Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
19dcc02a
Unverified
Commit
19dcc02a
authored
Apr 25, 2025
by
Cyrus Leung
Committed by
GitHub
Apr 25, 2025
Browse files
[Bugfix] Fix mistral model tests (#17181)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
7feae92c
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
40 additions
and
28 deletions
+40
-28
tests/models/decoder_only/language/test_mistral.py
tests/models/decoder_only/language/test_mistral.py
+36
-28
vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py
vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py
+4
-0
No files found.
tests/models/decoder_only/language/test_mistral.py
View file @
19dcc02a
...
...
@@ -10,8 +10,8 @@ import jsonschema
import
jsonschema.exceptions
import
pytest
from
vllm.entrypoints.openai.tool_parsers.mistral_tool_parser
import
(
# noqa
MistralToolParser
)
from
vllm.entrypoints.openai.tool_parsers.mistral_tool_parser
import
(
MistralToolCall
,
MistralToolParser
)
from
vllm.sampling_params
import
GuidedDecodingParams
,
SamplingParams
from
...utils
import
check_logprobs_close
...
...
@@ -194,7 +194,6 @@ def test_models(hf_runner, vllm_runner, example_prompts, model: str,
)
@
pytest
.
mark
.
skip
(
"RE-ENABLE: test is currently failing on main."
)
@
pytest
.
mark
.
parametrize
(
"model"
,
MISTRAL_FORMAT_MODELS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"bfloat16"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
64
])
...
...
@@ -246,10 +245,8 @@ def test_mistral_symbolic_languages(vllm_runner, model: str,
assert
"�"
not
in
outputs
[
0
].
outputs
[
0
].
text
.
strip
()
@
pytest
.
mark
.
skip
(
"RE-ENABLE: test is currently failing on main."
)
@
pytest
.
mark
.
parametrize
(
"model"
,
MISTRAL_FORMAT_MODELS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"bfloat16"
])
@
pytest
.
mark
.
parametrize
(
"model"
,
MISTRAL_FORMAT_MODELS
)
# v1 can't do func calling
def
test_mistral_function_calling
(
vllm_runner
,
model
:
str
,
dtype
:
str
)
->
None
:
with
vllm_runner
(
model
,
dtype
=
dtype
,
...
...
@@ -270,7 +267,8 @@ def test_mistral_function_calling(vllm_runner, model: str, dtype: str) -> None:
parsed_message
=
tool_parser
.
extract_tool_calls
(
model_output
,
None
)
assert
parsed_message
.
tools_called
assert
parsed_message
.
tool_calls
[
0
].
id
==
"0UAqFzWsD"
assert
MistralToolCall
.
is_valid_id
(
parsed_message
.
tool_calls
[
0
].
id
)
assert
parsed_message
.
tool_calls
[
0
].
function
.
name
==
"get_current_weather"
assert
parsed_message
.
tool_calls
[
...
...
@@ -281,28 +279,38 @@ def test_mistral_function_calling(vllm_runner, model: str, dtype: str) -> None:
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"guided_backend"
,
[
"outlines"
,
"lm-format-enforcer"
,
"xgrammar"
])
def
test_mistral_guided_decoding
(
vllm_runner
,
model
:
str
,
guided_backend
:
str
)
->
None
:
with
vllm_runner
(
model
,
dtype
=
'bfloat16'
,
tokenizer_mode
=
"mistral"
)
as
vllm_model
:
def
test_mistral_guided_decoding
(
monkeypatch
:
pytest
.
MonkeyPatch
,
vllm_runner
,
model
:
str
,
guided_backend
:
str
,
)
->
None
:
with
monkeypatch
.
context
()
as
m
:
# Guided JSON not supported in xgrammar + V1 yet
m
.
setenv
(
"VLLM_USE_V1"
,
"0"
)
guided_decoding
=
GuidedDecodingParams
(
json
=
SAMPLE_JSON_SCHEMA
,
backend
=
guided_backend
)
params
=
SamplingParams
(
max_tokens
=
512
,
temperature
=
0.7
,
guided_decoding
=
guided_decoding
)
messages
=
[{
"role"
:
"system"
,
"content"
:
"you are a helpful assistant"
},
{
"role"
:
"user"
,
"content"
:
f
"Give an example JSON for an employee profile that "
f
"fits this schema:
{
SAMPLE_JSON_SCHEMA
}
"
}]
outputs
=
vllm_model
.
model
.
chat
(
messages
,
sampling_params
=
params
)
with
vllm_runner
(
model
,
dtype
=
'bfloat16'
,
tokenizer_mode
=
"mistral"
,
guided_decoding_backend
=
guided_backend
,
)
as
vllm_model
:
guided_decoding
=
GuidedDecodingParams
(
json
=
SAMPLE_JSON_SCHEMA
)
params
=
SamplingParams
(
max_tokens
=
512
,
temperature
=
0.7
,
guided_decoding
=
guided_decoding
)
messages
=
[{
"role"
:
"system"
,
"content"
:
"you are a helpful assistant"
},
{
"role"
:
"user"
,
"content"
:
f
"Give an example JSON for an employee profile that "
f
"fits this schema:
{
SAMPLE_JSON_SCHEMA
}
"
}]
outputs
=
vllm_model
.
model
.
chat
(
messages
,
sampling_params
=
params
)
generated_text
=
outputs
[
0
].
outputs
[
0
].
text
json_response
=
json
.
loads
(
generated_text
)
...
...
vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py
View file @
19dcc02a
...
...
@@ -38,6 +38,10 @@ class MistralToolCall(ToolCall):
# https://github.com/mistralai/mistral-common/blob/21ee9f6cee3441e9bb1e6ed2d10173f90bd9b94b/src/mistral_common/protocol/instruct/validator.py#L299
return
""
.
join
(
choices
(
ALPHANUMERIC
,
k
=
9
))
@
staticmethod
def
is_valid_id
(
id
:
str
)
->
bool
:
return
id
.
isalnum
()
and
len
(
id
)
==
9
@
ToolParserManager
.
register_module
(
"mistral"
)
class
MistralToolParser
(
ToolParser
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment