Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
97528610
Unverified
Commit
97528610
authored
Sep 13, 2025
by
Binyao Jiang
Committed by
GitHub
Sep 13, 2025
Browse files
[Fix] Support qwen3-next MTP+DP (#10392)
parent
297d3745
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
29 additions
and
18 deletions
+29
-18
python/sglang/srt/configs/model_config.py
python/sglang/srt/configs/model_config.py
+1
-0
python/sglang/srt/layers/logits_processor.py
python/sglang/srt/layers/logits_processor.py
+1
-2
python/sglang/srt/mem_cache/memory_pool.py
python/sglang/srt/mem_cache/memory_pool.py
+22
-14
python/sglang/srt/models/qwen3_next_mtp.py
python/sglang/srt/models/qwen3_next_mtp.py
+5
-2
No files found.
python/sglang/srt/configs/model_config.py
View file @
97528610
...
@@ -170,6 +170,7 @@ class ModelConfig:
...
@@ -170,6 +170,7 @@ class ModelConfig:
if
is_draft_model
and
self
.
hf_config
.
architectures
[
0
]
==
"Qwen3NextForCausalLM"
:
if
is_draft_model
and
self
.
hf_config
.
architectures
[
0
]
==
"Qwen3NextForCausalLM"
:
self
.
hf_config
.
architectures
[
0
]
=
"Qwen3NextForCausalLMMTP"
self
.
hf_config
.
architectures
[
0
]
=
"Qwen3NextForCausalLMMTP"
self
.
hf_config
.
num_nextn_predict_layers
=
1
# Check model type
# Check model type
self
.
is_generation
=
is_generation_model
(
self
.
is_generation
=
is_generation_model
(
...
...
python/sglang/srt/layers/logits_processor.py
View file @
97528610
...
@@ -185,10 +185,9 @@ class LogitsMetadata:
...
@@ -185,10 +185,9 @@ class LogitsMetadata:
)
)
else
:
else
:
dp_local_start_pos
=
cumtokens
[
dp_rank
-
1
]
dp_local_start_pos
=
cumtokens
[
dp_rank
-
1
]
dp_local_num_tokens
=
self
.
global_num_tokens_for_logprob_gpu
[
dp_rank
]
self
.
dp_local_start_pos
=
dp_local_start_pos
self
.
dp_local_start_pos
=
dp_local_start_pos
self
.
dp_local_num_tokens
=
dp_
lo
c
al_num_tokens
self
.
dp_local_num_tokens
=
self
.
g
lo
b
al_num_tokens
_for_logprob_gpu
[
dp_rank
]
hidden_size
=
get_dp_hidden_size
()
hidden_size
=
get_dp_hidden_size
()
dtype
=
get_dp_dtype
()
dtype
=
get_dp_dtype
()
...
...
python/sglang/srt/mem_cache/memory_pool.py
View file @
97528610
...
@@ -52,6 +52,10 @@ if _is_npu:
...
@@ -52,6 +52,10 @@ if _is_npu:
import
torch_npu
import
torch_npu
def
get_tensor_size_bytes
(
t
:
torch
.
Tensor
):
return
np
.
prod
(
t
.
shape
)
*
t
.
dtype
.
itemsize
class
ReqToTokenPool
:
class
ReqToTokenPool
:
"""A memory pool that maps a request to its token locations."""
"""A memory pool that maps a request to its token locations."""
...
@@ -158,16 +162,23 @@ class MambaPool:
...
@@ -158,16 +162,23 @@ class MambaPool:
intermediate_ssm_state_cache
,
intermediate_ssm_state_cache
,
intermediate_conv_window_cache
,
intermediate_conv_window_cache
,
)
)
logger
.
info
(
f
"Mamba Cache is allocated. "
f
"conv_state size:
{
get_tensor_size_bytes
(
conv_state
)
/
GB
:.
2
f
}
GB, "
f
"ssm_state size:
{
get_tensor_size_bytes
(
temporal_state
)
/
GB
:.
2
f
}
GB "
f
"intermediate_ssm_state_cache size:
{
get_tensor_size_bytes
(
intermediate_ssm_state_cache
)
/
GB
:.
2
f
}
GB "
f
"intermediate_conv_window_cache size:
{
get_tensor_size_bytes
(
intermediate_conv_window_cache
)
/
GB
:.
2
f
}
GB "
)
else
:
else
:
self
.
mamba_cache
=
(
conv_state
,
temporal_state
)
self
.
mamba_cache
=
(
conv_state
,
temporal_state
)
logger
.
info
(
f
"Mamba Cache is allocated. "
f
"conv_state size:
{
get_tensor_size_bytes
(
conv_state
)
/
GB
:.
2
f
}
GB, "
f
"ssm_state size:
{
get_tensor_size_bytes
(
temporal_state
)
/
GB
:.
2
f
}
GB "
)
self
.
size
=
size
self
.
size
=
size
self
.
free_slots
=
list
(
range
(
size
))
self
.
free_slots
=
list
(
range
(
size
))
self
.
mem_usage
=
self
.
get_mamba_size
()
/
GB
self
.
mem_usage
=
self
.
get_mamba_size
()
/
GB
logger
.
info
(
f
"Mamba Cache is allocated. "
f
"conv_state size:
{
conv_state
.
numel
()
*
conv_state
.
itemsize
/
GB
:.
2
f
}
GB, "
f
"ssm_state size:
{
temporal_state
.
numel
()
*
temporal_state
.
itemsize
/
GB
:.
2
f
}
GB "
)
def
get_mamba_params_all_layers
(
self
):
def
get_mamba_params_all_layers
(
self
):
return
[
self
.
mamba_cache
[
i
]
for
i
in
range
(
len
(
self
.
mamba_cache
))]
return
[
self
.
mamba_cache
[
i
]
for
i
in
range
(
len
(
self
.
mamba_cache
))]
...
@@ -176,10 +187,7 @@ class MambaPool:
...
@@ -176,10 +187,7 @@ class MambaPool:
return
[
self
.
mamba_cache
[
i
][
layer_id
]
for
i
in
range
(
len
(
self
.
mamba_cache
))]
return
[
self
.
mamba_cache
[
i
][
layer_id
]
for
i
in
range
(
len
(
self
.
mamba_cache
))]
def
get_mamba_size
(
self
):
def
get_mamba_size
(
self
):
return
(
return
sum
(
get_tensor_size_bytes
(
t
)
for
t
in
self
.
mamba_cache
)
np
.
prod
(
self
.
mamba_cache
[
0
].
shape
)
*
self
.
mamba_cache
[
0
].
dtype
.
itemsize
+
np
.
prod
(
self
.
mamba_cache
[
1
].
shape
)
*
self
.
mamba_cache
[
1
].
dtype
.
itemsize
)
def
available_size
(
self
):
def
available_size
(
self
):
return
len
(
self
.
free_slots
)
return
len
(
self
.
free_slots
)
...
@@ -492,10 +500,10 @@ class MHATokenToKVPool(KVCache):
...
@@ -492,10 +500,10 @@ class MHATokenToKVPool(KVCache):
assert
hasattr
(
self
,
"v_buffer"
)
assert
hasattr
(
self
,
"v_buffer"
)
k_size_bytes
=
0
k_size_bytes
=
0
for
k_cache
in
self
.
k_buffer
:
for
k_cache
in
self
.
k_buffer
:
k_size_bytes
+=
np
.
prod
(
k_cache
.
shape
)
*
k_cache
.
dtype
.
itemsize
k_size_bytes
+=
get_tensor_size_bytes
(
k_cache
)
v_size_bytes
=
0
v_size_bytes
=
0
for
v_cache
in
self
.
v_buffer
:
for
v_cache
in
self
.
v_buffer
:
v_size_bytes
+=
np
.
prod
(
v_cache
.
shape
)
*
v_cache
.
dtype
.
itemsize
v_size_bytes
+=
get_tensor_size_bytes
(
v_cache
)
return
k_size_bytes
,
v_size_bytes
return
k_size_bytes
,
v_size_bytes
# for disagg
# for disagg
...
@@ -1077,7 +1085,7 @@ class MLATokenToKVPool(KVCache):
...
@@ -1077,7 +1085,7 @@ class MLATokenToKVPool(KVCache):
assert
hasattr
(
self
,
"kv_buffer"
)
assert
hasattr
(
self
,
"kv_buffer"
)
kv_size_bytes
=
0
kv_size_bytes
=
0
for
kv_cache
in
self
.
kv_buffer
:
for
kv_cache
in
self
.
kv_buffer
:
kv_size_bytes
+=
np
.
prod
(
kv_cache
.
shape
)
*
kv_cache
.
dtype
.
itemsize
kv_size_bytes
+=
get_tensor_size_bytes
(
kv_cache
)
return
kv_size_bytes
return
kv_size_bytes
# for disagg
# for disagg
...
@@ -1240,9 +1248,9 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
...
@@ -1240,9 +1248,9 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
assert
hasattr
(
self
,
"v_buffer"
)
assert
hasattr
(
self
,
"v_buffer"
)
kv_size_bytes
=
0
kv_size_bytes
=
0
for
k_cache
in
self
.
k_buffer
:
for
k_cache
in
self
.
k_buffer
:
kv_size_bytes
+=
np
.
prod
(
k_cache
.
shape
)
*
k_cache
.
dtype
.
itemsize
kv_size_bytes
+=
get_tensor_size_bytes
(
k_cache
)
for
v_cache
in
self
.
v_buffer
:
for
v_cache
in
self
.
v_buffer
:
kv_size_bytes
+=
np
.
prod
(
v_cache
.
shape
)
*
v_cache
.
dtype
.
itemsize
kv_size_bytes
+=
get_tensor_size_bytes
(
v_cache
)
return
kv_size_bytes
return
kv_size_bytes
def
get_kv_buffer
(
self
,
layer_id
:
int
):
def
get_kv_buffer
(
self
,
layer_id
:
int
):
...
...
python/sglang/srt/models/qwen3_next_mtp.py
View file @
97528610
...
@@ -85,8 +85,11 @@ class Qwen3NextForCausalLMMTP(Qwen3NextForCausalLM):
...
@@ -85,8 +85,11 @@ class Qwen3NextForCausalLMMTP(Qwen3NextForCausalLM):
if
input_embeds
is
None
:
if
input_embeds
is
None
:
input_embeds
=
self
.
model
.
embed_tokens
(
input_ids
)
input_embeds
=
self
.
model
.
embed_tokens
(
input_ids
)
input_embeds
=
self
.
pre_fc_norm_embedding
(
input_embeds
)
hidden_states
=
forward_batch
.
spec_info
.
hidden_states
hidden_states
=
self
.
pre_fc_norm_hidden
(
forward_batch
.
spec_info
.
hidden_states
)
# Some idle batch has 0 batch size. GemmaRMSNorm.forward would fail due to bs=0.
if
not
forward_batch
.
forward_mode
.
is_idle
():
input_embeds
=
self
.
pre_fc_norm_embedding
(
input_embeds
)
hidden_states
=
self
.
pre_fc_norm_hidden
(
hidden_states
)
hidden_states
=
self
.
fc
(
torch
.
cat
((
input_embeds
,
hidden_states
),
dim
=-
1
))
hidden_states
=
self
.
fc
(
torch
.
cat
((
input_embeds
,
hidden_states
),
dim
=-
1
))
hidden_states
=
self
.
model
(
hidden_states
=
self
.
model
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment