Unverified Commit dc01313d authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

[router] Add rustfmt and set group imports by default (#11732)

parent 7a7f99be
...@@ -54,7 +54,9 @@ jobs: ...@@ -54,7 +54,9 @@ jobs:
run: | run: |
source "$HOME/.cargo/env" source "$HOME/.cargo/env"
cd sgl-router/ cd sgl-router/
cargo fmt -- --check rustup component add --toolchain nightly-x86_64-unknown-linux-gnu rustfmt
rustup toolchain install nightly --profile minimal
cargo +nightly fmt -- --check
- name: Run Rust tests - name: Run Rust tests
timeout-minutes: 20 timeout-minutes: 20
......
use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
use serde_json::{from_str, to_string, to_value, to_vec};
use std::time::Instant; use std::time::Instant;
use sglang_router_rs::core::{BasicWorker, BasicWorkerBuilder, Worker, WorkerType}; use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
use sglang_router_rs::protocols::chat::{ChatCompletionRequest, ChatMessage, UserMessageContent}; use serde_json::{from_str, to_string, to_value, to_vec};
use sglang_router_rs::protocols::common::StringOrArray; use sglang_router_rs::{
use sglang_router_rs::protocols::completion::CompletionRequest; core::{BasicWorker, BasicWorkerBuilder, Worker, WorkerType},
use sglang_router_rs::protocols::generate::GenerateRequest; protocols::{
use sglang_router_rs::protocols::sampling_params::SamplingParams; chat::{ChatCompletionRequest, ChatMessage, UserMessageContent},
use sglang_router_rs::routers::http::pd_types::{generate_room_id, RequestWithBootstrap}; common::StringOrArray,
completion::CompletionRequest,
generate::GenerateRequest,
sampling_params::SamplingParams,
},
routers::http::pd_types::{generate_room_id, RequestWithBootstrap},
};
fn create_test_worker() -> BasicWorker { fn create_test_worker() -> BasicWorker {
BasicWorkerBuilder::new("http://test-server:8000") BasicWorkerBuilder::new("http://test-server:8000")
......
//! Comprehensive tokenizer benchmark with clean summary output //! Comprehensive tokenizer benchmark with clean summary output
//! Each test adds a row to the final summary table //! Each test adds a row to the final summary table
use std::{
collections::BTreeMap,
path::PathBuf,
sync::{
atomic::{AtomicBool, AtomicU64, Ordering},
Arc, Mutex, OnceLock,
},
thread,
time::{Duration, Instant},
};
use criterion::{black_box, criterion_group, BenchmarkId, Criterion, Throughput}; use criterion::{black_box, criterion_group, BenchmarkId, Criterion, Throughput};
use sglang_router_rs::tokenizer::{ use sglang_router_rs::tokenizer::{
huggingface::HuggingFaceTokenizer, sequence::Sequence, stop::*, stream::DecodeStream, traits::*, huggingface::HuggingFaceTokenizer, sequence::Sequence, stop::*, stream::DecodeStream, traits::*,
}; };
use std::collections::BTreeMap;
use std::path::PathBuf;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::{Arc, Mutex, OnceLock};
use std::thread;
use std::time::{Duration, Instant};
// Include the common test utilities // Include the common test utilities
#[path = "../tests/common/mod.rs"] #[path = "../tests/common/mod.rs"]
......
...@@ -7,15 +7,22 @@ ...@@ -7,15 +7,22 @@
//! - Streaming vs complete parsing //! - Streaming vs complete parsing
//! - Different model formats (JSON, Mistral, Qwen, Pythonic, etc.) //! - Different model formats (JSON, Mistral, Qwen, Pythonic, etc.)
use std::{
collections::BTreeMap,
sync::{
atomic::{AtomicBool, AtomicU64, Ordering},
Arc, Mutex,
},
thread,
time::{Duration, Instant},
};
use criterion::{black_box, criterion_group, BenchmarkId, Criterion, Throughput}; use criterion::{black_box, criterion_group, BenchmarkId, Criterion, Throughput};
use serde_json::json; use serde_json::json;
use sglang_router_rs::protocols::common::{Function, Tool}; use sglang_router_rs::{
use sglang_router_rs::tool_parser::{JsonParser, ParserFactory as ToolParserFactory, ToolParser}; protocols::common::{Function, Tool},
use std::collections::BTreeMap; tool_parser::{JsonParser, ParserFactory as ToolParserFactory, ToolParser},
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; };
use std::sync::{Arc, Mutex};
use std::thread;
use std::time::{Duration, Instant};
use tokio::runtime::Runtime; use tokio::runtime::Runtime;
// Test data for different parser formats - realistic complex examples // Test data for different parser formats - realistic complex examples
......
# Rust formatting configuration
# Enforce grouped imports by crate
imports_granularity = "Crate"
# Group std, external crates, and local crate imports separately
group_imports = "StdExternalCrate"
reorder_imports = true
reorder_modules = true
use super::ConfigResult;
use serde::{Deserialize, Serialize};
use std::collections::HashMap; use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use super::ConfigResult;
/// Main router configuration /// Main router configuration
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RouterConfig { pub struct RouterConfig {
......
use std::sync::atomic::{AtomicU32, AtomicU64, Ordering}; use std::{
use std::sync::{Arc, RwLock}; sync::{
use std::time::{Duration, Instant}; atomic::{AtomicU32, AtomicU64, Ordering},
Arc, RwLock,
},
time::{Duration, Instant},
};
use tracing::info; use tracing::info;
/// Circuit breaker configuration /// Circuit breaker configuration
...@@ -316,9 +321,10 @@ pub struct CircuitBreakerStats { ...@@ -316,9 +321,10 @@ pub struct CircuitBreakerStats {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*;
use std::thread; use std::thread;
use super::*;
#[test] #[test]
fn test_circuit_breaker_initial_state() { fn test_circuit_breaker_initial_state() {
let cb = CircuitBreaker::new(); let cb = CircuitBreaker::new();
......
...@@ -68,9 +68,10 @@ impl From<reqwest::Error> for WorkerError { ...@@ -68,9 +68,10 @@ impl From<reqwest::Error> for WorkerError {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*;
use std::error::Error; use std::error::Error;
use super::*;
#[test] #[test]
fn test_health_check_failed_display() { fn test_health_check_failed_display() {
let error = WorkerError::HealthCheckFailed { let error = WorkerError::HealthCheckFailed {
......
...@@ -3,16 +3,22 @@ ...@@ -3,16 +3,22 @@
//! Provides non-blocking worker management by queuing operations and processing //! Provides non-blocking worker management by queuing operations and processing
//! them asynchronously in background worker tasks. //! them asynchronously in background worker tasks.
use crate::core::WorkerManager; use std::{
use crate::protocols::worker_spec::{JobStatus, WorkerConfigRequest}; sync::{Arc, Weak},
use crate::server::AppContext; time::{Duration, SystemTime},
};
use dashmap::DashMap; use dashmap::DashMap;
use metrics::{counter, gauge, histogram}; use metrics::{counter, gauge, histogram};
use std::sync::{Arc, Weak};
use std::time::{Duration, SystemTime};
use tokio::sync::mpsc; use tokio::sync::mpsc;
use tracing::{debug, error, info, warn}; use tracing::{debug, error, info, warn};
use crate::{
core::WorkerManager,
protocols::worker_spec::{JobStatus, WorkerConfigRequest},
server::AppContext,
};
/// Job types for control plane operations /// Job types for control plane operations
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub enum Job { pub enum Job {
......
use crate::config::types::RetryConfig;
use axum::http::StatusCode;
use axum::response::Response;
use rand::Rng;
use std::time::Duration; use std::time::Duration;
use axum::{http::StatusCode, response::Response};
use rand::Rng;
use tracing::debug; use tracing::debug;
use crate::config::types::RetryConfig;
/// Check if an HTTP status code indicates a retryable error /// Check if an HTTP status code indicates a retryable error
pub fn is_retryable_status(status: StatusCode) -> bool { pub fn is_retryable_status(status: StatusCode) -> bool {
matches!( matches!(
...@@ -162,11 +163,14 @@ impl RetryExecutor { ...@@ -162,11 +163,14 @@ impl RetryExecutor {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use std::sync::{
atomic::{AtomicU32, Ordering},
Arc,
};
use axum::{http::StatusCode, response::IntoResponse};
use super::*; use super::*;
use axum::http::StatusCode;
use axum::response::IntoResponse;
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::Arc;
fn base_retry_config() -> RetryConfig { fn base_retry_config() -> RetryConfig {
RetryConfig { RetryConfig {
......
use std::sync::Arc; use std::{
use std::time::{Duration, Instant}; sync::Arc,
time::{Duration, Instant},
};
use tokio::sync::{Mutex, Notify}; use tokio::sync::{Mutex, Notify};
use tracing::{debug, trace}; use tracing::{debug, trace};
......
use super::{CircuitBreaker, WorkerError, WorkerResult}; use std::{
use crate::core::CircuitState; fmt,
use crate::core::{BasicWorkerBuilder, DPAwareWorkerBuilder}; sync::{
use crate::grpc_client::SglangSchedulerClient; atomic::{AtomicBool, AtomicUsize, Ordering},
use crate::metrics::RouterMetrics; Arc, LazyLock,
use crate::protocols::worker_spec::WorkerInfo; },
time::{Duration, Instant},
};
use async_trait::async_trait; use async_trait::async_trait;
use futures; use futures;
use serde_json; use serde_json;
use std::fmt; use tokio::{
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; sync::{Mutex, RwLock},
use std::sync::{Arc, LazyLock}; time,
use std::time::Duration; };
use std::time::Instant;
use tokio::sync::{Mutex, RwLock}; use super::{CircuitBreaker, WorkerError, WorkerResult};
use tokio::time; use crate::{
core::{BasicWorkerBuilder, CircuitState, DPAwareWorkerBuilder},
grpc_client::SglangSchedulerClient,
metrics::RouterMetrics,
protocols::worker_spec::WorkerInfo,
};
static WORKER_CLIENT: LazyLock<reqwest::Client> = LazyLock::new(|| { static WORKER_CLIENT: LazyLock<reqwest::Client> = LazyLock::new(|| {
reqwest::Client::builder() reqwest::Client::builder()
...@@ -1024,10 +1032,10 @@ pub fn worker_to_info(worker: &Arc<dyn Worker>) -> WorkerInfo { ...@@ -1024,10 +1032,10 @@ pub fn worker_to_info(worker: &Arc<dyn Worker>) -> WorkerInfo {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use std::{thread, time::Duration};
use super::*; use super::*;
use crate::core::CircuitBreakerConfig; use crate::core::CircuitBreakerConfig;
use std::thread;
use std::time::Duration;
#[test] #[test]
fn test_worker_type_display() { fn test_worker_type_display() {
...@@ -1502,9 +1510,10 @@ mod tests { ...@@ -1502,9 +1510,10 @@ mod tests {
#[test] #[test]
fn test_load_counter_performance() { fn test_load_counter_performance() {
use crate::core::BasicWorkerBuilder;
use std::time::Instant; use std::time::Instant;
use crate::core::BasicWorkerBuilder;
let worker = BasicWorkerBuilder::new("http://test:8080") let worker = BasicWorkerBuilder::new("http://test:8080")
.worker_type(WorkerType::Regular) .worker_type(WorkerType::Regular)
.build(); .build();
......
use super::circuit_breaker::{CircuitBreaker, CircuitBreakerConfig}; use std::collections::HashMap;
use super::worker::{
BasicWorker, ConnectionMode, DPAwareWorker, HealthConfig, WorkerMetadata, WorkerType, use super::{
circuit_breaker::{CircuitBreaker, CircuitBreakerConfig},
worker::{
BasicWorker, ConnectionMode, DPAwareWorker, HealthConfig, WorkerMetadata, WorkerType,
},
}; };
use crate::grpc_client::SglangSchedulerClient; use crate::grpc_client::SglangSchedulerClient;
use std::collections::HashMap;
/// Builder for creating BasicWorker instances with fluent API /// Builder for creating BasicWorker instances with fluent API
pub struct BasicWorkerBuilder { pub struct BasicWorkerBuilder {
...@@ -100,6 +103,7 @@ impl BasicWorkerBuilder { ...@@ -100,6 +103,7 @@ impl BasicWorkerBuilder {
atomic::{AtomicBool, AtomicUsize}, atomic::{AtomicBool, AtomicUsize},
Arc, Arc,
}; };
use tokio::sync::{Mutex, RwLock}; use tokio::sync::{Mutex, RwLock};
let bootstrap_host = match url::Url::parse(&self.url) { let bootstrap_host = match url::Url::parse(&self.url) {
...@@ -282,9 +286,10 @@ impl DPAwareWorkerBuilder { ...@@ -282,9 +286,10 @@ impl DPAwareWorkerBuilder {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use std::time::Duration;
use super::*; use super::*;
use crate::core::worker::Worker; use crate::core::worker::Worker;
use std::time::Duration;
#[test] #[test]
fn test_basic_worker_builder_minimal() { fn test_basic_worker_builder_minimal() {
......
...@@ -3,31 +3,35 @@ ...@@ -3,31 +3,35 @@
//! Handles all aspects of worker lifecycle including discovery, initialization, //! Handles all aspects of worker lifecycle including discovery, initialization,
//! runtime management, and health monitoring. //! runtime management, and health monitoring.
use crate::config::types::{ use std::{collections::HashMap, sync::Arc, time::Duration};
CircuitBreakerConfig as ConfigCircuitBreakerConfig, ConnectionMode as ConfigConnectionMode,
HealthCheckConfig, RouterConfig, RoutingMode,
};
use crate::core::{
BasicWorkerBuilder, CircuitBreakerConfig, ConnectionMode, DPAwareWorkerBuilder, HealthConfig,
Worker, WorkerFactory, WorkerRegistry, WorkerType,
};
use crate::grpc_client::SglangSchedulerClient;
use crate::policies::PolicyRegistry;
use crate::protocols::worker_spec::{
FlushCacheResult, WorkerConfigRequest, WorkerLoadInfo, WorkerLoadsResult,
};
use crate::server::AppContext;
use futures::future; use futures::future;
use once_cell::sync::Lazy; use once_cell::sync::Lazy;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::Value; use serde_json::Value;
use std::collections::HashMap; use tokio::{
use std::sync::Arc; sync::{watch, Mutex},
use std::time::Duration; task::JoinHandle,
use tokio::sync::{watch, Mutex}; };
use tokio::task::JoinHandle;
use tracing::{debug, error, info, warn}; use tracing::{debug, error, info, warn};
use crate::{
config::types::{
CircuitBreakerConfig as ConfigCircuitBreakerConfig, ConnectionMode as ConfigConnectionMode,
HealthCheckConfig, RouterConfig, RoutingMode,
},
core::{
BasicWorkerBuilder, CircuitBreakerConfig, ConnectionMode, DPAwareWorkerBuilder,
HealthConfig, Worker, WorkerFactory, WorkerRegistry, WorkerType,
},
grpc_client::SglangSchedulerClient,
policies::PolicyRegistry,
protocols::worker_spec::{
FlushCacheResult, WorkerConfigRequest, WorkerLoadInfo, WorkerLoadsResult,
},
server::AppContext,
};
static HTTP_CLIENT: Lazy<reqwest::Client> = Lazy::new(|| { static HTTP_CLIENT: Lazy<reqwest::Client> = Lazy::new(|| {
reqwest::Client::builder() reqwest::Client::builder()
.timeout(Duration::from_secs(10)) .timeout(Duration::from_secs(10))
...@@ -1803,9 +1807,10 @@ impl Drop for LoadMonitor { ...@@ -1803,9 +1807,10 @@ impl Drop for LoadMonitor {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*;
use std::collections::HashMap; use std::collections::HashMap;
use super::*;
#[test] #[test]
fn test_parse_server_info() { fn test_parse_server_info() {
let json = serde_json::json!({ let json = serde_json::json!({
......
...@@ -2,11 +2,13 @@ ...@@ -2,11 +2,13 @@
//! //!
//! Provides centralized registry for workers with model-based indexing //! Provides centralized registry for workers with model-based indexing
use crate::core::{ConnectionMode, Worker, WorkerType};
use dashmap::DashMap;
use std::sync::{Arc, RwLock}; use std::sync::{Arc, RwLock};
use dashmap::DashMap;
use uuid::Uuid; use uuid::Uuid;
use crate::core::{ConnectionMode, Worker, WorkerType};
/// Unique identifier for a worker /// Unique identifier for a worker
#[derive(Debug, Clone, Hash, Eq, PartialEq)] #[derive(Debug, Clone, Hash, Eq, PartialEq)]
pub struct WorkerId(String); pub struct WorkerId(String);
...@@ -363,8 +365,10 @@ impl WorkerRegistry { ...@@ -363,8 +365,10 @@ impl WorkerRegistry {
/// Start a health checker for all workers in the registry /// Start a health checker for all workers in the registry
/// This should be called once after the registry is populated with workers /// This should be called once after the registry is populated with workers
pub fn start_health_checker(&self, check_interval_secs: u64) -> crate::core::HealthChecker { pub fn start_health_checker(&self, check_interval_secs: u64) -> crate::core::HealthChecker {
use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::{
use std::sync::Arc; atomic::{AtomicBool, Ordering},
Arc,
};
let shutdown = Arc::new(AtomicBool::new(false)); let shutdown = Arc::new(AtomicBool::new(false));
let shutdown_clone = shutdown.clone(); let shutdown_clone = shutdown.clone();
...@@ -433,9 +437,10 @@ pub struct WorkerRegistryStats { ...@@ -433,9 +437,10 @@ pub struct WorkerRegistryStats {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use std::collections::HashMap;
use super::*; use super::*;
use crate::core::{BasicWorkerBuilder, CircuitBreakerConfig}; use crate::core::{BasicWorkerBuilder, CircuitBreakerConfig};
use std::collections::HashMap;
#[test] #[test]
fn test_worker_registry() { fn test_worker_registry() {
......
use std::collections::{BTreeMap, HashMap}; use std::{
use std::sync::RwLock; collections::{BTreeMap, HashMap},
sync::RwLock,
};
use async_trait::async_trait; use async_trait::async_trait;
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use super::conversation_items::{ use super::{
make_item_id, ConversationItem, ConversationItemId, ConversationItemStorage, ListParams, conversation_items::{
Result, SortOrder, make_item_id, ConversationItem, ConversationItemId, ConversationItemStorage, ListParams,
Result, SortOrder,
},
conversations::ConversationId,
}; };
use super::conversations::ConversationId;
#[derive(Default)] #[derive(Default)]
pub struct MemoryConversationItemStorage { pub struct MemoryConversationItemStorage {
...@@ -190,9 +194,10 @@ impl ConversationItemStorage for MemoryConversationItemStorage { ...@@ -190,9 +194,10 @@ impl ConversationItemStorage for MemoryConversationItemStorage {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*;
use chrono::{TimeZone, Utc}; use chrono::{TimeZone, Utc};
use super::*;
fn make_item( fn make_item(
item_type: &str, item_type: &str,
role: Option<&str>, role: Option<&str>,
......
use crate::config::OracleConfig; use std::{path::Path, sync::Arc, time::Duration};
use crate::data_connector::conversation_items::{
make_item_id, ConversationItem, ConversationItemId, ConversationItemStorage,
ConversationItemStorageError, ListParams, Result as ItemResult, SortOrder,
};
use crate::data_connector::conversations::ConversationId;
use async_trait::async_trait; use async_trait::async_trait;
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use deadpool::managed::{Manager, Metrics, Pool, PoolError, RecycleError, RecycleResult}; use deadpool::managed::{Manager, Metrics, Pool, PoolError, RecycleError, RecycleResult};
use oracle::sql_type::ToSql; use oracle::{sql_type::ToSql, Connection};
use oracle::Connection;
use serde_json::Value; use serde_json::Value;
use std::path::Path;
use std::sync::Arc; use crate::{
use std::time::Duration; config::OracleConfig,
data_connector::{
conversation_items::{
make_item_id, ConversationItem, ConversationItemId, ConversationItemStorage,
ConversationItemStorageError, ListParams, Result as ItemResult, SortOrder,
},
conversations::ConversationId,
},
};
#[derive(Clone)] #[derive(Clone)]
pub struct OracleConversationItemStorage { pub struct OracleConversationItemStorage {
......
use std::{
fmt::{Display, Formatter},
sync::Arc,
};
use async_trait::async_trait; use async_trait::async_trait;
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use rand::RngCore; use rand::RngCore;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::Value; use serde_json::Value;
use std::fmt::{Display, Formatter};
use std::sync::Arc;
use super::conversations::ConversationId; use super::conversations::ConversationId;
......
use std::{collections::HashMap, sync::Arc};
use async_trait::async_trait; use async_trait::async_trait;
use parking_lot::RwLock; use parking_lot::RwLock;
use std::collections::HashMap;
use std::sync::Arc;
use super::conversations::{ use super::conversations::{
Conversation, ConversationId, ConversationMetadata, ConversationStorage, NewConversation, Conversation, ConversationId, ConversationMetadata, ConversationStorage, NewConversation,
......
use crate::config::OracleConfig; use std::{path::Path, sync::Arc, time::Duration};
use crate::data_connector::conversations::{
Conversation, ConversationId, ConversationMetadata, ConversationStorage,
ConversationStorageError, NewConversation, Result,
};
use async_trait::async_trait; use async_trait::async_trait;
use chrono::Utc; use chrono::Utc;
use deadpool::managed::{Manager, Metrics, Pool, PoolError, RecycleError, RecycleResult}; use deadpool::managed::{Manager, Metrics, Pool, PoolError, RecycleError, RecycleResult};
use oracle::{sql_type::OracleType, Connection}; use oracle::{sql_type::OracleType, Connection};
use serde_json::Value; use serde_json::Value;
use std::path::Path;
use std::sync::Arc; use crate::{
use std::time::Duration; config::OracleConfig,
data_connector::conversations::{
Conversation, ConversationId, ConversationMetadata, ConversationStorage,
ConversationStorageError, NewConversation, Result,
},
};
#[derive(Clone)] #[derive(Clone)]
pub struct OracleConversationStorage { pub struct OracleConversationStorage {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment