Unverified Commit a73eb8cd authored by Keyang Ru's avatar Keyang Ru Committed by GitHub
Browse files

[router] Support Oracle DB(ATP) Data Connector (#10845)

parent e7387035
......@@ -19,7 +19,7 @@ name = "sglang-router"
path = "src/main.rs"
[dependencies]
clap = { version = "4", features = ["derive"] }
clap = { version = "4", features = ["derive", "env"] }
axum = { version = "0.8.4", features = ["macros", "ws", "tracing"] }
tower = { version = "0.5", features = ["full"] }
tower-http = { version = "0.6", features = ["trace", "compression-gzip", "cors", "timeout", "limit", "request-id", "util"] }
......@@ -69,6 +69,7 @@ rmcp = { version = "0.6.3", features = ["client", "server",
"reqwest",
"auth"] }
serde_yaml = "0.9"
oracle = { version = "0.6.3", features = ["chrono"] }
subtle = "2.6"
# gRPC and Protobuf dependencies
......
......@@ -70,6 +70,9 @@ pub struct RouterConfig {
/// History backend configuration (memory or none, default: memory)
#[serde(default = "default_history_backend")]
pub history_backend: HistoryBackend,
/// Oracle history backend configuration (required when `history_backend` = "oracle")
#[serde(skip_serializing_if = "Option::is_none")]
pub oracle: Option<OracleConfig>,
}
fn default_history_backend() -> HistoryBackend {
......@@ -84,6 +87,70 @@ pub enum HistoryBackend {
Memory,
/// No history storage
None,
/// Oracle ATP-backed storage
Oracle,
}
/// Oracle history backend configuration
#[derive(Clone, Serialize, Deserialize, PartialEq)]
pub struct OracleConfig {
/// Directory containing the ATP wallet or TLS config files (optional)
#[serde(skip_serializing_if = "Option::is_none")]
pub wallet_path: Option<String>,
/// Connection descriptor / DSN (e.g. `tcps://host:port/service`)
pub connect_descriptor: String,
/// Database username
pub username: String,
/// Database password
pub password: String,
/// Minimum number of pooled connections to keep ready
#[serde(default = "default_pool_min")]
pub pool_min: usize,
/// Maximum number of pooled connections
#[serde(default = "default_pool_max")]
pub pool_max: usize,
/// Maximum time to wait for a connection from the pool (seconds)
#[serde(default = "default_pool_timeout_secs")]
pub pool_timeout_secs: u64,
}
impl OracleConfig {
pub fn default_pool_min() -> usize {
default_pool_min()
}
pub fn default_pool_max() -> usize {
default_pool_max()
}
pub fn default_pool_timeout_secs() -> u64 {
default_pool_timeout_secs()
}
}
fn default_pool_min() -> usize {
1
}
fn default_pool_max() -> usize {
16
}
fn default_pool_timeout_secs() -> u64 {
30
}
impl std::fmt::Debug for OracleConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("OracleConfig")
.field("wallet_path", &self.wallet_path)
.field("connect_descriptor", &self.connect_descriptor)
.field("username", &self.username)
.field("pool_min", &self.pool_min)
.field("pool_max", &self.pool_max)
.field("pool_timeout_secs", &self.pool_timeout_secs)
.finish()
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
......@@ -381,6 +448,7 @@ impl Default for RouterConfig {
model_path: None,
tokenizer_path: None,
history_backend: default_history_backend(),
oracle: None,
}
}
}
......@@ -948,6 +1016,7 @@ mod tests {
model_path: None,
tokenizer_path: None,
history_backend: default_history_backend(),
oracle: None,
};
assert!(config.mode.is_pd_mode());
......@@ -1012,6 +1081,7 @@ mod tests {
model_path: None,
tokenizer_path: None,
history_backend: default_history_backend(),
oracle: None,
};
assert!(!config.mode.is_pd_mode());
......@@ -1072,6 +1142,7 @@ mod tests {
model_path: None,
tokenizer_path: None,
history_backend: default_history_backend(),
oracle: None,
};
assert!(config.has_service_discovery());
......
......@@ -29,6 +29,12 @@ impl ConfigValidator {
Self::validate_retry(&retry_cfg)?;
Self::validate_circuit_breaker(&cb_cfg)?;
if config.history_backend == HistoryBackend::Oracle && config.oracle.is_none() {
return Err(ConfigError::MissingRequired {
field: "oracle".to_string(),
});
}
Ok(())
}
......
// Data connector module for response storage
pub mod response_memory_store;
pub mod response_noop_store;
pub mod response_oracle_store;
pub mod responses;
pub use response_memory_store::MemoryResponseStorage;
pub use response_noop_store::NoOpResponseStorage;
pub use response_oracle_store::OracleResponseStorage;
pub use responses::{
ResponseChain, ResponseId, ResponseStorage, ResponseStorageError, SharedResponseStorage,
StoredResponse,
......
use crate::config::OracleConfig;
use crate::data_connector::responses::{
ResponseChain, ResponseId, ResponseStorage, ResponseStorageError, Result as StorageResult,
StoredResponse,
};
use async_trait::async_trait;
use deadpool::managed::{Manager, Metrics, Pool, PoolError, RecycleError, RecycleResult};
use oracle::{Connection, Row};
use serde_json::Value;
use std::collections::HashMap;
use std::path::Path;
use std::sync::Arc;
use std::time::Duration;
const SELECT_BASE: &str = "SELECT id, previous_response_id, input, instructions, output, \
tool_calls, metadata, created_at, user_id, model, raw_response FROM responses";
#[derive(Clone)]
pub struct OracleResponseStorage {
pool: Pool<OracleConnectionManager>,
}
impl OracleResponseStorage {
pub fn new(config: OracleConfig) -> StorageResult<Self> {
let config = Arc::new(config);
configure_oracle_client(&config)?;
initialize_schema(&config)?;
let manager = OracleConnectionManager::new(config.clone());
let mut builder = Pool::builder(manager)
.max_size(config.pool_max)
.runtime(deadpool::Runtime::Tokio1);
if config.pool_timeout_secs > 0 {
builder = builder.wait_timeout(Some(Duration::from_secs(config.pool_timeout_secs)));
}
let pool = builder.build().map_err(|err| {
ResponseStorageError::StorageError(format!(
"failed to build Oracle connection pool: {err}"
))
})?;
Ok(Self { pool })
}
async fn with_connection<F, T>(&self, func: F) -> StorageResult<T>
where
F: FnOnce(&Connection) -> StorageResult<T> + Send + 'static,
T: Send + 'static,
{
let connection = self.pool.get().await.map_err(map_pool_error)?;
tokio::task::spawn_blocking(move || {
let result = func(&connection);
drop(connection);
result
})
.await
.map_err(|err| {
ResponseStorageError::StorageError(format!(
"failed to execute Oracle query task: {err}"
))
})?
}
fn build_response_from_row(row: &Row) -> StorageResult<StoredResponse> {
let id: String = row
.get(0)
.map_err(|err| map_oracle_error(err).into_storage_error("fetch id"))?;
let previous: Option<String> = row.get(1).map_err(|err| {
map_oracle_error(err).into_storage_error("fetch previous_response_id")
})?;
let input: String = row
.get(2)
.map_err(|err| map_oracle_error(err).into_storage_error("fetch input"))?;
let instructions: Option<String> = row
.get(3)
.map_err(|err| map_oracle_error(err).into_storage_error("fetch instructions"))?;
let output: String = row
.get(4)
.map_err(|err| map_oracle_error(err).into_storage_error("fetch output"))?;
let tool_calls_json: Option<String> = row
.get(5)
.map_err(|err| map_oracle_error(err).into_storage_error("fetch tool_calls"))?;
let metadata_json: Option<String> = row
.get(6)
.map_err(|err| map_oracle_error(err).into_storage_error("fetch metadata"))?;
let created_at: chrono::DateTime<chrono::Utc> = row
.get(7)
.map_err(|err| map_oracle_error(err).into_storage_error("fetch created_at"))?;
let user_id: Option<String> = row
.get(8)
.map_err(|err| map_oracle_error(err).into_storage_error("fetch user_id"))?;
let model: Option<String> = row
.get(9)
.map_err(|err| map_oracle_error(err).into_storage_error("fetch model"))?;
let raw_response_json: Option<String> = row
.get(10)
.map_err(|err| map_oracle_error(err).into_storage_error("fetch raw_response"))?;
let previous_response_id = previous.map(ResponseId);
let tool_calls = parse_tool_calls(tool_calls_json)?;
let metadata = parse_metadata(metadata_json)?;
let raw_response = parse_raw_response(raw_response_json)?;
Ok(StoredResponse {
id: ResponseId(id),
previous_response_id,
input,
instructions,
output,
tool_calls,
metadata,
created_at,
user: user_id,
model,
raw_response,
})
}
}
#[async_trait]
impl ResponseStorage for OracleResponseStorage {
async fn store_response(&self, response: StoredResponse) -> StorageResult<ResponseId> {
let StoredResponse {
id,
previous_response_id,
input,
instructions,
output,
tool_calls,
metadata,
created_at,
user,
model,
raw_response,
} = response;
let response_id = id.clone();
let response_id_str = response_id.0.clone();
let previous_id = previous_response_id.map(|r| r.0);
let json_tool_calls = serde_json::to_string(&tool_calls)?;
let json_metadata = serde_json::to_string(&metadata)?;
let json_raw_response = serde_json::to_string(&raw_response)?;
self.with_connection(move |conn| {
conn.execute(
"INSERT INTO responses (id, previous_response_id, input, instructions, output, \
tool_calls, metadata, created_at, user_id, model, raw_response) \
VALUES (:1, :2, :3, :4, :5, :6, :7, :8, :9, :10, :11)",
&[
&response_id_str,
&previous_id,
&input,
&instructions,
&output,
&json_tool_calls,
&json_metadata,
&created_at,
&user,
&model,
&json_raw_response,
],
)
.map(|_| ())
.map_err(map_oracle_error)
})
.await?;
Ok(response_id)
}
async fn get_response(
&self,
response_id: &ResponseId,
) -> StorageResult<Option<StoredResponse>> {
let id = response_id.0.clone();
self.with_connection(move |conn| {
let mut stmt = conn
.statement(&format!("{} WHERE id = :1", SELECT_BASE))
.build()
.map_err(map_oracle_error)?;
let mut rows = stmt.query(&[&id]).map_err(map_oracle_error)?;
match rows.next() {
Some(row) => {
let row = row.map_err(map_oracle_error)?;
OracleResponseStorage::build_response_from_row(&row).map(Some)
}
None => Ok(None),
}
})
.await
}
async fn delete_response(&self, response_id: &ResponseId) -> StorageResult<()> {
let id = response_id.0.clone();
self.with_connection(move |conn| {
conn.execute("DELETE FROM responses WHERE id = :1", &[&id])
.map(|_| ())
.map_err(map_oracle_error)
})
.await
}
async fn get_response_chain(
&self,
response_id: &ResponseId,
max_depth: Option<usize>,
) -> StorageResult<ResponseChain> {
let mut chain = ResponseChain::new();
let mut current_id = Some(response_id.clone());
let mut visited = 0usize;
while let Some(ref lookup_id) = current_id {
if let Some(limit) = max_depth {
if visited >= limit {
break;
}
}
let fetched = self.get_response(lookup_id).await?;
match fetched {
Some(response) => {
current_id = response.previous_response_id.clone();
chain.responses.push(response);
visited += 1;
}
None => break,
}
}
chain.responses.reverse();
Ok(chain)
}
async fn list_user_responses(
&self,
user: &str,
limit: Option<usize>,
) -> StorageResult<Vec<StoredResponse>> {
let user = user.to_string();
self.with_connection(move |conn| {
let sql = if let Some(limit) = limit {
format!(
"SELECT * FROM ({} WHERE user_id = :1 ORDER BY created_at DESC) WHERE ROWNUM <= {}",
SELECT_BASE, limit
)
} else {
format!("{} WHERE user_id = :1 ORDER BY created_at DESC", SELECT_BASE)
};
let mut stmt = conn.statement(&sql).build().map_err(map_oracle_error)?;
let mut rows = stmt.query(&[&user]).map_err(map_oracle_error)?;
let mut results = Vec::new();
for row in &mut rows {
let row = row.map_err(map_oracle_error)?;
results.push(OracleResponseStorage::build_response_from_row(&row)?);
}
Ok(results)
})
.await
}
async fn delete_user_responses(&self, user: &str) -> StorageResult<usize> {
let user = user.to_string();
let affected = self
.with_connection(move |conn| {
conn.execute("DELETE FROM responses WHERE user_id = :1", &[&user])
.map_err(map_oracle_error)
})
.await?;
let deleted = affected.row_count().map_err(map_oracle_error)? as usize;
Ok(deleted)
}
}
#[derive(Clone)]
struct OracleConnectionManager {
params: Arc<OracleConnectParams>,
}
impl OracleConnectionManager {
fn new(config: Arc<OracleConfig>) -> Self {
let params = OracleConnectParams {
username: config.username.clone(),
password: config.password.clone(),
connect_descriptor: config.connect_descriptor.clone(),
};
Self {
params: Arc::new(params),
}
}
}
impl std::fmt::Debug for OracleConnectionManager {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("OracleConnectionManager")
.field("username", &self.params.username)
.field("connect_descriptor", &self.params.connect_descriptor)
.finish()
}
}
#[derive(Clone)]
struct OracleConnectParams {
username: String,
password: String,
connect_descriptor: String,
}
impl std::fmt::Debug for OracleConnectParams {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("OracleConnectParams")
.field("username", &self.username)
.field("connect_descriptor", &self.connect_descriptor)
.finish()
}
}
#[async_trait]
impl Manager for OracleConnectionManager {
type Type = Connection;
type Error = oracle::Error;
fn create(
&self,
) -> impl std::future::Future<Output = Result<Connection, oracle::Error>> + Send {
let params = self.params.clone();
async move {
let mut conn = Connection::connect(
&params.username,
&params.password,
&params.connect_descriptor,
)?;
conn.set_autocommit(true);
Ok(conn)
}
}
#[allow(clippy::manual_async_fn)]
fn recycle(
&self,
conn: &mut Connection,
_: &Metrics,
) -> impl std::future::Future<Output = RecycleResult<Self::Error>> + Send {
async move { conn.ping().map_err(RecycleError::Backend) }
}
}
fn configure_oracle_client(config: &OracleConfig) -> StorageResult<()> {
if let Some(wallet_path) = &config.wallet_path {
let wallet_path = Path::new(wallet_path);
if !wallet_path.is_dir() {
return Err(ResponseStorageError::StorageError(format!(
"Oracle wallet/config path '{}' is not a directory",
wallet_path.display()
)));
}
if !wallet_path.join("tnsnames.ora").exists() && !wallet_path.join("sqlnet.ora").exists() {
return Err(ResponseStorageError::StorageError(format!(
"Oracle wallet/config path '{}' is missing tnsnames.ora or sqlnet.ora",
wallet_path.display()
)));
}
std::env::set_var("TNS_ADMIN", wallet_path);
}
Ok(())
}
fn initialize_schema(config: &OracleConfig) -> StorageResult<()> {
let conn = Connection::connect(
&config.username,
&config.password,
&config.connect_descriptor,
)
.map_err(map_oracle_error)?;
let exists: i64 = conn
.query_row_as(
"SELECT COUNT(*) FROM user_tables WHERE table_name = 'RESPONSES'",
&[],
)
.map_err(map_oracle_error)?;
if exists == 0 {
conn.execute(
"CREATE TABLE responses (
id VARCHAR2(64) PRIMARY KEY,
previous_response_id VARCHAR2(64),
input CLOB,
instructions CLOB,
output CLOB,
tool_calls CLOB,
metadata CLOB,
created_at TIMESTAMP WITH TIME ZONE,
user_id VARCHAR2(128),
model VARCHAR2(128),
raw_response CLOB
)",
&[],
)
.map_err(map_oracle_error)?;
}
create_index_if_missing(
&conn,
"RESPONSES_PREV_IDX",
"CREATE INDEX responses_prev_idx ON responses(previous_response_id)",
)?;
create_index_if_missing(
&conn,
"RESPONSES_USER_IDX",
"CREATE INDEX responses_user_idx ON responses(user_id)",
)?;
Ok(())
}
fn create_index_if_missing(conn: &Connection, index_name: &str, ddl: &str) -> StorageResult<()> {
let count: i64 = conn
.query_row_as(
"SELECT COUNT(*) FROM user_indexes WHERE table_name = 'RESPONSES' AND index_name = :1",
&[&index_name],
)
.map_err(map_oracle_error)?;
if count == 0 {
if let Err(err) = conn.execute(ddl, &[]) {
if err.db_error().map(|db| db.code()) != Some(1408) {
return Err(map_oracle_error(err));
}
}
}
Ok(())
}
fn parse_tool_calls(raw: Option<String>) -> StorageResult<Vec<Value>> {
match raw {
Some(s) if !s.is_empty() => {
serde_json::from_str(&s).map_err(ResponseStorageError::SerializationError)
}
_ => Ok(Vec::new()),
}
}
fn parse_metadata(raw: Option<String>) -> StorageResult<HashMap<String, Value>> {
match raw {
Some(s) if !s.is_empty() => {
serde_json::from_str(&s).map_err(ResponseStorageError::SerializationError)
}
_ => Ok(HashMap::new()),
}
}
fn parse_raw_response(raw: Option<String>) -> StorageResult<Value> {
match raw {
Some(s) if !s.is_empty() => {
serde_json::from_str(&s).map_err(ResponseStorageError::SerializationError)
}
_ => Ok(Value::Null),
}
}
fn map_pool_error(err: PoolError<oracle::Error>) -> ResponseStorageError {
match err {
PoolError::Backend(e) => map_oracle_error(e),
other => ResponseStorageError::StorageError(format!(
"failed to obtain Oracle connection: {other}"
)),
}
}
fn map_oracle_error(err: oracle::Error) -> ResponseStorageError {
if let Some(db_err) = err.db_error() {
ResponseStorageError::StorageError(format!(
"Oracle error (code {}): {}",
db_err.code(),
db_err.message()
))
} else {
ResponseStorageError::StorageError(err.to_string())
}
}
trait OracleErrorExt {
fn into_storage_error(self, context: &str) -> ResponseStorageError;
}
impl OracleErrorExt for ResponseStorageError {
fn into_storage_error(self, context: &str) -> ResponseStorageError {
ResponseStorageError::StorageError(format!("{context}: {self}"))
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn parse_tool_calls_handles_empty_input() {
assert!(parse_tool_calls(None).unwrap().is_empty());
assert!(parse_tool_calls(Some(String::new())).unwrap().is_empty());
}
#[test]
fn parse_tool_calls_round_trips() {
let payload = json!([{ "type": "test", "value": 1 }]).to_string();
let parsed = parse_tool_calls(Some(payload)).unwrap();
assert_eq!(parsed.len(), 1);
assert_eq!(parsed[0]["type"], "test");
assert_eq!(parsed[0]["value"], 1);
}
#[test]
fn parse_metadata_defaults_to_empty_map() {
assert!(parse_metadata(None).unwrap().is_empty());
}
#[test]
fn parse_metadata_round_trips() {
let payload = json!({"key": "value", "nested": {"bool": true}}).to_string();
let parsed = parse_metadata(Some(payload)).unwrap();
assert_eq!(parsed.get("key").unwrap(), "value");
assert_eq!(parsed["nested"]["bool"], true);
}
#[test]
fn parse_raw_response_handles_null() {
assert_eq!(parse_raw_response(None).unwrap(), Value::Null);
}
#[test]
fn parse_raw_response_round_trips() {
let payload = json!({"id": "abc"}).to_string();
let parsed = parse_raw_response(Some(payload)).unwrap();
assert_eq!(parsed["id"], "abc");
}
}
......@@ -231,6 +231,7 @@ impl Router {
model_path: self.model_path.clone(),
tokenizer_path: self.tokenizer_path.clone(),
history_backend: config::HistoryBackend::Memory,
oracle: None,
})
}
}
......
use clap::{ArgAction, Parser, ValueEnum};
use sglang_router_rs::config::{
CircuitBreakerConfig, ConfigError, ConfigResult, ConnectionMode, DiscoveryConfig,
HealthCheckConfig, HistoryBackend, MetricsConfig, PolicyConfig, RetryConfig, RouterConfig,
RoutingMode,
HealthCheckConfig, HistoryBackend, MetricsConfig, OracleConfig, PolicyConfig, RetryConfig,
RouterConfig, RoutingMode,
};
use sglang_router_rs::metrics::PrometheusConfig;
use sglang_router_rs::server::{self, ServerConfig};
......@@ -314,9 +314,46 @@ struct CliArgs {
#[arg(long)]
tokenizer_path: Option<String>,
/// History backend configuration (memory or none)
#[arg(long, default_value = "memory", value_parser = ["memory", "none"])]
/// History backend configuration (memory, none, or oracle)
#[arg(long, default_value = "memory", value_parser = ["memory", "none", "oracle"])]
history_backend: String,
/// Directory containing the Oracle ATP wallet/config files (optional)
#[arg(long, env = "ATP_WALLET_PATH")]
oracle_wallet_path: Option<String>,
/// Wallet TNS alias to use (e.g. `<db_name>_low`)
#[arg(long, env = "ATP_TNS_ALIAS")]
oracle_tns_alias: Option<String>,
/// Oracle connection descriptor / DSN (e.g. `tcps://host:port/service_name`)
#[arg(long, env = "ATP_DSN")]
oracle_dsn: Option<String>,
/// Oracle ATP username
#[arg(long, env = "ATP_USER")]
oracle_user: Option<String>,
/// Oracle ATP password
#[arg(long, env = "ATP_PASSWORD")]
oracle_password: Option<String>,
/// Minimum number of pooled ATP connections (defaults to 1 when omitted)
#[arg(long, env = "ATP_POOL_MIN")]
oracle_pool_min: Option<usize>,
/// Maximum number of pooled ATP connections (defaults to 16 when omitted)
#[arg(long, env = "ATP_POOL_MAX")]
oracle_pool_max: Option<usize>,
/// Connection acquisition timeout in seconds (defaults to 30 when omitted)
#[arg(long, env = "ATP_POOL_TIMEOUT_SECS")]
oracle_pool_timeout_secs: Option<u64>,
}
enum OracleConnectSource {
Dsn { descriptor: String },
Wallet { path: String, alias: String },
}
impl CliArgs {
......@@ -364,6 +401,87 @@ impl CliArgs {
}
}
fn resolve_oracle_connect_details(&self) -> ConfigResult<OracleConnectSource> {
if let Some(dsn) = self.oracle_dsn.clone() {
return Ok(OracleConnectSource::Dsn { descriptor: dsn });
}
let wallet_path = self
.oracle_wallet_path
.clone()
.ok_or(ConfigError::MissingRequired {
field: "oracle_wallet_path or ATP_WALLET_PATH".to_string(),
})?;
let tns_alias = self
.oracle_tns_alias
.clone()
.ok_or(ConfigError::MissingRequired {
field: "oracle_tns_alias or ATP_TNS_ALIAS".to_string(),
})?;
Ok(OracleConnectSource::Wallet {
path: wallet_path,
alias: tns_alias,
})
}
fn build_oracle_config(&self) -> ConfigResult<OracleConfig> {
let (wallet_path, connect_descriptor) = match self.resolve_oracle_connect_details()? {
OracleConnectSource::Dsn { descriptor } => (None, descriptor),
OracleConnectSource::Wallet { path, alias } => (Some(path), alias),
};
let username = self
.oracle_user
.clone()
.ok_or(ConfigError::MissingRequired {
field: "oracle_user or ATP_USER".to_string(),
})?;
let password = self
.oracle_password
.clone()
.ok_or(ConfigError::MissingRequired {
field: "oracle_password or ATP_PASSWORD".to_string(),
})?;
let pool_min = self
.oracle_pool_min
.unwrap_or_else(OracleConfig::default_pool_min);
let pool_max = self
.oracle_pool_max
.unwrap_or_else(OracleConfig::default_pool_max);
if pool_min == 0 {
return Err(ConfigError::InvalidValue {
field: "oracle_pool_min".to_string(),
value: pool_min.to_string(),
reason: "pool minimum must be at least 1".to_string(),
});
}
if pool_max < pool_min {
return Err(ConfigError::InvalidValue {
field: "oracle_pool_max".to_string(),
value: pool_max.to_string(),
reason: "pool maximum must be greater than or equal to minimum".to_string(),
});
}
let pool_timeout_secs = self
.oracle_pool_timeout_secs
.unwrap_or_else(OracleConfig::default_pool_timeout_secs);
Ok(OracleConfig {
wallet_path,
connect_descriptor,
username,
password,
pool_min,
pool_max,
pool_timeout_secs,
})
}
/// Convert CLI arguments to RouterConfig
fn to_router_config(
&self,
......@@ -459,6 +577,18 @@ impl CliArgs {
_ => Self::determine_connection_mode(&all_urls),
};
let history_backend = match self.history_backend.as_str() {
"none" => HistoryBackend::None,
"oracle" => HistoryBackend::Oracle,
_ => HistoryBackend::Memory,
};
let oracle = if history_backend == HistoryBackend::Oracle {
Some(self.build_oracle_config()?)
} else {
None
};
// Build RouterConfig
Ok(RouterConfig {
mode,
......@@ -511,10 +641,8 @@ impl CliArgs {
rate_limit_tokens_per_second: None,
model_path: self.model_path.clone(),
tokenizer_path: self.tokenizer_path.clone(),
history_backend: match self.history_backend.as_str() {
"none" => HistoryBackend::None,
_ => HistoryBackend::Memory,
},
history_backend,
oracle,
})
}
......
use crate::{
config::{ConnectionMode, HistoryBackend, RouterConfig, RoutingMode},
core::{WorkerManager, WorkerRegistry, WorkerType},
data_connector::{MemoryResponseStorage, NoOpResponseStorage, SharedResponseStorage},
data_connector::{
MemoryResponseStorage, NoOpResponseStorage, OracleResponseStorage, SharedResponseStorage,
},
logging::{self, LoggingConfig},
metrics::{self, PrometheusConfig},
middleware::{self, AuthConfig, QueuedRequest, TokenBucket},
......@@ -92,6 +94,17 @@ impl AppContext {
let response_storage: SharedResponseStorage = match router_config.history_backend {
HistoryBackend::Memory => Arc::new(MemoryResponseStorage::new()),
HistoryBackend::None => Arc::new(NoOpResponseStorage::new()),
HistoryBackend::Oracle => {
let oracle_cfg = router_config.oracle.clone().ok_or_else(|| {
"oracle configuration is required when history_backend=oracle".to_string()
})?;
let storage = OracleResponseStorage::new(oracle_cfg).map_err(|err| {
format!("failed to initialize Oracle response storage: {err}")
})?;
Arc::new(storage)
}
};
Ok(Self {
......
......@@ -62,6 +62,7 @@ impl TestContext {
model_path: None,
tokenizer_path: None,
history_backend: sglang_router_rs::config::HistoryBackend::Memory,
oracle: None,
};
Self::new_with_config(config, worker_configs).await
......@@ -1401,6 +1402,7 @@ mod error_tests {
model_path: None,
tokenizer_path: None,
history_backend: sglang_router_rs::config::HistoryBackend::Memory,
oracle: None,
};
let ctx = TestContext::new_with_config(
......@@ -1760,6 +1762,7 @@ mod pd_mode_tests {
model_path: None,
tokenizer_path: None,
history_backend: sglang_router_rs::config::HistoryBackend::Memory,
oracle: None,
};
// Create app context
......@@ -1923,6 +1926,7 @@ mod request_id_tests {
model_path: None,
tokenizer_path: None,
history_backend: sglang_router_rs::config::HistoryBackend::Memory,
oracle: None,
};
let ctx = TestContext::new_with_config(
......
......@@ -10,7 +10,9 @@ use axum::{
};
use serde_json::json;
use sglang_router_rs::{
config::{RouterConfig, RoutingMode},
config::{
ConfigError, ConfigValidator, HistoryBackend, OracleConfig, RouterConfig, RoutingMode,
},
data_connector::{MemoryResponseStorage, ResponseId, ResponseStorage, StoredResponse},
protocols::spec::{
ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateRequest, ResponseInput,
......@@ -823,3 +825,69 @@ async fn test_openai_router_models_auth_forwarding() {
let models: serde_json::Value = serde_json::from_str(&body_str).unwrap();
assert_eq!(models["object"], "list");
}
#[test]
fn oracle_config_validation_requires_config_when_enabled() {
let config = RouterConfig {
mode: RoutingMode::OpenAI {
worker_urls: vec!["https://api.openai.com".to_string()],
},
history_backend: HistoryBackend::Oracle,
oracle: None,
..Default::default()
};
let err =
ConfigValidator::validate(&config).expect_err("config should fail without oracle details");
match err {
ConfigError::MissingRequired { field } => {
assert_eq!(field, "oracle");
}
other => panic!("unexpected error: {:?}", other),
}
}
#[test]
fn oracle_config_validation_accepts_dsn_only() {
let config = RouterConfig {
mode: RoutingMode::OpenAI {
worker_urls: vec!["https://api.openai.com".to_string()],
},
history_backend: HistoryBackend::Oracle,
oracle: Some(OracleConfig {
wallet_path: None,
connect_descriptor: "tcps://db.example.com:1522/service".to_string(),
username: "scott".to_string(),
password: "tiger".to_string(),
pool_min: 1,
pool_max: 4,
pool_timeout_secs: 30,
}),
..Default::default()
};
ConfigValidator::validate(&config).expect("dsn-based config should validate");
}
#[test]
fn oracle_config_validation_accepts_wallet_alias() {
let config = RouterConfig {
mode: RoutingMode::OpenAI {
worker_urls: vec!["https://api.openai.com".to_string()],
},
history_backend: HistoryBackend::Oracle,
oracle: Some(OracleConfig {
wallet_path: Some("/etc/sglang/oracle-wallet".to_string()),
connect_descriptor: "db_low".to_string(),
username: "app_user".to_string(),
password: "secret".to_string(),
pool_min: 1,
pool_max: 8,
pool_timeout_secs: 45,
}),
..Default::default()
};
ConfigValidator::validate(&config).expect("wallet-based config should validate");
}
......@@ -208,6 +208,7 @@ mod test_pd_routing {
model_path: None,
tokenizer_path: None,
history_backend: sglang_router_rs::config::HistoryBackend::Memory,
oracle: None,
};
let app_context =
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment