Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
5dccf697
Unverified
Commit
5dccf697
authored
Oct 22, 2025
by
Simo Lin
Committed by
GitHub
Oct 22, 2025
Browse files
[router] create worker removal step and clean up worker manager (#11921)
parent
eec9e471
Changes
23
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
731 additions
and
1887 deletions
+731
-1887
sgl-router/py_test/e2e/conftest.py
sgl-router/py_test/e2e/conftest.py
+2
-0
sgl-router/py_test/e2e/test_e2e_embeddings.py
sgl-router/py_test/e2e/test_e2e_embeddings.py
+27
-2
sgl-router/py_test/e2e/test_pd_router.py
sgl-router/py_test/e2e/test_pd_router.py
+2
-0
sgl-router/py_test/e2e/test_regular_router.py
sgl-router/py_test/e2e/test_regular_router.py
+65
-15
sgl-router/py_test/e2e_grpc/fixtures.py
sgl-router/py_test/e2e_grpc/fixtures.py
+2
-0
sgl-router/py_test/fixtures/router_manager.py
sgl-router/py_test/fixtures/router_manager.py
+79
-8
sgl-router/py_test/integration/conftest.py
sgl-router/py_test/integration/conftest.py
+2
-2
sgl-router/src/core/job_queue.rs
sgl-router/src/core/job_queue.rs
+43
-6
sgl-router/src/core/mod.rs
sgl-router/src/core/mod.rs
+1
-1
sgl-router/src/core/worker_manager.rs
sgl-router/src/core/worker_manager.rs
+2
-1504
sgl-router/src/core/workflow/mod.rs
sgl-router/src/core/workflow/mod.rs
+1
-1
sgl-router/src/core/workflow/steps/mod.rs
sgl-router/src/core/workflow/steps/mod.rs
+6
-0
sgl-router/src/core/workflow/steps/worker_registration.rs
sgl-router/src/core/workflow/steps/worker_registration.rs
+94
-49
sgl-router/src/core/workflow/steps/worker_removal.rs
sgl-router/src/core/workflow/steps/worker_removal.rs
+310
-0
sgl-router/src/routers/grpc/responses/conversions.rs
sgl-router/src/routers/grpc/responses/conversions.rs
+6
-2
sgl-router/src/routers/grpc/responses/handlers.rs
sgl-router/src/routers/grpc/responses/handlers.rs
+10
-8
sgl-router/src/routers/grpc/responses/tool_loop.rs
sgl-router/src/routers/grpc/responses/tool_loop.rs
+5
-4
sgl-router/src/server.rs
sgl-router/src/server.rs
+16
-53
sgl-router/src/service_discovery.rs
sgl-router/src/service_discovery.rs
+20
-8
sgl-router/tests/api_endpoints_test.rs
sgl-router/tests/api_endpoints_test.rs
+38
-224
No files found.
sgl-router/py_test/e2e/conftest.py
View file @
5dccf697
...
@@ -85,6 +85,8 @@ def _popen_launch_router(
...
@@ -85,6 +85,8 @@ def _popen_launch_router(
str
(
prom_port
),
str
(
prom_port
),
"--router-prometheus-host"
,
"--router-prometheus-host"
,
"127.0.0.1"
,
"127.0.0.1"
,
"--router-log-level"
,
"warn"
,
]
]
proc
=
subprocess
.
Popen
(
cmd
)
proc
=
subprocess
.
Popen
(
cmd
)
...
...
sgl-router/py_test/e2e/test_e2e_embeddings.py
View file @
5dccf697
import
time
from
types
import
SimpleNamespace
from
types
import
SimpleNamespace
import
pytest
import
pytest
import
requests
import
requests
def
_wait_for_workers
(
base_url
:
str
,
expected_count
:
int
,
timeout
:
float
=
60.0
,
headers
:
dict
=
None
)
->
None
:
"""Poll /workers endpoint until expected number of workers are registered."""
start
=
time
.
perf_counter
()
with
requests
.
Session
()
as
session
:
while
time
.
perf_counter
()
-
start
<
timeout
:
try
:
r
=
session
.
get
(
f
"
{
base_url
}
/workers"
,
headers
=
headers
,
timeout
=
5
)
if
r
.
status_code
==
200
:
workers
=
r
.
json
().
get
(
"workers"
,
[])
if
len
(
workers
)
>=
expected_count
:
return
except
requests
.
RequestException
:
pass
time
.
sleep
(
0.5
)
raise
TimeoutError
(
f
"Expected
{
expected_count
}
workers at
{
base_url
}
, timed out after
{
timeout
}
s"
)
@
pytest
.
mark
.
e2e
@
pytest
.
mark
.
e2e
def
test_embeddings_basic
(
def
test_embeddings_basic
(
e2e_router_only_rr
,
e2e_primary_embedding_worker
,
e2e_embedding_model
e2e_router_only_rr
,
e2e_primary_embedding_worker
,
e2e_embedding_model
...
@@ -12,8 +34,11 @@ def test_embeddings_basic(
...
@@ -12,8 +34,11 @@ def test_embeddings_basic(
worker_url
=
e2e_primary_embedding_worker
.
url
worker_url
=
e2e_primary_embedding_worker
.
url
# Attach embedding worker to router-only instance
# Attach embedding worker to router-only instance
r
=
requests
.
post
(
f
"
{
base
}
/add_worker"
,
params
=
{
"url"
:
worker_url
},
timeout
=
180
)
r
=
requests
.
post
(
f
"
{
base
}
/workers"
,
json
=
{
"url"
:
worker_url
},
timeout
=
180
)
r
.
raise_for_status
()
assert
r
.
status_code
==
202
,
f
"Expected 202 ACCEPTED, got
{
r
.
status_code
}
:
{
r
.
text
}
"
# Wait for worker to be registered
_wait_for_workers
(
base
,
expected_count
=
1
,
timeout
=
60.0
)
# Simple embedding request with two inputs
# Simple embedding request with two inputs
payload
=
{
payload
=
{
...
...
sgl-router/py_test/e2e/test_pd_router.py
View file @
5dccf697
...
@@ -198,6 +198,8 @@ def pd_cluster(e2e_model: str):
...
@@ -198,6 +198,8 @@ def pd_cluster(e2e_model: str):
"--policy"
,
"--policy"
,
"round_robin"
,
"round_robin"
,
"--pd-disaggregation"
,
"--pd-disaggregation"
,
"--log-level"
,
"warn"
,
]
]
for
url
,
bport
in
prefill
:
for
url
,
bport
in
prefill
:
cmd
+=
[
"--prefill"
,
url
,
str
(
bport
)]
cmd
+=
[
"--prefill"
,
url
,
str
(
bport
)]
...
...
sgl-router/py_test/e2e/test_regular_router.py
View file @
5dccf697
...
@@ -8,13 +8,39 @@ import requests
...
@@ -8,13 +8,39 @@ import requests
from
sglang.test.run_eval
import
run_eval
from
sglang.test.run_eval
import
run_eval
def
_wait_for_workers
(
base_url
:
str
,
expected_count
:
int
,
timeout
:
float
=
60.0
,
headers
:
dict
=
None
)
->
None
:
"""Poll /workers endpoint until expected number of workers are registered."""
start
=
time
.
perf_counter
()
with
requests
.
Session
()
as
session
:
while
time
.
perf_counter
()
-
start
<
timeout
:
try
:
r
=
session
.
get
(
f
"
{
base_url
}
/workers"
,
headers
=
headers
,
timeout
=
5
)
if
r
.
status_code
==
200
:
workers
=
r
.
json
().
get
(
"workers"
,
[])
if
len
(
workers
)
>=
expected_count
:
return
except
requests
.
RequestException
:
pass
time
.
sleep
(
0.5
)
raise
TimeoutError
(
f
"Expected
{
expected_count
}
workers at
{
base_url
}
, timed out after
{
timeout
}
s"
)
@
pytest
.
mark
.
e2e
@
pytest
.
mark
.
e2e
def
test_mmlu
(
e2e_router_only_rr
,
e2e_two_workers_dp2
,
e2e_model
):
def
test_mmlu
(
e2e_router_only_rr
,
e2e_two_workers_dp2
,
e2e_model
):
# Attach two dp=2 workers (total 4 GPUs) to a fresh router-only instance
# Attach two dp=2 workers (total 4 GPUs) to a fresh router-only instance
base
=
e2e_router_only_rr
.
url
base
=
e2e_router_only_rr
.
url
for
w
in
e2e_two_workers_dp2
:
for
w
in
e2e_two_workers_dp2
:
r
=
requests
.
post
(
f
"
{
base
}
/add_worker"
,
params
=
{
"url"
:
w
.
url
},
timeout
=
180
)
r
=
requests
.
post
(
f
"
{
base
}
/workers"
,
json
=
{
"url"
:
w
.
url
},
timeout
=
180
)
r
.
raise_for_status
()
assert
(
r
.
status_code
==
202
),
f
"Expected 202 ACCEPTED, got
{
r
.
status_code
}
:
{
r
.
text
}
"
# Wait for workers to be registered
_wait_for_workers
(
base
,
expected_count
=
2
,
timeout
=
60.0
)
args
=
SimpleNamespace
(
args
=
SimpleNamespace
(
base_url
=
base
,
base_url
=
base
,
...
@@ -35,8 +61,13 @@ def test_genai_bench(
...
@@ -35,8 +61,13 @@ def test_genai_bench(
"""Attach a worker to the regular router and run a short genai-bench."""
"""Attach a worker to the regular router and run a short genai-bench."""
base
=
e2e_router_only_rr
.
url
base
=
e2e_router_only_rr
.
url
for
w
in
e2e_two_workers_dp2
:
for
w
in
e2e_two_workers_dp2
:
r
=
requests
.
post
(
f
"
{
base
}
/add_worker"
,
params
=
{
"url"
:
w
.
url
},
timeout
=
180
)
r
=
requests
.
post
(
f
"
{
base
}
/workers"
,
json
=
{
"url"
:
w
.
url
},
timeout
=
180
)
r
.
raise_for_status
()
assert
(
r
.
status_code
==
202
),
f
"Expected 202 ACCEPTED, got
{
r
.
status_code
}
:
{
r
.
text
}
"
# Wait for workers to be registered
_wait_for_workers
(
base
,
expected_count
=
2
,
timeout
=
60.0
)
genai_bench_runner
(
genai_bench_runner
(
router_url
=
base
,
router_url
=
base
,
...
@@ -59,8 +90,11 @@ def test_add_and_remove_worker_live(e2e_router_only_rr, e2e_primary_worker, e2e_
...
@@ -59,8 +90,11 @@ def test_add_and_remove_worker_live(e2e_router_only_rr, e2e_primary_worker, e2e_
base
=
e2e_router_only_rr
.
url
base
=
e2e_router_only_rr
.
url
worker_url
=
e2e_primary_worker
.
url
worker_url
=
e2e_primary_worker
.
url
r
=
requests
.
post
(
f
"
{
base
}
/add_worker"
,
params
=
{
"url"
:
worker_url
},
timeout
=
180
)
r
=
requests
.
post
(
f
"
{
base
}
/workers"
,
json
=
{
"url"
:
worker_url
},
timeout
=
180
)
r
.
raise_for_status
()
assert
r
.
status_code
==
202
,
f
"Expected 202 ACCEPTED, got
{
r
.
status_code
}
:
{
r
.
text
}
"
# Wait for worker to be registered
_wait_for_workers
(
base
,
expected_count
=
1
,
timeout
=
60.0
)
with
requests
.
Session
()
as
s
:
with
requests
.
Session
()
as
s
:
for
i
in
range
(
8
):
for
i
in
range
(
8
):
...
@@ -77,8 +111,11 @@ def test_add_and_remove_worker_live(e2e_router_only_rr, e2e_primary_worker, e2e_
...
@@ -77,8 +111,11 @@ def test_add_and_remove_worker_live(e2e_router_only_rr, e2e_primary_worker, e2e_
r
.
raise_for_status
()
r
.
raise_for_status
()
# Remove the worker
# Remove the worker
r
=
requests
.
post
(
f
"
{
base
}
/remove_worker"
,
params
=
{
"url"
:
worker_url
},
timeout
=
60
)
from
urllib.parse
import
quote
r
.
raise_for_status
()
encoded_url
=
quote
(
worker_url
,
safe
=
""
)
r
=
requests
.
delete
(
f
"
{
base
}
/workers/
{
encoded_url
}
"
,
timeout
=
60
)
assert
r
.
status_code
==
202
,
f
"Expected 202 ACCEPTED, got
{
r
.
status_code
}
:
{
r
.
text
}
"
@
pytest
.
mark
.
e2e
@
pytest
.
mark
.
e2e
...
@@ -86,8 +123,11 @@ def test_lazy_fault_tolerance_live(e2e_router_only_rr, e2e_primary_worker, e2e_m
...
@@ -86,8 +123,11 @@ def test_lazy_fault_tolerance_live(e2e_router_only_rr, e2e_primary_worker, e2e_m
base
=
e2e_router_only_rr
.
url
base
=
e2e_router_only_rr
.
url
worker
=
e2e_primary_worker
worker
=
e2e_primary_worker
r
=
requests
.
post
(
f
"
{
base
}
/add_worker"
,
params
=
{
"url"
:
worker
.
url
},
timeout
=
180
)
r
=
requests
.
post
(
f
"
{
base
}
/workers"
,
json
=
{
"url"
:
worker
.
url
},
timeout
=
180
)
r
.
raise_for_status
()
assert
r
.
status_code
==
202
,
f
"Expected 202 ACCEPTED, got
{
r
.
status_code
}
:
{
r
.
text
}
"
# Wait for worker to be registered
_wait_for_workers
(
base
,
expected_count
=
1
,
timeout
=
60.0
)
def
killer
():
def
killer
():
time
.
sleep
(
10
)
time
.
sleep
(
10
)
...
@@ -129,20 +169,30 @@ def test_dp_aware_worker_expansion_and_api_key(
...
@@ -129,20 +169,30 @@ def test_dp_aware_worker_expansion_and_api_key(
# Attach worker; router should expand to dp_size logical workers
# Attach worker; router should expand to dp_size logical workers
r
=
requests
.
post
(
r
=
requests
.
post
(
f
"
{
router_url
}
/
add_
worker"
,
f
"
{
router_url
}
/worker
s
"
,
params
=
{
"url"
:
worker_url
,
"api_key"
:
api_key
},
json
=
{
"url"
:
worker_url
,
"api_key"
:
api_key
},
headers
=
{
"Authorization"
:
f
"Bearer
{
api_key
}
"
},
headers
=
{
"Authorization"
:
f
"Bearer
{
api_key
}
"
},
timeout
=
180
,
timeout
=
180
,
)
)
r
.
raise_for_status
()
assert
r
.
status_code
==
202
,
f
"Expected 202 ACCEPTED, got
{
r
.
status_code
}
:
{
r
.
text
}
"
# Wait for workers to be registered and expanded
_wait_for_workers
(
router_url
,
expected_count
=
2
,
timeout
=
60.0
,
headers
=
{
"Authorization"
:
f
"Bearer
{
api_key
}
"
},
)
# Verify the expanded workers have correct URLs
r
=
requests
.
get
(
r
=
requests
.
get
(
f
"
{
router_url
}
/
list_
workers"
,
f
"
{
router_url
}
/workers"
,
headers
=
{
"Authorization"
:
f
"Bearer
{
api_key
}
"
},
headers
=
{
"Authorization"
:
f
"Bearer
{
api_key
}
"
},
timeout
=
30
,
timeout
=
30
,
)
)
r
.
raise_for_status
()
r
.
raise_for_status
()
urls
=
r
.
json
().
get
(
"urls"
,
[])
workers
=
r
.
json
().
get
(
"workers"
,
[])
urls
=
[
w
[
"url"
]
for
w
in
workers
]
assert
len
(
urls
)
==
2
assert
len
(
urls
)
==
2
assert
set
(
urls
)
==
{
f
"
{
worker_url
}
@0"
,
f
"
{
worker_url
}
@1"
}
assert
set
(
urls
)
==
{
f
"
{
worker_url
}
@0"
,
f
"
{
worker_url
}
@1"
}
...
...
sgl-router/py_test/e2e_grpc/fixtures.py
View file @
5dccf697
...
@@ -267,6 +267,8 @@ def popen_launch_workers_and_router(
...
@@ -267,6 +267,8 @@ def popen_launch_workers_and_router(
policy
,
policy
,
"--model-path"
,
"--model-path"
,
model
,
model
,
"--log-level"
,
"warn"
,
]
]
# Add worker URLs
# Add worker URLs
...
...
sgl-router/py_test/fixtures/router_manager.py
View file @
5dccf697
...
@@ -133,19 +133,90 @@ class RouterManager:
...
@@ -133,19 +133,90 @@ class RouterManager:
time
.
sleep
(
0.2
)
time
.
sleep
(
0.2
)
raise
TimeoutError
(
f
"Router at
{
base_url
}
did not become healthy"
)
raise
TimeoutError
(
f
"Router at
{
base_url
}
did not become healthy"
)
def
add_worker
(
self
,
base_url
:
str
,
worker_url
:
str
)
->
None
:
def
add_worker
(
self
,
base_url
:
str
,
worker_url
:
str
,
timeout
:
float
=
30.0
)
->
None
:
r
=
requests
.
post
(
f
"
{
base_url
}
/add_worker"
,
params
=
{
"url"
:
worker_url
})
r
=
requests
.
post
(
f
"
{
base_url
}
/workers"
,
json
=
{
"url"
:
worker_url
})
assert
r
.
status_code
==
200
,
f
"add_worker failed:
{
r
.
status_code
}
{
r
.
text
}
"
assert
(
r
.
status_code
==
202
),
f
"add_worker failed:
{
r
.
status_code
}
{
r
.
text
}
"
# ACCEPTED status
def
remove_worker
(
self
,
base_url
:
str
,
worker_url
:
str
)
->
None
:
# Poll until worker is actually added and healthy
r
=
requests
.
post
(
f
"
{
base_url
}
/remove_worker"
,
params
=
{
"url"
:
worker_url
})
from
urllib.parse
import
quote
assert
r
.
status_code
==
200
,
f
"remove_worker failed:
{
r
.
status_code
}
{
r
.
text
}
"
encoded_url
=
quote
(
worker_url
,
safe
=
""
)
start
=
time
.
time
()
with
requests
.
Session
()
as
s
:
while
time
.
time
()
-
start
<
timeout
:
try
:
r
=
s
.
get
(
f
"
{
base_url
}
/workers/
{
encoded_url
}
"
,
timeout
=
2
)
if
r
.
status_code
==
200
:
data
=
r
.
json
()
# Check if registration job failed
job_status
=
data
.
get
(
"job_status"
)
if
job_status
and
job_status
.
get
(
"state"
)
==
"failed"
:
raise
RuntimeError
(
f
"Worker registration failed:
{
job_status
.
get
(
'message'
,
'Unknown error'
)
}
"
)
# Check if worker is healthy and registered (not just in job queue)
if
data
.
get
(
"is_healthy"
,
False
):
return
# Worker not ready yet, continue polling
except
requests
.
RequestException
:
pass
time
.
sleep
(
0.1
)
raise
TimeoutError
(
f
"Worker
{
worker_url
}
was not added and healthy after
{
timeout
}
s"
)
def
remove_worker
(
self
,
base_url
:
str
,
worker_url
:
str
,
timeout
:
float
=
30.0
)
->
None
:
# URL encode the worker_url for path parameter
from
urllib.parse
import
quote
encoded_url
=
quote
(
worker_url
,
safe
=
""
)
r
=
requests
.
delete
(
f
"
{
base_url
}
/workers/
{
encoded_url
}
"
)
assert
(
r
.
status_code
==
202
),
f
"remove_worker failed:
{
r
.
status_code
}
{
r
.
text
}
"
# ACCEPTED status
# Poll until worker is actually removed (GET returns 404) or timeout
start
=
time
.
time
()
last_status
=
None
with
requests
.
Session
()
as
s
:
while
time
.
time
()
-
start
<
timeout
:
try
:
r
=
s
.
get
(
f
"
{
base_url
}
/workers/
{
encoded_url
}
"
,
timeout
=
2
)
if
r
.
status_code
==
404
:
# Worker successfully removed
return
elif
r
.
status_code
==
200
:
# Check if removal job failed
data
=
r
.
json
()
job_status
=
data
.
get
(
"job_status"
)
if
job_status
:
last_status
=
job_status
if
job_status
.
get
(
"state"
)
==
"failed"
:
raise
RuntimeError
(
f
"Worker removal failed:
{
job_status
.
get
(
'message'
,
'Unknown error'
)
}
"
)
# Worker still being processed, continue polling
except
requests
.
RequestException
:
pass
time
.
sleep
(
0.1
)
# Provide detailed timeout error with last known status
error_msg
=
f
"Worker
{
worker_url
}
was not removed after
{
timeout
}
s"
if
last_status
:
error_msg
+=
f
". Last job status:
{
last_status
}
"
raise
TimeoutError
(
error_msg
)
def
list_workers
(
self
,
base_url
:
str
)
->
list
[
str
]:
def
list_workers
(
self
,
base_url
:
str
)
->
list
[
str
]:
r
=
requests
.
get
(
f
"
{
base_url
}
/
list_
workers"
)
r
=
requests
.
get
(
f
"
{
base_url
}
/workers"
)
assert
r
.
status_code
==
200
,
f
"list_workers failed:
{
r
.
status_code
}
{
r
.
text
}
"
assert
r
.
status_code
==
200
,
f
"list_workers failed:
{
r
.
status_code
}
{
r
.
text
}
"
data
=
r
.
json
()
data
=
r
.
json
()
return
data
.
get
(
"urls"
,
[])
# Extract URLs from WorkerInfo objects
workers
=
data
.
get
(
"workers"
,
[])
return
[
w
[
"url"
]
for
w
in
workers
]
def
stop_all
(
self
):
def
stop_all
(
self
):
for
p
in
self
.
_children
:
for
p
in
self
.
_children
:
...
...
sgl-router/py_test/integration/conftest.py
View file @
5dccf697
...
@@ -2,7 +2,7 @@ import os
...
@@ -2,7 +2,7 @@ import os
import
subprocess
import
subprocess
import
time
import
time
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
Dict
,
Iterable
,
List
,
Tuple
from
typing
import
Dict
,
Iterable
,
List
,
Optional
,
Tuple
import
pytest
import
pytest
import
requests
import
requests
...
@@ -84,7 +84,7 @@ def mock_workers():
...
@@ -84,7 +84,7 @@ def mock_workers():
procs
:
List
[
subprocess
.
Popen
]
=
[]
procs
:
List
[
subprocess
.
Popen
]
=
[]
def
_start
(
n
:
int
,
args
:
List
[
str
]
|
None
=
None
):
def
_start
(
n
:
int
,
args
:
Optional
[
List
[
str
]
]
=
None
):
args
=
args
or
[]
args
=
args
or
[]
new_procs
:
List
[
subprocess
.
Popen
]
=
[]
new_procs
:
List
[
subprocess
.
Popen
]
=
[]
urls
:
List
[
str
]
=
[]
urls
:
List
[
str
]
=
[]
...
...
sgl-router/src/core/job_queue.rs
View file @
5dccf697
...
@@ -15,11 +15,9 @@ use tracing::{debug, error, info, warn};
...
@@ -15,11 +15,9 @@ use tracing::{debug, error, info, warn};
use
crate
::{
use
crate
::{
config
::{
RouterConfig
,
RoutingMode
},
config
::{
RouterConfig
,
RoutingMode
},
core
::{
core
::
workflow
::{
workflow
::{
steps
::
WorkerRemovalRequest
,
WorkflowContext
,
WorkflowEngine
,
WorkflowId
,
WorkflowContext
,
WorkflowEngine
,
WorkflowId
,
WorkflowInstanceId
,
WorkflowStatus
,
WorkflowInstanceId
,
WorkflowStatus
,
},
WorkerManager
,
},
},
metrics
::
RouterMetrics
,
metrics
::
RouterMetrics
,
protocols
::
worker_spec
::{
JobStatus
,
WorkerConfigRequest
},
protocols
::
worker_spec
::{
JobStatus
,
WorkerConfigRequest
},
...
@@ -320,11 +318,29 @@ impl JobQueue {
...
@@ -320,11 +318,29 @@ impl JobQueue {
.await
.await
}
}
Job
::
RemoveWorker
{
url
}
=>
{
Job
::
RemoveWorker
{
url
}
=>
{
let
result
=
WorkerManager
::
remove_worker
(
url
,
context
);
let
engine
=
context
.workflow_engine
.get
()
.ok_or_else
(||
"Workflow engine not initialized"
.to_string
())
?
;
let
instance_id
=
Self
::
start_worker_removal_workflow
(
engine
,
url
,
context
)
.await
?
;
debug!
(
"Started worker removal workflow for {} (instance: {})"
,
url
,
instance_id
);
let
timeout_duration
=
Duration
::
from_secs
(
30
);
let
result
=
Self
::
wait_for_workflow_completion
(
engine
,
instance_id
,
url
,
timeout_duration
)
.await
;
// Clean up job status when removing worker
// Clean up job status when removing worker
if
let
Some
(
queue
)
=
context
.worker_job_queue
.get
()
{
if
let
Some
(
queue
)
=
context
.worker_job_queue
.get
()
{
queue
.remove_status
(
url
);
queue
.remove_status
(
url
);
}
}
result
result
}
}
Job
::
InitializeWorkersFromConfig
{
router_config
}
=>
{
Job
::
InitializeWorkersFromConfig
{
router_config
}
=>
{
...
@@ -424,6 +440,27 @@ impl JobQueue {
...
@@ -424,6 +440,27 @@ impl JobQueue {
.map_err
(|
e
|
format!
(
"Failed to start worker registration workflow: {:?}"
,
e
))
.map_err
(|
e
|
format!
(
"Failed to start worker registration workflow: {:?}"
,
e
))
}
}
/// Start worker removal workflow
async
fn
start_worker_removal_workflow
(
engine
:
&
Arc
<
WorkflowEngine
>
,
url
:
&
str
,
context
:
&
Arc
<
AppContext
>
,
)
->
Result
<
WorkflowInstanceId
,
String
>
{
let
removal_request
=
WorkerRemovalRequest
{
url
:
url
.to_string
(),
dp_aware
:
context
.router_config.dp_aware
,
};
let
mut
workflow_context
=
WorkflowContext
::
new
(
WorkflowInstanceId
::
new
());
workflow_context
.set
(
"removal_request"
,
removal_request
);
workflow_context
.set_arc
(
"app_context"
,
Arc
::
clone
(
context
));
engine
.start_workflow
(
WorkflowId
::
new
(
"worker_removal"
),
workflow_context
)
.await
.map_err
(|
e
|
format!
(
"Failed to start worker removal workflow: {:?}"
,
e
))
}
/// Wait for workflow completion with adaptive polling
/// Wait for workflow completion with adaptive polling
async
fn
wait_for_workflow_completion
(
async
fn
wait_for_workflow_completion
(
engine
:
&
Arc
<
WorkflowEngine
>
,
engine
:
&
Arc
<
WorkflowEngine
>
,
...
...
sgl-router/src/core/mod.rs
View file @
5dccf697
...
@@ -29,5 +29,5 @@ pub use worker::{
...
@@ -29,5 +29,5 @@ pub use worker::{
Worker
,
WorkerFactory
,
WorkerLoadGuard
,
WorkerType
,
Worker
,
WorkerFactory
,
WorkerLoadGuard
,
WorkerType
,
};
};
pub
use
worker_builder
::{
BasicWorkerBuilder
,
DPAwareWorkerBuilder
};
pub
use
worker_builder
::{
BasicWorkerBuilder
,
DPAwareWorkerBuilder
};
pub
use
worker_manager
::{
DpInfo
,
LoadMonitor
,
ServerInfo
,
WorkerManager
};
pub
use
worker_manager
::{
LoadMonitor
,
WorkerManager
};
pub
use
worker_registry
::{
WorkerId
,
WorkerRegistry
,
WorkerRegistryStats
};
pub
use
worker_registry
::{
WorkerId
,
WorkerRegistry
,
WorkerRegistryStats
};
sgl-router/src/core/worker_manager.rs
View file @
5dccf697
This diff is collapsed.
Click to expand it.
sgl-router/src/core/workflow/mod.rs
View file @
5dccf697
...
@@ -14,5 +14,5 @@ pub use engine::WorkflowEngine;
...
@@ -14,5 +14,5 @@ pub use engine::WorkflowEngine;
pub
use
event
::{
EventBus
,
EventSubscriber
,
LoggingSubscriber
,
WorkflowEvent
};
pub
use
event
::{
EventBus
,
EventSubscriber
,
LoggingSubscriber
,
WorkflowEvent
};
pub
use
executor
::{
FunctionStep
,
StepExecutor
};
pub
use
executor
::{
FunctionStep
,
StepExecutor
};
pub
use
state
::
WorkflowStateStore
;
pub
use
state
::
WorkflowStateStore
;
pub
use
steps
::
create_worker_registration_workflow
;
pub
use
steps
::
{
create_worker_registration_workflow
,
create_worker_removal_workflow
}
;
pub
use
types
::
*
;
pub
use
types
::
*
;
sgl-router/src/core/workflow/steps/mod.rs
View file @
5dccf697
...
@@ -2,11 +2,17 @@
...
@@ -2,11 +2,17 @@
//!
//!
//! This module contains concrete step implementations for various workflows:
//! This module contains concrete step implementations for various workflows:
//! - Worker registration and activation
//! - Worker registration and activation
//! - Worker removal
//! - Future: Tokenizer fetching, LoRA updates, etc.
//! - Future: Tokenizer fetching, LoRA updates, etc.
pub
mod
worker_registration
;
pub
mod
worker_registration
;
pub
mod
worker_removal
;
pub
use
worker_registration
::{
pub
use
worker_registration
::{
create_worker_registration_workflow
,
ActivateWorkerStep
,
CreateWorkerStep
,
create_worker_registration_workflow
,
ActivateWorkerStep
,
CreateWorkerStep
,
DetectConnectionModeStep
,
DiscoverMetadataStep
,
RegisterWorkerStep
,
UpdatePoliciesStep
,
DetectConnectionModeStep
,
DiscoverMetadataStep
,
RegisterWorkerStep
,
UpdatePoliciesStep
,
};
};
pub
use
worker_removal
::{
create_worker_removal_workflow
,
FindWorkersToRemoveStep
,
RemoveFromPolicyRegistryStep
,
RemoveFromWorkerRegistryStep
,
UpdateRemainingPoliciesStep
,
WorkerRemovalRequest
,
};
sgl-router/src/core/workflow/steps/worker_registration.rs
View file @
5dccf697
...
@@ -16,13 +16,14 @@ use std::{collections::HashMap, sync::Arc, time::Duration};
...
@@ -16,13 +16,14 @@ use std::{collections::HashMap, sync::Arc, time::Duration};
use
async_trait
::
async_trait
;
use
async_trait
::
async_trait
;
use
once_cell
::
sync
::
Lazy
;
use
once_cell
::
sync
::
Lazy
;
use
reqwest
::
Client
;
use
reqwest
::
Client
;
use
serde
::{
Deserialize
,
Serialize
};
use
serde_json
::
Value
;
use
serde_json
::
Value
;
use
tracing
::{
debug
,
info
,
warn
};
use
tracing
::{
debug
,
info
,
warn
};
use
crate
::{
use
crate
::{
core
::{
core
::{
workflow
::
*
,
BasicWorkerBuilder
,
CircuitBreakerConfig
,
ConnectionMode
,
workflow
::
*
,
BasicWorkerBuilder
,
CircuitBreakerConfig
,
ConnectionMode
,
DPAwareWorkerBuilder
,
DpInfo
,
HealthConfig
,
Worker
,
WorkerManager
,
WorkerType
,
DPAwareWorkerBuilder
,
HealthConfig
,
Worker
,
WorkerType
,
},
},
grpc_client
::
SglangSchedulerClient
,
grpc_client
::
SglangSchedulerClient
,
protocols
::
worker_spec
::
WorkerConfigRequest
,
protocols
::
worker_spec
::
WorkerConfigRequest
,
...
@@ -37,6 +38,82 @@ static HTTP_CLIENT: Lazy<Client> = Lazy::new(|| {
...
@@ -37,6 +38,82 @@ static HTTP_CLIENT: Lazy<Client> = Lazy::new(|| {
.expect
(
"Failed to create HTTP client"
)
.expect
(
"Failed to create HTTP client"
)
});
});
/// Server information returned from worker endpoints
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
struct
ServerInfo
{
#[serde(alias
=
"model"
)]
model_id
:
Option
<
String
>
,
model_path
:
Option
<
String
>
,
dp_size
:
Option
<
usize
>
,
version
:
Option
<
String
>
,
max_batch_size
:
Option
<
usize
>
,
max_total_tokens
:
Option
<
usize
>
,
max_prefill_tokens
:
Option
<
usize
>
,
max_running_requests
:
Option
<
usize
>
,
max_num_reqs
:
Option
<
usize
>
,
}
#[derive(Debug,
Clone)]
pub
struct
DpInfo
{
pub
dp_size
:
usize
,
pub
model_id
:
String
,
}
/// Parse server info from JSON response using serde
fn
parse_server_info
(
json
:
Value
)
->
Result
<
ServerInfo
,
String
>
{
serde_json
::
from_value
(
json
)
.map_err
(|
e
|
format!
(
"Failed to parse server info: {}"
,
e
))
}
/// Get server info from /get_server_info endpoint
async
fn
get_server_info
(
url
:
&
str
,
api_key
:
Option
<&
str
>
)
->
Result
<
ServerInfo
,
String
>
{
let
base_url
=
url
.trim_end_matches
(
'/'
);
let
server_info_url
=
format!
(
"{}/get_server_info"
,
base_url
);
let
mut
req
=
HTTP_CLIENT
.get
(
&
server_info_url
);
if
let
Some
(
key
)
=
api_key
{
req
=
req
.bearer_auth
(
key
);
}
let
response
=
req
.send
()
.await
.map_err
(|
e
|
format!
(
"Failed to connect to {}: {}"
,
server_info_url
,
e
))
?
;
if
!
response
.status
()
.is_success
()
{
return
Err
(
format!
(
"Server returned status {} from {}"
,
response
.status
(),
server_info_url
));
}
let
json
=
response
.json
::
<
Value
>
()
.await
.map_err
(|
e
|
format!
(
"Failed to parse response from {}: {}"
,
server_info_url
,
e
))
?
;
parse_server_info
(
json
)
}
/// Get DP info for a worker URL
async
fn
get_dp_info
(
url
:
&
str
,
api_key
:
Option
<&
str
>
)
->
Result
<
DpInfo
,
String
>
{
let
info
=
get_server_info
(
url
,
api_key
)
.await
?
;
let
dp_size
=
info
.dp_size
.ok_or_else
(||
format!
(
"No dp_size in response from {}"
,
url
))
?
;
let
model_id
=
info
.model_id
.or_else
(||
{
info
.model_path
.and_then
(|
path
|
path
.split
(
'/'
)
.next_back
()
.map
(|
s
|
s
.to_string
()))
})
.unwrap_or_else
(||
"unknown"
.to_string
());
Ok
(
DpInfo
{
dp_size
,
model_id
})
}
/// Helper: Strip protocol prefix from URL
/// Helper: Strip protocol prefix from URL
fn
strip_protocol
(
url
:
&
str
)
->
String
{
fn
strip_protocol
(
url
:
&
str
)
->
String
{
url
.trim_start_matches
(
"http://"
)
url
.trim_start_matches
(
"http://"
)
...
@@ -83,49 +160,6 @@ async fn try_grpc_health_check(url: &str, timeout_secs: u64) -> Result<(), Strin
...
@@ -83,49 +160,6 @@ async fn try_grpc_health_check(url: &str, timeout_secs: u64) -> Result<(), Strin
Ok
(())
Ok
(())
}
}
/// Helper: Fetch HTTP metadata
async
fn
fetch_http_metadata
(
url
:
&
str
,
api_key
:
Option
<&
str
>
,
)
->
Result
<
HashMap
<
String
,
String
>
,
String
>
{
let
clean_url
=
strip_protocol
(
url
);
let
info_url
=
if
clean_url
.starts_with
(
"http://"
)
||
clean_url
.starts_with
(
"https://"
)
{
format!
(
"{}/get_server_info"
,
clean_url
)
}
else
{
format!
(
"http://{}/get_server_info"
,
clean_url
)
};
let
mut
request
=
HTTP_CLIENT
.get
(
&
info_url
);
if
let
Some
(
key
)
=
api_key
{
request
=
request
.header
(
"Authorization"
,
format!
(
"Bearer {}"
,
key
));
}
let
response
=
request
.send
()
.await
.map_err
(|
e
|
format!
(
"Failed to fetch HTTP metadata: {}"
,
e
))
?
;
let
server_info
:
Value
=
response
.json
()
.await
.map_err
(|
e
|
format!
(
"Failed to parse HTTP metadata: {}"
,
e
))
?
;
let
mut
labels
=
HashMap
::
new
();
if
let
Some
(
model_path
)
=
server_info
.get
(
"model_path"
)
.and_then
(|
v
|
v
.as_str
())
{
if
!
model_path
.is_empty
()
{
labels
.insert
(
"model_path"
.to_string
(),
model_path
.to_string
());
}
}
if
let
Some
(
tokenizer_path
)
=
server_info
.get
(
"tokenizer_path"
)
.and_then
(|
v
|
v
.as_str
())
{
if
!
tokenizer_path
.is_empty
()
{
labels
.insert
(
"tokenizer_path"
.to_string
(),
tokenizer_path
.to_string
());
}
}
Ok
(
labels
)
}
/// Helper: Fetch gRPC metadata
/// Helper: Fetch gRPC metadata
async
fn
fetch_grpc_metadata
(
url
:
&
str
)
->
Result
<
HashMap
<
String
,
String
>
,
String
>
{
async
fn
fetch_grpc_metadata
(
url
:
&
str
)
->
Result
<
HashMap
<
String
,
String
>
,
String
>
{
let
grpc_url
=
if
url
.starts_with
(
"grpc://"
)
{
let
grpc_url
=
if
url
.starts_with
(
"grpc://"
)
{
...
@@ -266,7 +300,18 @@ impl StepExecutor for DiscoverMetadataStep {
...
@@ -266,7 +300,18 @@ impl StepExecutor for DiscoverMetadataStep {
let
discovered_labels
=
match
connection_mode
.as_ref
()
{
let
discovered_labels
=
match
connection_mode
.as_ref
()
{
ConnectionMode
::
Http
=>
{
ConnectionMode
::
Http
=>
{
fetch_http_metadata
(
&
config
.url
,
config
.api_key
.as_deref
())
.await
match
get_server_info
(
&
config
.url
,
config
.api_key
.as_deref
())
.await
{
Ok
(
server_info
)
=>
{
let
mut
labels
=
HashMap
::
new
();
if
let
Some
(
model_path
)
=
server_info
.model_path
{
if
!
model_path
.is_empty
()
{
labels
.insert
(
"model_path"
.to_string
(),
model_path
);
}
}
Ok
(
labels
)
}
Err
(
e
)
=>
Err
(
e
),
}
}
}
ConnectionMode
::
Grpc
{
..
}
=>
fetch_grpc_metadata
(
&
config
.url
)
.await
,
ConnectionMode
::
Grpc
{
..
}
=>
fetch_grpc_metadata
(
&
config
.url
)
.await
,
}
}
...
@@ -314,7 +359,7 @@ impl StepExecutor for DiscoverDPInfoStep {
...
@@ -314,7 +359,7 @@ impl StepExecutor for DiscoverDPInfoStep {
debug!
(
"Discovering DP info for {} (DP-aware)"
,
config
.url
);
debug!
(
"Discovering DP info for {} (DP-aware)"
,
config
.url
);
// Get DP info from worker
// Get DP info from worker
let
dp_info
=
WorkerManager
::
get_dp_info
(
&
config
.url
,
config
.api_key
.as_deref
())
let
dp_info
=
get_dp_info
(
&
config
.url
,
config
.api_key
.as_deref
())
.await
.await
.map_err
(|
e
|
WorkflowError
::
StepFailed
{
.map_err
(|
e
|
WorkflowError
::
StepFailed
{
step_id
:
StepId
::
new
(
"discover_dp_info"
),
step_id
:
StepId
::
new
(
"discover_dp_info"
),
...
@@ -327,7 +372,7 @@ impl StepExecutor for DiscoverDPInfoStep {
...
@@ -327,7 +372,7 @@ impl StepExecutor for DiscoverDPInfoStep {
);
);
// Store DP info in context
// Store DP info in context
context
.set
(
"dp_info"
,
Arc
::
new
(
dp_info
)
)
;
context
.set
(
"dp_info"
,
dp_info
);
Ok
(
StepResult
::
Success
)
Ok
(
StepResult
::
Success
)
}
}
...
@@ -522,7 +567,7 @@ impl StepExecutor for CreateWorkerStep {
...
@@ -522,7 +567,7 @@ impl StepExecutor for CreateWorkerStep {
}
}
// Store workers (plural) and labels in context
// Store workers (plural) and labels in context
context
.set
(
"workers"
,
Arc
::
new
(
workers
)
)
;
context
.set
(
"workers"
,
workers
);
context
.set
(
"labels"
,
final_labels
);
context
.set
(
"labels"
,
final_labels
);
Ok
(
StepResult
::
Success
)
Ok
(
StepResult
::
Success
)
...
@@ -595,7 +640,7 @@ impl StepExecutor for RegisterWorkerStep {
...
@@ -595,7 +640,7 @@ impl StepExecutor for RegisterWorkerStep {
);
);
}
}
context
.set
(
"worker_ids"
,
Arc
::
new
(
worker_ids
)
)
;
context
.set
(
"worker_ids"
,
worker_ids
);
Ok
(
StepResult
::
Success
)
Ok
(
StepResult
::
Success
)
}
else
{
}
else
{
// Non-DP-aware path: Register single worker
// Non-DP-aware path: Register single worker
...
...
sgl-router/src/core/workflow/steps/worker_removal.rs
0 → 100644
View file @
5dccf697
//! Worker Removal Workflow Steps
//!
//! This module implements the workflow steps for removing workers from the router.
//! Handles both single worker removal and DP-aware worker removal with prefix matching.
//!
//! Steps:
//! 1. FindWorkersToRemove - Identify workers to remove based on URL (handles DP-aware prefix matching)
//! 2. RemoveFromPolicyRegistry - Remove workers from policy registry and cache-aware policies
//! 3. RemoveFromWorkerRegistry - Remove workers from worker registry
//! 4. UpdateRemainingPolicies - Update cache-aware policies for remaining workers
use
std
::{
collections
::
HashSet
,
sync
::
Arc
};
use
async_trait
::
async_trait
;
use
tracing
::{
debug
,
info
};
use
crate
::{
core
::{
workflow
::
*
,
Worker
},
server
::
AppContext
,
};
/// Request structure for worker removal
#[derive(Debug,
Clone)]
pub
struct
WorkerRemovalRequest
{
pub
url
:
String
,
pub
dp_aware
:
bool
,
}
/// Step 1: Find workers to remove based on URL
pub
struct
FindWorkersToRemoveStep
;
#[async_trait]
impl
StepExecutor
for
FindWorkersToRemoveStep
{
async
fn
execute
(
&
self
,
context
:
&
mut
WorkflowContext
)
->
WorkflowResult
<
StepResult
>
{
let
request
:
Arc
<
WorkerRemovalRequest
>
=
context
.get
(
"removal_request"
)
.ok_or_else
(||
WorkflowError
::
ContextValueNotFound
(
"removal_request"
.to_string
()))
?
;
let
app_context
:
Arc
<
AppContext
>
=
context
.get
(
"app_context"
)
.ok_or_else
(||
WorkflowError
::
ContextValueNotFound
(
"app_context"
.to_string
()))
?
;
debug!
(
"Finding workers to remove for {} (dp_aware: {})"
,
request
.url
,
request
.dp_aware
);
let
workers_to_remove
:
Vec
<
Arc
<
dyn
Worker
>>
=
if
request
.dp_aware
{
// DP-aware: Find all workers with matching prefix
let
worker_url_prefix
=
format!
(
"{}@"
,
request
.url
);
let
all_workers
=
app_context
.worker_registry
.get_all
();
all_workers
.iter
()
.filter
(|
worker
|
worker
.url
()
.starts_with
(
&
worker_url_prefix
))
.cloned
()
.collect
()
}
else
{
// Non-DP-aware: Find single worker by exact URL
match
app_context
.worker_registry
.get_by_url
(
&
request
.url
)
{
Some
(
worker
)
=>
vec!
[
worker
],
None
=>
Vec
::
new
(),
}
};
if
workers_to_remove
.is_empty
()
{
let
error_msg
=
if
request
.dp_aware
{
format!
(
"No workers found with prefix {}@"
,
request
.url
)
}
else
{
format!
(
"Worker {} not found"
,
request
.url
)
};
return
Err
(
WorkflowError
::
StepFailed
{
step_id
:
StepId
::
new
(
"find_workers_to_remove"
),
message
:
error_msg
,
});
}
debug!
(
"Found {} worker(s) to remove for {}"
,
workers_to_remove
.len
(),
request
.url
);
// Store workers and their model IDs for subsequent steps
let
worker_urls
:
Vec
<
String
>
=
workers_to_remove
.iter
()
.map
(|
w
|
w
.url
()
.to_string
())
.collect
();
let
affected_models
:
HashSet
<
String
>
=
workers_to_remove
.iter
()
.map
(|
w
|
w
.model_id
()
.to_string
())
.collect
();
context
.set
(
"workers_to_remove"
,
workers_to_remove
);
context
.set
(
"worker_urls"
,
worker_urls
);
context
.set
(
"affected_models"
,
affected_models
);
Ok
(
StepResult
::
Success
)
}
fn
is_retryable
(
&
self
,
_
error
:
&
WorkflowError
)
->
bool
{
false
// Worker not found is not retryable
}
}
/// Step 2: Remove workers from policy registry
pub
struct
RemoveFromPolicyRegistryStep
;
#[async_trait]
impl
StepExecutor
for
RemoveFromPolicyRegistryStep
{
async
fn
execute
(
&
self
,
context
:
&
mut
WorkflowContext
)
->
WorkflowResult
<
StepResult
>
{
let
app_context
:
Arc
<
AppContext
>
=
context
.get
(
"app_context"
)
.ok_or_else
(||
WorkflowError
::
ContextValueNotFound
(
"app_context"
.to_string
()))
?
;
let
workers_to_remove
:
Arc
<
Vec
<
Arc
<
dyn
Worker
>>>
=
context
.get
(
"workers_to_remove"
)
.ok_or_else
(||
WorkflowError
::
ContextValueNotFound
(
"workers_to_remove"
.to_string
()))
?
;
debug!
(
"Removing {} worker(s) from policy registry"
,
workers_to_remove
.len
()
);
for
worker
in
workers_to_remove
.iter
()
{
let
model_id
=
worker
.model_id
()
.to_string
();
let
worker_url
=
worker
.url
();
// Remove from cache-aware policy
app_context
.policy_registry
.remove_worker_from_cache_aware
(
&
model_id
,
worker_url
);
// Notify policy registry
app_context
.policy_registry
.on_worker_removed
(
&
model_id
);
debug!
(
"Removed worker {} from policy registry (model: {})"
,
worker_url
,
model_id
);
}
Ok
(
StepResult
::
Success
)
}
fn
is_retryable
(
&
self
,
_
error
:
&
WorkflowError
)
->
bool
{
false
// Policy removal is not retryable
}
}
/// Step 3: Remove workers from worker registry
pub
struct
RemoveFromWorkerRegistryStep
;
#[async_trait]
impl
StepExecutor
for
RemoveFromWorkerRegistryStep
{
async
fn
execute
(
&
self
,
context
:
&
mut
WorkflowContext
)
->
WorkflowResult
<
StepResult
>
{
let
app_context
:
Arc
<
AppContext
>
=
context
.get
(
"app_context"
)
.ok_or_else
(||
WorkflowError
::
ContextValueNotFound
(
"app_context"
.to_string
()))
?
;
let
worker_urls
:
Arc
<
Vec
<
String
>>
=
context
.get
(
"worker_urls"
)
.ok_or_else
(||
WorkflowError
::
ContextValueNotFound
(
"worker_urls"
.to_string
()))
?
;
debug!
(
"Removing {} worker(s) from worker registry"
,
worker_urls
.len
()
);
let
mut
removed_count
=
0
;
for
worker_url
in
worker_urls
.iter
()
{
if
app_context
.worker_registry
.remove_by_url
(
worker_url
)
.is_some
()
{
removed_count
+=
1
;
debug!
(
"Removed worker {} from registry"
,
worker_url
);
}
}
if
removed_count
!=
worker_urls
.len
()
{
return
Err
(
WorkflowError
::
StepFailed
{
step_id
:
StepId
::
new
(
"remove_from_worker_registry"
),
message
:
format!
(
"Expected to remove {} workers but only removed {}"
,
worker_urls
.len
(),
removed_count
),
});
}
Ok
(
StepResult
::
Success
)
}
fn
is_retryable
(
&
self
,
_
error
:
&
WorkflowError
)
->
bool
{
false
// Worker removal is not retryable
}
}
/// Step 4: Update cache-aware policies for remaining workers
pub
struct
UpdateRemainingPoliciesStep
;
#[async_trait]
impl
StepExecutor
for
UpdateRemainingPoliciesStep
{
async
fn
execute
(
&
self
,
context
:
&
mut
WorkflowContext
)
->
WorkflowResult
<
StepResult
>
{
let
app_context
:
Arc
<
AppContext
>
=
context
.get
(
"app_context"
)
.ok_or_else
(||
WorkflowError
::
ContextValueNotFound
(
"app_context"
.to_string
()))
?
;
let
affected_models
:
Arc
<
HashSet
<
String
>>
=
context
.get
(
"affected_models"
)
.ok_or_else
(||
WorkflowError
::
ContextValueNotFound
(
"affected_models"
.to_string
()))
?
;
let
worker_urls
:
Arc
<
Vec
<
String
>>
=
context
.get
(
"worker_urls"
)
.ok_or_else
(||
WorkflowError
::
ContextValueNotFound
(
"worker_urls"
.to_string
()))
?
;
debug!
(
"Updating cache-aware policies for {} affected model(s)"
,
affected_models
.len
()
);
for
model_id
in
affected_models
.iter
()
{
let
remaining_workers
=
app_context
.worker_registry
.get_by_model_fast
(
model_id
);
if
let
Some
(
policy
)
=
app_context
.policy_registry
.get_policy
(
model_id
)
{
if
policy
.name
()
==
"cache_aware"
&&
!
remaining_workers
.is_empty
()
{
app_context
.policy_registry
.init_cache_aware_policy
(
model_id
,
&
remaining_workers
);
debug!
(
"Updated cache-aware policy for model {} ({} remaining workers)"
,
model_id
,
remaining_workers
.len
()
);
}
}
}
// Log final result at info level
if
worker_urls
.len
()
==
1
{
info!
(
"Removed worker {}"
,
worker_urls
[
0
]);
}
else
{
info!
(
"Removed {} DP-aware workers: {:?}"
,
worker_urls
.len
(),
worker_urls
);
}
Ok
(
StepResult
::
Success
)
}
fn
is_retryable
(
&
self
,
_
error
:
&
WorkflowError
)
->
bool
{
false
// Policy update is not retryable
}
}
/// Create a worker removal workflow definition
pub
fn
create_worker_removal_workflow
()
->
WorkflowDefinition
{
use
std
::
time
::
Duration
;
WorkflowDefinition
::
new
(
"worker_removal"
,
"Remove worker from router"
)
.add_step
(
StepDefinition
::
new
(
"find_workers_to_remove"
,
"Find workers to remove"
,
Arc
::
new
(
FindWorkersToRemoveStep
),
)
.with_timeout
(
Duration
::
from_secs
(
10
))
.with_retry
(
RetryPolicy
{
max_attempts
:
1
,
backoff
:
BackoffStrategy
::
Fixed
(
Duration
::
from_secs
(
0
)),
}),
)
.add_step
(
StepDefinition
::
new
(
"remove_from_policy_registry"
,
"Remove workers from policy registry"
,
Arc
::
new
(
RemoveFromPolicyRegistryStep
),
)
.with_timeout
(
Duration
::
from_secs
(
10
))
.with_retry
(
RetryPolicy
{
max_attempts
:
1
,
backoff
:
BackoffStrategy
::
Fixed
(
Duration
::
from_secs
(
0
)),
}),
)
.add_step
(
StepDefinition
::
new
(
"remove_from_worker_registry"
,
"Remove workers from worker registry"
,
Arc
::
new
(
RemoveFromWorkerRegistryStep
),
)
.with_timeout
(
Duration
::
from_secs
(
10
))
.with_retry
(
RetryPolicy
{
max_attempts
:
1
,
backoff
:
BackoffStrategy
::
Fixed
(
Duration
::
from_secs
(
0
)),
}),
)
.add_step
(
StepDefinition
::
new
(
"update_remaining_policies"
,
"Update cache-aware policies for remaining workers"
,
Arc
::
new
(
UpdateRemainingPoliciesStep
),
)
.with_timeout
(
Duration
::
from_secs
(
10
))
.with_retry
(
RetryPolicy
{
max_attempts
:
1
,
backoff
:
BackoffStrategy
::
Fixed
(
Duration
::
from_secs
(
0
)),
}),
)
}
sgl-router/src/routers/grpc/responses/conversions.rs
View file @
5dccf697
...
@@ -149,7 +149,11 @@ pub fn responses_to_chat(req: &ResponsesRequest) -> Result<ChatCompletionRequest
...
@@ -149,7 +149,11 @@ pub fn responses_to_chat(req: &ResponsesRequest) -> Result<ChatCompletionRequest
Ok
(
ChatCompletionRequest
{
Ok
(
ChatCompletionRequest
{
messages
,
messages
,
model
:
req
.model
.clone
()
.unwrap_or_else
(||
"default"
.to_string
()),
model
:
if
req
.model
.is_empty
()
{
"default"
.to_string
()
}
else
{
req
.model
.clone
()
},
temperature
:
req
.temperature
,
temperature
:
req
.temperature
,
max_completion_tokens
:
req
.max_output_tokens
,
max_completion_tokens
:
req
.max_output_tokens
,
stream
:
is_streaming
,
stream
:
is_streaming
,
...
@@ -311,7 +315,7 @@ mod tests {
...
@@ -311,7 +315,7 @@ mod tests {
let
req
=
ResponsesRequest
{
let
req
=
ResponsesRequest
{
input
:
ResponseInput
::
Text
(
"Hello, world!"
.to_string
()),
input
:
ResponseInput
::
Text
(
"Hello, world!"
.to_string
()),
instructions
:
Some
(
"You are a helpful assistant."
.to_string
()),
instructions
:
Some
(
"You are a helpful assistant."
.to_string
()),
model
:
Some
(
"gpt-4"
.to_string
()
)
,
model
:
"gpt-4"
.to_string
(),
temperature
:
Some
(
0.7
),
temperature
:
Some
(
0.7
),
..
Default
::
default
()
..
Default
::
default
()
};
};
...
...
sgl-router/src/routers/grpc/responses/handlers.rs
View file @
5dccf697
...
@@ -324,10 +324,11 @@ async fn route_responses_background(
...
@@ -324,10 +324,11 @@ async fn route_responses_background(
incomplete_details
:
None
,
incomplete_details
:
None
,
instructions
:
request
.instructions
.clone
(),
instructions
:
request
.instructions
.clone
(),
max_output_tokens
:
request
.max_output_tokens
,
max_output_tokens
:
request
.max_output_tokens
,
model
:
request
model
:
if
request
.model
.is_empty
()
{
.model
"default"
.to_string
()
.clone
()
}
else
{
.unwrap_or_else
(||
"default"
.to_string
()),
request
.model
.clone
()
},
output
:
Vec
::
new
(),
output
:
Vec
::
new
(),
parallel_tool_calls
:
request
.parallel_tool_calls
.unwrap_or
(
true
),
parallel_tool_calls
:
request
.parallel_tool_calls
.unwrap_or
(
true
),
previous_response_id
:
request
.previous_response_id
.clone
(),
previous_response_id
:
request
.previous_response_id
.clone
(),
...
@@ -622,10 +623,11 @@ async fn process_and_transform_sse_stream(
...
@@ -622,10 +623,11 @@ async fn process_and_transform_sse_stream(
// Create event emitter for OpenAI-compatible streaming
// Create event emitter for OpenAI-compatible streaming
let
response_id
=
format!
(
"resp_{}"
,
Uuid
::
new_v4
());
let
response_id
=
format!
(
"resp_{}"
,
Uuid
::
new_v4
());
let
model
=
original_request
let
model
=
if
original_request
.model
.is_empty
()
{
.model
"default"
.to_string
()
.clone
()
}
else
{
.unwrap_or_else
(||
"default"
.to_string
());
original_request
.model
.clone
()
};
let
created_at
=
chrono
::
Utc
::
now
()
.timestamp
()
as
u64
;
let
created_at
=
chrono
::
Utc
::
now
()
.timestamp
()
as
u64
;
let
mut
event_emitter
=
ResponseStreamEventEmitter
::
new
(
response_id
,
model
,
created_at
);
let
mut
event_emitter
=
ResponseStreamEventEmitter
::
new
(
response_id
,
model
,
created_at
);
...
...
sgl-router/src/routers/grpc/responses/tool_loop.rs
View file @
5dccf697
...
@@ -608,10 +608,11 @@ async fn execute_tool_loop_streaming_internal(
...
@@ -608,10 +608,11 @@ async fn execute_tool_loop_streaming_internal(
// Create response event emitter
// Create response event emitter
let
response_id
=
format!
(
"resp_{}"
,
Uuid
::
new_v4
());
let
response_id
=
format!
(
"resp_{}"
,
Uuid
::
new_v4
());
let
model
=
current_request
let
model
=
if
current_request
.model
.is_empty
()
{
.model
"default"
.to_string
()
.clone
()
}
else
{
.unwrap_or_else
(||
"default"
.to_string
());
current_request
.model
.clone
()
};
let
created_at
=
SystemTime
::
now
()
let
created_at
=
SystemTime
::
now
()
.duration_since
(
UNIX_EPOCH
)
.duration_since
(
UNIX_EPOCH
)
.unwrap
()
.unwrap
()
...
...
sgl-router/src/server.rs
View file @
5dccf697
...
@@ -22,8 +22,12 @@ use tracing::{error, info, warn, Level};
...
@@ -22,8 +22,12 @@ use tracing::{error, info, warn, Level};
use
crate
::{
use
crate
::{
config
::{
ConnectionMode
,
HistoryBackend
,
RouterConfig
,
RoutingMode
},
config
::{
ConnectionMode
,
HistoryBackend
,
RouterConfig
,
RoutingMode
},
core
::{
core
::{
worker_to_info
,
workflow
::
WorkflowEngine
,
Job
,
JobQueue
,
JobQueueConfig
,
LoadMonitor
,
worker_to_info
,
WorkerManager
,
WorkerRegistry
,
WorkerType
,
workflow
::{
create_worker_registration_workflow
,
create_worker_removal_workflow
,
LoggingSubscriber
,
WorkflowEngine
,
},
Job
,
JobQueue
,
JobQueueConfig
,
LoadMonitor
,
WorkerManager
,
WorkerRegistry
,
WorkerType
,
},
},
data_connector
::{
data_connector
::{
MemoryConversationItemStorage
,
MemoryConversationStorage
,
MemoryResponseStorage
,
MemoryConversationItemStorage
,
MemoryConversationStorage
,
MemoryResponseStorage
,
...
@@ -439,51 +443,6 @@ async fn v1_conversations_delete_item(
...
@@ -439,51 +443,6 @@ async fn v1_conversations_delete_item(
.await
.await
}
}
#[derive(Deserialize)]
struct
AddWorkerQuery
{
url
:
String
,
api_key
:
Option
<
String
>
,
}
async
fn
add_worker
(
State
(
state
):
State
<
Arc
<
AppState
>>
,
Query
(
AddWorkerQuery
{
url
,
api_key
}):
Query
<
AddWorkerQuery
>
,
)
->
Response
{
// Warn if router has API key but worker is being added without one
if
state
.context.router_config.api_key
.is_some
()
&&
api_key
.is_none
()
{
warn!
(
"Adding worker {} without API key while router has API key configured.
\
Worker will be accessible without authentication.
\
If the worker requires the same API key as the router, please specify it explicitly."
,
url
);
}
let
result
=
WorkerManager
::
add_worker
(
&
url
,
&
api_key
,
&
state
.context
)
.await
;
match
result
{
Ok
(
message
)
=>
(
StatusCode
::
OK
,
message
)
.into_response
(),
Err
(
error
)
=>
(
StatusCode
::
BAD_REQUEST
,
error
)
.into_response
(),
}
}
async
fn
list_workers
(
State
(
state
):
State
<
Arc
<
AppState
>>
)
->
Response
{
let
worker_list
=
WorkerManager
::
get_worker_urls
(
&
state
.context.worker_registry
);
Json
(
json!
({
"urls"
:
worker_list
}))
.into_response
()
}
async
fn
remove_worker
(
State
(
state
):
State
<
Arc
<
AppState
>>
,
Query
(
AddWorkerQuery
{
url
,
..
}):
Query
<
AddWorkerQuery
>
,
)
->
Response
{
let
result
=
WorkerManager
::
remove_worker
(
&
url
,
&
state
.context
);
match
result
{
Ok
(
message
)
=>
(
StatusCode
::
OK
,
message
)
.into_response
(),
Err
(
error
)
=>
(
StatusCode
::
BAD_REQUEST
,
error
)
.into_response
(),
}
}
async
fn
flush_cache
(
State
(
state
):
State
<
Arc
<
AppState
>>
,
_
req
:
Request
)
->
Response
{
async
fn
flush_cache
(
State
(
state
):
State
<
Arc
<
AppState
>>
,
_
req
:
Request
)
->
Response
{
match
WorkerManager
::
flush_cache_all
(
&
state
.context.worker_registry
,
&
state
.context.client
)
match
WorkerManager
::
flush_cache_all
(
&
state
.context.worker_registry
,
&
state
.context.client
)
.await
.await
...
@@ -566,6 +525,12 @@ async fn create_worker(
...
@@ -566,6 +525,12 @@ async fn create_worker(
);
);
}
}
// Populate dp_aware from router's configuration
let
config
=
WorkerConfigRequest
{
dp_aware
:
state
.context.router_config.dp_aware
,
..
config
};
// Submit job for async processing
// Submit job for async processing
let
worker_url
=
config
.url
.clone
();
let
worker_url
=
config
.url
.clone
();
let
job
=
Job
::
AddWorker
{
let
job
=
Job
::
AddWorker
{
...
@@ -761,9 +726,6 @@ pub fn build_app(
...
@@ -761,9 +726,6 @@ pub fn build_app(
.route
(
"/get_server_info"
,
get
(
get_server_info
));
.route
(
"/get_server_info"
,
get
(
get_server_info
));
let
admin_routes
=
Router
::
new
()
let
admin_routes
=
Router
::
new
()
.route
(
"/add_worker"
,
post
(
add_worker
))
.route
(
"/remove_worker"
,
post
(
remove_worker
))
.route
(
"/list_workers"
,
get
(
list_workers
))
.route
(
"/flush_cache"
,
post
(
flush_cache
))
.route
(
"/flush_cache"
,
post
(
flush_cache
))
.route
(
"/get_loads"
,
get
(
get_loads
))
.route
(
"/get_loads"
,
get
(
get_loads
))
.route_layer
(
axum
::
middleware
::
from_fn_with_state
(
.route_layer
(
axum
::
middleware
::
from_fn_with_state
(
...
@@ -1018,15 +980,16 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
...
@@ -1018,15 +980,16 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
engine
engine
.event_bus
()
.event_bus
()
.subscribe
(
Arc
::
new
(
crate
::
core
::
workflow
::
LoggingSubscriber
))
.subscribe
(
Arc
::
new
(
LoggingSubscriber
))
.await
;
.await
;
engine
.register_workflow
(
crate
::
core
::
workflow
::
create_worker_registration_workflow
());
engine
.register_workflow
(
create_worker_registration_workflow
());
engine
.register_workflow
(
create_worker_removal_workflow
());
app_context
app_context
.workflow_engine
.workflow_engine
.set
(
engine
)
.set
(
engine
)
.expect
(
"WorkflowEngine should only be initialized once"
);
.expect
(
"WorkflowEngine should only be initialized once"
);
info!
(
"Workflow engine initialized with worker registration workflow"
);
info!
(
"Workflow engine initialized with worker registration
and removal
workflow
s
"
);
info!
(
info!
(
"Initializing workers for routing mode: {:?}"
,
"Initializing workers for routing mode: {:?}"
,
...
...
sgl-router/src/service_discovery.rs
View file @
5dccf697
...
@@ -18,11 +18,7 @@ use rustls;
...
@@ -18,11 +18,7 @@ use rustls;
use
tokio
::{
task
,
time
};
use
tokio
::{
task
,
time
};
use
tracing
::{
debug
,
error
,
info
,
warn
};
use
tracing
::{
debug
,
error
,
info
,
warn
};
use
crate
::{
use
crate
::{
core
::
Job
,
protocols
::
worker_spec
::
WorkerConfigRequest
,
server
::
AppContext
};
core
::{
Job
,
WorkerManager
},
protocols
::
worker_spec
::
WorkerConfigRequest
,
server
::
AppContext
,
};
#[derive(Debug,
Clone)]
#[derive(Debug,
Clone)]
pub
struct
ServiceDiscoveryConfig
{
pub
struct
ServiceDiscoveryConfig
{
...
@@ -386,7 +382,7 @@ async fn handle_pod_event(
...
@@ -386,7 +382,7 @@ async fn handle_pod_event(
reasoning_parser
:
None
,
reasoning_parser
:
None
,
tool_parser
:
None
,
tool_parser
:
None
,
chat_template
:
None
,
chat_template
:
None
,
api_key
:
N
one
,
api_key
:
app_context
.router_config.api_key
.cl
one
()
,
health_check_timeout_secs
:
app_context
.router_config.health_check.timeout_secs
,
health_check_timeout_secs
:
app_context
.router_config.health_check.timeout_secs
,
health_check_interval_secs
:
app_context
health_check_interval_secs
:
app_context
.router_config
.router_config
...
@@ -453,8 +449,24 @@ async fn handle_pod_deletion(
...
@@ -453,8 +449,24 @@ async fn handle_pod_deletion(
pod_info
.name
,
pod_info
.pod_type
,
worker_url
pod_info
.name
,
pod_info
.pod_type
,
worker_url
);
);
if
let
Err
(
e
)
=
WorkerManager
::
remove_worker
(
&
worker_url
,
&
app_context
)
{
let
job
=
Job
::
RemoveWorker
{
error!
(
"Failed to remove worker {}: {}"
,
worker_url
,
e
);
url
:
worker_url
.clone
(),
};
if
let
Some
(
job_queue
)
=
app_context
.worker_job_queue
.get
()
{
if
let
Err
(
e
)
=
job_queue
.submit
(
job
)
.await
{
error!
(
"Failed to submit worker removal job for {}: {}"
,
worker_url
,
e
);
}
else
{
debug!
(
"Submitted worker removal job for {}"
,
worker_url
);
}
}
else
{
error!
(
"JobQueue not initialized, cannot remove worker {}"
,
worker_url
);
}
}
}
else
{
}
else
{
debug!
(
debug!
(
...
...
sgl-router/tests/api_endpoints_test.rs
View file @
5dccf697
...
@@ -14,7 +14,7 @@ use sglang_router_rs::{
...
@@ -14,7 +14,7 @@ use sglang_router_rs::{
config
::{
config
::{
CircuitBreakerConfig
,
ConnectionMode
,
PolicyConfig
,
RetryConfig
,
RouterConfig
,
RoutingMode
,
CircuitBreakerConfig
,
ConnectionMode
,
PolicyConfig
,
RetryConfig
,
RouterConfig
,
RoutingMode
,
},
},
core
::
WorkerManager
,
core
::
Job
,
routers
::{
RouterFactory
,
RouterTrait
},
routers
::{
RouterFactory
,
RouterTrait
},
server
::
AppContext
,
server
::
AppContext
,
};
};
...
@@ -112,22 +112,51 @@ impl TestContext {
...
@@ -112,22 +112,51 @@ impl TestContext {
// Create app context
// Create app context
let
app_context
=
common
::
create_test_context
(
config
.clone
());
let
app_context
=
common
::
create_test_context
(
config
.clone
());
//
Initialize
worker
s
in
the registry before creating router
//
Submit
worker in
itialization job (same as real server does)
if
!
worker_urls
.is_empty
()
{
if
!
worker_urls
.is_empty
()
{
WorkerManager
::
initialize_workers
(
&
config
,
&
app_context
.worker_registry
,
None
)
let
job_queue
=
app_context
.worker_job_queue
.get
()
.expect
(
"JobQueue should be initialized"
);
let
job
=
Job
::
InitializeWorkersFromConfig
{
router_config
:
Box
::
new
(
config
.clone
()),
};
job_queue
.submit
(
job
)
.await
.await
.expect
(
"Failed to initialize workers"
);
.expect
(
"Failed to submit worker initialization job"
);
// Poll until all workers are healthy (up to 10 seconds)
let
expected_count
=
worker_urls
.len
();
let
start
=
tokio
::
time
::
Instant
::
now
();
let
timeout_duration
=
tokio
::
time
::
Duration
::
from_secs
(
10
);
loop
{
let
healthy_workers
=
app_context
.worker_registry
.get_all
()
.iter
()
.filter
(|
w
|
w
.is_healthy
())
.count
();
if
healthy_workers
>=
expected_count
{
break
;
}
if
start
.elapsed
()
>
timeout_duration
{
panic!
(
"Timeout waiting for {} workers to become healthy (only {} ready)"
,
expected_count
,
healthy_workers
);
}
tokio
::
time
::
sleep
(
tokio
::
time
::
Duration
::
from_millis
(
100
))
.await
;
}
}
}
// Create router
// Create router
let
router
=
RouterFactory
::
create_router
(
&
app_context
)
.await
.unwrap
();
let
router
=
RouterFactory
::
create_router
(
&
app_context
)
.await
.unwrap
();
let
router
=
Arc
::
from
(
router
);
let
router
=
Arc
::
from
(
router
);
// Wait for router to discover workers
if
!
workers
.is_empty
()
{
tokio
::
time
::
sleep
(
tokio
::
time
::
Duration
::
from_millis
(
500
))
.await
;
}
Self
{
Self
{
workers
,
workers
,
router
,
router
,
...
@@ -711,221 +740,6 @@ mod model_info_tests {
...
@@ -711,221 +740,6 @@ mod model_info_tests {
}
}
}
}
#[cfg(test)]
mod
worker_management_tests
{
use
super
::
*
;
#[tokio::test]
async
fn
test_add_new_worker
()
{
let
ctx
=
TestContext
::
new
(
vec!
[])
.await
;
let
app
=
ctx
.create_app
()
.await
;
// Start a mock worker
let
mut
worker
=
MockWorker
::
new
(
MockWorkerConfig
{
port
:
18301
,
worker_type
:
WorkerType
::
Regular
,
health_status
:
HealthStatus
::
Healthy
,
response_delay_ms
:
0
,
fail_rate
:
0.0
,
});
let
url
=
worker
.start
()
.await
.unwrap
();
// Add the worker
let
req
=
Request
::
builder
()
.method
(
"POST"
)
.uri
(
format!
(
"/add_worker?url={}"
,
url
))
.body
(
Body
::
empty
())
.unwrap
();
let
resp
=
app
.clone
()
.oneshot
(
req
)
.await
.unwrap
();
assert_eq!
(
resp
.status
(),
StatusCode
::
OK
);
// List workers to verify
let
req
=
Request
::
builder
()
.method
(
"GET"
)
.uri
(
"/list_workers"
)
.body
(
Body
::
empty
())
.unwrap
();
let
resp
=
app
.oneshot
(
req
)
.await
.unwrap
();
assert_eq!
(
resp
.status
(),
StatusCode
::
OK
);
let
body
=
axum
::
body
::
to_bytes
(
resp
.into_body
(),
usize
::
MAX
)
.await
.unwrap
();
let
body_json
:
serde_json
::
Value
=
serde_json
::
from_slice
(
&
body
)
.unwrap
();
let
workers
=
body_json
[
"urls"
]
.as_array
()
.unwrap
();
assert
!
(
workers
.iter
()
.any
(|
w
|
w
.as_str
()
.unwrap
()
==
url
));
worker
.stop
()
.await
;
ctx
.shutdown
()
.await
;
}
#[tokio::test]
async
fn
test_remove_existing_worker
()
{
let
ctx
=
TestContext
::
new
(
vec!
[
MockWorkerConfig
{
port
:
18302
,
worker_type
:
WorkerType
::
Regular
,
health_status
:
HealthStatus
::
Healthy
,
response_delay_ms
:
0
,
fail_rate
:
0.0
,
}])
.await
;
let
app
=
ctx
.create_app
()
.await
;
// Get the worker URL
let
req
=
Request
::
builder
()
.method
(
"GET"
)
.uri
(
"/list_workers"
)
.body
(
Body
::
empty
())
.unwrap
();
let
resp
=
app
.clone
()
.oneshot
(
req
)
.await
.unwrap
();
let
body
=
axum
::
body
::
to_bytes
(
resp
.into_body
(),
usize
::
MAX
)
.await
.unwrap
();
let
body_json
:
serde_json
::
Value
=
serde_json
::
from_slice
(
&
body
)
.unwrap
();
let
workers
=
body_json
[
"urls"
]
.as_array
()
.unwrap
();
let
worker_url
=
workers
[
0
]
.as_str
()
.unwrap
();
// Remove the worker
let
req
=
Request
::
builder
()
.method
(
"POST"
)
.uri
(
format!
(
"/remove_worker?url={}"
,
worker_url
))
.body
(
Body
::
empty
())
.unwrap
();
let
resp
=
app
.clone
()
.oneshot
(
req
)
.await
.unwrap
();
assert_eq!
(
resp
.status
(),
StatusCode
::
OK
);
let
req
=
Request
::
builder
()
.method
(
"GET"
)
.uri
(
"/list_workers"
)
.body
(
Body
::
empty
())
.unwrap
();
let
resp
=
app
.oneshot
(
req
)
.await
.unwrap
();
let
body
=
axum
::
body
::
to_bytes
(
resp
.into_body
(),
usize
::
MAX
)
.await
.unwrap
();
let
body_json
:
serde_json
::
Value
=
serde_json
::
from_slice
(
&
body
)
.unwrap
();
let
workers
=
body_json
[
"urls"
]
.as_array
()
.unwrap
();
assert
!
(
workers
.is_empty
());
ctx
.shutdown
()
.await
;
}
#[tokio::test]
async
fn
test_add_worker_invalid_url
()
{
let
ctx
=
TestContext
::
new
(
vec!
[])
.await
;
let
app
=
ctx
.create_app
()
.await
;
// Invalid URL format
let
req
=
Request
::
builder
()
.method
(
"POST"
)
.uri
(
"/add_worker?url=not-a-valid-url"
)
.body
(
Body
::
empty
())
.unwrap
();
let
resp
=
app
.clone
()
.oneshot
(
req
)
.await
.unwrap
();
assert_eq!
(
resp
.status
(),
StatusCode
::
BAD_REQUEST
);
// Missing URL parameter
let
req
=
Request
::
builder
()
.method
(
"POST"
)
.uri
(
"/add_worker"
)
.body
(
Body
::
empty
())
.unwrap
();
let
resp
=
app
.clone
()
.oneshot
(
req
)
.await
.unwrap
();
assert_eq!
(
resp
.status
(),
StatusCode
::
BAD_REQUEST
);
// Empty URL
let
req
=
Request
::
builder
()
.method
(
"POST"
)
.uri
(
"/add_worker?url="
)
.body
(
Body
::
empty
())
.unwrap
();
let
resp
=
app
.oneshot
(
req
)
.await
.unwrap
();
assert_eq!
(
resp
.status
(),
StatusCode
::
BAD_REQUEST
);
ctx
.shutdown
()
.await
;
}
#[tokio::test]
async
fn
test_add_duplicate_worker
()
{
// Start a mock worker
let
mut
worker
=
MockWorker
::
new
(
MockWorkerConfig
{
port
:
18303
,
worker_type
:
WorkerType
::
Regular
,
health_status
:
HealthStatus
::
Healthy
,
response_delay_ms
:
0
,
fail_rate
:
0.0
,
});
let
url
=
worker
.start
()
.await
.unwrap
();
let
ctx
=
TestContext
::
new
(
vec!
[])
.await
;
let
app
=
ctx
.create_app
()
.await
;
// Add worker first time
let
req
=
Request
::
builder
()
.method
(
"POST"
)
.uri
(
format!
(
"/add_worker?url={}"
,
url
))
.body
(
Body
::
empty
())
.unwrap
();
let
resp
=
app
.clone
()
.oneshot
(
req
)
.await
.unwrap
();
assert_eq!
(
resp
.status
(),
StatusCode
::
OK
);
tokio
::
time
::
sleep
(
tokio
::
time
::
Duration
::
from_millis
(
500
))
.await
;
// Try to add same worker again
let
req
=
Request
::
builder
()
.method
(
"POST"
)
.uri
(
format!
(
"/add_worker?url={}"
,
url
))
.body
(
Body
::
empty
())
.unwrap
();
let
resp
=
app
.oneshot
(
req
)
.await
.unwrap
();
// Should return error for duplicate
assert_eq!
(
resp
.status
(),
StatusCode
::
BAD_REQUEST
);
worker
.stop
()
.await
;
ctx
.shutdown
()
.await
;
}
#[tokio::test]
async
fn
test_add_unhealthy_worker
()
{
// Start unhealthy worker
let
mut
worker
=
MockWorker
::
new
(
MockWorkerConfig
{
port
:
18304
,
worker_type
:
WorkerType
::
Regular
,
health_status
:
HealthStatus
::
Unhealthy
,
response_delay_ms
:
0
,
fail_rate
:
0.0
,
});
let
url
=
worker
.start
()
.await
.unwrap
();
let
ctx
=
TestContext
::
new
(
vec!
[])
.await
;
let
app
=
ctx
.create_app
()
.await
;
// Try to add unhealthy worker
let
req
=
Request
::
builder
()
.method
(
"POST"
)
.uri
(
format!
(
"/add_worker?url={}"
,
url
))
.body
(
Body
::
empty
())
.unwrap
();
let
resp
=
app
.oneshot
(
req
)
.await
.unwrap
();
// Router should reject unhealthy workers
assert
!
(
resp
.status
()
==
StatusCode
::
BAD_REQUEST
||
resp
.status
()
==
StatusCode
::
SERVICE_UNAVAILABLE
);
worker
.stop
()
.await
;
ctx
.shutdown
()
.await
;
}
}
#[cfg(test)]
#[cfg(test)]
mod
router_policy_tests
{
mod
router_policy_tests
{
use
super
::
*
;
use
super
::
*
;
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment