gemm_w4a4.cu 4.3 KB
Newer Older
muyangli's avatar
muyangli committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
#include "zgemm.h"
#include "gemm_w4a4_launch.cuh"

namespace nunchaku::kernels {

template<typename F>
static void invoke_launch(Tensor::ScalarType dtype, F &&launch) {
    if (dtype == Tensor::FP16) {
        launch.template operator()<GEMMConfig_W4A4_FP16>();
    } else if (dtype == Tensor::BF16) {
        launch.template operator()<GEMMConfig_W4A4_BF16>();
    } else {
        assert(false);
    }
}

void gemm_w4a4( 
    Tensor act,           // packed act [M, K / 2]
    Tensor wgt,           // packed act [N, K / 2]
    Tensor out,           // linear     [M, N]
    Tensor qout,          // packed act [M, N / 2]
    Tensor ascales,       // packed as  [K / 64, M]
    Tensor wscales,       // packed ws  [K / 64, N]
    Tensor oscales,       // packed as  [N / 64, M]
    Tensor poolout,       // linear     [M / PoolSize, N]
    Tensor lora_act_in,   // packed lora_act [M, R]
    Tensor lora_up,       // packed lora_wgt [N, R]
    Tensor lora_down,     // packed lora_wgt [N, R]
    Tensor lora_act_out,  // packed lora_act [M, R]
    Tensor norm_q,        // linear     [HEAD_DIM]
    Tensor norm_k,        // linear     [HEAD_DIM]
    Tensor rotary_emb,    // linear     [M, HEAD_DIM / 2, 2, 2]
    Tensor bias,          // packed ws  [N]
    Tensor smooth_factor, // packed ws  [N], for quantization of the next layer
    Tensor out_vk,        // linear     [B, num_heads, head_dim + 1, head_dim]
    Tensor out_linearattn,// linear     [B, (M), N / 3]
    bool act_unsigned,
    std::vector<float> lora_scales,  // [R / 16]
39
40
41
    bool fuse_silu,
    bool fp4,
    float alpha,
42
43
44
45
46
    Tensor wcscales,
    Tensor out_q,          // packed attention [B, H, M, D]
    Tensor out_k,          // packed attention [B, H, M, D]
    Tensor out_v,           // packed attention [B, H, M, D]
    int attn_tokens
muyangli's avatar
muyangli committed
47
) {
48
49
50
51
52
53
54
55
56
57
58
59
    Tensor::ScalarType dtype = Tensor::INVALID_SCALAR_TYPE;
    if (!fp4) {
        dtype = ascales.dtype();
    } else {
        for (auto tensor : {out, bias, lora_up, lora_down, poolout, wcscales}) {
            if (tensor.valid()) {
                assert(dtype == Tensor::INVALID_SCALAR_TYPE || dtype == tensor.dtype());
                dtype = tensor.dtype();
            }
        }
    }
    invoke_launch(dtype, [&]<typename Config>() {
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
        dispatchBool(fp4, [&]<bool USE_FP4>() {
            GEMM_W4A4_Launch<Config, USE_FP4>::gemm_w4a4(
                act,           
                wgt,           
                out,           
                qout,          
                ascales,       
                wscales,       
                oscales,       
                poolout,       
                lora_act_in,   
                lora_up,       
                lora_down,     
                lora_act_out,  
                norm_q,        
                norm_k,        
                rotary_emb,    
                bias,          
                smooth_factor,
                out_vk,
                out_linearattn, 
                act_unsigned,
                lora_scales,
                fuse_silu,
                fp4,
                alpha,
                wcscales,
                out_q, 
                out_k,
                out_v,
                attn_tokens
            );
        });
muyangli's avatar
muyangli committed
93
94
95
96
97
    });
}

void linearattn_vk_mul_q(Tensor q, Tensor vk) {
    invoke_launch(q.dtype(), [&]<typename Config>() {
98
        GEMM_W4A4_Launch<Config, false>::linearattn_vk_mul_q(q, vk);
muyangli's avatar
muyangli committed
99
100
101
    });
}

102
void quantize_w4a4_act_fuse_lora(Tensor input, Tensor output, Tensor oscales, Tensor lora_down, Tensor lora_act_out, Tensor smooth, bool fuse_glu, bool fp4) {
muyangli's avatar
muyangli committed
103
    invoke_launch(input.dtype(), [&]<typename Config>() {
104
105
106
107
108
        dispatchBool(fp4, [&]<bool USE_FP4>() {
            GEMM_W4A4_Launch<Config, USE_FP4>::quantize_w4a4_act_fuse_lora(
                input, output, oscales, lora_down, lora_act_out, smooth, fuse_glu, fp4
            );
        });
muyangli's avatar
muyangli committed
109
110
111
112
113
    });
}

void quantize_w4a4_act(Tensor input, Tensor output, Tensor oscales) {
    invoke_launch(input.dtype(), [&]<typename Config>() {
114
        GEMM_W4A4_Launch<Config, false>::quantize_w4a4_act(
muyangli's avatar
muyangli committed
115
116
117
118
119
120
            input, output, oscales
        );
    });
}
void quantize_w4a4_wgt(Tensor input, Tensor output, Tensor oscales) {
    invoke_launch(input.dtype(), [&]<typename Config>() {
121
        GEMM_W4A4_Launch<Config, false>::quantize_w4a4_wgt(
muyangli's avatar
muyangli committed
122
123
124
125
126
127
            input, output, oscales
        );
    });
}

};