blockwise_direct_convolution.cuh 6.64 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
2
3
#pragma once
#include "constant_tensor_descriptor.cuh"
#include "threadwise_tensor_op.cuh"
Chao Liu's avatar
rename  
Chao Liu committed
4
#include "threadwise_direct_convolution.cuh"
Chao Liu's avatar
Chao Liu committed
5
6
7
8
9
10
11

template <class TFloat,
          class InBlockDesc,
          class WeiBlockDesc,
          class OutBlockDesc,
          unsigned OutTileSizeH,
          unsigned OutTileSizeW,
Chao Liu's avatar
Chao Liu committed
12
13
14
          unsigned NPerThread,
          unsigned KPerThread,
          unsigned CPerThread,
Chao Liu's avatar
Chao Liu committed
15
          unsigned BlockSize>
Chao Liu's avatar
Chao Liu committed
16
17
18
19
20
21
__device__ void blockwise_direct_convolution(InBlockDesc,
                                             TFloat* const __restrict__ p_in_block,
                                             WeiBlockDesc,
                                             TFloat* const __restrict__ p_wei_block,
                                             OutBlockDesc,
                                             TFloat* __restrict__ p_out_block)
Chao Liu's avatar
Chao Liu committed
22
{
Chao Liu's avatar
Chao Liu committed
23
24
25
26
    constexpr auto I0 = Number<0>{};
    constexpr auto I1 = Number<1>{};
    constexpr auto I2 = Number<2>{};
    constexpr auto I3 = Number<3>{};
Chao Liu's avatar
Chao Liu committed
27
28
29
30
31
32
33
34
35
36
37

    constexpr auto in_block_desc  = InBlockDesc{};
    constexpr auto wei_block_desc = WeiBlockDesc{};
    constexpr auto out_block_desc = OutBlockDesc{};

    constexpr unsigned S = wei_block_desc.GetLength(I2);
    constexpr unsigned R = wei_block_desc.GetLength(I3);

    constexpr unsigned InTileSizeH = OutTileSizeH + S - 1;
    constexpr unsigned InTileSizeW = OutTileSizeW + R - 1;

Chao Liu's avatar
Chao Liu committed
38
39
40
41
42
43
44
45
    // divide thread work
    constexpr unsigned NThreadWork = (out_block_desc.GetLength(I0) + NPerThread - 1) / NPerThread;
    constexpr unsigned KThreadWork = (out_block_desc.GetLength(I1) + KPerThread - 1) / KPerThread;
    constexpr unsigned YThreadWork =
        (out_block_desc.GetLength(I2) + OutTileSizeH - 1) / OutTileSizeH;
    constexpr unsigned XThreadWork =
        (out_block_desc.GetLength(I3) + OutTileSizeW - 1) / OutTileSizeW;

Chao Liu's avatar
Chao Liu committed
46
47
48
49
50
51
52
53
54
#if 0
    if(threadIdx.x == 0)
    {
        print_ConstantTensorDescriptor(in_block_desc);
        print_ConstantTensorDescriptor(wei_block_desc);
        print_ConstantTensorDescriptor(out_block_desc);
    }
#endif

Chao Liu's avatar
Chao Liu committed
55
56
    constexpr auto in_thread_desc =
        make_ConstantTensorDescriptor(Sequence<NPerThread, CPerThread, InTileSizeH, InTileSizeW>{});
Chao Liu's avatar
Chao Liu committed
57

Chao Liu's avatar
Chao Liu committed
58
59
    constexpr auto wei_thread_desc =
        make_ConstantTensorDescriptor(Sequence<KPerThread, CPerThread, S, R>{});
Chao Liu's avatar
Chao Liu committed
60

Chao Liu's avatar
Chao Liu committed
61
62
    constexpr auto out_thread_desc = make_ConstantTensorDescriptor(
        Sequence<NPerThread, KPerThread, OutTileSizeH, OutTileSizeW>{});
Chao Liu's avatar
Chao Liu committed
63

Chao Liu's avatar
Chao Liu committed
64
65
    constexpr auto in_thread_block_desc =
        make_ConstantTensorDescriptor(in_thread_desc.GetLengths(), in_block_desc.GetStrides());
Chao Liu's avatar
Chao Liu committed
66

Chao Liu's avatar
Chao Liu committed
67
68
    constexpr auto wei_thread_block_desc =
        make_ConstantTensorDescriptor(wei_thread_desc.GetLengths(), wei_block_desc.GetStrides());
Chao Liu's avatar
Chao Liu committed
69

Chao Liu's avatar
Chao Liu committed
70
71
    constexpr auto out_thread_block_desc =
        make_ConstantTensorDescriptor(out_thread_desc.GetLengths(), out_block_desc.GetStrides());
Chao Liu's avatar
Chao Liu committed
72
73
74

    const unsigned thread_id = threadIdx.x;

Chao Liu's avatar
Chao Liu committed
75
76
    for(unsigned thread_work_id = thread_id;
        thread_work_id < NThreadWork * KThreadWork * YThreadWork * XThreadWork;
Chao Liu's avatar
Chao Liu committed
77
78
79
        thread_work_id += BlockSize)
    {
        unsigned itmp             = thread_work_id;
Chao Liu's avatar
Chao Liu committed
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
        unsigned n_thread_work_id = itmp / (KThreadWork * YThreadWork * XThreadWork);
        itmp -= n_thread_work_id * (KThreadWork * YThreadWork * XThreadWork);
        unsigned k_thread_work_id = itmp / (YThreadWork * XThreadWork);
        itmp -= k_thread_work_id * (YThreadWork * XThreadWork);
        unsigned y_thread_work_id = itmp / XThreadWork;
        unsigned x_thread_work_id = itmp - y_thread_work_id * XThreadWork;

        unsigned n_thread_data_begin  = n_thread_work_id * NPerThread;
        unsigned k_thread_data_begin  = k_thread_work_id * KPerThread;
        unsigned ho_thread_data_begin = y_thread_work_id * OutTileSizeH;
        unsigned wo_thread_data_begin = x_thread_work_id * OutTileSizeW;

        unsigned hi_thread_data_begin = ho_thread_data_begin; // minus padding
        unsigned wi_thread_data_begin = wo_thread_data_begin; // minus padding

        TFloat p_in_thread[in_thread_desc.GetElementSpace()];
        TFloat p_wei_thread[wei_thread_desc.GetElementSpace()];
        TFloat p_out_thread[out_thread_desc.GetElementSpace()];

        threadwise_4d_tensor_copy(out_thread_block_desc,
                                  p_out_block + out_block_desc.Get1dIndex(n_thread_data_begin,
                                                                          k_thread_data_begin,
                                                                          ho_thread_data_begin,
                                                                          wo_thread_data_begin),
                                  out_thread_desc,
                                  p_out_thread);

        for(unsigned c_thread_data_begin = 0; c_thread_data_begin < in_block_desc.GetLength(I1);
            c_thread_data_begin += CPerThread)
Chao Liu's avatar
Chao Liu committed
109
        {
Chao Liu's avatar
Chao Liu committed
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
            // copy input into register
            threadwise_4d_tensor_copy(in_thread_block_desc,
                                      p_in_block + in_block_desc.Get1dIndex(n_thread_data_begin,
                                                                            c_thread_data_begin,
                                                                            hi_thread_data_begin,
                                                                            wi_thread_data_begin),
                                      in_thread_desc,
                                      p_in_thread);

            // copy weight into register
            threadwise_4d_tensor_copy(
                wei_thread_block_desc,
                p_wei_block +
                    wei_block_desc.Get1dIndex(k_thread_data_begin, c_thread_data_begin, 0, 0),
                wei_thread_desc,
                p_wei_thread);
Chao Liu's avatar
Chao Liu committed
126
127

            // threadwise convolution
Chao Liu's avatar
Chao Liu committed
128
            threadwise_direct_convolution(in_thread_desc,
Chao Liu's avatar
Chao Liu committed
129
                                          p_in_thread,
Chao Liu's avatar
Chao Liu committed
130
                                          wei_thread_desc,
Chao Liu's avatar
Chao Liu committed
131
                                          p_wei_thread,
Chao Liu's avatar
Chao Liu committed
132
                                          out_thread_desc,
Chao Liu's avatar
Chao Liu committed
133
                                          p_out_thread);
Chao Liu's avatar
Chao Liu committed
134
        }
Chao Liu's avatar
Chao Liu committed
135
136
137
138
139
140
141
142
143

        // copy output into LDS
        threadwise_4d_tensor_copy(out_thread_desc,
                                  p_out_thread,
                                  out_thread_block_desc,
                                  p_out_block + out_block_desc.Get1dIndex(n_thread_data_begin,
                                                                          k_thread_data_begin,
                                                                          ho_thread_data_begin,
                                                                          wo_thread_data_begin));
Chao Liu's avatar
Chao Liu committed
144
145
    }
}