moe_sorting.h 872 Bytes
Newer Older
Xiaowei.zhang's avatar
Xiaowei.zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
#pragma once
// SPDX-License-Identifier: MIT
 
#include <torch/extension.h>

void moe_sorting_fwd(torch::Tensor &topk_ids,              // [m, topk]
                     torch::Tensor &topk_weights,          // [m, topk]
                     torch::Tensor &sorted_token_ids,      // [max_num_tokens_padded]
                     torch::Tensor &sorted_weights,        // [max_num_tokens_padded]
                     torch::Tensor &sorted_expert_ids,     // [max_num_m_blocks]
                     torch::Tensor &tokens_positions_per_expert,     // [num_experts*2]
                     torch::Tensor &num_valid_ids,         // [1]
                     torch::Tensor &moe_buf,               // [max_num_tokens_padded]
                     int num_experts,
                     int unit_size,
                     std::optional<torch::Tensor> local_expert_mask = std::nullopt);