"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "7f43f65235085b976188e6edb519d4cd79b6a509"
Unverified Commit e984d507 authored by valarLip's avatar valarLip Committed by GitHub
Browse files

enable aiter_biased_grouped_topk kernel (#7423)

parent 755f3147
...@@ -30,6 +30,7 @@ from sglang.srt.managers.expert_location_dispatch import ( ...@@ -30,6 +30,7 @@ from sglang.srt.managers.expert_location_dispatch import (
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.utils import ( from sglang.srt.utils import (
cpu_has_amx_support, cpu_has_amx_support,
get_bool_env_var,
get_compiler_backend, get_compiler_backend,
is_cpu, is_cpu,
is_cuda, is_cuda,
...@@ -38,6 +39,7 @@ from sglang.srt.utils import ( ...@@ -38,6 +39,7 @@ from sglang.srt.utils import (
_is_cuda = is_cuda() _is_cuda = is_cuda()
_is_hip = is_hip() _is_hip = is_hip()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
_is_cpu_amx_available = cpu_has_amx_support() _is_cpu_amx_available = cpu_has_amx_support()
_is_cpu = is_cpu() _is_cpu = is_cpu()
...@@ -46,6 +48,11 @@ if _is_cuda: ...@@ -46,6 +48,11 @@ if _is_cuda:
if _is_cuda or _is_hip: if _is_cuda or _is_hip:
from sgl_kernel import topk_softmax from sgl_kernel import topk_softmax
if _use_aiter:
try:
from aiter import biased_grouped_topk as aiter_biased_grouped_topk
except ImportError:
raise ImportError("aiter is required when SGLANG_USE_AITER is set to True")
def fused_topk_torch_native( def fused_topk_torch_native(
...@@ -347,6 +354,25 @@ def biased_grouped_topk_gpu( ...@@ -347,6 +354,25 @@ def biased_grouped_topk_gpu(
topk_ids, expert_location_dispatch_info, num_token_non_padded topk_ids, expert_location_dispatch_info, num_token_non_padded
) )
return topk_weights, topk_ids return topk_weights, topk_ids
elif _use_aiter:
token = gating_output.shape[0]
device = gating_output.device
assert (
hidden_states.shape[0] == gating_output.shape[0]
), f"Number of tokens mismatch: hidden_states.shape[0] = {hidden_states.shape[0]}, gating_output.shape[0] = {gating_output.shape[0]}"
topk_weights = torch.empty((token, topk), dtype=torch.float32, device=device)
topk_ids = torch.empty((token, topk), dtype=torch.int32, device=device)
aiter_biased_grouped_topk(
gating_output,
correction_bias,
topk_weights,
topk_ids,
num_expert_group,
topk_group,
renormalize,
routed_scaling_factor,
)
return topk_weights, topk_ids
else: else:
biased_grouped_topk_fn = ( biased_grouped_topk_fn = (
torch.compile( torch.compile(
......
...@@ -421,7 +421,7 @@ class CudaGraphRunner: ...@@ -421,7 +421,7 @@ class CudaGraphRunner:
empty_cache=False, empty_cache=False,
) )
capture_range.set_description( capture_range.set_description(
f"Capturing batches ({avail_mem=:.2f} GB)" f"Capturing batches ({bs=} {avail_mem=:.2f} GB)"
) )
with patch_model( with patch_model(
......
...@@ -388,7 +388,8 @@ class DeepseekV2MoE(nn.Module): ...@@ -388,7 +388,8 @@ class DeepseekV2MoE(nn.Module):
final_hidden_states = self.experts( final_hidden_states = self.experts(
hidden_states=hidden_states, router_logits=router_logits hidden_states=hidden_states, router_logits=router_logits
) )
if not _is_cuda: if not _is_cuda and not _use_aiter:
# fused in biased_grouped_topk so we can skip here
final_hidden_states *= self.routed_scaling_factor final_hidden_states *= self.routed_scaling_factor
if shared_output is not None: if shared_output is not None:
final_hidden_states = final_hidden_states + shared_output final_hidden_states = final_hidden_states + shared_output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment