Unverified Commit 489934be authored by Yi Zhang's avatar Yi Zhang Committed by GitHub
Browse files

fuse renormal into moe topk softmax kernel python code (#7751)


Co-authored-by: default avatarispobock <ispobaoke@gmail.com>
Co-authored-by: default avatarzhyncs <me@zhyncs.com>
parent 43f93f63
......@@ -52,7 +52,7 @@ runtime_common = [
srt = [
"sglang[runtime_common]",
"sgl-kernel==0.2.1",
"sgl-kernel==0.2.2",
"torch==2.7.1",
"torchaudio==2.7.1",
"torchvision==0.22.1",
......
......@@ -650,7 +650,7 @@ def _set_envs_and_config(server_args: ServerArgs):
if _is_cuda:
assert_pkg_version(
"sgl-kernel",
"0.2.1",
"0.2.2",
"Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`",
)
......
......@@ -108,38 +108,14 @@ def fused_topk(
M, topk, dtype=torch.float32, device=hidden_states.device
)
topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
token_expert_indicies = torch.empty(
M, topk, dtype=torch.int32, device=hidden_states.device
)
topk_softmax(
topk_weights,
topk_ids,
token_expert_indicies,
gating_output.float(),
renormalize,
)
del token_expert_indicies
return _fused_topk_postprocess(
topk_weights=topk_weights,
topk_ids=topk_ids,
renormalize=renormalize,
expert_location_dispatch_info=expert_location_dispatch_info,
num_token_non_padded=num_token_non_padded,
)
@torch.compile(dynamic=True, backend=get_compiler_backend())
def _fused_topk_postprocess(
topk_weights,
topk_ids,
renormalize,
expert_location_dispatch_info,
num_token_non_padded,
):
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
_mask_topk_ids_padded_region(topk_ids, num_token_non_padded)
return topk_weights, topk_ids
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment