driver.cpp 36.9 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
#include <iostream>
Chao Liu's avatar
Chao Liu committed
2
3
#include <numeric>
#include <initializer_list>
Chao Liu's avatar
Chao Liu committed
4
#include <cstdlib>
Chao Liu's avatar
Chao Liu committed
5
#include <stdlib.h>
Chao Liu's avatar
Chao Liu committed
6
7
8
#include "config.hpp"
#include "ConstantTensorDescriptor.hpp"
#include "device.hpp"
Chao Liu's avatar
Chao Liu committed
9
#include "conv_common.hpp"
Chao Liu's avatar
Chao Liu committed
10
#include "device_convolution_direct_v2_nchw_kcyx_nkhw.hpp"
11
12
13
14
15
16
// #include "device_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hpp"
// #include "device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw.hpp"
// #include "device_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hpp"
// #include "device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw.hpp"
//#include "device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw.hpp"
#include "device_convolution_implicit_gemm_v5_nchw_kcyx_nkhw.hpp"
Chao Liu's avatar
Chao Liu committed
17

18
19
using namespace ck;

20
21
22
23
#define CONV_DIRECTION_FWD_DATA 0
#define CONV_DIRECTION_BWD_DATA 0
#define CONV_DIRECTION_BWD_WEIT 1

Chao Liu's avatar
Chao Liu committed
24
struct GeneratorTensor_1
Chao Liu's avatar
Chao Liu committed
25
26
{
    template <class... Is>
Chao Liu's avatar
Chao Liu committed
27
    double operator()(Is... is)
Chao Liu's avatar
Chao Liu committed
28
    {
Chao Liu's avatar
Chao Liu committed
29
        return 1;
Chao Liu's avatar
Chao Liu committed
30
31
32
    }
};

Chao Liu's avatar
Chao Liu committed
33
34
35
struct GeneratorTensor_2
{
    int min_value = 0;
36
    int max_value = 16;
Chao Liu's avatar
Chao Liu committed
37
38
39
40
41
42
43
44

    template <class... Is>
    double operator()(Is...)
    {
        return (std::rand() % (max_value - min_value)) + min_value;
    }
};

45
46
47
48
49
50
51
struct GeneratorTensor_3
{
    template <class... Is>
    double operator()(Is... is)
    {
        std::array<index_t, sizeof...(Is)> dims = {{static_cast<index_t>(is)...}};

52
        auto f_acc = [](auto a, auto b) { return 100 * a + b; };
53

54
        return std::accumulate(dims.begin(), dims.end(), index_t(0), f_acc);
55
56
    }
};
57
58
59
60
61
62
63
64
65
66
67
68
69
70
struct GeneratorTensor_fixed
{
    template <class... Is>
    double operator()(Is... is)
    {
        std::array<index_t, sizeof...(Is)> dims = {{static_cast<index_t>(is)...}};

        if(dims[0] == 0)
            return (dims[1]*16 + dims[2]*4 + dims[3]);
        else
            return 1;
    }
};

71

Chao Liu's avatar
Chao Liu committed
72
73
74
75
76
struct GeneratorTensor_Checkboard
{
    template <class... Ts>
    double operator()(Ts... Xs) const
    {
77
        std::array<index_t, sizeof...(Ts)> dims = {{static_cast<index_t>(Xs)...}};
Chao Liu's avatar
Chao Liu committed
78
79
80
        return std::accumulate(dims.begin(),
                               dims.end(),
                               true,
Chao Liu's avatar
Chao Liu committed
81
                               [](bool init, index_t x) -> int { return init != (x % 2); })
Chao Liu's avatar
Chao Liu committed
82
83
84
85
86
                   ? 1
                   : -1;
    }
};

Chao Liu's avatar
Chao Liu committed
87
88
89
90
91
92
// this is ugly, only for 4d
template <class TConstTensorDesc>
void ostream_ConstantTensorDescriptor(TConstTensorDesc, std::ostream& os = std::cout)
{
    static_assert(TConstTensorDesc::nDim == 4, "nDim is not 4");

Chao Liu's avatar
Chao Liu committed
93
94
95
96
    constexpr auto I0   = Number<0>{};
    constexpr auto I1   = Number<1>{};
    constexpr auto I2   = Number<2>{};
    constexpr auto I3   = Number<3>{};
Chao Liu's avatar
Chao Liu committed
97
98
99
100
101
102
103
104
105
106
107
108
109
110
    constexpr auto desc = TConstTensorDesc{};

    os << "Lengths: {" << desc.GetLength(I0) << ", " << desc.GetLength(I1) << ", "
       << desc.GetLength(I2) << ", " << desc.GetLength(I3) << "}, "
       << "Strides: {" << desc.GetStride(I0) << ", " << desc.GetStride(I1) << ", "
       << desc.GetStride(I2) << ", " << desc.GetStride(I3) << "}" << std::endl;
}

// this is ugly, only for 4d
template <class TConstTensorDesc>
auto make_TensorDescriptor(TConstTensorDesc)
{
    static_assert(TConstTensorDesc::nDim == 4, "nDim is not 4");

Chao Liu's avatar
Chao Liu committed
111
112
113
114
    constexpr auto I0   = Number<0>{};
    constexpr auto I1   = Number<1>{};
    constexpr auto I2   = Number<2>{};
    constexpr auto I3   = Number<3>{};
Chao Liu's avatar
Chao Liu committed
115
116
    constexpr auto desc = TConstTensorDesc{};

Chao Liu's avatar
Chao Liu committed
117
    std::initializer_list<index_t> lengths = {
Chao Liu's avatar
Chao Liu committed
118
        desc.GetLength(I0), desc.GetLength(I1), desc.GetLength(I2), desc.GetLength(I3)};
Chao Liu's avatar
Chao Liu committed
119
    std::initializer_list<index_t> strides = {
Chao Liu's avatar
Chao Liu committed
120
121
122
123
124
        desc.GetStride(I0), desc.GetStride(I1), desc.GetStride(I2), desc.GetStride(I3)};

    return TensorDescriptor(lengths, strides);
}

125
126
127
128
129
130
131
template <class TIn,
          class TWei,
          class TOut,
          class ConvStrides,
          class ConvDilations,
          class LowerPads,
          class UpperPads>
132
133
134
void host_direct_convolution(const Tensor<TIn>& in_nchw,
                             const Tensor<TWei>& wei_kcyx,
                             Tensor<TOut>& out_nkhw,
135
136
                             ConvStrides,
                             ConvDilations,
137
138
                             LowerPads,
                             UpperPads)
Chao Liu's avatar
Chao Liu committed
139
{
Chao Liu's avatar
Chao Liu committed
140
141
    index_t h_pad_low = LowerPads{}.Get(Number<0>{});
    index_t w_pad_low = LowerPads{}.Get(Number<1>{});
142

Chao Liu's avatar
Chao Liu committed
143
144
    index_t h_pad_up = UpperPads{}.Get(Number<0>{});
    index_t w_pad_up = UpperPads{}.Get(Number<1>{});
145

Chao Liu's avatar
Chao Liu committed
146
147
    auto f = [&](auto n, auto k, auto ho, auto wo) {
        double v = 0;
Chao Liu's avatar
Chao Liu committed
148
        for(int c = 0; c < wei_kcyx.mDesc.GetLengths()[1]; ++c)
Chao Liu's avatar
Chao Liu committed
149
        {
Chao Liu's avatar
Chao Liu committed
150
            for(int y = 0; y < wei_kcyx.mDesc.GetLengths()[2]; ++y)
Chao Liu's avatar
Chao Liu committed
151
            {
152
                int hi = ho * ConvStrides{}[0] + y * ConvDilations{}[0] - h_pad_low;
Chao Liu's avatar
Chao Liu committed
153
                for(int x = 0; x < wei_kcyx.mDesc.GetLengths()[3]; ++x)
Chao Liu's avatar
Chao Liu committed
154
                {
155
                    int wi = wo * ConvStrides{}[1] + x * ConvDilations{}[1] - w_pad_low;
156
157
158
                    if(hi >= 0 && hi < in_nchw.mDesc.GetLengths()[2] && wi >= 0 &&
                       wi < in_nchw.mDesc.GetLengths()[3])
                    {
159
                        v += double(in_nchw(n, c, hi, wi)) * double(wei_kcyx(k, c, y, x));
160
                    }
Chao Liu's avatar
Chao Liu committed
161
162
163
                }
            }
        }
164
        out_nkhw(n, k, ho, wo) = v;
Chao Liu's avatar
Chao Liu committed
165
166
167
    };

    auto f_par = make_ParallelTensorFunctor(f,
168
169
170
171
                                            out_nkhw.mDesc.GetLengths()[0],
                                            out_nkhw.mDesc.GetLengths()[1],
                                            out_nkhw.mDesc.GetLengths()[2],
                                            out_nkhw.mDesc.GetLengths()[3]);
Chao Liu's avatar
Chao Liu committed
172

Chao Liu's avatar
Chao Liu committed
173
    f_par(std::thread::hardware_concurrency());
Chao Liu's avatar
Chao Liu committed
174
175
}

176
177
178
179
180
181
template <class TIn, class TWei, class TOut, class LowerPads, class UpperPads>
void host_winograd_3x3_convolution(const Tensor<TIn>& in_nchw,
                                   const Tensor<TWei>& wei_kcyx,
                                   Tensor<TOut>& out_nkhw,
                                   LowerPads,
                                   UpperPads)
Chao Liu's avatar
Chao Liu committed
182
{
Chao Liu's avatar
Chao Liu committed
183
184
    constexpr std::size_t HoPerTile = 2;
    constexpr std::size_t WoPerTile = 2;
Chao Liu's avatar
Chao Liu committed
185

Chao Liu's avatar
Chao Liu committed
186
187
188
189
    std::size_t N  = in_nchw.mDesc.GetLengths()[0];
    std::size_t C  = in_nchw.mDesc.GetLengths()[1];
    std::size_t HI = in_nchw.mDesc.GetLengths()[2];
    std::size_t WI = in_nchw.mDesc.GetLengths()[3];
Chao Liu's avatar
Chao Liu committed
190

Chao Liu's avatar
Chao Liu committed
191
192
193
    std::size_t K = wei_kcyx.mDesc.GetLengths()[0];
    std::size_t Y = wei_kcyx.mDesc.GetLengths()[2];
    std::size_t X = wei_kcyx.mDesc.GetLengths()[3];
Chao Liu's avatar
Chao Liu committed
194

195
196
    std::size_t HO = out_nkhw.mDesc.GetLengths()[2];
    std::size_t WO = out_nkhw.mDesc.GetLengths()[3];
Chao Liu's avatar
Chao Liu committed
197

Chao Liu's avatar
Chao Liu committed
198
199
    index_t h_pad_low = LowerPads{}.Get(Number<0>{});
    index_t w_pad_low = LowerPads{}.Get(Number<1>{});
200

Chao Liu's avatar
Chao Liu committed
201
202
    index_t h_pad_up = UpperPads{}.Get(Number<0>{});
    index_t w_pad_up = UpperPads{}.Get(Number<1>{});
203

Chao Liu's avatar
Chao Liu committed
204
205
    std::size_t HiPerTile = HoPerTile + Y - 1;
    std::size_t WiPerTile = WoPerTile + X - 1;
Chao Liu's avatar
Chao Liu committed
206

Chao Liu's avatar
Chao Liu committed
207
208
    std::size_t HTile = (HO + HoPerTile - 1) / HoPerTile;
    std::size_t WTile = (WO + WoPerTile - 1) / WoPerTile;
Chao Liu's avatar
Chao Liu committed
209

210
211
212
213
214
    Tensor<double> in_hold({N, C, HTile, WTile, HiPerTile, WiPerTile});
    Tensor<double> in_transform({N, C, HTile, WTile, HiPerTile, WiPerTile});
    Tensor<double> wei_transform({K, C, HiPerTile, WiPerTile});
    Tensor<double> out_transform({N, K, HTile, WTile, HiPerTile, HiPerTile});
    Tensor<double> out_hold({N, K, HTile, WTile, HoPerTile, WoPerTile});
Chao Liu's avatar
Chao Liu committed
215

Chao Liu's avatar
Chao Liu committed
216
217
    auto f_in_hold = [&](auto n, auto c, auto htile, auto wtile) {
        for(int j = 0; j < HiPerTile; ++j)
Chao Liu's avatar
Chao Liu committed
218
        {
Chao Liu's avatar
Chao Liu committed
219
220
            int hi = HoPerTile * htile + j - h_pad_low;
            for(int i = 0; i < WiPerTile; ++i)
Chao Liu's avatar
Chao Liu committed
221
            {
Chao Liu's avatar
Chao Liu committed
222
                int wi = WoPerTile * wtile + i - w_pad_low;
223
224
225
226

                if(hi >= 0 && hi < in_nchw.mDesc.GetLengths()[2] && wi >= 0 &&
                   wi < in_nchw.mDesc.GetLengths()[3])
                {
Chao Liu's avatar
Chao Liu committed
227
                    in_hold(n, c, htile, wtile, j, i) = in_nchw(n, c, hi, wi);
228
229
230
                }
                else
                {
231
                    in_hold(n, c, htile, wtile, j, i) = TIn(0);
232
                }
Chao Liu's avatar
Chao Liu committed
233
234
235
236
            }
        }
    };

Chao Liu's avatar
Chao Liu committed
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
    auto f_in_transform = [&](auto n, auto c, auto htile, auto wtile) {
        in_transform(n, c, htile, wtile, 0, 0) =
            in_hold(n, c, htile, wtile, 0, 0) - in_hold(n, c, htile, wtile, 0, 2) -
            in_hold(n, c, htile, wtile, 2, 0) + in_hold(n, c, htile, wtile, 2, 2);
        in_transform(n, c, htile, wtile, 0, 1) =
            in_hold(n, c, htile, wtile, 0, 1) + in_hold(n, c, htile, wtile, 0, 2) -
            in_hold(n, c, htile, wtile, 2, 1) - in_hold(n, c, htile, wtile, 2, 2);
        in_transform(n, c, htile, wtile, 0, 2) =
            -in_hold(n, c, htile, wtile, 0, 1) + in_hold(n, c, htile, wtile, 0, 2) +
            in_hold(n, c, htile, wtile, 2, 1) - in_hold(n, c, htile, wtile, 2, 2);
        in_transform(n, c, htile, wtile, 0, 3) =
            in_hold(n, c, htile, wtile, 0, 1) - in_hold(n, c, htile, wtile, 0, 3) -
            in_hold(n, c, htile, wtile, 2, 1) + in_hold(n, c, htile, wtile, 2, 3);

        in_transform(n, c, htile, wtile, 1, 0) =
            in_hold(n, c, htile, wtile, 1, 0) - in_hold(n, c, htile, wtile, 1, 2) +
            in_hold(n, c, htile, wtile, 2, 0) - in_hold(n, c, htile, wtile, 2, 2);
        in_transform(n, c, htile, wtile, 1, 1) =
            in_hold(n, c, htile, wtile, 1, 1) + in_hold(n, c, htile, wtile, 1, 2) +
            in_hold(n, c, htile, wtile, 2, 1) + in_hold(n, c, htile, wtile, 2, 2);
        in_transform(n, c, htile, wtile, 1, 2) =
            -in_hold(n, c, htile, wtile, 1, 1) + in_hold(n, c, htile, wtile, 1, 2) -
            in_hold(n, c, htile, wtile, 2, 1) + in_hold(n, c, htile, wtile, 2, 2);
        in_transform(n, c, htile, wtile, 1, 3) =
            in_hold(n, c, htile, wtile, 1, 1) - in_hold(n, c, htile, wtile, 1, 3) +
            in_hold(n, c, htile, wtile, 2, 1) - in_hold(n, c, htile, wtile, 2, 3);

        in_transform(n, c, htile, wtile, 2, 0) =
            -in_hold(n, c, htile, wtile, 1, 0) + in_hold(n, c, htile, wtile, 1, 2) +
            in_hold(n, c, htile, wtile, 2, 0) - in_hold(n, c, htile, wtile, 2, 2);
        in_transform(n, c, htile, wtile, 2, 1) =
            -in_hold(n, c, htile, wtile, 1, 1) - in_hold(n, c, htile, wtile, 1, 2) +
            in_hold(n, c, htile, wtile, 2, 1) + in_hold(n, c, htile, wtile, 2, 2);
        in_transform(n, c, htile, wtile, 2, 2) =
            in_hold(n, c, htile, wtile, 1, 1) - in_hold(n, c, htile, wtile, 1, 2) -
            in_hold(n, c, htile, wtile, 2, 1) + in_hold(n, c, htile, wtile, 2, 2);
        in_transform(n, c, htile, wtile, 2, 3) =
            -in_hold(n, c, htile, wtile, 1, 1) + in_hold(n, c, htile, wtile, 1, 3) +
            in_hold(n, c, htile, wtile, 2, 1) - in_hold(n, c, htile, wtile, 2, 3);

        in_transform(n, c, htile, wtile, 3, 0) =
            in_hold(n, c, htile, wtile, 1, 0) - in_hold(n, c, htile, wtile, 1, 2) -
            in_hold(n, c, htile, wtile, 3, 0) + in_hold(n, c, htile, wtile, 3, 2);
        in_transform(n, c, htile, wtile, 3, 1) =
            in_hold(n, c, htile, wtile, 1, 1) + in_hold(n, c, htile, wtile, 1, 2) -
            in_hold(n, c, htile, wtile, 3, 1) - in_hold(n, c, htile, wtile, 3, 2);
        in_transform(n, c, htile, wtile, 3, 2) =
            -in_hold(n, c, htile, wtile, 1, 1) + in_hold(n, c, htile, wtile, 1, 2) +
            in_hold(n, c, htile, wtile, 3, 1) - in_hold(n, c, htile, wtile, 3, 2);
        in_transform(n, c, htile, wtile, 3, 3) =
            in_hold(n, c, htile, wtile, 1, 1) - in_hold(n, c, htile, wtile, 1, 3) -
            in_hold(n, c, htile, wtile, 3, 1) + in_hold(n, c, htile, wtile, 3, 3);
Chao Liu's avatar
Chao Liu committed
289
290
291
    };

    auto f_wei_transform = [&](auto k, auto c) {
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
        wei_transform(k, c, 0, 0) = double(wei_kcyx(k, c, 0, 0));
        wei_transform(k, c, 0, 1) = 0.5 * double(wei_kcyx(k, c, 0, 0)) +
                                    0.5 * double(wei_kcyx(k, c, 0, 1)) +
                                    0.5 * double(wei_kcyx(k, c, 0, 2));
        wei_transform(k, c, 0, 2) = 0.5 * double(wei_kcyx(k, c, 0, 0)) -
                                    0.5 * double(wei_kcyx(k, c, 0, 1)) +
                                    0.5 * double(wei_kcyx(k, c, 0, 2));
        wei_transform(k, c, 0, 3) = double(wei_kcyx(k, c, 0, 2));

        wei_transform(k, c, 1, 0) = 0.5 * double(wei_kcyx(k, c, 0, 0)) +
                                    0.5 * double(wei_kcyx(k, c, 1, 0)) +
                                    0.5 * double(wei_kcyx(k, c, 2, 0));
        wei_transform(k, c, 1, 1) =
            0.25 * double(wei_kcyx(k, c, 0, 0)) + 0.25 * double(wei_kcyx(k, c, 0, 1)) +
            0.25 * double(wei_kcyx(k, c, 0, 2)) + 0.25 * double(wei_kcyx(k, c, 1, 0)) +
            0.25 * double(wei_kcyx(k, c, 1, 1)) + 0.25 * double(wei_kcyx(k, c, 1, 2)) +
            0.25 * double(wei_kcyx(k, c, 2, 0)) + 0.25 * double(wei_kcyx(k, c, 2, 1)) +
            0.25 * double(wei_kcyx(k, c, 2, 2));
        wei_transform(k, c, 1, 2) =
            0.25 * double(wei_kcyx(k, c, 0, 0)) - 0.25 * double(wei_kcyx(k, c, 0, 1)) +
            0.25 * double(wei_kcyx(k, c, 0, 2)) + 0.25 * double(wei_kcyx(k, c, 1, 0)) -
            0.25 * double(wei_kcyx(k, c, 1, 1)) + 0.25 * double(wei_kcyx(k, c, 1, 2)) +
            0.25 * double(wei_kcyx(k, c, 2, 0)) - 0.25 * double(wei_kcyx(k, c, 2, 1)) +
            0.25 * double(wei_kcyx(k, c, 2, 2));
        wei_transform(k, c, 1, 3) = 0.5 * double(wei_kcyx(k, c, 0, 2)) +
                                    0.5 * double(wei_kcyx(k, c, 1, 2)) +
                                    0.5 * double(wei_kcyx(k, c, 2, 2));

        wei_transform(k, c, 2, 0) = 0.5 * double(wei_kcyx(k, c, 0, 0)) -
                                    0.5 * double(wei_kcyx(k, c, 1, 0)) +
                                    0.5 * double(wei_kcyx(k, c, 2, 0));
        wei_transform(k, c, 2, 1) =
            0.25 * double(wei_kcyx(k, c, 0, 0)) + 0.25 * double(wei_kcyx(k, c, 0, 1)) +
            0.25 * double(wei_kcyx(k, c, 0, 2)) - 0.25 * double(wei_kcyx(k, c, 1, 0)) -
            0.25 * double(wei_kcyx(k, c, 1, 1)) - 0.25 * double(wei_kcyx(k, c, 1, 2)) +
            0.25 * double(wei_kcyx(k, c, 2, 0)) + 0.25 * double(wei_kcyx(k, c, 2, 1)) +
            0.25 * double(wei_kcyx(k, c, 2, 2));
        wei_transform(k, c, 2, 2) =
            0.25 * double(wei_kcyx(k, c, 0, 0)) - 0.25 * double(wei_kcyx(k, c, 0, 1)) +
            0.25 * double(wei_kcyx(k, c, 0, 2)) - 0.25 * double(wei_kcyx(k, c, 1, 0)) +
            0.25 * double(wei_kcyx(k, c, 1, 1)) - 0.25 * double(wei_kcyx(k, c, 1, 2)) +
            0.25 * double(wei_kcyx(k, c, 2, 0)) - 0.25 * double(wei_kcyx(k, c, 2, 1)) +
            0.25 * double(wei_kcyx(k, c, 2, 2));
        wei_transform(k, c, 2, 3) = 0.5 * double(wei_kcyx(k, c, 0, 2)) -
                                    0.5 * double(wei_kcyx(k, c, 1, 2)) +
                                    0.5 * double(wei_kcyx(k, c, 2, 2));

        wei_transform(k, c, 3, 0) = double(wei_kcyx(k, c, 2, 0));
        wei_transform(k, c, 3, 1) = 0.5 * double(wei_kcyx(k, c, 2, 0)) +
                                    0.5 * double(wei_kcyx(k, c, 2, 1)) +
                                    0.5 * double(wei_kcyx(k, c, 2, 2));
        wei_transform(k, c, 3, 2) = 0.5 * double(wei_kcyx(k, c, 2, 0)) -
                                    0.5 * double(wei_kcyx(k, c, 2, 1)) +
                                    0.5 * double(wei_kcyx(k, c, 2, 2));
        wei_transform(k, c, 3, 3) = double(wei_kcyx(k, c, 2, 2));
Chao Liu's avatar
Chao Liu committed
347
348
    };

Chao Liu's avatar
Chao Liu committed
349
350
    auto f_out_transform = [&](auto n, auto k, auto htile, auto wtile) {
        for(int j = 0; j < HiPerTile; ++j)
Chao Liu's avatar
Chao Liu committed
351
        {
Chao Liu's avatar
Chao Liu committed
352
            for(int i = 0; i < WiPerTile; ++i)
Chao Liu's avatar
Chao Liu committed
353
354
355
356
            {
                double v = 0;
                for(int c = 0; c < C; ++c)
                {
Chao Liu's avatar
Chao Liu committed
357
                    v += in_transform(n, c, htile, wtile, j, i) * wei_transform(k, c, j, i);
Chao Liu's avatar
Chao Liu committed
358
359
                }

Chao Liu's avatar
Chao Liu committed
360
                out_transform(n, k, htile, wtile, j, i) = v;
Chao Liu's avatar
Chao Liu committed
361
362
363
364
            }
        }
    };

Chao Liu's avatar
Chao Liu committed
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
    auto f_out_hold = [&](auto n, auto k, auto htile, auto wtile) {
        out_hold(n, k, htile, wtile, 0, 0) =
            out_transform(n, k, htile, wtile, 0, 0) + out_transform(n, k, htile, wtile, 0, 1) +
            out_transform(n, k, htile, wtile, 0, 2) + out_transform(n, k, htile, wtile, 1, 0) +
            out_transform(n, k, htile, wtile, 1, 1) + out_transform(n, k, htile, wtile, 1, 2) +
            out_transform(n, k, htile, wtile, 2, 0) + out_transform(n, k, htile, wtile, 2, 1) +
            out_transform(n, k, htile, wtile, 2, 2);
        out_hold(n, k, htile, wtile, 0, 1) =
            out_transform(n, k, htile, wtile, 0, 1) - out_transform(n, k, htile, wtile, 0, 2) -
            out_transform(n, k, htile, wtile, 0, 3) + out_transform(n, k, htile, wtile, 1, 1) -
            out_transform(n, k, htile, wtile, 1, 2) - out_transform(n, k, htile, wtile, 1, 3) +
            out_transform(n, k, htile, wtile, 2, 1) - out_transform(n, k, htile, wtile, 2, 2) -
            out_transform(n, k, htile, wtile, 2, 3);
        out_hold(n, k, htile, wtile, 1, 0) =
            out_transform(n, k, htile, wtile, 1, 0) + out_transform(n, k, htile, wtile, 1, 1) +
            out_transform(n, k, htile, wtile, 1, 2) - out_transform(n, k, htile, wtile, 2, 0) -
            out_transform(n, k, htile, wtile, 2, 1) - out_transform(n, k, htile, wtile, 2, 2) -
            out_transform(n, k, htile, wtile, 3, 0) - out_transform(n, k, htile, wtile, 3, 1) -
            out_transform(n, k, htile, wtile, 3, 2);
        out_hold(n, k, htile, wtile, 1, 1) =
            out_transform(n, k, htile, wtile, 1, 1) - out_transform(n, k, htile, wtile, 1, 2) -
            out_transform(n, k, htile, wtile, 1, 3) - out_transform(n, k, htile, wtile, 2, 1) +
            out_transform(n, k, htile, wtile, 2, 2) + out_transform(n, k, htile, wtile, 2, 3) -
            out_transform(n, k, htile, wtile, 3, 1) + out_transform(n, k, htile, wtile, 3, 2) +
            out_transform(n, k, htile, wtile, 3, 3);
Chao Liu's avatar
Chao Liu committed
390
391
    };

Chao Liu's avatar
Chao Liu committed
392
393
    auto f_out = [&](auto n, auto k, auto htile, auto wtile) {
        for(int j = 0; j < HoPerTile; ++j)
Chao Liu's avatar
Chao Liu committed
394
        {
Chao Liu's avatar
Chao Liu committed
395
396
            std::size_t ho = HoPerTile * htile + j;
            for(int i = 0; i < WoPerTile; ++i)
Chao Liu's avatar
Chao Liu committed
397
            {
398
                std::size_t wo = WoPerTile * wtile + i;
399
                out_nkhw(n, k, ho, wo) = out_hold(n, k, htile, wtile, j, i);
Chao Liu's avatar
Chao Liu committed
400
401
402
403
404
405
            }
        }
    };

    std::size_t num_thread = std::thread::hardware_concurrency();

Chao Liu's avatar
Chao Liu committed
406
407
    make_ParallelTensorFunctor(f_in_hold, N, C, HTile, WTile)(num_thread);
    make_ParallelTensorFunctor(f_in_transform, N, C, HTile, WTile)(num_thread);
Chao Liu's avatar
Chao Liu committed
408
    make_ParallelTensorFunctor(f_wei_transform, K, C)(num_thread);
Chao Liu's avatar
Chao Liu committed
409
410
411
    make_ParallelTensorFunctor(f_out_transform, N, K, HTile, WTile)(num_thread);
    make_ParallelTensorFunctor(f_out_hold, N, K, HTile, WTile)(num_thread);
    make_ParallelTensorFunctor(f_out, N, K, HTile, WTile)(num_thread);
Chao Liu's avatar
Chao Liu committed
412
413
414
415
416
417
}

template <class T>
void check_error(const Tensor<T>& ref, const Tensor<T>& result)
{
    float error     = 0;
Chao Liu's avatar
Chao Liu committed
418
    float max_diff  = -1;
Chao Liu's avatar
Chao Liu committed
419
420
421
    float ref_value = 0, result_value = 0;
    for(int i = 0; i < ref.mData.size(); ++i)
    {
422
        std::cout << result.mData[i] << "," << ref.mData[i] << " ";
423
424
        error += std::abs(double(ref.mData[i]) - double(result.mData[i]));
        float diff = std::abs(double(ref.mData[i]) - double(result.mData[i]));
Chao Liu's avatar
Chao Liu committed
425
426
427
428
429
430
431
432
        if(max_diff < diff)
        {
            max_diff     = diff;
            ref_value    = ref.mData[i];
            result_value = result.mData[i];
        }
    }

433
    std::cout << std::endl;
Chao Liu's avatar
Chao Liu committed
434
435
436
437
    std::cout << "error: " << error << std::endl;
    std::cout << "max_diff: " << max_diff << ", " << ref_value << ", " << result_value << std::endl;
}

Chao Liu's avatar
Chao Liu committed
438
int main(int argc, char* argv[])
Chao Liu's avatar
Chao Liu committed
439
{
Chao Liu's avatar
Chao Liu committed
440
441
#if 0
    constexpr index_t N  = 8;
Chao Liu's avatar
Chao Liu committed
442
    constexpr index_t C  = 16;
Chao Liu's avatar
Chao Liu committed
443
444
445
446
447
448
449
450
    constexpr index_t HI = 3;
    constexpr index_t WI = 18;
    constexpr index_t K  = 128;
    constexpr index_t Y  = 3;
    constexpr index_t X  = 3;

    constexpr index_t HPad = 0;
    constexpr index_t WPad = 0;
451
#elif 0
452
    // 3x3, 34x34
453
    constexpr index_t N  = 128;
454
    constexpr index_t C  = 256;
455
456
    constexpr index_t HI = 34;
    constexpr index_t WI = 34;
457
458
459
    constexpr index_t K  = 128;
    constexpr index_t Y  = 3;
    constexpr index_t X  = 3;
Chao Liu's avatar
Chao Liu committed
460

461
462
463
    using ConvStrides   = Sequence<2, 2>;
    using ConvDilations = Sequence<1, 1>;

Chao Liu's avatar
Chao Liu committed
464
465
    constexpr index_t HPad = 0;
    constexpr index_t WPad = 0;
Chao Liu's avatar
Chao Liu committed
466
#elif 0
467
    // 3x3, 56x56
Chao Liu's avatar
Chao Liu committed
468
469
    constexpr index_t N  = 64;
    constexpr index_t C  = 64;
470
471
    constexpr index_t HI = 56;
    constexpr index_t WI = 56;
Chao Liu's avatar
Chao Liu committed
472
473
474
    constexpr index_t K  = 128;
    constexpr index_t Y  = 3;
    constexpr index_t X  = 3;
Chao Liu's avatar
Chao Liu committed
475
476
477

    constexpr index_t HPad = 0;
    constexpr index_t WPad = 0;
Chao Liu's avatar
Chao Liu committed
478
#elif 0
Chao Liu's avatar
Chao Liu committed
479
480
481
482
483
    // 3x3 filter, 28x28 image
    constexpr index_t N  = 128;
    constexpr index_t C  = 256;
    constexpr index_t HI = 28;
    constexpr index_t WI = 28;
484
    constexpr index_t K  = 128;
Chao Liu's avatar
Chao Liu committed
485
486
487
    constexpr index_t Y  = 3;
    constexpr index_t X  = 3;

Chao Liu's avatar
Chao Liu committed
488
    using ConvStrides   = Sequence<1, 1>;
489
490
    using ConvDilations = Sequence<1, 1>;

Chao Liu's avatar
Chao Liu committed
491
492
    constexpr index_t HPad = 0;
    constexpr index_t WPad = 0;
Chao Liu's avatar
Chao Liu committed
493
#elif 0
Chao Liu's avatar
Chao Liu committed
494
    // 1x1 filter, 28x28 image
495
496
    constexpr index_t N  = 128;
    constexpr index_t C  = 512;
Chao Liu's avatar
Chao Liu committed
497
498
499
500
501
502
    constexpr index_t HI = 28;
    constexpr index_t WI = 28;
    constexpr index_t K  = 512;
    constexpr index_t Y  = 1;
    constexpr index_t X  = 1;

Chao Liu's avatar
Chao Liu committed
503
504
505
    using ConvStrides   = Sequence<1, 1>;
    using ConvDilations = Sequence<1, 1>;

Chao Liu's avatar
Chao Liu committed
506
507
    constexpr index_t HPad = 0;
    constexpr index_t WPad = 0;
508
509
#elif 0
    // 3x3 filter, 20x84 image, 1x1 padding
Chao Liu's avatar
Chao Liu committed
510
511
512
513
514
515
516
517
518
519
    constexpr index_t N  = 16;
    constexpr index_t C  = 256;
    constexpr index_t HI = 20;
    constexpr index_t WI = 84;
    constexpr index_t K  = 256;
    constexpr index_t Y  = 3;
    constexpr index_t X  = 3;

    constexpr index_t HPad = 1;
    constexpr index_t WPad = 1;
Chao Liu's avatar
Chao Liu committed
520
521
#elif 0
    // 3x3 filter, 112x112 image, 1x1 padding
Chao Liu's avatar
Chao Liu committed
522
523
524
525
526
527
528
529
530
531
    constexpr index_t N  = 16;
    constexpr index_t C  = 64;
    constexpr index_t HI = 112;
    constexpr index_t WI = 112;
    constexpr index_t K  = 128;
    constexpr index_t Y  = 3;
    constexpr index_t X  = 3;

    constexpr index_t HPad = 1;
    constexpr index_t WPad = 1;
532
#elif 0
533
534
535
536
537
538
539
540
541
542
543
    // 5x5 filter, 20x86 image
    constexpr index_t N  = 16;
    constexpr index_t C  = 256;
    constexpr index_t HI = 20;
    constexpr index_t WI = 86;
    constexpr index_t K  = 512;
    constexpr index_t Y  = 5;
    constexpr index_t X  = 5;

    constexpr index_t HPad = 0;
    constexpr index_t WPad = 0;
Chao Liu's avatar
Chao Liu committed
544
545
#elif 0
    // 5x5 filter, 20x86 image, 1x1 padding
Chao Liu's avatar
Chao Liu committed
546
547
548
549
550
551
552
553
554
555
    constexpr index_t N  = 16;
    constexpr index_t C  = 256;
    constexpr index_t HI = 20;
    constexpr index_t WI = 86;
    constexpr index_t K  = 512;
    constexpr index_t Y  = 5;
    constexpr index_t X  = 5;

    constexpr index_t HPad = 1;
    constexpr index_t WPad = 1;
Chao Liu's avatar
Chao Liu committed
556
557
#elif 0
    // 5x5 filter, 28x28 image, 2x2 padding
Chao Liu's avatar
Chao Liu committed
558
559
560
561
562
563
564
565
566
567
    constexpr index_t N  = 16;
    constexpr index_t C  = 192;
    constexpr index_t HI = 28;
    constexpr index_t WI = 28;
    constexpr index_t K  = 32;
    constexpr index_t Y  = 5;
    constexpr index_t X  = 5;

    constexpr index_t HPad = 2;
    constexpr index_t WPad = 2;
Chao Liu's avatar
Chao Liu committed
568
#elif 0
569
    // 3x3 filter, 14x14 image
Chao Liu's avatar
Chao Liu committed
570
    constexpr index_t N  = 128;
571
    constexpr index_t C  = 256;
Chao Liu's avatar
Chao Liu committed
572
573
    constexpr index_t HI = 14;
    constexpr index_t WI = 14;
574
575
576
    constexpr index_t K  = 128;
    constexpr index_t Y  = 3;
    constexpr index_t X  = 3;
Chao Liu's avatar
Chao Liu committed
577
578
579

    constexpr index_t HPad = 0;
    constexpr index_t WPad = 0;
Chao Liu's avatar
Chao Liu committed
580
#elif 0
581
    // 1x1 filter, 14x14 image
Chao Liu's avatar
Chao Liu committed
582
583
584
585
586
587
588
589
    constexpr index_t N  = 128;
    constexpr index_t C  = 512;
    constexpr index_t HI = 14;
    constexpr index_t WI = 14;
    constexpr index_t K  = 512;
    constexpr index_t Y  = 1;
    constexpr index_t X  = 1;

590
591
592
    using ConvStrides   = Sequence<1, 1>;
    using ConvDilations = Sequence<1, 1>;

Chao Liu's avatar
Chao Liu committed
593
594
595
596
597
598
599
600
601
602
603
604
    constexpr index_t HPad = 0;
    constexpr index_t WPad = 0;
#elif 0
    // 1x1 filter, 7x7 image
    constexpr index_t N  = 128;
    constexpr index_t C  = 512;
    constexpr index_t HI = 7;
    constexpr index_t WI = 7;
    constexpr index_t K  = 2048;
    constexpr index_t Y  = 1;
    constexpr index_t X  = 1;

605
606
    constexpr index_t HPad = 0;
    constexpr index_t WPad = 0;
Chao Liu's avatar
Chao Liu committed
607
#elif 0
608
609
    // 1x1 filter, 73x73 image
    constexpr index_t N  = 128;
Chao Liu's avatar
Chao Liu committed
610
    constexpr index_t C  = 512;
611
612
613
614
615
616
    constexpr index_t HI = 73;
    constexpr index_t WI = 73;
    constexpr index_t K  = 128;
    constexpr index_t Y  = 1;
    constexpr index_t X  = 1;

Chao Liu's avatar
Chao Liu committed
617
618
    constexpr index_t HPad = 0;
    constexpr index_t WPad = 0;
Chao Liu's avatar
Chao Liu committed
619
#elif 0
Chao Liu's avatar
Chao Liu committed
620
    // 1x1 filter, 8x8 image
Chao Liu's avatar
Chao Liu committed
621
    // cudnn@V100 68%, ck@V100 72%, ck@P100 52%, ck@VII 42%
Chao Liu's avatar
Chao Liu committed
622
623
624
625
626
627
628
629
630
631
632
633
634
    constexpr index_t N  = 64;
    constexpr index_t C  = 1536;
    constexpr index_t HI = 8;
    constexpr index_t WI = 8;
    constexpr index_t K  = 256;
    constexpr index_t Y  = 1;
    constexpr index_t X  = 1;

    using ConvStrides   = Sequence<1, 1>;
    using ConvDilations = Sequence<1, 1>;

    constexpr index_t HPad = 0;
    constexpr index_t WPad = 0;
Chao Liu's avatar
Chao Liu committed
635
#elif 0
Chao Liu's avatar
Chao Liu committed
636
    // 1x1 filter, 8x8 image
Chao Liu's avatar
Chao Liu committed
637
    // cudnn@V100 77%, ck@V100 76%, ck@P100 79%, ck@VII 51%
Chao Liu's avatar
Chao Liu committed
638
639
640
641
642
643
644
645
646
647
648
649
650
    constexpr index_t N  = 128;
    constexpr index_t C  = 2048;
    constexpr index_t HI = 8;
    constexpr index_t WI = 8;
    constexpr index_t K  = 384;
    constexpr index_t Y  = 1;
    constexpr index_t X  = 1;

    using ConvStrides   = Sequence<1, 1>;
    using ConvDilations = Sequence<1, 1>;

    constexpr index_t HPad = 0;
    constexpr index_t WPad = 0;
Chao Liu's avatar
Chao Liu committed
651
#elif 0
Chao Liu's avatar
Chao Liu committed
652
    // 1x1 filter, 7x7 image
Chao Liu's avatar
Chao Liu committed
653
    // cudnn@V100 82%, ck@V100 76%, ck@P100 67%, ck@VII 64%
Chao Liu's avatar
Chao Liu committed
654
655
656
657
658
659
660
661
662
663
664
665
666
    constexpr index_t N  = 128;
    constexpr index_t C  = 832;
    constexpr index_t HI = 7;
    constexpr index_t WI = 7;
    constexpr index_t K  = 384;
    constexpr index_t Y  = 1;
    constexpr index_t X  = 1;

    using ConvStrides   = Sequence<1, 1>;
    using ConvDilations = Sequence<1, 1>;

    constexpr index_t HPad = 0;
    constexpr index_t WPad = 0;
Chao Liu's avatar
Chao Liu committed
667
#elif 0
Chao Liu's avatar
Chao Liu committed
668
    // 1x1 filter, 8x8 image
Chao Liu's avatar
Chao Liu committed
669
    // cudnn@V100 83%, ck@V100 75%, ck@P100 78%, ck@VII 65%
Chao Liu's avatar
Chao Liu committed
670
671
672
673
674
675
676
677
678
679
680
681
682
    constexpr index_t N  = 128;
    constexpr index_t C  = 1280;
    constexpr index_t HI = 8;
    constexpr index_t WI = 8;
    constexpr index_t K  = 384;
    constexpr index_t Y  = 1;
    constexpr index_t X  = 1;

    using ConvStrides   = Sequence<1, 1>;
    using ConvDilations = Sequence<1, 1>;

    constexpr index_t HPad = 0;
    constexpr index_t WPad = 0;
Chao Liu's avatar
Chao Liu committed
683
#elif 0
Chao Liu's avatar
Chao Liu committed
684
    // 1x1 filter, 14x14 image
Chao Liu's avatar
Chao Liu committed
685
    // cudnn@V100 62%, ck@V100 68%, ck@P100 70%, ck@VII 50%
Chao Liu's avatar
Chao Liu committed
686
687
688
689
690
691
692
693
694
695
696
697
698
    constexpr index_t N  = 128;
    constexpr index_t C  = 512;
    constexpr index_t HI = 14;
    constexpr index_t WI = 14;
    constexpr index_t K  = 128;
    constexpr index_t Y  = 1;
    constexpr index_t X  = 1;

    using ConvStrides   = Sequence<1, 1>;
    using ConvDilations = Sequence<1, 1>;

    constexpr index_t HPad = 0;
    constexpr index_t WPad = 0;
Chao Liu's avatar
Chao Liu committed
699
#elif 0
Chao Liu's avatar
Chao Liu committed
700
    // 1x1 filter, 8x8 image
Chao Liu's avatar
Chao Liu committed
701
    // cudnn@V100 74%, ck@V100 57%, ck@P100 78%, ck@VII 61%
Chao Liu's avatar
Chao Liu committed
702
703
704
705
706
707
708
709
710
711
712
713
714
    constexpr index_t N  = 64;
    constexpr index_t C  = 1536;
    constexpr index_t HI = 8;
    constexpr index_t WI = 8;
    constexpr index_t K  = 384;
    constexpr index_t Y  = 1;
    constexpr index_t X  = 1;

    using ConvStrides   = Sequence<1, 1>;
    using ConvDilations = Sequence<1, 1>;

    constexpr index_t HPad = 0;
    constexpr index_t WPad = 0;
Chao Liu's avatar
Chao Liu committed
715
#elif 0
Chao Liu's avatar
Chao Liu committed
716
    // 1x1 filter, 28x28 image
Chao Liu's avatar
Chao Liu committed
717
    // cudnn@V100 86%, ck@V100 84%, ck@P100 80%, ck@VII 69%
Chao Liu's avatar
Chao Liu committed
718
719
720
721
722
723
724
725
726
727
728
729
730
    constexpr index_t N  = 128;
    constexpr index_t C  = 256;
    constexpr index_t HI = 28;
    constexpr index_t WI = 28;
    constexpr index_t K  = 128;
    constexpr index_t Y  = 1;
    constexpr index_t X  = 1;

    using ConvStrides   = Sequence<1, 1>;
    using ConvDilations = Sequence<1, 1>;

    constexpr index_t HPad = 0;
    constexpr index_t WPad = 0;
Chao Liu's avatar
Chao Liu committed
731
#elif 0
Chao Liu's avatar
Chao Liu committed
732
    // 1x1 filter, 7x7 image
Chao Liu's avatar
Chao Liu committed
733
    // cudnn@V100 71%, ck@V100 55%, ck@P100 70%, ck@VII 62%
Chao Liu's avatar
Chao Liu committed
734
735
736
737
738
739
740
741
742
743
744
745
746
    constexpr index_t N  = 128;
    constexpr index_t C  = 832;
    constexpr index_t HI = 7;
    constexpr index_t WI = 7;
    constexpr index_t K  = 256;
    constexpr index_t Y  = 1;
    constexpr index_t X  = 1;

    using ConvStrides   = Sequence<1, 1>;
    using ConvDilations = Sequence<1, 1>;

    constexpr index_t HPad = 0;
    constexpr index_t WPad = 0;
Chao Liu's avatar
Chao Liu committed
747
#elif 0
Chao Liu's avatar
Chao Liu committed
748
    // 3x3 filter, 2x2 stride, 35x35 input, 17x17 output
Chao Liu's avatar
Chao Liu committed
749
    // cudnn@V100 90%, ck@V100 93%, ck@P100 83%, ck@VII 81%
Chao Liu's avatar
Chao Liu committed
750
751
752
753
754
755
756
757
758
759
760
761
762
    constexpr index_t N  = 128;
    constexpr index_t C  = 288;
    constexpr index_t HI = 35;
    constexpr index_t WI = 35;
    constexpr index_t K  = 384;
    constexpr index_t Y  = 3;
    constexpr index_t X  = 3;

    using ConvStrides   = Sequence<2, 2>;
    using ConvDilations = Sequence<1, 1>;

    constexpr index_t HPad = 0;
    constexpr index_t WPad = 0;
Chao Liu's avatar
Chao Liu committed
763
#elif 0
Chao Liu's avatar
Chao Liu committed
764
    // 1x1 filter, 17x17 input
Chao Liu's avatar
Chao Liu committed
765
    // cudnn@V100 81%, ck@V100 76%, ck@P100 70%, ck@VII 76%
Chao Liu's avatar
Chao Liu committed
766
767
768
769
770
771
772
773
774
775
776
777
778
    constexpr index_t N  = 128;
    constexpr index_t C  = 768;
    constexpr index_t HI = 17;
    constexpr index_t WI = 17;
    constexpr index_t K  = 128;
    constexpr index_t Y  = 1;
    constexpr index_t X  = 1;

    using ConvStrides   = Sequence<1, 1>;
    using ConvDilations = Sequence<1, 1>;

    constexpr index_t HPad = 0;
    constexpr index_t WPad = 0;
Chao Liu's avatar
Chao Liu committed
779
#elif 0
Chao Liu's avatar
Chao Liu committed
780
    // 1x1 filter, 14x14 image
Chao Liu's avatar
Chao Liu committed
781
    // cudnn@V100 73%, ck@V100 71%, ck@P100 70%, ck@VII 64%
Chao Liu's avatar
Chao Liu committed
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
    constexpr index_t N  = 128;
    constexpr index_t C  = 528;
    constexpr index_t HI = 14;
    constexpr index_t WI = 14;
    constexpr index_t K  = 128;
    constexpr index_t Y  = 1;
    constexpr index_t X  = 1;

    using ConvStrides   = Sequence<1, 1>;
    using ConvDilations = Sequence<1, 1>;

    constexpr index_t HPad = 0;
    constexpr index_t WPad = 0;
#elif 0
    // 1x1 filter, 14x14 image
Chao Liu's avatar
Chao Liu committed
797
    // cudnn@V100 73%, ck@V100 72%, ck@P100 79%, ck@VII 75%
Chao Liu's avatar
Chao Liu committed
798
799
800
801
802
803
804
805
806
807
808
809
810
    constexpr index_t N  = 128;
    constexpr index_t C  = 528;
    constexpr index_t HI = 14;
    constexpr index_t WI = 14;
    constexpr index_t K  = 256;
    constexpr index_t Y  = 1;
    constexpr index_t X  = 1;

    using ConvStrides   = Sequence<1, 1>;
    using ConvDilations = Sequence<1, 1>;

    constexpr index_t HPad = 0;
    constexpr index_t WPad = 0;
811
#elif 0
Chao Liu's avatar
Chao Liu committed
812
    // 1x1 filter, 7x7 image
813
814
815
816
817
818
819
820
821
822
823
824
825
826
    constexpr index_t N  = 32;
    constexpr index_t C  = 128;
    constexpr index_t HI = 28;
    constexpr index_t WI = 28;
    constexpr index_t K  = 192;
    constexpr index_t Y  = 3;
    constexpr index_t X  = 3;

    using ConvStrides   = Sequence<1, 1>;
    using ConvDilations = Sequence<1, 1>;

    constexpr index_t HPad = 0;
    constexpr index_t WPad = 0;
#elif 1
827
    constexpr index_t N  = 32;
828
829
830
831
    constexpr index_t C  = 64;
    constexpr index_t HI = 4;
    constexpr index_t WI = 4;
    constexpr index_t K  = 64;
Chao Liu's avatar
Chao Liu committed
832
833
    constexpr index_t Y  = 1;
    constexpr index_t X  = 1;
Chao Liu's avatar
Chao Liu committed
834
835
836
837

    using ConvStrides   = Sequence<1, 1>;
    using ConvDilations = Sequence<1, 1>;

Chao Liu's avatar
Chao Liu committed
838
    constexpr index_t HPad = 0;
839
    constexpr index_t WPad = 0;    
840
841
842

    constexpr index_t HO = 4;
    constexpr index_t WO = 4;    
Chao Liu's avatar
Chao Liu committed
843
#endif
Chao Liu's avatar
Chao Liu committed
844

845
846
847
    auto lower_pads = Sequence<HPad, WPad>{};
    auto upper_pads = Sequence<HPad, WPad>{};

848
#if CONV_DIRECTION_FWD_DATA    
Chao Liu's avatar
Chao Liu committed
849
850
    auto in_nchw_desc  = make_ConstantTensorDescriptor_packed(Sequence<N, C, HI, WI>{});
    auto wei_kcyx_desc = make_ConstantTensorDescriptor_packed(Sequence<K, C, Y, X>{});
851
    auto out_nkhw_desc = get_convolution_with_padding_output_default_4d_tensor_descriptor(
852
        in_nchw_desc, wei_kcyx_desc, ConvStrides{}, ConvDilations{}, lower_pads, upper_pads);
853
854
855
856
857
#elif CONV_DIRECTION_BWD_WEIT        
    auto in_nchw_desc  = make_ConstantTensorDescriptor_packed(Sequence<C, N, HI, WI>{});
    auto wei_kcyx_desc = make_ConstantTensorDescriptor_packed(Sequence<C, K, Y, X>{});
    auto out_nkhw_desc = make_ConstantTensorDescriptor_packed(Sequence<K, N, HO, WO>{});
#endif     
Chao Liu's avatar
Chao Liu committed
858

Chao Liu's avatar
Chao Liu committed
859
    ostream_ConstantTensorDescriptor(in_nchw_desc, std::cout << "in_nchw_desc: ");
Chao Liu's avatar
Chao Liu committed
860
    ostream_ConstantTensorDescriptor(wei_kcyx_desc, std::cout << "wei_kcyx_desc: ");
Chao Liu's avatar
Chao Liu committed
861
    ostream_ConstantTensorDescriptor(out_nkhw_desc, std::cout << "out_nkhw_desc: ");
Chao Liu's avatar
Chao Liu committed
862

863
864
    using in_data_t  = float;
    using out_data_t = float;
865
866

#if CONV_DIRECTION_FWD_DATA    
867
868
869
    Tensor<in_data_t> in_nchw(make_TensorDescriptor(in_nchw_desc));
    Tensor<in_data_t> wei_kcyx(make_TensorDescriptor(wei_kcyx_desc));
    Tensor<out_data_t> out_nkhw_host(make_TensorDescriptor(out_nkhw_desc));
870
871
872
873
874
875
876
877
    Tensor<out_data_t> out_nkhw_device(make_TensorDescriptor(out_nkhw_desc));  
#elif CONV_DIRECTION_BWD_WEIT    
    Tensor<in_data_t> in_nchw(make_TensorDescriptor(in_nchw_desc));
    Tensor<out_data_t> wei_kcyx(make_TensorDescriptor(wei_kcyx_desc));
    Tensor<out_data_t> wei_kcyx_host(make_TensorDescriptor(wei_kcyx_desc));  
    Tensor<in_data_t> out_nkhw_host(make_TensorDescriptor(out_nkhw_desc));
    Tensor<in_data_t> out_nkhw(make_TensorDescriptor(out_nkhw_desc));
#endif     
Chao Liu's avatar
Chao Liu committed
878

Chao Liu's avatar
Chao Liu committed
879
    std::size_t num_thread = std::thread::hardware_concurrency();
Chao Liu's avatar
Chao Liu committed
880

Chao Liu's avatar
Chao Liu committed
881
882
883
884
885
886
887
    if(argc != 3)
    {
        printf("arg1: do_verification, arg2: nrepeat\n");
        exit(1);
    }

    bool do_verification = atoi(argv[1]);
Chao Liu's avatar
Chao Liu committed
888
    index_t nrepeat      = atoi(argv[2]);
889
890
891

    if(do_verification)
    {
892
#if 1
893
#if CONV_DIRECTION_FWD_DATA // fwd data
894
        in_nchw.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
Chao Liu's avatar
Chao Liu committed
895
        wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
896
897
898
899
900
#elif CONV_DIRECTION_BWD_WEIT // bwd wrw
        in_nchw.GenerateTensorValue(GeneratorTensor_2{}, num_thread);
        //out_nkhw_host.GenerateTensorValue(GeneratorTensor_2{}, num_thread);
        out_nkhw.GenerateTensorValue(GeneratorTensor_2{}, num_thread);
#endif         
Chao Liu's avatar
Chao Liu committed
901
902
903
#elif 0
        in_nchw.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
        wei_kcyx.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
904
905
906
#elif 0
        in_nchw.GenerateTensorValue(GeneratorTensor_3{}, num_thread);
        wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
907
#elif 0
908
        in_nchw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
Chao Liu's avatar
Chao Liu committed
909
        wei_kcyx.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
Chao Liu's avatar
Chao Liu committed
910
#elif 0
911
912
913
914
915
916
        in_nchw.GenerateTensorValue(GeneratorTensor_2{1, 5}, num_thread);

        auto gen_wei = [](auto... is) {
            return GeneratorTensor_2{1, 5}(is...) * GeneratorTensor_Checkboard{}(is...);
        };
        wei_kcyx.GenerateTensorValue(gen_wei, num_thread);
Chao Liu's avatar
Chao Liu committed
917
#endif
918
    }
Chao Liu's avatar
Chao Liu committed
919

Chao Liu's avatar
Chao Liu committed
920
#if 1
Chao Liu's avatar
Chao Liu committed
921
#if 0
Chao Liu's avatar
Chao Liu committed
922
    device_convolution_direct_v2_nchw_kcyx_nkhw
Chao Liu's avatar
Chao Liu committed
923
#elif 0
924
    device_convolution_implicit_gemm_v1_chwn_cyxk_khwn
Chao Liu's avatar
Chao Liu committed
925
#elif 0
Chao Liu's avatar
Chao Liu committed
926
    device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw
927
#elif 0
Chao Liu's avatar
Chao Liu committed
928
    device_convolution_implicit_gemm_v2_chwn_cyxk_khwn
Chao Liu's avatar
Chao Liu committed
929
#elif 0
Chao Liu's avatar
Chao Liu committed
930
    device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw
931
#elif 0
932
    device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw
933
934
#elif 1
    device_convolution_implicit_gemm_v5_nchw_kcyx_nkhw
935
#endif
936
#if CONV_DIRECTION_FWD_DATA // fwd data
937
938
939
940
941
942
943
944
945
    (in_nchw_desc,
     in_nchw,
     wei_kcyx_desc,
     wei_kcyx,
     out_nkhw_desc,
     out_nkhw_device,
     ConvStrides{},
     ConvDilations{},
     nrepeat);
946
947
948
949
950
951
952
953
954
955
956
#elif CONV_DIRECTION_BWD_WEIT // bwd wrw
    (in_nchw_desc,
     in_nchw,
     out_nkhw_desc,
     out_nkhw,     
     wei_kcyx_desc,
     wei_kcyx,
     ConvDilations{},
     ConvStrides{},     
     nrepeat);
#endif 
957

958
#elif 0
Chao Liu's avatar
Chao Liu committed
959
    device_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded(in_nchw_desc,
Chao Liu's avatar
Chao Liu committed
960
                                                             in_nchw,
Chao Liu's avatar
Chao Liu committed
961
962
                                                             wei_kcyx_desc,
                                                             wei_kcyx,
Chao Liu's avatar
Chao Liu committed
963
964
965
966
967
                                                             out_nkhw_desc,
                                                             out_nkhw_device,
                                                             lower_pads,
                                                             upper_pads,
                                                             nrepeat);
968
#endif
Chao Liu's avatar
Chao Liu committed
969

970
    if(do_verification)
971
    {
972
#if 0
973
974
        if(Y == 3 && X == 3 && ConvStrides{}[0] == 1 && ConvStrides{}[1] == 1 &&
           ConvDilations{}[0] == 1 && ConvDilations{}[1] == 1)
975
        {
Chao Liu's avatar
Chao Liu committed
976
            host_winograd_3x3_convolution(in_nchw, wei_kcyx, out_nkhw_host, lower_pads, upper_pads);
977
978
        }
        else
Chao Liu's avatar
Chao Liu committed
979
#endif
980
        {
981
982

#if CONV_DIRECTION_FWD_DATA // fwd data
983
984
985
986
987
988
989
            host_direct_convolution(in_nchw,
                                    wei_kcyx,
                                    out_nkhw_host,
                                    ConvStrides{},
                                    ConvDilations{},
                                    lower_pads,
                                    upper_pads);
990
991
992
993
994
995
996
997
998
999
#elif CONV_DIRECTION_BWD_WEIT // bwd  wrw
            host_direct_convolution(in_nchw,
                                    out_nkhw,
                                    wei_kcyx_host,
                                    ConvDilations{},
                                    ConvStrides{},
                                    lower_pads,
                                    upper_pads);
#endif 

1000
        }
1001
#if CONV_DIRECTION_FWD_DATA // fwd data
1002
        check_error(out_nkhw_host, out_nkhw_device);
1003
1004
1005
1006
1007
1008
#elif CONV_DIRECTION_BWD_WEIT // bwd  wrw
        check_error(wei_kcyx_host, wei_kcyx);
#endif 
        LogRange(std::cout << "in_nchw : ", in_nchw.mData, ",") << std::endl;        
        LogRange(std::cout << "out_nkhw_device  : ", out_nkhw.mData, ",") << std::endl;        
        //LogRange(std::cout << "wei_kcyx: ", wei_kcyx.mData, ",") << std::endl;
Chao Liu's avatar
Chao Liu committed
1009
#if 0
1010
        LogRange(std::cout << "in_nchw : ", in_nchw.mData, ",") << std::endl;
Chao Liu's avatar
Chao Liu committed
1011
        LogRange(std::cout << "wei_kcyx: ", wei_kcyx.mData, ",") << std::endl;
1012
1013
        LogRange(std::cout << "out_nkhw_host  : ", out_nkhw_host.mData, ",") << std::endl;
        LogRange(std::cout << "out_nkhw_device: ", out_nkhw_device.mData, ",") << std::endl;
Chao Liu's avatar
Chao Liu committed
1014
#endif
1015
    }
1016
}