gemm_w4a4.cu 6.81 KB
Newer Older
muyangli's avatar
muyangli committed
1
2
3
4
5
#include "zgemm.h"
#include "gemm_w4a4_launch.cuh"

namespace nunchaku::kernels {

6
7
8
9
10
11
12
13
14
15
16
// for sm_75 only
struct FasterI2FMode {
    enum Mode {
        Disabled = 0,
        Enabled,
        Always,
    };
    inline static Mode mode = Disabled;
    static bool check(bool act_unsigned);
};

muyangli's avatar
muyangli committed
17
template<typename F>
18
19
20
static void invoke_launch(Tensor::ScalarType dtype, bool use_fp4, bool fasterI2F, F &&launch) {
    if (fasterI2F && dtype == Tensor::FP16) {
        launch.template operator()<GEMMConfig_W4A4_FP16_FasterI2F, false>();
muyangli's avatar
muyangli committed
21
    } else {
22
23
24
25
26
27
28
29
30
        dispatchBool(use_fp4, [&]<bool USE_FP4>() {
            if (dtype == Tensor::FP16) {
                launch.template operator()<GEMMConfig_W4A4_FP16, USE_FP4>();
            } else if (dtype == Tensor::BF16) {
                launch.template operator()<GEMMConfig_W4A4_BF16, USE_FP4>();
            } else {
                assert(false);
            }
        });
muyangli's avatar
muyangli committed
31
32
33
    }
}

Muyang Li's avatar
Muyang Li committed
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
void gemm_w4a4(Tensor act,            // packed act [M, K / 2]
               Tensor wgt,            // packed act [N, K / 2]
               Tensor out,            // linear     [M, N]
               Tensor qout,           // packed act [M, N / 2]
               Tensor ascales,        // packed as  [K / 64, M]
               Tensor wscales,        // packed ws  [K / 64, N]
               Tensor oscales,        // packed as  [N / 64, M]
               Tensor poolout,        // linear     [M / PoolSize, N]
               Tensor lora_act_in,    // packed lora_act [M, R]
               Tensor lora_up,        // packed lora_wgt [N, R]
               Tensor lora_down,      // packed lora_wgt [N, R]
               Tensor lora_act_out,   // packed lora_act [M, R]
               Tensor norm_q,         // linear     [HEAD_DIM]
               Tensor norm_k,         // linear     [HEAD_DIM]
               Tensor rotary_emb,     // linear     [M, HEAD_DIM / 2, 2, 2]
               Tensor bias,           // packed ws  [N]
               Tensor smooth_factor,  // packed ws  [N], for quantization of the next layer
               Tensor out_vk,         // linear     [B, num_heads, head_dim + 1, head_dim]
               Tensor out_linearattn, // linear     [B, (M), N / 3]
               bool act_unsigned,
               std::vector<float> lora_scales, // [R / 16]
               bool fuse_silu,
               bool fp4,
               float alpha,
               Tensor wcscales,
               Tensor out_q, // packed attention [B, H, M, D]
               Tensor out_k, // packed attention [B, H, M, D]
               Tensor out_v, // packed attention [B, H, M, D]
               int attn_tokens) {
63
64
65
66
67
68
69
70
71
72
73
    Tensor::ScalarType dtype = Tensor::INVALID_SCALAR_TYPE;
    if (!fp4) {
        dtype = ascales.dtype();
    } else {
        for (auto tensor : {out, bias, lora_up, lora_down, poolout, wcscales}) {
            if (tensor.valid()) {
                assert(dtype == Tensor::INVALID_SCALAR_TYPE || dtype == tensor.dtype());
                dtype = tensor.dtype();
            }
        }
    }
74
    invoke_launch(dtype, fp4, FasterI2FMode::check(act_unsigned), [&]<typename Config, bool USE_FP4>() {
Muyang Li's avatar
Muyang Li committed
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
        GEMM_W4A4_Launch<Config, USE_FP4>::gemm_w4a4(act,
                                                     wgt,
                                                     out,
                                                     qout,
                                                     ascales,
                                                     wscales,
                                                     oscales,
                                                     poolout,
                                                     lora_act_in,
                                                     lora_up,
                                                     lora_down,
                                                     lora_act_out,
                                                     norm_q,
                                                     norm_k,
                                                     rotary_emb,
                                                     bias,
                                                     smooth_factor,
                                                     out_vk,
                                                     out_linearattn,
                                                     act_unsigned,
                                                     lora_scales,
                                                     fuse_silu,
                                                     fp4,
                                                     alpha,
                                                     wcscales,
                                                     out_q,
                                                     out_k,
                                                     out_v,
                                                     attn_tokens);
muyangli's avatar
muyangli committed
104
105
106
107
    });
}

void linearattn_vk_mul_q(Tensor q, Tensor vk) {
108
    invoke_launch(q.dtype(), false, false, [&]<typename Config, bool USE_FP4>() {
109
        GEMM_W4A4_Launch<Config, false>::linearattn_vk_mul_q(q, vk);
muyangli's avatar
muyangli committed
110
111
112
    });
}

Muyang Li's avatar
Muyang Li committed
113
114
115
116
117
118
119
120
void quantize_w4a4_act_fuse_lora(Tensor input,
                                 Tensor output,
                                 Tensor oscales,
                                 Tensor lora_down,
                                 Tensor lora_act_out,
                                 Tensor smooth,
                                 bool fuse_glu,
                                 bool fp4) {
121
122
    invoke_launch(input.dtype(), fp4, false, [&]<typename Config, bool USE_FP4>() {
        GEMM_W4A4_Launch<Config, USE_FP4>::quantize_w4a4_act_fuse_lora(
Muyang Li's avatar
Muyang Li committed
123
            input, output, oscales, lora_down, lora_act_out, smooth, fuse_glu, fp4);
muyangli's avatar
muyangli committed
124
125
126
127
    });
}

void quantize_w4a4_act(Tensor input, Tensor output, Tensor oscales) {
128
    invoke_launch(input.dtype(), false, false, [&]<typename Config, bool USE_FP4>() {
Muyang Li's avatar
Muyang Li committed
129
        GEMM_W4A4_Launch<Config, false>::quantize_w4a4_act(input, output, oscales);
muyangli's avatar
muyangli committed
130
131
132
    });
}
void quantize_w4a4_wgt(Tensor input, Tensor output, Tensor oscales) {
133
    invoke_launch(input.dtype(), false, false, [&]<typename Config, bool USE_FP4>() {
Muyang Li's avatar
Muyang Li committed
134
        GEMM_W4A4_Launch<Config, false>::quantize_w4a4_wgt(input, output, oscales);
muyangli's avatar
muyangli committed
135
136
137
    });
}

138
139
140
141
142
bool FasterI2FMode::check(bool act_unsigned) {
    auto *prop = getCurrentDeviceProperties();
    if (prop->major != 7 || prop->minor != 5) {
        return false;
    }
Muyang Li's avatar
Muyang Li committed
143

144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
    if (mode == Always) {
        return true;
    } else if (mode == Enabled && !act_unsigned) {
        return true;
    } else {
        return false;
    }
}

void set_faster_i2f_mode(std::string mode) {
    static const std::map<std::string, FasterI2FMode::Mode> mapping = {
        {"disabled", FasterI2FMode::Disabled},
        {"enabled", FasterI2FMode::Enabled},
        {"always", FasterI2FMode::Always},
    };
    FasterI2FMode::mode = mapping.at(mode);
}

Muyang Li's avatar
Muyang Li committed
162
}; // namespace nunchaku::kernels