Unverified Commit 21b9a4b4 authored by Keyang Ru's avatar Keyang Ru Committed by GitHub
Browse files

[router] Introduce router integration tests (#10086)

parent db37422c
...@@ -95,7 +95,15 @@ jobs: ...@@ -95,7 +95,15 @@ jobs:
cd sgl-router cd sgl-router
source "$HOME/.cargo/env" source "$HOME/.cargo/env"
pip install pytest pytest-cov pytest-xdist pip install pytest pytest-cov pytest-xdist
pytest -q py_test/unit pytest -q py_test/unit --cov=sglang_router --cov-report=term-missing --cov-fail-under=80
- name: Run Python integration tests
run: |
cd sgl-router
source "$HOME/.cargo/env"
# Integration tests use FastAPI/uvicorn for mock workers
pip install fastapi uvicorn orjson
pytest -q -m integration
- name: Run e2e test - name: Run e2e test
run: | run: |
......
"""Test package root for router Python tests."""
"""Shared fixtures for router integration tests."""
"""
Lightweight mock worker HTTP server for router integration tests.
Implements minimal endpoints used by the router:
- GET /health, /health_generate
- POST /generate, /v1/completions, /v1/chat/completions
- POST /flush_cache
- GET /get_server_info, /get_model_info, /v1/models
Behavior knobs are controlled via CLI flags to simulate failures, latency, and load.
"""
import argparse
import asyncio
import json
import os
import random
import signal
import sys
import time
from contextlib import asynccontextmanager
from typing import Optional
import uvicorn
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import JSONResponse, PlainTextResponse, StreamingResponse
# Global state (per-process)
_inflight = 0
_failures_seen = 0
def _parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser()
p.add_argument("--host", default="127.0.0.1")
p.add_argument("--port", type=int, required=True)
p.add_argument("--worker-id", default=None)
p.add_argument("--latency-ms", type=int, default=0)
p.add_argument("--timeout", action="store_true")
p.add_argument("--status-code", type=int, default=200)
p.add_argument("--fail-first-n", type=int, default=0)
p.add_argument("--random-fail-rate", type=float, default=0.0)
p.add_argument("--require-api-key", action="store_true")
p.add_argument("--api-key", default=None)
p.add_argument("--max-payload-bytes", type=int, default=10 * 1024 * 1024)
p.add_argument("--stream", action="store_true")
p.add_argument("--crash-on-request", action="store_true")
p.add_argument("--health-fail-after-ms", type=int, default=0)
return p.parse_args()
def _extract_worker_id(args: argparse.Namespace) -> str:
if args.worker_id:
return str(args.worker_id)
# default to port (unique enough for tests)
return f"worker-{args.port}"
def create_app(args: argparse.Namespace) -> FastAPI:
app = FastAPI()
worker_id = _extract_worker_id(args)
start_ts = time.time()
crashed = {"done": False}
async def maybe_delay():
if args.latency_ms > 0:
await asyncio.sleep(args.latency_ms / 1000.0)
def should_fail() -> Optional[int]:
global _failures_seen
# Fail first N requests (500)
if args.fail_first_n > 0 and _failures_seen < args.fail_first_n:
_failures_seen += 1
return 500
# Random failure probability (500)
if args.random_fail_rate > 0.0 and random.random() < args.random_fail_rate:
return 500
# Forced status code override (non-200) for all responses
if args.status_code != 200:
return int(args.status_code)
return None
def check_api_key(request: Request):
if not args.require_api_key:
return
auth = request.headers.get("Authorization")
if not auth or not auth.startswith("Bearer "):
raise HTTPException(status_code=401, detail="Unauthorized")
key = auth.split(" ", 1)[1]
if args.api_key and key != args.api_key:
raise HTTPException(status_code=401, detail="Unauthorized")
@asynccontextmanager
async def track_inflight():
global _inflight
_inflight += 1
try:
yield
finally:
_inflight -= 1
@app.get("/health")
async def health():
if (
args.health_fail_after_ms
and (time.time() - start_ts) * 1000.0 >= args.health_fail_after_ms
):
return PlainTextResponse("bad", status_code=500)
return PlainTextResponse("ok", status_code=200)
@app.get("/health_generate")
async def health_generate():
return PlainTextResponse("ok", status_code=200)
@app.post("/flush_cache")
async def flush_cache():
return PlainTextResponse("ok", status_code=200)
@app.get("/get_model_info")
async def get_model_info():
return JSONResponse({"model": "mock", "vocab_size": 32000})
@app.get("/v1/models")
async def list_models():
return JSONResponse({"data": [{"id": "mock", "object": "model"}]})
@app.get("/get_server_info")
async def get_server_info():
return JSONResponse(
{
"worker_id": worker_id,
"load_in_flight": _inflight,
"cache": {"size": 0, "hit_rate": 0.0},
}
)
@app.get("/get_load")
async def get_load():
return JSONResponse({"load": _inflight})
def make_json_response(obj: dict, status_code: int = 200) -> JSONResponse:
resp = JSONResponse(obj, status_code=status_code)
resp.headers["X-Worker-Id"] = worker_id
return resp
async def handle_text_request(request: Request):
# Authorization
check_api_key(request)
# Payload limit
body = await request.body()
if len(body) > args.max_payload_bytes:
return make_json_response({"error": "payload too large"}, status_code=413)
# Simulate crash on first request
if args.crash_on_request and not crashed["done"]:
crashed["done"] = True
os._exit(1)
# Optional timeout (simulate hang)
if args.timeout:
await asyncio.sleep(3600)
# Optional latency
await maybe_delay()
# Optional failures
fail_code = should_fail()
if fail_code is not None and fail_code != 200:
return make_json_response(
{"error": f"mock failure {fail_code}"}, status_code=fail_code
)
# Build response echoing minimal shape
try:
data = await request.json()
except (json.JSONDecodeError, ValueError):
data = {}
now = time.time()
ret = {
"id": f"cmpl-{int(now*1000)}",
"object": "text_completion",
"created": int(now),
"model": "mock",
"choices": [
{
"text": "ok",
"index": 0,
"finish_reason": "stop",
}
],
"worker_id": worker_id,
"echo": data,
}
return make_json_response(ret, status_code=200)
async def handle_stream_request(request: Request):
check_api_key(request)
async def gen():
# minimal 2-chunk stream then [DONE]
for i in range(2):
await asyncio.sleep(0.01)
chunk = {
"choices": [{"delta": {"content": "x"}}],
"worker_id": worker_id,
}
yield f"data: {json.dumps(chunk)}\n\n"
yield "data: [DONE]\n\n"
headers = {"X-Worker-Id": worker_id}
return StreamingResponse(gen(), media_type="text/event-stream", headers=headers)
@app.post("/generate")
async def generate(request: Request):
async with track_inflight():
if args.stream:
return await handle_stream_request(request)
return await handle_text_request(request)
@app.post("/v1/completions")
async def completions(request: Request):
async with track_inflight():
if args.stream:
return await handle_stream_request(request)
return await handle_text_request(request)
@app.post("/v1/chat/completions")
async def chat_completions(request: Request):
async with track_inflight():
if args.stream:
return await handle_stream_request(request)
return await handle_text_request(request)
return app
def main() -> None:
args = _parse_args()
app = create_app(args)
# Handle SIGTERM gracefully for fast test teardown
signal.signal(signal.SIGTERM, lambda *_: sys.exit(0))
uvicorn.run(app, host=args.host, port=args.port, log_level="warning")
if __name__ == "__main__":
main()
import socket
def find_free_port() -> int:
"""Return an available TCP port on localhost."""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("127.0.0.1", 0))
return s.getsockname()[1]
import subprocess
import time
from dataclasses import dataclass
from typing import Dict, List, Optional
import requests
from .ports import find_free_port
@dataclass
class ProcHandle:
process: subprocess.Popen
url: str
class RouterManager:
"""Helper to spawn a router process and interact with admin endpoints."""
def __init__(self):
self._children: List[subprocess.Popen] = []
def start_router(
self,
worker_urls: Optional[List[str]] = None,
policy: str = "round_robin",
port: Optional[int] = None,
extra: Optional[Dict] = None,
# PD options
pd_disaggregation: bool = False,
prefill_urls: Optional[List[tuple]] = None,
decode_urls: Optional[List[str]] = None,
prefill_policy: Optional[str] = None,
decode_policy: Optional[str] = None,
) -> ProcHandle:
worker_urls = worker_urls or []
port = port or find_free_port()
cmd = [
"python3",
"-m",
"sglang_router.launch_router",
"--host",
"127.0.0.1",
"--port",
str(port),
"--policy",
policy,
]
# Avoid Prometheus port collisions by assigning a free port per router
prom_port = find_free_port()
cmd.extend(
["--prometheus-port", str(prom_port), "--prometheus-host", "127.0.0.1"]
)
if worker_urls:
cmd.extend(["--worker-urls", *worker_urls])
# PD routing configuration
if pd_disaggregation:
cmd.append("--pd-disaggregation")
if prefill_urls:
for url, bport in prefill_urls:
if bport is None:
cmd.extend(["--prefill", url, "none"])
else:
cmd.extend(["--prefill", url, str(bport)])
if decode_urls:
for url in decode_urls:
cmd.extend(["--decode", url])
if prefill_policy:
cmd.extend(["--prefill-policy", prefill_policy])
if decode_policy:
cmd.extend(["--decode-policy", decode_policy])
# Map supported extras to CLI flags (subset for integration)
if extra:
flag_map = {
"max_payload_size": "--max-payload-size",
"dp_aware": "--dp-aware",
"api_key": "--api-key",
# Health/monitoring
"worker_startup_check_interval": "--worker-startup-check-interval",
# Cache-aware tuning
"cache_threshold": "--cache-threshold",
"balance_abs_threshold": "--balance-abs-threshold",
"balance_rel_threshold": "--balance-rel-threshold",
# Retry
"retry_max_retries": "--retry-max-retries",
"retry_initial_backoff_ms": "--retry-initial-backoff-ms",
"retry_max_backoff_ms": "--retry-max-backoff-ms",
"retry_backoff_multiplier": "--retry-backoff-multiplier",
"retry_jitter_factor": "--retry-jitter-factor",
"disable_retries": "--disable-retries",
# Circuit breaker
"cb_failure_threshold": "--cb-failure-threshold",
"cb_success_threshold": "--cb-success-threshold",
"cb_timeout_duration_secs": "--cb-timeout-duration-secs",
"cb_window_duration_secs": "--cb-window-duration-secs",
"disable_circuit_breaker": "--disable-circuit-breaker",
# Rate limiting
"max_concurrent_requests": "--max-concurrent-requests",
"queue_size": "--queue-size",
"queue_timeout_secs": "--queue-timeout-secs",
"rate_limit_tokens_per_second": "--rate-limit-tokens-per-second",
}
for k, v in extra.items():
if v is None:
continue
flag = flag_map.get(k)
if not flag:
continue
if isinstance(v, bool):
if v:
cmd.append(flag)
else:
cmd.extend([flag, str(v)])
proc = subprocess.Popen(cmd)
self._children.append(proc)
url = f"http://127.0.0.1:{port}"
self._wait_health(url)
return ProcHandle(process=proc, url=url)
def _wait_health(self, base_url: str, timeout: float = 30.0):
start = time.time()
with requests.Session() as s:
while time.time() - start < timeout:
try:
r = s.get(f"{base_url}/health", timeout=2)
if r.status_code == 200:
return
except requests.RequestException:
pass
time.sleep(0.2)
raise TimeoutError(f"Router at {base_url} did not become healthy")
def add_worker(self, base_url: str, worker_url: str) -> None:
r = requests.post(f"{base_url}/add_worker", params={"url": worker_url})
assert r.status_code == 200, f"add_worker failed: {r.status_code} {r.text}"
def remove_worker(self, base_url: str, worker_url: str) -> None:
r = requests.post(f"{base_url}/remove_worker", params={"url": worker_url})
assert r.status_code == 200, f"remove_worker failed: {r.status_code} {r.text}"
def list_workers(self, base_url: str) -> list[str]:
r = requests.get(f"{base_url}/list_workers")
assert r.status_code == 200, f"list_workers failed: {r.status_code} {r.text}"
data = r.json()
return data.get("urls", [])
def stop_all(self):
for p in self._children:
if p.poll() is None:
p.terminate()
try:
p.wait(timeout=5)
except subprocess.TimeoutExpired:
p.kill()
self._children.clear()
"""Integration test package for the router."""
import os
import subprocess
import time
from pathlib import Path
from typing import Dict, Iterable, List, Tuple
import pytest
import requests
from ..fixtures.ports import find_free_port
from ..fixtures.router_manager import RouterManager
def pytest_configure(config):
config.addinivalue_line("markers", "integration: mark as router integration test")
@pytest.fixture
def router_manager() -> Iterable[RouterManager]:
mgr = RouterManager()
try:
yield mgr
finally:
mgr.stop_all()
def _spawn_mock_worker(args: List[str]) -> Tuple[subprocess.Popen, str, str]:
repo_root = Path(__file__).resolve().parents[2]
script = repo_root / "py_test" / "fixtures" / "mock_worker.py"
port = find_free_port()
worker_id = f"worker-{port}"
base_cmd = [
"python3",
str(script),
"--port",
str(port),
"--worker-id",
worker_id,
]
cmd = base_cmd + args
proc = subprocess.Popen(cmd)
url = f"http://127.0.0.1:{port}"
_wait_health(url)
return proc, url, worker_id
def _wait_health(url: str, timeout: float = 10.0):
start = time.time()
with requests.Session() as s:
while time.time() - start < timeout:
try:
r = s.get(f"{url}/health", timeout=1)
if r.status_code == 200:
return
except requests.RequestException:
pass
time.sleep(0.1)
raise TimeoutError(f"Mock worker at {url} did not become healthy")
@pytest.fixture
def mock_worker():
"""Start a single healthy mock worker; yields (process, url, worker_id)."""
proc, url, worker_id = _spawn_mock_worker([])
try:
yield proc, url, worker_id
finally:
if proc.poll() is None:
proc.terminate()
try:
proc.wait(timeout=3)
except subprocess.TimeoutExpired:
proc.kill()
@pytest.fixture
def mock_workers():
"""Factory to start N workers with custom args.
Usage:
procs, urls, ids = mock_workers(n=3, args=["--latency-ms", "5"]) # same args for all
...
"""
procs: List[subprocess.Popen] = []
def _start(n: int, args: List[str] | None = None):
args = args or []
new_procs: List[subprocess.Popen] = []
urls: List[str] = []
ids: List[str] = []
for _ in range(n):
p, url, wid = _spawn_mock_worker(args)
procs.append(p)
new_procs.append(p)
urls.append(url)
ids.append(wid)
return new_procs, urls, ids
try:
yield _start
finally:
for p in procs:
if p.poll() is None:
p.terminate()
try:
p.wait(timeout=3)
except subprocess.TimeoutExpired:
p.kill()
import collections
import concurrent.futures
import uuid
import pytest
import requests
@pytest.mark.integration
def test_cache_aware_affinity(mock_workers, router_manager):
# Two workers; same prompt should stick to one due to cache tree
_, urls, ids = mock_workers(n=2)
rh = router_manager.start_router(worker_urls=urls, policy="cache_aware")
counts = collections.Counter()
with requests.Session() as s:
for i in range(12):
r = s.post(
f"{rh.url}/v1/completions",
json={
"model": "test-model",
"prompt": "repeated prompt for cache",
"max_tokens": 1,
"stream": False,
},
)
assert r.status_code == 200
wid = r.headers.get("X-Worker-Id") or r.json().get("worker_id")
counts[wid] += 1
# Expect strong skew toward one worker (tree match); majority > 80%
top = max(counts.values())
assert top >= 10, counts
@pytest.mark.integration
def test_cache_aware_diverse_prompts_balances(mock_workers, router_manager):
# Add latency so concurrent requests overlap and influence load-based selection
_, urls, ids = mock_workers(n=3, args=["--latency-ms", "30"])
rh = router_manager.start_router(
worker_urls=urls,
policy="cache_aware",
extra={
"cache_threshold": 0.99,
"balance_abs_threshold": 0,
"balance_rel_threshold": 1.0,
},
)
counts = collections.Counter()
def call(i):
# Use diverse, unrelated prompts to avoid prefix matches entirely
prompt = str(uuid.uuid4())
r = requests.post(
f"{rh.url}/v1/completions",
json={
"model": "test-model",
"prompt": prompt,
"max_tokens": 1,
"stream": False,
},
timeout=5,
)
assert r.status_code == 200
return r.headers.get("X-Worker-Id") or r.json().get("worker_id")
with concurrent.futures.ThreadPoolExecutor(max_workers=16) as ex:
for wid in ex.map(call, range(40)):
counts[wid] += 1
# Expect participation of at least two workers
assert sum(1 for v in counts.values() if v > 0) >= 2, counts
import collections
import concurrent.futures
import time
import pytest
import requests
@pytest.mark.integration
def test_power_of_two_prefers_less_loaded(mock_workers, router_manager):
# Start two workers: one slow (higher inflight), one fast
# Router monitors /get_load and Power-of-Two uses cached loads to choose
# Start one slow and one fast worker using the fixture factory
procs_slow, urls_slow, ids_slow = mock_workers(n=1, args=["--latency-ms", "200"])
procs_fast, urls_fast, ids_fast = mock_workers(n=1, args=["--latency-ms", "0"])
procs = procs_slow + procs_fast
urls = urls_slow + urls_fast
ids = ids_slow + ids_fast
slow_id = ids_slow[0]
rh = router_manager.start_router(
worker_urls=urls,
policy="power_of_two",
extra={"worker_startup_check_interval": 1},
)
# Prime: fire a burst to create measurable load on slow worker, then wait for monitor tick
def _prime_call(i):
try:
requests.post(
f"{rh.url}/v1/completions",
json={
"model": "test-model",
"prompt": f"warm-{i}",
"max_tokens": 1,
"stream": False,
},
timeout=5,
)
except Exception:
pass
with concurrent.futures.ThreadPoolExecutor(max_workers=32) as ex:
list(ex.map(_prime_call, range(128)))
time.sleep(2)
# Apply direct background load on the slow worker to amplify load diff
def _direct_load(i):
try:
requests.post(
f"{slow_url}/v1/completions",
json={
"model": "test-model",
"prompt": f"bg-{i}",
"max_tokens": 1,
"stream": False,
},
timeout=5,
)
except Exception:
pass
with concurrent.futures.ThreadPoolExecutor(max_workers=32) as ex:
list(ex.map(_direct_load, range(128)))
time.sleep(1)
def call(i):
r = requests.post(
f"{rh.url}/v1/completions",
json={
"model": "test-model",
"prompt": f"p{i}",
"max_tokens": 1,
"stream": False,
},
timeout=5,
)
assert r.status_code == 200
return r.headers.get("X-Worker-Id") or r.json().get("worker_id")
counts = collections.Counter()
with concurrent.futures.ThreadPoolExecutor(max_workers=32) as ex:
for wid in ex.map(call, range(200)):
counts[wid] += 1
# Expect the slow worker (higher latency/inflight) to receive fewer requests
fast_worker_id = [i for i in ids if i != slow_id][0]
assert counts[slow_id] < counts[fast_worker_id], counts
import collections
import math
import pytest
import requests
@pytest.mark.integration
def test_random_distribution(mock_workers, router_manager):
procs, urls, ids = mock_workers(n=4)
rh = router_manager.start_router(worker_urls=urls, policy="random")
counts = collections.Counter()
N = 200
with requests.Session() as s:
for i in range(N):
r = s.post(
f"{rh.url}/v1/completions",
json={
"model": "test-model",
"prompt": f"p{i}",
"max_tokens": 1,
"stream": False,
},
)
assert r.status_code == 200
wid = r.headers.get("X-Worker-Id") or r.json().get("worker_id")
counts[wid] += 1
# simple statistical tolerance: each worker should be within ±50% of mean
mean = N / len(ids)
for wid in ids:
assert 0.5 * mean <= counts[wid] <= 1.5 * mean, counts
import collections
import time
import pytest
import requests
@pytest.mark.integration
def test_round_robin_distribution(mock_workers, router_manager):
procs, urls, ids = mock_workers(n=3)
rh = router_manager.start_router(worker_urls=urls, policy="round_robin")
counts = collections.Counter()
with requests.Session() as s:
for i in range(30):
r = s.post(
f"{rh.url}/v1/completions",
json={
"model": "test-model",
"prompt": f"hello {i}",
"max_tokens": 1,
"stream": False,
},
)
assert r.status_code == 200
wid = r.headers.get("X-Worker-Id") or r.json().get("worker_id")
assert wid in ids
counts[wid] += 1
# Expect near-even distribution across 3 workers
# 30 requests -> ideally 10 each; allow small tolerance ±3
for wid in ids:
assert 7 <= counts[wid] <= 13, counts
import pytest
import requests
@pytest.mark.integration
def test_router_api_key_enforcement(router_manager, mock_workers):
# Start backend requiring API key; router should forward Authorization header transparently
_, urls, _ = mock_workers(
n=1, args=["--require-api-key", "--api-key", "correct_api_key"]
)
rh = router_manager.start_router(
worker_urls=urls,
policy="round_robin",
extra={},
)
# No auth -> 401
r = requests.post(
f"{rh.url}/v1/completions",
json={"model": "test-model", "prompt": "x", "max_tokens": 1, "stream": False},
)
assert r.status_code == 401
# Invalid auth -> 401
r = requests.post(
f"{rh.url}/v1/completions",
json={"model": "test-model", "prompt": "x", "max_tokens": 1, "stream": False},
headers={"Authorization": "Bearer wrong"},
)
assert r.status_code == 401
# Correct auth -> 200
r = requests.post(
f"{rh.url}/v1/completions",
json={"model": "test-model", "prompt": "x", "max_tokens": 1, "stream": False},
headers={"Authorization": "Bearer correct_api_key"},
)
assert r.status_code == 200
import time
import pytest
import requests
@pytest.mark.integration
def test_circuit_breaker_opens_and_recovers(router_manager, mock_workers):
# A single worker that fails first 3 requests, then succeeds
_, [wurl], _ = mock_workers(n=1, args=["--fail-first-n", "3"]) # fails first 3
rh = router_manager.start_router(
worker_urls=[wurl],
policy="round_robin",
extra={
"cb_failure_threshold": 3,
"cb_success_threshold": 2,
"cb_timeout_duration_secs": 3,
"cb_window_duration_secs": 10,
"disable_retries": True,
},
)
def post_once():
return requests.post(
f"{rh.url}/v1/completions",
json={
"model": "test-model",
"prompt": "trigger",
"max_tokens": 1,
"stream": False,
},
timeout=3,
)
saw_503 = False
for _ in range(8):
r = post_once()
if r.status_code == 503:
saw_503 = True
break
assert saw_503, "circuit breaker did not open to return 503"
time.sleep(4)
r1 = post_once()
r2 = post_once()
assert r1.status_code == 200 and r2.status_code == 200
@pytest.mark.integration
def test_circuit_breaker_half_open_failure_reopens(router_manager, mock_workers):
_, [wurl], _ = mock_workers(n=1, args=["--status-code", "500"]) # always fail
rh = router_manager.start_router(
worker_urls=[wurl],
policy="round_robin",
extra={
"cb_failure_threshold": 2,
"cb_success_threshold": 2,
"cb_timeout_duration_secs": 2,
"cb_window_duration_secs": 5,
"disable_retries": True,
},
)
def post_once():
return requests.post(
f"{rh.url}/v1/completions",
json={
"model": "test-model",
"prompt": "x",
"max_tokens": 1,
"stream": False,
},
timeout=3,
)
opened = False
for _ in range(8):
r = post_once()
if r.status_code == 503:
opened = True
break
assert opened, "circuit breaker did not open"
time.sleep(3)
r = post_once()
assert r.status_code == 500
r2 = post_once()
assert r2.status_code == 503
@pytest.mark.integration
def test_circuit_breaker_disable_flag(router_manager, mock_workers):
_, [wurl], _ = mock_workers(n=1, args=["--status-code", "500"]) # always fail
rh = router_manager.start_router(
worker_urls=[wurl],
policy="round_robin",
extra={
"disable_circuit_breaker": True,
"disable_retries": True,
},
)
r = requests.post(
f"{rh.url}/v1/completions",
json={
"model": "test-model",
"prompt": "x",
"max_tokens": 1,
"stream": False,
},
timeout=3,
)
assert r.status_code == 500
@pytest.mark.integration
def test_circuit_breaker_per_worker_isolation(router_manager, mock_workers):
_, [fail_url], _ = mock_workers(n=1, args=["--status-code", "500"]) # always fail
_, [ok_url], _ = mock_workers(n=1)
rh = router_manager.start_router(
worker_urls=[fail_url, ok_url],
policy="round_robin",
extra={
"cb_failure_threshold": 2,
"cb_success_threshold": 1,
"cb_timeout_duration_secs": 2,
"cb_window_duration_secs": 10,
"disable_retries": True,
},
)
def post_once():
return requests.post(
f"{rh.url}/v1/completions",
json={
"model": "test-model",
"prompt": "y",
"max_tokens": 1,
"stream": False,
},
timeout=3,
)
failures = 0
successes_after_open = 0
opened = False
for _ in range(30):
r = post_once()
if not opened:
if r.status_code == 500:
failures += 1
if failures >= 2:
_ = post_once()
_ = post_once()
opened = True
else:
if r.status_code == 200:
successes_after_open += 1
else:
assert False, f"Unexpected non-200 after CB open: {r.status_code}"
assert opened and successes_after_open >= 5
@pytest.mark.integration
def test_circuit_breaker_with_retries(router_manager, mock_workers):
_, [fail_url], _ = mock_workers(n=1, args=["--status-code", "500"]) # always fail
_, [ok_url], _ = mock_workers(n=1)
rh = router_manager.start_router(
worker_urls=[fail_url, ok_url],
policy="round_robin",
extra={
"retry_max_retries": 3,
"retry_initial_backoff_ms": 10,
"retry_max_backoff_ms": 50,
"cb_failure_threshold": 2,
"cb_success_threshold": 1,
"cb_timeout_duration_secs": 2,
"cb_window_duration_secs": 10,
},
)
r = requests.post(
f"{rh.url}/v1/completions",
json={
"model": "test-model",
"prompt": "z",
"max_tokens": 1,
"stream": False,
},
timeout=5,
)
assert r.status_code == 200
import concurrent.futures
import subprocess
import time
import pytest
import requests
@pytest.mark.integration
def test_worker_crash_reroute_with_retries(router_manager, mock_workers):
# Start one healthy and one that will crash on first request
_, [ok_url], _ = mock_workers(n=1)
_, [crash_url], _ = mock_workers(n=1, args=["--crash-on-request"])
rh = router_manager.start_router(
worker_urls=[crash_url, ok_url],
policy="round_robin",
extra={
"retry_max_retries": 3,
"retry_initial_backoff_ms": 10,
"retry_max_backoff_ms": 50,
},
)
# A single request should succeed via retry to the healthy worker
r = requests.post(
f"{rh.url}/v1/completions",
json={
"model": "test-model",
"prompt": "crash",
"max_tokens": 1,
"stream": False,
},
timeout=5,
)
assert r.status_code == 200
# mock_workers fixture handles cleanup
import collections
import concurrent.futures
import subprocess
import time
import pytest
import requests
@pytest.mark.integration
def test_pd_power_of_two_decode_attribution(router_manager, mock_workers):
# Start two prefill and three decode mock workers via fixture
_, prefill_urls_raw, prefill_ids = mock_workers(n=2)
_, decode_urls_raw, decode_ids_list = mock_workers(n=3)
prefill_urls = [(u, None) for u in prefill_urls_raw]
decode_urls = list(decode_urls_raw)
decode_ids = set(decode_ids_list)
rh = router_manager.start_router(
policy="power_of_two",
pd_disaggregation=True,
prefill_urls=prefill_urls,
decode_urls=decode_urls,
extra={"worker_startup_check_interval": 1},
)
counts = collections.Counter()
with requests.Session() as s:
for i in range(30):
r = s.post(
f"{rh.url}/v1/completions",
json={
"model": "test-model",
"prompt": f"p{i}",
"max_tokens": 1,
"stream": False,
},
)
assert r.status_code == 200
wid = r.headers.get("X-Worker-Id") or r.json().get("worker_id")
assert wid in decode_ids
counts[wid] += 1
assert sum(1 for v in counts.values() if v > 0) >= 2
@pytest.mark.integration
def test_pd_power_of_two_skews_to_faster_decode(router_manager, mock_workers):
# Start two prefill workers (fast)
_, prefill_urls_raw, _ = mock_workers(n=2)
# Start two decode workers: one slow, one fast
_, [decode_slow_url], [slow_id] = mock_workers(
n=1, args=["--latency-ms", "300"]
) # slower decode
_, [decode_fast_url], [fast_id] = mock_workers(n=1)
decode_urls_raw = [decode_slow_url, decode_fast_url]
prefill_urls = [(u, None) for u in prefill_urls_raw]
decode_urls = list(decode_urls_raw)
rh = router_manager.start_router(
policy="power_of_two",
pd_disaggregation=True,
prefill_urls=prefill_urls,
decode_urls=decode_urls,
extra={"worker_startup_check_interval": 1},
)
def _prime_call(i):
try:
requests.post(
f"{rh.url}/v1/completions",
json={
"model": "test-model",
"prompt": f"warm-{i}",
"max_tokens": 1,
"stream": False,
},
timeout=8,
)
except Exception:
pass
with concurrent.futures.ThreadPoolExecutor(max_workers=32) as ex:
list(ex.map(_prime_call, range(128)))
time.sleep(2)
def _direct_decode_load(i):
try:
requests.post(
f"{decode_slow_url}/v1/completions",
json={
"model": "test-model",
"prompt": f"bg-{i}",
"max_tokens": 1,
"stream": False,
},
timeout=8,
)
except Exception:
pass
with concurrent.futures.ThreadPoolExecutor(max_workers=32) as ex:
list(ex.map(_direct_decode_load, range(128)))
time.sleep(1)
def call(i):
r = requests.post(
f"{rh.url}/v1/completions",
json={
"model": "test-model",
"prompt": f"p{i}",
"max_tokens": 1,
"stream": False,
},
timeout=8,
)
assert r.status_code == 200
return r.headers.get("X-Worker-Id") or r.json().get("worker_id")
counts = collections.Counter()
with concurrent.futures.ThreadPoolExecutor(max_workers=32) as ex:
for wid in ex.map(call, range(200)):
counts[wid] += 1
assert counts[slow_id] < counts[fast_id], counts
import concurrent.futures
import time
import pytest
import requests
@pytest.mark.integration
def test_rate_limit_and_queue(router_manager, mock_workers):
# One fast backend
_, urls, _ = mock_workers(n=1)
rh = router_manager.start_router(
worker_urls=urls,
policy="round_robin",
extra={
"max_concurrent_requests": 2,
"queue_size": 0, # no queue -> immediate 429 when limit exceeded
},
)
def call_once(i):
try:
r = requests.post(
f"{rh.url}/v1/completions",
json={
"model": "test-model",
"prompt": f"p{i}",
"max_tokens": 1,
"stream": False,
},
timeout=3,
)
return r.status_code
except Exception:
return 599
# Fire a burst of concurrent requests
with concurrent.futures.ThreadPoolExecutor(max_workers=16) as ex:
results = list(ex.map(call_once, range(16)))
# Expect some to succeed and some to be rate limited (429)
assert any(code == 200 for code in results)
assert any(code == 429 for code in results)
@pytest.mark.integration
def test_rate_limit_queue_and_timeout(router_manager, mock_workers):
# Slow backend: ~2s per request ensures queue wait > timeout
_, urls, _ = mock_workers(n=1, args=["--latency-ms", "2000"]) # 2.0s per request
# Allow 1 concurrent, queue up to 1, with 1s queue timeout
rh = router_manager.start_router(
worker_urls=urls,
policy="round_robin",
extra={
"max_concurrent_requests": 1,
"queue_size": 1,
"queue_timeout_secs": 1,
},
)
def call_once(i):
try:
r = requests.post(
f"{rh.url}/v1/completions",
json={
"model": "test-model",
"prompt": f"q{i}",
"max_tokens": 1,
"stream": False,
},
timeout=5,
)
return r.status_code
except Exception:
return 599
# Fire 4 concurrent requests: 1 runs (~2s), 1 queued (times out at 1s -> 408), 2 overflow -> 429
import concurrent.futures
with concurrent.futures.ThreadPoolExecutor(max_workers=8) as ex:
results = list(ex.map(call_once, range(4)))
# We expect:
# - Some 200s (processed)
# - At least one 408 (queued too long and timed out)
# - Remaining non-200s are either 429 (queue overflow) or additional 408s depending on scheduling
assert any(code == 200 for code in results)
assert any(code == 408 for code in results), results
non200 = [c for c in results if c != 200]
assert len(non200) >= 2 and all(c in (408, 429) for c in non200), results
import concurrent.futures
import subprocess
import time
import pytest
import requests
@pytest.mark.integration
def test_retry_reroutes_to_healthy_worker(router_manager, mock_workers):
# Worker A always 500; Worker B healthy
# Worker A always 500; Worker B/C healthy
_, [url_a], [id_a] = mock_workers(n=1, args=["--status-code", "500"]) # fail
_, [url_b], [id_b] = mock_workers(n=1)
_, [url_c], [id_c] = mock_workers(n=1)
rh = router_manager.start_router(
worker_urls=[url_a, url_b, url_c],
policy="round_robin",
extra={
"retry_max_retries": 3,
"retry_initial_backoff_ms": 10,
"retry_max_backoff_ms": 50,
},
)
r = requests.post(
f"{rh.url}/v1/completions",
json={
"model": "test-model",
"prompt": "x",
"max_tokens": 1,
"stream": False,
},
timeout=5,
)
assert r.status_code == 200
wid = r.headers.get("X-Worker-Id") or r.json().get("worker_id")
assert wid == id_b # should have retried onto healthy worker
# mock_workers fixture handles cleanup
@pytest.mark.integration
def test_disable_retries_surfaces_failure(router_manager, mock_workers):
# Single failing worker, retries disabled -> should return 500
_, [url], [wid] = mock_workers(n=1, args=["--status-code", "500"]) # always fail
rh = router_manager.start_router(
worker_urls=[url],
policy="round_robin",
extra={
"disable_retries": True,
},
)
r = requests.post(
f"{rh.url}/v1/completions",
json={
"model": "test-model",
"prompt": "x",
"max_tokens": 1,
"stream": False,
},
timeout=5,
)
assert r.status_code == 500
# mock_workers fixture handles cleanup
import pytest
import requests
@pytest.mark.integration
def test_discovery_shim_add_remove(router_manager, mock_workers):
# Start router without workers
rh = router_manager.start_router(worker_urls=[], policy="round_robin")
# Initially empty
urls = router_manager.list_workers(rh.url)
assert urls == []
# Add a worker (simulate discovery event)
_, [wurl], [wid] = mock_workers(n=1)
router_manager.add_worker(rh.url, wurl)
urls = router_manager.list_workers(rh.url)
assert wurl in urls
# Can serve a request
r = requests.post(
f"{rh.url}/v1/completions",
json={
"model": "test-model",
"prompt": "hi",
"max_tokens": 1,
"stream": False,
},
)
assert r.status_code == 200
# Remove worker (simulate pod deletion)
router_manager.remove_worker(rh.url, wurl)
urls = router_manager.list_workers(rh.url)
assert wurl not in urls
# mock_workers fixture handles cleanup
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment