pool2d_fwd_fp16.cpp 3.64 KB
Newer Older
Qianfeng's avatar
Qianfeng committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
#include <iostream>
#include <cstdlib>

#include "config.hpp"
#include "tensor_layout.hpp"
#include "reduction_enums.hpp"

#include "pool2d_fwd_common.hpp"

using InDataType  = ck::half_t;
using OutDataType = ck::half_t;
using AccDataType = float;

using IndexDataType = int32_t;

using InLayout  = ck::tensor_layout::convolution::NHWC;
using OutLayout = ck::tensor_layout::convolution::NHWC;

#if 1
static constexpr auto ReduceOpId = ck::ReduceTensorOp::MAX;
#else
static constexpr auto ReduceOpId = ck::ReduceTensorOp::AVG;
#endif

static constexpr bool OutputIndex  = false;
static constexpr bool PropagateNan = false;

int main(int argc, char* argv[])
{
    using namespace ck::host_reduce;

    bool do_verification;
    int init_method;
    bool time_kernel;

    // Pool shape
    ck::index_t N               = 128;
    ck::index_t C               = 192;
    ck::index_t Y               = 3;
    ck::index_t X               = 3;
    ck::index_t Hi              = 71;
    ck::index_t Wi              = 71;
    ck::index_t window_stride_h = 2;
    ck::index_t window_stride_w = 2;
    ck::index_t in_left_pad_h   = 1;
    ck::index_t in_left_pad_w   = 1;
    ck::index_t in_right_pad_h  = 1;
    ck::index_t in_right_pad_w  = 1;

    if(argc == 1)
    {
        do_verification = true;
        init_method     = 1;
        time_kernel     = true;
    }
    else if(argc == 4)
    {
        do_verification = std::stoi(argv[1]);
        init_method     = std::stoi(argv[2]);
        time_kernel     = static_cast<bool>(std::stoi(argv[3]));
    }
    else if(argc == 16)
    {
        do_verification = std::stoi(argv[1]);
        init_method     = std::stoi(argv[2]);
        time_kernel     = static_cast<bool>(std::stoi(argv[3]));

        N               = std::stoi(argv[4]);
        C               = std::stoi(argv[5]);
        Y               = std::stoi(argv[6]);
        X               = std::stoi(argv[7]);
        Hi              = std::stoi(argv[8]);
        Wi              = std::stoi(argv[9]);
        window_stride_h = std::stoi(argv[10]);
        window_stride_w = std::stoi(argv[11]);
        in_left_pad_h   = std::stoi(argv[12]);
        in_left_pad_w   = std::stoi(argv[13]);
        in_right_pad_h  = std::stoi(argv[14]);
        in_right_pad_w  = std::stoi(argv[15]);
    }
    else
    {
        printf("arg1: verification (0=no, 1=yes)\n");
        printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
        printf("arg3: time kernel (0=no, 1=yes)\n");
        printf("arg4 to 15: N, C, Y, X, Hi, Wi, Sy, Sx, LeftPy, LeftPx, RightPy, "
               "RightPx\n");
        exit(0);
    }

    bool pass = pool_test<InDataType,
                          OutDataType,
                          AccDataType,
                          IndexDataType,
                          InLayout,
                          OutLayout,
                          ReduceOpId,
                          PropagateNan,
                          OutputIndex>(do_verification,
                                       init_method,
                                       time_kernel,
                                       N,
                                       C,
                                       Y,
                                       X,
                                       Hi,
                                       Wi,
                                       window_stride_h,
                                       window_stride_w,
                                       in_left_pad_h,
                                       in_left_pad_w,
                                       in_right_pad_h,
                                       in_right_pad_w);

    return (pass ? 0 : 1);
}