moe_ops.h 2.48 KB
Newer Older
1
2
#pragma once

3
#include <torch/all.h>
4

5
6
void topk_softmax(torch::Tensor& topk_weights, torch::Tensor& topk_indices,
                  torch::Tensor& token_expert_indices,
7
                  torch::Tensor& gating_output, bool renormalize);
8
9
10
11
12
13
14

void moe_sum(torch::Tensor& input, torch::Tensor& output);

void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
                          int64_t block_size, torch::Tensor sorted_token_ids,
                          torch::Tensor experts_ids,
                          torch::Tensor num_tokens_post_pad);
15
16
17
18
19
20
21
22

void batched_moe_align_block_size(int64_t max_tokens_per_batch,
                                  int64_t block_size,
                                  torch::Tensor const& expert_num_tokens,
                                  torch::Tensor sorted_ids,
                                  torch::Tensor expert_ids,
                                  torch::Tensor num_tokens_post_pad);

23
24
25
void moe_lora_align_block_size(torch::Tensor topk_ids,
                               torch::Tensor token_lora_mapping,
                               int64_t num_experts, int64_t block_size,
26
27
                               int64_t max_loras, int64_t max_num_tokens_padded,
                               int64_t max_num_m_blocks,
28
29
30
                               torch::Tensor sorted_token_ids,
                               torch::Tensor expert_ids,
                               torch::Tensor num_tokens_post_pad);
31
#ifndef USE_ROCM
32
33
34
35
36
37
38
39
40
torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output,
                             torch::Tensor b_qweight, torch::Tensor b_scales,
                             std::optional<torch::Tensor> b_qzeros,
                             std::optional<torch::Tensor> topk_weights,
                             torch::Tensor sorted_token_ids,
                             torch::Tensor expert_ids,
                             torch::Tensor num_tokens_post_pad, int64_t top_k,
                             int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N,
                             int64_t BLOCK_SIZE_K, int64_t bit);
41
42
43
44
45

std::tuple<torch::Tensor, torch::Tensor> grouped_topk(
    torch::Tensor const& scores, torch::Tensor const& scores_with_bias,
    int64_t n_group, int64_t topk_group, int64_t topk, bool renormalize,
    double routed_scaling_factor);
46
47
#endif

48
49
50
51
52
bool moe_permute_unpermute_supported();

void shuffle_rows(const torch::Tensor& input_tensor,
                  const torch::Tensor& dst2src_map,
                  torch::Tensor& output_tensor);