Unverified Commit 505329ca authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Support shared experts overlap in cutlass moe (#11611)

parent 8a382fd3
...@@ -800,7 +800,7 @@ class FusedMoE(torch.nn.Module): ...@@ -800,7 +800,7 @@ class FusedMoE(torch.nn.Module):
f"Unsupported weight_name {weight_name} for FusedMoE weight_loader_fused. Nothing is loaded." f"Unsupported weight_name {weight_name} for FusedMoE weight_loader_fused. Nothing is loaded."
) )
def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput): def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput, **kwargs):
origin_hidden_states_dim = hidden_states.shape[-1] origin_hidden_states_dim = hidden_states.shape[-1]
assert self.quant_method is not None assert self.quant_method is not None
...@@ -825,6 +825,7 @@ class FusedMoE(torch.nn.Module): ...@@ -825,6 +825,7 @@ class FusedMoE(torch.nn.Module):
combine_input = self.quant_method.apply( combine_input = self.quant_method.apply(
layer=self, layer=self,
dispatch_output=dispatch_output, dispatch_output=dispatch_output,
**kwargs,
) )
final_hidden_states = self.dispatcher.combine(combine_input) final_hidden_states = self.dispatcher.combine(combine_input)
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
import os
from typing import TYPE_CHECKING, Any, Dict, List, Optional from typing import TYPE_CHECKING, Any, Dict, List, Optional
import torch import torch
...@@ -1347,6 +1348,8 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): ...@@ -1347,6 +1348,8 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
self, self,
layer: FusedMoE, layer: FusedMoE,
dispatch_output: StandardDispatchOutput, dispatch_output: StandardDispatchOutput,
forward_shared_experts=None,
alt_stream=None,
) -> CombineInput: ) -> CombineInput:
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
...@@ -1418,9 +1421,19 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): ...@@ -1418,9 +1421,19 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
)[0] )[0]
if should_use_flashinfer_cutlass_moe_fp4_allgather(): if should_use_flashinfer_cutlass_moe_fp4_allgather():
output, global_output = get_local_dp_buffer(), output output, global_output = get_local_dp_buffer(), output
if forward_shared_experts is not None:
alt_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(alt_stream):
forward_shared_experts()
get_tp_group().reduce_scatterv( get_tp_group().reduce_scatterv(
global_output, output=output, sizes=get_dp_global_num_tokens() global_output, output=output, sizes=get_dp_global_num_tokens()
) )
if forward_shared_experts is not None:
torch.cuda.current_stream().wait_stream(alt_stream)
return StandardCombineInput(hidden_states=output) return StandardCombineInput(hidden_states=output)
from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4 from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4
......
...@@ -655,6 +655,7 @@ class DeepseekV2MoE(nn.Module): ...@@ -655,6 +655,7 @@ class DeepseekV2MoE(nn.Module):
self._enable_a2a_moe = ( self._enable_a2a_moe = (
get_moe_a2a_backend().is_deepep() or get_moe_a2a_backend().is_mooncake() get_moe_a2a_backend().is_deepep() or get_moe_a2a_backend().is_mooncake()
) )
self._fuse_shared_experts_inside_sbo = SboFlags.fuse_shared_experts_inside_sbo()
def get_moe_weights(self): def get_moe_weights(self):
return [ return [
...@@ -746,6 +747,7 @@ class DeepseekV2MoE(nn.Module): ...@@ -746,6 +747,7 @@ class DeepseekV2MoE(nn.Module):
return self.forward_cpu(hidden_states, should_allreduce_fusion) return self.forward_cpu(hidden_states, should_allreduce_fusion)
if hidden_states.shape[0] > 0: if hidden_states.shape[0] > 0:
if not self._fuse_shared_experts_inside_sbo:
shared_output = self._forward_shared_experts( shared_output = self._forward_shared_experts(
hidden_states, gemm_output_zero_allocator hidden_states, gemm_output_zero_allocator
) )
...@@ -756,7 +758,27 @@ class DeepseekV2MoE(nn.Module): ...@@ -756,7 +758,27 @@ class DeepseekV2MoE(nn.Module):
shared_output = None shared_output = None
topk_output = self.topk.empty_topk_output(hidden_states.device) topk_output = self.topk.empty_topk_output(hidden_states.device)
final_hidden_states = self.experts(hidden_states, topk_output) if self._fuse_shared_experts_inside_sbo:
shared_output = None
def _forward_shared_experts_and_put_results():
nonlocal shared_output
shared_output = self._forward_shared_experts(
hidden_states, gemm_output_zero_allocator
)
final_hidden_states = self.experts(
hidden_states,
topk_output,
**(
dict(
forward_shared_experts=_forward_shared_experts_and_put_results,
alt_stream=self.alt_stream,
)
if self._fuse_shared_experts_inside_sbo
else {}
),
)
if not _is_cuda and not _use_aiter: if not _is_cuda and not _use_aiter:
# fused in biased_grouped_topk so we can skip here # fused in biased_grouped_topk so we can skip here
final_hidden_states *= self.routed_scaling_factor final_hidden_states *= self.routed_scaling_factor
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment