Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
11553c1a
Unverified
Commit
11553c1a
authored
May 18, 2025
by
libra
Committed by
GitHub
May 18, 2025
Browse files
Add pipeline parallelism for Qwen2 and Qwen3 Model (#6250)
parent
01dd39ba
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
340 additions
and
73 deletions
+340
-73
python/sglang/srt/models/qwen2.py
python/sglang/srt/models/qwen2.py
+95
-26
python/sglang/srt/models/qwen2_moe.py
python/sglang/srt/models/qwen2_moe.py
+89
-27
python/sglang/srt/models/qwen3.py
python/sglang/srt/models/qwen3.py
+52
-10
python/sglang/srt/models/qwen3_moe.py
python/sglang/srt/models/qwen3_moe.py
+49
-10
test/srt/test_pp_single_node.py
test/srt/test_pp_single_node.py
+55
-0
No files found.
python/sglang/srt/models/qwen2.py
View file @
11553c1a
...
@@ -15,12 +15,14 @@
...
@@ -15,12 +15,14 @@
# Adapted from llama2.py
# Adapted from llama2.py
# Modify details for the adaptation of Qwen2 model.
# Modify details for the adaptation of Qwen2 model.
"""Inference-only Qwen2 model compatible with HuggingFace weights."""
"""Inference-only Qwen2 model compatible with HuggingFace weights."""
from
typing
import
Any
,
Dict
,
Iterable
,
Optional
,
Tuple
import
logging
from
typing
import
Any
,
Dict
,
Iterable
,
Optional
,
Tuple
,
Union
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
sglang.srt.distributed
import
(
from
sglang.srt.distributed
import
(
get_pp_group
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_world_size
,
)
)
...
@@ -36,11 +38,12 @@ from sglang.srt.layers.pooler import Pooler, PoolingType
...
@@ -36,11 +38,12 @@ from sglang.srt.layers.pooler import Pooler, PoolingType
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.rotary_embedding
import
get_rope
from
sglang.srt.layers.rotary_embedding
import
get_rope
from
sglang.srt.layers.utils
import
PPMissingLayer
,
get_layer_id
from
sglang.srt.layers.vocab_parallel_embedding
import
(
from
sglang.srt.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
ParallelLMHead
,
VocabParallelEmbedding
,
VocabParallelEmbedding
,
)
)
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
PPProxyTensors
from
sglang.srt.model_loader.weight_utils
import
(
from
sglang.srt.model_loader.weight_utils
import
(
default_weight_loader
,
default_weight_loader
,
kv_cache_scales_loader
,
kv_cache_scales_loader
,
...
@@ -50,6 +53,9 @@ from sglang.srt.utils import add_prefix, make_layers
...
@@ -50,6 +53,9 @@ from sglang.srt.utils import add_prefix, make_layers
Qwen2Config
=
None
Qwen2Config
=
None
logger
=
logging
.
getLogger
(
__name__
)
class
Qwen2MLP
(
nn
.
Module
):
class
Qwen2MLP
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -245,15 +251,21 @@ class Qwen2Model(nn.Module):
...
@@ -245,15 +251,21 @@ class Qwen2Model(nn.Module):
self
.
config
=
config
self
.
config
=
config
self
.
padding_idx
=
config
.
pad_token_id
self
.
padding_idx
=
config
.
pad_token_id
self
.
vocab_size
=
config
.
vocab_size
self
.
vocab_size
=
config
.
vocab_size
self
.
embed_tokens
=
VocabParallelEmbedding
(
self
.
pp_group
=
get_pp_group
()
config
.
vocab_size
,
config
.
hidden_size
,
if
self
.
pp_group
.
is_first_rank
:
quant_config
=
quant_config
,
self
.
embed_tokens
=
VocabParallelEmbedding
(
prefix
=
add_prefix
(
"embed_tokens"
,
prefix
),
config
.
vocab_size
,
)
config
.
hidden_size
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"embed_tokens"
,
prefix
),
)
else
:
self
.
embed_tokens
=
PPMissingLayer
()
# Use the provided decoder layer type or default to Qwen2DecoderLayer
# Use the provided decoder layer type or default to Qwen2DecoderLayer
decoder_layer_type
=
decoder_layer_type
or
Qwen2DecoderLayer
decoder_layer_type
=
decoder_layer_type
or
Qwen2DecoderLayer
self
.
layers
=
make_layers
(
self
.
layers
,
self
.
start_layer
,
self
.
end_layer
=
make_layers
(
config
.
num_hidden_layers
,
config
.
num_hidden_layers
,
lambda
idx
,
prefix
:
decoder_layer_type
(
lambda
idx
,
prefix
:
decoder_layer_type
(
layer_id
=
idx
,
layer_id
=
idx
,
...
@@ -261,9 +273,14 @@ class Qwen2Model(nn.Module):
...
@@ -261,9 +273,14 @@ class Qwen2Model(nn.Module):
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
prefix
,
prefix
=
prefix
,
),
),
pp_rank
=
self
.
pp_group
.
rank_in_group
,
pp_size
=
self
.
pp_group
.
world_size
,
prefix
=
add_prefix
(
"layers"
,
prefix
),
prefix
=
add_prefix
(
"layers"
,
prefix
),
)
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
if
self
.
pp_group
.
is_last_rank
:
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
else
:
self
.
norm
=
PPMissingLayer
(
return_tuple
=
True
)
def
get_input_embedding
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
get_input_embedding
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
hasattr
(
self
.
config
,
"scale_emb"
):
if
hasattr
(
self
.
config
,
"scale_emb"
):
...
@@ -280,13 +297,20 @@ class Qwen2Model(nn.Module):
...
@@ -280,13 +297,20 @@ class Qwen2Model(nn.Module):
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
input_embeds
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
pp_proxy_tensors
:
Optional
[
PPProxyTensors
]
=
None
,
if
input_embeds
is
None
:
)
->
Union
[
torch
.
Tensor
,
PPProxyTensors
]:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
if
self
.
pp_group
.
is_first_rank
:
if
input_embeds
is
None
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
else
:
hidden_states
=
input_embeds
residual
=
None
else
:
else
:
hidden_states
=
input_embeds
assert
pp_proxy_tensors
is
not
None
residual
=
None
hidden_states
=
pp_proxy_tensors
[
"hidden_states"
]
for
i
in
range
(
len
(
self
.
layers
)):
residual
=
pp_proxy_tensors
[
"residual"
]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
layer
=
self
.
layers
[
i
]
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
hidden_states
,
residual
=
layer
(
positions
,
positions
,
...
@@ -294,7 +318,15 @@ class Qwen2Model(nn.Module):
...
@@ -294,7 +318,15 @@ class Qwen2Model(nn.Module):
forward_batch
,
forward_batch
,
residual
,
residual
,
)
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
if
not
self
.
pp_group
.
is_last_rank
:
return
PPProxyTensors
(
{
"hidden_states"
:
hidden_states
,
"residual"
:
residual
,
}
)
else
:
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
return
hidden_states
# If this function is called, it should always initialize KV cache scale
# If this function is called, it should always initialize KV cache scale
...
@@ -348,6 +380,7 @@ class Qwen2ForCausalLM(nn.Module):
...
@@ -348,6 +380,7 @@ class Qwen2ForCausalLM(nn.Module):
prefix
:
str
=
""
,
prefix
:
str
=
""
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
pp_group
=
get_pp_group
()
self
.
config
=
config
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
model
=
Qwen2Model
(
self
.
model
=
Qwen2Model
(
...
@@ -379,14 +412,33 @@ class Qwen2ForCausalLM(nn.Module):
...
@@ -379,14 +412,33 @@ class Qwen2ForCausalLM(nn.Module):
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
input_embeds
:
torch
.
Tensor
=
None
,
get_embedding
:
bool
=
False
,
get_embedding
:
bool
=
False
,
pp_proxy_tensors
:
Optional
[
PPProxyTensors
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
,
input_embeds
)
hidden_states
=
self
.
model
(
if
not
get_embedding
:
input_ids
,
return
self
.
logits_processor
(
positions
,
input_ids
,
hidden_states
,
self
.
lm_head
,
forward_batch
forward_batch
,
)
input_embeds
,
pp_proxy_tensors
=
pp_proxy_tensors
,
)
if
self
.
pp_group
.
is_last_rank
:
if
not
get_embedding
:
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
,
forward_batch
)
else
:
return
self
.
pooler
(
hidden_states
,
forward_batch
)
else
:
else
:
return
self
.
pooler
(
hidden_states
,
forward_batch
)
return
hidden_states
@
property
def
start_layer
(
self
):
return
self
.
model
.
start_layer
@
property
def
end_layer
(
self
):
return
self
.
model
.
end_layer
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
stacked_params_mapping
=
[
...
@@ -400,6 +452,17 @@ class Qwen2ForCausalLM(nn.Module):
...
@@ -400,6 +452,17 @@ class Qwen2ForCausalLM(nn.Module):
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
weights
:
for
name
,
loaded_weight
in
weights
:
layer_id
=
get_layer_id
(
name
)
if
(
layer_id
is
not
None
and
hasattr
(
self
.
model
,
"start_layer"
)
and
(
layer_id
<
self
.
model
.
start_layer
or
layer_id
>=
self
.
model
.
end_layer
)
):
continue
if
"rotary_emb.inv_freq"
in
name
or
"projector"
in
name
:
if
"rotary_emb.inv_freq"
in
name
or
"projector"
in
name
:
continue
continue
if
"rotary_emb.cos_cached"
in
name
or
"rotary_emb.sin_cached"
in
name
:
if
"rotary_emb.cos_cached"
in
name
or
"rotary_emb.sin_cached"
in
name
:
...
@@ -426,9 +489,15 @@ class Qwen2ForCausalLM(nn.Module):
...
@@ -426,9 +489,15 @@ class Qwen2ForCausalLM(nn.Module):
# Skip loading extra bias for GPTQ models.
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
if
name
in
params_dict
.
keys
():
weight_loader
(
param
,
loaded_weight
)
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
else
:
logger
.
warning
(
f
"Parameter
{
name
}
not found in params_dict"
)
def
get_embed_and_head
(
self
):
def
get_embed_and_head
(
self
):
return
self
.
model
.
embed_tokens
.
weight
,
self
.
lm_head
.
weight
return
self
.
model
.
embed_tokens
.
weight
,
self
.
lm_head
.
weight
...
...
python/sglang/srt/models/qwen2_moe.py
View file @
11553c1a
...
@@ -16,9 +16,10 @@
...
@@ -16,9 +16,10 @@
# https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/qwen2_moe.py
# https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/qwen2_moe.py
"""Inference-only Qwen2MoE model compatible with HuggingFace weights."""
"""Inference-only Qwen2MoE model compatible with HuggingFace weights."""
import
logging
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
enum
import
Enum
,
auto
from
enum
import
Enum
,
auto
from
typing
import
Any
,
Dict
,
Iterable
,
Optional
,
Tuple
from
typing
import
Any
,
Dict
,
Iterable
,
Optional
,
Tuple
,
Union
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
...
@@ -26,6 +27,7 @@ from torch import nn
...
@@ -26,6 +27,7 @@ from torch import nn
from
transformers
import
PretrainedConfig
from
transformers
import
PretrainedConfig
from
sglang.srt.distributed
import
(
from
sglang.srt.distributed
import
(
get_pp_group
,
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_reduce
,
tensor_model_parallel_all_reduce
,
)
)
...
@@ -52,18 +54,21 @@ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
...
@@ -52,18 +54,21 @@ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.rotary_embedding
import
get_rope
from
sglang.srt.layers.rotary_embedding
import
get_rope
from
sglang.srt.layers.utils
import
PPMissingLayer
,
get_layer_id
from
sglang.srt.layers.vocab_parallel_embedding
import
(
from
sglang.srt.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
ParallelLMHead
,
VocabParallelEmbedding
,
VocabParallelEmbedding
,
)
)
from
sglang.srt.managers.expert_distribution
import
ExpertDistributionRecorder
from
sglang.srt.managers.expert_distribution
import
ExpertDistributionRecorder
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
PPProxyTensors
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.utils
import
add_prefix
,
make_layers
from
sglang.srt.utils
import
add_prefix
,
make_layers
expert_distribution_recorder
=
ExpertDistributionRecorder
()
expert_distribution_recorder
=
ExpertDistributionRecorder
()
logger
=
logging
.
getLogger
(
__name__
)
class
Qwen2MoeMLP
(
nn
.
Module
):
class
Qwen2MoeMLP
(
nn
.
Module
):
def
__init__
(
def
__init__
(
...
@@ -535,16 +540,21 @@ class Qwen2MoeModel(nn.Module):
...
@@ -535,16 +540,21 @@ class Qwen2MoeModel(nn.Module):
super
().
__init__
()
super
().
__init__
()
self
.
padding_idx
=
config
.
pad_token_id
self
.
padding_idx
=
config
.
pad_token_id
self
.
vocab_size
=
config
.
vocab_size
self
.
vocab_size
=
config
.
vocab_size
self
.
pp_group
=
get_pp_group
()
if
self
.
pp_group
.
is_first_rank
:
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
,
enable_tp
=
not
global_server_args_dict
[
"enable_dp_attention"
],
prefix
=
add_prefix
(
"embed_tokens"
,
prefix
),
)
else
:
self
.
embed_tokens
=
PPMissingLayer
()
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
,
enable_tp
=
not
global_server_args_dict
[
"enable_dp_attention"
],
prefix
=
add_prefix
(
"embed_tokens"
,
prefix
),
)
# Use the provided decoder layer type or default to Qwen2MoeDecoderLayer
# Use the provided decoder layer type or default to Qwen2MoeDecoderLayer
decoder_layer_type
=
decoder_layer_type
or
Qwen2MoeDecoderLayer
decoder_layer_type
=
decoder_layer_type
or
Qwen2MoeDecoderLayer
self
.
layers
=
make_layers
(
self
.
layers
,
self
.
start_layer
,
self
.
end_layer
=
make_layers
(
config
.
num_hidden_layers
,
config
.
num_hidden_layers
,
lambda
idx
,
prefix
:
decoder_layer_type
(
lambda
idx
,
prefix
:
decoder_layer_type
(
layer_id
=
idx
,
layer_id
=
idx
,
...
@@ -552,9 +562,14 @@ class Qwen2MoeModel(nn.Module):
...
@@ -552,9 +562,14 @@ class Qwen2MoeModel(nn.Module):
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
prefix
,
prefix
=
prefix
,
),
),
pp_rank
=
self
.
pp_group
.
rank_in_group
,
pp_size
=
self
.
pp_group
.
world_size
,
prefix
=
add_prefix
(
"layers"
,
prefix
),
prefix
=
add_prefix
(
"layers"
,
prefix
),
)
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
if
self
.
pp_group
.
is_last_rank
:
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
else
:
self
.
norm
=
PPMissingLayer
(
return_tuple
=
True
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -562,20 +577,35 @@ class Qwen2MoeModel(nn.Module):
...
@@ -562,20 +577,35 @@ class Qwen2MoeModel(nn.Module):
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
input_embeds
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
pp_proxy_tensors
:
Optional
[
PPProxyTensors
]
=
None
,
if
input_embeds
is
None
:
)
->
Union
[
torch
.
Tensor
,
PPProxyTensors
]:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
if
self
.
pp_group
.
is_first_rank
:
if
input_embeds
is
None
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
else
:
hidden_states
=
input_embeds
residual
=
None
else
:
else
:
hidden_states
=
input_embeds
assert
pp_proxy_tensors
is
not
None
residual
=
None
hidden_states
=
pp_proxy_tensors
[
"hidden_states"
]
for
i
in
range
(
len
(
self
.
layers
)):
residual
=
pp_proxy_tensors
[
"residual"
]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
expert_distribution_recorder
.
set_current_layer
(
i
)
expert_distribution_recorder
.
set_current_layer
(
i
)
layer
=
self
.
layers
[
i
]
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
forward_batch
,
residual
positions
,
hidden_states
,
forward_batch
,
residual
)
)
if
hidden_states
.
shape
[
0
]
!=
0
:
if
not
self
.
pp_group
.
is_last_rank
:
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
PPProxyTensors
(
{
"hidden_states"
:
hidden_states
,
"residual"
:
residual
,
}
)
else
:
if
hidden_states
.
shape
[
0
]
!=
0
:
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
return
hidden_states
...
@@ -589,6 +619,7 @@ class Qwen2MoeForCausalLM(nn.Module):
...
@@ -589,6 +619,7 @@ class Qwen2MoeForCausalLM(nn.Module):
prefix
:
str
=
""
,
prefix
:
str
=
""
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
pp_group
=
get_pp_group
()
self
.
config
=
config
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
model
=
Qwen2MoeModel
(
self
.
model
=
Qwen2MoeModel
(
...
@@ -609,11 +640,29 @@ class Qwen2MoeForCausalLM(nn.Module):
...
@@ -609,11 +640,29 @@ class Qwen2MoeForCausalLM(nn.Module):
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
input_embeds
:
torch
.
Tensor
=
None
,
)
->
LogitsProcessorOutput
:
pp_proxy_tensors
:
Optional
[
PPProxyTensors
]
=
None
,
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
,
input_embeds
)
)
->
torch
.
Tensor
:
return
self
.
logits_processor
(
hidden_states
=
self
.
model
(
input_ids
,
hidden_states
,
self
.
lm_head
,
forward_batch
input_ids
,
positions
,
forward_batch
,
input_embeds
,
pp_proxy_tensors
=
pp_proxy_tensors
,
)
)
if
self
.
pp_group
.
is_last_rank
:
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
,
forward_batch
)
else
:
return
hidden_states
@
property
def
start_layer
(
self
):
return
self
.
model
.
start_layer
@
property
def
end_layer
(
self
):
return
self
.
model
.
end_layer
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
stacked_params_mapping
=
[
...
@@ -636,6 +685,16 @@ class Qwen2MoeForCausalLM(nn.Module):
...
@@ -636,6 +685,16 @@ class Qwen2MoeForCausalLM(nn.Module):
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
weights
:
for
name
,
loaded_weight
in
weights
:
layer_id
=
get_layer_id
(
name
)
if
(
layer_id
is
not
None
and
hasattr
(
self
.
model
,
"start_layer"
)
and
(
layer_id
<
self
.
model
.
start_layer
or
layer_id
>=
self
.
model
.
end_layer
)
):
continue
if
"rotary_emb.inv_freq"
in
name
:
if
"rotary_emb.inv_freq"
in
name
:
continue
continue
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
...
@@ -684,11 +743,14 @@ class Qwen2MoeForCausalLM(nn.Module):
...
@@ -684,11 +743,14 @@ class Qwen2MoeForCausalLM(nn.Module):
if
name
not
in
params_dict
:
if
name
not
in
params_dict
:
continue
continue
param
=
params_dict
[
name
]
if
name
in
params_dict
.
keys
():
weight_loader
=
getattr
(
param
=
params_dict
[
name
]
param
,
"weight_loader"
,
default_weight_loader
weight_loader
=
getattr
(
)
param
,
"weight_loader"
,
default_weight_loader
weight_loader
(
param
,
loaded_weight
)
)
weight_loader
(
param
,
loaded_weight
)
else
:
logger
.
warning
(
f
"Parameter
{
name
}
not found in params_dict"
)
EntryClass
=
Qwen2MoeForCausalLM
EntryClass
=
Qwen2MoeForCausalLM
python/sglang/srt/models/qwen3.py
View file @
11553c1a
# Adapted from qwen2.py
# Adapted from qwen2.py
import
logging
from
functools
import
partial
from
functools
import
partial
from
typing
import
Any
,
Dict
,
Iterable
,
Optional
,
Tuple
from
typing
import
Any
,
Dict
,
Iterable
,
Optional
,
Tuple
...
@@ -7,6 +8,7 @@ import torch
...
@@ -7,6 +8,7 @@ import torch
from
torch
import
nn
from
torch
import
nn
from
sglang.srt.distributed
import
(
from
sglang.srt.distributed
import
(
get_pp_group
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_world_size
,
split_tensor_along_last_dim
,
split_tensor_along_last_dim
,
...
@@ -19,8 +21,9 @@ from sglang.srt.layers.pooler import Pooler, PoolingType
...
@@ -19,8 +21,9 @@ from sglang.srt.layers.pooler import Pooler, PoolingType
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.rotary_embedding
import
get_rope
from
sglang.srt.layers.rotary_embedding
import
get_rope
from
sglang.srt.layers.utils
import
get_layer_id
from
sglang.srt.layers.vocab_parallel_embedding
import
ParallelLMHead
from
sglang.srt.layers.vocab_parallel_embedding
import
ParallelLMHead
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
PPProxyTensors
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.models.qwen2
import
Qwen2MLP
as
Qwen3MLP
from
sglang.srt.models.qwen2
import
Qwen2MLP
as
Qwen3MLP
from
sglang.srt.models.qwen2
import
Qwen2Model
from
sglang.srt.models.qwen2
import
Qwen2Model
...
@@ -28,6 +31,8 @@ from sglang.srt.utils import add_prefix
...
@@ -28,6 +31,8 @@ from sglang.srt.utils import add_prefix
Qwen3Config
=
None
Qwen3Config
=
None
logger
=
logging
.
getLogger
(
__name__
)
class
Qwen3Attention
(
nn
.
Module
):
class
Qwen3Attention
(
nn
.
Module
):
def
__init__
(
def
__init__
(
...
@@ -238,6 +243,7 @@ class Qwen3ForCausalLM(nn.Module):
...
@@ -238,6 +243,7 @@ class Qwen3ForCausalLM(nn.Module):
prefix
:
str
=
""
,
prefix
:
str
=
""
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
pp_group
=
get_pp_group
()
self
.
config
=
config
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
model
=
Qwen3Model
(
self
.
model
=
Qwen3Model
(
...
@@ -266,14 +272,33 @@ class Qwen3ForCausalLM(nn.Module):
...
@@ -266,14 +272,33 @@ class Qwen3ForCausalLM(nn.Module):
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
input_embeds
:
torch
.
Tensor
=
None
,
get_embedding
:
bool
=
False
,
get_embedding
:
bool
=
False
,
pp_proxy_tensors
:
Optional
[
PPProxyTensors
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
,
input_embeds
)
hidden_states
=
self
.
model
(
if
not
get_embedding
:
input_ids
,
return
self
.
logits_processor
(
positions
,
input_ids
,
hidden_states
,
self
.
lm_head
,
forward_batch
forward_batch
,
)
input_embeds
,
pp_proxy_tensors
=
pp_proxy_tensors
,
)
if
self
.
pp_group
.
is_last_rank
:
if
not
get_embedding
:
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
,
forward_batch
)
else
:
return
self
.
pooler
(
hidden_states
,
forward_batch
)
else
:
else
:
return
self
.
pooler
(
hidden_states
,
forward_batch
)
return
hidden_states
@
property
def
start_layer
(
self
):
return
self
.
model
.
start_layer
@
property
def
end_layer
(
self
):
return
self
.
model
.
end_layer
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
stacked_params_mapping
=
[
...
@@ -287,6 +312,17 @@ class Qwen3ForCausalLM(nn.Module):
...
@@ -287,6 +312,17 @@ class Qwen3ForCausalLM(nn.Module):
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
weights
:
for
name
,
loaded_weight
in
weights
:
layer_id
=
get_layer_id
(
name
)
if
(
layer_id
is
not
None
and
hasattr
(
self
.
model
,
"start_layer"
)
and
(
layer_id
<
self
.
model
.
start_layer
or
layer_id
>=
self
.
model
.
end_layer
)
):
continue
if
"rotary_emb.inv_freq"
in
name
or
"projector"
in
name
:
if
"rotary_emb.inv_freq"
in
name
or
"projector"
in
name
:
continue
continue
if
"rotary_emb.cos_cached"
in
name
or
"rotary_emb.sin_cached"
in
name
:
if
"rotary_emb.cos_cached"
in
name
or
"rotary_emb.sin_cached"
in
name
:
...
@@ -313,9 +349,15 @@ class Qwen3ForCausalLM(nn.Module):
...
@@ -313,9 +349,15 @@ class Qwen3ForCausalLM(nn.Module):
# Skip loading extra bias for GPTQ models.
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
if
name
in
params_dict
.
keys
():
weight_loader
(
param
,
loaded_weight
)
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
else
:
logger
.
warning
(
f
"Parameter
{
name
}
not found in params_dict"
)
def
get_embed_and_head
(
self
):
def
get_embed_and_head
(
self
):
return
self
.
model
.
embed_tokens
.
weight
,
self
.
lm_head
.
weight
return
self
.
model
.
embed_tokens
.
weight
,
self
.
lm_head
.
weight
...
...
python/sglang/srt/models/qwen3_moe.py
View file @
11553c1a
...
@@ -17,6 +17,7 @@
...
@@ -17,6 +17,7 @@
"""Inference-only Qwen3MoE model compatible with HuggingFace weights."""
"""Inference-only Qwen3MoE model compatible with HuggingFace weights."""
import
logging
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
enum
import
Enum
,
auto
from
enum
import
Enum
,
auto
from
functools
import
partial
from
functools
import
partial
...
@@ -28,6 +29,7 @@ from torch import nn
...
@@ -28,6 +29,7 @@ from torch import nn
from
transformers.configuration_utils
import
PretrainedConfig
from
transformers.configuration_utils
import
PretrainedConfig
from
sglang.srt.distributed
import
(
from
sglang.srt.distributed
import
(
get_pp_group
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_world_size
,
split_tensor_along_last_dim
,
split_tensor_along_last_dim
,
...
@@ -57,12 +59,13 @@ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
...
@@ -57,12 +59,13 @@ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.rotary_embedding
import
get_rope
from
sglang.srt.layers.rotary_embedding
import
get_rope
from
sglang.srt.layers.utils
import
get_layer_id
from
sglang.srt.layers.vocab_parallel_embedding
import
(
from
sglang.srt.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
ParallelLMHead
,
VocabParallelEmbedding
,
VocabParallelEmbedding
,
)
)
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
PPProxyTensors
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.models.qwen2_moe
import
Qwen2MoeMLP
as
Qwen3MoeMLP
from
sglang.srt.models.qwen2_moe
import
Qwen2MoeMLP
as
Qwen3MoeMLP
from
sglang.srt.models.qwen2_moe
import
Qwen2MoeModel
from
sglang.srt.models.qwen2_moe
import
Qwen2MoeModel
...
@@ -70,6 +73,8 @@ from sglang.srt.utils import add_prefix
...
@@ -70,6 +73,8 @@ from sglang.srt.utils import add_prefix
Qwen3MoeConfig
=
None
Qwen3MoeConfig
=
None
logger
=
logging
.
getLogger
(
__name__
)
class
Qwen3MoeSparseMoeBlock
(
nn
.
Module
):
class
Qwen3MoeSparseMoeBlock
(
nn
.
Module
):
def
__init__
(
def
__init__
(
...
@@ -516,6 +521,7 @@ class Qwen3MoeForCausalLM(nn.Module):
...
@@ -516,6 +521,7 @@ class Qwen3MoeForCausalLM(nn.Module):
prefix
:
str
=
""
,
prefix
:
str
=
""
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
pp_group
=
get_pp_group
()
self
.
config
=
config
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
model
=
Qwen3MoeModel
(
self
.
model
=
Qwen3MoeModel
(
...
@@ -536,12 +542,31 @@ class Qwen3MoeForCausalLM(nn.Module):
...
@@ -536,12 +542,31 @@ class Qwen3MoeForCausalLM(nn.Module):
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
input_embeds
:
torch
.
Tensor
=
None
,
)
->
LogitsProcessorOutput
:
pp_proxy_tensors
:
Optional
[
PPProxyTensors
]
=
None
,
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
,
input_embeds
)
)
->
torch
.
Tensor
:
return
self
.
logits_processor
(
hidden_states
=
self
.
model
(
input_ids
,
hidden_states
,
self
.
lm_head
,
forward_batch
input_ids
,
positions
,
forward_batch
,
input_embeds
,
pp_proxy_tensors
=
pp_proxy_tensors
,
)
)
if
self
.
pp_group
.
is_last_rank
:
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
,
forward_batch
)
else
:
return
hidden_states
@
property
def
start_layer
(
self
):
return
self
.
model
.
start_layer
@
property
def
end_layer
(
self
):
return
self
.
model
.
end_layer
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
# (param_name, shard_name, shard_id)
...
@@ -563,6 +588,17 @@ class Qwen3MoeForCausalLM(nn.Module):
...
@@ -563,6 +588,17 @@ class Qwen3MoeForCausalLM(nn.Module):
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
weights
:
for
name
,
loaded_weight
in
weights
:
layer_id
=
get_layer_id
(
name
)
if
(
layer_id
is
not
None
and
hasattr
(
self
.
model
,
"start_layer"
)
and
(
layer_id
<
self
.
model
.
start_layer
or
layer_id
>=
self
.
model
.
end_layer
)
):
continue
if
"rotary_emb.inv_freq"
in
name
:
if
"rotary_emb.inv_freq"
in
name
:
continue
continue
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
...
@@ -611,11 +647,14 @@ class Qwen3MoeForCausalLM(nn.Module):
...
@@ -611,11 +647,14 @@ class Qwen3MoeForCausalLM(nn.Module):
if
name
not
in
params_dict
:
if
name
not
in
params_dict
:
continue
continue
param
=
params_dict
[
name
]
if
name
in
params_dict
.
keys
():
weight_loader
=
getattr
(
param
=
params_dict
[
name
]
param
,
"weight_loader"
,
default_weight_loader
weight_loader
=
getattr
(
)
param
,
"weight_loader"
,
default_weight_loader
weight_loader
(
param
,
loaded_weight
)
)
weight_loader
(
param
,
loaded_weight
)
else
:
logger
.
warning
(
f
"Parameter
{
name
}
not found in params_dict"
)
EntryClass
=
Qwen3MoeForCausalLM
EntryClass
=
Qwen3MoeForCausalLM
test/srt/test_pp_single_node.py
View file @
11553c1a
"""
"""
Usage:
Usage:
python3 -m unittest test_pp_single_node.TestPPAccuracy.test_gsm8k
python3 -m unittest test_pp_single_node.TestPPAccuracy.test_gsm8k
python3 -m unittest test_pp_single_node.TestQwenPPAccuracy.test_pp_consistency
python3 -m unittest test_pp_single_node.TestFixedBugs.test_chunked_prefill_with_small_bs
python3 -m unittest test_pp_single_node.TestFixedBugs.test_chunked_prefill_with_small_bs
"""
"""
...
@@ -61,6 +62,60 @@ class TestPPAccuracy(unittest.TestCase):
...
@@ -61,6 +62,60 @@ class TestPPAccuracy(unittest.TestCase):
time
.
sleep
(
5
)
time
.
sleep
(
5
)
class
TestQwenPPAccuracy
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
base_url
=
"http://127.0.0.1:23334"
# different ports to avoid conflicts
cls
.
model_name
=
"Qwen/Qwen3-8B"
# replace with your Qwen Model if needed
def
run_gsm8k_test
(
self
,
pp_size
):
process
=
popen_launch_server
(
self
.
model_name
,
self
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
[
"--pp-size"
,
pp_size
,
"--chunked-prefill-size"
,
256
,
],
)
try
:
args
=
SimpleNamespace
(
num_shots
=
5
,
data_path
=
None
,
num_questions
=
200
,
max_new_tokens
=
512
,
parallel
=
128
,
host
=
"http://127.0.0.1"
,
port
=
int
(
self
.
base_url
.
split
(
":"
)[
-
1
]),
)
metrics
=
run_eval
(
args
)
time
.
sleep
(
5
)
return
metrics
finally
:
kill_process_tree
(
process
.
pid
)
def
test_baseline_accuracy
(
self
):
metrics
=
self
.
run_gsm8k_test
(
pp_size
=
1
)
print
(
f
"[Qwen Baseline]
{
metrics
=
}
"
)
self
.
assertGreater
(
metrics
[
"accuracy"
],
0.74
)
def
test_pp_consistency
(
self
):
baseline
=
self
.
run_gsm8k_test
(
pp_size
=
1
)
pp_metrics
=
self
.
run_gsm8k_test
(
pp_size
=
2
)
print
(
f
"[Qwen PP Comparison] Baseline:
{
baseline
}
| PP:
{
pp_metrics
}
"
)
self
.
assertAlmostEqual
(
pp_metrics
[
"accuracy"
],
baseline
[
"accuracy"
],
delta
=
0.01
,
msg
=
f
"PP accuracy exceeds 1% (baseline:
{
baseline
[
'accuracy'
]
}
, pp:
{
pp_metrics
[
'accuracy'
]
}
)"
,
)
class
TestFixedBugs
(
unittest
.
TestCase
):
class
TestFixedBugs
(
unittest
.
TestCase
):
def
test_chunked_prefill_with_small_bs
(
self
):
def
test_chunked_prefill_with_small_bs
(
self
):
model
=
DEFAULT_MODEL_NAME_FOR_TEST
model
=
DEFAULT_MODEL_NAME_FOR_TEST
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment