Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
09ae5b20
Unverified
Commit
09ae5b20
authored
Jun 18, 2025
by
Simo Lin
Committed by
GitHub
Jun 19, 2025
Browse files
Merge PDLB (Prefill-Decode Load Balancer) into SGLang Router (#7096)
parent
712bf9ec
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
4045 additions
and
187 deletions
+4045
-187
python/sglang/srt/disaggregation/mini_lb.py
python/sglang/srt/disaggregation/mini_lb.py
+27
-3
sgl-router/Cargo.toml
sgl-router/Cargo.toml
+3
-1
sgl-router/py_src/sglang_router/launch_router.py
sgl-router/py_src/sglang_router/launch_router.py
+107
-6
sgl-router/py_src/sglang_router/router.py
sgl-router/py_src/sglang_router/router.py
+12
-2
sgl-router/py_test/test_launch_router.py
sgl-router/py_test/test_launch_router.py
+116
-2
sgl-router/src/lib.rs
sgl-router/src/lib.rs
+85
-18
sgl-router/src/openai_api_types.rs
sgl-router/src/openai_api_types.rs
+704
-0
sgl-router/src/pd_router.rs
sgl-router/src/pd_router.rs
+1002
-0
sgl-router/src/pd_types.rs
sgl-router/src/pd_types.rs
+245
-0
sgl-router/src/request_adapter.rs
sgl-router/src/request_adapter.rs
+264
-0
sgl-router/src/router.rs
sgl-router/src/router.rs
+420
-122
sgl-router/src/server.rs
sgl-router/src/server.rs
+156
-33
sgl-router/tests/test_pd_routing.rs
sgl-router/tests/test_pd_routing.rs
+904
-0
No files found.
python/sglang/srt/disaggregation/mini_lb.py
View file @
09ae5b20
...
...
@@ -218,15 +218,39 @@ async def get_server_info():
)
prefill_infos
=
[]
decode_infos
=
[]
all_internal_states
=
[]
async
with
aiohttp
.
ClientSession
()
as
session
:
for
server
in
chain
(
prefill_servers
):
server_info
=
await
session
.
get
(
f
"
{
server
}
/get_server_info"
)
prefill_infos
.
append
(
await
server_info
.
json
())
for
server
in
chain
(
decode_servers
):
server_info
=
await
session
.
get
(
f
"
{
server
}
/get_server_info"
)
decode_infos
.
append
(
await
server_info
.
json
())
return
{
"prefill"
:
prefill_infos
,
"decode"
:
decode_infos
}
info_json
=
await
server_info
.
json
()
decode_infos
.
append
(
info_json
)
# Extract internal_states from decode servers
if
"internal_states"
in
info_json
:
all_internal_states
.
extend
(
info_json
[
"internal_states"
])
# Return format expected by bench_one_batch_server.py
if
all_internal_states
:
return
{
"internal_states"
:
all_internal_states
,
"prefill"
:
prefill_infos
,
"decode"
:
decode_infos
,
}
else
:
# Fallback with dummy data if no internal states found
return
{
"internal_states"
:
[
{
"last_gen_throughput"
:
0.0
,
"avg_spec_accept_length"
:
None
,
}
],
"prefill"
:
prefill_infos
,
"decode"
:
decode_infos
,
}
@
app
.
get
(
"/get_model_info"
)
...
...
sgl-router/Cargo.toml
View file @
09ae5b20
...
...
@@ -15,7 +15,7 @@ serde = { version = "1.0", features = ["derive"] }
clap
=
{
version
=
"4.4"
,
features
=
["derive"]
}
bytes
=
"1.8.0"
rand
=
"0.8.5"
reqwest
=
{
version
=
"0.12.8"
,
features
=
[
"stream"
,
"blocking"
]
}
reqwest
=
{
version
=
"0.12.8"
,
features
=
[
"stream"
,
"blocking"
,
"json"
]
}
futures-util
=
"0.3"
serde_json
=
"1.0"
pyo3
=
{
version
=
"0.22.5"
,
features
=
["extension-module"]
}
...
...
@@ -33,6 +33,8 @@ futures = "0.3"
# Added for metrics
metrics
=
"0.24.2"
metrics-exporter-prometheus
=
"0.17.0"
# Added for request tracing
uuid
=
{
version
=
"1.10"
,
features
=
[
"v4"
,
"serde"
]
}
[profile.release]
lto
=
"thin"
codegen-units
=
1
sgl-router/py_src/sglang_router/launch_router.py
View file @
09ae5b20
...
...
@@ -31,6 +31,13 @@ class RouterArgs:
host
:
str
=
"127.0.0.1"
port
:
int
=
30000
# PD-specific configuration
pd_disaggregated
:
bool
=
False
# Enable PD disaggregated mode
prefill_urls
:
List
[
tuple
]
=
dataclasses
.
field
(
default_factory
=
list
)
# List of (url, bootstrap_port)
decode_urls
:
List
[
str
]
=
dataclasses
.
field
(
default_factory
=
list
)
# Routing policy
policy
:
str
=
"cache_aware"
worker_startup_timeout_secs
:
int
=
300
...
...
@@ -40,7 +47,7 @@ class RouterArgs:
balance_rel_threshold
:
float
=
1.0001
eviction_interval
:
int
=
60
max_tree_size
:
int
=
2
**
24
max_payload_size
:
int
=
4
*
1024
*
1024
#
4MB
max_payload_size
:
int
=
256
*
1024
*
1024
#
256MB default for large batches
verbose
:
bool
=
False
log_dir
:
Optional
[
str
]
=
None
# Service discovery configuration
...
...
@@ -95,8 +102,29 @@ class RouterArgs:
f
"--
{
prefix
}
policy"
,
type
=
str
,
default
=
RouterArgs
.
policy
,
choices
=
[
"random"
,
"round_robin"
,
"cache_aware"
],
help
=
"Load balancing policy to use"
,
choices
=
[
"random"
,
"round_robin"
,
"cache_aware"
,
"power_of_two"
],
help
=
"Load balancing policy to use. Note: power_of_two is only available in PD disaggregated mode"
,
)
# PD-specific arguments
parser
.
add_argument
(
f
"--
{
prefix
}
pd-disaggregated"
,
action
=
"store_true"
,
help
=
"Enable PD (Prefill-Decode) disaggregated mode"
,
)
parser
.
add_argument
(
f
"--
{
prefix
}
prefill"
,
nargs
=
2
,
action
=
"append"
,
metavar
=
(
"URL"
,
"BOOTSTRAP_PORT"
),
help
=
"Prefill server URL and bootstrap port. Can be specified multiple times. BOOTSTRAP_PORT can be 'none' for no bootstrap port."
,
)
parser
.
add_argument
(
f
"--
{
prefix
}
decode"
,
nargs
=
1
,
action
=
"append"
,
metavar
=
(
"URL"
,),
help
=
"Decode server URL. Can be specified multiple times."
,
)
parser
.
add_argument
(
f
"--
{
prefix
}
worker-startup-timeout-secs"
,
...
...
@@ -205,11 +233,19 @@ class RouterArgs:
use_router_prefix: If True, look for arguments with 'router-' prefix
"""
prefix
=
"router_"
if
use_router_prefix
else
""
worker_urls
=
args
.
worker_urls
if
args
.
worker_urls
is
not
None
else
[]
worker_urls
=
getattr
(
args
,
"worker_urls"
,
[])
# Parse PD URLs
prefill_urls
=
cls
.
_parse_prefill_urls
(
getattr
(
args
,
f
"
{
prefix
}
prefill"
,
None
))
decode_urls
=
cls
.
_parse_decode_urls
(
getattr
(
args
,
f
"
{
prefix
}
decode"
,
None
))
return
cls
(
worker_urls
=
worker_urls
,
host
=
args
.
host
,
port
=
args
.
port
,
pd_disaggregated
=
getattr
(
args
,
f
"
{
prefix
}
pd_disaggregated"
,
False
),
prefill_urls
=
prefill_urls
,
decode_urls
=
decode_urls
,
policy
=
getattr
(
args
,
f
"
{
prefix
}
policy"
),
worker_startup_timeout_secs
=
getattr
(
args
,
f
"
{
prefix
}
worker_startup_timeout_secs"
...
...
@@ -247,6 +283,46 @@ class RouterArgs:
selector
[
key
]
=
value
return
selector
@
staticmethod
def
_parse_prefill_urls
(
prefill_list
):
"""Parse prefill URLs from --prefill arguments.
Format: --prefill URL BOOTSTRAP_PORT
Example: --prefill http://prefill1:8080 9000 --prefill http://prefill2:8080 none
"""
if
not
prefill_list
:
return
[]
prefill_urls
=
[]
for
url
,
bootstrap_port_str
in
prefill_list
:
# Handle 'none' as None
if
bootstrap_port_str
.
lower
()
==
"none"
:
bootstrap_port
=
None
else
:
try
:
bootstrap_port
=
int
(
bootstrap_port_str
)
except
ValueError
:
raise
ValueError
(
f
"Invalid bootstrap port:
{
bootstrap_port_str
}
. Must be a number or 'none'"
)
prefill_urls
.
append
((
url
,
bootstrap_port
))
return
prefill_urls
@
staticmethod
def
_parse_decode_urls
(
decode_list
):
"""Parse decode URLs from --decode arguments.
Format: --decode URL
Example: --decode http://decode1:8081 --decode http://decode2:8081
"""
if
not
decode_list
:
return
[]
# decode_list is a list of single-element lists due to nargs=1
return
[
url
[
0
]
for
url
in
decode_list
]
def
policy_from_str
(
policy_str
:
str
)
->
PolicyType
:
"""Convert policy string to PolicyType enum."""
...
...
@@ -254,6 +330,7 @@ def policy_from_str(policy_str: str) -> PolicyType:
"random"
:
PolicyType
.
Random
,
"round_robin"
:
PolicyType
.
RoundRobin
,
"cache_aware"
:
PolicyType
.
CacheAware
,
"power_of_two"
:
PolicyType
.
PowerOfTwo
,
}
return
policy_map
[
policy_str
]
...
...
@@ -277,8 +354,19 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]:
else
:
router_args
=
args
# Validate configuration based on mode
if
router_args
.
pd_disaggregated
:
# Validate PD configuration
if
not
router_args
.
prefill_urls
:
raise
ValueError
(
"PD disaggregated mode requires --prefill"
)
if
not
router_args
.
decode_urls
:
raise
ValueError
(
"PD disaggregated mode requires --decode"
)
# Create router with unified constructor
router
=
Router
(
worker_urls
=
router_args
.
worker_urls
,
worker_urls
=
(
router_args
.
worker_urls
if
not
router_args
.
pd_disaggregated
else
[]
),
host
=
router_args
.
host
,
port
=
router_args
.
port
,
policy
=
policy_from_str
(
router_args
.
policy
),
...
...
@@ -298,6 +386,13 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]:
service_discovery_namespace
=
router_args
.
service_discovery_namespace
,
prometheus_port
=
router_args
.
prometheus_port
,
prometheus_host
=
router_args
.
prometheus_host
,
pd_disaggregated
=
router_args
.
pd_disaggregated
,
prefill_urls
=
(
router_args
.
prefill_urls
if
router_args
.
pd_disaggregated
else
None
),
decode_urls
=
(
router_args
.
decode_urls
if
router_args
.
pd_disaggregated
else
None
),
)
router
.
start
()
...
...
@@ -326,8 +421,14 @@ This launcher enables starting a router with individual worker instances. It is
multi-node setups or when you want to start workers and router separately.
Examples:
# Regular mode
python -m sglang_router.launch_router --worker-urls http://worker1:8000 http://worker2:8000
python -m sglang_router.launch_router --worker-urls http://worker1:8000 http://worker2:8000 --cache-threshold 0.7 --balance-abs-threshold 64 --balance-rel-threshold 1.2
# PD disaggregated mode
python -m sglang_router.launch_router --pd-disaggregated
\\
--prefill http://prefill1:8000 9000 --prefill http://prefill2:8000 none
\\
--decode http://decode1:8001 --decode http://decode2:8001
\\
--policy cache_aware
"""
,
formatter_class
=
CustomHelpFormatter
,
...
...
sgl-router/py_src/sglang_router/router.py
View file @
09ae5b20
...
...
@@ -15,6 +15,7 @@ class Router:
- PolicyType.Random: Randomly select workers
- PolicyType.RoundRobin: Distribute requests in round-robin fashion
- PolicyType.CacheAware: Distribute requests based on cache state and load balance
- PolicyType.PowerOfTwo: Select best of two random workers based on load (PD mode only)
host: Host address to bind the router server. Default: '127.0.0.1'
port: Port number to bind the router server. Default: 3001
worker_startup_timeout_secs: Timeout in seconds for worker startup. Default: 300
...
...
@@ -28,7 +29,7 @@ class Router:
AND max_load > min_load * rel_threshold. Otherwise, use cache aware. Default: 1.0001
eviction_interval_secs: Interval in seconds between cache eviction operations in cache-aware
routing. Default: 60
max_payload_size: Maximum payload size in bytes. Default:
4
MB
max_payload_size: Maximum payload size in bytes. Default:
256
MB
max_tree_size: Maximum size of the approximation tree for cache-aware routing. Default: 2^24
verbose: Enable verbose logging. Default: False
log_dir: Directory to store log files. If None, logs are only output to console. Default: None
...
...
@@ -42,6 +43,9 @@ class Router:
watches pods across all namespaces (requires cluster-wide permissions). Default: None
prometheus_port: Port to expose Prometheus metrics. Default: None
prometheus_host: Host address to bind the Prometheus metrics server. Default: None
pd_disaggregated: Enable PD (Prefill-Decode) disaggregated mode. Default: False
prefill_urls: List of (url, bootstrap_port) tuples for prefill servers (PD mode only)
decode_urls: List of URLs for decode servers (PD mode only)
"""
def
__init__
(
...
...
@@ -57,7 +61,7 @@ class Router:
balance_rel_threshold
:
float
=
1.0001
,
eviction_interval_secs
:
int
=
60
,
max_tree_size
:
int
=
2
**
24
,
max_payload_size
:
int
=
4
*
1024
*
1024
,
#
4
MB
max_payload_size
:
int
=
256
*
1024
*
1024
,
#
256
MB
verbose
:
bool
=
False
,
log_dir
:
Optional
[
str
]
=
None
,
service_discovery
:
bool
=
False
,
...
...
@@ -66,6 +70,9 @@ class Router:
service_discovery_namespace
:
Optional
[
str
]
=
None
,
prometheus_port
:
Optional
[
int
]
=
None
,
prometheus_host
:
Optional
[
str
]
=
None
,
pd_disaggregated
:
bool
=
False
,
prefill_urls
:
Optional
[
List
[
tuple
]]
=
None
,
decode_urls
:
Optional
[
List
[
str
]]
=
None
,
):
if
selector
is
None
:
selector
=
{}
...
...
@@ -91,6 +98,9 @@ class Router:
service_discovery_namespace
=
service_discovery_namespace
,
prometheus_port
=
prometheus_port
,
prometheus_host
=
prometheus_host
,
pd_disaggregated
=
pd_disaggregated
,
prefill_urls
=
prefill_urls
,
decode_urls
=
decode_urls
,
)
def
start
(
self
)
->
None
:
...
...
sgl-router/py_test/test_launch_router.py
View file @
09ae5b20
...
...
@@ -35,13 +35,21 @@ class TestLaunchRouter(unittest.TestCase):
balance_rel_threshold
=
1.0001
,
eviction_interval
=
60
,
max_tree_size
=
2
**
24
,
max_payload_size
=
4
*
1024
*
1024
,
#
4
MB
max_payload_size
=
256
*
1024
*
1024
,
#
256
MB
verbose
=
False
,
log_dir
=
None
,
service_discovery
=
False
,
selector
=
None
,
service_discovery_port
=
80
,
service_discovery_namespace
=
None
,
prometheus_port
=
None
,
prometheus_host
=
None
,
# PD-specific attributes
pd_disaggregated
=
False
,
prefill
=
None
,
decode
=
None
,
# Keep worker_urls for regular mode
worker_urls
=
[],
)
def
create_router_args
(
self
,
**
kwargs
):
...
...
@@ -81,7 +89,7 @@ class TestLaunchRouter(unittest.TestCase):
def
test_launch_router_with_empty_worker_urls
(
self
):
args
=
self
.
create_router_args
(
worker_urls
=
[])
self
.
run_router_process
(
args
)
self
.
run_router_process
(
args
)
# Expected error
def
test_launch_router_with_service_discovery
(
self
):
# Test router startup with service discovery enabled but no selectors
...
...
@@ -100,6 +108,112 @@ class TestLaunchRouter(unittest.TestCase):
)
self
.
run_router_process
(
args
)
def
test_launch_router_pd_mode_basic
(
self
):
"""Test basic PD router functionality without actually starting servers."""
# This test just verifies the PD router can be created and configured
# without actually starting it (which would require real prefill/decode servers)
from
sglang_router
import
Router
from
sglang_router.launch_router
import
RouterArgs
from
sglang_router_rs
import
PolicyType
# Test RouterArgs parsing for PD mode
# Simulate the parsed args structure from argparse with action="append"
args
=
self
.
create_router_args
(
pd_disaggregated
=
True
,
policy
=
"power_of_two"
,
# PowerOfTwo is only valid in PD mode
prefill
=
[
[
"http://prefill1:8080"
,
"9000"
],
[
"http://prefill2:8080"
,
"none"
],
],
decode
=
[
[
"http://decode1:8081"
],
[
"http://decode2:8081"
],
],
worker_urls
=
[],
# Empty for PD mode
)
router_args
=
RouterArgs
.
from_cli_args
(
args
)
self
.
assertTrue
(
router_args
.
pd_disaggregated
)
self
.
assertEqual
(
router_args
.
policy
,
"power_of_two"
)
self
.
assertEqual
(
len
(
router_args
.
prefill_urls
),
2
)
self
.
assertEqual
(
len
(
router_args
.
decode_urls
),
2
)
# Verify the parsed URLs and bootstrap ports
self
.
assertEqual
(
router_args
.
prefill_urls
[
0
],
(
"http://prefill1:8080"
,
9000
))
self
.
assertEqual
(
router_args
.
prefill_urls
[
1
],
(
"http://prefill2:8080"
,
None
))
self
.
assertEqual
(
router_args
.
decode_urls
[
0
],
"http://decode1:8081"
)
self
.
assertEqual
(
router_args
.
decode_urls
[
1
],
"http://decode2:8081"
)
# Test Router creation in PD mode
router
=
Router
(
worker_urls
=
[],
# Empty for PD mode
pd_disaggregated
=
True
,
prefill_urls
=
[
(
"http://prefill1:8080"
,
9000
),
(
"http://prefill2:8080"
,
None
),
],
decode_urls
=
[
"http://decode1:8081"
,
"http://decode2:8081"
],
policy
=
PolicyType
.
CacheAware
,
host
=
"127.0.0.1"
,
port
=
3001
,
)
self
.
assertIsNotNone
(
router
)
def
test_policy_validation
(
self
):
"""Test that policy validation works correctly for PD and regular modes."""
from
sglang_router.launch_router
import
RouterArgs
,
launch_router
# Test 1: PowerOfTwo is only valid in PD mode
args
=
self
.
create_router_args
(
pd_disaggregated
=
False
,
policy
=
"power_of_two"
,
worker_urls
=
[
"http://localhost:8000"
],
)
# Should raise error
with
self
.
assertRaises
(
ValueError
)
as
cm
:
launch_router
(
args
)
self
.
assertIn
(
"PowerOfTwo policy is only supported in PD disaggregated mode"
,
str
(
cm
.
exception
),
)
# Test 2: RoundRobin is not valid in PD mode
args
=
self
.
create_router_args
(
pd_disaggregated
=
True
,
policy
=
"round_robin"
,
prefill
=
[[
"http://prefill1:8080"
,
"9000"
]],
decode
=
[[
"http://decode1:8081"
]],
worker_urls
=
[],
)
# Should raise error
with
self
.
assertRaises
(
ValueError
)
as
cm
:
launch_router
(
args
)
self
.
assertIn
(
"RoundRobin policy is not supported in PD disaggregated mode"
,
str
(
cm
.
exception
),
)
# Test 3: Valid combinations should not raise errors
# Regular mode with RoundRobin
args
=
self
.
create_router_args
(
pd_disaggregated
=
False
,
policy
=
"round_robin"
,
worker_urls
=
[
"http://localhost:8000"
],
)
# This should not raise (though it may fail to connect)
# PD mode with PowerOfTwo
args
=
self
.
create_router_args
(
pd_disaggregated
=
True
,
policy
=
"power_of_two"
,
prefill
=
[[
"http://prefill1:8080"
,
"9000"
]],
decode
=
[[
"http://decode1:8081"
]],
worker_urls
=
[],
)
# This should not raise (though it may fail to connect)
if
__name__
==
"__main__"
:
unittest
.
main
()
sgl-router/src/lib.rs
View file @
09ae5b20
use
pyo3
::
prelude
::
*
;
pub
mod
logging
;
use
std
::
collections
::
HashMap
;
pub
mod
openai_api_types
;
pub
mod
pd_router
;
pub
mod
pd_types
;
pub
mod
prometheus
;
pub
mod
request_adapter
;
pub
mod
router
;
pub
mod
server
;
pub
mod
service_discovery
;
...
...
@@ -14,6 +18,7 @@ pub enum PolicyType {
Random
,
RoundRobin
,
CacheAware
,
PowerOfTwo
,
// Moved from PD-specific, now shared
}
#[pyclass]
...
...
@@ -39,6 +44,12 @@ struct Router {
service_discovery_namespace
:
Option
<
String
>
,
prometheus_port
:
Option
<
u16
>
,
prometheus_host
:
Option
<
String
>
,
request_timeout_secs
:
u64
,
// PD mode flag
pd_disaggregated
:
bool
,
// PD-specific fields (only used when pd_disaggregated is true)
prefill_urls
:
Option
<
Vec
<
(
String
,
Option
<
u16
>
)
>>
,
decode_urls
:
Option
<
Vec
<
String
>>
,
}
#[pymethods]
...
...
@@ -56,7 +67,7 @@ impl Router {
balance_rel_threshold
=
1.0001
,
eviction_interval_secs
=
60
,
max_tree_size
=
2u
size
.
pow(
24
),
max_payload_size
=
4
*
1024
*
1024
,
max_payload_size
=
256
*
1024
*
1024
,
// 256MB default for large batches
verbose
=
false
,
log_dir
=
None,
service_discovery
=
false
,
...
...
@@ -64,7 +75,11 @@ impl Router {
service_discovery_port
=
80
,
service_discovery_namespace
=
None,
prometheus_port
=
None,
prometheus_host
=
None
prometheus_host
=
None,
request_timeout_secs
=
600
,
// Add configurable request timeout
pd_disaggregated
=
false
,
// New flag for PD mode
prefill_urls
=
None,
decode_urls
=
None
))]
fn
new
(
worker_urls
:
Vec
<
String
>
,
...
...
@@ -87,6 +102,10 @@ impl Router {
service_discovery_namespace
:
Option
<
String
>
,
prometheus_port
:
Option
<
u16
>
,
prometheus_host
:
Option
<
String
>
,
request_timeout_secs
:
u64
,
pd_disaggregated
:
bool
,
prefill_urls
:
Option
<
Vec
<
(
String
,
Option
<
u16
>
)
>>
,
decode_urls
:
Option
<
Vec
<
String
>>
,
)
->
PyResult
<
Self
>
{
Ok
(
Router
{
host
,
...
...
@@ -109,28 +128,75 @@ impl Router {
service_discovery_namespace
,
prometheus_port
,
prometheus_host
,
request_timeout_secs
,
pd_disaggregated
,
prefill_urls
,
decode_urls
,
})
}
fn
start
(
&
self
)
->
PyResult
<
()
>
{
let
policy_config
=
match
&
self
.policy
{
PolicyType
::
Random
=>
router
::
PolicyConfig
::
RandomConfig
{
timeout_secs
:
self
.worker_startup_timeout_secs
,
interval_secs
:
self
.worker_startup_check_interval
,
},
PolicyType
::
RoundRobin
=>
router
::
PolicyConfig
::
RoundRobinConfig
{
timeout_secs
:
self
.worker_startup_timeout_secs
,
interval_secs
:
self
.worker_startup_check_interval
,
},
PolicyType
::
CacheAware
=>
router
::
PolicyConfig
::
CacheAwareConfig
{
let
policy_config
=
if
self
.pd_disaggregated
{
// PD mode - map PolicyType to PDSelectionPolicy
let
pd_selection_policy
=
match
&
self
.policy
{
PolicyType
::
Random
=>
pd_types
::
PDSelectionPolicy
::
Random
,
PolicyType
::
PowerOfTwo
=>
pd_types
::
PDSelectionPolicy
::
PowerOfTwo
,
PolicyType
::
CacheAware
=>
pd_types
::
PDSelectionPolicy
::
CacheAware
{
cache_threshold
:
self
.cache_threshold
,
balance_abs_threshold
:
self
.balance_abs_threshold
,
balance_rel_threshold
:
self
.balance_rel_threshold
,
},
PolicyType
::
RoundRobin
=>
{
return
Err
(
pyo3
::
exceptions
::
PyValueError
::
new_err
(
"RoundRobin policy is not supported in PD disaggregated mode"
,
));
}
};
let
prefill_urls
=
self
.prefill_urls
.as_ref
()
.ok_or_else
(||
{
pyo3
::
exceptions
::
PyValueError
::
new_err
(
"PD disaggregated mode requires prefill_urls"
,
)
})
?
;
let
decode_urls
=
self
.decode_urls
.as_ref
()
.ok_or_else
(||
{
pyo3
::
exceptions
::
PyValueError
::
new_err
(
"PD disaggregated mode requires decode_urls"
,
)
})
?
;
router
::
PolicyConfig
::
PrefillDecodeConfig
{
selection_policy
:
pd_selection_policy
,
prefill_urls
:
prefill_urls
.clone
(),
decode_urls
:
decode_urls
.clone
(),
timeout_secs
:
self
.worker_startup_timeout_secs
,
interval_secs
:
self
.worker_startup_check_interval
,
cache_threshold
:
self
.cache_threshold
,
balance_abs_threshold
:
self
.balance_abs_threshold
,
balance_rel_threshold
:
self
.balance_rel_threshold
,
eviction_interval_secs
:
self
.eviction_interval_secs
,
max_tree_size
:
self
.max_tree_size
,
},
}
}
else
{
// Regular mode
match
&
self
.policy
{
PolicyType
::
Random
=>
router
::
PolicyConfig
::
RandomConfig
{
timeout_secs
:
self
.worker_startup_timeout_secs
,
interval_secs
:
self
.worker_startup_check_interval
,
},
PolicyType
::
RoundRobin
=>
router
::
PolicyConfig
::
RoundRobinConfig
{
timeout_secs
:
self
.worker_startup_timeout_secs
,
interval_secs
:
self
.worker_startup_check_interval
,
},
PolicyType
::
CacheAware
=>
router
::
PolicyConfig
::
CacheAwareConfig
{
timeout_secs
:
self
.worker_startup_timeout_secs
,
interval_secs
:
self
.worker_startup_check_interval
,
cache_threshold
:
self
.cache_threshold
,
balance_abs_threshold
:
self
.balance_abs_threshold
,
balance_rel_threshold
:
self
.balance_rel_threshold
,
eviction_interval_secs
:
self
.eviction_interval_secs
,
max_tree_size
:
self
.max_tree_size
,
},
PolicyType
::
PowerOfTwo
=>
{
return
Err
(
pyo3
::
exceptions
::
PyValueError
::
new_err
(
"PowerOfTwo policy is only supported in PD disaggregated mode"
,
));
}
}
};
// Create service discovery config if enabled
...
...
@@ -166,6 +232,7 @@ impl Router {
log_dir
:
self
.log_dir
.clone
(),
service_discovery_config
,
prometheus_config
,
request_timeout_secs
:
self
.request_timeout_secs
,
})
.await
.map_err
(|
e
|
pyo3
::
exceptions
::
PyRuntimeError
::
new_err
(
e
.to_string
()))
?
;
...
...
sgl-router/src/openai_api_types.rs
0 → 100644
View file @
09ae5b20
// OpenAI-compatible API types for text generation
// Based on OpenAI's API specification: https://platform.openai.com/docs/api-reference
// Reference: Azure OpenAI API documentation which follows OpenAI's specification
use
serde
::{
Deserialize
,
Serialize
};
use
serde_json
::
Value
;
use
std
::
collections
::
HashMap
;
/// Common trait for all generation requests
pub
trait
GenerationRequest
:
Send
+
Sync
{
/// Check if the request is for streaming
fn
is_stream
(
&
self
)
->
bool
;
/// Get the model name if specified
fn
get_model
(
&
self
)
->
Option
<&
str
>
;
/// Extract text content for routing decisions
fn
extract_text_for_routing
(
&
self
)
->
String
;
}
// ============= Completions API (v1/completions) - DEPRECATED but still supported =============
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
pub
struct
CompletionRequest
{
/// ID of the model to use (required for OpenAI, optional for some implementations, such as SGLang)
pub
model
:
String
,
/// The prompt(s) to generate completions for
pub
prompt
:
StringOrArray
,
/// The suffix that comes after a completion of inserted text
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
suffix
:
Option
<
String
>
,
/// The maximum number of tokens to generate
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
max_tokens
:
Option
<
u32
>
,
/// What sampling temperature to use, between 0 and 2
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
temperature
:
Option
<
f32
>
,
/// An alternative to sampling with temperature (nucleus sampling)
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
top_p
:
Option
<
f32
>
,
/// How many completions to generate for each prompt
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
n
:
Option
<
u32
>
,
/// Whether to stream back partial progress
#[serde(default)]
pub
stream
:
bool
,
/// Include the log probabilities on the logprobs most likely tokens
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
logprobs
:
Option
<
u32
>
,
/// Echo back the prompt in addition to the completion
#[serde(default)]
pub
echo
:
bool
,
/// Up to 4 sequences where the API will stop generating further tokens
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
stop
:
Option
<
StringOrArray
>
,
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
presence_penalty
:
Option
<
f32
>
,
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
frequency_penalty
:
Option
<
f32
>
,
/// Generates best_of completions server-side and returns the "best"
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
best_of
:
Option
<
u32
>
,
/// Modify the likelihood of specified tokens appearing in the completion
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
logit_bias
:
Option
<
HashMap
<
String
,
f32
>>
,
/// A unique identifier representing your end-user
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
user
:
Option
<
String
>
,
/// If specified, our system will make a best effort to sample deterministically
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
seed
:
Option
<
i64
>
,
}
impl
GenerationRequest
for
CompletionRequest
{
fn
is_stream
(
&
self
)
->
bool
{
self
.stream
}
fn
get_model
(
&
self
)
->
Option
<&
str
>
{
Some
(
&
self
.model
)
}
fn
extract_text_for_routing
(
&
self
)
->
String
{
match
&
self
.prompt
{
StringOrArray
::
String
(
s
)
=>
s
.clone
(),
StringOrArray
::
Array
(
v
)
=>
v
.join
(
" "
),
}
}
}
// ============= Chat Completions API (v1/chat/completions) =============
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
pub
struct
ChatCompletionRequest
{
/// ID of the model to use
pub
model
:
String
,
/// A list of messages comprising the conversation so far
pub
messages
:
Vec
<
ChatMessage
>
,
/// What sampling temperature to use, between 0 and 2
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
temperature
:
Option
<
f32
>
,
/// An alternative to sampling with temperature
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
top_p
:
Option
<
f32
>
,
/// How many chat completion choices to generate for each input message
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
n
:
Option
<
u32
>
,
/// If set, partial message deltas will be sent
#[serde(default)]
pub
stream
:
bool
,
/// Up to 4 sequences where the API will stop generating further tokens
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
stop
:
Option
<
StringOrArray
>
,
/// The maximum number of tokens to generate
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
max_tokens
:
Option
<
u32
>
,
/// An upper bound for the number of tokens that can be generated for a completion
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
max_completion_tokens
:
Option
<
u32
>
,
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
presence_penalty
:
Option
<
f32
>
,
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
frequency_penalty
:
Option
<
f32
>
,
/// Modify the likelihood of specified tokens appearing in the completion
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
logit_bias
:
Option
<
HashMap
<
String
,
i32
>>
,
/// A unique identifier representing your end-user
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
user
:
Option
<
String
>
,
/// If specified, our system will make a best effort to sample deterministically
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
seed
:
Option
<
i64
>
,
/// Whether to return log probabilities of the output tokens
#[serde(default)]
pub
logprobs
:
bool
,
/// An integer between 0 and 20 specifying the number of most likely tokens to return
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
top_logprobs
:
Option
<
u32
>
,
/// An object specifying the format that the model must output
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
response_format
:
Option
<
ResponseFormat
>
,
/// A list of tools the model may call
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
tools
:
Option
<
Vec
<
Tool
>>
,
/// Controls which (if any) tool is called by the model
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
tool_choice
:
Option
<
ToolChoice
>
,
/// Whether to enable parallel function calling during tool use
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
parallel_tool_calls
:
Option
<
bool
>
,
/// Deprecated: use tools instead
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
functions
:
Option
<
Vec
<
Function
>>
,
/// Deprecated: use tool_choice instead
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
function_call
:
Option
<
FunctionCall
>
,
}
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
#[serde(untagged)]
pub
enum
ChatMessage
{
System
{
role
:
String
,
// "system"
content
:
String
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
name
:
Option
<
String
>
,
},
User
{
role
:
String
,
// "user"
content
:
UserMessageContent
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
name
:
Option
<
String
>
,
},
Assistant
{
role
:
String
,
// "assistant"
#[serde(skip_serializing_if
=
"Option::is_none"
)]
content
:
Option
<
String
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
name
:
Option
<
String
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
tool_calls
:
Option
<
Vec
<
ToolCall
>>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
function_call
:
Option
<
FunctionCallResponse
>
,
},
Tool
{
role
:
String
,
// "tool"
content
:
String
,
tool_call_id
:
String
,
},
Function
{
role
:
String
,
// "function"
content
:
String
,
name
:
String
,
},
}
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
#[serde(untagged)]
pub
enum
UserMessageContent
{
Text
(
String
),
Parts
(
Vec
<
ContentPart
>
),
}
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
#[serde(tag
=
"type"
)]
pub
enum
ContentPart
{
#[serde(rename
=
"text"
)]
Text
{
text
:
String
},
#[serde(rename
=
"image_url"
)]
ImageUrl
{
image_url
:
ImageUrl
},
}
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
pub
struct
ImageUrl
{
pub
url
:
String
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
detail
:
Option
<
String
>
,
// "auto", "low", or "high"
}
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
#[serde(tag
=
"type"
)]
pub
enum
ResponseFormat
{
#[serde(rename
=
"text"
)]
Text
,
#[serde(rename
=
"json_object"
)]
JsonObject
,
#[serde(rename
=
"json_schema"
)]
JsonSchema
{
json_schema
:
JsonSchemaFormat
},
}
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
pub
struct
JsonSchemaFormat
{
pub
name
:
String
,
pub
schema
:
Value
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
strict
:
Option
<
bool
>
,
}
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
pub
struct
Tool
{
#[serde(rename
=
"type"
)]
pub
tool_type
:
String
,
// "function"
pub
function
:
Function
,
}
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
pub
struct
Function
{
pub
name
:
String
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
description
:
Option
<
String
>
,
pub
parameters
:
Value
,
// JSON Schema
}
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
#[serde(untagged)]
pub
enum
ToolChoice
{
None
,
Auto
,
Required
,
Function
{
#[serde(rename
=
"type"
)]
tool_type
:
String
,
// "function"
function
:
FunctionChoice
,
},
}
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
pub
struct
FunctionChoice
{
pub
name
:
String
,
}
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
pub
struct
ToolCall
{
pub
id
:
String
,
#[serde(rename
=
"type"
)]
pub
tool_type
:
String
,
// "function"
pub
function
:
FunctionCallResponse
,
}
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
#[serde(untagged)]
pub
enum
FunctionCall
{
None
,
Auto
,
Function
{
name
:
String
},
}
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
pub
struct
FunctionCallResponse
{
pub
name
:
String
,
pub
arguments
:
String
,
// JSON string
}
impl
GenerationRequest
for
ChatCompletionRequest
{
fn
is_stream
(
&
self
)
->
bool
{
self
.stream
}
fn
get_model
(
&
self
)
->
Option
<&
str
>
{
Some
(
&
self
.model
)
}
fn
extract_text_for_routing
(
&
self
)
->
String
{
// Extract text from messages for routing decisions
self
.messages
.iter
()
.filter_map
(|
msg
|
match
msg
{
ChatMessage
::
System
{
content
,
..
}
=>
Some
(
content
.clone
()),
ChatMessage
::
User
{
content
,
..
}
=>
match
content
{
UserMessageContent
::
Text
(
text
)
=>
Some
(
text
.clone
()),
UserMessageContent
::
Parts
(
parts
)
=>
{
let
texts
:
Vec
<
String
>
=
parts
.iter
()
.filter_map
(|
part
|
match
part
{
ContentPart
::
Text
{
text
}
=>
Some
(
text
.clone
()),
_
=>
None
,
})
.collect
();
Some
(
texts
.join
(
" "
))
}
},
ChatMessage
::
Assistant
{
content
,
..
}
=>
content
.clone
(),
ChatMessage
::
Tool
{
content
,
..
}
=>
Some
(
content
.clone
()),
ChatMessage
::
Function
{
content
,
..
}
=>
Some
(
content
.clone
()),
})
.collect
::
<
Vec
<
String
>>
()
.join
(
" "
)
}
}
// ============= Generate API (/generate) =============
#[derive(Clone,
Debug,
Serialize,
Deserialize)]
pub
struct
GenerateRequest
{
/// The prompt to generate from (OpenAI style)
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
prompt
:
Option
<
StringOrArray
>
,
/// Text input - SGLang native format
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
text
:
Option
<
String
>
,
/// Input IDs for tokenized input
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
input_ids
:
Option
<
InputIds
>
,
/// Generation parameters
#[serde(default,
skip_serializing_if
=
"Option::is_none"
)]
pub
parameters
:
Option
<
GenerateParameters
>
,
/// Sampling parameters (sglang style)
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
sampling_params
:
Option
<
SamplingParams
>
,
/// Whether to stream the response
#[serde(default)]
pub
stream
:
bool
,
/// Whether to return logprobs
#[serde(default)]
pub
return_logprob
:
bool
,
}
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
#[serde(untagged)]
pub
enum
InputIds
{
Single
(
Vec
<
i32
>
),
Batch
(
Vec
<
Vec
<
i32
>>
),
}
#[derive(Debug,
Clone,
Deserialize,
Serialize,
Default)]
pub
struct
GenerateParameters
{
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
best_of
:
Option
<
u32
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
decoder_input_details
:
Option
<
bool
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
details
:
Option
<
bool
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
do_sample
:
Option
<
bool
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
max_new_tokens
:
Option
<
u32
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
repetition_penalty
:
Option
<
f32
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
return_full_text
:
Option
<
bool
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
seed
:
Option
<
u64
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
stop
:
Option
<
Vec
<
String
>>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
temperature
:
Option
<
f32
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
top_k
:
Option
<
u32
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
top_p
:
Option
<
f32
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
truncate
:
Option
<
u32
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
typical_p
:
Option
<
f32
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
watermark
:
Option
<
bool
>
,
}
#[derive(Debug,
Clone,
Deserialize,
Serialize,
Default)]
pub
struct
SamplingParams
{
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
temperature
:
Option
<
f32
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
max_new_tokens
:
Option
<
u32
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
top_p
:
Option
<
f32
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
top_k
:
Option
<
i32
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
frequency_penalty
:
Option
<
f32
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
presence_penalty
:
Option
<
f32
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
repetition_penalty
:
Option
<
f32
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
stop
:
Option
<
StringOrArray
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
ignore_eos
:
Option
<
bool
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
skip_special_tokens
:
Option
<
bool
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
json_schema
:
Option
<
String
>
,
}
impl
GenerationRequest
for
GenerateRequest
{
fn
is_stream
(
&
self
)
->
bool
{
self
.stream
}
fn
get_model
(
&
self
)
->
Option
<&
str
>
{
// Generate requests typically don't have a model field
None
}
fn
extract_text_for_routing
(
&
self
)
->
String
{
// Check fields in priority order: text, prompt, inputs
if
let
Some
(
ref
text
)
=
self
.text
{
return
text
.clone
();
}
if
let
Some
(
ref
prompt
)
=
self
.prompt
{
return
match
prompt
{
StringOrArray
::
String
(
s
)
=>
s
.clone
(),
StringOrArray
::
Array
(
v
)
=>
v
.join
(
" "
),
};
}
if
let
Some
(
ref
input_ids
)
=
self
.input_ids
{
return
match
input_ids
{
InputIds
::
Single
(
ids
)
=>
ids
.iter
()
.map
(|
&
id
|
id
.to_string
())
.collect
::
<
Vec
<
String
>>
()
.join
(
" "
),
InputIds
::
Batch
(
batches
)
=>
batches
.iter
()
.flat_map
(|
batch
|
batch
.iter
()
.map
(|
&
id
|
id
.to_string
()))
.collect
::
<
Vec
<
String
>>
()
.join
(
" "
),
};
}
// No text input found
String
::
new
()
}
}
// ============= Helper Types =============
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
#[serde(untagged)]
pub
enum
StringOrArray
{
String
(
String
),
Array
(
Vec
<
String
>
),
}
// ============= Response Types =============
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
pub
struct
CompletionResponse
{
pub
id
:
String
,
pub
object
:
String
,
// "text_completion"
pub
created
:
u64
,
pub
model
:
String
,
pub
choices
:
Vec
<
CompletionChoice
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
usage
:
Option
<
Usage
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
system_fingerprint
:
Option
<
String
>
,
}
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
pub
struct
CompletionChoice
{
pub
text
:
String
,
pub
index
:
u32
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
logprobs
:
Option
<
LogProbs
>
,
pub
finish_reason
:
Option
<
String
>
,
// "stop", "length", "content_filter", etc.
}
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
pub
struct
LogProbs
{
pub
tokens
:
Vec
<
String
>
,
pub
token_logprobs
:
Vec
<
Option
<
f32
>>
,
pub
top_logprobs
:
Vec
<
Option
<
HashMap
<
String
,
f32
>>>
,
pub
text_offset
:
Vec
<
u32
>
,
}
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
pub
struct
ChatCompletionResponse
{
pub
id
:
String
,
pub
object
:
String
,
// "chat.completion"
pub
created
:
u64
,
pub
model
:
String
,
pub
choices
:
Vec
<
ChatChoice
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
usage
:
Option
<
Usage
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
system_fingerprint
:
Option
<
String
>
,
}
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
pub
struct
ChatChoice
{
pub
index
:
u32
,
pub
message
:
ChatMessage
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
logprobs
:
Option
<
ChatLogProbs
>
,
pub
finish_reason
:
Option
<
String
>
,
// "stop", "length", "tool_calls", "content_filter", "function_call"
}
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
pub
struct
ChatLogProbs
{
pub
content
:
Option
<
Vec
<
ChatLogProbsContent
>>
,
}
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
pub
struct
ChatLogProbsContent
{
pub
token
:
String
,
pub
logprob
:
f32
,
pub
bytes
:
Option
<
Vec
<
u8
>>
,
pub
top_logprobs
:
Vec
<
TopLogProb
>
,
}
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
pub
struct
TopLogProb
{
pub
token
:
String
,
pub
logprob
:
f32
,
pub
bytes
:
Option
<
Vec
<
u8
>>
,
}
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
pub
struct
Usage
{
pub
prompt_tokens
:
u32
,
pub
completion_tokens
:
u32
,
pub
total_tokens
:
u32
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
completion_tokens_details
:
Option
<
CompletionTokensDetails
>
,
}
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
pub
struct
CompletionTokensDetails
{
pub
reasoning_tokens
:
Option
<
u32
>
,
}
// ============= Streaming Response Types =============
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
pub
struct
CompletionStreamResponse
{
pub
id
:
String
,
pub
object
:
String
,
// "text_completion"
pub
created
:
u64
,
pub
choices
:
Vec
<
CompletionStreamChoice
>
,
pub
model
:
String
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
system_fingerprint
:
Option
<
String
>
,
}
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
pub
struct
CompletionStreamChoice
{
pub
text
:
String
,
pub
index
:
u32
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
logprobs
:
Option
<
LogProbs
>
,
pub
finish_reason
:
Option
<
String
>
,
}
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
pub
struct
ChatCompletionStreamResponse
{
pub
id
:
String
,
pub
object
:
String
,
// "chat.completion.chunk"
pub
created
:
u64
,
pub
model
:
String
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
system_fingerprint
:
Option
<
String
>
,
pub
choices
:
Vec
<
ChatStreamChoice
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
usage
:
Option
<
Usage
>
,
}
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
pub
struct
ChatStreamChoice
{
pub
index
:
u32
,
pub
delta
:
ChatMessageDelta
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
logprobs
:
Option
<
ChatLogProbs
>
,
pub
finish_reason
:
Option
<
String
>
,
}
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
pub
struct
ChatMessageDelta
{
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
role
:
Option
<
String
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
content
:
Option
<
String
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
tool_calls
:
Option
<
Vec
<
ToolCallDelta
>>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
function_call
:
Option
<
FunctionCallDelta
>
,
}
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
pub
struct
ToolCallDelta
{
pub
index
:
u32
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
id
:
Option
<
String
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
#[serde(rename
=
"type"
)]
pub
tool_type
:
Option
<
String
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
function
:
Option
<
FunctionCallDelta
>
,
}
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
pub
struct
FunctionCallDelta
{
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
name
:
Option
<
String
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
arguments
:
Option
<
String
>
,
}
// ============= Error Response Types =============
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
pub
struct
ErrorResponse
{
pub
error
:
ErrorDetail
,
}
#[derive(Debug,
Clone,
Deserialize,
Serialize)]
pub
struct
ErrorDetail
{
pub
message
:
String
,
#[serde(rename
=
"type"
)]
pub
error_type
:
String
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
param
:
Option
<
String
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
code
:
Option
<
String
>
,
}
sgl-router/src/pd_router.rs
0 → 100644
View file @
09ae5b20
// PD (Prefill-Decode) Router Implementation
// This module handles routing for disaggregated prefill-decode systems
use
crate
::
pd_types
::{
Bootstrap
,
ChatReqInput
,
EngineInfo
,
GenerateReqInput
,
PDSelectionPolicy
};
use
crate
::
tree
::
Tree
;
use
actix_web
::
http
::
header
::{
HeaderValue
,
CONTENT_TYPE
};
use
actix_web
::{
HttpRequest
,
HttpResponse
};
use
futures_util
::{
StreamExt
,
TryStreamExt
};
use
metrics
::{
counter
,
histogram
};
use
serde_json
::
Value
;
use
std
::
collections
::
HashMap
;
use
std
::
sync
::
atomic
::{
AtomicUsize
,
Ordering
};
use
std
::
sync
::{
Arc
,
Mutex
,
RwLock
};
use
std
::
time
::{
Duration
,
Instant
};
use
tracing
::{
debug
,
error
,
info
,
warn
};
use
uuid
::
Uuid
;
// Removed over-engineered ProxyResponse - using HttpResponse directly
#[derive(Debug)]
pub
struct
PDRouter
{
pub
prefill_workers
:
Arc
<
RwLock
<
Vec
<
EngineInfo
>>>
,
pub
decode_workers
:
Arc
<
RwLock
<
Vec
<
EngineInfo
>>>
,
pub
selection_policy
:
PDSelectionPolicy
,
pub
load_tracking
:
Arc
<
dashmap
::
DashMap
<
String
,
Arc
<
AtomicUsize
>>>
,
pub
prefill_tree
:
Option
<
Arc
<
Mutex
<
Tree
>>>
,
pub
timeout_secs
:
u64
,
pub
interval_secs
:
u64
,
pub
worker_loads
:
Arc
<
tokio
::
sync
::
watch
::
Receiver
<
HashMap
<
String
,
isize
>>>
,
pub
load_monitor_handle
:
Option
<
Arc
<
tokio
::
task
::
JoinHandle
<
()
>>>
,
pub
http_client
:
reqwest
::
Client
,
}
// RAII guard for load tracking to ensure cleanup even on panic
struct
LoadGuard
<
'a
>
{
tracking
:
&
'a
Arc
<
dashmap
::
DashMap
<
String
,
Arc
<
AtomicUsize
>>>
,
urls
:
Vec
<
String
>
,
}
impl
<
'a
>
LoadGuard
<
'a
>
{
fn
new
(
tracking
:
&
'a
Arc
<
dashmap
::
DashMap
<
String
,
Arc
<
AtomicUsize
>>>
,
urls
:
Vec
<
String
>
,
)
->
Self
{
// Increment counters
for
url
in
&
urls
{
let
counter
=
tracking
.entry
(
url
.clone
())
.or_insert_with
(||
Arc
::
new
(
AtomicUsize
::
new
(
0
)));
counter
.fetch_add
(
1
,
Ordering
::
Relaxed
);
}
LoadGuard
{
tracking
,
urls
}
}
}
impl
Drop
for
LoadGuard
<
'_
>
{
fn
drop
(
&
mut
self
)
{
// Guaranteed cleanup even on panic
for
url
in
&
self
.urls
{
if
let
Some
(
counter
)
=
self
.tracking
.get
(
url
)
{
counter
.fetch_sub
(
1
,
Ordering
::
Relaxed
);
}
}
}
}
impl
PDRouter
{
// TODO: Add methods for dynamic worker management to support /register endpoint:
// - add_prefill_server(url: String, bootstrap_port: Option<u16>)
// - add_decode_server(url: String)
// - remove_prefill_server(url: &str)
// - remove_decode_server(url: &str)
// These methods will be used when service discovery is implemented for PD mode
pub
fn
new
(
prefill_urls
:
Vec
<
(
String
,
Option
<
u16
>
)
>
,
decode_urls
:
Vec
<
String
>
,
selection_policy
:
PDSelectionPolicy
,
timeout_secs
:
u64
,
interval_secs
:
u64
,
)
->
Result
<
Self
,
String
>
{
// Convert URLs to EngineInfo
let
prefill_workers
:
Vec
<
EngineInfo
>
=
prefill_urls
.into_iter
()
.map
(|(
url
,
port
)|
EngineInfo
::
new_prefill
(
url
,
port
))
.collect
();
let
decode_workers
:
Vec
<
EngineInfo
>
=
decode_urls
.into_iter
()
.map
(
EngineInfo
::
new_decode
)
.collect
();
// Wait for PD workers to be healthy
let
all_urls
:
Vec
<
String
>
=
prefill_workers
.iter
()
.chain
(
decode_workers
.iter
())
.map
(|
engine
|
engine
.url
.clone
())
.collect
();
crate
::
router
::
Router
::
wait_for_healthy_workers
(
&
all_urls
,
timeout_secs
,
interval_secs
)
?
;
// Initialize load tracking with atomic counters
let
load_tracking
=
Arc
::
new
(
dashmap
::
DashMap
::
new
());
for
engine
in
&
prefill_workers
{
load_tracking
.insert
(
engine
.url
.clone
(),
Arc
::
new
(
AtomicUsize
::
new
(
0
)));
}
for
engine
in
&
decode_workers
{
load_tracking
.insert
(
engine
.url
.clone
(),
Arc
::
new
(
AtomicUsize
::
new
(
0
)));
}
// Initialize cache-aware components if needed
let
prefill_tree
=
match
&
selection_policy
{
PDSelectionPolicy
::
CacheAware
{
..
}
=>
{
let
tree
=
Arc
::
new
(
Mutex
::
new
(
Tree
::
new
()));
// Initialize tree with prefill workers
for
engine
in
&
prefill_workers
{
tree
.lock
()
.unwrap
()
.insert
(
""
,
&
engine
.url
);
}
Some
(
tree
)
}
_
=>
None
,
};
// Set up background load monitoring for power-of-two selection
let
(
tx
,
rx
)
=
tokio
::
sync
::
watch
::
channel
(
HashMap
::
new
());
let
worker_loads
=
Arc
::
new
(
rx
);
// Create a shared HTTP client for all operations
let
http_client
=
reqwest
::
Client
::
builder
()
.timeout
(
Duration
::
from_secs
(
timeout_secs
))
.build
()
.map_err
(|
e
|
format!
(
"Failed to create HTTP client: {}"
,
e
))
?
;
let
load_monitor_handle
=
if
matches!
(
selection_policy
,
PDSelectionPolicy
::
PowerOfTwo
)
{
let
monitor_urls
=
all_urls
.clone
();
let
monitor_interval
=
interval_secs
;
let
monitor_client
=
http_client
.clone
();
Some
(
Arc
::
new
(
tokio
::
spawn
(
async
move
{
Self
::
monitor_worker_loads_with_client
(
monitor_urls
,
tx
,
monitor_interval
,
monitor_client
,
)
.await
;
})))
}
else
{
None
};
Ok
(
PDRouter
{
prefill_workers
:
Arc
::
new
(
RwLock
::
new
(
prefill_workers
)),
decode_workers
:
Arc
::
new
(
RwLock
::
new
(
decode_workers
)),
selection_policy
,
load_tracking
,
prefill_tree
,
timeout_secs
,
interval_secs
,
worker_loads
,
load_monitor_handle
,
http_client
,
})
}
// Route a typed generate request
pub
async
fn
route_generate
(
&
self
,
client
:
&
reqwest
::
Client
,
req
:
&
HttpRequest
,
mut
typed_req
:
GenerateReqInput
,
route
:
&
str
,
)
->
HttpResponse
{
let
start
=
Instant
::
now
();
let
_
request_id
=
Uuid
::
new_v4
();
// Get stream flag and return_logprob flag before moving the request
let
is_stream
=
typed_req
.is_stream
();
let
return_logprob
=
typed_req
.other
.get
(
"return_logprob"
)
.and_then
(|
v
|
v
.as_bool
())
.unwrap_or
(
false
);
// Select servers
let
(
prefill
,
decode
)
=
match
self
.select_pd_pair
(
client
)
.await
{
Ok
(
pair
)
=>
pair
,
Err
(
e
)
=>
{
error!
(
"Failed to select PD pair: {}"
,
e
);
counter!
(
"sgl_router_pd_errors_total"
,
"error"
=>
"server_selection"
)
.increment
(
1
);
return
HttpResponse
::
ServiceUnavailable
()
.body
(
format!
(
"No available servers: {}"
,
e
));
}
};
// Log routing decision
info!
(
"PD routing: {} -> prefill={}, decode={}"
,
route
,
prefill
.url
,
decode
.url
);
// Add bootstrap info using the trait method
if
let
Err
(
e
)
=
typed_req
.add_bootstrap_info
(
&
prefill
)
{
error!
(
"Failed to add bootstrap info: {}"
,
e
);
counter!
(
"sgl_router_pd_errors_total"
,
"error"
=>
"bootstrap_injection"
)
.increment
(
1
);
return
HttpResponse
::
InternalServerError
()
.body
(
format!
(
"Bootstrap injection failed: {}"
,
e
));
}
// Convert to JSON after bootstrap injection
let
json_with_bootstrap
=
match
serde_json
::
to_value
(
&
typed_req
)
{
Ok
(
json
)
=>
json
,
Err
(
e
)
=>
{
error!
(
"Failed to serialize request: {}"
,
e
);
return
HttpResponse
::
InternalServerError
()
.body
(
"Failed to serialize request"
);
}
};
// Execute dual dispatch
self
.execute_dual_dispatch
(
client
,
req
,
json_with_bootstrap
,
route
,
&
prefill
,
&
decode
,
is_stream
,
return_logprob
,
start
,
)
.await
}
// Route a typed chat request
pub
async
fn
route_chat
(
&
self
,
client
:
&
reqwest
::
Client
,
req
:
&
HttpRequest
,
mut
typed_req
:
ChatReqInput
,
route
:
&
str
,
)
->
HttpResponse
{
let
start
=
Instant
::
now
();
// Get stream flag and return_logprob flag before moving the request
let
is_stream
=
typed_req
.is_stream
();
let
return_logprob
=
typed_req
.other
.get
(
"return_logprob"
)
.and_then
(|
v
|
v
.as_bool
())
.unwrap_or
(
false
);
// Select servers
let
(
prefill
,
decode
)
=
match
self
.select_pd_pair
(
client
)
.await
{
Ok
(
pair
)
=>
pair
,
Err
(
e
)
=>
{
error!
(
"Failed to select PD pair: {}"
,
e
);
counter!
(
"sgl_router_pd_errors_total"
,
"error"
=>
"server_selection"
)
.increment
(
1
);
return
HttpResponse
::
ServiceUnavailable
()
.body
(
format!
(
"No available servers: {}"
,
e
));
}
};
// Log routing decision
info!
(
"PD routing: {} -> prefill={}, decode={}"
,
route
,
prefill
.url
,
decode
.url
);
// Add bootstrap info using the trait method
if
let
Err
(
e
)
=
typed_req
.add_bootstrap_info
(
&
prefill
)
{
error!
(
"Failed to add bootstrap info: {}"
,
e
);
counter!
(
"sgl_router_pd_errors_total"
,
"error"
=>
"bootstrap_injection"
)
.increment
(
1
);
return
HttpResponse
::
InternalServerError
()
.body
(
format!
(
"Bootstrap injection failed: {}"
,
e
));
}
// Convert to JSON after bootstrap injection
let
json_with_bootstrap
=
match
serde_json
::
to_value
(
&
typed_req
)
{
Ok
(
json
)
=>
json
,
Err
(
e
)
=>
{
error!
(
"Failed to serialize request: {}"
,
e
);
return
HttpResponse
::
InternalServerError
()
.body
(
"Failed to serialize request"
);
}
};
// Execute dual dispatch
self
.execute_dual_dispatch
(
client
,
req
,
json_with_bootstrap
,
route
,
&
prefill
,
&
decode
,
is_stream
,
return_logprob
,
start
,
)
.await
}
// Execute the dual dispatch to prefill and decode servers
#[allow(clippy::too_many_arguments)]
async
fn
execute_dual_dispatch
(
&
self
,
client
:
&
reqwest
::
Client
,
req
:
&
HttpRequest
,
json_request
:
serde_json
::
Value
,
route
:
&
str
,
prefill
:
&
EngineInfo
,
decode
:
&
EngineInfo
,
is_stream
:
bool
,
return_logprob
:
bool
,
start_time
:
Instant
,
)
->
HttpResponse
{
// Update load tracking for both workers
let
_
guard
=
LoadGuard
::
new
(
&
self
.load_tracking
,
vec!
[
prefill
.url
.clone
(),
decode
.url
.clone
()],
);
// Build requests using .json() method
let
mut
prefill_request
=
client
.post
(
prefill
.api_path
(
route
))
.json
(
&
json_request
);
let
mut
decode_request
=
client
.post
(
decode
.api_path
(
route
))
.json
(
&
json_request
);
// Copy headers from original request
for
(
name
,
value
)
in
crate
::
router
::
copy_request_headers
(
req
)
{
if
name
.to_lowercase
()
!=
"content-type"
&&
name
.to_lowercase
()
!=
"content-length"
{
prefill_request
=
prefill_request
.header
(
&
name
,
&
value
);
decode_request
=
decode_request
.header
(
&
name
,
&
value
);
}
}
// Send both requests concurrently
let
(
prefill_result
,
decode_result
)
=
tokio
::
join!
(
prefill_request
.send
(),
decode_request
.send
());
// Update metrics
let
duration
=
start_time
.elapsed
();
histogram!
(
"sgl_router_pd_request_duration_seconds"
,
"route"
=>
route
.to_string
())
.record
(
duration
.as_secs_f64
());
counter!
(
"sgl_router_pd_requests_total"
,
"route"
=>
route
.to_string
())
.increment
(
1
);
counter!
(
"sgl_router_pd_prefill_requests_total"
,
"worker"
=>
prefill
.url
.to_string
())
.increment
(
1
);
counter!
(
"sgl_router_pd_decode_requests_total"
,
"worker"
=>
decode
.url
.to_string
())
.increment
(
1
);
// Process decode response
match
decode_result
{
Ok
(
res
)
=>
{
let
status
=
actix_web
::
http
::
StatusCode
::
from_u16
(
res
.status
()
.as_u16
())
.unwrap_or
(
actix_web
::
http
::
StatusCode
::
INTERNAL_SERVER_ERROR
);
if
!
status
.is_success
()
{
counter!
(
"sgl_router_pd_decode_errors_total"
,
"worker"
=>
decode
.url
.to_string
())
.increment
(
1
);
error!
(
"Decode server {} returned error status: {}"
,
decode
.url
,
status
);
// Return the error response from decode server
match
res
.bytes
()
.await
{
Ok
(
error_body
)
=>
{
return
HttpResponse
::
build
(
status
)
.body
(
error_body
.to_vec
());
}
Err
(
e
)
=>
{
return
HttpResponse
::
build
(
status
)
.body
(
format!
(
"Decode server error: {}"
,
e
));
}
}
}
// Log prefill errors for debugging
if
let
Err
(
e
)
=
&
prefill_result
{
error!
(
"Prefill server {} failed (non-critical): {}"
,
prefill
.url
,
e
);
counter!
(
"sgl_router_pd_prefill_errors_total"
,
"worker"
=>
prefill
.url
.to_string
())
.increment
(
1
);
}
if
is_stream
{
// Streaming response
if
return_logprob
{
// Get prefill logprobs for merging
let
prefill_logprobs
=
match
prefill_result
{
Ok
(
prefill_res
)
=>
match
prefill_res
.bytes
()
.await
{
Ok
(
body
)
=>
serde_json
::
from_slice
::
<
Value
>
(
&
body
)
.ok
()
.and_then
(|
json
|
{
json
.pointer
(
"/meta_info/input_token_logprobs"
)
.cloned
()
}),
Err
(
_
)
=>
None
,
},
Err
(
_
)
=>
None
,
};
// Stream with logprob merging
HttpResponse
::
build
(
status
)
.insert_header
((
CONTENT_TYPE
,
HeaderValue
::
from_static
(
"text/event-stream"
),
))
.streaming
(
res
.bytes_stream
()
.map
(
move
|
chunk_result
|
{
match
chunk_result
{
Ok
(
chunk
)
=>
{
// Try to merge logprobs
if
let
Ok
(
merged
)
=
Self
::
merge_streaming_logprobs
(
prefill_logprobs
.clone
(),
&
chunk
,
)
{
Ok
(
merged
)
}
else
{
Ok
(
chunk
)
}
}
Err
(
e
)
=>
Err
(
actix_web
::
error
::
ErrorInternalServerError
(
format!
(
"Stream error: {}"
,
e
),
)),
}
}))
}
else
{
// No logprob merging needed
HttpResponse
::
build
(
status
)
.insert_header
((
CONTENT_TYPE
,
HeaderValue
::
from_static
(
"text/event-stream"
)))
.streaming
({
let
decode_url
=
decode
.url
.clone
();
res
.bytes_stream
()
.map_err
(
move
|
e
|
{
error!
(
"Stream error from decode server {}: {}"
,
decode_url
,
e
);
counter!
(
"sgl_router_pd_stream_errors_total"
,
"worker"
=>
decode_url
.to_string
())
.increment
(
1
);
actix_web
::
error
::
ErrorInternalServerError
(
format!
(
"Stream error: {}"
,
e
))
})
})
}
}
else
{
// Non-streaming response
match
res
.bytes
()
.await
{
Ok
(
decode_body
)
=>
{
if
return_logprob
{
self
.merge_logprobs
(
prefill_result
,
decode_body
,
status
)
.await
}
else
{
HttpResponse
::
build
(
status
)
.body
(
decode_body
.to_vec
())
}
}
Err
(
e
)
=>
{
error!
(
"Failed to read decode response: {}"
,
e
);
HttpResponse
::
InternalServerError
()
.body
(
"Failed to read response"
)
}
}
}
}
Err
(
e
)
=>
{
error!
(
"Decode request failed: {}"
,
e
);
counter!
(
"sgl_router_pd_decode_errors_total"
,
"worker"
=>
decode
.url
.to_string
())
.increment
(
1
);
HttpResponse
::
BadGateway
()
.body
(
format!
(
"Decode server error: {}"
,
e
))
}
}
}
// Merge logprobs from prefill and decode responses
async
fn
merge_logprobs
(
&
self
,
prefill_result
:
Result
<
reqwest
::
Response
,
reqwest
::
Error
>
,
decode_body
:
bytes
::
Bytes
,
status
:
actix_web
::
http
::
StatusCode
,
)
->
HttpResponse
{
match
prefill_result
{
Ok
(
prefill_res
)
=>
{
match
prefill_res
.bytes
()
.await
{
Ok
(
prefill_body
)
=>
{
match
(
serde_json
::
from_slice
::
<
Value
>
(
&
prefill_body
),
serde_json
::
from_slice
::
<
Value
>
(
&
decode_body
),
)
{
(
Ok
(
prefill_json
),
Ok
(
mut
decode_json
))
=>
{
// Merge input_token_logprobs
if
let
(
Some
(
prefill_meta
),
Some
(
decode_meta
))
=
(
prefill_json
.get
(
"meta_info"
),
decode_json
.get_mut
(
"meta_info"
),
)
{
if
let
(
Some
(
prefill_logprobs
),
Some
(
decode_logprobs
))
=
(
prefill_meta
.get
(
"input_token_logprobs"
),
decode_meta
.get_mut
(
"input_token_logprobs"
),
)
{
if
let
(
Some
(
p_arr
),
Some
(
d_arr
))
=
(
prefill_logprobs
.as_array
(),
decode_logprobs
.as_array
(),
)
{
let
mut
merged
=
p_arr
.clone
();
merged
.extend
(
d_arr
.clone
());
decode_meta
[
"input_token_logprobs"
]
=
Value
::
Array
(
merged
);
}
}
}
HttpResponse
::
build
(
status
)
.json
(
&
decode_json
)
}
_
=>
{
warn!
(
"Failed to parse responses for logprob merging"
);
HttpResponse
::
build
(
status
)
.body
(
decode_body
.to_vec
())
}
}
}
Err
(
e
)
=>
{
warn!
(
"Failed to read prefill response: {}"
,
e
);
HttpResponse
::
build
(
status
)
.body
(
decode_body
.to_vec
())
}
}
}
Err
(
_
)
=>
HttpResponse
::
build
(
status
)
.body
(
decode_body
.to_vec
()),
}
}
// Select a pair of prefill and decode servers
async
fn
select_pd_pair
(
&
self
,
_
client
:
&
reqwest
::
Client
,
)
->
Result
<
(
EngineInfo
,
EngineInfo
),
String
>
{
// Check we have workers
if
self
.prefill_workers
.read
()
.map_err
(|
e
|
format!
(
"Failed to acquire prefill workers lock: {}"
,
e
))
?
.is_empty
()
{
return
Err
(
"No prefill workers available. Please check if prefill servers are configured and healthy."
.to_string
());
}
if
self
.decode_workers
.read
()
.map_err
(|
e
|
format!
(
"Failed to acquire decode workers lock: {}"
,
e
))
?
.is_empty
()
{
return
Err
(
"No decode workers available. Please check if decode servers are configured and healthy."
.to_string
());
}
match
&
self
.selection_policy
{
PDSelectionPolicy
::
Random
=>
self
.select_random
(),
PDSelectionPolicy
::
PowerOfTwo
=>
self
.select_power_of_two
()
.await
,
PDSelectionPolicy
::
CacheAware
{
..
}
=>
{
// TODO: Implement cache-aware selection
self
.select_power_of_two
()
.await
}
}
}
fn
select_random
(
&
self
)
->
Result
<
(
EngineInfo
,
EngineInfo
),
String
>
{
let
prefill_list
=
self
.prefill_workers
.read
()
.map_err
(|
_
|
"Lock error"
)
?
;
let
decode_list
=
self
.decode_workers
.read
()
.map_err
(|
_
|
"Lock error"
)
?
;
let
prefill
=
prefill_list
[
rand
::
random
::
<
usize
>
()
%
prefill_list
.len
()]
.clone
();
let
decode
=
decode_list
[
rand
::
random
::
<
usize
>
()
%
decode_list
.len
()]
.clone
();
Ok
((
prefill
,
decode
))
}
async
fn
select_power_of_two
(
&
self
)
->
Result
<
(
EngineInfo
,
EngineInfo
),
String
>
{
let
prefill_list
=
self
.prefill_workers
.read
()
.map_err
(|
_
|
"Lock error"
)
?
;
let
decode_list
=
self
.decode_workers
.read
()
.map_err
(|
_
|
"Lock error"
)
?
;
let
(
p1_idx
,
p2_idx
)
=
get_two_random_indices
(
prefill_list
.len
());
let
(
d1_idx
,
d2_idx
)
=
get_two_random_indices
(
decode_list
.len
());
let
loads
=
self
.worker_loads
.borrow
();
let
p1_load
=
loads
.get
(
&
prefill_list
[
p1_idx
]
.url
)
.copied
()
.unwrap_or
(
0
);
let
p2_load
=
loads
.get
(
&
prefill_list
[
p2_idx
]
.url
)
.copied
()
.unwrap_or
(
0
);
let
d1_load
=
loads
.get
(
&
decode_list
[
d1_idx
]
.url
)
.copied
()
.unwrap_or
(
0
);
let
d2_load
=
loads
.get
(
&
decode_list
[
d2_idx
]
.url
)
.copied
()
.unwrap_or
(
0
);
info!
(
"Power-of-two selection - Prefill: {}={} vs {}={} | Decode: {}={} vs {}={}"
,
prefill_list
[
p1_idx
]
.url
,
p1_load
,
prefill_list
[
p2_idx
]
.url
,
p2_load
,
decode_list
[
d1_idx
]
.url
,
d1_load
,
decode_list
[
d2_idx
]
.url
,
d2_load
);
let
selected_prefill
=
if
p1_load
<=
p2_load
{
prefill_list
[
p1_idx
]
.clone
()
}
else
{
prefill_list
[
p2_idx
]
.clone
()
};
let
selected_decode
=
if
d1_load
<=
d2_load
{
decode_list
[
d1_idx
]
.clone
()
}
else
{
decode_list
[
d2_idx
]
.clone
()
};
Ok
((
selected_prefill
,
selected_decode
))
}
// Background task to monitor worker loads with shared client
async
fn
monitor_worker_loads_with_client
(
worker_urls
:
Vec
<
String
>
,
tx
:
tokio
::
sync
::
watch
::
Sender
<
HashMap
<
String
,
isize
>>
,
interval_secs
:
u64
,
client
:
reqwest
::
Client
,
)
{
loop
{
let
mut
loads
=
HashMap
::
new
();
let
futures
:
Vec
<
_
>
=
worker_urls
.iter
()
.map
(|
url
|
{
let
client
=
client
.clone
();
let
url
=
url
.clone
();
async
move
{
let
load
=
get_worker_load
(
&
client
,
&
url
)
.await
.unwrap_or
(
0
);
(
url
,
load
)
}
})
.collect
();
let
results
=
futures_util
::
future
::
join_all
(
futures
)
.await
;
for
(
url
,
load
)
in
results
{
loads
.insert
(
url
,
load
);
}
debug!
(
"Worker loads updated: {:?}"
,
loads
);
// Check if receiver is still active
if
tx
.send
(
loads
)
.is_err
()
{
info!
(
"Load monitor receiver dropped, shutting down monitor task"
);
break
;
}
tokio
::
time
::
sleep
(
Duration
::
from_secs
(
interval_secs
))
.await
;
}
}
// Simple helper to merge logprobs in streaming responses
fn
merge_streaming_logprobs
(
prefill_logprobs
:
Option
<
Value
>
,
decode_chunk
:
&
[
u8
],
)
->
Result
<
bytes
::
Bytes
,
()
>
{
// Skip non-data chunks
let
chunk_str
=
std
::
str
::
from_utf8
(
decode_chunk
)
.map_err
(|
_
|
())
?
;
if
!
chunk_str
.starts_with
(
"data: "
)
||
chunk_str
.contains
(
"[DONE]"
)
{
return
Err
(());
}
// Parse JSON from chunk
let
json_str
=
chunk_str
.trim_start_matches
(
"data: "
)
.trim
();
let
mut
decode_json
:
Value
=
serde_json
::
from_str
(
json_str
)
.map_err
(|
_
|
())
?
;
// Merge prefill logprobs if available
if
let
Some
(
ref
p_logprobs
)
=
prefill_logprobs
{
if
let
Some
(
meta
)
=
decode_json
.get_mut
(
"meta_info"
)
{
if
let
Some
(
d_logprobs
)
=
meta
.get_mut
(
"input_token_logprobs"
)
{
if
let
(
Some
(
p_arr
),
Some
(
d_arr
))
=
(
p_logprobs
.as_array
(),
d_logprobs
.as_array
())
{
let
mut
merged
=
p_arr
.clone
();
merged
.extend
(
d_arr
.clone
());
*
d_logprobs
=
Value
::
Array
(
merged
);
}
}
}
}
// Re-serialize
let
merged_str
=
format!
(
"data: {}
\n\n
"
,
serde_json
::
to_string
(
&
decode_json
)
.unwrap_or_default
()
);
Ok
(
bytes
::
Bytes
::
from
(
merged_str
))
}
}
// Helper functions
fn
get_two_random_indices
(
len
:
usize
)
->
(
usize
,
usize
)
{
if
len
==
1
{
(
0
,
0
)
}
else
{
let
idx1
=
rand
::
random
::
<
usize
>
()
%
len
;
let
mut
idx2
=
rand
::
random
::
<
usize
>
()
%
len
;
while
idx2
==
idx1
{
idx2
=
rand
::
random
::
<
usize
>
()
%
len
;
}
(
idx1
,
idx2
)
}
}
async
fn
get_worker_load
(
client
:
&
reqwest
::
Client
,
worker_url
:
&
str
)
->
Option
<
isize
>
{
match
client
.get
(
format!
(
"{}/get_load"
,
worker_url
))
.send
()
.await
{
Ok
(
res
)
if
res
.status
()
.is_success
()
=>
match
res
.bytes
()
.await
{
Ok
(
bytes
)
=>
match
serde_json
::
from_slice
::
<
Value
>
(
&
bytes
)
{
Ok
(
data
)
=>
data
.get
(
"load"
)
.and_then
(|
v
|
v
.as_i64
())
.map
(|
v
|
v
as
isize
),
Err
(
e
)
=>
{
debug!
(
"Failed to parse load response from {}: {}"
,
worker_url
,
e
);
None
}
},
Err
(
e
)
=>
{
debug!
(
"Failed to read load response from {}: {}"
,
worker_url
,
e
);
None
}
},
Ok
(
res
)
=>
{
debug!
(
"Worker {} returned non-success status: {}"
,
worker_url
,
res
.status
()
);
None
}
Err
(
e
)
=>
{
debug!
(
"Failed to get load from {}: {}"
,
worker_url
,
e
);
None
}
}
}
// PD-specific endpoints
impl
PDRouter
{
pub
async
fn
health_generate
(
&
self
,
client
:
&
reqwest
::
Client
)
->
HttpResponse
{
let
mut
all_healthy
=
true
;
let
mut
unhealthy_servers
=
Vec
::
new
();
// Collect all worker URLs with their types
let
mut
worker_infos
=
Vec
::
new
();
for
worker
in
self
.prefill_workers
.read
()
.unwrap
()
.iter
()
{
worker_infos
.push
((
worker
.url
.clone
(),
"prefill"
));
}
for
worker
in
self
.decode_workers
.read
()
.unwrap
()
.iter
()
{
worker_infos
.push
((
worker
.url
.clone
(),
"decode"
));
}
// Create tasks with URL tracking
let
tasks
:
Vec
<
_
>
=
worker_infos
.iter
()
.map
(|(
url
,
_
)|
{
let
health_url
=
format!
(
"{}/health_generate"
,
url
);
client
.get
(
&
health_url
)
.send
()
})
.collect
();
let
results
=
futures_util
::
future
::
join_all
(
tasks
)
.await
;
for
((
url
,
worker_type
),
result
)
in
worker_infos
.iter
()
.zip
(
results
.into_iter
())
{
match
result
{
Ok
(
res
)
if
res
.status
()
.is_success
()
=>
{
debug!
(
"Health check passed for {} server: {}"
,
worker_type
,
url
);
}
Ok
(
res
)
=>
{
all_healthy
=
false
;
let
msg
=
format!
(
"{} server {} returned status {}"
,
worker_type
,
url
,
res
.status
()
);
error!
(
"{}"
,
msg
);
unhealthy_servers
.push
(
msg
);
}
Err
(
e
)
=>
{
all_healthy
=
false
;
let
msg
=
format!
(
"{} server {} error: {}"
,
worker_type
,
url
,
e
);
error!
(
"{}"
,
msg
);
unhealthy_servers
.push
(
msg
);
}
}
}
if
all_healthy
{
HttpResponse
::
Ok
()
.body
(
"Health check passed on all servers"
)
}
else
{
HttpResponse
::
ServiceUnavailable
()
.body
(
format!
(
"Health check failed: {:?}"
,
unhealthy_servers
))
}
}
pub
async
fn
get_server_info
(
&
self
,
client
:
&
reqwest
::
Client
)
->
HttpResponse
{
// Get info from all decode servers (where generation happens)
let
mut
all_internal_states
=
Vec
::
new
();
let
mut
decode_infos
=
Vec
::
new
();
// Clone URLs to avoid holding lock across await
let
worker_urls
:
Vec
<
String
>
=
self
.decode_workers
.read
()
.unwrap
()
.iter
()
.map
(|
w
|
w
.url
.clone
())
.collect
();
for
worker_url
in
worker_urls
{
match
client
.get
(
format!
(
"{}/get_server_info"
,
worker_url
))
.send
()
.await
{
Ok
(
res
)
if
res
.status
()
.is_success
()
=>
{
match
res
.json
::
<
Value
>
()
.await
{
Ok
(
info
)
=>
{
// Extract internal_states from each decode server
if
let
Some
(
states
)
=
info
.get
(
"internal_states"
)
{
if
let
Some
(
states_array
)
=
states
.as_array
()
{
all_internal_states
.extend
(
states_array
.clone
());
}
}
decode_infos
.push
(
info
);
}
Err
(
e
)
=>
error!
(
"Failed to parse server info: {}"
,
e
),
}
}
_
=>
{}
}
}
// If we have internal states, return in the format expected by bench_one_batch_server.py
if
!
all_internal_states
.is_empty
()
{
// Use the first decode server's internal state (they should all be similar)
HttpResponse
::
Ok
()
.json
(
serde_json
::
json!
({
"internal_states"
:
all_internal_states
,
// Include original format for compatibility
"decode_servers"
:
decode_infos
,
}))
}
else
{
// Fallback: create a dummy internal_states entry
HttpResponse
::
Ok
()
.json
(
serde_json
::
json!
({
"internal_states"
:
[{
"last_gen_throughput"
:
0.0
,
"avg_spec_accept_length"
:
null
,
}],
"decode_servers"
:
decode_infos
,
}))
}
}
pub
async
fn
get_models
(
&
self
,
client
:
&
reqwest
::
Client
,
req
:
&
HttpRequest
)
->
HttpResponse
{
// Get first prefill worker URL to avoid holding lock across await
let
first_worker_url
=
if
let
Ok
(
workers
)
=
self
.prefill_workers
.read
()
{
workers
.first
()
.map
(|
w
|
w
.url
.clone
())
}
else
{
return
HttpResponse
::
InternalServerError
()
.body
(
"Failed to access prefill workers"
);
};
if
let
Some
(
worker_url
)
=
first_worker_url
{
// Send request directly without going through Router
let
mut
request_builder
=
client
.get
(
format!
(
"{}/v1/models"
,
worker_url
));
for
(
name
,
value
)
in
crate
::
router
::
copy_request_headers
(
req
)
{
if
name
.to_lowercase
()
!=
"content-type"
&&
name
.to_lowercase
()
!=
"content-length"
{
request_builder
=
request_builder
.header
(
name
,
value
);
}
}
match
request_builder
.send
()
.await
{
Ok
(
res
)
=>
{
let
status
=
actix_web
::
http
::
StatusCode
::
from_u16
(
res
.status
()
.as_u16
())
.unwrap_or
(
actix_web
::
http
::
StatusCode
::
INTERNAL_SERVER_ERROR
);
match
res
.bytes
()
.await
{
Ok
(
body
)
=>
HttpResponse
::
build
(
status
)
.body
(
body
.to_vec
()),
Err
(
e
)
=>
HttpResponse
::
InternalServerError
()
.body
(
format!
(
"Failed to read response body: {}"
,
e
)),
}
}
Err
(
e
)
=>
HttpResponse
::
InternalServerError
()
.body
(
format!
(
"Failed to send request: {}"
,
e
)),
}
}
else
{
HttpResponse
::
ServiceUnavailable
()
.body
(
"No prefill servers available"
)
}
}
pub
async
fn
get_loads
(
&
self
,
client
:
&
reqwest
::
Client
)
->
HttpResponse
{
let
p_urls
:
Vec
<
_
>
=
self
.prefill_workers
.read
()
.unwrap
()
.iter
()
.map
(|
w
|
w
.url
.clone
())
.collect
();
let
d_urls
:
Vec
<
_
>
=
self
.decode_workers
.read
()
.unwrap
()
.iter
()
.map
(|
w
|
w
.url
.clone
())
.collect
();
let
mut
prefill_loads
=
Vec
::
new
();
let
mut
decode_loads
=
Vec
::
new
();
for
url
in
&
p_urls
{
let
load
=
get_worker_load
(
client
,
url
)
.await
.unwrap_or
(
-
1
);
prefill_loads
.push
(
serde_json
::
json!
({
"engine"
:
format!
(
"(Prefill@{})"
,
url
),
"load"
:
load
as
i64
}));
}
for
url
in
&
d_urls
{
let
load
=
get_worker_load
(
client
,
url
)
.await
.unwrap_or
(
-
1
);
decode_loads
.push
(
serde_json
::
json!
({
"engine"
:
format!
(
"(Decode@{})"
,
url
),
"load"
:
load
as
i64
}));
}
HttpResponse
::
Ok
()
.json
(
serde_json
::
json!
({
"prefill"
:
prefill_loads
,
"decode"
:
decode_loads
}))
}
pub
async
fn
get_model_info
(
&
self
,
client
:
&
reqwest
::
Client
,
req
:
&
HttpRequest
,
)
->
HttpResponse
{
// Get model info from the first prefill server (matches original Rust PDLB behavior)
// Get first prefill worker URL to avoid holding lock across await
let
first_worker_url
=
if
let
Ok
(
workers
)
=
self
.prefill_workers
.read
()
{
workers
.first
()
.map
(|
w
|
w
.url
.clone
())
}
else
{
return
HttpResponse
::
InternalServerError
()
.body
(
"Failed to access prefill workers"
);
};
if
let
Some
(
worker_url
)
=
first_worker_url
{
let
mut
request_builder
=
client
.get
(
format!
(
"{}/get_model_info"
,
worker_url
));
for
(
name
,
value
)
in
crate
::
router
::
copy_request_headers
(
req
)
{
if
name
.to_lowercase
()
!=
"content-type"
&&
name
.to_lowercase
()
!=
"content-length"
{
request_builder
=
request_builder
.header
(
name
,
value
);
}
}
match
request_builder
.send
()
.await
{
Ok
(
res
)
=>
{
let
status
=
actix_web
::
http
::
StatusCode
::
from_u16
(
res
.status
()
.as_u16
())
.unwrap_or
(
actix_web
::
http
::
StatusCode
::
INTERNAL_SERVER_ERROR
);
match
res
.bytes
()
.await
{
Ok
(
body
)
=>
HttpResponse
::
build
(
status
)
.body
(
body
.to_vec
()),
Err
(
e
)
=>
HttpResponse
::
InternalServerError
()
.body
(
format!
(
"Failed to read response body: {}"
,
e
)),
}
}
Err
(
e
)
=>
HttpResponse
::
InternalServerError
()
.body
(
format!
(
"Failed to send request: {}"
,
e
)),
}
}
else
{
HttpResponse
::
ServiceUnavailable
()
.body
(
"No prefill servers available"
)
}
}
pub
async
fn
flush_cache
(
&
self
,
client
:
&
reqwest
::
Client
)
->
HttpResponse
{
let
mut
tasks
=
Vec
::
new
();
// Flush cache on all prefill servers
for
worker
in
self
.prefill_workers
.read
()
.unwrap
()
.iter
()
{
let
url
=
format!
(
"{}/flush_cache"
,
worker
.url
);
tasks
.push
(
client
.post
(
&
url
)
.send
());
}
// Flush cache on all decode servers
for
worker
in
self
.decode_workers
.read
()
.unwrap
()
.iter
()
{
let
url
=
format!
(
"{}/flush_cache"
,
worker
.url
);
tasks
.push
(
client
.post
(
&
url
)
.send
());
}
let
results
=
futures_util
::
future
::
join_all
(
tasks
)
.await
;
let
mut
all_success
=
true
;
for
(
i
,
result
)
in
results
.into_iter
()
.enumerate
()
{
match
result
{
Ok
(
res
)
if
res
.status
()
.is_success
()
=>
{}
Ok
(
res
)
=>
{
all_success
=
false
;
warn!
(
"Server {} returned status {} for flush_cache"
,
i
,
res
.status
()
);
}
Err
(
e
)
=>
{
all_success
=
false
;
error!
(
"Server {} error during flush_cache: {}"
,
i
,
e
);
}
}
}
if
all_success
{
HttpResponse
::
Ok
()
.body
(
"Cache flushed on all servers"
)
}
else
{
HttpResponse
::
InternalServerError
()
.body
(
"Cache flush failed on one or more servers"
)
}
}
}
sgl-router/src/pd_types.rs
0 → 100644
View file @
09ae5b20
// Essential PDLB types extracted for PD routing
use
serde
::{
Deserialize
,
Serialize
};
use
serde_json
::
Value
;
#[derive(Debug,
Clone)]
pub
enum
EngineType
{
Prefill
,
Decode
,
}
#[derive(Debug,
Clone)]
pub
struct
EngineInfo
{
pub
engine_type
:
EngineType
,
pub
url
:
String
,
pub
bootstrap_port
:
Option
<
u16
>
,
}
impl
EngineInfo
{
pub
fn
new_prefill
(
url
:
String
,
bootstrap_port
:
Option
<
u16
>
)
->
Self
{
EngineInfo
{
engine_type
:
EngineType
::
Prefill
,
url
,
bootstrap_port
,
}
}
pub
fn
new_decode
(
url
:
String
)
->
Self
{
EngineInfo
{
engine_type
:
EngineType
::
Decode
,
url
,
bootstrap_port
:
None
,
}
}
pub
fn
api_path
(
&
self
,
api_path
:
&
str
)
->
String
{
if
api_path
.starts_with
(
"/"
)
{
format!
(
"{}{}"
,
self
.url
,
api_path
)
}
else
{
format!
(
"{}/{}"
,
self
.url
,
api_path
)
}
}
pub
fn
get_hostname
(
&
self
)
->
String
{
// Simple hostname extraction without external dependencies
let
url
=
self
.url
.trim_start_matches
(
"http://"
)
.trim_start_matches
(
"https://"
);
url
.split
(
':'
)
.next
()
.unwrap_or
(
"localhost"
)
.to_string
()
}
}
// PD-specific routing policies
#[derive(Debug,
Clone,
PartialEq)]
pub
enum
PDSelectionPolicy
{
Random
,
PowerOfTwo
,
CacheAware
{
cache_threshold
:
f32
,
balance_abs_threshold
:
usize
,
balance_rel_threshold
:
f32
,
},
}
// Bootstrap types from PDLB
#[derive(Debug,
Deserialize,
Serialize)]
#[serde(untagged)]
pub
enum
SingleOrBatch
<
T
>
{
Single
(
T
),
Batch
(
Vec
<
T
>
),
}
pub
type
InputIds
=
SingleOrBatch
<
Vec
<
i32
>>
;
pub
type
InputText
=
SingleOrBatch
<
String
>
;
pub
type
BootstrapHost
=
SingleOrBatch
<
String
>
;
pub
type
BootstrapPort
=
SingleOrBatch
<
Option
<
u16
>>
;
pub
type
BootstrapRoom
=
SingleOrBatch
<
u64
>
;
// Bootstrap trait for request handling
pub
trait
Bootstrap
:
Send
+
Sync
{
fn
is_stream
(
&
self
)
->
bool
;
fn
get_batch_size
(
&
self
)
->
Result
<
Option
<
usize
>
,
String
>
;
fn
set_bootstrap_info
(
&
mut
self
,
bootstrap_host
:
BootstrapHost
,
bootstrap_port
:
BootstrapPort
,
bootstrap_room
:
BootstrapRoom
,
);
fn
add_bootstrap_info
(
&
mut
self
,
prefill_info
:
&
EngineInfo
)
->
Result
<
(),
String
>
{
let
batch_size
=
self
.get_batch_size
()
?
;
if
let
Some
(
batch_size
)
=
batch_size
{
self
.set_bootstrap_info
(
BootstrapHost
::
Batch
(
vec!
[
prefill_info
.get_hostname
();
batch_size
]),
BootstrapPort
::
Batch
(
vec!
[
prefill_info
.bootstrap_port
;
batch_size
]),
// Use high-quality random numbers to minimize collision risk
BootstrapRoom
::
Batch
(
(
0
..
batch_size
)
.map
(|
_
|
{
// Combine multiple sources of randomness for better distribution
let
r1
=
rand
::
random
::
<
u64
>
();
let
r2
=
rand
::
random
::
<
u64
>
();
r1
.wrapping_add
(
r2
.rotate_left
(
32
))
})
.collect
(),
),
);
}
else
{
self
.set_bootstrap_info
(
BootstrapHost
::
Single
(
prefill_info
.get_hostname
()),
BootstrapPort
::
Single
(
prefill_info
.bootstrap_port
),
BootstrapRoom
::
Single
({
// Use high-quality random number for single requests too
let
r1
=
rand
::
random
::
<
u64
>
();
let
r2
=
rand
::
random
::
<
u64
>
();
r1
.wrapping_add
(
r2
.rotate_left
(
32
))
}),
);
}
Ok
(())
}
}
// Request types
#[derive(Debug,
Deserialize,
Serialize)]
pub
struct
GenerateReqInput
{
pub
text
:
Option
<
InputText
>
,
pub
input_ids
:
Option
<
InputIds
>
,
#[serde(default)]
pub
stream
:
bool
,
pub
bootstrap_host
:
Option
<
BootstrapHost
>
,
pub
bootstrap_port
:
Option
<
BootstrapPort
>
,
pub
bootstrap_room
:
Option
<
BootstrapRoom
>
,
#[serde(flatten)]
pub
other
:
Value
,
}
impl
GenerateReqInput
{
pub
fn
get_batch_size
(
&
self
)
->
Result
<
Option
<
usize
>
,
String
>
{
if
self
.text
.is_some
()
&&
self
.input_ids
.is_some
()
{
return
Err
(
"Both text and input_ids are present in the request"
.to_string
());
}
// Check text batch
if
let
Some
(
InputText
::
Batch
(
texts
))
=
&
self
.text
{
if
texts
.is_empty
()
{
return
Err
(
"Batch text array is empty"
.to_string
());
}
if
texts
.len
()
>
10000
{
// Reasonable limit for production
return
Err
(
format!
(
"Batch size {} exceeds maximum allowed (10000)"
,
texts
.len
()
));
}
return
Ok
(
Some
(
texts
.len
()));
}
// Check input_ids batch
if
let
Some
(
InputIds
::
Batch
(
ids
))
=
&
self
.input_ids
{
if
ids
.is_empty
()
{
return
Err
(
"Batch input_ids array is empty"
.to_string
());
}
if
ids
.len
()
>
10000
{
// Reasonable limit for production
return
Err
(
format!
(
"Batch size {} exceeds maximum allowed (10000)"
,
ids
.len
()
));
}
// Validate each sequence is not empty
for
(
i
,
seq
)
in
ids
.iter
()
.enumerate
()
{
if
seq
.is_empty
()
{
return
Err
(
format!
(
"Input sequence at index {} is empty"
,
i
));
}
}
return
Ok
(
Some
(
ids
.len
()));
}
Ok
(
None
)
}
}
impl
Bootstrap
for
GenerateReqInput
{
fn
is_stream
(
&
self
)
->
bool
{
self
.stream
}
fn
get_batch_size
(
&
self
)
->
Result
<
Option
<
usize
>
,
String
>
{
self
.get_batch_size
()
}
fn
set_bootstrap_info
(
&
mut
self
,
bootstrap_host
:
BootstrapHost
,
bootstrap_port
:
BootstrapPort
,
bootstrap_room
:
BootstrapRoom
,
)
{
self
.bootstrap_host
=
Some
(
bootstrap_host
);
self
.bootstrap_port
=
Some
(
bootstrap_port
);
self
.bootstrap_room
=
Some
(
bootstrap_room
);
}
}
#[derive(Debug,
Deserialize,
Serialize)]
pub
struct
ChatReqInput
{
#[serde(default)]
pub
stream
:
bool
,
pub
bootstrap_host
:
Option
<
BootstrapHost
>
,
pub
bootstrap_port
:
Option
<
BootstrapPort
>
,
pub
bootstrap_room
:
Option
<
BootstrapRoom
>
,
#[serde(flatten)]
pub
other
:
Value
,
}
impl
Bootstrap
for
ChatReqInput
{
fn
is_stream
(
&
self
)
->
bool
{
self
.stream
}
fn
get_batch_size
(
&
self
)
->
Result
<
Option
<
usize
>
,
String
>
{
// Check if 'n' parameter is present and > 1
if
let
Some
(
n_value
)
=
self
.other
.get
(
"n"
)
{
if
let
Some
(
n
)
=
n_value
.as_u64
()
{
if
n
>
1
{
return
Ok
(
Some
(
n
as
usize
));
}
}
}
Ok
(
None
)
}
fn
set_bootstrap_info
(
&
mut
self
,
bootstrap_host
:
BootstrapHost
,
bootstrap_port
:
BootstrapPort
,
bootstrap_room
:
BootstrapRoom
,
)
{
self
.bootstrap_host
=
Some
(
bootstrap_host
);
self
.bootstrap_port
=
Some
(
bootstrap_port
);
self
.bootstrap_room
=
Some
(
bootstrap_room
);
}
}
sgl-router/src/request_adapter.rs
0 → 100644
View file @
09ae5b20
// Request adapter to bridge OpenAI API types with PD routing requirements
use
crate
::
openai_api_types
::{
ChatCompletionRequest
,
CompletionRequest
,
GenerateRequest
,
GenerationRequest
,
StringOrArray
,
};
use
crate
::
pd_types
::{
Bootstrap
,
ChatReqInput
,
GenerateReqInput
,
SingleOrBatch
};
use
serde_json
::
Value
;
/// Adapter trait to convert OpenAI requests to PD-compatible requests
pub
trait
ToPdRequest
{
type
Output
:
Bootstrap
;
fn
to_pd_request
(
self
)
->
Self
::
Output
;
}
// Helper macro to insert optional fields into a map
macro_rules!
insert_if_some
{
(
$map:expr
,
$
(
$field:expr
=>
$key:expr
),
*
$
(,)
?
)
=>
{
$
(
if
let
Some
(
value
)
=
$field
{
$map
.insert
(
$key
.to_string
(),
serde_json
::
to_value
(
value
)
.unwrap_or
(
Value
::
Null
));
}
)
*
};
}
// Helper macro for simple value insertions
macro_rules!
insert_value
{
(
$map:expr
,
$
(
$field:expr
=>
$key:expr
),
*
$
(,)
?
)
=>
{
$
(
$map
.insert
(
$key
.to_string
(),
$field
.into
());
)
*
};
}
// ============= Generate Request Adapter =============
impl
ToPdRequest
for
GenerateRequest
{
type
Output
=
GenerateReqInput
;
fn
to_pd_request
(
self
)
->
Self
::
Output
{
// Build the other fields first
let
mut
other
=
serde_json
::
Map
::
new
();
// Handle text input - check in priority order: text (SGLang), prompt (OpenAI)
let
(
text
,
input_ids
)
=
if
let
Some
(
text_str
)
=
self
.text
{
// SGLang native format
(
Some
(
SingleOrBatch
::
Single
(
text_str
)),
None
)
}
else
if
let
Some
(
prompt
)
=
self
.prompt
{
// OpenAI style prompt
let
text
=
match
prompt
{
StringOrArray
::
String
(
s
)
=>
Some
(
SingleOrBatch
::
Single
(
s
)),
StringOrArray
::
Array
(
v
)
=>
Some
(
SingleOrBatch
::
Batch
(
v
)),
};
(
text
,
None
)
}
else
if
let
Some
(
ids
)
=
self
.input_ids
{
// Input IDs case
let
input_ids
=
match
ids
{
crate
::
openai_api_types
::
InputIds
::
Single
(
ids
)
=>
Some
(
SingleOrBatch
::
Single
(
ids
)),
crate
::
openai_api_types
::
InputIds
::
Batch
(
ids
)
=>
Some
(
SingleOrBatch
::
Batch
(
ids
)),
};
(
None
,
input_ids
)
}
else
{
// No input provided
(
None
,
None
)
};
// Add parameters to other - handle both old and new style
if
let
Some
(
params
)
=
self
.parameters
{
// For generate endpoint, extract max_new_tokens to top level if present
let
mut
params_value
=
serde_json
::
to_value
(
&
params
)
.unwrap_or
(
Value
::
Null
);
if
let
Value
::
Object
(
ref
mut
params_map
)
=
params_value
{
// Move max_new_tokens to top level if it exists
if
let
Some
(
max_new_tokens
)
=
params_map
.remove
(
"max_new_tokens"
)
{
other
.insert
(
"max_new_tokens"
.to_string
(),
max_new_tokens
);
}
// Move temperature to top level if it exists
if
let
Some
(
temperature
)
=
params_map
.remove
(
"temperature"
)
{
other
.insert
(
"temperature"
.to_string
(),
temperature
);
}
}
// Only add parameters if there are remaining fields
if
!
params_value
.is_null
()
&&
params_value
.as_object
()
.map_or
(
false
,
|
m
|
!
m
.is_empty
())
{
other
.insert
(
"parameters"
.to_string
(),
params_value
);
}
}
// Add sampling_params if present
if
let
Some
(
sampling_params
)
=
self
.sampling_params
{
let
params_value
=
serde_json
::
to_value
(
&
sampling_params
)
.unwrap_or
(
Value
::
Null
);
if
!
params_value
.is_null
()
{
// Extract commonly used fields to top level
if
let
Value
::
Object
(
ref
params_map
)
=
params_value
{
if
let
Some
(
max_new_tokens
)
=
params_map
.get
(
"max_new_tokens"
)
{
other
.insert
(
"max_new_tokens"
.to_string
(),
max_new_tokens
.clone
());
}
if
let
Some
(
temperature
)
=
params_map
.get
(
"temperature"
)
{
other
.insert
(
"temperature"
.to_string
(),
temperature
.clone
());
}
}
other
.insert
(
"sampling_params"
.to_string
(),
params_value
);
}
}
// Add other fields
insert_value!
(
other
,
self
.stream
=>
"stream"
,
self
.return_logprob
=>
"return_logprob"
);
GenerateReqInput
{
text
,
input_ids
,
stream
:
self
.stream
,
bootstrap_host
:
None
,
bootstrap_port
:
None
,
bootstrap_room
:
None
,
other
:
Value
::
Object
(
other
),
}
}
}
// ============= Completion Request Adapter =============
impl
ToPdRequest
for
CompletionRequest
{
type
Output
=
GenerateReqInput
;
fn
to_pd_request
(
self
)
->
Self
::
Output
{
// Convert CompletionRequest to GenerateReqInput
let
text
=
match
self
.prompt
{
StringOrArray
::
String
(
s
)
=>
Some
(
SingleOrBatch
::
Single
(
s
)),
StringOrArray
::
Array
(
v
)
=>
Some
(
SingleOrBatch
::
Batch
(
v
)),
};
// Map OpenAI parameters to generate parameters
let
mut
other
=
serde_json
::
Map
::
new
();
// Create parameters object
let
mut
params
=
serde_json
::
Map
::
new
();
// Map OpenAI fields to internal parameter names
insert_if_some!
(
params
,
self
.max_tokens
=>
"max_new_tokens"
,
self
.temperature
=>
"temperature"
,
self
.top_p
=>
"top_p"
,
self
.n
=>
"best_of"
,
self
.logprobs
=>
"top_n_tokens"
,
self
.seed
=>
"seed"
);
// Special handling for fields that need transformation
if
let
Some
(
presence_penalty
)
=
self
.presence_penalty
{
params
.insert
(
"repetition_penalty"
.to_string
(),
(
1.0
+
presence_penalty
)
.into
(),
);
}
if
let
Some
(
stop
)
=
self
.stop
{
let
stop_sequences
=
match
stop
{
StringOrArray
::
String
(
s
)
=>
vec!
[
s
],
StringOrArray
::
Array
(
v
)
=>
v
,
};
params
.insert
(
"stop"
.to_string
(),
stop_sequences
.into
());
}
if
self
.echo
{
params
.insert
(
"return_full_text"
.to_string
(),
true
.into
());
}
other
.insert
(
"parameters"
.to_string
(),
Value
::
Object
(
params
));
// Store original model and stream flag
insert_value!
(
other
,
self
.model
=>
"model"
,
self
.stream
=>
"stream"
);
GenerateReqInput
{
text
,
input_ids
:
None
,
stream
:
self
.stream
,
bootstrap_host
:
None
,
bootstrap_port
:
None
,
bootstrap_room
:
None
,
other
:
Value
::
Object
(
other
),
}
}
}
// ============= Chat Completion Request Adapter =============
impl
ToPdRequest
for
ChatCompletionRequest
{
type
Output
=
ChatReqInput
;
fn
to_pd_request
(
self
)
->
Self
::
Output
{
let
mut
other
=
serde_json
::
Map
::
new
();
// Add required fields
insert_if_some!
(
other
,
Some
(
&
self
.messages
)
=>
"messages"
);
insert_value!
(
other
,
self
.model
=>
"model"
,
self
.stream
=>
"stream"
);
// Add all optional fields
insert_if_some!
(
other
,
self
.temperature
=>
"temperature"
,
self
.top_p
=>
"top_p"
,
self
.n
=>
"n"
,
self
.stop
=>
"stop"
,
self
.max_tokens
=>
"max_tokens"
,
self
.max_completion_tokens
=>
"max_completion_tokens"
,
self
.presence_penalty
=>
"presence_penalty"
,
self
.frequency_penalty
=>
"frequency_penalty"
,
self
.logit_bias
=>
"logit_bias"
,
self
.user
=>
"user"
,
self
.seed
=>
"seed"
,
self
.top_logprobs
=>
"top_logprobs"
,
self
.response_format
=>
"response_format"
,
self
.tools
=>
"tools"
,
self
.tool_choice
=>
"tool_choice"
,
self
.parallel_tool_calls
=>
"parallel_tool_calls"
,
self
.functions
=>
"functions"
,
self
.function_call
=>
"function_call"
);
// Handle boolean logprobs flag
if
self
.logprobs
{
other
.insert
(
"logprobs"
.to_string
(),
true
.into
());
}
ChatReqInput
{
stream
:
self
.stream
,
bootstrap_host
:
None
,
bootstrap_port
:
None
,
bootstrap_room
:
None
,
other
:
Value
::
Object
(
other
),
}
}
}
// ============= Direct routing support for regular router =============
/// Extension trait for routing without PD conversion
pub
trait
RouteableRequest
:
GenerationRequest
+
serde
::
Serialize
+
Clone
{
/// Convert to JSON for sending to backend
fn
to_json
(
&
self
)
->
Result
<
Value
,
serde_json
::
Error
>
{
serde_json
::
to_value
(
self
)
}
/// Convert to bytes for legacy routing
fn
to_bytes
(
&
self
)
->
Result
<
bytes
::
Bytes
,
serde_json
::
Error
>
{
let
json
=
serde_json
::
to_vec
(
self
)
?
;
Ok
(
bytes
::
Bytes
::
from
(
json
))
}
}
impl
RouteableRequest
for
GenerateRequest
{}
impl
RouteableRequest
for
CompletionRequest
{}
impl
RouteableRequest
for
ChatCompletionRequest
{}
sgl-router/src/router.rs
View file @
09ae5b20
use
crate
::
pd_router
::
PDRouter
;
use
crate
::
pd_types
::
PDSelectionPolicy
;
use
crate
::
tree
::
Tree
;
use
::
metrics
::{
counter
,
gauge
,
histogram
};
use
actix_web
::
http
::
header
::{
HeaderValue
,
CONTENT_TYPE
};
use
actix_web
::{
HttpRequest
,
HttpResponse
};
use
bytes
::
Bytes
;
use
futures_util
::{
StreamExt
,
TryStreamExt
};
use
serde_json
::
Value
;
use
std
::
collections
::
HashMap
;
use
std
::
fmt
::
Debug
;
use
std
::
sync
::
atomic
::
AtomicUsize
;
...
...
@@ -15,7 +15,7 @@ use std::time::Instant;
use
tokio
;
use
tracing
::{
debug
,
error
,
info
,
warn
};
fn
copy_request_headers
(
req
:
&
HttpRequest
)
->
Vec
<
(
String
,
String
)
>
{
pub
fn
copy_request_headers
(
req
:
&
HttpRequest
)
->
Vec
<
(
String
,
String
)
>
{
req
.headers
()
.iter
()
.filter_map
(|(
name
,
value
)|
{
...
...
@@ -40,6 +40,9 @@ pub enum Router {
timeout_secs
:
u64
,
interval_secs
:
u64
,
},
PrefillDecode
{
pd_router
:
Arc
<
PDRouter
>
,
},
CacheAware
{
/*
Cache-Aware Load Balancing Router
...
...
@@ -133,6 +136,13 @@ pub enum PolicyConfig {
timeout_secs
:
u64
,
interval_secs
:
u64
,
},
PrefillDecodeConfig
{
selection_policy
:
PDSelectionPolicy
,
prefill_urls
:
Vec
<
(
String
,
Option
<
u16
>
)
>
,
// (url, bootstrap_port)
decode_urls
:
Vec
<
String
>
,
timeout_secs
:
u64
,
interval_secs
:
u64
,
},
}
impl
Router
{
...
...
@@ -155,10 +165,24 @@ impl Router {
interval_secs
,
..
}
=>
(
*
timeout_secs
,
*
interval_secs
),
PolicyConfig
::
PrefillDecodeConfig
{
timeout_secs
,
interval_secs
,
..
}
=>
(
*
timeout_secs
,
*
interval_secs
),
};
// Wait until all workers are healthy
Self
::
wait_for_healthy_workers
(
&
worker_urls
,
timeout_secs
,
interval_secs
)
?
;
// For PrefillDecode, we need to handle workers differently
match
&
policy_config
{
PolicyConfig
::
PrefillDecodeConfig
{
..
}
=>
{
// PD mode doesn't use the worker_urls parameter
// We'll validate PD workers separately
}
_
=>
{
// Wait until all workers are healthy for regular modes
Self
::
wait_for_healthy_workers
(
&
worker_urls
,
timeout_secs
,
interval_secs
)
?
;
}
}
// Create router based on policy...
Ok
(
match
policy_config
{
...
...
@@ -226,7 +250,7 @@ impl Router {
});
for
url
in
&
worker_urls
{
tree
.lock
()
.unwrap
()
.insert
(
&
""
.to_string
()
,
url
);
tree
.lock
()
.unwrap
()
.insert
(
""
,
url
);
}
Router
::
CacheAware
{
...
...
@@ -242,6 +266,26 @@ impl Router {
_
eviction_thread
:
Some
(
eviction_thread
),
}
}
PolicyConfig
::
PrefillDecodeConfig
{
selection_policy
,
prefill_urls
,
decode_urls
,
timeout_secs
,
interval_secs
,
}
=>
{
// Create PDRouter instance
let
pd_router
=
PDRouter
::
new
(
prefill_urls
,
decode_urls
,
selection_policy
,
timeout_secs
,
interval_secs
,
)
?
;
Router
::
PrefillDecode
{
pd_router
:
Arc
::
new
(
pd_router
),
}
}
})
}
...
...
@@ -251,16 +295,23 @@ impl Router {
Router
::
RoundRobin
{
worker_urls
,
..
}
=>
Arc
::
clone
(
worker_urls
),
Router
::
Random
{
worker_urls
,
..
}
=>
Arc
::
clone
(
worker_urls
),
Router
::
CacheAware
{
worker_urls
,
..
}
=>
Arc
::
clone
(
worker_urls
),
Router
::
PrefillDecode
{
..
}
=>
{
// For PD mode, return empty list since we manage workers differently
Arc
::
new
(
RwLock
::
new
(
Vec
::
new
()))
}
}
}
fn
wait_for_healthy_workers
(
pub
fn
wait_for_healthy_workers
(
worker_urls
:
&
[
String
],
timeout_secs
:
u64
,
interval_secs
:
u64
,
)
->
Result
<
(),
String
>
{
let
start_time
=
std
::
time
::
Instant
::
now
();
let
sync_client
=
reqwest
::
blocking
::
Client
::
new
();
let
sync_client
=
reqwest
::
blocking
::
Client
::
builder
()
.timeout
(
Duration
::
from_secs
(
timeout_secs
))
.build
()
.map_err
(|
e
|
format!
(
"Failed to create HTTP client: {}"
,
e
))
?
;
loop
{
if
start_time
.elapsed
()
>
Duration
::
from_secs
(
timeout_secs
)
{
...
...
@@ -323,10 +374,14 @@ impl Router {
Ok
(
worker_urls
.read
()
.unwrap
()[
0
]
.clone
())
}
}
Router
::
PrefillDecode
{
..
}
=>
{
// For PD mode, we don't need this method as routing is handled by PDRouter
Err
(
"PrefillDecode mode doesn't use select_first_worker"
.to_string
())
}
}
}
async
fn
send_request
(
pub
async
fn
send_request
(
&
self
,
client
:
&
reqwest
::
Client
,
worker_url
:
&
str
,
...
...
@@ -339,7 +394,11 @@ impl Router {
// Copy all headers from original request except for /health because it does not need authorization
if
route
!=
"/health"
{
for
(
name
,
value
)
in
copy_request_headers
(
req
)
{
request_builder
=
request_builder
.header
(
name
,
value
);
// Skip Content-Type and Content-Length as .json() sets them
if
name
.to_lowercase
()
!=
"content-type"
&&
name
.to_lowercase
()
!=
"content-length"
{
request_builder
=
request_builder
.header
(
name
,
value
);
}
}
}
...
...
@@ -433,50 +492,193 @@ impl Router {
HttpResponse
::
InternalServerError
()
.body
(
"All retry attempts failed"
)
}
fn
get_text_from_request
(
&
self
,
body
:
&
Bytes
,
route
:
&
str
)
->
String
{
// Convert body to JSON
let
json
:
Value
=
match
serde_json
::
from_slice
(
body
)
{
Ok
(
j
)
=>
j
,
Err
(
_
)
=>
{
warn!
(
"Failed to parse JSON from request body."
);
return
String
::
new
();
pub
async
fn
route_to_all
(
&
self
,
client
:
&
reqwest
::
Client
,
route
:
&
str
,
req
:
&
HttpRequest
,
)
->
HttpResponse
{
// Get all worker URLs based on router type
let
worker_urls
=
match
self
{
Router
::
PrefillDecode
{
..
}
=>
{
// For PD mode, route_to_all is not supported directly
// It should be handled by PDRouter if needed
return
HttpResponse
::
NotImplemented
()
.body
(
"route_to_all not implemented for PrefillDecode mode"
);
}
_
=>
self
.get_worker_urls
()
.read
()
.unwrap
()
.clone
(),
};
match
route
{
"/generate"
=>
{
// For /generate, always use the "text" field.
match
json
.get
(
"text"
)
.and_then
(
Value
::
as_str
)
{
Some
(
text
)
=>
text
.to_string
(),
None
=>
{
warn!
(
"No 'text' field found in request body for route /generate."
);
String
::
new
()
}
}
// Send requests to all workers concurrently
let
mut
tasks
=
Vec
::
new
();
for
worker_url
in
&
worker_urls
{
let
mut
request_builder
=
client
.post
(
format!
(
"{}{}"
,
worker_url
,
route
));
// Copy headers from original request
for
(
name
,
value
)
in
copy_request_headers
(
req
)
{
request_builder
=
request_builder
.header
(
name
,
value
);
}
"/v1/chat/completions"
|
"/v1/completions"
=>
{
// For these routes, try "messages", then "prompt", then "text".
if
let
Some
(
messages
)
=
json
.get
(
"messages"
)
{
serde_json
::
to_string
(
messages
)
.unwrap_or_default
()
}
else
if
let
Some
(
prompt
)
=
json
.get
(
"prompt"
)
.and_then
(
Value
::
as_str
)
{
prompt
.to_string
()
}
else
{
warn!
(
"Failed to find 'messages', 'prompt' in request body."
);
String
::
new
()
}
tasks
.push
(
request_builder
.send
());
}
// Wait for all responses
let
results
=
futures_util
::
future
::
join_all
(
tasks
)
.await
;
// Check if all succeeded
let
all_success
=
results
.iter
()
.all
(|
r
|
{
r
.as_ref
()
.map
(|
res
|
res
.status
()
.is_success
())
.unwrap_or
(
false
)
});
if
all_success
{
HttpResponse
::
Ok
()
.body
(
"Operation completed on all servers"
)
}
else
{
HttpResponse
::
InternalServerError
()
.body
(
"Operation failed on one or more servers"
)
}
}
pub
async
fn
get_all_loads
(
&
self
,
client
:
&
reqwest
::
Client
,
_
req
:
&
HttpRequest
,
)
->
HttpResponse
{
// For PD mode, delegate to PDRouter
match
self
{
Router
::
PrefillDecode
{
pd_router
}
=>
{
return
pd_router
.get_loads
(
client
)
.await
;
}
_
=>
{
warn!
(
"Unknown route: {} - defaulting to fallback string"
,
route
);
String
::
new
()
// For non-PD routers, handle normally
}
}
let
urls
=
self
.get_worker_urls
()
.read
()
.unwrap
()
.clone
();
let
prefill_urls
:
Vec
<
String
>
=
Vec
::
new
();
let
decode_urls
=
urls
;
// Collect loads from all servers
let
mut
prefill_loads
=
Vec
::
new
();
let
mut
decode_loads
=
Vec
::
new
();
// Get prefill loads
for
url
in
&
prefill_urls
{
let
load
=
self
.get_worker_load
(
client
,
url
)
.await
.unwrap_or
(
-
1
);
prefill_loads
.push
(
serde_json
::
json!
({
"engine"
:
format!
(
"(Prefill@{})"
,
url
),
"load"
:
load
as
i64
}));
}
// Get decode loads
for
url
in
&
decode_urls
{
let
load
=
self
.get_worker_load
(
client
,
url
)
.await
.unwrap_or
(
-
1
);
decode_loads
.push
(
serde_json
::
json!
({
"engine"
:
format!
(
"(Decode@{})"
,
url
),
"load"
:
load
as
i64
}));
}
HttpResponse
::
Ok
()
.json
(
serde_json
::
json!
({
"prefill"
:
prefill_loads
,
"decode"
:
decode_loads
}))
}
// TODO: return Result<String, String> instead of panicking
fn
select_generate_worker
(
&
self
,
body
:
&
Bytes
,
route
:
&
str
)
->
String
{
let
text
=
self
.get_text_from_request
(
&
body
,
route
);
// New method to route typed requests directly
pub
async
fn
route_typed_request
<
T
:
crate
::
openai_api_types
::
GenerationRequest
+
serde
::
Serialize
+
Clone
,
>
(
&
self
,
client
:
&
reqwest
::
Client
,
req
:
&
HttpRequest
,
typed_req
:
&
T
,
route
:
&
str
,
)
->
HttpResponse
{
match
self
{
Router
::
PrefillDecode
{
..
}
=>
HttpResponse
::
InternalServerError
()
.body
(
"PD routing should use specialized typed handlers"
),
_
=>
{
// Handle retries like the original implementation
let
start
=
Instant
::
now
();
const
MAX_REQUEST_RETRIES
:
u32
=
3
;
const
MAX_TOTAL_RETRIES
:
u32
=
6
;
let
mut
total_retries
=
0
;
while
total_retries
<
MAX_TOTAL_RETRIES
{
// Extract routing text directly from typed request
let
text
=
typed_req
.extract_text_for_routing
();
let
is_stream
=
typed_req
.is_stream
();
// Select worker based on text
let
worker_url
=
self
.select_generate_worker_from_text
(
&
text
);
let
mut
request_retries
=
0
;
// Try the same worker multiple times
while
request_retries
<
MAX_REQUEST_RETRIES
{
if
total_retries
>=
1
{
info!
(
"Retrying request after {} failed attempts"
,
total_retries
);
counter!
(
"sgl_router_retries_total"
,
"route"
=>
route
.to_string
())
.increment
(
1
);
}
// Send typed request directly
let
response
=
self
.send_typed_request
(
client
,
req
,
typed_req
,
route
,
&
worker_url
,
is_stream
,
)
.await
;
if
response
.status
()
.is_success
()
{
let
duration
=
start
.elapsed
();
histogram!
(
"sgl_router_generate_duration_seconds"
,
"route"
=>
route
.to_string
())
.record
(
duration
.as_secs_f64
());
return
response
;
}
else
{
// if the worker is healthy, it means the request is bad, so return the error response
let
health_response
=
self
.send_request
(
client
,
&
worker_url
,
"/health"
,
req
)
.await
;
if
health_response
.status
()
.is_success
()
{
counter!
(
"sgl_router_request_errors_total"
,
"route"
=>
route
.to_string
())
.increment
(
1
);
return
response
;
}
}
warn!
(
"Generate request to {} failed (attempt {}/{})"
,
worker_url
,
request_retries
+
1
,
MAX_REQUEST_RETRIES
);
request_retries
+=
1
;
total_retries
+=
1
;
let
worker_url
=
match
self
{
if
request_retries
==
MAX_REQUEST_RETRIES
{
warn!
(
"Removing failed worker: {}"
,
worker_url
);
self
.remove_worker
(
&
worker_url
);
break
;
}
}
}
counter!
(
"sgl_router_request_errors_total"
,
"route"
=>
route
.to_string
())
.increment
(
1
);
HttpResponse
::
InternalServerError
()
.body
(
"All retry attempts failed"
)
}
}
}
// Helper method to select worker from text
fn
select_generate_worker_from_text
(
&
self
,
text
:
&
str
)
->
String
{
match
self
{
Router
::
RoundRobin
{
worker_urls
,
current_index
,
...
...
@@ -506,8 +708,6 @@ impl Router {
balance_rel_threshold
,
..
}
=>
{
// TODO: delay scheduling if cache hit rate is high because it may cause imbalance. prioritize low hit rate ones
let
tree
=
tree
.lock
()
.unwrap
();
let
mut
running_queue
=
running_queue
.lock
()
.unwrap
();
...
...
@@ -572,35 +772,48 @@ impl Router {
selected_url
}
};
worker_url
Router
::
PrefillDecode
{
..
}
=>
{
// For PD mode, we don't use this method
return
"PD_MODE_ERROR"
.to_string
();
}
}
}
async
fn
send_generate_request
(
// Send typed request directly without conversion
async
fn
send_typed_request
<
T
:
serde
::
Serialize
>
(
&
self
,
client
:
&
reqwest
::
Client
,
req
:
&
HttpRequest
,
body
:
&
Bytes
,
typed_req
:
&
T
,
route
:
&
str
,
worker_url
:
&
str
,
is_stream
:
bool
,
)
->
HttpResponse
{
let
is_stream
=
serde_json
::
from_slice
::
<
serde_json
::
Value
>
(
&
body
)
.map
(|
v
|
v
.get
(
"stream"
)
.and_then
(|
s
|
s
.as_bool
())
.unwrap_or
(
false
))
.unwrap_or
(
false
);
let
start
=
Instant
::
now
();
// Debug: Log what we're sending
if
let
Ok
(
json_str
)
=
serde_json
::
to_string_pretty
(
typed_req
)
{
debug!
(
"Sending request to {}: {}"
,
route
,
json_str
);
}
let
mut
request_builder
=
client
.post
(
format!
(
"{}{}"
,
worker_url
,
route
))
.
body
(
body
.to_vec
());
.
json
(
typed_req
);
// Use json() directly with typed request
// Copy all headers from original request
for
(
name
,
value
)
in
copy_request_headers
(
req
)
{
request_builder
=
request_builder
.header
(
name
,
value
);
// Skip Content-Type and Content-Length as .json() sets them
if
name
.to_lowercase
()
!=
"content-type"
&&
name
.to_lowercase
()
!=
"content-length"
{
request_builder
=
request_builder
.header
(
&
name
,
&
value
);
}
}
let
res
=
match
request_builder
.send
()
.await
{
Ok
(
res
)
=>
res
,
Err
(
_
)
=>
return
HttpResponse
::
InternalServerError
()
.finish
(),
Err
(
e
)
=>
{
error!
(
"Failed to send request to {}: {}"
,
worker_url
,
e
);
return
HttpResponse
::
InternalServerError
()
.body
(
format!
(
"Request failed: {}"
,
e
));
}
};
let
status
=
actix_web
::
http
::
StatusCode
::
from_u16
(
res
.status
()
.as_u16
())
...
...
@@ -625,6 +838,12 @@ impl Router {
}
}
// Record metrics
let
duration
=
start
.elapsed
();
histogram!
(
"sgl_router_generate_duration_seconds"
,
"route"
=>
route
.to_string
())
.record
(
duration
.as_secs_f64
());
counter!
(
"sgl_router_requests_total"
,
"route"
=>
route
.to_string
())
.increment
(
1
);
response
}
else
if
let
Router
::
CacheAware
{
running_queue
,
..
}
=
self
{
let
running_queue
=
Arc
::
clone
(
running_queue
);
...
...
@@ -660,70 +879,6 @@ impl Router {
}
}
pub
async
fn
route_generate_request
(
&
self
,
client
:
&
reqwest
::
Client
,
req
:
&
HttpRequest
,
body
:
&
Bytes
,
route
:
&
str
,
)
->
HttpResponse
{
let
start
=
Instant
::
now
();
const
MAX_REQUEST_RETRIES
:
u32
=
3
;
const
MAX_TOTAL_RETRIES
:
u32
=
6
;
let
mut
total_retries
=
0
;
while
total_retries
<
MAX_TOTAL_RETRIES
{
let
worker_url
=
self
.select_generate_worker
(
body
,
route
);
let
mut
request_retries
=
0
;
// Try the same worker multiple times
while
request_retries
<
MAX_REQUEST_RETRIES
{
if
total_retries
>=
1
{
info!
(
"Retrying request after {} failed attempts"
,
total_retries
);
counter!
(
"sgl_router_retries_total"
,
"route"
=>
route
.to_string
())
.increment
(
1
);
}
let
response
=
self
.send_generate_request
(
client
,
req
,
body
,
route
,
&
worker_url
)
.await
;
if
response
.status
()
.is_success
()
{
let
duration
=
start
.elapsed
();
histogram!
(
"sgl_router_generate_duration_seconds"
,
"route"
=>
route
.to_string
())
.record
(
duration
.as_secs_f64
());
return
response
;
}
else
{
// if the worker is healthy, it means the request is bad, so return the error response
let
health_response
=
self
.send_request
(
client
,
&
worker_url
,
"/health"
,
req
)
.await
;
if
health_response
.status
()
.is_success
()
{
counter!
(
"sgl_router_request_errors_total"
,
"route"
=>
route
.to_string
())
.increment
(
1
);
return
response
;
}
}
warn!
(
"Generate request to {} failed (attempt {}/{})"
,
worker_url
,
request_retries
+
1
,
MAX_REQUEST_RETRIES
);
request_retries
+=
1
;
total_retries
+=
1
;
if
request_retries
==
MAX_REQUEST_RETRIES
{
warn!
(
"Removing failed worker: {}"
,
worker_url
);
self
.remove_worker
(
&
worker_url
);
break
;
}
}
}
counter!
(
"sgl_router_request_errors_total"
,
"route"
=>
route
.to_string
())
.increment
(
1
);
HttpResponse
::
InternalServerError
()
.body
(
"All retry attempts failed"
)
}
pub
async
fn
add_worker
(
&
self
,
worker_url
:
&
str
)
->
Result
<
String
,
String
>
{
let
(
timeout_secs
,
interval_secs
)
=
match
self
{
Router
::
Random
{
...
...
@@ -741,10 +896,17 @@ impl Router {
interval_secs
,
..
}
=>
(
*
timeout_secs
,
*
interval_secs
),
Router
::
PrefillDecode
{
..
}
=>
{
// For PD mode, we don't support adding workers via this method
return
Err
(
"Adding workers to PrefillDecode router not supported via add_worker. Use dedicated PD management methods."
.to_string
());
}
};
let
start_time
=
std
::
time
::
Instant
::
now
();
let
client
=
reqwest
::
Client
::
new
();
let
client
=
reqwest
::
Client
::
builder
()
.timeout
(
Duration
::
from_secs
(
timeout_secs
))
.build
()
.map_err
(|
e
|
format!
(
"Failed to create HTTP client: {}"
,
e
))
?
;
loop
{
if
start_time
.elapsed
()
>
Duration
::
from_secs
(
timeout_secs
)
{
...
...
@@ -774,6 +936,9 @@ impl Router {
urls
.push
(
worker_url
.to_string
());
gauge!
(
"sgl_router_active_workers"
)
.set
(
urls
.len
()
as
f64
);
}
Router
::
PrefillDecode
{
..
}
=>
{
return
Err
(
"Adding workers to PrefillDecode router not supported via add_worker. Use dedicated PD management methods."
.to_string
());
}
}
// If cache aware, initialize the queues for the new worker
...
...
@@ -797,7 +962,7 @@ impl Router {
.insert
(
worker_url
.to_string
(),
0
);
// Add worker to tree
tree
.lock
()
.unwrap
()
.insert
(
&
""
.to_string
()
,
&
worker_url
);
tree
.lock
()
.unwrap
()
.insert
(
""
,
worker_url
);
}
return
Ok
(
format!
(
"Successfully added worker: {}"
,
worker_url
));
...
...
@@ -850,6 +1015,10 @@ impl Router {
return
;
}
}
Router
::
PrefillDecode
{
..
}
=>
{
warn!
(
"Removing workers from PrefillDecode router not supported via remove_worker. Use dedicated PD management methods."
);
return
;
}
}
// if cache aware, remove the worker from the tree
...
...
@@ -875,4 +1044,133 @@ impl Router {
);
}
}
async
fn
get_worker_load
(
&
self
,
client
:
&
reqwest
::
Client
,
worker_url
:
&
str
)
->
Option
<
isize
>
{
match
client
.get
(
&
format!
(
"{}/get_load"
,
worker_url
))
.send
()
.await
{
Ok
(
res
)
if
res
.status
()
.is_success
()
=>
match
res
.bytes
()
.await
{
Ok
(
bytes
)
=>
match
serde_json
::
from_slice
::
<
serde_json
::
Value
>
(
&
bytes
)
{
Ok
(
data
)
=>
data
.get
(
"load"
)
.and_then
(|
v
|
v
.as_i64
())
.map
(|
v
|
v
as
isize
),
Err
(
e
)
=>
{
debug!
(
"Failed to parse load response from {}: {}"
,
worker_url
,
e
);
None
}
},
Err
(
e
)
=>
{
debug!
(
"Failed to read load response from {}: {}"
,
worker_url
,
e
);
None
}
},
Ok
(
res
)
=>
{
debug!
(
"Worker {} returned non-success status: {}"
,
worker_url
,
res
.status
()
);
None
}
Err
(
e
)
=>
{
debug!
(
"Failed to get load from {}: {}"
,
worker_url
,
e
);
None
}
}
}
// PD-specific wrapper methods that delegate to PDRouter
pub
async
fn
route_pd_health_generate
(
&
self
,
_
client
:
&
reqwest
::
Client
,
_
req
:
&
HttpRequest
,
)
->
HttpResponse
{
match
self
{
Router
::
PrefillDecode
{
pd_router
}
=>
{
pd_router
.health_generate
(
&
pd_router
.http_client
)
.await
}
_
=>
HttpResponse
::
InternalServerError
()
.body
(
"Not in PrefillDecode mode"
),
}
}
pub
async
fn
route_pd_generate_typed
(
&
self
,
_
client
:
&
reqwest
::
Client
,
req
:
&
HttpRequest
,
typed_req
:
crate
::
pd_types
::
GenerateReqInput
,
route
:
&
str
,
)
->
HttpResponse
{
match
self
{
Router
::
PrefillDecode
{
pd_router
}
=>
{
pd_router
.route_generate
(
&
pd_router
.http_client
,
req
,
typed_req
,
route
)
.await
}
_
=>
HttpResponse
::
InternalServerError
()
.body
(
"Not in PrefillDecode mode"
),
}
}
pub
async
fn
route_pd_chat_typed
(
&
self
,
_
client
:
&
reqwest
::
Client
,
req
:
&
HttpRequest
,
typed_req
:
crate
::
pd_types
::
ChatReqInput
,
route
:
&
str
,
)
->
HttpResponse
{
match
self
{
Router
::
PrefillDecode
{
pd_router
}
=>
{
pd_router
.route_chat
(
&
pd_router
.http_client
,
req
,
typed_req
,
route
)
.await
}
_
=>
HttpResponse
::
InternalServerError
()
.body
(
"Not in PrefillDecode mode"
),
}
}
pub
async
fn
get_pd_server_info
(
&
self
,
_
client
:
&
reqwest
::
Client
,
_
req
:
&
HttpRequest
,
)
->
HttpResponse
{
match
self
{
Router
::
PrefillDecode
{
pd_router
}
=>
{
pd_router
.get_server_info
(
&
pd_router
.http_client
)
.await
}
_
=>
HttpResponse
::
InternalServerError
()
.body
(
"Not in PrefillDecode mode"
),
}
}
pub
async
fn
get_pd_models
(
&
self
,
_
client
:
&
reqwest
::
Client
,
req
:
&
HttpRequest
,
)
->
HttpResponse
{
match
self
{
Router
::
PrefillDecode
{
pd_router
}
=>
{
pd_router
.get_models
(
&
pd_router
.http_client
,
req
)
.await
}
_
=>
HttpResponse
::
InternalServerError
()
.body
(
"Not in PrefillDecode mode"
),
}
}
pub
async
fn
route_pd_flush_cache
(
&
self
,
_
client
:
&
reqwest
::
Client
)
->
HttpResponse
{
match
self
{
Router
::
PrefillDecode
{
pd_router
}
=>
{
pd_router
.flush_cache
(
&
pd_router
.http_client
)
.await
}
_
=>
HttpResponse
::
InternalServerError
()
.body
(
"Not in PrefillDecode mode"
),
}
}
pub
async
fn
get_pd_model_info
(
&
self
,
_
client
:
&
reqwest
::
Client
,
req
:
&
HttpRequest
,
)
->
HttpResponse
{
match
self
{
Router
::
PrefillDecode
{
pd_router
}
=>
{
pd_router
.get_model_info
(
&
pd_router
.http_client
,
req
)
.await
}
_
=>
HttpResponse
::
InternalServerError
()
.body
(
"Not in PrefillDecode mode"
),
}
}
}
sgl-router/src/server.rs
View file @
09ae5b20
use
crate
::
logging
::{
self
,
LoggingConfig
};
use
crate
::
openai_api_types
::{
ChatCompletionRequest
,
CompletionRequest
,
GenerateRequest
};
use
crate
::
prometheus
::{
self
,
PrometheusConfig
};
use
crate
::
request_adapter
::
ToPdRequest
;
use
crate
::
router
::
PolicyConfig
;
use
crate
::
router
::
Router
;
use
crate
::
service_discovery
::{
start_service_discovery
,
ServiceDiscoveryConfig
};
use
actix_web
::{
error
,
get
,
post
,
web
,
App
,
Error
,
HttpRequest
,
HttpResponse
,
HttpServer
,
Responder
,
};
use
bytes
::
Bytes
;
use
futures_util
::
StreamExt
;
use
reqwest
::
Client
;
use
std
::
collections
::
HashMap
;
...
...
@@ -20,6 +21,7 @@ use tracing::{error, info, warn, Level};
pub
struct
AppState
{
router
:
Arc
<
Router
>
,
client
:
Client
,
is_pd_mode
:
bool
,
// Add flag to track PD mode
}
impl
AppState
{
...
...
@@ -28,9 +30,16 @@ impl AppState {
client
:
Client
,
policy_config
:
PolicyConfig
,
)
->
Result
<
Self
,
String
>
{
// Check if this is PD mode from policy config
let
is_pd_mode
=
matches!
(
policy_config
,
PolicyConfig
::
PrefillDecodeConfig
{
..
});
// Create router based on policy
let
router
=
Arc
::
new
(
Router
::
new
(
worker_urls
,
policy_config
)
?
);
Ok
(
Self
{
router
,
client
})
Ok
(
Self
{
router
,
client
,
is_pd_mode
,
})
}
}
...
...
@@ -46,8 +55,25 @@ async fn sink_handler(_req: HttpRequest, mut payload: web::Payload) -> Result<Ht
}
// Custom error handler for JSON payload errors.
fn
json_error_handler
(
_
err
:
error
::
JsonPayloadError
,
_
req
:
&
HttpRequest
)
->
Error
{
error
::
ErrorPayloadTooLarge
(
"Payload too large"
)
fn
json_error_handler
(
err
:
error
::
JsonPayloadError
,
_
req
:
&
HttpRequest
)
->
Error
{
error!
(
"JSON payload error: {:?}"
,
err
);
match
&
err
{
error
::
JsonPayloadError
::
OverflowKnownLength
{
length
,
limit
}
=>
{
error!
(
"Payload too large: {} bytes exceeds limit of {} bytes"
,
length
,
limit
);
error
::
ErrorPayloadTooLarge
(
format!
(
"Payload too large: {} bytes exceeds limit of {} bytes"
,
length
,
limit
))
}
error
::
JsonPayloadError
::
Overflow
{
limit
}
=>
{
error!
(
"Payload overflow: exceeds limit of {} bytes"
,
limit
);
error
::
ErrorPayloadTooLarge
(
format!
(
"Payload exceeds limit of {} bytes"
,
limit
))
}
_
=>
error
::
ErrorBadRequest
(
format!
(
"Invalid JSON payload: {}"
,
err
)),
}
}
#[get(
"/health"
)]
...
...
@@ -59,59 +85,134 @@ async fn health(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
#[get(
"/health_generate"
)]
async
fn
health_generate
(
req
:
HttpRequest
,
data
:
web
::
Data
<
AppState
>
)
->
impl
Responder
{
data
.router
.route_to_first
(
&
data
.client
,
"/health_generate"
,
&
req
)
.await
// Check if we're in PD mode
if
data
.is_pd_mode
{
// For PD mode, check health on all servers
data
.router
.route_pd_health_generate
(
&
data
.client
,
&
req
)
.await
}
else
{
// Regular mode
data
.router
.route_to_first
(
&
data
.client
,
"/health_generate"
,
&
req
)
.await
}
}
#[get(
"/get_server_info"
)]
async
fn
get_server_info
(
req
:
HttpRequest
,
data
:
web
::
Data
<
AppState
>
)
->
impl
Responder
{
data
.router
.route_to_first
(
&
data
.client
,
"/get_server_info"
,
&
req
)
.await
if
data
.is_pd_mode
{
// For PD mode, aggregate info from both prefill and decode servers
data
.router
.get_pd_server_info
(
&
data
.client
,
&
req
)
.await
}
else
{
// Regular mode - return first server's info
data
.router
.route_to_first
(
&
data
.client
,
"/get_server_info"
,
&
req
)
.await
}
}
#[get(
"/v1/models"
)]
async
fn
v1_models
(
req
:
HttpRequest
,
data
:
web
::
Data
<
AppState
>
)
->
impl
Responder
{
data
.router
.route_to_first
(
&
data
.client
,
"/v1/models"
,
&
req
)
.await
if
data
.is_pd_mode
{
// For PD mode, return models from the first prefill server
data
.router
.get_pd_models
(
&
data
.client
,
&
req
)
.await
}
else
{
// Regular mode
data
.router
.route_to_first
(
&
data
.client
,
"/v1/models"
,
&
req
)
.await
}
}
#[get(
"/get_model_info"
)]
async
fn
get_model_info
(
req
:
HttpRequest
,
data
:
web
::
Data
<
AppState
>
)
->
impl
Responder
{
data
.router
.route_to_first
(
&
data
.client
,
"/get_model_info"
,
&
req
)
.await
if
data
.is_pd_mode
{
// For PD mode, get model info from the first prefill server
data
.router
.get_pd_model_info
(
&
data
.client
,
&
req
)
.await
}
else
{
data
.router
.route_to_first
(
&
data
.client
,
"/get_model_info"
,
&
req
)
.await
}
}
#[post(
"/generate"
)]
async
fn
generate
(
req
:
HttpRequest
,
body
:
Bytes
,
data
:
web
::
Data
<
AppState
>
)
->
impl
Responder
{
data
.router
.route_generate_request
(
&
data
.client
,
&
req
,
&
body
,
"/generate"
)
.await
async
fn
generate
(
req
:
HttpRequest
,
body
:
web
::
Json
<
GenerateRequest
>
,
state
:
web
::
Data
<
AppState
>
,
)
->
Result
<
HttpResponse
,
Error
>
{
let
client
=
&
state
.client
;
let
router
=
&
state
.router
;
// Use typed request directly for both PD and regular routing
if
state
.is_pd_mode
{
// For PD mode, convert to PD request with bootstrap
let
pd_request
=
body
.into_inner
()
.to_pd_request
();
Ok
(
router
.route_pd_generate_typed
(
&
client
,
&
req
,
pd_request
,
"/generate"
)
.await
)
}
else
{
// For regular mode, use typed request directly
let
request
=
body
.into_inner
();
Ok
(
router
.route_typed_request
(
&
client
,
&
req
,
&
request
,
"/generate"
)
.await
)
}
}
#[post(
"/v1/chat/completions"
)]
async
fn
v1_chat_completions
(
req
:
HttpRequest
,
body
:
Bytes
,
data
:
web
::
Data
<
AppState
>
,
)
->
impl
Responder
{
data
.router
.route_generate_request
(
&
data
.client
,
&
req
,
&
body
,
"/v1/chat/completions"
)
.await
body
:
web
::
Json
<
ChatCompletionRequest
>
,
state
:
web
::
Data
<
AppState
>
,
)
->
Result
<
HttpResponse
,
Error
>
{
let
client
=
&
state
.client
;
let
router
=
&
state
.router
;
// Use typed request directly for both PD and regular routing
if
state
.is_pd_mode
{
// For PD mode, convert to PD request with bootstrap
let
pd_request
=
body
.into_inner
()
.to_pd_request
();
Ok
(
router
.route_pd_chat_typed
(
&
client
,
&
req
,
pd_request
,
"/v1/chat/completions"
)
.await
)
}
else
{
// For regular mode, use typed request directly
let
request
=
body
.into_inner
();
Ok
(
router
.route_typed_request
(
&
client
,
&
req
,
&
request
,
"/v1/chat/completions"
)
.await
)
}
}
#[post(
"/v1/completions"
)]
async
fn
v1_completions
(
req
:
HttpRequest
,
body
:
Bytes
,
data
:
web
::
Data
<
AppState
>
,
)
->
impl
Responder
{
data
.router
.route_generate_request
(
&
data
.client
,
&
req
,
&
body
,
"/v1/completions"
)
.await
body
:
web
::
Json
<
CompletionRequest
>
,
state
:
web
::
Data
<
AppState
>
,
)
->
Result
<
HttpResponse
,
Error
>
{
let
client
=
&
state
.client
;
let
router
=
&
state
.router
;
// Use typed request directly for both PD and regular routing
if
state
.is_pd_mode
{
// For PD mode, convert to PD request with bootstrap
let
pd_request
=
body
.into_inner
()
.to_pd_request
();
Ok
(
router
.route_pd_generate_typed
(
&
client
,
&
req
,
pd_request
,
"/v1/completions"
)
.await
)
}
else
{
// For regular mode, use typed request directly
let
request
=
body
.into_inner
();
Ok
(
router
.route_typed_request
(
&
client
,
&
req
,
&
request
,
"/v1/completions"
)
.await
)
}
}
#[post(
"/add_worker"
)]
...
...
@@ -153,6 +254,25 @@ async fn remove_worker(
HttpResponse
::
Ok
()
.body
(
format!
(
"Successfully removed worker: {}"
,
worker_url
))
}
#[post(
"/flush_cache"
)]
async
fn
flush_cache
(
req
:
HttpRequest
,
data
:
web
::
Data
<
AppState
>
)
->
impl
Responder
{
if
data
.is_pd_mode
{
// For PD mode, flush cache on both prefill and decode servers
data
.router
.route_pd_flush_cache
(
&
data
.client
)
.await
}
else
{
// Route to all workers for cache flushing
data
.router
.route_to_all
(
&
data
.client
,
"/flush_cache"
,
&
req
)
.await
}
}
#[get(
"/get_loads"
)]
async
fn
get_loads
(
req
:
HttpRequest
,
data
:
web
::
Data
<
AppState
>
)
->
impl
Responder
{
// Get loads from all workers
data
.router
.get_all_loads
(
&
data
.client
,
&
req
)
.await
}
pub
struct
ServerConfig
{
pub
host
:
String
,
pub
port
:
u16
,
...
...
@@ -163,6 +283,7 @@ pub struct ServerConfig {
pub
log_dir
:
Option
<
String
>
,
pub
service_discovery_config
:
Option
<
ServiceDiscoveryConfig
>
,
pub
prometheus_config
:
Option
<
PrometheusConfig
>
,
pub
request_timeout_secs
:
u64
,
}
pub
async
fn
startup
(
config
:
ServerConfig
)
->
std
::
io
::
Result
<
()
>
{
...
...
@@ -215,6 +336,7 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> {
let
client
=
Client
::
builder
()
.pool_idle_timeout
(
Some
(
Duration
::
from_secs
(
50
)))
.timeout
(
Duration
::
from_secs
(
config
.request_timeout_secs
))
// Use configurable timeout
.build
()
.expect
(
"Failed to create HTTP client"
);
...
...
@@ -276,7 +398,8 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> {
.service
(
add_worker
)
.service
(
remove_worker
)
.service
(
list_workers
)
// Default handler for unmatched routes.
.service
(
flush_cache
)
.service
(
get_loads
)
.default_service
(
web
::
route
()
.to
(
sink_handler
))
})
.bind_auto_h2c
((
config
.host
,
config
.port
))
?
...
...
sgl-router/tests/test_pd_routing.rs
0 → 100644
View file @
09ae5b20
//! Comprehensive tests for PrefillDecode (PD) routing functionality
//!
//! This test suite covers:
//! - Phase 1: Basic PD router creation and configuration
//! - Phase 2: Bootstrap injection and request handling
//! - Phase 3: Cache-aware selection (when implemented)
//!
//! Note: PD mode is enabled via the pd_disaggregated flag, not as a policy type.
//! The policy type (Random, PowerOfTwo, CacheAware) determines the selection algorithm within PD mode.
#[cfg(test)]
mod
test_pd_routing
{
use
rand
::
Rng
;
use
serde_json
::
json
;
use
sglang_router_rs
::
pd_types
::{
EngineInfo
,
EngineType
,
PDSelectionPolicy
};
use
sglang_router_rs
::
router
::{
PolicyConfig
,
Router
};
// Test-only struct to help validate PD request parsing
#[derive(Debug)]
struct
PDRequest
{
pub
is_stream
:
bool
,
pub
batch_size
:
Option
<
usize
>
,
}
impl
PDRequest
{
// Extract PD-relevant info from JSON for testing
pub
fn
from_json
(
json
:
&
serde_json
::
Value
)
->
Self
{
let
is_stream
=
json
.get
(
"stream"
)
.and_then
(|
v
|
v
.as_bool
())
.unwrap_or
(
false
);
// Detect batch size from text or input_ids
let
batch_size
=
if
let
Some
(
text
)
=
json
.get
(
"text"
)
{
text
.as_array
()
.map
(|
arr
|
arr
.len
())
}
else
if
let
Some
(
input_ids
)
=
json
.get
(
"input_ids"
)
{
input_ids
.as_array
()
.map
(|
arr
|
arr
.len
())
}
else
{
None
};
PDRequest
{
is_stream
,
batch_size
,
}
}
}
// ========================================================================
// Phase 1: Basic PD Components and Router Creation
// ========================================================================
#[test]
fn
test_engine_info_creation
()
{
// Test EngineInfo creation for prefill servers
let
prefill_engine
=
EngineInfo
::
new_prefill
(
"http://prefill:8080"
.to_string
(),
Some
(
9000
));
match
prefill_engine
.engine_type
{
EngineType
::
Prefill
=>
(),
_
=>
panic!
(
"Expected Prefill engine type"
),
}
assert_eq!
(
prefill_engine
.url
,
"http://prefill:8080"
);
assert_eq!
(
prefill_engine
.bootstrap_port
,
Some
(
9000
));
assert_eq!
(
prefill_engine
.get_hostname
(),
"prefill"
);
// Test EngineInfo creation for decode servers
let
decode_engine
=
EngineInfo
::
new_decode
(
"http://decode:8080"
.to_string
());
match
decode_engine
.engine_type
{
EngineType
::
Decode
=>
(),
_
=>
panic!
(
"Expected Decode engine type"
),
}
assert_eq!
(
decode_engine
.url
,
"http://decode:8080"
);
assert_eq!
(
decode_engine
.bootstrap_port
,
None
);
assert_eq!
(
decode_engine
.get_hostname
(),
"decode"
);
// Test API path generation
assert_eq!
(
prefill_engine
.api_path
(
"/generate"
),
"http://prefill:8080/generate"
);
assert_eq!
(
prefill_engine
.api_path
(
"health"
),
"http://prefill:8080/health"
);
assert_eq!
(
decode_engine
.api_path
(
"/v1/chat/completions"
),
"http://decode:8080/v1/chat/completions"
);
}
#[test]
fn
test_pd_selection_policies
()
{
// Test all PD selection policy variants
// Note: These policies are only used when pd_disaggregated=true
let
policies
=
vec!
[
PDSelectionPolicy
::
Random
,
PDSelectionPolicy
::
PowerOfTwo
,
PDSelectionPolicy
::
CacheAware
{
cache_threshold
:
0.5
,
balance_abs_threshold
:
32
,
balance_rel_threshold
:
1.1
,
},
];
for
policy
in
policies
{
// Verify each policy can be created and matched
match
&
policy
{
PDSelectionPolicy
::
Random
=>
{
assert
!
(
matches!
(
policy
,
PDSelectionPolicy
::
Random
));
}
PDSelectionPolicy
::
PowerOfTwo
=>
{
assert
!
(
matches!
(
policy
,
PDSelectionPolicy
::
PowerOfTwo
));
}
PDSelectionPolicy
::
CacheAware
{
cache_threshold
,
..
}
=>
{
assert
!
(
*
cache_threshold
>=
0.0
&&
*
cache_threshold
<=
1.0
);
}
}
}
}
#[test]
fn
test_pd_router_configuration
()
{
// Test PrefillDecodeConfig creation with various policies
// This config is used when pd_disaggregated=true
let
configs
=
vec!
[
PolicyConfig
::
PrefillDecodeConfig
{
selection_policy
:
PDSelectionPolicy
::
Random
,
prefill_urls
:
vec!
[
(
"http://prefill1:8080"
.to_string
(),
Some
(
9000
)),
(
"http://prefill2:8080"
.to_string
(),
None
),
],
decode_urls
:
vec!
[
"http://decode1:8080"
.to_string
(),
"http://decode2:8080"
.to_string
(),
],
timeout_secs
:
10
,
interval_secs
:
1
,
},
PolicyConfig
::
PrefillDecodeConfig
{
selection_policy
:
PDSelectionPolicy
::
PowerOfTwo
,
prefill_urls
:
vec!
[(
"http://prefill:8080"
.to_string
(),
Some
(
9000
))],
decode_urls
:
vec!
[
"http://decode:8080"
.to_string
()],
timeout_secs
:
5
,
interval_secs
:
1
,
},
PolicyConfig
::
PrefillDecodeConfig
{
selection_policy
:
PDSelectionPolicy
::
CacheAware
{
cache_threshold
:
0.7
,
balance_abs_threshold
:
20
,
balance_rel_threshold
:
1.2
,
},
prefill_urls
:
vec!
[
(
"http://p1:8080"
.to_string
(),
Some
(
9000
)),
(
"http://p2:8080"
.to_string
(),
Some
(
9001
)),
(
"http://p3:8080"
.to_string
(),
Some
(
9002
)),
],
decode_urls
:
vec!
[
"http://d1:8080"
.to_string
(),
"http://d2:8080"
.to_string
()],
timeout_secs
:
10
,
interval_secs
:
2
,
},
];
for
config
in
configs
{
// Router creation will fail due to health checks, but config should be valid
let
result
=
Router
::
new
(
vec!
[],
config
);
assert
!
(
result
.is_err
());
let
error_msg
=
result
.unwrap_err
();
// Error should be about health/timeout, not configuration
assert
!
(
error_msg
.contains
(
"healthy"
)
||
error_msg
.contains
(
"timeout"
),
"Unexpected error: {}"
,
error_msg
);
}
}
// ========================================================================
// Phase 2: Bootstrap Injection and Request Handling
// ========================================================================
#[test]
fn
test_pd_request_from_json
()
{
// Test PDRequest parsing from single text request
let
single_json
=
json!
({
"text"
:
"Hello world"
,
"stream"
:
false
,
"temperature"
:
0.7
,
"max_tokens"
:
100
});
let
pd_req
=
PDRequest
::
from_json
(
&
single_json
);
assert
!
(
!
pd_req
.is_stream
);
assert_eq!
(
pd_req
.batch_size
,
None
);
// Test PDRequest parsing from batch text request
let
batch_json
=
json!
({
"text"
:
[
"Hello"
,
"World"
,
"Test"
],
"stream"
:
true
,
"temperature"
:
0.5
});
let
pd_req
=
PDRequest
::
from_json
(
&
batch_json
);
assert
!
(
pd_req
.is_stream
);
assert_eq!
(
pd_req
.batch_size
,
Some
(
3
));
// Test PDRequest parsing from input_ids request
let
ids_json
=
json!
({
"input_ids"
:
[[
1
,
2
,
3
],
[
4
,
5
,
6
]],
"stream"
:
false
});
let
pd_req
=
PDRequest
::
from_json
(
&
ids_json
);
assert
!
(
!
pd_req
.is_stream
);
assert_eq!
(
pd_req
.batch_size
,
Some
(
2
));
// Test PDRequest parsing from chat request
let
chat_json
=
json!
({
"messages"
:
[
{
"role"
:
"system"
,
"content"
:
"You are a helpful assistant"
},
{
"role"
:
"user"
,
"content"
:
"Hello"
}
],
"stream"
:
true
});
let
pd_req
=
PDRequest
::
from_json
(
&
chat_json
);
assert
!
(
pd_req
.is_stream
);
assert_eq!
(
pd_req
.batch_size
,
None
);
}
#[test]
fn
test_bootstrap_injection_simulation
()
{
// Since we can't test the actual inject_bootstrap_fields function here
// (it's private in the router module), we'll test the expected behavior
// Simulate bootstrap injection for single request
let
mut
single_json
=
json!
({
"text"
:
"Hello world"
,
"stream"
:
false
,
"temperature"
:
0.7
});
// Simulate what inject_bootstrap_fields would do
let
prefill_info
=
EngineInfo
::
new_prefill
(
"http://prefill1:8080"
.to_string
(),
Some
(
9000
));
single_json
[
"bootstrap_host"
]
=
json!
(
prefill_info
.get_hostname
());
single_json
[
"bootstrap_port"
]
=
json!
(
prefill_info
.bootstrap_port
);
single_json
[
"bootstrap_room"
]
=
json!
(
12345u64
);
// Random room ID
// Verify bootstrap fields are added correctly
assert_eq!
(
single_json
[
"bootstrap_host"
],
"prefill1"
);
assert_eq!
(
single_json
[
"bootstrap_port"
],
9000
);
assert
!
(
single_json
[
"bootstrap_room"
]
.is_u64
());
assert_eq!
(
single_json
[
"temperature"
],
0.7
);
// Original field preserved
// Simulate bootstrap injection for batch request
let
mut
batch_json
=
json!
({
"text"
:
[
"Hello"
,
"World"
,
"Test"
],
"stream"
:
true
});
let
batch_size
=
3
;
batch_json
[
"bootstrap_host"
]
=
json!
(
vec!
[
prefill_info
.get_hostname
();
batch_size
]);
batch_json
[
"bootstrap_port"
]
=
json!
(
vec!
[
prefill_info
.bootstrap_port
;
batch_size
]);
batch_json
[
"bootstrap_room"
]
=
json!
(
vec!
[
111u64
,
222u64
,
333u64
]);
// Verify batch bootstrap fields
assert
!
(
batch_json
[
"bootstrap_host"
]
.is_array
());
assert_eq!
(
batch_json
[
"bootstrap_host"
]
.as_array
()
.unwrap
()
.len
(),
batch_size
);
assert
!
(
batch_json
[
"bootstrap_port"
]
.is_array
());
assert
!
(
batch_json
[
"bootstrap_room"
]
.is_array
());
assert_eq!
(
batch_json
[
"stream"
],
true
);
// Original field preserved
}
#[test]
fn
test_request_serialization
()
{
// Test that requests can be properly serialized and deserialized
let
request
=
json!
({
"text"
:
"Test prompt"
,
"stream"
:
false
,
"temperature"
:
0.7
,
"max_tokens"
:
100
,
"top_p"
:
0.9
,
"frequency_penalty"
:
0.5
,
"bootstrap_host"
:
"prefill1"
,
"bootstrap_port"
:
9000
,
"bootstrap_room"
:
12345u64
});
// Convert to bytes (as would happen in the router)
let
bytes
=
serde_json
::
to_vec
(
&
request
)
.unwrap
();
// Parse back from bytes
let
parsed
:
serde_json
::
Value
=
serde_json
::
from_slice
(
&
bytes
)
.unwrap
();
// Verify all fields are preserved
assert_eq!
(
parsed
[
"text"
],
"Test prompt"
);
assert_eq!
(
parsed
[
"stream"
],
false
);
assert_eq!
(
parsed
[
"temperature"
],
0.7
);
assert_eq!
(
parsed
[
"max_tokens"
],
100
);
assert_eq!
(
parsed
[
"bootstrap_host"
],
"prefill1"
);
assert_eq!
(
parsed
[
"bootstrap_port"
],
9000
);
assert_eq!
(
parsed
[
"bootstrap_room"
],
12345
);
}
#[test]
fn
test_engine_info_hostname_extraction
()
{
// Test various URL formats
let
test_cases
=
vec!
[
(
"http://localhost:8080"
,
"localhost"
),
(
"http://10.0.0.1:8080"
,
"10.0.0.1"
),
(
"https://api.example.com:443"
,
"api.example.com"
),
(
"http://prefill-server"
,
"prefill-server"
),
(
"http://[::1]:8080"
,
"["
),
// IPv6 edge case
(
"prefill:8080"
,
"prefill"
),
// No protocol
];
for
(
url
,
expected_hostname
)
in
test_cases
{
let
engine
=
EngineInfo
::
new_prefill
(
url
.to_string
(),
None
);
assert_eq!
(
engine
.get_hostname
(),
expected_hostname
);
}
}
#[test]
fn
test_pd_request_edge_cases
()
{
// Test empty request
let
empty_json
=
json!
({});
let
pd_req
=
PDRequest
::
from_json
(
&
empty_json
);
assert
!
(
!
pd_req
.is_stream
);
assert_eq!
(
pd_req
.batch_size
,
None
);
// Test request with only stream field
let
stream_only
=
json!
({
"stream"
:
true
});
let
pd_req
=
PDRequest
::
from_json
(
&
stream_only
);
assert
!
(
pd_req
.is_stream
);
assert_eq!
(
pd_req
.batch_size
,
None
);
// Test request with empty text array
let
empty_batch
=
json!
({
"text"
:
[]
});
let
pd_req
=
PDRequest
::
from_json
(
&
empty_batch
);
assert_eq!
(
pd_req
.batch_size
,
Some
(
0
));
// Test request with non-array text (should be None)
let
non_array_text
=
json!
({
"text"
:
"single string"
});
let
pd_req
=
PDRequest
::
from_json
(
&
non_array_text
);
assert_eq!
(
pd_req
.batch_size
,
None
);
}
// ========================================================================
// Phase 2: Background Load Monitoring Tests
// ========================================================================
#[tokio::test]
async
fn
test_background_load_monitoring
()
{
use
std
::
collections
::
HashMap
;
use
tokio
::
sync
::
watch
;
// Create a watch channel for testing
let
(
tx
,
rx
)
=
watch
::
channel
(
HashMap
::
new
());
// Simulate load updates
let
mut
loads
=
HashMap
::
new
();
loads
.insert
(
"http://prefill1:8080"
.to_string
(),
10
);
loads
.insert
(
"http://prefill2:8080"
.to_string
(),
20
);
loads
.insert
(
"http://decode1:8080"
.to_string
(),
5
);
loads
.insert
(
"http://decode2:8080"
.to_string
(),
15
);
// Send the loads
tx
.send
(
loads
.clone
())
.unwrap
();
// Verify receiver gets the update
let
received_loads
=
rx
.borrow
();
assert_eq!
(
received_loads
.get
(
"http://prefill1:8080"
),
Some
(
&
10
));
assert_eq!
(
received_loads
.get
(
"http://prefill2:8080"
),
Some
(
&
20
));
assert_eq!
(
received_loads
.get
(
"http://decode1:8080"
),
Some
(
&
5
));
assert_eq!
(
received_loads
.get
(
"http://decode2:8080"
),
Some
(
&
15
));
}
#[test]
fn
test_power_of_two_load_selection
()
{
// Test the power-of-two selection logic with different load scenarios
// Scenario 1: Clear winner for both prefill and decode
let
_
loads
=
vec!
[
(
"prefill1"
,
100
),
(
"prefill2"
,
10
),
// Should be selected
(
"decode1"
,
50
),
(
"decode2"
,
5
),
// Should be selected
];
// In actual implementation, the lower load should be selected
assert
!
(
10
<
100
);
assert
!
(
5
<
50
);
// Scenario 2: Equal loads (should select first)
let
_
equal_loads
=
vec!
[
(
"prefill1"
,
20
),
(
"prefill2"
,
20
),
// Either could be selected
(
"decode1"
,
30
),
(
"decode2"
,
30
),
// Either could be selected
];
// When loads are equal, <= comparison means first is selected
assert
!
(
20
<=
20
);
assert
!
(
30
<=
30
);
// Scenario 3: Missing load data (should default to usize::MAX)
// This tests the unwrap_or(usize::MAX) behavior
let
missing_load
=
usize
::
MAX
;
assert
!
(
10
<
missing_load
);
assert
!
(
missing_load
>
0
);
}
#[test]
fn
test_load_monitoring_configuration
()
{
// Test that load monitoring is only enabled for PowerOfTwo policy
let
policies
=
vec!
[
(
PDSelectionPolicy
::
Random
,
false
),
(
PDSelectionPolicy
::
PowerOfTwo
,
true
),
(
PDSelectionPolicy
::
CacheAware
{
cache_threshold
:
0.5
,
balance_abs_threshold
:
32
,
balance_rel_threshold
:
1.1
,
},
false
,
),
];
for
(
policy
,
should_monitor
)
in
policies
{
match
policy
{
PDSelectionPolicy
::
PowerOfTwo
=>
assert
!
(
should_monitor
),
_
=>
assert
!
(
!
should_monitor
),
}
}
}
#[tokio::test]
async
fn
test_watch_channel_behavior
()
{
use
std
::
collections
::
HashMap
;
use
tokio
::
sync
::
watch
;
// Test watch channel's broadcast behavior
let
(
tx
,
rx1
)
=
watch
::
channel
(
HashMap
::
new
());
let
rx2
=
rx1
.clone
();
// Initial state - empty map
assert
!
(
rx1
.borrow
()
.is_empty
());
assert
!
(
rx2
.borrow
()
.is_empty
());
// Update 1
let
mut
loads
=
HashMap
::
new
();
loads
.insert
(
"worker1"
.to_string
(),
10
);
tx
.send
(
loads
.clone
())
.unwrap
();
// Both receivers see the update
assert_eq!
(
rx1
.borrow
()
.get
(
"worker1"
),
Some
(
&
10
));
assert_eq!
(
rx2
.borrow
()
.get
(
"worker1"
),
Some
(
&
10
));
// Update 2 - overwrites previous
loads
.insert
(
"worker1"
.to_string
(),
20
);
loads
.insert
(
"worker2"
.to_string
(),
30
);
tx
.send
(
loads
)
.unwrap
();
// Both receivers see the latest state
assert_eq!
(
rx1
.borrow
()
.get
(
"worker1"
),
Some
(
&
20
));
assert_eq!
(
rx2
.borrow
()
.get
(
"worker2"
),
Some
(
&
30
));
}
// ========================================================================
// Tests based on bench_one_batch_server.py patterns
// ========================================================================
#[test]
fn
test_generate_request_formats
()
{
// Based on bench_one_batch_server.py request patterns
// Test 1: Batch request with input_ids (most common in benchmarks)
let
batch_request
=
json!
({
"input_ids"
:
[[
1
,
2
,
3
,
4
],
[
5
,
6
,
7
,
8
],
[
9
,
10
,
11
,
12
]],
"sampling_params"
:
{
"temperature"
:
0.0
,
"max_new_tokens"
:
16
,
"ignore_eos"
:
true
,
},
"return_logprob"
:
false
,
"stream"
:
true
});
let
pd_req
=
PDRequest
::
from_json
(
&
batch_request
);
assert
!
(
pd_req
.is_stream
);
assert_eq!
(
pd_req
.batch_size
,
Some
(
3
));
// Test 2: Request with return_logprob (critical for PD)
let
logprob_request
=
json!
({
"input_ids"
:
[[
1
,
2
,
3
]],
"sampling_params"
:
{
"temperature"
:
0.7
,
"max_new_tokens"
:
8
,
},
"return_logprob"
:
true
,
"stream"
:
false
});
assert_eq!
(
logprob_request
[
"return_logprob"
],
true
);
assert_eq!
(
logprob_request
[
"stream"
],
false
);
// Test 3: Large batch sizes from benchmark
let
batch_sizes
=
vec!
[
1
,
16
,
64
];
// From bench_one_batch_server.py
for
bs
in
batch_sizes
{
let
request
=
json!
({
"input_ids"
:
vec!
[
vec!
[
1
,
2
,
3
];
bs
],
"sampling_params"
:
{
"temperature"
:
0.0
,
"max_new_tokens"
:
16
,
},
"stream"
:
true
});
let
pd_req
=
PDRequest
::
from_json
(
&
request
);
assert_eq!
(
pd_req
.batch_size
,
Some
(
bs
));
}
}
#[test]
fn
test_sampling_params_handling
()
{
// Test various sampling parameters from bench_one_batch_server.py
let
sampling_params_variations
=
vec!
[
json!
({
"temperature"
:
0.0
,
"max_new_tokens"
:
8
,
"ignore_eos"
:
true
}),
json!
({
"temperature"
:
0.7
,
"max_new_tokens"
:
16
,
"ignore_eos"
:
false
,
"top_p"
:
0.9
,
"frequency_penalty"
:
0.5
}),
json!
({
"temperature"
:
1.0
,
"max_new_tokens"
:
64
,
"json_schema"
:
"$$ANY$$"
// Structured output
}),
];
for
params
in
sampling_params_variations
{
let
request
=
json!
({
"input_ids"
:
[[
1
,
2
,
3
]],
"sampling_params"
:
params
.clone
(),
"stream"
:
false
});
// Verify params are preserved
assert_eq!
(
request
[
"sampling_params"
],
params
);
}
}
#[test]
fn
test_streaming_response_parsing
()
{
// Test SSE format parsing from streaming responses
let
sse_chunks
=
vec!
[
"data: {
\"
text
\"
:
\"
Hello
\"
,
\"
meta_info
\"
:{
\"
completion_tokens
\"
:1,
\"
finish_reason
\"
:null}}"
,
"data: {
\"
text
\"
:
\"
world
\"
,
\"
meta_info
\"
:{
\"
completion_tokens
\"
:2,
\"
finish_reason
\"
:null}}"
,
"data: {
\"
text
\"
:
\"
!
\"
,
\"
meta_info
\"
:{
\"
completion_tokens
\"
:3,
\"
finish_reason
\"
:{
\"
type
\"
:
\"
length
\"
}}}"
,
"data: [DONE]"
,
];
for
chunk
in
&
sse_chunks
[
..
3
]
{
assert
!
(
chunk
.starts_with
(
"data: "
));
let
json_str
=
&
chunk
[
6
..
];
// Skip "data: "
let
parsed
:
serde_json
::
Value
=
serde_json
::
from_str
(
json_str
)
.unwrap
();
assert
!
(
parsed
[
"meta_info"
][
"completion_tokens"
]
.is_u64
());
}
// Test [DONE] detection
assert_eq!
(
sse_chunks
[
3
],
"data: [DONE]"
);
}
#[test]
fn
test_ttft_calculation
()
{
// Test Time To First Token calculation pattern
let
first_token_response
=
json!
({
"text"
:
"Hello"
,
"meta_info"
:
{
"completion_tokens"
:
1
,
"finish_reason"
:
null
}
});
// TTFT is calculated when completion_tokens == 1
assert_eq!
(
first_token_response
[
"meta_info"
][
"completion_tokens"
],
1
);
assert
!
(
first_token_response
[
"meta_info"
][
"finish_reason"
]
.is_null
());
}
#[test]
fn
test_throughput_metrics
()
{
// Test throughput calculation patterns from bench_one_batch_server.py
let
batch_size
=
16
;
let
input_len
=
1024
;
let
output_len
=
16
;
let
ttft
=
0.5
;
// seconds
let
total_latency
=
2.0
;
// seconds
// Input throughput = batch_size * input_len / ttft
let
input_throughput
=
(
batch_size
as
f64
)
*
(
input_len
as
f64
)
/
ttft
;
assert
!
((
input_throughput
-
32768.0
)
.abs
()
<
0.01
);
// Output throughput = batch_size * output_len / (latency - ttft)
let
output_throughput
=
(
batch_size
as
f64
)
*
(
output_len
as
f64
)
/
(
total_latency
-
ttft
);
assert
!
((
output_throughput
-
170.67
)
.abs
()
<
0.01
);
}
#[test]
fn
test_error_response_handling
()
{
// Test error response format from bench_one_batch_server.py
let
error_response
=
json!
({
"error"
:
"Request has failed. Invalid input format."
});
assert
!
(
error_response
.get
(
"error"
)
.is_some
());
assert
!
(
error_response
[
"error"
]
.as_str
()
.unwrap
()
.contains
(
"failed"
));
}
#[test]
fn
test_structured_output_request
()
{
// Test structured output format (json_schema)
let
structured_request
=
json!
({
"text"
:
"What is the capital of France? Answer in JSON."
,
"sampling_params"
:
{
"temperature"
:
0.0
,
"max_new_tokens"
:
64
,
"json_schema"
:
"$$ANY$$"
},
"stream"
:
false
});
assert_eq!
(
structured_request
[
"sampling_params"
][
"json_schema"
],
"$$ANY$$"
);
}
#[test]
fn
test_bootstrap_injection_with_benchmark_requests
()
{
// Test bootstrap injection with actual benchmark request patterns
let
mut
benchmark_request
=
json!
({
"input_ids"
:
vec!
[
vec!
[
1
,
2
,
3
,
4
];
16
],
// Batch size 16
"sampling_params"
:
{
"temperature"
:
0.0
,
"max_new_tokens"
:
8
,
"ignore_eos"
:
true
},
"return_logprob"
:
true
,
"stream"
:
true
});
// Simulate bootstrap injection
let
prefill_info
=
EngineInfo
::
new_prefill
(
"http://prefill:8080"
.to_string
(),
Some
(
9000
));
let
batch_size
=
16
;
benchmark_request
[
"bootstrap_host"
]
=
json!
(
vec!
[
prefill_info
.get_hostname
();
batch_size
]);
benchmark_request
[
"bootstrap_port"
]
=
json!
(
vec!
[
prefill_info
.bootstrap_port
;
batch_size
]);
benchmark_request
[
"bootstrap_room"
]
=
json!
((
0
..
batch_size
)
.map
(|
_
|
12345u64
)
.collect
::
<
Vec
<
_
>>
());
// Verify bootstrap fields match batch size
assert_eq!
(
benchmark_request
[
"bootstrap_host"
]
.as_array
()
.unwrap
()
.len
(),
batch_size
);
assert_eq!
(
benchmark_request
[
"bootstrap_port"
]
.as_array
()
.unwrap
()
.len
(),
batch_size
);
assert_eq!
(
benchmark_request
[
"bootstrap_room"
]
.as_array
()
.unwrap
()
.len
(),
batch_size
);
// Verify original fields are preserved
assert_eq!
(
benchmark_request
[
"return_logprob"
],
true
);
assert_eq!
(
benchmark_request
[
"stream"
],
true
);
}
#[test]
fn
test_server_info_response_format
()
{
// Test server info format expected by bench_one_batch_server.py
let
server_info
=
json!
({
"internal_states"
:
[{
"avg_spec_accept_length"
:
3.5
,
"last_gen_throughput"
:
2048.5
,
"load"
:
16
}],
"prefill"
:
[
{
"url"
:
"http://prefill1:8080"
,
"load"
:
10
},
{
"url"
:
"http://prefill2:8080"
,
"load"
:
20
}
],
"decode"
:
[
{
"url"
:
"http://decode1:8080"
,
"load"
:
5
},
{
"url"
:
"http://decode2:8080"
,
"load"
:
15
}
]
});
// Verify structure matches what benchmark expects
assert
!
(
server_info
[
"internal_states"
][
0
][
"avg_spec_accept_length"
]
.is_f64
());
assert
!
(
server_info
[
"internal_states"
][
0
][
"last_gen_throughput"
]
.is_f64
());
assert
!
(
server_info
[
"prefill"
]
.is_array
());
assert
!
(
server_info
[
"decode"
]
.is_array
());
}
// ========================================================================
// Comprehensive Endpoint Coverage Test
// ========================================================================
#[test]
fn
test_pd_endpoints_coverage
()
{
// Document all endpoints from Python mini_lb.py and verify implementation status
let
implemented_endpoints
=
vec!
[
(
"/health"
,
"GET"
,
true
),
(
"/health_generate"
,
"GET"
,
true
),
// Note: Python uses POST, we use GET
(
"/get_server_info"
,
"GET"
,
true
),
(
"/v1/models"
,
"GET"
,
true
),
(
"/get_model_info"
,
"GET"
,
true
),
(
"/generate"
,
"POST"
,
true
),
(
"/v1/chat/completions"
,
"POST"
,
true
),
(
"/v1/completions"
,
"POST"
,
true
),
(
"/flush_cache"
,
"POST"
,
true
),
(
"/get_loads"
,
"GET"
,
true
),
(
"/register"
,
"POST"
,
false
),
// NOT IMPLEMENTED - needs dynamic worker management
];
let
implemented_count
=
implemented_endpoints
.iter
()
.filter
(|(
_
,
_
,
impl_status
)|
*
impl_status
)
.count
();
let
total_count
=
implemented_endpoints
.len
();
// We've implemented 10 out of 11 endpoints (register is not needed for Phase 1/2)
assert_eq!
(
implemented_count
,
10
);
assert_eq!
(
total_count
,
11
);
// Document the missing endpoint
let
missing
:
Vec
<
_
>
=
implemented_endpoints
.iter
()
.filter
(|(
_
,
_
,
impl_status
)|
!
impl_status
)
.map
(|(
endpoint
,
method
,
_
)|
format!
(
"{} {}"
,
method
,
endpoint
))
.collect
();
assert_eq!
(
missing
,
vec!
[
"POST /register"
]);
}
#[test]
fn
test_large_batch_bootstrap_injection
()
{
// Test bootstrap injection performance with very large batches
// This simulates the bench_one_batch_server.py scenario
let
large_batch_sizes
=
vec!
[
1024
,
4096
,
8192
];
for
batch_size
in
large_batch_sizes
{
let
start
=
std
::
time
::
Instant
::
now
();
// Simulate a large batch request
let
mut
large_batch_request
=
json!
({
"input_ids"
:
vec!
[
vec!
[
1
,
2
,
3
,
4
];
batch_size
],
"sampling_params"
:
{
"temperature"
:
0.0
,
"max_new_tokens"
:
16
,
},
"stream"
:
true
});
// Simulate bootstrap injection
let
prefill_info
=
EngineInfo
::
new_prefill
(
"http://prefill:8080"
.to_string
(),
Some
(
9000
));
large_batch_request
[
"bootstrap_host"
]
=
json!
(
vec!
[
prefill_info
.get_hostname
();
batch_size
]);
large_batch_request
[
"bootstrap_port"
]
=
json!
(
vec!
[
prefill_info
.bootstrap_port
;
batch_size
]);
large_batch_request
[
"bootstrap_room"
]
=
json!
((
0
..
batch_size
)
.map
(|
_
|
rand
::
thread_rng
()
.gen
::
<
u64
>
())
.collect
::
<
Vec
<
_
>>
());
let
elapsed
=
start
.elapsed
();
// Verify bootstrap fields are correctly sized
assert_eq!
(
large_batch_request
[
"bootstrap_host"
]
.as_array
()
.unwrap
()
.len
(),
batch_size
);
assert_eq!
(
large_batch_request
[
"bootstrap_port"
]
.as_array
()
.unwrap
()
.len
(),
batch_size
);
assert_eq!
(
large_batch_request
[
"bootstrap_room"
]
.as_array
()
.unwrap
()
.len
(),
batch_size
);
// Bootstrap injection should be reasonably fast even for large batches
println!
(
"Bootstrap injection for batch_size {} took {:?}"
,
batch_size
,
elapsed
);
assert
!
(
elapsed
.as_millis
()
<
1000
,
"Bootstrap injection took too long for batch size {}"
,
batch_size
);
}
}
#[test]
fn
test_payload_size_calculation
()
{
// Test payload size estimation for bench_one_batch_server.py scenarios
let
test_cases
=
vec!
[
(
1
,
1024
,
16
),
// Small batch
(
16
,
1024
,
16
),
// Medium batch
(
64
,
1024
,
16
),
// Large batch
(
8192
,
4096
,
5
),
// Benchmark scenario
];
for
(
batch_size
,
input_len
,
_
output_len
)
in
test_cases
{
// Estimate payload size (rough calculation)
// Each token is ~4 bytes (i32), plus JSON overhead
let
tokens_size
=
batch_size
*
input_len
*
4
;
// 4 bytes per token
let
json_overhead
=
batch_size
*
100
;
// ~100 bytes overhead per request
let
total_size
=
tokens_size
+
json_overhead
;
println!
(
"Batch size: {}, Input len: {}, Estimated payload: {} MB"
,
batch_size
,
input_len
,
total_size
/
(
1024
*
1024
)
);
// For the benchmark case (8192, 4096), this should be ~134 MB
if
batch_size
==
8192
&&
input_len
==
4096
{
assert
!
(
total_size
>
100
*
1024
*
1024
,
"Benchmark payload should be > 100MB"
);
assert
!
(
total_size
<
200
*
1024
*
1024
,
"Benchmark payload should be < 200MB"
);
}
}
}
#[test]
fn
test_policy_type_to_pd_selection_policy_mapping
()
{
// Document the mapping from PolicyType to PDSelectionPolicy
// This mapping happens in lib.rs when pd_disaggregated=true
// PolicyType::Random -> PDSelectionPolicy::Random
// PolicyType::PowerOfTwo -> PDSelectionPolicy::PowerOfTwo
// PolicyType::CacheAware -> PDSelectionPolicy::CacheAware { ... }
// PolicyType::RoundRobin -> ERROR (not supported in PD mode)
// Test that PDSelectionPolicy doesn't include RoundRobin
let
pd_policy_count
=
3
;
// Random, PowerOfTwo, CacheAware
assert_eq!
(
pd_policy_count
,
3
,
"PDSelectionPolicy should have exactly 3 variants"
);
// Verify that each PDSelectionPolicy variant can be created
let
_
random
=
PDSelectionPolicy
::
Random
;
let
_
po2
=
PDSelectionPolicy
::
PowerOfTwo
;
let
_
cache_aware
=
PDSelectionPolicy
::
CacheAware
{
cache_threshold
:
0.5
,
balance_abs_threshold
:
32
,
balance_rel_threshold
:
1.1
,
};
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment