blockwise_batched_gemm.hip.hpp 22.3 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
2
3
#pragma once
#include "threadwise_gemm.hip.hpp"

Chao Liu's avatar
Chao Liu committed
4
template <index_t BlockSize,
Chao Liu's avatar
Chao Liu committed
5
6
7
          class BlockMatrixA,
          class BlockMatrixB,
          class ThreadMatrixC,
Chao Liu's avatar
Chao Liu committed
8
9
10
11
12
13
14
15
16
17
18
          index_t BlockMatrixStrideA,
          index_t BlockMatrixStrideB,
          index_t ThreadMatrixStrideC,
          index_t BatchSize,
          index_t MPerThreadSubC,
          index_t NPerThreadSubC,
          index_t MLevel0Cluster,
          index_t NLevel0Cluster,
          index_t MLevel1Cluster,
          index_t NLevel1Cluster,
          index_t KPerThreadLoop,
Chao Liu's avatar
Chao Liu committed
19
20
21
          index_t BatchPerThread,
          index_t DataPerReadA,
          index_t DataPerReadB>
Chao Liu's avatar
Chao Liu committed
22
23
struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
{
Chao Liu's avatar
Chao Liu committed
24
25
    index_t mMyThreadOffsetA = 0;
    index_t mMyThreadOffsetB = 0;
Chao Liu's avatar
Chao Liu committed
26
27
28

    struct MatrixIndex
    {
Chao Liu's avatar
Chao Liu committed
29
30
31
        index_t batch;
        index_t row;
        index_t col;
Chao Liu's avatar
Chao Liu committed
32
33
34
35
36
37
38
    };

    __device__ BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2()
    {
        static_assert(BatchSize % BatchPerThread == 0,
                      "wrong! BatchSize is not dividable by BatchPerThread");

Chao Liu's avatar
Chao Liu committed
39
        constexpr index_t BatchThreadWork = BatchSize / BatchPerThread;
Chao Liu's avatar
Chao Liu committed
40

Chao Liu's avatar
Chao Liu committed
41
        constexpr index_t ThreadPerLevel1Cluster =
Chao Liu's avatar
Chao Liu committed
42
43
44
45
46
47
48
49
50
51
52
53
            MLevel0Cluster * NLevel0Cluster * MLevel1Cluster * NLevel1Cluster;

        static_assert(BlockSize == BatchThreadWork * ThreadPerLevel1Cluster,
                      "wrong! wrong blocksize\n");

        constexpr auto a_block_mtx  = BlockMatrixA{};
        constexpr auto b_block_mtx  = BlockMatrixB{};
        constexpr auto c_thread_mtx = ThreadMatrixC{};

        static_assert(a_block_mtx.NRow() == b_block_mtx.NRow(),
                      "wrong! K dimension not consistent\n");

Chao Liu's avatar
Chao Liu committed
54
55
        constexpr index_t M = a_block_mtx.NCol(); // A is transposed
        constexpr index_t N = b_block_mtx.NCol();
Chao Liu's avatar
Chao Liu committed
56

Chao Liu's avatar
Chao Liu committed
57
58
        constexpr index_t MPerThread = c_thread_mtx.NRow();
        constexpr index_t NPerThread = c_thread_mtx.NCol();
Chao Liu's avatar
Chao Liu committed
59
60
61
62

        static_assert((MPerThread % MPerThreadSubC == 0) && (NPerThread % NPerThreadSubC == 0),
                      "wrong! Cannot evenly divide thread work among repeat \n");

Chao Liu's avatar
Chao Liu committed
63
64
        constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
        constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
Chao Liu's avatar
Chao Liu committed
65
66
67
68

        static_assert((M % MRepeat == 0) && (N % NRepeat == 0),
                      "wrong! Cannot evenly divide work among repeat\n");

Chao Liu's avatar
Chao Liu committed
69
70
        constexpr index_t MPerLevel1Cluster = M / MRepeat;
        constexpr index_t NPerLevel1Cluster = N / NRepeat;
Chao Liu's avatar
Chao Liu committed
71
72
73
74
75

        static_assert((MPerLevel1Cluster % MLevel1Cluster == 0) &&
                          (NPerLevel1Cluster % NLevel1Cluster == 0),
                      "wrong! Cannot evenly divide work among Level1Cluster\n");

Chao Liu's avatar
Chao Liu committed
76
77
        constexpr index_t MPerLevel0Cluster = MPerLevel1Cluster / MLevel1Cluster;
        constexpr index_t NPerLevel0Cluster = NPerLevel1Cluster / NLevel1Cluster;
Chao Liu's avatar
Chao Liu committed
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113

        static_assert((MPerLevel0Cluster % MLevel0Cluster == 0) &&
                          (NPerLevel0Cluster % NLevel0Cluster == 0),
                      "wrong! Cannot evenly divide work among Level0Cluster\n");

        static_assert((MPerThreadSubC == MPerLevel0Cluster / MLevel0Cluster) &&
                          (NPerThreadSubC == NPerLevel0Cluster / NLevel0Cluster),
                      "wrong! thread work size is wrong\n");

        const auto c_thread_mtx_index = GetBeginOfThreadMatrixC(get_thread_local_1d_id());

        mMyThreadOffsetA = c_thread_mtx_index.batch * BlockMatrixStrideA +
                           a_block_mtx.Get1dIndex(0, c_thread_mtx_index.row);

        mMyThreadOffsetB = c_thread_mtx_index.batch * BlockMatrixStrideB +
                           b_block_mtx.Get1dIndex(0, c_thread_mtx_index.col);

#if 0
        if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
        {
            print_ConstantMatrixDescriptor(BlockMatrixA{}, "a_block_mtx: ");
            print_ConstantMatrixDescriptor(BlockMatrixB{}, "b_block_mtx: ");
            print_ConstantMatrixDescriptor(ThreadMatrixC{}, "c_thread_mtx: ");

            printf("%u %u, %u %u %u, %u %u\n",
                   get_block_1d_id(),
                   get_thread_local_1d_id(),
                   c_thread_mtx_index.batch,
                   c_thread_mtx_index.row,
                   c_thread_mtx_index.col,
                   mMyThreadOffsetA,
                   mMyThreadOffsetB);
        }
#endif
    }

Chao Liu's avatar
Chao Liu committed
114
    __device__ MatrixIndex GetBeginOfThreadMatrixC(index_t thread_id) const
Chao Liu's avatar
Chao Liu committed
115
    {
Chao Liu's avatar
Chao Liu committed
116
        constexpr index_t ThreadPerLevel1Cluster =
Chao Liu's avatar
Chao Liu committed
117
118
            MLevel0Cluster * NLevel0Cluster * MLevel1Cluster * NLevel1Cluster;

Chao Liu's avatar
Chao Liu committed
119
        constexpr index_t ThreadPerLevel0Cluster = MLevel0Cluster * NLevel0Cluster;
Chao Liu's avatar
Chao Liu committed
120

Chao Liu's avatar
Chao Liu committed
121
122
        index_t batch_work_id = thread_id / ThreadPerLevel1Cluster;
        index_t cluster_id    = thread_id - batch_work_id * ThreadPerLevel1Cluster;
Chao Liu's avatar
Chao Liu committed
123

Chao Liu's avatar
Chao Liu committed
124
125
126
        index_t level1_id   = cluster_id / ThreadPerLevel0Cluster;
        index_t level1_m_id = level1_id / NLevel1Cluster;
        index_t level1_n_id = level1_id % NLevel1Cluster;
Chao Liu's avatar
Chao Liu committed
127

Chao Liu's avatar
Chao Liu committed
128
129
130
        index_t level0_id   = cluster_id % ThreadPerLevel0Cluster;
        index_t level0_m_id = level0_id / NLevel0Cluster;
        index_t level0_n_id = level0_id % NLevel0Cluster;
Chao Liu's avatar
Chao Liu committed
131

Chao Liu's avatar
Chao Liu committed
132
133
        constexpr index_t MPerLevel0Cluster = MPerThreadSubC * MLevel0Cluster;
        constexpr index_t NPerLevel0Cluster = NPerThreadSubC * NLevel0Cluster;
Chao Liu's avatar
Chao Liu committed
134
135
136
137
138
139

        return MatrixIndex{batch_work_id * BatchPerThread,
                           level1_m_id * MPerLevel0Cluster + level0_m_id * MPerThreadSubC,
                           level1_n_id * NPerLevel0Cluster + level0_n_id * NPerThreadSubC};
    }

Chao Liu's avatar
Chao Liu committed
140
    // this should be optimized away because input will be known at compile time
Chao Liu's avatar
Chao Liu committed
141
    __device__ static MatrixIndex
Chao Liu's avatar
Chao Liu committed
142
    GetDistanceFromBeginOfThreadMatrixC(index_t batch_in_c, index_t m_in_c, index_t n_in_c)
Chao Liu's avatar
Chao Liu committed
143
144
145
    {
        constexpr auto c_thread_mtx = ThreadMatrixC{};

Chao Liu's avatar
Chao Liu committed
146
147
        constexpr index_t MPerThread = c_thread_mtx.NRow();
        constexpr index_t NPerThread = c_thread_mtx.NCol();
Chao Liu's avatar
Chao Liu committed
148

Chao Liu's avatar
Chao Liu committed
149
150
        constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
        constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
Chao Liu's avatar
Chao Liu committed
151

Chao Liu's avatar
Chao Liu committed
152
153
        constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
        constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
Chao Liu's avatar
Chao Liu committed
154

Chao Liu's avatar
Chao Liu committed
155
156
        index_t m_repeat = m_in_c / MPerThreadSubC;
        index_t n_repeat = n_in_c / NPerThreadSubC;
Chao Liu's avatar
Chao Liu committed
157

Chao Liu's avatar
Chao Liu committed
158
159
        index_t m_in_sub_c = m_in_c % MPerThreadSubC;
        index_t n_in_sub_c = n_in_c % NPerThreadSubC;
Chao Liu's avatar
Chao Liu committed
160
161
162
163
164
165

        return MatrixIndex{batch_in_c,
                           m_repeat * MPerLevel1Cluster + m_in_sub_c,
                           n_repeat * NPerLevel1Cluster + n_in_sub_c};
    }

Chao Liu's avatar
Chao Liu committed
166
    template <class FloatA, class FloatB, class FloatC>
Chao Liu's avatar
Chao Liu committed
167
168
    __device__ void Run(const FloatA* __restrict__ p_a_block,
                        const FloatB* __restrict__ p_b_block,
Chao Liu's avatar
Chao Liu committed
169
                        FloatC* __restrict__ p_c_thread) const
Chao Liu's avatar
Chao Liu committed
170
171
172
173
174
175
176
177
    {
        constexpr auto True  = integral_constant<bool, true>{};
        constexpr auto False = integral_constant<bool, false>{};

        constexpr auto a_block_mtx  = BlockMatrixA{};
        constexpr auto b_block_mtx  = BlockMatrixB{};
        constexpr auto c_thread_mtx = ThreadMatrixC{};

Chao Liu's avatar
Chao Liu committed
178
        constexpr index_t KPerBlock = a_block_mtx.NRow(); // A is transposed
Chao Liu's avatar
Chao Liu committed
179

Chao Liu's avatar
Chao Liu committed
180
181
        constexpr index_t MPerThread = c_thread_mtx.NRow();
        constexpr index_t NPerThread = c_thread_mtx.NCol();
Chao Liu's avatar
Chao Liu committed
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200

        // thread A, B for GEMM
        //   A is transposed, b is not
        constexpr auto a_thread_mtx =
            make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<MPerThread>{});

        constexpr auto b_thread_mtx =
            make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<NPerThread>{});

        // thread A-sub, B-sub for copy
        constexpr auto a_thread_sub_mtx = make_ConstantMatrixDescriptor(
            Number<KPerThreadLoop>{}, Number<MPerThreadSubC>{}, Number<MPerThread>{});

        constexpr auto b_thread_sub_mtx = make_ConstantMatrixDescriptor(
            Number<KPerThreadLoop>{}, Number<NPerThreadSubC>{}, Number<NPerThread>{});

        FloatA p_a_thread[a_thread_mtx.GetElementSpace()];
        FloatB p_b_thread[b_thread_mtx.GetElementSpace()];

Chao Liu's avatar
Chao Liu committed
201
202
        constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
        constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
Chao Liu's avatar
Chao Liu committed
203

Chao Liu's avatar
Chao Liu committed
204
205
        constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
        constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
Chao Liu's avatar
Chao Liu committed
206
207
208

// loop over k
#pragma unroll
Chao Liu's avatar
Chao Liu committed
209
        for(index_t k_begin = 0; k_begin < KPerBlock; k_begin += KPerThreadLoop)
Chao Liu's avatar
Chao Liu committed
210
211
212
        {
// loop over batch
#pragma unroll
Chao Liu's avatar
Chao Liu committed
213
            for(index_t ib = 0; ib < BatchPerThread; ++ib)
Chao Liu's avatar
Chao Liu committed
214
215
            {
                // read next batch of a, b
Chao Liu's avatar
Chao Liu committed
216
                if(BlockMatrixStrideA != 0 or ib == 0)
Chao Liu's avatar
Chao Liu committed
217
218
                {
#pragma unroll
Chao Liu's avatar
Chao Liu committed
219
                    for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
Chao Liu's avatar
Chao Liu committed
220
221
222
223
224
                    {
                        threadwise_matrix_copy(
                            a_block_mtx,
                            p_a_block +
                                a_block_mtx.Get1dIndex(k_begin, m_repeat * MPerLevel1Cluster) +
Chao Liu's avatar
Chao Liu committed
225
                                ib * BlockMatrixStrideA + mMyThreadOffsetA,
Chao Liu's avatar
Chao Liu committed
226
227
                            a_thread_mtx,
                            p_a_thread + a_thread_mtx.Get1dIndex(0, m_repeat * MPerThreadSubC),
Chao Liu's avatar
Chao Liu committed
228
229
                            a_thread_sub_mtx.GetLengths(),
                            Number<DataPerReadA>{});
Chao Liu's avatar
Chao Liu committed
230
231
232
                    }
                }

Chao Liu's avatar
Chao Liu committed
233
                if(BlockMatrixStrideB != 0 or ib == 0)
Chao Liu's avatar
Chao Liu committed
234
235
                {
#pragma unroll
Chao Liu's avatar
Chao Liu committed
236
                    for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
Chao Liu's avatar
Chao Liu committed
237
238
239
240
241
                    {
                        threadwise_matrix_copy(
                            b_block_mtx,
                            p_b_block +
                                b_block_mtx.Get1dIndex(k_begin, n_repeat * NPerLevel1Cluster) +
Chao Liu's avatar
Chao Liu committed
242
                                ib * BlockMatrixStrideB + mMyThreadOffsetB,
Chao Liu's avatar
Chao Liu committed
243
244
                            b_thread_mtx,
                            p_b_thread + b_thread_mtx.Get1dIndex(0, n_repeat * NPerThreadSubC),
Chao Liu's avatar
Chao Liu committed
245
246
                            b_thread_sub_mtx.GetLengths(),
                            Number<DataPerReadB>{});
Chao Liu's avatar
Chao Liu committed
247
248
249
                    }
                }

250
251
252
#if 0
                if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
                {
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
                    printf("a: %f %f %f %f %f %f %f %f, b: %f %f %f %f %f %f %f %f\n",
                           p_a_thread[0],
                           p_a_thread[1],
                           p_a_thread[2],
                           p_a_thread[3],
                           p_a_thread[4],
                           p_a_thread[5],
                           p_a_thread[6],
                           p_a_thread[7],
                           p_b_thread[0],
                           p_b_thread[1],
                           p_b_thread[2],
                           p_b_thread[3],
                           p_b_thread[4],
                           p_b_thread[5],
                           p_b_thread[6],
                           p_b_thread[7]);
270
271
272
                }
#endif

Chao Liu's avatar
Chao Liu committed
273
274
275
276
277
278
279
280
281
282
                threadwise_gemm(a_thread_mtx,
                                True,
                                p_a_thread,
                                b_thread_mtx,
                                False,
                                p_b_thread,
                                c_thread_mtx,
                                False,
                                p_c_thread + ib * ThreadMatrixStrideC);
            }
Chao Liu's avatar
Chao Liu committed
283
284
285
        }
    }

286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
#if DEVICE_BACKEND_HIP
    template <class FloatA, class FloatB, class FloatC>
    __device__ void Run_asm(const FloatA* __restrict__ p_a_block,
                            const FloatB* __restrict__ p_b_block,
                            FloatC* __restrict__ p_c_thread) const
    {
        constexpr auto a_block_mtx  = BlockMatrixA{};
        constexpr auto b_block_mtx  = BlockMatrixB{};
        constexpr auto c_thread_mtx = ThreadMatrixC{};

        constexpr index_t M = a_block_mtx.NCol();
        constexpr index_t N = b_block_mtx.NCol();
        constexpr index_t K = a_block_mtx.NRow(); // A is transposed

        constexpr index_t MPerThread = c_thread_mtx.NRow();
        constexpr index_t NPerThread = c_thread_mtx.NCol();

        // thread A, B for GEMM
        //   A is transposed, b is not
        constexpr auto a_thread_mtx =
            make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<MPerThread>{});

        constexpr auto b_thread_mtx =
            make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<NPerThread>{});

        // thread A-sub, B-sub for copy
        constexpr auto a_thread_sub_mtx = make_ConstantMatrixDescriptor(
            Number<KPerThreadLoop>{}, Number<MPerThreadSubC>{}, Number<MPerThread>{});

        constexpr auto b_thread_sub_mtx = make_ConstantMatrixDescriptor(
            Number<KPerThreadLoop>{}, Number<NPerThreadSubC>{}, Number<NPerThread>{});

        FloatA p_a_thread[a_thread_mtx.GetElementSpace()];
        FloatB p_b_thread[b_thread_mtx.GetElementSpace()];

        constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
        constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;

        // assertion for inline asm
        static_assert(is_same<FloatA, float>::value && is_same<FloatB, float>::value &&
                          is_same<FloatC, float>::value,
                      "Run_asm only deal with float\n");

        static_assert(MPerThreadSubC == 4 && NPerThreadSubC == 4 && KPerThreadLoop == 1 &&
                          MPerThread == 8 && NPerThread == 8,
                      "Run_asm cannot deal with this GEMM shape yet\n");

333
334
        static_assert(DataPerReadA == 4 && DataPerReadB == 4, "Run_asm only do float4 read\n");

335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
        static_assert(
            BlockMatrixStrideA == 0 && BatchPerThread == 1,
            "Run_asm can only deal with BlockMatrixStrideA == 0 && BatchPerThread == 1 for now\n");

        using Float4 = vector_type<float, 4>::MemoryType;

        Float4* reg_a = (Float4*)(p_a_thread);
        Float4* reg_b = (Float4*)(p_b_thread);
        Float4* reg_c = (Float4*)(p_c_thread);

        reg_a[0] = *reinterpret_cast<const Float4*>(&p_a_block[mMyThreadOffsetA]);
        reg_b[0] = *reinterpret_cast<const Float4*>(&p_b_block[mMyThreadOffsetB]);
        reg_b[1] =
            *reinterpret_cast<const Float4*>(&p_b_block[mMyThreadOffsetB + NPerLevel1Cluster]);
        reg_a[1] =
            *reinterpret_cast<const Float4*>(&p_a_block[mMyThreadOffsetA + MPerLevel1Cluster]);
        outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]);
        outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]);

#pragma unroll
        for(index_t k = 1; k < K; ++k)
        {
            reg_a[0] = *reinterpret_cast<const Float4*>(&p_a_block[mMyThreadOffsetA + k * M]);
            outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]);
            reg_b[0] = *reinterpret_cast<const Float4*>(&p_b_block[mMyThreadOffsetB + k * N]);
            outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]);
            reg_b[1] = *reinterpret_cast<const Float4*>(
                &p_b_block[mMyThreadOffsetB + k * N + NPerLevel1Cluster]);
            reg_a[1] = *reinterpret_cast<const Float4*>(
                &p_a_block[mMyThreadOffsetA + k * M + MPerLevel1Cluster]);
            outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]);
            outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]);
        }
        outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]);
        outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]);
    }
Chao Liu's avatar
Chao Liu committed
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466

    template <class FloatA, class FloatB, class FloatC>
    __device__ void Run_asm_v2(const FloatA* __restrict__ p_a_block,
                               const FloatB* __restrict__ p_b_block,
                               FloatC* __restrict__ p_c_thread) const
    {
        constexpr auto a_block_mtx  = BlockMatrixA{};
        constexpr auto b_block_mtx  = BlockMatrixB{};
        constexpr auto c_thread_mtx = ThreadMatrixC{};

        constexpr index_t M = a_block_mtx.NCol();
        constexpr index_t N = b_block_mtx.NCol();
        constexpr index_t K = a_block_mtx.NRow(); // A is transposed

        constexpr index_t MPerThread = c_thread_mtx.NRow();
        constexpr index_t NPerThread = c_thread_mtx.NCol();

        // thread A, B for GEMM
        //   A is transposed, b is not
        constexpr auto a_thread_mtx =
            make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<MPerThread>{});

        constexpr auto b_thread_mtx =
            make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<NPerThread>{});

        // thread A-sub, B-sub for copy
        constexpr auto a_thread_sub_mtx = make_ConstantMatrixDescriptor(
            Number<KPerThreadLoop>{}, Number<MPerThreadSubC>{}, Number<MPerThread>{});

        constexpr auto b_thread_sub_mtx = make_ConstantMatrixDescriptor(
            Number<KPerThreadLoop>{}, Number<NPerThreadSubC>{}, Number<NPerThread>{});

        FloatA p_a_thread[a_thread_mtx.GetElementSpace()];
        FloatB p_b_thread[b_thread_mtx.GetElementSpace()];

        constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
        constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;

        // assertion for inline asm
        static_assert(is_same<FloatA, float>::value && is_same<FloatB, float>::value &&
                          is_same<FloatC, float>::value,
                      "Run_asm only deal with float\n");

        static_assert(MPerThreadSubC == 4 && NPerThreadSubC == 4 && KPerThreadLoop == 1 &&
                          MPerThread == 8 && NPerThread == 8,
                      "Run_asm cannot deal with this GEMM shape yet\n");

        static_assert(DataPerReadA == 4 && DataPerReadB == 4, "Run_asm only do float4 read\n");

        static_assert(
            BlockMatrixStrideA == 0 && BatchPerThread == 1,
            "Run_asm can only deal with BlockMatrixStrideA == 0 && BatchPerThread == 1 for now\n");

        using Float4 = vector_type<float, 4>::MemoryType;

        Float4* reg_a = (Float4*)(p_a_thread);
        Float4* reg_b = (Float4*)(p_b_thread);
        Float4* reg_c = (Float4*)(p_c_thread);

        void* a_lds_loc = (void*)(p_a_block + mMyThreadOffsetA);
        void* b_lds_loc = (void*)(p_b_block + mMyThreadOffsetB);

        constexpr index_t a_lds_row_stride         = sizeof(Float) * M;
        constexpr index_t b_lds_row_stride         = sizeof(Float) * N;
        constexpr index_t a_lds_cluster_col_stride = sizeof(Float) * MPerLevel1Cluster;
        constexpr index_t b_lds_cluster_col_stride = sizeof(Float) * NPerLevel1Cluster;

        ds_read_b128(reg_a[0], a_lds_loc, 0);
        ds_read_b128(reg_b[0], b_lds_loc, 0);
        ds_read_b128(reg_b[1], b_lds_loc, b_lds_cluster_col_stride);
        ds_read_b128(reg_a[1], a_lds_loc, a_lds_cluster_col_stride);
        lgkmcnt(2);
        outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]);
        lgkmcnt(1);
        outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]);

#pragma unroll
        for(index_t k = 1; k < K; ++k)
        {
            ds_read_b128(reg_a[0], a_lds_loc, k * a_lds_row_stride);
            lgkmcnt(1);
            outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]);
            ds_read_b128(reg_b[0], b_lds_loc, k * b_lds_row_stride);
            outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]);
            ds_read_b128(reg_b[1], b_lds_loc, b_lds_cluster_col_stride + k * b_lds_row_stride);
            ds_read_b128(reg_a[1], a_lds_loc, a_lds_cluster_col_stride + k * a_lds_row_stride);
            lgkmcnt(2);
            outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]);
            lgkmcnt(1);
            outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]);
        }

        lgkmcnt(0);
        outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]);
        outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]);
    }
467
468
#endif

Chao Liu's avatar
Chao Liu committed
469
    template <class BlockMatrixC, index_t BlockMatrixStrideC, class FloatC>
Chao Liu's avatar
Chao Liu committed
470
471
472
473
474
475
    __device__ void CopyThreadMatrixCToBlockMatrixC(const FloatC* __restrict__ p_c_thread,
                                                    FloatC* __restrict__ p_c_block) const
    {
        constexpr auto c_block_mtx  = BlockMatrixC{};
        constexpr auto c_thread_mtx = ThreadMatrixC{};

Chao Liu's avatar
Chao Liu committed
476
477
        constexpr index_t MPerThread = c_thread_mtx.NRow();
        constexpr index_t NPerThread = c_thread_mtx.NCol();
Chao Liu's avatar
Chao Liu committed
478
479
480
481

        constexpr auto c_thread_sub_mtx = make_ConstantMatrixDescriptor(
            Number<MPerThreadSubC>{}, Number<NPerThreadSubC>{}, Number<NPerThread>{});

Chao Liu's avatar
Chao Liu committed
482
483
        constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
        constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
Chao Liu's avatar
Chao Liu committed
484

Chao Liu's avatar
Chao Liu committed
485
486
        constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
        constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
Chao Liu's avatar
Chao Liu committed
487
488
489

        const auto c_thread_mtx_begin = GetBeginOfThreadMatrixC(get_thread_local_1d_id());

Chao Liu's avatar
Chao Liu committed
490
        const index_t c_thread_offset =
Chao Liu's avatar
Chao Liu committed
491
492
493
            c_thread_mtx_begin.batch * BlockMatrixStrideC +
            c_block_mtx.Get1dIndex(c_thread_mtx_begin.row, c_thread_mtx_begin.col);

Chao Liu's avatar
Chao Liu committed
494
        for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
Chao Liu's avatar
Chao Liu committed
495
        {
Chao Liu's avatar
Chao Liu committed
496
            for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
Chao Liu's avatar
Chao Liu committed
497
498
499
            {
                threadwise_matrix_copy(
                    c_thread_sub_mtx,
Chao Liu's avatar
Chao Liu committed
500
501
502
                    p_c_thread +
                        c_thread_sub_mtx.Get1dIndex(m_repeat * MPerLevel1Cluster,
                                                    n_repeat * NPerLevel1Cluster),
Chao Liu's avatar
Chao Liu committed
503
504
505
506
507
508
509
510
511
512
                    c_block_mtx,
                    p_c_block +
                        c_block_mtx.Get1dIndex(m_repeat * MPerLevel1Cluster,
                                               n_repeat * NPerLevel1Cluster) +
                        c_thread_offset,
                    c_thread_sub_mtx.GetLengths());
            }
        }
    }
};