conv3d_fwd.cpp 14 KB
Newer Older
1
2
3
4
5
#include <half.hpp>
#include <iostream>
#include <stdexcept>
#include <tuple>
#include <vector>
6
#include "gtest/gtest.h"
7
8
9

#include "data_type.hpp"
#include "element_wise_operation.hpp"
Adam Osewski's avatar
Adam Osewski committed
10
#include "library/include/ck/library/utility/conv_util.hpp"
11
#include "conv_util.hpp"
12
13
14

namespace {

15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
class Conv3dFwdNDHWCInstances : public ::testing::Test
{
    public:
    template <typename T>
    bool test_conv3d_nwc_instances(const std::vector<test::conv::DeviceConvFwdNoOpPtr>& conv_ptrs,
                                   const ck::utils::conv::ConvParams& params)
    {
        using namespace std::placeholders;
        using namespace ck::utils;
        namespace ctl = ck::tensor_layout::convolution;

        conv::ConvFwdOpInstance<T,
                                T,
                                T,
                                ctl::NDHWC,
                                ctl::KZYXC,
                                ctl::NDHWK,
                                ck::tensor_operation::element_wise::PassThrough,
                                ck::tensor_operation::element_wise::PassThrough,
                                ck::tensor_operation::element_wise::PassThrough,
                                FillUniformDistributionIntegerValue<T>,
                                FillUniformDistributionIntegerValue<T>>
            conv_instance(params,
                          true,
                          FillUniformDistributionIntegerValue<T>{},
                          FillUniformDistributionIntegerValue<T>{});
        auto reference_conv_fwd_fun =
            std::bind(conv::run_reference_convolution_forward<3, T, T, T>, params, _1, _2, _3);
        OpInstanceRunEngine<T, T, T> run_engine(conv_instance, reference_conv_fwd_fun);
        run_engine.SetAtol(atol_);
        run_engine.SetRtol(rtol_);
        return run_engine.Test(conv_ptrs);
    }

    template <typename T>
    bool test_default()
    {
        return test_conv3d_nwc_instances<T>(
            ck::utils::conv::ConvolutionFwdInstances<T, T, T>::template Get<3>(), params_default_);
    }

    template <typename T>
    bool test_filter1x1_stride1_pad0()
    {
        return test_conv3d_nwc_instances<T>(
            ck::utils::conv::ConvolutionFwdInstances<T, T, T>::template Get<3>(),
            params_filter1x1_stride1_pad0_);
    }

    template <typename T>
    bool test_filter1x1_pad0()
    {
        return test_conv3d_nwc_instances<T>(
            ck::utils::conv::ConvolutionFwdInstances<T, T, T>::template Get<3>(),
            params_filter1x1_pad0_);
    }

    static inline ck::utils::conv::ConvParams params_default_{
        3, 4, 256, 64, {3, 3, 3}, {28, 28, 28}, {2, 2, 2}, {2, 2, 2}, {2, 2, 2}, {2, 2, 2}};
    static inline ck::utils::conv::ConvParams params_filter1x1_stride1_pad0_{
        3, 4, 256, 64, {1, 1, 1}, {28, 28, 28}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}};
    static inline ck::utils::conv::ConvParams params_filter1x1_pad0_{
        3, 4, 256, 64, {1, 1, 1}, {28, 28, 28}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}};

    private:
    double atol_{1e-5};
    double rtol_{1e-4};
};

} // anonymous namespace

TEST(Conv3DFwdNDHWC, IntegerValues)
87
88
89
90
{
    using namespace std::placeholders;
    using namespace ck::utils;
    namespace ctl = ck::tensor_layout::convolution;
91
    using T       = float;
92

93
94
    ck::utils::conv::ConvParams params{
        3, 4, 256, 64, {3, 3, 3}, {18, 18, 18}, {1, 1, 1}, {2, 2, 2}, {2, 2, 2}, {2, 2, 2}};
95

96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
    std::vector<test::conv::DeviceConvFwdNoOpPtr> conv_ptrs;
    test::conv::get_test_convolution_fwd_instance<3, T, T, T, T>(conv_ptrs);
    conv::ConvFwdOpInstance<T,
                            T,
                            T,
                            ctl::NDHWC,
                            ctl::KZYXC,
                            ctl::NDHWK,
                            ck::tensor_operation::element_wise::PassThrough,
                            ck::tensor_operation::element_wise::PassThrough,
                            ck::tensor_operation::element_wise::PassThrough,
                            FillUniformDistributionIntegerValue<T>,
                            FillUniformDistributionIntegerValue<T>>
        conv_instance(params,
                      true,
                      FillUniformDistributionIntegerValue<T>{},
                      FillUniformDistributionIntegerValue<T>{});
113
114
115
116

    auto reference_conv_fwd_fun =
        std::bind(conv::run_reference_convolution_forward<3, T, T, T>, params, _1, _2, _3);
    OpInstanceRunEngine<T, T, T> run_engine(conv_instance, reference_conv_fwd_fun);
117
118
119
    run_engine.SetAtol(1e-5);
    run_engine.SetRtol(1e-3);
    EXPECT_TRUE(run_engine.Test(conv_ptrs));
120
121
}

122
TEST(Conv3DFwdNDHWC, FloatingPointValues)
123
{
124
125
126
    using namespace std::placeholders;
    using namespace ck::utils;
    namespace ctl = ck::tensor_layout::convolution;
127
    using T       = ck::half_t;
128

129
130
    ck::utils::conv::ConvParams params{
        3, 4, 256, 64, {3, 3, 3}, {18, 18, 18}, {1, 1, 1}, {2, 2, 2}, {2, 2, 2}, {2, 2, 2}};
131

132
    std::vector<test::conv::DeviceConvFwdNoOpPtr> conv_ptrs;
133
134
135
136
137
138
139
140
141
142
143
144
145
    test::conv::get_test_convolution_fwd_instance<3, T, T, T, float>(conv_ptrs);
    conv::ConvFwdOpInstance<T,
                            T,
                            T,
                            ctl::NDHWC,
                            ctl::KZYXC,
                            ctl::NDHWK,
                            ck::tensor_operation::element_wise::PassThrough,
                            ck::tensor_operation::element_wise::PassThrough,
                            ck::tensor_operation::element_wise::PassThrough,
                            FillUniformDistribution<T>,
                            FillUniformDistribution<T>>
        conv_instance(params, true, FillUniformDistribution<T>{}, FillUniformDistribution<T>{});
146

147
148
149
150
151
    auto reference_conv_fwd_fun =
        std::bind(conv::run_reference_convolution_forward<3, T, T, T>, params, _1, _2, _3);
    OpInstanceRunEngine<T, T, T> run_engine(conv_instance, reference_conv_fwd_fun);
    run_engine.SetAtol(1e-3);
    run_engine.SetRtol(1e-3);
152
    EXPECT_TRUE(run_engine.Test(conv_ptrs));
153
154
}

155
TEST(Conv3DFwdNDHWC, InputOver2GB)
156
{
157
158
    using PassThrough = ck::tensor_operation::element_wise::PassThrough;
    using namespace ck::utils;
159
    using T = float;
160

161
    // >2GB Input
162
    conv::ConvParams params;
Adam Osewski's avatar
Adam Osewski committed
163
164
165
166
167
168
169
170
171
172
    params.num_dim_spatial_        = 3;
    params.N_                      = 2;
    params.K_                      = 16;
    params.C_                      = 32;
    params.filter_spatial_lengths_ = std::vector<ck::index_t>{3, 3, 3};
    params.input_spatial_lengths_  = std::vector<ck::index_t>{32, 1000, 1000};
    params.conv_filter_strides_    = std::vector<ck::index_t>{1, 1, 1};
    params.conv_filter_dilations_  = std::vector<ck::index_t>{1, 1, 1};
    params.input_left_pads_        = std::vector<ck::index_t>{1, 1, 1};
    params.input_right_pads_       = std::vector<ck::index_t>{1, 1, 1};
173

174
    std::vector<test::conv::DeviceConvFwdNoOpPtr> conv_ptrs;
175
    test::conv::get_test_convolution_fwd_instance<3, T, T, T, T>(conv_ptrs);
176
177
178
    auto arg = conv_ptrs.back()->MakeArgumentPointer(nullptr,
                                                     nullptr,
                                                     nullptr,
Adam Osewski's avatar
Adam Osewski committed
179
180
181
182
183
                                                     params.N_,
                                                     params.K_,
                                                     params.C_,
                                                     params.input_spatial_lengths_,
                                                     params.filter_spatial_lengths_,
184
                                                     params.GetOutputSpatialLengths(),
Adam Osewski's avatar
Adam Osewski committed
185
186
187
188
                                                     params.conv_filter_strides_,
                                                     params.conv_filter_dilations_,
                                                     params.input_left_pads_,
                                                     params.input_right_pads_,
189
190
191
                                                     PassThrough{},
                                                     PassThrough{},
                                                     PassThrough{});
192
    EXPECT_FALSE(conv_ptrs.back()->IsSupportedArgument(arg.get()));
193
194
}

195
TEST(Conv3DFwdNDHWC, FiltersOver2GB)
196
{
197
198
    using PassThrough = ck::tensor_operation::element_wise::PassThrough;
    using namespace ck::utils;
199
    using T = float;
200

201
    // >2GB Filters
202
    conv::ConvParams params;
Adam Osewski's avatar
Adam Osewski committed
203
204
205
206
207
208
209
210
211
212
    params.num_dim_spatial_        = 3;
    params.N_                      = 2;
    params.K_                      = 16;
    params.C_                      = 32;
    params.filter_spatial_lengths_ = std::vector<ck::index_t>{4, 1000, 1000};
    params.input_spatial_lengths_  = std::vector<ck::index_t>{16, 16, 16};
    params.conv_filter_strides_    = std::vector<ck::index_t>{1, 1, 1};
    params.conv_filter_dilations_  = std::vector<ck::index_t>{1, 1, 1};
    params.input_left_pads_        = std::vector<ck::index_t>{1, 1, 1};
    params.input_right_pads_       = std::vector<ck::index_t>{1, 1, 1};
213

214
    std::vector<test::conv::DeviceConvFwdNoOpPtr> conv_ptrs;
215
    test::conv::get_test_convolution_fwd_instance<3, T, T, T, T>(conv_ptrs);
216
217
218
    auto arg = conv_ptrs.back()->MakeArgumentPointer(nullptr,
                                                     nullptr,
                                                     nullptr,
Adam Osewski's avatar
Adam Osewski committed
219
220
221
222
223
                                                     params.N_,
                                                     params.K_,
                                                     params.C_,
                                                     params.input_spatial_lengths_,
                                                     params.filter_spatial_lengths_,
224
                                                     params.GetOutputSpatialLengths(),
Adam Osewski's avatar
Adam Osewski committed
225
226
227
228
                                                     params.conv_filter_strides_,
                                                     params.conv_filter_dilations_,
                                                     params.input_left_pads_,
                                                     params.input_right_pads_,
229
230
231
                                                     PassThrough{},
                                                     PassThrough{},
                                                     PassThrough{});
232
    EXPECT_FALSE(conv_ptrs.back()->IsSupportedArgument(arg.get()));
233
234
}

235
TEST(Conv3DFwdNDHWC, OutputOver2GB)
236
{
237
238
    using PassThrough = ck::tensor_operation::element_wise::PassThrough;
    using namespace ck::utils;
239
    using T = float;
240

241
    // >2GB Output
242
    conv::ConvParams params;
Adam Osewski's avatar
Adam Osewski committed
243
244
245
246
247
248
249
250
251
252
    params.num_dim_spatial_        = 3;
    params.N_                      = 2;
    params.K_                      = 16;
    params.C_                      = 2;
    params.filter_spatial_lengths_ = std::vector<ck::index_t>{1, 1, 1};
    params.input_spatial_lengths_  = std::vector<ck::index_t>{1000, 1000, 30};
    params.conv_filter_strides_    = std::vector<ck::index_t>{1, 1, 1};
    params.conv_filter_dilations_  = std::vector<ck::index_t>{1, 1, 1};
    params.input_left_pads_        = std::vector<ck::index_t>{2, 2, 2};
    params.input_right_pads_       = std::vector<ck::index_t>{2, 2, 2};
253

254
    std::vector<test::conv::DeviceConvFwdNoOpPtr> conv_ptrs;
255
    test::conv::get_test_convolution_fwd_instance<3, T, T, T, T>(conv_ptrs);
256
257
258
    auto arg = conv_ptrs.back()->MakeArgumentPointer(nullptr,
                                                     nullptr,
                                                     nullptr,
Adam Osewski's avatar
Adam Osewski committed
259
260
261
262
263
                                                     params.N_,
                                                     params.K_,
                                                     params.C_,
                                                     params.input_spatial_lengths_,
                                                     params.filter_spatial_lengths_,
264
                                                     params.GetOutputSpatialLengths(),
Adam Osewski's avatar
Adam Osewski committed
265
266
267
268
                                                     params.conv_filter_strides_,
                                                     params.conv_filter_dilations_,
                                                     params.input_left_pads_,
                                                     params.input_right_pads_,
269
270
271
                                                     PassThrough{},
                                                     PassThrough{},
                                                     PassThrough{});
272
    EXPECT_FALSE(conv_ptrs.back()->IsSupportedArgument(arg.get()));
273
274
}

275
276
TEST_F(Conv3dFwdNDHWCInstances, BF16_default) { EXPECT_TRUE(this->test_default<ck::bhalf_t>()); }
TEST_F(Conv3dFwdNDHWCInstances, BF16_filter1x1_stride1_pad0)
277
{
278
279
280
281
282
    EXPECT_TRUE(this->test_filter1x1_stride1_pad0<ck::bhalf_t>());
}
TEST_F(Conv3dFwdNDHWCInstances, BF16_filter1x1_pad0)
{
    EXPECT_TRUE(this->test_filter1x1_pad0<ck::bhalf_t>());
283
284
}

285
286
TEST_F(Conv3dFwdNDHWCInstances, F16_default) { EXPECT_TRUE(this->test_default<ck::half_t>()); }
TEST_F(Conv3dFwdNDHWCInstances, F16_filter1x1_stride1_pad0)
287
{
288
289
290
291
292
    EXPECT_TRUE(this->test_filter1x1_stride1_pad0<ck::half_t>());
}
TEST_F(Conv3dFwdNDHWCInstances, F16_filter1x1_pad0)
{
    EXPECT_TRUE(this->test_filter1x1_pad0<ck::half_t>());
293
294
}

295
296
297
298
299
300
TEST_F(Conv3dFwdNDHWCInstances, F32_default) { EXPECT_TRUE(this->test_default<float>()); }
TEST_F(Conv3dFwdNDHWCInstances, F32_filter1x1_stride1_pad0)
{
    EXPECT_TRUE(this->test_filter1x1_stride1_pad0<float>());
}
TEST_F(Conv3dFwdNDHWCInstances, F32_filter1x1_pad0)
301
{
302
    EXPECT_TRUE(this->test_filter1x1_pad0<float>());
303
304
}

305
306
307
308
309
310
TEST_F(Conv3dFwdNDHWCInstances, I8_default) { EXPECT_TRUE(this->test_default<int8_t>()); }
TEST_F(Conv3dFwdNDHWCInstances, I8_filter1x1_stride1_pad0)
{
    EXPECT_TRUE(this->test_filter1x1_stride1_pad0<int8_t>());
}
TEST_F(Conv3dFwdNDHWCInstances, I8_filter1x1_pad0)
311
{
312
    EXPECT_TRUE(this->test_filter1x1_pad0<int8_t>());
313
}