gemm_w4a4_launch_impl.cuh 25.1 KB
Newer Older
muyangli's avatar
muyangli committed
1
2
3
4
5
#include "gemm_w4a4_launch.cuh"

namespace nunchaku::kernels {

#ifndef __INTELLISENSE__
6
7
template<typename Config, bool USE_FP4>
void GEMM_W4A4_Launch<Config, USE_FP4>::gemm_w4a4(
muyangli's avatar
muyangli committed
8
9
#else
template<>
10
void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
muyangli's avatar
muyangli committed
11
#endif
Muyang Li's avatar
Muyang Li committed
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
    Tensor act,            // packed act [M, K / 2]
    Tensor wgt,            // packed act [N, K / 2]
    Tensor out,            // linear     [M, N]
    Tensor qout,           // packed act [M, N / 2]
    Tensor ascales,        // packed as  [K / 64, M]
    Tensor wscales,        // packed ws  [K / 64, N]
    Tensor oscales,        // packed as  [N / 64, M]
    Tensor poolout,        // linear     [M / PoolSize, N]
    Tensor lora_act_in,    // packed lora_act [M, R]
    Tensor lora_up,        // packed lora_wgt [N, R]
    Tensor lora_down,      // packed lora_wgt [N, R]
    Tensor lora_act_out,   // packed lora_act [M, R]
    Tensor norm_q,         // linear     [HEAD_DIM]
    Tensor norm_k,         // linear     [HEAD_DIM]
    Tensor rotary_emb,     // linear     [M, HEAD_DIM / 2, 2, 2]
    Tensor bias,           // packed ws  [N]
    Tensor smooth_factor,  // packed ws  [N], for quantization of the next layer
    Tensor out_vk,         // linear     [B, num_heads, head_dim + 1, head_dim]
    Tensor out_linearattn, // linear     [B, (M), N / 3]
muyangli's avatar
muyangli committed
31
    bool act_unsigned,
Muyang Li's avatar
Muyang Li committed
32
    std::vector<float> lora_scales, // [R / 16]
33
34
35
    bool fuse_silu,
    bool fp4,
    float alpha,
Muyang Li's avatar
Muyang Li committed
36
37
38
39
40
    Tensor wcscales, // packed ws  [N]
    Tensor out_q,    // packed attention [B, H, M, D]
    Tensor out_k,    // packed attention [B, H, M, D]
    Tensor out_v,    // packed attention [B, H, M, D]
    int attn_tokens) {
41
42
43
44
45
#ifdef __INTELLISENSE__
    static constexpr bool USE_FP4 = false;
#endif
    assert(fp4 == USE_FP4);

muyangli's avatar
muyangli committed
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
    int M = act.numel() / act.shape[-1];
    int N = wgt.shape[0];
    int K = act.shape[-1] * 2;
    assert(K == wgt.shape[1] * 2);

    int actualM = 0;
    int actualN = 0;
    if (out.valid()) {
        actualM = out.numel() / out.shape[-1];
        actualN = out.shape[-1];

        assert(actualM <= M && M - actualM < GEMM::BLOCK_M);
        assert(actualN <= N && N - actualN < GEMM::BLOCK_N);
    }

    spdlog::trace("gemm_w4a4: M={} N={} K={}", M, N, K);
    spdlog::trace("act at {}", act.data_ptr());
    spdlog::trace("wgt at {}", wgt.data_ptr());
    spdlog::trace("ascales at {}", ascales.data_ptr());
    spdlog::trace("wscales at {}", wscales.data_ptr());
    if (bias.valid()) {
        spdlog::trace("bias at {}", bias.data_ptr());
    }

    int shmem = 0;

    auto launch = [&]<typename Epilogue>(Epilogue::Arguments args) {
        assert(M % GEMM::BLOCK_M == 0);
        assert(N % GEMM::BLOCK_N == 0);
        dim3 grid(M / GEMM::BLOCK_M, N / GEMM::BLOCK_N);

        bool swapBlockMN = M > N * 2;
        if (swapBlockMN) {
            std::swap(grid.x, grid.y);
        }

82
83
84
85
86
87
88
89
90
        // test_sizeof<typename Epilogue::Arguments>();
        // std::apply([](auto ...args) {
        //     (test_sizeof<decltype(args)>(), ...);
        // }, args);

        // constexpr bool FP4_AVAILABLE = __CUDA_ARCH__ >= 1200;

        if constexpr (!USE_FP4) {
            dispatchBool(act_unsigned, [&]<bool ACT_UNSIGNED>() {
Muyang Li's avatar
Muyang Li committed
91
92
93
94
95
96
97
98
99
100
101
                auto func = invoke_kernel<typename GEMM::gemm_w4a4_kernel<Epilogue, ACT_UNSIGNED>,
                                          const packed_act_t *,
                                          const packed_wgt_t *,
                                          const packed_ascale_t *,
                                          const packed_wscale_t *,
                                          int,
                                          int,
                                          int,
                                          typename Epilogue::Arguments,
                                          bool,
                                          bool>;
102
103
104
105

                if (shmem >= 24 * 1024) {
                    checkCUDA(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
                }
muyangli's avatar
muyangli committed
106

107
                assert(alpha == 1.0f);
Muyang Li's avatar
Muyang Li committed
108

109
110
111
112
113
                func<<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS, shmem, getCurrentCUDAStream()>>>(
                    act.data_ptr<packed_act_t>(),
                    wgt.data_ptr<packed_wgt_t>(),
                    ascales.data_ptr<packed_ascale_t>(),
                    wscales.data_ptr<packed_wscale_t>(),
Muyang Li's avatar
Muyang Li committed
114
115
116
                    M,
                    N,
                    K,
117
118
                    args,
                    swapBlockMN,
Muyang Li's avatar
Muyang Li committed
119
                    false);
120
121
122
123
                checkCUDA(cudaGetLastError());
            });
            return;
        }
124

125
126
127
128
        if constexpr (USE_FP4) {
            dispatchBool(alpha != 1.0f, [&]<bool USE_ALPHA>() {
                assert(!act_unsigned);

Muyang Li's avatar
Muyang Li committed
129
130
131
132
133
134
135
136
137
138
139
140
                auto func = invoke_kernel<typename GEMM::gemm_w4a4_fp4_kernel<Epilogue, USE_ALPHA>,
                                          const packed_act_t *,
                                          const packed_wgt_t *,
                                          const packed_amscale_t *,
                                          const packed_wmscale_t *,
                                          float,
                                          int,
                                          int,
                                          int,
                                          typename Epilogue::Arguments,
                                          bool,
                                          bool>;
141
142
143
144
145
146
147

                if (shmem >= 24 * 1024) {
                    checkCUDA(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
                }

                assert(ascales.dtype() == Tensor::FP8_E4M3);
                assert(wscales.dtype() == Tensor::FP8_E4M3);
Muyang Li's avatar
Muyang Li committed
148

149
150
151
152
153
154
                func<<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS, shmem, getCurrentCUDAStream()>>>(
                    act.data_ptr<packed_act_t>(),
                    wgt.data_ptr<packed_wgt_t>(),
                    ascales.data_ptr<packed_amscale_t>(),
                    wscales.data_ptr<packed_wmscale_t>(),
                    alpha,
Muyang Li's avatar
Muyang Li committed
155
156
157
                    M,
                    N,
                    K,
158
159
                    args,
                    swapBlockMN,
Muyang Li's avatar
Muyang Li committed
160
                    false);
161
162
                checkCUDA(cudaGetLastError());
            });
Muyang Li's avatar
Muyang Li committed
163

164
165
166
167
168
169
            return;
        }

        // if constexpr (USE_FP4 && !FP4_AVAILABLE) {
        //     throw std::runtime_error("FP4 kernel is not available");
        // }
muyangli's avatar
muyangli committed
170
171
172
    };

    auto launch_bias = [&]<typename NextEpilogue>(NextEpilogue::Arguments nextArgs) {
173
174
175
176
177
178
        assert(!bias.valid() || bias.numel() == N);
        assert(!wcscales.valid() || wcscales.numel() == N);

        dispatchBool(bias.valid(), [&]<bool USE_BIAS>() {
            dispatchBool(wcscales.valid(), [&]<bool USE_SCALE>() {
                using EpilogueBias = typename GEMM::EpilogueBias<USE_BIAS, USE_SCALE>;
Muyang Li's avatar
Muyang Li committed
179
180
                // append EpilgoueNop to workaround mismatched memory layout of std::tuple between device and host code
                // on Windows
181
                // ** sizeof(std::tuple<std::tuple<int>>) == 8 on device **
Muyang Li's avatar
Muyang Li committed
182
183
184
185
186
187
188
189
190
                using Epilogue =
                    typename GEMM::EpilogueCombination<EpilogueBias, NextEpilogue, typename GEMM::EpilogueNop>;
                return launch.template operator()<Epilogue>(
                    {typename EpilogueBias::Arguments{
                         .bias  = USE_BIAS ? bias.data_ptr<packed_wscale_t>() : nullptr,
                         .scale = USE_SCALE ? wcscales.data_ptr<packed_wscale_t>() : nullptr,
                     },
                     nextArgs,
                     {}});
191
            });
muyangli's avatar
muyangli committed
192
193
194
195
        });
    };
    // auto launch_bias = launch;

Muyang Li's avatar
Muyang Li committed
196
197
    auto launch_lora = [&]<typename NextEpilogue, typename MidEpilogue>(NextEpilogue::Arguments nextArgs,
                                                                        MidEpilogue::Arguments midArgs) {
muyangli's avatar
muyangli committed
198
199
200
        assert(lora_up.valid() == lora_act_in.valid());
        assert(lora_down.valid() == lora_act_out.valid());

Muyang Li's avatar
Muyang Li committed
201
        const int rank_up   = lora_up.valid() ? lora_up.shape[1] : 0;
sxtyzhangzk's avatar
sxtyzhangzk committed
202
203
204
205
        const int rank_down = lora_down.valid() ? lora_down.shape[1] : 0;

        if (rank_up == 0) {
            assert(rank_down == 0);
Muyang Li's avatar
Muyang Li committed
206
207
            return launch_bias.template operator()<typename GEMM::EpilogueCombination<MidEpilogue, NextEpilogue>>(
                {midArgs, nextArgs});
muyangli's avatar
muyangli committed
208
209
        }

sxtyzhangzk's avatar
sxtyzhangzk committed
210
        assert(rank_up % 16 == 0);
muyangli's avatar
muyangli committed
211
212
213
214
215
216

        assert(lora_up.shape[0] == N);
        // assert(lora_up.shape[1] == Lora::LORA_RANK);
        assert(lora_act_in.shape[0] == M);
        assert(lora_act_in.shape[1] == rank_up);

Muyang Li's avatar
Muyang Li committed
217
        using LoraUp  = Lora;
sxtyzhangzk's avatar
sxtyzhangzk committed
218
        using scale_t = typename LoraUp::scale_t;
muyangli's avatar
muyangli committed
219

sxtyzhangzk's avatar
sxtyzhangzk committed
220
221
222
223
        scale_t scales;
        if constexpr (scales.size() > 0) {
            for (size_t i = 0; i < scales.size(); i++) {
                scales[i] = i < lora_scales.size() ? lora_scales[i] : 0.0f;
muyangli's avatar
muyangli committed
224
            }
sxtyzhangzk's avatar
sxtyzhangzk committed
225
        }
muyangli's avatar
muyangli committed
226

sxtyzhangzk's avatar
sxtyzhangzk committed
227
        if (rank_down == 0) {
Muyang Li's avatar
Muyang Li committed
228
229
230
231
232
233
234
235
236
237
238
239
240
241
            using Epilogue = typename GEMM::EpilogueCombination<typename LoraUp::EpilogueLoraUp,
                                                                MidEpilogue,
                                                                NextEpilogue,
                                                                typename GEMM::EpilogueNop>;
            return launch_bias.template operator()<Epilogue>({typename LoraUp::EpilogueLoraUp::Arguments{
                                                                  .lora_act    = lora_act_in.data_ptr<float>(),
                                                                  .lora_wgt_up = lora_up.data_ptr<packed_fpsum_t>(),
                                                                  .rank        = rank_up,
                                                                  .scales      = scales,
                                                                  .alwaysfalse = false,
                                                              },
                                                              midArgs,
                                                              nextArgs,
                                                              {}});
sxtyzhangzk's avatar
sxtyzhangzk committed
242
        }
muyangli's avatar
muyangli committed
243

sxtyzhangzk's avatar
sxtyzhangzk committed
244
245
246
247
248
249
250
251
252
253
254
255
256
        // assert(rank_down == rank_up);
        assert(rank_down % 16 == 0);

        assert(lora_down.shape[0] == N);
        // assert(lora_down.shape[1] == Lora::LORA_RANK);
        assert(lora_act_out.shape[0] == M);
        assert(lora_act_out.shape[1] == rank_down);

        lora_act_out.zero_();

        // dispatchVal(rank_down, std::integer_sequence<int, 16, 32, 48, 64, 80>(), [&]<int RANK_DOWN>() {

        using LoraDown = LoraUp; // GEMM::Lora<RANK_DOWN>;
Muyang Li's avatar
Muyang Li committed
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
        using Epilogue = typename GEMM::EpilogueCombination<typename LoraUp::EpilogueLoraUp,
                                                            MidEpilogue,
                                                            typename LoraDown::EpilogueLoraDown,
                                                            NextEpilogue,
                                                            typename GEMM::EpilogueNop>;
        return launch_bias.template operator()<Epilogue>({typename LoraUp::EpilogueLoraUp::Arguments{
                                                              .lora_act    = lora_act_in.data_ptr<float>(),
                                                              .lora_wgt_up = lora_up.data_ptr<packed_fpsum_t>(),
                                                              .rank        = rank_up,
                                                              .scales      = scales,
                                                              .alwaysfalse = false,
                                                          },
                                                          midArgs,
                                                          typename LoraDown::EpilogueLoraDown::Arguments{
                                                              .lora_wgt_down = lora_down.data_ptr<packed_fpsum_t>(),
                                                              .lora_act      = lora_act_out.data_ptr<float>(),
                                                              .rank          = rank_down,
                                                              .alwaysfalse   = false,
                                                          },
                                                          nextArgs,
                                                          {}});
sxtyzhangzk's avatar
sxtyzhangzk committed
278
279

        // });
muyangli's avatar
muyangli committed
280
281
282
283
284
285
286
287
    };

    if (qout.valid() && oscales.valid()) {

        // dispatchBool(qout_unsigned, [&]<bool USE_UNSIGNED>() {

        static constexpr float SHIFT_GELU = 0.171875f;

288
        constexpr bool USE_UNSIGNED = !USE_FP4;
Muyang Li's avatar
Muyang Li committed
289
290
291
292
293
294
        using EpilogueQuantize      = typename GEMM::EpilogueQuantize<false, USE_UNSIGNED, USE_FP4>;
        auto argsQuantize =
            typename EpilogueQuantize::Arguments{.qout    = qout.data_ptr<packed_act_t>(),
                                                 .oscales = oscales.data_ptr<typename EpilogueQuantize::oscales_t>(),
                                                 .shift_value   = USE_FP4 ? 0.0f : SHIFT_GELU,
                                                 .smooth_factor = smooth_factor.data_ptr<packed_wscale_t>()};
295
296
297

        // TODO: check if gelu is needed
        if (out.valid()) {
Muyang Li's avatar
Muyang Li committed
298
299
300
301
302
303
304
305
306
            launch_lora.template
            operator()<typename GEMM::EpilogueCombination<typename GEMM::EpilogueDefault, EpilogueQuantize>,
                       typename Epilogues::EpilogueGelu>({typename GEMM::EpilogueDefault::Arguments{
                                                              .out     = out.data_ptr<half_t>(),
                                                              .actualM = actualM,
                                                              .actualN = actualN,
                                                          },
                                                          argsQuantize},
                                                         {});
307
        } else {
sxtyzhangzk's avatar
sxtyzhangzk committed
308
            launch_lora.template operator()<EpilogueQuantize, typename Epilogues::EpilogueGelu>(argsQuantize, {});
309
        }
muyangli's avatar
muyangli committed
310
311
312
313
314

    } else if (out_linearattn.valid()) {

        assert(out_vk.valid());

sxtyzhangzk's avatar
sxtyzhangzk committed
315
        using Epilogue = typename Epilogues::EpilogueLiteLA;
muyangli's avatar
muyangli committed
316
317
318
319
320
321
322

        assert(out_vk.dtype() == Tensor::FP32);
        assert(out_vk.ndims() == 4);
        assert(out_vk.shape[2] == Epilogue::LITELA_HEAD_DIM + 1);
        assert(out_vk.shape[3] == Epilogue::LITELA_HEAD_DIM);
        assert(out_vk.shape[1] * Epilogue::LITELA_HEAD_DIM * 3 == N);
        int batch_size = out_vk.shape[0];
Muyang Li's avatar
Muyang Li committed
323
        int num_heads  = out_vk.shape[1];
muyangli's avatar
muyangli committed
324
325
326
327
328
329
330
331
332
333
334
335
336
337

        assert(isTypeMatch<half_t>(out_linearattn.dtype()));
        assert(out_linearattn.ndims() == 3);
        assert(out_linearattn.shape[0] == batch_size);
        assert(out_linearattn.shape[2] * 3 == N);
        int num_tokens = out_linearattn.shape[1];

        assert(num_tokens % GEMM::BLOCK_M == 0);
        int num_blocks_per_batch = ceilDiv(num_tokens, GEMM::BLOCK_M);

        shmem = std::max(shmem, Epilogue::SHMEM_SIZE);

        out_vk.zero_();

Muyang Li's avatar
Muyang Li committed
338
339
340
341
342
343
344
345
        launch_lora.template operator()<Epilogue, typename GEMM::EpilogueNop>(
            typename Epilogue::Arguments{
                .out_q                = out_linearattn.data_ptr<half_t>(),
                .out_vk               = out_vk.data_ptr<float>(),
                .num_blocks_per_batch = num_blocks_per_batch,
                .actualM              = M,
            },
            {});
muyangli's avatar
muyangli committed
346
347
348
349
350
351

    } else if (rotary_emb.valid()) {
        assert(norm_q.valid());
        assert(norm_k.valid());
        // assert(isTypeMatch<half_t>(rotary_emb.scalar_type()));
        assert(rotary_emb.scalar_type() == Tensor::FP32);
352
353
        assert(rotary_emb.ndims() == 3);
        assert(rotary_emb.shape[0] * rotary_emb.shape[1] == M);
sxtyzhangzk's avatar
sxtyzhangzk committed
354
        assert(rotary_emb.shape[2] == Epilogues::EpilogueRMSNormRope::HEAD_DIM);
355

Muyang Li's avatar
Muyang Li committed
356
357
358
        // assert(rotary_emb.numel() == M * GEMM::EpilogueQKVProj::HEAD_DIM / 2 *
        // GEMM::EpilogueQKVProj::ROTARY_EMB_NUM_ELEMENTS); launch_lora.template operator()<typename
        // GEMM::EpilogueQKVProj, typename GEMM::EpilogueNop>(typename GEMM::EpilogueQKVProj::Arguments{
359
360
361
362
363
364
365
366
367
368
        //     .out = out.data_ptr<half_t>(),
        //     .actualM = actualM,
        //     .actualN = actualN,
        //     .pool_out = poolout.valid() ? poolout.data_ptr<half_t>() : nullptr,
        //     .rotary_emb = rotary_emb.data_ptr<float>(),
        //     .rmsnorm_weight_q = norm_q.data_ptr<half_t>(),
        //     .rmsnorm_weight_k = norm_k.data_ptr<half_t>(),
        //     .epsilon = 1e-6,
        // }, {});

sxtyzhangzk's avatar
sxtyzhangzk committed
369
        using EpilogueRope = typename Epilogues::EpilogueRMSNormRope;
Muyang Li's avatar
Muyang Li committed
370
371
372
373
374
        auto argsRope      = typename Epilogues::EpilogueRMSNormRope::Arguments{
                 .rotary_emb       = rotary_emb.data_ptr<typename EpilogueRope::packed_rotemb_t>(),
                 .rmsnorm_weight_q = norm_q.data_ptr<half_t>(),
                 .rmsnorm_weight_k = norm_k.data_ptr<half_t>(),
                 .epsilon          = 1e-6,
375
376
377
        };

        if (out_q.valid()) {
Muyang Li's avatar
Muyang Li committed
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
            launch_lora.template
            operator()<typename GEMM::EpilogueCombination<EpilogueRope, typename Epilogues::EpiloguePackQKV>,
                       typename GEMM::EpilogueNop>(
                {argsRope,
                 typename Epilogues::EpiloguePackQKV::Arguments{
                     .out_q        = out_q.data_ptr<typename Epilogues::EpiloguePackQKV::packed_qkv_t>(),
                     .out_k        = out_k.data_ptr<typename Epilogues::EpiloguePackQKV::packed_qkv_t>(),
                     .out_v        = out_v.data_ptr<typename Epilogues::EpiloguePackQKV::packed_qkv_t>(),
                     .actualM      = attn_tokens,
                     .strideHead_q = int(out_q.stride(1) * out_q.scalar_size() /
                                         sizeof(typename Epilogues::EpiloguePackQKV::packed_qkv_t)),
                     .strideHead_k = int(out_k.stride(1) * out_k.scalar_size() /
                                         sizeof(typename Epilogues::EpiloguePackQKV::packed_qkv_t)),
                     .strideHead_v = int(out_v.stride(1) * out_v.scalar_size() /
                                         sizeof(typename Epilogues::EpiloguePackQKV::packed_qkv_t)),
                 }},
                {});
395
        } else {
Muyang Li's avatar
Muyang Li committed
396
397
398
399
400
401
402
403
404
            launch_lora
                .template operator()<typename GEMM::EpilogueCombination<EpilogueRope, typename GEMM::EpilogueDefault>,
                                     typename GEMM::EpilogueNop>({argsRope,
                                                                  typename GEMM::EpilogueDefault::Arguments{
                                                                      .out     = out.data_ptr<half_t>(),
                                                                      .actualM = actualM,
                                                                      .actualN = actualN,
                                                                  }},
                                                                 {});
405
        }
Muyang Li's avatar
Muyang Li committed
406

muyangli's avatar
muyangli committed
407
408
409
410
    } else if (out.valid()) {

        using Epilogue = typename GEMM::EpilogueDefault;
        typename Epilogue::Arguments args{
Muyang Li's avatar
Muyang Li committed
411
            .out     = out.data_ptr<half_t>(),
muyangli's avatar
muyangli committed
412
413
414
415
416
417
418
419
420
421
422
423
424
425
            .actualM = actualM,
            .actualN = actualN,
        };

        if (fuse_silu) {
            launch_lora.template operator()<Epilogue, typename GEMM::EpilogueSilu>(args, {});
        } else {
            launch_lora.template operator()<Epilogue, typename GEMM::EpilogueNop>(args, {});
        }
    } else {
        assert(false);
    }
}

426
427
template<typename Config, bool USE_FP4>
void GEMM_W4A4_Launch<Config, USE_FP4>::linearattn_vk_mul_q(Tensor q, Tensor vk) {
sxtyzhangzk's avatar
sxtyzhangzk committed
428
    using Epilogue = typename Epilogues::EpilogueLiteLA;
muyangli's avatar
muyangli committed
429
430

    int batch_size = vk.shape[0];
Muyang Li's avatar
Muyang Li committed
431
    int num_heads  = vk.shape[1];
muyangli's avatar
muyangli committed
432
433
434
435
436
437
438
439
440
441
442
443
    int num_tokens = q.shape[1];

    assert(isTypeMatch<half_t>(q.scalar_type()));
    assert(vk.scalar_type() == Tensor::FP32);

    int BLOCK_SIZE;
    if (num_tokens % 256 == 0) {
        BLOCK_SIZE = 256;
    } else {
        BLOCK_SIZE = 128;
    }

Muyang Li's avatar
Muyang Li committed
444
445
446
    invoke_kernel<typename Epilogue::vk_mul_q_kernel>
        <<<dim3(ceilDiv(num_tokens, BLOCK_SIZE), num_heads, batch_size), BLOCK_SIZE, 0, getCurrentCUDAStream()>>>(
            q.data_ptr<half_t>(), vk.data_ptr<float>(), 1e-6f, num_tokens);
muyangli's avatar
muyangli committed
447
448
449
    checkCUDA(cudaGetLastError());
}

450
template<typename Config, bool USE_FP4>
Muyang Li's avatar
Muyang Li committed
451
452
453
454
455
456
457
458
void GEMM_W4A4_Launch<Config, USE_FP4>::quantize_w4a4_act_fuse_lora(Tensor input,
                                                                    Tensor output,
                                                                    Tensor oscales,
                                                                    Tensor lora_down,
                                                                    Tensor lora_act_out,
                                                                    Tensor smooth,
                                                                    bool fuse_glu,
                                                                    bool fp4) {
muyangli's avatar
muyangli committed
459
460
461
462
463
464
465
466
467
468
469
    const int actualM = input.numel() / input.shape[-1];
    const int actualN = input.shape[-1];

    const int M = ceilDiv(actualM, GEMM::BLOCK_M) * GEMM::BLOCK_M;
    const int N = ceilDiv(actualN / (fuse_glu ? 2 : 1), GEMM::BLOCK_N) * GEMM::BLOCK_N;

    assert(output.dtype() == Tensor::INT8);
    assert(output.numel() / output.shape[-1] == M);
    assert(output.shape[-1] == N / 2);

    // assert(oscales.dtype() == Tensor::FP16);
470
471
472
473
474
475
476
    if (fp4) {
        assert(oscales.dtype() == Tensor::FP8_E4M3);
        assert(oscales.numel() == M * N / GEMM::WARP_K * 4);
    } else {
        assert(isTypeMatch<half_t>(oscales.dtype()));
        assert(oscales.numel() == M * N / GEMM::WARP_K);
    }
muyangli's avatar
muyangli committed
477
478
479

    const int rank = lora_down.shape[1];

sxtyzhangzk's avatar
sxtyzhangzk committed
480
481
    assert(rank % 16 == 0);

muyangli's avatar
muyangli committed
482
483
484
485
486
487
488
489
490
    assert(lora_down.shape[0] == N);
    // assert(lora_down.shape[1] == Lora::LORA_RANK);
    assert(lora_act_out.shape[0] == M);
    assert(lora_act_out.shape[1] == rank);

    lora_act_out.zero_();

    dim3 grid(M / GEMM::BLOCK_M, N / GEMM::BLOCK_N);

sxtyzhangzk's avatar
sxtyzhangzk committed
491
492
493
494
495
496
497
498
499
    // dispatchVal(rank, LoraRanks(), [&]<int RANK>() {
    dispatchBool(fuse_glu, [&]<bool FUSE_GLU>() {
        // using Lora = typename GEMM::Lora<RANK>;
        using kernel = typename GEMM::quantize_w4a4_fuse_lora_kernel<FUSE_GLU, USE_FP4>;

        auto func = invoke_kernel<kernel, typename kernel::Arguments>;

        checkCUDA(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, kernel::SHMEM_SIZE));

Muyang Li's avatar
Muyang Li committed
500
501
        // log(std::format("quantize_w4a4_act_fuse_lora M={} N={} input={} output={} (size={} numel={})", M, N,
        // input.data_ptr(), output.data_ptr(), output.buffer->getSize(), output.numel()));
sxtyzhangzk's avatar
sxtyzhangzk committed
502
503
504

        func<<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS, kernel::SHMEM_SIZE, getCurrentCUDAStream()>>>(
            typename kernel::Arguments{
Muyang Li's avatar
Muyang Li committed
505
                .input         = input.data_ptr<half_t>(),
sxtyzhangzk's avatar
sxtyzhangzk committed
506
                .smooth_factor = smooth.valid() ? smooth.data_ptr<packed_wscale_t>() : nullptr,
Muyang Li's avatar
Muyang Li committed
507
508
                .output        = output.data_ptr<packed_act_t>(),
                .oscales       = oscales.data_ptr<typename kernel::oscales_t>(),
sxtyzhangzk's avatar
sxtyzhangzk committed
509
                .lora_wgt_down = lora_down.data_ptr<packed_fpsum_t>(),
Muyang Li's avatar
Muyang Li committed
510
511
512
513
514
515
516
517
                .lora_act      = lora_act_out.data_ptr<float>(),
                .lora_rank     = rank,
                .M             = M,
                .N             = N,
                .actualM       = actualM,
                .actualN       = actualN,
                .alwaysfalse   = false,
            });
sxtyzhangzk's avatar
sxtyzhangzk committed
518
        checkCUDA(cudaGetLastError());
muyangli's avatar
muyangli committed
519
    });
sxtyzhangzk's avatar
sxtyzhangzk committed
520
    // });
muyangli's avatar
muyangli committed
521
522
}

523
524
525
template<typename Config, bool USE_FP4>
void GEMM_W4A4_Launch<Config, USE_FP4>::quantize_w4a4_act(Tensor input, Tensor output, Tensor oscales) {
    if constexpr (USE_FP4) {
Muyang Li's avatar
Muyang Li committed
526
        assert(false); // not implemented
527
528
529
        return;
    }

muyangli's avatar
muyangli committed
530
531
532
533
534
535
536
537
538
539
540
541
    int M = input.numel() / input.shape[-1];
    int K = input.shape[-1];

    assert(output.dtype() == Tensor::INT8);
    assert(output.numel() / output.shape[-1] == M);
    assert(output.shape[-1] == K / 2);

    // assert(oscales.dtype() == Tensor::FP16);
    assert(isTypeMatch<half_t>(oscales.dtype()));
    assert(oscales.numel() == M * K / GEMM::WARP_K);

    dim3 grid(M / GEMM::WARP_M, K / GEMM::WARP_K);
muyangli's avatar
muyangli committed
542
    invoke_kernel<typename GEMM::quantize_w4a4_act_kernel><<<grid, GEMM::WARP_SIZE, 0, getCurrentCUDAStream()>>>(
Muyang Li's avatar
Muyang Li committed
543
        input.data_ptr<half_t>(), output.data_ptr<packed_act_t>(), oscales.data_ptr<packed_ascale_t>(), K);
muyangli's avatar
muyangli committed
544
545
546
    checkCUDA(cudaGetLastError());
}

547
548
549
550
551
552
553
template<typename Config, bool USE_FP4>
void GEMM_W4A4_Launch<Config, USE_FP4>::quantize_w4a4_wgt(Tensor input, Tensor output, Tensor oscales) {
    if constexpr (USE_FP4) {
        assert(false);
        return;
    }

muyangli's avatar
muyangli committed
554
555
556
557
558
559
560
    int N = input.numel() / input.shape[-1];
    int K = input.shape[-1];

    assert(output.dtype() == Tensor::INT8);
    assert(output.ndims() == 2);
    assert(output.shape[0] == N);
    assert(output.shape[1] == K / 2);
Muyang Li's avatar
Muyang Li committed
561

muyangli's avatar
muyangli committed
562
563
564
565
566
    assert(isTypeMatch<half_t>(oscales.dtype()));
    // assert(oscales.dtype() == Tensor::FP16);
    assert(oscales.numel() == N * K / GEMM::WARP_K);

    dim3 grid(N / GEMM::WARP_N, K / GEMM::WARP_K);
muyangli's avatar
muyangli committed
567
    invoke_kernel<typename GEMM::quantize_w4a4_wgt_kernel><<<grid, GEMM::WARP_SIZE, 0, getCurrentCUDAStream()>>>(
Muyang Li's avatar
Muyang Li committed
568
        input.data_ptr<half_t>(), output.data_ptr<packed_wgt_t>(), oscales.data_ptr<packed_wscale_t>(), K);
muyangli's avatar
muyangli committed
569
570
571
    checkCUDA(cudaGetLastError());
}

Muyang Li's avatar
Muyang Li committed
572
}; // namespace nunchaku::kernels