Unverified Commit 3f23d8cd authored by Shenggui Li's avatar Shenggui Li Committed by GitHub
Browse files

added support for tied weights in qwen pipeline parallelism (#6546)

parent 1a399799
...@@ -84,7 +84,7 @@ jobs: ...@@ -84,7 +84,7 @@ jobs:
bash scripts/ci_install_dependency.sh bash scripts/ci_install_dependency.sh
- name: Run test - name: Run test
timeout-minutes: 25 timeout-minutes: 30
run: | run: |
cd test/srt cd test/srt
python3 run_suite.py --suite per-commit-2-gpu python3 run_suite.py --suite per-commit-2-gpu
......
...@@ -386,7 +386,10 @@ class Qwen2ForCausalLM(nn.Module): ...@@ -386,7 +386,10 @@ class Qwen2ForCausalLM(nn.Module):
self.model = Qwen2Model( self.model = Qwen2Model(
config, quant_config=quant_config, prefix=add_prefix("model", prefix) config, quant_config=quant_config, prefix=add_prefix("model", prefix)
) )
if config.tie_word_embeddings:
# handle the lm head on different pp ranks
if self.pp_group.is_last_rank:
if self.pp_group.world_size == 1 and config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens self.lm_head = self.model.embed_tokens
else: else:
self.lm_head = ParallelLMHead( self.lm_head = ParallelLMHead(
...@@ -395,6 +398,24 @@ class Qwen2ForCausalLM(nn.Module): ...@@ -395,6 +398,24 @@ class Qwen2ForCausalLM(nn.Module):
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("lm_head", prefix), prefix=add_prefix("lm_head", prefix),
) )
else:
# ranks other than the last rank will have a placeholder layer
self.lm_head = PPMissingLayer()
# perform weight tying for PP
if self.pp_group.world_size > 1 and config.tie_word_embeddings:
if self.pp_group.is_first_rank:
self.pp_group.send(
self.model.embed_tokens.weight, dst=self.pp_group.last_rank
)
else:
emb_token_weight = self.pp_group.recv(
size=(config.vocab_size, config.hidden_size),
dtype=next(self.model.parameters()).dtype,
src=self.pp_group.first_rank,
)
self.lm_head.weight.copy_(emb_token_weight)
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
...@@ -470,6 +491,14 @@ class Qwen2ForCausalLM(nn.Module): ...@@ -470,6 +491,14 @@ class Qwen2ForCausalLM(nn.Module):
# the checkpoint. Skip them. # the checkpoint. Skip them.
continue continue
if self.config.tie_word_embeddings and "lm_head.weight" in name: if self.config.tie_word_embeddings and "lm_head.weight" in name:
if self.pp_group.world_size > 1 and self.pp_group.is_last_rank:
# Handle pp weight tying here
# find the embed_tokens.weight in the weights
embed_token_weights = next(
filter(lambda x: x[0] == "model.embed_tokens.weight", weights)
)[1]
loaded_weight = embed_token_weights
else:
continue continue
if name.startswith("model.vision_tower") and name not in params_dict: if name.startswith("model.vision_tower") and name not in params_dict:
continue continue
......
...@@ -21,7 +21,7 @@ from sglang.srt.layers.pooler import Pooler, PoolingType ...@@ -21,7 +21,7 @@ from sglang.srt.layers.pooler import Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.utils import get_layer_id from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
...@@ -249,7 +249,10 @@ class Qwen3ForCausalLM(nn.Module): ...@@ -249,7 +249,10 @@ class Qwen3ForCausalLM(nn.Module):
self.model = Qwen3Model( self.model = Qwen3Model(
config, quant_config=quant_config, prefix=add_prefix("model", prefix) config, quant_config=quant_config, prefix=add_prefix("model", prefix)
) )
if config.tie_word_embeddings:
# handle the lm head on different pp ranks
if self.pp_group.is_last_rank:
if self.pp_group.world_size == 1 and config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens self.lm_head = self.model.embed_tokens
else: else:
self.lm_head = ParallelLMHead( self.lm_head = ParallelLMHead(
...@@ -258,6 +261,24 @@ class Qwen3ForCausalLM(nn.Module): ...@@ -258,6 +261,24 @@ class Qwen3ForCausalLM(nn.Module):
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("lm_head", prefix), prefix=add_prefix("lm_head", prefix),
) )
else:
# ranks other than the last rank will have a placeholder layer
self.lm_head = PPMissingLayer()
# perform weight tying for PP
if self.pp_group.world_size > 1 and config.tie_word_embeddings:
if self.pp_group.is_first_rank:
self.pp_group.send(
self.model.embed_tokens.weight, dst=self.pp_group.last_rank
)
else:
emb_token_weight = self.pp_group.recv(
size=(config.vocab_size, config.hidden_size),
dtype=next(self.model.parameters()).dtype,
src=self.pp_group.first_rank,
)
self.lm_head.weight.copy_(emb_token_weight)
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
...@@ -330,6 +351,14 @@ class Qwen3ForCausalLM(nn.Module): ...@@ -330,6 +351,14 @@ class Qwen3ForCausalLM(nn.Module):
# the checkpoint. Skip them. # the checkpoint. Skip them.
continue continue
if self.config.tie_word_embeddings and "lm_head.weight" in name: if self.config.tie_word_embeddings and "lm_head.weight" in name:
if self.pp_group.world_size > 1 and self.pp_group.is_last_rank:
# Handle pp weight tying here
# find the embed_tokens.weight in the weights
embed_token_weights = next(
filter(lambda x: x[0] == "model.embed_tokens.weight", weights)
)[1]
loaded_weight = embed_token_weights
else:
continue continue
if name.startswith("model.vision_tower") and name not in params_dict: if name.startswith("model.vision_tower") and name not in params_dict:
continue continue
......
...@@ -116,6 +116,62 @@ class TestQwenPPAccuracy(unittest.TestCase): ...@@ -116,6 +116,62 @@ class TestQwenPPAccuracy(unittest.TestCase):
) )
class TestQwenPPTieWeightsAccuracy(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.base_url = "http://127.0.0.1:23334" # different ports to avoid conflicts
cls.model_name = (
"Qwen/Qwen3-0.6B" # qwen3 < 8B all have tie_word_embeddings = True
)
def run_gsm8k_test(self, pp_size):
process = popen_launch_server(
self.model_name,
self.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--pp-size",
pp_size,
"--chunked-prefill-size",
256,
],
)
try:
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval(args)
time.sleep(5)
return metrics
finally:
kill_process_tree(process.pid)
def test_baseline_accuracy(self):
metrics = self.run_gsm8k_test(pp_size=1)
print(f"[Qwen Baseline] {metrics=}")
self.assertGreater(metrics["accuracy"], 0.39)
def test_pp_consistency(self):
baseline = self.run_gsm8k_test(pp_size=1)
pp_metrics = self.run_gsm8k_test(pp_size=2)
print(f"[Qwen PP Comparison] Baseline: {baseline} | PP: {pp_metrics}")
self.assertAlmostEqual(
pp_metrics["accuracy"],
baseline["accuracy"],
delta=0.01,
msg=f"PP accuracy exceeds 1% (baseline: {baseline['accuracy']}, pp: {pp_metrics['accuracy']})",
)
class TestFixedBugs(unittest.TestCase): class TestFixedBugs(unittest.TestCase):
def test_chunked_prefill_with_small_bs(self): def test_chunked_prefill_with_small_bs(self):
model = DEFAULT_MODEL_NAME_FOR_TEST model = DEFAULT_MODEL_NAME_FOR_TEST
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment