Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
3f23d8cd
Unverified
Commit
3f23d8cd
authored
May 25, 2025
by
Shenggui Li
Committed by
GitHub
May 25, 2025
Browse files
added support for tied weights in qwen pipeline parallelism (#6546)
parent
1a399799
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
134 additions
and
20 deletions
+134
-20
.github/workflows/pr-test.yml
.github/workflows/pr-test.yml
+1
-1
python/sglang/srt/models/qwen2.py
python/sglang/srt/models/qwen2.py
+38
-9
python/sglang/srt/models/qwen3.py
python/sglang/srt/models/qwen3.py
+39
-10
test/srt/test_pp_single_node.py
test/srt/test_pp_single_node.py
+56
-0
No files found.
.github/workflows/pr-test.yml
View file @
3f23d8cd
...
@@ -84,7 +84,7 @@ jobs:
...
@@ -84,7 +84,7 @@ jobs:
bash scripts/ci_install_dependency.sh
bash scripts/ci_install_dependency.sh
-
name
:
Run test
-
name
:
Run test
timeout-minutes
:
25
timeout-minutes
:
30
run
:
|
run
:
|
cd test/srt
cd test/srt
python3 run_suite.py --suite per-commit-2-gpu
python3 run_suite.py --suite per-commit-2-gpu
...
...
python/sglang/srt/models/qwen2.py
View file @
3f23d8cd
...
@@ -386,15 +386,36 @@ class Qwen2ForCausalLM(nn.Module):
...
@@ -386,15 +386,36 @@ class Qwen2ForCausalLM(nn.Module):
self
.
model
=
Qwen2Model
(
self
.
model
=
Qwen2Model
(
config
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"model"
,
prefix
)
config
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"model"
,
prefix
)
)
)
if
config
.
tie_word_embeddings
:
self
.
lm_head
=
self
.
model
.
embed_tokens
# handle the lm head on different pp ranks
if
self
.
pp_group
.
is_last_rank
:
if
self
.
pp_group
.
world_size
==
1
and
config
.
tie_word_embeddings
:
self
.
lm_head
=
self
.
model
.
embed_tokens
else
:
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"lm_head"
,
prefix
),
)
else
:
else
:
self
.
lm_head
=
ParallelLMHead
(
# ranks other than the last rank will have a placeholder layer
config
.
vocab_size
,
self
.
lm_head
=
PPMissingLayer
()
config
.
hidden_size
,
quant_config
=
quant_config
,
# perform weight tying for PP
prefix
=
add_prefix
(
"lm_head"
,
prefix
),
if
self
.
pp_group
.
world_size
>
1
and
config
.
tie_word_embeddings
:
)
if
self
.
pp_group
.
is_first_rank
:
self
.
pp_group
.
send
(
self
.
model
.
embed_tokens
.
weight
,
dst
=
self
.
pp_group
.
last_rank
)
else
:
emb_token_weight
=
self
.
pp_group
.
recv
(
size
=
(
config
.
vocab_size
,
config
.
hidden_size
),
dtype
=
next
(
self
.
model
.
parameters
()).
dtype
,
src
=
self
.
pp_group
.
first_rank
,
)
self
.
lm_head
.
weight
.
copy_
(
emb_token_weight
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
pooler
=
Pooler
(
pooling_type
=
PoolingType
.
LAST
,
normalize
=
True
)
self
.
pooler
=
Pooler
(
pooling_type
=
PoolingType
.
LAST
,
normalize
=
True
)
...
@@ -470,7 +491,15 @@ class Qwen2ForCausalLM(nn.Module):
...
@@ -470,7 +491,15 @@ class Qwen2ForCausalLM(nn.Module):
# the checkpoint. Skip them.
# the checkpoint. Skip them.
continue
continue
if
self
.
config
.
tie_word_embeddings
and
"lm_head.weight"
in
name
:
if
self
.
config
.
tie_word_embeddings
and
"lm_head.weight"
in
name
:
continue
if
self
.
pp_group
.
world_size
>
1
and
self
.
pp_group
.
is_last_rank
:
# Handle pp weight tying here
# find the embed_tokens.weight in the weights
embed_token_weights
=
next
(
filter
(
lambda
x
:
x
[
0
]
==
"model.embed_tokens.weight"
,
weights
)
)[
1
]
loaded_weight
=
embed_token_weights
else
:
continue
if
name
.
startswith
(
"model.vision_tower"
)
and
name
not
in
params_dict
:
if
name
.
startswith
(
"model.vision_tower"
)
and
name
not
in
params_dict
:
continue
continue
...
...
python/sglang/srt/models/qwen3.py
View file @
3f23d8cd
...
@@ -21,7 +21,7 @@ from sglang.srt.layers.pooler import Pooler, PoolingType
...
@@ -21,7 +21,7 @@ from sglang.srt.layers.pooler import Pooler, PoolingType
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.rotary_embedding
import
get_rope
from
sglang.srt.layers.rotary_embedding
import
get_rope
from
sglang.srt.layers.utils
import
get_layer_id
from
sglang.srt.layers.utils
import
PPMissingLayer
,
get_layer_id
from
sglang.srt.layers.vocab_parallel_embedding
import
ParallelLMHead
from
sglang.srt.layers.vocab_parallel_embedding
import
ParallelLMHead
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
PPProxyTensors
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
PPProxyTensors
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
...
@@ -249,15 +249,36 @@ class Qwen3ForCausalLM(nn.Module):
...
@@ -249,15 +249,36 @@ class Qwen3ForCausalLM(nn.Module):
self
.
model
=
Qwen3Model
(
self
.
model
=
Qwen3Model
(
config
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"model"
,
prefix
)
config
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"model"
,
prefix
)
)
)
if
config
.
tie_word_embeddings
:
self
.
lm_head
=
self
.
model
.
embed_tokens
# handle the lm head on different pp ranks
if
self
.
pp_group
.
is_last_rank
:
if
self
.
pp_group
.
world_size
==
1
and
config
.
tie_word_embeddings
:
self
.
lm_head
=
self
.
model
.
embed_tokens
else
:
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"lm_head"
,
prefix
),
)
else
:
else
:
self
.
lm_head
=
ParallelLMHead
(
# ranks other than the last rank will have a placeholder layer
config
.
vocab_size
,
self
.
lm_head
=
PPMissingLayer
()
config
.
hidden_size
,
quant_config
=
quant_config
,
# perform weight tying for PP
prefix
=
add_prefix
(
"lm_head"
,
prefix
),
if
self
.
pp_group
.
world_size
>
1
and
config
.
tie_word_embeddings
:
)
if
self
.
pp_group
.
is_first_rank
:
self
.
pp_group
.
send
(
self
.
model
.
embed_tokens
.
weight
,
dst
=
self
.
pp_group
.
last_rank
)
else
:
emb_token_weight
=
self
.
pp_group
.
recv
(
size
=
(
config
.
vocab_size
,
config
.
hidden_size
),
dtype
=
next
(
self
.
model
.
parameters
()).
dtype
,
src
=
self
.
pp_group
.
first_rank
,
)
self
.
lm_head
.
weight
.
copy_
(
emb_token_weight
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
pooler
=
Pooler
(
pooling_type
=
PoolingType
.
LAST
,
normalize
=
True
)
self
.
pooler
=
Pooler
(
pooling_type
=
PoolingType
.
LAST
,
normalize
=
True
)
...
@@ -330,7 +351,15 @@ class Qwen3ForCausalLM(nn.Module):
...
@@ -330,7 +351,15 @@ class Qwen3ForCausalLM(nn.Module):
# the checkpoint. Skip them.
# the checkpoint. Skip them.
continue
continue
if
self
.
config
.
tie_word_embeddings
and
"lm_head.weight"
in
name
:
if
self
.
config
.
tie_word_embeddings
and
"lm_head.weight"
in
name
:
continue
if
self
.
pp_group
.
world_size
>
1
and
self
.
pp_group
.
is_last_rank
:
# Handle pp weight tying here
# find the embed_tokens.weight in the weights
embed_token_weights
=
next
(
filter
(
lambda
x
:
x
[
0
]
==
"model.embed_tokens.weight"
,
weights
)
)[
1
]
loaded_weight
=
embed_token_weights
else
:
continue
if
name
.
startswith
(
"model.vision_tower"
)
and
name
not
in
params_dict
:
if
name
.
startswith
(
"model.vision_tower"
)
and
name
not
in
params_dict
:
continue
continue
...
...
test/srt/test_pp_single_node.py
View file @
3f23d8cd
...
@@ -116,6 +116,62 @@ class TestQwenPPAccuracy(unittest.TestCase):
...
@@ -116,6 +116,62 @@ class TestQwenPPAccuracy(unittest.TestCase):
)
)
class
TestQwenPPTieWeightsAccuracy
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
base_url
=
"http://127.0.0.1:23334"
# different ports to avoid conflicts
cls
.
model_name
=
(
"Qwen/Qwen3-0.6B"
# qwen3 < 8B all have tie_word_embeddings = True
)
def
run_gsm8k_test
(
self
,
pp_size
):
process
=
popen_launch_server
(
self
.
model_name
,
self
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
[
"--pp-size"
,
pp_size
,
"--chunked-prefill-size"
,
256
,
],
)
try
:
args
=
SimpleNamespace
(
num_shots
=
5
,
data_path
=
None
,
num_questions
=
200
,
max_new_tokens
=
512
,
parallel
=
128
,
host
=
"http://127.0.0.1"
,
port
=
int
(
self
.
base_url
.
split
(
":"
)[
-
1
]),
)
metrics
=
run_eval
(
args
)
time
.
sleep
(
5
)
return
metrics
finally
:
kill_process_tree
(
process
.
pid
)
def
test_baseline_accuracy
(
self
):
metrics
=
self
.
run_gsm8k_test
(
pp_size
=
1
)
print
(
f
"[Qwen Baseline]
{
metrics
=
}
"
)
self
.
assertGreater
(
metrics
[
"accuracy"
],
0.39
)
def
test_pp_consistency
(
self
):
baseline
=
self
.
run_gsm8k_test
(
pp_size
=
1
)
pp_metrics
=
self
.
run_gsm8k_test
(
pp_size
=
2
)
print
(
f
"[Qwen PP Comparison] Baseline:
{
baseline
}
| PP:
{
pp_metrics
}
"
)
self
.
assertAlmostEqual
(
pp_metrics
[
"accuracy"
],
baseline
[
"accuracy"
],
delta
=
0.01
,
msg
=
f
"PP accuracy exceeds 1% (baseline:
{
baseline
[
'accuracy'
]
}
, pp:
{
pp_metrics
[
'accuracy'
]
}
)"
,
)
class
TestFixedBugs
(
unittest
.
TestCase
):
class
TestFixedBugs
(
unittest
.
TestCase
):
def
test_chunked_prefill_with_small_bs
(
self
):
def
test_chunked_prefill_with_small_bs
(
self
):
model
=
DEFAULT_MODEL_NAME_FOR_TEST
model
=
DEFAULT_MODEL_NAME_FOR_TEST
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment