Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
9eb50ecc
Unverified
Commit
9eb50ecc
authored
Sep 06, 2025
by
Keyang Ru
Committed by
GitHub
Sep 06, 2025
Browse files
[router] Improve the router e2e tests (#10102)
parent
b3e7a2ce
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
422 additions
and
1120 deletions
+422
-1120
.github/workflows/pr-test-rust.yml
.github/workflows/pr-test-rust.yml
+3
-3
sgl-router/py_test/e2e/conftest.py
sgl-router/py_test/e2e/conftest.py
+235
-0
sgl-router/py_test/e2e/test_e2e_router.py
sgl-router/py_test/e2e/test_e2e_router.py
+146
-0
sgl-router/py_test/fixtures/mock_worker.py
sgl-router/py_test/fixtures/mock_worker.py
+5
-1
sgl-router/py_test/integration/test_payload_size.py
sgl-router/py_test/integration/test_payload_size.py
+33
-0
sgl-router/py_test/run_suite.py
sgl-router/py_test/run_suite.py
+0
-27
sgl-router/py_test/test_launch_router.py
sgl-router/py_test/test_launch_router.py
+0
-354
sgl-router/py_test/test_launch_server.py
sgl-router/py_test/test_launch_server.py
+0
-735
No files found.
.github/workflows/pr-test-rust.yml
View file @
9eb50ecc
...
...
@@ -105,11 +105,11 @@ jobs:
pip install fastapi uvicorn orjson
pytest -q -m integration
-
name
:
Run
e2e
test
-
name
:
Run
Python E2E
test
s
run
:
|
bash scripts/killall_sglang.sh "nuk_gpus"
cd sgl-router
/py_test
pyt
hon3 run_suite.py
cd sgl-router
pyt
est -m e2e -s -vv -o log_cli=true --log-cli-level=INFO
finish
:
needs
:
[
unit-test-rust
,
e2e-python
]
...
...
sgl-router/py_test/e2e/conftest.py
0 → 100644
View file @
9eb50ecc
import
socket
import
subprocess
import
time
from
types
import
SimpleNamespace
from
urllib.parse
import
urlparse
import
pytest
import
requests
from
sglang.test.test_utils
import
(
DEFAULT_MODEL_NAME_FOR_TEST
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
)
def
_find_available_port
()
->
int
:
with
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
as
s
:
s
.
bind
((
"127.0.0.1"
,
0
))
return
s
.
getsockname
()[
1
]
def
_parse_url
(
base_url
:
str
)
->
tuple
[
str
,
str
]:
"""Parse a base URL and return (host, port) as strings.
This is more robust than simple string splitting and supports different schemes
and URL shapes like trailing paths.
"""
parsed
=
urlparse
(
base_url
)
return
parsed
.
hostname
or
"127.0.0.1"
,
(
str
(
parsed
.
port
)
if
parsed
.
port
is
not
None
else
""
)
def
_wait_router_health
(
base_url
:
str
,
timeout
:
float
)
->
None
:
start
=
time
.
perf_counter
()
with
requests
.
Session
()
as
session
:
while
time
.
perf_counter
()
-
start
<
timeout
:
try
:
r
=
session
.
get
(
f
"
{
base_url
}
/health"
,
timeout
=
5
)
if
r
.
status_code
==
200
:
return
except
requests
.
RequestException
:
pass
time
.
sleep
(
2
)
raise
TimeoutError
(
"Router failed to become healthy in time"
)
def
_popen_launch_router
(
model
:
str
,
base_url
:
str
,
dp_size
:
int
,
timeout
:
float
,
policy
:
str
=
"cache_aware"
,
)
->
subprocess
.
Popen
:
host
,
port
=
_parse_url
(
base_url
)
prom_port
=
_find_available_port
()
cmd
=
[
"python3"
,
"-m"
,
"sglang_router.launch_server"
,
"--model-path"
,
model
,
"--host"
,
host
,
"--port"
,
port
,
"--dp"
,
str
(
dp_size
),
"--router-policy"
,
policy
,
"--allow-auto-truncate"
,
"--router-prometheus-port"
,
str
(
prom_port
),
"--router-prometheus-host"
,
"127.0.0.1"
,
]
proc
=
subprocess
.
Popen
(
cmd
)
_wait_router_health
(
base_url
,
timeout
)
return
proc
def
_popen_launch_worker
(
model
:
str
,
base_url
:
str
,
*
,
dp_size
:
int
|
None
=
None
,
api_key
:
str
|
None
=
None
,
)
->
subprocess
.
Popen
:
host
,
port
=
_parse_url
(
base_url
)
cmd
=
[
"python3"
,
"-m"
,
"sglang.launch_server"
,
"--model-path"
,
model
,
"--host"
,
host
,
"--port"
,
port
,
"--base-gpu-id"
,
"0"
,
]
if
dp_size
is
not
None
:
cmd
+=
[
"--dp-size"
,
str
(
dp_size
)]
if
api_key
is
not
None
:
cmd
+=
[
"--api-key"
,
api_key
]
return
subprocess
.
Popen
(
cmd
)
def
_popen_launch_router_only
(
base_url
:
str
,
policy
:
str
=
"round_robin"
,
timeout
:
float
=
120.0
,
*
,
dp_aware
:
bool
=
False
,
api_key
:
str
|
None
=
None
,
)
->
subprocess
.
Popen
:
host
,
port
=
_parse_url
(
base_url
)
prom_port
=
_find_available_port
()
cmd
=
[
"python3"
,
"-m"
,
"sglang_router.launch_router"
,
"--host"
,
host
,
"--port"
,
port
,
"--policy"
,
policy
,
]
if
dp_aware
:
cmd
+=
[
"--dp-aware"
]
if
api_key
is
not
None
:
cmd
+=
[
"--api-key"
,
api_key
]
cmd
+=
[
"--prometheus-port"
,
str
(
prom_port
),
"--prometheus-host"
,
"127.0.0.1"
,
]
proc
=
subprocess
.
Popen
(
cmd
)
_wait_router_health
(
base_url
,
timeout
)
return
proc
def
_terminate
(
proc
:
subprocess
.
Popen
,
timeout
:
float
=
120
)
->
None
:
if
proc
is
None
:
return
proc
.
terminate
()
start
=
time
.
perf_counter
()
while
proc
.
poll
()
is
None
:
if
time
.
perf_counter
()
-
start
>
timeout
:
proc
.
kill
()
break
time
.
sleep
(
1
)
def
pytest_configure
(
config
):
config
.
addinivalue_line
(
"markers"
,
"e2e: mark as end-to-end test"
)
@
pytest
.
fixture
(
scope
=
"session"
)
def
e2e_model
()
->
str
:
# Always use the default test model
return
DEFAULT_MODEL_NAME_FOR_TEST
@
pytest
.
fixture
def
e2e_router
(
e2e_model
:
str
):
# Keep this available but tests below use router-only to avoid GPU contention
base_url
=
DEFAULT_URL_FOR_TEST
proc
=
_popen_launch_router
(
e2e_model
,
base_url
,
dp_size
=
2
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
)
try
:
yield
SimpleNamespace
(
proc
=
proc
,
url
=
base_url
)
finally
:
_terminate
(
proc
)
@
pytest
.
fixture
def
e2e_router_only_rr
():
port
=
_find_available_port
()
base_url
=
f
"http://127.0.0.1:
{
port
}
"
proc
=
_popen_launch_router_only
(
base_url
,
policy
=
"round_robin"
)
try
:
yield
SimpleNamespace
(
proc
=
proc
,
url
=
base_url
)
finally
:
_terminate
(
proc
)
@
pytest
.
fixture
(
scope
=
"session"
)
def
e2e_primary_worker
(
e2e_model
:
str
):
port
=
_find_available_port
()
base_url
=
f
"http://127.0.0.1:
{
port
}
"
proc
=
_popen_launch_worker
(
e2e_model
,
base_url
)
# Router health gate will handle worker readiness
try
:
yield
SimpleNamespace
(
proc
=
proc
,
url
=
base_url
)
finally
:
_terminate
(
proc
)
@
pytest
.
fixture
def
e2e_router_only_rr_dp_aware_api
():
"""Router-only with dp-aware enabled and an API key."""
port
=
_find_available_port
()
base_url
=
f
"http://127.0.0.1:
{
port
}
"
api_key
=
"secret"
proc
=
_popen_launch_router_only
(
base_url
,
policy
=
"round_robin"
,
timeout
=
180.0
,
dp_aware
=
True
,
api_key
=
api_key
)
try
:
yield
SimpleNamespace
(
proc
=
proc
,
url
=
base_url
,
api_key
=
api_key
)
finally
:
_terminate
(
proc
)
@
pytest
.
fixture
def
e2e_worker_dp2_api
(
e2e_model
:
str
,
e2e_router_only_rr_dp_aware_api
):
"""Worker with dp-size=2 and the same API key as the dp-aware router."""
port
=
_find_available_port
()
base_url
=
f
"http://127.0.0.1:
{
port
}
"
api_key
=
e2e_router_only_rr_dp_aware_api
.
api_key
proc
=
_popen_launch_worker
(
e2e_model
,
base_url
,
dp_size
=
2
,
api_key
=
api_key
)
try
:
yield
SimpleNamespace
(
proc
=
proc
,
url
=
base_url
)
finally
:
_terminate
(
proc
)
sgl-router/py_test/e2e/test_e2e_router.py
0 → 100644
View file @
9eb50ecc
import
threading
import
time
from
types
import
SimpleNamespace
import
pytest
import
requests
from
sglang.test.run_eval
import
run_eval
@
pytest
.
mark
.
e2e
def
test_mmlu
(
e2e_router_only_rr
,
e2e_primary_worker
,
e2e_model
):
# Attach the primary worker to a fresh router-only instance (single model)
base
=
e2e_router_only_rr
.
url
r
=
requests
.
post
(
f
"
{
base
}
/add_worker"
,
params
=
{
"url"
:
e2e_primary_worker
.
url
},
timeout
=
180
)
r
.
raise_for_status
()
args
=
SimpleNamespace
(
base_url
=
base
,
model
=
e2e_model
,
eval_name
=
"mmlu"
,
num_examples
=
64
,
num_threads
=
32
,
temperature
=
0.1
,
)
metrics
=
run_eval
(
args
)
assert
metrics
[
"score"
]
>=
0.65
@
pytest
.
mark
.
e2e
def
test_add_and_remove_worker_live
(
e2e_router_only_rr
,
e2e_primary_worker
,
e2e_model
):
base
=
e2e_router_only_rr
.
url
worker_url
=
e2e_primary_worker
.
url
r
=
requests
.
post
(
f
"
{
base
}
/add_worker"
,
params
=
{
"url"
:
worker_url
},
timeout
=
180
)
r
.
raise_for_status
()
with
requests
.
Session
()
as
s
:
for
i
in
range
(
8
):
r
=
s
.
post
(
f
"
{
base
}
/v1/completions"
,
json
=
{
"model"
:
e2e_model
,
"prompt"
:
f
"x
{
i
}
"
,
"max_tokens"
:
1
,
"stream"
:
False
,
},
timeout
=
120
,
)
r
.
raise_for_status
()
# Remove the worker
r
=
requests
.
post
(
f
"
{
base
}
/remove_worker"
,
params
=
{
"url"
:
worker_url
},
timeout
=
60
)
r
.
raise_for_status
()
@
pytest
.
mark
.
e2e
def
test_lazy_fault_tolerance_live
(
e2e_router_only_rr
,
e2e_primary_worker
,
e2e_model
):
base
=
e2e_router_only_rr
.
url
worker
=
e2e_primary_worker
r
=
requests
.
post
(
f
"
{
base
}
/add_worker"
,
params
=
{
"url"
:
worker
.
url
},
timeout
=
180
)
r
.
raise_for_status
()
def
killer
():
time
.
sleep
(
10
)
try
:
worker
.
proc
.
terminate
()
except
Exception
:
pass
t
=
threading
.
Thread
(
target
=
killer
,
daemon
=
True
)
t
.
start
()
args
=
SimpleNamespace
(
base_url
=
base
,
model
=
e2e_model
,
eval_name
=
"mmlu"
,
num_examples
=
32
,
num_threads
=
16
,
temperature
=
0.0
,
)
metrics
=
run_eval
(
args
)
assert
0.0
<=
metrics
[
"score"
]
<=
1.0
@
pytest
.
mark
.
e2e
def
test_dp_aware_worker_expansion_and_api_key
(
e2e_model
,
e2e_router_only_rr_dp_aware_api
,
e2e_worker_dp2_api
,
):
"""
Launch a router-only instance in dp_aware mode and a single worker with dp_size=2
and API key protection. Verify expansion, auth enforcement, and basic eval.
"""
import
os
router_url
=
e2e_router_only_rr_dp_aware_api
.
url
worker_url
=
e2e_worker_dp2_api
.
url
api_key
=
e2e_router_only_rr_dp_aware_api
.
api_key
# Attach worker; router should expand to dp_size logical workers
r
=
requests
.
post
(
f
"
{
router_url
}
/add_worker"
,
params
=
{
"url"
:
worker_url
},
timeout
=
180
)
r
.
raise_for_status
()
r
=
requests
.
get
(
f
"
{
router_url
}
/list_workers"
,
timeout
=
30
)
r
.
raise_for_status
()
urls
=
r
.
json
().
get
(
"urls"
,
[])
assert
len
(
urls
)
==
2
assert
set
(
urls
)
==
{
f
"
{
worker_url
}
@0"
,
f
"
{
worker_url
}
@1"
}
# Verify API key enforcement path-through
# 1) Without Authorization -> 401 from backend
r
=
requests
.
post
(
f
"
{
router_url
}
/v1/completions"
,
json
=
{
"model"
:
e2e_model
,
"prompt"
:
"hi"
,
"max_tokens"
:
1
},
timeout
=
60
,
)
assert
r
.
status_code
==
401
# 2) With correct Authorization -> 200
r
=
requests
.
post
(
f
"
{
router_url
}
/v1/completions"
,
json
=
{
"model"
:
e2e_model
,
"prompt"
:
"hi"
,
"max_tokens"
:
1
},
headers
=
{
"Authorization"
:
f
"Bearer
{
api_key
}
"
},
timeout
=
60
,
)
assert
r
.
status_code
==
200
# Finally, run MMLU eval through the router with auth
os
.
environ
[
"OPENAI_API_KEY"
]
=
api_key
args
=
SimpleNamespace
(
base_url
=
router_url
,
model
=
e2e_model
,
eval_name
=
"mmlu"
,
num_examples
=
64
,
num_threads
=
32
,
temperature
=
0.1
,
)
metrics
=
run_eval
(
args
)
assert
metrics
[
"score"
]
>=
0.65
sgl-router/py_test/fixtures/mock_worker.py
View file @
9eb50ecc
...
...
@@ -44,6 +44,7 @@ def _parse_args() -> argparse.Namespace:
p
.
add_argument
(
"--api-key"
,
default
=
None
)
p
.
add_argument
(
"--max-payload-bytes"
,
type
=
int
,
default
=
10
*
1024
*
1024
)
p
.
add_argument
(
"--stream"
,
action
=
"store_true"
)
p
.
add_argument
(
"--dp-size"
,
type
=
int
,
default
=
1
)
p
.
add_argument
(
"--crash-on-request"
,
action
=
"store_true"
)
p
.
add_argument
(
"--health-fail-after-ms"
,
type
=
int
,
default
=
0
)
return
p
.
parse_args
()
...
...
@@ -125,12 +126,15 @@ def create_app(args: argparse.Namespace) -> FastAPI:
return
JSONResponse
({
"data"
:
[{
"id"
:
"mock"
,
"object"
:
"model"
}]})
@
app
.
get
(
"/get_server_info"
)
async
def
get_server_info
():
async
def
get_server_info
(
request
:
Request
):
# Enforce API key on server info when required (used by dp_aware probing)
check_api_key
(
request
)
return
JSONResponse
(
{
"worker_id"
:
worker_id
,
"load_in_flight"
:
_inflight
,
"cache"
:
{
"size"
:
0
,
"hit_rate"
:
0.0
},
"dp_size"
:
int
(
args
.
dp_size
),
}
)
...
...
sgl-router/py_test/integration/test_payload_size.py
0 → 100644
View file @
9eb50ecc
import
pytest
import
requests
@
pytest
.
mark
.
integration
def
test_payload_size_limit
(
router_manager
,
mock_workers
):
# Start one backend and a router with a 1MB payload limit
_
,
urls
,
_
=
mock_workers
(
n
=
1
)
rh
=
router_manager
.
start_router
(
worker_urls
=
urls
,
policy
=
"round_robin"
,
extra
=
{
"max_payload_size"
:
1
*
1024
*
1024
},
# 1MB
)
# Payload just under 1MB should succeed
payload_small
=
{
"model"
:
"test-model"
,
"prompt"
:
"x"
*
int
(
0.5
*
1024
*
1024
),
# ~0.5MB
"max_tokens"
:
1
,
"stream"
:
False
,
}
r
=
requests
.
post
(
f
"
{
rh
.
url
}
/v1/completions"
,
json
=
payload_small
)
assert
r
.
status_code
==
200
# Payload over 1MB should fail with 413
payload_large
=
{
"model"
:
"test-model"
,
"prompt"
:
"x"
*
int
(
1.2
*
1024
*
1024
),
# ~1.2MB
"max_tokens"
:
1
,
"stream"
:
False
,
}
r
=
requests
.
post
(
f
"
{
rh
.
url
}
/v1/completions"
,
json
=
payload_large
)
assert
r
.
status_code
==
413
sgl-router/py_test/run_suite.py
deleted
100644 → 0
View file @
b3e7a2ce
import
argparse
import
glob
from
sglang.test.test_utils
import
TestFile
,
run_unittest_files
if
__name__
==
"__main__"
:
arg_parser
=
argparse
.
ArgumentParser
()
arg_parser
.
add_argument
(
"--timeout-per-file"
,
type
=
int
,
default
=
2000
,
help
=
"The time limit for running one file in seconds."
,
)
args
=
arg_parser
.
parse_args
()
files
=
glob
.
glob
(
"**/test_*.py"
,
recursive
=
True
)
# Exclude integration tests from the e2e suite; those are run separately via pytest -m integration
files
=
[
f
for
f
in
files
if
"/integration/"
not
in
f
and
not
f
.
startswith
(
"integration/"
)
]
files
.
sort
()
test_files
=
[
TestFile
(
name
=
file
)
for
file
in
files
]
exit_code
=
run_unittest_files
(
test_files
,
args
.
timeout_per_file
)
exit
(
exit_code
)
sgl-router/py_test/test_launch_router.py
deleted
100644 → 0
View file @
b3e7a2ce
import
multiprocessing
import
time
import
unittest
from
types
import
SimpleNamespace
def
terminate_process
(
process
:
multiprocessing
.
Process
,
timeout
:
float
=
1.0
)
->
None
:
"""Terminate a process gracefully, with forced kill as fallback.
Args:
process: The process to terminate
timeout: Seconds to wait for graceful termination before forcing kill
"""
if
not
process
.
is_alive
():
return
process
.
terminate
()
process
.
join
(
timeout
=
timeout
)
if
process
.
is_alive
():
process
.
kill
()
# Force kill if terminate didn't work
process
.
join
()
class
TestLaunchRouter
(
unittest
.
TestCase
):
def
setUp
(
self
):
"""Set up default arguments for router tests."""
self
.
default_args
=
SimpleNamespace
(
host
=
"127.0.0.1"
,
port
=
30000
,
policy
=
"cache_aware"
,
worker_startup_timeout_secs
=
600
,
worker_startup_check_interval
=
10
,
cache_threshold
=
0.5
,
balance_abs_threshold
=
32
,
balance_rel_threshold
=
1.0001
,
eviction_interval_secs
=
60
,
max_tree_size
=
2
**
24
,
max_payload_size
=
256
*
1024
*
1024
,
# 256MB
verbose
=
False
,
log_dir
=
None
,
log_level
=
None
,
service_discovery
=
False
,
selector
=
None
,
service_discovery_port
=
80
,
service_discovery_namespace
=
None
,
dp_aware
=
False
,
prometheus_port
=
None
,
prometheus_host
=
None
,
request_timeout_secs
=
60
,
max_concurrent_requests
=
64
,
cors_allowed_origins
=
[],
pd_disaggregation
=
False
,
prefill
=
None
,
decode
=
None
,
worker_urls
=
[],
retry_max_retries
=
3
,
retry_initial_backoff_ms
=
100
,
retry_max_backoff_ms
=
10_000
,
retry_backoff_multiplier
=
2.0
,
retry_jitter_factor
=
0.1
,
cb_failure_threshold
=
5
,
cb_success_threshold
=
2
,
cb_timeout_duration_secs
=
30
,
cb_window_duration_secs
=
60
,
disable_retries
=
False
,
disable_circuit_breaker
=
False
,
model_path
=
None
,
tokenizer_path
=
None
,
)
def
create_router_args
(
self
,
**
kwargs
):
"""Create router arguments by updating default args with provided kwargs."""
args_dict
=
vars
(
self
.
default_args
).
copy
()
args_dict
.
update
(
kwargs
)
return
SimpleNamespace
(
**
args_dict
)
def
run_router_process
(
self
,
args
):
"""Run router in a separate process and verify it starts successfully."""
def
run_router
():
try
:
from
sglang_router.launch_router
import
launch_router
router
=
launch_router
(
args
)
if
router
is
None
:
return
1
return
0
except
Exception
as
e
:
print
(
e
)
return
1
process
=
multiprocessing
.
Process
(
target
=
run_router
)
try
:
process
.
start
()
# Wait 3 seconds
time
.
sleep
(
3
)
# Process is still running means router started successfully
self
.
assertTrue
(
process
.
is_alive
())
finally
:
terminate_process
(
process
)
def
test_launch_router_common
(
self
):
args
=
self
.
create_router_args
(
worker_urls
=
[
"http://localhost:8000"
])
self
.
run_router_process
(
args
)
def
test_launch_router_with_empty_worker_urls
(
self
):
args
=
self
.
create_router_args
(
worker_urls
=
[])
self
.
run_router_process
(
args
)
# Should start successfully with empty worker list
def
test_launch_router_with_service_discovery
(
self
):
# Test router startup with service discovery enabled but no selectors
args
=
self
.
create_router_args
(
worker_urls
=
[],
service_discovery
=
True
,
selector
=
[
"app=test-worker"
]
)
self
.
run_router_process
(
args
)
def
test_launch_router_with_service_discovery_namespace
(
self
):
# Test router startup with service discovery enabled and namespace specified
args
=
self
.
create_router_args
(
worker_urls
=
[],
service_discovery
=
True
,
selector
=
[
"app=test-worker"
],
service_discovery_namespace
=
"test-namespace"
,
)
self
.
run_router_process
(
args
)
def
test_launch_router_common_with_dp_aware
(
self
):
args
=
self
.
create_router_args
(
worker_urls
=
[
"http://localhost:8000"
],
dp_aware
=
True
,
)
self
.
run_router_process
(
args
)
def
test_launch_router_with_empty_worker_urls_with_dp_aware
(
self
):
args
=
self
.
create_router_args
(
worker_urls
=
[],
dp_aware
=
True
,
)
self
.
run_router_process
(
args
)
def
test_launch_router_common_with_dp_aware_service_discovery
(
self
):
# Test launch router with bot srevice_discovery and dp_aware enabled
# Should fail since service_discovery and dp_aware is conflict
args
=
self
.
create_router_args
(
worker_urls
=
[
"http://localhost:8000"
],
dp_aware
=
True
,
service_discovery
=
True
,
selector
=
[
"app=test-worker"
],
)
def
run_router
():
try
:
from
sglang_router.launch_router
import
launch_router
router
=
launch_router
(
args
)
if
router
is
None
:
return
1
return
0
except
Exception
as
e
:
print
(
e
)
return
1
process
=
multiprocessing
.
Process
(
target
=
run_router
)
try
:
process
.
start
()
# Wait 3 seconds
time
.
sleep
(
3
)
# Should fail since service_discovery and dp_aware is conflict
self
.
assertFalse
(
process
.
is_alive
())
finally
:
terminate_process
(
process
)
def
test_launch_router_pd_mode_basic
(
self
):
"""Test basic PD router functionality without actually starting servers."""
# This test just verifies the PD router can be created and configured
# without actually starting it (which would require real prefill/decode servers)
from
sglang_router.launch_router
import
RouterArgs
from
sglang_router.router
import
PolicyType
,
Router
# Test RouterArgs parsing for PD mode
# Simulate the parsed args structure from argparse with action="append"
args
=
self
.
create_router_args
(
pd_disaggregation
=
True
,
policy
=
"power_of_two"
,
# PowerOfTwo is only valid in PD mode
prefill
=
[
[
"http://prefill1:8080"
,
"9000"
],
[
"http://prefill2:8080"
,
"none"
],
],
decode
=
[
[
"http://decode1:8081"
],
[
"http://decode2:8081"
],
],
worker_urls
=
[],
# Empty for PD mode
)
router_args
=
RouterArgs
.
from_cli_args
(
args
)
self
.
assertTrue
(
router_args
.
pd_disaggregation
)
self
.
assertEqual
(
router_args
.
policy
,
"power_of_two"
)
self
.
assertEqual
(
len
(
router_args
.
prefill_urls
),
2
)
self
.
assertEqual
(
len
(
router_args
.
decode_urls
),
2
)
# Verify the parsed URLs and bootstrap ports
self
.
assertEqual
(
router_args
.
prefill_urls
[
0
],
(
"http://prefill1:8080"
,
9000
))
self
.
assertEqual
(
router_args
.
prefill_urls
[
1
],
(
"http://prefill2:8080"
,
None
))
self
.
assertEqual
(
router_args
.
decode_urls
[
0
],
"http://decode1:8081"
)
self
.
assertEqual
(
router_args
.
decode_urls
[
1
],
"http://decode2:8081"
)
# Test Router creation in PD mode
router
=
Router
.
from_args
(
router_args
)
self
.
assertIsNotNone
(
router
)
def
test_policy_validation
(
self
):
"""Test that policy validation works correctly for PD and regular modes."""
from
sglang_router.launch_router
import
RouterArgs
,
launch_router
# Test 1: PowerOfTwo requires at least 2 workers
args
=
self
.
create_router_args
(
pd_disaggregation
=
False
,
policy
=
"power_of_two"
,
worker_urls
=
[
"http://localhost:8000"
],
# Only 1 worker
)
# Should raise error
with
self
.
assertRaises
(
ValueError
)
as
cm
:
launch_router
(
args
)
self
.
assertIn
(
"Power-of-two policy requires at least 2 workers"
,
str
(
cm
.
exception
),
)
# Test 2: PowerOfTwo with sufficient workers should succeed
args
=
self
.
create_router_args
(
pd_disaggregation
=
False
,
policy
=
"power_of_two"
,
worker_urls
=
[
"http://localhost:8000"
,
"http://localhost:8001"
],
# 2 workers
)
# This should not raise an error (validation passes)
# Test 3: All policies now work in both modes
# Regular mode with RoundRobin
args
=
self
.
create_router_args
(
pd_disaggregation
=
False
,
policy
=
"round_robin"
,
worker_urls
=
[
"http://localhost:8000"
],
)
# This should not raise validation error
# PD mode with RoundRobin (now supported!)
args
=
self
.
create_router_args
(
pd_disaggregation
=
True
,
policy
=
"round_robin"
,
prefill
=
[[
"http://prefill1:8080"
,
"9000"
]],
decode
=
[[
"http://decode1:8081"
]],
worker_urls
=
[],
)
# This should not raise validation error
def
test_pd_service_discovery_args_parsing
(
self
):
"""Test PD service discovery CLI argument parsing."""
import
argparse
from
sglang_router.launch_router
import
RouterArgs
parser
=
argparse
.
ArgumentParser
()
RouterArgs
.
add_cli_args
(
parser
)
args
=
parser
.
parse_args
(
[
"--pd-disaggregation"
,
"--service-discovery"
,
"--prefill-selector"
,
"app=sglang"
,
"component=prefill"
,
"--decode-selector"
,
"app=sglang"
,
"component=decode"
,
"--service-discovery-port"
,
"8000"
,
"--service-discovery-namespace"
,
"production"
,
"--policy"
,
"cache_aware"
,
]
)
router_args
=
RouterArgs
.
from_cli_args
(
args
)
self
.
assertTrue
(
router_args
.
pd_disaggregation
)
self
.
assertTrue
(
router_args
.
service_discovery
)
self
.
assertEqual
(
router_args
.
prefill_selector
,
{
"app"
:
"sglang"
,
"component"
:
"prefill"
}
)
self
.
assertEqual
(
router_args
.
decode_selector
,
{
"app"
:
"sglang"
,
"component"
:
"decode"
}
)
self
.
assertEqual
(
router_args
.
service_discovery_port
,
8000
)
self
.
assertEqual
(
router_args
.
service_discovery_namespace
,
"production"
)
def
test_regular_service_discovery_args_parsing
(
self
):
"""Test regular mode service discovery CLI argument parsing."""
import
argparse
from
sglang_router.launch_router
import
RouterArgs
parser
=
argparse
.
ArgumentParser
()
RouterArgs
.
add_cli_args
(
parser
)
args
=
parser
.
parse_args
(
[
"--service-discovery"
,
"--selector"
,
"app=sglang-worker"
,
"environment=staging"
,
"--service-discovery-port"
,
"8000"
,
"--policy"
,
"round_robin"
,
]
)
router_args
=
RouterArgs
.
from_cli_args
(
args
)
self
.
assertFalse
(
router_args
.
pd_disaggregation
)
self
.
assertTrue
(
router_args
.
service_discovery
)
self
.
assertEqual
(
router_args
.
selector
,
{
"app"
:
"sglang-worker"
,
"environment"
:
"staging"
}
)
self
.
assertEqual
(
router_args
.
prefill_selector
,
{})
self
.
assertEqual
(
router_args
.
decode_selector
,
{})
def
test_empty_worker_urls_args_parsing
(
self
):
"""Test that router accepts no worker URLs and defaults to empty list."""
import
argparse
from
sglang_router.launch_router
import
RouterArgs
parser
=
argparse
.
ArgumentParser
()
RouterArgs
.
add_cli_args
(
parser
)
# Test with no --worker-urls argument at all
args
=
parser
.
parse_args
([
"--policy"
,
"random"
,
"--port"
,
"30000"
])
router_args
=
RouterArgs
.
from_cli_args
(
args
)
self
.
assertEqual
(
router_args
.
worker_urls
,
[])
# Test with explicit empty --worker-urls
args
=
parser
.
parse_args
([
"--worker-urls"
,
"--policy"
,
"random"
])
router_args
=
RouterArgs
.
from_cli_args
(
args
)
self
.
assertEqual
(
router_args
.
worker_urls
,
[])
if
__name__
==
"__main__"
:
unittest
.
main
()
sgl-router/py_test/test_launch_server.py
deleted
100644 → 0
View file @
b3e7a2ce
import
socket
import
subprocess
import
time
import
unittest
from
types
import
SimpleNamespace
import
requests
from
sglang.srt.utils
import
kill_process_tree
from
sglang.test.run_eval
import
run_eval
from
sglang.test.test_utils
import
(
DEFAULT_MODEL_NAME_FOR_TEST
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
)
def
popen_launch_router
(
model
:
str
,
base_url
:
str
,
dp_size
:
int
,
timeout
:
float
,
policy
:
str
=
"cache_aware"
,
max_payload_size
:
int
=
None
,
api_key
:
str
=
None
,
log_dir
:
str
=
None
,
service_discovery
:
bool
=
False
,
selector
:
list
=
None
,
service_discovery_port
:
int
=
80
,
service_discovery_namespace
:
str
=
None
,
prometheus_port
:
int
=
None
,
prometheus_host
:
str
=
None
,
dp_aware
:
bool
=
False
,
# Router retry/CB tuning (optional)
router_retry_max_retries
:
int
=
None
,
router_retry_initial_backoff_ms
:
int
=
None
,
router_retry_max_backoff_ms
:
int
=
None
,
router_retry_backoff_multiplier
:
float
=
None
,
router_retry_jitter_factor
:
float
=
None
,
router_cb_failure_threshold
:
int
=
None
,
router_cb_success_threshold
:
int
=
None
,
router_cb_timeout_duration_secs
:
int
=
None
,
router_cb_window_duration_secs
:
int
=
None
,
):
"""
Launch the router server process.
Args:
model: Model path/name
base_url: Server base URL
dp_size: Data parallel size
timeout: Server launch timeout
policy: Router policy, one of "cache_aware", "round_robin", "random"
max_payload_size: Maximum payload size in bytes
api_key: API key for the router
log_dir: Directory to store log files. If None, logs are only output to console.
service_discovery: Enable Kubernetes service discovery
selector: List of label selectors in format ["key1=value1", "key2=value2"]
service_discovery_port: Port to use for service discovery
service_discovery_namespace: Kubernetes namespace to watch for pods. If None, watches all namespaces.
prometheus_port: Port to expose Prometheus metrics. If None, Prometheus metrics are disabled.
prometheus_host: Host address to bind the Prometheus metrics server.
dp_aware: Enable data parallelism aware routing strategy.
"""
_
,
host
,
port
=
base_url
.
split
(
":"
)
host
=
host
[
2
:]
command
=
[
"python3"
,
"-m"
,
"sglang_router.launch_server"
,
"--model-path"
,
model
,
"--host"
,
host
,
"--port"
,
port
,
"--dp"
,
str
(
dp_size
),
"--router-eviction-interval-secs"
,
"5"
,
"--router-policy"
,
policy
,
"--allow-auto-truncate"
,
]
if
api_key
is
not
None
:
command
.
extend
([
"--api-key"
,
api_key
])
command
.
extend
([
"--router-api-key"
,
api_key
])
if
max_payload_size
is
not
None
:
command
.
extend
([
"--router-max-payload-size"
,
str
(
max_payload_size
)])
if
service_discovery
:
command
.
append
(
"--router-service-discovery"
)
if
selector
:
command
.
extend
([
"--router-selector"
]
+
selector
)
if
service_discovery_port
!=
80
:
command
.
extend
([
"--router-service-discovery-port"
,
str
(
service_discovery_port
)])
if
service_discovery_namespace
:
command
.
extend
(
[
"--router-service-discovery-namespace"
,
service_discovery_namespace
]
)
if
prometheus_port
is
not
None
:
command
.
extend
([
"--router-prometheus-port"
,
str
(
prometheus_port
)])
if
prometheus_host
is
not
None
:
command
.
extend
([
"--router-prometheus-host"
,
prometheus_host
])
if
log_dir
is
not
None
:
command
.
extend
([
"--log-dir"
,
log_dir
])
if
dp_aware
:
command
.
append
(
"--router-dp-aware"
)
# Append router retry/CB tuning flags if provided
def
_add
(
flag
:
str
,
val
):
if
val
is
not
None
:
command
.
extend
([
flag
,
str
(
val
)])
_add
(
"--router-retry-max-retries"
,
router_retry_max_retries
)
_add
(
"--router-retry-initial-backoff-ms"
,
router_retry_initial_backoff_ms
)
_add
(
"--router-retry-max-backoff-ms"
,
router_retry_max_backoff_ms
)
_add
(
"--router-retry-backoff-multiplier"
,
router_retry_backoff_multiplier
)
_add
(
"--router-retry-jitter-factor"
,
router_retry_jitter_factor
)
_add
(
"--router-cb-failure-threshold"
,
router_cb_failure_threshold
)
_add
(
"--router-cb-success-threshold"
,
router_cb_success_threshold
)
_add
(
"--router-cb-timeout-duration-secs"
,
router_cb_timeout_duration_secs
)
_add
(
"--router-cb-window-duration-secs"
,
router_cb_window_duration_secs
)
process
=
subprocess
.
Popen
(
command
,
stdout
=
None
,
stderr
=
None
)
start_time
=
time
.
perf_counter
()
with
requests
.
Session
()
as
session
:
while
time
.
perf_counter
()
-
start_time
<
timeout
:
try
:
response
=
session
.
get
(
f
"
{
base_url
}
/health"
)
if
response
.
status_code
==
200
:
print
(
f
"Router
{
base_url
}
is healthy"
)
return
process
except
requests
.
RequestException
:
pass
time
.
sleep
(
10
)
raise
TimeoutError
(
"Router failed to start within the timeout period."
)
def
find_available_port
():
with
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
as
s
:
s
.
bind
((
"127.0.0.1"
,
0
))
return
s
.
getsockname
()[
1
]
def
popen_launch_server
(
model
:
str
,
base_url
:
str
,
timeout
:
float
,
api_key
:
str
=
None
,
):
_
,
host
,
port
=
base_url
.
split
(
":"
)
host
=
host
[
2
:]
command
=
[
"python3"
,
"-m"
,
"sglang.launch_server"
,
"--model-path"
,
model
,
"--host"
,
host
,
"--port"
,
port
,
"--base-gpu-id"
,
"1"
,
]
if
api_key
is
not
None
:
command
.
extend
([
"--api-key"
,
api_key
])
process
=
subprocess
.
Popen
(
command
,
stdout
=
None
,
stderr
=
None
)
# intentionally don't wait and defer the job to the router health check
return
process
def
terminate_and_wait
(
process
,
timeout
=
300
):
"""Terminate a process and wait until it is terminated.
Args:
process: subprocess.Popen object
timeout: maximum time to wait in seconds
Raises:
TimeoutError: if process does not terminate within timeout
"""
if
process
is
None
:
return
process
.
terminate
()
start_time
=
time
.
perf_counter
()
while
process
.
poll
()
is
None
:
print
(
f
"Terminating process
{
process
.
pid
}
"
)
if
time
.
perf_counter
()
-
start_time
>
timeout
:
raise
TimeoutError
(
f
"Process
{
process
.
pid
}
failed to terminate within
{
timeout
}
s"
)
time
.
sleep
(
1
)
print
(
f
"Process
{
process
.
pid
}
is successfully terminated"
)
class
TestLaunchServer
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
model
=
DEFAULT_MODEL_NAME_FOR_TEST
self
.
base_url
=
DEFAULT_URL_FOR_TEST
self
.
process
=
None
self
.
other_process
=
[]
def
tearDown
(
self
):
print
(
"Running tearDown..."
)
if
self
.
process
:
terminate_and_wait
(
self
.
process
)
for
process
in
self
.
other_process
:
terminate_and_wait
(
process
)
print
(
"tearDown done"
)
def
test_1_mmlu
(
self
):
print
(
"Running test_1_mmlu..."
)
# DP size = 2
self
.
process
=
popen_launch_router
(
self
.
model
,
self
.
base_url
,
dp_size
=
2
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
policy
=
"cache_aware"
,
)
args
=
SimpleNamespace
(
base_url
=
self
.
base_url
,
model
=
self
.
model
,
eval_name
=
"mmlu"
,
num_examples
=
64
,
num_threads
=
32
,
temperature
=
0.1
,
)
metrics
=
run_eval
(
args
)
score
=
metrics
[
"score"
]
THRESHOLD
=
0.635
passed
=
score
>=
THRESHOLD
msg
=
f
"MMLU test
{
'passed'
if
passed
else
'failed'
}
with score
{
score
:.
3
f
}
(threshold:
{
THRESHOLD
}
)"
self
.
assertGreaterEqual
(
score
,
THRESHOLD
,
msg
)
def
test_2_add_and_remove_worker
(
self
):
print
(
"Running test_2_add_and_remove_worker..."
)
# DP size = 1
self
.
process
=
popen_launch_router
(
self
.
model
,
self
.
base_url
,
dp_size
=
1
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
policy
=
"round_robin"
,
# use round robin to make sure every worker processes requests
)
# 1. start a worker
port
=
find_available_port
()
worker_url
=
f
"http://127.0.0.1:
{
port
}
"
worker_process
=
popen_launch_server
(
self
.
model
,
worker_url
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
)
self
.
other_process
.
append
(
worker_process
)
# 2. use /add_worker api to add it to the router. It will be used by the router after it is healthy
with
requests
.
Session
()
as
session
:
response
=
session
.
post
(
f
"
{
self
.
base_url
}
/add_worker?url=
{
worker_url
}
"
)
print
(
f
"status code:
{
response
.
status_code
}
, response:
{
response
.
text
}
"
)
self
.
assertEqual
(
response
.
status_code
,
200
)
# 3. run mmlu
args
=
SimpleNamespace
(
base_url
=
self
.
base_url
,
model
=
self
.
model
,
eval_name
=
"mmlu"
,
num_examples
=
64
,
num_threads
=
32
,
temperature
=
0.1
,
)
metrics
=
run_eval
(
args
)
score
=
metrics
[
"score"
]
THRESHOLD
=
0.635
passed
=
score
>=
THRESHOLD
msg
=
f
"MMLU test
{
'passed'
if
passed
else
'failed'
}
with score
{
score
:.
3
f
}
(threshold:
{
THRESHOLD
}
)"
self
.
assertGreaterEqual
(
score
,
THRESHOLD
,
msg
)
# 4. use /remove_worker api to remove it from the router
with
requests
.
Session
()
as
session
:
response
=
session
.
post
(
f
"
{
self
.
base_url
}
/remove_worker?url=
{
worker_url
}
"
)
print
(
f
"status code:
{
response
.
status_code
}
, response:
{
response
.
text
}
"
)
self
.
assertEqual
(
response
.
status_code
,
200
)
# 5. run mmlu again
metrics
=
run_eval
(
args
)
score
=
metrics
[
"score"
]
THRESHOLD
=
0.635
passed
=
score
>=
THRESHOLD
msg
=
f
"MMLU test
{
'passed'
if
passed
else
'failed'
}
with score
{
score
:.
3
f
}
(threshold:
{
THRESHOLD
}
)"
self
.
assertGreaterEqual
(
score
,
THRESHOLD
,
msg
)
def
test_3_lazy_fault_tolerance
(
self
):
print
(
"Running test_3_lazy_fault_tolerance..."
)
# DP size = 1
self
.
process
=
popen_launch_router
(
self
.
model
,
self
.
base_url
,
dp_size
=
1
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
policy
=
"round_robin"
,
)
# 1. start a worker
port
=
find_available_port
()
worker_url
=
f
"http://127.0.0.1:
{
port
}
"
worker_process
=
popen_launch_server
(
self
.
model
,
worker_url
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
)
self
.
other_process
.
append
(
worker_process
)
# 2. use /add_worker api to add it to the router. It will be used by the router after it is healthy
with
requests
.
Session
()
as
session
:
response
=
session
.
post
(
f
"
{
self
.
base_url
}
/add_worker?url=
{
worker_url
}
"
)
print
(
f
"status code:
{
response
.
status_code
}
, response:
{
response
.
text
}
"
)
self
.
assertEqual
(
response
.
status_code
,
200
)
# Start a thread to kill the worker after 10 seconds to mimic abrupt worker failure
def
kill_worker
():
time
.
sleep
(
10
)
kill_process_tree
(
worker_process
.
pid
)
print
(
"Worker process killed"
)
import
threading
kill_thread
=
threading
.
Thread
(
target
=
kill_worker
)
kill_thread
.
daemon
=
True
kill_thread
.
start
()
# 3. run mmlu
args
=
SimpleNamespace
(
base_url
=
self
.
base_url
,
model
=
self
.
model
,
eval_name
=
"mmlu"
,
num_examples
=
256
,
num_threads
=
32
,
temperature
=
0.1
,
)
metrics
=
run_eval
(
args
)
score
=
metrics
[
"score"
]
THRESHOLD
=
0.635
passed
=
score
>=
THRESHOLD
msg
=
f
"MMLU test
{
'passed'
if
passed
else
'failed'
}
with score
{
score
:.
3
f
}
(threshold:
{
THRESHOLD
}
)"
self
.
assertGreaterEqual
(
score
,
THRESHOLD
,
msg
)
def
test_4_payload_size
(
self
):
print
(
"Running test_4_payload_size..."
)
# Start router with 1MB limit
self
.
process
=
popen_launch_router
(
self
.
model
,
self
.
base_url
,
dp_size
=
1
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
policy
=
"round_robin"
,
max_payload_size
=
1
*
1024
*
1024
,
# 1MB limit
)
# Test case 1: Payload just under 1MB should succeed
payload_0_5_mb
=
{
"text"
:
"x"
*
int
(
0.5
*
1024
*
1024
),
# 0.5MB of text
"temperature"
:
0.0
,
}
with
requests
.
Session
()
as
session
:
response
=
session
.
post
(
f
"
{
self
.
base_url
}
/generate"
,
json
=
payload_0_5_mb
,
headers
=
{
"Content-Type"
:
"application/json"
},
)
self
.
assertEqual
(
response
.
status_code
,
200
,
f
"0.5MB payload should succeed but got status
{
response
.
status_code
}
"
,
)
# Test case 2: Payload over 1MB should fail
payload_1_plus_mb
=
{
"text"
:
"x"
*
int
((
1.2
*
1024
*
1024
)),
# 1.2MB of text
"temperature"
:
0.0
,
}
with
requests
.
Session
()
as
session
:
response
=
session
.
post
(
f
"
{
self
.
base_url
}
/generate"
,
json
=
payload_1_plus_mb
,
headers
=
{
"Content-Type"
:
"application/json"
},
)
self
.
assertEqual
(
response
.
status_code
,
413
,
# Payload Too Large
f
"1.2MB payload should fail with 413 but got status
{
response
.
status_code
}
"
,
)
def
test_5_api_key
(
self
):
print
(
"Running test_5_api_key..."
)
self
.
process
=
popen_launch_router
(
self
.
model
,
self
.
base_url
,
dp_size
=
1
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
policy
=
"round_robin"
,
api_key
=
"correct_api_key"
,
)
# Test case 1: request without api key should fail
with
requests
.
Session
()
as
session
:
response
=
session
.
post
(
f
"
{
self
.
base_url
}
/generate"
,
json
=
{
"text"
:
"Kanye west is, "
,
"temperature"
:
0
},
)
print
(
f
"status code:
{
response
.
status_code
}
, response:
{
response
.
text
}
"
)
self
.
assertEqual
(
response
.
status_code
,
401
,
"Request without api key should fail with 401"
,
)
# Test case 2: request with invalid api key should fail
with
requests
.
Session
()
as
session
:
response
=
requests
.
post
(
f
"
{
self
.
base_url
}
/generate"
,
json
=
{
"text"
:
"Kanye west is, "
,
"temperature"
:
0
},
headers
=
{
"Authorization"
:
"Bearer 123"
},
)
print
(
f
"status code:
{
response
.
status_code
}
, response:
{
response
.
text
}
"
)
self
.
assertEqual
(
response
.
status_code
,
401
,
"Request with invalid api key should fail with 401"
,
)
# Test case 3: request with correct api key should succeed
with
requests
.
Session
()
as
session
:
response
=
session
.
post
(
f
"
{
self
.
base_url
}
/generate"
,
json
=
{
"text"
:
"Kanye west is "
,
"temperature"
:
0
},
headers
=
{
"Authorization"
:
"Bearer correct_api_key"
},
)
print
(
f
"status code:
{
response
.
status_code
}
, response:
{
response
.
text
}
"
)
self
.
assertEqual
(
response
.
status_code
,
200
,
"Request with correct api key should succeed"
)
def
test_6_mmlu_with_dp_aware
(
self
):
print
(
"Running test_6_mmlu_with_dp_aware..."
)
# DP size = 2
self
.
process
=
popen_launch_router
(
self
.
model
,
self
.
base_url
,
dp_size
=
2
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
policy
=
"cache_aware"
,
dp_aware
=
True
,
)
args
=
SimpleNamespace
(
base_url
=
self
.
base_url
,
model
=
self
.
model
,
eval_name
=
"mmlu"
,
num_examples
=
64
,
num_threads
=
32
,
temperature
=
0.1
,
)
metrics
=
run_eval
(
args
)
score
=
metrics
[
"score"
]
THRESHOLD
=
0.635
passed
=
score
>=
THRESHOLD
msg
=
f
"dp aware MMLU test
{
'passed'
if
passed
else
'failed'
}
with score
{
score
:.
3
f
}
(threshold:
{
THRESHOLD
}
)"
self
.
assertGreaterEqual
(
score
,
THRESHOLD
,
msg
)
def
test_7_add_and_remove_worker_with_dp_aware
(
self
):
print
(
"Running test_7_add_and_remove_worker_with_dp_aware..."
)
# Set dp_size = 1
self
.
process
=
popen_launch_router
(
self
.
model
,
self
.
base_url
,
dp_size
=
1
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
policy
=
"round_robin"
,
# make sure every worker processes requests
dp_aware
=
True
,
# dp aware strategy should work well with RR
)
# 1. Start a worker
port
=
find_available_port
()
worker_url
=
f
"http://127.0.0.1:
{
port
}
"
worker_process
=
popen_launch_server
(
self
.
model
,
worker_url
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
)
self
.
other_process
.
append
(
worker_process
)
# 2. Use the /add_worker API to add it to the router
# It will be used by router after it is healthy
with
requests
.
Session
()
as
session
:
response
=
session
.
post
(
f
"
{
self
.
base_url
}
/add_worker?url=
{
worker_url
}
"
)
print
(
f
"status code:
{
response
.
status_code
}
, response:
{
response
.
text
}
"
)
self
.
assertEqual
(
response
.
status_code
,
200
)
# 3. Run mmlu
args
=
SimpleNamespace
(
base_url
=
self
.
base_url
,
model
=
self
.
model
,
eval_name
=
"mmlu"
,
num_examples
=
64
,
num_threads
=
32
,
temperature
=
0.1
,
)
metrics
=
run_eval
(
args
)
score
=
metrics
[
"score"
]
THRESHOLD
=
0.635
passed
=
score
>=
THRESHOLD
msg
=
f
"MMLU test
{
'passed'
if
passed
else
'failed'
}
with score
{
score
:.
3
f
}
(threshold:
{
THRESHOLD
}
)"
self
.
assertGreaterEqual
(
score
,
THRESHOLD
,
msg
)
# 4. Use the /remove_worker API to remove it from the router
with
requests
.
Session
()
as
session
:
response
=
session
.
post
(
f
"
{
self
.
base_url
}
/remove_worker?url=
{
worker_url
}
"
)
print
(
f
"status code:
{
response
.
status_code
}
, response:
{
response
.
text
}
"
)
self
.
assertEqual
(
response
.
status_code
,
200
)
# 5. Run mmlu again
metrics
=
run_eval
(
args
)
score
=
metrics
[
"score"
]
THRESHOLD
=
0.635
passed
=
score
>=
THRESHOLD
msg
=
f
"MMLU test
{
'passed'
if
passed
else
'failed'
}
with score
{
score
:.
3
f
}
(threshold:
{
THRESHOLD
}
)"
self
.
assertGreaterEqual
(
score
,
THRESHOLD
,
msg
)
# 6. Start another worker with api_key set
terminate_and_wait
(
worker_process
)
# terminate the old worker process
port
=
find_available_port
()
worker_url
=
f
"http://127.0.0.1:
{
port
}
"
worker_process
=
popen_launch_server
(
self
.
model
,
worker_url
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
api_key
=
"correct_api_key"
,
)
self
.
other_process
.
append
(
worker_process
)
# 7. Use the /add_worker API to add it to the router
# Should fail since the router would contact the worker's
# /get_server_info endpoint for the dp_size info, but it
# has no knowledge of the api key
with
requests
.
Session
()
as
session
:
response
=
session
.
post
(
f
"
{
self
.
base_url
}
/add_worker?url=
{
worker_url
}
"
)
print
(
f
"status code:
{
response
.
status_code
}
, response:
{
response
.
text
}
"
)
self
.
assertNotEqual
(
response
.
status_code
,
200
)
def
test_8_lazy_fault_tolerance_with_dp_aware
(
self
):
print
(
"Running test_8_lazy_fault_tolerance_with_dp_aware..."
)
# Set dp_size = 1
self
.
process
=
popen_launch_router
(
self
.
model
,
self
.
base_url
,
dp_size
=
1
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
policy
=
"round_robin"
,
dp_aware
=
True
,
)
# 1. Start a worker
port
=
find_available_port
()
worker_url
=
f
"http://127.0.0.1:
{
port
}
"
worker_process
=
popen_launch_server
(
self
.
model
,
worker_url
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
)
self
.
other_process
.
append
(
worker_process
)
# 2. Use the /add_worker API to add it to the router
# It will be used by router after it is healthy
with
requests
.
Session
()
as
session
:
response
=
session
.
post
(
f
"
{
self
.
base_url
}
/add_worker?url=
{
worker_url
}
"
)
print
(
f
"status code:
{
response
.
status_code
}
, response:
{
response
.
text
}
"
)
self
.
assertEqual
(
response
.
status_code
,
200
)
# Start a thread to kill the worker after 10 seconds to mimic
# abrupt worker failure
def
kill_worker
():
time
.
sleep
(
10
)
kill_process_tree
(
worker_process
.
pid
)
print
(
"Worker process killed"
)
import
threading
kill_thread
=
threading
.
Thread
(
target
=
kill_worker
)
kill_thread
.
daemon
=
True
kill_thread
.
start
()
# 3. Run mmlu
args
=
SimpleNamespace
(
base_url
=
self
.
base_url
,
model
=
self
.
model
,
eval_name
=
"mmlu"
,
num_examples
=
256
,
num_threads
=
32
,
temperature
=
0.1
,
)
metrics
=
run_eval
(
args
)
score
=
metrics
[
"score"
]
THRESHOLD
=
0.635
passed
=
score
>=
THRESHOLD
msg
=
f
"MMLU test
{
'passed'
if
passed
else
'failed'
}
with score
{
score
:.
3
f
}
(threshold:
{
THRESHOLD
}
)"
self
.
assertGreaterEqual
(
score
,
THRESHOLD
,
msg
)
def
test_9_payload_size_with_dp_aware
(
self
):
print
(
"Running test_9_payload_size_with_dp_aware..."
)
# Start the router with 1MB limit
self
.
process
=
popen_launch_router
(
self
.
model
,
self
.
base_url
,
dp_size
=
1
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
policy
=
"round_robin"
,
max_payload_size
=
1
*
1024
*
1024
,
# 1MB limit
dp_aware
=
True
,
)
# Test case 1: Payload just under 1MB should succeed
payload_0_5_mb
=
{
"text"
:
"x"
*
int
(
0.5
*
1024
*
1024
),
# 0.5MB of text
"temperature"
:
0.0
,
}
with
requests
.
Session
()
as
session
:
response
=
session
.
post
(
f
"
{
self
.
base_url
}
/generate"
,
json
=
payload_0_5_mb
,
headers
=
{
"Content-Type"
:
"application/json"
},
)
self
.
assertEqual
(
response
.
status_code
,
200
,
f
"0.5MB payload should succeed but got status
{
response
.
status_code
}
"
,
)
# Test case 2: Payload over 1MB should fail
payload_1_plus_mb
=
{
"text"
:
"x"
*
int
((
1.2
*
1024
*
1024
)),
# 1.2MB of text
"temperature"
:
0.0
,
}
with
requests
.
Session
()
as
session
:
response
=
session
.
post
(
f
"
{
self
.
base_url
}
/generate"
,
json
=
payload_1_plus_mb
,
headers
=
{
"Content-Type"
:
"application/json"
},
)
self
.
assertEqual
(
response
.
status_code
,
413
,
# Payload Too Large
f
"1.2MB payload should fail with 413 but got status
{
response
.
status_code
}
"
,
)
def
test_10_api_key_with_dp_aware
(
self
):
print
(
"Running test_10_api_key_with_dp_aware..."
)
self
.
process
=
popen_launch_router
(
self
.
model
,
self
.
base_url
,
dp_size
=
1
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
policy
=
"round_robin"
,
api_key
=
"correct_api_key"
,
dp_aware
=
True
,
)
# Test case 1: request without api key should fail
with
requests
.
Session
()
as
session
:
response
=
session
.
post
(
f
"
{
self
.
base_url
}
/generate"
,
json
=
{
"text"
:
"Kanye west is, "
,
"temperature"
:
0
},
)
print
(
f
"status code:
{
response
.
status_code
}
, response:
{
response
.
text
}
"
)
self
.
assertEqual
(
response
.
status_code
,
401
,
f
"Request without api key should fail with 401 but got status
{
response
.
status_code
}
"
,
)
# Test case 2: request with invalid api key should fail
with
requests
.
Session
()
as
session
:
response
=
requests
.
post
(
f
"
{
self
.
base_url
}
/generate"
,
json
=
{
"text"
:
"Kanye west is, "
,
"temperature"
:
0
},
headers
=
{
"Authorization"
:
"Bearer 123"
},
)
print
(
f
"status code:
{
response
.
status_code
}
, response:
{
response
.
text
}
"
)
self
.
assertEqual
(
response
.
status_code
,
401
,
f
"Request without api key should fail with 401 but got status
{
response
.
status_code
}
"
,
)
# Test case 3: request with correct api key should succeed
with
requests
.
Session
()
as
session
:
response
=
session
.
post
(
f
"
{
self
.
base_url
}
/generate"
,
json
=
{
"text"
:
"Kanye west is "
,
"temperature"
:
0
},
headers
=
{
"Authorization"
:
"Bearer correct_api_key"
},
)
print
(
f
"status code:
{
response
.
status_code
}
, response:
{
response
.
text
}
"
)
self
.
assertEqual
(
response
.
status_code
,
200
,
f
"Request with correct api key should succeed but got status
{
response
.
status_code
}
"
,
)
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment