"...git@developer.sourcefind.cn:OpenDAS/torch-harmonics.git" did not exist on "f9e8ad1849805c2c65424435214d0e6829852fa5"
Unverified Commit 53c2934d authored by Arthur Cheng's avatar Arthur Cheng Committed by GitHub
Browse files

[Router] Consolidate ConnectionMode enum to core module (#11937)

parent e321c971
......@@ -3,6 +3,7 @@ use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use super::ConfigResult;
use crate::core::ConnectionMode;
/// Main router configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
......@@ -208,16 +209,6 @@ impl std::fmt::Debug for OracleConfig {
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
#[serde(tag = "type")]
pub enum ConnectionMode {
#[default]
#[serde(rename = "http")]
Http,
#[serde(rename = "grpc")]
Grpc,
}
/// Routing mode configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type")]
......
use super::*;
use crate::core::ConnectionMode;
/// Configuration validator
pub struct ConfigValidator;
......@@ -476,7 +477,7 @@ impl ConfigValidator {
}
// Validate gRPC connection mode requires tokenizer configuration
if config.connection_mode == ConnectionMode::Grpc
if matches!(config.connection_mode, ConnectionMode::Grpc { .. })
&& config.tokenizer_path.is_none()
&& config.model_path.is_none()
{
......@@ -832,7 +833,7 @@ mod tests {
);
// Set connection mode to gRPC without tokenizer config
config.connection_mode = ConnectionMode::Grpc;
config.connection_mode = ConnectionMode::Grpc { port: None };
config.tokenizer_path = None;
config.model_path = None;
......@@ -852,7 +853,7 @@ mod tests {
PolicyConfig::Random,
);
config.connection_mode = ConnectionMode::Grpc;
config.connection_mode = ConnectionMode::Grpc { port: None };
config.model_path = Some("meta-llama/Llama-3-8B".to_string());
let result = ConfigValidator::validate(&config);
......@@ -868,7 +869,7 @@ mod tests {
PolicyConfig::Random,
);
config.connection_mode = ConnectionMode::Grpc;
config.connection_mode = ConnectionMode::Grpc { port: None };
config.tokenizer_path = Some("/path/to/tokenizer.json".to_string());
let result = ConfigValidator::validate(&config);
......
......@@ -8,6 +8,7 @@ use std::{
};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use serde_json;
use tokio::{sync::RwLock, time};
......@@ -240,13 +241,17 @@ pub trait Worker: Send + Sync + fmt::Debug {
}
/// Connection mode for worker communication
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
#[serde(tag = "type", rename_all = "lowercase")]
pub enum ConnectionMode {
/// HTTP/REST connection
#[default]
Http,
/// gRPC connection
Grpc {
/// Optional port for gRPC endpoint (if different from URL)
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(default)]
port: Option<u16>,
},
}
......
......@@ -194,7 +194,7 @@ struct Router {
queue_size: usize,
queue_timeout_secs: u64,
rate_limit_tokens_per_second: Option<i32>,
connection_mode: config::ConnectionMode,
connection_mode: core::ConnectionMode,
model_path: Option<String>,
tokenizer_path: Option<String>,
chat_template: Option<String>,
......@@ -211,13 +211,13 @@ struct Router {
impl Router {
/// Determine connection mode from worker URLs
fn determine_connection_mode(worker_urls: &[String]) -> config::ConnectionMode {
fn determine_connection_mode(worker_urls: &[String]) -> core::ConnectionMode {
for url in worker_urls {
if url.starts_with("grpc://") || url.starts_with("grpcs://") {
return config::ConnectionMode::Grpc;
return core::ConnectionMode::Grpc { port: None };
}
}
config::ConnectionMode::Http
core::ConnectionMode::Http
}
pub fn to_router_config(&self) -> config::ConfigResult<config::RouterConfig> {
......
......@@ -3,10 +3,11 @@ use std::collections::HashMap;
use clap::{ArgAction, Parser, ValueEnum};
use sglang_router_rs::{
config::{
CircuitBreakerConfig, ConfigError, ConfigResult, ConnectionMode, DiscoveryConfig,
HealthCheckConfig, HistoryBackend, MetricsConfig, OracleConfig, PolicyConfig, RetryConfig,
RouterConfig, RoutingMode, TokenizerCacheConfig,
CircuitBreakerConfig, ConfigError, ConfigResult, DiscoveryConfig, HealthCheckConfig,
HistoryBackend, MetricsConfig, OracleConfig, PolicyConfig, RetryConfig, RouterConfig,
RoutingMode, TokenizerCacheConfig,
},
core::ConnectionMode,
metrics::PrometheusConfig,
server::{self, ServerConfig},
service_discovery::ServiceDiscoveryConfig,
......@@ -325,7 +326,7 @@ impl CliArgs {
fn determine_connection_mode(worker_urls: &[String]) -> ConnectionMode {
for url in worker_urls {
if url.starts_with("grpc://") || url.starts_with("grpcs://") {
return ConnectionMode::Grpc;
return ConnectionMode::Grpc { port: None };
}
}
ConnectionMode::Http
......
......@@ -9,7 +9,8 @@ use super::{
RouterTrait,
};
use crate::{
config::{ConnectionMode, PolicyConfig, RoutingMode},
config::{PolicyConfig, RoutingMode},
core::ConnectionMode,
policies::PolicyFactory,
server::AppContext,
};
......@@ -21,7 +22,7 @@ impl RouterFactory {
/// Create a router instance from application context
pub async fn create_router(ctx: &Arc<AppContext>) -> Result<Box<dyn RouterTrait>, String> {
match ctx.router_config.connection_mode {
ConnectionMode::Grpc => match &ctx.router_config.mode {
ConnectionMode::Grpc { .. } => match &ctx.router_config.mode {
RoutingMode::Regular { .. } => Self::create_grpc_router(ctx).await,
RoutingMode::PrefillDecode {
prefill_policy,
......
......@@ -18,8 +18,8 @@ use serde_json::Value;
use tracing::{debug, info, warn};
use crate::{
config::{ConnectionMode, RoutingMode},
core::{WorkerRegistry, WorkerType},
config::RoutingMode,
core::{ConnectionMode, WorkerRegistry, WorkerType},
protocols::{
chat::ChatCompletionRequest,
classify::ClassifyRequest,
......@@ -148,13 +148,13 @@ impl RouterManager {
(ConnectionMode::Http, RoutingMode::OpenAI { .. }) => {
RouterId::new("http-openai".to_string())
}
(ConnectionMode::Grpc, RoutingMode::Regular { .. }) => {
(ConnectionMode::Grpc { .. }, RoutingMode::Regular { .. }) => {
RouterId::new("grpc-regular".to_string())
}
(ConnectionMode::Grpc, RoutingMode::PrefillDecode { .. }) => {
(ConnectionMode::Grpc { .. }, RoutingMode::PrefillDecode { .. }) => {
RouterId::new("grpc-pd".to_string())
}
(ConnectionMode::Grpc, RoutingMode::OpenAI { .. }) => {
(ConnectionMode::Grpc { .. }, RoutingMode::OpenAI { .. }) => {
RouterId::new("grpc-regular".to_string())
}
}
......
......@@ -20,14 +20,15 @@ use tokio::{net::TcpListener, signal, spawn};
use tracing::{error, info, warn, Level};
use crate::{
config::{ConnectionMode, HistoryBackend, RouterConfig, RoutingMode},
config::{HistoryBackend, RouterConfig, RoutingMode},
core::{
worker_to_info,
workflow::{
create_worker_registration_workflow, create_worker_removal_workflow, LoggingSubscriber,
WorkflowEngine,
},
Job, JobQueue, JobQueueConfig, LoadMonitor, WorkerManager, WorkerRegistry, WorkerType,
ConnectionMode, Job, JobQueue, JobQueueConfig, LoadMonitor, WorkerManager, WorkerRegistry,
WorkerType,
},
data_connector::{
MemoryConversationItemStorage, MemoryConversationStorage, MemoryResponseStorage,
......@@ -825,11 +826,10 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
};
// Initialize tokenizer and parser factories for gRPC mode
let (tokenizer, reasoning_parser_factory, tool_parser_factory) = if config
.router_config
.connection_mode
== ConnectionMode::Grpc
{
let (tokenizer, reasoning_parser_factory, tool_parser_factory) = if matches!(
config.router_config.connection_mode,
ConnectionMode::Grpc { .. }
) {
let tokenizer_path = config
.router_config
.tokenizer_path
......
......@@ -11,10 +11,8 @@ use common::mock_worker::{HealthStatus, MockWorker, MockWorkerConfig, WorkerType
use reqwest::Client;
use serde_json::json;
use sglang_router_rs::{
config::{
CircuitBreakerConfig, ConnectionMode, PolicyConfig, RetryConfig, RouterConfig, RoutingMode,
},
core::Job,
config::{CircuitBreakerConfig, PolicyConfig, RetryConfig, RouterConfig, RoutingMode},
core::{ConnectionMode, Job},
routers::{RouterFactory, RouterTrait},
server::AppContext,
};
......
......@@ -16,9 +16,10 @@ use common::{
};
use sglang_router_rs::{
config::{
CircuitBreakerConfig, ConnectionMode, HealthCheckConfig, PolicyConfig, RetryConfig,
RouterConfig, RoutingMode,
CircuitBreakerConfig, HealthCheckConfig, PolicyConfig, RetryConfig, RouterConfig,
RoutingMode,
},
core::ConnectionMode,
routers::RouterFactory,
};
......
......@@ -2,11 +2,8 @@
mod test_pd_routing {
use serde_json::json;
use sglang_router_rs::{
config::{
CircuitBreakerConfig, ConnectionMode, PolicyConfig, RetryConfig, RouterConfig,
RoutingMode,
},
core::{BasicWorkerBuilder, Worker, WorkerType},
config::{CircuitBreakerConfig, PolicyConfig, RetryConfig, RouterConfig, RoutingMode},
core::{BasicWorkerBuilder, ConnectionMode, Worker, WorkerType},
routers::{http::pd_types::PDSelectionPolicy, RouterFactory},
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment