"magic_pdf/vscode:/vscode.git/clone" did not exist on "f0a8886c7b631e8755f72484ced2b97a6cb1b13f"
blockwise_gemm.hip.hpp 19.1 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
#pragma once
Chao Liu's avatar
Chao Liu committed
2
#include "common.hip.hpp"
3
#include "threadwise_gemm.hip.hpp"
Chao Liu's avatar
Chao Liu committed
4

Chao Liu's avatar
Chao Liu committed
5
6
// if following number are power of 2, index calculation shall be greatly reduced:
//    MPerThreadSubC, NPerThreadSubC, MLevel0Cluster, NLevel0Cluster, MLevel1Cluster, NLevel1Cluster
Chao Liu's avatar
Chao Liu committed
7
template <index_t BlockSize,
Chao Liu's avatar
Chao Liu committed
8
9
10
          class BlockMatrixA,
          class BlockMatrixB,
          class ThreadMatrixC,
Chao Liu's avatar
Chao Liu committed
11
12
13
14
15
16
          index_t MPerThreadSubC,
          index_t NPerThreadSubC,
          index_t MLevel0Cluster,
          index_t NLevel0Cluster,
          index_t MLevel1Cluster,
          index_t NLevel1Cluster,
Chao Liu's avatar
Chao Liu committed
17
18
19
          index_t KPerThreadLoop,
          index_t DataPerReadA,
          index_t DataPerReadB>
Chao Liu's avatar
Chao Liu committed
20
21
22
23
struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
{
    struct MatrixIndex
    {
Chao Liu's avatar
Chao Liu committed
24
25
        index_t row;
        index_t col;
Chao Liu's avatar
Chao Liu committed
26
27
    };

Chao Liu's avatar
Chao Liu committed
28
29
    index_t mMyThreadOffsetA;
    index_t mMyThreadOffsetB;
Chao Liu's avatar
Chao Liu committed
30
31
32

    __device__ BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2()
    {
Chao Liu's avatar
Chao Liu committed
33
        constexpr index_t ThreadPerLevel1Cluster =
Chao Liu's avatar
Chao Liu committed
34
35
36
37
            MLevel0Cluster * NLevel0Cluster * MLevel1Cluster * NLevel1Cluster;

        static_assert(BlockSize == ThreadPerLevel1Cluster, "wrong! wrong blocksize\n");

Chao Liu's avatar
Chao Liu committed
38
39
40
        constexpr auto a_block_mtx  = BlockMatrixA{};
        constexpr auto b_block_mtx  = BlockMatrixB{};
        constexpr auto c_thread_mtx = ThreadMatrixC{};
Chao Liu's avatar
Chao Liu committed
41
42
43
44

        static_assert(a_block_mtx.NRow() == b_block_mtx.NRow(),
                      "wrong! K dimension not consistent\n");

Chao Liu's avatar
Chao Liu committed
45
46
47
        constexpr index_t M = a_block_mtx.NCol(); // A is transposed
        constexpr index_t N = b_block_mtx.NCol();
        constexpr index_t K = a_block_mtx.NRow();
Chao Liu's avatar
Chao Liu committed
48

Chao Liu's avatar
Chao Liu committed
49
50
        constexpr index_t MPerThread = c_thread_mtx.NRow();
        constexpr index_t NPerThread = c_thread_mtx.NCol();
Chao Liu's avatar
Chao Liu committed
51
52
53
54

        static_assert((MPerThread % MPerThreadSubC == 0) && (NPerThread % NPerThreadSubC == 0),
                      "wrong! Cannot evenly divide thread work among repeat \n");

Chao Liu's avatar
Chao Liu committed
55
56
        constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
        constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
Chao Liu's avatar
Chao Liu committed
57
58
59
60

        static_assert((M % MRepeat == 0) && (N % NRepeat == 0),
                      "wrong! Cannot evenly divide work among repeat\n");

Chao Liu's avatar
Chao Liu committed
61
62
        constexpr index_t MPerLevel1Cluster = M / MRepeat;
        constexpr index_t NPerLevel1Cluster = N / NRepeat;
Chao Liu's avatar
Chao Liu committed
63
64
65
66
67

        static_assert((MPerLevel1Cluster % MLevel1Cluster == 0) &&
                          (NPerLevel1Cluster % NLevel1Cluster == 0),
                      "wrong! Cannot evenly divide work among Level1Cluster\n");

Chao Liu's avatar
Chao Liu committed
68
69
        constexpr index_t MPerLevel0Cluster = MPerLevel1Cluster / MLevel1Cluster;
        constexpr index_t NPerLevel0Cluster = NPerLevel1Cluster / NLevel1Cluster;
Chao Liu's avatar
Chao Liu committed
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84

        static_assert((MPerLevel0Cluster % MLevel0Cluster == 0) &&
                          (NPerLevel0Cluster % NLevel0Cluster == 0),
                      "wrong! Cannot evenly divide work among Level0Cluster\n");

        static_assert((MPerThreadSubC == MPerLevel0Cluster / MLevel0Cluster) &&
                          (NPerThreadSubC == NPerLevel0Cluster / NLevel0Cluster),
                      "wrong! thread work size is wrong\n");

        auto c_thread_mtx_index = GetBeginOfThreadMatrixC(get_thread_local_1d_id());

        mMyThreadOffsetA = a_block_mtx.Get1dIndex(0, c_thread_mtx_index.row);
        mMyThreadOffsetB = b_block_mtx.Get1dIndex(0, c_thread_mtx_index.col);
    }

Chao Liu's avatar
Chao Liu committed
85
    __device__ static MatrixIndex GetBeginOfThreadMatrixC(index_t thread_id)
Chao Liu's avatar
Chao Liu committed
86
    {
Chao Liu's avatar
Chao Liu committed
87
        constexpr index_t ThreadPerLevel0Cluster = MLevel0Cluster * NLevel0Cluster;
Chao Liu's avatar
Chao Liu committed
88

Chao Liu's avatar
Chao Liu committed
89
90
91
        index_t level1_id   = thread_id / ThreadPerLevel0Cluster;
        index_t level1_m_id = level1_id / NLevel1Cluster;
        index_t level1_n_id = level1_id % NLevel1Cluster;
Chao Liu's avatar
Chao Liu committed
92

Chao Liu's avatar
Chao Liu committed
93
94
95
        index_t level0_id   = thread_id % ThreadPerLevel0Cluster;
        index_t level0_m_id = level0_id / NLevel0Cluster;
        index_t level0_n_id = level0_id % NLevel0Cluster;
Chao Liu's avatar
Chao Liu committed
96

Chao Liu's avatar
Chao Liu committed
97
98
        constexpr index_t MPerLevel0Cluster = MPerThreadSubC * MLevel0Cluster;
        constexpr index_t NPerLevel0Cluster = NPerThreadSubC * NLevel0Cluster;
Chao Liu's avatar
Chao Liu committed
99
100
101
102
103
104

        return MatrixIndex{level1_m_id * MPerLevel0Cluster + level0_m_id * MPerThreadSubC,
                           level1_n_id * NPerLevel0Cluster + level0_n_id * NPerThreadSubC};
    }

    // this should be optimized away if input is known
Chao Liu's avatar
Chao Liu committed
105
106
    __device__ static MatrixIndex GetDistanceFromBeginOfThreadMatrixC(index_t m_in_c,
                                                                      index_t n_in_c)
Chao Liu's avatar
Chao Liu committed
107
    {
Chao Liu's avatar
Chao Liu committed
108
        constexpr auto c_thread_mtx = ThreadMatrixC{};
Chao Liu's avatar
Chao Liu committed
109

Chao Liu's avatar
Chao Liu committed
110
111
        constexpr index_t MPerThread = c_thread_mtx.NRow();
        constexpr index_t NPerThread = c_thread_mtx.NCol();
Chao Liu's avatar
Chao Liu committed
112

Chao Liu's avatar
Chao Liu committed
113
114
        constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
        constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
Chao Liu's avatar
Chao Liu committed
115

Chao Liu's avatar
Chao Liu committed
116
117
        constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
        constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
Chao Liu's avatar
Chao Liu committed
118

Chao Liu's avatar
Chao Liu committed
119
120
        index_t m_repeat = m_in_c / MPerThreadSubC;
        index_t n_repeat = n_in_c / NPerThreadSubC;
Chao Liu's avatar
Chao Liu committed
121

Chao Liu's avatar
Chao Liu committed
122
123
        index_t m_in_sub_c = m_in_c % MPerThreadSubC;
        index_t n_in_sub_c = n_in_c % NPerThreadSubC;
Chao Liu's avatar
Chao Liu committed
124
125
126
127
128

        return MatrixIndex{m_repeat * MPerLevel1Cluster + m_in_sub_c,
                           n_repeat * NPerLevel1Cluster + n_in_sub_c};
    }

129
130
#if DEVICE_BACKEND_HIP
    template <class FloatA, class FloatB, class FloatC>
Chao Liu's avatar
Chao Liu committed
131
132
    __device__ void Run_asm(const FloatA* __restrict__ p_a_block,
                            const FloatB* __restrict__ p_b_block,
133
                            FloatC* __restrict__ p_c_thread) const
Chao Liu's avatar
Chao Liu committed
134
    {
Chao Liu's avatar
Chao Liu committed
135
136
137
138
        static_assert(is_same<FloatA, float>::value && is_same<FloatB, float>::value &&
                          is_same<FloatC, float>::value,
                      "Run_asm only deal with float\n");

Chao Liu's avatar
Chao Liu committed
139
140
        constexpr auto True  = integral_constant<bool, true>{};
        constexpr auto False = integral_constant<bool, false>{};
Chao Liu's avatar
Chao Liu committed
141

Chao Liu's avatar
Chao Liu committed
142
143
144
        constexpr auto a_block_mtx  = BlockMatrixA{};
        constexpr auto b_block_mtx  = BlockMatrixB{};
        constexpr auto c_thread_mtx = ThreadMatrixC{};
Chao Liu's avatar
Chao Liu committed
145

Chao Liu's avatar
Chao Liu committed
146
147
148
        constexpr index_t M = a_block_mtx.NCol();
        constexpr index_t N = b_block_mtx.NCol();
        constexpr index_t K = a_block_mtx.NRow();
Chao Liu's avatar
Chao Liu committed
149

Chao Liu's avatar
Chao Liu committed
150
151
        constexpr index_t MPerThread = c_thread_mtx.NRow();
        constexpr index_t NPerThread = c_thread_mtx.NCol();
Chao Liu's avatar
Chao Liu committed
152
153

        // thread A, B for GEMM
Chao Liu's avatar
Chao Liu committed
154
155
        constexpr auto a_thread_mtx =
            make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<MPerThread>{});
Chao Liu's avatar
Chao Liu committed
156

Chao Liu's avatar
Chao Liu committed
157
158
        constexpr auto b_thread_mtx =
            make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<NPerThread>{});
Chao Liu's avatar
Chao Liu committed
159
160

        // thread A-sub, B-sub for copy
Chao Liu's avatar
Chao Liu committed
161
162
        constexpr auto a_thread_sub_mtx = make_ConstantMatrixDescriptor(
            Number<KPerThreadLoop>{}, Number<MPerThreadSubC>{}, Number<MPerThread>{});
Chao Liu's avatar
Chao Liu committed
163

Chao Liu's avatar
Chao Liu committed
164
165
        constexpr auto b_thread_sub_mtx = make_ConstantMatrixDescriptor(
            Number<KPerThreadLoop>{}, Number<NPerThreadSubC>{}, Number<NPerThread>{});
Chao Liu's avatar
Chao Liu committed
166

Chao Liu's avatar
Chao Liu committed
167
168
169
170
171
172
        static_assert(MPerThreadSubC == 4 && NPerThreadSubC == 4 && KPerThreadLoop == 1 &&
                          MPerThread == 8 && NPerThread == 8,
                      "Run_asm cannot deal with this GEMM shape yet\n");

        using Float4 = vector_type<float, 4>::MemoryType;

Chao Liu's avatar
Chao Liu committed
173
        float p_thread[a_thread_mtx.GetElementSpace() + b_thread_mtx.GetElementSpace()];
Jing Zhang's avatar
Jing Zhang committed
174

Chao Liu's avatar
Chao Liu committed
175
176
        FloatA* p_a_thread = p_thread;
        FloatB* p_b_thread = p_thread + a_thread_mtx.GetElementSpace();
Chao Liu's avatar
Chao Liu committed
177

Chao Liu's avatar
Chao Liu committed
178
179
        constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
        constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
Chao Liu's avatar
Chao Liu committed
180

Chao Liu's avatar
Chao Liu committed
181
182
        constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
        constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
Chao Liu's avatar
Chao Liu committed
183

Jing Zhang's avatar
Jing Zhang committed
184
185
186
        Float4* reg_a = (Float4*)(p_a_thread);
        Float4* reg_b = (Float4*)(p_b_thread);
        Float4* reg_c = (Float4*)(p_c_thread);
Chao Liu's avatar
Chao Liu committed
187
188
        void* a_loc   = (void*)(p_a_block + mMyThreadOffsetA);
        void* b_loc   = (void*)(p_b_block + mMyThreadOffsetB);
Jing Zhang's avatar
Jing Zhang committed
189

Jing Zhang's avatar
Jing Zhang committed
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
        int lds_a_block_off   = sizeof(Float) * M;
        int lds_b_block_off   = sizeof(Float) * N;
        int lds_a_block_off_1 = MPerLevel1Cluster * sizeof(Float);
        int lds_b_block_off_1 = NPerLevel1Cluster * sizeof(Float);
        ds_read_b128(reg_a[0], a_loc, 0);
        ds_read_b128(reg_b[0], b_loc, 0);
        ds_read_b128(reg_b[1], b_loc, lds_b_block_off_1);
        ds_read_b128(reg_a[1], a_loc, lds_a_block_off_1);
        lgkmcnt(2);
        outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]);
        lgkmcnt(1);
        outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]);
        lgkmcnt(0);
#pragma unroll
        for(int k_i = 1; k_i < K; k_i++)
        {
            ds_read_b128(reg_a[0], a_loc, k_i * lds_a_block_off);
Jing Zhang's avatar
Jing Zhang committed
207
            outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]);
Jing Zhang's avatar
Jing Zhang committed
208
            ds_read_b128(reg_b[0], b_loc, k_i * lds_b_block_off);
Jing Zhang's avatar
Jing Zhang committed
209
            outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]);
Jing Zhang's avatar
Jing Zhang committed
210
211
            ds_read_b128(reg_b[1], b_loc, lds_b_block_off_1 + k_i * lds_b_block_off);
            ds_read_b128(reg_a[1], a_loc, lds_a_block_off_1 + k_i * lds_a_block_off);
Jing Zhang's avatar
Jing Zhang committed
212
213
214
215
216
            lgkmcnt(2);
            outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]);
            lgkmcnt(1);
            outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]);
            lgkmcnt(0);
Chao Liu's avatar
Chao Liu committed
217
        }
Jing Zhang's avatar
Jing Zhang committed
218
219
        outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]);
        outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]);
Chao Liu's avatar
Chao Liu committed
220
    }
221
#endif
222

223
    template <class FloatA, class FloatB, class FloatC>
Chao Liu's avatar
Chao Liu committed
224
225
    __device__ void Run(const FloatA* const __restrict__ p_a_block,
                        const FloatB* const __restrict__ p_b_block,
226
                        FloatC* const __restrict__ p_c_thread) const
Chao Liu's avatar
Chao Liu committed
227
228
229
230
231
232
233
234
    {
        constexpr auto True  = integral_constant<bool, true>{};
        constexpr auto False = integral_constant<bool, false>{};

        constexpr auto a_block_mtx  = BlockMatrixA{};
        constexpr auto b_block_mtx  = BlockMatrixB{};
        constexpr auto c_thread_mtx = ThreadMatrixC{};

Chao Liu's avatar
Chao Liu committed
235
236
237
        constexpr index_t M = a_block_mtx.NCol();
        constexpr index_t N = b_block_mtx.NCol();
        constexpr index_t K = a_block_mtx.NRow();
Chao Liu's avatar
Chao Liu committed
238

Chao Liu's avatar
Chao Liu committed
239
240
        constexpr index_t MPerThread = c_thread_mtx.NRow();
        constexpr index_t NPerThread = c_thread_mtx.NCol();
Chao Liu's avatar
Chao Liu committed
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258

        // thread A, B for GEMM
        constexpr auto a_thread_mtx =
            make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<MPerThread>{});

        constexpr auto b_thread_mtx =
            make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<NPerThread>{});

        // thread A-sub, B-sub for copy
        constexpr auto a_thread_sub_mtx = make_ConstantMatrixDescriptor(
            Number<KPerThreadLoop>{}, Number<MPerThreadSubC>{}, Number<MPerThread>{});

        constexpr auto b_thread_sub_mtx = make_ConstantMatrixDescriptor(
            Number<KPerThreadLoop>{}, Number<NPerThreadSubC>{}, Number<NPerThread>{});

        FloatA p_a_thread[a_thread_mtx.GetElementSpace()];
        FloatB p_b_thread[b_thread_mtx.GetElementSpace()];

Chao Liu's avatar
Chao Liu committed
259
260
        constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
        constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
Chao Liu's avatar
Chao Liu committed
261

Chao Liu's avatar
Chao Liu committed
262
263
        constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
        constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
Chao Liu's avatar
Chao Liu committed
264

Chao Liu's avatar
Chao Liu committed
265
266
        const FloatA* const p_a_block_thread_offset = p_a_block + mMyThreadOffsetA;

Chao Liu's avatar
Chao Liu committed
267
268
#pragma unroll
        // loop over k
Chao Liu's avatar
Chao Liu committed
269
        for(index_t k_begin = 0; k_begin < K; k_begin += KPerThreadLoop)
Chao Liu's avatar
Chao Liu committed
270
        {
Chao Liu's avatar
Chao Liu committed
271
#pragma unroll
Chao Liu's avatar
Chao Liu committed
272
            // copy A-sub to form A
Chao Liu's avatar
Chao Liu committed
273
            for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
Chao Liu's avatar
Chao Liu committed
274
275
276
277
278
279
            {
                threadwise_matrix_copy(
                    a_block_mtx,
                    p_a_block + a_block_mtx.Get1dIndex(k_begin, m_repeat * MPerLevel1Cluster) +
                        mMyThreadOffsetA,
                    a_thread_mtx,
Chao Liu's avatar
Chao Liu committed
280
                    p_a_thread + a_thread_mtx.Get1dIndex(0, m_repeat * MPerThreadSubC),
Chao Liu's avatar
Chao Liu committed
281
282
                    a_thread_sub_mtx.GetLengths(),
                    Number<DataPerReadA>{});
Chao Liu's avatar
Chao Liu committed
283
            }
Chao Liu's avatar
Chao Liu committed
284

Chao Liu's avatar
Chao Liu committed
285
#pragma unroll
Chao Liu's avatar
Chao Liu committed
286
            // copy B-sub to form B
Chao Liu's avatar
Chao Liu committed
287
            for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
Chao Liu's avatar
Chao Liu committed
288
289
290
291
292
293
294
            {
                threadwise_matrix_copy(
                    b_block_mtx,
                    p_b_block + b_block_mtx.Get1dIndex(k_begin, n_repeat * NPerLevel1Cluster) +
                        mMyThreadOffsetB,
                    b_thread_mtx,
                    p_b_thread + b_thread_mtx.Get1dIndex(0, n_repeat * NPerThreadSubC),
Chao Liu's avatar
Chao Liu committed
295
296
                    b_thread_sub_mtx.GetLengths(),
                    Number<DataPerReadB>{});
Chao Liu's avatar
Chao Liu committed
297
298
            }

Chao Liu's avatar
Chao Liu committed
299
            // C = A * B
Chao Liu's avatar
Chao Liu committed
300
301
302
303
304
305
306
307
            threadwise_gemm(a_thread_mtx,
                            True,
                            p_a_thread,
                            b_thread_mtx,
                            False,
                            p_b_thread,
                            c_thread_mtx,
                            False,
308
                            p_c_thread);
Chao Liu's avatar
Chao Liu committed
309
310
311
        }
    }

312
    template <class FloatA, class FloatB, class FloatC>
313
314
    __device__ void Run_RegisterDoubleBuffer(FloatA* const p_a_block,
                                             FloatB* const p_b_block,
315
                                             FloatC* p_c_thread) const
316
    {
Chao Liu's avatar
Chao Liu committed
317
318
        constexpr auto True  = integral_constant<bool, true>{};
        constexpr auto False = integral_constant<bool, false>{};
319

Chao Liu's avatar
Chao Liu committed
320
321
322
        constexpr auto a_block_mtx  = BlockMatrixA{};
        constexpr auto b_block_mtx  = BlockMatrixB{};
        constexpr auto c_thread_mtx = ThreadMatrixC{};
323

Chao Liu's avatar
Chao Liu committed
324
325
326
        constexpr index_t M = a_block_mtx.NCol();
        constexpr index_t N = b_block_mtx.NCol();
        constexpr index_t K = a_block_mtx.NRow();
327

Chao Liu's avatar
Chao Liu committed
328
329
        constexpr index_t MPerThread = c_thread_mtx.NRow();
        constexpr index_t NPerThread = c_thread_mtx.NCol();
330
331

        // thread A, B for GEMM
Chao Liu's avatar
Chao Liu committed
332
333
        constexpr auto a_thread_mtx =
            make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<MPerThread>{});
334

Chao Liu's avatar
Chao Liu committed
335
336
        constexpr auto b_thread_mtx =
            make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<NPerThread>{});
337
338

        // thread A-sub, B-sub for copy
Chao Liu's avatar
Chao Liu committed
339
340
        constexpr auto a_thread_sub_mtx = make_ConstantMatrixDescriptor(
            Number<KPerThreadLoop>{}, Number<MPerThreadSubC>{}, Number<MPerThread>{});
341

Chao Liu's avatar
Chao Liu committed
342
343
        constexpr auto b_thread_sub_mtx = make_ConstantMatrixDescriptor(
            Number<KPerThreadLoop>{}, Number<NPerThreadSubC>{}, Number<NPerThread>{});
344

345
        // register
346
347
348
349
350
351
        FloatA p_a_thread_0[a_thread_mtx.GetElementSpace()];
        FloatB p_b_thread_0[b_thread_mtx.GetElementSpace()];

        FloatA p_a_thread_1[a_thread_mtx.GetElementSpace()];
        FloatB p_b_thread_1[b_thread_mtx.GetElementSpace()];

Chao Liu's avatar
Chao Liu committed
352
353
        constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
        constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
354

Chao Liu's avatar
Chao Liu committed
355
356
        constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
        constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
357

Chao Liu's avatar
Chao Liu committed
358
// preload A, B
359
#pragma unroll
Chao Liu's avatar
Chao Liu committed
360
        for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
361
362
363
364
365
        { // copy A-sub to form A
            threadwise_matrix_copy(a_block_mtx,
                                   p_a_block + mMyThreadOffsetA + m_repeat * MPerLevel1Cluster,
                                   a_thread_sub_mtx,
                                   p_a_thread_0 + m_repeat * MPerThreadSubC,
Chao Liu's avatar
Chao Liu committed
366
367
                                   a_thread_sub_mtx.GetLengths(),
                                   Number<DataPerReadA>{});
368
369
370
        }

#pragma unroll
Chao Liu's avatar
Chao Liu committed
371
        for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
372
373
374
375
376
        { // copy B-sub to form B
            threadwise_matrix_copy(b_block_mtx,
                                   p_b_block + mMyThreadOffsetB + n_repeat * NPerLevel1Cluster,
                                   b_thread_sub_mtx,
                                   p_b_thread_0 + n_repeat * NPerThreadSubC,
Chao Liu's avatar
Chao Liu committed
377
378
                                   b_thread_sub_mtx.GetLengths(),
                                   Number<DataPerReadB>{});
379
380
381
382
383
        }

        bool even_loop = true;

#pragma unroll
Chao Liu's avatar
Chao Liu committed
384
        for(index_t k_begin = 0; k_begin + KPerThreadLoop < K;
385
386
387
388
389
390
391
392
            k_begin += KPerThreadLoop, even_loop = !even_loop)
        { // loop over k
            FloatA* p_a_thread_now = even_loop ? p_a_thread_0 : p_a_thread_1;
            FloatB* p_b_thread_now = even_loop ? p_b_thread_0 : p_b_thread_1;

            FloatA* p_a_thread_next = even_loop ? p_a_thread_1 : p_a_thread_0;
            FloatB* p_b_thread_next = even_loop ? p_b_thread_1 : p_b_thread_0;

Chao Liu's avatar
Chao Liu committed
393
// preload next A, B
394
#pragma unroll
Chao Liu's avatar
Chao Liu committed
395
            for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
396
397
398
399
400
401
402
            { // copy A-sub to form A
                threadwise_matrix_copy(a_block_mtx,
                                       p_a_block + mMyThreadOffsetA +
                                           (k_begin + 1) * a_block_mtx.RowStride() +
                                           m_repeat * MPerLevel1Cluster,
                                       a_thread_sub_mtx,
                                       p_a_thread_next + m_repeat * MPerThreadSubC,
Chao Liu's avatar
Chao Liu committed
403
404
                                       a_thread_sub_mtx.GetLengths(),
                                       Number<DataPerReadA>{});
405
406
407
            }

#pragma unroll
Chao Liu's avatar
Chao Liu committed
408
            for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
409
410
411
412
413
414
415
            { // copy B-sub to form B
                threadwise_matrix_copy(b_block_mtx,
                                       p_b_block + mMyThreadOffsetB +
                                           (k_begin + 1) * b_block_mtx.RowStride() +
                                           n_repeat * NPerLevel1Cluster,
                                       b_thread_sub_mtx,
                                       p_b_thread_next + n_repeat * NPerThreadSubC,
Chao Liu's avatar
Chao Liu committed
416
417
                                       b_thread_sub_mtx.GetLengths(),
                                       Number<DataPerReadB>{});
418
419
420
421
422
423
424
425
426
427
428
            }

            // C = A * B
            threadwise_gemm(a_thread_mtx,
                            True,
                            p_a_thread_now,
                            b_thread_mtx,
                            False,
                            p_b_thread_now,
                            c_thread_mtx,
                            False,
429
                            p_c_thread);
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
        }

        // last loop
        {
            FloatA* p_a_thread_now = even_loop ? p_a_thread_0 : p_a_thread_1;
            FloatB* p_b_thread_now = even_loop ? p_b_thread_0 : p_b_thread_1;

            // C = A * B
            threadwise_gemm(a_thread_mtx,
                            True,
                            p_a_thread_now,
                            b_thread_mtx,
                            False,
                            p_b_thread_now,
                            c_thread_mtx,
                            False,
446
                            p_c_thread);
Chao Liu's avatar
Chao Liu committed
447
448
        }
    }
Chao Liu's avatar
Chao Liu committed
449
};