direct_convolution.cuh 24.9 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
#pragma once
Chao Liu's avatar
Chao Liu committed
2
#include "constant_tensor_descriptor.cuh"
Chao Liu's avatar
Chao Liu committed
3

Chao Liu's avatar
Chao Liu committed
4
template <class TFloat,
Chao Liu's avatar
Chao Liu committed
5
6
          class SrcDesc,
          class DstDesc,
Chao Liu's avatar
Chao Liu committed
7
8
9
10
11
          unsigned NWorkLen0,
          unsigned NWorkLen1,
          unsigned NWorkLen2,
          unsigned NWorkLen3,
          class F>
Chao Liu's avatar
Chao Liu committed
12
13
__device__ void blockwise_4d_tensor_op(
    SrcDesc, TFloat* const __restrict__ p_src, DstDesc, TFloat* __restrict__ p_dst, F f)
Chao Liu's avatar
Chao Liu committed
14
{
Chao Liu's avatar
Chao Liu committed
15
16
17
18
19
20
21
22
    constexpr auto I0 = Index<0>{};
    constexpr auto I1 = Index<1>{};
    constexpr auto I2 = Index<2>{};
    constexpr auto I3 = Index<3>{};

    constexpr auto src_desc = SrcDesc{};
    constexpr auto dst_desc = DstDesc{};

Chao Liu's avatar
Chao Liu committed
23
#if 0
Chao Liu's avatar
Chao Liu committed
24
    if(threadIdx.x == 0)
Chao Liu's avatar
Chao Liu committed
25
    {
Chao Liu's avatar
Chao Liu committed
26
27
        print_ConstantTensorDescriptor(src_desc, "blockwise_4d_tensor_op: src_desc: ");
        print_ConstantTensorDescriptor(dst_desc, "blockwise_4d_tensor_op: dst_desc: ");
Chao Liu's avatar
Chao Liu committed
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
    }
#endif

    constexpr unsigned NWorkStride3 = 1;
    constexpr unsigned NWorkStride2 = NWorkLen3 * NWorkStride3;
    constexpr unsigned NWorkStride1 = NWorkLen2 * NWorkStride2;
    constexpr unsigned NWorkStride0 = NWorkLen1 * NWorkStride1;

    unsigned itmp =
        threadIdx.x + threadIdx.y * blockDim.x + threadIdx.z * (blockDim.y * blockDim.x);

    const unsigned did0_begin = itmp / NWorkStride0;

    itmp -= did0_begin * NWorkStride0;

    const unsigned did1_begin = itmp / NWorkStride1;

    itmp -= did1_begin * NWorkStride1;

    const unsigned did2_begin = itmp / NWorkStride2;

    itmp -= did2_begin * NWorkStride2;

    const unsigned did3_begin = itmp / NWorkStride3;

Chao Liu's avatar
Chao Liu committed
53
    for(unsigned did0 = did0_begin; did0 < src_desc.GetLength(I0); did0 += NWorkLen0)
Chao Liu's avatar
Chao Liu committed
54
    {
Chao Liu's avatar
Chao Liu committed
55
        for(unsigned did1 = did1_begin; did1 < src_desc.GetLength(I1); did1 += NWorkLen1)
Chao Liu's avatar
Chao Liu committed
56
        {
Chao Liu's avatar
Chao Liu committed
57
            for(unsigned did2 = did2_begin; did2 < src_desc.GetLength(I2); did2 += NWorkLen2)
Chao Liu's avatar
Chao Liu committed
58
            {
Chao Liu's avatar
Chao Liu committed
59
                for(unsigned did3 = did3_begin; did3 < src_desc.GetLength(I3); did3 += NWorkLen3)
Chao Liu's avatar
Chao Liu committed
60
61
                {
                    const unsigned sindex =
Chao Liu's avatar
Chao Liu committed
62
63
                        src_desc.GetStride(I0) * did0 + src_desc.GetStride(I1) * did1 +
                        src_desc.GetStride(I2) * did2 + src_desc.GetStride(I3) * did3;
Chao Liu's avatar
Chao Liu committed
64
65

                    const unsigned dindex =
Chao Liu's avatar
Chao Liu committed
66
67
                        dst_desc.GetStride(I0) * did0 + dst_desc.GetStride(I1) * did1 +
                        dst_desc.GetStride(I2) * did2 + dst_desc.GetStride(I3) * did3;
Chao Liu's avatar
Chao Liu committed
68

Chao Liu's avatar
Chao Liu committed
69
                    f(p_src[dindex], p_dst[sindex]);
Chao Liu's avatar
Chao Liu committed
70

Chao Liu's avatar
Chao Liu committed
71
#if 0
Chao Liu's avatar
Chao Liu committed
72
73
74
75
76
77
78
79
80
81
82
                    // if(threadIdx.x == 0)
                    {
                        printf("blockwise_4d_tensor_op: 1: thread id %u, \t"
                               "sindex %u, p_src[sindex] %f, \t"
                               "dindex %u, p_dst[dindex] %f\n",
                               threadIdx.x,
                               sindex,
                               p_src[sindex],
                               dindex,
                               p_dst[dindex]);
                    }
Chao Liu's avatar
Chao Liu committed
83
84
85
86
87
88
89
#endif
                }
            }
        }
    }
}

Chao Liu's avatar
Chao Liu committed
90
91
92
template <class TFloat, class SrcDesc, class DstDesc, class F>
__device__ void threadwise_4d_tensor_op(
    SrcDesc, TFloat* const __restrict__ p_src, DstDesc, TFloat* __restrict__ p_dst, F f)
Chao Liu's avatar
Chao Liu committed
93
{
Chao Liu's avatar
Chao Liu committed
94
95
96
97
98
99
100
101
    constexpr auto I0 = Index<0>{};
    constexpr auto I1 = Index<1>{};
    constexpr auto I2 = Index<2>{};
    constexpr auto I3 = Index<3>{};

    constexpr auto src_desc = SrcDesc{};
    constexpr auto dst_desc = DstDesc{};

Chao Liu's avatar
Chao Liu committed
102
#if 0
Chao Liu's avatar
Chao Liu committed
103
104
    if(threadIdx.x == 0)
    {
Chao Liu's avatar
Chao Liu committed
105
106
        print_ConstantTensorDescriptor(src_desc);
        print_ConstantTensorDescriptor(dst_desc);
Chao Liu's avatar
Chao Liu committed
107
108
109
    }
#endif

Chao Liu's avatar
Chao Liu committed
110
    for(unsigned did0 = 0; did0 < src_desc.GetLength(I0); ++did0)
Chao Liu's avatar
Chao Liu committed
111
    {
Chao Liu's avatar
Chao Liu committed
112
        for(unsigned did1 = 0; did1 < src_desc.GetLength(I1); ++did1)
Chao Liu's avatar
Chao Liu committed
113
        {
Chao Liu's avatar
Chao Liu committed
114
            for(unsigned did2 = 0; did2 < src_desc.GetLength(I2); ++did2)
Chao Liu's avatar
Chao Liu committed
115
            {
Chao Liu's avatar
Chao Liu committed
116
                for(unsigned did3 = 0; did3 < src_desc.GetLength(I3); ++did3)
Chao Liu's avatar
Chao Liu committed
117
118
                {
                    const unsigned sindex =
Chao Liu's avatar
Chao Liu committed
119
120
                        src_desc.GetStride(I0) * did0 + src_desc.GetStride(I1) * did1 +
                        src_desc.GetStride(I2) * did2 + src_desc.GetStride(I3) * did3;
Chao Liu's avatar
Chao Liu committed
121
122

                    const unsigned dindex =
Chao Liu's avatar
Chao Liu committed
123
124
                        dst_desc.GetStride(I0) * did0 + dst_desc.GetStride(I1) * did1 +
                        dst_desc.GetStride(I2) * did2 + dst_desc.GetStride(I3) * did3;
Chao Liu's avatar
Chao Liu committed
125

Chao Liu's avatar
Chao Liu committed
126
127
                    f(p_src[sindex], p_dst[dindex]);

Chao Liu's avatar
Chao Liu committed
128
#if 0
Chao Liu's avatar
Chao Liu committed
129
130
131
132
133
134
135
136
137
138
139
140
                    if(threadIdx.x == 0)
                    {
                        printf("threadwise_4d_tensor_op: 1: thread id %u, \t"
                               "sindex %u, p_src[sindex] %f, \t"
                               "dindex %u, p_dst[dindex] %f\n",
                               threadIdx.x,
                               sindex,
                               p_src[sindex],
                               dindex,
                               p_dst[dindex]);
                    }
#endif
Chao Liu's avatar
Chao Liu committed
141
142
143
144
145
146
                }
            }
        }
    }
}

Chao Liu's avatar
Chao Liu committed
147
148
template <class TFloat, class InDesc, class WeiDesc, class OutDesc>
__device__ void threadwise_direct_convolution(InDesc,
Chao Liu's avatar
Chao Liu committed
149
                                              TFloat* const __restrict__ p_in,
Chao Liu's avatar
Chao Liu committed
150
                                              WeiDesc,
Chao Liu's avatar
Chao Liu committed
151
                                              TFloat* const __restrict__ p_wei,
Chao Liu's avatar
Chao Liu committed
152
                                              OutDesc,
Chao Liu's avatar
Chao Liu committed
153
154
                                              TFloat* __restrict__ p_out)
{
Chao Liu's avatar
Chao Liu committed
155
156
157
158
159
160
161
162
163
    constexpr auto I0 = Index<0>{};
    constexpr auto I1 = Index<1>{};
    constexpr auto I2 = Index<2>{};
    constexpr auto I3 = Index<3>{};

    constexpr auto in_desc  = InDesc{};
    constexpr auto wei_desc = WeiDesc{};
    constexpr auto out_desc = OutDesc{};

Chao Liu's avatar
Chao Liu committed
164
#if 0
Chao Liu's avatar
Chao Liu committed
165
166
    if(threadIdx.x == 0)
    {
Chao Liu's avatar
Chao Liu committed
167
168
169
        print_ConstantTensorDescriptor(in_desc);
        print_ConstantTensorDescriptor(wei_desc);
        print_ConstantTensorDescriptor(out_desc);
Chao Liu's avatar
Chao Liu committed
170
171
172
    }
#endif

Chao Liu's avatar
Chao Liu committed
173
    for(unsigned n = 0; n < out_desc.GetLength(I0); ++n)
Chao Liu's avatar
Chao Liu committed
174
    {
Chao Liu's avatar
Chao Liu committed
175
        for(unsigned k = 0; k < out_desc.GetLength(I1); ++k)
Chao Liu's avatar
Chao Liu committed
176
        {
Chao Liu's avatar
Chao Liu committed
177
            for(unsigned ho = 0; ho < out_desc.GetLength(I2); ++ho)
Chao Liu's avatar
Chao Liu committed
178
            {
Chao Liu's avatar
Chao Liu committed
179
                for(unsigned wo = 0; wo < out_desc.GetLength(I3); ++wo)
Chao Liu's avatar
Chao Liu committed
180
                {
Chao Liu's avatar
Chao Liu committed
181
                    for(unsigned c = 0; c < wei_desc.GetLength(I1); ++c)
Chao Liu's avatar
Chao Liu committed
182
                    {
Chao Liu's avatar
Chao Liu committed
183
                        for(unsigned s = 0; s < wei_desc.GetLength(I2); ++s)
Chao Liu's avatar
Chao Liu committed
184
                        {
Chao Liu's avatar
Chao Liu committed
185
                            for(unsigned r = 0; r < wei_desc.GetLength(I3); ++r)
Chao Liu's avatar
Chao Liu committed
186
187
188
189
190
                            {
                                const unsigned hi = ho + s;
                                const unsigned wi = wo + r;

                                const unsigned in_index =
Chao Liu's avatar
Chao Liu committed
191
192
                                    in_desc.GetStride(I0) * n + in_desc.GetStride(I1) * c +
                                    in_desc.GetStride(I2) * hi + in_desc.GetStride(I3) * wi;
Chao Liu's avatar
Chao Liu committed
193
194

                                const unsigned wei_index =
Chao Liu's avatar
Chao Liu committed
195
196
                                    wei_desc.GetStride(I0) * k + wei_desc.GetStride(I1) * c +
                                    wei_desc.GetStride(I2) * s + in_desc.GetStride(I3) * r;
Chao Liu's avatar
Chao Liu committed
197
198

                                const unsigned out_index =
Chao Liu's avatar
Chao Liu committed
199
200
                                    out_desc.GetStride(I0) * n + out_desc.GetStride(I1) * k +
                                    out_desc.GetStride(I2) * ho + out_desc.GetStride(I3) * wo;
Chao Liu's avatar
Chao Liu committed
201
202
203

                                p_out[out_index] += p_wei[wei_index] * p_in[in_index];

Chao Liu's avatar
Chao Liu committed
204
#if 0
Chao Liu's avatar
Chao Liu committed
205
                                if(threadIdx.x == 0)
Chao Liu's avatar
Chao Liu committed
206
                                {
Chao Liu's avatar
Chao Liu committed
207
208
                                    printf("threadwise_direct_convolution: 1: \t"
                                           "threadIdx.x %u\t"
Chao Liu's avatar
Chao Liu committed
209
210
211
                                           "out_index %u, p_out[out_index] %f, \t"
                                           "wei_index %u, p_wei[wei_index] %f, \t"
                                           "in_index %u, p_in[in_index] %f\n",
Chao Liu's avatar
Chao Liu committed
212
213
214
215
216
217
218
                                           threadIdx.x,
                                           out_index,
                                           p_out[out_index],
                                           wei_index,
                                           p_wei[wei_index],
                                           in_index,
                                           p_in[in_index]);
Chao Liu's avatar
Chao Liu committed
219
220
221
222
223
224
225
226
227
228
229
230
                                }
#endif
                            }
                        }
                    }
                }
            }
        }
    }
}

template <class TFloat,
Chao Liu's avatar
Chao Liu committed
231
232
233
          class InDesc,
          class WeiDesc,
          class OutDesc,
Chao Liu's avatar
Chao Liu committed
234
          unsigned OutTileSizeH,
Chao Liu's avatar
Chao Liu committed
235
236
          unsigned OutTileSizeW>
__device__ void blockwise_convolution(InDesc,
Chao Liu's avatar
Chao Liu committed
237
                                      TFloat* const __restrict__ p_in,
Chao Liu's avatar
Chao Liu committed
238
                                      WeiDesc,
Chao Liu's avatar
Chao Liu committed
239
                                      TFloat* const __restrict__ p_wei,
Chao Liu's avatar
Chao Liu committed
240
                                      OutDesc,
Chao Liu's avatar
Chao Liu committed
241
242
                                      TFloat* __restrict__ p_out)
{
Chao Liu's avatar
Chao Liu committed
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
    constexpr auto I0 = Index<0>{};
    constexpr auto I1 = Index<1>{};
    constexpr auto I2 = Index<2>{};
    constexpr auto I3 = Index<3>{};

    constexpr auto in_desc  = InDesc{};
    constexpr auto wei_desc = WeiDesc{};
    constexpr auto out_desc = OutDesc{};

    constexpr unsigned S = wei_desc.GetLength(I2);
    constexpr unsigned R = wei_desc.GetLength(I3);

    constexpr unsigned NPerBlock = out_desc.GetLength(I0);
    constexpr unsigned KPerBlock = out_desc.GetLength(I1);
    constexpr unsigned YPerBlock = (out_desc.GetLength(I2) + OutTileSizeH - 1) / OutTileSizeH;
    constexpr unsigned XPerBlock = (out_desc.GetLength(I3) + OutTileSizeW - 1) / OutTileSizeW;

    constexpr unsigned CPerBlockLoop = in_desc.GetLength(I1);

    constexpr unsigned InTileSizeH = OutTileSizeH + S - 1;
    constexpr unsigned InTileSizeW = OutTileSizeW + R - 1;

Chao Liu's avatar
Chao Liu committed
265
#if 0
Chao Liu's avatar
Chao Liu committed
266
267
    if(threadIdx.x == 0)
    {
Chao Liu's avatar
Chao Liu committed
268
269
270
        print_ConstantTensorDescriptor(in_desc);
        print_ConstantTensorDescriptor(wei_desc);
        print_ConstantTensorDescriptor(out_desc);
Chao Liu's avatar
Chao Liu committed
271
272
273
    }
#endif

Chao Liu's avatar
Chao Liu committed
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
    constexpr auto in_thread_src_desc = make_ConstantTensorDescriptor(
        Sequence<1, CPerBlockLoop, OutTileSizeH + S - 1, OutTileSizeW + R - 1>{},
        in_desc.GetStrides());

    constexpr auto wei_thread_src_desc =
        make_ConstantTensorDescriptor(Sequence<1, CPerBlockLoop, S, R>{}, wei_desc.GetStrides());

    constexpr auto out_thread_src_desc = make_ConstantTensorDescriptor(
        Sequence<1, 1, OutTileSizeH, OutTileSizeW>{}, out_desc.GetStrides());

    constexpr auto in_thread_dst_desc =
        make_ConstantTensorDescriptor(in_thread_src_desc.GetLengths());

    constexpr auto wei_thread_dst_desc =
        make_ConstantTensorDescriptor(wei_thread_src_desc.GetLengths());

    constexpr auto out_thread_dst_desc =
        make_ConstantTensorDescriptor(out_thread_src_desc.GetLengths());
Chao Liu's avatar
Chao Liu committed
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324

    const unsigned thread_sz = blockDim.x * blockDim.y * blockDim.z;

    const unsigned thread_id =
        threadIdx.x + threadIdx.y * blockDim.x + threadIdx.z * (blockDim.y * blockDim.x);

    for(unsigned thread_work_id = thread_id;
        thread_work_id < NPerBlock * KPerBlock * YPerBlock * XPerBlock;
        thread_work_id += thread_sz)
    {
        unsigned itmp             = thread_work_id;
        unsigned n_thread_work_id = itmp / (KPerBlock * YPerBlock * XPerBlock);
        itmp -= n_thread_work_id * (KPerBlock * YPerBlock * XPerBlock);
        unsigned k_thread_work_id = itmp / (YPerBlock * XPerBlock);
        itmp -= k_thread_work_id * (YPerBlock * XPerBlock);
        unsigned y_thread_work_id = itmp / XPerBlock;
        unsigned x_thread_work_id = itmp - y_thread_work_id * XPerBlock;

        unsigned n_thread_work_begin  = n_thread_work_id * 1;
        unsigned k_thread_work_begin  = k_thread_work_id * 1;
        unsigned ho_thread_work_begin = y_thread_work_id * OutTileSizeH;
        unsigned wo_thread_work_begin = x_thread_work_id * OutTileSizeW;

        unsigned hi_thread_work_begin = ho_thread_work_begin; // minus padding
        unsigned wi_thread_work_begin = wo_thread_work_begin; // minus padding

        TFloat p_in_thread[1 * CPerBlockLoop * InTileSizeH * InTileSizeW];
        TFloat p_wei_thread[1 * CPerBlockLoop * S * R];
        TFloat p_out_thread[1 * 1 * OutTileSizeH * OutTileSizeW];

        auto f_copy = [](const TFloat& src, TFloat& dst) { dst = src; };

        // copy input tensor into register
Chao Liu's avatar
Chao Liu committed
325
326
327
328
        threadwise_4d_tensor_op<TFloat,
                                decltype(in_thread_src_desc),
                                decltype(in_thread_dst_desc),
                                decltype(f_copy)>(
Chao Liu's avatar
Chao Liu committed
329
            in_thread_src_desc,
Chao Liu's avatar
Chao Liu committed
330
331
            p_in + in_desc.Get1dIndex(
                       n_thread_work_begin, 0, hi_thread_work_begin, wi_thread_work_begin),
Chao Liu's avatar
Chao Liu committed
332
            in_thread_dst_desc,
Chao Liu's avatar
Chao Liu committed
333
334
335
336
            p_in_thread,
            f_copy);

        // copy weight tensor into register
Chao Liu's avatar
Chao Liu committed
337
338
339
340
        threadwise_4d_tensor_op<TFloat,
                                decltype(wei_thread_src_desc),
                                decltype(wei_thread_dst_desc),
                                decltype(f_copy)>(
Chao Liu's avatar
Chao Liu committed
341
342
343
            wei_thread_src_desc,
            p_wei + wei_desc.Get1dIndex(k_thread_work_begin, 0, 0, 0),
            wei_thread_dst_desc,
Chao Liu's avatar
Chao Liu committed
344
345
346
347
            p_wei_thread,
            f_copy);

        // copy output tensor into register
Chao Liu's avatar
Chao Liu committed
348
349
350
351
352
353
354
355
356
357
358
        threadwise_4d_tensor_op<TFloat,
                                decltype(out_thread_src_desc),
                                decltype(out_thread_dst_desc),
                                decltype(f_copy)>(out_thread_src_desc,
                                                  p_out + out_desc.Get1dIndex(n_thread_work_begin,
                                                                              k_thread_work_begin,
                                                                              ho_thread_work_begin,
                                                                              wo_thread_work_begin),
                                                  out_thread_dst_desc,
                                                  p_out_thread,
                                                  f_copy);
Chao Liu's avatar
Chao Liu committed
359
360

        // threadwise convolution
Chao Liu's avatar
Chao Liu committed
361
362
363
364
365
366
367
368
369
        threadwise_direct_convolution<TFloat,
                                      decltype(in_thread_dst_desc),
                                      decltype(wei_thread_dst_desc),
                                      decltype(out_thread_dst_desc)>(in_thread_dst_desc,
                                                                     p_in_thread,
                                                                     wei_thread_dst_desc,
                                                                     p_wei_thread,
                                                                     out_thread_dst_desc,
                                                                     p_out_thread);
Chao Liu's avatar
Chao Liu committed
370
371

        // accumulate output tensor into device mem
Chao Liu's avatar
Chao Liu committed
372
373
374
375
376
377
378
379
380
381
382
        threadwise_4d_tensor_op<TFloat,
                                decltype(out_thread_dst_desc),
                                decltype(out_thread_src_desc),
                                decltype(f_copy)>(out_thread_dst_desc,
                                                  p_out_thread,
                                                  out_thread_src_desc,
                                                  p_out + out_desc.Get1dIndex(n_thread_work_begin,
                                                                              k_thread_work_begin,
                                                                              ho_thread_work_begin,
                                                                              wo_thread_work_begin),
                                                  f_copy);
Chao Liu's avatar
Chao Liu committed
383
384
385
386
    }
}

template <class TFloat,
Chao Liu's avatar
Chao Liu committed
387
388
389
          class InDesc,
          class WeiDesc,
          class OutDesc,
Chao Liu's avatar
Chao Liu committed
390
391
          unsigned NPerBlock,
          unsigned KPerBlock,
Chao Liu's avatar
Chao Liu committed
392
393
394
          unsigned CPerBlockLoop,
          unsigned OutTileSizeH,
          unsigned OutTileSizeW,
Chao Liu's avatar
Chao Liu committed
395
396
          unsigned YPerBlock,
          unsigned XPerBlock,
Chao Liu's avatar
Chao Liu committed
397
398
399
400
401
          unsigned NBlockCopyLen0,
          unsigned NBlockCopyLen1,
          unsigned NBlockCopyLen2,
          unsigned NBlockCopyLen3>
__global__ void gridwise_convolution(InDesc,
Chao Liu's avatar
Chao Liu committed
402
                                     TFloat* const __restrict__ p_in,
Chao Liu's avatar
Chao Liu committed
403
                                     WeiDesc,
Chao Liu's avatar
Chao Liu committed
404
                                     TFloat* const __restrict__ p_wei,
Chao Liu's avatar
Chao Liu committed
405
                                     OutDesc,
Chao Liu's avatar
Chao Liu committed
406
407
                                     TFloat* __restrict__ p_out)
{
Chao Liu's avatar
Chao Liu committed
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
    constexpr auto I0 = Index<0>{};
    constexpr auto I1 = Index<1>{};
    constexpr auto I2 = Index<2>{};
    constexpr auto I3 = Index<3>{};

    constexpr auto in_desc  = InDesc{};
    constexpr auto wei_desc = WeiDesc{};
    constexpr auto out_desc = OutDesc{};

    constexpr unsigned S = wei_desc.GetLength(I2);
    constexpr unsigned R = wei_desc.GetLength(I3);

    constexpr unsigned HoPerBlock = OutTileSizeH * YPerBlock;
    constexpr unsigned WoPerBlock = OutTileSizeW * XPerBlock;

    constexpr unsigned HiPerBlock = YPerBlock * OutTileSizeH + S - 1;
    constexpr unsigned WiPerBlock = XPerBlock * OutTileSizeW + R - 1;

    constexpr unsigned NBlockWork = (out_desc.GetLength(I0) + NPerBlock - 1) / NPerBlock;
    constexpr unsigned KBlockWork = (out_desc.GetLength(I1) + KPerBlock - 1) / KPerBlock;
    constexpr unsigned YBlockWork = (out_desc.GetLength(I2) + HoPerBlock - 1) / HoPerBlock;
    constexpr unsigned XBlockWork = (out_desc.GetLength(I3) + WoPerBlock - 1) / WoPerBlock;

    constexpr auto in_block_glb_desc = make_ConstantTensorDescriptor(
        Sequence<NPerBlock, CPerBlockLoop, HiPerBlock, WiPerBlock>{}, in_desc.GetStrides());

    constexpr auto wei_block_glb_desc = make_ConstantTensorDescriptor(
        Sequence<KPerBlock, CPerBlockLoop, S, R>{}, wei_desc.GetStrides());
Chao Liu's avatar
Chao Liu committed
436

Chao Liu's avatar
Chao Liu committed
437
438
    constexpr auto out_block_glb_desc = make_ConstantTensorDescriptor(
        Sequence<NPerBlock, KPerBlock, HoPerBlock, WoPerBlock>{}, out_desc.GetStrides());
Chao Liu's avatar
Chao Liu committed
439

Chao Liu's avatar
Chao Liu committed
440
441
442
443
444
445
446
447
448
449
450
451
452
453
    constexpr auto in_block_lds_desc =
        make_ConstantTensorDescriptor(in_block_glb_desc.GetLengths());
    constexpr auto wei_block_lds_desc =
        make_ConstantTensorDescriptor(wei_block_glb_desc.GetLengths());
    constexpr auto out_block_lds_desc =
        make_ConstantTensorDescriptor(out_block_glb_desc.GetLengths());

    constexpr unsigned in_block_size  = in_block_lds_desc.GetElementSize();
    constexpr unsigned wei_block_size = wei_block_lds_desc.GetElementSize();
    constexpr unsigned out_block_size = out_block_lds_desc.GetElementSize();

    __shared__ TFloat p_in_block[in_block_size];
    __shared__ TFloat p_wei_block[wei_block_size];
    __shared__ TFloat p_out_block[out_block_size];
Chao Liu's avatar
Chao Liu committed
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476

    const unsigned block_id =
        blockIdx.x + blockIdx.y * gridDim.x + blockIdx.z * (gridDim.y * gridDim.x);

    unsigned itmp            = block_id;
    unsigned n_block_work_id = itmp / (KBlockWork * YBlockWork * XBlockWork);
    itmp -= n_block_work_id * (KBlockWork * YBlockWork * XBlockWork);
    unsigned k_block_work_id = itmp / (YBlockWork * XBlockWork);
    itmp -= k_block_work_id * (YBlockWork * XBlockWork);
    unsigned y_block_work_id = itmp / XBlockWork;
    unsigned x_block_work_id = itmp - y_block_work_id * XBlockWork;

    unsigned n_block_work_begin = n_block_work_id * NPerBlock;
    unsigned k_block_work_begin = k_block_work_id * KPerBlock;
    unsigned y_block_work_begin = y_block_work_id * YPerBlock;
    unsigned x_block_work_begin = x_block_work_id * XPerBlock;

    unsigned ho_block_work_begin = y_block_work_begin * OutTileSizeH;
    unsigned wo_block_work_begin = x_block_work_begin * OutTileSizeW;

    unsigned hi_block_work_begin = ho_block_work_begin; // minus padding
    unsigned wi_block_work_begin = wo_block_work_begin; // minus padding

Chao Liu's avatar
Chao Liu committed
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
#if 0
    if(threadIdx.x == 0)
    {
        print_ConstantTensorDescriptor( in_desc, "gridwise_convolution:  in_desc: ");
        print_ConstantTensorDescriptor(wei_desc, "gridwise_convolution: wei_desc: ");
        print_ConstantTensorDescriptor(out_desc, "gridwise_convolution: out_desc: ");
        print_ConstantTensorDescriptor( in_block_glb_desc, "gridwise_convolution:  in_block_glb_desc: ");
        print_ConstantTensorDescriptor(wei_block_glb_desc, "gridwise_convolution: wei_block_glb_desc: ");
        print_ConstantTensorDescriptor(out_block_glb_desc, "gridwise_convolution: out_block_glb_desc: ");
        print_ConstantTensorDescriptor( in_block_lds_desc, "gridwise_convolution:  in_block_lds_desc: ");
        print_ConstantTensorDescriptor(wei_block_lds_desc, "gridwise_convolution: wei_block_lds_desc: ");
        print_ConstantTensorDescriptor(out_block_lds_desc, "gridwise_convolution: out_block_lds_desc: ");

        printf("NBlockWork %u, KBlockWork %u, YBlockWork %u, XBlockWork %u \t"
               "block_id %u, n_block_work_id %u, k_block_work_id %u, y_block_work_id %u, "
               "x_block_work_id %u\n",
               NBlockWork,
               KBlockWork,
               YBlockWork,
               XBlockWork,
               block_id,
               n_block_work_id,
               k_block_work_id,
               y_block_work_id,
               x_block_work_id);
    }
#endif

    for(unsigned c_block_work_begin = 0; c_block_work_begin < in_desc.GetLength(I1);
Chao Liu's avatar
Chao Liu committed
506
507
508
509
510
        c_block_work_begin += CPerBlockLoop)
    {
        auto f_copy = [](const TFloat& src, TFloat& dst) { dst = src; };

        // copy input tensor to LDS
Chao Liu's avatar
Chao Liu committed
511
512
513
514
515
516
517
518
519
520
521
522
523
524
        blockwise_4d_tensor_op<TFloat,
                               decltype(in_block_glb_desc),
                               decltype(in_block_lds_desc),
                               NBlockCopyLen0,
                               NBlockCopyLen1,
                               NBlockCopyLen2,
                               NBlockCopyLen3,
                               decltype(f_copy)>(
            in_block_glb_desc,
            p_in + in_block_glb_desc.Get1dIndex(n_block_work_begin,
                                                c_block_work_begin,
                                                hi_block_work_begin,
                                                wi_block_work_begin),
            in_block_lds_desc,
Chao Liu's avatar
Chao Liu committed
525
526
527
            p_in_block,
            f_copy);

Chao Liu's avatar
Chao Liu committed
528
#if 1
Chao Liu's avatar
Chao Liu committed
529
        // copy weight tensor to LDS
Chao Liu's avatar
Chao Liu committed
530
531
532
533
534
535
536
537
538
539
540
        blockwise_4d_tensor_op<TFloat,
                               decltype(wei_block_glb_desc),
                               decltype(wei_block_lds_desc),
                               NBlockCopyLen0,
                               NBlockCopyLen1,
                               NBlockCopyLen2,
                               NBlockCopyLen3,
                               decltype(f_copy)>(
            wei_block_glb_desc,
            p_wei + wei_block_glb_desc.Get1dIndex(k_block_work_begin, c_block_work_begin, 0, 0),
            wei_block_lds_desc,
Chao Liu's avatar
Chao Liu committed
541
542
543
544
            p_wei_block,
            f_copy);

        // copy output tensor to LDS
Chao Liu's avatar
Chao Liu committed
545
546
547
548
549
550
551
552
553
554
555
556
557
558
        blockwise_4d_tensor_op<TFloat,
                               decltype(out_block_glb_desc),
                               decltype(out_block_lds_desc),
                               NBlockCopyLen0,
                               NBlockCopyLen1,
                               NBlockCopyLen2,
                               NBlockCopyLen3,
                               decltype(f_copy)>(
            out_block_glb_desc,
            p_out + out_block_glb_desc.Get1dIndex(n_block_work_begin,
                                                  k_block_work_begin,
                                                  ho_block_work_begin,
                                                  wo_block_work_begin),
            out_block_lds_desc,
Chao Liu's avatar
Chao Liu committed
559
560
561
            p_out_block,
            f_copy);

Chao Liu's avatar
Chao Liu committed
562
#if 0
Chao Liu's avatar
Chao Liu committed
563
        __syncthreads();
Chao Liu's avatar
Chao Liu committed
564
#endif
Chao Liu's avatar
Chao Liu committed
565

Chao Liu's avatar
Chao Liu committed
566
567
        // blockwise convolution
        blockwise_convolution<TFloat,
Chao Liu's avatar
Chao Liu committed
568
569
570
                              decltype(in_block_lds_desc),
                              decltype(wei_block_lds_desc),
                              decltype(out_block_lds_desc),
Chao Liu's avatar
Chao Liu committed
571
                              OutTileSizeH,
Chao Liu's avatar
Chao Liu committed
572
573
574
575
576
577
                              OutTileSizeW>(in_block_lds_desc,
                                            p_in_block,
                                            wei_block_lds_desc,
                                            p_wei_block,
                                            out_block_lds_desc,
                                            p_out_block);
Chao Liu's avatar
Chao Liu committed
578

Chao Liu's avatar
Chao Liu committed
579
#if 0
Chao Liu's avatar
Chao Liu committed
580
        __syncthreads();
Chao Liu's avatar
Chao Liu committed
581
#endif
Chao Liu's avatar
Chao Liu committed
582
583

        // accum output tensor from LDS to device mem
Chao Liu's avatar
Chao Liu committed
584
585
586
587
588
589
590
591
592
        blockwise_4d_tensor_op<TFloat,
                               decltype(out_block_lds_desc),
                               decltype(out_block_glb_desc),
                               NBlockCopyLen0,
                               NBlockCopyLen1,
                               NBlockCopyLen2,
                               NBlockCopyLen3,
                               decltype(f_copy)>(
            out_block_lds_desc,
Chao Liu's avatar
Chao Liu committed
593
            p_out_block,
Chao Liu's avatar
Chao Liu committed
594
595
596
597
598
            out_block_glb_desc,
            p_out + out_block_glb_desc.Get1dIndex(n_block_work_begin,
                                                  k_block_work_begin,
                                                  ho_block_work_begin,
                                                  wo_block_work_begin),
Chao Liu's avatar
Chao Liu committed
599
            f_copy);
Chao Liu's avatar
Chao Liu committed
600
#endif
Chao Liu's avatar
Chao Liu committed
601
    }
Chao Liu's avatar
Chao Liu committed
602
}