Unverified Commit c0722f22 authored by Julien Denize's avatar Julien Denize Committed by GitHub
Browse files

[Mistral Grammar] Fix tool and reasoning parsing (#39217)


Signed-off-by: default avatarjuliendenize <julien.denize@mistral.ai>
parent 951dca80
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
from dataclasses import dataclass, field
import openai import openai
import pytest import pytest
from tests.tool_use.utils import MESSAGES_ASKING_FOR_TOOLS, WEATHER_TOOL from tests.tool_use.utils import (
MESSAGES_ASKING_FOR_PARALLEL_TOOLS,
MESSAGES_ASKING_FOR_TOOLS,
MESSAGES_WITH_TOOL_RESPONSE,
MESSAGES_WITHOUT_TOOLS,
SEARCH_TOOL,
SEED,
WEATHER_TOOL,
ensure_system_prompt,
)
from .utils import ServerConfig
def _requires_tool_parser(server_config: ServerConfig) -> None:
r"""Skip test if server was not started with --tool-call-parser."""
if "--tool-call-parser" not in server_config.get("arguments", []):
pytest.skip(
f"Skipping: {server_config['model']} not configured with --tool-call-parser"
)
def _is_pre_v11(server_config: ServerConfig) -> bool:
r"""Pre-v11 Mistral models lack grammar-based tool call enforcement."""
return "7B" in server_config.get("model", "")
@dataclass
class StreamedToolCallResult:
r"""Accumulated result from streaming a single tool call."""
function_name: str | None = None
function_args_str: str = ""
tool_call_id: str | None = None
role_name: str | None = None
finish_reason_count: int = 0
finish_reason: str | None = None
async def _collect_streamed_tool_call(
stream: openai.AsyncStream,
*,
expected_finish_reason: str = "tool_calls",
) -> StreamedToolCallResult:
result = StreamedToolCallResult()
async for chunk in stream:
if chunk.choices[0].finish_reason:
result.finish_reason_count += 1
result.finish_reason = chunk.choices[0].finish_reason
assert chunk.choices[0].finish_reason == expected_finish_reason
if chunk.choices[0].delta.role:
assert not result.role_name or result.role_name == "assistant"
result.role_name = "assistant"
streamed_tool_calls = chunk.choices[0].delta.tool_calls
if streamed_tool_calls and len(streamed_tool_calls) > 0:
assert len(streamed_tool_calls) == 1
tool_call = streamed_tool_calls[0]
if tool_call.id:
assert not result.tool_call_id
result.tool_call_id = tool_call.id
if tool_call.function:
if tool_call.function.name:
assert result.function_name is None
result.function_name = tool_call.function.name
if tool_call.function.arguments:
result.function_args_str += tool_call.function.arguments
return result
@dataclass
class StreamedContentResult:
r"""Accumulated result from streaming a content-only response."""
chunks: list[str] = field(default_factory=list)
finish_reason_count: int = 0
finish_reason: str | None = None
role_sent: bool = False
async def _collect_streamed_content(
stream: openai.AsyncStream,
*,
expected_finish_reason: str | None = None,
no_tool_calls: bool = True,
) -> StreamedContentResult:
r"""Consume a streaming response and collect text content."""
result = StreamedContentResult()
async for chunk in stream:
delta = chunk.choices[0].delta
if delta.role:
assert not result.role_sent
assert delta.role == "assistant"
result.role_sent = True
if delta.content:
result.chunks.append(delta.content)
if chunk.choices[0].finish_reason is not None:
result.finish_reason_count += 1
result.finish_reason = chunk.choices[0].finish_reason
if expected_finish_reason is not None:
assert result.finish_reason == expected_finish_reason
if no_tool_calls:
assert not delta.tool_calls or len(delta.tool_calls) == 0
return result
@dataclass
class StreamedParallelToolCallResult:
r"""Accumulated result from streaming parallel tool calls."""
function_names: list[str] = field(default_factory=list)
function_args_strs: list[str] = field(default_factory=list)
tool_call_ids: list[str] = field(default_factory=list)
role_name: str | None = None
finish_reason_count: int = 0
async def _collect_streamed_parallel_tool_calls(
stream: openai.AsyncStream,
) -> StreamedParallelToolCallResult:
r"""Consume a streaming response and collect parallel tool calls."""
result = StreamedParallelToolCallResult()
tool_call_idx: int = -1
async for chunk in stream:
if chunk.choices[0].finish_reason:
result.finish_reason_count += 1
assert chunk.choices[0].finish_reason == "tool_calls"
if chunk.choices[0].delta.role:
assert not result.role_name or result.role_name == "assistant"
result.role_name = "assistant"
streamed_tool_calls = chunk.choices[0].delta.tool_calls
if streamed_tool_calls and len(streamed_tool_calls) > 0:
assert len(streamed_tool_calls) == 1
tool_call = streamed_tool_calls[0]
if tool_call.index != tool_call_idx:
tool_call_idx = tool_call.index
result.function_args_strs.append("")
result.tool_call_ids.append("")
if tool_call.id:
result.tool_call_ids[tool_call.index] = tool_call.id
if tool_call.function:
if tool_call.function.name:
result.function_names.append(tool_call.function.name)
if tool_call.function.arguments:
result.function_args_strs[tool_call.index] += (
tool_call.function.arguments
)
return result
# test: a tool_choice with mistral-tokenizer results in an ID of length 9 # test: a tool_choice with mistral-tokenizer results in an ID of length 9
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_tool_call_with_tool_choice(client: openai.AsyncOpenAI): async def test_tool_call_with_tool_choice(
client: openai.AsyncOpenAI, server_config: ServerConfig
) -> None:
_requires_tool_parser(server_config)
models = await client.models.list() models = await client.models.list()
model_name: str = models.data[0].id model_name: str = models.data[0].id
chat_completion = await client.chat.completions.create( chat_completion = await client.chat.completions.create(
messages=MESSAGES_ASKING_FOR_TOOLS, messages=ensure_system_prompt(MESSAGES_ASKING_FOR_TOOLS, server_config),
temperature=0, temperature=0,
max_completion_tokens=100, max_completion_tokens=100,
model=model_name, model=model_name,
tools=[WEATHER_TOOL], tools=[WEATHER_TOOL],
tool_choice=WEATHER_TOOL, tool_choice=WEATHER_TOOL,
logprobs=False, logprobs=False,
seed=SEED,
) )
choice = chat_completion.choices[0] choice = chat_completion.choices[0]
...@@ -28,3 +201,307 @@ async def test_tool_call_with_tool_choice(client: openai.AsyncOpenAI): ...@@ -28,3 +201,307 @@ async def test_tool_call_with_tool_choice(client: openai.AsyncOpenAI):
assert choice.message.role == "assistant" assert choice.message.role == "assistant"
assert choice.message.tool_calls is None or len(choice.message.tool_calls) == 1 assert choice.message.tool_calls is None or len(choice.message.tool_calls) == 1
assert len(choice.message.tool_calls[0].id) == 9 # length of 9 for mistral assert len(choice.message.tool_calls[0].id) == 9 # length of 9 for mistral
_NOT_SET = object()
@pytest.mark.asyncio
@pytest.mark.parametrize(
"tools, tool_choice, streaming_id_len_pre_v11",
[
pytest.param(
[WEATHER_TOOL, SEARCH_TOOL],
_NOT_SET,
9,
id="auto",
),
pytest.param(
[WEATHER_TOOL],
"required",
30,
id="required",
),
],
)
async def test_tool_call_auto_or_required(
client: openai.AsyncOpenAI,
server_config: ServerConfig,
tools: list,
tool_choice: object,
streaming_id_len_pre_v11: int,
) -> None:
_requires_tool_parser(server_config)
models = await client.models.list()
model_name: str = models.data[0].id
create_kwargs: dict = {
"messages": ensure_system_prompt(MESSAGES_ASKING_FOR_TOOLS, server_config),
"temperature": 0,
"max_completion_tokens": 100,
"model": model_name,
"tools": tools,
"logprobs": False,
"seed": SEED,
}
if tool_choice is not _NOT_SET:
create_kwargs["tool_choice"] = tool_choice
# --- non-streaming ---
chat_completion = await client.chat.completions.create(**create_kwargs)
choice = chat_completion.choices[0]
tool_calls = choice.message.tool_calls
assert choice.finish_reason == "tool_calls"
assert tool_calls is not None and len(tool_calls) >= 1
assert tool_calls[0].function.name == "get_current_weather"
parsed_arguments = json.loads(tool_calls[0].function.arguments)
assert "city" in parsed_arguments
assert len(tool_calls[0].id) == 9
# --- streaming ---
stream = await client.chat.completions.create(**create_kwargs, stream=True)
result = await _collect_streamed_tool_call(stream)
assert result.finish_reason_count == 1
assert result.role_name == "assistant"
assert result.function_name == "get_current_weather"
streamed_args = json.loads(result.function_args_str)
assert isinstance(result.tool_call_id, str)
if _is_pre_v11(server_config):
assert len(result.tool_call_id) == streaming_id_len_pre_v11
else:
assert len(result.tool_call_id) == 9
assert parsed_arguments == streamed_args
@pytest.mark.asyncio
async def test_tool_call_none_with_tools(
client: openai.AsyncOpenAI, server_config: ServerConfig
) -> None:
_requires_tool_parser(server_config)
models = await client.models.list()
model_name: str = models.data[0].id
# --- non-streaming ---
chat_completion = await client.chat.completions.create(
messages=ensure_system_prompt(MESSAGES_ASKING_FOR_TOOLS, server_config),
temperature=0,
max_completion_tokens=100,
model=model_name,
tools=[WEATHER_TOOL],
tool_choice="none",
logprobs=False,
seed=SEED,
)
choice = chat_completion.choices[0]
assert choice.finish_reason != "tool_calls"
assert choice.message.tool_calls is None or len(choice.message.tool_calls) == 0
assert choice.message.content is not None
# Without grammar enforcement, pre-v11 models may still emit [TOOL_CALLS]
if not _is_pre_v11(server_config):
assert "[TOOL_CALLS]" not in choice.message.content
non_streaming_content = choice.message.content
# --- streaming ---
stream = await client.chat.completions.create(
messages=ensure_system_prompt(MESSAGES_ASKING_FOR_TOOLS, server_config),
temperature=0,
max_completion_tokens=100,
model=model_name,
tools=[WEATHER_TOOL],
tool_choice="none",
logprobs=False,
seed=SEED,
stream=True,
)
# Pre-v11 models lack grammar enforcement, so the model may still
# emit tool calls even with tool_choice="none".
pre_v11 = _is_pre_v11(server_config)
result = await _collect_streamed_content(stream, no_tool_calls=not pre_v11)
assert result.finish_reason_count == 1
if not pre_v11:
assert result.finish_reason != "tool_calls"
streamed_content = "".join(result.chunks)
if not pre_v11:
assert "[TOOL_CALLS]" not in streamed_content
assert streamed_content == non_streaming_content
@pytest.mark.asyncio
async def test_chat_without_tools(
client: openai.AsyncOpenAI, server_config: ServerConfig
) -> None:
models = await client.models.list()
model_name: str = models.data[0].id
# --- non-streaming ---
chat_completion = await client.chat.completions.create(
messages=ensure_system_prompt(MESSAGES_WITHOUT_TOOLS, server_config),
temperature=0,
max_completion_tokens=150,
model=model_name,
logprobs=False,
seed=SEED,
)
choice = chat_completion.choices[0]
output_text = choice.message.content
assert output_text is not None and len(output_text) > 0
assert choice.finish_reason != "tool_calls"
assert choice.message.tool_calls is None or len(choice.message.tool_calls) == 0
# --- streaming ---
stream = await client.chat.completions.create(
messages=ensure_system_prompt(MESSAGES_WITHOUT_TOOLS, server_config),
temperature=0,
max_completion_tokens=150,
model=model_name,
logprobs=False,
seed=SEED,
stream=True,
)
result = await _collect_streamed_content(
stream, expected_finish_reason=choice.finish_reason
)
assert result.role_sent
assert result.finish_reason_count == 1
assert len(result.chunks)
assert "".join(result.chunks) == output_text
@pytest.mark.asyncio
async def test_tool_call_with_results(
client: openai.AsyncOpenAI, server_config: ServerConfig
) -> None:
_requires_tool_parser(server_config)
models = await client.models.list()
model_name: str = models.data[0].id
# --- non-streaming ---
chat_completion = await client.chat.completions.create(
messages=ensure_system_prompt(MESSAGES_WITH_TOOL_RESPONSE, server_config),
temperature=0,
max_completion_tokens=100,
model=model_name,
tools=[WEATHER_TOOL, SEARCH_TOOL],
logprobs=False,
seed=SEED,
)
choice = chat_completion.choices[0]
assert choice.finish_reason != "tool_calls"
assert choice.message.tool_calls is None or len(choice.message.tool_calls) == 0
assert choice.message.content is not None
assert "98" in choice.message.content
# --- streaming ---
stream = await client.chat.completions.create(
messages=ensure_system_prompt(MESSAGES_WITH_TOOL_RESPONSE, server_config),
temperature=0,
max_completion_tokens=100,
model=model_name,
tools=[WEATHER_TOOL, SEARCH_TOOL],
logprobs=False,
seed=SEED,
stream=True,
)
result = await _collect_streamed_content(
stream, expected_finish_reason=choice.finish_reason
)
assert result.role_sent
assert result.finish_reason_count == 1
assert len(result.chunks)
assert "".join(result.chunks) == choice.message.content
def _requires_parallel(server_config: ServerConfig) -> None:
r"""Skip test if the model does not support parallel tool calls."""
if not server_config.get("supports_parallel"):
pytest.skip(
f"Skipping: {server_config['model']} does not support parallel tool calls"
)
@pytest.mark.asyncio
async def test_tool_call_parallel(
client: openai.AsyncOpenAI, server_config: ServerConfig
) -> None:
_requires_tool_parser(server_config)
_requires_parallel(server_config)
models = await client.models.list()
model_name: str = models.data[0].id
# --- non-streaming ---
chat_completion = await client.chat.completions.create(
messages=ensure_system_prompt(
MESSAGES_ASKING_FOR_PARALLEL_TOOLS, server_config
),
temperature=0,
max_completion_tokens=200,
model=model_name,
tools=[WEATHER_TOOL],
logprobs=False,
seed=SEED,
)
choice = chat_completion.choices[0]
tool_calls = choice.message.tool_calls
assert choice.finish_reason == "tool_calls"
assert tool_calls is not None and len(tool_calls) >= 2
for tc in tool_calls:
assert tc.type == "function"
assert tc.function.name == "get_current_weather"
assert isinstance(tc.function.arguments, str)
parsed = json.loads(tc.function.arguments)
assert "city" in parsed
assert len(tc.id) == 9
non_streaming_tool_calls = tool_calls
# --- streaming ---
stream = await client.chat.completions.create(
messages=ensure_system_prompt(
MESSAGES_ASKING_FOR_PARALLEL_TOOLS, server_config
),
temperature=0,
max_completion_tokens=200,
model=model_name,
tools=[WEATHER_TOOL],
logprobs=False,
seed=SEED,
stream=True,
)
result = await _collect_streamed_parallel_tool_calls(stream)
assert result.finish_reason_count == 1
assert result.role_name == "assistant"
assert len(result.function_names) >= 2
assert all(name == "get_current_weather" for name in result.function_names)
assert len(result.tool_call_ids) >= 2
assert all(isinstance(tid, str) and len(tid) == 9 for tid in result.tool_call_ids)
for args_str in result.function_args_strs:
streamed_args = json.loads(args_str)
assert "city" in streamed_args
assert len(result.function_names) == len(non_streaming_tool_calls)
...@@ -2,16 +2,7 @@ ...@@ -2,16 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing_extensions import TypedDict from tests.tool_use.utils import ServerConfig
class ServerConfig(TypedDict, total=False):
model: str
arguments: list[str]
system_prompt: str | None
supports_parallel: bool | None
supports_rocm: bool | None
ARGS: list[str] = ["--max-model-len", "1024"] ARGS: list[str] = ["--max-model-len", "1024"]
...@@ -21,6 +12,11 @@ CONFIGS: dict[str, ServerConfig] = { ...@@ -21,6 +12,11 @@ CONFIGS: dict[str, ServerConfig] = {
"arguments": [ "arguments": [
"--tokenizer-mode", "--tokenizer-mode",
"mistral", "mistral",
"--tool-call-parser",
"mistral",
"--enable-auto-tool-choice",
"--enforce-eager",
"--no-enable-prefix-caching",
'--ignore-patterns="consolidated.safetensors"', '--ignore-patterns="consolidated.safetensors"',
], ],
"system_prompt": "You are a helpful assistant with access to tools. If a tool" "system_prompt": "You are a helpful assistant with access to tools. If a tool"
...@@ -29,4 +25,22 @@ CONFIGS: dict[str, ServerConfig] = { ...@@ -29,4 +25,22 @@ CONFIGS: dict[str, ServerConfig] = {
"without calling a tool. DO NOT CALL A TOOL THAT IS IRRELEVANT " "without calling a tool. DO NOT CALL A TOOL THAT IS IRRELEVANT "
"to the user's question - just respond to it normally.", "to the user's question - just respond to it normally.",
}, },
"ministral-3b": {
"model": "mistralai/Ministral-3-3B-Instruct-2512",
"arguments": [
"--tokenizer-mode",
"mistral",
"--tool-call-parser",
"mistral",
"--enable-auto-tool-choice",
"--enforce-eager",
"--no-enable-prefix-caching",
],
"system_prompt": "You are a helpful assistant with access to tools. If a tool"
" that you have would be helpful to answer a user query, "
"call the tool. Otherwise, answer the user's query directly "
"without calling a tool. DO NOT CALL A TOOL THAT IS IRRELEVANT "
"to the user's question - just respond to it normally.",
"supports_parallel": True,
},
} }
...@@ -11,7 +11,7 @@ from openai.types.chat.chat_completion_audio import ( ...@@ -11,7 +11,7 @@ from openai.types.chat.chat_completion_audio import (
ChatCompletionAudio as OpenAIChatCompletionAudio, ChatCompletionAudio as OpenAIChatCompletionAudio,
) )
from openai.types.chat.chat_completion_message import Annotation as OpenAIAnnotation from openai.types.chat.chat_completion_message import Annotation as OpenAIAnnotation
from pydantic import Field, model_validator from pydantic import Field, PrivateAttr, model_validator
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.config.utils import replace from vllm.config.utils import replace
...@@ -398,6 +398,9 @@ class ChatCompletionRequest(OpenAIBaseModel): ...@@ -398,6 +398,9 @@ class ChatCompletionRequest(OpenAIBaseModel):
msg["tool_calls"] = list(tool_calls) msg["tool_calls"] = list(tool_calls)
return self return self
_grammar_from_tool_parser: bool = PrivateAttr(default=False)
"""CAUTION: Should only be set by ``ToolParser.adjust_request``."""
def build_chat_params( def build_chat_params(
self, self,
default_template: str | None, default_template: str | None,
...@@ -822,13 +825,6 @@ class ChatCompletionRequest(OpenAIBaseModel): ...@@ -822,13 +825,6 @@ class ChatCompletionRequest(OpenAIBaseModel):
return data return data
@model_validator(mode="before")
@classmethod
def set_include_reasoning_for_none_effort(cls, data: Any) -> Any:
if data.get("reasoning_effort") == "none":
data["include_reasoning"] = False
return data
class BatchChatCompletionRequest(OpenAIBaseModel): class BatchChatCompletionRequest(OpenAIBaseModel):
"""Request model for the /v1/chat/completions/batch endpoint. """Request model for the /v1/chat/completions/batch endpoint.
......
...@@ -73,7 +73,10 @@ from vllm.reasoning import ReasoningParser ...@@ -73,7 +73,10 @@ from vllm.reasoning import ReasoningParser
from vllm.renderers import ChatParams from vllm.renderers import ChatParams
from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.mistral_tool_parser import MistralToolCall from vllm.tool_parsers.mistral_tool_parser import (
MistralToolCall,
MistralToolParser,
)
from vllm.tool_parsers.utils import partial_json_loads from vllm.tool_parsers.utils import partial_json_loads
from vllm.utils.collection_utils import as_list from vllm.utils.collection_utils import as_list
from vllm.utils.mistral import is_mistral_tokenizer from vllm.utils.mistral import is_mistral_tokenizer
...@@ -140,6 +143,12 @@ class OpenAIServingChat(OpenAIServing): ...@@ -140,6 +143,12 @@ class OpenAIServingChat(OpenAIServing):
enable_auto_tools=enable_auto_tools, enable_auto_tools=enable_auto_tools,
model_name=self.model_config.model, model_name=self.model_config.model,
) )
_is_mistral_tool_parser = self.tool_parser is not None and issubclass(
self.tool_parser, MistralToolParser
)
if _is_mistral_tool_parser and self.reasoning_parser_cls is not None:
MistralToolParser.model_can_reason = True
self.exclude_tools_when_tool_choice_none = exclude_tools_when_tool_choice_none self.exclude_tools_when_tool_choice_none = exclude_tools_when_tool_choice_none
self.enable_prompt_tokens_details = enable_prompt_tokens_details self.enable_prompt_tokens_details = enable_prompt_tokens_details
...@@ -310,6 +319,11 @@ class OpenAIServingChat(OpenAIServing): ...@@ -310,6 +319,11 @@ class OpenAIServingChat(OpenAIServing):
else: else:
if not request.include_reasoning: if not request.include_reasoning:
reasoning_ended = True reasoning_ended = True
elif request._grammar_from_tool_parser:
# The Mistral grammar already includes an optional
# `think?` rule that handles both reasoning and
# non-reasoning outputs.
reasoning_ended = True
elif reasoning_parser: elif reasoning_parser:
reasoning_ended = reasoning_parser.is_reasoning_end( reasoning_ended = reasoning_parser.is_reasoning_end(
prompt_token_ids or [] prompt_token_ids or []
...@@ -530,6 +544,8 @@ class OpenAIServingChat(OpenAIServing): ...@@ -530,6 +544,8 @@ class OpenAIServingChat(OpenAIServing):
harmony_tools_streamed = [False] * num_choices harmony_tools_streamed = [False] * num_choices
tools_streamed = [False] * num_choices tools_streamed = [False] * num_choices
is_mistral_grammar_path = request._grammar_from_tool_parser
if isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam): if isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam):
tool_choice_function_name = request.tool_choice.function.name tool_choice_function_name = request.tool_choice.function.name
else: else:
...@@ -553,7 +569,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -553,7 +569,7 @@ class OpenAIServingChat(OpenAIServing):
# Only one of these will be used, thus previous_texts and # Only one of these will be used, thus previous_texts and
# all_previous_token_ids will not be used twice in the same iteration. # all_previous_token_ids will not be used twice in the same iteration.
if tool_choice_auto or reasoning_parser: if is_mistral_grammar_path or tool_choice_auto or reasoning_parser:
# These are only required in "auto" tool choice case # These are only required in "auto" tool choice case
all_previous_token_ids = [[] for _ in range(num_choices)] all_previous_token_ids = [[] for _ in range(num_choices)]
reasoning_end_arr = [False] * num_choices reasoning_end_arr = [False] * num_choices
...@@ -748,7 +764,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -748,7 +764,7 @@ class OpenAIServingChat(OpenAIServing):
delta_message: DeltaMessage | None delta_message: DeltaMessage | None
# just update previous_texts and previous_token_ids # just update previous_texts and previous_token_ids
if tool_choice_auto or reasoning_parser: if is_mistral_grammar_path or tool_choice_auto or reasoning_parser:
assert previous_texts is not None assert previous_texts is not None
assert all_previous_token_ids is not None assert all_previous_token_ids is not None
previous_text = previous_texts[i] previous_text = previous_texts[i]
...@@ -772,6 +788,30 @@ class OpenAIServingChat(OpenAIServing): ...@@ -772,6 +788,30 @@ class OpenAIServingChat(OpenAIServing):
) )
) )
harmony_tools_streamed[i] |= tools_streamed_flag harmony_tools_streamed[i] |= tools_streamed_flag
# Mistral grammar path: combined reasoning + tool streaming
elif is_mistral_grammar_path:
assert tool_parser is not None
assert isinstance(tool_parser, MistralToolParser)
assert reasoning_end_arr is not None
output_token_ids = as_list(output.token_ids)
result = tool_parser.extract_maybe_reasoning_and_tool_streaming(
reasoning_parser=reasoning_parser,
previous_text=previous_text,
current_text=current_text,
delta_text=delta_text,
previous_token_ids=previous_token_ids,
current_token_ids=current_token_ids,
output_token_ids=output_token_ids,
reasoning_ended=reasoning_end_arr[i],
prompt_is_reasoning_end=(prompt_is_reasoning_end_arr[i]),
request=request,
)
delta_message = result.delta_message
reasoning_end_arr[i] = result.reasoning_ended
current_text = result.current_text
current_token_ids = result.current_token_ids
if result.tools_called:
tools_streamed[i] = True
# handle streaming deltas for tools with named tool_choice # handle streaming deltas for tools with named tool_choice
elif tool_choice_function_name: elif tool_choice_function_name:
# When encountering think end id in prompt_token_ids # When encountering think end id in prompt_token_ids
...@@ -925,7 +965,9 @@ class OpenAIServingChat(OpenAIServing): ...@@ -925,7 +965,9 @@ class OpenAIServingChat(OpenAIServing):
delta_message = DeltaMessage(content=delta_text) delta_message = DeltaMessage(content=delta_text)
# update the previous values for the next iteration # update the previous values for the next iteration
if (tool_choice_auto or reasoning_parser) and not self.use_harmony: if (
is_mistral_grammar_path or tool_choice_auto or reasoning_parser
) and not self.use_harmony:
assert previous_texts is not None assert previous_texts is not None
assert all_previous_token_ids is not None assert all_previous_token_ids is not None
previous_texts[i] = current_text previous_texts[i] = current_text
...@@ -1312,7 +1354,24 @@ class OpenAIServingChat(OpenAIServing): ...@@ -1312,7 +1354,24 @@ class OpenAIServingChat(OpenAIServing):
tool_call_class = ( tool_call_class = (
MistralToolCall if is_mistral_tokenizer(tokenizer) else ToolCall MistralToolCall if is_mistral_tokenizer(tokenizer) else ToolCall
) )
if (not self.enable_auto_tools or not self.tool_parser) and (
use_mistral_tool_parser = request._grammar_from_tool_parser
if use_mistral_tool_parser:
tool_call_items = MistralToolParser.build_non_streaming_tool_calls(
tool_calls
)
if tool_call_items:
auto_tools_called = (
request.tool_choice is None or request.tool_choice == "auto"
)
message = ChatMessage(
role=role,
reasoning=reasoning,
content=content,
tool_calls=tool_call_items,
)
elif (not self.enable_auto_tools or not self.tool_parser) and (
not isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam) not isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam)
and request.tool_choice != "required" and request.tool_choice != "required"
): ):
......
...@@ -65,6 +65,7 @@ from vllm.renderers.inputs.preprocess import ( ...@@ -65,6 +65,7 @@ from vllm.renderers.inputs.preprocess import (
from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers import ToolParser from vllm.tool_parsers import ToolParser
from vllm.tool_parsers.mistral_tool_parser import MistralToolParser
from vllm.tracing import ( from vllm.tracing import (
contains_trace_headers, contains_trace_headers,
extract_trace_headers, extract_trace_headers,
...@@ -610,16 +611,31 @@ class OpenAIServing: ...@@ -610,16 +611,31 @@ class OpenAIServing:
tool_parser_cls: type[ToolParser] | None, tool_parser_cls: type[ToolParser] | None,
content: str | None = None, content: str | None = None,
) -> tuple[list[FunctionCall] | None, str | None]: ) -> tuple[list[FunctionCall] | None, str | None]:
# When the Mistral grammar factory injected structured outputs,
# let the parser handle the output.
use_mistral_tool_parser = (
isinstance(request, ChatCompletionRequest)
and tool_parser_cls is not None
and issubclass(tool_parser_cls, MistralToolParser)
and request._grammar_from_tool_parser
)
function_calls = list[FunctionCall]() function_calls = list[FunctionCall]()
if request.tool_choice and isinstance(request.tool_choice, ToolChoiceFunction): if (
not use_mistral_tool_parser
and request.tool_choice
and isinstance(request.tool_choice, ToolChoiceFunction)
):
assert content is not None assert content is not None
# Forced Function Call # Forced Function Call
function_calls.append( function_calls.append(
FunctionCall(name=request.tool_choice.name, arguments=content) FunctionCall(name=request.tool_choice.name, arguments=content)
) )
content = None # Clear content since tool is called. content = None # Clear content since tool is called.
elif request.tool_choice and isinstance( elif (
request.tool_choice, ChatCompletionNamedToolChoiceParam not use_mistral_tool_parser
and request.tool_choice
and isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam)
): ):
assert content is not None assert content is not None
# Forced Function Call # Forced Function Call
...@@ -627,7 +643,7 @@ class OpenAIServing: ...@@ -627,7 +643,7 @@ class OpenAIServing:
FunctionCall(name=request.tool_choice.function.name, arguments=content) FunctionCall(name=request.tool_choice.function.name, arguments=content)
) )
content = None # Clear content since tool is called. content = None # Clear content since tool is called.
elif request.tool_choice == "required": elif not use_mistral_tool_parser and request.tool_choice == "required":
tool_calls = [] tool_calls = []
with contextlib.suppress(ValidationError): with contextlib.suppress(ValidationError):
content = content or "" content = content or ""
...@@ -642,10 +658,12 @@ class OpenAIServing: ...@@ -642,10 +658,12 @@ class OpenAIServing:
) )
) )
content = None # Clear content since tool is called. content = None # Clear content since tool is called.
elif ( elif tool_parser_cls and (
tool_parser_cls use_mistral_tool_parser
and enable_auto_tools or (
and (request.tool_choice == "auto" or request.tool_choice is None) enable_auto_tools
and (request.tool_choice == "auto" or request.tool_choice is None)
)
): ):
if tokenizer is None: if tokenizer is None:
raise ValueError( raise ValueError(
......
...@@ -53,6 +53,7 @@ from vllm.renderers.inputs.preprocess import ( ...@@ -53,6 +53,7 @@ from vllm.renderers.inputs.preprocess import (
prompt_to_seq, prompt_to_seq,
) )
from vllm.tool_parsers import ToolParser from vllm.tool_parsers import ToolParser
from vllm.tool_parsers.mistral_tool_parser import MistralToolParser
from vllm.utils import random_uuid from vllm.utils import random_uuid
from vllm.utils.mistral import is_mistral_tokenizer from vllm.utils.mistral import is_mistral_tokenizer
from vllm.utils.mistral import mt as _mt from vllm.utils.mistral import mt as _mt
...@@ -555,9 +556,19 @@ class OpenAIServingRender: ...@@ -555,9 +556,19 @@ class OpenAIServingRender:
# tool parsing is done only if a tool_parser has been set and if # tool parsing is done only if a tool_parser has been set and if
# tool_choice is not "none" (if tool_choice is "none" but a tool_parser # tool_choice is not "none" (if tool_choice is "none" but a tool_parser
# is set, we want to prevent parsing a tool_call hallucinated by the LLM # is set, we want to prevent parsing a tool_call hallucinated by the LLM
#
# Exception: Mistral grammar-capable tokenizers always call
# adjust_request — even for tool_choice="none" — so that the grammar
# factory can prevent special-token leakage.
if tool_parser is not None: if tool_parser is not None:
tool_choice = getattr(request, "tool_choice", "none") tool_choice = getattr(request, "tool_choice", "none")
if tool_choice != "none": tokenizer = renderer.get_tokenizer()
is_mistral_grammar_eligible = (
issubclass(tool_parser, MistralToolParser)
and is_mistral_tokenizer(tokenizer)
and tokenizer.supports_grammar
)
if tool_choice != "none" or is_mistral_grammar_eligible:
if not isinstance(request, ChatCompletionRequest | ResponsesRequest): if not isinstance(request, ChatCompletionRequest | ResponsesRequest):
msg = ( msg = (
"Tool usage is only supported " "Tool usage is only supported "
...@@ -565,7 +576,6 @@ class OpenAIServingRender: ...@@ -565,7 +576,6 @@ class OpenAIServingRender:
f"but got {type(request).__name__}" f"but got {type(request).__name__}"
) )
raise NotImplementedError(msg) raise NotImplementedError(msg)
tokenizer = renderer.get_tokenizer()
request = tool_parser(tokenizer, request.tools).adjust_request( request = tool_parser(tokenizer, request.tools).adjust_request(
request=request request=request
) )
......
...@@ -157,6 +157,10 @@ def _is_non_tekken_mistral(tokenizer: TokenizerLike) -> bool: ...@@ -157,6 +157,10 @@ def _is_non_tekken_mistral(tokenizer: TokenizerLike) -> bool:
return is_mistral_tokenizer(tokenizer) and not tokenizer.is_tekken return is_mistral_tokenizer(tokenizer) and not tokenizer.is_tekken
def _get_llg_tokenizer(tokenizer: TokenizerLike) -> Any:
return tokenizer.llg_tokenizer if is_mistral_tokenizer(tokenizer) else None
class SamplingParams( class SamplingParams(
PydanticMsgspecMixin, PydanticMsgspecMixin,
msgspec.Struct, msgspec.Struct,
...@@ -816,7 +820,10 @@ class SamplingParams( ...@@ -816,7 +820,10 @@ class SamplingParams(
# allows <|special_token|> and similar, see # allows <|special_token|> and similar, see
# https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens # https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens
# Without tokenizer these are disallowed in grammars. # Without tokenizer these are disallowed in grammars.
validate_guidance_grammar(self, tokenizer=None) validate_guidance_grammar(
self,
tokenizer=_get_llg_tokenizer(tokenizer),
)
elif backend == "outlines": elif backend == "outlines":
# outlines backend # outlines backend
validate_structured_output_request_outlines(self) validate_structured_output_request_outlines(self)
...@@ -862,7 +869,10 @@ class SamplingParams( ...@@ -862,7 +869,10 @@ class SamplingParams(
self.structured_outputs._backend = "outlines" self.structured_outputs._backend = "outlines"
else: else:
# Fall back to guidance by default. # Fall back to guidance by default.
validate_guidance_grammar(self, tokenizer=None) validate_guidance_grammar(
self,
tokenizer=_get_llg_tokenizer(tokenizer),
)
self.structured_outputs._backend = "guidance" self.structured_outputs._backend = "guidance"
# Remember that this backend was set automatically # Remember that this backend was set automatically
self.structured_outputs._backend_was_auto = True self.structured_outputs._backend_was_auto = True
......
...@@ -54,6 +54,50 @@ if TYPE_CHECKING: ...@@ -54,6 +54,50 @@ if TYPE_CHECKING:
logger = init_logger(__name__) logger = init_logger(__name__)
def _pop_unallowed_keys_and_warn(
dictionary: dict[str, Any], allowed_keys: set[str], err_dict_name: str
):
keys = list(dictionary.keys())
for key in keys:
if key not in allowed_keys:
dictionary.pop(key)
logger.warning_once(
f"'{key=}' is not supported by mistral-common "
f"for {err_dict_name}. It has been popped from the "
"object."
)
# TODO(juliendenize): remove this once OpenAI API is better supported by
# `mistral-common`.
def adapt_inplace_to_mistral_tool(
tool: dict[str, Any],
) -> dict[str, Any]:
tools_fields = set(Tool.model_fields.keys())
function_fields = set(Function.model_fields.keys())
# The Mistral client, in comparison to the OpenAI client, requires the
# "parameters" dict and the "description" string to be present
# even if they are empty.
if function := tool.get("function"):
if function.get("parameters") is None:
function["parameters"] = {}
if function.get("description") is None:
function["description"] = ""
_pop_unallowed_keys_and_warn(
dictionary=function,
allowed_keys=function_fields,
err_dict_name="function",
)
_pop_unallowed_keys_and_warn(
dictionary=tool, allowed_keys=tools_fields, err_dict_name="tools"
)
return tool
def maybe_serialize_tool_calls(request: "MistralChatCompletionRequest"): def maybe_serialize_tool_calls(request: "MistralChatCompletionRequest"):
# SEE: https://github.com/vllm-project/vllm/pull/9951 # SEE: https://github.com/vllm-project/vllm/pull/9951
# Credits go to: @gcalmettes # Credits go to: @gcalmettes
...@@ -159,44 +203,11 @@ def _prepare_apply_chat_template_tools_and_messages( ...@@ -159,44 +203,11 @@ def _prepare_apply_chat_template_tools_and_messages(
# Remove reasoning as unsupported by Mistral # Remove reasoning as unsupported by Mistral
_ = message.pop("reasoning", None) # type: ignore _ = message.pop("reasoning", None) # type: ignore
# The Mistral client, in comparison to the OpenAI client, requires the tools = (
# "parameters" dict and the "description" string to be present [adapt_inplace_to_mistral_tool(tool=tool) for tool in tools]
# even if they are empty. if tools is not None
if tools: else None
for function in [ )
tool["function"] for tool in tools if tool["type"] == "function"
]:
if function.get("parameters") is None:
function["parameters"] = {}
if function.get("description") is None:
function["description"] = ""
# We filter not supported arguments to avoid throwing an error.
# TODO(juliendenize): remove this once OpenAI API is better supported by
# `mistral-common`.
tools_fields = set(Tool.model_fields.keys())
function_fields = set(Function.model_fields.keys())
for tool in tools:
tool_keys = list(tool.keys())
for tool_key in tool_keys:
if tool_key not in tools_fields:
tool.pop(tool_key)
logger.warning_once(
f"'{tool_key}' is not supported by mistral-common for tools. "
"It has been popped from the tool definition."
)
if tool["type"] == "function":
function_keys = list(tool["function"].keys())
for function_key in function_keys:
if function_key not in function_fields:
tool["function"].pop(function_key)
logger.warning_once(
f"'{function_key}' is not supported by mistral-common "
"for function tools. It has been popped from the "
"function definition."
)
else:
raise ValueError("mistral-common only supports function tools.")
return messages, tools return messages, tools
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import json import json
from collections.abc import Sequence from collections.abc import Sequence
from dataclasses import dataclass
from enum import Enum, auto from enum import Enum, auto
from random import choices from random import choices
from string import ascii_letters, digits from string import ascii_letters, digits
from typing import Any from typing import TYPE_CHECKING, Any
import ijson import ijson
import regex as re import regex as re
...@@ -37,14 +40,19 @@ from vllm.entrypoints.openai.engine.protocol import ( ...@@ -37,14 +40,19 @@ from vllm.entrypoints.openai.engine.protocol import (
) )
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.reasoning.mistral_reasoning_parser import MistralReasoningParser
from vllm.sampling_params import StructuredOutputsParams from vllm.sampling_params import StructuredOutputsParams
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.tokenizers.mistral import MistralTokenizer, adapt_inplace_to_mistral_tool
from vllm.tool_parsers.abstract_tool_parser import ( from vllm.tool_parsers.abstract_tool_parser import (
Tool, Tool,
ToolParser, ToolParser,
) )
from vllm.utils.mistral import is_mistral_tokenizer from vllm.utils.mistral import is_mistral_tokenizer
if TYPE_CHECKING:
from vllm.reasoning import ReasoningParser
logger = init_logger(__name__) logger = init_logger(__name__)
ALPHANUMERIC = ascii_letters + digits ALPHANUMERIC = ascii_letters + digits
...@@ -86,13 +94,28 @@ def _is_pre_v11_tokeniser(model_tokenizer: TokenizerLike) -> bool: ...@@ -86,13 +94,28 @@ def _is_pre_v11_tokeniser(model_tokenizer: TokenizerLike) -> bool:
return not (is_mistral_tokenizer(model_tokenizer) and model_tokenizer.version >= 11) return not (is_mistral_tokenizer(model_tokenizer) and model_tokenizer.version >= 11)
class MistralToolParser(ToolParser): @dataclass
class MistralStreamingResult:
r"""Encapsulates the mutable state returned from
`MistralToolParser.extract_maybe_reasoning_and_tool_streaming`.
""" """
Tool call parser for Mistral 7B Instruct v0.3, intended for use with
- [`mistral_common`](https://github.com/mistralai/mistral-common/)
- the examples/tool_chat_template_mistral.jinja template.
Used when --enable-auto-tool-choice --tool-call-parser mistral are all set delta_message: DeltaMessage | None
reasoning_ended: bool
tools_called: bool
current_text: str
current_token_ids: list[int]
class MistralToolParser(ToolParser):
r"""Tool call parser for Mistral models, intended for use with either:
- `mistral_common <https://github.com/mistralai/mistral-common/>`_
(recommended)
- the `examples/tool_chat_template_mistral.jinja` template.
Used when `--enable-auto-tool-choice --tool-call-parser mistral` are all
set.
""" """
# Used to generate correct grammar in `adjust_request` # Used to generate correct grammar in `adjust_request`
...@@ -210,9 +233,11 @@ class MistralToolParser(ToolParser): ...@@ -210,9 +233,11 @@ class MistralToolParser(ToolParser):
reasoning=self.model_can_reason reasoning=self.model_can_reason
) )
tools = ( mistral_tools = (
[ [
MistralTool.from_openai(openai_tool=tool.model_dump()) MistralTool.model_validate(
adapt_inplace_to_mistral_tool(tool.model_dump())
)
for tool in request.tools for tool in request.tools
] ]
if request.tools is not None if request.tools is not None
...@@ -244,15 +269,158 @@ class MistralToolParser(ToolParser): ...@@ -244,15 +269,158 @@ class MistralToolParser(ToolParser):
lark_grammar = grammar_factory.get_lark_from_jinja( lark_grammar = grammar_factory.get_lark_from_jinja(
template=template, template=template,
mode=tool_choice, mode=tool_choice,
tools=tools, tools=mistral_tools,
json_schema=json_schema, json_schema=json_schema,
parallel_tool_calls=request.parallel_tool_calls, parallel_tool_calls=request.parallel_tool_calls,
json_only=False, json_only=False,
) )
request.structured_outputs = StructuredOutputsParams(grammar=lark_grammar) request.structured_outputs = StructuredOutputsParams(grammar=lark_grammar)
request._grammar_from_tool_parser = True
return request return request
def extract_maybe_reasoning_and_tool_streaming(
self,
*,
reasoning_parser: ReasoningParser | None,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: list[int],
current_token_ids: list[int],
output_token_ids: Sequence[int],
reasoning_ended: bool,
prompt_is_reasoning_end: bool | None,
request: ChatCompletionRequest,
) -> MistralStreamingResult:
r"""Streaming extraction with reasoning followed by tool-call parsing.
This method encapsulates the combined reasoning extraction and
tool-call streaming logic so that the serving layer only needs a
thin routing branch.
The flow is:
1. If a *reasoning_parser* is present and reasoning has **not** ended,
extract reasoning tokens. Pre-v15 models may have pre-filled
`[THINK]...[/THINK]` in system prompts, so we skip the
prompt-level reasoning-end check for those.
2. Once reasoning ends (or if there is no reasoning parser), delegate
to `extract_tool_calls_streaming` and track whether tools were
called.
Args:
reasoning_parser: Optional reasoning parser instance.
previous_text: Accumulated text from prior chunks.
current_text: Full accumulated text including current chunk.
delta_text: New text in this chunk.
previous_token_ids: Token ids from prior chunks.
current_token_ids: Full token ids including current chunk.
output_token_ids: Raw output token ids from the engine.
reasoning_ended: Whether reasoning has already ended.
prompt_is_reasoning_end: Whether the prompt itself ends reasoning.
request: The originating chat completion request.
"""
delta_message: DeltaMessage | None = None
tools_called = False
reasoning_ended_at_entry = reasoning_ended
# For MistralReasoningParser, only enter the reasoning block when
# the model has actually emitted a [THINK] token. Other reasoning
# parsers always expect thinking to be present.
expect_thinking = (
not isinstance(reasoning_parser, MistralReasoningParser)
or reasoning_parser.start_token_id in current_token_ids
)
if reasoning_parser is not None and not reasoning_ended and expect_thinking:
# Pre-v15 models may have pre-filled [THINK]...[/THINK] in
# system prompts, so skip the prompt-level reasoning-end
# check and wait for the output's own end-of-think.
is_pre_v15 = (
isinstance(self.model_tokenizer, MistralTokenizer)
and self.model_tokenizer.version < 15
)
if not is_pre_v15 and prompt_is_reasoning_end:
reasoning_ended = True
current_token_ids = list(output_token_ids)
else:
delta_message = reasoning_parser.extract_reasoning_streaming(
previous_text,
current_text,
delta_text,
previous_token_ids,
current_token_ids,
output_token_ids,
)
if reasoning_parser.is_reasoning_end_streaming(
current_token_ids, output_token_ids
):
reasoning_ended = True
current_token_ids = reasoning_parser.extract_content_ids(
list(output_token_ids)
)
if delta_message and delta_message.content:
current_text = delta_message.content
delta_message.content = None
else:
current_text = ""
if not reasoning_ended:
return MistralStreamingResult(
delta_message=delta_message,
reasoning_ended=False,
tools_called=False,
current_text=current_text,
current_token_ids=current_token_ids,
)
delta_token_ids = list(output_token_ids)
# On the iteration where reasoning just ended, reset the text/token
# state so the tool parser sees a clean history instead of the
# accumulated reasoning text.
if not reasoning_ended_at_entry and reasoning_ended:
previous_text = ""
previous_token_ids = []
delta_text = current_text
delta_token_ids = current_token_ids
delta_message = self.extract_tool_calls_streaming(
previous_text=previous_text,
current_text=current_text,
delta_text=delta_text,
previous_token_ids=previous_token_ids,
current_token_ids=current_token_ids,
delta_token_ids=delta_token_ids,
request=request,
)
if delta_message and delta_message.tool_calls:
tools_called = True
return MistralStreamingResult(
delta_message=delta_message,
reasoning_ended=reasoning_ended,
tools_called=tools_called,
current_text=current_text,
current_token_ids=current_token_ids,
)
@staticmethod
def build_non_streaming_tool_calls(
tool_calls: list[FunctionCall] | None,
) -> list[ToolCall]:
r"""Build `MistralToolCall` items for non-streaming responses."""
if not tool_calls:
return []
return [
MistralToolCall(id=tc.id, function=tc)
if tc.id
else MistralToolCall(function=tc)
for tc in tool_calls
]
def extract_tool_calls( def extract_tool_calls(
self, self,
model_output: str, model_output: str,
...@@ -323,7 +491,7 @@ class MistralToolParser(ToolParser): ...@@ -323,7 +491,7 @@ class MistralToolParser(ToolParser):
)[0] )[0]
tool_calls = json.loads(raw_tool_call) tool_calls = json.loads(raw_tool_call)
except (IndexError, json.JSONDecodeError): except (IndexError, json.JSONDecodeError):
logger.exception("Error in extracting tool call from response: {e}") logger.exception("Error in extracting tool call from response.")
# If raw decoding and decoding post regex rule fails, then just # If raw decoding and decoding post regex rule fails, then just
# return content. # return content.
return ExtractedToolCallInformation( return ExtractedToolCallInformation(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment