Unverified Commit 90313fb0 authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

[router] add token bucket rate limiter (#9656)

parent 3578eb1e
......@@ -72,6 +72,12 @@ class RouterArgs:
request_timeout_secs: int = 1800
# Max concurrent requests for rate limiting
max_concurrent_requests: int = 256
# Queue size for pending requests when max concurrent limit reached
queue_size: int = 100
# Maximum time (in seconds) a request can wait in queue before timing out
queue_timeout_secs: int = 60
# Token bucket refill rate (tokens per second). If not set, defaults to max_concurrent_requests
rate_limit_tokens_per_second: Optional[int] = None
# CORS allowed origins
cors_allowed_origins: List[str] = dataclasses.field(default_factory=list)
# Retry configuration
......@@ -402,6 +408,24 @@ class RouterArgs:
default=RouterArgs.max_concurrent_requests,
help="Maximum number of concurrent requests allowed (for rate limiting)",
)
parser.add_argument(
f"--{prefix}queue-size",
type=int,
default=RouterArgs.queue_size,
help="Queue size for pending requests when max concurrent limit reached (0 = no queue, return 429 immediately)",
)
parser.add_argument(
f"--{prefix}queue-timeout-secs",
type=int,
default=RouterArgs.queue_timeout_secs,
help="Maximum time (in seconds) a request can wait in queue before timing out",
)
parser.add_argument(
f"--{prefix}rate-limit-tokens-per-second",
type=int,
default=RouterArgs.rate_limit_tokens_per_second,
help="Token bucket refill rate (tokens per second). If not set, defaults to max_concurrent_requests",
)
parser.add_argument(
f"--{prefix}cors-allowed-origins",
type=str,
......@@ -478,6 +502,21 @@ class RouterArgs:
f"{prefix}max_concurrent_requests",
RouterArgs.max_concurrent_requests,
),
queue_size=getattr(
args,
f"{prefix}queue_size",
RouterArgs.queue_size,
),
queue_timeout_secs=getattr(
args,
f"{prefix}queue_timeout_secs",
RouterArgs.queue_timeout_secs,
),
rate_limit_tokens_per_second=getattr(
args,
f"{prefix}rate_limit_tokens_per_second",
RouterArgs.rate_limit_tokens_per_second,
),
cors_allowed_origins=getattr(args, f"{prefix}cors_allowed_origins", []),
retry_max_retries=getattr(args, f"{prefix}retry_max_retries"),
retry_initial_backoff_ms=getattr(args, f"{prefix}retry_initial_backoff_ms"),
......@@ -700,6 +739,9 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]:
),
request_id_headers=router_args.request_id_headers,
max_concurrent_requests=router_args.max_concurrent_requests,
queue_size=router_args.queue_size,
queue_timeout_secs=router_args.queue_timeout_secs,
rate_limit_tokens_per_second=router_args.rate_limit_tokens_per_second,
cors_allowed_origins=router_args.cors_allowed_origins,
retry_max_retries=router_args.retry_max_retries,
retry_initial_backoff_ms=router_args.retry_initial_backoff_ms,
......
......@@ -64,7 +64,10 @@ class Router:
bootstrap_port_annotation: Kubernetes annotation name for bootstrap port (PD mode).
Default: 'sglang.ai/bootstrap-port'
request_timeout_secs: Request timeout in seconds. Default: 600
max_concurrent_requests: Maximum number of concurrent requests allowed for rate limiting. Default: 64
max_concurrent_requests: Maximum number of concurrent requests allowed for rate limiting. Default: 256
queue_size: Queue size for pending requests when max concurrent limit reached (0 = no queue, return 429 immediately). Default: 100
queue_timeout_secs: Maximum time (in seconds) a request can wait in queue before timing out. Default: 60
rate_limit_tokens_per_second: Token bucket refill rate (tokens per second). If not set, defaults to max_concurrent_requests. Default: None
cors_allowed_origins: List of allowed origins for CORS. Empty list allows all origins. Default: []
health_failure_threshold: Number of consecutive health check failures before marking worker unhealthy. Default: 3
health_success_threshold: Number of consecutive health check successes before marking worker healthy. Default: 2
......@@ -108,6 +111,9 @@ class Router:
prefill_policy: Optional[PolicyType] = None,
decode_policy: Optional[PolicyType] = None,
max_concurrent_requests: int = 256,
queue_size: int = 100,
queue_timeout_secs: int = 60,
rate_limit_tokens_per_second: Optional[int] = None,
cors_allowed_origins: List[str] = None,
retry_max_retries: int = 5,
retry_initial_backoff_ms: int = 50,
......@@ -169,6 +175,9 @@ class Router:
prefill_policy=prefill_policy,
decode_policy=decode_policy,
max_concurrent_requests=max_concurrent_requests,
queue_size=queue_size,
queue_timeout_secs=queue_timeout_secs,
rate_limit_tokens_per_second=rate_limit_tokens_per_second,
cors_allowed_origins=cors_allowed_origins,
retry_max_retries=retry_max_retries,
retry_initial_backoff_ms=retry_initial_backoff_ms,
......
......@@ -37,6 +37,12 @@ pub struct RouterConfig {
pub request_id_headers: Option<Vec<String>>,
/// Maximum concurrent requests allowed (for rate limiting)
pub max_concurrent_requests: usize,
/// Queue size for pending requests when max concurrent limit reached (0 = no queue, return 429 immediately)
pub queue_size: usize,
/// Maximum time (in seconds) a request can wait in queue before timing out
pub queue_timeout_secs: u64,
/// Token bucket refill rate (tokens per second). If not set, defaults to max_concurrent_requests
pub rate_limit_tokens_per_second: Option<usize>,
/// CORS allowed origins
pub cors_allowed_origins: Vec<String>,
/// Retry configuration
......@@ -320,6 +326,9 @@ impl Default for RouterConfig {
log_level: None,
request_id_headers: None,
max_concurrent_requests: 256,
queue_size: 100,
queue_timeout_secs: 60,
rate_limit_tokens_per_second: None,
cors_allowed_origins: vec![],
retry: RetryConfig::default(),
circuit_breaker: CircuitBreakerConfig::default(),
......@@ -466,6 +475,9 @@ mod tests {
disable_circuit_breaker: false,
health_check: HealthCheckConfig::default(),
enable_igw: false,
queue_size: 100,
queue_timeout_secs: 60,
rate_limit_tokens_per_second: None,
};
let json = serde_json::to_string(&config).unwrap();
......@@ -899,6 +911,9 @@ mod tests {
disable_circuit_breaker: false,
health_check: HealthCheckConfig::default(),
enable_igw: false,
queue_size: 100,
queue_timeout_secs: 60,
rate_limit_tokens_per_second: None,
};
assert!(config.mode.is_pd_mode());
......@@ -956,6 +971,9 @@ mod tests {
disable_circuit_breaker: false,
health_check: HealthCheckConfig::default(),
enable_igw: false,
queue_size: 100,
queue_timeout_secs: 60,
rate_limit_tokens_per_second: None,
};
assert!(!config.mode.is_pd_mode());
......@@ -1009,6 +1027,9 @@ mod tests {
disable_circuit_breaker: false,
health_check: HealthCheckConfig::default(),
enable_igw: false,
queue_size: 100,
queue_timeout_secs: 60,
rate_limit_tokens_per_second: None,
};
assert!(config.has_service_discovery());
......
......@@ -9,6 +9,7 @@
pub mod circuit_breaker;
pub mod error;
pub mod retry;
pub mod token_bucket;
pub mod worker;
// Re-export commonly used types at the module level
......
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::{Mutex, Notify};
use tracing::{debug, trace};
/// Token bucket for rate limiting
///
/// This implementation provides:
/// - Smooth rate limiting with configurable refill rate
/// - Burst capacity handling
/// - Fair queuing for waiting requests
#[derive(Clone)]
pub struct TokenBucket {
inner: Arc<Mutex<TokenBucketInner>>,
notify: Arc<Notify>,
capacity: f64,
refill_rate: f64, // tokens per second
}
struct TokenBucketInner {
tokens: f64,
last_refill: Instant,
}
impl TokenBucket {
/// Create a new token bucket
///
/// # Arguments
/// * `capacity` - Maximum number of tokens (burst capacity)
/// * `refill_rate` - Tokens added per second
pub fn new(capacity: usize, refill_rate: usize) -> Self {
let capacity = capacity as f64;
let refill_rate = refill_rate as f64;
// Ensure refill_rate is not zero to prevent division by zero
let refill_rate = if refill_rate > 0.0 {
refill_rate
} else {
1.0 // Default to 1 token per second if zero
};
Self {
inner: Arc::new(Mutex::new(TokenBucketInner {
tokens: capacity, // Start full
last_refill: Instant::now(),
})),
notify: Arc::new(Notify::new()),
capacity,
refill_rate,
}
}
/// Try to acquire tokens immediately
pub async fn try_acquire(&self, tokens: f64) -> Result<(), ()> {
let mut inner = self.inner.lock().await;
// Refill tokens based on elapsed time
let now = Instant::now();
let elapsed = now.duration_since(inner.last_refill).as_secs_f64();
let refill_amount = elapsed * self.refill_rate;
inner.tokens = (inner.tokens + refill_amount).min(self.capacity);
inner.last_refill = now;
trace!(
"Token bucket: {} tokens available, requesting {}",
inner.tokens,
tokens
);
if inner.tokens >= tokens {
inner.tokens -= tokens;
debug!(
"Token bucket: acquired {} tokens, {} remaining",
tokens, inner.tokens
);
Ok(())
} else {
Err(())
}
}
/// Acquire tokens, waiting if necessary
pub async fn acquire(&self, tokens: f64) -> Result<(), tokio::time::error::Elapsed> {
// First try to acquire immediately
if self.try_acquire(tokens).await.is_ok() {
return Ok(());
}
// Calculate wait time
let wait_time = {
let inner = self.inner.lock().await;
let tokens_needed = tokens - inner.tokens;
let wait_secs = tokens_needed / self.refill_rate;
Duration::from_secs_f64(wait_secs)
};
debug!(
"Token bucket: waiting {:?} for {} tokens",
wait_time, tokens
);
// Wait for tokens to be available
tokio::time::timeout(wait_time, async {
loop {
// Check if we can acquire now
if self.try_acquire(tokens).await.is_ok() {
return;
}
// Wait for notification or small interval
tokio::select! {
_ = self.notify.notified() => {},
_ = tokio::time::sleep(Duration::from_millis(10)) => {},
}
}
})
.await?;
Ok(())
}
/// Acquire tokens with custom timeout
pub async fn acquire_timeout(
&self,
tokens: f64,
timeout: Duration,
) -> Result<(), tokio::time::error::Elapsed> {
tokio::time::timeout(timeout, self.acquire(tokens)).await?
}
/// Return tokens to the bucket (for cancelled requests)
pub async fn return_tokens(&self, tokens: f64) {
let mut inner = self.inner.lock().await;
inner.tokens = (inner.tokens + tokens).min(self.capacity);
self.notify.notify_waiters();
debug!(
"Token bucket: returned {} tokens, {} available",
tokens, inner.tokens
);
}
/// Get current available tokens (for monitoring)
pub async fn available_tokens(&self) -> f64 {
let mut inner = self.inner.lock().await;
// Refill before checking
let now = Instant::now();
let elapsed = now.duration_since(inner.last_refill).as_secs_f64();
let refill_amount = elapsed * self.refill_rate;
inner.tokens = (inner.tokens + refill_amount).min(self.capacity);
inner.last_refill = now;
inner.tokens
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_token_bucket_basic() {
let bucket = TokenBucket::new(10, 5); // 10 capacity, 5 per second
// Should succeed - bucket starts full
assert!(bucket.try_acquire(5.0).await.is_ok());
assert!(bucket.try_acquire(5.0).await.is_ok());
// Should fail - no tokens left
assert!(bucket.try_acquire(1.0).await.is_err());
// Wait for refill
tokio::time::sleep(Duration::from_millis(300)).await;
// Should have ~1.5 tokens now
assert!(bucket.try_acquire(1.0).await.is_ok());
}
#[tokio::test]
async fn test_token_bucket_refill() {
let bucket = TokenBucket::new(10, 10); // 10 capacity, 10 per second
// Use all tokens
assert!(bucket.try_acquire(10.0).await.is_ok());
// Wait for partial refill
tokio::time::sleep(Duration::from_millis(500)).await;
// Should have ~5 tokens
let available = bucket.available_tokens().await;
assert!((4.0..=6.0).contains(&available));
}
}
......@@ -85,6 +85,9 @@ struct Router {
health_check_endpoint: String,
// IGW (Inference Gateway) configuration
enable_igw: bool,
queue_size: usize,
queue_timeout_secs: u64,
rate_limit_tokens_per_second: Option<usize>,
}
impl Router {
......@@ -176,6 +179,9 @@ impl Router {
log_level: self.log_level.clone(),
request_id_headers: self.request_id_headers.clone(),
max_concurrent_requests: self.max_concurrent_requests,
queue_size: self.queue_size,
queue_timeout_secs: self.queue_timeout_secs,
rate_limit_tokens_per_second: self.rate_limit_tokens_per_second,
cors_allowed_origins: self.cors_allowed_origins.clone(),
retry: config::RetryConfig {
max_retries: self.retry_max_retries,
......@@ -190,8 +196,8 @@ impl Router {
timeout_duration_secs: self.cb_timeout_duration_secs,
window_duration_secs: self.cb_window_duration_secs,
},
disable_retries: false,
disable_circuit_breaker: false,
disable_retries: self.disable_retries,
disable_circuit_breaker: self.disable_circuit_breaker,
health_check: config::HealthCheckConfig {
failure_threshold: self.health_failure_threshold,
success_threshold: self.health_success_threshold,
......@@ -263,6 +269,9 @@ impl Router {
health_check_endpoint = String::from("/health"),
// IGW defaults
enable_igw = false,
queue_size = 100,
queue_timeout_secs = 60,
rate_limit_tokens_per_second = None,
))]
#[allow(clippy::too_many_arguments)]
fn new(
......@@ -317,6 +326,9 @@ impl Router {
health_check_interval_secs: u64,
health_check_endpoint: String,
enable_igw: bool,
queue_size: usize,
queue_timeout_secs: u64,
rate_limit_tokens_per_second: Option<usize>,
) -> PyResult<Self> {
Ok(Router {
host,
......@@ -370,6 +382,9 @@ impl Router {
health_check_interval_secs,
health_check_endpoint,
enable_igw,
queue_size,
queue_timeout_secs,
rate_limit_tokens_per_second,
})
}
......
......@@ -394,6 +394,8 @@ impl CliArgs {
Some(self.request_id_headers.clone())
},
max_concurrent_requests: self.max_concurrent_requests,
queue_size: 100, // Default queue size
queue_timeout_secs: 60, // Default timeout
cors_allowed_origins: self.cors_allowed_origins.clone(),
retry: RetryConfig {
max_retries: self.retry_max_retries,
......@@ -418,6 +420,7 @@ impl CliArgs {
endpoint: self.health_check_endpoint.clone(),
},
enable_igw: self.enable_igw,
rate_limit_tokens_per_second: None,
})
}
......
use axum::{extract::Request, http::HeaderValue, response::Response};
use axum::{
extract::Request, extract::State, http::HeaderValue, http::StatusCode, middleware::Next,
response::IntoResponse, response::Response,
};
use rand::Rng;
use std::sync::Arc;
use std::time::Duration;
use std::time::Instant;
use tokio::sync::{mpsc, oneshot};
use tower::{Layer, Service};
use tower_http::trace::{MakeSpan, OnRequest, OnResponse, TraceLayer};
use tracing::{field::Empty, info_span, Span};
use tracing::{debug, error, field::Empty, info, info_span, warn, Span};
pub use crate::core::token_bucket::TokenBucket;
use crate::server::AppState;
/// Generate OpenAI-compatible request ID based on endpoint
fn generate_request_id(path: &str) -> String {
......@@ -313,3 +322,181 @@ pub fn log_request(entry: RequestLogEntry) {
);
}
}
// ============ Concurrency Limiting with Queue Support ============
/// Request queue entry
pub struct QueuedRequest {
/// Time when the request was queued
queued_at: Instant,
/// Channel to send the permit back when acquired
permit_tx: oneshot::Sender<Result<(), StatusCode>>,
}
/// Queue metrics for monitoring
#[derive(Debug, Default)]
pub struct QueueMetrics {
pub total_queued: std::sync::atomic::AtomicU64,
pub current_queued: std::sync::atomic::AtomicU64,
pub total_timeout: std::sync::atomic::AtomicU64,
pub total_rejected: std::sync::atomic::AtomicU64,
}
/// Queue processor that handles queued requests
pub struct QueueProcessor {
token_bucket: Arc<TokenBucket>,
queue_rx: mpsc::Receiver<QueuedRequest>,
queue_timeout: Duration,
}
impl QueueProcessor {
pub fn new(
token_bucket: Arc<TokenBucket>,
queue_rx: mpsc::Receiver<QueuedRequest>,
queue_timeout: Duration,
) -> Self {
Self {
token_bucket,
queue_rx,
queue_timeout,
}
}
pub async fn run(mut self) {
info!("Starting concurrency queue processor");
// Process requests in a single task to reduce overhead
while let Some(queued) = self.queue_rx.recv().await {
// Check timeout immediately
let elapsed = queued.queued_at.elapsed();
if elapsed >= self.queue_timeout {
warn!("Request already timed out in queue");
let _ = queued.permit_tx.send(Err(StatusCode::REQUEST_TIMEOUT));
continue;
}
let remaining_timeout = self.queue_timeout - elapsed;
// Try to acquire token for this request
if self.token_bucket.try_acquire(1.0).await.is_ok() {
// Got token immediately
debug!("Queue: acquired token immediately for queued request");
let _ = queued.permit_tx.send(Ok(()));
} else {
// Need to wait for token
let token_bucket = self.token_bucket.clone();
// Spawn task only when we actually need to wait
tokio::spawn(async move {
if token_bucket
.acquire_timeout(1.0, remaining_timeout)
.await
.is_ok()
{
debug!("Queue: acquired token after waiting");
let _ = queued.permit_tx.send(Ok(()));
} else {
warn!("Queue: request timed out waiting for token");
let _ = queued.permit_tx.send(Err(StatusCode::REQUEST_TIMEOUT));
}
});
}
}
warn!("Concurrency queue processor shutting down");
}
}
/// State for the concurrency limiter
pub struct ConcurrencyLimiter {
pub queue_tx: Option<mpsc::Sender<QueuedRequest>>,
}
impl ConcurrencyLimiter {
/// Create new concurrency limiter with optional queue
pub fn new(
token_bucket: Arc<TokenBucket>,
queue_size: usize,
queue_timeout: Duration,
) -> (Self, Option<QueueProcessor>) {
if queue_size > 0 {
let (queue_tx, queue_rx) = mpsc::channel(queue_size);
let processor = QueueProcessor::new(token_bucket, queue_rx, queue_timeout);
(
Self {
queue_tx: Some(queue_tx),
},
Some(processor),
)
} else {
(Self { queue_tx: None }, None)
}
}
}
/// Middleware function for concurrency limiting with optional queuing
pub async fn concurrency_limit_middleware(
State(app_state): State<Arc<AppState>>,
request: Request<axum::body::Body>,
next: Next,
) -> Response {
let token_bucket = app_state.context.rate_limiter.clone();
// Try to acquire token immediately
if token_bucket.try_acquire(1.0).await.is_ok() {
debug!("Acquired token immediately");
let response = next.run(request).await;
// Return the token to the bucket
token_bucket.return_tokens(1.0).await;
response
} else {
// No tokens available, try to queue if enabled
if let Some(queue_tx) = &app_state.concurrency_queue_tx {
debug!("No tokens available, attempting to queue request");
// Create a channel for the token response
let (permit_tx, permit_rx) = oneshot::channel();
let queued = QueuedRequest {
queued_at: Instant::now(),
permit_tx,
};
// Try to send to queue
match queue_tx.try_send(queued) {
Ok(_) => {
// Wait for token from queue processor
match permit_rx.await {
Ok(Ok(())) => {
debug!("Acquired token from queue");
let response = next.run(request).await;
// Return the token to the bucket
token_bucket.return_tokens(1.0).await;
response
}
Ok(Err(status)) => {
warn!("Queue returned error status: {}", status);
status.into_response()
}
Err(_) => {
error!("Queue response channel closed");
StatusCode::INTERNAL_SERVER_ERROR.into_response()
}
}
}
Err(_) => {
warn!("Request queue is full, returning 429");
StatusCode::TOO_MANY_REQUESTS.into_response()
}
}
} else {
warn!("No tokens available and queuing is disabled, returning 429");
StatusCode::TOO_MANY_REQUESTS.into_response()
}
}
}
use crate::config::RouterConfig;
use crate::logging::{self, LoggingConfig};
use crate::metrics::{self, PrometheusConfig};
use crate::middleware::TokenBucket;
use crate::protocols::spec::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
use crate::routers::{RouterFactory, RouterTrait};
use crate::service_discovery::{start_service_discovery, ServiceDiscoveryConfig};
......@@ -25,7 +26,7 @@ use tracing::{error, info, warn, Level};
pub struct AppContext {
pub client: Client,
pub router_config: RouterConfig,
pub concurrency_limiter: Arc<tokio::sync::Semaphore>,
pub rate_limiter: Arc<TokenBucket>,
// Future dependencies can be added here
}
......@@ -34,12 +35,14 @@ impl AppContext {
router_config: RouterConfig,
client: Client,
max_concurrent_requests: usize,
rate_limit_tokens_per_second: Option<usize>,
) -> Self {
let concurrency_limiter = Arc::new(tokio::sync::Semaphore::new(max_concurrent_requests));
let rate_limit_tokens = rate_limit_tokens_per_second.unwrap_or(max_concurrent_requests);
let rate_limiter = Arc::new(TokenBucket::new(max_concurrent_requests, rate_limit_tokens));
Self {
client,
router_config,
concurrency_limiter,
rate_limiter,
}
}
}
......@@ -48,6 +51,7 @@ impl AppContext {
pub struct AppState {
pub router: Arc<dyn RouterTrait>,
pub context: Arc<AppContext>,
pub concurrency_queue_tx: Option<tokio::sync::mpsc::Sender<crate::middleware::QueuedRequest>>,
}
// Fallback handler for unmatched routes
......@@ -186,7 +190,11 @@ pub fn build_app(
let protected_routes = Router::new()
.route("/generate", post(generate))
.route("/v1/chat/completions", post(v1_chat_completions))
.route("/v1/completions", post(v1_completions));
.route("/v1/completions", post(v1_completions))
.route_layer(axum::middleware::from_fn_with_state(
app_state.clone(),
crate::middleware::concurrency_limit_middleware,
));
let public_routes = Router::new()
.route("/liveness", get(liveness))
......@@ -282,15 +290,33 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
config.router_config.clone(),
client.clone(),
config.router_config.max_concurrent_requests,
config.router_config.rate_limit_tokens_per_second,
));
// Create router with the context
let router = RouterFactory::create_router(&app_context).await?;
// Set up concurrency limiter with queue if configured
let (limiter, processor) = crate::middleware::ConcurrencyLimiter::new(
app_context.rate_limiter.clone(),
config.router_config.queue_size,
Duration::from_secs(config.router_config.queue_timeout_secs),
);
// Start queue processor if enabled
if let Some(processor) = processor {
tokio::spawn(processor.run());
info!(
"Started request queue with size: {}, timeout: {}s",
config.router_config.queue_size, config.router_config.queue_timeout_secs
);
}
// Create app state with router and context
let app_state = Arc::new(AppState {
router: Arc::from(router),
context: app_context.clone(),
concurrency_queue_tx: limiter.queue_tx.clone(),
});
let router_arc = Arc::clone(&app_state.router);
......
......@@ -45,6 +45,9 @@ impl TestContext {
log_level: None,
request_id_headers: None,
max_concurrent_requests: 64,
queue_size: 0,
queue_timeout_secs: 60,
rate_limit_tokens_per_second: None,
cors_allowed_origins: vec![],
retry: RetryConfig::default(),
circuit_breaker: CircuitBreakerConfig::default(),
......@@ -1088,6 +1091,9 @@ mod error_tests {
log_level: None,
request_id_headers: None,
max_concurrent_requests: 64,
queue_size: 0,
queue_timeout_secs: 60,
rate_limit_tokens_per_second: None,
cors_allowed_origins: vec![],
retry: RetryConfig::default(),
circuit_breaker: CircuitBreakerConfig::default(),
......@@ -1440,6 +1446,9 @@ mod pd_mode_tests {
log_level: None,
request_id_headers: None,
max_concurrent_requests: 64,
queue_size: 0,
queue_timeout_secs: 60,
rate_limit_tokens_per_second: None,
cors_allowed_origins: vec![],
retry: RetryConfig::default(),
circuit_breaker: CircuitBreakerConfig::default(),
......@@ -1596,6 +1605,9 @@ mod request_id_tests {
log_level: None,
request_id_headers: Some(vec!["custom-id".to_string(), "trace-id".to_string()]),
max_concurrent_requests: 64,
queue_size: 0,
queue_timeout_secs: 60,
rate_limit_tokens_per_second: None,
cors_allowed_origins: vec![],
retry: RetryConfig::default(),
circuit_breaker: CircuitBreakerConfig::default(),
......
......@@ -16,6 +16,7 @@ pub fn create_test_context(config: RouterConfig) -> Arc<AppContext> {
config.clone(),
reqwest::Client::new(),
config.max_concurrent_requests,
config.rate_limit_tokens_per_second,
))
}
......
......@@ -19,12 +19,14 @@ pub fn create_test_app(
router_config.clone(),
client,
router_config.max_concurrent_requests,
router_config.rate_limit_tokens_per_second,
));
// Create AppState with the test router and context
let app_state = Arc::new(AppState {
router,
context: app_context,
concurrency_queue_tx: None, // No queue for tests
});
// Configure request ID headers (use defaults if not specified)
......
......@@ -36,6 +36,9 @@ impl TestContext {
log_level: None,
request_id_headers: None,
max_concurrent_requests: 64,
queue_size: 0,
queue_timeout_secs: 60,
rate_limit_tokens_per_second: None,
cors_allowed_origins: vec![],
retry: RetryConfig::default(),
circuit_breaker: CircuitBreakerConfig::default(),
......
......@@ -37,6 +37,9 @@ impl TestContext {
log_level: None,
request_id_headers: None,
max_concurrent_requests: 64,
queue_size: 0,
queue_timeout_secs: 60,
rate_limit_tokens_per_second: None,
cors_allowed_origins: vec![],
retry: RetryConfig::default(),
circuit_breaker: CircuitBreakerConfig::default(),
......
......@@ -178,6 +178,8 @@ mod test_pd_routing {
log_level: None,
request_id_headers: None,
max_concurrent_requests: 64,
queue_size: 0,
queue_timeout_secs: 60,
cors_allowed_origins: vec![],
retry: RetryConfig::default(),
circuit_breaker: CircuitBreakerConfig::default(),
......@@ -185,11 +187,12 @@ mod test_pd_routing {
disable_circuit_breaker: false,
health_check: sglang_router_rs::config::HealthCheckConfig::default(),
enable_igw: false,
rate_limit_tokens_per_second: None,
};
// Router creation will fail due to health checks, but config should be valid
let app_context =
sglang_router_rs::server::AppContext::new(config, reqwest::Client::new(), 64);
sglang_router_rs::server::AppContext::new(config, reqwest::Client::new(), 64, None);
let app_context = std::sync::Arc::new(app_context);
let result = RouterFactory::create_router(&app_context).await;
assert!(result.is_err());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment