"magic_pdf/git@developer.sourcefind.cn:wangsen/mineru.git" did not exist on "7990e7dfbbf510e5de1f7e06c56827ad159e92c6"
conv_util.cpp 9.96 KB
Newer Older
1
2
3
#include <iostream>
#include <string>
#include <vector>
Adam Osewski's avatar
Adam Osewski committed
4
#include <gtest/gtest.h>
5
6

#include "config.hpp"
Adam Osewski's avatar
Adam Osewski committed
7
#include "conv_util.hpp"
8
#include "tensor_layout.hpp"
9
#include "check_err.hpp"
10
11
12

namespace {

13
class TestConvUtil : public ::testing::Test
14
{
15
16
17
    public:
    void SetNDParams(std::size_t ndims)
    {
Adam Osewski's avatar
Adam Osewski committed
18
19
20
21
22
23
24
        conv_params.num_dim_spatial_        = ndims;
        conv_params.filter_spatial_lengths_ = std::vector<ck::index_t>(ndims, 3);
        conv_params.input_spatial_lengths_  = std::vector<ck::index_t>(ndims, 71);
        conv_params.conv_filter_strides_    = std::vector<ck::index_t>(ndims, 2);
        conv_params.conv_filter_dilations_  = std::vector<ck::index_t>(ndims, 1);
        conv_params.input_left_pads_        = std::vector<ck::index_t>(ndims, 1);
        conv_params.input_right_pads_       = std::vector<ck::index_t>(ndims, 1);
25
26
27
28
    }

    protected:
    // -------  default 2D -------
29
30
31
32
33
    // input NCHW {128,192,71,71},
    // weights KCYX {256,192,3,3},
    // stride {2,2},
    // dilations {1,1},
    // padding {{1,1}, {1,1}}
34
    ck::utils::conv::ConvParams conv_params;
35
36
37
38
39
40
41
};

} // namespace

TEST_F(TestConvUtil, ConvParamsGetOutputSpatialLengths2D)
{
    ck::utils::conv::ConvParams conv_params;
42
    std::vector<ck::index_t> out_spatial_len = conv_params.GetOutputSpatialLengths();
43
44
45
    EXPECT_TRUE(ck::utils::check_err(out_spatial_len,
                                     std::vector<ck::index_t>{36, 36},
                                     "Error: ConvParams 2D default constructor."));
46

Adam Osewski's avatar
Adam Osewski committed
47
48
    conv_params.conv_filter_strides_ = std::vector<ck::index_t>{1, 1};
    out_spatial_len                  = conv_params.GetOutputSpatialLengths();
49
50
    EXPECT_TRUE(ck::utils::check_err(
        out_spatial_len, std::vector<ck::index_t>{71, 71}, "Error: ConvParams 2D stride {1,1}."));
51

Adam Osewski's avatar
Adam Osewski committed
52
53
54
55
    conv_params.conv_filter_strides_ = std::vector<ck::index_t>{2, 2};
    conv_params.input_left_pads_     = std::vector<ck::index_t>{2, 2};
    conv_params.input_right_pads_    = std::vector<ck::index_t>{2, 2};
    out_spatial_len                  = conv_params.GetOutputSpatialLengths();
56
57
58
    EXPECT_TRUE(ck::utils::check_err(out_spatial_len,
                                     std::vector<ck::index_t>{37, 37},
                                     "Error: ConvParams 2D padding left/right {2,2}."));
59

Adam Osewski's avatar
Adam Osewski committed
60
61
    conv_params.conv_filter_dilations_ = std::vector<ck::index_t>{2, 2};
    out_spatial_len                    = conv_params.GetOutputSpatialLengths();
62
63
    EXPECT_TRUE(ck::utils::check_err(
        out_spatial_len, std::vector<ck::index_t>{36, 36}, "Error: ConvParams 2D dilation {2,2}."));
64

Adam Osewski's avatar
Adam Osewski committed
65
66
67
68
69
    conv_params.conv_filter_strides_   = std::vector<ck::index_t>{3, 3};
    conv_params.input_left_pads_       = std::vector<ck::index_t>{1, 1};
    conv_params.input_right_pads_      = std::vector<ck::index_t>{1, 1};
    conv_params.conv_filter_dilations_ = std::vector<ck::index_t>{2, 2};
    out_spatial_len                    = conv_params.GetOutputSpatialLengths();
70
    EXPECT_TRUE(
71
72
        ck::utils::check_err(out_spatial_len,
                             std::vector<ck::index_t>{23, 23},
73
74
                             "Error: ConvParams 2D strides{3,3}, padding {1,1}, dilations {2,2}."));
}
75

76
77
78
TEST_F(TestConvUtil, ConvParamsGetOutputSpatialLengths1D)
{
    SetNDParams(1);
79

80
81
82
    std::vector<ck::index_t> out_spatial_len = conv_params.GetOutputSpatialLengths();
    EXPECT_TRUE(ck::utils::check_err(
        out_spatial_len, std::vector<ck::index_t>{36}, "Error: ConvParams 1D."));
83

Adam Osewski's avatar
Adam Osewski committed
84
85
    conv_params.conv_filter_strides_ = std::vector<ck::index_t>{1};
    out_spatial_len                  = conv_params.GetOutputSpatialLengths();
86
87
    EXPECT_TRUE(ck::utils::check_err(
        out_spatial_len, std::vector<ck::index_t>{71}, "Error: ConvParams 1D stride {1}."));
88

Adam Osewski's avatar
Adam Osewski committed
89
90
91
92
    conv_params.conv_filter_strides_ = std::vector<ck::index_t>{2};
    conv_params.input_left_pads_     = std::vector<ck::index_t>{2};
    conv_params.input_right_pads_    = std::vector<ck::index_t>{2};
    out_spatial_len                  = conv_params.GetOutputSpatialLengths();
93
94
95
    EXPECT_TRUE(ck::utils::check_err(out_spatial_len,
                                     std::vector<ck::index_t>{37},
                                     "Error: ConvParams 1D padding left/right {2}."));
96

Adam Osewski's avatar
Adam Osewski committed
97
98
    conv_params.conv_filter_dilations_ = std::vector<ck::index_t>{2};
    out_spatial_len                    = conv_params.GetOutputSpatialLengths();
99
100
    EXPECT_TRUE(ck::utils::check_err(
        out_spatial_len, std::vector<ck::index_t>{36}, "Error: ConvParams 1D dilation {2}."));
101

Adam Osewski's avatar
Adam Osewski committed
102
103
104
105
106
    conv_params.conv_filter_strides_   = std::vector<ck::index_t>{3};
    conv_params.input_left_pads_       = std::vector<ck::index_t>{1};
    conv_params.input_right_pads_      = std::vector<ck::index_t>{1};
    conv_params.conv_filter_dilations_ = std::vector<ck::index_t>{2};
    out_spatial_len                    = conv_params.GetOutputSpatialLengths();
107
108
109
110
111
112
113
114
115
116
117
118
119
    EXPECT_TRUE(
        ck::utils::check_err(out_spatial_len,
                             std::vector<ck::index_t>{23},
                             "Error: ConvParams 1D strides{3}, padding {1}, dilations {2}."));
}

TEST_F(TestConvUtil, ConvParamsGetOutputSpatialLengths3D)
{
    SetNDParams(3);

    std::vector<ck::index_t> out_spatial_len = conv_params.GetOutputSpatialLengths();
    EXPECT_TRUE(ck::utils::check_err(
        out_spatial_len, std::vector<ck::index_t>{36, 36, 36}, "Error: ConvParams 3D."));
120

Adam Osewski's avatar
Adam Osewski committed
121
122
    conv_params.conv_filter_strides_ = std::vector<ck::index_t>{1, 1, 1};
    out_spatial_len                  = conv_params.GetOutputSpatialLengths();
123
124
125
    EXPECT_TRUE(ck::utils::check_err(out_spatial_len,
                                     std::vector<ck::index_t>{71, 71, 71},
                                     "Error: ConvParams 3D stride {1, 1, 1}."));
126

Adam Osewski's avatar
Adam Osewski committed
127
128
129
130
    conv_params.conv_filter_strides_ = std::vector<ck::index_t>{2, 2, 2};
    conv_params.input_left_pads_     = std::vector<ck::index_t>{2, 2, 2};
    conv_params.input_right_pads_    = std::vector<ck::index_t>{2, 2, 2};
    out_spatial_len                  = conv_params.GetOutputSpatialLengths();
131
132
133
    EXPECT_TRUE(ck::utils::check_err(out_spatial_len,
                                     std::vector<ck::index_t>{37, 37, 37},
                                     "Error: ConvParams 3D padding left/right {2, 2, 2}."));
134

Adam Osewski's avatar
Adam Osewski committed
135
136
    conv_params.conv_filter_dilations_ = std::vector<ck::index_t>{2, 2, 2};
    out_spatial_len                    = conv_params.GetOutputSpatialLengths();
137
138
139
    EXPECT_TRUE(ck::utils::check_err(out_spatial_len,
                                     std::vector<ck::index_t>{36, 36, 36},
                                     "Error: ConvParams 3D dilation {2, 2, 2}."));
140

Adam Osewski's avatar
Adam Osewski committed
141
142
143
144
145
    conv_params.conv_filter_strides_   = std::vector<ck::index_t>{3, 3, 3};
    conv_params.input_left_pads_       = std::vector<ck::index_t>{1, 1, 1};
    conv_params.input_right_pads_      = std::vector<ck::index_t>{1, 1, 1};
    conv_params.conv_filter_dilations_ = std::vector<ck::index_t>{2, 2, 2};
    out_spatial_len                    = conv_params.GetOutputSpatialLengths();
146
    EXPECT_TRUE(ck::utils::check_err(
147
148
        out_spatial_len,
        std::vector<ck::index_t>{23, 23, 23},
149
        "Error: ConvParams 3D strides{3, 3, 3}, padding {1, 1, 1}, dilations {2, 2, 2}."));
150
151
}

152
TEST(ConvUtil, GetHostTensorDescriptor)
153
154
155
{
    namespace tl = ck::tensor_layout::convolution;
    std::vector<std::size_t> dims{2, 3, 4, 5};
156
    HostTensorDescriptor h = ck::utils::conv::get_host_tensor_descriptor(dims, tl::NHWC{});
157
158
159
160
    EXPECT_TRUE(ck::utils::check_err(
        h.GetLengths(), {2, 3, 4, 5}, "Error: wrong NHWC dimensions lengths!"));
    EXPECT_TRUE(ck::utils::check_err(
        h.GetStrides(), {3 * 4 * 5, 1, 3 * 5, 3}, "Error: wrong NHWC dimensions strides!"));
161

162
    h = ck::utils::conv::get_host_tensor_descriptor(dims, tl::NCHW{});
163
164
165
166
    EXPECT_TRUE(ck::utils::check_err(
        h.GetLengths(), {2, 3, 4, 5}, "Error: wrong NCHW dimensions lengths!"));
    EXPECT_TRUE(ck::utils::check_err(
        h.GetStrides(), {3 * 4 * 5, 4 * 5, 5, 1}, "Error: wrong NCHW dimensions strides!"));
167
168

    dims = std::vector<std::size_t>{2, 3, 4};
169
    h    = ck::utils::conv::get_host_tensor_descriptor(dims, tl::NWC{});
170
171
172
173
    EXPECT_TRUE(
        ck::utils::check_err(h.GetLengths(), {2, 3, 4}, "Error: wrong NWC dimensions lengths!"));
    EXPECT_TRUE(ck::utils::check_err(
        h.GetStrides(), {3 * 4, 1, 3}, "Error: wrong NWC dimensions strides!"));
174

175
176
177
178
179
    h = ck::utils::conv::get_host_tensor_descriptor(dims, tl::NCW{});
    EXPECT_TRUE(
        ck::utils::check_err(h.GetLengths(), {2, 3, 4}, "Error: wrong NCW dimensions lengths!"));
    EXPECT_TRUE(ck::utils::check_err(
        h.GetStrides(), {3 * 4, 4, 1}, "Error: wrong NCW dimensions strides!"));
180
181

    dims = std::vector<std::size_t>{2, 3, 4, 5, 6};
182
    h    = ck::utils::conv::get_host_tensor_descriptor(dims, tl::NDHWC{});
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
    EXPECT_TRUE(
        ck::utils::check_err(h.GetLengths(), dims, "Error: wrong NDHWC dimensions lengths!"));
    EXPECT_TRUE(ck::utils::check_err(h.GetStrides(),
                                     {3 * 4 * 5 * 6, // N
                                      1,             // C
                                      3 * 5 * 6,     // D
                                      3 * 6,         // H
                                      3},            // W
                                     "Error: wrong NDHWC dimensions strides!"));

    h = ck::utils::conv::get_host_tensor_descriptor(dims, tl::NCDHW{});
    EXPECT_TRUE(
        ck::utils::check_err(h.GetLengths(), dims, "Error: wrong NCDHW dimensions lengths!"));
    EXPECT_TRUE(ck::utils::check_err(h.GetStrides(),
                                     {3 * 4 * 5 * 6, // N
                                      4 * 5 * 6,     // C
                                      5 * 6,         // D
                                      6,             // H
                                      1},            // W
                                     "Error: wrong NCDHW dimensions strides!"));
203
}