Unverified Commit d736e0b6 authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[router] add grpc router pd mode for chat and generate (#11140)

parent ffd03a9b
...@@ -67,8 +67,8 @@ dependencies = [ ...@@ -67,8 +67,8 @@ dependencies = [
"uvicorn", "uvicorn",
"uvloop", "uvloop",
"xgrammar==0.1.24", "xgrammar==0.1.24",
"grpcio==1.74.0", # keep it align with compile_proto.py "grpcio==1.75.1", # keep it align with compile_proto.py
"grpcio-tools==1.74.0" # keep it align with compile_proto.py "grpcio-tools==1.75.1" # keep it align with compile_proto.py
] ]
[project.optional-dependencies] [project.optional-dependencies]
......
...@@ -19,7 +19,6 @@ import grpc ...@@ -19,7 +19,6 @@ import grpc
import zmq import zmq
import zmq.asyncio import zmq.asyncio
from sglang.srt.managers.disagg_service import start_disagg_service
from sglang.srt.managers.io_struct import ( from sglang.srt.managers.io_struct import (
AbortReq, AbortReq,
BatchEmbeddingOutput, BatchEmbeddingOutput,
...@@ -111,6 +110,7 @@ class GrpcRequestManager: ...@@ -111,6 +110,7 @@ class GrpcRequestManager:
self, self,
server_args: ServerArgs, server_args: ServerArgs,
port_args: PortArgs, port_args: PortArgs,
bootstrap_server=None,
): ):
"""Initialize the gRPC request manager.""" """Initialize the gRPC request manager."""
self.server_args = server_args self.server_args = server_args
...@@ -147,8 +147,8 @@ class GrpcRequestManager: ...@@ -147,8 +147,8 @@ class GrpcRequestManager:
self.crash_dump_request_list = [] self.crash_dump_request_list = []
self.crash_dump_performed = False self.crash_dump_performed = False
# Bootstrap server for disaggregation mode # Bootstrap server (passed from serve_grpc, not started here)
self.bootstrap_server = start_disagg_service(server_args) self.bootstrap_server = bootstrap_server
logger.info( logger.info(
f"GrpcRequestManager initialized with ZMQ IPC: " f"GrpcRequestManager initialized with ZMQ IPC: "
...@@ -157,7 +157,7 @@ class GrpcRequestManager: ...@@ -157,7 +157,7 @@ class GrpcRequestManager:
) )
if self.bootstrap_server: if self.bootstrap_server:
logger.info( logger.info(
f"Bootstrap server started for disaggregation mode: " f"Bootstrap server initialized for disaggregation mode: "
f"{server_args.disaggregation_mode}" f"{server_args.disaggregation_mode}"
) )
......
...@@ -16,11 +16,13 @@ from typing import AsyncIterator, Dict, Optional, Tuple ...@@ -16,11 +16,13 @@ from typing import AsyncIterator, Dict, Optional, Tuple
import grpc import grpc
from grpc_reflection.v1alpha import reflection from grpc_reflection.v1alpha import reflection
from sglang.srt.disaggregation.utils import FAKE_BOOTSTRAP_HOST, DisaggregationMode
from sglang.srt.entrypoints.grpc_request_manager import GrpcRequestManager from sglang.srt.entrypoints.grpc_request_manager import GrpcRequestManager
from sglang.srt.grpc import sglang_scheduler_pb2, sglang_scheduler_pb2_grpc from sglang.srt.grpc import sglang_scheduler_pb2, sglang_scheduler_pb2_grpc
from sglang.srt.managers.data_parallel_controller import ( from sglang.srt.managers.data_parallel_controller import (
run_data_parallel_controller_process, run_data_parallel_controller_process,
) )
from sglang.srt.managers.disagg_service import start_disagg_service
from sglang.srt.managers.io_struct import ( from sglang.srt.managers.io_struct import (
TokenizedEmbeddingReqInput, TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput, TokenizedGenerateReqInput,
...@@ -331,6 +333,10 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer) ...@@ -331,6 +333,10 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
token_ids_logprob=None, token_ids_logprob=None,
) )
if self.server_args.disaggregation_mode != DisaggregationMode.NULL:
health_request.bootstrap_host = FAKE_BOOTSTRAP_HOST
health_request.bootstrap_room = 0
logger.info(f"Sending health check request to request manager...") logger.info(f"Sending health check request to request manager...")
# Submit and wait for response # Submit and wait for response
...@@ -406,6 +412,15 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer) ...@@ -406,6 +412,15 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
# Convert sampling params # Convert sampling params
sampling_params = self._convert_sampling_params(grpc_req.sampling_params) sampling_params = self._convert_sampling_params(grpc_req.sampling_params)
# Extract disaggregated params if present
bootstrap_host = None
bootstrap_port = None
bootstrap_room = None
if grpc_req.HasField("disaggregated_params"):
bootstrap_host = grpc_req.disaggregated_params.bootstrap_host or None
bootstrap_port = grpc_req.disaggregated_params.bootstrap_port or None
bootstrap_room = grpc_req.disaggregated_params.bootstrap_room or None
# Create request # Create request
return TokenizedGenerateReqInput( return TokenizedGenerateReqInput(
rid=grpc_req.request_id, rid=grpc_req.request_id,
...@@ -425,6 +440,9 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer) ...@@ -425,6 +440,9 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
token_ids_logprob=( token_ids_logprob=(
list(grpc_req.token_ids_logprob) if grpc_req.token_ids_logprob else None list(grpc_req.token_ids_logprob) if grpc_req.token_ids_logprob else None
), ),
bootstrap_host=bootstrap_host,
bootstrap_port=bootstrap_port,
bootstrap_room=bootstrap_room,
) )
def _convert_embed_request( def _convert_embed_request(
...@@ -659,6 +677,16 @@ async def serve_grpc( ...@@ -659,6 +677,16 @@ async def serve_grpc(
): ):
"""Start the standalone gRPC server with integrated scheduler.""" """Start the standalone gRPC server with integrated scheduler."""
# Start bootstrap server BEFORE launching scheduler processes (only in PREFILL mode)
# This ensures the bootstrap server is ready when prefill schedulers try to register
bootstrap_server = None
if server_args.disaggregation_mode == "prefill":
bootstrap_server = start_disagg_service(server_args)
if bootstrap_server:
logger.info(
f"Bootstrap server started for disaggregation mode on {server_args.host}:{server_args.disaggregation_bootstrap_port}"
)
# Launch only the scheduler process(es) (no tokenizer/detokenizer needed for gRPC) # Launch only the scheduler process(es) (no tokenizer/detokenizer needed for gRPC)
logger.info("Launching scheduler process(es)...") logger.info("Launching scheduler process(es)...")
scheduler_info, port_args, scheduler_procs = _launch_scheduler_process_only( scheduler_info, port_args, scheduler_procs = _launch_scheduler_process_only(
...@@ -682,9 +710,11 @@ async def serve_grpc( ...@@ -682,9 +710,11 @@ async def serve_grpc(
} }
# Create request manager with the correct port args # Create request manager with the correct port args
# Note: We pass None for bootstrap_server since it's already started above
request_manager = GrpcRequestManager( request_manager = GrpcRequestManager(
server_args=server_args, server_args=server_args,
port_args=port_args, port_args=port_args,
bootstrap_server=bootstrap_server,
) )
# Create gRPC server # Create gRPC server
...@@ -764,79 +794,9 @@ def main(): ...@@ -764,79 +794,9 @@ def main():
mp.set_start_method("spawn", force=True) mp.set_start_method("spawn", force=True)
parser = argparse.ArgumentParser(description="SGLang Standalone gRPC Server") parser = argparse.ArgumentParser(description="SGLang Standalone gRPC Server")
ServerArgs.add_cli_args(parser)
# Server arguments
parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind to")
parser.add_argument("--port", type=int, default=30000, help="gRPC server port")
# Model arguments
parser.add_argument("--model-path", type=str, required=True, help="Model path")
parser.add_argument("--tokenizer-path", type=str, help="Tokenizer path")
parser.add_argument("--context-length", type=int, help="Context length")
parser.add_argument("--tp-size", type=int, default=1, help="Tensor parallel size")
parser.add_argument("--dp-size", type=int, default=1, help="Data parallel size")
# Runtime arguments
parser.add_argument(
"--max-running-requests", type=int, default=2048, help="Max concurrent requests"
)
parser.add_argument(
"--max-total-tokens", type=int, default=1000000, help="Max total tokens"
)
parser.add_argument(
"--max-prefill-tokens", type=int, default=16384, help="Max prefill tokens"
)
parser.add_argument(
"--attention-backend", type=str, default="flashinfer", help="Attention backend"
)
parser.add_argument("--lora-paths", type=str, help="LoRA adapter paths")
# Logging
parser.add_argument("--log-level", type=str, default="INFO", help="Logging level")
# Disaggregation mode arguments
parser.add_argument(
"--disaggregation-mode",
type=str,
default="null",
choices=["null", "prefill", "decode"],
help='Only used for PD disaggregation. "prefill" for prefill-only server, and "decode" for decode-only server. If not specified, it is not PD disaggregated',
)
parser.add_argument(
"--disaggregation-transfer-backend",
type=str,
default="mooncake",
choices=["mooncake", "nixl", "ascend", "fake"],
help="The backend for disaggregation transfer. Default is mooncake.",
)
parser.add_argument(
"--disaggregation-bootstrap-port",
type=int,
default=8998,
help="Bootstrap server port on the prefill server. Default is 8998.",
)
args = parser.parse_args() args = parser.parse_args()
server_args = ServerArgs.from_cli_args(args)
# Convert to ServerArgs with gRPC host and port
server_args = ServerArgs(
model_path=args.model_path,
tokenizer_path=args.tokenizer_path or args.model_path,
context_length=args.context_length,
tp_size=args.tp_size,
dp_size=args.dp_size,
max_running_requests=args.max_running_requests,
max_total_tokens=args.max_total_tokens,
max_prefill_tokens=args.max_prefill_tokens,
attention_backend=args.attention_backend,
lora_paths=args.lora_paths.split(",") if args.lora_paths else None,
log_level=args.log_level,
disaggregation_mode=args.disaggregation_mode,
disaggregation_transfer_backend=args.disaggregation_transfer_backend,
disaggregation_bootstrap_port=args.disaggregation_bootstrap_port,
host=args.host,
port=args.port,
)
# Run server # Run server
asyncio.run( asyncio.run(
......
...@@ -31,6 +31,18 @@ pub trait Worker: Send + Sync + fmt::Debug { ...@@ -31,6 +31,18 @@ pub trait Worker: Send + Sync + fmt::Debug {
/// Get the worker's connection mode (HTTP or gRPC) /// Get the worker's connection mode (HTTP or gRPC)
fn connection_mode(&self) -> ConnectionMode; fn connection_mode(&self) -> ConnectionMode;
/// Get the bootstrap hostname for PD mode
/// Returns cached hostname parsed from URL at construction time
fn bootstrap_host(&self) -> &str {
&self.metadata().bootstrap_host
}
/// Get the bootstrap port for PD mode
/// Returns cached port from WorkerType::Prefill
fn bootstrap_port(&self) -> Option<u16> {
self.metadata().bootstrap_port
}
/// Check if the worker is currently healthy /// Check if the worker is currently healthy
fn is_healthy(&self) -> bool; fn is_healthy(&self) -> bool;
...@@ -147,21 +159,6 @@ pub trait Worker: Send + Sync + fmt::Debug { ...@@ -147,21 +159,6 @@ pub trait Worker: Send + Sync + fmt::Debug {
true true
} }
// TODO: - Enhanced Worker Discovery
// The Worker trait should handle async discovery of metadata from the worker itself
// rather than having service discovery or other components query /get_server_info.
// This keeps service discovery decoupled from worker-specific APIs.
//
// Proposed additions:
// - async fn discover_metadata(&mut self) -> Result<(), Error>
// Query /get_server_info and populate metadata labels with model_id, priority, cost, etc.
// - async fn validate_configuration(&self) -> Result<(), Error>
// Ensure worker has required configuration for its mode (e.g., tokenizer for gRPC)
// - Make worker creation async to allow metadata discovery during initialization
//
// This way service discovery just calls router.add_worker() and the worker
// handles its own metadata discovery internally.
/// Get the model ID this worker serves /// Get the model ID this worker serves
fn model_id(&self) -> &str { fn model_id(&self) -> &str {
self.metadata() self.metadata()
...@@ -325,6 +322,10 @@ pub struct WorkerMetadata { ...@@ -325,6 +322,10 @@ pub struct WorkerMetadata {
pub health_config: HealthConfig, pub health_config: HealthConfig,
/// API key /// API key
pub api_key: Option<String>, pub api_key: Option<String>,
/// Cached bootstrap hostname (parsed from URL at construction time)
pub bootstrap_host: String,
/// Cached bootstrap port (from WorkerType::Prefill)
pub bootstrap_port: Option<u16>,
} }
/// Basic worker implementation /// Basic worker implementation
......
...@@ -96,12 +96,29 @@ impl BasicWorkerBuilder { ...@@ -96,12 +96,29 @@ impl BasicWorkerBuilder {
/// Build the BasicWorker instance /// Build the BasicWorker instance
pub fn build(self) -> BasicWorker { pub fn build(self) -> BasicWorker {
use std::borrow::Cow;
use std::sync::{ use std::sync::{
atomic::{AtomicBool, AtomicUsize}, atomic::{AtomicBool, AtomicUsize},
Arc, Arc,
}; };
use tokio::sync::{Mutex, RwLock}; use tokio::sync::{Mutex, RwLock};
let url_to_parse = if self.url.contains("://") {
Cow::from(&self.url)
} else {
Cow::from(format!("http://{}", self.url))
};
let bootstrap_host = match url::Url::parse(&url_to_parse) {
Ok(parsed) => parsed.host_str().unwrap_or("localhost").to_string(),
Err(_) => "localhost".to_string(),
};
let bootstrap_port = match self.worker_type {
WorkerType::Prefill { bootstrap_port } => bootstrap_port,
_ => None,
};
let metadata = WorkerMetadata { let metadata = WorkerMetadata {
url: self.url.clone(), url: self.url.clone(),
api_key: self.api_key, api_key: self.api_key,
...@@ -109,6 +126,8 @@ impl BasicWorkerBuilder { ...@@ -109,6 +126,8 @@ impl BasicWorkerBuilder {
connection_mode: self.connection_mode, connection_mode: self.connection_mode,
labels: self.labels, labels: self.labels,
health_config: self.health_config, health_config: self.health_config,
bootstrap_host,
bootstrap_port,
}; };
let grpc_client = Arc::new(RwLock::new( let grpc_client = Arc::new(RwLock::new(
......
...@@ -342,6 +342,12 @@ impl SglangSchedulerClient { ...@@ -342,6 +342,12 @@ impl SglangSchedulerClient {
.map_err(|_| "min_tokens must fit into a 32-bit signed integer".to_string())?; .map_err(|_| "min_tokens must fit into a 32-bit signed integer".to_string())?;
} }
// Handle n with conversion
if let Some(n) = p.n {
sampling.n = i32::try_from(n)
.map_err(|_| "n must fit into a 32-bit signed integer".to_string())?;
}
// Handle constraints (exactly one allowed) // Handle constraints (exactly one allowed)
sampling.constraint = Self::build_single_constraint_from_plain(p)?; sampling.constraint = Self::build_single_constraint_from_plain(p)?;
......
...@@ -2,6 +2,11 @@ use serde::{Deserialize, Serialize}; ...@@ -2,6 +2,11 @@ use serde::{Deserialize, Serialize};
use serde_json::{to_value, Map, Number, Value}; use serde_json::{to_value, Map, Number, Value};
use std::collections::HashMap; use std::collections::HashMap;
// Default model value when not specified
fn default_model() -> String {
"unknown".to_string()
}
// # Protocol Specifications // # Protocol Specifications
// //
// This module contains all protocol definitions for OpenAI and SGLang APIs. // This module contains all protocol definitions for OpenAI and SGLang APIs.
...@@ -169,6 +174,7 @@ pub struct ChatCompletionRequest { ...@@ -169,6 +174,7 @@ pub struct ChatCompletionRequest {
pub messages: Vec<ChatMessage>, pub messages: Vec<ChatMessage>,
/// ID of the model to use /// ID of the model to use
#[serde(default = "default_model")]
pub model: String, pub model: String,
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far /// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far
......
//! gRPC router implementations //! gRPC router implementations
use crate::grpc_client::proto;
use crate::protocols::spec::StringOrArray;
pub mod pd_router; pub mod pd_router;
pub mod router; pub mod router;
pub mod utils;
/// Processed chat messages ready for gRPC generation
#[derive(Debug)]
pub struct ProcessedMessages {
pub text: String,
pub multimodal_inputs: Option<proto::MultimodalInputs>,
pub stop_sequences: Option<StringOrArray>,
}
// PD (Prefill-Decode) gRPC Router Implementation // PD (Prefill-Decode) gRPC Router Implementation
use crate::config::types::RetryConfig; use crate::config::types::RetryConfig;
use crate::core::{WorkerRegistry, WorkerType}; use crate::core::{ConnectionMode, Worker, WorkerRegistry, WorkerType};
use crate::metrics::RouterMetrics; use crate::grpc_client::proto;
use crate::grpc_client::SglangSchedulerClient;
use crate::policies::PolicyRegistry; use crate::policies::PolicyRegistry;
use crate::reasoning_parser::ReasoningParserFactory; use crate::protocols::spec::{
use crate::routers::RouterTrait; ChatChoice, ChatCompletionMessage, ChatCompletionRequest, ChatCompletionResponse,
ChatCompletionStreamResponse, ChatLogProbs, ChatLogProbsContent, ChatMessageDelta,
ChatStreamChoice, CompletionRequest, EmbeddingRequest, FunctionCallDelta, FunctionCallResponse,
GenerateRequest, InputIds, RerankRequest, ResponsesGetParams, ResponsesRequest, StringOrArray,
Tool, ToolCall, ToolCallDelta, ToolChoice, ToolChoiceValue, TopLogProb, Usage,
};
use crate::reasoning_parser::{ParserResult, ReasoningParser, ReasoningParserFactory};
use crate::routers::http::pd_types::generate_room_id;
use crate::routers::{grpc, RouterTrait};
use crate::server::AppContext;
use crate::tokenizer::traits::Tokenizer; use crate::tokenizer::traits::Tokenizer;
use crate::tool_parser::ToolParserFactory; use crate::tokenizer::{SequenceDecoderOutput, StopSequenceDecoder};
use crate::tool_parser::{StreamingParseResult, ToolParser, ToolParserFactory};
use async_trait::async_trait; use async_trait::async_trait;
use axum::{ use axum::{
body::Body, body::Body,
extract::Request, extract::Request,
http::{HeaderMap, StatusCode}, http::{header, HeaderMap, HeaderValue, StatusCode},
response::{IntoResponse, Response}, response::{IntoResponse, Response},
Json,
}; };
use grpc::utils;
use proto::generate_response::Response::{Chunk, Complete, Error};
use serde_json::Value;
use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
use tracing::info; use std::time::Instant;
use std::time::{SystemTime, UNIX_EPOCH};
use tokio::sync::mpsc::unbounded_channel;
use tokio::sync::mpsc::UnboundedSender;
use tokio_stream::Stream;
use tokio_stream::StreamExt;
use tracing::{debug, error, warn};
use uuid::Uuid;
/// gRPC PD (Prefill-Decode) router implementation for SGLang /// gRPC PD (Prefill-Decode) router implementation for SGLang
#[derive(Clone)]
#[allow(dead_code)] // Fields will be used once implementation is complete #[allow(dead_code)] // Fields will be used once implementation is complete
pub struct GrpcPDRouter { pub struct GrpcPDRouter {
worker_registry: Arc<WorkerRegistry>, worker_registry: Arc<WorkerRegistry>,
...@@ -26,7 +50,6 @@ pub struct GrpcPDRouter { ...@@ -26,7 +50,6 @@ pub struct GrpcPDRouter {
tokenizer: Arc<dyn Tokenizer>, tokenizer: Arc<dyn Tokenizer>,
reasoning_parser_factory: ReasoningParserFactory, reasoning_parser_factory: ReasoningParserFactory,
tool_parser_factory: ToolParserFactory, tool_parser_factory: ToolParserFactory,
dp_aware: bool, dp_aware: bool,
api_key: Option<String>, api_key: Option<String>,
retry_config: RetryConfig, retry_config: RetryConfig,
...@@ -34,7 +57,7 @@ pub struct GrpcPDRouter { ...@@ -34,7 +57,7 @@ pub struct GrpcPDRouter {
impl GrpcPDRouter { impl GrpcPDRouter {
/// Create a new gRPC PD router /// Create a new gRPC PD router
pub async fn new(ctx: &Arc<crate::server::AppContext>) -> Result<Self, String> { pub async fn new(ctx: &Arc<AppContext>) -> Result<Self, String> {
// Get registries from context // Get registries from context
let worker_registry = ctx.worker_registry.clone(); let worker_registry = ctx.worker_registry.clone();
let policy_registry = ctx.policy_registry.clone(); let policy_registry = ctx.policy_registry.clone();
...@@ -56,33 +79,6 @@ impl GrpcPDRouter { ...@@ -56,33 +79,6 @@ impl GrpcPDRouter {
.ok_or_else(|| "gRPC PD router requires tool parser factory".to_string())? .ok_or_else(|| "gRPC PD router requires tool parser factory".to_string())?
.clone(); .clone();
// Get prefill and decode workers from registry - they should have been created by WorkerManager
let prefill_workers = worker_registry.get_workers_filtered(
None, // any model
Some(WorkerType::Prefill {
bootstrap_port: None,
}),
Some(crate::core::ConnectionMode::Grpc { port: None }),
false, // include unhealthy workers during initialization
);
let decode_workers = worker_registry.get_workers_filtered(
None, // any model
Some(WorkerType::Decode),
Some(crate::core::ConnectionMode::Grpc { port: None }),
false, // include unhealthy workers during initialization
);
// Update metrics
RouterMetrics::set_active_workers(prefill_workers.len() + decode_workers.len());
info!(
"gRPC PD router found {} prefill and {} decode workers in registry",
prefill_workers.len(),
decode_workers.len()
);
// No need for local health checkers - WorkerRegistry handles health checking
Ok(GrpcPDRouter { Ok(GrpcPDRouter {
worker_registry, worker_registry,
policy_registry, policy_registry,
...@@ -94,6 +90,1895 @@ impl GrpcPDRouter { ...@@ -94,6 +90,1895 @@ impl GrpcPDRouter {
retry_config: ctx.router_config.effective_retry_config(), retry_config: ctx.router_config.effective_retry_config(),
}) })
} }
/// Select a prefill-decode worker pair using load balancing policies
async fn select_pd_pair(
&self,
request_text: Option<&str>,
model_id: Option<&str>,
) -> Result<(Arc<dyn Worker>, Arc<dyn Worker>), String> {
let effective_model_id = if !self.dp_aware { None } else { model_id };
debug!(
"Selecting PD pair: dp_aware={}, model_id={:?}, effective_model_id={:?}",
self.dp_aware, model_id, effective_model_id
);
// Get prefill workers
let prefill_workers = if let Some(model) = effective_model_id {
self.worker_registry
.get_by_model_fast(model)
.into_iter()
.filter(|w| matches!(w.worker_type(), WorkerType::Prefill { .. }))
.collect()
} else {
self.worker_registry.get_workers_filtered(
None,
Some(WorkerType::Prefill {
bootstrap_port: None,
}),
Some(ConnectionMode::Grpc { port: None }),
true, // only healthy workers
)
};
// Get decode workers
let decode_workers = if let Some(model) = effective_model_id {
self.worker_registry
.get_by_model_fast(model)
.into_iter()
.filter(|w| matches!(w.worker_type(), WorkerType::Decode))
.collect()
} else {
self.worker_registry.get_workers_filtered(
None,
Some(WorkerType::Decode),
Some(ConnectionMode::Grpc { port: None }),
true, // only healthy workers
)
};
if prefill_workers.is_empty() {
return Err("No healthy prefill workers available".to_string());
}
if decode_workers.is_empty() {
return Err("No healthy decode workers available".to_string());
}
debug!(
"Found {} prefill workers and {} decode workers",
prefill_workers.len(),
decode_workers.len()
);
let prefill_policy = self.policy_registry.get_prefill_policy();
let decode_policy = self.policy_registry.get_decode_policy();
let prefill_idx = prefill_policy
.select_worker(&prefill_workers, request_text)
.ok_or_else(|| "Failed to select prefill worker".to_string())?;
let decode_idx = decode_policy
.select_worker(&decode_workers, request_text)
.ok_or_else(|| "Failed to select decode worker".to_string())?;
let prefill = prefill_workers[prefill_idx].clone();
let decode = decode_workers[decode_idx].clone();
debug!(
"Selected PD pair: prefill={}, decode={}",
prefill.url(),
decode.url()
);
Ok((prefill, decode))
}
/// Main route_generate implementation with PD dual dispatch
async fn route_generate_impl(
&self,
_headers: Option<&HeaderMap>,
body: &GenerateRequest,
model_id: Option<&str>,
) -> Response {
debug!(
"Processing generate request for model: {:?} (PD mode)",
model_id
);
// Step 1: Resolve input (text or input_ids)
let (original_text, token_ids) = match self.resolve_generate_input(body) {
Ok(res) => res,
Err(msg) => {
error!("Invalid generate request: {}", msg);
return (StatusCode::BAD_REQUEST, msg).into_response();
}
};
debug!("Resolved input with {} tokens", token_ids.len());
// Step 2: Select prefill-decode worker pair
let (prefill_worker, decode_worker) = match self
.select_pd_pair(original_text.as_deref(), model_id)
.await
{
Ok(pair) => pair,
Err(e) => {
warn!("Failed to select PD worker pair: {}", e);
return (StatusCode::SERVICE_UNAVAILABLE, e).into_response();
}
};
debug!(
"Selected PD pair: prefill={}, decode={}",
prefill_worker.url(),
decode_worker.url()
);
// Step 3: Get gRPC clients for both workers
let prefill_client = match utils::get_grpc_client_from_worker(&prefill_worker).await {
Ok(client) => client,
Err(response) => return response,
};
let decode_client = match utils::get_grpc_client_from_worker(&decode_worker).await {
Ok(client) => client,
Err(response) => return response,
};
// Step 4: Build the gRPC request
let request_id = body
.rid
.clone()
.unwrap_or_else(|| format!("gen-{}", Uuid::new_v4()));
let mut request = match prefill_client.build_plain_generate_request(
request_id.clone(),
body,
original_text.clone(),
token_ids,
) {
Ok(req) => req,
Err(e) => {
error!("Failed to build generate request: {}", e);
return (StatusCode::BAD_REQUEST, e).into_response();
}
};
// Step 5: Inject bootstrap metadata
if let Err(e) = Self::inject_bootstrap_metadata(&mut request, &*prefill_worker) {
error!("Failed to inject bootstrap metadata: {}", e);
return (StatusCode::INTERNAL_SERVER_ERROR, e).into_response();
}
// Step 6: Get weight version for response metadata
let weight_version = decode_worker
.metadata()
.labels
.get("weight_version")
.cloned()
.unwrap_or_else(|| "default".to_string());
// Step 7: Handle streaming vs non-streaming
if body.stream {
self.handle_streaming_generate(
prefill_client,
decode_client,
request,
body,
request_id,
weight_version,
)
.await
} else {
self.handle_non_streaming_generate(
prefill_client,
decode_client,
request,
body,
request_id,
weight_version,
)
.await
}
}
/// Inject bootstrap metadata into a protobuf GenerateRequest
fn inject_bootstrap_metadata(
request: &mut proto::GenerateRequest,
prefill_worker: &dyn Worker,
) -> Result<(), String> {
let hostname = prefill_worker.bootstrap_host();
let bootstrap_port = prefill_worker.bootstrap_port().unwrap_or(8998);
let room_id = generate_room_id();
// Create DisaggregatedParams
let disagg_params = proto::DisaggregatedParams {
bootstrap_host: hostname.to_string(),
bootstrap_port: bootstrap_port as i32,
bootstrap_room: room_id as i32,
};
// Inject metadata
request.disaggregated_params = Some(disagg_params);
debug!(
"Injected bootstrap metadata: host={}, port={}, room={}",
hostname, bootstrap_port, room_id
);
Ok(())
}
/// Main route_chat implementation with PD dual dispatch
async fn route_chat_impl(
&self,
_headers: Option<&HeaderMap>,
body: &ChatCompletionRequest,
model_id: Option<&str>,
) -> Response {
debug!(
"Processing chat completion request for model: {:?} (PD mode)",
model_id
);
// Step 1: Filter tools if needed for allowed_tools or specific function
let body_ref = utils::filter_tools_for_request(body);
// Step 2: Process messages and apply chat template
let processed_messages = match utils::process_chat_messages(&body_ref, &*self.tokenizer) {
Ok(msgs) => msgs,
Err(e) => {
error!("Failed to process chat messages: {}", e);
return (StatusCode::BAD_REQUEST, e.to_string()).into_response();
}
};
// Step 3: Tokenize the processed text
let encoding = match self.tokenizer.encode(&processed_messages.text) {
Ok(encoding) => encoding,
Err(e) => {
error!("Tokenization failed: {}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
format!("Tokenization failed: {}", e),
)
.into_response();
}
};
// Step 4: Build tool constraints if needed
// body_ref already has filtered tools if needed
let tool_call_constraint = body_ref.tools.as_ref().and_then(|tools| {
utils::generate_tool_constraints(tools, &body.tool_choice, &body.model)
});
let token_ids = encoding.token_ids().to_vec();
debug!("Tokenized {} tokens from input", token_ids.len());
// Step 5: Select prefill-decode worker pair
let (prefill_worker, decode_worker) = match self
.select_pd_pair(Some(&processed_messages.text), model_id)
.await
{
Ok(pair) => pair,
Err(e) => {
warn!("Failed to select PD worker pair: {}", e);
return (StatusCode::SERVICE_UNAVAILABLE, e).into_response();
}
};
debug!(
"Selected PD pair: prefill={}, decode={}",
prefill_worker.url(),
decode_worker.url()
);
// Step 6: Get gRPC clients for both workers
let prefill_client = match utils::get_grpc_client_from_worker(&prefill_worker).await {
Ok(client) => client,
Err(response) => return response,
};
let decode_client = match utils::get_grpc_client_from_worker(&decode_worker).await {
Ok(client) => client,
Err(response) => return response,
};
// Step 7: Build the base gRPC request
let request_id = format!("chatcmpl-{}", Uuid::new_v4());
let mut request = match prefill_client.build_generate_request(
request_id.clone(),
&body_ref,
processed_messages.text.clone(),
token_ids,
processed_messages.multimodal_inputs,
tool_call_constraint,
) {
Ok(request) => request,
Err(e) => {
error!("Failed to build gRPC request: {}", e);
return (
StatusCode::BAD_REQUEST,
format!("Invalid request parameters: {}", e),
)
.into_response();
}
};
// Step 8: Inject bootstrap metadata into the request
if let Err(e) = Self::inject_bootstrap_metadata(&mut request, &*prefill_worker) {
error!("Failed to inject bootstrap metadata: {}", e);
return (StatusCode::INTERNAL_SERVER_ERROR, e).into_response();
}
// Step 9: Handle streaming vs non-streaming
if body.stream {
self.handle_streaming_chat(prefill_client, decode_client, request, body)
.await
} else {
self.handle_non_streaming_chat(prefill_client, decode_client, request, body)
.await
}
}
/// Resolve the generate input into optional original text and token IDs
fn resolve_generate_input(
&self,
request: &GenerateRequest,
) -> Result<(Option<String>, Vec<u32>), String> {
if let Some(text) = &request.text {
let encoding = self
.tokenizer
.encode(text)
.map_err(|e| format!("Tokenization failed: {}", e))?;
return Ok((Some(text.to_string()), encoding.token_ids().to_vec()));
}
// Handle input_ids - validate and convert
if let Some(input_ids) = &request.input_ids {
return match input_ids {
InputIds::Single(ids) => ids
.iter()
.map(|&id| u32::try_from(id))
.collect::<Result<Vec<u32>, _>>()
.map(|converted| (None, converted))
.map_err(|_| "input_ids must be non-negative".to_string()),
InputIds::Batch(_) => {
Err("Batch input_ids are not supported in PD mode".to_string())
}
};
}
Err("Either `text` or `input_ids` must be provided".to_string())
}
/// Submit request and handle streaming response for chat completions (PD mode)
async fn handle_streaming_chat(
&self,
mut prefill_client: SglangSchedulerClient,
mut decode_client: SglangSchedulerClient,
request: proto::GenerateRequest,
original_request: &ChatCompletionRequest,
) -> Response {
let request_id = request.request_id.clone();
let model = original_request.model.clone();
// Create channel for SSE streaming
let (tx, rx) = unbounded_channel::<Result<bytes::Bytes, std::io::Error>>();
// Send requests in parallel to both prefill and decode workers
debug!("Starting concurrent streaming requests to prefill and decode workers");
let prefill_request = request.clone();
let decode_request = request;
let (prefill_result, decode_result) = tokio::join!(
prefill_client.generate(prefill_request),
decode_client.generate(decode_request)
);
// Get prefill stream
let prefill_stream = match prefill_result {
Ok(s) => s,
Err(e) => {
error!("Failed to start prefill generation: {}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
format!("Prefill worker failed to start: {}", e),
)
.into_response();
}
};
// Get decode stream - this is what we'll process for output
let decode_stream = match decode_result {
Ok(s) => s,
Err(e) => {
error!("Failed to start decode generation: {}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
format!("Decode worker failed to start: {}", e),
)
.into_response();
}
};
let stop_params = (
original_request.stop.clone(),
original_request.stop_token_ids.clone(),
original_request.skip_special_tokens,
original_request.no_stop_trim,
);
// Spawn processing task for both streams
let self_clone = self.clone();
let original_request_clone = original_request.clone();
tokio::spawn(async move {
let result = Self::process_dual_streaming_chunks(
&self_clone,
prefill_stream,
decode_stream,
request_id,
model,
stop_params,
original_request_clone,
&tx,
)
.await;
if let Err(e) = result {
let error_chunk = format!(
"data: {}\n\n",
serde_json::json!({
"error": {
"message": e,
"type": "internal_error"
}
})
);
let _ = tx.send(Ok(bytes::Bytes::from(error_chunk)));
}
// Send DONE marker
let _ = tx.send(Ok(bytes::Bytes::from("data: [DONE]\n\n")));
});
// Create response with SSE headers
let stream = tokio_stream::wrappers::UnboundedReceiverStream::new(rx);
let mut response = Response::new(Body::from_stream(stream));
*response.status_mut() = StatusCode::OK;
response.headers_mut().insert(
header::CONTENT_TYPE,
HeaderValue::from_static("text/event-stream"),
);
response
.headers_mut()
.insert("Cache-Control", HeaderValue::from_static("no-cache"));
response
.headers_mut()
.insert("Connection", HeaderValue::from_static("keep-alive"));
response
}
/// Submit request and handle streaming response for generate endpoint (PD mode)
async fn handle_streaming_generate(
&self,
mut prefill_client: SglangSchedulerClient,
mut decode_client: SglangSchedulerClient,
request: proto::GenerateRequest,
original_request: &GenerateRequest,
request_id: String,
weight_version: String,
) -> Response {
// Create channel for SSE streaming
let (tx, rx) = unbounded_channel::<Result<bytes::Bytes, std::io::Error>>();
// Send requests in parallel to both prefill and decode workers
debug!("Starting concurrent streaming generate requests to prefill and decode workers");
let prefill_request = request.clone();
let decode_request = request;
let (prefill_result, decode_result) = tokio::join!(
prefill_client.generate(prefill_request),
decode_client.generate(decode_request)
);
// Get prefill stream (for input_logprobs if needed)
let prefill_stream = match prefill_result {
Ok(s) => s,
Err(e) => {
error!("Failed to start prefill generation: {}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
format!("Prefill worker failed to start: {}", e),
)
.into_response();
}
};
// Get decode stream - this is what we'll process for output
let decode_stream = match decode_result {
Ok(s) => s,
Err(e) => {
error!("Failed to start decode generation: {}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
format!("Decode worker failed to start: {}", e),
)
.into_response();
}
};
// Spawn processing task for both streams
let tokenizer = self.tokenizer.clone();
let return_logprob = original_request.return_logprob;
tokio::spawn(async move {
let result = Self::process_generate_streaming(
tokenizer,
prefill_stream,
decode_stream,
request_id,
weight_version,
return_logprob,
&tx,
)
.await;
if let Err(e) = result {
let error_chunk = format!(
"data: {}\n\n",
serde_json::json!({
"error": {
"message": e,
"type": "internal_error"
}
})
);
let _ = tx.send(Ok(bytes::Bytes::from(error_chunk)));
}
// Send DONE marker
let _ = tx.send(Ok(bytes::Bytes::from("data: [DONE]\n\n")));
});
// Create response with SSE headers
let stream = tokio_stream::wrappers::UnboundedReceiverStream::new(rx);
let mut response = Response::new(Body::from_stream(stream));
*response.status_mut() = StatusCode::OK;
response.headers_mut().insert(
header::CONTENT_TYPE,
HeaderValue::from_static("text/event-stream"),
);
response
.headers_mut()
.insert("Cache-Control", HeaderValue::from_static("no-cache"));
response
.headers_mut()
.insert("Connection", HeaderValue::from_static("keep-alive"));
response
}
/// Process generate streaming (simplified - no tool calls or reasoning)
#[allow(clippy::too_many_arguments)]
async fn process_generate_streaming(
tokenizer: Arc<dyn Tokenizer>,
mut prefill_stream: impl Stream<Item = Result<proto::GenerateResponse, tonic::Status>> + Unpin,
mut decode_stream: impl Stream<Item = Result<proto::GenerateResponse, tonic::Status>> + Unpin,
request_id: String,
weight_version: String,
include_logprobs: bool,
tx: &UnboundedSender<Result<bytes::Bytes, std::io::Error>>,
) -> Result<(), String> {
let start_time = Instant::now();
// Phase 1: Collect input_logprobs from prefill stream if requested
// TODO: Store and emit input_logprobs when implementing prompt logprobs in streaming
if include_logprobs {
while let Some(response) = prefill_stream.next().await {
let gen_response = response.map_err(|e| format!("Prefill stream error: {}", e))?;
match gen_response.response {
Some(Complete(_complete)) => {
// Input logprobs collected but not yet used in streaming
break;
}
Some(Error(error)) => {
return Err(format!("Prefill error: {}", error.message));
}
_ => continue,
}
}
}
// Phase 2: Main streaming loop (decode stream)
// Track state per index for n>1 case
let mut accumulated_texts: HashMap<u32, String> = HashMap::new();
let mut completion_tokens_map: HashMap<u32, u32> = HashMap::new();
let mut current_index: u32 = 0;
while let Some(response) = decode_stream.next().await {
let gen_response = response.map_err(|e| format!("Decode stream error: {}", e))?;
match gen_response.response {
Some(Chunk(chunk)) => {
// Use our tracked index instead of chunk.index (PD backend bug workaround)
let index = current_index;
debug!(
"Received chunk with backend_index={}, using_index={}, tokens={:?}",
chunk.index, index, chunk.token_ids
);
let completion_tokens = completion_tokens_map.entry(index).or_insert(0);
*completion_tokens += chunk.token_ids.len() as u32;
let chunk_text = tokenizer.decode(&chunk.token_ids, true).unwrap_or_default();
let accumulated_text = accumulated_texts.entry(index).or_default();
accumulated_text.push_str(&chunk_text);
let index_id = format!("{}-{}", request_id, index);
let chunk_response = serde_json::json!({
"text": accumulated_text.clone(),
"output_ids": chunk.token_ids,
"meta_info": {
"id": index_id,
"finish_reason": null,
"prompt_tokens": chunk.prompt_tokens,
"weight_version": weight_version,
"completion_tokens": *completion_tokens,
"cached_tokens": chunk.cached_tokens
},
"index": index
});
let sse_chunk = format!(
"data: {}\n\n",
serde_json::to_string(&chunk_response).unwrap()
);
tx.send(Ok(bytes::Bytes::from(sse_chunk)))
.map_err(|_| "Failed to send chunk".to_string())?;
}
Some(Complete(complete)) => {
let index = current_index;
debug!(
"Received Complete with backend_index={}, using_index={}, finish_reason={}",
complete.index, index, complete.finish_reason
);
let accumulated_text =
accumulated_texts.get(&index).cloned().unwrap_or_default();
let completion_tokens = *completion_tokens_map.get(&index).unwrap_or(&0);
let index_id = format!("{}-{}", request_id, index);
let e2e_latency = start_time.elapsed().as_secs_f64();
// Send final chunk with finish_reason (no new tokens in Complete, they were already sent in Chunks)
let finish_response = serde_json::json!({
"text": accumulated_text,
"output_ids": complete.output_ids[complete.output_ids.len().saturating_sub(1)..].to_vec(),
"meta_info": {
"id": index_id,
"finish_reason": complete.finish_reason,
"prompt_tokens": complete.prompt_tokens,
"weight_version": weight_version,
"completion_tokens": completion_tokens,
"cached_tokens": complete.cached_tokens,
"e2e_latency": e2e_latency
},
"index": index
});
let sse_chunk = format!(
"data: {}\n\n",
serde_json::to_string(&finish_response).unwrap()
);
tx.send(Ok(bytes::Bytes::from(sse_chunk)))
.map_err(|_| "Failed to send finish chunk".to_string())?;
// Move to next completion
current_index += 1;
}
Some(Error(error)) => {
return Err(error.message);
}
None => continue,
}
}
Ok(())
}
/// Process dual streaming chunks (prefill + decode) and send SSE events (PD mode)
#[allow(clippy::too_many_arguments)]
async fn process_dual_streaming_chunks(
router: &GrpcPDRouter,
mut prefill_stream: impl Stream<Item = Result<proto::GenerateResponse, tonic::Status>> + Unpin,
mut decode_stream: impl Stream<Item = Result<proto::GenerateResponse, tonic::Status>> + Unpin,
request_id: String,
model: String,
stop_params: (Option<StringOrArray>, Option<Vec<u32>>, bool, bool),
original_request: ChatCompletionRequest,
tx: &UnboundedSender<Result<bytes::Bytes, std::io::Error>>,
) -> Result<(), String> {
// Extract request parameters
let separate_reasoning = original_request.separate_reasoning;
let tool_choice = &original_request.tool_choice;
let tools = &original_request.tools;
let history_tool_calls_count = utils::get_history_tool_calls_count(&original_request);
let stream_options = &original_request.stream_options;
// Phase 1: Initialize state tracking (per-index for n>1 support)
let mut is_firsts: HashMap<u32, bool> = HashMap::new();
let mut stream_buffers: HashMap<u32, String> = HashMap::new();
let mut finish_reasons: HashMap<u32, String> = HashMap::new();
let mut matched_stops: HashMap<u32, Option<Value>> = HashMap::new();
let mut prompt_tokens: HashMap<u32, u32> = HashMap::new();
let mut completion_tokens: HashMap<u32, u32> = HashMap::new();
let mut cached_tokens: HashMap<u32, u32> = HashMap::new();
// Parser state (lazy initialization per index)
type PooledReasoningParser = Arc<std::sync::Mutex<Box<dyn ReasoningParser>>>;
let mut reasoning_parsers: HashMap<u32, PooledReasoningParser> = HashMap::new();
type PooledToolParser = Arc<tokio::sync::Mutex<Box<dyn ToolParser>>>;
let mut tool_parsers: HashMap<u32, PooledToolParser> = HashMap::new();
let mut has_tool_calls: HashMap<u32, bool> = HashMap::new();
// Create stop decoder
let (stop, stop_token_ids, skip_special_tokens, no_stop_trim) = stop_params;
let mut stop_decoder = utils::create_stop_decoder(
&router.tokenizer,
stop.as_ref(),
stop_token_ids.as_ref(),
skip_special_tokens,
no_stop_trim,
);
let created = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
// Phase 1.5: Collect input_logprobs from prefill stream if requested
// Note: In PD mode, input_logprobs come from prefill worker
// TODO: Store and emit input_logprobs when implementing prompt logprobs in streaming
if original_request.logprobs {
while let Some(response) = prefill_stream.next().await {
let gen_response = response.map_err(|e| format!("Prefill stream error: {}", e))?;
match gen_response.response {
Some(Complete(_complete)) => {
// Input logprobs collected but not yet used in streaming
// (OpenAI spec doesn't require prompt logprobs in streaming responses)
break;
}
Some(Error(error)) => {
return Err(format!("Prefill error: {}", error.message));
}
_ => continue,
}
}
}
// Phase 2: Main streaming loop (decode stream)
while let Some(response) = decode_stream.next().await {
let gen_response = response.map_err(|e| format!("Stream error: {}", e))?;
match gen_response.response {
Some(Chunk(chunk)) => {
let index = chunk.index;
// Process tokens through stop decoder
let (chunk_text, _should_stop) =
Self::process_chunk_tokens(&mut stop_decoder, &chunk.token_ids);
if chunk_text.is_empty() {
continue;
}
// Process logprobs if present
let choice_logprobs = if let Some(ref proto_logprobs) = chunk.output_logprobs {
match router.convert_proto_to_openai_logprobs(proto_logprobs) {
Ok(logprobs) => Some(logprobs),
Err(e) => {
warn!("Failed to process logprobs: {}", e);
None
}
}
} else {
None
};
// Initialize stream buffer if first time
let stream_buffer = stream_buffers.entry(index).or_default();
// Send first chunk with role
if is_firsts.get(&index).copied().unwrap_or(true) {
let first_chunk = ChatCompletionStreamResponse {
id: request_id.clone(),
object: "chat.completion.chunk".to_string(),
created,
model: model.clone(),
system_fingerprint: None,
choices: vec![ChatStreamChoice {
index,
delta: ChatMessageDelta {
role: Some("assistant".to_string()),
content: None,
tool_calls: None,
reasoning_content: None,
},
logprobs: None,
finish_reason: None,
matched_stop: None,
}],
usage: None,
};
tx.send(Ok(bytes::Bytes::from(Self::format_sse_chunk(&first_chunk))))
.map_err(|_| "Failed to send first chunk".to_string())?;
is_firsts.insert(index, false);
}
// Calculate delta
let mut delta = chunk_text;
stream_buffer.push_str(&delta);
// Reasoning content handling
if separate_reasoning {
let (normal_text, reasoning_chunk) = router.process_reasoning_stream(
&delta,
index,
&mut reasoning_parsers,
&request_id,
&model,
created,
);
if let Some(chunk) = reasoning_chunk {
tx.send(Ok(bytes::Bytes::from(Self::format_sse_chunk(&chunk))))
.map_err(|_| "Failed to send reasoning chunk".to_string())?;
}
delta = normal_text;
}
// Tool call handling
let tool_choice_enabled =
!matches!(tool_choice, Some(ToolChoice::Value(ToolChoiceValue::None)));
if tool_choice_enabled && tools.is_some() {
let (should_skip, tool_chunks) = router
.process_tool_calls_stream(
&delta,
index,
&mut tool_parsers,
&mut has_tool_calls,
tools.as_ref().unwrap(),
&request_id,
&model,
created,
history_tool_calls_count,
)
.await;
for chunk in tool_chunks {
tx.send(Ok(bytes::Bytes::from(Self::format_sse_chunk(&chunk))))
.map_err(|_| "Failed to send tool call chunk".to_string())?;
}
if should_skip {
continue;
}
}
// Regular content emission
if !delta.is_empty() {
let content_chunk = Self::create_content_chunk(
delta,
index,
&request_id,
&model,
created,
choice_logprobs,
);
tx.send(Ok(bytes::Bytes::from(Self::format_sse_chunk(
&content_chunk,
))))
.map_err(|_| "Failed to send content chunk".to_string())?;
}
}
Some(Complete(complete)) => {
// Flush any remaining text
if let SequenceDecoderOutput::Text(text) = stop_decoder.flush() {
if !text.is_empty() {
let index = complete.index;
let stream_buffer = stream_buffers.entry(index).or_default();
stream_buffer.push_str(&text);
let content_chunk = ChatCompletionStreamResponse {
id: request_id.clone(),
object: "chat.completion.chunk".to_string(),
created,
model: model.clone(),
system_fingerprint: None,
choices: vec![ChatStreamChoice {
index,
delta: ChatMessageDelta {
role: Some("assistant".to_string()),
content: Some(text),
tool_calls: None,
reasoning_content: None,
},
logprobs: None,
finish_reason: None,
matched_stop: None,
}],
usage: None,
};
let sse_chunk = serde_json::to_string(&content_chunk)
.map_err(|e| format!("Failed to serialize content chunk: {}", e))?;
tx.send(Ok(bytes::Bytes::from(format!("data: {}\n\n", sse_chunk))))
.map_err(|_| "Failed to send flushed content".to_string())?;
}
}
// Store metadata
let index = complete.index;
prompt_tokens.insert(index, complete.prompt_tokens as u32);
completion_tokens.insert(index, complete.completion_tokens as u32);
cached_tokens.insert(index, complete.cached_tokens as u32);
finish_reasons.insert(index, complete.finish_reason.clone());
// Extract matched_stop
let matched_stop_value = match &complete.matched_stop {
Some(proto::generate_complete::MatchedStop::MatchedTokenId(token_id)) => {
Some(Value::Number(serde_json::Number::from(*token_id)))
}
Some(proto::generate_complete::MatchedStop::MatchedStopStr(stop_str)) => {
Some(Value::String(stop_str.clone()))
}
None => None,
};
matched_stops.insert(index, matched_stop_value);
break;
}
Some(Error(error)) => {
return Err(error.message);
}
None => continue,
}
}
// Phase 3: Check unstreamed tool args
for (index, parser) in &tool_parsers {
let parser_guard = parser.lock().await;
if let Some(unstreamed_items) = parser_guard.get_unstreamed_tool_args() {
for tool_call_item in unstreamed_items {
let tool_call_delta = ToolCallDelta {
index: tool_call_item.tool_index as u32,
id: None,
tool_type: None,
function: Some(FunctionCallDelta {
name: None,
arguments: if !tool_call_item.parameters.is_empty() {
Some(tool_call_item.parameters)
} else {
None
},
}),
};
let tool_chunk = ChatCompletionStreamResponse {
id: request_id.clone(),
object: "chat.completion.chunk".to_string(),
created,
model: model.clone(),
system_fingerprint: None,
choices: vec![ChatStreamChoice {
index: *index,
delta: ChatMessageDelta {
role: Some("assistant".to_string()),
content: None,
tool_calls: Some(vec![tool_call_delta]),
reasoning_content: None,
},
logprobs: None,
finish_reason: None,
matched_stop: None,
}],
usage: None,
};
let sse_chunk = serde_json::to_string(&tool_chunk)
.map_err(|e| format!("Failed to serialize tool chunk: {}", e))?;
tx.send(Ok(bytes::Bytes::from(format!("data: {}\n\n", sse_chunk))))
.map_err(|_| "Failed to send unstreamed tool args".to_string())?;
}
}
}
// Phase 4: Finish reason chunks
for (index, finish_reason) in finish_reasons.iter() {
let final_finish_reason =
if has_tool_calls.get(index).copied().unwrap_or(false) && finish_reason == "stop" {
"tool_calls".to_string()
} else {
finish_reason.clone()
};
let matched_stop_value = matched_stops.get(index).and_then(|v| v.clone());
let finish_chunk = ChatCompletionStreamResponse {
id: request_id.clone(),
object: "chat.completion.chunk".to_string(),
created,
model: model.clone(),
system_fingerprint: None,
choices: vec![ChatStreamChoice {
index: *index,
delta: ChatMessageDelta {
role: Some("assistant".to_string()),
content: None,
tool_calls: None,
reasoning_content: None,
},
logprobs: None,
finish_reason: Some(final_finish_reason),
matched_stop: matched_stop_value,
}],
usage: None,
};
let sse_chunk = serde_json::to_string(&finish_chunk)
.map_err(|e| format!("Failed to serialize finish chunk: {}", e))?;
tx.send(Ok(bytes::Bytes::from(format!("data: {}\n\n", sse_chunk))))
.map_err(|_| "Failed to send finish chunk".to_string())?;
}
// Phase 5: Usage chunk
if let Some(stream_opts) = stream_options {
if stream_opts.include_usage.unwrap_or(false) {
let total_prompt: u32 = prompt_tokens.values().sum();
let total_completion: u32 = completion_tokens.values().sum();
let usage_chunk = ChatCompletionStreamResponse {
id: request_id.clone(),
object: "chat.completion.chunk".to_string(),
created,
model: model.clone(),
system_fingerprint: None,
choices: vec![],
usage: Some(Usage {
prompt_tokens: total_prompt,
completion_tokens: total_completion,
total_tokens: total_prompt + total_completion,
completion_tokens_details: None,
}),
};
let sse_chunk = serde_json::to_string(&usage_chunk)
.map_err(|e| format!("Failed to serialize usage chunk: {}", e))?;
tx.send(Ok(bytes::Bytes::from(format!("data: {}\n\n", sse_chunk))))
.map_err(|_| "Failed to send usage chunk".to_string())?;
}
}
Ok(())
}
/// Helper: Process reasoning content in streaming mode
fn process_reasoning_stream(
&self,
delta: &str,
index: u32,
reasoning_parsers: &mut HashMap<u32, Arc<std::sync::Mutex<Box<dyn ReasoningParser>>>>,
request_id: &str,
model: &str,
created: u64,
) -> (String, Option<ChatCompletionStreamResponse>) {
// Get or create parser for this index
reasoning_parsers
.entry(index)
.or_insert_with(|| self.reasoning_parser_factory.get_pooled(model));
if let Some(pooled_parser) = reasoning_parsers.get(&index) {
let parse_result = {
let mut parser = pooled_parser.lock().unwrap();
parser.parse_reasoning_streaming_incremental(delta)
};
match parse_result {
Ok(ParserResult {
reasoning_text,
normal_text,
}) => {
let chunk = if !reasoning_text.is_empty() {
Some(ChatCompletionStreamResponse {
id: request_id.to_string(),
object: "chat.completion.chunk".to_string(),
created,
model: model.to_string(),
system_fingerprint: None,
choices: vec![ChatStreamChoice {
index,
delta: ChatMessageDelta {
role: Some("assistant".to_string()),
content: None,
tool_calls: None,
reasoning_content: Some(reasoning_text),
},
logprobs: None,
finish_reason: None,
matched_stop: None,
}],
usage: None,
})
} else {
None
};
return (normal_text, chunk);
}
Err(e) => {
warn!("Reasoning parsing error: {}", e);
}
}
}
(delta.to_string(), None)
}
/// Helper: Process tool calls in streaming mode
#[allow(clippy::too_many_arguments)]
async fn process_tool_calls_stream(
&self,
delta: &str,
index: u32,
tool_parsers: &mut HashMap<u32, Arc<tokio::sync::Mutex<Box<dyn ToolParser>>>>,
has_tool_calls: &mut HashMap<u32, bool>,
tools: &[Tool],
request_id: &str,
model: &str,
created: u64,
history_tool_calls_count: usize,
) -> (bool, Vec<ChatCompletionStreamResponse>) {
let mut chunks = Vec::new();
// Get or create parser for this index
tool_parsers
.entry(index)
.or_insert_with(|| self.tool_parser_factory.get_pooled(model));
if let Some(pooled_parser) = tool_parsers.get(&index) {
let mut parser = pooled_parser.lock().await;
match parser.parse_incremental(delta, tools).await {
Ok(StreamingParseResult { normal_text, calls }) => {
// Emit normal text if present
if !normal_text.is_empty() {
chunks.push(ChatCompletionStreamResponse {
id: request_id.to_string(),
object: "chat.completion.chunk".to_string(),
created,
model: model.to_string(),
system_fingerprint: None,
choices: vec![ChatStreamChoice {
index,
delta: ChatMessageDelta {
role: Some("assistant".to_string()),
content: Some(normal_text),
tool_calls: None,
reasoning_content: None,
},
logprobs: None,
finish_reason: None,
matched_stop: None,
}],
usage: None,
});
}
// Emit tool call chunks
for tool_call_item in calls {
has_tool_calls.insert(index, true);
let tool_call_id = if let Some(ref name) = tool_call_item.name {
Some(utils::generate_tool_call_id(
model,
name,
tool_call_item.tool_index,
history_tool_calls_count,
))
} else {
None
};
let tool_call_delta = ToolCallDelta {
index: tool_call_item.tool_index as u32,
id: tool_call_id,
tool_type: if tool_call_item.name.is_some() {
Some("function".to_string())
} else {
None
},
function: Some(FunctionCallDelta {
name: tool_call_item.name,
arguments: if !tool_call_item.parameters.is_empty() {
Some(tool_call_item.parameters)
} else {
None
},
}),
};
chunks.push(ChatCompletionStreamResponse {
id: request_id.to_string(),
object: "chat.completion.chunk".to_string(),
created,
model: model.to_string(),
system_fingerprint: None,
choices: vec![ChatStreamChoice {
index,
delta: ChatMessageDelta {
role: Some("assistant".to_string()),
content: None,
tool_calls: Some(vec![tool_call_delta]),
reasoning_content: None,
},
logprobs: None,
finish_reason: None,
matched_stop: None,
}],
usage: None,
});
}
// If we emitted chunks, skip regular content
return (!chunks.is_empty(), chunks);
}
Err(e) => {
warn!("Tool call parsing error: {}", e);
}
}
}
(false, chunks)
}
/// Helper: Create content chunk
fn create_content_chunk(
content: String,
index: u32,
request_id: &str,
model: &str,
created: u64,
logprobs: Option<ChatLogProbs>,
) -> ChatCompletionStreamResponse {
ChatCompletionStreamResponse {
id: request_id.to_string(),
object: "chat.completion.chunk".to_string(),
created,
model: model.to_string(),
system_fingerprint: None,
choices: vec![ChatStreamChoice {
index,
delta: ChatMessageDelta {
role: Some("assistant".to_string()),
content: Some(content),
tool_calls: None,
reasoning_content: None,
},
logprobs,
finish_reason: None,
matched_stop: None,
}],
usage: None,
}
}
/// Helper: Format response as SSE chunk
fn format_sse_chunk(response: &ChatCompletionStreamResponse) -> String {
format!(
"data: {}\n\n",
serde_json::to_string(response).unwrap_or_default()
)
}
/// Process a chunk of tokens through the stop decoder
fn process_chunk_tokens(
stop_decoder: &mut StopSequenceDecoder,
token_ids: &[u32],
) -> (String, bool) {
let mut chunk_text = String::new();
for &token_id in token_ids {
match stop_decoder.process_token(token_id).unwrap_or_else(|e| {
debug!(
"Error processing token {}: {}. Treating as Held.",
token_id, e
);
SequenceDecoderOutput::Held
}) {
SequenceDecoderOutput::Text(text) => {
chunk_text.push_str(&text);
}
SequenceDecoderOutput::StoppedWithText(text) => {
chunk_text.push_str(&text);
return (chunk_text, true);
}
SequenceDecoderOutput::Stopped => {
return (chunk_text, true);
}
SequenceDecoderOutput::Held => {}
}
}
(chunk_text, false)
}
/// Submit request and handle non-streaming response for chat completions (PD mode)
async fn handle_non_streaming_chat(
&self,
mut prefill_client: SglangSchedulerClient,
mut decode_client: SglangSchedulerClient,
request: proto::GenerateRequest,
original_request: &ChatCompletionRequest,
) -> Response {
// Step 1: Create stop decoder
let mut stop_decoder = utils::create_stop_decoder(
&self.tokenizer,
original_request.stop.as_ref(),
original_request.stop_token_ids.as_ref(),
original_request.skip_special_tokens,
original_request.no_stop_trim,
);
// Step 2: Send requests in parallel
debug!("Sending concurrent requests to prefill and decode workers");
let prefill_request = request.clone();
let decode_request = request;
let (prefill_result, decode_result) = tokio::join!(
prefill_client.generate(prefill_request),
decode_client.generate(decode_request)
);
// Step 3: Process prefill stream in parallel - if it fails, assume decode fails
let prefill_stream = match prefill_result {
Ok(s) => s,
Err(e) => {
error!("Failed to start prefill generation: {}", e);
return utils::internal_error_message(format!(
"Prefill worker failed to start: {}",
e
));
}
};
let decode_stream = match decode_result {
Ok(s) => s,
Err(e) => {
error!("Failed to start decode generation: {}", e);
return utils::internal_error_message(format!(
"Decode worker failed to start: {}",
e
));
}
};
// Collect prefill response (for input_logprobs if requested)
let prefill_responses =
match utils::collect_stream_responses(prefill_stream, "Prefill").await {
Ok(responses) => responses,
Err(error_response) => return error_response,
};
// Extract input_logprobs from prefill response if available
let prefill_input_logprobs = prefill_responses
.first()
.and_then(|r| r.input_logprobs.clone());
// Step 4: Process decode stream (collect all responses for n>1 support)
let all_responses = match utils::collect_stream_responses(decode_stream, "Decode").await {
Ok(responses) => responses,
Err(error_response) => return error_response,
};
if all_responses.is_empty() {
return utils::internal_error_static("No responses from decode worker");
}
// Process each response into a ChatChoice
let history_tool_calls_count = utils::get_history_tool_calls_count(original_request);
let mut choices = Vec::new();
for (index, complete) in all_responses.iter().enumerate() {
// Merge prefill input_logprobs if available and requested
let mut complete_with_logprobs = complete.clone();
if prefill_input_logprobs.is_some() && original_request.logprobs {
complete_with_logprobs.input_logprobs = prefill_input_logprobs.clone();
}
match self
.process_single_choice(
&complete_with_logprobs,
index,
original_request,
&mut stop_decoder,
history_tool_calls_count,
)
.await
{
Ok(choice) => choices.push(choice),
Err(e) => {
return utils::internal_error_message(format!(
"Failed to process choice {}: {}",
index, e
));
}
}
}
// Aggregate usage information from all responses
let total_prompt_tokens: u32 = all_responses.iter().map(|r| r.prompt_tokens as u32).sum();
let total_completion_tokens: u32 = all_responses
.iter()
.map(|r| r.completion_tokens as u32)
.sum();
let usage = Usage {
prompt_tokens: total_prompt_tokens,
completion_tokens: total_completion_tokens,
total_tokens: total_prompt_tokens + total_completion_tokens,
completion_tokens_details: None,
};
// Build final ChatCompletionResponse
let response = ChatCompletionResponse {
id: format!("chatcmpl-{}", Uuid::new_v4()),
object: "chat.completion".to_string(),
created: SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
model: original_request.model.clone(),
choices,
usage: Some(usage),
system_fingerprint: None,
};
// Serialize and return JSON response
Json(response).into_response()
}
/// Submit request and handle non-streaming response for generate endpoint (PD mode)
async fn handle_non_streaming_generate(
&self,
mut prefill_client: SglangSchedulerClient,
mut decode_client: SglangSchedulerClient,
request: proto::GenerateRequest,
original_request: &GenerateRequest,
request_id: String,
weight_version: String,
) -> Response {
use std::time::Instant;
let start_time = Instant::now();
// Send requests in parallel
debug!("Sending concurrent generate requests to prefill and decode workers");
let prefill_request = request.clone();
let decode_request = request;
let (prefill_result, decode_result) = tokio::join!(
prefill_client.generate(prefill_request),
decode_client.generate(decode_request)
);
// Process prefill stream
let prefill_stream = match prefill_result {
Ok(s) => s,
Err(e) => {
error!("Failed to start prefill generation: {}", e);
return utils::internal_error_message(format!(
"Prefill worker failed to start: {}",
e
));
}
};
let decode_stream = match decode_result {
Ok(s) => s,
Err(e) => {
error!("Failed to start decode generation: {}", e);
return utils::internal_error_message(format!(
"Decode worker failed to start: {}",
e
));
}
};
// Collect prefill responses
// TODO add logprob for generate
let _prefill_responses =
match utils::collect_stream_responses(prefill_stream, "Prefill").await {
Ok(responses) => responses,
Err(error_response) => return error_response,
};
// Collect decode responses
let decode_responses = match utils::collect_stream_responses(decode_stream, "Decode").await
{
Ok(responses) => responses,
Err(error_response) => return error_response,
};
if decode_responses.is_empty() {
return utils::internal_error_static("No completion received from decode worker");
}
// Create stop decoder from sampling params
let params = original_request.sampling_params.as_ref();
let mut stop_decoder = utils::create_stop_decoder(
&self.tokenizer,
params.and_then(|p| p.stop.as_ref()),
params.and_then(|p| p.stop_token_ids.as_ref()),
params.and_then(|p| p.skip_special_tokens).unwrap_or(true),
params.and_then(|p| p.no_stop_trim).unwrap_or(false),
);
// Process each completion
let mut result_array = Vec::new();
for mut complete in decode_responses {
stop_decoder.reset();
// Process tokens through stop decoder
let outputs = match stop_decoder.process_tokens(&complete.output_ids) {
Ok(outputs) => outputs,
Err(e) => {
return utils::internal_error_message(format!(
"Failed to process tokens: {}",
e
))
}
};
// Accumulate text with early breaks
let mut decoded_text = String::new();
for output in outputs {
match output {
SequenceDecoderOutput::Text(t) => decoded_text.push_str(&t),
SequenceDecoderOutput::StoppedWithText(t) => {
decoded_text.push_str(&t);
break;
}
SequenceDecoderOutput::Stopped => break,
SequenceDecoderOutput::Held => {}
}
}
// Flush remaining text
if let SequenceDecoderOutput::Text(t) = stop_decoder.flush() {
decoded_text.push_str(&t);
}
let output_ids = complete.output_ids.clone();
// Build base meta_info
let mut meta_info = serde_json::json!({
"id": request_id.clone(),
"finish_reason": complete.finish_reason.clone(),
"prompt_tokens": complete.prompt_tokens,
"weight_version": weight_version.clone(),
"completion_tokens": complete.completion_tokens,
"cached_tokens": complete.cached_tokens,
"e2e_latency": start_time.elapsed().as_secs_f64(),
});
let meta_obj = meta_info.as_object_mut().unwrap();
// Add matched_stop if present
if let Some(matched) = complete.matched_stop.take() {
use proto::generate_complete::MatchedStop;
let matched_value = match matched {
MatchedStop::MatchedTokenId(id) => serde_json::json!(id),
MatchedStop::MatchedStopStr(s) => serde_json::json!(s),
};
meta_obj.insert("matched_stop".to_string(), matched_value);
}
result_array.push(serde_json::json!({
"text": decoded_text,
"output_ids": output_ids,
"meta_info": meta_info,
}));
}
Json(result_array).into_response()
}
/// Process a single GenerateComplete response into a ChatChoice
async fn process_single_choice(
&self,
complete: &proto::GenerateComplete,
index: usize,
original_request: &ChatCompletionRequest,
stop_decoder: &mut StopSequenceDecoder,
history_tool_calls_count: usize,
) -> Result<ChatChoice, String> {
stop_decoder.reset();
// Decode tokens
let outputs = stop_decoder
.process_tokens(&complete.output_ids)
.map_err(|e| format!("Failed to process tokens: {}", e))?;
// Accumulate text with early breaks
let mut final_text = String::new();
for output in outputs {
match output {
SequenceDecoderOutput::Text(t) => final_text.push_str(&t),
SequenceDecoderOutput::StoppedWithText(t) => {
final_text.push_str(&t);
break;
}
SequenceDecoderOutput::Stopped => break,
SequenceDecoderOutput::Held => {}
}
}
// Flush remaining text
if let SequenceDecoderOutput::Text(t) = stop_decoder.flush() {
final_text.push_str(&t);
}
// Step 1: Handle reasoning content parsing
let mut reasoning_text: Option<String> = None;
let mut processed_text = final_text;
// Check if reasoning parsing is enabled and separate_reasoning is requested
if original_request.separate_reasoning {
let pooled_parser = self
.reasoning_parser_factory
.get_pooled(&original_request.model);
let mut parser = pooled_parser
.lock()
.map_err(|e| format!("Failed to acquire reasoning parser lock: {}", e))?;
match parser.detect_and_parse_reasoning(&processed_text) {
Ok(result) => {
if !result.reasoning_text.is_empty() {
reasoning_text = Some(result.reasoning_text);
}
processed_text = result.normal_text;
}
Err(e) => {
return Err(format!("Reasoning parsing error: {}", e));
}
}
}
// Step 2: Handle tool call parsing
let mut tool_calls: Option<Vec<ToolCall>> = None;
// Check if tool calls should be processed
let tool_choice_enabled = !matches!(
&original_request.tool_choice,
Some(ToolChoice::Value(ToolChoiceValue::None))
);
if tool_choice_enabled && original_request.tools.is_some() {
// Check if JSON schema constraint was used (specific function or required mode)
let used_json_schema = match &original_request.tool_choice {
Some(ToolChoice::Function { .. }) => true,
Some(ToolChoice::Value(ToolChoiceValue::Required)) => true,
Some(ToolChoice::AllowedTools { mode, .. }) => mode == "required",
_ => false,
};
if used_json_schema {
(tool_calls, processed_text) = utils::parse_json_schema_response(
&processed_text,
&original_request.tool_choice,
);
} else {
(tool_calls, processed_text) = self
.parse_tool_calls(
&processed_text,
&original_request.model,
history_tool_calls_count,
)
.await;
}
}
// Step 3: Use finish reason directly from proto (already OpenAI-compatible string)
let finish_reason_str = &complete.finish_reason;
// Override finish reason if we have tool calls
let final_finish_reason_str = if tool_calls.is_some() {
"tool_calls"
} else {
finish_reason_str
};
// Extract matched_stop information from proto
let matched_stop = match &complete.matched_stop {
Some(proto::generate_complete::MatchedStop::MatchedTokenId(token_id)) => {
Some(Value::Number(serde_json::Number::from(*token_id)))
}
Some(proto::generate_complete::MatchedStop::MatchedStopStr(stop_str)) => {
Some(Value::String(stop_str.clone()))
}
None => None,
};
// Step 4: Convert output logprobs if present
// Note: complete.input_logprobs exists in proto but is not used for chat completions
// (input logprobs are only used in /v1/completions endpoint with echo=true)
let logprobs = if let Some(proto_logprobs) = &complete.output_logprobs {
match self.convert_proto_to_openai_logprobs(proto_logprobs) {
Ok(logprobs) => Some(logprobs),
Err(e) => {
error!("Failed to convert logprobs: {}", e);
None
}
}
} else {
None
};
// Step 5: Build ChatCompletionMessage (proper response message type)
let chat_message = ChatCompletionMessage {
role: "assistant".to_string(),
content: if processed_text.is_empty() {
None
} else {
Some(processed_text)
},
tool_calls,
reasoning_content: reasoning_text,
};
// Step 6: Build ChatChoice
let choice = ChatChoice {
index: index as u32,
message: chat_message,
logprobs,
finish_reason: Some(final_finish_reason_str.to_string()),
matched_stop,
hidden_states: None,
};
Ok(choice)
}
/// Parse tool calls using model-specific parser
async fn parse_tool_calls(
&self,
processed_text: &str,
model: &str,
history_tool_calls_count: usize,
) -> (Option<Vec<ToolCall>>, String) {
// Get pooled parser for this model
let pooled_parser = self.tool_parser_factory.get_pooled(model);
// Check format detection first
let can_parse = {
let parser = pooled_parser.lock().await;
parser.detect_format(processed_text)
// Lock is dropped here
};
if !can_parse {
return (None, processed_text.to_string());
}
// Lock again for async parsing
let result = {
let parser = pooled_parser.lock().await;
parser.parse_complete(processed_text).await
// Lock is dropped here
};
match result {
Ok((normal_text, parsed_tool_calls)) => {
if parsed_tool_calls.is_empty() {
return (None, normal_text);
}
let spec_tool_calls = parsed_tool_calls
.into_iter()
.enumerate()
.map(|(index, tc)| {
// Generate ID for this tool call
let id = utils::generate_tool_call_id(
model,
&tc.function.name,
index,
history_tool_calls_count,
);
ToolCall {
id,
tool_type: "function".to_string(),
function: FunctionCallResponse {
name: tc.function.name,
arguments: Some(
serde_json::to_string(&tc.function.arguments)
.unwrap_or_else(|_| "{}".to_string()),
),
},
}
})
.collect();
(Some(spec_tool_calls), normal_text)
}
Err(e) => {
error!("Tool call parsing error: {}", e);
(None, processed_text.to_string())
}
}
}
/// Convert proto LogProbs to OpenAI ChatLogProbs format
/// Note: Always decodes with skip_special_tokens=false to show actual tokens generated
fn convert_proto_to_openai_logprobs(
&self,
proto_logprobs: &proto::OutputLogProbs,
) -> Result<ChatLogProbs, String> {
let mut content_items = Vec::new();
// Decode token IDs to text (always with skip_special_tokens=false for logprobs)
let token_texts: Vec<String> = proto_logprobs
.token_ids
.iter()
.map(|&token_id| {
self.tokenizer
.decode(&[token_id as u32], false)
.unwrap_or_else(|_| format!("<token_{}>", token_id))
})
.collect();
// Build ChatLogProbsContent for each token
for (i, &logprob) in proto_logprobs.token_logprobs.iter().enumerate() {
let token_text = token_texts.get(i).cloned().unwrap_or_default();
let bytes = Some(token_text.as_bytes().to_vec());
// Build top_logprobs for this position
let mut top_logprobs = Vec::new();
if let Some(top_logprobs_entry) = proto_logprobs.top_logprobs.get(i) {
// Decode top token IDs (always with skip_special_tokens=false)
let top_token_texts: Vec<String> = top_logprobs_entry
.token_ids
.iter()
.map(|&tid| {
self.tokenizer
.decode(&[tid as u32], false)
.unwrap_or_else(|_| format!("<token_{}>", tid))
})
.collect();
for (j, (&top_logprob, &_top_token_id)) in top_logprobs_entry
.values
.iter()
.zip(top_logprobs_entry.token_ids.iter())
.enumerate()
{
if let Some(top_token_text) = top_token_texts.get(j) {
top_logprobs.push(TopLogProb {
token: top_token_text.clone(),
logprob: top_logprob,
bytes: Some(top_token_text.as_bytes().to_vec()),
});
}
}
}
content_items.push(ChatLogProbsContent {
token: token_text,
logprob,
bytes,
top_logprobs,
});
}
Ok(ChatLogProbs::Detailed {
content: (!content_items.is_empty()).then_some(content_items),
})
}
} }
impl std::fmt::Debug for GrpcPDRouter { impl std::fmt::Debug for GrpcPDRouter {
...@@ -103,13 +1988,13 @@ impl std::fmt::Debug for GrpcPDRouter { ...@@ -103,13 +1988,13 @@ impl std::fmt::Debug for GrpcPDRouter {
Some(WorkerType::Prefill { Some(WorkerType::Prefill {
bootstrap_port: None, bootstrap_port: None,
}), }),
Some(crate::core::ConnectionMode::Grpc { port: None }), Some(ConnectionMode::Grpc { port: None }),
false, false,
); );
let decode_workers = self.worker_registry.get_workers_filtered( let decode_workers = self.worker_registry.get_workers_filtered(
None, None,
Some(WorkerType::Decode), Some(WorkerType::Decode),
Some(crate::core::ConnectionMode::Grpc { port: None }), Some(ConnectionMode::Grpc { port: None }),
false, false,
); );
f.debug_struct("GrpcPDRouter") f.debug_struct("GrpcPDRouter")
...@@ -149,26 +2034,26 @@ impl RouterTrait for GrpcPDRouter { ...@@ -149,26 +2034,26 @@ impl RouterTrait for GrpcPDRouter {
async fn route_generate( async fn route_generate(
&self, &self,
_headers: Option<&HeaderMap>, headers: Option<&HeaderMap>,
_body: &crate::protocols::spec::GenerateRequest, body: &GenerateRequest,
_model_id: Option<&str>, model_id: Option<&str>,
) -> Response { ) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response() self.route_generate_impl(headers, body, model_id).await
} }
async fn route_chat( async fn route_chat(
&self, &self,
_headers: Option<&HeaderMap>, headers: Option<&HeaderMap>,
_body: &crate::protocols::spec::ChatCompletionRequest, body: &ChatCompletionRequest,
_model_id: Option<&str>, model_id: Option<&str>,
) -> Response { ) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response() self.route_chat_impl(headers, body, model_id).await
} }
async fn route_completion( async fn route_completion(
&self, &self,
_headers: Option<&HeaderMap>, _headers: Option<&HeaderMap>,
_body: &crate::protocols::spec::CompletionRequest, _body: &CompletionRequest,
_model_id: Option<&str>, _model_id: Option<&str>,
) -> Response { ) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response() (StatusCode::NOT_IMPLEMENTED).into_response()
...@@ -177,7 +2062,7 @@ impl RouterTrait for GrpcPDRouter { ...@@ -177,7 +2062,7 @@ impl RouterTrait for GrpcPDRouter {
async fn route_responses( async fn route_responses(
&self, &self,
_headers: Option<&HeaderMap>, _headers: Option<&HeaderMap>,
_body: &crate::protocols::spec::ResponsesRequest, _body: &ResponsesRequest,
_model_id: Option<&str>, _model_id: Option<&str>,
) -> Response { ) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response() (StatusCode::NOT_IMPLEMENTED).into_response()
...@@ -187,7 +2072,7 @@ impl RouterTrait for GrpcPDRouter { ...@@ -187,7 +2072,7 @@ impl RouterTrait for GrpcPDRouter {
&self, &self,
_headers: Option<&HeaderMap>, _headers: Option<&HeaderMap>,
_response_id: &str, _response_id: &str,
_params: &crate::protocols::spec::ResponsesGetParams, _params: &ResponsesGetParams,
) -> Response { ) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response() (StatusCode::NOT_IMPLEMENTED).into_response()
} }
...@@ -199,7 +2084,7 @@ impl RouterTrait for GrpcPDRouter { ...@@ -199,7 +2084,7 @@ impl RouterTrait for GrpcPDRouter {
async fn route_embeddings( async fn route_embeddings(
&self, &self,
_headers: Option<&HeaderMap>, _headers: Option<&HeaderMap>,
_body: &crate::protocols::spec::EmbeddingRequest, _body: &EmbeddingRequest,
_model_id: Option<&str>, _model_id: Option<&str>,
) -> Response { ) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response() (StatusCode::NOT_IMPLEMENTED).into_response()
...@@ -208,7 +2093,7 @@ impl RouterTrait for GrpcPDRouter { ...@@ -208,7 +2093,7 @@ impl RouterTrait for GrpcPDRouter {
async fn route_rerank( async fn route_rerank(
&self, &self,
_headers: Option<&HeaderMap>, _headers: Option<&HeaderMap>,
_body: &crate::protocols::spec::RerankRequest, _body: &RerankRequest,
_model_id: Option<&str>, _model_id: Option<&str>,
) -> Response { ) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response() (StatusCode::NOT_IMPLEMENTED).into_response()
......
...@@ -15,45 +15,32 @@ use bytes::Bytes; ...@@ -15,45 +15,32 @@ use bytes::Bytes;
use std::io; use std::io;
use tokio::sync::mpsc; use tokio::sync::mpsc;
use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::{debug, error, info, warn}; use tracing::{debug, error, warn};
use crate::config::types::RetryConfig; use crate::config::types::RetryConfig;
use crate::core::{ConnectionMode, Worker, WorkerRegistry, WorkerType}; use crate::core::{ConnectionMode, Worker, WorkerRegistry, WorkerType};
use crate::grpc_client::{proto, SglangSchedulerClient}; use crate::grpc_client::{proto, SglangSchedulerClient};
use crate::metrics::RouterMetrics;
use crate::policies::PolicyRegistry; use crate::policies::PolicyRegistry;
use crate::protocols::spec::ChatMessage;
use crate::protocols::spec::{ use crate::protocols::spec::{
ChatChoice, ChatCompletionMessage, ChatCompletionRequest, ChatCompletionResponse, ChatChoice, ChatCompletionMessage, ChatCompletionRequest, ChatCompletionResponse,
ChatCompletionStreamResponse, ChatMessageDelta, ChatStreamChoice, CompletionRequest, ChatCompletionStreamResponse, ChatMessage, ChatMessageDelta, ChatStreamChoice,
EmbeddingRequest, FunctionCallDelta, FunctionCallResponse, GenerateRequest, RerankRequest, CompletionRequest, EmbeddingRequest, FunctionCallDelta, FunctionCallResponse, GenerateRequest,
ResponsesGetParams, ResponsesRequest, StringOrArray, Tool, ToolCall, ToolCallDelta, ToolChoice, RerankRequest, ResponsesGetParams, ResponsesRequest, StringOrArray, ToolCall, ToolCallDelta,
ToolChoiceValue, Usage, ToolChoice, ToolChoiceValue, Usage,
}; };
use crate::reasoning_parser::{ParserResult, ReasoningParserFactory}; use crate::reasoning_parser::{ParserResult, ReasoningParserFactory};
use crate::routers::RouterTrait; use crate::routers::{grpc, RouterTrait};
use crate::server::AppContext; use crate::server::AppContext;
use crate::tokenizer::chat_template::{ChatTemplateContentFormat, ChatTemplateParams}; use crate::tokenizer::stop::{SequenceDecoderOutput, StopSequenceDecoder};
use crate::tokenizer::stop::{
SequenceDecoderOutput, StopSequenceDecoder, StopSequenceDecoderBuilder,
};
use crate::tokenizer::traits::Tokenizer; use crate::tokenizer::traits::Tokenizer;
use crate::tokenizer::HuggingFaceTokenizer;
use crate::tool_parser::{StreamingParseResult, ToolParserFactory}; use crate::tool_parser::{StreamingParseResult, ToolParserFactory};
use grpc::utils;
use proto::generate_response::Response::{Chunk, Complete, Error}; use proto::generate_response::Response::{Chunk, Complete, Error};
use serde_json::{json, Map, Value}; use serde_json::{json, Value};
use std::time::{Instant, SystemTime, UNIX_EPOCH}; use std::time::{Instant, SystemTime, UNIX_EPOCH};
use tokio_stream::StreamExt; use tokio_stream::StreamExt;
use uuid::Uuid; use uuid::Uuid;
// Data structures for processing
#[derive(Debug)]
pub struct ProcessedMessages {
pub text: String,
pub multimodal_inputs: Option<proto::MultimodalInputs>,
pub stop_sequences: Option<StringOrArray>,
}
/// gRPC router implementation for SGLang /// gRPC router implementation for SGLang
#[derive(Clone)] #[derive(Clone)]
#[allow(dead_code)] #[allow(dead_code)]
...@@ -91,16 +78,6 @@ impl GrpcRouter { ...@@ -91,16 +78,6 @@ impl GrpcRouter {
let worker_registry = ctx.worker_registry.clone(); let worker_registry = ctx.worker_registry.clone();
let policy_registry = ctx.policy_registry.clone(); let policy_registry = ctx.policy_registry.clone();
let workers = worker_registry.get_workers_filtered(
None,
Some(WorkerType::Regular),
Some(ConnectionMode::Grpc { port: None }),
false,
);
RouterMetrics::set_active_workers(workers.len());
info!("gRPC router found {} workers in registry", workers.len());
Ok(GrpcRouter { Ok(GrpcRouter {
worker_registry, worker_registry,
policy_registry, policy_registry,
...@@ -125,56 +102,11 @@ impl GrpcRouter { ...@@ -125,56 +102,11 @@ impl GrpcRouter {
model_id model_id
); );
// Step 1: Select worker (fail fast if no workers available) // Step 1: Filter tools if needed for allowed_tools or specific function
let worker = match self.select_worker_for_request(model_id, None) { let body_ref = utils::filter_tools_for_request(body);
Some(w) => w,
None => {
warn!("No available workers for model: {:?}", model_id);
return (StatusCode::SERVICE_UNAVAILABLE, "No available workers").into_response();
}
};
debug!("Selected worker: {}", worker.url());
// Step 2: Get gRPC client from worker
let client = match Self::get_grpc_client_from_worker(&worker).await {
Ok(client) => client,
Err(response) => return response,
};
// Step 3: Filter tools if needed for allowed_tools or specific function
// Only clone body if we need to modify tools
let mut body_with_filtered_tools;
let body_ref = match &body.tool_choice {
Some(ToolChoice::AllowedTools { tools: allowed, .. }) if body.tools.is_some() => {
body_with_filtered_tools = body.clone();
let all_tools = body_with_filtered_tools.tools.as_ref().unwrap();
let allowed_names: std::collections::HashSet<&str> =
allowed.iter().map(|t| t.name.as_str()).collect();
let filtered_tools: Vec<Tool> = all_tools
.iter()
.filter(|t| allowed_names.contains(t.function.name.as_str()))
.cloned()
.collect();
body_with_filtered_tools.tools = Some(filtered_tools);
&body_with_filtered_tools
}
Some(ToolChoice::Function { function, .. }) if body.tools.is_some() => {
body_with_filtered_tools = body.clone();
let all_tools = body_with_filtered_tools.tools.as_ref().unwrap();
let filtered_tools: Vec<Tool> = all_tools
.iter()
.filter(|t| t.function.name == function.name)
.cloned()
.collect();
body_with_filtered_tools.tools = Some(filtered_tools);
&body_with_filtered_tools
}
_ => body, // No filtering needed, use original
};
// Step 4: Process messages and apply chat template // Step 2: Process messages and apply chat template
let processed_messages = match self.process_chat_messages(body_ref) { let processed_messages = match utils::process_chat_messages(&body_ref, &*self.tokenizer) {
Ok(msgs) => msgs, Ok(msgs) => msgs,
Err(e) => { Err(e) => {
error!("Failed to process chat messages: {}", e); error!("Failed to process chat messages: {}", e);
...@@ -182,7 +114,7 @@ impl GrpcRouter { ...@@ -182,7 +114,7 @@ impl GrpcRouter {
} }
}; };
// Step 5: Tokenize the processed text // Step 3: Tokenize the processed text
let encoding = match self.tokenizer.encode(&processed_messages.text) { let encoding = match self.tokenizer.encode(&processed_messages.text) {
Ok(encoding) => encoding, Ok(encoding) => encoding,
Err(e) => { Err(e) => {
...@@ -198,17 +130,35 @@ impl GrpcRouter { ...@@ -198,17 +130,35 @@ impl GrpcRouter {
let token_ids = encoding.token_ids().to_vec(); let token_ids = encoding.token_ids().to_vec();
debug!("Tokenized {} tokens from input", token_ids.len()); debug!("Tokenized {} tokens from input", token_ids.len());
// Step 6: Build tool constraints if needed // Step 4: Build tool constraints if needed
// body_ref already has filtered tools if needed // body_ref already has filtered tools if needed
let tool_call_constraint = body_ref.tools.as_ref().and_then(|tools| { let tool_call_constraint = body_ref.tools.as_ref().and_then(|tools| {
self.generate_tool_constraints(tools, &body.tool_choice, &body.model) utils::generate_tool_constraints(tools, &body.tool_choice, &body.model)
}); });
// Step 5: Select worker
let worker = match self.select_worker_for_request(model_id, Some(&processed_messages.text))
{
Some(w) => w,
None => {
warn!("No available workers for model: {:?}", model_id);
return (StatusCode::SERVICE_UNAVAILABLE, "No available workers").into_response();
}
};
debug!("Selected worker: {}", worker.url());
// Step 6: Get gRPC client from worker
let client = match utils::get_grpc_client_from_worker(&worker).await {
Ok(client) => client,
Err(response) => return response,
};
// Step 7: Build the base gRPC request (use body_ref with filtered tools if applicable) // Step 7: Build the base gRPC request (use body_ref with filtered tools if applicable)
let request_id = format!("chatcmpl-{}", Uuid::new_v4()); let request_id = format!("chatcmpl-{}", Uuid::new_v4());
let request = match client.build_generate_request( let request = match client.build_generate_request(
request_id, request_id,
body_ref, &body_ref,
processed_messages.text.clone(), processed_messages.text.clone(),
token_ids, token_ids,
processed_messages.multimodal_inputs, processed_messages.multimodal_inputs,
...@@ -265,7 +215,7 @@ impl GrpcRouter { ...@@ -265,7 +215,7 @@ impl GrpcRouter {
debug!("Selected worker: {}", worker.url()); debug!("Selected worker: {}", worker.url());
// Step 3: Get gRPC client from worker // Step 3: Get gRPC client from worker
let client = match Self::get_grpc_client_from_worker(&worker).await { let client = match utils::get_grpc_client_from_worker(&worker).await {
Ok(client) => client, Ok(client) => client,
Err(response) => return response, Err(response) => return response,
}; };
...@@ -299,44 +249,12 @@ impl GrpcRouter { ...@@ -299,44 +249,12 @@ impl GrpcRouter {
// Step 6: Handle streaming vs non-streaming // Step 6: Handle streaming vs non-streaming
if body.stream { if body.stream {
// TODO: Implement streaming support for generate endpoint self.handle_streaming_generate(client, request, body, request_id, weight_version)
return ( .await
StatusCode::NOT_IMPLEMENTED, } else {
"Streaming generate over gRPC is not supported yet", self.handle_non_streaming_generate(client, request, body, request_id, weight_version)
) .await
.into_response();
} }
self.handle_non_streaming_generate(client, request, body, request_id, weight_version)
.await
}
/// Get gRPC client from worker, returning appropriate error response on failure
async fn get_grpc_client_from_worker(
worker: &Arc<dyn Worker>,
) -> Result<SglangSchedulerClient, Response> {
let client_arc = worker
.get_grpc_client()
.await
.map_err(|e| {
error!("Failed to get gRPC client from worker: {}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to get gRPC client: {}", e),
)
.into_response()
})?
.ok_or_else(|| {
error!("Selected worker is not a gRPC worker");
(
StatusCode::INTERNAL_SERVER_ERROR,
"Selected worker is not configured for gRPC",
)
.into_response()
})?;
let client = client_arc.lock().await.clone();
Ok(client)
} }
/// Select a worker for the request /// Select a worker for the request
...@@ -375,412 +293,6 @@ impl GrpcRouter { ...@@ -375,412 +293,6 @@ impl GrpcRouter {
Some(available[idx].clone()) Some(available[idx].clone())
} }
/// Process chat messages and apply template
fn process_chat_messages(
&self,
request: &ChatCompletionRequest,
) -> Result<ProcessedMessages, String> {
// Use the tokenizer's chat template - we require HuggingFace tokenizer for gRPC
let formatted_text = if let Some(hf_tokenizer) = self
.tokenizer
.as_any()
.downcast_ref::<HuggingFaceTokenizer>()
{
// Get content format and transform messages accordingly
let content_format = hf_tokenizer.chat_template_content_format();
let mut transformed_messages =
Self::process_content_format(&request.messages, content_format)?;
// Process tool call arguments in assistant messages
Self::process_tool_call_arguments(&mut transformed_messages)?;
// Convert tools to JSON values for template processing
let tools_json: Option<Vec<Value>> = request
.tools
.as_ref()
.map(|tools| {
tools
.iter()
.map(serde_json::to_value)
.collect::<Result<Vec<_>, _>>()
})
.transpose()
.map_err(|e| format!("Failed to serialize tools: {}", e))?;
// Build template kwargs, merging reasoning_effort if present
let mut combined_template_kwargs = std::collections::HashMap::new();
// Add reasoning_effort if present (like Python does)
if let Some(reasoning_effort) = &request.reasoning_effort {
combined_template_kwargs.insert(
"reasoning_effort".to_string(),
Value::String(reasoning_effort.clone()),
);
}
// Add any additional template kwargs from request
if let Some(template_kwargs) = &request.chat_template_kwargs {
for (key, value) in template_kwargs {
combined_template_kwargs.insert(key.clone(), value.clone());
}
}
let final_template_kwargs = if combined_template_kwargs.is_empty() {
None
} else {
Some(&combined_template_kwargs)
};
let params = ChatTemplateParams {
add_generation_prompt: true,
continue_final_message: request.continue_final_message,
tools: tools_json.as_deref(),
template_kwargs: final_template_kwargs,
..Default::default()
};
// Handle assistant prefix for continue_final_message
let assistant_prefix = if request.continue_final_message
&& !transformed_messages.is_empty()
&& transformed_messages
.last()
.and_then(|msg| msg.get("role"))
.and_then(|v| v.as_str())
== Some("assistant")
{
// Pop the last message to handle it separately
let last_msg = transformed_messages.pop().unwrap();
last_msg
.get("content")
.and_then(|v| v.as_str())
.map(|s| s.to_string())
} else {
None
};
// Apply chat template with the (now possibly shorter) list of messages
let rendered = hf_tokenizer
.apply_chat_template(&transformed_messages, params)
.map_err(|e| format!("Failed to apply chat template: {}", e))?;
// Append assistant prefix if we have one
if let Some(prefix) = assistant_prefix {
format!("{}{}", rendered, prefix)
} else {
rendered
}
} else {
return Err(
"gRPC router requires HuggingFace tokenizer with chat template support".to_string(),
);
};
// Placeholder for multimodal inputs
let multimodal_inputs = None;
Ok(ProcessedMessages {
text: formatted_text,
multimodal_inputs,
stop_sequences: request.stop.clone(),
})
}
/// Process messages based on content format for ANY message type
fn process_content_format(
messages: &[ChatMessage],
content_format: ChatTemplateContentFormat,
) -> Result<Vec<Value>, String> {
messages
.iter()
.map(|message| {
let mut message_json = serde_json::to_value(message)
.map_err(|e| format!("Failed to serialize message: {}", e))?;
if let Some(obj) = message_json.as_object_mut() {
if let Some(content_value) = obj.get_mut("content") {
Self::transform_content_field(content_value, content_format);
}
}
Ok(message_json)
})
.collect()
}
/// Transform a single content field based on content format
fn transform_content_field(
content_value: &mut Value,
content_format: ChatTemplateContentFormat,
) {
let Some(content_array) = content_value.as_array() else {
return; // Not multimodal, keep as-is
};
match content_format {
ChatTemplateContentFormat::String => {
// Extract and join text parts only
let text_parts: Vec<String> = content_array
.iter()
.filter_map(|part| {
part.as_object()?
.get("type")?
.as_str()
.filter(|&t| t == "text")
.and_then(|_| part.as_object()?.get("text")?.as_str())
.map(String::from)
})
.collect();
if !text_parts.is_empty() {
*content_value = Value::String(text_parts.join(" "));
}
}
ChatTemplateContentFormat::OpenAI => {
// Replace media URLs with simple type placeholders
let processed_parts: Vec<Value> = content_array
.iter()
.map(|part| {
part.as_object()
.and_then(|obj| obj.get("type")?.as_str())
.and_then(|type_str| match type_str {
"image_url" => Some(json!({"type": "image"})),
"video_url" => Some(json!({"type": "video"})),
"audio_url" => Some(json!({"type": "audio"})),
_ => None,
})
.unwrap_or_else(|| part.clone())
})
.collect();
*content_value = Value::Array(processed_parts);
}
}
}
/// Process tool call arguments in messages
/// Per Transformers docs, tool call arguments in assistant messages should be dicts
fn process_tool_call_arguments(messages: &mut [Value]) -> Result<(), String> {
for msg in messages {
// Early return if not assistant message
let role = msg.get("role").and_then(|v| v.as_str());
if role != Some("assistant") {
continue;
}
// Early return if no tool_calls
let Some(tool_calls) = msg.get_mut("tool_calls").and_then(|tc| tc.as_array_mut())
else {
continue;
};
// Process each tool call's arguments
for call in tool_calls {
let Some(function) = call.get_mut("function") else {
continue;
};
let Some(args) = function.get_mut("arguments") else {
continue;
};
let Some(args_str) = args.as_str() else {
continue;
};
// Parse JSON string to object (like Python json.loads)
match serde_json::from_str::<Value>(args_str) {
Ok(parsed) => *args = parsed,
Err(e) => {
return Err(format!(
"Failed to parse tool call arguments as JSON: '{}'. Error: {}",
args_str, e
))
}
}
}
}
Ok(())
}
/// Generate tool constraints for structured generation
/// Note: tools should already be filtered if needed (by allowed_tools or specific function)
fn generate_tool_constraints(
&self,
tools: &[Tool],
tool_choice: &Option<ToolChoice>,
_model: &str,
) -> Option<(String, String)> {
let choice = tool_choice.as_ref()?;
match choice {
// Specific function: Return parameters schema directly
// tools should already be filtered to contain only the specific function
ToolChoice::Function { .. } => {
if tools.is_empty() {
return None;
}
let tool = &tools[0];
// Return the tool's parameters schema directly (not wrapped in array)
let params_schema = serde_json::to_string(&tool.function.parameters).ok()?;
Some(("json_schema".to_string(), params_schema))
}
// Required: Array of tool calls with minItems: 1
ToolChoice::Value(ToolChoiceValue::Required) => {
let schema = self.build_required_array_schema(tools)?;
Some(("json_schema".to_string(), schema))
}
// AllowedTools with required mode: tools are already filtered
ToolChoice::AllowedTools { mode, .. } => {
if mode == "required" {
if tools.is_empty() {
return None;
}
let schema = self.build_required_array_schema(tools)?;
Some(("json_schema".to_string(), schema))
} else {
// "auto" mode - no constraint needed
None
}
}
// "auto" or "none" - no constraint
_ => None,
}
}
/// Build JSON schema for required tool calls (array with minItems: 1)
/// Includes $defs consolidation from all tools (matching Python's behavior)
fn build_required_array_schema(&self, tools: &[Tool]) -> Option<String> {
// Build anyOf schemas for each tool
let mut any_of_schemas = Vec::new();
for tool in tools {
let tool_schema = json!({
"properties": {
"name": {
"type": "string",
"enum": [tool.function.name]
},
"parameters": tool.function.parameters
},
"required": ["name", "parameters"]
});
any_of_schemas.push(tool_schema);
}
// Consolidate $defs from all tools (matching Python's _get_tool_schema_defs)
let mut all_defs: HashMap<String, Value> = HashMap::new();
for tool in tools {
if let Value::Object(params) = &tool.function.parameters {
if let Some(Value::Object(defs)) = params.get("$defs") {
for (def_name, def_schema) in defs {
if let Some(existing) = all_defs.get(def_name) {
// Check for conflicts
if existing != def_schema {
error!(
"Tool definition '{}' has multiple schemas, which is not supported",
def_name
);
return None;
}
} else {
all_defs.insert(def_name.clone(), def_schema.clone());
}
}
}
}
}
// Build the full array schema
let mut array_schema = json!({
"type": "array",
"minItems": 1,
"items": {
"type": "object",
"anyOf": any_of_schemas
}
});
// Add $defs if any were found (matching Python's behavior)
if !all_defs.is_empty() {
if let Value::Object(ref mut schema_obj) = array_schema {
let defs_value =
Value::Object(all_defs.into_iter().collect::<Map<String, Value>>());
schema_obj.insert("$defs".to_string(), defs_value);
}
}
serde_json::to_string(&array_schema).ok()
}
/// Parse tool calls from JSON schema constrained response
fn parse_json_schema_response(
&self,
processed_text: &str,
tool_choice: &Option<ToolChoice>,
) -> (Option<Vec<ToolCall>>, String) {
match tool_choice {
Some(ToolChoice::Function { function, .. }) => {
// Specific function: Parse parameters directly
match serde_json::from_str::<Value>(processed_text) {
Ok(params) => {
let tool_call = ToolCall {
id: format!("call_{}", uuid::Uuid::new_v4()),
tool_type: "function".to_string(),
function: FunctionCallResponse {
name: function.name.clone(),
arguments: Some(
serde_json::to_string(&params)
.unwrap_or_else(|_| "{}".to_string()),
),
},
};
(Some(vec![tool_call]), String::new())
}
Err(e) => {
error!("Failed to parse specific function parameters: {}", e);
(None, processed_text.to_string())
}
}
}
Some(ToolChoice::Value(ToolChoiceValue::Required))
| Some(ToolChoice::AllowedTools { .. }) => {
// Required mode: Parse array of tool calls
match serde_json::from_str::<Vec<Value>>(processed_text) {
Ok(parsed_array) => {
let spec_tool_calls: Vec<ToolCall> = parsed_array
.into_iter()
.enumerate()
.filter_map(|(i, item)| {
let obj = item.as_object()?;
let name = obj.get("name")?.as_str()?.to_string();
let parameters = obj.get("parameters")?;
Some(ToolCall {
id: format!("call_{}_{}", i, uuid::Uuid::new_v4()),
tool_type: "function".to_string(),
function: FunctionCallResponse {
name,
arguments: Some(
serde_json::to_string(parameters)
.unwrap_or_else(|_| "{}".to_string()),
),
},
})
})
.collect();
(Some(spec_tool_calls), String::new())
}
Err(e) => {
error!("Failed to parse required tool call array: {}", e);
(None, processed_text.to_string())
}
}
}
_ => (None, processed_text.to_string()),
}
}
/// Parse tool calls using model-specific parser /// Parse tool calls using model-specific parser
async fn parse_tool_calls( async fn parse_tool_calls(
&self, &self,
...@@ -895,48 +407,6 @@ impl GrpcRouter { ...@@ -895,48 +407,6 @@ impl GrpcRouter {
(StatusCode::INTERNAL_SERVER_ERROR, message).into_response() (StatusCode::INTERNAL_SERVER_ERROR, message).into_response()
} }
/// Create a StopSequenceDecoder from stop parameters
fn create_stop_decoder(
&self,
stop: Option<&StringOrArray>,
stop_token_ids: Option<&Vec<u32>>,
skip_special_tokens: bool,
no_stop_trim: bool,
) -> StopSequenceDecoder {
// Extract stop sequences
let stop_sequences: Vec<String> = match stop {
Some(StringOrArray::String(s)) => vec![s.clone()],
Some(StringOrArray::Array(arr)) => arr.clone(),
None => vec![],
};
// Build stop sequence decoder
let mut builder = StopSequenceDecoderBuilder::new(self.tokenizer.clone())
.skip_special_tokens(skip_special_tokens);
// Add stop sequences (visible if no_stop_trim is true, hidden otherwise)
for seq in stop_sequences {
builder = if no_stop_trim {
builder.visible_stop_sequence(seq)
} else {
builder.stop_sequence(seq)
};
}
// Add stop token IDs (visible if no_stop_trim is true, hidden otherwise)
if let Some(token_ids) = stop_token_ids {
for &token_id in token_ids {
builder = if no_stop_trim {
builder.visible_stop_token(token_id)
} else {
builder.stop_token(token_id)
};
}
}
builder.build()
}
/// Count the number of tool calls in the request message history /// Count the number of tool calls in the request message history
/// This is used for KimiK2 format which needs globally unique indices /// This is used for KimiK2 format which needs globally unique indices
fn get_history_tool_calls_count(request: &ChatCompletionRequest) -> usize { fn get_history_tool_calls_count(request: &ChatCompletionRequest) -> usize {
...@@ -1354,7 +824,8 @@ impl GrpcRouter { ...@@ -1354,7 +824,8 @@ impl GrpcRouter {
// Create stop decoder // Create stop decoder
let (stop, stop_token_ids, skip_special_tokens, no_stop_trim) = stop_params; let (stop, stop_token_ids, skip_special_tokens, no_stop_trim) = stop_params;
let mut stop_decoder = router.create_stop_decoder( let mut stop_decoder = utils::create_stop_decoder(
&router.tokenizer,
stop.as_ref(), stop.as_ref(),
stop_token_ids.as_ref(), stop_token_ids.as_ref(),
skip_special_tokens, skip_special_tokens,
...@@ -1678,7 +1149,8 @@ impl GrpcRouter { ...@@ -1678,7 +1149,8 @@ impl GrpcRouter {
request: proto::GenerateRequest, request: proto::GenerateRequest,
original_request: &ChatCompletionRequest, original_request: &ChatCompletionRequest,
) -> Response { ) -> Response {
let mut stop_decoder = self.create_stop_decoder( let mut stop_decoder = utils::create_stop_decoder(
&self.tokenizer,
original_request.stop.as_ref(), original_request.stop.as_ref(),
original_request.stop_token_ids.as_ref(), original_request.stop_token_ids.as_ref(),
original_request.skip_special_tokens, original_request.skip_special_tokens,
...@@ -1686,42 +1158,17 @@ impl GrpcRouter { ...@@ -1686,42 +1158,17 @@ impl GrpcRouter {
); );
// Start generation // Start generation
let mut stream = match client.generate(request).await { let stream = match client.generate(request).await {
Ok(s) => s, Ok(s) => s,
Err(e) => { Err(e) => {
return Self::internal_error_message(format!("Failed to start generation: {}", e)) return Self::internal_error_message(format!("Failed to start generation: {}", e))
} }
}; };
// Collect all responses (for n>1 support) let all_responses = match utils::collect_stream_responses(stream, "Regular").await {
let mut all_responses = Vec::new(); Ok(responses) => responses,
while let Some(response) = stream.next().await { Err(err_response) => return err_response,
match response { };
Ok(gen_response) => match gen_response.response {
Some(Complete(complete)) => {
all_responses.push(complete);
}
Some(Error(err)) => {
return Self::internal_error_message(format!(
"Generation failed: {}",
err.message
));
}
Some(Chunk(_)) => {
return Self::internal_error_static(
"Unexpected chunk response for non-streaming request",
)
}
None => return Self::internal_error_static("Empty response from server"),
},
Err(e) => {
return Self::internal_error_message(format!(
"Failed to get GenerateResponse: {}",
e
))
}
}
}
if all_responses.is_empty() { if all_responses.is_empty() {
return Self::internal_error_static("No responses from server"); return Self::internal_error_static("No responses from server");
...@@ -1793,115 +1240,262 @@ impl GrpcRouter { ...@@ -1793,115 +1240,262 @@ impl GrpcRouter {
) -> Response { ) -> Response {
let start_time = Instant::now(); let start_time = Instant::now();
let mut stream = match client.generate(request).await { let stream = match client.generate(request).await {
Ok(stream) => stream, Ok(stream) => stream,
Err(e) => { Err(e) => {
return Self::internal_error_message(format!("Failed to start generation: {}", e)) return Self::internal_error_message(format!("Failed to start generation: {}", e))
} }
}; };
let mut final_completion: Option<proto::GenerateComplete> = None; // Collect all responses using utils helper
let responses = match utils::collect_stream_responses(stream, "Generate").await {
Ok(responses) => responses,
Err(error_response) => return error_response,
};
while let Some(result) = stream.next().await { if responses.is_empty() {
match result { return Self::internal_error_static("No completion received from scheduler");
Ok(gen_response) => match gen_response.response {
Some(Complete(complete)) => {
final_completion = Some(complete);
break;
}
Some(Error(err)) => {
return Self::internal_error_message(format!(
"Generation failed: {}",
err.message
));
}
Some(Chunk(_)) | None => continue,
},
Err(e) => {
return Self::internal_error_message(format!(
"Failed to receive generate response: {}",
e
))
}
}
} }
let mut complete = match final_completion {
Some(c) => c,
None => {
return Self::internal_error_static("No completion received from scheduler");
}
};
// Create stop decoder from sampling params // Create stop decoder from sampling params
let params = original_request.sampling_params.as_ref(); let params = original_request.sampling_params.as_ref();
let mut stop_decoder = self.create_stop_decoder( let mut stop_decoder = utils::create_stop_decoder(
&self.tokenizer,
params.and_then(|p| p.stop.as_ref()), params.and_then(|p| p.stop.as_ref()),
params.and_then(|p| p.stop_token_ids.as_ref()), params.and_then(|p| p.stop_token_ids.as_ref()),
params.and_then(|p| p.skip_special_tokens).unwrap_or(true), params.and_then(|p| p.skip_special_tokens).unwrap_or(true),
params.and_then(|p| p.no_stop_trim).unwrap_or(false), params.and_then(|p| p.no_stop_trim).unwrap_or(false),
); );
// Process tokens through stop decoder // Process each completion
let outputs = match stop_decoder.process_tokens(&complete.output_ids) { let mut result_array = Vec::new();
Ok(outputs) => outputs, for mut complete in responses {
Err(e) => { stop_decoder.reset();
return Self::internal_error_message(format!("Failed to process tokens: {}", e))
}
};
// Accumulate text with early breaks // Process tokens through stop decoder
let mut decoded_text = String::new(); let outputs = match stop_decoder.process_tokens(&complete.output_ids) {
for output in outputs { Ok(outputs) => outputs,
match output { Err(e) => {
SequenceDecoderOutput::Text(t) => decoded_text.push_str(&t), return Self::internal_error_message(format!("Failed to process tokens: {}", e))
SequenceDecoderOutput::StoppedWithText(t) => { }
decoded_text.push_str(&t); };
break;
// Accumulate text with early breaks
let mut decoded_text = String::new();
for output in outputs {
match output {
SequenceDecoderOutput::Text(t) => decoded_text.push_str(&t),
SequenceDecoderOutput::StoppedWithText(t) => {
decoded_text.push_str(&t);
break;
}
SequenceDecoderOutput::Stopped => break,
SequenceDecoderOutput::Held => {}
} }
SequenceDecoderOutput::Stopped => break,
SequenceDecoderOutput::Held => {}
} }
}
// Flush remaining text // Flush remaining text
if let SequenceDecoderOutput::Text(t) = stop_decoder.flush() { if let SequenceDecoderOutput::Text(t) = stop_decoder.flush() {
decoded_text.push_str(&t); decoded_text.push_str(&t);
} }
let output_ids = std::mem::take(&mut complete.output_ids); let output_ids = std::mem::take(&mut complete.output_ids);
let finish_reason = std::mem::take(&mut complete.finish_reason); let finish_reason = std::mem::take(&mut complete.finish_reason);
// Build base meta_info using json! macro // Build base meta_info using json! macro
let mut meta_info = json!({ let mut meta_info = json!({
"finish_reason": finish_reason, "id": request_id.clone(),
"prompt_tokens": complete.prompt_tokens, "finish_reason": finish_reason,
"completion_tokens": complete.completion_tokens, "prompt_tokens": complete.prompt_tokens,
"cached_tokens": complete.cached_tokens, "weight_version": weight_version.clone(),
"id": request_id, "completion_tokens": complete.completion_tokens,
"weight_version": weight_version, "cached_tokens": complete.cached_tokens,
"e2e_latency": start_time.elapsed().as_secs_f64(), "e2e_latency": start_time.elapsed().as_secs_f64(),
}); });
let meta_obj = meta_info.as_object_mut().unwrap(); let meta_obj = meta_info.as_object_mut().unwrap();
// Add matched_stop if present // Add matched_stop if present
if let Some(matched) = complete.matched_stop.take() { if let Some(matched) = complete.matched_stop.take() {
use proto::generate_complete::MatchedStop; use proto::generate_complete::MatchedStop;
let matched_value = match matched { let matched_value = match matched {
MatchedStop::MatchedTokenId(id) => json!(id), MatchedStop::MatchedTokenId(id) => json!(id),
MatchedStop::MatchedStopStr(s) => json!(s), MatchedStop::MatchedStopStr(s) => json!(s),
}; };
meta_obj.insert("matched_stop".to_string(), matched_value); meta_obj.insert("matched_stop".to_string(), matched_value);
}
result_array.push(json!({
"text": decoded_text,
"output_ids": output_ids,
"meta_info": meta_info,
}));
} }
let response_body = json!({ Json(result_array).into_response()
"text": decoded_text, }
"output_ids": output_ids,
"meta_info": meta_info, /// Submit request and handle streaming response for the `/generate` endpoint
async fn handle_streaming_generate(
&self,
mut client: SglangSchedulerClient,
request: proto::GenerateRequest,
original_request: &GenerateRequest,
request_id: String,
weight_version: String,
) -> Response {
let tokenizer = self.tokenizer.clone();
let return_logprob = original_request.return_logprob;
// Create channel for SSE streaming
let (tx, rx) =
tokio::sync::mpsc::unbounded_channel::<Result<bytes::Bytes, std::io::Error>>();
// Start the stream
let stream = match client.generate(request).await {
Ok(stream) => stream,
Err(e) => {
return Self::internal_error_message(format!("Failed to start generation: {}", e))
}
};
// Spawn async task to process stream
tokio::spawn(async move {
let result = Self::process_generate_streaming(
tokenizer,
stream,
request_id,
weight_version,
return_logprob,
&tx,
)
.await;
if let Err(e) = result {
let error_chunk = format!("data: {{\"error\": \"{}\"}}\n\n", e);
let _ = tx.send(Ok(bytes::Bytes::from(error_chunk)));
}
// Send [DONE] marker
let _ = tx.send(Ok(bytes::Bytes::from("data: [DONE]\n\n")));
}); });
Json(response_body).into_response() // Create SSE response stream
let body_stream = tokio_stream::wrappers::UnboundedReceiverStream::new(rx);
Response::builder()
.status(StatusCode::OK)
.header("Content-Type", "text/event-stream")
.header("Cache-Control", "no-cache")
.header("Connection", "keep-alive")
.body(axum::body::Body::from_stream(body_stream))
.unwrap()
}
/// Process streaming chunks for generate endpoint
async fn process_generate_streaming(
tokenizer: Arc<dyn Tokenizer>,
mut stream: impl tokio_stream::Stream<Item = Result<proto::GenerateResponse, tonic::Status>>
+ Unpin,
request_id: String,
weight_version: String,
_include_logprobs: bool,
tx: &tokio::sync::mpsc::UnboundedSender<Result<bytes::Bytes, std::io::Error>>,
) -> Result<(), String> {
use proto::generate_response::Response::{Chunk, Complete, Error};
use std::time::Instant;
use tokio_stream::StreamExt;
let start_time = Instant::now();
// Track state per index for n>1 case
use std::collections::HashMap;
let mut accumulated_texts: HashMap<u32, String> = HashMap::new();
let mut completion_tokens_map: HashMap<u32, u32> = HashMap::new();
while let Some(response) = stream.next().await {
let gen_response = response.map_err(|e| format!("Stream error: {}", e))?;
match gen_response.response {
Some(Chunk(chunk)) => {
let index = chunk.index;
// Update completion tokens for this index
let completion_tokens = completion_tokens_map.entry(index).or_insert(0);
*completion_tokens += chunk.token_ids.len() as u32;
// Decode tokens to text (skip_special_tokens=true to handle newlines correctly)
let chunk_text = tokenizer.decode(&chunk.token_ids, true).unwrap_or_default();
// Accumulate text for this index
let accumulated_text = accumulated_texts.entry(index).or_default();
accumulated_text.push_str(&chunk_text);
// Generate unique ID per index
let index_id = format!("{}-{}", request_id, index);
// Build streaming response chunk (SGLang format)
let chunk_response = serde_json::json!({
"text": accumulated_text.clone(),
"output_ids": chunk.token_ids,
"meta_info": {
"id": index_id,
"finish_reason": null,
"prompt_tokens": chunk.prompt_tokens,
"weight_version": weight_version,
"completion_tokens": *completion_tokens,
"cached_tokens": chunk.cached_tokens
},
"index": index
});
let sse_chunk = format!(
"data: {}\n\n",
serde_json::to_string(&chunk_response).unwrap()
);
tx.send(Ok(bytes::Bytes::from(sse_chunk)))
.map_err(|_| "Failed to send chunk".to_string())?;
}
Some(Complete(complete)) => {
let index = complete.index;
let accumulated_text =
accumulated_texts.get(&index).cloned().unwrap_or_default();
let completion_tokens = *completion_tokens_map.get(&index).unwrap_or(&0);
let index_id = format!("{}-{}", request_id, index);
let e2e_latency = start_time.elapsed().as_secs_f64();
// Send final chunk with finish_reason (no new tokens in Complete, they were already sent in Chunks)
let finish_response = serde_json::json!({
"text": accumulated_text,
"output_ids": complete.output_ids[complete.output_ids.len().saturating_sub(1)..].to_vec(),
"meta_info": {
"id": index_id,
"finish_reason": complete.finish_reason,
"prompt_tokens": complete.prompt_tokens,
"weight_version": weight_version,
"completion_tokens": completion_tokens,
"cached_tokens": complete.cached_tokens,
"e2e_latency": e2e_latency
},
"index": index
});
let sse_chunk = format!(
"data: {}\n\n",
serde_json::to_string(&finish_response).unwrap()
);
tx.send(Ok(bytes::Bytes::from(sse_chunk)))
.map_err(|_| "Failed to send finish chunk".to_string())?;
// Continue to process all completions if n>1
}
Some(Error(error)) => {
return Err(error.message);
}
None => continue,
}
}
Ok(())
} }
/// Convert proto LogProbs to OpenAI ChatLogProbs format /// Convert proto LogProbs to OpenAI ChatLogProbs format
...@@ -2036,28 +1630,28 @@ impl GrpcRouter { ...@@ -2036,28 +1630,28 @@ impl GrpcRouter {
} }
// Step 2: Handle tool call parsing // Step 2: Handle tool call parsing
let mut tool_calls: Option<Vec<crate::protocols::spec::ToolCall>> = None; let mut tool_calls: Option<Vec<ToolCall>> = None;
// Check if tool calls should be processed // Check if tool calls should be processed
let tool_choice_enabled = !matches!( let tool_choice_enabled = !matches!(
&original_request.tool_choice, &original_request.tool_choice,
Some(ToolChoice::Value( Some(ToolChoice::Value(ToolChoiceValue::None))
crate::protocols::spec::ToolChoiceValue::None
))
); );
if tool_choice_enabled && original_request.tools.is_some() { if tool_choice_enabled && original_request.tools.is_some() {
// Check if JSON schema constraint was used (specific function or required mode) // Check if JSON schema constraint was used (specific function or required mode)
let used_json_schema = match &original_request.tool_choice { let used_json_schema = match &original_request.tool_choice {
Some(ToolChoice::Function { .. }) => true, Some(ToolChoice::Function { .. }) => true,
Some(ToolChoice::Value(crate::protocols::spec::ToolChoiceValue::Required)) => true, Some(ToolChoice::Value(ToolChoiceValue::Required)) => true,
Some(ToolChoice::AllowedTools { mode, .. }) => mode == "required", Some(ToolChoice::AllowedTools { mode, .. }) => mode == "required",
_ => false, _ => false,
}; };
if used_json_schema { if used_json_schema {
(tool_calls, processed_text) = (tool_calls, processed_text) = utils::parse_json_schema_response(
self.parse_json_schema_response(&processed_text, &original_request.tool_choice); &processed_text,
&original_request.tool_choice,
);
} else { } else {
(tool_calls, processed_text) = self (tool_calls, processed_text) = self
.parse_tool_calls( .parse_tool_calls(
...@@ -2081,11 +1675,11 @@ impl GrpcRouter { ...@@ -2081,11 +1675,11 @@ impl GrpcRouter {
// Extract matched_stop information from proto // Extract matched_stop information from proto
let matched_stop = match &complete.matched_stop { let matched_stop = match &complete.matched_stop {
Some(proto::generate_complete::MatchedStop::MatchedTokenId(token_id)) => Some( Some(proto::generate_complete::MatchedStop::MatchedTokenId(token_id)) => {
serde_json::Value::Number(serde_json::Number::from(*token_id)), Some(Value::Number(serde_json::Number::from(*token_id)))
), }
Some(proto::generate_complete::MatchedStop::MatchedStopStr(stop_str)) => { Some(proto::generate_complete::MatchedStop::MatchedStopStr(stop_str)) => {
Some(serde_json::Value::String(stop_str.clone())) Some(Value::String(stop_str.clone()))
} }
None => None, None => None,
}; };
...@@ -2239,240 +1833,3 @@ impl RouterTrait for GrpcRouter { ...@@ -2239,240 +1833,3 @@ impl RouterTrait for GrpcRouter {
"grpc" "grpc"
} }
} }
#[cfg(test)]
mod tests {
use super::*;
use crate::protocols::spec::{ChatMessage, ContentPart, ImageUrl, UserMessageContent};
use crate::tokenizer::chat_template::ChatTemplateContentFormat;
use serde_json::json;
#[test]
fn test_transform_messages_string_format() {
let messages = vec![ChatMessage::User {
role: "user".to_string(),
content: UserMessageContent::Parts(vec![
ContentPart::Text {
text: "Hello".to_string(),
},
ContentPart::ImageUrl {
image_url: ImageUrl {
url: "https://example.com/image.jpg".to_string(),
detail: None,
},
},
ContentPart::Text {
text: "World".to_string(),
},
]),
name: None,
}];
let result =
GrpcRouter::process_content_format(&messages, ChatTemplateContentFormat::String)
.unwrap();
assert_eq!(result.len(), 1);
let transformed_message = &result[0];
// Should flatten multimodal content to text only
assert_eq!(
transformed_message["content"].as_str().unwrap(),
"Hello World"
);
assert_eq!(transformed_message["role"].as_str().unwrap(), "user");
}
#[test]
fn test_transform_messages_openai_format() {
let messages = vec![ChatMessage::User {
role: "user".to_string(),
content: UserMessageContent::Parts(vec![
ContentPart::Text {
text: "Describe this image:".to_string(),
},
ContentPart::ImageUrl {
image_url: ImageUrl {
url: "https://example.com/image.jpg".to_string(),
detail: Some("high".to_string()),
},
},
]),
name: None,
}];
let result =
GrpcRouter::process_content_format(&messages, ChatTemplateContentFormat::OpenAI)
.unwrap();
assert_eq!(result.len(), 1);
let transformed_message = &result[0];
// Should replace media URLs with simple type placeholders
let content_array = transformed_message["content"].as_array().unwrap();
assert_eq!(content_array.len(), 2);
// Text part should remain unchanged
assert_eq!(content_array[0]["type"], "text");
assert_eq!(content_array[0]["text"], "Describe this image:");
// Image part should be replaced with simple type placeholder
assert_eq!(content_array[1], json!({"type": "image"}));
}
#[test]
fn test_transform_messages_simple_string_content() {
let messages = vec![ChatMessage::User {
role: "user".to_string(),
content: UserMessageContent::Text("Simple text message".to_string()),
name: None,
}];
let result =
GrpcRouter::process_content_format(&messages, ChatTemplateContentFormat::String)
.unwrap();
assert_eq!(result.len(), 1);
let transformed_message = &result[0];
// Simple string content should remain unchanged
assert_eq!(
transformed_message["content"].as_str().unwrap(),
"Simple text message"
);
}
#[test]
fn test_transform_messages_assistant_message() {
let messages = vec![ChatMessage::Assistant {
role: "assistant".to_string(),
content: Some("Assistant response".to_string()),
name: None,
tool_calls: None,
reasoning_content: None,
}];
let result =
GrpcRouter::process_content_format(&messages, ChatTemplateContentFormat::String)
.unwrap();
assert_eq!(result.len(), 1);
let transformed_message = &result[0];
assert_eq!(transformed_message["role"].as_str().unwrap(), "assistant");
assert_eq!(
transformed_message["content"].as_str().unwrap(),
"Assistant response"
);
}
#[test]
fn test_transform_messages_multiple_messages() {
let messages = vec![
ChatMessage::System {
role: "system".to_string(),
content: "System prompt".to_string(),
name: None,
},
ChatMessage::User {
role: "user".to_string(),
content: UserMessageContent::Parts(vec![
ContentPart::Text {
text: "User message".to_string(),
},
ContentPart::ImageUrl {
image_url: ImageUrl {
url: "https://example.com/image.jpg".to_string(),
detail: None,
},
},
]),
name: None,
},
];
let result =
GrpcRouter::process_content_format(&messages, ChatTemplateContentFormat::String)
.unwrap();
assert_eq!(result.len(), 2);
// System message should remain unchanged
assert_eq!(result[0]["role"].as_str().unwrap(), "system");
assert_eq!(result[0]["content"].as_str().unwrap(), "System prompt");
// User message should be flattened to text only
assert_eq!(result[1]["role"].as_str().unwrap(), "user");
assert_eq!(result[1]["content"].as_str().unwrap(), "User message");
}
#[test]
fn test_transform_messages_empty_text_parts() {
let messages = vec![ChatMessage::User {
role: "user".to_string(),
content: UserMessageContent::Parts(vec![ContentPart::ImageUrl {
image_url: ImageUrl {
url: "https://example.com/image.jpg".to_string(),
detail: None,
},
}]),
name: None,
}];
let result =
GrpcRouter::process_content_format(&messages, ChatTemplateContentFormat::String)
.unwrap();
assert_eq!(result.len(), 1);
let transformed_message = &result[0];
// Should keep original multimodal content when no text parts exist
assert!(transformed_message["content"].is_array());
}
#[test]
fn test_transform_messages_mixed_content_types() {
let messages = vec![
ChatMessage::User {
role: "user".to_string(),
content: UserMessageContent::Text("Plain text".to_string()),
name: None,
},
ChatMessage::User {
role: "user".to_string(),
content: UserMessageContent::Parts(vec![
ContentPart::Text {
text: "With image".to_string(),
},
ContentPart::ImageUrl {
image_url: ImageUrl {
url: "https://example.com/image.jpg".to_string(),
detail: Some("low".to_string()),
},
},
]),
name: None,
},
];
let result_string =
GrpcRouter::process_content_format(&messages, ChatTemplateContentFormat::String)
.unwrap();
assert_eq!(result_string.len(), 2);
assert_eq!(result_string[0]["content"].as_str().unwrap(), "Plain text");
assert_eq!(result_string[1]["content"].as_str().unwrap(), "With image");
let result_openai =
GrpcRouter::process_content_format(&messages, ChatTemplateContentFormat::OpenAI)
.unwrap();
assert_eq!(result_openai.len(), 2);
assert_eq!(result_openai[0]["content"].as_str().unwrap(), "Plain text");
let content_array = result_openai[1]["content"].as_array().unwrap();
assert_eq!(content_array.len(), 2);
assert_eq!(content_array[0]["type"], "text");
assert_eq!(content_array[1], json!({"type": "image"}));
}
}
//! Shared utilities for gRPC routers
use super::ProcessedMessages;
use crate::core::Worker;
use crate::grpc_client::{proto, SglangSchedulerClient};
use crate::protocols::spec::{
ChatCompletionRequest, ChatMessage, FunctionCallResponse, StringOrArray, Tool, ToolCall,
ToolChoice, ToolChoiceValue,
};
use crate::tokenizer::chat_template::{ChatTemplateContentFormat, ChatTemplateParams};
use crate::tokenizer::traits::Tokenizer;
use crate::tokenizer::HuggingFaceTokenizer;
pub use crate::tokenizer::StopSequenceDecoder;
use axum::{
http::StatusCode,
response::{IntoResponse, Response},
};
use futures::StreamExt;
use serde_json::{json, Map, Value};
use std::collections::HashMap;
use std::sync::Arc;
use tonic::codec::Streaming;
use tracing::{debug, error};
use uuid::Uuid;
/// Get gRPC client from worker, returning appropriate error response on failure
pub async fn get_grpc_client_from_worker(
worker: &Arc<dyn Worker>,
) -> Result<SglangSchedulerClient, Response> {
let client_arc = worker
.get_grpc_client()
.await
.map_err(|e| {
error!("Failed to get gRPC client from worker: {}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to get gRPC client: {}", e),
)
.into_response()
})?
.ok_or_else(|| {
error!("Selected worker is not a gRPC worker");
(
StatusCode::INTERNAL_SERVER_ERROR,
"Selected worker is not configured for gRPC",
)
.into_response()
})?;
let client = client_arc.lock().await.clone();
Ok(client)
}
/// Process tool call arguments in messages
/// Per Transformers docs, tool call arguments in assistant messages should be dicts
pub fn process_tool_call_arguments(messages: &mut [Value]) -> Result<(), String> {
for msg in messages {
// Early return if not assistant message
let role = msg.get("role").and_then(|v| v.as_str());
if role != Some("assistant") {
continue;
}
// Early return if no tool_calls
let Some(tool_calls) = msg.get_mut("tool_calls").and_then(|tc| tc.as_array_mut()) else {
continue;
};
// Process each tool call's arguments
for call in tool_calls {
let Some(function) = call.get_mut("function") else {
continue;
};
let Some(args) = function.get_mut("arguments") else {
continue;
};
let Some(args_str) = args.as_str() else {
continue;
};
// Parse JSON string to object (like Python json.loads)
match serde_json::from_str::<Value>(args_str) {
Ok(parsed) => *args = parsed,
Err(e) => {
return Err(format!(
"Failed to parse tool call arguments as JSON: '{}'. Error: {}",
args_str, e
))
}
}
}
}
Ok(())
}
/// Process messages based on content format for ANY message type
pub fn process_content_format(
messages: &[ChatMessage],
content_format: ChatTemplateContentFormat,
) -> Result<Vec<Value>, String> {
messages
.iter()
.map(|message| {
let mut message_json = serde_json::to_value(message)
.map_err(|e| format!("Failed to serialize message: {}", e))?;
if let Some(obj) = message_json.as_object_mut() {
if let Some(content_value) = obj.get_mut("content") {
transform_content_field(content_value, content_format);
}
}
Ok(message_json)
})
.collect()
}
/// Transform a single content field based on content format
pub fn transform_content_field(
content_value: &mut Value,
content_format: ChatTemplateContentFormat,
) {
let Some(content_array) = content_value.as_array() else {
return; // Not multimodal, keep as-is
};
match content_format {
ChatTemplateContentFormat::String => {
// Extract and join text parts only
let text_parts: Vec<String> = content_array
.iter()
.filter_map(|part| {
part.as_object()?
.get("type")?
.as_str()
.filter(|&t| t == "text")
.and_then(|_| part.as_object()?.get("text")?.as_str())
.map(String::from)
})
.collect();
if !text_parts.is_empty() {
*content_value = Value::String(text_parts.join(" "));
}
}
ChatTemplateContentFormat::OpenAI => {
// Replace media URLs with simple type placeholders
let processed_parts: Vec<Value> = content_array
.iter()
.map(|part| {
part.as_object()
.and_then(|obj| obj.get("type")?.as_str())
.and_then(|type_str| match type_str {
"image_url" => Some(json!({"type": "image"})),
"video_url" => Some(json!({"type": "video"})),
"audio_url" => Some(json!({"type": "audio"})),
_ => None,
})
.unwrap_or_else(|| part.clone())
})
.collect();
*content_value = Value::Array(processed_parts);
}
}
}
/// Generate tool constraints for structured generation
/// Note: tools should already be filtered if needed (by allowed_tools or specific function)
pub fn generate_tool_constraints(
tools: &[Tool],
tool_choice: &Option<ToolChoice>,
_model: &str,
) -> Option<(String, String)> {
let choice = tool_choice.as_ref()?;
match choice {
// Specific function: Return parameters schema directly
// tools should already be filtered to contain only the specific function
ToolChoice::Function { .. } => {
if tools.is_empty() {
return None;
}
let tool = &tools[0];
// Return the tool's parameters schema directly (not wrapped in array)
let params_schema = serde_json::to_string(&tool.function.parameters).ok()?;
Some(("json_schema".to_string(), params_schema))
}
// Required: Array of tool calls with minItems: 1
ToolChoice::Value(ToolChoiceValue::Required) => {
let schema = build_required_array_schema(tools)?;
Some(("json_schema".to_string(), schema))
}
// AllowedTools with required mode: tools are already filtered
ToolChoice::AllowedTools { mode, .. } => {
if mode == "required" {
if tools.is_empty() {
return None;
}
let schema = build_required_array_schema(tools)?;
Some(("json_schema".to_string(), schema))
} else {
// "auto" mode - no constraint needed
None
}
}
// "auto" or "none" - no constraint
_ => None,
}
}
/// Build JSON schema for required tool calls (array with minItems: 1)
/// Includes $defs consolidation from all tools (matching Python's behavior)
pub fn build_required_array_schema(tools: &[Tool]) -> Option<String> {
// Build anyOf schemas for each tool
let mut any_of_schemas = Vec::new();
for tool in tools {
let tool_schema = json!({
"properties": {
"name": {
"type": "string",
"enum": [tool.function.name]
},
"parameters": tool.function.parameters
},
"required": ["name", "parameters"]
});
any_of_schemas.push(tool_schema);
}
// Consolidate $defs from all tools (matching Python's _get_tool_schema_defs)
let mut all_defs: HashMap<String, Value> = HashMap::new();
for tool in tools {
if let Value::Object(params) = &tool.function.parameters {
if let Some(Value::Object(defs)) = params.get("$defs") {
for (def_name, def_schema) in defs {
if let Some(existing) = all_defs.get(def_name) {
// Check for conflicts
if existing != def_schema {
error!(
"Tool definition '{}' has multiple schemas, which is not supported",
def_name
);
return None;
}
} else {
all_defs.insert(def_name.clone(), def_schema.clone());
}
}
}
}
}
// Build the full array schema
let mut array_schema = json!({
"type": "array",
"minItems": 1,
"items": {
"type": "object",
"anyOf": any_of_schemas
}
});
// Add $defs if any were found (matching Python's behavior)
if !all_defs.is_empty() {
if let Value::Object(ref mut schema_obj) = array_schema {
let defs_value = Value::Object(all_defs.into_iter().collect::<Map<String, Value>>());
schema_obj.insert("$defs".to_string(), defs_value);
}
}
serde_json::to_string(&array_schema).ok()
}
/// Filter tools based on tool_choice (shared by both routers)
/// Returns a reference to the original body if no filtering needed,
/// otherwise returns a cloned and filtered body
pub fn filter_tools_for_request(
body: &ChatCompletionRequest,
) -> std::borrow::Cow<'_, ChatCompletionRequest> {
match &body.tool_choice {
Some(ToolChoice::AllowedTools { tools: allowed, .. }) if body.tools.is_some() => {
let mut filtered_body = body.clone();
let all_tools = filtered_body.tools.as_ref().unwrap();
let allowed_names: std::collections::HashSet<&str> =
allowed.iter().map(|t| t.name.as_str()).collect();
let filtered_tools: Vec<Tool> = all_tools
.iter()
.filter(|t| allowed_names.contains(t.function.name.as_str()))
.cloned()
.collect();
filtered_body.tools = Some(filtered_tools);
std::borrow::Cow::Owned(filtered_body)
}
Some(ToolChoice::Function { function, .. }) if body.tools.is_some() => {
let mut filtered_body = body.clone();
let all_tools = filtered_body.tools.as_ref().unwrap();
let filtered_tools: Vec<Tool> = all_tools
.iter()
.filter(|t| t.function.name == function.name)
.cloned()
.collect();
filtered_body.tools = Some(filtered_tools);
std::borrow::Cow::Owned(filtered_body)
}
_ => std::borrow::Cow::Borrowed(body), // No filtering needed, use original
}
}
/// Process chat messages and apply template (shared by both routers)
/// Requires HuggingFace tokenizer with chat template support
pub fn process_chat_messages(
request: &ChatCompletionRequest,
tokenizer: &dyn Tokenizer,
) -> Result<ProcessedMessages, String> {
// Use the tokenizer's chat template - we require HuggingFace tokenizer for gRPC
let formatted_text = if let Some(hf_tokenizer) =
tokenizer.as_any().downcast_ref::<HuggingFaceTokenizer>()
{
// Get content format and transform messages accordingly
let content_format = hf_tokenizer.chat_template_content_format();
let mut transformed_messages = process_content_format(&request.messages, content_format)?;
// Process tool call arguments in assistant messages
process_tool_call_arguments(&mut transformed_messages)?;
// Convert tools to JSON values for template processing
let tools_json: Option<Vec<Value>> = request
.tools
.as_ref()
.map(|tools| {
tools
.iter()
.map(serde_json::to_value)
.collect::<Result<Vec<_>, _>>()
})
.transpose()
.map_err(|e| format!("Failed to serialize tools: {}", e))?;
// Build template kwargs, merging reasoning_effort if present
let mut combined_template_kwargs = HashMap::new();
// Add reasoning_effort if present (like Python does)
if let Some(reasoning_effort) = &request.reasoning_effort {
combined_template_kwargs.insert(
"reasoning_effort".to_string(),
Value::String(reasoning_effort.clone()),
);
}
// Add any additional template kwargs from request
if let Some(template_kwargs) = &request.chat_template_kwargs {
for (key, value) in template_kwargs {
combined_template_kwargs.insert(key.clone(), value.clone());
}
}
let final_template_kwargs = if combined_template_kwargs.is_empty() {
None
} else {
Some(&combined_template_kwargs)
};
let params = ChatTemplateParams {
add_generation_prompt: true,
continue_final_message: request.continue_final_message,
tools: tools_json.as_deref(),
template_kwargs: final_template_kwargs,
..Default::default()
};
// Handle assistant prefix for continue_final_message
let assistant_prefix = if request.continue_final_message
&& !transformed_messages.is_empty()
&& transformed_messages
.last()
.and_then(|msg| msg.get("role"))
.and_then(|v| v.as_str())
== Some("assistant")
{
// Pop the last message to handle it separately
let last_msg = transformed_messages.pop().unwrap();
last_msg
.get("content")
.and_then(|v| v.as_str())
.map(|s| s.to_string())
} else {
None
};
// Apply chat template with the (now possibly shorter) list of messages
let rendered = hf_tokenizer
.apply_chat_template(&transformed_messages, params)
.map_err(|e| format!("Failed to apply chat template: {}", e))?;
// Append assistant prefix if we have one
if let Some(prefix) = assistant_prefix {
format!("{}{}", rendered, prefix)
} else {
rendered
}
} else {
return Err(
"gRPC router requires HuggingFace tokenizer with chat template support".to_string(),
);
};
// Placeholder for multimodal inputs
let multimodal_inputs = None;
Ok(ProcessedMessages {
text: formatted_text,
multimodal_inputs,
stop_sequences: request.stop.clone(),
})
}
/// Error response helpers (shared between regular and PD routers)
pub fn internal_error_static(msg: &'static str) -> Response {
error!("{}", msg);
(StatusCode::INTERNAL_SERVER_ERROR, msg).into_response()
}
pub fn internal_error_message(message: String) -> Response {
error!("{}", message);
(StatusCode::INTERNAL_SERVER_ERROR, message).into_response()
}
/// Create a StopSequenceDecoder from stop parameters
pub fn create_stop_decoder(
tokenizer: &Arc<dyn Tokenizer>,
stop: Option<&StringOrArray>,
stop_token_ids: Option<&Vec<u32>>,
skip_special_tokens: bool,
no_stop_trim: bool,
) -> StopSequenceDecoder {
use crate::tokenizer::stop::StopSequenceDecoderBuilder;
// Extract stop sequences
let stop_sequences: Vec<String> = match stop {
Some(StringOrArray::String(s)) => vec![s.clone()],
Some(StringOrArray::Array(arr)) => arr.clone(),
None => vec![],
};
// Build stop sequence decoder
let mut builder =
StopSequenceDecoderBuilder::new(tokenizer.clone()).skip_special_tokens(skip_special_tokens);
// Add stop sequences (visible if no_stop_trim is true, hidden otherwise)
for seq in stop_sequences {
builder = if no_stop_trim {
builder.visible_stop_sequence(seq)
} else {
builder.stop_sequence(seq)
};
}
// Add stop token IDs (visible if no_stop_trim is true, hidden otherwise)
if let Some(token_ids) = stop_token_ids {
for &token_id in token_ids {
builder = if no_stop_trim {
builder.visible_stop_token(token_id)
} else {
builder.stop_token(token_id)
};
}
}
builder.build()
}
/// Parse tool calls from JSON schema constrained response
pub fn parse_json_schema_response(
processed_text: &str,
tool_choice: &Option<ToolChoice>,
) -> (Option<Vec<ToolCall>>, String) {
match tool_choice {
Some(ToolChoice::Function { function, .. }) => {
// Specific function: Parse parameters directly
match serde_json::from_str::<Value>(processed_text) {
Ok(params) => {
let tool_call = ToolCall {
id: format!("call_{}", uuid::Uuid::new_v4()),
tool_type: "function".to_string(),
function: FunctionCallResponse {
name: function.name.clone(),
arguments: Some(
serde_json::to_string(&params).unwrap_or_else(|_| "{}".to_string()),
),
},
};
(Some(vec![tool_call]), String::new())
}
Err(e) => {
error!("Failed to parse specific function parameters: {}", e);
(None, processed_text.to_string())
}
}
}
Some(ToolChoice::Value(ToolChoiceValue::Required))
| Some(ToolChoice::AllowedTools { .. }) => {
// Required mode: Parse array of tool calls
match serde_json::from_str::<Vec<Value>>(processed_text) {
Ok(parsed_array) => {
let spec_tool_calls: Vec<ToolCall> = parsed_array
.into_iter()
.enumerate()
.filter_map(|(i, item)| {
let obj = item.as_object()?;
let name = obj.get("name")?.as_str()?.to_string();
let parameters = obj.get("parameters")?;
Some(ToolCall {
id: format!("call_{}_{}", i, uuid::Uuid::new_v4()),
tool_type: "function".to_string(),
function: FunctionCallResponse {
name,
arguments: Some(
serde_json::to_string(parameters)
.unwrap_or_else(|_| "{}".to_string()),
),
},
})
})
.collect();
(Some(spec_tool_calls), String::new())
}
Err(e) => {
error!("Failed to parse required tool call array: {}", e);
(None, processed_text.to_string())
}
}
}
_ => (None, processed_text.to_string()),
}
}
/// Collect responses from a gRPC stream
///
/// This helper processes a gRPC GenerateResponse stream and collects all Complete responses.
/// Used by both regular and PD routers for non-streaming requests.
///
/// # Arguments
/// * `stream` - The gRPC response stream to consume
/// * `worker_name` - Name for logging (e.g., "Prefill", "Decode", "Worker")
///
/// # Returns
/// * `Ok(Vec<GenerateComplete>)` - All complete responses collected from the stream
/// * `Err(Response)` - Error response if the stream fails or returns an error
pub async fn collect_stream_responses(
mut stream: Streaming<proto::GenerateResponse>,
worker_name: &str,
) -> Result<Vec<proto::GenerateComplete>, Response> {
use proto::generate_response::Response::*;
let mut all_responses = Vec::new();
while let Some(response) = stream.next().await {
match response {
Ok(gen_response) => {
match gen_response.response {
Some(Complete(complete)) => {
debug!(
"{} completed: prompt_tokens={}, completion_tokens={}, finish_reason={}",
worker_name, complete.prompt_tokens, complete.completion_tokens, complete.finish_reason
);
all_responses.push(complete);
}
Some(Error(err)) => {
error!("{} error: {}", worker_name, err.message);
return Err(internal_error_message(format!(
"{} generation failed: {}",
worker_name, err.message
)));
}
Some(Chunk(chunk)) => {
debug!("{} chunk: {} tokens", worker_name, chunk.token_ids.len());
}
None => {
debug!("{}: empty response", worker_name);
}
}
}
Err(e) => {
error!("{} stream error: {:?}", worker_name, e);
return Err(internal_error_message(format!(
"{} stream failed: {}",
worker_name, e
)));
}
}
}
debug!("{} stream closed", worker_name);
Ok(all_responses)
}
/// Count the number of tool calls in the request message history
/// This is used for KimiK2 format which needs globally unique indices
pub fn get_history_tool_calls_count(request: &ChatCompletionRequest) -> usize {
request
.messages
.iter()
.filter_map(|msg| {
if let ChatMessage::Assistant { tool_calls, .. } = msg {
tool_calls.as_ref().map(|calls| calls.len())
} else {
None
}
})
.sum()
}
/// Generate a tool call ID based on model format
///
/// # Arguments
/// * `model` - Model name to determine ID format
/// * `tool_name` - Name of the tool being called
/// * `tool_index` - Index of this tool call within the current message
/// * `history_count` - Number of tool calls in previous messages
///
/// # Returns
/// A unique ID string. KimiK2 uses `functions.{name}:{global_index}`, others use `call_{uuid}`
pub fn generate_tool_call_id(
model: &str,
tool_name: &str,
tool_index: usize,
history_count: usize,
) -> String {
if model.to_lowercase().contains("kimi") {
// KimiK2 format: functions.{name}:{global_index}
format!("functions.{}:{}", tool_name, history_count + tool_index)
} else {
// Standard OpenAI format: call_{24-char-uuid}
format!("call_{}", &Uuid::new_v4().simple().to_string()[..24])
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::protocols::spec::{ChatMessage, ContentPart, ImageUrl, UserMessageContent};
use crate::tokenizer::chat_template::ChatTemplateContentFormat;
use serde_json::json;
#[test]
fn test_transform_messages_string_format() {
let messages = vec![ChatMessage::User {
role: "user".to_string(),
content: UserMessageContent::Parts(vec![
ContentPart::Text {
text: "Hello".to_string(),
},
ContentPart::ImageUrl {
image_url: ImageUrl {
url: "https://example.com/image.jpg".to_string(),
detail: None,
},
},
ContentPart::Text {
text: "World".to_string(),
},
]),
name: None,
}];
let result = process_content_format(&messages, ChatTemplateContentFormat::String).unwrap();
assert_eq!(result.len(), 1);
let transformed_message = &result[0];
// Should flatten multimodal content to text only
assert_eq!(
transformed_message["content"].as_str().unwrap(),
"Hello World"
);
assert_eq!(transformed_message["role"].as_str().unwrap(), "user");
}
#[test]
fn test_transform_messages_openai_format() {
let messages = vec![ChatMessage::User {
role: "user".to_string(),
content: UserMessageContent::Parts(vec![
ContentPart::Text {
text: "Describe this image:".to_string(),
},
ContentPart::ImageUrl {
image_url: ImageUrl {
url: "https://example.com/image.jpg".to_string(),
detail: Some("high".to_string()),
},
},
]),
name: None,
}];
let result = process_content_format(&messages, ChatTemplateContentFormat::OpenAI).unwrap();
assert_eq!(result.len(), 1);
let transformed_message = &result[0];
// Should replace media URLs with simple type placeholders
let content_array = transformed_message["content"].as_array().unwrap();
assert_eq!(content_array.len(), 2);
// Text part should remain unchanged
assert_eq!(content_array[0]["type"], "text");
assert_eq!(content_array[0]["text"], "Describe this image:");
// Image part should be replaced with simple type placeholder
assert_eq!(content_array[1], json!({"type": "image"}));
}
#[test]
fn test_transform_messages_simple_string_content() {
let messages = vec![ChatMessage::User {
role: "user".to_string(),
content: UserMessageContent::Text("Simple text message".to_string()),
name: None,
}];
let result = process_content_format(&messages, ChatTemplateContentFormat::String).unwrap();
assert_eq!(result.len(), 1);
let transformed_message = &result[0];
// Simple string content should remain unchanged
assert_eq!(
transformed_message["content"].as_str().unwrap(),
"Simple text message"
);
}
#[test]
fn test_transform_messages_multiple_messages() {
let messages = vec![
ChatMessage::System {
role: "system".to_string(),
content: "System prompt".to_string(),
name: None,
},
ChatMessage::User {
role: "user".to_string(),
content: UserMessageContent::Parts(vec![
ContentPart::Text {
text: "User message".to_string(),
},
ContentPart::ImageUrl {
image_url: ImageUrl {
url: "https://example.com/image.jpg".to_string(),
detail: None,
},
},
]),
name: None,
},
];
let result = process_content_format(&messages, ChatTemplateContentFormat::String).unwrap();
assert_eq!(result.len(), 2);
// System message should remain unchanged
assert_eq!(result[0]["role"].as_str().unwrap(), "system");
assert_eq!(result[0]["content"].as_str().unwrap(), "System prompt");
// User message should be flattened to text only
assert_eq!(result[1]["role"].as_str().unwrap(), "user");
assert_eq!(result[1]["content"].as_str().unwrap(), "User message");
}
#[test]
fn test_transform_messages_empty_text_parts() {
let messages = vec![ChatMessage::User {
role: "user".to_string(),
content: UserMessageContent::Parts(vec![ContentPart::ImageUrl {
image_url: ImageUrl {
url: "https://example.com/image.jpg".to_string(),
detail: None,
},
}]),
name: None,
}];
let result = process_content_format(&messages, ChatTemplateContentFormat::String).unwrap();
assert_eq!(result.len(), 1);
let transformed_message = &result[0];
// Should keep original multimodal content when no text parts exist
assert!(transformed_message["content"].is_array());
}
#[test]
fn test_transform_messages_mixed_content_types() {
let messages = vec![
ChatMessage::User {
role: "user".to_string(),
content: UserMessageContent::Text("Plain text".to_string()),
name: None,
},
ChatMessage::User {
role: "user".to_string(),
content: UserMessageContent::Parts(vec![
ContentPart::Text {
text: "With image".to_string(),
},
ContentPart::ImageUrl {
image_url: ImageUrl {
url: "https://example.com/image.jpg".to_string(),
detail: Some("low".to_string()),
},
},
]),
name: None,
},
];
let result_string =
process_content_format(&messages, ChatTemplateContentFormat::String).unwrap();
assert_eq!(result_string.len(), 2);
assert_eq!(result_string[0]["content"].as_str().unwrap(), "Plain text");
assert_eq!(result_string[1]["content"].as_str().unwrap(), "With image");
let result_openai =
process_content_format(&messages, ChatTemplateContentFormat::OpenAI).unwrap();
assert_eq!(result_openai.len(), 2);
assert_eq!(result_openai[0]["content"].as_str().unwrap(), "Plain text");
let content_array = result_openai[1]["content"].as_array().unwrap();
assert_eq!(content_array.len(), 2);
assert_eq!(content_array[0]["type"], "text");
assert_eq!(content_array[1], json!({"type": "image"}));
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment