Unverified Commit dc01313d authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

[router] Add rustfmt and set group imports by default (#11732)

parent 7a7f99be
...@@ -3,38 +3,40 @@ ...@@ -3,38 +3,40 @@
//! This module contains shared streaming logic for both Regular and PD routers, //! This module contains shared streaming logic for both Regular and PD routers,
//! eliminating ~600 lines of duplication. //! eliminating ~600 lines of duplication.
use axum::response::Response; use std::{collections::HashMap, io, sync::Arc, time::Instant};
use axum::{body::Body, http::StatusCode};
use axum::{body::Body, http::StatusCode, response::Response};
use bytes::Bytes; use bytes::Bytes;
use http::header::{HeaderValue, CONTENT_TYPE}; use http::header::{HeaderValue, CONTENT_TYPE};
use proto::{
generate_complete::MatchedStop::{MatchedStopStr, MatchedTokenId},
generate_response::Response::{Chunk, Complete, Error},
};
use serde_json::{json, Value}; use serde_json::{json, Value};
use std::collections::HashMap; use tokio::sync::{mpsc, mpsc::UnboundedSender};
use std::io; use tokio_stream::{wrappers::UnboundedReceiverStream, StreamExt};
use std::sync::Arc;
use tokio::sync::mpsc::UnboundedSender;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tokio_stream::StreamExt;
use tracing::{debug, error, warn}; use tracing::{debug, error, warn};
use super::context; use super::{context, utils};
use super::utils; use crate::{
use crate::grpc_client::proto; grpc_client::proto,
use crate::protocols::chat::{ protocols::{
ChatCompletionRequest, ChatCompletionStreamResponse, ChatMessageDelta, ChatStreamChoice, chat::{
}; ChatCompletionRequest, ChatCompletionStreamResponse, ChatMessageDelta, ChatStreamChoice,
use crate::protocols::common::{ },
ChatLogProbs, FunctionCallDelta, StringOrArray, Tool, ToolCallDelta, ToolChoice, common::{
ToolChoiceValue, Usage, ChatLogProbs, FunctionCallDelta, StringOrArray, Tool, ToolCallDelta, ToolChoice,
ToolChoiceValue, Usage,
},
generate::GenerateRequest,
},
reasoning_parser::ReasoningParser,
tokenizer::{
stop::{SequenceDecoderOutput, StopSequenceDecoder},
traits::Tokenizer,
},
tool_parser::ToolParser,
}; };
use crate::protocols::generate::GenerateRequest;
use crate::reasoning_parser::ReasoningParser;
use crate::tokenizer::stop::{SequenceDecoderOutput, StopSequenceDecoder};
use crate::tokenizer::traits::Tokenizer;
use crate::tool_parser::ToolParser;
use proto::generate_complete::MatchedStop::{MatchedStopStr, MatchedTokenId};
use proto::generate_response::Response::{Chunk, Complete, Error};
use std::time::Instant;
use tokio::sync::mpsc;
/// Shared streaming processor for both single and dual dispatch modes /// Shared streaming processor for both single and dual dispatch modes
#[derive(Clone)] #[derive(Clone)]
......
//! Shared utilities for gRPC routers //! Shared utilities for gRPC routers
use super::ProcessedMessages; use std::{collections::HashMap, sync::Arc};
use crate::core::Worker;
use crate::grpc_client::sglang_scheduler::AbortOnDropStream;
use crate::grpc_client::{proto, SglangSchedulerClient};
use crate::protocols::chat::{ChatCompletionRequest, ChatMessage};
use crate::protocols::common::{
ChatLogProbs, ChatLogProbsContent, FunctionCallResponse, StringOrArray, Tool, ToolCall,
ToolChoice, ToolChoiceValue, TopLogProb,
};
use crate::protocols::generate::GenerateFinishReason;
use crate::tokenizer::chat_template::{ChatTemplateContentFormat, ChatTemplateParams};
use crate::tokenizer::traits::Tokenizer;
use crate::tokenizer::HuggingFaceTokenizer;
pub use crate::tokenizer::StopSequenceDecoder;
use axum::{ use axum::{
http::StatusCode, http::StatusCode,
response::{IntoResponse, Response}, response::{IntoResponse, Response},
...@@ -21,11 +9,29 @@ use axum::{ ...@@ -21,11 +9,29 @@ use axum::{
}; };
use futures::StreamExt; use futures::StreamExt;
use serde_json::{json, Map, Value}; use serde_json::{json, Map, Value};
use std::collections::HashMap;
use std::sync::Arc;
use tracing::{error, warn}; use tracing::{error, warn};
use uuid::Uuid; use uuid::Uuid;
use super::ProcessedMessages;
pub use crate::tokenizer::StopSequenceDecoder;
use crate::{
core::Worker,
grpc_client::{proto, sglang_scheduler::AbortOnDropStream, SglangSchedulerClient},
protocols::{
chat::{ChatCompletionRequest, ChatMessage},
common::{
ChatLogProbs, ChatLogProbsContent, FunctionCallResponse, StringOrArray, Tool, ToolCall,
ToolChoice, ToolChoiceValue, TopLogProb,
},
generate::GenerateFinishReason,
},
tokenizer::{
chat_template::{ChatTemplateContentFormat, ChatTemplateParams},
traits::Tokenizer,
HuggingFaceTokenizer,
},
};
/// Get gRPC client from worker, returning appropriate error response on failure /// Get gRPC client from worker, returning appropriate error response on failure
pub async fn get_grpc_client_from_worker( pub async fn get_grpc_client_from_worker(
worker: &Arc<dyn Worker>, worker: &Arc<dyn Worker>,
...@@ -953,12 +959,17 @@ pub fn parse_finish_reason(reason_str: &str, completion_tokens: i32) -> Generate ...@@ -953,12 +959,17 @@ pub fn parse_finish_reason(reason_str: &str, completion_tokens: i32) -> Generate
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*;
use crate::protocols::chat::{ChatMessage, UserMessageContent};
use crate::protocols::common::{ContentPart, ImageUrl};
use crate::tokenizer::chat_template::ChatTemplateContentFormat;
use serde_json::json; use serde_json::json;
use super::*;
use crate::{
protocols::{
chat::{ChatMessage, UserMessageContent},
common::{ContentPart, ImageUrl},
},
tokenizer::chat_template::ChatTemplateContentFormat,
};
#[test] #[test]
fn test_transform_messages_string_format() { fn test_transform_messages_string_format() {
let messages = vec![ChatMessage::User { let messages = vec![ChatMessage::User {
......
use axum::body::Body; use axum::{body::Body, extract::Request, http::HeaderMap};
use axum::extract::Request;
use axum::http::HeaderMap;
/// Copy request headers to a Vec of name-value string pairs /// Copy request headers to a Vec of name-value string pairs
/// Used for forwarding headers to backend workers /// Used for forwarding headers to backend workers
......
use super::pd_types::api_path; use std::{sync::Arc, time::Instant};
use crate::config::types::RetryConfig;
use crate::core::{
is_retryable_status, RetryExecutor, Worker, WorkerLoadGuard, WorkerRegistry, WorkerType,
};
use crate::metrics::RouterMetrics;
use crate::policies::{LoadBalancingPolicy, PolicyRegistry};
use crate::protocols::chat::{ChatCompletionRequest, ChatMessage, UserMessageContent};
use crate::protocols::common::{InputIds, StringOrArray};
use crate::protocols::completion::CompletionRequest;
use crate::protocols::embedding::EmbeddingRequest;
use crate::protocols::generate::GenerateRequest;
use crate::protocols::rerank::RerankRequest;
use crate::protocols::responses::{ResponsesGetParams, ResponsesRequest};
use crate::routers::header_utils;
use crate::routers::RouterTrait;
use async_trait::async_trait; use async_trait::async_trait;
use axum::{ use axum::{
body::Body, body::Body,
...@@ -25,11 +11,29 @@ use futures_util::StreamExt; ...@@ -25,11 +11,29 @@ use futures_util::StreamExt;
use reqwest::Client; use reqwest::Client;
use serde::Serialize; use serde::Serialize;
use serde_json::{json, Value}; use serde_json::{json, Value};
use std::sync::Arc;
use std::time::Instant;
use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::{debug, error, warn}; use tracing::{debug, error, warn};
use super::pd_types::api_path;
use crate::{
config::types::RetryConfig,
core::{
is_retryable_status, RetryExecutor, Worker, WorkerLoadGuard, WorkerRegistry, WorkerType,
},
metrics::RouterMetrics,
policies::{LoadBalancingPolicy, PolicyRegistry},
protocols::{
chat::{ChatCompletionRequest, ChatMessage, UserMessageContent},
common::{InputIds, StringOrArray},
completion::CompletionRequest,
embedding::EmbeddingRequest,
generate::GenerateRequest,
rerank::RerankRequest,
responses::{ResponsesGetParams, ResponsesRequest},
},
routers::{header_utils, RouterTrait},
};
#[derive(Debug)] #[derive(Debug)]
pub struct PDRouter { pub struct PDRouter {
pub worker_registry: Arc<WorkerRegistry>, pub worker_registry: Arc<WorkerRegistry>,
......
use crate::config::types::RetryConfig; use std::{sync::Arc, time::Instant};
use crate::core::{
is_retryable_status, ConnectionMode, RetryExecutor, Worker, WorkerRegistry, WorkerType,
};
use crate::metrics::RouterMetrics;
use crate::policies::PolicyRegistry;
use crate::protocols::chat::ChatCompletionRequest;
use crate::protocols::common::GenerationRequest;
use crate::protocols::completion::CompletionRequest;
use crate::protocols::embedding::EmbeddingRequest;
use crate::protocols::generate::GenerateRequest;
use crate::protocols::rerank::{RerankRequest, RerankResponse, RerankResult};
use crate::protocols::responses::{ResponsesGetParams, ResponsesRequest};
use crate::routers::header_utils;
use crate::routers::RouterTrait;
use axum::body::to_bytes;
use axum::{ use axum::{
body::Body, body::{to_bytes, Body},
extract::Request, extract::Request,
http::{ http::{
header::CONTENT_LENGTH, header::CONTENT_TYPE, HeaderMap, HeaderValue, Method, StatusCode, header::{CONTENT_LENGTH, CONTENT_TYPE},
HeaderMap, HeaderValue, Method, StatusCode,
}, },
response::{IntoResponse, Response}, response::{IntoResponse, Response},
Json, Json,
}; };
use futures_util::StreamExt; use futures_util::StreamExt;
use reqwest::Client; use reqwest::Client;
use std::sync::Arc;
use std::time::Instant;
use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::{debug, error}; use tracing::{debug, error};
use crate::{
config::types::RetryConfig,
core::{
is_retryable_status, ConnectionMode, RetryExecutor, Worker, WorkerRegistry, WorkerType,
},
metrics::RouterMetrics,
policies::PolicyRegistry,
protocols::{
chat::ChatCompletionRequest,
common::GenerationRequest,
completion::CompletionRequest,
embedding::EmbeddingRequest,
generate::GenerateRequest,
rerank::{RerankRequest, RerankResponse, RerankResult},
responses::{ResponsesGetParams, ResponsesRequest},
},
routers::{header_utils, RouterTrait},
};
/// Regular router that uses injected load balancing policies /// Regular router that uses injected load balancing policies
#[derive(Debug)] #[derive(Debug)]
pub struct Router { pub struct Router {
......
//! Router implementations //! Router implementations
use std::fmt::Debug;
use async_trait::async_trait; use async_trait::async_trait;
use axum::{ use axum::{
body::Body, body::Body,
...@@ -7,16 +9,17 @@ use axum::{ ...@@ -7,16 +9,17 @@ use axum::{
http::{HeaderMap, StatusCode}, http::{HeaderMap, StatusCode},
response::{IntoResponse, Response}, response::{IntoResponse, Response},
}; };
use std::fmt::Debug;
use crate::protocols::chat::ChatCompletionRequest;
use crate::protocols::completion::CompletionRequest;
use crate::protocols::embedding::EmbeddingRequest;
use crate::protocols::generate::GenerateRequest;
use crate::protocols::rerank::RerankRequest;
use crate::protocols::responses::{ResponsesGetParams, ResponsesRequest};
use serde_json::Value; use serde_json::Value;
use crate::protocols::{
chat::ChatCompletionRequest,
completion::CompletionRequest,
embedding::EmbeddingRequest,
generate::GenerateRequest,
rerank::RerankRequest,
responses::{ResponsesGetParams, ResponsesRequest},
};
pub mod factory; pub mod factory;
pub mod grpc; pub mod grpc;
pub mod header_utils; pub mod header_utils;
...@@ -25,7 +28,6 @@ pub mod openai; // New refactored OpenAI router module ...@@ -25,7 +28,6 @@ pub mod openai; // New refactored OpenAI router module
pub mod router_manager; pub mod router_manager;
pub use factory::RouterFactory; pub use factory::RouterFactory;
// Re-export HTTP routers for convenience // Re-export HTTP routers for convenience
pub use http::{pd_router, pd_types, router}; pub use http::{pd_router, pd_types, router};
......
//! Conversation CRUD operations and persistence //! Conversation CRUD operations and persistence
use crate::data_connector::{ use std::{collections::HashMap, sync::Arc};
conversation_items::ListParams, conversation_items::SortOrder, Conversation, ConversationId,
ConversationItemId, ConversationItemStorage, ConversationStorage, NewConversation, use axum::{
NewConversationItem, ResponseId, ResponseStorage, SharedConversationItemStorage, http::StatusCode,
SharedConversationStorage, response::{IntoResponse, Response},
Json,
}; };
use crate::protocols::responses::{ResponseInput, ResponseInputOutputItem, ResponsesRequest};
use axum::http::StatusCode;
use axum::response::{IntoResponse, Response};
use axum::Json;
use chrono::Utc; use chrono::Utc;
use serde_json::{json, Value}; use serde_json::{json, Value};
use std::collections::HashMap;
use std::sync::Arc;
use tracing::{debug, info, warn}; use tracing::{debug, info, warn};
use super::responses::build_stored_response; use super::responses::build_stored_response;
use crate::{
data_connector::{
conversation_items::{ListParams, SortOrder},
Conversation, ConversationId, ConversationItemId, ConversationItemStorage,
ConversationStorage, NewConversation, NewConversationItem, ResponseId, ResponseStorage,
SharedConversationItemStorage, SharedConversationStorage,
},
protocols::responses::{ResponseInput, ResponseInputOutputItem, ResponsesRequest},
};
/// Maximum number of properties allowed in conversation metadata /// Maximum number of properties allowed in conversation metadata
pub(crate) const MAX_METADATA_PROPERTIES: usize = 16; pub(crate) const MAX_METADATA_PROPERTIES: usize = 16;
......
...@@ -8,19 +8,20 @@ ...@@ -8,19 +8,20 @@
//! - Payload transformation for MCP tool interception //! - Payload transformation for MCP tool interception
//! - Metadata injection for MCP operations //! - Metadata injection for MCP operations
use crate::mcp::McpClientManager; use std::{io, sync::Arc};
use crate::protocols::responses::{
ResponseInput, ResponseTool, ResponseToolType, ResponsesRequest,
};
use crate::routers::header_utils::apply_request_headers;
use axum::http::HeaderMap; use axum::http::HeaderMap;
use bytes::Bytes; use bytes::Bytes;
use serde_json::{json, to_value, Value}; use serde_json::{json, to_value, Value};
use std::{io, sync::Arc};
use tokio::sync::mpsc; use tokio::sync::mpsc;
use tracing::{info, warn}; use tracing::{info, warn};
use super::utils::event_types; use super::utils::event_types;
use crate::{
mcp::McpClientManager,
protocols::responses::{ResponseInput, ResponseTool, ResponseToolType, ResponsesRequest},
routers::header_utils::apply_request_headers,
};
// ============================================================================ // ============================================================================
// Configuration and State Types // Configuration and State Types
......
//! Response storage, patching, and extraction utilities //! Response storage, patching, and extraction utilities
use crate::data_connector::{ResponseId, StoredResponse};
use crate::protocols::responses::{ResponseInput, ResponseToolType, ResponsesRequest};
use serde_json::{json, Value};
use std::collections::HashMap; use std::collections::HashMap;
use serde_json::{json, Value};
use tracing::warn; use tracing::warn;
use super::utils::event_types; use super::utils::event_types;
use crate::{
data_connector::{ResponseId, StoredResponse},
protocols::responses::{ResponseInput, ResponseToolType, ResponsesRequest},
};
// ============================================================================ // ============================================================================
// Response Storage Operations // Response Storage Operations
......
//! OpenAI router - main coordinator that delegates to specialized modules //! OpenAI router - main coordinator that delegates to specialized modules
use crate::config::CircuitBreakerConfig; use std::{
use crate::core::{CircuitBreaker, CircuitBreakerConfig as CoreCircuitBreakerConfig}; any::Any,
use crate::data_connector::{ sync::{atomic::AtomicBool, Arc},
conversation_items::ListParams, conversation_items::SortOrder, ConversationId, ResponseId,
SharedConversationItemStorage, SharedConversationStorage, SharedResponseStorage,
};
use crate::protocols::chat::ChatCompletionRequest;
use crate::protocols::completion::CompletionRequest;
use crate::protocols::embedding::EmbeddingRequest;
use crate::protocols::generate::GenerateRequest;
use crate::protocols::rerank::RerankRequest;
use crate::protocols::responses::{
ResponseContentPart, ResponseInput, ResponseInputOutputItem, ResponsesGetParams,
ResponsesRequest,
}; };
use crate::routers::header_utils::apply_request_headers;
use axum::{ use axum::{
body::Body, body::Body,
extract::Request, extract::Request,
...@@ -25,10 +14,6 @@ use axum::{ ...@@ -25,10 +14,6 @@ use axum::{
}; };
use futures_util::StreamExt; use futures_util::StreamExt;
use serde_json::{json, to_value, Value}; use serde_json::{json, to_value, Value};
use std::{
any::Any,
sync::{atomic::AtomicBool, Arc},
};
use tokio::sync::mpsc; use tokio::sync::mpsc;
use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::warn; use tracing::warn;
...@@ -39,12 +24,35 @@ use super::conversations::{ ...@@ -39,12 +24,35 @@ use super::conversations::{
get_conversation, get_conversation_item, list_conversation_items, persist_conversation_items, get_conversation, get_conversation_item, list_conversation_items, persist_conversation_items,
update_conversation, update_conversation,
}; };
use super::mcp::{ use super::{
execute_tool_loop, mcp_manager_from_request_tools, prepare_mcp_payload_for_streaming, mcp::{
McpLoopConfig, execute_tool_loop, mcp_manager_from_request_tools, prepare_mcp_payload_for_streaming,
McpLoopConfig,
},
responses::{mask_tools_as_mcp, patch_streaming_response_json},
streaming::handle_streaming_response,
};
use crate::{
config::CircuitBreakerConfig,
core::{CircuitBreaker, CircuitBreakerConfig as CoreCircuitBreakerConfig},
data_connector::{
conversation_items::{ListParams, SortOrder},
ConversationId, ResponseId, SharedConversationItemStorage, SharedConversationStorage,
SharedResponseStorage,
},
protocols::{
chat::ChatCompletionRequest,
completion::CompletionRequest,
embedding::EmbeddingRequest,
generate::GenerateRequest,
rerank::RerankRequest,
responses::{
ResponseContentPart, ResponseInput, ResponseInputOutputItem, ResponsesGetParams,
ResponsesRequest,
},
},
routers::header_utils::apply_request_headers,
}; };
use super::responses::{mask_tools_as_mcp, patch_streaming_response_json};
use super::streaming::handle_streaming_response;
// ============================================================================ // ============================================================================
// OpenAIRouter Struct // OpenAIRouter Struct
......
...@@ -7,11 +7,8 @@ ...@@ -7,11 +7,8 @@
//! - MCP tool execution loops within streaming responses //! - MCP tool execution loops within streaming responses
//! - Event transformation and output index remapping //! - Event transformation and output index remapping
use crate::data_connector::{ use std::{borrow::Cow, io, sync::Arc};
SharedConversationItemStorage, SharedConversationStorage, SharedResponseStorage,
};
use crate::protocols::responses::{ResponseToolType, ResponsesRequest};
use crate::routers::header_utils::{apply_request_headers, preserve_response_headers};
use axum::{ use axum::{
body::Body, body::Body,
http::{header::CONTENT_TYPE, HeaderMap, HeaderValue, StatusCode}, http::{header::CONTENT_TYPE, HeaderMap, HeaderValue, StatusCode},
...@@ -20,20 +17,28 @@ use axum::{ ...@@ -20,20 +17,28 @@ use axum::{
use bytes::Bytes; use bytes::Bytes;
use futures_util::StreamExt; use futures_util::StreamExt;
use serde_json::{json, Value}; use serde_json::{json, Value};
use std::{borrow::Cow, io, sync::Arc};
use tokio::sync::mpsc; use tokio::sync::mpsc;
use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::warn; use tracing::warn;
// Import from sibling modules // Import from sibling modules
use super::conversations::persist_conversation_items; use super::conversations::persist_conversation_items;
use super::mcp::{ use super::{
build_resume_payload, execute_streaming_tool_calls, inject_mcp_metadata_streaming, mcp::{
mcp_manager_from_request_tools, prepare_mcp_payload_for_streaming, send_mcp_list_tools_events, build_resume_payload, execute_streaming_tool_calls, inject_mcp_metadata_streaming,
McpLoopConfig, ToolLoopState, mcp_manager_from_request_tools, prepare_mcp_payload_for_streaming,
send_mcp_list_tools_events, McpLoopConfig, ToolLoopState,
},
responses::{mask_tools_as_mcp, patch_streaming_response_json, rewrite_streaming_block},
utils::{event_types, FunctionCallInProgress, OutputIndexMapper, StreamAction},
};
use crate::{
data_connector::{
SharedConversationItemStorage, SharedConversationStorage, SharedResponseStorage,
},
protocols::responses::{ResponseToolType, ResponsesRequest},
routers::header_utils::{apply_request_headers, preserve_response_headers},
}; };
use super::responses::{mask_tools_as_mcp, patch_streaming_response_json, rewrite_streaming_block};
use super::utils::{event_types, FunctionCallInProgress, OutputIndexMapper, StreamAction};
// ============================================================================ // ============================================================================
// Streaming Response Accumulator // Streaming Response Accumulator
......
...@@ -4,16 +4,8 @@ ...@@ -4,16 +4,8 @@
//! - Single Router Mode (enable_igw=false): Router owns workers directly //! - Single Router Mode (enable_igw=false): Router owns workers directly
//! - Multi-Router Mode (enable_igw=true): RouterManager coordinates everything //! - Multi-Router Mode (enable_igw=true): RouterManager coordinates everything
use crate::config::{ConnectionMode, RoutingMode}; use std::sync::Arc;
use crate::core::{WorkerRegistry, WorkerType};
use crate::protocols::chat::ChatCompletionRequest;
use crate::protocols::completion::CompletionRequest;
use crate::protocols::embedding::EmbeddingRequest;
use crate::protocols::generate::GenerateRequest;
use crate::protocols::rerank::RerankRequest;
use crate::protocols::responses::{ResponsesGetParams, ResponsesRequest};
use crate::routers::RouterTrait;
use crate::server::{AppContext, ServerConfig};
use async_trait::async_trait; use async_trait::async_trait;
use axum::{ use axum::{
body::Body, body::Body,
...@@ -23,9 +15,23 @@ use axum::{ ...@@ -23,9 +15,23 @@ use axum::{
}; };
use dashmap::DashMap; use dashmap::DashMap;
use serde_json::Value; use serde_json::Value;
use std::sync::Arc;
use tracing::{debug, info, warn}; use tracing::{debug, info, warn};
use crate::{
config::{ConnectionMode, RoutingMode},
core::{WorkerRegistry, WorkerType},
protocols::{
chat::ChatCompletionRequest,
completion::CompletionRequest,
embedding::EmbeddingRequest,
generate::GenerateRequest,
rerank::RerankRequest,
responses::{ResponsesGetParams, ResponsesRequest},
},
routers::RouterTrait,
server::{AppContext, ServerConfig},
};
#[derive(Debug, Clone, Hash, Eq, PartialEq)] #[derive(Debug, Clone, Hash, Eq, PartialEq)]
pub struct RouterId(String); pub struct RouterId(String);
......
use std::{
sync::{
atomic::{AtomicBool, Ordering},
Arc, OnceLock,
},
time::Duration,
};
use axum::{
extract::{Path, Query, Request, State},
http::StatusCode,
response::{IntoResponse, Response},
routing::{delete, get, post},
serve, Json, Router,
};
use reqwest::Client;
use serde::Deserialize;
use serde_json::{json, Value};
use tokio::{net::TcpListener, signal, spawn};
use tracing::{error, info, warn, Level};
use crate::{ use crate::{
config::{ConnectionMode, HistoryBackend, RouterConfig, RoutingMode}, config::{ConnectionMode, HistoryBackend, RouterConfig, RoutingMode},
core::{ core::{
...@@ -30,24 +51,6 @@ use crate::{ ...@@ -30,24 +51,6 @@ use crate::{
tokenizer::{factory as tokenizer_factory, traits::Tokenizer}, tokenizer::{factory as tokenizer_factory, traits::Tokenizer},
tool_parser::ParserFactory as ToolParserFactory, tool_parser::ParserFactory as ToolParserFactory,
}; };
use axum::{
extract::{Path, Query, Request, State},
http::StatusCode,
response::{IntoResponse, Response},
routing::{delete, get, post},
serve, Json, Router,
};
use reqwest::Client;
use serde::Deserialize;
use serde_json::{json, Value};
use std::sync::OnceLock;
use std::{
sync::atomic::{AtomicBool, Ordering},
sync::Arc,
time::Duration,
};
use tokio::{net::TcpListener, signal, spawn};
use tracing::{error, info, warn, Level};
// //
......
use crate::core::WorkerManager; use std::{
use crate::protocols::worker_spec::WorkerConfigRequest; collections::{HashMap, HashSet},
use crate::server::AppContext; sync::{Arc, Mutex},
time::Duration,
};
use futures::{StreamExt, TryStreamExt}; use futures::{StreamExt, TryStreamExt};
use k8s_openapi::api::core::v1::Pod; use k8s_openapi::api::core::v1::Pod;
use kube::{ use kube::{
api::Api, api::Api,
runtime::watcher::{watcher, Config}, runtime::{
runtime::WatchStreamExt, watcher::{watcher, Config},
WatchStreamExt,
},
Client, Client,
}; };
use std::collections::{HashMap, HashSet};
use rustls; use rustls;
use std::sync::{Arc, Mutex}; use tokio::{task, time};
use std::time::Duration;
use tokio::task;
use tokio::time;
use tracing::{debug, error, info, warn}; use tracing::{debug, error, info, warn};
use crate::{core::WorkerManager, protocols::worker_spec::WorkerConfigRequest, server::AppContext};
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct ServiceDiscoveryConfig { pub struct ServiceDiscoveryConfig {
pub enabled: bool, pub enabled: bool,
...@@ -452,10 +453,12 @@ async fn handle_pod_deletion( ...@@ -452,10 +453,12 @@ async fn handle_pod_deletion(
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use k8s_openapi::{
api::core::v1::{Pod, PodCondition, PodSpec, PodStatus},
apimachinery::pkg::apis::meta::v1::{ObjectMeta, Time},
};
use super::*; use super::*;
use k8s_openapi::api::core::v1::{Pod, PodCondition, PodSpec, PodStatus};
use k8s_openapi::apimachinery::pkg::apis::meta::v1::ObjectMeta;
use k8s_openapi::apimachinery::pkg::apis::meta::v1::Time;
fn create_k8s_pod( fn create_k8s_pod(
name: Option<&str>, name: Option<&str>,
...@@ -535,8 +538,7 @@ mod tests { ...@@ -535,8 +538,7 @@ mod tests {
} }
async fn create_test_app_context() -> Arc<AppContext> { async fn create_test_app_context() -> Arc<AppContext> {
use crate::config::RouterConfig; use crate::{config::RouterConfig, middleware::TokenBucket};
use crate::middleware::TokenBucket;
let router_config = RouterConfig { let router_config = RouterConfig {
worker_startup_timeout_secs: 1, worker_startup_timeout_secs: 1,
......
...@@ -3,11 +3,15 @@ ...@@ -3,11 +3,15 @@
//! This module provides functionality to apply chat templates to messages, //! This module provides functionality to apply chat templates to messages,
//! similar to HuggingFace transformers' apply_chat_template method. //! similar to HuggingFace transformers' apply_chat_template method.
use std::collections::HashMap;
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
use minijinja::machinery::ast::{Expr, Stmt}; use minijinja::{
use minijinja::{context, Environment, Value}; context,
machinery::ast::{Expr, Stmt},
Environment, Value,
};
use serde_json; use serde_json;
use std::collections::HashMap;
/// Chat template content format /// Chat template content format
#[derive(Debug, Clone, Copy, PartialEq, Eq)] #[derive(Debug, Clone, Copy, PartialEq, Eq)]
...@@ -319,8 +323,10 @@ impl<'a> Detector<'a> { ...@@ -319,8 +323,10 @@ impl<'a> Detector<'a> {
/// AST-based detection using minijinja's unstable machinery /// AST-based detection using minijinja's unstable machinery
/// Single-pass detector with scope tracking /// Single-pass detector with scope tracking
fn detect_format_with_ast(template: &str) -> Option<ChatTemplateContentFormat> { fn detect_format_with_ast(template: &str) -> Option<ChatTemplateContentFormat> {
use minijinja::machinery::{parse, WhitespaceConfig}; use minijinja::{
use minijinja::syntax::SyntaxConfig; machinery::{parse, WhitespaceConfig},
syntax::SyntaxConfig,
};
let ast = match parse( let ast = match parse(
template, template,
......
use super::traits; use std::{fs::File, io::Read, path::Path, sync::Arc};
use anyhow::{Error, Result}; use anyhow::{Error, Result};
use std::fs::File;
use std::io::Read;
use std::path::Path;
use std::sync::Arc;
use tracing::{debug, info}; use tracing::{debug, info};
use super::huggingface::HuggingFaceTokenizer; use super::{huggingface::HuggingFaceTokenizer, tiktoken::TiktokenTokenizer, traits};
use super::tiktoken::TiktokenTokenizer;
use crate::tokenizer::hub::download_tokenizer_from_hf; use crate::tokenizer::hub::download_tokenizer_from_hf;
/// Represents the type of tokenizer being used /// Represents the type of tokenizer being used
...@@ -379,8 +375,7 @@ pub fn get_tokenizer_info(file_path: &str) -> Result<TokenizerType> { ...@@ -379,8 +375,7 @@ pub fn get_tokenizer_info(file_path: &str) -> Result<TokenizerType> {
Some("json") => Ok(TokenizerType::HuggingFace(file_path.to_string())), Some("json") => Ok(TokenizerType::HuggingFace(file_path.to_string())),
_ => { _ => {
// Try auto-detection // Try auto-detection
use std::fs::File; use std::{fs::File, io::Read};
use std::io::Read;
let mut file = File::open(file_path)?; let mut file = File::open(file_path)?;
let mut buffer = vec![0u8; 512]; let mut buffer = vec![0u8; 512];
......
use std::{
env,
path::{Path, PathBuf},
};
use hf_hub::api::tokio::ApiBuilder; use hf_hub::api::tokio::ApiBuilder;
use std::env;
use std::path::{Path, PathBuf};
const IGNORED: [&str; 5] = [ const IGNORED: [&str; 5] = [
".gitattributes", ".gitattributes",
......
...@@ -3,12 +3,12 @@ use std::collections::HashMap; ...@@ -3,12 +3,12 @@ use std::collections::HashMap;
use anyhow::{Error, Result}; use anyhow::{Error, Result};
use tokenizers::tokenizer::Tokenizer as HfTokenizer; use tokenizers::tokenizer::Tokenizer as HfTokenizer;
use super::chat_template::{ use super::{
detect_chat_template_content_format, ChatTemplateContentFormat, ChatTemplateParams, chat_template::{
ChatTemplateProcessor, detect_chat_template_content_format, ChatTemplateContentFormat, ChatTemplateParams,
}; ChatTemplateProcessor,
use super::traits::{ },
Decoder, Encoder, Encoding, SpecialTokens, TokenIdType, Tokenizer as TokenizerTrait, traits::{Decoder, Encoder, Encoding, SpecialTokens, TokenIdType, Tokenizer as TokenizerTrait},
}; };
/// HuggingFace tokenizer wrapper /// HuggingFace tokenizer wrapper
......
//! Mock tokenizer implementation for testing //! Mock tokenizer implementation for testing
use super::traits::{Decoder, Encoder, Encoding, SpecialTokens, Tokenizer as TokenizerTrait};
use anyhow::Result;
use std::collections::HashMap; use std::collections::HashMap;
use anyhow::Result;
use super::traits::{Decoder, Encoder, Encoding, SpecialTokens, Tokenizer as TokenizerTrait};
/// Mock tokenizer for testing purposes /// Mock tokenizer for testing purposes
pub struct MockTokenizer { pub struct MockTokenizer {
vocab: HashMap<String, u32>, vocab: HashMap<String, u32>,
......
use std::{ops::Deref, sync::Arc};
use anyhow::Result; use anyhow::Result;
use std::ops::Deref;
use std::sync::Arc;
pub mod factory; pub mod factory;
pub mod hub; pub mod hub;
...@@ -27,14 +27,12 @@ pub use factory::{ ...@@ -27,14 +27,12 @@ pub use factory::{
create_tokenizer_from_file, create_tokenizer_with_chat_template, create_tokenizer_from_file, create_tokenizer_with_chat_template,
create_tokenizer_with_chat_template_blocking, TokenizerType, create_tokenizer_with_chat_template_blocking, TokenizerType,
}; };
pub use huggingface::HuggingFaceTokenizer;
pub use sequence::Sequence; pub use sequence::Sequence;
pub use stop::{SequenceDecoderOutput, StopSequenceConfig, StopSequenceDecoder}; pub use stop::{SequenceDecoderOutput, StopSequenceConfig, StopSequenceDecoder};
pub use stream::DecodeStream; pub use stream::DecodeStream;
pub use traits::{Decoder, Encoder, Encoding, SpecialTokens, Tokenizer as TokenizerTrait};
pub use huggingface::HuggingFaceTokenizer;
pub use tiktoken::{TiktokenModel, TiktokenTokenizer}; pub use tiktoken::{TiktokenModel, TiktokenTokenizer};
pub use traits::{Decoder, Encoder, Encoding, SpecialTokens, Tokenizer as TokenizerTrait};
/// Main tokenizer wrapper that provides a unified interface for different tokenizer implementations /// Main tokenizer wrapper that provides a unified interface for different tokenizer implementations
#[derive(Clone)] #[derive(Clone)]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment