Unverified Commit ec15c836 authored by Yuan Luo's avatar Yuan Luo Committed by GitHub
Browse files

Optimize Qwen3-moe model by using flashinfer fused allreduce (#9973)


Co-authored-by: default avatarluoyuan.luo <luoyuan.luo@antgroup.com>
parent 106c2b31
...@@ -42,9 +42,15 @@ from sglang.srt.layers.moe import ( ...@@ -42,9 +42,15 @@ from sglang.srt.layers.moe import (
) )
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import is_cuda, is_flashinfer_available, is_sm100_supported from sglang.srt.utils import (
is_cuda,
is_flashinfer_available,
is_sm90_supported,
is_sm100_supported,
)
_is_flashinfer_available = is_flashinfer_available() _is_flashinfer_available = is_flashinfer_available()
_is_sm90_supported = is_cuda() and is_sm90_supported()
_is_sm100_supported = is_cuda() and is_sm100_supported() _is_sm100_supported = is_cuda() and is_sm100_supported()
FUSE_ALLREDUCE_MAX_BATCH_SIZE = 2048 FUSE_ALLREDUCE_MAX_BATCH_SIZE = 2048
...@@ -484,11 +490,11 @@ class CommunicateWithAllReduceAndLayerNormFn: ...@@ -484,11 +490,11 @@ class CommunicateWithAllReduceAndLayerNormFn:
# According to the discussion in https://github.com/flashinfer-ai/flashinfer/issues/1223#issuecomment-3047256465 # According to the discussion in https://github.com/flashinfer-ai/flashinfer/issues/1223#issuecomment-3047256465
# We set the max token num to 128 for allreduce fusion with min-latency case(use_oneshot=True). # We set the max token num to 128 for allreduce fusion with min-latency case(use_oneshot=True).
if ( if (
_is_sm100_supported (_is_sm100_supported or _is_sm90_supported)
and _is_flashinfer_available and _is_flashinfer_available
and hasattr(layernorm, "forward_with_allreduce_fusion") and hasattr(layernorm, "forward_with_allreduce_fusion")
and global_server_args_dict["enable_flashinfer_allreduce_fusion"] and global_server_args_dict["enable_flashinfer_allreduce_fusion"]
and hidden_states.shape[0] <= 2048 and hidden_states.shape[0] <= 4096
): ):
hidden_states, residual = layernorm.forward_with_allreduce_fusion( hidden_states, residual = layernorm.forward_with_allreduce_fusion(
hidden_states, residual hidden_states, residual
......
...@@ -105,11 +105,14 @@ class Qwen2MoeMLP(nn.Module): ...@@ -105,11 +105,14 @@ class Qwen2MoeMLP(nn.Module):
def forward( def forward(
self, self,
x, x,
should_allreduce_fusion: bool = False,
use_reduce_scatter: bool = False, use_reduce_scatter: bool = False,
): ):
gate_up, _ = self.gate_up_proj(x) gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up) x = self.act_fn(gate_up)
x, _ = self.down_proj(x, skip_all_reduce=use_reduce_scatter) x, _ = self.down_proj(
x, skip_all_reduce=should_allreduce_fusion or use_reduce_scatter
)
return x return x
......
...@@ -42,7 +42,10 @@ from sglang.srt.layers.linear import ( ...@@ -42,7 +42,10 @@ from sglang.srt.layers.linear import (
RowParallelLinear, RowParallelLinear,
) )
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe import get_moe_a2a_backend from sglang.srt.layers.moe import (
get_moe_a2a_backend,
should_use_flashinfer_cutlass_moe_fp4_allgather,
)
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.moe.topk import TopK from sglang.srt.layers.moe.topk import TopK
...@@ -57,10 +60,17 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTe ...@@ -57,10 +60,17 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTe
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.qwen2_moe import Qwen2MoeMLP as Qwen3MoeMLP from sglang.srt.models.qwen2_moe import Qwen2MoeMLP as Qwen3MoeMLP
from sglang.srt.models.qwen2_moe import Qwen2MoeModel from sglang.srt.models.qwen2_moe import Qwen2MoeModel
from sglang.srt.utils import add_prefix, is_cuda, is_non_idle_and_non_empty from sglang.srt.utils import (
add_prefix,
is_cuda,
is_flashinfer_available,
is_non_idle_and_non_empty,
)
Qwen3MoeConfig = None Qwen3MoeConfig = None
_is_flashinfer_available = is_flashinfer_available()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_is_cuda = is_cuda() _is_cuda = is_cuda()
...@@ -119,11 +129,14 @@ class Qwen3MoeSparseMoeBlock(nn.Module): ...@@ -119,11 +129,14 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
forward_batch: Optional[ForwardBatch] = None, forward_batch: Optional[ForwardBatch] = None,
should_allreduce_fusion: bool = False,
use_reduce_scatter: bool = False, use_reduce_scatter: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
if not get_moe_a2a_backend().is_deepep(): if not get_moe_a2a_backend().is_deepep():
return self.forward_normal(hidden_states, use_reduce_scatter) return self.forward_normal(
hidden_states, should_allreduce_fusion, use_reduce_scatter
)
else: else:
return self.forward_deepep(hidden_states, forward_batch) return self.forward_deepep(hidden_states, forward_batch)
...@@ -137,6 +150,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module): ...@@ -137,6 +150,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
def forward_normal( def forward_normal(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
should_allreduce_fusion: bool = False,
use_reduce_scatter: bool = False, use_reduce_scatter: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape num_tokens, hidden_dim = hidden_states.shape
...@@ -146,7 +160,12 @@ class Qwen3MoeSparseMoeBlock(nn.Module): ...@@ -146,7 +160,12 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
router_logits, _ = self.gate(hidden_states) router_logits, _ = self.gate(hidden_states)
topk_output = self.topk(hidden_states, router_logits) topk_output = self.topk(hidden_states, router_logits)
final_hidden_states = self.experts(hidden_states, topk_output) final_hidden_states = self.experts(hidden_states, topk_output)
if self.tp_size > 1 and not use_reduce_scatter: if (
self.tp_size > 1
and not should_allreduce_fusion
and not use_reduce_scatter
and not should_use_flashinfer_cutlass_moe_fp4_allgather()
):
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
return final_hidden_states.view(num_tokens, hidden_dim) return final_hidden_states.view(num_tokens, hidden_dim)
...@@ -500,6 +519,7 @@ class Qwen3MoeDecoderLayer(nn.Module): ...@@ -500,6 +519,7 @@ class Qwen3MoeDecoderLayer(nn.Module):
input_layernorm=self.input_layernorm, input_layernorm=self.input_layernorm,
post_attention_layernorm=self.post_attention_layernorm, post_attention_layernorm=self.post_attention_layernorm,
allow_reduce_scatter=True, allow_reduce_scatter=True,
is_last_layer=(self.layer_id == self.config.num_hidden_layers - 1),
) )
def forward( def forward(
...@@ -525,17 +545,28 @@ class Qwen3MoeDecoderLayer(nn.Module): ...@@ -525,17 +545,28 @@ class Qwen3MoeDecoderLayer(nn.Module):
hidden_states, residual, forward_batch hidden_states, residual, forward_batch
) )
should_allreduce_fusion = (
self.layer_communicator.should_fuse_mlp_allreduce_with_next_layer(
forward_batch
)
)
# For DP with padding, reduce scatter can be used instead of all-reduce. # For DP with padding, reduce scatter can be used instead of all-reduce.
use_reduce_scatter = self.layer_communicator.should_use_reduce_scatter( use_reduce_scatter = self.layer_communicator.should_use_reduce_scatter(
forward_batch forward_batch
) )
hidden_states = self.mlp(hidden_states, forward_batch, use_reduce_scatter) hidden_states = self.mlp(
hidden_states, forward_batch, should_allreduce_fusion, use_reduce_scatter
hidden_states, residual = self.layer_communicator.postprocess_layer(
hidden_states, residual, forward_batch
) )
if should_allreduce_fusion:
hidden_states._sglang_needs_allreduce_fusion = True
else:
hidden_states, residual = self.layer_communicator.postprocess_layer(
hidden_states, residual, forward_batch
)
return hidden_states, residual return hidden_states, residual
def op_comm_prepare_attn( def op_comm_prepare_attn(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment