gemm_w4a4.h 1.94 KB
Newer Older
Zhekai Zhang's avatar
Zhekai Zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
#pragma once

#include "common.h"
#include "Tensor.h"

void gemm_w4a4( 
        Tensor act,           // packed act [M, K / 2]
        Tensor wgt,           // packed act [N, K / 2]
        Tensor out,           // linear     [M, N]
        Tensor qout,          // packed act [M, N / 2]
        Tensor ascales,       // packed as  [K / 64, M]
        Tensor wscales,       // packed ws  [K / 64, N]
        Tensor oscales,       // packed as  [N / 64, M]
        Tensor poolout,       // linear     [M / PoolSize, N]
        Tensor lora_act_in,   // packed lora_act [M, R]
        Tensor lora_up,       // packed lora_wgt [N, R]
        Tensor lora_down,     // packed lora_wgt [N, R]
        Tensor lora_act_out,  // packed lora_act [M, R]
        Tensor norm_q,        // linear     [HEAD_DIM]
        Tensor norm_k,        // linear     [HEAD_DIM]
        Tensor rotary_emb,    // linear     [M, HEAD_DIM / 2, 2, 2]
        Tensor bias,          // packed ws  [N]
        Tensor smooth_factor, // packed ws  [N], for quantization of the next layer
        bool act_unsigned,
        std::vector<float> lora_scales  // [R / 16]
);

void quantize_w4a4_act_fuse_lora(Tensor input, Tensor output, Tensor oscales, Tensor lora_down, Tensor lora_act_out, Tensor smooth = {});
void quantize_w4a4_act(Tensor input, Tensor output, Tensor oscales);
void quantize_w4a4_wgt(Tensor input, Tensor output, Tensor oscales);

void gemm_w8a8(Tensor act,      // [M, K]
               Tensor wgt,      // [N, K]
               Tensor out,      // [M, N]
               Tensor ascales,  // [1, M]
               Tensor wscales   // [1, N]
               );

void quantize_w8a8_act(Tensor input, Tensor output, Tensor oscales, bool fuse_glu);

void gemm_w8a8_fuse_litela(
    Tensor act,      // [B, (M), K]
    Tensor wgt,      // [N, K]
    Tensor out_q,    // [B, (M), N / 3]
    Tensor out_vk,   // [B, num_heads, head_dim + 1, head_dim]
    Tensor ascales,  // [1, M]
    Tensor wscales   // [1, N]
);