blockwise_gemm.hip.hpp 18.2 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
#pragma once
2
#include "threadwise_gemm.hip.hpp"
Chao Liu's avatar
Chao Liu committed
3

Chao Liu's avatar
Chao Liu committed
4
5
// if following number are power of 2, index calculation shall be greatly reduced:
//    MPerThreadSubC, NPerThreadSubC, MLevel0Cluster, NLevel0Cluster, MLevel1Cluster, NLevel1Cluster
Chao Liu's avatar
Chao Liu committed
6
template <index_t BlockSize,
Chao Liu's avatar
Chao Liu committed
7
8
9
          class BlockMatrixA,
          class BlockMatrixB,
          class ThreadMatrixC,
Chao Liu's avatar
Chao Liu committed
10
11
12
13
14
15
16
          index_t MPerThreadSubC,
          index_t NPerThreadSubC,
          index_t MLevel0Cluster,
          index_t NLevel0Cluster,
          index_t MLevel1Cluster,
          index_t NLevel1Cluster,
          index_t KPerThreadLoop>
Chao Liu's avatar
Chao Liu committed
17
18
19
20
struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
{
    struct MatrixIndex
    {
Chao Liu's avatar
Chao Liu committed
21
22
        index_t row;
        index_t col;
Chao Liu's avatar
Chao Liu committed
23
24
    };

Chao Liu's avatar
Chao Liu committed
25
26
    index_t mMyThreadOffsetA;
    index_t mMyThreadOffsetB;
Chao Liu's avatar
Chao Liu committed
27
28
29

    __device__ BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2()
    {
Chao Liu's avatar
Chao Liu committed
30
        constexpr index_t ThreadPerLevel1Cluster =
Chao Liu's avatar
Chao Liu committed
31
32
33
34
            MLevel0Cluster * NLevel0Cluster * MLevel1Cluster * NLevel1Cluster;

        static_assert(BlockSize == ThreadPerLevel1Cluster, "wrong! wrong blocksize\n");

Chao Liu's avatar
Chao Liu committed
35
36
37
        constexpr auto a_block_mtx  = BlockMatrixA{};
        constexpr auto b_block_mtx  = BlockMatrixB{};
        constexpr auto c_thread_mtx = ThreadMatrixC{};
Chao Liu's avatar
Chao Liu committed
38
39
40
41

        static_assert(a_block_mtx.NRow() == b_block_mtx.NRow(),
                      "wrong! K dimension not consistent\n");

Chao Liu's avatar
Chao Liu committed
42
43
44
        constexpr index_t M = a_block_mtx.NCol(); // A is transposed
        constexpr index_t N = b_block_mtx.NCol();
        constexpr index_t K = a_block_mtx.NRow();
Chao Liu's avatar
Chao Liu committed
45

Chao Liu's avatar
Chao Liu committed
46
47
        constexpr index_t MPerThread = c_thread_mtx.NRow();
        constexpr index_t NPerThread = c_thread_mtx.NCol();
Chao Liu's avatar
Chao Liu committed
48
49
50
51

        static_assert((MPerThread % MPerThreadSubC == 0) && (NPerThread % NPerThreadSubC == 0),
                      "wrong! Cannot evenly divide thread work among repeat \n");

Chao Liu's avatar
Chao Liu committed
52
53
        constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
        constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
Chao Liu's avatar
Chao Liu committed
54
55
56
57

        static_assert((M % MRepeat == 0) && (N % NRepeat == 0),
                      "wrong! Cannot evenly divide work among repeat\n");

Chao Liu's avatar
Chao Liu committed
58
59
        constexpr index_t MPerLevel1Cluster = M / MRepeat;
        constexpr index_t NPerLevel1Cluster = N / NRepeat;
Chao Liu's avatar
Chao Liu committed
60
61
62
63
64

        static_assert((MPerLevel1Cluster % MLevel1Cluster == 0) &&
                          (NPerLevel1Cluster % NLevel1Cluster == 0),
                      "wrong! Cannot evenly divide work among Level1Cluster\n");

Chao Liu's avatar
Chao Liu committed
65
66
        constexpr index_t MPerLevel0Cluster = MPerLevel1Cluster / MLevel1Cluster;
        constexpr index_t NPerLevel0Cluster = NPerLevel1Cluster / NLevel1Cluster;
Chao Liu's avatar
Chao Liu committed
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81

        static_assert((MPerLevel0Cluster % MLevel0Cluster == 0) &&
                          (NPerLevel0Cluster % NLevel0Cluster == 0),
                      "wrong! Cannot evenly divide work among Level0Cluster\n");

        static_assert((MPerThreadSubC == MPerLevel0Cluster / MLevel0Cluster) &&
                          (NPerThreadSubC == NPerLevel0Cluster / NLevel0Cluster),
                      "wrong! thread work size is wrong\n");

        auto c_thread_mtx_index = GetBeginOfThreadMatrixC(get_thread_local_1d_id());

        mMyThreadOffsetA = a_block_mtx.Get1dIndex(0, c_thread_mtx_index.row);
        mMyThreadOffsetB = b_block_mtx.Get1dIndex(0, c_thread_mtx_index.col);
    }

Chao Liu's avatar
Chao Liu committed
82
    __device__ static MatrixIndex GetBeginOfThreadMatrixC(index_t thread_id)
Chao Liu's avatar
Chao Liu committed
83
    {
Chao Liu's avatar
Chao Liu committed
84
        constexpr index_t ThreadPerLevel0Cluster = MLevel0Cluster * NLevel0Cluster;
Chao Liu's avatar
Chao Liu committed
85

Chao Liu's avatar
Chao Liu committed
86
87
88
        index_t level1_id   = thread_id / ThreadPerLevel0Cluster;
        index_t level1_m_id = level1_id / NLevel1Cluster;
        index_t level1_n_id = level1_id % NLevel1Cluster;
Chao Liu's avatar
Chao Liu committed
89

Chao Liu's avatar
Chao Liu committed
90
91
92
        index_t level0_id   = thread_id % ThreadPerLevel0Cluster;
        index_t level0_m_id = level0_id / NLevel0Cluster;
        index_t level0_n_id = level0_id % NLevel0Cluster;
Chao Liu's avatar
Chao Liu committed
93

Chao Liu's avatar
Chao Liu committed
94
95
        constexpr index_t MPerLevel0Cluster = MPerThreadSubC * MLevel0Cluster;
        constexpr index_t NPerLevel0Cluster = NPerThreadSubC * NLevel0Cluster;
Chao Liu's avatar
Chao Liu committed
96
97
98
99
100
101

        return MatrixIndex{level1_m_id * MPerLevel0Cluster + level0_m_id * MPerThreadSubC,
                           level1_n_id * NPerLevel0Cluster + level0_n_id * NPerThreadSubC};
    }

    // this should be optimized away if input is known
Chao Liu's avatar
Chao Liu committed
102
103
    __device__ static MatrixIndex GetDistanceFromBeginOfThreadMatrixC(index_t m_in_c,
                                                                      index_t n_in_c)
Chao Liu's avatar
Chao Liu committed
104
    {
Chao Liu's avatar
Chao Liu committed
105
        constexpr auto c_thread_mtx = ThreadMatrixC{};
Chao Liu's avatar
Chao Liu committed
106

Chao Liu's avatar
Chao Liu committed
107
108
        constexpr index_t MPerThread = c_thread_mtx.NRow();
        constexpr index_t NPerThread = c_thread_mtx.NCol();
Chao Liu's avatar
Chao Liu committed
109

Chao Liu's avatar
Chao Liu committed
110
111
        constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
        constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
Chao Liu's avatar
Chao Liu committed
112

Chao Liu's avatar
Chao Liu committed
113
114
        constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
        constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
Chao Liu's avatar
Chao Liu committed
115

Chao Liu's avatar
Chao Liu committed
116
117
        index_t m_repeat = m_in_c / MPerThreadSubC;
        index_t n_repeat = n_in_c / NPerThreadSubC;
Chao Liu's avatar
Chao Liu committed
118

Chao Liu's avatar
Chao Liu committed
119
120
        index_t m_in_sub_c = m_in_c % MPerThreadSubC;
        index_t n_in_sub_c = n_in_c % NPerThreadSubC;
Chao Liu's avatar
Chao Liu committed
121
122
123
124
125

        return MatrixIndex{m_repeat * MPerLevel1Cluster + m_in_sub_c,
                           n_repeat * NPerLevel1Cluster + n_in_sub_c};
    }

126
127
#if DEVICE_BACKEND_HIP
    template <class FloatA, class FloatB, class FloatC>
Chao Liu's avatar
Chao Liu committed
128
129
    __device__ void Run_asm(const FloatA* __restrict__ p_a_block,
                            const FloatB* __restrict__ p_b_block,
130
                            FloatC* __restrict__ p_c_thread) const
Chao Liu's avatar
Chao Liu committed
131
    {
Chao Liu's avatar
Chao Liu committed
132
133
        constexpr auto True  = integral_constant<bool, true>{};
        constexpr auto False = integral_constant<bool, false>{};
Chao Liu's avatar
Chao Liu committed
134

Chao Liu's avatar
Chao Liu committed
135
136
137
        constexpr auto a_block_mtx  = BlockMatrixA{};
        constexpr auto b_block_mtx  = BlockMatrixB{};
        constexpr auto c_thread_mtx = ThreadMatrixC{};
Chao Liu's avatar
Chao Liu committed
138

Chao Liu's avatar
Chao Liu committed
139
140
141
        constexpr index_t M = a_block_mtx.NCol();
        constexpr index_t N = b_block_mtx.NCol();
        constexpr index_t K = a_block_mtx.NRow();
Chao Liu's avatar
Chao Liu committed
142

Chao Liu's avatar
Chao Liu committed
143
144
        constexpr index_t MPerThread = c_thread_mtx.NRow();
        constexpr index_t NPerThread = c_thread_mtx.NCol();
Chao Liu's avatar
Chao Liu committed
145
146

        // thread A, B for GEMM
Chao Liu's avatar
Chao Liu committed
147
148
        constexpr auto a_thread_mtx =
            make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<MPerThread>{});
Chao Liu's avatar
Chao Liu committed
149

Chao Liu's avatar
Chao Liu committed
150
151
        constexpr auto b_thread_mtx =
            make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<NPerThread>{});
Chao Liu's avatar
Chao Liu committed
152
153

        // thread A-sub, B-sub for copy
Chao Liu's avatar
Chao Liu committed
154
155
        constexpr auto a_thread_sub_mtx = make_ConstantMatrixDescriptor(
            Number<KPerThreadLoop>{}, Number<MPerThreadSubC>{}, Number<MPerThread>{});
Chao Liu's avatar
Chao Liu committed
156

Chao Liu's avatar
Chao Liu committed
157
158
        constexpr auto b_thread_sub_mtx = make_ConstantMatrixDescriptor(
            Number<KPerThreadLoop>{}, Number<NPerThreadSubC>{}, Number<NPerThread>{});
Chao Liu's avatar
Chao Liu committed
159

Chao Liu's avatar
Chao Liu committed
160
        float p_thread[a_thread_mtx.GetElementSpace() + b_thread_mtx.GetElementSpace()];
Jing Zhang's avatar
Jing Zhang committed
161

Chao Liu's avatar
Chao Liu committed
162
163
        FloatA* p_a_thread = p_thread;
        FloatB* p_b_thread = p_thread + a_thread_mtx.GetElementSpace();
Chao Liu's avatar
Chao Liu committed
164

Chao Liu's avatar
Chao Liu committed
165
166
        constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
        constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
Chao Liu's avatar
Chao Liu committed
167

Chao Liu's avatar
Chao Liu committed
168
169
        constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
        constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
Chao Liu's avatar
Chao Liu committed
170

Jing Zhang's avatar
Jing Zhang committed
171
172
173
        Float4* reg_a = (Float4*)(p_a_thread);
        Float4* reg_b = (Float4*)(p_b_thread);
        Float4* reg_c = (Float4*)(p_c_thread);
Chao Liu's avatar
Chao Liu committed
174
175
        void* a_loc   = (void*)(p_a_block + mMyThreadOffsetA);
        void* b_loc   = (void*)(p_b_block + mMyThreadOffsetB);
Jing Zhang's avatar
Jing Zhang committed
176

Jing Zhang's avatar
Jing Zhang committed
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
        int lds_a_block_off   = sizeof(Float) * M;
        int lds_b_block_off   = sizeof(Float) * N;
        int lds_a_block_off_1 = MPerLevel1Cluster * sizeof(Float);
        int lds_b_block_off_1 = NPerLevel1Cluster * sizeof(Float);
        ds_read_b128(reg_a[0], a_loc, 0);
        ds_read_b128(reg_b[0], b_loc, 0);
        ds_read_b128(reg_b[1], b_loc, lds_b_block_off_1);
        ds_read_b128(reg_a[1], a_loc, lds_a_block_off_1);
        lgkmcnt(2);
        outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]);
        lgkmcnt(1);
        outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]);
        lgkmcnt(0);
#pragma unroll
        for(int k_i = 1; k_i < K; k_i++)
        {
            ds_read_b128(reg_a[0], a_loc, k_i * lds_a_block_off);
Jing Zhang's avatar
Jing Zhang committed
194
            outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]);
Jing Zhang's avatar
Jing Zhang committed
195
            ds_read_b128(reg_b[0], b_loc, k_i * lds_b_block_off);
Jing Zhang's avatar
Jing Zhang committed
196
            outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]);
Jing Zhang's avatar
Jing Zhang committed
197
198
            ds_read_b128(reg_b[1], b_loc, lds_b_block_off_1 + k_i * lds_b_block_off);
            ds_read_b128(reg_a[1], a_loc, lds_a_block_off_1 + k_i * lds_a_block_off);
Jing Zhang's avatar
Jing Zhang committed
199
200
201
202
203
            lgkmcnt(2);
            outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]);
            lgkmcnt(1);
            outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]);
            lgkmcnt(0);
Chao Liu's avatar
Chao Liu committed
204
        }
Jing Zhang's avatar
Jing Zhang committed
205
206
        outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]);
        outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]);
Chao Liu's avatar
Chao Liu committed
207
    }
208
#endif
209

210
    template <class FloatA, class FloatB, class FloatC>
Chao Liu's avatar
Chao Liu committed
211
212
    __device__ void Run(const FloatA* const __restrict__ p_a_block,
                        const FloatB* const __restrict__ p_b_block,
213
                        FloatC* const __restrict__ p_c_thread) const
Chao Liu's avatar
Chao Liu committed
214
215
216
217
218
219
220
221
    {
        constexpr auto True  = integral_constant<bool, true>{};
        constexpr auto False = integral_constant<bool, false>{};

        constexpr auto a_block_mtx  = BlockMatrixA{};
        constexpr auto b_block_mtx  = BlockMatrixB{};
        constexpr auto c_thread_mtx = ThreadMatrixC{};

Chao Liu's avatar
Chao Liu committed
222
223
224
        constexpr index_t M = a_block_mtx.NCol();
        constexpr index_t N = b_block_mtx.NCol();
        constexpr index_t K = a_block_mtx.NRow();
Chao Liu's avatar
Chao Liu committed
225

Chao Liu's avatar
Chao Liu committed
226
227
        constexpr index_t MPerThread = c_thread_mtx.NRow();
        constexpr index_t NPerThread = c_thread_mtx.NCol();
Chao Liu's avatar
Chao Liu committed
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245

        // thread A, B for GEMM
        constexpr auto a_thread_mtx =
            make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<MPerThread>{});

        constexpr auto b_thread_mtx =
            make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<NPerThread>{});

        // thread A-sub, B-sub for copy
        constexpr auto a_thread_sub_mtx = make_ConstantMatrixDescriptor(
            Number<KPerThreadLoop>{}, Number<MPerThreadSubC>{}, Number<MPerThread>{});

        constexpr auto b_thread_sub_mtx = make_ConstantMatrixDescriptor(
            Number<KPerThreadLoop>{}, Number<NPerThreadSubC>{}, Number<NPerThread>{});

        FloatA p_a_thread[a_thread_mtx.GetElementSpace()];
        FloatB p_b_thread[b_thread_mtx.GetElementSpace()];

Chao Liu's avatar
Chao Liu committed
246
247
        constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
        constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
Chao Liu's avatar
Chao Liu committed
248

Chao Liu's avatar
Chao Liu committed
249
250
        constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
        constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
Chao Liu's avatar
Chao Liu committed
251

Chao Liu's avatar
Chao Liu committed
252
253
        const FloatA* const p_a_block_thread_offset = p_a_block + mMyThreadOffsetA;

Chao Liu's avatar
Chao Liu committed
254
255
#pragma unroll
        // loop over k
Chao Liu's avatar
Chao Liu committed
256
        for(index_t k_begin = 0; k_begin < K; k_begin += KPerThreadLoop)
Chao Liu's avatar
Chao Liu committed
257
        {
Chao Liu's avatar
Chao Liu committed
258
#pragma unroll
Chao Liu's avatar
Chao Liu committed
259
            // copy A-sub to form A
Chao Liu's avatar
Chao Liu committed
260
            for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
Chao Liu's avatar
Chao Liu committed
261
262
263
264
265
266
            {
                threadwise_matrix_copy(
                    a_block_mtx,
                    p_a_block + a_block_mtx.Get1dIndex(k_begin, m_repeat * MPerLevel1Cluster) +
                        mMyThreadOffsetA,
                    a_thread_mtx,
Chao Liu's avatar
Chao Liu committed
267
                    p_a_thread + a_thread_mtx.Get1dIndex(0, m_repeat * MPerThreadSubC),
Chao Liu's avatar
Chao Liu committed
268
269
                    a_thread_sub_mtx.GetLengths());
            }
Chao Liu's avatar
Chao Liu committed
270

Chao Liu's avatar
Chao Liu committed
271
#pragma unroll
Chao Liu's avatar
Chao Liu committed
272
            // copy B-sub to form B
Chao Liu's avatar
Chao Liu committed
273
            for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
Chao Liu's avatar
Chao Liu committed
274
275
276
277
278
279
280
281
282
283
            {
                threadwise_matrix_copy(
                    b_block_mtx,
                    p_b_block + b_block_mtx.Get1dIndex(k_begin, n_repeat * NPerLevel1Cluster) +
                        mMyThreadOffsetB,
                    b_thread_mtx,
                    p_b_thread + b_thread_mtx.Get1dIndex(0, n_repeat * NPerThreadSubC),
                    b_thread_sub_mtx.GetLengths());
            }

Chao Liu's avatar
Chao Liu committed
284
            // C = A * B
Chao Liu's avatar
Chao Liu committed
285
286
287
288
289
290
291
292
            threadwise_gemm(a_thread_mtx,
                            True,
                            p_a_thread,
                            b_thread_mtx,
                            False,
                            p_b_thread,
                            c_thread_mtx,
                            False,
293
                            p_c_thread);
Chao Liu's avatar
Chao Liu committed
294
295
296
        }
    }

297
    template <class FloatA, class FloatB, class FloatC>
298
299
    __device__ void Run_RegisterDoubleBuffer(FloatA* const p_a_block,
                                             FloatB* const p_b_block,
300
                                             FloatC* p_c_thread) const
301
    {
Chao Liu's avatar
Chao Liu committed
302
303
        constexpr auto True  = integral_constant<bool, true>{};
        constexpr auto False = integral_constant<bool, false>{};
304

Chao Liu's avatar
Chao Liu committed
305
306
307
        constexpr auto a_block_mtx  = BlockMatrixA{};
        constexpr auto b_block_mtx  = BlockMatrixB{};
        constexpr auto c_thread_mtx = ThreadMatrixC{};
308

Chao Liu's avatar
Chao Liu committed
309
310
311
        constexpr index_t M = a_block_mtx.NCol();
        constexpr index_t N = b_block_mtx.NCol();
        constexpr index_t K = a_block_mtx.NRow();
312

Chao Liu's avatar
Chao Liu committed
313
314
        constexpr index_t MPerThread = c_thread_mtx.NRow();
        constexpr index_t NPerThread = c_thread_mtx.NCol();
315
316

        // thread A, B for GEMM
Chao Liu's avatar
Chao Liu committed
317
318
        constexpr auto a_thread_mtx =
            make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<MPerThread>{});
319

Chao Liu's avatar
Chao Liu committed
320
321
        constexpr auto b_thread_mtx =
            make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<NPerThread>{});
322
323

        // thread A-sub, B-sub for copy
Chao Liu's avatar
Chao Liu committed
324
325
        constexpr auto a_thread_sub_mtx = make_ConstantMatrixDescriptor(
            Number<KPerThreadLoop>{}, Number<MPerThreadSubC>{}, Number<MPerThread>{});
326

Chao Liu's avatar
Chao Liu committed
327
328
        constexpr auto b_thread_sub_mtx = make_ConstantMatrixDescriptor(
            Number<KPerThreadLoop>{}, Number<NPerThreadSubC>{}, Number<NPerThread>{});
329

330
        // register
331
332
333
334
335
336
        FloatA p_a_thread_0[a_thread_mtx.GetElementSpace()];
        FloatB p_b_thread_0[b_thread_mtx.GetElementSpace()];

        FloatA p_a_thread_1[a_thread_mtx.GetElementSpace()];
        FloatB p_b_thread_1[b_thread_mtx.GetElementSpace()];

Chao Liu's avatar
Chao Liu committed
337
338
        constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
        constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
339

Chao Liu's avatar
Chao Liu committed
340
341
        constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
        constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
342

Chao Liu's avatar
Chao Liu committed
343
// preload A, B
344
#pragma unroll
Chao Liu's avatar
Chao Liu committed
345
        for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
346
347
348
349
350
351
352
353
354
        { // copy A-sub to form A
            threadwise_matrix_copy(a_block_mtx,
                                   p_a_block + mMyThreadOffsetA + m_repeat * MPerLevel1Cluster,
                                   a_thread_sub_mtx,
                                   p_a_thread_0 + m_repeat * MPerThreadSubC,
                                   a_thread_sub_mtx.GetLengths());
        }

#pragma unroll
Chao Liu's avatar
Chao Liu committed
355
        for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
356
357
358
359
360
361
362
363
364
365
366
        { // copy B-sub to form B
            threadwise_matrix_copy(b_block_mtx,
                                   p_b_block + mMyThreadOffsetB + n_repeat * NPerLevel1Cluster,
                                   b_thread_sub_mtx,
                                   p_b_thread_0 + n_repeat * NPerThreadSubC,
                                   b_thread_sub_mtx.GetLengths());
        }

        bool even_loop = true;

#pragma unroll
Chao Liu's avatar
Chao Liu committed
367
        for(index_t k_begin = 0; k_begin + KPerThreadLoop < K;
368
369
370
371
372
373
374
375
            k_begin += KPerThreadLoop, even_loop = !even_loop)
        { // loop over k
            FloatA* p_a_thread_now = even_loop ? p_a_thread_0 : p_a_thread_1;
            FloatB* p_b_thread_now = even_loop ? p_b_thread_0 : p_b_thread_1;

            FloatA* p_a_thread_next = even_loop ? p_a_thread_1 : p_a_thread_0;
            FloatB* p_b_thread_next = even_loop ? p_b_thread_1 : p_b_thread_0;

Chao Liu's avatar
Chao Liu committed
376
// preload next A, B
377
#pragma unroll
Chao Liu's avatar
Chao Liu committed
378
            for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
379
380
381
382
383
384
385
386
387
388
389
            { // copy A-sub to form A
                threadwise_matrix_copy(a_block_mtx,
                                       p_a_block + mMyThreadOffsetA +
                                           (k_begin + 1) * a_block_mtx.RowStride() +
                                           m_repeat * MPerLevel1Cluster,
                                       a_thread_sub_mtx,
                                       p_a_thread_next + m_repeat * MPerThreadSubC,
                                       a_thread_sub_mtx.GetLengths());
            }

#pragma unroll
Chao Liu's avatar
Chao Liu committed
390
            for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
            { // copy B-sub to form B
                threadwise_matrix_copy(b_block_mtx,
                                       p_b_block + mMyThreadOffsetB +
                                           (k_begin + 1) * b_block_mtx.RowStride() +
                                           n_repeat * NPerLevel1Cluster,
                                       b_thread_sub_mtx,
                                       p_b_thread_next + n_repeat * NPerThreadSubC,
                                       b_thread_sub_mtx.GetLengths());
            }

            // C = A * B
            threadwise_gemm(a_thread_mtx,
                            True,
                            p_a_thread_now,
                            b_thread_mtx,
                            False,
                            p_b_thread_now,
                            c_thread_mtx,
                            False,
410
                            p_c_thread);
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
        }

        // last loop
        {
            FloatA* p_a_thread_now = even_loop ? p_a_thread_0 : p_a_thread_1;
            FloatB* p_b_thread_now = even_loop ? p_b_thread_0 : p_b_thread_1;

            // C = A * B
            threadwise_gemm(a_thread_mtx,
                            True,
                            p_a_thread_now,
                            b_thread_mtx,
                            False,
                            p_b_thread_now,
                            c_thread_mtx,
                            False,
427
                            p_c_thread);
Chao Liu's avatar
Chao Liu committed
428
429
        }
    }
Chao Liu's avatar
Chao Liu committed
430
};