"docs/git@developer.sourcefind.cn:change/sglang.git" did not exist on "bcc213df61bcfe6a9805ee80076352112c56e5ac"
Unverified Commit 963175d5 authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

[router][grpc] Support streaming for v1/chat/completions (#11179)

parent 0618ad6d
...@@ -578,7 +578,7 @@ class GrpcRequestManager: ...@@ -578,7 +578,7 @@ class GrpcRequestManager:
batch_out.cached_tokens[i] if batch_out.cached_tokens else 0 batch_out.cached_tokens[i] if batch_out.cached_tokens else 0
), ),
"finish_reason": ( "finish_reason": (
str(batch_out.finished_reasons[i]) batch_out.finished_reasons[i]
if batch_out.finished_reasons[i] if batch_out.finished_reasons[i]
else None else None
), ),
......
...@@ -112,7 +112,6 @@ def _launch_scheduler_process_only( ...@@ -112,7 +112,6 @@ def _launch_scheduler_process_only(
pp_rank, pp_rank,
None, None,
writer, writer,
None,
), ),
) )
...@@ -583,6 +582,7 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer) ...@@ -583,6 +582,7 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
cached_tokens=meta_info.get("cached_tokens", 0), cached_tokens=meta_info.get("cached_tokens", 0),
output_logprobs=output_logprobs_proto, output_logprobs=output_logprobs_proto,
input_logprobs=input_logprobs_proto, input_logprobs=input_logprobs_proto,
index=output.get("index", 0),
), ),
) )
...@@ -640,6 +640,7 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer) ...@@ -640,6 +640,7 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
cached_tokens=meta_info.get("cached_tokens", 0), cached_tokens=meta_info.get("cached_tokens", 0),
output_logprobs=output_logprobs_proto, output_logprobs=output_logprobs_proto,
input_logprobs=input_logprobs_proto, input_logprobs=input_logprobs_proto,
index=output.get("index", 0),
**matched_stop_kwargs, **matched_stop_kwargs,
), ),
) )
......
...@@ -179,6 +179,9 @@ message GenerateStreamChunk { ...@@ -179,6 +179,9 @@ message GenerateStreamChunk {
// Input logprobs (if requested) - only in first chunk // Input logprobs (if requested) - only in first chunk
InputLogProbs input_logprobs = 7; InputLogProbs input_logprobs = 7;
// Index for ordering when n>1 (for parallel request multiplexing)
uint32 index = 8;
} }
message GenerateComplete { message GenerateComplete {
...@@ -207,6 +210,9 @@ message GenerateComplete { ...@@ -207,6 +210,9 @@ message GenerateComplete {
// Input logprobs if requested (for prompt tokens) // Input logprobs if requested (for prompt tokens)
InputLogProbs input_logprobs = 10; InputLogProbs input_logprobs = 10;
// Index for ordering when n>1 (for parallel request multiplexing)
uint32 index = 11;
} }
message GenerateError { message GenerateError {
......
...@@ -160,7 +160,7 @@ class GenerateResponse(_message.Message): ...@@ -160,7 +160,7 @@ class GenerateResponse(_message.Message):
def __init__(self, request_id: _Optional[str] = ..., chunk: _Optional[_Union[GenerateStreamChunk, _Mapping]] = ..., complete: _Optional[_Union[GenerateComplete, _Mapping]] = ..., error: _Optional[_Union[GenerateError, _Mapping]] = ...) -> None: ... def __init__(self, request_id: _Optional[str] = ..., chunk: _Optional[_Union[GenerateStreamChunk, _Mapping]] = ..., complete: _Optional[_Union[GenerateComplete, _Mapping]] = ..., error: _Optional[_Union[GenerateError, _Mapping]] = ...) -> None: ...
class GenerateStreamChunk(_message.Message): class GenerateStreamChunk(_message.Message):
__slots__ = ("token_ids", "prompt_tokens", "completion_tokens", "cached_tokens", "output_logprobs", "hidden_states", "input_logprobs") __slots__ = ("token_ids", "prompt_tokens", "completion_tokens", "cached_tokens", "output_logprobs", "hidden_states", "input_logprobs", "index")
TOKEN_IDS_FIELD_NUMBER: _ClassVar[int] TOKEN_IDS_FIELD_NUMBER: _ClassVar[int]
PROMPT_TOKENS_FIELD_NUMBER: _ClassVar[int] PROMPT_TOKENS_FIELD_NUMBER: _ClassVar[int]
COMPLETION_TOKENS_FIELD_NUMBER: _ClassVar[int] COMPLETION_TOKENS_FIELD_NUMBER: _ClassVar[int]
...@@ -168,6 +168,7 @@ class GenerateStreamChunk(_message.Message): ...@@ -168,6 +168,7 @@ class GenerateStreamChunk(_message.Message):
OUTPUT_LOGPROBS_FIELD_NUMBER: _ClassVar[int] OUTPUT_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
HIDDEN_STATES_FIELD_NUMBER: _ClassVar[int] HIDDEN_STATES_FIELD_NUMBER: _ClassVar[int]
INPUT_LOGPROBS_FIELD_NUMBER: _ClassVar[int] INPUT_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
INDEX_FIELD_NUMBER: _ClassVar[int]
token_ids: _containers.RepeatedScalarFieldContainer[int] token_ids: _containers.RepeatedScalarFieldContainer[int]
prompt_tokens: int prompt_tokens: int
completion_tokens: int completion_tokens: int
...@@ -175,10 +176,11 @@ class GenerateStreamChunk(_message.Message): ...@@ -175,10 +176,11 @@ class GenerateStreamChunk(_message.Message):
output_logprobs: OutputLogProbs output_logprobs: OutputLogProbs
hidden_states: _containers.RepeatedScalarFieldContainer[float] hidden_states: _containers.RepeatedScalarFieldContainer[float]
input_logprobs: InputLogProbs input_logprobs: InputLogProbs
def __init__(self, token_ids: _Optional[_Iterable[int]] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., output_logprobs: _Optional[_Union[OutputLogProbs, _Mapping]] = ..., hidden_states: _Optional[_Iterable[float]] = ..., input_logprobs: _Optional[_Union[InputLogProbs, _Mapping]] = ...) -> None: ... index: int
def __init__(self, token_ids: _Optional[_Iterable[int]] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., output_logprobs: _Optional[_Union[OutputLogProbs, _Mapping]] = ..., hidden_states: _Optional[_Iterable[float]] = ..., input_logprobs: _Optional[_Union[InputLogProbs, _Mapping]] = ..., index: _Optional[int] = ...) -> None: ...
class GenerateComplete(_message.Message): class GenerateComplete(_message.Message):
__slots__ = ("output_ids", "finish_reason", "prompt_tokens", "completion_tokens", "cached_tokens", "output_logprobs", "all_hidden_states", "matched_token_id", "matched_stop_str", "input_logprobs") __slots__ = ("output_ids", "finish_reason", "prompt_tokens", "completion_tokens", "cached_tokens", "output_logprobs", "all_hidden_states", "matched_token_id", "matched_stop_str", "input_logprobs", "index")
OUTPUT_IDS_FIELD_NUMBER: _ClassVar[int] OUTPUT_IDS_FIELD_NUMBER: _ClassVar[int]
FINISH_REASON_FIELD_NUMBER: _ClassVar[int] FINISH_REASON_FIELD_NUMBER: _ClassVar[int]
PROMPT_TOKENS_FIELD_NUMBER: _ClassVar[int] PROMPT_TOKENS_FIELD_NUMBER: _ClassVar[int]
...@@ -189,6 +191,7 @@ class GenerateComplete(_message.Message): ...@@ -189,6 +191,7 @@ class GenerateComplete(_message.Message):
MATCHED_TOKEN_ID_FIELD_NUMBER: _ClassVar[int] MATCHED_TOKEN_ID_FIELD_NUMBER: _ClassVar[int]
MATCHED_STOP_STR_FIELD_NUMBER: _ClassVar[int] MATCHED_STOP_STR_FIELD_NUMBER: _ClassVar[int]
INPUT_LOGPROBS_FIELD_NUMBER: _ClassVar[int] INPUT_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
INDEX_FIELD_NUMBER: _ClassVar[int]
output_ids: _containers.RepeatedScalarFieldContainer[int] output_ids: _containers.RepeatedScalarFieldContainer[int]
finish_reason: str finish_reason: str
prompt_tokens: int prompt_tokens: int
...@@ -199,7 +202,8 @@ class GenerateComplete(_message.Message): ...@@ -199,7 +202,8 @@ class GenerateComplete(_message.Message):
matched_token_id: int matched_token_id: int
matched_stop_str: str matched_stop_str: str
input_logprobs: InputLogProbs input_logprobs: InputLogProbs
def __init__(self, output_ids: _Optional[_Iterable[int]] = ..., finish_reason: _Optional[str] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., output_logprobs: _Optional[_Union[OutputLogProbs, _Mapping]] = ..., all_hidden_states: _Optional[_Iterable[_Union[HiddenStates, _Mapping]]] = ..., matched_token_id: _Optional[int] = ..., matched_stop_str: _Optional[str] = ..., input_logprobs: _Optional[_Union[InputLogProbs, _Mapping]] = ...) -> None: ... index: int
def __init__(self, output_ids: _Optional[_Iterable[int]] = ..., finish_reason: _Optional[str] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., output_logprobs: _Optional[_Union[OutputLogProbs, _Mapping]] = ..., all_hidden_states: _Optional[_Iterable[_Union[HiddenStates, _Mapping]]] = ..., matched_token_id: _Optional[int] = ..., matched_stop_str: _Optional[str] = ..., input_logprobs: _Optional[_Union[InputLogProbs, _Mapping]] = ..., index: _Optional[int] = ...) -> None: ...
class GenerateError(_message.Message): class GenerateError(_message.Message):
__slots__ = ("message", "http_status_code", "details") __slots__ = ("message", "http_status_code", "details")
......
...@@ -192,7 +192,6 @@ fn create_large_chat_completion_request() -> ChatCompletionRequest { ...@@ -192,7 +192,6 @@ fn create_large_chat_completion_request() -> ChatCompletionRequest {
content: Some(format!("Answer {}: This is a detailed response about topic {} that covers multiple aspects and provides comprehensive analysis of the interconnected systems you mentioned.", i, i)), content: Some(format!("Answer {}: This is a detailed response about topic {} that covers multiple aspects and provides comprehensive analysis of the interconnected systems you mentioned.", i, i)),
name: None, name: None,
tool_calls: None, tool_calls: None,
function_call: None,
reasoning_content: None, reasoning_content: None,
}); });
} }
......
...@@ -179,6 +179,9 @@ message GenerateStreamChunk { ...@@ -179,6 +179,9 @@ message GenerateStreamChunk {
// Input logprobs (if requested) - only in first chunk // Input logprobs (if requested) - only in first chunk
InputLogProbs input_logprobs = 7; InputLogProbs input_logprobs = 7;
// Index for ordering when n>1 (for parallel request multiplexing)
uint32 index = 8;
} }
message GenerateComplete { message GenerateComplete {
...@@ -207,6 +210,9 @@ message GenerateComplete { ...@@ -207,6 +210,9 @@ message GenerateComplete {
// Input logprobs if requested (for prompt tokens) // Input logprobs if requested (for prompt tokens)
InputLogProbs input_logprobs = 10; InputLogProbs input_logprobs = 10;
// Index for ordering when n>1 (for parallel request multiplexing)
uint32 index = 11;
} }
message GenerateError { message GenerateError {
......
...@@ -72,8 +72,6 @@ pub enum ChatMessage { ...@@ -72,8 +72,6 @@ pub enum ChatMessage {
name: Option<String>, name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
tool_calls: Option<Vec<ToolCall>>, tool_calls: Option<Vec<ToolCall>>,
#[serde(skip_serializing_if = "Option::is_none")]
function_call: Option<FunctionCallResponse>,
/// Reasoning content for O1-style models (SGLang extension) /// Reasoning content for O1-style models (SGLang extension)
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
reasoning_content: Option<String>, reasoning_content: Option<String>,
...@@ -140,8 +138,6 @@ pub struct ChatMessageDelta { ...@@ -140,8 +138,6 @@ pub struct ChatMessageDelta {
pub content: Option<String>, pub content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCallDelta>>, pub tool_calls: Option<Vec<ToolCallDelta>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub function_call: Option<FunctionCallDelta>,
/// Reasoning content delta for O1-style models (SGLang extension) /// Reasoning content delta for O1-style models (SGLang extension)
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub reasoning_content: Option<String>, pub reasoning_content: Option<String>,
...@@ -473,6 +469,8 @@ pub struct ChatStreamChoice { ...@@ -473,6 +469,8 @@ pub struct ChatStreamChoice {
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub logprobs: Option<ChatLogProbs>, pub logprobs: Option<ChatLogProbs>,
pub finish_reason: Option<String>, pub finish_reason: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub matched_stop: Option<Value>,
} }
// Completions API request types (v1/completions) - DEPRECATED but still supported // Completions API request types (v1/completions) - DEPRECATED but still supported
......
...@@ -44,7 +44,7 @@ graph TB ...@@ -44,7 +44,7 @@ graph TB
end end
subgraph Factory Layer subgraph Factory Layer
MID --> PF[ParserFactory] MID --> PF[ReasoningParserFactory]
PF --> REG[ParserRegistry] PF --> REG[ParserRegistry]
REG --> PM[Pattern Matching] REG --> PM[Pattern Matching]
PM --> PP[Parser Pool] PM --> PP[Parser Pool]
...@@ -93,7 +93,7 @@ graph TB ...@@ -93,7 +93,7 @@ graph TB
```mermaid ```mermaid
sequenceDiagram sequenceDiagram
participant C as Client participant C as Client
participant F as ParserFactory participant F as ReasoningParserFactory
participant R as Registry participant R as Registry
participant P as Parser Pool participant P as Parser Pool
participant BP as BaseParser participant BP as BaseParser
...@@ -206,7 +206,7 @@ classDiagram ...@@ -206,7 +206,7 @@ classDiagram
+new() Self +new() Self
} }
class ParserFactory { class ReasoningParserFactory {
-registry: ParserRegistry -registry: ParserRegistry
+new() Self +new() Self
+get_pooled(model_id: &str) PooledParser +get_pooled(model_id: &str) PooledParser
...@@ -240,7 +240,7 @@ classDiagram ...@@ -240,7 +240,7 @@ classDiagram
Step3Parser o-- BaseReasoningParser Step3Parser o-- BaseReasoningParser
BaseReasoningParser o-- ParserConfig BaseReasoningParser o-- ParserConfig
ParserFactory o-- ParserRegistry ReasoningParserFactory o-- ParserRegistry
ParserRegistry o-- ReasoningParser ParserRegistry o-- ReasoningParser
``` ```
...@@ -302,7 +302,7 @@ classDiagram ...@@ -302,7 +302,7 @@ classDiagram
- Delegate to get_pooled_parser - Delegate to get_pooled_parser
- Case-insensitive comparison - Case-insensitive comparison
**ParserFactory Methods**: **ReasoningParserFactory Methods**:
1. **`new()`**: 1. **`new()`**:
- Register all built-in parsers - Register all built-in parsers
...@@ -437,7 +437,7 @@ impl ReasoningParser for MyModelParser { ...@@ -437,7 +437,7 @@ impl ReasoningParser for MyModelParser {
**Step 2: Register in Factory** **Step 2: Register in Factory**
```rust ```rust
// In factory.rs ParserFactory::new() // In factory.rs ReasoningParserFactory::new()
registry.register_parser("mymodel", || { registry.register_parser("mymodel", || {
Box::new(MyModelParser::new()) Box::new(MyModelParser::new())
}); });
......
...@@ -128,11 +128,11 @@ impl Default for ParserRegistry { ...@@ -128,11 +128,11 @@ impl Default for ParserRegistry {
/// Factory for creating reasoning parsers based on model type. /// Factory for creating reasoning parsers based on model type.
#[derive(Clone)] #[derive(Clone)]
pub struct ParserFactory { pub struct ReasoningParserFactory {
registry: ParserRegistry, registry: ParserRegistry,
} }
impl ParserFactory { impl ReasoningParserFactory {
/// Create a new factory with default parsers registered. /// Create a new factory with default parsers registered.
pub fn new() -> Self { pub fn new() -> Self {
let registry = ParserRegistry::new(); let registry = ParserRegistry::new();
...@@ -237,7 +237,7 @@ impl ParserFactory { ...@@ -237,7 +237,7 @@ impl ParserFactory {
} }
} }
impl Default for ParserFactory { impl Default for ReasoningParserFactory {
fn default() -> Self { fn default() -> Self {
Self::new() Self::new()
} }
...@@ -249,35 +249,35 @@ mod tests { ...@@ -249,35 +249,35 @@ mod tests {
#[test] #[test]
fn test_factory_creates_deepseek_r1() { fn test_factory_creates_deepseek_r1() {
let factory = ParserFactory::new(); let factory = ReasoningParserFactory::new();
let parser = factory.create("deepseek-r1-distill").unwrap(); let parser = factory.create("deepseek-r1-distill").unwrap();
assert_eq!(parser.model_type(), "deepseek_r1"); assert_eq!(parser.model_type(), "deepseek_r1");
} }
#[test] #[test]
fn test_factory_creates_qwen3() { fn test_factory_creates_qwen3() {
let factory = ParserFactory::new(); let factory = ReasoningParserFactory::new();
let parser = factory.create("qwen3-7b").unwrap(); let parser = factory.create("qwen3-7b").unwrap();
assert_eq!(parser.model_type(), "qwen3"); assert_eq!(parser.model_type(), "qwen3");
} }
#[test] #[test]
fn test_factory_creates_kimi() { fn test_factory_creates_kimi() {
let factory = ParserFactory::new(); let factory = ReasoningParserFactory::new();
let parser = factory.create("kimi-chat").unwrap(); let parser = factory.create("kimi-chat").unwrap();
assert_eq!(parser.model_type(), "kimi"); assert_eq!(parser.model_type(), "kimi");
} }
#[test] #[test]
fn test_factory_fallback_to_passthrough() { fn test_factory_fallback_to_passthrough() {
let factory = ParserFactory::new(); let factory = ReasoningParserFactory::new();
let parser = factory.create("unknown-model").unwrap(); let parser = factory.create("unknown-model").unwrap();
assert_eq!(parser.model_type(), "passthrough"); assert_eq!(parser.model_type(), "passthrough");
} }
#[test] #[test]
fn test_case_insensitive_matching() { fn test_case_insensitive_matching() {
let factory = ParserFactory::new(); let factory = ReasoningParserFactory::new();
let parser1 = factory.create("DeepSeek-R1").unwrap(); let parser1 = factory.create("DeepSeek-R1").unwrap();
let parser2 = factory.create("QWEN3").unwrap(); let parser2 = factory.create("QWEN3").unwrap();
let parser3 = factory.create("Kimi").unwrap(); let parser3 = factory.create("Kimi").unwrap();
...@@ -289,21 +289,21 @@ mod tests { ...@@ -289,21 +289,21 @@ mod tests {
#[test] #[test]
fn test_step3_model() { fn test_step3_model() {
let factory = ParserFactory::new(); let factory = ReasoningParserFactory::new();
let step3 = factory.create("step3-model").unwrap(); let step3 = factory.create("step3-model").unwrap();
assert_eq!(step3.model_type(), "step3"); assert_eq!(step3.model_type(), "step3");
} }
#[test] #[test]
fn test_glm45_model() { fn test_glm45_model() {
let factory = ParserFactory::new(); let factory = ReasoningParserFactory::new();
let glm45 = factory.create("glm45-v2").unwrap(); let glm45 = factory.create("glm45-v2").unwrap();
assert_eq!(glm45.model_type(), "glm45"); assert_eq!(glm45.model_type(), "glm45");
} }
#[test] #[test]
fn test_pooled_parser_reuse() { fn test_pooled_parser_reuse() {
let factory = ParserFactory::new(); let factory = ReasoningParserFactory::new();
// Get the same parser twice - should be the same instance // Get the same parser twice - should be the same instance
let parser1 = factory.get_pooled("deepseek-r1"); let parser1 = factory.get_pooled("deepseek-r1");
...@@ -321,7 +321,7 @@ mod tests { ...@@ -321,7 +321,7 @@ mod tests {
fn test_pooled_parser_concurrent_access() { fn test_pooled_parser_concurrent_access() {
use std::thread; use std::thread;
let factory = ParserFactory::new(); let factory = ReasoningParserFactory::new();
let parser = factory.get_pooled("deepseek-r1"); let parser = factory.get_pooled("deepseek-r1");
// Spawn multiple threads that use the same parser // Spawn multiple threads that use the same parser
...@@ -347,7 +347,7 @@ mod tests { ...@@ -347,7 +347,7 @@ mod tests {
#[test] #[test]
fn test_pool_clearing() { fn test_pool_clearing() {
let factory = ParserFactory::new(); let factory = ReasoningParserFactory::new();
// Get a pooled parser // Get a pooled parser
let parser1 = factory.get_pooled("deepseek-r1"); let parser1 = factory.get_pooled("deepseek-r1");
...@@ -364,7 +364,7 @@ mod tests { ...@@ -364,7 +364,7 @@ mod tests {
#[test] #[test]
fn test_passthrough_parser_pooling() { fn test_passthrough_parser_pooling() {
let factory = ParserFactory::new(); let factory = ReasoningParserFactory::new();
// Unknown models should get passthrough parser // Unknown models should get passthrough parser
let parser1 = factory.get_pooled("unknown-model-1"); let parser1 = factory.get_pooled("unknown-model-1");
...@@ -383,7 +383,7 @@ mod tests { ...@@ -383,7 +383,7 @@ mod tests {
use std::thread; use std::thread;
use std::time::Instant; use std::time::Instant;
let factory = ParserFactory::new(); let factory = ReasoningParserFactory::new();
let num_threads = 100; let num_threads = 100;
let requests_per_thread = 50; let requests_per_thread = 50;
let models = vec!["deepseek-r1", "qwen3", "kimi", "qwen3-thinking"]; let models = vec!["deepseek-r1", "qwen3", "kimi", "qwen3-thinking"];
...@@ -527,7 +527,7 @@ mod tests { ...@@ -527,7 +527,7 @@ mod tests {
fn test_concurrent_pool_modifications() { fn test_concurrent_pool_modifications() {
use std::thread; use std::thread;
let factory = ParserFactory::new(); let factory = ReasoningParserFactory::new();
let mut handles = vec![]; let mut handles = vec![];
// Thread 1: Continuously get parsers // Thread 1: Continuously get parsers
......
...@@ -2,7 +2,7 @@ pub mod factory; ...@@ -2,7 +2,7 @@ pub mod factory;
pub mod parsers; pub mod parsers;
pub mod traits; pub mod traits;
pub use factory::{ParserFactory, ParserRegistry, PooledParser}; pub use factory::{ParserRegistry, PooledParser, ReasoningParserFactory};
pub use parsers::{ pub use parsers::{
BaseReasoningParser, DeepSeekR1Parser, Glm45Parser, KimiParser, Qwen3Parser, BaseReasoningParser, DeepSeekR1Parser, Glm45Parser, KimiParser, Qwen3Parser,
QwenThinkingParser, Step3Parser, QwenThinkingParser, Step3Parser,
......
...@@ -4,7 +4,7 @@ use crate::config::types::RetryConfig; ...@@ -4,7 +4,7 @@ use crate::config::types::RetryConfig;
use crate::core::{WorkerRegistry, WorkerType}; use crate::core::{WorkerRegistry, WorkerType};
use crate::metrics::RouterMetrics; use crate::metrics::RouterMetrics;
use crate::policies::PolicyRegistry; use crate::policies::PolicyRegistry;
use crate::reasoning_parser::ParserFactory; use crate::reasoning_parser::ReasoningParserFactory;
use crate::routers::RouterTrait; use crate::routers::RouterTrait;
use crate::tokenizer::traits::Tokenizer; use crate::tokenizer::traits::Tokenizer;
use crate::tool_parser::ToolParserFactory; use crate::tool_parser::ToolParserFactory;
...@@ -24,7 +24,7 @@ pub struct GrpcPDRouter { ...@@ -24,7 +24,7 @@ pub struct GrpcPDRouter {
worker_registry: Arc<WorkerRegistry>, worker_registry: Arc<WorkerRegistry>,
policy_registry: Arc<PolicyRegistry>, policy_registry: Arc<PolicyRegistry>,
tokenizer: Arc<dyn Tokenizer>, tokenizer: Arc<dyn Tokenizer>,
reasoning_parser_factory: ParserFactory, reasoning_parser_factory: ReasoningParserFactory,
tool_parser_factory: ToolParserFactory, tool_parser_factory: ToolParserFactory,
dp_aware: bool, dp_aware: bool,
......
This diff is collapsed.
...@@ -15,7 +15,7 @@ use crate::{ ...@@ -15,7 +15,7 @@ use crate::{
}, },
worker_spec::{WorkerApiResponse, WorkerConfigRequest, WorkerErrorResponse}, worker_spec::{WorkerApiResponse, WorkerConfigRequest, WorkerErrorResponse},
}, },
reasoning_parser::ParserFactory, reasoning_parser::ReasoningParserFactory,
routers::{router_manager::RouterManager, RouterTrait}, routers::{router_manager::RouterManager, RouterTrait},
service_discovery::{start_service_discovery, ServiceDiscoveryConfig}, service_discovery::{start_service_discovery, ServiceDiscoveryConfig},
tokenizer::{factory as tokenizer_factory, traits::Tokenizer}, tokenizer::{factory as tokenizer_factory, traits::Tokenizer},
...@@ -45,7 +45,7 @@ pub struct AppContext { ...@@ -45,7 +45,7 @@ pub struct AppContext {
pub router_config: RouterConfig, pub router_config: RouterConfig,
pub rate_limiter: Arc<TokenBucket>, pub rate_limiter: Arc<TokenBucket>,
pub tokenizer: Option<Arc<dyn Tokenizer>>, pub tokenizer: Option<Arc<dyn Tokenizer>>,
pub reasoning_parser_factory: Option<ParserFactory>, pub reasoning_parser_factory: Option<ReasoningParserFactory>,
pub tool_parser_factory: Option<ToolParserFactory>, pub tool_parser_factory: Option<ToolParserFactory>,
pub worker_registry: Arc<WorkerRegistry>, pub worker_registry: Arc<WorkerRegistry>,
pub policy_registry: Arc<PolicyRegistry>, pub policy_registry: Arc<PolicyRegistry>,
...@@ -79,7 +79,7 @@ impl AppContext { ...@@ -79,7 +79,7 @@ impl AppContext {
tokenizer_factory::create_tokenizer(&tokenizer_path) tokenizer_factory::create_tokenizer(&tokenizer_path)
.map_err(|e| format!("Failed to create tokenizer: {e}"))?, .map_err(|e| format!("Failed to create tokenizer: {e}"))?,
); );
let reasoning_parser_factory = Some(ParserFactory::new()); let reasoning_parser_factory = Some(ReasoningParserFactory::new());
let tool_parser_factory = Some(ToolParserFactory::new()); let tool_parser_factory = Some(ToolParserFactory::new());
(tokenizer, reasoning_parser_factory, tool_parser_factory) (tokenizer, reasoning_parser_factory, tool_parser_factory)
......
...@@ -123,12 +123,7 @@ impl DeepSeekParser { ...@@ -123,12 +123,7 @@ impl DeepSeekParser {
let arguments = serde_json::to_string(&args) let arguments = serde_json::to_string(&args)
.map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?; .map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?;
// Generate ID
let id = format!("deepseek_call_{}", uuid::Uuid::new_v4());
Ok(ToolCall { Ok(ToolCall {
id,
r#type: "function".to_string(),
function: FunctionCall { function: FunctionCall {
name: func_name.to_string(), name: func_name.to_string(),
arguments, arguments,
...@@ -320,4 +315,8 @@ impl ToolParser for DeepSeekParser { ...@@ -320,4 +315,8 @@ impl ToolParser for DeepSeekParser {
fn detect_format(&self, text: &str) -> bool { fn detect_format(&self, text: &str) -> bool {
self.has_tool_markers(text) self.has_tool_markers(text)
} }
fn get_unstreamed_tool_args(&self) -> Option<Vec<ToolCallItem>> {
helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool)
}
} }
...@@ -129,12 +129,7 @@ impl Glm4MoeParser { ...@@ -129,12 +129,7 @@ impl Glm4MoeParser {
let arguments_str = serde_json::to_string(&arguments) let arguments_str = serde_json::to_string(&arguments)
.map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?; .map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?;
// Generate ID
let id = format!("glm4_call_{}", uuid::Uuid::new_v4());
Ok(Some(ToolCall { Ok(Some(ToolCall {
id,
r#type: "function".to_string(),
function: FunctionCall { function: FunctionCall {
name: func_name.to_string(), name: func_name.to_string(),
arguments: arguments_str, arguments: arguments_str,
...@@ -321,4 +316,8 @@ impl ToolParser for Glm4MoeParser { ...@@ -321,4 +316,8 @@ impl ToolParser for Glm4MoeParser {
fn detect_format(&self, text: &str) -> bool { fn detect_format(&self, text: &str) -> bool {
self.has_tool_markers(text) self.has_tool_markers(text)
} }
fn get_unstreamed_tool_args(&self) -> Option<Vec<ToolCallItem>> {
helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool)
}
} }
...@@ -113,12 +113,7 @@ impl ToolParser for GptOssParser { ...@@ -113,12 +113,7 @@ impl ToolParser for GptOssParser {
} }
}; };
// Generate unique ID
let id = format!("gpt_oss_call_{}", uuid::Uuid::new_v4());
tools.push(ToolCall { tools.push(ToolCall {
id,
r#type: "function".to_string(),
function: FunctionCall { function: FunctionCall {
name: function_name, name: function_name,
arguments, arguments,
......
...@@ -14,6 +14,48 @@ pub fn get_tool_indices(tools: &[Tool]) -> HashMap<String, usize> { ...@@ -14,6 +14,48 @@ pub fn get_tool_indices(tools: &[Tool]) -> HashMap<String, usize> {
.collect() .collect()
} }
/// Get unstreamed tool call arguments
/// Returns tool call items for arguments that have been parsed but not yet streamed
/// This ensures tool calls are properly completed even if the model generates final arguments in the last chunk
pub fn get_unstreamed_args(
prev_tool_call_arr: &[Value],
streamed_args_for_tool: &[String],
) -> Option<Vec<ToolCallItem>> {
// Check if we have tool calls being tracked
if prev_tool_call_arr.is_empty() || streamed_args_for_tool.is_empty() {
return None;
}
// Get the last tool call that was being processed
let tool_index = prev_tool_call_arr.len() - 1;
if tool_index >= streamed_args_for_tool.len() {
return None;
}
// Get expected vs actual arguments
let expected_args = prev_tool_call_arr[tool_index].get("arguments")?;
let expected_str = serde_json::to_string(expected_args).ok()?;
let actual_str = &streamed_args_for_tool[tool_index];
// Check if there are remaining arguments to send
let remaining = if expected_str.starts_with(actual_str) {
&expected_str[actual_str.len()..]
} else {
return None;
};
if remaining.is_empty() {
return None;
}
// Return the remaining arguments as a ToolCallItem
Some(vec![ToolCallItem {
tool_index,
name: None, // No name for argument deltas
parameters: remaining.to_string(),
}])
}
/// Check if a buffer ends with a partial occurrence of a token /// Check if a buffer ends with a partial occurrence of a token
/// Returns Some(length) if there's a partial match, None otherwise /// Returns Some(length) if there's a partial match, None otherwise
pub fn ends_with_partial_token(buffer: &str, token: &str) -> Option<usize> { pub fn ends_with_partial_token(buffer: &str, token: &str) -> Option<usize> {
......
...@@ -8,7 +8,7 @@ use crate::tool_parser::{ ...@@ -8,7 +8,7 @@ use crate::tool_parser::{
parsers::helpers, parsers::helpers,
partial_json::PartialJson, partial_json::PartialJson,
traits::ToolParser, traits::ToolParser,
types::{FunctionCall, StreamingParseResult, ToolCall}, types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem},
}; };
/// JSON format parser for tool calls /// JSON format parser for tool calls
...@@ -136,16 +136,7 @@ impl JsonParser { ...@@ -136,16 +136,7 @@ impl JsonParser {
let arguments = serde_json::to_string(args) let arguments = serde_json::to_string(args)
.map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?; .map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?;
// Generate a unique ID if not provided
let id = obj
.get("id")
.and_then(|v| v.as_str())
.map(String::from)
.unwrap_or_else(|| format!("call_{}", uuid::Uuid::new_v4()));
Ok(Some(ToolCall { Ok(Some(ToolCall {
id,
r#type: "function".to_string(),
function: FunctionCall { function: FunctionCall {
name: name.to_string(), name: name.to_string(),
arguments, arguments,
...@@ -274,4 +265,8 @@ impl ToolParser for JsonParser { ...@@ -274,4 +265,8 @@ impl ToolParser for JsonParser {
let trimmed = text.trim(); let trimmed = text.trim();
(trimmed.starts_with('[') || trimmed.starts_with('{')) && trimmed.contains(r#""name""#) (trimmed.starts_with('[') || trimmed.starts_with('{')) && trimmed.contains(r#""name""#)
} }
fn get_unstreamed_tool_args(&self) -> Option<Vec<ToolCallItem>> {
helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool)
}
} }
...@@ -131,12 +131,7 @@ impl ToolParser for KimiK2Parser { ...@@ -131,12 +131,7 @@ impl ToolParser for KimiK2Parser {
// Try to parse JSON arguments // Try to parse JSON arguments
match serde_json::from_str::<serde_json::Value>(function_args) { match serde_json::from_str::<serde_json::Value>(function_args) {
Ok(_) => { Ok(_) => {
// Generate unique ID
let id = format!("kimi_call_{}", uuid::Uuid::new_v4());
tools.push(ToolCall { tools.push(ToolCall {
id,
r#type: "function".to_string(),
function: FunctionCall { function: FunctionCall {
name: func_name, name: func_name,
arguments: function_args.to_string(), arguments: function_args.to_string(),
...@@ -339,4 +334,8 @@ impl ToolParser for KimiK2Parser { ...@@ -339,4 +334,8 @@ impl ToolParser for KimiK2Parser {
fn detect_format(&self, text: &str) -> bool { fn detect_format(&self, text: &str) -> bool {
self.has_tool_markers(text) || text.contains("<|tool_call_begin|>") self.has_tool_markers(text) || text.contains("<|tool_call_begin|>")
} }
fn get_unstreamed_tool_args(&self) -> Option<Vec<ToolCallItem>> {
helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool)
}
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment