Unverified Commit 9eb50ecc authored by Keyang Ru's avatar Keyang Ru Committed by GitHub
Browse files

[router] Improve the router e2e tests (#10102)

parent b3e7a2ce
...@@ -105,11 +105,11 @@ jobs: ...@@ -105,11 +105,11 @@ jobs:
pip install fastapi uvicorn orjson pip install fastapi uvicorn orjson
pytest -q -m integration pytest -q -m integration
- name: Run e2e test - name: Run Python E2E tests
run: | run: |
bash scripts/killall_sglang.sh "nuk_gpus" bash scripts/killall_sglang.sh "nuk_gpus"
cd sgl-router/py_test cd sgl-router
python3 run_suite.py pytest -m e2e -s -vv -o log_cli=true --log-cli-level=INFO
finish: finish:
needs: [unit-test-rust, e2e-python] needs: [unit-test-rust, e2e-python]
......
import socket
import subprocess
import time
from types import SimpleNamespace
from urllib.parse import urlparse
import pytest
import requests
from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
)
def _find_available_port() -> int:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("127.0.0.1", 0))
return s.getsockname()[1]
def _parse_url(base_url: str) -> tuple[str, str]:
"""Parse a base URL and return (host, port) as strings.
This is more robust than simple string splitting and supports different schemes
and URL shapes like trailing paths.
"""
parsed = urlparse(base_url)
return parsed.hostname or "127.0.0.1", (
str(parsed.port) if parsed.port is not None else ""
)
def _wait_router_health(base_url: str, timeout: float) -> None:
start = time.perf_counter()
with requests.Session() as session:
while time.perf_counter() - start < timeout:
try:
r = session.get(f"{base_url}/health", timeout=5)
if r.status_code == 200:
return
except requests.RequestException:
pass
time.sleep(2)
raise TimeoutError("Router failed to become healthy in time")
def _popen_launch_router(
model: str,
base_url: str,
dp_size: int,
timeout: float,
policy: str = "cache_aware",
) -> subprocess.Popen:
host, port = _parse_url(base_url)
prom_port = _find_available_port()
cmd = [
"python3",
"-m",
"sglang_router.launch_server",
"--model-path",
model,
"--host",
host,
"--port",
port,
"--dp",
str(dp_size),
"--router-policy",
policy,
"--allow-auto-truncate",
"--router-prometheus-port",
str(prom_port),
"--router-prometheus-host",
"127.0.0.1",
]
proc = subprocess.Popen(cmd)
_wait_router_health(base_url, timeout)
return proc
def _popen_launch_worker(
model: str,
base_url: str,
*,
dp_size: int | None = None,
api_key: str | None = None,
) -> subprocess.Popen:
host, port = _parse_url(base_url)
cmd = [
"python3",
"-m",
"sglang.launch_server",
"--model-path",
model,
"--host",
host,
"--port",
port,
"--base-gpu-id",
"0",
]
if dp_size is not None:
cmd += ["--dp-size", str(dp_size)]
if api_key is not None:
cmd += ["--api-key", api_key]
return subprocess.Popen(cmd)
def _popen_launch_router_only(
base_url: str,
policy: str = "round_robin",
timeout: float = 120.0,
*,
dp_aware: bool = False,
api_key: str | None = None,
) -> subprocess.Popen:
host, port = _parse_url(base_url)
prom_port = _find_available_port()
cmd = [
"python3",
"-m",
"sglang_router.launch_router",
"--host",
host,
"--port",
port,
"--policy",
policy,
]
if dp_aware:
cmd += ["--dp-aware"]
if api_key is not None:
cmd += ["--api-key", api_key]
cmd += [
"--prometheus-port",
str(prom_port),
"--prometheus-host",
"127.0.0.1",
]
proc = subprocess.Popen(cmd)
_wait_router_health(base_url, timeout)
return proc
def _terminate(proc: subprocess.Popen, timeout: float = 120) -> None:
if proc is None:
return
proc.terminate()
start = time.perf_counter()
while proc.poll() is None:
if time.perf_counter() - start > timeout:
proc.kill()
break
time.sleep(1)
def pytest_configure(config):
config.addinivalue_line("markers", "e2e: mark as end-to-end test")
@pytest.fixture(scope="session")
def e2e_model() -> str:
# Always use the default test model
return DEFAULT_MODEL_NAME_FOR_TEST
@pytest.fixture
def e2e_router(e2e_model: str):
# Keep this available but tests below use router-only to avoid GPU contention
base_url = DEFAULT_URL_FOR_TEST
proc = _popen_launch_router(
e2e_model, base_url, dp_size=2, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
)
try:
yield SimpleNamespace(proc=proc, url=base_url)
finally:
_terminate(proc)
@pytest.fixture
def e2e_router_only_rr():
port = _find_available_port()
base_url = f"http://127.0.0.1:{port}"
proc = _popen_launch_router_only(base_url, policy="round_robin")
try:
yield SimpleNamespace(proc=proc, url=base_url)
finally:
_terminate(proc)
@pytest.fixture(scope="session")
def e2e_primary_worker(e2e_model: str):
port = _find_available_port()
base_url = f"http://127.0.0.1:{port}"
proc = _popen_launch_worker(e2e_model, base_url)
# Router health gate will handle worker readiness
try:
yield SimpleNamespace(proc=proc, url=base_url)
finally:
_terminate(proc)
@pytest.fixture
def e2e_router_only_rr_dp_aware_api():
"""Router-only with dp-aware enabled and an API key."""
port = _find_available_port()
base_url = f"http://127.0.0.1:{port}"
api_key = "secret"
proc = _popen_launch_router_only(
base_url, policy="round_robin", timeout=180.0, dp_aware=True, api_key=api_key
)
try:
yield SimpleNamespace(proc=proc, url=base_url, api_key=api_key)
finally:
_terminate(proc)
@pytest.fixture
def e2e_worker_dp2_api(e2e_model: str, e2e_router_only_rr_dp_aware_api):
"""Worker with dp-size=2 and the same API key as the dp-aware router."""
port = _find_available_port()
base_url = f"http://127.0.0.1:{port}"
api_key = e2e_router_only_rr_dp_aware_api.api_key
proc = _popen_launch_worker(e2e_model, base_url, dp_size=2, api_key=api_key)
try:
yield SimpleNamespace(proc=proc, url=base_url)
finally:
_terminate(proc)
import threading
import time
from types import SimpleNamespace
import pytest
import requests
from sglang.test.run_eval import run_eval
@pytest.mark.e2e
def test_mmlu(e2e_router_only_rr, e2e_primary_worker, e2e_model):
# Attach the primary worker to a fresh router-only instance (single model)
base = e2e_router_only_rr.url
r = requests.post(
f"{base}/add_worker", params={"url": e2e_primary_worker.url}, timeout=180
)
r.raise_for_status()
args = SimpleNamespace(
base_url=base,
model=e2e_model,
eval_name="mmlu",
num_examples=64,
num_threads=32,
temperature=0.1,
)
metrics = run_eval(args)
assert metrics["score"] >= 0.65
@pytest.mark.e2e
def test_add_and_remove_worker_live(e2e_router_only_rr, e2e_primary_worker, e2e_model):
base = e2e_router_only_rr.url
worker_url = e2e_primary_worker.url
r = requests.post(f"{base}/add_worker", params={"url": worker_url}, timeout=180)
r.raise_for_status()
with requests.Session() as s:
for i in range(8):
r = s.post(
f"{base}/v1/completions",
json={
"model": e2e_model,
"prompt": f"x{i}",
"max_tokens": 1,
"stream": False,
},
timeout=120,
)
r.raise_for_status()
# Remove the worker
r = requests.post(f"{base}/remove_worker", params={"url": worker_url}, timeout=60)
r.raise_for_status()
@pytest.mark.e2e
def test_lazy_fault_tolerance_live(e2e_router_only_rr, e2e_primary_worker, e2e_model):
base = e2e_router_only_rr.url
worker = e2e_primary_worker
r = requests.post(f"{base}/add_worker", params={"url": worker.url}, timeout=180)
r.raise_for_status()
def killer():
time.sleep(10)
try:
worker.proc.terminate()
except Exception:
pass
t = threading.Thread(target=killer, daemon=True)
t.start()
args = SimpleNamespace(
base_url=base,
model=e2e_model,
eval_name="mmlu",
num_examples=32,
num_threads=16,
temperature=0.0,
)
metrics = run_eval(args)
assert 0.0 <= metrics["score"] <= 1.0
@pytest.mark.e2e
def test_dp_aware_worker_expansion_and_api_key(
e2e_model,
e2e_router_only_rr_dp_aware_api,
e2e_worker_dp2_api,
):
"""
Launch a router-only instance in dp_aware mode and a single worker with dp_size=2
and API key protection. Verify expansion, auth enforcement, and basic eval.
"""
import os
router_url = e2e_router_only_rr_dp_aware_api.url
worker_url = e2e_worker_dp2_api.url
api_key = e2e_router_only_rr_dp_aware_api.api_key
# Attach worker; router should expand to dp_size logical workers
r = requests.post(
f"{router_url}/add_worker", params={"url": worker_url}, timeout=180
)
r.raise_for_status()
r = requests.get(f"{router_url}/list_workers", timeout=30)
r.raise_for_status()
urls = r.json().get("urls", [])
assert len(urls) == 2
assert set(urls) == {f"{worker_url}@0", f"{worker_url}@1"}
# Verify API key enforcement path-through
# 1) Without Authorization -> 401 from backend
r = requests.post(
f"{router_url}/v1/completions",
json={"model": e2e_model, "prompt": "hi", "max_tokens": 1},
timeout=60,
)
assert r.status_code == 401
# 2) With correct Authorization -> 200
r = requests.post(
f"{router_url}/v1/completions",
json={"model": e2e_model, "prompt": "hi", "max_tokens": 1},
headers={"Authorization": f"Bearer {api_key}"},
timeout=60,
)
assert r.status_code == 200
# Finally, run MMLU eval through the router with auth
os.environ["OPENAI_API_KEY"] = api_key
args = SimpleNamespace(
base_url=router_url,
model=e2e_model,
eval_name="mmlu",
num_examples=64,
num_threads=32,
temperature=0.1,
)
metrics = run_eval(args)
assert metrics["score"] >= 0.65
...@@ -44,6 +44,7 @@ def _parse_args() -> argparse.Namespace: ...@@ -44,6 +44,7 @@ def _parse_args() -> argparse.Namespace:
p.add_argument("--api-key", default=None) p.add_argument("--api-key", default=None)
p.add_argument("--max-payload-bytes", type=int, default=10 * 1024 * 1024) p.add_argument("--max-payload-bytes", type=int, default=10 * 1024 * 1024)
p.add_argument("--stream", action="store_true") p.add_argument("--stream", action="store_true")
p.add_argument("--dp-size", type=int, default=1)
p.add_argument("--crash-on-request", action="store_true") p.add_argument("--crash-on-request", action="store_true")
p.add_argument("--health-fail-after-ms", type=int, default=0) p.add_argument("--health-fail-after-ms", type=int, default=0)
return p.parse_args() return p.parse_args()
...@@ -125,12 +126,15 @@ def create_app(args: argparse.Namespace) -> FastAPI: ...@@ -125,12 +126,15 @@ def create_app(args: argparse.Namespace) -> FastAPI:
return JSONResponse({"data": [{"id": "mock", "object": "model"}]}) return JSONResponse({"data": [{"id": "mock", "object": "model"}]})
@app.get("/get_server_info") @app.get("/get_server_info")
async def get_server_info(): async def get_server_info(request: Request):
# Enforce API key on server info when required (used by dp_aware probing)
check_api_key(request)
return JSONResponse( return JSONResponse(
{ {
"worker_id": worker_id, "worker_id": worker_id,
"load_in_flight": _inflight, "load_in_flight": _inflight,
"cache": {"size": 0, "hit_rate": 0.0}, "cache": {"size": 0, "hit_rate": 0.0},
"dp_size": int(args.dp_size),
} }
) )
......
import pytest
import requests
@pytest.mark.integration
def test_payload_size_limit(router_manager, mock_workers):
# Start one backend and a router with a 1MB payload limit
_, urls, _ = mock_workers(n=1)
rh = router_manager.start_router(
worker_urls=urls,
policy="round_robin",
extra={"max_payload_size": 1 * 1024 * 1024}, # 1MB
)
# Payload just under 1MB should succeed
payload_small = {
"model": "test-model",
"prompt": "x" * int(0.5 * 1024 * 1024), # ~0.5MB
"max_tokens": 1,
"stream": False,
}
r = requests.post(f"{rh.url}/v1/completions", json=payload_small)
assert r.status_code == 200
# Payload over 1MB should fail with 413
payload_large = {
"model": "test-model",
"prompt": "x" * int(1.2 * 1024 * 1024), # ~1.2MB
"max_tokens": 1,
"stream": False,
}
r = requests.post(f"{rh.url}/v1/completions", json=payload_large)
assert r.status_code == 413
import argparse
import glob
from sglang.test.test_utils import TestFile, run_unittest_files
if __name__ == "__main__":
arg_parser = argparse.ArgumentParser()
arg_parser.add_argument(
"--timeout-per-file",
type=int,
default=2000,
help="The time limit for running one file in seconds.",
)
args = arg_parser.parse_args()
files = glob.glob("**/test_*.py", recursive=True)
# Exclude integration tests from the e2e suite; those are run separately via pytest -m integration
files = [
f
for f in files
if "/integration/" not in f and not f.startswith("integration/")
]
files.sort()
test_files = [TestFile(name=file) for file in files]
exit_code = run_unittest_files(test_files, args.timeout_per_file)
exit(exit_code)
import multiprocessing
import time
import unittest
from types import SimpleNamespace
def terminate_process(process: multiprocessing.Process, timeout: float = 1.0) -> None:
"""Terminate a process gracefully, with forced kill as fallback.
Args:
process: The process to terminate
timeout: Seconds to wait for graceful termination before forcing kill
"""
if not process.is_alive():
return
process.terminate()
process.join(timeout=timeout)
if process.is_alive():
process.kill() # Force kill if terminate didn't work
process.join()
class TestLaunchRouter(unittest.TestCase):
def setUp(self):
"""Set up default arguments for router tests."""
self.default_args = SimpleNamespace(
host="127.0.0.1",
port=30000,
policy="cache_aware",
worker_startup_timeout_secs=600,
worker_startup_check_interval=10,
cache_threshold=0.5,
balance_abs_threshold=32,
balance_rel_threshold=1.0001,
eviction_interval_secs=60,
max_tree_size=2**24,
max_payload_size=256 * 1024 * 1024, # 256MB
verbose=False,
log_dir=None,
log_level=None,
service_discovery=False,
selector=None,
service_discovery_port=80,
service_discovery_namespace=None,
dp_aware=False,
prometheus_port=None,
prometheus_host=None,
request_timeout_secs=60,
max_concurrent_requests=64,
cors_allowed_origins=[],
pd_disaggregation=False,
prefill=None,
decode=None,
worker_urls=[],
retry_max_retries=3,
retry_initial_backoff_ms=100,
retry_max_backoff_ms=10_000,
retry_backoff_multiplier=2.0,
retry_jitter_factor=0.1,
cb_failure_threshold=5,
cb_success_threshold=2,
cb_timeout_duration_secs=30,
cb_window_duration_secs=60,
disable_retries=False,
disable_circuit_breaker=False,
model_path=None,
tokenizer_path=None,
)
def create_router_args(self, **kwargs):
"""Create router arguments by updating default args with provided kwargs."""
args_dict = vars(self.default_args).copy()
args_dict.update(kwargs)
return SimpleNamespace(**args_dict)
def run_router_process(self, args):
"""Run router in a separate process and verify it starts successfully."""
def run_router():
try:
from sglang_router.launch_router import launch_router
router = launch_router(args)
if router is None:
return 1
return 0
except Exception as e:
print(e)
return 1
process = multiprocessing.Process(target=run_router)
try:
process.start()
# Wait 3 seconds
time.sleep(3)
# Process is still running means router started successfully
self.assertTrue(process.is_alive())
finally:
terminate_process(process)
def test_launch_router_common(self):
args = self.create_router_args(worker_urls=["http://localhost:8000"])
self.run_router_process(args)
def test_launch_router_with_empty_worker_urls(self):
args = self.create_router_args(worker_urls=[])
self.run_router_process(
args
) # Should start successfully with empty worker list
def test_launch_router_with_service_discovery(self):
# Test router startup with service discovery enabled but no selectors
args = self.create_router_args(
worker_urls=[], service_discovery=True, selector=["app=test-worker"]
)
self.run_router_process(args)
def test_launch_router_with_service_discovery_namespace(self):
# Test router startup with service discovery enabled and namespace specified
args = self.create_router_args(
worker_urls=[],
service_discovery=True,
selector=["app=test-worker"],
service_discovery_namespace="test-namespace",
)
self.run_router_process(args)
def test_launch_router_common_with_dp_aware(self):
args = self.create_router_args(
worker_urls=["http://localhost:8000"],
dp_aware=True,
)
self.run_router_process(args)
def test_launch_router_with_empty_worker_urls_with_dp_aware(self):
args = self.create_router_args(
worker_urls=[],
dp_aware=True,
)
self.run_router_process(args)
def test_launch_router_common_with_dp_aware_service_discovery(self):
# Test launch router with bot srevice_discovery and dp_aware enabled
# Should fail since service_discovery and dp_aware is conflict
args = self.create_router_args(
worker_urls=["http://localhost:8000"],
dp_aware=True,
service_discovery=True,
selector=["app=test-worker"],
)
def run_router():
try:
from sglang_router.launch_router import launch_router
router = launch_router(args)
if router is None:
return 1
return 0
except Exception as e:
print(e)
return 1
process = multiprocessing.Process(target=run_router)
try:
process.start()
# Wait 3 seconds
time.sleep(3)
# Should fail since service_discovery and dp_aware is conflict
self.assertFalse(process.is_alive())
finally:
terminate_process(process)
def test_launch_router_pd_mode_basic(self):
"""Test basic PD router functionality without actually starting servers."""
# This test just verifies the PD router can be created and configured
# without actually starting it (which would require real prefill/decode servers)
from sglang_router.launch_router import RouterArgs
from sglang_router.router import PolicyType, Router
# Test RouterArgs parsing for PD mode
# Simulate the parsed args structure from argparse with action="append"
args = self.create_router_args(
pd_disaggregation=True,
policy="power_of_two", # PowerOfTwo is only valid in PD mode
prefill=[
["http://prefill1:8080", "9000"],
["http://prefill2:8080", "none"],
],
decode=[
["http://decode1:8081"],
["http://decode2:8081"],
],
worker_urls=[], # Empty for PD mode
)
router_args = RouterArgs.from_cli_args(args)
self.assertTrue(router_args.pd_disaggregation)
self.assertEqual(router_args.policy, "power_of_two")
self.assertEqual(len(router_args.prefill_urls), 2)
self.assertEqual(len(router_args.decode_urls), 2)
# Verify the parsed URLs and bootstrap ports
self.assertEqual(router_args.prefill_urls[0], ("http://prefill1:8080", 9000))
self.assertEqual(router_args.prefill_urls[1], ("http://prefill2:8080", None))
self.assertEqual(router_args.decode_urls[0], "http://decode1:8081")
self.assertEqual(router_args.decode_urls[1], "http://decode2:8081")
# Test Router creation in PD mode
router = Router.from_args(router_args)
self.assertIsNotNone(router)
def test_policy_validation(self):
"""Test that policy validation works correctly for PD and regular modes."""
from sglang_router.launch_router import RouterArgs, launch_router
# Test 1: PowerOfTwo requires at least 2 workers
args = self.create_router_args(
pd_disaggregation=False,
policy="power_of_two",
worker_urls=["http://localhost:8000"], # Only 1 worker
)
# Should raise error
with self.assertRaises(ValueError) as cm:
launch_router(args)
self.assertIn(
"Power-of-two policy requires at least 2 workers",
str(cm.exception),
)
# Test 2: PowerOfTwo with sufficient workers should succeed
args = self.create_router_args(
pd_disaggregation=False,
policy="power_of_two",
worker_urls=["http://localhost:8000", "http://localhost:8001"], # 2 workers
)
# This should not raise an error (validation passes)
# Test 3: All policies now work in both modes
# Regular mode with RoundRobin
args = self.create_router_args(
pd_disaggregation=False,
policy="round_robin",
worker_urls=["http://localhost:8000"],
)
# This should not raise validation error
# PD mode with RoundRobin (now supported!)
args = self.create_router_args(
pd_disaggregation=True,
policy="round_robin",
prefill=[["http://prefill1:8080", "9000"]],
decode=[["http://decode1:8081"]],
worker_urls=[],
)
# This should not raise validation error
def test_pd_service_discovery_args_parsing(self):
"""Test PD service discovery CLI argument parsing."""
import argparse
from sglang_router.launch_router import RouterArgs
parser = argparse.ArgumentParser()
RouterArgs.add_cli_args(parser)
args = parser.parse_args(
[
"--pd-disaggregation",
"--service-discovery",
"--prefill-selector",
"app=sglang",
"component=prefill",
"--decode-selector",
"app=sglang",
"component=decode",
"--service-discovery-port",
"8000",
"--service-discovery-namespace",
"production",
"--policy",
"cache_aware",
]
)
router_args = RouterArgs.from_cli_args(args)
self.assertTrue(router_args.pd_disaggregation)
self.assertTrue(router_args.service_discovery)
self.assertEqual(
router_args.prefill_selector, {"app": "sglang", "component": "prefill"}
)
self.assertEqual(
router_args.decode_selector, {"app": "sglang", "component": "decode"}
)
self.assertEqual(router_args.service_discovery_port, 8000)
self.assertEqual(router_args.service_discovery_namespace, "production")
def test_regular_service_discovery_args_parsing(self):
"""Test regular mode service discovery CLI argument parsing."""
import argparse
from sglang_router.launch_router import RouterArgs
parser = argparse.ArgumentParser()
RouterArgs.add_cli_args(parser)
args = parser.parse_args(
[
"--service-discovery",
"--selector",
"app=sglang-worker",
"environment=staging",
"--service-discovery-port",
"8000",
"--policy",
"round_robin",
]
)
router_args = RouterArgs.from_cli_args(args)
self.assertFalse(router_args.pd_disaggregation)
self.assertTrue(router_args.service_discovery)
self.assertEqual(
router_args.selector, {"app": "sglang-worker", "environment": "staging"}
)
self.assertEqual(router_args.prefill_selector, {})
self.assertEqual(router_args.decode_selector, {})
def test_empty_worker_urls_args_parsing(self):
"""Test that router accepts no worker URLs and defaults to empty list."""
import argparse
from sglang_router.launch_router import RouterArgs
parser = argparse.ArgumentParser()
RouterArgs.add_cli_args(parser)
# Test with no --worker-urls argument at all
args = parser.parse_args(["--policy", "random", "--port", "30000"])
router_args = RouterArgs.from_cli_args(args)
self.assertEqual(router_args.worker_urls, [])
# Test with explicit empty --worker-urls
args = parser.parse_args(["--worker-urls", "--policy", "random"])
router_args = RouterArgs.from_cli_args(args)
self.assertEqual(router_args.worker_urls, [])
if __name__ == "__main__":
unittest.main()
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment