Unverified Commit 53c2934d authored by Arthur Cheng's avatar Arthur Cheng Committed by GitHub
Browse files

[Router] Consolidate ConnectionMode enum to core module (#11937)

parent e321c971
...@@ -3,6 +3,7 @@ use std::collections::HashMap; ...@@ -3,6 +3,7 @@ use std::collections::HashMap;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use super::ConfigResult; use super::ConfigResult;
use crate::core::ConnectionMode;
/// Main router configuration /// Main router configuration
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
...@@ -208,16 +209,6 @@ impl std::fmt::Debug for OracleConfig { ...@@ -208,16 +209,6 @@ impl std::fmt::Debug for OracleConfig {
} }
} }
#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
#[serde(tag = "type")]
pub enum ConnectionMode {
#[default]
#[serde(rename = "http")]
Http,
#[serde(rename = "grpc")]
Grpc,
}
/// Routing mode configuration /// Routing mode configuration
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type")] #[serde(tag = "type")]
......
use super::*; use super::*;
use crate::core::ConnectionMode;
/// Configuration validator /// Configuration validator
pub struct ConfigValidator; pub struct ConfigValidator;
...@@ -476,7 +477,7 @@ impl ConfigValidator { ...@@ -476,7 +477,7 @@ impl ConfigValidator {
} }
// Validate gRPC connection mode requires tokenizer configuration // Validate gRPC connection mode requires tokenizer configuration
if config.connection_mode == ConnectionMode::Grpc if matches!(config.connection_mode, ConnectionMode::Grpc { .. })
&& config.tokenizer_path.is_none() && config.tokenizer_path.is_none()
&& config.model_path.is_none() && config.model_path.is_none()
{ {
...@@ -832,7 +833,7 @@ mod tests { ...@@ -832,7 +833,7 @@ mod tests {
); );
// Set connection mode to gRPC without tokenizer config // Set connection mode to gRPC without tokenizer config
config.connection_mode = ConnectionMode::Grpc; config.connection_mode = ConnectionMode::Grpc { port: None };
config.tokenizer_path = None; config.tokenizer_path = None;
config.model_path = None; config.model_path = None;
...@@ -852,7 +853,7 @@ mod tests { ...@@ -852,7 +853,7 @@ mod tests {
PolicyConfig::Random, PolicyConfig::Random,
); );
config.connection_mode = ConnectionMode::Grpc; config.connection_mode = ConnectionMode::Grpc { port: None };
config.model_path = Some("meta-llama/Llama-3-8B".to_string()); config.model_path = Some("meta-llama/Llama-3-8B".to_string());
let result = ConfigValidator::validate(&config); let result = ConfigValidator::validate(&config);
...@@ -868,7 +869,7 @@ mod tests { ...@@ -868,7 +869,7 @@ mod tests {
PolicyConfig::Random, PolicyConfig::Random,
); );
config.connection_mode = ConnectionMode::Grpc; config.connection_mode = ConnectionMode::Grpc { port: None };
config.tokenizer_path = Some("/path/to/tokenizer.json".to_string()); config.tokenizer_path = Some("/path/to/tokenizer.json".to_string());
let result = ConfigValidator::validate(&config); let result = ConfigValidator::validate(&config);
......
...@@ -8,6 +8,7 @@ use std::{ ...@@ -8,6 +8,7 @@ use std::{
}; };
use async_trait::async_trait; use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use serde_json; use serde_json;
use tokio::{sync::RwLock, time}; use tokio::{sync::RwLock, time};
...@@ -240,13 +241,17 @@ pub trait Worker: Send + Sync + fmt::Debug { ...@@ -240,13 +241,17 @@ pub trait Worker: Send + Sync + fmt::Debug {
} }
/// Connection mode for worker communication /// Connection mode for worker communication
#[derive(Debug, Clone, PartialEq, Eq, Hash)] #[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
#[serde(tag = "type", rename_all = "lowercase")]
pub enum ConnectionMode { pub enum ConnectionMode {
/// HTTP/REST connection /// HTTP/REST connection
#[default]
Http, Http,
/// gRPC connection /// gRPC connection
Grpc { Grpc {
/// Optional port for gRPC endpoint (if different from URL) /// Optional port for gRPC endpoint (if different from URL)
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(default)]
port: Option<u16>, port: Option<u16>,
}, },
} }
......
...@@ -194,7 +194,7 @@ struct Router { ...@@ -194,7 +194,7 @@ struct Router {
queue_size: usize, queue_size: usize,
queue_timeout_secs: u64, queue_timeout_secs: u64,
rate_limit_tokens_per_second: Option<i32>, rate_limit_tokens_per_second: Option<i32>,
connection_mode: config::ConnectionMode, connection_mode: core::ConnectionMode,
model_path: Option<String>, model_path: Option<String>,
tokenizer_path: Option<String>, tokenizer_path: Option<String>,
chat_template: Option<String>, chat_template: Option<String>,
...@@ -211,13 +211,13 @@ struct Router { ...@@ -211,13 +211,13 @@ struct Router {
impl Router { impl Router {
/// Determine connection mode from worker URLs /// Determine connection mode from worker URLs
fn determine_connection_mode(worker_urls: &[String]) -> config::ConnectionMode { fn determine_connection_mode(worker_urls: &[String]) -> core::ConnectionMode {
for url in worker_urls { for url in worker_urls {
if url.starts_with("grpc://") || url.starts_with("grpcs://") { if url.starts_with("grpc://") || url.starts_with("grpcs://") {
return config::ConnectionMode::Grpc; return core::ConnectionMode::Grpc { port: None };
} }
} }
config::ConnectionMode::Http core::ConnectionMode::Http
} }
pub fn to_router_config(&self) -> config::ConfigResult<config::RouterConfig> { pub fn to_router_config(&self) -> config::ConfigResult<config::RouterConfig> {
......
...@@ -3,10 +3,11 @@ use std::collections::HashMap; ...@@ -3,10 +3,11 @@ use std::collections::HashMap;
use clap::{ArgAction, Parser, ValueEnum}; use clap::{ArgAction, Parser, ValueEnum};
use sglang_router_rs::{ use sglang_router_rs::{
config::{ config::{
CircuitBreakerConfig, ConfigError, ConfigResult, ConnectionMode, DiscoveryConfig, CircuitBreakerConfig, ConfigError, ConfigResult, DiscoveryConfig, HealthCheckConfig,
HealthCheckConfig, HistoryBackend, MetricsConfig, OracleConfig, PolicyConfig, RetryConfig, HistoryBackend, MetricsConfig, OracleConfig, PolicyConfig, RetryConfig, RouterConfig,
RouterConfig, RoutingMode, TokenizerCacheConfig, RoutingMode, TokenizerCacheConfig,
}, },
core::ConnectionMode,
metrics::PrometheusConfig, metrics::PrometheusConfig,
server::{self, ServerConfig}, server::{self, ServerConfig},
service_discovery::ServiceDiscoveryConfig, service_discovery::ServiceDiscoveryConfig,
...@@ -325,7 +326,7 @@ impl CliArgs { ...@@ -325,7 +326,7 @@ impl CliArgs {
fn determine_connection_mode(worker_urls: &[String]) -> ConnectionMode { fn determine_connection_mode(worker_urls: &[String]) -> ConnectionMode {
for url in worker_urls { for url in worker_urls {
if url.starts_with("grpc://") || url.starts_with("grpcs://") { if url.starts_with("grpc://") || url.starts_with("grpcs://") {
return ConnectionMode::Grpc; return ConnectionMode::Grpc { port: None };
} }
} }
ConnectionMode::Http ConnectionMode::Http
......
...@@ -9,7 +9,8 @@ use super::{ ...@@ -9,7 +9,8 @@ use super::{
RouterTrait, RouterTrait,
}; };
use crate::{ use crate::{
config::{ConnectionMode, PolicyConfig, RoutingMode}, config::{PolicyConfig, RoutingMode},
core::ConnectionMode,
policies::PolicyFactory, policies::PolicyFactory,
server::AppContext, server::AppContext,
}; };
...@@ -21,7 +22,7 @@ impl RouterFactory { ...@@ -21,7 +22,7 @@ impl RouterFactory {
/// Create a router instance from application context /// Create a router instance from application context
pub async fn create_router(ctx: &Arc<AppContext>) -> Result<Box<dyn RouterTrait>, String> { pub async fn create_router(ctx: &Arc<AppContext>) -> Result<Box<dyn RouterTrait>, String> {
match ctx.router_config.connection_mode { match ctx.router_config.connection_mode {
ConnectionMode::Grpc => match &ctx.router_config.mode { ConnectionMode::Grpc { .. } => match &ctx.router_config.mode {
RoutingMode::Regular { .. } => Self::create_grpc_router(ctx).await, RoutingMode::Regular { .. } => Self::create_grpc_router(ctx).await,
RoutingMode::PrefillDecode { RoutingMode::PrefillDecode {
prefill_policy, prefill_policy,
......
...@@ -18,8 +18,8 @@ use serde_json::Value; ...@@ -18,8 +18,8 @@ use serde_json::Value;
use tracing::{debug, info, warn}; use tracing::{debug, info, warn};
use crate::{ use crate::{
config::{ConnectionMode, RoutingMode}, config::RoutingMode,
core::{WorkerRegistry, WorkerType}, core::{ConnectionMode, WorkerRegistry, WorkerType},
protocols::{ protocols::{
chat::ChatCompletionRequest, chat::ChatCompletionRequest,
classify::ClassifyRequest, classify::ClassifyRequest,
...@@ -148,13 +148,13 @@ impl RouterManager { ...@@ -148,13 +148,13 @@ impl RouterManager {
(ConnectionMode::Http, RoutingMode::OpenAI { .. }) => { (ConnectionMode::Http, RoutingMode::OpenAI { .. }) => {
RouterId::new("http-openai".to_string()) RouterId::new("http-openai".to_string())
} }
(ConnectionMode::Grpc, RoutingMode::Regular { .. }) => { (ConnectionMode::Grpc { .. }, RoutingMode::Regular { .. }) => {
RouterId::new("grpc-regular".to_string()) RouterId::new("grpc-regular".to_string())
} }
(ConnectionMode::Grpc, RoutingMode::PrefillDecode { .. }) => { (ConnectionMode::Grpc { .. }, RoutingMode::PrefillDecode { .. }) => {
RouterId::new("grpc-pd".to_string()) RouterId::new("grpc-pd".to_string())
} }
(ConnectionMode::Grpc, RoutingMode::OpenAI { .. }) => { (ConnectionMode::Grpc { .. }, RoutingMode::OpenAI { .. }) => {
RouterId::new("grpc-regular".to_string()) RouterId::new("grpc-regular".to_string())
} }
} }
......
...@@ -20,14 +20,15 @@ use tokio::{net::TcpListener, signal, spawn}; ...@@ -20,14 +20,15 @@ use tokio::{net::TcpListener, signal, spawn};
use tracing::{error, info, warn, Level}; use tracing::{error, info, warn, Level};
use crate::{ use crate::{
config::{ConnectionMode, HistoryBackend, RouterConfig, RoutingMode}, config::{HistoryBackend, RouterConfig, RoutingMode},
core::{ core::{
worker_to_info, worker_to_info,
workflow::{ workflow::{
create_worker_registration_workflow, create_worker_removal_workflow, LoggingSubscriber, create_worker_registration_workflow, create_worker_removal_workflow, LoggingSubscriber,
WorkflowEngine, WorkflowEngine,
}, },
Job, JobQueue, JobQueueConfig, LoadMonitor, WorkerManager, WorkerRegistry, WorkerType, ConnectionMode, Job, JobQueue, JobQueueConfig, LoadMonitor, WorkerManager, WorkerRegistry,
WorkerType,
}, },
data_connector::{ data_connector::{
MemoryConversationItemStorage, MemoryConversationStorage, MemoryResponseStorage, MemoryConversationItemStorage, MemoryConversationStorage, MemoryResponseStorage,
...@@ -825,11 +826,10 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err ...@@ -825,11 +826,10 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
}; };
// Initialize tokenizer and parser factories for gRPC mode // Initialize tokenizer and parser factories for gRPC mode
let (tokenizer, reasoning_parser_factory, tool_parser_factory) = if config let (tokenizer, reasoning_parser_factory, tool_parser_factory) = if matches!(
.router_config config.router_config.connection_mode,
.connection_mode ConnectionMode::Grpc { .. }
== ConnectionMode::Grpc ) {
{
let tokenizer_path = config let tokenizer_path = config
.router_config .router_config
.tokenizer_path .tokenizer_path
......
...@@ -11,10 +11,8 @@ use common::mock_worker::{HealthStatus, MockWorker, MockWorkerConfig, WorkerType ...@@ -11,10 +11,8 @@ use common::mock_worker::{HealthStatus, MockWorker, MockWorkerConfig, WorkerType
use reqwest::Client; use reqwest::Client;
use serde_json::json; use serde_json::json;
use sglang_router_rs::{ use sglang_router_rs::{
config::{ config::{CircuitBreakerConfig, PolicyConfig, RetryConfig, RouterConfig, RoutingMode},
CircuitBreakerConfig, ConnectionMode, PolicyConfig, RetryConfig, RouterConfig, RoutingMode, core::{ConnectionMode, Job},
},
core::Job,
routers::{RouterFactory, RouterTrait}, routers::{RouterFactory, RouterTrait},
server::AppContext, server::AppContext,
}; };
......
...@@ -16,9 +16,10 @@ use common::{ ...@@ -16,9 +16,10 @@ use common::{
}; };
use sglang_router_rs::{ use sglang_router_rs::{
config::{ config::{
CircuitBreakerConfig, ConnectionMode, HealthCheckConfig, PolicyConfig, RetryConfig, CircuitBreakerConfig, HealthCheckConfig, PolicyConfig, RetryConfig, RouterConfig,
RouterConfig, RoutingMode, RoutingMode,
}, },
core::ConnectionMode,
routers::RouterFactory, routers::RouterFactory,
}; };
......
...@@ -2,11 +2,8 @@ ...@@ -2,11 +2,8 @@
mod test_pd_routing { mod test_pd_routing {
use serde_json::json; use serde_json::json;
use sglang_router_rs::{ use sglang_router_rs::{
config::{ config::{CircuitBreakerConfig, PolicyConfig, RetryConfig, RouterConfig, RoutingMode},
CircuitBreakerConfig, ConnectionMode, PolicyConfig, RetryConfig, RouterConfig, core::{BasicWorkerBuilder, ConnectionMode, Worker, WorkerType},
RoutingMode,
},
core::{BasicWorkerBuilder, Worker, WorkerType},
routers::{http::pd_types::PDSelectionPolicy, RouterFactory}, routers::{http::pd_types::PDSelectionPolicy, RouterFactory},
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment