conv3d_fwd.cpp 10 KB
Newer Older
1
2
3
4
5
#include <half.hpp>
#include <iostream>
#include <stdexcept>
#include <tuple>
#include <vector>
6
#include "gtest/gtest.h"
7
8
9

#include "data_type.hpp"
#include "element_wise_operation.hpp"
Adam Osewski's avatar
Adam Osewski committed
10
#include "library/include/ck/library/utility/conv_util.hpp"
11
#include "conv_util.hpp"
12
13
14

namespace {

15
16
17
18
19
20
21
22
template <typename T>
bool test_conv3d_ndhwc_instances(const std::vector<test::conv::DeviceConvFwdNoOpPtr>& conv_ptrs)
{
    using namespace std::placeholders;
    using namespace ck::utils;
    namespace ctl = ck::tensor_layout::convolution;

    conv::ConvParams params;
Adam Osewski's avatar
Adam Osewski committed
23
24
25
26
27
28
29
30
    params.N_                      = 64;
    params.num_dim_spatial_        = 3;
    params.filter_spatial_lengths_ = std::vector<ck::index_t>{3, 3, 2};
    params.input_spatial_lengths_  = std::vector<ck::index_t>{32, 32, 2};
    params.conv_filter_strides_    = std::vector<ck::index_t>{2, 2, 2};
    params.conv_filter_dilations_  = std::vector<ck::index_t>{1, 1, 1};
    params.input_left_pads_        = std::vector<ck::index_t>{1, 1, 1};
    params.input_right_pads_       = std::vector<ck::index_t>{1, 1, 1};
31
32
33
34
35
36
37
38
39
40
41
42

    conv::ConvFwdOpInstance<T, T, T, ctl::NDHWC, ctl::KZYXC, ctl::NDHWK> conv_instance(params);

    auto reference_conv_fwd_fun =
        std::bind(conv::run_reference_convolution_forward<3, T, T, T>, params, _1, _2, _3);
    OpInstanceRunEngine<T, T, T> run_engine(conv_instance, reference_conv_fwd_fun);
    return run_engine.Test(conv_ptrs);
}

} // anonymous namespace

TEST(Conv3DFwdNDHWC, TestConv3D)
43
{
44
45
46
47
48
    using namespace std::placeholders;
    using namespace ck::utils;
    namespace ctl = ck::tensor_layout::convolution;

    conv::ConvParams params;
Adam Osewski's avatar
Adam Osewski committed
49
50
51
52
53
54
55
56
57
58
    params.num_dim_spatial_        = 3;
    params.N_                      = 2;
    params.K_                      = 16;
    params.C_                      = 4;
    params.filter_spatial_lengths_ = std::vector<ck::index_t>{3, 3, 3};
    params.input_spatial_lengths_  = std::vector<ck::index_t>{16, 16, 16};
    params.conv_filter_strides_    = std::vector<ck::index_t>{1, 1, 1};
    params.conv_filter_dilations_  = std::vector<ck::index_t>{1, 1, 1};
    params.input_left_pads_        = std::vector<ck::index_t>{1, 1, 1};
    params.input_right_pads_       = std::vector<ck::index_t>{1, 1, 1};
59

60
61
62
63
64
65
66
67
68
69
    std::vector<test::conv::DeviceConvFwdNoOpPtr> conv_ptrs;
    test::conv::get_test_convolution_fwd_instance<3>(conv_ptrs);
    conv::ConvFwdOpInstance<float, float, float, ctl::NDHWC, ctl::KZYXC, ctl::NDHWK> conv_instance(
        params);

    auto reference_conv_fwd_fun = std::bind(
        conv::run_reference_convolution_forward<3, float, float, float>, params, _1, _2, _3);
    OpInstanceRunEngine<float, float, float> run_engine(conv_instance, reference_conv_fwd_fun);
    run_engine.SetAtol(1e-5);
    run_engine.SetRtol(1e-4);
70
    EXPECT_TRUE(run_engine.Test(conv_ptrs));
71
72
}

73
TEST(Conv3DFwdNDHWC, InputOver2GB)
74
{
75
76
77
    using PassThrough = ck::tensor_operation::element_wise::PassThrough;
    using namespace ck::utils;

78
    // >2GB Input
79
    conv::ConvParams params;
Adam Osewski's avatar
Adam Osewski committed
80
81
82
83
84
85
86
87
88
89
    params.num_dim_spatial_        = 3;
    params.N_                      = 2;
    params.K_                      = 16;
    params.C_                      = 32;
    params.filter_spatial_lengths_ = std::vector<ck::index_t>{3, 3, 3};
    params.input_spatial_lengths_  = std::vector<ck::index_t>{32, 1000, 1000};
    params.conv_filter_strides_    = std::vector<ck::index_t>{1, 1, 1};
    params.conv_filter_dilations_  = std::vector<ck::index_t>{1, 1, 1};
    params.input_left_pads_        = std::vector<ck::index_t>{1, 1, 1};
    params.input_right_pads_       = std::vector<ck::index_t>{1, 1, 1};
90

91
92
93
94
95
96
    std::vector<test::conv::DeviceConvFwdNoOpPtr> conv_ptrs;
    test::conv::get_test_convolution_fwd_instance<3>(conv_ptrs);

    auto arg = conv_ptrs.back()->MakeArgumentPointer(nullptr,
                                                     nullptr,
                                                     nullptr,
Adam Osewski's avatar
Adam Osewski committed
97
98
99
100
101
                                                     params.N_,
                                                     params.K_,
                                                     params.C_,
                                                     params.input_spatial_lengths_,
                                                     params.filter_spatial_lengths_,
102
                                                     params.GetOutputSpatialLengths(),
Adam Osewski's avatar
Adam Osewski committed
103
104
105
106
                                                     params.conv_filter_strides_,
                                                     params.conv_filter_dilations_,
                                                     params.input_left_pads_,
                                                     params.input_right_pads_,
107
108
109
                                                     PassThrough{},
                                                     PassThrough{},
                                                     PassThrough{});
110
    EXPECT_FALSE(conv_ptrs.back()->IsSupportedArgument(arg.get()));
111
112
}

113
TEST(Conv3DFwdNDHWC, FiltersOver2GB)
114
{
115
116
117
    using PassThrough = ck::tensor_operation::element_wise::PassThrough;
    using namespace ck::utils;

118
    // >2GB Filters
119
    conv::ConvParams params;
Adam Osewski's avatar
Adam Osewski committed
120
121
122
123
124
125
126
127
128
129
    params.num_dim_spatial_        = 3;
    params.N_                      = 2;
    params.K_                      = 16;
    params.C_                      = 32;
    params.filter_spatial_lengths_ = std::vector<ck::index_t>{4, 1000, 1000};
    params.input_spatial_lengths_  = std::vector<ck::index_t>{16, 16, 16};
    params.conv_filter_strides_    = std::vector<ck::index_t>{1, 1, 1};
    params.conv_filter_dilations_  = std::vector<ck::index_t>{1, 1, 1};
    params.input_left_pads_        = std::vector<ck::index_t>{1, 1, 1};
    params.input_right_pads_       = std::vector<ck::index_t>{1, 1, 1};
130

131
132
133
134
135
136
    std::vector<test::conv::DeviceConvFwdNoOpPtr> conv_ptrs;
    test::conv::get_test_convolution_fwd_instance<3>(conv_ptrs);

    auto arg = conv_ptrs.back()->MakeArgumentPointer(nullptr,
                                                     nullptr,
                                                     nullptr,
Adam Osewski's avatar
Adam Osewski committed
137
138
139
140
141
                                                     params.N_,
                                                     params.K_,
                                                     params.C_,
                                                     params.input_spatial_lengths_,
                                                     params.filter_spatial_lengths_,
142
                                                     params.GetOutputSpatialLengths(),
Adam Osewski's avatar
Adam Osewski committed
143
144
145
146
                                                     params.conv_filter_strides_,
                                                     params.conv_filter_dilations_,
                                                     params.input_left_pads_,
                                                     params.input_right_pads_,
147
148
149
                                                     PassThrough{},
                                                     PassThrough{},
                                                     PassThrough{});
150
    EXPECT_FALSE(conv_ptrs.back()->IsSupportedArgument(arg.get()));
151
152
}

153
TEST(Conv3DFwdNDHWC, OutputOver2GB)
154
{
155
156
157
    using PassThrough = ck::tensor_operation::element_wise::PassThrough;
    using namespace ck::utils;

158
    // >2GB Output
159
    conv::ConvParams params;
Adam Osewski's avatar
Adam Osewski committed
160
161
162
163
164
165
166
167
168
169
    params.num_dim_spatial_        = 3;
    params.N_                      = 2;
    params.K_                      = 16;
    params.C_                      = 2;
    params.filter_spatial_lengths_ = std::vector<ck::index_t>{1, 1, 1};
    params.input_spatial_lengths_  = std::vector<ck::index_t>{1000, 1000, 30};
    params.conv_filter_strides_    = std::vector<ck::index_t>{1, 1, 1};
    params.conv_filter_dilations_  = std::vector<ck::index_t>{1, 1, 1};
    params.input_left_pads_        = std::vector<ck::index_t>{2, 2, 2};
    params.input_right_pads_       = std::vector<ck::index_t>{2, 2, 2};
170

171
172
173
174
175
    std::vector<test::conv::DeviceConvFwdNoOpPtr> conv_ptrs;
    test::conv::get_test_convolution_fwd_instance<3>(conv_ptrs);
    auto arg = conv_ptrs.back()->MakeArgumentPointer(nullptr,
                                                     nullptr,
                                                     nullptr,
Adam Osewski's avatar
Adam Osewski committed
176
177
178
179
180
                                                     params.N_,
                                                     params.K_,
                                                     params.C_,
                                                     params.input_spatial_lengths_,
                                                     params.filter_spatial_lengths_,
181
                                                     params.GetOutputSpatialLengths(),
Adam Osewski's avatar
Adam Osewski committed
182
183
184
185
                                                     params.conv_filter_strides_,
                                                     params.conv_filter_dilations_,
                                                     params.input_left_pads_,
                                                     params.input_right_pads_,
186
187
188
                                                     PassThrough{},
                                                     PassThrough{},
                                                     PassThrough{});
189
    EXPECT_FALSE(conv_ptrs.back()->IsSupportedArgument(arg.get()));
190
191
}

192
TEST(Conv3DFwdNDHWC, Bf16Instances)
193
{
194
195
    EXPECT_TRUE(test_conv3d_ndhwc_instances<ck::bhalf_t>(
        ck::utils::conv::ConvolutionFwdInstances<ck::bhalf_t, ck::bhalf_t, ck::bhalf_t>::Get<3>()));
196
197
}

198
TEST(Conv3DFwdNDHWC, F16Instances)
199
{
200
201
    EXPECT_TRUE(test_conv3d_ndhwc_instances<ck::half_t>(
        ck::utils::conv::ConvolutionFwdInstances<ck::half_t, ck::half_t, ck::half_t>::Get<3>()));
202
203
}

204
TEST(Conv3DFwdNDHWC, F32Instances)
205
{
206
207
    EXPECT_TRUE(test_conv3d_ndhwc_instances<float>(
        ck::utils::conv::ConvolutionFwdInstances<float, float, float>::Get<3>()));
208
209
}

210
TEST(Conv3DFwdNDHWC, Int8Instances)
211
{
212
213
    EXPECT_TRUE(test_conv3d_ndhwc_instances<int8_t>(
        ck::utils::conv::ConvolutionFwdInstances<int8_t, int8_t, int8_t>::Get<3>()));
214
}