blockwise_gemm.hip.hpp 20.1 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
#pragma once
Chao Liu's avatar
Chao Liu committed
2
#include "common.hip.hpp"
3
#include "threadwise_gemm.hip.hpp"
Chao Liu's avatar
Chao Liu committed
4

Chao Liu's avatar
Chao Liu committed
5
6
// if following number are power of 2, index calculation shall be greatly reduced:
//    MPerThreadSubC, NPerThreadSubC, MLevel0Cluster, NLevel0Cluster, MLevel1Cluster, NLevel1Cluster
Chao Liu's avatar
Chao Liu committed
7
template <index_t BlockSize,
Chao Liu's avatar
Chao Liu committed
8
9
          class BlockMatrixA,
          class BlockMatrixB,
Chao Liu's avatar
Chao Liu committed
10
          class ThreadMatrixC,
Chao Liu's avatar
Chao Liu committed
11
12
13
14
15
16
          index_t MPerThreadSubC,
          index_t NPerThreadSubC,
          index_t MLevel0Cluster,
          index_t NLevel0Cluster,
          index_t MLevel1Cluster,
          index_t NLevel1Cluster,
Chao Liu's avatar
Chao Liu committed
17
18
19
          index_t KPerThreadLoop,
          index_t DataPerReadA,
          index_t DataPerReadB>
Chao Liu's avatar
Chao Liu committed
20
21
22
23
struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
{
    struct MatrixIndex
    {
Chao Liu's avatar
Chao Liu committed
24
25
        index_t row;
        index_t col;
Chao Liu's avatar
Chao Liu committed
26
27
    };

Chao Liu's avatar
Chao Liu committed
28
29
    index_t mMyThreadOffsetA;
    index_t mMyThreadOffsetB;
Chao Liu's avatar
Chao Liu committed
30
31
32

    __device__ BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2()
    {
Chao Liu's avatar
Chao Liu committed
33
        constexpr index_t ThreadPerLevel1Cluster =
Chao Liu's avatar
Chao Liu committed
34
35
36
37
            MLevel0Cluster * NLevel0Cluster * MLevel1Cluster * NLevel1Cluster;

        static_assert(BlockSize == ThreadPerLevel1Cluster, "wrong! wrong blocksize\n");

Chao Liu's avatar
Chao Liu committed
38
        static_assert(BlockMatrixA::NRow() == BlockMatrixB::NRow(),
Chao Liu's avatar
Chao Liu committed
39
40
                      "wrong! K dimension not consistent\n");

Chao Liu's avatar
Chao Liu committed
41
42
43
        constexpr index_t M = BlockMatrixA::NCol(); // A is transposed
        constexpr index_t N = BlockMatrixB::NCol();
        constexpr index_t K = BlockMatrixA::NRow();
Chao Liu's avatar
Chao Liu committed
44

Chao Liu's avatar
Chao Liu committed
45
46
47
        static_assert(M % (MPerThreadSubC * MLevel0Cluster * MLevel1Cluster) == 0 &&
                          N % (NPerThreadSubC * NLevel0Cluster * NLevel1Cluster) == 0,
                      "wrong! Cannot evenly divide work among\n");
Chao Liu's avatar
Chao Liu committed
48

Chao Liu's avatar
Chao Liu committed
49
        static_assert(is_same_type(ThreadMatrixC::GetLengths(), GetThreadMatrixCLengths()),
Chao Liu's avatar
Chao Liu committed
50
                      "wrong! ThreadMatrixC lengths is wrong");
Chao Liu's avatar
Chao Liu committed
51

Chao Liu's avatar
Chao Liu committed
52
        auto c_thread_mtx_index = GetBeginOfThreadMatrixC(get_thread_local_1d_id());
Chao Liu's avatar
Chao Liu committed
53

54
55
        mMyThreadOffsetA = BlockMatrixA::GetOffsetFromMultiIndex(0, c_thread_mtx_index.row);
        mMyThreadOffsetB = BlockMatrixB::GetOffsetFromMultiIndex(0, c_thread_mtx_index.col);
Chao Liu's avatar
Chao Liu committed
56
    }
Chao Liu's avatar
Chao Liu committed
57

Chao Liu's avatar
Chao Liu committed
58
    __device__ static constexpr auto GetThreadMatrixCLengths()
Chao Liu's avatar
Chao Liu committed
59
60
61
    {
        constexpr index_t M = BlockMatrixA::NCol(); // A is transposed
        constexpr index_t N = BlockMatrixB::NCol();
Chao Liu's avatar
Chao Liu committed
62

Chao Liu's avatar
Chao Liu committed
63
64
        constexpr index_t MRepeat = M / (MPerThreadSubC * MLevel0Cluster * MLevel1Cluster);
        constexpr index_t NRepeat = N / (NPerThreadSubC * NLevel0Cluster * NLevel1Cluster);
Chao Liu's avatar
Chao Liu committed
65

Chao Liu's avatar
Chao Liu committed
66
        return Sequence<MRepeat * MPerThreadSubC, NRepeat * NPerThreadSubC>{};
Chao Liu's avatar
Chao Liu committed
67
68
    }

Chao Liu's avatar
Chao Liu committed
69
    __device__ static MatrixIndex GetBeginOfThreadMatrixC(index_t thread_id)
Chao Liu's avatar
Chao Liu committed
70
    {
Chao Liu's avatar
Chao Liu committed
71
        constexpr index_t ThreadPerLevel0Cluster = MLevel0Cluster * NLevel0Cluster;
Chao Liu's avatar
Chao Liu committed
72

Chao Liu's avatar
Chao Liu committed
73
74
75
        index_t level1_id   = thread_id / ThreadPerLevel0Cluster;
        index_t level1_m_id = level1_id / NLevel1Cluster;
        index_t level1_n_id = level1_id % NLevel1Cluster;
Chao Liu's avatar
Chao Liu committed
76

Chao Liu's avatar
Chao Liu committed
77
78
79
        index_t level0_id   = thread_id % ThreadPerLevel0Cluster;
        index_t level0_m_id = level0_id / NLevel0Cluster;
        index_t level0_n_id = level0_id % NLevel0Cluster;
Chao Liu's avatar
Chao Liu committed
80

Chao Liu's avatar
Chao Liu committed
81
82
        constexpr index_t MPerLevel0Cluster = MPerThreadSubC * MLevel0Cluster;
        constexpr index_t NPerLevel0Cluster = NPerThreadSubC * NLevel0Cluster;
Chao Liu's avatar
Chao Liu committed
83
84
85
86
87

        return MatrixIndex{level1_m_id * MPerLevel0Cluster + level0_m_id * MPerThreadSubC,
                           level1_n_id * NPerLevel0Cluster + level0_n_id * NPerThreadSubC};
    }

Chao Liu's avatar
Chao Liu committed
88
89
    __device__ static MatrixIndex GetDistanceFromBeginOfThreadMatrixC(index_t m_in_c,
                                                                      index_t n_in_c)
Chao Liu's avatar
Chao Liu committed
90
    {
Chao Liu's avatar
Chao Liu committed
91
        constexpr auto c_thread_mtx = ThreadMatrixC{};
Chao Liu's avatar
Chao Liu committed
92

Chao Liu's avatar
Chao Liu committed
93
94
        constexpr index_t MPerThread = c_thread_mtx.NRow();
        constexpr index_t NPerThread = c_thread_mtx.NCol();
Chao Liu's avatar
Chao Liu committed
95

Chao Liu's avatar
Chao Liu committed
96
97
        constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
        constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
Chao Liu's avatar
Chao Liu committed
98

Chao Liu's avatar
Chao Liu committed
99
100
        constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
        constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
Chao Liu's avatar
Chao Liu committed
101

Chao Liu's avatar
Chao Liu committed
102
103
        index_t m_repeat = m_in_c / MPerThreadSubC;
        index_t n_repeat = n_in_c / NPerThreadSubC;
Chao Liu's avatar
Chao Liu committed
104

Chao Liu's avatar
Chao Liu committed
105
106
        index_t m_in_sub_c = m_in_c % MPerThreadSubC;
        index_t n_in_sub_c = n_in_c % NPerThreadSubC;
Chao Liu's avatar
Chao Liu committed
107
108
109
110
111

        return MatrixIndex{m_repeat * MPerLevel1Cluster + m_in_sub_c,
                           n_repeat * NPerLevel1Cluster + n_in_sub_c};
    }

Chao Liu's avatar
Chao Liu committed
112
#if USE_AMD_INLINE_ASM
113
    // TODO: this is not working correctly
114
    template <class FloatA, class FloatB, class FloatC>
Chao Liu's avatar
Chao Liu committed
115
116
    __device__ void Run_asm(const FloatA* __restrict__ p_a_block,
                            const FloatB* __restrict__ p_b_block,
117
                            FloatC* __restrict__ p_c_thread) const
Chao Liu's avatar
Chao Liu committed
118
    {
Chao Liu's avatar
Chao Liu committed
119
120
        constexpr auto True  = integral_constant<bool, true>{};
        constexpr auto False = integral_constant<bool, false>{};
Chao Liu's avatar
Chao Liu committed
121

Chao Liu's avatar
Chao Liu committed
122
123
124
        constexpr auto a_block_mtx  = BlockMatrixA{};
        constexpr auto b_block_mtx  = BlockMatrixB{};
        constexpr auto c_thread_mtx = ThreadMatrixC{};
Chao Liu's avatar
Chao Liu committed
125

Chao Liu's avatar
Chao Liu committed
126
127
128
        constexpr index_t M = a_block_mtx.NCol();
        constexpr index_t N = b_block_mtx.NCol();
        constexpr index_t K = a_block_mtx.NRow();
Chao Liu's avatar
Chao Liu committed
129

Chao Liu's avatar
Chao Liu committed
130
131
        constexpr index_t MPerThread = c_thread_mtx.NRow();
        constexpr index_t NPerThread = c_thread_mtx.NCol();
Chao Liu's avatar
Chao Liu committed
132
133

        // thread A, B for GEMM
Chao Liu's avatar
Chao Liu committed
134
135
        constexpr auto a_thread_mtx =
            make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<MPerThread>{});
Chao Liu's avatar
Chao Liu committed
136

Chao Liu's avatar
Chao Liu committed
137
138
        constexpr auto b_thread_mtx =
            make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<NPerThread>{});
Chao Liu's avatar
Chao Liu committed
139
140

        // thread A-sub, B-sub for copy
Chao Liu's avatar
Chao Liu committed
141
142
        constexpr auto a_thread_sub_mtx = make_ConstantMatrixDescriptor(
            Number<KPerThreadLoop>{}, Number<MPerThreadSubC>{}, Number<MPerThread>{});
Chao Liu's avatar
Chao Liu committed
143

Chao Liu's avatar
Chao Liu committed
144
145
        constexpr auto b_thread_sub_mtx = make_ConstantMatrixDescriptor(
            Number<KPerThreadLoop>{}, Number<NPerThreadSubC>{}, Number<NPerThread>{});
Chao Liu's avatar
Chao Liu committed
146

147
        FloatA p_a_thread[a_thread_mtx.GetElementSpace()];
Jing Zhang's avatar
Jing Zhang committed
148
        //FloatB p_b_thread[b_thread_mtx.GetElementSpace() * 2];
149
150
151
152
153
154
155
156
157
158
        FloatB p_b_thread[b_thread_mtx.GetElementSpace()];

        constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
        constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;

        // assertion for inline asm
        static_assert(is_same<FloatA, float>::value && is_same<FloatB, float>::value &&
                          is_same<FloatC, float>::value,
                      "Run_asm only deal with float\n");

Chao Liu's avatar
Chao Liu committed
159
160
161
162
        static_assert(MPerThreadSubC == 4 && NPerThreadSubC == 4 && KPerThreadLoop == 1 &&
                          MPerThread == 8 && NPerThread == 8,
                      "Run_asm cannot deal with this GEMM shape yet\n");

163
164
        static_assert(DataPerReadA == 4 && DataPerReadB == 4, "Run_asm only do float4 read\n");

Chao Liu's avatar
Chao Liu committed
165
166
        using Float4 = vector_type<float, 4>::MemoryType;

Jing Zhang's avatar
Jing Zhang committed
167
168
169
        Float4* reg_a = (Float4*)(p_a_thread);
        Float4* reg_b = (Float4*)(p_b_thread);
        Float4* reg_c = (Float4*)(p_c_thread);
170
171
172
173
174
175
176

        reg_a[0] = *reinterpret_cast<const Float4*>(&p_a_block[mMyThreadOffsetA]);
        reg_b[0] = *reinterpret_cast<const Float4*>(&p_b_block[mMyThreadOffsetB]);
        reg_b[1] =
            *reinterpret_cast<const Float4*>(&p_b_block[mMyThreadOffsetB + NPerLevel1Cluster]);
        reg_a[1] =
            *reinterpret_cast<const Float4*>(&p_a_block[mMyThreadOffsetA + MPerLevel1Cluster]);
Jing Zhang's avatar
Jing Zhang committed
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209

#if 0
#pragma unroll
        for(index_t k = 1; k < K; ++k)
        {
            int b_reg_0 = (k % 2) * 2;
            int b_reg_1 = ((k - 1) % 2) * 2;
            reg_b[b_reg_0 + 0] =
                *reinterpret_cast<const Float4*>(&p_b_block[mMyThreadOffsetB + k * N]);
            reg_b[b_reg_0 + 1] = *reinterpret_cast<const Float4*>(
                &p_b_block[mMyThreadOffsetB + k * N + NPerLevel1Cluster]);
            outerProduct4x4(reg_a[0], reg_b[b_reg_1 + 0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]);
            outerProduct4x4(reg_a[0], reg_b[b_reg_1 + 1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]);
            reg_a[0] = *reinterpret_cast<const Float4*>(&p_a_block[mMyThreadOffsetA + k * M]);
            outerProduct4x4(
                reg_a[1], reg_b[b_reg_1 + 0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]);
            outerProduct4x4(
                reg_a[1], reg_b[b_reg_1 + 1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]);
            reg_a[1] = *reinterpret_cast<const Float4*>(
                &p_a_block[mMyThreadOffsetA + k * M + MPerLevel1Cluster]);
        }
        outerProduct4x4(reg_a[0], reg_b[2], reg_c[0], reg_c[2], reg_c[4], reg_c[6]);
        outerProduct4x4(reg_a[0], reg_b[3], reg_c[1], reg_c[3], reg_c[5], reg_c[7]);
        outerProduct4x4(reg_a[1], reg_b[2], reg_c[8], reg_c[10], reg_c[12], reg_c[14]);
        outerProduct4x4(reg_a[1], reg_b[3], reg_c[9], reg_c[11], reg_c[13], reg_c[15]);

#else
        reg_a[0] = *reinterpret_cast<const Float4*>(&p_a_block[mMyThreadOffsetA]);
        reg_b[0] = *reinterpret_cast<const Float4*>(&p_b_block[mMyThreadOffsetB]);
        reg_b[1] =
            *reinterpret_cast<const Float4*>(&p_b_block[mMyThreadOffsetB + NPerLevel1Cluster]);
        reg_a[1] =
            *reinterpret_cast<const Float4*>(&p_a_block[mMyThreadOffsetA + MPerLevel1Cluster]);
Jing Zhang's avatar
Jing Zhang committed
210
211
212
        outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]);
        outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]);
#pragma unroll
213
        for(index_t k = 1; k < K; ++k)
Jing Zhang's avatar
Jing Zhang committed
214
        {
215
            reg_a[0] = *reinterpret_cast<const Float4*>(&p_a_block[mMyThreadOffsetA + k * M]);
Jing Zhang's avatar
Jing Zhang committed
216
            outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]);
217
            reg_b[0] = *reinterpret_cast<const Float4*>(&p_b_block[mMyThreadOffsetB + k * N]);
Jing Zhang's avatar
Jing Zhang committed
218
            outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]);
219
220
221
222
            reg_b[1] = *reinterpret_cast<const Float4*>(
                &p_b_block[mMyThreadOffsetB + k * N + NPerLevel1Cluster]);
            reg_a[1] = *reinterpret_cast<const Float4*>(
                &p_a_block[mMyThreadOffsetA + k * M + MPerLevel1Cluster]);
Jing Zhang's avatar
Jing Zhang committed
223
224
            outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]);
            outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]);
Chao Liu's avatar
Chao Liu committed
225
        }
Jing Zhang's avatar
Jing Zhang committed
226
227
        outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]);
        outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]);
Jing Zhang's avatar
Jing Zhang committed
228
#endif
Chao Liu's avatar
Chao Liu committed
229
    }
230
#endif
231

232
    template <class FloatA, class FloatB, class FloatC>
Chao Liu's avatar
Chao Liu committed
233
234
    __device__ void Run(const FloatA* const __restrict__ p_a_block,
                        const FloatB* const __restrict__ p_b_block,
235
                        FloatC* const __restrict__ p_c_thread) const
Chao Liu's avatar
Chao Liu committed
236
237
238
239
240
241
242
243
    {
        constexpr auto True  = integral_constant<bool, true>{};
        constexpr auto False = integral_constant<bool, false>{};

        constexpr auto a_block_mtx  = BlockMatrixA{};
        constexpr auto b_block_mtx  = BlockMatrixB{};
        constexpr auto c_thread_mtx = ThreadMatrixC{};

Chao Liu's avatar
Chao Liu committed
244
245
246
        constexpr index_t M = a_block_mtx.NCol();
        constexpr index_t N = b_block_mtx.NCol();
        constexpr index_t K = a_block_mtx.NRow();
Chao Liu's avatar
Chao Liu committed
247

Chao Liu's avatar
Chao Liu committed
248
249
        constexpr index_t MPerThread = c_thread_mtx.NRow();
        constexpr index_t NPerThread = c_thread_mtx.NCol();
Chao Liu's avatar
Chao Liu committed
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267

        // thread A, B for GEMM
        constexpr auto a_thread_mtx =
            make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<MPerThread>{});

        constexpr auto b_thread_mtx =
            make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<NPerThread>{});

        // thread A-sub, B-sub for copy
        constexpr auto a_thread_sub_mtx = make_ConstantMatrixDescriptor(
            Number<KPerThreadLoop>{}, Number<MPerThreadSubC>{}, Number<MPerThread>{});

        constexpr auto b_thread_sub_mtx = make_ConstantMatrixDescriptor(
            Number<KPerThreadLoop>{}, Number<NPerThreadSubC>{}, Number<NPerThread>{});

        FloatA p_a_thread[a_thread_mtx.GetElementSpace()];
        FloatB p_b_thread[b_thread_mtx.GetElementSpace()];

Chao Liu's avatar
Chao Liu committed
268
269
        constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
        constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
Chao Liu's avatar
Chao Liu committed
270

Chao Liu's avatar
Chao Liu committed
271
272
        constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
        constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
Chao Liu's avatar
Chao Liu committed
273

Chao Liu's avatar
Chao Liu committed
274
275
        const FloatA* const p_a_block_thread_offset = p_a_block + mMyThreadOffsetA;

Chao Liu's avatar
Chao Liu committed
276
277
#pragma unroll
        // loop over k
Chao Liu's avatar
Chao Liu committed
278
        for(index_t k_begin = 0; k_begin < K; k_begin += KPerThreadLoop)
Chao Liu's avatar
Chao Liu committed
279
        {
Chao Liu's avatar
Chao Liu committed
280
#pragma unroll
Chao Liu's avatar
Chao Liu committed
281
            // copy A-sub to form A
Chao Liu's avatar
Chao Liu committed
282
            for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
Chao Liu's avatar
Chao Liu committed
283
284
285
            {
                threadwise_matrix_copy(
                    a_block_mtx,
286
287
                    p_a_block +
                        a_block_mtx.GetOffsetFromMultiIndex(k_begin, m_repeat * MPerLevel1Cluster) +
Chao Liu's avatar
Chao Liu committed
288
289
                        mMyThreadOffsetA,
                    a_thread_mtx,
290
                    p_a_thread + a_thread_mtx.GetOffsetFromMultiIndex(0, m_repeat * MPerThreadSubC),
Chao Liu's avatar
Chao Liu committed
291
292
                    a_thread_sub_mtx.GetLengths(),
                    Number<DataPerReadA>{});
Chao Liu's avatar
Chao Liu committed
293
            }
Chao Liu's avatar
Chao Liu committed
294

Chao Liu's avatar
Chao Liu committed
295
#pragma unroll
Chao Liu's avatar
Chao Liu committed
296
            // copy B-sub to form B
Chao Liu's avatar
Chao Liu committed
297
            for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
Chao Liu's avatar
Chao Liu committed
298
299
300
            {
                threadwise_matrix_copy(
                    b_block_mtx,
301
302
                    p_b_block +
                        b_block_mtx.GetOffsetFromMultiIndex(k_begin, n_repeat * NPerLevel1Cluster) +
Chao Liu's avatar
Chao Liu committed
303
304
                        mMyThreadOffsetB,
                    b_thread_mtx,
305
                    p_b_thread + b_thread_mtx.GetOffsetFromMultiIndex(0, n_repeat * NPerThreadSubC),
Chao Liu's avatar
Chao Liu committed
306
307
                    b_thread_sub_mtx.GetLengths(),
                    Number<DataPerReadB>{});
Chao Liu's avatar
Chao Liu committed
308
309
            }

Chao Liu's avatar
Chao Liu committed
310
            // C = A * B
Chao Liu's avatar
Chao Liu committed
311
312
313
314
315
316
317
318
            threadwise_gemm(a_thread_mtx,
                            True,
                            p_a_thread,
                            b_thread_mtx,
                            False,
                            p_b_thread,
                            c_thread_mtx,
                            False,
319
                            p_c_thread);
Chao Liu's avatar
Chao Liu committed
320
321
322
        }
    }

323
    template <class FloatA, class FloatB, class FloatC>
324
325
    __device__ void Run_RegisterDoubleBuffer(FloatA* const p_a_block,
                                             FloatB* const p_b_block,
326
                                             FloatC* p_c_thread) const
327
    {
Chao Liu's avatar
Chao Liu committed
328
329
        constexpr auto True  = integral_constant<bool, true>{};
        constexpr auto False = integral_constant<bool, false>{};
330

Chao Liu's avatar
Chao Liu committed
331
332
333
        constexpr auto a_block_mtx  = BlockMatrixA{};
        constexpr auto b_block_mtx  = BlockMatrixB{};
        constexpr auto c_thread_mtx = ThreadMatrixC{};
334

Chao Liu's avatar
Chao Liu committed
335
336
337
        constexpr index_t M = a_block_mtx.NCol();
        constexpr index_t N = b_block_mtx.NCol();
        constexpr index_t K = a_block_mtx.NRow();
338

Chao Liu's avatar
Chao Liu committed
339
340
        constexpr index_t MPerThread = c_thread_mtx.NRow();
        constexpr index_t NPerThread = c_thread_mtx.NCol();
341
342

        // thread A, B for GEMM
Chao Liu's avatar
Chao Liu committed
343
344
        constexpr auto a_thread_mtx =
            make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<MPerThread>{});
345

Chao Liu's avatar
Chao Liu committed
346
347
        constexpr auto b_thread_mtx =
            make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<NPerThread>{});
348
349

        // thread A-sub, B-sub for copy
Chao Liu's avatar
Chao Liu committed
350
351
        constexpr auto a_thread_sub_mtx = make_ConstantMatrixDescriptor(
            Number<KPerThreadLoop>{}, Number<MPerThreadSubC>{}, Number<MPerThread>{});
352

Chao Liu's avatar
Chao Liu committed
353
354
        constexpr auto b_thread_sub_mtx = make_ConstantMatrixDescriptor(
            Number<KPerThreadLoop>{}, Number<NPerThreadSubC>{}, Number<NPerThread>{});
355

356
        // register
357
358
359
360
361
362
        FloatA p_a_thread_0[a_thread_mtx.GetElementSpace()];
        FloatB p_b_thread_0[b_thread_mtx.GetElementSpace()];

        FloatA p_a_thread_1[a_thread_mtx.GetElementSpace()];
        FloatB p_b_thread_1[b_thread_mtx.GetElementSpace()];

Chao Liu's avatar
Chao Liu committed
363
364
        constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
        constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
365

Chao Liu's avatar
Chao Liu committed
366
367
        constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
        constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
368

Chao Liu's avatar
Chao Liu committed
369
// preload A, B
370
#pragma unroll
Chao Liu's avatar
Chao Liu committed
371
        for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
372
373
374
375
376
        { // copy A-sub to form A
            threadwise_matrix_copy(a_block_mtx,
                                   p_a_block + mMyThreadOffsetA + m_repeat * MPerLevel1Cluster,
                                   a_thread_sub_mtx,
                                   p_a_thread_0 + m_repeat * MPerThreadSubC,
Chao Liu's avatar
Chao Liu committed
377
378
                                   a_thread_sub_mtx.GetLengths(),
                                   Number<DataPerReadA>{});
379
380
381
        }

#pragma unroll
Chao Liu's avatar
Chao Liu committed
382
        for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
383
384
385
386
387
        { // copy B-sub to form B
            threadwise_matrix_copy(b_block_mtx,
                                   p_b_block + mMyThreadOffsetB + n_repeat * NPerLevel1Cluster,
                                   b_thread_sub_mtx,
                                   p_b_thread_0 + n_repeat * NPerThreadSubC,
Chao Liu's avatar
Chao Liu committed
388
389
                                   b_thread_sub_mtx.GetLengths(),
                                   Number<DataPerReadB>{});
390
391
392
393
394
        }

        bool even_loop = true;

#pragma unroll
Chao Liu's avatar
Chao Liu committed
395
        for(index_t k_begin = 0; k_begin + KPerThreadLoop < K;
396
397
398
399
400
401
402
403
            k_begin += KPerThreadLoop, even_loop = !even_loop)
        { // loop over k
            FloatA* p_a_thread_now = even_loop ? p_a_thread_0 : p_a_thread_1;
            FloatB* p_b_thread_now = even_loop ? p_b_thread_0 : p_b_thread_1;

            FloatA* p_a_thread_next = even_loop ? p_a_thread_1 : p_a_thread_0;
            FloatB* p_b_thread_next = even_loop ? p_b_thread_1 : p_b_thread_0;

Chao Liu's avatar
Chao Liu committed
404
// preload next A, B
405
#pragma unroll
Chao Liu's avatar
Chao Liu committed
406
            for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
407
408
409
410
411
412
413
            { // copy A-sub to form A
                threadwise_matrix_copy(a_block_mtx,
                                       p_a_block + mMyThreadOffsetA +
                                           (k_begin + 1) * a_block_mtx.RowStride() +
                                           m_repeat * MPerLevel1Cluster,
                                       a_thread_sub_mtx,
                                       p_a_thread_next + m_repeat * MPerThreadSubC,
Chao Liu's avatar
Chao Liu committed
414
415
                                       a_thread_sub_mtx.GetLengths(),
                                       Number<DataPerReadA>{});
416
417
418
            }

#pragma unroll
Chao Liu's avatar
Chao Liu committed
419
            for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
420
421
422
423
424
425
426
            { // copy B-sub to form B
                threadwise_matrix_copy(b_block_mtx,
                                       p_b_block + mMyThreadOffsetB +
                                           (k_begin + 1) * b_block_mtx.RowStride() +
                                           n_repeat * NPerLevel1Cluster,
                                       b_thread_sub_mtx,
                                       p_b_thread_next + n_repeat * NPerThreadSubC,
Chao Liu's avatar
Chao Liu committed
427
428
                                       b_thread_sub_mtx.GetLengths(),
                                       Number<DataPerReadB>{});
429
430
431
432
433
434
435
436
437
438
439
            }

            // C = A * B
            threadwise_gemm(a_thread_mtx,
                            True,
                            p_a_thread_now,
                            b_thread_mtx,
                            False,
                            p_b_thread_now,
                            c_thread_mtx,
                            False,
440
                            p_c_thread);
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
        }

        // last loop
        {
            FloatA* p_a_thread_now = even_loop ? p_a_thread_0 : p_a_thread_1;
            FloatB* p_b_thread_now = even_loop ? p_b_thread_0 : p_b_thread_1;

            // C = A * B
            threadwise_gemm(a_thread_mtx,
                            True,
                            p_a_thread_now,
                            b_thread_mtx,
                            False,
                            p_b_thread_now,
                            c_thread_mtx,
                            False,
457
                            p_c_thread);
Chao Liu's avatar
Chao Liu committed
458
459
        }
    }
Chao Liu's avatar
Chao Liu committed
460
};