Unverified Commit 84b006b2 authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

Cleanup MoE Refactor (#9223)

parent 8ca07bd9
......@@ -573,6 +573,9 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
topk_output: TopKOutput,
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor:
from sglang.srt.layers.moe.topk import TopKOutputChecker
if self.use_flashinfer:
# Based on profiling results, we need to quantize x to mxfp8 here to achieve better performance
x_quant, x_scale = mxfp8_quantize(
......@@ -580,8 +583,10 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
) # to mxfp8
x_scale = x_scale.view(torch.float8_e4m3fn).reshape(-1)
assert x_quant.shape[-1] == self.hidden_size
assert TopKOutputChecker.format_is_bypassed(topk_output)
top_k, router_logits = topk_output
top_k = topk_output.topk_config.top_k
router_logits = topk_output.router_logits
trtllm_gen_output = trtllm_fp4_block_scale_moe(
router_logits.to(torch.bfloat16),
......@@ -602,8 +607,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
None, # output2_scale_scalar
layer.num_experts,
top_k,
None, # n_group
None, # topk_group
None, # n_group # TODO: support n_group
None, # topk_group # TODO: support topk_group
self.intermediate_size, # padded to multiple of 256
layer.moe_ep_rank * layer.num_local_experts, # local_expert_offset
layer.num_local_experts, # local num experts
......
......@@ -459,15 +459,15 @@ class DeepseekV2MoE(nn.Module):
with torch.cuda.stream(self.alt_stream):
# router_logits: (num_tokens, n_experts)
router_logits = self.gate(hidden_states)
kwargs = {"hidden_states": hidden_states}
kwargs["topk_output"] = self.topk(hidden_states, router_logits)
final_hidden_states = self.experts(**kwargs)
topk_output = self.topk(hidden_states, router_logits)
final_hidden_states = self.experts(hidden_states, topk_output)
if not _is_cuda:
final_hidden_states *= self.routed_scaling_factor
current_stream.wait_stream(self.alt_stream)
with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
final_hidden_states_out = torch.empty_like(final_hidden_states)
torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
final_hidden_states = final_hidden_states_out
sm.tag(final_hidden_states)
......@@ -489,10 +489,9 @@ class DeepseekV2MoE(nn.Module):
shared_output = self._forward_shared_experts(hidden_states)
# router_logits: (num_tokens, n_experts)
router_logits = self.gate(hidden_states)
kwargs = {"hidden_states": hidden_states}
kwargs["topk_output"] = self.topk(hidden_states, router_logits)
topk_output = self.topk(hidden_states, router_logits)
final_hidden_states = self.experts(**kwargs)
final_hidden_states = self.experts(hidden_states, topk_output)
if not _is_cuda and not _use_aiter:
# fused in biased_grouped_topk so we can skip here
final_hidden_states *= self.routed_scaling_factor
......
......@@ -509,9 +509,8 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
with torch.cuda.stream(self.alt_stream):
# router_logits: (num_tokens, n_experts)
router_logits = self.gate(hidden_states)
kwargs = {"hidden_states": hidden_states}
kwargs["topk_output"] = self.topk(hidden_states, router_logits)
final_hidden_states = self.experts(**kwargs)
topk_output = self.topk(hidden_states, router_logits)
final_hidden_states = self.experts(hidden_states, topk_output)
if not _is_cuda:
final_hidden_states *= self.routed_scaling_factor
current_stream.wait_stream(self.alt_stream)
......@@ -552,9 +551,8 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
shared_output = self._forward_shared_experts(hidden_states)
# router_logits: (num_tokens, n_experts)
router_logits = self.gate(hidden_states)
kwargs = {"hidden_states": hidden_states}
kwargs["topk_output"] = self.topk(hidden_states, router_logits)
final_hidden_states = self.experts(**kwargs)
topk_output = self.topk(hidden_states, router_logits)
final_hidden_states = self.experts(hidden_states, topk_output)
if not _is_cuda and not _use_aiter:
# fused in biased_grouped_topk so we can skip here
final_hidden_states *= self.routed_scaling_factor
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment