blockwise_convolution.cuh 7.08 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
#pragma once
#include "constant_tensor_descriptor.cuh"
#include "threadwise_tensor_op.cuh"
#include "threadwise_convolution.cuh"

template <class TFloat,
          class InBlockDesc,
          class WeiBlockDesc,
          class OutBlockDesc,
          unsigned OutTileSizeH,
          unsigned OutTileSizeW,
          unsigned BlockSize>
__device__ void blockwise_convolution(InBlockDesc,
                                      TFloat* const __restrict__ p_in_block,
                                      WeiBlockDesc,
                                      TFloat* const __restrict__ p_wei_block,
                                      OutBlockDesc,
                                      TFloat* __restrict__ p_out_block)
{
    constexpr auto I0 = Index<0>{};
    constexpr auto I1 = Index<1>{};
    constexpr auto I2 = Index<2>{};
    constexpr auto I3 = Index<3>{};

    constexpr auto in_block_desc  = InBlockDesc{};
    constexpr auto wei_block_desc = WeiBlockDesc{};
    constexpr auto out_block_desc = OutBlockDesc{};

    constexpr unsigned S = wei_block_desc.GetLength(I2);
    constexpr unsigned R = wei_block_desc.GetLength(I3);

    constexpr unsigned NPerBlock = out_block_desc.GetLength(I0);
    constexpr unsigned KPerBlock = out_block_desc.GetLength(I1);
    constexpr unsigned YPerBlock = (out_block_desc.GetLength(I2) + OutTileSizeH - 1) / OutTileSizeH;
    constexpr unsigned XPerBlock = (out_block_desc.GetLength(I3) + OutTileSizeW - 1) / OutTileSizeW;

    constexpr unsigned CPerBlock = in_block_desc.GetLength(I1);

    constexpr unsigned InTileSizeH = OutTileSizeH + S - 1;
    constexpr unsigned InTileSizeW = OutTileSizeW + R - 1;

#if 0
    if(threadIdx.x == 0)
    {
        print_ConstantTensorDescriptor(in_block_desc);
        print_ConstantTensorDescriptor(wei_block_desc);
        print_ConstantTensorDescriptor(out_block_desc);
    }
#endif

    constexpr auto in_thread_src_desc = make_ConstantTensorDescriptor(
        Sequence<1, CPerBlock, InTileSizeH, InTileSizeW>{}, in_block_desc.GetStrides());

    constexpr auto wei_thread_src_desc =
        make_ConstantTensorDescriptor(Sequence<1, CPerBlock, S, R>{}, wei_block_desc.GetStrides());

    constexpr auto out_thread_src_desc = make_ConstantTensorDescriptor(
        Sequence<1, 1, OutTileSizeH, OutTileSizeW>{}, out_block_desc.GetStrides());

    constexpr auto in_thread_dst_desc =
        make_ConstantTensorDescriptor(in_thread_src_desc.GetLengths());

    constexpr auto wei_thread_dst_desc =
        make_ConstantTensorDescriptor(wei_thread_src_desc.GetLengths());

    constexpr auto out_thread_dst_desc =
        make_ConstantTensorDescriptor(out_thread_src_desc.GetLengths());

    const unsigned thread_id = threadIdx.x;

    for(unsigned thread_work_id = thread_id; thread_work_id < NPerBlock * YPerBlock * XPerBlock;
        thread_work_id += BlockSize)
    {
        unsigned itmp             = thread_work_id;
        unsigned n_thread_work_id = itmp / (YPerBlock * XPerBlock);
        itmp -= n_thread_work_id * (YPerBlock * XPerBlock);
        unsigned y_thread_work_id = itmp / XPerBlock;
        unsigned x_thread_work_id = itmp - y_thread_work_id * XPerBlock;

        unsigned n_thread_work_begin  = n_thread_work_id * 1;
        unsigned ho_thread_work_begin = y_thread_work_id * OutTileSizeH;
        unsigned wo_thread_work_begin = x_thread_work_id * OutTileSizeW;

        unsigned hi_thread_work_begin = ho_thread_work_begin; // minus padding
        unsigned wi_thread_work_begin = wo_thread_work_begin; // minus padding

        TFloat p_in_thread[in_thread_src_desc.GetElementSpace()];
        TFloat p_wei_thread[wei_thread_src_desc.GetElementSpace()];
        TFloat p_out_thread[out_thread_src_desc.GetElementSpace()];

        auto f_copy = [](const TFloat& src, TFloat& dst) { dst = src; };

        // copy input tensor into register
        threadwise_4d_tensor_op_binary<TFloat,
                                       decltype(in_thread_src_desc),
                                       decltype(in_thread_dst_desc),
                                       decltype(f_copy)>(
            in_thread_src_desc,
            p_in_block + in_block_desc.Get1dIndex(
                           n_thread_work_begin, 0, hi_thread_work_begin, wi_thread_work_begin),
            in_thread_dst_desc,
            p_in_thread,
            f_copy);

        for(unsigned k_thread_work_begin = 0; k_thread_work_begin < KPerBlock;
            ++k_thread_work_begin)
        {
            // copy weight tensor into register
            threadwise_4d_tensor_op_binary<TFloat,
                                           decltype(wei_thread_src_desc),
                                           decltype(wei_thread_dst_desc),
                                           decltype(f_copy)>(
                wei_thread_src_desc,
                p_wei_block + wei_block_desc.Get1dIndex(k_thread_work_begin, 0, 0, 0),
                wei_thread_dst_desc,
                p_wei_thread,
                f_copy);

            // copy output tensor into register
            threadwise_4d_tensor_op_binary<TFloat,
                                           decltype(out_thread_src_desc),
                                           decltype(out_thread_dst_desc),
                                           decltype(f_copy)>(
                out_thread_src_desc,
                p_out_block + out_block_desc.Get1dIndex(n_thread_work_begin,
                                                k_thread_work_begin,
                                                ho_thread_work_begin,
                                                wo_thread_work_begin),
                out_thread_dst_desc,
                p_out_thread,
                f_copy);

            // threadwise convolution
            threadwise_direct_convolution<TFloat,
                                          decltype(in_thread_dst_desc),
                                          decltype(wei_thread_dst_desc),
                                          decltype(out_thread_dst_desc)>(in_thread_dst_desc,
                                                                         p_in_thread,
                                                                         wei_thread_dst_desc,
                                                                         p_wei_thread,
                                                                         out_thread_dst_desc,
                                                                         p_out_thread);

            // accumulate output tensor into LDS
            threadwise_4d_tensor_op_binary<TFloat,
                                           decltype(out_thread_dst_desc),
                                           decltype(out_thread_src_desc),
                                           decltype(f_copy)>(
                out_thread_dst_desc,
                p_out_thread,
                out_thread_src_desc,
                p_out_block + out_block_desc.Get1dIndex(n_thread_work_begin,
                                                k_thread_work_begin,
                                                ho_thread_work_begin,
                                                wo_thread_work_begin),
                f_copy);
        }
    }
}