Unverified Commit 8c298031 authored by Yi Zhang's avatar Yi Zhang Committed by GitHub
Browse files

refactor llama4 dp attention logic (#7729)

parent 4de03953
...@@ -27,9 +27,8 @@ from sglang.srt.distributed import ( ...@@ -27,9 +27,8 @@ from sglang.srt.distributed import (
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
) )
from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
from sglang.srt.layers.dp_attention import ( from sglang.srt.layers.dp_attention import (
dp_gather_partial,
dp_scatter,
get_attention_tp_rank, get_attention_tp_rank,
get_attention_tp_size, get_attention_tp_size,
get_local_attention_dp_size, get_local_attention_dp_size,
...@@ -367,7 +366,10 @@ class Llama4DecoderLayer(nn.Module): ...@@ -367,7 +366,10 @@ class Llama4DecoderLayer(nn.Module):
bias_o_proj=False, bias_o_proj=False,
prefix=add_prefix("self_attn", prefix), prefix=add_prefix("self_attn", prefix),
) )
is_moe_layer = (layer_id + 1) % config.interleave_moe_layer_step == 0 self.config = config
is_moe_layer = self._is_moe_layer(layer_id)
is_previous_moe_layer = self._is_moe_layer(layer_id - 1)
if is_moe_layer: if is_moe_layer:
self.feed_forward = Llama4MoE( self.feed_forward = Llama4MoE(
config=config, config=config,
...@@ -387,6 +389,22 @@ class Llama4DecoderLayer(nn.Module): ...@@ -387,6 +389,22 @@ class Llama4DecoderLayer(nn.Module):
config.hidden_size, eps=config.rms_norm_eps config.hidden_size, eps=config.rms_norm_eps
) )
self.layer_scatter_modes = LayerScatterModes.init_new(
layer_id=layer_id,
num_layers=config.num_hidden_layers,
is_layer_sparse=is_moe_layer,
is_previous_layer_sparse=is_previous_moe_layer,
)
self.layer_communicator = LayerCommunicator(
layer_scatter_modes=self.layer_scatter_modes,
input_layernorm=self.input_layernorm,
post_attention_layernorm=self.post_attention_layernorm,
)
def _is_moe_layer(self, layer_id: int) -> bool:
return (layer_id + 1) % self.config.interleave_moe_layer_step == 0
def forward( def forward(
self, self,
positions: torch.Tensor, positions: torch.Tensor,
...@@ -394,57 +412,26 @@ class Llama4DecoderLayer(nn.Module): ...@@ -394,57 +412,26 @@ class Llama4DecoderLayer(nn.Module):
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
if hidden_states.shape[0] == 0: hidden_states, residual = self.layer_communicator.prepare_attn(
residual = hidden_states hidden_states, residual, forward_batch
else: )
# Self Attention
if residual is None: if hidden_states.shape[0] != 0:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
hidden_states = self.self_attn( hidden_states = self.self_attn(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
forward_batch=forward_batch, forward_batch=forward_batch,
) )
# Gather hidden_states, residual = self.layer_communicator.prepare_mlp(
if get_tensor_model_parallel_world_size() > 1: hidden_states, residual, forward_batch
# all gather and all reduce )
if self.local_dp_size != 1:
if self.attn_tp_rank == 0:
hidden_states += residual
hidden_states, local_hidden_states = (
forward_batch.gathered_buffer,
hidden_states,
)
dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
dp_scatter(residual, hidden_states, forward_batch)
hidden_states = self.post_attention_layernorm(hidden_states)
else:
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual
)
else:
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual
)
# Fully Connected # Fully Connected
hidden_states = self.feed_forward(hidden_states, forward_batch) hidden_states = self.feed_forward(hidden_states, forward_batch)
hidden_states, residual = self.layer_communicator.postprocess_layer(
# TODO(ch-wan): use reduce-scatter in MLP to avoid this scatter hidden_states, residual, forward_batch
# Scatter )
if self.local_dp_size != 1:
# important: forward batch.gathered_buffer is used both after scatter and after gather.
# be careful about this!
hidden_states, global_hidden_states = (
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
hidden_states,
)
dp_scatter(hidden_states, global_hidden_states, forward_batch)
return hidden_states, residual return hidden_states, residual
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment