Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
9a0cac1b
Unverified
Commit
9a0cac1b
authored
Sep 01, 2025
by
Chang Su
Committed by
GitHub
Sep 01, 2025
Browse files
[router] add grpc pd and regular router init (#9893)
parent
b5245064
Changes
14
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
783 additions
and
58 deletions
+783
-58
sgl-router/py_src/sglang_router/launch_router.py
sgl-router/py_src/sglang_router/launch_router.py
+20
-0
sgl-router/py_src/sglang_router/router.py
sgl-router/py_src/sglang_router/router.py
+6
-0
sgl-router/py_test/test_launch_router.py
sgl-router/py_test/test_launch_router.py
+2
-0
sgl-router/src/config/types.rs
sgl-router/src/config/types.rs
+32
-0
sgl-router/src/config/validation.rs
sgl-router/src/config/validation.rs
+71
-2
sgl-router/src/lib.rs
sgl-router/src/lib.rs
+59
-0
sgl-router/src/main.rs
sgl-router/src/main.rs
+57
-2
sgl-router/src/routers/factory.rs
sgl-router/src/routers/factory.rs
+131
-36
sgl-router/src/routers/grpc/pd_router.rs
sgl-router/src/routers/grpc/pd_router.rs
+226
-7
sgl-router/src/routers/grpc/router.rs
sgl-router/src/routers/grpc/router.rs
+154
-7
sgl-router/tests/api_endpoints_test.rs
sgl-router/tests/api_endpoints_test.rs
+13
-1
sgl-router/tests/request_formats_test.rs
sgl-router/tests/request_formats_test.rs
+4
-1
sgl-router/tests/streaming_tests.rs
sgl-router/tests/streaming_tests.rs
+4
-1
sgl-router/tests/test_pd_routing.rs
sgl-router/tests/test_pd_routing.rs
+4
-1
No files found.
sgl-router/py_src/sglang_router/launch_router.py
View file @
9a0cac1b
...
@@ -99,6 +99,9 @@ class RouterArgs:
...
@@ -99,6 +99,9 @@ class RouterArgs:
cb_timeout_duration_secs
:
int
=
60
cb_timeout_duration_secs
:
int
=
60
cb_window_duration_secs
:
int
=
120
cb_window_duration_secs
:
int
=
120
disable_circuit_breaker
:
bool
=
False
disable_circuit_breaker
:
bool
=
False
# Tokenizer configuration
model_path
:
Optional
[
str
]
=
None
tokenizer_path
:
Optional
[
str
]
=
None
@
staticmethod
@
staticmethod
def
add_cli_args
(
def
add_cli_args
(
...
@@ -433,6 +436,19 @@ class RouterArgs:
...
@@ -433,6 +436,19 @@ class RouterArgs:
default
=
[],
default
=
[],
help
=
"CORS allowed origins (e.g., http://localhost:3000 https://example.com)"
,
help
=
"CORS allowed origins (e.g., http://localhost:3000 https://example.com)"
,
)
)
# Tokenizer configuration
parser
.
add_argument
(
f
"--
{
prefix
}
model-path"
,
type
=
str
,
default
=
None
,
help
=
"Model path for loading tokenizer (HuggingFace model ID or local path)"
,
)
parser
.
add_argument
(
f
"--
{
prefix
}
tokenizer-path"
,
type
=
str
,
default
=
None
,
help
=
"Explicit tokenizer path (overrides model_path tokenizer if provided)"
,
)
@
classmethod
@
classmethod
def
from_cli_args
(
def
from_cli_args
(
...
@@ -554,6 +570,8 @@ class RouterArgs:
...
@@ -554,6 +570,8 @@ class RouterArgs:
health_check_endpoint
=
getattr
(
health_check_endpoint
=
getattr
(
args
,
f
"
{
prefix
}
health_check_endpoint"
,
RouterArgs
.
health_check_endpoint
args
,
f
"
{
prefix
}
health_check_endpoint"
,
RouterArgs
.
health_check_endpoint
),
),
model_path
=
getattr
(
args
,
f
"
{
prefix
}
model_path"
,
None
),
tokenizer_path
=
getattr
(
args
,
f
"
{
prefix
}
tokenizer_path"
,
None
),
)
)
@
staticmethod
@
staticmethod
...
@@ -759,6 +777,8 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]:
...
@@ -759,6 +777,8 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]:
health_check_timeout_secs
=
router_args
.
health_check_timeout_secs
,
health_check_timeout_secs
=
router_args
.
health_check_timeout_secs
,
health_check_interval_secs
=
router_args
.
health_check_interval_secs
,
health_check_interval_secs
=
router_args
.
health_check_interval_secs
,
health_check_endpoint
=
router_args
.
health_check_endpoint
,
health_check_endpoint
=
router_args
.
health_check_endpoint
,
model_path
=
router_args
.
model_path
,
tokenizer_path
=
router_args
.
tokenizer_path
,
)
)
router
.
start
()
router
.
start
()
...
...
sgl-router/py_src/sglang_router/router.py
View file @
9a0cac1b
...
@@ -74,6 +74,8 @@ class Router:
...
@@ -74,6 +74,8 @@ class Router:
health_check_timeout_secs: Timeout in seconds for health check requests. Default: 5
health_check_timeout_secs: Timeout in seconds for health check requests. Default: 5
health_check_interval_secs: Interval in seconds between runtime health checks. Default: 60
health_check_interval_secs: Interval in seconds between runtime health checks. Default: 60
health_check_endpoint: Health check endpoint path. Default: '/health'
health_check_endpoint: Health check endpoint path. Default: '/health'
model_path: Model path for loading tokenizer (HuggingFace model ID or local path). Default: None
tokenizer_path: Explicit tokenizer path (overrides model_path tokenizer if provided). Default: None
"""
"""
def
__init__
(
def
__init__
(
...
@@ -131,6 +133,8 @@ class Router:
...
@@ -131,6 +133,8 @@ class Router:
health_check_timeout_secs
:
int
=
5
,
health_check_timeout_secs
:
int
=
5
,
health_check_interval_secs
:
int
=
60
,
health_check_interval_secs
:
int
=
60
,
health_check_endpoint
:
str
=
"/health"
,
health_check_endpoint
:
str
=
"/health"
,
model_path
:
Optional
[
str
]
=
None
,
tokenizer_path
:
Optional
[
str
]
=
None
,
):
):
if
selector
is
None
:
if
selector
is
None
:
selector
=
{}
selector
=
{}
...
@@ -195,6 +199,8 @@ class Router:
...
@@ -195,6 +199,8 @@ class Router:
health_check_timeout_secs
=
health_check_timeout_secs
,
health_check_timeout_secs
=
health_check_timeout_secs
,
health_check_interval_secs
=
health_check_interval_secs
,
health_check_interval_secs
=
health_check_interval_secs
,
health_check_endpoint
=
health_check_endpoint
,
health_check_endpoint
=
health_check_endpoint
,
model_path
=
model_path
,
tokenizer_path
=
tokenizer_path
,
)
)
def
start
(
self
)
->
None
:
def
start
(
self
)
->
None
:
...
...
sgl-router/py_test/test_launch_router.py
View file @
9a0cac1b
...
@@ -64,6 +64,8 @@ class TestLaunchRouter(unittest.TestCase):
...
@@ -64,6 +64,8 @@ class TestLaunchRouter(unittest.TestCase):
cb_window_duration_secs
=
60
,
cb_window_duration_secs
=
60
,
disable_retries
=
False
,
disable_retries
=
False
,
disable_circuit_breaker
=
False
,
disable_circuit_breaker
=
False
,
model_path
=
None
,
tokenizer_path
=
None
,
)
)
def
create_router_args
(
self
,
**
kwargs
):
def
create_router_args
(
self
,
**
kwargs
):
...
...
sgl-router/src/config/types.rs
View file @
9a0cac1b
...
@@ -7,6 +7,9 @@ use std::collections::HashMap;
...
@@ -7,6 +7,9 @@ use std::collections::HashMap;
pub
struct
RouterConfig
{
pub
struct
RouterConfig
{
/// Routing mode configuration
/// Routing mode configuration
pub
mode
:
RoutingMode
,
pub
mode
:
RoutingMode
,
/// Worker connection mode
#[serde(default)]
pub
connection_mode
:
ConnectionMode
,
/// Policy configuration
/// Policy configuration
pub
policy
:
PolicyConfig
,
pub
policy
:
PolicyConfig
,
/// Server host address
/// Server host address
...
@@ -60,6 +63,20 @@ pub struct RouterConfig {
...
@@ -60,6 +63,20 @@ pub struct RouterConfig {
/// Enable Inference Gateway mode (false = proxy mode, true = IGW mode)
/// Enable Inference Gateway mode (false = proxy mode, true = IGW mode)
#[serde(default)]
#[serde(default)]
pub
enable_igw
:
bool
,
pub
enable_igw
:
bool
,
/// Model path for loading tokenizer (can be a HuggingFace model ID or local path)
pub
model_path
:
Option
<
String
>
,
/// Explicit tokenizer path (overrides model_path tokenizer if provided)
pub
tokenizer_path
:
Option
<
String
>
,
}
#[derive(Debug,
Clone,
Serialize,
Deserialize,
Default,
PartialEq)]
#[serde(tag
=
"type"
)]
pub
enum
ConnectionMode
{
#[default]
#[serde(rename
=
"http"
)]
Http
,
#[serde(rename
=
"grpc"
)]
Grpc
,
}
}
/// Routing mode configuration
/// Routing mode configuration
...
@@ -336,6 +353,9 @@ impl Default for RouterConfig {
...
@@ -336,6 +353,9 @@ impl Default for RouterConfig {
disable_circuit_breaker
:
false
,
disable_circuit_breaker
:
false
,
health_check
:
HealthCheckConfig
::
default
(),
health_check
:
HealthCheckConfig
::
default
(),
enable_igw
:
false
,
enable_igw
:
false
,
connection_mode
:
ConnectionMode
::
Http
,
model_path
:
None
,
tokenizer_path
:
None
,
}
}
}
}
}
}
...
@@ -478,6 +498,9 @@ mod tests {
...
@@ -478,6 +498,9 @@ mod tests {
queue_size
:
100
,
queue_size
:
100
,
queue_timeout_secs
:
60
,
queue_timeout_secs
:
60
,
rate_limit_tokens_per_second
:
None
,
rate_limit_tokens_per_second
:
None
,
connection_mode
:
ConnectionMode
::
Http
,
model_path
:
None
,
tokenizer_path
:
None
,
};
};
let
json
=
serde_json
::
to_string
(
&
config
)
.unwrap
();
let
json
=
serde_json
::
to_string
(
&
config
)
.unwrap
();
...
@@ -914,6 +937,9 @@ mod tests {
...
@@ -914,6 +937,9 @@ mod tests {
queue_size
:
100
,
queue_size
:
100
,
queue_timeout_secs
:
60
,
queue_timeout_secs
:
60
,
rate_limit_tokens_per_second
:
None
,
rate_limit_tokens_per_second
:
None
,
connection_mode
:
ConnectionMode
::
Http
,
model_path
:
None
,
tokenizer_path
:
None
,
};
};
assert
!
(
config
.mode
.is_pd_mode
());
assert
!
(
config
.mode
.is_pd_mode
());
...
@@ -974,6 +1000,9 @@ mod tests {
...
@@ -974,6 +1000,9 @@ mod tests {
queue_size
:
100
,
queue_size
:
100
,
queue_timeout_secs
:
60
,
queue_timeout_secs
:
60
,
rate_limit_tokens_per_second
:
None
,
rate_limit_tokens_per_second
:
None
,
connection_mode
:
ConnectionMode
::
Http
,
model_path
:
None
,
tokenizer_path
:
None
,
};
};
assert
!
(
!
config
.mode
.is_pd_mode
());
assert
!
(
!
config
.mode
.is_pd_mode
());
...
@@ -1030,6 +1059,9 @@ mod tests {
...
@@ -1030,6 +1059,9 @@ mod tests {
queue_size
:
100
,
queue_size
:
100
,
queue_timeout_secs
:
60
,
queue_timeout_secs
:
60
,
rate_limit_tokens_per_second
:
None
,
rate_limit_tokens_per_second
:
None
,
connection_mode
:
ConnectionMode
::
Http
,
model_path
:
None
,
tokenizer_path
:
None
,
};
};
assert
!
(
config
.has_service_discovery
());
assert
!
(
config
.has_service_discovery
());
...
...
sgl-router/src/config/validation.rs
View file @
9a0cac1b
...
@@ -349,6 +349,16 @@ impl ConfigValidator {
...
@@ -349,6 +349,16 @@ impl ConfigValidator {
return
Ok
(());
return
Ok
(());
}
}
// Validate gRPC connection mode requires tokenizer configuration
if
config
.connection_mode
==
ConnectionMode
::
Grpc
&&
config
.tokenizer_path
.is_none
()
&&
config
.model_path
.is_none
()
{
return
Err
(
ConfigError
::
ValidationFailed
{
reason
:
"gRPC connection mode requires either --tokenizer-path or --model-path to be specified"
.to_string
(),
});
}
// All policies are now supported for both router types thanks to the unified trait design
// All policies are now supported for both router types thanks to the unified trait design
// No mode/policy restrictions needed anymore
// No mode/policy restrictions needed anymore
...
@@ -419,11 +429,14 @@ impl ConfigValidator {
...
@@ -419,11 +429,14 @@ impl ConfigValidator {
});
});
}
}
if
!
url
.starts_with
(
"http://"
)
&&
!
url
.starts_with
(
"https://"
)
{
if
!
url
.starts_with
(
"http://"
)
&&
!
url
.starts_with
(
"https://"
)
&&
!
url
.starts_with
(
"grpc://"
)
{
return
Err
(
ConfigError
::
InvalidValue
{
return
Err
(
ConfigError
::
InvalidValue
{
field
:
"worker_url"
.to_string
(),
field
:
"worker_url"
.to_string
(),
value
:
url
.clone
(),
value
:
url
.clone
(),
reason
:
"URL must start with http://
or
https://"
.to_string
(),
reason
:
"URL must start with http://
,
https://
, or grpc://
"
.to_string
(),
});
});
}
}
...
@@ -684,4 +697,60 @@ mod tests {
...
@@ -684,4 +697,60 @@ mod tests {
assert
!
(
e
.to_string
()
.contains
(
"prefill requires at least 2"
));
assert
!
(
e
.to_string
()
.contains
(
"prefill requires at least 2"
));
}
}
}
}
#[test]
fn
test_validate_grpc_requires_tokenizer
()
{
// Test that gRPC connection mode requires tokenizer configuration
let
mut
config
=
RouterConfig
::
new
(
RoutingMode
::
Regular
{
worker_urls
:
vec!
[
"grpc://worker:50051"
.to_string
()],
},
PolicyConfig
::
Random
,
);
// Set connection mode to gRPC without tokenizer config
config
.connection_mode
=
ConnectionMode
::
Grpc
;
config
.tokenizer_path
=
None
;
config
.model_path
=
None
;
let
result
=
ConfigValidator
::
validate
(
&
config
);
assert
!
(
result
.is_err
());
if
let
Err
(
e
)
=
result
{
assert
!
(
e
.to_string
()
.contains
(
"gRPC connection mode requires"
));
}
}
#[test]
fn
test_validate_grpc_with_model_path
()
{
// Test that gRPC works with model_path
let
mut
config
=
RouterConfig
::
new
(
RoutingMode
::
Regular
{
worker_urls
:
vec!
[
"grpc://worker:50051"
.to_string
()],
},
PolicyConfig
::
Random
,
);
config
.connection_mode
=
ConnectionMode
::
Grpc
;
config
.model_path
=
Some
(
"meta-llama/Llama-3-8B"
.to_string
());
let
result
=
ConfigValidator
::
validate
(
&
config
);
assert
!
(
result
.is_ok
());
}
#[test]
fn
test_validate_grpc_with_tokenizer_path
()
{
// Test that gRPC works with tokenizer_path
let
mut
config
=
RouterConfig
::
new
(
RoutingMode
::
Regular
{
worker_urls
:
vec!
[
"grpc://worker:50051"
.to_string
()],
},
PolicyConfig
::
Random
,
);
config
.connection_mode
=
ConnectionMode
::
Grpc
;
config
.tokenizer_path
=
Some
(
"/path/to/tokenizer.json"
.to_string
());
let
result
=
ConfigValidator
::
validate
(
&
config
);
assert
!
(
result
.is_ok
());
}
}
}
sgl-router/src/lib.rs
View file @
9a0cac1b
...
@@ -2,6 +2,7 @@ use pyo3::prelude::*;
...
@@ -2,6 +2,7 @@ use pyo3::prelude::*;
pub
mod
config
;
pub
mod
config
;
pub
mod
logging
;
pub
mod
logging
;
use
std
::
collections
::
HashMap
;
use
std
::
collections
::
HashMap
;
pub
mod
core
;
pub
mod
core
;
#[cfg(feature
=
"grpc-client"
)]
#[cfg(feature
=
"grpc-client"
)]
pub
mod
grpc
;
pub
mod
grpc
;
...
@@ -89,9 +90,39 @@ struct Router {
...
@@ -89,9 +90,39 @@ struct Router {
queue_size
:
usize
,
queue_size
:
usize
,
queue_timeout_secs
:
u64
,
queue_timeout_secs
:
u64
,
rate_limit_tokens_per_second
:
Option
<
usize
>
,
rate_limit_tokens_per_second
:
Option
<
usize
>
,
// Connection mode (determined from worker URLs)
connection_mode
:
config
::
ConnectionMode
,
// Model path for tokenizer
model_path
:
Option
<
String
>
,
// Explicit tokenizer path
tokenizer_path
:
Option
<
String
>
,
}
}
impl
Router
{
impl
Router
{
/// Determine connection mode from worker URLs
fn
determine_connection_mode
(
worker_urls
:
&
[
String
])
->
config
::
ConnectionMode
{
// Check if any URL is a gRPC endpoint (starts with grpc:// or has port that commonly indicates gRPC)
for
url
in
worker_urls
{
if
url
.starts_with
(
"grpc://"
)
||
url
.starts_with
(
"grpcs://"
)
{
return
config
::
ConnectionMode
::
Grpc
;
}
// Also check for common gRPC ports if the scheme isn't specified
if
let
Ok
(
parsed_url
)
=
url
::
Url
::
parse
(
url
)
{
if
let
Some
(
port
)
=
parsed_url
.port
()
{
// Common gRPC ports
if
port
==
50051
||
port
==
9090
||
((
50000
..=
50100
)
.contains
(
&
port
))
{
return
config
::
ConnectionMode
::
Grpc
;
}
}
}
else
if
url
.contains
(
":50051"
)
||
url
.contains
(
":9090"
)
||
url
.contains
(
":5000"
)
{
// Fallback check for URLs that might not parse correctly
return
config
::
ConnectionMode
::
Grpc
;
}
}
// Default to HTTP
config
::
ConnectionMode
::
Http
}
/// Convert PyO3 Router to RouterConfig
/// Convert PyO3 Router to RouterConfig
pub
fn
to_router_config
(
&
self
)
->
config
::
ConfigResult
<
config
::
RouterConfig
>
{
pub
fn
to_router_config
(
&
self
)
->
config
::
ConfigResult
<
config
::
RouterConfig
>
{
use
config
::{
use
config
::{
...
@@ -168,6 +199,7 @@ impl Router {
...
@@ -168,6 +199,7 @@ impl Router {
policy
,
policy
,
host
:
self
.host
.clone
(),
host
:
self
.host
.clone
(),
port
:
self
.port
,
port
:
self
.port
,
connection_mode
:
self
.connection_mode
.clone
(),
max_payload_size
:
self
.max_payload_size
,
max_payload_size
:
self
.max_payload_size
,
request_timeout_secs
:
self
.request_timeout_secs
,
request_timeout_secs
:
self
.request_timeout_secs
,
worker_startup_timeout_secs
:
self
.worker_startup_timeout_secs
,
worker_startup_timeout_secs
:
self
.worker_startup_timeout_secs
,
...
@@ -207,6 +239,8 @@ impl Router {
...
@@ -207,6 +239,8 @@ impl Router {
endpoint
:
self
.health_check_endpoint
.clone
(),
endpoint
:
self
.health_check_endpoint
.clone
(),
},
},
enable_igw
:
self
.enable_igw
,
enable_igw
:
self
.enable_igw
,
model_path
:
self
.model_path
.clone
(),
tokenizer_path
:
self
.tokenizer_path
.clone
(),
})
})
}
}
}
}
...
@@ -273,6 +307,9 @@ impl Router {
...
@@ -273,6 +307,9 @@ impl Router {
queue_size
=
100
,
queue_size
=
100
,
queue_timeout_secs
=
60
,
queue_timeout_secs
=
60
,
rate_limit_tokens_per_second
=
None
,
rate_limit_tokens_per_second
=
None
,
// Tokenizer defaults
model_path
=
None
,
tokenizer_path
=
None
,
))]
))]
#[allow(clippy::too_many_arguments)]
#[allow(clippy::too_many_arguments)]
fn
new
(
fn
new
(
...
@@ -330,7 +367,26 @@ impl Router {
...
@@ -330,7 +367,26 @@ impl Router {
queue_size
:
usize
,
queue_size
:
usize
,
queue_timeout_secs
:
u64
,
queue_timeout_secs
:
u64
,
rate_limit_tokens_per_second
:
Option
<
usize
>
,
rate_limit_tokens_per_second
:
Option
<
usize
>
,
model_path
:
Option
<
String
>
,
tokenizer_path
:
Option
<
String
>
,
)
->
PyResult
<
Self
>
{
)
->
PyResult
<
Self
>
{
// Determine connection mode from worker URLs
let
mut
all_urls
=
worker_urls
.clone
();
// Add prefill URLs if in PD mode
if
let
Some
(
ref
prefill_urls
)
=
prefill_urls
{
for
(
url
,
_
)
in
prefill_urls
{
all_urls
.push
(
url
.clone
());
}
}
// Add decode URLs if in PD mode
if
let
Some
(
ref
decode_urls
)
=
decode_urls
{
all_urls
.extend
(
decode_urls
.clone
());
}
let
connection_mode
=
Self
::
determine_connection_mode
(
&
all_urls
);
Ok
(
Router
{
Ok
(
Router
{
host
,
host
,
port
,
port
,
...
@@ -386,6 +442,9 @@ impl Router {
...
@@ -386,6 +442,9 @@ impl Router {
queue_size
,
queue_size
,
queue_timeout_secs
,
queue_timeout_secs
,
rate_limit_tokens_per_second
,
rate_limit_tokens_per_second
,
connection_mode
,
model_path
,
tokenizer_path
,
})
})
}
}
...
...
sgl-router/src/main.rs
View file @
9a0cac1b
use
clap
::{
ArgAction
,
Parser
};
use
clap
::{
ArgAction
,
Parser
};
use
sglang_router_rs
::
config
::{
use
sglang_router_rs
::
config
::{
CircuitBreakerConfig
,
ConfigError
,
ConfigResult
,
DiscoveryConfig
,
HealthCheck
Config
,
CircuitBreakerConfig
,
ConfigError
,
ConfigResult
,
ConnectionMode
,
Discovery
Config
,
MetricsConfig
,
PolicyConfig
,
RetryConfig
,
RouterConfig
,
RoutingMode
,
HealthCheckConfig
,
MetricsConfig
,
PolicyConfig
,
RetryConfig
,
RouterConfig
,
RoutingMode
,
};
};
use
sglang_router_rs
::
metrics
::
PrometheusConfig
;
use
sglang_router_rs
::
metrics
::
PrometheusConfig
;
use
sglang_router_rs
::
server
::{
self
,
ServerConfig
};
use
sglang_router_rs
::
server
::{
self
,
ServerConfig
};
...
@@ -272,9 +272,42 @@ struct CliArgs {
...
@@ -272,9 +272,42 @@ struct CliArgs {
/// Enable Inference Gateway mode
/// Enable Inference Gateway mode
#[arg(long,
default_value_t
=
false
)]
#[arg(long,
default_value_t
=
false
)]
enable_igw
:
bool
,
enable_igw
:
bool
,
// Tokenizer configuration
/// Model path for loading tokenizer (HuggingFace model ID or local path)
#[arg(long)]
model_path
:
Option
<
String
>
,
/// Explicit tokenizer path (overrides model_path tokenizer if provided)
#[arg(long)]
tokenizer_path
:
Option
<
String
>
,
}
}
impl
CliArgs
{
impl
CliArgs
{
/// Determine connection mode from worker URLs
fn
determine_connection_mode
(
worker_urls
:
&
[
String
])
->
ConnectionMode
{
// Check if any URL is a gRPC endpoint (starts with grpc:// or has port that commonly indicates gRPC)
for
url
in
worker_urls
{
if
url
.starts_with
(
"grpc://"
)
||
url
.starts_with
(
"grpcs://"
)
{
return
ConnectionMode
::
Grpc
;
}
// Also check for common gRPC ports if the scheme isn't specified
if
let
Ok
(
parsed_url
)
=
url
::
Url
::
parse
(
url
)
{
if
let
Some
(
port
)
=
parsed_url
.port
()
{
// Common gRPC ports
if
port
==
50051
||
port
==
9090
||
((
50000
..=
50100
)
.contains
(
&
port
))
{
return
ConnectionMode
::
Grpc
;
}
}
}
else
if
url
.contains
(
":50051"
)
||
url
.contains
(
":9090"
)
||
url
.contains
(
":5000"
)
{
// Fallback check for URLs that might not parse correctly
return
ConnectionMode
::
Grpc
;
}
}
// Default to HTTP
ConnectionMode
::
Http
}
/// Parse selector strings into HashMap
/// Parse selector strings into HashMap
fn
parse_selector
(
selector_list
:
&
[
String
])
->
HashMap
<
String
,
String
>
{
fn
parse_selector
(
selector_list
:
&
[
String
])
->
HashMap
<
String
,
String
>
{
let
mut
map
=
HashMap
::
new
();
let
mut
map
=
HashMap
::
new
();
...
@@ -372,10 +405,30 @@ impl CliArgs {
...
@@ -372,10 +405,30 @@ impl CliArgs {
host
:
self
.prometheus_host
.clone
(),
host
:
self
.prometheus_host
.clone
(),
});
});
// Determine connection mode from all worker URLs
let
mut
all_urls
=
Vec
::
new
();
match
&
mode
{
RoutingMode
::
Regular
{
worker_urls
}
=>
{
all_urls
.extend
(
worker_urls
.clone
());
}
RoutingMode
::
PrefillDecode
{
prefill_urls
,
decode_urls
,
..
}
=>
{
for
(
url
,
_
)
in
prefill_urls
{
all_urls
.push
(
url
.clone
());
}
all_urls
.extend
(
decode_urls
.clone
());
}
}
let
connection_mode
=
Self
::
determine_connection_mode
(
&
all_urls
);
// Build RouterConfig
// Build RouterConfig
Ok
(
RouterConfig
{
Ok
(
RouterConfig
{
mode
,
mode
,
policy
,
policy
,
connection_mode
,
host
:
self
.host
.clone
(),
host
:
self
.host
.clone
(),
port
:
self
.port
,
port
:
self
.port
,
max_payload_size
:
self
.max_payload_size
,
max_payload_size
:
self
.max_payload_size
,
...
@@ -421,6 +474,8 @@ impl CliArgs {
...
@@ -421,6 +474,8 @@ impl CliArgs {
},
},
enable_igw
:
self
.enable_igw
,
enable_igw
:
self
.enable_igw
,
rate_limit_tokens_per_second
:
None
,
rate_limit_tokens_per_second
:
None
,
model_path
:
self
.model_path
.clone
(),
tokenizer_path
:
self
.tokenizer_path
.clone
(),
})
})
}
}
...
...
sgl-router/src/routers/factory.rs
View file @
9a0cac1b
...
@@ -4,7 +4,7 @@ use super::{
...
@@ -4,7 +4,7 @@ use super::{
http
::{
pd_router
::
PDRouter
,
router
::
Router
},
http
::{
pd_router
::
PDRouter
,
router
::
Router
},
RouterTrait
,
RouterTrait
,
};
};
use
crate
::
config
::{
PolicyConfig
,
RoutingMode
};
use
crate
::
config
::{
ConnectionMode
,
PolicyConfig
,
RoutingMode
};
use
crate
::
policies
::
PolicyFactory
;
use
crate
::
policies
::
PolicyFactory
;
use
crate
::
server
::
AppContext
;
use
crate
::
server
::
AppContext
;
use
std
::
sync
::
Arc
;
use
std
::
sync
::
Arc
;
...
@@ -20,28 +20,56 @@ impl RouterFactory {
...
@@ -20,28 +20,56 @@ impl RouterFactory {
return
Self
::
create_igw_router
(
ctx
)
.await
;
return
Self
::
create_igw_router
(
ctx
)
.await
;
}
}
// TODO: Add gRPC mode check here when implementing gRPC support
// Check connection mode and route to appropriate implementation
match
ctx
.router_config.connection_mode
{
// Default to HTTP proxy mode
ConnectionMode
::
Grpc
=>
{
match
&
ctx
.router_config.mode
{
// Route to gRPC implementation based on routing mode
RoutingMode
::
Regular
{
worker_urls
}
=>
{
match
&
ctx
.router_config.mode
{
Self
::
create_regular_router
(
worker_urls
,
&
ctx
.router_config.policy
,
ctx
)
.await
RoutingMode
::
Regular
{
worker_urls
}
=>
{
Self
::
create_grpc_router
(
worker_urls
,
&
ctx
.router_config.policy
,
ctx
)
.await
}
RoutingMode
::
PrefillDecode
{
prefill_urls
,
decode_urls
,
prefill_policy
,
decode_policy
,
}
=>
{
Self
::
create_grpc_pd_router
(
prefill_urls
,
decode_urls
,
prefill_policy
.as_ref
(),
decode_policy
.as_ref
(),
&
ctx
.router_config.policy
,
ctx
,
)
.await
}
}
}
}
RoutingMode
::
PrefillDecode
{
ConnectionMode
::
Http
=>
{
prefill_urls
,
// Route to HTTP implementation based on routing mode
decode_urls
,
match
&
ctx
.router_config.mode
{
prefill_policy
,
RoutingMode
::
Regular
{
worker_urls
}
=>
{
decode_policy
,
Self
::
create_regular_router
(
worker_urls
,
&
ctx
.router_config.policy
,
ctx
)
}
=>
{
.await
Self
::
create_pd_router
(
}
prefill_urls
,
RoutingMode
::
PrefillDecode
{
decode_urls
,
prefill_urls
,
prefill_policy
.as_ref
(),
decode_urls
,
decode_policy
.as_ref
(),
prefill_policy
,
&
ctx
.router_config.policy
,
decode_policy
,
ctx
,
}
=>
{
)
Self
::
create_pd_router
(
.await
prefill_urls
,
decode_urls
,
prefill_policy
.as_ref
(),
decode_policy
.as_ref
(),
&
ctx
.router_config.policy
,
ctx
,
)
.await
}
}
}
}
}
}
}
}
...
@@ -109,25 +137,92 @@ impl RouterFactory {
...
@@ -109,25 +137,92 @@ impl RouterFactory {
/// Create a gRPC router with injected policy
/// Create a gRPC router with injected policy
pub
async
fn
create_grpc_router
(
pub
async
fn
create_grpc_router
(
_
worker_urls
:
&
[
String
],
worker_urls
:
&
[
String
],
_
policy_config
:
&
PolicyConfig
,
policy_config
:
&
PolicyConfig
,
_
ctx
:
&
Arc
<
AppContext
>
,
ctx
:
&
Arc
<
AppContext
>
,
)
->
Result
<
Box
<
dyn
RouterTrait
>
,
String
>
{
)
->
Result
<
Box
<
dyn
RouterTrait
>
,
String
>
{
// For now, return an error as gRPC router is not yet implemented
use
super
::
grpc
::
router
::
GrpcRouter
;
Err
(
"gRPC router is not yet implemented"
.to_string
())
// Create policy
let
policy
=
PolicyFactory
::
create_from_config
(
policy_config
);
// Determine which tokenizer path to use
// Priority: tokenizer_path > model_path
let
tokenizer_path
=
ctx
.router_config
.tokenizer_path
.clone
()
.or_else
(||
ctx
.router_config.model_path
.clone
())
.ok_or_else
(||
{
"gRPC router requires either --tokenizer-path or --model-path to be specified"
.to_string
()
})
?
;
// Create gRPC router
let
router
=
GrpcRouter
::
new
(
worker_urls
.to_vec
(),
policy
,
ctx
.router_config.worker_startup_timeout_secs
,
ctx
.router_config.worker_startup_check_interval_secs
,
ctx
.router_config.dp_aware
,
ctx
.router_config.api_key
.clone
(),
ctx
.router_config
.effective_retry_config
(),
ctx
.router_config
.effective_circuit_breaker_config
(),
ctx
.router_config.health_check
.clone
(),
tokenizer_path
,
)
.await
?
;
Ok
(
Box
::
new
(
router
))
}
}
/// Create a gRPC PD router
(placeholder for now)
/// Create a gRPC PD router
with tokenizer and worker configuration
pub
async
fn
create_grpc_pd_router
(
pub
async
fn
create_grpc_pd_router
(
_
prefill_urls
:
&
[(
String
,
Option
<
u16
>
)],
prefill_urls
:
&
[(
String
,
Option
<
u16
>
)],
_
decode_urls
:
&
[
String
],
decode_urls
:
&
[
String
],
_
prefill_policy_config
:
Option
<&
PolicyConfig
>
,
prefill_policy_config
:
Option
<&
PolicyConfig
>
,
_
decode_policy_config
:
Option
<&
PolicyConfig
>
,
decode_policy_config
:
Option
<&
PolicyConfig
>
,
_
main_policy_config
:
&
PolicyConfig
,
main_policy_config
:
&
PolicyConfig
,
_
ctx
:
&
Arc
<
AppContext
>
,
ctx
:
&
Arc
<
AppContext
>
,
)
->
Result
<
Box
<
dyn
RouterTrait
>
,
String
>
{
)
->
Result
<
Box
<
dyn
RouterTrait
>
,
String
>
{
// For now, return an error as gRPC PD router is not yet implemented
use
super
::
grpc
::
pd_router
::
GrpcPDRouter
;
Err
(
"gRPC PD router is not yet implemented"
.to_string
())
// Create policies - use specific policies if provided, otherwise fall back to main policy
let
prefill_policy
=
PolicyFactory
::
create_from_config
(
prefill_policy_config
.unwrap_or
(
main_policy_config
));
let
decode_policy
=
PolicyFactory
::
create_from_config
(
decode_policy_config
.unwrap_or
(
main_policy_config
));
// Determine which tokenizer path to use
// Priority: tokenizer_path > model_path
let
tokenizer_path
=
ctx
.router_config
.tokenizer_path
.clone
()
.or_else
(||
ctx
.router_config.model_path
.clone
())
.ok_or_else
(||
{
"gRPC PD router requires either --tokenizer-path or --model-path to be specified"
.to_string
()
})
?
;
// Create gRPC PD router
let
router
=
GrpcPDRouter
::
new
(
prefill_urls
.to_vec
(),
decode_urls
.to_vec
(),
prefill_policy
,
decode_policy
,
ctx
.router_config.worker_startup_timeout_secs
,
ctx
.router_config.worker_startup_check_interval_secs
,
ctx
.router_config.dp_aware
,
ctx
.router_config.api_key
.clone
(),
ctx
.router_config
.effective_retry_config
(),
ctx
.router_config
.effective_circuit_breaker_config
(),
ctx
.router_config.health_check
.clone
(),
tokenizer_path
,
)
.await
?
;
Ok
(
Box
::
new
(
router
))
}
}
/// Create an IGW router (placeholder for future implementation)
/// Create an IGW router (placeholder for future implementation)
...
...
sgl-router/src/routers/grpc/pd_router.rs
View file @
9a0cac1b
// PD (Prefill-Decode) gRPC Router Implementation
// PD (Prefill-Decode) gRPC Router Implementation
// TODO: Implement gRPC-based PD router for disaggregated prefill-decode systems
use
crate
::
config
::
types
::{
CircuitBreakerConfig
as
ConfigCircuitBreakerConfig
,
HealthCheckConfig
as
ConfigHealthCheckConfig
,
RetryConfig
,
};
use
crate
::
core
::{
BasicWorker
,
CircuitBreakerConfig
,
HealthChecker
,
HealthConfig
,
Worker
,
WorkerType
,
};
use
crate
::
grpc
::
SglangSchedulerClient
;
use
crate
::
metrics
::
RouterMetrics
;
use
crate
::
policies
::
LoadBalancingPolicy
;
use
crate
::
reasoning_parser
::
ParserFactory
;
use
crate
::
routers
::{
RouterTrait
,
WorkerManagement
};
use
crate
::
routers
::{
RouterTrait
,
WorkerManagement
};
use
crate
::
tokenizer
::{
factory
,
traits
::
Tokenizer
};
use
crate
::
tool_parser
::
ParserRegistry
;
use
async_trait
::
async_trait
;
use
async_trait
::
async_trait
;
use
axum
::{
use
axum
::{
body
::
Body
,
body
::
Body
,
...
@@ -9,15 +21,222 @@ use axum::{
...
@@ -9,15 +21,222 @@ use axum::{
http
::{
HeaderMap
,
StatusCode
},
http
::{
HeaderMap
,
StatusCode
},
response
::{
IntoResponse
,
Response
},
response
::{
IntoResponse
,
Response
},
};
};
use
std
::
collections
::
HashMap
;
use
std
::
sync
::{
Arc
,
RwLock
};
use
std
::
time
::
Duration
;
use
tracing
::{
info
,
warn
};
/// Placeholder for gRPC PD router
/// gRPC PD (Prefill-Decode) router implementation for SGLang
#[derive(Debug)]
#[allow(dead_code)]
// Fields will be used once implementation is complete
pub
struct
GrpcPDRouter
;
pub
struct
GrpcPDRouter
{
/// Prefill worker connections
prefill_workers
:
Arc
<
RwLock
<
Vec
<
Box
<
dyn
Worker
>>>>
,
/// Decode worker connections
decode_workers
:
Arc
<
RwLock
<
Vec
<
Box
<
dyn
Worker
>>>>
,
/// gRPC clients for prefill workers
prefill_grpc_clients
:
Arc
<
RwLock
<
HashMap
<
String
,
SglangSchedulerClient
>>>
,
/// gRPC clients for decode workers
decode_grpc_clients
:
Arc
<
RwLock
<
HashMap
<
String
,
SglangSchedulerClient
>>>
,
/// Load balancing policy for prefill
prefill_policy
:
Arc
<
dyn
LoadBalancingPolicy
>
,
/// Load balancing policy for decode
decode_policy
:
Arc
<
dyn
LoadBalancingPolicy
>
,
/// Tokenizer for handling text encoding/decoding
tokenizer
:
Arc
<
dyn
Tokenizer
>
,
/// Reasoning parser factory for structured reasoning outputs
reasoning_parser_factory
:
ParserFactory
,
/// Tool parser registry for function/tool calls
tool_parser_registry
:
&
'static
ParserRegistry
,
/// Worker health checkers
_
prefill_health_checker
:
Option
<
HealthChecker
>
,
_
decode_health_checker
:
Option
<
HealthChecker
>
,
/// Configuration
timeout_secs
:
u64
,
interval_secs
:
u64
,
dp_aware
:
bool
,
api_key
:
Option
<
String
>
,
retry_config
:
RetryConfig
,
circuit_breaker_config
:
CircuitBreakerConfig
,
}
impl
GrpcPDRouter
{
impl
GrpcPDRouter
{
pub
async
fn
new
()
->
Result
<
Self
,
String
>
{
/// Create a new gRPC PD router
// TODO: Implement gRPC PD router initialization
#[allow(clippy::too_many_arguments)]
Err
(
"gRPC PD router not yet implemented"
.to_string
())
pub
async
fn
new
(
prefill_urls
:
Vec
<
(
String
,
Option
<
u16
>
)
>
,
decode_urls
:
Vec
<
String
>
,
prefill_policy
:
Arc
<
dyn
LoadBalancingPolicy
>
,
decode_policy
:
Arc
<
dyn
LoadBalancingPolicy
>
,
timeout_secs
:
u64
,
interval_secs
:
u64
,
dp_aware
:
bool
,
api_key
:
Option
<
String
>
,
retry_config
:
RetryConfig
,
circuit_breaker_config
:
ConfigCircuitBreakerConfig
,
health_check_config
:
ConfigHealthCheckConfig
,
tokenizer_path_or_model
:
String
,
)
->
Result
<
Self
,
String
>
{
// Update metrics
RouterMetrics
::
set_active_workers
(
prefill_urls
.len
()
+
decode_urls
.len
());
// Initialize tokenizer
let
tokenizer
=
factory
::
create_tokenizer
(
&
tokenizer_path_or_model
)
.map_err
(|
e
|
format!
(
"Failed to create tokenizer: {}"
,
e
))
?
;
// Initialize reasoning parser factory
let
reasoning_parser_factory
=
ParserFactory
::
new
();
// Get tool parser registry
let
tool_parser_registry
=
ParserRegistry
::
new
();
// Convert config CircuitBreakerConfig to core CircuitBreakerConfig
let
core_cb_config
=
CircuitBreakerConfig
{
failure_threshold
:
circuit_breaker_config
.failure_threshold
,
success_threshold
:
circuit_breaker_config
.success_threshold
,
timeout_duration
:
Duration
::
from_secs
(
circuit_breaker_config
.timeout_duration_secs
),
window_duration
:
Duration
::
from_secs
(
circuit_breaker_config
.window_duration_secs
),
};
// Create gRPC clients for prefill workers
let
mut
prefill_grpc_clients
=
HashMap
::
new
();
for
(
url
,
_
bootstrap_port
)
in
&
prefill_urls
{
match
SglangSchedulerClient
::
connect
(
url
)
.await
{
Ok
(
client
)
=>
{
prefill_grpc_clients
.insert
(
url
.clone
(),
client
);
info!
(
"Connected to gRPC prefill worker at {}"
,
url
);
}
Err
(
e
)
=>
{
warn!
(
"Failed to connect to gRPC prefill worker at {}: {}"
,
url
,
e
);
// Continue with other workers
}
}
}
// Create gRPC clients for decode workers
let
mut
decode_grpc_clients
=
HashMap
::
new
();
for
url
in
&
decode_urls
{
match
SglangSchedulerClient
::
connect
(
url
)
.await
{
Ok
(
client
)
=>
{
decode_grpc_clients
.insert
(
url
.clone
(),
client
);
info!
(
"Connected to gRPC decode worker at {}"
,
url
);
}
Err
(
e
)
=>
{
warn!
(
"Failed to connect to gRPC decode worker at {}: {}"
,
url
,
e
);
// Continue with other workers
}
}
}
if
prefill_grpc_clients
.is_empty
()
&&
decode_grpc_clients
.is_empty
()
{
return
Err
(
"Failed to connect to any gRPC workers"
.to_string
());
}
// Create Prefill Worker trait objects with gRPC connection mode
let
prefill_workers
:
Vec
<
Box
<
dyn
Worker
>>
=
prefill_urls
.iter
()
.map
(|(
url
,
bootstrap_port
)|
{
let
worker
=
BasicWorker
::
with_connection_mode
(
url
.clone
(),
WorkerType
::
Prefill
{
bootstrap_port
:
*
bootstrap_port
,
},
crate
::
core
::
ConnectionMode
::
Grpc
{
port
:
*
bootstrap_port
,
},
)
.with_circuit_breaker_config
(
core_cb_config
.clone
())
.with_health_config
(
HealthConfig
{
timeout_secs
:
health_check_config
.timeout_secs
,
check_interval_secs
:
health_check_config
.check_interval_secs
,
endpoint
:
health_check_config
.endpoint
.clone
(),
failure_threshold
:
health_check_config
.failure_threshold
,
success_threshold
:
health_check_config
.success_threshold
,
});
Box
::
new
(
worker
)
as
Box
<
dyn
Worker
>
})
.collect
();
// Create Decode Worker trait objects with gRPC connection mode
let
decode_workers
:
Vec
<
Box
<
dyn
Worker
>>
=
decode_urls
.iter
()
.map
(|
url
|
{
let
worker
=
BasicWorker
::
with_connection_mode
(
url
.clone
(),
WorkerType
::
Decode
,
crate
::
core
::
ConnectionMode
::
Grpc
{
port
:
None
},
)
.with_circuit_breaker_config
(
core_cb_config
.clone
())
.with_health_config
(
HealthConfig
{
timeout_secs
:
health_check_config
.timeout_secs
,
check_interval_secs
:
health_check_config
.check_interval_secs
,
endpoint
:
health_check_config
.endpoint
.clone
(),
failure_threshold
:
health_check_config
.failure_threshold
,
success_threshold
:
health_check_config
.success_threshold
,
});
Box
::
new
(
worker
)
as
Box
<
dyn
Worker
>
})
.collect
();
// Initialize policies with workers if needed
if
let
Some
(
cache_aware
)
=
prefill_policy
.as_any
()
.downcast_ref
::
<
crate
::
policies
::
CacheAwarePolicy
>
()
{
cache_aware
.init_workers
(
&
prefill_workers
);
}
if
let
Some
(
cache_aware
)
=
decode_policy
.as_any
()
.downcast_ref
::
<
crate
::
policies
::
CacheAwarePolicy
>
()
{
cache_aware
.init_workers
(
&
decode_workers
);
}
let
prefill_workers
=
Arc
::
new
(
RwLock
::
new
(
prefill_workers
));
let
decode_workers
=
Arc
::
new
(
RwLock
::
new
(
decode_workers
));
let
prefill_health_checker
=
crate
::
core
::
start_health_checker
(
Arc
::
clone
(
&
prefill_workers
),
interval_secs
);
let
decode_health_checker
=
crate
::
core
::
start_health_checker
(
Arc
::
clone
(
&
decode_workers
),
interval_secs
);
Ok
(
GrpcPDRouter
{
prefill_workers
,
decode_workers
,
prefill_grpc_clients
:
Arc
::
new
(
RwLock
::
new
(
prefill_grpc_clients
)),
decode_grpc_clients
:
Arc
::
new
(
RwLock
::
new
(
decode_grpc_clients
)),
prefill_policy
,
decode_policy
,
tokenizer
,
reasoning_parser_factory
,
tool_parser_registry
,
_
prefill_health_checker
:
Some
(
prefill_health_checker
),
_
decode_health_checker
:
Some
(
decode_health_checker
),
timeout_secs
,
interval_secs
,
dp_aware
,
api_key
,
retry_config
,
circuit_breaker_config
:
core_cb_config
,
})
}
}
impl
std
::
fmt
::
Debug
for
GrpcPDRouter
{
fn
fmt
(
&
self
,
f
:
&
mut
std
::
fmt
::
Formatter
<
'_
>
)
->
std
::
fmt
::
Result
{
f
.debug_struct
(
"GrpcPDRouter"
)
.field
(
"prefill_workers_count"
,
&
self
.prefill_workers
.read
()
.unwrap
()
.len
(),
)
.field
(
"decode_workers_count"
,
&
self
.decode_workers
.read
()
.unwrap
()
.len
(),
)
.field
(
"timeout_secs"
,
&
self
.timeout_secs
)
.field
(
"interval_secs"
,
&
self
.interval_secs
)
.field
(
"dp_aware"
,
&
self
.dp_aware
)
.finish
()
}
}
}
}
...
...
sgl-router/src/routers/grpc/router.rs
View file @
9a0cac1b
// gRPC Router Implementation
// gRPC Router Implementation
// TODO: Implement gRPC-based router
use
crate
::
config
::
types
::{
CircuitBreakerConfig
as
ConfigCircuitBreakerConfig
,
HealthCheckConfig
as
ConfigHealthCheckConfig
,
RetryConfig
,
};
use
crate
::
core
::{
BasicWorker
,
CircuitBreakerConfig
,
HealthChecker
,
HealthConfig
,
Worker
,
WorkerType
,
};
use
crate
::
grpc
::
SglangSchedulerClient
;
use
crate
::
metrics
::
RouterMetrics
;
use
crate
::
policies
::
LoadBalancingPolicy
;
use
crate
::
reasoning_parser
::
ParserFactory
;
use
crate
::
routers
::{
RouterTrait
,
WorkerManagement
};
use
crate
::
routers
::{
RouterTrait
,
WorkerManagement
};
use
crate
::
tokenizer
::{
factory
,
traits
::
Tokenizer
};
use
crate
::
tool_parser
::
ParserRegistry
;
use
async_trait
::
async_trait
;
use
async_trait
::
async_trait
;
use
axum
::{
use
axum
::{
body
::
Body
,
body
::
Body
,
...
@@ -9,15 +21,150 @@ use axum::{
...
@@ -9,15 +21,150 @@ use axum::{
http
::{
HeaderMap
,
StatusCode
},
http
::{
HeaderMap
,
StatusCode
},
response
::{
IntoResponse
,
Response
},
response
::{
IntoResponse
,
Response
},
};
};
use
std
::
collections
::
HashMap
;
use
std
::
sync
::{
Arc
,
RwLock
};
use
std
::
time
::
Duration
;
use
tracing
::{
info
,
warn
};
/// Placeholder for gRPC router
/// gRPC router implementation for SGLang
#[derive(Debug)]
#[allow(dead_code)]
// Fields will be used once implementation is complete
pub
struct
GrpcRouter
;
pub
struct
GrpcRouter
{
/// Worker connections
workers
:
Arc
<
RwLock
<
Vec
<
Box
<
dyn
Worker
>>>>
,
/// gRPC clients for each worker
grpc_clients
:
Arc
<
RwLock
<
HashMap
<
String
,
SglangSchedulerClient
>>>
,
/// Load balancing policy
policy
:
Arc
<
dyn
LoadBalancingPolicy
>
,
/// Tokenizer for handling text encoding/decoding
tokenizer
:
Arc
<
dyn
Tokenizer
>
,
/// Reasoning parser factory for structured reasoning outputs
reasoning_parser_factory
:
ParserFactory
,
/// Tool parser registry for function/tool calls
tool_parser_registry
:
&
'static
ParserRegistry
,
/// Worker health checker
_
health_checker
:
Option
<
HealthChecker
>
,
/// Configuration
timeout_secs
:
u64
,
interval_secs
:
u64
,
dp_aware
:
bool
,
api_key
:
Option
<
String
>
,
retry_config
:
RetryConfig
,
circuit_breaker_config
:
CircuitBreakerConfig
,
}
impl
GrpcRouter
{
impl
GrpcRouter
{
pub
async
fn
new
()
->
Result
<
Self
,
String
>
{
/// Create a new gRPC router
// TODO: Implement gRPC router initialization
#[allow(clippy::too_many_arguments)]
Err
(
"gRPC router not yet implemented"
.to_string
())
pub
async
fn
new
(
worker_urls
:
Vec
<
String
>
,
policy
:
Arc
<
dyn
LoadBalancingPolicy
>
,
timeout_secs
:
u64
,
interval_secs
:
u64
,
dp_aware
:
bool
,
api_key
:
Option
<
String
>
,
retry_config
:
RetryConfig
,
circuit_breaker_config
:
ConfigCircuitBreakerConfig
,
health_check_config
:
ConfigHealthCheckConfig
,
tokenizer_path_or_model
:
String
,
)
->
Result
<
Self
,
String
>
{
// Update metrics
RouterMetrics
::
set_active_workers
(
worker_urls
.len
());
// Initialize tokenizer
let
tokenizer
=
factory
::
create_tokenizer
(
&
tokenizer_path_or_model
)
.map_err
(|
e
|
format!
(
"Failed to create tokenizer: {}"
,
e
))
?
;
// Initialize reasoning parser factory
let
reasoning_parser_factory
=
ParserFactory
::
new
();
// Get tool parser registry
let
tool_parser_registry
=
ParserRegistry
::
new
();
// Convert config CircuitBreakerConfig to core CircuitBreakerConfig
let
core_cb_config
=
CircuitBreakerConfig
{
failure_threshold
:
circuit_breaker_config
.failure_threshold
,
success_threshold
:
circuit_breaker_config
.success_threshold
,
timeout_duration
:
Duration
::
from_secs
(
circuit_breaker_config
.timeout_duration_secs
),
window_duration
:
Duration
::
from_secs
(
circuit_breaker_config
.window_duration_secs
),
};
// Create gRPC clients for each worker
let
mut
grpc_clients
=
HashMap
::
new
();
for
url
in
&
worker_urls
{
match
SglangSchedulerClient
::
connect
(
url
)
.await
{
Ok
(
client
)
=>
{
grpc_clients
.insert
(
url
.clone
(),
client
);
info!
(
"Connected to gRPC worker at {}"
,
url
);
}
Err
(
e
)
=>
{
warn!
(
"Failed to connect to gRPC worker at {}: {}"
,
url
,
e
);
// Continue with other workers
}
}
}
if
grpc_clients
.is_empty
()
{
return
Err
(
"Failed to connect to any gRPC workers"
.to_string
());
}
// Create Worker trait objects with gRPC connection mode
let
workers
:
Vec
<
Box
<
dyn
Worker
>>
=
worker_urls
.iter
()
.map
(|
url
|
{
let
worker
=
BasicWorker
::
with_connection_mode
(
url
.clone
(),
WorkerType
::
Regular
,
crate
::
core
::
ConnectionMode
::
Grpc
{
port
:
None
},
)
.with_circuit_breaker_config
(
core_cb_config
.clone
())
.with_health_config
(
HealthConfig
{
timeout_secs
:
health_check_config
.timeout_secs
,
check_interval_secs
:
health_check_config
.check_interval_secs
,
endpoint
:
health_check_config
.endpoint
.clone
(),
failure_threshold
:
health_check_config
.failure_threshold
,
success_threshold
:
health_check_config
.success_threshold
,
});
Box
::
new
(
worker
)
as
Box
<
dyn
Worker
>
})
.collect
();
// Initialize policy with workers if needed
if
let
Some
(
cache_aware
)
=
policy
.as_any
()
.downcast_ref
::
<
crate
::
policies
::
CacheAwarePolicy
>
()
{
cache_aware
.init_workers
(
&
workers
);
}
let
workers
=
Arc
::
new
(
RwLock
::
new
(
workers
));
let
health_checker
=
crate
::
core
::
start_health_checker
(
Arc
::
clone
(
&
workers
),
interval_secs
);
Ok
(
GrpcRouter
{
workers
,
grpc_clients
:
Arc
::
new
(
RwLock
::
new
(
grpc_clients
)),
policy
,
tokenizer
,
reasoning_parser_factory
,
tool_parser_registry
,
_
health_checker
:
Some
(
health_checker
),
timeout_secs
,
interval_secs
,
dp_aware
,
api_key
,
retry_config
,
circuit_breaker_config
:
core_cb_config
,
})
}
}
impl
std
::
fmt
::
Debug
for
GrpcRouter
{
fn
fmt
(
&
self
,
f
:
&
mut
std
::
fmt
::
Formatter
<
'_
>
)
->
std
::
fmt
::
Result
{
f
.debug_struct
(
"GrpcRouter"
)
.field
(
"workers_count"
,
&
self
.workers
.read
()
.unwrap
()
.len
())
.field
(
"timeout_secs"
,
&
self
.timeout_secs
)
.field
(
"interval_secs"
,
&
self
.interval_secs
)
.field
(
"dp_aware"
,
&
self
.dp_aware
)
.finish
()
}
}
}
}
...
...
sgl-router/tests/api_endpoints_test.rs
View file @
9a0cac1b
...
@@ -9,7 +9,7 @@ use common::mock_worker::{HealthStatus, MockWorker, MockWorkerConfig, WorkerType
...
@@ -9,7 +9,7 @@ use common::mock_worker::{HealthStatus, MockWorker, MockWorkerConfig, WorkerType
use
reqwest
::
Client
;
use
reqwest
::
Client
;
use
serde_json
::
json
;
use
serde_json
::
json
;
use
sglang_router_rs
::
config
::{
use
sglang_router_rs
::
config
::{
CircuitBreakerConfig
,
PolicyConfig
,
RetryConfig
,
RouterConfig
,
RoutingMode
,
CircuitBreakerConfig
,
ConnectionMode
,
PolicyConfig
,
RetryConfig
,
RouterConfig
,
RoutingMode
,
};
};
use
sglang_router_rs
::
routers
::{
RouterFactory
,
RouterTrait
};
use
sglang_router_rs
::
routers
::{
RouterFactory
,
RouterTrait
};
use
std
::
sync
::
Arc
;
use
std
::
sync
::
Arc
;
...
@@ -55,6 +55,9 @@ impl TestContext {
...
@@ -55,6 +55,9 @@ impl TestContext {
disable_circuit_breaker
:
false
,
disable_circuit_breaker
:
false
,
health_check
:
sglang_router_rs
::
config
::
HealthCheckConfig
::
default
(),
health_check
:
sglang_router_rs
::
config
::
HealthCheckConfig
::
default
(),
enable_igw
:
false
,
enable_igw
:
false
,
connection_mode
:
ConnectionMode
::
Http
,
model_path
:
None
,
tokenizer_path
:
None
,
};
};
Self
::
new_with_config
(
config
,
worker_configs
)
.await
Self
::
new_with_config
(
config
,
worker_configs
)
.await
...
@@ -1101,6 +1104,9 @@ mod error_tests {
...
@@ -1101,6 +1104,9 @@ mod error_tests {
disable_circuit_breaker
:
false
,
disable_circuit_breaker
:
false
,
health_check
:
sglang_router_rs
::
config
::
HealthCheckConfig
::
default
(),
health_check
:
sglang_router_rs
::
config
::
HealthCheckConfig
::
default
(),
enable_igw
:
false
,
enable_igw
:
false
,
connection_mode
:
ConnectionMode
::
Http
,
model_path
:
None
,
tokenizer_path
:
None
,
};
};
let
ctx
=
TestContext
::
new_with_config
(
let
ctx
=
TestContext
::
new_with_config
(
...
@@ -1456,6 +1462,9 @@ mod pd_mode_tests {
...
@@ -1456,6 +1462,9 @@ mod pd_mode_tests {
disable_circuit_breaker
:
false
,
disable_circuit_breaker
:
false
,
health_check
:
sglang_router_rs
::
config
::
HealthCheckConfig
::
default
(),
health_check
:
sglang_router_rs
::
config
::
HealthCheckConfig
::
default
(),
enable_igw
:
false
,
enable_igw
:
false
,
connection_mode
:
ConnectionMode
::
Http
,
model_path
:
None
,
tokenizer_path
:
None
,
};
};
// Create app context
// Create app context
...
@@ -1615,6 +1624,9 @@ mod request_id_tests {
...
@@ -1615,6 +1624,9 @@ mod request_id_tests {
disable_circuit_breaker
:
false
,
disable_circuit_breaker
:
false
,
health_check
:
sglang_router_rs
::
config
::
HealthCheckConfig
::
default
(),
health_check
:
sglang_router_rs
::
config
::
HealthCheckConfig
::
default
(),
enable_igw
:
false
,
enable_igw
:
false
,
connection_mode
:
ConnectionMode
::
Http
,
model_path
:
None
,
tokenizer_path
:
None
,
};
};
let
ctx
=
TestContext
::
new_with_config
(
let
ctx
=
TestContext
::
new_with_config
(
...
...
sgl-router/tests/request_formats_test.rs
View file @
9a0cac1b
...
@@ -4,7 +4,7 @@ use common::mock_worker::{HealthStatus, MockWorker, MockWorkerConfig, WorkerType
...
@@ -4,7 +4,7 @@ use common::mock_worker::{HealthStatus, MockWorker, MockWorkerConfig, WorkerType
use
reqwest
::
Client
;
use
reqwest
::
Client
;
use
serde_json
::
json
;
use
serde_json
::
json
;
use
sglang_router_rs
::
config
::{
use
sglang_router_rs
::
config
::{
CircuitBreakerConfig
,
PolicyConfig
,
RetryConfig
,
RouterConfig
,
RoutingMode
,
CircuitBreakerConfig
,
ConnectionMode
,
PolicyConfig
,
RetryConfig
,
RouterConfig
,
RoutingMode
,
};
};
use
sglang_router_rs
::
routers
::{
RouterFactory
,
RouterTrait
};
use
sglang_router_rs
::
routers
::{
RouterFactory
,
RouterTrait
};
use
std
::
sync
::
Arc
;
use
std
::
sync
::
Arc
;
...
@@ -46,6 +46,9 @@ impl TestContext {
...
@@ -46,6 +46,9 @@ impl TestContext {
disable_circuit_breaker
:
false
,
disable_circuit_breaker
:
false
,
health_check
:
sglang_router_rs
::
config
::
HealthCheckConfig
::
default
(),
health_check
:
sglang_router_rs
::
config
::
HealthCheckConfig
::
default
(),
enable_igw
:
false
,
enable_igw
:
false
,
connection_mode
:
ConnectionMode
::
Http
,
model_path
:
None
,
tokenizer_path
:
None
,
};
};
let
mut
workers
=
Vec
::
new
();
let
mut
workers
=
Vec
::
new
();
...
...
sgl-router/tests/streaming_tests.rs
View file @
9a0cac1b
...
@@ -5,7 +5,7 @@ use futures_util::StreamExt;
...
@@ -5,7 +5,7 @@ use futures_util::StreamExt;
use
reqwest
::
Client
;
use
reqwest
::
Client
;
use
serde_json
::
json
;
use
serde_json
::
json
;
use
sglang_router_rs
::
config
::{
use
sglang_router_rs
::
config
::{
CircuitBreakerConfig
,
PolicyConfig
,
RetryConfig
,
RouterConfig
,
RoutingMode
,
CircuitBreakerConfig
,
ConnectionMode
,
PolicyConfig
,
RetryConfig
,
RouterConfig
,
RoutingMode
,
};
};
use
sglang_router_rs
::
routers
::{
RouterFactory
,
RouterTrait
};
use
sglang_router_rs
::
routers
::{
RouterFactory
,
RouterTrait
};
use
std
::
sync
::
Arc
;
use
std
::
sync
::
Arc
;
...
@@ -47,6 +47,9 @@ impl TestContext {
...
@@ -47,6 +47,9 @@ impl TestContext {
disable_circuit_breaker
:
false
,
disable_circuit_breaker
:
false
,
health_check
:
sglang_router_rs
::
config
::
HealthCheckConfig
::
default
(),
health_check
:
sglang_router_rs
::
config
::
HealthCheckConfig
::
default
(),
enable_igw
:
false
,
enable_igw
:
false
,
connection_mode
:
ConnectionMode
::
Http
,
model_path
:
None
,
tokenizer_path
:
None
,
};
};
let
mut
workers
=
Vec
::
new
();
let
mut
workers
=
Vec
::
new
();
...
...
sgl-router/tests/test_pd_routing.rs
View file @
9a0cac1b
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
mod
test_pd_routing
{
mod
test_pd_routing
{
use
serde_json
::
json
;
use
serde_json
::
json
;
use
sglang_router_rs
::
config
::{
use
sglang_router_rs
::
config
::{
CircuitBreakerConfig
,
PolicyConfig
,
RetryConfig
,
RouterConfig
,
RoutingMode
,
CircuitBreakerConfig
,
ConnectionMode
,
PolicyConfig
,
RetryConfig
,
RouterConfig
,
RoutingMode
,
};
};
use
sglang_router_rs
::
core
::{
WorkerFactory
,
WorkerType
};
use
sglang_router_rs
::
core
::{
WorkerFactory
,
WorkerType
};
use
sglang_router_rs
::
routers
::
http
::
pd_types
::
get_hostname
;
use
sglang_router_rs
::
routers
::
http
::
pd_types
::
get_hostname
;
...
@@ -188,6 +188,9 @@ mod test_pd_routing {
...
@@ -188,6 +188,9 @@ mod test_pd_routing {
health_check
:
sglang_router_rs
::
config
::
HealthCheckConfig
::
default
(),
health_check
:
sglang_router_rs
::
config
::
HealthCheckConfig
::
default
(),
enable_igw
:
false
,
enable_igw
:
false
,
rate_limit_tokens_per_second
:
None
,
rate_limit_tokens_per_second
:
None
,
connection_mode
:
ConnectionMode
::
Http
,
model_path
:
None
,
tokenizer_path
:
None
,
};
};
// Router creation will fail due to health checks, but config should be valid
// Router creation will fail due to health checks, but config should be valid
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment