moe_sorting.h 927 Bytes
Newer Older
Xiaowei.zhang's avatar
Xiaowei.zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
#pragma once
// SPDX-License-Identifier: MIT
 
#include <torch/extension.h>

void moe_sorting_fwd(torch::Tensor &topk_ids,              // [m, topk]
                     torch::Tensor &topk_weights,          // [m, topk]
                     torch::Tensor &sorted_token_ids,      // [max_num_tokens_padded]
                     torch::Tensor &sorted_weights,        // [max_num_tokens_padded]
                     torch::Tensor &sorted_expert_ids,     // [max_num_m_blocks]
                     torch::Tensor &tokens_positions_per_expert,     // [num_experts*2]
                     torch::Tensor &num_valid_ids,         // [1]
13
14
15
                     std::optional<torch::Tensor> moe_buf = std::nullopt,  // [max_num_tokens_padded], set to None to skip zero-fill
                     int num_experts = 0,
                     int unit_size = 0,
Xiaowei.zhang's avatar
Xiaowei.zhang committed
16
                     std::optional<torch::Tensor> local_expert_mask = std::nullopt);