blockwise_gemm.hip.hpp 18.3 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
#pragma once
Chao Liu's avatar
Chao Liu committed
2
#include "common.hip.hpp"
3
#include "threadwise_gemm.hip.hpp"
Chao Liu's avatar
Chao Liu committed
4

Chao Liu's avatar
Chao Liu committed
5
6
// if following number are power of 2, index calculation shall be greatly reduced:
//    MPerThreadSubC, NPerThreadSubC, MLevel0Cluster, NLevel0Cluster, MLevel1Cluster, NLevel1Cluster
Chao Liu's avatar
Chao Liu committed
7
template <index_t BlockSize,
Chao Liu's avatar
Chao Liu committed
8
9
10
          class BlockMatrixA,
          class BlockMatrixB,
          class ThreadMatrixC,
Chao Liu's avatar
Chao Liu committed
11
12
13
14
15
16
17
          index_t MPerThreadSubC,
          index_t NPerThreadSubC,
          index_t MLevel0Cluster,
          index_t NLevel0Cluster,
          index_t MLevel1Cluster,
          index_t NLevel1Cluster,
          index_t KPerThreadLoop>
Chao Liu's avatar
Chao Liu committed
18
19
20
21
struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
{
    struct MatrixIndex
    {
Chao Liu's avatar
Chao Liu committed
22
23
        index_t row;
        index_t col;
Chao Liu's avatar
Chao Liu committed
24
25
    };

Chao Liu's avatar
Chao Liu committed
26
27
    index_t mMyThreadOffsetA;
    index_t mMyThreadOffsetB;
Chao Liu's avatar
Chao Liu committed
28
29
30

    __device__ BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2()
    {
Chao Liu's avatar
Chao Liu committed
31
        constexpr index_t ThreadPerLevel1Cluster =
Chao Liu's avatar
Chao Liu committed
32
33
34
35
            MLevel0Cluster * NLevel0Cluster * MLevel1Cluster * NLevel1Cluster;

        static_assert(BlockSize == ThreadPerLevel1Cluster, "wrong! wrong blocksize\n");

Chao Liu's avatar
Chao Liu committed
36
37
38
        constexpr auto a_block_mtx  = BlockMatrixA{};
        constexpr auto b_block_mtx  = BlockMatrixB{};
        constexpr auto c_thread_mtx = ThreadMatrixC{};
Chao Liu's avatar
Chao Liu committed
39
40
41
42

        static_assert(a_block_mtx.NRow() == b_block_mtx.NRow(),
                      "wrong! K dimension not consistent\n");

Chao Liu's avatar
Chao Liu committed
43
44
45
        constexpr index_t M = a_block_mtx.NCol(); // A is transposed
        constexpr index_t N = b_block_mtx.NCol();
        constexpr index_t K = a_block_mtx.NRow();
Chao Liu's avatar
Chao Liu committed
46

Chao Liu's avatar
Chao Liu committed
47
48
        constexpr index_t MPerThread = c_thread_mtx.NRow();
        constexpr index_t NPerThread = c_thread_mtx.NCol();
Chao Liu's avatar
Chao Liu committed
49
50
51
52

        static_assert((MPerThread % MPerThreadSubC == 0) && (NPerThread % NPerThreadSubC == 0),
                      "wrong! Cannot evenly divide thread work among repeat \n");

Chao Liu's avatar
Chao Liu committed
53
54
        constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
        constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
Chao Liu's avatar
Chao Liu committed
55
56
57
58

        static_assert((M % MRepeat == 0) && (N % NRepeat == 0),
                      "wrong! Cannot evenly divide work among repeat\n");

Chao Liu's avatar
Chao Liu committed
59
60
        constexpr index_t MPerLevel1Cluster = M / MRepeat;
        constexpr index_t NPerLevel1Cluster = N / NRepeat;
Chao Liu's avatar
Chao Liu committed
61
62
63
64
65

        static_assert((MPerLevel1Cluster % MLevel1Cluster == 0) &&
                          (NPerLevel1Cluster % NLevel1Cluster == 0),
                      "wrong! Cannot evenly divide work among Level1Cluster\n");

Chao Liu's avatar
Chao Liu committed
66
67
        constexpr index_t MPerLevel0Cluster = MPerLevel1Cluster / MLevel1Cluster;
        constexpr index_t NPerLevel0Cluster = NPerLevel1Cluster / NLevel1Cluster;
Chao Liu's avatar
Chao Liu committed
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82

        static_assert((MPerLevel0Cluster % MLevel0Cluster == 0) &&
                          (NPerLevel0Cluster % NLevel0Cluster == 0),
                      "wrong! Cannot evenly divide work among Level0Cluster\n");

        static_assert((MPerThreadSubC == MPerLevel0Cluster / MLevel0Cluster) &&
                          (NPerThreadSubC == NPerLevel0Cluster / NLevel0Cluster),
                      "wrong! thread work size is wrong\n");

        auto c_thread_mtx_index = GetBeginOfThreadMatrixC(get_thread_local_1d_id());

        mMyThreadOffsetA = a_block_mtx.Get1dIndex(0, c_thread_mtx_index.row);
        mMyThreadOffsetB = b_block_mtx.Get1dIndex(0, c_thread_mtx_index.col);
    }

Chao Liu's avatar
Chao Liu committed
83
    __device__ static MatrixIndex GetBeginOfThreadMatrixC(index_t thread_id)
Chao Liu's avatar
Chao Liu committed
84
    {
Chao Liu's avatar
Chao Liu committed
85
        constexpr index_t ThreadPerLevel0Cluster = MLevel0Cluster * NLevel0Cluster;
Chao Liu's avatar
Chao Liu committed
86

Chao Liu's avatar
Chao Liu committed
87
88
89
        index_t level1_id   = thread_id / ThreadPerLevel0Cluster;
        index_t level1_m_id = level1_id / NLevel1Cluster;
        index_t level1_n_id = level1_id % NLevel1Cluster;
Chao Liu's avatar
Chao Liu committed
90

Chao Liu's avatar
Chao Liu committed
91
92
93
        index_t level0_id   = thread_id % ThreadPerLevel0Cluster;
        index_t level0_m_id = level0_id / NLevel0Cluster;
        index_t level0_n_id = level0_id % NLevel0Cluster;
Chao Liu's avatar
Chao Liu committed
94

Chao Liu's avatar
Chao Liu committed
95
96
        constexpr index_t MPerLevel0Cluster = MPerThreadSubC * MLevel0Cluster;
        constexpr index_t NPerLevel0Cluster = NPerThreadSubC * NLevel0Cluster;
Chao Liu's avatar
Chao Liu committed
97
98
99
100
101
102

        return MatrixIndex{level1_m_id * MPerLevel0Cluster + level0_m_id * MPerThreadSubC,
                           level1_n_id * NPerLevel0Cluster + level0_n_id * NPerThreadSubC};
    }

    // this should be optimized away if input is known
Chao Liu's avatar
Chao Liu committed
103
104
    __device__ static MatrixIndex GetDistanceFromBeginOfThreadMatrixC(index_t m_in_c,
                                                                      index_t n_in_c)
Chao Liu's avatar
Chao Liu committed
105
    {
Chao Liu's avatar
Chao Liu committed
106
        constexpr auto c_thread_mtx = ThreadMatrixC{};
Chao Liu's avatar
Chao Liu committed
107

Chao Liu's avatar
Chao Liu committed
108
109
        constexpr index_t MPerThread = c_thread_mtx.NRow();
        constexpr index_t NPerThread = c_thread_mtx.NCol();
Chao Liu's avatar
Chao Liu committed
110

Chao Liu's avatar
Chao Liu committed
111
112
        constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
        constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
Chao Liu's avatar
Chao Liu committed
113

Chao Liu's avatar
Chao Liu committed
114
115
        constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
        constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
Chao Liu's avatar
Chao Liu committed
116

Chao Liu's avatar
Chao Liu committed
117
118
        index_t m_repeat = m_in_c / MPerThreadSubC;
        index_t n_repeat = n_in_c / NPerThreadSubC;
Chao Liu's avatar
Chao Liu committed
119

Chao Liu's avatar
Chao Liu committed
120
121
        index_t m_in_sub_c = m_in_c % MPerThreadSubC;
        index_t n_in_sub_c = n_in_c % NPerThreadSubC;
Chao Liu's avatar
Chao Liu committed
122
123
124
125
126

        return MatrixIndex{m_repeat * MPerLevel1Cluster + m_in_sub_c,
                           n_repeat * NPerLevel1Cluster + n_in_sub_c};
    }

127
128
#if DEVICE_BACKEND_HIP
    template <class FloatA, class FloatB, class FloatC>
Chao Liu's avatar
Chao Liu committed
129
130
    __device__ void Run_asm(const FloatA* __restrict__ p_a_block,
                            const FloatB* __restrict__ p_b_block,
131
                            FloatC* __restrict__ p_c_thread) const
Chao Liu's avatar
Chao Liu committed
132
    {
Chao Liu's avatar
Chao Liu committed
133
134
        constexpr auto True  = integral_constant<bool, true>{};
        constexpr auto False = integral_constant<bool, false>{};
Chao Liu's avatar
Chao Liu committed
135

Chao Liu's avatar
Chao Liu committed
136
137
138
        constexpr auto a_block_mtx  = BlockMatrixA{};
        constexpr auto b_block_mtx  = BlockMatrixB{};
        constexpr auto c_thread_mtx = ThreadMatrixC{};
Chao Liu's avatar
Chao Liu committed
139

Chao Liu's avatar
Chao Liu committed
140
141
142
        constexpr index_t M = a_block_mtx.NCol();
        constexpr index_t N = b_block_mtx.NCol();
        constexpr index_t K = a_block_mtx.NRow();
Chao Liu's avatar
Chao Liu committed
143

Chao Liu's avatar
Chao Liu committed
144
145
        constexpr index_t MPerThread = c_thread_mtx.NRow();
        constexpr index_t NPerThread = c_thread_mtx.NCol();
Chao Liu's avatar
Chao Liu committed
146
147

        // thread A, B for GEMM
Chao Liu's avatar
Chao Liu committed
148
149
        constexpr auto a_thread_mtx =
            make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<MPerThread>{});
Chao Liu's avatar
Chao Liu committed
150

Chao Liu's avatar
Chao Liu committed
151
152
        constexpr auto b_thread_mtx =
            make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<NPerThread>{});
Chao Liu's avatar
Chao Liu committed
153
154

        // thread A-sub, B-sub for copy
Chao Liu's avatar
Chao Liu committed
155
156
        constexpr auto a_thread_sub_mtx = make_ConstantMatrixDescriptor(
            Number<KPerThreadLoop>{}, Number<MPerThreadSubC>{}, Number<MPerThread>{});
Chao Liu's avatar
Chao Liu committed
157

Chao Liu's avatar
Chao Liu committed
158
159
        constexpr auto b_thread_sub_mtx = make_ConstantMatrixDescriptor(
            Number<KPerThreadLoop>{}, Number<NPerThreadSubC>{}, Number<NPerThread>{});
Chao Liu's avatar
Chao Liu committed
160

Chao Liu's avatar
Chao Liu committed
161
        float p_thread[a_thread_mtx.GetElementSpace() + b_thread_mtx.GetElementSpace()];
Jing Zhang's avatar
Jing Zhang committed
162

Chao Liu's avatar
Chao Liu committed
163
164
        FloatA* p_a_thread = p_thread;
        FloatB* p_b_thread = p_thread + a_thread_mtx.GetElementSpace();
Chao Liu's avatar
Chao Liu committed
165

Chao Liu's avatar
Chao Liu committed
166
167
        constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
        constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
Chao Liu's avatar
Chao Liu committed
168

Chao Liu's avatar
Chao Liu committed
169
170
        constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
        constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
Chao Liu's avatar
Chao Liu committed
171

Jing Zhang's avatar
Jing Zhang committed
172
173
174
        Float4* reg_a = (Float4*)(p_a_thread);
        Float4* reg_b = (Float4*)(p_b_thread);
        Float4* reg_c = (Float4*)(p_c_thread);
Chao Liu's avatar
Chao Liu committed
175
176
        void* a_loc   = (void*)(p_a_block + mMyThreadOffsetA);
        void* b_loc   = (void*)(p_b_block + mMyThreadOffsetB);
Jing Zhang's avatar
Jing Zhang committed
177

Jing Zhang's avatar
Jing Zhang committed
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
        int lds_a_block_off   = sizeof(Float) * M;
        int lds_b_block_off   = sizeof(Float) * N;
        int lds_a_block_off_1 = MPerLevel1Cluster * sizeof(Float);
        int lds_b_block_off_1 = NPerLevel1Cluster * sizeof(Float);
        ds_read_b128(reg_a[0], a_loc, 0);
        ds_read_b128(reg_b[0], b_loc, 0);
        ds_read_b128(reg_b[1], b_loc, lds_b_block_off_1);
        ds_read_b128(reg_a[1], a_loc, lds_a_block_off_1);
        lgkmcnt(2);
        outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]);
        lgkmcnt(1);
        outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]);
        lgkmcnt(0);
#pragma unroll
        for(int k_i = 1; k_i < K; k_i++)
        {
            ds_read_b128(reg_a[0], a_loc, k_i * lds_a_block_off);
Jing Zhang's avatar
Jing Zhang committed
195
            outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]);
Jing Zhang's avatar
Jing Zhang committed
196
            ds_read_b128(reg_b[0], b_loc, k_i * lds_b_block_off);
Jing Zhang's avatar
Jing Zhang committed
197
            outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]);
Jing Zhang's avatar
Jing Zhang committed
198
199
            ds_read_b128(reg_b[1], b_loc, lds_b_block_off_1 + k_i * lds_b_block_off);
            ds_read_b128(reg_a[1], a_loc, lds_a_block_off_1 + k_i * lds_a_block_off);
Jing Zhang's avatar
Jing Zhang committed
200
201
202
203
204
            lgkmcnt(2);
            outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]);
            lgkmcnt(1);
            outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]);
            lgkmcnt(0);
Chao Liu's avatar
Chao Liu committed
205
        }
Jing Zhang's avatar
Jing Zhang committed
206
207
        outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]);
        outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]);
Chao Liu's avatar
Chao Liu committed
208
    }
209
#endif
210

211
    template <class FloatA, class FloatB, class FloatC>
Chao Liu's avatar
Chao Liu committed
212
213
    __device__ void Run(const FloatA* const __restrict__ p_a_block,
                        const FloatB* const __restrict__ p_b_block,
214
                        FloatC* const __restrict__ p_c_thread) const
Chao Liu's avatar
Chao Liu committed
215
216
217
218
219
220
221
222
    {
        constexpr auto True  = integral_constant<bool, true>{};
        constexpr auto False = integral_constant<bool, false>{};

        constexpr auto a_block_mtx  = BlockMatrixA{};
        constexpr auto b_block_mtx  = BlockMatrixB{};
        constexpr auto c_thread_mtx = ThreadMatrixC{};

Chao Liu's avatar
Chao Liu committed
223
224
225
        constexpr index_t M = a_block_mtx.NCol();
        constexpr index_t N = b_block_mtx.NCol();
        constexpr index_t K = a_block_mtx.NRow();
Chao Liu's avatar
Chao Liu committed
226

Chao Liu's avatar
Chao Liu committed
227
228
        constexpr index_t MPerThread = c_thread_mtx.NRow();
        constexpr index_t NPerThread = c_thread_mtx.NCol();
Chao Liu's avatar
Chao Liu committed
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246

        // thread A, B for GEMM
        constexpr auto a_thread_mtx =
            make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<MPerThread>{});

        constexpr auto b_thread_mtx =
            make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<NPerThread>{});

        // thread A-sub, B-sub for copy
        constexpr auto a_thread_sub_mtx = make_ConstantMatrixDescriptor(
            Number<KPerThreadLoop>{}, Number<MPerThreadSubC>{}, Number<MPerThread>{});

        constexpr auto b_thread_sub_mtx = make_ConstantMatrixDescriptor(
            Number<KPerThreadLoop>{}, Number<NPerThreadSubC>{}, Number<NPerThread>{});

        FloatA p_a_thread[a_thread_mtx.GetElementSpace()];
        FloatB p_b_thread[b_thread_mtx.GetElementSpace()];

Chao Liu's avatar
Chao Liu committed
247
248
        constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
        constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
Chao Liu's avatar
Chao Liu committed
249

Chao Liu's avatar
Chao Liu committed
250
251
        constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
        constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
Chao Liu's avatar
Chao Liu committed
252

Chao Liu's avatar
Chao Liu committed
253
254
        const FloatA* const p_a_block_thread_offset = p_a_block + mMyThreadOffsetA;

Chao Liu's avatar
Chao Liu committed
255
256
#pragma unroll
        // loop over k
Chao Liu's avatar
Chao Liu committed
257
        for(index_t k_begin = 0; k_begin < K; k_begin += KPerThreadLoop)
Chao Liu's avatar
Chao Liu committed
258
        {
Chao Liu's avatar
Chao Liu committed
259
#pragma unroll
Chao Liu's avatar
Chao Liu committed
260
            // copy A-sub to form A
Chao Liu's avatar
Chao Liu committed
261
            for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
Chao Liu's avatar
Chao Liu committed
262
263
264
265
266
267
            {
                threadwise_matrix_copy(
                    a_block_mtx,
                    p_a_block + a_block_mtx.Get1dIndex(k_begin, m_repeat * MPerLevel1Cluster) +
                        mMyThreadOffsetA,
                    a_thread_mtx,
Chao Liu's avatar
Chao Liu committed
268
                    p_a_thread + a_thread_mtx.Get1dIndex(0, m_repeat * MPerThreadSubC),
Chao Liu's avatar
Chao Liu committed
269
270
                    a_thread_sub_mtx.GetLengths());
            }
Chao Liu's avatar
Chao Liu committed
271

Chao Liu's avatar
Chao Liu committed
272
#pragma unroll
Chao Liu's avatar
Chao Liu committed
273
            // copy B-sub to form B
Chao Liu's avatar
Chao Liu committed
274
            for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
Chao Liu's avatar
Chao Liu committed
275
276
277
278
279
280
281
282
283
284
            {
                threadwise_matrix_copy(
                    b_block_mtx,
                    p_b_block + b_block_mtx.Get1dIndex(k_begin, n_repeat * NPerLevel1Cluster) +
                        mMyThreadOffsetB,
                    b_thread_mtx,
                    p_b_thread + b_thread_mtx.Get1dIndex(0, n_repeat * NPerThreadSubC),
                    b_thread_sub_mtx.GetLengths());
            }

Chao Liu's avatar
Chao Liu committed
285
            // C = A * B
Chao Liu's avatar
Chao Liu committed
286
287
288
289
290
291
292
293
            threadwise_gemm(a_thread_mtx,
                            True,
                            p_a_thread,
                            b_thread_mtx,
                            False,
                            p_b_thread,
                            c_thread_mtx,
                            False,
294
                            p_c_thread);
Chao Liu's avatar
Chao Liu committed
295
296
297
        }
    }

298
    template <class FloatA, class FloatB, class FloatC>
299
300
    __device__ void Run_RegisterDoubleBuffer(FloatA* const p_a_block,
                                             FloatB* const p_b_block,
301
                                             FloatC* p_c_thread) const
302
    {
Chao Liu's avatar
Chao Liu committed
303
304
        constexpr auto True  = integral_constant<bool, true>{};
        constexpr auto False = integral_constant<bool, false>{};
305

Chao Liu's avatar
Chao Liu committed
306
307
308
        constexpr auto a_block_mtx  = BlockMatrixA{};
        constexpr auto b_block_mtx  = BlockMatrixB{};
        constexpr auto c_thread_mtx = ThreadMatrixC{};
309

Chao Liu's avatar
Chao Liu committed
310
311
312
        constexpr index_t M = a_block_mtx.NCol();
        constexpr index_t N = b_block_mtx.NCol();
        constexpr index_t K = a_block_mtx.NRow();
313

Chao Liu's avatar
Chao Liu committed
314
315
        constexpr index_t MPerThread = c_thread_mtx.NRow();
        constexpr index_t NPerThread = c_thread_mtx.NCol();
316
317

        // thread A, B for GEMM
Chao Liu's avatar
Chao Liu committed
318
319
        constexpr auto a_thread_mtx =
            make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<MPerThread>{});
320

Chao Liu's avatar
Chao Liu committed
321
322
        constexpr auto b_thread_mtx =
            make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<NPerThread>{});
323
324

        // thread A-sub, B-sub for copy
Chao Liu's avatar
Chao Liu committed
325
326
        constexpr auto a_thread_sub_mtx = make_ConstantMatrixDescriptor(
            Number<KPerThreadLoop>{}, Number<MPerThreadSubC>{}, Number<MPerThread>{});
327

Chao Liu's avatar
Chao Liu committed
328
329
        constexpr auto b_thread_sub_mtx = make_ConstantMatrixDescriptor(
            Number<KPerThreadLoop>{}, Number<NPerThreadSubC>{}, Number<NPerThread>{});
330

331
        // register
332
333
334
335
336
337
        FloatA p_a_thread_0[a_thread_mtx.GetElementSpace()];
        FloatB p_b_thread_0[b_thread_mtx.GetElementSpace()];

        FloatA p_a_thread_1[a_thread_mtx.GetElementSpace()];
        FloatB p_b_thread_1[b_thread_mtx.GetElementSpace()];

Chao Liu's avatar
Chao Liu committed
338
339
        constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
        constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
340

Chao Liu's avatar
Chao Liu committed
341
342
        constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
        constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
343

Chao Liu's avatar
Chao Liu committed
344
// preload A, B
345
#pragma unroll
Chao Liu's avatar
Chao Liu committed
346
        for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
347
348
349
350
351
352
353
354
355
        { // copy A-sub to form A
            threadwise_matrix_copy(a_block_mtx,
                                   p_a_block + mMyThreadOffsetA + m_repeat * MPerLevel1Cluster,
                                   a_thread_sub_mtx,
                                   p_a_thread_0 + m_repeat * MPerThreadSubC,
                                   a_thread_sub_mtx.GetLengths());
        }

#pragma unroll
Chao Liu's avatar
Chao Liu committed
356
        for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
357
358
359
360
361
362
363
364
365
366
367
        { // copy B-sub to form B
            threadwise_matrix_copy(b_block_mtx,
                                   p_b_block + mMyThreadOffsetB + n_repeat * NPerLevel1Cluster,
                                   b_thread_sub_mtx,
                                   p_b_thread_0 + n_repeat * NPerThreadSubC,
                                   b_thread_sub_mtx.GetLengths());
        }

        bool even_loop = true;

#pragma unroll
Chao Liu's avatar
Chao Liu committed
368
        for(index_t k_begin = 0; k_begin + KPerThreadLoop < K;
369
370
371
372
373
374
375
376
            k_begin += KPerThreadLoop, even_loop = !even_loop)
        { // loop over k
            FloatA* p_a_thread_now = even_loop ? p_a_thread_0 : p_a_thread_1;
            FloatB* p_b_thread_now = even_loop ? p_b_thread_0 : p_b_thread_1;

            FloatA* p_a_thread_next = even_loop ? p_a_thread_1 : p_a_thread_0;
            FloatB* p_b_thread_next = even_loop ? p_b_thread_1 : p_b_thread_0;

Chao Liu's avatar
Chao Liu committed
377
// preload next A, B
378
#pragma unroll
Chao Liu's avatar
Chao Liu committed
379
            for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
380
381
382
383
384
385
386
387
388
389
390
            { // copy A-sub to form A
                threadwise_matrix_copy(a_block_mtx,
                                       p_a_block + mMyThreadOffsetA +
                                           (k_begin + 1) * a_block_mtx.RowStride() +
                                           m_repeat * MPerLevel1Cluster,
                                       a_thread_sub_mtx,
                                       p_a_thread_next + m_repeat * MPerThreadSubC,
                                       a_thread_sub_mtx.GetLengths());
            }

#pragma unroll
Chao Liu's avatar
Chao Liu committed
391
            for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
            { // copy B-sub to form B
                threadwise_matrix_copy(b_block_mtx,
                                       p_b_block + mMyThreadOffsetB +
                                           (k_begin + 1) * b_block_mtx.RowStride() +
                                           n_repeat * NPerLevel1Cluster,
                                       b_thread_sub_mtx,
                                       p_b_thread_next + n_repeat * NPerThreadSubC,
                                       b_thread_sub_mtx.GetLengths());
            }

            // C = A * B
            threadwise_gemm(a_thread_mtx,
                            True,
                            p_a_thread_now,
                            b_thread_mtx,
                            False,
                            p_b_thread_now,
                            c_thread_mtx,
                            False,
411
                            p_c_thread);
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
        }

        // last loop
        {
            FloatA* p_a_thread_now = even_loop ? p_a_thread_0 : p_a_thread_1;
            FloatB* p_b_thread_now = even_loop ? p_b_thread_0 : p_b_thread_1;

            // C = A * B
            threadwise_gemm(a_thread_mtx,
                            True,
                            p_a_thread_now,
                            b_thread_mtx,
                            False,
                            p_b_thread_now,
                            c_thread_mtx,
                            False,
428
                            p_c_thread);
Chao Liu's avatar
Chao Liu committed
429
430
        }
    }
Chao Liu's avatar
Chao Liu committed
431
};