gemm_base.cuh 32.8 KB
Newer Older
muyangli's avatar
muyangli committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
#pragma once

#include "common.h"

#include "../utils.cuh"
#include "../dispatch_utils.h"
#include "gemm_utils.cuh"


#pragma nv_diag_suppress 177

#ifdef _MSC_VER
#define ALWAYSINLINE [[msvc::forceinline]]
#else
#define ALWAYSINLINE __attribute__((always_inline))
#endif

// #define ENABLE_NAN_CHECK 1
#if ENABLE_NAN_CHECK
#define STRINGIZE(x) STRINGIZE2(x)
#define STRINGIZE2(x) #x
#define CHECK_NAN(data, name) checkNan(data, name " at " STRINGIZE(__LINE__))
#else
#define CHECK_NAN(data, name)
#endif

namespace nunchaku::kernels {

template<bool bf16>
class GEMMConfig_W4A4 {
public:
    // BE CAREFUL: weights need to be repacked when the tiling size changes

    static constexpr int BLOCK_M = 256;
    static constexpr int BLOCK_N = 128;
    static constexpr int WARP_SIZE = 32;
    static constexpr int NUM_WARPS = 8;

    static constexpr int INSN_M = 16;
    static constexpr int INSN_N = 16;
    static constexpr int INSN_K = 64;

    using half_t  = typename std::conditional_t<bf16, __nv_bfloat16,  half>;
    using half2_t = typename std::conditional_t<bf16, __nv_bfloat162, half2>;
};

using GEMMConfig_W4A4_FP16 = GEMMConfig_W4A4<false>;
using GEMMConfig_W4A4_BF16 = GEMMConfig_W4A4<true>;


class GEMMConfig_W8A8 {
public:
    static constexpr int BLOCK_M = 256;
    static constexpr int BLOCK_N = 128;
    static constexpr int WARP_SIZE = 32;
    static constexpr int NUM_WARPS = 8;

    static constexpr int INSN_M = 16;
    static constexpr int INSN_N = 16;
    static constexpr int INSN_K = 32;

#if 0
    using half_t  = half;
    using half2_t = half2;
#else
    using half_t  = __nv_bfloat16;
    using half2_t = __nv_bfloat162;
#endif
};

template<class Config>
class GEMMBase : public Config {
public:
    using Config::BLOCK_M;
    using Config::BLOCK_N;
    using Config::WARP_SIZE;
    using Config::NUM_WARPS;
    using Config::INSN_M;
    using Config::INSN_N;
    using Config::INSN_K;

    using typename Config::half_t;
    using typename Config::half2_t;

    static constexpr int WARP_M = BLOCK_M / NUM_WARPS;
    static constexpr int WARP_N = BLOCK_N;
    static constexpr int WARP_K = INSN_K;

    static constexpr int WARP_M_TILES = WARP_M / INSN_M;
    static constexpr int WARP_N_TILES = WARP_N / INSN_N;
    static constexpr int WARP_K_TILES = WARP_K / INSN_K;

    /**
     * refer to https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#mma-16864-c
     * 
     * wscales store order: (pack = 4)
     *  0   1   8   9   <-- load by lane 0, broadcast to lane {0, 4, 8, ..., 28} (8x)
     *  2   3   10  11  <-- load by lane 1, broadcast to lane {1, 5, 9, ..., 29} (8x)
     *  4   5   12  13  <-- load by lane 2, broadcast to lane {2, 6, 10, ..., 30} (8x)
     *  6   7   14  15  <-- load by lane 3, broadcast to lane {3, 7, 11, ..., 31} (8x)
     * 
     *  16  17  24  25  <-- load by lane 4, broadcast to lane {0, 4, 8, ..., 28} (8x)
     *  ...
     *  22  23  30  31  <-- load by lane 7, broadcast to lane {3, 7, 11, ..., 31} (8x)
     *  ... ...
     *  112 113 120 121 <-- load by lane 28, broadcast to lane {0, 4, 8, ..., 28} (8x)
     *  ...
     *  118 119 126 127 <-- load by lane 31, broadcast to lane {3, 7, 11, ..., 31} (8x)
     *  
     * wscales store order: (pack = 8)
     *  0   1   8   9   16  17  24  25  <-- load by lane 0, broadcast to lane {0, 4, 8, ..., 28} (8x)
     *  2   3   10  11  18  19  26  27  <-- load by lane 1, broadcast to lane {1, 5, 9, ..., 29} (8x)
     *  4   5   12  13  20  21  28  29  <-- load by lane 2, broadcast to lane {2, 6, 10, ..., 30} (8x)
     *  6   7   14  15  22  23  30  31  <-- load by lane 3, broadcast to lane {3, 7, 11, ..., 31} (8x)
     * 
     *  224 225 232 233 240 241 248 249 <-- load by lane 28, broadcast to lane {0, 4, 8, ..., 28} (8x)
     *  ...
     *  230 231 238 239 246 247 254 255 <-- load by lane 31, broadcast to lane {3, 7, 11, ..., 31} (8x)
     * 
     * {k}-th wscale used by lane {i} => {k // (WSCALES_PACK_SIZE * WARP_SIZE)}-th pack, in lane {4*(k // WSCALES_PACK_SIZE) + i % 4}, element {k % WSCALES_PACK_SIZE}
     * 
     * max pack size set to 8 since max load size is 16 bytes / lane
     * min pack size set to 2 since shuffle granularity is 32b 2*half
     * */ 
    static constexpr int WSCALES_PACK_SIZE = clamp(WARP_N / WARP_SIZE, 4 / sizeof(half), 16 / sizeof(half));
    static constexpr int WSCALES_NUM_PACKS = ceilDiv(WARP_N, (WSCALES_PACK_SIZE * WARP_SIZE));
    static constexpr int WSCALES_VALID_LANES = std::min(WARP_SIZE, WARP_N / WSCALES_PACK_SIZE);

    /**
     * ascales store order: (pack = 2)
     *  0   8   <-- load by lane 0, broadcast to lane {0, 1, 2, 3} (4x)
     *  1   9   <-- load by lane 1, broadcast to lane {4, 5, 6, 7} (4x)
     *  2   10
     *  ...
     *  6   14
     *  7   15  <-- load by lane 7, broadcast to lane {28, 29, 30, 31} (4x)
     *  ... ...
     *  48  56  <-- load by lane 24, broadcast to lane {0, 1, 2, 3} (4x)
     *  49  57
     *  ...
     *  54  62
     *  55  63  <-- load by lane 31, broadcast to lane {28, 29, 30, 31}  (4x)
     * 
     * {k}-th wscale used by lane {i} => {k // (ASCALES_PACK_SIZE * WARP_SIZE)}-th pack, in lane {8*(k // ASCALES_PACK_SIZE) + i // 4}, element {k % ASCALES_PACK_SIZE}
     */
    static constexpr int ASCALES_PACK_SIZE = clamp(WARP_M / WARP_SIZE, 4 / sizeof(half), 16 / sizeof(half));
    static constexpr int ASCALES_NUM_PACKS = ceilDiv(WARP_M, (ASCALES_PACK_SIZE * WARP_SIZE));
    static constexpr int ASCALES_VALID_LANES = std::min(WARP_SIZE, WARP_M / ASCALES_PACK_SIZE);

    using packed_act_t = uint4;
    using packed_wgt_t = uint4;
    
    struct alignas(32) packed_psum_t {
        int data[8];
    };
    struct alignas(16) packed_fpsum_t {
        half2_t data[4];
    };
    struct alignas(8) packed_gated_fpsum_t {
        half_t data[4];
    };
    // 16 * 16 matrix
    struct alignas(32) packed_f32psum_t {
        float data[8];

        static constexpr packed_f32psum_t zeros() {
            packed_f32psum_t result;
            for (int i = 0; i < 8; i++) {
                result.data[i] = 0;
            }
            return result;
        };
    };

    struct packed_wscale_t {
        half2_t data[WSCALES_PACK_SIZE / 2];
    };
    struct packed_ascale_t {
        half2_t data[ASCALES_PACK_SIZE / 2];
    };

    using act_warp = std::array<packed_act_t, WARP_M_TILES>;
    using wgt_warp = std::array<packed_wgt_t, WARP_N_TILES>;
    using ascale_warp = std::array<packed_ascale_t, ASCALES_NUM_PACKS>;
    using wscale_warp = std::array<packed_wscale_t, WSCALES_NUM_PACKS>;
    using fpsum_warp = std::array<packed_fpsum_t, WARP_M_TILES * WARP_N_TILES>;
    using f32psum_warp = std::array<packed_f32psum_t, WARP_M_TILES * WARP_N_TILES>;
    using gated_fpsum_warp = std::array<packed_gated_fpsum_t, WARP_M_TILES * WARP_N_TILES>;


    struct BlockInfo {
        int bm;
        int bn;
        int numBlocksM;
        int numBlocksN;
    };

    __device__ __forceinline__
    static packed_f32psum_t mma_f16xf16_f32(packed_fpsum_t a, packed_fpsum_t b, packed_f32psum_t psum) {
        static_assert(std::is_same_v<half_t, half> || std::is_same_v<half_t, __nv_bfloat16>);

        if constexpr (std::is_same_v<half_t, half>) {
            asm volatile(
                "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
                "{%0,  %1,  %2,  %3},"
                "{%4,  %5,  %6,  %7},"
                "{%8,  %9},"
                "{%10,  %11,  %12,  %13};\n"
                : 
                "=f"(psum.data[0]), "=f"(psum.data[1]), "=f"(psum.data[2]), "=f"(psum.data[3])
                : 
                "r"(*reinterpret_cast<unsigned int *>(&a.data[0])), 
                "r"(*reinterpret_cast<unsigned int *>(&a.data[1])), 
                "r"(*reinterpret_cast<unsigned int *>(&a.data[2])), 
                "r"(*reinterpret_cast<unsigned int *>(&a.data[3])),
                "r"(*reinterpret_cast<unsigned int *>(&b.data[0])), 
                "r"(*reinterpret_cast<unsigned int *>(&b.data[1])),

                // "r"(0), "r"(0), "r"(0), "r"(0)
                "f"(psum.data[0]), "f"(psum.data[1]), "f"(psum.data[2]), "f"(psum.data[3])
            );
            asm volatile(
                "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
                "{%0,  %1,  %2,  %3},"
                "{%4,  %5,  %6,  %7},"
                "{%8,  %9},"
                "{%10,  %11,  %12,  %13};\n"
                : 
                "=f"(psum.data[4]), "=f"(psum.data[5]), "=f"(psum.data[6]), "=f"(psum.data[7])
                : 
                "r"(*reinterpret_cast<unsigned int *>(&a.data[0])), 
                "r"(*reinterpret_cast<unsigned int *>(&a.data[1])), 
                "r"(*reinterpret_cast<unsigned int *>(&a.data[2])), 
                "r"(*reinterpret_cast<unsigned int *>(&a.data[3])),
                "r"(*reinterpret_cast<unsigned int *>(&b.data[2])), 
                "r"(*reinterpret_cast<unsigned int *>(&b.data[3])),
                // "r"(0), "r"(0), "r"(0), "r"(0)
                "f"(psum.data[4]), "f"(psum.data[5]), "f"(psum.data[6]), "f"(psum.data[7])
            );
        }

        if constexpr (std::is_same_v<half_t, __nv_bfloat16>) {
            asm volatile(
                "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
                "{%0,  %1,  %2,  %3},"
                "{%4,  %5,  %6,  %7},"
                "{%8,  %9},"
                "{%10,  %11,  %12,  %13};\n"
                : 
                "=f"(psum.data[0]), "=f"(psum.data[1]), "=f"(psum.data[2]), "=f"(psum.data[3])
                : 
                "r"(*reinterpret_cast<unsigned int *>(&a.data[0])), 
                "r"(*reinterpret_cast<unsigned int *>(&a.data[1])), 
                "r"(*reinterpret_cast<unsigned int *>(&a.data[2])), 
                "r"(*reinterpret_cast<unsigned int *>(&a.data[3])),
                "r"(*reinterpret_cast<unsigned int *>(&b.data[0])), 
                "r"(*reinterpret_cast<unsigned int *>(&b.data[1])),

                // "r"(0), "r"(0), "r"(0), "r"(0)
                "f"(psum.data[0]), "f"(psum.data[1]), "f"(psum.data[2]), "f"(psum.data[3])
            );
            asm volatile(
                "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
                "{%0,  %1,  %2,  %3},"
                "{%4,  %5,  %6,  %7},"
                "{%8,  %9},"
                "{%10,  %11,  %12,  %13};\n"
                : 
                "=f"(psum.data[4]), "=f"(psum.data[5]), "=f"(psum.data[6]), "=f"(psum.data[7])
                : 
                "r"(*reinterpret_cast<unsigned int *>(&a.data[0])), 
                "r"(*reinterpret_cast<unsigned int *>(&a.data[1])), 
                "r"(*reinterpret_cast<unsigned int *>(&a.data[2])), 
                "r"(*reinterpret_cast<unsigned int *>(&a.data[3])),
                "r"(*reinterpret_cast<unsigned int *>(&b.data[2])), 
                "r"(*reinterpret_cast<unsigned int *>(&b.data[3])),
                // "r"(0), "r"(0), "r"(0), "r"(0)
                "f"(psum.data[4]), "f"(psum.data[5]), "f"(psum.data[6]), "f"(psum.data[7])
            );
        }
        
        return psum;
    }

    __device__ __forceinline__
    static packed_fpsum_t packed_fp32_to_fp16(packed_f32psum_t input) {
        packed_fpsum_t results;
        for (int i = 0; i < 4; i++) {
            results.data[i] = float22half2<half2_t>(float2(input.data[i * 2], input.data[i * 2 + 1]));
        }
        return results;
    }

    __device__ __forceinline__
    static packed_f32psum_t packed_fp16_to_fp32(packed_fpsum_t input) {
        packed_f32psum_t results;
        for (int i = 0; i < 4; i++) {
            float2 tmp = half22float2(input.data[i]);
            results.data[i * 2] = tmp.x;
            results.data[i * 2 + 1] = tmp.y;
        }
        return results;
    }

    __device__ __forceinline__
    static fpsum_warp packed_fp32_to_fp16(f32psum_warp input) {
        fpsum_warp results;
    #pragma unroll
        for (int i = 0; i < results.size(); i++) {
            results[i] = packed_fp32_to_fp16(input[i]);
        }
        return results;
    }

    // activation: row major, [M / BLOCK_M, K / WARP_K, NUM_WARPS, WARP_M_TILES, WARP_SIZE] of packed_act_t
    __device__ __forceinline__ 
    static void load_act(const packed_act_t *act, int k, int K, act_warp &out, bool pred) {
        int laneId = threadIdx.x % WARP_SIZE;
        int warpId = threadIdx.x / WARP_SIZE;
    #pragma unroll
        for (int i = 0; i < WARP_M_TILES; i++) {
            if (pred) {
                // out[i] = load(&act[((warpId * WARP_M_TILES + i) * K / WARP_K + k) * WARP_SIZE + laneId]);
                out[i] = load(&act[((k * NUM_WARPS + warpId) * WARP_M_TILES + i) * WARP_SIZE + laneId]);
            }
        }
    }

    // weight: column major: [N / BLOCK_N, 1, K / WARP_K, WARP_N_TILES, WARP_SIZE] of packed_wgt_t
    __device__ __forceinline__ 
    static void load_wgt(const packed_wgt_t *wgt, int k, int K, wgt_warp &out, bool pred) {
        int laneId = threadIdx.x % WARP_SIZE;
        
        // const packed_wgt_t *ptr = &wgt[(0 * K / WARP_K + k) * WARP_SIZE + laneId];
        const packed_wgt_t *ptr = &wgt[(0 + k * WARP_N_TILES) * WARP_SIZE + laneId];
        // int offset = K / WARP_K * WARP_SIZE;
    #pragma unroll
        for (int i = 0; i < WARP_N_TILES; i++) {
            if (pred) {
                // out[i] = load(&wgt[(i * K / WARP_K + k) * WARP_SIZE + laneId]);
                // out[i] = load(&wgt[(i + k * WARP_N_TILES) * WARP_SIZE + laneId]);
                out[i] = load(&ptr[i * WARP_SIZE]);
                // ptr += offset;
            }
        }
    }

    // ascales: row major [M / BLOCK_M, K / group size, NUM_WARPS, ASCALES_NUM_PACKS, ASCALES_VALID_LANES] of packed_ascale_t
    __device__ __forceinline__ 
    static void load_ascale(const packed_ascale_t *ascales, int group, int M, ascale_warp &out, bool pred) {
        int laneId = threadIdx.x % WARP_SIZE;
        int warpId = threadIdx.x / WARP_SIZE;
    #pragma unroll
        for (int i = 0; i < ASCALES_NUM_PACKS; i++) {
            if (pred && laneId < ASCALES_VALID_LANES) {
                // out[i] = ascales[(group * M / WARP_M + warpId) * ASCALES_VALID_LANES * ASCALES_NUM_PACKS + i * ASCALES_VALID_LANES + laneId];
                out[i] = ascales[(group * NUM_WARPS + warpId) * ASCALES_NUM_PACKS * ASCALES_VALID_LANES + i * ASCALES_VALID_LANES + laneId];

            }
        }
    }

    // wscales: column major [N / BLOCK_N, K / group size, 1, WSCALES_NUM_PACKS, WSCALES_VALID_LANES] of packed_wscale_t </del>
    __device__ __forceinline__
    static void load_wscale(const packed_wscale_t *wscales, int group, int N, wscale_warp &out, bool pred) {
        int laneId = threadIdx.x % WARP_SIZE;

        // static_assert(WSCALES_NUM_PACKS * WSCALES_VALID_LANES == 32);
        // static_assert(sizeof(packed_wscale_t) == 8);

        // const packed_wscale_t *ptr = &wscales[(group * WSCALES_NUM_PACKS + 0) * WSCALES_VALID_LANES + laneId];
        // // const packed_wscale_t *ptr = (const packed_wscale_t *)((const char *)wscales) + ((group * WSCALES_NUM_PACKS + 0) * WSCALES_VALID_LANES + laneId) * sizeof(packed_wscale_t);

    #pragma unroll
        for (int i = 0; i < WSCALES_NUM_PACKS; i++) {
            if (pred && laneId < WSCALES_VALID_LANES) {
                
                // out[i] = wscales[group * N / WARP_N * WSCALES_VALID_LANES * WSCALES_NUM_PACKS + i * WSCALES_VALID_LANES + laneId];
                // out[i] = load(&wscales[group * N / WARP_N * WSCALES_VALID_LANES * WSCALES_NUM_PACKS + i * WSCALES_VALID_LANES + laneId]);
                out[i] = load(&wscales[(group * WSCALES_NUM_PACKS + i) * WSCALES_VALID_LANES + laneId]);
                // out[i] = load(&ptr[i * WSCALES_VALID_LANES]);
            }
        }
    }

    // get {k}-th and {k+1}-th wscale from the block, k must be multiples of 2, k must be uniform across all lanes
    __device__ __forceinline__
    static half2_t broadcast_wscale(wscale_warp block, int k, int laneId) {
        const int packIdx = k / (WSCALES_PACK_SIZE * WARP_SIZE);
        const int srcLane = 4 * (k / WSCALES_PACK_SIZE) + laneId % 4;
        const int elementIdx = k % WSCALES_PACK_SIZE / 2;
        return __shfl_sync(~0, block[packIdx].data[elementIdx], srcLane);
    }
    // get {k}-th and {k+1}-th ascale from the block, k must be multiples of 2, k must be uniform across all lanes
    __device__ __forceinline__
    static half2_t broadcast_ascale(ascale_warp block, int k, int laneId) {
        const int packIdx = k / (ASCALES_PACK_SIZE * WARP_SIZE);
        const int srcLane = 8 * (k / ASCALES_PACK_SIZE) + laneId / 4;
        const int elementIdx = k % ASCALES_PACK_SIZE / 2;
        return __shfl_sync(~0, block[packIdx].data[elementIdx], srcLane);
    }

    template<typename F>
    __device__ __forceinline__
    static void apply_scales(F &&getpsum, ascale_warp ascale, wscale_warp wscale, fpsum_warp &fpsum) {
        const int laneId = threadIdx.x % WARP_SIZE;
        const int warpId = threadIdx.x / WARP_SIZE;

        half2_t asx[WARP_M_TILES];
        half2_t asy[WARP_M_TILES];

        for (int i = 0; i < WARP_M_TILES; i++) {
            half2_t as = broadcast_ascale(ascale, i * 2, laneId);
            asx[i] = half2_t(as.x, as.x);
            asy[i] = half2_t(as.y, as.y);
        }

        for (int j = 0; j < WARP_N_TILES; j++) {
            half2_t ws1 = broadcast_wscale(wscale, j * 4, laneId);
            half2_t ws2 = broadcast_wscale(wscale, j * 4 + 2, laneId);

            for (int i = 0; i < WARP_M_TILES; i++) {
                auto &fsum = fpsum[i * WARP_N_TILES + j];

                packed_psum_t psum = getpsum(i, j);

                // constexpr int target = 0;
                // if (threadIdx.x == 3 && j == 1 && i == 0) {
                    
                //     printf("before ws2 = %f %f fsum.data[%d] = %f %f\n", (float)ws2.x, (float)ws2.y, target, (float)fsum.data[target].x, (float)fsum.data[target].y);
                // }
                
                fsum.data[0] = __hfma2(float22half2<half2_t>(make_float2(__int2float_rn(psum.data[0]), __int2float_rn(psum.data[1]))), __hmul2(asx[i], ws1), fsum.data[0]);
                fsum.data[1] = __hfma2(float22half2<half2_t>(make_float2(__int2float_rn(psum.data[2]), __int2float_rn(psum.data[3]))), __hmul2(asy[i], ws1), fsum.data[1]);
                fsum.data[2] = __hfma2(float22half2<half2_t>(make_float2(__int2float_rn(psum.data[4]), __int2float_rn(psum.data[5]))), __hmul2(asx[i], ws2), fsum.data[2]);
                fsum.data[3] = __hfma2(float22half2<half2_t>(make_float2(__int2float_rn(psum.data[6]), __int2float_rn(psum.data[7]))), __hmul2(asy[i], ws2), fsum.data[3]);

                // if (threadIdx.x == 3 && j == 1 && i == 0) {
                //     printf("before ws2 = %f %f fsum.data[%d] = %f %f\n", (float)ws2.x, (float)ws2.y, target, (float)fsum.data[target].x, (float)fsum.data[target].y);
                // }
            }
        }
    }

    template<typename F>
    __device__ __forceinline__
    static void apply_scales(F &&getpsum, ascale_warp ascale, wscale_warp wscale, f32psum_warp &fpsum) {
        const int laneId = threadIdx.x % WARP_SIZE;
        const int warpId = threadIdx.x / WARP_SIZE;

        float2 asx[WARP_M_TILES];
        float2 asy[WARP_M_TILES];

        for (int i = 0; i < WARP_M_TILES; i++) {
            half2_t as = broadcast_ascale(ascale, i * 2, laneId);
            asx[i] = half22float2(half2_t(as.x, as.x));
            asy[i] = half22float2(half2_t(as.y, as.y));
        }

        auto fma2 = [](float2 a, float2 b, float &cx, float &cy) ALWAYSINLINE {
            cx += a.x * b.x;
            cy += a.y * b.y;
        };

        for (int j = 0; j < WARP_N_TILES; j++) {
            float2 ws1 = half22float2(broadcast_wscale(wscale, j * 4, laneId));
            float2 ws2 = half22float2(broadcast_wscale(wscale, j * 4 + 2, laneId));

            for (int i = 0; i < WARP_M_TILES; i++) {
                auto &fsum = fpsum[i * WARP_N_TILES + j];

                packed_psum_t psum = getpsum(i, j);

                fma2(make_float2(__int2float_rn(psum.data[0]), __int2float_rn(psum.data[1])), asx[i] * ws1, fsum.data[0], fsum.data[1]);
                fma2(make_float2(__int2float_rn(psum.data[2]), __int2float_rn(psum.data[3])), asy[i] * ws1, fsum.data[2], fsum.data[3]);
                fma2(make_float2(__int2float_rn(psum.data[4]), __int2float_rn(psum.data[5])), asx[i] * ws2, fsum.data[4], fsum.data[5]);
                fma2(make_float2(__int2float_rn(psum.data[6]), __int2float_rn(psum.data[7])), asy[i] * ws2, fsum.data[6], fsum.data[7]);
            }
        }
    }

    /**
     * input: WARP_M of half (in shared memory, per-warp)
     * output: [..., ASCALES_NUM_PACKS, ASCALES_VALID_LANES] in global memory, per-warp
     */
    __device__ __forceinline__
    static void pack_ascales(const half_t *input, packed_ascale_t *output) {
        const int laneId = threadIdx.x % WARP_SIZE;
    #pragma unroll
        for (int j = 0; j < ASCALES_NUM_PACKS; j++) {
            if (laneId < ASCALES_VALID_LANES) {
                packed_ascale_t tmp;
    #pragma unroll
                for (int i = 0; i < ASCALES_PACK_SIZE; i += 2) {
                    tmp.data[i / 2].x = input[j * ASCALES_PACK_SIZE * WARP_SIZE + laneId / 8 * 8 * ASCALES_PACK_SIZE + laneId % 8 + i * 8];
                    tmp.data[i / 2].y = input[j * ASCALES_PACK_SIZE * WARP_SIZE + laneId / 8 * 8 * ASCALES_PACK_SIZE + laneId % 8 + (i + 1) * 8];
                }
                output[j * ASCALES_VALID_LANES + laneId] = tmp;
            }
        }
    }

    /**
     * input: WARP_N of half (in shared memory, per-warp)
     * output: [..., WSCALES_NUM_PACKS, WSCALES_VALID_LANES] in global memory, per-warp
     */
    __device__ __forceinline__
    static void pack_wscales(const half_t *input, packed_wscale_t *output) {
        const int laneId = threadIdx.x % WARP_SIZE;
    #pragma unroll
        for (int j = 0; j < WSCALES_NUM_PACKS; j++) {
            if (laneId < WSCALES_VALID_LANES) {
                packed_wscale_t tmp;
    #pragma unroll
                for (int i = 0; i < WSCALES_PACK_SIZE; i += 2) {
                    tmp.data[i / 2] = *reinterpret_cast<const half2_t *>(&input[j * WSCALES_PACK_SIZE * WARP_SIZE + laneId / 4 * 4 * WSCALES_PACK_SIZE + laneId % 4 * 2 + i * 4]);
                }
                store(&output[j * WSCALES_VALID_LANES + laneId], tmp);
            }
        }
    }

    struct unpack_fpsum {
        // +8 to prevent bank conflicts
        using matrix_t = half_t[8][WARP_N + 8];

        static constexpr int SHMEM_SIZE = sizeof(matrix_t);
        static constexpr int PACK_SIZE = WARP_N / WARP_SIZE;
        using pack_t = std::array<half_t, PACK_SIZE>;

        // F (int rowId, pack_t &pack)
        template<typename ...F>
        __device__ __forceinline__
        void operator()(fpsum_warp fpsum, half_t *output, int stride, int maxRows, int maxCols, void *shmem, F &&...plugins) {
            const int laneId = threadIdx.x % WARP_SIZE;
            
            matrix_t &mat = *reinterpret_cast<matrix_t *>(shmem);


            // pack_t reduce_tmp;
            // constexpr bool enableReduce = !std::is_void_v<FuncReduce>;

            // if constexpr (enableReduce) {
            //     reduce_tmp.fill(reduce_initval);
            //     // reduce_tmp = load<true>(reinterpret_cast<pack_t *>(&reduce_result[laneId * PACK_SIZE]));
            // }
            // auto doReduce = [&reduce_tmp](pack_t pack) {
            //     if constexpr (enableReduce) {
            //         for (int i = 0; i < PACK_SIZE; i++) {
            //             reduce_tmp[i] = FuncReduce()(reduce_tmp[i], pack[i]);
            //         }
            //     }
            // };

            #pragma unroll
            for (int i = 0; i < WARP_M_TILES; i++) {
            #pragma unroll
                for (int j = 0; j < WARP_N_TILES; j++) {
                    packed_fpsum_t &fsum = fpsum[i * WARP_N_TILES + j];
                    int row = laneId / 4;
                    int col = laneId % 4 * 2 + j * INSN_N;
                    *reinterpret_cast<half2_t *>(&mat[row][col + 0]) = fsum.data[0];
                    *reinterpret_cast<half2_t *>(&mat[row][col + 8]) = fsum.data[2];
                }
                __syncwarp();

            #pragma unroll
                for (int row = 0; row < 8; row++) {
                    pack_t pack = *reinterpret_cast<pack_t *>(&mat[row][laneId * PACK_SIZE]);

                    // if constexpr (enableReduce) {
                    //     doReduce(pack);
                    // }

                    (plugins(i * INSN_M + row, pack), ...);

                    bool pred = i * INSN_M + row < maxRows && laneId * PACK_SIZE < maxCols;
                    if (pred) {
                        store(reinterpret_cast<pack_t *>(&output[(i * INSN_M + row) * stride + laneId * PACK_SIZE]), pack);
                    }
                }
                __syncwarp();

            #pragma unroll
                for (int j = 0; j < WARP_N_TILES; j++) {
                    packed_fpsum_t &fsum = fpsum[i * WARP_N_TILES + j];
                    int row = laneId / 4;
                    int col = laneId % 4 * 2 + j * INSN_N;
                    *reinterpret_cast<half2_t *>(&mat[row][col + 0]) = fsum.data[1];
                    *reinterpret_cast<half2_t *>(&mat[row][col + 8]) = fsum.data[3];
                }
                __syncwarp();

            #pragma unroll
                for (int row = 0; row < 8; row++) {
                    pack_t pack = *reinterpret_cast<pack_t *>(&mat[row][laneId * PACK_SIZE]);

                    // if constexpr (enableReduce) {
                    //     doReduce(pack);
                    // }

                    (plugins(i * INSN_M + 8 + row, pack), ...);

                    bool pred = i * INSN_M + 8 + row < maxRows && laneId * PACK_SIZE < maxCols;
                    if (pred) {
                        store(reinterpret_cast<pack_t *>(&output[(i * INSN_M + 8 + row) * stride + laneId * PACK_SIZE]), pack);
                    }
                }
                __syncwarp();
            }
            // if (enableReduce) {
            //     store<true>(reinterpret_cast<pack_t *>(&reduce_result[laneId * PACK_SIZE]), reduce_tmp);
            // }
        }
    };

    
    template<typename F>
    __device__ __forceinline__
    static fpsum_warp apply_act(fpsum_warp fpsum, F func) {
        fpsum_warp result;
    #pragma unroll
        for (int i = 0; i < WARP_M_TILES; i++) {
    #pragma unroll
            for (int j = 0; j < WARP_N_TILES; j++) {
    #pragma unroll
                for (int k = 0; k < 4; k++) {
                    half2_t &dst = result[i * WARP_N_TILES + j].data[k];
                    half2_t src = fpsum[i * WARP_N_TILES + j].data[k];
                    dst.x = func(src.x);
                    dst.y = func(src.y);
                }
            }
        }
        return result;
    }
    

    struct EpilogueDefault {
        struct Arguments {
            half_t *out;
            int actualM, actualN;
        };

        __device__ __forceinline__
        void operator()(const BlockInfo binfo, fpsum_warp fpsum, int M, int N, int K, Arguments args) {
            const int warpId = threadIdx.x / WARP_SIZE;
            
            __shared__ alignas(128) uint8_t shmem[NUM_WARPS][ceilDiv(unpack_fpsum::SHMEM_SIZE, 128) * 128];

            const int m_offset = binfo.bm * BLOCK_M + warpId * WARP_M;
            const int n_offset = binfo.bn * BLOCK_N;

            unpack_fpsum()(
                fpsum, 
                args.out + m_offset * args.actualN + n_offset, 
                args.actualN, 
                args.actualM - m_offset,
                args.actualN - n_offset,
                shmem[warpId],
                [&](int rowId, unpack_fpsum::pack_t &pack) ALWAYSINLINE {
                    if constexpr (std::is_same_v<half_t, half>) {
                    #pragma unroll
                        for (int i = 0; i < pack.size(); i++) {
                            pack[i] = __hmin(pack[i], (half)65504);
                            pack[i] = __hmax(pack[i], (half)-65504);
                        }
                    }
                }
            );
        }
    };

    struct EpilogueNop {
        // workaround for layout mismatch between host and device code
        struct Arguments { size_t unused; };

        __device__ __forceinline__
        void operator()(const BlockInfo binfo, fpsum_warp fpsum, int M, int N, int K, Arguments args) {
        }
    };

    struct EpilogueBias {
        struct Arguments {
            const packed_wscale_t *bias;  // [N / BLOCK_N, WSCALES_NUM_PACKS, WSCALES_VALID_LANES] of packed_wscale_t
        };

        __device__ __forceinline__
        void apply_bias(fpsum_warp &fpsum, int M, int N, int K, const packed_wscale_t *bias) {
            const int laneId = threadIdx.x % WARP_SIZE;

            // if (laneId == 0) {
            //     printf("block.x=%d block.y=%d warpId=%d bias=%p\n", blockIdx.x, blockIdx.y, threadIdx.x / WARP_SIZE, bias);
            // }

            wscale_warp b;
            load_wscale(bias, 0, N, b, true);

            for (int j = 0; j < WARP_N_TILES; j++) {
                half2_t b1 = broadcast_wscale(b, j * 4, laneId);
                half2_t b2 = broadcast_wscale(b, j * 4 + 2, laneId);

                for (int i = 0; i < WARP_M_TILES; i++) {
                    auto &fsum = fpsum[i * WARP_N_TILES + j];

                    fsum.data[0] = __hadd2(fsum.data[0], b1);
                    fsum.data[1] = __hadd2(fsum.data[1], b1);
                    fsum.data[2] = __hadd2(fsum.data[2], b2);
                    fsum.data[3] = __hadd2(fsum.data[3], b2);
                }
            }
        }

        __device__ __forceinline__
        void operator()(const BlockInfo binfo, fpsum_warp &fpsum, int M, int N, int K, Arguments args) {
            const int bn = binfo.bn;
            apply_bias(
                fpsum, M, N, K,
                args.bias + bn * WSCALES_NUM_PACKS * WSCALES_VALID_LANES
            );
        }
    };

    struct EpilogueSilu {
        struct Arguments { size_t unused; };

        __device__ __forceinline__
        void operator()(const BlockInfo binfo, fpsum_warp &fpsum, int M, int N, int K, Arguments args) {
            fpsum = apply_act(fpsum, [](half_t x) { return silu(x); });
        }
    };

    template<typename ...Epilogues>
    struct EpilogueCombination {
        using Arguments = std::tuple<typename Epilogues::Arguments...>;

        __device__ __forceinline__
        void operator()(const BlockInfo binfo, fpsum_warp &fpsum, int M, int N, int K, Arguments args) {
            // this function makes intellisense crashes :(
    #if __INTELLISENSE__
            __trap();   // should not happen when actually compiling
    #else
            std::tuple<Epilogues...> epilogues;
            auto run = [&]<size_t idx>() {
                std::get<idx>(epilogues).operator()(binfo, fpsum, M, N, K, std::get<idx>(args));
            };
            auto foreach = [&]<size_t ...Is>(std::index_sequence<Is...>) {
                (run.template operator()<Is>(), ...);
            };
            foreach(std::make_index_sequence<sizeof...(Epilogues)>());
    #endif
        }
    };


};

#define IMPORT_GEMM_BASE(config) \
    using Base = GEMMBase<config>;  \
    using Base::BLOCK_M;    \
    using Base::BLOCK_N;    \
    using Base::WARP_SIZE;  \
    using Base::NUM_WARPS;  \
    using Base::INSN_M; \
    using Base::INSN_N; \
    using Base::INSN_K; \
    using typename Base::half_t;    \
    using typename Base::half2_t;   \
    using Base::WARP_M; \
    using Base::WARP_N; \
    using Base::WARP_K; \
    using Base::WARP_M_TILES;   \
    using Base::WARP_N_TILES;   \
    using Base::WARP_K_TILES;   \
    using Base::WSCALES_PACK_SIZE;  \
    using Base::WSCALES_NUM_PACKS;  \
    using Base::WSCALES_VALID_LANES;    \
    using Base::ASCALES_PACK_SIZE;  \
    using Base::ASCALES_NUM_PACKS;  \
    using Base::ASCALES_VALID_LANES;    \
    using typename Base::packed_act_t;  \
    using typename Base::packed_wgt_t;  \
    using typename Base::packed_psum_t; \
    using typename Base::packed_fpsum_t;    \
    using typename Base::packed_gated_fpsum_t;  \
    using typename Base::packed_f32psum_t;  \
    using typename Base::packed_wscale_t;   \
    using typename Base::packed_ascale_t;   \
    using typename Base::act_warp;  \
    using typename Base::wgt_warp;  \
    using typename Base::ascale_warp;   \
    using typename Base::wscale_warp;   \
    using typename Base::fpsum_warp;    \
    using typename Base::f32psum_warp;  \
    using typename Base::gated_fpsum_warp;  \
    using typename Base::BlockInfo; \
    using typename Base::unpack_fpsum;  \
    using typename Base::EpilogueDefault;   \
    using typename Base::EpilogueNop;   \
    using typename Base::EpilogueBias;


template<typename kernel, typename ...T>
__global__
static void invoke_kernel(T ...args) {
    kernel()(args...);
}

template<typename T>
__global__
static void test_sizeof_device() {
    printf("sizeof on device = %d\n", (int)sizeof(T));
}

template<typename T>
static void test_sizeof_host() {
    printf("sizeof on host = %d\n", (int)sizeof(T));
}

template<typename T>
static void test_sizeof() {
    printf("typeid = %s\n", typeid(T).name());
    test_sizeof_host<T>();
    test_sizeof_device<T><<<1, 1>>>();
    checkCUDA(cudaDeviceSynchronize());
}

};  // namespace nunchaku::kernels