Unverified Commit c0e84297 authored by Trevor Morris's avatar Trevor Morris Committed by GitHub
Browse files

Use reduce scatter for DP (#8539)

parent 92cc32d9
......@@ -27,6 +27,7 @@ from sglang.srt.layers.dp_attention import (
attn_tp_all_gather_into_tensor,
attn_tp_reduce_scatter_tensor,
dp_gather_partial,
dp_reduce_scatter_tensor,
dp_scatter,
get_attention_dp_size,
get_attention_tp_rank,
......@@ -149,10 +150,13 @@ class LayerCommunicator:
layer_scatter_modes: LayerScatterModes,
input_layernorm: torch.nn.Module,
post_attention_layernorm: torch.nn.Module,
# Reduce scatter requires skipping all-reduce in model code after MoE/MLP, so only enable for models which have that implemented. Remove flag once done for all models that use LayerCommunicator.
allow_reduce_scatter: bool = False,
):
self.layer_scatter_modes = layer_scatter_modes
self.input_layernorm = input_layernorm
self.post_attention_layernorm = post_attention_layernorm
self.allow_reduce_scatter = allow_reduce_scatter
self._context = CommunicateContext.init_new()
self._communicate_simple_fn = CommunicateSimpleFn.get_fn(
......@@ -239,6 +243,15 @@ class LayerCommunicator:
residual=residual,
forward_batch=forward_batch,
context=self._context,
allow_reduce_scatter=self.allow_reduce_scatter,
)
def should_use_reduce_scatter(self, forward_batch: ForwardBatch):
return (
self.allow_reduce_scatter
and self._communicate_summable_tensor_pair_fn
is CommunicateSummableTensorPairFn._scatter_hidden_states
and forward_batch.dp_padding_mode.is_max_len()
)
......@@ -524,6 +537,7 @@ class CommunicateSummableTensorPairFn:
residual: torch.Tensor,
forward_batch: ForwardBatch,
context: CommunicateContext,
**kwargs,
):
return hidden_states, residual
......@@ -533,15 +547,17 @@ class CommunicateSummableTensorPairFn:
residual: torch.Tensor,
forward_batch: ForwardBatch,
context: CommunicateContext,
allow_reduce_scatter: bool = False,
):
# TODO(ch-wan): use reduce-scatter in MLP to avoid this scatter
# important: forward batch.gathered_buffer is used both after scatter and after gather.
# be careful about this!
hidden_states, global_hidden_states = (
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
hidden_states,
)
dp_scatter(hidden_states, global_hidden_states, forward_batch)
if allow_reduce_scatter and forward_batch.dp_padding_mode.is_max_len():
# When using padding, all_reduce is skipped after MLP and MOE and reduce scatter is used here instead.
dp_reduce_scatter_tensor(hidden_states, global_hidden_states)
else:
dp_scatter(hidden_states, global_hidden_states, forward_batch)
return hidden_states, residual
@staticmethod
......@@ -550,6 +566,7 @@ class CommunicateSummableTensorPairFn:
residual: torch.Tensor,
forward_batch: ForwardBatch,
context: CommunicateContext,
**kwargs,
):
hidden_states += residual
residual = None
......
......@@ -12,6 +12,7 @@ import triton.language as tl
from sglang.srt.distributed import (
GroupCoordinator,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
get_tp_group,
tensor_model_parallel_all_reduce,
......@@ -355,6 +356,17 @@ def dp_scatter(
)
def dp_reduce_scatter_tensor(output: torch.Tensor, input: torch.Tensor):
if get_tensor_model_parallel_world_size() == get_attention_dp_size():
get_tp_group().reduce_scatter_tensor(output, input)
else:
scattered_local_tokens = input.tensor_split(
get_tensor_model_parallel_world_size()
)[get_tensor_model_parallel_rank()]
get_tp_group().reduce_scatter_tensor(scattered_local_tokens, input)
get_attention_tp_group().all_gather_into_tensor(output, scattered_local_tokens)
def attn_tp_reduce_scatter_tensor(output: torch.Tensor, input: torch.Tensor):
return get_attention_tp_group().reduce_scatter_tensor(output, input)
......
......@@ -1277,7 +1277,7 @@ class RowParallelLinear(LinearBase):
# It does not support additional parameters.
param.load_row_parallel_weight(loaded_weight)
def forward(self, input_, can_fuse_mlp_allreduce=False):
def forward(self, input_, skip_all_reduce=False):
if self.input_is_parallel:
input_parallel = input_
else:
......@@ -1294,7 +1294,7 @@ class RowParallelLinear(LinearBase):
with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_)
sm.tag(output_parallel)
if self.reduce_results and self.tp_size > 1 and not can_fuse_mlp_allreduce:
if self.reduce_results and self.tp_size > 1 and not skip_all_reduce:
output = tensor_model_parallel_all_reduce(output_parallel)
else:
output = output_parallel
......
......@@ -628,8 +628,10 @@ class ForwardBatch:
self.dp_padding_mode = dp_padding_mode
if dp_padding_mode.is_max_len():
# when DP gather mode is all gather, we will use all_gather_into_tensor to gather hidden states,
# where transferred tokens should be padded to the same length.
# when DP gather mode is all gather, we will use
# all_gather_into_tensor to gather hidden states, where transferred
# tokens should be padded to the same length. We will also use
# reduce-scatter instead of all-reduce after MLP.
max_num_tokens = max(global_num_tokens)
global_num_tokens = [max_num_tokens] * sync_group_size
buffer_len = max_num_tokens * sync_group_size
......
......@@ -208,13 +208,21 @@ class DeepseekV2MLP(nn.Module):
)
self.act_fn = SiluAndMul()
def forward(self, x, forward_batch=None, can_fuse_mlp_allreduce=False):
def forward(
self,
x,
forward_batch=None,
can_fuse_mlp_allreduce: bool = False,
use_reduce_scatter: bool = False,
):
if (self.tp_size == 1) and x.shape[0] == 0:
return x
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x, can_fuse_mlp_allreduce=can_fuse_mlp_allreduce)
x, _ = self.down_proj(
x, skip_all_reduce=can_fuse_mlp_allreduce or use_reduce_scatter
)
return x
......@@ -441,6 +449,7 @@ class DeepseekV2MoE(nn.Module):
hidden_states: torch.Tensor,
forward_batch: Optional[ForwardBatch] = None,
can_fuse_mlp_allreduce: bool = False,
use_reduce_scatter: bool = False,
) -> torch.Tensor:
if not self._enable_deepep_moe:
DUAL_STREAM_TOKEN_THRESHOLD = 1024
......@@ -450,15 +459,20 @@ class DeepseekV2MoE(nn.Module):
and hidden_states.shape[0] <= DUAL_STREAM_TOKEN_THRESHOLD
):
return self.forward_normal_dual_stream(
hidden_states, can_fuse_mlp_allreduce
hidden_states, can_fuse_mlp_allreduce, use_reduce_scatter
)
else:
return self.forward_normal(hidden_states, can_fuse_mlp_allreduce)
return self.forward_normal(
hidden_states, can_fuse_mlp_allreduce, use_reduce_scatter
)
else:
return self.forward_deepep(hidden_states, forward_batch)
def forward_normal_dual_stream(
self, hidden_states: torch.Tensor, can_fuse_mlp_allreduce: bool = False
self,
hidden_states: torch.Tensor,
can_fuse_mlp_allreduce: bool = False,
use_reduce_scatter: bool = False,
) -> torch.Tensor:
current_stream = torch.cuda.current_stream()
......@@ -486,12 +500,15 @@ class DeepseekV2MoE(nn.Module):
torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
final_hidden_states = final_hidden_states_out
sm.tag(final_hidden_states)
if self.tp_size > 1 and not can_fuse_mlp_allreduce:
if self.tp_size > 1 and not can_fuse_mlp_allreduce and not use_reduce_scatter:
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
return final_hidden_states
def forward_normal(
self, hidden_states: torch.Tensor, can_fuse_mlp_allreduce: bool = False
self,
hidden_states: torch.Tensor,
can_fuse_mlp_allreduce: bool = False,
use_reduce_scatter: bool = False,
) -> torch.Tensor:
if hasattr(self, "shared_experts") and use_intel_amx_backend(
self.shared_experts.gate_up_proj
......@@ -520,7 +537,7 @@ class DeepseekV2MoE(nn.Module):
torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
final_hidden_states = final_hidden_states_out
sm.tag(final_hidden_states)
if self.tp_size > 1 and not can_fuse_mlp_allreduce:
if self.tp_size > 1 and not can_fuse_mlp_allreduce and not use_reduce_scatter:
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
return final_hidden_states
......@@ -1822,6 +1839,7 @@ class DeepseekV2DecoderLayer(nn.Module):
layer_scatter_modes=self.layer_scatter_modes,
input_layernorm=self.input_layernorm,
post_attention_layernorm=self.post_attention_layernorm,
allow_reduce_scatter=True,
)
def _is_layer_sparse(self, layer_id: int, is_nextn: bool) -> bool:
......@@ -1884,7 +1902,13 @@ class DeepseekV2DecoderLayer(nn.Module):
and not self.is_nextn
)
hidden_states = self.mlp(hidden_states, forward_batch, can_fuse_mlp_allreduce)
# For DP with padding, reduce scatter can be used instead of all-reduce.
use_reduce_scatter = self.layer_communicator.should_use_reduce_scatter(
forward_batch
)
hidden_states = self.mlp(
hidden_states, forward_batch, can_fuse_mlp_allreduce, use_reduce_scatter
)
if can_fuse_mlp_allreduce:
hidden_states._sglang_needs_allreduce_fusion = True
......
......@@ -160,7 +160,7 @@ class Glm4MoeMLP(nn.Module):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x, can_fuse_mlp_allreduce=can_fuse_mlp_allreduce)
x, _ = self.down_proj(x, skip_all_reduce=can_fuse_mlp_allreduce)
return x
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment