Unverified Commit 21af5c04 authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

[Fix] Compatibility between DP attention and pipeline parallelism (#10100)


Co-authored-by: default avatargemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
parent 012584ec
......@@ -32,6 +32,7 @@ from sglang.srt.configs.model_config import AttentionArch, ModelConfig
from sglang.srt.configs.update_config import adjust_config_with_unaligned_cpu_tp
from sglang.srt.constants import GPU_MEMORY_TYPE_WEIGHTS
from sglang.srt.distributed import (
get_pp_group,
get_tp_group,
get_world_group,
init_distributed_environment,
......@@ -639,6 +640,7 @@ class ModelRunner:
cpu_group=get_world_group().cpu_group,
)
self.tp_group = get_tp_group()
self.pp_group = get_pp_group()
self.attention_tp_group = get_attention_tp_group()
# Check memory for tensor parallelism
......@@ -1825,7 +1827,10 @@ class ModelRunner:
else:
raise ValueError(f"Invalid forward mode: {forward_batch.forward_mode}")
if forward_batch.global_num_tokens_cpu is not None:
if (
forward_batch.global_num_tokens_cpu is not None
and self.pp_group.is_last_rank
):
forward_batch.post_forward_mlp_sync_batch(ret)
return ret, can_run_cuda_graph
......
......@@ -14,11 +14,14 @@ import requests
from sglang.bench_one_batch_server import BenchArgs as OneBatchBenchArgs
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
from sglang.test.run_eval import run_eval
from sglang.test.test_utils import (
DEFAULT_MLA_MODEL_NAME_FOR_TEST,
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
is_in_ci,
popen_launch_server,
run_bench_one_batch_server,
......@@ -57,7 +60,7 @@ class TestPPAccuracy(unittest.TestCase):
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval(args)
metrics = run_eval_few_shot_gsm8k(args)
print(f"{metrics=}")
self.assertGreater(metrics["accuracy"], 0.74)
......@@ -88,6 +91,45 @@ class TestPPAccuracy(unittest.TestCase):
assert len(output_top_logprobs) == 16
class TestDPAttentionDP2PP2(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--trust-remote-code",
"--tp",
"2",
"--pp-size",
"2",
"--enable-dp-attention",
"--dp",
"2",
],
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_mgsm_en(self):
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mgsm_en",
num_examples=None,
num_threads=1024,
)
metrics = run_eval(args)
print(f"{metrics=}")
self.assertGreater(metrics["score"], 0.8)
class TestQwenPPAccuracy(unittest.TestCase):
@classmethod
def setUpClass(cls):
......@@ -117,7 +159,7 @@ class TestQwenPPAccuracy(unittest.TestCase):
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval(args)
metrics = run_eval_few_shot_gsm8k(args)
time.sleep(5)
return metrics
finally:
......@@ -172,7 +214,7 @@ class TestQwenPPTieWeightsAccuracy(unittest.TestCase):
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval(args)
metrics = run_eval_few_shot_gsm8k(args)
time.sleep(5)
return metrics
finally:
......@@ -224,7 +266,7 @@ class TestQwenMoePPAccuracy(unittest.TestCase):
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval(args)
metrics = run_eval_few_shot_gsm8k(args)
time.sleep(5)
return metrics
finally:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment