gemm_w4a4.cu 5.34 KB
Newer Older
muyangli's avatar
muyangli committed
1
2
3
4
5
#include "zgemm.h"
#include "gemm_w4a4_launch.cuh"

namespace nunchaku::kernels {

6
7
8
9
10
11
12
13
14
15
16
17
// for sm_75 only
struct FasterI2FMode {
    enum Mode {
        Disabled = 0,
        Enabled,
        Always,
    };
    inline static Mode mode = Disabled;
    static bool check(bool act_unsigned);
};


muyangli's avatar
muyangli committed
18
template<typename F>
19
20
21
static void invoke_launch(Tensor::ScalarType dtype, bool use_fp4, bool fasterI2F, F &&launch) {
    if (fasterI2F && dtype == Tensor::FP16) {
        launch.template operator()<GEMMConfig_W4A4_FP16_FasterI2F, false>();
muyangli's avatar
muyangli committed
22
    } else {
23
24
25
26
27
28
29
30
31
        dispatchBool(use_fp4, [&]<bool USE_FP4>() {
            if (dtype == Tensor::FP16) {
                launch.template operator()<GEMMConfig_W4A4_FP16, USE_FP4>();
            } else if (dtype == Tensor::BF16) {
                launch.template operator()<GEMMConfig_W4A4_BF16, USE_FP4>();
            } else {
                assert(false);
            }
        });
muyangli's avatar
muyangli committed
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
    }
}

void gemm_w4a4( 
    Tensor act,           // packed act [M, K / 2]
    Tensor wgt,           // packed act [N, K / 2]
    Tensor out,           // linear     [M, N]
    Tensor qout,          // packed act [M, N / 2]
    Tensor ascales,       // packed as  [K / 64, M]
    Tensor wscales,       // packed ws  [K / 64, N]
    Tensor oscales,       // packed as  [N / 64, M]
    Tensor poolout,       // linear     [M / PoolSize, N]
    Tensor lora_act_in,   // packed lora_act [M, R]
    Tensor lora_up,       // packed lora_wgt [N, R]
    Tensor lora_down,     // packed lora_wgt [N, R]
    Tensor lora_act_out,  // packed lora_act [M, R]
    Tensor norm_q,        // linear     [HEAD_DIM]
    Tensor norm_k,        // linear     [HEAD_DIM]
    Tensor rotary_emb,    // linear     [M, HEAD_DIM / 2, 2, 2]
    Tensor bias,          // packed ws  [N]
    Tensor smooth_factor, // packed ws  [N], for quantization of the next layer
    Tensor out_vk,        // linear     [B, num_heads, head_dim + 1, head_dim]
    Tensor out_linearattn,// linear     [B, (M), N / 3]
    bool act_unsigned,
    std::vector<float> lora_scales,  // [R / 16]
57
58
59
    bool fuse_silu,
    bool fp4,
    float alpha,
60
61
62
63
64
    Tensor wcscales,
    Tensor out_q,          // packed attention [B, H, M, D]
    Tensor out_k,          // packed attention [B, H, M, D]
    Tensor out_v,           // packed attention [B, H, M, D]
    int attn_tokens
muyangli's avatar
muyangli committed
65
) {
66
67
68
69
70
71
72
73
74
75
76
    Tensor::ScalarType dtype = Tensor::INVALID_SCALAR_TYPE;
    if (!fp4) {
        dtype = ascales.dtype();
    } else {
        for (auto tensor : {out, bias, lora_up, lora_down, poolout, wcscales}) {
            if (tensor.valid()) {
                assert(dtype == Tensor::INVALID_SCALAR_TYPE || dtype == tensor.dtype());
                dtype = tensor.dtype();
            }
        }
    }
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
    invoke_launch(dtype, fp4, FasterI2FMode::check(act_unsigned), [&]<typename Config, bool USE_FP4>() {
        GEMM_W4A4_Launch<Config, USE_FP4>::gemm_w4a4(
            act,           
            wgt,           
            out,           
            qout,          
            ascales,       
            wscales,       
            oscales,       
            poolout,       
            lora_act_in,   
            lora_up,       
            lora_down,     
            lora_act_out,  
            norm_q,        
            norm_k,        
            rotary_emb,    
            bias,          
            smooth_factor,
            out_vk,
            out_linearattn, 
            act_unsigned,
            lora_scales,
            fuse_silu,
            fp4,
            alpha,
            wcscales,
            out_q, 
            out_k,
            out_v,
            attn_tokens
        );
muyangli's avatar
muyangli committed
109
110
111
112
    });
}

void linearattn_vk_mul_q(Tensor q, Tensor vk) {
113
    invoke_launch(q.dtype(), false, false, [&]<typename Config, bool USE_FP4>() {
114
        GEMM_W4A4_Launch<Config, false>::linearattn_vk_mul_q(q, vk);
muyangli's avatar
muyangli committed
115
116
117
    });
}

118
void quantize_w4a4_act_fuse_lora(Tensor input, Tensor output, Tensor oscales, Tensor lora_down, Tensor lora_act_out, Tensor smooth, bool fuse_glu, bool fp4) {
119
120
121
122
    invoke_launch(input.dtype(), fp4, false, [&]<typename Config, bool USE_FP4>() {
        GEMM_W4A4_Launch<Config, USE_FP4>::quantize_w4a4_act_fuse_lora(
            input, output, oscales, lora_down, lora_act_out, smooth, fuse_glu, fp4
        );
muyangli's avatar
muyangli committed
123
124
125
126
    });
}

void quantize_w4a4_act(Tensor input, Tensor output, Tensor oscales) {
127
    invoke_launch(input.dtype(), false, false, [&]<typename Config, bool USE_FP4>() {
128
        GEMM_W4A4_Launch<Config, false>::quantize_w4a4_act(
muyangli's avatar
muyangli committed
129
130
131
132
133
            input, output, oscales
        );
    });
}
void quantize_w4a4_wgt(Tensor input, Tensor output, Tensor oscales) {
134
    invoke_launch(input.dtype(), false, false, [&]<typename Config, bool USE_FP4>() {
135
        GEMM_W4A4_Launch<Config, false>::quantize_w4a4_wgt(
muyangli's avatar
muyangli committed
136
137
138
139
140
            input, output, oscales
        );
    });
}

141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
bool FasterI2FMode::check(bool act_unsigned) {
    auto *prop = getCurrentDeviceProperties();
    if (prop->major != 7 || prop->minor != 5) {
        return false;
    }
    
    if (mode == Always) {
        return true;
    } else if (mode == Enabled && !act_unsigned) {
        return true;
    } else {
        return false;
    }
}

void set_faster_i2f_mode(std::string mode) {
    static const std::map<std::string, FasterI2FMode::Mode> mapping = {
        {"disabled", FasterI2FMode::Disabled},
        {"enabled", FasterI2FMode::Enabled},
        {"always", FasterI2FMode::Always},
    };
    FasterI2FMode::mode = mapping.at(mode);
}

muyangli's avatar
muyangli committed
165
};