blockwise_2d_tensor_op.cuh 5.19 KB
Newer Older
1
#pragma once
Chao Liu's avatar
Chao Liu committed
2
#include "ConstantTensorDescriptor.cuh"
3

Chao Liu's avatar
Chao Liu committed
4
template <unsigned BlockSize, class Float, class DstDesc, class F>
5
__device__ void
Chao Liu's avatar
Chao Liu committed
6
blockwise_2d_tensor_pointwise_operation_unary(DstDesc, Float* __restrict__ p_dst, F f)
Chao Liu's avatar
Chao Liu committed
7
{
Chao Liu's avatar
Chao Liu committed
8
9
    constexpr auto I0 = Number<0>{};
    constexpr auto I1 = Number<1>{};
Chao Liu's avatar
Chao Liu committed
10

11
12
    constexpr auto dst_desc = DstDesc{};

Chao Liu's avatar
Chao Liu committed
13
    constexpr auto desc = make_ConstantTensorDescriptor(dst_desc.GetLengths());
Chao Liu's avatar
Chao Liu committed
14

15
16
17
#if 0
    if(threadIdx.x == 0)
    {
Chao Liu's avatar
Chao Liu committed
18
19
        print_ConstantTensorDescriptor(dst_desc, "blockwise_4d_tensor_op_unary: dst_desc: ");
        print_ConstantTensorDescriptor(desc, "blockwise_4d_tensor_op_unary: desc: ");
20
21
22
    }
#endif

Chao Liu's avatar
Chao Liu committed
23
24
    constexpr unsigned NLoop = desc.GetElementSize() / BlockSize;

Chao Liu's avatar
faster  
Chao Liu committed
25
    for(unsigned iloop = 0; iloop < NLoop; ++iloop)
Chao Liu's avatar
Chao Liu committed
26
27
28
29
30
31
32
33
34
    {
        unsigned is = threadIdx.x + iloop * BlockSize;

        const unsigned did0 = is / desc.GetStride(I0);

        is -= did0 * desc.GetStride(I0);

        const unsigned did1 = is / desc.GetStride(I1);

Chao Liu's avatar
Chao Liu committed
35
        const unsigned dindex = dst_desc.Get1dIndex(did0, did1);
Chao Liu's avatar
Chao Liu committed
36

Chao Liu's avatar
Chao Liu committed
37
        f(p_dst[dindex]);
Chao Liu's avatar
Chao Liu committed
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
    }

    constexpr bool has_tail = (desc.GetElementSize() > NLoop * BlockSize);

    if(has_tail)
    {
        unsigned is = threadIdx.x + NLoop * BlockSize;

        if(is < desc.GetElementSize())
        {
            const unsigned did0 = is / desc.GetStride(I0);

            is -= did0 * desc.GetStride(I0);

            const unsigned did1 = is / desc.GetStride(I1);

Chao Liu's avatar
Chao Liu committed
54
            const unsigned dindex = dst_desc.Get1dIndex(did0, did1);
Chao Liu's avatar
Chao Liu committed
55

Chao Liu's avatar
Chao Liu committed
56
            f(p_dst[dindex]);
Chao Liu's avatar
Chao Liu committed
57
58
59
        }
    }
}
Chao Liu's avatar
Chao Liu committed
60

Chao Liu's avatar
Chao Liu committed
61
// Function: p_dst[reorder[i0], reorder[i1], reorder[i2], reorder[i3]] = p_src[i0,i1,i2,i3]
62
63
// TODO: in order to optimize mem access for different mem type,
// need to write specialized version
Chao Liu's avatar
Chao Liu committed
64
template <unsigned BlockSize,
Chao Liu's avatar
Chao Liu committed
65
          class Float,
66
67
          class SrcDesc,
          class DstDesc,
Chao Liu's avatar
Chao Liu committed
68
69
          class SrcOpLengths,
          class DstFromSrcReorder,
Chao Liu's avatar
Chao Liu committed
70
          class F>
Chao Liu's avatar
Chao Liu committed
71
__device__ void blockwise_2d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src(
Chao Liu's avatar
Chao Liu committed
72
73
74
75
76
77
78
    SrcDesc,
    Float* const __restrict__ p_src,
    DstDesc,
    Float* __restrict__ p_dst,
    SrcOpLengths,
    DstFromSrcReorder,
    F f)
Chao Liu's avatar
Chao Liu committed
79
{
Chao Liu's avatar
Chao Liu committed
80
81
    constexpr auto I0 = Number<0>{};
    constexpr auto I1 = Number<1>{};
Chao Liu's avatar
Chao Liu committed
82

Chao Liu's avatar
Chao Liu committed
83
84
    constexpr unsigned IR0 = DstFromSrcReorder{}.Get(I0);
    constexpr unsigned IR1 = DstFromSrcReorder{}.Get(I1);
Chao Liu's avatar
Chao Liu committed
85

86
87
    constexpr auto src_desc = SrcDesc{};
    constexpr auto dst_desc = DstDesc{};
Chao Liu's avatar
Chao Liu committed
88
    constexpr auto ref_desc = make_ConstantTensorDescriptor(SrcOpLengths{});
Chao Liu's avatar
Chao Liu committed
89

90
    constexpr unsigned NLoop = ref_desc.GetElementSize() / BlockSize;
Chao Liu's avatar
Chao Liu committed
91
92
93
94
95

    for(unsigned iloop = 0; iloop < NLoop; ++iloop)
    {
        unsigned is = threadIdx.x + iloop * BlockSize;

Chao Liu's avatar
Chao Liu committed
96
        unsigned did[2];
Chao Liu's avatar
Chao Liu committed
97

98
        did[0] = is / ref_desc.GetStride(I0);
Chao Liu's avatar
Chao Liu committed
99

100
        is -= did[0] * ref_desc.GetStride(I0);
Chao Liu's avatar
Chao Liu committed
101

102
        did[1] = is / ref_desc.GetStride(I1);
Chao Liu's avatar
Chao Liu committed
103

Chao Liu's avatar
Chao Liu committed
104
        const unsigned aindex = src_desc.Get1dIndex(did[0], did[1]);
Chao Liu's avatar
Chao Liu committed
105

Chao Liu's avatar
Chao Liu committed
106
        const unsigned bindex = dst_desc.Get1dIndex(did[IR0], did[IR1]);
107
108

        f(p_src[aindex], p_dst[bindex]);
Chao Liu's avatar
Chao Liu committed
109
110
    }

111
    constexpr bool has_tail = (ref_desc.GetElementSize() > NLoop * BlockSize);
Chao Liu's avatar
Chao Liu committed
112
113
114
115
116

    if(has_tail)
    {
        unsigned is = threadIdx.x + NLoop * BlockSize;

117
        if(is < ref_desc.GetElementSize())
Chao Liu's avatar
Chao Liu committed
118
        {
Chao Liu's avatar
Chao Liu committed
119
            unsigned did[2];
120
121

            did[0] = is / ref_desc.GetStride(I0);
Chao Liu's avatar
Chao Liu committed
122

123
            is -= did[0] * ref_desc.GetStride(I0);
Chao Liu's avatar
Chao Liu committed
124

125
            did[1] = is / ref_desc.GetStride(I1);
Chao Liu's avatar
Chao Liu committed
126

Chao Liu's avatar
Chao Liu committed
127
            const unsigned aindex = src_desc.Get1dIndex(did[0], did[1]);
128

Chao Liu's avatar
Chao Liu committed
129
            const unsigned bindex = dst_desc.Get1dIndex(did[IR0], did[IR1]);
130

131
            f(p_src[aindex], p_dst[bindex]);
132
133
134
135
        }
    }
}

Chao Liu's avatar
Chao Liu committed
136
template <unsigned BlockSize, class Float, class DstDesc>
Chao Liu's avatar
Chao Liu committed
137
__device__ void blockwise_2d_tensor_set_zero(DstDesc, Float* __restrict__ p_dst)
138
{
Chao Liu's avatar
Chao Liu committed
139
    auto f_set_zero = [](Float& v) { v = Float(0); };
Chao Liu's avatar
Chao Liu committed
140

Chao Liu's avatar
Chao Liu committed
141
    blockwise_2d_tensor_pointwise_operation_unary<BlockSize>(DstDesc{}, p_dst, f_set_zero);
Chao Liu's avatar
Chao Liu committed
142
}
143

Chao Liu's avatar
Chao Liu committed
144
template <unsigned BlockSize,
Chao Liu's avatar
Chao Liu committed
145
          class Float,
146
147
          class SrcDesc,
          class DstDesc,
Chao Liu's avatar
Chao Liu committed
148
149
150
          class SrcOpLengths,
          class DstFromSrcReorder>
__device__ void
Chao Liu's avatar
Chao Liu committed
151
blockwise_2d_tensor_copy_reorder_by_get_dst_from_src(SrcDesc,
Chao Liu's avatar
Chao Liu committed
152
153
154
155
156
                                                     Float* const __restrict__ p_src,
                                                     DstDesc,
                                                     Float* __restrict__ p_dst,
                                                     SrcOpLengths,
                                                     DstFromSrcReorder)
157
{
Chao Liu's avatar
Chao Liu committed
158
    auto f_copy = [](const Float& src, Float& dst) { dst = src; };
159

Chao Liu's avatar
Chao Liu committed
160
    blockwise_2d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src<BlockSize>(
Chao Liu's avatar
Chao Liu committed
161
        SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, DstFromSrcReorder{}, f_copy);
162
163
}

Chao Liu's avatar
Chao Liu committed
164
template <unsigned BlockSize, class Float, class SrcDesc, class DstDesc, class SrcOpLengths>
Chao Liu's avatar
Chao Liu committed
165
__device__ void blockwise_2d_tensor_copy(
Chao Liu's avatar
Chao Liu committed
166
    SrcDesc, Float* const __restrict__ p_src, DstDesc, Float* __restrict__ p_dst, SrcOpLengths)
Chao Liu's avatar
Chao Liu committed
167
{
Chao Liu's avatar
Chao Liu committed
168
    constexpr auto dst_from_src_reorder = Sequence<0, 1>{};
169

Chao Liu's avatar
Chao Liu committed
170
    blockwise_2d_tensor_copy_reorder_by_get_dst_from_src<BlockSize>(
Chao Liu's avatar
Chao Liu committed
171
        SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, dst_from_src_reorder);
172
}