conv2d_fwd.cpp 10.1 KB
Newer Older
1
2
#include <tuple>
#include <vector>
Chao Liu's avatar
Chao Liu committed
3
#include <gtest/gtest.h>
4

Chao Liu's avatar
Chao Liu committed
5
6
#include "ck/utility/data_type.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
Adam Osewski's avatar
Adam Osewski committed
7
#include "ck/library/utility/conv_util.hpp"
Chao Liu's avatar
Chao Liu committed
8
#include "test/convnd_fwd/conv_util.hpp"
9
10
11

namespace {

12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
class Conv2dFwdNHWCInstances : public ::testing::Test
{
    public:
    template <typename T>
    bool test_conv2d_nhwc_instances(const std::vector<test::conv::DeviceConvFwdNoOpPtr>& conv_ptrs,
                                    const ck::utils::conv::ConvParams& params)
    {
        using namespace std::placeholders;
        using namespace ck::utils;

        conv::ConvFwdOpInstance<T,
                                T,
                                T,
                                ck::tensor_layout::convolution::NHWC,
                                ck::tensor_layout::convolution::KYXC,
                                ck::tensor_layout::convolution::NHWK,
                                ck::tensor_operation::element_wise::PassThrough,
                                ck::tensor_operation::element_wise::PassThrough,
                                ck::tensor_operation::element_wise::PassThrough,
                                FillUniformDistributionIntegerValue<T>,
                                FillUniformDistributionIntegerValue<T>>
            conv_instance(params,
                          true,
                          FillUniformDistributionIntegerValue<T>{},
                          FillUniformDistributionIntegerValue<T>{});
        auto reference_conv_fwd_fun =
            std::bind(conv::run_reference_convolution_forward<2, T, T, T>, params, _1, _2, _3);
        OpInstanceRunEngine<T, T, T> run_engine(conv_instance, reference_conv_fwd_fun);
        run_engine.SetAtol(atol_);
        run_engine.SetRtol(rtol_);
        return run_engine.Test(conv_ptrs);
    }

    template <typename T>
    bool test_default(bool use_convnd = false)
    {
        if(use_convnd)
        {
            return test_conv2d_nhwc_instances<T>(
                test::conv::ConvolutionNDFwdInstances<T, T, T>::Get(2), params_default_);
        }
        else
        {
            return test_conv2d_nhwc_instances<T>(
                ck::utils::conv::ConvolutionFwdInstances<T, T, T>::template Get<2>(),
                params_default_);
        }
    }

    template <typename T>
    bool test_filter1x1_stride1_pad0(bool use_convnd = false)
    {
        if(use_convnd)
        {
            return test_conv2d_nhwc_instances<T>(
                test::conv::ConvolutionNDFwdInstances<T, T, T>::Get(2),
                params_filter1x1_stride1_pad0_);
        }
        else
        {
            return test_conv2d_nhwc_instances<T>(
                ck::utils::conv::ConvolutionFwdInstances<T, T, T>::template Get<2>(),
                params_filter1x1_stride1_pad0_);
        }
    }

    template <typename T>
    bool test_filter1x1_pad0(bool use_convnd = false)
    {
        if(use_convnd)
        {
            return test_conv2d_nhwc_instances<T>(
                test::conv::ConvolutionNDFwdInstances<T, T, T>::Get(2), params_filter1x1_pad0_);
        }
        else
        {
            return test_conv2d_nhwc_instances<T>(
                ck::utils::conv::ConvolutionFwdInstances<T, T, T>::template Get<2>(),
                params_filter1x1_pad0_);
        }
    }

    template <typename T>
    bool test_oddC()
    {
        return test_conv2d_nhwc_instances<T>(
            ck::utils::conv::ConvolutionFwdInstances<T, T, T>::template Get<2>(), params_oddC_);
    }

    static inline ck::utils::conv::ConvParams params_default_{
        2, 4, 256, 64, {3, 3}, {36, 36}, {2, 2}, {2, 2}, {2, 2}, {2, 2}};
    static inline ck::utils::conv::ConvParams params_filter1x1_stride1_pad0_{
        2, 4, 256, 64, {1, 1}, {28, 28}, {1, 1}, {1, 1}, {0, 0}, {0, 0}};
    static inline ck::utils::conv::ConvParams params_filter1x1_pad0_{
        2, 4, 256, 64, {1, 1}, {28, 28}, {2, 2}, {1, 1}, {0, 0}, {0, 0}};
    static inline ck::utils::conv::ConvParams params_oddC_{
        2, 4, 256, 3, {3, 3}, {28, 28}, {1, 1}, {1, 1}, {0, 0}, {0, 0}};

    private:
    double atol_{1e-5};
    double rtol_{1e-4};
};

} // anonymous namespace

TEST(Conv2DFwdNHWC, IntegerValues)
118
{
119
120
    using namespace std::placeholders;
    using namespace ck::utils;
121
    using T = float;
122

123
124
    ck::utils::conv::ConvParams params{
        2, 4, 256, 64, {3, 3}, {36, 36}, {1, 1}, {2, 2}, {2, 2}, {2, 2}};
125

126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
    std::vector<test::conv::DeviceConvFwdNoOpPtr> conv_ptrs;
    test::conv::get_test_convolution_fwd_instance<2, T, T, T, T>(conv_ptrs);
    conv::ConvFwdOpInstance<T,
                            T,
                            T,
                            ck::tensor_layout::convolution::NHWC,
                            ck::tensor_layout::convolution::KYXC,
                            ck::tensor_layout::convolution::NHWK,
                            ck::tensor_operation::element_wise::PassThrough,
                            ck::tensor_operation::element_wise::PassThrough,
                            ck::tensor_operation::element_wise::PassThrough,
                            FillUniformDistributionIntegerValue<T>,
                            FillUniformDistributionIntegerValue<T>>
        conv_instance(params,
                      true,
                      FillUniformDistributionIntegerValue<T>{},
                      FillUniformDistributionIntegerValue<T>{});
143
144
145
146

    auto reference_conv_fwd_fun =
        std::bind(conv::run_reference_convolution_forward<2, T, T, T>, params, _1, _2, _3);
    OpInstanceRunEngine<T, T, T> run_engine(conv_instance, reference_conv_fwd_fun);
147
148
149
    run_engine.SetAtol(1e-5);
    run_engine.SetRtol(1e-4);
    EXPECT_TRUE(run_engine.Test(conv_ptrs));
150
151
}

152
TEST(Conv2DFwdNHWC, FloatingPointValues)
153
{
154
155
    using namespace std::placeholders;
    using namespace ck::utils;
156
    using T = ck::half_t;
157

158
159
    ck::utils::conv::ConvParams params{
        2, 4, 256, 64, {3, 3}, {36, 36}, {2, 2}, {2, 2}, {2, 2}, {2, 2}};
160
161

    std::vector<test::conv::DeviceConvFwdNoOpPtr> conv_ptrs;
162
163
164
165
166
167
168
169
170
171
172
173
174
    test::conv::get_test_convolution_fwd_instance<2, T, T, T, float>(conv_ptrs);
    conv::ConvFwdOpInstance<T,
                            T,
                            T,
                            ck::tensor_layout::convolution::NHWC,
                            ck::tensor_layout::convolution::KYXC,
                            ck::tensor_layout::convolution::NHWK,
                            ck::tensor_operation::element_wise::PassThrough,
                            ck::tensor_operation::element_wise::PassThrough,
                            ck::tensor_operation::element_wise::PassThrough,
                            FillUniformDistribution<T>,
                            FillUniformDistribution<T>>
        conv_instance(params, true, FillUniformDistribution<T>{}, FillUniformDistribution<T>{});
175

176
177
178
179
180
    auto reference_conv_fwd_fun =
        std::bind(conv::run_reference_convolution_forward<2, T, T, T>, params, _1, _2, _3);
    OpInstanceRunEngine<T, T, T> run_engine(conv_instance, reference_conv_fwd_fun);
    run_engine.SetAtol(2e-4);
    run_engine.SetRtol(1e-3);
181
    EXPECT_TRUE(run_engine.Test(conv_ptrs));
182
183
}

184
185
TEST_F(Conv2dFwdNHWCInstances, BF16_default) { EXPECT_TRUE(this->test_default<ck::bhalf_t>()); }
TEST_F(Conv2dFwdNHWCInstances, BF16_filter1x1_stride1_pad0)
186
{
187
    EXPECT_TRUE(this->test_filter1x1_stride1_pad0<ck::bhalf_t>());
188
}
189
TEST_F(Conv2dFwdNHWCInstances, BF16_filter1x1_pad0)
190
{
191
    EXPECT_TRUE(this->test_filter1x1_pad0<ck::bhalf_t>());
192
}
193
194
TEST_F(Conv2dFwdNHWCInstances, F16_default) { EXPECT_TRUE(this->test_default<ck::half_t>()); }
TEST_F(Conv2dFwdNHWCInstances, F16_filter1x1_stride1_pad0)
195
{
196
    EXPECT_TRUE(this->test_filter1x1_stride1_pad0<ck::half_t>());
197
}
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
TEST_F(Conv2dFwdNHWCInstances, F16_filter1x1_pad0)
{
    EXPECT_TRUE(this->test_filter1x1_pad0<ck::half_t>());
}
TEST_F(Conv2dFwdNHWCInstances, F16_oddC) { EXPECT_TRUE(this->test_oddC<ck::half_t>()); }
TEST_F(Conv2dFwdNHWCInstances, F32_default) { EXPECT_TRUE(this->test_default<float>()); }
TEST_F(Conv2dFwdNHWCInstances, F32_filter1x1_stride1_pad0)
{
    EXPECT_TRUE(this->test_filter1x1_stride1_pad0<float>());
}
TEST_F(Conv2dFwdNHWCInstances, F32_filter1x1_pad0)
{
    EXPECT_TRUE(this->test_filter1x1_pad0<float>());
}
TEST_F(Conv2dFwdNHWCInstances, I8_default) { EXPECT_TRUE(this->test_default<int8_t>()); }
TEST_F(Conv2dFwdNHWCInstances, I8_filter1x1_stride1_pad0)
214
{
215
216
217
218
219
    EXPECT_TRUE(this->test_filter1x1_stride1_pad0<int8_t>());
}
TEST_F(Conv2dFwdNHWCInstances, I8_filter1x1_pad0)
{
    EXPECT_TRUE(this->test_filter1x1_pad0<int8_t>());
220
}
221

222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
TEST_F(Conv2dFwdNHWCInstances, ND_BF16_default)
{
    EXPECT_TRUE(this->test_default<ck::bhalf_t>(true));
}
TEST_F(Conv2dFwdNHWCInstances, ND_BF16_filter1x1_stride1_pad0)
{
    EXPECT_TRUE(this->test_filter1x1_stride1_pad0<ck::bhalf_t>(true));
}
TEST_F(Conv2dFwdNHWCInstances, ND_BF16_filter1x1_pad0)
{
    EXPECT_TRUE(this->test_filter1x1_pad0<ck::bhalf_t>(true));
}
TEST_F(Conv2dFwdNHWCInstances, ND_F16_default)
{
    EXPECT_TRUE(this->test_default<ck::half_t>(true));
}
TEST_F(Conv2dFwdNHWCInstances, ND_F16_filter1x1_stride1_pad0)
{
    EXPECT_TRUE(this->test_filter1x1_stride1_pad0<ck::half_t>(true));
}
TEST_F(Conv2dFwdNHWCInstances, ND_F16_filter1x1_pad0)
{
    EXPECT_TRUE(this->test_filter1x1_pad0<ck::half_t>(true));
}
TEST_F(Conv2dFwdNHWCInstances, ND_F32_default) { EXPECT_TRUE(this->test_default<float>(true)); }
TEST_F(Conv2dFwdNHWCInstances, ND_F32_filter1x1_stride1_pad0)
{
    EXPECT_TRUE(this->test_filter1x1_stride1_pad0<float>(true));
}
TEST_F(Conv2dFwdNHWCInstances, ND_F32_filter1x1_pad0)
{
    EXPECT_TRUE(this->test_filter1x1_pad0<float>(true));
}
TEST_F(Conv2dFwdNHWCInstances, ND_I8_default) { EXPECT_TRUE(this->test_default<int8_t>(true)); }
TEST_F(Conv2dFwdNHWCInstances, ND_I8_filter1x1_stride1_pad0)
{
    EXPECT_TRUE(this->test_filter1x1_stride1_pad0<int8_t>(true));
}
TEST_F(Conv2dFwdNHWCInstances, ND_I8_filter1x1_pad0)
261
{
262
    EXPECT_TRUE(this->test_filter1x1_pad0<int8_t>(true));
263
}