moe_ck.h 8.47 KB
Newer Older
Xiaowei.zhang's avatar
Xiaowei.zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
// SPDX-License-Identifier: MIT
 
#pragma once

#include <torch/extension.h>

torch::Tensor ck_moe(torch::Tensor &hidden_states,          // [m, k], input token
                     torch::Tensor &w1,                     // [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w])
                     torch::Tensor &w2,                     // [e, n, k], pre-shuffle([e, nr, kr, w])
                     torch::Tensor &topk_weights,           // [tokens, topk]
                     torch::Tensor &topk_ids,               // [tokens, topk]
                     std::optional<bool> use_int8_w8a16,    // use int8 w8a16 quantization
                     std::optional<bool> use_int4_w4a16,    // use int4 w4a16 quantization
                     std::optional<bool> use_int8_w8a8_block,// use int8 w8a8 block quantization
                     std::optional<bool> use_int4_w4a8_block,// use int4 w4a8 block quantization
                     std::optional<torch::Tensor> w1_zp,    // [e, 2*n, k/group], gate(up) zero-point
                     std::optional<torch::Tensor> w2_zp,    // [e, k, n/group], down zero-point
                     std::optional<torch::Tensor> w1_scale, // [e, 1, n], gate(up) scale
                     std::optional<torch::Tensor> w2_scale, // [e, 1, k], down scale
                     std::optional<torch::Tensor> a1_scale, // [m, 1], token scale
                     std::optional<torch::Tensor> a2_scale, // [e, 1, n], smooth-quant-scale for 2nd gemm input
                     std::optional<int> block_shape_n,      // quant block n size
                     std::optional<int> block_shape_k,      // quant block k size
                     std::optional<int> block_m = 32,       // moe partion size for tokens in m direction
                     std::optional<int> solution_id = 0,    // solution id
                     std::optional<torch::Tensor> expert_mask = std::nullopt);
torch::Tensor ck_shuffle_moe(torch::Tensor &hidden_states,          // [m, k], input token
                     torch::Tensor &w1,                     // [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w])
                     torch::Tensor &w2,                     // [e, n, k], pre-shuffle([e, nr, kr, w])
                     torch::Tensor &topk_weights,           // [tokens, topk]
                     torch::Tensor &topk_ids,               // [tokens, topk]
                     std::optional<bool> use_int8_w8a16,    // use int8 w8a16 quantization
                     std::optional<bool> use_int4_w4a16,    // use int4 w4a16 quantization
                     std::optional<bool> use_int8_w8a8_block,// use int8 w8a8 block quantization
                     std::optional<bool> use_int4_w4a8_block,// use int4 w4a8 block quantization
                     std::optional<torch::Tensor> w1_zp,    // [e, 2*n, k/group], gate(up) zero-point
                     std::optional<torch::Tensor> w2_zp,    // [e, k, n/group], down zero-point
                     std::optional<torch::Tensor> w1_scale, // [e, 1, n], gate(up) scale or ...
                     std::optional<torch::Tensor> w2_scale, // [e, 1, k], down scale or ...
                     std::optional<torch::Tensor> a1_scale, // [m, 1], token scale
                     std::optional<torch::Tensor> a2_scale, // [e, 1, n], smooth-quant-scale for 2nd gemm input
                     std::optional<int> block_shape_n,      // quant block n size
                     std::optional<int> block_shape_k,      // quant block k size
                     std::optional<int> block_m = 32,       // moe partion size for tokens in m direction
                     std::optional<int> solution_id = 0,    // solution id
                     std::optional<torch::Tensor> expert_mask = std::nullopt);

std::vector<int> ck_moe_get_solutions(torch::Tensor &hidden_states,          // [m, k], input token
                            torch::Tensor &w1,                     // [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w])
                            torch::Tensor &w2,                     // [e, n, k], pre-shuffle([e, nr, kr, w])
                            torch::Tensor &topk_weights,           // [tokens, topk]
                            torch::Tensor &topk_ids,               // [tokens, topk]
                            std::optional<bool> use_int8_w8a16,    // use int8 w8a16 quantization
                            std::optional<bool> use_int4_w4a16,    // use int4 w4a16 quantization
                            std::optional<bool> use_int8_w8a8_block,// use int8 w8a8 block quantization
                            std::optional<bool> use_int4_w4a8_block,// use int4 w4a8 block quantization
                            std::optional<torch::Tensor> w1_zp,    // [e, 2*n, k/group], gate(up) zero-point
                            std::optional<torch::Tensor> w2_zp,    // [e, k, n/group], down zero-point
                            std::optional<torch::Tensor> w1_scale, // [e, 1, n], gate(up) scale or ...
                            std::optional<torch::Tensor> w2_scale, // [e, 1, k], down scale or ...
                            std::optional<torch::Tensor> a1_scale, // [m, 1], token scale
                            std::optional<torch::Tensor> a2_scale, // [e, 1, n], smooth-quant-scale for 2nd gemm input
                            std::optional<int> block_shape_n,      // quant block n size
                            std::optional<int> block_shape_k,      // quant block k size
                            std::optional<int> block_m = 32,       // moe partion size for tokens in m direction
                            std::optional<torch::Tensor> expert_mask = std::nullopt);

void ck_moe_per_token_quant(torch::Tensor &input,              // [m, k], input token
                            torch::Tensor &out_quant,          // [m, k], output token
                            torch::Tensor &out_scale);         // [m, 1], output scale

void ck_moe_stage_1(torch::Tensor &hidden_states,     // [m, k], input token
                    torch::Tensor &w1,                // [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w])
                    torch::Tensor &w2,                // [e, n, k], pre-shuffle([e, nr, kr, w])
                    torch::Tensor &sorted_token_ids,  // [max_num_tokens_padded]
                    torch::Tensor &sorted_expert_ids, // [max_num_m_blocks]
                    torch::Tensor &tokens_positions_per_expert, // [num_experts*2]
                    torch::Tensor &num_valid_ids,     // [1]
                    torch::Tensor &out,               // [max_num_tokens_padded, inter_dim]
                    int topk,
                    std::optional<bool> use_int8_w8a8_block,// use int8 w8a8 block quantization
                    std::optional<bool> use_fp8_w8a8_block, // use fp8 w8a8 block quantization
                    std::optional<torch::Tensor> w1_scale, // [e, 1, n], gate(up) scale
                    std::optional<torch::Tensor> a1_scale, // [m, 1], token scale
                    std::optional<int> block_shape_n,      // quant block n size
                    std::optional<int> block_shape_k,      // quant block k size
                    std::optional<int> block_m,
                    std::optional<torch::Tensor> sorted_weights,
                    std::optional<int> act_op);

void ck_moe_stage_2(torch::Tensor &inter_states,      // [m, k], input token
                    torch::Tensor &w1,                // [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w])
                    torch::Tensor &w2,                // [e, n, k], pre-shuffle([e, nr, kr, w])
                    torch::Tensor &sorted_token_ids,  // [max_num_tokens_padded]
                    torch::Tensor &sorted_expert_ids, // [max_num_m_blocks]
                    torch::Tensor &tokens_positions_per_expert, // [num_experts*2]
                    torch::Tensor &num_valid_ids,     // [1]
                    torch::Tensor &out,               // [max_num_tokens_padded, inter_dim]
                    int topk,
                    std::optional<bool> use_int8_w8a8_block,// use int8 w8a8 block quantization
                    std::optional<bool> use_fp8_w8a8_block, // use fp8 w8a8 block quantization
                    std::optional<torch::Tensor> w2_scale, // [e, 1, n], gate(up) scale
                    std::optional<torch::Tensor> a2_scale, // [m, 1], token scale
                    std::optional<int> block_shape_n,      // quant block n size
                    std::optional<int> block_shape_k,      // quant block k size
                    std::optional<int> block_m,
                    std::optional<torch::Tensor> sorted_weights);    // [max_num_tokens_padded]);