"torch_scatter/src/gpu.c" did not exist on "85a258235fe64d2d6c2b415bfe99de07d930d790"
Unverified Commit dc01313d authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

[router] Add rustfmt and set group imports by default (#11732)

parent 7a7f99be
use std::{
fmt::{Display, Formatter},
sync::Arc,
};
use async_trait::async_trait; use async_trait::async_trait;
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use rand::RngCore; use rand::RngCore;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::{Map as JsonMap, Value}; use serde_json::{Map as JsonMap, Value};
use std::fmt::{Display, Formatter};
use std::sync::Arc;
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, PartialOrd, Ord)] #[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, PartialOrd, Ord)]
pub struct ConversationId(pub String); pub struct ConversationId(pub String);
......
use std::{collections::HashMap, sync::Arc};
use async_trait::async_trait; use async_trait::async_trait;
use parking_lot::RwLock; use parking_lot::RwLock;
use std::collections::HashMap;
use std::sync::Arc;
use super::responses::{ResponseChain, ResponseId, ResponseStorage, Result, StoredResponse}; use super::responses::{ResponseChain, ResponseId, ResponseStorage, Result, StoredResponse};
......
use crate::config::OracleConfig; use std::{collections::HashMap, path::Path, sync::Arc, time::Duration};
use crate::data_connector::responses::{
ResponseChain, ResponseId, ResponseStorage, ResponseStorageError, Result as StorageResult,
StoredResponse,
};
use async_trait::async_trait; use async_trait::async_trait;
use deadpool::managed::{Manager, Metrics, Pool, PoolError, RecycleError, RecycleResult}; use deadpool::managed::{Manager, Metrics, Pool, PoolError, RecycleError, RecycleResult};
use oracle::{Connection, Row}; use oracle::{Connection, Row};
use serde_json::Value; use serde_json::Value;
use std::collections::HashMap;
use std::path::Path; use crate::{
use std::sync::Arc; config::OracleConfig,
use std::time::Duration; data_connector::responses::{
ResponseChain, ResponseId, ResponseStorage, ResponseStorageError, Result as StorageResult,
StoredResponse,
},
};
const SELECT_BASE: &str = "SELECT id, previous_response_id, input, instructions, output, \ const SELECT_BASE: &str = "SELECT id, previous_response_id, input, instructions, output, \
tool_calls, metadata, created_at, user_id, model, conversation_id, raw_response FROM responses"; tool_calls, metadata, created_at, user_id, model, conversation_id, raw_response FROM responses";
...@@ -510,9 +511,10 @@ impl OracleErrorExt for ResponseStorageError { ...@@ -510,9 +511,10 @@ impl OracleErrorExt for ResponseStorageError {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*;
use serde_json::json; use serde_json::json;
use super::*;
#[test] #[test]
fn parse_tool_calls_handles_empty_input() { fn parse_tool_calls_handles_empty_input() {
assert!(parse_tool_calls(None).unwrap().is_empty()); assert!(parse_tool_calls(None).unwrap().is_empty());
......
use std::{collections::HashMap, sync::Arc};
use async_trait::async_trait; use async_trait::async_trait;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::Value; use serde_json::Value;
use std::collections::HashMap;
use std::sync::Arc;
/// Response identifier /// Response identifier
#[derive(Debug, Clone, Hash, Eq, PartialEq, Serialize, Deserialize)] #[derive(Debug, Clone, Hash, Eq, PartialEq, Serialize, Deserialize)]
......
use std::convert::TryFrom; use std::{
use std::pin::Pin; convert::TryFrom,
use std::sync::atomic::{AtomicBool, Ordering}; pin::Pin,
use std::sync::Arc; sync::{
use std::task::{Context, Poll}; atomic::{AtomicBool, Ordering},
use std::time::Duration; Arc,
},
task::{Context, Poll},
time::Duration,
};
use tonic::{transport::Channel, Request, Streaming}; use tonic::{transport::Channel, Request, Streaming};
use tracing::{debug, warn}; use tracing::{debug, warn};
use crate::protocols::chat::ChatCompletionRequest; use crate::protocols::{
use crate::protocols::common::{ResponseFormat, StringOrArray, ToolChoice, ToolChoiceValue}; chat::ChatCompletionRequest,
use crate::protocols::generate::GenerateRequest; common::{ResponseFormat, StringOrArray, ToolChoice, ToolChoiceValue},
use crate::protocols::sampling_params::SamplingParams as GenerateSamplingParams; generate::GenerateRequest,
sampling_params::SamplingParams as GenerateSamplingParams,
};
// Include the generated protobuf code // Include the generated protobuf code
pub mod proto { pub mod proto {
......
use std::path::PathBuf; use std::path::PathBuf;
use tracing::Level; use tracing::Level;
use tracing_appender::non_blocking::WorkerGuard; use tracing_appender::{
use tracing_appender::rolling::{RollingFileAppender, Rotation}; non_blocking::WorkerGuard,
rolling::{RollingFileAppender, Rotation},
};
use tracing_log::LogTracer; use tracing_log::LogTracer;
use tracing_subscriber::fmt::time::ChronoUtc; use tracing_subscriber::{
use tracing_subscriber::layer::SubscriberExt; fmt::time::ChronoUtc, layer::SubscriberExt, util::SubscriberInitExt, EnvFilter, Layer,
use tracing_subscriber::util::SubscriberInitExt; };
use tracing_subscriber::{EnvFilter, Layer};
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct LoggingConfig { pub struct LoggingConfig {
......
use std::collections::HashMap;
use clap::{ArgAction, Parser, ValueEnum}; use clap::{ArgAction, Parser, ValueEnum};
use sglang_router_rs::config::{ use sglang_router_rs::{
config::{
CircuitBreakerConfig, ConfigError, ConfigResult, ConnectionMode, DiscoveryConfig, CircuitBreakerConfig, ConfigError, ConfigResult, ConnectionMode, DiscoveryConfig,
HealthCheckConfig, HistoryBackend, MetricsConfig, OracleConfig, PolicyConfig, RetryConfig, HealthCheckConfig, HistoryBackend, MetricsConfig, OracleConfig, PolicyConfig, RetryConfig,
RouterConfig, RoutingMode, RouterConfig, RoutingMode,
},
metrics::PrometheusConfig,
server::{self, ServerConfig},
service_discovery::ServiceDiscoveryConfig,
}; };
use sglang_router_rs::metrics::PrometheusConfig;
use sglang_router_rs::server::{self, ServerConfig};
use sglang_router_rs::service_discovery::ServiceDiscoveryConfig;
use std::collections::HashMap;
fn parse_prefill_args() -> Vec<(String, Option<u16>)> { fn parse_prefill_args() -> Vec<(String, Option<u16>)> {
let args: Vec<String> = std::env::args().collect(); let args: Vec<String> = std::env::args().collect();
......
use std::{borrow::Cow, collections::HashMap, time::Duration};
use backoff::ExponentialBackoffBuilder; use backoff::ExponentialBackoffBuilder;
use dashmap::DashMap; use dashmap::DashMap;
use rmcp::{ use rmcp::{
...@@ -13,7 +15,6 @@ use rmcp::{ ...@@ -13,7 +15,6 @@ use rmcp::{
RoleClient, ServiceExt, RoleClient, ServiceExt,
}; };
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::{borrow::Cow, collections::HashMap, time::Duration};
use crate::mcp::{ use crate::mcp::{
config::{McpConfig, McpServerConfig, McpTransport}, config::{McpConfig, McpServerConfig, McpTransport},
......
use serde::{Deserialize, Serialize};
use std::collections::HashMap; use std::collections::HashMap;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Deserialize, Serialize)] #[derive(Debug, Clone, Deserialize, Serialize)]
pub struct McpConfig { pub struct McpConfig {
pub servers: Vec<McpServerConfig>, pub servers: Vec<McpServerConfig>,
......
// OAuth authentication support for MCP servers // OAuth authentication support for MCP servers
use std::{net::SocketAddr, sync::Arc};
use axum::{ use axum::{
extract::{Query, State}, extract::{Query, State},
response::Html, response::Html,
...@@ -8,7 +10,6 @@ use axum::{ ...@@ -8,7 +10,6 @@ use axum::{
}; };
use rmcp::transport::auth::OAuthState; use rmcp::transport::auth::OAuthState;
use serde::Deserialize; use serde::Deserialize;
use std::{net::SocketAddr, sync::Arc};
use tokio::sync::{oneshot, Mutex}; use tokio::sync::{oneshot, Mutex};
use crate::mcp::error::{McpError, McpResult}; use crate::mcp::error::{McpError, McpResult};
......
use std::{
net::{IpAddr, Ipv4Addr, SocketAddr},
time::Duration,
};
use metrics::{counter, describe_counter, describe_gauge, describe_histogram, gauge, histogram}; use metrics::{counter, describe_counter, describe_gauge, describe_histogram, gauge, histogram};
use metrics_exporter_prometheus::{Matcher, PrometheusBuilder}; use metrics_exporter_prometheus::{Matcher, PrometheusBuilder};
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::time::Duration;
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct PrometheusConfig { pub struct PrometheusConfig {
...@@ -620,9 +623,10 @@ impl TokenizerMetrics { ...@@ -620,9 +623,10 @@ impl TokenizerMetrics {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*;
use std::net::TcpListener; use std::net::TcpListener;
use super::*;
#[test] #[test]
fn test_prometheus_config_default() { fn test_prometheus_config_default() {
let config = PrometheusConfig::default(); let config = PrometheusConfig::default();
...@@ -912,9 +916,13 @@ mod tests { ...@@ -912,9 +916,13 @@ mod tests {
#[test] #[test]
fn test_concurrent_metric_updates() { fn test_concurrent_metric_updates() {
use std::sync::atomic::{AtomicBool, Ordering}; use std::{
use std::sync::Arc; sync::{
use std::thread; atomic::{AtomicBool, Ordering},
Arc,
},
thread,
};
let done = Arc::new(AtomicBool::new(false)); let done = Arc::new(AtomicBool::new(false));
let mut handles = vec![]; let mut handles = vec![];
......
use std::{
sync::{
atomic::{AtomicU64, Ordering},
Arc,
},
time::{Duration, Instant},
};
use axum::{ use axum::{
body::Body, extract::Request, extract::State, http::header, http::HeaderValue, body::Body,
http::StatusCode, middleware::Next, response::IntoResponse, response::Response, extract::{Request, State},
http::{header, HeaderValue, StatusCode},
middleware::Next,
response::{IntoResponse, Response},
}; };
use rand::Rng; use rand::Rng;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::Duration;
use std::time::Instant;
use subtle::ConstantTimeEq; use subtle::ConstantTimeEq;
use tokio::sync::{mpsc, oneshot}; use tokio::sync::{mpsc, oneshot};
use tower::{Layer, Service}; use tower::{Layer, Service};
...@@ -14,9 +21,7 @@ use tower_http::trace::{MakeSpan, OnRequest, OnResponse, TraceLayer}; ...@@ -14,9 +21,7 @@ use tower_http::trace::{MakeSpan, OnRequest, OnResponse, TraceLayer};
use tracing::{debug, error, field::Empty, info, info_span, warn, Span}; use tracing::{debug, error, field::Empty, info, info_span, warn, Span};
pub use crate::core::token_bucket::TokenBucket; pub use crate::core::token_bucket::TokenBucket;
use crate::{metrics::RouterMetrics, server::AppState};
use crate::metrics::RouterMetrics;
use crate::server::AppState;
#[derive(Clone)] #[derive(Clone)]
pub struct AuthConfig { pub struct AuthConfig {
......
...@@ -59,17 +59,15 @@ ...@@ -59,17 +59,15 @@
during the next eviction cycle. during the next eviction cycle.
*/ */
use super::{get_healthy_worker_indices, CacheAwareConfig, LoadBalancingPolicy}; use std::{sync::Arc, thread, time::Duration};
use crate::core::Worker;
use crate::metrics::RouterMetrics;
use crate::tree::Tree;
use dashmap::DashMap; use dashmap::DashMap;
use rand::Rng; use rand::Rng;
use std::sync::Arc;
use std::thread;
use std::time::Duration;
use tracing::debug; use tracing::debug;
use super::{get_healthy_worker_indices, CacheAwareConfig, LoadBalancingPolicy};
use crate::{core::Worker, metrics::RouterMetrics, tree::Tree};
/// Cache-aware routing policy /// Cache-aware routing policy
/// ///
/// Routes requests based on cache affinity when load is balanced, /// Routes requests based on cache affinity when load is balanced,
......
//! Factory for creating load balancing policies //! Factory for creating load balancing policies
use std::sync::Arc;
use super::{ use super::{
CacheAwareConfig, CacheAwarePolicy, LoadBalancingPolicy, PowerOfTwoPolicy, RandomPolicy, CacheAwareConfig, CacheAwarePolicy, LoadBalancingPolicy, PowerOfTwoPolicy, RandomPolicy,
RoundRobinPolicy, RoundRobinPolicy,
}; };
use crate::config::PolicyConfig; use crate::config::PolicyConfig;
use std::sync::Arc;
/// Factory for creating policy instances /// Factory for creating policy instances
pub struct PolicyFactory; pub struct PolicyFactory;
......
...@@ -3,9 +3,9 @@ ...@@ -3,9 +3,9 @@
//! This module provides a unified abstraction for routing policies that work //! This module provides a unified abstraction for routing policies that work
//! across both regular and prefill-decode (PD) routing modes. //! across both regular and prefill-decode (PD) routing modes.
use std::{fmt::Debug, sync::Arc};
use crate::core::Worker; use crate::core::Worker;
use std::fmt::Debug;
use std::sync::Arc;
mod cache_aware; mod cache_aware;
mod factory; mod factory;
......
//! Power-of-two choices load balancing policy //! Power-of-two choices load balancing policy
use super::{get_healthy_worker_indices, LoadBalancingPolicy}; use std::{
use crate::core::Worker; collections::HashMap,
use crate::metrics::RouterMetrics; sync::{Arc, RwLock},
};
use rand::Rng; use rand::Rng;
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use tracing::info; use tracing::info;
use super::{get_healthy_worker_indices, LoadBalancingPolicy};
use crate::{core::Worker, metrics::RouterMetrics};
/// Power-of-two choices policy /// Power-of-two choices policy
/// ///
/// Randomly selects two workers and routes to the one with lower load. /// Randomly selects two workers and routes to the one with lower load.
......
//! Random load balancing policy //! Random load balancing policy
use super::{get_healthy_worker_indices, LoadBalancingPolicy};
use crate::core::Worker;
use crate::metrics::RouterMetrics;
use rand::Rng;
use std::sync::Arc; use std::sync::Arc;
use rand::Rng;
use super::{get_healthy_worker_indices, LoadBalancingPolicy};
use crate::{core::Worker, metrics::RouterMetrics};
/// Random selection policy /// Random selection policy
/// ///
/// Selects workers randomly with uniform distribution among healthy workers. /// Selects workers randomly with uniform distribution among healthy workers.
...@@ -50,9 +51,10 @@ impl LoadBalancingPolicy for RandomPolicy { ...@@ -50,9 +51,10 @@ impl LoadBalancingPolicy for RandomPolicy {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use std::collections::HashMap;
use super::*; use super::*;
use crate::core::{BasicWorkerBuilder, WorkerType}; use crate::core::{BasicWorkerBuilder, WorkerType};
use std::collections::HashMap;
#[test] #[test]
fn test_random_selection() { fn test_random_selection() {
......
use std::{
collections::HashMap,
sync::{Arc, RwLock},
};
use tracing::{debug, info, warn};
/// Policy Registry for managing model-to-policy mappings /// Policy Registry for managing model-to-policy mappings
/// ///
/// This registry manages the dynamic assignment of load balancing policies to models. /// This registry manages the dynamic assignment of load balancing policies to models.
...@@ -8,11 +15,7 @@ use super::{ ...@@ -8,11 +15,7 @@ use super::{
CacheAwareConfig, CacheAwarePolicy, LoadBalancingPolicy, PowerOfTwoPolicy, RandomPolicy, CacheAwareConfig, CacheAwarePolicy, LoadBalancingPolicy, PowerOfTwoPolicy, RandomPolicy,
RoundRobinPolicy, RoundRobinPolicy,
}; };
use crate::config::types::PolicyConfig; use crate::{config::types::PolicyConfig, core::Worker};
use crate::core::Worker;
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use tracing::{debug, info, warn};
/// Registry for managing model-to-policy mappings /// Registry for managing model-to-policy mappings
#[derive(Clone)] #[derive(Clone)]
......
//! Round-robin load balancing policy //! Round-robin load balancing policy
use std::sync::{
atomic::{AtomicUsize, Ordering},
Arc,
};
use super::{get_healthy_worker_indices, LoadBalancingPolicy}; use super::{get_healthy_worker_indices, LoadBalancingPolicy};
use crate::core::Worker; use crate::{core::Worker, metrics::RouterMetrics};
use crate::metrics::RouterMetrics;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
/// Round-robin selection policy /// Round-robin selection policy
/// ///
......
use std::collections::HashMap;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::Value; use serde_json::Value;
use std::collections::HashMap;
use validator::Validate; use validator::Validate;
use super::common::*; use super::{
use super::sampling_params::{validate_top_k_value, validate_top_p_value}; common::*,
sampling_params::{validate_top_k_value, validate_top_p_value},
};
use crate::protocols::validated::Normalizable; use crate::protocols::validated::Normalizable;
// ============================================================================ // ============================================================================
...@@ -532,11 +535,12 @@ impl Normalizable for ChatCompletionRequest { ...@@ -532,11 +535,12 @@ impl Normalizable for ChatCompletionRequest {
// Apply tool_choice defaults // Apply tool_choice defaults
if self.tool_choice.is_none() { if self.tool_choice.is_none() {
if let Some(tools) = &self.tools { if let Some(tools) = &self.tools {
self.tool_choice = if !tools.is_empty() { let choice_value = if !tools.is_empty() {
Some(ToolChoice::Value(ToolChoiceValue::Auto)) ToolChoiceValue::Auto
} else { } else {
Some(ToolChoice::Value(ToolChoiceValue::None)) ToolChoiceValue::None
}; };
self.tool_choice = Some(ToolChoice::Value(choice_value));
} }
// If tools is None, leave tool_choice as None (don't set it) // If tools is None, leave tool_choice as None (don't set it)
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment