Unverified Commit 045ab92d authored by Keyang Ru's avatar Keyang Ru Committed by GitHub
Browse files

[router] add py binding unit tests to coverage 80% (#10043)

parent bd7f8821
......@@ -39,7 +39,7 @@ jobs:
cd sgl-router/
cargo fmt -- --check
- name: Run test
- name: Run Rust tests
timeout-minutes: 20
run: |
source "$HOME/.cargo/env"
......@@ -83,6 +83,15 @@ jobs:
pip install setuptools-rust wheel build
python3 -m build
pip install --force-reinstall dist/*.whl
- name: Run Python unit tests
run: |
cd sgl-router
source "$HOME/.cargo/env"
pip install pytest pytest-cov pytest-xdist
pytest -q py_test/unit
- name: Run e2e test
run: |
bash scripts/killall_sglang.sh "nuk_gpus"
......
[run]
source = py_src/sglang_router
omit =
py_src/sglang_router/mini_lb.py
[report]
fail_under = 80
omit =
py_src/sglang_router/mini_lb.py
import sys
from pathlib import Path
# Ensure local sources in py_src are importable ahead of any installed package
_ROOT = Path(__file__).resolve().parents[1]
_SRC = _ROOT / "py_src"
if str(_SRC) not in sys.path:
sys.path.insert(0, str(_SRC))
"""
Unit tests for sglang_router.
This package contains fast, isolated unit tests for Python components
of the SGLang router. These tests focus on testing individual functions
and classes in isolation without starting actual router instances.
"""
"""
Unit tests for argument parsing functionality in sglang_router.
These tests focus on testing the argument parsing logic in isolation,
without starting actual router instances.
"""
import argparse
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from sglang_router.launch_router import RouterArgs, parse_router_args
from sglang_router.router import policy_from_str
class TestRouterArgs:
"""Test RouterArgs dataclass and its methods."""
def test_default_values(self):
"""Test that RouterArgs has correct default values."""
args = RouterArgs()
# Test basic defaults
assert args.host == "127.0.0.1"
assert args.port == 30000
assert args.policy == "cache_aware"
assert args.worker_urls == []
assert args.pd_disaggregation is False
assert args.prefill_urls == []
assert args.decode_urls == []
# Test PD-specific defaults
assert args.prefill_policy is None
assert args.decode_policy is None
# Test service discovery defaults
assert args.service_discovery is False
assert args.selector == {}
assert args.service_discovery_port == 80
assert args.service_discovery_namespace is None
# Test retry and circuit breaker defaults
assert args.retry_max_retries == 5
assert args.cb_failure_threshold == 10
assert args.disable_retries is False
assert args.disable_circuit_breaker is False
def test_parse_selector_valid(self):
"""Test parsing valid selector arguments."""
# Test single key-value pair
result = RouterArgs._parse_selector(["app=worker"])
assert result == {"app": "worker"}
# Test multiple key-value pairs
result = RouterArgs._parse_selector(["app=worker", "env=prod", "version=v1"])
assert result == {"app": "worker", "env": "prod", "version": "v1"}
# Test empty list
result = RouterArgs._parse_selector([])
assert result == {}
# Test None
result = RouterArgs._parse_selector(None)
assert result == {}
def test_parse_selector_invalid(self):
"""Test parsing invalid selector arguments."""
# Test malformed selector (no equals sign)
result = RouterArgs._parse_selector(["app"])
assert result == {}
# Test multiple equals signs (should use first one)
result = RouterArgs._parse_selector(["app=worker=extra"])
assert result == {"app": "worker=extra"}
def test_parse_prefill_urls_valid(self):
"""Test parsing valid prefill URL arguments."""
# Test with bootstrap port
result = RouterArgs._parse_prefill_urls([["http://prefill1:8000", "9000"]])
assert result == [("http://prefill1:8000", 9000)]
# Test with 'none' bootstrap port
result = RouterArgs._parse_prefill_urls([["http://prefill1:8000", "none"]])
assert result == [("http://prefill1:8000", None)]
# Test without bootstrap port
result = RouterArgs._parse_prefill_urls([["http://prefill1:8000"]])
assert result == [("http://prefill1:8000", None)]
# Test multiple prefill URLs
result = RouterArgs._parse_prefill_urls(
[
["http://prefill1:8000", "9000"],
["http://prefill2:8000", "none"],
["http://prefill3:8000"],
]
)
expected = [
("http://prefill1:8000", 9000),
("http://prefill2:8000", None),
("http://prefill3:8000", None),
]
assert result == expected
# Test empty list
result = RouterArgs._parse_prefill_urls([])
assert result == []
# Test None
result = RouterArgs._parse_prefill_urls(None)
assert result == []
def test_parse_prefill_urls_invalid(self):
"""Test parsing invalid prefill URL arguments."""
# Test invalid bootstrap port
with pytest.raises(ValueError, match="Invalid bootstrap port"):
RouterArgs._parse_prefill_urls([["http://prefill1:8000", "invalid"]])
def test_parse_decode_urls_valid(self):
"""Test parsing valid decode URL arguments."""
# Test single decode URL
result = RouterArgs._parse_decode_urls([["http://decode1:8001"]])
assert result == ["http://decode1:8001"]
# Test multiple decode URLs
result = RouterArgs._parse_decode_urls(
[["http://decode1:8001"], ["http://decode2:8001"]]
)
assert result == ["http://decode1:8001", "http://decode2:8001"]
# Test empty list
result = RouterArgs._parse_decode_urls([])
assert result == []
# Test None
result = RouterArgs._parse_decode_urls(None)
assert result == []
def test_from_cli_args_basic(self):
"""Test creating RouterArgs from basic CLI arguments."""
args = SimpleNamespace(
host="0.0.0.0",
port=30001,
worker_urls=["http://worker1:8000", "http://worker2:8000"],
policy="round_robin",
prefill=None,
decode=None,
router_policy="round_robin",
router_pd_disaggregation=False,
router_prefill_policy=None,
router_decode_policy=None,
router_worker_startup_timeout_secs=300,
router_worker_startup_check_interval=15,
router_cache_threshold=0.7,
router_balance_abs_threshold=128,
router_balance_rel_threshold=2.0,
router_eviction_interval=180,
router_max_tree_size=2**28,
router_max_payload_size=1024 * 1024 * 1024, # 1GB
router_dp_aware=True,
router_api_key="test-key",
router_log_dir="/tmp/logs",
router_log_level="debug",
router_service_discovery=True,
router_selector=["app=worker", "env=test"],
router_service_discovery_port=8080,
router_service_discovery_namespace="default",
router_prefill_selector=["app=prefill"],
router_decode_selector=["app=decode"],
router_prometheus_port=29000,
router_prometheus_host="0.0.0.0",
router_request_id_headers=["x-request-id", "x-trace-id"],
router_request_timeout_secs=1200,
router_max_concurrent_requests=512,
router_queue_size=200,
router_queue_timeout_secs=120,
router_rate_limit_tokens_per_second=100,
router_cors_allowed_origins=["http://localhost:3000"],
router_retry_max_retries=3,
router_retry_initial_backoff_ms=100,
router_retry_max_backoff_ms=10000,
router_retry_backoff_multiplier=2.0,
router_retry_jitter_factor=0.1,
router_cb_failure_threshold=5,
router_cb_success_threshold=2,
router_cb_timeout_duration_secs=30,
router_cb_window_duration_secs=60,
router_disable_retries=False,
router_disable_circuit_breaker=False,
router_health_failure_threshold=2,
router_health_success_threshold=1,
router_health_check_timeout_secs=3,
router_health_check_interval_secs=30,
router_health_check_endpoint="/healthz",
)
router_args = RouterArgs.from_cli_args(args, use_router_prefix=True)
# Test basic configuration
assert router_args.host == "0.0.0.0"
assert router_args.port == 30001
assert router_args.worker_urls == ["http://worker1:8000", "http://worker2:8000"]
assert router_args.policy == "round_robin"
# Test PD configuration
assert router_args.pd_disaggregation is False
assert router_args.prefill_urls == []
assert router_args.decode_urls == []
# Test service discovery
assert router_args.service_discovery is True
assert router_args.selector == {"app": "worker", "env": "test"}
assert router_args.service_discovery_port == 8080
assert router_args.service_discovery_namespace == "default"
assert router_args.prefill_selector == {"app": "prefill"}
assert router_args.decode_selector == {"app": "decode"}
# Test other configurations
assert router_args.dp_aware is True
assert router_args.api_key == "test-key"
assert router_args.log_dir == "/tmp/logs"
assert router_args.log_level == "debug"
assert router_args.prometheus_port == 29000
assert router_args.prometheus_host == "0.0.0.0"
assert router_args.request_id_headers == ["x-request-id", "x-trace-id"]
assert router_args.request_timeout_secs == 1200
assert router_args.max_concurrent_requests == 512
assert router_args.queue_size == 200
assert router_args.queue_timeout_secs == 120
assert router_args.rate_limit_tokens_per_second == 100
assert router_args.cors_allowed_origins == ["http://localhost:3000"]
# Test retry configuration
assert router_args.retry_max_retries == 3
assert router_args.retry_initial_backoff_ms == 100
assert router_args.retry_max_backoff_ms == 10000
assert router_args.retry_backoff_multiplier == 2.0
assert router_args.retry_jitter_factor == 0.1
# Test circuit breaker configuration
assert router_args.cb_failure_threshold == 5
assert router_args.cb_success_threshold == 2
assert router_args.cb_timeout_duration_secs == 30
assert router_args.cb_window_duration_secs == 60
assert router_args.disable_retries is False
assert router_args.disable_circuit_breaker is False
# Test health check configuration
assert router_args.health_failure_threshold == 2
assert router_args.health_success_threshold == 1
assert router_args.health_check_timeout_secs == 3
assert router_args.health_check_interval_secs == 30
assert router_args.health_check_endpoint == "/healthz"
# Note: model_path and tokenizer_path are not available in current RouterArgs
def test_from_cli_args_pd_mode(self):
"""Test creating RouterArgs from CLI arguments in PD mode."""
args = SimpleNamespace(
host="127.0.0.1",
port=30000,
worker_urls=[],
policy="cache_aware",
prefill=[
["http://prefill1:8000", "9000"],
["http://prefill2:8000", "none"],
],
decode=[["http://decode1:8001"], ["http://decode2:8001"]],
router_prefill=[
["http://prefill1:8000", "9000"],
["http://prefill2:8000", "none"],
],
router_decode=[["http://decode1:8001"], ["http://decode2:8001"]],
router_policy="cache_aware",
router_pd_disaggregation=True,
router_prefill_policy="power_of_two",
router_decode_policy="round_robin",
# Include all required fields with defaults
router_worker_startup_timeout_secs=600,
router_worker_startup_check_interval=30,
router_cache_threshold=0.3,
router_balance_abs_threshold=64,
router_balance_rel_threshold=1.5,
router_eviction_interval=120,
router_max_tree_size=2**26,
router_max_payload_size=512 * 1024 * 1024,
router_dp_aware=False,
router_api_key=None,
router_log_dir=None,
router_log_level=None,
router_service_discovery=False,
router_selector=None,
router_service_discovery_port=80,
router_service_discovery_namespace=None,
router_prefill_selector=None,
router_decode_selector=None,
router_prometheus_port=None,
router_prometheus_host=None,
router_request_id_headers=None,
router_request_timeout_secs=1800,
router_max_concurrent_requests=256,
router_queue_size=100,
router_queue_timeout_secs=60,
router_rate_limit_tokens_per_second=None,
router_cors_allowed_origins=[],
router_retry_max_retries=5,
router_retry_initial_backoff_ms=50,
router_retry_max_backoff_ms=30000,
router_retry_backoff_multiplier=1.5,
router_retry_jitter_factor=0.2,
router_cb_failure_threshold=10,
router_cb_success_threshold=3,
router_cb_timeout_duration_secs=60,
router_cb_window_duration_secs=120,
router_disable_retries=False,
router_disable_circuit_breaker=False,
router_health_failure_threshold=3,
router_health_success_threshold=2,
router_health_check_timeout_secs=5,
router_health_check_interval_secs=60,
router_health_check_endpoint="/health",
)
router_args = RouterArgs.from_cli_args(args, use_router_prefix=True)
# Test PD configuration
assert router_args.pd_disaggregation is True
assert router_args.prefill_urls == [
("http://prefill1:8000", 9000),
("http://prefill2:8000", None),
]
assert router_args.decode_urls == ["http://decode1:8001", "http://decode2:8001"]
assert router_args.prefill_policy == "power_of_two"
assert router_args.decode_policy == "round_robin"
assert router_args.policy == "cache_aware" # Main policy still set
def test_from_cli_args_without_prefix(self):
"""Test creating RouterArgs from CLI arguments without router prefix."""
args = SimpleNamespace(
host="127.0.0.1",
port=30000,
worker_urls=["http://worker1:8000"],
policy="random",
prefill=None,
decode=None,
pd_disaggregation=False,
prefill_policy=None,
decode_policy=None,
worker_startup_timeout_secs=600,
worker_startup_check_interval=30,
cache_threshold=0.3,
balance_abs_threshold=64,
balance_rel_threshold=1.5,
eviction_interval=120,
max_tree_size=2**26,
max_payload_size=512 * 1024 * 1024,
dp_aware=False,
api_key=None,
log_dir=None,
log_level=None,
service_discovery=False,
selector=None,
service_discovery_port=80,
service_discovery_namespace=None,
prefill_selector=None,
decode_selector=None,
prometheus_port=None,
prometheus_host=None,
request_id_headers=None,
request_timeout_secs=1800,
max_concurrent_requests=256,
queue_size=100,
queue_timeout_secs=60,
rate_limit_tokens_per_second=None,
cors_allowed_origins=[],
retry_max_retries=5,
retry_initial_backoff_ms=50,
retry_max_backoff_ms=30000,
retry_backoff_multiplier=1.5,
retry_jitter_factor=0.2,
cb_failure_threshold=10,
cb_success_threshold=3,
cb_timeout_duration_secs=60,
cb_window_duration_secs=120,
disable_retries=False,
disable_circuit_breaker=False,
health_failure_threshold=3,
health_success_threshold=2,
health_check_timeout_secs=5,
health_check_interval_secs=60,
health_check_endpoint="/health",
model_path=None,
tokenizer_path=None,
)
router_args = RouterArgs.from_cli_args(args, use_router_prefix=False)
assert router_args.host == "127.0.0.1"
assert router_args.port == 30000
assert router_args.worker_urls == ["http://worker1:8000"]
assert router_args.policy == "random"
assert router_args.pd_disaggregation is False
class TestPolicyFromStr:
"""Test policy string to enum conversion."""
def test_valid_policies(self):
"""Test conversion of valid policy strings."""
from sglang_router_rs import PolicyType
assert policy_from_str("random") == PolicyType.Random
assert policy_from_str("round_robin") == PolicyType.RoundRobin
assert policy_from_str("cache_aware") == PolicyType.CacheAware
assert policy_from_str("power_of_two") == PolicyType.PowerOfTwo
def test_invalid_policy(self):
"""Test conversion of invalid policy string."""
with pytest.raises(KeyError):
policy_from_str("invalid_policy")
class TestParseRouterArgs:
"""Test the parse_router_args function."""
def test_parse_basic_args(self):
"""Test parsing basic router arguments."""
args = [
"--host",
"0.0.0.0",
"--port",
"30001",
"--worker-urls",
"http://worker1:8000",
"http://worker2:8000",
"--policy",
"round_robin",
]
router_args = parse_router_args(args)
assert router_args.host == "0.0.0.0"
assert router_args.port == 30001
assert router_args.worker_urls == ["http://worker1:8000", "http://worker2:8000"]
assert router_args.policy == "round_robin"
def test_parse_pd_args(self):
"""Test parsing PD disaggregated mode arguments."""
args = [
"--pd-disaggregation",
"--prefill",
"http://prefill1:8000",
"9000",
"--prefill",
"http://prefill2:8000",
"none",
"--decode",
"http://decode1:8001",
"--decode",
"http://decode2:8001",
"--prefill-policy",
"power_of_two",
"--decode-policy",
"round_robin",
]
router_args = parse_router_args(args)
assert router_args.pd_disaggregation is True
assert router_args.prefill_urls == [
("http://prefill1:8000", 9000),
("http://prefill2:8000", None),
]
assert router_args.decode_urls == ["http://decode1:8001", "http://decode2:8001"]
assert router_args.prefill_policy == "power_of_two"
assert router_args.decode_policy == "round_robin"
def test_parse_service_discovery_args(self):
"""Test parsing service discovery arguments."""
args = [
"--service-discovery",
"--selector",
"app=worker",
"env=prod",
"--service-discovery-port",
"8080",
"--service-discovery-namespace",
"default",
]
router_args = parse_router_args(args)
assert router_args.service_discovery is True
assert router_args.selector == {"app": "worker", "env": "prod"}
assert router_args.service_discovery_port == 8080
assert router_args.service_discovery_namespace == "default"
def test_parse_retry_and_circuit_breaker_args(self):
"""Test parsing retry and circuit breaker arguments."""
args = [
"--retry-max-retries",
"3",
"--retry-initial-backoff-ms",
"100",
"--retry-max-backoff-ms",
"10000",
"--retry-backoff-multiplier",
"2.0",
"--retry-jitter-factor",
"0.1",
"--disable-retries",
"--cb-failure-threshold",
"5",
"--cb-success-threshold",
"2",
"--cb-timeout-duration-secs",
"30",
"--cb-window-duration-secs",
"60",
"--disable-circuit-breaker",
]
router_args = parse_router_args(args)
# Test retry configuration
assert router_args.retry_max_retries == 3
assert router_args.retry_initial_backoff_ms == 100
assert router_args.retry_max_backoff_ms == 10000
assert router_args.retry_backoff_multiplier == 2.0
assert router_args.retry_jitter_factor == 0.1
assert router_args.disable_retries is True
# Test circuit breaker configuration
assert router_args.cb_failure_threshold == 5
assert router_args.cb_success_threshold == 2
assert router_args.cb_timeout_duration_secs == 30
assert router_args.cb_window_duration_secs == 60
assert router_args.disable_circuit_breaker is True
def test_parse_rate_limiting_args(self):
"""Test parsing rate limiting arguments."""
args = [
"--max-concurrent-requests",
"512",
"--queue-size",
"200",
"--queue-timeout-secs",
"120",
"--rate-limit-tokens-per-second",
"100",
]
router_args = parse_router_args(args)
assert router_args.max_concurrent_requests == 512
assert router_args.queue_size == 200
assert router_args.queue_timeout_secs == 120
assert router_args.rate_limit_tokens_per_second == 100
def test_parse_health_check_args(self):
"""Test parsing health check arguments."""
args = [
"--health-failure-threshold",
"2",
"--health-success-threshold",
"1",
"--health-check-timeout-secs",
"3",
"--health-check-interval-secs",
"30",
"--health-check-endpoint",
"/healthz",
]
router_args = parse_router_args(args)
assert router_args.health_failure_threshold == 2
assert router_args.health_success_threshold == 1
assert router_args.health_check_timeout_secs == 3
assert router_args.health_check_interval_secs == 30
assert router_args.health_check_endpoint == "/healthz"
def test_parse_cors_args(self):
"""Test parsing CORS arguments."""
args = [
"--cors-allowed-origins",
"http://localhost:3000",
"https://example.com",
]
router_args = parse_router_args(args)
assert router_args.cors_allowed_origins == [
"http://localhost:3000",
"https://example.com",
]
def test_parse_tokenizer_args(self):
"""Test parsing tokenizer arguments."""
# Note: model-path and tokenizer-path arguments are not available in current implementation
# This test is skipped until those arguments are added
pytest.skip("Tokenizer arguments not available in current implementation")
def test_parse_invalid_args(self):
"""Test parsing invalid arguments."""
# Test invalid policy
with pytest.raises(SystemExit):
parse_router_args(["--policy", "invalid_policy"])
# Test invalid bootstrap port
with pytest.raises(ValueError, match="Invalid bootstrap port"):
parse_router_args(
[
"--pd-disaggregation",
"--prefill",
"http://prefill1:8000",
"invalid_port",
]
)
def test_help_output(self):
"""Test that help output is generated correctly."""
with pytest.raises(SystemExit) as exc_info:
parse_router_args(["--help"])
# SystemExit with code 0 indicates help was displayed
assert exc_info.value.code == 0
"""
Unit tests for router configuration validation and setup.
These tests focus on testing the router configuration logic in isolation,
including validation of configuration parameters and their interactions.
"""
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from sglang_router.launch_router import RouterArgs, launch_router
from sglang_router.router import policy_from_str
from sglang_router_rs import PolicyType
class TestRouterConfigValidation:
"""Test router configuration validation logic."""
def test_valid_basic_config(self):
"""Test that a valid basic configuration passes validation."""
args = RouterArgs(
host="127.0.0.1",
port=30000,
worker_urls=["http://worker1:8000", "http://worker2:8000"],
policy="cache_aware",
)
# Should not raise any exceptions
assert args.host == "127.0.0.1"
assert args.port == 30000
assert args.worker_urls == ["http://worker1:8000", "http://worker2:8000"]
assert args.policy == "cache_aware"
def test_valid_pd_config(self):
"""Test that a valid PD configuration passes validation."""
args = RouterArgs(
host="127.0.0.1",
port=30000,
pd_disaggregation=True,
prefill_urls=[
("http://prefill1:8000", 9000),
("http://prefill2:8000", None),
],
decode_urls=["http://decode1:8001", "http://decode2:8001"],
policy="cache_aware",
)
assert args.pd_disaggregation is True
assert args.prefill_urls == [
("http://prefill1:8000", 9000),
("http://prefill2:8000", None),
]
assert args.decode_urls == ["http://decode1:8001", "http://decode2:8001"]
assert args.policy == "cache_aware"
def test_pd_config_without_urls_raises_error(self):
"""Test that PD mode without URLs raises validation error."""
args = RouterArgs(
pd_disaggregation=True,
prefill_urls=[],
decode_urls=[],
service_discovery=False,
)
# This should raise an error when trying to launch
with pytest.raises(
ValueError, match="PD disaggregation mode requires --prefill"
):
launch_router(args)
def test_pd_config_with_service_discovery_allows_empty_urls(self):
"""Test that PD mode with service discovery allows empty URLs."""
args = RouterArgs(
pd_disaggregation=True,
prefill_urls=[],
decode_urls=[],
service_discovery=True,
)
# Should not raise validation error when service discovery is enabled
with patch("sglang_router.launch_router.Router") as router_mod:
mock_router_instance = MagicMock()
router_mod.from_args = MagicMock(return_value=mock_router_instance)
launch_router(args)
# Should create router instance via from_args
router_mod.from_args.assert_called_once()
def test_regular_mode_without_workers_allows_empty_urls(self):
"""Test that regular mode allows empty worker URLs."""
args = RouterArgs(worker_urls=[], service_discovery=False)
# Should not raise validation error
with patch("sglang_router.launch_router.Router") as router_mod:
mock_router_instance = MagicMock()
router_mod.from_args = MagicMock(return_value=mock_router_instance)
launch_router(args)
# Should create router instance via from_args
router_mod.from_args.assert_called_once()
def test_cache_threshold_validation(self):
"""Test cache threshold validation."""
# Valid cache threshold
args = RouterArgs(cache_threshold=0.5)
assert args.cache_threshold == 0.5
# Edge cases
args = RouterArgs(cache_threshold=0.0)
assert args.cache_threshold == 0.0
args = RouterArgs(cache_threshold=1.0)
assert args.cache_threshold == 1.0
def test_balance_threshold_validation(self):
"""Test load balancing threshold validation."""
# Valid thresholds
args = RouterArgs(balance_abs_threshold=64, balance_rel_threshold=1.5)
assert args.balance_abs_threshold == 64
assert args.balance_rel_threshold == 1.5
# Edge cases
args = RouterArgs(balance_abs_threshold=0, balance_rel_threshold=1.0)
assert args.balance_abs_threshold == 0
assert args.balance_rel_threshold == 1.0
def test_timeout_validation(self):
"""Test timeout parameter validation."""
# Valid timeouts
args = RouterArgs(
worker_startup_timeout_secs=600,
worker_startup_check_interval=30,
request_timeout_secs=1800,
queue_timeout_secs=60,
)
assert args.worker_startup_timeout_secs == 600
assert args.worker_startup_check_interval == 30
assert args.request_timeout_secs == 1800
assert args.queue_timeout_secs == 60
def test_retry_config_validation(self):
"""Test retry configuration validation."""
# Valid retry config
args = RouterArgs(
retry_max_retries=5,
retry_initial_backoff_ms=50,
retry_max_backoff_ms=30000,
retry_backoff_multiplier=1.5,
retry_jitter_factor=0.2,
disable_retries=False,
)
assert args.retry_max_retries == 5
assert args.retry_initial_backoff_ms == 50
assert args.retry_max_backoff_ms == 30000
assert args.retry_backoff_multiplier == 1.5
assert args.retry_jitter_factor == 0.2
assert args.disable_retries is False
def test_circuit_breaker_config_validation(self):
"""Test circuit breaker configuration validation."""
# Valid circuit breaker config
args = RouterArgs(
cb_failure_threshold=10,
cb_success_threshold=3,
cb_timeout_duration_secs=60,
cb_window_duration_secs=120,
disable_circuit_breaker=False,
)
assert args.cb_failure_threshold == 10
assert args.cb_success_threshold == 3
assert args.cb_timeout_duration_secs == 60
assert args.cb_window_duration_secs == 120
assert args.disable_circuit_breaker is False
def test_health_check_config_validation(self):
"""Test health check configuration validation."""
# Valid health check config
args = RouterArgs(
health_failure_threshold=3,
health_success_threshold=2,
health_check_timeout_secs=5,
health_check_interval_secs=60,
health_check_endpoint="/health",
)
assert args.health_failure_threshold == 3
assert args.health_success_threshold == 2
assert args.health_check_timeout_secs == 5
assert args.health_check_interval_secs == 60
assert args.health_check_endpoint == "/health"
def test_rate_limiting_config_validation(self):
"""Test rate limiting configuration validation."""
# Valid rate limiting config
args = RouterArgs(
max_concurrent_requests=256,
queue_size=100,
queue_timeout_secs=60,
rate_limit_tokens_per_second=100,
)
assert args.max_concurrent_requests == 256
assert args.queue_size == 100
assert args.queue_timeout_secs == 60
assert args.rate_limit_tokens_per_second == 100
def test_service_discovery_config_validation(self):
"""Test service discovery configuration validation."""
# Valid service discovery config
args = RouterArgs(
service_discovery=True,
selector={"app": "worker", "env": "prod"},
service_discovery_port=8080,
service_discovery_namespace="default",
)
assert args.service_discovery is True
assert args.selector == {"app": "worker", "env": "prod"}
assert args.service_discovery_port == 8080
assert args.service_discovery_namespace == "default"
def test_pd_service_discovery_config_validation(self):
"""Test PD service discovery configuration validation."""
# Valid PD service discovery config
args = RouterArgs(
pd_disaggregation=True,
service_discovery=True,
prefill_selector={"app": "prefill"},
decode_selector={"app": "decode"},
bootstrap_port_annotation="sglang.ai/bootstrap-port",
)
assert args.pd_disaggregation is True
assert args.service_discovery is True
assert args.prefill_selector == {"app": "prefill"}
assert args.decode_selector == {"app": "decode"}
assert args.bootstrap_port_annotation == "sglang.ai/bootstrap-port"
def test_prometheus_config_validation(self):
"""Test Prometheus configuration validation."""
# Valid Prometheus config
args = RouterArgs(prometheus_port=29000, prometheus_host="127.0.0.1")
assert args.prometheus_port == 29000
assert args.prometheus_host == "127.0.0.1"
def test_cors_config_validation(self):
"""Test CORS configuration validation."""
# Valid CORS config
args = RouterArgs(
cors_allowed_origins=["http://localhost:3000", "https://example.com"]
)
assert args.cors_allowed_origins == [
"http://localhost:3000",
"https://example.com",
]
def test_tokenizer_config_validation(self):
"""Test tokenizer configuration validation."""
# Note: model_path and tokenizer_path are not available in current RouterArgs
pytest.skip("Tokenizer configuration not available in current implementation")
def test_dp_aware_config_validation(self):
"""Test data parallelism aware configuration validation."""
# Valid DP aware config
args = RouterArgs(dp_aware=True, api_key="test-api-key")
assert args.dp_aware is True
assert args.api_key == "test-api-key"
def test_request_id_headers_validation(self):
"""Test request ID headers configuration validation."""
# Valid request ID headers config
args = RouterArgs(
request_id_headers=["x-request-id", "x-trace-id", "x-correlation-id"]
)
assert args.request_id_headers == [
"x-request-id",
"x-trace-id",
"x-correlation-id",
]
def test_policy_consistency_validation(self):
"""Test policy consistency validation in PD mode."""
# Test with both prefill and decode policies specified
args = RouterArgs(
pd_disaggregation=True,
prefill_urls=[("http://prefill1:8000", None)],
decode_urls=["http://decode1:8001"],
policy="cache_aware",
prefill_policy="power_of_two",
decode_policy="round_robin",
)
# Should not raise validation error
with patch("sglang_router.launch_router.Router") as router_mod:
mock_router_instance = MagicMock()
router_mod.from_args = MagicMock(return_value=mock_router_instance)
launch_router(args)
# Should create router instance via from_args
router_mod.from_args.assert_called_once()
def test_policy_fallback_validation(self):
"""Test policy fallback validation in PD mode."""
# Test with only prefill policy specified
args = RouterArgs(
pd_disaggregation=True,
prefill_urls=[("http://prefill1:8000", None)],
decode_urls=["http://decode1:8001"],
policy="cache_aware",
prefill_policy="power_of_two",
decode_policy=None,
)
# Should not raise validation error
with patch("sglang_router.launch_router.Router") as router_mod:
mock_router_instance = MagicMock()
router_mod.from_args = MagicMock(return_value=mock_router_instance)
launch_router(args)
# Should create router instance via from_args
router_mod.from_args.assert_called_once()
def test_policy_enum_conversion(self):
"""Test policy string to enum conversion."""
# Test all valid policy conversions
assert policy_from_str("random") == PolicyType.Random
assert policy_from_str("round_robin") == PolicyType.RoundRobin
assert policy_from_str("cache_aware") == PolicyType.CacheAware
assert policy_from_str("power_of_two") == PolicyType.PowerOfTwo
def test_invalid_policy_enum_conversion(self):
"""Test invalid policy string to enum conversion."""
with pytest.raises(KeyError):
policy_from_str("invalid_policy")
def test_config_immutability(self):
"""Test that configuration objects are properly immutable."""
args = RouterArgs(
host="127.0.0.1", port=30000, worker_urls=["http://worker1:8000"]
)
# Test that we can't modify the configuration after creation
# (This is more of a design test - dataclasses are mutable by default)
original_host = args.host
args.host = "0.0.0.0"
assert args.host == "0.0.0.0" # Dataclasses are mutable
assert args.host != original_host
def test_config_defaults_consistency(self):
"""Test that configuration defaults are consistent."""
args1 = RouterArgs()
args2 = RouterArgs()
# Both instances should have the same defaults
assert args1.host == args2.host
assert args1.port == args2.port
assert args1.policy == args2.policy
assert args1.worker_urls == args2.worker_urls
assert args1.pd_disaggregation == args2.pd_disaggregation
def test_config_serialization(self):
"""Test that configuration can be serialized/deserialized."""
args = RouterArgs(
host="127.0.0.1",
port=30000,
worker_urls=["http://worker1:8000"],
policy="cache_aware",
cache_threshold=0.5,
)
# Test that we can access all attributes
assert hasattr(args, "host")
assert hasattr(args, "port")
assert hasattr(args, "worker_urls")
assert hasattr(args, "policy")
assert hasattr(args, "cache_threshold")
def test_config_with_none_values(self):
"""Test configuration with None values."""
args = RouterArgs(
api_key=None,
log_dir=None,
log_level=None,
prometheus_port=None,
prometheus_host=None,
request_id_headers=None,
rate_limit_tokens_per_second=None,
service_discovery_namespace=None,
)
# All None values should be preserved
assert args.api_key is None
assert args.log_dir is None
assert args.log_level is None
assert args.prometheus_port is None
assert args.prometheus_host is None
assert args.request_id_headers is None
assert args.rate_limit_tokens_per_second is None
assert args.service_discovery_namespace is None
def test_config_with_empty_lists(self):
"""Test configuration with empty lists."""
args = RouterArgs(
worker_urls=[], prefill_urls=[], decode_urls=[], cors_allowed_origins=[]
)
# All empty lists should be preserved
assert args.worker_urls == []
assert args.prefill_urls == []
assert args.decode_urls == []
assert args.cors_allowed_origins == []
def test_config_with_empty_dicts(self):
"""Test configuration with empty dictionaries."""
args = RouterArgs(selector={}, prefill_selector={}, decode_selector={})
# All empty dictionaries should be preserved
assert args.selector == {}
assert args.prefill_selector == {}
assert args.decode_selector == {}
"""
Unit tests for startup sequence logic in sglang_router.
These tests focus on testing the startup sequence logic in isolation,
including router initialization, configuration validation, and startup flow.
"""
import logging
from types import SimpleNamespace
from unittest.mock import MagicMock, call, patch
import pytest
from sglang_router.launch_router import RouterArgs, launch_router
from sglang_router.router import policy_from_str
# Local helper mirroring the router logger setup used in production
def setup_logger():
logger = logging.getLogger("router")
logger.setLevel(logging.INFO)
if not logger.handlers:
formatter = logging.Formatter(
"[Router (Python)] %(asctime)s - %(levelname)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
handler = logging.StreamHandler()
handler.setFormatter(formatter)
logger.addHandler(handler)
return logger
from sglang_router_rs import PolicyType
class TestSetupLogger:
"""Test logger setup functionality."""
def test_setup_logger_returns_logger(self):
"""Test that setup_logger returns a logger instance."""
logger = setup_logger()
assert isinstance(logger, logging.Logger)
assert logger.name == "router"
assert logger.level == logging.INFO
def test_setup_logger_has_handler(self):
"""Test that setup_logger configures a handler."""
logger = setup_logger()
assert len(logger.handlers) > 0
handler = logger.handlers[0]
assert isinstance(handler, logging.StreamHandler)
def test_setup_logger_has_formatter(self):
"""Test that setup_logger configures a formatter."""
logger = setup_logger()
handler = logger.handlers[0]
formatter = handler.formatter
assert formatter is not None
assert "[Router (Python)]" in formatter._fmt
def test_setup_logger_multiple_calls(self):
"""Test that multiple calls to setup_logger work correctly."""
logger1 = setup_logger()
logger2 = setup_logger()
# Should return the same logger instance
assert logger1 is logger2
class TestPolicyFromStr:
"""Test policy string to enum conversion in startup context."""
def test_policy_conversion_in_startup(self):
"""Test policy conversion during startup sequence."""
# Test all valid policies
policies = ["random", "round_robin", "cache_aware", "power_of_two"]
expected_enums = [
PolicyType.Random,
PolicyType.RoundRobin,
PolicyType.CacheAware,
PolicyType.PowerOfTwo,
]
for policy_str, expected_enum in zip(policies, expected_enums):
result = policy_from_str(policy_str)
assert result == expected_enum
def test_invalid_policy_in_startup(self):
"""Test handling of invalid policy during startup."""
with pytest.raises(KeyError):
policy_from_str("invalid_policy")
class TestRouterInitialization:
"""Test router initialization logic."""
def test_router_initialization_basic(self):
"""Test basic router initialization."""
args = RouterArgs(
host="127.0.0.1",
port=30000,
worker_urls=["http://worker1:8000"],
policy="cache_aware",
)
with patch("sglang_router.launch_router.Router") as router_mod:
captured_args = {}
mock_router_instance = MagicMock()
def fake_from_args(router_args):
# capture needed fields from RouterArgs
captured_args.update(
dict(
host=router_args.host,
port=router_args.port,
worker_urls=router_args.worker_urls,
policy=policy_from_str(router_args.policy),
)
)
return mock_router_instance
router_mod.from_args = MagicMock(side_effect=fake_from_args)
result = launch_router(args)
# Verify Router.from_args was called and captured fields match
router_mod.from_args.assert_called_once()
assert captured_args["host"] == "127.0.0.1"
assert captured_args["port"] == 30000
assert captured_args["worker_urls"] == ["http://worker1:8000"]
assert captured_args["policy"] == PolicyType.CacheAware
# Verify router.start() was called
mock_router_instance.start.assert_called_once()
# Function returns None; ensure start was invoked
def test_router_initialization_pd_mode(self):
"""Test router initialization in PD mode."""
args = RouterArgs(
pd_disaggregation=True,
prefill_urls=[("http://prefill1:8000", 9000)],
decode_urls=["http://decode1:8001"],
policy="power_of_two",
)
with patch("sglang_router.launch_router.Router") as router_mod:
captured_args = {}
mock_router_instance = MagicMock()
def fake_from_args(router_args):
captured_args.update(
dict(
pd_disaggregation=router_args.pd_disaggregation,
prefill_urls=router_args.prefill_urls,
decode_urls=router_args.decode_urls,
policy=policy_from_str(router_args.policy),
)
)
return mock_router_instance
router_mod.from_args = MagicMock(side_effect=fake_from_args)
result = launch_router(args)
# Verify Router.from_args was called with PD parameters
router_mod.from_args.assert_called_once()
assert captured_args["pd_disaggregation"] is True
assert captured_args["prefill_urls"] == [("http://prefill1:8000", 9000)]
assert captured_args["decode_urls"] == ["http://decode1:8001"]
assert captured_args["policy"] == PolicyType.PowerOfTwo
# Verify router.start() was called
mock_router_instance.start.assert_called_once()
# Function returns None; ensure start was invoked
def test_router_initialization_with_service_discovery(self):
"""Test router initialization with service discovery."""
args = RouterArgs(
service_discovery=True,
selector={"app": "worker", "env": "prod"},
service_discovery_port=8080,
service_discovery_namespace="default",
)
with patch("sglang_router.launch_router.Router") as router_mod:
captured_args = {}
mock_router_instance = MagicMock()
def fake_from_args(router_args):
captured_args.update(
dict(
service_discovery=router_args.service_discovery,
selector=router_args.selector,
service_discovery_port=router_args.service_discovery_port,
service_discovery_namespace=router_args.service_discovery_namespace,
)
)
return mock_router_instance
router_mod.from_args = MagicMock(side_effect=fake_from_args)
result = launch_router(args)
# Verify Router.from_args was called with service discovery parameters
router_mod.from_args.assert_called_once()
assert captured_args["service_discovery"] is True
assert captured_args["selector"] == {"app": "worker", "env": "prod"}
assert captured_args["service_discovery_port"] == 8080
assert captured_args["service_discovery_namespace"] == "default"
# Verify router.start() was called
mock_router_instance.start.assert_called_once()
# Function returns None; ensure start was invoked
def test_router_initialization_with_retry_config(self):
"""Test router initialization with retry configuration."""
args = RouterArgs(
retry_max_retries=3,
retry_initial_backoff_ms=100,
retry_max_backoff_ms=10000,
retry_backoff_multiplier=2.0,
retry_jitter_factor=0.1,
disable_retries=False,
)
with patch("sglang_router.launch_router.Router") as router_mod:
captured_args = {}
mock_router_instance = MagicMock()
def fake_from_args(router_args):
captured_args.update(
dict(
retry_max_retries=router_args.retry_max_retries,
retry_initial_backoff_ms=router_args.retry_initial_backoff_ms,
retry_max_backoff_ms=router_args.retry_max_backoff_ms,
retry_backoff_multiplier=router_args.retry_backoff_multiplier,
retry_jitter_factor=router_args.retry_jitter_factor,
disable_retries=router_args.disable_retries,
)
)
return mock_router_instance
router_mod.from_args = MagicMock(side_effect=fake_from_args)
result = launch_router(args)
# Verify router was created with retry parameters
router_mod.from_args.assert_called_once()
assert captured_args["retry_max_retries"] == 3
assert captured_args["retry_initial_backoff_ms"] == 100
assert captured_args["retry_max_backoff_ms"] == 10000
assert captured_args["retry_backoff_multiplier"] == 2.0
assert captured_args["retry_jitter_factor"] == 0.1
assert captured_args["disable_retries"] is False
# Verify router.start() was called
mock_router_instance.start.assert_called_once()
# Function returns None; ensure start was invoked
def test_router_initialization_with_circuit_breaker_config(self):
"""Test router initialization with circuit breaker configuration."""
args = RouterArgs(
cb_failure_threshold=5,
cb_success_threshold=2,
cb_timeout_duration_secs=30,
cb_window_duration_secs=60,
disable_circuit_breaker=False,
)
with patch("sglang_router.launch_router.Router") as router_mod:
captured_args = {}
mock_router_instance = MagicMock()
def fake_from_args(router_args):
captured_args.update(
dict(
cb_failure_threshold=router_args.cb_failure_threshold,
cb_success_threshold=router_args.cb_success_threshold,
cb_timeout_duration_secs=router_args.cb_timeout_duration_secs,
cb_window_duration_secs=router_args.cb_window_duration_secs,
disable_circuit_breaker=router_args.disable_circuit_breaker,
)
)
return mock_router_instance
router_mod.from_args = MagicMock(side_effect=fake_from_args)
result = launch_router(args)
# Verify router was created with circuit breaker parameters
router_mod.from_args.assert_called_once()
assert captured_args["cb_failure_threshold"] == 5
assert captured_args["cb_success_threshold"] == 2
assert captured_args["cb_timeout_duration_secs"] == 30
assert captured_args["cb_window_duration_secs"] == 60
assert captured_args["disable_circuit_breaker"] is False
# Verify router.start() was called
mock_router_instance.start.assert_called_once()
# Function returns None; ensure start was invoked
def test_router_initialization_with_rate_limiting_config(self):
"""Test router initialization with rate limiting configuration."""
args = RouterArgs(
max_concurrent_requests=512,
queue_size=200,
queue_timeout_secs=120,
rate_limit_tokens_per_second=100,
)
with patch("sglang_router.launch_router.Router") as router_mod:
captured_args = {}
mock_router_instance = MagicMock()
def fake_from_args(router_args):
captured_args.update(
dict(
max_concurrent_requests=router_args.max_concurrent_requests,
queue_size=router_args.queue_size,
queue_timeout_secs=router_args.queue_timeout_secs,
rate_limit_tokens_per_second=router_args.rate_limit_tokens_per_second,
)
)
return mock_router_instance
router_mod.from_args = MagicMock(side_effect=fake_from_args)
result = launch_router(args)
# Verify router was created with rate limiting parameters
router_mod.from_args.assert_called_once()
assert captured_args["max_concurrent_requests"] == 512
assert captured_args["queue_size"] == 200
assert captured_args["queue_timeout_secs"] == 120
assert captured_args["rate_limit_tokens_per_second"] == 100
# Verify router.start() was called
mock_router_instance.start.assert_called_once()
# Function returns None; ensure start was invoked
def test_router_initialization_with_health_check_config(self):
"""Test router initialization with health check configuration."""
args = RouterArgs(
health_failure_threshold=2,
health_success_threshold=1,
health_check_timeout_secs=3,
health_check_interval_secs=30,
health_check_endpoint="/healthz",
)
with patch("sglang_router.launch_router.Router") as router_mod:
captured_args = {}
mock_router_instance = MagicMock()
def fake_from_args(router_args):
captured_args.update(
dict(
health_failure_threshold=router_args.health_failure_threshold,
health_success_threshold=router_args.health_success_threshold,
health_check_timeout_secs=router_args.health_check_timeout_secs,
health_check_interval_secs=router_args.health_check_interval_secs,
health_check_endpoint=router_args.health_check_endpoint,
)
)
return mock_router_instance
router_mod.from_args = MagicMock(side_effect=fake_from_args)
result = launch_router(args)
# Verify router was created with health check parameters
router_mod.from_args.assert_called_once()
assert captured_args["health_failure_threshold"] == 2
assert captured_args["health_success_threshold"] == 1
assert captured_args["health_check_timeout_secs"] == 3
assert captured_args["health_check_interval_secs"] == 30
assert captured_args["health_check_endpoint"] == "/healthz"
# Verify router.start() was called
mock_router_instance.start.assert_called_once()
# Function returns None; ensure start was invoked
def test_router_initialization_with_prometheus_config(self):
"""Test router initialization with Prometheus configuration."""
args = RouterArgs(prometheus_port=29000, prometheus_host="127.0.0.1")
with patch("sglang_router.launch_router.Router") as router_mod:
captured_args = {}
mock_router_instance = MagicMock()
def fake_from_args(router_args):
captured_args.update(
dict(
prometheus_port=router_args.prometheus_port,
prometheus_host=router_args.prometheus_host,
)
)
return mock_router_instance
router_mod.from_args = MagicMock(side_effect=fake_from_args)
result = launch_router(args)
# Verify router was created with Prometheus parameters
router_mod.from_args.assert_called_once()
assert captured_args["prometheus_port"] == 29000
assert captured_args["prometheus_host"] == "127.0.0.1"
# Verify router.start() was called
mock_router_instance.start.assert_called_once()
# Function returns None; ensure start was invoked
def test_router_initialization_with_cors_config(self):
"""Test router initialization with CORS configuration."""
args = RouterArgs(
cors_allowed_origins=["http://localhost:3000", "https://example.com"]
)
with patch("sglang_router.launch_router.Router") as router_mod:
captured_args = {}
mock_router_instance = MagicMock()
def fake_from_args(router_args):
captured_args.update(
dict(cors_allowed_origins=router_args.cors_allowed_origins)
)
return mock_router_instance
router_mod.from_args = MagicMock(side_effect=fake_from_args)
result = launch_router(args)
# Verify router was created with CORS parameters
router_mod.from_args.assert_called_once()
assert captured_args["cors_allowed_origins"] == [
"http://localhost:3000",
"https://example.com",
]
# Verify router.start() was called
mock_router_instance.start.assert_called_once()
# Function returns None; ensure start was invoked
def test_router_initialization_with_tokenizer_config(self):
"""Test router initialization with tokenizer configuration."""
# Note: model_path and tokenizer_path are not available in current RouterArgs
pytest.skip("Tokenizer configuration not available in current implementation")
class TestStartupValidation:
"""Test startup validation logic."""
def test_pd_mode_validation_during_startup(self):
"""Test PD mode validation during startup."""
# PD mode without URLs should fail
args = RouterArgs(
pd_disaggregation=True,
prefill_urls=[],
decode_urls=[],
service_discovery=False,
)
with pytest.raises(
ValueError, match="PD disaggregation mode requires --prefill"
):
launch_router(args)
def test_pd_mode_with_service_discovery_validation(self):
"""Test PD mode with service discovery validation during startup."""
args = RouterArgs(
pd_disaggregation=True,
prefill_urls=[],
decode_urls=[],
service_discovery=True,
)
# Should not raise validation error
with patch("sglang_router.launch_router.Router") as router_mod:
mock_router_instance = MagicMock()
router_mod.from_args = MagicMock(return_value=mock_router_instance)
result = launch_router(args)
# Should create router instance
router_mod.from_args.assert_called_once()
def test_policy_warning_during_startup(self):
"""Test policy warning during startup in PD mode."""
args = RouterArgs(
pd_disaggregation=True,
prefill_urls=[("http://prefill1:8000", None)],
decode_urls=["http://decode1:8001"],
policy="cache_aware",
prefill_policy="power_of_two",
decode_policy="round_robin",
)
with patch("sglang_router.launch_router.Router") as router_mod:
mock_router_instance = MagicMock()
router_mod.from_args = MagicMock(return_value=mock_router_instance)
# The policy messages are emitted by router_args logger
with patch("sglang_router.router_args.logger") as mock_logger:
result = launch_router(args)
# Should log warning about policy usage
mock_logger.warning.assert_called_once()
warning_call = mock_logger.warning.call_args[0][0]
assert (
"Both --prefill-policy and --decode-policy are specified"
in warning_call
)
# Should create router instance
router_mod.from_args.assert_called_once()
def test_policy_info_during_startup(self):
"""Test policy info logging during startup in PD mode."""
# Test with only prefill policy specified
args = RouterArgs(
pd_disaggregation=True,
prefill_urls=[("http://prefill1:8000", None)],
decode_urls=["http://decode1:8001"],
policy="cache_aware",
prefill_policy="power_of_two",
decode_policy=None,
)
with patch("sglang_router.launch_router.Router") as router_mod:
mock_router_instance = MagicMock()
router_mod.from_args = MagicMock(return_value=mock_router_instance)
# The policy messages are emitted by router_args logger
with patch("sglang_router.router_args.logger") as mock_logger:
result = launch_router(args)
# Should log info about policy usage
mock_logger.info.assert_called_once()
info_call = mock_logger.info.call_args[0][0]
assert "Using --prefill-policy 'power_of_two'" in info_call
assert "and --policy 'cache_aware'" in info_call
# Should create router instance
router_mod.from_args.assert_called_once()
def test_policy_info_decode_only_during_startup(self):
"""Test policy info logging during startup with only decode policy specified."""
args = RouterArgs(
pd_disaggregation=True,
prefill_urls=[("http://prefill1:8000", None)],
decode_urls=["http://decode1:8001"],
policy="cache_aware",
prefill_policy=None,
decode_policy="round_robin",
)
with patch("sglang_router.launch_router.Router") as router_mod:
mock_router_instance = MagicMock()
router_mod.from_args = MagicMock(return_value=mock_router_instance)
# The policy messages are emitted by router_args logger
with patch("sglang_router.router_args.logger") as mock_logger:
result = launch_router(args)
# Should log info about policy usage
mock_logger.info.assert_called_once()
info_call = mock_logger.info.call_args[0][0]
assert "Using --policy 'cache_aware'" in info_call
assert "and --decode-policy 'round_robin'" in info_call
# Should create router instance
router_mod.from_args.assert_called_once()
class TestStartupErrorHandling:
"""Test startup error handling logic."""
def test_router_creation_error_handling(self):
"""Test error handling when router creation fails."""
args = RouterArgs(
host="127.0.0.1", port=30000, worker_urls=["http://worker1:8000"]
)
with patch("sglang_router.launch_router.Router") as router_mod:
# Simulate router creation failure in from_args
router_mod.from_args = MagicMock(
side_effect=Exception("Router creation failed")
)
with patch("sglang_router.launch_router.logger") as mock_logger:
with pytest.raises(Exception, match="Router creation failed"):
launch_router(args)
# Should log error
mock_logger.error.assert_called_once()
error_call = mock_logger.error.call_args[0][0]
assert "Error starting router: Router creation failed" in error_call
def test_router_start_error_handling(self):
"""Test error handling when router start fails."""
args = RouterArgs(
host="127.0.0.1", port=30000, worker_urls=["http://worker1:8000"]
)
with patch("sglang_router.launch_router.Router") as router_mod:
mock_router_instance = MagicMock()
router_mod.from_args = MagicMock(return_value=mock_router_instance)
# Simulate router start failure
mock_router_instance.start.side_effect = Exception("Router start failed")
with patch("sglang_router.launch_router.logger") as mock_logger:
with pytest.raises(Exception, match="Router start failed"):
launch_router(args)
# Should log error
mock_logger.error.assert_called_once()
error_call = mock_logger.error.call_args[0][0]
assert "Error starting router: Router start failed" in error_call
# --- Added unit tests for Router wrapper and launch_server helpers ---
def _install_sglang_stubs(monkeypatch):
"""Install lightweight stubs for sglang.srt to avoid heavy deps during unit tests."""
import sys
import types
sglang_mod = types.ModuleType("sglang")
srt_mod = types.ModuleType("sglang.srt")
entry_mod = types.ModuleType("sglang.srt.entrypoints")
http_server_mod = types.ModuleType("sglang.srt.entrypoints.http_server")
server_args_mod = types.ModuleType("sglang.srt.server_args")
utils_mod = types.ModuleType("sglang.srt.utils")
def launch_server(_args):
return None
class ServerArgs:
# Minimal fields used by launch_server_process
def __init__(self):
self.port = 0
self.base_gpu_id = 0
self.dp_size = 1
self.tp_size = 1
@staticmethod
def add_cli_args(_parser):
return None
@staticmethod
def from_cli_args(_args):
sa = ServerArgs()
if hasattr(_args, "dp_size"):
sa.dp_size = _args.dp_size
if hasattr(_args, "tp_size"):
sa.tp_size = _args.tp_size
if hasattr(_args, "host"):
sa.host = _args.host
else:
sa.host = "127.0.0.1"
return sa
def is_port_available(_port: int) -> bool:
return True
http_server_mod.launch_server = launch_server
server_args_mod.ServerArgs = ServerArgs
utils_mod.is_port_available = is_port_available
# Also stub external deps imported at module top-level
def _dummy_get(*_a, **_k):
raise NotImplementedError
requests_stub = types.SimpleNamespace(
exceptions=types.SimpleNamespace(RequestException=Exception), get=_dummy_get
)
setproctitle_stub = types.SimpleNamespace(setproctitle=lambda *_a, **_k: None)
monkeypatch.setitem(sys.modules, "requests", requests_stub)
monkeypatch.setitem(sys.modules, "setproctitle", setproctitle_stub)
monkeypatch.setitem(sys.modules, "sglang", sglang_mod)
monkeypatch.setitem(sys.modules, "sglang.srt", srt_mod)
monkeypatch.setitem(sys.modules, "sglang.srt.entrypoints", entry_mod)
monkeypatch.setitem(
sys.modules, "sglang.srt.entrypoints.http_server", http_server_mod
)
monkeypatch.setitem(sys.modules, "sglang.srt.server_args", server_args_mod)
monkeypatch.setitem(sys.modules, "sglang.srt.utils", utils_mod)
def test_router_defaults_and_start(monkeypatch):
"""Router wrapper: defaults normalization and start() call.
Mocks the Rust-backed _Router to avoid native deps.
"""
from sglang_router import router as router_mod
captured = {}
class FakeRouter:
def __init__(self, **kwargs):
captured.update(kwargs)
def start(self):
captured["started"] = True
monkeypatch.setattr(router_mod, "_Router", FakeRouter, raising=True)
from sglang_router.router_args import RouterArgs as _RouterArgs
Router = router_mod.Router
args = _RouterArgs(
worker_urls=["http://w1:8000"],
policy="round_robin",
selector=None,
prefill_selector=None,
decode_selector=None,
cors_allowed_origins=None,
)
r = Router.from_args(args)
# Defaults preserved/normalized by Router.from_args
assert captured["selector"] is None
assert captured["prefill_selector"] is None
assert captured["decode_selector"] is None
assert captured["cors_allowed_origins"] is None
assert captured["worker_urls"] == ["http://w1:8000"]
from sglang_router_rs import PolicyType
assert captured["policy"] == PolicyType.RoundRobin
r.start()
assert captured.get("started") is True
def test_find_available_ports_and_wait_health(monkeypatch):
"""launch_server helpers: port finding and health waiting with transient error."""
_install_sglang_stubs(monkeypatch)
import importlib
ls = importlib.import_module("sglang_router.launch_server")
# Deterministic increments
monkeypatch.setattr(ls.random, "randint", lambda a, b: 100)
ports = ls.find_available_ports(30000, 3)
assert ports == [30000, 30100, 30200]
calls = {"n": 0}
class Ok:
status_code = 200
def fake_get(_url, timeout=5):
calls["n"] += 1
if calls["n"] == 1:
raise ls.requests.exceptions.RequestException("boom")
return Ok()
monkeypatch.setattr(ls.requests, "get", fake_get)
monkeypatch.setattr(ls.time, "sleep", lambda _s: None)
base = {"t": 0.0}
monkeypatch.setattr(
ls.time,
"perf_counter",
lambda: (base.__setitem__("t", base["t"] + 0.1) or base["t"]),
)
assert ls.wait_for_server_health("127.0.0.1", 12345, timeout=1)
def test_launch_server_process_and_cleanup(monkeypatch):
"""launch_server: process creation args and cleanup SIGTERM/SIGKILL logic."""
_install_sglang_stubs(monkeypatch)
import importlib
ls = importlib.import_module("sglang_router.launch_server")
created = {}
class FakeProcess:
def __init__(self, target, args):
created["target"] = target
created["args"] = args
self.pid = 4242
self._alive = True
def start(self):
created["started"] = True
def join(self, timeout=None):
return None
def is_alive(self):
return self._alive
monkeypatch.setattr(ls.mp, "Process", FakeProcess)
import sys as _sys
SA = _sys.modules["sglang.srt.server_args"].ServerArgs
sa = SA()
sa.tp_size = 2
proc = ls.launch_server_process(sa, worker_port=31001, dp_id=3)
assert created.get("started") is True
targ, targ_args = created["target"], created["args"]
assert targ is ls.run_server
passed_sa = targ_args[0]
assert passed_sa.port == 31001
assert passed_sa.base_gpu_id == 3 * 2
assert passed_sa.dp_size == 1
# cleanup_processes
p1 = FakeProcess(target=None, args=())
p1._alive = False
p2 = FakeProcess(target=None, args=())
p2._alive = True
calls = []
def fake_killpg(pid, sig):
calls.append((pid, sig))
monkeypatch.setattr(ls.os, "killpg", fake_killpg)
ls.cleanup_processes([p1, p2])
import signal as _sig
assert (p1.pid, _sig.SIGTERM) in calls and (p2.pid, _sig.SIGTERM) in calls
assert (p2.pid, _sig.SIGKILL) in calls
def test_validation_error_handling(self):
"""Test error handling when validation fails."""
args = RouterArgs(
pd_disaggregation=True,
prefill_urls=[],
decode_urls=[],
service_discovery=False,
)
with patch("sglang_router.launch_router.logger") as mock_logger:
with pytest.raises(
ValueError, match="PD disaggregation mode requires --prefill"
):
launch_router(args)
# Should log error for validation failures
mock_logger.error.assert_called_once()
class TestStartupFlow:
"""Test complete startup flow."""
def test_complete_startup_flow_basic(self):
"""Test complete startup flow for basic configuration."""
args = RouterArgs(
host="127.0.0.1",
port=30000,
worker_urls=["http://worker1:8000", "http://worker2:8000"],
policy="cache_aware",
cache_threshold=0.5,
balance_abs_threshold=32,
balance_rel_threshold=1.5,
)
with patch("sglang_router.launch_router.Router") as router_mod:
mock_router_instance = MagicMock()
router_mod.from_args = MagicMock(return_value=mock_router_instance)
result = launch_router(args)
# Verify complete flow
router_mod.from_args.assert_called_once()
mock_router_instance.start.assert_called_once()
def test_complete_startup_flow_pd_mode(self):
"""Test complete startup flow for PD mode configuration."""
args = RouterArgs(
pd_disaggregation=True,
prefill_urls=[
("http://prefill1:8000", 9000),
("http://prefill2:8000", None),
],
decode_urls=["http://decode1:8001", "http://decode2:8001"],
policy="power_of_two",
prefill_policy="cache_aware",
decode_policy="round_robin",
)
with patch("sglang_router.launch_router.Router") as router_mod:
mock_router_instance = MagicMock()
router_mod.from_args = MagicMock(return_value=mock_router_instance)
with patch("sglang_router.router_args.logger") as mock_logger:
result = launch_router(args)
# Verify complete flow
router_mod.from_args.assert_called_once()
mock_router_instance.start.assert_called_once()
# Verify policy warning was logged
mock_logger.warning.assert_called_once()
def test_complete_startup_flow_with_all_features(self):
"""Test complete startup flow with all features enabled."""
args = RouterArgs(
host="0.0.0.0",
port=30001,
worker_urls=["http://worker1:8000"],
policy="round_robin",
service_discovery=True,
selector={"app": "worker"},
service_discovery_port=8080,
service_discovery_namespace="default",
dp_aware=True,
api_key="test-key",
log_dir="/tmp/logs",
log_level="debug",
prometheus_port=29000,
prometheus_host="0.0.0.0",
request_id_headers=["x-request-id", "x-trace-id"],
request_timeout_secs=1200,
max_concurrent_requests=512,
queue_size=200,
queue_timeout_secs=120,
rate_limit_tokens_per_second=100,
cors_allowed_origins=["http://localhost:3000"],
retry_max_retries=3,
retry_initial_backoff_ms=100,
retry_max_backoff_ms=10000,
retry_backoff_multiplier=2.0,
retry_jitter_factor=0.1,
cb_failure_threshold=5,
cb_success_threshold=2,
cb_timeout_duration_secs=30,
cb_window_duration_secs=60,
health_failure_threshold=2,
health_success_threshold=1,
health_check_timeout_secs=3,
health_check_interval_secs=30,
health_check_endpoint="/healthz",
)
with patch("sglang_router.launch_router.Router") as router_mod:
captured_args = {}
mock_router_instance = MagicMock()
def fake_from_args(router_args):
captured_args.update(
dict(
host=router_args.host,
port=router_args.port,
worker_urls=router_args.worker_urls,
policy=policy_from_str(router_args.policy),
service_discovery=router_args.service_discovery,
selector=router_args.selector,
service_discovery_port=router_args.service_discovery_port,
service_discovery_namespace=router_args.service_discovery_namespace,
dp_aware=router_args.dp_aware,
api_key=router_args.api_key,
log_dir=router_args.log_dir,
log_level=router_args.log_level,
prometheus_port=router_args.prometheus_port,
prometheus_host=router_args.prometheus_host,
request_id_headers=router_args.request_id_headers,
request_timeout_secs=router_args.request_timeout_secs,
max_concurrent_requests=router_args.max_concurrent_requests,
queue_size=router_args.queue_size,
queue_timeout_secs=router_args.queue_timeout_secs,
rate_limit_tokens_per_second=router_args.rate_limit_tokens_per_second,
cors_allowed_origins=router_args.cors_allowed_origins,
retry_max_retries=router_args.retry_max_retries,
retry_initial_backoff_ms=router_args.retry_initial_backoff_ms,
retry_max_backoff_ms=router_args.retry_max_backoff_ms,
retry_backoff_multiplier=router_args.retry_backoff_multiplier,
retry_jitter_factor=router_args.retry_jitter_factor,
cb_failure_threshold=router_args.cb_failure_threshold,
cb_success_threshold=router_args.cb_success_threshold,
cb_timeout_duration_secs=router_args.cb_timeout_duration_secs,
cb_window_duration_secs=router_args.cb_window_duration_secs,
health_failure_threshold=router_args.health_failure_threshold,
health_success_threshold=router_args.health_success_threshold,
health_check_timeout_secs=router_args.health_check_timeout_secs,
health_check_interval_secs=router_args.health_check_interval_secs,
health_check_endpoint=router_args.health_check_endpoint,
)
)
return mock_router_instance
router_mod.from_args = MagicMock(side_effect=fake_from_args)
result = launch_router(args)
# Verify complete flow
router_mod.from_args.assert_called_once()
mock_router_instance.start.assert_called_once()
# Verify key parameters were propagated into RouterArgs
assert captured_args["host"] == "0.0.0.0"
assert captured_args["port"] == 30001
assert captured_args["worker_urls"] == ["http://worker1:8000"]
assert captured_args["policy"] == PolicyType.RoundRobin
assert captured_args["service_discovery"] is True
assert captured_args["selector"] == {"app": "worker"}
assert captured_args["service_discovery_port"] == 8080
assert captured_args["service_discovery_namespace"] == "default"
assert captured_args["dp_aware"] is True
assert captured_args["api_key"] == "test-key"
assert captured_args["log_dir"] == "/tmp/logs"
assert captured_args["log_level"] == "debug"
assert captured_args["prometheus_port"] == 29000
assert captured_args["prometheus_host"] == "0.0.0.0"
assert captured_args["request_id_headers"] == ["x-request-id", "x-trace-id"]
assert captured_args["request_timeout_secs"] == 1200
assert captured_args["max_concurrent_requests"] == 512
assert captured_args["queue_size"] == 200
assert captured_args["queue_timeout_secs"] == 120
assert captured_args["rate_limit_tokens_per_second"] == 100
assert captured_args["cors_allowed_origins"] == ["http://localhost:3000"]
assert captured_args["retry_max_retries"] == 3
assert captured_args["retry_initial_backoff_ms"] == 100
assert captured_args["retry_max_backoff_ms"] == 10000
assert captured_args["retry_backoff_multiplier"] == 2.0
assert captured_args["retry_jitter_factor"] == 0.1
assert captured_args["cb_failure_threshold"] == 5
assert captured_args["cb_success_threshold"] == 2
assert captured_args["cb_timeout_duration_secs"] == 30
assert captured_args["cb_window_duration_secs"] == 60
assert captured_args["health_failure_threshold"] == 2
assert captured_args["health_success_threshold"] == 1
assert captured_args["health_check_timeout_secs"] == 3
assert captured_args["health_check_interval_secs"] == 30
assert captured_args["health_check_endpoint"] == "/healthz"
"""
Unit tests for validation logic in sglang_router.
These tests focus on testing the validation logic in isolation,
including parameter validation, URL validation, and configuration validation.
"""
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from sglang_router.launch_router import RouterArgs, launch_router
class TestURLValidation:
"""Test URL validation logic."""
def test_valid_worker_urls(self):
"""Test validation of valid worker URLs."""
valid_urls = [
"http://worker1:8000",
"https://worker2:8000",
"http://localhost:8000",
"http://127.0.0.1:8000",
"http://192.168.1.100:8000",
"http://worker.example.com:8000",
]
for url in valid_urls:
args = RouterArgs(worker_urls=[url])
# Should not raise any validation errors
assert url in args.worker_urls
def test_valid_prefill_urls(self):
"""Test validation of valid prefill URLs."""
valid_prefill_urls = [
("http://prefill1:8000", 9000),
("https://prefill2:8000", None),
("http://localhost:8000", 9000),
("http://127.0.0.1:8000", None),
]
for url, bootstrap_port in valid_prefill_urls:
args = RouterArgs(prefill_urls=[(url, bootstrap_port)])
# Should not raise any validation errors
assert (url, bootstrap_port) in args.prefill_urls
def test_valid_decode_urls(self):
"""Test validation of valid decode URLs."""
valid_decode_urls = [
"http://decode1:8001",
"https://decode2:8001",
"http://localhost:8001",
"http://127.0.0.1:8001",
]
for url in valid_decode_urls:
args = RouterArgs(decode_urls=[url])
# Should not raise any validation errors
assert url in args.decode_urls
def test_malformed_urls(self):
"""Test handling of malformed URLs."""
# Note: The current implementation doesn't validate URL format
# This test documents the current behavior
malformed_urls = [
"not-a-url",
"ftp://worker1:8000", # Wrong protocol
"http://", # Missing host
":8000", # Missing protocol and host
"http://worker1", # Missing port
]
for url in malformed_urls:
args = RouterArgs(worker_urls=[url])
# Currently, malformed URLs are accepted
# This might be something to improve in the future
assert url in args.worker_urls
class TestPortValidation:
"""Test port validation logic."""
def test_valid_ports(self):
"""Test validation of valid port numbers."""
valid_ports = [1, 80, 8000, 30000, 65535]
for port in valid_ports:
args = RouterArgs(port=port)
assert args.port == port
def test_invalid_ports(self):
"""Test handling of invalid port numbers."""
# Note: The current implementation doesn't validate port ranges
# This test documents the current behavior
invalid_ports = [0, -1, 65536, 70000]
for port in invalid_ports:
args = RouterArgs(port=port)
# Currently, invalid ports are accepted
# This might be something to improve in the future
assert args.port == port
def test_bootstrap_port_validation(self):
"""Test validation of bootstrap ports in PD mode."""
valid_bootstrap_ports = [1, 80, 9000, 30000, 65535, None]
for bootstrap_port in valid_bootstrap_ports:
args = RouterArgs(prefill_urls=[("http://prefill1:8000", bootstrap_port)])
assert args.prefill_urls[0][1] == bootstrap_port
class TestParameterValidation:
"""Test parameter validation logic."""
def test_cache_threshold_validation(self):
"""Test cache threshold parameter validation."""
# Valid cache thresholds
valid_thresholds = [0.0, 0.1, 0.5, 0.9, 1.0]
for threshold in valid_thresholds:
args = RouterArgs(cache_threshold=threshold)
assert args.cache_threshold == threshold
def test_balance_threshold_validation(self):
"""Test load balancing threshold parameter validation."""
# Valid absolute thresholds
valid_abs_thresholds = [0, 1, 32, 64, 128, 1000]
for threshold in valid_abs_thresholds:
args = RouterArgs(balance_abs_threshold=threshold)
assert args.balance_abs_threshold == threshold
# Valid relative thresholds
valid_rel_thresholds = [1.0, 1.1, 1.5, 2.0, 10.0]
for threshold in valid_rel_thresholds:
args = RouterArgs(balance_rel_threshold=threshold)
assert args.balance_rel_threshold == threshold
def test_timeout_validation(self):
"""Test timeout parameter validation."""
# Valid timeouts
valid_timeouts = [1, 30, 60, 300, 600, 1800, 3600]
for timeout in valid_timeouts:
args = RouterArgs(
worker_startup_timeout_secs=timeout,
worker_startup_check_interval=timeout,
request_timeout_secs=timeout,
queue_timeout_secs=timeout,
)
assert args.worker_startup_timeout_secs == timeout
assert args.worker_startup_check_interval == timeout
assert args.request_timeout_secs == timeout
assert args.queue_timeout_secs == timeout
def test_retry_parameter_validation(self):
"""Test retry parameter validation."""
# Valid retry parameters
valid_retry_counts = [0, 1, 3, 5, 10]
for count in valid_retry_counts:
args = RouterArgs(retry_max_retries=count)
assert args.retry_max_retries == count
# Valid backoff parameters
valid_backoff_ms = [1, 50, 100, 1000, 30000]
for backoff in valid_backoff_ms:
args = RouterArgs(
retry_initial_backoff_ms=backoff, retry_max_backoff_ms=backoff
)
assert args.retry_initial_backoff_ms == backoff
assert args.retry_max_backoff_ms == backoff
# Valid multiplier parameters
valid_multipliers = [1.0, 1.5, 2.0, 3.0]
for multiplier in valid_multipliers:
args = RouterArgs(retry_backoff_multiplier=multiplier)
assert args.retry_backoff_multiplier == multiplier
# Valid jitter parameters
valid_jitter = [0.0, 0.1, 0.2, 0.5]
for jitter in valid_jitter:
args = RouterArgs(retry_jitter_factor=jitter)
assert args.retry_jitter_factor == jitter
def test_circuit_breaker_parameter_validation(self):
"""Test circuit breaker parameter validation."""
# Valid failure thresholds
valid_failure_thresholds = [1, 3, 5, 10, 20]
for threshold in valid_failure_thresholds:
args = RouterArgs(cb_failure_threshold=threshold)
assert args.cb_failure_threshold == threshold
# Valid success thresholds
valid_success_thresholds = [1, 2, 3, 5]
for threshold in valid_success_thresholds:
args = RouterArgs(cb_success_threshold=threshold)
assert args.cb_success_threshold == threshold
# Valid timeout durations
valid_timeouts = [10, 30, 60, 120, 300]
for timeout in valid_timeouts:
args = RouterArgs(
cb_timeout_duration_secs=timeout, cb_window_duration_secs=timeout
)
assert args.cb_timeout_duration_secs == timeout
assert args.cb_window_duration_secs == timeout
def test_health_check_parameter_validation(self):
"""Test health check parameter validation."""
# Valid failure thresholds
valid_failure_thresholds = [1, 2, 3, 5, 10]
for threshold in valid_failure_thresholds:
args = RouterArgs(health_failure_threshold=threshold)
assert args.health_failure_threshold == threshold
# Valid success thresholds
valid_success_thresholds = [1, 2, 3, 5]
for threshold in valid_success_thresholds:
args = RouterArgs(health_success_threshold=threshold)
assert args.health_success_threshold == threshold
# Valid timeouts and intervals
valid_times = [1, 5, 10, 30, 60, 120]
for time_val in valid_times:
args = RouterArgs(
health_check_timeout_secs=time_val, health_check_interval_secs=time_val
)
assert args.health_check_timeout_secs == time_val
assert args.health_check_interval_secs == time_val
def test_rate_limiting_parameter_validation(self):
"""Test rate limiting parameter validation."""
# Valid concurrent request limits
valid_limits = [1, 10, 64, 256, 512, 1000]
for limit in valid_limits:
args = RouterArgs(max_concurrent_requests=limit)
assert args.max_concurrent_requests == limit
# Valid queue sizes
valid_queue_sizes = [0, 10, 50, 100, 500, 1000]
for size in valid_queue_sizes:
args = RouterArgs(queue_size=size)
assert args.queue_size == size
# Valid token rates
valid_rates = [1, 10, 50, 100, 500, 1000]
for rate in valid_rates:
args = RouterArgs(rate_limit_tokens_per_second=rate)
assert args.rate_limit_tokens_per_second == rate
def test_tree_size_validation(self):
"""Test tree size parameter validation."""
# Valid tree sizes (powers of 2)
valid_sizes = [2**10, 2**20, 2**24, 2**26, 2**28, 2**30]
for size in valid_sizes:
args = RouterArgs(max_tree_size=size)
assert args.max_tree_size == size
def test_payload_size_validation(self):
"""Test payload size parameter validation."""
# Valid payload sizes
valid_sizes = [
1024, # 1KB
1024 * 1024, # 1MB
10 * 1024 * 1024, # 10MB
100 * 1024 * 1024, # 100MB
512 * 1024 * 1024, # 512MB
1024 * 1024 * 1024, # 1GB
]
for size in valid_sizes:
args = RouterArgs(max_payload_size=size)
assert args.max_payload_size == size
class TestConfigurationValidation:
"""Test configuration validation logic."""
def test_pd_mode_validation(self):
"""Test PD mode configuration validation."""
# Valid PD configuration
args = RouterArgs(
pd_disaggregation=True,
prefill_urls=[("http://prefill1:8000", 9000)],
decode_urls=["http://decode1:8001"],
)
assert args.pd_disaggregation is True
assert len(args.prefill_urls) > 0
assert len(args.decode_urls) > 0
def test_service_discovery_validation(self):
"""Test service discovery configuration validation."""
# Valid service discovery configuration
args = RouterArgs(
service_discovery=True,
selector={"app": "worker", "env": "prod"},
service_discovery_port=8080,
service_discovery_namespace="default",
)
assert args.service_discovery is True
assert args.selector == {"app": "worker", "env": "prod"}
assert args.service_discovery_port == 8080
assert args.service_discovery_namespace == "default"
def test_pd_service_discovery_validation(self):
"""Test PD service discovery configuration validation."""
# Valid PD service discovery configuration
args = RouterArgs(
pd_disaggregation=True,
service_discovery=True,
prefill_selector={"app": "prefill"},
decode_selector={"app": "decode"},
)
assert args.pd_disaggregation is True
assert args.service_discovery is True
assert args.prefill_selector == {"app": "prefill"}
assert args.decode_selector == {"app": "decode"}
def test_policy_validation(self):
"""Test policy configuration validation."""
# Valid policies
valid_policies = ["random", "round_robin", "cache_aware", "power_of_two"]
for policy in valid_policies:
args = RouterArgs(policy=policy)
assert args.policy == policy
def test_pd_policy_validation(self):
"""Test PD policy configuration validation."""
# Valid PD policies
valid_policies = ["random", "round_robin", "cache_aware", "power_of_two"]
for prefill_policy in valid_policies:
for decode_policy in valid_policies:
args = RouterArgs(
pd_disaggregation=True,
prefill_urls=[("http://prefill1:8000", None)],
decode_urls=["http://decode1:8001"],
prefill_policy=prefill_policy,
decode_policy=decode_policy,
)
assert args.prefill_policy == prefill_policy
assert args.decode_policy == decode_policy
def test_cors_validation(self):
"""Test CORS configuration validation."""
# Valid CORS origins
valid_origins = [
[],
["http://localhost:3000"],
["https://example.com"],
["http://localhost:3000", "https://example.com"],
["*"], # Wildcard (if supported)
]
for origins in valid_origins:
args = RouterArgs(cors_allowed_origins=origins)
assert args.cors_allowed_origins == origins
def test_logging_validation(self):
"""Test logging configuration validation."""
# Valid log levels
valid_log_levels = ["debug", "info", "warning", "error", "critical"]
for level in valid_log_levels:
args = RouterArgs(log_level=level)
assert args.log_level == level
def test_prometheus_validation(self):
"""Test Prometheus configuration validation."""
# Valid Prometheus configuration
args = RouterArgs(prometheus_port=29000, prometheus_host="127.0.0.1")
assert args.prometheus_port == 29000
assert args.prometheus_host == "127.0.0.1"
def test_tokenizer_validation(self):
"""Test tokenizer configuration validation."""
# Note: model_path and tokenizer_path are not available in current RouterArgs
pytest.skip("Tokenizer configuration not available in current implementation")
def test_request_id_headers_validation(self):
"""Test request ID headers configuration validation."""
# Valid request ID headers
valid_headers = [
["x-request-id"],
["x-request-id", "x-trace-id"],
["x-request-id", "x-trace-id", "x-correlation-id"],
["custom-header"],
]
for headers in valid_headers:
args = RouterArgs(request_id_headers=headers)
assert args.request_id_headers == headers
class TestLaunchValidation:
"""Test launch-time validation logic."""
def test_pd_mode_requires_urls(self):
"""Test that PD mode requires prefill and decode URLs."""
# PD mode without URLs should fail
args = RouterArgs(
pd_disaggregation=True,
prefill_urls=[],
decode_urls=[],
service_discovery=False,
)
with pytest.raises(
ValueError, match="PD disaggregation mode requires --prefill"
):
launch_router(args)
def test_pd_mode_with_service_discovery_allows_empty_urls(self):
"""Test that PD mode with service discovery allows empty URLs."""
args = RouterArgs(
pd_disaggregation=True,
prefill_urls=[],
decode_urls=[],
service_discovery=True,
)
# Should not raise validation error
with patch("sglang_router.launch_router.Router") as router_mod:
mock_router_instance = MagicMock()
router_mod.from_args = MagicMock(return_value=mock_router_instance)
launch_router(args)
# Should create router instance via from_args
router_mod.from_args.assert_called_once()
def test_regular_mode_allows_empty_worker_urls(self):
"""Test that regular mode allows empty worker URLs."""
args = RouterArgs(worker_urls=[], service_discovery=False)
# Should not raise validation error
with patch("sglang_router.launch_router.Router") as router_mod:
mock_router_instance = MagicMock()
router_mod.from_args = MagicMock(return_value=mock_router_instance)
launch_router(args)
# Should create router instance via from_args
router_mod.from_args.assert_called_once()
def test_launch_with_valid_config(self):
"""Test launching with valid configuration."""
args = RouterArgs(
host="127.0.0.1",
port=30000,
worker_urls=["http://worker1:8000"],
policy="cache_aware",
)
# Should not raise validation error
with patch("sglang_router.launch_router.Router") as router_mod:
mock_router_instance = MagicMock()
router_mod.from_args = MagicMock(return_value=mock_router_instance)
launch_router(args)
# Should create router instance via from_args
router_mod.from_args.assert_called_once()
def test_launch_with_pd_config(self):
"""Test launching with valid PD configuration."""
args = RouterArgs(
pd_disaggregation=True,
prefill_urls=[("http://prefill1:8000", 9000)],
decode_urls=["http://decode1:8001"],
policy="cache_aware",
)
# Should not raise validation error
with patch("sglang_router.launch_router.Router") as router_mod:
mock_router_instance = MagicMock()
router_mod.from_args = MagicMock(return_value=mock_router_instance)
launch_router(args)
# Should create router instance via from_args
router_mod.from_args.assert_called_once()
def test_launch_with_service_discovery_config(self):
"""Test launching with valid service discovery configuration."""
args = RouterArgs(
service_discovery=True,
selector={"app": "worker"},
service_discovery_port=8080,
)
# Should not raise validation error
with patch("sglang_router.launch_router.Router") as router_mod:
mock_router_instance = MagicMock()
router_mod.from_args = MagicMock(return_value=mock_router_instance)
launch_router(args)
# Should create router instance via from_args
router_mod.from_args.assert_called_once()
......@@ -21,6 +21,7 @@ dev = [
"requests>=2.25.0",
]
# https://github.com/PyO3/setuptools-rust?tab=readme-ov-file
[tool.setuptools.packages]
find = { where = ["py_src"] }
......
[pytest]
testpaths = py_test
python_files = test_*.py
python_classes = Test*
python_functions = test_*
addopts = --cov=sglang_router --cov-report=term-missing
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment