direct_convolution.cuh 20.8 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
#pragma once
Chao Liu's avatar
Chao Liu committed
2
#include "constant_tensor_descriptor.cuh"
3
4
#include "blockwise_tensor_op.cuh"
#include "threadwise_tensor_op.cuh"
Chao Liu's avatar
Chao Liu committed
5

Chao Liu's avatar
Chao Liu committed
6
7
template <class TFloat, class InDesc, class WeiDesc, class OutDesc>
__device__ void threadwise_direct_convolution(InDesc,
Chao Liu's avatar
Chao Liu committed
8
                                              TFloat* const __restrict__ p_in,
Chao Liu's avatar
Chao Liu committed
9
                                              WeiDesc,
Chao Liu's avatar
Chao Liu committed
10
                                              TFloat* const __restrict__ p_wei,
Chao Liu's avatar
Chao Liu committed
11
                                              OutDesc,
Chao Liu's avatar
Chao Liu committed
12
13
                                              TFloat* __restrict__ p_out)
{
Chao Liu's avatar
Chao Liu committed
14
15
16
17
18
19
20
21
22
    constexpr auto I0 = Index<0>{};
    constexpr auto I1 = Index<1>{};
    constexpr auto I2 = Index<2>{};
    constexpr auto I3 = Index<3>{};

    constexpr auto in_desc  = InDesc{};
    constexpr auto wei_desc = WeiDesc{};
    constexpr auto out_desc = OutDesc{};

Chao Liu's avatar
Chao Liu committed
23
#if 0
Chao Liu's avatar
Chao Liu committed
24
25
    if(threadIdx.x == 0)
    {
Chao Liu's avatar
Chao Liu committed
26
27
28
        print_ConstantTensorDescriptor(in_desc);
        print_ConstantTensorDescriptor(wei_desc);
        print_ConstantTensorDescriptor(out_desc);
Chao Liu's avatar
Chao Liu committed
29
30
31
    }
#endif

Chao Liu's avatar
Chao Liu committed
32
    for(unsigned n = 0; n < out_desc.GetLength(I0); ++n)
Chao Liu's avatar
Chao Liu committed
33
    {
Chao Liu's avatar
Chao Liu committed
34
        for(unsigned k = 0; k < out_desc.GetLength(I1); ++k)
Chao Liu's avatar
Chao Liu committed
35
        {
Chao Liu's avatar
Chao Liu committed
36
            for(unsigned ho = 0; ho < out_desc.GetLength(I2); ++ho)
Chao Liu's avatar
Chao Liu committed
37
            {
Chao Liu's avatar
Chao Liu committed
38
                for(unsigned wo = 0; wo < out_desc.GetLength(I3); ++wo)
Chao Liu's avatar
Chao Liu committed
39
                {
Chao Liu's avatar
Chao Liu committed
40
                    for(unsigned c = 0; c < wei_desc.GetLength(I1); ++c)
Chao Liu's avatar
Chao Liu committed
41
                    {
Chao Liu's avatar
Chao Liu committed
42
                        for(unsigned s = 0; s < wei_desc.GetLength(I2); ++s)
Chao Liu's avatar
Chao Liu committed
43
                        {
Chao Liu's avatar
Chao Liu committed
44
                            for(unsigned r = 0; r < wei_desc.GetLength(I3); ++r)
Chao Liu's avatar
Chao Liu committed
45
46
47
48
49
                            {
                                const unsigned hi = ho + s;
                                const unsigned wi = wo + r;

                                const unsigned in_index =
Chao Liu's avatar
Chao Liu committed
50
51
                                    in_desc.GetStride(I0) * n + in_desc.GetStride(I1) * c +
                                    in_desc.GetStride(I2) * hi + in_desc.GetStride(I3) * wi;
Chao Liu's avatar
Chao Liu committed
52
53

                                const unsigned wei_index =
Chao Liu's avatar
Chao Liu committed
54
55
                                    wei_desc.GetStride(I0) * k + wei_desc.GetStride(I1) * c +
                                    wei_desc.GetStride(I2) * s + in_desc.GetStride(I3) * r;
Chao Liu's avatar
Chao Liu committed
56
57

                                const unsigned out_index =
Chao Liu's avatar
Chao Liu committed
58
59
                                    out_desc.GetStride(I0) * n + out_desc.GetStride(I1) * k +
                                    out_desc.GetStride(I2) * ho + out_desc.GetStride(I3) * wo;
Chao Liu's avatar
Chao Liu committed
60
61
62

                                p_out[out_index] += p_wei[wei_index] * p_in[in_index];

Chao Liu's avatar
Chao Liu committed
63
#if 0
Chao Liu's avatar
Chao Liu committed
64
                                if(threadIdx.x == 0)
Chao Liu's avatar
Chao Liu committed
65
                                {
Chao Liu's avatar
Chao Liu committed
66
67
                                    printf("threadwise_direct_convolution: 1: \t"
                                           "threadIdx.x %u\t"
Chao Liu's avatar
Chao Liu committed
68
69
70
                                           "out_index %u, p_out[out_index] %f, \t"
                                           "wei_index %u, p_wei[wei_index] %f, \t"
                                           "in_index %u, p_in[in_index] %f\n",
Chao Liu's avatar
Chao Liu committed
71
72
73
74
75
76
77
                                           threadIdx.x,
                                           out_index,
                                           p_out[out_index],
                                           wei_index,
                                           p_wei[wei_index],
                                           in_index,
                                           p_in[in_index]);
Chao Liu's avatar
Chao Liu committed
78
79
80
81
82
83
84
85
86
87
88
89
                                }
#endif
                            }
                        }
                    }
                }
            }
        }
    }
}

template <class TFloat,
Chao Liu's avatar
Chao Liu committed
90
91
92
          class InDesc,
          class WeiDesc,
          class OutDesc,
Chao Liu's avatar
Chao Liu committed
93
          unsigned OutTileSizeH,
94
95
          unsigned OutTileSizeW,
          unsigned BlockSize>
Chao Liu's avatar
Chao Liu committed
96
__device__ void blockwise_convolution(InDesc,
Chao Liu's avatar
Chao Liu committed
97
                                      TFloat* const __restrict__ p_in,
Chao Liu's avatar
Chao Liu committed
98
                                      WeiDesc,
Chao Liu's avatar
Chao Liu committed
99
                                      TFloat* const __restrict__ p_wei,
Chao Liu's avatar
Chao Liu committed
100
                                      OutDesc,
Chao Liu's avatar
Chao Liu committed
101
102
                                      TFloat* __restrict__ p_out)
{
Chao Liu's avatar
Chao Liu committed
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
    constexpr auto I0 = Index<0>{};
    constexpr auto I1 = Index<1>{};
    constexpr auto I2 = Index<2>{};
    constexpr auto I3 = Index<3>{};

    constexpr auto in_desc  = InDesc{};
    constexpr auto wei_desc = WeiDesc{};
    constexpr auto out_desc = OutDesc{};

    constexpr unsigned S = wei_desc.GetLength(I2);
    constexpr unsigned R = wei_desc.GetLength(I3);

    constexpr unsigned NPerBlock = out_desc.GetLength(I0);
    constexpr unsigned KPerBlock = out_desc.GetLength(I1);
    constexpr unsigned YPerBlock = (out_desc.GetLength(I2) + OutTileSizeH - 1) / OutTileSizeH;
    constexpr unsigned XPerBlock = (out_desc.GetLength(I3) + OutTileSizeW - 1) / OutTileSizeW;

    constexpr unsigned CPerBlockLoop = in_desc.GetLength(I1);

    constexpr unsigned InTileSizeH = OutTileSizeH + S - 1;
    constexpr unsigned InTileSizeW = OutTileSizeW + R - 1;

Chao Liu's avatar
Chao Liu committed
125
#if 0
Chao Liu's avatar
Chao Liu committed
126
127
    if(threadIdx.x == 0)
    {
Chao Liu's avatar
Chao Liu committed
128
129
130
        print_ConstantTensorDescriptor(in_desc);
        print_ConstantTensorDescriptor(wei_desc);
        print_ConstantTensorDescriptor(out_desc);
Chao Liu's avatar
Chao Liu committed
131
132
133
    }
#endif

Chao Liu's avatar
Chao Liu committed
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
    constexpr auto in_thread_src_desc = make_ConstantTensorDescriptor(
        Sequence<1, CPerBlockLoop, OutTileSizeH + S - 1, OutTileSizeW + R - 1>{},
        in_desc.GetStrides());

    constexpr auto wei_thread_src_desc =
        make_ConstantTensorDescriptor(Sequence<1, CPerBlockLoop, S, R>{}, wei_desc.GetStrides());

    constexpr auto out_thread_src_desc = make_ConstantTensorDescriptor(
        Sequence<1, 1, OutTileSizeH, OutTileSizeW>{}, out_desc.GetStrides());

    constexpr auto in_thread_dst_desc =
        make_ConstantTensorDescriptor(in_thread_src_desc.GetLengths());

    constexpr auto wei_thread_dst_desc =
        make_ConstantTensorDescriptor(wei_thread_src_desc.GetLengths());

    constexpr auto out_thread_dst_desc =
        make_ConstantTensorDescriptor(out_thread_src_desc.GetLengths());
Chao Liu's avatar
Chao Liu committed
152

153
    const unsigned thread_id = threadIdx.x;
Chao Liu's avatar
Chao Liu committed
154
155
156

    for(unsigned thread_work_id = thread_id;
        thread_work_id < NPerBlock * KPerBlock * YPerBlock * XPerBlock;
157
        thread_work_id += BlockSize)
Chao Liu's avatar
Chao Liu committed
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
    {
        unsigned itmp             = thread_work_id;
        unsigned n_thread_work_id = itmp / (KPerBlock * YPerBlock * XPerBlock);
        itmp -= n_thread_work_id * (KPerBlock * YPerBlock * XPerBlock);
        unsigned k_thread_work_id = itmp / (YPerBlock * XPerBlock);
        itmp -= k_thread_work_id * (YPerBlock * XPerBlock);
        unsigned y_thread_work_id = itmp / XPerBlock;
        unsigned x_thread_work_id = itmp - y_thread_work_id * XPerBlock;

        unsigned n_thread_work_begin  = n_thread_work_id * 1;
        unsigned k_thread_work_begin  = k_thread_work_id * 1;
        unsigned ho_thread_work_begin = y_thread_work_id * OutTileSizeH;
        unsigned wo_thread_work_begin = x_thread_work_id * OutTileSizeW;

        unsigned hi_thread_work_begin = ho_thread_work_begin; // minus padding
        unsigned wi_thread_work_begin = wo_thread_work_begin; // minus padding

        TFloat p_in_thread[1 * CPerBlockLoop * InTileSizeH * InTileSizeW];
        TFloat p_wei_thread[1 * CPerBlockLoop * S * R];
        TFloat p_out_thread[1 * 1 * OutTileSizeH * OutTileSizeW];

        auto f_copy = [](const TFloat& src, TFloat& dst) { dst = src; };

        // copy input tensor into register
Chao Liu's avatar
Chao Liu committed
182
183
184
185
        threadwise_4d_tensor_op<TFloat,
                                decltype(in_thread_src_desc),
                                decltype(in_thread_dst_desc),
                                decltype(f_copy)>(
Chao Liu's avatar
Chao Liu committed
186
            in_thread_src_desc,
Chao Liu's avatar
Chao Liu committed
187
188
            p_in + in_desc.Get1dIndex(
                       n_thread_work_begin, 0, hi_thread_work_begin, wi_thread_work_begin),
Chao Liu's avatar
Chao Liu committed
189
            in_thread_dst_desc,
Chao Liu's avatar
Chao Liu committed
190
191
192
193
            p_in_thread,
            f_copy);

        // copy weight tensor into register
Chao Liu's avatar
Chao Liu committed
194
195
196
197
        threadwise_4d_tensor_op<TFloat,
                                decltype(wei_thread_src_desc),
                                decltype(wei_thread_dst_desc),
                                decltype(f_copy)>(
Chao Liu's avatar
Chao Liu committed
198
199
200
            wei_thread_src_desc,
            p_wei + wei_desc.Get1dIndex(k_thread_work_begin, 0, 0, 0),
            wei_thread_dst_desc,
Chao Liu's avatar
Chao Liu committed
201
202
203
204
            p_wei_thread,
            f_copy);

        // copy output tensor into register
Chao Liu's avatar
Chao Liu committed
205
206
207
208
209
210
211
212
213
214
215
        threadwise_4d_tensor_op<TFloat,
                                decltype(out_thread_src_desc),
                                decltype(out_thread_dst_desc),
                                decltype(f_copy)>(out_thread_src_desc,
                                                  p_out + out_desc.Get1dIndex(n_thread_work_begin,
                                                                              k_thread_work_begin,
                                                                              ho_thread_work_begin,
                                                                              wo_thread_work_begin),
                                                  out_thread_dst_desc,
                                                  p_out_thread,
                                                  f_copy);
Chao Liu's avatar
Chao Liu committed
216
217

        // threadwise convolution
Chao Liu's avatar
Chao Liu committed
218
219
220
221
222
223
224
225
226
        threadwise_direct_convolution<TFloat,
                                      decltype(in_thread_dst_desc),
                                      decltype(wei_thread_dst_desc),
                                      decltype(out_thread_dst_desc)>(in_thread_dst_desc,
                                                                     p_in_thread,
                                                                     wei_thread_dst_desc,
                                                                     p_wei_thread,
                                                                     out_thread_dst_desc,
                                                                     p_out_thread);
Chao Liu's avatar
Chao Liu committed
227
228

        // accumulate output tensor into device mem
Chao Liu's avatar
Chao Liu committed
229
230
231
232
233
234
235
236
237
238
239
        threadwise_4d_tensor_op<TFloat,
                                decltype(out_thread_dst_desc),
                                decltype(out_thread_src_desc),
                                decltype(f_copy)>(out_thread_dst_desc,
                                                  p_out_thread,
                                                  out_thread_src_desc,
                                                  p_out + out_desc.Get1dIndex(n_thread_work_begin,
                                                                              k_thread_work_begin,
                                                                              ho_thread_work_begin,
                                                                              wo_thread_work_begin),
                                                  f_copy);
Chao Liu's avatar
Chao Liu committed
240
241
242
243
    }
}

template <class TFloat,
Chao Liu's avatar
Chao Liu committed
244
245
246
          class InDesc,
          class WeiDesc,
          class OutDesc,
Chao Liu's avatar
Chao Liu committed
247
248
          unsigned NPerBlock,
          unsigned KPerBlock,
Chao Liu's avatar
Chao Liu committed
249
250
251
          unsigned CPerBlockLoop,
          unsigned OutTileSizeH,
          unsigned OutTileSizeW,
Chao Liu's avatar
Chao Liu committed
252
253
          unsigned YPerBlock,
          unsigned XPerBlock,
Chao Liu's avatar
Chao Liu committed
254
255
256
          unsigned NBlockCopyLen0,
          unsigned NBlockCopyLen1,
          unsigned NBlockCopyLen2,
257
258
259
          unsigned NBlockCopyLen3,
          unsigned BlockSize,
          unsigned GridSize>
Chao Liu's avatar
Chao Liu committed
260
__global__ void gridwise_convolution(InDesc,
Chao Liu's avatar
Chao Liu committed
261
                                     TFloat* const __restrict__ p_in,
Chao Liu's avatar
Chao Liu committed
262
                                     WeiDesc,
Chao Liu's avatar
Chao Liu committed
263
                                     TFloat* const __restrict__ p_wei,
Chao Liu's avatar
Chao Liu committed
264
                                     OutDesc,
Chao Liu's avatar
Chao Liu committed
265
266
                                     TFloat* __restrict__ p_out)
{
Chao Liu's avatar
Chao Liu committed
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
    constexpr auto I0 = Index<0>{};
    constexpr auto I1 = Index<1>{};
    constexpr auto I2 = Index<2>{};
    constexpr auto I3 = Index<3>{};

    constexpr auto in_desc  = InDesc{};
    constexpr auto wei_desc = WeiDesc{};
    constexpr auto out_desc = OutDesc{};

    constexpr unsigned S = wei_desc.GetLength(I2);
    constexpr unsigned R = wei_desc.GetLength(I3);

    constexpr unsigned HoPerBlock = OutTileSizeH * YPerBlock;
    constexpr unsigned WoPerBlock = OutTileSizeW * XPerBlock;

    constexpr unsigned HiPerBlock = YPerBlock * OutTileSizeH + S - 1;
    constexpr unsigned WiPerBlock = XPerBlock * OutTileSizeW + R - 1;

    constexpr unsigned NBlockWork = (out_desc.GetLength(I0) + NPerBlock - 1) / NPerBlock;
    constexpr unsigned KBlockWork = (out_desc.GetLength(I1) + KPerBlock - 1) / KPerBlock;
    constexpr unsigned YBlockWork = (out_desc.GetLength(I2) + HoPerBlock - 1) / HoPerBlock;
    constexpr unsigned XBlockWork = (out_desc.GetLength(I3) + WoPerBlock - 1) / WoPerBlock;

    constexpr auto in_block_glb_desc = make_ConstantTensorDescriptor(
        Sequence<NPerBlock, CPerBlockLoop, HiPerBlock, WiPerBlock>{}, in_desc.GetStrides());

    constexpr auto wei_block_glb_desc = make_ConstantTensorDescriptor(
        Sequence<KPerBlock, CPerBlockLoop, S, R>{}, wei_desc.GetStrides());
Chao Liu's avatar
Chao Liu committed
295

Chao Liu's avatar
Chao Liu committed
296
297
    constexpr auto out_block_glb_desc = make_ConstantTensorDescriptor(
        Sequence<NPerBlock, KPerBlock, HoPerBlock, WoPerBlock>{}, out_desc.GetStrides());
Chao Liu's avatar
Chao Liu committed
298

Chao Liu's avatar
Chao Liu committed
299
300
301
302
303
304
305
306
307
308
309
310
311
312
    constexpr auto in_block_lds_desc =
        make_ConstantTensorDescriptor(in_block_glb_desc.GetLengths());
    constexpr auto wei_block_lds_desc =
        make_ConstantTensorDescriptor(wei_block_glb_desc.GetLengths());
    constexpr auto out_block_lds_desc =
        make_ConstantTensorDescriptor(out_block_glb_desc.GetLengths());

    constexpr unsigned in_block_size  = in_block_lds_desc.GetElementSize();
    constexpr unsigned wei_block_size = wei_block_lds_desc.GetElementSize();
    constexpr unsigned out_block_size = out_block_lds_desc.GetElementSize();

    __shared__ TFloat p_in_block[in_block_size];
    __shared__ TFloat p_wei_block[wei_block_size];
    __shared__ TFloat p_out_block[out_block_size];
Chao Liu's avatar
Chao Liu committed
313

314
    const unsigned block_id = blockIdx.x;
Chao Liu's avatar
Chao Liu committed
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334

    unsigned itmp            = block_id;
    unsigned n_block_work_id = itmp / (KBlockWork * YBlockWork * XBlockWork);
    itmp -= n_block_work_id * (KBlockWork * YBlockWork * XBlockWork);
    unsigned k_block_work_id = itmp / (YBlockWork * XBlockWork);
    itmp -= k_block_work_id * (YBlockWork * XBlockWork);
    unsigned y_block_work_id = itmp / XBlockWork;
    unsigned x_block_work_id = itmp - y_block_work_id * XBlockWork;

    unsigned n_block_work_begin = n_block_work_id * NPerBlock;
    unsigned k_block_work_begin = k_block_work_id * KPerBlock;
    unsigned y_block_work_begin = y_block_work_id * YPerBlock;
    unsigned x_block_work_begin = x_block_work_id * XPerBlock;

    unsigned ho_block_work_begin = y_block_work_begin * OutTileSizeH;
    unsigned wo_block_work_begin = x_block_work_begin * OutTileSizeW;

    unsigned hi_block_work_begin = ho_block_work_begin; // minus padding
    unsigned wi_block_work_begin = wo_block_work_begin; // minus padding

Chao Liu's avatar
Chao Liu committed
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
#if 0
    if(threadIdx.x == 0)
    {
        print_ConstantTensorDescriptor( in_desc, "gridwise_convolution:  in_desc: ");
        print_ConstantTensorDescriptor(wei_desc, "gridwise_convolution: wei_desc: ");
        print_ConstantTensorDescriptor(out_desc, "gridwise_convolution: out_desc: ");
        print_ConstantTensorDescriptor( in_block_glb_desc, "gridwise_convolution:  in_block_glb_desc: ");
        print_ConstantTensorDescriptor(wei_block_glb_desc, "gridwise_convolution: wei_block_glb_desc: ");
        print_ConstantTensorDescriptor(out_block_glb_desc, "gridwise_convolution: out_block_glb_desc: ");
        print_ConstantTensorDescriptor( in_block_lds_desc, "gridwise_convolution:  in_block_lds_desc: ");
        print_ConstantTensorDescriptor(wei_block_lds_desc, "gridwise_convolution: wei_block_lds_desc: ");
        print_ConstantTensorDescriptor(out_block_lds_desc, "gridwise_convolution: out_block_lds_desc: ");

        printf("NBlockWork %u, KBlockWork %u, YBlockWork %u, XBlockWork %u \t"
               "block_id %u, n_block_work_id %u, k_block_work_id %u, y_block_work_id %u, "
               "x_block_work_id %u\n",
               NBlockWork,
               KBlockWork,
               YBlockWork,
               XBlockWork,
               block_id,
               n_block_work_id,
               k_block_work_id,
               y_block_work_id,
               x_block_work_id);
    }
#endif

    for(unsigned c_block_work_begin = 0; c_block_work_begin < in_desc.GetLength(I1);
Chao Liu's avatar
Chao Liu committed
364
365
366
367
368
        c_block_work_begin += CPerBlockLoop)
    {
        auto f_copy = [](const TFloat& src, TFloat& dst) { dst = src; };

        // copy input tensor to LDS
Chao Liu's avatar
Chao Liu committed
369
370
371
372
373
374
375
        blockwise_4d_tensor_op<TFloat,
                               decltype(in_block_glb_desc),
                               decltype(in_block_lds_desc),
                               NBlockCopyLen0,
                               NBlockCopyLen1,
                               NBlockCopyLen2,
                               NBlockCopyLen3,
376
377
378
379
380
381
382
383
384
                               decltype(f_copy),
                               BlockSize>(in_block_glb_desc,
                                          p_in + in_block_glb_desc.Get1dIndex(n_block_work_begin,
                                                                              c_block_work_begin,
                                                                              hi_block_work_begin,
                                                                              wi_block_work_begin),
                                          in_block_lds_desc,
                                          p_in_block,
                                          f_copy);
Chao Liu's avatar
Chao Liu committed
385
386

        // copy weight tensor to LDS
Chao Liu's avatar
Chao Liu committed
387
388
389
390
391
392
393
        blockwise_4d_tensor_op<TFloat,
                               decltype(wei_block_glb_desc),
                               decltype(wei_block_lds_desc),
                               NBlockCopyLen0,
                               NBlockCopyLen1,
                               NBlockCopyLen2,
                               NBlockCopyLen3,
394
395
                               decltype(f_copy),
                               BlockSize>(
Chao Liu's avatar
Chao Liu committed
396
397
398
            wei_block_glb_desc,
            p_wei + wei_block_glb_desc.Get1dIndex(k_block_work_begin, c_block_work_begin, 0, 0),
            wei_block_lds_desc,
Chao Liu's avatar
Chao Liu committed
399
400
401
402
            p_wei_block,
            f_copy);

        // copy output tensor to LDS
Chao Liu's avatar
Chao Liu committed
403
404
405
406
407
408
409
        blockwise_4d_tensor_op<TFloat,
                               decltype(out_block_glb_desc),
                               decltype(out_block_lds_desc),
                               NBlockCopyLen0,
                               NBlockCopyLen1,
                               NBlockCopyLen2,
                               NBlockCopyLen3,
410
411
412
413
414
415
416
417
418
419
                               decltype(f_copy),
                               BlockSize>(out_block_glb_desc,
                                          p_out +
                                              out_block_glb_desc.Get1dIndex(n_block_work_begin,
                                                                            k_block_work_begin,
                                                                            ho_block_work_begin,
                                                                            wo_block_work_begin),
                                          out_block_lds_desc,
                                          p_out_block,
                                          f_copy);
Chao Liu's avatar
Chao Liu committed
420

421
#if 1
Chao Liu's avatar
Chao Liu committed
422
        __syncthreads();
Chao Liu's avatar
Chao Liu committed
423
#endif
Chao Liu's avatar
Chao Liu committed
424

Chao Liu's avatar
Chao Liu committed
425
426
        // blockwise convolution
        blockwise_convolution<TFloat,
Chao Liu's avatar
Chao Liu committed
427
428
429
                              decltype(in_block_lds_desc),
                              decltype(wei_block_lds_desc),
                              decltype(out_block_lds_desc),
Chao Liu's avatar
Chao Liu committed
430
                              OutTileSizeH,
431
432
433
434
435
436
437
                              OutTileSizeW,
                              BlockSize>(in_block_lds_desc,
                                         p_in_block,
                                         wei_block_lds_desc,
                                         p_wei_block,
                                         out_block_lds_desc,
                                         p_out_block);
Chao Liu's avatar
Chao Liu committed
438

439
#if 1
Chao Liu's avatar
Chao Liu committed
440
        __syncthreads();
Chao Liu's avatar
Chao Liu committed
441
#endif
Chao Liu's avatar
Chao Liu committed
442
443

        // accum output tensor from LDS to device mem
Chao Liu's avatar
Chao Liu committed
444
445
446
447
448
449
450
        blockwise_4d_tensor_op<TFloat,
                               decltype(out_block_lds_desc),
                               decltype(out_block_glb_desc),
                               NBlockCopyLen0,
                               NBlockCopyLen1,
                               NBlockCopyLen2,
                               NBlockCopyLen3,
451
452
453
454
455
456
457
458
459
460
                               decltype(f_copy),
                               BlockSize>(out_block_lds_desc,
                                          p_out_block,
                                          out_block_glb_desc,
                                          p_out +
                                              out_block_glb_desc.Get1dIndex(n_block_work_begin,
                                                                            k_block_work_begin,
                                                                            ho_block_work_begin,
                                                                            wo_block_work_begin),
                                          f_copy);
Chao Liu's avatar
Chao Liu committed
461
    }
Chao Liu's avatar
Chao Liu committed
462
}