gemm_w4a4_launch.cuh 2.65 KB
Newer Older
muyangli's avatar
muyangli committed
1
#include "gemm_w4a4.cuh"
sxtyzhangzk's avatar
sxtyzhangzk committed
2
#include "epilogues.cuh"
muyangli's avatar
muyangli committed
3
4
5

namespace nunchaku::kernels {

6
template<typename Config, bool USE_FP4>
muyangli's avatar
muyangli committed
7
8
class GEMM_W4A4_Launch {
    using GEMM = GEMM_W4A4<Config>;
sxtyzhangzk's avatar
sxtyzhangzk committed
9
10
    using Epilogues = Epilogues<Config>;
    using Lora = Lora<Config>;
muyangli's avatar
muyangli committed
11
12
13
14
15

    using packed_act_t    = typename GEMM::packed_act_t;
    using packed_wgt_t    = typename GEMM::packed_wgt_t;
    using packed_ascale_t = typename GEMM::packed_ascale_t;
    using packed_wscale_t = typename GEMM::packed_wscale_t;
16
17
    using packed_amscale_t = typename GEMM::packed_amscale_t;
    using packed_wmscale_t = typename GEMM::packed_wmscale_t;
muyangli's avatar
muyangli committed
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
    using packed_fpsum_t  = typename GEMM::packed_fpsum_t;
    using half_t          = typename GEMM::half_t;

public:
    static void gemm_w4a4(
        Tensor act,          // packed act [M, K / 2]
        Tensor wgt,          // packed act [N, K / 2]
        Tensor out,          // linear     [M, N]
        Tensor qout,         // packed act [M, N / 2]
        Tensor ascales,      // packed as  [K / 64, M]
        Tensor wscales,      // packed ws  [K / 64, N]
        Tensor oscales,      // packed as  [N / 64, M]
        Tensor poolout,      // linear     [M / PoolSize, N]
        Tensor lora_act_in,  // packed lora_act [M, R]
        Tensor lora_up,      // packed lora_wgt [N, R]
        Tensor lora_down,    // packed lora_wgt [N, R]
        Tensor lora_act_out, // packed lora_act [M, R]
        Tensor norm_q,       // linear     [HEAD_DIM]
        Tensor norm_k,       // linear     [HEAD_DIM]
        Tensor rotary_emb,   // linear     [M, HEAD_DIM / 2, 2, 2]
        Tensor bias,         // packed ws  [N]
        Tensor smooth_factor, // packed ws  [N], for quantization of the next layer
        Tensor out_vk,        // linear     [B, num_heads, head_dim + 1, head_dim]
        Tensor out_linearattn,// linear     [B, (M), N / 3]
        bool act_unsigned,
        std::vector<float> lora_scales,  // [R / 16]
44
45
46
        bool fuse_silu,
        bool fp4,
        float alpha,
47
48
49
50
51
        Tensor wcscales,       // packed ws  [N]  
        Tensor out_q,          // packed attention [B, H, M, D]
        Tensor out_k,          // packed attention [B, H, M, D]
        Tensor out_v,           // packed attention [B, H, M, D]
        int attn_tokens
muyangli's avatar
muyangli committed
52
    );
53
    static void quantize_w4a4_act_fuse_lora(Tensor input, Tensor output, Tensor oscales, Tensor lora_down, Tensor lora_act_out, Tensor smooth, bool fuse_glu, bool fp4);
muyangli's avatar
muyangli committed
54
55
56
57
58
59
60
61
    static void quantize_w4a4_act(Tensor input, Tensor output, Tensor oscales);
    static void quantize_w4a4_wgt(Tensor input, Tensor output, Tensor oscales);

    static void linearattn_vk_mul_q(Tensor q, Tensor vk);
};


};  // namespace nunchaku::kernels