"examples/vscode:/vscode.git/clone" did not exist on "3fd0ab3d74c9dc9e185580eb0a863f3c24465a67"
Unverified Commit 4e1bd700 authored by Graham King's avatar Graham King Committed by GitHub
Browse files

fix(frontend): vllm processor works with stream_interval > 1 (#6816)


Signed-off-by: default avatarGraham King <grahamk@nvidia.com>
parent 35f99f93
...@@ -8,8 +8,13 @@ from collections.abc import Sequence ...@@ -8,8 +8,13 @@ from collections.abc import Sequence
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any from typing import Any
from vllm.entrypoints.chat_utils import make_tool_call_id
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
from vllm.entrypoints.openai.engine.protocol import DeltaMessage, DeltaToolCall from vllm.entrypoints.openai.engine.protocol import (
DeltaFunctionCall,
DeltaMessage,
DeltaToolCall,
)
from vllm.reasoning import ReasoningParser from vllm.reasoning import ReasoningParser
from vllm.renderers import ChatParams from vllm.renderers import ChatParams
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
...@@ -255,6 +260,11 @@ class StreamingPostProcessor: ...@@ -255,6 +260,11 @@ class StreamingPostProcessor:
self.previous_token_ids: list[int] = [] self.previous_token_ids: list[int] = []
self.reasoning_is_done = False self.reasoning_is_done = False
self.in_progress_tool_calls: dict[int, DeltaToolCall] = {} self.in_progress_tool_calls: dict[int, DeltaToolCall] = {}
# Buffer for post-reasoning tool text when </think> and <tool_call>
# arrive in the same chunk. The streaming tool parser cannot handle
# this correctly, so we accumulate text here and fall back to the
# non-streaming extract_tool_calls() once the buffer is complete.
self._tool_text_buffer: str | None = None
@staticmethod @staticmethod
def _merge_tool_call( def _merge_tool_call(
...@@ -290,6 +300,102 @@ class StreamingPostProcessor: ...@@ -290,6 +300,102 @@ class StreamingPostProcessor:
stripped = stripped.replace(marker, "") stripped = stripped.replace(marker, "")
return stripped.strip() == "" return stripped.strip() == ""
def _should_parse_tools(self) -> bool:
return (
self.tool_parser is not None
and self.request_for_sampling.tool_choice != "none"
)
@staticmethod
def _compose_delta_message(
reasoning: str | None, content: str | None
) -> DeltaMessage | None:
delta_message = DeltaMessage(reasoning=reasoning, content=content)
if not delta_message.reasoning and not delta_message.content:
return None
return delta_message
def _add_tool_call_from_extracted(self, index: int, tool_call: Any) -> None:
tool_delta = DeltaToolCall(
index=index,
type="function",
id=(tool_call.id if tool_call.id else make_tool_call_id()),
function=DeltaFunctionCall(
name=tool_call.function.name,
arguments=tool_call.function.arguments,
),
)
existing = self.in_progress_tool_calls.get(index)
self.in_progress_tool_calls[index] = self._merge_tool_call(existing, tool_delta)
def _extract_tool_calls_from_text(
self, text: str, *, saved_reasoning: str | None = None
) -> DeltaMessage | None:
if self.tool_parser is None:
return self._compose_delta_message(saved_reasoning, None)
extracted = self.tool_parser.extract_tool_calls(text, self.request_for_sampling)
if extracted.tools_called:
for i, tool_call in enumerate(extracted.tool_calls):
self._add_tool_call_from_extracted(i, tool_call)
return self._compose_delta_message(saved_reasoning, None)
return self._compose_delta_message(saved_reasoning, extracted.content or None)
def _extract_tool_calls_streaming(
self,
*,
current_text: str,
delta_text: str,
delta_token_ids: list[int],
current_token_ids: list[int],
) -> DeltaMessage | None:
if self.tool_parser is None:
return None
return self.tool_parser.extract_tool_calls_streaming(
previous_text=self.previous_text,
current_text=current_text,
delta_text=delta_text,
previous_token_ids=self.previous_token_ids,
current_token_ids=current_token_ids,
delta_token_ids=delta_token_ids,
request=self.request_for_sampling,
)
def _merge_streaming_tool_calls(self, tool_calls: list[DeltaToolCall]) -> None:
for tool_delta in tool_calls:
existing = self.in_progress_tool_calls.get(tool_delta.index)
merged = self._merge_tool_call(existing, tool_delta)
self.in_progress_tool_calls[tool_delta.index] = merged
def _dump_in_progress_tool_calls(self) -> list[dict[str, Any]]:
return [
tool_call.model_dump(exclude_none=True)
for _, tool_call in self.in_progress_tool_calls.items()
]
def _emit_tool_calls_choice(self, output: Any) -> dict[str, Any]:
choice = {
"index": output.index,
"delta": {
"role": "assistant",
"tool_calls": self._dump_in_progress_tool_calls(),
},
"finish_reason": output.finish_reason,
"logprobs": output.logprobs,
}
self.in_progress_tool_calls.clear()
return choice
@staticmethod
def _build_choice(output: Any, delta: dict[str, Any]) -> dict[str, Any]:
return {
"index": output.index,
"delta": delta,
"finish_reason": output.finish_reason,
"logprobs": output.logprobs,
}
def process_output(self, output: Any) -> dict[str, Any] | None: def process_output(self, output: Any) -> dict[str, Any] | None:
delta_token_ids = list(output.token_ids or []) delta_token_ids = list(output.token_ids or [])
# vLLM output_processor already applies stop-token/stop-string trimming # vLLM output_processor already applies stop-token/stop-string trimming
...@@ -306,19 +412,36 @@ class StreamingPostProcessor: ...@@ -306,19 +412,36 @@ class StreamingPostProcessor:
delta = {} delta = {}
else: else:
return None return None
return { return self._build_choice(output, delta)
"index": output.index,
"delta": delta,
"finish_reason": output.finish_reason,
"logprobs": output.logprobs,
}
current_text = self.previous_text + delta_text current_text = self.previous_text + delta_text
current_token_ids = self.previous_token_ids + delta_token_ids current_token_ids = self.previous_token_ids + delta_token_ids
delta_message: DeltaMessage | None = DeltaMessage(content=delta_text) delta_message: DeltaMessage | None = DeltaMessage(content=delta_text)
if not self.reasoning_is_done and self.reasoning_parser: # ------------------------------------------------------------------
# Drain the tool-text buffer (populated when </think> and <tool_call>
# arrived in the same chunk). The streaming tool parser cannot
# handle that transition correctly, so we accumulate text here and
# use the non-streaming extract_tool_calls() once complete.
# ------------------------------------------------------------------
if self._tool_text_buffer is not None:
self._tool_text_buffer += delta_text
tool_call_end = getattr(self.tool_parser, "tool_call_end_token", None)
buffer_complete = (
tool_call_end and tool_call_end in self._tool_text_buffer
) or output.finish_reason
if buffer_complete:
buffered_text = self._tool_text_buffer
self._tool_text_buffer = None
delta_message = self._extract_tool_calls_from_text(buffered_text)
else:
# Still accumulating; emit nothing for this chunk.
self.previous_text = current_text
self.previous_token_ids = current_token_ids
return None
elif not self.reasoning_is_done and self.reasoning_parser:
delta_message = self.reasoning_parser.extract_reasoning_streaming( delta_message = self.reasoning_parser.extract_reasoning_streaming(
self.previous_text, self.previous_text,
current_text, current_text,
...@@ -328,68 +451,96 @@ class StreamingPostProcessor: ...@@ -328,68 +451,96 @@ class StreamingPostProcessor:
delta_token_ids, delta_token_ids,
) )
should_parse_tools = ( # When reasoning ends in this chunk, reset accumulated state.
self.tool_parser is not None # If there is post-reasoning content (e.g. <tool_call> markup),
and self.request_for_sampling.tool_choice != "none" # buffer it for non-streaming extraction rather than feeding it
# to the streaming tool parser which cannot handle the combined
# reasoning-end + tool-start in a single chunk.
if self.reasoning_parser.is_reasoning_end_streaming(
current_token_ids, delta_token_ids
):
self.reasoning_is_done = True
saved_reasoning = delta_message.reasoning if delta_message else None
post_content = (delta_message.content if delta_message else None) or ""
self.previous_text = ""
self.previous_token_ids = []
current_text = ""
current_token_ids = []
tool_call_start = getattr(
self.tool_parser, "tool_call_start_token", None
)
if post_content and tool_call_start and tool_call_start in post_content:
# Tool call markup present — buffer for non-streaming
# extraction (streaming parser can't handle the combined
# reasoning-end + tool-start in a single chunk).
self._tool_text_buffer = post_content
if output.finish_reason:
# If finish_reason is already set, this is the final
# chunk; parse buffered text now instead of waiting for
# a later call that will never happen.
buffered_text = self._tool_text_buffer
self._tool_text_buffer = None
delta_message = self._extract_tool_calls_from_text(
buffered_text,
saved_reasoning=saved_reasoning,
)
else:
delta_message = self._compose_delta_message(
saved_reasoning,
None,
)
else:
# Plain content (or no content) after reasoning end.
delta_message = self._compose_delta_message(
reasoning=saved_reasoning,
content=post_content if post_content else None,
)
elif (
delta_message
and delta_message.content
and not delta_message.reasoning
and self._should_parse_tools()
):
# Reasoning parser returned content (not reasoning).
# The model may have skipped reasoning and gone straight
# to tool calls (e.g. Mistral [TOOL_CALLS] without
# [THINK]...[/THINK]). Let the tool parser decide.
delta_message = self._extract_tool_calls_streaming(
current_text=current_text,
delta_text=delta_text,
current_token_ids=current_token_ids,
delta_token_ids=delta_token_ids,
) )
if should_parse_tools: else:
if self._should_parse_tools():
no_prev_reasoning = ( no_prev_reasoning = (
delta_message and delta_message.content and not delta_message.reasoning delta_message
and delta_message.content
and not delta_message.reasoning
) )
if self.reasoning_is_done or no_prev_reasoning: if self.reasoning_is_done or no_prev_reasoning:
delta_message = self.tool_parser.extract_tool_calls_streaming( delta_message = self._extract_tool_calls_streaming(
previous_text=self.previous_text,
current_text=current_text, current_text=current_text,
delta_text=delta_text, delta_text=delta_text,
previous_token_ids=self.previous_token_ids,
current_token_ids=current_token_ids, current_token_ids=current_token_ids,
delta_token_ids=delta_token_ids, delta_token_ids=delta_token_ids,
request=self.request_for_sampling,
) )
if (
not self.reasoning_is_done
and self.reasoning_parser
and self.reasoning_parser.is_reasoning_end_streaming(
current_token_ids, delta_token_ids
)
):
self.reasoning_is_done = True
self.previous_text = ""
self.previous_token_ids = []
current_text = ""
current_token_ids = []
choice = None choice = None
if delta_message is None: if delta_message is None:
if self.in_progress_tool_calls: if self.in_progress_tool_calls:
choice = { choice = self._emit_tool_calls_choice(output)
"index": output.index,
"delta": {
"role": "assistant",
"tool_calls": [
tool_call.model_dump(exclude_none=True)
for _, tool_call in sorted(
self.in_progress_tool_calls.items()
)
],
},
"finish_reason": output.finish_reason,
"logprobs": output.logprobs,
}
self.in_progress_tool_calls.clear()
elif output.finish_reason: elif output.finish_reason:
choice = { choice = self._build_choice(output, {})
"index": output.index,
"delta": {},
"finish_reason": output.finish_reason,
"logprobs": output.logprobs,
}
elif delta_message.tool_calls: elif delta_message.tool_calls:
for tool_delta in delta_message.tool_calls: self._merge_streaming_tool_calls(delta_message.tool_calls)
existing = self.in_progress_tool_calls.get(tool_delta.index) if output.finish_reason and self.in_progress_tool_calls:
merged = self._merge_tool_call(existing, tool_delta) # Tool calls and finish_reason arrived in the same chunk.
self.in_progress_tool_calls[tool_delta.index] = merged # Emit now — there will be no subsequent process_output call
# to drain the buffer.
choice = self._emit_tool_calls_choice(output)
elif delta_message.content or delta_message.reasoning: elif delta_message.content or delta_message.reasoning:
delta: dict[str, Any] = {"role": "assistant"} delta: dict[str, Any] = {"role": "assistant"}
content = delta_message.content content = delta_message.content
...@@ -400,39 +551,14 @@ class StreamingPostProcessor: ...@@ -400,39 +551,14 @@ class StreamingPostProcessor:
if delta_message.reasoning: if delta_message.reasoning:
delta["reasoning_content"] = delta_message.reasoning delta["reasoning_content"] = delta_message.reasoning
if self.in_progress_tool_calls: if self.in_progress_tool_calls:
delta["tool_calls"] = [ delta["tool_calls"] = self._dump_in_progress_tool_calls()
tool_call.model_dump(exclude_none=True)
for _, tool_call in sorted(self.in_progress_tool_calls.items())
]
self.in_progress_tool_calls.clear() self.in_progress_tool_calls.clear()
if len(delta) > 1: if len(delta) > 1:
choice = { choice = self._build_choice(output, delta)
"index": output.index,
"delta": delta,
"finish_reason": output.finish_reason,
"logprobs": output.logprobs,
}
elif self.in_progress_tool_calls: elif self.in_progress_tool_calls:
choice = { choice = self._emit_tool_calls_choice(output)
"index": output.index,
"delta": {
"role": "assistant",
"tool_calls": [
tool_call.model_dump(exclude_none=True)
for _, tool_call in sorted(self.in_progress_tool_calls.items())
],
},
"finish_reason": output.finish_reason,
"logprobs": output.logprobs,
}
self.in_progress_tool_calls.clear()
elif output.finish_reason: elif output.finish_reason:
choice = { choice = self._build_choice(output, {})
"index": output.index,
"delta": {},
"finish_reason": output.finish_reason,
"logprobs": output.logprobs,
}
self.previous_text = current_text self.previous_text = current_text
self.previous_token_ids = current_token_ids self.previous_token_ids = current_token_ids
......
...@@ -114,6 +114,7 @@ def _init_worker( ...@@ -114,6 +114,7 @@ def _init_worker(
) -> None: ) -> None:
"""Initialize a worker process with its own VllmConfig and InputProcessor.""" """Initialize a worker process with its own VllmConfig and InputProcessor."""
global _w_input_processor, _w_tokenizer, _w_tool_parser_class global _w_input_processor, _w_tokenizer, _w_tool_parser_class
global _w_reasoning_parser_class
model_config = ModelConfig( model_config = ModelConfig(
model=model_path, model=model_path,
......
...@@ -18,6 +18,7 @@ kr8s==0.20.13 ...@@ -18,6 +18,7 @@ kr8s==0.20.13
kubernetes==32.0.1 kubernetes==32.0.1
kubernetes_asyncio==32.0.0 kubernetes_asyncio==32.0.0
matplotlib==3.10.7 matplotlib==3.10.7
mistral-common==1.9.1
# For NATS object store verification in router tests # For NATS object store verification in router tests
nats-py==2.12.0 nats-py==2.12.0
pmdarima==2.1.1 pmdarima==2.1.1
......
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import importlib
def check_module_available(module_name: str) -> bool:
"""For tests / pre-commit"""
if importlib.util.find_spec(module_name) is None:
return False
try:
importlib.import_module(module_name)
return True
except ImportError:
return False
This diff is collapsed.
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Unit test for StreamingPostProcessor with Mistral reasoning + tool calling."""
# mypy seems to be running both sides of the HAS_VLLM if statement
# mypy: ignore-errors
import json
import pytest
from .common import check_module_available
HAS_VLLM = check_module_available("vllm")
if HAS_VLLM:
from mistral_common.tokens.tokenizers.base import SpecialTokens
from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest,
ChatCompletionToolsParam,
)
from vllm.entrypoints.openai.engine.protocol import FunctionDefinition
from vllm.outputs import CompletionOutput
from vllm.reasoning.mistral_reasoning_parser import MistralReasoningParser
from vllm.sampling_params import SamplingParams
from vllm.tokenizers.mistral import MistralTokenizer
from vllm.tool_parsers.mistral_tool_parser import MistralToolParser
from dynamo.frontend.prepost import StreamingPostProcessor
else:
# Fake some types so that `pre-commit` passes
class MistralTokenizer:
pass
class CompletionOutput:
def __init__(*args, **kwargs):
pass
pytestmark = [
pytest.mark.vllm,
pytest.mark.gpu_0, # "Hardware"
pytest.mark.pre_merge, # "Lifecyle"
pytest.mark.unit, # "Test Type"
pytest.mark.skipif(not HAS_VLLM, reason="requires vllm"),
]
# ---------------------------------------------------------------------------
# Mock MistralTokenizer
# ---------------------------------------------------------------------------
# Token IDs from unit_test_4.txt
TOOL_CALLS_TOKEN_ID = 9
EOS_TOKEN_ID = 2
BOS_TOKEN_ID = 1
# Arbitrary IDs for think tokens (not present in this test's output, but
# needed to initialise MistralReasoningParser).
THINK_START_TOKEN_ID = 7
THINK_END_TOKEN_ID = 8
class _InnerTokenizer:
"""Mimics the inner ``tokenizer.tokenizer`` accessed by MistralReasoningParser."""
def get_control_token(self, token):
return {
SpecialTokens.begin_think: THINK_START_TOKEN_ID,
SpecialTokens.end_think: THINK_END_TOKEN_ID,
}.get(token)
class MockMistralTokenizer(MistralTokenizer):
"""Lightweight MistralTokenizer subclass for testing.
Passes ``isinstance(tok, MistralTokenizer)`` without needing model files.
"""
def __new__(cls):
# Bypass MistralTokenizer.__init__ (needs model artefacts).
return object.__new__(cls)
def __init__(self):
self.version = 11
self._vocab_dict = {"[TOOL_CALLS]": TOOL_CALLS_TOKEN_ID}
self.tokenizer = _InnerTokenizer()
self._special_tokens = ["[TOOL_CALLS]"]
def __bool__(self):
# Needed because MistralReasoningParser does ``if not self.model_tokenizer``
# which triggers __len__ → vocab_size on the real MistralTokenizer.
return True
def get_vocab(self):
return dict(self._vocab_dict)
@property
def all_special_tokens(self):
return self._special_tokens
# ---------------------------------------------------------------------------
# Test data from unit_test_4.txt (stream_interval=1, Mistral format)
#
# Output: [TOOL_CALLS]search_gutenberg_books{"search_terms": ["James Joyce"]}
# No reasoning tokens at all — the model jumps straight to tool calls.
# ---------------------------------------------------------------------------
OUTPUTS_INTERVAL_1 = [
CompletionOutput(
index=0,
text="[TOOL_CALLS]",
token_ids=[9],
routed_experts=None,
cumulative_logprob=None,
logprobs=None,
finish_reason=None,
stop_reason=None,
),
CompletionOutput(
index=0,
text="search",
token_ids=[8928],
routed_experts=None,
cumulative_logprob=None,
logprobs=None,
finish_reason=None,
stop_reason=None,
),
CompletionOutput(
index=0,
text="_g",
token_ids=[11898],
routed_experts=None,
cumulative_logprob=None,
logprobs=None,
finish_reason=None,
stop_reason=None,
),
CompletionOutput(
index=0,
text="uten",
token_ids=[8318],
routed_experts=None,
cumulative_logprob=None,
logprobs=None,
finish_reason=None,
stop_reason=None,
),
CompletionOutput(
index=0,
text="berg",
token_ids=[6415],
routed_experts=None,
cumulative_logprob=None,
logprobs=None,
finish_reason=None,
stop_reason=None,
),
CompletionOutput(
index=0,
text="_",
token_ids=[1095],
routed_experts=None,
cumulative_logprob=None,
logprobs=None,
finish_reason=None,
stop_reason=None,
),
CompletionOutput(
index=0,
text="books",
token_ids=[32493],
routed_experts=None,
cumulative_logprob=None,
logprobs=None,
finish_reason=None,
stop_reason=None,
),
CompletionOutput(
index=0,
text="",
token_ids=[32],
routed_experts=None,
cumulative_logprob=None,
logprobs=None,
finish_reason=None,
stop_reason=None,
),
CompletionOutput(
index=0,
text='{"',
token_ids=[19227],
routed_experts=None,
cumulative_logprob=None,
logprobs=None,
finish_reason=None,
stop_reason=None,
),
CompletionOutput(
index=0,
text="search",
token_ids=[8928],
routed_experts=None,
cumulative_logprob=None,
logprobs=None,
finish_reason=None,
stop_reason=None,
),
CompletionOutput(
index=0,
text="_",
token_ids=[1095],
routed_experts=None,
cumulative_logprob=None,
logprobs=None,
finish_reason=None,
stop_reason=None,
),
CompletionOutput(
index=0,
text="terms",
token_ids=[62244],
routed_experts=None,
cumulative_logprob=None,
logprobs=None,
finish_reason=None,
stop_reason=None,
),
CompletionOutput(
index=0,
text='":',
token_ids=[2811],
routed_experts=None,
cumulative_logprob=None,
logprobs=None,
finish_reason=None,
stop_reason=None,
),
CompletionOutput(
index=0,
text=' ["',
token_ids=[12161],
routed_experts=None,
cumulative_logprob=None,
logprobs=None,
finish_reason=None,
stop_reason=None,
),
CompletionOutput(
index=0,
text="James",
token_ids=[31872],
routed_experts=None,
cumulative_logprob=None,
logprobs=None,
finish_reason=None,
stop_reason=None,
),
CompletionOutput(
index=0,
text=" Joyce",
token_ids=[58617],
routed_experts=None,
cumulative_logprob=None,
logprobs=None,
finish_reason=None,
stop_reason=None,
),
CompletionOutput(
index=0,
text='"]',
token_ids=[4964],
routed_experts=None,
cumulative_logprob=None,
logprobs=None,
finish_reason=None,
stop_reason=None,
),
CompletionOutput(
index=0,
text="}",
token_ids=[1125],
routed_experts=None,
cumulative_logprob=None,
logprobs=None,
finish_reason=None,
stop_reason=None,
),
CompletionOutput(
index=0,
text="",
token_ids=[2],
routed_experts=None,
cumulative_logprob=None,
logprobs=None,
finish_reason="stop",
stop_reason=None,
),
]
# ---------------------------------------------------------------------------
# Test data from unit_test_5.txt (stream_interval=20, Mistral format)
#
# Only 2 chunks: [TOOL_CALLS] alone, then the entire function name + JSON
# arguments + EOS in a single CompletionOutput with finish_reason=stop.
# ---------------------------------------------------------------------------
OUTPUTS_INTERVAL_20 = [
CompletionOutput(
index=0,
text="[TOOL_CALLS]",
token_ids=[9],
routed_experts=None,
cumulative_logprob=None,
logprobs=None,
finish_reason=None,
stop_reason=None,
),
CompletionOutput(
index=0,
text='search_gutenberg_books{"search_terms": ["James Joyce books"]}',
token_ids=[
8928,
11898,
8318,
6415,
1095,
32493,
32,
19227,
8928,
1095,
62244,
2811,
12161,
31872,
58617,
12796,
4964,
1125,
2,
],
routed_experts=None,
cumulative_logprob=None,
logprobs=None,
finish_reason="stop",
stop_reason=None,
),
]
PROMPT_TOKEN_IDS = [
1,
5,
1091,
19227,
4994,
2811,
1429,
5165,
1897,
1429,
5165,
2811,
16753,
2391,
2811,
1429,
8928,
11898,
8318,
6415,
1095,
32493,
1897,
1429,
14653,
2811,
1429,
8483,
1394,
12796,
1294,
1278,
13217,
111317,
6415,
11329,
1897,
1429,
26204,
2811,
16753,
4994,
2811,
1429,
6371,
1897,
1429,
48649,
2811,
16753,
8928,
1095,
62244,
2811,
16753,
4994,
2811,
1429,
5477,
1897,
1429,
11089,
2811,
16753,
4994,
2811,
1429,
3607,
50666,
1429,
14653,
2811,
1429,
2525,
1307,
6123,
6856,
1317,
3081,
12796,
1034,
47579,
1429,
15760,
2811,
12161,
8928,
1095,
62244,
4964,
2821,
27028,
6,
3,
7493,
1584,
1278,
26864,
1307,
2269,
7456,
58617,
12796,
1063,
13516,
1278,
9519,
1317,
6123,
1046,
4,
]
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture
def tokenizer():
return MockMistralTokenizer()
@pytest.fixture
def request_for_sampling():
"""Construct a ChatCompletionRequest matching the Mistral test spec."""
return ChatCompletionRequest.model_construct(
messages=[
{
"content": "What are the titles of some James Joyce books? "
"Use the tool to search.",
"role": "user",
}
],
model="mistralai/Ministral-3-3B-Reasoning-2512",
tools=[
ChatCompletionToolsParam(
type="function",
function=FunctionDefinition(
name="search_gutenberg_books",
description="Search for books in the Project Gutenberg library",
parameters={
"type": "object",
"properties": {
"search_terms": {
"type": "array",
"items": {"type": "string"},
"description": "List of search terms to find books",
}
},
"required": ["search_terms"],
},
),
)
],
tool_choice="auto",
include_reasoning=True,
stream=False,
n=1,
frequency_penalty=0.0,
presence_penalty=0.0,
temperature=None,
top_p=None,
skip_special_tokens=True,
chat_template_kwargs=None,
reasoning_effort=None,
parallel_tool_calls=True,
)
@pytest.fixture
def sampling_params():
return SamplingParams(
n=1,
presence_penalty=0.0,
frequency_penalty=0.0,
repetition_penalty=1.0,
temperature=1.0,
top_p=1.0,
top_k=0,
min_p=0.0,
seed=None,
stop=[],
stop_token_ids=[],
include_stop_str_in_output=False,
ignore_eos=False,
max_tokens=100000,
min_tokens=0,
logprobs=None,
prompt_logprobs=None,
skip_special_tokens=True,
spaces_between_special_tokens=True,
truncate_prompt_tokens=None,
)
@pytest.fixture
def processor(tokenizer, request_for_sampling, sampling_params):
tool_parser = MistralToolParser(tokenizer)
return StreamingPostProcessor(
tokenizer=tokenizer,
request_for_sampling=request_for_sampling,
sampling_params=sampling_params,
prompt_token_ids=PROMPT_TOKEN_IDS,
tool_parser=tool_parser,
reasoning_parser_class=MistralReasoningParser,
chat_template_kwargs={"reasoning_effort": None},
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _collect_results(processor, outputs):
"""Run all outputs through process_output and collect non-None results."""
results = []
for output in outputs:
result = processor.process_output(output)
if result is not None:
results.append(result)
return results
def _collect_reasoning(results):
"""Extract and join all reasoning_content from results."""
parts = []
for r in results:
rc = r.get("delta", {}).get("reasoning_content")
if rc is not None:
parts.append(rc)
return "".join(parts)
def _collect_tool_calls(results):
"""Merge all streamed tool_call deltas into complete tool calls."""
merged: dict[int, dict] = {}
for r in results:
tc_list = r.get("delta", {}).get("tool_calls")
if not tc_list:
continue
for tc in tc_list:
idx = tc["index"]
if idx not in merged:
merged[idx] = {
"id": tc.get("id"),
"type": tc.get("type"),
"function": {
"name": tc.get("function", {}).get("name"),
"arguments": tc.get("function", {}).get("arguments", ""),
},
}
else:
existing = merged[idx]
if tc.get("id") and not existing["id"]:
existing["id"] = tc["id"]
if tc.get("type") and not existing["type"]:
existing["type"] = tc["type"]
fn = tc.get("function", {})
if fn.get("name") and not existing["function"]["name"]:
existing["function"]["name"] = fn["name"]
if fn.get("arguments"):
existing["function"]["arguments"] += fn["arguments"]
return [merged[k] for k in sorted(merged)]
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
@pytest.mark.vllm
def test_mistral_tool_call(processor):
"""Mistral tool call with no reasoning.
The model output is:
[TOOL_CALLS]search_gutenberg_books{"search_terms": ["James Joyce"]}
with no [THINK]...[/THINK] reasoning block.
The tool parser should extract the tool call correctly, not leak the
tool-call markup as plain content.
"""
results = _collect_results(processor, OUTPUTS_INTERVAL_1)
tool_calls = _collect_tool_calls(results)
# -- tool calls must be parsed correctly --------------------------------
assert len(tool_calls) == 1, (
f"Expected 1 tool call but got {len(tool_calls)}. "
"Tool-call markup was likely emitted as plain content."
)
tc = tool_calls[0]
assert tc["function"]["name"] == "search_gutenberg_books"
assert json.loads(tc["function"]["arguments"]) == {
"search_terms": ["James Joyce"],
}
assert tc["id"] is not None
assert tc["type"] == "function"
# -- no reasoning content should be present -----------------------------
reasoning = _collect_reasoning(results)
assert reasoning == "", f"Unexpected reasoning content: {reasoning!r}"
# -- [TOOL_CALLS] markup should not appear in content -------------------
all_content = "".join(r.get("delta", {}).get("content", "") for r in results)
assert (
"[TOOL_CALLS]" not in all_content
), f"Raw [TOOL_CALLS] markup leaked into content: {all_content!r}"
# -- finish reason ------------------------------------------------------
finish_reasons = [r["finish_reason"] for r in results if r.get("finish_reason")]
assert "stop" in finish_reasons
@pytest.mark.vllm
def test_mistral_tool_call_interval_20(
tokenizer, request_for_sampling, sampling_params
):
"""stream_interval=20: function name + args + EOS in a single chunk.
Only 2 CompletionOutput objects:
1. [TOOL_CALLS] alone
2. search_gutenberg_books{"search_terms": ["James Joyce books"]}
with finish_reason=stop
The tool call and finish_reason arrive together. The processor must
still emit the parsed tool call and the finish_reason.
"""
tool_parser = MistralToolParser(tokenizer)
proc = StreamingPostProcessor(
tokenizer=tokenizer,
request_for_sampling=request_for_sampling,
sampling_params=sampling_params,
prompt_token_ids=PROMPT_TOKEN_IDS,
tool_parser=tool_parser,
reasoning_parser_class=MistralReasoningParser,
chat_template_kwargs={"reasoning_effort": None},
)
results = _collect_results(proc, OUTPUTS_INTERVAL_20)
tool_calls = _collect_tool_calls(results)
# -- tool calls must be parsed correctly --------------------------------
assert len(tool_calls) == 1, (
f"Expected 1 tool call but got {len(tool_calls)}. "
"Tool-call markup was likely emitted as plain content."
)
tc = tool_calls[0]
assert tc["function"]["name"] == "search_gutenberg_books"
assert json.loads(tc["function"]["arguments"]) == {
"search_terms": ["James Joyce books"],
}
assert tc["id"] is not None
assert tc["type"] == "function"
# -- no reasoning content should be present -----------------------------
reasoning = _collect_reasoning(results)
assert reasoning == "", f"Unexpected reasoning content: {reasoning!r}"
# -- [TOOL_CALLS] markup should not appear in content -------------------
all_content = "".join(r.get("delta", {}).get("content", "") for r in results)
assert (
"[TOOL_CALLS]" not in all_content
), f"Raw [TOOL_CALLS] markup leaked into content: {all_content!r}"
# -- finish reason ------------------------------------------------------
finish_reasons = [r["finish_reason"] for r in results if r.get("finish_reason")]
assert "stop" in finish_reasons
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment