Unverified Commit 44e86480 authored by Xiaoyu Zhang's avatar Xiaoyu Zhang Committed by GitHub
Browse files

fuse allreduce and residual_rmsnorm (#8731)

parent 8c07fabd
...@@ -441,7 +441,7 @@ class CommunicateWithAllReduceAndLayerNormFn: ...@@ -441,7 +441,7 @@ class CommunicateWithAllReduceAndLayerNormFn:
and _is_flashinfer_available and _is_flashinfer_available
and hasattr(layernorm, "forward_with_allreduce_fusion") and hasattr(layernorm, "forward_with_allreduce_fusion")
and global_server_args_dict["enable_flashinfer_allreduce_fusion"] and global_server_args_dict["enable_flashinfer_allreduce_fusion"]
and hidden_states.shape[0] <= 128 and hidden_states.shape[0] <= 2048
): ):
hidden_states, residual = layernorm.forward_with_allreduce_fusion( hidden_states, residual = layernorm.forward_with_allreduce_fusion(
hidden_states, residual hidden_states, residual
......
...@@ -92,7 +92,7 @@ _workspace_manager = FlashInferWorkspaceManager() ...@@ -92,7 +92,7 @@ _workspace_manager = FlashInferWorkspaceManager()
def ensure_workspace_initialized( def ensure_workspace_initialized(
max_token_num: int = 128, hidden_dim: int = 4096, use_fp32_lamport: bool = False max_token_num: int = 2048, hidden_dim: int = 4096, use_fp32_lamport: bool = False
): ):
"""Ensure workspace is initialized""" """Ensure workspace is initialized"""
if not is_flashinfer_available() or _flashinfer_comm is None: if not is_flashinfer_available() or _flashinfer_comm is None:
...@@ -124,8 +124,8 @@ def flashinfer_allreduce_residual_rmsnorm( ...@@ -124,8 +124,8 @@ def flashinfer_allreduce_residual_rmsnorm(
residual: torch.Tensor, residual: torch.Tensor,
weight: torch.Tensor, weight: torch.Tensor,
eps: float = 1e-6, eps: float = 1e-6,
max_token_num: int = 128, max_token_num: int = 2048,
use_oneshot: bool = True, use_oneshot: Optional[bool] = None,
trigger_completion_at_end: bool = False, trigger_completion_at_end: bool = False,
fp32_acc: bool = False, fp32_acc: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
......
...@@ -1294,6 +1294,7 @@ class RowParallelLinear(LinearBase): ...@@ -1294,6 +1294,7 @@ class RowParallelLinear(LinearBase):
with use_symmetric_memory(parallel_state.get_tp_group()) as sm: with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_) output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_)
sm.tag(output_parallel) sm.tag(output_parallel)
if self.reduce_results and self.tp_size > 1 and not skip_all_reduce: if self.reduce_results and self.tp_size > 1 and not skip_all_reduce:
output = tensor_model_parallel_all_reduce(output_parallel) output = tensor_model_parallel_all_reduce(output_parallel)
else: else:
......
...@@ -847,10 +847,14 @@ class FusedMoE(torch.nn.Module): ...@@ -847,10 +847,14 @@ class FusedMoE(torch.nn.Module):
) )
sm.tag(final_hidden_states) sm.tag(final_hidden_states)
final_hidden_states = final_hidden_states[
..., :origin_hidden_states_dim
].contiguous()
if self.reduce_results and (self.moe_tp_size > 1 or self.moe_ep_size > 1): if self.reduce_results and (self.moe_tp_size > 1 or self.moe_ep_size > 1):
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
return final_hidden_states[..., :origin_hidden_states_dim].contiguous() return final_hidden_states
@classmethod @classmethod
def make_expert_params_mapping( def make_expert_params_mapping(
......
...@@ -212,7 +212,7 @@ class DeepseekV2MLP(nn.Module): ...@@ -212,7 +212,7 @@ class DeepseekV2MLP(nn.Module):
self, self,
x, x,
forward_batch=None, forward_batch=None,
can_fuse_mlp_allreduce: bool = False, should_allreduce_fusion: bool = False,
use_reduce_scatter: bool = False, use_reduce_scatter: bool = False,
): ):
if (self.tp_size == 1) and x.shape[0] == 0: if (self.tp_size == 1) and x.shape[0] == 0:
...@@ -221,7 +221,7 @@ class DeepseekV2MLP(nn.Module): ...@@ -221,7 +221,7 @@ class DeepseekV2MLP(nn.Module):
gate_up, _ = self.gate_up_proj(x) gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up) x = self.act_fn(gate_up)
x, _ = self.down_proj( x, _ = self.down_proj(
x, skip_all_reduce=can_fuse_mlp_allreduce or use_reduce_scatter x, skip_all_reduce=should_allreduce_fusion or use_reduce_scatter
) )
return x return x
...@@ -448,7 +448,7 @@ class DeepseekV2MoE(nn.Module): ...@@ -448,7 +448,7 @@ class DeepseekV2MoE(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
forward_batch: Optional[ForwardBatch] = None, forward_batch: Optional[ForwardBatch] = None,
can_fuse_mlp_allreduce: bool = False, should_allreduce_fusion: bool = False,
use_reduce_scatter: bool = False, use_reduce_scatter: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
if not self._enable_deepep_moe: if not self._enable_deepep_moe:
...@@ -459,11 +459,11 @@ class DeepseekV2MoE(nn.Module): ...@@ -459,11 +459,11 @@ class DeepseekV2MoE(nn.Module):
and hidden_states.shape[0] <= DUAL_STREAM_TOKEN_THRESHOLD and hidden_states.shape[0] <= DUAL_STREAM_TOKEN_THRESHOLD
): ):
return self.forward_normal_dual_stream( return self.forward_normal_dual_stream(
hidden_states, can_fuse_mlp_allreduce, use_reduce_scatter hidden_states, should_allreduce_fusion, use_reduce_scatter
) )
else: else:
return self.forward_normal( return self.forward_normal(
hidden_states, can_fuse_mlp_allreduce, use_reduce_scatter hidden_states, should_allreduce_fusion, use_reduce_scatter
) )
else: else:
return self.forward_deepep(hidden_states, forward_batch) return self.forward_deepep(hidden_states, forward_batch)
...@@ -471,7 +471,7 @@ class DeepseekV2MoE(nn.Module): ...@@ -471,7 +471,7 @@ class DeepseekV2MoE(nn.Module):
def forward_normal_dual_stream( def forward_normal_dual_stream(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
can_fuse_mlp_allreduce: bool = False, should_allreduce_fusion: bool = False,
use_reduce_scatter: bool = False, use_reduce_scatter: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -500,20 +500,20 @@ class DeepseekV2MoE(nn.Module): ...@@ -500,20 +500,20 @@ class DeepseekV2MoE(nn.Module):
torch.add(final_hidden_states, shared_output, out=final_hidden_states_out) torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
final_hidden_states = final_hidden_states_out final_hidden_states = final_hidden_states_out
sm.tag(final_hidden_states) sm.tag(final_hidden_states)
if self.tp_size > 1 and not can_fuse_mlp_allreduce and not use_reduce_scatter: if self.tp_size > 1 and not should_allreduce_fusion and not use_reduce_scatter:
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
return final_hidden_states return final_hidden_states
def forward_normal( def forward_normal(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
can_fuse_mlp_allreduce: bool = False, should_allreduce_fusion: bool = False,
use_reduce_scatter: bool = False, use_reduce_scatter: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
if hasattr(self, "shared_experts") and use_intel_amx_backend( if hasattr(self, "shared_experts") and use_intel_amx_backend(
self.shared_experts.gate_up_proj self.shared_experts.gate_up_proj
): ):
return self.forward_cpu(hidden_states, can_fuse_mlp_allreduce) return self.forward_cpu(hidden_states, should_allreduce_fusion)
shared_output = self._forward_shared_experts(hidden_states) shared_output = self._forward_shared_experts(hidden_states)
# router_logits: (num_tokens, n_experts) # router_logits: (num_tokens, n_experts)
...@@ -537,12 +537,14 @@ class DeepseekV2MoE(nn.Module): ...@@ -537,12 +537,14 @@ class DeepseekV2MoE(nn.Module):
torch.add(final_hidden_states, shared_output, out=final_hidden_states_out) torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
final_hidden_states = final_hidden_states_out final_hidden_states = final_hidden_states_out
sm.tag(final_hidden_states) sm.tag(final_hidden_states)
if self.tp_size > 1 and not can_fuse_mlp_allreduce and not use_reduce_scatter: if self.tp_size > 1 and not should_allreduce_fusion and not use_reduce_scatter:
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
return final_hidden_states return final_hidden_states
def forward_cpu( def forward_cpu(
self, hidden_states: torch.Tensor, can_fuse_mlp_allreduce: bool = False self,
hidden_states: torch.Tensor,
should_allreduce_fusion: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
# router_logits: (num_tokens, n_experts) # router_logits: (num_tokens, n_experts)
router_logits = self.gate(hidden_states) router_logits = self.gate(hidden_states)
...@@ -593,7 +595,7 @@ class DeepseekV2MoE(nn.Module): ...@@ -593,7 +595,7 @@ class DeepseekV2MoE(nn.Module):
None, # a2_scale None, # a2_scale
True, # is_vnni True, # is_vnni
) )
if self.tp_size > 1 and not can_fuse_mlp_allreduce: if self.tp_size > 1 and not should_allreduce_fusion:
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
return final_hidden_states return final_hidden_states
...@@ -1842,6 +1844,8 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1842,6 +1844,8 @@ class DeepseekV2DecoderLayer(nn.Module):
allow_reduce_scatter=True, allow_reduce_scatter=True,
) )
self._fuse_allreduce_lookup_table = self._build_fuse_allreduce_lookup_table()
def _is_layer_sparse(self, layer_id: int, is_nextn: bool) -> bool: def _is_layer_sparse(self, layer_id: int, is_nextn: bool) -> bool:
return is_nextn or ( return is_nextn or (
self.config.n_routed_experts is not None self.config.n_routed_experts is not None
...@@ -1850,27 +1854,18 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1850,27 +1854,18 @@ class DeepseekV2DecoderLayer(nn.Module):
) )
def _should_fuse_mlp_allreduce_with_next_layer(self, forward_batch) -> bool: def _should_fuse_mlp_allreduce_with_next_layer(self, forward_batch) -> bool:
"""Check if MLP allreduce can be fused with next layer's add_rmsnorm""" """Check if MLP allreduce can be fused with next layer's residual_rmsnorm"""
if (
self.layer_id == self.config.num_hidden_layers - 1
or get_tensor_model_parallel_world_size() <= 1
):
return False
if not global_server_args_dict.get("enable_flashinfer_allreduce_fusion", False):
return False
if not _is_sm100_supported or not _is_flashinfer_available: batch_size = (
return False forward_batch.input_ids.shape[0]
if hasattr(forward_batch, "input_ids")
else 0
)
if hasattr(forward_batch, "input_ids") and ( if batch_size > 128:
forward_batch.input_ids.shape[0] == 0
or forward_batch.input_ids.shape[0] > 128
):
return False return False
return True return self._fuse_allreduce_lookup_table.get(batch_size, False)
def forward( def forward(
self, self,
...@@ -1896,7 +1891,7 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1896,7 +1891,7 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_states, residual, forward_batch hidden_states, residual, forward_batch
) )
can_fuse_mlp_allreduce = ( should_allreduce_fusion = (
self._should_fuse_mlp_allreduce_with_next_layer(forward_batch) self._should_fuse_mlp_allreduce_with_next_layer(forward_batch)
and not (self.enable_dp_attention and self.speculative_algorithm.is_eagle()) and not (self.enable_dp_attention and self.speculative_algorithm.is_eagle())
and not self.is_nextn and not self.is_nextn
...@@ -1907,13 +1902,13 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1907,13 +1902,13 @@ class DeepseekV2DecoderLayer(nn.Module):
forward_batch forward_batch
) )
hidden_states = self.mlp( hidden_states = self.mlp(
hidden_states, forward_batch, can_fuse_mlp_allreduce, use_reduce_scatter hidden_states, forward_batch, should_allreduce_fusion, use_reduce_scatter
) )
if can_fuse_mlp_allreduce: if should_allreduce_fusion:
hidden_states._sglang_needs_allreduce_fusion = True hidden_states._sglang_needs_allreduce_fusion = True
if not can_fuse_mlp_allreduce: if not should_allreduce_fusion:
hidden_states, residual = self.layer_communicator.postprocess_layer( hidden_states, residual = self.layer_communicator.postprocess_layer(
hidden_states, residual, forward_batch hidden_states, residual, forward_batch
) )
...@@ -1990,6 +1985,26 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1990,6 +1985,26 @@ class DeepseekV2DecoderLayer(nn.Module):
) )
return output return output
def _build_fuse_allreduce_lookup_table(self):
static_conditions_met = (
self.layer_id != self.config.num_hidden_layers - 1
and get_tensor_model_parallel_world_size() > 1
and global_server_args_dict.get("enable_flashinfer_allreduce_fusion", False)
and _is_sm100_supported
and _is_flashinfer_available
)
if not static_conditions_met:
return {}
lookup_table = {}
for batch_size in range(129): # 0 to 128
is_last_layer = self.layer_id == self.config.num_hidden_layers - 1
should_fuse = batch_size > 0 and batch_size <= 128 and not is_last_layer
lookup_table[batch_size] = should_fuse
return lookup_table
class DeepseekV2Model(nn.Module): class DeepseekV2Model(nn.Module):
fall_back_to_pt_during_load = False fall_back_to_pt_during_load = False
......
...@@ -154,13 +154,13 @@ class Glm4MoeMLP(nn.Module): ...@@ -154,13 +154,13 @@ class Glm4MoeMLP(nn.Module):
) )
self.act_fn = SiluAndMul() self.act_fn = SiluAndMul()
def forward(self, x, forward_batch=None, can_fuse_mlp_allreduce=False): def forward(self, x, forward_batch=None, should_allreduce_fusion=False):
if (self.tp_size == 1) and x.shape[0] == 0: if (self.tp_size == 1) and x.shape[0] == 0:
return x return x
gate_up, _ = self.gate_up_proj(x) gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up) x = self.act_fn(gate_up)
x, _ = self.down_proj(x, skip_all_reduce=can_fuse_mlp_allreduce) x, _ = self.down_proj(x, skip_all_reduce=should_allreduce_fusion)
return x return x
...@@ -529,7 +529,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE): ...@@ -529,7 +529,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
def forward_normal_dual_stream( def forward_normal_dual_stream(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
can_fuse_mlp_allreduce: bool = False, should_allreduce_fusion: bool = False,
use_reduce_scatter: bool = False, use_reduce_scatter: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -553,7 +553,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE): ...@@ -553,7 +553,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
if self.ep_size > 1: if self.ep_size > 1:
if ( if (
self.tp_size > 1 self.tp_size > 1
and not can_fuse_mlp_allreduce and not should_allreduce_fusion
and not use_reduce_scatter and not use_reduce_scatter
): ):
final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states = tensor_model_parallel_all_reduce(
...@@ -564,7 +564,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE): ...@@ -564,7 +564,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
final_hidden_states += shared_output final_hidden_states += shared_output
if ( if (
self.tp_size > 1 self.tp_size > 1
and not can_fuse_mlp_allreduce and not should_allreduce_fusion
and not use_reduce_scatter and not use_reduce_scatter
): ):
final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states = tensor_model_parallel_all_reduce(
...@@ -575,13 +575,13 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE): ...@@ -575,13 +575,13 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
def forward_normal( def forward_normal(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
can_fuse_mlp_allreduce: bool = False, should_allreduce_fusion: bool = False,
use_reduce_scatter: bool = False, use_reduce_scatter: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
if hasattr(self, "shared_experts") and use_intel_amx_backend( if hasattr(self, "shared_experts") and use_intel_amx_backend(
self.shared_experts.gate_up_proj self.shared_experts.gate_up_proj
): ):
return self.forward_cpu(hidden_states, can_fuse_mlp_allreduce) return self.forward_cpu(hidden_states, should_allreduce_fusion)
shared_output = self._forward_shared_experts(hidden_states) shared_output = self._forward_shared_experts(hidden_states)
# router_logits: (num_tokens, n_experts) # router_logits: (num_tokens, n_experts)
...@@ -596,7 +596,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE): ...@@ -596,7 +596,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
# fused in biased_grouped_topk so we can skip here # fused in biased_grouped_topk so we can skip here
final_hidden_states *= self.routed_scaling_factor final_hidden_states *= self.routed_scaling_factor
if self.ep_size > 1: if self.ep_size > 1:
if self.tp_size > 1 and not can_fuse_mlp_allreduce: if self.tp_size > 1 and not should_allreduce_fusion:
final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states final_hidden_states
) )
...@@ -605,7 +605,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE): ...@@ -605,7 +605,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
else: else:
if shared_output is not None: if shared_output is not None:
final_hidden_states += shared_output final_hidden_states += shared_output
if self.tp_size > 1 and not can_fuse_mlp_allreduce: if self.tp_size > 1 and not should_allreduce_fusion:
final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states final_hidden_states
) )
......
...@@ -56,7 +56,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig ...@@ -56,7 +56,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.fp8_utils import dequant_mxfp4 from sglang.srt.layers.quantization.fp8_utils import dequant_mxfp4
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.utils import PPMissingLayer, get_layer_id from sglang.srt.layers.utils import PPMissingLayer, get_layer_id, is_sm100_supported
from sglang.srt.layers.vocab_parallel_embedding import ( from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
...@@ -64,7 +64,10 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ...@@ -64,7 +64,10 @@ from sglang.srt.layers.vocab_parallel_embedding import (
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import add_prefix, make_layers from sglang.srt.utils import add_prefix, is_cuda, is_flashinfer_available, make_layers
_is_flashinfer_available = is_flashinfer_available()
_is_sm100_supported = is_cuda() and is_sm100_supported()
class GptOssConfig(PretrainedConfig): class GptOssConfig(PretrainedConfig):
...@@ -151,10 +154,13 @@ class GptOssSparseMoeBlock(nn.Module): ...@@ -151,10 +154,13 @@ class GptOssSparseMoeBlock(nn.Module):
) )
def forward( def forward(
self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None self,
hidden_states: torch.Tensor,
forward_batch: Optional[ForwardBatch] = None,
should_allreduce_fusion: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
if not global_server_args_dict["moe_a2a_backend"].is_deepep(): if not global_server_args_dict["moe_a2a_backend"].is_deepep():
return self.forward_normal(hidden_states) return self.forward_normal(hidden_states, should_allreduce_fusion)
else: else:
raise Exception("forward_deepep branch not implemented yet") raise Exception("forward_deepep branch not implemented yet")
...@@ -165,7 +171,11 @@ class GptOssSparseMoeBlock(nn.Module): ...@@ -165,7 +171,11 @@ class GptOssSparseMoeBlock(nn.Module):
if name not in ["correction_bias"] if name not in ["correction_bias"]
] ]
def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward_normal(
self,
hidden_states: torch.Tensor,
should_allreduce_fusion: bool = False,
) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim) hidden_states = hidden_states.view(-1, hidden_dim)
...@@ -179,7 +189,7 @@ class GptOssSparseMoeBlock(nn.Module): ...@@ -179,7 +189,7 @@ class GptOssSparseMoeBlock(nn.Module):
kwargs["topk_output"] = (self.top_k, router_logits) kwargs["topk_output"] = (self.top_k, router_logits)
final_hidden_states = self.experts(**kwargs) final_hidden_states = self.experts(**kwargs)
if self.tp_size > 1: if self.tp_size > 1 and not should_allreduce_fusion:
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
ans = final_hidden_states.view(num_tokens, hidden_dim) ans = final_hidden_states.view(num_tokens, hidden_dim)
...@@ -370,6 +380,7 @@ class GptOssDecoderLayer(nn.Module): ...@@ -370,6 +380,7 @@ class GptOssDecoderLayer(nn.Module):
# GptOss all layers are sparse and have no nextn now # GptOss all layers are sparse and have no nextn now
self.is_layer_sparse = True self.is_layer_sparse = True
self.is_nextn = False
is_previous_layer_sparse = True is_previous_layer_sparse = True
self.layer_scatter_modes = LayerScatterModes.init_new( self.layer_scatter_modes = LayerScatterModes.init_new(
...@@ -402,6 +413,42 @@ class GptOssDecoderLayer(nn.Module): ...@@ -402,6 +413,42 @@ class GptOssDecoderLayer(nn.Module):
post_attention_layernorm=self.post_attention_layernorm, post_attention_layernorm=self.post_attention_layernorm,
) )
self._fuse_allreduce_lookup_table = self._build_fuse_allreduce_lookup_table()
def _should_fuse_mlp_allreduce_with_next_layer(self, forward_batch) -> bool:
"""Check if MLP allreduce can be fused with next layer's residual_rmsnorm"""
batch_size = (
forward_batch.input_ids.shape[0]
if hasattr(forward_batch, "input_ids")
else 0
)
if batch_size > 128:
return False
return self._fuse_allreduce_lookup_table.get(batch_size, False)
def _build_fuse_allreduce_lookup_table(self):
static_conditions_met = (
self.layer_id != self.config.num_hidden_layers - 1
and get_tensor_model_parallel_world_size() > 1
and global_server_args_dict.get("enable_flashinfer_allreduce_fusion", False)
and _is_sm100_supported
and _is_flashinfer_available
)
if not static_conditions_met:
return {}
lookup_table = {}
for batch_size in range(129): # 0 to 128
is_last_layer = self.layer_id == self.config.num_hidden_layers - 1
should_fuse = batch_size > 0 and batch_size <= 128 and not is_last_layer
lookup_table[batch_size] = should_fuse
return lookup_table
def forward( def forward(
self, self,
positions: torch.Tensor, positions: torch.Tensor,
...@@ -424,8 +471,17 @@ class GptOssDecoderLayer(nn.Module): ...@@ -424,8 +471,17 @@ class GptOssDecoderLayer(nn.Module):
hidden_states, residual, forward_batch hidden_states, residual, forward_batch
) )
hidden_states = self.mlp(hidden_states, forward_batch) should_allreduce_fusion = (
self._should_fuse_mlp_allreduce_with_next_layer(forward_batch)
and not self.is_nextn
)
hidden_states = self.mlp(hidden_states, forward_batch, should_allreduce_fusion)
if should_allreduce_fusion:
hidden_states._sglang_needs_allreduce_fusion = True
if not should_allreduce_fusion:
hidden_states, residual = self.layer_communicator.postprocess_layer( hidden_states, residual = self.layer_communicator.postprocess_layer(
hidden_states, residual, forward_batch hidden_states, residual, forward_batch
) )
......
...@@ -1435,7 +1435,7 @@ class ServerArgs: ...@@ -1435,7 +1435,7 @@ class ServerArgs:
parser.add_argument( parser.add_argument(
"--enable-flashinfer-allreduce-fusion", "--enable-flashinfer-allreduce-fusion",
action="store_true", action="store_true",
help="Enable FlashInfer allreduce fusion for Add_RMSNorm.", help="Enable FlashInfer allreduce fusion with Residual RMSNorm.",
) )
parser.add_argument( parser.add_argument(
"--deepep-mode", "--deepep-mode",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment