Unverified Commit 3828db43 authored by Keyang Ru's avatar Keyang Ru Committed by GitHub
Browse files

[router] Add IGW (Inference Gateway) Feature Flag (#9371)


Co-authored-by: default avatarYineng Zhang <me@zhyncs.com>
parent 88fbc31b
...@@ -53,7 +53,7 @@ jobs: ...@@ -53,7 +53,7 @@ jobs:
cargo check --benches cargo check --benches
- name: Quick benchmark sanity check - name: Quick benchmark sanity check
timeout-minutes: 10 timeout-minutes: 15
run: | run: |
source "$HOME/.cargo/env" source "$HOME/.cargo/env"
cd sgl-router/ cd sgl-router/
......
...@@ -51,6 +51,9 @@ pub struct RouterConfig { ...@@ -51,6 +51,9 @@ pub struct RouterConfig {
pub disable_circuit_breaker: bool, pub disable_circuit_breaker: bool,
/// Health check configuration /// Health check configuration
pub health_check: HealthCheckConfig, pub health_check: HealthCheckConfig,
/// Enable Inference Gateway mode (false = proxy mode, true = IGW mode)
#[serde(default)]
pub enable_igw: bool,
} }
/// Routing mode configuration /// Routing mode configuration
...@@ -323,6 +326,7 @@ impl Default for RouterConfig { ...@@ -323,6 +326,7 @@ impl Default for RouterConfig {
disable_retries: false, disable_retries: false,
disable_circuit_breaker: false, disable_circuit_breaker: false,
health_check: HealthCheckConfig::default(), health_check: HealthCheckConfig::default(),
enable_igw: false,
} }
} }
} }
...@@ -377,6 +381,11 @@ impl RouterConfig { ...@@ -377,6 +381,11 @@ impl RouterConfig {
} }
cfg cfg
} }
/// Check if running in IGW (Inference Gateway) mode
pub fn is_igw_mode(&self) -> bool {
self.enable_igw
}
} }
#[cfg(test)] #[cfg(test)]
...@@ -456,6 +465,7 @@ mod tests { ...@@ -456,6 +465,7 @@ mod tests {
disable_retries: false, disable_retries: false,
disable_circuit_breaker: false, disable_circuit_breaker: false,
health_check: HealthCheckConfig::default(), health_check: HealthCheckConfig::default(),
enable_igw: false,
}; };
let json = serde_json::to_string(&config).unwrap(); let json = serde_json::to_string(&config).unwrap();
...@@ -888,6 +898,7 @@ mod tests { ...@@ -888,6 +898,7 @@ mod tests {
disable_retries: false, disable_retries: false,
disable_circuit_breaker: false, disable_circuit_breaker: false,
health_check: HealthCheckConfig::default(), health_check: HealthCheckConfig::default(),
enable_igw: false,
}; };
assert!(config.mode.is_pd_mode()); assert!(config.mode.is_pd_mode());
...@@ -944,6 +955,7 @@ mod tests { ...@@ -944,6 +955,7 @@ mod tests {
disable_retries: false, disable_retries: false,
disable_circuit_breaker: false, disable_circuit_breaker: false,
health_check: HealthCheckConfig::default(), health_check: HealthCheckConfig::default(),
enable_igw: false,
}; };
assert!(!config.mode.is_pd_mode()); assert!(!config.mode.is_pd_mode());
...@@ -996,6 +1008,7 @@ mod tests { ...@@ -996,6 +1008,7 @@ mod tests {
disable_retries: false, disable_retries: false,
disable_circuit_breaker: false, disable_circuit_breaker: false,
health_check: HealthCheckConfig::default(), health_check: HealthCheckConfig::default(),
enable_igw: false,
}; };
assert!(config.has_service_discovery()); assert!(config.has_service_discovery());
......
...@@ -344,6 +344,11 @@ impl ConfigValidator { ...@@ -344,6 +344,11 @@ impl ConfigValidator {
/// Validate compatibility between different configuration sections /// Validate compatibility between different configuration sections
fn validate_compatibility(config: &RouterConfig) -> ConfigResult<()> { fn validate_compatibility(config: &RouterConfig) -> ConfigResult<()> {
// IGW mode is independent - skip other compatibility checks when enabled
if config.enable_igw {
return Ok(());
}
// All policies are now supported for both router types thanks to the unified trait design // All policies are now supported for both router types thanks to the unified trait design
// No mode/policy restrictions needed anymore // No mode/policy restrictions needed anymore
......
...@@ -82,6 +82,8 @@ struct Router { ...@@ -82,6 +82,8 @@ struct Router {
health_check_timeout_secs: u64, health_check_timeout_secs: u64,
health_check_interval_secs: u64, health_check_interval_secs: u64,
health_check_endpoint: String, health_check_endpoint: String,
// IGW (Inference Gateway) configuration
enable_igw: bool,
} }
impl Router { impl Router {
...@@ -110,7 +112,12 @@ impl Router { ...@@ -110,7 +112,12 @@ impl Router {
}; };
// Determine routing mode // Determine routing mode
let mode = if self.pd_disaggregation { let mode = if self.enable_igw {
// IGW mode - routing mode is not used in IGW, but we need to provide a placeholder
RoutingMode::Regular {
worker_urls: vec![],
}
} else if self.pd_disaggregation {
RoutingMode::PrefillDecode { RoutingMode::PrefillDecode {
prefill_urls: self.prefill_urls.clone().unwrap_or_default(), prefill_urls: self.prefill_urls.clone().unwrap_or_default(),
decode_urls: self.decode_urls.clone().unwrap_or_default(), decode_urls: self.decode_urls.clone().unwrap_or_default(),
...@@ -191,6 +198,7 @@ impl Router { ...@@ -191,6 +198,7 @@ impl Router {
check_interval_secs: self.health_check_interval_secs, check_interval_secs: self.health_check_interval_secs,
endpoint: self.health_check_endpoint.clone(), endpoint: self.health_check_endpoint.clone(),
}, },
enable_igw: self.enable_igw,
}) })
} }
} }
...@@ -252,6 +260,8 @@ impl Router { ...@@ -252,6 +260,8 @@ impl Router {
health_check_timeout_secs = 5, health_check_timeout_secs = 5,
health_check_interval_secs = 60, health_check_interval_secs = 60,
health_check_endpoint = String::from("/health"), health_check_endpoint = String::from("/health"),
// IGW defaults
enable_igw = false,
))] ))]
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
fn new( fn new(
...@@ -305,6 +315,7 @@ impl Router { ...@@ -305,6 +315,7 @@ impl Router {
health_check_timeout_secs: u64, health_check_timeout_secs: u64,
health_check_interval_secs: u64, health_check_interval_secs: u64,
health_check_endpoint: String, health_check_endpoint: String,
enable_igw: bool,
) -> PyResult<Self> { ) -> PyResult<Self> {
Ok(Router { Ok(Router {
host, host,
...@@ -357,6 +368,7 @@ impl Router { ...@@ -357,6 +368,7 @@ impl Router {
health_check_timeout_secs, health_check_timeout_secs,
health_check_interval_secs, health_check_interval_secs,
health_check_endpoint, health_check_endpoint,
enable_igw,
}) })
} }
......
...@@ -70,6 +70,7 @@ Examples: ...@@ -70,6 +70,7 @@ Examples:
--decode http://127.0.0.3:30003 \ --decode http://127.0.0.3:30003 \
--decode http://127.0.0.4:30004 \ --decode http://127.0.0.4:30004 \
--prefill-policy cache_aware --decode-policy power_of_two --prefill-policy cache_aware --decode-policy power_of_two
"#)] "#)]
struct CliArgs { struct CliArgs {
/// Host address to bind the router server /// Host address to bind the router server
...@@ -266,6 +267,11 @@ struct CliArgs { ...@@ -266,6 +267,11 @@ struct CliArgs {
/// Health check endpoint path /// Health check endpoint path
#[arg(long, default_value = "/health")] #[arg(long, default_value = "/health")]
health_check_endpoint: String, health_check_endpoint: String,
// IGW (Inference Gateway) configuration
/// Enable Inference Gateway mode
#[arg(long, default_value_t = false)]
enable_igw: bool,
} }
impl CliArgs { impl CliArgs {
...@@ -307,7 +313,12 @@ impl CliArgs { ...@@ -307,7 +313,12 @@ impl CliArgs {
prefill_urls: Vec<(String, Option<u16>)>, prefill_urls: Vec<(String, Option<u16>)>,
) -> ConfigResult<RouterConfig> { ) -> ConfigResult<RouterConfig> {
// Determine routing mode // Determine routing mode
let mode = if self.pd_disaggregation { let mode = if self.enable_igw {
// IGW mode - routing mode is not used in IGW, but we need to provide a placeholder
RoutingMode::Regular {
worker_urls: vec![],
}
} else if self.pd_disaggregation {
let decode_urls = self.decode.clone(); let decode_urls = self.decode.clone();
// Validate PD configuration if not using service discovery // Validate PD configuration if not using service discovery
...@@ -406,6 +417,7 @@ impl CliArgs { ...@@ -406,6 +417,7 @@ impl CliArgs {
check_interval_secs: self.health_check_interval_secs, check_interval_secs: self.health_check_interval_secs,
endpoint: self.health_check_endpoint.clone(), endpoint: self.health_check_endpoint.clone(),
}, },
enable_igw: self.enable_igw,
}) })
} }
...@@ -487,17 +499,22 @@ fn main() -> Result<(), Box<dyn std::error::Error>> { ...@@ -487,17 +499,22 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
println!("Host: {}:{}", cli_args.host, cli_args.port); println!("Host: {}:{}", cli_args.host, cli_args.port);
println!( println!(
"Mode: {}", "Mode: {}",
if cli_args.pd_disaggregation { if cli_args.enable_igw {
"IGW (Inference Gateway)"
} else if cli_args.pd_disaggregation {
"PD Disaggregated" "PD Disaggregated"
} else { } else {
"Regular" "Regular"
} }
); );
println!("Policy: {}", cli_args.policy);
if cli_args.pd_disaggregation && !prefill_urls.is_empty() { if !cli_args.enable_igw {
println!("Prefill nodes: {:?}", prefill_urls); println!("Policy: {}", cli_args.policy);
println!("Decode nodes: {:?}", cli_args.decode);
if cli_args.pd_disaggregation && !prefill_urls.is_empty() {
println!("Prefill nodes: {:?}", prefill_urls);
println!("Decode nodes: {:?}", cli_args.decode);
}
} }
// Convert to RouterConfig // Convert to RouterConfig
......
...@@ -12,6 +12,12 @@ pub struct RouterFactory; ...@@ -12,6 +12,12 @@ pub struct RouterFactory;
impl RouterFactory { impl RouterFactory {
/// Create a router instance from application context /// Create a router instance from application context
pub async fn create_router(ctx: &Arc<AppContext>) -> Result<Box<dyn RouterTrait>, String> { pub async fn create_router(ctx: &Arc<AppContext>) -> Result<Box<dyn RouterTrait>, String> {
// Check if IGW mode is enabled
if ctx.router_config.enable_igw {
return Self::create_igw_router(ctx).await;
}
// Default to proxy mode
match &ctx.router_config.mode { match &ctx.router_config.mode {
RoutingMode::Regular { worker_urls } => { RoutingMode::Regular { worker_urls } => {
Self::create_regular_router(worker_urls, &ctx.router_config.policy, ctx).await Self::create_regular_router(worker_urls, &ctx.router_config.policy, ctx).await
...@@ -94,4 +100,10 @@ impl RouterFactory { ...@@ -94,4 +100,10 @@ impl RouterFactory {
Ok(Box::new(router)) Ok(Box::new(router))
} }
/// Create an IGW router (placeholder for future implementation)
async fn create_igw_router(_ctx: &Arc<AppContext>) -> Result<Box<dyn RouterTrait>, String> {
// For now, return an error indicating IGW is not yet implemented
Err("IGW mode is not yet implemented".to_string())
}
} }
...@@ -51,6 +51,7 @@ impl TestContext { ...@@ -51,6 +51,7 @@ impl TestContext {
disable_retries: false, disable_retries: false,
disable_circuit_breaker: false, disable_circuit_breaker: false,
health_check: sglang_router_rs::config::HealthCheckConfig::default(), health_check: sglang_router_rs::config::HealthCheckConfig::default(),
enable_igw: false,
}; };
Self::new_with_config(config, worker_configs).await Self::new_with_config(config, worker_configs).await
...@@ -1093,6 +1094,7 @@ mod error_tests { ...@@ -1093,6 +1094,7 @@ mod error_tests {
disable_retries: false, disable_retries: false,
disable_circuit_breaker: false, disable_circuit_breaker: false,
health_check: sglang_router_rs::config::HealthCheckConfig::default(), health_check: sglang_router_rs::config::HealthCheckConfig::default(),
enable_igw: false,
}; };
let ctx = TestContext::new_with_config( let ctx = TestContext::new_with_config(
...@@ -1444,6 +1446,7 @@ mod pd_mode_tests { ...@@ -1444,6 +1446,7 @@ mod pd_mode_tests {
disable_retries: false, disable_retries: false,
disable_circuit_breaker: false, disable_circuit_breaker: false,
health_check: sglang_router_rs::config::HealthCheckConfig::default(), health_check: sglang_router_rs::config::HealthCheckConfig::default(),
enable_igw: false,
}; };
// Create app context // Create app context
...@@ -1599,6 +1602,7 @@ mod request_id_tests { ...@@ -1599,6 +1602,7 @@ mod request_id_tests {
disable_retries: false, disable_retries: false,
disable_circuit_breaker: false, disable_circuit_breaker: false,
health_check: sglang_router_rs::config::HealthCheckConfig::default(), health_check: sglang_router_rs::config::HealthCheckConfig::default(),
enable_igw: false,
}; };
let ctx = TestContext::new_with_config( let ctx = TestContext::new_with_config(
......
...@@ -42,6 +42,7 @@ impl TestContext { ...@@ -42,6 +42,7 @@ impl TestContext {
disable_retries: false, disable_retries: false,
disable_circuit_breaker: false, disable_circuit_breaker: false,
health_check: sglang_router_rs::config::HealthCheckConfig::default(), health_check: sglang_router_rs::config::HealthCheckConfig::default(),
enable_igw: false,
}; };
let mut workers = Vec::new(); let mut workers = Vec::new();
......
...@@ -43,6 +43,7 @@ impl TestContext { ...@@ -43,6 +43,7 @@ impl TestContext {
disable_retries: false, disable_retries: false,
disable_circuit_breaker: false, disable_circuit_breaker: false,
health_check: sglang_router_rs::config::HealthCheckConfig::default(), health_check: sglang_router_rs::config::HealthCheckConfig::default(),
enable_igw: false,
}; };
let mut workers = Vec::new(); let mut workers = Vec::new();
......
...@@ -184,6 +184,7 @@ mod test_pd_routing { ...@@ -184,6 +184,7 @@ mod test_pd_routing {
disable_retries: false, disable_retries: false,
disable_circuit_breaker: false, disable_circuit_breaker: false,
health_check: sglang_router_rs::config::HealthCheckConfig::default(), health_check: sglang_router_rs::config::HealthCheckConfig::default(),
enable_igw: false,
}; };
// Router creation will fail due to health checks, but config should be valid // Router creation will fail due to health checks, but config should be valid
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment