Unverified Commit 813bd6f8 authored by Yuan Luo's avatar Yuan Luo Committed by GitHub
Browse files

[2/2] Use moe_sum_reduce cuda kernel (#10654)


Co-authored-by: default avatarluoyuan.luo <luoyuan.luo@antgroup.com>
Co-authored-by: default avatarhuangtingwei <141888744+huangtingwei9988@users.noreply.github.com>
parent 729f612d
...@@ -36,7 +36,7 @@ _is_cpu = is_cpu() ...@@ -36,7 +36,7 @@ _is_cpu = is_cpu()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
if _is_cuda: if _is_cuda:
from sgl_kernel import gelu_and_mul, silu_and_mul from sgl_kernel import gelu_and_mul, moe_sum_reduce, silu_and_mul
elif _is_cpu and _is_cpu_amx_available: elif _is_cpu and _is_cpu_amx_available:
pass pass
elif _is_hip: elif _is_hip:
...@@ -569,11 +569,12 @@ def fused_experts_impl( ...@@ -569,11 +569,12 @@ def fused_experts_impl(
routed_scaling_factor, routed_scaling_factor,
) )
else: else:
moe_sum_reduce_triton( moe_sum_reduce(
intermediate_cache3.view(*intermediate_cache3.shape), intermediate_cache3.view(*intermediate_cache3.shape),
out_hidden_states[begin_chunk_idx:end_chunk_idx], out_hidden_states[begin_chunk_idx:end_chunk_idx],
routed_scaling_factor, routed_scaling_factor,
) )
elif _is_hip: elif _is_hip:
if _use_aiter: if _use_aiter:
moe_sum( moe_sum(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment