avgpool3d_bwd_fp16.cpp 2.22 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.

#include <iostream>

#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_avgpool3d_bwd_ndhwc_ndhwc.hpp"

#include "avgpool3d_bwd_common.hpp"

using DOutDataType    = ck::half_t;
using DInDataType     = ck::half_t;
using ComputeDataType = float;

#if 1
using DOutLayout = ck::tensor_layout::convolution::NDHWC;
using DInLayout  = ck::tensor_layout::convolution::NDHWC;
#else
using DOutLayout = ck::tensor_layout::convolution::NCDHW;
using DInLayout  = ck::tensor_layout::convolution::NCDHW;
#endif

using DevicePoolBwdInstance =
    ck::tensor_operation::device::DeviceAvgPool3dBwd_NDHWC_NDHWC<DOutDataType,
                                                                 DInDataType,
                                                                 ComputeDataType,
                                                                 64, // BlockSize
                                                                 64, // ReduceMThreadClusterSize
                                                                 1,  // ReduceKThreadClusterSize
                                                                 1,  // ReduceMThreadSliceSize
                                                                 1,  // ReduceKThreadSliceSize
                                                                 1>; // InSrcOutDstVectorSize

int main()
{
    std::vector<ck::index_t> window_lengths    = {5, 5, 5};
    std::vector<ck::index_t> window_strides    = {2, 2, 2};
    std::vector<ck::index_t> window_dilations  = {2, 2, 2};
    std::vector<ck::index_t> dinput_left_pads  = {0, 0, 0};
    std::vector<ck::index_t> dinput_right_pads = {0, 0, 0};

    ck::index_t N  = 1;
    ck::index_t C  = 16;
    ck::index_t Di = 40;
    ck::index_t Hi = 40;
    ck::index_t Wi = 40;

    pool3d_bwd_test<DevicePoolBwdInstance, DOutDataType, DInDataType, DOutLayout, DInLayout>(
        true,
        false,
        N,
        C,
        Di,
        Hi,
        Wi,
        window_lengths,
        window_strides,
        window_dilations,
        dinput_left_pads,
        dinput_right_pads);
}