zgemm.h 3.58 KB
Newer Older
Zhekai Zhang's avatar
Zhekai Zhang committed
1
2
3
4
5
#pragma once

#include "common.h"
#include "Tensor.h"

muyangli's avatar
muyangli committed
6
7
namespace nunchaku::kernels {

Muyang Li's avatar
Muyang Li committed
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
void gemm_w4a4(Tensor act,            // packed act [M, K / 2]
               Tensor wgt,            // packed act [N, K / 2]
               Tensor out,            // linear     [M, N]
               Tensor qout,           // packed act [M, N / 2]
               Tensor ascales,        // packed as  [K / 64, M]
               Tensor wscales,        // packed ws  [K / 64, N]
               Tensor oscales,        // packed as  [N / 64, M]
               Tensor poolout,        // linear     [M / PoolSize, N]
               Tensor lora_act_in,    // packed lora_act [M, R]
               Tensor lora_up,        // packed lora_wgt [N, R]
               Tensor lora_down,      // packed lora_wgt [N, R]
               Tensor lora_act_out,   // packed lora_act [M, R]
               Tensor norm_q,         // linear     [HEAD_DIM]
               Tensor norm_k,         // linear     [HEAD_DIM]
               Tensor rotary_emb,     // linear     [M, HEAD_DIM / 2, 2, 2]
               Tensor bias,           // packed ws  [N]
               Tensor smooth_factor,  // packed ws  [N], for quantization of the next layer
               Tensor out_vk,         // linear     [B, num_heads, head_dim + 1, head_dim]
               Tensor out_linearattn, // linear     [B, (M), N / 3]
               bool act_unsigned,
               std::vector<float> lora_scales, // [R / 16]
               bool fuse_silu,
               bool fp4,
               float alpha,
               Tensor wcscales,
               Tensor out_q, // packed attention [B, H, M, D]
               Tensor out_k, // packed attention [B, H, M, D]
               Tensor out_v, // packed attention [B, H, M, D]
               int attn_tokens);
muyangli's avatar
muyangli committed
37
void linearattn_vk_mul_q(Tensor q, Tensor vk);
Zhekai Zhang's avatar
Zhekai Zhang committed
38

Muyang Li's avatar
Muyang Li committed
39
40
41
42
43
44
45
46
void quantize_w4a4_act_fuse_lora(Tensor input,
                                 Tensor output,
                                 Tensor oscales,
                                 Tensor lora_down,
                                 Tensor lora_act_out,
                                 Tensor smooth = {},
                                 bool fuse_glu = false,
                                 bool fp4      = false);
Zhekai Zhang's avatar
Zhekai Zhang committed
47
48
49
void quantize_w4a4_act(Tensor input, Tensor output, Tensor oscales);
void quantize_w4a4_wgt(Tensor input, Tensor output, Tensor oscales);

Muyang Li's avatar
Muyang Li committed
50
51
52
53
54
55
56
void gemm_w8a8(Tensor act,     // [M, K]
               Tensor wgt,     // [N, K]
               Tensor out,     // [M, N]
               Tensor ascales, // [1, M]
               Tensor wscales, // [1, N]
               Tensor bias     // packed ws  [N]
);
Zhekai Zhang's avatar
Zhekai Zhang committed
57
58
59

void quantize_w8a8_act(Tensor input, Tensor output, Tensor oscales, bool fuse_glu);

muyangli's avatar
muyangli committed
60
61
62
63
64
65
66
67
68
// void gemm_w8a8_fuse_litela(
//     Tensor act,      // [B, (M), K]
//     Tensor wgt,      // [N, K]
//     Tensor out_q,    // [B, (M), N / 3]
//     Tensor out_vk,   // [B, num_heads, head_dim + 1, head_dim]
//     Tensor ascales,  // [1, M]
//     Tensor wscales   // [1, N]
// );

Muyang Li's avatar
Muyang Li committed
69
70
71
72
73
void attention_fp16(Tensor q, // packed [Batch, Head, TokensQ, HEAD_DIM]
                    Tensor k, // packed [Batch, Head, TokensKV, HEAD_DIM]
                    Tensor v, // packed [Batch, Head, TokensKV, HEAD_DIM]
                    Tensor o, // linear [Batch, TokensQ, Head * HEAD_DIM]
                    float scale);
74

75
76
// EXPERIMENTAL, for sm_75
void set_faster_i2f_mode(std::string mode);
77
78
79
80
81

// FOR TEST ONLY
void test_rmsnorm_rope(Tensor input, Tensor output, Tensor norm_q, Tensor norm_k, Tensor rotary_emb);
void test_pack_qkv(Tensor input, Tensor out_q, Tensor out_k, Tensor out_v, int numTokens);

Muyang Li's avatar
Muyang Li committed
82
}; // namespace nunchaku::kernels