Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
11553c1a
"vscode:/vscode.git/clone" did not exist on "44902202583f8a13cd0c6bf58c9bdc526d5a1ca2"
Unverified
Commit
11553c1a
authored
May 18, 2025
by
libra
Committed by
GitHub
May 18, 2025
Browse files
Add pipeline parallelism for Qwen2 and Qwen3 Model (#6250)
parent
01dd39ba
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
340 additions
and
73 deletions
+340
-73
python/sglang/srt/models/qwen2.py
python/sglang/srt/models/qwen2.py
+95
-26
python/sglang/srt/models/qwen2_moe.py
python/sglang/srt/models/qwen2_moe.py
+89
-27
python/sglang/srt/models/qwen3.py
python/sglang/srt/models/qwen3.py
+52
-10
python/sglang/srt/models/qwen3_moe.py
python/sglang/srt/models/qwen3_moe.py
+49
-10
test/srt/test_pp_single_node.py
test/srt/test_pp_single_node.py
+55
-0
No files found.
python/sglang/srt/models/qwen2.py
View file @
11553c1a
...
@@ -15,12 +15,14 @@
...
@@ -15,12 +15,14 @@
# Adapted from llama2.py
# Adapted from llama2.py
# Modify details for the adaptation of Qwen2 model.
# Modify details for the adaptation of Qwen2 model.
"""Inference-only Qwen2 model compatible with HuggingFace weights."""
"""Inference-only Qwen2 model compatible with HuggingFace weights."""
from
typing
import
Any
,
Dict
,
Iterable
,
Optional
,
Tuple
import
logging
from
typing
import
Any
,
Dict
,
Iterable
,
Optional
,
Tuple
,
Union
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
sglang.srt.distributed
import
(
from
sglang.srt.distributed
import
(
get_pp_group
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_world_size
,
)
)
...
@@ -36,11 +38,12 @@ from sglang.srt.layers.pooler import Pooler, PoolingType
...
@@ -36,11 +38,12 @@ from sglang.srt.layers.pooler import Pooler, PoolingType
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.rotary_embedding
import
get_rope
from
sglang.srt.layers.rotary_embedding
import
get_rope
from
sglang.srt.layers.utils
import
PPMissingLayer
,
get_layer_id
from
sglang.srt.layers.vocab_parallel_embedding
import
(
from
sglang.srt.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
ParallelLMHead
,
VocabParallelEmbedding
,
VocabParallelEmbedding
,
)
)
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
PPProxyTensors
from
sglang.srt.model_loader.weight_utils
import
(
from
sglang.srt.model_loader.weight_utils
import
(
default_weight_loader
,
default_weight_loader
,
kv_cache_scales_loader
,
kv_cache_scales_loader
,
...
@@ -50,6 +53,9 @@ from sglang.srt.utils import add_prefix, make_layers
...
@@ -50,6 +53,9 @@ from sglang.srt.utils import add_prefix, make_layers
Qwen2Config
=
None
Qwen2Config
=
None
logger
=
logging
.
getLogger
(
__name__
)
class
Qwen2MLP
(
nn
.
Module
):
class
Qwen2MLP
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -245,15 +251,21 @@ class Qwen2Model(nn.Module):
...
@@ -245,15 +251,21 @@ class Qwen2Model(nn.Module):
self
.
config
=
config
self
.
config
=
config
self
.
padding_idx
=
config
.
pad_token_id
self
.
padding_idx
=
config
.
pad_token_id
self
.
vocab_size
=
config
.
vocab_size
self
.
vocab_size
=
config
.
vocab_size
self
.
pp_group
=
get_pp_group
()
if
self
.
pp_group
.
is_first_rank
:
self
.
embed_tokens
=
VocabParallelEmbedding
(
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
vocab_size
,
config
.
hidden_size
,
config
.
hidden_size
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"embed_tokens"
,
prefix
),
prefix
=
add_prefix
(
"embed_tokens"
,
prefix
),
)
)
else
:
self
.
embed_tokens
=
PPMissingLayer
()
# Use the provided decoder layer type or default to Qwen2DecoderLayer
# Use the provided decoder layer type or default to Qwen2DecoderLayer
decoder_layer_type
=
decoder_layer_type
or
Qwen2DecoderLayer
decoder_layer_type
=
decoder_layer_type
or
Qwen2DecoderLayer
self
.
layers
=
make_layers
(
self
.
layers
,
self
.
start_layer
,
self
.
end_layer
=
make_layers
(
config
.
num_hidden_layers
,
config
.
num_hidden_layers
,
lambda
idx
,
prefix
:
decoder_layer_type
(
lambda
idx
,
prefix
:
decoder_layer_type
(
layer_id
=
idx
,
layer_id
=
idx
,
...
@@ -261,9 +273,14 @@ class Qwen2Model(nn.Module):
...
@@ -261,9 +273,14 @@ class Qwen2Model(nn.Module):
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
prefix
,
prefix
=
prefix
,
),
),
pp_rank
=
self
.
pp_group
.
rank_in_group
,
pp_size
=
self
.
pp_group
.
world_size
,
prefix
=
add_prefix
(
"layers"
,
prefix
),
prefix
=
add_prefix
(
"layers"
,
prefix
),
)
)
if
self
.
pp_group
.
is_last_rank
:
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
else
:
self
.
norm
=
PPMissingLayer
(
return_tuple
=
True
)
def
get_input_embedding
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
get_input_embedding
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
hasattr
(
self
.
config
,
"scale_emb"
):
if
hasattr
(
self
.
config
,
"scale_emb"
):
...
@@ -280,13 +297,20 @@ class Qwen2Model(nn.Module):
...
@@ -280,13 +297,20 @@ class Qwen2Model(nn.Module):
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
input_embeds
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
pp_proxy_tensors
:
Optional
[
PPProxyTensors
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
PPProxyTensors
]:
if
self
.
pp_group
.
is_first_rank
:
if
input_embeds
is
None
:
if
input_embeds
is
None
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
hidden_states
=
self
.
embed_tokens
(
input_ids
)
else
:
else
:
hidden_states
=
input_embeds
hidden_states
=
input_embeds
residual
=
None
residual
=
None
for
i
in
range
(
len
(
self
.
layers
)):
else
:
assert
pp_proxy_tensors
is
not
None
hidden_states
=
pp_proxy_tensors
[
"hidden_states"
]
residual
=
pp_proxy_tensors
[
"residual"
]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
layer
=
self
.
layers
[
i
]
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
hidden_states
,
residual
=
layer
(
positions
,
positions
,
...
@@ -294,6 +318,14 @@ class Qwen2Model(nn.Module):
...
@@ -294,6 +318,14 @@ class Qwen2Model(nn.Module):
forward_batch
,
forward_batch
,
residual
,
residual
,
)
)
if
not
self
.
pp_group
.
is_last_rank
:
return
PPProxyTensors
(
{
"hidden_states"
:
hidden_states
,
"residual"
:
residual
,
}
)
else
:
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
return
hidden_states
...
@@ -348,6 +380,7 @@ class Qwen2ForCausalLM(nn.Module):
...
@@ -348,6 +380,7 @@ class Qwen2ForCausalLM(nn.Module):
prefix
:
str
=
""
,
prefix
:
str
=
""
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
pp_group
=
get_pp_group
()
self
.
config
=
config
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
model
=
Qwen2Model
(
self
.
model
=
Qwen2Model
(
...
@@ -379,14 +412,33 @@ class Qwen2ForCausalLM(nn.Module):
...
@@ -379,14 +412,33 @@ class Qwen2ForCausalLM(nn.Module):
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
input_embeds
:
torch
.
Tensor
=
None
,
get_embedding
:
bool
=
False
,
get_embedding
:
bool
=
False
,
pp_proxy_tensors
:
Optional
[
PPProxyTensors
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
,
input_embeds
)
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
,
input_embeds
,
pp_proxy_tensors
=
pp_proxy_tensors
,
)
if
self
.
pp_group
.
is_last_rank
:
if
not
get_embedding
:
if
not
get_embedding
:
return
self
.
logits_processor
(
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
,
forward_batch
input_ids
,
hidden_states
,
self
.
lm_head
,
forward_batch
)
)
else
:
else
:
return
self
.
pooler
(
hidden_states
,
forward_batch
)
return
self
.
pooler
(
hidden_states
,
forward_batch
)
else
:
return
hidden_states
@
property
def
start_layer
(
self
):
return
self
.
model
.
start_layer
@
property
def
end_layer
(
self
):
return
self
.
model
.
end_layer
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
stacked_params_mapping
=
[
...
@@ -400,6 +452,17 @@ class Qwen2ForCausalLM(nn.Module):
...
@@ -400,6 +452,17 @@ class Qwen2ForCausalLM(nn.Module):
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
weights
:
for
name
,
loaded_weight
in
weights
:
layer_id
=
get_layer_id
(
name
)
if
(
layer_id
is
not
None
and
hasattr
(
self
.
model
,
"start_layer"
)
and
(
layer_id
<
self
.
model
.
start_layer
or
layer_id
>=
self
.
model
.
end_layer
)
):
continue
if
"rotary_emb.inv_freq"
in
name
or
"projector"
in
name
:
if
"rotary_emb.inv_freq"
in
name
or
"projector"
in
name
:
continue
continue
if
"rotary_emb.cos_cached"
in
name
or
"rotary_emb.sin_cached"
in
name
:
if
"rotary_emb.cos_cached"
in
name
or
"rotary_emb.sin_cached"
in
name
:
...
@@ -426,9 +489,15 @@ class Qwen2ForCausalLM(nn.Module):
...
@@ -426,9 +489,15 @@ class Qwen2ForCausalLM(nn.Module):
# Skip loading extra bias for GPTQ models.
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
continue
if
name
in
params_dict
.
keys
():
param
=
params_dict
[
name
]
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
weight_loader
(
param
,
loaded_weight
)
else
:
logger
.
warning
(
f
"Parameter
{
name
}
not found in params_dict"
)
def
get_embed_and_head
(
self
):
def
get_embed_and_head
(
self
):
return
self
.
model
.
embed_tokens
.
weight
,
self
.
lm_head
.
weight
return
self
.
model
.
embed_tokens
.
weight
,
self
.
lm_head
.
weight
...
...
python/sglang/srt/models/qwen2_moe.py
View file @
11553c1a
...
@@ -16,9 +16,10 @@
...
@@ -16,9 +16,10 @@
# https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/qwen2_moe.py
# https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/qwen2_moe.py
"""Inference-only Qwen2MoE model compatible with HuggingFace weights."""
"""Inference-only Qwen2MoE model compatible with HuggingFace weights."""
import
logging
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
enum
import
Enum
,
auto
from
enum
import
Enum
,
auto
from
typing
import
Any
,
Dict
,
Iterable
,
Optional
,
Tuple
from
typing
import
Any
,
Dict
,
Iterable
,
Optional
,
Tuple
,
Union
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
...
@@ -26,6 +27,7 @@ from torch import nn
...
@@ -26,6 +27,7 @@ from torch import nn
from
transformers
import
PretrainedConfig
from
transformers
import
PretrainedConfig
from
sglang.srt.distributed
import
(
from
sglang.srt.distributed
import
(
get_pp_group
,
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_reduce
,
tensor_model_parallel_all_reduce
,
)
)
...
@@ -52,18 +54,21 @@ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
...
@@ -52,18 +54,21 @@ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.rotary_embedding
import
get_rope
from
sglang.srt.layers.rotary_embedding
import
get_rope
from
sglang.srt.layers.utils
import
PPMissingLayer
,
get_layer_id
from
sglang.srt.layers.vocab_parallel_embedding
import
(
from
sglang.srt.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
ParallelLMHead
,
VocabParallelEmbedding
,
VocabParallelEmbedding
,
)
)
from
sglang.srt.managers.expert_distribution
import
ExpertDistributionRecorder
from
sglang.srt.managers.expert_distribution
import
ExpertDistributionRecorder
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
PPProxyTensors
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.utils
import
add_prefix
,
make_layers
from
sglang.srt.utils
import
add_prefix
,
make_layers
expert_distribution_recorder
=
ExpertDistributionRecorder
()
expert_distribution_recorder
=
ExpertDistributionRecorder
()
logger
=
logging
.
getLogger
(
__name__
)
class
Qwen2MoeMLP
(
nn
.
Module
):
class
Qwen2MoeMLP
(
nn
.
Module
):
def
__init__
(
def
__init__
(
...
@@ -535,16 +540,21 @@ class Qwen2MoeModel(nn.Module):
...
@@ -535,16 +540,21 @@ class Qwen2MoeModel(nn.Module):
super
().
__init__
()
super
().
__init__
()
self
.
padding_idx
=
config
.
pad_token_id
self
.
padding_idx
=
config
.
pad_token_id
self
.
vocab_size
=
config
.
vocab_size
self
.
vocab_size
=
config
.
vocab_size
self
.
pp_group
=
get_pp_group
()
if
self
.
pp_group
.
is_first_rank
:
self
.
embed_tokens
=
VocabParallelEmbedding
(
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
vocab_size
,
config
.
hidden_size
,
config
.
hidden_size
,
enable_tp
=
not
global_server_args_dict
[
"enable_dp_attention"
],
enable_tp
=
not
global_server_args_dict
[
"enable_dp_attention"
],
prefix
=
add_prefix
(
"embed_tokens"
,
prefix
),
prefix
=
add_prefix
(
"embed_tokens"
,
prefix
),
)
)
else
:
self
.
embed_tokens
=
PPMissingLayer
()
# Use the provided decoder layer type or default to Qwen2MoeDecoderLayer
# Use the provided decoder layer type or default to Qwen2MoeDecoderLayer
decoder_layer_type
=
decoder_layer_type
or
Qwen2MoeDecoderLayer
decoder_layer_type
=
decoder_layer_type
or
Qwen2MoeDecoderLayer
self
.
layers
=
make_layers
(
self
.
layers
,
self
.
start_layer
,
self
.
end_layer
=
make_layers
(
config
.
num_hidden_layers
,
config
.
num_hidden_layers
,
lambda
idx
,
prefix
:
decoder_layer_type
(
lambda
idx
,
prefix
:
decoder_layer_type
(
layer_id
=
idx
,
layer_id
=
idx
,
...
@@ -552,9 +562,14 @@ class Qwen2MoeModel(nn.Module):
...
@@ -552,9 +562,14 @@ class Qwen2MoeModel(nn.Module):
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
prefix
,
prefix
=
prefix
,
),
),
pp_rank
=
self
.
pp_group
.
rank_in_group
,
pp_size
=
self
.
pp_group
.
world_size
,
prefix
=
add_prefix
(
"layers"
,
prefix
),
prefix
=
add_prefix
(
"layers"
,
prefix
),
)
)
if
self
.
pp_group
.
is_last_rank
:
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
else
:
self
.
norm
=
PPMissingLayer
(
return_tuple
=
True
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -562,18 +577,33 @@ class Qwen2MoeModel(nn.Module):
...
@@ -562,18 +577,33 @@ class Qwen2MoeModel(nn.Module):
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
input_embeds
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
pp_proxy_tensors
:
Optional
[
PPProxyTensors
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
PPProxyTensors
]:
if
self
.
pp_group
.
is_first_rank
:
if
input_embeds
is
None
:
if
input_embeds
is
None
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
hidden_states
=
self
.
embed_tokens
(
input_ids
)
else
:
else
:
hidden_states
=
input_embeds
hidden_states
=
input_embeds
residual
=
None
residual
=
None
for
i
in
range
(
len
(
self
.
layers
)):
else
:
assert
pp_proxy_tensors
is
not
None
hidden_states
=
pp_proxy_tensors
[
"hidden_states"
]
residual
=
pp_proxy_tensors
[
"residual"
]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
expert_distribution_recorder
.
set_current_layer
(
i
)
expert_distribution_recorder
.
set_current_layer
(
i
)
layer
=
self
.
layers
[
i
]
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
forward_batch
,
residual
positions
,
hidden_states
,
forward_batch
,
residual
)
)
if
not
self
.
pp_group
.
is_last_rank
:
return
PPProxyTensors
(
{
"hidden_states"
:
hidden_states
,
"residual"
:
residual
,
}
)
else
:
if
hidden_states
.
shape
[
0
]
!=
0
:
if
hidden_states
.
shape
[
0
]
!=
0
:
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
return
hidden_states
...
@@ -589,6 +619,7 @@ class Qwen2MoeForCausalLM(nn.Module):
...
@@ -589,6 +619,7 @@ class Qwen2MoeForCausalLM(nn.Module):
prefix
:
str
=
""
,
prefix
:
str
=
""
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
pp_group
=
get_pp_group
()
self
.
config
=
config
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
model
=
Qwen2MoeModel
(
self
.
model
=
Qwen2MoeModel
(
...
@@ -609,11 +640,29 @@ class Qwen2MoeForCausalLM(nn.Module):
...
@@ -609,11 +640,29 @@ class Qwen2MoeForCausalLM(nn.Module):
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
input_embeds
:
torch
.
Tensor
=
None
,
)
->
LogitsProcessorOutput
:
pp_proxy_tensors
:
Optional
[
PPProxyTensors
]
=
None
,
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
,
input_embeds
)
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
,
input_embeds
,
pp_proxy_tensors
=
pp_proxy_tensors
,
)
if
self
.
pp_group
.
is_last_rank
:
return
self
.
logits_processor
(
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
,
forward_batch
input_ids
,
hidden_states
,
self
.
lm_head
,
forward_batch
)
)
else
:
return
hidden_states
@
property
def
start_layer
(
self
):
return
self
.
model
.
start_layer
@
property
def
end_layer
(
self
):
return
self
.
model
.
end_layer
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
stacked_params_mapping
=
[
...
@@ -636,6 +685,16 @@ class Qwen2MoeForCausalLM(nn.Module):
...
@@ -636,6 +685,16 @@ class Qwen2MoeForCausalLM(nn.Module):
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
weights
:
for
name
,
loaded_weight
in
weights
:
layer_id
=
get_layer_id
(
name
)
if
(
layer_id
is
not
None
and
hasattr
(
self
.
model
,
"start_layer"
)
and
(
layer_id
<
self
.
model
.
start_layer
or
layer_id
>=
self
.
model
.
end_layer
)
):
continue
if
"rotary_emb.inv_freq"
in
name
:
if
"rotary_emb.inv_freq"
in
name
:
continue
continue
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
...
@@ -684,11 +743,14 @@ class Qwen2MoeForCausalLM(nn.Module):
...
@@ -684,11 +743,14 @@ class Qwen2MoeForCausalLM(nn.Module):
if
name
not
in
params_dict
:
if
name
not
in
params_dict
:
continue
continue
if
name
in
params_dict
.
keys
():
param
=
params_dict
[
name
]
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
param
,
"weight_loader"
,
default_weight_loader
)
)
weight_loader
(
param
,
loaded_weight
)
weight_loader
(
param
,
loaded_weight
)
else
:
logger
.
warning
(
f
"Parameter
{
name
}
not found in params_dict"
)
EntryClass
=
Qwen2MoeForCausalLM
EntryClass
=
Qwen2MoeForCausalLM
python/sglang/srt/models/qwen3.py
View file @
11553c1a
# Adapted from qwen2.py
# Adapted from qwen2.py
import
logging
from
functools
import
partial
from
functools
import
partial
from
typing
import
Any
,
Dict
,
Iterable
,
Optional
,
Tuple
from
typing
import
Any
,
Dict
,
Iterable
,
Optional
,
Tuple
...
@@ -7,6 +8,7 @@ import torch
...
@@ -7,6 +8,7 @@ import torch
from
torch
import
nn
from
torch
import
nn
from
sglang.srt.distributed
import
(
from
sglang.srt.distributed
import
(
get_pp_group
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_world_size
,
split_tensor_along_last_dim
,
split_tensor_along_last_dim
,
...
@@ -19,8 +21,9 @@ from sglang.srt.layers.pooler import Pooler, PoolingType
...
@@ -19,8 +21,9 @@ from sglang.srt.layers.pooler import Pooler, PoolingType
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.rotary_embedding
import
get_rope
from
sglang.srt.layers.rotary_embedding
import
get_rope
from
sglang.srt.layers.utils
import
get_layer_id
from
sglang.srt.layers.vocab_parallel_embedding
import
ParallelLMHead
from
sglang.srt.layers.vocab_parallel_embedding
import
ParallelLMHead
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
PPProxyTensors
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.models.qwen2
import
Qwen2MLP
as
Qwen3MLP
from
sglang.srt.models.qwen2
import
Qwen2MLP
as
Qwen3MLP
from
sglang.srt.models.qwen2
import
Qwen2Model
from
sglang.srt.models.qwen2
import
Qwen2Model
...
@@ -28,6 +31,8 @@ from sglang.srt.utils import add_prefix
...
@@ -28,6 +31,8 @@ from sglang.srt.utils import add_prefix
Qwen3Config
=
None
Qwen3Config
=
None
logger
=
logging
.
getLogger
(
__name__
)
class
Qwen3Attention
(
nn
.
Module
):
class
Qwen3Attention
(
nn
.
Module
):
def
__init__
(
def
__init__
(
...
@@ -238,6 +243,7 @@ class Qwen3ForCausalLM(nn.Module):
...
@@ -238,6 +243,7 @@ class Qwen3ForCausalLM(nn.Module):
prefix
:
str
=
""
,
prefix
:
str
=
""
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
pp_group
=
get_pp_group
()
self
.
config
=
config
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
model
=
Qwen3Model
(
self
.
model
=
Qwen3Model
(
...
@@ -266,14 +272,33 @@ class Qwen3ForCausalLM(nn.Module):
...
@@ -266,14 +272,33 @@ class Qwen3ForCausalLM(nn.Module):
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
input_embeds
:
torch
.
Tensor
=
None
,
get_embedding
:
bool
=
False
,
get_embedding
:
bool
=
False
,
pp_proxy_tensors
:
Optional
[
PPProxyTensors
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
,
input_embeds
)
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
,
input_embeds
,
pp_proxy_tensors
=
pp_proxy_tensors
,
)
if
self
.
pp_group
.
is_last_rank
:
if
not
get_embedding
:
if
not
get_embedding
:
return
self
.
logits_processor
(
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
,
forward_batch
input_ids
,
hidden_states
,
self
.
lm_head
,
forward_batch
)
)
else
:
else
:
return
self
.
pooler
(
hidden_states
,
forward_batch
)
return
self
.
pooler
(
hidden_states
,
forward_batch
)
else
:
return
hidden_states
@
property
def
start_layer
(
self
):
return
self
.
model
.
start_layer
@
property
def
end_layer
(
self
):
return
self
.
model
.
end_layer
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
stacked_params_mapping
=
[
...
@@ -287,6 +312,17 @@ class Qwen3ForCausalLM(nn.Module):
...
@@ -287,6 +312,17 @@ class Qwen3ForCausalLM(nn.Module):
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
weights
:
for
name
,
loaded_weight
in
weights
:
layer_id
=
get_layer_id
(
name
)
if
(
layer_id
is
not
None
and
hasattr
(
self
.
model
,
"start_layer"
)
and
(
layer_id
<
self
.
model
.
start_layer
or
layer_id
>=
self
.
model
.
end_layer
)
):
continue
if
"rotary_emb.inv_freq"
in
name
or
"projector"
in
name
:
if
"rotary_emb.inv_freq"
in
name
or
"projector"
in
name
:
continue
continue
if
"rotary_emb.cos_cached"
in
name
or
"rotary_emb.sin_cached"
in
name
:
if
"rotary_emb.cos_cached"
in
name
or
"rotary_emb.sin_cached"
in
name
:
...
@@ -313,9 +349,15 @@ class Qwen3ForCausalLM(nn.Module):
...
@@ -313,9 +349,15 @@ class Qwen3ForCausalLM(nn.Module):
# Skip loading extra bias for GPTQ models.
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
continue
if
name
in
params_dict
.
keys
():
param
=
params_dict
[
name
]
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
weight_loader
(
param
,
loaded_weight
)
else
:
logger
.
warning
(
f
"Parameter
{
name
}
not found in params_dict"
)
def
get_embed_and_head
(
self
):
def
get_embed_and_head
(
self
):
return
self
.
model
.
embed_tokens
.
weight
,
self
.
lm_head
.
weight
return
self
.
model
.
embed_tokens
.
weight
,
self
.
lm_head
.
weight
...
...
python/sglang/srt/models/qwen3_moe.py
View file @
11553c1a
...
@@ -17,6 +17,7 @@
...
@@ -17,6 +17,7 @@
"""Inference-only Qwen3MoE model compatible with HuggingFace weights."""
"""Inference-only Qwen3MoE model compatible with HuggingFace weights."""
import
logging
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
enum
import
Enum
,
auto
from
enum
import
Enum
,
auto
from
functools
import
partial
from
functools
import
partial
...
@@ -28,6 +29,7 @@ from torch import nn
...
@@ -28,6 +29,7 @@ from torch import nn
from
transformers.configuration_utils
import
PretrainedConfig
from
transformers.configuration_utils
import
PretrainedConfig
from
sglang.srt.distributed
import
(
from
sglang.srt.distributed
import
(
get_pp_group
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_world_size
,
split_tensor_along_last_dim
,
split_tensor_along_last_dim
,
...
@@ -57,12 +59,13 @@ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
...
@@ -57,12 +59,13 @@ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.rotary_embedding
import
get_rope
from
sglang.srt.layers.rotary_embedding
import
get_rope
from
sglang.srt.layers.utils
import
get_layer_id
from
sglang.srt.layers.vocab_parallel_embedding
import
(
from
sglang.srt.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
ParallelLMHead
,
VocabParallelEmbedding
,
VocabParallelEmbedding
,
)
)
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
PPProxyTensors
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.models.qwen2_moe
import
Qwen2MoeMLP
as
Qwen3MoeMLP
from
sglang.srt.models.qwen2_moe
import
Qwen2MoeMLP
as
Qwen3MoeMLP
from
sglang.srt.models.qwen2_moe
import
Qwen2MoeModel
from
sglang.srt.models.qwen2_moe
import
Qwen2MoeModel
...
@@ -70,6 +73,8 @@ from sglang.srt.utils import add_prefix
...
@@ -70,6 +73,8 @@ from sglang.srt.utils import add_prefix
Qwen3MoeConfig
=
None
Qwen3MoeConfig
=
None
logger
=
logging
.
getLogger
(
__name__
)
class
Qwen3MoeSparseMoeBlock
(
nn
.
Module
):
class
Qwen3MoeSparseMoeBlock
(
nn
.
Module
):
def
__init__
(
def
__init__
(
...
@@ -516,6 +521,7 @@ class Qwen3MoeForCausalLM(nn.Module):
...
@@ -516,6 +521,7 @@ class Qwen3MoeForCausalLM(nn.Module):
prefix
:
str
=
""
,
prefix
:
str
=
""
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
pp_group
=
get_pp_group
()
self
.
config
=
config
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
model
=
Qwen3MoeModel
(
self
.
model
=
Qwen3MoeModel
(
...
@@ -536,11 +542,30 @@ class Qwen3MoeForCausalLM(nn.Module):
...
@@ -536,11 +542,30 @@ class Qwen3MoeForCausalLM(nn.Module):
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
input_embeds
:
torch
.
Tensor
=
None
,
)
->
LogitsProcessorOutput
:
pp_proxy_tensors
:
Optional
[
PPProxyTensors
]
=
None
,
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
,
input_embeds
)
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
,
input_embeds
,
pp_proxy_tensors
=
pp_proxy_tensors
,
)
if
self
.
pp_group
.
is_last_rank
:
return
self
.
logits_processor
(
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
,
forward_batch
input_ids
,
hidden_states
,
self
.
lm_head
,
forward_batch
)
)
else
:
return
hidden_states
@
property
def
start_layer
(
self
):
return
self
.
model
.
start_layer
@
property
def
end_layer
(
self
):
return
self
.
model
.
end_layer
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
stacked_params_mapping
=
[
...
@@ -563,6 +588,17 @@ class Qwen3MoeForCausalLM(nn.Module):
...
@@ -563,6 +588,17 @@ class Qwen3MoeForCausalLM(nn.Module):
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
weights
:
for
name
,
loaded_weight
in
weights
:
layer_id
=
get_layer_id
(
name
)
if
(
layer_id
is
not
None
and
hasattr
(
self
.
model
,
"start_layer"
)
and
(
layer_id
<
self
.
model
.
start_layer
or
layer_id
>=
self
.
model
.
end_layer
)
):
continue
if
"rotary_emb.inv_freq"
in
name
:
if
"rotary_emb.inv_freq"
in
name
:
continue
continue
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
...
@@ -611,11 +647,14 @@ class Qwen3MoeForCausalLM(nn.Module):
...
@@ -611,11 +647,14 @@ class Qwen3MoeForCausalLM(nn.Module):
if
name
not
in
params_dict
:
if
name
not
in
params_dict
:
continue
continue
if
name
in
params_dict
.
keys
():
param
=
params_dict
[
name
]
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
param
,
"weight_loader"
,
default_weight_loader
)
)
weight_loader
(
param
,
loaded_weight
)
weight_loader
(
param
,
loaded_weight
)
else
:
logger
.
warning
(
f
"Parameter
{
name
}
not found in params_dict"
)
EntryClass
=
Qwen3MoeForCausalLM
EntryClass
=
Qwen3MoeForCausalLM
test/srt/test_pp_single_node.py
View file @
11553c1a
"""
"""
Usage:
Usage:
python3 -m unittest test_pp_single_node.TestPPAccuracy.test_gsm8k
python3 -m unittest test_pp_single_node.TestPPAccuracy.test_gsm8k
python3 -m unittest test_pp_single_node.TestQwenPPAccuracy.test_pp_consistency
python3 -m unittest test_pp_single_node.TestFixedBugs.test_chunked_prefill_with_small_bs
python3 -m unittest test_pp_single_node.TestFixedBugs.test_chunked_prefill_with_small_bs
"""
"""
...
@@ -61,6 +62,60 @@ class TestPPAccuracy(unittest.TestCase):
...
@@ -61,6 +62,60 @@ class TestPPAccuracy(unittest.TestCase):
time
.
sleep
(
5
)
time
.
sleep
(
5
)
class
TestQwenPPAccuracy
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
base_url
=
"http://127.0.0.1:23334"
# different ports to avoid conflicts
cls
.
model_name
=
"Qwen/Qwen3-8B"
# replace with your Qwen Model if needed
def
run_gsm8k_test
(
self
,
pp_size
):
process
=
popen_launch_server
(
self
.
model_name
,
self
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
[
"--pp-size"
,
pp_size
,
"--chunked-prefill-size"
,
256
,
],
)
try
:
args
=
SimpleNamespace
(
num_shots
=
5
,
data_path
=
None
,
num_questions
=
200
,
max_new_tokens
=
512
,
parallel
=
128
,
host
=
"http://127.0.0.1"
,
port
=
int
(
self
.
base_url
.
split
(
":"
)[
-
1
]),
)
metrics
=
run_eval
(
args
)
time
.
sleep
(
5
)
return
metrics
finally
:
kill_process_tree
(
process
.
pid
)
def
test_baseline_accuracy
(
self
):
metrics
=
self
.
run_gsm8k_test
(
pp_size
=
1
)
print
(
f
"[Qwen Baseline]
{
metrics
=
}
"
)
self
.
assertGreater
(
metrics
[
"accuracy"
],
0.74
)
def
test_pp_consistency
(
self
):
baseline
=
self
.
run_gsm8k_test
(
pp_size
=
1
)
pp_metrics
=
self
.
run_gsm8k_test
(
pp_size
=
2
)
print
(
f
"[Qwen PP Comparison] Baseline:
{
baseline
}
| PP:
{
pp_metrics
}
"
)
self
.
assertAlmostEqual
(
pp_metrics
[
"accuracy"
],
baseline
[
"accuracy"
],
delta
=
0.01
,
msg
=
f
"PP accuracy exceeds 1% (baseline:
{
baseline
[
'accuracy'
]
}
, pp:
{
pp_metrics
[
'accuracy'
]
}
)"
,
)
class
TestFixedBugs
(
unittest
.
TestCase
):
class
TestFixedBugs
(
unittest
.
TestCase
):
def
test_chunked_prefill_with_small_bs
(
self
):
def
test_chunked_prefill_with_small_bs
(
self
):
model
=
DEFAULT_MODEL_NAME_FOR_TEST
model
=
DEFAULT_MODEL_NAME_FOR_TEST
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment