"src/include/blockwise_2d_tensor_op.cuh" did not exist on "3bd51021ab676dcc17fe0674d6a39411acbfce5a"
blockwise_gemm.hip.hpp 18.7 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
#pragma once
Chao Liu's avatar
Chao Liu committed
2
#include "common.hip.hpp"
3
#include "threadwise_gemm.hip.hpp"
Chao Liu's avatar
Chao Liu committed
4

Chao Liu's avatar
Chao Liu committed
5
6
// if following number are power of 2, index calculation shall be greatly reduced:
//    MPerThreadSubC, NPerThreadSubC, MLevel0Cluster, NLevel0Cluster, MLevel1Cluster, NLevel1Cluster
Chao Liu's avatar
Chao Liu committed
7
template <index_t BlockSize,
Chao Liu's avatar
Chao Liu committed
8
9
10
          class BlockMatrixA,
          class BlockMatrixB,
          class ThreadMatrixC,
Chao Liu's avatar
Chao Liu committed
11
12
13
14
15
16
17
          index_t MPerThreadSubC,
          index_t NPerThreadSubC,
          index_t MLevel0Cluster,
          index_t NLevel0Cluster,
          index_t MLevel1Cluster,
          index_t NLevel1Cluster,
          index_t KPerThreadLoop>
Chao Liu's avatar
Chao Liu committed
18
19
20
21
struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
{
    struct MatrixIndex
    {
Chao Liu's avatar
Chao Liu committed
22
23
        index_t row;
        index_t col;
Chao Liu's avatar
Chao Liu committed
24
25
    };

Chao Liu's avatar
Chao Liu committed
26
27
    index_t mMyThreadOffsetA;
    index_t mMyThreadOffsetB;
Chao Liu's avatar
Chao Liu committed
28
29
30

    __device__ BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2()
    {
Chao Liu's avatar
Chao Liu committed
31
        constexpr index_t ThreadPerLevel1Cluster =
Chao Liu's avatar
Chao Liu committed
32
33
34
35
            MLevel0Cluster * NLevel0Cluster * MLevel1Cluster * NLevel1Cluster;

        static_assert(BlockSize == ThreadPerLevel1Cluster, "wrong! wrong blocksize\n");

Chao Liu's avatar
Chao Liu committed
36
37
38
        constexpr auto a_block_mtx  = BlockMatrixA{};
        constexpr auto b_block_mtx  = BlockMatrixB{};
        constexpr auto c_thread_mtx = ThreadMatrixC{};
Chao Liu's avatar
Chao Liu committed
39
40
41
42

        static_assert(a_block_mtx.NRow() == b_block_mtx.NRow(),
                      "wrong! K dimension not consistent\n");

Chao Liu's avatar
Chao Liu committed
43
44
45
        constexpr index_t M = a_block_mtx.NCol(); // A is transposed
        constexpr index_t N = b_block_mtx.NCol();
        constexpr index_t K = a_block_mtx.NRow();
Chao Liu's avatar
Chao Liu committed
46

Chao Liu's avatar
Chao Liu committed
47
48
        constexpr index_t MPerThread = c_thread_mtx.NRow();
        constexpr index_t NPerThread = c_thread_mtx.NCol();
Chao Liu's avatar
Chao Liu committed
49
50
51
52

        static_assert((MPerThread % MPerThreadSubC == 0) && (NPerThread % NPerThreadSubC == 0),
                      "wrong! Cannot evenly divide thread work among repeat \n");

Chao Liu's avatar
Chao Liu committed
53
54
        constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
        constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
Chao Liu's avatar
Chao Liu committed
55
56
57
58

        static_assert((M % MRepeat == 0) && (N % NRepeat == 0),
                      "wrong! Cannot evenly divide work among repeat\n");

Chao Liu's avatar
Chao Liu committed
59
60
        constexpr index_t MPerLevel1Cluster = M / MRepeat;
        constexpr index_t NPerLevel1Cluster = N / NRepeat;
Chao Liu's avatar
Chao Liu committed
61
62
63
64
65

        static_assert((MPerLevel1Cluster % MLevel1Cluster == 0) &&
                          (NPerLevel1Cluster % NLevel1Cluster == 0),
                      "wrong! Cannot evenly divide work among Level1Cluster\n");

Chao Liu's avatar
Chao Liu committed
66
67
        constexpr index_t MPerLevel0Cluster = MPerLevel1Cluster / MLevel1Cluster;
        constexpr index_t NPerLevel0Cluster = NPerLevel1Cluster / NLevel1Cluster;
Chao Liu's avatar
Chao Liu committed
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82

        static_assert((MPerLevel0Cluster % MLevel0Cluster == 0) &&
                          (NPerLevel0Cluster % NLevel0Cluster == 0),
                      "wrong! Cannot evenly divide work among Level0Cluster\n");

        static_assert((MPerThreadSubC == MPerLevel0Cluster / MLevel0Cluster) &&
                          (NPerThreadSubC == NPerLevel0Cluster / NLevel0Cluster),
                      "wrong! thread work size is wrong\n");

        auto c_thread_mtx_index = GetBeginOfThreadMatrixC(get_thread_local_1d_id());

        mMyThreadOffsetA = a_block_mtx.Get1dIndex(0, c_thread_mtx_index.row);
        mMyThreadOffsetB = b_block_mtx.Get1dIndex(0, c_thread_mtx_index.col);
    }

Chao Liu's avatar
Chao Liu committed
83
    __device__ static MatrixIndex GetBeginOfThreadMatrixC(index_t thread_id)
Chao Liu's avatar
Chao Liu committed
84
    {
Chao Liu's avatar
Chao Liu committed
85
        constexpr index_t ThreadPerLevel0Cluster = MLevel0Cluster * NLevel0Cluster;
Chao Liu's avatar
Chao Liu committed
86

Chao Liu's avatar
Chao Liu committed
87
88
89
        index_t level1_id   = thread_id / ThreadPerLevel0Cluster;
        index_t level1_m_id = level1_id / NLevel1Cluster;
        index_t level1_n_id = level1_id % NLevel1Cluster;
Chao Liu's avatar
Chao Liu committed
90

Chao Liu's avatar
Chao Liu committed
91
92
93
        index_t level0_id   = thread_id % ThreadPerLevel0Cluster;
        index_t level0_m_id = level0_id / NLevel0Cluster;
        index_t level0_n_id = level0_id % NLevel0Cluster;
Chao Liu's avatar
Chao Liu committed
94

Chao Liu's avatar
Chao Liu committed
95
96
        constexpr index_t MPerLevel0Cluster = MPerThreadSubC * MLevel0Cluster;
        constexpr index_t NPerLevel0Cluster = NPerThreadSubC * NLevel0Cluster;
Chao Liu's avatar
Chao Liu committed
97
98
99
100
101
102

        return MatrixIndex{level1_m_id * MPerLevel0Cluster + level0_m_id * MPerThreadSubC,
                           level1_n_id * NPerLevel0Cluster + level0_n_id * NPerThreadSubC};
    }

    // this should be optimized away if input is known
Chao Liu's avatar
Chao Liu committed
103
104
    __device__ static MatrixIndex GetDistanceFromBeginOfThreadMatrixC(index_t m_in_c,
                                                                      index_t n_in_c)
Chao Liu's avatar
Chao Liu committed
105
    {
Chao Liu's avatar
Chao Liu committed
106
        constexpr auto c_thread_mtx = ThreadMatrixC{};
Chao Liu's avatar
Chao Liu committed
107

Chao Liu's avatar
Chao Liu committed
108
109
        constexpr index_t MPerThread = c_thread_mtx.NRow();
        constexpr index_t NPerThread = c_thread_mtx.NCol();
Chao Liu's avatar
Chao Liu committed
110

Chao Liu's avatar
Chao Liu committed
111
112
        constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
        constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
Chao Liu's avatar
Chao Liu committed
113

Chao Liu's avatar
Chao Liu committed
114
115
        constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
        constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
Chao Liu's avatar
Chao Liu committed
116

Chao Liu's avatar
Chao Liu committed
117
118
        index_t m_repeat = m_in_c / MPerThreadSubC;
        index_t n_repeat = n_in_c / NPerThreadSubC;
Chao Liu's avatar
Chao Liu committed
119

Chao Liu's avatar
Chao Liu committed
120
121
        index_t m_in_sub_c = m_in_c % MPerThreadSubC;
        index_t n_in_sub_c = n_in_c % NPerThreadSubC;
Chao Liu's avatar
Chao Liu committed
122
123
124
125
126

        return MatrixIndex{m_repeat * MPerLevel1Cluster + m_in_sub_c,
                           n_repeat * NPerLevel1Cluster + n_in_sub_c};
    }

127
128
#if DEVICE_BACKEND_HIP
    template <class FloatA, class FloatB, class FloatC>
Chao Liu's avatar
Chao Liu committed
129
130
    __device__ void Run_asm(const FloatA* __restrict__ p_a_block,
                            const FloatB* __restrict__ p_b_block,
131
                            FloatC* __restrict__ p_c_thread) const
Chao Liu's avatar
Chao Liu committed
132
    {
Chao Liu's avatar
Chao Liu committed
133
134
135
136
        static_assert(is_same<FloatA, float>::value && is_same<FloatB, float>::value &&
                          is_same<FloatC, float>::value,
                      "Run_asm only deal with float\n");

Chao Liu's avatar
Chao Liu committed
137
138
        constexpr auto True  = integral_constant<bool, true>{};
        constexpr auto False = integral_constant<bool, false>{};
Chao Liu's avatar
Chao Liu committed
139

Chao Liu's avatar
Chao Liu committed
140
141
142
        constexpr auto a_block_mtx  = BlockMatrixA{};
        constexpr auto b_block_mtx  = BlockMatrixB{};
        constexpr auto c_thread_mtx = ThreadMatrixC{};
Chao Liu's avatar
Chao Liu committed
143

Chao Liu's avatar
Chao Liu committed
144
145
146
        constexpr index_t M = a_block_mtx.NCol();
        constexpr index_t N = b_block_mtx.NCol();
        constexpr index_t K = a_block_mtx.NRow();
Chao Liu's avatar
Chao Liu committed
147

Chao Liu's avatar
Chao Liu committed
148
149
        constexpr index_t MPerThread = c_thread_mtx.NRow();
        constexpr index_t NPerThread = c_thread_mtx.NCol();
Chao Liu's avatar
Chao Liu committed
150
151

        // thread A, B for GEMM
Chao Liu's avatar
Chao Liu committed
152
153
        constexpr auto a_thread_mtx =
            make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<MPerThread>{});
Chao Liu's avatar
Chao Liu committed
154

Chao Liu's avatar
Chao Liu committed
155
156
        constexpr auto b_thread_mtx =
            make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<NPerThread>{});
Chao Liu's avatar
Chao Liu committed
157
158

        // thread A-sub, B-sub for copy
Chao Liu's avatar
Chao Liu committed
159
160
        constexpr auto a_thread_sub_mtx = make_ConstantMatrixDescriptor(
            Number<KPerThreadLoop>{}, Number<MPerThreadSubC>{}, Number<MPerThread>{});
Chao Liu's avatar
Chao Liu committed
161

Chao Liu's avatar
Chao Liu committed
162
163
        constexpr auto b_thread_sub_mtx = make_ConstantMatrixDescriptor(
            Number<KPerThreadLoop>{}, Number<NPerThreadSubC>{}, Number<NPerThread>{});
Chao Liu's avatar
Chao Liu committed
164

Chao Liu's avatar
Chao Liu committed
165
166
167
168
169
170
        static_assert(MPerThreadSubC == 4 && NPerThreadSubC == 4 && KPerThreadLoop == 1 &&
                          MPerThread == 8 && NPerThread == 8,
                      "Run_asm cannot deal with this GEMM shape yet\n");

        using Float4 = vector_type<float, 4>::MemoryType;

Chao Liu's avatar
Chao Liu committed
171
        float p_thread[a_thread_mtx.GetElementSpace() + b_thread_mtx.GetElementSpace()];
Jing Zhang's avatar
Jing Zhang committed
172

Chao Liu's avatar
Chao Liu committed
173
174
        FloatA* p_a_thread = p_thread;
        FloatB* p_b_thread = p_thread + a_thread_mtx.GetElementSpace();
Chao Liu's avatar
Chao Liu committed
175

Chao Liu's avatar
Chao Liu committed
176
177
        constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
        constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
Chao Liu's avatar
Chao Liu committed
178

Chao Liu's avatar
Chao Liu committed
179
180
        constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
        constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
Chao Liu's avatar
Chao Liu committed
181

Jing Zhang's avatar
Jing Zhang committed
182
183
184
        Float4* reg_a = (Float4*)(p_a_thread);
        Float4* reg_b = (Float4*)(p_b_thread);
        Float4* reg_c = (Float4*)(p_c_thread);
Chao Liu's avatar
Chao Liu committed
185
186
        void* a_loc   = (void*)(p_a_block + mMyThreadOffsetA);
        void* b_loc   = (void*)(p_b_block + mMyThreadOffsetB);
Jing Zhang's avatar
Jing Zhang committed
187

Jing Zhang's avatar
Jing Zhang committed
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
        int lds_a_block_off   = sizeof(Float) * M;
        int lds_b_block_off   = sizeof(Float) * N;
        int lds_a_block_off_1 = MPerLevel1Cluster * sizeof(Float);
        int lds_b_block_off_1 = NPerLevel1Cluster * sizeof(Float);
        ds_read_b128(reg_a[0], a_loc, 0);
        ds_read_b128(reg_b[0], b_loc, 0);
        ds_read_b128(reg_b[1], b_loc, lds_b_block_off_1);
        ds_read_b128(reg_a[1], a_loc, lds_a_block_off_1);
        lgkmcnt(2);
        outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]);
        lgkmcnt(1);
        outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]);
        lgkmcnt(0);
#pragma unroll
        for(int k_i = 1; k_i < K; k_i++)
        {
            ds_read_b128(reg_a[0], a_loc, k_i * lds_a_block_off);
Jing Zhang's avatar
Jing Zhang committed
205
            outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]);
Jing Zhang's avatar
Jing Zhang committed
206
            ds_read_b128(reg_b[0], b_loc, k_i * lds_b_block_off);
Jing Zhang's avatar
Jing Zhang committed
207
            outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]);
Jing Zhang's avatar
Jing Zhang committed
208
209
            ds_read_b128(reg_b[1], b_loc, lds_b_block_off_1 + k_i * lds_b_block_off);
            ds_read_b128(reg_a[1], a_loc, lds_a_block_off_1 + k_i * lds_a_block_off);
Jing Zhang's avatar
Jing Zhang committed
210
211
212
213
214
            lgkmcnt(2);
            outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]);
            lgkmcnt(1);
            outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]);
            lgkmcnt(0);
Chao Liu's avatar
Chao Liu committed
215
        }
Jing Zhang's avatar
Jing Zhang committed
216
217
        outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]);
        outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]);
Chao Liu's avatar
Chao Liu committed
218
    }
219
#endif
220

221
    template <class FloatA, class FloatB, class FloatC>
Chao Liu's avatar
Chao Liu committed
222
223
    __device__ void Run(const FloatA* const __restrict__ p_a_block,
                        const FloatB* const __restrict__ p_b_block,
224
                        FloatC* const __restrict__ p_c_thread) const
Chao Liu's avatar
Chao Liu committed
225
226
227
228
229
230
231
232
    {
        constexpr auto True  = integral_constant<bool, true>{};
        constexpr auto False = integral_constant<bool, false>{};

        constexpr auto a_block_mtx  = BlockMatrixA{};
        constexpr auto b_block_mtx  = BlockMatrixB{};
        constexpr auto c_thread_mtx = ThreadMatrixC{};

Chao Liu's avatar
Chao Liu committed
233
234
235
        constexpr index_t M = a_block_mtx.NCol();
        constexpr index_t N = b_block_mtx.NCol();
        constexpr index_t K = a_block_mtx.NRow();
Chao Liu's avatar
Chao Liu committed
236

Chao Liu's avatar
Chao Liu committed
237
238
        constexpr index_t MPerThread = c_thread_mtx.NRow();
        constexpr index_t NPerThread = c_thread_mtx.NCol();
Chao Liu's avatar
Chao Liu committed
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256

        // thread A, B for GEMM
        constexpr auto a_thread_mtx =
            make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<MPerThread>{});

        constexpr auto b_thread_mtx =
            make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<NPerThread>{});

        // thread A-sub, B-sub for copy
        constexpr auto a_thread_sub_mtx = make_ConstantMatrixDescriptor(
            Number<KPerThreadLoop>{}, Number<MPerThreadSubC>{}, Number<MPerThread>{});

        constexpr auto b_thread_sub_mtx = make_ConstantMatrixDescriptor(
            Number<KPerThreadLoop>{}, Number<NPerThreadSubC>{}, Number<NPerThread>{});

        FloatA p_a_thread[a_thread_mtx.GetElementSpace()];
        FloatB p_b_thread[b_thread_mtx.GetElementSpace()];

Chao Liu's avatar
Chao Liu committed
257
258
        constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
        constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
Chao Liu's avatar
Chao Liu committed
259

Chao Liu's avatar
Chao Liu committed
260
261
        constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
        constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
Chao Liu's avatar
Chao Liu committed
262

Chao Liu's avatar
Chao Liu committed
263
264
        const FloatA* const p_a_block_thread_offset = p_a_block + mMyThreadOffsetA;

Chao Liu's avatar
Chao Liu committed
265
266
#pragma unroll
        // loop over k
Chao Liu's avatar
Chao Liu committed
267
        for(index_t k_begin = 0; k_begin < K; k_begin += KPerThreadLoop)
Chao Liu's avatar
Chao Liu committed
268
        {
Chao Liu's avatar
Chao Liu committed
269
#pragma unroll
Chao Liu's avatar
Chao Liu committed
270
            // copy A-sub to form A
Chao Liu's avatar
Chao Liu committed
271
            for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
Chao Liu's avatar
Chao Liu committed
272
273
274
275
276
277
            {
                threadwise_matrix_copy(
                    a_block_mtx,
                    p_a_block + a_block_mtx.Get1dIndex(k_begin, m_repeat * MPerLevel1Cluster) +
                        mMyThreadOffsetA,
                    a_thread_mtx,
Chao Liu's avatar
Chao Liu committed
278
                    p_a_thread + a_thread_mtx.Get1dIndex(0, m_repeat * MPerThreadSubC),
Chao Liu's avatar
Chao Liu committed
279
280
                    a_thread_sub_mtx.GetLengths());
            }
Chao Liu's avatar
Chao Liu committed
281

Chao Liu's avatar
Chao Liu committed
282
#pragma unroll
Chao Liu's avatar
Chao Liu committed
283
            // copy B-sub to form B
Chao Liu's avatar
Chao Liu committed
284
            for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
Chao Liu's avatar
Chao Liu committed
285
286
287
288
289
290
291
292
293
294
            {
                threadwise_matrix_copy(
                    b_block_mtx,
                    p_b_block + b_block_mtx.Get1dIndex(k_begin, n_repeat * NPerLevel1Cluster) +
                        mMyThreadOffsetB,
                    b_thread_mtx,
                    p_b_thread + b_thread_mtx.Get1dIndex(0, n_repeat * NPerThreadSubC),
                    b_thread_sub_mtx.GetLengths());
            }

Chao Liu's avatar
Chao Liu committed
295
            // C = A * B
Chao Liu's avatar
Chao Liu committed
296
297
298
299
300
301
302
303
            threadwise_gemm(a_thread_mtx,
                            True,
                            p_a_thread,
                            b_thread_mtx,
                            False,
                            p_b_thread,
                            c_thread_mtx,
                            False,
304
                            p_c_thread);
Chao Liu's avatar
Chao Liu committed
305
306
307
        }
    }

308
    template <class FloatA, class FloatB, class FloatC>
309
310
    __device__ void Run_RegisterDoubleBuffer(FloatA* const p_a_block,
                                             FloatB* const p_b_block,
311
                                             FloatC* p_c_thread) const
312
    {
Chao Liu's avatar
Chao Liu committed
313
314
        constexpr auto True  = integral_constant<bool, true>{};
        constexpr auto False = integral_constant<bool, false>{};
315

Chao Liu's avatar
Chao Liu committed
316
317
318
        constexpr auto a_block_mtx  = BlockMatrixA{};
        constexpr auto b_block_mtx  = BlockMatrixB{};
        constexpr auto c_thread_mtx = ThreadMatrixC{};
319

Chao Liu's avatar
Chao Liu committed
320
321
322
        constexpr index_t M = a_block_mtx.NCol();
        constexpr index_t N = b_block_mtx.NCol();
        constexpr index_t K = a_block_mtx.NRow();
323

Chao Liu's avatar
Chao Liu committed
324
325
        constexpr index_t MPerThread = c_thread_mtx.NRow();
        constexpr index_t NPerThread = c_thread_mtx.NCol();
326
327

        // thread A, B for GEMM
Chao Liu's avatar
Chao Liu committed
328
329
        constexpr auto a_thread_mtx =
            make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<MPerThread>{});
330

Chao Liu's avatar
Chao Liu committed
331
332
        constexpr auto b_thread_mtx =
            make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<NPerThread>{});
333
334

        // thread A-sub, B-sub for copy
Chao Liu's avatar
Chao Liu committed
335
336
        constexpr auto a_thread_sub_mtx = make_ConstantMatrixDescriptor(
            Number<KPerThreadLoop>{}, Number<MPerThreadSubC>{}, Number<MPerThread>{});
337

Chao Liu's avatar
Chao Liu committed
338
339
        constexpr auto b_thread_sub_mtx = make_ConstantMatrixDescriptor(
            Number<KPerThreadLoop>{}, Number<NPerThreadSubC>{}, Number<NPerThread>{});
340

341
        // register
342
343
344
345
346
347
        FloatA p_a_thread_0[a_thread_mtx.GetElementSpace()];
        FloatB p_b_thread_0[b_thread_mtx.GetElementSpace()];

        FloatA p_a_thread_1[a_thread_mtx.GetElementSpace()];
        FloatB p_b_thread_1[b_thread_mtx.GetElementSpace()];

Chao Liu's avatar
Chao Liu committed
348
349
        constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
        constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
350

Chao Liu's avatar
Chao Liu committed
351
352
        constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
        constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
353

Chao Liu's avatar
Chao Liu committed
354
// preload A, B
355
#pragma unroll
Chao Liu's avatar
Chao Liu committed
356
        for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
357
358
359
360
361
362
363
364
365
        { // copy A-sub to form A
            threadwise_matrix_copy(a_block_mtx,
                                   p_a_block + mMyThreadOffsetA + m_repeat * MPerLevel1Cluster,
                                   a_thread_sub_mtx,
                                   p_a_thread_0 + m_repeat * MPerThreadSubC,
                                   a_thread_sub_mtx.GetLengths());
        }

#pragma unroll
Chao Liu's avatar
Chao Liu committed
366
        for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
367
368
369
370
371
372
373
374
375
376
377
        { // copy B-sub to form B
            threadwise_matrix_copy(b_block_mtx,
                                   p_b_block + mMyThreadOffsetB + n_repeat * NPerLevel1Cluster,
                                   b_thread_sub_mtx,
                                   p_b_thread_0 + n_repeat * NPerThreadSubC,
                                   b_thread_sub_mtx.GetLengths());
        }

        bool even_loop = true;

#pragma unroll
Chao Liu's avatar
Chao Liu committed
378
        for(index_t k_begin = 0; k_begin + KPerThreadLoop < K;
379
380
381
382
383
384
385
386
            k_begin += KPerThreadLoop, even_loop = !even_loop)
        { // loop over k
            FloatA* p_a_thread_now = even_loop ? p_a_thread_0 : p_a_thread_1;
            FloatB* p_b_thread_now = even_loop ? p_b_thread_0 : p_b_thread_1;

            FloatA* p_a_thread_next = even_loop ? p_a_thread_1 : p_a_thread_0;
            FloatB* p_b_thread_next = even_loop ? p_b_thread_1 : p_b_thread_0;

Chao Liu's avatar
Chao Liu committed
387
// preload next A, B
388
#pragma unroll
Chao Liu's avatar
Chao Liu committed
389
            for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
390
391
392
393
394
395
396
397
398
399
400
            { // copy A-sub to form A
                threadwise_matrix_copy(a_block_mtx,
                                       p_a_block + mMyThreadOffsetA +
                                           (k_begin + 1) * a_block_mtx.RowStride() +
                                           m_repeat * MPerLevel1Cluster,
                                       a_thread_sub_mtx,
                                       p_a_thread_next + m_repeat * MPerThreadSubC,
                                       a_thread_sub_mtx.GetLengths());
            }

#pragma unroll
Chao Liu's avatar
Chao Liu committed
401
            for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
            { // copy B-sub to form B
                threadwise_matrix_copy(b_block_mtx,
                                       p_b_block + mMyThreadOffsetB +
                                           (k_begin + 1) * b_block_mtx.RowStride() +
                                           n_repeat * NPerLevel1Cluster,
                                       b_thread_sub_mtx,
                                       p_b_thread_next + n_repeat * NPerThreadSubC,
                                       b_thread_sub_mtx.GetLengths());
            }

            // C = A * B
            threadwise_gemm(a_thread_mtx,
                            True,
                            p_a_thread_now,
                            b_thread_mtx,
                            False,
                            p_b_thread_now,
                            c_thread_mtx,
                            False,
421
                            p_c_thread);
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
        }

        // last loop
        {
            FloatA* p_a_thread_now = even_loop ? p_a_thread_0 : p_a_thread_1;
            FloatB* p_b_thread_now = even_loop ? p_b_thread_0 : p_b_thread_1;

            // C = A * B
            threadwise_gemm(a_thread_mtx,
                            True,
                            p_a_thread_now,
                            b_thread_mtx,
                            False,
                            p_b_thread_now,
                            c_thread_mtx,
                            False,
438
                            p_c_thread);
Chao Liu's avatar
Chao Liu committed
439
440
        }
    }
Chao Liu's avatar
Chao Liu committed
441
};