Unverified Commit 7ff93e61 authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

router(grpc): Implement route for chat_cmpl endpoint (#10761)

parent b24b2e7e
...@@ -13,6 +13,8 @@ fn main() -> Result<(), Box<dyn std::error::Error>> { ...@@ -13,6 +13,8 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
// Generate both client and server code // Generate both client and server code
.build_server(true) .build_server(true)
.build_client(true) .build_client(true)
// Add protoc arguments for proto3 optional support
.protoc_arg("--experimental_allow_proto3_optional")
// Add a module-level attribute for documentation and clippy warnings // Add a module-level attribute for documentation and clippy warnings
.server_mod_attribute( .server_mod_attribute(
"sglang.grpc.scheduler", "sglang.grpc.scheduler",
......
...@@ -97,7 +97,7 @@ mod tests { ...@@ -97,7 +97,7 @@ mod tests {
fn test_generate_request_construction() { fn test_generate_request_construction() {
let sampling_params = proto::SamplingParams { let sampling_params = proto::SamplingParams {
temperature: 0.7, temperature: 0.7,
max_new_tokens: 128, max_new_tokens: Some(128),
top_p: 0.9, top_p: 0.9,
top_k: 50, top_k: 50,
stop: vec!["</s>".to_string()], stop: vec!["</s>".to_string()],
...@@ -126,7 +126,7 @@ mod tests { ...@@ -126,7 +126,7 @@ mod tests {
let params = gen_req.sampling_params.unwrap(); let params = gen_req.sampling_params.unwrap();
assert_eq!(params.temperature, 0.7); assert_eq!(params.temperature, 0.7);
assert_eq!(params.max_new_tokens, 128); assert_eq!(params.max_new_tokens, Some(128));
assert_eq!(params.stop, vec!["</s>"]); assert_eq!(params.stop, vec!["</s>"]);
} }
...@@ -155,7 +155,7 @@ mod tests { ...@@ -155,7 +155,7 @@ mod tests {
fn test_sampling_params_defaults() { fn test_sampling_params_defaults() {
let params = proto::SamplingParams::default(); let params = proto::SamplingParams::default();
assert_eq!(params.temperature, 0.0); assert_eq!(params.temperature, 0.0);
assert_eq!(params.max_new_tokens, 0); assert_eq!(params.max_new_tokens, None);
assert_eq!(params.top_p, 0.0); assert_eq!(params.top_p, 0.0);
assert_eq!(params.top_k, 0); assert_eq!(params.top_k, 0);
assert!(params.stop.is_empty()); assert!(params.stop.is_empty());
......
...@@ -36,7 +36,7 @@ message SamplingParams { ...@@ -36,7 +36,7 @@ message SamplingParams {
float presence_penalty = 6; float presence_penalty = 6;
float repetition_penalty = 7; float repetition_penalty = 7;
int32 max_new_tokens = 8; optional int32 max_new_tokens = 8;
repeated string stop = 9; repeated string stop = 9;
repeated int32 stop_token_ids = 10; repeated int32 stop_token_ids = 10;
bool skip_special_tokens = 11; bool skip_special_tokens = 11;
......
...@@ -4,12 +4,16 @@ use crate::config::types::RetryConfig; ...@@ -4,12 +4,16 @@ use crate::config::types::RetryConfig;
use crate::core::{ use crate::core::{
BasicWorkerBuilder, CircuitBreakerConfig, HealthConfig, WorkerRegistry, WorkerType, BasicWorkerBuilder, CircuitBreakerConfig, HealthConfig, WorkerRegistry, WorkerType,
}; };
use crate::grpc::SglangSchedulerClient; use crate::grpc::{proto, SglangSchedulerClient};
use crate::metrics::RouterMetrics; use crate::metrics::RouterMetrics;
use crate::policies::{LoadBalancingPolicy, PolicyRegistry}; use crate::policies::{LoadBalancingPolicy, PolicyRegistry};
use crate::protocols::spec::{
ChatCompletionRequest, ChatMessage, ContentPart, ResponseFormat, StringOrArray,
UserMessageContent,
};
use crate::reasoning_parser::ParserFactory; use crate::reasoning_parser::ParserFactory;
use crate::routers::RouterTrait; use crate::routers::RouterTrait;
use crate::tokenizer::traits::Tokenizer; use crate::tokenizer::{chat_template::ChatMessage as TokenizerChatMessage, traits::Tokenizer};
use crate::tool_parser::ParserRegistry; use crate::tool_parser::ParserRegistry;
use async_trait::async_trait; use async_trait::async_trait;
use axum::{ use axum::{
...@@ -21,7 +25,16 @@ use axum::{ ...@@ -21,7 +25,16 @@ use axum::{
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use tracing::{info, warn}; use tracing::{debug, error, info, warn};
use uuid::Uuid;
// Data structures for processing
#[derive(Debug)]
pub struct ProcessedMessages {
pub text: String,
pub multimodal_inputs: Option<proto::MultimodalInputs>,
pub stop_sequences: Option<StringOrArray>,
}
/// gRPC router implementation for SGLang /// gRPC router implementation for SGLang
#[allow(dead_code)] // Fields will be used once implementation is complete #[allow(dead_code)] // Fields will be used once implementation is complete
...@@ -161,6 +174,345 @@ impl GrpcRouter { ...@@ -161,6 +174,345 @@ impl GrpcRouter {
circuit_breaker_config: core_cb_config, circuit_breaker_config: core_cb_config,
}) })
} }
// ============ Chat Implementation ============
/// Main route_chat implementation
async fn route_chat_impl(
&self,
_headers: Option<&HeaderMap>,
body: &ChatCompletionRequest,
model_id: Option<&str>,
) -> Response {
debug!(
"Processing chat completion request for model: {:?}",
model_id
);
// Step 1: Select worker (fail fast if no workers available)
let worker = match self.select_worker_for_request(model_id, None) {
Some(w) => w,
None => {
warn!("No available workers for model: {:?}", model_id);
return (StatusCode::SERVICE_UNAVAILABLE, "No available workers").into_response();
}
};
debug!("Selected worker: {}", worker.url());
// Step 2: Get gRPC client for worker (fail fast if can't connect)
let client = match self.get_or_create_grpc_client(worker.url()).await {
Ok(c) => c,
Err(e) => {
error!("Failed to get gRPC client: {}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to get gRPC client: {}", e),
)
.into_response();
}
};
// Step 3: Process messages and apply chat template
let processed_messages = match self.process_chat_messages(body) {
Ok(msgs) => msgs,
Err(e) => {
error!("Failed to process chat messages: {}", e);
return (StatusCode::BAD_REQUEST, e.to_string()).into_response();
}
};
// Step 4: Tokenize the processed text
let encoding = match self.tokenizer.encode(&processed_messages.text) {
Ok(encoding) => encoding,
Err(e) => {
error!("Tokenization failed: {}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
format!("Tokenization failed: {}", e),
)
.into_response();
}
};
let token_ids = encoding.token_ids().to_vec();
debug!("Tokenized {} tokens from input", token_ids.len());
// Step 5: Build tool constraints if needed
let structural_tag = if let Some(tools) = &body.tools {
self.generate_tool_constraints(tools, &body.tool_choice, &body.model)
} else {
None
};
// Step 6: Build SamplingParams for gRPC
let sampling_params = match self.build_grpc_sampling_params(body, structural_tag) {
Ok(params) => params,
Err(e) => {
error!("Failed to build sampling parameters: {}", e);
return (
StatusCode::BAD_REQUEST,
format!("Invalid sampling parameters: {}", e),
)
.into_response();
}
};
// Step 7: Create GenerateRequest
let grpc_request = proto::GenerateRequest {
request_id: format!("chatcmpl-{}", Uuid::new_v4()),
tokenized: Some(proto::TokenizedInput {
original_text: processed_messages.text.clone(),
input_ids: token_ids.into_iter().map(|id| id as i32).collect(),
}),
mm_inputs: processed_messages.multimodal_inputs,
sampling_params: Some(sampling_params),
return_logprob: body.logprobs,
logprob_start_len: -1,
top_logprobs_num: body.top_logprobs.unwrap_or(0) as i32,
return_hidden_states: body.return_hidden_states,
..Default::default()
};
// Step 8: Handle streaming vs non-streaming
if body.stream {
self.handle_streaming_chat(client, grpc_request, body).await
} else {
self.handle_non_streaming_chat(client, grpc_request, body)
.await
}
}
// ============ Helper Methods ============
/// Process chat messages and apply template
fn process_chat_messages(
&self,
request: &ChatCompletionRequest,
) -> Result<ProcessedMessages, String> {
let tokenizer_messages = self.convert_messages_for_tokenizer(&request.messages)?;
// Use the tokenizer's chat template - we require HuggingFace tokenizer for gRPC
let formatted_text = if let Some(hf_tokenizer) =
self.tokenizer
.as_any()
.downcast_ref::<crate::tokenizer::HuggingFaceTokenizer>()
{
hf_tokenizer
.apply_chat_template(&tokenizer_messages, true)
.map_err(|e| format!("Failed to apply chat template: {}", e))?
} else {
return Err(
"gRPC router requires HuggingFace tokenizer with chat template support".to_string(),
);
};
// Placeholder for multimodal inputs
let multimodal_inputs = None;
Ok(ProcessedMessages {
text: formatted_text,
multimodal_inputs,
stop_sequences: request.stop.clone(),
})
}
/// Convert spec ChatMessage enum to tokenizer ChatMessage struct
fn convert_messages_for_tokenizer(
&self,
messages: &[ChatMessage],
) -> Result<Vec<TokenizerChatMessage>, String> {
let mut converted = Vec::new();
for message in messages {
let tokenizer_msg = match message {
ChatMessage::System { content, .. } => TokenizerChatMessage::new("system", content),
ChatMessage::User { content, .. } => {
let text_content = match content {
UserMessageContent::Text(text) => text.clone(),
UserMessageContent::Parts(parts) => {
// Simple text extraction for now - multimodal is placeholder
parts
.iter()
.filter_map(|part| match part {
ContentPart::Text { text } => Some(text.as_str()),
ContentPart::ImageUrl { .. } => None, // Skip images for now
})
.collect::<Vec<&str>>()
.join(" ")
}
};
TokenizerChatMessage::new("user", text_content)
}
ChatMessage::Assistant { content, .. } => {
// Simple content extraction - no special tool/reasoning formatting
TokenizerChatMessage::new("assistant", content.as_deref().unwrap_or(""))
}
ChatMessage::Tool { content, .. } => TokenizerChatMessage::new("tool", content),
ChatMessage::Function { content, .. } => {
TokenizerChatMessage::new("function", content)
}
};
converted.push(tokenizer_msg);
}
Ok(converted)
}
/// Build gRPC SamplingParams from OpenAI request
fn build_grpc_sampling_params(
&self,
request: &ChatCompletionRequest,
structural_tag: Option<String>,
) -> Result<proto::SamplingParams, String> {
let stop_sequences = self.extract_stop_strings(request);
// Handle max tokens: prefer max_completion_tokens (new) over max_tokens (deprecated)
// If neither is specified, use None to let the backend decide the default
#[allow(deprecated)]
let max_new_tokens = request
.max_completion_tokens
.or(request.max_tokens)
.map(|v| v as i32);
#[allow(deprecated)]
Ok(proto::SamplingParams {
temperature: request.temperature.unwrap_or(1.0),
top_p: request.top_p.unwrap_or(1.0),
top_k: request.top_k.unwrap_or(-1),
min_p: request.min_p.unwrap_or(0.0),
frequency_penalty: request.frequency_penalty.unwrap_or(0.0),
presence_penalty: request.presence_penalty.unwrap_or(0.0),
repetition_penalty: request.repetition_penalty.unwrap_or(1.0),
max_new_tokens,
stop: stop_sequences,
stop_token_ids: request.stop_token_ids.clone().unwrap_or_default(),
skip_special_tokens: request.skip_special_tokens,
n: request.n.unwrap_or(1) as i32,
structural_tag: structural_tag.unwrap_or_default(),
constraint: self.build_constraint(request)?,
..Default::default()
})
}
/// Extract stop strings from request
fn extract_stop_strings(&self, request: &ChatCompletionRequest) -> Vec<String> {
match &request.stop {
Some(StringOrArray::String(s)) => vec![s.clone()],
Some(StringOrArray::Array(arr)) => arr.clone(),
None => vec![],
}
}
/// Build constraint for structured generation
fn build_constraint(
&self,
request: &ChatCompletionRequest,
) -> Result<Option<proto::sampling_params::Constraint>, String> {
if let Some(ResponseFormat::JsonSchema { json_schema }) = &request.response_format {
let schema_str = serde_json::to_string(&json_schema.schema)
.map_err(|e| format!("Failed to serialize JSON schema: {}", e))?;
return Ok(Some(proto::sampling_params::Constraint::JsonSchema(
schema_str,
)));
}
if let Some(ebnf) = &request.ebnf {
return Ok(Some(proto::sampling_params::Constraint::EbnfGrammar(
ebnf.clone(),
)));
}
if let Some(regex) = &request.regex {
return Ok(Some(proto::sampling_params::Constraint::Regex(
regex.clone(),
)));
}
Ok(None)
}
/// Generate tool constraints for structured generation
fn generate_tool_constraints(
&self,
_tools: &[crate::protocols::spec::Tool],
_tool_choice: &Option<crate::protocols::spec::ToolChoice>,
model: &str,
) -> Option<String> {
let _parser = self.tool_parser_registry.get_parser(model)?;
None
}
/// Select a worker for the request
fn select_worker_for_request(
&self,
model_id: Option<&str>,
text: Option<&str>,
) -> Option<Arc<dyn crate::core::Worker>> {
// Get workers for the specified model, filtered by connection mode
let workers = self.worker_registry.get_workers_filtered(
model_id,
Some(WorkerType::Regular),
Some(crate::core::ConnectionMode::Grpc { port: None }),
false, // get all workers, we'll filter by is_available() next
);
// Filter by availability (health + circuit breaker)
let available: Vec<Arc<dyn crate::core::Worker>> = workers
.iter()
.filter(|w| w.is_available())
.cloned()
.collect();
if available.is_empty() {
return None;
}
// Get the appropriate policy for this model
let policy = match model_id {
Some(model) => self.policy_registry.get_policy_or_default(model),
None => self.policy_registry.get_default_policy(),
};
// Select worker using the policy
let idx = policy.select_worker(&available, text)?;
Some(available[idx].clone())
}
/// Get or create a gRPC client for the worker
async fn get_or_create_grpc_client(
&self,
worker_url: &str,
) -> Result<SglangSchedulerClient, String> {
debug!("Creating new gRPC client for worker: {}", worker_url);
SglangSchedulerClient::connect(worker_url)
.await
.map_err(|e| format!("Failed to connect to gRPC server: {}", e))
}
/// Placeholder for streaming handler (to be implemented in Phase 2)
async fn handle_streaming_chat(
&self,
_client: SglangSchedulerClient,
_request: proto::GenerateRequest,
_original_request: &ChatCompletionRequest,
) -> Response {
(StatusCode::NOT_IMPLEMENTED, "Streaming not yet implemented").into_response()
}
/// Placeholder for non-streaming handler (to be implemented in Phase 3)
async fn handle_non_streaming_chat(
&self,
_client: SglangSchedulerClient,
_request: proto::GenerateRequest,
_original_request: &ChatCompletionRequest,
) -> Response {
(
StatusCode::NOT_IMPLEMENTED,
"Non-streaming not yet implemented",
)
.into_response()
}
} }
impl std::fmt::Debug for GrpcRouter { impl std::fmt::Debug for GrpcRouter {
...@@ -212,11 +564,11 @@ impl RouterTrait for GrpcRouter { ...@@ -212,11 +564,11 @@ impl RouterTrait for GrpcRouter {
async fn route_chat( async fn route_chat(
&self, &self,
_headers: Option<&HeaderMap>, headers: Option<&HeaderMap>,
_body: &crate::protocols::spec::ChatCompletionRequest, body: &crate::protocols::spec::ChatCompletionRequest,
_model_id: Option<&str>, model_id: Option<&str>,
) -> Response { ) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response() self.route_chat_impl(headers, body, model_id).await
} }
async fn route_completion( async fn route_completion(
......
...@@ -210,6 +210,10 @@ impl TokenizerTrait for HuggingFaceTokenizer { ...@@ -210,6 +210,10 @@ impl TokenizerTrait for HuggingFaceTokenizer {
fn id_to_token(&self, id: TokenIdType) -> Option<String> { fn id_to_token(&self, id: TokenIdType) -> Option<String> {
self.reverse_vocab.get(&id).cloned() self.reverse_vocab.get(&id).cloned()
} }
fn as_any(&self) -> &dyn std::any::Any {
self
}
} }
#[cfg(test)] #[cfg(test)]
......
...@@ -109,4 +109,8 @@ impl TokenizerTrait for MockTokenizer { ...@@ -109,4 +109,8 @@ impl TokenizerTrait for MockTokenizer {
fn id_to_token(&self, id: u32) -> Option<String> { fn id_to_token(&self, id: u32) -> Option<String> {
self.reverse_vocab.get(&id).cloned() self.reverse_vocab.get(&id).cloned()
} }
fn as_any(&self) -> &dyn std::any::Any {
self
}
} }
...@@ -170,6 +170,10 @@ impl TokenizerTrait for TiktokenTokenizer { ...@@ -170,6 +170,10 @@ impl TokenizerTrait for TiktokenTokenizer {
// We can only decode IDs to text // We can only decode IDs to text
None None
} }
fn as_any(&self) -> &dyn std::any::Any {
self
}
} }
#[cfg(test)] #[cfg(test)]
......
...@@ -22,6 +22,9 @@ pub trait Tokenizer: Encoder + Decoder { ...@@ -22,6 +22,9 @@ pub trait Tokenizer: Encoder + Decoder {
fn get_special_tokens(&self) -> &SpecialTokens; fn get_special_tokens(&self) -> &SpecialTokens;
fn token_to_id(&self, token: &str) -> Option<TokenIdType>; fn token_to_id(&self, token: &str) -> Option<TokenIdType>;
fn id_to_token(&self, id: TokenIdType) -> Option<String>; fn id_to_token(&self, id: TokenIdType) -> Option<String>;
/// Enable downcasting to concrete types
fn as_any(&self) -> &dyn std::any::Any;
} }
/// Contains the results of tokenizing text: token IDs, string tokens, and their spans /// Contains the results of tokenizing text: token IDs, string tokens, and their spans
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment