Unverified Commit 83804bc6 authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

[router][grpc] Restructure modules and code clean up (#12598)

parent d5fa019c
//! Pipeline stages for regular (non-harmony) model processing
//!
//! This module defines stages specific to regular tokenizer-based models.
pub mod chat;
pub mod generate;
mod preparation;
mod request_building;
mod response_processing;
pub use chat::{ChatPreparationStage, ChatRequestBuildingStage, ChatResponseProcessingStage};
pub use generate::{
GeneratePreparationStage, GenerateRequestBuildingStage, GenerateResponseProcessingStage,
};
pub use preparation::PreparationStage;
pub use request_building::RequestBuildingStage;
pub use response_processing::ResponseProcessingStage;
//! Preparation stage that delegates to endpoint-specific implementations
//!
//! This stage checks RequestType at runtime and delegates to the appropriate
//! endpoint-specific stage (ChatPreparationStage or GeneratePreparationStage).
use async_trait::async_trait;
use axum::response::Response;
use super::{chat::ChatPreparationStage, generate::GeneratePreparationStage};
use crate::routers::grpc::{
common::stages::PipelineStage,
context::{RequestContext, RequestType},
};
/// Preparation stage (delegates to endpoint-specific implementations)
pub struct PreparationStage {
chat_stage: ChatPreparationStage,
generate_stage: GeneratePreparationStage,
}
impl PreparationStage {
pub fn new() -> Self {
Self {
chat_stage: ChatPreparationStage,
generate_stage: GeneratePreparationStage,
}
}
}
impl Default for PreparationStage {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl PipelineStage for PreparationStage {
async fn execute(&self, ctx: &mut RequestContext) -> Result<Option<Response>, Response> {
match &ctx.input.request_type {
RequestType::Chat(_) => self.chat_stage.execute(ctx).await,
RequestType::Generate(_) => self.generate_stage.execute(ctx).await,
RequestType::Responses(_) => {
// Responses API has its own preparation handled elsewhere
Ok(None)
}
}
}
fn name(&self) -> &'static str {
"Preparation"
}
}
//! Request building stage that delegates to endpoint-specific implementations
use async_trait::async_trait;
use axum::response::Response;
use uuid::Uuid;
use super::{chat::ChatRequestBuildingStage, generate::GenerateRequestBuildingStage};
use crate::{
grpc_client::proto,
routers::grpc::{
common::stages::PipelineStage,
context::{RequestContext, RequestType},
},
};
/// Request building stage (delegates to endpoint-specific implementations)
pub struct RequestBuildingStage {
chat_stage: ChatRequestBuildingStage,
generate_stage: GenerateRequestBuildingStage,
}
impl RequestBuildingStage {
pub fn new(inject_pd_metadata: bool) -> Self {
Self {
chat_stage: ChatRequestBuildingStage::new(inject_pd_metadata),
generate_stage: GenerateRequestBuildingStage::new(inject_pd_metadata),
}
}
}
#[async_trait]
impl PipelineStage for RequestBuildingStage {
async fn execute(&self, ctx: &mut RequestContext) -> Result<Option<Response>, Response> {
match &ctx.input.request_type {
RequestType::Chat(_) => self.chat_stage.execute(ctx).await,
RequestType::Generate(_) => self.generate_stage.execute(ctx).await,
RequestType::Responses(_request) => {
// Responses API builds request during the MCP loop
// For now, create minimal request - responses handler will populate it
let request_id = format!("resp-{}", Uuid::new_v4());
ctx.state.proto_request = Some(proto::GenerateRequest {
request_id,
..Default::default()
});
Ok(None)
}
}
}
fn name(&self) -> &'static str {
"RequestBuilding"
}
}
//! Response processing stage that delegates to endpoint-specific implementations
use std::sync::Arc;
use async_trait::async_trait;
use axum::response::Response;
use super::{chat::ChatResponseProcessingStage, generate::GenerateResponseProcessingStage};
use crate::routers::grpc::{
common::stages::PipelineStage,
context::{RequestContext, RequestType},
error,
regular::{processor, streaming},
};
/// Response processing stage (delegates to endpoint-specific implementations)
pub struct ResponseProcessingStage {
chat_stage: ChatResponseProcessingStage,
generate_stage: GenerateResponseProcessingStage,
}
impl ResponseProcessingStage {
pub fn new(
processor: processor::ResponseProcessor,
streaming_processor: Arc<streaming::StreamingProcessor>,
) -> Self {
Self {
chat_stage: ChatResponseProcessingStage::new(
processor.clone(),
streaming_processor.clone(),
),
generate_stage: GenerateResponseProcessingStage::new(processor, streaming_processor),
}
}
}
#[async_trait]
impl PipelineStage for ResponseProcessingStage {
async fn execute(&self, ctx: &mut RequestContext) -> Result<Option<Response>, Response> {
match &ctx.input.request_type {
RequestType::Chat(_) => self.chat_stage.execute(ctx).await,
RequestType::Generate(_) => self.generate_stage.execute(ctx).await,
RequestType::Responses(_) => Err(error::bad_request(
"Responses API processing must be handled by responses handler".to_string(),
)),
}
}
fn name(&self) -> &'static str {
"ResponseProcessing"
}
}
//! Streaming response processor for gRPC routers
//!
//! This module contains shared streaming logic for both Regular and PD routers,
//! eliminating ~600 lines of duplication.
//! This module contains shared streaming logic for both Regular and PD router.
use std::{collections::HashMap, io, sync::Arc, time::Instant};
......@@ -17,9 +16,8 @@ use tokio::sync::{mpsc, mpsc::UnboundedSender};
use tokio_stream::{wrappers::UnboundedReceiverStream, StreamExt};
use tracing::{debug, error, warn};
use super::{context, utils};
use crate::{
grpc_client::proto,
grpc_client::{proto, sglang_scheduler::AbortOnDropStream},
protocols::{
chat::{
ChatCompletionRequest, ChatCompletionStreamResponse, ChatMessageDelta, ChatStreamChoice,
......@@ -30,20 +28,21 @@ use crate::{
},
generate::GenerateRequest,
},
reasoning_parser::ReasoningParser,
reasoning_parser::{ParserFactory as ReasoningParserFactory, ParserResult, ReasoningParser},
routers::grpc::{context, utils},
tokenizer::{
stop::{SequenceDecoderOutput, StopSequenceDecoder},
traits::Tokenizer,
},
tool_parser::ToolParser,
tool_parser::{ParserFactory as ToolParserFactory, StreamingParseResult, ToolParser},
};
/// Shared streaming processor for both single and dual dispatch modes
#[derive(Clone)]
pub struct StreamingProcessor {
tokenizer: Arc<dyn Tokenizer>,
tool_parser_factory: crate::tool_parser::ParserFactory,
reasoning_parser_factory: crate::reasoning_parser::ParserFactory,
tool_parser_factory: ToolParserFactory,
reasoning_parser_factory: ReasoningParserFactory,
configured_tool_parser: Option<String>,
configured_reasoning_parser: Option<String>,
}
......@@ -51,8 +50,8 @@ pub struct StreamingProcessor {
impl StreamingProcessor {
pub fn new(
tokenizer: Arc<dyn Tokenizer>,
tool_parser_factory: crate::tool_parser::ParserFactory,
reasoning_parser_factory: crate::reasoning_parser::ParserFactory,
tool_parser_factory: ToolParserFactory,
reasoning_parser_factory: ReasoningParserFactory,
configured_tool_parser: Option<String>,
configured_reasoning_parser: Option<String>,
) -> Self {
......@@ -161,7 +160,7 @@ impl StreamingProcessor {
/// Process streaming chunks from a single stream (Regular mode)
pub async fn process_streaming_chunks(
&self,
mut grpc_stream: crate::grpc_client::sglang_scheduler::AbortOnDropStream,
mut grpc_stream: AbortOnDropStream,
dispatch: context::DispatchMetadata,
stop_params: (Option<StringOrArray>, Option<Vec<u32>>, bool, bool),
original_request: Arc<ChatCompletionRequest>,
......@@ -576,8 +575,8 @@ impl StreamingProcessor {
/// Process dual streaming chunks (prefill + decode) - PD mode
pub async fn process_dual_streaming_chunks(
&self,
mut prefill_stream: crate::grpc_client::sglang_scheduler::AbortOnDropStream,
decode_stream: crate::grpc_client::sglang_scheduler::AbortOnDropStream,
mut prefill_stream: AbortOnDropStream,
decode_stream: AbortOnDropStream,
dispatch: context::DispatchMetadata,
stop_params: (Option<StringOrArray>, Option<Vec<u32>>, bool, bool),
original_request: Arc<ChatCompletionRequest>,
......@@ -696,7 +695,7 @@ impl StreamingProcessor {
/// Process streaming chunks for generate endpoint (no tool/reasoning parsing)
async fn process_generate_streaming(
tokenizer: Arc<dyn Tokenizer>,
mut stream: crate::grpc_client::sglang_scheduler::AbortOnDropStream,
mut stream: AbortOnDropStream,
request_id: String,
weight_version: String,
_include_logprobs: bool,
......@@ -800,8 +799,8 @@ impl StreamingProcessor {
/// Process dual streaming for generate endpoint (PD mode with logprobs support)
async fn process_generate_streaming_dual(
tokenizer: Arc<dyn Tokenizer>,
mut prefill_stream: crate::grpc_client::sglang_scheduler::AbortOnDropStream,
decode_stream: crate::grpc_client::sglang_scheduler::AbortOnDropStream,
mut prefill_stream: AbortOnDropStream,
decode_stream: AbortOnDropStream,
request_id: String,
weight_version: String,
return_logprob: bool,
......@@ -857,7 +856,7 @@ impl StreamingProcessor {
/// Process generate streaming with optional input_logprobs
async fn process_generate_streaming_with_input_logprobs(
tokenizer: Arc<dyn Tokenizer>,
mut stream: crate::grpc_client::sglang_scheduler::AbortOnDropStream,
mut stream: AbortOnDropStream,
request_id: String,
weight_version: String,
_include_logprobs: bool,
......@@ -1051,7 +1050,7 @@ impl StreamingProcessor {
};
match parse_result {
Ok(crate::reasoning_parser::ParserResult {
Ok(ParserResult {
reasoning_text,
normal_text,
}) => {
......@@ -1122,7 +1121,7 @@ impl StreamingProcessor {
let mut parser = pooled_parser.lock().await;
match parser.parse_incremental(delta, tools).await {
Ok(crate::tool_parser::StreamingParseResult { normal_text, calls }) => {
Ok(StreamingParseResult { normal_text, calls }) => {
// Emit normal text if present
if !normal_text.is_empty() {
chunks.push(ChatCompletionStreamResponse {
......
// gRPC Router Implementation
use std::sync::Arc;
use async_trait::async_trait;
......@@ -12,13 +10,14 @@ use axum::{
use tracing::debug;
use super::{
common::responses::handlers::{cancel_response_impl, get_response_impl},
context::SharedComponents,
harmony::{
serve_harmony_responses, serve_harmony_responses_stream, HarmonyDetector,
HarmonyResponsesContext,
},
pipeline::RequestPipeline,
responses,
regular::responses,
};
use crate::{
app_context::AppContext,
......@@ -43,9 +42,7 @@ pub struct GrpcRouter {
pipeline: RequestPipeline,
harmony_pipeline: RequestPipeline,
shared_components: Arc<SharedComponents>,
// Responses context (bundles all /v1/responses dependencies: storage, MCP, background_tasks)
responses_context: responses::ResponsesContext,
// Harmony responses context (uses harmony pipeline)
harmony_responses_context: responses::ResponsesContext,
}
......@@ -156,7 +153,6 @@ impl GrpcRouter {
&self.pipeline
};
// Use selected pipeline for ALL requests (streaming and non-streaming)
pipeline
.execute_chat(
Arc::new(body.clone()),
......@@ -176,7 +172,6 @@ impl GrpcRouter {
) -> Response {
debug!("Processing generate request for model: {:?}", model_id);
// Use pipeline for ALL requests (streaming and non-streaming)
self.pipeline
.execute_generate(
Arc::new(body.clone()),
......@@ -187,35 +182,51 @@ impl GrpcRouter {
.await
}
/// Main route_responses implementation (pipeline-based for Harmony)
/// Main route_responses implementation
///
/// Routes to either Harmony or regular responses implementation based on model detection
async fn route_responses_impl(
&self,
_headers: Option<&HeaderMap>,
headers: Option<&HeaderMap>,
body: &ResponsesRequest,
model_id: Option<&str>,
) -> Response {
debug!(
"Processing Harmony responses request for model: {:?}, streaming: {:?}",
model_id, body.stream
);
// Choose implementation based on Harmony model detection
let is_harmony = HarmonyDetector::is_harmony_model(&body.model);
// Create HarmonyResponsesContext from existing responses context
let harmony_ctx = HarmonyResponsesContext::new(
Arc::new(self.harmony_pipeline.clone()),
self.shared_components.clone(),
self.harmony_responses_context.mcp_manager.clone(),
self.harmony_responses_context.response_storage.clone(),
debug!(
"Processing responses request for model: {:?}, using_harmony={}",
model_id, is_harmony
);
// Check if streaming is requested
if body.stream.unwrap_or(false) {
serve_harmony_responses_stream(&harmony_ctx, body.clone()).await
} else {
// Use non-streaming version for standard JSON responses
match serve_harmony_responses(&harmony_ctx, body.clone()).await {
Ok(response) => axum::Json(response).into_response(),
Err(error_response) => error_response,
if is_harmony {
debug!(
"Processing Harmony responses request for model: {:?}, streaming: {:?}",
model_id, body.stream
);
let harmony_ctx = HarmonyResponsesContext::new(
Arc::new(self.harmony_pipeline.clone()),
self.shared_components.clone(),
self.harmony_responses_context.mcp_manager.clone(),
self.harmony_responses_context.response_storage.clone(),
);
if body.stream.unwrap_or(false) {
serve_harmony_responses_stream(&harmony_ctx, body.clone()).await
} else {
match serve_harmony_responses(&harmony_ctx, body.clone()).await {
Ok(response) => axum::Json(response).into_response(),
Err(error_response) => error_response,
}
}
} else {
responses::route_responses(
&self.responses_context,
Arc::new(body.clone()),
headers.cloned(),
model_id.map(|s| s.to_string()),
)
.await
}
}
}
......@@ -236,7 +247,6 @@ impl RouterTrait for GrpcRouter {
}
async fn health_generate(&self, _req: Request<Body>) -> Response {
// TODO: Implement actual generation test for gRPC
(
StatusCode::NOT_IMPLEMENTED,
"Health generate not yet implemented for gRPC",
......@@ -289,27 +299,7 @@ impl RouterTrait for GrpcRouter {
body: &ResponsesRequest,
model_id: Option<&str>,
) -> Response {
// Choose implementation based on Harmony model detection
let is_harmony = HarmonyDetector::is_harmony_model(&body.model);
debug!(
"Processing responses request for model: {:?}, using_harmony={}",
model_id, is_harmony
);
if is_harmony {
// Use pipeline-based implementation for Harmony models
self.route_responses_impl(headers, body, model_id).await
} else {
// Use legacy responses module for non-Harmony models
responses::route_responses(
&self.responses_context,
Arc::new(body.clone()),
headers.cloned(),
model_id.map(|s| s.to_string()),
)
.await
}
self.route_responses_impl(headers, body, model_id).await
}
async fn get_response(
......@@ -318,11 +308,11 @@ impl RouterTrait for GrpcRouter {
response_id: &str,
_params: &ResponsesGetParams,
) -> Response {
responses::get_response_impl(&self.responses_context, response_id).await
get_response_impl(&self.responses_context, response_id).await
}
async fn cancel_response(&self, _headers: Option<&HeaderMap>, response_id: &str) -> Response {
responses::cancel_response_impl(&self.responses_context, response_id).await
cancel_response_impl(&self.responses_context, response_id).await
}
async fn route_embeddings(
......
//! Request building stage: Build proto GenerateRequest
use std::sync::Arc;
use async_trait::async_trait;
use axum::response::Response;
use proto::DisaggregatedParams;
use rand::Rng;
use tracing::debug;
use uuid::Uuid;
use super::PipelineStage;
use crate::{
core::Worker,
grpc_client::proto,
routers::grpc::{
context::{ClientSelection, RequestContext, RequestType, WorkerSelection},
error,
},
};
/// Request building stage: Build proto GenerateRequest
pub struct RequestBuildingStage {
inject_pd_metadata: bool,
}
impl RequestBuildingStage {
pub fn new(inject_pd_metadata: bool) -> Self {
Self { inject_pd_metadata }
}
}
#[async_trait]
impl PipelineStage for RequestBuildingStage {
async fn execute(&self, ctx: &mut RequestContext) -> Result<Option<Response>, Response> {
let prep = ctx
.state
.preparation
.as_ref()
.ok_or_else(|| error::internal_error("Preparation not completed"))?;
let clients = ctx
.state
.clients
.as_ref()
.ok_or_else(|| error::internal_error("Client acquisition not completed"))?;
// Get client for building request (use prefill client if PD mode)
let builder_client = match clients {
ClientSelection::Single { client } => client,
ClientSelection::Dual { prefill, .. } => prefill,
};
let mut proto_request = match &ctx.input.request_type {
RequestType::Chat(request) => {
let request_id = format!("chatcmpl-{}", Uuid::new_v4());
let body_ref = prep.filtered_request.as_ref().unwrap_or(request);
builder_client
.build_generate_request(
request_id,
body_ref,
prep.processed_messages.as_ref().unwrap().text.clone(),
prep.token_ids.clone(),
prep.processed_messages
.as_ref()
.unwrap()
.multimodal_inputs
.clone(),
prep.tool_constraints.clone(),
)
.map_err(|e| error::bad_request(format!("Invalid request parameters: {}", e)))?
}
RequestType::Generate(request) => {
let request_id = request
.rid
.clone()
.unwrap_or_else(|| format!("gen-{}", Uuid::new_v4()));
builder_client
.build_plain_generate_request(
request_id,
request,
prep.original_text.clone(),
prep.token_ids.clone(),
)
.map_err(error::bad_request)?
}
RequestType::Responses(_request) => {
// Responses API builds request during the MCP loop
// For now, create minimal request - responses handler will populate it
let request_id = format!("resp-{}", Uuid::new_v4());
proto::GenerateRequest {
request_id,
..Default::default()
}
}
};
// Inject PD metadata if needed
if self.inject_pd_metadata {
if let WorkerSelection::Dual { prefill, .. } = ctx.state.workers.as_ref().unwrap() {
self.inject_bootstrap_metadata(&mut proto_request, prefill);
}
}
ctx.state.proto_request = Some(proto_request);
Ok(None)
}
fn name(&self) -> &'static str {
"RequestBuilding"
}
}
impl RequestBuildingStage {
fn inject_bootstrap_metadata(
&self,
request: &mut proto::GenerateRequest,
prefill_worker: &Arc<dyn Worker>,
) {
let hostname = prefill_worker.bootstrap_host();
let bootstrap_port = prefill_worker.bootstrap_port().unwrap_or(8998);
// Generate room ID for bootstrap
let room_id = rand::rng().random_range(0..i32::MAX);
// Create DisaggregatedParams
let disagg_params = DisaggregatedParams {
bootstrap_host: hostname.to_string(),
bootstrap_port: bootstrap_port as i32,
bootstrap_room: room_id,
};
// Inject metadata directly into request
request.disaggregated_params = Some(disagg_params);
debug!(
"Injected bootstrap metadata: host={}, port={}, room={}",
hostname, bootstrap_port, room_id
);
}
}
......@@ -21,12 +21,19 @@ use crate::{
},
generate::GenerateFinishReason,
},
reasoning_parser::{
ParserFactory as ReasoningParserFactory, PooledParser as ReasoningPooledParser,
ReasoningParser,
},
tokenizer::{
cache::CachedTokenizer,
chat_template::{ChatTemplateContentFormat, ChatTemplateParams},
traits::Tokenizer,
HuggingFaceTokenizer,
},
tool_parser::{
ParserFactory as ToolParserFactory, PooledParser as ToolPooledParser, ToolParser,
},
};
/// Get gRPC client from worker, returning appropriate error response on failure
......@@ -44,20 +51,17 @@ pub async fn get_grpc_client_from_worker(
/// Process tool call arguments in messages
/// Per Transformers docs, tool call arguments in assistant messages should be dicts
pub fn process_tool_call_arguments(messages: &mut [Value]) -> Result<(), String> {
fn process_tool_call_arguments(messages: &mut [Value]) -> Result<(), String> {
for msg in messages {
// Early return if not assistant message
let role = msg.get("role").and_then(|v| v.as_str());
if role != Some("assistant") {
continue;
}
// Early return if no tool_calls
let Some(tool_calls) = msg.get_mut("tool_calls").and_then(|tc| tc.as_array_mut()) else {
continue;
};
// Process each tool call's arguments
for call in tool_calls {
let Some(function) = call.get_mut("function") else {
continue;
......@@ -107,10 +111,7 @@ pub fn process_content_format(
}
/// Transform a single content field based on content format
pub fn transform_content_field(
content_value: &mut Value,
content_format: ChatTemplateContentFormat,
) {
fn transform_content_field(content_value: &mut Value, content_format: ChatTemplateContentFormat) {
let Some(content_array) = content_value.as_array() else {
return; // Not multimodal, keep as-is
};
......@@ -209,7 +210,7 @@ pub fn generate_tool_constraints(
/// Build JSON schema for required tool calls (array with minItems: 1)
/// Includes $defs consolidation from all tools (matching Python's behavior)
pub fn build_required_array_schema(tools: &[Tool]) -> Result<String, String> {
fn build_required_array_schema(tools: &[Tool]) -> Result<String, String> {
// Build anyOf schemas for each tool
let mut any_of_schemas = Vec::new();
for tool in tools {
......@@ -651,7 +652,7 @@ pub fn generate_tool_call_id(
/// Check if a reasoning parser is available for the given model
pub fn check_reasoning_parser_availability(
reasoning_parser_factory: &crate::reasoning_parser::ParserFactory,
reasoning_parser_factory: &ReasoningParserFactory,
configured_parser: Option<&String>,
model: &str,
) -> bool {
......@@ -666,7 +667,7 @@ pub fn check_reasoning_parser_availability(
/// Check if a tool parser is available for the given model
pub fn check_tool_parser_availability(
tool_parser_factory: &crate::tool_parser::ParserFactory,
tool_parser_factory: &ToolParserFactory,
configured_parser: Option<&String>,
model: &str,
) -> bool {
......@@ -683,10 +684,10 @@ pub fn check_tool_parser_availability(
/// Otherwise, auto-detect based on the model name.
/// Get a pooled reasoning parser (for non-streaming where state doesn't matter)
pub fn get_reasoning_parser(
reasoning_parser_factory: &crate::reasoning_parser::ParserFactory,
reasoning_parser_factory: &ReasoningParserFactory,
configured_parser: Option<&String>,
model: &str,
) -> crate::reasoning_parser::PooledParser {
) -> ReasoningPooledParser {
if let Some(parser_name) = configured_parser {
// Use configured parser if specified
reasoning_parser_factory
......@@ -707,10 +708,10 @@ pub fn get_reasoning_parser(
/// Create a fresh reasoning parser instance (for streaming where state isolation is needed)
pub fn create_reasoning_parser(
reasoning_parser_factory: &crate::reasoning_parser::ParserFactory,
reasoning_parser_factory: &ReasoningParserFactory,
configured_parser: Option<&String>,
model: &str,
) -> Option<Box<dyn crate::reasoning_parser::ReasoningParser>> {
) -> Option<Box<dyn ReasoningParser>> {
if let Some(parser_name) = configured_parser {
// Use configured parser if specified
reasoning_parser_factory
......@@ -735,10 +736,10 @@ pub fn create_reasoning_parser(
/// Otherwise, auto-detect based on the model name.
/// Get a pooled tool parser (for non-streaming where state doesn't matter)
pub fn get_tool_parser(
tool_parser_factory: &crate::tool_parser::ParserFactory,
tool_parser_factory: &ToolParserFactory,
configured_parser: Option<&String>,
model: &str,
) -> crate::tool_parser::PooledParser {
) -> ToolPooledParser {
if let Some(parser_name) = configured_parser {
// Use configured parser if specified
tool_parser_factory
......@@ -759,10 +760,10 @@ pub fn get_tool_parser(
/// Create a fresh tool parser instance (for streaming where state isolation is needed)
pub fn create_tool_parser(
tool_parser_factory: &crate::tool_parser::ParserFactory,
tool_parser_factory: &ToolParserFactory,
configured_parser: Option<&String>,
model: &str,
) -> Option<Box<dyn crate::tool_parser::ToolParser>> {
) -> Option<Box<dyn ToolParser>> {
if let Some(parser_name) = configured_parser {
// Use configured parser if specified
tool_parser_factory
......
......@@ -18,20 +18,15 @@ use minijinja::{
use serde_json;
/// Chat template content format
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum ChatTemplateContentFormat {
/// Content is a simple string
#[default]
String,
/// Content is a list of structured parts (OpenAI format)
OpenAI,
}
impl Default for ChatTemplateContentFormat {
fn default() -> Self {
Self::String
}
}
impl std::fmt::Display for ChatTemplateContentFormat {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
......
......@@ -171,7 +171,7 @@ fn test_chatml_template() {
let processor = ChatTemplateProcessor::new(template.to_string());
let messages = vec![
let messages = [
ChatMessage::User {
content: UserMessageContent::Text("Hello".to_string()),
name: None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment