blockwise_gemm.hpp 18.3 KB
Newer Older
1
2
3
#ifndef CK_BLOCKWISE_GEMM_HPP
#define CK_BLOCKWISE_GEMM_HPP

Chao Liu's avatar
Chao Liu committed
4
5
#include "common.hpp"
#include "threadwise_gemm.hpp"
Chao Liu's avatar
Chao Liu committed
6

7
8
namespace ck {

Chao Liu's avatar
Chao Liu committed
9
10
// if following number are power of 2, index calculation shall be greatly reduced:
//    MPerThreadSubC, NPerThreadSubC, MLevel0Cluster, NLevel0Cluster, MLevel1Cluster, NLevel1Cluster
Chao Liu's avatar
Chao Liu committed
11
template <index_t BlockSize,
Chao Liu's avatar
Chao Liu committed
12
13
          class BlockMatrixA,
          class BlockMatrixB,
Chao Liu's avatar
Chao Liu committed
14
          class ThreadMatrixC,
Chao Liu's avatar
Chao Liu committed
15
16
17
18
19
20
          index_t MPerThreadSubC,
          index_t NPerThreadSubC,
          index_t MLevel0Cluster,
          index_t NLevel0Cluster,
          index_t MLevel1Cluster,
          index_t NLevel1Cluster,
Chao Liu's avatar
Chao Liu committed
21
22
23
          index_t KPerThreadLoop,
          index_t DataPerReadA,
          index_t DataPerReadB>
Chao Liu's avatar
Chao Liu committed
24
25
26
27
struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
{
    struct MatrixIndex
    {
Chao Liu's avatar
Chao Liu committed
28
29
        index_t row;
        index_t col;
Chao Liu's avatar
Chao Liu committed
30
31
    };

Chao Liu's avatar
Chao Liu committed
32
33
    index_t mMyThreadOffsetA;
    index_t mMyThreadOffsetB;
Chao Liu's avatar
Chao Liu committed
34
35
36

    __device__ BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2()
    {
Chao Liu's avatar
Chao Liu committed
37
        constexpr index_t ThreadPerLevel1Cluster =
Chao Liu's avatar
Chao Liu committed
38
39
40
41
            MLevel0Cluster * NLevel0Cluster * MLevel1Cluster * NLevel1Cluster;

        static_assert(BlockSize == ThreadPerLevel1Cluster, "wrong! wrong blocksize\n");

Chao Liu's avatar
Chao Liu committed
42
        static_assert(BlockMatrixA::NRow() == BlockMatrixB::NRow(),
Chao Liu's avatar
Chao Liu committed
43
44
                      "wrong! K dimension not consistent\n");

Chao Liu's avatar
Chao Liu committed
45
46
47
        constexpr index_t M = BlockMatrixA::NCol(); // A is transposed
        constexpr index_t N = BlockMatrixB::NCol();
        constexpr index_t K = BlockMatrixA::NRow();
Chao Liu's avatar
Chao Liu committed
48

Chao Liu's avatar
Chao Liu committed
49
50
51
        static_assert(M % (MPerThreadSubC * MLevel0Cluster * MLevel1Cluster) == 0 &&
                          N % (NPerThreadSubC * NLevel0Cluster * NLevel1Cluster) == 0,
                      "wrong! Cannot evenly divide work among\n");
Chao Liu's avatar
Chao Liu committed
52

Chao Liu's avatar
Chao Liu committed
53
        static_assert(is_same_type(ThreadMatrixC::GetLengths(), GetThreadMatrixCLengths()),
Chao Liu's avatar
Chao Liu committed
54
                      "wrong! ThreadMatrixC lengths is wrong");
Chao Liu's avatar
Chao Liu committed
55

Chao Liu's avatar
Chao Liu committed
56
        auto c_thread_mtx_index = GetBeginOfThreadMatrixC(get_thread_local_1d_id());
Chao Liu's avatar
Chao Liu committed
57

58
59
        mMyThreadOffsetA = BlockMatrixA::GetOffsetFromMultiIndex(0, c_thread_mtx_index.row);
        mMyThreadOffsetB = BlockMatrixB::GetOffsetFromMultiIndex(0, c_thread_mtx_index.col);
Chao Liu's avatar
Chao Liu committed
60
    }
Chao Liu's avatar
Chao Liu committed
61

Chao Liu's avatar
Chao Liu committed
62
    __device__ static constexpr auto GetThreadMatrixCLengths()
Chao Liu's avatar
Chao Liu committed
63
64
65
    {
        constexpr index_t M = BlockMatrixA::NCol(); // A is transposed
        constexpr index_t N = BlockMatrixB::NCol();
Chao Liu's avatar
Chao Liu committed
66

Chao Liu's avatar
Chao Liu committed
67
68
        constexpr index_t MRepeat = M / (MPerThreadSubC * MLevel0Cluster * MLevel1Cluster);
        constexpr index_t NRepeat = N / (NPerThreadSubC * NLevel0Cluster * NLevel1Cluster);
Chao Liu's avatar
Chao Liu committed
69

Chao Liu's avatar
Chao Liu committed
70
        return Sequence<MRepeat * MPerThreadSubC, NRepeat * NPerThreadSubC>{};
Chao Liu's avatar
Chao Liu committed
71
72
    }

Chao Liu's avatar
Chao Liu committed
73
    __device__ static MatrixIndex GetBeginOfThreadMatrixC(index_t thread_id)
Chao Liu's avatar
Chao Liu committed
74
    {
Chao Liu's avatar
Chao Liu committed
75
        constexpr index_t ThreadPerLevel0Cluster = MLevel0Cluster * NLevel0Cluster;
Chao Liu's avatar
Chao Liu committed
76

Chao Liu's avatar
Chao Liu committed
77
78
79
        index_t level1_id   = thread_id / ThreadPerLevel0Cluster;
        index_t level1_m_id = level1_id / NLevel1Cluster;
        index_t level1_n_id = level1_id % NLevel1Cluster;
Chao Liu's avatar
Chao Liu committed
80

Chao Liu's avatar
Chao Liu committed
81
82
83
        index_t level0_id   = thread_id % ThreadPerLevel0Cluster;
        index_t level0_m_id = level0_id / NLevel0Cluster;
        index_t level0_n_id = level0_id % NLevel0Cluster;
Chao Liu's avatar
Chao Liu committed
84

Chao Liu's avatar
Chao Liu committed
85
86
        constexpr index_t MPerLevel0Cluster = MPerThreadSubC * MLevel0Cluster;
        constexpr index_t NPerLevel0Cluster = NPerThreadSubC * NLevel0Cluster;
Chao Liu's avatar
Chao Liu committed
87
88
89
90
91

        return MatrixIndex{level1_m_id * MPerLevel0Cluster + level0_m_id * MPerThreadSubC,
                           level1_n_id * NPerLevel0Cluster + level0_n_id * NPerThreadSubC};
    }

Chao Liu's avatar
Chao Liu committed
92
93
    __device__ static MatrixIndex GetDistanceFromBeginOfThreadMatrixC(index_t m_in_c,
                                                                      index_t n_in_c)
Chao Liu's avatar
Chao Liu committed
94
    {
Chao Liu's avatar
Chao Liu committed
95
        constexpr auto c_thread_mtx = ThreadMatrixC{};
Chao Liu's avatar
Chao Liu committed
96

Chao Liu's avatar
Chao Liu committed
97
98
        constexpr index_t MPerThread = c_thread_mtx.NRow();
        constexpr index_t NPerThread = c_thread_mtx.NCol();
Chao Liu's avatar
Chao Liu committed
99

Chao Liu's avatar
Chao Liu committed
100
101
        constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
        constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
Chao Liu's avatar
Chao Liu committed
102

Chao Liu's avatar
Chao Liu committed
103
104
        constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
        constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
Chao Liu's avatar
Chao Liu committed
105

Chao Liu's avatar
Chao Liu committed
106
107
        index_t m_repeat = m_in_c / MPerThreadSubC;
        index_t n_repeat = n_in_c / NPerThreadSubC;
Chao Liu's avatar
Chao Liu committed
108

Chao Liu's avatar
Chao Liu committed
109
110
        index_t m_in_sub_c = m_in_c % MPerThreadSubC;
        index_t n_in_sub_c = n_in_c % NPerThreadSubC;
Chao Liu's avatar
Chao Liu committed
111
112
113
114
115

        return MatrixIndex{m_repeat * MPerLevel1Cluster + m_in_sub_c,
                           n_repeat * NPerLevel1Cluster + n_in_sub_c};
    }

116
#if CK_USE_AMD_INLINE_ASM
117
    // TODO: this is not working correctly
118
    template <class FloatA, class FloatB, class FloatC>
Chao Liu's avatar
Chao Liu committed
119
120
    __device__ void Run_asm(const FloatA* __restrict__ p_a_block,
                            const FloatB* __restrict__ p_b_block,
121
                            FloatC* __restrict__ p_c_thread) const
Chao Liu's avatar
Chao Liu committed
122
    {
Chao Liu's avatar
Chao Liu committed
123
124
        constexpr auto True  = integral_constant<bool, true>{};
        constexpr auto False = integral_constant<bool, false>{};
Chao Liu's avatar
Chao Liu committed
125

Chao Liu's avatar
Chao Liu committed
126
127
128
        constexpr auto a_block_mtx  = BlockMatrixA{};
        constexpr auto b_block_mtx  = BlockMatrixB{};
        constexpr auto c_thread_mtx = ThreadMatrixC{};
Chao Liu's avatar
Chao Liu committed
129

Chao Liu's avatar
Chao Liu committed
130
131
132
        constexpr index_t M = a_block_mtx.NCol();
        constexpr index_t N = b_block_mtx.NCol();
        constexpr index_t K = a_block_mtx.NRow();
Chao Liu's avatar
Chao Liu committed
133

Chao Liu's avatar
Chao Liu committed
134
135
        constexpr index_t MPerThread = c_thread_mtx.NRow();
        constexpr index_t NPerThread = c_thread_mtx.NCol();
Chao Liu's avatar
Chao Liu committed
136
137

        // thread A, B for GEMM
Chao Liu's avatar
Chao Liu committed
138
139
        constexpr auto a_thread_mtx =
            make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<MPerThread>{});
Chao Liu's avatar
Chao Liu committed
140

Chao Liu's avatar
Chao Liu committed
141
142
        constexpr auto b_thread_mtx =
            make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<NPerThread>{});
Chao Liu's avatar
Chao Liu committed
143
144

        // thread A-sub, B-sub for copy
Chao Liu's avatar
Chao Liu committed
145
146
        constexpr auto a_thread_sub_mtx = make_ConstantMatrixDescriptor(
            Number<KPerThreadLoop>{}, Number<MPerThreadSubC>{}, Number<MPerThread>{});
Chao Liu's avatar
Chao Liu committed
147

Chao Liu's avatar
Chao Liu committed
148
149
        constexpr auto b_thread_sub_mtx = make_ConstantMatrixDescriptor(
            Number<KPerThreadLoop>{}, Number<NPerThreadSubC>{}, Number<NPerThread>{});
Chao Liu's avatar
Chao Liu committed
150

151
152
153
154
155
156
157
158
159
160
161
        FloatA p_a_thread[a_thread_mtx.GetElementSpace()];
        FloatB p_b_thread[b_thread_mtx.GetElementSpace()];

        constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
        constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;

        // assertion for inline asm
        static_assert(is_same<FloatA, float>::value && is_same<FloatB, float>::value &&
                          is_same<FloatC, float>::value,
                      "Run_asm only deal with float\n");

Chao Liu's avatar
Chao Liu committed
162
163
164
165
        static_assert(MPerThreadSubC == 4 && NPerThreadSubC == 4 && KPerThreadLoop == 1 &&
                          MPerThread == 8 && NPerThread == 8,
                      "Run_asm cannot deal with this GEMM shape yet\n");

166
167
        static_assert(DataPerReadA == 4 && DataPerReadB == 4, "Run_asm only do float4 read\n");

Chao Liu's avatar
Chao Liu committed
168
169
        using Float4 = vector_type<float, 4>::MemoryType;

Jing Zhang's avatar
Jing Zhang committed
170
171
172
        Float4* reg_a = (Float4*)(p_a_thread);
        Float4* reg_b = (Float4*)(p_b_thread);
        Float4* reg_c = (Float4*)(p_c_thread);
173
174
175
176
177
178
179

        reg_a[0] = *reinterpret_cast<const Float4*>(&p_a_block[mMyThreadOffsetA]);
        reg_b[0] = *reinterpret_cast<const Float4*>(&p_b_block[mMyThreadOffsetB]);
        reg_b[1] =
            *reinterpret_cast<const Float4*>(&p_b_block[mMyThreadOffsetB + NPerLevel1Cluster]);
        reg_a[1] =
            *reinterpret_cast<const Float4*>(&p_a_block[mMyThreadOffsetA + MPerLevel1Cluster]);
Jing Zhang's avatar
Jing Zhang committed
180
181
182
        outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]);
        outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]);
#pragma unroll
183
        for(index_t k = 1; k < K; ++k)
Jing Zhang's avatar
Jing Zhang committed
184
        {
185
            reg_a[0] = *reinterpret_cast<const Float4*>(&p_a_block[mMyThreadOffsetA + k * M]);
Jing Zhang's avatar
Jing Zhang committed
186
            outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]);
187
            reg_b[0] = *reinterpret_cast<const Float4*>(&p_b_block[mMyThreadOffsetB + k * N]);
Jing Zhang's avatar
Jing Zhang committed
188
            outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]);
189
190
191
192
            reg_b[1] = *reinterpret_cast<const Float4*>(
                &p_b_block[mMyThreadOffsetB + k * N + NPerLevel1Cluster]);
            reg_a[1] = *reinterpret_cast<const Float4*>(
                &p_a_block[mMyThreadOffsetA + k * M + MPerLevel1Cluster]);
Jing Zhang's avatar
Jing Zhang committed
193
194
            outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]);
            outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]);
Chao Liu's avatar
Chao Liu committed
195
        }
Jing Zhang's avatar
Jing Zhang committed
196
197
        outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]);
        outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]);
Chao Liu's avatar
Chao Liu committed
198
    }
199
#endif
200

201
    template <class FloatA, class FloatB, class FloatC>
Chao Liu's avatar
Chao Liu committed
202
203
    __device__ void Run(const FloatA* const __restrict__ p_a_block,
                        const FloatB* const __restrict__ p_b_block,
204
                        FloatC* const __restrict__ p_c_thread) const
Chao Liu's avatar
Chao Liu committed
205
206
207
208
209
210
211
212
    {
        constexpr auto True  = integral_constant<bool, true>{};
        constexpr auto False = integral_constant<bool, false>{};

        constexpr auto a_block_mtx  = BlockMatrixA{};
        constexpr auto b_block_mtx  = BlockMatrixB{};
        constexpr auto c_thread_mtx = ThreadMatrixC{};

Chao Liu's avatar
Chao Liu committed
213
214
215
        constexpr index_t M = a_block_mtx.NCol();
        constexpr index_t N = b_block_mtx.NCol();
        constexpr index_t K = a_block_mtx.NRow();
Chao Liu's avatar
Chao Liu committed
216

Chao Liu's avatar
Chao Liu committed
217
218
        constexpr index_t MPerThread = c_thread_mtx.NRow();
        constexpr index_t NPerThread = c_thread_mtx.NCol();
Chao Liu's avatar
Chao Liu committed
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236

        // thread A, B for GEMM
        constexpr auto a_thread_mtx =
            make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<MPerThread>{});

        constexpr auto b_thread_mtx =
            make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<NPerThread>{});

        // thread A-sub, B-sub for copy
        constexpr auto a_thread_sub_mtx = make_ConstantMatrixDescriptor(
            Number<KPerThreadLoop>{}, Number<MPerThreadSubC>{}, Number<MPerThread>{});

        constexpr auto b_thread_sub_mtx = make_ConstantMatrixDescriptor(
            Number<KPerThreadLoop>{}, Number<NPerThreadSubC>{}, Number<NPerThread>{});

        FloatA p_a_thread[a_thread_mtx.GetElementSpace()];
        FloatB p_b_thread[b_thread_mtx.GetElementSpace()];

Chao Liu's avatar
Chao Liu committed
237
238
        constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
        constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
Chao Liu's avatar
Chao Liu committed
239

Chao Liu's avatar
Chao Liu committed
240
241
        constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
        constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
Chao Liu's avatar
Chao Liu committed
242

Chao Liu's avatar
Chao Liu committed
243
244
        const FloatA* const p_a_block_thread_offset = p_a_block + mMyThreadOffsetA;

Chao Liu's avatar
Chao Liu committed
245
246
#pragma unroll
        // loop over k
Chao Liu's avatar
Chao Liu committed
247
        for(index_t k_begin = 0; k_begin < K; k_begin += KPerThreadLoop)
Chao Liu's avatar
Chao Liu committed
248
        {
Chao Liu's avatar
Chao Liu committed
249
#pragma unroll
Chao Liu's avatar
Chao Liu committed
250
            // copy A-sub to form A
Chao Liu's avatar
Chao Liu committed
251
            for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
Chao Liu's avatar
Chao Liu committed
252
253
254
            {
                threadwise_matrix_copy(
                    a_block_mtx,
255
256
                    p_a_block +
                        a_block_mtx.GetOffsetFromMultiIndex(k_begin, m_repeat * MPerLevel1Cluster) +
Chao Liu's avatar
Chao Liu committed
257
258
                        mMyThreadOffsetA,
                    a_thread_mtx,
259
                    p_a_thread + a_thread_mtx.GetOffsetFromMultiIndex(0, m_repeat * MPerThreadSubC),
Chao Liu's avatar
Chao Liu committed
260
261
                    a_thread_sub_mtx.GetLengths(),
                    Number<DataPerReadA>{});
Chao Liu's avatar
Chao Liu committed
262
            }
Chao Liu's avatar
Chao Liu committed
263

Chao Liu's avatar
Chao Liu committed
264
#pragma unroll
Chao Liu's avatar
Chao Liu committed
265
            // copy B-sub to form B
Chao Liu's avatar
Chao Liu committed
266
            for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
Chao Liu's avatar
Chao Liu committed
267
268
269
            {
                threadwise_matrix_copy(
                    b_block_mtx,
270
271
                    p_b_block +
                        b_block_mtx.GetOffsetFromMultiIndex(k_begin, n_repeat * NPerLevel1Cluster) +
Chao Liu's avatar
Chao Liu committed
272
273
                        mMyThreadOffsetB,
                    b_thread_mtx,
274
                    p_b_thread + b_thread_mtx.GetOffsetFromMultiIndex(0, n_repeat * NPerThreadSubC),
Chao Liu's avatar
Chao Liu committed
275
276
                    b_thread_sub_mtx.GetLengths(),
                    Number<DataPerReadB>{});
Chao Liu's avatar
Chao Liu committed
277
278
            }

Chao Liu's avatar
Chao Liu committed
279
            // C = A * B
Chao Liu's avatar
Chao Liu committed
280
281
282
283
284
285
286
287
            threadwise_gemm(a_thread_mtx,
                            True,
                            p_a_thread,
                            b_thread_mtx,
                            False,
                            p_b_thread,
                            c_thread_mtx,
                            False,
288
                            p_c_thread);
Chao Liu's avatar
Chao Liu committed
289
290
291
        }
    }

292
    template <class FloatA, class FloatB, class FloatC>
293
294
    __device__ void Run_RegisterDoubleBuffer(FloatA* const p_a_block,
                                             FloatB* const p_b_block,
295
                                             FloatC* p_c_thread) const
296
    {
Chao Liu's avatar
Chao Liu committed
297
298
        constexpr auto True  = integral_constant<bool, true>{};
        constexpr auto False = integral_constant<bool, false>{};
299

Chao Liu's avatar
Chao Liu committed
300
301
302
        constexpr auto a_block_mtx  = BlockMatrixA{};
        constexpr auto b_block_mtx  = BlockMatrixB{};
        constexpr auto c_thread_mtx = ThreadMatrixC{};
303

Chao Liu's avatar
Chao Liu committed
304
305
306
        constexpr index_t M = a_block_mtx.NCol();
        constexpr index_t N = b_block_mtx.NCol();
        constexpr index_t K = a_block_mtx.NRow();
307

Chao Liu's avatar
Chao Liu committed
308
309
        constexpr index_t MPerThread = c_thread_mtx.NRow();
        constexpr index_t NPerThread = c_thread_mtx.NCol();
310
311

        // thread A, B for GEMM
Chao Liu's avatar
Chao Liu committed
312
313
        constexpr auto a_thread_mtx =
            make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<MPerThread>{});
314

Chao Liu's avatar
Chao Liu committed
315
316
        constexpr auto b_thread_mtx =
            make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<NPerThread>{});
317
318

        // thread A-sub, B-sub for copy
Chao Liu's avatar
Chao Liu committed
319
320
        constexpr auto a_thread_sub_mtx = make_ConstantMatrixDescriptor(
            Number<KPerThreadLoop>{}, Number<MPerThreadSubC>{}, Number<MPerThread>{});
321

Chao Liu's avatar
Chao Liu committed
322
323
        constexpr auto b_thread_sub_mtx = make_ConstantMatrixDescriptor(
            Number<KPerThreadLoop>{}, Number<NPerThreadSubC>{}, Number<NPerThread>{});
324

325
        // register
326
327
328
329
330
331
        FloatA p_a_thread_0[a_thread_mtx.GetElementSpace()];
        FloatB p_b_thread_0[b_thread_mtx.GetElementSpace()];

        FloatA p_a_thread_1[a_thread_mtx.GetElementSpace()];
        FloatB p_b_thread_1[b_thread_mtx.GetElementSpace()];

Chao Liu's avatar
Chao Liu committed
332
333
        constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
        constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
334

Chao Liu's avatar
Chao Liu committed
335
336
        constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
        constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
337

Chao Liu's avatar
Chao Liu committed
338
// preload A, B
339
#pragma unroll
Chao Liu's avatar
Chao Liu committed
340
        for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
341
342
343
344
345
        { // copy A-sub to form A
            threadwise_matrix_copy(a_block_mtx,
                                   p_a_block + mMyThreadOffsetA + m_repeat * MPerLevel1Cluster,
                                   a_thread_sub_mtx,
                                   p_a_thread_0 + m_repeat * MPerThreadSubC,
Chao Liu's avatar
Chao Liu committed
346
347
                                   a_thread_sub_mtx.GetLengths(),
                                   Number<DataPerReadA>{});
348
349
350
        }

#pragma unroll
Chao Liu's avatar
Chao Liu committed
351
        for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
352
353
354
355
356
        { // copy B-sub to form B
            threadwise_matrix_copy(b_block_mtx,
                                   p_b_block + mMyThreadOffsetB + n_repeat * NPerLevel1Cluster,
                                   b_thread_sub_mtx,
                                   p_b_thread_0 + n_repeat * NPerThreadSubC,
Chao Liu's avatar
Chao Liu committed
357
358
                                   b_thread_sub_mtx.GetLengths(),
                                   Number<DataPerReadB>{});
359
360
361
362
363
        }

        bool even_loop = true;

#pragma unroll
Chao Liu's avatar
Chao Liu committed
364
        for(index_t k_begin = 0; k_begin + KPerThreadLoop < K;
365
366
367
368
369
370
371
372
            k_begin += KPerThreadLoop, even_loop = !even_loop)
        { // loop over k
            FloatA* p_a_thread_now = even_loop ? p_a_thread_0 : p_a_thread_1;
            FloatB* p_b_thread_now = even_loop ? p_b_thread_0 : p_b_thread_1;

            FloatA* p_a_thread_next = even_loop ? p_a_thread_1 : p_a_thread_0;
            FloatB* p_b_thread_next = even_loop ? p_b_thread_1 : p_b_thread_0;

Chao Liu's avatar
Chao Liu committed
373
// preload next A, B
374
#pragma unroll
Chao Liu's avatar
Chao Liu committed
375
            for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
376
377
378
379
380
381
382
            { // copy A-sub to form A
                threadwise_matrix_copy(a_block_mtx,
                                       p_a_block + mMyThreadOffsetA +
                                           (k_begin + 1) * a_block_mtx.RowStride() +
                                           m_repeat * MPerLevel1Cluster,
                                       a_thread_sub_mtx,
                                       p_a_thread_next + m_repeat * MPerThreadSubC,
Chao Liu's avatar
Chao Liu committed
383
384
                                       a_thread_sub_mtx.GetLengths(),
                                       Number<DataPerReadA>{});
385
386
387
            }

#pragma unroll
Chao Liu's avatar
Chao Liu committed
388
            for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
389
390
391
392
393
394
395
            { // copy B-sub to form B
                threadwise_matrix_copy(b_block_mtx,
                                       p_b_block + mMyThreadOffsetB +
                                           (k_begin + 1) * b_block_mtx.RowStride() +
                                           n_repeat * NPerLevel1Cluster,
                                       b_thread_sub_mtx,
                                       p_b_thread_next + n_repeat * NPerThreadSubC,
Chao Liu's avatar
Chao Liu committed
396
397
                                       b_thread_sub_mtx.GetLengths(),
                                       Number<DataPerReadB>{});
398
399
400
401
402
403
404
405
406
407
408
            }

            // C = A * B
            threadwise_gemm(a_thread_mtx,
                            True,
                            p_a_thread_now,
                            b_thread_mtx,
                            False,
                            p_b_thread_now,
                            c_thread_mtx,
                            False,
409
                            p_c_thread);
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
        }

        // last loop
        {
            FloatA* p_a_thread_now = even_loop ? p_a_thread_0 : p_a_thread_1;
            FloatB* p_b_thread_now = even_loop ? p_b_thread_0 : p_b_thread_1;

            // C = A * B
            threadwise_gemm(a_thread_mtx,
                            True,
                            p_a_thread_now,
                            b_thread_mtx,
                            False,
                            p_b_thread_now,
                            c_thread_mtx,
                            False,
426
                            p_c_thread);
Chao Liu's avatar
Chao Liu committed
427
428
        }
    }
Chao Liu's avatar
Chao Liu committed
429
};
430
431
432

} // namespace ck
#endif