gemm_w4a4_launch_impl.cuh 21.5 KB
Newer Older
muyangli's avatar
muyangli committed
1
2
3
4
5
#include "gemm_w4a4_launch.cuh"

namespace nunchaku::kernels {

#ifndef __INTELLISENSE__
6
7
template<typename Config, bool USE_FP4>
void GEMM_W4A4_Launch<Config, USE_FP4>::gemm_w4a4(
muyangli's avatar
muyangli committed
8
9
#else
template<>
10
void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16, false>::gemm_w4a4(
muyangli's avatar
muyangli committed
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
#endif
    Tensor act,           // packed act [M, K / 2]
    Tensor wgt,           // packed act [N, K / 2]
    Tensor out,           // linear     [M, N]
    Tensor qout,          // packed act [M, N / 2]
    Tensor ascales,       // packed as  [K / 64, M]
    Tensor wscales,       // packed ws  [K / 64, N]
    Tensor oscales,       // packed as  [N / 64, M]
    Tensor poolout,       // linear     [M / PoolSize, N]
    Tensor lora_act_in,   // packed lora_act [M, R]
    Tensor lora_up,       // packed lora_wgt [N, R]
    Tensor lora_down,     // packed lora_wgt [N, R]
    Tensor lora_act_out,  // packed lora_act [M, R]
    Tensor norm_q,        // linear     [HEAD_DIM]
    Tensor norm_k,        // linear     [HEAD_DIM]
    Tensor rotary_emb,    // linear     [M, HEAD_DIM / 2, 2, 2]
    Tensor bias,          // packed ws  [N]
    Tensor smooth_factor, // packed ws  [N], for quantization of the next layer
    Tensor out_vk,        // linear     [B, num_heads, head_dim + 1, head_dim]
    Tensor out_linearattn,// linear     [B, (M), N / 3]
    bool act_unsigned,
    std::vector<float> lora_scales,  // [R / 16]
33
34
35
    bool fuse_silu,
    bool fp4,
    float alpha,
36
37
38
39
40
    Tensor wcscales,       // packed ws  [N]
    Tensor out_q,          // packed attention [B, H, M, D]
    Tensor out_k,          // packed attention [B, H, M, D]
    Tensor out_v,          // packed attention [B, H, M, D]
    int attn_tokens
muyangli's avatar
muyangli committed
41
) {
42
43
44
45
46
#ifdef __INTELLISENSE__
    static constexpr bool USE_FP4 = false;
#endif
    assert(fp4 == USE_FP4);

muyangli's avatar
muyangli committed
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
    int M = act.numel() / act.shape[-1];
    int N = wgt.shape[0];
    int K = act.shape[-1] * 2;
    assert(K == wgt.shape[1] * 2);

    int actualM = 0;
    int actualN = 0;
    if (out.valid()) {
        actualM = out.numel() / out.shape[-1];
        actualN = out.shape[-1];

        assert(actualM <= M && M - actualM < GEMM::BLOCK_M);
        assert(actualN <= N && N - actualN < GEMM::BLOCK_N);
    }

    spdlog::trace("gemm_w4a4: M={} N={} K={}", M, N, K);
    spdlog::trace("act at {}", act.data_ptr());
    spdlog::trace("wgt at {}", wgt.data_ptr());
    spdlog::trace("ascales at {}", ascales.data_ptr());
    spdlog::trace("wscales at {}", wscales.data_ptr());
    if (bias.valid()) {
        spdlog::trace("bias at {}", bias.data_ptr());
    }

    int shmem = 0;

    auto launch = [&]<typename Epilogue>(Epilogue::Arguments args) {
        assert(M % GEMM::BLOCK_M == 0);
        assert(N % GEMM::BLOCK_N == 0);
        dim3 grid(M / GEMM::BLOCK_M, N / GEMM::BLOCK_N);

        bool swapBlockMN = M > N * 2;
        if (swapBlockMN) {
            std::swap(grid.x, grid.y);
        }

83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
        // test_sizeof<typename Epilogue::Arguments>();
        // std::apply([](auto ...args) {
        //     (test_sizeof<decltype(args)>(), ...);
        // }, args);

        // constexpr bool FP4_AVAILABLE = __CUDA_ARCH__ >= 1200;

        if constexpr (!USE_FP4) {
            dispatchBool(act_unsigned, [&]<bool ACT_UNSIGNED>() {
                auto func = invoke_kernel<typename GEMM::gemm_w4a4_kernel<Epilogue, ACT_UNSIGNED>, 
                    const packed_act_t *, 
                    const packed_wgt_t *, 
                    const packed_ascale_t *,
                    const packed_wscale_t *,
                    int, int, int,
                    typename Epilogue::Arguments,
                    bool,
                    bool>;

                if (shmem >= 24 * 1024) {
                    checkCUDA(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
                }
muyangli's avatar
muyangli committed
105

106
                assert(alpha == 1.0f);
107
                
108
109
110
111
112
113
114
115
116
117
118
119
120
121
                func<<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS, shmem, getCurrentCUDAStream()>>>(
                    act.data_ptr<packed_act_t>(),
                    wgt.data_ptr<packed_wgt_t>(),
                    ascales.data_ptr<packed_ascale_t>(),
                    wscales.data_ptr<packed_wscale_t>(),
                    M, N, K,
                    args,
                    swapBlockMN,
                    false
                );
                checkCUDA(cudaGetLastError());
            });
            return;
        }
122

123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
        if constexpr (USE_FP4) {
            dispatchBool(alpha != 1.0f, [&]<bool USE_ALPHA>() {
                assert(!act_unsigned);

                auto func = invoke_kernel<typename GEMM::gemm_w4a4_fp4_kernel<Epilogue, USE_ALPHA>, 
                    const packed_act_t *, 
                    const packed_wgt_t *, 
                    const packed_amscale_t *,
                    const packed_wmscale_t *,
                    float,
                    int, int, int,
                    typename Epilogue::Arguments,
                    bool,
                    bool>;

                if (shmem >= 24 * 1024) {
                    checkCUDA(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
                }

                assert(ascales.dtype() == Tensor::FP8_E4M3);
                assert(wscales.dtype() == Tensor::FP8_E4M3);
                
                func<<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS, shmem, getCurrentCUDAStream()>>>(
                    act.data_ptr<packed_act_t>(),
                    wgt.data_ptr<packed_wgt_t>(),
                    ascales.data_ptr<packed_amscale_t>(),
                    wscales.data_ptr<packed_wmscale_t>(),
                    alpha,
                    M, N, K,
                    args,
                    swapBlockMN,
                    false
                );
                checkCUDA(cudaGetLastError());
            });
            
            return;
        }

        // if constexpr (USE_FP4 && !FP4_AVAILABLE) {
        //     throw std::runtime_error("FP4 kernel is not available");
        // }
muyangli's avatar
muyangli committed
165
166
167
    };

    auto launch_bias = [&]<typename NextEpilogue>(NextEpilogue::Arguments nextArgs) {
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
        assert(!bias.valid() || bias.numel() == N);
        assert(!wcscales.valid() || wcscales.numel() == N);

        dispatchBool(bias.valid(), [&]<bool USE_BIAS>() {
            dispatchBool(wcscales.valid(), [&]<bool USE_SCALE>() {
                using EpilogueBias = typename GEMM::EpilogueBias<USE_BIAS, USE_SCALE>;
                // append EpilgoueNop to workaround mismatched memory layout of std::tuple between device and host code on Windows
                // ** sizeof(std::tuple<std::tuple<int>>) == 8 on device **
                using Epilogue = typename GEMM::EpilogueCombination<EpilogueBias, NextEpilogue, typename GEMM::EpilogueNop>;
                return launch.template operator()<Epilogue>({
                    typename EpilogueBias::Arguments{
                        .bias = USE_BIAS ? bias.data_ptr<packed_wscale_t>() : nullptr,
                        .scale = USE_SCALE ? wcscales.data_ptr<packed_wscale_t>() : nullptr,
                    },
                    nextArgs,
                    {}
                });
            });
muyangli's avatar
muyangli committed
186
187
188
189
190
191
192
193
        });
    };
    // auto launch_bias = launch;

    auto launch_lora = [&]<typename NextEpilogue, typename MidEpilogue>(NextEpilogue::Arguments nextArgs, MidEpilogue::Arguments midArgs) {
        assert(lora_up.valid() == lora_act_in.valid());
        assert(lora_down.valid() == lora_act_out.valid());

sxtyzhangzk's avatar
sxtyzhangzk committed
194
195
196
197
198
        const int rank_up = lora_up.valid() ? lora_up.shape[1] : 0;
        const int rank_down = lora_down.valid() ? lora_down.shape[1] : 0;

        if (rank_up == 0) {
            assert(rank_down == 0);
muyangli's avatar
muyangli committed
199
200
201
            return launch_bias.template operator()<typename GEMM::EpilogueCombination<MidEpilogue, NextEpilogue>>({midArgs, nextArgs});
        }

sxtyzhangzk's avatar
sxtyzhangzk committed
202
203

        assert(rank_up % 16 == 0);
muyangli's avatar
muyangli committed
204
205
206
207
208
209

        assert(lora_up.shape[0] == N);
        // assert(lora_up.shape[1] == Lora::LORA_RANK);
        assert(lora_act_in.shape[0] == M);
        assert(lora_act_in.shape[1] == rank_up);

sxtyzhangzk's avatar
sxtyzhangzk committed
210
211
        using LoraUp = Lora;
        using scale_t = typename LoraUp::scale_t;
muyangli's avatar
muyangli committed
212

sxtyzhangzk's avatar
sxtyzhangzk committed
213
214
215
216
        scale_t scales;
        if constexpr (scales.size() > 0) {
            for (size_t i = 0; i < scales.size(); i++) {
                scales[i] = i < lora_scales.size() ? lora_scales[i] : 0.0f;
muyangli's avatar
muyangli committed
217
            }
sxtyzhangzk's avatar
sxtyzhangzk committed
218
        }
muyangli's avatar
muyangli committed
219

sxtyzhangzk's avatar
sxtyzhangzk committed
220
221
        if (rank_down == 0) {
            using Epilogue = typename GEMM::EpilogueCombination<typename LoraUp::EpilogueLoraUp, MidEpilogue, NextEpilogue, typename GEMM::EpilogueNop>;
muyangli's avatar
muyangli committed
222
223
224
225
            return launch_bias.template operator()<Epilogue>({
                typename LoraUp::EpilogueLoraUp::Arguments{
                    .lora_act = lora_act_in.data_ptr<float>(),
                    .lora_wgt_up = lora_up.data_ptr<packed_fpsum_t>(),
sxtyzhangzk's avatar
sxtyzhangzk committed
226
                    .rank = rank_up,
muyangli's avatar
muyangli committed
227
                    .scales = scales,
sxtyzhangzk's avatar
sxtyzhangzk committed
228
                    .alwaysfalse = false,
muyangli's avatar
muyangli committed
229
230
231
232
233
                },
                midArgs,
                nextArgs,
                {}
            });
sxtyzhangzk's avatar
sxtyzhangzk committed
234
        }
muyangli's avatar
muyangli committed
235

sxtyzhangzk's avatar
sxtyzhangzk committed
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
        // assert(rank_down == rank_up);
        assert(rank_down % 16 == 0);

        assert(lora_down.shape[0] == N);
        // assert(lora_down.shape[1] == Lora::LORA_RANK);
        assert(lora_act_out.shape[0] == M);
        assert(lora_act_out.shape[1] == rank_down);

        lora_act_out.zero_();

        // dispatchVal(rank_down, std::integer_sequence<int, 16, 32, 48, 64, 80>(), [&]<int RANK_DOWN>() {

        using LoraDown = LoraUp; // GEMM::Lora<RANK_DOWN>;
        using Epilogue = typename GEMM::EpilogueCombination<typename LoraUp::EpilogueLoraUp, MidEpilogue, typename LoraDown::EpilogueLoraDown, NextEpilogue, typename GEMM::EpilogueNop>;
        return launch_bias.template operator()<Epilogue>({
            typename LoraUp::EpilogueLoraUp::Arguments{
                .lora_act = lora_act_in.data_ptr<float>(),
                .lora_wgt_up = lora_up.data_ptr<packed_fpsum_t>(),
                .rank = rank_up,
                .scales = scales,
                .alwaysfalse = false,
            },
            midArgs,
            typename LoraDown::EpilogueLoraDown::Arguments{
                .lora_wgt_down = lora_down.data_ptr<packed_fpsum_t>(),
                .lora_act = lora_act_out.data_ptr<float>(),
                .rank = rank_down,
                .alwaysfalse = false,
            },
            nextArgs,
            {}
muyangli's avatar
muyangli committed
267
        });
sxtyzhangzk's avatar
sxtyzhangzk committed
268
269

        // });
muyangli's avatar
muyangli committed
270
271
272
273
274
275
276
277
    };

    if (qout.valid() && oscales.valid()) {

        // dispatchBool(qout_unsigned, [&]<bool USE_UNSIGNED>() {

        static constexpr float SHIFT_GELU = 0.171875f;

278
279
280
281
282
283
284
285
286
287
288
        constexpr bool USE_UNSIGNED = !USE_FP4;
        using EpilogueQuantize = typename GEMM::EpilogueQuantize<false, USE_UNSIGNED, USE_FP4>;
        auto argsQuantize = typename EpilogueQuantize::Arguments{
            .qout = qout.data_ptr<packed_act_t>(),
            .oscales = oscales.data_ptr<typename EpilogueQuantize::oscales_t>(),
            .shift_value = USE_FP4 ? 0.0f : SHIFT_GELU,
            .smooth_factor = smooth_factor.data_ptr<packed_wscale_t>()
        };

        // TODO: check if gelu is needed
        if (out.valid()) {
sxtyzhangzk's avatar
sxtyzhangzk committed
289
            launch_lora.template operator()<typename GEMM::EpilogueCombination<typename GEMM::EpilogueDefault, EpilogueQuantize>, typename Epilogues::EpilogueGelu>({
290
291
292
293
294
295
296
297
                typename GEMM::EpilogueDefault::Arguments{
                    .out = out.data_ptr<half_t>(),
                    .actualM = actualM,
                    .actualN = actualN,
                },
                argsQuantize
            }, {});
        } else {
sxtyzhangzk's avatar
sxtyzhangzk committed
298
            launch_lora.template operator()<EpilogueQuantize, typename Epilogues::EpilogueGelu>(argsQuantize, {});
299
        }
muyangli's avatar
muyangli committed
300

301
        
muyangli's avatar
muyangli committed
302
303
304
305
    } else if (out_linearattn.valid()) {

        assert(out_vk.valid());

sxtyzhangzk's avatar
sxtyzhangzk committed
306
        using Epilogue = typename Epilogues::EpilogueLiteLA;
muyangli's avatar
muyangli committed
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340

        assert(out_vk.dtype() == Tensor::FP32);
        assert(out_vk.ndims() == 4);
        assert(out_vk.shape[2] == Epilogue::LITELA_HEAD_DIM + 1);
        assert(out_vk.shape[3] == Epilogue::LITELA_HEAD_DIM);
        assert(out_vk.shape[1] * Epilogue::LITELA_HEAD_DIM * 3 == N);
        int batch_size = out_vk.shape[0];
        int num_heads = out_vk.shape[1];

        assert(isTypeMatch<half_t>(out_linearattn.dtype()));
        assert(out_linearattn.ndims() == 3);
        assert(out_linearattn.shape[0] == batch_size);
        assert(out_linearattn.shape[2] * 3 == N);
        int num_tokens = out_linearattn.shape[1];

        assert(num_tokens % GEMM::BLOCK_M == 0);
        int num_blocks_per_batch = ceilDiv(num_tokens, GEMM::BLOCK_M);

        shmem = std::max(shmem, Epilogue::SHMEM_SIZE);

        out_vk.zero_();

        launch_lora.template operator()<Epilogue, typename GEMM::EpilogueNop>(typename Epilogue::Arguments{
            .out_q = out_linearattn.data_ptr<half_t>(),
            .out_vk = out_vk.data_ptr<float>(),
            .num_blocks_per_batch = num_blocks_per_batch,
            .actualM = M,
        }, {});

    } else if (rotary_emb.valid()) {
        assert(norm_q.valid());
        assert(norm_k.valid());
        // assert(isTypeMatch<half_t>(rotary_emb.scalar_type()));
        assert(rotary_emb.scalar_type() == Tensor::FP32);
341
342
        assert(rotary_emb.ndims() == 3);
        assert(rotary_emb.shape[0] * rotary_emb.shape[1] == M);
sxtyzhangzk's avatar
sxtyzhangzk committed
343
        assert(rotary_emb.shape[2] == Epilogues::EpilogueRMSNormRope::HEAD_DIM);
344
345
346
347
348
349
350
351
352
353
354
355
356

        // assert(rotary_emb.numel() == M * GEMM::EpilogueQKVProj::HEAD_DIM / 2 * GEMM::EpilogueQKVProj::ROTARY_EMB_NUM_ELEMENTS);
        // launch_lora.template operator()<typename GEMM::EpilogueQKVProj, typename GEMM::EpilogueNop>(typename GEMM::EpilogueQKVProj::Arguments{
        //     .out = out.data_ptr<half_t>(),
        //     .actualM = actualM,
        //     .actualN = actualN,
        //     .pool_out = poolout.valid() ? poolout.data_ptr<half_t>() : nullptr,
        //     .rotary_emb = rotary_emb.data_ptr<float>(),
        //     .rmsnorm_weight_q = norm_q.data_ptr<half_t>(),
        //     .rmsnorm_weight_k = norm_k.data_ptr<half_t>(),
        //     .epsilon = 1e-6,
        // }, {});

sxtyzhangzk's avatar
sxtyzhangzk committed
357
358
        using EpilogueRope = typename Epilogues::EpilogueRMSNormRope;
        auto argsRope = typename Epilogues::EpilogueRMSNormRope::Arguments{
359
            .rotary_emb = rotary_emb.data_ptr<typename EpilogueRope::packed_rotemb_t>(),
muyangli's avatar
muyangli committed
360
361
362
            .rmsnorm_weight_q = norm_q.data_ptr<half_t>(),
            .rmsnorm_weight_k = norm_k.data_ptr<half_t>(),
            .epsilon = 1e-6,
363
364
365
        };

        if (out_q.valid()) {
sxtyzhangzk's avatar
sxtyzhangzk committed
366
            launch_lora.template operator()<typename GEMM::EpilogueCombination<EpilogueRope, typename Epilogues::EpiloguePackQKV>, typename GEMM::EpilogueNop>({
367
                argsRope,
sxtyzhangzk's avatar
sxtyzhangzk committed
368
369
370
371
                typename Epilogues::EpiloguePackQKV::Arguments{
                    .out_q = out_q.data_ptr<typename Epilogues::EpiloguePackQKV::packed_qkv_t>(),
                    .out_k = out_k.data_ptr<typename Epilogues::EpiloguePackQKV::packed_qkv_t>(),
                    .out_v = out_v.data_ptr<typename Epilogues::EpiloguePackQKV::packed_qkv_t>(),
372
                    .actualM = attn_tokens,
sxtyzhangzk's avatar
sxtyzhangzk committed
373
374
375
                    .strideHead_q = int(out_q.stride(1) * out_q.scalar_size() / sizeof(typename Epilogues::EpiloguePackQKV::packed_qkv_t)),
                    .strideHead_k = int(out_k.stride(1) * out_k.scalar_size() / sizeof(typename Epilogues::EpiloguePackQKV::packed_qkv_t)),
                    .strideHead_v = int(out_v.stride(1) * out_v.scalar_size() / sizeof(typename Epilogues::EpiloguePackQKV::packed_qkv_t)),
376
377
378
379
380
381
382
383
384
385
386
387
388
                }
            }, {});
        } else {
            launch_lora.template operator()<typename GEMM::EpilogueCombination<EpilogueRope, typename GEMM::EpilogueDefault>, typename GEMM::EpilogueNop>({
                argsRope,
                typename GEMM::EpilogueDefault::Arguments{
                    .out = out.data_ptr<half_t>(),
                    .actualM = actualM,
                    .actualN = actualN,
                }
            }, {});
        }
        
muyangli's avatar
muyangli committed
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
    } else if (out.valid()) {

        using Epilogue = typename GEMM::EpilogueDefault;
        typename Epilogue::Arguments args{
            .out = out.data_ptr<half_t>(),
            .actualM = actualM,
            .actualN = actualN,
        };

        if (fuse_silu) {
            launch_lora.template operator()<Epilogue, typename GEMM::EpilogueSilu>(args, {});
        } else {
            launch_lora.template operator()<Epilogue, typename GEMM::EpilogueNop>(args, {});
        }
    } else {
        assert(false);
    }
}

408
409
template<typename Config, bool USE_FP4>
void GEMM_W4A4_Launch<Config, USE_FP4>::linearattn_vk_mul_q(Tensor q, Tensor vk) {
sxtyzhangzk's avatar
sxtyzhangzk committed
410
    using Epilogue = typename Epilogues::EpilogueLiteLA;
muyangli's avatar
muyangli committed
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425

    int batch_size = vk.shape[0];
    int num_heads = vk.shape[1];
    int num_tokens = q.shape[1];

    assert(isTypeMatch<half_t>(q.scalar_type()));
    assert(vk.scalar_type() == Tensor::FP32);

    int BLOCK_SIZE;
    if (num_tokens % 256 == 0) {
        BLOCK_SIZE = 256;
    } else {
        BLOCK_SIZE = 128;
    }

muyangli's avatar
muyangli committed
426
    invoke_kernel<typename Epilogue::vk_mul_q_kernel><<<dim3(ceilDiv(num_tokens, BLOCK_SIZE), num_heads, batch_size), BLOCK_SIZE, 0, getCurrentCUDAStream()>>>(
muyangli's avatar
muyangli committed
427
428
429
430
431
432
433
434
        q.data_ptr<half_t>(),
        vk.data_ptr<float>(),
        1e-6f,
        num_tokens
    );
    checkCUDA(cudaGetLastError());
}

435
436
template<typename Config, bool USE_FP4>
void GEMM_W4A4_Launch<Config, USE_FP4>::quantize_w4a4_act_fuse_lora(Tensor input, Tensor output, Tensor oscales, Tensor lora_down, Tensor lora_act_out, Tensor smooth, bool fuse_glu, bool fp4) {
muyangli's avatar
muyangli committed
437
438
439
440
441
442
443
444
445
446
447
    const int actualM = input.numel() / input.shape[-1];
    const int actualN = input.shape[-1];

    const int M = ceilDiv(actualM, GEMM::BLOCK_M) * GEMM::BLOCK_M;
    const int N = ceilDiv(actualN / (fuse_glu ? 2 : 1), GEMM::BLOCK_N) * GEMM::BLOCK_N;

    assert(output.dtype() == Tensor::INT8);
    assert(output.numel() / output.shape[-1] == M);
    assert(output.shape[-1] == N / 2);

    // assert(oscales.dtype() == Tensor::FP16);
448
449
450
451
452
453
454
    if (fp4) {
        assert(oscales.dtype() == Tensor::FP8_E4M3);
        assert(oscales.numel() == M * N / GEMM::WARP_K * 4);
    } else {
        assert(isTypeMatch<half_t>(oscales.dtype()));
        assert(oscales.numel() == M * N / GEMM::WARP_K);
    }
muyangli's avatar
muyangli committed
455
456
457

    const int rank = lora_down.shape[1];

sxtyzhangzk's avatar
sxtyzhangzk committed
458
459
    assert(rank % 16 == 0);

muyangli's avatar
muyangli committed
460
461
462
463
464
465
466
467
468
    assert(lora_down.shape[0] == N);
    // assert(lora_down.shape[1] == Lora::LORA_RANK);
    assert(lora_act_out.shape[0] == M);
    assert(lora_act_out.shape[1] == rank);

    lora_act_out.zero_();

    dim3 grid(M / GEMM::BLOCK_M, N / GEMM::BLOCK_N);

sxtyzhangzk's avatar
sxtyzhangzk committed
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
    // dispatchVal(rank, LoraRanks(), [&]<int RANK>() {
    dispatchBool(fuse_glu, [&]<bool FUSE_GLU>() {
        // using Lora = typename GEMM::Lora<RANK>;
        using kernel = typename GEMM::quantize_w4a4_fuse_lora_kernel<FUSE_GLU, USE_FP4>;

        auto func = invoke_kernel<kernel, typename kernel::Arguments>;

        checkCUDA(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, kernel::SHMEM_SIZE));

        // log(std::format("quantize_w4a4_act_fuse_lora M={} N={} input={} output={} (size={} numel={})", M, N, input.data_ptr(), output.data_ptr(), output.buffer->getSize(), output.numel()));

        func<<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS, kernel::SHMEM_SIZE, getCurrentCUDAStream()>>>(
            typename kernel::Arguments{
                .input = input.data_ptr<half_t>(),
                .smooth_factor = smooth.valid() ? smooth.data_ptr<packed_wscale_t>() : nullptr,
                .output = output.data_ptr<packed_act_t>(),
                .oscales = oscales.data_ptr<typename kernel::oscales_t>(),
                .lora_wgt_down = lora_down.data_ptr<packed_fpsum_t>(),
                .lora_act = lora_act_out.data_ptr<float>(),
                .lora_rank = rank,
                .M = M,
                .N = N,
                .actualM = actualM,
                .actualN = actualN,
                .alwaysfalse = false,
            }
        );
        checkCUDA(cudaGetLastError());
muyangli's avatar
muyangli committed
497
    });
sxtyzhangzk's avatar
sxtyzhangzk committed
498
    // });
muyangli's avatar
muyangli committed
499
500
}

501
502
503
504
505
506
507
template<typename Config, bool USE_FP4>
void GEMM_W4A4_Launch<Config, USE_FP4>::quantize_w4a4_act(Tensor input, Tensor output, Tensor oscales) {
    if constexpr (USE_FP4) {
        assert(false);  // not implemented
        return;
    }

muyangli's avatar
muyangli committed
508
509
510
511
512
513
514
515
516
517
518
519
    int M = input.numel() / input.shape[-1];
    int K = input.shape[-1];

    assert(output.dtype() == Tensor::INT8);
    assert(output.numel() / output.shape[-1] == M);
    assert(output.shape[-1] == K / 2);

    // assert(oscales.dtype() == Tensor::FP16);
    assert(isTypeMatch<half_t>(oscales.dtype()));
    assert(oscales.numel() == M * K / GEMM::WARP_K);

    dim3 grid(M / GEMM::WARP_M, K / GEMM::WARP_K);
muyangli's avatar
muyangli committed
520
    invoke_kernel<typename GEMM::quantize_w4a4_act_kernel><<<grid, GEMM::WARP_SIZE, 0, getCurrentCUDAStream()>>>(
muyangli's avatar
muyangli committed
521
522
523
524
525
526
527
528
        input.data_ptr<half_t>(),
        output.data_ptr<packed_act_t>(),
        oscales.data_ptr<packed_ascale_t>(),
        K
    );
    checkCUDA(cudaGetLastError());
}

529
530
531
532
533
534
535
template<typename Config, bool USE_FP4>
void GEMM_W4A4_Launch<Config, USE_FP4>::quantize_w4a4_wgt(Tensor input, Tensor output, Tensor oscales) {
    if constexpr (USE_FP4) {
        assert(false);
        return;
    }

muyangli's avatar
muyangli committed
536
537
538
539
540
541
542
543
544
545
546
547
548
    int N = input.numel() / input.shape[-1];
    int K = input.shape[-1];

    assert(output.dtype() == Tensor::INT8);
    assert(output.ndims() == 2);
    assert(output.shape[0] == N);
    assert(output.shape[1] == K / 2);
    
    assert(isTypeMatch<half_t>(oscales.dtype()));
    // assert(oscales.dtype() == Tensor::FP16);
    assert(oscales.numel() == N * K / GEMM::WARP_K);

    dim3 grid(N / GEMM::WARP_N, K / GEMM::WARP_K);
muyangli's avatar
muyangli committed
549
    invoke_kernel<typename GEMM::quantize_w4a4_wgt_kernel><<<grid, GEMM::WARP_SIZE, 0, getCurrentCUDAStream()>>>(
muyangli's avatar
muyangli committed
550
551
552
553
554
555
556
557
558
        input.data_ptr<half_t>(),
        output.data_ptr<packed_wgt_t>(),
        oscales.data_ptr<packed_wscale_t>(),
        K
    );
    checkCUDA(cudaGetLastError());
}

};  // namespace nunchaku::kernels