gemm_w4a4_launch_impl.cuh 19.3 KB
Newer Older
muyangli's avatar
muyangli committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
#include "gemm_w4a4_launch.cuh"

namespace nunchaku::kernels {

#ifndef __INTELLISENSE__
template<typename Config>
void GEMM_W4A4_Launch<Config>::gemm_w4a4(
#else
template<>
void GEMM_W4A4_Launch<GEMMConfig_W4A4_FP16>::gemm_w4a4(
#endif
    Tensor act,           // packed act [M, K / 2]
    Tensor wgt,           // packed act [N, K / 2]
    Tensor out,           // linear     [M, N]
    Tensor qout,          // packed act [M, N / 2]
    Tensor ascales,       // packed as  [K / 64, M]
    Tensor wscales,       // packed ws  [K / 64, N]
    Tensor oscales,       // packed as  [N / 64, M]
    Tensor poolout,       // linear     [M / PoolSize, N]
    Tensor lora_act_in,   // packed lora_act [M, R]
    Tensor lora_up,       // packed lora_wgt [N, R]
    Tensor lora_down,     // packed lora_wgt [N, R]
    Tensor lora_act_out,  // packed lora_act [M, R]
    Tensor norm_q,        // linear     [HEAD_DIM]
    Tensor norm_k,        // linear     [HEAD_DIM]
    Tensor rotary_emb,    // linear     [M, HEAD_DIM / 2, 2, 2]
    Tensor bias,          // packed ws  [N]
    Tensor smooth_factor, // packed ws  [N], for quantization of the next layer
    Tensor out_vk,        // linear     [B, num_heads, head_dim + 1, head_dim]
    Tensor out_linearattn,// linear     [B, (M), N / 3]
    bool act_unsigned,
    std::vector<float> lora_scales,  // [R / 16]
33
34
35
36
    bool fuse_silu,
    bool fp4,
    float alpha,
    Tensor wcscales       // packed ws  [N]  
muyangli's avatar
muyangli committed
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
) {
    int M = act.numel() / act.shape[-1];
    int N = wgt.shape[0];
    int K = act.shape[-1] * 2;
    assert(K == wgt.shape[1] * 2);

    int actualM = 0;
    int actualN = 0;
    if (out.valid()) {
        actualM = out.numel() / out.shape[-1];
        actualN = out.shape[-1];

        assert(actualM <= M && M - actualM < GEMM::BLOCK_M);
        assert(actualN <= N && N - actualN < GEMM::BLOCK_N);
    }

    spdlog::trace("gemm_w4a4: M={} N={} K={}", M, N, K);
    spdlog::trace("act at {}", act.data_ptr());
    spdlog::trace("wgt at {}", wgt.data_ptr());
    spdlog::trace("ascales at {}", ascales.data_ptr());
    spdlog::trace("wscales at {}", wscales.data_ptr());
    if (bias.valid()) {
        spdlog::trace("bias at {}", bias.data_ptr());
    }

    int shmem = 0;

    auto launch = [&]<typename Epilogue>(Epilogue::Arguments args) {
        assert(M % GEMM::BLOCK_M == 0);
        assert(N % GEMM::BLOCK_N == 0);
        dim3 grid(M / GEMM::BLOCK_M, N / GEMM::BLOCK_N);

        bool swapBlockMN = M > N * 2;
        if (swapBlockMN) {
            std::swap(grid.x, grid.y);
        }

74
        dispatchBool(fp4, [&]<bool USE_FP4>() {
muyangli's avatar
muyangli committed
75
76
77
78
79
            // test_sizeof<typename Epilogue::Arguments>();
            // std::apply([](auto ...args) {
            //     (test_sizeof<decltype(args)>(), ...);
            // }, args);

80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
            // constexpr bool FP4_AVAILABLE = __CUDA_ARCH__ >= 1200;

            if constexpr (!USE_FP4) {
                dispatchBool(act_unsigned, [&]<bool ACT_UNSIGNED>() {
                    auto func = invoke_kernel<typename GEMM::gemm_w4a4_kernel<Epilogue, ACT_UNSIGNED>, 
                        const packed_act_t *, 
                        const packed_wgt_t *, 
                        const packed_ascale_t *,
                        const packed_wscale_t *,
                        int, int, int,
                        typename Epilogue::Arguments,
                        bool,
                        bool>;

                    if (shmem >= 24 * 1024) {
                        checkCUDA(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
                    }

                    assert(alpha == 1.0f);
                    
                    func<<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS, shmem>>>(
                        act.data_ptr<packed_act_t>(),
                        wgt.data_ptr<packed_wgt_t>(),
                        ascales.data_ptr<packed_ascale_t>(),
                        wscales.data_ptr<packed_wscale_t>(),
                        M, N, K,
                        args,
                        swapBlockMN,
                        false
                    );
                    checkCUDA(cudaGetLastError());
                });
                return;
            }
muyangli's avatar
muyangli committed
114

115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
            if constexpr (USE_FP4) {
                dispatchBool(alpha != 1.0f, [&]<bool USE_ALPHA>() {
                    assert(!act_unsigned);

                    auto func = invoke_kernel<typename GEMM::gemm_w4a4_fp4_kernel<Epilogue, USE_ALPHA>, 
                        const packed_act_t *, 
                        const packed_wgt_t *, 
                        const packed_amscale_t *,
                        const packed_wmscale_t *,
                        float,
                        int, int, int,
                        typename Epilogue::Arguments,
                        bool,
                        bool>;

                    if (shmem >= 24 * 1024) {
                        checkCUDA(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
                    }

                    assert(ascales.dtype() == Tensor::FP8_E4M3);
                    assert(wscales.dtype() == Tensor::FP8_E4M3);
                    
                    func<<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS, shmem>>>(
                        act.data_ptr<packed_act_t>(),
                        wgt.data_ptr<packed_wgt_t>(),
                        ascales.data_ptr<packed_amscale_t>(),
                        wscales.data_ptr<packed_wmscale_t>(),
                        alpha,
                        M, N, K,
                        args,
                        swapBlockMN,
                        false
                    );
                    checkCUDA(cudaGetLastError());
                });
                
                return;
muyangli's avatar
muyangli committed
152
            }
153
154
155
156

            // if constexpr (USE_FP4 && !FP4_AVAILABLE) {
            //     throw std::runtime_error("FP4 kernel is not available");
            // }
muyangli's avatar
muyangli committed
157
158
159
160
        });
    };

    auto launch_bias = [&]<typename NextEpilogue>(NextEpilogue::Arguments nextArgs) {
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
        assert(!bias.valid() || bias.numel() == N);
        assert(!wcscales.valid() || wcscales.numel() == N);

        dispatchBool(bias.valid(), [&]<bool USE_BIAS>() {
            dispatchBool(wcscales.valid(), [&]<bool USE_SCALE>() {
                using EpilogueBias = typename GEMM::EpilogueBias<USE_BIAS, USE_SCALE>;
                // append EpilgoueNop to workaround mismatched memory layout of std::tuple between device and host code on Windows
                // ** sizeof(std::tuple<std::tuple<int>>) == 8 on device **
                using Epilogue = typename GEMM::EpilogueCombination<EpilogueBias, NextEpilogue, typename GEMM::EpilogueNop>;
                return launch.template operator()<Epilogue>({
                    typename EpilogueBias::Arguments{
                        .bias = USE_BIAS ? bias.data_ptr<packed_wscale_t>() : nullptr,
                        .scale = USE_SCALE ? wcscales.data_ptr<packed_wscale_t>() : nullptr,
                    },
                    nextArgs,
                    {}
                });
            });
muyangli's avatar
muyangli committed
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
        });
    };
    // auto launch_bias = launch;

    auto launch_lora = [&]<typename NextEpilogue, typename MidEpilogue>(NextEpilogue::Arguments nextArgs, MidEpilogue::Arguments midArgs) {
        assert(lora_up.valid() == lora_act_in.valid());
        assert(lora_down.valid() == lora_act_out.valid());

        if (!lora_up.valid()) {
            assert(!lora_down.valid());
            return launch_bias.template operator()<typename GEMM::EpilogueCombination<MidEpilogue, NextEpilogue>>({midArgs, nextArgs});
        }

        const int rank_up = lora_up.shape[1];

        assert(lora_up.shape[0] == N);
        // assert(lora_up.shape[1] == Lora::LORA_RANK);
        assert(lora_act_in.shape[0] == M);
        assert(lora_act_in.shape[1] == rank_up);

        dispatchVal(rank_up, LoraRanks(), [&]<int RANK_UP>() {
            using LoraUp = typename GEMM::Lora<RANK_UP>;
            using scale_t = typename LoraUp::scale_t;

            scale_t scales;
            if constexpr (scales.size() > 0) {
                assert(lora_scales.size() >= scales.size());
                for (size_t i = 0; i < scales.size(); i++) {
                    scales[i] = lora_scales[i];
                }
            }

            if (!lora_down.valid()) {
                using Epilogue = typename GEMM::EpilogueCombination<typename LoraUp::EpilogueLoraUp, MidEpilogue, NextEpilogue, typename GEMM::EpilogueNop>;
                return launch_bias.template operator()<Epilogue>({
                    typename LoraUp::EpilogueLoraUp::Arguments{
                        .lora_act = lora_act_in.data_ptr<float>(),
                        .lora_wgt_up = lora_up.data_ptr<packed_fpsum_t>(),
                        .scales = scales,
                    },
                    midArgs,
                    nextArgs,
                    {}
                });
            }

            const int rank_down = lora_down.shape[1];

            assert(rank_down == rank_up);

            assert(lora_down.shape[0] == N);
            // assert(lora_down.shape[1] == Lora::LORA_RANK);
            assert(lora_act_out.shape[0] == M);
            assert(lora_act_out.shape[1] == rank_down);

            lora_act_out.zero_();

            // dispatchVal(rank_down, std::integer_sequence<int, 16, 32, 48, 64, 80>(), [&]<int RANK_DOWN>() {

            using LoraDown = LoraUp; // GEMM::Lora<RANK_DOWN>;
            using Epilogue = typename GEMM::EpilogueCombination<typename LoraUp::EpilogueLoraUp, MidEpilogue, typename LoraDown::EpilogueLoraDown, NextEpilogue, typename GEMM::EpilogueNop>;
            return launch_bias.template operator()<Epilogue>({
                typename LoraUp::EpilogueLoraUp::Arguments{
                    .lora_act = lora_act_in.data_ptr<float>(),
                    .lora_wgt_up = lora_up.data_ptr<packed_fpsum_t>(),
                    .scales = scales,
                },
                midArgs,
                typename LoraDown::EpilogueLoraDown::Arguments{
                    .lora_wgt_down = lora_down.data_ptr<packed_fpsum_t>(),
                    .lora_act = lora_act_out.data_ptr<float>(),
                },
                nextArgs,
                {}
            });

            // });
        });
    };

    if (qout.valid() && oscales.valid()) {

        // dispatchBool(qout_unsigned, [&]<bool USE_UNSIGNED>() {

        static constexpr float SHIFT_GELU = 0.171875f;

265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
        dispatchBool(fp4, [&]<bool USE_FP4>() {
            constexpr bool USE_UNSIGNED = !USE_FP4;
            using EpilogueQuantize = typename GEMM::EpilogueQuantize<false, USE_UNSIGNED, USE_FP4>;
            auto argsQuantize = typename EpilogueQuantize::Arguments{
                .qout = qout.data_ptr<packed_act_t>(),
                .oscales = oscales.data_ptr<typename EpilogueQuantize::oscales_t>(),
                .shift_value = USE_FP4 ? 0.0f : SHIFT_GELU,
                .smooth_factor = smooth_factor.data_ptr<packed_wscale_t>()
            };

            // TODO: check if gelu is needed
            if (out.valid()) {
                launch_lora.template operator()<typename GEMM::EpilogueCombination<typename GEMM::EpilogueDefault, EpilogueQuantize>, typename GEMM::EpilogueGelu>({
                    typename GEMM::EpilogueDefault::Arguments{
                        .out = out.data_ptr<half_t>(),
                        .actualM = actualM,
                        .actualN = actualN,
                    },
                    argsQuantize
                }, {});
            } else {
                launch_lora.template operator()<EpilogueQuantize, typename GEMM::EpilogueGelu>(argsQuantize, {});
            }
        });
muyangli's avatar
muyangli committed
289

290
        
muyangli's avatar
muyangli committed
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
    } else if (out_linearattn.valid()) {

        assert(out_vk.valid());

        using Epilogue = typename GEMM::EpilogueLiteLA;

        assert(out_vk.dtype() == Tensor::FP32);
        assert(out_vk.ndims() == 4);
        assert(out_vk.shape[2] == Epilogue::LITELA_HEAD_DIM + 1);
        assert(out_vk.shape[3] == Epilogue::LITELA_HEAD_DIM);
        assert(out_vk.shape[1] * Epilogue::LITELA_HEAD_DIM * 3 == N);
        int batch_size = out_vk.shape[0];
        int num_heads = out_vk.shape[1];

        assert(isTypeMatch<half_t>(out_linearattn.dtype()));
        assert(out_linearattn.ndims() == 3);
        assert(out_linearattn.shape[0] == batch_size);
        assert(out_linearattn.shape[2] * 3 == N);
        int num_tokens = out_linearattn.shape[1];

        assert(num_tokens % GEMM::BLOCK_M == 0);
        int num_blocks_per_batch = ceilDiv(num_tokens, GEMM::BLOCK_M);

        shmem = std::max(shmem, Epilogue::SHMEM_SIZE);

        out_vk.zero_();

        launch_lora.template operator()<Epilogue, typename GEMM::EpilogueNop>(typename Epilogue::Arguments{
            .out_q = out_linearattn.data_ptr<half_t>(),
            .out_vk = out_vk.data_ptr<float>(),
            .num_blocks_per_batch = num_blocks_per_batch,
            .actualM = M,
        }, {});

    } else if (rotary_emb.valid()) {
        assert(norm_q.valid());
        assert(norm_k.valid());
        // assert(isTypeMatch<half_t>(rotary_emb.scalar_type()));
        assert(rotary_emb.scalar_type() == Tensor::FP32);
        assert(rotary_emb.numel() == M * GEMM::EpilogueQKVProj::HEAD_DIM / 2 * GEMM::EpilogueQKVProj::ROTARY_EMB_NUM_ELEMENTS);
        launch_lora.template operator()<typename GEMM::EpilogueQKVProj, typename GEMM::EpilogueNop>(typename GEMM::EpilogueQKVProj::Arguments{
            .out = out.data_ptr<half_t>(),
            .actualM = actualM,
            .actualN = actualN,
            .pool_out = poolout.valid() ? poolout.data_ptr<half_t>() : nullptr,
            .rotary_emb = rotary_emb.data_ptr<float>(),
            .rmsnorm_weight_q = norm_q.data_ptr<half_t>(),
            .rmsnorm_weight_k = norm_k.data_ptr<half_t>(),
            .epsilon = 1e-6,
        }, {});
    } else if (out.valid()) {

        using Epilogue = typename GEMM::EpilogueDefault;
        typename Epilogue::Arguments args{
            .out = out.data_ptr<half_t>(),
            .actualM = actualM,
            .actualN = actualN,
        };

        if (fuse_silu) {
            launch_lora.template operator()<Epilogue, typename GEMM::EpilogueSilu>(args, {});
        } else {
            launch_lora.template operator()<Epilogue, typename GEMM::EpilogueNop>(args, {});
        }
    } else {
        assert(false);
    }
}

template<typename Config>
void GEMM_W4A4_Launch<Config>::linearattn_vk_mul_q(Tensor q, Tensor vk) {
    using Epilogue = typename GEMM::EpilogueLiteLA;

    int batch_size = vk.shape[0];
    int num_heads = vk.shape[1];
    int num_tokens = q.shape[1];

    assert(isTypeMatch<half_t>(q.scalar_type()));
    assert(vk.scalar_type() == Tensor::FP32);

    int BLOCK_SIZE;
    if (num_tokens % 256 == 0) {
        BLOCK_SIZE = 256;
    } else {
        BLOCK_SIZE = 128;
    }

    invoke_kernel<typename Epilogue::vk_mul_q_kernel><<<dim3(ceilDiv(num_tokens, BLOCK_SIZE), num_heads, batch_size), BLOCK_SIZE>>>(
        q.data_ptr<half_t>(),
        vk.data_ptr<float>(),
        1e-6f,
        num_tokens
    );
    checkCUDA(cudaGetLastError());
}

template<typename Config>
388
void GEMM_W4A4_Launch<Config>::quantize_w4a4_act_fuse_lora(Tensor input, Tensor output, Tensor oscales, Tensor lora_down, Tensor lora_act_out, Tensor smooth, bool fuse_glu, bool fp4) {
muyangli's avatar
muyangli committed
389
390
391
392
393
394
395
396
397
398
399
    const int actualM = input.numel() / input.shape[-1];
    const int actualN = input.shape[-1];

    const int M = ceilDiv(actualM, GEMM::BLOCK_M) * GEMM::BLOCK_M;
    const int N = ceilDiv(actualN / (fuse_glu ? 2 : 1), GEMM::BLOCK_N) * GEMM::BLOCK_N;

    assert(output.dtype() == Tensor::INT8);
    assert(output.numel() / output.shape[-1] == M);
    assert(output.shape[-1] == N / 2);

    // assert(oscales.dtype() == Tensor::FP16);
400
401
402
403
404
405
406
    if (fp4) {
        assert(oscales.dtype() == Tensor::FP8_E4M3);
        assert(oscales.numel() == M * N / GEMM::WARP_K * 4);
    } else {
        assert(isTypeMatch<half_t>(oscales.dtype()));
        assert(oscales.numel() == M * N / GEMM::WARP_K);
    }
muyangli's avatar
muyangli committed
407
408
409
410
411
412
413
414
415
416
417
418
419
420

    const int rank = lora_down.shape[1];

    assert(lora_down.shape[0] == N);
    // assert(lora_down.shape[1] == Lora::LORA_RANK);
    assert(lora_act_out.shape[0] == M);
    assert(lora_act_out.shape[1] == rank);

    lora_act_out.zero_();

    dim3 grid(M / GEMM::BLOCK_M, N / GEMM::BLOCK_N);

    dispatchVal(rank, LoraRanks(), [&]<int RANK>() {
        dispatchBool(fuse_glu, [&]<bool FUSE_GLU>() {
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
            dispatchBool(fp4, [&]<bool USE_FP4>() {
                using Lora = typename GEMM::Lora<RANK>;
                using kernel = typename Lora::quantize_w4a4_fuse_lora_kernel<FUSE_GLU, USE_FP4>;

                auto func = invoke_kernel<kernel, typename kernel::Arguments>;

                checkCUDA(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, kernel::SHMEM_SIZE));

                // log(std::format("quantize_w4a4_act_fuse_lora M={} N={} input={} output={} (size={} numel={})", M, N, input.data_ptr(), output.data_ptr(), output.buffer->getSize(), output.numel()));

                func<<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS, kernel::SHMEM_SIZE>>>(
                    typename kernel::Arguments{
                        .input = input.data_ptr<half_t>(),
                        .smooth_factor = smooth.valid() ? smooth.data_ptr<packed_wscale_t>() : nullptr,
                        .output = output.data_ptr<packed_act_t>(),
                        .oscales = oscales.data_ptr<typename kernel::oscales_t>(),
                        .lora_wgt_down = lora_down.data_ptr<packed_fpsum_t>(),
                        .lora_act = lora_act_out.data_ptr<float>(),
                        .M = M,
                        .N = N,
                        .actualM = actualM,
                        .actualN = actualN,
                    }
                );
                checkCUDA(cudaGetLastError());
            });
muyangli's avatar
muyangli committed
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
        });
    });
}

template<typename Config>
void GEMM_W4A4_Launch<Config>::quantize_w4a4_act(Tensor input, Tensor output, Tensor oscales) {
    int M = input.numel() / input.shape[-1];
    int K = input.shape[-1];

    assert(output.dtype() == Tensor::INT8);
    assert(output.numel() / output.shape[-1] == M);
    assert(output.shape[-1] == K / 2);

    // assert(oscales.dtype() == Tensor::FP16);
    assert(isTypeMatch<half_t>(oscales.dtype()));
    assert(oscales.numel() == M * K / GEMM::WARP_K);

    dim3 grid(M / GEMM::WARP_M, K / GEMM::WARP_K);
    invoke_kernel<typename GEMM::quantize_w4a4_act_kernel><<<grid, GEMM::WARP_SIZE>>>(
        input.data_ptr<half_t>(),
        output.data_ptr<packed_act_t>(),
        oscales.data_ptr<packed_ascale_t>(),
        K
    );
    checkCUDA(cudaGetLastError());
}

template<typename Config>
void GEMM_W4A4_Launch<Config>::quantize_w4a4_wgt(Tensor input, Tensor output, Tensor oscales) {
    int N = input.numel() / input.shape[-1];
    int K = input.shape[-1];

    assert(output.dtype() == Tensor::INT8);
    assert(output.ndims() == 2);
    assert(output.shape[0] == N);
    assert(output.shape[1] == K / 2);
    
    assert(isTypeMatch<half_t>(oscales.dtype()));
    // assert(oscales.dtype() == Tensor::FP16);
    assert(oscales.numel() == N * K / GEMM::WARP_K);

    dim3 grid(N / GEMM::WARP_N, K / GEMM::WARP_K);
    invoke_kernel<typename GEMM::quantize_w4a4_wgt_kernel><<<grid, GEMM::WARP_SIZE>>>(
        input.data_ptr<half_t>(),
        output.data_ptr<packed_wgt_t>(),
        oscales.data_ptr<packed_wscale_t>(),
        K
    );
    checkCUDA(cudaGetLastError());
}

};  // namespace nunchaku::kernels