Unverified Commit 0c3db889 authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

[router][grpc] Add helpfer functions for decoder in router.rs and fix specs (#10971)

parent 2bdaf482
...@@ -36,9 +36,9 @@ message SamplingParams { ...@@ -36,9 +36,9 @@ message SamplingParams {
float presence_penalty = 6; float presence_penalty = 6;
float repetition_penalty = 7; float repetition_penalty = 7;
int32 max_new_tokens = 8; optional int32 max_new_tokens = 8;
repeated string stop = 9; repeated string stop = 9;
repeated int32 stop_token_ids = 10; repeated uint32 stop_token_ids = 10;
bool skip_special_tokens = 11; bool skip_special_tokens = 11;
bool spaces_between_special_tokens = 12; bool spaces_between_special_tokens = 12;
...@@ -98,7 +98,7 @@ message GenerateRequest { ...@@ -98,7 +98,7 @@ message GenerateRequest {
bool return_logprob = 5; bool return_logprob = 5;
int32 logprob_start_len = 6; int32 logprob_start_len = 6;
int32 top_logprobs_num = 7; int32 top_logprobs_num = 7;
repeated int32 token_ids_logprob = 8; repeated uint32 token_ids_logprob = 8;
bool return_hidden_states = 9; bool return_hidden_states = 9;
// For disaggregated serving // For disaggregated serving
...@@ -129,7 +129,7 @@ message GenerateRequest { ...@@ -129,7 +129,7 @@ message GenerateRequest {
message TokenizedInput { message TokenizedInput {
string original_text = 1; // For reference string original_text = 1; // For reference
repeated int32 input_ids = 2; repeated uint32 input_ids = 2;
} }
message MultimodalInputs { message MultimodalInputs {
...@@ -167,7 +167,7 @@ message GenerateResponse { ...@@ -167,7 +167,7 @@ message GenerateResponse {
message GenerateStreamChunk { message GenerateStreamChunk {
// Generated tokens (incremental chunk) // Generated tokens (incremental chunk)
repeated int32 token_ids = 1; repeated uint32 token_ids = 1;
// Cumulative counts // Cumulative counts
int32 prompt_tokens = 2; int32 prompt_tokens = 2;
...@@ -183,7 +183,7 @@ message GenerateStreamChunk { ...@@ -183,7 +183,7 @@ message GenerateStreamChunk {
message GenerateComplete { message GenerateComplete {
// Final output // Final output
repeated int32 output_ids = 1; repeated uint32 output_ids = 1;
// Finish reason // Finish reason
enum FinishReason { enum FinishReason {
......
...@@ -20,7 +20,7 @@ pub struct SglangSchedulerClient { ...@@ -20,7 +20,7 @@ pub struct SglangSchedulerClient {
impl SglangSchedulerClient { impl SglangSchedulerClient {
/// Create a new client and connect to the scheduler /// Create a new client and connect to the scheduler
pub async fn connect(endpoint: &str) -> Result<Self, Box<dyn std::error::Error>> { pub async fn connect(endpoint: &str) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
debug!("Connecting to SGLang scheduler at {}", endpoint); debug!("Connecting to SGLang scheduler at {}", endpoint);
// Convert grpc:// to http:// for tonic // Convert grpc:// to http:// for tonic
...@@ -41,10 +41,11 @@ impl SglangSchedulerClient { ...@@ -41,10 +41,11 @@ impl SglangSchedulerClient {
} }
/// Submit a generation request (returns streaming response) /// Submit a generation request (returns streaming response)
pub async fn generate_stream( pub async fn generate(
&mut self, &mut self,
req: proto::GenerateRequest, req: proto::GenerateRequest,
) -> Result<tonic::Streaming<proto::GenerateResponse>, Box<dyn std::error::Error>> { ) -> Result<tonic::Streaming<proto::GenerateResponse>, Box<dyn std::error::Error + Send + Sync>>
{
let request = Request::new(req); let request = Request::new(req);
let response = self.client.generate(request).await?; let response = self.client.generate(request).await?;
Ok(response.into_inner()) Ok(response.into_inner())
...@@ -53,7 +54,7 @@ impl SglangSchedulerClient { ...@@ -53,7 +54,7 @@ impl SglangSchedulerClient {
/// Perform health check /// Perform health check
pub async fn health_check( pub async fn health_check(
&mut self, &mut self,
) -> Result<proto::HealthCheckResponse, Box<dyn std::error::Error>> { ) -> Result<proto::HealthCheckResponse, Box<dyn std::error::Error + Send + Sync>> {
debug!("Sending health check request"); debug!("Sending health check request");
let request = Request::new(proto::HealthCheckRequest { let request = Request::new(proto::HealthCheckRequest {
tokenized: Some(proto::TokenizedInput { tokenized: Some(proto::TokenizedInput {
...@@ -72,7 +73,7 @@ impl SglangSchedulerClient { ...@@ -72,7 +73,7 @@ impl SglangSchedulerClient {
&mut self, &mut self,
request_id: String, request_id: String,
reason: String, reason: String,
) -> Result<(), Box<dyn std::error::Error>> { ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let request = Request::new(proto::AbortRequest { request_id, reason }); let request = Request::new(proto::AbortRequest { request_id, reason });
self.client.abort(request).await?; self.client.abort(request).await?;
...@@ -85,7 +86,7 @@ impl SglangSchedulerClient { ...@@ -85,7 +86,7 @@ impl SglangSchedulerClient {
request_id: String, request_id: String,
body: &ChatCompletionRequest, body: &ChatCompletionRequest,
processed_text: String, processed_text: String,
token_ids: Vec<i32>, token_ids: Vec<u32>,
multimodal_inputs: Option<proto::MultimodalInputs>, multimodal_inputs: Option<proto::MultimodalInputs>,
tool_call_constraint: Option<(String, String)>, // (constraint_type, constraint_value) tool_call_constraint: Option<(String, String)>, // (constraint_type, constraint_value)
) -> Result<proto::GenerateRequest, String> { ) -> Result<proto::GenerateRequest, String> {
...@@ -153,6 +154,8 @@ impl SglangSchedulerClient { ...@@ -153,6 +154,8 @@ impl SglangSchedulerClient {
stop: stop_sequences, stop: stop_sequences,
stop_token_ids: request.stop_token_ids.clone().unwrap_or_default(), stop_token_ids: request.stop_token_ids.clone().unwrap_or_default(),
skip_special_tokens, skip_special_tokens,
ignore_eos: request.ignore_eos,
no_stop_trim: request.no_stop_trim,
n: request.n.unwrap_or(1) as i32, n: request.n.unwrap_or(1) as i32,
constraint: self.build_constraint(request, tool_call_constraint)?, constraint: self.build_constraint(request, tool_call_constraint)?,
..Default::default() ..Default::default()
......
...@@ -38,7 +38,7 @@ message SamplingParams { ...@@ -38,7 +38,7 @@ message SamplingParams {
optional int32 max_new_tokens = 8; optional int32 max_new_tokens = 8;
repeated string stop = 9; repeated string stop = 9;
repeated int32 stop_token_ids = 10; repeated uint32 stop_token_ids = 10;
bool skip_special_tokens = 11; bool skip_special_tokens = 11;
bool spaces_between_special_tokens = 12; bool spaces_between_special_tokens = 12;
...@@ -98,7 +98,7 @@ message GenerateRequest { ...@@ -98,7 +98,7 @@ message GenerateRequest {
bool return_logprob = 5; bool return_logprob = 5;
int32 logprob_start_len = 6; int32 logprob_start_len = 6;
int32 top_logprobs_num = 7; int32 top_logprobs_num = 7;
repeated int32 token_ids_logprob = 8; repeated uint32 token_ids_logprob = 8;
bool return_hidden_states = 9; bool return_hidden_states = 9;
// For disaggregated serving // For disaggregated serving
...@@ -129,7 +129,7 @@ message GenerateRequest { ...@@ -129,7 +129,7 @@ message GenerateRequest {
message TokenizedInput { message TokenizedInput {
string original_text = 1; // For reference string original_text = 1; // For reference
repeated int32 input_ids = 2; repeated uint32 input_ids = 2;
} }
message MultimodalInputs { message MultimodalInputs {
...@@ -167,7 +167,7 @@ message GenerateResponse { ...@@ -167,7 +167,7 @@ message GenerateResponse {
message GenerateStreamChunk { message GenerateStreamChunk {
// Generated tokens (incremental chunk) // Generated tokens (incremental chunk)
repeated int32 token_ids = 1; repeated uint32 token_ids = 1;
// Cumulative counts // Cumulative counts
int32 prompt_tokens = 2; int32 prompt_tokens = 2;
...@@ -183,7 +183,7 @@ message GenerateStreamChunk { ...@@ -183,7 +183,7 @@ message GenerateStreamChunk {
message GenerateComplete { message GenerateComplete {
// Final output // Final output
repeated int32 output_ids = 1; repeated uint32 output_ids = 1;
// Finish reason // Finish reason
enum FinishReason { enum FinishReason {
......
...@@ -313,7 +313,7 @@ pub struct ChatCompletionRequest { ...@@ -313,7 +313,7 @@ pub struct ChatCompletionRequest {
/// Specific token IDs to use as stop conditions /// Specific token IDs to use as stop conditions
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub stop_token_ids: Option<Vec<i32>>, pub stop_token_ids: Option<Vec<u32>>,
/// Skip trimming stop tokens from output /// Skip trimming stop tokens from output
#[serde(default)] #[serde(default)]
...@@ -564,7 +564,7 @@ pub struct CompletionRequest { ...@@ -564,7 +564,7 @@ pub struct CompletionRequest {
/// Specific token IDs to use as stop conditions /// Specific token IDs to use as stop conditions
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub stop_token_ids: Option<Vec<i32>>, pub stop_token_ids: Option<Vec<u32>>,
/// Skip trimming stop tokens from output /// Skip trimming stop tokens from output
#[serde(default)] #[serde(default)]
...@@ -1864,7 +1864,7 @@ pub struct SamplingParams { ...@@ -1864,7 +1864,7 @@ pub struct SamplingParams {
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub min_tokens: Option<u32>, pub min_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub stop_token_ids: Option<Vec<i32>>, pub stop_token_ids: Option<Vec<u32>>,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub no_stop_trim: Option<bool>, pub no_stop_trim: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
......
...@@ -17,19 +17,20 @@ use crate::grpc_client::{proto, SglangSchedulerClient}; ...@@ -17,19 +17,20 @@ use crate::grpc_client::{proto, SglangSchedulerClient};
use crate::metrics::RouterMetrics; use crate::metrics::RouterMetrics;
use crate::policies::PolicyRegistry; use crate::policies::PolicyRegistry;
use crate::protocols::spec::ChatMessage; use crate::protocols::spec::ChatMessage;
use crate::protocols::spec::{ChatCompletionRequest, StringOrArray};
use crate::protocols::spec::{ use crate::protocols::spec::{
CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest, ResponsesGetParams, ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest,
ResponsesRequest, Tool, ToolChoice, ResponsesGetParams, ResponsesRequest, StringOrArray, Tool, ToolChoice,
}; };
use crate::reasoning_parser::ParserFactory; use crate::reasoning_parser::ParserFactory;
use crate::routers::RouterTrait; use crate::routers::RouterTrait;
use crate::server::AppContext; use crate::server::AppContext;
use crate::tokenizer::chat_template::{ChatTemplateContentFormat, ChatTemplateParams}; use crate::tokenizer::chat_template::{ChatTemplateContentFormat, ChatTemplateParams};
use crate::tokenizer::stop::{SequenceDecoderOutput, StopSequenceDecoderBuilder};
use crate::tokenizer::traits::Tokenizer; use crate::tokenizer::traits::Tokenizer;
use crate::tokenizer::HuggingFaceTokenizer; use crate::tokenizer::HuggingFaceTokenizer;
use crate::tool_parser::ParserRegistry; use crate::tool_parser::ParserRegistry;
use serde_json::Value; use serde_json::Value;
use tokio_stream::StreamExt;
use uuid::Uuid; use uuid::Uuid;
// Data structures for processing // Data structures for processing
...@@ -182,7 +183,7 @@ impl GrpcRouter { ...@@ -182,7 +183,7 @@ impl GrpcRouter {
request_id, request_id,
body, body,
processed_messages.text.clone(), processed_messages.text.clone(),
token_ids.into_iter().map(|id| id as i32).collect(), token_ids,
processed_messages.multimodal_inputs, processed_messages.multimodal_inputs,
tool_call_constraint, // Pass the full tuple (type, value) tool_call_constraint, // Pass the full tuple (type, value)
) { ) {
...@@ -479,28 +480,225 @@ impl GrpcRouter { ...@@ -479,28 +480,225 @@ impl GrpcRouter {
None None
} }
/// Placeholder for streaming handler (to be implemented in Phase 2) /// Create a StopSequenceDecoder from the chat completion request
fn create_stop_decoder(
&self,
original_request: &ChatCompletionRequest,
) -> crate::tokenizer::stop::StopSequenceDecoder {
// Extract stop sequences from request
let stop_sequences: Vec<String> = match &original_request.stop {
Some(StringOrArray::String(s)) => vec![s.clone()],
Some(StringOrArray::Array(arr)) => arr.clone(),
None => vec![],
};
// Build stop sequence decoder
let mut builder = StopSequenceDecoderBuilder::new(self.tokenizer.clone())
.skip_special_tokens(original_request.skip_special_tokens);
// Add stop sequences (visible if no_stop_trim is true, hidden otherwise)
for seq in stop_sequences {
builder = if original_request.no_stop_trim {
builder.visible_stop_sequence(seq)
} else {
builder.stop_sequence(seq)
};
}
// Add stop token IDs (visible if no_stop_trim is true, hidden otherwise)
if let Some(stop_token_ids) = &original_request.stop_token_ids {
for &token_id in stop_token_ids {
builder = if original_request.no_stop_trim {
builder.visible_stop_token(token_id)
} else {
builder.stop_token(token_id)
};
}
}
builder.build()
}
/// Process a chunk of tokens through the stop decoder
fn process_chunk_tokens(
stop_decoder: &mut crate::tokenizer::stop::StopSequenceDecoder,
token_ids: &[u32],
) -> (String, bool) {
let mut chunk_text = String::new();
for &token_id in token_ids {
match stop_decoder.process_token(token_id).unwrap_or_else(|e| {
debug!(
"Error processing token {}: {}. Treating as Held.",
token_id, e
);
SequenceDecoderOutput::Held
}) {
SequenceDecoderOutput::Text(text) => {
chunk_text.push_str(&text);
}
SequenceDecoderOutput::StoppedWithText(text) => {
chunk_text.push_str(&text);
return (chunk_text, true); // Return text and signal to stop
}
SequenceDecoderOutput::Stopped => {
return (chunk_text, true); // Return text and signal to stop
}
SequenceDecoderOutput::Held => {
// Text held for potential stop sequence match
}
}
}
(chunk_text, false) // Return text and continue processing
}
/// Submit request and handle streaming response for chat completions route
async fn handle_streaming_chat( async fn handle_streaming_chat(
&self, &self,
_client: SglangSchedulerClient, mut client: SglangSchedulerClient,
_request: proto::GenerateRequest, request: proto::GenerateRequest,
_original_request: &ChatCompletionRequest, original_request: &ChatCompletionRequest,
) -> Response { ) -> Response {
(StatusCode::NOT_IMPLEMENTED, "Streaming not yet implemented").into_response() let mut stop_decoder = self.create_stop_decoder(original_request);
// Process streaming tokens
let mut grpc_stream = match client.generate(request).await {
Ok(stream) => stream,
Err(e) => {
error!("Failed to start generation: {}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
format!("Generation failed: {}", e),
)
.into_response();
}
};
let mut decoded_text = String::new();
while let Some(response) = grpc_stream.next().await {
let gen_response = match response {
Ok(resp) => resp,
Err(e) => {
error!("Stream error: {}", e);
break;
}
};
match gen_response.response {
Some(proto::generate_response::Response::Chunk(chunk)) => {
// Process tokens and check if we should stop
let (chunk_text, should_stop) =
Self::process_chunk_tokens(&mut stop_decoder, &chunk.token_ids);
decoded_text.push_str(&chunk_text);
if should_stop {
break;
}
continue;
}
Some(proto::generate_response::Response::Complete(_complete)) => {
// Flush any remaining text
if let SequenceDecoderOutput::Text(text) = stop_decoder.flush() {
if !text.is_empty() {
decoded_text.push_str(&text);
debug!("Flushed text: {}", text);
}
}
break;
}
Some(proto::generate_response::Response::Error(error)) => {
error!("Generation error: {}", error.message);
break;
}
None => continue,
}
} }
/// Placeholder for non-streaming handler (to be implemented in Phase 3) // TODO: Replace with proper SSE streaming response
// For now, return the complete decoded text
(StatusCode::OK, format!("Decoded text: {}", decoded_text)).into_response()
}
/// Submit request and handle non-streaming response for chat completions route
async fn handle_non_streaming_chat( async fn handle_non_streaming_chat(
&self, &self,
_client: SglangSchedulerClient, mut client: SglangSchedulerClient,
_request: proto::GenerateRequest, request: proto::GenerateRequest,
_original_request: &ChatCompletionRequest, original_request: &ChatCompletionRequest,
) -> Response { ) -> Response {
let mut stop_decoder = self.create_stop_decoder(original_request);
// Small helpers to log + return a uniform 500
let fail_str = |msg: &'static str| -> Response {
error!("{}", msg);
(StatusCode::INTERNAL_SERVER_ERROR, msg).into_response()
};
let fail_fmt = |prefix: &str, e: &dyn std::fmt::Display| -> Response {
error!("{}{}", prefix, e);
( (
StatusCode::NOT_IMPLEMENTED, StatusCode::INTERNAL_SERVER_ERROR,
"Non-streaming not yet implemented", format!("{}{}", prefix, e),
) )
.into_response() .into_response()
};
// Start generation
let mut stream = match client.generate(request).await {
Ok(s) => s,
Err(e) => return fail_fmt("Failed to start generation: ", &e),
};
// Get the single Complete response
let gen_response = match stream.next().await {
Some(Ok(r)) => r,
Some(Err(e)) => return fail_fmt("Failed to get GenerateResponse: ", &e),
None => return fail_str("No response from server"),
};
// Extract the expected variant early
let complete = match gen_response.response {
Some(proto::generate_response::Response::Complete(c)) => c,
Some(proto::generate_response::Response::Error(err)) => {
error!("Generation failed: {}", err.message);
return (
StatusCode::INTERNAL_SERVER_ERROR,
format!("Generation failed: {}", err.message),
)
.into_response();
}
Some(proto::generate_response::Response::Chunk(_)) => {
return fail_str("Unexpected chunk response for non-streaming request")
}
None => return fail_str("Empty response from server"),
};
// Decode tokens
let outputs = match stop_decoder.process_tokens(&complete.output_ids) {
Ok(o) => o,
Err(e) => return fail_fmt("Failed to process tokens: ", &e),
};
// Accumulate text with early breaks
let mut final_text = String::new();
for output in outputs {
match output {
SequenceDecoderOutput::Text(t) => final_text.push_str(&t),
SequenceDecoderOutput::StoppedWithText(t) => {
final_text.push_str(&t);
break;
}
SequenceDecoderOutput::Stopped => break,
SequenceDecoderOutput::Held => {}
}
}
// Flush remaining text
if let SequenceDecoderOutput::Text(t) = stop_decoder.flush() {
final_text.push_str(&t);
}
// TODO: Create proper OpenAI-compatible response
(StatusCode::OK, format!("Final text: {}", final_text)).into_response()
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment