Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
fef56c18
Unverified
Commit
fef56c18
authored
Apr 06, 2026
by
Julien Denize
Committed by
GitHub
Apr 06, 2026
Browse files
[Mistral Grammar] Support Grammar Factory (#38150)
Signed-off-by:
juliendenize
<
julien.denize@mistral.ai
>
parent
c5e3454e
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
601 additions
and
29 deletions
+601
-29
requirements/common.txt
requirements/common.txt
+1
-1
requirements/rocm-test.txt
requirements/rocm-test.txt
+1
-1
requirements/test.txt
requirements/test.txt
+1
-1
tests/tokenizers_/test_mistral.py
tests/tokenizers_/test_mistral.py
+28
-0
tests/tool_parsers/test_mistral_tool_parser.py
tests/tool_parsers/test_mistral_tool_parser.py
+344
-3
tests/v1/structured_output/test_backend_guidance.py
tests/v1/structured_output/test_backend_guidance.py
+44
-0
vllm/sampling_params.py
vllm/sampling_params.py
+17
-11
vllm/tokenizers/mistral.py
vllm/tokenizers/mistral.py
+25
-0
vllm/tool_parsers/mistral_tool_parser.py
vllm/tool_parsers/mistral_tool_parser.py
+133
-9
vllm/v1/structured_output/backend_guidance.py
vllm/v1/structured_output/backend_guidance.py
+7
-3
No files found.
requirements/common.txt
View file @
fef56c18
...
...
@@ -31,7 +31,7 @@ partial-json-parser # used for parsing partial JSON outputs
pyzmq >= 25.0.0
msgspec
gguf >= 0.17.0
mistral_common[image] >= 1.1
0
.0
mistral_common[image] >= 1.1
1
.0
opencv-python-headless >= 4.13.0 # required for video IO
pyyaml
six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12
...
...
requirements/rocm-test.txt
View file @
fef56c18
...
...
@@ -604,7 +604,7 @@ mcp==1.27.0
# via -r requirements/common.txt
mdurl==0.1.2
# via markdown-it-py
mistral-common==1.1
0
.0
mistral-common==1.1
1
.0
# via
# -c requirements/common.txt
# -r requirements/common.txt
...
...
requirements/test.txt
View file @
fef56c18
...
...
@@ -508,7 +508,7 @@ mbstrdecoder==1.1.3
# typepy
mdurl==0.1.2
# via markdown-it-py
mistral-common==1.1
0
.0
mistral-common==1.1
1
.0
# via
# -c requirements/common.txt
# -r requirements/test.in
...
...
tests/tokenizers_/test_mistral.py
View file @
fef56c18
...
...
@@ -3,8 +3,10 @@
from
typing
import
Any
import
llguidance
import
pytest
from
mistral_common.exceptions
import
InvalidMessageStructureException
from
mistral_common.guidance.grammar_factory
import
GrammarFactory
from
mistral_common.tokens.tokenizers.base
import
SpecialTokenPolicy
from
vllm.tokenizers.mistral
import
(
...
...
@@ -2407,3 +2409,29 @@ class TestMistralTokenizer:
assert
actual_tokens
==
expected_tokens
assert
mistral_tokenizer
.
convert_ids_to_tokens
([])
==
[]
def
test_grammar_factory
(
self
,
mistral_tokenizer
:
MistralTokenizer
)
->
None
:
# works in this case cause Mistral 7B is < v11 and SPM
if
not
mistral_tokenizer
.
is_tekken
:
with
pytest
.
raises
(
AttributeError
):
mistral_tokenizer
.
grammar_factory
# noqa: B018
return
factory
=
mistral_tokenizer
.
grammar_factory
assert
isinstance
(
factory
,
GrammarFactory
)
# Test caching
factory_2
=
mistral_tokenizer
.
grammar_factory
assert
factory
is
factory_2
def
test_llg_tokenizer
(
self
,
mistral_tokenizer
:
MistralTokenizer
)
->
None
:
if
not
mistral_tokenizer
.
is_tekken
:
with
pytest
.
raises
(
ValueError
):
mistral_tokenizer
.
llg_tokenizer
# noqa: B018
return
llg_tokenizer
=
mistral_tokenizer
.
llg_tokenizer
assert
isinstance
(
llg_tokenizer
,
llguidance
.
LLTokenizer
)
# Test caching
llg_tokenizer_2
=
mistral_tokenizer
.
llg_tokenizer
assert
llg_tokenizer
is
llg_tokenizer_2
tests/tool_parsers/test_mistral_tool_parser.py
View file @
fef56c18
...
...
@@ -3,19 +3,43 @@
import
json
from
collections.abc
import
Generator
from
unittest.mock
import
MagicMock
,
patch
import
partial_json_parser
import
pytest
from
mistral_common.protocol.instruct.messages
import
AssistantMessage
from
mistral_common.protocol.instruct.request
import
InstructRequest
from
mistral_common.protocol.instruct.tool_calls
import
FunctionCall
,
ToolCall
from
mistral_common.protocol.instruct.tool_calls
import
(
FunctionCall
,
ToolCall
,
)
from
mistral_common.protocol.instruct.tool_calls
import
(
NamedToolChoice
as
MistralNamedToolChoice
,
)
from
mistral_common.protocol.instruct.tool_calls
import
(
ToolChoice
as
MistralToolChoice
,
)
from
mistral_common.protocol.instruct.tool_calls
import
(
ToolChoiceEnum
as
MistralToolChoiceEnum
,
)
from
partial_json_parser.core.options
import
Allow
from
vllm.entrypoints.openai.engine.protocol
import
DeltaMessage
,
DeltaToolCall
from
vllm.entrypoints.openai.chat_completion.protocol
import
(
ChatCompletionRequest
,
)
from
vllm.entrypoints.openai.engine.protocol
import
(
DeltaMessage
,
DeltaToolCall
,
StructuralTagResponseFormat
,
)
from
vllm.sampling_params
import
StructuredOutputsParams
from
vllm.tokenizers
import
TokenizerLike
,
get_tokenizer
from
vllm.tokenizers.detokenizer_utils
import
detokenize_incrementally
from
vllm.tokenizers.mistral
import
MistralTokenizer
from
vllm.tool_parsers.mistral_tool_parser
import
MistralToolParser
from
vllm.tool_parsers.mistral_tool_parser
import
(
_DEFAULT_JSON_SCHEMA
,
MistralToolParser
,
)
@
pytest
.
fixture
(
scope
=
"module"
)
...
...
@@ -40,6 +64,13 @@ def mistral_tool_parser(mistral_tokenizer):
return
MistralToolParser
(
mistral_tokenizer
)
@
pytest
.
fixture
def
non_mistral_parser
()
->
MistralToolParser
:
mock_tokenizer
=
MagicMock
()
mock_tokenizer
.
get_vocab
.
return_value
=
{
"[TOOL_CALLS]"
:
1
}
return
MistralToolParser
(
mock_tokenizer
)
def
assert_tool_calls
(
actual_tool_calls
:
list
[
ToolCall
]
|
list
[
DeltaToolCall
],
expected_tool_calls
:
list
[
ToolCall
],
...
...
@@ -951,3 +982,313 @@ def test_fast_detokenization_text_detection_pre_v11(
assert
len
(
delta_message
.
tool_calls
)
>
0
assert
delta_message
.
tool_calls
[
0
].
function
is
not
None
assert
delta_message
.
tool_calls
[
0
].
function
.
name
==
"add"
SAMPLE_TOOLS_DICTS
=
[
{
"type"
:
"function"
,
"function"
:
{
"name"
:
"get_weather"
,
"description"
:
"Get the weather"
,
"parameters"
:
{
"type"
:
"object"
,
"properties"
:
{
"city"
:
{
"type"
:
"string"
}},
"required"
:
[
"city"
],
},
},
},
{
"type"
:
"function"
,
"function"
:
{
"name"
:
"add"
,
"description"
:
"Add two numbers"
,
"parameters"
:
{
"type"
:
"object"
,
"properties"
:
{
"a"
:
{
"type"
:
"number"
},
"b"
:
{
"type"
:
"number"
},
},
"required"
:
[
"a"
,
"b"
],
},
},
},
]
def
_make_request
(
**
kwargs
)
->
ChatCompletionRequest
:
defaults
:
dict
=
{
"messages"
:
[],
"model"
:
"mistralai/Mistral-Small-3.2-24B-Instruct-2506"
,
"tools"
:
SAMPLE_TOOLS_DICTS
,
"tool_choice"
:
"auto"
,
}
defaults
.
update
(
kwargs
)
return
ChatCompletionRequest
(
**
defaults
)
@
pytest
.
mark
.
parametrize
(
"request_kwargs,expected_mode,expected_parallel"
,
[
({
"tool_choice"
:
"auto"
},
MistralToolChoiceEnum
.
auto
,
True
),
({
"tool_choice"
:
"none"
},
MistralToolChoiceEnum
.
none
,
True
),
({
"tool_choice"
:
"required"
},
MistralToolChoiceEnum
.
required
,
True
),
({
"tool_choice"
:
None
,
"tools"
:
None
},
MistralToolChoiceEnum
.
auto
,
True
),
(
{
"tool_choice"
:
{
"type"
:
"function"
,
"function"
:
{
"name"
:
"get_weather"
},
}
},
MistralNamedToolChoice
.
model_validate
(
{
"type"
:
"function"
,
"function"
:
{
"name"
:
"get_weather"
}}
),
True
,
),
(
{
"tool_choice"
:
"auto"
,
"parallel_tool_calls"
:
False
},
MistralToolChoiceEnum
.
auto
,
False
,
),
(
{
"tool_choice"
:
"auto"
,
"response_format"
:
{
"type"
:
"text"
}},
MistralToolChoiceEnum
.
auto
,
True
,
),
],
ids
=
[
"auto"
,
"none"
,
"required"
,
"null_tool_choice"
,
"named_tool_choice"
,
"parallel_false"
,
"response_format_text"
,
],
)
def
test_adjust_request_grammar_factory
(
mistral_tool_parser
:
MistralToolParser
,
request_kwargs
:
dict
,
expected_mode
:
MistralToolChoice
,
expected_parallel
:
bool
,
)
->
None
:
request
=
_make_request
(
**
request_kwargs
)
factory
=
mistral_tool_parser
.
model_tokenizer
.
grammar_factory
with
patch
.
object
(
factory
,
"get_lark_from_jinja"
,
wraps
=
factory
.
get_lark_from_jinja
,
)
as
mock_get_lark
:
result
=
mistral_tool_parser
.
adjust_request
(
request
)
mock_get_lark
.
assert_called_once
()
call_kwargs
=
mock_get_lark
.
call_args
assert
call_kwargs
.
kwargs
[
"mode"
]
==
expected_mode
assert
call_kwargs
.
kwargs
[
"json_schema"
]
is
None
assert
call_kwargs
.
kwargs
[
"parallel_tool_calls"
]
==
expected_parallel
assert
result
.
structured_outputs
is
not
None
assert
isinstance
(
result
.
structured_outputs
.
grammar
,
str
)
assert
len
(
result
.
structured_outputs
.
grammar
)
>
0
def
test_adjust_request_unsupported_grammar_for_tokenizer
(
mistral_tokenizer
)
->
None
:
with
patch
.
object
(
type
(
mistral_tokenizer
),
"supports_grammar"
,
new_callable
=
lambda
:
property
(
lambda
self
:
False
),
):
parser
=
MistralToolParser
(
mistral_tokenizer
)
request
=
_make_request
()
result
=
parser
.
adjust_request
(
request
)
assert
result
.
structured_outputs
is
None
@
pytest
.
mark
.
parametrize
(
"tool_choice,expected_skip"
,
[(
"auto"
,
False
),
(
"none"
,
True
)],
ids
=
[
"auto_skip_false"
,
"none_skip_true"
],
)
def
test_adjust_request_non_mistral_tokenizer
(
non_mistral_parser
:
MistralToolParser
,
tool_choice
:
str
,
expected_skip
:
bool
,
)
->
None
:
request
=
_make_request
(
tool_choice
=
tool_choice
)
result
=
non_mistral_parser
.
adjust_request
(
request
)
assert
result
.
skip_special_tokens
is
expected_skip
@
pytest
.
mark
.
parametrize
(
"so_kwargs"
,
[
{
"regex"
:
r
"\d+"
},
{
"choice"
:
[
"a"
,
"b"
]},
{
"structural_tag"
:
'{"key": "value"}'
},
{
"grammar"
:
"start: 'hello'"
},
],
ids
=
[
"regex"
,
"choice"
,
"structural_tag"
,
"grammar"
],
)
def
test_adjust_request_unsupported_structured_outputs
(
mistral_tool_parser
:
MistralToolParser
,
so_kwargs
:
dict
,
)
->
None
:
request
=
_make_request
(
structured_outputs
=
StructuredOutputsParams
(
**
so_kwargs
),
)
result
=
mistral_tool_parser
.
adjust_request
(
request
)
assert
result
.
structured_outputs
==
request
.
structured_outputs
def
test_adjust_request_unsupported_response_format
(
mistral_tool_parser
:
MistralToolParser
,
)
->
None
:
request
=
_make_request
(
response_format
=
StructuralTagResponseFormat
(
type
=
"structural_tag"
,
format
=
{
"some"
:
"config"
}
),
)
result
=
mistral_tool_parser
.
adjust_request
(
request
)
assert
result
.
structured_outputs
is
None
assert
result
.
response_format
==
request
.
response_format
@
pytest
.
mark
.
parametrize
(
"so_kwargs,expected_json_schema"
,
[
({
"json_object"
:
True
},
_DEFAULT_JSON_SCHEMA
),
({
"json"
:
'{"type": "object"}'
},
{
"type"
:
"object"
}),
(
{
"json"
:
{
"type"
:
"object"
,
"properties"
:
{
"x"
:
{
"type"
:
"integer"
}}}},
{
"type"
:
"object"
,
"properties"
:
{
"x"
:
{
"type"
:
"integer"
}}},
),
],
ids
=
[
"json_object"
,
"json_str"
,
"json_dict"
],
)
def
test_adjust_request_structured_outputs_generates_grammar
(
mistral_tool_parser
:
MistralToolParser
,
so_kwargs
:
dict
,
expected_json_schema
:
str
,
)
->
None
:
request
=
_make_request
(
structured_outputs
=
StructuredOutputsParams
(
**
so_kwargs
),
)
factory
=
mistral_tool_parser
.
model_tokenizer
.
grammar_factory
with
patch
.
object
(
factory
,
"get_lark_from_jinja"
,
wraps
=
factory
.
get_lark_from_jinja
,
)
as
mock_get_lark
:
result
=
mistral_tool_parser
.
adjust_request
(
request
)
mock_get_lark
.
assert_called_once
()
assert
mock_get_lark
.
call_args
.
kwargs
[
"json_schema"
]
==
expected_json_schema
assert
result
.
structured_outputs
is
not
None
assert
isinstance
(
result
.
structured_outputs
.
grammar
,
str
)
assert
len
(
result
.
structured_outputs
.
grammar
)
>
0
@
pytest
.
mark
.
parametrize
(
"response_format_kwargs,expected_json_schema"
,
[
({
"type"
:
"json_object"
},
_DEFAULT_JSON_SCHEMA
),
(
{
"type"
:
"json_schema"
,
"json_schema"
:
{
"name"
:
"my_schema"
,
"schema"
:
{
"type"
:
"object"
,
"properties"
:
{
"x"
:
{
"type"
:
"integer"
}},
},
},
},
{
"type"
:
"object"
,
"properties"
:
{
"x"
:
{
"type"
:
"integer"
}}},
),
],
ids
=
[
"json_object"
,
"json_schema_with_schema"
],
)
def
test_adjust_request_response_format_generates_grammar
(
mistral_tool_parser
:
MistralToolParser
,
response_format_kwargs
:
dict
,
expected_json_schema
:
str
,
)
->
None
:
request
=
_make_request
(
response_format
=
response_format_kwargs
)
factory
=
mistral_tool_parser
.
model_tokenizer
.
grammar_factory
with
patch
.
object
(
factory
,
"get_lark_from_jinja"
,
wraps
=
factory
.
get_lark_from_jinja
,
)
as
mock_get_lark
:
result
=
mistral_tool_parser
.
adjust_request
(
request
)
mock_get_lark
.
assert_called_once
()
assert
mock_get_lark
.
call_args
.
kwargs
[
"json_schema"
]
==
expected_json_schema
assert
result
.
structured_outputs
is
not
None
assert
isinstance
(
result
.
structured_outputs
.
grammar
,
str
)
assert
len
(
result
.
structured_outputs
.
grammar
)
>
0
def
test_adjust_request_tool_choice_none_with_json_schema_uses_json_schema_factory
(
mistral_tool_parser
:
MistralToolParser
,
)
->
None
:
request
=
_make_request
(
tool_choice
=
"none"
,
structured_outputs
=
StructuredOutputsParams
(
json
=
'{"type": "object"}'
),
)
factory
=
mistral_tool_parser
.
model_tokenizer
.
grammar_factory
with
patch
.
object
(
factory
,
"get_lark_for_json_schema"
,
wraps
=
factory
.
get_lark_for_json_schema
,
)
as
mock_json_schema
:
result
=
mistral_tool_parser
.
adjust_request
(
request
)
mock_json_schema
.
assert_called_once
()
assert
mock_json_schema
.
call_args
.
kwargs
[
"json_schema"
]
==
{
"type"
:
"object"
}
assert
result
.
structured_outputs
is
not
None
assert
isinstance
(
result
.
structured_outputs
.
grammar
,
str
)
assert
len
(
result
.
structured_outputs
.
grammar
)
>
0
def
test_adjust_request_tool_choice_auto_with_json_schema_uses_jinja_factory
(
mistral_tool_parser
:
MistralToolParser
,
)
->
None
:
request
=
_make_request
(
tool_choice
=
"auto"
,
structured_outputs
=
StructuredOutputsParams
(
json
=
'{"type": "object"}'
),
)
factory
=
mistral_tool_parser
.
model_tokenizer
.
grammar_factory
with
(
patch
.
object
(
factory
,
"get_lark_for_json_schema"
,
wraps
=
factory
.
get_lark_for_json_schema
,
)
as
mock_json_schema
,
patch
.
object
(
factory
,
"get_lark_from_jinja"
,
wraps
=
factory
.
get_lark_from_jinja
,
)
as
mock_jinja
,
):
result
=
mistral_tool_parser
.
adjust_request
(
request
)
mock_jinja
.
assert_called_once
()
assert
mock_jinja
.
call_args
.
kwargs
[
"json_schema"
]
==
{
"type"
:
"object"
}
mock_json_schema
.
assert_not_called
()
assert
result
.
structured_outputs
is
not
None
assert
isinstance
(
result
.
structured_outputs
.
grammar
,
str
)
assert
len
(
result
.
structured_outputs
.
grammar
)
>
0
tests/v1/structured_output/test_backend_guidance.py
View file @
fef56c18
...
...
@@ -11,6 +11,7 @@ from vllm.config.model import ModelConfig
from
vllm.config.parallel
import
ParallelConfig
from
vllm.config.speculative
import
SpeculativeConfig
from
vllm.sampling_params
import
SamplingParams
,
StructuredOutputsParams
from
vllm.tokenizers
import
get_tokenizer
from
vllm.v1.request
import
Request
from
vllm.v1.structured_output
import
StructuredOutputManager
from
vllm.v1.structured_output.backend_guidance
import
GuidanceBackend
...
...
@@ -19,6 +20,14 @@ from vllm.v1.structured_output.backend_types import StructuredOutputOptions
TOKENIZER
=
"gpt2"
@
pytest
.
fixture
(
scope
=
"module"
)
def
mistral_tokenizer
():
return
get_tokenizer
(
tokenizer_name
=
"mistralai/Mistral-Small-3.2-24B-Instruct-2506"
,
tokenizer_mode
=
"mistral"
,
)
def
test_backend_guidance_rollback_terminated
():
# Test that the backend guidance successfully rollbacks from a
# terminated state. This can happen with speculative decoding,
...
...
@@ -187,3 +196,38 @@ def test_grammar_init_async_and_sync(async_grammar):
# Verify the grammar can accept valid tokens
assert
grammar
.
accept_tokens
(
request
.
request_id
,
prompt
)
@
pytest
.
mark
.
parametrize
(
"request_type,grammar_spec"
,
[
pytest
.
param
(
StructuredOutputOptions
.
JSON
,
'{"type": "object"}'
,
id
=
"json"
,
),
pytest
.
param
(
StructuredOutputOptions
.
GRAMMAR
,
'start: "hello" | "world"'
,
id
=
"lark"
,
),
],
)
def
test_mistral_tokenizer_compile_grammar
(
mistral_tokenizer
,
request_type
:
StructuredOutputOptions
,
grammar_spec
:
str
,
)
->
None
:
vllm_config
=
VllmConfig
(
structured_outputs_config
=
StructuredOutputsConfig
(
backend
=
"guidance"
),
)
backend
=
GuidanceBackend
(
vllm_config
,
tokenizer
=
mistral_tokenizer
,
vocab_size
=
mistral_tokenizer
.
vocab_size
,
)
assert
backend
.
ll_tokenizer
is
mistral_tokenizer
.
llg_tokenizer
grammar
=
backend
.
compile_grammar
(
request_type
,
grammar_spec
)
assert
grammar
is
not
None
assert
not
grammar
.
is_terminated
()
vllm/sampling_params.py
View file @
fef56c18
...
...
@@ -153,6 +153,10 @@ class RequestOutputKind(Enum):
FINAL_ONLY
=
2
def
_is_non_tekken_mistral
(
tokenizer
:
TokenizerLike
)
->
bool
:
return
is_mistral_tokenizer
(
tokenizer
)
and
not
tokenizer
.
is_tekken
class
SamplingParams
(
PydanticMsgspecMixin
,
msgspec
.
Struct
,
...
...
@@ -801,16 +805,17 @@ class SamplingParams(
# xgrammar with no fallback
validate_xgrammar_grammar
(
self
)
elif
backend
.
startswith
(
"guidance"
):
if
_is_non_tekken_mistral
(
tokenizer
=
tokenizer
):
raise
ValueError
(
"Non-tekken Mistral tokenizers are not supported for the 'guidance'"
" structured output backend. Please either use a more recent "
"Mistral model, the ['xgrammar', 'outlines'] "
"backends or tokenizer_mode='hf' instead."
)
# TODO: ideally we would have the LLTokenizer here as Lark syntax
# allows <|special_token|> and similar, see
# https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens
# Without tokenizer these are disallowed in grammars.
if
is_mistral_tokenizer
(
tokenizer
):
raise
ValueError
(
"Mistral tokenizer is not supported for the 'guidance' "
"structured output backend. Please use ['xgrammar', 'outlines'] "
"backends or tokenizer_mode='hf' instead."
)
validate_guidance_grammar
(
self
,
tokenizer
=
None
)
elif
backend
==
"outlines"
:
# outlines backend
...
...
@@ -839,19 +844,20 @@ class SamplingParams(
# or includes some jsonschema feature(s) that
# are not supported in xgrammar.
skip_guidance
=
_is_non_tekken_mistral
(
tokenizer
)
# Check if schema has features unsupported by guidance
so_params
=
self
.
structured_outputs
skip_guidance
=
False
if
so_params
.
json
:
if
not
skip_guidance
and
so_params
.
json
:
if
isinstance
(
so_params
.
json
,
str
):
schema
=
json_mod
.
loads
(
so_params
.
json
)
else
:
schema
=
so_params
.
json
skip_guidance
=
has_guidance_unsupported_json_features
(
schema
)
if
is_mistral_tokenizer
(
tokenizer
)
or
skip_guidance
:
# Fall back to outlines if the tokenizer is Mistral
#
or if
schema contains features unsupported by guidance
if
skip_guidance
:
# Fall back to outlines if the tokenizer is
non-tekken
Mistral
or
#
the
schema contains features unsupported by guidance
validate_structured_output_request_outlines
(
self
)
self
.
structured_outputs
.
_backend
=
"outlines"
else
:
...
...
vllm/tokenizers/mistral.py
View file @
fef56c18
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections.abc
import
Sequence
from
functools
import
cached_property
from
pathlib
import
Path
from
typing
import
TYPE_CHECKING
,
Any
,
cast
,
overload
from
mistral_common.guidance.grammar_factory
import
GrammarFactory
from
mistral_common.guidance.tokenizer
import
from_mistral_tokenizer
from
mistral_common.protocol.instruct.request
import
(
ChatCompletionRequest
as
MistralChatCompletionRequest
,
)
...
...
@@ -45,6 +48,7 @@ except ImportError:
)
if
TYPE_CHECKING
:
import
llguidance
from
transformers
import
BatchEncoding
logger
=
init_logger
(
__name__
)
...
...
@@ -574,3 +578,24 @@ class MistralTokenizer(TokenizerLike):
]
return
tokens
@
property
def
supports_grammar
(
self
)
->
bool
:
return
GrammarFactory
.
is_supported
(
self
.
mistral
)
@
cached_property
def
grammar_factory
(
self
)
->
GrammarFactory
:
if
not
self
.
supports_grammar
:
raise
AttributeError
(
"This tokenizer does not support `grammar_factory`. "
"This is only supported for tekken tokenizers with "
"version >= 11."
)
# Cache grammar factory to avoid creating a llguidance tokenizer at every usage.
return
GrammarFactory
(
self
.
mistral
)
@
cached_property
def
llg_tokenizer
(
self
)
->
"llguidance.LLTokenizer"
:
if
not
self
.
is_tekken
:
raise
ValueError
(
"`llg_tokenizer` is only supported for Tekkenizers."
)
return
from_mistral_tokenizer
(
self
.
mistral
)
vllm/tool_parsers/mistral_tool_parser.py
View file @
fef56c18
...
...
@@ -10,6 +10,18 @@ from typing import Any
import
ijson
import
regex
as
re
from
mistral_common.protocol.instruct.tool_calls
import
(
NamedToolChoice
as
MistralNamedToolChoice
,
)
from
mistral_common.protocol.instruct.tool_calls
import
(
Tool
as
MistralTool
,
)
from
mistral_common.protocol.instruct.tool_calls
import
(
ToolChoice
as
MistralToolChoice
,
)
from
mistral_common.protocol.instruct.tool_calls
import
(
ToolChoiceEnum
as
MistralToolChoiceEnum
,
)
from
pydantic
import
Field
from
vllm.entrypoints.openai.chat_completion.protocol
import
(
...
...
@@ -25,6 +37,7 @@ from vllm.entrypoints.openai.engine.protocol import (
)
from
vllm.entrypoints.openai.responses.protocol
import
ResponsesRequest
from
vllm.logger
import
init_logger
from
vllm.sampling_params
import
StructuredOutputsParams
from
vllm.tokenizers
import
TokenizerLike
from
vllm.tool_parsers.abstract_tool_parser
import
(
Tool
,
...
...
@@ -36,6 +49,8 @@ logger = init_logger(__name__)
ALPHANUMERIC
=
ascii_letters
+
digits
_DEFAULT_JSON_SCHEMA
=
{
"anyOf"
:
[{
"type"
:
"object"
},
{
"type"
:
"array"
}]}
class
StreamingState
(
Enum
):
"""Enum for tracking the current streaming parsing state."""
...
...
@@ -80,6 +95,9 @@ class MistralToolParser(ToolParser):
Used when --enable-auto-tool-choice --tool-call-parser mistral are all set
"""
# Used to generate correct grammar in `adjust_request`
model_can_reason
:
bool
=
False
def
__init__
(
self
,
tokenizer
:
TokenizerLike
,
tools
:
list
[
Tool
]
|
None
=
None
):
super
().
__init__
(
tokenizer
,
tools
)
...
...
@@ -115,12 +133,34 @@ class MistralToolParser(ToolParser):
def
adjust_request
(
self
,
request
:
ChatCompletionRequest
|
ResponsesRequest
)
->
ChatCompletionRequest
|
ResponsesRequest
:
request
=
super
().
adjust_request
(
request
)
so_non_supported_attributes
=
[
"regex"
,
"choice"
,
"grammar"
,
# whitespace_pattern is not a constraint type but an option;
# Mistral grammar factory does not support it.
"whitespace_pattern"
,
"structural_tag"
,
]
any_so_non_supported_active
=
request
.
structured_outputs
is
not
None
and
any
(
getattr
(
request
.
structured_outputs
,
attribute
)
is
not
None
for
attribute
in
so_non_supported_attributes
)
response_format_non_supported_active
=
(
isinstance
(
request
,
ResponsesRequest
)
or
request
.
response_format
is
not
None
and
request
.
response_format
.
type
==
"structural_tag"
)
if
(
not
is_mistral_tokenizer
(
self
.
model_tokenizer
)
and
request
.
tools
and
request
.
tool_choice
!=
"none"
or
isinstance
(
request
,
ResponsesRequest
)
or
not
self
.
model_tokenizer
.
supports_grammar
or
any_so_non_supported_active
or
response_format_non_supported_active
):
request
=
super
().
adjust_request
(
request
)
if
request
.
tools
and
request
.
tool_choice
!=
"none"
:
# Do not skip special tokens when using chat template
# with Mistral parser as TOOL_CALL token is needed
# for tool detection.
...
...
@@ -129,6 +169,90 @@ class MistralToolParser(ToolParser):
request
.
skip_special_tokens
=
False
return
request
json_schema
:
dict
[
str
,
Any
]
|
None
=
None
if
request
.
structured_outputs
is
not
None
:
if
request
.
structured_outputs
.
json_object
is
not
None
:
json_schema
=
_DEFAULT_JSON_SCHEMA
elif
request
.
structured_outputs
.
json
is
not
None
:
if
isinstance
(
request
.
structured_outputs
.
json
,
str
):
json_schema
=
json
.
loads
(
request
.
structured_outputs
.
json
)
else
:
json_schema
=
request
.
structured_outputs
.
json
else
:
raise
ValueError
(
"Unsupported request.structured_outputs for MistralToolParser. "
"Only `json` and `json_object` are supported."
)
elif
(
request
.
response_format
is
not
None
and
request
.
response_format
.
type
!=
"text"
):
if
request
.
response_format
.
type
==
"json_object"
:
json_schema
=
_DEFAULT_JSON_SCHEMA
elif
request
.
response_format
.
type
==
"json_schema"
:
if
request
.
response_format
.
json_schema
is
not
None
:
json_schema
=
request
.
response_format
.
json_schema
.
json_schema
else
:
json_schema
=
_DEFAULT_JSON_SCHEMA
else
:
raise
ValueError
(
"MistralToolParser only accepts `text`, `json_object` or "
f
"`json_schema`, got
{
request
.
response_format
=
}
"
)
# Structured Outputs will be defined.
request
.
response_format
=
None
grammar_factory
=
self
.
model_tokenizer
.
grammar_factory
# TODO: Once unified parser, improve this.
# The issue is figuring out when a model is a reasoning one or not.
template
=
grammar_factory
.
select_jinja_template
(
reasoning
=
self
.
model_can_reason
)
tools
=
(
[
MistralTool
.
from_openai
(
openai_tool
=
tool
.
model_dump
())
for
tool
in
request
.
tools
]
if
request
.
tools
is
not
None
else
None
)
tool_choice
:
MistralToolChoice
match
request
.
tool_choice
:
case
"none"
|
"auto"
|
"required"
:
tool_choice
=
MistralToolChoiceEnum
(
request
.
tool_choice
)
case
None
:
tool_choice
=
MistralToolChoiceEnum
.
auto
# _ == Named tool choice
case
_
:
tool_choice
=
MistralNamedToolChoice
.
model_validate
(
{
"type"
:
"function"
,
"function"
:
{
"name"
:
request
.
tool_choice
.
function
.
name
},
}
)
# Rendering grammar is cached in mistral-common given tools, template and mode.
match
tool_choice
,
json_schema
is
not
None
:
case
MistralToolChoiceEnum
.
none
,
True
:
lark_grammar
=
grammar_factory
.
get_lark_for_json_schema
(
template
=
template
,
json_schema
=
json_schema
)
case
_
,
_
:
lark_grammar
=
grammar_factory
.
get_lark_from_jinja
(
template
=
template
,
mode
=
tool_choice
,
tools
=
tools
,
json_schema
=
json_schema
,
parallel_tool_calls
=
request
.
parallel_tool_calls
,
json_only
=
False
,
)
request
.
structured_outputs
=
StructuredOutputsParams
(
grammar
=
lark_grammar
)
return
request
def
extract_tool_calls
(
self
,
model_output
:
str
,
...
...
vllm/v1/structured_output/backend_guidance.py
View file @
fef56c18
...
...
@@ -12,6 +12,7 @@ import torch
from
vllm.logger
import
init_logger
from
vllm.sampling_params
import
SamplingParams
from
vllm.utils.import_utils
import
LazyLoader
from
vllm.utils.mistral
import
is_mistral_tokenizer
from
vllm.v1.structured_output.backend_types
import
(
StructuredOutputBackend
,
StructuredOutputGrammar
,
...
...
@@ -92,6 +93,9 @@ class GuidanceBackend(StructuredOutputBackend):
self
.
vllm_config
.
structured_outputs_config
.
disable_additional_properties
)
if
is_mistral_tokenizer
(
self
.
tokenizer
):
self
.
ll_tokenizer
=
self
.
tokenizer
.
llg_tokenizer
else
:
self
.
ll_tokenizer
=
llguidance_hf
.
from_tokenizer
(
self
.
tokenizer
,
max
(
self
.
vocab_size
,
len
(
self
.
tokenizer
))
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment