Unverified Commit 8c298031 authored by Yi Zhang's avatar Yi Zhang Committed by GitHub
Browse files

refactor llama4 dp attention logic (#7729)

parent 4de03953
......@@ -27,9 +27,8 @@ from sglang.srt.distributed import (
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce,
)
from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
from sglang.srt.layers.dp_attention import (
dp_gather_partial,
dp_scatter,
get_attention_tp_rank,
get_attention_tp_size,
get_local_attention_dp_size,
......@@ -367,7 +366,10 @@ class Llama4DecoderLayer(nn.Module):
bias_o_proj=False,
prefix=add_prefix("self_attn", prefix),
)
is_moe_layer = (layer_id + 1) % config.interleave_moe_layer_step == 0
self.config = config
is_moe_layer = self._is_moe_layer(layer_id)
is_previous_moe_layer = self._is_moe_layer(layer_id - 1)
if is_moe_layer:
self.feed_forward = Llama4MoE(
config=config,
......@@ -387,6 +389,22 @@ class Llama4DecoderLayer(nn.Module):
config.hidden_size, eps=config.rms_norm_eps
)
self.layer_scatter_modes = LayerScatterModes.init_new(
layer_id=layer_id,
num_layers=config.num_hidden_layers,
is_layer_sparse=is_moe_layer,
is_previous_layer_sparse=is_previous_moe_layer,
)
self.layer_communicator = LayerCommunicator(
layer_scatter_modes=self.layer_scatter_modes,
input_layernorm=self.input_layernorm,
post_attention_layernorm=self.post_attention_layernorm,
)
def _is_moe_layer(self, layer_id: int) -> bool:
return (layer_id + 1) % self.config.interleave_moe_layer_step == 0
def forward(
self,
positions: torch.Tensor,
......@@ -394,57 +412,26 @@ class Llama4DecoderLayer(nn.Module):
forward_batch: ForwardBatch,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
if hidden_states.shape[0] == 0:
residual = hidden_states
else:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
hidden_states, residual = self.layer_communicator.prepare_attn(
hidden_states, residual, forward_batch
)
if hidden_states.shape[0] != 0:
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
forward_batch=forward_batch,
)
# Gather
if get_tensor_model_parallel_world_size() > 1:
# all gather and all reduce
if self.local_dp_size != 1:
if self.attn_tp_rank == 0:
hidden_states += residual
hidden_states, local_hidden_states = (
forward_batch.gathered_buffer,
hidden_states,
)
dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
dp_scatter(residual, hidden_states, forward_batch)
hidden_states = self.post_attention_layernorm(hidden_states)
else:
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual
)
else:
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual
)
hidden_states, residual = self.layer_communicator.prepare_mlp(
hidden_states, residual, forward_batch
)
# Fully Connected
hidden_states = self.feed_forward(hidden_states, forward_batch)
# TODO(ch-wan): use reduce-scatter in MLP to avoid this scatter
# Scatter
if self.local_dp_size != 1:
# important: forward batch.gathered_buffer is used both after scatter and after gather.
# be careful about this!
hidden_states, global_hidden_states = (
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
hidden_states,
)
dp_scatter(hidden_states, global_hidden_states, forward_batch)
hidden_states, residual = self.layer_communicator.postprocess_layer(
hidden_states, residual, forward_batch
)
return hidden_states, residual
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment