conv_util.cpp 10 KB
Newer Older
1
2
3
#include <iostream>
#include <string>
#include <vector>
Adam Osewski's avatar
Adam Osewski committed
4
#include <gtest/gtest.h>
5

Chao Liu's avatar
Chao Liu committed
6
7
8
9
10
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"

#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/conv_util.hpp"
11
12
13

namespace {

14
class TestConvUtil : public ::testing::Test
15
{
16
17
18
    public:
    void SetNDParams(std::size_t ndims)
    {
Adam Osewski's avatar
Adam Osewski committed
19
20
21
22
23
24
25
        conv_params.num_dim_spatial_        = ndims;
        conv_params.filter_spatial_lengths_ = std::vector<ck::index_t>(ndims, 3);
        conv_params.input_spatial_lengths_  = std::vector<ck::index_t>(ndims, 71);
        conv_params.conv_filter_strides_    = std::vector<ck::index_t>(ndims, 2);
        conv_params.conv_filter_dilations_  = std::vector<ck::index_t>(ndims, 1);
        conv_params.input_left_pads_        = std::vector<ck::index_t>(ndims, 1);
        conv_params.input_right_pads_       = std::vector<ck::index_t>(ndims, 1);
26
27
28
29
    }

    protected:
    // -------  default 2D -------
30
31
32
33
34
    // input NCHW {128,192,71,71},
    // weights KCYX {256,192,3,3},
    // stride {2,2},
    // dilations {1,1},
    // padding {{1,1}, {1,1}}
35
    ck::utils::conv::ConvParams conv_params;
36
37
38
39
40
41
42
};

} // namespace

TEST_F(TestConvUtil, ConvParamsGetOutputSpatialLengths2D)
{
    ck::utils::conv::ConvParams conv_params;
43
    std::vector<ck::index_t> out_spatial_len = conv_params.GetOutputSpatialLengths();
44
45
46
    EXPECT_TRUE(ck::utils::check_err(out_spatial_len,
                                     std::vector<ck::index_t>{36, 36},
                                     "Error: ConvParams 2D default constructor."));
47

Adam Osewski's avatar
Adam Osewski committed
48
49
    conv_params.conv_filter_strides_ = std::vector<ck::index_t>{1, 1};
    out_spatial_len                  = conv_params.GetOutputSpatialLengths();
50
51
    EXPECT_TRUE(ck::utils::check_err(
        out_spatial_len, std::vector<ck::index_t>{71, 71}, "Error: ConvParams 2D stride {1,1}."));
52

Adam Osewski's avatar
Adam Osewski committed
53
54
55
56
    conv_params.conv_filter_strides_ = std::vector<ck::index_t>{2, 2};
    conv_params.input_left_pads_     = std::vector<ck::index_t>{2, 2};
    conv_params.input_right_pads_    = std::vector<ck::index_t>{2, 2};
    out_spatial_len                  = conv_params.GetOutputSpatialLengths();
57
58
59
    EXPECT_TRUE(ck::utils::check_err(out_spatial_len,
                                     std::vector<ck::index_t>{37, 37},
                                     "Error: ConvParams 2D padding left/right {2,2}."));
60

Adam Osewski's avatar
Adam Osewski committed
61
62
    conv_params.conv_filter_dilations_ = std::vector<ck::index_t>{2, 2};
    out_spatial_len                    = conv_params.GetOutputSpatialLengths();
63
64
    EXPECT_TRUE(ck::utils::check_err(
        out_spatial_len, std::vector<ck::index_t>{36, 36}, "Error: ConvParams 2D dilation {2,2}."));
65

Adam Osewski's avatar
Adam Osewski committed
66
67
68
69
70
    conv_params.conv_filter_strides_   = std::vector<ck::index_t>{3, 3};
    conv_params.input_left_pads_       = std::vector<ck::index_t>{1, 1};
    conv_params.input_right_pads_      = std::vector<ck::index_t>{1, 1};
    conv_params.conv_filter_dilations_ = std::vector<ck::index_t>{2, 2};
    out_spatial_len                    = conv_params.GetOutputSpatialLengths();
71
    EXPECT_TRUE(
72
73
        ck::utils::check_err(out_spatial_len,
                             std::vector<ck::index_t>{23, 23},
74
75
                             "Error: ConvParams 2D strides{3,3}, padding {1,1}, dilations {2,2}."));
}
76

77
78
79
TEST_F(TestConvUtil, ConvParamsGetOutputSpatialLengths1D)
{
    SetNDParams(1);
80

81
82
83
    std::vector<ck::index_t> out_spatial_len = conv_params.GetOutputSpatialLengths();
    EXPECT_TRUE(ck::utils::check_err(
        out_spatial_len, std::vector<ck::index_t>{36}, "Error: ConvParams 1D."));
84

Adam Osewski's avatar
Adam Osewski committed
85
86
    conv_params.conv_filter_strides_ = std::vector<ck::index_t>{1};
    out_spatial_len                  = conv_params.GetOutputSpatialLengths();
87
88
    EXPECT_TRUE(ck::utils::check_err(
        out_spatial_len, std::vector<ck::index_t>{71}, "Error: ConvParams 1D stride {1}."));
89

Adam Osewski's avatar
Adam Osewski committed
90
91
92
93
    conv_params.conv_filter_strides_ = std::vector<ck::index_t>{2};
    conv_params.input_left_pads_     = std::vector<ck::index_t>{2};
    conv_params.input_right_pads_    = std::vector<ck::index_t>{2};
    out_spatial_len                  = conv_params.GetOutputSpatialLengths();
94
95
96
    EXPECT_TRUE(ck::utils::check_err(out_spatial_len,
                                     std::vector<ck::index_t>{37},
                                     "Error: ConvParams 1D padding left/right {2}."));
97

Adam Osewski's avatar
Adam Osewski committed
98
99
    conv_params.conv_filter_dilations_ = std::vector<ck::index_t>{2};
    out_spatial_len                    = conv_params.GetOutputSpatialLengths();
100
101
    EXPECT_TRUE(ck::utils::check_err(
        out_spatial_len, std::vector<ck::index_t>{36}, "Error: ConvParams 1D dilation {2}."));
102

Adam Osewski's avatar
Adam Osewski committed
103
104
105
106
107
    conv_params.conv_filter_strides_   = std::vector<ck::index_t>{3};
    conv_params.input_left_pads_       = std::vector<ck::index_t>{1};
    conv_params.input_right_pads_      = std::vector<ck::index_t>{1};
    conv_params.conv_filter_dilations_ = std::vector<ck::index_t>{2};
    out_spatial_len                    = conv_params.GetOutputSpatialLengths();
108
109
110
111
112
113
114
115
116
117
118
119
120
    EXPECT_TRUE(
        ck::utils::check_err(out_spatial_len,
                             std::vector<ck::index_t>{23},
                             "Error: ConvParams 1D strides{3}, padding {1}, dilations {2}."));
}

TEST_F(TestConvUtil, ConvParamsGetOutputSpatialLengths3D)
{
    SetNDParams(3);

    std::vector<ck::index_t> out_spatial_len = conv_params.GetOutputSpatialLengths();
    EXPECT_TRUE(ck::utils::check_err(
        out_spatial_len, std::vector<ck::index_t>{36, 36, 36}, "Error: ConvParams 3D."));
121

Adam Osewski's avatar
Adam Osewski committed
122
123
    conv_params.conv_filter_strides_ = std::vector<ck::index_t>{1, 1, 1};
    out_spatial_len                  = conv_params.GetOutputSpatialLengths();
124
125
126
    EXPECT_TRUE(ck::utils::check_err(out_spatial_len,
                                     std::vector<ck::index_t>{71, 71, 71},
                                     "Error: ConvParams 3D stride {1, 1, 1}."));
127

Adam Osewski's avatar
Adam Osewski committed
128
129
130
131
    conv_params.conv_filter_strides_ = std::vector<ck::index_t>{2, 2, 2};
    conv_params.input_left_pads_     = std::vector<ck::index_t>{2, 2, 2};
    conv_params.input_right_pads_    = std::vector<ck::index_t>{2, 2, 2};
    out_spatial_len                  = conv_params.GetOutputSpatialLengths();
132
133
134
    EXPECT_TRUE(ck::utils::check_err(out_spatial_len,
                                     std::vector<ck::index_t>{37, 37, 37},
                                     "Error: ConvParams 3D padding left/right {2, 2, 2}."));
135

Adam Osewski's avatar
Adam Osewski committed
136
137
    conv_params.conv_filter_dilations_ = std::vector<ck::index_t>{2, 2, 2};
    out_spatial_len                    = conv_params.GetOutputSpatialLengths();
138
139
140
    EXPECT_TRUE(ck::utils::check_err(out_spatial_len,
                                     std::vector<ck::index_t>{36, 36, 36},
                                     "Error: ConvParams 3D dilation {2, 2, 2}."));
141

Adam Osewski's avatar
Adam Osewski committed
142
143
144
145
146
    conv_params.conv_filter_strides_   = std::vector<ck::index_t>{3, 3, 3};
    conv_params.input_left_pads_       = std::vector<ck::index_t>{1, 1, 1};
    conv_params.input_right_pads_      = std::vector<ck::index_t>{1, 1, 1};
    conv_params.conv_filter_dilations_ = std::vector<ck::index_t>{2, 2, 2};
    out_spatial_len                    = conv_params.GetOutputSpatialLengths();
147
    EXPECT_TRUE(ck::utils::check_err(
148
149
        out_spatial_len,
        std::vector<ck::index_t>{23, 23, 23},
150
        "Error: ConvParams 3D strides{3, 3, 3}, padding {1, 1, 1}, dilations {2, 2, 2}."));
151
152
}

153
TEST(ConvUtil, GetHostTensorDescriptor)
154
155
156
{
    namespace tl = ck::tensor_layout::convolution;
    std::vector<std::size_t> dims{2, 3, 4, 5};
157
    HostTensorDescriptor h = ck::utils::conv::get_host_tensor_descriptor(dims, tl::NHWC{});
158
159
160
161
    EXPECT_TRUE(ck::utils::check_err(
        h.GetLengths(), {2, 3, 4, 5}, "Error: wrong NHWC dimensions lengths!"));
    EXPECT_TRUE(ck::utils::check_err(
        h.GetStrides(), {3 * 4 * 5, 1, 3 * 5, 3}, "Error: wrong NHWC dimensions strides!"));
162

163
    h = ck::utils::conv::get_host_tensor_descriptor(dims, tl::NCHW{});
164
165
166
167
    EXPECT_TRUE(ck::utils::check_err(
        h.GetLengths(), {2, 3, 4, 5}, "Error: wrong NCHW dimensions lengths!"));
    EXPECT_TRUE(ck::utils::check_err(
        h.GetStrides(), {3 * 4 * 5, 4 * 5, 5, 1}, "Error: wrong NCHW dimensions strides!"));
168
169

    dims = std::vector<std::size_t>{2, 3, 4};
170
    h    = ck::utils::conv::get_host_tensor_descriptor(dims, tl::NWC{});
171
172
173
174
    EXPECT_TRUE(
        ck::utils::check_err(h.GetLengths(), {2, 3, 4}, "Error: wrong NWC dimensions lengths!"));
    EXPECT_TRUE(ck::utils::check_err(
        h.GetStrides(), {3 * 4, 1, 3}, "Error: wrong NWC dimensions strides!"));
175

176
177
178
179
180
    h = ck::utils::conv::get_host_tensor_descriptor(dims, tl::NCW{});
    EXPECT_TRUE(
        ck::utils::check_err(h.GetLengths(), {2, 3, 4}, "Error: wrong NCW dimensions lengths!"));
    EXPECT_TRUE(ck::utils::check_err(
        h.GetStrides(), {3 * 4, 4, 1}, "Error: wrong NCW dimensions strides!"));
181
182

    dims = std::vector<std::size_t>{2, 3, 4, 5, 6};
183
    h    = ck::utils::conv::get_host_tensor_descriptor(dims, tl::NDHWC{});
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
    EXPECT_TRUE(
        ck::utils::check_err(h.GetLengths(), dims, "Error: wrong NDHWC dimensions lengths!"));
    EXPECT_TRUE(ck::utils::check_err(h.GetStrides(),
                                     {3 * 4 * 5 * 6, // N
                                      1,             // C
                                      3 * 5 * 6,     // D
                                      3 * 6,         // H
                                      3},            // W
                                     "Error: wrong NDHWC dimensions strides!"));

    h = ck::utils::conv::get_host_tensor_descriptor(dims, tl::NCDHW{});
    EXPECT_TRUE(
        ck::utils::check_err(h.GetLengths(), dims, "Error: wrong NCDHW dimensions lengths!"));
    EXPECT_TRUE(ck::utils::check_err(h.GetStrides(),
                                     {3 * 4 * 5 * 6, // N
                                      4 * 5 * 6,     // C
                                      5 * 6,         // D
                                      6,             // H
                                      1},            // W
                                     "Error: wrong NCDHW dimensions strides!"));
204
}