Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
9d68bdb2
Unverified
Commit
9d68bdb2
authored
Aug 11, 2025
by
Simo Lin
Committed by
GitHub
Aug 11, 2025
Browse files
[router] Add Rust Binary Entrypoint for SGLang Router (#9089)
parent
a2184901
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
638 additions
and
78 deletions
+638
-78
sgl-router/Cargo.toml
sgl-router/Cargo.toml
+5
-0
sgl-router/README.md
sgl-router/README.md
+31
-2
sgl-router/src/main.rs
sgl-router/src/main.rs
+490
-0
sgl-router/src/routers/factory.rs
sgl-router/src/routers/factory.rs
+19
-14
sgl-router/src/routers/pd_router.rs
sgl-router/src/routers/pd_router.rs
+4
-2
sgl-router/src/routers/router.rs
sgl-router/src/routers/router.rs
+69
-26
sgl-router/src/server.rs
sgl-router/src/server.rs
+1
-1
sgl-router/src/service_discovery.rs
sgl-router/src/service_discovery.rs
+11
-10
sgl-router/tests/api_endpoints_test.rs
sgl-router/tests/api_endpoints_test.rs
+3
-10
sgl-router/tests/request_formats_test.rs
sgl-router/tests/request_formats_test.rs
+1
-5
sgl-router/tests/streaming_tests.rs
sgl-router/tests/streaming_tests.rs
+1
-5
sgl-router/tests/test_pd_routing.rs
sgl-router/tests/test_pd_routing.rs
+3
-3
No files found.
sgl-router/Cargo.toml
View file @
9d68bdb2
...
@@ -9,7 +9,12 @@ name = "sglang_router_rs"
...
@@ -9,7 +9,12 @@ name = "sglang_router_rs"
# Python/C binding + Rust library: Use ["cdylib", "rlib"]
# Python/C binding + Rust library: Use ["cdylib", "rlib"]
crate-type
=
[
"cdylib"
,
"rlib"
]
crate-type
=
[
"cdylib"
,
"rlib"
]
[[bin]]
name
=
"sglang-router"
path
=
"src/main.rs"
[dependencies]
[dependencies]
clap
=
{
version
=
"4"
,
features
=
["derive"]
}
axum
=
{
version
=
"0.8.4"
,
features
=
[
"macros"
,
"ws"
,
"tracing"
]
}
axum
=
{
version
=
"0.8.4"
,
features
=
[
"macros"
,
"ws"
,
"tracing"
]
}
tower
=
{
version
=
"0.5"
,
features
=
["full"]
}
tower
=
{
version
=
"0.5"
,
features
=
["full"]
}
tower-http
=
{
version
=
"0.6"
,
features
=
[
"trace"
,
"compression-gzip"
,
"cors"
,
"timeout"
,
"limit"
,
"request-id"
,
"util"
]
}
tower-http
=
{
version
=
"0.6"
,
features
=
[
"trace"
,
"compression-gzip"
,
"cors"
,
"timeout"
,
"limit"
,
"request-id"
,
"util"
]
}
...
...
sgl-router/README.md
View file @
9d68bdb2
...
@@ -56,7 +56,21 @@ pip install -e .
...
@@ -56,7 +56,21 @@ pip install -e .
cargo build
cargo build
```
```
#### Launch Router with Worker URLs in regular mode
#### Using the Rust Binary Directly (Alternative to Python)
```
bash
# Build the Rust binary
cargo build
--release
# Launch router with worker URLs in regular mode
./target/release/sglang-router
\
--worker-urls
http://worker1:8000 http://worker2:8000
# Or use cargo run
cargo run
--release
--
\
--worker-urls
http://worker1:8000 http://worker2:8000
```
#### Launch Router with Python (Original Method)
```
bash
```
bash
# Launch router with worker URLs
# Launch router with worker URLs
python
-m
sglang_router.launch_router
\
python
-m
sglang_router.launch_router
\
...
@@ -68,7 +82,22 @@ python -m sglang_router.launch_router \
...
@@ -68,7 +82,22 @@ python -m sglang_router.launch_router \
# Note that the prefill and decode URLs must be provided in the following format:
# Note that the prefill and decode URLs must be provided in the following format:
# http://<ip>:<port> for decode nodes
# http://<ip>:<port> for decode nodes
# http://<ip>:<port> bootstrap-port for prefill nodes, where bootstrap-port is optional
# http://<ip>:<port> bootstrap-port for prefill nodes, where bootstrap-port is optional
# Launch router with worker URLs
# Using Rust binary directly
./target/release/sglang-router
\
--pd-disaggregation
\
--policy
cache_aware
\
--prefill
http://127.0.0.1:30001 9001
\
--prefill
http://127.0.0.2:30002 9002
\
--prefill
http://127.0.0.3:30003 9003
\
--prefill
http://127.0.0.4:30004 9004
\
--decode
http://127.0.0.5:30005
\
--decode
http://127.0.0.6:30006
\
--decode
http://127.0.0.7:30007
\
--host
0.0.0.0
\
--port
8080
# Or using Python launcher
python
-m
sglang_router.launch_router
\
python
-m
sglang_router.launch_router
\
--pd-disaggregation
\
--pd-disaggregation
\
--policy
cache_aware
\
--policy
cache_aware
\
...
...
sgl-router/src/main.rs
0 → 100644
View file @
9d68bdb2
use
clap
::{
ArgAction
,
Parser
};
use
sglang_router_rs
::
config
::{
CircuitBreakerConfig
,
ConfigError
,
ConfigResult
,
DiscoveryConfig
,
MetricsConfig
,
PolicyConfig
,
RetryConfig
,
RouterConfig
,
RoutingMode
,
};
use
sglang_router_rs
::
metrics
::
PrometheusConfig
;
use
sglang_router_rs
::
server
::{
self
,
ServerConfig
};
use
sglang_router_rs
::
service_discovery
::
ServiceDiscoveryConfig
;
use
std
::
collections
::
HashMap
;
// Helper function to parse prefill arguments from command line
fn
parse_prefill_args
()
->
Vec
<
(
String
,
Option
<
u16
>
)
>
{
let
args
:
Vec
<
String
>
=
std
::
env
::
args
()
.collect
();
let
mut
prefill_entries
=
Vec
::
new
();
let
mut
i
=
0
;
while
i
<
args
.len
()
{
if
args
[
i
]
==
"--prefill"
&&
i
+
1
<
args
.len
()
{
let
url
=
args
[
i
+
1
]
.clone
();
let
bootstrap_port
=
if
i
+
2
<
args
.len
()
&&
!
args
[
i
+
2
]
.starts_with
(
"--"
)
{
// Check if next arg is a port number
if
let
Ok
(
port
)
=
args
[
i
+
2
]
.parse
::
<
u16
>
()
{
i
+=
1
;
// Skip the port argument
Some
(
port
)
}
else
if
args
[
i
+
2
]
.to_lowercase
()
==
"none"
{
i
+=
1
;
// Skip the "none" argument
None
}
else
{
None
}
}
else
{
None
};
prefill_entries
.push
((
url
,
bootstrap_port
));
i
+=
2
;
// Skip --prefill and URL
}
else
{
i
+=
1
;
}
}
prefill_entries
}
#[derive(Parser,
Debug)]
#[command(name
=
"sglang-router"
)]
#[command(about
=
"SGLang Router - High-performance request distribution across worker nodes"
)]
#[command(long_about
=
r#"
SGLang Router - High-performance request distribution across worker nodes
Usage:
This launcher enables starting a router with individual worker instances. It is useful for
multi-node setups or when you want to start workers and router separately.
Examples:
# Regular mode
sglang-router --worker-urls http://worker1:8000 http://worker2:8000
# PD disaggregated mode with same policy for both
sglang-router --pd-disaggregation \
--prefill http://127.0.0.1:30001 9001 \
--prefill http://127.0.0.2:30002 9002 \
--decode http://127.0.0.3:30003 \
--decode http://127.0.0.4:30004 \
--policy cache_aware
# PD mode with different policies for prefill and decode
sglang-router --pd-disaggregation \
--prefill http://127.0.0.1:30001 9001 \
--prefill http://127.0.0.2:30002 \
--decode http://127.0.0.3:30003 \
--decode http://127.0.0.4:30004 \
--prefill-policy cache_aware --decode-policy power_of_two
"#
)]
struct
CliArgs
{
/// Host address to bind the router server
#[arg(long,
default_value
=
"127.0.0.1"
)]
host
:
String
,
/// Port number to bind the router server
#[arg(long,
default_value_t
=
30000
)]
port
:
u16
,
/// List of worker URLs (e.g., http://worker1:8000 http://worker2:8000)
#[arg(long,
num_args
=
0
..
)]
worker_urls
:
Vec
<
String
>
,
/// Load balancing policy to use
#[arg(long,
default_value
=
"cache_aware"
,
value_parser
=
[
"random"
,
"round_robin"
,
"cache_aware"
,
"power_of_two"
]
)]
policy
:
String
,
/// Enable PD (Prefill-Decode) disaggregated mode
#[arg(long,
default_value_t
=
false
)]
pd_disaggregation
:
bool
,
/// Decode server URL (can be specified multiple times)
#[arg(long,
action
=
ArgAction::Append)]
decode
:
Vec
<
String
>
,
/// Specific policy for prefill nodes in PD mode
#[arg(long,
value_parser
=
[
"random"
,
"round_robin"
,
"cache_aware"
,
"power_of_two"
]
)]
prefill_policy
:
Option
<
String
>
,
/// Specific policy for decode nodes in PD mode
#[arg(long,
value_parser
=
[
"random"
,
"round_robin"
,
"cache_aware"
,
"power_of_two"
]
)]
decode_policy
:
Option
<
String
>
,
/// Timeout in seconds for worker startup
#[arg(long,
default_value_t
=
300
)]
worker_startup_timeout_secs
:
u64
,
/// Interval in seconds between checks for worker startup
#[arg(long,
default_value_t
=
10
)]
worker_startup_check_interval
:
u64
,
/// Cache threshold (0.0-1.0) for cache-aware routing
#[arg(long,
default_value_t
=
0.5
)]
cache_threshold
:
f32
,
/// Absolute threshold for load balancing
#[arg(long,
default_value_t
=
32
)]
balance_abs_threshold
:
usize
,
/// Relative threshold for load balancing
#[arg(long,
default_value_t
=
1.0001
)]
balance_rel_threshold
:
f32
,
/// Interval in seconds between cache eviction operations
#[arg(long,
default_value_t
=
60
)]
eviction_interval
:
u64
,
/// Maximum size of the approximation tree for cache-aware routing
#[arg(long,
default_value_t
=
16777216
)]
// 2^24
max_tree_size
:
usize
,
/// Maximum payload size in bytes
#[arg(long,
default_value_t
=
268435456
)]
// 256MB
max_payload_size
:
usize
,
/// Enable data parallelism aware schedule
#[arg(long,
default_value_t
=
false
)]
dp_aware
:
bool
,
/// API key for worker authorization
#[arg(long)]
api_key
:
Option
<
String
>
,
/// Directory to store log files
#[arg(long)]
log_dir
:
Option
<
String
>
,
/// Set the logging level
#[arg(long,
default_value
=
"info"
,
value_parser
=
[
"debug"
,
"info"
,
"warn"
,
"error"
]
)]
log_level
:
String
,
/// Enable Kubernetes service discovery
#[arg(long,
default_value_t
=
false
)]
service_discovery
:
bool
,
/// Label selector for Kubernetes service discovery (format: key1=value1 key2=value2)
#[arg(long,
num_args
=
0
..
)]
selector
:
Vec
<
String
>
,
/// Port to use for discovered worker pods
#[arg(long,
default_value_t
=
80
)]
service_discovery_port
:
u16
,
/// Kubernetes namespace to watch for pods
#[arg(long)]
service_discovery_namespace
:
Option
<
String
>
,
/// Label selector for prefill server pods in PD mode
#[arg(long,
num_args
=
0
..
)]
prefill_selector
:
Vec
<
String
>
,
/// Label selector for decode server pods in PD mode
#[arg(long,
num_args
=
0
..
)]
decode_selector
:
Vec
<
String
>
,
/// Port to expose Prometheus metrics
#[arg(long,
default_value_t
=
29000
)]
prometheus_port
:
u16
,
/// Host address to bind the Prometheus metrics server
#[arg(long,
default_value
=
"127.0.0.1"
)]
prometheus_host
:
String
,
/// Custom HTTP headers to check for request IDs
#[arg(long,
num_args
=
0
..
)]
request_id_headers
:
Vec
<
String
>
,
/// Request timeout in seconds
#[arg(long,
default_value_t
=
600
)]
request_timeout_secs
:
u64
,
/// Maximum number of concurrent requests allowed
#[arg(long,
default_value_t
=
64
)]
max_concurrent_requests
:
usize
,
/// CORS allowed origins
#[arg(long,
num_args
=
0
..
)]
cors_allowed_origins
:
Vec
<
String
>
,
// Retry configuration
/// Maximum number of retries
#[arg(long,
default_value_t
=
3
)]
retry_max_retries
:
u32
,
/// Initial backoff in milliseconds for retries
#[arg(long,
default_value_t
=
100
)]
retry_initial_backoff_ms
:
u64
,
/// Maximum backoff in milliseconds for retries
#[arg(long,
default_value_t
=
10000
)]
retry_max_backoff_ms
:
u64
,
/// Backoff multiplier for exponential backoff
#[arg(long,
default_value_t
=
2.0
)]
retry_backoff_multiplier
:
f32
,
/// Jitter factor for retry backoff
#[arg(long,
default_value_t
=
0.1
)]
retry_jitter_factor
:
f32
,
/// Disable retries
#[arg(long,
default_value_t
=
false
)]
disable_retries
:
bool
,
// Circuit breaker configuration
/// Number of failures before circuit breaker opens
#[arg(long,
default_value_t
=
5
)]
cb_failure_threshold
:
u32
,
/// Number of successes before circuit breaker closes
#[arg(long,
default_value_t
=
2
)]
cb_success_threshold
:
u32
,
/// Timeout duration in seconds for circuit breaker
#[arg(long,
default_value_t
=
30
)]
cb_timeout_duration_secs
:
u64
,
/// Window duration in seconds for circuit breaker
#[arg(long,
default_value_t
=
60
)]
cb_window_duration_secs
:
u64
,
/// Disable circuit breaker
#[arg(long,
default_value_t
=
false
)]
disable_circuit_breaker
:
bool
,
}
impl
CliArgs
{
/// Parse selector strings into HashMap
fn
parse_selector
(
selector_list
:
&
[
String
])
->
HashMap
<
String
,
String
>
{
let
mut
map
=
HashMap
::
new
();
for
item
in
selector_list
{
if
let
Some
(
eq_pos
)
=
item
.find
(
'='
)
{
let
key
=
item
[
..
eq_pos
]
.to_string
();
let
value
=
item
[
eq_pos
+
1
..
]
.to_string
();
map
.insert
(
key
,
value
);
}
}
map
}
/// Convert policy string to PolicyConfig
fn
parse_policy
(
&
self
,
policy_str
:
&
str
)
->
PolicyConfig
{
match
policy_str
{
"random"
=>
PolicyConfig
::
Random
,
"round_robin"
=>
PolicyConfig
::
RoundRobin
,
"cache_aware"
=>
PolicyConfig
::
CacheAware
{
cache_threshold
:
self
.cache_threshold
,
balance_abs_threshold
:
self
.balance_abs_threshold
,
balance_rel_threshold
:
self
.balance_rel_threshold
,
eviction_interval_secs
:
self
.eviction_interval
,
max_tree_size
:
self
.max_tree_size
,
},
"power_of_two"
=>
PolicyConfig
::
PowerOfTwo
{
load_check_interval_secs
:
5
,
// Default value
},
_
=>
PolicyConfig
::
RoundRobin
,
// Fallback
}
}
/// Convert CLI arguments to RouterConfig
fn
to_router_config
(
&
self
,
prefill_urls
:
Vec
<
(
String
,
Option
<
u16
>
)
>
,
)
->
ConfigResult
<
RouterConfig
>
{
// Determine routing mode
let
mode
=
if
self
.pd_disaggregation
{
let
decode_urls
=
self
.decode
.clone
();
// Validate PD configuration if not using service discovery
if
!
self
.service_discovery
&&
(
prefill_urls
.is_empty
()
||
decode_urls
.is_empty
())
{
return
Err
(
ConfigError
::
ValidationFailed
{
reason
:
"PD disaggregation mode requires --prefill and --decode URLs when not using service discovery"
.to_string
(),
});
}
RoutingMode
::
PrefillDecode
{
prefill_urls
,
decode_urls
,
prefill_policy
:
self
.prefill_policy
.as_ref
()
.map
(|
p
|
self
.parse_policy
(
p
)),
decode_policy
:
self
.decode_policy
.as_ref
()
.map
(|
p
|
self
.parse_policy
(
p
)),
}
}
else
{
// Regular mode
if
!
self
.service_discovery
&&
self
.worker_urls
.is_empty
()
{
return
Err
(
ConfigError
::
ValidationFailed
{
reason
:
"Regular mode requires --worker-urls when not using service discovery"
.to_string
(),
});
}
RoutingMode
::
Regular
{
worker_urls
:
self
.worker_urls
.clone
(),
}
};
// Main policy
let
policy
=
self
.parse_policy
(
&
self
.policy
);
// Service discovery configuration
let
discovery
=
if
self
.service_discovery
{
Some
(
DiscoveryConfig
{
enabled
:
true
,
namespace
:
self
.service_discovery_namespace
.clone
(),
port
:
self
.service_discovery_port
,
check_interval_secs
:
60
,
selector
:
Self
::
parse_selector
(
&
self
.selector
),
prefill_selector
:
Self
::
parse_selector
(
&
self
.prefill_selector
),
decode_selector
:
Self
::
parse_selector
(
&
self
.decode_selector
),
bootstrap_port_annotation
:
"sglang.ai/bootstrap-port"
.to_string
(),
})
}
else
{
None
};
// Metrics configuration
let
metrics
=
Some
(
MetricsConfig
{
port
:
self
.prometheus_port
,
host
:
self
.prometheus_host
.clone
(),
});
// Build RouterConfig
Ok
(
RouterConfig
{
mode
,
policy
,
host
:
self
.host
.clone
(),
port
:
self
.port
,
max_payload_size
:
self
.max_payload_size
,
request_timeout_secs
:
self
.request_timeout_secs
,
worker_startup_timeout_secs
:
self
.worker_startup_timeout_secs
,
worker_startup_check_interval_secs
:
self
.worker_startup_check_interval
,
dp_aware
:
self
.dp_aware
,
api_key
:
self
.api_key
.clone
(),
discovery
,
metrics
,
log_dir
:
self
.log_dir
.clone
(),
log_level
:
Some
(
self
.log_level
.clone
()),
request_id_headers
:
if
self
.request_id_headers
.is_empty
()
{
None
}
else
{
Some
(
self
.request_id_headers
.clone
())
},
max_concurrent_requests
:
self
.max_concurrent_requests
,
cors_allowed_origins
:
self
.cors_allowed_origins
.clone
(),
retry
:
RetryConfig
{
max_retries
:
self
.retry_max_retries
,
initial_backoff_ms
:
self
.retry_initial_backoff_ms
,
max_backoff_ms
:
self
.retry_max_backoff_ms
,
backoff_multiplier
:
self
.retry_backoff_multiplier
,
jitter_factor
:
self
.retry_jitter_factor
,
},
circuit_breaker
:
CircuitBreakerConfig
{
failure_threshold
:
self
.cb_failure_threshold
,
success_threshold
:
self
.cb_success_threshold
,
timeout_duration_secs
:
self
.cb_timeout_duration_secs
,
window_duration_secs
:
self
.cb_window_duration_secs
,
},
disable_retries
:
self
.disable_retries
,
disable_circuit_breaker
:
self
.disable_circuit_breaker
,
})
}
/// Create ServerConfig from CLI args and RouterConfig
fn
to_server_config
(
&
self
,
router_config
:
RouterConfig
)
->
ServerConfig
{
// Create service discovery config if enabled
let
service_discovery_config
=
if
self
.service_discovery
{
Some
(
ServiceDiscoveryConfig
{
enabled
:
true
,
selector
:
Self
::
parse_selector
(
&
self
.selector
),
check_interval
:
std
::
time
::
Duration
::
from_secs
(
60
),
port
:
self
.service_discovery_port
,
namespace
:
self
.service_discovery_namespace
.clone
(),
pd_mode
:
self
.pd_disaggregation
,
prefill_selector
:
Self
::
parse_selector
(
&
self
.prefill_selector
),
decode_selector
:
Self
::
parse_selector
(
&
self
.decode_selector
),
bootstrap_port_annotation
:
"sglang.ai/bootstrap-port"
.to_string
(),
})
}
else
{
None
};
// Create Prometheus config
let
prometheus_config
=
Some
(
PrometheusConfig
{
port
:
self
.prometheus_port
,
host
:
self
.prometheus_host
.clone
(),
});
ServerConfig
{
host
:
self
.host
.clone
(),
port
:
self
.port
,
router_config
,
max_payload_size
:
self
.max_payload_size
,
log_dir
:
self
.log_dir
.clone
(),
log_level
:
Some
(
self
.log_level
.clone
()),
service_discovery_config
,
prometheus_config
,
request_timeout_secs
:
self
.request_timeout_secs
,
request_id_headers
:
if
self
.request_id_headers
.is_empty
()
{
None
}
else
{
Some
(
self
.request_id_headers
.clone
())
},
}
}
}
fn
main
()
->
Result
<
(),
Box
<
dyn
std
::
error
::
Error
>>
{
// Parse prefill arguments manually before clap parsing
let
prefill_urls
=
parse_prefill_args
();
// Filter out prefill arguments and their values before passing to clap
let
mut
filtered_args
:
Vec
<
String
>
=
Vec
::
new
();
let
raw_args
:
Vec
<
String
>
=
std
::
env
::
args
()
.collect
();
let
mut
i
=
0
;
while
i
<
raw_args
.len
()
{
if
raw_args
[
i
]
==
"--prefill"
&&
i
+
1
<
raw_args
.len
()
{
// Skip --prefill and its URL
i
+=
2
;
// Also skip bootstrap port if present
if
i
<
raw_args
.len
()
&&
!
raw_args
[
i
]
.starts_with
(
"--"
)
{
if
raw_args
[
i
]
.parse
::
<
u16
>
()
.is_ok
()
||
raw_args
[
i
]
.to_lowercase
()
==
"none"
{
i
+=
1
;
}
}
}
else
{
filtered_args
.push
(
raw_args
[
i
]
.clone
());
i
+=
1
;
}
}
// Parse CLI arguments with clap using filtered args
let
cli_args
=
CliArgs
::
parse_from
(
filtered_args
);
// Print startup info
println!
(
"SGLang Router starting..."
);
println!
(
"Host: {}:{}"
,
cli_args
.host
,
cli_args
.port
);
println!
(
"Mode: {}"
,
if
cli_args
.pd_disaggregation
{
"PD Disaggregated"
}
else
{
"Regular"
}
);
println!
(
"Policy: {}"
,
cli_args
.policy
);
if
cli_args
.pd_disaggregation
&&
!
prefill_urls
.is_empty
()
{
println!
(
"Prefill nodes: {:?}"
,
prefill_urls
);
println!
(
"Decode nodes: {:?}"
,
cli_args
.decode
);
}
// Convert to RouterConfig
let
router_config
=
cli_args
.to_router_config
(
prefill_urls
)
?
;
// Validate configuration
router_config
.validate
()
?
;
// Create ServerConfig
let
server_config
=
cli_args
.to_server_config
(
router_config
);
// Create a new runtime for the server (like Python binding does)
let
runtime
=
tokio
::
runtime
::
Runtime
::
new
()
?
;
// Block on the async startup function
runtime
.block_on
(
async
move
{
server
::
startup
(
server_config
)
.await
})
?
;
Ok
(())
}
sgl-router/src/routers/factory.rs
View file @
9d68bdb2
...
@@ -11,29 +11,32 @@ pub struct RouterFactory;
...
@@ -11,29 +11,32 @@ pub struct RouterFactory;
impl
RouterFactory
{
impl
RouterFactory
{
/// Create a router instance from application context
/// Create a router instance from application context
pub
fn
create_router
(
ctx
:
&
Arc
<
AppContext
>
)
->
Result
<
Box
<
dyn
RouterTrait
>
,
String
>
{
pub
async
fn
create_router
(
ctx
:
&
Arc
<
AppContext
>
)
->
Result
<
Box
<
dyn
RouterTrait
>
,
String
>
{
match
&
ctx
.router_config.mode
{
match
&
ctx
.router_config.mode
{
RoutingMode
::
Regular
{
worker_urls
}
=>
{
RoutingMode
::
Regular
{
worker_urls
}
=>
{
Self
::
create_regular_router
(
worker_urls
,
&
ctx
.router_config.policy
,
ctx
)
Self
::
create_regular_router
(
worker_urls
,
&
ctx
.router_config.policy
,
ctx
)
.await
}
}
RoutingMode
::
PrefillDecode
{
RoutingMode
::
PrefillDecode
{
prefill_urls
,
prefill_urls
,
decode_urls
,
decode_urls
,
prefill_policy
,
prefill_policy
,
decode_policy
,
decode_policy
,
}
=>
Self
::
create_pd_router
(
}
=>
{
prefill_urls
,
Self
::
create_pd_router
(
decode_urls
,
prefill_urls
,
prefill_policy
.as_ref
(),
decode_urls
,
decode_policy
.as_ref
(),
prefill_policy
.as_ref
(),
&
ctx
.router_config.policy
,
decode_policy
.as_ref
(),
ctx
,
&
ctx
.router_config.policy
,
),
ctx
,
)
.await
}
}
}
}
}
/// Create a regular router with injected policy
/// Create a regular router with injected policy
fn
create_regular_router
(
async
fn
create_regular_router
(
worker_urls
:
&
[
String
],
worker_urls
:
&
[
String
],
policy_config
:
&
PolicyConfig
,
policy_config
:
&
PolicyConfig
,
ctx
:
&
Arc
<
AppContext
>
,
ctx
:
&
Arc
<
AppContext
>
,
...
@@ -52,13 +55,14 @@ impl RouterFactory {
...
@@ -52,13 +55,14 @@ impl RouterFactory {
ctx
.router_config.api_key
.clone
(),
ctx
.router_config.api_key
.clone
(),
ctx
.router_config.retry
.clone
(),
ctx
.router_config.retry
.clone
(),
ctx
.router_config.circuit_breaker
.clone
(),
ctx
.router_config.circuit_breaker
.clone
(),
)
?
;
)
.await
?
;
Ok
(
Box
::
new
(
router
))
Ok
(
Box
::
new
(
router
))
}
}
/// Create a PD router with injected policy
/// Create a PD router with injected policy
fn
create_pd_router
(
async
fn
create_pd_router
(
prefill_urls
:
&
[(
String
,
Option
<
u16
>
)],
prefill_urls
:
&
[(
String
,
Option
<
u16
>
)],
decode_urls
:
&
[
String
],
decode_urls
:
&
[
String
],
prefill_policy_config
:
Option
<&
PolicyConfig
>
,
prefill_policy_config
:
Option
<&
PolicyConfig
>
,
...
@@ -83,7 +87,8 @@ impl RouterFactory {
...
@@ -83,7 +87,8 @@ impl RouterFactory {
ctx
.router_config.worker_startup_check_interval_secs
,
ctx
.router_config.worker_startup_check_interval_secs
,
ctx
.router_config.retry
.clone
(),
ctx
.router_config.retry
.clone
(),
ctx
.router_config.circuit_breaker
.clone
(),
ctx
.router_config.circuit_breaker
.clone
(),
)
?
;
)
.await
?
;
Ok
(
Box
::
new
(
router
))
Ok
(
Box
::
new
(
router
))
}
}
...
...
sgl-router/src/routers/pd_router.rs
View file @
9d68bdb2
...
@@ -67,6 +67,7 @@ impl PDRouter {
...
@@ -67,6 +67,7 @@ impl PDRouter {
self
.timeout_secs
,
self
.timeout_secs
,
self
.interval_secs
,
self
.interval_secs
,
)
)
.await
.map_err
(|
_
|
PDRouterError
::
HealthCheckFailed
{
.map_err
(|
_
|
PDRouterError
::
HealthCheckFailed
{
url
:
url
.to_string
(),
url
:
url
.to_string
(),
})
})
...
@@ -349,7 +350,7 @@ impl PDRouter {
...
@@ -349,7 +350,7 @@ impl PDRouter {
Ok
(
format!
(
"Successfully removed decode server: {}"
,
url
))
Ok
(
format!
(
"Successfully removed decode server: {}"
,
url
))
}
}
pub
fn
new
(
pub
async
fn
new
(
prefill_urls
:
Vec
<
(
String
,
Option
<
u16
>
)
>
,
prefill_urls
:
Vec
<
(
String
,
Option
<
u16
>
)
>
,
decode_urls
:
Vec
<
String
>
,
decode_urls
:
Vec
<
String
>
,
prefill_policy
:
Arc
<
dyn
LoadBalancingPolicy
>
,
prefill_policy
:
Arc
<
dyn
LoadBalancingPolicy
>
,
...
@@ -392,7 +393,8 @@ impl PDRouter {
...
@@ -392,7 +393,8 @@ impl PDRouter {
&
all_urls
,
&
all_urls
,
timeout_secs
,
timeout_secs
,
interval_secs
,
interval_secs
,
)
?
;
)
.await
?
;
}
}
// Initialize cache-aware policies with workers
// Initialize cache-aware policies with workers
...
...
sgl-router/src/routers/router.rs
View file @
9d68bdb2
...
@@ -17,7 +17,6 @@ use futures_util::StreamExt;
...
@@ -17,7 +17,6 @@ use futures_util::StreamExt;
use
reqwest
::
Client
;
use
reqwest
::
Client
;
use
std
::
collections
::
HashMap
;
use
std
::
collections
::
HashMap
;
use
std
::
sync
::{
Arc
,
RwLock
};
use
std
::
sync
::{
Arc
,
RwLock
};
use
std
::
thread
;
use
std
::
time
::{
Duration
,
Instant
};
use
std
::
time
::{
Duration
,
Instant
};
use
tokio_stream
::
wrappers
::
UnboundedReceiverStream
;
use
tokio_stream
::
wrappers
::
UnboundedReceiverStream
;
use
tracing
::{
debug
,
error
,
info
,
warn
};
use
tracing
::{
debug
,
error
,
info
,
warn
};
...
@@ -52,7 +51,7 @@ pub struct Router {
...
@@ -52,7 +51,7 @@ pub struct Router {
impl
Router
{
impl
Router
{
/// Create a new router with injected policy and client
/// Create a new router with injected policy and client
pub
fn
new
(
pub
async
fn
new
(
worker_urls
:
Vec
<
String
>
,
worker_urls
:
Vec
<
String
>
,
policy
:
Arc
<
dyn
LoadBalancingPolicy
>
,
policy
:
Arc
<
dyn
LoadBalancingPolicy
>
,
client
:
Client
,
client
:
Client
,
...
@@ -68,7 +67,7 @@ impl Router {
...
@@ -68,7 +67,7 @@ impl Router {
// Wait for workers to be healthy (skip if empty - for service discovery mode)
// Wait for workers to be healthy (skip if empty - for service discovery mode)
if
!
worker_urls
.is_empty
()
{
if
!
worker_urls
.is_empty
()
{
Self
::
wait_for_healthy_workers
(
&
worker_urls
,
timeout_secs
,
interval_secs
)
?
;
Self
::
wait_for_healthy_workers
(
&
worker_urls
,
timeout_secs
,
interval_secs
)
.await
?
;
}
}
let
worker_urls
=
if
dp_aware
{
let
worker_urls
=
if
dp_aware
{
...
@@ -156,7 +155,7 @@ impl Router {
...
@@ -156,7 +155,7 @@ impl Router {
.collect
()
.collect
()
}
}
pub
fn
wait_for_healthy_workers
(
pub
async
fn
wait_for_healthy_workers
(
worker_urls
:
&
[
String
],
worker_urls
:
&
[
String
],
timeout_secs
:
u64
,
timeout_secs
:
u64
,
interval_secs
:
u64
,
interval_secs
:
u64
,
...
@@ -167,9 +166,24 @@ impl Router {
...
@@ -167,9 +166,24 @@ impl Router {
);
);
}
}
// Perform health check asynchronously
Self
::
wait_for_healthy_workers_async
(
worker_urls
,
timeout_secs
,
interval_secs
)
.await
}
async
fn
wait_for_healthy_workers_async
(
worker_urls
:
&
[
String
],
timeout_secs
:
u64
,
interval_secs
:
u64
,
)
->
Result
<
(),
String
>
{
info!
(
"Waiting for {} workers to become healthy (timeout: {}s)"
,
worker_urls
.len
(),
timeout_secs
);
let
start_time
=
std
::
time
::
Instant
::
now
();
let
start_time
=
std
::
time
::
Instant
::
now
();
let
sync_
client
=
reqwest
::
blocking
::
Client
::
builder
()
let
client
=
reqwest
::
Client
::
builder
()
.timeout
(
Duration
::
from_secs
(
timeout_secs
))
.timeout
(
Duration
::
from_secs
(
2
))
.build
()
.build
()
.map_err
(|
e
|
format!
(
"Failed to create HTTP client: {}"
,
e
))
?
;
.map_err
(|
e
|
format!
(
"Failed to create HTTP client: {}"
,
e
))
?
;
...
@@ -185,20 +199,48 @@ impl Router {
...
@@ -185,20 +199,48 @@ impl Router {
));
));
}
}
// Perform all health checks concurrently
let
mut
health_checks
=
Vec
::
new
();
for
url
in
worker_urls
{
let
client_clone
=
client
.clone
();
let
url_clone
=
url
.clone
();
let
check_health
=
tokio
::
spawn
(
async
move
{
let
health_url
=
format!
(
"{}/health"
,
url_clone
);
match
client_clone
.get
(
&
health_url
)
.send
()
.await
{
Ok
(
res
)
=>
{
if
res
.status
()
.is_success
()
{
None
}
else
{
Some
((
url_clone
,
format!
(
"status: {}"
,
res
.status
())))
}
}
Err
(
_
)
=>
Some
((
url_clone
,
"not ready"
.to_string
())),
}
});
health_checks
.push
(
check_health
);
}
// Wait for all health checks to complete
let
results
=
futures
::
future
::
join_all
(
health_checks
)
.await
;
let
mut
all_healthy
=
true
;
let
mut
all_healthy
=
true
;
let
mut
unhealthy_workers
=
Vec
::
new
();
let
mut
unhealthy_workers
=
Vec
::
new
();
for
url
in
worker_urls
{
for
result
in
results
{
match
sync_client
.get
(
&
format!
(
"{}/health"
,
url
))
.send
()
{
match
result
{
Ok
(
res
)
=>
{
Ok
(
None
)
=>
{
if
!
res
.status
()
.is_success
()
{
// Worker is healthy
all_healthy
=
false
;
unhealthy_workers
.push
((
url
,
format!
(
"status: {}"
,
res
.status
())));
}
}
}
Err
(
_
)
=>
{
Ok
(
Some
((
url
,
reason
))
)
=>
{
all_healthy
=
false
;
all_healthy
=
false
;
unhealthy_workers
.push
((
url
,
"not ready"
.to_string
()));
unhealthy_workers
.push
((
url
,
reason
));
}
Err
(
e
)
=>
{
all_healthy
=
false
;
unhealthy_workers
.push
((
"unknown"
.to_string
(),
format!
(
"task error: {}"
,
e
)));
}
}
}
}
}
}
...
@@ -208,11 +250,12 @@ impl Router {
...
@@ -208,11 +250,12 @@ impl Router {
return
Ok
(());
return
Ok
(());
}
else
{
}
else
{
debug!
(
debug!
(
"Waiting for {} workers to become healthy ({} unhealthy)"
,
"Waiting for {} workers to become healthy ({} unhealthy
: {:?}
)"
,
worker_urls
.len
(),
worker_urls
.len
(),
unhealthy_workers
.len
()
unhealthy_workers
.len
(),
unhealthy_workers
);
);
t
hread
::
sleep
(
Duration
::
from_secs
(
interval_secs
));
t
okio
::
time
::
sleep
(
Duration
::
from_secs
(
interval_secs
))
.await
;
}
}
}
}
}
}
...
@@ -1246,19 +1289,19 @@ mod tests {
...
@@ -1246,19 +1289,19 @@ mod tests {
assert_eq!
(
result
.unwrap
(),
"http://worker1:8080"
);
assert_eq!
(
result
.unwrap
(),
"http://worker1:8080"
);
}
}
#[test]
#[
tokio::
test]
fn
test_wait_for_healthy_workers_empty_list
()
{
async
fn
test_wait_for_healthy_workers_empty_list
()
{
// Empty list will
timeout as there are no workers to check
// Empty list will
return error immediately
let
result
=
Router
::
wait_for_healthy_workers
(
&
[],
1
,
1
);
let
result
=
Router
::
wait_for_healthy_workers
(
&
[],
1
,
1
)
.await
;
assert
!
(
result
.is_err
());
assert
!
(
result
.is_err
());
assert
!
(
result
.unwrap_err
()
.contains
(
"
Timeout
"
));
assert
!
(
result
.unwrap_err
()
.contains
(
"
no workers provided
"
));
}
}
#[test]
#[
tokio::
test]
fn
test_wait_for_healthy_workers_invalid_urls
()
{
async
fn
test_wait_for_healthy_workers_invalid_urls
()
{
// This test will timeout quickly since the URLs are invalid
// This test will timeout quickly since the URLs are invalid
let
result
=
let
result
=
Router
::
wait_for_healthy_workers
(
&
[
"http://nonexistent:8080"
.to_string
()],
1
,
1
);
Router
::
wait_for_healthy_workers
(
&
[
"http://nonexistent:8080"
.to_string
()],
1
,
1
)
.await
;
assert
!
(
result
.is_err
());
assert
!
(
result
.is_err
());
assert
!
(
result
.unwrap_err
()
.contains
(
"Timeout"
));
assert
!
(
result
.unwrap_err
()
.contains
(
"Timeout"
));
}
}
...
...
sgl-router/src/server.rs
View file @
9d68bdb2
...
@@ -285,7 +285,7 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
...
@@ -285,7 +285,7 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
));
));
// Create router with the context
// Create router with the context
let
router
=
RouterFactory
::
create_router
(
&
app_context
)
?
;
let
router
=
RouterFactory
::
create_router
(
&
app_context
)
.await
?
;
// Create app state with router and context
// Create app state with router and context
let
app_state
=
Arc
::
new
(
AppState
{
let
app_state
=
Arc
::
new
(
AppState
{
...
...
sgl-router/src/service_discovery.rs
View file @
9d68bdb2
...
@@ -576,7 +576,7 @@ mod tests {
...
@@ -576,7 +576,7 @@ mod tests {
}
}
// Helper to create a Router instance for testing event handlers
// Helper to create a Router instance for testing event handlers
fn
create_test_router
()
->
Arc
<
dyn
RouterTrait
>
{
async
fn
create_test_router
()
->
Arc
<
dyn
RouterTrait
>
{
use
crate
::
config
::
PolicyConfig
;
use
crate
::
config
::
PolicyConfig
;
use
crate
::
policies
::
PolicyFactory
;
use
crate
::
policies
::
PolicyFactory
;
use
crate
::
routers
::
router
::
Router
;
use
crate
::
routers
::
router
::
Router
;
...
@@ -593,6 +593,7 @@ mod tests {
...
@@ -593,6 +593,7 @@ mod tests {
crate
::
config
::
types
::
RetryConfig
::
default
(),
crate
::
config
::
types
::
RetryConfig
::
default
(),
crate
::
config
::
types
::
CircuitBreakerConfig
::
default
(),
crate
::
config
::
types
::
CircuitBreakerConfig
::
default
(),
)
)
.await
.unwrap
();
.unwrap
();
Arc
::
new
(
router
)
as
Arc
<
dyn
RouterTrait
>
Arc
::
new
(
router
)
as
Arc
<
dyn
RouterTrait
>
}
}
...
@@ -896,7 +897,7 @@ mod tests {
...
@@ -896,7 +897,7 @@ mod tests {
#[tokio::test]
#[tokio::test]
async
fn
test_handle_pod_event_add_unhealthy_pod
()
{
async
fn
test_handle_pod_event_add_unhealthy_pod
()
{
let
router
=
create_test_router
();
let
router
=
create_test_router
()
.await
;
let
tracked_pods
=
Arc
::
new
(
Mutex
::
new
(
HashSet
::
new
()));
let
tracked_pods
=
Arc
::
new
(
Mutex
::
new
(
HashSet
::
new
()));
let
pod_info
=
PodInfo
{
let
pod_info
=
PodInfo
{
name
:
"pod1"
.into
(),
name
:
"pod1"
.into
(),
...
@@ -925,7 +926,7 @@ mod tests {
...
@@ -925,7 +926,7 @@ mod tests {
#[tokio::test]
#[tokio::test]
async
fn
test_handle_pod_deletion_non_existing_pod
()
{
async
fn
test_handle_pod_deletion_non_existing_pod
()
{
let
router
=
create_test_router
();
let
router
=
create_test_router
()
.await
;
let
tracked_pods
=
Arc
::
new
(
Mutex
::
new
(
HashSet
::
new
()));
let
tracked_pods
=
Arc
::
new
(
Mutex
::
new
(
HashSet
::
new
()));
let
pod_info
=
PodInfo
{
let
pod_info
=
PodInfo
{
name
:
"pod1"
.into
(),
name
:
"pod1"
.into
(),
...
@@ -952,7 +953,7 @@ mod tests {
...
@@ -952,7 +953,7 @@ mod tests {
#[tokio::test]
#[tokio::test]
async
fn
test_handle_pd_pod_event_prefill_pod
()
{
async
fn
test_handle_pd_pod_event_prefill_pod
()
{
let
router
=
create_test_router
();
let
router
=
create_test_router
()
.await
;
let
tracked_pods
=
Arc
::
new
(
Mutex
::
new
(
HashSet
::
new
()));
let
tracked_pods
=
Arc
::
new
(
Mutex
::
new
(
HashSet
::
new
()));
let
pod_info
=
PodInfo
{
let
pod_info
=
PodInfo
{
name
:
"prefill-pod"
.into
(),
name
:
"prefill-pod"
.into
(),
...
@@ -981,7 +982,7 @@ mod tests {
...
@@ -981,7 +982,7 @@ mod tests {
#[tokio::test]
#[tokio::test]
async
fn
test_handle_pd_pod_event_decode_pod
()
{
async
fn
test_handle_pd_pod_event_decode_pod
()
{
let
router
=
create_test_router
();
let
router
=
create_test_router
()
.await
;
let
tracked_pods
=
Arc
::
new
(
Mutex
::
new
(
HashSet
::
new
()));
let
tracked_pods
=
Arc
::
new
(
Mutex
::
new
(
HashSet
::
new
()));
let
pod_info
=
PodInfo
{
let
pod_info
=
PodInfo
{
name
:
"decode-pod"
.into
(),
name
:
"decode-pod"
.into
(),
...
@@ -1008,7 +1009,7 @@ mod tests {
...
@@ -1008,7 +1009,7 @@ mod tests {
#[tokio::test]
#[tokio::test]
async
fn
test_handle_pd_pod_deletion_tracked_pod
()
{
async
fn
test_handle_pd_pod_deletion_tracked_pod
()
{
let
router
=
create_test_router
();
let
router
=
create_test_router
()
.await
;
let
tracked_pods
=
Arc
::
new
(
Mutex
::
new
(
HashSet
::
new
()));
let
tracked_pods
=
Arc
::
new
(
Mutex
::
new
(
HashSet
::
new
()));
let
pod_info
=
PodInfo
{
let
pod_info
=
PodInfo
{
name
:
"test-pod"
.into
(),
name
:
"test-pod"
.into
(),
...
@@ -1042,7 +1043,7 @@ mod tests {
...
@@ -1042,7 +1043,7 @@ mod tests {
#[tokio::test]
#[tokio::test]
async
fn
test_handle_pd_pod_deletion_untracked_pod
()
{
async
fn
test_handle_pd_pod_deletion_untracked_pod
()
{
let
router
=
create_test_router
();
let
router
=
create_test_router
()
.await
;
let
tracked_pods
=
Arc
::
new
(
Mutex
::
new
(
HashSet
::
new
()));
let
tracked_pods
=
Arc
::
new
(
Mutex
::
new
(
HashSet
::
new
()));
let
pod_info
=
PodInfo
{
let
pod_info
=
PodInfo
{
name
:
"untracked-pod"
.into
(),
name
:
"untracked-pod"
.into
(),
...
@@ -1071,7 +1072,7 @@ mod tests {
...
@@ -1071,7 +1072,7 @@ mod tests {
#[tokio::test]
#[tokio::test]
async
fn
test_unified_handler_regular_mode
()
{
async
fn
test_unified_handler_regular_mode
()
{
let
router
=
create_test_router
();
let
router
=
create_test_router
()
.await
;
let
tracked_pods
=
Arc
::
new
(
Mutex
::
new
(
HashSet
::
new
()));
let
tracked_pods
=
Arc
::
new
(
Mutex
::
new
(
HashSet
::
new
()));
let
pod_info
=
PodInfo
{
let
pod_info
=
PodInfo
{
name
:
"regular-pod"
.into
(),
name
:
"regular-pod"
.into
(),
...
@@ -1099,7 +1100,7 @@ mod tests {
...
@@ -1099,7 +1100,7 @@ mod tests {
#[tokio::test]
#[tokio::test]
async
fn
test_unified_handler_pd_mode_with_prefill
()
{
async
fn
test_unified_handler_pd_mode_with_prefill
()
{
let
router
=
create_test_router
();
let
router
=
create_test_router
()
.await
;
let
tracked_pods
=
Arc
::
new
(
Mutex
::
new
(
HashSet
::
new
()));
let
tracked_pods
=
Arc
::
new
(
Mutex
::
new
(
HashSet
::
new
()));
let
pod_info
=
PodInfo
{
let
pod_info
=
PodInfo
{
name
:
"prefill-pod"
.into
(),
name
:
"prefill-pod"
.into
(),
...
@@ -1127,7 +1128,7 @@ mod tests {
...
@@ -1127,7 +1128,7 @@ mod tests {
#[tokio::test]
#[tokio::test]
async
fn
test_unified_handler_deletion_with_pd_mode
()
{
async
fn
test_unified_handler_deletion_with_pd_mode
()
{
let
router
=
create_test_router
();
let
router
=
create_test_router
()
.await
;
let
tracked_pods
=
Arc
::
new
(
Mutex
::
new
(
HashSet
::
new
()));
let
tracked_pods
=
Arc
::
new
(
Mutex
::
new
(
HashSet
::
new
()));
let
pod_info
=
PodInfo
{
let
pod_info
=
PodInfo
{
name
:
"decode-pod"
.into
(),
name
:
"decode-pod"
.into
(),
...
...
sgl-router/tests/api_endpoints_test.rs
View file @
9d68bdb2
...
@@ -92,12 +92,8 @@ impl TestContext {
...
@@ -92,12 +92,8 @@ impl TestContext {
// Create app context
// Create app context
let
app_context
=
common
::
create_test_context
(
config
.clone
());
let
app_context
=
common
::
create_test_context
(
config
.clone
());
// Create router using sync factory in a blocking context
// Create router
let
router
=
let
router
=
RouterFactory
::
create_router
(
&
app_context
)
.await
.unwrap
();
tokio
::
task
::
spawn_blocking
(
move
||
RouterFactory
::
create_router
(
&
app_context
))
.await
.unwrap
()
.unwrap
();
let
router
=
Arc
::
from
(
router
);
let
router
=
Arc
::
from
(
router
);
// Wait for router to discover workers
// Wait for router to discover workers
...
@@ -1451,10 +1447,7 @@ mod pd_mode_tests {
...
@@ -1451,10 +1447,7 @@ mod pd_mode_tests {
let
app_context
=
common
::
create_test_context
(
config
);
let
app_context
=
common
::
create_test_context
(
config
);
// Create router - this might fail due to health check issues
// Create router - this might fail due to health check issues
let
router_result
=
let
router_result
=
RouterFactory
::
create_router
(
&
app_context
)
.await
;
tokio
::
task
::
spawn_blocking
(
move
||
RouterFactory
::
create_router
(
&
app_context
))
.await
.unwrap
();
// Clean up workers
// Clean up workers
prefill_worker
.stop
()
.await
;
prefill_worker
.stop
()
.await
;
...
...
sgl-router/tests/request_formats_test.rs
View file @
9d68bdb2
...
@@ -60,11 +60,7 @@ impl TestContext {
...
@@ -60,11 +60,7 @@ impl TestContext {
config
.mode
=
RoutingMode
::
Regular
{
worker_urls
};
config
.mode
=
RoutingMode
::
Regular
{
worker_urls
};
let
app_context
=
common
::
create_test_context
(
config
);
let
app_context
=
common
::
create_test_context
(
config
);
let
router
=
let
router
=
RouterFactory
::
create_router
(
&
app_context
)
.await
.unwrap
();
tokio
::
task
::
spawn_blocking
(
move
||
RouterFactory
::
create_router
(
&
app_context
))
.await
.unwrap
()
.unwrap
();
let
router
=
Arc
::
from
(
router
);
let
router
=
Arc
::
from
(
router
);
if
!
workers
.is_empty
()
{
if
!
workers
.is_empty
()
{
...
...
sgl-router/tests/streaming_tests.rs
View file @
9d68bdb2
...
@@ -61,11 +61,7 @@ impl TestContext {
...
@@ -61,11 +61,7 @@ impl TestContext {
config
.mode
=
RoutingMode
::
Regular
{
worker_urls
};
config
.mode
=
RoutingMode
::
Regular
{
worker_urls
};
let
app_context
=
common
::
create_test_context
(
config
);
let
app_context
=
common
::
create_test_context
(
config
);
let
router
=
let
router
=
RouterFactory
::
create_router
(
&
app_context
)
.await
.unwrap
();
tokio
::
task
::
spawn_blocking
(
move
||
RouterFactory
::
create_router
(
&
app_context
))
.await
.unwrap
()
.unwrap
();
let
router
=
Arc
::
from
(
router
);
let
router
=
Arc
::
from
(
router
);
if
!
workers
.is_empty
()
{
if
!
workers
.is_empty
()
{
...
...
sgl-router/tests/test_pd_routing.rs
View file @
9d68bdb2
...
@@ -109,8 +109,8 @@ mod test_pd_routing {
...
@@ -109,8 +109,8 @@ mod test_pd_routing {
}
}
}
}
#[test]
#[
tokio::
test]
fn
test_pd_router_configuration
()
{
async
fn
test_pd_router_configuration
()
{
// Test PD router configuration with various policies
// Test PD router configuration with various policies
// In the new structure, RoutingMode and PolicyConfig are separate
// In the new structure, RoutingMode and PolicyConfig are separate
let
test_cases
=
vec!
[
let
test_cases
=
vec!
[
...
@@ -190,7 +190,7 @@ mod test_pd_routing {
...
@@ -190,7 +190,7 @@ mod test_pd_routing {
let
app_context
=
let
app_context
=
sglang_router_rs
::
server
::
AppContext
::
new
(
config
,
reqwest
::
Client
::
new
(),
64
);
sglang_router_rs
::
server
::
AppContext
::
new
(
config
,
reqwest
::
Client
::
new
(),
64
);
let
app_context
=
std
::
sync
::
Arc
::
new
(
app_context
);
let
app_context
=
std
::
sync
::
Arc
::
new
(
app_context
);
let
result
=
RouterFactory
::
create_router
(
&
app_context
);
let
result
=
RouterFactory
::
create_router
(
&
app_context
)
.await
;
assert
!
(
result
.is_err
());
assert
!
(
result
.is_err
());
let
error_msg
=
result
.unwrap_err
();
let
error_msg
=
result
.unwrap_err
();
// Error should be about health/timeout, not configuration
// Error should be about health/timeout, not configuration
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment