Unverified Commit 00974e4f authored by Shangming Cai's avatar Shangming Cai Committed by GitHub
Browse files

[CI] Refactor disaggregation tests (#10068)


Signed-off-by: default avatarShangming Cai <csmthu@gmail.com>
parent 5f1eb204
import time
import requests
from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
CustomTestCase,
popen_with_error_check,
)
class TestDisaggregationBase(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.process_lb, cls.process_decode, cls.process_prefill = None, None, None
pass
@classmethod
def launch_lb(cls):
lb_command = [
"python3",
"-m",
"sglang_router.launch_router",
"--pd-disaggregation",
"--mini-lb", # FIXME: remove this
"--prefill",
cls.prefill_url,
"--decode",
cls.decode_url,
"--host",
cls.base_host,
"--port",
cls.lb_port,
]
print("Starting load balancer:", " ".join(lb_command))
cls.process_lb = popen_with_error_check(lb_command)
cls.wait_server_ready(cls.lb_url + "/health")
@classmethod
def wait_server_ready(cls, url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH):
start_time = time.perf_counter()
while True:
try:
response = requests.get(url)
if response.status_code == 200:
print(f"Server {url} is ready")
return
except Exception:
pass
if time.perf_counter() - start_time > timeout:
raise RuntimeError(f"Server {url} failed to start in {timeout}s")
time.sleep(1)
@classmethod
def tearDownClass(cls):
for process in [cls.process_lb, cls.process_decode, cls.process_prefill]:
if process:
try:
kill_process_tree(process.pid)
except Exception as e:
print(f"Error killing process {process.pid}: {e}")
# wait for 5 seconds
time.sleep(5)
...@@ -139,6 +139,7 @@ suites = { ...@@ -139,6 +139,7 @@ suites = {
TestFile("lora/test_lora_llama4.py", 600), TestFile("lora/test_lora_llama4.py", 600),
TestFile("test_disaggregation.py", 499), TestFile("test_disaggregation.py", 499),
TestFile("test_disaggregation_different_tp.py", 155), TestFile("test_disaggregation_different_tp.py", 155),
TestFile("test_disaggregation_pp.py", 60),
TestFile("test_full_deepseek_v3.py", 333), TestFile("test_full_deepseek_v3.py", 333),
], ],
"per-commit-8-gpu-b200": [ "per-commit-8-gpu-b200": [
......
...@@ -7,21 +7,19 @@ from urllib.parse import urlparse ...@@ -7,21 +7,19 @@ from urllib.parse import urlparse
import requests import requests
from sglang.srt.utils import kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
from sglang.test.test_disaggregation_utils import TestDisaggregationBase
from sglang.test.test_utils import ( from sglang.test.test_utils import (
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST, DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST, DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST, DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_pd_server, popen_launch_pd_server,
popen_with_error_check,
) )
class TestDisaggregationAccuracy(CustomTestCase): class TestDisaggregationAccuracy(TestDisaggregationBase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST cls.model = DEFAULT_MODEL_NAME_FOR_TEST
...@@ -44,25 +42,7 @@ class TestDisaggregationAccuracy(CustomTestCase): ...@@ -44,25 +42,7 @@ class TestDisaggregationAccuracy(CustomTestCase):
cls.wait_server_ready(cls.prefill_url + "/health") cls.wait_server_ready(cls.prefill_url + "/health")
cls.wait_server_ready(cls.decode_url + "/health") cls.wait_server_ready(cls.decode_url + "/health")
lb_command = [ cls.launch_lb()
"python3",
"-m",
"sglang_router.launch_router",
"--pd-disaggregation",
"--mini-lb", # FIXME: remove this
"--prefill",
cls.prefill_url,
"--decode",
cls.decode_url,
"--host",
cls.base_host,
"--port",
cls.lb_port,
]
print("Starting load balancer:", " ".join(lb_command))
cls.process_lb = popen_with_error_check(lb_command)
cls.wait_server_ready(cls.lb_url + "/health")
@classmethod @classmethod
def start_prefill(cls): def start_prefill(cls):
...@@ -102,34 +82,6 @@ class TestDisaggregationAccuracy(CustomTestCase): ...@@ -102,34 +82,6 @@ class TestDisaggregationAccuracy(CustomTestCase):
other_args=decode_args, other_args=decode_args,
) )
@classmethod
def wait_server_ready(cls, url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH):
start_time = time.perf_counter()
while True:
try:
response = requests.get(url)
if response.status_code == 200:
print(f"Server {url} is ready")
return
except Exception:
pass
if time.perf_counter() - start_time > timeout:
raise RuntimeError(f"Server {url} failed to start in {timeout}s")
time.sleep(1)
@classmethod
def tearDownClass(cls):
for process in [cls.process_lb, cls.process_decode, cls.process_prefill]:
if process:
try:
kill_process_tree(process.pid)
except Exception as e:
print(f"Error killing process {process.pid}: {e}")
# wait for 5 seconds
time.sleep(5)
def test_gsm8k(self): def test_gsm8k(self):
args = SimpleNamespace( args = SimpleNamespace(
num_shots=5, num_shots=5,
...@@ -199,7 +151,7 @@ class TestDisaggregationAccuracy(CustomTestCase): ...@@ -199,7 +151,7 @@ class TestDisaggregationAccuracy(CustomTestCase):
json.loads(output) json.loads(output)
class TestDisaggregationMooncakeFailure(CustomTestCase): class TestDisaggregationMooncakeFailure(TestDisaggregationBase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
# set DISAGGREGATION_TEST_FAILURE_PROB to simulate failure # set DISAGGREGATION_TEST_FAILURE_PROB to simulate failure
...@@ -225,25 +177,12 @@ class TestDisaggregationMooncakeFailure(CustomTestCase): ...@@ -225,25 +177,12 @@ class TestDisaggregationMooncakeFailure(CustomTestCase):
cls.wait_server_ready(cls.prefill_url + "/health") cls.wait_server_ready(cls.prefill_url + "/health")
cls.wait_server_ready(cls.decode_url + "/health") cls.wait_server_ready(cls.decode_url + "/health")
lb_command = [ cls.launch_lb()
"python3",
"-m",
"sglang_router.launch_router",
"--pd-disaggregation",
"--mini-lb", # FIXME: remove this
"--prefill",
cls.prefill_url,
"--decode",
cls.decode_url,
"--host",
cls.base_host,
"--port",
cls.lb_port,
]
print("Starting load balancer:", " ".join(lb_command)) @classmethod
cls.process_lb = popen_with_error_check(lb_command) def tearDownClass(cls):
cls.wait_server_ready(cls.lb_url + "/health") os.environ.pop("DISAGGREGATION_TEST_FAILURE_PROB")
super().tearDownClass()
@classmethod @classmethod
def start_prefill(cls): def start_prefill(cls):
...@@ -283,36 +222,6 @@ class TestDisaggregationMooncakeFailure(CustomTestCase): ...@@ -283,36 +222,6 @@ class TestDisaggregationMooncakeFailure(CustomTestCase):
other_args=decode_args, other_args=decode_args,
) )
@classmethod
def wait_server_ready(cls, url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH):
start_time = time.perf_counter()
while True:
try:
response = requests.get(url)
if response.status_code == 200:
print(f"Server {url} is ready")
return
except Exception:
pass
if time.perf_counter() - start_time > timeout:
raise RuntimeError(f"Server {url} failed to start in {timeout}s")
time.sleep(1)
@classmethod
def tearDownClass(cls):
# unset DISAGGREGATION_TEST_FAILURE_PROB
os.environ.pop("DISAGGREGATION_TEST_FAILURE_PROB")
for process in [cls.process_lb, cls.process_decode, cls.process_prefill]:
if process:
try:
kill_process_tree(process.pid)
except Exception as e:
print(f"Error killing process {process.pid}: {e}")
# wait for 5 seconds
time.sleep(5)
def test_gsm8k(self): def test_gsm8k(self):
args = SimpleNamespace( args = SimpleNamespace(
num_shots=5, num_shots=5,
...@@ -341,7 +250,7 @@ class TestDisaggregationMooncakeFailure(CustomTestCase): ...@@ -341,7 +250,7 @@ class TestDisaggregationMooncakeFailure(CustomTestCase):
raise e from health_check_error raise e from health_check_error
class TestDisaggregationMooncakeSpec(CustomTestCase): class TestDisaggregationMooncakeSpec(TestDisaggregationBase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
...@@ -380,41 +289,7 @@ class TestDisaggregationMooncakeSpec(CustomTestCase): ...@@ -380,41 +289,7 @@ class TestDisaggregationMooncakeSpec(CustomTestCase):
cls.wait_server_ready(cls.prefill_url + "/health") cls.wait_server_ready(cls.prefill_url + "/health")
cls.wait_server_ready(cls.decode_url + "/health") cls.wait_server_ready(cls.decode_url + "/health")
lb_command = [ cls.launch_lb()
"python3",
"-m",
"sglang_router.launch_router",
"--pd-disaggregation",
"--mini-lb", # FIXME: remove this
"--prefill",
cls.prefill_url,
"--decode",
cls.decode_url,
"--host",
cls.base_host,
"--port",
cls.lb_port,
]
print("Starting load balancer:", " ".join(lb_command))
cls.process_lb = popen_with_error_check(lb_command)
cls.wait_server_ready(cls.lb_url + "/health")
@classmethod
def wait_server_ready(cls, url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH):
start_time = time.perf_counter()
while True:
try:
response = requests.get(url)
if response.status_code == 200:
print(f"Server {url} is ready")
return
except Exception:
pass
if time.perf_counter() - start_time > timeout:
raise RuntimeError(f"Server {url} failed to start in {timeout}s")
time.sleep(1)
@classmethod @classmethod
def start_prefill(cls): def start_prefill(cls):
...@@ -454,18 +329,6 @@ class TestDisaggregationMooncakeSpec(CustomTestCase): ...@@ -454,18 +329,6 @@ class TestDisaggregationMooncakeSpec(CustomTestCase):
other_args=decode_args, other_args=decode_args,
) )
@classmethod
def tearDownClass(cls):
for process in [cls.process_lb, cls.process_decode, cls.process_prefill]:
if process:
try:
kill_process_tree(process.pid)
except Exception as e:
print(f"Error killing process {process.pid}: {e}")
# wait for 5 seconds
time.sleep(5)
def test_gsm8k(self): def test_gsm8k(self):
args = SimpleNamespace( args = SimpleNamespace(
num_shots=5, num_shots=5,
...@@ -482,7 +345,7 @@ class TestDisaggregationMooncakeSpec(CustomTestCase): ...@@ -482,7 +345,7 @@ class TestDisaggregationMooncakeSpec(CustomTestCase):
self.assertGreater(metrics["accuracy"], 0.20) self.assertGreater(metrics["accuracy"], 0.20)
class TestDisaggregationSimulatedRetract(CustomTestCase): class TestDisaggregationSimulatedRetract(TestDisaggregationBase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
os.environ["SGLANG_TEST_RETRACT"] = "true" os.environ["SGLANG_TEST_RETRACT"] = "true"
...@@ -506,25 +369,12 @@ class TestDisaggregationSimulatedRetract(CustomTestCase): ...@@ -506,25 +369,12 @@ class TestDisaggregationSimulatedRetract(CustomTestCase):
cls.wait_server_ready(cls.prefill_url + "/health") cls.wait_server_ready(cls.prefill_url + "/health")
cls.wait_server_ready(cls.decode_url + "/health") cls.wait_server_ready(cls.decode_url + "/health")
lb_command = [ cls.launch_lb()
"python3",
"-m",
"sglang_router.launch_router",
"--pd-disaggregation",
"--mini-lb", # FIXME: remove this
"--prefill",
cls.prefill_url,
"--decode",
cls.decode_url,
"--host",
cls.base_host,
"--port",
cls.lb_port,
]
print("Starting load balancer:", " ".join(lb_command)) @classmethod
cls.process_lb = popen_with_error_check(lb_command) def tearDownClass(cls):
cls.wait_server_ready(cls.lb_url + "/health") os.environ.pop("SGLANG_TEST_RETRACT")
super().tearDownClass()
@classmethod @classmethod
def start_prefill(cls): def start_prefill(cls):
...@@ -564,35 +414,6 @@ class TestDisaggregationSimulatedRetract(CustomTestCase): ...@@ -564,35 +414,6 @@ class TestDisaggregationSimulatedRetract(CustomTestCase):
other_args=decode_args, other_args=decode_args,
) )
@classmethod
def wait_server_ready(cls, url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH):
start_time = time.perf_counter()
while True:
try:
response = requests.get(url)
if response.status_code == 200:
print(f"Server {url} is ready")
return
except Exception:
pass
if time.perf_counter() - start_time > timeout:
raise RuntimeError(f"Server {url} failed to start in {timeout}s")
time.sleep(1)
@classmethod
def tearDownClass(cls):
os.environ.pop("SGLANG_TEST_RETRACT")
for process in [cls.process_lb, cls.process_decode, cls.process_prefill]:
if process:
try:
kill_process_tree(process.pid)
except Exception as e:
print(f"Error killing process {process.pid}: {e}")
# wait for 5 seconds
time.sleep(5)
def test_gsm8k(self): def test_gsm8k(self):
args = SimpleNamespace( args = SimpleNamespace(
num_shots=5, num_shots=5,
......
import os import os
import subprocess
import time import time
import unittest import unittest
from types import SimpleNamespace from types import SimpleNamespace
from urllib.parse import urlparse from urllib.parse import urlparse
import requests
from sglang.srt.utils import kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
from sglang.test.test_disaggregation_utils import TestDisaggregationBase
from sglang.test.test_utils import ( from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST_MLA, DEFAULT_MODEL_NAME_FOR_TEST_MLA,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST, DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_pd_server, popen_launch_pd_server,
popen_with_error_check,
) )
class TestDisaggregationMooncakePrefillLargerTP(CustomTestCase): class TestDisaggregationMooncakePrefillLargerTP(TestDisaggregationBase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
# Temporarily disable JIT DeepGEMM # Temporarily disable JIT DeepGEMM
...@@ -46,25 +41,7 @@ class TestDisaggregationMooncakePrefillLargerTP(CustomTestCase): ...@@ -46,25 +41,7 @@ class TestDisaggregationMooncakePrefillLargerTP(CustomTestCase):
cls.wait_server_ready(cls.prefill_url + "/health") cls.wait_server_ready(cls.prefill_url + "/health")
cls.wait_server_ready(cls.decode_url + "/health") cls.wait_server_ready(cls.decode_url + "/health")
lb_command = [ cls.launch_lb()
"python3",
"-m",
"sglang_router.launch_router",
"--pd-disaggregation",
"--mini-lb", # FIXME: remove this
"--prefill",
cls.prefill_url,
"--decode",
cls.decode_url,
"--host",
cls.base_host,
"--port",
cls.lb_port,
]
print("Starting load balancer:", " ".join(lb_command))
cls.process_lb = popen_with_error_check(lb_command)
cls.wait_server_ready(cls.lb_url + "/health")
@classmethod @classmethod
def start_prefill(cls): def start_prefill(cls):
...@@ -104,39 +81,6 @@ class TestDisaggregationMooncakePrefillLargerTP(CustomTestCase): ...@@ -104,39 +81,6 @@ class TestDisaggregationMooncakePrefillLargerTP(CustomTestCase):
other_args=decode_args, other_args=decode_args,
) )
@classmethod
def wait_server_ready(cls, url, timeout=60):
start_time = time.perf_counter()
while True:
try:
response = requests.get(url)
if response.status_code == 200:
print(f"Server {url} is ready")
return
except Exception:
pass
if time.perf_counter() - start_time > timeout:
raise RuntimeError(f"Server {url} failed to start in {timeout}s")
time.sleep(1)
@classmethod
def tearDownClass(cls):
# Restore JIT DeepGEMM environment variable
if cls.original_jit_deepgemm is not None:
os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = cls.original_jit_deepgemm
else:
os.environ.pop("SGL_ENABLE_JIT_DEEPGEMM", None)
for process in [cls.process_lb, cls.process_decode, cls.process_prefill]:
if process:
try:
kill_process_tree(process.pid)
except Exception as e:
print(f"Error killing process {process.pid}: {e}")
# wait for 5 seconds
time.sleep(5)
def test_gsm8k(self): def test_gsm8k(self):
args = SimpleNamespace( args = SimpleNamespace(
num_shots=5, num_shots=5,
...@@ -153,7 +97,7 @@ class TestDisaggregationMooncakePrefillLargerTP(CustomTestCase): ...@@ -153,7 +97,7 @@ class TestDisaggregationMooncakePrefillLargerTP(CustomTestCase):
self.assertGreater(metrics["accuracy"], 0.60) self.assertGreater(metrics["accuracy"], 0.60)
class TestDisaggregationMooncakeDecodeLargerTP(CustomTestCase): class TestDisaggregationMooncakeDecodeLargerTP(TestDisaggregationBase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
# Temporarily disable JIT DeepGEMM # Temporarily disable JIT DeepGEMM
...@@ -180,25 +124,7 @@ class TestDisaggregationMooncakeDecodeLargerTP(CustomTestCase): ...@@ -180,25 +124,7 @@ class TestDisaggregationMooncakeDecodeLargerTP(CustomTestCase):
cls.wait_server_ready(cls.prefill_url + "/health") cls.wait_server_ready(cls.prefill_url + "/health")
cls.wait_server_ready(cls.decode_url + "/health") cls.wait_server_ready(cls.decode_url + "/health")
lb_command = [ cls.launch_lb()
"python3",
"-m",
"sglang_router.launch_router",
"--pd-disaggregation",
"--mini-lb", # FIXME: remove this
"--prefill",
cls.prefill_url,
"--decode",
cls.decode_url,
"--host",
cls.base_host,
"--port",
cls.lb_port,
]
print("Starting load balancer:", " ".join(lb_command))
cls.process_lb = popen_with_error_check(lb_command)
cls.wait_server_ready(cls.lb_url + "/health")
@classmethod @classmethod
def start_prefill(cls): def start_prefill(cls):
...@@ -238,39 +164,6 @@ class TestDisaggregationMooncakeDecodeLargerTP(CustomTestCase): ...@@ -238,39 +164,6 @@ class TestDisaggregationMooncakeDecodeLargerTP(CustomTestCase):
other_args=decode_args, other_args=decode_args,
) )
@classmethod
def wait_server_ready(cls, url, timeout=60):
start_time = time.perf_counter()
while True:
try:
response = requests.get(url)
if response.status_code == 200:
print(f"Server {url} is ready")
return
except Exception:
pass
if time.perf_counter() - start_time > timeout:
raise RuntimeError(f"Server {url} failed to start in {timeout}s")
time.sleep(1)
@classmethod
def tearDownClass(cls):
# Restore JIT DeepGEMM environment variable
if cls.original_jit_deepgemm is not None:
os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = cls.original_jit_deepgemm
else:
os.environ.pop("SGL_ENABLE_JIT_DEEPGEMM", None)
for process in [cls.process_lb, cls.process_decode, cls.process_prefill]:
if process:
try:
kill_process_tree(process.pid)
except Exception as e:
print(f"Error killing process {process.pid}: {e}")
# wait for 5 seconds
time.sleep(5)
def test_gsm8k(self): def test_gsm8k(self):
args = SimpleNamespace( args = SimpleNamespace(
num_shots=5, num_shots=5,
......
import json
import os
import random
import time import time
import unittest import unittest
from concurrent.futures import ThreadPoolExecutor
from types import SimpleNamespace from types import SimpleNamespace
from typing import List, Optional from urllib.parse import urlparse
import requests
from sglang.srt.utils import kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval from sglang.test.few_shot_gsm8k import run_eval
from sglang.test.runners import DEFAULT_PROMPTS from sglang.test.test_disaggregation_utils import TestDisaggregationBase
from sglang.test.test_utils import ( from sglang.test.test_utils import (
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST, DEFAULT_URL_FOR_TEST,
CustomTestCase, popen_launch_pd_server,
popen_launch_server,
) )
class TestPDPPAccuracy(unittest.TestCase): class TestDisaggregationPPAccuracy(TestDisaggregationBase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST cls.model = DEFAULT_MODEL_NAME_FOR_TEST
...@@ -46,27 +36,7 @@ class TestPDPPAccuracy(unittest.TestCase): ...@@ -46,27 +36,7 @@ class TestPDPPAccuracy(unittest.TestCase):
cls.wait_server_ready(cls.prefill_url + "/health") cls.wait_server_ready(cls.prefill_url + "/health")
cls.wait_server_ready(cls.decode_url + "/health") cls.wait_server_ready(cls.decode_url + "/health")
lb_command = [ cls.launch_lb()
"python3",
"-m",
"sglang_router.launch_router",
"--pd-disaggregation",
"--mini-lb", # FIXME: remove this
"--prefill",
cls.prefill_url,
"--decode",
cls.decode_url,
"--host",
cls.base_host,
"--port",
cls.lb_port,
]
print("Starting load balancer:", " ".join(lb_command))
cls.process_lb = subprocess.Popen(
lb_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
cls.wait_server_ready(cls.lb_url + "/health")
@classmethod @classmethod
def start_prefill(cls): def start_prefill(cls):
...@@ -75,11 +45,11 @@ class TestPDPPAccuracy(unittest.TestCase): ...@@ -75,11 +45,11 @@ class TestPDPPAccuracy(unittest.TestCase):
"--disaggregation-mode", "--disaggregation-mode",
"prefill", "prefill",
"--tp-size", "--tp-size",
"2", "1",
"--pp-size", "--pp-size",
"2", "2",
"--disaggregation-ib-device", "--disaggregation-ib-device",
"mlx5_roce0", "mlx5_roce0,mlx5_roce1",
"--disable-overlap-schedule", "--disable-overlap-schedule",
] ]
cls.process_prefill = popen_launch_pd_server( cls.process_prefill = popen_launch_pd_server(
...@@ -98,9 +68,9 @@ class TestPDPPAccuracy(unittest.TestCase): ...@@ -98,9 +68,9 @@ class TestPDPPAccuracy(unittest.TestCase):
"--tp", "--tp",
"1", "1",
"--base-gpu-id", "--base-gpu-id",
"1", "2",
"--disaggregation-ib-device", "--disaggregation-ib-device",
"mlx5_roce1", "mlx5_roce2",
] ]
cls.process_decode = popen_launch_pd_server( cls.process_decode = popen_launch_pd_server(
cls.model, cls.model,
...@@ -109,10 +79,6 @@ class TestPDPPAccuracy(unittest.TestCase): ...@@ -109,10 +79,6 @@ class TestPDPPAccuracy(unittest.TestCase):
other_args=decode_args, other_args=decode_args,
) )
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_gsm8k(self): def test_gsm8k(self):
args = SimpleNamespace( args = SimpleNamespace(
num_shots=5, num_shots=5,
...@@ -120,8 +86,8 @@ class TestPDPPAccuracy(unittest.TestCase): ...@@ -120,8 +86,8 @@ class TestPDPPAccuracy(unittest.TestCase):
num_questions=200, num_questions=200,
max_new_tokens=512, max_new_tokens=512,
parallel=128, parallel=128,
host="http://127.0.0.1", host=f"http://{self.base_host}",
port=int(self.base_url.split(":")[-1]), port=int(self.lb_port),
) )
metrics = run_eval(args) metrics = run_eval(args)
print(f"{metrics=}") print(f"{metrics=}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment