Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
314cfade
Unverified
Commit
314cfade
authored
Feb 12, 2025
by
Rafael Vasquez
Committed by
GitHub
Feb 12, 2025
Browse files
[Frontend] Generate valid tool call IDs when using `tokenizer-mode=mistral` (#12332)
parent
985b4a2b
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
149 additions
and
8 deletions
+149
-8
tests/mistral_tool_use/__init__.py
tests/mistral_tool_use/__init__.py
+0
-0
tests/mistral_tool_use/conftest.py
tests/mistral_tool_use/conftest.py
+40
-0
tests/mistral_tool_use/test_mistral_tool_calls.py
tests/mistral_tool_use/test_mistral_tool_calls.py
+29
-0
tests/mistral_tool_use/utils.py
tests/mistral_tool_use/utils.py
+33
-0
vllm/entrypoints/openai/serving_chat.py
vllm/entrypoints/openai/serving_chat.py
+11
-5
vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py
vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py
+1
-1
vllm/transformers_utils/tokenizers/__init__.py
vllm/transformers_utils/tokenizers/__init__.py
+5
-2
vllm/transformers_utils/tokenizers/mistral.py
vllm/transformers_utils/tokenizers/mistral.py
+30
-0
No files found.
tests/mistral_tool_use/__init__.py
0 → 100644
View file @
314cfade
tests/mistral_tool_use/conftest.py
0 → 100644
View file @
314cfade
# SPDX-License-Identifier: Apache-2.0
import
pytest
import
pytest_asyncio
from
huggingface_hub
import
snapshot_download
from
tests.utils
import
RemoteOpenAIServer
from
vllm.platforms
import
current_platform
from
.utils
import
ARGS
,
CONFIGS
,
ServerConfig
# for each server config, download the model and return the config
@
pytest
.
fixture
(
scope
=
"session"
,
params
=
CONFIGS
.
keys
())
def
server_config
(
request
):
config
=
CONFIGS
[
request
.
param
]
if
current_platform
.
is_rocm
()
and
not
config
.
get
(
"supports_rocm"
,
True
):
pytest
.
skip
(
"The {} model can't be tested on the ROCm platform"
.
format
(
config
[
"model"
]))
# download model and tokenizer using transformers
snapshot_download
(
config
[
"model"
])
yield
CONFIGS
[
request
.
param
]
# run this for each server config
@
pytest
.
fixture
(
scope
=
"session"
)
def
server
(
request
,
server_config
:
ServerConfig
):
model
=
server_config
[
"model"
]
args_for_model
=
server_config
[
"arguments"
]
with
RemoteOpenAIServer
(
model
,
ARGS
+
args_for_model
,
max_wait_seconds
=
480
)
as
server
:
yield
server
@
pytest_asyncio
.
fixture
async
def
client
(
server
:
RemoteOpenAIServer
):
async
with
server
.
get_async_client
()
as
async_client
:
yield
async_client
tests/mistral_tool_use/test_mistral_tool_calls.py
0 → 100644
View file @
314cfade
# SPDX-License-Identifier: Apache-2.0
import
openai
import
pytest
from
tests.tool_use.utils
import
MESSAGES_ASKING_FOR_TOOLS
,
WEATHER_TOOL
# test: a tool_choice with mistral-tokenizer results in an ID of length 9
@
pytest
.
mark
.
asyncio
async
def
test_tool_call_with_tool_choice
(
client
:
openai
.
AsyncOpenAI
):
models
=
await
client
.
models
.
list
()
model_name
:
str
=
models
.
data
[
0
].
id
chat_completion
=
await
client
.
chat
.
completions
.
create
(
messages
=
MESSAGES_ASKING_FOR_TOOLS
,
temperature
=
0
,
max_completion_tokens
=
100
,
model
=
model_name
,
tools
=
[
WEATHER_TOOL
],
tool_choice
=
WEATHER_TOOL
,
logprobs
=
False
)
choice
=
chat_completion
.
choices
[
0
]
assert
choice
.
finish_reason
!=
"tool_calls"
# "stop" or "length"
assert
choice
.
message
.
role
==
"assistant"
assert
choice
.
message
.
tool_calls
is
None
\
or
len
(
choice
.
message
.
tool_calls
)
==
1
assert
len
(
choice
.
message
.
tool_calls
[
0
].
id
)
==
9
# length of 9 for mistral
tests/mistral_tool_use/utils.py
0 → 100644
View file @
314cfade
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Dict
,
List
,
Optional
from
typing_extensions
import
TypedDict
class
ServerConfig
(
TypedDict
,
total
=
False
):
model
:
str
arguments
:
List
[
str
]
system_prompt
:
Optional
[
str
]
supports_parallel
:
Optional
[
bool
]
supports_rocm
:
Optional
[
bool
]
ARGS
:
List
[
str
]
=
[
"--max-model-len"
,
"1024"
]
CONFIGS
:
Dict
[
str
,
ServerConfig
]
=
{
"mistral"
:
{
"model"
:
"mistralai/Mistral-7B-Instruct-v0.3"
,
"arguments"
:
[
"--tokenizer-mode"
,
"mistral"
,
"--ignore-patterns=
\"
consolidated.safetensors
\"
"
],
"system_prompt"
:
"You are a helpful assistant with access to tools. If a tool"
" that you have would be helpful to answer a user query, "
"call the tool. Otherwise, answer the user's query directly "
"without calling a tool. DO NOT CALL A TOOL THAT IS IRRELEVANT "
"to the user's question - just respond to it normally."
},
}
vllm/entrypoints/openai/serving_chat.py
View file @
314cfade
...
@@ -28,12 +28,15 @@ from vllm.entrypoints.openai.reasoning_parsers import (ReasoningParser,
...
@@ -28,12 +28,15 @@ from vllm.entrypoints.openai.reasoning_parsers import (ReasoningParser,
from
vllm.entrypoints.openai.serving_engine
import
OpenAIServing
from
vllm.entrypoints.openai.serving_engine
import
OpenAIServing
from
vllm.entrypoints.openai.serving_models
import
OpenAIServingModels
from
vllm.entrypoints.openai.serving_models
import
OpenAIServingModels
from
vllm.entrypoints.openai.tool_parsers
import
ToolParser
,
ToolParserManager
from
vllm.entrypoints.openai.tool_parsers
import
ToolParser
,
ToolParserManager
from
vllm.entrypoints.openai.tool_parsers.mistral_tool_parser
import
(
MistralToolCall
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.outputs
import
CompletionOutput
,
RequestOutput
from
vllm.outputs
import
CompletionOutput
,
RequestOutput
from
vllm.sampling_params
import
BeamSearchParams
,
SamplingParams
from
vllm.sampling_params
import
BeamSearchParams
,
SamplingParams
from
vllm.sequence
import
Logprob
from
vllm.sequence
import
Logprob
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
,
MistralTokenizer
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
,
MistralTokenizer
from
vllm.transformers_utils.tokenizers
import
maybe_serialize_tool_calls
from
vllm.transformers_utils.tokenizers
import
(
maybe_serialize_tool_calls
,
truncate_tool_call_ids
)
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -150,11 +153,12 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -150,11 +153,12 @@ class OpenAIServingChat(OpenAIServing):
return
self
.
create_error_response
(
return
self
.
create_error_response
(
"tool_choice =
\"
required
\"
is not supported!"
)
"tool_choice =
\"
required
\"
is not supported!"
)
if
isinstance
(
tokenizer
,
MistralTokenizer
):
# because of issues with pydantic we need to potentially
# because of issues with pydantic we need to potentially
# re-serialize the tool_calls field of the request
# re-serialize the tool_calls field of the request
# for more info: see comment in `maybe_serialize_tool_calls`
# for more info: see comment in `maybe_serialize_tool_calls`
if
isinstance
(
tokenizer
,
MistralTokenizer
):
maybe_serialize_tool_calls
(
request
)
maybe_serialize_tool_calls
(
request
)
truncate_tool_call_ids
(
request
)
if
(
request
.
tool_choice
==
"auto"
and
if
(
request
.
tool_choice
==
"auto"
and
not
(
self
.
enable_auto_tools
and
tool_parser
is
not
None
)
not
(
self
.
enable_auto_tools
and
tool_parser
is
not
None
)
...
@@ -745,11 +749,13 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -745,11 +749,13 @@ class OpenAIServingChat(OpenAIServing):
elif
request
.
tool_choice
and
type
(
elif
request
.
tool_choice
and
type
(
request
.
tool_choice
)
is
ChatCompletionNamedToolChoiceParam
:
request
.
tool_choice
)
is
ChatCompletionNamedToolChoiceParam
:
tool_call_class
=
MistralToolCall
if
isinstance
(
tokenizer
,
MistralTokenizer
)
else
ToolCall
message
=
ChatMessage
(
message
=
ChatMessage
(
role
=
role
,
role
=
role
,
content
=
""
,
content
=
""
,
tool_calls
=
[
tool_calls
=
[
T
ool
C
all
(
function
=
FunctionCall
(
t
ool
_c
all
_class
(
function
=
FunctionCall
(
name
=
request
.
tool_choice
.
function
.
name
,
name
=
request
.
tool_choice
.
function
.
name
,
arguments
=
output
.
text
))
arguments
=
output
.
text
))
])
])
...
...
vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py
View file @
314cfade
...
@@ -33,7 +33,7 @@ class MistralToolCall(ToolCall):
...
@@ -33,7 +33,7 @@ class MistralToolCall(ToolCall):
@
staticmethod
@
staticmethod
def
generate_random_id
():
def
generate_random_id
():
# Mistral Tool Call Ids must be alphanumeric with a
maximum
length of 9.
# Mistral Tool Call Ids must be alphanumeric with a length of 9.
# https://github.com/mistralai/mistral-common/blob/21ee9f6cee3441e9bb1e6ed2d10173f90bd9b94b/src/mistral_common/protocol/instruct/validator.py#L299
# https://github.com/mistralai/mistral-common/blob/21ee9f6cee3441e9bb1e6ed2d10173f90bd9b94b/src/mistral_common/protocol/instruct/validator.py#L299
return
""
.
join
(
choices
(
ALPHANUMERIC
,
k
=
9
))
return
""
.
join
(
choices
(
ALPHANUMERIC
,
k
=
9
))
...
...
vllm/transformers_utils/tokenizers/__init__.py
View file @
314cfade
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
from
.mistral
import
MistralTokenizer
,
maybe_serialize_tool_calls
from
.mistral
import
(
MistralTokenizer
,
maybe_serialize_tool_calls
,
truncate_tool_call_ids
)
__all__
=
[
"MistralTokenizer"
,
"maybe_serialize_tool_calls"
]
__all__
=
[
"MistralTokenizer"
,
"maybe_serialize_tool_calls"
,
"truncate_tool_call_ids"
]
vllm/transformers_utils/tokenizers/mistral.py
View file @
314cfade
...
@@ -68,6 +68,36 @@ def maybe_serialize_tool_calls(request: "ChatCompletionRequest"):
...
@@ -68,6 +68,36 @@ def maybe_serialize_tool_calls(request: "ChatCompletionRequest"):
request
.
messages
[
i
][
"tool_calls"
]
=
validated_tool_calls
request
.
messages
[
i
][
"tool_calls"
]
=
validated_tool_calls
def
truncate_tool_call_ids
(
request
:
"ChatCompletionRequest"
):
"""Truncates tool call IDs for Mistral's ID requirements."""
for
i
,
message
in
enumerate
(
request
.
messages
):
if
message
.
get
(
"role"
)
==
'assistant'
:
tool_calls
=
message
.
get
(
"tool_calls"
,
[])
for
tool_call
in
tool_calls
:
if
len
(
tool_call
[
"id"
])
>
9
:
logger
.
warning
(
"Truncating tool call ID: %s to %s"
,
tool_call
[
"id"
],
tool_call
[
"id"
][
-
9
:],
)
tool_call
[
"id"
]
=
tool_call
[
"id"
][
-
9
:]
request
.
messages
[
i
][
"tool_calls"
]
=
tool_calls
elif
message
.
get
(
"role"
)
in
{
"tool_results"
,
"tool"
}:
if
"tool_call_id"
in
message
:
tool_call_id
=
message
[
"tool_call_id"
]
if
len
(
tool_call_id
)
>
9
:
logger
.
warning
(
"Truncating tool_call_id: %s to %s"
,
tool_call_id
,
tool_call_id
[
-
9
:],
)
tool_call_id
=
tool_call_id
[
-
9
:]
request
.
messages
[
i
][
"tool_call_id"
]
=
tool_call_id
def
list_local_repo_files
(
repo_id
:
str
,
revision
:
Optional
[
str
])
->
List
[
str
]:
def
list_local_repo_files
(
repo_id
:
str
,
revision
:
Optional
[
str
])
->
List
[
str
]:
repo_cache
=
os
.
path
.
join
(
repo_cache
=
os
.
path
.
join
(
huggingface_hub
.
constants
.
HF_HUB_CACHE
,
huggingface_hub
.
constants
.
HF_HUB_CACHE
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment