gridwise_direct_convolution_1.hip.hpp 6.59 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
#pragma once
2
3
4
5
#include "common.hip.hpp"
#include "ConstantTensorDescriptor.hip.hpp"
#include "blockwise_4d_tensor_op.hip.hpp"
#include "blockwise_direct_convolution.hip.hpp"
Chao Liu's avatar
Chao Liu committed
6

Chao Liu's avatar
Chao Liu committed
7
template <class Float,
Chao Liu's avatar
Chao Liu committed
8
9
10
          class InGlobalDesc,
          class WeiGlobalDesc,
          class OutGlobalDesc,
Chao Liu's avatar
Chao Liu committed
11
12
13
14
15
16
17
18
19
20
21
22
          index_t NPerBlock,
          index_t KPerBlock,
          index_t CPerBlock,
          index_t HoPerBlock,
          index_t WoPerBlock,
          index_t NPerThread,
          index_t KPerThread,
          index_t CPerThread,
          index_t HoPerThread,
          index_t WoPerThread,
          index_t BlockSize,
          index_t GridSize>
Chao Liu's avatar
Chao Liu committed
23
24
25
__global__ void gridwise_direct_convolution_1(const Float* const __restrict__ p_in_global,
                                              const Float* const __restrict__ p_wei_global,
                                              Float* const __restrict__ p_out_global)
Chao Liu's avatar
Chao Liu committed
26
{
Chao Liu's avatar
Chao Liu committed
27
28
29
30
    constexpr auto I0 = Number<0>{};
    constexpr auto I1 = Number<1>{};
    constexpr auto I2 = Number<2>{};
    constexpr auto I3 = Number<3>{};
Chao Liu's avatar
Chao Liu committed
31

Chao Liu's avatar
Chao Liu committed
32
33
34
    constexpr auto in_global_desc  = InGlobalDesc{};
    constexpr auto wei_global_desc = WeiGlobalDesc{};
    constexpr auto out_global_desc = OutGlobalDesc{};
Chao Liu's avatar
Chao Liu committed
35

Chao Liu's avatar
Chao Liu committed
36
37
    constexpr index_t Y = wei_global_desc.GetLength(I2);
    constexpr index_t X = wei_global_desc.GetLength(I3);
Chao Liu's avatar
Chao Liu committed
38

Chao Liu's avatar
Chao Liu committed
39
40
    constexpr index_t HiPerBlock = HoPerBlock + Y - 1;
    constexpr index_t WiPerBlock = WoPerBlock + X - 1;
Chao Liu's avatar
Chao Liu committed
41

Chao Liu's avatar
Chao Liu committed
42
43
44
45
    constexpr index_t NBlockWork = (out_global_desc.GetLength(I0) + NPerBlock - 1) / NPerBlock;
    constexpr index_t KBlockWork = (out_global_desc.GetLength(I1) + KPerBlock - 1) / KPerBlock;
    constexpr index_t HBlockWork = (out_global_desc.GetLength(I2) + HoPerBlock - 1) / HoPerBlock;
    constexpr index_t WBlockWork = (out_global_desc.GetLength(I3) + WoPerBlock - 1) / WoPerBlock;
Chao Liu's avatar
Chao Liu committed
46

Chao Liu's avatar
Chao Liu committed
47
    constexpr auto in_block_global_desc = make_ConstantTensorDescriptor(
Chao Liu's avatar
Chao Liu committed
48
        Sequence<NPerBlock, CPerBlock, HiPerBlock, WiPerBlock>{}, in_global_desc.GetStrides());
Chao Liu's avatar
Chao Liu committed
49

Chao Liu's avatar
Chao Liu committed
50
    constexpr auto wei_block_global_desc = make_ConstantTensorDescriptor(
Chao Liu's avatar
Chao Liu committed
51
        Sequence<KPerBlock, CPerBlock, Y, X>{}, wei_global_desc.GetStrides());
Chao Liu's avatar
Chao Liu committed
52

Chao Liu's avatar
Chao Liu committed
53
    constexpr auto out_block_global_desc = make_ConstantTensorDescriptor(
Chao Liu's avatar
Chao Liu committed
54
        Sequence<NPerBlock, KPerBlock, HoPerBlock, WoPerBlock>{}, out_global_desc.GetStrides());
Chao Liu's avatar
Chao Liu committed
55

Chao Liu's avatar
Chao Liu committed
56
57
58
59
60
    constexpr auto in_block_desc = make_ConstantTensorDescriptor(in_block_global_desc.GetLengths());
    constexpr auto wei_block_desc =
        make_ConstantTensorDescriptor(wei_block_global_desc.GetLengths());
    constexpr auto out_block_desc =
        make_ConstantTensorDescriptor(out_block_global_desc.GetLengths());
Chao Liu's avatar
Chao Liu committed
61

Chao Liu's avatar
Chao Liu committed
62
63
64
    constexpr index_t in_block_element_size  = in_block_desc.GetElementSpace();
    constexpr index_t wei_block_element_size = wei_block_desc.GetElementSpace();
    constexpr index_t out_block_size         = out_block_desc.GetElementSpace();
Chao Liu's avatar
Chao Liu committed
65

Chao Liu's avatar
Chao Liu committed
66
67
    __shared__ Float p_in_block[in_block_element_size];
    __shared__ Float p_wei_block[wei_block_element_size];
Chao Liu's avatar
Chao Liu committed
68
    __shared__ Float p_out_block[out_block_size];
Chao Liu's avatar
Chao Liu committed
69

Chao Liu's avatar
Chao Liu committed
70
    const index_t block_id = blockIdx.x;
Chao Liu's avatar
Chao Liu committed
71

Chao Liu's avatar
Chao Liu committed
72
73
    index_t itmp            = block_id;
    index_t n_block_work_id = itmp / (KBlockWork * HBlockWork * WBlockWork);
Chao Liu's avatar
Chao Liu committed
74
    itmp -= n_block_work_id * (KBlockWork * HBlockWork * WBlockWork);
Chao Liu's avatar
Chao Liu committed
75
    index_t k_block_work_id = itmp / (HBlockWork * WBlockWork);
Chao Liu's avatar
Chao Liu committed
76
    itmp -= k_block_work_id * (HBlockWork * WBlockWork);
Chao Liu's avatar
Chao Liu committed
77
78
    index_t h_block_work_id = itmp / WBlockWork;
    index_t w_block_work_id = itmp - h_block_work_id * WBlockWork;
Chao Liu's avatar
Chao Liu committed
79

Chao Liu's avatar
Chao Liu committed
80
81
82
83
    index_t n_block_work_begin  = n_block_work_id * NPerBlock;
    index_t k_block_work_begin  = k_block_work_id * KPerBlock;
    index_t ho_block_work_begin = h_block_work_id * HoPerBlock;
    index_t wo_block_work_begin = w_block_work_id * WoPerBlock;
Chao Liu's avatar
Chao Liu committed
84

Chao Liu's avatar
Chao Liu committed
85
86
    index_t hi_block_work_begin = ho_block_work_begin; // minus padding
    index_t wi_block_work_begin = wo_block_work_begin; // minus padding
Chao Liu's avatar
Chao Liu committed
87

Chao Liu's avatar
Chao Liu committed
88
    constexpr auto blockwise_in_copy =
89
90
91
92
93
        Blockwise4dTensorCopy1<BlockSize,
                               Float,
                               decltype(in_block_global_desc),
                               decltype(in_block_desc),
                               decltype(in_block_desc.GetLengths())>{};
Chao Liu's avatar
Chao Liu committed
94
95

    constexpr auto blockwise_wei_copy =
96
97
98
99
100
        Blockwise4dTensorCopy1<BlockSize,
                               Float,
                               decltype(wei_block_global_desc),
                               decltype(wei_block_desc),
                               decltype(wei_block_desc.GetLengths())>{};
Chao Liu's avatar
Chao Liu committed
101
102

    constexpr auto blockwise_out_copy =
103
104
105
106
107
        Blockwise4dTensorCopy1<BlockSize,
                               Float,
                               decltype(out_block_desc),
                               decltype(out_block_global_desc),
                               decltype(out_block_desc.GetLengths())>{};
Chao Liu's avatar
Chao Liu committed
108

Chao Liu's avatar
Chao Liu committed
109
    // set output tensor in LDS to 0
Chao Liu's avatar
Chao Liu committed
110
    blockwise_4d_tensor_set_zero<BlockSize>(out_block_desc, p_out_block);
Chao Liu's avatar
faster  
Chao Liu committed
111

Chao Liu's avatar
Chao Liu committed
112
    for(index_t c_block_work_begin = 0; c_block_work_begin < in_global_desc.GetLength(I1);
Chao Liu's avatar
Chao Liu committed
113
        c_block_work_begin += CPerBlock)
Chao Liu's avatar
Chao Liu committed
114
115
    {
        // copy input tensor to LDS
Chao Liu's avatar
tidy up  
Chao Liu committed
116
117
118
119
120
        blockwise_in_copy.Run(p_in_global +
                                  in_global_desc.Get1dIndex(n_block_work_begin,
                                                            c_block_work_begin,
                                                            hi_block_work_begin,
                                                            wi_block_work_begin),
Chao Liu's avatar
Chao Liu committed
121
                              p_in_block);
Chao Liu's avatar
Chao Liu committed
122
123

        // copy weight tensor to LDS
124
        blockwise_wei_copy.Run(
Chao Liu's avatar
Chao Liu committed
125
            p_wei_global + wei_global_desc.Get1dIndex(k_block_work_begin, c_block_work_begin, 0, 0),
Chao Liu's avatar
Chao Liu committed
126
            p_wei_block);
Chao Liu's avatar
Chao Liu committed
127
128
129
130

        __syncthreads();

        // blockwise convolution
Chao Liu's avatar
Chao Liu committed
131
        blockwise_direct_convolution<BlockSize,
Chao Liu's avatar
Chao Liu committed
132
                                     Float,
Chao Liu's avatar
Chao Liu committed
133
134
135
136
137
                                     decltype(in_block_desc),
                                     decltype(wei_block_desc),
                                     decltype(out_block_desc),
                                     NPerThread,
                                     KPerThread,
Chao Liu's avatar
Chao Liu committed
138
139
140
                                     CPerThread,
                                     HoPerThread,
                                     WoPerThread>(
Chao Liu's avatar
Chao Liu committed
141
            in_block_desc, p_in_block, wei_block_desc, p_wei_block, out_block_desc, p_out_block);
Chao Liu's avatar
Chao Liu committed
142
143

        __syncthreads();
Chao Liu's avatar
faster  
Chao Liu committed
144
    }
Chao Liu's avatar
Chao Liu committed
145

Chao Liu's avatar
faster  
Chao Liu committed
146
    // copy output tensor from LDS to device mem
Chao Liu's avatar
tidy up  
Chao Liu committed
147
148
149
150
151
    blockwise_out_copy.Run(
        p_out_block,
        p_out_global +
            out_global_desc.Get1dIndex(
                n_block_work_begin, k_block_work_begin, ho_block_work_begin, wo_block_work_begin));
Chao Liu's avatar
Chao Liu committed
152
}