Unverified Commit f96413c4 authored by Xiaoyu Zhang's avatar Xiaoyu Zhang Committed by GitHub
Browse files

Refactor allreduce add rmsnorm pattern (#9278)

parent 08ebdf79
...@@ -34,6 +34,7 @@ from sglang.srt.layers.dp_attention import ( ...@@ -34,6 +34,7 @@ from sglang.srt.layers.dp_attention import (
get_attention_tp_size, get_attention_tp_size,
get_global_dp_buffer, get_global_dp_buffer,
get_local_dp_buffer, get_local_dp_buffer,
is_dp_attention_enabled,
) )
from sglang.srt.layers.moe import ( from sglang.srt.layers.moe import (
get_moe_a2a_backend, get_moe_a2a_backend,
...@@ -47,6 +48,8 @@ from sglang.srt.utils import is_cuda, is_flashinfer_available ...@@ -47,6 +48,8 @@ from sglang.srt.utils import is_cuda, is_flashinfer_available
_is_flashinfer_available = is_flashinfer_available() _is_flashinfer_available = is_flashinfer_available()
_is_sm100_supported = is_cuda() and is_sm100_supported() _is_sm100_supported = is_cuda() and is_sm100_supported()
FUSE_ALLREDUCE_MAX_BATCH_SIZE = 2048
class ScatterMode(Enum): class ScatterMode(Enum):
""" """
...@@ -162,11 +165,13 @@ class LayerCommunicator: ...@@ -162,11 +165,13 @@ class LayerCommunicator:
post_attention_layernorm: torch.nn.Module, post_attention_layernorm: torch.nn.Module,
# Reduce scatter requires skipping all-reduce in model code after MoE/MLP, so only enable for models which have that implemented. Remove flag once done for all models that use LayerCommunicator. # Reduce scatter requires skipping all-reduce in model code after MoE/MLP, so only enable for models which have that implemented. Remove flag once done for all models that use LayerCommunicator.
allow_reduce_scatter: bool = False, allow_reduce_scatter: bool = False,
is_last_layer: bool = False,
): ):
self.layer_scatter_modes = layer_scatter_modes self.layer_scatter_modes = layer_scatter_modes
self.input_layernorm = input_layernorm self.input_layernorm = input_layernorm
self.post_attention_layernorm = post_attention_layernorm self.post_attention_layernorm = post_attention_layernorm
self.allow_reduce_scatter = allow_reduce_scatter self.allow_reduce_scatter = allow_reduce_scatter
self.is_last_layer = is_last_layer
self._context = CommunicateContext.init_new() self._context = CommunicateContext.init_new()
self._communicate_simple_fn = CommunicateSimpleFn.get_fn( self._communicate_simple_fn = CommunicateSimpleFn.get_fn(
...@@ -264,6 +269,42 @@ class LayerCommunicator: ...@@ -264,6 +269,42 @@ class LayerCommunicator:
and forward_batch.dp_padding_mode.is_max_len() and forward_batch.dp_padding_mode.is_max_len()
) )
def should_fuse_mlp_allreduce_with_next_layer(
self, forward_batch: ForwardBatch
) -> bool:
speculative_algo = global_server_args_dict.get("speculative_algorithm", None)
if (
is_dp_attention_enabled()
and speculative_algo is not None
and speculative_algo.is_eagle()
):
return False
batch_size = (
forward_batch.input_ids.shape[0]
if hasattr(forward_batch, "input_ids")
else 0
)
if batch_size > FUSE_ALLREDUCE_MAX_BATCH_SIZE:
return False
static_conditions_met = (
(not self.is_last_layer)
and (self._context.tp_size > 1)
and global_server_args_dict.get("enable_flashinfer_allreduce_fusion", False)
and _is_sm100_supported
and _is_flashinfer_available
)
if not static_conditions_met:
return False
return (
batch_size > 0
and batch_size <= FUSE_ALLREDUCE_MAX_BATCH_SIZE
and (not self.is_last_layer)
)
@dataclass @dataclass
class CommunicateContext: class CommunicateContext:
......
...@@ -1852,10 +1852,11 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1852,10 +1852,11 @@ class DeepseekV2DecoderLayer(nn.Module):
input_layernorm=self.input_layernorm, input_layernorm=self.input_layernorm,
post_attention_layernorm=self.post_attention_layernorm, post_attention_layernorm=self.post_attention_layernorm,
allow_reduce_scatter=True, allow_reduce_scatter=True,
is_last_layer=(
is_nextn or (self.layer_id == self.config.num_hidden_layers - 1)
),
) )
self._fuse_allreduce_lookup_table = self._build_fuse_allreduce_lookup_table()
def _is_layer_sparse(self, layer_id: int, is_nextn: bool) -> bool: def _is_layer_sparse(self, layer_id: int, is_nextn: bool) -> bool:
return is_nextn or ( return is_nextn or (
self.config.n_routed_experts is not None self.config.n_routed_experts is not None
...@@ -1863,20 +1864,6 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1863,20 +1864,6 @@ class DeepseekV2DecoderLayer(nn.Module):
and layer_id % self.config.moe_layer_freq == 0 and layer_id % self.config.moe_layer_freq == 0
) )
def _should_fuse_mlp_allreduce_with_next_layer(self, forward_batch) -> bool:
"""Check if MLP allreduce can be fused with next layer's residual_rmsnorm"""
batch_size = (
forward_batch.input_ids.shape[0]
if hasattr(forward_batch, "input_ids")
else 0
)
if batch_size > 128:
return False
return self._fuse_allreduce_lookup_table.get(batch_size, False)
def forward( def forward(
self, self,
positions: torch.Tensor, positions: torch.Tensor,
...@@ -1902,11 +1889,9 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1902,11 +1889,9 @@ class DeepseekV2DecoderLayer(nn.Module):
) )
should_allreduce_fusion = ( should_allreduce_fusion = (
self._should_fuse_mlp_allreduce_with_next_layer(forward_batch) self.layer_communicator.should_fuse_mlp_allreduce_with_next_layer(
and not ( forward_batch
is_dp_attention_enabled() and self.speculative_algorithm.is_eagle()
) )
and not self.is_nextn
) )
# For DP with padding, reduce scatter can be used instead of all-reduce. # For DP with padding, reduce scatter can be used instead of all-reduce.
...@@ -1997,26 +1982,6 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1997,26 +1982,6 @@ class DeepseekV2DecoderLayer(nn.Module):
) )
return output return output
def _build_fuse_allreduce_lookup_table(self):
static_conditions_met = (
self.layer_id != self.config.num_hidden_layers - 1
and get_tensor_model_parallel_world_size() > 1
and global_server_args_dict.get("enable_flashinfer_allreduce_fusion", False)
and _is_sm100_supported
and _is_flashinfer_available
)
if not static_conditions_met:
return {}
lookup_table = {}
for batch_size in range(129): # 0 to 128
is_last_layer = self.layer_id == self.config.num_hidden_layers - 1
should_fuse = batch_size > 0 and batch_size <= 128 and not is_last_layer
lookup_table[batch_size] = should_fuse
return lookup_table
class DeepseekV2Model(nn.Module): class DeepseekV2Model(nn.Module):
fall_back_to_pt_during_load = False fall_back_to_pt_during_load = False
......
...@@ -453,44 +453,11 @@ class GptOssDecoderLayer(nn.Module): ...@@ -453,44 +453,11 @@ class GptOssDecoderLayer(nn.Module):
layer_scatter_modes=self.layer_scatter_modes, layer_scatter_modes=self.layer_scatter_modes,
input_layernorm=self.input_layernorm, input_layernorm=self.input_layernorm,
post_attention_layernorm=self.post_attention_layernorm, post_attention_layernorm=self.post_attention_layernorm,
is_last_layer=(
self.is_nextn or (self.layer_id == self.config.num_hidden_layers - 1)
),
) )
self._fuse_allreduce_lookup_table = self._build_fuse_allreduce_lookup_table()
def _should_fuse_mlp_allreduce_with_next_layer(self, forward_batch) -> bool:
"""Check if MLP allreduce can be fused with next layer's residual_rmsnorm"""
batch_size = (
forward_batch.input_ids.shape[0]
if hasattr(forward_batch, "input_ids")
else 0
)
if batch_size > 128:
return False
return self._fuse_allreduce_lookup_table.get(batch_size, False)
def _build_fuse_allreduce_lookup_table(self):
static_conditions_met = (
self.layer_id != self.config.num_hidden_layers - 1
and get_tensor_model_parallel_world_size() > 1
and global_server_args_dict.get("enable_flashinfer_allreduce_fusion", False)
and _is_sm100_supported
and _is_flashinfer_available
)
if not static_conditions_met:
return {}
lookup_table = {}
for batch_size in range(129): # 0 to 128
is_last_layer = self.layer_id == self.config.num_hidden_layers - 1
should_fuse = batch_size > 0 and batch_size <= 128 and not is_last_layer
lookup_table[batch_size] = should_fuse
return lookup_table
def forward( def forward(
self, self,
positions: torch.Tensor, positions: torch.Tensor,
...@@ -514,8 +481,9 @@ class GptOssDecoderLayer(nn.Module): ...@@ -514,8 +481,9 @@ class GptOssDecoderLayer(nn.Module):
) )
should_allreduce_fusion = ( should_allreduce_fusion = (
self._should_fuse_mlp_allreduce_with_next_layer(forward_batch) self.layer_communicator.should_fuse_mlp_allreduce_with_next_layer(
and not self.is_nextn forward_batch
)
) )
hidden_states = self.mlp(hidden_states, forward_batch, should_allreduce_fusion) hidden_states = self.mlp(hidden_states, forward_batch, should_allreduce_fusion)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment