conv2d_fwd.cpp 10.1 KB
Newer Older
1
2
#include <tuple>
#include <vector>
3
#include "gtest/gtest.h"
4

Adam Osewski's avatar
Adam Osewski committed
5
#include "ck/library/utility/conv_util.hpp"
6
#include "config.hpp"
7
#include "conv_util.hpp"
8
9
10
#include "data_type.hpp"
#include "element_wise_operation.hpp"
#include "fill.hpp"
11
12
13

namespace {

14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
class Conv2dFwdNHWCInstances : public ::testing::Test
{
    public:
    template <typename T>
    bool test_conv2d_nhwc_instances(const std::vector<test::conv::DeviceConvFwdNoOpPtr>& conv_ptrs,
                                    const ck::utils::conv::ConvParams& params)
    {
        using namespace std::placeholders;
        using namespace ck::utils;

        conv::ConvFwdOpInstance<T,
                                T,
                                T,
                                ck::tensor_layout::convolution::NHWC,
                                ck::tensor_layout::convolution::KYXC,
                                ck::tensor_layout::convolution::NHWK,
                                ck::tensor_operation::element_wise::PassThrough,
                                ck::tensor_operation::element_wise::PassThrough,
                                ck::tensor_operation::element_wise::PassThrough,
                                FillUniformDistributionIntegerValue<T>,
                                FillUniformDistributionIntegerValue<T>>
            conv_instance(params,
                          true,
                          FillUniformDistributionIntegerValue<T>{},
                          FillUniformDistributionIntegerValue<T>{});
        auto reference_conv_fwd_fun =
            std::bind(conv::run_reference_convolution_forward<2, T, T, T>, params, _1, _2, _3);
        OpInstanceRunEngine<T, T, T> run_engine(conv_instance, reference_conv_fwd_fun);
        run_engine.SetAtol(atol_);
        run_engine.SetRtol(rtol_);
        return run_engine.Test(conv_ptrs);
    }

    template <typename T>
    bool test_default(bool use_convnd = false)
    {
        if(use_convnd)
        {
            return test_conv2d_nhwc_instances<T>(
                test::conv::ConvolutionNDFwdInstances<T, T, T>::Get(2), params_default_);
        }
        else
        {
            return test_conv2d_nhwc_instances<T>(
                ck::utils::conv::ConvolutionFwdInstances<T, T, T>::template Get<2>(),
                params_default_);
        }
    }

    template <typename T>
    bool test_filter1x1_stride1_pad0(bool use_convnd = false)
    {
        if(use_convnd)
        {
            return test_conv2d_nhwc_instances<T>(
                test::conv::ConvolutionNDFwdInstances<T, T, T>::Get(2),
                params_filter1x1_stride1_pad0_);
        }
        else
        {
            return test_conv2d_nhwc_instances<T>(
                ck::utils::conv::ConvolutionFwdInstances<T, T, T>::template Get<2>(),
                params_filter1x1_stride1_pad0_);
        }
    }

    template <typename T>
    bool test_filter1x1_pad0(bool use_convnd = false)
    {
        if(use_convnd)
        {
            return test_conv2d_nhwc_instances<T>(
                test::conv::ConvolutionNDFwdInstances<T, T, T>::Get(2), params_filter1x1_pad0_);
        }
        else
        {
            return test_conv2d_nhwc_instances<T>(
                ck::utils::conv::ConvolutionFwdInstances<T, T, T>::template Get<2>(),
                params_filter1x1_pad0_);
        }
    }

    template <typename T>
    bool test_oddC()
    {
        return test_conv2d_nhwc_instances<T>(
            ck::utils::conv::ConvolutionFwdInstances<T, T, T>::template Get<2>(), params_oddC_);
    }

    static inline ck::utils::conv::ConvParams params_default_{
        2, 4, 256, 64, {3, 3}, {36, 36}, {2, 2}, {2, 2}, {2, 2}, {2, 2}};
    static inline ck::utils::conv::ConvParams params_filter1x1_stride1_pad0_{
        2, 4, 256, 64, {1, 1}, {28, 28}, {1, 1}, {1, 1}, {0, 0}, {0, 0}};
    static inline ck::utils::conv::ConvParams params_filter1x1_pad0_{
        2, 4, 256, 64, {1, 1}, {28, 28}, {2, 2}, {1, 1}, {0, 0}, {0, 0}};
    static inline ck::utils::conv::ConvParams params_oddC_{
        2, 4, 256, 3, {3, 3}, {28, 28}, {1, 1}, {1, 1}, {0, 0}, {0, 0}};

    private:
    double atol_{1e-5};
    double rtol_{1e-4};
};

} // anonymous namespace

TEST(Conv2DFwdNHWC, IntegerValues)
120
{
121
122
    using namespace std::placeholders;
    using namespace ck::utils;
123
    using T = float;
124

125
126
    ck::utils::conv::ConvParams params{
        2, 4, 256, 64, {3, 3}, {36, 36}, {1, 1}, {2, 2}, {2, 2}, {2, 2}};
127

128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
    std::vector<test::conv::DeviceConvFwdNoOpPtr> conv_ptrs;
    test::conv::get_test_convolution_fwd_instance<2, T, T, T, T>(conv_ptrs);
    conv::ConvFwdOpInstance<T,
                            T,
                            T,
                            ck::tensor_layout::convolution::NHWC,
                            ck::tensor_layout::convolution::KYXC,
                            ck::tensor_layout::convolution::NHWK,
                            ck::tensor_operation::element_wise::PassThrough,
                            ck::tensor_operation::element_wise::PassThrough,
                            ck::tensor_operation::element_wise::PassThrough,
                            FillUniformDistributionIntegerValue<T>,
                            FillUniformDistributionIntegerValue<T>>
        conv_instance(params,
                      true,
                      FillUniformDistributionIntegerValue<T>{},
                      FillUniformDistributionIntegerValue<T>{});
145
146
147
148

    auto reference_conv_fwd_fun =
        std::bind(conv::run_reference_convolution_forward<2, T, T, T>, params, _1, _2, _3);
    OpInstanceRunEngine<T, T, T> run_engine(conv_instance, reference_conv_fwd_fun);
149
150
151
    run_engine.SetAtol(1e-5);
    run_engine.SetRtol(1e-4);
    EXPECT_TRUE(run_engine.Test(conv_ptrs));
152
153
}

154
TEST(Conv2DFwdNHWC, FloatingPointValues)
155
{
156
157
    using namespace std::placeholders;
    using namespace ck::utils;
158
    using T = ck::half_t;
159

160
161
    ck::utils::conv::ConvParams params{
        2, 4, 256, 64, {3, 3}, {36, 36}, {2, 2}, {2, 2}, {2, 2}, {2, 2}};
162
163

    std::vector<test::conv::DeviceConvFwdNoOpPtr> conv_ptrs;
164
165
166
167
168
169
170
171
172
173
174
175
176
    test::conv::get_test_convolution_fwd_instance<2, T, T, T, float>(conv_ptrs);
    conv::ConvFwdOpInstance<T,
                            T,
                            T,
                            ck::tensor_layout::convolution::NHWC,
                            ck::tensor_layout::convolution::KYXC,
                            ck::tensor_layout::convolution::NHWK,
                            ck::tensor_operation::element_wise::PassThrough,
                            ck::tensor_operation::element_wise::PassThrough,
                            ck::tensor_operation::element_wise::PassThrough,
                            FillUniformDistribution<T>,
                            FillUniformDistribution<T>>
        conv_instance(params, true, FillUniformDistribution<T>{}, FillUniformDistribution<T>{});
177

178
179
180
181
182
    auto reference_conv_fwd_fun =
        std::bind(conv::run_reference_convolution_forward<2, T, T, T>, params, _1, _2, _3);
    OpInstanceRunEngine<T, T, T> run_engine(conv_instance, reference_conv_fwd_fun);
    run_engine.SetAtol(2e-4);
    run_engine.SetRtol(1e-3);
183
    EXPECT_TRUE(run_engine.Test(conv_ptrs));
184
185
}

186
187
TEST_F(Conv2dFwdNHWCInstances, BF16_default) { EXPECT_TRUE(this->test_default<ck::bhalf_t>()); }
TEST_F(Conv2dFwdNHWCInstances, BF16_filter1x1_stride1_pad0)
188
{
189
    EXPECT_TRUE(this->test_filter1x1_stride1_pad0<ck::bhalf_t>());
190
}
191
TEST_F(Conv2dFwdNHWCInstances, BF16_filter1x1_pad0)
192
{
193
    EXPECT_TRUE(this->test_filter1x1_pad0<ck::bhalf_t>());
194
}
195
196
TEST_F(Conv2dFwdNHWCInstances, F16_default) { EXPECT_TRUE(this->test_default<ck::half_t>()); }
TEST_F(Conv2dFwdNHWCInstances, F16_filter1x1_stride1_pad0)
197
{
198
    EXPECT_TRUE(this->test_filter1x1_stride1_pad0<ck::half_t>());
199
}
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
TEST_F(Conv2dFwdNHWCInstances, F16_filter1x1_pad0)
{
    EXPECT_TRUE(this->test_filter1x1_pad0<ck::half_t>());
}
TEST_F(Conv2dFwdNHWCInstances, F16_oddC) { EXPECT_TRUE(this->test_oddC<ck::half_t>()); }
TEST_F(Conv2dFwdNHWCInstances, F32_default) { EXPECT_TRUE(this->test_default<float>()); }
TEST_F(Conv2dFwdNHWCInstances, F32_filter1x1_stride1_pad0)
{
    EXPECT_TRUE(this->test_filter1x1_stride1_pad0<float>());
}
TEST_F(Conv2dFwdNHWCInstances, F32_filter1x1_pad0)
{
    EXPECT_TRUE(this->test_filter1x1_pad0<float>());
}
TEST_F(Conv2dFwdNHWCInstances, I8_default) { EXPECT_TRUE(this->test_default<int8_t>()); }
TEST_F(Conv2dFwdNHWCInstances, I8_filter1x1_stride1_pad0)
216
{
217
218
219
220
221
    EXPECT_TRUE(this->test_filter1x1_stride1_pad0<int8_t>());
}
TEST_F(Conv2dFwdNHWCInstances, I8_filter1x1_pad0)
{
    EXPECT_TRUE(this->test_filter1x1_pad0<int8_t>());
222
}
223

224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
TEST_F(Conv2dFwdNHWCInstances, ND_BF16_default)
{
    EXPECT_TRUE(this->test_default<ck::bhalf_t>(true));
}
TEST_F(Conv2dFwdNHWCInstances, ND_BF16_filter1x1_stride1_pad0)
{
    EXPECT_TRUE(this->test_filter1x1_stride1_pad0<ck::bhalf_t>(true));
}
TEST_F(Conv2dFwdNHWCInstances, ND_BF16_filter1x1_pad0)
{
    EXPECT_TRUE(this->test_filter1x1_pad0<ck::bhalf_t>(true));
}
TEST_F(Conv2dFwdNHWCInstances, ND_F16_default)
{
    EXPECT_TRUE(this->test_default<ck::half_t>(true));
}
TEST_F(Conv2dFwdNHWCInstances, ND_F16_filter1x1_stride1_pad0)
{
    EXPECT_TRUE(this->test_filter1x1_stride1_pad0<ck::half_t>(true));
}
TEST_F(Conv2dFwdNHWCInstances, ND_F16_filter1x1_pad0)
{
    EXPECT_TRUE(this->test_filter1x1_pad0<ck::half_t>(true));
}
TEST_F(Conv2dFwdNHWCInstances, ND_F32_default) { EXPECT_TRUE(this->test_default<float>(true)); }
TEST_F(Conv2dFwdNHWCInstances, ND_F32_filter1x1_stride1_pad0)
{
    EXPECT_TRUE(this->test_filter1x1_stride1_pad0<float>(true));
}
TEST_F(Conv2dFwdNHWCInstances, ND_F32_filter1x1_pad0)
{
    EXPECT_TRUE(this->test_filter1x1_pad0<float>(true));
}
TEST_F(Conv2dFwdNHWCInstances, ND_I8_default) { EXPECT_TRUE(this->test_default<int8_t>(true)); }
TEST_F(Conv2dFwdNHWCInstances, ND_I8_filter1x1_stride1_pad0)
{
    EXPECT_TRUE(this->test_filter1x1_stride1_pad0<int8_t>(true));
}
TEST_F(Conv2dFwdNHWCInstances, ND_I8_filter1x1_pad0)
263
{
264
    EXPECT_TRUE(this->test_filter1x1_pad0<int8_t>(true));
265
}