Unverified Commit 37f3325b authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

[router][grpc] Support E2E non-stream chat completions (#10980)

parent bd95944c
...@@ -13,7 +13,7 @@ import sys ...@@ -13,7 +13,7 @@ import sys
import threading import threading
import time import time
import uuid import uuid
from typing import Any, Dict, List, Optional, Union from typing import Any, AsyncGenerator, Dict, List, Optional, Union
import grpc import grpc
import zmq import zmq
...@@ -156,7 +156,7 @@ class GrpcRequestManager: ...@@ -156,7 +156,7 @@ class GrpcRequestManager:
obj: TokenizedGenerateReqInput, obj: TokenizedGenerateReqInput,
request_id: Optional[str] = None, request_id: Optional[str] = None,
grpc_context: Optional[grpc.aio.ServicerContext] = None, grpc_context: Optional[grpc.aio.ServicerContext] = None,
): ) -> AsyncGenerator[Union[Dict, List[Dict]], None]:
""" """
Submit a generation request to the scheduler with n>1 parallel sampling support. Submit a generation request to the scheduler with n>1 parallel sampling support.
......
...@@ -321,14 +321,14 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer) ...@@ -321,14 +321,14 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
logger.info(f"Sending health check request to request manager...") logger.info(f"Sending health check request to request manager...")
# Submit and wait for response # Submit and wait for response
output_queue = await self.request_manager.generate_request( output_generator = self.request_manager.generate_request(
health_request, request_id=rid health_request, request_id=rid
) )
try: try:
# Wait for response with configurable timeout # Get first response with timeout
response = await asyncio.wait_for( response = await asyncio.wait_for(
output_queue.get(), timeout=HEALTH_CHECK_TIMEOUT output_generator.__anext__(), timeout=HEALTH_CHECK_TIMEOUT
) )
# Clean up # Clean up
...@@ -492,13 +492,32 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer) ...@@ -492,13 +492,32 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
) -> sglang_scheduler_pb2.GenerateResponse: ) -> sglang_scheduler_pb2.GenerateResponse:
"""Create a completion response.""" """Create a completion response."""
# Determine finish reason # Extract meta info and finish reason details
finish_reason = sglang_scheduler_pb2.GenerateComplete.STOP
meta_info = output.get("meta_info", {}) meta_info = output.get("meta_info", {})
if meta_info.get("finish_reason") == "length": finish_reason_data = meta_info.get("finish_reason")
finish_reason = sglang_scheduler_pb2.GenerateComplete.LENGTH
elif meta_info.get("finish_reason") == "eos_token": # Determine finish reason, default is stop
finish_reason = sglang_scheduler_pb2.GenerateComplete.EOS_TOKEN finish_reason = "stop"
if finish_reason_data:
if isinstance(finish_reason_data, dict):
finish_reason_type = finish_reason_data.get("type")
else:
# Handle legacy string format
finish_reason_type = finish_reason_data
if finish_reason_type == "length":
finish_reason = "length"
elif finish_reason_type == "abort":
finish_reason = "abort"
# Extract matched_stop information
matched_stop_kwargs = {}
if isinstance(finish_reason_data, dict) and "matched" in finish_reason_data:
matched = finish_reason_data["matched"]
if isinstance(matched, int):
matched_stop_kwargs["matched_token_id"] = matched
elif isinstance(matched, str):
matched_stop_kwargs["matched_stop_str"] = matched
return sglang_scheduler_pb2.GenerateResponse( return sglang_scheduler_pb2.GenerateResponse(
request_id=request_id, request_id=request_id,
...@@ -510,6 +529,7 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer) ...@@ -510,6 +529,7 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
"completion_tokens", len(output.get("token_ids", [])) "completion_tokens", len(output.get("token_ids", []))
), ),
cached_tokens=meta_info.get("cached_tokens", 0), cached_tokens=meta_info.get("cached_tokens", 0),
**matched_stop_kwargs,
), ),
) )
......
...@@ -185,20 +185,8 @@ message GenerateComplete { ...@@ -185,20 +185,8 @@ message GenerateComplete {
// Final output // Final output
repeated uint32 output_ids = 1; repeated uint32 output_ids = 1;
// Finish reason // Finish reason as OpenAI-compatible string ("stop", "length", "abort")
enum FinishReason { string finish_reason = 2;
// The model generated a stop sequence.
STOP = 0;
// The model reached the maximum generation length.
LENGTH = 1;
// The model generated an end-of-sequence (EOS) token.
EOS_TOKEN = 2;
// The model generated a user-provided stop string.
STOP_STR = 3;
// The request was aborted by the user or system.
ABORT = 4;
}
FinishReason finish_reason = 2;
// Token usage counts // Token usage counts
int32 prompt_tokens = 3; int32 prompt_tokens = 3;
...@@ -210,6 +198,12 @@ message GenerateComplete { ...@@ -210,6 +198,12 @@ message GenerateComplete {
// All hidden states if requested // All hidden states if requested
repeated HiddenStates all_hidden_states = 7; repeated HiddenStates all_hidden_states = 7;
// Matched stop information (for stop sequences)
oneof matched_stop {
uint32 matched_token_id = 8;
string matched_stop_str = 9;
}
} }
message GenerateError { message GenerateError {
......
...@@ -3,7 +3,6 @@ import datetime ...@@ -3,7 +3,6 @@ import datetime
from google.protobuf import timestamp_pb2 as _timestamp_pb2 from google.protobuf import timestamp_pb2 as _timestamp_pb2
from google.protobuf import struct_pb2 as _struct_pb2 from google.protobuf import struct_pb2 as _struct_pb2
from google.protobuf.internal import containers as _containers from google.protobuf.internal import containers as _containers
from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper
from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message from google.protobuf import message as _message
from collections.abc import Iterable as _Iterable, Mapping as _Mapping from collections.abc import Iterable as _Iterable, Mapping as _Mapping
...@@ -179,19 +178,7 @@ class GenerateStreamChunk(_message.Message): ...@@ -179,19 +178,7 @@ class GenerateStreamChunk(_message.Message):
def __init__(self, token_ids: _Optional[_Iterable[int]] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., logprobs: _Optional[_Union[LogProbs, _Mapping]] = ..., hidden_states: _Optional[_Iterable[float]] = ...) -> None: ... def __init__(self, token_ids: _Optional[_Iterable[int]] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., logprobs: _Optional[_Union[LogProbs, _Mapping]] = ..., hidden_states: _Optional[_Iterable[float]] = ...) -> None: ...
class GenerateComplete(_message.Message): class GenerateComplete(_message.Message):
__slots__ = ("output_ids", "finish_reason", "prompt_tokens", "completion_tokens", "cached_tokens", "all_logprobs", "all_hidden_states") __slots__ = ("output_ids", "finish_reason", "prompt_tokens", "completion_tokens", "cached_tokens", "all_logprobs", "all_hidden_states", "matched_token_id", "matched_stop_str")
class FinishReason(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
__slots__ = ()
STOP: _ClassVar[GenerateComplete.FinishReason]
LENGTH: _ClassVar[GenerateComplete.FinishReason]
EOS_TOKEN: _ClassVar[GenerateComplete.FinishReason]
STOP_STR: _ClassVar[GenerateComplete.FinishReason]
ABORT: _ClassVar[GenerateComplete.FinishReason]
STOP: GenerateComplete.FinishReason
LENGTH: GenerateComplete.FinishReason
EOS_TOKEN: GenerateComplete.FinishReason
STOP_STR: GenerateComplete.FinishReason
ABORT: GenerateComplete.FinishReason
OUTPUT_IDS_FIELD_NUMBER: _ClassVar[int] OUTPUT_IDS_FIELD_NUMBER: _ClassVar[int]
FINISH_REASON_FIELD_NUMBER: _ClassVar[int] FINISH_REASON_FIELD_NUMBER: _ClassVar[int]
PROMPT_TOKENS_FIELD_NUMBER: _ClassVar[int] PROMPT_TOKENS_FIELD_NUMBER: _ClassVar[int]
...@@ -199,14 +186,18 @@ class GenerateComplete(_message.Message): ...@@ -199,14 +186,18 @@ class GenerateComplete(_message.Message):
CACHED_TOKENS_FIELD_NUMBER: _ClassVar[int] CACHED_TOKENS_FIELD_NUMBER: _ClassVar[int]
ALL_LOGPROBS_FIELD_NUMBER: _ClassVar[int] ALL_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
ALL_HIDDEN_STATES_FIELD_NUMBER: _ClassVar[int] ALL_HIDDEN_STATES_FIELD_NUMBER: _ClassVar[int]
MATCHED_TOKEN_ID_FIELD_NUMBER: _ClassVar[int]
MATCHED_STOP_STR_FIELD_NUMBER: _ClassVar[int]
output_ids: _containers.RepeatedScalarFieldContainer[int] output_ids: _containers.RepeatedScalarFieldContainer[int]
finish_reason: GenerateComplete.FinishReason finish_reason: str
prompt_tokens: int prompt_tokens: int
completion_tokens: int completion_tokens: int
cached_tokens: int cached_tokens: int
all_logprobs: _containers.RepeatedCompositeFieldContainer[LogProbs] all_logprobs: _containers.RepeatedCompositeFieldContainer[LogProbs]
all_hidden_states: _containers.RepeatedCompositeFieldContainer[HiddenStates] all_hidden_states: _containers.RepeatedCompositeFieldContainer[HiddenStates]
def __init__(self, output_ids: _Optional[_Iterable[int]] = ..., finish_reason: _Optional[_Union[GenerateComplete.FinishReason, str]] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., all_logprobs: _Optional[_Iterable[_Union[LogProbs, _Mapping]]] = ..., all_hidden_states: _Optional[_Iterable[_Union[HiddenStates, _Mapping]]] = ...) -> None: ... matched_token_id: int
matched_stop_str: str
def __init__(self, output_ids: _Optional[_Iterable[int]] = ..., finish_reason: _Optional[str] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., all_logprobs: _Optional[_Iterable[_Union[LogProbs, _Mapping]]] = ..., all_hidden_states: _Optional[_Iterable[_Union[HiddenStates, _Mapping]]] = ..., matched_token_id: _Optional[int] = ..., matched_stop_str: _Optional[str] = ...) -> None: ...
class GenerateError(_message.Message): class GenerateError(_message.Message):
__slots__ = ("message", "http_status_code", "details") __slots__ = ("message", "http_status_code", "details")
......
...@@ -185,20 +185,8 @@ message GenerateComplete { ...@@ -185,20 +185,8 @@ message GenerateComplete {
// Final output // Final output
repeated uint32 output_ids = 1; repeated uint32 output_ids = 1;
// Finish reason // Finish reason as OpenAI-compatible string ("stop", "length", "abort")
enum FinishReason { string finish_reason = 2;
// The model generated a stop sequence.
STOP = 0;
// The model reached the maximum generation length.
LENGTH = 1;
// The model generated an end-of-sequence (EOS) token.
EOS_TOKEN = 2;
// The model generated a user-provided stop string.
STOP_STR = 3;
// The request was aborted by the user or system.
ABORT = 4;
}
FinishReason finish_reason = 2;
// Token usage counts // Token usage counts
int32 prompt_tokens = 3; int32 prompt_tokens = 3;
...@@ -210,6 +198,12 @@ message GenerateComplete { ...@@ -210,6 +198,12 @@ message GenerateComplete {
// All hidden states if requested // All hidden states if requested
repeated HiddenStates all_hidden_states = 7; repeated HiddenStates all_hidden_states = 7;
// Matched stop information (for stop sequences)
oneof matched_stop {
uint32 matched_token_id = 8;
string matched_stop_str = 9;
}
} }
message GenerateError { message GenerateError {
......
...@@ -423,10 +423,25 @@ pub struct ChatCompletionResponse { ...@@ -423,10 +423,25 @@ pub struct ChatCompletionResponse {
pub system_fingerprint: Option<String>, pub system_fingerprint: Option<String>,
} }
/// Response message structure for ChatCompletionResponse (different from request ChatMessage)
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ChatCompletionMessage {
pub role: String, // Always "assistant" for responses
#[serde(skip_serializing_if = "Option::is_none")]
pub content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCall>>,
/// Reasoning content for O1-style models (SGLang extension)
#[serde(skip_serializing_if = "Option::is_none")]
pub reasoning_content: Option<String>,
// Note: function_call is deprecated and not included
// Note: refusal, annotations, audio are not added yet
}
#[derive(Debug, Clone, Deserialize, Serialize)] #[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ChatChoice { pub struct ChatChoice {
pub index: u32, pub index: u32,
pub message: ChatMessage, pub message: ChatCompletionMessage,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub logprobs: Option<ChatLogProbs>, pub logprobs: Option<ChatLogProbs>,
pub finish_reason: Option<String>, // "stop", "length", "tool_calls", "content_filter", "function_call" pub finish_reason: Option<String>, // "stop", "length", "tool_calls", "content_filter", "function_call"
......
...@@ -8,6 +8,7 @@ use axum::{ ...@@ -8,6 +8,7 @@ use axum::{
extract::Request, extract::Request,
http::{HeaderMap, StatusCode}, http::{HeaderMap, StatusCode},
response::{IntoResponse, Response}, response::{IntoResponse, Response},
Json,
}; };
use tracing::{debug, error, info, warn}; use tracing::{debug, error, info, warn};
...@@ -18,8 +19,9 @@ use crate::metrics::RouterMetrics; ...@@ -18,8 +19,9 @@ use crate::metrics::RouterMetrics;
use crate::policies::PolicyRegistry; use crate::policies::PolicyRegistry;
use crate::protocols::spec::ChatMessage; use crate::protocols::spec::ChatMessage;
use crate::protocols::spec::{ use crate::protocols::spec::{
ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest, ChatChoice, ChatCompletionMessage, ChatCompletionRequest, ChatCompletionResponse,
ResponsesGetParams, ResponsesRequest, StringOrArray, Tool, ToolChoice, CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest, ResponsesGetParams,
ResponsesRequest, StringOrArray, Tool, ToolChoice, Usage,
}; };
use crate::reasoning_parser::ParserFactory; use crate::reasoning_parser::ParserFactory;
use crate::routers::RouterTrait; use crate::routers::RouterTrait;
...@@ -30,6 +32,7 @@ use crate::tokenizer::traits::Tokenizer; ...@@ -30,6 +32,7 @@ use crate::tokenizer::traits::Tokenizer;
use crate::tokenizer::HuggingFaceTokenizer; use crate::tokenizer::HuggingFaceTokenizer;
use crate::tool_parser::ParserRegistry; use crate::tool_parser::ParserRegistry;
use serde_json::Value; use serde_json::Value;
use std::time::{SystemTime, UNIX_EPOCH};
use tokio_stream::StreamExt; use tokio_stream::StreamExt;
use uuid::Uuid; use uuid::Uuid;
...@@ -648,36 +651,99 @@ impl GrpcRouter { ...@@ -648,36 +651,99 @@ impl GrpcRouter {
Err(e) => return fail_fmt("Failed to start generation: ", &e), Err(e) => return fail_fmt("Failed to start generation: ", &e),
}; };
// Get the single Complete response // Collect all responses (for n>1 support)
let gen_response = match stream.next().await { let mut all_responses = Vec::new();
Some(Ok(r)) => r, while let Some(response) = stream.next().await {
Some(Err(e)) => return fail_fmt("Failed to get GenerateResponse: ", &e), match response {
None => return fail_str("No response from server"), Ok(gen_response) => match gen_response.response {
}; Some(proto::generate_response::Response::Complete(complete)) => {
all_responses.push(complete);
// Extract the expected variant early }
let complete = match gen_response.response { Some(proto::generate_response::Response::Error(err)) => {
Some(proto::generate_response::Response::Complete(c)) => c, error!("Generation failed for one choice: {}", err.message);
Some(proto::generate_response::Response::Error(err)) => { return (
error!("Generation failed: {}", err.message); StatusCode::INTERNAL_SERVER_ERROR,
return ( format!("Generation failed: {}", err.message),
StatusCode::INTERNAL_SERVER_ERROR, )
format!("Generation failed: {}", err.message), .into_response();
) }
.into_response(); Some(proto::generate_response::Response::Chunk(_)) => {
return fail_str("Unexpected chunk response for non-streaming request")
}
None => return fail_str("Empty response from server"),
},
Err(e) => return fail_fmt("Failed to get GenerateResponse: ", &e),
} }
Some(proto::generate_response::Response::Chunk(_)) => { }
return fail_str("Unexpected chunk response for non-streaming request")
if all_responses.is_empty() {
return fail_str("No responses from server");
}
// Process each response into a ChatChoice
let mut choices = Vec::new();
for (index, complete) in all_responses.iter().enumerate() {
match self
.process_single_choice(complete, index, original_request, &mut stop_decoder)
.await
{
Ok(choice) => choices.push(choice),
Err(e) => {
error!("Failed to process choice {}: {}", index, e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to process choice {}: {}", index, e),
)
.into_response();
}
} }
None => return fail_str("Empty response from server"), }
// Aggregate usage information from all responses
let total_prompt_tokens: u32 = all_responses.iter().map(|r| r.prompt_tokens as u32).sum();
let total_completion_tokens: u32 = all_responses
.iter()
.map(|r| r.completion_tokens as u32)
.sum();
let usage = Usage {
prompt_tokens: total_prompt_tokens,
completion_tokens: total_completion_tokens,
total_tokens: total_prompt_tokens + total_completion_tokens,
completion_tokens_details: None,
}; };
// Decode tokens // Build final ChatCompletionResponse
let outputs = match stop_decoder.process_tokens(&complete.output_ids) { let response = ChatCompletionResponse {
Ok(o) => o, id: format!("chatcmpl-{}", Uuid::new_v4()),
Err(e) => return fail_fmt("Failed to process tokens: ", &e), object: "chat.completion".to_string(),
created: SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
model: original_request.model.clone(),
choices,
usage: Some(usage),
system_fingerprint: None,
}; };
// Serialize and return JSON response
Json(response).into_response()
}
/// Process a single GenerateComplete response into a ChatChoice
async fn process_single_choice(
&self,
complete: &proto::GenerateComplete,
index: usize,
original_request: &ChatCompletionRequest,
stop_decoder: &mut crate::tokenizer::stop::StopSequenceDecoder,
) -> Result<ChatChoice, String> {
stop_decoder.reset();
// Decode tokens
let outputs = stop_decoder
.process_tokens(&complete.output_ids)
.map_err(|e| format!("Failed to process tokens: {}", e))?;
// Accumulate text with early breaks // Accumulate text with early breaks
let mut final_text = String::new(); let mut final_text = String::new();
for output in outputs { for output in outputs {
...@@ -697,8 +763,119 @@ impl GrpcRouter { ...@@ -697,8 +763,119 @@ impl GrpcRouter {
final_text.push_str(&t); final_text.push_str(&t);
} }
// TODO: Create proper OpenAI-compatible response // Step 1: Handle reasoning content parsing
(StatusCode::OK, format!("Final text: {}", final_text)).into_response() let mut reasoning_text: Option<String> = None;
let mut processed_text = final_text;
// Check if reasoning parsing is enabled and separate_reasoning is requested
if original_request.separate_reasoning {
if let Ok(mut parser) = self
.reasoning_parser_factory
.create(&original_request.model)
{
match parser.detect_and_parse_reasoning(&processed_text) {
Ok(result) => {
if !result.reasoning_text.is_empty() {
reasoning_text = Some(result.reasoning_text);
}
processed_text = result.normal_text;
}
Err(e) => {
return Err(format!("Reasoning parsing error: {}", e));
}
}
}
}
// Step 2: Handle tool call parsing
let mut tool_calls: Option<Vec<crate::protocols::spec::ToolCall>> = None;
// Check if tool calls should be processed
let tool_choice_enabled = !matches!(
&original_request.tool_choice,
Some(ToolChoice::Value(
crate::protocols::spec::ToolChoiceValue::None
))
);
if tool_choice_enabled && original_request.tools.is_some() {
if let Some(parser) = self
.tool_parser_registry
.get_parser(&original_request.model)
{
match parser.parse_complete(&processed_text).await {
Ok(parsed_tool_calls) => {
if !parsed_tool_calls.is_empty() {
let spec_tool_calls = parsed_tool_calls
.into_iter()
.map(|tc| crate::protocols::spec::ToolCall {
id: tc.id,
tool_type: "function".to_string(),
function: crate::protocols::spec::FunctionCallResponse {
name: tc.function.name,
arguments: Some(
serde_json::to_string(&tc.function.arguments)
.unwrap_or_else(|_| "{}".to_string()),
),
},
})
.collect();
tool_calls = Some(spec_tool_calls);
processed_text = String::new();
}
}
Err(e) => {
error!("Tool call parsing error: {}", e);
// Continue without tool calls rather than failing
}
}
}
}
// Step 3: Use finish reason directly from proto (already OpenAI-compatible string)
let finish_reason_str = &complete.finish_reason;
// Override finish reason if we have tool calls
let final_finish_reason_str = if tool_calls.is_some() {
"tool_calls"
} else {
finish_reason_str
};
// Extract matched_stop information from proto
let matched_stop = match &complete.matched_stop {
Some(proto::generate_complete::MatchedStop::MatchedTokenId(token_id)) => Some(
serde_json::Value::Number(serde_json::Number::from(*token_id)),
),
Some(proto::generate_complete::MatchedStop::MatchedStopStr(stop_str)) => {
Some(serde_json::Value::String(stop_str.clone()))
}
None => None,
};
// Step 4: Build ChatCompletionMessage (proper response message type)
let chat_message = ChatCompletionMessage {
role: "assistant".to_string(),
content: if processed_text.is_empty() {
None
} else {
Some(processed_text)
},
tool_calls,
reasoning_content: reasoning_text,
};
// Step 5: Build ChatChoice
let choice = ChatChoice {
index: index as u32,
message: chat_message,
logprobs: None,
finish_reason: Some(final_finish_reason_str.to_string()),
matched_stop,
hidden_states: None,
};
Ok(choice)
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment