Unverified Commit c4500233 authored by sogalin's avatar sogalin Committed by GitHub
Browse files

Add Qwen3-30B-A3B-Thinking-2507 support on AMD GPUs. (#9456)

parent f445a1d9
...@@ -49,13 +49,15 @@ if _is_cuda: ...@@ -49,13 +49,15 @@ if _is_cuda:
elif _is_cpu and _is_cpu_amx_available: elif _is_cpu and _is_cpu_amx_available:
pass pass
elif _is_hip: elif _is_hip:
from vllm import _custom_ops as vllm_ops # gelu_and_mul, silu_and_mul from sgl_kernel import gelu_and_mul, silu_and_mul
if _use_aiter: if _use_aiter:
try: try:
from aiter import moe_sum from aiter import moe_sum
except ImportError: except ImportError:
raise ImportError("aiter is required when SGLANG_USE_AITER is set to True") raise ImportError("aiter is required when SGLANG_USE_AITER is set to True")
else:
from vllm import _custom_ops as vllm_ops
if _is_cuda or _is_hip: if _is_cuda or _is_hip:
...@@ -1537,7 +1539,7 @@ def fused_experts_impl( ...@@ -1537,7 +1539,7 @@ def fused_experts_impl(
gemm1_alpha, gemm1_alpha,
gemm1_limit, gemm1_limit,
) )
elif _is_cuda: elif _is_cuda or _is_hip:
silu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2) silu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
else: else:
vllm_ops.silu_and_mul( vllm_ops.silu_and_mul(
...@@ -1546,7 +1548,7 @@ def fused_experts_impl( ...@@ -1546,7 +1548,7 @@ def fused_experts_impl(
elif activation == "gelu": elif activation == "gelu":
assert gemm1_alpha is None, "gemm1_alpha is not supported for gelu" assert gemm1_alpha is None, "gemm1_alpha is not supported for gelu"
assert gemm1_limit is None, "gemm1_limit is not supported for gelu" assert gemm1_limit is None, "gemm1_limit is not supported for gelu"
if _is_cuda: if _is_cuda or _is_hip:
gelu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2) gelu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
else: else:
vllm_ops.gelu_and_mul( vllm_ops.gelu_and_mul(
...@@ -1619,10 +1621,19 @@ def fused_experts_impl( ...@@ -1619,10 +1621,19 @@ def fused_experts_impl(
out_hidden_states[begin_chunk_idx:end_chunk_idx], out_hidden_states[begin_chunk_idx:end_chunk_idx],
) )
else: else:
vllm_ops.moe_sum( # According to micro benchmark results, torch.compile can get better performance for small token.
intermediate_cache3.view(*intermediate_cache3.shape), if tokens_in_chunk <= 32:
out_hidden_states[begin_chunk_idx:end_chunk_idx], moe_sum_reduce_torch_compile(
) intermediate_cache3.view(*intermediate_cache3.shape),
out_hidden_states[begin_chunk_idx:end_chunk_idx],
routed_scaling_factor,
)
else:
moe_sum_reduce_triton(
intermediate_cache3.view(*intermediate_cache3.shape),
out_hidden_states[begin_chunk_idx:end_chunk_idx],
routed_scaling_factor,
)
else: else:
vllm_ops.moe_sum( vllm_ops.moe_sum(
intermediate_cache3.view(*intermediate_cache3.shape), intermediate_cache3.view(*intermediate_cache3.shape),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment