Linear.cpp 16.4 KB
Newer Older
Zhekai Zhang's avatar
Zhekai Zhang committed
1
#include "Linear.h"
muyangli's avatar
muyangli committed
2
#include "kernels/zgemm/zgemm.h"
Zhekai Zhang's avatar
Zhekai Zhang committed
3
4
5
#include "kernels/gemm_f16.h"
#include "kernels/misc_kernels.h"
#include "kernels/awq/gemv_awq.h"
muyangli's avatar
muyangli committed
6
7
8
9
10
11
12
13
14
15
16
17
18
#include "kernels/dwconv.h"

#include <nvtx3/nvToolsExt.h>

using namespace nunchaku;

GEMM_F16::GEMM_F16(int in_features, int out_features, bool use_bias, Tensor::ScalarType dtype, Device device) :
    in_features(in_features), out_features(out_features)
{
    this->weight = Tensor::allocate({out_features, in_features}, dtype, device);
    this->bias = use_bias ? Tensor::allocate({out_features}, dtype, device) : Tensor{};

    registerParams
muyangli's avatar
muyangli committed
19
        (weight, "weight", ParamFlags::LazyLoad)
muyangli's avatar
muyangli committed
20
21
22
23
24
25
26
27
        (bias, "bias")
    ;
}

Tensor GEMM_F16::forward(Tensor x) {
    Tensor out = gemm_f16(x, this->weight, {}, this->bias, 1.0f);
    return out;
}
Zhekai Zhang's avatar
Zhekai Zhang committed
28
29

GEMV_AWQ::GEMV_AWQ(int in_features, int out_features, bool use_bias, Tensor::ScalarType dtype, Device device) : 
muyangli's avatar
muyangli committed
30
    in_features(in_features), out_features(out_features), group_size(64), lora_rank(0), lora_scale(1.0f), device(device)
Zhekai Zhang's avatar
Zhekai Zhang committed
31
32
33
34
35
36
37
38
39
40
41
{
    this->qweight = Tensor::allocate({out_features / 4, ceilDiv(in_features, 8) * 4}, Tensor::INT32, device);
    this->wscales = Tensor::allocate({ceilDiv(in_features, group_size), out_features}, dtype, device);
    this->wzeros  = Tensor::allocate({ceilDiv(in_features, group_size), out_features}, dtype, device);
    this->bias = use_bias ? Tensor::allocate({out_features}, dtype, device) : Tensor{};

    // !!! lora layout is different from w4a4 !!!
    this->lora_down = Tensor::allocate({lora_rank, in_features}, dtype, device, true);
    this->lora_up = Tensor::allocate({out_features, lora_rank}, dtype, device, true);

    registerParams
muyangli's avatar
muyangli committed
42
        (qweight, "qweight", ParamFlags::LazyLoad)
Zhekai Zhang's avatar
Zhekai Zhang committed
43
44
45
46
47
48
49
50
51
52
53
54
        (wscales, "wscales")
        (wzeros, "wzeros")
        (bias, "bias")
        (lora_down, "lora_down", ParamFlags::Optional)
        (lora_up, "lora_up", ParamFlags::Optional)
    ;
}

void GEMV_AWQ::loadParam(std::string key, Tensor &dst, Tensor src) {
    if (key == "lora_down" || key == "lora_up") {
        assert(src.ndims() == 2);
        if (dst.shape.dataExtent != src.shape.dataExtent) {
55
56
            dst = Tensor::allocate(src.shape.dataExtent, dst.scalar_type(), this->device);
            Module::loadParam(key, dst, src);
Zhekai Zhang's avatar
Zhekai Zhang committed
57
58
59
60
61
            if (key == "lora_down") {
                const int new_rank = dst.shape[0];
                this->lora_rank = new_rank;
            }
        } else {
62
            Module::loadParam(key, dst, src);
Zhekai Zhang's avatar
Zhekai Zhang committed
63
64
65
66
67
68
69
70
71
72
73
74
75
        }
    } else {
        Module::loadParam(key, dst, src);
    }
}

Tensor GEMV_AWQ::forward(Tensor x) {
    debug("x", x);

    const int M = (int)x.numel() / x.shape[-1];
    Tensor out = gemv_awq(x, this->qweight, this->wscales, this->wzeros, M, out_features, in_features, group_size);
    if (bias.valid()) {
        // TODO: batch
76
77
78
        // assert(out.numel() == bias.numel());
        // out = kernels::add(out, bias.view(out.shape.dataExtent));
        kernels::mul_add_batch(out, {}, false, 0.0, bias, false);
Zhekai Zhang's avatar
Zhekai Zhang committed
79
80
81
82
83
    }

    debug("out_before_lora", out);

    if (this->lora_rank > 0) {
muyangli's avatar
muyangli committed
84
        Tensor lora_act = gemm_f16(x, this->lora_down, {}, {}, 1.0f);
Zhekai Zhang's avatar
Zhekai Zhang committed
85
86
        debug("lora_act", lora_act);

muyangli's avatar
muyangli committed
87
        Tensor lora_out = gemm_f16(lora_act, this->lora_up, {}, {}, this->lora_scale);
Zhekai Zhang's avatar
Zhekai Zhang committed
88
89
        debug("lora_out", lora_out);

muyangli's avatar
muyangli committed
90
        out = kernels::add(out, lora_out);
Zhekai Zhang's avatar
Zhekai Zhang committed
91
92
93
94
95
96
97
98
99
100
    }

    debug("out", out);
    
    return out;
}


#define NO_LORA_FUSION 0

101
GEMM_W4A4::GEMM_W4A4(int in_features, int out_features, bool bias, bool use_fp4, Tensor::ScalarType dtype, Device device) : 
muyangli's avatar
muyangli committed
102
103
    in_features(in_features), out_features(out_features), 
    in_features_pad(ceilDiv(in_features, 128) * 128), out_features_pad(ceilDiv(out_features, 128) * 128),
104
    use_fp4(use_fp4),
muyangli's avatar
muyangli committed
105
    lora_rank(0), dtype(dtype), device(device)
Zhekai Zhang's avatar
Zhekai Zhang committed
106
{
muyangli's avatar
muyangli committed
107
    this->qweight = Tensor::allocate({out_features_pad, in_features_pad / 2}, Tensor::INT8, device, true);
108
109
110
111
112
    if (use_fp4) {
        this->wscales = Tensor::allocate({in_features_pad / 16, out_features_pad}, Tensor::FP8_E4M3, device, true);
    } else {
        this->wscales = Tensor::allocate({in_features_pad / 64, out_features_pad}, dtype, device, true);
    }
Zhekai Zhang's avatar
Zhekai Zhang committed
113

muyangli's avatar
muyangli committed
114
    this->bias = bias ? Tensor::allocate({out_features_pad}, dtype, device, true) : Tensor{};
Zhekai Zhang's avatar
Zhekai Zhang committed
115

muyangli's avatar
muyangli committed
116
117
    this->lora_down = Tensor::allocate({in_features_pad, lora_rank}, dtype, device, true);
    this->lora_up = Tensor::allocate({out_features_pad, lora_rank}, dtype, device, true);
Zhekai Zhang's avatar
Zhekai Zhang committed
118
119

    // TODO: smooth factor in non-Lora fusion
muyangli's avatar
muyangli committed
120
    this->smooth = Tensor::allocate({in_features_pad}, dtype, device, true);
Zhekai Zhang's avatar
Zhekai Zhang committed
121

122
123
124
125
126
127
    // FIXME: reset wtscale and wcscales to default values when reloading the weights
    this->wtscale = Tensor::allocate({1}, Tensor::FP32, Device::cpu(), true);
    *this->wtscale.data_ptr<float>() = 1.0f;

    this->wcscales = Tensor::allocate({0}, dtype, device, true);

Zhekai Zhang's avatar
Zhekai Zhang committed
128
    registerParams
muyangli's avatar
muyangli committed
129
        (qweight, "qweight", ParamFlags::LazyLoad)
Zhekai Zhang's avatar
Zhekai Zhang committed
130
131
132
133
134
        (wscales, "wscales")
        (this->bias, "bias")
        (lora_down, "lora_down", ParamFlags::Optional)
        (lora_up, "lora_up", ParamFlags::Optional)
        (smooth, "smooth")
135
136
        (wtscale, "wtscale", ParamFlags::Optional)
        (wcscales, "wcscales", ParamFlags::Optional)
Zhekai Zhang's avatar
Zhekai Zhang committed
137
138
139
140
141
142
143
144
145
146
147
    ;

#if NO_LORA_FUSION
    checkCUBLAS(cublasCreate(&handle));
#endif
}

void GEMM_W4A4::loadParam(std::string key, Tensor &dst, Tensor src) {
    if (key == "lora_down" || key == "lora_up") {
        assert(src.ndims() == 2);
        if (dst.shape.dataExtent != src.shape.dataExtent) {
148
149
            dst = Tensor::allocate(src.shape.dataExtent, dst.scalar_type(), this->device);
            Module::loadParam(key, dst, src);
Zhekai Zhang's avatar
Zhekai Zhang committed
150
151
152
            this->lora_rank = dst.shape[1];
            this->lora_scales.resize(ceilDiv(this->lora_rank, 16), 1.0f);
        } else {
153
            Module::loadParam(key, dst, src);
Zhekai Zhang's avatar
Zhekai Zhang committed
154
        }
155
156
157
    } else if (key == "wcscales") {
        assert(src.ndims() == 1);
        assert(src.shape[0] == out_features_pad);
158
159
        dst = Tensor::allocate(src.shape.dataExtent, dst.scalar_type(), this->device);
        Module::loadParam(key, dst, src);
160
161
162
163
164
165
166
    } else if (key == "wtscale") {
        assert(src.numel() == 1);
        if (src.dtype() == Tensor::BF16) {
            *dst.data_ptr<float>() = float(*src.data_ptr<__nv_bfloat16>());
        } else if (src.dtype() == Tensor::FP16) {
            *dst.data_ptr<float>() = float(*src.data_ptr<half>());
        } else if (src.dtype() == Tensor::FP32) {
167
            Module::loadParam(key, dst, src);
168
169
170
        } else {
            assert(false);
        }
Zhekai Zhang's avatar
Zhekai Zhang committed
171
172
173
174
175
    } else {
        Module::loadParam(key, dst, src);
    }
}

muyangli's avatar
muyangli committed
176
177
178
179
180
181
182
183
Tensor GEMM_W4A4::forward(Tensor x) {
    return std::get<Tensor>(this->forward(x, FuseOptions::EMPTY, nullptr));
}

Tensor GEMM_W4A4::forward_silu(Tensor x) {
    return std::get<Tensor>(this->forward(x, FuseOptions::SILU, nullptr));
}

Zhekai Zhang's avatar
Zhekai Zhang committed
184
std::variant<Tensor, GEMM_W4A4::QuantizedActivation> GEMM_W4A4::forward(Tensor x, FuseOptions fuse, GEMM_W4A4 *nextGEMM) {
muyangli's avatar
muyangli committed
185
    return forward_quant(quantize(x, false), fuse, nextGEMM);
Zhekai Zhang's avatar
Zhekai Zhang committed
186
187
}

188
void GEMM_W4A4::forward(Tensor x, Tensor out, Tensor pool, Tensor norm_q, Tensor norm_k, Tensor rotary_emb, Tensor out_q, Tensor out_k, Tensor out_v, int numTokens) {
muyangli's avatar
muyangli committed
189
    QuantizedActivation qact = quantize(x, false);
Zhekai Zhang's avatar
Zhekai Zhang committed
190
191
192
193
194
195
196
197
198
199
200

#if !NO_LORA_FUSION

#if 0
    Tensor dummy = Tensor::empty_like(qact.lora_act);
    dummy.zero_();

    gemm_w4a4(qact.act, qweight, out, {}, qact.ascales, wscales, {}, pool, dummy, this->lora_up, {}, {}, norm_q, norm_k, rotary_emb, this->bias, {}, qact.is_unsigned);
    debug("gemm.nolora.out", out);
#endif

201
202
    kernels::gemm_w4a4(
        qact.act, qweight, out, {}, qact.ascales, wscales, {}, pool, qact.lora_act, this->lora_up, {}, {}, norm_q, norm_k, rotary_emb, this->bias, {}, {}, {}, qact.is_unsigned, this->lora_scales, false,
203
204
        use_fp4, *this->wtscale.data_ptr<float>(), wcscales.numel() > 0 ? wcscales: Tensor{},
        out_q, out_k, out_v, numTokens
205
    );
Zhekai Zhang's avatar
Zhekai Zhang committed
206
207
208
209
210

    debug("gemm.out", out);
#else
    const int M = (int)qact.act.numel() / qact.act.shape[-1];

muyangli's avatar
muyangli committed
211
    kernels::gemm_w4a4(qact.act, qweight, out, {}, qact.ascales, wscales, {}, pool, {}, {}, {}, {}, norm_q, norm_k, rotary_emb, this->bias, {}, qact.is_unsigned, this->lora_scales);
Zhekai Zhang's avatar
Zhekai Zhang committed
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244

    nvtxRangePushA("LoraUp");

    static const half one = 1.0;
    static const half zero = 0.0;
    // lora_up: [M, R] * [OC, R] => [M, OC]
    // cublas view: [OC, R] * [M, R]^T
    checkCUBLAS(cublasHgemm(
        handle, 
        CUBLAS_OP_T, CUBLAS_OP_N, 
        this->out_features, M, this->lora_rank,
        &one,
        this->lora_up.data_ptr<half>(),
        this->lora_rank,
        qact.lora_act.data_ptr<half>(),
        this->lora_rank,
        &one, 
        out.data_ptr<half>(),
        this->out_features));

    nvtxRangePop();
#endif
}

std::variant<Tensor, GEMM_W4A4::QuantizedActivation> GEMM_W4A4::forward_quant(QuantizedActivation qact, FuseOptions fuse, GEMM_W4A4 *nextGEMM) {
    Tensor out;
    QuantizedActivation qout;

    Tensor next_lora;
    Tensor next_smooth;

    const int M = (int)qact.act.numel() / qact.act.shape[-1];

muyangli's avatar
muyangli committed
245
246
247
248
    if (fuse == FuseOptions::EMPTY || fuse == FuseOptions::SILU) {
        // auto shape = TensorShape(qact.act.shape.dataExtent);
        // shape[-1] = out_features;
        auto shape = TensorShape(qact.actShape.dataExtent);
Zhekai Zhang's avatar
Zhekai Zhang committed
249
        shape[-1] = out_features;
muyangli's avatar
muyangli committed
250
        out = Tensor::allocate(shape, dtype, device);
Zhekai Zhang's avatar
Zhekai Zhang committed
251
    } else {
muyangli's avatar
muyangli committed
252
        qout.act = Tensor::allocate({M, out_features_pad / 2}, Tensor::INT8, device);
253
        if (use_fp4) {
muyangli's avatar
muyangli committed
254
            qout.ascales = Tensor::allocate({out_features_pad / 16, M}, Tensor::FP8_E4M3, device);
255
        } else {
muyangli's avatar
muyangli committed
256
            qout.ascales = Tensor::allocate({out_features_pad / 64, M}, dtype, device);
257
        }
muyangli's avatar
muyangli committed
258
        qout.lora_act = Tensor::allocate({M, lora_rank}, Tensor::FP32, device);
259
        qout.is_unsigned = !use_fp4;
muyangli's avatar
muyangli committed
260
        qout.actShape = qact.actShape;
Zhekai Zhang's avatar
Zhekai Zhang committed
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282

        next_lora = nextGEMM->lora_down;
        next_smooth = nextGEMM->smooth;
    }

#if !NO_LORA_FUSION

#if 0
    Tensor dummy = Tensor::empty_like(qact.lora_act);
    dummy.zero_();

    gemm_w4a4(qact.act, qweight, out, qout.act, qact.ascales, wscales, qout.ascales, {}, dummy, this->lora_up, next_lora, qout.lora_act, {}, {}, {}, this->bias, next_smooth, qact.is_unsigned);

    if (fuse == FuseOptions::EMPTY) {
        debug("gemm.nolora.out", out);
    } else {
        debug("gemm.nolora.qout", qout.act);
        debug("gemm.nolora.oscales", qout.ascales);
        debug("gemm.nolora.lora_act_out", qout.lora_act);
    }
#endif

283
284
    kernels::gemm_w4a4(
        qact.act, qweight, out, qout.act, qact.ascales, wscales, qout.ascales, {}, qact.lora_act, this->lora_up, next_lora, qout.lora_act, {}, {}, {}, this->bias, next_smooth, {}, {}, qact.is_unsigned, this->lora_scales, fuse == FuseOptions::SILU,
285
286
        use_fp4, *this->wtscale.data_ptr<float>(), wcscales.numel() > 0 ? wcscales: Tensor{},
        {}, {}, {}, 0
287
    );
Zhekai Zhang's avatar
Zhekai Zhang committed
288

muyangli's avatar
muyangli committed
289
    if (fuse == FuseOptions::EMPTY || fuse == FuseOptions::SILU) {
Zhekai Zhang's avatar
Zhekai Zhang committed
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
        debug("gemm.out", out);
    } else {
        debug("gemm.qout", qout.act);
        debug("gemm.oscales", qout.ascales);
        debug("gemm.lora_act_out", qout.lora_act);
    }

    
#else
    if (!out.valid()) {
        auto shape = TensorShape(qact.act.shape.dataExtent);
        shape[-1] = out_features;
        out = Tensor::allocate(shape, Tensor::FP16, qweight.device());
    }

muyangli's avatar
muyangli committed
305
    kernels::gemm_w4a4(qact.act, qweight, out, qout.act, qact.ascales, wscales, qout.ascales, {}, {}, {}, {}, {}, {}, {}, {}, this->bias, next_smooth, qact.is_unsigned, this->lora_scales);
Zhekai Zhang's avatar
Zhekai Zhang committed
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359

    nvtxRangePushA("LoraUp");

    static const half one = 1.0;
    static const half zero = 0.0;

    // lora_up: [M, R] * [OC, R]^T => [M, OC]
    // cublas view: [R, OC]^T * [R, M] => [OC, M]
    // lora_up layout wrong?
    checkCUBLAS(cublasHgemm(
        handle, 
        CUBLAS_OP_T, CUBLAS_OP_N, 
        this->out_features, M, this->lora_rank,
        &one,
        this->lora_up.data_ptr<half>(),
        this->lora_rank,
        qact.lora_act.data_ptr<half>(),
        this->lora_rank,
        &one, 
        out.data_ptr<half>(),
        this->out_features));

    nvtxRangePop();

    if (fuse == FuseOptions::GELU_QUANT) {
        nvtxRangePushA("LoraDown");
        // IC is for next lora (OC of this layer)
        // lora_down: [M, IC] * [IC, R] => [M, R]
        // cublas view: [R, IC] * [IC, M] => [R, M]
        checkCUBLAS(cublasHgemm(
            handle, 
            CUBLAS_OP_N, CUBLAS_OP_N, 
            this->lora_rank, M, this->out_features,
            &one,
            next_lora.data_ptr<half>(),
            this->lora_rank,
            out.data_ptr<half>(),
            this->out_features,
            &zero, 
            qout.lora_act.data_ptr<half>(),
            this->lora_rank));

        out = {};

        nvtxRangePop();
    }

#endif
    if (out.valid()) {
        return out;
    }
    return qout;
}

muyangli's avatar
muyangli committed
360
361
362
363
364
365
366
Tensor GEMM_W4A4::forward_quant(QuantizedActivation qact) {
    return std::get<Tensor>(this->forward_quant(qact, FuseOptions::EMPTY, nullptr));
}

GEMM_W4A4::QuantizedActivation GEMM_W4A4::quantize(Tensor x, bool fuse_glu) {
    const int actualM = x.numel() / x.shape[-1];
    const int M = ceilDiv(actualM, 256) * 256;
Zhekai Zhang's avatar
Zhekai Zhang committed
367

muyangli's avatar
muyangli committed
368
369
    // auto shape = TensorShape(x.shape.dataExtent);
    // shape[-1] = in_features / 2;
Zhekai Zhang's avatar
Zhekai Zhang committed
370
371

    QuantizedActivation qact;
muyangli's avatar
muyangli committed
372
    qact.act = Tensor::allocate({M, in_features_pad / 2}, Tensor::INT8, device);
373
    if (use_fp4) {
muyangli's avatar
muyangli committed
374
        qact.ascales = Tensor::allocate({in_features_pad / 16, M}, Tensor::FP8_E4M3, device);
375
    } else {
muyangli's avatar
muyangli committed
376
        qact.ascales = Tensor::allocate({in_features_pad / 64, M}, dtype, device);
377
    }
muyangli's avatar
muyangli committed
378
    qact.lora_act = Tensor::allocate({M, lora_rank}, Tensor::FP32, device);
Zhekai Zhang's avatar
Zhekai Zhang committed
379
    qact.is_unsigned = false;
muyangli's avatar
muyangli committed
380
    qact.actShape = x.shape.dataExtent;
Zhekai Zhang's avatar
Zhekai Zhang committed
381
382
383
384
385

#if !NO_LORA_FUSION
    debug("quantize.x", x);
    debug("quantize.smooth", this->smooth);

386
    kernels::quantize_w4a4_act_fuse_lora(x, qact.act, qact.ascales, this->lora_down, qact.lora_act, this->smooth, fuse_glu, use_fp4);
Zhekai Zhang's avatar
Zhekai Zhang committed
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413

    debug("quantize.qact", qact.act);
    debug("quantize.ascales", qact.ascales);
    debug("quantize.lora_act", qact.lora_act);
#else 
    static const half one = 1.0;
    static const half zero = 0.0;

    nvtxRangePushA("LoraDown");

    // lora_down: [M, IC] * [IC, R] => [M, R]
    // cublas view: [R, IC] * [IC, M]
    checkCUBLAS(cublasHgemm(
        handle, 
        CUBLAS_OP_N, CUBLAS_OP_N, 
        this->lora_rank, M, this->in_features,
        &one,
        lora_down.data_ptr<half>(),
        this->lora_rank,
        x.data_ptr<half>(),
        this->in_features,
        &zero, 
        qact.lora_act.data_ptr<half>(),
        this->lora_rank));

    nvtxRangePop();

muyangli's avatar
muyangli committed
414
    kernels::quantize_w4a4_act(x, qact.act, qact.ascales);
Zhekai Zhang's avatar
Zhekai Zhang committed
415
416
417
418
419

#endif

    return qact;
}
muyangli's avatar
muyangli committed
420
421
422
423
424
425
426
427
428

GEMM_W8A8::GEMM_W8A8(int in_features, int out_features, bool bias, Tensor::ScalarType dtype, Device device) : 
    in_features(in_features), out_features(out_features), dtype(dtype)
{
    this->qweight = Tensor::allocate({out_features, in_features}, Tensor::INT8, device);
    this->wscales = Tensor::allocate({out_features}, dtype, device);
    this->bias = bias ? Tensor::allocate({out_features}, dtype, device, true) : Tensor{};

    registerParams
muyangli's avatar
muyangli committed
429
        (qweight, "qweight", ParamFlags::LazyLoad)
muyangli's avatar
muyangli committed
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
        (wscales, "wscales")
        (this->bias, "bias")
    ;
}

GEMM_W8A8::QuantizedActivation GEMM_W8A8::quantize(Tensor x, bool fuse_glu) {
    QuantizedActivation qact;
    auto qshape = x.shape;
    if (fuse_glu) {
        qshape[-1] /= 2;
    }
    qact.act = Tensor::allocate(qshape, Tensor::INT8, x.device());
    qact.ascales = Tensor::allocate({(int)x.numel() / x.shape[-1]}, this->dtype, x.device());

    debug("quantize.x", x);

    kernels::quantize_w8a8_act(x, qact.act, qact.ascales, fuse_glu);

    debug("quantize.qact", qact.act);
    debug("quantize.ascales", qact.ascales);

    return qact;
}

Tensor GEMM_W8A8::forward_quant(QuantizedActivation qact) {
LeeDongYeun's avatar
LeeDongYeun committed
455
456
457
    auto shape = TensorShape(qact.act.shape.dataExtent);
    shape[-1] = out_features;
    Tensor out = Tensor::allocate(shape, this->dtype, qact.act.device());
muyangli's avatar
muyangli committed
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
    kernels::gemm_w8a8(qact.act, this->qweight, out, qact.ascales, this->wscales, this->bias);

    debug("gemm.out", out);
    return out;
}

DWCONV::DWCONV(int in_features, bool use_bias, Tensor::ScalarType dtype, Device device) : 
    in_features(in_features)
{
    this->weight = Tensor::allocate({in_features, 3, 3, 1}, dtype, device);
    this->bias = use_bias ? Tensor::allocate({in_features}, dtype, device) : Tensor{};

    registerParams
        (this->weight, "weight")
        (this->bias, "bias")
    ;
}

Tensor DWCONV::forward(Tensor x) {
    return dwconv_f16(x, this->weight, {}, this->bias);
}