blockwise_batched_gemm.hpp 23 KB
Newer Older
1
2
3
#ifndef CK_BLOCKWISE_BATCHED_GEMM_HPP
#define CK_BLOCKWISE_BATCHED_GEMM_HPP

Chao Liu's avatar
Chao Liu committed
4
#include "threadwise_gemm.hpp"
Chao Liu's avatar
Chao Liu committed
5

6
7
namespace ck {

Chao Liu's avatar
Chao Liu committed
8
template <index_t BlockSize,
Chao Liu's avatar
Chao Liu committed
9
10
11
          class BlockMatrixA,
          class BlockMatrixB,
          class ThreadMatrixC,
Chao Liu's avatar
Chao Liu committed
12
13
14
15
16
17
18
19
20
21
22
          index_t BlockMatrixStrideA,
          index_t BlockMatrixStrideB,
          index_t ThreadMatrixStrideC,
          index_t BatchSize,
          index_t MPerThreadSubC,
          index_t NPerThreadSubC,
          index_t MLevel0Cluster,
          index_t NLevel0Cluster,
          index_t MLevel1Cluster,
          index_t NLevel1Cluster,
          index_t KPerThreadLoop,
Chao Liu's avatar
Chao Liu committed
23
24
25
          index_t BatchPerThread,
          index_t DataPerReadA,
          index_t DataPerReadB>
Chao Liu's avatar
Chao Liu committed
26
27
struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
{
Chao Liu's avatar
Chao Liu committed
28
29
    index_t mMyThreadOffsetA = 0;
    index_t mMyThreadOffsetB = 0;
Chao Liu's avatar
Chao Liu committed
30
31
32

    struct MatrixIndex
    {
Chao Liu's avatar
Chao Liu committed
33
34
35
        index_t batch;
        index_t row;
        index_t col;
Chao Liu's avatar
Chao Liu committed
36
37
38
39
40
41
42
    };

    __device__ BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2()
    {
        static_assert(BatchSize % BatchPerThread == 0,
                      "wrong! BatchSize is not dividable by BatchPerThread");

Chao Liu's avatar
Chao Liu committed
43
        constexpr index_t BatchThreadWork = BatchSize / BatchPerThread;
Chao Liu's avatar
Chao Liu committed
44

Chao Liu's avatar
Chao Liu committed
45
        constexpr index_t ThreadPerLevel1Cluster =
Chao Liu's avatar
Chao Liu committed
46
47
48
49
50
51
52
53
54
55
56
57
            MLevel0Cluster * NLevel0Cluster * MLevel1Cluster * NLevel1Cluster;

        static_assert(BlockSize == BatchThreadWork * ThreadPerLevel1Cluster,
                      "wrong! wrong blocksize\n");

        constexpr auto a_block_mtx  = BlockMatrixA{};
        constexpr auto b_block_mtx  = BlockMatrixB{};
        constexpr auto c_thread_mtx = ThreadMatrixC{};

        static_assert(a_block_mtx.NRow() == b_block_mtx.NRow(),
                      "wrong! K dimension not consistent\n");

Chao Liu's avatar
Chao Liu committed
58
59
        constexpr index_t M = a_block_mtx.NCol(); // A is transposed
        constexpr index_t N = b_block_mtx.NCol();
Chao Liu's avatar
Chao Liu committed
60

Chao Liu's avatar
Chao Liu committed
61
62
        constexpr index_t MPerThread = c_thread_mtx.NRow();
        constexpr index_t NPerThread = c_thread_mtx.NCol();
Chao Liu's avatar
Chao Liu committed
63
64
65
66

        static_assert((MPerThread % MPerThreadSubC == 0) && (NPerThread % NPerThreadSubC == 0),
                      "wrong! Cannot evenly divide thread work among repeat \n");

Chao Liu's avatar
Chao Liu committed
67
68
        constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
        constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
Chao Liu's avatar
Chao Liu committed
69
70
71
72

        static_assert((M % MRepeat == 0) && (N % NRepeat == 0),
                      "wrong! Cannot evenly divide work among repeat\n");

Chao Liu's avatar
Chao Liu committed
73
74
        constexpr index_t MPerLevel1Cluster = M / MRepeat;
        constexpr index_t NPerLevel1Cluster = N / NRepeat;
Chao Liu's avatar
Chao Liu committed
75
76
77
78
79

        static_assert((MPerLevel1Cluster % MLevel1Cluster == 0) &&
                          (NPerLevel1Cluster % NLevel1Cluster == 0),
                      "wrong! Cannot evenly divide work among Level1Cluster\n");

Chao Liu's avatar
Chao Liu committed
80
81
        constexpr index_t MPerLevel0Cluster = MPerLevel1Cluster / MLevel1Cluster;
        constexpr index_t NPerLevel0Cluster = NPerLevel1Cluster / NLevel1Cluster;
Chao Liu's avatar
Chao Liu committed
82
83
84
85
86
87
88
89
90
91
92
93

        static_assert((MPerLevel0Cluster % MLevel0Cluster == 0) &&
                          (NPerLevel0Cluster % NLevel0Cluster == 0),
                      "wrong! Cannot evenly divide work among Level0Cluster\n");

        static_assert((MPerThreadSubC == MPerLevel0Cluster / MLevel0Cluster) &&
                          (NPerThreadSubC == NPerLevel0Cluster / NLevel0Cluster),
                      "wrong! thread work size is wrong\n");

        const auto c_thread_mtx_index = GetBeginOfThreadMatrixC(get_thread_local_1d_id());

        mMyThreadOffsetA = c_thread_mtx_index.batch * BlockMatrixStrideA +
94
                           a_block_mtx.GetOffsetFromMultiIndex(0, c_thread_mtx_index.row);
Chao Liu's avatar
Chao Liu committed
95
96

        mMyThreadOffsetB = c_thread_mtx_index.batch * BlockMatrixStrideB +
97
                           b_block_mtx.GetOffsetFromMultiIndex(0, c_thread_mtx_index.col);
Chao Liu's avatar
Chao Liu committed
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117

#if 0
        if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
        {
            print_ConstantMatrixDescriptor(BlockMatrixA{}, "a_block_mtx: ");
            print_ConstantMatrixDescriptor(BlockMatrixB{}, "b_block_mtx: ");
            print_ConstantMatrixDescriptor(ThreadMatrixC{}, "c_thread_mtx: ");

            printf("%u %u, %u %u %u, %u %u\n",
                   get_block_1d_id(),
                   get_thread_local_1d_id(),
                   c_thread_mtx_index.batch,
                   c_thread_mtx_index.row,
                   c_thread_mtx_index.col,
                   mMyThreadOffsetA,
                   mMyThreadOffsetB);
        }
#endif
    }

Chao Liu's avatar
Chao Liu committed
118
    __device__ MatrixIndex GetBeginOfThreadMatrixC(index_t thread_id) const
Chao Liu's avatar
Chao Liu committed
119
    {
Chao Liu's avatar
Chao Liu committed
120
        constexpr index_t ThreadPerLevel1Cluster =
Chao Liu's avatar
Chao Liu committed
121
122
            MLevel0Cluster * NLevel0Cluster * MLevel1Cluster * NLevel1Cluster;

Chao Liu's avatar
Chao Liu committed
123
        constexpr index_t ThreadPerLevel0Cluster = MLevel0Cluster * NLevel0Cluster;
Chao Liu's avatar
Chao Liu committed
124

Chao Liu's avatar
Chao Liu committed
125
126
        index_t batch_work_id = thread_id / ThreadPerLevel1Cluster;
        index_t cluster_id    = thread_id - batch_work_id * ThreadPerLevel1Cluster;
Chao Liu's avatar
Chao Liu committed
127

Chao Liu's avatar
Chao Liu committed
128
129
130
        index_t level1_id   = cluster_id / ThreadPerLevel0Cluster;
        index_t level1_m_id = level1_id / NLevel1Cluster;
        index_t level1_n_id = level1_id % NLevel1Cluster;
Chao Liu's avatar
Chao Liu committed
131

Chao Liu's avatar
Chao Liu committed
132
133
134
        index_t level0_id   = cluster_id % ThreadPerLevel0Cluster;
        index_t level0_m_id = level0_id / NLevel0Cluster;
        index_t level0_n_id = level0_id % NLevel0Cluster;
Chao Liu's avatar
Chao Liu committed
135

Chao Liu's avatar
Chao Liu committed
136
137
        constexpr index_t MPerLevel0Cluster = MPerThreadSubC * MLevel0Cluster;
        constexpr index_t NPerLevel0Cluster = NPerThreadSubC * NLevel0Cluster;
Chao Liu's avatar
Chao Liu committed
138
139
140
141
142
143

        return MatrixIndex{batch_work_id * BatchPerThread,
                           level1_m_id * MPerLevel0Cluster + level0_m_id * MPerThreadSubC,
                           level1_n_id * NPerLevel0Cluster + level0_n_id * NPerThreadSubC};
    }

Chao Liu's avatar
Chao Liu committed
144
    // this should be optimized away because input will be known at compile time
Chao Liu's avatar
Chao Liu committed
145
    __device__ static MatrixIndex
Chao Liu's avatar
Chao Liu committed
146
    GetDistanceFromBeginOfThreadMatrixC(index_t batch_in_c, index_t m_in_c, index_t n_in_c)
Chao Liu's avatar
Chao Liu committed
147
148
149
    {
        constexpr auto c_thread_mtx = ThreadMatrixC{};

Chao Liu's avatar
Chao Liu committed
150
151
        constexpr index_t MPerThread = c_thread_mtx.NRow();
        constexpr index_t NPerThread = c_thread_mtx.NCol();
Chao Liu's avatar
Chao Liu committed
152

Chao Liu's avatar
Chao Liu committed
153
154
        constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
        constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
Chao Liu's avatar
Chao Liu committed
155

Chao Liu's avatar
Chao Liu committed
156
157
        constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
        constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
Chao Liu's avatar
Chao Liu committed
158

Chao Liu's avatar
Chao Liu committed
159
160
        index_t m_repeat = m_in_c / MPerThreadSubC;
        index_t n_repeat = n_in_c / NPerThreadSubC;
Chao Liu's avatar
Chao Liu committed
161

Chao Liu's avatar
Chao Liu committed
162
163
        index_t m_in_sub_c = m_in_c % MPerThreadSubC;
        index_t n_in_sub_c = n_in_c % NPerThreadSubC;
Chao Liu's avatar
Chao Liu committed
164
165
166
167
168
169

        return MatrixIndex{batch_in_c,
                           m_repeat * MPerLevel1Cluster + m_in_sub_c,
                           n_repeat * NPerLevel1Cluster + n_in_sub_c};
    }

Chao Liu's avatar
Chao Liu committed
170
    template <class FloatA, class FloatB, class FloatC>
Chao Liu's avatar
Chao Liu committed
171
172
    __device__ void Run(const FloatA* __restrict__ p_a_block,
                        const FloatB* __restrict__ p_b_block,
Chao Liu's avatar
Chao Liu committed
173
                        FloatC* __restrict__ p_c_thread) const
Chao Liu's avatar
Chao Liu committed
174
175
176
177
178
179
180
181
    {
        constexpr auto True  = integral_constant<bool, true>{};
        constexpr auto False = integral_constant<bool, false>{};

        constexpr auto a_block_mtx  = BlockMatrixA{};
        constexpr auto b_block_mtx  = BlockMatrixB{};
        constexpr auto c_thread_mtx = ThreadMatrixC{};

Chao Liu's avatar
Chao Liu committed
182
        constexpr index_t KPerBlock = a_block_mtx.NRow(); // A is transposed
Chao Liu's avatar
Chao Liu committed
183

Chao Liu's avatar
Chao Liu committed
184
185
        constexpr index_t MPerThread = c_thread_mtx.NRow();
        constexpr index_t NPerThread = c_thread_mtx.NCol();
Chao Liu's avatar
Chao Liu committed
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204

        // thread A, B for GEMM
        //   A is transposed, b is not
        constexpr auto a_thread_mtx =
            make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<MPerThread>{});

        constexpr auto b_thread_mtx =
            make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<NPerThread>{});

        // thread A-sub, B-sub for copy
        constexpr auto a_thread_sub_mtx = make_ConstantMatrixDescriptor(
            Number<KPerThreadLoop>{}, Number<MPerThreadSubC>{}, Number<MPerThread>{});

        constexpr auto b_thread_sub_mtx = make_ConstantMatrixDescriptor(
            Number<KPerThreadLoop>{}, Number<NPerThreadSubC>{}, Number<NPerThread>{});

        FloatA p_a_thread[a_thread_mtx.GetElementSpace()];
        FloatB p_b_thread[b_thread_mtx.GetElementSpace()];

Chao Liu's avatar
Chao Liu committed
205
206
        constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
        constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
Chao Liu's avatar
Chao Liu committed
207

Chao Liu's avatar
Chao Liu committed
208
209
        constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
        constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
Chao Liu's avatar
Chao Liu committed
210
211
212

// loop over k
#pragma unroll
Chao Liu's avatar
Chao Liu committed
213
        for(index_t k_begin = 0; k_begin < KPerBlock; k_begin += KPerThreadLoop)
Chao Liu's avatar
Chao Liu committed
214
215
216
        {
// loop over batch
#pragma unroll
Chao Liu's avatar
Chao Liu committed
217
            for(index_t ib = 0; ib < BatchPerThread; ++ib)
Chao Liu's avatar
Chao Liu committed
218
219
            {
                // read next batch of a, b
Chao Liu's avatar
Chao Liu committed
220
                if(BlockMatrixStrideA != 0 or ib == 0)
Chao Liu's avatar
Chao Liu committed
221
222
                {
#pragma unroll
Chao Liu's avatar
Chao Liu committed
223
                    for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
Chao Liu's avatar
Chao Liu committed
224
225
226
227
                    {
                        threadwise_matrix_copy(
                            a_block_mtx,
                            p_a_block +
228
229
                                a_block_mtx.GetOffsetFromMultiIndex(k_begin,
                                                                    m_repeat * MPerLevel1Cluster) +
Chao Liu's avatar
Chao Liu committed
230
                                ib * BlockMatrixStrideA + mMyThreadOffsetA,
Chao Liu's avatar
Chao Liu committed
231
                            a_thread_mtx,
232
233
                            p_a_thread +
                                a_thread_mtx.GetOffsetFromMultiIndex(0, m_repeat * MPerThreadSubC),
Chao Liu's avatar
Chao Liu committed
234
235
                            a_thread_sub_mtx.GetLengths(),
                            Number<DataPerReadA>{});
Chao Liu's avatar
Chao Liu committed
236
237
238
                    }
                }

Chao Liu's avatar
Chao Liu committed
239
                if(BlockMatrixStrideB != 0 or ib == 0)
Chao Liu's avatar
Chao Liu committed
240
241
                {
#pragma unroll
Chao Liu's avatar
Chao Liu committed
242
                    for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
Chao Liu's avatar
Chao Liu committed
243
244
245
246
                    {
                        threadwise_matrix_copy(
                            b_block_mtx,
                            p_b_block +
247
248
                                b_block_mtx.GetOffsetFromMultiIndex(k_begin,
                                                                    n_repeat * NPerLevel1Cluster) +
Chao Liu's avatar
Chao Liu committed
249
                                ib * BlockMatrixStrideB + mMyThreadOffsetB,
Chao Liu's avatar
Chao Liu committed
250
                            b_thread_mtx,
251
252
                            p_b_thread +
                                b_thread_mtx.GetOffsetFromMultiIndex(0, n_repeat * NPerThreadSubC),
Chao Liu's avatar
Chao Liu committed
253
254
                            b_thread_sub_mtx.GetLengths(),
                            Number<DataPerReadB>{});
Chao Liu's avatar
Chao Liu committed
255
256
257
                    }
                }

258
259
260
#if 0
                if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
                {
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
                    printf("a: %f %f %f %f %f %f %f %f, b: %f %f %f %f %f %f %f %f\n",
                           p_a_thread[0],
                           p_a_thread[1],
                           p_a_thread[2],
                           p_a_thread[3],
                           p_a_thread[4],
                           p_a_thread[5],
                           p_a_thread[6],
                           p_a_thread[7],
                           p_b_thread[0],
                           p_b_thread[1],
                           p_b_thread[2],
                           p_b_thread[3],
                           p_b_thread[4],
                           p_b_thread[5],
                           p_b_thread[6],
                           p_b_thread[7]);
278
279
280
                }
#endif

Chao Liu's avatar
Chao Liu committed
281
282
283
284
285
286
287
288
289
290
                threadwise_gemm(a_thread_mtx,
                                True,
                                p_a_thread,
                                b_thread_mtx,
                                False,
                                p_b_thread,
                                c_thread_mtx,
                                False,
                                p_c_thread + ib * ThreadMatrixStrideC);
            }
Chao Liu's avatar
Chao Liu committed
291
292
293
        }
    }

294
#if CK_USE_AMD_INLINE_ASM
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
    template <class FloatA, class FloatB, class FloatC>
    __device__ void Run_asm(const FloatA* __restrict__ p_a_block,
                            const FloatB* __restrict__ p_b_block,
                            FloatC* __restrict__ p_c_thread) const
    {
        constexpr auto a_block_mtx  = BlockMatrixA{};
        constexpr auto b_block_mtx  = BlockMatrixB{};
        constexpr auto c_thread_mtx = ThreadMatrixC{};

        constexpr index_t K = a_block_mtx.NRow(); // A is transposed

        constexpr index_t MPerThread = c_thread_mtx.NRow();
        constexpr index_t NPerThread = c_thread_mtx.NCol();

        // thread A, B for GEMM
        //   A is transposed, b is not
        constexpr auto a_thread_mtx =
            make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<MPerThread>{});

        constexpr auto b_thread_mtx =
            make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<NPerThread>{});

        // thread A-sub, B-sub for copy
        constexpr auto a_thread_sub_mtx = make_ConstantMatrixDescriptor(
            Number<KPerThreadLoop>{}, Number<MPerThreadSubC>{}, Number<MPerThread>{});

        constexpr auto b_thread_sub_mtx = make_ConstantMatrixDescriptor(
            Number<KPerThreadLoop>{}, Number<NPerThreadSubC>{}, Number<NPerThread>{});

        FloatA p_a_thread[a_thread_mtx.GetElementSpace()];
        FloatB p_b_thread[b_thread_mtx.GetElementSpace()];

        constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
        constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;

        // assertion for inline asm
        static_assert(is_same<FloatA, float>::value && is_same<FloatB, float>::value &&
                          is_same<FloatC, float>::value,
                      "Run_asm only deal with float\n");

        static_assert(MPerThreadSubC == 4 && NPerThreadSubC == 4 && KPerThreadLoop == 1 &&
                          MPerThread == 8 && NPerThread == 8,
                      "Run_asm cannot deal with this GEMM shape yet\n");

339
340
        static_assert(DataPerReadA == 4 && DataPerReadB == 4, "Run_asm only do float4 read\n");

341
342
343
344
345
346
347
348
349
350
351
352
        static_assert(
            BlockMatrixStrideA == 0 && BatchPerThread == 1,
            "Run_asm can only deal with BlockMatrixStrideA == 0 && BatchPerThread == 1 for now\n");

        using Float4 = vector_type<float, 4>::MemoryType;

        Float4* reg_a = (Float4*)(p_a_thread);
        Float4* reg_b = (Float4*)(p_b_thread);
        Float4* reg_c = (Float4*)(p_c_thread);

        reg_a[0] = *reinterpret_cast<const Float4*>(&p_a_block[mMyThreadOffsetA]);
        reg_b[0] = *reinterpret_cast<const Float4*>(&p_b_block[mMyThreadOffsetB]);
Chao Liu's avatar
Chao Liu committed
353
        reg_b[1] = *reinterpret_cast<const Float4*>(
354
355
            &p_b_block[b_block_mtx.GetOffsetFromMultiIndex(0, NPerLevel1Cluster) +
                       mMyThreadOffsetB]);
Chao Liu's avatar
Chao Liu committed
356
        reg_a[1] = *reinterpret_cast<const Float4*>(
357
358
            &p_a_block[a_block_mtx.GetOffsetFromMultiIndex(0, MPerLevel1Cluster) +
                       mMyThreadOffsetA]);
359
360
361
362
363
364
        outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]);
        outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]);

#pragma unroll
        for(index_t k = 1; k < K; ++k)
        {
Chao Liu's avatar
Chao Liu committed
365
            reg_a[0] = *reinterpret_cast<const Float4*>(
366
                &p_a_block[a_block_mtx.GetOffsetFromMultiIndex(k, 0) + mMyThreadOffsetA]);
367
            outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]);
Chao Liu's avatar
Chao Liu committed
368
            reg_b[0] = *reinterpret_cast<const Float4*>(
369
                &p_b_block[b_block_mtx.GetOffsetFromMultiIndex(k, 0) + mMyThreadOffsetB]);
370
371
            outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]);
            reg_b[1] = *reinterpret_cast<const Float4*>(
372
373
                &p_b_block[b_block_mtx.GetOffsetFromMultiIndex(k, NPerLevel1Cluster) +
                           mMyThreadOffsetB]);
374
            reg_a[1] = *reinterpret_cast<const Float4*>(
375
376
                &p_a_block[a_block_mtx.GetOffsetFromMultiIndex(k, MPerLevel1Cluster) +
                           mMyThreadOffsetA]);
377
378
379
380
381
382
            outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]);
            outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]);
        }
        outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]);
        outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]);
    }
Chao Liu's avatar
Chao Liu committed
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444

    template <class FloatA, class FloatB, class FloatC>
    __device__ void Run_asm_v2(const FloatA* __restrict__ p_a_block,
                               const FloatB* __restrict__ p_b_block,
                               FloatC* __restrict__ p_c_thread) const
    {
        constexpr auto a_block_mtx  = BlockMatrixA{};
        constexpr auto b_block_mtx  = BlockMatrixB{};
        constexpr auto c_thread_mtx = ThreadMatrixC{};

        constexpr index_t M = a_block_mtx.NCol();
        constexpr index_t N = b_block_mtx.NCol();
        constexpr index_t K = a_block_mtx.NRow(); // A is transposed

        constexpr index_t MPerThread = c_thread_mtx.NRow();
        constexpr index_t NPerThread = c_thread_mtx.NCol();

        // thread A, B for GEMM
        //   A is transposed, b is not
        constexpr auto a_thread_mtx =
            make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<MPerThread>{});

        constexpr auto b_thread_mtx =
            make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<NPerThread>{});

        // thread A-sub, B-sub for copy
        constexpr auto a_thread_sub_mtx = make_ConstantMatrixDescriptor(
            Number<KPerThreadLoop>{}, Number<MPerThreadSubC>{}, Number<MPerThread>{});

        constexpr auto b_thread_sub_mtx = make_ConstantMatrixDescriptor(
            Number<KPerThreadLoop>{}, Number<NPerThreadSubC>{}, Number<NPerThread>{});

        FloatA p_a_thread[a_thread_mtx.GetElementSpace()];
        FloatB p_b_thread[b_thread_mtx.GetElementSpace()];

        constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
        constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;

        // assertion for inline asm
        static_assert(is_same<FloatA, float>::value && is_same<FloatB, float>::value &&
                          is_same<FloatC, float>::value,
                      "Run_asm only deal with float\n");

        static_assert(MPerThreadSubC == 4 && NPerThreadSubC == 4 && KPerThreadLoop == 1 &&
                          MPerThread == 8 && NPerThread == 8,
                      "Run_asm cannot deal with this GEMM shape yet\n");

        static_assert(DataPerReadA == 4 && DataPerReadB == 4, "Run_asm only do float4 read\n");

        static_assert(
            BlockMatrixStrideA == 0 && BatchPerThread == 1,
            "Run_asm can only deal with BlockMatrixStrideA == 0 && BatchPerThread == 1 for now\n");

        using Float4 = vector_type<float, 4>::MemoryType;

        Float4* reg_a = (Float4*)(p_a_thread);
        Float4* reg_b = (Float4*)(p_b_thread);
        Float4* reg_c = (Float4*)(p_c_thread);

        void* a_lds_loc = (void*)(p_a_block + mMyThreadOffsetA);
        void* b_lds_loc = (void*)(p_b_block + mMyThreadOffsetB);

Chao Liu's avatar
Chao Liu committed
445
446
447
448
        constexpr index_t a_lds_row_stride         = sizeof(float) * a_block_mtx.RowStride();
        constexpr index_t b_lds_row_stride         = sizeof(float) * b_block_mtx.RowStride();
        constexpr index_t a_lds_cluster_col_stride = sizeof(float) * MPerLevel1Cluster;
        constexpr index_t b_lds_cluster_col_stride = sizeof(float) * NPerLevel1Cluster;
Chao Liu's avatar
Chao Liu committed
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478

        ds_read_b128(reg_a[0], a_lds_loc, 0);
        ds_read_b128(reg_b[0], b_lds_loc, 0);
        ds_read_b128(reg_b[1], b_lds_loc, b_lds_cluster_col_stride);
        ds_read_b128(reg_a[1], a_lds_loc, a_lds_cluster_col_stride);
        lgkmcnt(2);
        outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]);
        lgkmcnt(1);
        outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]);

#pragma unroll
        for(index_t k = 1; k < K; ++k)
        {
            ds_read_b128(reg_a[0], a_lds_loc, k * a_lds_row_stride);
            lgkmcnt(1);
            outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]);
            ds_read_b128(reg_b[0], b_lds_loc, k * b_lds_row_stride);
            outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]);
            ds_read_b128(reg_b[1], b_lds_loc, b_lds_cluster_col_stride + k * b_lds_row_stride);
            ds_read_b128(reg_a[1], a_lds_loc, a_lds_cluster_col_stride + k * a_lds_row_stride);
            lgkmcnt(2);
            outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]);
            lgkmcnt(1);
            outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]);
        }

        lgkmcnt(0);
        outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]);
        outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]);
    }
479
480
#endif

Chao Liu's avatar
Chao Liu committed
481
    template <class BlockMatrixC, index_t BlockMatrixStrideC, class FloatC>
Chao Liu's avatar
Chao Liu committed
482
483
484
485
486
487
    __device__ void CopyThreadMatrixCToBlockMatrixC(const FloatC* __restrict__ p_c_thread,
                                                    FloatC* __restrict__ p_c_block) const
    {
        constexpr auto c_block_mtx  = BlockMatrixC{};
        constexpr auto c_thread_mtx = ThreadMatrixC{};

Chao Liu's avatar
Chao Liu committed
488
489
        constexpr index_t MPerThread = c_thread_mtx.NRow();
        constexpr index_t NPerThread = c_thread_mtx.NCol();
Chao Liu's avatar
Chao Liu committed
490
491
492
493

        constexpr auto c_thread_sub_mtx = make_ConstantMatrixDescriptor(
            Number<MPerThreadSubC>{}, Number<NPerThreadSubC>{}, Number<NPerThread>{});

Chao Liu's avatar
Chao Liu committed
494
495
        constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
        constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
Chao Liu's avatar
Chao Liu committed
496

Chao Liu's avatar
Chao Liu committed
497
498
        constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
        constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
Chao Liu's avatar
Chao Liu committed
499
500
501

        const auto c_thread_mtx_begin = GetBeginOfThreadMatrixC(get_thread_local_1d_id());

Chao Liu's avatar
Chao Liu committed
502
        const index_t c_thread_offset =
Chao Liu's avatar
Chao Liu committed
503
            c_thread_mtx_begin.batch * BlockMatrixStrideC +
504
            c_block_mtx.GetOffsetFromMultiIndex(c_thread_mtx_begin.row, c_thread_mtx_begin.col);
Chao Liu's avatar
Chao Liu committed
505

Chao Liu's avatar
Chao Liu committed
506
        for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
Chao Liu's avatar
Chao Liu committed
507
        {
Chao Liu's avatar
Chao Liu committed
508
            for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
Chao Liu's avatar
Chao Liu committed
509
510
511
            {
                threadwise_matrix_copy(
                    c_thread_sub_mtx,
Chao Liu's avatar
Chao Liu committed
512
                    p_c_thread +
513
514
                        c_thread_sub_mtx.GetOffsetFromMultiIndex(m_repeat * MPerLevel1Cluster,
                                                                 n_repeat * NPerLevel1Cluster),
Chao Liu's avatar
Chao Liu committed
515
516
                    c_block_mtx,
                    p_c_block +
517
518
                        c_block_mtx.GetOffsetFromMultiIndex(m_repeat * MPerLevel1Cluster,
                                                            n_repeat * NPerLevel1Cluster) +
Chao Liu's avatar
Chao Liu committed
519
520
521
522
523
524
                        c_thread_offset,
                    c_thread_sub_mtx.GetLengths());
            }
        }
    }
};
525
526
527

} // namespace
#endif