Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
c0722f22
Unverified
Commit
c0722f22
authored
Apr 16, 2026
by
Julien Denize
Committed by
GitHub
Apr 15, 2026
Browse files
[Mistral Grammar] Fix tool and reasoning parsing (#39217)
Signed-off-by:
juliendenize
<
julien.denize@mistral.ai
>
parent
951dca80
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
1601 additions
and
266 deletions
+1601
-266
tests/tool_parsers/test_mistral_tool_parser.py
tests/tool_parsers/test_mistral_tool_parser.py
+752
-180
tests/tool_use/mistral/test_mistral_tool_calls.py
tests/tool_use/mistral/test_mistral_tool_calls.py
+480
-3
tests/tool_use/mistral/utils.py
tests/tool_use/mistral/utils.py
+24
-10
vllm/entrypoints/openai/chat_completion/protocol.py
vllm/entrypoints/openai/chat_completion/protocol.py
+4
-8
vllm/entrypoints/openai/chat_completion/serving.py
vllm/entrypoints/openai/chat_completion/serving.py
+64
-5
vllm/entrypoints/openai/engine/serving.py
vllm/entrypoints/openai/engine/serving.py
+26
-8
vllm/entrypoints/serve/render/serving.py
vllm/entrypoints/serve/render/serving.py
+12
-2
vllm/sampling_params.py
vllm/sampling_params.py
+12
-2
vllm/tokenizers/mistral.py
vllm/tokenizers/mistral.py
+49
-38
vllm/tool_parsers/mistral_tool_parser.py
vllm/tool_parsers/mistral_tool_parser.py
+178
-10
No files found.
tests/tool_parsers/test_mistral_tool_parser.py
View file @
c0722f22
...
...
@@ -3,6 +3,7 @@
import
json
from
collections.abc
import
Generator
from
typing
import
Any
from
unittest.mock
import
MagicMock
,
patch
import
partial_json_parser
...
...
@@ -23,24 +24,33 @@ from mistral_common.protocol.instruct.tool_calls import (
ToolChoiceEnum
as
MistralToolChoiceEnum
,
)
from
partial_json_parser.core.options
import
Allow
from
pydantic
import
ValidationError
from
vllm.entrypoints.openai.chat_completion.protocol
import
(
ChatCompletionRequest
,
)
from
vllm.entrypoints.openai.engine.protocol
import
(
DeltaFunctionCall
,
DeltaMessage
,
DeltaToolCall
,
ExtractedToolCallInformation
,
StructuralTagResponseFormat
,
)
from
vllm.entrypoints.openai.engine.protocol
import
FunctionCall
as
VllmFunctionCall
from
vllm.reasoning.mistral_reasoning_parser
import
MistralReasoningParser
from
vllm.sampling_params
import
StructuredOutputsParams
from
vllm.tokenizers
import
TokenizerLike
,
get_tokenizer
from
vllm.tokenizers.detokenizer_utils
import
detokenize_incrementally
from
vllm.tokenizers.mistral
import
MistralTokenizer
from
vllm.tool_parsers.mistral_tool_parser
import
(
_DEFAULT_JSON_SCHEMA
,
MistralStreamingResult
,
MistralToolCall
,
MistralToolParser
,
)
_DUMMY_REQUEST
=
ChatCompletionRequest
(
messages
=
[],
model
=
"test"
)
@
pytest
.
fixture
(
scope
=
"module"
)
def
mistral_pre_v11_tokenizer
():
...
...
@@ -205,7 +215,7 @@ def stream_delta_message_generator(
previous_token_ids
,
current_token_ids
,
delta_token_ids
,
request
=
None
,
# type: ignore[arg-type]
request
=
_DUMMY_REQUEST
,
)
if
delta_message
:
yield
delta_message
...
...
@@ -218,14 +228,18 @@ def stream_delta_message_generator(
read_offset
=
new_read_offset
def
test_extract_tool_calls_no_tools
(
mistral_pre_v11_tool_parser
):
@
pytest
.
mark
.
parametrize
(
"parser_fixture"
,
[
"mistral_pre_v11_tool_parser"
,
"mistral_tool_parser"
],
ids
=
[
"pre_v11"
,
"v11"
],
)
def
test_extract_tool_calls_no_tools
(
parser_fixture
,
request
):
parser
=
request
.
getfixturevalue
(
parser_fixture
)
model_output
=
"This is a test"
extracted_tool_calls
=
mistral_pre_v11_tool_parser
.
extract_tool_calls
(
model_output
,
request
=
None
)
# type: ignore[arg-type]
assert
not
extracted_tool_calls
.
tools_called
assert
extracted_tool_calls
.
tool_calls
==
[]
assert
extracted_tool_calls
.
content
==
model_output
result
=
parser
.
extract_tool_calls
(
model_output
,
request
=
_DUMMY_REQUEST
)
assert
result
==
ExtractedToolCallInformation
(
tools_called
=
False
,
tool_calls
=
[],
content
=
model_output
)
@
pytest
.
mark
.
parametrize
(
...
...
@@ -234,6 +248,8 @@ def test_extract_tool_calls_no_tools(mistral_pre_v11_tool_parser):
"single_tool_weather"
,
"argument_before_name"
,
"argument_before_name_and_name_in_argument"
,
"multiple_tools"
,
"content_before_tool"
,
],
argnames
=
[
"model_output"
,
"expected_tool_calls"
,
"expected_content"
],
argvalues
=
[
...
...
@@ -292,14 +308,44 @@ def test_extract_tool_calls_no_tools(mistral_pre_v11_tool_parser):
],
None
,
),
(
"""[TOOL_CALLS] [{"name": "add", "arguments": {"a": 3.5, "b": 4}}, {"name": "get_current_weather", "arguments":{"city": "San Francisco", "state": "CA", "unit": "celsius"}}]"""
,
# noqa: E501
[
ToolCall
(
function
=
FunctionCall
(
name
=
"add"
,
arguments
=
json
.
dumps
({
"a"
:
3.5
,
"b"
:
4
})
)
),
ToolCall
(
function
=
FunctionCall
(
name
=
"get_current_weather"
,
arguments
=
json
.
dumps
(
{
"city"
:
"San Francisco"
,
"state"
:
"CA"
,
"unit"
:
"celsius"
}
),
)
),
],
None
,
),
(
"""Hello[TOOL_CALLS] [{"name": "add", "arguments":{"a": 1, "b": 2}}]"""
,
# noqa: E501
[
ToolCall
(
function
=
FunctionCall
(
name
=
"add"
,
arguments
=
json
.
dumps
({
"a"
:
1
,
"b"
:
2
})
)
)
],
"Hello"
,
),
],
)
def
test_extract_tool_calls_pre_v11_tokenizer
(
mistral_pre_v11_tool_parser
,
model_output
,
expected_tool_calls
,
expected_content
):
extracted_tool_calls
=
mistral_pre_v11_tool_parser
.
extract_tool_calls
(
model_output
,
request
=
None
)
# type: ignore[arg-type]
model_output
,
request
=
_DUMMY_REQUEST
)
assert
extracted_tool_calls
.
tools_called
assert_tool_calls
(
extracted_tool_calls
.
tool_calls
,
expected_tool_calls
)
...
...
@@ -307,6 +353,46 @@ def test_extract_tool_calls_pre_v11_tokenizer(
assert
extracted_tool_calls
.
content
==
expected_content
def
test_extract_tool_calls_pre_v11_multiple_bot_tokens_raises
(
mistral_pre_v11_tool_parser
,
):
model_output
=
(
'[TOOL_CALLS] [{"name": "add", "arguments":{"a": 1}}]'
'[TOOL_CALLS] [{"name": "sub", "arguments":{"b": 2}}]'
)
with
pytest
.
raises
(
ValueError
,
match
=
"Only one BOT token"
):
mistral_pre_v11_tool_parser
.
extract_tool_calls
(
model_output
,
request
=
_DUMMY_REQUEST
)
def
test_extract_tool_calls_pre_v11_regex_fallback_raises
(
mistral_pre_v11_tool_parser
,
):
"""The regex fallback path finds valid JSON but does not re-serialize
the `arguments` dict to a string, causing a Pydantic
`ValidationError` when constructing `FunctionCall`."""
model_output
=
(
'[TOOL_CALLS] junk [{"name": "add", "arguments":{"a": 1, "b": 2}}] trail'
)
with
pytest
.
raises
(
ValidationError
):
mistral_pre_v11_tool_parser
.
extract_tool_calls
(
model_output
,
request
=
_DUMMY_REQUEST
)
def
test_extract_tool_calls_pre_v11_regex_fallback_fails
(
mistral_pre_v11_tool_parser
,
):
model_output
=
"[TOOL_CALLS] not json at all"
result
=
mistral_pre_v11_tool_parser
.
extract_tool_calls
(
model_output
,
request
=
_DUMMY_REQUEST
)
assert
result
==
ExtractedToolCallInformation
(
tools_called
=
False
,
tool_calls
=
[],
content
=
"not json at all"
)
@
pytest
.
mark
.
parametrize
(
ids
=
[
"single_tool_add"
,
...
...
@@ -395,8 +481,8 @@ def test_extract_tool_calls(
mistral_tool_parser
,
model_output
,
expected_tool_calls
,
expected_content
):
extracted_tool_calls
=
mistral_tool_parser
.
extract_tool_calls
(
model_output
,
request
=
None
)
# type: ignore[arg-type]
model_output
,
request
=
_DUMMY_REQUEST
)
assert
extracted_tool_calls
.
tools_called
assert_tool_calls
(
extracted_tool_calls
.
tool_calls
,
expected_tool_calls
)
...
...
@@ -404,6 +490,16 @@ def test_extract_tool_calls(
assert
extracted_tool_calls
.
content
==
expected_content
def
test_extract_tool_calls_v11_without_args_skipped
(
mistral_tool_parser
):
model_output
=
"[TOOL_CALLS]toolname_no_args"
result
=
mistral_tool_parser
.
extract_tool_calls
(
model_output
,
request
=
_DUMMY_REQUEST
)
assert
result
==
ExtractedToolCallInformation
(
tools_called
=
True
,
tool_calls
=
[],
content
=
None
)
def
_test_extract_tool_calls_streaming
(
tool_parser
,
tokenizer
,
model_output
,
tools
,
expected_tool_calls
,
expected_content
):
...
...
@@ -669,17 +765,65 @@ def test_extract_tool_calls_streaming(
)
def
test_extract_tool_calls_streaming_v11_no_tools
(
mistral_tool_parser
,
mistral_tokenizer
):
model_output
=
"This is a test"
if
isinstance
(
mistral_tokenizer
,
MistralTokenizer
):
all_token_ids
=
mistral_tokenizer
.
encode
(
model_output
)
else
:
all_token_ids
=
mistral_tokenizer
.
encode
(
model_output
,
add_special_tokens
=
False
)
skip_special
=
isinstance
(
mistral_tokenizer
,
MistralTokenizer
)
collected_content
=
""
previous_text
=
""
previous_tokens
=
None
prefix_offset
=
0
read_offset
=
0
for
i
in
range
(
len
(
all_token_ids
)):
current_token_ids
=
all_token_ids
[:
i
+
1
]
previous_token_ids
=
all_token_ids
[:
i
]
delta_token_ids
=
[
all_token_ids
[
i
]]
new_tokens
,
delta_text
,
prefix_offset
,
read_offset
=
detokenize_incrementally
(
tokenizer
=
mistral_tokenizer
,
all_input_ids
=
current_token_ids
,
prev_tokens
=
previous_tokens
,
prefix_offset
=
prefix_offset
,
read_offset
=
read_offset
,
skip_special_tokens
=
skip_special
,
spaces_between_special_tokens
=
True
,
)
current_text
=
previous_text
+
delta_text
previous_tokens
=
(
previous_tokens
+
new_tokens
if
previous_tokens
else
new_tokens
)
delta_message
=
mistral_tool_parser
.
extract_tool_calls_streaming
(
previous_text
=
previous_text
,
current_text
=
current_text
,
delta_text
=
delta_text
,
previous_token_ids
=
previous_token_ids
,
current_token_ids
=
current_token_ids
,
delta_token_ids
=
delta_token_ids
,
request
=
_DUMMY_REQUEST
,
)
if
delta_message
and
delta_message
.
content
:
collected_content
+=
delta_message
.
content
if
delta_message
:
assert
not
delta_message
.
tool_calls
previous_text
=
current_text
assert
collected_content
==
model_output
@
pytest
.
mark
.
parametrize
(
ids
=
[
"single_tool_add"
,
"single_tool_weather"
,
"multiple_tool_calls"
,
"content_before_tool"
,
"complex"
,
],
argnames
=
[
"model_output"
,
"expected_tool_calls"
,
"expected_content"
],
argvalues
=
[
(
"parser_fixture, tokenizer_fixture, model_output,"
" expected_tool_calls, expected_content"
,
[
pytest
.
param
(
"mistral_tool_parser"
,
"mistral_tokenizer"
,
"""[TOOL_CALLS]add_this_and_that{"a": 3.5, "b": 4}"""
,
# noqa: E501
[
ToolCall
(
...
...
@@ -690,8 +834,11 @@ def test_extract_tool_calls_streaming(
)
],
""
,
id
=
"v11-single_tool_add"
,
),
(
pytest
.
param
(
"mistral_tool_parser"
,
"mistral_tokenizer"
,
"""[TOOL_CALLS]get_current_weather{"city": "San Francisco", "state": "CA", "unit": "celsius"}"""
,
# noqa: E501
[
ToolCall
(
...
...
@@ -704,8 +851,11 @@ def test_extract_tool_calls_streaming(
)
],
""
,
id
=
"v11-single_tool_weather"
,
),
(
pytest
.
param
(
"mistral_tool_parser"
,
"mistral_tokenizer"
,
"""[TOOL_CALLS]add{"a": 3.5, "b": 4}[TOOL_CALLS]multiply{"a": 3, "b": 6}"""
,
# noqa: E501
[
ToolCall
(
...
...
@@ -720,9 +870,11 @@ def test_extract_tool_calls_streaming(
),
],
""
,
id
=
"v11-multiple_tool_calls"
,
),
(
# Additional content should not be after the tool calls
pytest
.
param
(
"mistral_tool_parser"
,
"mistral_tokenizer"
,
"""bla[TOOL_CALLS]add_this_and_that{"a": 3.5, "b": 4}"""
,
# noqa: E501
[
ToolCall
(
...
...
@@ -733,9 +885,11 @@ def test_extract_tool_calls_streaming(
)
],
"bla"
,
id
=
"v11-content_before_tool"
,
),
(
# Complex
pytest
.
param
(
"mistral_tool_parser"
,
"mistral_tokenizer"
,
"""hi{hi[TOOL_CALLS]bash{"command": "print(
\\
"hello world!
\\
")
\\
nre.compile(r
\'
{}
\'
)"}"""
,
# noqa: E501
[
ToolCall
(
...
...
@@ -748,58 +902,19 @@ def test_extract_tool_calls_streaming(
)
],
"hi{hi"
,
),
],
)
def
test_extract_tool_calls_streaming_one_chunk
(
mistral_tool_parser
,
mistral_tokenizer
,
model_output
,
expected_tool_calls
,
expected_content
,
):
if
isinstance
(
mistral_tokenizer
,
MistralTokenizer
):
all_token_ids
=
mistral_tokenizer
.
encode
(
model_output
)
else
:
all_token_ids
=
mistral_tokenizer
.
encode
(
model_output
,
add_special_tokens
=
False
)
all_token_ids
=
fix_tool_call_tokenization
(
all_token_ids
,
mistral_tool_parser
,
mistral_tokenizer
)
delta_message
=
mistral_tool_parser
.
extract_tool_calls_streaming
(
previous_text
=
""
,
current_text
=
model_output
,
delta_text
=
model_output
,
previous_token_ids
=
[],
current_token_ids
=
all_token_ids
,
delta_token_ids
=
all_token_ids
,
request
=
None
,
)
# type: ignore[arg-type]
assert
isinstance
(
delta_message
,
DeltaMessage
)
assert
len
(
delta_message
.
tool_calls
)
==
len
(
expected_tool_calls
)
assert_tool_calls
(
delta_message
.
tool_calls
,
expected_tool_calls
)
if
delta_message
.
content
is
None
:
assert
expected_content
==
""
else
:
assert
delta_message
.
content
==
expected_content
@
pytest
.
mark
.
parametrize
(
ids
=
[
"no_tools"
,
"single_tool_add"
,
"single_tool_add_strings"
,
"single_tool_weather"
,
"argument_before_name"
,
"argument_before_name_and_name_in_argument"
,
"multiple_tools"
,
],
argnames
=
[
"model_output"
,
"expected_tool_calls"
,
"expected_content"
],
argvalues
=
[
(
"""This is a test"""
,
[],
"""This is a test"""
),
(
id
=
"v11-complex"
,
),
pytest
.
param
(
"mistral_pre_v11_tool_parser"
,
"mistral_pre_v11_tokenizer"
,
"""This is a test"""
,
[],
"""This is a test"""
,
id
=
"pre_v11-no_tools"
,
),
pytest
.
param
(
"mistral_pre_v11_tool_parser"
,
"mistral_pre_v11_tokenizer"
,
"""[TOOL_CALLS] [ {"name":"add" , "arguments" : {"a": 3, "b": 4} } ]"""
,
# noqa: E501
[
ToolCall
(
...
...
@@ -809,8 +924,11 @@ def test_extract_tool_calls_streaming_one_chunk(
)
],
""
,
id
=
"pre_v11-single_tool_add"
,
),
(
pytest
.
param
(
"mistral_pre_v11_tool_parser"
,
"mistral_pre_v11_tokenizer"
,
"""[TOOL_CALLS] [{"name": "add", "arguments":{"a": "3", "b": "4"}}]"""
,
# noqa: E501
[
ToolCall
(
...
...
@@ -820,8 +938,11 @@ def test_extract_tool_calls_streaming_one_chunk(
)
],
""
,
id
=
"pre_v11-single_tool_add_strings"
,
),
(
pytest
.
param
(
"mistral_pre_v11_tool_parser"
,
"mistral_pre_v11_tokenizer"
,
"""[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"city": "San Francisco", "state": "CA", "unit": "celsius"}}]"""
,
# noqa: E501
[
ToolCall
(
...
...
@@ -834,8 +955,11 @@ def test_extract_tool_calls_streaming_one_chunk(
)
],
""
,
id
=
"pre_v11-single_tool_weather"
,
),
(
pytest
.
param
(
"mistral_pre_v11_tool_parser"
,
"mistral_pre_v11_tokenizer"
,
"""[TOOL_CALLS] [{"arguments": {"city": "San Francisco", "state": "CA", "unit": "celsius"}, "name": "get_current_weather"}]"""
,
# noqa: E501
[
ToolCall
(
...
...
@@ -848,8 +972,11 @@ def test_extract_tool_calls_streaming_one_chunk(
)
],
""
,
id
=
"pre_v11-argument_before_name"
,
),
(
pytest
.
param
(
"mistral_pre_v11_tool_parser"
,
"mistral_pre_v11_tokenizer"
,
"""[TOOL_CALLS] [{"arguments": {"name": "John Doe"}, "name": "get_age"}]"""
,
# noqa: E501
[
ToolCall
(
...
...
@@ -864,8 +991,11 @@ def test_extract_tool_calls_streaming_one_chunk(
)
],
""
,
id
=
"pre_v11-argument_before_name_and_name_in_argument"
,
),
(
pytest
.
param
(
"mistral_pre_v11_tool_parser"
,
"mistral_pre_v11_tokenizer"
,
"""[TOOL_CALLS] [{"arguments": {"a": 3.5, "b": 4}, "name": "add"}, {"arguments":{"city": "San Francisco", "state": "CA", "unit": "celsius"}, "name": "get_current_weather"}]"""
,
# noqa: E501
[
ToolCall
(
...
...
@@ -883,35 +1013,50 @@ def test_extract_tool_calls_streaming_one_chunk(
),
],
""
,
id
=
"pre_v11-multiple_tools"
,
),
pytest
.
param
(
"mistral_pre_v11_tool_parser"
,
"mistral_pre_v11_tokenizer"
,
"""Some text[TOOL_CALLS] [{"name": "add", "arguments":{"a": 1, "b": 2}}]"""
,
# noqa: E501
[
ToolCall
(
function
=
FunctionCall
(
name
=
"add"
,
arguments
=
json
.
dumps
({
"a"
:
1
,
"b"
:
2
})
)
)
],
"Some text"
,
id
=
"pre_v11-content_before_tool"
,
),
],
)
def
test_extract_tool_calls_streaming_
pre_v11_tokenizer_
one_chunk
(
mistral_pre_v11_tool_parser
,
mistral_pre_v11_
tokenizer
,
def
test_extract_tool_calls_streaming_one_chunk
(
parser_fixture
,
tokenizer
_fixture
,
model_output
,
expected_tool_calls
,
expected_content
,
request
,
):
if
isinstance
(
mistral_pre_v11_tokenizer
,
MistralTokenizer
):
all_token_ids
=
mistral_pre_v11_tokenizer
.
encode
(
model_output
)
tool_parser
=
request
.
getfixturevalue
(
parser_fixture
)
tokenizer
=
request
.
getfixturevalue
(
tokenizer_fixture
)
if
isinstance
(
tokenizer
,
MistralTokenizer
):
all_token_ids
=
tokenizer
.
encode
(
model_output
)
else
:
all_token_ids
=
mistral_pre_v11_tokenizer
.
encode
(
model_output
,
add_special_tokens
=
False
)
all_token_ids
=
fix_tool_call_tokenization
(
all_token_ids
,
mistral_pre_v11_tool_parser
,
mistral_pre_v11_tokenizer
)
all_token_ids
=
tokenizer
.
encode
(
model_output
,
add_special_tokens
=
False
)
all_token_ids
=
fix_tool_call_tokenization
(
all_token_ids
,
tool_parser
,
tokenizer
)
delta_message
=
mistral_pre_v11_
tool_parser
.
extract_tool_calls_streaming
(
delta_message
=
tool_parser
.
extract_tool_calls_streaming
(
previous_text
=
""
,
current_text
=
model_output
,
delta_text
=
model_output
,
previous_token_ids
=
[],
current_token_ids
=
all_token_ids
,
delta_token_ids
=
all_token_ids
,
request
=
None
,
)
# type: ignore[arg-type]
request
=
_DUMMY_REQUEST
,
)
assert
isinstance
(
delta_message
,
DeltaMessage
)
assert
len
(
delta_message
.
tool_calls
)
==
len
(
expected_tool_calls
)
...
...
@@ -923,65 +1068,105 @@ def test_extract_tool_calls_streaming_pre_v11_tokenizer_one_chunk(
assert
delta_message
.
content
==
expected_content
def
test_fast_detokenization_text_detection
(
mistral_tool_parser
):
@
pytest
.
mark
.
parametrize
(
"parser_fixture, model_output, fake_count, two_phase"
,
[
pytest
.
param
(
"mistral_tool_parser"
,
'[TOOL_CALLS]add{"a": 1, "b": 2}'
,
20
,
True
,
id
=
"v11"
,
),
pytest
.
param
(
"mistral_pre_v11_tool_parser"
,
'[TOOL_CALLS] [{"name": "add", "arguments":{"a": 1, "b": 2}}]'
,
30
,
False
,
id
=
"pre_v11"
,
),
],
)
def
test_fast_detokenization_text_detection
(
parser_fixture
,
model_output
,
fake_count
,
two_phase
,
request
):
"""Regression: bot_token in text but not token_ids (PR #37209)."""
model_output
=
'[TOOL_CALLS]add{"a": 1, "b": 2}'
parser
=
request
.
getfixturevalue
(
parser_fixture
)
# Token IDs that do NOT contain bot_token_id.
fake_token_ids
=
list
(
range
(
99
,
99
+
20
))
fake_token_ids
=
list
(
range
(
99
,
99
+
fake_count
))
if
two_phase
:
# First delta: pure content, no bot token yet
delta_message_before
=
mistral_tool_
parser
.
extract_tool_calls_streaming
(
delta_message_before
=
parser
.
extract_tool_calls_streaming
(
previous_text
=
""
,
current_text
=
"Hello"
,
delta_text
=
"Hello"
,
previous_token_ids
=
[],
current_token_ids
=
[
99
],
delta_token_ids
=
[
99
],
request
=
None
,
request
=
_DUMMY_REQUEST
,
)
assert
delta_message_before
is
not
None
assert
delta_message_before
.
content
==
"Hello"
assert
not
delta_message_before
.
tool_calls
# Second delta: bot token in text but NOT in token_ids
delta_message
=
mistral_tool_parser
.
extract_tool_calls_streaming
(
previous_text
=
"Hello"
,
current_text
=
"Hello"
+
model_output
,
previous_text
=
"Hello"
current_text
=
"Hello"
+
model_output
previous_token_ids
=
[
99
]
delta_token_ids
=
fake_token_ids
[
1
:]
else
:
previous_text
=
""
current_text
=
model_output
previous_token_ids
=
[]
delta_token_ids
=
fake_token_ids
delta_message
=
parser
.
extract_tool_calls_streaming
(
previous_text
=
previous_text
,
current_text
=
current_text
,
delta_text
=
model_output
,
previous_token_ids
=
[
99
]
,
previous_token_ids
=
previous_token_ids
,
current_token_ids
=
fake_token_ids
,
delta_token_ids
=
fake
_token_ids
[
1
:]
,
request
=
None
,
delta_token_ids
=
delta
_token_ids
,
request
=
_DUMMY_REQUEST
,
)
assert
delta_message
is
not
None
assert
delta_message
.
tool_calls
is
not
None
assert
len
(
delta_message
.
tool_calls
)
>
0
assert
len
(
delta_message
.
tool_calls
)
==
1
assert
delta_message
.
tool_calls
[
0
].
function
is
not
None
assert
delta_message
.
tool_calls
[
0
].
function
.
name
==
"add"
def
test_fast_detokenization_text_detection_pre_v11
(
mistral_pre_v11_tool_parser
,
@
pytest
.
mark
.
parametrize
(
"parser_fixture, patched_method, current_text"
,
[
(
"mistral_tool_parser"
,
"_extract_tool_calls_streaming"
,
"[TOOL_CALLS]add{}"
,
),
(
"mistral_pre_v11_tool_parser"
,
"_extract_tool_calls_streaming_pre_v11_tokenizer"
,
'[TOOL_CALLS] [{"name":"a","arguments":{}}]'
,
),
],
ids
=
[
"v11"
,
"pre_v11"
],
)
def
test_extract_tool_calls_streaming_exception_returns_none
(
parser_fixture
,
patched_method
,
current_text
,
request
):
"""Regression: bot_token text detection for pre-v11 tokenizer (PR #37209)."""
model_output
=
'[TOOL_CALLS] [{"name": "add", "arguments":{"a": 1, "b": 2}}]'
fake_token_ids
=
list
(
range
(
99
,
99
+
30
))
delta_message
=
mistral_pre_v11_tool_parser
.
extract_tool_calls_streaming
(
parser
=
request
.
getfixturevalue
(
parser_fixture
)
with
patch
.
object
(
parser
,
patched_method
,
side_effect
=
RuntimeError
(
"boom"
)):
result
=
parser
.
extract_tool_calls_streaming
(
previous_text
=
""
,
current_text
=
model_outpu
t
,
delta_text
=
model_outpu
t
,
current_text
=
current_tex
t
,
delta_text
=
current_tex
t
,
previous_token_ids
=
[],
current_token_ids
=
fake
_token_id
s
,
delta_token_ids
=
fake
_token_id
s
,
request
=
None
,
current_token_ids
=
[
parser
.
bot
_token_id
]
,
delta_token_ids
=
[
parser
.
bot
_token_id
]
,
request
=
_DUMMY_REQUEST
,
)
assert
delta_message
is
not
None
assert
delta_message
.
tool_calls
is
not
None
assert
len
(
delta_message
.
tool_calls
)
>
0
assert
delta_message
.
tool_calls
[
0
].
function
is
not
None
assert
delta_message
.
tool_calls
[
0
].
function
.
name
==
"add"
assert
result
is
None
SAMPLE_TOOLS_DICTS
=
[
...
...
@@ -1238,57 +1423,444 @@ def test_adjust_request_response_format_generates_grammar(
assert
len
(
result
.
structured_outputs
.
grammar
)
>
0
def
test_adjust_request_tool_choice_none_with_json_schema_uses_json_schema_factory
(
@
pytest
.
mark
.
parametrize
(
"tool_choice, expected_method, not_called_method"
,
[
(
"none"
,
"get_lark_for_json_schema"
,
None
),
(
"auto"
,
"get_lark_from_jinja"
,
"get_lark_for_json_schema"
),
],
ids
=
[
"none_uses_json_schema_factory"
,
"auto_uses_jinja_factory"
],
)
def
test_adjust_request_tool_choice_with_json_schema_factory_routing
(
mistral_tool_parser
:
MistralToolParser
,
tool_choice
:
str
,
expected_method
:
str
,
not_called_method
:
str
|
None
,
)
->
None
:
request
=
_make_request
(
tool_choice
=
"none"
,
tool_choice
=
tool_choice
,
structured_outputs
=
StructuredOutputsParams
(
json
=
'{"type": "object"}'
),
)
factory
=
mistral_tool_parser
.
model_tokenizer
.
grammar_factory
with
patch
.
object
(
patches
=
{
expected_method
:
patch
.
object
(
factory
,
expected_method
,
wraps
=
getattr
(
factory
,
expected_method
),
),
}
if
not_called_method
:
patches
[
not_called_method
]
=
patch
.
object
(
factory
,
"get_lark_for_json_schema"
,
wraps
=
factory
.
get_lark_for_json_schema
,
)
as
mock_json_schema
:
not_called_method
,
wraps
=
getattr
(
factory
,
not_called_method
),
)
with
patches
[
expected_method
]
as
mock_expected
:
ctx
=
patches
[
not_called_method
]
if
not_called_method
else
None
if
ctx
:
with
ctx
as
mock_not_called
:
result
=
mistral_tool_parser
.
adjust_request
(
request
)
mock_not_called
.
assert_not_called
()
else
:
result
=
mistral_tool_parser
.
adjust_request
(
request
)
mock_
json_schema
.
assert_called_once
()
assert
mock_
json_schema
.
call_args
.
kwargs
[
"json_schema"
]
==
{
"type"
:
"object"
}
mock_
expected
.
assert_called_once
()
assert
mock_
expected
.
call_args
.
kwargs
[
"json_schema"
]
==
{
"type"
:
"object"
}
assert
result
.
structured_outputs
is
not
None
assert
isinstance
(
result
.
structured_outputs
.
grammar
,
str
)
assert
len
(
result
.
structured_outputs
.
grammar
)
>
0
def
test_adjust_request_tool_choice_auto_with_json_schema_uses_jinja_factory
(
def
test_grammar_from_tool_parser_default_false
()
->
None
:
request
=
_make_request
()
assert
request
.
_grammar_from_tool_parser
is
False
def
test_grammar_from_tool_parser_set_by_adjust_request
(
mistral_tool_parser
:
MistralToolParser
,
)
->
None
:
request
=
_make_request
(
tool_choice
=
"auto"
,
structured_outputs
=
StructuredOutputsParams
(
json
=
'{"type": "object"}'
),
)
factory
=
mistral_tool_parser
.
model_tokenizer
.
grammar_factory
request
=
_make_request
()
result
=
mistral_tool_parser
.
adjust_request
(
request
)
assert
result
.
_grammar_from_tool_parser
is
True
with
(
patch
.
object
(
factory
,
"get_lark_for_json_schema"
,
wraps
=
factory
.
get_lark_for_json_schema
,
)
as
mock_json_schema
,
patch
.
object
(
factory
,
"get_lark_from_jinja"
,
wraps
=
factory
.
get_lark_from_jinja
,
)
as
mock_jinja
,
@
pytest
.
mark
.
parametrize
(
"tool_calls, expected_len"
,
[
(
None
,
0
),
([],
0
),
([
VllmFunctionCall
(
id
=
"abc123xyz"
,
name
=
"f"
,
arguments
=
"{}"
)],
1
),
([
VllmFunctionCall
(
name
=
"f"
,
arguments
=
"{}"
)],
1
),
(
[
VllmFunctionCall
(
id
=
"fixed1234"
,
name
=
"a"
,
arguments
=
'{"x": 1}'
),
VllmFunctionCall
(
name
=
"b"
,
arguments
=
'{"y": 2}'
),
],
2
,
),
],
ids
=
[
"none"
,
"empty"
,
"with_id"
,
"without_id"
,
"mixed"
],
)
def
test_build_non_streaming_tool_calls
(
tool_calls
:
list
[
VllmFunctionCall
]
|
None
,
expected_len
:
int
,
)
->
None
:
result
=
MistralToolParser
.
build_non_streaming_tool_calls
(
tool_calls
)
assert
len
(
result
)
==
expected_len
if
tool_calls
is
None
:
return
for
i
,
tc
in
enumerate
(
result
):
assert
isinstance
(
tc
,
MistralToolCall
)
assert
tc
.
type
==
"function"
input_tc
=
tool_calls
[
i
]
if
input_tc
.
id
:
assert
tc
.
id
==
input_tc
.
id
else
:
assert
len
(
tc
.
id
)
==
9
assert
tc
.
id
.
isalnum
()
assert
tc
.
function
.
name
==
input_tc
.
name
assert
tc
.
function
.
arguments
==
input_tc
.
arguments
class
TestExtractMaybeReasoningAndToolStreaming
:
r
"""Tests for `MistralToolParser.extract_maybe_reasoning_and_tool_streaming`."""
@
pytest
.
fixture
def
parser
(
self
)
->
MistralToolParser
:
mock_tokenizer
=
MagicMock
()
mock_tokenizer
.
get_vocab
.
return_value
=
{
"[TOOL_CALLS]"
:
1
}
return
MistralToolParser
(
mock_tokenizer
)
@
pytest
.
fixture
def
request_obj
(
self
)
->
ChatCompletionRequest
:
return
_make_request
()
@
staticmethod
def
_call
(
parser
:
MistralToolParser
,
request
:
ChatCompletionRequest
,
*
,
reasoning_parser
:
Any
=
None
,
previous_text
:
str
=
""
,
current_text
:
str
=
"hello"
,
delta_text
:
str
=
"hello"
,
previous_token_ids
:
list
[
int
]
|
None
=
None
,
current_token_ids
:
list
[
int
]
|
None
=
None
,
output_token_ids
:
list
[
int
]
|
None
=
None
,
reasoning_ended
:
bool
=
False
,
prompt_is_reasoning_end
:
bool
|
None
=
None
,
)
->
MistralStreamingResult
:
return
parser
.
extract_maybe_reasoning_and_tool_streaming
(
reasoning_parser
=
reasoning_parser
,
previous_text
=
previous_text
,
current_text
=
current_text
,
delta_text
=
delta_text
,
previous_token_ids
=
previous_token_ids
or
[],
current_token_ids
=
current_token_ids
or
[
1
,
2
,
3
],
output_token_ids
=
output_token_ids
or
[
1
,
2
,
3
],
reasoning_ended
=
reasoning_ended
,
prompt_is_reasoning_end
=
prompt_is_reasoning_end
,
request
=
request
,
)
def
test_no_reasoning_tools_called
(
self
,
parser
:
MistralToolParser
,
request_obj
:
ChatCompletionRequest
)
->
None
:
tool_delta
=
DeltaMessage
(
tool_calls
=
[
DeltaToolCall
(
index
=
0
,
function
=
DeltaFunctionCall
(
name
=
"f"
,
arguments
=
"{}"
),
)
]
)
with
patch
.
object
(
parser
,
"extract_tool_calls_streaming"
,
return_value
=
tool_delta
):
result
=
mistral_tool_parser
.
adjust_request
(
request
)
result
=
self
.
_call
(
parser
,
request_obj
,
reasoning_parser
=
None
)
mock_jinja
.
assert_called_once
()
assert
mock_jinja
.
call_args
.
kwargs
[
"json_schema"
]
==
{
"type"
:
"object"
}
mock_json_schema
.
assert_not_called
()
assert
result
==
MistralStreamingResult
(
delta_message
=
tool_delta
,
reasoning_ended
=
False
,
tools_called
=
True
,
current_text
=
"hello"
,
current_token_ids
=
[
1
,
2
,
3
],
)
assert
result
.
structured_outputs
is
not
None
assert
isinstance
(
result
.
structured_outputs
.
grammar
,
str
)
assert
len
(
result
.
structured_outputs
.
grammar
)
>
0
def
test_no_reasoning_no_tools
(
self
,
parser
:
MistralToolParser
,
request_obj
:
ChatCompletionRequest
)
->
None
:
content_delta
=
DeltaMessage
(
content
=
"hello"
)
with
patch
.
object
(
parser
,
"extract_tool_calls_streaming"
,
return_value
=
content_delta
):
result
=
self
.
_call
(
parser
,
request_obj
,
reasoning_parser
=
None
)
assert
result
==
MistralStreamingResult
(
delta_message
=
content_delta
,
reasoning_ended
=
False
,
tools_called
=
False
,
current_text
=
"hello"
,
current_token_ids
=
[
1
,
2
,
3
],
)
def
test_mistral_reasoning_parser_no_think_token
(
self
,
parser
:
MistralToolParser
,
request_obj
:
ChatCompletionRequest
)
->
None
:
mock_rp
=
MagicMock
(
spec
=
MistralReasoningParser
)
mock_rp
.
start_token_id
=
999
content_delta
=
DeltaMessage
(
content
=
"direct"
)
with
patch
.
object
(
parser
,
"extract_tool_calls_streaming"
,
return_value
=
content_delta
):
result
=
self
.
_call
(
parser
,
request_obj
,
reasoning_parser
=
mock_rp
,
reasoning_ended
=
False
,
current_token_ids
=
[
1
,
2
,
3
],
)
mock_rp
.
extract_reasoning_streaming
.
assert_not_called
()
assert
result
==
MistralStreamingResult
(
delta_message
=
content_delta
,
reasoning_ended
=
False
,
tools_called
=
False
,
current_text
=
"hello"
,
current_token_ids
=
[
1
,
2
,
3
],
)
def
test_mistral_reasoning_parser_with_think_token
(
self
,
parser
:
MistralToolParser
,
request_obj
:
ChatCompletionRequest
)
->
None
:
mock_rp
=
MagicMock
(
spec
=
MistralReasoningParser
)
mock_rp
.
start_token_id
=
999
mock_rp
.
extract_reasoning_streaming
.
return_value
=
DeltaMessage
(
reasoning
=
"thinking..."
)
mock_rp
.
is_reasoning_end_streaming
.
return_value
=
False
result
=
self
.
_call
(
parser
,
request_obj
,
reasoning_parser
=
mock_rp
,
reasoning_ended
=
False
,
current_token_ids
=
[
1
,
999
,
3
],
)
mock_rp
.
extract_reasoning_streaming
.
assert_called_once
()
assert
result
==
MistralStreamingResult
(
delta_message
=
DeltaMessage
(
reasoning
=
"thinking..."
),
reasoning_ended
=
False
,
tools_called
=
False
,
current_text
=
"hello"
,
current_token_ids
=
[
1
,
999
,
3
],
)
def
test_non_mistral_reasoning_parser_always_expects_thinking
(
self
,
parser
:
MistralToolParser
,
request_obj
:
ChatCompletionRequest
)
->
None
:
mock_rp
=
MagicMock
()
mock_rp
.
start_token_id
=
999
mock_rp
.
extract_reasoning_streaming
.
return_value
=
DeltaMessage
(
reasoning
=
"thinking..."
)
mock_rp
.
is_reasoning_end_streaming
.
return_value
=
False
result
=
self
.
_call
(
parser
,
request_obj
,
reasoning_parser
=
mock_rp
,
reasoning_ended
=
False
,
current_token_ids
=
[
1
,
2
,
3
],
)
mock_rp
.
extract_reasoning_streaming
.
assert_called_once
()
assert
result
==
MistralStreamingResult
(
delta_message
=
DeltaMessage
(
reasoning
=
"thinking..."
),
reasoning_ended
=
False
,
tools_called
=
False
,
current_text
=
"hello"
,
current_token_ids
=
[
1
,
2
,
3
],
)
def
test_reasoning_already_ended_no_reset
(
self
,
parser
:
MistralToolParser
,
request_obj
:
ChatCompletionRequest
)
->
None
:
content_delta
=
DeltaMessage
(
content
=
"content"
)
with
patch
.
object
(
parser
,
"extract_tool_calls_streaming"
,
return_value
=
content_delta
)
as
mock_extract
:
result
=
self
.
_call
(
parser
,
request_obj
,
reasoning_parser
=
MagicMock
(),
reasoning_ended
=
True
,
previous_text
=
"prior_tool_text"
,
previous_token_ids
=
[
10
,
20
],
current_text
=
"prior_tool_texthello"
,
current_token_ids
=
[
10
,
20
,
1
,
2
,
3
],
)
_
,
call_kwargs
=
mock_extract
.
call_args
assert
call_kwargs
[
"previous_text"
]
==
"prior_tool_text"
assert
call_kwargs
[
"previous_token_ids"
]
==
[
10
,
20
]
assert
result
==
MistralStreamingResult
(
delta_message
=
content_delta
,
reasoning_ended
=
True
,
tools_called
=
False
,
current_text
=
"prior_tool_texthello"
,
current_token_ids
=
[
10
,
20
,
1
,
2
,
3
],
)
def
test_pre_v15_ignores_prompt_reasoning_end
(
self
,
parser
:
MistralToolParser
,
request_obj
:
ChatCompletionRequest
)
->
None
:
mock_tokenizer
=
MagicMock
(
spec
=
MistralTokenizer
)
mock_tokenizer
.
version
=
13
parser
.
model_tokenizer
=
mock_tokenizer
mock_rp
=
MagicMock
(
spec
=
MistralReasoningParser
)
mock_rp
.
start_token_id
=
999
mock_rp
.
extract_reasoning_streaming
.
return_value
=
DeltaMessage
(
reasoning
=
"thinking..."
)
mock_rp
.
is_reasoning_end_streaming
.
return_value
=
False
result
=
self
.
_call
(
parser
,
request_obj
,
reasoning_parser
=
mock_rp
,
reasoning_ended
=
False
,
prompt_is_reasoning_end
=
True
,
current_token_ids
=
[
999
,
1
,
2
],
)
mock_rp
.
extract_reasoning_streaming
.
assert_called_once
()
assert
result
==
MistralStreamingResult
(
delta_message
=
DeltaMessage
(
reasoning
=
"thinking..."
),
reasoning_ended
=
False
,
tools_called
=
False
,
current_text
=
"hello"
,
current_token_ids
=
[
999
,
1
,
2
],
)
def
test_non_pre_v15_prompt_reasoning_end
(
self
,
parser
:
MistralToolParser
,
request_obj
:
ChatCompletionRequest
)
->
None
:
mock_tokenizer
=
MagicMock
(
spec
=
MistralTokenizer
)
mock_tokenizer
.
version
=
15
parser
.
model_tokenizer
=
mock_tokenizer
mock_rp
=
MagicMock
(
spec
=
MistralReasoningParser
)
mock_rp
.
start_token_id
=
999
content_delta
=
DeltaMessage
(
content
=
"after reasoning"
)
with
patch
.
object
(
parser
,
"extract_tool_calls_streaming"
,
return_value
=
content_delta
):
result
=
self
.
_call
(
parser
,
request_obj
,
reasoning_parser
=
mock_rp
,
reasoning_ended
=
False
,
prompt_is_reasoning_end
=
True
,
current_token_ids
=
[
999
,
1
,
2
],
output_token_ids
=
[
10
,
20
,
30
],
)
mock_rp
.
extract_reasoning_streaming
.
assert_not_called
()
assert
result
==
MistralStreamingResult
(
delta_message
=
content_delta
,
reasoning_ended
=
True
,
tools_called
=
False
,
current_text
=
"hello"
,
current_token_ids
=
[
10
,
20
,
30
],
)
def
test_reasoning_end_transition_with_content
(
self
,
parser
:
MistralToolParser
,
request_obj
:
ChatCompletionRequest
)
->
None
:
"""When reasoning ends and the delta has content, that content is
cleared from delta_message and used as current_text for tool parsing."""
mock_rp
=
MagicMock
()
mock_rp
.
start_token_id
=
999
mock_rp
.
extract_reasoning_streaming
.
return_value
=
DeltaMessage
(
reasoning
=
"think"
,
content
=
"leftover"
)
mock_rp
.
is_reasoning_end_streaming
.
return_value
=
True
mock_rp
.
extract_content_ids
.
return_value
=
[
50
,
51
]
content_delta
=
DeltaMessage
(
content
=
"leftover"
)
with
patch
.
object
(
parser
,
"extract_tool_calls_streaming"
,
return_value
=
content_delta
)
as
mock_extract
:
result
=
self
.
_call
(
parser
,
request_obj
,
reasoning_parser
=
mock_rp
,
reasoning_ended
=
False
,
current_token_ids
=
[
999
,
1
,
2
],
output_token_ids
=
[
10
,
20
,
30
],
)
mock_rp
.
extract_content_ids
.
assert_called_once_with
([
10
,
20
,
30
])
_
,
call_kwargs
=
mock_extract
.
call_args
assert
call_kwargs
[
"previous_text"
]
==
""
assert
call_kwargs
[
"previous_token_ids"
]
==
[]
assert
call_kwargs
[
"delta_text"
]
==
"leftover"
assert
call_kwargs
[
"current_token_ids"
]
==
[
50
,
51
]
assert
result
==
MistralStreamingResult
(
delta_message
=
content_delta
,
reasoning_ended
=
True
,
tools_called
=
False
,
current_text
=
"leftover"
,
current_token_ids
=
[
50
,
51
],
)
def
test_reasoning_end_transition_without_content
(
self
,
parser
:
MistralToolParser
,
request_obj
:
ChatCompletionRequest
)
->
None
:
"""When reasoning ends but the delta has no content, current_text
is set to empty string."""
mock_rp
=
MagicMock
()
mock_rp
.
start_token_id
=
999
mock_rp
.
extract_reasoning_streaming
.
return_value
=
DeltaMessage
(
reasoning
=
"think"
)
mock_rp
.
is_reasoning_end_streaming
.
return_value
=
True
mock_rp
.
extract_content_ids
.
return_value
=
[
50
,
51
]
empty_delta
=
DeltaMessage
(
content
=
""
)
with
patch
.
object
(
parser
,
"extract_tool_calls_streaming"
,
return_value
=
empty_delta
)
as
mock_extract
:
result
=
self
.
_call
(
parser
,
request_obj
,
reasoning_parser
=
mock_rp
,
reasoning_ended
=
False
,
current_token_ids
=
[
999
,
1
,
2
],
output_token_ids
=
[
10
,
20
,
30
],
)
_
,
call_kwargs
=
mock_extract
.
call_args
assert
call_kwargs
[
"delta_text"
]
==
""
assert
call_kwargs
[
"current_token_ids"
]
==
[
50
,
51
]
assert
result
==
MistralStreamingResult
(
delta_message
=
empty_delta
,
reasoning_ended
=
True
,
tools_called
=
False
,
current_text
=
""
,
current_token_ids
=
[
50
,
51
],
)
tests/tool_use/mistral/test_mistral_tool_calls.py
View file @
c0722f22
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
json
from
dataclasses
import
dataclass
,
field
import
openai
import
pytest
from
tests.tool_use.utils
import
MESSAGES_ASKING_FOR_TOOLS
,
WEATHER_TOOL
from
tests.tool_use.utils
import
(
MESSAGES_ASKING_FOR_PARALLEL_TOOLS
,
MESSAGES_ASKING_FOR_TOOLS
,
MESSAGES_WITH_TOOL_RESPONSE
,
MESSAGES_WITHOUT_TOOLS
,
SEARCH_TOOL
,
SEED
,
WEATHER_TOOL
,
ensure_system_prompt
,
)
from
.utils
import
ServerConfig
def
_requires_tool_parser
(
server_config
:
ServerConfig
)
->
None
:
r
"""Skip test if server was not started with --tool-call-parser."""
if
"--tool-call-parser"
not
in
server_config
.
get
(
"arguments"
,
[]):
pytest
.
skip
(
f
"Skipping:
{
server_config
[
'model'
]
}
not configured with --tool-call-parser"
)
def
_is_pre_v11
(
server_config
:
ServerConfig
)
->
bool
:
r
"""Pre-v11 Mistral models lack grammar-based tool call enforcement."""
return
"7B"
in
server_config
.
get
(
"model"
,
""
)
@
dataclass
class
StreamedToolCallResult
:
r
"""Accumulated result from streaming a single tool call."""
function_name
:
str
|
None
=
None
function_args_str
:
str
=
""
tool_call_id
:
str
|
None
=
None
role_name
:
str
|
None
=
None
finish_reason_count
:
int
=
0
finish_reason
:
str
|
None
=
None
async
def
_collect_streamed_tool_call
(
stream
:
openai
.
AsyncStream
,
*
,
expected_finish_reason
:
str
=
"tool_calls"
,
)
->
StreamedToolCallResult
:
result
=
StreamedToolCallResult
()
async
for
chunk
in
stream
:
if
chunk
.
choices
[
0
].
finish_reason
:
result
.
finish_reason_count
+=
1
result
.
finish_reason
=
chunk
.
choices
[
0
].
finish_reason
assert
chunk
.
choices
[
0
].
finish_reason
==
expected_finish_reason
if
chunk
.
choices
[
0
].
delta
.
role
:
assert
not
result
.
role_name
or
result
.
role_name
==
"assistant"
result
.
role_name
=
"assistant"
streamed_tool_calls
=
chunk
.
choices
[
0
].
delta
.
tool_calls
if
streamed_tool_calls
and
len
(
streamed_tool_calls
)
>
0
:
assert
len
(
streamed_tool_calls
)
==
1
tool_call
=
streamed_tool_calls
[
0
]
if
tool_call
.
id
:
assert
not
result
.
tool_call_id
result
.
tool_call_id
=
tool_call
.
id
if
tool_call
.
function
:
if
tool_call
.
function
.
name
:
assert
result
.
function_name
is
None
result
.
function_name
=
tool_call
.
function
.
name
if
tool_call
.
function
.
arguments
:
result
.
function_args_str
+=
tool_call
.
function
.
arguments
return
result
@
dataclass
class
StreamedContentResult
:
r
"""Accumulated result from streaming a content-only response."""
chunks
:
list
[
str
]
=
field
(
default_factory
=
list
)
finish_reason_count
:
int
=
0
finish_reason
:
str
|
None
=
None
role_sent
:
bool
=
False
async
def
_collect_streamed_content
(
stream
:
openai
.
AsyncStream
,
*
,
expected_finish_reason
:
str
|
None
=
None
,
no_tool_calls
:
bool
=
True
,
)
->
StreamedContentResult
:
r
"""Consume a streaming response and collect text content."""
result
=
StreamedContentResult
()
async
for
chunk
in
stream
:
delta
=
chunk
.
choices
[
0
].
delta
if
delta
.
role
:
assert
not
result
.
role_sent
assert
delta
.
role
==
"assistant"
result
.
role_sent
=
True
if
delta
.
content
:
result
.
chunks
.
append
(
delta
.
content
)
if
chunk
.
choices
[
0
].
finish_reason
is
not
None
:
result
.
finish_reason_count
+=
1
result
.
finish_reason
=
chunk
.
choices
[
0
].
finish_reason
if
expected_finish_reason
is
not
None
:
assert
result
.
finish_reason
==
expected_finish_reason
if
no_tool_calls
:
assert
not
delta
.
tool_calls
or
len
(
delta
.
tool_calls
)
==
0
return
result
@
dataclass
class
StreamedParallelToolCallResult
:
r
"""Accumulated result from streaming parallel tool calls."""
function_names
:
list
[
str
]
=
field
(
default_factory
=
list
)
function_args_strs
:
list
[
str
]
=
field
(
default_factory
=
list
)
tool_call_ids
:
list
[
str
]
=
field
(
default_factory
=
list
)
role_name
:
str
|
None
=
None
finish_reason_count
:
int
=
0
async
def
_collect_streamed_parallel_tool_calls
(
stream
:
openai
.
AsyncStream
,
)
->
StreamedParallelToolCallResult
:
r
"""Consume a streaming response and collect parallel tool calls."""
result
=
StreamedParallelToolCallResult
()
tool_call_idx
:
int
=
-
1
async
for
chunk
in
stream
:
if
chunk
.
choices
[
0
].
finish_reason
:
result
.
finish_reason_count
+=
1
assert
chunk
.
choices
[
0
].
finish_reason
==
"tool_calls"
if
chunk
.
choices
[
0
].
delta
.
role
:
assert
not
result
.
role_name
or
result
.
role_name
==
"assistant"
result
.
role_name
=
"assistant"
streamed_tool_calls
=
chunk
.
choices
[
0
].
delta
.
tool_calls
if
streamed_tool_calls
and
len
(
streamed_tool_calls
)
>
0
:
assert
len
(
streamed_tool_calls
)
==
1
tool_call
=
streamed_tool_calls
[
0
]
if
tool_call
.
index
!=
tool_call_idx
:
tool_call_idx
=
tool_call
.
index
result
.
function_args_strs
.
append
(
""
)
result
.
tool_call_ids
.
append
(
""
)
if
tool_call
.
id
:
result
.
tool_call_ids
[
tool_call
.
index
]
=
tool_call
.
id
if
tool_call
.
function
:
if
tool_call
.
function
.
name
:
result
.
function_names
.
append
(
tool_call
.
function
.
name
)
if
tool_call
.
function
.
arguments
:
result
.
function_args_strs
[
tool_call
.
index
]
+=
(
tool_call
.
function
.
arguments
)
return
result
# test: a tool_choice with mistral-tokenizer results in an ID of length 9
@
pytest
.
mark
.
asyncio
async
def
test_tool_call_with_tool_choice
(
client
:
openai
.
AsyncOpenAI
):
async
def
test_tool_call_with_tool_choice
(
client
:
openai
.
AsyncOpenAI
,
server_config
:
ServerConfig
)
->
None
:
_requires_tool_parser
(
server_config
)
models
=
await
client
.
models
.
list
()
model_name
:
str
=
models
.
data
[
0
].
id
chat_completion
=
await
client
.
chat
.
completions
.
create
(
messages
=
MESSAGES_ASKING_FOR_TOOLS
,
messages
=
ensure_system_prompt
(
MESSAGES_ASKING_FOR_TOOLS
,
server_config
),
temperature
=
0
,
max_completion_tokens
=
100
,
model
=
model_name
,
tools
=
[
WEATHER_TOOL
],
tool_choice
=
WEATHER_TOOL
,
logprobs
=
False
,
seed
=
SEED
,
)
choice
=
chat_completion
.
choices
[
0
]
...
...
@@ -28,3 +201,307 @@ async def test_tool_call_with_tool_choice(client: openai.AsyncOpenAI):
assert
choice
.
message
.
role
==
"assistant"
assert
choice
.
message
.
tool_calls
is
None
or
len
(
choice
.
message
.
tool_calls
)
==
1
assert
len
(
choice
.
message
.
tool_calls
[
0
].
id
)
==
9
# length of 9 for mistral
_NOT_SET
=
object
()
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
parametrize
(
"tools, tool_choice, streaming_id_len_pre_v11"
,
[
pytest
.
param
(
[
WEATHER_TOOL
,
SEARCH_TOOL
],
_NOT_SET
,
9
,
id
=
"auto"
,
),
pytest
.
param
(
[
WEATHER_TOOL
],
"required"
,
30
,
id
=
"required"
,
),
],
)
async
def
test_tool_call_auto_or_required
(
client
:
openai
.
AsyncOpenAI
,
server_config
:
ServerConfig
,
tools
:
list
,
tool_choice
:
object
,
streaming_id_len_pre_v11
:
int
,
)
->
None
:
_requires_tool_parser
(
server_config
)
models
=
await
client
.
models
.
list
()
model_name
:
str
=
models
.
data
[
0
].
id
create_kwargs
:
dict
=
{
"messages"
:
ensure_system_prompt
(
MESSAGES_ASKING_FOR_TOOLS
,
server_config
),
"temperature"
:
0
,
"max_completion_tokens"
:
100
,
"model"
:
model_name
,
"tools"
:
tools
,
"logprobs"
:
False
,
"seed"
:
SEED
,
}
if
tool_choice
is
not
_NOT_SET
:
create_kwargs
[
"tool_choice"
]
=
tool_choice
# --- non-streaming ---
chat_completion
=
await
client
.
chat
.
completions
.
create
(
**
create_kwargs
)
choice
=
chat_completion
.
choices
[
0
]
tool_calls
=
choice
.
message
.
tool_calls
assert
choice
.
finish_reason
==
"tool_calls"
assert
tool_calls
is
not
None
and
len
(
tool_calls
)
>=
1
assert
tool_calls
[
0
].
function
.
name
==
"get_current_weather"
parsed_arguments
=
json
.
loads
(
tool_calls
[
0
].
function
.
arguments
)
assert
"city"
in
parsed_arguments
assert
len
(
tool_calls
[
0
].
id
)
==
9
# --- streaming ---
stream
=
await
client
.
chat
.
completions
.
create
(
**
create_kwargs
,
stream
=
True
)
result
=
await
_collect_streamed_tool_call
(
stream
)
assert
result
.
finish_reason_count
==
1
assert
result
.
role_name
==
"assistant"
assert
result
.
function_name
==
"get_current_weather"
streamed_args
=
json
.
loads
(
result
.
function_args_str
)
assert
isinstance
(
result
.
tool_call_id
,
str
)
if
_is_pre_v11
(
server_config
):
assert
len
(
result
.
tool_call_id
)
==
streaming_id_len_pre_v11
else
:
assert
len
(
result
.
tool_call_id
)
==
9
assert
parsed_arguments
==
streamed_args
@
pytest
.
mark
.
asyncio
async
def
test_tool_call_none_with_tools
(
client
:
openai
.
AsyncOpenAI
,
server_config
:
ServerConfig
)
->
None
:
_requires_tool_parser
(
server_config
)
models
=
await
client
.
models
.
list
()
model_name
:
str
=
models
.
data
[
0
].
id
# --- non-streaming ---
chat_completion
=
await
client
.
chat
.
completions
.
create
(
messages
=
ensure_system_prompt
(
MESSAGES_ASKING_FOR_TOOLS
,
server_config
),
temperature
=
0
,
max_completion_tokens
=
100
,
model
=
model_name
,
tools
=
[
WEATHER_TOOL
],
tool_choice
=
"none"
,
logprobs
=
False
,
seed
=
SEED
,
)
choice
=
chat_completion
.
choices
[
0
]
assert
choice
.
finish_reason
!=
"tool_calls"
assert
choice
.
message
.
tool_calls
is
None
or
len
(
choice
.
message
.
tool_calls
)
==
0
assert
choice
.
message
.
content
is
not
None
# Without grammar enforcement, pre-v11 models may still emit [TOOL_CALLS]
if
not
_is_pre_v11
(
server_config
):
assert
"[TOOL_CALLS]"
not
in
choice
.
message
.
content
non_streaming_content
=
choice
.
message
.
content
# --- streaming ---
stream
=
await
client
.
chat
.
completions
.
create
(
messages
=
ensure_system_prompt
(
MESSAGES_ASKING_FOR_TOOLS
,
server_config
),
temperature
=
0
,
max_completion_tokens
=
100
,
model
=
model_name
,
tools
=
[
WEATHER_TOOL
],
tool_choice
=
"none"
,
logprobs
=
False
,
seed
=
SEED
,
stream
=
True
,
)
# Pre-v11 models lack grammar enforcement, so the model may still
# emit tool calls even with tool_choice="none".
pre_v11
=
_is_pre_v11
(
server_config
)
result
=
await
_collect_streamed_content
(
stream
,
no_tool_calls
=
not
pre_v11
)
assert
result
.
finish_reason_count
==
1
if
not
pre_v11
:
assert
result
.
finish_reason
!=
"tool_calls"
streamed_content
=
""
.
join
(
result
.
chunks
)
if
not
pre_v11
:
assert
"[TOOL_CALLS]"
not
in
streamed_content
assert
streamed_content
==
non_streaming_content
@
pytest
.
mark
.
asyncio
async
def
test_chat_without_tools
(
client
:
openai
.
AsyncOpenAI
,
server_config
:
ServerConfig
)
->
None
:
models
=
await
client
.
models
.
list
()
model_name
:
str
=
models
.
data
[
0
].
id
# --- non-streaming ---
chat_completion
=
await
client
.
chat
.
completions
.
create
(
messages
=
ensure_system_prompt
(
MESSAGES_WITHOUT_TOOLS
,
server_config
),
temperature
=
0
,
max_completion_tokens
=
150
,
model
=
model_name
,
logprobs
=
False
,
seed
=
SEED
,
)
choice
=
chat_completion
.
choices
[
0
]
output_text
=
choice
.
message
.
content
assert
output_text
is
not
None
and
len
(
output_text
)
>
0
assert
choice
.
finish_reason
!=
"tool_calls"
assert
choice
.
message
.
tool_calls
is
None
or
len
(
choice
.
message
.
tool_calls
)
==
0
# --- streaming ---
stream
=
await
client
.
chat
.
completions
.
create
(
messages
=
ensure_system_prompt
(
MESSAGES_WITHOUT_TOOLS
,
server_config
),
temperature
=
0
,
max_completion_tokens
=
150
,
model
=
model_name
,
logprobs
=
False
,
seed
=
SEED
,
stream
=
True
,
)
result
=
await
_collect_streamed_content
(
stream
,
expected_finish_reason
=
choice
.
finish_reason
)
assert
result
.
role_sent
assert
result
.
finish_reason_count
==
1
assert
len
(
result
.
chunks
)
assert
""
.
join
(
result
.
chunks
)
==
output_text
@
pytest
.
mark
.
asyncio
async
def
test_tool_call_with_results
(
client
:
openai
.
AsyncOpenAI
,
server_config
:
ServerConfig
)
->
None
:
_requires_tool_parser
(
server_config
)
models
=
await
client
.
models
.
list
()
model_name
:
str
=
models
.
data
[
0
].
id
# --- non-streaming ---
chat_completion
=
await
client
.
chat
.
completions
.
create
(
messages
=
ensure_system_prompt
(
MESSAGES_WITH_TOOL_RESPONSE
,
server_config
),
temperature
=
0
,
max_completion_tokens
=
100
,
model
=
model_name
,
tools
=
[
WEATHER_TOOL
,
SEARCH_TOOL
],
logprobs
=
False
,
seed
=
SEED
,
)
choice
=
chat_completion
.
choices
[
0
]
assert
choice
.
finish_reason
!=
"tool_calls"
assert
choice
.
message
.
tool_calls
is
None
or
len
(
choice
.
message
.
tool_calls
)
==
0
assert
choice
.
message
.
content
is
not
None
assert
"98"
in
choice
.
message
.
content
# --- streaming ---
stream
=
await
client
.
chat
.
completions
.
create
(
messages
=
ensure_system_prompt
(
MESSAGES_WITH_TOOL_RESPONSE
,
server_config
),
temperature
=
0
,
max_completion_tokens
=
100
,
model
=
model_name
,
tools
=
[
WEATHER_TOOL
,
SEARCH_TOOL
],
logprobs
=
False
,
seed
=
SEED
,
stream
=
True
,
)
result
=
await
_collect_streamed_content
(
stream
,
expected_finish_reason
=
choice
.
finish_reason
)
assert
result
.
role_sent
assert
result
.
finish_reason_count
==
1
assert
len
(
result
.
chunks
)
assert
""
.
join
(
result
.
chunks
)
==
choice
.
message
.
content
def
_requires_parallel
(
server_config
:
ServerConfig
)
->
None
:
r
"""Skip test if the model does not support parallel tool calls."""
if
not
server_config
.
get
(
"supports_parallel"
):
pytest
.
skip
(
f
"Skipping:
{
server_config
[
'model'
]
}
does not support parallel tool calls"
)
@
pytest
.
mark
.
asyncio
async
def
test_tool_call_parallel
(
client
:
openai
.
AsyncOpenAI
,
server_config
:
ServerConfig
)
->
None
:
_requires_tool_parser
(
server_config
)
_requires_parallel
(
server_config
)
models
=
await
client
.
models
.
list
()
model_name
:
str
=
models
.
data
[
0
].
id
# --- non-streaming ---
chat_completion
=
await
client
.
chat
.
completions
.
create
(
messages
=
ensure_system_prompt
(
MESSAGES_ASKING_FOR_PARALLEL_TOOLS
,
server_config
),
temperature
=
0
,
max_completion_tokens
=
200
,
model
=
model_name
,
tools
=
[
WEATHER_TOOL
],
logprobs
=
False
,
seed
=
SEED
,
)
choice
=
chat_completion
.
choices
[
0
]
tool_calls
=
choice
.
message
.
tool_calls
assert
choice
.
finish_reason
==
"tool_calls"
assert
tool_calls
is
not
None
and
len
(
tool_calls
)
>=
2
for
tc
in
tool_calls
:
assert
tc
.
type
==
"function"
assert
tc
.
function
.
name
==
"get_current_weather"
assert
isinstance
(
tc
.
function
.
arguments
,
str
)
parsed
=
json
.
loads
(
tc
.
function
.
arguments
)
assert
"city"
in
parsed
assert
len
(
tc
.
id
)
==
9
non_streaming_tool_calls
=
tool_calls
# --- streaming ---
stream
=
await
client
.
chat
.
completions
.
create
(
messages
=
ensure_system_prompt
(
MESSAGES_ASKING_FOR_PARALLEL_TOOLS
,
server_config
),
temperature
=
0
,
max_completion_tokens
=
200
,
model
=
model_name
,
tools
=
[
WEATHER_TOOL
],
logprobs
=
False
,
seed
=
SEED
,
stream
=
True
,
)
result
=
await
_collect_streamed_parallel_tool_calls
(
stream
)
assert
result
.
finish_reason_count
==
1
assert
result
.
role_name
==
"assistant"
assert
len
(
result
.
function_names
)
>=
2
assert
all
(
name
==
"get_current_weather"
for
name
in
result
.
function_names
)
assert
len
(
result
.
tool_call_ids
)
>=
2
assert
all
(
isinstance
(
tid
,
str
)
and
len
(
tid
)
==
9
for
tid
in
result
.
tool_call_ids
)
for
args_str
in
result
.
function_args_strs
:
streamed_args
=
json
.
loads
(
args_str
)
assert
"city"
in
streamed_args
assert
len
(
result
.
function_names
)
==
len
(
non_streaming_tool_calls
)
tests/tool_use/mistral/utils.py
View file @
c0722f22
...
...
@@ -2,16 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing_extensions
import
TypedDict
class
ServerConfig
(
TypedDict
,
total
=
False
):
model
:
str
arguments
:
list
[
str
]
system_prompt
:
str
|
None
supports_parallel
:
bool
|
None
supports_rocm
:
bool
|
None
from
tests.tool_use.utils
import
ServerConfig
ARGS
:
list
[
str
]
=
[
"--max-model-len"
,
"1024"
]
...
...
@@ -21,6 +12,11 @@ CONFIGS: dict[str, ServerConfig] = {
"arguments"
:
[
"--tokenizer-mode"
,
"mistral"
,
"--tool-call-parser"
,
"mistral"
,
"--enable-auto-tool-choice"
,
"--enforce-eager"
,
"--no-enable-prefix-caching"
,
'--ignore-patterns="consolidated.safetensors"'
,
],
"system_prompt"
:
"You are a helpful assistant with access to tools. If a tool"
...
...
@@ -29,4 +25,22 @@ CONFIGS: dict[str, ServerConfig] = {
"without calling a tool. DO NOT CALL A TOOL THAT IS IRRELEVANT "
"to the user's question - just respond to it normally."
,
},
"ministral-3b"
:
{
"model"
:
"mistralai/Ministral-3-3B-Instruct-2512"
,
"arguments"
:
[
"--tokenizer-mode"
,
"mistral"
,
"--tool-call-parser"
,
"mistral"
,
"--enable-auto-tool-choice"
,
"--enforce-eager"
,
"--no-enable-prefix-caching"
,
],
"system_prompt"
:
"You are a helpful assistant with access to tools. If a tool"
" that you have would be helpful to answer a user query, "
"call the tool. Otherwise, answer the user's query directly "
"without calling a tool. DO NOT CALL A TOOL THAT IS IRRELEVANT "
"to the user's question - just respond to it normally."
,
"supports_parallel"
:
True
,
},
}
vllm/entrypoints/openai/chat_completion/protocol.py
View file @
c0722f22
...
...
@@ -11,7 +11,7 @@ from openai.types.chat.chat_completion_audio import (
ChatCompletionAudio
as
OpenAIChatCompletionAudio
,
)
from
openai.types.chat.chat_completion_message
import
Annotation
as
OpenAIAnnotation
from
pydantic
import
Field
,
model_validator
from
pydantic
import
Field
,
PrivateAttr
,
model_validator
from
vllm.config
import
ModelConfig
from
vllm.config.utils
import
replace
...
...
@@ -398,6 +398,9 @@ class ChatCompletionRequest(OpenAIBaseModel):
msg
[
"tool_calls"
]
=
list
(
tool_calls
)
return
self
_grammar_from_tool_parser
:
bool
=
PrivateAttr
(
default
=
False
)
"""CAUTION: Should only be set by ``ToolParser.adjust_request``."""
def
build_chat_params
(
self
,
default_template
:
str
|
None
,
...
...
@@ -822,13 +825,6 @@ class ChatCompletionRequest(OpenAIBaseModel):
return
data
@
model_validator
(
mode
=
"before"
)
@
classmethod
def
set_include_reasoning_for_none_effort
(
cls
,
data
:
Any
)
->
Any
:
if
data
.
get
(
"reasoning_effort"
)
==
"none"
:
data
[
"include_reasoning"
]
=
False
return
data
class
BatchChatCompletionRequest
(
OpenAIBaseModel
):
"""Request model for the /v1/chat/completions/batch endpoint.
...
...
vllm/entrypoints/openai/chat_completion/serving.py
View file @
c0722f22
...
...
@@ -73,7 +73,10 @@ from vllm.reasoning import ReasoningParser
from
vllm.renderers
import
ChatParams
from
vllm.sampling_params
import
BeamSearchParams
,
SamplingParams
from
vllm.tokenizers
import
TokenizerLike
from
vllm.tool_parsers.mistral_tool_parser
import
MistralToolCall
from
vllm.tool_parsers.mistral_tool_parser
import
(
MistralToolCall
,
MistralToolParser
,
)
from
vllm.tool_parsers.utils
import
partial_json_loads
from
vllm.utils.collection_utils
import
as_list
from
vllm.utils.mistral
import
is_mistral_tokenizer
...
...
@@ -140,6 +143,12 @@ class OpenAIServingChat(OpenAIServing):
enable_auto_tools
=
enable_auto_tools
,
model_name
=
self
.
model_config
.
model
,
)
_is_mistral_tool_parser
=
self
.
tool_parser
is
not
None
and
issubclass
(
self
.
tool_parser
,
MistralToolParser
)
if
_is_mistral_tool_parser
and
self
.
reasoning_parser_cls
is
not
None
:
MistralToolParser
.
model_can_reason
=
True
self
.
exclude_tools_when_tool_choice_none
=
exclude_tools_when_tool_choice_none
self
.
enable_prompt_tokens_details
=
enable_prompt_tokens_details
...
...
@@ -310,6 +319,11 @@ class OpenAIServingChat(OpenAIServing):
else
:
if
not
request
.
include_reasoning
:
reasoning_ended
=
True
elif
request
.
_grammar_from_tool_parser
:
# The Mistral grammar already includes an optional
# `think?` rule that handles both reasoning and
# non-reasoning outputs.
reasoning_ended
=
True
elif
reasoning_parser
:
reasoning_ended
=
reasoning_parser
.
is_reasoning_end
(
prompt_token_ids
or
[]
...
...
@@ -530,6 +544,8 @@ class OpenAIServingChat(OpenAIServing):
harmony_tools_streamed
=
[
False
]
*
num_choices
tools_streamed
=
[
False
]
*
num_choices
is_mistral_grammar_path
=
request
.
_grammar_from_tool_parser
if
isinstance
(
request
.
tool_choice
,
ChatCompletionNamedToolChoiceParam
):
tool_choice_function_name
=
request
.
tool_choice
.
function
.
name
else
:
...
...
@@ -553,7 +569,7 @@ class OpenAIServingChat(OpenAIServing):
# Only one of these will be used, thus previous_texts and
# all_previous_token_ids will not be used twice in the same iteration.
if
tool_choice_auto
or
reasoning_parser
:
if
is_mistral_grammar_path
or
tool_choice_auto
or
reasoning_parser
:
# These are only required in "auto" tool choice case
all_previous_token_ids
=
[[]
for
_
in
range
(
num_choices
)]
reasoning_end_arr
=
[
False
]
*
num_choices
...
...
@@ -748,7 +764,7 @@ class OpenAIServingChat(OpenAIServing):
delta_message
:
DeltaMessage
|
None
# just update previous_texts and previous_token_ids
if
tool_choice_auto
or
reasoning_parser
:
if
is_mistral_grammar_path
or
tool_choice_auto
or
reasoning_parser
:
assert
previous_texts
is
not
None
assert
all_previous_token_ids
is
not
None
previous_text
=
previous_texts
[
i
]
...
...
@@ -772,6 +788,30 @@ class OpenAIServingChat(OpenAIServing):
)
)
harmony_tools_streamed
[
i
]
|=
tools_streamed_flag
# Mistral grammar path: combined reasoning + tool streaming
elif
is_mistral_grammar_path
:
assert
tool_parser
is
not
None
assert
isinstance
(
tool_parser
,
MistralToolParser
)
assert
reasoning_end_arr
is
not
None
output_token_ids
=
as_list
(
output
.
token_ids
)
result
=
tool_parser
.
extract_maybe_reasoning_and_tool_streaming
(
reasoning_parser
=
reasoning_parser
,
previous_text
=
previous_text
,
current_text
=
current_text
,
delta_text
=
delta_text
,
previous_token_ids
=
previous_token_ids
,
current_token_ids
=
current_token_ids
,
output_token_ids
=
output_token_ids
,
reasoning_ended
=
reasoning_end_arr
[
i
],
prompt_is_reasoning_end
=
(
prompt_is_reasoning_end_arr
[
i
]),
request
=
request
,
)
delta_message
=
result
.
delta_message
reasoning_end_arr
[
i
]
=
result
.
reasoning_ended
current_text
=
result
.
current_text
current_token_ids
=
result
.
current_token_ids
if
result
.
tools_called
:
tools_streamed
[
i
]
=
True
# handle streaming deltas for tools with named tool_choice
elif
tool_choice_function_name
:
# When encountering think end id in prompt_token_ids
...
...
@@ -925,7 +965,9 @@ class OpenAIServingChat(OpenAIServing):
delta_message
=
DeltaMessage
(
content
=
delta_text
)
# update the previous values for the next iteration
if
(
tool_choice_auto
or
reasoning_parser
)
and
not
self
.
use_harmony
:
if
(
is_mistral_grammar_path
or
tool_choice_auto
or
reasoning_parser
)
and
not
self
.
use_harmony
:
assert
previous_texts
is
not
None
assert
all_previous_token_ids
is
not
None
previous_texts
[
i
]
=
current_text
...
...
@@ -1312,7 +1354,24 @@ class OpenAIServingChat(OpenAIServing):
tool_call_class
=
(
MistralToolCall
if
is_mistral_tokenizer
(
tokenizer
)
else
ToolCall
)
if
(
not
self
.
enable_auto_tools
or
not
self
.
tool_parser
)
and
(
use_mistral_tool_parser
=
request
.
_grammar_from_tool_parser
if
use_mistral_tool_parser
:
tool_call_items
=
MistralToolParser
.
build_non_streaming_tool_calls
(
tool_calls
)
if
tool_call_items
:
auto_tools_called
=
(
request
.
tool_choice
is
None
or
request
.
tool_choice
==
"auto"
)
message
=
ChatMessage
(
role
=
role
,
reasoning
=
reasoning
,
content
=
content
,
tool_calls
=
tool_call_items
,
)
elif
(
not
self
.
enable_auto_tools
or
not
self
.
tool_parser
)
and
(
not
isinstance
(
request
.
tool_choice
,
ChatCompletionNamedToolChoiceParam
)
and
request
.
tool_choice
!=
"required"
):
...
...
vllm/entrypoints/openai/engine/serving.py
View file @
c0722f22
...
...
@@ -65,6 +65,7 @@ from vllm.renderers.inputs.preprocess import (
from
vllm.sampling_params
import
BeamSearchParams
,
SamplingParams
from
vllm.tokenizers
import
TokenizerLike
from
vllm.tool_parsers
import
ToolParser
from
vllm.tool_parsers.mistral_tool_parser
import
MistralToolParser
from
vllm.tracing
import
(
contains_trace_headers
,
extract_trace_headers
,
...
...
@@ -610,16 +611,31 @@ class OpenAIServing:
tool_parser_cls
:
type
[
ToolParser
]
|
None
,
content
:
str
|
None
=
None
,
)
->
tuple
[
list
[
FunctionCall
]
|
None
,
str
|
None
]:
# When the Mistral grammar factory injected structured outputs,
# let the parser handle the output.
use_mistral_tool_parser
=
(
isinstance
(
request
,
ChatCompletionRequest
)
and
tool_parser_cls
is
not
None
and
issubclass
(
tool_parser_cls
,
MistralToolParser
)
and
request
.
_grammar_from_tool_parser
)
function_calls
=
list
[
FunctionCall
]()
if
request
.
tool_choice
and
isinstance
(
request
.
tool_choice
,
ToolChoiceFunction
):
if
(
not
use_mistral_tool_parser
and
request
.
tool_choice
and
isinstance
(
request
.
tool_choice
,
ToolChoiceFunction
)
):
assert
content
is
not
None
# Forced Function Call
function_calls
.
append
(
FunctionCall
(
name
=
request
.
tool_choice
.
name
,
arguments
=
content
)
)
content
=
None
# Clear content since tool is called.
elif
request
.
tool_choice
and
isinstance
(
request
.
tool_choice
,
ChatCompletionNamedToolChoiceParam
elif
(
not
use_mistral_tool_parser
and
request
.
tool_choice
and
isinstance
(
request
.
tool_choice
,
ChatCompletionNamedToolChoiceParam
)
):
assert
content
is
not
None
# Forced Function Call
...
...
@@ -627,7 +643,7 @@ class OpenAIServing:
FunctionCall
(
name
=
request
.
tool_choice
.
function
.
name
,
arguments
=
content
)
)
content
=
None
# Clear content since tool is called.
elif
request
.
tool_choice
==
"required"
:
elif
not
use_mistral_tool_parser
and
request
.
tool_choice
==
"required"
:
tool_calls
=
[]
with
contextlib
.
suppress
(
ValidationError
):
content
=
content
or
""
...
...
@@ -642,10 +658,12 @@ class OpenAIServing:
)
)
content
=
None
# Clear content since tool is called.
elif
(
tool_parser_cls
and
enable_auto_tools
elif
tool_parser_cls
and
(
use_mistral_tool_parser
or
(
enable_auto_tools
and
(
request
.
tool_choice
==
"auto"
or
request
.
tool_choice
is
None
)
)
):
if
tokenizer
is
None
:
raise
ValueError
(
...
...
vllm/entrypoints/serve/render/serving.py
View file @
c0722f22
...
...
@@ -53,6 +53,7 @@ from vllm.renderers.inputs.preprocess import (
prompt_to_seq
,
)
from
vllm.tool_parsers
import
ToolParser
from
vllm.tool_parsers.mistral_tool_parser
import
MistralToolParser
from
vllm.utils
import
random_uuid
from
vllm.utils.mistral
import
is_mistral_tokenizer
from
vllm.utils.mistral
import
mt
as
_mt
...
...
@@ -555,9 +556,19 @@ class OpenAIServingRender:
# tool parsing is done only if a tool_parser has been set and if
# tool_choice is not "none" (if tool_choice is "none" but a tool_parser
# is set, we want to prevent parsing a tool_call hallucinated by the LLM
#
# Exception: Mistral grammar-capable tokenizers always call
# adjust_request — even for tool_choice="none" — so that the grammar
# factory can prevent special-token leakage.
if
tool_parser
is
not
None
:
tool_choice
=
getattr
(
request
,
"tool_choice"
,
"none"
)
if
tool_choice
!=
"none"
:
tokenizer
=
renderer
.
get_tokenizer
()
is_mistral_grammar_eligible
=
(
issubclass
(
tool_parser
,
MistralToolParser
)
and
is_mistral_tokenizer
(
tokenizer
)
and
tokenizer
.
supports_grammar
)
if
tool_choice
!=
"none"
or
is_mistral_grammar_eligible
:
if
not
isinstance
(
request
,
ChatCompletionRequest
|
ResponsesRequest
):
msg
=
(
"Tool usage is only supported "
...
...
@@ -565,7 +576,6 @@ class OpenAIServingRender:
f
"but got
{
type
(
request
).
__name__
}
"
)
raise
NotImplementedError
(
msg
)
tokenizer
=
renderer
.
get_tokenizer
()
request
=
tool_parser
(
tokenizer
,
request
.
tools
).
adjust_request
(
request
=
request
)
...
...
vllm/sampling_params.py
View file @
c0722f22
...
...
@@ -157,6 +157,10 @@ def _is_non_tekken_mistral(tokenizer: TokenizerLike) -> bool:
return
is_mistral_tokenizer
(
tokenizer
)
and
not
tokenizer
.
is_tekken
def
_get_llg_tokenizer
(
tokenizer
:
TokenizerLike
)
->
Any
:
return
tokenizer
.
llg_tokenizer
if
is_mistral_tokenizer
(
tokenizer
)
else
None
class
SamplingParams
(
PydanticMsgspecMixin
,
msgspec
.
Struct
,
...
...
@@ -816,7 +820,10 @@ class SamplingParams(
# allows <|special_token|> and similar, see
# https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens
# Without tokenizer these are disallowed in grammars.
validate_guidance_grammar
(
self
,
tokenizer
=
None
)
validate_guidance_grammar
(
self
,
tokenizer
=
_get_llg_tokenizer
(
tokenizer
),
)
elif
backend
==
"outlines"
:
# outlines backend
validate_structured_output_request_outlines
(
self
)
...
...
@@ -862,7 +869,10 @@ class SamplingParams(
self
.
structured_outputs
.
_backend
=
"outlines"
else
:
# Fall back to guidance by default.
validate_guidance_grammar
(
self
,
tokenizer
=
None
)
validate_guidance_grammar
(
self
,
tokenizer
=
_get_llg_tokenizer
(
tokenizer
),
)
self
.
structured_outputs
.
_backend
=
"guidance"
# Remember that this backend was set automatically
self
.
structured_outputs
.
_backend_was_auto
=
True
...
...
vllm/tokenizers/mistral.py
View file @
c0722f22
...
...
@@ -54,6 +54,50 @@ if TYPE_CHECKING:
logger
=
init_logger
(
__name__
)
def
_pop_unallowed_keys_and_warn
(
dictionary
:
dict
[
str
,
Any
],
allowed_keys
:
set
[
str
],
err_dict_name
:
str
):
keys
=
list
(
dictionary
.
keys
())
for
key
in
keys
:
if
key
not
in
allowed_keys
:
dictionary
.
pop
(
key
)
logger
.
warning_once
(
f
"'
{
key
=
}
' is not supported by mistral-common "
f
"for
{
err_dict_name
}
. It has been popped from the "
"object."
)
# TODO(juliendenize): remove this once OpenAI API is better supported by
# `mistral-common`.
def
adapt_inplace_to_mistral_tool
(
tool
:
dict
[
str
,
Any
],
)
->
dict
[
str
,
Any
]:
tools_fields
=
set
(
Tool
.
model_fields
.
keys
())
function_fields
=
set
(
Function
.
model_fields
.
keys
())
# The Mistral client, in comparison to the OpenAI client, requires the
# "parameters" dict and the "description" string to be present
# even if they are empty.
if
function
:
=
tool
.
get
(
"function"
):
if
function
.
get
(
"parameters"
)
is
None
:
function
[
"parameters"
]
=
{}
if
function
.
get
(
"description"
)
is
None
:
function
[
"description"
]
=
""
_pop_unallowed_keys_and_warn
(
dictionary
=
function
,
allowed_keys
=
function_fields
,
err_dict_name
=
"function"
,
)
_pop_unallowed_keys_and_warn
(
dictionary
=
tool
,
allowed_keys
=
tools_fields
,
err_dict_name
=
"tools"
)
return
tool
def
maybe_serialize_tool_calls
(
request
:
"MistralChatCompletionRequest"
):
# SEE: https://github.com/vllm-project/vllm/pull/9951
# Credits go to: @gcalmettes
...
...
@@ -159,44 +203,11 @@ def _prepare_apply_chat_template_tools_and_messages(
# Remove reasoning as unsupported by Mistral
_
=
message
.
pop
(
"reasoning"
,
None
)
# type: ignore
# The Mistral client, in comparison to the OpenAI client, requires the
# "parameters" dict and the "description" string to be present
# even if they are empty.
if
tools
:
for
function
in
[
tool
[
"function"
]
for
tool
in
tools
if
tool
[
"type"
]
==
"function"
]:
if
function
.
get
(
"parameters"
)
is
None
:
function
[
"parameters"
]
=
{}
if
function
.
get
(
"description"
)
is
None
:
function
[
"description"
]
=
""
# We filter not supported arguments to avoid throwing an error.
# TODO(juliendenize): remove this once OpenAI API is better supported by
# `mistral-common`.
tools_fields
=
set
(
Tool
.
model_fields
.
keys
())
function_fields
=
set
(
Function
.
model_fields
.
keys
())
for
tool
in
tools
:
tool_keys
=
list
(
tool
.
keys
())
for
tool_key
in
tool_keys
:
if
tool_key
not
in
tools_fields
:
tool
.
pop
(
tool_key
)
logger
.
warning_once
(
f
"'
{
tool_key
}
' is not supported by mistral-common for tools. "
"It has been popped from the tool definition."
)
if
tool
[
"type"
]
==
"function"
:
function_keys
=
list
(
tool
[
"function"
].
keys
())
for
function_key
in
function_keys
:
if
function_key
not
in
function_fields
:
tool
[
"function"
].
pop
(
function_key
)
logger
.
warning_once
(
f
"'
{
function_key
}
' is not supported by mistral-common "
"for function tools. It has been popped from the "
"function definition."
tools
=
(
[
adapt_inplace_to_mistral_tool
(
tool
=
tool
)
for
tool
in
tools
]
if
tools
is
not
None
else
None
)
else
:
raise
ValueError
(
"mistral-common only supports function tools."
)
return
messages
,
tools
...
...
vllm/tool_parsers/mistral_tool_parser.py
View file @
c0722f22
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
__future__
import
annotations
import
json
from
collections.abc
import
Sequence
from
dataclasses
import
dataclass
from
enum
import
Enum
,
auto
from
random
import
choices
from
string
import
ascii_letters
,
digits
from
typing
import
Any
from
typing
import
TYPE_CHECKING
,
Any
import
ijson
import
regex
as
re
...
...
@@ -37,14 +40,19 @@ from vllm.entrypoints.openai.engine.protocol import (
)
from
vllm.entrypoints.openai.responses.protocol
import
ResponsesRequest
from
vllm.logger
import
init_logger
from
vllm.reasoning.mistral_reasoning_parser
import
MistralReasoningParser
from
vllm.sampling_params
import
StructuredOutputsParams
from
vllm.tokenizers
import
TokenizerLike
from
vllm.tokenizers.mistral
import
MistralTokenizer
,
adapt_inplace_to_mistral_tool
from
vllm.tool_parsers.abstract_tool_parser
import
(
Tool
,
ToolParser
,
)
from
vllm.utils.mistral
import
is_mistral_tokenizer
if
TYPE_CHECKING
:
from
vllm.reasoning
import
ReasoningParser
logger
=
init_logger
(
__name__
)
ALPHANUMERIC
=
ascii_letters
+
digits
...
...
@@ -86,13 +94,28 @@ def _is_pre_v11_tokeniser(model_tokenizer: TokenizerLike) -> bool:
return
not
(
is_mistral_tokenizer
(
model_tokenizer
)
and
model_tokenizer
.
version
>=
11
)
class
MistralToolParser
(
ToolParser
):
@
dataclass
class
MistralStreamingResult
:
r
"""Encapsulates the mutable state returned from
`MistralToolParser.extract_maybe_reasoning_and_tool_streaming`.
"""
Tool call parser for Mistral 7B Instruct v0.3, intended for use with
- [`mistral_common`](https://github.com/mistralai/mistral-common/)
- the examples/tool_chat_template_mistral.jinja template.
Used when --enable-auto-tool-choice --tool-call-parser mistral are all set
delta_message
:
DeltaMessage
|
None
reasoning_ended
:
bool
tools_called
:
bool
current_text
:
str
current_token_ids
:
list
[
int
]
class
MistralToolParser
(
ToolParser
):
r
"""Tool call parser for Mistral models, intended for use with either:
- `mistral_common <https://github.com/mistralai/mistral-common/>`_
(recommended)
- the `examples/tool_chat_template_mistral.jinja` template.
Used when `--enable-auto-tool-choice --tool-call-parser mistral` are all
set.
"""
# Used to generate correct grammar in `adjust_request`
...
...
@@ -210,9 +233,11 @@ class MistralToolParser(ToolParser):
reasoning
=
self
.
model_can_reason
)
tools
=
(
mistral_
tools
=
(
[
MistralTool
.
from_openai
(
openai_tool
=
tool
.
model_dump
())
MistralTool
.
model_validate
(
adapt_inplace_to_mistral_tool
(
tool
.
model_dump
())
)
for
tool
in
request
.
tools
]
if
request
.
tools
is
not
None
...
...
@@ -244,15 +269,158 @@ class MistralToolParser(ToolParser):
lark_grammar
=
grammar_factory
.
get_lark_from_jinja
(
template
=
template
,
mode
=
tool_choice
,
tools
=
tools
,
tools
=
mistral_
tools
,
json_schema
=
json_schema
,
parallel_tool_calls
=
request
.
parallel_tool_calls
,
json_only
=
False
,
)
request
.
structured_outputs
=
StructuredOutputsParams
(
grammar
=
lark_grammar
)
request
.
_grammar_from_tool_parser
=
True
return
request
def
extract_maybe_reasoning_and_tool_streaming
(
self
,
*
,
reasoning_parser
:
ReasoningParser
|
None
,
previous_text
:
str
,
current_text
:
str
,
delta_text
:
str
,
previous_token_ids
:
list
[
int
],
current_token_ids
:
list
[
int
],
output_token_ids
:
Sequence
[
int
],
reasoning_ended
:
bool
,
prompt_is_reasoning_end
:
bool
|
None
,
request
:
ChatCompletionRequest
,
)
->
MistralStreamingResult
:
r
"""Streaming extraction with reasoning followed by tool-call parsing.
This method encapsulates the combined reasoning extraction and
tool-call streaming logic so that the serving layer only needs a
thin routing branch.
The flow is:
1. If a *reasoning_parser* is present and reasoning has **not** ended,
extract reasoning tokens. Pre-v15 models may have pre-filled
`[THINK]...[/THINK]` in system prompts, so we skip the
prompt-level reasoning-end check for those.
2. Once reasoning ends (or if there is no reasoning parser), delegate
to `extract_tool_calls_streaming` and track whether tools were
called.
Args:
reasoning_parser: Optional reasoning parser instance.
previous_text: Accumulated text from prior chunks.
current_text: Full accumulated text including current chunk.
delta_text: New text in this chunk.
previous_token_ids: Token ids from prior chunks.
current_token_ids: Full token ids including current chunk.
output_token_ids: Raw output token ids from the engine.
reasoning_ended: Whether reasoning has already ended.
prompt_is_reasoning_end: Whether the prompt itself ends reasoning.
request: The originating chat completion request.
"""
delta_message
:
DeltaMessage
|
None
=
None
tools_called
=
False
reasoning_ended_at_entry
=
reasoning_ended
# For MistralReasoningParser, only enter the reasoning block when
# the model has actually emitted a [THINK] token. Other reasoning
# parsers always expect thinking to be present.
expect_thinking
=
(
not
isinstance
(
reasoning_parser
,
MistralReasoningParser
)
or
reasoning_parser
.
start_token_id
in
current_token_ids
)
if
reasoning_parser
is
not
None
and
not
reasoning_ended
and
expect_thinking
:
# Pre-v15 models may have pre-filled [THINK]...[/THINK] in
# system prompts, so skip the prompt-level reasoning-end
# check and wait for the output's own end-of-think.
is_pre_v15
=
(
isinstance
(
self
.
model_tokenizer
,
MistralTokenizer
)
and
self
.
model_tokenizer
.
version
<
15
)
if
not
is_pre_v15
and
prompt_is_reasoning_end
:
reasoning_ended
=
True
current_token_ids
=
list
(
output_token_ids
)
else
:
delta_message
=
reasoning_parser
.
extract_reasoning_streaming
(
previous_text
,
current_text
,
delta_text
,
previous_token_ids
,
current_token_ids
,
output_token_ids
,
)
if
reasoning_parser
.
is_reasoning_end_streaming
(
current_token_ids
,
output_token_ids
):
reasoning_ended
=
True
current_token_ids
=
reasoning_parser
.
extract_content_ids
(
list
(
output_token_ids
)
)
if
delta_message
and
delta_message
.
content
:
current_text
=
delta_message
.
content
delta_message
.
content
=
None
else
:
current_text
=
""
if
not
reasoning_ended
:
return
MistralStreamingResult
(
delta_message
=
delta_message
,
reasoning_ended
=
False
,
tools_called
=
False
,
current_text
=
current_text
,
current_token_ids
=
current_token_ids
,
)
delta_token_ids
=
list
(
output_token_ids
)
# On the iteration where reasoning just ended, reset the text/token
# state so the tool parser sees a clean history instead of the
# accumulated reasoning text.
if
not
reasoning_ended_at_entry
and
reasoning_ended
:
previous_text
=
""
previous_token_ids
=
[]
delta_text
=
current_text
delta_token_ids
=
current_token_ids
delta_message
=
self
.
extract_tool_calls_streaming
(
previous_text
=
previous_text
,
current_text
=
current_text
,
delta_text
=
delta_text
,
previous_token_ids
=
previous_token_ids
,
current_token_ids
=
current_token_ids
,
delta_token_ids
=
delta_token_ids
,
request
=
request
,
)
if
delta_message
and
delta_message
.
tool_calls
:
tools_called
=
True
return
MistralStreamingResult
(
delta_message
=
delta_message
,
reasoning_ended
=
reasoning_ended
,
tools_called
=
tools_called
,
current_text
=
current_text
,
current_token_ids
=
current_token_ids
,
)
@
staticmethod
def
build_non_streaming_tool_calls
(
tool_calls
:
list
[
FunctionCall
]
|
None
,
)
->
list
[
ToolCall
]:
r
"""Build `MistralToolCall` items for non-streaming responses."""
if
not
tool_calls
:
return
[]
return
[
MistralToolCall
(
id
=
tc
.
id
,
function
=
tc
)
if
tc
.
id
else
MistralToolCall
(
function
=
tc
)
for
tc
in
tool_calls
]
def
extract_tool_calls
(
self
,
model_output
:
str
,
...
...
@@ -323,7 +491,7 @@ class MistralToolParser(ToolParser):
)[
0
]
tool_calls
=
json
.
loads
(
raw_tool_call
)
except
(
IndexError
,
json
.
JSONDecodeError
):
logger
.
exception
(
"Error in extracting tool call from response
: {e}
"
)
logger
.
exception
(
"Error in extracting tool call from response
.
"
)
# If raw decoding and decoding post regex rule fails, then just
# return content.
return
ExtractedToolCallInformation
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment