moe_align_sum.h 1.82 KB
Newer Older
Xiaowei.zhang's avatar
Xiaowei.zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
#pragma once

#include <torch/all.h>
#include <torch/extension.h>
#include "aiter_enum.h"

// void topk_softmax(torch::Tensor& topk_weights, torch::Tensor& topk_indices,
//                   torch::Tensor& token_expert_indices,
//                   torch::Tensor& gating_output);

// void moe_sum(torch::Tensor& input, torch::Tensor& output);

void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
                          int64_t block_size, torch::Tensor sorted_token_ids,
                          torch::Tensor experts_ids,
                          torch::Tensor num_tokens_post_pad);

void sgl_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
                              int64_t block_size,
                              torch::Tensor sorted_token_ids,
                              torch::Tensor experts_ids,
                              torch::Tensor num_tokens_post_pad);
// #ifndef USE_ROCM
// torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output,
//                              torch::Tensor b_qweight, torch::Tensor b_scales,
//                              std::optional<torch::Tensor> b_qzeros,
//                              std::optional<torch::Tensor> topk_weights,
//                              torch::Tensor sorted_token_ids,
//                              torch::Tensor expert_ids,
//                              torch::Tensor num_tokens_post_pad, int64_t top_k,
//                              int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N,
//                              int64_t BLOCK_SIZE_K, int64_t bit);
// #endif
// std::vector<torch::Tensor> moe_fused_gate(
//     torch::Tensor& input,
//     torch::Tensor& bias,
//     int64_t num_expert_group,
//     int64_t topk_group,
//     int64_t topk,
//     int64_t n_share_experts_fusion,
//     double routed_scaling_factor);