Unverified Commit 45f3f3f5 authored by Sage Moore's avatar Sage Moore Committed by GitHub
Browse files

[ROCm][Bugfix] Ensure that the moe_wna16_gemm kernel is not built on ROCm platforms. (#14629)


Signed-off-by: default avatarSage Moore <sage@neuralmagic.com>
parent ff47aab0
...@@ -559,7 +559,6 @@ target_compile_definitions(_C PRIVATE CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1) ...@@ -559,7 +559,6 @@ target_compile_definitions(_C PRIVATE CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1)
set(VLLM_MOE_EXT_SRC set(VLLM_MOE_EXT_SRC
"csrc/moe/torch_bindings.cpp" "csrc/moe/torch_bindings.cpp"
"csrc/moe/moe_align_sum_kernels.cu" "csrc/moe/moe_align_sum_kernels.cu"
"csrc/moe/moe_wna16.cu"
"csrc/moe/topk_softmax_kernels.cu") "csrc/moe/topk_softmax_kernels.cu")
set_gencode_flags_for_srcs( set_gencode_flags_for_srcs(
...@@ -574,6 +573,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") ...@@ -574,6 +573,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
SRCS "${VLLM_MOE_WNA16_SRC}" SRCS "${VLLM_MOE_WNA16_SRC}"
CUDA_ARCHS "${CUDA_ARCHS}") CUDA_ARCHS "${CUDA_ARCHS}")
list(APPEND VLLM_MOE_EXT_SRC "${VLLM_MOE_WNA16_SRC}")
cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0" "${CUDA_ARCHS}") cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0" "${CUDA_ARCHS}")
if (MARLIN_MOE_ARCHS) if (MARLIN_MOE_ARCHS)
set(MARLIN_MOE_SRC set(MARLIN_MOE_SRC
......
...@@ -18,7 +18,7 @@ void sgl_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, ...@@ -18,7 +18,7 @@ void sgl_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
torch::Tensor sorted_token_ids, torch::Tensor sorted_token_ids,
torch::Tensor experts_ids, torch::Tensor experts_ids,
torch::Tensor num_tokens_post_pad); torch::Tensor num_tokens_post_pad);
#ifndef USE_ROCM
torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output, torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output,
torch::Tensor b_qweight, torch::Tensor b_scales, torch::Tensor b_qweight, torch::Tensor b_scales,
std::optional<torch::Tensor> b_qzeros, std::optional<torch::Tensor> b_qzeros,
...@@ -28,3 +28,4 @@ torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output, ...@@ -28,3 +28,4 @@ torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output,
torch::Tensor num_tokens_post_pad, int64_t top_k, torch::Tensor num_tokens_post_pad, int64_t top_k,
int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N, int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N,
int64_t BLOCK_SIZE_K, int64_t bit); int64_t BLOCK_SIZE_K, int64_t bit);
#endif
\ No newline at end of file
...@@ -31,6 +31,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { ...@@ -31,6 +31,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
" Tensor! num_tokens_post_pad) -> ()"); " Tensor! num_tokens_post_pad) -> ()");
m.impl("sgl_moe_align_block_size", torch::kCUDA, &sgl_moe_align_block_size); m.impl("sgl_moe_align_block_size", torch::kCUDA, &sgl_moe_align_block_size);
#ifndef USE_ROCM
m.def( m.def(
"moe_wna16_gemm(Tensor input, Tensor! output, Tensor b_qweight, " "moe_wna16_gemm(Tensor input, Tensor! output, Tensor b_qweight, "
"Tensor b_scales, Tensor? b_qzeros, " "Tensor b_scales, Tensor? b_qzeros, "
...@@ -41,7 +42,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { ...@@ -41,7 +42,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
m.impl("moe_wna16_gemm", torch::kCUDA, &moe_wna16_gemm); m.impl("moe_wna16_gemm", torch::kCUDA, &moe_wna16_gemm);
#ifndef USE_ROCM
m.def( m.def(
"marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, " "marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, "
"Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! " "Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! "
......
...@@ -1146,6 +1146,10 @@ def moe_wna16_gemm(input: torch.Tensor, output: torch.Tensor, ...@@ -1146,6 +1146,10 @@ def moe_wna16_gemm(input: torch.Tensor, output: torch.Tensor,
num_tokens_post_pad: torch.Tensor, top_k: int, num_tokens_post_pad: torch.Tensor, top_k: int,
BLOCK_SIZE_M: int, BLOCK_SIZE_N: int, BLOCK_SIZE_K: int, BLOCK_SIZE_M: int, BLOCK_SIZE_N: int, BLOCK_SIZE_K: int,
bit: int) -> torch.Tensor: bit: int) -> torch.Tensor:
if not current_platform.is_cuda():
raise NotImplementedError(
"The optimized moe_wna16_gemm kernel is only "
"available on CUDA platforms")
torch.ops._moe_C.moe_wna16_gemm(input, output, b_qweight, b_scales, torch.ops._moe_C.moe_wna16_gemm(input, output, b_qweight, b_scales,
b_qzeros, topk_weights, sorted_token_ids, b_qzeros, topk_weights, sorted_token_ids,
experts_ids, num_tokens_post_pad, top_k, experts_ids, num_tokens_post_pad, top_k,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment