Unverified Commit 84b006b2 authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

Cleanup MoE Refactor (#9223)

parent 8ca07bd9
...@@ -573,6 +573,9 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -573,6 +573,9 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
topk_output: TopKOutput, topk_output: TopKOutput,
moe_runner_config: MoeRunnerConfig, moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor: ) -> torch.Tensor:
from sglang.srt.layers.moe.topk import TopKOutputChecker
if self.use_flashinfer: if self.use_flashinfer:
# Based on profiling results, we need to quantize x to mxfp8 here to achieve better performance # Based on profiling results, we need to quantize x to mxfp8 here to achieve better performance
x_quant, x_scale = mxfp8_quantize( x_quant, x_scale = mxfp8_quantize(
...@@ -580,8 +583,10 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -580,8 +583,10 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
) # to mxfp8 ) # to mxfp8
x_scale = x_scale.view(torch.float8_e4m3fn).reshape(-1) x_scale = x_scale.view(torch.float8_e4m3fn).reshape(-1)
assert x_quant.shape[-1] == self.hidden_size assert x_quant.shape[-1] == self.hidden_size
assert TopKOutputChecker.format_is_bypassed(topk_output)
top_k, router_logits = topk_output top_k = topk_output.topk_config.top_k
router_logits = topk_output.router_logits
trtllm_gen_output = trtllm_fp4_block_scale_moe( trtllm_gen_output = trtllm_fp4_block_scale_moe(
router_logits.to(torch.bfloat16), router_logits.to(torch.bfloat16),
...@@ -602,8 +607,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -602,8 +607,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
None, # output2_scale_scalar None, # output2_scale_scalar
layer.num_experts, layer.num_experts,
top_k, top_k,
None, # n_group None, # n_group # TODO: support n_group
None, # topk_group None, # topk_group # TODO: support topk_group
self.intermediate_size, # padded to multiple of 256 self.intermediate_size, # padded to multiple of 256
layer.moe_ep_rank * layer.num_local_experts, # local_expert_offset layer.moe_ep_rank * layer.num_local_experts, # local_expert_offset
layer.num_local_experts, # local num experts layer.num_local_experts, # local num experts
......
...@@ -459,15 +459,15 @@ class DeepseekV2MoE(nn.Module): ...@@ -459,15 +459,15 @@ class DeepseekV2MoE(nn.Module):
with torch.cuda.stream(self.alt_stream): with torch.cuda.stream(self.alt_stream):
# router_logits: (num_tokens, n_experts) # router_logits: (num_tokens, n_experts)
router_logits = self.gate(hidden_states) router_logits = self.gate(hidden_states)
kwargs = {"hidden_states": hidden_states} topk_output = self.topk(hidden_states, router_logits)
kwargs["topk_output"] = self.topk(hidden_states, router_logits) final_hidden_states = self.experts(hidden_states, topk_output)
final_hidden_states = self.experts(**kwargs)
if not _is_cuda: if not _is_cuda:
final_hidden_states *= self.routed_scaling_factor final_hidden_states *= self.routed_scaling_factor
current_stream.wait_stream(self.alt_stream) current_stream.wait_stream(self.alt_stream)
with use_symmetric_memory(parallel_state.get_tp_group()) as sm: with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
final_hidden_states_out = torch.empty_like(final_hidden_states) final_hidden_states_out = torch.empty_like(final_hidden_states)
torch.add(final_hidden_states, shared_output, out=final_hidden_states_out) torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
final_hidden_states = final_hidden_states_out final_hidden_states = final_hidden_states_out
sm.tag(final_hidden_states) sm.tag(final_hidden_states)
...@@ -489,10 +489,9 @@ class DeepseekV2MoE(nn.Module): ...@@ -489,10 +489,9 @@ class DeepseekV2MoE(nn.Module):
shared_output = self._forward_shared_experts(hidden_states) shared_output = self._forward_shared_experts(hidden_states)
# router_logits: (num_tokens, n_experts) # router_logits: (num_tokens, n_experts)
router_logits = self.gate(hidden_states) router_logits = self.gate(hidden_states)
kwargs = {"hidden_states": hidden_states} topk_output = self.topk(hidden_states, router_logits)
kwargs["topk_output"] = self.topk(hidden_states, router_logits)
final_hidden_states = self.experts(**kwargs) final_hidden_states = self.experts(hidden_states, topk_output)
if not _is_cuda and not _use_aiter: if not _is_cuda and not _use_aiter:
# fused in biased_grouped_topk so we can skip here # fused in biased_grouped_topk so we can skip here
final_hidden_states *= self.routed_scaling_factor final_hidden_states *= self.routed_scaling_factor
......
...@@ -509,9 +509,8 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE): ...@@ -509,9 +509,8 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
with torch.cuda.stream(self.alt_stream): with torch.cuda.stream(self.alt_stream):
# router_logits: (num_tokens, n_experts) # router_logits: (num_tokens, n_experts)
router_logits = self.gate(hidden_states) router_logits = self.gate(hidden_states)
kwargs = {"hidden_states": hidden_states} topk_output = self.topk(hidden_states, router_logits)
kwargs["topk_output"] = self.topk(hidden_states, router_logits) final_hidden_states = self.experts(hidden_states, topk_output)
final_hidden_states = self.experts(**kwargs)
if not _is_cuda: if not _is_cuda:
final_hidden_states *= self.routed_scaling_factor final_hidden_states *= self.routed_scaling_factor
current_stream.wait_stream(self.alt_stream) current_stream.wait_stream(self.alt_stream)
...@@ -552,9 +551,8 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE): ...@@ -552,9 +551,8 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
shared_output = self._forward_shared_experts(hidden_states) shared_output = self._forward_shared_experts(hidden_states)
# router_logits: (num_tokens, n_experts) # router_logits: (num_tokens, n_experts)
router_logits = self.gate(hidden_states) router_logits = self.gate(hidden_states)
kwargs = {"hidden_states": hidden_states} topk_output = self.topk(hidden_states, router_logits)
kwargs["topk_output"] = self.topk(hidden_states, router_logits) final_hidden_states = self.experts(hidden_states, topk_output)
final_hidden_states = self.experts(**kwargs)
if not _is_cuda and not _use_aiter: if not _is_cuda and not _use_aiter:
# fused in biased_grouped_topk so we can skip here # fused in biased_grouped_topk so we can skip here
final_hidden_states *= self.routed_scaling_factor final_hidden_states *= self.routed_scaling_factor
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment