Unverified Commit c0e84297 authored by Trevor Morris's avatar Trevor Morris Committed by GitHub
Browse files

Use reduce scatter for DP (#8539)

parent 92cc32d9
...@@ -27,6 +27,7 @@ from sglang.srt.layers.dp_attention import ( ...@@ -27,6 +27,7 @@ from sglang.srt.layers.dp_attention import (
attn_tp_all_gather_into_tensor, attn_tp_all_gather_into_tensor,
attn_tp_reduce_scatter_tensor, attn_tp_reduce_scatter_tensor,
dp_gather_partial, dp_gather_partial,
dp_reduce_scatter_tensor,
dp_scatter, dp_scatter,
get_attention_dp_size, get_attention_dp_size,
get_attention_tp_rank, get_attention_tp_rank,
...@@ -149,10 +150,13 @@ class LayerCommunicator: ...@@ -149,10 +150,13 @@ class LayerCommunicator:
layer_scatter_modes: LayerScatterModes, layer_scatter_modes: LayerScatterModes,
input_layernorm: torch.nn.Module, input_layernorm: torch.nn.Module,
post_attention_layernorm: torch.nn.Module, post_attention_layernorm: torch.nn.Module,
# Reduce scatter requires skipping all-reduce in model code after MoE/MLP, so only enable for models which have that implemented. Remove flag once done for all models that use LayerCommunicator.
allow_reduce_scatter: bool = False,
): ):
self.layer_scatter_modes = layer_scatter_modes self.layer_scatter_modes = layer_scatter_modes
self.input_layernorm = input_layernorm self.input_layernorm = input_layernorm
self.post_attention_layernorm = post_attention_layernorm self.post_attention_layernorm = post_attention_layernorm
self.allow_reduce_scatter = allow_reduce_scatter
self._context = CommunicateContext.init_new() self._context = CommunicateContext.init_new()
self._communicate_simple_fn = CommunicateSimpleFn.get_fn( self._communicate_simple_fn = CommunicateSimpleFn.get_fn(
...@@ -239,6 +243,15 @@ class LayerCommunicator: ...@@ -239,6 +243,15 @@ class LayerCommunicator:
residual=residual, residual=residual,
forward_batch=forward_batch, forward_batch=forward_batch,
context=self._context, context=self._context,
allow_reduce_scatter=self.allow_reduce_scatter,
)
def should_use_reduce_scatter(self, forward_batch: ForwardBatch):
return (
self.allow_reduce_scatter
and self._communicate_summable_tensor_pair_fn
is CommunicateSummableTensorPairFn._scatter_hidden_states
and forward_batch.dp_padding_mode.is_max_len()
) )
...@@ -524,6 +537,7 @@ class CommunicateSummableTensorPairFn: ...@@ -524,6 +537,7 @@ class CommunicateSummableTensorPairFn:
residual: torch.Tensor, residual: torch.Tensor,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
context: CommunicateContext, context: CommunicateContext,
**kwargs,
): ):
return hidden_states, residual return hidden_states, residual
...@@ -533,15 +547,17 @@ class CommunicateSummableTensorPairFn: ...@@ -533,15 +547,17 @@ class CommunicateSummableTensorPairFn:
residual: torch.Tensor, residual: torch.Tensor,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
context: CommunicateContext, context: CommunicateContext,
allow_reduce_scatter: bool = False,
): ):
# TODO(ch-wan): use reduce-scatter in MLP to avoid this scatter
# important: forward batch.gathered_buffer is used both after scatter and after gather.
# be careful about this!
hidden_states, global_hidden_states = ( hidden_states, global_hidden_states = (
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]], forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
hidden_states, hidden_states,
) )
dp_scatter(hidden_states, global_hidden_states, forward_batch) if allow_reduce_scatter and forward_batch.dp_padding_mode.is_max_len():
# When using padding, all_reduce is skipped after MLP and MOE and reduce scatter is used here instead.
dp_reduce_scatter_tensor(hidden_states, global_hidden_states)
else:
dp_scatter(hidden_states, global_hidden_states, forward_batch)
return hidden_states, residual return hidden_states, residual
@staticmethod @staticmethod
...@@ -550,6 +566,7 @@ class CommunicateSummableTensorPairFn: ...@@ -550,6 +566,7 @@ class CommunicateSummableTensorPairFn:
residual: torch.Tensor, residual: torch.Tensor,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
context: CommunicateContext, context: CommunicateContext,
**kwargs,
): ):
hidden_states += residual hidden_states += residual
residual = None residual = None
......
...@@ -12,6 +12,7 @@ import triton.language as tl ...@@ -12,6 +12,7 @@ import triton.language as tl
from sglang.srt.distributed import ( from sglang.srt.distributed import (
GroupCoordinator, GroupCoordinator,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
get_tp_group, get_tp_group,
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
...@@ -355,6 +356,17 @@ def dp_scatter( ...@@ -355,6 +356,17 @@ def dp_scatter(
) )
def dp_reduce_scatter_tensor(output: torch.Tensor, input: torch.Tensor):
if get_tensor_model_parallel_world_size() == get_attention_dp_size():
get_tp_group().reduce_scatter_tensor(output, input)
else:
scattered_local_tokens = input.tensor_split(
get_tensor_model_parallel_world_size()
)[get_tensor_model_parallel_rank()]
get_tp_group().reduce_scatter_tensor(scattered_local_tokens, input)
get_attention_tp_group().all_gather_into_tensor(output, scattered_local_tokens)
def attn_tp_reduce_scatter_tensor(output: torch.Tensor, input: torch.Tensor): def attn_tp_reduce_scatter_tensor(output: torch.Tensor, input: torch.Tensor):
return get_attention_tp_group().reduce_scatter_tensor(output, input) return get_attention_tp_group().reduce_scatter_tensor(output, input)
......
...@@ -1277,7 +1277,7 @@ class RowParallelLinear(LinearBase): ...@@ -1277,7 +1277,7 @@ class RowParallelLinear(LinearBase):
# It does not support additional parameters. # It does not support additional parameters.
param.load_row_parallel_weight(loaded_weight) param.load_row_parallel_weight(loaded_weight)
def forward(self, input_, can_fuse_mlp_allreduce=False): def forward(self, input_, skip_all_reduce=False):
if self.input_is_parallel: if self.input_is_parallel:
input_parallel = input_ input_parallel = input_
else: else:
...@@ -1294,7 +1294,7 @@ class RowParallelLinear(LinearBase): ...@@ -1294,7 +1294,7 @@ class RowParallelLinear(LinearBase):
with use_symmetric_memory(parallel_state.get_tp_group()) as sm: with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_) output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_)
sm.tag(output_parallel) sm.tag(output_parallel)
if self.reduce_results and self.tp_size > 1 and not can_fuse_mlp_allreduce: if self.reduce_results and self.tp_size > 1 and not skip_all_reduce:
output = tensor_model_parallel_all_reduce(output_parallel) output = tensor_model_parallel_all_reduce(output_parallel)
else: else:
output = output_parallel output = output_parallel
......
...@@ -628,8 +628,10 @@ class ForwardBatch: ...@@ -628,8 +628,10 @@ class ForwardBatch:
self.dp_padding_mode = dp_padding_mode self.dp_padding_mode = dp_padding_mode
if dp_padding_mode.is_max_len(): if dp_padding_mode.is_max_len():
# when DP gather mode is all gather, we will use all_gather_into_tensor to gather hidden states, # when DP gather mode is all gather, we will use
# where transferred tokens should be padded to the same length. # all_gather_into_tensor to gather hidden states, where transferred
# tokens should be padded to the same length. We will also use
# reduce-scatter instead of all-reduce after MLP.
max_num_tokens = max(global_num_tokens) max_num_tokens = max(global_num_tokens)
global_num_tokens = [max_num_tokens] * sync_group_size global_num_tokens = [max_num_tokens] * sync_group_size
buffer_len = max_num_tokens * sync_group_size buffer_len = max_num_tokens * sync_group_size
......
...@@ -208,13 +208,21 @@ class DeepseekV2MLP(nn.Module): ...@@ -208,13 +208,21 @@ class DeepseekV2MLP(nn.Module):
) )
self.act_fn = SiluAndMul() self.act_fn = SiluAndMul()
def forward(self, x, forward_batch=None, can_fuse_mlp_allreduce=False): def forward(
self,
x,
forward_batch=None,
can_fuse_mlp_allreduce: bool = False,
use_reduce_scatter: bool = False,
):
if (self.tp_size == 1) and x.shape[0] == 0: if (self.tp_size == 1) and x.shape[0] == 0:
return x return x
gate_up, _ = self.gate_up_proj(x) gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up) x = self.act_fn(gate_up)
x, _ = self.down_proj(x, can_fuse_mlp_allreduce=can_fuse_mlp_allreduce) x, _ = self.down_proj(
x, skip_all_reduce=can_fuse_mlp_allreduce or use_reduce_scatter
)
return x return x
...@@ -441,6 +449,7 @@ class DeepseekV2MoE(nn.Module): ...@@ -441,6 +449,7 @@ class DeepseekV2MoE(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
forward_batch: Optional[ForwardBatch] = None, forward_batch: Optional[ForwardBatch] = None,
can_fuse_mlp_allreduce: bool = False, can_fuse_mlp_allreduce: bool = False,
use_reduce_scatter: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
if not self._enable_deepep_moe: if not self._enable_deepep_moe:
DUAL_STREAM_TOKEN_THRESHOLD = 1024 DUAL_STREAM_TOKEN_THRESHOLD = 1024
...@@ -450,15 +459,20 @@ class DeepseekV2MoE(nn.Module): ...@@ -450,15 +459,20 @@ class DeepseekV2MoE(nn.Module):
and hidden_states.shape[0] <= DUAL_STREAM_TOKEN_THRESHOLD and hidden_states.shape[0] <= DUAL_STREAM_TOKEN_THRESHOLD
): ):
return self.forward_normal_dual_stream( return self.forward_normal_dual_stream(
hidden_states, can_fuse_mlp_allreduce hidden_states, can_fuse_mlp_allreduce, use_reduce_scatter
) )
else: else:
return self.forward_normal(hidden_states, can_fuse_mlp_allreduce) return self.forward_normal(
hidden_states, can_fuse_mlp_allreduce, use_reduce_scatter
)
else: else:
return self.forward_deepep(hidden_states, forward_batch) return self.forward_deepep(hidden_states, forward_batch)
def forward_normal_dual_stream( def forward_normal_dual_stream(
self, hidden_states: torch.Tensor, can_fuse_mlp_allreduce: bool = False self,
hidden_states: torch.Tensor,
can_fuse_mlp_allreduce: bool = False,
use_reduce_scatter: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
current_stream = torch.cuda.current_stream() current_stream = torch.cuda.current_stream()
...@@ -486,12 +500,15 @@ class DeepseekV2MoE(nn.Module): ...@@ -486,12 +500,15 @@ class DeepseekV2MoE(nn.Module):
torch.add(final_hidden_states, shared_output, out=final_hidden_states_out) torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
final_hidden_states = final_hidden_states_out final_hidden_states = final_hidden_states_out
sm.tag(final_hidden_states) sm.tag(final_hidden_states)
if self.tp_size > 1 and not can_fuse_mlp_allreduce: if self.tp_size > 1 and not can_fuse_mlp_allreduce and not use_reduce_scatter:
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
return final_hidden_states return final_hidden_states
def forward_normal( def forward_normal(
self, hidden_states: torch.Tensor, can_fuse_mlp_allreduce: bool = False self,
hidden_states: torch.Tensor,
can_fuse_mlp_allreduce: bool = False,
use_reduce_scatter: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
if hasattr(self, "shared_experts") and use_intel_amx_backend( if hasattr(self, "shared_experts") and use_intel_amx_backend(
self.shared_experts.gate_up_proj self.shared_experts.gate_up_proj
...@@ -520,7 +537,7 @@ class DeepseekV2MoE(nn.Module): ...@@ -520,7 +537,7 @@ class DeepseekV2MoE(nn.Module):
torch.add(final_hidden_states, shared_output, out=final_hidden_states_out) torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
final_hidden_states = final_hidden_states_out final_hidden_states = final_hidden_states_out
sm.tag(final_hidden_states) sm.tag(final_hidden_states)
if self.tp_size > 1 and not can_fuse_mlp_allreduce: if self.tp_size > 1 and not can_fuse_mlp_allreduce and not use_reduce_scatter:
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
return final_hidden_states return final_hidden_states
...@@ -1822,6 +1839,7 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1822,6 +1839,7 @@ class DeepseekV2DecoderLayer(nn.Module):
layer_scatter_modes=self.layer_scatter_modes, layer_scatter_modes=self.layer_scatter_modes,
input_layernorm=self.input_layernorm, input_layernorm=self.input_layernorm,
post_attention_layernorm=self.post_attention_layernorm, post_attention_layernorm=self.post_attention_layernorm,
allow_reduce_scatter=True,
) )
def _is_layer_sparse(self, layer_id: int, is_nextn: bool) -> bool: def _is_layer_sparse(self, layer_id: int, is_nextn: bool) -> bool:
...@@ -1884,7 +1902,13 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1884,7 +1902,13 @@ class DeepseekV2DecoderLayer(nn.Module):
and not self.is_nextn and not self.is_nextn
) )
hidden_states = self.mlp(hidden_states, forward_batch, can_fuse_mlp_allreduce) # For DP with padding, reduce scatter can be used instead of all-reduce.
use_reduce_scatter = self.layer_communicator.should_use_reduce_scatter(
forward_batch
)
hidden_states = self.mlp(
hidden_states, forward_batch, can_fuse_mlp_allreduce, use_reduce_scatter
)
if can_fuse_mlp_allreduce: if can_fuse_mlp_allreduce:
hidden_states._sglang_needs_allreduce_fusion = True hidden_states._sglang_needs_allreduce_fusion = True
......
...@@ -160,7 +160,7 @@ class Glm4MoeMLP(nn.Module): ...@@ -160,7 +160,7 @@ class Glm4MoeMLP(nn.Module):
gate_up, _ = self.gate_up_proj(x) gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up) x = self.act_fn(gate_up)
x, _ = self.down_proj(x, can_fuse_mlp_allreduce=can_fuse_mlp_allreduce) x, _ = self.down_proj(x, skip_all_reduce=can_fuse_mlp_allreduce)
return x return x
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment