Commit d2b52805 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.10.2rc1' into v0.10.2rc1-ori

parents 9a521c23 5438967f
...@@ -45,8 +45,6 @@ void moe_permute( ...@@ -45,8 +45,6 @@ void moe_permute(
auto copy_topk_ids = topk_ids.clone(); // copy topk_ids for preprocess auto copy_topk_ids = topk_ids.clone(); // copy topk_ids for preprocess
auto permuted_experts_id = torch::empty_like(topk_ids); auto permuted_experts_id = torch::empty_like(topk_ids);
auto sorted_row_idx = torch::empty_like(inv_permuted_idx); auto sorted_row_idx = torch::empty_like(inv_permuted_idx);
auto align_expert_first_token_offset =
torch::zeros_like(expert_first_token_offset);
CubKeyValueSorter sorter{}; CubKeyValueSorter sorter{};
int64_t* valid_num_ptr = nullptr; int64_t* valid_num_ptr = nullptr;
...@@ -85,12 +83,14 @@ void moe_permute( ...@@ -85,12 +83,14 @@ void moe_permute(
}); });
// get m_indices and update expert_first_token_offset with align block // get m_indices and update expert_first_token_offset with align block
getMIndices(get_ptr<int64_t>(expert_first_token_offset), // this is only required for DeepGemm and not required for CUTLASS group gemm
get_ptr<int64_t>(align_expert_first_token_offset),
get_ptr<int>(m_indices), n_local_expert, align_block_size_value,
stream);
if (align_block_size.has_value()) { if (align_block_size.has_value()) {
// update align_expert_first_token_offset auto align_expert_first_token_offset =
torch::zeros_like(expert_first_token_offset);
getMIndices(get_ptr<int64_t>(expert_first_token_offset),
get_ptr<int64_t>(align_expert_first_token_offset),
get_ptr<int>(m_indices), n_local_expert, align_block_size_value,
stream);
expert_first_token_offset.copy_(align_expert_first_token_offset); expert_first_token_offset.copy_(align_expert_first_token_offset);
} }
} }
...@@ -195,19 +195,14 @@ void moe_permute(const torch::Tensor& input, const torch::Tensor& topk_weights, ...@@ -195,19 +195,14 @@ void moe_permute(const torch::Tensor& input, const torch::Tensor& topk_weights,
torch::Tensor& expert_first_token_offset, torch::Tensor& expert_first_token_offset,
torch::Tensor& src_row_id2dst_row_id_map, torch::Tensor& src_row_id2dst_row_id_map,
torch::Tensor& m_indices) { torch::Tensor& m_indices) {
TORCH_CHECK(false, "moe_unpermute is not supported on CUDA < 12.0"); TORCH_CHECK(false, "moe_permute is not supported on CUDA < 12.0");
} }
void moe_unpermute(const torch::Tensor& input, void moe_unpermute(
const torch::Tensor& topk_weights, torch::Tensor& topk_ids, const torch::Tensor& permuted_hidden_states,
const torch::Tensor& token_expert_indices, const torch::Tensor& topk_weights, const torch::Tensor& inv_permuted_idx,
const std::optional<torch::Tensor>& expert_map, const std::optional<torch::Tensor>& expert_first_token_offset, int64_t topk,
int64_t n_expert, int64_t n_local_expert, int64_t topk, torch::Tensor& hidden_states) {
const std::optional<int64_t>& align_block_size,
torch::Tensor& permuted_input,
torch::Tensor& expert_first_token_offset,
torch::Tensor& src_row_id2dst_row_id_map,
torch::Tensor& m_indices) {
TORCH_CHECK(false, "moe_unpermute is not supported on CUDA < 12.0"); TORCH_CHECK(false, "moe_unpermute is not supported on CUDA < 12.0");
} }
...@@ -224,4 +219,4 @@ bool moe_permute_unpermute_supported() { ...@@ -224,4 +219,4 @@ bool moe_permute_unpermute_supported() {
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("moe_permute", &moe_permute); m.impl("moe_permute", &moe_permute);
m.impl("moe_unpermute", &moe_unpermute); m.impl("moe_unpermute", &moe_unpermute);
} }
\ No newline at end of file
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment