Unverified Commit 9b8333d9 authored by yiakwy-xpu-ml-framework-team's avatar yiakwy-xpu-ml-framework-team Committed by GitHub
Browse files

[ROCm] enable moe topk softmax in amd (#4448)

parent f5bbf603
......@@ -61,6 +61,10 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
"moe_align_block_size(Tensor topk_ids, int num_experts, int block_size, Tensor! sorted_token_ids, Tensor! "
"experts_ids, Tensor! num_tokens_post_pad, Tensor! token_cnts_buffer, Tensor! cumsum_buffer) -> ()");
m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size);
m.def(
"topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! "
"token_expert_indices, Tensor gating_output) -> ()");
m.impl("topk_softmax", torch::kCUDA, &topk_softmax);
}
REGISTER_EXTENSION(common_ops)
......@@ -41,6 +41,7 @@ include_dirs = [
sources = [
"csrc/allreduce/custom_all_reduce.hip",
"csrc/moe/moe_align_kernel.cu",
"csrc/moe/moe_topk_softmax_kernels.cu",
"csrc/torch_extension_rocm.cc",
]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment