moe_op.h 9.18 KB
Newer Older
Xiaowei.zhang's avatar
Xiaowei.zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
#pragma once
// SPDX-License-Identifier: MIT
 
#include <torch/extension.h>
#include "aiter_enum.h"

void fmoe(torch::Tensor &out,               // [token_cnt, dim]
          torch::Tensor &input,             // [token_cnt, dim] M,K
          torch::Tensor &gate,              // [expert, hidden_dim, dim] N,K
          torch::Tensor &down,              // [expert, hidden_dim, dim]
          torch::Tensor &sorted_token_ids,  // [max_num_tokens_padded]
          torch::Tensor &sorted_weights,    // [max_num_tokens_padded]
          torch::Tensor &sorted_expert_ids, // [max_num_m_blocks]
          torch::Tensor &num_valid_ids,     // [1]
          uint32_t topk                     //
);

void fmoe_int8_g1u0(torch::Tensor &out,               // [token_cnt, dim]
                    torch::Tensor &input,             // [token_cnt, dim] M,K
                    torch::Tensor &gate,              // [expert, hidden_dim, dim] N,K
                    torch::Tensor &down,              // [expert, hidden_dim, dim]
                    torch::Tensor &sorted_token_ids,  // [max_num_tokens_padded]
                    torch::Tensor &sorted_weights,    // [max_num_tokens_padded]
                    torch::Tensor &sorted_expert_ids, // [max_num_m_blocks]
                    torch::Tensor &num_valid_ids,     // [1]
                    uint32_t topk,                    //
                    torch::Tensor &input_scale,       // [token_cnt, 1]
                    torch::Tensor &fc1_scale,         // [expert, 1, hidden_dim]
                    torch::Tensor &fc2_scale,         // [expert, 1, dim]
                    torch::Tensor &fc2_smooth_scale,  // [expert, 1, hidden_dim]
                    ActivationType activation = ActivationType::Silu);

void fmoe_g1u1(torch::Tensor &out,                                           // [token_cnt, dim]
               torch::Tensor &input,                                         // [token_cnt, dim] M,K
               torch::Tensor &gate,                                          // [expert, hidden_dim*2, dim] N,K
               torch::Tensor &down,                                          // [expert, hidden_dim, dim]
               torch::Tensor &sorted_token_ids,                              // [max_num_tokens_padded]
               torch::Tensor &sorted_weights,                                // [max_num_tokens_padded]
               torch::Tensor &sorted_expert_ids,                             // [max_num_m_blocks]
               torch::Tensor &num_valid_ids,                                 // [1]
               uint32_t topk,                                                //
               torch::Tensor &input_scale,                                   // [token_cnt, 1]
               torch::Tensor &fc1_scale,                                     // [expert, 1, hidden_dim]
               torch::Tensor &fc2_scale,                                     // [expert, 1, dim]
               std::optional<torch::Tensor> fc2_smooth_scale = std::nullopt, // [expert, 1, hidden_dim]
               ActivationType activation = ActivationType::Silu);

void fmoe_g1u1_tkw1(torch::Tensor &out,                                           // [token_cnt, dim]
                    torch::Tensor &input,                                         // [token_cnt, dim] M,K
                    torch::Tensor &gate,                                          // [expert, hidden_dim*2, dim] N,K
                    torch::Tensor &down,                                          // [expert, hidden_dim, dim]
                    torch::Tensor &sorted_token_ids,                              // [max_num_tokens_padded]
                    torch::Tensor &sorted_weights,                                // [max_num_tokens_padded]
                    torch::Tensor &sorted_expert_ids,                             // [max_num_m_blocks]
                    torch::Tensor &num_valid_ids,                                 // [1]
                    uint32_t topk,                                                //
                    torch::Tensor &input_scale,                                   // [token_cnt, 1]
                    torch::Tensor &fc1_scale,                                     // [expert, 1, hidden_dim]
                    torch::Tensor &fc2_scale,                                     // [expert, 1, dim]
                    std::optional<torch::Tensor> fc2_smooth_scale = std::nullopt, // [expert, 1, hidden_dim]
                    ActivationType activation = ActivationType::Silu);

void fmoe_int8_g1u0_a16(torch::Tensor &out,               // [token_cnt, dim]
                        torch::Tensor &input,             // [token_cnt, dim] M,K
                        torch::Tensor &gate,              // [expert, inter_dim, dim] N,K
                        torch::Tensor &down,              // [expert, dim, inter_dim]
                        torch::Tensor &sorted_token_ids,  // [max_num_tokens_padded]
                        torch::Tensor &sorted_weights,    // [max_num_tokens_padded]
                        torch::Tensor &sorted_expert_ids, // [max_num_m_blocks]
                        torch::Tensor &num_valid_ids,     // [1]
                        uint32_t topk,                    //
                        torch::Tensor &fc1_scale,         // [expert, 1, hidden_dim]
                        torch::Tensor &fc2_scale,         // [expert, 1, dim]
                        torch::Tensor &fc1_smooth_scale,  // [expert, 1, hidden_dim]
                        torch::Tensor &fc2_smooth_scale   // [expert, 1, hidden_dim]
);
void fmoe_g1u1_a16(torch::Tensor &out,               // [token_cnt, dim]
                   torch::Tensor &input,             // [token_cnt, dim] M,K
                   torch::Tensor &gate,              // [expert, inter_dim, dim] N,K
                   torch::Tensor &down,              // [expert, dim, inter_dim]
                   torch::Tensor &sorted_token_ids,  // [max_num_tokens_padded]
                   torch::Tensor &sorted_weights,    // [max_num_tokens_padded]
                   torch::Tensor &sorted_expert_ids, // [max_num_m_blocks]
                   torch::Tensor &num_valid_ids,     // [1]
                   uint32_t topk,                    //
                   torch::Tensor &fc1_scale,         // [expert, 1, hidden_dim]
                   torch::Tensor &fc2_scale,         // [expert, 1, dim]
                   torch::Tensor &fc1_smooth_scale,  // [expert, 1, hidden_dim]
                   torch::Tensor &fc2_smooth_scale   // [expert, 1, hidden_dim]
);

void fmoe_fp8_blockscale_g1u1(torch::Tensor &out,                                           // [token_cnt, dim]
                              torch::Tensor &input,                                         // [token_cnt, dim] M,K
                              torch::Tensor &gate,                                          // [expert, inter_dim*2, dim] N,K
                              torch::Tensor &down,                                          // [expert, dim, inter_dim]
                              torch::Tensor &sorted_token_ids,                              // [max_num_tokens_padded]
                              torch::Tensor &sorted_weights,                                // [max_num_tokens_padded]
                              torch::Tensor &sorted_expert_ids,                             // [max_num_m_blocks]
                              torch::Tensor &num_valid_ids,                                 // [1]
                              uint32_t topk,                                                //
                              torch::Tensor &input_scale,                                   // [expert, 1, dim]
                              torch::Tensor &fc1_scale,                                     // [expert, 1, inter_dim]
                              torch::Tensor &fc2_scale,                                     // [expert, 1, dim]
                              int fc_scale_blkn = 128,                                      // = 128,
                              int fc_scale_blkk = 128,                                      // = 128
                              std::optional<torch::Tensor> fc2_smooth_scale = std::nullopt, // [expert, 1, inter_dim]
                              ActivationType activation = ActivationType::Silu);

void moe_stage1_g1u1(torch::Tensor &input,             // [token_cnt, model_dim] M,K
                     torch::Tensor &w1,                // [expert, inter_dim*2, model_dim] N,K
                     torch::Tensor &w2,                // [expert, model_dim, inter_dim]
                     torch::Tensor &sorted_token_ids,  // [max_num_tokens_padded]
                     torch::Tensor &sorted_expert_ids, // [max_num_m_blocks]
                     torch::Tensor &num_valid_ids,     // [1]
                     torch::Tensor &out,               // [token_cnt, topk, inter_dim]
                     int inter_dim,
                     std::string &kernelName,
                     int block_m,
                     int ksplit,
                     ActivationType activation,
                     QuantType quant_type,
                     std::optional<torch::Tensor> a1_scale, // [token_cnt, 1], token scale
                     std::optional<torch::Tensor> w1_scale, // [expert, 1, inter_dim], gate(up) scale
                     std::optional<torch::Tensor> sorted_weights);