Unverified Commit dc01313d authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

[router] Add rustfmt and set group imports by default (#11732)

parent 7a7f99be
use std::collections::HashMap;
use async_trait::async_trait; use async_trait::async_trait;
use regex::Regex; use regex::Regex;
use serde_json::Value; use serde_json::Value;
use std::collections::HashMap;
use crate::protocols::common::Tool;
use crate::tool_parser::{ use crate::{
errors::{ParserError, ParserResult}, protocols::common::Tool,
parsers::helpers, tool_parser::{
traits::ToolParser, errors::{ParserError, ParserResult},
types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem}, parsers::helpers,
traits::ToolParser,
types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem},
},
}; };
/// Step3 format parser for tool calls /// Step3 format parser for tool calls
......
use serde_json::{Map, Value};
use crate::tool_parser::{ use crate::tool_parser::{
errors::{ParserError, ParserResult}, errors::{ParserError, ParserResult},
traits::PartialJsonParser, traits::PartialJsonParser,
}; };
use serde_json::{Map, Value};
/// Parser for incomplete JSON /// Parser for incomplete JSON
pub struct PartialJson { pub struct PartialJson {
......
use super::*; use super::*;
use crate::tool_parser::parsers::JsonParser; use crate::tool_parser::{
use crate::tool_parser::partial_json::{ parsers::JsonParser,
compute_diff, find_common_prefix, is_complete_json, PartialJson, partial_json::{compute_diff, find_common_prefix, is_complete_json, PartialJson},
traits::ToolParser,
}; };
use crate::tool_parser::traits::ToolParser;
#[tokio::test] #[tokio::test]
async fn test_tool_parser_factory() { async fn test_tool_parser_factory() {
......
use crate::protocols::common::Tool;
use crate::tool_parser::{
errors::ParserResult,
types::{StreamingParseResult, ToolCall},
};
use async_trait::async_trait; use async_trait::async_trait;
use crate::{
protocols::common::Tool,
tool_parser::{
errors::ParserResult,
types::{StreamingParseResult, ToolCall},
},
};
/// Core trait for all tool parsers /// Core trait for all tool parsers
#[async_trait] #[async_trait]
pub trait ToolParser: Send + Sync { pub trait ToolParser: Send + Sync {
......
use dashmap::mapref::entry::Entry; use std::{
use dashmap::DashMap; cmp::Reverse,
collections::{BinaryHeap, HashMap, VecDeque},
sync::{Arc, RwLock},
time::{Duration, SystemTime, UNIX_EPOCH},
};
use dashmap::{mapref::entry::Entry, DashMap};
use tracing::info; use tracing::info;
use std::cmp::Reverse;
use std::collections::BinaryHeap;
use std::collections::HashMap;
use std::collections::VecDeque;
use std::sync::Arc;
use std::sync::RwLock;
use std::time::Duration;
use std::time::{SystemTime, UNIX_EPOCH};
type NodeRef = Arc<Node>; type NodeRef = Arc<Node>;
#[derive(Debug)] #[derive(Debug)]
...@@ -666,12 +662,12 @@ impl Tree { ...@@ -666,12 +662,12 @@ impl Tree {
// Unit tests // Unit tests
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use rand::distr::Alphanumeric; use std::{thread, time::Instant};
use rand::distr::SampleString;
use rand::rng as thread_rng; use rand::{
use rand::Rng; distr::{Alphanumeric, SampleString},
use std::thread; rng as thread_rng, Rng,
use std::time::Instant; };
use super::*; use super::*;
......
mod common; mod common;
use std::sync::Arc;
use axum::{ use axum::{
body::Body, body::Body,
extract::Request, extract::Request,
...@@ -8,13 +10,14 @@ use axum::{ ...@@ -8,13 +10,14 @@ use axum::{
use common::mock_worker::{HealthStatus, MockWorker, MockWorkerConfig, WorkerType}; use common::mock_worker::{HealthStatus, MockWorker, MockWorkerConfig, WorkerType};
use reqwest::Client; use reqwest::Client;
use serde_json::json; use serde_json::json;
use sglang_router_rs::config::{ use sglang_router_rs::{
CircuitBreakerConfig, ConnectionMode, PolicyConfig, RetryConfig, RouterConfig, RoutingMode, config::{
CircuitBreakerConfig, ConnectionMode, PolicyConfig, RetryConfig, RouterConfig, RoutingMode,
},
core::WorkerManager,
routers::{RouterFactory, RouterTrait},
server::AppContext,
}; };
use sglang_router_rs::core::WorkerManager;
use sglang_router_rs::routers::{RouterFactory, RouterTrait};
use sglang_router_rs::server::AppContext;
use std::sync::Arc;
use tower::ServiceExt; use tower::ServiceExt;
/// Test context that manages mock workers /// Test context that manages mock workers
...@@ -995,9 +998,10 @@ mod router_policy_tests { ...@@ -995,9 +998,10 @@ mod router_policy_tests {
#[cfg(test)] #[cfg(test)]
mod responses_endpoint_tests { mod responses_endpoint_tests {
use super::*;
use reqwest::Client as HttpClient; use reqwest::Client as HttpClient;
use super::*;
#[tokio::test] #[tokio::test]
async fn test_v1_responses_non_streaming() { async fn test_v1_responses_non_streaming() {
let ctx = TestContext::new(vec![MockWorkerConfig { let ctx = TestContext::new(vec![MockWorkerConfig {
......
use sglang_router_rs::core::{BasicWorkerBuilder, Worker, WorkerType}; use std::{collections::HashMap, sync::Arc};
use sglang_router_rs::policies::{CacheAwareConfig, CacheAwarePolicy, LoadBalancingPolicy};
use std::collections::HashMap; use sglang_router_rs::{
use std::sync::Arc; core::{BasicWorkerBuilder, Worker, WorkerType},
policies::{CacheAwareConfig, CacheAwarePolicy, LoadBalancingPolicy},
};
#[test] #[test]
fn test_backward_compatibility_with_empty_model_id() { fn test_backward_compatibility_with_empty_model_id() {
......
use sglang_router_rs::protocols::chat::{ChatMessage, UserMessageContent}; use sglang_router_rs::{
use sglang_router_rs::tokenizer::chat_template::{ protocols::chat::{ChatMessage, UserMessageContent},
detect_chat_template_content_format, ChatTemplateContentFormat, ChatTemplateParams, tokenizer::chat_template::{
ChatTemplateProcessor, detect_chat_template_content_format, ChatTemplateContentFormat, ChatTemplateParams,
ChatTemplateProcessor,
},
}; };
#[test] #[test]
......
use sglang_router_rs::protocols::chat::{ChatMessage, UserMessageContent}; use sglang_router_rs::{
use sglang_router_rs::protocols::common::{ContentPart, ImageUrl}; protocols::{
use sglang_router_rs::tokenizer::chat_template::{ chat::{ChatMessage, UserMessageContent},
detect_chat_template_content_format, ChatTemplateContentFormat, ChatTemplateParams, common::{ContentPart, ImageUrl},
ChatTemplateProcessor, },
tokenizer::chat_template::{
detect_chat_template_content_format, ChatTemplateContentFormat, ChatTemplateParams,
ChatTemplateProcessor,
},
}; };
#[test] #[test]
......
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use sglang_router_rs::protocols::chat::{ChatMessage, UserMessageContent};
use sglang_router_rs::tokenizer::chat_template::ChatTemplateParams;
use sglang_router_rs::tokenizer::huggingface::HuggingFaceTokenizer;
use std::fs; use std::fs;
use sglang_router_rs::{
protocols::chat::{ChatMessage, UserMessageContent},
tokenizer::{chat_template::ChatTemplateParams, huggingface::HuggingFaceTokenizer},
};
use tempfile::TempDir; use tempfile::TempDir;
#[test] #[test]
......
...@@ -148,8 +148,7 @@ mod tests { ...@@ -148,8 +148,7 @@ mod tests {
async fn test_mock_server_with_rmcp_client() { async fn test_mock_server_with_rmcp_client() {
let mut server = MockMCPServer::start().await.unwrap(); let mut server = MockMCPServer::start().await.unwrap();
use rmcp::transport::StreamableHttpClientTransport; use rmcp::{transport::StreamableHttpClientTransport, ServiceExt};
use rmcp::ServiceExt;
let transport = StreamableHttpClientTransport::from_uri(server.url().as_str()); let transport = StreamableHttpClientTransport::from_uri(server.url().as_str());
let client = ().serve(transport).await; let client = ().serve(transport).await;
......
...@@ -2,19 +2,21 @@ ...@@ -2,19 +2,21 @@
#![allow(dead_code)] #![allow(dead_code)]
use std::{net::SocketAddr, sync::Arc};
use axum::{ use axum::{
body::Body, body::Body,
extract::{Request, State}, extract::{Request, State},
http::{HeaderValue, StatusCode}, http::{HeaderValue, StatusCode},
response::sse::{Event, KeepAlive}, response::{
response::{IntoResponse, Response, Sse}, sse::{Event, KeepAlive},
IntoResponse, Response, Sse,
},
routing::post, routing::post,
Json, Router, Json, Router,
}; };
use futures_util::stream::{self, StreamExt}; use futures_util::stream::{self, StreamExt};
use serde_json::json; use serde_json::json;
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::net::TcpListener; use tokio::net::TcpListener;
/// Mock OpenAI API server for testing /// Mock OpenAI API server for testing
......
// Mock worker for testing - these functions are used by integration tests // Mock worker for testing - these functions are used by integration tests
#![allow(dead_code)] #![allow(dead_code)]
use std::{
collections::{HashMap, HashSet},
convert::Infallible,
sync::{Arc, Mutex, OnceLock},
time::{SystemTime, UNIX_EPOCH},
};
use axum::{ use axum::{
extract::{Json, Path, State}, extract::{Json, Path, State},
http::StatusCode, http::StatusCode,
response::sse::{Event, KeepAlive}, response::{
response::{IntoResponse, Response, Sse}, sse::{Event, KeepAlive},
IntoResponse, Response, Sse,
},
routing::{get, post}, routing::{get, post},
Router, Router,
}; };
use futures_util::stream::{self, StreamExt}; use futures_util::stream::{self, StreamExt};
use serde_json::json; use serde_json::json;
use std::collections::{HashMap, HashSet};
use std::convert::Infallible;
use std::sync::{Arc, Mutex, OnceLock};
use std::time::{SystemTime, UNIX_EPOCH};
use tokio::sync::RwLock; use tokio::sync::RwLock;
use uuid::Uuid; use uuid::Uuid;
......
...@@ -7,19 +7,24 @@ pub mod mock_worker; ...@@ -7,19 +7,24 @@ pub mod mock_worker;
pub mod streaming_helpers; pub mod streaming_helpers;
pub mod test_app; pub mod test_app;
use std::{
fs,
path::PathBuf,
sync::{Arc, Mutex, OnceLock},
};
use serde_json::json; use serde_json::json;
use sglang_router_rs::config::RouterConfig; use sglang_router_rs::{
use sglang_router_rs::core::{LoadMonitor, WorkerRegistry}; config::RouterConfig,
use sglang_router_rs::data_connector::{ core::{LoadMonitor, WorkerRegistry},
MemoryConversationItemStorage, MemoryConversationStorage, MemoryResponseStorage, data_connector::{
MemoryConversationItemStorage, MemoryConversationStorage, MemoryResponseStorage,
},
middleware::TokenBucket,
policies::PolicyRegistry,
protocols::common::{Function, Tool},
server::AppContext,
}; };
use sglang_router_rs::middleware::TokenBucket;
use sglang_router_rs::policies::PolicyRegistry;
use sglang_router_rs::protocols::common::{Function, Tool};
use sglang_router_rs::server::AppContext;
use std::fs;
use std::path::PathBuf;
use std::sync::{Arc, Mutex, OnceLock};
/// Helper function to create AppContext for tests /// Helper function to create AppContext for tests
pub fn create_test_context(config: RouterConfig) -> Arc<AppContext> { pub fn create_test_context(config: RouterConfig) -> Arc<AppContext> {
......
use std::sync::{Arc, OnceLock};
use axum::Router; use axum::Router;
use reqwest::Client; use reqwest::Client;
use sglang_router_rs::{ use sglang_router_rs::{
...@@ -11,7 +13,6 @@ use sglang_router_rs::{ ...@@ -11,7 +13,6 @@ use sglang_router_rs::{
routers::RouterTrait, routers::RouterTrait,
server::{build_app, AppContext, AppState}, server::{build_app, AppContext, AppState},
}; };
use std::sync::{Arc, OnceLock};
/// Create a test Axum application using the actual server's build_app function /// Create a test Axum application using the actual server's build_app function
#[allow(dead_code)] #[allow(dead_code)]
......
...@@ -9,10 +9,11 @@ ...@@ -9,10 +9,11 @@
mod common; mod common;
use std::collections::HashMap;
use common::mock_mcp_server::MockMCPServer; use common::mock_mcp_server::MockMCPServer;
use serde_json::json; use serde_json::json;
use sglang_router_rs::mcp::{McpClientManager, McpConfig, McpError, McpServerConfig, McpTransport}; use sglang_router_rs::mcp::{McpClientManager, McpConfig, McpError, McpServerConfig, McpTransport};
use std::collections::HashMap;
/// Create a new mock server for testing (each test gets its own) /// Create a new mock server for testing (each test gets its own)
async fn create_mock_server() -> MockMCPServer { async fn create_mock_server() -> MockMCPServer {
......
//! Integration tests for PolicyRegistry with RouterManager //! Integration tests for PolicyRegistry with RouterManager
use sglang_router_rs::config::PolicyConfig; use std::{collections::HashMap, sync::Arc};
use sglang_router_rs::core::WorkerRegistry;
use sglang_router_rs::policies::PolicyRegistry; use sglang_router_rs::{
use sglang_router_rs::protocols::worker_spec::WorkerConfigRequest; config::PolicyConfig, core::WorkerRegistry, policies::PolicyRegistry,
use sglang_router_rs::routers::router_manager::RouterManager; protocols::worker_spec::WorkerConfigRequest, routers::router_manager::RouterManager,
use std::collections::HashMap; };
use std::sync::Arc;
#[tokio::test] #[tokio::test]
async fn test_policy_registry_with_router_manager() { async fn test_policy_registry_with_router_manager() {
...@@ -95,8 +94,7 @@ async fn test_policy_registry_with_router_manager() { ...@@ -95,8 +94,7 @@ async fn test_policy_registry_with_router_manager() {
#[test] #[test]
fn test_policy_registry_cleanup() { fn test_policy_registry_cleanup() {
use sglang_router_rs::config::PolicyConfig; use sglang_router_rs::{config::PolicyConfig, policies::PolicyRegistry};
use sglang_router_rs::policies::PolicyRegistry;
let registry = PolicyRegistry::new(PolicyConfig::RoundRobin); let registry = PolicyRegistry::new(PolicyConfig::RoundRobin);
...@@ -123,8 +121,7 @@ fn test_policy_registry_cleanup() { ...@@ -123,8 +121,7 @@ fn test_policy_registry_cleanup() {
#[test] #[test]
fn test_policy_registry_multiple_models() { fn test_policy_registry_multiple_models() {
use sglang_router_rs::config::PolicyConfig; use sglang_router_rs::{config::PolicyConfig, policies::PolicyRegistry};
use sglang_router_rs::policies::PolicyRegistry;
let registry = PolicyRegistry::new(PolicyConfig::RoundRobin); let registry = PolicyRegistry::new(PolicyConfig::RoundRobin);
......
mod common; mod common;
use std::sync::Arc;
use common::mock_worker::{HealthStatus, MockWorker, MockWorkerConfig, WorkerType}; use common::mock_worker::{HealthStatus, MockWorker, MockWorkerConfig, WorkerType};
use reqwest::Client; use reqwest::Client;
use serde_json::json; use serde_json::json;
use sglang_router_rs::config::{RouterConfig, RoutingMode}; use sglang_router_rs::{
use sglang_router_rs::core::WorkerManager; config::{RouterConfig, RoutingMode},
use sglang_router_rs::routers::{RouterFactory, RouterTrait}; core::WorkerManager,
use std::sync::Arc; routers::{RouterFactory, RouterTrait},
};
/// Test context that manages mock workers /// Test context that manages mock workers
struct TestContext { struct TestContext {
......
// Integration test for Responses API // Integration test for Responses API
use axum::http::StatusCode; use axum::http::StatusCode;
use sglang_router_rs::protocols::common::{ use sglang_router_rs::protocols::{
GenerationRequest, ToolChoice, ToolChoiceValue, UsageInfo, common::{GenerationRequest, ToolChoice, ToolChoiceValue, UsageInfo},
}; responses::{
use sglang_router_rs::protocols::responses::{ ReasoningEffort, ResponseInput, ResponseReasoningParam, ResponseTool, ResponseToolType,
ReasoningEffort, ResponseInput, ResponseReasoningParam, ResponseTool, ResponseToolType, ResponsesRequest, ServiceTier, Truncation,
ResponsesRequest, ServiceTier, Truncation, },
}; };
mod common; mod common;
use common::mock_mcp_server::MockMCPServer; use common::{
use common::mock_worker::{HealthStatus, MockWorker, MockWorkerConfig, WorkerType}; mock_mcp_server::MockMCPServer,
use sglang_router_rs::config::{ mock_worker::{HealthStatus, MockWorker, MockWorkerConfig, WorkerType},
CircuitBreakerConfig, ConnectionMode, HealthCheckConfig, PolicyConfig, RetryConfig, };
RouterConfig, RoutingMode, use sglang_router_rs::{
config::{
CircuitBreakerConfig, ConnectionMode, HealthCheckConfig, PolicyConfig, RetryConfig,
RouterConfig, RoutingMode,
},
routers::RouterFactory,
}; };
use sglang_router_rs::routers::RouterFactory;
#[tokio::test] #[tokio::test]
async fn test_non_streaming_mcp_minimal_e2e_with_persistence() { async fn test_non_streaming_mcp_minimal_e2e_with_persistence() {
......
use serde_json::json; use serde_json::json;
use sglang_router_rs::protocols::chat::{ChatCompletionRequest, ChatMessage, UserMessageContent}; use sglang_router_rs::protocols::{
use sglang_router_rs::protocols::common::{ chat::{ChatCompletionRequest, ChatMessage, UserMessageContent},
Function, FunctionCall, FunctionChoice, StreamOptions, Tool, ToolChoice, ToolChoiceValue, common::{
ToolReference, Function, FunctionCall, FunctionChoice, StreamOptions, Tool, ToolChoice, ToolChoiceValue,
ToolReference,
},
validated::Normalizable,
}; };
use sglang_router_rs::protocols::validated::Normalizable;
use validator::Validate; use validator::Validate;
// Deprecated fields normalization tests // Deprecated fields normalization tests
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment