Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
5039d547
Unverified
Commit
5039d547
authored
Apr 09, 2025
by
fzyzcjy
Committed by
GitHub
Apr 08, 2025
Browse files
Support 2x8xH100 for Llama 4 (#5159)
parent
d09a51f1
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
77 additions
and
19 deletions
+77
-19
python/sglang/srt/models/llama4.py
python/sglang/srt/models/llama4.py
+77
-19
No files found.
python/sglang/srt/models/llama4.py
View file @
5039d547
...
...
@@ -27,6 +27,13 @@ from sglang.srt.distributed import (
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_reduce
,
)
from
sglang.srt.layers.dp_attention
import
(
dp_gather_partial
,
dp_scatter
,
get_attention_dp_size
,
get_attention_tp_rank
,
get_attention_tp_size
,
)
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.linear
import
(
QKVParallelLinear
,
...
...
@@ -38,6 +45,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.rotary_embedding
import
get_rope
from
sglang.srt.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.models.llama
import
LlamaForCausalLM
,
LlamaMLP
from
sglang.srt.utils
import
add_prefix
,
get_compiler_backend
,
make_layers
...
...
@@ -143,20 +151,24 @@ class Llama4Attention(nn.Module):
self
.
hidden_size
=
hidden_size
self
.
use_rope
=
int
((
layer_id
+
1
)
%
4
!=
0
)
self
.
use_qk_norm
=
config
.
use_qk_norm
and
self
.
use_rope
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
dp_size
=
get_attention_dp_size
()
attn_tp_rank
=
get_attention_tp_rank
()
attn_tp_size
=
get_attention_tp_size
()
self
.
total_num_heads
=
num_heads
assert
self
.
total_num_heads
%
tp_size
==
0
self
.
num_heads
=
self
.
total_num_heads
//
tp_size
assert
self
.
total_num_heads
%
attn_
tp_size
==
0
self
.
num_heads
=
self
.
total_num_heads
//
attn_
tp_size
self
.
total_num_kv_heads
=
num_kv_heads
if
self
.
total_num_kv_heads
>=
tp_size
:
if
self
.
total_num_kv_heads
>=
attn_
tp_size
:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert
self
.
total_num_kv_heads
%
tp_size
==
0
assert
self
.
total_num_kv_heads
%
attn_
tp_size
==
0
else
:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert
tp_size
%
self
.
total_num_kv_heads
==
0
self
.
num_kv_heads
=
max
(
1
,
self
.
total_num_kv_heads
//
tp_size
)
assert
attn_
tp_size
%
self
.
total_num_kv_heads
==
0
self
.
num_kv_heads
=
max
(
1
,
self
.
total_num_kv_heads
//
attn_
tp_size
)
self
.
head_dim
=
config
.
head_dim
self
.
q_size
=
self
.
num_heads
*
self
.
head_dim
self
.
kv_size
=
self
.
num_kv_heads
*
self
.
head_dim
...
...
@@ -183,6 +195,8 @@ class Llama4Attention(nn.Module):
bias
=
bias
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"qkv_proj"
,
prefix
),
tp_rank
=
attn_tp_rank
,
tp_size
=
attn_tp_size
,
)
self
.
o_proj
=
RowParallelLinear
(
...
...
@@ -191,6 +205,9 @@ class Llama4Attention(nn.Module):
bias
=
bias_o_proj
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"o_proj"
,
prefix
),
tp_rank
=
attn_tp_rank
,
tp_size
=
attn_tp_size
,
reduce_results
=
False
,
)
is_neox_style
=
True
is_gguf
=
quant_config
and
quant_config
.
get_name
()
==
"gguf"
...
...
@@ -274,6 +291,9 @@ class Llama4DecoderLayer(nn.Module):
rope_theta
=
config
.
rope_theta
rope_scaling
=
config
.
rope_scaling
max_position_embeddings
=
config
.
max_position_embeddings
self
.
dp_size
=
get_attention_dp_size
()
self
.
attn_tp_size
=
get_attention_tp_size
()
self
.
attn_tp_rank
=
get_attention_tp_rank
()
self
.
self_attn
=
Llama4Attention
(
config
=
config
,
...
...
@@ -316,21 +336,58 @@ class Llama4DecoderLayer(nn.Module):
forward_batch
:
ForwardBatch
,
residual
:
Optional
[
torch
.
Tensor
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Self Attention
if
residual
is
None
:
if
hidden_states
.
shape
[
0
]
==
0
:
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
else
:
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
forward_batch
=
forward_batch
,
)
# Self Attention
if
residual
is
None
:
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
else
:
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
forward_batch
=
forward_batch
,
)
# Gather
if
get_tensor_model_parallel_world_size
()
>
1
:
# all gather and all reduce
if
self
.
dp_size
!=
1
:
if
self
.
attn_tp_rank
==
0
:
hidden_states
+=
residual
hidden_states
,
local_hidden_states
=
(
forward_batch
.
gathered_buffer
,
hidden_states
,
)
dp_gather_partial
(
hidden_states
,
local_hidden_states
,
forward_batch
)
dp_scatter
(
residual
,
hidden_states
,
forward_batch
)
hidden_states
=
self
.
post_attention_layernorm
(
hidden_states
)
else
:
hidden_states
=
tensor_model_parallel_all_reduce
(
hidden_states
)
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
hidden_states
,
residual
)
else
:
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
hidden_states
,
residual
)
# Fully Connected
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
feed_forward
(
hidden_states
)
# TODO(ch-wan): ues reduce-scatter in MLP to avoid this scatter
# Scatter
if
self
.
dp_size
!=
1
:
# important: forward batch.gathered_buffer is used both after scatter and after gather.
# be careful about this!
hidden_states
,
global_hidden_states
=
(
forward_batch
.
gathered_buffer
[:
forward_batch
.
input_ids
.
shape
[
0
]],
hidden_states
,
)
dp_scatter
(
hidden_states
,
global_hidden_states
,
forward_batch
)
return
hidden_states
,
residual
...
...
@@ -350,6 +407,7 @@ class Llama4Model(nn.Module):
config
.
hidden_size
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"embed_tokens"
,
prefix
),
enable_tp
=
not
global_server_args_dict
[
"enable_dp_attention"
],
)
self
.
layers
=
make_layers
(
config
.
num_hidden_layers
,
...
...
@@ -385,7 +443,8 @@ class Llama4Model(nn.Module):
forward_batch
,
residual
,
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
if
not
forward_batch
.
forward_mode
.
is_idle
():
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
if
len
(
aux_hidden_states
)
==
0
:
return
hidden_states
...
...
@@ -394,7 +453,6 @@ class Llama4Model(nn.Module):
class
Llama4ForCausalLM
(
LlamaForCausalLM
):
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
],
"gate_up_proj"
:
[
"gate_proj"
,
"up_proj"
],
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment