Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
8c298031
"server/text_generation_server/models/flash_gemma2.py" did not exist on "e71471bec95823ef69daaeb03c4657b9b5211a02"
Unverified
Commit
8c298031
authored
Jul 04, 2025
by
Yi Zhang
Committed by
GitHub
Jul 03, 2025
Browse files
refactor llama4 dp attention logic (#7729)
parent
4de03953
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
32 additions
and
45 deletions
+32
-45
python/sglang/srt/models/llama4.py
python/sglang/srt/models/llama4.py
+32
-45
No files found.
python/sglang/srt/models/llama4.py
View file @
8c298031
...
@@ -27,9 +27,8 @@ from sglang.srt.distributed import (
...
@@ -27,9 +27,8 @@ from sglang.srt.distributed import (
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_reduce
,
tensor_model_parallel_all_reduce
,
)
)
from
sglang.srt.layers.communicator
import
LayerCommunicator
,
LayerScatterModes
from
sglang.srt.layers.dp_attention
import
(
from
sglang.srt.layers.dp_attention
import
(
dp_gather_partial
,
dp_scatter
,
get_attention_tp_rank
,
get_attention_tp_rank
,
get_attention_tp_size
,
get_attention_tp_size
,
get_local_attention_dp_size
,
get_local_attention_dp_size
,
...
@@ -367,7 +366,10 @@ class Llama4DecoderLayer(nn.Module):
...
@@ -367,7 +366,10 @@ class Llama4DecoderLayer(nn.Module):
bias_o_proj
=
False
,
bias_o_proj
=
False
,
prefix
=
add_prefix
(
"self_attn"
,
prefix
),
prefix
=
add_prefix
(
"self_attn"
,
prefix
),
)
)
is_moe_layer
=
(
layer_id
+
1
)
%
config
.
interleave_moe_layer_step
==
0
self
.
config
=
config
is_moe_layer
=
self
.
_is_moe_layer
(
layer_id
)
is_previous_moe_layer
=
self
.
_is_moe_layer
(
layer_id
-
1
)
if
is_moe_layer
:
if
is_moe_layer
:
self
.
feed_forward
=
Llama4MoE
(
self
.
feed_forward
=
Llama4MoE
(
config
=
config
,
config
=
config
,
...
@@ -387,6 +389,22 @@ class Llama4DecoderLayer(nn.Module):
...
@@ -387,6 +389,22 @@ class Llama4DecoderLayer(nn.Module):
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
)
self
.
layer_scatter_modes
=
LayerScatterModes
.
init_new
(
layer_id
=
layer_id
,
num_layers
=
config
.
num_hidden_layers
,
is_layer_sparse
=
is_moe_layer
,
is_previous_layer_sparse
=
is_previous_moe_layer
,
)
self
.
layer_communicator
=
LayerCommunicator
(
layer_scatter_modes
=
self
.
layer_scatter_modes
,
input_layernorm
=
self
.
input_layernorm
,
post_attention_layernorm
=
self
.
post_attention_layernorm
,
)
def
_is_moe_layer
(
self
,
layer_id
:
int
)
->
bool
:
return
(
layer_id
+
1
)
%
self
.
config
.
interleave_moe_layer_step
==
0
def
forward
(
def
forward
(
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
...
@@ -394,57 +412,26 @@ class Llama4DecoderLayer(nn.Module):
...
@@ -394,57 +412,26 @@ class Llama4DecoderLayer(nn.Module):
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
residual
:
Optional
[
torch
.
Tensor
],
residual
:
Optional
[
torch
.
Tensor
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
hidden_states
.
shape
[
0
]
==
0
:
hidden_states
,
residual
=
self
.
layer_communicator
.
prepare_attn
(
residual
=
hidden_states
hidden_states
,
residual
,
forward_batch
else
:
)
# Self Attention
if
residual
is
None
:
if
hidden_states
.
shape
[
0
]
!=
0
:
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
else
:
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
self_attn
(
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
positions
=
positions
,
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
forward_batch
=
forward_batch
,
forward_batch
=
forward_batch
,
)
)
# Gather
hidden_states
,
residual
=
self
.
layer_communicator
.
prepare_mlp
(
if
get_tensor_model_parallel_world_size
()
>
1
:
hidden_states
,
residual
,
forward_batch
# all gather and all reduce
if
self
.
local_dp_size
!=
1
:
if
self
.
attn_tp_rank
==
0
:
hidden_states
+=
residual
hidden_states
,
local_hidden_states
=
(
forward_batch
.
gathered_buffer
,
hidden_states
,
)
dp_gather_partial
(
hidden_states
,
local_hidden_states
,
forward_batch
)
dp_scatter
(
residual
,
hidden_states
,
forward_batch
)
hidden_states
=
self
.
post_attention_layernorm
(
hidden_states
)
else
:
hidden_states
=
tensor_model_parallel_all_reduce
(
hidden_states
)
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
hidden_states
,
residual
)
else
:
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
hidden_states
,
residual
)
)
# Fully Connected
# Fully Connected
hidden_states
=
self
.
feed_forward
(
hidden_states
,
forward_batch
)
hidden_states
=
self
.
feed_forward
(
hidden_states
,
forward_batch
)
hidden_states
,
residual
=
self
.
layer_communicator
.
postprocess_layer
(
# TODO(ch-wan): use reduce-scatter in MLP to avoid this scatter
hidden_states
,
residual
,
forward_batch
# Scatter
if
self
.
local_dp_size
!=
1
:
# important: forward batch.gathered_buffer is used both after scatter and after gather.
# be careful about this!
hidden_states
,
global_hidden_states
=
(
forward_batch
.
gathered_buffer
[:
forward_batch
.
input_ids
.
shape
[
0
]],
hidden_states
,
)
)
dp_scatter
(
hidden_states
,
global_hidden_states
,
forward_batch
)
return
hidden_states
,
residual
return
hidden_states
,
residual
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment