Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
a7fe6e10
Unverified
Commit
a7fe6e10
authored
Sep 26, 2025
by
Simo Lin
Committed by
GitHub
Sep 26, 2025
Browse files
[router] remove old/oudated/useless comments (#10967)
parent
be059b83
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
28 additions
and
306 deletions
+28
-306
sgl-router/src/lib.rs
sgl-router/src/lib.rs
+5
-35
sgl-router/src/logging.rs
sgl-router/src/logging.rs
+1
-34
sgl-router/src/main.rs
sgl-router/src/main.rs
+10
-115
sgl-router/src/metrics.rs
sgl-router/src/metrics.rs
+1
-61
sgl-router/src/service_discovery.rs
sgl-router/src/service_discovery.rs
+11
-61
No files found.
sgl-router/src/lib.rs
View file @
a7fe6e10
...
@@ -67,58 +67,47 @@ struct Router {
...
@@ -67,58 +67,47 @@ struct Router {
decode_policy
:
Option
<
PolicyType
>
,
decode_policy
:
Option
<
PolicyType
>
,
max_concurrent_requests
:
usize
,
max_concurrent_requests
:
usize
,
cors_allowed_origins
:
Vec
<
String
>
,
cors_allowed_origins
:
Vec
<
String
>
,
// Retry configuration
retry_max_retries
:
u32
,
retry_max_retries
:
u32
,
retry_initial_backoff_ms
:
u64
,
retry_initial_backoff_ms
:
u64
,
retry_max_backoff_ms
:
u64
,
retry_max_backoff_ms
:
u64
,
retry_backoff_multiplier
:
f32
,
retry_backoff_multiplier
:
f32
,
retry_jitter_factor
:
f32
,
retry_jitter_factor
:
f32
,
disable_retries
:
bool
,
disable_retries
:
bool
,
// Circuit breaker configuration
cb_failure_threshold
:
u32
,
cb_failure_threshold
:
u32
,
cb_success_threshold
:
u32
,
cb_success_threshold
:
u32
,
cb_timeout_duration_secs
:
u64
,
cb_timeout_duration_secs
:
u64
,
cb_window_duration_secs
:
u64
,
cb_window_duration_secs
:
u64
,
disable_circuit_breaker
:
bool
,
disable_circuit_breaker
:
bool
,
// Health check configuration
health_failure_threshold
:
u32
,
health_failure_threshold
:
u32
,
health_success_threshold
:
u32
,
health_success_threshold
:
u32
,
health_check_timeout_secs
:
u64
,
health_check_timeout_secs
:
u64
,
health_check_interval_secs
:
u64
,
health_check_interval_secs
:
u64
,
health_check_endpoint
:
String
,
health_check_endpoint
:
String
,
// IGW (Inference Gateway) configuration
enable_igw
:
bool
,
enable_igw
:
bool
,
queue_size
:
usize
,
queue_size
:
usize
,
queue_timeout_secs
:
u64
,
queue_timeout_secs
:
u64
,
rate_limit_tokens_per_second
:
Option
<
usize
>
,
rate_limit_tokens_per_second
:
Option
<
usize
>
,
// Connection mode (determined from worker URLs)
connection_mode
:
config
::
ConnectionMode
,
connection_mode
:
config
::
ConnectionMode
,
// Model path for tokenizer
model_path
:
Option
<
String
>
,
model_path
:
Option
<
String
>
,
// Explicit tokenizer path
tokenizer_path
:
Option
<
String
>
,
tokenizer_path
:
Option
<
String
>
,
}
}
impl
Router
{
impl
Router
{
/// Determine connection mode from worker URLs
/// Determine connection mode from worker URLs
fn
determine_connection_mode
(
worker_urls
:
&
[
String
])
->
config
::
ConnectionMode
{
fn
determine_connection_mode
(
worker_urls
:
&
[
String
])
->
config
::
ConnectionMode
{
// Only consider it gRPC if explicitly specified with grpc:// or grpcs:// scheme
for
url
in
worker_urls
{
for
url
in
worker_urls
{
if
url
.starts_with
(
"grpc://"
)
||
url
.starts_with
(
"grpcs://"
)
{
if
url
.starts_with
(
"grpc://"
)
||
url
.starts_with
(
"grpcs://"
)
{
return
config
::
ConnectionMode
::
Grpc
;
return
config
::
ConnectionMode
::
Grpc
;
}
}
}
}
// Default to HTTP for all other cases (including http://, https://, or no scheme)
config
::
ConnectionMode
::
Http
config
::
ConnectionMode
::
Http
}
}
/// Convert PyO3 Router to RouterConfig
pub
fn
to_router_config
(
&
self
)
->
config
::
ConfigResult
<
config
::
RouterConfig
>
{
pub
fn
to_router_config
(
&
self
)
->
config
::
ConfigResult
<
config
::
RouterConfig
>
{
use
config
::{
use
config
::{
DiscoveryConfig
,
MetricsConfig
,
PolicyConfig
as
ConfigPolicyConfig
,
RoutingMode
,
DiscoveryConfig
,
MetricsConfig
,
PolicyConfig
as
ConfigPolicyConfig
,
RoutingMode
,
};
};
// Convert policy helper function
let
convert_policy
=
|
policy
:
&
PolicyType
|
->
ConfigPolicyConfig
{
let
convert_policy
=
|
policy
:
&
PolicyType
|
->
ConfigPolicyConfig
{
match
policy
{
match
policy
{
PolicyType
::
Random
=>
ConfigPolicyConfig
::
Random
,
PolicyType
::
Random
=>
ConfigPolicyConfig
::
Random
,
...
@@ -131,14 +120,12 @@ impl Router {
...
@@ -131,14 +120,12 @@ impl Router {
max_tree_size
:
self
.max_tree_size
,
max_tree_size
:
self
.max_tree_size
,
},
},
PolicyType
::
PowerOfTwo
=>
ConfigPolicyConfig
::
PowerOfTwo
{
PolicyType
::
PowerOfTwo
=>
ConfigPolicyConfig
::
PowerOfTwo
{
load_check_interval_secs
:
5
,
// Default value
load_check_interval_secs
:
5
,
},
},
}
}
};
};
// Determine routing mode
let
mode
=
if
self
.enable_igw
{
let
mode
=
if
self
.enable_igw
{
// IGW mode - routing mode is not used in IGW, but we need to provide a placeholder
RoutingMode
::
Regular
{
RoutingMode
::
Regular
{
worker_urls
:
vec!
[],
worker_urls
:
vec!
[],
}
}
...
@@ -155,10 +142,8 @@ impl Router {
...
@@ -155,10 +142,8 @@ impl Router {
}
}
};
};
// Convert main policy
let
policy
=
convert_policy
(
&
self
.policy
);
let
policy
=
convert_policy
(
&
self
.policy
);
// Service discovery configuration
let
discovery
=
if
self
.service_discovery
{
let
discovery
=
if
self
.service_discovery
{
Some
(
DiscoveryConfig
{
Some
(
DiscoveryConfig
{
enabled
:
true
,
enabled
:
true
,
...
@@ -174,7 +159,6 @@ impl Router {
...
@@ -174,7 +159,6 @@ impl Router {
None
None
};
};
// Metrics configuration
let
metrics
=
match
(
self
.prometheus_port
,
self
.prometheus_host
.as_ref
())
{
let
metrics
=
match
(
self
.prometheus_port
,
self
.prometheus_host
.as_ref
())
{
(
Some
(
port
),
Some
(
host
))
=>
Some
(
MetricsConfig
{
(
Some
(
port
),
Some
(
host
))
=>
Some
(
MetricsConfig
{
port
,
port
,
...
@@ -251,7 +235,7 @@ impl Router {
...
@@ -251,7 +235,7 @@ impl Router {
balance_rel_threshold
=
1.5
,
balance_rel_threshold
=
1.5
,
eviction_interval_secs
=
120
,
eviction_interval_secs
=
120
,
max_tree_size
=
2u
size
.
pow(
26
),
max_tree_size
=
2u
size
.
pow(
26
),
max_payload_size
=
512
*
1024
*
1024
,
// 512MB default for large batches
max_payload_size
=
512
*
1024
*
1024
,
dp_aware
=
false
,
dp_aware
=
false
,
api_key
=
None,
api_key
=
None,
log_dir
=
None,
log_dir
=
None,
...
@@ -265,40 +249,35 @@ impl Router {
...
@@ -265,40 +249,35 @@ impl Router {
bootstrap_port_annotation
=
String::from(
"sglang.ai/bootstrap-port"
),
bootstrap_port_annotation
=
String::from(
"sglang.ai/bootstrap-port"
),
prometheus_port
=
None,
prometheus_port
=
None,
prometheus_host
=
None,
prometheus_host
=
None,
request_timeout_secs
=
1800
,
// Add configurable request timeout
request_timeout_secs
=
1800
,
request_id_headers
=
None,
// Custom request ID headers
request_id_headers
=
None,
pd_disaggregation
=
false
,
// New flag for PD mode
pd_disaggregation
=
false
,
prefill_urls
=
None,
prefill_urls
=
None,
decode_urls
=
None,
decode_urls
=
None,
prefill_policy
=
None,
prefill_policy
=
None,
decode_policy
=
None,
decode_policy
=
None,
max_concurrent_requests
=
256
,
max_concurrent_requests
=
256
,
cors_allowed_origins
=
vec
![
]
,
cors_allowed_origins
=
vec
![
]
,
// Retry defaults
retry_max_retries
=
5
,
retry_max_retries
=
5
,
retry_initial_backoff_ms
=
50
,
retry_initial_backoff_ms
=
50
,
retry_max_backoff_ms
=
30_000
,
retry_max_backoff_ms
=
30_000
,
retry_backoff_multiplier
=
1.5
,
retry_backoff_multiplier
=
1.5
,
retry_jitter_factor
=
0.2
,
retry_jitter_factor
=
0.2
,
disable_retries
=
false
,
disable_retries
=
false
,
// Circuit breaker defaults
cb_failure_threshold
=
10
,
cb_failure_threshold
=
10
,
cb_success_threshold
=
3
,
cb_success_threshold
=
3
,
cb_timeout_duration_secs
=
60
,
cb_timeout_duration_secs
=
60
,
cb_window_duration_secs
=
120
,
cb_window_duration_secs
=
120
,
disable_circuit_breaker
=
false
,
disable_circuit_breaker
=
false
,
// Health check defaults
health_failure_threshold
=
3
,
health_failure_threshold
=
3
,
health_success_threshold
=
2
,
health_success_threshold
=
2
,
health_check_timeout_secs
=
5
,
health_check_timeout_secs
=
5
,
health_check_interval_secs
=
60
,
health_check_interval_secs
=
60
,
health_check_endpoint
=
String
::
from
(
"/health"
),
health_check_endpoint
=
String
::
from
(
"/health"
),
// IGW defaults
enable_igw
=
false
,
enable_igw
=
false
,
queue_size
=
100
,
queue_size
=
100
,
queue_timeout_secs
=
60
,
queue_timeout_secs
=
60
,
rate_limit_tokens_per_second
=
None
,
rate_limit_tokens_per_second
=
None
,
// Tokenizer defaults
model_path
=
None
,
model_path
=
None
,
tokenizer_path
=
None
,
tokenizer_path
=
None
,
))]
))]
...
@@ -361,17 +340,14 @@ impl Router {
...
@@ -361,17 +340,14 @@ impl Router {
model_path
:
Option
<
String
>
,
model_path
:
Option
<
String
>
,
tokenizer_path
:
Option
<
String
>
,
tokenizer_path
:
Option
<
String
>
,
)
->
PyResult
<
Self
>
{
)
->
PyResult
<
Self
>
{
// Determine connection mode from worker URLs
let
mut
all_urls
=
worker_urls
.clone
();
let
mut
all_urls
=
worker_urls
.clone
();
// Add prefill URLs if in PD mode
if
let
Some
(
ref
prefill_urls
)
=
prefill_urls
{
if
let
Some
(
ref
prefill_urls
)
=
prefill_urls
{
for
(
url
,
_
)
in
prefill_urls
{
for
(
url
,
_
)
in
prefill_urls
{
all_urls
.push
(
url
.clone
());
all_urls
.push
(
url
.clone
());
}
}
}
}
// Add decode URLs if in PD mode
if
let
Some
(
ref
decode_urls
)
=
decode_urls
{
if
let
Some
(
ref
decode_urls
)
=
decode_urls
{
all_urls
.extend
(
decode_urls
.clone
());
all_urls
.extend
(
decode_urls
.clone
());
}
}
...
@@ -440,12 +416,10 @@ impl Router {
...
@@ -440,12 +416,10 @@ impl Router {
}
}
fn
start
(
&
self
)
->
PyResult
<
()
>
{
fn
start
(
&
self
)
->
PyResult
<
()
>
{
// Convert to RouterConfig and validate
let
router_config
=
self
.to_router_config
()
.map_err
(|
e
|
{
let
router_config
=
self
.to_router_config
()
.map_err
(|
e
|
{
pyo3
::
exceptions
::
PyValueError
::
new_err
(
format!
(
"Configuration error: {}"
,
e
))
pyo3
::
exceptions
::
PyValueError
::
new_err
(
format!
(
"Configuration error: {}"
,
e
))
})
?
;
})
?
;
// Validate the configuration
router_config
.validate
()
.map_err
(|
e
|
{
router_config
.validate
()
.map_err
(|
e
|
{
pyo3
::
exceptions
::
PyValueError
::
new_err
(
format!
(
pyo3
::
exceptions
::
PyValueError
::
new_err
(
format!
(
"Configuration validation failed: {}"
,
"Configuration validation failed: {}"
,
...
@@ -453,7 +427,6 @@ impl Router {
...
@@ -453,7 +427,6 @@ impl Router {
))
))
})
?
;
})
?
;
// Create service discovery config if enabled
let
service_discovery_config
=
if
self
.service_discovery
{
let
service_discovery_config
=
if
self
.service_discovery
{
Some
(
service_discovery
::
ServiceDiscoveryConfig
{
Some
(
service_discovery
::
ServiceDiscoveryConfig
{
enabled
:
true
,
enabled
:
true
,
...
@@ -470,7 +443,6 @@ impl Router {
...
@@ -470,7 +443,6 @@ impl Router {
None
None
};
};
// Create Prometheus config if enabled
let
prometheus_config
=
Some
(
PrometheusConfig
{
let
prometheus_config
=
Some
(
PrometheusConfig
{
port
:
self
.prometheus_port
.unwrap_or
(
29000
),
port
:
self
.prometheus_port
.unwrap_or
(
29000
),
host
:
self
host
:
self
...
@@ -479,11 +451,9 @@ impl Router {
...
@@ -479,11 +451,9 @@ impl Router {
.unwrap_or_else
(||
"127.0.0.1"
.to_string
()),
.unwrap_or_else
(||
"127.0.0.1"
.to_string
()),
});
});
// Use tokio runtime instead of actix-web System for better compatibility
let
runtime
=
tokio
::
runtime
::
Runtime
::
new
()
let
runtime
=
tokio
::
runtime
::
Runtime
::
new
()
.map_err
(|
e
|
pyo3
::
exceptions
::
PyRuntimeError
::
new_err
(
e
.to_string
()))
?
;
.map_err
(|
e
|
pyo3
::
exceptions
::
PyRuntimeError
::
new_err
(
e
.to_string
()))
?
;
// Block on the async startup function
runtime
.block_on
(
async
move
{
runtime
.block_on
(
async
move
{
server
::
startup
(
server
::
ServerConfig
{
server
::
startup
(
server
::
ServerConfig
{
host
:
self
.host
.clone
(),
host
:
self
.host
.clone
(),
...
...
sgl-router/src/logging.rs
View file @
a7fe6e10
...
@@ -8,20 +8,13 @@ use tracing_subscriber::layer::SubscriberExt;
...
@@ -8,20 +8,13 @@ use tracing_subscriber::layer::SubscriberExt;
use
tracing_subscriber
::
util
::
SubscriberInitExt
;
use
tracing_subscriber
::
util
::
SubscriberInitExt
;
use
tracing_subscriber
::{
EnvFilter
,
Layer
};
use
tracing_subscriber
::{
EnvFilter
,
Layer
};
/// Configuration for the logging system
#[derive(Debug,
Clone)]
#[derive(Debug,
Clone)]
pub
struct
LoggingConfig
{
pub
struct
LoggingConfig
{
/// Log level for the application (default: INFO)
pub
level
:
Level
,
pub
level
:
Level
,
/// Whether to use json format for logs (default: false)
pub
json_format
:
bool
,
pub
json_format
:
bool
,
/// Path to store log files. If None, logs will only go to stdout/stderr
pub
log_dir
:
Option
<
String
>
,
pub
log_dir
:
Option
<
String
>
,
/// Whether to colorize logs when output is a terminal (default: true)
pub
colorize
:
bool
,
pub
colorize
:
bool
,
/// Log file name to use if log_dir is specified (default: "sgl-router")
pub
log_file_name
:
String
,
pub
log_file_name
:
String
,
/// Custom log targets to filter (default: "sglang_router_rs")
pub
log_targets
:
Option
<
Vec
<
String
>>
,
pub
log_targets
:
Option
<
Vec
<
String
>>
,
}
}
...
@@ -38,30 +31,14 @@ impl Default for LoggingConfig {
...
@@ -38,30 +31,14 @@ impl Default for LoggingConfig {
}
}
}
}
/// Guard that keeps the file appender worker thread alive
///
/// This must be kept in scope for the duration of the program
/// to ensure logs are properly written to files
#[allow(dead_code)]
#[allow(dead_code)]
pub
struct
LogGuard
{
pub
struct
LogGuard
{
_
file_guard
:
Option
<
WorkerGuard
>
,
_
file_guard
:
Option
<
WorkerGuard
>
,
}
}
/// Initialize the logging system with the given configuration
///
/// # Arguments
/// * `config` - Configuration for the logging system
///
/// # Returns
/// A LogGuard that must be kept alive for the duration of the program
///
/// # Panics
/// Will not panic, as initialization errors are handled gracefully
pub
fn
init_logging
(
config
:
LoggingConfig
)
->
LogGuard
{
pub
fn
init_logging
(
config
:
LoggingConfig
)
->
LogGuard
{
// Forward logs to tracing - ignore errors to allow for multiple initialization
let
_
=
LogTracer
::
init
();
let
_
=
LogTracer
::
init
();
// Convert log level to filter string
let
level_filter
=
match
config
.level
{
let
level_filter
=
match
config
.level
{
Level
::
TRACE
=>
"trace"
,
Level
::
TRACE
=>
"trace"
,
Level
::
DEBUG
=>
"debug"
,
Level
::
DEBUG
=>
"debug"
,
...
@@ -70,9 +47,7 @@ pub fn init_logging(config: LoggingConfig) -> LogGuard {
...
@@ -70,9 +47,7 @@ pub fn init_logging(config: LoggingConfig) -> LogGuard {
Level
::
ERROR
=>
"error"
,
Level
::
ERROR
=>
"error"
,
};
};
// Create env filter
let
env_filter
=
EnvFilter
::
try_from_default_env
()
.unwrap_or_else
(|
_
|
{
let
env_filter
=
EnvFilter
::
try_from_default_env
()
.unwrap_or_else
(|
_
|
{
// Format: <target>=<level>,<target2>=<level2>,...
let
filter_string
=
if
let
Some
(
targets
)
=
&
config
.log_targets
{
let
filter_string
=
if
let
Some
(
targets
)
=
&
config
.log_targets
{
targets
targets
.iter
()
.iter
()
...
@@ -92,13 +67,10 @@ pub fn init_logging(config: LoggingConfig) -> LogGuard {
...
@@ -92,13 +67,10 @@ pub fn init_logging(config: LoggingConfig) -> LogGuard {
EnvFilter
::
new
(
filter_string
)
EnvFilter
::
new
(
filter_string
)
});
});
// Setup stdout/stderr layer
let
mut
layers
=
Vec
::
new
();
let
mut
layers
=
Vec
::
new
();
// Standard timestamp format: YYYY-MM-DD HH:MM:SS
let
time_format
=
"%Y-%m-%d %H:%M:%S"
.to_string
();
let
time_format
=
"%Y-%m-%d %H:%M:%S"
.to_string
();
// Configure the console stdout layer
let
stdout_layer
=
tracing_subscriber
::
fmt
::
layer
()
let
stdout_layer
=
tracing_subscriber
::
fmt
::
layer
()
.with_ansi
(
config
.colorize
)
.with_ansi
(
config
.colorize
)
.with_file
(
true
)
.with_file
(
true
)
...
@@ -113,14 +85,12 @@ pub fn init_logging(config: LoggingConfig) -> LogGuard {
...
@@ -113,14 +85,12 @@ pub fn init_logging(config: LoggingConfig) -> LogGuard {
layers
.push
(
stdout_layer
);
layers
.push
(
stdout_layer
);
// Create a file appender if log_dir is specified
let
mut
file_guard
=
None
;
let
mut
file_guard
=
None
;
if
let
Some
(
log_dir
)
=
&
config
.log_dir
{
if
let
Some
(
log_dir
)
=
&
config
.log_dir
{
let
file_name
=
config
.log_file_name
.clone
();
let
file_name
=
config
.log_file_name
.clone
();
let
log_dir
=
PathBuf
::
from
(
log_dir
);
let
log_dir
=
PathBuf
::
from
(
log_dir
);
// Create log directory if it doesn't exist
if
!
log_dir
.exists
()
{
if
!
log_dir
.exists
()
{
if
let
Err
(
e
)
=
std
::
fs
::
create_dir_all
(
&
log_dir
)
{
if
let
Err
(
e
)
=
std
::
fs
::
create_dir_all
(
&
log_dir
)
{
eprintln!
(
"Failed to create log directory: {}"
,
e
);
eprintln!
(
"Failed to create log directory: {}"
,
e
);
...
@@ -134,7 +104,7 @@ pub fn init_logging(config: LoggingConfig) -> LogGuard {
...
@@ -134,7 +104,7 @@ pub fn init_logging(config: LoggingConfig) -> LogGuard {
file_guard
=
Some
(
guard
);
file_guard
=
Some
(
guard
);
let
file_layer
=
tracing_subscriber
::
fmt
::
layer
()
let
file_layer
=
tracing_subscriber
::
fmt
::
layer
()
.with_ansi
(
false
)
// Never use ANSI colors in log files
.with_ansi
(
false
)
.with_file
(
true
)
.with_file
(
true
)
.with_line_number
(
true
)
.with_line_number
(
true
)
.with_timer
(
ChronoUtc
::
new
(
time_format
))
.with_timer
(
ChronoUtc
::
new
(
time_format
))
...
@@ -149,14 +119,11 @@ pub fn init_logging(config: LoggingConfig) -> LogGuard {
...
@@ -149,14 +119,11 @@ pub fn init_logging(config: LoggingConfig) -> LogGuard {
layers
.push
(
file_layer
);
layers
.push
(
file_layer
);
}
}
// Initialize the subscriber with all layers
// Use try_init to handle errors gracefully in case another subscriber is already set
let
_
=
tracing_subscriber
::
registry
()
let
_
=
tracing_subscriber
::
registry
()
.with
(
env_filter
)
.with
(
env_filter
)
.with
(
layers
)
.with
(
layers
)
.try_init
();
.try_init
();
// Return the guard to keep the file appender worker thread alive
LogGuard
{
LogGuard
{
_
file_guard
:
file_guard
,
_
file_guard
:
file_guard
,
}
}
...
...
sgl-router/src/main.rs
View file @
a7fe6e10
...
@@ -9,7 +9,6 @@ use sglang_router_rs::server::{self, ServerConfig};
...
@@ -9,7 +9,6 @@ use sglang_router_rs::server::{self, ServerConfig};
use
sglang_router_rs
::
service_discovery
::
ServiceDiscoveryConfig
;
use
sglang_router_rs
::
service_discovery
::
ServiceDiscoveryConfig
;
use
std
::
collections
::
HashMap
;
use
std
::
collections
::
HashMap
;
// Helper function to parse prefill arguments from command line
fn
parse_prefill_args
()
->
Vec
<
(
String
,
Option
<
u16
>
)
>
{
fn
parse_prefill_args
()
->
Vec
<
(
String
,
Option
<
u16
>
)
>
{
let
args
:
Vec
<
String
>
=
std
::
env
::
args
()
.collect
();
let
args
:
Vec
<
String
>
=
std
::
env
::
args
()
.collect
();
let
mut
prefill_entries
=
Vec
::
new
();
let
mut
prefill_entries
=
Vec
::
new
();
...
@@ -19,12 +18,11 @@ fn parse_prefill_args() -> Vec<(String, Option<u16>)> {
...
@@ -19,12 +18,11 @@ fn parse_prefill_args() -> Vec<(String, Option<u16>)> {
if
args
[
i
]
==
"--prefill"
&&
i
+
1
<
args
.len
()
{
if
args
[
i
]
==
"--prefill"
&&
i
+
1
<
args
.len
()
{
let
url
=
args
[
i
+
1
]
.clone
();
let
url
=
args
[
i
+
1
]
.clone
();
let
bootstrap_port
=
if
i
+
2
<
args
.len
()
&&
!
args
[
i
+
2
]
.starts_with
(
"--"
)
{
let
bootstrap_port
=
if
i
+
2
<
args
.len
()
&&
!
args
[
i
+
2
]
.starts_with
(
"--"
)
{
// Check if next arg is a port number
if
let
Ok
(
port
)
=
args
[
i
+
2
]
.parse
::
<
u16
>
()
{
if
let
Ok
(
port
)
=
args
[
i
+
2
]
.parse
::
<
u16
>
()
{
i
+=
1
;
// Skip the port argument
i
+=
1
;
Some
(
port
)
Some
(
port
)
}
else
if
args
[
i
+
2
]
.to_lowercase
()
==
"none"
{
}
else
if
args
[
i
+
2
]
.to_lowercase
()
==
"none"
{
i
+=
1
;
// Skip the "none" argument
i
+=
1
;
None
None
}
else
{
}
else
{
None
None
...
@@ -33,7 +31,7 @@ fn parse_prefill_args() -> Vec<(String, Option<u16>)> {
...
@@ -33,7 +31,7 @@ fn parse_prefill_args() -> Vec<(String, Option<u16>)> {
None
None
};
};
prefill_entries
.push
((
url
,
bootstrap_port
));
prefill_entries
.push
((
url
,
bootstrap_port
));
i
+=
2
;
// Skip --prefill and URL
i
+=
2
;
}
else
{
}
else
{
i
+=
1
;
i
+=
1
;
}
}
...
@@ -101,252 +99,186 @@ Examples:
...
@@ -101,252 +99,186 @@ Examples:
"#
)]
"#
)]
struct
CliArgs
{
struct
CliArgs
{
/// Host address to bind the router server
#[arg(long,
default_value
=
"127.0.0.1"
)]
#[arg(long,
default_value
=
"127.0.0.1"
)]
host
:
String
,
host
:
String
,
/// Port number to bind the router server
#[arg(long,
default_value_t
=
30000
)]
#[arg(long,
default_value_t
=
30000
)]
port
:
u16
,
port
:
u16
,
/// List of worker URLs (e.g., http://worker1:8000 http://worker2:8000)
#[arg(long,
num_args
=
0
..
)]
#[arg(long,
num_args
=
0
..
)]
worker_urls
:
Vec
<
String
>
,
worker_urls
:
Vec
<
String
>
,
/// Load balancing policy to use
#[arg(long,
default_value
=
"cache_aware"
,
value_parser
=
[
"random"
,
"round_robin"
,
"cache_aware"
,
"power_of_two"
]
)]
#[arg(long,
default_value
=
"cache_aware"
,
value_parser
=
[
"random"
,
"round_robin"
,
"cache_aware"
,
"power_of_two"
]
)]
policy
:
String
,
policy
:
String
,
/// Enable PD (Prefill-Decode) disaggregated mode
#[arg(long,
default_value_t
=
false
)]
#[arg(long,
default_value_t
=
false
)]
pd_disaggregation
:
bool
,
pd_disaggregation
:
bool
,
/// Decode server URL (can be specified multiple times)
#[arg(long,
action
=
ArgAction::Append)]
#[arg(long,
action
=
ArgAction::Append)]
decode
:
Vec
<
String
>
,
decode
:
Vec
<
String
>
,
/// Specific policy for prefill nodes in PD mode
#[arg(long,
value_parser
=
[
"random"
,
"round_robin"
,
"cache_aware"
,
"power_of_two"
]
)]
#[arg(long,
value_parser
=
[
"random"
,
"round_robin"
,
"cache_aware"
,
"power_of_two"
]
)]
prefill_policy
:
Option
<
String
>
,
prefill_policy
:
Option
<
String
>
,
/// Specific policy for decode nodes in PD mode
#[arg(long,
value_parser
=
[
"random"
,
"round_robin"
,
"cache_aware"
,
"power_of_two"
]
)]
#[arg(long,
value_parser
=
[
"random"
,
"round_robin"
,
"cache_aware"
,
"power_of_two"
]
)]
decode_policy
:
Option
<
String
>
,
decode_policy
:
Option
<
String
>
,
/// Timeout in seconds for worker startup
#[arg(long,
default_value_t
=
600
)]
#[arg(long,
default_value_t
=
600
)]
worker_startup_timeout_secs
:
u64
,
worker_startup_timeout_secs
:
u64
,
/// Interval in seconds between checks for worker startup
#[arg(long,
default_value_t
=
30
)]
#[arg(long,
default_value_t
=
30
)]
worker_startup_check_interval
:
u64
,
worker_startup_check_interval
:
u64
,
/// Cache threshold (0.0-1.0) for cache-aware routing
#[arg(long,
default_value_t
=
0.3
)]
#[arg(long,
default_value_t
=
0.3
)]
cache_threshold
:
f32
,
cache_threshold
:
f32
,
/// Absolute threshold for load balancing
#[arg(long,
default_value_t
=
64
)]
#[arg(long,
default_value_t
=
64
)]
balance_abs_threshold
:
usize
,
balance_abs_threshold
:
usize
,
/// Relative threshold for load balancing
#[arg(long,
default_value_t
=
1.5
)]
#[arg(long,
default_value_t
=
1.5
)]
balance_rel_threshold
:
f32
,
balance_rel_threshold
:
f32
,
/// Interval in seconds between cache eviction operations
#[arg(long,
default_value_t
=
120
)]
#[arg(long,
default_value_t
=
120
)]
eviction_interval
:
u64
,
eviction_interval
:
u64
,
/// Maximum size of the approximation tree for cache-aware routing
#[arg(long,
default_value_t
=
67108864
)]
#[arg(long,
default_value_t
=
67108864
)]
// 2^26
max_tree_size
:
usize
,
max_tree_size
:
usize
,
/// Maximum payload size in bytes
#[arg(long,
default_value_t
=
536870912
)]
#[arg(long,
default_value_t
=
536870912
)]
// 512MB
max_payload_size
:
usize
,
max_payload_size
:
usize
,
/// Enable data parallelism aware schedule
#[arg(long,
default_value_t
=
false
)]
#[arg(long,
default_value_t
=
false
)]
dp_aware
:
bool
,
dp_aware
:
bool
,
/// API key for worker authorization
#[arg(long)]
#[arg(long)]
api_key
:
Option
<
String
>
,
api_key
:
Option
<
String
>
,
/// Backend to route requests to (sglang, vllm, trtllm, openai, anthropic)
#[arg(long,
value_enum,
default_value_t
=
Backend::Sglang,
alias
=
"runtime"
)]
#[arg(long,
value_enum,
default_value_t
=
Backend::Sglang,
alias
=
"runtime"
)]
backend
:
Backend
,
backend
:
Backend
,
/// Directory to store log files
#[arg(long)]
#[arg(long)]
log_dir
:
Option
<
String
>
,
log_dir
:
Option
<
String
>
,
/// Set the logging level
#[arg(long,
default_value
=
"info"
,
value_parser
=
[
"debug"
,
"info"
,
"warn"
,
"error"
]
)]
#[arg(long,
default_value
=
"info"
,
value_parser
=
[
"debug"
,
"info"
,
"warn"
,
"error"
]
)]
log_level
:
String
,
log_level
:
String
,
/// Enable Kubernetes service discovery
#[arg(long,
default_value_t
=
false
)]
#[arg(long,
default_value_t
=
false
)]
service_discovery
:
bool
,
service_discovery
:
bool
,
/// Label selector for Kubernetes service discovery (format: key1=value1 key2=value2)
#[arg(long,
num_args
=
0
..
)]
#[arg(long,
num_args
=
0
..
)]
selector
:
Vec
<
String
>
,
selector
:
Vec
<
String
>
,
/// Port to use for discovered worker pods
#[arg(long,
default_value_t
=
80
)]
#[arg(long,
default_value_t
=
80
)]
service_discovery_port
:
u16
,
service_discovery_port
:
u16
,
/// Kubernetes namespace to watch for pods
#[arg(long)]
#[arg(long)]
service_discovery_namespace
:
Option
<
String
>
,
service_discovery_namespace
:
Option
<
String
>
,
/// Label selector for prefill server pods in PD mode
#[arg(long,
num_args
=
0
..
)]
#[arg(long,
num_args
=
0
..
)]
prefill_selector
:
Vec
<
String
>
,
prefill_selector
:
Vec
<
String
>
,
/// Label selector for decode server pods in PD mode
#[arg(long,
num_args
=
0
..
)]
#[arg(long,
num_args
=
0
..
)]
decode_selector
:
Vec
<
String
>
,
decode_selector
:
Vec
<
String
>
,
/// Port to expose Prometheus metrics
#[arg(long,
default_value_t
=
29000
)]
#[arg(long,
default_value_t
=
29000
)]
prometheus_port
:
u16
,
prometheus_port
:
u16
,
/// Host address to bind the Prometheus metrics server
#[arg(long,
default_value
=
"127.0.0.1"
)]
#[arg(long,
default_value
=
"127.0.0.1"
)]
prometheus_host
:
String
,
prometheus_host
:
String
,
/// Custom HTTP headers to check for request IDs
#[arg(long,
num_args
=
0
..
)]
#[arg(long,
num_args
=
0
..
)]
request_id_headers
:
Vec
<
String
>
,
request_id_headers
:
Vec
<
String
>
,
/// Request timeout in seconds
#[arg(long,
default_value_t
=
1800
)]
#[arg(long,
default_value_t
=
1800
)]
request_timeout_secs
:
u64
,
request_timeout_secs
:
u64
,
/// Maximum number of concurrent requests allowed
#[arg(long,
default_value_t
=
256
)]
#[arg(long,
default_value_t
=
256
)]
max_concurrent_requests
:
usize
,
max_concurrent_requests
:
usize
,
/// CORS allowed origins
#[arg(long,
num_args
=
0
..
)]
#[arg(long,
num_args
=
0
..
)]
cors_allowed_origins
:
Vec
<
String
>
,
cors_allowed_origins
:
Vec
<
String
>
,
// Retry configuration
/// Maximum number of retries
#[arg(long,
default_value_t
=
5
)]
#[arg(long,
default_value_t
=
5
)]
retry_max_retries
:
u32
,
retry_max_retries
:
u32
,
/// Initial backoff in milliseconds for retries
#[arg(long,
default_value_t
=
50
)]
#[arg(long,
default_value_t
=
50
)]
retry_initial_backoff_ms
:
u64
,
retry_initial_backoff_ms
:
u64
,
/// Maximum backoff in milliseconds for retries
#[arg(long,
default_value_t
=
30000
)]
#[arg(long,
default_value_t
=
30000
)]
retry_max_backoff_ms
:
u64
,
retry_max_backoff_ms
:
u64
,
/// Backoff multiplier for exponential backoff
#[arg(long,
default_value_t
=
1.5
)]
#[arg(long,
default_value_t
=
1.5
)]
retry_backoff_multiplier
:
f32
,
retry_backoff_multiplier
:
f32
,
/// Jitter factor for retry backoff
#[arg(long,
default_value_t
=
0.2
)]
#[arg(long,
default_value_t
=
0.2
)]
retry_jitter_factor
:
f32
,
retry_jitter_factor
:
f32
,
/// Disable retries
#[arg(long,
default_value_t
=
false
)]
#[arg(long,
default_value_t
=
false
)]
disable_retries
:
bool
,
disable_retries
:
bool
,
// Circuit breaker configuration
/// Number of failures before circuit breaker opens
#[arg(long,
default_value_t
=
10
)]
#[arg(long,
default_value_t
=
10
)]
cb_failure_threshold
:
u32
,
cb_failure_threshold
:
u32
,
/// Number of successes before circuit breaker closes
#[arg(long,
default_value_t
=
3
)]
#[arg(long,
default_value_t
=
3
)]
cb_success_threshold
:
u32
,
cb_success_threshold
:
u32
,
/// Timeout duration in seconds for circuit breaker
#[arg(long,
default_value_t
=
60
)]
#[arg(long,
default_value_t
=
60
)]
cb_timeout_duration_secs
:
u64
,
cb_timeout_duration_secs
:
u64
,
/// Window duration in seconds for circuit breaker
#[arg(long,
default_value_t
=
120
)]
#[arg(long,
default_value_t
=
120
)]
cb_window_duration_secs
:
u64
,
cb_window_duration_secs
:
u64
,
/// Disable circuit breaker
#[arg(long,
default_value_t
=
false
)]
#[arg(long,
default_value_t
=
false
)]
disable_circuit_breaker
:
bool
,
disable_circuit_breaker
:
bool
,
// Health check configuration
/// Number of consecutive health check failures before marking worker unhealthy
#[arg(long,
default_value_t
=
3
)]
#[arg(long,
default_value_t
=
3
)]
health_failure_threshold
:
u32
,
health_failure_threshold
:
u32
,
/// Number of consecutive health check successes before marking worker healthy
#[arg(long,
default_value_t
=
2
)]
#[arg(long,
default_value_t
=
2
)]
health_success_threshold
:
u32
,
health_success_threshold
:
u32
,
/// Timeout in seconds for health check requests
#[arg(long,
default_value_t
=
5
)]
#[arg(long,
default_value_t
=
5
)]
health_check_timeout_secs
:
u64
,
health_check_timeout_secs
:
u64
,
/// Interval in seconds between runtime health checks
#[arg(long,
default_value_t
=
60
)]
#[arg(long,
default_value_t
=
60
)]
health_check_interval_secs
:
u64
,
health_check_interval_secs
:
u64
,
/// Health check endpoint path
#[arg(long,
default_value
=
"/health"
)]
#[arg(long,
default_value
=
"/health"
)]
health_check_endpoint
:
String
,
health_check_endpoint
:
String
,
// IGW (Inference Gateway) configuration
/// Enable Inference Gateway mode
#[arg(long,
default_value_t
=
false
)]
#[arg(long,
default_value_t
=
false
)]
enable_igw
:
bool
,
enable_igw
:
bool
,
// Tokenizer configuration
/// Model path for loading tokenizer (HuggingFace model ID or local path)
#[arg(long)]
#[arg(long)]
model_path
:
Option
<
String
>
,
model_path
:
Option
<
String
>
,
/// Explicit tokenizer path (overrides model_path tokenizer if provided)
#[arg(long)]
#[arg(long)]
tokenizer_path
:
Option
<
String
>
,
tokenizer_path
:
Option
<
String
>
,
/// History backend configuration (memory, none, or oracle)
#[arg(long,
default_value
=
"memory"
,
value_parser
=
[
"memory"
,
"none"
,
"oracle"
]
)]
#[arg(long,
default_value
=
"memory"
,
value_parser
=
[
"memory"
,
"none"
,
"oracle"
]
)]
history_backend
:
String
,
history_backend
:
String
,
/// Directory containing the Oracle ATP wallet/config files (optional)
#[arg(long,
env
=
"ATP_WALLET_PATH"
)]
#[arg(long,
env
=
"ATP_WALLET_PATH"
)]
oracle_wallet_path
:
Option
<
String
>
,
oracle_wallet_path
:
Option
<
String
>
,
/// Wallet TNS alias to use (e.g. `<db_name>_low`)
#[arg(long,
env
=
"ATP_TNS_ALIAS"
)]
#[arg(long,
env
=
"ATP_TNS_ALIAS"
)]
oracle_tns_alias
:
Option
<
String
>
,
oracle_tns_alias
:
Option
<
String
>
,
/// Oracle connection descriptor / DSN (e.g. `tcps://host:port/service_name`)
#[arg(long,
env
=
"ATP_DSN"
)]
#[arg(long,
env
=
"ATP_DSN"
)]
oracle_dsn
:
Option
<
String
>
,
oracle_dsn
:
Option
<
String
>
,
/// Oracle ATP username
#[arg(long,
env
=
"ATP_USER"
)]
#[arg(long,
env
=
"ATP_USER"
)]
oracle_user
:
Option
<
String
>
,
oracle_user
:
Option
<
String
>
,
/// Oracle ATP password
#[arg(long,
env
=
"ATP_PASSWORD"
)]
#[arg(long,
env
=
"ATP_PASSWORD"
)]
oracle_password
:
Option
<
String
>
,
oracle_password
:
Option
<
String
>
,
/// Minimum number of pooled ATP connections (defaults to 1 when omitted)
#[arg(long,
env
=
"ATP_POOL_MIN"
)]
#[arg(long,
env
=
"ATP_POOL_MIN"
)]
oracle_pool_min
:
Option
<
usize
>
,
oracle_pool_min
:
Option
<
usize
>
,
/// Maximum number of pooled ATP connections (defaults to 16 when omitted)
#[arg(long,
env
=
"ATP_POOL_MAX"
)]
#[arg(long,
env
=
"ATP_POOL_MAX"
)]
oracle_pool_max
:
Option
<
usize
>
,
oracle_pool_max
:
Option
<
usize
>
,
/// Connection acquisition timeout in seconds (defaults to 30 when omitted)
#[arg(long,
env
=
"ATP_POOL_TIMEOUT_SECS"
)]
#[arg(long,
env
=
"ATP_POOL_TIMEOUT_SECS"
)]
oracle_pool_timeout_secs
:
Option
<
u64
>
,
oracle_pool_timeout_secs
:
Option
<
u64
>
,
}
}
...
@@ -357,19 +289,15 @@ enum OracleConnectSource {
...
@@ -357,19 +289,15 @@ enum OracleConnectSource {
}
}
impl
CliArgs
{
impl
CliArgs
{
/// Determine connection mode from worker URLs
fn
determine_connection_mode
(
worker_urls
:
&
[
String
])
->
ConnectionMode
{
fn
determine_connection_mode
(
worker_urls
:
&
[
String
])
->
ConnectionMode
{
// Only consider it gRPC if explicitly specified with grpc:// or grpcs:// scheme
for
url
in
worker_urls
{
for
url
in
worker_urls
{
if
url
.starts_with
(
"grpc://"
)
||
url
.starts_with
(
"grpcs://"
)
{
if
url
.starts_with
(
"grpc://"
)
||
url
.starts_with
(
"grpcs://"
)
{
return
ConnectionMode
::
Grpc
;
return
ConnectionMode
::
Grpc
;
}
}
}
}
// Default to HTTP for all other cases (including http://, https://, or no scheme)
ConnectionMode
::
Http
ConnectionMode
::
Http
}
}
/// Parse selector strings into HashMap
fn
parse_selector
(
selector_list
:
&
[
String
])
->
HashMap
<
String
,
String
>
{
fn
parse_selector
(
selector_list
:
&
[
String
])
->
HashMap
<
String
,
String
>
{
let
mut
map
=
HashMap
::
new
();
let
mut
map
=
HashMap
::
new
();
for
item
in
selector_list
{
for
item
in
selector_list
{
...
@@ -382,7 +310,6 @@ impl CliArgs {
...
@@ -382,7 +310,6 @@ impl CliArgs {
map
map
}
}
/// Convert policy string to PolicyConfig
fn
parse_policy
(
&
self
,
policy_str
:
&
str
)
->
PolicyConfig
{
fn
parse_policy
(
&
self
,
policy_str
:
&
str
)
->
PolicyConfig
{
match
policy_str
{
match
policy_str
{
"random"
=>
PolicyConfig
::
Random
,
"random"
=>
PolicyConfig
::
Random
,
...
@@ -395,9 +322,9 @@ impl CliArgs {
...
@@ -395,9 +322,9 @@ impl CliArgs {
max_tree_size
:
self
.max_tree_size
,
max_tree_size
:
self
.max_tree_size
,
},
},
"power_of_two"
=>
PolicyConfig
::
PowerOfTwo
{
"power_of_two"
=>
PolicyConfig
::
PowerOfTwo
{
load_check_interval_secs
:
5
,
// Default value
load_check_interval_secs
:
5
,
},
},
_
=>
PolicyConfig
::
RoundRobin
,
// Fallback
_
=>
PolicyConfig
::
RoundRobin
,
}
}
}
}
...
@@ -482,26 +409,21 @@ impl CliArgs {
...
@@ -482,26 +409,21 @@ impl CliArgs {
})
})
}
}
/// Convert CLI arguments to RouterConfig
fn
to_router_config
(
fn
to_router_config
(
&
self
,
&
self
,
prefill_urls
:
Vec
<
(
String
,
Option
<
u16
>
)
>
,
prefill_urls
:
Vec
<
(
String
,
Option
<
u16
>
)
>
,
)
->
ConfigResult
<
RouterConfig
>
{
)
->
ConfigResult
<
RouterConfig
>
{
// Determine routing mode
let
mode
=
if
self
.enable_igw
{
let
mode
=
if
self
.enable_igw
{
// IGW mode - routing mode is not used in IGW, but we need to provide a placeholder
RoutingMode
::
Regular
{
RoutingMode
::
Regular
{
worker_urls
:
vec!
[],
worker_urls
:
vec!
[],
}
}
}
else
if
matches!
(
self
.backend
,
Backend
::
Openai
)
{
}
else
if
matches!
(
self
.backend
,
Backend
::
Openai
)
{
// OpenAI backend mode - use worker_urls as base(s)
RoutingMode
::
OpenAI
{
RoutingMode
::
OpenAI
{
worker_urls
:
self
.worker_urls
.clone
(),
worker_urls
:
self
.worker_urls
.clone
(),
}
}
}
else
if
self
.pd_disaggregation
{
}
else
if
self
.pd_disaggregation
{
let
decode_urls
=
self
.decode
.clone
();
let
decode_urls
=
self
.decode
.clone
();
// Validate PD configuration if not using service discovery
if
!
self
.service_discovery
&&
(
prefill_urls
.is_empty
()
||
decode_urls
.is_empty
())
{
if
!
self
.service_discovery
&&
(
prefill_urls
.is_empty
()
||
decode_urls
.is_empty
())
{
return
Err
(
ConfigError
::
ValidationFailed
{
return
Err
(
ConfigError
::
ValidationFailed
{
reason
:
"PD disaggregation mode requires --prefill and --decode URLs when not using service discovery"
.to_string
(),
reason
:
"PD disaggregation mode requires --prefill and --decode URLs when not using service discovery"
.to_string
(),
...
@@ -515,7 +437,6 @@ impl CliArgs {
...
@@ -515,7 +437,6 @@ impl CliArgs {
decode_policy
:
self
.decode_policy
.as_ref
()
.map
(|
p
|
self
.parse_policy
(
p
)),
decode_policy
:
self
.decode_policy
.as_ref
()
.map
(|
p
|
self
.parse_policy
(
p
)),
}
}
}
else
{
}
else
{
// Regular mode
if
!
self
.service_discovery
&&
self
.worker_urls
.is_empty
()
{
if
!
self
.service_discovery
&&
self
.worker_urls
.is_empty
()
{
return
Err
(
ConfigError
::
ValidationFailed
{
return
Err
(
ConfigError
::
ValidationFailed
{
reason
:
"Regular mode requires --worker-urls when not using service discovery"
reason
:
"Regular mode requires --worker-urls when not using service discovery"
...
@@ -527,10 +448,8 @@ impl CliArgs {
...
@@ -527,10 +448,8 @@ impl CliArgs {
}
}
};
};
// Main policy
let
policy
=
self
.parse_policy
(
&
self
.policy
);
let
policy
=
self
.parse_policy
(
&
self
.policy
);
// Service discovery configuration
let
discovery
=
if
self
.service_discovery
{
let
discovery
=
if
self
.service_discovery
{
Some
(
DiscoveryConfig
{
Some
(
DiscoveryConfig
{
enabled
:
true
,
enabled
:
true
,
...
@@ -546,13 +465,11 @@ impl CliArgs {
...
@@ -546,13 +465,11 @@ impl CliArgs {
None
None
};
};
// Metrics configuration
let
metrics
=
Some
(
MetricsConfig
{
let
metrics
=
Some
(
MetricsConfig
{
port
:
self
.prometheus_port
,
port
:
self
.prometheus_port
,
host
:
self
.prometheus_host
.clone
(),
host
:
self
.prometheus_host
.clone
(),
});
});
// Determine connection mode from all worker URLs
let
mut
all_urls
=
Vec
::
new
();
let
mut
all_urls
=
Vec
::
new
();
match
&
mode
{
match
&
mode
{
RoutingMode
::
Regular
{
worker_urls
}
=>
{
RoutingMode
::
Regular
{
worker_urls
}
=>
{
...
@@ -568,9 +485,7 @@ impl CliArgs {
...
@@ -568,9 +485,7 @@ impl CliArgs {
}
}
all_urls
.extend
(
decode_urls
.clone
());
all_urls
.extend
(
decode_urls
.clone
());
}
}
RoutingMode
::
OpenAI
{
..
}
=>
{
RoutingMode
::
OpenAI
{
..
}
=>
{}
// For connection-mode detection, skip URLs; OpenAI forces HTTP below.
}
}
}
let
connection_mode
=
match
&
mode
{
let
connection_mode
=
match
&
mode
{
RoutingMode
::
OpenAI
{
..
}
=>
ConnectionMode
::
Http
,
RoutingMode
::
OpenAI
{
..
}
=>
ConnectionMode
::
Http
,
...
@@ -589,7 +504,6 @@ impl CliArgs {
...
@@ -589,7 +504,6 @@ impl CliArgs {
None
None
};
};
// Build RouterConfig
Ok
(
RouterConfig
{
Ok
(
RouterConfig
{
mode
,
mode
,
policy
,
policy
,
...
@@ -612,8 +526,8 @@ impl CliArgs {
...
@@ -612,8 +526,8 @@ impl CliArgs {
Some
(
self
.request_id_headers
.clone
())
Some
(
self
.request_id_headers
.clone
())
},
},
max_concurrent_requests
:
self
.max_concurrent_requests
,
max_concurrent_requests
:
self
.max_concurrent_requests
,
queue_size
:
100
,
// Default queue size
queue_size
:
100
,
queue_timeout_secs
:
60
,
// Default timeout
queue_timeout_secs
:
60
,
cors_allowed_origins
:
self
.cors_allowed_origins
.clone
(),
cors_allowed_origins
:
self
.cors_allowed_origins
.clone
(),
retry
:
RetryConfig
{
retry
:
RetryConfig
{
max_retries
:
self
.retry_max_retries
,
max_retries
:
self
.retry_max_retries
,
...
@@ -646,9 +560,7 @@ impl CliArgs {
...
@@ -646,9 +560,7 @@ impl CliArgs {
})
})
}
}
/// Create ServerConfig from CLI args and RouterConfig
fn
to_server_config
(
&
self
,
router_config
:
RouterConfig
)
->
ServerConfig
{
fn
to_server_config
(
&
self
,
router_config
:
RouterConfig
)
->
ServerConfig
{
// Create service discovery config if enabled
let
service_discovery_config
=
if
self
.service_discovery
{
let
service_discovery_config
=
if
self
.service_discovery
{
Some
(
ServiceDiscoveryConfig
{
Some
(
ServiceDiscoveryConfig
{
enabled
:
true
,
enabled
:
true
,
...
@@ -665,7 +577,6 @@ impl CliArgs {
...
@@ -665,7 +577,6 @@ impl CliArgs {
None
None
};
};
// Create Prometheus config
let
prometheus_config
=
Some
(
PrometheusConfig
{
let
prometheus_config
=
Some
(
PrometheusConfig
{
port
:
self
.prometheus_port
,
port
:
self
.prometheus_port
,
host
:
self
.prometheus_host
.clone
(),
host
:
self
.prometheus_host
.clone
(),
...
@@ -691,19 +602,15 @@ impl CliArgs {
...
@@ -691,19 +602,15 @@ impl CliArgs {
}
}
fn
main
()
->
Result
<
(),
Box
<
dyn
std
::
error
::
Error
>>
{
fn
main
()
->
Result
<
(),
Box
<
dyn
std
::
error
::
Error
>>
{
// Parse prefill arguments manually before clap parsing
let
prefill_urls
=
parse_prefill_args
();
let
prefill_urls
=
parse_prefill_args
();
// Filter out prefill arguments and their values before passing to clap
let
mut
filtered_args
:
Vec
<
String
>
=
Vec
::
new
();
let
mut
filtered_args
:
Vec
<
String
>
=
Vec
::
new
();
let
raw_args
:
Vec
<
String
>
=
std
::
env
::
args
()
.collect
();
let
raw_args
:
Vec
<
String
>
=
std
::
env
::
args
()
.collect
();
let
mut
i
=
0
;
let
mut
i
=
0
;
while
i
<
raw_args
.len
()
{
while
i
<
raw_args
.len
()
{
if
raw_args
[
i
]
==
"--prefill"
&&
i
+
1
<
raw_args
.len
()
{
if
raw_args
[
i
]
==
"--prefill"
&&
i
+
1
<
raw_args
.len
()
{
// Skip --prefill and its URL
i
+=
2
;
i
+=
2
;
// Also skip bootstrap port if present
if
i
<
raw_args
.len
()
if
i
<
raw_args
.len
()
&&
!
raw_args
[
i
]
.starts_with
(
"--"
)
&&
!
raw_args
[
i
]
.starts_with
(
"--"
)
&&
(
raw_args
[
i
]
.parse
::
<
u16
>
()
.is_ok
()
||
raw_args
[
i
]
.to_lowercase
()
==
"none"
)
&&
(
raw_args
[
i
]
.parse
::
<
u16
>
()
.is_ok
()
||
raw_args
[
i
]
.to_lowercase
()
==
"none"
)
...
@@ -716,10 +623,8 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
...
@@ -716,10 +623,8 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
}
}
}
}
// Parse CLI arguments with clap using filtered args
let
cli_args
=
CliArgs
::
parse_from
(
filtered_args
);
let
cli_args
=
CliArgs
::
parse_from
(
filtered_args
);
// Print startup info
println!
(
"SGLang Router starting..."
);
println!
(
"SGLang Router starting..."
);
println!
(
"Host: {}:{}"
,
cli_args
.host
,
cli_args
.port
);
println!
(
"Host: {}:{}"
,
cli_args
.host
,
cli_args
.port
);
let
mode_str
=
if
cli_args
.enable_igw
{
let
mode_str
=
if
cli_args
.enable_igw
{
...
@@ -733,7 +638,6 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
...
@@ -733,7 +638,6 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
};
};
println!
(
"Mode: {}"
,
mode_str
);
println!
(
"Mode: {}"
,
mode_str
);
// Warn for runtimes that are parsed but not yet implemented
match
cli_args
.backend
{
match
cli_args
.backend
{
Backend
::
Vllm
|
Backend
::
Trtllm
|
Backend
::
Anthropic
=>
{
Backend
::
Vllm
|
Backend
::
Trtllm
|
Backend
::
Anthropic
=>
{
println!
(
println!
(
...
@@ -754,19 +658,10 @@ Provide --worker-urls or PD flags as usual.",
...
@@ -754,19 +658,10 @@ Provide --worker-urls or PD flags as usual.",
}
}
}
}
// Convert to RouterConfig
let
router_config
=
cli_args
.to_router_config
(
prefill_urls
)
?
;
let
router_config
=
cli_args
.to_router_config
(
prefill_urls
)
?
;
// Validate configuration
router_config
.validate
()
?
;
router_config
.validate
()
?
;
// Create ServerConfig
let
server_config
=
cli_args
.to_server_config
(
router_config
);
let
server_config
=
cli_args
.to_server_config
(
router_config
);
// Create a new runtime for the server (like Python binding does)
let
runtime
=
tokio
::
runtime
::
Runtime
::
new
()
?
;
let
runtime
=
tokio
::
runtime
::
Runtime
::
new
()
?
;
// Block on the async startup function
runtime
.block_on
(
async
move
{
server
::
startup
(
server_config
)
.await
})
?
;
runtime
.block_on
(
async
move
{
server
::
startup
(
server_config
)
.await
})
?
;
Ok
(())
Ok
(())
...
...
sgl-router/src/metrics.rs
View file @
a7fe6e10
...
@@ -19,7 +19,6 @@ impl Default for PrometheusConfig {
...
@@ -19,7 +19,6 @@ impl Default for PrometheusConfig {
}
}
pub
fn
init_metrics
()
{
pub
fn
init_metrics
()
{
// Request metrics
describe_counter!
(
describe_counter!
(
"sgl_router_requests_total"
,
"sgl_router_requests_total"
,
"Total number of requests by route and method"
"Total number of requests by route and method"
...
@@ -45,7 +44,6 @@ pub fn init_metrics() {
...
@@ -45,7 +44,6 @@ pub fn init_metrics() {
"Total number of requests that exhausted retries by route"
"Total number of requests that exhausted retries by route"
);
);
// Circuit breaker metrics
describe_gauge!
(
describe_gauge!
(
"sgl_router_cb_state"
,
"sgl_router_cb_state"
,
"Circuit breaker state per worker (0=closed, 1=open, 2=half_open)"
"Circuit breaker state per worker (0=closed, 1=open, 2=half_open)"
...
@@ -59,7 +57,6 @@ pub fn init_metrics() {
...
@@ -59,7 +57,6 @@ pub fn init_metrics() {
"Total number of circuit breaker outcomes by worker and outcome type (success/failure)"
"Total number of circuit breaker outcomes by worker and outcome type (success/failure)"
);
);
// Worker metrics
describe_gauge!
(
describe_gauge!
(
"sgl_router_active_workers"
,
"sgl_router_active_workers"
,
"Number of currently active workers"
"Number of currently active workers"
...
@@ -74,7 +71,6 @@ pub fn init_metrics() {
...
@@ -74,7 +71,6 @@ pub fn init_metrics() {
"Total requests processed by each worker"
"Total requests processed by each worker"
);
);
// Policy metrics
describe_counter!
(
describe_counter!
(
"sgl_router_policy_decisions_total"
,
"sgl_router_policy_decisions_total"
,
"Total routing policy decisions by policy and worker"
"Total routing policy decisions by policy and worker"
...
@@ -92,7 +88,6 @@ pub fn init_metrics() {
...
@@ -92,7 +88,6 @@ pub fn init_metrics() {
describe_gauge!
(
"sgl_router_max_load"
,
"Maximum worker load"
);
describe_gauge!
(
"sgl_router_max_load"
,
"Maximum worker load"
);
describe_gauge!
(
"sgl_router_min_load"
,
"Minimum worker load"
);
describe_gauge!
(
"sgl_router_min_load"
,
"Minimum worker load"
);
// PD-specific metrics
describe_counter!
(
"sgl_router_pd_requests_total"
,
"Total PD requests by route"
);
describe_counter!
(
"sgl_router_pd_requests_total"
,
"Total PD requests by route"
);
describe_counter!
(
describe_counter!
(
"sgl_router_pd_prefill_requests_total"
,
"sgl_router_pd_prefill_requests_total"
,
...
@@ -123,7 +118,6 @@ pub fn init_metrics() {
...
@@ -123,7 +118,6 @@ pub fn init_metrics() {
"PD request duration by route"
"PD request duration by route"
);
);
// Service discovery metrics
describe_counter!
(
describe_counter!
(
"sgl_router_discovery_updates_total"
,
"sgl_router_discovery_updates_total"
,
"Total service discovery update events"
"Total service discovery update events"
...
@@ -137,13 +131,11 @@ pub fn init_metrics() {
...
@@ -137,13 +131,11 @@ pub fn init_metrics() {
"Number of workers removed in last discovery update"
"Number of workers removed in last discovery update"
);
);
// Generate request specific metrics
describe_histogram!
(
describe_histogram!
(
"sgl_router_generate_duration_seconds"
,
"sgl_router_generate_duration_seconds"
,
"Generate request duration"
"Generate request duration"
);
);
// Embedding request specific metrics
describe_counter!
(
"sgl_router_embeddings_total"
,
"Total embedding requests"
);
describe_counter!
(
"sgl_router_embeddings_total"
,
"Total embedding requests"
);
describe_histogram!
(
describe_histogram!
(
"sgl_router_embeddings_duration_seconds"
,
"sgl_router_embeddings_duration_seconds"
,
...
@@ -155,13 +147,11 @@ pub fn init_metrics() {
...
@@ -155,13 +147,11 @@ pub fn init_metrics() {
);
);
describe_gauge!
(
"sgl_router_embeddings_queue_size"
,
"Embedding queue size"
);
describe_gauge!
(
"sgl_router_embeddings_queue_size"
,
"Embedding queue size"
);
// Running requests gauge for cache-aware policy
describe_gauge!
(
describe_gauge!
(
"sgl_router_running_requests"
,
"sgl_router_running_requests"
,
"Number of running requests per worker"
"Number of running requests per worker"
);
);
// Tokenizer metrics
describe_histogram!
(
describe_histogram!
(
"sgl_tokenizer_encode_duration_seconds"
,
"sgl_tokenizer_encode_duration_seconds"
,
"Time to encode text to tokens"
"Time to encode text to tokens"
...
@@ -207,7 +197,6 @@ pub fn init_metrics() {
...
@@ -207,7 +197,6 @@ pub fn init_metrics() {
"Vocabulary size of the loaded tokenizer"
"Vocabulary size of the loaded tokenizer"
);
);
// Stop sequence detection metrics
describe_counter!
(
describe_counter!
(
"sgl_tokenizer_stop_sequences_detected_total"
,
"sgl_tokenizer_stop_sequences_detected_total"
,
"Total stop sequences detected by type"
"Total stop sequences detected by type"
...
@@ -221,7 +210,6 @@ pub fn init_metrics() {
...
@@ -221,7 +210,6 @@ pub fn init_metrics() {
"Time to check for stop sequences per token"
"Time to check for stop sequences per token"
);
);
// Streaming decode metrics
describe_counter!
(
describe_counter!
(
"sgl_tokenizer_stream_tokens_total"
,
"sgl_tokenizer_stream_tokens_total"
,
"Total tokens processed in streaming decode"
"Total tokens processed in streaming decode"
...
@@ -235,7 +223,6 @@ pub fn init_metrics() {
...
@@ -235,7 +223,6 @@ pub fn init_metrics() {
"Time per streaming decode step"
"Time per streaming decode step"
);
);
// Factory metrics
describe_counter!
(
describe_counter!
(
"sgl_tokenizer_factory_loads_total"
,
"sgl_tokenizer_factory_loads_total"
,
"Total tokenizer loads by file type"
"Total tokenizer loads by file type"
...
@@ -251,7 +238,6 @@ pub fn init_metrics() {
...
@@ -251,7 +238,6 @@ pub fn init_metrics() {
}
}
pub
fn
start_prometheus
(
config
:
PrometheusConfig
)
{
pub
fn
start_prometheus
(
config
:
PrometheusConfig
)
{
// Initialize metric descriptions
init_metrics
();
init_metrics
();
let
duration_matcher
=
Matcher
::
Suffix
(
String
::
from
(
"duration_seconds"
));
let
duration_matcher
=
Matcher
::
Suffix
(
String
::
from
(
"duration_seconds"
));
...
@@ -280,7 +266,6 @@ pub struct RouterMetrics;
...
@@ -280,7 +266,6 @@ pub struct RouterMetrics;
pub
struct
TokenizerMetrics
;
pub
struct
TokenizerMetrics
;
impl
RouterMetrics
{
impl
RouterMetrics
{
// Request metrics
pub
fn
record_request
(
route
:
&
str
)
{
pub
fn
record_request
(
route
:
&
str
)
{
counter!
(
"sgl_router_requests_total"
,
counter!
(
"sgl_router_requests_total"
,
"route"
=>
route
.to_string
()
"route"
=>
route
.to_string
()
...
@@ -324,7 +309,6 @@ impl RouterMetrics {
...
@@ -324,7 +309,6 @@ impl RouterMetrics {
.increment
(
1
);
.increment
(
1
);
}
}
// Worker metrics
pub
fn
set_active_workers
(
count
:
usize
)
{
pub
fn
set_active_workers
(
count
:
usize
)
{
gauge!
(
"sgl_router_active_workers"
)
.set
(
count
as
f64
);
gauge!
(
"sgl_router_active_workers"
)
.set
(
count
as
f64
);
}
}
...
@@ -350,7 +334,6 @@ impl RouterMetrics {
...
@@ -350,7 +334,6 @@ impl RouterMetrics {
.increment
(
1
);
.increment
(
1
);
}
}
// Policy metrics
pub
fn
record_policy_decision
(
policy
:
&
str
,
worker
:
&
str
)
{
pub
fn
record_policy_decision
(
policy
:
&
str
,
worker
:
&
str
)
{
counter!
(
"sgl_router_policy_decisions_total"
,
counter!
(
"sgl_router_policy_decisions_total"
,
"policy"
=>
policy
.to_string
(),
"policy"
=>
policy
.to_string
(),
...
@@ -383,7 +366,6 @@ impl RouterMetrics {
...
@@ -383,7 +366,6 @@ impl RouterMetrics {
gauge!
(
"sgl_router_min_load"
)
.set
(
min_load
as
f64
);
gauge!
(
"sgl_router_min_load"
)
.set
(
min_load
as
f64
);
}
}
// PD-specific metrics
pub
fn
record_pd_request
(
route
:
&
str
)
{
pub
fn
record_pd_request
(
route
:
&
str
)
{
counter!
(
"sgl_router_pd_requests_total"
,
counter!
(
"sgl_router_pd_requests_total"
,
"route"
=>
route
.to_string
()
"route"
=>
route
.to_string
()
...
@@ -440,19 +422,16 @@ impl RouterMetrics {
...
@@ -440,19 +422,16 @@ impl RouterMetrics {
.increment
(
1
);
.increment
(
1
);
}
}
// Service discovery metrics
pub
fn
record_discovery_update
(
added
:
usize
,
removed
:
usize
)
{
pub
fn
record_discovery_update
(
added
:
usize
,
removed
:
usize
)
{
counter!
(
"sgl_router_discovery_updates_total"
)
.increment
(
1
);
counter!
(
"sgl_router_discovery_updates_total"
)
.increment
(
1
);
gauge!
(
"sgl_router_discovery_workers_added"
)
.set
(
added
as
f64
);
gauge!
(
"sgl_router_discovery_workers_added"
)
.set
(
added
as
f64
);
gauge!
(
"sgl_router_discovery_workers_removed"
)
.set
(
removed
as
f64
);
gauge!
(
"sgl_router_discovery_workers_removed"
)
.set
(
removed
as
f64
);
}
}
// Generate request metrics
pub
fn
record_generate_duration
(
duration
:
Duration
)
{
pub
fn
record_generate_duration
(
duration
:
Duration
)
{
histogram!
(
"sgl_router_generate_duration_seconds"
)
.record
(
duration
.as_secs_f64
());
histogram!
(
"sgl_router_generate_duration_seconds"
)
.record
(
duration
.as_secs_f64
());
}
}
// Embeddings metrics
pub
fn
record_embeddings_request
()
{
pub
fn
record_embeddings_request
()
{
counter!
(
"sgl_router_embeddings_total"
)
.increment
(
1
);
counter!
(
"sgl_router_embeddings_total"
)
.increment
(
1
);
}
}
...
@@ -473,7 +452,6 @@ impl RouterMetrics {
...
@@ -473,7 +452,6 @@ impl RouterMetrics {
gauge!
(
"sgl_router_embeddings_queue_size"
)
.set
(
size
as
f64
);
gauge!
(
"sgl_router_embeddings_queue_size"
)
.set
(
size
as
f64
);
}
}
// Running requests for cache-aware policy
pub
fn
set_running_requests
(
worker
:
&
str
,
count
:
usize
)
{
pub
fn
set_running_requests
(
worker
:
&
str
,
count
:
usize
)
{
gauge!
(
"sgl_router_running_requests"
,
gauge!
(
"sgl_router_running_requests"
,
"worker"
=>
worker
.to_string
()
"worker"
=>
worker
.to_string
()
...
@@ -481,7 +459,6 @@ impl RouterMetrics {
...
@@ -481,7 +459,6 @@ impl RouterMetrics {
.set
(
count
as
f64
);
.set
(
count
as
f64
);
}
}
// Circuit breaker metrics
pub
fn
set_cb_state
(
worker
:
&
str
,
state_code
:
u8
)
{
pub
fn
set_cb_state
(
worker
:
&
str
,
state_code
:
u8
)
{
gauge!
(
"sgl_router_cb_state"
,
gauge!
(
"sgl_router_cb_state"
,
"worker"
=>
worker
.to_string
()
"worker"
=>
worker
.to_string
()
...
@@ -508,7 +485,6 @@ impl RouterMetrics {
...
@@ -508,7 +485,6 @@ impl RouterMetrics {
}
}
impl
TokenizerMetrics
{
impl
TokenizerMetrics
{
// Encoding metrics
pub
fn
record_encode_request
(
tokenizer_type
:
&
str
)
{
pub
fn
record_encode_request
(
tokenizer_type
:
&
str
)
{
counter!
(
"sgl_tokenizer_encode_requests_total"
,
counter!
(
"sgl_tokenizer_encode_requests_total"
,
"tokenizer_type"
=>
tokenizer_type
.to_string
()
"tokenizer_type"
=>
tokenizer_type
.to_string
()
...
@@ -535,7 +511,6 @@ impl TokenizerMetrics {
...
@@ -535,7 +511,6 @@ impl TokenizerMetrics {
histogram!
(
"sgl_tokenizer_chars_per_encode"
)
.record
(
char_count
as
f64
);
histogram!
(
"sgl_tokenizer_chars_per_encode"
)
.record
(
char_count
as
f64
);
}
}
// Decoding metrics
pub
fn
record_decode_request
(
tokenizer_type
:
&
str
)
{
pub
fn
record_decode_request
(
tokenizer_type
:
&
str
)
{
counter!
(
"sgl_tokenizer_decode_requests_total"
,
counter!
(
"sgl_tokenizer_decode_requests_total"
,
"tokenizer_type"
=>
tokenizer_type
.to_string
()
"tokenizer_type"
=>
tokenizer_type
.to_string
()
...
@@ -558,7 +533,6 @@ impl TokenizerMetrics {
...
@@ -558,7 +533,6 @@ impl TokenizerMetrics {
histogram!
(
"sgl_tokenizer_tokens_per_decode"
)
.record
(
token_count
as
f64
);
histogram!
(
"sgl_tokenizer_tokens_per_decode"
)
.record
(
token_count
as
f64
);
}
}
// Batch encoding metrics
pub
fn
record_encode_batch_duration
(
duration
:
Duration
,
batch_size
:
usize
)
{
pub
fn
record_encode_batch_duration
(
duration
:
Duration
,
batch_size
:
usize
)
{
histogram!
(
"sgl_tokenizer_encode_batch_duration_seconds"
,
histogram!
(
"sgl_tokenizer_encode_batch_duration_seconds"
,
"batch_size"
=>
batch_size
.to_string
()
"batch_size"
=>
batch_size
.to_string
()
...
@@ -566,7 +540,6 @@ impl TokenizerMetrics {
...
@@ -566,7 +540,6 @@ impl TokenizerMetrics {
.record
(
duration
.as_secs_f64
());
.record
(
duration
.as_secs_f64
());
}
}
// Stop sequence detection metrics
pub
fn
record_stop_sequence_detected
(
stop_type
:
&
str
)
{
pub
fn
record_stop_sequence_detected
(
stop_type
:
&
str
)
{
counter!
(
"sgl_tokenizer_stop_sequences_detected_total"
,
counter!
(
"sgl_tokenizer_stop_sequences_detected_total"
,
"type"
=>
stop_type
.to_string
()
"type"
=>
stop_type
.to_string
()
...
@@ -582,7 +555,6 @@ impl TokenizerMetrics {
...
@@ -582,7 +555,6 @@ impl TokenizerMetrics {
histogram!
(
"sgl_tokenizer_stop_detection_duration_seconds"
)
.record
(
duration
.as_secs_f64
());
histogram!
(
"sgl_tokenizer_stop_detection_duration_seconds"
)
.record
(
duration
.as_secs_f64
());
}
}
// Streaming decode metrics
pub
fn
record_stream_token
()
{
pub
fn
record_stream_token
()
{
counter!
(
"sgl_tokenizer_stream_tokens_total"
)
.increment
(
1
);
counter!
(
"sgl_tokenizer_stream_tokens_total"
)
.increment
(
1
);
}
}
...
@@ -595,7 +567,6 @@ impl TokenizerMetrics {
...
@@ -595,7 +567,6 @@ impl TokenizerMetrics {
histogram!
(
"sgl_tokenizer_stream_step_duration_seconds"
)
.record
(
duration
.as_secs_f64
());
histogram!
(
"sgl_tokenizer_stream_step_duration_seconds"
)
.record
(
duration
.as_secs_f64
());
}
}
// Factory metrics
pub
fn
record_factory_load
(
file_type
:
&
str
)
{
pub
fn
record_factory_load
(
file_type
:
&
str
)
{
counter!
(
"sgl_tokenizer_factory_loads_total"
,
counter!
(
"sgl_tokenizer_factory_loads_total"
,
"file_type"
=>
file_type
.to_string
()
"file_type"
=>
file_type
.to_string
()
...
@@ -614,7 +585,6 @@ impl TokenizerMetrics {
...
@@ -614,7 +585,6 @@ impl TokenizerMetrics {
histogram!
(
"sgl_tokenizer_factory_load_duration_seconds"
)
.record
(
duration
.as_secs_f64
());
histogram!
(
"sgl_tokenizer_factory_load_duration_seconds"
)
.record
(
duration
.as_secs_f64
());
}
}
// Vocabulary metrics
pub
fn
set_vocab_size
(
tokenizer_type
:
&
str
,
size
:
usize
)
{
pub
fn
set_vocab_size
(
tokenizer_type
:
&
str
,
size
:
usize
)
{
gauge!
(
"sgl_tokenizer_vocab_size"
,
gauge!
(
"sgl_tokenizer_vocab_size"
,
"tokenizer_type"
=>
tokenizer_type
.to_string
()
"tokenizer_type"
=>
tokenizer_type
.to_string
()
...
@@ -705,7 +675,6 @@ mod tests {
...
@@ -705,7 +675,6 @@ mod tests {
.parse
()
.parse
()
.unwrap_or
(
IpAddr
::
V4
(
Ipv4Addr
::
new
(
0
,
0
,
0
,
0
)));
.unwrap_or
(
IpAddr
::
V4
(
Ipv4Addr
::
new
(
0
,
0
,
0
,
0
)));
// Should fall back to 0.0.0.0
assert_eq!
(
ip_addr
,
IpAddr
::
V4
(
Ipv4Addr
::
new
(
0
,
0
,
0
,
0
)));
assert_eq!
(
ip_addr
,
IpAddr
::
V4
(
Ipv4Addr
::
new
(
0
,
0
,
0
,
0
)));
}
}
}
}
...
@@ -780,7 +749,6 @@ mod tests {
...
@@ -780,7 +749,6 @@ mod tests {
fn
test_duration_suffix_matcher
()
{
fn
test_duration_suffix_matcher
()
{
let
matcher
=
Matcher
::
Suffix
(
String
::
from
(
"duration_seconds"
));
let
matcher
=
Matcher
::
Suffix
(
String
::
from
(
"duration_seconds"
));
// Test matching behavior
let
_
matching_metrics
=
[
let
_
matching_metrics
=
[
"request_duration_seconds"
,
"request_duration_seconds"
,
"response_duration_seconds"
,
"response_duration_seconds"
,
...
@@ -789,8 +757,6 @@ mod tests {
...
@@ -789,8 +757,6 @@ mod tests {
let
_
non_matching_metrics
=
[
"duration_total"
,
"duration_seconds_total"
,
"other_metric"
];
let
_
non_matching_metrics
=
[
"duration_total"
,
"duration_seconds_total"
,
"other_metric"
];
// Note: We can't directly test Matcher matching without the internals,
// but we can verify the matcher is created correctly
match
matcher
{
match
matcher
{
Matcher
::
Suffix
(
suffix
)
=>
assert_eq!
(
suffix
,
"duration_seconds"
),
Matcher
::
Suffix
(
suffix
)
=>
assert_eq!
(
suffix
,
"duration_seconds"
),
_
=>
panic!
(
"Expected Suffix matcher"
),
_
=>
panic!
(
"Expected Suffix matcher"
),
...
@@ -801,7 +767,6 @@ mod tests {
...
@@ -801,7 +767,6 @@ mod tests {
#[test]
#[test]
fn
test_prometheus_builder_configuration
()
{
fn
test_prometheus_builder_configuration
()
{
// This test verifies the builder configuration without actually starting Prometheus
let
_
config
=
PrometheusConfig
::
default
();
let
_
config
=
PrometheusConfig
::
default
();
let
duration_matcher
=
Matcher
::
Suffix
(
String
::
from
(
"duration_seconds"
));
let
duration_matcher
=
Matcher
::
Suffix
(
String
::
from
(
"duration_seconds"
));
...
@@ -810,10 +775,8 @@ mod tests {
...
@@ -810,10 +775,8 @@ mod tests {
60.0
,
90.0
,
120.0
,
180.0
,
240.0
,
60.0
,
90.0
,
120.0
,
180.0
,
240.0
,
];
];
// Verify bucket configuration
assert_eq!
(
duration_bucket
.len
(),
20
);
assert_eq!
(
duration_bucket
.len
(),
20
);
// Verify matcher is suffix type
match
duration_matcher
{
match
duration_matcher
{
Matcher
::
Suffix
(
s
)
=>
assert_eq!
(
s
,
"duration_seconds"
),
Matcher
::
Suffix
(
s
)
=>
assert_eq!
(
s
,
"duration_seconds"
),
_
=>
panic!
(
"Expected Suffix matcher"
),
_
=>
panic!
(
"Expected Suffix matcher"
),
...
@@ -832,14 +795,12 @@ mod tests {
...
@@ -832,14 +795,12 @@ mod tests {
#[test]
#[test]
fn
test_custom_buckets_for_different_metrics
()
{
fn
test_custom_buckets_for_different_metrics
()
{
// Test that we can create different bucket configurations
let
request_buckets
=
[
0.001
,
0.01
,
0.1
,
1.0
,
10.0
];
let
request_buckets
=
[
0.001
,
0.01
,
0.1
,
1.0
,
10.0
];
let
generate_buckets
=
[
0.1
,
0.5
,
1.0
,
5.0
,
30.0
,
60.0
];
let
generate_buckets
=
[
0.1
,
0.5
,
1.0
,
5.0
,
30.0
,
60.0
];
assert_eq!
(
request_buckets
.len
(),
5
);
assert_eq!
(
request_buckets
.len
(),
5
);
assert_eq!
(
generate_buckets
.len
(),
6
);
assert_eq!
(
generate_buckets
.len
(),
6
);
// Verify each set is sorted
for
i
in
1
..
request_buckets
.len
()
{
for
i
in
1
..
request_buckets
.len
()
{
assert
!
(
request_buckets
[
i
]
>
request_buckets
[
i
-
1
]);
assert
!
(
request_buckets
[
i
]
>
request_buckets
[
i
-
1
]);
}
}
...
@@ -853,7 +814,6 @@ mod tests {
...
@@ -853,7 +814,6 @@ mod tests {
#[test]
#[test]
fn
test_metrics_static_methods
()
{
fn
test_metrics_static_methods
()
{
// Test that all static methods can be called without panic
RouterMetrics
::
record_request
(
"/generate"
);
RouterMetrics
::
record_request
(
"/generate"
);
RouterMetrics
::
record_request_duration
(
"/generate"
,
Duration
::
from_millis
(
100
));
RouterMetrics
::
record_request_duration
(
"/generate"
,
Duration
::
from_millis
(
100
));
RouterMetrics
::
record_request_error
(
"/generate"
,
"timeout"
);
RouterMetrics
::
record_request_error
(
"/generate"
,
"timeout"
);
...
@@ -887,41 +847,32 @@ mod tests {
...
@@ -887,41 +847,32 @@ mod tests {
#[test]
#[test]
fn
test_tokenizer_metrics_static_methods
()
{
fn
test_tokenizer_metrics_static_methods
()
{
// Test that all tokenizer metric methods can be called without panic
// Encoding metrics
TokenizerMetrics
::
record_encode_request
(
"huggingface"
);
TokenizerMetrics
::
record_encode_request
(
"huggingface"
);
TokenizerMetrics
::
record_encode_duration
(
Duration
::
from_millis
(
10
));
TokenizerMetrics
::
record_encode_duration
(
Duration
::
from_millis
(
10
));
TokenizerMetrics
::
record_encode_error
(
"invalid_input"
);
TokenizerMetrics
::
record_encode_error
(
"invalid_input"
);
TokenizerMetrics
::
record_tokens_per_encode
(
100
);
TokenizerMetrics
::
record_tokens_per_encode
(
100
);
TokenizerMetrics
::
record_chars_per_encode
(
500
);
TokenizerMetrics
::
record_chars_per_encode
(
500
);
// Decoding metrics
TokenizerMetrics
::
record_decode_request
(
"huggingface"
);
TokenizerMetrics
::
record_decode_request
(
"huggingface"
);
TokenizerMetrics
::
record_decode_duration
(
Duration
::
from_millis
(
5
));
TokenizerMetrics
::
record_decode_duration
(
Duration
::
from_millis
(
5
));
TokenizerMetrics
::
record_decode_error
(
"invalid_tokens"
);
TokenizerMetrics
::
record_decode_error
(
"invalid_tokens"
);
TokenizerMetrics
::
record_tokens_per_decode
(
50
);
TokenizerMetrics
::
record_tokens_per_decode
(
50
);
// Batch encoding
TokenizerMetrics
::
record_encode_batch_duration
(
Duration
::
from_millis
(
100
),
10
);
TokenizerMetrics
::
record_encode_batch_duration
(
Duration
::
from_millis
(
100
),
10
);
// Stop sequence detection
TokenizerMetrics
::
record_stop_sequence_detected
(
"token"
);
TokenizerMetrics
::
record_stop_sequence_detected
(
"token"
);
TokenizerMetrics
::
record_stop_sequence_detected
(
"string"
);
TokenizerMetrics
::
record_stop_sequence_detected
(
"string"
);
TokenizerMetrics
::
record_partial_match
();
TokenizerMetrics
::
record_partial_match
();
TokenizerMetrics
::
record_stop_detection_duration
(
Duration
::
from_micros
(
100
));
TokenizerMetrics
::
record_stop_detection_duration
(
Duration
::
from_micros
(
100
));
// Streaming decode
TokenizerMetrics
::
record_stream_token
();
TokenizerMetrics
::
record_stream_token
();
TokenizerMetrics
::
record_incomplete_utf8
();
TokenizerMetrics
::
record_incomplete_utf8
();
TokenizerMetrics
::
record_stream_step_duration
(
Duration
::
from_micros
(
50
));
TokenizerMetrics
::
record_stream_step_duration
(
Duration
::
from_micros
(
50
));
// Factory metrics
TokenizerMetrics
::
record_factory_load
(
"json"
);
TokenizerMetrics
::
record_factory_load
(
"json"
);
TokenizerMetrics
::
record_factory_error
(
"unsupported_format"
);
TokenizerMetrics
::
record_factory_error
(
"unsupported_format"
);
TokenizerMetrics
::
record_factory_load_duration
(
Duration
::
from_millis
(
200
));
TokenizerMetrics
::
record_factory_load_duration
(
Duration
::
from_millis
(
200
));
// Vocabulary metrics
TokenizerMetrics
::
set_vocab_size
(
"huggingface"
,
50000
);
TokenizerMetrics
::
set_vocab_size
(
"huggingface"
,
50000
);
}
}
...
@@ -929,17 +880,14 @@ mod tests {
...
@@ -929,17 +880,14 @@ mod tests {
#[test]
#[test]
fn
test_port_already_in_use
()
{
fn
test_port_already_in_use
()
{
// Skip this test if we can't bind to the port
let
port
=
29123
;
let
port
=
29123
;
// Use a different port to avoid conflicts
if
let
Ok
(
_
listener
)
=
TcpListener
::
bind
((
"127.0.0.1"
,
port
))
{
if
let
Ok
(
_
listener
)
=
TcpListener
::
bind
((
"127.0.0.1"
,
port
))
{
// Port is available, we can test
let
config
=
PrometheusConfig
{
let
config
=
PrometheusConfig
{
port
,
port
,
host
:
"127.0.0.1"
.to_string
(),
host
:
"127.0.0.1"
.to_string
(),
};
};
// Just verify config is created correctly
assert_eq!
(
config
.port
,
port
);
assert_eq!
(
config
.port
,
port
);
}
}
}
}
...
@@ -948,8 +896,6 @@ mod tests {
...
@@ -948,8 +896,6 @@ mod tests {
#[test]
#[test]
fn
test_metrics_endpoint_accessibility
()
{
fn
test_metrics_endpoint_accessibility
()
{
// This would be an integration test in practice
// Here we just verify the configuration
let
config
=
PrometheusConfig
{
let
config
=
PrometheusConfig
{
port
:
29000
,
port
:
29000
,
host
:
"127.0.0.1"
.to_string
(),
host
:
"127.0.0.1"
.to_string
(),
...
@@ -963,7 +909,6 @@ mod tests {
...
@@ -963,7 +909,6 @@ mod tests {
#[test]
#[test]
fn
test_concurrent_metric_updates
()
{
fn
test_concurrent_metric_updates
()
{
// Test that metric updates can be called concurrently
use
std
::
sync
::
atomic
::{
AtomicBool
,
Ordering
};
use
std
::
sync
::
atomic
::{
AtomicBool
,
Ordering
};
use
std
::
sync
::
Arc
;
use
std
::
sync
::
Arc
;
use
std
::
thread
;
use
std
::
thread
;
...
@@ -984,11 +929,9 @@ mod tests {
...
@@ -984,11 +929,9 @@ mod tests {
handles
.push
(
handle
);
handles
.push
(
handle
);
}
}
// Let threads run briefly
thread
::
sleep
(
Duration
::
from_millis
(
10
));
thread
::
sleep
(
Duration
::
from_millis
(
10
));
done
.store
(
true
,
Ordering
::
Relaxed
);
done
.store
(
true
,
Ordering
::
Relaxed
);
// Wait for all threads
for
handle
in
handles
{
for
handle
in
handles
{
handle
.join
()
.unwrap
();
handle
.join
()
.unwrap
();
}
}
...
@@ -998,7 +941,6 @@ mod tests {
...
@@ -998,7 +941,6 @@ mod tests {
#[test]
#[test]
fn
test_empty_string_metrics
()
{
fn
test_empty_string_metrics
()
{
// Test that empty strings don't cause issues
RouterMetrics
::
record_request
(
""
);
RouterMetrics
::
record_request
(
""
);
RouterMetrics
::
set_worker_health
(
""
,
true
);
RouterMetrics
::
set_worker_health
(
""
,
true
);
RouterMetrics
::
record_policy_decision
(
""
,
""
);
RouterMetrics
::
record_policy_decision
(
""
,
""
);
...
@@ -1030,7 +972,6 @@ mod tests {
...
@@ -1030,7 +972,6 @@ mod tests {
#[test]
#[test]
fn
test_extreme_metric_values
()
{
fn
test_extreme_metric_values
()
{
// Test extreme values
RouterMetrics
::
set_active_workers
(
0
);
RouterMetrics
::
set_active_workers
(
0
);
RouterMetrics
::
set_active_workers
(
usize
::
MAX
);
RouterMetrics
::
set_active_workers
(
usize
::
MAX
);
...
@@ -1038,7 +979,6 @@ mod tests {
...
@@ -1038,7 +979,6 @@ mod tests {
RouterMetrics
::
set_worker_load
(
"worker"
,
usize
::
MAX
);
RouterMetrics
::
set_worker_load
(
"worker"
,
usize
::
MAX
);
RouterMetrics
::
record_request_duration
(
"route"
,
Duration
::
from_nanos
(
1
));
RouterMetrics
::
record_request_duration
(
"route"
,
Duration
::
from_nanos
(
1
));
// 24 hours
RouterMetrics
::
record_request_duration
(
"route"
,
Duration
::
from_secs
(
86400
));
RouterMetrics
::
record_request_duration
(
"route"
,
Duration
::
from_secs
(
86400
));
}
}
}
}
sgl-router/src/service_discovery.rs
View file @
a7fe6e10
...
@@ -19,7 +19,6 @@ use tokio::task;
...
@@ -19,7 +19,6 @@ use tokio::task;
use
tokio
::
time
;
use
tokio
::
time
;
use
tracing
::{
debug
,
error
,
info
,
warn
};
use
tracing
::{
debug
,
error
,
info
,
warn
};
/// Represents the service discovery configuration
#[derive(Debug,
Clone)]
#[derive(Debug,
Clone)]
pub
struct
ServiceDiscoveryConfig
{
pub
struct
ServiceDiscoveryConfig
{
pub
enabled
:
bool
,
pub
enabled
:
bool
,
...
@@ -41,8 +40,8 @@ impl Default for ServiceDiscoveryConfig {
...
@@ -41,8 +40,8 @@ impl Default for ServiceDiscoveryConfig {
enabled
:
false
,
enabled
:
false
,
selector
:
HashMap
::
new
(),
selector
:
HashMap
::
new
(),
check_interval
:
Duration
::
from_secs
(
60
),
check_interval
:
Duration
::
from_secs
(
60
),
port
:
8000
,
// Standard port for modern services
port
:
8000
,
namespace
:
None
,
// None means watch all namespaces
namespace
:
None
,
pd_mode
:
false
,
pd_mode
:
false
,
prefill_selector
:
HashMap
::
new
(),
prefill_selector
:
HashMap
::
new
(),
decode_selector
:
HashMap
::
new
(),
decode_selector
:
HashMap
::
new
(),
...
@@ -51,7 +50,6 @@ impl Default for ServiceDiscoveryConfig {
...
@@ -51,7 +50,6 @@ impl Default for ServiceDiscoveryConfig {
}
}
}
}
/// Pod type for PD mode service discovery
#[derive(Debug,
Clone,
PartialEq,
Eq,
Hash)]
#[derive(Debug,
Clone,
PartialEq,
Eq,
Hash)]
pub
enum
PodType
{
pub
enum
PodType
{
Prefill
,
Prefill
,
...
@@ -59,7 +57,6 @@ pub enum PodType {
...
@@ -59,7 +57,6 @@ pub enum PodType {
Regular
,
Regular
,
}
}
/// Represents a Kubernetes pod's information used for worker management
#[derive(Debug,
Clone,
PartialEq,
Eq,
Hash)]
#[derive(Debug,
Clone,
PartialEq,
Eq,
Hash)]
pub
struct
PodInfo
{
pub
struct
PodInfo
{
pub
name
:
String
,
pub
name
:
String
,
...
@@ -71,7 +68,6 @@ pub struct PodInfo {
...
@@ -71,7 +68,6 @@ pub struct PodInfo {
}
}
impl
PodInfo
{
impl
PodInfo
{
/// Check if a pod matches any of the given selectors
fn
matches_selector
(
pod
:
&
Pod
,
selector
:
&
HashMap
<
String
,
String
>
)
->
bool
{
fn
matches_selector
(
pod
:
&
Pod
,
selector
:
&
HashMap
<
String
,
String
>
)
->
bool
{
if
selector
.is_empty
()
{
if
selector
.is_empty
()
{
return
false
;
return
false
;
...
@@ -83,19 +79,15 @@ impl PodInfo {
...
@@ -83,19 +79,15 @@ impl PodInfo {
.is_some_and
(|
labels
|
selector
.iter
()
.all
(|(
k
,
v
)|
labels
.get
(
k
)
==
Some
(
v
)))
.is_some_and
(|
labels
|
selector
.iter
()
.all
(|(
k
,
v
)|
labels
.get
(
k
)
==
Some
(
v
)))
}
}
/// Check if a pod should be included in service discovery
pub
fn
should_include
(
pod
:
&
Pod
,
config
:
&
ServiceDiscoveryConfig
)
->
bool
{
pub
fn
should_include
(
pod
:
&
Pod
,
config
:
&
ServiceDiscoveryConfig
)
->
bool
{
if
config
.pd_mode
{
if
config
.pd_mode
{
// In PD mode, at least one selector must be non-empty
if
config
.prefill_selector
.is_empty
()
&&
config
.decode_selector
.is_empty
()
{
if
config
.prefill_selector
.is_empty
()
&&
config
.decode_selector
.is_empty
()
{
warn!
(
"PD mode enabled but both prefill_selector and decode_selector are empty"
);
warn!
(
"PD mode enabled but both prefill_selector and decode_selector are empty"
);
return
false
;
return
false
;
}
}
// In PD mode, pod must match either prefill or decode selector
Self
::
matches_selector
(
pod
,
&
config
.prefill_selector
)
Self
::
matches_selector
(
pod
,
&
config
.prefill_selector
)
||
Self
::
matches_selector
(
pod
,
&
config
.decode_selector
)
||
Self
::
matches_selector
(
pod
,
&
config
.decode_selector
)
}
else
{
}
else
{
// In regular mode, pod must match the general selector
if
config
.selector
.is_empty
()
{
if
config
.selector
.is_empty
()
{
warn!
(
"Regular mode enabled but selector is empty"
);
warn!
(
"Regular mode enabled but selector is empty"
);
return
false
;
return
false
;
...
@@ -104,7 +96,6 @@ impl PodInfo {
...
@@ -104,7 +96,6 @@ impl PodInfo {
}
}
}
}
/// Unified PodInfo creation with optional PD configuration
pub
fn
from_pod
(
pod
:
&
Pod
,
config
:
Option
<&
ServiceDiscoveryConfig
>
)
->
Option
<
Self
>
{
pub
fn
from_pod
(
pod
:
&
Pod
,
config
:
Option
<&
ServiceDiscoveryConfig
>
)
->
Option
<
Self
>
{
let
name
=
pod
.metadata.name
.clone
()
?
;
let
name
=
pod
.metadata.name
.clone
()
?
;
let
status
=
pod
.status
.clone
()
?
;
let
status
=
pod
.status
.clone
()
?
;
...
@@ -120,10 +111,8 @@ impl PodInfo {
...
@@ -120,10 +111,8 @@ impl PodInfo {
let
pod_status
=
status
.phase
.unwrap_or_else
(||
"Unknown"
.to_string
());
let
pod_status
=
status
.phase
.unwrap_or_else
(||
"Unknown"
.to_string
());
// Determine pod type based on labels if config is provided and in PD mode
let
pod_type
=
if
let
Some
(
config
)
=
config
{
let
pod_type
=
if
let
Some
(
config
)
=
config
{
if
config
.pd_mode
{
if
config
.pd_mode
{
// Use simplified helper methods for cleaner logic
if
Self
::
matches_selector
(
pod
,
&
config
.prefill_selector
)
{
if
Self
::
matches_selector
(
pod
,
&
config
.prefill_selector
)
{
Some
(
PodType
::
Prefill
)
Some
(
PodType
::
Prefill
)
}
else
if
Self
::
matches_selector
(
pod
,
&
config
.decode_selector
)
{
}
else
if
Self
::
matches_selector
(
pod
,
&
config
.decode_selector
)
{
...
@@ -135,11 +124,9 @@ impl PodInfo {
...
@@ -135,11 +124,9 @@ impl PodInfo {
Some
(
PodType
::
Regular
)
Some
(
PodType
::
Regular
)
}
}
}
else
{
}
else
{
// No config provided, default to None (for backwards compatibility)
None
None
};
};
// Extract bootstrap port from annotations for prefill pods
let
bootstrap_port
=
if
matches!
(
pod_type
,
Some
(
PodType
::
Prefill
))
{
let
bootstrap_port
=
if
matches!
(
pod_type
,
Some
(
PodType
::
Prefill
))
{
if
let
Some
(
config
)
=
config
{
if
let
Some
(
config
)
=
config
{
pod
.metadata
pod
.metadata
...
@@ -164,12 +151,10 @@ impl PodInfo {
...
@@ -164,12 +151,10 @@ impl PodInfo {
})
})
}
}
/// Returns true if the pod is in a state where it can accept traffic
pub
fn
is_healthy
(
&
self
)
->
bool
{
pub
fn
is_healthy
(
&
self
)
->
bool
{
self
.is_ready
&&
self
.status
==
"Running"
self
.is_ready
&&
self
.status
==
"Running"
}
}
/// Generates a worker URL for this pod
pub
fn
worker_url
(
&
self
,
port
:
u16
)
->
String
{
pub
fn
worker_url
(
&
self
,
port
:
u16
)
->
String
{
format!
(
"http://{}:{}"
,
self
.ip
,
port
)
format!
(
"http://{}:{}"
,
self
.ip
,
port
)
}
}
...
@@ -179,9 +164,7 @@ pub async fn start_service_discovery(
...
@@ -179,9 +164,7 @@ pub async fn start_service_discovery(
config
:
ServiceDiscoveryConfig
,
config
:
ServiceDiscoveryConfig
,
app_context
:
Arc
<
AppContext
>
,
app_context
:
Arc
<
AppContext
>
,
)
->
Result
<
task
::
JoinHandle
<
()
>
,
kube
::
Error
>
{
)
->
Result
<
task
::
JoinHandle
<
()
>
,
kube
::
Error
>
{
// Don't initialize anything if service discovery is disabled
if
!
config
.enabled
{
if
!
config
.enabled
{
// Return a generic error when service discovery is disabled
return
Err
(
kube
::
Error
::
Api
(
kube
::
error
::
ErrorResponse
{
return
Err
(
kube
::
Error
::
Api
(
kube
::
error
::
ErrorResponse
{
status
:
"Disabled"
.to_string
(),
status
:
"Disabled"
.to_string
(),
message
:
"Service discovery is disabled"
.to_string
(),
message
:
"Service discovery is disabled"
.to_string
(),
...
@@ -192,7 +175,6 @@ pub async fn start_service_discovery(
...
@@ -192,7 +175,6 @@ pub async fn start_service_discovery(
let
_
=
rustls
::
crypto
::
ring
::
default_provider
()
.install_default
();
let
_
=
rustls
::
crypto
::
ring
::
default_provider
()
.install_default
();
// Initialize Kubernetes client
let
client
=
Client
::
try_default
()
.await
?
;
let
client
=
Client
::
try_default
()
.await
?
;
// Log the appropriate selectors based on mode
// Log the appropriate selectors based on mode
...
@@ -229,12 +211,9 @@ pub async fn start_service_discovery(
...
@@ -229,12 +211,9 @@ pub async fn start_service_discovery(
);
);
}
}
// Create the task that will run in the background
let
handle
=
task
::
spawn
(
async
move
{
let
handle
=
task
::
spawn
(
async
move
{
// We'll track pods we've already added to avoid duplicates
let
tracked_pods
=
Arc
::
new
(
Mutex
::
new
(
HashSet
::
new
()));
let
tracked_pods
=
Arc
::
new
(
Mutex
::
new
(
HashSet
::
new
()));
// Create a watcher for pods
let
pods
:
Api
<
Pod
>
=
if
let
Some
(
namespace
)
=
&
config
.namespace
{
let
pods
:
Api
<
Pod
>
=
if
let
Some
(
namespace
)
=
&
config
.namespace
{
Api
::
namespaced
(
client
,
namespace
)
Api
::
namespaced
(
client
,
namespace
)
}
else
{
}
else
{
...
@@ -243,23 +222,19 @@ pub async fn start_service_discovery(
...
@@ -243,23 +222,19 @@ pub async fn start_service_discovery(
debug!
(
"K8s service discovery initialized"
);
debug!
(
"K8s service discovery initialized"
);
// Create Arcs for configuration data
let
config_arc
=
Arc
::
new
(
config
.clone
());
let
config_arc
=
Arc
::
new
(
config
.clone
());
let
port
=
config
.port
;
let
port
=
config
.port
;
let
mut
retry_delay
=
Duration
::
from_secs
(
1
);
let
mut
retry_delay
=
Duration
::
from_secs
(
1
);
const
MAX_RETRY_DELAY
:
Duration
=
Duration
::
from_secs
(
300
);
// 5 minutes max
const
MAX_RETRY_DELAY
:
Duration
=
Duration
::
from_secs
(
300
);
loop
{
loop
{
// Create a watcher with the proper parameters according to the kube-rs API
let
watcher_config
=
Config
::
default
();
let
watcher_config
=
Config
::
default
();
let
watcher_stream
=
watcher
(
pods
.clone
(),
watcher_config
)
.applied_objects
();
let
watcher_stream
=
watcher
(
pods
.clone
(),
watcher_config
)
.applied_objects
();
// Clone Arcs for the closures
let
config_clone
=
Arc
::
clone
(
&
config_arc
);
let
config_clone
=
Arc
::
clone
(
&
config_arc
);
let
tracked_pods_clone
=
Arc
::
clone
(
&
tracked_pods
);
let
tracked_pods_clone
=
Arc
::
clone
(
&
tracked_pods
);
// Simplified label selector filter using helper method
let
filtered_stream
=
watcher_stream
.filter_map
(
move
|
obj_res
|
{
let
filtered_stream
=
watcher_stream
.filter_map
(
move
|
obj_res
|
{
let
config_inner
=
Arc
::
clone
(
&
config_clone
);
let
config_inner
=
Arc
::
clone
(
&
config_clone
);
...
@@ -277,7 +252,6 @@ pub async fn start_service_discovery(
...
@@ -277,7 +252,6 @@ pub async fn start_service_discovery(
}
}
});
});
// Clone again for the next closure
let
tracked_pods_clone2
=
Arc
::
clone
(
&
tracked_pods_clone
);
let
tracked_pods_clone2
=
Arc
::
clone
(
&
tracked_pods_clone
);
let
app_context_clone
=
Arc
::
clone
(
&
app_context
);
let
app_context_clone
=
Arc
::
clone
(
&
app_context
);
let
config_clone2
=
Arc
::
clone
(
&
config_arc
);
let
config_clone2
=
Arc
::
clone
(
&
config_arc
);
...
@@ -317,7 +291,6 @@ pub async fn start_service_discovery(
...
@@ -317,7 +291,6 @@ pub async fn start_service_discovery(
.await
.await
{
{
Ok
(
_
)
=>
{
Ok
(
_
)
=>
{
// Reset retry delay on success
retry_delay
=
Duration
::
from_secs
(
1
);
retry_delay
=
Duration
::
from_secs
(
1
);
}
}
Err
(
err
)
=>
{
Err
(
err
)
=>
{
...
@@ -328,12 +301,10 @@ pub async fn start_service_discovery(
...
@@ -328,12 +301,10 @@ pub async fn start_service_discovery(
);
);
time
::
sleep
(
retry_delay
)
.await
;
time
::
sleep
(
retry_delay
)
.await
;
// Exponential backoff with jitter
retry_delay
=
std
::
cmp
::
min
(
retry_delay
*
2
,
MAX_RETRY_DELAY
);
retry_delay
=
std
::
cmp
::
min
(
retry_delay
*
2
,
MAX_RETRY_DELAY
);
}
}
}
}
// If the watcher exits for some reason, wait a bit before restarting
warn!
(
warn!
(
"Kubernetes watcher exited, restarting in {} seconds"
,
"Kubernetes watcher exited, restarting in {} seconds"
,
config_arc
.check_interval
.as_secs
()
config_arc
.check_interval
.as_secs
()
...
@@ -354,9 +325,7 @@ async fn handle_pod_event(
...
@@ -354,9 +325,7 @@ async fn handle_pod_event(
)
{
)
{
let
worker_url
=
pod_info
.worker_url
(
port
);
let
worker_url
=
pod_info
.worker_url
(
port
);
// If pod is healthy, try to add it (with atomic check-and-insert)
if
pod_info
.is_healthy
()
{
if
pod_info
.is_healthy
()
{
// Atomic check-and-insert to prevent race conditions
let
should_add
=
{
let
should_add
=
{
let
mut
tracker
=
match
tracked_pods
.lock
()
{
let
mut
tracker
=
match
tracked_pods
.lock
()
{
Ok
(
tracker
)
=>
tracker
,
Ok
(
tracker
)
=>
tracker
,
...
@@ -367,9 +336,8 @@ async fn handle_pod_event(
...
@@ -367,9 +336,8 @@ async fn handle_pod_event(
};
};
if
tracker
.contains
(
pod_info
)
{
if
tracker
.contains
(
pod_info
)
{
false
// Already tracked
false
}
else
{
}
else
{
// Reserve the spot to prevent other threads from adding the same pod
tracker
.insert
(
pod_info
.clone
());
tracker
.insert
(
pod_info
.clone
());
true
true
}
}
...
@@ -381,7 +349,6 @@ async fn handle_pod_event(
...
@@ -381,7 +349,6 @@ async fn handle_pod_event(
pod_info
.name
,
pod_info
.pod_type
,
worker_url
pod_info
.name
,
pod_info
.pod_type
,
worker_url
);
);
// Build worker config based on pod type and routing mode
let
worker_type
=
if
pd_mode
{
let
worker_type
=
if
pd_mode
{
match
&
pod_info
.pod_type
{
match
&
pod_info
.pod_type
{
Some
(
PodType
::
Prefill
)
=>
Some
(
"prefill"
.to_string
()),
Some
(
PodType
::
Prefill
)
=>
Some
(
"prefill"
.to_string
()),
...
@@ -392,7 +359,6 @@ async fn handle_pod_event(
...
@@ -392,7 +359,6 @@ async fn handle_pod_event(
None
None
};
};
// Only set bootstrap_port for prefill workers in PD mode
let
bootstrap_port
=
if
pd_mode
{
let
bootstrap_port
=
if
pd_mode
{
match
&
pod_info
.pod_type
{
match
&
pod_info
.pod_type
{
Some
(
PodType
::
Prefill
)
=>
pod_info
.bootstrap_port
,
Some
(
PodType
::
Prefill
)
=>
pod_info
.bootstrap_port
,
...
@@ -425,7 +391,6 @@ async fn handle_pod_event(
...
@@ -425,7 +391,6 @@ async fn handle_pod_event(
}
}
Err
(
e
)
=>
{
Err
(
e
)
=>
{
error!
(
"Failed to add worker {} to router: {}"
,
worker_url
,
e
);
error!
(
"Failed to add worker {} to router: {}"
,
worker_url
,
e
);
// Remove from tracking since addition failed
if
let
Ok
(
mut
tracker
)
=
tracked_pods
.lock
()
{
if
let
Ok
(
mut
tracker
)
=
tracked_pods
.lock
()
{
tracker
.remove
(
pod_info
);
tracker
.remove
(
pod_info
);
}
}
...
@@ -464,8 +429,6 @@ async fn handle_pod_deletion(
...
@@ -464,8 +429,6 @@ async fn handle_pod_deletion(
error!
(
"Failed to remove worker {}: {}"
,
worker_url
,
e
);
error!
(
"Failed to remove worker {}: {}"
,
worker_url
,
e
);
}
}
}
else
{
}
else
{
// This case might occur if a pod is deleted before it was ever marked healthy and added.
// Or if the event is duplicated. No action needed on the router if it wasn't tracked (and thus not added).
debug!
(
debug!
(
"Pod deletion event for untracked/already removed pod: {} (type: {:?}). Worker URL: {}"
,
"Pod deletion event for untracked/already removed pod: {} (type: {:?}). Worker URL: {}"
,
pod_info
.name
,
pod_info
.pod_type
,
worker_url
pod_info
.name
,
pod_info
.pod_type
,
worker_url
...
@@ -480,7 +443,6 @@ mod tests {
...
@@ -480,7 +443,6 @@ mod tests {
use
k8s_openapi
::
apimachinery
::
pkg
::
apis
::
meta
::
v1
::
ObjectMeta
;
use
k8s_openapi
::
apimachinery
::
pkg
::
apis
::
meta
::
v1
::
ObjectMeta
;
use
k8s_openapi
::
apimachinery
::
pkg
::
apis
::
meta
::
v1
::
Time
;
use
k8s_openapi
::
apimachinery
::
pkg
::
apis
::
meta
::
v1
::
Time
;
// Helper function to create a Pod for testing PodInfo::from_pod
fn
create_k8s_pod
(
fn
create_k8s_pod
(
name
:
Option
<&
str
>
,
name
:
Option
<&
str
>
,
ip
:
Option
<&
str
>
,
ip
:
Option
<&
str
>
,
...
@@ -523,7 +485,6 @@ mod tests {
...
@@ -523,7 +485,6 @@ mod tests {
pod
pod
}
}
// Helper function to create a Pod with PD-specific labels and annotations
fn
create_pd_k8s_pod
(
name
:
&
str
,
ip
:
&
str
,
pod_type
:
&
str
,
bootstrap_port
:
Option
<
u16
>
)
->
Pod
{
fn
create_pd_k8s_pod
(
name
:
&
str
,
ip
:
&
str
,
pod_type
:
&
str
,
bootstrap_port
:
Option
<
u16
>
)
->
Pod
{
let
mut
labels
=
std
::
collections
::
BTreeMap
::
new
();
let
mut
labels
=
std
::
collections
::
BTreeMap
::
new
();
labels
.insert
(
"app"
.to_string
(),
"sglang"
.to_string
());
labels
.insert
(
"app"
.to_string
(),
"sglang"
.to_string
());
...
@@ -559,18 +520,15 @@ mod tests {
...
@@ -559,18 +520,15 @@ mod tests {
}
}
}
}
// Helper to create an AppContext instance for testing event handlers
async
fn
create_test_app_context
()
->
Arc
<
AppContext
>
{
async
fn
create_test_app_context
()
->
Arc
<
AppContext
>
{
use
crate
::
config
::
RouterConfig
;
use
crate
::
config
::
RouterConfig
;
use
crate
::
middleware
::
TokenBucket
;
use
crate
::
middleware
::
TokenBucket
;
// Create a minimal RouterConfig for testing with very short timeout
let
router_config
=
RouterConfig
{
let
router_config
=
RouterConfig
{
worker_startup_timeout_secs
:
1
,
worker_startup_timeout_secs
:
1
,
..
Default
::
default
()
..
Default
::
default
()
};
// Very short timeout for tests
};
// Create AppContext with minimal components
Arc
::
new
(
AppContext
{
Arc
::
new
(
AppContext
{
client
:
reqwest
::
Client
::
new
(),
client
:
reqwest
::
Client
::
new
(),
router_config
:
router_config
.clone
(),
router_config
:
router_config
.clone
(),
...
@@ -579,16 +537,15 @@ mod tests {
...
@@ -579,16 +537,15 @@ mod tests {
policy_registry
:
Arc
::
new
(
crate
::
policies
::
PolicyRegistry
::
new
(
policy_registry
:
Arc
::
new
(
crate
::
policies
::
PolicyRegistry
::
new
(
router_config
.policy
.clone
(),
router_config
.policy
.clone
(),
)),
)),
tokenizer
:
None
,
// HTTP mode doesn't need tokenizer
tokenizer
:
None
,
reasoning_parser_factory
:
None
,
// HTTP mode doesn't need reasoning parser
reasoning_parser_factory
:
None
,
tool_parser_registry
:
None
,
// HTTP mode doesn't need tool parser
tool_parser_registry
:
None
,
router_manager
:
None
,
// Test doesn't need router manager
router_manager
:
None
,
response_storage
:
Arc
::
new
(
crate
::
data_connector
::
MemoryResponseStorage
::
new
()),
response_storage
:
Arc
::
new
(
crate
::
data_connector
::
MemoryResponseStorage
::
new
()),
load_monitor
:
None
,
load_monitor
:
None
,
})
})
}
}
// Helper to create a PD config for testing
fn
create_pd_config
()
->
ServiceDiscoveryConfig
{
fn
create_pd_config
()
->
ServiceDiscoveryConfig
{
let
mut
prefill_selector
=
HashMap
::
new
();
let
mut
prefill_selector
=
HashMap
::
new
();
prefill_selector
.insert
(
"app"
.to_string
(),
"sglang"
.to_string
());
prefill_selector
.insert
(
"app"
.to_string
(),
"sglang"
.to_string
());
...
@@ -615,19 +572,15 @@ mod tests {
...
@@ -615,19 +572,15 @@ mod tests {
fn
test_pod_info_should_include
()
{
fn
test_pod_info_should_include
()
{
let
config
=
create_pd_config
();
let
config
=
create_pd_config
();
// Test prefill pod should be included
let
prefill_pod
=
create_pd_k8s_pod
(
"prefill-pod"
,
"10.0.0.1"
,
"prefill"
,
Some
(
8081
));
let
prefill_pod
=
create_pd_k8s_pod
(
"prefill-pod"
,
"10.0.0.1"
,
"prefill"
,
Some
(
8081
));
assert
!
(
PodInfo
::
should_include
(
&
prefill_pod
,
&
config
));
assert
!
(
PodInfo
::
should_include
(
&
prefill_pod
,
&
config
));
// Test decode pod should be included
let
decode_pod
=
create_pd_k8s_pod
(
"decode-pod"
,
"10.0.0.2"
,
"decode"
,
None
);
let
decode_pod
=
create_pd_k8s_pod
(
"decode-pod"
,
"10.0.0.2"
,
"decode"
,
None
);
assert
!
(
PodInfo
::
should_include
(
&
decode_pod
,
&
config
));
assert
!
(
PodInfo
::
should_include
(
&
decode_pod
,
&
config
));
// Test unmatched pod should not be included
let
unmatched_pod
=
create_pd_k8s_pod
(
"other-pod"
,
"10.0.0.3"
,
"other"
,
None
);
let
unmatched_pod
=
create_pd_k8s_pod
(
"other-pod"
,
"10.0.0.3"
,
"other"
,
None
);
assert
!
(
!
PodInfo
::
should_include
(
&
unmatched_pod
,
&
config
));
assert
!
(
!
PodInfo
::
should_include
(
&
unmatched_pod
,
&
config
));
// Test regular mode
let
mut
regular_config
=
ServiceDiscoveryConfig
::
default
();
let
mut
regular_config
=
ServiceDiscoveryConfig
::
default
();
regular_config
regular_config
.selector
.selector
...
@@ -654,7 +607,6 @@ mod tests {
...
@@ -654,7 +607,6 @@ mod tests {
#[test]
#[test]
fn
test_pod_type_enum
()
{
fn
test_pod_type_enum
()
{
// Test that PodType enum has expected variants
let
prefill
=
PodType
::
Prefill
;
let
prefill
=
PodType
::
Prefill
;
let
decode
=
PodType
::
Decode
;
let
decode
=
PodType
::
Decode
;
let
regular
=
PodType
::
Regular
;
let
regular
=
PodType
::
Regular
;
...
@@ -714,7 +666,7 @@ mod tests {
...
@@ -714,7 +666,7 @@ mod tests {
fn
test_pod_info_from_pod_with_pd_config_regular_mode
()
{
fn
test_pod_info_from_pod_with_pd_config_regular_mode
()
{
let
k8s_pod
=
create_pd_k8s_pod
(
"regular-pod"
,
"10.0.0.3"
,
"worker"
,
None
);
let
k8s_pod
=
create_pd_k8s_pod
(
"regular-pod"
,
"10.0.0.3"
,
"worker"
,
None
);
let
mut
config
=
create_pd_config
();
let
mut
config
=
create_pd_config
();
config
.pd_mode
=
false
;
// Set to regular mode
config
.pd_mode
=
false
;
let
pod_info
=
PodInfo
::
from_pod
(
&
k8s_pod
,
Some
(
&
config
))
.unwrap
();
let
pod_info
=
PodInfo
::
from_pod
(
&
k8s_pod
,
Some
(
&
config
))
.unwrap
();
assert_eq!
(
pod_info
.name
,
"regular-pod"
);
assert_eq!
(
pod_info
.name
,
"regular-pod"
);
...
@@ -742,7 +694,6 @@ mod tests {
...
@@ -742,7 +694,6 @@ mod tests {
#[test]
#[test]
fn
test_pod_info_from_pod_with_pd_config_invalid_bootstrap_port
()
{
fn
test_pod_info_from_pod_with_pd_config_invalid_bootstrap_port
()
{
let
mut
pod
=
create_pd_k8s_pod
(
"prefill-pod"
,
"10.0.0.1"
,
"prefill"
,
None
);
let
mut
pod
=
create_pd_k8s_pod
(
"prefill-pod"
,
"10.0.0.1"
,
"prefill"
,
None
);
// Add invalid bootstrap port annotation
pod
.metadata.annotations
.as_mut
()
.unwrap
()
.insert
(
pod
.metadata.annotations
.as_mut
()
.unwrap
()
.insert
(
"sglang.ai/bootstrap-port"
.to_string
(),
"sglang.ai/bootstrap-port"
.to_string
(),
"invalid"
.to_string
(),
"invalid"
.to_string
(),
...
@@ -751,7 +702,7 @@ mod tests {
...
@@ -751,7 +702,7 @@ mod tests {
let
pod_info
=
PodInfo
::
from_pod
(
&
pod
,
Some
(
&
config
))
.unwrap
();
let
pod_info
=
PodInfo
::
from_pod
(
&
pod
,
Some
(
&
config
))
.unwrap
();
assert_eq!
(
pod_info
.pod_type
,
Some
(
PodType
::
Prefill
));
assert_eq!
(
pod_info
.pod_type
,
Some
(
PodType
::
Prefill
));
assert
!
(
pod_info
.bootstrap_port
.is_none
());
// Should be None for invalid port
assert
!
(
pod_info
.bootstrap_port
.is_none
());
}
}
#[test]
#[test]
...
@@ -1077,7 +1028,6 @@ mod tests {
...
@@ -1077,7 +1028,6 @@ mod tests {
)
)
.await
;
.await
;
// Pod should not be tracked since add_worker_from_url will fail for non-running server
assert
!
(
!
tracked_pods
.lock
()
.unwrap
()
.contains
(
&
pod_info
));
assert
!
(
!
tracked_pods
.lock
()
.unwrap
()
.contains
(
&
pod_info
));
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment