Linear.cpp 19.7 KB
Newer Older
Zhekai Zhang's avatar
Zhekai Zhang committed
1
#include "Linear.h"
muyangli's avatar
muyangli committed
2
#include "kernels/zgemm/zgemm.h"
Zhekai Zhang's avatar
Zhekai Zhang committed
3
4
5
#include "kernels/gemm_f16.h"
#include "kernels/misc_kernels.h"
#include "kernels/awq/gemv_awq.h"
muyangli's avatar
muyangli committed
6
7
#include "kernels/dwconv.h"

limm's avatar
limm committed
8
9
10
11
#include <hip/hip_fp16.h>      // 定义 __half, __bfloat16 等
#include <hip/hip_bfloat16.h>  // 显式包含 bfloat16(部分 ROCm 版本需要)
// #include <nvtx3/nvToolsExt.h>
#include <roctx.h>
muyangli's avatar
muyangli committed
12
13
14

using namespace nunchaku;

Muyang Li's avatar
Muyang Li committed
15
16
GEMM_F16::GEMM_F16(int in_features, int out_features, bool use_bias, Tensor::ScalarType dtype, Device device)
    : in_features(in_features), out_features(out_features) {
muyangli's avatar
muyangli committed
17
    this->weight = Tensor::allocate({out_features, in_features}, dtype, device);
Muyang Li's avatar
Muyang Li committed
18
    this->bias   = use_bias ? Tensor::allocate({out_features}, dtype, device) : Tensor{};
muyangli's avatar
muyangli committed
19

Muyang Li's avatar
Muyang Li committed
20
    registerParams(weight, "weight", ParamFlags::LazyLoad)(bias, "bias");
muyangli's avatar
muyangli committed
21
22
23
24
25
26
}

Tensor GEMM_F16::forward(Tensor x) {
    Tensor out = gemm_f16(x, this->weight, {}, this->bias, 1.0f);
    return out;
}
Zhekai Zhang's avatar
Zhekai Zhang committed
27

Muyang Li's avatar
Muyang Li committed
28
29
30
GEMV_AWQ::GEMV_AWQ(int in_features, int out_features, bool use_bias, Tensor::ScalarType dtype, Device device)
    : in_features(in_features), out_features(out_features), group_size(64), lora_rank(0), lora_scale(1.0f),
      device(device) {
Zhekai Zhang's avatar
Zhekai Zhang committed
31
32
33
    this->qweight = Tensor::allocate({out_features / 4, ceilDiv(in_features, 8) * 4}, Tensor::INT32, device);
    this->wscales = Tensor::allocate({ceilDiv(in_features, group_size), out_features}, dtype, device);
    this->wzeros  = Tensor::allocate({ceilDiv(in_features, group_size), out_features}, dtype, device);
Muyang Li's avatar
Muyang Li committed
34
    this->bias    = use_bias ? Tensor::allocate({out_features}, dtype, device) : Tensor{};
Zhekai Zhang's avatar
Zhekai Zhang committed
35
36
37

    // !!! lora layout is different from w4a4 !!!
    this->lora_down = Tensor::allocate({lora_rank, in_features}, dtype, device, true);
Muyang Li's avatar
Muyang Li committed
38
39
40
41
    this->lora_up   = Tensor::allocate({out_features, lora_rank}, dtype, device, true);

    registerParams(qweight, "qweight", ParamFlags::LazyLoad)(wscales, "wscales")(wzeros, "wzeros")(bias, "bias")(
        lora_down, "lora_down", ParamFlags::Optional)(lora_up, "lora_up", ParamFlags::Optional);
Zhekai Zhang's avatar
Zhekai Zhang committed
42
43
44
45
46
47
}

void GEMV_AWQ::loadParam(std::string key, Tensor &dst, Tensor src) {
    if (key == "lora_down" || key == "lora_up") {
        assert(src.ndims() == 2);
        if (dst.shape.dataExtent != src.shape.dataExtent) {
48
49
            dst = Tensor::allocate(src.shape.dataExtent, dst.scalar_type(), this->device);
            Module::loadParam(key, dst, src);
Zhekai Zhang's avatar
Zhekai Zhang committed
50
51
            if (key == "lora_down") {
                const int new_rank = dst.shape[0];
Muyang Li's avatar
Muyang Li committed
52
                this->lora_rank    = new_rank;
Zhekai Zhang's avatar
Zhekai Zhang committed
53
54
            }
        } else {
55
            Module::loadParam(key, dst, src);
Zhekai Zhang's avatar
Zhekai Zhang committed
56
57
58
59
60
61
62
63
64
65
        }
    } else {
        Module::loadParam(key, dst, src);
    }
}

Tensor GEMV_AWQ::forward(Tensor x) {
    debug("x", x);

    const int M = (int)x.numel() / x.shape[-1];
Muyang Li's avatar
Muyang Li committed
66
    Tensor out  = gemv_awq(x, this->qweight, this->wscales, this->wzeros, M, out_features, in_features, group_size);
Zhekai Zhang's avatar
Zhekai Zhang committed
67
68
    if (bias.valid()) {
        // TODO: batch
69
70
71
        // assert(out.numel() == bias.numel());
        // out = kernels::add(out, bias.view(out.shape.dataExtent));
        kernels::mul_add_batch(out, {}, false, 0.0, bias, false);
Zhekai Zhang's avatar
Zhekai Zhang committed
72
73
74
75
76
    }

    debug("out_before_lora", out);

    if (this->lora_rank > 0) {
muyangli's avatar
muyangli committed
77
        Tensor lora_act = gemm_f16(x, this->lora_down, {}, {}, 1.0f);
Zhekai Zhang's avatar
Zhekai Zhang committed
78
79
        debug("lora_act", lora_act);

muyangli's avatar
muyangli committed
80
        Tensor lora_out = gemm_f16(lora_act, this->lora_up, {}, {}, this->lora_scale);
Zhekai Zhang's avatar
Zhekai Zhang committed
81
82
        debug("lora_out", lora_out);

muyangli's avatar
muyangli committed
83
        out = kernels::add(out, lora_out);
Zhekai Zhang's avatar
Zhekai Zhang committed
84
85
86
    }

    debug("out", out);
Muyang Li's avatar
Muyang Li committed
87

Zhekai Zhang's avatar
Zhekai Zhang committed
88
89
90
91
92
    return out;
}

#define NO_LORA_FUSION 0

Muyang Li's avatar
Muyang Li committed
93
94
95
96
GEMM_W4A4::GEMM_W4A4(
    int in_features, int out_features, bool bias, bool use_fp4, Tensor::ScalarType dtype, Device device)
    : in_features(in_features), out_features(out_features), in_features_pad(ceilDiv(in_features, 128) * 128),
      out_features_pad(ceilDiv(out_features, 128) * 128), use_fp4(use_fp4), lora_rank(0), dtype(dtype), device(device) {
muyangli's avatar
muyangli committed
97
    this->qweight = Tensor::allocate({out_features_pad, in_features_pad / 2}, Tensor::INT8, device, true);
98
99
100
101
102
    if (use_fp4) {
        this->wscales = Tensor::allocate({in_features_pad / 16, out_features_pad}, Tensor::FP8_E4M3, device, true);
    } else {
        this->wscales = Tensor::allocate({in_features_pad / 64, out_features_pad}, dtype, device, true);
    }
Zhekai Zhang's avatar
Zhekai Zhang committed
103

muyangli's avatar
muyangli committed
104
    this->bias = bias ? Tensor::allocate({out_features_pad}, dtype, device, true) : Tensor{};
Zhekai Zhang's avatar
Zhekai Zhang committed
105

muyangli's avatar
muyangli committed
106
    this->lora_down = Tensor::allocate({in_features_pad, lora_rank}, dtype, device, true);
Muyang Li's avatar
Muyang Li committed
107
    this->lora_up   = Tensor::allocate({out_features_pad, lora_rank}, dtype, device, true);
Zhekai Zhang's avatar
Zhekai Zhang committed
108
109

    // TODO: smooth factor in non-Lora fusion
muyangli's avatar
muyangli committed
110
    this->smooth = Tensor::allocate({in_features_pad}, dtype, device, true);
Zhekai Zhang's avatar
Zhekai Zhang committed
111

112
    // FIXME: reset wtscale and wcscales to default values when reloading the weights
Muyang Li's avatar
Muyang Li committed
113
    this->wtscale                    = Tensor::allocate({1}, Tensor::FP32, Device::cpu(), true);
114
115
116
117
    *this->wtscale.data_ptr<float>() = 1.0f;

    this->wcscales = Tensor::allocate({0}, dtype, device, true);

Muyang Li's avatar
Muyang Li committed
118
119
120
    registerParams(qweight, "qweight", ParamFlags::LazyLoad)(wscales, "wscales")(this->bias, "bias")(
        lora_down, "lora_down", ParamFlags::Optional)(lora_up, "lora_up", ParamFlags::Optional)(smooth, "smooth")(
        wtscale, "wtscale", ParamFlags::Optional)(wcscales, "wcscales", ParamFlags::Optional);
Zhekai Zhang's avatar
Zhekai Zhang committed
121
122

#if NO_LORA_FUSION
fengzch-das's avatar
fengzch-das committed
123
    checkCUBLAS(cublasCreate(&handle));
Zhekai Zhang's avatar
Zhekai Zhang committed
124
125
126
127
128
129
130
#endif
}

void GEMM_W4A4::loadParam(std::string key, Tensor &dst, Tensor src) {
    if (key == "lora_down" || key == "lora_up") {
        assert(src.ndims() == 2);
        if (dst.shape.dataExtent != src.shape.dataExtent) {
131
132
            dst = Tensor::allocate(src.shape.dataExtent, dst.scalar_type(), this->device);
            Module::loadParam(key, dst, src);
Zhekai Zhang's avatar
Zhekai Zhang committed
133
134
135
            this->lora_rank = dst.shape[1];
            this->lora_scales.resize(ceilDiv(this->lora_rank, 16), 1.0f);
        } else {
136
            Module::loadParam(key, dst, src);
Zhekai Zhang's avatar
Zhekai Zhang committed
137
        }
138
139
140
    } else if (key == "wcscales") {
        assert(src.ndims() == 1);
        assert(src.shape[0] == out_features_pad);
141
142
        dst = Tensor::allocate(src.shape.dataExtent, dst.scalar_type(), this->device);
        Module::loadParam(key, dst, src);
143
144
145
    } else if (key == "wtscale") {
        assert(src.numel() == 1);
        if (src.dtype() == Tensor::BF16) {
limm's avatar
limm committed
146
            *dst.data_ptr<float>() = float(*src.data_ptr<hip_bfloat16>());
147
148
149
        } else if (src.dtype() == Tensor::FP16) {
            *dst.data_ptr<float>() = float(*src.data_ptr<half>());
        } else if (src.dtype() == Tensor::FP32) {
150
            Module::loadParam(key, dst, src);
151
152
153
        } else {
            assert(false);
        }
Zhekai Zhang's avatar
Zhekai Zhang committed
154
155
156
157
158
    } else {
        Module::loadParam(key, dst, src);
    }
}

muyangli's avatar
muyangli committed
159
160
161
162
163
164
165
166
Tensor GEMM_W4A4::forward(Tensor x) {
    return std::get<Tensor>(this->forward(x, FuseOptions::EMPTY, nullptr));
}

Tensor GEMM_W4A4::forward_silu(Tensor x) {
    return std::get<Tensor>(this->forward(x, FuseOptions::SILU, nullptr));
}

Muyang Li's avatar
Muyang Li committed
167
168
std::variant<Tensor, GEMM_W4A4::QuantizedActivation>
GEMM_W4A4::forward(Tensor x, FuseOptions fuse, GEMM_W4A4 *nextGEMM) {
muyangli's avatar
muyangli committed
169
    return forward_quant(quantize(x, false), fuse, nextGEMM);
Zhekai Zhang's avatar
Zhekai Zhang committed
170
171
}

Muyang Li's avatar
Muyang Li committed
172
173
174
175
176
177
178
179
180
181
void GEMM_W4A4::forward(Tensor x,
                        Tensor out,
                        Tensor pool,
                        Tensor norm_q,
                        Tensor norm_k,
                        Tensor rotary_emb,
                        Tensor out_q,
                        Tensor out_k,
                        Tensor out_v,
                        int numTokens) {
muyangli's avatar
muyangli committed
182
    QuantizedActivation qact = quantize(x, false);
Zhekai Zhang's avatar
Zhekai Zhang committed
183
184
185
186
187
188
189
190
191
192
193

#if !NO_LORA_FUSION

#if 0
    Tensor dummy = Tensor::empty_like(qact.lora_act);
    dummy.zero_();

    gemm_w4a4(qact.act, qweight, out, {}, qact.ascales, wscales, {}, pool, dummy, this->lora_up, {}, {}, norm_q, norm_k, rotary_emb, this->bias, {}, qact.is_unsigned);
    debug("gemm.nolora.out", out);
#endif

Muyang Li's avatar
Muyang Li committed
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
    kernels::gemm_w4a4(qact.act,
                       qweight,
                       out,
                       {},
                       qact.ascales,
                       wscales,
                       {},
                       pool,
                       qact.lora_act,
                       this->lora_up,
                       {},
                       {},
                       norm_q,
                       norm_k,
                       rotary_emb,
                       this->bias,
                       {},
                       {},
                       {},
                       qact.is_unsigned,
                       this->lora_scales,
                       false,
                       use_fp4,
                       *this->wtscale.data_ptr<float>(),
                       wcscales.numel() > 0 ? wcscales : Tensor{},
                       out_q,
                       out_k,
                       out_v,
                       numTokens);
Zhekai Zhang's avatar
Zhekai Zhang committed
223
224
225
226
227

    debug("gemm.out", out);
#else
    const int M = (int)qact.act.numel() / qact.act.shape[-1];

Muyang Li's avatar
Muyang Li committed
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
    kernels::gemm_w4a4(qact.act,
                       qweight,
                       out,
                       {},
                       qact.ascales,
                       wscales,
                       {},
                       pool,
                       {},
                       {},
                       {},
                       {},
                       norm_q,
                       norm_k,
                       rotary_emb,
                       this->bias,
                       {},
                       qact.is_unsigned,
                       this->lora_scales);
Zhekai Zhang's avatar
Zhekai Zhang committed
247

fengzch-das's avatar
fengzch-das committed
248
    nvtxRangePushA("LoraUp");
Zhekai Zhang's avatar
Zhekai Zhang committed
249

Muyang Li's avatar
Muyang Li committed
250
    static const half one  = 1.0;
Zhekai Zhang's avatar
Zhekai Zhang committed
251
252
253
    static const half zero = 0.0;
    // lora_up: [M, R] * [OC, R] => [M, OC]
    // cublas view: [OC, R] * [M, R]^T
fengzch-das's avatar
fengzch-das committed
254
255
256
    checkCUBLAS(cublasHgemm(handle,
                            CUBLAS_OP_T,
                            CUBLAS_OP_N,
Muyang Li's avatar
Muyang Li committed
257
258
259
260
261
262
263
264
265
266
267
                            this->out_features,
                            M,
                            this->lora_rank,
                            &one,
                            this->lora_up.data_ptr<half>(),
                            this->lora_rank,
                            qact.lora_act.data_ptr<half>(),
                            this->lora_rank,
                            &one,
                            out.data_ptr<half>(),
                            this->out_features));
Zhekai Zhang's avatar
Zhekai Zhang committed
268

fengzch-das's avatar
fengzch-das committed
269
    nvtxRangePop();
Zhekai Zhang's avatar
Zhekai Zhang committed
270
271
272
#endif
}

Muyang Li's avatar
Muyang Li committed
273
274
std::variant<Tensor, GEMM_W4A4::QuantizedActivation>
GEMM_W4A4::forward_quant(QuantizedActivation qact, FuseOptions fuse, GEMM_W4A4 *nextGEMM) {
Zhekai Zhang's avatar
Zhekai Zhang committed
275
276
277
278
279
280
281
282
    Tensor out;
    QuantizedActivation qout;

    Tensor next_lora;
    Tensor next_smooth;

    const int M = (int)qact.act.numel() / qact.act.shape[-1];

muyangli's avatar
muyangli committed
283
284
285
286
    if (fuse == FuseOptions::EMPTY || fuse == FuseOptions::SILU) {
        // auto shape = TensorShape(qact.act.shape.dataExtent);
        // shape[-1] = out_features;
        auto shape = TensorShape(qact.actShape.dataExtent);
Muyang Li's avatar
Muyang Li committed
287
288
        shape[-1]  = out_features;
        out        = Tensor::allocate(shape, dtype, device);
Zhekai Zhang's avatar
Zhekai Zhang committed
289
    } else {
muyangli's avatar
muyangli committed
290
        qout.act = Tensor::allocate({M, out_features_pad / 2}, Tensor::INT8, device);
291
        if (use_fp4) {
muyangli's avatar
muyangli committed
292
            qout.ascales = Tensor::allocate({out_features_pad / 16, M}, Tensor::FP8_E4M3, device);
293
        } else {
muyangli's avatar
muyangli committed
294
            qout.ascales = Tensor::allocate({out_features_pad / 64, M}, dtype, device);
295
        }
Muyang Li's avatar
Muyang Li committed
296
        qout.lora_act    = Tensor::allocate({M, lora_rank}, Tensor::FP32, device);
297
        qout.is_unsigned = !use_fp4;
Muyang Li's avatar
Muyang Li committed
298
        qout.actShape    = qact.actShape;
Zhekai Zhang's avatar
Zhekai Zhang committed
299

Muyang Li's avatar
Muyang Li committed
300
        next_lora   = nextGEMM->lora_down;
Zhekai Zhang's avatar
Zhekai Zhang committed
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
        next_smooth = nextGEMM->smooth;
    }

#if !NO_LORA_FUSION

#if 0
    Tensor dummy = Tensor::empty_like(qact.lora_act);
    dummy.zero_();

    gemm_w4a4(qact.act, qweight, out, qout.act, qact.ascales, wscales, qout.ascales, {}, dummy, this->lora_up, next_lora, qout.lora_act, {}, {}, {}, this->bias, next_smooth, qact.is_unsigned);

    if (fuse == FuseOptions::EMPTY) {
        debug("gemm.nolora.out", out);
    } else {
        debug("gemm.nolora.qout", qout.act);
        debug("gemm.nolora.oscales", qout.ascales);
        debug("gemm.nolora.lora_act_out", qout.lora_act);
    }
#endif

Muyang Li's avatar
Muyang Li committed
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
    kernels::gemm_w4a4(qact.act,
                       qweight,
                       out,
                       qout.act,
                       qact.ascales,
                       wscales,
                       qout.ascales,
                       {},
                       qact.lora_act,
                       this->lora_up,
                       next_lora,
                       qout.lora_act,
                       {},
                       {},
                       {},
                       this->bias,
                       next_smooth,
                       {},
                       {},
                       qact.is_unsigned,
                       this->lora_scales,
                       fuse == FuseOptions::SILU,
                       use_fp4,
                       *this->wtscale.data_ptr<float>(),
                       wcscales.numel() > 0 ? wcscales : Tensor{},
                       {},
                       {},
                       {},
                       0);
Zhekai Zhang's avatar
Zhekai Zhang committed
350

muyangli's avatar
muyangli committed
351
    if (fuse == FuseOptions::EMPTY || fuse == FuseOptions::SILU) {
Zhekai Zhang's avatar
Zhekai Zhang committed
352
353
354
355
356
357
358
359
360
361
        debug("gemm.out", out);
    } else {
        debug("gemm.qout", qout.act);
        debug("gemm.oscales", qout.ascales);
        debug("gemm.lora_act_out", qout.lora_act);
    }

#else
    if (!out.valid()) {
        auto shape = TensorShape(qact.act.shape.dataExtent);
Muyang Li's avatar
Muyang Li committed
362
363
        shape[-1]  = out_features;
        out        = Tensor::allocate(shape, Tensor::FP16, qweight.device());
Zhekai Zhang's avatar
Zhekai Zhang committed
364
365
    }

Muyang Li's avatar
Muyang Li committed
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
    kernels::gemm_w4a4(qact.act,
                       qweight,
                       out,
                       qout.act,
                       qact.ascales,
                       wscales,
                       qout.ascales,
                       {},
                       {},
                       {},
                       {},
                       {},
                       {},
                       {},
                       {},
                       this->bias,
                       next_smooth,
                       qact.is_unsigned,
                       this->lora_scales);
Zhekai Zhang's avatar
Zhekai Zhang committed
385

fengzch-das's avatar
fengzch-das committed
386
    nvtxRangePushA("LoraUp");
Zhekai Zhang's avatar
Zhekai Zhang committed
387

Muyang Li's avatar
Muyang Li committed
388
    static const half one  = 1.0;
Zhekai Zhang's avatar
Zhekai Zhang committed
389
390
391
392
393
    static const half zero = 0.0;

    // lora_up: [M, R] * [OC, R]^T => [M, OC]
    // cublas view: [R, OC]^T * [R, M] => [OC, M]
    // lora_up layout wrong?
fengzch-das's avatar
fengzch-das committed
394
395
396
    checkCUBLAS(cublasHgemm(handle,
                            CUBLAS_OP_T,
                            CUBLAS_OP_N,
Muyang Li's avatar
Muyang Li committed
397
398
399
400
401
402
403
404
405
406
407
                            this->out_features,
                            M,
                            this->lora_rank,
                            &one,
                            this->lora_up.data_ptr<half>(),
                            this->lora_rank,
                            qact.lora_act.data_ptr<half>(),
                            this->lora_rank,
                            &one,
                            out.data_ptr<half>(),
                            this->out_features));
Zhekai Zhang's avatar
Zhekai Zhang committed
408

fengzch-das's avatar
fengzch-das committed
409
    nvtxRangePop();
Zhekai Zhang's avatar
Zhekai Zhang committed
410
411

    if (fuse == FuseOptions::GELU_QUANT) {
fengzch-das's avatar
fengzch-das committed
412
        nvtxRangePushA("LoraDown");
Zhekai Zhang's avatar
Zhekai Zhang committed
413
414
415
        // IC is for next lora (OC of this layer)
        // lora_down: [M, IC] * [IC, R] => [M, R]
        // cublas view: [R, IC] * [IC, M] => [R, M]
fengzch-das's avatar
fengzch-das committed
416
417
418
        checkCUBLAS(cublasHgemm(handle,
                                CUBLAS_OP_N,
                                CUBLAS_OP_N,
Muyang Li's avatar
Muyang Li committed
419
420
421
422
423
424
425
426
427
428
429
                                this->lora_rank,
                                M,
                                this->out_features,
                                &one,
                                next_lora.data_ptr<half>(),
                                this->lora_rank,
                                out.data_ptr<half>(),
                                this->out_features,
                                &zero,
                                qout.lora_act.data_ptr<half>(),
                                this->lora_rank));
Zhekai Zhang's avatar
Zhekai Zhang committed
430
431
432

        out = {};

fengzch-das's avatar
fengzch-das committed
433
        nvtxRangePop();
Zhekai Zhang's avatar
Zhekai Zhang committed
434
435
436
437
438
439
440
441
442
    }

#endif
    if (out.valid()) {
        return out;
    }
    return qout;
}

muyangli's avatar
muyangli committed
443
444
445
446
447
448
Tensor GEMM_W4A4::forward_quant(QuantizedActivation qact) {
    return std::get<Tensor>(this->forward_quant(qact, FuseOptions::EMPTY, nullptr));
}

GEMM_W4A4::QuantizedActivation GEMM_W4A4::quantize(Tensor x, bool fuse_glu) {
    const int actualM = x.numel() / x.shape[-1];
Muyang Li's avatar
Muyang Li committed
449
    const int M       = ceilDiv(actualM, 256) * 256;
Zhekai Zhang's avatar
Zhekai Zhang committed
450

muyangli's avatar
muyangli committed
451
452
    // auto shape = TensorShape(x.shape.dataExtent);
    // shape[-1] = in_features / 2;
Zhekai Zhang's avatar
Zhekai Zhang committed
453
454

    QuantizedActivation qact;
muyangli's avatar
muyangli committed
455
    qact.act = Tensor::allocate({M, in_features_pad / 2}, Tensor::INT8, device);
456
    if (use_fp4) {
muyangli's avatar
muyangli committed
457
        qact.ascales = Tensor::allocate({in_features_pad / 16, M}, Tensor::FP8_E4M3, device);
458
    } else {
muyangli's avatar
muyangli committed
459
        qact.ascales = Tensor::allocate({in_features_pad / 64, M}, dtype, device);
460
    }
Muyang Li's avatar
Muyang Li committed
461
    qact.lora_act    = Tensor::allocate({M, lora_rank}, Tensor::FP32, device);
Zhekai Zhang's avatar
Zhekai Zhang committed
462
    qact.is_unsigned = false;
Muyang Li's avatar
Muyang Li committed
463
    qact.actShape    = x.shape.dataExtent;
Zhekai Zhang's avatar
Zhekai Zhang committed
464
465
466
467
468

#if !NO_LORA_FUSION
    debug("quantize.x", x);
    debug("quantize.smooth", this->smooth);

Muyang Li's avatar
Muyang Li committed
469
470
    kernels::quantize_w4a4_act_fuse_lora(
        x, qact.act, qact.ascales, this->lora_down, qact.lora_act, this->smooth, fuse_glu, use_fp4);
Zhekai Zhang's avatar
Zhekai Zhang committed
471
472
473
474

    debug("quantize.qact", qact.act);
    debug("quantize.ascales", qact.ascales);
    debug("quantize.lora_act", qact.lora_act);
Muyang Li's avatar
Muyang Li committed
475
476
#else
    static const half one  = 1.0;
Zhekai Zhang's avatar
Zhekai Zhang committed
477
478
    static const half zero = 0.0;

fengzch-das's avatar
fengzch-das committed
479
    nvtxRangePushA("LoraDown");
Zhekai Zhang's avatar
Zhekai Zhang committed
480
481
482

    // lora_down: [M, IC] * [IC, R] => [M, R]
    // cublas view: [R, IC] * [IC, M]
fengzch-das's avatar
fengzch-das committed
483
484
485
    checkCUBLAS(cublasHgemm(handle,
                            CUBLAS_OP_N,
                            CUBLAS_OP_N,
Muyang Li's avatar
Muyang Li committed
486
487
488
489
490
491
492
493
494
495
496
                            this->lora_rank,
                            M,
                            this->in_features,
                            &one,
                            lora_down.data_ptr<half>(),
                            this->lora_rank,
                            x.data_ptr<half>(),
                            this->in_features,
                            &zero,
                            qact.lora_act.data_ptr<half>(),
                            this->lora_rank));
Zhekai Zhang's avatar
Zhekai Zhang committed
497

fengzch-das's avatar
fengzch-das committed
498
    nvtxRangePop();
Zhekai Zhang's avatar
Zhekai Zhang committed
499

muyangli's avatar
muyangli committed
500
    kernels::quantize_w4a4_act(x, qact.act, qact.ascales);
Zhekai Zhang's avatar
Zhekai Zhang committed
501
502
503
504
505

#endif

    return qact;
}
muyangli's avatar
muyangli committed
506

Muyang Li's avatar
Muyang Li committed
507
508
GEMM_W8A8::GEMM_W8A8(int in_features, int out_features, bool bias, Tensor::ScalarType dtype, Device device)
    : in_features(in_features), out_features(out_features), dtype(dtype) {
muyangli's avatar
muyangli committed
509
510
    this->qweight = Tensor::allocate({out_features, in_features}, Tensor::INT8, device);
    this->wscales = Tensor::allocate({out_features}, dtype, device);
Muyang Li's avatar
Muyang Li committed
511
    this->bias    = bias ? Tensor::allocate({out_features}, dtype, device, true) : Tensor{};
muyangli's avatar
muyangli committed
512

Muyang Li's avatar
Muyang Li committed
513
    registerParams(qweight, "qweight", ParamFlags::LazyLoad)(wscales, "wscales")(this->bias, "bias");
muyangli's avatar
muyangli committed
514
515
516
517
518
519
520
521
}

GEMM_W8A8::QuantizedActivation GEMM_W8A8::quantize(Tensor x, bool fuse_glu) {
    QuantizedActivation qact;
    auto qshape = x.shape;
    if (fuse_glu) {
        qshape[-1] /= 2;
    }
Muyang Li's avatar
Muyang Li committed
522
    qact.act     = Tensor::allocate(qshape, Tensor::INT8, x.device());
muyangli's avatar
muyangli committed
523
524
525
526
527
528
529
530
531
532
533
534
535
    qact.ascales = Tensor::allocate({(int)x.numel() / x.shape[-1]}, this->dtype, x.device());

    debug("quantize.x", x);

    kernels::quantize_w8a8_act(x, qact.act, qact.ascales, fuse_glu);

    debug("quantize.qact", qact.act);
    debug("quantize.ascales", qact.ascales);

    return qact;
}

Tensor GEMM_W8A8::forward_quant(QuantizedActivation qact) {
LeeDongYeun's avatar
LeeDongYeun committed
536
    auto shape = TensorShape(qact.act.shape.dataExtent);
Muyang Li's avatar
Muyang Li committed
537
    shape[-1]  = out_features;
LeeDongYeun's avatar
LeeDongYeun committed
538
    Tensor out = Tensor::allocate(shape, this->dtype, qact.act.device());
muyangli's avatar
muyangli committed
539
540
541
542
543
544
    kernels::gemm_w8a8(qact.act, this->qweight, out, qact.ascales, this->wscales, this->bias);

    debug("gemm.out", out);
    return out;
}

Muyang Li's avatar
Muyang Li committed
545
DWCONV::DWCONV(int in_features, bool use_bias, Tensor::ScalarType dtype, Device device) : in_features(in_features) {
muyangli's avatar
muyangli committed
546
    this->weight = Tensor::allocate({in_features, 3, 3, 1}, dtype, device);
Muyang Li's avatar
Muyang Li committed
547
    this->bias   = use_bias ? Tensor::allocate({in_features}, dtype, device) : Tensor{};
muyangli's avatar
muyangli committed
548

Muyang Li's avatar
Muyang Li committed
549
    registerParams(this->weight, "weight")(this->bias, "bias");
muyangli's avatar
muyangli committed
550
551
552
553
}

Tensor DWCONV::forward(Tensor x) {
    return dwconv_f16(x, this->weight, {}, this->bias);
Muyang Li's avatar
Muyang Li committed
554
}