Unverified Commit c0722f22 authored by Julien Denize's avatar Julien Denize Committed by GitHub
Browse files

[Mistral Grammar] Fix tool and reasoning parsing (#39217)


Signed-off-by: default avatarjuliendenize <julien.denize@mistral.ai>
parent 951dca80
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
from dataclasses import dataclass, field
import openai
import pytest
from tests.tool_use.utils import MESSAGES_ASKING_FOR_TOOLS, WEATHER_TOOL
from tests.tool_use.utils import (
MESSAGES_ASKING_FOR_PARALLEL_TOOLS,
MESSAGES_ASKING_FOR_TOOLS,
MESSAGES_WITH_TOOL_RESPONSE,
MESSAGES_WITHOUT_TOOLS,
SEARCH_TOOL,
SEED,
WEATHER_TOOL,
ensure_system_prompt,
)
from .utils import ServerConfig
def _requires_tool_parser(server_config: ServerConfig) -> None:
r"""Skip test if server was not started with --tool-call-parser."""
if "--tool-call-parser" not in server_config.get("arguments", []):
pytest.skip(
f"Skipping: {server_config['model']} not configured with --tool-call-parser"
)
def _is_pre_v11(server_config: ServerConfig) -> bool:
r"""Pre-v11 Mistral models lack grammar-based tool call enforcement."""
return "7B" in server_config.get("model", "")
@dataclass
class StreamedToolCallResult:
r"""Accumulated result from streaming a single tool call."""
function_name: str | None = None
function_args_str: str = ""
tool_call_id: str | None = None
role_name: str | None = None
finish_reason_count: int = 0
finish_reason: str | None = None
async def _collect_streamed_tool_call(
stream: openai.AsyncStream,
*,
expected_finish_reason: str = "tool_calls",
) -> StreamedToolCallResult:
result = StreamedToolCallResult()
async for chunk in stream:
if chunk.choices[0].finish_reason:
result.finish_reason_count += 1
result.finish_reason = chunk.choices[0].finish_reason
assert chunk.choices[0].finish_reason == expected_finish_reason
if chunk.choices[0].delta.role:
assert not result.role_name or result.role_name == "assistant"
result.role_name = "assistant"
streamed_tool_calls = chunk.choices[0].delta.tool_calls
if streamed_tool_calls and len(streamed_tool_calls) > 0:
assert len(streamed_tool_calls) == 1
tool_call = streamed_tool_calls[0]
if tool_call.id:
assert not result.tool_call_id
result.tool_call_id = tool_call.id
if tool_call.function:
if tool_call.function.name:
assert result.function_name is None
result.function_name = tool_call.function.name
if tool_call.function.arguments:
result.function_args_str += tool_call.function.arguments
return result
@dataclass
class StreamedContentResult:
r"""Accumulated result from streaming a content-only response."""
chunks: list[str] = field(default_factory=list)
finish_reason_count: int = 0
finish_reason: str | None = None
role_sent: bool = False
async def _collect_streamed_content(
stream: openai.AsyncStream,
*,
expected_finish_reason: str | None = None,
no_tool_calls: bool = True,
) -> StreamedContentResult:
r"""Consume a streaming response and collect text content."""
result = StreamedContentResult()
async for chunk in stream:
delta = chunk.choices[0].delta
if delta.role:
assert not result.role_sent
assert delta.role == "assistant"
result.role_sent = True
if delta.content:
result.chunks.append(delta.content)
if chunk.choices[0].finish_reason is not None:
result.finish_reason_count += 1
result.finish_reason = chunk.choices[0].finish_reason
if expected_finish_reason is not None:
assert result.finish_reason == expected_finish_reason
if no_tool_calls:
assert not delta.tool_calls or len(delta.tool_calls) == 0
return result
@dataclass
class StreamedParallelToolCallResult:
r"""Accumulated result from streaming parallel tool calls."""
function_names: list[str] = field(default_factory=list)
function_args_strs: list[str] = field(default_factory=list)
tool_call_ids: list[str] = field(default_factory=list)
role_name: str | None = None
finish_reason_count: int = 0
async def _collect_streamed_parallel_tool_calls(
stream: openai.AsyncStream,
) -> StreamedParallelToolCallResult:
r"""Consume a streaming response and collect parallel tool calls."""
result = StreamedParallelToolCallResult()
tool_call_idx: int = -1
async for chunk in stream:
if chunk.choices[0].finish_reason:
result.finish_reason_count += 1
assert chunk.choices[0].finish_reason == "tool_calls"
if chunk.choices[0].delta.role:
assert not result.role_name or result.role_name == "assistant"
result.role_name = "assistant"
streamed_tool_calls = chunk.choices[0].delta.tool_calls
if streamed_tool_calls and len(streamed_tool_calls) > 0:
assert len(streamed_tool_calls) == 1
tool_call = streamed_tool_calls[0]
if tool_call.index != tool_call_idx:
tool_call_idx = tool_call.index
result.function_args_strs.append("")
result.tool_call_ids.append("")
if tool_call.id:
result.tool_call_ids[tool_call.index] = tool_call.id
if tool_call.function:
if tool_call.function.name:
result.function_names.append(tool_call.function.name)
if tool_call.function.arguments:
result.function_args_strs[tool_call.index] += (
tool_call.function.arguments
)
return result
# test: a tool_choice with mistral-tokenizer results in an ID of length 9
@pytest.mark.asyncio
async def test_tool_call_with_tool_choice(client: openai.AsyncOpenAI):
async def test_tool_call_with_tool_choice(
client: openai.AsyncOpenAI, server_config: ServerConfig
) -> None:
_requires_tool_parser(server_config)
models = await client.models.list()
model_name: str = models.data[0].id
chat_completion = await client.chat.completions.create(
messages=MESSAGES_ASKING_FOR_TOOLS,
messages=ensure_system_prompt(MESSAGES_ASKING_FOR_TOOLS, server_config),
temperature=0,
max_completion_tokens=100,
model=model_name,
tools=[WEATHER_TOOL],
tool_choice=WEATHER_TOOL,
logprobs=False,
seed=SEED,
)
choice = chat_completion.choices[0]
......@@ -28,3 +201,307 @@ async def test_tool_call_with_tool_choice(client: openai.AsyncOpenAI):
assert choice.message.role == "assistant"
assert choice.message.tool_calls is None or len(choice.message.tool_calls) == 1
assert len(choice.message.tool_calls[0].id) == 9 # length of 9 for mistral
_NOT_SET = object()
@pytest.mark.asyncio
@pytest.mark.parametrize(
"tools, tool_choice, streaming_id_len_pre_v11",
[
pytest.param(
[WEATHER_TOOL, SEARCH_TOOL],
_NOT_SET,
9,
id="auto",
),
pytest.param(
[WEATHER_TOOL],
"required",
30,
id="required",
),
],
)
async def test_tool_call_auto_or_required(
client: openai.AsyncOpenAI,
server_config: ServerConfig,
tools: list,
tool_choice: object,
streaming_id_len_pre_v11: int,
) -> None:
_requires_tool_parser(server_config)
models = await client.models.list()
model_name: str = models.data[0].id
create_kwargs: dict = {
"messages": ensure_system_prompt(MESSAGES_ASKING_FOR_TOOLS, server_config),
"temperature": 0,
"max_completion_tokens": 100,
"model": model_name,
"tools": tools,
"logprobs": False,
"seed": SEED,
}
if tool_choice is not _NOT_SET:
create_kwargs["tool_choice"] = tool_choice
# --- non-streaming ---
chat_completion = await client.chat.completions.create(**create_kwargs)
choice = chat_completion.choices[0]
tool_calls = choice.message.tool_calls
assert choice.finish_reason == "tool_calls"
assert tool_calls is not None and len(tool_calls) >= 1
assert tool_calls[0].function.name == "get_current_weather"
parsed_arguments = json.loads(tool_calls[0].function.arguments)
assert "city" in parsed_arguments
assert len(tool_calls[0].id) == 9
# --- streaming ---
stream = await client.chat.completions.create(**create_kwargs, stream=True)
result = await _collect_streamed_tool_call(stream)
assert result.finish_reason_count == 1
assert result.role_name == "assistant"
assert result.function_name == "get_current_weather"
streamed_args = json.loads(result.function_args_str)
assert isinstance(result.tool_call_id, str)
if _is_pre_v11(server_config):
assert len(result.tool_call_id) == streaming_id_len_pre_v11
else:
assert len(result.tool_call_id) == 9
assert parsed_arguments == streamed_args
@pytest.mark.asyncio
async def test_tool_call_none_with_tools(
client: openai.AsyncOpenAI, server_config: ServerConfig
) -> None:
_requires_tool_parser(server_config)
models = await client.models.list()
model_name: str = models.data[0].id
# --- non-streaming ---
chat_completion = await client.chat.completions.create(
messages=ensure_system_prompt(MESSAGES_ASKING_FOR_TOOLS, server_config),
temperature=0,
max_completion_tokens=100,
model=model_name,
tools=[WEATHER_TOOL],
tool_choice="none",
logprobs=False,
seed=SEED,
)
choice = chat_completion.choices[0]
assert choice.finish_reason != "tool_calls"
assert choice.message.tool_calls is None or len(choice.message.tool_calls) == 0
assert choice.message.content is not None
# Without grammar enforcement, pre-v11 models may still emit [TOOL_CALLS]
if not _is_pre_v11(server_config):
assert "[TOOL_CALLS]" not in choice.message.content
non_streaming_content = choice.message.content
# --- streaming ---
stream = await client.chat.completions.create(
messages=ensure_system_prompt(MESSAGES_ASKING_FOR_TOOLS, server_config),
temperature=0,
max_completion_tokens=100,
model=model_name,
tools=[WEATHER_TOOL],
tool_choice="none",
logprobs=False,
seed=SEED,
stream=True,
)
# Pre-v11 models lack grammar enforcement, so the model may still
# emit tool calls even with tool_choice="none".
pre_v11 = _is_pre_v11(server_config)
result = await _collect_streamed_content(stream, no_tool_calls=not pre_v11)
assert result.finish_reason_count == 1
if not pre_v11:
assert result.finish_reason != "tool_calls"
streamed_content = "".join(result.chunks)
if not pre_v11:
assert "[TOOL_CALLS]" not in streamed_content
assert streamed_content == non_streaming_content
@pytest.mark.asyncio
async def test_chat_without_tools(
client: openai.AsyncOpenAI, server_config: ServerConfig
) -> None:
models = await client.models.list()
model_name: str = models.data[0].id
# --- non-streaming ---
chat_completion = await client.chat.completions.create(
messages=ensure_system_prompt(MESSAGES_WITHOUT_TOOLS, server_config),
temperature=0,
max_completion_tokens=150,
model=model_name,
logprobs=False,
seed=SEED,
)
choice = chat_completion.choices[0]
output_text = choice.message.content
assert output_text is not None and len(output_text) > 0
assert choice.finish_reason != "tool_calls"
assert choice.message.tool_calls is None or len(choice.message.tool_calls) == 0
# --- streaming ---
stream = await client.chat.completions.create(
messages=ensure_system_prompt(MESSAGES_WITHOUT_TOOLS, server_config),
temperature=0,
max_completion_tokens=150,
model=model_name,
logprobs=False,
seed=SEED,
stream=True,
)
result = await _collect_streamed_content(
stream, expected_finish_reason=choice.finish_reason
)
assert result.role_sent
assert result.finish_reason_count == 1
assert len(result.chunks)
assert "".join(result.chunks) == output_text
@pytest.mark.asyncio
async def test_tool_call_with_results(
client: openai.AsyncOpenAI, server_config: ServerConfig
) -> None:
_requires_tool_parser(server_config)
models = await client.models.list()
model_name: str = models.data[0].id
# --- non-streaming ---
chat_completion = await client.chat.completions.create(
messages=ensure_system_prompt(MESSAGES_WITH_TOOL_RESPONSE, server_config),
temperature=0,
max_completion_tokens=100,
model=model_name,
tools=[WEATHER_TOOL, SEARCH_TOOL],
logprobs=False,
seed=SEED,
)
choice = chat_completion.choices[0]
assert choice.finish_reason != "tool_calls"
assert choice.message.tool_calls is None or len(choice.message.tool_calls) == 0
assert choice.message.content is not None
assert "98" in choice.message.content
# --- streaming ---
stream = await client.chat.completions.create(
messages=ensure_system_prompt(MESSAGES_WITH_TOOL_RESPONSE, server_config),
temperature=0,
max_completion_tokens=100,
model=model_name,
tools=[WEATHER_TOOL, SEARCH_TOOL],
logprobs=False,
seed=SEED,
stream=True,
)
result = await _collect_streamed_content(
stream, expected_finish_reason=choice.finish_reason
)
assert result.role_sent
assert result.finish_reason_count == 1
assert len(result.chunks)
assert "".join(result.chunks) == choice.message.content
def _requires_parallel(server_config: ServerConfig) -> None:
r"""Skip test if the model does not support parallel tool calls."""
if not server_config.get("supports_parallel"):
pytest.skip(
f"Skipping: {server_config['model']} does not support parallel tool calls"
)
@pytest.mark.asyncio
async def test_tool_call_parallel(
client: openai.AsyncOpenAI, server_config: ServerConfig
) -> None:
_requires_tool_parser(server_config)
_requires_parallel(server_config)
models = await client.models.list()
model_name: str = models.data[0].id
# --- non-streaming ---
chat_completion = await client.chat.completions.create(
messages=ensure_system_prompt(
MESSAGES_ASKING_FOR_PARALLEL_TOOLS, server_config
),
temperature=0,
max_completion_tokens=200,
model=model_name,
tools=[WEATHER_TOOL],
logprobs=False,
seed=SEED,
)
choice = chat_completion.choices[0]
tool_calls = choice.message.tool_calls
assert choice.finish_reason == "tool_calls"
assert tool_calls is not None and len(tool_calls) >= 2
for tc in tool_calls:
assert tc.type == "function"
assert tc.function.name == "get_current_weather"
assert isinstance(tc.function.arguments, str)
parsed = json.loads(tc.function.arguments)
assert "city" in parsed
assert len(tc.id) == 9
non_streaming_tool_calls = tool_calls
# --- streaming ---
stream = await client.chat.completions.create(
messages=ensure_system_prompt(
MESSAGES_ASKING_FOR_PARALLEL_TOOLS, server_config
),
temperature=0,
max_completion_tokens=200,
model=model_name,
tools=[WEATHER_TOOL],
logprobs=False,
seed=SEED,
stream=True,
)
result = await _collect_streamed_parallel_tool_calls(stream)
assert result.finish_reason_count == 1
assert result.role_name == "assistant"
assert len(result.function_names) >= 2
assert all(name == "get_current_weather" for name in result.function_names)
assert len(result.tool_call_ids) >= 2
assert all(isinstance(tid, str) and len(tid) == 9 for tid in result.tool_call_ids)
for args_str in result.function_args_strs:
streamed_args = json.loads(args_str)
assert "city" in streamed_args
assert len(result.function_names) == len(non_streaming_tool_calls)
......@@ -2,16 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing_extensions import TypedDict
class ServerConfig(TypedDict, total=False):
model: str
arguments: list[str]
system_prompt: str | None
supports_parallel: bool | None
supports_rocm: bool | None
from tests.tool_use.utils import ServerConfig
ARGS: list[str] = ["--max-model-len", "1024"]
......@@ -21,6 +12,11 @@ CONFIGS: dict[str, ServerConfig] = {
"arguments": [
"--tokenizer-mode",
"mistral",
"--tool-call-parser",
"mistral",
"--enable-auto-tool-choice",
"--enforce-eager",
"--no-enable-prefix-caching",
'--ignore-patterns="consolidated.safetensors"',
],
"system_prompt": "You are a helpful assistant with access to tools. If a tool"
......@@ -29,4 +25,22 @@ CONFIGS: dict[str, ServerConfig] = {
"without calling a tool. DO NOT CALL A TOOL THAT IS IRRELEVANT "
"to the user's question - just respond to it normally.",
},
"ministral-3b": {
"model": "mistralai/Ministral-3-3B-Instruct-2512",
"arguments": [
"--tokenizer-mode",
"mistral",
"--tool-call-parser",
"mistral",
"--enable-auto-tool-choice",
"--enforce-eager",
"--no-enable-prefix-caching",
],
"system_prompt": "You are a helpful assistant with access to tools. If a tool"
" that you have would be helpful to answer a user query, "
"call the tool. Otherwise, answer the user's query directly "
"without calling a tool. DO NOT CALL A TOOL THAT IS IRRELEVANT "
"to the user's question - just respond to it normally.",
"supports_parallel": True,
},
}
......@@ -11,7 +11,7 @@ from openai.types.chat.chat_completion_audio import (
ChatCompletionAudio as OpenAIChatCompletionAudio,
)
from openai.types.chat.chat_completion_message import Annotation as OpenAIAnnotation
from pydantic import Field, model_validator
from pydantic import Field, PrivateAttr, model_validator
from vllm.config import ModelConfig
from vllm.config.utils import replace
......@@ -398,6 +398,9 @@ class ChatCompletionRequest(OpenAIBaseModel):
msg["tool_calls"] = list(tool_calls)
return self
_grammar_from_tool_parser: bool = PrivateAttr(default=False)
"""CAUTION: Should only be set by ``ToolParser.adjust_request``."""
def build_chat_params(
self,
default_template: str | None,
......@@ -822,13 +825,6 @@ class ChatCompletionRequest(OpenAIBaseModel):
return data
@model_validator(mode="before")
@classmethod
def set_include_reasoning_for_none_effort(cls, data: Any) -> Any:
if data.get("reasoning_effort") == "none":
data["include_reasoning"] = False
return data
class BatchChatCompletionRequest(OpenAIBaseModel):
"""Request model for the /v1/chat/completions/batch endpoint.
......
......@@ -73,7 +73,10 @@ from vllm.reasoning import ReasoningParser
from vllm.renderers import ChatParams
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.mistral_tool_parser import MistralToolCall
from vllm.tool_parsers.mistral_tool_parser import (
MistralToolCall,
MistralToolParser,
)
from vllm.tool_parsers.utils import partial_json_loads
from vllm.utils.collection_utils import as_list
from vllm.utils.mistral import is_mistral_tokenizer
......@@ -140,6 +143,12 @@ class OpenAIServingChat(OpenAIServing):
enable_auto_tools=enable_auto_tools,
model_name=self.model_config.model,
)
_is_mistral_tool_parser = self.tool_parser is not None and issubclass(
self.tool_parser, MistralToolParser
)
if _is_mistral_tool_parser and self.reasoning_parser_cls is not None:
MistralToolParser.model_can_reason = True
self.exclude_tools_when_tool_choice_none = exclude_tools_when_tool_choice_none
self.enable_prompt_tokens_details = enable_prompt_tokens_details
......@@ -310,6 +319,11 @@ class OpenAIServingChat(OpenAIServing):
else:
if not request.include_reasoning:
reasoning_ended = True
elif request._grammar_from_tool_parser:
# The Mistral grammar already includes an optional
# `think?` rule that handles both reasoning and
# non-reasoning outputs.
reasoning_ended = True
elif reasoning_parser:
reasoning_ended = reasoning_parser.is_reasoning_end(
prompt_token_ids or []
......@@ -530,6 +544,8 @@ class OpenAIServingChat(OpenAIServing):
harmony_tools_streamed = [False] * num_choices
tools_streamed = [False] * num_choices
is_mistral_grammar_path = request._grammar_from_tool_parser
if isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam):
tool_choice_function_name = request.tool_choice.function.name
else:
......@@ -553,7 +569,7 @@ class OpenAIServingChat(OpenAIServing):
# Only one of these will be used, thus previous_texts and
# all_previous_token_ids will not be used twice in the same iteration.
if tool_choice_auto or reasoning_parser:
if is_mistral_grammar_path or tool_choice_auto or reasoning_parser:
# These are only required in "auto" tool choice case
all_previous_token_ids = [[] for _ in range(num_choices)]
reasoning_end_arr = [False] * num_choices
......@@ -748,7 +764,7 @@ class OpenAIServingChat(OpenAIServing):
delta_message: DeltaMessage | None
# just update previous_texts and previous_token_ids
if tool_choice_auto or reasoning_parser:
if is_mistral_grammar_path or tool_choice_auto or reasoning_parser:
assert previous_texts is not None
assert all_previous_token_ids is not None
previous_text = previous_texts[i]
......@@ -772,6 +788,30 @@ class OpenAIServingChat(OpenAIServing):
)
)
harmony_tools_streamed[i] |= tools_streamed_flag
# Mistral grammar path: combined reasoning + tool streaming
elif is_mistral_grammar_path:
assert tool_parser is not None
assert isinstance(tool_parser, MistralToolParser)
assert reasoning_end_arr is not None
output_token_ids = as_list(output.token_ids)
result = tool_parser.extract_maybe_reasoning_and_tool_streaming(
reasoning_parser=reasoning_parser,
previous_text=previous_text,
current_text=current_text,
delta_text=delta_text,
previous_token_ids=previous_token_ids,
current_token_ids=current_token_ids,
output_token_ids=output_token_ids,
reasoning_ended=reasoning_end_arr[i],
prompt_is_reasoning_end=(prompt_is_reasoning_end_arr[i]),
request=request,
)
delta_message = result.delta_message
reasoning_end_arr[i] = result.reasoning_ended
current_text = result.current_text
current_token_ids = result.current_token_ids
if result.tools_called:
tools_streamed[i] = True
# handle streaming deltas for tools with named tool_choice
elif tool_choice_function_name:
# When encountering think end id in prompt_token_ids
......@@ -925,7 +965,9 @@ class OpenAIServingChat(OpenAIServing):
delta_message = DeltaMessage(content=delta_text)
# update the previous values for the next iteration
if (tool_choice_auto or reasoning_parser) and not self.use_harmony:
if (
is_mistral_grammar_path or tool_choice_auto or reasoning_parser
) and not self.use_harmony:
assert previous_texts is not None
assert all_previous_token_ids is not None
previous_texts[i] = current_text
......@@ -1312,7 +1354,24 @@ class OpenAIServingChat(OpenAIServing):
tool_call_class = (
MistralToolCall if is_mistral_tokenizer(tokenizer) else ToolCall
)
if (not self.enable_auto_tools or not self.tool_parser) and (
use_mistral_tool_parser = request._grammar_from_tool_parser
if use_mistral_tool_parser:
tool_call_items = MistralToolParser.build_non_streaming_tool_calls(
tool_calls
)
if tool_call_items:
auto_tools_called = (
request.tool_choice is None or request.tool_choice == "auto"
)
message = ChatMessage(
role=role,
reasoning=reasoning,
content=content,
tool_calls=tool_call_items,
)
elif (not self.enable_auto_tools or not self.tool_parser) and (
not isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam)
and request.tool_choice != "required"
):
......
......@@ -65,6 +65,7 @@ from vllm.renderers.inputs.preprocess import (
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers import ToolParser
from vllm.tool_parsers.mistral_tool_parser import MistralToolParser
from vllm.tracing import (
contains_trace_headers,
extract_trace_headers,
......@@ -610,16 +611,31 @@ class OpenAIServing:
tool_parser_cls: type[ToolParser] | None,
content: str | None = None,
) -> tuple[list[FunctionCall] | None, str | None]:
# When the Mistral grammar factory injected structured outputs,
# let the parser handle the output.
use_mistral_tool_parser = (
isinstance(request, ChatCompletionRequest)
and tool_parser_cls is not None
and issubclass(tool_parser_cls, MistralToolParser)
and request._grammar_from_tool_parser
)
function_calls = list[FunctionCall]()
if request.tool_choice and isinstance(request.tool_choice, ToolChoiceFunction):
if (
not use_mistral_tool_parser
and request.tool_choice
and isinstance(request.tool_choice, ToolChoiceFunction)
):
assert content is not None
# Forced Function Call
function_calls.append(
FunctionCall(name=request.tool_choice.name, arguments=content)
)
content = None # Clear content since tool is called.
elif request.tool_choice and isinstance(
request.tool_choice, ChatCompletionNamedToolChoiceParam
elif (
not use_mistral_tool_parser
and request.tool_choice
and isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam)
):
assert content is not None
# Forced Function Call
......@@ -627,7 +643,7 @@ class OpenAIServing:
FunctionCall(name=request.tool_choice.function.name, arguments=content)
)
content = None # Clear content since tool is called.
elif request.tool_choice == "required":
elif not use_mistral_tool_parser and request.tool_choice == "required":
tool_calls = []
with contextlib.suppress(ValidationError):
content = content or ""
......@@ -642,10 +658,12 @@ class OpenAIServing:
)
)
content = None # Clear content since tool is called.
elif (
tool_parser_cls
and enable_auto_tools
and (request.tool_choice == "auto" or request.tool_choice is None)
elif tool_parser_cls and (
use_mistral_tool_parser
or (
enable_auto_tools
and (request.tool_choice == "auto" or request.tool_choice is None)
)
):
if tokenizer is None:
raise ValueError(
......
......@@ -53,6 +53,7 @@ from vllm.renderers.inputs.preprocess import (
prompt_to_seq,
)
from vllm.tool_parsers import ToolParser
from vllm.tool_parsers.mistral_tool_parser import MistralToolParser
from vllm.utils import random_uuid
from vllm.utils.mistral import is_mistral_tokenizer
from vllm.utils.mistral import mt as _mt
......@@ -555,9 +556,19 @@ class OpenAIServingRender:
# tool parsing is done only if a tool_parser has been set and if
# tool_choice is not "none" (if tool_choice is "none" but a tool_parser
# is set, we want to prevent parsing a tool_call hallucinated by the LLM
#
# Exception: Mistral grammar-capable tokenizers always call
# adjust_request — even for tool_choice="none" — so that the grammar
# factory can prevent special-token leakage.
if tool_parser is not None:
tool_choice = getattr(request, "tool_choice", "none")
if tool_choice != "none":
tokenizer = renderer.get_tokenizer()
is_mistral_grammar_eligible = (
issubclass(tool_parser, MistralToolParser)
and is_mistral_tokenizer(tokenizer)
and tokenizer.supports_grammar
)
if tool_choice != "none" or is_mistral_grammar_eligible:
if not isinstance(request, ChatCompletionRequest | ResponsesRequest):
msg = (
"Tool usage is only supported "
......@@ -565,7 +576,6 @@ class OpenAIServingRender:
f"but got {type(request).__name__}"
)
raise NotImplementedError(msg)
tokenizer = renderer.get_tokenizer()
request = tool_parser(tokenizer, request.tools).adjust_request(
request=request
)
......
......@@ -157,6 +157,10 @@ def _is_non_tekken_mistral(tokenizer: TokenizerLike) -> bool:
return is_mistral_tokenizer(tokenizer) and not tokenizer.is_tekken
def _get_llg_tokenizer(tokenizer: TokenizerLike) -> Any:
return tokenizer.llg_tokenizer if is_mistral_tokenizer(tokenizer) else None
class SamplingParams(
PydanticMsgspecMixin,
msgspec.Struct,
......@@ -816,7 +820,10 @@ class SamplingParams(
# allows <|special_token|> and similar, see
# https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens
# Without tokenizer these are disallowed in grammars.
validate_guidance_grammar(self, tokenizer=None)
validate_guidance_grammar(
self,
tokenizer=_get_llg_tokenizer(tokenizer),
)
elif backend == "outlines":
# outlines backend
validate_structured_output_request_outlines(self)
......@@ -862,7 +869,10 @@ class SamplingParams(
self.structured_outputs._backend = "outlines"
else:
# Fall back to guidance by default.
validate_guidance_grammar(self, tokenizer=None)
validate_guidance_grammar(
self,
tokenizer=_get_llg_tokenizer(tokenizer),
)
self.structured_outputs._backend = "guidance"
# Remember that this backend was set automatically
self.structured_outputs._backend_was_auto = True
......
......@@ -54,6 +54,50 @@ if TYPE_CHECKING:
logger = init_logger(__name__)
def _pop_unallowed_keys_and_warn(
dictionary: dict[str, Any], allowed_keys: set[str], err_dict_name: str
):
keys = list(dictionary.keys())
for key in keys:
if key not in allowed_keys:
dictionary.pop(key)
logger.warning_once(
f"'{key=}' is not supported by mistral-common "
f"for {err_dict_name}. It has been popped from the "
"object."
)
# TODO(juliendenize): remove this once OpenAI API is better supported by
# `mistral-common`.
def adapt_inplace_to_mistral_tool(
tool: dict[str, Any],
) -> dict[str, Any]:
tools_fields = set(Tool.model_fields.keys())
function_fields = set(Function.model_fields.keys())
# The Mistral client, in comparison to the OpenAI client, requires the
# "parameters" dict and the "description" string to be present
# even if they are empty.
if function := tool.get("function"):
if function.get("parameters") is None:
function["parameters"] = {}
if function.get("description") is None:
function["description"] = ""
_pop_unallowed_keys_and_warn(
dictionary=function,
allowed_keys=function_fields,
err_dict_name="function",
)
_pop_unallowed_keys_and_warn(
dictionary=tool, allowed_keys=tools_fields, err_dict_name="tools"
)
return tool
def maybe_serialize_tool_calls(request: "MistralChatCompletionRequest"):
# SEE: https://github.com/vllm-project/vllm/pull/9951
# Credits go to: @gcalmettes
......@@ -159,44 +203,11 @@ def _prepare_apply_chat_template_tools_and_messages(
# Remove reasoning as unsupported by Mistral
_ = message.pop("reasoning", None) # type: ignore
# The Mistral client, in comparison to the OpenAI client, requires the
# "parameters" dict and the "description" string to be present
# even if they are empty.
if tools:
for function in [
tool["function"] for tool in tools if tool["type"] == "function"
]:
if function.get("parameters") is None:
function["parameters"] = {}
if function.get("description") is None:
function["description"] = ""
# We filter not supported arguments to avoid throwing an error.
# TODO(juliendenize): remove this once OpenAI API is better supported by
# `mistral-common`.
tools_fields = set(Tool.model_fields.keys())
function_fields = set(Function.model_fields.keys())
for tool in tools:
tool_keys = list(tool.keys())
for tool_key in tool_keys:
if tool_key not in tools_fields:
tool.pop(tool_key)
logger.warning_once(
f"'{tool_key}' is not supported by mistral-common for tools. "
"It has been popped from the tool definition."
)
if tool["type"] == "function":
function_keys = list(tool["function"].keys())
for function_key in function_keys:
if function_key not in function_fields:
tool["function"].pop(function_key)
logger.warning_once(
f"'{function_key}' is not supported by mistral-common "
"for function tools. It has been popped from the "
"function definition."
)
else:
raise ValueError("mistral-common only supports function tools.")
tools = (
[adapt_inplace_to_mistral_tool(tool=tool) for tool in tools]
if tools is not None
else None
)
return messages, tools
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import json
from collections.abc import Sequence
from dataclasses import dataclass
from enum import Enum, auto
from random import choices
from string import ascii_letters, digits
from typing import Any
from typing import TYPE_CHECKING, Any
import ijson
import regex as re
......@@ -37,14 +40,19 @@ from vllm.entrypoints.openai.engine.protocol import (
)
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
from vllm.logger import init_logger
from vllm.reasoning.mistral_reasoning_parser import MistralReasoningParser
from vllm.sampling_params import StructuredOutputsParams
from vllm.tokenizers import TokenizerLike
from vllm.tokenizers.mistral import MistralTokenizer, adapt_inplace_to_mistral_tool
from vllm.tool_parsers.abstract_tool_parser import (
Tool,
ToolParser,
)
from vllm.utils.mistral import is_mistral_tokenizer
if TYPE_CHECKING:
from vllm.reasoning import ReasoningParser
logger = init_logger(__name__)
ALPHANUMERIC = ascii_letters + digits
......@@ -86,13 +94,28 @@ def _is_pre_v11_tokeniser(model_tokenizer: TokenizerLike) -> bool:
return not (is_mistral_tokenizer(model_tokenizer) and model_tokenizer.version >= 11)
class MistralToolParser(ToolParser):
@dataclass
class MistralStreamingResult:
r"""Encapsulates the mutable state returned from
`MistralToolParser.extract_maybe_reasoning_and_tool_streaming`.
"""
Tool call parser for Mistral 7B Instruct v0.3, intended for use with
- [`mistral_common`](https://github.com/mistralai/mistral-common/)
- the examples/tool_chat_template_mistral.jinja template.
Used when --enable-auto-tool-choice --tool-call-parser mistral are all set
delta_message: DeltaMessage | None
reasoning_ended: bool
tools_called: bool
current_text: str
current_token_ids: list[int]
class MistralToolParser(ToolParser):
r"""Tool call parser for Mistral models, intended for use with either:
- `mistral_common <https://github.com/mistralai/mistral-common/>`_
(recommended)
- the `examples/tool_chat_template_mistral.jinja` template.
Used when `--enable-auto-tool-choice --tool-call-parser mistral` are all
set.
"""
# Used to generate correct grammar in `adjust_request`
......@@ -210,9 +233,11 @@ class MistralToolParser(ToolParser):
reasoning=self.model_can_reason
)
tools = (
mistral_tools = (
[
MistralTool.from_openai(openai_tool=tool.model_dump())
MistralTool.model_validate(
adapt_inplace_to_mistral_tool(tool.model_dump())
)
for tool in request.tools
]
if request.tools is not None
......@@ -244,15 +269,158 @@ class MistralToolParser(ToolParser):
lark_grammar = grammar_factory.get_lark_from_jinja(
template=template,
mode=tool_choice,
tools=tools,
tools=mistral_tools,
json_schema=json_schema,
parallel_tool_calls=request.parallel_tool_calls,
json_only=False,
)
request.structured_outputs = StructuredOutputsParams(grammar=lark_grammar)
request._grammar_from_tool_parser = True
return request
def extract_maybe_reasoning_and_tool_streaming(
self,
*,
reasoning_parser: ReasoningParser | None,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: list[int],
current_token_ids: list[int],
output_token_ids: Sequence[int],
reasoning_ended: bool,
prompt_is_reasoning_end: bool | None,
request: ChatCompletionRequest,
) -> MistralStreamingResult:
r"""Streaming extraction with reasoning followed by tool-call parsing.
This method encapsulates the combined reasoning extraction and
tool-call streaming logic so that the serving layer only needs a
thin routing branch.
The flow is:
1. If a *reasoning_parser* is present and reasoning has **not** ended,
extract reasoning tokens. Pre-v15 models may have pre-filled
`[THINK]...[/THINK]` in system prompts, so we skip the
prompt-level reasoning-end check for those.
2. Once reasoning ends (or if there is no reasoning parser), delegate
to `extract_tool_calls_streaming` and track whether tools were
called.
Args:
reasoning_parser: Optional reasoning parser instance.
previous_text: Accumulated text from prior chunks.
current_text: Full accumulated text including current chunk.
delta_text: New text in this chunk.
previous_token_ids: Token ids from prior chunks.
current_token_ids: Full token ids including current chunk.
output_token_ids: Raw output token ids from the engine.
reasoning_ended: Whether reasoning has already ended.
prompt_is_reasoning_end: Whether the prompt itself ends reasoning.
request: The originating chat completion request.
"""
delta_message: DeltaMessage | None = None
tools_called = False
reasoning_ended_at_entry = reasoning_ended
# For MistralReasoningParser, only enter the reasoning block when
# the model has actually emitted a [THINK] token. Other reasoning
# parsers always expect thinking to be present.
expect_thinking = (
not isinstance(reasoning_parser, MistralReasoningParser)
or reasoning_parser.start_token_id in current_token_ids
)
if reasoning_parser is not None and not reasoning_ended and expect_thinking:
# Pre-v15 models may have pre-filled [THINK]...[/THINK] in
# system prompts, so skip the prompt-level reasoning-end
# check and wait for the output's own end-of-think.
is_pre_v15 = (
isinstance(self.model_tokenizer, MistralTokenizer)
and self.model_tokenizer.version < 15
)
if not is_pre_v15 and prompt_is_reasoning_end:
reasoning_ended = True
current_token_ids = list(output_token_ids)
else:
delta_message = reasoning_parser.extract_reasoning_streaming(
previous_text,
current_text,
delta_text,
previous_token_ids,
current_token_ids,
output_token_ids,
)
if reasoning_parser.is_reasoning_end_streaming(
current_token_ids, output_token_ids
):
reasoning_ended = True
current_token_ids = reasoning_parser.extract_content_ids(
list(output_token_ids)
)
if delta_message and delta_message.content:
current_text = delta_message.content
delta_message.content = None
else:
current_text = ""
if not reasoning_ended:
return MistralStreamingResult(
delta_message=delta_message,
reasoning_ended=False,
tools_called=False,
current_text=current_text,
current_token_ids=current_token_ids,
)
delta_token_ids = list(output_token_ids)
# On the iteration where reasoning just ended, reset the text/token
# state so the tool parser sees a clean history instead of the
# accumulated reasoning text.
if not reasoning_ended_at_entry and reasoning_ended:
previous_text = ""
previous_token_ids = []
delta_text = current_text
delta_token_ids = current_token_ids
delta_message = self.extract_tool_calls_streaming(
previous_text=previous_text,
current_text=current_text,
delta_text=delta_text,
previous_token_ids=previous_token_ids,
current_token_ids=current_token_ids,
delta_token_ids=delta_token_ids,
request=request,
)
if delta_message and delta_message.tool_calls:
tools_called = True
return MistralStreamingResult(
delta_message=delta_message,
reasoning_ended=reasoning_ended,
tools_called=tools_called,
current_text=current_text,
current_token_ids=current_token_ids,
)
@staticmethod
def build_non_streaming_tool_calls(
tool_calls: list[FunctionCall] | None,
) -> list[ToolCall]:
r"""Build `MistralToolCall` items for non-streaming responses."""
if not tool_calls:
return []
return [
MistralToolCall(id=tc.id, function=tc)
if tc.id
else MistralToolCall(function=tc)
for tc in tool_calls
]
def extract_tool_calls(
self,
model_output: str,
......@@ -323,7 +491,7 @@ class MistralToolParser(ToolParser):
)[0]
tool_calls = json.loads(raw_tool_call)
except (IndexError, json.JSONDecodeError):
logger.exception("Error in extracting tool call from response: {e}")
logger.exception("Error in extracting tool call from response.")
# If raw decoding and decoding post regex rule fails, then just
# return content.
return ExtractedToolCallInformation(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment