moe_utils.h 2.44 KB
Newer Older
Xiaowei.zhang's avatar
Xiaowei.zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
#pragma once
#include <torch/all.h>
#include <torch/extension.h>
#include "aiter_enum.h"
namespace aiter{
void topk_softmax(torch::Tensor &topk_weights, torch::Tensor &topk_indices,
                  torch::Tensor &token_expert_indices,
                  torch::Tensor &gating_output,
                  bool need_renorm);
void moe_sum(torch::Tensor &input, torch::Tensor &output);

                  
}

void biased_grouped_topk(torch::Tensor& gating_output,   // [num_tokens, num_experts]
                         torch::Tensor& correction_bias, // [num_expert]
                         torch::Tensor& topk_weights,    // [num_tokens, topk]
                         torch::Tensor& topk_ids,        // [num_tokens, topk]
                         int num_expert_group,
                         int topk_group,
                         bool renormalize,
                         const float routed_scaling_factor = 1.);

void grouped_topk(torch::Tensor& gating_output, // [num_tokens, num_experts]
                  torch::Tensor& topk_weights,  // [num_tokens, topk]
                  torch::Tensor& topk_ids,      // [num_tokens, topk]
                  int num_expert_group,
                  int topk_grp,
                  bool need_renorm,
                  bool is_softmax                   = true,
                  const float routed_scaling_factor = 1.);

std::vector<at::Tensor> moe_fused_gate(at::Tensor& input,
                                       at::Tensor& bias,
                                       at::Tensor& topk_weights,
                                       at::Tensor& topk_ids,
                                       int64_t num_expert_group,
                                       int64_t topk_group,
                                       int64_t topk,
                                       int64_t n_share_experts_fusion,
                                       double routed_scaling_factor);

void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
                          int64_t block_size, torch::Tensor sorted_token_ids,
                          torch::Tensor experts_ids,
                          torch::Tensor num_tokens_post_pad);

void sgl_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
                              int64_t block_size,
                              torch::Tensor sorted_token_ids,
                              torch::Tensor experts_ids,
                              torch::Tensor num_tokens_post_pad);