blockwise_direct_convolution.hip.hpp 5.99 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
#pragma once
2
3
4
#include "ConstantTensorDescriptor.hip.hpp"
#include "threadwise_4d_tensor_op.hip.hpp"
#include "threadwise_direct_convolution.hip.hpp"
Chao Liu's avatar
Chao Liu committed
5

Chao Liu's avatar
Chao Liu committed
6
template <unsigned BlockSize,
Chao Liu's avatar
Chao Liu committed
7
          class Float,
Chao Liu's avatar
Chao Liu committed
8
9
10
          class InBlockDesc,
          class WeiBlockDesc,
          class OutBlockDesc,
Chao Liu's avatar
Chao Liu committed
11
12
          unsigned NPerThread,
          unsigned KPerThread,
Chao Liu's avatar
Chao Liu committed
13
14
15
          unsigned CPerThread,
          unsigned HoPerThread,
          unsigned WoPerThread>
Chao Liu's avatar
Chao Liu committed
16
__device__ void blockwise_direct_convolution(InBlockDesc,
Chao Liu's avatar
Chao Liu committed
17
                                             Float* const __restrict__ p_in_block,
Chao Liu's avatar
Chao Liu committed
18
                                             WeiBlockDesc,
Chao Liu's avatar
Chao Liu committed
19
                                             Float* const __restrict__ p_wei_block,
Chao Liu's avatar
Chao Liu committed
20
                                             OutBlockDesc,
Chao Liu's avatar
Chao Liu committed
21
                                             Float* __restrict__ p_out_block)
Chao Liu's avatar
Chao Liu committed
22
{
Chao Liu's avatar
Chao Liu committed
23
24
25
26
    constexpr auto I0 = Number<0>{};
    constexpr auto I1 = Number<1>{};
    constexpr auto I2 = Number<2>{};
    constexpr auto I3 = Number<3>{};
Chao Liu's avatar
Chao Liu committed
27
28
29
30
31

    constexpr auto in_block_desc  = InBlockDesc{};
    constexpr auto wei_block_desc = WeiBlockDesc{};
    constexpr auto out_block_desc = OutBlockDesc{};

Chao Liu's avatar
Chao Liu committed
32
33
    constexpr unsigned Y = wei_block_desc.GetLength(I2);
    constexpr unsigned X = wei_block_desc.GetLength(I3);
Chao Liu's avatar
Chao Liu committed
34

Chao Liu's avatar
Chao Liu committed
35
36
    constexpr unsigned InTileSizeH = HoPerThread + Y - 1;
    constexpr unsigned InTileSizeW = WoPerThread + X - 1;
Chao Liu's avatar
Chao Liu committed
37

Chao Liu's avatar
Chao Liu committed
38
39
40
    // divide thread work
    constexpr unsigned NThreadWork = (out_block_desc.GetLength(I0) + NPerThread - 1) / NPerThread;
    constexpr unsigned KThreadWork = (out_block_desc.GetLength(I1) + KPerThread - 1) / KPerThread;
Chao Liu's avatar
Chao Liu committed
41
42
    constexpr unsigned YThreadWork = (out_block_desc.GetLength(I2) + HoPerThread - 1) / HoPerThread;
    constexpr unsigned XThreadWork = (out_block_desc.GetLength(I3) + WoPerThread - 1) / WoPerThread;
Chao Liu's avatar
Chao Liu committed
43

Chao Liu's avatar
Chao Liu committed
44
45
46
47
48
49
50
51
52
#if 0
    if(threadIdx.x == 0)
    {
        print_ConstantTensorDescriptor(in_block_desc);
        print_ConstantTensorDescriptor(wei_block_desc);
        print_ConstantTensorDescriptor(out_block_desc);
    }
#endif

Chao Liu's avatar
Chao Liu committed
53
54
    constexpr auto in_thread_desc =
        make_ConstantTensorDescriptor(Sequence<NPerThread, CPerThread, InTileSizeH, InTileSizeW>{});
Chao Liu's avatar
Chao Liu committed
55

Chao Liu's avatar
Chao Liu committed
56
    constexpr auto wei_thread_desc =
Chao Liu's avatar
Chao Liu committed
57
        make_ConstantTensorDescriptor(Sequence<KPerThread, CPerThread, Y, X>{});
Chao Liu's avatar
Chao Liu committed
58

Chao Liu's avatar
Chao Liu committed
59
    constexpr auto out_thread_desc =
Chao Liu's avatar
Chao Liu committed
60
        get_convolution_output_default_4d_tensor_descriptor(in_thread_desc, wei_thread_desc);
Chao Liu's avatar
Chao Liu committed
61

Chao Liu's avatar
Chao Liu committed
62
63
    constexpr auto in_thread_block_desc =
        make_ConstantTensorDescriptor(in_thread_desc.GetLengths(), in_block_desc.GetStrides());
Chao Liu's avatar
Chao Liu committed
64

Chao Liu's avatar
Chao Liu committed
65
66
    constexpr auto wei_thread_block_desc =
        make_ConstantTensorDescriptor(wei_thread_desc.GetLengths(), wei_block_desc.GetStrides());
Chao Liu's avatar
Chao Liu committed
67

Chao Liu's avatar
Chao Liu committed
68
69
    constexpr auto out_thread_block_desc =
        make_ConstantTensorDescriptor(out_thread_desc.GetLengths(), out_block_desc.GetStrides());
Chao Liu's avatar
Chao Liu committed
70
71
72

    const unsigned thread_id = threadIdx.x;

Chao Liu's avatar
Chao Liu committed
73
74
    for(unsigned thread_work_id = thread_id;
        thread_work_id < NThreadWork * KThreadWork * YThreadWork * XThreadWork;
Chao Liu's avatar
Chao Liu committed
75
76
77
        thread_work_id += BlockSize)
    {
        unsigned itmp             = thread_work_id;
Chao Liu's avatar
Chao Liu committed
78
79
80
81
82
83
84
85
86
        unsigned n_thread_work_id = itmp / (KThreadWork * YThreadWork * XThreadWork);
        itmp -= n_thread_work_id * (KThreadWork * YThreadWork * XThreadWork);
        unsigned k_thread_work_id = itmp / (YThreadWork * XThreadWork);
        itmp -= k_thread_work_id * (YThreadWork * XThreadWork);
        unsigned y_thread_work_id = itmp / XThreadWork;
        unsigned x_thread_work_id = itmp - y_thread_work_id * XThreadWork;

        unsigned n_thread_data_begin  = n_thread_work_id * NPerThread;
        unsigned k_thread_data_begin  = k_thread_work_id * KPerThread;
Chao Liu's avatar
Chao Liu committed
87
88
        unsigned ho_thread_data_begin = y_thread_work_id * HoPerThread;
        unsigned wo_thread_data_begin = x_thread_work_id * WoPerThread;
Chao Liu's avatar
Chao Liu committed
89
90
91
92

        unsigned hi_thread_data_begin = ho_thread_data_begin; // minus padding
        unsigned wi_thread_data_begin = wo_thread_data_begin; // minus padding

Chao Liu's avatar
Chao Liu committed
93
        Float p_out_thread[out_thread_desc.GetElementSpace()];
Chao Liu's avatar
Chao Liu committed
94

Chao Liu's avatar
Chao Liu committed
95
        threadwise_4d_tensor_copy(out_block_desc,
Chao Liu's avatar
Chao Liu committed
96
97
98
99
100
                                  p_out_block +
                                      out_block_desc.Get1dIndex(n_thread_data_begin,
                                                                k_thread_data_begin,
                                                                ho_thread_data_begin,
                                                                wo_thread_data_begin),
Chao Liu's avatar
Chao Liu committed
101
                                  out_thread_desc,
Chao Liu's avatar
Chao Liu committed
102
                                  p_out_thread,
Chao Liu's avatar
Chao Liu committed
103
                                  out_thread_desc.GetLengths());
Chao Liu's avatar
Chao Liu committed
104
105
106

        for(unsigned c_thread_data_begin = 0; c_thread_data_begin < in_block_desc.GetLength(I1);
            c_thread_data_begin += CPerThread)
Chao Liu's avatar
Chao Liu committed
107
        {
Chao Liu's avatar
Chao Liu committed
108
109
110
            // threadwise convolution
            threadwise_direct_convolution_2(
                in_thread_block_desc,
Chao Liu's avatar
Chao Liu committed
111
112
113
114
115
                p_in_block +
                    in_block_desc.Get1dIndex(n_thread_data_begin,
                                             c_thread_data_begin,
                                             hi_thread_data_begin,
                                             wi_thread_data_begin),
Chao Liu's avatar
Chao Liu committed
116
117
118
                wei_thread_block_desc,
                p_wei_block +
                    wei_block_desc.Get1dIndex(k_thread_data_begin, c_thread_data_begin, 0, 0),
Chao Liu's avatar
Chao Liu committed
119
120
                out_thread_desc,
                p_out_thread);
Chao Liu's avatar
Chao Liu committed
121
        }
Chao Liu's avatar
Chao Liu committed
122
123
124
125

        // copy output into LDS
        threadwise_4d_tensor_copy(out_thread_desc,
                                  p_out_thread,
Chao Liu's avatar
Chao Liu committed
126
                                  out_block_desc,
Chao Liu's avatar
Chao Liu committed
127
128
129
130
131
                                  p_out_block +
                                      out_block_desc.Get1dIndex(n_thread_data_begin,
                                                                k_thread_data_begin,
                                                                ho_thread_data_begin,
                                                                wo_thread_data_begin),
Chao Liu's avatar
Chao Liu committed
132
                                  out_thread_desc.GetLengths());
Chao Liu's avatar
Chao Liu committed
133
134
    }
}