"...resnet50_tensorflow.git" did not exist on "f3e7cc255cf0064b1cdbb4f10bc619b62cbb2630"
conv_util.cpp 10.6 KB
Newer Older
1
2
3
4
5
#include <iostream>
#include <string>
#include <vector>

#include "config.hpp"
6
#include "conv_fwd_util.hpp"
7
#include "tensor_layout.hpp"
8
#include "check_err.hpp"
9
10
11

namespace {

12
bool test_conv_params_get_output_spatial_lengths()
13
14
15
16
17
18
19
20
{
    bool res{true};
    // -------------------------- default 2D ------------------------------------
    // input NCHW {128,192,71,71},
    // weights KCYX {256,192,3,3},
    // stride {2,2},
    // dilations {1,1},
    // padding {{1,1}, {1,1}}
21
    ck::utils::conv::ConvParams conv_params;
22
    std::vector<ck::index_t> out_spatial_len = conv_params.GetOutputSpatialLengths();
23
24
25
    res                                      = ck::utils::check_err(out_spatial_len,
                               std::vector<ck::index_t>{36, 36},
                               "Error: ConvParams 2D default constructor.");
26
27
28

    conv_params.conv_filter_strides = std::vector<ck::index_t>{1, 1};
    out_spatial_len                 = conv_params.GetOutputSpatialLengths();
29
    res                             = ck::utils::check_err(
30
31
32
33
34
35
        out_spatial_len, std::vector<ck::index_t>{71, 71}, "Error: ConvParams 2D stride {1,1}.");

    conv_params.conv_filter_strides = std::vector<ck::index_t>{2, 2};
    conv_params.input_left_pads     = std::vector<ck::index_t>{2, 2};
    conv_params.input_right_pads    = std::vector<ck::index_t>{2, 2};
    out_spatial_len                 = conv_params.GetOutputSpatialLengths();
36
37
38
    res                             = ck::utils::check_err(out_spatial_len,
                               std::vector<ck::index_t>{37, 37},
                               "Error: ConvParams 2D padding left/right {2,2}.");
39
40
41

    conv_params.conv_filter_dilations = std::vector<ck::index_t>{2, 2};
    out_spatial_len                   = conv_params.GetOutputSpatialLengths();
42
    res                               = ck::utils::check_err(
43
44
45
46
47
48
49
        out_spatial_len, std::vector<ck::index_t>{36, 36}, "Error: ConvParams 2D dilation {2,2}.");

    conv_params.conv_filter_strides   = std::vector<ck::index_t>{3, 3};
    conv_params.input_left_pads       = std::vector<ck::index_t>{1, 1};
    conv_params.input_right_pads      = std::vector<ck::index_t>{1, 1};
    conv_params.conv_filter_dilations = std::vector<ck::index_t>{2, 2};
    out_spatial_len                   = conv_params.GetOutputSpatialLengths();
50
51
52
53
    res =
        ck::utils::check_err(out_spatial_len,
                             std::vector<ck::index_t>{23, 23},
                             "Error: ConvParams 2D strides{3,3}, padding {1,1}, dilations {2,2}.");
54
55
56
57
58
59
60
61
62
63
64

    // -------------------------- 1D ------------------------------------
    conv_params.num_dim_spatial        = 1;
    conv_params.filter_spatial_lengths = std::vector<ck::index_t>{3};
    conv_params.input_spatial_lengths  = std::vector<ck::index_t>{71};
    conv_params.conv_filter_strides    = std::vector<ck::index_t>{2};
    conv_params.conv_filter_dilations  = std::vector<ck::index_t>{1};
    conv_params.input_left_pads        = std::vector<ck::index_t>{1};
    conv_params.input_right_pads       = std::vector<ck::index_t>{1};

    out_spatial_len = conv_params.GetOutputSpatialLengths();
65
66
    res             = ck::utils::check_err(
        out_spatial_len, std::vector<ck::index_t>{36}, "Error: ConvParams 1D.");
67

68
    conv_params.conv_filter_strides = std::vector<ck::index_t>{1};
69
    out_spatial_len                 = conv_params.GetOutputSpatialLengths();
70
    res                             = ck::utils::check_err(
71
        out_spatial_len, std::vector<ck::index_t>{71}, "Error: ConvParams 1D stride {1}.");
72
73
74
75
76

    conv_params.conv_filter_strides = std::vector<ck::index_t>{2};
    conv_params.input_left_pads     = std::vector<ck::index_t>{2};
    conv_params.input_right_pads    = std::vector<ck::index_t>{2};
    out_spatial_len                 = conv_params.GetOutputSpatialLengths();
77
78
79
    res                             = ck::utils::check_err(out_spatial_len,
                               std::vector<ck::index_t>{37},
                               "Error: ConvParams 1D padding left/right {2}.");
80
81
82

    conv_params.conv_filter_dilations = std::vector<ck::index_t>{2};
    out_spatial_len                   = conv_params.GetOutputSpatialLengths();
83
    res                               = ck::utils::check_err(
84
85
86
87
88
89
90
        out_spatial_len, std::vector<ck::index_t>{36}, "Error: ConvParams 1D dilation {2}.");

    conv_params.conv_filter_strides   = std::vector<ck::index_t>{3};
    conv_params.input_left_pads       = std::vector<ck::index_t>{1};
    conv_params.input_right_pads      = std::vector<ck::index_t>{1};
    conv_params.conv_filter_dilations = std::vector<ck::index_t>{2};
    out_spatial_len                   = conv_params.GetOutputSpatialLengths();
91
92
93
    res                               = ck::utils::check_err(out_spatial_len,
                               std::vector<ck::index_t>{23},
                               "Error: ConvParams 1D strides{3}, padding {1}, dilations {2}.");
94
95
96
97
98
99
100
101
102
103
104

    // -------------------------- 3D ------------------------------------
    conv_params.num_dim_spatial        = 3;
    conv_params.filter_spatial_lengths = std::vector<ck::index_t>{3, 3, 3};
    conv_params.input_spatial_lengths  = std::vector<ck::index_t>{71, 71, 71};
    conv_params.conv_filter_strides    = std::vector<ck::index_t>{2, 2, 2};
    conv_params.conv_filter_dilations  = std::vector<ck::index_t>{1, 1, 1};
    conv_params.input_left_pads        = std::vector<ck::index_t>{1, 1, 1};
    conv_params.input_right_pads       = std::vector<ck::index_t>{1, 1, 1};

    out_spatial_len = conv_params.GetOutputSpatialLengths();
105
    res             = ck::utils::check_err(
106
107
108
109
        out_spatial_len, std::vector<ck::index_t>{36, 36, 36}, "Error: ConvParams 3D.");

    conv_params.conv_filter_strides = std::vector<ck::index_t>{1, 1, 1};
    out_spatial_len                 = conv_params.GetOutputSpatialLengths();
110
111
112
    res                             = ck::utils::check_err(out_spatial_len,
                               std::vector<ck::index_t>{71, 71, 71},
                               "Error: ConvParams 3D stride {1, 1, 1}.");
113
114
115
116
117

    conv_params.conv_filter_strides = std::vector<ck::index_t>{2, 2, 2};
    conv_params.input_left_pads     = std::vector<ck::index_t>{2, 2, 2};
    conv_params.input_right_pads    = std::vector<ck::index_t>{2, 2, 2};
    out_spatial_len                 = conv_params.GetOutputSpatialLengths();
118
119
120
    res                             = ck::utils::check_err(out_spatial_len,
                               std::vector<ck::index_t>{37, 37, 37},
                               "Error: ConvParams 3D padding left/right {2, 2, 2}.");
121
122
123

    conv_params.conv_filter_dilations = std::vector<ck::index_t>{2, 2, 2};
    out_spatial_len                   = conv_params.GetOutputSpatialLengths();
124
125
126
    res                               = ck::utils::check_err(out_spatial_len,
                               std::vector<ck::index_t>{36, 36, 36},
                               "Error: ConvParams 3D dilation {2, 2, 2}.");
127
128
129
130
131
132

    conv_params.conv_filter_strides   = std::vector<ck::index_t>{3, 3, 3};
    conv_params.input_left_pads       = std::vector<ck::index_t>{1, 1, 1};
    conv_params.input_right_pads      = std::vector<ck::index_t>{1, 1, 1};
    conv_params.conv_filter_dilations = std::vector<ck::index_t>{2, 2, 2};
    out_spatial_len                   = conv_params.GetOutputSpatialLengths();
133
    res                               = ck::utils::check_err(
134
135
136
        out_spatial_len,
        std::vector<ck::index_t>{23, 23, 23},
        "Error: ConvParams 3D strides{3, 3, 3}, padding {1, 1, 1}, dilations {2, 2, 2}.");
137
138
139
140

    return res;
}

141
bool test_get_host_tensor_descriptor()
142
143
144
145
{
    bool res{true};
    namespace tl = ck::tensor_layout::convolution;
    std::vector<std::size_t> dims{2, 3, 4, 5};
146
147
148
149
    HostTensorDescriptor h = ck::utils::conv::get_host_tensor_descriptor(dims, tl::NHWC{});
    res =
        ck::utils::check_err(h.GetLengths(), {2, 3, 4, 5}, "Error: wrong NHWC dimensions lengths!");
    res = ck::utils::check_err(
150
        h.GetStrides(), {3 * 4 * 5, 1, 3 * 5, 3}, "Error: wrong NHWC dimensions strides!");
151

152
153
154
155
    h = ck::utils::conv::get_host_tensor_descriptor(dims, tl::NCHW{});
    res =
        ck::utils::check_err(h.GetLengths(), {2, 3, 4, 5}, "Error: wrong NCHW dimensions lengths!");
    res = ck::utils::check_err(
156
        h.GetStrides(), {3 * 4 * 5, 4 * 5, 5, 1}, "Error: wrong NCHW dimensions strides!");
157
158

    dims = std::vector<std::size_t>{2, 3, 4};
159
160
161
162
    h    = ck::utils::conv::get_host_tensor_descriptor(dims, tl::NWC{});
    res  = ck::utils::check_err(h.GetLengths(), {2, 3, 4}, "Error: wrong NWC dimensions lengths!");
    res =
        ck::utils::check_err(h.GetStrides(), {3 * 4, 1, 3}, "Error: wrong NWC dimensions strides!");
163

164
165
166
167
    h   = ck::utils::conv::get_host_tensor_descriptor(dims, tl::NCW{});
    res = ck::utils::check_err(h.GetLengths(), {2, 3, 4}, "Error: wrong NCW dimensions lengths!");
    res =
        ck::utils::check_err(h.GetStrides(), {3 * 4, 4, 1}, "Error: wrong NCW dimensions strides!");
168
169

    dims = std::vector<std::size_t>{2, 3, 4, 5, 6};
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
    h    = ck::utils::conv::get_host_tensor_descriptor(dims, tl::NDHWC{});
    res  = ck::utils::check_err(h.GetLengths(), dims, "Error: wrong NDHWC dimensions lengths!");
    res  = ck::utils::check_err(h.GetStrides(),
                               {3 * 4 * 5 * 6, // N
                                1,             // C
                                3 * 5 * 6,     // D
                                3 * 6,         // H
                                3},            // W
                               "Error: wrong NDHWC dimensions strides!");

    h   = ck::utils::conv::get_host_tensor_descriptor(dims, tl::NCDHW{});
    res = ck::utils::check_err(h.GetLengths(), dims, "Error: wrong NCDHW dimensions lengths!");
    res = ck::utils::check_err(h.GetStrides(),
                               {3 * 4 * 5 * 6, // N
                                4 * 5 * 6,     // C
                                5 * 6,         // D
                                6,             // H
                                1},            // W
                               "Error: wrong NCDHW dimensions strides!");
189
190
191
192
193
194
195
196

    return res;
}

} // namespace

int main(void)
{
197
198
199
200
201
    bool res = test_conv_params_get_output_spatial_lengths();
    std::cout << "test_conv_params_get_output_spatial_lengths ..... "
              << (res ? "SUCCESS" : "FAILURE") << std::endl;
    res = test_get_host_tensor_descriptor();
    std::cout << "test_get_host_tensor_descriptor ..... " << (res ? "SUCCESS" : "FAILURE")
202
              << std::endl;
203
    return res ? 0 : 1;
204
}