Unverified Commit 44e86480 authored by Xiaoyu Zhang's avatar Xiaoyu Zhang Committed by GitHub
Browse files

fuse allreduce and residual_rmsnorm (#8731)

parent 8c07fabd
......@@ -441,7 +441,7 @@ class CommunicateWithAllReduceAndLayerNormFn:
and _is_flashinfer_available
and hasattr(layernorm, "forward_with_allreduce_fusion")
and global_server_args_dict["enable_flashinfer_allreduce_fusion"]
and hidden_states.shape[0] <= 128
and hidden_states.shape[0] <= 2048
):
hidden_states, residual = layernorm.forward_with_allreduce_fusion(
hidden_states, residual
......
......@@ -92,7 +92,7 @@ _workspace_manager = FlashInferWorkspaceManager()
def ensure_workspace_initialized(
max_token_num: int = 128, hidden_dim: int = 4096, use_fp32_lamport: bool = False
max_token_num: int = 2048, hidden_dim: int = 4096, use_fp32_lamport: bool = False
):
"""Ensure workspace is initialized"""
if not is_flashinfer_available() or _flashinfer_comm is None:
......@@ -124,8 +124,8 @@ def flashinfer_allreduce_residual_rmsnorm(
residual: torch.Tensor,
weight: torch.Tensor,
eps: float = 1e-6,
max_token_num: int = 128,
use_oneshot: bool = True,
max_token_num: int = 2048,
use_oneshot: Optional[bool] = None,
trigger_completion_at_end: bool = False,
fp32_acc: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
......
......@@ -1294,6 +1294,7 @@ class RowParallelLinear(LinearBase):
with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_)
sm.tag(output_parallel)
if self.reduce_results and self.tp_size > 1 and not skip_all_reduce:
output = tensor_model_parallel_all_reduce(output_parallel)
else:
......
......@@ -847,10 +847,14 @@ class FusedMoE(torch.nn.Module):
)
sm.tag(final_hidden_states)
final_hidden_states = final_hidden_states[
..., :origin_hidden_states_dim
].contiguous()
if self.reduce_results and (self.moe_tp_size > 1 or self.moe_ep_size > 1):
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
return final_hidden_states[..., :origin_hidden_states_dim].contiguous()
return final_hidden_states
@classmethod
def make_expert_params_mapping(
......
......@@ -212,7 +212,7 @@ class DeepseekV2MLP(nn.Module):
self,
x,
forward_batch=None,
can_fuse_mlp_allreduce: bool = False,
should_allreduce_fusion: bool = False,
use_reduce_scatter: bool = False,
):
if (self.tp_size == 1) and x.shape[0] == 0:
......@@ -221,7 +221,7 @@ class DeepseekV2MLP(nn.Module):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(
x, skip_all_reduce=can_fuse_mlp_allreduce or use_reduce_scatter
x, skip_all_reduce=should_allreduce_fusion or use_reduce_scatter
)
return x
......@@ -448,7 +448,7 @@ class DeepseekV2MoE(nn.Module):
self,
hidden_states: torch.Tensor,
forward_batch: Optional[ForwardBatch] = None,
can_fuse_mlp_allreduce: bool = False,
should_allreduce_fusion: bool = False,
use_reduce_scatter: bool = False,
) -> torch.Tensor:
if not self._enable_deepep_moe:
......@@ -459,11 +459,11 @@ class DeepseekV2MoE(nn.Module):
and hidden_states.shape[0] <= DUAL_STREAM_TOKEN_THRESHOLD
):
return self.forward_normal_dual_stream(
hidden_states, can_fuse_mlp_allreduce, use_reduce_scatter
hidden_states, should_allreduce_fusion, use_reduce_scatter
)
else:
return self.forward_normal(
hidden_states, can_fuse_mlp_allreduce, use_reduce_scatter
hidden_states, should_allreduce_fusion, use_reduce_scatter
)
else:
return self.forward_deepep(hidden_states, forward_batch)
......@@ -471,7 +471,7 @@ class DeepseekV2MoE(nn.Module):
def forward_normal_dual_stream(
self,
hidden_states: torch.Tensor,
can_fuse_mlp_allreduce: bool = False,
should_allreduce_fusion: bool = False,
use_reduce_scatter: bool = False,
) -> torch.Tensor:
......@@ -500,20 +500,20 @@ class DeepseekV2MoE(nn.Module):
torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
final_hidden_states = final_hidden_states_out
sm.tag(final_hidden_states)
if self.tp_size > 1 and not can_fuse_mlp_allreduce and not use_reduce_scatter:
if self.tp_size > 1 and not should_allreduce_fusion and not use_reduce_scatter:
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
return final_hidden_states
def forward_normal(
self,
hidden_states: torch.Tensor,
can_fuse_mlp_allreduce: bool = False,
should_allreduce_fusion: bool = False,
use_reduce_scatter: bool = False,
) -> torch.Tensor:
if hasattr(self, "shared_experts") and use_intel_amx_backend(
self.shared_experts.gate_up_proj
):
return self.forward_cpu(hidden_states, can_fuse_mlp_allreduce)
return self.forward_cpu(hidden_states, should_allreduce_fusion)
shared_output = self._forward_shared_experts(hidden_states)
# router_logits: (num_tokens, n_experts)
......@@ -537,12 +537,14 @@ class DeepseekV2MoE(nn.Module):
torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
final_hidden_states = final_hidden_states_out
sm.tag(final_hidden_states)
if self.tp_size > 1 and not can_fuse_mlp_allreduce and not use_reduce_scatter:
if self.tp_size > 1 and not should_allreduce_fusion and not use_reduce_scatter:
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
return final_hidden_states
def forward_cpu(
self, hidden_states: torch.Tensor, can_fuse_mlp_allreduce: bool = False
self,
hidden_states: torch.Tensor,
should_allreduce_fusion: bool = False,
) -> torch.Tensor:
# router_logits: (num_tokens, n_experts)
router_logits = self.gate(hidden_states)
......@@ -593,7 +595,7 @@ class DeepseekV2MoE(nn.Module):
None, # a2_scale
True, # is_vnni
)
if self.tp_size > 1 and not can_fuse_mlp_allreduce:
if self.tp_size > 1 and not should_allreduce_fusion:
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
return final_hidden_states
......@@ -1842,6 +1844,8 @@ class DeepseekV2DecoderLayer(nn.Module):
allow_reduce_scatter=True,
)
self._fuse_allreduce_lookup_table = self._build_fuse_allreduce_lookup_table()
def _is_layer_sparse(self, layer_id: int, is_nextn: bool) -> bool:
return is_nextn or (
self.config.n_routed_experts is not None
......@@ -1850,27 +1854,18 @@ class DeepseekV2DecoderLayer(nn.Module):
)
def _should_fuse_mlp_allreduce_with_next_layer(self, forward_batch) -> bool:
"""Check if MLP allreduce can be fused with next layer's add_rmsnorm"""
if (
self.layer_id == self.config.num_hidden_layers - 1
or get_tensor_model_parallel_world_size() <= 1
):
return False
if not global_server_args_dict.get("enable_flashinfer_allreduce_fusion", False):
return False
"""Check if MLP allreduce can be fused with next layer's residual_rmsnorm"""
if not _is_sm100_supported or not _is_flashinfer_available:
return False
batch_size = (
forward_batch.input_ids.shape[0]
if hasattr(forward_batch, "input_ids")
else 0
)
if hasattr(forward_batch, "input_ids") and (
forward_batch.input_ids.shape[0] == 0
or forward_batch.input_ids.shape[0] > 128
):
if batch_size > 128:
return False
return True
return self._fuse_allreduce_lookup_table.get(batch_size, False)
def forward(
self,
......@@ -1896,7 +1891,7 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_states, residual, forward_batch
)
can_fuse_mlp_allreduce = (
should_allreduce_fusion = (
self._should_fuse_mlp_allreduce_with_next_layer(forward_batch)
and not (self.enable_dp_attention and self.speculative_algorithm.is_eagle())
and not self.is_nextn
......@@ -1907,13 +1902,13 @@ class DeepseekV2DecoderLayer(nn.Module):
forward_batch
)
hidden_states = self.mlp(
hidden_states, forward_batch, can_fuse_mlp_allreduce, use_reduce_scatter
hidden_states, forward_batch, should_allreduce_fusion, use_reduce_scatter
)
if can_fuse_mlp_allreduce:
if should_allreduce_fusion:
hidden_states._sglang_needs_allreduce_fusion = True
if not can_fuse_mlp_allreduce:
if not should_allreduce_fusion:
hidden_states, residual = self.layer_communicator.postprocess_layer(
hidden_states, residual, forward_batch
)
......@@ -1990,6 +1985,26 @@ class DeepseekV2DecoderLayer(nn.Module):
)
return output
def _build_fuse_allreduce_lookup_table(self):
static_conditions_met = (
self.layer_id != self.config.num_hidden_layers - 1
and get_tensor_model_parallel_world_size() > 1
and global_server_args_dict.get("enable_flashinfer_allreduce_fusion", False)
and _is_sm100_supported
and _is_flashinfer_available
)
if not static_conditions_met:
return {}
lookup_table = {}
for batch_size in range(129): # 0 to 128
is_last_layer = self.layer_id == self.config.num_hidden_layers - 1
should_fuse = batch_size > 0 and batch_size <= 128 and not is_last_layer
lookup_table[batch_size] = should_fuse
return lookup_table
class DeepseekV2Model(nn.Module):
fall_back_to_pt_during_load = False
......
......@@ -154,13 +154,13 @@ class Glm4MoeMLP(nn.Module):
)
self.act_fn = SiluAndMul()
def forward(self, x, forward_batch=None, can_fuse_mlp_allreduce=False):
def forward(self, x, forward_batch=None, should_allreduce_fusion=False):
if (self.tp_size == 1) and x.shape[0] == 0:
return x
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x, skip_all_reduce=can_fuse_mlp_allreduce)
x, _ = self.down_proj(x, skip_all_reduce=should_allreduce_fusion)
return x
......@@ -529,7 +529,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
def forward_normal_dual_stream(
self,
hidden_states: torch.Tensor,
can_fuse_mlp_allreduce: bool = False,
should_allreduce_fusion: bool = False,
use_reduce_scatter: bool = False,
) -> torch.Tensor:
......@@ -553,7 +553,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
if self.ep_size > 1:
if (
self.tp_size > 1
and not can_fuse_mlp_allreduce
and not should_allreduce_fusion
and not use_reduce_scatter
):
final_hidden_states = tensor_model_parallel_all_reduce(
......@@ -564,7 +564,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
final_hidden_states += shared_output
if (
self.tp_size > 1
and not can_fuse_mlp_allreduce
and not should_allreduce_fusion
and not use_reduce_scatter
):
final_hidden_states = tensor_model_parallel_all_reduce(
......@@ -575,13 +575,13 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
def forward_normal(
self,
hidden_states: torch.Tensor,
can_fuse_mlp_allreduce: bool = False,
should_allreduce_fusion: bool = False,
use_reduce_scatter: bool = False,
) -> torch.Tensor:
if hasattr(self, "shared_experts") and use_intel_amx_backend(
self.shared_experts.gate_up_proj
):
return self.forward_cpu(hidden_states, can_fuse_mlp_allreduce)
return self.forward_cpu(hidden_states, should_allreduce_fusion)
shared_output = self._forward_shared_experts(hidden_states)
# router_logits: (num_tokens, n_experts)
......@@ -596,7 +596,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
# fused in biased_grouped_topk so we can skip here
final_hidden_states *= self.routed_scaling_factor
if self.ep_size > 1:
if self.tp_size > 1 and not can_fuse_mlp_allreduce:
if self.tp_size > 1 and not should_allreduce_fusion:
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states
)
......@@ -605,7 +605,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
else:
if shared_output is not None:
final_hidden_states += shared_output
if self.tp_size > 1 and not can_fuse_mlp_allreduce:
if self.tp_size > 1 and not should_allreduce_fusion:
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states
)
......
......@@ -56,7 +56,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.fp8_utils import dequant_mxfp4
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
from sglang.srt.layers.utils import PPMissingLayer, get_layer_id, is_sm100_supported
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
......@@ -64,7 +64,10 @@ from sglang.srt.layers.vocab_parallel_embedding import (
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import add_prefix, make_layers
from sglang.srt.utils import add_prefix, is_cuda, is_flashinfer_available, make_layers
_is_flashinfer_available = is_flashinfer_available()
_is_sm100_supported = is_cuda() and is_sm100_supported()
class GptOssConfig(PretrainedConfig):
......@@ -151,10 +154,13 @@ class GptOssSparseMoeBlock(nn.Module):
)
def forward(
self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None
self,
hidden_states: torch.Tensor,
forward_batch: Optional[ForwardBatch] = None,
should_allreduce_fusion: bool = False,
) -> torch.Tensor:
if not global_server_args_dict["moe_a2a_backend"].is_deepep():
return self.forward_normal(hidden_states)
return self.forward_normal(hidden_states, should_allreduce_fusion)
else:
raise Exception("forward_deepep branch not implemented yet")
......@@ -165,7 +171,11 @@ class GptOssSparseMoeBlock(nn.Module):
if name not in ["correction_bias"]
]
def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
def forward_normal(
self,
hidden_states: torch.Tensor,
should_allreduce_fusion: bool = False,
) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
......@@ -179,7 +189,7 @@ class GptOssSparseMoeBlock(nn.Module):
kwargs["topk_output"] = (self.top_k, router_logits)
final_hidden_states = self.experts(**kwargs)
if self.tp_size > 1:
if self.tp_size > 1 and not should_allreduce_fusion:
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
ans = final_hidden_states.view(num_tokens, hidden_dim)
......@@ -370,6 +380,7 @@ class GptOssDecoderLayer(nn.Module):
# GptOss all layers are sparse and have no nextn now
self.is_layer_sparse = True
self.is_nextn = False
is_previous_layer_sparse = True
self.layer_scatter_modes = LayerScatterModes.init_new(
......@@ -402,6 +413,42 @@ class GptOssDecoderLayer(nn.Module):
post_attention_layernorm=self.post_attention_layernorm,
)
self._fuse_allreduce_lookup_table = self._build_fuse_allreduce_lookup_table()
def _should_fuse_mlp_allreduce_with_next_layer(self, forward_batch) -> bool:
"""Check if MLP allreduce can be fused with next layer's residual_rmsnorm"""
batch_size = (
forward_batch.input_ids.shape[0]
if hasattr(forward_batch, "input_ids")
else 0
)
if batch_size > 128:
return False
return self._fuse_allreduce_lookup_table.get(batch_size, False)
def _build_fuse_allreduce_lookup_table(self):
static_conditions_met = (
self.layer_id != self.config.num_hidden_layers - 1
and get_tensor_model_parallel_world_size() > 1
and global_server_args_dict.get("enable_flashinfer_allreduce_fusion", False)
and _is_sm100_supported
and _is_flashinfer_available
)
if not static_conditions_met:
return {}
lookup_table = {}
for batch_size in range(129): # 0 to 128
is_last_layer = self.layer_id == self.config.num_hidden_layers - 1
should_fuse = batch_size > 0 and batch_size <= 128 and not is_last_layer
lookup_table[batch_size] = should_fuse
return lookup_table
def forward(
self,
positions: torch.Tensor,
......@@ -424,12 +471,21 @@ class GptOssDecoderLayer(nn.Module):
hidden_states, residual, forward_batch
)
hidden_states = self.mlp(hidden_states, forward_batch)
hidden_states, residual = self.layer_communicator.postprocess_layer(
hidden_states, residual, forward_batch
should_allreduce_fusion = (
self._should_fuse_mlp_allreduce_with_next_layer(forward_batch)
and not self.is_nextn
)
hidden_states = self.mlp(hidden_states, forward_batch, should_allreduce_fusion)
if should_allreduce_fusion:
hidden_states._sglang_needs_allreduce_fusion = True
if not should_allreduce_fusion:
hidden_states, residual = self.layer_communicator.postprocess_layer(
hidden_states, residual, forward_batch
)
return hidden_states, residual
......
......@@ -1435,7 +1435,7 @@ class ServerArgs:
parser.add_argument(
"--enable-flashinfer-allreduce-fusion",
action="store_true",
help="Enable FlashInfer allreduce fusion for Add_RMSNorm.",
help="Enable FlashInfer allreduce fusion with Residual RMSNorm.",
)
parser.add_argument(
"--deepep-mode",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment