Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
11cd1ae6
Unverified
Commit
11cd1ae6
authored
Nov 15, 2024
by
Patrick von Platen
Committed by
GitHub
Nov 15, 2024
Browse files
[Tool parsing] Improve / correct mistral tool parsing (#10333)
parent
554af922
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
172 additions
and
59 deletions
+172
-59
tests/models/decoder_only/language/test_mistral.py
tests/models/decoder_only/language/test_mistral.py
+82
-11
vllm/entrypoints/openai/serving_chat.py
vllm/entrypoints/openai/serving_chat.py
+5
-34
vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py
vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py
+17
-8
vllm/transformers_utils/tokenizers/__init__.py
vllm/transformers_utils/tokenizers/__init__.py
+2
-2
vllm/transformers_utils/tokenizers/mistral.py
vllm/transformers_utils/tokenizers/mistral.py
+66
-4
No files found.
tests/models/decoder_only/language/test_mistral.py
View file @
11cd1ae6
...
@@ -2,9 +2,13 @@
...
@@ -2,9 +2,13 @@
Run `pytest tests/models/test_mistral.py`.
Run `pytest tests/models/test_mistral.py`.
"""
"""
import
copy
import
pytest
import
pytest
from
vllm
import
SamplingParams
from
vllm
import
SamplingParams
from
vllm.entrypoints.openai.tool_parsers.mistral_tool_parser
import
(
# noqa
MistralToolParser
)
from
...utils
import
check_logprobs_close
from
...utils
import
check_logprobs_close
...
@@ -58,17 +62,69 @@ TOOLS = [{
...
@@ -58,17 +62,69 @@ TOOLS = [{
},
},
"required"
:
[
"city"
,
"state"
,
"unit"
]
"required"
:
[
"city"
,
"state"
,
"unit"
]
}
}
},
},
{
"type"
:
"function"
,
"function"
:
{
"name"
:
"rewrite"
,
"description"
:
"Rewrites text"
,
"parameters"
:
{
"type"
:
"object"
,
"required"
:
[],
"properties"
:
{
"text"
:
{
"type"
:
"string"
,
"description"
:
"The input text to rewrite."
}
}
}
}
}
}]
}]
MSGS
=
[{
MSGS
=
[
{
"role"
:
"system"
,
"content"
:
"You are an assistant."
},
{
"role"
:
"user"
,
"content"
:
"Could you please rewrite the below article?
\n\n
My English needs improvving, maybe I make errors."
# noqa
},
{
"role"
:
"assistant"
,
"content"
:
""
,
"tool_calls"
:
[{
"id"
:
"bbc5b7ede"
,
"type"
:
"function"
,
"function"
:
{
"name"
:
"rewrite"
,
"arguments"
:
'{
\"
text
\"
:
\"
My English needs improvving, maybe I make errors.
\"
}'
# noqa
}
}]
},
{
"role"
:
"tool"
,
"content"
:
"{
\"
action
\"
:
\"
rewrite
\"
,
\"
outcome
\"
:
\"
My English needs improving, maybe I make errors.
\"
}"
,
# noqa
"tool_call_id"
:
"bbc5b7ede"
,
"name"
:
"rewrite"
},
{
"role"
:
"assistant"
,
"content"
:
"---
\n\n
My English needs improving, maybe I make errors"
},
{
"role"
:
"role"
:
"user"
,
"user"
,
"content"
:
(
"Can you tell me what the temperate"
"content"
:
(
"Can you tell me what the temperate"
" will be in Dallas, in fahrenheit?"
)
" will be in Dallas, in fahrenheit?"
)
}]
}
EXPECTED_FUNC_CALL
=
(
]
'[{"name": "get_current_weather", "arguments": '
'{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]'
)
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
...
@@ -175,8 +231,23 @@ def test_mistral_function_calling(
...
@@ -175,8 +231,23 @@ def test_mistral_function_calling(
tokenizer_mode
=
"mistral"
,
tokenizer_mode
=
"mistral"
,
config_format
=
"mistral"
,
config_format
=
"mistral"
,
load_format
=
"mistral"
)
as
vllm_model
:
load_format
=
"mistral"
)
as
vllm_model
:
outputs
=
vllm_model
.
model
.
chat
(
MSGS
,
msgs
=
copy
.
deepcopy
(
MSGS
)
outputs
=
vllm_model
.
model
.
chat
(
msgs
,
tools
=
TOOLS
,
tools
=
TOOLS
,
sampling_params
=
SAMPLING_PARAMS
)
sampling_params
=
SAMPLING_PARAMS
)
assert
outputs
[
0
].
outputs
[
0
].
text
.
strip
()
==
EXPECTED_FUNC_CALL
tokenizer
=
vllm_model
.
model
.
get_tokenizer
()
tool_parser
=
MistralToolParser
(
tokenizer
)
model_output
=
outputs
[
0
].
outputs
[
0
].
text
.
strip
()
assert
model_output
.
startswith
(
tool_parser
.
bot_token
),
model_output
parsed_message
=
tool_parser
.
extract_tool_calls
(
model_output
,
None
)
assert
parsed_message
.
tools_called
assert
parsed_message
.
tool_calls
[
0
].
id
==
"0UAqFzWsD"
assert
parsed_message
.
tool_calls
[
0
].
function
.
name
==
"get_current_weather"
assert
parsed_message
.
tool_calls
[
0
].
function
.
arguments
==
'{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}'
# noqa
assert
parsed_message
.
content
is
None
vllm/entrypoints/openai/serving_chat.py
View file @
11cd1ae6
...
@@ -30,6 +30,7 @@ from vllm.outputs import CompletionOutput, RequestOutput
...
@@ -30,6 +30,7 @@ from vllm.outputs import CompletionOutput, RequestOutput
from
vllm.sampling_params
import
BeamSearchParams
,
SamplingParams
from
vllm.sampling_params
import
BeamSearchParams
,
SamplingParams
from
vllm.sequence
import
Logprob
from
vllm.sequence
import
Logprob
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
,
MistralTokenizer
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
,
MistralTokenizer
from
vllm.transformers_utils.tokenizers
import
maybe_serialize_tool_calls
from
vllm.utils
import
iterate_with_cancellation
from
vllm.utils
import
iterate_with_cancellation
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -127,41 +128,11 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -127,41 +128,11 @@ class OpenAIServingChat(OpenAIServing):
return
self
.
create_error_response
(
return
self
.
create_error_response
(
"tool_choice =
\"
required
\"
is not supported!"
)
"tool_choice =
\"
required
\"
is not supported!"
)
# NOTE: There is currently a bug in pydantic where attributes
# because of issues with pydantic we need to potentially
# declared as iterables are replaced in in the instances by
# re-serialize the tool_calls field of the request
# pydantic-core ValidatorIterator instance. In particular, this
# for more info: see comment in `maybe_serialize_tool_calls`
# affects tool_calls defined in ChatCompletionAssistantMessageParam
# model:
# see:
# - https://github.com/pydantic/pydantic/issues/9467
# As a result, tool_calls from assistant messages are never
# deserialized in the request object if the tool_calls iterator is
# not consumed. This affect messages passed to the MistralTokenizer
# since no chat template is applied and therefore the tools_calls
# iterator is not directly consumed.
# Issue is tracked on Pydantic side, with resolution planned for
# v2.11 release. In the meantime, the official workaround is to
# consume the iterator so the tool_calls are correctly deserialized
# in the OpenAI ChatCompletionAssistantMessageParam object
# https://github.com/pydantic/pydantic/issues/9467#issuecomment-2442097291 # noqa: E501
# Official Pydantic Issues:
# - https://github.com/pydantic/pydantic/issues/9541
# TODO: remove when pydantic v2.11 is released
if
isinstance
(
tokenizer
,
MistralTokenizer
):
if
isinstance
(
tokenizer
,
MistralTokenizer
):
for
i
,
message
in
enumerate
(
request
.
messages
):
maybe_serialize_tool_calls
(
request
)
if
message
.
get
(
"role"
)
==
'assistant'
:
tool_calls_validator
=
message
.
get
(
"tool_calls"
,
().
__iter__
())
validated_tool_calls
=
[]
while
True
:
try
:
tool_call
=
next
(
tool_calls_validator
)
# type: ignore
validated_tool_calls
.
append
(
tool_call
)
except
StopIteration
:
break
request
.
messages
[
i
][
"tool_calls"
]
=
validated_tool_calls
if
(
request
.
tool_choice
==
"auto"
and
if
(
request
.
tool_choice
==
"auto"
and
not
(
self
.
enable_auto_tools
and
tool_parser
is
not
None
)
not
(
self
.
enable_auto_tools
and
tool_parser
is
not
None
)
...
...
vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py
View file @
11cd1ae6
...
@@ -62,7 +62,7 @@ class MistralToolParser(ToolParser):
...
@@ -62,7 +62,7 @@ class MistralToolParser(ToolParser):
]
# map what has been streamed for each tool so far to a list
]
# map what has been streamed for each tool so far to a list
self
.
bot_token
=
"[TOOL_CALLS]"
self
.
bot_token
=
"[TOOL_CALLS]"
self
.
bot_token_id
=
self
.
vocab
.
get
(
self
.
bot_token
)
self
.
bot_token_id
=
self
.
vocab
.
get
(
self
.
bot_token
)
self
.
tool_call_regex
=
re
.
compile
(
r
"\[{.*
?
}\]"
,
re
.
DOTALL
)
self
.
tool_call_regex
=
re
.
compile
(
r
"\[{.*}\]"
,
re
.
DOTALL
)
if
self
.
bot_token_id
is
None
:
if
self
.
bot_token_id
is
None
:
raise
RuntimeError
(
raise
RuntimeError
(
"Mistral Tool Parser could not locate the tool call token in "
"Mistral Tool Parser could not locate the tool call token in "
...
@@ -84,16 +84,25 @@ class MistralToolParser(ToolParser):
...
@@ -84,16 +84,25 @@ class MistralToolParser(ToolParser):
return
ExtractedToolCallInformation
(
tools_called
=
False
,
return
ExtractedToolCallInformation
(
tools_called
=
False
,
tool_calls
=
[],
tool_calls
=
[],
content
=
model_output
)
content
=
model_output
)
# first remove the BOT token
tool_content
=
model_output
.
replace
(
self
.
bot_token
,
""
).
strip
()
try
:
try
:
# use a regex to find the tool call. remove the BOT token
# we first try to directly load the json as parsing very nested
# and make sure to replace single quotes with double quotes
# jsons is difficult
raw_tool_call
=
self
.
tool_call_regex
.
findall
(
try
:
model_output
.
replace
(
self
.
bot_token
,
""
))[
0
]
function_call_arr
=
json
.
loads
(
tool_content
)
except
json
.
JSONDecodeError
:
# use a regex to find the part corresponding to the tool call.
# NOTE: This use case should not happen if the model is trained
# correctly. It's a easy possible fix so it's included, but
# can be brittle for very complex / highly nested tool calls
raw_tool_call
=
self
.
tool_call_regex
.
findall
(
tool_content
)[
0
]
function_call_arr
=
json
.
loads
(
raw_tool_call
)
# load the JSON, and then use it to build the Function and
# Tool Call
# Tool Call
function_call_arr
=
json
.
loads
(
raw_tool_call
)
tool_calls
:
List
[
MistralToolCall
]
=
[
tool_calls
:
List
[
MistralToolCall
]
=
[
MistralToolCall
(
MistralToolCall
(
type
=
"function"
,
type
=
"function"
,
...
@@ -116,7 +125,7 @@ class MistralToolParser(ToolParser):
...
@@ -116,7 +125,7 @@ class MistralToolParser(ToolParser):
# return information to just treat the tool call as regular JSON
# return information to just treat the tool call as regular JSON
return
ExtractedToolCallInformation
(
tools_called
=
False
,
return
ExtractedToolCallInformation
(
tools_called
=
False
,
tool_calls
=
[],
tool_calls
=
[],
content
=
model_outpu
t
)
content
=
tool_conten
t
)
def
extract_tool_calls_streaming
(
def
extract_tool_calls_streaming
(
self
,
self
,
...
...
vllm/transformers_utils/tokenizers/__init__.py
View file @
11cd1ae6
from
.mistral
import
MistralTokenizer
from
.mistral
import
MistralTokenizer
,
maybe_serialize_tool_calls
__all__
=
[
"MistralTokenizer"
]
__all__
=
[
"MistralTokenizer"
,
"maybe_serialize_tool_calls"
]
vllm/transformers_utils/tokenizers/mistral.py
View file @
11cd1ae6
...
@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast
...
@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast
import
huggingface_hub
import
huggingface_hub
from
huggingface_hub
import
HfApi
,
hf_hub_download
from
huggingface_hub
import
HfApi
,
hf_hub_download
from
mistral_common.protocol.instruct.request
import
ChatCompletionRequest
from
mistral_common.protocol.instruct.request
import
ChatCompletionRequest
from
mistral_common.tokens.tokenizers.base
import
SpecialTokens
# yapf: disable
# yapf: disable
from
mistral_common.tokens.tokenizers.mistral
import
(
from
mistral_common.tokens.tokenizers.mistral
import
(
MistralTokenizer
as
PublicMistralTokenizer
)
MistralTokenizer
as
PublicMistralTokenizer
)
...
@@ -29,6 +30,43 @@ class Encoding:
...
@@ -29,6 +30,43 @@ class Encoding:
input_ids
:
List
[
int
]
input_ids
:
List
[
int
]
def
maybe_serialize_tool_calls
(
request
:
ChatCompletionRequest
):
# SEE: https://github.com/vllm-project/vllm/pull/9951
# Credits go to: @gcalmettes
# NOTE: There is currently a bug in pydantic where attributes
# declared as iterables are replaced in in the instances by
# pydantic-core ValidatorIterator instance. In particular, this
# affects tool_calls defined in ChatCompletionAssistantMessageParam
# model:
# see:
# - https://github.com/pydantic/pydantic/issues/9467
# As a result, tool_calls from assistant messages are never
# deserialized in the request object if the tool_calls iterator is
# not consumed. This affect messages passed to the MistralTokenizer
# since no chat template is applied and therefore the tools_calls
# iterator is not directly consumed.
# Issue is tracked on Pydantic side, with resolution planned for
# v2.11 release. In the meantime, the official workaround is to
# consume the iterator so the tool_calls are correctly deserialized
# in the OpenAI ChatCompletionAssistantMessageParam object
# https://github.com/pydantic/pydantic/issues/9467#issuecomment-2442097291 # noqa: E501
# Official Pydantic Issues:
# - https://github.com/pydantic/pydantic/issues/9541
# TODO: remove when pydantic v2.11 is released
for
i
,
message
in
enumerate
(
request
.
messages
):
if
message
.
get
(
"role"
)
==
'assistant'
:
tool_calls_validator
=
message
.
get
(
"tool_calls"
,
().
__iter__
())
validated_tool_calls
=
[]
while
True
:
try
:
tool_call
=
next
(
tool_calls_validator
)
# type: ignore
validated_tool_calls
.
append
(
tool_call
)
except
StopIteration
:
break
request
.
messages
[
i
][
"tool_calls"
]
=
validated_tool_calls
def
list_local_repo_files
(
repo_id
:
str
,
revision
:
Optional
[
str
])
->
List
[
str
]:
def
list_local_repo_files
(
repo_id
:
str
,
revision
:
Optional
[
str
])
->
List
[
str
]:
repo_cache
=
os
.
path
.
join
(
repo_cache
=
os
.
path
.
join
(
huggingface_hub
.
constants
.
HF_HUB_CACHE
,
huggingface_hub
.
constants
.
HF_HUB_CACHE
,
...
@@ -222,7 +260,8 @@ class MistralTokenizer:
...
@@ -222,7 +260,8 @@ class MistralTokenizer:
if
self
.
is_tekken
:
if
self
.
is_tekken
:
tokens
=
[
tokens
=
[
t
for
t
in
tokens
t
for
t
in
tokens
if
t
not
in
self
.
tokenizer
.
_all_special_tokens
if
(
t
is
SpecialTokens
.
tool_calls
or
t
not
in
self
.
tokenizer
.
_all_special_tokens
)
]
]
if
any
(
isinstance
(
t
,
bytes
)
for
t
in
tokens
):
if
any
(
isinstance
(
t
,
bytes
)
for
t
in
tokens
):
...
@@ -246,7 +285,27 @@ class MistralTokenizer:
...
@@ -246,7 +285,27 @@ class MistralTokenizer:
else
:
else
:
decoded
=
""
.
join
(
tokens
)
decoded
=
""
.
join
(
tokens
)
else
:
else
:
decoded
=
self
.
tokenizer
.
decode
(
tokens
)
# type: ignore[arg-type]
# make sure certain special tokens like Tool calls are
# not decoded
special_tokens
=
{
SpecialTokens
.
tool_calls
}
regular_tokens
:
List
[
str
]
=
[]
decoded_list
=
[]
for
token
in
tokens
:
if
token
in
special_tokens
:
if
regular_tokens
:
decoded_list
.
append
(
self
.
tokenizer
.
decode
(
regular_tokens
))
regular_tokens
=
[]
decoded_list
.
append
(
token
)
else
:
regular_tokens
.
append
(
token
)
if
regular_tokens
:
decoded_list
.
append
(
self
.
decode
(
regular_tokens
))
# type: ignore
decoded
=
''
.
join
(
decoded_list
)
return
decoded
return
decoded
...
@@ -274,8 +333,11 @@ class MistralTokenizer:
...
@@ -274,8 +333,11 @@ class MistralTokenizer:
assert
self
.
is_tekken
or
self
.
is_spm
,
type
(
self
.
tokenizer
)
assert
self
.
is_tekken
or
self
.
is_spm
,
type
(
self
.
tokenizer
)
if
self
.
is_tekken
:
if
self
.
is_tekken
:
# skip special tokens
# skip special tokens except tool call
ids
=
[
i
for
i
in
ids
if
i
>
self
.
tokenizer
.
num_special_tokens
]
ids
=
[
i
for
i
in
ids
if
i
>
self
.
tokenizer
.
num_special_tokens
or
i
==
self
.
tokenizer
.
get_control_token
(
SpecialTokens
.
tool_calls
)
]
tokens
=
[
self
.
tokenizer
.
id_to_piece
(
id
)
for
id
in
ids
]
tokens
=
[
self
.
tokenizer
.
id_to_piece
(
id
)
for
id
in
ids
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment