Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
ab62b135
Unverified
Commit
ab62b135
authored
Sep 05, 2025
by
gongwei-130
Committed by
GitHub
Sep 05, 2025
Browse files
support Llama4 with non uniformed intermediate size across layers for… (#10047)
parent
273b2834
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
123 additions
and
13 deletions
+123
-13
python/sglang/srt/lora/mem_pool.py
python/sglang/srt/lora/mem_pool.py
+24
-10
python/sglang/srt/lora/utils.py
python/sglang/srt/lora/utils.py
+2
-2
python/sglang/srt/models/gemma3n_mm.py
python/sglang/srt/models/gemma3n_mm.py
+1
-1
python/sglang/srt/models/llama4.py
python/sglang/srt/models/llama4.py
+9
-0
python/sglang/srt/models/mllama4.py
python/sglang/srt/models/mllama4.py
+25
-0
test/srt/lora/test_lora_llama4.py
test/srt/lora/test_lora_llama4.py
+61
-0
test/srt/run_suite.py
test/srt/run_suite.py
+1
-0
No files found.
python/sglang/srt/lora/mem_pool.py
View file @
ab62b135
...
...
@@ -104,12 +104,18 @@ class LoRAMemoryPool:
return
all
(
_can_support
(
x
)
for
x
in
config
)
def
get_lora_A_shape
(
self
,
module_name
:
str
,
base_model
:
torch
.
nn
.
Module
,
max_lora_dim
:
int
self
,
module_name
:
str
,
base_model
:
torch
.
nn
.
Module
,
max_lora_dim
:
int
,
layer_idx
:
int
,
)
->
Tuple
[
int
]:
"""
Given a module_name (might be a stacked name), return the hidden dims of modules' input and output.
"""
input_dim
,
_
=
get_hidden_dim
(
module_name
,
self
.
base_hf_config
,
base_model
)
input_dim
,
_
=
get_hidden_dim
(
module_name
,
self
.
base_hf_config
,
base_model
,
layer_idx
)
c
=
get_stacked_multiply
(
module_name
)
if
self
.
tp_size
>
1
and
module_name
in
ROW_PARALLELISM_LINEAR_LORA_NAMES
:
input_dim
=
divide
(
input_dim
,
self
.
tp_size
)
...
...
@@ -120,12 +126,18 @@ class LoRAMemoryPool:
)
def
get_lora_B_shape
(
self
,
module_name
:
str
,
base_model
:
torch
.
nn
.
Module
,
max_lora_dim
:
int
self
,
module_name
:
str
,
base_model
:
torch
.
nn
.
Module
,
max_lora_dim
:
int
,
layer_idx
:
int
,
)
->
Tuple
[
int
]:
"""
Given a module_name (might be a stacked name), return the hidden dims of modules' input and output.
"""
_
,
output_dim
=
get_hidden_dim
(
module_name
,
self
.
base_hf_config
,
base_model
)
_
,
output_dim
=
get_hidden_dim
(
module_name
,
self
.
base_hf_config
,
base_model
,
layer_idx
)
if
self
.
tp_size
>
1
and
module_name
not
in
ROW_PARALLELISM_LINEAR_LORA_NAMES
:
output_dim
=
divide
(
output_dim
,
self
.
tp_size
)
return
(
...
...
@@ -140,19 +152,21 @@ class LoRAMemoryPool:
def
init_buffer
(
buffer
:
Dict
[
str
,
List
[
torch
.
Tensor
]],
target_modules
:
Set
[
str
],
get_lora_shape_fn
:
Callable
[[
str
,
torch
.
nn
.
Module
,
int
],
Tuple
[
int
]],
get_lora_shape_fn
:
Callable
[[
str
,
torch
.
nn
.
Module
,
int
,
int
],
Tuple
[
int
]],
):
for
module_name
in
target_modules
:
lora_shape
=
get_lora_shape_fn
(
module_name
,
base_model
,
self
.
max_lora_rank
)
buffer
[
module_name
]
=
[
torch
.
empty
(
lora_shape
,
get_lora_shape_fn
(
module_name
,
base_model
,
self
.
max_lora_rank
,
idx
,
),
dtype
=
self
.
dtype
,
device
=
device
,
)
for
_
in
range
(
self
.
num_layer
)
for
idx
in
range
(
self
.
num_layer
)
]
init_buffer
(
...
...
python/sglang/srt/lora/utils.py
View file @
ab62b135
...
...
@@ -48,14 +48,14 @@ def get_layer_id(name: str) -> int:
def
get_hidden_dim
(
module_name
:
str
,
config
:
AutoConfig
,
base_model
:
torch
.
nn
.
Module
module_name
:
str
,
config
:
AutoConfig
,
base_model
:
torch
.
nn
.
Module
,
layer_idx
:
int
)
->
Tuple
[
int
]:
"""
Given a module_name (might be a stacked name), return the hidden dims of modules' input and output.
"""
if
hasattr
(
base_model
,
"get_hidden_dim"
):
return
base_model
.
get_hidden_dim
(
module_name
)
return
base_model
.
get_hidden_dim
(
module_name
,
layer_idx
)
else
:
"""
WARNING: get_hidden_dim() is not defined,
...
...
python/sglang/srt/models/gemma3n_mm.py
View file @
ab62b135
...
...
@@ -499,7 +499,7 @@ class Gemma3nForConditionalGeneration(PreTrainedModel):
def
should_apply_lora
(
self
,
module_name
:
str
)
->
bool
:
return
bool
(
self
.
lora_pattern
.
match
(
module_name
))
def
get_hidden_dim
(
self
,
module_name
):
def
get_hidden_dim
(
self
,
module_name
,
layer_idx
):
# return input_dim, output_dim
if
module_name
==
"qkv_proj"
:
return
(
...
...
python/sglang/srt/models/llama4.py
View file @
ab62b135
...
...
@@ -423,6 +423,12 @@ class Llama4DecoderLayer(nn.Module):
return
self
.
config
.
num_local_experts
>
0
return
(
layer_id
+
1
)
%
self
.
config
.
interleave_moe_layer_step
==
0
def
get_intermediate_size
(
self
)
->
int
:
if
isinstance
(
self
.
feed_forward
,
Llama4MoE
):
return
self
.
config
.
intermediate_size
else
:
return
self
.
config
.
intermediate_size_mlp
def
forward
(
self
,
positions
:
torch
.
Tensor
,
...
...
@@ -540,6 +546,9 @@ class Llama4ForCausalLM(LlamaForCausalLM):
def
get_input_embeddings
(
self
):
return
self
.
model
.
embed_tokens
def
get_layers
(
self
):
return
self
.
model
.
layers
def
_init_model
(
self
,
config
:
Llama4TextConfig
,
...
...
python/sglang/srt/models/mllama4.py
View file @
ab62b135
...
...
@@ -961,5 +961,30 @@ class Llama4ForConditionalGeneration(nn.Module):
def
set_embed
(
self
,
embed
):
return
self
.
language_model
.
set_embed
(
embed
)
def
get_hidden_dim
(
self
,
module_name
,
layer_idx
):
# return input_dim, output_dim
if
module_name
==
"qkv_proj"
:
return
(
self
.
config
.
hidden_size
,
self
.
config
.
head_dim
*
(
self
.
config
.
num_attention_heads
+
self
.
config
.
num_key_value_heads
*
2
),
)
elif
module_name
==
"o_proj"
:
return
(
self
.
config
.
head_dim
*
self
.
config
.
num_attention_heads
,
self
.
config
.
hidden_size
,
)
elif
module_name
==
"gate_up_proj"
:
return
self
.
config
.
hidden_size
,
self
.
config
.
intermediate_size
*
2
elif
module_name
==
"down_proj"
:
decoder_layer
=
self
.
language_model
.
get_layers
()[
layer_idx
]
intermediate_size
=
decoder_layer
.
get_intermediate_size
()
return
intermediate_size
,
self
.
config
.
hidden_size
else
:
raise
NotImplementedError
()
EntryClass
=
Llama4ForConditionalGeneration
test/srt/lora/test_lora_llama4.py
0 → 100644
View file @
ab62b135
import
unittest
from
types
import
SimpleNamespace
from
sglang.srt.utils
import
kill_process_tree
from
sglang.test.test_utils
import
(
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
CustomTestCase
,
popen_launch_server
,
)
MODELS
=
[
SimpleNamespace
(
model
=
"meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8"
,
tp_size
=
8
,
),
]
class
TestLlama4LoRA
(
CustomTestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
def
test_bringup
(
self
):
for
model
in
MODELS
:
try
:
process
=
popen_launch_server
(
model
.
model
,
self
.
base_url
,
timeout
=
3
*
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
[
"--enable-lora"
,
"--max-lora-rank"
,
"64"
,
"--lora-target-modules"
,
"all"
,
"--tp-size"
,
str
(
model
.
tp_size
),
"--context-length"
,
"1048576"
,
"--attention-backend"
,
"fa3"
,
],
)
except
Exception
as
e
:
print
(
f
"Error testing
{
model
.
model
}
:
{
e
}
"
)
self
.
fail
(
f
"Test failed for
{
model
.
model
}
:
{
e
}
"
)
finally
:
# Ensure process cleanup happens regardless of success/failure
if
process
is
not
None
and
process
.
poll
()
is
None
:
print
(
f
"Cleaning up process
{
process
.
pid
}
"
)
try
:
kill_process_tree
(
process
.
pid
)
except
Exception
as
e
:
print
(
f
"Error killing process:
{
e
}
"
)
if
__name__
==
"__main__"
:
unittest
.
main
()
test/srt/run_suite.py
View file @
ab62b135
...
...
@@ -136,6 +136,7 @@ suites = {
"per-commit-8-gpu"
:
[
# Disabled because it hangs on the CI.
# TestFile("ep/test_moe_ep.py", 181),
TestFile
(
"lora/test_lora_llama4.py"
,
600
),
TestFile
(
"test_disaggregation.py"
,
499
),
TestFile
(
"test_disaggregation_different_tp.py"
,
155
),
TestFile
(
"test_full_deepseek_v3.py"
,
333
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment