Unverified Commit 927975ea authored by Flora Feng's avatar Flora Feng Committed by GitHub
Browse files

[Parser] Migrate response api streaming to unified parser (#38755)


Signed-off-by: default avatarsfeng33 <4florafeng@gmail.com>
Signed-off-by: default avatarAndrew Xia <axia@meta.com>
parent 9ea7d670
...@@ -628,6 +628,31 @@ def _identity_increment(event): ...@@ -628,6 +628,31 @@ def _identity_increment(event):
return event return event
def _mock_parser_with_reasoning(serving, delta_sequence: list[DeltaMessage]):
"""Set up serving.parser so that it returns a mock parser instance
with a reasoning parser that returns the given delta_sequence.
The mock has reasoning_parser set (truthy) but tool_parser as None,
so the parser's parse_delta enters the reasoning-only branch.
"""
call_count = 0
def mock_parse_delta(**kwargs):
nonlocal call_count
if call_count >= len(delta_sequence):
return None
result = delta_sequence[call_count]
call_count += 1
return result
mock_parser_instance = MagicMock()
mock_parser_instance.reasoning_parser = MagicMock() # truthy
mock_parser_instance.tool_parser = None
mock_parser_instance.parse_delta = mock_parse_delta
mock_parser_instance.is_reasoning_end = MagicMock(return_value=False)
serving.parser = MagicMock(return_value=mock_parser_instance)
class TestStreamingReasoningToContentTransition: class TestStreamingReasoningToContentTransition:
"""Tests for _process_simple_streaming_events reasoning-to-content """Tests for _process_simple_streaming_events reasoning-to-content
transition, specifically the fix for mixed deltas that carry both transition, specifically the fix for mixed deltas that carry both
...@@ -646,27 +671,13 @@ class TestStreamingReasoningToContentTransition: ...@@ -646,27 +671,13 @@ class TestStreamingReasoningToContentTransition:
monkeypatch.setattr(envs, "VLLM_USE_EXPERIMENTAL_PARSER_CONTEXT", False) monkeypatch.setattr(envs, "VLLM_USE_EXPERIMENTAL_PARSER_CONTEXT", False)
serving = _make_serving_instance_with_reasoning() serving = _make_serving_instance_with_reasoning()
# Sequence of DeltaMessages the mock reasoning parser will return # Sequence of DeltaMessages the mock orchestrator will return
delta_sequence = [ delta_sequence = [
DeltaMessage(reasoning="thinking..."), DeltaMessage(reasoning="thinking..."),
DeltaMessage(reasoning=" end", content="hello"), # mixed delta DeltaMessage(reasoning=" end", content="hello"), # mixed delta
DeltaMessage(content=" world"), DeltaMessage(content=" world"),
] ]
call_count = 0 _mock_parser_with_reasoning(serving, delta_sequence)
def mock_extract_reasoning_streaming(**kwargs):
nonlocal call_count
result = delta_sequence[call_count]
call_count += 1
return result
# Mock the reasoning parser on the serving instance
mock_parser = MagicMock()
mock_parser.extract_reasoning_streaming = mock_extract_reasoning_streaming
mock_parser.extract_tool_calls_streaming = mock_extract_reasoning_streaming
serving.parser = MagicMock()
serving.parser.reasoning_parser_cls = MagicMock(return_value=mock_parser)
serving.parser.tool_parser_cls = MagicMock(return_value=mock_parser)
# Create contexts for each streaming chunk # Create contexts for each streaming chunk
contexts = [ contexts = [
_make_simple_context_with_output("chunk1", [10]), _make_simple_context_with_output("chunk1", [10]),
...@@ -734,20 +745,7 @@ class TestStreamingReasoningToContentTransition: ...@@ -734,20 +745,7 @@ class TestStreamingReasoningToContentTransition:
DeltaMessage(reasoning="thinking"), DeltaMessage(reasoning="thinking"),
DeltaMessage(content="answer"), DeltaMessage(content="answer"),
] ]
call_count = 0 _mock_parser_with_reasoning(serving, delta_sequence)
def mock_extract_reasoning_streaming(**kwargs):
nonlocal call_count
result = delta_sequence[call_count]
call_count += 1
return result
mock_parser = MagicMock()
mock_parser.extract_reasoning_streaming = mock_extract_reasoning_streaming
mock_parser.extract_tool_calls_streaming = mock_extract_reasoning_streaming
serving.parser = MagicMock()
serving.parser.reasoning_parser_cls = MagicMock(return_value=mock_parser)
serving.parser.tool_parser_cls = MagicMock(return_value=mock_parser)
contexts = [ contexts = [
_make_simple_context_with_output("chunk1", [10]), _make_simple_context_with_output("chunk1", [10]),
...@@ -809,20 +807,7 @@ class TestStreamingReasoningToContentTransition: ...@@ -809,20 +807,7 @@ class TestStreamingReasoningToContentTransition:
DeltaMessage(reasoning="step 1"), DeltaMessage(reasoning="step 1"),
DeltaMessage(reasoning=" step 2"), DeltaMessage(reasoning=" step 2"),
] ]
call_count = 0 _mock_parser_with_reasoning(serving, delta_sequence)
def mock_extract_reasoning_streaming(**kwargs):
nonlocal call_count
result = delta_sequence[call_count]
call_count += 1
return result
mock_parser = MagicMock()
mock_parser.extract_reasoning_streaming = mock_extract_reasoning_streaming
mock_parser.extract_tool_calls_streaming = mock_extract_reasoning_streaming
serving.parser = MagicMock()
serving.parser.reasoning_parser_cls = MagicMock(return_value=mock_parser)
serving.parser.tool_parser_cls = MagicMock(return_value=mock_parser)
contexts = [ contexts = [
_make_simple_context_with_output("chunk1", [10]), _make_simple_context_with_output("chunk1", [10]),
......
...@@ -1339,101 +1339,31 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -1339,101 +1339,31 @@ class OpenAIServingResponses(OpenAIServing):
current_content_index = 0 current_content_index = 0
current_output_index = 0 current_output_index = 0
current_item_id = "" current_item_id = ""
reasoning_parser = None parser = self.parser(tokenizer, request.tools) if self.parser else None
if self.parser and self.parser.reasoning_parser_cls:
reasoning_parser = self.parser.reasoning_parser_cls(tokenizer)
tool_parser = None
if self.parser and self.parser.tool_parser_cls:
tool_parser = self.parser.tool_parser_cls(tokenizer, request.tools)
reasoning_ended = False
tool_call_text_started = False
previous_text = ""
previous_token_ids: list[int] = []
prompt_is_reasoning_end = None
first_delta_sent = False first_delta_sent = False
previous_delta_messages: list[DeltaMessage] = [] previous_delta_messages: list[DeltaMessage] = []
async for ctx in result_generator: async for ctx in result_generator:
assert isinstance(ctx, SimpleContext) assert isinstance(ctx, SimpleContext)
if ctx.last_output is None: if ctx.last_output is None:
continue continue
if reasoning_parser and prompt_is_reasoning_end is None:
prompt_is_reasoning_end = reasoning_parser.is_reasoning_end(
ctx.last_output.prompt_token_ids
)
if ctx.last_output.outputs: if ctx.last_output.outputs:
output = ctx.last_output.outputs[0] output = ctx.last_output.outputs[0]
# finish_reason='error' indicates a retryable error # finish_reason='error' indicates a retryable error
self._raise_if_error(output.finish_reason, request.request_id) self._raise_if_error(output.finish_reason, request.request_id)
delta_text = output.text delta_text = output.text
delta_token_ids = as_list(output.token_ids) delta_token_ids = as_list(output.token_ids)
current_text = previous_text + delta_text
current_token_ids = previous_token_ids + delta_token_ids if parser:
delta_message = parser.parse_delta(
if reasoning_parser and tool_parser:
if prompt_is_reasoning_end:
reasoning_ended = True
if not reasoning_ended:
delta_message = reasoning_parser.extract_reasoning_streaming(
previous_text=previous_text,
current_text=current_text,
delta_text=delta_text,
previous_token_ids=previous_token_ids,
current_token_ids=current_token_ids,
delta_token_ids=delta_token_ids,
)
if reasoning_parser.is_reasoning_end(delta_token_ids):
reasoning_ended = True
current_token_ids = reasoning_parser.extract_content_ids(
delta_token_ids
)
if delta_message and delta_message.content:
current_text = delta_message.content
delta_message.content = None
else:
current_text = ""
if reasoning_ended:
if not tool_call_text_started:
tool_call_text_started = True
previous_text = ""
previous_token_ids = []
delta_text = current_text
delta_token_ids = current_token_ids
delta_message = tool_parser.extract_tool_calls_streaming(
previous_text=previous_text,
current_text=current_text,
delta_text=delta_text,
previous_token_ids=previous_token_ids,
current_token_ids=current_token_ids,
delta_token_ids=delta_token_ids,
request=request, # type: ignore[arg-type]
)
elif reasoning_parser:
delta_message = reasoning_parser.extract_reasoning_streaming(
previous_text=previous_text,
current_text=current_text,
delta_text=delta_text,
previous_token_ids=previous_token_ids,
current_token_ids=current_token_ids,
delta_token_ids=delta_token_ids,
)
elif tool_parser:
delta_message = tool_parser.extract_tool_calls_streaming(
previous_text=previous_text,
current_text=current_text,
delta_text=delta_text, delta_text=delta_text,
previous_token_ids=previous_token_ids,
current_token_ids=current_token_ids,
delta_token_ids=delta_token_ids, delta_token_ids=delta_token_ids,
request=request, # type: ignore[arg-type] request=request,
prompt_token_ids=ctx.last_output.prompt_token_ids,
) )
else: else:
delta_message = DeltaMessage( delta_message = DeltaMessage(
content=output.text, content=output.text,
) )
previous_text = current_text
previous_token_ids = current_token_ids
if not delta_message: if not delta_message:
continue continue
if not first_delta_sent: if not first_delta_sent:
......
...@@ -5,6 +5,7 @@ import contextlib ...@@ -5,6 +5,7 @@ import contextlib
import json import json
from abc import abstractmethod from abc import abstractmethod
from collections.abc import Sequence from collections.abc import Sequence
from dataclasses import dataclass, field
from functools import cached_property from functools import cached_property
from openai.types.responses import ( from openai.types.responses import (
...@@ -43,6 +44,17 @@ from vllm.utils import random_uuid ...@@ -43,6 +44,17 @@ from vllm.utils import random_uuid
logger = init_logger(__name__) logger = init_logger(__name__)
@dataclass
class StreamState:
"""Mutable state for ``Parser.parse_delta()``. One per stream."""
reasoning_ended: bool = False
tool_call_text_started: bool = False
prompt_reasoning_checked: bool = False
previous_text: str = ""
previous_token_ids: list[int] = field(default_factory=list)
class Parser: class Parser:
""" """
Abstract Parser class that unifies ReasoningParser and ToolParser into Abstract Parser class that unifies ReasoningParser and ToolParser into
...@@ -80,6 +92,7 @@ class Parser: ...@@ -80,6 +92,7 @@ class Parser:
self.model_tokenizer = tokenizer self.model_tokenizer = tokenizer
self._reasoning_parser: ReasoningParser | None = None self._reasoning_parser: ReasoningParser | None = None
self._tool_parser: ToolParser | None = None self._tool_parser: ToolParser | None = None
self._stream_state = StreamState()
@cached_property @cached_property
def vocab(self) -> dict[str, int]: def vocab(self) -> dict[str, int]:
...@@ -291,6 +304,18 @@ class Parser: ...@@ -291,6 +304,18 @@ class Parser:
A DeltaMessage with tool_calls field, or None. A DeltaMessage with tool_calls field, or None.
""" """
@abstractmethod
def parse_delta(
self,
delta_text: str,
delta_token_ids: list[int],
request: ChatCompletionRequest | ResponsesRequest,
prompt_token_ids: list[int] | None = None,
) -> DeltaMessage | None:
"""Parse a single streaming delta, orchestrating reasoning then
tool call extraction via internal stream state.
"""
class DelegatingParser(Parser): class DelegatingParser(Parser):
""" """
...@@ -524,6 +549,100 @@ class DelegatingParser(Parser): ...@@ -524,6 +549,100 @@ class DelegatingParser(Parser):
request, request,
) )
def is_reasoning_end(self, input_ids: list[int]) -> bool:
if self._reasoning_parser is None:
return False
return self._reasoning_parser.is_reasoning_end(input_ids)
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
if self._reasoning_parser is None:
return input_ids
return self._reasoning_parser.extract_content_ids(input_ids)
def parse_delta(
self,
delta_text: str,
delta_token_ids: list[int],
request: ChatCompletionRequest | ResponsesRequest,
prompt_token_ids: list[int] | None = None,
) -> DeltaMessage | None:
state = self._stream_state
if not state.prompt_reasoning_checked and prompt_token_ids is not None:
state.prompt_reasoning_checked = True
if self.is_reasoning_end(prompt_token_ids):
state.reasoning_ended = True
current_text = state.previous_text + delta_text
current_token_ids = state.previous_token_ids + delta_token_ids
delta_message: DeltaMessage | None = None
if self._reasoning_parser and self._tool_parser:
if not state.reasoning_ended:
delta_message = self.extract_reasoning_streaming(
previous_text=state.previous_text,
current_text=current_text,
delta_text=delta_text,
previous_token_ids=state.previous_token_ids,
current_token_ids=current_token_ids,
delta_token_ids=delta_token_ids,
)
if self.is_reasoning_end(delta_token_ids):
state.reasoning_ended = True
current_token_ids = self.extract_content_ids(delta_token_ids)
if delta_message and delta_message.content:
current_text = delta_message.content
delta_message.content = None
else:
current_text = ""
if state.reasoning_ended:
if not state.tool_call_text_started:
state.tool_call_text_started = True
state.previous_text = ""
state.previous_token_ids = []
delta_text = current_text
delta_token_ids = current_token_ids
delta_message = self.extract_tool_calls_streaming(
previous_text=state.previous_text,
current_text=current_text,
delta_text=delta_text,
previous_token_ids=state.previous_token_ids,
current_token_ids=current_token_ids,
delta_token_ids=delta_token_ids,
request=request, # type: ignore[arg-type]
)
elif self._reasoning_parser:
delta_message = self.extract_reasoning_streaming(
previous_text=state.previous_text,
current_text=current_text,
delta_text=delta_text,
previous_token_ids=state.previous_token_ids,
current_token_ids=current_token_ids,
delta_token_ids=delta_token_ids,
)
elif self._tool_parser:
delta_message = self.extract_tool_calls_streaming(
previous_text=state.previous_text,
current_text=current_text,
delta_text=delta_text,
previous_token_ids=state.previous_token_ids,
current_token_ids=current_token_ids,
delta_token_ids=delta_token_ids,
request=request, # type: ignore[arg-type]
)
else:
delta_message = DeltaMessage(content=delta_text)
state.previous_text = current_text
state.previous_token_ids = current_token_ids
return delta_message
class _WrappedParser(DelegatingParser): class _WrappedParser(DelegatingParser):
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment