moe_asm.h 5.17 KB
Newer Older
Xiaowei.zhang's avatar
Xiaowei.zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
#pragma once
// SPDX-License-Identifier: MIT

#include <torch/extension.h>
#include "aiter_enum.h"

void asm_fmoe_stage1(torch::Tensor &out,               // [token_cnt, dim]
              torch::Tensor &input,             // [token_cnt, dim] M,K
              torch::Tensor &gate,              // [expert, hidden_dim, dim] N,K
              torch::Tensor &down,              // [expert, hidden_dim, dim]
              torch::Tensor &sorted_token_ids,  // [max_num_tokens_padded]
              torch::Tensor &sorted_weights,    // [max_num_tokens_padded]
              torch::Tensor &sorted_expert_ids, // [max_num_m_blocks]
              torch::Tensor &num_valid_ids,                     //
              uint32_t top_k,
              std::optional<torch::Tensor> scale_a = std::nullopt,
              std::optional<torch::Tensor> scale_b = std::nullopt,
              std::optional<torch::Tensor> zero_points = std::nullopt,
              std::optional<int> mode = 0,
              std::optional<int> solidx = 0,
              std::optional<int> block_size = 16,
              std::optional<int> persist_groups = 0
);

void asm_fmoe_stage2(torch::Tensor &out,               // [token_cnt, dim]
              torch::Tensor &input,             // [token_cnt, dim] M,K
              torch::Tensor &gate,              // [expert, hidden_dim, dim] N,K
              torch::Tensor &down,              // [expert, hidden_dim, dim]
              torch::Tensor &sorted_token_ids,  // [max_num_tokens_padded]
              torch::Tensor &sorted_weights,    // [max_num_tokens_padded]
              torch::Tensor &sorted_expert_ids, // [max_num_m_blocks]
              torch::Tensor &num_valid_ids,                     //
              uint32_t top_k,
              std::optional<torch::Tensor> scale_a = std::nullopt,
              std::optional<torch::Tensor> scale_b = std::nullopt,
              std::optional<torch::Tensor> zero_points = std::nullopt,
              std::optional<int> mode = 0,
              std::optional<int> solidx = 0,
              std::optional<int> block_size = 16,
              std::optional<int> persist_groups = 0
);

void asm_fmoe_a8(torch::Tensor &out,               // [token_cnt, dim]
              torch::Tensor &input,             // [token_cnt, dim] M,K
              torch::Tensor &gate,              // [expert, hidden_dim, dim] N,K
              torch::Tensor &down,              // [expert, hidden_dim, dim]
              torch::Tensor &sorted_token_ids,  // [max_num_tokens_padded]
              torch::Tensor &sorted_weights,    // [max_num_tokens_padded]
              torch::Tensor &sorted_expert_ids, // [max_num_m_blocks]
              torch::Tensor &num_valid_ids,                     //
              uint32_t top_k,
              std::optional<torch::Tensor> scale_a = std::nullopt,
              std::optional<torch::Tensor> scale_b = std::nullopt,
              std::optional<torch::Tensor> zero_points = std::nullopt,
              std::optional<int> mode = 0,
              std::optional<int> solidx = 0,
              std::optional<int> out_type = 0,
              std::optional<int> persist_groups = 0,
              std::optional<int> use_shuffle = 0

);

std::vector<std::string> asm_moe_get_solutions(torch::Tensor &hidden_states,          // [m, k], input token
                            torch::Tensor &w1,                     // [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w])
                            torch::Tensor &w2,                     // [e, n, k], pre-shuffle([e, nr, kr, w])
                            torch::Tensor &topk_weights,           // [tokens, topk]
                            torch::Tensor &topk_ids,               // [tokens, topk]
                            std::optional<bool> use_int8_w8a16,    // use int8 w8a16 quantization
                            std::optional<bool> use_int4_w4a16,    // use int4 w4a16 quantization
                            std::optional<bool> use_int8_w8a8,     // use int8 w8a8 quantization
                            std::optional<bool> use_int4_w4a8,     // use int4 w4a8 quantization
                            std::optional<bool> use_fp8_w8a8,      // use f8 w8a8 quantization
                            std::optional<bool> per_channel_quant, // use channel quantization
                            std::optional<torch::Tensor> w1_zp,    // [e, 2*n, k/group], gate(up) zero-point
                            std::optional<torch::Tensor> w2_zp,    // [e, k, n/group], down zero-point
                            std::optional<torch::Tensor> w1_scale, // [e, 1, n], gate(up) scale or ...
                            std::optional<torch::Tensor> w2_scale, // [e, 1, k], down scale or ...
                            std::optional<torch::Tensor> a1_scale, // [m, 1], token scale
                            std::optional<torch::Tensor> a2_scale, // [e, 1, n], smooth-quant-scale for 2nd gemm input
                            std::optional<int> block_shape_n,      // quant block n size
                            std::optional<int> block_shape_k,      // quant block k size
                            std::optional<int> block_m = 32,       // moe partion size for tokens in m direction
                            std::optional<torch::Tensor> expert_mask = std::nullopt);