"vscode:/vscode.git/clone" did not exist on "b6d4702301365fc2a3dce6b4739ec534e09ce36f"
Unverified Commit 963175d5 authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

[router][grpc] Support streaming for v1/chat/completions (#11179)

parent 0618ad6d
......@@ -578,7 +578,7 @@ class GrpcRequestManager:
batch_out.cached_tokens[i] if batch_out.cached_tokens else 0
),
"finish_reason": (
str(batch_out.finished_reasons[i])
batch_out.finished_reasons[i]
if batch_out.finished_reasons[i]
else None
),
......
......@@ -112,7 +112,6 @@ def _launch_scheduler_process_only(
pp_rank,
None,
writer,
None,
),
)
......@@ -583,6 +582,7 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
cached_tokens=meta_info.get("cached_tokens", 0),
output_logprobs=output_logprobs_proto,
input_logprobs=input_logprobs_proto,
index=output.get("index", 0),
),
)
......@@ -640,6 +640,7 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
cached_tokens=meta_info.get("cached_tokens", 0),
output_logprobs=output_logprobs_proto,
input_logprobs=input_logprobs_proto,
index=output.get("index", 0),
**matched_stop_kwargs,
),
)
......
......@@ -179,6 +179,9 @@ message GenerateStreamChunk {
// Input logprobs (if requested) - only in first chunk
InputLogProbs input_logprobs = 7;
// Index for ordering when n>1 (for parallel request multiplexing)
uint32 index = 8;
}
message GenerateComplete {
......@@ -207,6 +210,9 @@ message GenerateComplete {
// Input logprobs if requested (for prompt tokens)
InputLogProbs input_logprobs = 10;
// Index for ordering when n>1 (for parallel request multiplexing)
uint32 index = 11;
}
message GenerateError {
......
......@@ -160,7 +160,7 @@ class GenerateResponse(_message.Message):
def __init__(self, request_id: _Optional[str] = ..., chunk: _Optional[_Union[GenerateStreamChunk, _Mapping]] = ..., complete: _Optional[_Union[GenerateComplete, _Mapping]] = ..., error: _Optional[_Union[GenerateError, _Mapping]] = ...) -> None: ...
class GenerateStreamChunk(_message.Message):
__slots__ = ("token_ids", "prompt_tokens", "completion_tokens", "cached_tokens", "output_logprobs", "hidden_states", "input_logprobs")
__slots__ = ("token_ids", "prompt_tokens", "completion_tokens", "cached_tokens", "output_logprobs", "hidden_states", "input_logprobs", "index")
TOKEN_IDS_FIELD_NUMBER: _ClassVar[int]
PROMPT_TOKENS_FIELD_NUMBER: _ClassVar[int]
COMPLETION_TOKENS_FIELD_NUMBER: _ClassVar[int]
......@@ -168,6 +168,7 @@ class GenerateStreamChunk(_message.Message):
OUTPUT_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
HIDDEN_STATES_FIELD_NUMBER: _ClassVar[int]
INPUT_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
INDEX_FIELD_NUMBER: _ClassVar[int]
token_ids: _containers.RepeatedScalarFieldContainer[int]
prompt_tokens: int
completion_tokens: int
......@@ -175,10 +176,11 @@ class GenerateStreamChunk(_message.Message):
output_logprobs: OutputLogProbs
hidden_states: _containers.RepeatedScalarFieldContainer[float]
input_logprobs: InputLogProbs
def __init__(self, token_ids: _Optional[_Iterable[int]] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., output_logprobs: _Optional[_Union[OutputLogProbs, _Mapping]] = ..., hidden_states: _Optional[_Iterable[float]] = ..., input_logprobs: _Optional[_Union[InputLogProbs, _Mapping]] = ...) -> None: ...
index: int
def __init__(self, token_ids: _Optional[_Iterable[int]] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., output_logprobs: _Optional[_Union[OutputLogProbs, _Mapping]] = ..., hidden_states: _Optional[_Iterable[float]] = ..., input_logprobs: _Optional[_Union[InputLogProbs, _Mapping]] = ..., index: _Optional[int] = ...) -> None: ...
class GenerateComplete(_message.Message):
__slots__ = ("output_ids", "finish_reason", "prompt_tokens", "completion_tokens", "cached_tokens", "output_logprobs", "all_hidden_states", "matched_token_id", "matched_stop_str", "input_logprobs")
__slots__ = ("output_ids", "finish_reason", "prompt_tokens", "completion_tokens", "cached_tokens", "output_logprobs", "all_hidden_states", "matched_token_id", "matched_stop_str", "input_logprobs", "index")
OUTPUT_IDS_FIELD_NUMBER: _ClassVar[int]
FINISH_REASON_FIELD_NUMBER: _ClassVar[int]
PROMPT_TOKENS_FIELD_NUMBER: _ClassVar[int]
......@@ -189,6 +191,7 @@ class GenerateComplete(_message.Message):
MATCHED_TOKEN_ID_FIELD_NUMBER: _ClassVar[int]
MATCHED_STOP_STR_FIELD_NUMBER: _ClassVar[int]
INPUT_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
INDEX_FIELD_NUMBER: _ClassVar[int]
output_ids: _containers.RepeatedScalarFieldContainer[int]
finish_reason: str
prompt_tokens: int
......@@ -199,7 +202,8 @@ class GenerateComplete(_message.Message):
matched_token_id: int
matched_stop_str: str
input_logprobs: InputLogProbs
def __init__(self, output_ids: _Optional[_Iterable[int]] = ..., finish_reason: _Optional[str] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., output_logprobs: _Optional[_Union[OutputLogProbs, _Mapping]] = ..., all_hidden_states: _Optional[_Iterable[_Union[HiddenStates, _Mapping]]] = ..., matched_token_id: _Optional[int] = ..., matched_stop_str: _Optional[str] = ..., input_logprobs: _Optional[_Union[InputLogProbs, _Mapping]] = ...) -> None: ...
index: int
def __init__(self, output_ids: _Optional[_Iterable[int]] = ..., finish_reason: _Optional[str] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., output_logprobs: _Optional[_Union[OutputLogProbs, _Mapping]] = ..., all_hidden_states: _Optional[_Iterable[_Union[HiddenStates, _Mapping]]] = ..., matched_token_id: _Optional[int] = ..., matched_stop_str: _Optional[str] = ..., input_logprobs: _Optional[_Union[InputLogProbs, _Mapping]] = ..., index: _Optional[int] = ...) -> None: ...
class GenerateError(_message.Message):
__slots__ = ("message", "http_status_code", "details")
......
......@@ -192,7 +192,6 @@ fn create_large_chat_completion_request() -> ChatCompletionRequest {
content: Some(format!("Answer {}: This is a detailed response about topic {} that covers multiple aspects and provides comprehensive analysis of the interconnected systems you mentioned.", i, i)),
name: None,
tool_calls: None,
function_call: None,
reasoning_content: None,
});
}
......
......@@ -179,6 +179,9 @@ message GenerateStreamChunk {
// Input logprobs (if requested) - only in first chunk
InputLogProbs input_logprobs = 7;
// Index for ordering when n>1 (for parallel request multiplexing)
uint32 index = 8;
}
message GenerateComplete {
......@@ -207,6 +210,9 @@ message GenerateComplete {
// Input logprobs if requested (for prompt tokens)
InputLogProbs input_logprobs = 10;
// Index for ordering when n>1 (for parallel request multiplexing)
uint32 index = 11;
}
message GenerateError {
......
......@@ -72,8 +72,6 @@ pub enum ChatMessage {
name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
tool_calls: Option<Vec<ToolCall>>,
#[serde(skip_serializing_if = "Option::is_none")]
function_call: Option<FunctionCallResponse>,
/// Reasoning content for O1-style models (SGLang extension)
#[serde(skip_serializing_if = "Option::is_none")]
reasoning_content: Option<String>,
......@@ -140,8 +138,6 @@ pub struct ChatMessageDelta {
pub content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCallDelta>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub function_call: Option<FunctionCallDelta>,
/// Reasoning content delta for O1-style models (SGLang extension)
#[serde(skip_serializing_if = "Option::is_none")]
pub reasoning_content: Option<String>,
......@@ -473,6 +469,8 @@ pub struct ChatStreamChoice {
#[serde(skip_serializing_if = "Option::is_none")]
pub logprobs: Option<ChatLogProbs>,
pub finish_reason: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub matched_stop: Option<Value>,
}
// Completions API request types (v1/completions) - DEPRECATED but still supported
......
......@@ -44,7 +44,7 @@ graph TB
end
subgraph Factory Layer
MID --> PF[ParserFactory]
MID --> PF[ReasoningParserFactory]
PF --> REG[ParserRegistry]
REG --> PM[Pattern Matching]
PM --> PP[Parser Pool]
......@@ -93,7 +93,7 @@ graph TB
```mermaid
sequenceDiagram
participant C as Client
participant F as ParserFactory
participant F as ReasoningParserFactory
participant R as Registry
participant P as Parser Pool
participant BP as BaseParser
......@@ -206,7 +206,7 @@ classDiagram
+new() Self
}
class ParserFactory {
class ReasoningParserFactory {
-registry: ParserRegistry
+new() Self
+get_pooled(model_id: &str) PooledParser
......@@ -240,7 +240,7 @@ classDiagram
Step3Parser o-- BaseReasoningParser
BaseReasoningParser o-- ParserConfig
ParserFactory o-- ParserRegistry
ReasoningParserFactory o-- ParserRegistry
ParserRegistry o-- ReasoningParser
```
......@@ -302,7 +302,7 @@ classDiagram
- Delegate to get_pooled_parser
- Case-insensitive comparison
**ParserFactory Methods**:
**ReasoningParserFactory Methods**:
1. **`new()`**:
- Register all built-in parsers
......@@ -437,7 +437,7 @@ impl ReasoningParser for MyModelParser {
**Step 2: Register in Factory**
```rust
// In factory.rs ParserFactory::new()
// In factory.rs ReasoningParserFactory::new()
registry.register_parser("mymodel", || {
Box::new(MyModelParser::new())
});
......
......@@ -128,11 +128,11 @@ impl Default for ParserRegistry {
/// Factory for creating reasoning parsers based on model type.
#[derive(Clone)]
pub struct ParserFactory {
pub struct ReasoningParserFactory {
registry: ParserRegistry,
}
impl ParserFactory {
impl ReasoningParserFactory {
/// Create a new factory with default parsers registered.
pub fn new() -> Self {
let registry = ParserRegistry::new();
......@@ -237,7 +237,7 @@ impl ParserFactory {
}
}
impl Default for ParserFactory {
impl Default for ReasoningParserFactory {
fn default() -> Self {
Self::new()
}
......@@ -249,35 +249,35 @@ mod tests {
#[test]
fn test_factory_creates_deepseek_r1() {
let factory = ParserFactory::new();
let factory = ReasoningParserFactory::new();
let parser = factory.create("deepseek-r1-distill").unwrap();
assert_eq!(parser.model_type(), "deepseek_r1");
}
#[test]
fn test_factory_creates_qwen3() {
let factory = ParserFactory::new();
let factory = ReasoningParserFactory::new();
let parser = factory.create("qwen3-7b").unwrap();
assert_eq!(parser.model_type(), "qwen3");
}
#[test]
fn test_factory_creates_kimi() {
let factory = ParserFactory::new();
let factory = ReasoningParserFactory::new();
let parser = factory.create("kimi-chat").unwrap();
assert_eq!(parser.model_type(), "kimi");
}
#[test]
fn test_factory_fallback_to_passthrough() {
let factory = ParserFactory::new();
let factory = ReasoningParserFactory::new();
let parser = factory.create("unknown-model").unwrap();
assert_eq!(parser.model_type(), "passthrough");
}
#[test]
fn test_case_insensitive_matching() {
let factory = ParserFactory::new();
let factory = ReasoningParserFactory::new();
let parser1 = factory.create("DeepSeek-R1").unwrap();
let parser2 = factory.create("QWEN3").unwrap();
let parser3 = factory.create("Kimi").unwrap();
......@@ -289,21 +289,21 @@ mod tests {
#[test]
fn test_step3_model() {
let factory = ParserFactory::new();
let factory = ReasoningParserFactory::new();
let step3 = factory.create("step3-model").unwrap();
assert_eq!(step3.model_type(), "step3");
}
#[test]
fn test_glm45_model() {
let factory = ParserFactory::new();
let factory = ReasoningParserFactory::new();
let glm45 = factory.create("glm45-v2").unwrap();
assert_eq!(glm45.model_type(), "glm45");
}
#[test]
fn test_pooled_parser_reuse() {
let factory = ParserFactory::new();
let factory = ReasoningParserFactory::new();
// Get the same parser twice - should be the same instance
let parser1 = factory.get_pooled("deepseek-r1");
......@@ -321,7 +321,7 @@ mod tests {
fn test_pooled_parser_concurrent_access() {
use std::thread;
let factory = ParserFactory::new();
let factory = ReasoningParserFactory::new();
let parser = factory.get_pooled("deepseek-r1");
// Spawn multiple threads that use the same parser
......@@ -347,7 +347,7 @@ mod tests {
#[test]
fn test_pool_clearing() {
let factory = ParserFactory::new();
let factory = ReasoningParserFactory::new();
// Get a pooled parser
let parser1 = factory.get_pooled("deepseek-r1");
......@@ -364,7 +364,7 @@ mod tests {
#[test]
fn test_passthrough_parser_pooling() {
let factory = ParserFactory::new();
let factory = ReasoningParserFactory::new();
// Unknown models should get passthrough parser
let parser1 = factory.get_pooled("unknown-model-1");
......@@ -383,7 +383,7 @@ mod tests {
use std::thread;
use std::time::Instant;
let factory = ParserFactory::new();
let factory = ReasoningParserFactory::new();
let num_threads = 100;
let requests_per_thread = 50;
let models = vec!["deepseek-r1", "qwen3", "kimi", "qwen3-thinking"];
......@@ -527,7 +527,7 @@ mod tests {
fn test_concurrent_pool_modifications() {
use std::thread;
let factory = ParserFactory::new();
let factory = ReasoningParserFactory::new();
let mut handles = vec![];
// Thread 1: Continuously get parsers
......
......@@ -2,7 +2,7 @@ pub mod factory;
pub mod parsers;
pub mod traits;
pub use factory::{ParserFactory, ParserRegistry, PooledParser};
pub use factory::{ParserRegistry, PooledParser, ReasoningParserFactory};
pub use parsers::{
BaseReasoningParser, DeepSeekR1Parser, Glm45Parser, KimiParser, Qwen3Parser,
QwenThinkingParser, Step3Parser,
......
......@@ -4,7 +4,7 @@ use crate::config::types::RetryConfig;
use crate::core::{WorkerRegistry, WorkerType};
use crate::metrics::RouterMetrics;
use crate::policies::PolicyRegistry;
use crate::reasoning_parser::ParserFactory;
use crate::reasoning_parser::ReasoningParserFactory;
use crate::routers::RouterTrait;
use crate::tokenizer::traits::Tokenizer;
use crate::tool_parser::ToolParserFactory;
......@@ -24,7 +24,7 @@ pub struct GrpcPDRouter {
worker_registry: Arc<WorkerRegistry>,
policy_registry: Arc<PolicyRegistry>,
tokenizer: Arc<dyn Tokenizer>,
reasoning_parser_factory: ParserFactory,
reasoning_parser_factory: ReasoningParserFactory,
tool_parser_factory: ToolParserFactory,
dp_aware: bool,
......
This diff is collapsed.
......@@ -15,7 +15,7 @@ use crate::{
},
worker_spec::{WorkerApiResponse, WorkerConfigRequest, WorkerErrorResponse},
},
reasoning_parser::ParserFactory,
reasoning_parser::ReasoningParserFactory,
routers::{router_manager::RouterManager, RouterTrait},
service_discovery::{start_service_discovery, ServiceDiscoveryConfig},
tokenizer::{factory as tokenizer_factory, traits::Tokenizer},
......@@ -45,7 +45,7 @@ pub struct AppContext {
pub router_config: RouterConfig,
pub rate_limiter: Arc<TokenBucket>,
pub tokenizer: Option<Arc<dyn Tokenizer>>,
pub reasoning_parser_factory: Option<ParserFactory>,
pub reasoning_parser_factory: Option<ReasoningParserFactory>,
pub tool_parser_factory: Option<ToolParserFactory>,
pub worker_registry: Arc<WorkerRegistry>,
pub policy_registry: Arc<PolicyRegistry>,
......@@ -79,7 +79,7 @@ impl AppContext {
tokenizer_factory::create_tokenizer(&tokenizer_path)
.map_err(|e| format!("Failed to create tokenizer: {e}"))?,
);
let reasoning_parser_factory = Some(ParserFactory::new());
let reasoning_parser_factory = Some(ReasoningParserFactory::new());
let tool_parser_factory = Some(ToolParserFactory::new());
(tokenizer, reasoning_parser_factory, tool_parser_factory)
......
......@@ -123,12 +123,7 @@ impl DeepSeekParser {
let arguments = serde_json::to_string(&args)
.map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?;
// Generate ID
let id = format!("deepseek_call_{}", uuid::Uuid::new_v4());
Ok(ToolCall {
id,
r#type: "function".to_string(),
function: FunctionCall {
name: func_name.to_string(),
arguments,
......@@ -320,4 +315,8 @@ impl ToolParser for DeepSeekParser {
fn detect_format(&self, text: &str) -> bool {
self.has_tool_markers(text)
}
fn get_unstreamed_tool_args(&self) -> Option<Vec<ToolCallItem>> {
helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool)
}
}
......@@ -129,12 +129,7 @@ impl Glm4MoeParser {
let arguments_str = serde_json::to_string(&arguments)
.map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?;
// Generate ID
let id = format!("glm4_call_{}", uuid::Uuid::new_v4());
Ok(Some(ToolCall {
id,
r#type: "function".to_string(),
function: FunctionCall {
name: func_name.to_string(),
arguments: arguments_str,
......@@ -321,4 +316,8 @@ impl ToolParser for Glm4MoeParser {
fn detect_format(&self, text: &str) -> bool {
self.has_tool_markers(text)
}
fn get_unstreamed_tool_args(&self) -> Option<Vec<ToolCallItem>> {
helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool)
}
}
......@@ -113,12 +113,7 @@ impl ToolParser for GptOssParser {
}
};
// Generate unique ID
let id = format!("gpt_oss_call_{}", uuid::Uuid::new_v4());
tools.push(ToolCall {
id,
r#type: "function".to_string(),
function: FunctionCall {
name: function_name,
arguments,
......
......@@ -14,6 +14,48 @@ pub fn get_tool_indices(tools: &[Tool]) -> HashMap<String, usize> {
.collect()
}
/// Get unstreamed tool call arguments
/// Returns tool call items for arguments that have been parsed but not yet streamed
/// This ensures tool calls are properly completed even if the model generates final arguments in the last chunk
pub fn get_unstreamed_args(
prev_tool_call_arr: &[Value],
streamed_args_for_tool: &[String],
) -> Option<Vec<ToolCallItem>> {
// Check if we have tool calls being tracked
if prev_tool_call_arr.is_empty() || streamed_args_for_tool.is_empty() {
return None;
}
// Get the last tool call that was being processed
let tool_index = prev_tool_call_arr.len() - 1;
if tool_index >= streamed_args_for_tool.len() {
return None;
}
// Get expected vs actual arguments
let expected_args = prev_tool_call_arr[tool_index].get("arguments")?;
let expected_str = serde_json::to_string(expected_args).ok()?;
let actual_str = &streamed_args_for_tool[tool_index];
// Check if there are remaining arguments to send
let remaining = if expected_str.starts_with(actual_str) {
&expected_str[actual_str.len()..]
} else {
return None;
};
if remaining.is_empty() {
return None;
}
// Return the remaining arguments as a ToolCallItem
Some(vec![ToolCallItem {
tool_index,
name: None, // No name for argument deltas
parameters: remaining.to_string(),
}])
}
/// Check if a buffer ends with a partial occurrence of a token
/// Returns Some(length) if there's a partial match, None otherwise
pub fn ends_with_partial_token(buffer: &str, token: &str) -> Option<usize> {
......
......@@ -8,7 +8,7 @@ use crate::tool_parser::{
parsers::helpers,
partial_json::PartialJson,
traits::ToolParser,
types::{FunctionCall, StreamingParseResult, ToolCall},
types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem},
};
/// JSON format parser for tool calls
......@@ -136,16 +136,7 @@ impl JsonParser {
let arguments = serde_json::to_string(args)
.map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?;
// Generate a unique ID if not provided
let id = obj
.get("id")
.and_then(|v| v.as_str())
.map(String::from)
.unwrap_or_else(|| format!("call_{}", uuid::Uuid::new_v4()));
Ok(Some(ToolCall {
id,
r#type: "function".to_string(),
function: FunctionCall {
name: name.to_string(),
arguments,
......@@ -274,4 +265,8 @@ impl ToolParser for JsonParser {
let trimmed = text.trim();
(trimmed.starts_with('[') || trimmed.starts_with('{')) && trimmed.contains(r#""name""#)
}
fn get_unstreamed_tool_args(&self) -> Option<Vec<ToolCallItem>> {
helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool)
}
}
......@@ -131,12 +131,7 @@ impl ToolParser for KimiK2Parser {
// Try to parse JSON arguments
match serde_json::from_str::<serde_json::Value>(function_args) {
Ok(_) => {
// Generate unique ID
let id = format!("kimi_call_{}", uuid::Uuid::new_v4());
tools.push(ToolCall {
id,
r#type: "function".to_string(),
function: FunctionCall {
name: func_name,
arguments: function_args.to_string(),
......@@ -339,4 +334,8 @@ impl ToolParser for KimiK2Parser {
fn detect_format(&self, text: &str) -> bool {
self.has_tool_markers(text) || text.contains("<|tool_call_begin|>")
}
fn get_unstreamed_tool_args(&self) -> Option<Vec<ToolCallItem>> {
helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool)
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment