blockwise_batched_gemm.hip.hpp 17.7 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
2
3
#pragma once
#include "threadwise_gemm.hip.hpp"

Chao Liu's avatar
Chao Liu committed
4
template <index_t BlockSize,
Chao Liu's avatar
Chao Liu committed
5
6
7
          class BlockMatrixA,
          class BlockMatrixB,
          class ThreadMatrixC,
Chao Liu's avatar
Chao Liu committed
8
9
10
11
12
13
14
15
16
17
18
          index_t BlockMatrixStrideA,
          index_t BlockMatrixStrideB,
          index_t ThreadMatrixStrideC,
          index_t BatchSize,
          index_t MPerThreadSubC,
          index_t NPerThreadSubC,
          index_t MLevel0Cluster,
          index_t NLevel0Cluster,
          index_t MLevel1Cluster,
          index_t NLevel1Cluster,
          index_t KPerThreadLoop,
Chao Liu's avatar
Chao Liu committed
19
20
21
          index_t BatchPerThread,
          index_t DataPerReadA,
          index_t DataPerReadB>
Chao Liu's avatar
Chao Liu committed
22
23
struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
{
Chao Liu's avatar
Chao Liu committed
24
25
    index_t mMyThreadOffsetA = 0;
    index_t mMyThreadOffsetB = 0;
Chao Liu's avatar
Chao Liu committed
26
27
28

    struct MatrixIndex
    {
Chao Liu's avatar
Chao Liu committed
29
30
31
        index_t batch;
        index_t row;
        index_t col;
Chao Liu's avatar
Chao Liu committed
32
33
34
35
36
37
38
    };

    __device__ BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2()
    {
        static_assert(BatchSize % BatchPerThread == 0,
                      "wrong! BatchSize is not dividable by BatchPerThread");

Chao Liu's avatar
Chao Liu committed
39
        constexpr index_t BatchThreadWork = BatchSize / BatchPerThread;
Chao Liu's avatar
Chao Liu committed
40

Chao Liu's avatar
Chao Liu committed
41
        constexpr index_t ThreadPerLevel1Cluster =
Chao Liu's avatar
Chao Liu committed
42
43
44
45
46
47
48
49
50
51
52
53
            MLevel0Cluster * NLevel0Cluster * MLevel1Cluster * NLevel1Cluster;

        static_assert(BlockSize == BatchThreadWork * ThreadPerLevel1Cluster,
                      "wrong! wrong blocksize\n");

        constexpr auto a_block_mtx  = BlockMatrixA{};
        constexpr auto b_block_mtx  = BlockMatrixB{};
        constexpr auto c_thread_mtx = ThreadMatrixC{};

        static_assert(a_block_mtx.NRow() == b_block_mtx.NRow(),
                      "wrong! K dimension not consistent\n");

Chao Liu's avatar
Chao Liu committed
54
55
56
        constexpr index_t M = a_block_mtx.NCol(); // A is transposed
        constexpr index_t N = b_block_mtx.NCol();
        constexpr index_t K = a_block_mtx.NRow();
Chao Liu's avatar
Chao Liu committed
57

Chao Liu's avatar
Chao Liu committed
58
59
        constexpr index_t MPerThread = c_thread_mtx.NRow();
        constexpr index_t NPerThread = c_thread_mtx.NCol();
Chao Liu's avatar
Chao Liu committed
60
61
62
63

        static_assert((MPerThread % MPerThreadSubC == 0) && (NPerThread % NPerThreadSubC == 0),
                      "wrong! Cannot evenly divide thread work among repeat \n");

Chao Liu's avatar
Chao Liu committed
64
65
        constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
        constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
Chao Liu's avatar
Chao Liu committed
66
67
68
69

        static_assert((M % MRepeat == 0) && (N % NRepeat == 0),
                      "wrong! Cannot evenly divide work among repeat\n");

Chao Liu's avatar
Chao Liu committed
70
71
        constexpr index_t MPerLevel1Cluster = M / MRepeat;
        constexpr index_t NPerLevel1Cluster = N / NRepeat;
Chao Liu's avatar
Chao Liu committed
72
73
74
75
76

        static_assert((MPerLevel1Cluster % MLevel1Cluster == 0) &&
                          (NPerLevel1Cluster % NLevel1Cluster == 0),
                      "wrong! Cannot evenly divide work among Level1Cluster\n");

Chao Liu's avatar
Chao Liu committed
77
78
        constexpr index_t MPerLevel0Cluster = MPerLevel1Cluster / MLevel1Cluster;
        constexpr index_t NPerLevel0Cluster = NPerLevel1Cluster / NLevel1Cluster;
Chao Liu's avatar
Chao Liu committed
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114

        static_assert((MPerLevel0Cluster % MLevel0Cluster == 0) &&
                          (NPerLevel0Cluster % NLevel0Cluster == 0),
                      "wrong! Cannot evenly divide work among Level0Cluster\n");

        static_assert((MPerThreadSubC == MPerLevel0Cluster / MLevel0Cluster) &&
                          (NPerThreadSubC == NPerLevel0Cluster / NLevel0Cluster),
                      "wrong! thread work size is wrong\n");

        const auto c_thread_mtx_index = GetBeginOfThreadMatrixC(get_thread_local_1d_id());

        mMyThreadOffsetA = c_thread_mtx_index.batch * BlockMatrixStrideA +
                           a_block_mtx.Get1dIndex(0, c_thread_mtx_index.row);

        mMyThreadOffsetB = c_thread_mtx_index.batch * BlockMatrixStrideB +
                           b_block_mtx.Get1dIndex(0, c_thread_mtx_index.col);

#if 0
        if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
        {
            print_ConstantMatrixDescriptor(BlockMatrixA{}, "a_block_mtx: ");
            print_ConstantMatrixDescriptor(BlockMatrixB{}, "b_block_mtx: ");
            print_ConstantMatrixDescriptor(ThreadMatrixC{}, "c_thread_mtx: ");

            printf("%u %u, %u %u %u, %u %u\n",
                   get_block_1d_id(),
                   get_thread_local_1d_id(),
                   c_thread_mtx_index.batch,
                   c_thread_mtx_index.row,
                   c_thread_mtx_index.col,
                   mMyThreadOffsetA,
                   mMyThreadOffsetB);
        }
#endif
    }

Chao Liu's avatar
Chao Liu committed
115
    __device__ MatrixIndex GetBeginOfThreadMatrixC(index_t thread_id) const
Chao Liu's avatar
Chao Liu committed
116
    {
Chao Liu's avatar
Chao Liu committed
117
        constexpr index_t BatchThreadWork = BatchSize / BatchPerThread;
Chao Liu's avatar
Chao Liu committed
118

Chao Liu's avatar
Chao Liu committed
119
        constexpr index_t ThreadPerLevel1Cluster =
Chao Liu's avatar
Chao Liu committed
120
121
            MLevel0Cluster * NLevel0Cluster * MLevel1Cluster * NLevel1Cluster;

Chao Liu's avatar
Chao Liu committed
122
        constexpr index_t ThreadPerLevel0Cluster = MLevel0Cluster * NLevel0Cluster;
Chao Liu's avatar
Chao Liu committed
123

Chao Liu's avatar
Chao Liu committed
124
125
        index_t batch_work_id = thread_id / ThreadPerLevel1Cluster;
        index_t cluster_id    = thread_id - batch_work_id * ThreadPerLevel1Cluster;
Chao Liu's avatar
Chao Liu committed
126

Chao Liu's avatar
Chao Liu committed
127
128
129
        index_t level1_id   = cluster_id / ThreadPerLevel0Cluster;
        index_t level1_m_id = level1_id / NLevel1Cluster;
        index_t level1_n_id = level1_id % NLevel1Cluster;
Chao Liu's avatar
Chao Liu committed
130

Chao Liu's avatar
Chao Liu committed
131
132
133
        index_t level0_id   = cluster_id % ThreadPerLevel0Cluster;
        index_t level0_m_id = level0_id / NLevel0Cluster;
        index_t level0_n_id = level0_id % NLevel0Cluster;
Chao Liu's avatar
Chao Liu committed
134

Chao Liu's avatar
Chao Liu committed
135
136
        constexpr index_t MPerLevel0Cluster = MPerThreadSubC * MLevel0Cluster;
        constexpr index_t NPerLevel0Cluster = NPerThreadSubC * NLevel0Cluster;
Chao Liu's avatar
Chao Liu committed
137
138
139
140
141
142

        return MatrixIndex{batch_work_id * BatchPerThread,
                           level1_m_id * MPerLevel0Cluster + level0_m_id * MPerThreadSubC,
                           level1_n_id * NPerLevel0Cluster + level0_n_id * NPerThreadSubC};
    }

Chao Liu's avatar
Chao Liu committed
143
    // this should be optimized away because input will be known at compile time
Chao Liu's avatar
Chao Liu committed
144
    __device__ static MatrixIndex
Chao Liu's avatar
Chao Liu committed
145
    GetDistanceFromBeginOfThreadMatrixC(index_t batch_in_c, index_t m_in_c, index_t n_in_c)
Chao Liu's avatar
Chao Liu committed
146
147
148
    {
        constexpr auto c_thread_mtx = ThreadMatrixC{};

Chao Liu's avatar
Chao Liu committed
149
150
        constexpr index_t MPerThread = c_thread_mtx.NRow();
        constexpr index_t NPerThread = c_thread_mtx.NCol();
Chao Liu's avatar
Chao Liu committed
151

Chao Liu's avatar
Chao Liu committed
152
153
        constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
        constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
Chao Liu's avatar
Chao Liu committed
154

Chao Liu's avatar
Chao Liu committed
155
156
        constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
        constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
Chao Liu's avatar
Chao Liu committed
157

Chao Liu's avatar
Chao Liu committed
158
159
        index_t m_repeat = m_in_c / MPerThreadSubC;
        index_t n_repeat = n_in_c / NPerThreadSubC;
Chao Liu's avatar
Chao Liu committed
160

Chao Liu's avatar
Chao Liu committed
161
162
        index_t m_in_sub_c = m_in_c % MPerThreadSubC;
        index_t n_in_sub_c = n_in_c % NPerThreadSubC;
Chao Liu's avatar
Chao Liu committed
163
164
165
166
167
168

        return MatrixIndex{batch_in_c,
                           m_repeat * MPerLevel1Cluster + m_in_sub_c,
                           n_repeat * NPerLevel1Cluster + n_in_sub_c};
    }

Chao Liu's avatar
Chao Liu committed
169
    template <class FloatA, class FloatB, class FloatC>
Chao Liu's avatar
Chao Liu committed
170
171
    __device__ void Run(const FloatA* __restrict__ p_a_block,
                        const FloatB* __restrict__ p_b_block,
Chao Liu's avatar
Chao Liu committed
172
                        FloatC* __restrict__ p_c_thread) const
Chao Liu's avatar
Chao Liu committed
173
174
175
176
177
178
179
180
    {
        constexpr auto True  = integral_constant<bool, true>{};
        constexpr auto False = integral_constant<bool, false>{};

        constexpr auto a_block_mtx  = BlockMatrixA{};
        constexpr auto b_block_mtx  = BlockMatrixB{};
        constexpr auto c_thread_mtx = ThreadMatrixC{};

Chao Liu's avatar
Chao Liu committed
181
        constexpr index_t KPerBlock = a_block_mtx.NRow(); // A is transposed
Chao Liu's avatar
Chao Liu committed
182

Chao Liu's avatar
Chao Liu committed
183
184
        constexpr index_t MPerThread = c_thread_mtx.NRow();
        constexpr index_t NPerThread = c_thread_mtx.NCol();
Chao Liu's avatar
Chao Liu committed
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203

        // thread A, B for GEMM
        //   A is transposed, b is not
        constexpr auto a_thread_mtx =
            make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<MPerThread>{});

        constexpr auto b_thread_mtx =
            make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<NPerThread>{});

        // thread A-sub, B-sub for copy
        constexpr auto a_thread_sub_mtx = make_ConstantMatrixDescriptor(
            Number<KPerThreadLoop>{}, Number<MPerThreadSubC>{}, Number<MPerThread>{});

        constexpr auto b_thread_sub_mtx = make_ConstantMatrixDescriptor(
            Number<KPerThreadLoop>{}, Number<NPerThreadSubC>{}, Number<NPerThread>{});

        FloatA p_a_thread[a_thread_mtx.GetElementSpace()];
        FloatB p_b_thread[b_thread_mtx.GetElementSpace()];

Chao Liu's avatar
Chao Liu committed
204
205
        constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
        constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
Chao Liu's avatar
Chao Liu committed
206

Chao Liu's avatar
Chao Liu committed
207
208
        constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
        constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
Chao Liu's avatar
Chao Liu committed
209
210
211

// loop over k
#pragma unroll
Chao Liu's avatar
Chao Liu committed
212
        for(index_t k_begin = 0; k_begin < KPerBlock; k_begin += KPerThreadLoop)
Chao Liu's avatar
Chao Liu committed
213
214
215
        {
// loop over batch
#pragma unroll
Chao Liu's avatar
Chao Liu committed
216
            for(index_t ib = 0; ib < BatchPerThread; ++ib)
Chao Liu's avatar
Chao Liu committed
217
218
            {
                // read next batch of a, b
Chao Liu's avatar
Chao Liu committed
219
                if(BlockMatrixStrideA != 0 or ib == 0)
Chao Liu's avatar
Chao Liu committed
220
221
                {
#pragma unroll
Chao Liu's avatar
Chao Liu committed
222
                    for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
Chao Liu's avatar
Chao Liu committed
223
224
225
226
227
                    {
                        threadwise_matrix_copy(
                            a_block_mtx,
                            p_a_block +
                                a_block_mtx.Get1dIndex(k_begin, m_repeat * MPerLevel1Cluster) +
Chao Liu's avatar
Chao Liu committed
228
                                ib * BlockMatrixStrideA + mMyThreadOffsetA,
Chao Liu's avatar
Chao Liu committed
229
230
                            a_thread_mtx,
                            p_a_thread + a_thread_mtx.Get1dIndex(0, m_repeat * MPerThreadSubC),
Chao Liu's avatar
Chao Liu committed
231
232
                            a_thread_sub_mtx.GetLengths(),
                            Number<DataPerReadA>{});
Chao Liu's avatar
Chao Liu committed
233
234
235
                    }
                }

Chao Liu's avatar
Chao Liu committed
236
                if(BlockMatrixStrideB != 0 or ib == 0)
Chao Liu's avatar
Chao Liu committed
237
238
                {
#pragma unroll
Chao Liu's avatar
Chao Liu committed
239
                    for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
Chao Liu's avatar
Chao Liu committed
240
241
242
243
244
                    {
                        threadwise_matrix_copy(
                            b_block_mtx,
                            p_b_block +
                                b_block_mtx.Get1dIndex(k_begin, n_repeat * NPerLevel1Cluster) +
Chao Liu's avatar
Chao Liu committed
245
                                ib * BlockMatrixStrideB + mMyThreadOffsetB,
Chao Liu's avatar
Chao Liu committed
246
247
                            b_thread_mtx,
                            p_b_thread + b_thread_mtx.Get1dIndex(0, n_repeat * NPerThreadSubC),
Chao Liu's avatar
Chao Liu committed
248
249
                            b_thread_sub_mtx.GetLengths(),
                            Number<DataPerReadB>{});
Chao Liu's avatar
Chao Liu committed
250
251
252
                    }
                }

253
254
255
256
257
258
259
260
261
#if 0
                if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
                {
                        printf("a: %f %f %f %f %f %f %f %f, b: %f %f %f %f %f %f %f %f\n",
                               p_a_thread[0], p_a_thread[1], p_a_thread[2], p_a_thread[3], p_a_thread[4], p_a_thread[5], p_a_thread[6], p_a_thread[7],
                               p_b_thread[0], p_b_thread[1], p_b_thread[2], p_b_thread[3], p_b_thread[4], p_b_thread[5], p_b_thread[6], p_b_thread[7]);
                }
#endif

Chao Liu's avatar
Chao Liu committed
262
263
264
265
266
267
268
269
270
271
                threadwise_gemm(a_thread_mtx,
                                True,
                                p_a_thread,
                                b_thread_mtx,
                                False,
                                p_b_thread,
                                c_thread_mtx,
                                False,
                                p_c_thread + ib * ThreadMatrixStrideC);
            }
Chao Liu's avatar
Chao Liu committed
272
273
274
        }
    }

275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
#if DEVICE_BACKEND_HIP
    template <class FloatA, class FloatB, class FloatC>
    __device__ void Run_asm(const FloatA* __restrict__ p_a_block,
                            const FloatB* __restrict__ p_b_block,
                            FloatC* __restrict__ p_c_thread) const
    {
        constexpr auto True  = integral_constant<bool, true>{};
        constexpr auto False = integral_constant<bool, false>{};

        constexpr auto a_block_mtx  = BlockMatrixA{};
        constexpr auto b_block_mtx  = BlockMatrixB{};
        constexpr auto c_thread_mtx = ThreadMatrixC{};

        constexpr index_t M = a_block_mtx.NCol();
        constexpr index_t N = b_block_mtx.NCol();
        constexpr index_t K = a_block_mtx.NRow(); // A is transposed

        constexpr index_t MPerThread = c_thread_mtx.NRow();
        constexpr index_t NPerThread = c_thread_mtx.NCol();

        // thread A, B for GEMM
        //   A is transposed, b is not
        constexpr auto a_thread_mtx =
            make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<MPerThread>{});

        constexpr auto b_thread_mtx =
            make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<NPerThread>{});

        // thread A-sub, B-sub for copy
        constexpr auto a_thread_sub_mtx = make_ConstantMatrixDescriptor(
            Number<KPerThreadLoop>{}, Number<MPerThreadSubC>{}, Number<MPerThread>{});

        constexpr auto b_thread_sub_mtx = make_ConstantMatrixDescriptor(
            Number<KPerThreadLoop>{}, Number<NPerThreadSubC>{}, Number<NPerThread>{});

        FloatA p_a_thread[a_thread_mtx.GetElementSpace()];
        FloatB p_b_thread[b_thread_mtx.GetElementSpace()];

        constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
        constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;

        // assertion for inline asm
        static_assert(is_same<FloatA, float>::value && is_same<FloatB, float>::value &&
                          is_same<FloatC, float>::value,
                      "Run_asm only deal with float\n");

        static_assert(MPerThreadSubC == 4 && NPerThreadSubC == 4 && KPerThreadLoop == 1 &&
                          MPerThread == 8 && NPerThread == 8,
                      "Run_asm cannot deal with this GEMM shape yet\n");

325
326
        static_assert(DataPerReadA == 4 && DataPerReadB == 4, "Run_asm only do float4 read\n");

327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
        static_assert(
            BlockMatrixStrideA == 0 && BatchPerThread == 1,
            "Run_asm can only deal with BlockMatrixStrideA == 0 && BatchPerThread == 1 for now\n");

        using Float4 = vector_type<float, 4>::MemoryType;

        Float4* reg_a = (Float4*)(p_a_thread);
        Float4* reg_b = (Float4*)(p_b_thread);
        Float4* reg_c = (Float4*)(p_c_thread);

        reg_a[0] = *reinterpret_cast<const Float4*>(&p_a_block[mMyThreadOffsetA]);
        reg_b[0] = *reinterpret_cast<const Float4*>(&p_b_block[mMyThreadOffsetB]);
        reg_b[1] =
            *reinterpret_cast<const Float4*>(&p_b_block[mMyThreadOffsetB + NPerLevel1Cluster]);
        reg_a[1] =
            *reinterpret_cast<const Float4*>(&p_a_block[mMyThreadOffsetA + MPerLevel1Cluster]);
        outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]);
        outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]);

#pragma unroll
        for(index_t k = 1; k < K; ++k)
        {
            reg_a[0] = *reinterpret_cast<const Float4*>(&p_a_block[mMyThreadOffsetA + k * M]);
            outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]);
            reg_b[0] = *reinterpret_cast<const Float4*>(&p_b_block[mMyThreadOffsetB + k * N]);
            outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]);
            reg_b[1] = *reinterpret_cast<const Float4*>(
                &p_b_block[mMyThreadOffsetB + k * N + NPerLevel1Cluster]);
            reg_a[1] = *reinterpret_cast<const Float4*>(
                &p_a_block[mMyThreadOffsetA + k * M + MPerLevel1Cluster]);
            outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]);
            outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]);
        }
        outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]);
        outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]);
    }
#endif

Chao Liu's avatar
Chao Liu committed
365
    template <class BlockMatrixC, index_t BlockMatrixStrideC, class FloatC>
Chao Liu's avatar
Chao Liu committed
366
367
368
369
370
371
    __device__ void CopyThreadMatrixCToBlockMatrixC(const FloatC* __restrict__ p_c_thread,
                                                    FloatC* __restrict__ p_c_block) const
    {
        constexpr auto c_block_mtx  = BlockMatrixC{};
        constexpr auto c_thread_mtx = ThreadMatrixC{};

Chao Liu's avatar
Chao Liu committed
372
373
        constexpr index_t MPerThread = c_thread_mtx.NRow();
        constexpr index_t NPerThread = c_thread_mtx.NCol();
Chao Liu's avatar
Chao Liu committed
374
375
376
377

        constexpr auto c_thread_sub_mtx = make_ConstantMatrixDescriptor(
            Number<MPerThreadSubC>{}, Number<NPerThreadSubC>{}, Number<NPerThread>{});

Chao Liu's avatar
Chao Liu committed
378
379
        constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
        constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
Chao Liu's avatar
Chao Liu committed
380

Chao Liu's avatar
Chao Liu committed
381
382
        constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
        constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
Chao Liu's avatar
Chao Liu committed
383
384
385

        const auto c_thread_mtx_begin = GetBeginOfThreadMatrixC(get_thread_local_1d_id());

Chao Liu's avatar
Chao Liu committed
386
        const index_t c_thread_offset =
Chao Liu's avatar
Chao Liu committed
387
388
389
            c_thread_mtx_begin.batch * BlockMatrixStrideC +
            c_block_mtx.Get1dIndex(c_thread_mtx_begin.row, c_thread_mtx_begin.col);

Chao Liu's avatar
Chao Liu committed
390
        for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
Chao Liu's avatar
Chao Liu committed
391
        {
Chao Liu's avatar
Chao Liu committed
392
            for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
Chao Liu's avatar
Chao Liu committed
393
394
395
            {
                threadwise_matrix_copy(
                    c_thread_sub_mtx,
Chao Liu's avatar
tidy up  
Chao Liu committed
396
397
398
                    p_c_thread +
                        c_thread_sub_mtx.Get1dIndex(m_repeat * MPerLevel1Cluster,
                                                    n_repeat * NPerLevel1Cluster),
Chao Liu's avatar
Chao Liu committed
399
400
401
402
403
404
405
406
407
408
                    c_block_mtx,
                    p_c_block +
                        c_block_mtx.Get1dIndex(m_repeat * MPerLevel1Cluster,
                                               n_repeat * NPerLevel1Cluster) +
                        c_thread_offset,
                    c_thread_sub_mtx.GetLengths());
            }
        }
    }
};