threadwise_nd_tensor_op.hip.hpp 7.18 KB
Newer Older
1
2
3
4
#pragma once
#include "ConstantTensorDescriptor.hip.hpp"

// need to assume src and dst is aligned
Chao Liu's avatar
Chao Liu committed
5
template <class Float, class SrcDesc, class DstDesc, class SrcOpLengths, index_t DataPerRead>
6
7
8
9
10
11
12
__device__ void threadwise_6d_tensor_copy(SrcDesc,
                                          const Float* __restrict__ p_src,
                                          DstDesc,
                                          Float* __restrict__ p_dst,
                                          SrcOpLengths,
                                          Number<DataPerRead>)
{
Chao Liu's avatar
Chao Liu committed
13
    using vector_t = typename vector_type<Float, DataPerRead>::MemoryType;
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39

    static_assert(SrcDesc{}.GetDimension() == 6 && DstDesc{}.GetDimension() == 6 &&
                      SrcOpLengths::nDim == 6,
                  "wrong! should be 6 dimension");

    constexpr auto I0 = Number<0>{};
    constexpr auto I1 = Number<1>{};
    constexpr auto I2 = Number<2>{};
    constexpr auto I3 = Number<3>{};
    constexpr auto I4 = Number<4>{};
    constexpr auto I5 = Number<5>{};

    constexpr auto src_desc = SrcDesc{};
    constexpr auto dst_desc = DstDesc{};
    constexpr auto ref_desc = make_ConstantTensorDescriptor(SrcOpLengths{});

    static_assert(SrcDesc{}.GetStride(I5) == 1 && DstDesc{}.GetStride(I5) == 1,
                  "wrong! only support stride5 == 1!\n");

    static_assert(DataPerRead == 1 || DataPerRead == 2 || DataPerRead == 4,
                  "wrong! only support DataPerRead == 1, 2 or 4!\n");

    static_assert(SrcDesc{}.GetStride(I4) % DataPerRead == 0 &&
                      DstDesc{}.GetStride(I4) % DataPerRead == 0,
                  "wrong! src and dst stride should be multiple of DataPerRead to keep alignment");

Chao Liu's avatar
Chao Liu committed
40
    constexpr index_t L5 = SrcOpLengths{}.Get(I5);
41
42
43

    static_assert(L5 % DataPerRead == 0, "wrong! L5 should be evenly divided by DataPerRead");

Chao Liu's avatar
Chao Liu committed
44
    constexpr index_t nloop_d5 = L5 / DataPerRead;
45

Chao Liu's avatar
Chao Liu committed
46
    for(index_t did0 = 0; did0 < ref_desc.GetLength(I0); ++did0)
47
    {
Chao Liu's avatar
Chao Liu committed
48
        for(index_t did1 = 0; did1 < ref_desc.GetLength(I1); ++did1)
49
        {
Chao Liu's avatar
Chao Liu committed
50
            for(index_t did2 = 0; did2 < ref_desc.GetLength(I2); ++did2)
51
            {
Chao Liu's avatar
Chao Liu committed
52
                for(index_t did3 = 0; did3 < ref_desc.GetLength(I3); ++did3)
53
                {
Chao Liu's avatar
Chao Liu committed
54
                    for(index_t did4 = 0; did4 < ref_desc.GetLength(I4); ++did4)
55
                    {
Chao Liu's avatar
Chao Liu committed
56
                        for(index_t iloop_d5 = 0; iloop_d5 < nloop_d5; ++iloop_d5)
57
                        {
Chao Liu's avatar
Chao Liu committed
58
                            const index_t src_index = src_desc.Get1dIndex(
59
60
                                did0, did1, did2, did3, did4, iloop_d5 * DataPerRead);

Chao Liu's avatar
Chao Liu committed
61
                            const index_t dst_index = dst_desc.Get1dIndex(
62
63
                                did0, did1, did2, did3, did4, iloop_d5 * DataPerRead);

Chao Liu's avatar
Chao Liu committed
64
65
                            *(reinterpret_cast<vector_t*>(p_dst + dst_index)) =
                                *(reinterpret_cast<const vector_t*>(p_src + src_index));
66
67
68
69
70
71
72
73
74
                        }
                    }
                }
            }
        }
    }
}

// need to assume src and dst is aligned
Chao Liu's avatar
Chao Liu committed
75
template <class Float, class SrcDesc, class DstDesc, class SrcOpLengths, index_t DataPerRead>
76
77
78
79
80
81
82
__device__ void threadwise_8d_tensor_copy(SrcDesc,
                                          const Float* __restrict__ p_src,
                                          DstDesc,
                                          Float* __restrict__ p_dst,
                                          SrcOpLengths,
                                          Number<DataPerRead>)
{
Chao Liu's avatar
Chao Liu committed
83
    using vector_t = typename vector_type<Float, DataPerRead>::MemoryType;
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111

    static_assert(SrcDesc{}.GetDimension() == 8 && DstDesc{}.GetDimension() == 8 &&
                      SrcOpLengths::nDim == 8,
                  "wrong! should be 8 dimension");

    constexpr auto I0 = Number<0>{};
    constexpr auto I1 = Number<1>{};
    constexpr auto I2 = Number<2>{};
    constexpr auto I3 = Number<3>{};
    constexpr auto I4 = Number<4>{};
    constexpr auto I5 = Number<5>{};
    constexpr auto I6 = Number<6>{};
    constexpr auto I7 = Number<7>{};

    constexpr auto src_desc = SrcDesc{};
    constexpr auto dst_desc = DstDesc{};
    constexpr auto ref_desc = make_ConstantTensorDescriptor(SrcOpLengths{});

    static_assert(SrcDesc{}.GetStride(I7) == 1 && DstDesc{}.GetStride(I7) == 1,
                  "wrong! only support stride7 == 1!\n");

    static_assert(DataPerRead == 1 || DataPerRead == 2 || DataPerRead == 4,
                  "wrong! only support DataPerRead == 1, 2 or 4!\n");

    static_assert(SrcDesc{}.GetStride(I6) % DataPerRead == 0 &&
                      DstDesc{}.GetStride(I6) % DataPerRead == 0,
                  "wrong! src and dst stride should be multiple of DataPerRead to keep alignment");

Chao Liu's avatar
Chao Liu committed
112
    constexpr index_t L7 = SrcOpLengths{}.Get(I7);
113
114
115

    static_assert(L7 % DataPerRead == 0, "wrong! L7 should be evenly divided by DataPerRead");

Chao Liu's avatar
Chao Liu committed
116
    constexpr index_t nloop_d7 = L7 / DataPerRead;
117

Chao Liu's avatar
Chao Liu committed
118
    for(index_t did0 = 0; did0 < ref_desc.GetLength(I0); ++did0)
119
    {
Chao Liu's avatar
Chao Liu committed
120
        for(index_t did1 = 0; did1 < ref_desc.GetLength(I1); ++did1)
121
        {
Chao Liu's avatar
Chao Liu committed
122
            for(index_t did2 = 0; did2 < ref_desc.GetLength(I2); ++did2)
123
            {
Chao Liu's avatar
Chao Liu committed
124
                for(index_t did3 = 0; did3 < ref_desc.GetLength(I3); ++did3)
125
                {
Chao Liu's avatar
Chao Liu committed
126
                    for(index_t did4 = 0; did4 < ref_desc.GetLength(I4); ++did4)
127
                    {
Chao Liu's avatar
Chao Liu committed
128
                        for(index_t did5 = 0; did5 < ref_desc.GetLength(I5); ++did5)
129
                        {
Chao Liu's avatar
Chao Liu committed
130
                            for(index_t did6 = 0; did6 < ref_desc.GetLength(I6); ++did6)
131
                            {
Chao Liu's avatar
Chao Liu committed
132
                                for(index_t iloop_d7 = 0; iloop_d7 < nloop_d7; ++iloop_d7)
133
                                {
Chao Liu's avatar
Chao Liu committed
134
                                    const index_t src_index =
135
136
137
138
139
140
141
142
143
                                        src_desc.Get1dIndex(did0,
                                                            did1,
                                                            did2,
                                                            did3,
                                                            did4,
                                                            did5,
                                                            did6,
                                                            iloop_d7 * DataPerRead);

Chao Liu's avatar
Chao Liu committed
144
                                    const index_t dst_index =
145
146
147
148
149
150
151
152
153
                                        dst_desc.Get1dIndex(did0,
                                                            did1,
                                                            did2,
                                                            did3,
                                                            did4,
                                                            did5,
                                                            did6,
                                                            iloop_d7 * DataPerRead);

Chao Liu's avatar
Chao Liu committed
154
155
                                    *(reinterpret_cast<vector_t*>(p_dst + dst_index)) =
                                        *(reinterpret_cast<const vector_t*>(p_src + src_index));
156
157
158
159
160
161
162
163
164
                                }
                            }
                        }
                    }
                }
            }
        }
    }
}