blockwise_direct_convolution.cuh 6.08 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
2
3
#pragma once
#include "constant_tensor_descriptor.cuh"
#include "threadwise_tensor_op.cuh"
Chao Liu's avatar
rename  
Chao Liu committed
4
#include "threadwise_direct_convolution.cuh"
Chao Liu's avatar
Chao Liu committed
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19

template <class TFloat,
          class InBlockDesc,
          class WeiBlockDesc,
          class OutBlockDesc,
          unsigned OutTileSizeH,
          unsigned OutTileSizeW,
          unsigned BlockSize>
__device__ void blockwise_convolution(InBlockDesc,
                                      TFloat* const __restrict__ p_in_block,
                                      WeiBlockDesc,
                                      TFloat* const __restrict__ p_wei_block,
                                      OutBlockDesc,
                                      TFloat* __restrict__ p_out_block)
{
Chao Liu's avatar
Chao Liu committed
20
21
22
23
    constexpr auto I0 = Number<0>{};
    constexpr auto I1 = Number<1>{};
    constexpr auto I2 = Number<2>{};
    constexpr auto I3 = Number<3>{};
Chao Liu's avatar
Chao Liu committed
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91

    constexpr auto in_block_desc  = InBlockDesc{};
    constexpr auto wei_block_desc = WeiBlockDesc{};
    constexpr auto out_block_desc = OutBlockDesc{};

    constexpr unsigned S = wei_block_desc.GetLength(I2);
    constexpr unsigned R = wei_block_desc.GetLength(I3);

    constexpr unsigned NPerBlock = out_block_desc.GetLength(I0);
    constexpr unsigned KPerBlock = out_block_desc.GetLength(I1);
    constexpr unsigned YPerBlock = (out_block_desc.GetLength(I2) + OutTileSizeH - 1) / OutTileSizeH;
    constexpr unsigned XPerBlock = (out_block_desc.GetLength(I3) + OutTileSizeW - 1) / OutTileSizeW;

    constexpr unsigned CPerBlock = in_block_desc.GetLength(I1);

    constexpr unsigned InTileSizeH = OutTileSizeH + S - 1;
    constexpr unsigned InTileSizeW = OutTileSizeW + R - 1;

#if 0
    if(threadIdx.x == 0)
    {
        print_ConstantTensorDescriptor(in_block_desc);
        print_ConstantTensorDescriptor(wei_block_desc);
        print_ConstantTensorDescriptor(out_block_desc);
    }
#endif

    constexpr auto in_thread_src_desc = make_ConstantTensorDescriptor(
        Sequence<1, CPerBlock, InTileSizeH, InTileSizeW>{}, in_block_desc.GetStrides());

    constexpr auto wei_thread_src_desc =
        make_ConstantTensorDescriptor(Sequence<1, CPerBlock, S, R>{}, wei_block_desc.GetStrides());

    constexpr auto out_thread_src_desc = make_ConstantTensorDescriptor(
        Sequence<1, 1, OutTileSizeH, OutTileSizeW>{}, out_block_desc.GetStrides());

    constexpr auto in_thread_dst_desc =
        make_ConstantTensorDescriptor(in_thread_src_desc.GetLengths());

    constexpr auto wei_thread_dst_desc =
        make_ConstantTensorDescriptor(wei_thread_src_desc.GetLengths());

    constexpr auto out_thread_dst_desc =
        make_ConstantTensorDescriptor(out_thread_src_desc.GetLengths());

    const unsigned thread_id = threadIdx.x;

    for(unsigned thread_work_id = thread_id; thread_work_id < NPerBlock * YPerBlock * XPerBlock;
        thread_work_id += BlockSize)
    {
        unsigned itmp             = thread_work_id;
        unsigned n_thread_work_id = itmp / (YPerBlock * XPerBlock);
        itmp -= n_thread_work_id * (YPerBlock * XPerBlock);
        unsigned y_thread_work_id = itmp / XPerBlock;
        unsigned x_thread_work_id = itmp - y_thread_work_id * XPerBlock;

        unsigned n_thread_work_begin  = n_thread_work_id * 1;
        unsigned ho_thread_work_begin = y_thread_work_id * OutTileSizeH;
        unsigned wo_thread_work_begin = x_thread_work_id * OutTileSizeW;

        unsigned hi_thread_work_begin = ho_thread_work_begin; // minus padding
        unsigned wi_thread_work_begin = wo_thread_work_begin; // minus padding

        TFloat p_in_thread[in_thread_src_desc.GetElementSpace()];
        TFloat p_wei_thread[wei_thread_src_desc.GetElementSpace()];
        TFloat p_out_thread[out_thread_src_desc.GetElementSpace()];

        // copy input tensor into register
Chao Liu's avatar
Chao Liu committed
92
        threadwise_4d_tensor_copy(
Chao Liu's avatar
Chao Liu committed
93
94
            in_thread_src_desc,
            p_in_block + in_block_desc.Get1dIndex(
Chao Liu's avatar
Chao Liu committed
95
                             n_thread_work_begin, 0, hi_thread_work_begin, wi_thread_work_begin),
Chao Liu's avatar
Chao Liu committed
96
            in_thread_dst_desc,
Chao Liu's avatar
Chao Liu committed
97
            p_in_thread);
Chao Liu's avatar
Chao Liu committed
98
99
100
101
102

        for(unsigned k_thread_work_begin = 0; k_thread_work_begin < KPerBlock;
            ++k_thread_work_begin)
        {
            // copy weight tensor into register
Chao Liu's avatar
Chao Liu committed
103
104
105
106
107
            threadwise_4d_tensor_copy(wei_thread_src_desc,
                                      p_wei_block +
                                          wei_block_desc.Get1dIndex(k_thread_work_begin, 0, 0, 0),
                                      wei_thread_dst_desc,
                                      p_wei_thread);
Chao Liu's avatar
Chao Liu committed
108
109

            // copy output tensor into register
Chao Liu's avatar
Chao Liu committed
110
111
112
113
114
115
116
            threadwise_4d_tensor_copy(out_thread_src_desc,
                                      p_out_block + out_block_desc.Get1dIndex(n_thread_work_begin,
                                                                              k_thread_work_begin,
                                                                              ho_thread_work_begin,
                                                                              wo_thread_work_begin),
                                      out_thread_dst_desc,
                                      p_out_thread);
Chao Liu's avatar
Chao Liu committed
117
118

            // threadwise convolution
Chao Liu's avatar
Chao Liu committed
119
120
121
122
123
124
            threadwise_direct_convolution(in_thread_dst_desc,
                                          p_in_thread,
                                          wei_thread_dst_desc,
                                          p_wei_thread,
                                          out_thread_dst_desc,
                                          p_out_thread);
Chao Liu's avatar
Chao Liu committed
125
126

            // accumulate output tensor into LDS
Chao Liu's avatar
Chao Liu committed
127
128
129
130
131
132
133
134
            threadwise_4d_tensor_copy(out_thread_dst_desc,
                                      p_out_thread,
                                      out_thread_src_desc,
                                      p_out_block +
                                          out_block_desc.Get1dIndex(n_thread_work_begin,
                                                                    k_thread_work_begin,
                                                                    ho_thread_work_begin,
                                                                    wo_thread_work_begin));
Chao Liu's avatar
Chao Liu committed
135
136
137
        }
    }
}