Unverified Commit dc01313d authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

[router] Add rustfmt and set group imports by default (#11732)

parent 7a7f99be
use std::{
fmt::{Display, Formatter},
sync::Arc,
};
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use rand::RngCore;
use serde::{Deserialize, Serialize};
use serde_json::{Map as JsonMap, Value};
use std::fmt::{Display, Formatter};
use std::sync::Arc;
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, PartialOrd, Ord)]
pub struct ConversationId(pub String);
......
use std::{collections::HashMap, sync::Arc};
use async_trait::async_trait;
use parking_lot::RwLock;
use std::collections::HashMap;
use std::sync::Arc;
use super::responses::{ResponseChain, ResponseId, ResponseStorage, Result, StoredResponse};
......
use crate::config::OracleConfig;
use crate::data_connector::responses::{
ResponseChain, ResponseId, ResponseStorage, ResponseStorageError, Result as StorageResult,
StoredResponse,
};
use std::{collections::HashMap, path::Path, sync::Arc, time::Duration};
use async_trait::async_trait;
use deadpool::managed::{Manager, Metrics, Pool, PoolError, RecycleError, RecycleResult};
use oracle::{Connection, Row};
use serde_json::Value;
use std::collections::HashMap;
use std::path::Path;
use std::sync::Arc;
use std::time::Duration;
use crate::{
config::OracleConfig,
data_connector::responses::{
ResponseChain, ResponseId, ResponseStorage, ResponseStorageError, Result as StorageResult,
StoredResponse,
},
};
const SELECT_BASE: &str = "SELECT id, previous_response_id, input, instructions, output, \
tool_calls, metadata, created_at, user_id, model, conversation_id, raw_response FROM responses";
......@@ -510,9 +511,10 @@ impl OracleErrorExt for ResponseStorageError {
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
use super::*;
#[test]
fn parse_tool_calls_handles_empty_input() {
assert!(parse_tool_calls(None).unwrap().is_empty());
......
use std::{collections::HashMap, sync::Arc};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
use std::sync::Arc;
/// Response identifier
#[derive(Debug, Clone, Hash, Eq, PartialEq, Serialize, Deserialize)]
......
use std::convert::TryFrom;
use std::pin::Pin;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Duration;
use std::{
convert::TryFrom,
pin::Pin,
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
task::{Context, Poll},
time::Duration,
};
use tonic::{transport::Channel, Request, Streaming};
use tracing::{debug, warn};
use crate::protocols::chat::ChatCompletionRequest;
use crate::protocols::common::{ResponseFormat, StringOrArray, ToolChoice, ToolChoiceValue};
use crate::protocols::generate::GenerateRequest;
use crate::protocols::sampling_params::SamplingParams as GenerateSamplingParams;
use crate::protocols::{
chat::ChatCompletionRequest,
common::{ResponseFormat, StringOrArray, ToolChoice, ToolChoiceValue},
generate::GenerateRequest,
sampling_params::SamplingParams as GenerateSamplingParams,
};
// Include the generated protobuf code
pub mod proto {
......
use std::path::PathBuf;
use tracing::Level;
use tracing_appender::non_blocking::WorkerGuard;
use tracing_appender::rolling::{RollingFileAppender, Rotation};
use tracing_appender::{
non_blocking::WorkerGuard,
rolling::{RollingFileAppender, Rotation},
};
use tracing_log::LogTracer;
use tracing_subscriber::fmt::time::ChronoUtc;
use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt;
use tracing_subscriber::{EnvFilter, Layer};
use tracing_subscriber::{
fmt::time::ChronoUtc, layer::SubscriberExt, util::SubscriberInitExt, EnvFilter, Layer,
};
#[derive(Debug, Clone)]
pub struct LoggingConfig {
......
use std::collections::HashMap;
use clap::{ArgAction, Parser, ValueEnum};
use sglang_router_rs::config::{
CircuitBreakerConfig, ConfigError, ConfigResult, ConnectionMode, DiscoveryConfig,
HealthCheckConfig, HistoryBackend, MetricsConfig, OracleConfig, PolicyConfig, RetryConfig,
RouterConfig, RoutingMode,
use sglang_router_rs::{
config::{
CircuitBreakerConfig, ConfigError, ConfigResult, ConnectionMode, DiscoveryConfig,
HealthCheckConfig, HistoryBackend, MetricsConfig, OracleConfig, PolicyConfig, RetryConfig,
RouterConfig, RoutingMode,
},
metrics::PrometheusConfig,
server::{self, ServerConfig},
service_discovery::ServiceDiscoveryConfig,
};
use sglang_router_rs::metrics::PrometheusConfig;
use sglang_router_rs::server::{self, ServerConfig};
use sglang_router_rs::service_discovery::ServiceDiscoveryConfig;
use std::collections::HashMap;
fn parse_prefill_args() -> Vec<(String, Option<u16>)> {
let args: Vec<String> = std::env::args().collect();
......
use std::{borrow::Cow, collections::HashMap, time::Duration};
use backoff::ExponentialBackoffBuilder;
use dashmap::DashMap;
use rmcp::{
......@@ -13,7 +15,6 @@ use rmcp::{
RoleClient, ServiceExt,
};
use serde::{Deserialize, Serialize};
use std::{borrow::Cow, collections::HashMap, time::Duration};
use crate::mcp::{
config::{McpConfig, McpServerConfig, McpTransport},
......
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct McpConfig {
pub servers: Vec<McpServerConfig>,
......
// OAuth authentication support for MCP servers
use std::{net::SocketAddr, sync::Arc};
use axum::{
extract::{Query, State},
response::Html,
......@@ -8,7 +10,6 @@ use axum::{
};
use rmcp::transport::auth::OAuthState;
use serde::Deserialize;
use std::{net::SocketAddr, sync::Arc};
use tokio::sync::{oneshot, Mutex};
use crate::mcp::error::{McpError, McpResult};
......
use std::{
net::{IpAddr, Ipv4Addr, SocketAddr},
time::Duration,
};
use metrics::{counter, describe_counter, describe_gauge, describe_histogram, gauge, histogram};
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder};
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::time::Duration;
#[derive(Debug, Clone)]
pub struct PrometheusConfig {
......@@ -620,9 +623,10 @@ impl TokenizerMetrics {
#[cfg(test)]
mod tests {
use super::*;
use std::net::TcpListener;
use super::*;
#[test]
fn test_prometheus_config_default() {
let config = PrometheusConfig::default();
......@@ -912,9 +916,13 @@ mod tests {
#[test]
fn test_concurrent_metric_updates() {
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::thread;
use std::{
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
thread,
};
let done = Arc::new(AtomicBool::new(false));
let mut handles = vec![];
......
use std::{
sync::{
atomic::{AtomicU64, Ordering},
Arc,
},
time::{Duration, Instant},
};
use axum::{
body::Body, extract::Request, extract::State, http::header, http::HeaderValue,
http::StatusCode, middleware::Next, response::IntoResponse, response::Response,
body::Body,
extract::{Request, State},
http::{header, HeaderValue, StatusCode},
middleware::Next,
response::{IntoResponse, Response},
};
use rand::Rng;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::Duration;
use std::time::Instant;
use subtle::ConstantTimeEq;
use tokio::sync::{mpsc, oneshot};
use tower::{Layer, Service};
......@@ -14,9 +21,7 @@ use tower_http::trace::{MakeSpan, OnRequest, OnResponse, TraceLayer};
use tracing::{debug, error, field::Empty, info, info_span, warn, Span};
pub use crate::core::token_bucket::TokenBucket;
use crate::metrics::RouterMetrics;
use crate::server::AppState;
use crate::{metrics::RouterMetrics, server::AppState};
#[derive(Clone)]
pub struct AuthConfig {
......
......@@ -59,17 +59,15 @@
during the next eviction cycle.
*/
use super::{get_healthy_worker_indices, CacheAwareConfig, LoadBalancingPolicy};
use crate::core::Worker;
use crate::metrics::RouterMetrics;
use crate::tree::Tree;
use std::{sync::Arc, thread, time::Duration};
use dashmap::DashMap;
use rand::Rng;
use std::sync::Arc;
use std::thread;
use std::time::Duration;
use tracing::debug;
use super::{get_healthy_worker_indices, CacheAwareConfig, LoadBalancingPolicy};
use crate::{core::Worker, metrics::RouterMetrics, tree::Tree};
/// Cache-aware routing policy
///
/// Routes requests based on cache affinity when load is balanced,
......
//! Factory for creating load balancing policies
use std::sync::Arc;
use super::{
CacheAwareConfig, CacheAwarePolicy, LoadBalancingPolicy, PowerOfTwoPolicy, RandomPolicy,
RoundRobinPolicy,
};
use crate::config::PolicyConfig;
use std::sync::Arc;
/// Factory for creating policy instances
pub struct PolicyFactory;
......
......@@ -3,9 +3,9 @@
//! This module provides a unified abstraction for routing policies that work
//! across both regular and prefill-decode (PD) routing modes.
use std::{fmt::Debug, sync::Arc};
use crate::core::Worker;
use std::fmt::Debug;
use std::sync::Arc;
mod cache_aware;
mod factory;
......
//! Power-of-two choices load balancing policy
use super::{get_healthy_worker_indices, LoadBalancingPolicy};
use crate::core::Worker;
use crate::metrics::RouterMetrics;
use std::{
collections::HashMap,
sync::{Arc, RwLock},
};
use rand::Rng;
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use tracing::info;
use super::{get_healthy_worker_indices, LoadBalancingPolicy};
use crate::{core::Worker, metrics::RouterMetrics};
/// Power-of-two choices policy
///
/// Randomly selects two workers and routes to the one with lower load.
......
//! Random load balancing policy
use super::{get_healthy_worker_indices, LoadBalancingPolicy};
use crate::core::Worker;
use crate::metrics::RouterMetrics;
use rand::Rng;
use std::sync::Arc;
use rand::Rng;
use super::{get_healthy_worker_indices, LoadBalancingPolicy};
use crate::{core::Worker, metrics::RouterMetrics};
/// Random selection policy
///
/// Selects workers randomly with uniform distribution among healthy workers.
......@@ -50,9 +51,10 @@ impl LoadBalancingPolicy for RandomPolicy {
#[cfg(test)]
mod tests {
use std::collections::HashMap;
use super::*;
use crate::core::{BasicWorkerBuilder, WorkerType};
use std::collections::HashMap;
#[test]
fn test_random_selection() {
......
use std::{
collections::HashMap,
sync::{Arc, RwLock},
};
use tracing::{debug, info, warn};
/// Policy Registry for managing model-to-policy mappings
///
/// This registry manages the dynamic assignment of load balancing policies to models.
......@@ -8,11 +15,7 @@ use super::{
CacheAwareConfig, CacheAwarePolicy, LoadBalancingPolicy, PowerOfTwoPolicy, RandomPolicy,
RoundRobinPolicy,
};
use crate::config::types::PolicyConfig;
use crate::core::Worker;
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use tracing::{debug, info, warn};
use crate::{config::types::PolicyConfig, core::Worker};
/// Registry for managing model-to-policy mappings
#[derive(Clone)]
......
//! Round-robin load balancing policy
use std::sync::{
atomic::{AtomicUsize, Ordering},
Arc,
};
use super::{get_healthy_worker_indices, LoadBalancingPolicy};
use crate::core::Worker;
use crate::metrics::RouterMetrics;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use crate::{core::Worker, metrics::RouterMetrics};
/// Round-robin selection policy
///
......
use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
use validator::Validate;
use super::common::*;
use super::sampling_params::{validate_top_k_value, validate_top_p_value};
use super::{
common::*,
sampling_params::{validate_top_k_value, validate_top_p_value},
};
use crate::protocols::validated::Normalizable;
// ============================================================================
......@@ -532,11 +535,12 @@ impl Normalizable for ChatCompletionRequest {
// Apply tool_choice defaults
if self.tool_choice.is_none() {
if let Some(tools) = &self.tools {
self.tool_choice = if !tools.is_empty() {
Some(ToolChoice::Value(ToolChoiceValue::Auto))
let choice_value = if !tools.is_empty() {
ToolChoiceValue::Auto
} else {
Some(ToolChoice::Value(ToolChoiceValue::None))
ToolChoiceValue::None
};
self.tool_choice = Some(ToolChoice::Value(choice_value));
}
// If tools is None, leave tool_choice as None (don't set it)
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment