Unverified Commit 00974e4f authored by Shangming Cai's avatar Shangming Cai Committed by GitHub
Browse files

[CI] Refactor disaggregation tests (#10068)


Signed-off-by: default avatarShangming Cai <csmthu@gmail.com>
parent 5f1eb204
import time
import requests
from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
CustomTestCase,
popen_with_error_check,
)
class TestDisaggregationBase(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.process_lb, cls.process_decode, cls.process_prefill = None, None, None
pass
@classmethod
def launch_lb(cls):
lb_command = [
"python3",
"-m",
"sglang_router.launch_router",
"--pd-disaggregation",
"--mini-lb", # FIXME: remove this
"--prefill",
cls.prefill_url,
"--decode",
cls.decode_url,
"--host",
cls.base_host,
"--port",
cls.lb_port,
]
print("Starting load balancer:", " ".join(lb_command))
cls.process_lb = popen_with_error_check(lb_command)
cls.wait_server_ready(cls.lb_url + "/health")
@classmethod
def wait_server_ready(cls, url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH):
start_time = time.perf_counter()
while True:
try:
response = requests.get(url)
if response.status_code == 200:
print(f"Server {url} is ready")
return
except Exception:
pass
if time.perf_counter() - start_time > timeout:
raise RuntimeError(f"Server {url} failed to start in {timeout}s")
time.sleep(1)
@classmethod
def tearDownClass(cls):
for process in [cls.process_lb, cls.process_decode, cls.process_prefill]:
if process:
try:
kill_process_tree(process.pid)
except Exception as e:
print(f"Error killing process {process.pid}: {e}")
# wait for 5 seconds
time.sleep(5)
......@@ -139,6 +139,7 @@ suites = {
TestFile("lora/test_lora_llama4.py", 600),
TestFile("test_disaggregation.py", 499),
TestFile("test_disaggregation_different_tp.py", 155),
TestFile("test_disaggregation_pp.py", 60),
TestFile("test_full_deepseek_v3.py", 333),
],
"per-commit-8-gpu-b200": [
......
......@@ -7,21 +7,19 @@ from urllib.parse import urlparse
import requests
from sglang.srt.utils import kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
from sglang.test.test_disaggregation_utils import TestDisaggregationBase
from sglang.test.test_utils import (
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_pd_server,
popen_with_error_check,
)
class TestDisaggregationAccuracy(CustomTestCase):
class TestDisaggregationAccuracy(TestDisaggregationBase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
......@@ -44,25 +42,7 @@ class TestDisaggregationAccuracy(CustomTestCase):
cls.wait_server_ready(cls.prefill_url + "/health")
cls.wait_server_ready(cls.decode_url + "/health")
lb_command = [
"python3",
"-m",
"sglang_router.launch_router",
"--pd-disaggregation",
"--mini-lb", # FIXME: remove this
"--prefill",
cls.prefill_url,
"--decode",
cls.decode_url,
"--host",
cls.base_host,
"--port",
cls.lb_port,
]
print("Starting load balancer:", " ".join(lb_command))
cls.process_lb = popen_with_error_check(lb_command)
cls.wait_server_ready(cls.lb_url + "/health")
cls.launch_lb()
@classmethod
def start_prefill(cls):
......@@ -102,34 +82,6 @@ class TestDisaggregationAccuracy(CustomTestCase):
other_args=decode_args,
)
@classmethod
def wait_server_ready(cls, url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH):
start_time = time.perf_counter()
while True:
try:
response = requests.get(url)
if response.status_code == 200:
print(f"Server {url} is ready")
return
except Exception:
pass
if time.perf_counter() - start_time > timeout:
raise RuntimeError(f"Server {url} failed to start in {timeout}s")
time.sleep(1)
@classmethod
def tearDownClass(cls):
for process in [cls.process_lb, cls.process_decode, cls.process_prefill]:
if process:
try:
kill_process_tree(process.pid)
except Exception as e:
print(f"Error killing process {process.pid}: {e}")
# wait for 5 seconds
time.sleep(5)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
......@@ -199,7 +151,7 @@ class TestDisaggregationAccuracy(CustomTestCase):
json.loads(output)
class TestDisaggregationMooncakeFailure(CustomTestCase):
class TestDisaggregationMooncakeFailure(TestDisaggregationBase):
@classmethod
def setUpClass(cls):
# set DISAGGREGATION_TEST_FAILURE_PROB to simulate failure
......@@ -225,25 +177,12 @@ class TestDisaggregationMooncakeFailure(CustomTestCase):
cls.wait_server_ready(cls.prefill_url + "/health")
cls.wait_server_ready(cls.decode_url + "/health")
lb_command = [
"python3",
"-m",
"sglang_router.launch_router",
"--pd-disaggregation",
"--mini-lb", # FIXME: remove this
"--prefill",
cls.prefill_url,
"--decode",
cls.decode_url,
"--host",
cls.base_host,
"--port",
cls.lb_port,
]
cls.launch_lb()
print("Starting load balancer:", " ".join(lb_command))
cls.process_lb = popen_with_error_check(lb_command)
cls.wait_server_ready(cls.lb_url + "/health")
@classmethod
def tearDownClass(cls):
os.environ.pop("DISAGGREGATION_TEST_FAILURE_PROB")
super().tearDownClass()
@classmethod
def start_prefill(cls):
......@@ -283,36 +222,6 @@ class TestDisaggregationMooncakeFailure(CustomTestCase):
other_args=decode_args,
)
@classmethod
def wait_server_ready(cls, url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH):
start_time = time.perf_counter()
while True:
try:
response = requests.get(url)
if response.status_code == 200:
print(f"Server {url} is ready")
return
except Exception:
pass
if time.perf_counter() - start_time > timeout:
raise RuntimeError(f"Server {url} failed to start in {timeout}s")
time.sleep(1)
@classmethod
def tearDownClass(cls):
# unset DISAGGREGATION_TEST_FAILURE_PROB
os.environ.pop("DISAGGREGATION_TEST_FAILURE_PROB")
for process in [cls.process_lb, cls.process_decode, cls.process_prefill]:
if process:
try:
kill_process_tree(process.pid)
except Exception as e:
print(f"Error killing process {process.pid}: {e}")
# wait for 5 seconds
time.sleep(5)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
......@@ -341,7 +250,7 @@ class TestDisaggregationMooncakeFailure(CustomTestCase):
raise e from health_check_error
class TestDisaggregationMooncakeSpec(CustomTestCase):
class TestDisaggregationMooncakeSpec(TestDisaggregationBase):
@classmethod
def setUpClass(cls):
......@@ -380,41 +289,7 @@ class TestDisaggregationMooncakeSpec(CustomTestCase):
cls.wait_server_ready(cls.prefill_url + "/health")
cls.wait_server_ready(cls.decode_url + "/health")
lb_command = [
"python3",
"-m",
"sglang_router.launch_router",
"--pd-disaggregation",
"--mini-lb", # FIXME: remove this
"--prefill",
cls.prefill_url,
"--decode",
cls.decode_url,
"--host",
cls.base_host,
"--port",
cls.lb_port,
]
print("Starting load balancer:", " ".join(lb_command))
cls.process_lb = popen_with_error_check(lb_command)
cls.wait_server_ready(cls.lb_url + "/health")
@classmethod
def wait_server_ready(cls, url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH):
start_time = time.perf_counter()
while True:
try:
response = requests.get(url)
if response.status_code == 200:
print(f"Server {url} is ready")
return
except Exception:
pass
if time.perf_counter() - start_time > timeout:
raise RuntimeError(f"Server {url} failed to start in {timeout}s")
time.sleep(1)
cls.launch_lb()
@classmethod
def start_prefill(cls):
......@@ -454,18 +329,6 @@ class TestDisaggregationMooncakeSpec(CustomTestCase):
other_args=decode_args,
)
@classmethod
def tearDownClass(cls):
for process in [cls.process_lb, cls.process_decode, cls.process_prefill]:
if process:
try:
kill_process_tree(process.pid)
except Exception as e:
print(f"Error killing process {process.pid}: {e}")
# wait for 5 seconds
time.sleep(5)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
......@@ -482,7 +345,7 @@ class TestDisaggregationMooncakeSpec(CustomTestCase):
self.assertGreater(metrics["accuracy"], 0.20)
class TestDisaggregationSimulatedRetract(CustomTestCase):
class TestDisaggregationSimulatedRetract(TestDisaggregationBase):
@classmethod
def setUpClass(cls):
os.environ["SGLANG_TEST_RETRACT"] = "true"
......@@ -506,25 +369,12 @@ class TestDisaggregationSimulatedRetract(CustomTestCase):
cls.wait_server_ready(cls.prefill_url + "/health")
cls.wait_server_ready(cls.decode_url + "/health")
lb_command = [
"python3",
"-m",
"sglang_router.launch_router",
"--pd-disaggregation",
"--mini-lb", # FIXME: remove this
"--prefill",
cls.prefill_url,
"--decode",
cls.decode_url,
"--host",
cls.base_host,
"--port",
cls.lb_port,
]
cls.launch_lb()
print("Starting load balancer:", " ".join(lb_command))
cls.process_lb = popen_with_error_check(lb_command)
cls.wait_server_ready(cls.lb_url + "/health")
@classmethod
def tearDownClass(cls):
os.environ.pop("SGLANG_TEST_RETRACT")
super().tearDownClass()
@classmethod
def start_prefill(cls):
......@@ -564,35 +414,6 @@ class TestDisaggregationSimulatedRetract(CustomTestCase):
other_args=decode_args,
)
@classmethod
def wait_server_ready(cls, url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH):
start_time = time.perf_counter()
while True:
try:
response = requests.get(url)
if response.status_code == 200:
print(f"Server {url} is ready")
return
except Exception:
pass
if time.perf_counter() - start_time > timeout:
raise RuntimeError(f"Server {url} failed to start in {timeout}s")
time.sleep(1)
@classmethod
def tearDownClass(cls):
os.environ.pop("SGLANG_TEST_RETRACT")
for process in [cls.process_lb, cls.process_decode, cls.process_prefill]:
if process:
try:
kill_process_tree(process.pid)
except Exception as e:
print(f"Error killing process {process.pid}: {e}")
# wait for 5 seconds
time.sleep(5)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
......
import os
import subprocess
import time
import unittest
from types import SimpleNamespace
from urllib.parse import urlparse
import requests
from sglang.srt.utils import kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
from sglang.test.test_disaggregation_utils import TestDisaggregationBase
from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST_MLA,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_pd_server,
popen_with_error_check,
)
class TestDisaggregationMooncakePrefillLargerTP(CustomTestCase):
class TestDisaggregationMooncakePrefillLargerTP(TestDisaggregationBase):
@classmethod
def setUpClass(cls):
# Temporarily disable JIT DeepGEMM
......@@ -46,25 +41,7 @@ class TestDisaggregationMooncakePrefillLargerTP(CustomTestCase):
cls.wait_server_ready(cls.prefill_url + "/health")
cls.wait_server_ready(cls.decode_url + "/health")
lb_command = [
"python3",
"-m",
"sglang_router.launch_router",
"--pd-disaggregation",
"--mini-lb", # FIXME: remove this
"--prefill",
cls.prefill_url,
"--decode",
cls.decode_url,
"--host",
cls.base_host,
"--port",
cls.lb_port,
]
print("Starting load balancer:", " ".join(lb_command))
cls.process_lb = popen_with_error_check(lb_command)
cls.wait_server_ready(cls.lb_url + "/health")
cls.launch_lb()
@classmethod
def start_prefill(cls):
......@@ -104,39 +81,6 @@ class TestDisaggregationMooncakePrefillLargerTP(CustomTestCase):
other_args=decode_args,
)
@classmethod
def wait_server_ready(cls, url, timeout=60):
start_time = time.perf_counter()
while True:
try:
response = requests.get(url)
if response.status_code == 200:
print(f"Server {url} is ready")
return
except Exception:
pass
if time.perf_counter() - start_time > timeout:
raise RuntimeError(f"Server {url} failed to start in {timeout}s")
time.sleep(1)
@classmethod
def tearDownClass(cls):
# Restore JIT DeepGEMM environment variable
if cls.original_jit_deepgemm is not None:
os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = cls.original_jit_deepgemm
else:
os.environ.pop("SGL_ENABLE_JIT_DEEPGEMM", None)
for process in [cls.process_lb, cls.process_decode, cls.process_prefill]:
if process:
try:
kill_process_tree(process.pid)
except Exception as e:
print(f"Error killing process {process.pid}: {e}")
# wait for 5 seconds
time.sleep(5)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
......@@ -153,7 +97,7 @@ class TestDisaggregationMooncakePrefillLargerTP(CustomTestCase):
self.assertGreater(metrics["accuracy"], 0.60)
class TestDisaggregationMooncakeDecodeLargerTP(CustomTestCase):
class TestDisaggregationMooncakeDecodeLargerTP(TestDisaggregationBase):
@classmethod
def setUpClass(cls):
# Temporarily disable JIT DeepGEMM
......@@ -180,25 +124,7 @@ class TestDisaggregationMooncakeDecodeLargerTP(CustomTestCase):
cls.wait_server_ready(cls.prefill_url + "/health")
cls.wait_server_ready(cls.decode_url + "/health")
lb_command = [
"python3",
"-m",
"sglang_router.launch_router",
"--pd-disaggregation",
"--mini-lb", # FIXME: remove this
"--prefill",
cls.prefill_url,
"--decode",
cls.decode_url,
"--host",
cls.base_host,
"--port",
cls.lb_port,
]
print("Starting load balancer:", " ".join(lb_command))
cls.process_lb = popen_with_error_check(lb_command)
cls.wait_server_ready(cls.lb_url + "/health")
cls.launch_lb()
@classmethod
def start_prefill(cls):
......@@ -238,39 +164,6 @@ class TestDisaggregationMooncakeDecodeLargerTP(CustomTestCase):
other_args=decode_args,
)
@classmethod
def wait_server_ready(cls, url, timeout=60):
start_time = time.perf_counter()
while True:
try:
response = requests.get(url)
if response.status_code == 200:
print(f"Server {url} is ready")
return
except Exception:
pass
if time.perf_counter() - start_time > timeout:
raise RuntimeError(f"Server {url} failed to start in {timeout}s")
time.sleep(1)
@classmethod
def tearDownClass(cls):
# Restore JIT DeepGEMM environment variable
if cls.original_jit_deepgemm is not None:
os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = cls.original_jit_deepgemm
else:
os.environ.pop("SGL_ENABLE_JIT_DEEPGEMM", None)
for process in [cls.process_lb, cls.process_decode, cls.process_prefill]:
if process:
try:
kill_process_tree(process.pid)
except Exception as e:
print(f"Error killing process {process.pid}: {e}")
# wait for 5 seconds
time.sleep(5)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
......
import json
import os
import random
import time
import unittest
from concurrent.futures import ThreadPoolExecutor
from types import SimpleNamespace
from typing import List, Optional
from urllib.parse import urlparse
import requests
from sglang.srt.utils import kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval
from sglang.test.runners import DEFAULT_PROMPTS
from sglang.test.test_disaggregation_utils import TestDisaggregationBase
from sglang.test.test_utils import (
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
popen_launch_pd_server,
)
class TestPDPPAccuracy(unittest.TestCase):
class TestDisaggregationPPAccuracy(TestDisaggregationBase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
......@@ -46,27 +36,7 @@ class TestPDPPAccuracy(unittest.TestCase):
cls.wait_server_ready(cls.prefill_url + "/health")
cls.wait_server_ready(cls.decode_url + "/health")
lb_command = [
"python3",
"-m",
"sglang_router.launch_router",
"--pd-disaggregation",
"--mini-lb", # FIXME: remove this
"--prefill",
cls.prefill_url,
"--decode",
cls.decode_url,
"--host",
cls.base_host,
"--port",
cls.lb_port,
]
print("Starting load balancer:", " ".join(lb_command))
cls.process_lb = subprocess.Popen(
lb_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
cls.wait_server_ready(cls.lb_url + "/health")
cls.launch_lb()
@classmethod
def start_prefill(cls):
......@@ -75,11 +45,11 @@ class TestPDPPAccuracy(unittest.TestCase):
"--disaggregation-mode",
"prefill",
"--tp-size",
"2",
"1",
"--pp-size",
"2",
"--disaggregation-ib-device",
"mlx5_roce0",
"mlx5_roce0,mlx5_roce1",
"--disable-overlap-schedule",
]
cls.process_prefill = popen_launch_pd_server(
......@@ -98,9 +68,9 @@ class TestPDPPAccuracy(unittest.TestCase):
"--tp",
"1",
"--base-gpu-id",
"1",
"2",
"--disaggregation-ib-device",
"mlx5_roce1",
"mlx5_roce2",
]
cls.process_decode = popen_launch_pd_server(
cls.model,
......@@ -109,10 +79,6 @@ class TestPDPPAccuracy(unittest.TestCase):
other_args=decode_args,
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
......@@ -120,8 +86,8 @@ class TestPDPPAccuracy(unittest.TestCase):
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
host=f"http://{self.base_host}",
port=int(self.lb_port),
)
metrics = run_eval(args)
print(f"{metrics=}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment