Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
fb16fbaf
Unverified
Commit
fb16fbaf
authored
Jul 28, 2025
by
Lifu Huang
Committed by
GitHub
Jul 28, 2025
Browse files
Fix incorrect KV cache allocation for MTP models. (#8482)
Co-authored-by:
Stefan He
<
hebiaobuaa@gmail.com
>
parent
0ce84c82
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
18 additions
and
13 deletions
+18
-13
python/sglang/srt/configs/model_config.py
python/sglang/srt/configs/model_config.py
+3
-0
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+15
-13
No files found.
python/sglang/srt/configs/model_config.py
View file @
fb16fbaf
...
...
@@ -261,6 +261,9 @@ class ModelConfig:
self
.
num_key_value_heads
=
self
.
num_attention_heads
self
.
hidden_size
=
self
.
hf_text_config
.
hidden_size
self
.
num_hidden_layers
=
self
.
hf_text_config
.
num_hidden_layers
self
.
num_nextn_predict_layers
=
getattr
(
self
.
hf_text_config
,
"num_nextn_predict_layers"
,
None
)
self
.
vocab_size
=
self
.
hf_text_config
.
vocab_size
# Verify quantization
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
fb16fbaf
...
...
@@ -285,11 +285,21 @@ class ModelRunner:
if
architectures
and
not
any
(
"Llama4"
in
arch
for
arch
in
architectures
):
self
.
is_hybrid
=
self
.
model_config
.
is_hybrid
=
True
self
.
start_layer
=
getattr
(
self
.
model
,
"start_layer"
,
0
)
self
.
end_layer
=
getattr
(
self
.
model
,
"end_layer"
,
self
.
model_config
.
num_hidden_layers
# For MTP models like DeepSeek-V3 or GLM-4.5, the MTP layer(s) are used separately as draft
# models for speculative decoding. In those cases, `num_nextn_predict_layers` is used to
# determine the number of layers.
model_has_mtp_layers
=
self
.
model_config
.
num_nextn_predict_layers
is
not
None
model_num_layers
=
(
self
.
model_config
.
num_nextn_predict_layers
if
self
.
is_draft_worker
and
model_has_mtp_layers
else
self
.
model_config
.
num_hidden_layers
)
self
.
start_layer
=
getattr
(
self
.
model
,
"start_layer"
,
0
)
self
.
end_layer
=
getattr
(
self
.
model
,
"end_layer"
,
model_num_layers
)
self
.
num_effective_layers
=
self
.
end_layer
-
self
.
start_layer
assert
(
not
model_has_mtp_layers
)
or
(
self
.
num_effective_layers
==
model_num_layers
),
"PP is not compatible with MTP models."
# Apply torchao quantization
torchao_applied
=
getattr
(
self
.
model
,
"torchao_applied"
,
False
)
...
...
@@ -1178,11 +1188,7 @@ class ModelRunner:
dtype
=
self
.
kv_cache_dtype
,
kv_lora_rank
=
self
.
model_config
.
kv_lora_rank
,
qk_rope_head_dim
=
self
.
model_config
.
qk_rope_head_dim
,
layer_num
=
(
self
.
model_config
.
num_hidden_layers
if
not
self
.
is_draft_worker
else
self
.
model_config
.
hf_config
.
num_nextn_predict_layers
),
# PP is not compatible with mla backend
layer_num
=
self
.
num_effective_layers
,
device
=
self
.
device
,
enable_memory_saver
=
self
.
server_args
.
enable_memory_saver
,
start_layer
=
self
.
start_layer
,
...
...
@@ -1195,11 +1201,7 @@ class ModelRunner:
dtype
=
self
.
kv_cache_dtype
,
kv_lora_rank
=
self
.
model_config
.
kv_lora_rank
,
qk_rope_head_dim
=
self
.
model_config
.
qk_rope_head_dim
,
layer_num
=
(
self
.
model_config
.
num_hidden_layers
if
not
self
.
is_draft_worker
else
self
.
model_config
.
hf_config
.
num_nextn_predict_layers
),
# PP is not compatible with mla backend
layer_num
=
self
.
num_effective_layers
,
device
=
self
.
device
,
enable_memory_saver
=
self
.
server_args
.
enable_memory_saver
,
start_layer
=
self
.
start_layer
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment