Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
045ab92d
Unverified
Commit
045ab92d
authored
Sep 05, 2025
by
Keyang Ru
Committed by
GitHub
Sep 05, 2025
Browse files
[router] add py binding unit tests to coverage 80% (#10043)
parent
bd7f8821
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
2649 additions
and
1 deletion
+2649
-1
.github/workflows/pr-test-rust.yml
.github/workflows/pr-test-rust.yml
+10
-1
sgl-router/.coveragerc
sgl-router/.coveragerc
+9
-0
sgl-router/py_test/conftest.py
sgl-router/py_test/conftest.py
+8
-0
sgl-router/py_test/unit/__init__.py
sgl-router/py_test/unit/__init__.py
+7
-0
sgl-router/py_test/unit/test_arg_parser.py
sgl-router/py_test/unit/test_arg_parser.py
+628
-0
sgl-router/py_test/unit/test_router_config.py
sgl-router/py_test/unit/test_router_config.py
+421
-0
sgl-router/py_test/unit/test_startup_sequence.py
sgl-router/py_test/unit/test_startup_sequence.py
+1053
-0
sgl-router/py_test/unit/test_validation.py
sgl-router/py_test/unit/test_validation.py
+506
-0
sgl-router/pyproject.toml
sgl-router/pyproject.toml
+1
-0
sgl-router/pytest.ini
sgl-router/pytest.ini
+6
-0
No files found.
.github/workflows/pr-test-rust.yml
View file @
045ab92d
...
...
@@ -39,7 +39,7 @@ jobs:
cd sgl-router/
cargo fmt -- --check
-
name
:
Run test
-
name
:
Run
Rust
test
s
timeout-minutes
:
20
run
:
|
source "$HOME/.cargo/env"
...
...
@@ -83,6 +83,15 @@ jobs:
pip install setuptools-rust wheel build
python3 -m build
pip install --force-reinstall dist/*.whl
-
name
:
Run Python unit tests
run
:
|
cd sgl-router
source "$HOME/.cargo/env"
pip install pytest pytest-cov pytest-xdist
pytest -q py_test/unit
-
name
:
Run e2e test
run
:
|
bash scripts/killall_sglang.sh "nuk_gpus"
...
...
sgl-router/.coveragerc
0 → 100644
View file @
045ab92d
[run]
source = py_src/sglang_router
omit =
py_src/sglang_router/mini_lb.py
[report]
fail_under = 80
omit =
py_src/sglang_router/mini_lb.py
sgl-router/py_test/conftest.py
0 → 100644
View file @
045ab92d
import
sys
from
pathlib
import
Path
# Ensure local sources in py_src are importable ahead of any installed package
_ROOT
=
Path
(
__file__
).
resolve
().
parents
[
1
]
_SRC
=
_ROOT
/
"py_src"
if
str
(
_SRC
)
not
in
sys
.
path
:
sys
.
path
.
insert
(
0
,
str
(
_SRC
))
sgl-router/py_test/unit/__init__.py
0 → 100644
View file @
045ab92d
"""
Unit tests for sglang_router.
This package contains fast, isolated unit tests for Python components
of the SGLang router. These tests focus on testing individual functions
and classes in isolation without starting actual router instances.
"""
sgl-router/py_test/unit/test_arg_parser.py
0 → 100644
View file @
045ab92d
"""
Unit tests for argument parsing functionality in sglang_router.
These tests focus on testing the argument parsing logic in isolation,
without starting actual router instances.
"""
import
argparse
from
types
import
SimpleNamespace
from
unittest.mock
import
MagicMock
,
patch
import
pytest
from
sglang_router.launch_router
import
RouterArgs
,
parse_router_args
from
sglang_router.router
import
policy_from_str
class
TestRouterArgs
:
"""Test RouterArgs dataclass and its methods."""
def
test_default_values
(
self
):
"""Test that RouterArgs has correct default values."""
args
=
RouterArgs
()
# Test basic defaults
assert
args
.
host
==
"127.0.0.1"
assert
args
.
port
==
30000
assert
args
.
policy
==
"cache_aware"
assert
args
.
worker_urls
==
[]
assert
args
.
pd_disaggregation
is
False
assert
args
.
prefill_urls
==
[]
assert
args
.
decode_urls
==
[]
# Test PD-specific defaults
assert
args
.
prefill_policy
is
None
assert
args
.
decode_policy
is
None
# Test service discovery defaults
assert
args
.
service_discovery
is
False
assert
args
.
selector
==
{}
assert
args
.
service_discovery_port
==
80
assert
args
.
service_discovery_namespace
is
None
# Test retry and circuit breaker defaults
assert
args
.
retry_max_retries
==
5
assert
args
.
cb_failure_threshold
==
10
assert
args
.
disable_retries
is
False
assert
args
.
disable_circuit_breaker
is
False
def
test_parse_selector_valid
(
self
):
"""Test parsing valid selector arguments."""
# Test single key-value pair
result
=
RouterArgs
.
_parse_selector
([
"app=worker"
])
assert
result
==
{
"app"
:
"worker"
}
# Test multiple key-value pairs
result
=
RouterArgs
.
_parse_selector
([
"app=worker"
,
"env=prod"
,
"version=v1"
])
assert
result
==
{
"app"
:
"worker"
,
"env"
:
"prod"
,
"version"
:
"v1"
}
# Test empty list
result
=
RouterArgs
.
_parse_selector
([])
assert
result
==
{}
# Test None
result
=
RouterArgs
.
_parse_selector
(
None
)
assert
result
==
{}
def
test_parse_selector_invalid
(
self
):
"""Test parsing invalid selector arguments."""
# Test malformed selector (no equals sign)
result
=
RouterArgs
.
_parse_selector
([
"app"
])
assert
result
==
{}
# Test multiple equals signs (should use first one)
result
=
RouterArgs
.
_parse_selector
([
"app=worker=extra"
])
assert
result
==
{
"app"
:
"worker=extra"
}
def
test_parse_prefill_urls_valid
(
self
):
"""Test parsing valid prefill URL arguments."""
# Test with bootstrap port
result
=
RouterArgs
.
_parse_prefill_urls
([[
"http://prefill1:8000"
,
"9000"
]])
assert
result
==
[(
"http://prefill1:8000"
,
9000
)]
# Test with 'none' bootstrap port
result
=
RouterArgs
.
_parse_prefill_urls
([[
"http://prefill1:8000"
,
"none"
]])
assert
result
==
[(
"http://prefill1:8000"
,
None
)]
# Test without bootstrap port
result
=
RouterArgs
.
_parse_prefill_urls
([[
"http://prefill1:8000"
]])
assert
result
==
[(
"http://prefill1:8000"
,
None
)]
# Test multiple prefill URLs
result
=
RouterArgs
.
_parse_prefill_urls
(
[
[
"http://prefill1:8000"
,
"9000"
],
[
"http://prefill2:8000"
,
"none"
],
[
"http://prefill3:8000"
],
]
)
expected
=
[
(
"http://prefill1:8000"
,
9000
),
(
"http://prefill2:8000"
,
None
),
(
"http://prefill3:8000"
,
None
),
]
assert
result
==
expected
# Test empty list
result
=
RouterArgs
.
_parse_prefill_urls
([])
assert
result
==
[]
# Test None
result
=
RouterArgs
.
_parse_prefill_urls
(
None
)
assert
result
==
[]
def
test_parse_prefill_urls_invalid
(
self
):
"""Test parsing invalid prefill URL arguments."""
# Test invalid bootstrap port
with
pytest
.
raises
(
ValueError
,
match
=
"Invalid bootstrap port"
):
RouterArgs
.
_parse_prefill_urls
([[
"http://prefill1:8000"
,
"invalid"
]])
def
test_parse_decode_urls_valid
(
self
):
"""Test parsing valid decode URL arguments."""
# Test single decode URL
result
=
RouterArgs
.
_parse_decode_urls
([[
"http://decode1:8001"
]])
assert
result
==
[
"http://decode1:8001"
]
# Test multiple decode URLs
result
=
RouterArgs
.
_parse_decode_urls
(
[[
"http://decode1:8001"
],
[
"http://decode2:8001"
]]
)
assert
result
==
[
"http://decode1:8001"
,
"http://decode2:8001"
]
# Test empty list
result
=
RouterArgs
.
_parse_decode_urls
([])
assert
result
==
[]
# Test None
result
=
RouterArgs
.
_parse_decode_urls
(
None
)
assert
result
==
[]
def
test_from_cli_args_basic
(
self
):
"""Test creating RouterArgs from basic CLI arguments."""
args
=
SimpleNamespace
(
host
=
"0.0.0.0"
,
port
=
30001
,
worker_urls
=
[
"http://worker1:8000"
,
"http://worker2:8000"
],
policy
=
"round_robin"
,
prefill
=
None
,
decode
=
None
,
router_policy
=
"round_robin"
,
router_pd_disaggregation
=
False
,
router_prefill_policy
=
None
,
router_decode_policy
=
None
,
router_worker_startup_timeout_secs
=
300
,
router_worker_startup_check_interval
=
15
,
router_cache_threshold
=
0.7
,
router_balance_abs_threshold
=
128
,
router_balance_rel_threshold
=
2.0
,
router_eviction_interval
=
180
,
router_max_tree_size
=
2
**
28
,
router_max_payload_size
=
1024
*
1024
*
1024
,
# 1GB
router_dp_aware
=
True
,
router_api_key
=
"test-key"
,
router_log_dir
=
"/tmp/logs"
,
router_log_level
=
"debug"
,
router_service_discovery
=
True
,
router_selector
=
[
"app=worker"
,
"env=test"
],
router_service_discovery_port
=
8080
,
router_service_discovery_namespace
=
"default"
,
router_prefill_selector
=
[
"app=prefill"
],
router_decode_selector
=
[
"app=decode"
],
router_prometheus_port
=
29000
,
router_prometheus_host
=
"0.0.0.0"
,
router_request_id_headers
=
[
"x-request-id"
,
"x-trace-id"
],
router_request_timeout_secs
=
1200
,
router_max_concurrent_requests
=
512
,
router_queue_size
=
200
,
router_queue_timeout_secs
=
120
,
router_rate_limit_tokens_per_second
=
100
,
router_cors_allowed_origins
=
[
"http://localhost:3000"
],
router_retry_max_retries
=
3
,
router_retry_initial_backoff_ms
=
100
,
router_retry_max_backoff_ms
=
10000
,
router_retry_backoff_multiplier
=
2.0
,
router_retry_jitter_factor
=
0.1
,
router_cb_failure_threshold
=
5
,
router_cb_success_threshold
=
2
,
router_cb_timeout_duration_secs
=
30
,
router_cb_window_duration_secs
=
60
,
router_disable_retries
=
False
,
router_disable_circuit_breaker
=
False
,
router_health_failure_threshold
=
2
,
router_health_success_threshold
=
1
,
router_health_check_timeout_secs
=
3
,
router_health_check_interval_secs
=
30
,
router_health_check_endpoint
=
"/healthz"
,
)
router_args
=
RouterArgs
.
from_cli_args
(
args
,
use_router_prefix
=
True
)
# Test basic configuration
assert
router_args
.
host
==
"0.0.0.0"
assert
router_args
.
port
==
30001
assert
router_args
.
worker_urls
==
[
"http://worker1:8000"
,
"http://worker2:8000"
]
assert
router_args
.
policy
==
"round_robin"
# Test PD configuration
assert
router_args
.
pd_disaggregation
is
False
assert
router_args
.
prefill_urls
==
[]
assert
router_args
.
decode_urls
==
[]
# Test service discovery
assert
router_args
.
service_discovery
is
True
assert
router_args
.
selector
==
{
"app"
:
"worker"
,
"env"
:
"test"
}
assert
router_args
.
service_discovery_port
==
8080
assert
router_args
.
service_discovery_namespace
==
"default"
assert
router_args
.
prefill_selector
==
{
"app"
:
"prefill"
}
assert
router_args
.
decode_selector
==
{
"app"
:
"decode"
}
# Test other configurations
assert
router_args
.
dp_aware
is
True
assert
router_args
.
api_key
==
"test-key"
assert
router_args
.
log_dir
==
"/tmp/logs"
assert
router_args
.
log_level
==
"debug"
assert
router_args
.
prometheus_port
==
29000
assert
router_args
.
prometheus_host
==
"0.0.0.0"
assert
router_args
.
request_id_headers
==
[
"x-request-id"
,
"x-trace-id"
]
assert
router_args
.
request_timeout_secs
==
1200
assert
router_args
.
max_concurrent_requests
==
512
assert
router_args
.
queue_size
==
200
assert
router_args
.
queue_timeout_secs
==
120
assert
router_args
.
rate_limit_tokens_per_second
==
100
assert
router_args
.
cors_allowed_origins
==
[
"http://localhost:3000"
]
# Test retry configuration
assert
router_args
.
retry_max_retries
==
3
assert
router_args
.
retry_initial_backoff_ms
==
100
assert
router_args
.
retry_max_backoff_ms
==
10000
assert
router_args
.
retry_backoff_multiplier
==
2.0
assert
router_args
.
retry_jitter_factor
==
0.1
# Test circuit breaker configuration
assert
router_args
.
cb_failure_threshold
==
5
assert
router_args
.
cb_success_threshold
==
2
assert
router_args
.
cb_timeout_duration_secs
==
30
assert
router_args
.
cb_window_duration_secs
==
60
assert
router_args
.
disable_retries
is
False
assert
router_args
.
disable_circuit_breaker
is
False
# Test health check configuration
assert
router_args
.
health_failure_threshold
==
2
assert
router_args
.
health_success_threshold
==
1
assert
router_args
.
health_check_timeout_secs
==
3
assert
router_args
.
health_check_interval_secs
==
30
assert
router_args
.
health_check_endpoint
==
"/healthz"
# Note: model_path and tokenizer_path are not available in current RouterArgs
def
test_from_cli_args_pd_mode
(
self
):
"""Test creating RouterArgs from CLI arguments in PD mode."""
args
=
SimpleNamespace
(
host
=
"127.0.0.1"
,
port
=
30000
,
worker_urls
=
[],
policy
=
"cache_aware"
,
prefill
=
[
[
"http://prefill1:8000"
,
"9000"
],
[
"http://prefill2:8000"
,
"none"
],
],
decode
=
[[
"http://decode1:8001"
],
[
"http://decode2:8001"
]],
router_prefill
=
[
[
"http://prefill1:8000"
,
"9000"
],
[
"http://prefill2:8000"
,
"none"
],
],
router_decode
=
[[
"http://decode1:8001"
],
[
"http://decode2:8001"
]],
router_policy
=
"cache_aware"
,
router_pd_disaggregation
=
True
,
router_prefill_policy
=
"power_of_two"
,
router_decode_policy
=
"round_robin"
,
# Include all required fields with defaults
router_worker_startup_timeout_secs
=
600
,
router_worker_startup_check_interval
=
30
,
router_cache_threshold
=
0.3
,
router_balance_abs_threshold
=
64
,
router_balance_rel_threshold
=
1.5
,
router_eviction_interval
=
120
,
router_max_tree_size
=
2
**
26
,
router_max_payload_size
=
512
*
1024
*
1024
,
router_dp_aware
=
False
,
router_api_key
=
None
,
router_log_dir
=
None
,
router_log_level
=
None
,
router_service_discovery
=
False
,
router_selector
=
None
,
router_service_discovery_port
=
80
,
router_service_discovery_namespace
=
None
,
router_prefill_selector
=
None
,
router_decode_selector
=
None
,
router_prometheus_port
=
None
,
router_prometheus_host
=
None
,
router_request_id_headers
=
None
,
router_request_timeout_secs
=
1800
,
router_max_concurrent_requests
=
256
,
router_queue_size
=
100
,
router_queue_timeout_secs
=
60
,
router_rate_limit_tokens_per_second
=
None
,
router_cors_allowed_origins
=
[],
router_retry_max_retries
=
5
,
router_retry_initial_backoff_ms
=
50
,
router_retry_max_backoff_ms
=
30000
,
router_retry_backoff_multiplier
=
1.5
,
router_retry_jitter_factor
=
0.2
,
router_cb_failure_threshold
=
10
,
router_cb_success_threshold
=
3
,
router_cb_timeout_duration_secs
=
60
,
router_cb_window_duration_secs
=
120
,
router_disable_retries
=
False
,
router_disable_circuit_breaker
=
False
,
router_health_failure_threshold
=
3
,
router_health_success_threshold
=
2
,
router_health_check_timeout_secs
=
5
,
router_health_check_interval_secs
=
60
,
router_health_check_endpoint
=
"/health"
,
)
router_args
=
RouterArgs
.
from_cli_args
(
args
,
use_router_prefix
=
True
)
# Test PD configuration
assert
router_args
.
pd_disaggregation
is
True
assert
router_args
.
prefill_urls
==
[
(
"http://prefill1:8000"
,
9000
),
(
"http://prefill2:8000"
,
None
),
]
assert
router_args
.
decode_urls
==
[
"http://decode1:8001"
,
"http://decode2:8001"
]
assert
router_args
.
prefill_policy
==
"power_of_two"
assert
router_args
.
decode_policy
==
"round_robin"
assert
router_args
.
policy
==
"cache_aware"
# Main policy still set
def
test_from_cli_args_without_prefix
(
self
):
"""Test creating RouterArgs from CLI arguments without router prefix."""
args
=
SimpleNamespace
(
host
=
"127.0.0.1"
,
port
=
30000
,
worker_urls
=
[
"http://worker1:8000"
],
policy
=
"random"
,
prefill
=
None
,
decode
=
None
,
pd_disaggregation
=
False
,
prefill_policy
=
None
,
decode_policy
=
None
,
worker_startup_timeout_secs
=
600
,
worker_startup_check_interval
=
30
,
cache_threshold
=
0.3
,
balance_abs_threshold
=
64
,
balance_rel_threshold
=
1.5
,
eviction_interval
=
120
,
max_tree_size
=
2
**
26
,
max_payload_size
=
512
*
1024
*
1024
,
dp_aware
=
False
,
api_key
=
None
,
log_dir
=
None
,
log_level
=
None
,
service_discovery
=
False
,
selector
=
None
,
service_discovery_port
=
80
,
service_discovery_namespace
=
None
,
prefill_selector
=
None
,
decode_selector
=
None
,
prometheus_port
=
None
,
prometheus_host
=
None
,
request_id_headers
=
None
,
request_timeout_secs
=
1800
,
max_concurrent_requests
=
256
,
queue_size
=
100
,
queue_timeout_secs
=
60
,
rate_limit_tokens_per_second
=
None
,
cors_allowed_origins
=
[],
retry_max_retries
=
5
,
retry_initial_backoff_ms
=
50
,
retry_max_backoff_ms
=
30000
,
retry_backoff_multiplier
=
1.5
,
retry_jitter_factor
=
0.2
,
cb_failure_threshold
=
10
,
cb_success_threshold
=
3
,
cb_timeout_duration_secs
=
60
,
cb_window_duration_secs
=
120
,
disable_retries
=
False
,
disable_circuit_breaker
=
False
,
health_failure_threshold
=
3
,
health_success_threshold
=
2
,
health_check_timeout_secs
=
5
,
health_check_interval_secs
=
60
,
health_check_endpoint
=
"/health"
,
model_path
=
None
,
tokenizer_path
=
None
,
)
router_args
=
RouterArgs
.
from_cli_args
(
args
,
use_router_prefix
=
False
)
assert
router_args
.
host
==
"127.0.0.1"
assert
router_args
.
port
==
30000
assert
router_args
.
worker_urls
==
[
"http://worker1:8000"
]
assert
router_args
.
policy
==
"random"
assert
router_args
.
pd_disaggregation
is
False
class
TestPolicyFromStr
:
"""Test policy string to enum conversion."""
def
test_valid_policies
(
self
):
"""Test conversion of valid policy strings."""
from
sglang_router_rs
import
PolicyType
assert
policy_from_str
(
"random"
)
==
PolicyType
.
Random
assert
policy_from_str
(
"round_robin"
)
==
PolicyType
.
RoundRobin
assert
policy_from_str
(
"cache_aware"
)
==
PolicyType
.
CacheAware
assert
policy_from_str
(
"power_of_two"
)
==
PolicyType
.
PowerOfTwo
def
test_invalid_policy
(
self
):
"""Test conversion of invalid policy string."""
with
pytest
.
raises
(
KeyError
):
policy_from_str
(
"invalid_policy"
)
class
TestParseRouterArgs
:
"""Test the parse_router_args function."""
def
test_parse_basic_args
(
self
):
"""Test parsing basic router arguments."""
args
=
[
"--host"
,
"0.0.0.0"
,
"--port"
,
"30001"
,
"--worker-urls"
,
"http://worker1:8000"
,
"http://worker2:8000"
,
"--policy"
,
"round_robin"
,
]
router_args
=
parse_router_args
(
args
)
assert
router_args
.
host
==
"0.0.0.0"
assert
router_args
.
port
==
30001
assert
router_args
.
worker_urls
==
[
"http://worker1:8000"
,
"http://worker2:8000"
]
assert
router_args
.
policy
==
"round_robin"
def
test_parse_pd_args
(
self
):
"""Test parsing PD disaggregated mode arguments."""
args
=
[
"--pd-disaggregation"
,
"--prefill"
,
"http://prefill1:8000"
,
"9000"
,
"--prefill"
,
"http://prefill2:8000"
,
"none"
,
"--decode"
,
"http://decode1:8001"
,
"--decode"
,
"http://decode2:8001"
,
"--prefill-policy"
,
"power_of_two"
,
"--decode-policy"
,
"round_robin"
,
]
router_args
=
parse_router_args
(
args
)
assert
router_args
.
pd_disaggregation
is
True
assert
router_args
.
prefill_urls
==
[
(
"http://prefill1:8000"
,
9000
),
(
"http://prefill2:8000"
,
None
),
]
assert
router_args
.
decode_urls
==
[
"http://decode1:8001"
,
"http://decode2:8001"
]
assert
router_args
.
prefill_policy
==
"power_of_two"
assert
router_args
.
decode_policy
==
"round_robin"
def
test_parse_service_discovery_args
(
self
):
"""Test parsing service discovery arguments."""
args
=
[
"--service-discovery"
,
"--selector"
,
"app=worker"
,
"env=prod"
,
"--service-discovery-port"
,
"8080"
,
"--service-discovery-namespace"
,
"default"
,
]
router_args
=
parse_router_args
(
args
)
assert
router_args
.
service_discovery
is
True
assert
router_args
.
selector
==
{
"app"
:
"worker"
,
"env"
:
"prod"
}
assert
router_args
.
service_discovery_port
==
8080
assert
router_args
.
service_discovery_namespace
==
"default"
def
test_parse_retry_and_circuit_breaker_args
(
self
):
"""Test parsing retry and circuit breaker arguments."""
args
=
[
"--retry-max-retries"
,
"3"
,
"--retry-initial-backoff-ms"
,
"100"
,
"--retry-max-backoff-ms"
,
"10000"
,
"--retry-backoff-multiplier"
,
"2.0"
,
"--retry-jitter-factor"
,
"0.1"
,
"--disable-retries"
,
"--cb-failure-threshold"
,
"5"
,
"--cb-success-threshold"
,
"2"
,
"--cb-timeout-duration-secs"
,
"30"
,
"--cb-window-duration-secs"
,
"60"
,
"--disable-circuit-breaker"
,
]
router_args
=
parse_router_args
(
args
)
# Test retry configuration
assert
router_args
.
retry_max_retries
==
3
assert
router_args
.
retry_initial_backoff_ms
==
100
assert
router_args
.
retry_max_backoff_ms
==
10000
assert
router_args
.
retry_backoff_multiplier
==
2.0
assert
router_args
.
retry_jitter_factor
==
0.1
assert
router_args
.
disable_retries
is
True
# Test circuit breaker configuration
assert
router_args
.
cb_failure_threshold
==
5
assert
router_args
.
cb_success_threshold
==
2
assert
router_args
.
cb_timeout_duration_secs
==
30
assert
router_args
.
cb_window_duration_secs
==
60
assert
router_args
.
disable_circuit_breaker
is
True
def
test_parse_rate_limiting_args
(
self
):
"""Test parsing rate limiting arguments."""
args
=
[
"--max-concurrent-requests"
,
"512"
,
"--queue-size"
,
"200"
,
"--queue-timeout-secs"
,
"120"
,
"--rate-limit-tokens-per-second"
,
"100"
,
]
router_args
=
parse_router_args
(
args
)
assert
router_args
.
max_concurrent_requests
==
512
assert
router_args
.
queue_size
==
200
assert
router_args
.
queue_timeout_secs
==
120
assert
router_args
.
rate_limit_tokens_per_second
==
100
def
test_parse_health_check_args
(
self
):
"""Test parsing health check arguments."""
args
=
[
"--health-failure-threshold"
,
"2"
,
"--health-success-threshold"
,
"1"
,
"--health-check-timeout-secs"
,
"3"
,
"--health-check-interval-secs"
,
"30"
,
"--health-check-endpoint"
,
"/healthz"
,
]
router_args
=
parse_router_args
(
args
)
assert
router_args
.
health_failure_threshold
==
2
assert
router_args
.
health_success_threshold
==
1
assert
router_args
.
health_check_timeout_secs
==
3
assert
router_args
.
health_check_interval_secs
==
30
assert
router_args
.
health_check_endpoint
==
"/healthz"
def
test_parse_cors_args
(
self
):
"""Test parsing CORS arguments."""
args
=
[
"--cors-allowed-origins"
,
"http://localhost:3000"
,
"https://example.com"
,
]
router_args
=
parse_router_args
(
args
)
assert
router_args
.
cors_allowed_origins
==
[
"http://localhost:3000"
,
"https://example.com"
,
]
def
test_parse_tokenizer_args
(
self
):
"""Test parsing tokenizer arguments."""
# Note: model-path and tokenizer-path arguments are not available in current implementation
# This test is skipped until those arguments are added
pytest
.
skip
(
"Tokenizer arguments not available in current implementation"
)
def
test_parse_invalid_args
(
self
):
"""Test parsing invalid arguments."""
# Test invalid policy
with
pytest
.
raises
(
SystemExit
):
parse_router_args
([
"--policy"
,
"invalid_policy"
])
# Test invalid bootstrap port
with
pytest
.
raises
(
ValueError
,
match
=
"Invalid bootstrap port"
):
parse_router_args
(
[
"--pd-disaggregation"
,
"--prefill"
,
"http://prefill1:8000"
,
"invalid_port"
,
]
)
def
test_help_output
(
self
):
"""Test that help output is generated correctly."""
with
pytest
.
raises
(
SystemExit
)
as
exc_info
:
parse_router_args
([
"--help"
])
# SystemExit with code 0 indicates help was displayed
assert
exc_info
.
value
.
code
==
0
sgl-router/py_test/unit/test_router_config.py
0 → 100644
View file @
045ab92d
"""
Unit tests for router configuration validation and setup.
These tests focus on testing the router configuration logic in isolation,
including validation of configuration parameters and their interactions.
"""
from
types
import
SimpleNamespace
from
unittest.mock
import
MagicMock
,
patch
import
pytest
from
sglang_router.launch_router
import
RouterArgs
,
launch_router
from
sglang_router.router
import
policy_from_str
from
sglang_router_rs
import
PolicyType
class
TestRouterConfigValidation
:
"""Test router configuration validation logic."""
def
test_valid_basic_config
(
self
):
"""Test that a valid basic configuration passes validation."""
args
=
RouterArgs
(
host
=
"127.0.0.1"
,
port
=
30000
,
worker_urls
=
[
"http://worker1:8000"
,
"http://worker2:8000"
],
policy
=
"cache_aware"
,
)
# Should not raise any exceptions
assert
args
.
host
==
"127.0.0.1"
assert
args
.
port
==
30000
assert
args
.
worker_urls
==
[
"http://worker1:8000"
,
"http://worker2:8000"
]
assert
args
.
policy
==
"cache_aware"
def
test_valid_pd_config
(
self
):
"""Test that a valid PD configuration passes validation."""
args
=
RouterArgs
(
host
=
"127.0.0.1"
,
port
=
30000
,
pd_disaggregation
=
True
,
prefill_urls
=
[
(
"http://prefill1:8000"
,
9000
),
(
"http://prefill2:8000"
,
None
),
],
decode_urls
=
[
"http://decode1:8001"
,
"http://decode2:8001"
],
policy
=
"cache_aware"
,
)
assert
args
.
pd_disaggregation
is
True
assert
args
.
prefill_urls
==
[
(
"http://prefill1:8000"
,
9000
),
(
"http://prefill2:8000"
,
None
),
]
assert
args
.
decode_urls
==
[
"http://decode1:8001"
,
"http://decode2:8001"
]
assert
args
.
policy
==
"cache_aware"
def
test_pd_config_without_urls_raises_error
(
self
):
"""Test that PD mode without URLs raises validation error."""
args
=
RouterArgs
(
pd_disaggregation
=
True
,
prefill_urls
=
[],
decode_urls
=
[],
service_discovery
=
False
,
)
# This should raise an error when trying to launch
with
pytest
.
raises
(
ValueError
,
match
=
"PD disaggregation mode requires --prefill"
):
launch_router
(
args
)
def
test_pd_config_with_service_discovery_allows_empty_urls
(
self
):
"""Test that PD mode with service discovery allows empty URLs."""
args
=
RouterArgs
(
pd_disaggregation
=
True
,
prefill_urls
=
[],
decode_urls
=
[],
service_discovery
=
True
,
)
# Should not raise validation error when service discovery is enabled
with
patch
(
"sglang_router.launch_router.Router"
)
as
router_mod
:
mock_router_instance
=
MagicMock
()
router_mod
.
from_args
=
MagicMock
(
return_value
=
mock_router_instance
)
launch_router
(
args
)
# Should create router instance via from_args
router_mod
.
from_args
.
assert_called_once
()
def
test_regular_mode_without_workers_allows_empty_urls
(
self
):
"""Test that regular mode allows empty worker URLs."""
args
=
RouterArgs
(
worker_urls
=
[],
service_discovery
=
False
)
# Should not raise validation error
with
patch
(
"sglang_router.launch_router.Router"
)
as
router_mod
:
mock_router_instance
=
MagicMock
()
router_mod
.
from_args
=
MagicMock
(
return_value
=
mock_router_instance
)
launch_router
(
args
)
# Should create router instance via from_args
router_mod
.
from_args
.
assert_called_once
()
def
test_cache_threshold_validation
(
self
):
"""Test cache threshold validation."""
# Valid cache threshold
args
=
RouterArgs
(
cache_threshold
=
0.5
)
assert
args
.
cache_threshold
==
0.5
# Edge cases
args
=
RouterArgs
(
cache_threshold
=
0.0
)
assert
args
.
cache_threshold
==
0.0
args
=
RouterArgs
(
cache_threshold
=
1.0
)
assert
args
.
cache_threshold
==
1.0
def
test_balance_threshold_validation
(
self
):
"""Test load balancing threshold validation."""
# Valid thresholds
args
=
RouterArgs
(
balance_abs_threshold
=
64
,
balance_rel_threshold
=
1.5
)
assert
args
.
balance_abs_threshold
==
64
assert
args
.
balance_rel_threshold
==
1.5
# Edge cases
args
=
RouterArgs
(
balance_abs_threshold
=
0
,
balance_rel_threshold
=
1.0
)
assert
args
.
balance_abs_threshold
==
0
assert
args
.
balance_rel_threshold
==
1.0
def
test_timeout_validation
(
self
):
"""Test timeout parameter validation."""
# Valid timeouts
args
=
RouterArgs
(
worker_startup_timeout_secs
=
600
,
worker_startup_check_interval
=
30
,
request_timeout_secs
=
1800
,
queue_timeout_secs
=
60
,
)
assert
args
.
worker_startup_timeout_secs
==
600
assert
args
.
worker_startup_check_interval
==
30
assert
args
.
request_timeout_secs
==
1800
assert
args
.
queue_timeout_secs
==
60
def
test_retry_config_validation
(
self
):
"""Test retry configuration validation."""
# Valid retry config
args
=
RouterArgs
(
retry_max_retries
=
5
,
retry_initial_backoff_ms
=
50
,
retry_max_backoff_ms
=
30000
,
retry_backoff_multiplier
=
1.5
,
retry_jitter_factor
=
0.2
,
disable_retries
=
False
,
)
assert
args
.
retry_max_retries
==
5
assert
args
.
retry_initial_backoff_ms
==
50
assert
args
.
retry_max_backoff_ms
==
30000
assert
args
.
retry_backoff_multiplier
==
1.5
assert
args
.
retry_jitter_factor
==
0.2
assert
args
.
disable_retries
is
False
def
test_circuit_breaker_config_validation
(
self
):
"""Test circuit breaker configuration validation."""
# Valid circuit breaker config
args
=
RouterArgs
(
cb_failure_threshold
=
10
,
cb_success_threshold
=
3
,
cb_timeout_duration_secs
=
60
,
cb_window_duration_secs
=
120
,
disable_circuit_breaker
=
False
,
)
assert
args
.
cb_failure_threshold
==
10
assert
args
.
cb_success_threshold
==
3
assert
args
.
cb_timeout_duration_secs
==
60
assert
args
.
cb_window_duration_secs
==
120
assert
args
.
disable_circuit_breaker
is
False
def
test_health_check_config_validation
(
self
):
"""Test health check configuration validation."""
# Valid health check config
args
=
RouterArgs
(
health_failure_threshold
=
3
,
health_success_threshold
=
2
,
health_check_timeout_secs
=
5
,
health_check_interval_secs
=
60
,
health_check_endpoint
=
"/health"
,
)
assert
args
.
health_failure_threshold
==
3
assert
args
.
health_success_threshold
==
2
assert
args
.
health_check_timeout_secs
==
5
assert
args
.
health_check_interval_secs
==
60
assert
args
.
health_check_endpoint
==
"/health"
def
test_rate_limiting_config_validation
(
self
):
"""Test rate limiting configuration validation."""
# Valid rate limiting config
args
=
RouterArgs
(
max_concurrent_requests
=
256
,
queue_size
=
100
,
queue_timeout_secs
=
60
,
rate_limit_tokens_per_second
=
100
,
)
assert
args
.
max_concurrent_requests
==
256
assert
args
.
queue_size
==
100
assert
args
.
queue_timeout_secs
==
60
assert
args
.
rate_limit_tokens_per_second
==
100
def
test_service_discovery_config_validation
(
self
):
"""Test service discovery configuration validation."""
# Valid service discovery config
args
=
RouterArgs
(
service_discovery
=
True
,
selector
=
{
"app"
:
"worker"
,
"env"
:
"prod"
},
service_discovery_port
=
8080
,
service_discovery_namespace
=
"default"
,
)
assert
args
.
service_discovery
is
True
assert
args
.
selector
==
{
"app"
:
"worker"
,
"env"
:
"prod"
}
assert
args
.
service_discovery_port
==
8080
assert
args
.
service_discovery_namespace
==
"default"
def
test_pd_service_discovery_config_validation
(
self
):
"""Test PD service discovery configuration validation."""
# Valid PD service discovery config
args
=
RouterArgs
(
pd_disaggregation
=
True
,
service_discovery
=
True
,
prefill_selector
=
{
"app"
:
"prefill"
},
decode_selector
=
{
"app"
:
"decode"
},
bootstrap_port_annotation
=
"sglang.ai/bootstrap-port"
,
)
assert
args
.
pd_disaggregation
is
True
assert
args
.
service_discovery
is
True
assert
args
.
prefill_selector
==
{
"app"
:
"prefill"
}
assert
args
.
decode_selector
==
{
"app"
:
"decode"
}
assert
args
.
bootstrap_port_annotation
==
"sglang.ai/bootstrap-port"
def
test_prometheus_config_validation
(
self
):
"""Test Prometheus configuration validation."""
# Valid Prometheus config
args
=
RouterArgs
(
prometheus_port
=
29000
,
prometheus_host
=
"127.0.0.1"
)
assert
args
.
prometheus_port
==
29000
assert
args
.
prometheus_host
==
"127.0.0.1"
def
test_cors_config_validation
(
self
):
"""Test CORS configuration validation."""
# Valid CORS config
args
=
RouterArgs
(
cors_allowed_origins
=
[
"http://localhost:3000"
,
"https://example.com"
]
)
assert
args
.
cors_allowed_origins
==
[
"http://localhost:3000"
,
"https://example.com"
,
]
def
test_tokenizer_config_validation
(
self
):
"""Test tokenizer configuration validation."""
# Note: model_path and tokenizer_path are not available in current RouterArgs
pytest
.
skip
(
"Tokenizer configuration not available in current implementation"
)
def
test_dp_aware_config_validation
(
self
):
"""Test data parallelism aware configuration validation."""
# Valid DP aware config
args
=
RouterArgs
(
dp_aware
=
True
,
api_key
=
"test-api-key"
)
assert
args
.
dp_aware
is
True
assert
args
.
api_key
==
"test-api-key"
def
test_request_id_headers_validation
(
self
):
"""Test request ID headers configuration validation."""
# Valid request ID headers config
args
=
RouterArgs
(
request_id_headers
=
[
"x-request-id"
,
"x-trace-id"
,
"x-correlation-id"
]
)
assert
args
.
request_id_headers
==
[
"x-request-id"
,
"x-trace-id"
,
"x-correlation-id"
,
]
def
test_policy_consistency_validation
(
self
):
"""Test policy consistency validation in PD mode."""
# Test with both prefill and decode policies specified
args
=
RouterArgs
(
pd_disaggregation
=
True
,
prefill_urls
=
[(
"http://prefill1:8000"
,
None
)],
decode_urls
=
[
"http://decode1:8001"
],
policy
=
"cache_aware"
,
prefill_policy
=
"power_of_two"
,
decode_policy
=
"round_robin"
,
)
# Should not raise validation error
with
patch
(
"sglang_router.launch_router.Router"
)
as
router_mod
:
mock_router_instance
=
MagicMock
()
router_mod
.
from_args
=
MagicMock
(
return_value
=
mock_router_instance
)
launch_router
(
args
)
# Should create router instance via from_args
router_mod
.
from_args
.
assert_called_once
()
def
test_policy_fallback_validation
(
self
):
"""Test policy fallback validation in PD mode."""
# Test with only prefill policy specified
args
=
RouterArgs
(
pd_disaggregation
=
True
,
prefill_urls
=
[(
"http://prefill1:8000"
,
None
)],
decode_urls
=
[
"http://decode1:8001"
],
policy
=
"cache_aware"
,
prefill_policy
=
"power_of_two"
,
decode_policy
=
None
,
)
# Should not raise validation error
with
patch
(
"sglang_router.launch_router.Router"
)
as
router_mod
:
mock_router_instance
=
MagicMock
()
router_mod
.
from_args
=
MagicMock
(
return_value
=
mock_router_instance
)
launch_router
(
args
)
# Should create router instance via from_args
router_mod
.
from_args
.
assert_called_once
()
def
test_policy_enum_conversion
(
self
):
"""Test policy string to enum conversion."""
# Test all valid policy conversions
assert
policy_from_str
(
"random"
)
==
PolicyType
.
Random
assert
policy_from_str
(
"round_robin"
)
==
PolicyType
.
RoundRobin
assert
policy_from_str
(
"cache_aware"
)
==
PolicyType
.
CacheAware
assert
policy_from_str
(
"power_of_two"
)
==
PolicyType
.
PowerOfTwo
def
test_invalid_policy_enum_conversion
(
self
):
"""Test invalid policy string to enum conversion."""
with
pytest
.
raises
(
KeyError
):
policy_from_str
(
"invalid_policy"
)
def
test_config_immutability
(
self
):
"""Test that configuration objects are properly immutable."""
args
=
RouterArgs
(
host
=
"127.0.0.1"
,
port
=
30000
,
worker_urls
=
[
"http://worker1:8000"
]
)
# Test that we can't modify the configuration after creation
# (This is more of a design test - dataclasses are mutable by default)
original_host
=
args
.
host
args
.
host
=
"0.0.0.0"
assert
args
.
host
==
"0.0.0.0"
# Dataclasses are mutable
assert
args
.
host
!=
original_host
def
test_config_defaults_consistency
(
self
):
"""Test that configuration defaults are consistent."""
args1
=
RouterArgs
()
args2
=
RouterArgs
()
# Both instances should have the same defaults
assert
args1
.
host
==
args2
.
host
assert
args1
.
port
==
args2
.
port
assert
args1
.
policy
==
args2
.
policy
assert
args1
.
worker_urls
==
args2
.
worker_urls
assert
args1
.
pd_disaggregation
==
args2
.
pd_disaggregation
def
test_config_serialization
(
self
):
"""Test that configuration can be serialized/deserialized."""
args
=
RouterArgs
(
host
=
"127.0.0.1"
,
port
=
30000
,
worker_urls
=
[
"http://worker1:8000"
],
policy
=
"cache_aware"
,
cache_threshold
=
0.5
,
)
# Test that we can access all attributes
assert
hasattr
(
args
,
"host"
)
assert
hasattr
(
args
,
"port"
)
assert
hasattr
(
args
,
"worker_urls"
)
assert
hasattr
(
args
,
"policy"
)
assert
hasattr
(
args
,
"cache_threshold"
)
def
test_config_with_none_values
(
self
):
"""Test configuration with None values."""
args
=
RouterArgs
(
api_key
=
None
,
log_dir
=
None
,
log_level
=
None
,
prometheus_port
=
None
,
prometheus_host
=
None
,
request_id_headers
=
None
,
rate_limit_tokens_per_second
=
None
,
service_discovery_namespace
=
None
,
)
# All None values should be preserved
assert
args
.
api_key
is
None
assert
args
.
log_dir
is
None
assert
args
.
log_level
is
None
assert
args
.
prometheus_port
is
None
assert
args
.
prometheus_host
is
None
assert
args
.
request_id_headers
is
None
assert
args
.
rate_limit_tokens_per_second
is
None
assert
args
.
service_discovery_namespace
is
None
def
test_config_with_empty_lists
(
self
):
"""Test configuration with empty lists."""
args
=
RouterArgs
(
worker_urls
=
[],
prefill_urls
=
[],
decode_urls
=
[],
cors_allowed_origins
=
[]
)
# All empty lists should be preserved
assert
args
.
worker_urls
==
[]
assert
args
.
prefill_urls
==
[]
assert
args
.
decode_urls
==
[]
assert
args
.
cors_allowed_origins
==
[]
def
test_config_with_empty_dicts
(
self
):
"""Test configuration with empty dictionaries."""
args
=
RouterArgs
(
selector
=
{},
prefill_selector
=
{},
decode_selector
=
{})
# All empty dictionaries should be preserved
assert
args
.
selector
==
{}
assert
args
.
prefill_selector
==
{}
assert
args
.
decode_selector
==
{}
sgl-router/py_test/unit/test_startup_sequence.py
0 → 100644
View file @
045ab92d
"""
Unit tests for startup sequence logic in sglang_router.
These tests focus on testing the startup sequence logic in isolation,
including router initialization, configuration validation, and startup flow.
"""
import
logging
from
types
import
SimpleNamespace
from
unittest.mock
import
MagicMock
,
call
,
patch
import
pytest
from
sglang_router.launch_router
import
RouterArgs
,
launch_router
from
sglang_router.router
import
policy_from_str
# Local helper mirroring the router logger setup used in production
def
setup_logger
():
logger
=
logging
.
getLogger
(
"router"
)
logger
.
setLevel
(
logging
.
INFO
)
if
not
logger
.
handlers
:
formatter
=
logging
.
Formatter
(
"[Router (Python)] %(asctime)s - %(levelname)s - %(message)s"
,
datefmt
=
"%Y-%m-%d %H:%M:%S"
,
)
handler
=
logging
.
StreamHandler
()
handler
.
setFormatter
(
formatter
)
logger
.
addHandler
(
handler
)
return
logger
from
sglang_router_rs
import
PolicyType
class
TestSetupLogger
:
"""Test logger setup functionality."""
def
test_setup_logger_returns_logger
(
self
):
"""Test that setup_logger returns a logger instance."""
logger
=
setup_logger
()
assert
isinstance
(
logger
,
logging
.
Logger
)
assert
logger
.
name
==
"router"
assert
logger
.
level
==
logging
.
INFO
def
test_setup_logger_has_handler
(
self
):
"""Test that setup_logger configures a handler."""
logger
=
setup_logger
()
assert
len
(
logger
.
handlers
)
>
0
handler
=
logger
.
handlers
[
0
]
assert
isinstance
(
handler
,
logging
.
StreamHandler
)
def
test_setup_logger_has_formatter
(
self
):
"""Test that setup_logger configures a formatter."""
logger
=
setup_logger
()
handler
=
logger
.
handlers
[
0
]
formatter
=
handler
.
formatter
assert
formatter
is
not
None
assert
"[Router (Python)]"
in
formatter
.
_fmt
def
test_setup_logger_multiple_calls
(
self
):
"""Test that multiple calls to setup_logger work correctly."""
logger1
=
setup_logger
()
logger2
=
setup_logger
()
# Should return the same logger instance
assert
logger1
is
logger2
class
TestPolicyFromStr
:
"""Test policy string to enum conversion in startup context."""
def
test_policy_conversion_in_startup
(
self
):
"""Test policy conversion during startup sequence."""
# Test all valid policies
policies
=
[
"random"
,
"round_robin"
,
"cache_aware"
,
"power_of_two"
]
expected_enums
=
[
PolicyType
.
Random
,
PolicyType
.
RoundRobin
,
PolicyType
.
CacheAware
,
PolicyType
.
PowerOfTwo
,
]
for
policy_str
,
expected_enum
in
zip
(
policies
,
expected_enums
):
result
=
policy_from_str
(
policy_str
)
assert
result
==
expected_enum
def
test_invalid_policy_in_startup
(
self
):
"""Test handling of invalid policy during startup."""
with
pytest
.
raises
(
KeyError
):
policy_from_str
(
"invalid_policy"
)
class
TestRouterInitialization
:
"""Test router initialization logic."""
def
test_router_initialization_basic
(
self
):
"""Test basic router initialization."""
args
=
RouterArgs
(
host
=
"127.0.0.1"
,
port
=
30000
,
worker_urls
=
[
"http://worker1:8000"
],
policy
=
"cache_aware"
,
)
with
patch
(
"sglang_router.launch_router.Router"
)
as
router_mod
:
captured_args
=
{}
mock_router_instance
=
MagicMock
()
def
fake_from_args
(
router_args
):
# capture needed fields from RouterArgs
captured_args
.
update
(
dict
(
host
=
router_args
.
host
,
port
=
router_args
.
port
,
worker_urls
=
router_args
.
worker_urls
,
policy
=
policy_from_str
(
router_args
.
policy
),
)
)
return
mock_router_instance
router_mod
.
from_args
=
MagicMock
(
side_effect
=
fake_from_args
)
result
=
launch_router
(
args
)
# Verify Router.from_args was called and captured fields match
router_mod
.
from_args
.
assert_called_once
()
assert
captured_args
[
"host"
]
==
"127.0.0.1"
assert
captured_args
[
"port"
]
==
30000
assert
captured_args
[
"worker_urls"
]
==
[
"http://worker1:8000"
]
assert
captured_args
[
"policy"
]
==
PolicyType
.
CacheAware
# Verify router.start() was called
mock_router_instance
.
start
.
assert_called_once
()
# Function returns None; ensure start was invoked
def
test_router_initialization_pd_mode
(
self
):
"""Test router initialization in PD mode."""
args
=
RouterArgs
(
pd_disaggregation
=
True
,
prefill_urls
=
[(
"http://prefill1:8000"
,
9000
)],
decode_urls
=
[
"http://decode1:8001"
],
policy
=
"power_of_two"
,
)
with
patch
(
"sglang_router.launch_router.Router"
)
as
router_mod
:
captured_args
=
{}
mock_router_instance
=
MagicMock
()
def
fake_from_args
(
router_args
):
captured_args
.
update
(
dict
(
pd_disaggregation
=
router_args
.
pd_disaggregation
,
prefill_urls
=
router_args
.
prefill_urls
,
decode_urls
=
router_args
.
decode_urls
,
policy
=
policy_from_str
(
router_args
.
policy
),
)
)
return
mock_router_instance
router_mod
.
from_args
=
MagicMock
(
side_effect
=
fake_from_args
)
result
=
launch_router
(
args
)
# Verify Router.from_args was called with PD parameters
router_mod
.
from_args
.
assert_called_once
()
assert
captured_args
[
"pd_disaggregation"
]
is
True
assert
captured_args
[
"prefill_urls"
]
==
[(
"http://prefill1:8000"
,
9000
)]
assert
captured_args
[
"decode_urls"
]
==
[
"http://decode1:8001"
]
assert
captured_args
[
"policy"
]
==
PolicyType
.
PowerOfTwo
# Verify router.start() was called
mock_router_instance
.
start
.
assert_called_once
()
# Function returns None; ensure start was invoked
def
test_router_initialization_with_service_discovery
(
self
):
"""Test router initialization with service discovery."""
args
=
RouterArgs
(
service_discovery
=
True
,
selector
=
{
"app"
:
"worker"
,
"env"
:
"prod"
},
service_discovery_port
=
8080
,
service_discovery_namespace
=
"default"
,
)
with
patch
(
"sglang_router.launch_router.Router"
)
as
router_mod
:
captured_args
=
{}
mock_router_instance
=
MagicMock
()
def
fake_from_args
(
router_args
):
captured_args
.
update
(
dict
(
service_discovery
=
router_args
.
service_discovery
,
selector
=
router_args
.
selector
,
service_discovery_port
=
router_args
.
service_discovery_port
,
service_discovery_namespace
=
router_args
.
service_discovery_namespace
,
)
)
return
mock_router_instance
router_mod
.
from_args
=
MagicMock
(
side_effect
=
fake_from_args
)
result
=
launch_router
(
args
)
# Verify Router.from_args was called with service discovery parameters
router_mod
.
from_args
.
assert_called_once
()
assert
captured_args
[
"service_discovery"
]
is
True
assert
captured_args
[
"selector"
]
==
{
"app"
:
"worker"
,
"env"
:
"prod"
}
assert
captured_args
[
"service_discovery_port"
]
==
8080
assert
captured_args
[
"service_discovery_namespace"
]
==
"default"
# Verify router.start() was called
mock_router_instance
.
start
.
assert_called_once
()
# Function returns None; ensure start was invoked
def
test_router_initialization_with_retry_config
(
self
):
"""Test router initialization with retry configuration."""
args
=
RouterArgs
(
retry_max_retries
=
3
,
retry_initial_backoff_ms
=
100
,
retry_max_backoff_ms
=
10000
,
retry_backoff_multiplier
=
2.0
,
retry_jitter_factor
=
0.1
,
disable_retries
=
False
,
)
with
patch
(
"sglang_router.launch_router.Router"
)
as
router_mod
:
captured_args
=
{}
mock_router_instance
=
MagicMock
()
def
fake_from_args
(
router_args
):
captured_args
.
update
(
dict
(
retry_max_retries
=
router_args
.
retry_max_retries
,
retry_initial_backoff_ms
=
router_args
.
retry_initial_backoff_ms
,
retry_max_backoff_ms
=
router_args
.
retry_max_backoff_ms
,
retry_backoff_multiplier
=
router_args
.
retry_backoff_multiplier
,
retry_jitter_factor
=
router_args
.
retry_jitter_factor
,
disable_retries
=
router_args
.
disable_retries
,
)
)
return
mock_router_instance
router_mod
.
from_args
=
MagicMock
(
side_effect
=
fake_from_args
)
result
=
launch_router
(
args
)
# Verify router was created with retry parameters
router_mod
.
from_args
.
assert_called_once
()
assert
captured_args
[
"retry_max_retries"
]
==
3
assert
captured_args
[
"retry_initial_backoff_ms"
]
==
100
assert
captured_args
[
"retry_max_backoff_ms"
]
==
10000
assert
captured_args
[
"retry_backoff_multiplier"
]
==
2.0
assert
captured_args
[
"retry_jitter_factor"
]
==
0.1
assert
captured_args
[
"disable_retries"
]
is
False
# Verify router.start() was called
mock_router_instance
.
start
.
assert_called_once
()
# Function returns None; ensure start was invoked
def
test_router_initialization_with_circuit_breaker_config
(
self
):
"""Test router initialization with circuit breaker configuration."""
args
=
RouterArgs
(
cb_failure_threshold
=
5
,
cb_success_threshold
=
2
,
cb_timeout_duration_secs
=
30
,
cb_window_duration_secs
=
60
,
disable_circuit_breaker
=
False
,
)
with
patch
(
"sglang_router.launch_router.Router"
)
as
router_mod
:
captured_args
=
{}
mock_router_instance
=
MagicMock
()
def
fake_from_args
(
router_args
):
captured_args
.
update
(
dict
(
cb_failure_threshold
=
router_args
.
cb_failure_threshold
,
cb_success_threshold
=
router_args
.
cb_success_threshold
,
cb_timeout_duration_secs
=
router_args
.
cb_timeout_duration_secs
,
cb_window_duration_secs
=
router_args
.
cb_window_duration_secs
,
disable_circuit_breaker
=
router_args
.
disable_circuit_breaker
,
)
)
return
mock_router_instance
router_mod
.
from_args
=
MagicMock
(
side_effect
=
fake_from_args
)
result
=
launch_router
(
args
)
# Verify router was created with circuit breaker parameters
router_mod
.
from_args
.
assert_called_once
()
assert
captured_args
[
"cb_failure_threshold"
]
==
5
assert
captured_args
[
"cb_success_threshold"
]
==
2
assert
captured_args
[
"cb_timeout_duration_secs"
]
==
30
assert
captured_args
[
"cb_window_duration_secs"
]
==
60
assert
captured_args
[
"disable_circuit_breaker"
]
is
False
# Verify router.start() was called
mock_router_instance
.
start
.
assert_called_once
()
# Function returns None; ensure start was invoked
def
test_router_initialization_with_rate_limiting_config
(
self
):
"""Test router initialization with rate limiting configuration."""
args
=
RouterArgs
(
max_concurrent_requests
=
512
,
queue_size
=
200
,
queue_timeout_secs
=
120
,
rate_limit_tokens_per_second
=
100
,
)
with
patch
(
"sglang_router.launch_router.Router"
)
as
router_mod
:
captured_args
=
{}
mock_router_instance
=
MagicMock
()
def
fake_from_args
(
router_args
):
captured_args
.
update
(
dict
(
max_concurrent_requests
=
router_args
.
max_concurrent_requests
,
queue_size
=
router_args
.
queue_size
,
queue_timeout_secs
=
router_args
.
queue_timeout_secs
,
rate_limit_tokens_per_second
=
router_args
.
rate_limit_tokens_per_second
,
)
)
return
mock_router_instance
router_mod
.
from_args
=
MagicMock
(
side_effect
=
fake_from_args
)
result
=
launch_router
(
args
)
# Verify router was created with rate limiting parameters
router_mod
.
from_args
.
assert_called_once
()
assert
captured_args
[
"max_concurrent_requests"
]
==
512
assert
captured_args
[
"queue_size"
]
==
200
assert
captured_args
[
"queue_timeout_secs"
]
==
120
assert
captured_args
[
"rate_limit_tokens_per_second"
]
==
100
# Verify router.start() was called
mock_router_instance
.
start
.
assert_called_once
()
# Function returns None; ensure start was invoked
def
test_router_initialization_with_health_check_config
(
self
):
"""Test router initialization with health check configuration."""
args
=
RouterArgs
(
health_failure_threshold
=
2
,
health_success_threshold
=
1
,
health_check_timeout_secs
=
3
,
health_check_interval_secs
=
30
,
health_check_endpoint
=
"/healthz"
,
)
with
patch
(
"sglang_router.launch_router.Router"
)
as
router_mod
:
captured_args
=
{}
mock_router_instance
=
MagicMock
()
def
fake_from_args
(
router_args
):
captured_args
.
update
(
dict
(
health_failure_threshold
=
router_args
.
health_failure_threshold
,
health_success_threshold
=
router_args
.
health_success_threshold
,
health_check_timeout_secs
=
router_args
.
health_check_timeout_secs
,
health_check_interval_secs
=
router_args
.
health_check_interval_secs
,
health_check_endpoint
=
router_args
.
health_check_endpoint
,
)
)
return
mock_router_instance
router_mod
.
from_args
=
MagicMock
(
side_effect
=
fake_from_args
)
result
=
launch_router
(
args
)
# Verify router was created with health check parameters
router_mod
.
from_args
.
assert_called_once
()
assert
captured_args
[
"health_failure_threshold"
]
==
2
assert
captured_args
[
"health_success_threshold"
]
==
1
assert
captured_args
[
"health_check_timeout_secs"
]
==
3
assert
captured_args
[
"health_check_interval_secs"
]
==
30
assert
captured_args
[
"health_check_endpoint"
]
==
"/healthz"
# Verify router.start() was called
mock_router_instance
.
start
.
assert_called_once
()
# Function returns None; ensure start was invoked
def
test_router_initialization_with_prometheus_config
(
self
):
"""Test router initialization with Prometheus configuration."""
args
=
RouterArgs
(
prometheus_port
=
29000
,
prometheus_host
=
"127.0.0.1"
)
with
patch
(
"sglang_router.launch_router.Router"
)
as
router_mod
:
captured_args
=
{}
mock_router_instance
=
MagicMock
()
def
fake_from_args
(
router_args
):
captured_args
.
update
(
dict
(
prometheus_port
=
router_args
.
prometheus_port
,
prometheus_host
=
router_args
.
prometheus_host
,
)
)
return
mock_router_instance
router_mod
.
from_args
=
MagicMock
(
side_effect
=
fake_from_args
)
result
=
launch_router
(
args
)
# Verify router was created with Prometheus parameters
router_mod
.
from_args
.
assert_called_once
()
assert
captured_args
[
"prometheus_port"
]
==
29000
assert
captured_args
[
"prometheus_host"
]
==
"127.0.0.1"
# Verify router.start() was called
mock_router_instance
.
start
.
assert_called_once
()
# Function returns None; ensure start was invoked
def
test_router_initialization_with_cors_config
(
self
):
"""Test router initialization with CORS configuration."""
args
=
RouterArgs
(
cors_allowed_origins
=
[
"http://localhost:3000"
,
"https://example.com"
]
)
with
patch
(
"sglang_router.launch_router.Router"
)
as
router_mod
:
captured_args
=
{}
mock_router_instance
=
MagicMock
()
def
fake_from_args
(
router_args
):
captured_args
.
update
(
dict
(
cors_allowed_origins
=
router_args
.
cors_allowed_origins
)
)
return
mock_router_instance
router_mod
.
from_args
=
MagicMock
(
side_effect
=
fake_from_args
)
result
=
launch_router
(
args
)
# Verify router was created with CORS parameters
router_mod
.
from_args
.
assert_called_once
()
assert
captured_args
[
"cors_allowed_origins"
]
==
[
"http://localhost:3000"
,
"https://example.com"
,
]
# Verify router.start() was called
mock_router_instance
.
start
.
assert_called_once
()
# Function returns None; ensure start was invoked
def
test_router_initialization_with_tokenizer_config
(
self
):
"""Test router initialization with tokenizer configuration."""
# Note: model_path and tokenizer_path are not available in current RouterArgs
pytest
.
skip
(
"Tokenizer configuration not available in current implementation"
)
class
TestStartupValidation
:
"""Test startup validation logic."""
def
test_pd_mode_validation_during_startup
(
self
):
"""Test PD mode validation during startup."""
# PD mode without URLs should fail
args
=
RouterArgs
(
pd_disaggregation
=
True
,
prefill_urls
=
[],
decode_urls
=
[],
service_discovery
=
False
,
)
with
pytest
.
raises
(
ValueError
,
match
=
"PD disaggregation mode requires --prefill"
):
launch_router
(
args
)
def
test_pd_mode_with_service_discovery_validation
(
self
):
"""Test PD mode with service discovery validation during startup."""
args
=
RouterArgs
(
pd_disaggregation
=
True
,
prefill_urls
=
[],
decode_urls
=
[],
service_discovery
=
True
,
)
# Should not raise validation error
with
patch
(
"sglang_router.launch_router.Router"
)
as
router_mod
:
mock_router_instance
=
MagicMock
()
router_mod
.
from_args
=
MagicMock
(
return_value
=
mock_router_instance
)
result
=
launch_router
(
args
)
# Should create router instance
router_mod
.
from_args
.
assert_called_once
()
def
test_policy_warning_during_startup
(
self
):
"""Test policy warning during startup in PD mode."""
args
=
RouterArgs
(
pd_disaggregation
=
True
,
prefill_urls
=
[(
"http://prefill1:8000"
,
None
)],
decode_urls
=
[
"http://decode1:8001"
],
policy
=
"cache_aware"
,
prefill_policy
=
"power_of_two"
,
decode_policy
=
"round_robin"
,
)
with
patch
(
"sglang_router.launch_router.Router"
)
as
router_mod
:
mock_router_instance
=
MagicMock
()
router_mod
.
from_args
=
MagicMock
(
return_value
=
mock_router_instance
)
# The policy messages are emitted by router_args logger
with
patch
(
"sglang_router.router_args.logger"
)
as
mock_logger
:
result
=
launch_router
(
args
)
# Should log warning about policy usage
mock_logger
.
warning
.
assert_called_once
()
warning_call
=
mock_logger
.
warning
.
call_args
[
0
][
0
]
assert
(
"Both --prefill-policy and --decode-policy are specified"
in
warning_call
)
# Should create router instance
router_mod
.
from_args
.
assert_called_once
()
def
test_policy_info_during_startup
(
self
):
"""Test policy info logging during startup in PD mode."""
# Test with only prefill policy specified
args
=
RouterArgs
(
pd_disaggregation
=
True
,
prefill_urls
=
[(
"http://prefill1:8000"
,
None
)],
decode_urls
=
[
"http://decode1:8001"
],
policy
=
"cache_aware"
,
prefill_policy
=
"power_of_two"
,
decode_policy
=
None
,
)
with
patch
(
"sglang_router.launch_router.Router"
)
as
router_mod
:
mock_router_instance
=
MagicMock
()
router_mod
.
from_args
=
MagicMock
(
return_value
=
mock_router_instance
)
# The policy messages are emitted by router_args logger
with
patch
(
"sglang_router.router_args.logger"
)
as
mock_logger
:
result
=
launch_router
(
args
)
# Should log info about policy usage
mock_logger
.
info
.
assert_called_once
()
info_call
=
mock_logger
.
info
.
call_args
[
0
][
0
]
assert
"Using --prefill-policy 'power_of_two'"
in
info_call
assert
"and --policy 'cache_aware'"
in
info_call
# Should create router instance
router_mod
.
from_args
.
assert_called_once
()
def
test_policy_info_decode_only_during_startup
(
self
):
"""Test policy info logging during startup with only decode policy specified."""
args
=
RouterArgs
(
pd_disaggregation
=
True
,
prefill_urls
=
[(
"http://prefill1:8000"
,
None
)],
decode_urls
=
[
"http://decode1:8001"
],
policy
=
"cache_aware"
,
prefill_policy
=
None
,
decode_policy
=
"round_robin"
,
)
with
patch
(
"sglang_router.launch_router.Router"
)
as
router_mod
:
mock_router_instance
=
MagicMock
()
router_mod
.
from_args
=
MagicMock
(
return_value
=
mock_router_instance
)
# The policy messages are emitted by router_args logger
with
patch
(
"sglang_router.router_args.logger"
)
as
mock_logger
:
result
=
launch_router
(
args
)
# Should log info about policy usage
mock_logger
.
info
.
assert_called_once
()
info_call
=
mock_logger
.
info
.
call_args
[
0
][
0
]
assert
"Using --policy 'cache_aware'"
in
info_call
assert
"and --decode-policy 'round_robin'"
in
info_call
# Should create router instance
router_mod
.
from_args
.
assert_called_once
()
class
TestStartupErrorHandling
:
"""Test startup error handling logic."""
def
test_router_creation_error_handling
(
self
):
"""Test error handling when router creation fails."""
args
=
RouterArgs
(
host
=
"127.0.0.1"
,
port
=
30000
,
worker_urls
=
[
"http://worker1:8000"
]
)
with
patch
(
"sglang_router.launch_router.Router"
)
as
router_mod
:
# Simulate router creation failure in from_args
router_mod
.
from_args
=
MagicMock
(
side_effect
=
Exception
(
"Router creation failed"
)
)
with
patch
(
"sglang_router.launch_router.logger"
)
as
mock_logger
:
with
pytest
.
raises
(
Exception
,
match
=
"Router creation failed"
):
launch_router
(
args
)
# Should log error
mock_logger
.
error
.
assert_called_once
()
error_call
=
mock_logger
.
error
.
call_args
[
0
][
0
]
assert
"Error starting router: Router creation failed"
in
error_call
def
test_router_start_error_handling
(
self
):
"""Test error handling when router start fails."""
args
=
RouterArgs
(
host
=
"127.0.0.1"
,
port
=
30000
,
worker_urls
=
[
"http://worker1:8000"
]
)
with
patch
(
"sglang_router.launch_router.Router"
)
as
router_mod
:
mock_router_instance
=
MagicMock
()
router_mod
.
from_args
=
MagicMock
(
return_value
=
mock_router_instance
)
# Simulate router start failure
mock_router_instance
.
start
.
side_effect
=
Exception
(
"Router start failed"
)
with
patch
(
"sglang_router.launch_router.logger"
)
as
mock_logger
:
with
pytest
.
raises
(
Exception
,
match
=
"Router start failed"
):
launch_router
(
args
)
# Should log error
mock_logger
.
error
.
assert_called_once
()
error_call
=
mock_logger
.
error
.
call_args
[
0
][
0
]
assert
"Error starting router: Router start failed"
in
error_call
# --- Added unit tests for Router wrapper and launch_server helpers ---
def
_install_sglang_stubs
(
monkeypatch
):
"""Install lightweight stubs for sglang.srt to avoid heavy deps during unit tests."""
import
sys
import
types
sglang_mod
=
types
.
ModuleType
(
"sglang"
)
srt_mod
=
types
.
ModuleType
(
"sglang.srt"
)
entry_mod
=
types
.
ModuleType
(
"sglang.srt.entrypoints"
)
http_server_mod
=
types
.
ModuleType
(
"sglang.srt.entrypoints.http_server"
)
server_args_mod
=
types
.
ModuleType
(
"sglang.srt.server_args"
)
utils_mod
=
types
.
ModuleType
(
"sglang.srt.utils"
)
def
launch_server
(
_args
):
return
None
class
ServerArgs
:
# Minimal fields used by launch_server_process
def
__init__
(
self
):
self
.
port
=
0
self
.
base_gpu_id
=
0
self
.
dp_size
=
1
self
.
tp_size
=
1
@
staticmethod
def
add_cli_args
(
_parser
):
return
None
@
staticmethod
def
from_cli_args
(
_args
):
sa
=
ServerArgs
()
if
hasattr
(
_args
,
"dp_size"
):
sa
.
dp_size
=
_args
.
dp_size
if
hasattr
(
_args
,
"tp_size"
):
sa
.
tp_size
=
_args
.
tp_size
if
hasattr
(
_args
,
"host"
):
sa
.
host
=
_args
.
host
else
:
sa
.
host
=
"127.0.0.1"
return
sa
def
is_port_available
(
_port
:
int
)
->
bool
:
return
True
http_server_mod
.
launch_server
=
launch_server
server_args_mod
.
ServerArgs
=
ServerArgs
utils_mod
.
is_port_available
=
is_port_available
# Also stub external deps imported at module top-level
def
_dummy_get
(
*
_a
,
**
_k
):
raise
NotImplementedError
requests_stub
=
types
.
SimpleNamespace
(
exceptions
=
types
.
SimpleNamespace
(
RequestException
=
Exception
),
get
=
_dummy_get
)
setproctitle_stub
=
types
.
SimpleNamespace
(
setproctitle
=
lambda
*
_a
,
**
_k
:
None
)
monkeypatch
.
setitem
(
sys
.
modules
,
"requests"
,
requests_stub
)
monkeypatch
.
setitem
(
sys
.
modules
,
"setproctitle"
,
setproctitle_stub
)
monkeypatch
.
setitem
(
sys
.
modules
,
"sglang"
,
sglang_mod
)
monkeypatch
.
setitem
(
sys
.
modules
,
"sglang.srt"
,
srt_mod
)
monkeypatch
.
setitem
(
sys
.
modules
,
"sglang.srt.entrypoints"
,
entry_mod
)
monkeypatch
.
setitem
(
sys
.
modules
,
"sglang.srt.entrypoints.http_server"
,
http_server_mod
)
monkeypatch
.
setitem
(
sys
.
modules
,
"sglang.srt.server_args"
,
server_args_mod
)
monkeypatch
.
setitem
(
sys
.
modules
,
"sglang.srt.utils"
,
utils_mod
)
def
test_router_defaults_and_start
(
monkeypatch
):
"""Router wrapper: defaults normalization and start() call.
Mocks the Rust-backed _Router to avoid native deps.
"""
from
sglang_router
import
router
as
router_mod
captured
=
{}
class
FakeRouter
:
def
__init__
(
self
,
**
kwargs
):
captured
.
update
(
kwargs
)
def
start
(
self
):
captured
[
"started"
]
=
True
monkeypatch
.
setattr
(
router_mod
,
"_Router"
,
FakeRouter
,
raising
=
True
)
from
sglang_router.router_args
import
RouterArgs
as
_RouterArgs
Router
=
router_mod
.
Router
args
=
_RouterArgs
(
worker_urls
=
[
"http://w1:8000"
],
policy
=
"round_robin"
,
selector
=
None
,
prefill_selector
=
None
,
decode_selector
=
None
,
cors_allowed_origins
=
None
,
)
r
=
Router
.
from_args
(
args
)
# Defaults preserved/normalized by Router.from_args
assert
captured
[
"selector"
]
is
None
assert
captured
[
"prefill_selector"
]
is
None
assert
captured
[
"decode_selector"
]
is
None
assert
captured
[
"cors_allowed_origins"
]
is
None
assert
captured
[
"worker_urls"
]
==
[
"http://w1:8000"
]
from
sglang_router_rs
import
PolicyType
assert
captured
[
"policy"
]
==
PolicyType
.
RoundRobin
r
.
start
()
assert
captured
.
get
(
"started"
)
is
True
def
test_find_available_ports_and_wait_health
(
monkeypatch
):
"""launch_server helpers: port finding and health waiting with transient error."""
_install_sglang_stubs
(
monkeypatch
)
import
importlib
ls
=
importlib
.
import_module
(
"sglang_router.launch_server"
)
# Deterministic increments
monkeypatch
.
setattr
(
ls
.
random
,
"randint"
,
lambda
a
,
b
:
100
)
ports
=
ls
.
find_available_ports
(
30000
,
3
)
assert
ports
==
[
30000
,
30100
,
30200
]
calls
=
{
"n"
:
0
}
class
Ok
:
status_code
=
200
def
fake_get
(
_url
,
timeout
=
5
):
calls
[
"n"
]
+=
1
if
calls
[
"n"
]
==
1
:
raise
ls
.
requests
.
exceptions
.
RequestException
(
"boom"
)
return
Ok
()
monkeypatch
.
setattr
(
ls
.
requests
,
"get"
,
fake_get
)
monkeypatch
.
setattr
(
ls
.
time
,
"sleep"
,
lambda
_s
:
None
)
base
=
{
"t"
:
0.0
}
monkeypatch
.
setattr
(
ls
.
time
,
"perf_counter"
,
lambda
:
(
base
.
__setitem__
(
"t"
,
base
[
"t"
]
+
0.1
)
or
base
[
"t"
]),
)
assert
ls
.
wait_for_server_health
(
"127.0.0.1"
,
12345
,
timeout
=
1
)
def
test_launch_server_process_and_cleanup
(
monkeypatch
):
"""launch_server: process creation args and cleanup SIGTERM/SIGKILL logic."""
_install_sglang_stubs
(
monkeypatch
)
import
importlib
ls
=
importlib
.
import_module
(
"sglang_router.launch_server"
)
created
=
{}
class
FakeProcess
:
def
__init__
(
self
,
target
,
args
):
created
[
"target"
]
=
target
created
[
"args"
]
=
args
self
.
pid
=
4242
self
.
_alive
=
True
def
start
(
self
):
created
[
"started"
]
=
True
def
join
(
self
,
timeout
=
None
):
return
None
def
is_alive
(
self
):
return
self
.
_alive
monkeypatch
.
setattr
(
ls
.
mp
,
"Process"
,
FakeProcess
)
import
sys
as
_sys
SA
=
_sys
.
modules
[
"sglang.srt.server_args"
].
ServerArgs
sa
=
SA
()
sa
.
tp_size
=
2
proc
=
ls
.
launch_server_process
(
sa
,
worker_port
=
31001
,
dp_id
=
3
)
assert
created
.
get
(
"started"
)
is
True
targ
,
targ_args
=
created
[
"target"
],
created
[
"args"
]
assert
targ
is
ls
.
run_server
passed_sa
=
targ_args
[
0
]
assert
passed_sa
.
port
==
31001
assert
passed_sa
.
base_gpu_id
==
3
*
2
assert
passed_sa
.
dp_size
==
1
# cleanup_processes
p1
=
FakeProcess
(
target
=
None
,
args
=
())
p1
.
_alive
=
False
p2
=
FakeProcess
(
target
=
None
,
args
=
())
p2
.
_alive
=
True
calls
=
[]
def
fake_killpg
(
pid
,
sig
):
calls
.
append
((
pid
,
sig
))
monkeypatch
.
setattr
(
ls
.
os
,
"killpg"
,
fake_killpg
)
ls
.
cleanup_processes
([
p1
,
p2
])
import
signal
as
_sig
assert
(
p1
.
pid
,
_sig
.
SIGTERM
)
in
calls
and
(
p2
.
pid
,
_sig
.
SIGTERM
)
in
calls
assert
(
p2
.
pid
,
_sig
.
SIGKILL
)
in
calls
def
test_validation_error_handling
(
self
):
"""Test error handling when validation fails."""
args
=
RouterArgs
(
pd_disaggregation
=
True
,
prefill_urls
=
[],
decode_urls
=
[],
service_discovery
=
False
,
)
with
patch
(
"sglang_router.launch_router.logger"
)
as
mock_logger
:
with
pytest
.
raises
(
ValueError
,
match
=
"PD disaggregation mode requires --prefill"
):
launch_router
(
args
)
# Should log error for validation failures
mock_logger
.
error
.
assert_called_once
()
class
TestStartupFlow
:
"""Test complete startup flow."""
def
test_complete_startup_flow_basic
(
self
):
"""Test complete startup flow for basic configuration."""
args
=
RouterArgs
(
host
=
"127.0.0.1"
,
port
=
30000
,
worker_urls
=
[
"http://worker1:8000"
,
"http://worker2:8000"
],
policy
=
"cache_aware"
,
cache_threshold
=
0.5
,
balance_abs_threshold
=
32
,
balance_rel_threshold
=
1.5
,
)
with
patch
(
"sglang_router.launch_router.Router"
)
as
router_mod
:
mock_router_instance
=
MagicMock
()
router_mod
.
from_args
=
MagicMock
(
return_value
=
mock_router_instance
)
result
=
launch_router
(
args
)
# Verify complete flow
router_mod
.
from_args
.
assert_called_once
()
mock_router_instance
.
start
.
assert_called_once
()
def
test_complete_startup_flow_pd_mode
(
self
):
"""Test complete startup flow for PD mode configuration."""
args
=
RouterArgs
(
pd_disaggregation
=
True
,
prefill_urls
=
[
(
"http://prefill1:8000"
,
9000
),
(
"http://prefill2:8000"
,
None
),
],
decode_urls
=
[
"http://decode1:8001"
,
"http://decode2:8001"
],
policy
=
"power_of_two"
,
prefill_policy
=
"cache_aware"
,
decode_policy
=
"round_robin"
,
)
with
patch
(
"sglang_router.launch_router.Router"
)
as
router_mod
:
mock_router_instance
=
MagicMock
()
router_mod
.
from_args
=
MagicMock
(
return_value
=
mock_router_instance
)
with
patch
(
"sglang_router.router_args.logger"
)
as
mock_logger
:
result
=
launch_router
(
args
)
# Verify complete flow
router_mod
.
from_args
.
assert_called_once
()
mock_router_instance
.
start
.
assert_called_once
()
# Verify policy warning was logged
mock_logger
.
warning
.
assert_called_once
()
def
test_complete_startup_flow_with_all_features
(
self
):
"""Test complete startup flow with all features enabled."""
args
=
RouterArgs
(
host
=
"0.0.0.0"
,
port
=
30001
,
worker_urls
=
[
"http://worker1:8000"
],
policy
=
"round_robin"
,
service_discovery
=
True
,
selector
=
{
"app"
:
"worker"
},
service_discovery_port
=
8080
,
service_discovery_namespace
=
"default"
,
dp_aware
=
True
,
api_key
=
"test-key"
,
log_dir
=
"/tmp/logs"
,
log_level
=
"debug"
,
prometheus_port
=
29000
,
prometheus_host
=
"0.0.0.0"
,
request_id_headers
=
[
"x-request-id"
,
"x-trace-id"
],
request_timeout_secs
=
1200
,
max_concurrent_requests
=
512
,
queue_size
=
200
,
queue_timeout_secs
=
120
,
rate_limit_tokens_per_second
=
100
,
cors_allowed_origins
=
[
"http://localhost:3000"
],
retry_max_retries
=
3
,
retry_initial_backoff_ms
=
100
,
retry_max_backoff_ms
=
10000
,
retry_backoff_multiplier
=
2.0
,
retry_jitter_factor
=
0.1
,
cb_failure_threshold
=
5
,
cb_success_threshold
=
2
,
cb_timeout_duration_secs
=
30
,
cb_window_duration_secs
=
60
,
health_failure_threshold
=
2
,
health_success_threshold
=
1
,
health_check_timeout_secs
=
3
,
health_check_interval_secs
=
30
,
health_check_endpoint
=
"/healthz"
,
)
with
patch
(
"sglang_router.launch_router.Router"
)
as
router_mod
:
captured_args
=
{}
mock_router_instance
=
MagicMock
()
def
fake_from_args
(
router_args
):
captured_args
.
update
(
dict
(
host
=
router_args
.
host
,
port
=
router_args
.
port
,
worker_urls
=
router_args
.
worker_urls
,
policy
=
policy_from_str
(
router_args
.
policy
),
service_discovery
=
router_args
.
service_discovery
,
selector
=
router_args
.
selector
,
service_discovery_port
=
router_args
.
service_discovery_port
,
service_discovery_namespace
=
router_args
.
service_discovery_namespace
,
dp_aware
=
router_args
.
dp_aware
,
api_key
=
router_args
.
api_key
,
log_dir
=
router_args
.
log_dir
,
log_level
=
router_args
.
log_level
,
prometheus_port
=
router_args
.
prometheus_port
,
prometheus_host
=
router_args
.
prometheus_host
,
request_id_headers
=
router_args
.
request_id_headers
,
request_timeout_secs
=
router_args
.
request_timeout_secs
,
max_concurrent_requests
=
router_args
.
max_concurrent_requests
,
queue_size
=
router_args
.
queue_size
,
queue_timeout_secs
=
router_args
.
queue_timeout_secs
,
rate_limit_tokens_per_second
=
router_args
.
rate_limit_tokens_per_second
,
cors_allowed_origins
=
router_args
.
cors_allowed_origins
,
retry_max_retries
=
router_args
.
retry_max_retries
,
retry_initial_backoff_ms
=
router_args
.
retry_initial_backoff_ms
,
retry_max_backoff_ms
=
router_args
.
retry_max_backoff_ms
,
retry_backoff_multiplier
=
router_args
.
retry_backoff_multiplier
,
retry_jitter_factor
=
router_args
.
retry_jitter_factor
,
cb_failure_threshold
=
router_args
.
cb_failure_threshold
,
cb_success_threshold
=
router_args
.
cb_success_threshold
,
cb_timeout_duration_secs
=
router_args
.
cb_timeout_duration_secs
,
cb_window_duration_secs
=
router_args
.
cb_window_duration_secs
,
health_failure_threshold
=
router_args
.
health_failure_threshold
,
health_success_threshold
=
router_args
.
health_success_threshold
,
health_check_timeout_secs
=
router_args
.
health_check_timeout_secs
,
health_check_interval_secs
=
router_args
.
health_check_interval_secs
,
health_check_endpoint
=
router_args
.
health_check_endpoint
,
)
)
return
mock_router_instance
router_mod
.
from_args
=
MagicMock
(
side_effect
=
fake_from_args
)
result
=
launch_router
(
args
)
# Verify complete flow
router_mod
.
from_args
.
assert_called_once
()
mock_router_instance
.
start
.
assert_called_once
()
# Verify key parameters were propagated into RouterArgs
assert
captured_args
[
"host"
]
==
"0.0.0.0"
assert
captured_args
[
"port"
]
==
30001
assert
captured_args
[
"worker_urls"
]
==
[
"http://worker1:8000"
]
assert
captured_args
[
"policy"
]
==
PolicyType
.
RoundRobin
assert
captured_args
[
"service_discovery"
]
is
True
assert
captured_args
[
"selector"
]
==
{
"app"
:
"worker"
}
assert
captured_args
[
"service_discovery_port"
]
==
8080
assert
captured_args
[
"service_discovery_namespace"
]
==
"default"
assert
captured_args
[
"dp_aware"
]
is
True
assert
captured_args
[
"api_key"
]
==
"test-key"
assert
captured_args
[
"log_dir"
]
==
"/tmp/logs"
assert
captured_args
[
"log_level"
]
==
"debug"
assert
captured_args
[
"prometheus_port"
]
==
29000
assert
captured_args
[
"prometheus_host"
]
==
"0.0.0.0"
assert
captured_args
[
"request_id_headers"
]
==
[
"x-request-id"
,
"x-trace-id"
]
assert
captured_args
[
"request_timeout_secs"
]
==
1200
assert
captured_args
[
"max_concurrent_requests"
]
==
512
assert
captured_args
[
"queue_size"
]
==
200
assert
captured_args
[
"queue_timeout_secs"
]
==
120
assert
captured_args
[
"rate_limit_tokens_per_second"
]
==
100
assert
captured_args
[
"cors_allowed_origins"
]
==
[
"http://localhost:3000"
]
assert
captured_args
[
"retry_max_retries"
]
==
3
assert
captured_args
[
"retry_initial_backoff_ms"
]
==
100
assert
captured_args
[
"retry_max_backoff_ms"
]
==
10000
assert
captured_args
[
"retry_backoff_multiplier"
]
==
2.0
assert
captured_args
[
"retry_jitter_factor"
]
==
0.1
assert
captured_args
[
"cb_failure_threshold"
]
==
5
assert
captured_args
[
"cb_success_threshold"
]
==
2
assert
captured_args
[
"cb_timeout_duration_secs"
]
==
30
assert
captured_args
[
"cb_window_duration_secs"
]
==
60
assert
captured_args
[
"health_failure_threshold"
]
==
2
assert
captured_args
[
"health_success_threshold"
]
==
1
assert
captured_args
[
"health_check_timeout_secs"
]
==
3
assert
captured_args
[
"health_check_interval_secs"
]
==
30
assert
captured_args
[
"health_check_endpoint"
]
==
"/healthz"
sgl-router/py_test/unit/test_validation.py
0 → 100644
View file @
045ab92d
"""
Unit tests for validation logic in sglang_router.
These tests focus on testing the validation logic in isolation,
including parameter validation, URL validation, and configuration validation.
"""
from
types
import
SimpleNamespace
from
unittest.mock
import
MagicMock
,
patch
import
pytest
from
sglang_router.launch_router
import
RouterArgs
,
launch_router
class
TestURLValidation
:
"""Test URL validation logic."""
def
test_valid_worker_urls
(
self
):
"""Test validation of valid worker URLs."""
valid_urls
=
[
"http://worker1:8000"
,
"https://worker2:8000"
,
"http://localhost:8000"
,
"http://127.0.0.1:8000"
,
"http://192.168.1.100:8000"
,
"http://worker.example.com:8000"
,
]
for
url
in
valid_urls
:
args
=
RouterArgs
(
worker_urls
=
[
url
])
# Should not raise any validation errors
assert
url
in
args
.
worker_urls
def
test_valid_prefill_urls
(
self
):
"""Test validation of valid prefill URLs."""
valid_prefill_urls
=
[
(
"http://prefill1:8000"
,
9000
),
(
"https://prefill2:8000"
,
None
),
(
"http://localhost:8000"
,
9000
),
(
"http://127.0.0.1:8000"
,
None
),
]
for
url
,
bootstrap_port
in
valid_prefill_urls
:
args
=
RouterArgs
(
prefill_urls
=
[(
url
,
bootstrap_port
)])
# Should not raise any validation errors
assert
(
url
,
bootstrap_port
)
in
args
.
prefill_urls
def
test_valid_decode_urls
(
self
):
"""Test validation of valid decode URLs."""
valid_decode_urls
=
[
"http://decode1:8001"
,
"https://decode2:8001"
,
"http://localhost:8001"
,
"http://127.0.0.1:8001"
,
]
for
url
in
valid_decode_urls
:
args
=
RouterArgs
(
decode_urls
=
[
url
])
# Should not raise any validation errors
assert
url
in
args
.
decode_urls
def
test_malformed_urls
(
self
):
"""Test handling of malformed URLs."""
# Note: The current implementation doesn't validate URL format
# This test documents the current behavior
malformed_urls
=
[
"not-a-url"
,
"ftp://worker1:8000"
,
# Wrong protocol
"http://"
,
# Missing host
":8000"
,
# Missing protocol and host
"http://worker1"
,
# Missing port
]
for
url
in
malformed_urls
:
args
=
RouterArgs
(
worker_urls
=
[
url
])
# Currently, malformed URLs are accepted
# This might be something to improve in the future
assert
url
in
args
.
worker_urls
class
TestPortValidation
:
"""Test port validation logic."""
def
test_valid_ports
(
self
):
"""Test validation of valid port numbers."""
valid_ports
=
[
1
,
80
,
8000
,
30000
,
65535
]
for
port
in
valid_ports
:
args
=
RouterArgs
(
port
=
port
)
assert
args
.
port
==
port
def
test_invalid_ports
(
self
):
"""Test handling of invalid port numbers."""
# Note: The current implementation doesn't validate port ranges
# This test documents the current behavior
invalid_ports
=
[
0
,
-
1
,
65536
,
70000
]
for
port
in
invalid_ports
:
args
=
RouterArgs
(
port
=
port
)
# Currently, invalid ports are accepted
# This might be something to improve in the future
assert
args
.
port
==
port
def
test_bootstrap_port_validation
(
self
):
"""Test validation of bootstrap ports in PD mode."""
valid_bootstrap_ports
=
[
1
,
80
,
9000
,
30000
,
65535
,
None
]
for
bootstrap_port
in
valid_bootstrap_ports
:
args
=
RouterArgs
(
prefill_urls
=
[(
"http://prefill1:8000"
,
bootstrap_port
)])
assert
args
.
prefill_urls
[
0
][
1
]
==
bootstrap_port
class
TestParameterValidation
:
"""Test parameter validation logic."""
def
test_cache_threshold_validation
(
self
):
"""Test cache threshold parameter validation."""
# Valid cache thresholds
valid_thresholds
=
[
0.0
,
0.1
,
0.5
,
0.9
,
1.0
]
for
threshold
in
valid_thresholds
:
args
=
RouterArgs
(
cache_threshold
=
threshold
)
assert
args
.
cache_threshold
==
threshold
def
test_balance_threshold_validation
(
self
):
"""Test load balancing threshold parameter validation."""
# Valid absolute thresholds
valid_abs_thresholds
=
[
0
,
1
,
32
,
64
,
128
,
1000
]
for
threshold
in
valid_abs_thresholds
:
args
=
RouterArgs
(
balance_abs_threshold
=
threshold
)
assert
args
.
balance_abs_threshold
==
threshold
# Valid relative thresholds
valid_rel_thresholds
=
[
1.0
,
1.1
,
1.5
,
2.0
,
10.0
]
for
threshold
in
valid_rel_thresholds
:
args
=
RouterArgs
(
balance_rel_threshold
=
threshold
)
assert
args
.
balance_rel_threshold
==
threshold
def
test_timeout_validation
(
self
):
"""Test timeout parameter validation."""
# Valid timeouts
valid_timeouts
=
[
1
,
30
,
60
,
300
,
600
,
1800
,
3600
]
for
timeout
in
valid_timeouts
:
args
=
RouterArgs
(
worker_startup_timeout_secs
=
timeout
,
worker_startup_check_interval
=
timeout
,
request_timeout_secs
=
timeout
,
queue_timeout_secs
=
timeout
,
)
assert
args
.
worker_startup_timeout_secs
==
timeout
assert
args
.
worker_startup_check_interval
==
timeout
assert
args
.
request_timeout_secs
==
timeout
assert
args
.
queue_timeout_secs
==
timeout
def
test_retry_parameter_validation
(
self
):
"""Test retry parameter validation."""
# Valid retry parameters
valid_retry_counts
=
[
0
,
1
,
3
,
5
,
10
]
for
count
in
valid_retry_counts
:
args
=
RouterArgs
(
retry_max_retries
=
count
)
assert
args
.
retry_max_retries
==
count
# Valid backoff parameters
valid_backoff_ms
=
[
1
,
50
,
100
,
1000
,
30000
]
for
backoff
in
valid_backoff_ms
:
args
=
RouterArgs
(
retry_initial_backoff_ms
=
backoff
,
retry_max_backoff_ms
=
backoff
)
assert
args
.
retry_initial_backoff_ms
==
backoff
assert
args
.
retry_max_backoff_ms
==
backoff
# Valid multiplier parameters
valid_multipliers
=
[
1.0
,
1.5
,
2.0
,
3.0
]
for
multiplier
in
valid_multipliers
:
args
=
RouterArgs
(
retry_backoff_multiplier
=
multiplier
)
assert
args
.
retry_backoff_multiplier
==
multiplier
# Valid jitter parameters
valid_jitter
=
[
0.0
,
0.1
,
0.2
,
0.5
]
for
jitter
in
valid_jitter
:
args
=
RouterArgs
(
retry_jitter_factor
=
jitter
)
assert
args
.
retry_jitter_factor
==
jitter
def
test_circuit_breaker_parameter_validation
(
self
):
"""Test circuit breaker parameter validation."""
# Valid failure thresholds
valid_failure_thresholds
=
[
1
,
3
,
5
,
10
,
20
]
for
threshold
in
valid_failure_thresholds
:
args
=
RouterArgs
(
cb_failure_threshold
=
threshold
)
assert
args
.
cb_failure_threshold
==
threshold
# Valid success thresholds
valid_success_thresholds
=
[
1
,
2
,
3
,
5
]
for
threshold
in
valid_success_thresholds
:
args
=
RouterArgs
(
cb_success_threshold
=
threshold
)
assert
args
.
cb_success_threshold
==
threshold
# Valid timeout durations
valid_timeouts
=
[
10
,
30
,
60
,
120
,
300
]
for
timeout
in
valid_timeouts
:
args
=
RouterArgs
(
cb_timeout_duration_secs
=
timeout
,
cb_window_duration_secs
=
timeout
)
assert
args
.
cb_timeout_duration_secs
==
timeout
assert
args
.
cb_window_duration_secs
==
timeout
def
test_health_check_parameter_validation
(
self
):
"""Test health check parameter validation."""
# Valid failure thresholds
valid_failure_thresholds
=
[
1
,
2
,
3
,
5
,
10
]
for
threshold
in
valid_failure_thresholds
:
args
=
RouterArgs
(
health_failure_threshold
=
threshold
)
assert
args
.
health_failure_threshold
==
threshold
# Valid success thresholds
valid_success_thresholds
=
[
1
,
2
,
3
,
5
]
for
threshold
in
valid_success_thresholds
:
args
=
RouterArgs
(
health_success_threshold
=
threshold
)
assert
args
.
health_success_threshold
==
threshold
# Valid timeouts and intervals
valid_times
=
[
1
,
5
,
10
,
30
,
60
,
120
]
for
time_val
in
valid_times
:
args
=
RouterArgs
(
health_check_timeout_secs
=
time_val
,
health_check_interval_secs
=
time_val
)
assert
args
.
health_check_timeout_secs
==
time_val
assert
args
.
health_check_interval_secs
==
time_val
def
test_rate_limiting_parameter_validation
(
self
):
"""Test rate limiting parameter validation."""
# Valid concurrent request limits
valid_limits
=
[
1
,
10
,
64
,
256
,
512
,
1000
]
for
limit
in
valid_limits
:
args
=
RouterArgs
(
max_concurrent_requests
=
limit
)
assert
args
.
max_concurrent_requests
==
limit
# Valid queue sizes
valid_queue_sizes
=
[
0
,
10
,
50
,
100
,
500
,
1000
]
for
size
in
valid_queue_sizes
:
args
=
RouterArgs
(
queue_size
=
size
)
assert
args
.
queue_size
==
size
# Valid token rates
valid_rates
=
[
1
,
10
,
50
,
100
,
500
,
1000
]
for
rate
in
valid_rates
:
args
=
RouterArgs
(
rate_limit_tokens_per_second
=
rate
)
assert
args
.
rate_limit_tokens_per_second
==
rate
def
test_tree_size_validation
(
self
):
"""Test tree size parameter validation."""
# Valid tree sizes (powers of 2)
valid_sizes
=
[
2
**
10
,
2
**
20
,
2
**
24
,
2
**
26
,
2
**
28
,
2
**
30
]
for
size
in
valid_sizes
:
args
=
RouterArgs
(
max_tree_size
=
size
)
assert
args
.
max_tree_size
==
size
def
test_payload_size_validation
(
self
):
"""Test payload size parameter validation."""
# Valid payload sizes
valid_sizes
=
[
1024
,
# 1KB
1024
*
1024
,
# 1MB
10
*
1024
*
1024
,
# 10MB
100
*
1024
*
1024
,
# 100MB
512
*
1024
*
1024
,
# 512MB
1024
*
1024
*
1024
,
# 1GB
]
for
size
in
valid_sizes
:
args
=
RouterArgs
(
max_payload_size
=
size
)
assert
args
.
max_payload_size
==
size
class
TestConfigurationValidation
:
"""Test configuration validation logic."""
def
test_pd_mode_validation
(
self
):
"""Test PD mode configuration validation."""
# Valid PD configuration
args
=
RouterArgs
(
pd_disaggregation
=
True
,
prefill_urls
=
[(
"http://prefill1:8000"
,
9000
)],
decode_urls
=
[
"http://decode1:8001"
],
)
assert
args
.
pd_disaggregation
is
True
assert
len
(
args
.
prefill_urls
)
>
0
assert
len
(
args
.
decode_urls
)
>
0
def
test_service_discovery_validation
(
self
):
"""Test service discovery configuration validation."""
# Valid service discovery configuration
args
=
RouterArgs
(
service_discovery
=
True
,
selector
=
{
"app"
:
"worker"
,
"env"
:
"prod"
},
service_discovery_port
=
8080
,
service_discovery_namespace
=
"default"
,
)
assert
args
.
service_discovery
is
True
assert
args
.
selector
==
{
"app"
:
"worker"
,
"env"
:
"prod"
}
assert
args
.
service_discovery_port
==
8080
assert
args
.
service_discovery_namespace
==
"default"
def
test_pd_service_discovery_validation
(
self
):
"""Test PD service discovery configuration validation."""
# Valid PD service discovery configuration
args
=
RouterArgs
(
pd_disaggregation
=
True
,
service_discovery
=
True
,
prefill_selector
=
{
"app"
:
"prefill"
},
decode_selector
=
{
"app"
:
"decode"
},
)
assert
args
.
pd_disaggregation
is
True
assert
args
.
service_discovery
is
True
assert
args
.
prefill_selector
==
{
"app"
:
"prefill"
}
assert
args
.
decode_selector
==
{
"app"
:
"decode"
}
def
test_policy_validation
(
self
):
"""Test policy configuration validation."""
# Valid policies
valid_policies
=
[
"random"
,
"round_robin"
,
"cache_aware"
,
"power_of_two"
]
for
policy
in
valid_policies
:
args
=
RouterArgs
(
policy
=
policy
)
assert
args
.
policy
==
policy
def
test_pd_policy_validation
(
self
):
"""Test PD policy configuration validation."""
# Valid PD policies
valid_policies
=
[
"random"
,
"round_robin"
,
"cache_aware"
,
"power_of_two"
]
for
prefill_policy
in
valid_policies
:
for
decode_policy
in
valid_policies
:
args
=
RouterArgs
(
pd_disaggregation
=
True
,
prefill_urls
=
[(
"http://prefill1:8000"
,
None
)],
decode_urls
=
[
"http://decode1:8001"
],
prefill_policy
=
prefill_policy
,
decode_policy
=
decode_policy
,
)
assert
args
.
prefill_policy
==
prefill_policy
assert
args
.
decode_policy
==
decode_policy
def
test_cors_validation
(
self
):
"""Test CORS configuration validation."""
# Valid CORS origins
valid_origins
=
[
[],
[
"http://localhost:3000"
],
[
"https://example.com"
],
[
"http://localhost:3000"
,
"https://example.com"
],
[
"*"
],
# Wildcard (if supported)
]
for
origins
in
valid_origins
:
args
=
RouterArgs
(
cors_allowed_origins
=
origins
)
assert
args
.
cors_allowed_origins
==
origins
def
test_logging_validation
(
self
):
"""Test logging configuration validation."""
# Valid log levels
valid_log_levels
=
[
"debug"
,
"info"
,
"warning"
,
"error"
,
"critical"
]
for
level
in
valid_log_levels
:
args
=
RouterArgs
(
log_level
=
level
)
assert
args
.
log_level
==
level
def
test_prometheus_validation
(
self
):
"""Test Prometheus configuration validation."""
# Valid Prometheus configuration
args
=
RouterArgs
(
prometheus_port
=
29000
,
prometheus_host
=
"127.0.0.1"
)
assert
args
.
prometheus_port
==
29000
assert
args
.
prometheus_host
==
"127.0.0.1"
def
test_tokenizer_validation
(
self
):
"""Test tokenizer configuration validation."""
# Note: model_path and tokenizer_path are not available in current RouterArgs
pytest
.
skip
(
"Tokenizer configuration not available in current implementation"
)
def
test_request_id_headers_validation
(
self
):
"""Test request ID headers configuration validation."""
# Valid request ID headers
valid_headers
=
[
[
"x-request-id"
],
[
"x-request-id"
,
"x-trace-id"
],
[
"x-request-id"
,
"x-trace-id"
,
"x-correlation-id"
],
[
"custom-header"
],
]
for
headers
in
valid_headers
:
args
=
RouterArgs
(
request_id_headers
=
headers
)
assert
args
.
request_id_headers
==
headers
class
TestLaunchValidation
:
"""Test launch-time validation logic."""
def
test_pd_mode_requires_urls
(
self
):
"""Test that PD mode requires prefill and decode URLs."""
# PD mode without URLs should fail
args
=
RouterArgs
(
pd_disaggregation
=
True
,
prefill_urls
=
[],
decode_urls
=
[],
service_discovery
=
False
,
)
with
pytest
.
raises
(
ValueError
,
match
=
"PD disaggregation mode requires --prefill"
):
launch_router
(
args
)
def
test_pd_mode_with_service_discovery_allows_empty_urls
(
self
):
"""Test that PD mode with service discovery allows empty URLs."""
args
=
RouterArgs
(
pd_disaggregation
=
True
,
prefill_urls
=
[],
decode_urls
=
[],
service_discovery
=
True
,
)
# Should not raise validation error
with
patch
(
"sglang_router.launch_router.Router"
)
as
router_mod
:
mock_router_instance
=
MagicMock
()
router_mod
.
from_args
=
MagicMock
(
return_value
=
mock_router_instance
)
launch_router
(
args
)
# Should create router instance via from_args
router_mod
.
from_args
.
assert_called_once
()
def
test_regular_mode_allows_empty_worker_urls
(
self
):
"""Test that regular mode allows empty worker URLs."""
args
=
RouterArgs
(
worker_urls
=
[],
service_discovery
=
False
)
# Should not raise validation error
with
patch
(
"sglang_router.launch_router.Router"
)
as
router_mod
:
mock_router_instance
=
MagicMock
()
router_mod
.
from_args
=
MagicMock
(
return_value
=
mock_router_instance
)
launch_router
(
args
)
# Should create router instance via from_args
router_mod
.
from_args
.
assert_called_once
()
def
test_launch_with_valid_config
(
self
):
"""Test launching with valid configuration."""
args
=
RouterArgs
(
host
=
"127.0.0.1"
,
port
=
30000
,
worker_urls
=
[
"http://worker1:8000"
],
policy
=
"cache_aware"
,
)
# Should not raise validation error
with
patch
(
"sglang_router.launch_router.Router"
)
as
router_mod
:
mock_router_instance
=
MagicMock
()
router_mod
.
from_args
=
MagicMock
(
return_value
=
mock_router_instance
)
launch_router
(
args
)
# Should create router instance via from_args
router_mod
.
from_args
.
assert_called_once
()
def
test_launch_with_pd_config
(
self
):
"""Test launching with valid PD configuration."""
args
=
RouterArgs
(
pd_disaggregation
=
True
,
prefill_urls
=
[(
"http://prefill1:8000"
,
9000
)],
decode_urls
=
[
"http://decode1:8001"
],
policy
=
"cache_aware"
,
)
# Should not raise validation error
with
patch
(
"sglang_router.launch_router.Router"
)
as
router_mod
:
mock_router_instance
=
MagicMock
()
router_mod
.
from_args
=
MagicMock
(
return_value
=
mock_router_instance
)
launch_router
(
args
)
# Should create router instance via from_args
router_mod
.
from_args
.
assert_called_once
()
def
test_launch_with_service_discovery_config
(
self
):
"""Test launching with valid service discovery configuration."""
args
=
RouterArgs
(
service_discovery
=
True
,
selector
=
{
"app"
:
"worker"
},
service_discovery_port
=
8080
,
)
# Should not raise validation error
with
patch
(
"sglang_router.launch_router.Router"
)
as
router_mod
:
mock_router_instance
=
MagicMock
()
router_mod
.
from_args
=
MagicMock
(
return_value
=
mock_router_instance
)
launch_router
(
args
)
# Should create router instance via from_args
router_mod
.
from_args
.
assert_called_once
()
sgl-router/pyproject.toml
View file @
045ab92d
...
...
@@ -21,6 +21,7 @@ dev = [
"requests>=2.25.0"
,
]
# https://github.com/PyO3/setuptools-rust?tab=readme-ov-file
[tool.setuptools.packages]
find
=
{
where
=
["py_src"]
}
...
...
sgl-router/pytest.ini
0 → 100644
View file @
045ab92d
[pytest]
testpaths
=
py_test
python_files
=
test_*.py
python_classes
=
Test*
python_functions
=
test_*
addopts
=
--cov=sglang_router --cov-report=term-missing
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment