Unverified Commit 83804bc6 authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

[router][grpc] Restructure modules and code clean up (#12598)

parent d5fa019c
//! Pipeline stages for regular (non-harmony) model processing
//!
//! This module defines stages specific to regular tokenizer-based models.
pub mod chat;
pub mod generate;
mod preparation;
mod request_building;
mod response_processing;
pub use chat::{ChatPreparationStage, ChatRequestBuildingStage, ChatResponseProcessingStage};
pub use generate::{
GeneratePreparationStage, GenerateRequestBuildingStage, GenerateResponseProcessingStage,
};
pub use preparation::PreparationStage;
pub use request_building::RequestBuildingStage;
pub use response_processing::ResponseProcessingStage;
//! Preparation stage that delegates to endpoint-specific implementations
//!
//! This stage checks RequestType at runtime and delegates to the appropriate
//! endpoint-specific stage (ChatPreparationStage or GeneratePreparationStage).
use async_trait::async_trait;
use axum::response::Response;
use super::{chat::ChatPreparationStage, generate::GeneratePreparationStage};
use crate::routers::grpc::{
common::stages::PipelineStage,
context::{RequestContext, RequestType},
};
/// Preparation stage (delegates to endpoint-specific implementations)
pub struct PreparationStage {
chat_stage: ChatPreparationStage,
generate_stage: GeneratePreparationStage,
}
impl PreparationStage {
pub fn new() -> Self {
Self {
chat_stage: ChatPreparationStage,
generate_stage: GeneratePreparationStage,
}
}
}
impl Default for PreparationStage {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl PipelineStage for PreparationStage {
async fn execute(&self, ctx: &mut RequestContext) -> Result<Option<Response>, Response> {
match &ctx.input.request_type {
RequestType::Chat(_) => self.chat_stage.execute(ctx).await,
RequestType::Generate(_) => self.generate_stage.execute(ctx).await,
RequestType::Responses(_) => {
// Responses API has its own preparation handled elsewhere
Ok(None)
}
}
}
fn name(&self) -> &'static str {
"Preparation"
}
}
//! Request building stage that delegates to endpoint-specific implementations
use async_trait::async_trait;
use axum::response::Response;
use uuid::Uuid;
use super::{chat::ChatRequestBuildingStage, generate::GenerateRequestBuildingStage};
use crate::{
grpc_client::proto,
routers::grpc::{
common::stages::PipelineStage,
context::{RequestContext, RequestType},
},
};
/// Request building stage (delegates to endpoint-specific implementations)
pub struct RequestBuildingStage {
chat_stage: ChatRequestBuildingStage,
generate_stage: GenerateRequestBuildingStage,
}
impl RequestBuildingStage {
pub fn new(inject_pd_metadata: bool) -> Self {
Self {
chat_stage: ChatRequestBuildingStage::new(inject_pd_metadata),
generate_stage: GenerateRequestBuildingStage::new(inject_pd_metadata),
}
}
}
#[async_trait]
impl PipelineStage for RequestBuildingStage {
async fn execute(&self, ctx: &mut RequestContext) -> Result<Option<Response>, Response> {
match &ctx.input.request_type {
RequestType::Chat(_) => self.chat_stage.execute(ctx).await,
RequestType::Generate(_) => self.generate_stage.execute(ctx).await,
RequestType::Responses(_request) => {
// Responses API builds request during the MCP loop
// For now, create minimal request - responses handler will populate it
let request_id = format!("resp-{}", Uuid::new_v4());
ctx.state.proto_request = Some(proto::GenerateRequest {
request_id,
..Default::default()
});
Ok(None)
}
}
}
fn name(&self) -> &'static str {
"RequestBuilding"
}
}
//! Response processing stage that delegates to endpoint-specific implementations
use std::sync::Arc;
use async_trait::async_trait;
use axum::response::Response;
use super::{chat::ChatResponseProcessingStage, generate::GenerateResponseProcessingStage};
use crate::routers::grpc::{
common::stages::PipelineStage,
context::{RequestContext, RequestType},
error,
regular::{processor, streaming},
};
/// Response processing stage (delegates to endpoint-specific implementations)
pub struct ResponseProcessingStage {
chat_stage: ChatResponseProcessingStage,
generate_stage: GenerateResponseProcessingStage,
}
impl ResponseProcessingStage {
pub fn new(
processor: processor::ResponseProcessor,
streaming_processor: Arc<streaming::StreamingProcessor>,
) -> Self {
Self {
chat_stage: ChatResponseProcessingStage::new(
processor.clone(),
streaming_processor.clone(),
),
generate_stage: GenerateResponseProcessingStage::new(processor, streaming_processor),
}
}
}
#[async_trait]
impl PipelineStage for ResponseProcessingStage {
async fn execute(&self, ctx: &mut RequestContext) -> Result<Option<Response>, Response> {
match &ctx.input.request_type {
RequestType::Chat(_) => self.chat_stage.execute(ctx).await,
RequestType::Generate(_) => self.generate_stage.execute(ctx).await,
RequestType::Responses(_) => Err(error::bad_request(
"Responses API processing must be handled by responses handler".to_string(),
)),
}
}
fn name(&self) -> &'static str {
"ResponseProcessing"
}
}
//! Streaming response processor for gRPC routers //! Streaming response processor for gRPC routers
//! //!
//! This module contains shared streaming logic for both Regular and PD routers, //! This module contains shared streaming logic for both Regular and PD router.
//! eliminating ~600 lines of duplication.
use std::{collections::HashMap, io, sync::Arc, time::Instant}; use std::{collections::HashMap, io, sync::Arc, time::Instant};
...@@ -17,9 +16,8 @@ use tokio::sync::{mpsc, mpsc::UnboundedSender}; ...@@ -17,9 +16,8 @@ use tokio::sync::{mpsc, mpsc::UnboundedSender};
use tokio_stream::{wrappers::UnboundedReceiverStream, StreamExt}; use tokio_stream::{wrappers::UnboundedReceiverStream, StreamExt};
use tracing::{debug, error, warn}; use tracing::{debug, error, warn};
use super::{context, utils};
use crate::{ use crate::{
grpc_client::proto, grpc_client::{proto, sglang_scheduler::AbortOnDropStream},
protocols::{ protocols::{
chat::{ chat::{
ChatCompletionRequest, ChatCompletionStreamResponse, ChatMessageDelta, ChatStreamChoice, ChatCompletionRequest, ChatCompletionStreamResponse, ChatMessageDelta, ChatStreamChoice,
...@@ -30,20 +28,21 @@ use crate::{ ...@@ -30,20 +28,21 @@ use crate::{
}, },
generate::GenerateRequest, generate::GenerateRequest,
}, },
reasoning_parser::ReasoningParser, reasoning_parser::{ParserFactory as ReasoningParserFactory, ParserResult, ReasoningParser},
routers::grpc::{context, utils},
tokenizer::{ tokenizer::{
stop::{SequenceDecoderOutput, StopSequenceDecoder}, stop::{SequenceDecoderOutput, StopSequenceDecoder},
traits::Tokenizer, traits::Tokenizer,
}, },
tool_parser::ToolParser, tool_parser::{ParserFactory as ToolParserFactory, StreamingParseResult, ToolParser},
}; };
/// Shared streaming processor for both single and dual dispatch modes /// Shared streaming processor for both single and dual dispatch modes
#[derive(Clone)] #[derive(Clone)]
pub struct StreamingProcessor { pub struct StreamingProcessor {
tokenizer: Arc<dyn Tokenizer>, tokenizer: Arc<dyn Tokenizer>,
tool_parser_factory: crate::tool_parser::ParserFactory, tool_parser_factory: ToolParserFactory,
reasoning_parser_factory: crate::reasoning_parser::ParserFactory, reasoning_parser_factory: ReasoningParserFactory,
configured_tool_parser: Option<String>, configured_tool_parser: Option<String>,
configured_reasoning_parser: Option<String>, configured_reasoning_parser: Option<String>,
} }
...@@ -51,8 +50,8 @@ pub struct StreamingProcessor { ...@@ -51,8 +50,8 @@ pub struct StreamingProcessor {
impl StreamingProcessor { impl StreamingProcessor {
pub fn new( pub fn new(
tokenizer: Arc<dyn Tokenizer>, tokenizer: Arc<dyn Tokenizer>,
tool_parser_factory: crate::tool_parser::ParserFactory, tool_parser_factory: ToolParserFactory,
reasoning_parser_factory: crate::reasoning_parser::ParserFactory, reasoning_parser_factory: ReasoningParserFactory,
configured_tool_parser: Option<String>, configured_tool_parser: Option<String>,
configured_reasoning_parser: Option<String>, configured_reasoning_parser: Option<String>,
) -> Self { ) -> Self {
...@@ -161,7 +160,7 @@ impl StreamingProcessor { ...@@ -161,7 +160,7 @@ impl StreamingProcessor {
/// Process streaming chunks from a single stream (Regular mode) /// Process streaming chunks from a single stream (Regular mode)
pub async fn process_streaming_chunks( pub async fn process_streaming_chunks(
&self, &self,
mut grpc_stream: crate::grpc_client::sglang_scheduler::AbortOnDropStream, mut grpc_stream: AbortOnDropStream,
dispatch: context::DispatchMetadata, dispatch: context::DispatchMetadata,
stop_params: (Option<StringOrArray>, Option<Vec<u32>>, bool, bool), stop_params: (Option<StringOrArray>, Option<Vec<u32>>, bool, bool),
original_request: Arc<ChatCompletionRequest>, original_request: Arc<ChatCompletionRequest>,
...@@ -576,8 +575,8 @@ impl StreamingProcessor { ...@@ -576,8 +575,8 @@ impl StreamingProcessor {
/// Process dual streaming chunks (prefill + decode) - PD mode /// Process dual streaming chunks (prefill + decode) - PD mode
pub async fn process_dual_streaming_chunks( pub async fn process_dual_streaming_chunks(
&self, &self,
mut prefill_stream: crate::grpc_client::sglang_scheduler::AbortOnDropStream, mut prefill_stream: AbortOnDropStream,
decode_stream: crate::grpc_client::sglang_scheduler::AbortOnDropStream, decode_stream: AbortOnDropStream,
dispatch: context::DispatchMetadata, dispatch: context::DispatchMetadata,
stop_params: (Option<StringOrArray>, Option<Vec<u32>>, bool, bool), stop_params: (Option<StringOrArray>, Option<Vec<u32>>, bool, bool),
original_request: Arc<ChatCompletionRequest>, original_request: Arc<ChatCompletionRequest>,
...@@ -696,7 +695,7 @@ impl StreamingProcessor { ...@@ -696,7 +695,7 @@ impl StreamingProcessor {
/// Process streaming chunks for generate endpoint (no tool/reasoning parsing) /// Process streaming chunks for generate endpoint (no tool/reasoning parsing)
async fn process_generate_streaming( async fn process_generate_streaming(
tokenizer: Arc<dyn Tokenizer>, tokenizer: Arc<dyn Tokenizer>,
mut stream: crate::grpc_client::sglang_scheduler::AbortOnDropStream, mut stream: AbortOnDropStream,
request_id: String, request_id: String,
weight_version: String, weight_version: String,
_include_logprobs: bool, _include_logprobs: bool,
...@@ -800,8 +799,8 @@ impl StreamingProcessor { ...@@ -800,8 +799,8 @@ impl StreamingProcessor {
/// Process dual streaming for generate endpoint (PD mode with logprobs support) /// Process dual streaming for generate endpoint (PD mode with logprobs support)
async fn process_generate_streaming_dual( async fn process_generate_streaming_dual(
tokenizer: Arc<dyn Tokenizer>, tokenizer: Arc<dyn Tokenizer>,
mut prefill_stream: crate::grpc_client::sglang_scheduler::AbortOnDropStream, mut prefill_stream: AbortOnDropStream,
decode_stream: crate::grpc_client::sglang_scheduler::AbortOnDropStream, decode_stream: AbortOnDropStream,
request_id: String, request_id: String,
weight_version: String, weight_version: String,
return_logprob: bool, return_logprob: bool,
...@@ -857,7 +856,7 @@ impl StreamingProcessor { ...@@ -857,7 +856,7 @@ impl StreamingProcessor {
/// Process generate streaming with optional input_logprobs /// Process generate streaming with optional input_logprobs
async fn process_generate_streaming_with_input_logprobs( async fn process_generate_streaming_with_input_logprobs(
tokenizer: Arc<dyn Tokenizer>, tokenizer: Arc<dyn Tokenizer>,
mut stream: crate::grpc_client::sglang_scheduler::AbortOnDropStream, mut stream: AbortOnDropStream,
request_id: String, request_id: String,
weight_version: String, weight_version: String,
_include_logprobs: bool, _include_logprobs: bool,
...@@ -1051,7 +1050,7 @@ impl StreamingProcessor { ...@@ -1051,7 +1050,7 @@ impl StreamingProcessor {
}; };
match parse_result { match parse_result {
Ok(crate::reasoning_parser::ParserResult { Ok(ParserResult {
reasoning_text, reasoning_text,
normal_text, normal_text,
}) => { }) => {
...@@ -1122,7 +1121,7 @@ impl StreamingProcessor { ...@@ -1122,7 +1121,7 @@ impl StreamingProcessor {
let mut parser = pooled_parser.lock().await; let mut parser = pooled_parser.lock().await;
match parser.parse_incremental(delta, tools).await { match parser.parse_incremental(delta, tools).await {
Ok(crate::tool_parser::StreamingParseResult { normal_text, calls }) => { Ok(StreamingParseResult { normal_text, calls }) => {
// Emit normal text if present // Emit normal text if present
if !normal_text.is_empty() { if !normal_text.is_empty() {
chunks.push(ChatCompletionStreamResponse { chunks.push(ChatCompletionStreamResponse {
......
// gRPC Router Implementation
use std::sync::Arc; use std::sync::Arc;
use async_trait::async_trait; use async_trait::async_trait;
...@@ -12,13 +10,14 @@ use axum::{ ...@@ -12,13 +10,14 @@ use axum::{
use tracing::debug; use tracing::debug;
use super::{ use super::{
common::responses::handlers::{cancel_response_impl, get_response_impl},
context::SharedComponents, context::SharedComponents,
harmony::{ harmony::{
serve_harmony_responses, serve_harmony_responses_stream, HarmonyDetector, serve_harmony_responses, serve_harmony_responses_stream, HarmonyDetector,
HarmonyResponsesContext, HarmonyResponsesContext,
}, },
pipeline::RequestPipeline, pipeline::RequestPipeline,
responses, regular::responses,
}; };
use crate::{ use crate::{
app_context::AppContext, app_context::AppContext,
...@@ -43,9 +42,7 @@ pub struct GrpcRouter { ...@@ -43,9 +42,7 @@ pub struct GrpcRouter {
pipeline: RequestPipeline, pipeline: RequestPipeline,
harmony_pipeline: RequestPipeline, harmony_pipeline: RequestPipeline,
shared_components: Arc<SharedComponents>, shared_components: Arc<SharedComponents>,
// Responses context (bundles all /v1/responses dependencies: storage, MCP, background_tasks)
responses_context: responses::ResponsesContext, responses_context: responses::ResponsesContext,
// Harmony responses context (uses harmony pipeline)
harmony_responses_context: responses::ResponsesContext, harmony_responses_context: responses::ResponsesContext,
} }
...@@ -156,7 +153,6 @@ impl GrpcRouter { ...@@ -156,7 +153,6 @@ impl GrpcRouter {
&self.pipeline &self.pipeline
}; };
// Use selected pipeline for ALL requests (streaming and non-streaming)
pipeline pipeline
.execute_chat( .execute_chat(
Arc::new(body.clone()), Arc::new(body.clone()),
...@@ -176,7 +172,6 @@ impl GrpcRouter { ...@@ -176,7 +172,6 @@ impl GrpcRouter {
) -> Response { ) -> Response {
debug!("Processing generate request for model: {:?}", model_id); debug!("Processing generate request for model: {:?}", model_id);
// Use pipeline for ALL requests (streaming and non-streaming)
self.pipeline self.pipeline
.execute_generate( .execute_generate(
Arc::new(body.clone()), Arc::new(body.clone()),
...@@ -187,35 +182,51 @@ impl GrpcRouter { ...@@ -187,35 +182,51 @@ impl GrpcRouter {
.await .await
} }
/// Main route_responses implementation (pipeline-based for Harmony) /// Main route_responses implementation
///
/// Routes to either Harmony or regular responses implementation based on model detection
async fn route_responses_impl( async fn route_responses_impl(
&self, &self,
_headers: Option<&HeaderMap>, headers: Option<&HeaderMap>,
body: &ResponsesRequest, body: &ResponsesRequest,
model_id: Option<&str>, model_id: Option<&str>,
) -> Response { ) -> Response {
debug!( // Choose implementation based on Harmony model detection
"Processing Harmony responses request for model: {:?}, streaming: {:?}", let is_harmony = HarmonyDetector::is_harmony_model(&body.model);
model_id, body.stream
);
// Create HarmonyResponsesContext from existing responses context debug!(
let harmony_ctx = HarmonyResponsesContext::new( "Processing responses request for model: {:?}, using_harmony={}",
Arc::new(self.harmony_pipeline.clone()), model_id, is_harmony
self.shared_components.clone(),
self.harmony_responses_context.mcp_manager.clone(),
self.harmony_responses_context.response_storage.clone(),
); );
// Check if streaming is requested if is_harmony {
if body.stream.unwrap_or(false) { debug!(
serve_harmony_responses_stream(&harmony_ctx, body.clone()).await "Processing Harmony responses request for model: {:?}, streaming: {:?}",
} else { model_id, body.stream
// Use non-streaming version for standard JSON responses );
match serve_harmony_responses(&harmony_ctx, body.clone()).await { let harmony_ctx = HarmonyResponsesContext::new(
Ok(response) => axum::Json(response).into_response(), Arc::new(self.harmony_pipeline.clone()),
Err(error_response) => error_response, self.shared_components.clone(),
self.harmony_responses_context.mcp_manager.clone(),
self.harmony_responses_context.response_storage.clone(),
);
if body.stream.unwrap_or(false) {
serve_harmony_responses_stream(&harmony_ctx, body.clone()).await
} else {
match serve_harmony_responses(&harmony_ctx, body.clone()).await {
Ok(response) => axum::Json(response).into_response(),
Err(error_response) => error_response,
}
} }
} else {
responses::route_responses(
&self.responses_context,
Arc::new(body.clone()),
headers.cloned(),
model_id.map(|s| s.to_string()),
)
.await
} }
} }
} }
...@@ -236,7 +247,6 @@ impl RouterTrait for GrpcRouter { ...@@ -236,7 +247,6 @@ impl RouterTrait for GrpcRouter {
} }
async fn health_generate(&self, _req: Request<Body>) -> Response { async fn health_generate(&self, _req: Request<Body>) -> Response {
// TODO: Implement actual generation test for gRPC
( (
StatusCode::NOT_IMPLEMENTED, StatusCode::NOT_IMPLEMENTED,
"Health generate not yet implemented for gRPC", "Health generate not yet implemented for gRPC",
...@@ -289,27 +299,7 @@ impl RouterTrait for GrpcRouter { ...@@ -289,27 +299,7 @@ impl RouterTrait for GrpcRouter {
body: &ResponsesRequest, body: &ResponsesRequest,
model_id: Option<&str>, model_id: Option<&str>,
) -> Response { ) -> Response {
// Choose implementation based on Harmony model detection self.route_responses_impl(headers, body, model_id).await
let is_harmony = HarmonyDetector::is_harmony_model(&body.model);
debug!(
"Processing responses request for model: {:?}, using_harmony={}",
model_id, is_harmony
);
if is_harmony {
// Use pipeline-based implementation for Harmony models
self.route_responses_impl(headers, body, model_id).await
} else {
// Use legacy responses module for non-Harmony models
responses::route_responses(
&self.responses_context,
Arc::new(body.clone()),
headers.cloned(),
model_id.map(|s| s.to_string()),
)
.await
}
} }
async fn get_response( async fn get_response(
...@@ -318,11 +308,11 @@ impl RouterTrait for GrpcRouter { ...@@ -318,11 +308,11 @@ impl RouterTrait for GrpcRouter {
response_id: &str, response_id: &str,
_params: &ResponsesGetParams, _params: &ResponsesGetParams,
) -> Response { ) -> Response {
responses::get_response_impl(&self.responses_context, response_id).await get_response_impl(&self.responses_context, response_id).await
} }
async fn cancel_response(&self, _headers: Option<&HeaderMap>, response_id: &str) -> Response { async fn cancel_response(&self, _headers: Option<&HeaderMap>, response_id: &str) -> Response {
responses::cancel_response_impl(&self.responses_context, response_id).await cancel_response_impl(&self.responses_context, response_id).await
} }
async fn route_embeddings( async fn route_embeddings(
......
//! Request building stage: Build proto GenerateRequest
use std::sync::Arc;
use async_trait::async_trait;
use axum::response::Response;
use proto::DisaggregatedParams;
use rand::Rng;
use tracing::debug;
use uuid::Uuid;
use super::PipelineStage;
use crate::{
core::Worker,
grpc_client::proto,
routers::grpc::{
context::{ClientSelection, RequestContext, RequestType, WorkerSelection},
error,
},
};
/// Request building stage: Build proto GenerateRequest
pub struct RequestBuildingStage {
inject_pd_metadata: bool,
}
impl RequestBuildingStage {
pub fn new(inject_pd_metadata: bool) -> Self {
Self { inject_pd_metadata }
}
}
#[async_trait]
impl PipelineStage for RequestBuildingStage {
async fn execute(&self, ctx: &mut RequestContext) -> Result<Option<Response>, Response> {
let prep = ctx
.state
.preparation
.as_ref()
.ok_or_else(|| error::internal_error("Preparation not completed"))?;
let clients = ctx
.state
.clients
.as_ref()
.ok_or_else(|| error::internal_error("Client acquisition not completed"))?;
// Get client for building request (use prefill client if PD mode)
let builder_client = match clients {
ClientSelection::Single { client } => client,
ClientSelection::Dual { prefill, .. } => prefill,
};
let mut proto_request = match &ctx.input.request_type {
RequestType::Chat(request) => {
let request_id = format!("chatcmpl-{}", Uuid::new_v4());
let body_ref = prep.filtered_request.as_ref().unwrap_or(request);
builder_client
.build_generate_request(
request_id,
body_ref,
prep.processed_messages.as_ref().unwrap().text.clone(),
prep.token_ids.clone(),
prep.processed_messages
.as_ref()
.unwrap()
.multimodal_inputs
.clone(),
prep.tool_constraints.clone(),
)
.map_err(|e| error::bad_request(format!("Invalid request parameters: {}", e)))?
}
RequestType::Generate(request) => {
let request_id = request
.rid
.clone()
.unwrap_or_else(|| format!("gen-{}", Uuid::new_v4()));
builder_client
.build_plain_generate_request(
request_id,
request,
prep.original_text.clone(),
prep.token_ids.clone(),
)
.map_err(error::bad_request)?
}
RequestType::Responses(_request) => {
// Responses API builds request during the MCP loop
// For now, create minimal request - responses handler will populate it
let request_id = format!("resp-{}", Uuid::new_v4());
proto::GenerateRequest {
request_id,
..Default::default()
}
}
};
// Inject PD metadata if needed
if self.inject_pd_metadata {
if let WorkerSelection::Dual { prefill, .. } = ctx.state.workers.as_ref().unwrap() {
self.inject_bootstrap_metadata(&mut proto_request, prefill);
}
}
ctx.state.proto_request = Some(proto_request);
Ok(None)
}
fn name(&self) -> &'static str {
"RequestBuilding"
}
}
impl RequestBuildingStage {
fn inject_bootstrap_metadata(
&self,
request: &mut proto::GenerateRequest,
prefill_worker: &Arc<dyn Worker>,
) {
let hostname = prefill_worker.bootstrap_host();
let bootstrap_port = prefill_worker.bootstrap_port().unwrap_or(8998);
// Generate room ID for bootstrap
let room_id = rand::rng().random_range(0..i32::MAX);
// Create DisaggregatedParams
let disagg_params = DisaggregatedParams {
bootstrap_host: hostname.to_string(),
bootstrap_port: bootstrap_port as i32,
bootstrap_room: room_id,
};
// Inject metadata directly into request
request.disaggregated_params = Some(disagg_params);
debug!(
"Injected bootstrap metadata: host={}, port={}, room={}",
hostname, bootstrap_port, room_id
);
}
}
...@@ -21,12 +21,19 @@ use crate::{ ...@@ -21,12 +21,19 @@ use crate::{
}, },
generate::GenerateFinishReason, generate::GenerateFinishReason,
}, },
reasoning_parser::{
ParserFactory as ReasoningParserFactory, PooledParser as ReasoningPooledParser,
ReasoningParser,
},
tokenizer::{ tokenizer::{
cache::CachedTokenizer, cache::CachedTokenizer,
chat_template::{ChatTemplateContentFormat, ChatTemplateParams}, chat_template::{ChatTemplateContentFormat, ChatTemplateParams},
traits::Tokenizer, traits::Tokenizer,
HuggingFaceTokenizer, HuggingFaceTokenizer,
}, },
tool_parser::{
ParserFactory as ToolParserFactory, PooledParser as ToolPooledParser, ToolParser,
},
}; };
/// Get gRPC client from worker, returning appropriate error response on failure /// Get gRPC client from worker, returning appropriate error response on failure
...@@ -44,20 +51,17 @@ pub async fn get_grpc_client_from_worker( ...@@ -44,20 +51,17 @@ pub async fn get_grpc_client_from_worker(
/// Process tool call arguments in messages /// Process tool call arguments in messages
/// Per Transformers docs, tool call arguments in assistant messages should be dicts /// Per Transformers docs, tool call arguments in assistant messages should be dicts
pub fn process_tool_call_arguments(messages: &mut [Value]) -> Result<(), String> { fn process_tool_call_arguments(messages: &mut [Value]) -> Result<(), String> {
for msg in messages { for msg in messages {
// Early return if not assistant message
let role = msg.get("role").and_then(|v| v.as_str()); let role = msg.get("role").and_then(|v| v.as_str());
if role != Some("assistant") { if role != Some("assistant") {
continue; continue;
} }
// Early return if no tool_calls
let Some(tool_calls) = msg.get_mut("tool_calls").and_then(|tc| tc.as_array_mut()) else { let Some(tool_calls) = msg.get_mut("tool_calls").and_then(|tc| tc.as_array_mut()) else {
continue; continue;
}; };
// Process each tool call's arguments
for call in tool_calls { for call in tool_calls {
let Some(function) = call.get_mut("function") else { let Some(function) = call.get_mut("function") else {
continue; continue;
...@@ -107,10 +111,7 @@ pub fn process_content_format( ...@@ -107,10 +111,7 @@ pub fn process_content_format(
} }
/// Transform a single content field based on content format /// Transform a single content field based on content format
pub fn transform_content_field( fn transform_content_field(content_value: &mut Value, content_format: ChatTemplateContentFormat) {
content_value: &mut Value,
content_format: ChatTemplateContentFormat,
) {
let Some(content_array) = content_value.as_array() else { let Some(content_array) = content_value.as_array() else {
return; // Not multimodal, keep as-is return; // Not multimodal, keep as-is
}; };
...@@ -209,7 +210,7 @@ pub fn generate_tool_constraints( ...@@ -209,7 +210,7 @@ pub fn generate_tool_constraints(
/// Build JSON schema for required tool calls (array with minItems: 1) /// Build JSON schema for required tool calls (array with minItems: 1)
/// Includes $defs consolidation from all tools (matching Python's behavior) /// Includes $defs consolidation from all tools (matching Python's behavior)
pub fn build_required_array_schema(tools: &[Tool]) -> Result<String, String> { fn build_required_array_schema(tools: &[Tool]) -> Result<String, String> {
// Build anyOf schemas for each tool // Build anyOf schemas for each tool
let mut any_of_schemas = Vec::new(); let mut any_of_schemas = Vec::new();
for tool in tools { for tool in tools {
...@@ -651,7 +652,7 @@ pub fn generate_tool_call_id( ...@@ -651,7 +652,7 @@ pub fn generate_tool_call_id(
/// Check if a reasoning parser is available for the given model /// Check if a reasoning parser is available for the given model
pub fn check_reasoning_parser_availability( pub fn check_reasoning_parser_availability(
reasoning_parser_factory: &crate::reasoning_parser::ParserFactory, reasoning_parser_factory: &ReasoningParserFactory,
configured_parser: Option<&String>, configured_parser: Option<&String>,
model: &str, model: &str,
) -> bool { ) -> bool {
...@@ -666,7 +667,7 @@ pub fn check_reasoning_parser_availability( ...@@ -666,7 +667,7 @@ pub fn check_reasoning_parser_availability(
/// Check if a tool parser is available for the given model /// Check if a tool parser is available for the given model
pub fn check_tool_parser_availability( pub fn check_tool_parser_availability(
tool_parser_factory: &crate::tool_parser::ParserFactory, tool_parser_factory: &ToolParserFactory,
configured_parser: Option<&String>, configured_parser: Option<&String>,
model: &str, model: &str,
) -> bool { ) -> bool {
...@@ -683,10 +684,10 @@ pub fn check_tool_parser_availability( ...@@ -683,10 +684,10 @@ pub fn check_tool_parser_availability(
/// Otherwise, auto-detect based on the model name. /// Otherwise, auto-detect based on the model name.
/// Get a pooled reasoning parser (for non-streaming where state doesn't matter) /// Get a pooled reasoning parser (for non-streaming where state doesn't matter)
pub fn get_reasoning_parser( pub fn get_reasoning_parser(
reasoning_parser_factory: &crate::reasoning_parser::ParserFactory, reasoning_parser_factory: &ReasoningParserFactory,
configured_parser: Option<&String>, configured_parser: Option<&String>,
model: &str, model: &str,
) -> crate::reasoning_parser::PooledParser { ) -> ReasoningPooledParser {
if let Some(parser_name) = configured_parser { if let Some(parser_name) = configured_parser {
// Use configured parser if specified // Use configured parser if specified
reasoning_parser_factory reasoning_parser_factory
...@@ -707,10 +708,10 @@ pub fn get_reasoning_parser( ...@@ -707,10 +708,10 @@ pub fn get_reasoning_parser(
/// Create a fresh reasoning parser instance (for streaming where state isolation is needed) /// Create a fresh reasoning parser instance (for streaming where state isolation is needed)
pub fn create_reasoning_parser( pub fn create_reasoning_parser(
reasoning_parser_factory: &crate::reasoning_parser::ParserFactory, reasoning_parser_factory: &ReasoningParserFactory,
configured_parser: Option<&String>, configured_parser: Option<&String>,
model: &str, model: &str,
) -> Option<Box<dyn crate::reasoning_parser::ReasoningParser>> { ) -> Option<Box<dyn ReasoningParser>> {
if let Some(parser_name) = configured_parser { if let Some(parser_name) = configured_parser {
// Use configured parser if specified // Use configured parser if specified
reasoning_parser_factory reasoning_parser_factory
...@@ -735,10 +736,10 @@ pub fn create_reasoning_parser( ...@@ -735,10 +736,10 @@ pub fn create_reasoning_parser(
/// Otherwise, auto-detect based on the model name. /// Otherwise, auto-detect based on the model name.
/// Get a pooled tool parser (for non-streaming where state doesn't matter) /// Get a pooled tool parser (for non-streaming where state doesn't matter)
pub fn get_tool_parser( pub fn get_tool_parser(
tool_parser_factory: &crate::tool_parser::ParserFactory, tool_parser_factory: &ToolParserFactory,
configured_parser: Option<&String>, configured_parser: Option<&String>,
model: &str, model: &str,
) -> crate::tool_parser::PooledParser { ) -> ToolPooledParser {
if let Some(parser_name) = configured_parser { if let Some(parser_name) = configured_parser {
// Use configured parser if specified // Use configured parser if specified
tool_parser_factory tool_parser_factory
...@@ -759,10 +760,10 @@ pub fn get_tool_parser( ...@@ -759,10 +760,10 @@ pub fn get_tool_parser(
/// Create a fresh tool parser instance (for streaming where state isolation is needed) /// Create a fresh tool parser instance (for streaming where state isolation is needed)
pub fn create_tool_parser( pub fn create_tool_parser(
tool_parser_factory: &crate::tool_parser::ParserFactory, tool_parser_factory: &ToolParserFactory,
configured_parser: Option<&String>, configured_parser: Option<&String>,
model: &str, model: &str,
) -> Option<Box<dyn crate::tool_parser::ToolParser>> { ) -> Option<Box<dyn ToolParser>> {
if let Some(parser_name) = configured_parser { if let Some(parser_name) = configured_parser {
// Use configured parser if specified // Use configured parser if specified
tool_parser_factory tool_parser_factory
......
...@@ -18,20 +18,15 @@ use minijinja::{ ...@@ -18,20 +18,15 @@ use minijinja::{
use serde_json; use serde_json;
/// Chat template content format /// Chat template content format
#[derive(Debug, Clone, Copy, PartialEq, Eq)] #[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum ChatTemplateContentFormat { pub enum ChatTemplateContentFormat {
/// Content is a simple string /// Content is a simple string
#[default]
String, String,
/// Content is a list of structured parts (OpenAI format) /// Content is a list of structured parts (OpenAI format)
OpenAI, OpenAI,
} }
impl Default for ChatTemplateContentFormat {
fn default() -> Self {
Self::String
}
}
impl std::fmt::Display for ChatTemplateContentFormat { impl std::fmt::Display for ChatTemplateContentFormat {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self { match self {
......
...@@ -171,7 +171,7 @@ fn test_chatml_template() { ...@@ -171,7 +171,7 @@ fn test_chatml_template() {
let processor = ChatTemplateProcessor::new(template.to_string()); let processor = ChatTemplateProcessor::new(template.to_string());
let messages = vec![ let messages = [
ChatMessage::User { ChatMessage::User {
content: UserMessageContent::Text("Hello".to_string()), content: UserMessageContent::Text("Hello".to_string()),
name: None, name: None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment