Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
83caf35e
Unverified
Commit
83caf35e
authored
Oct 03, 2024
by
Guillaume Calmettes
Committed by
GitHub
Oct 03, 2024
Browse files
[BugFix] Enforce Mistral ToolCall id constraint when using the Mistral tool call parser (#9020)
parent
01843c89
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
22 additions
and
6 deletions
+22
-6
tests/tool_use/test_parallel_tool_calls.py
tests/tool_use/test_parallel_tool_calls.py
+2
-2
tests/tool_use/test_tool_calls.py
tests/tool_use/test_tool_calls.py
+2
-2
vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py
vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py
+18
-2
No files found.
tests/tool_use/test_parallel_tool_calls.py
View file @
83caf35e
...
@@ -45,7 +45,7 @@ async def test_parallel_tool_calls(client: openai.AsyncOpenAI,
...
@@ -45,7 +45,7 @@ async def test_parallel_tool_calls(client: openai.AsyncOpenAI,
assert
tool_call
.
type
==
"function"
assert
tool_call
.
type
==
"function"
assert
tool_call
.
function
is
not
None
assert
tool_call
.
function
is
not
None
assert
isinstance
(
tool_call
.
id
,
str
)
assert
isinstance
(
tool_call
.
id
,
str
)
assert
len
(
tool_call
.
id
)
>
16
assert
len
(
tool_call
.
id
)
>
=
9
# make sure the weather tool was called correctly
# make sure the weather tool was called correctly
assert
tool_call
.
function
.
name
==
WEATHER_TOOL
[
"function"
][
"name"
]
assert
tool_call
.
function
.
name
==
WEATHER_TOOL
[
"function"
][
"name"
]
...
@@ -108,7 +108,7 @@ async def test_parallel_tool_calls(client: openai.AsyncOpenAI,
...
@@ -108,7 +108,7 @@ async def test_parallel_tool_calls(client: openai.AsyncOpenAI,
if
tool_call
.
id
:
if
tool_call
.
id
:
tool_call_id_count
+=
1
tool_call_id_count
+=
1
assert
(
isinstance
(
tool_call
.
id
,
str
)
assert
(
isinstance
(
tool_call
.
id
,
str
)
and
(
len
(
tool_call
.
id
)
>
16
))
and
(
len
(
tool_call
.
id
)
>
=
9
))
# if parts of the function start being streamed
# if parts of the function start being streamed
if
tool_call
.
function
:
if
tool_call
.
function
:
...
...
tests/tool_use/test_tool_calls.py
View file @
83caf35e
...
@@ -33,7 +33,7 @@ async def test_tool_call_and_choice(client: openai.AsyncOpenAI):
...
@@ -33,7 +33,7 @@ async def test_tool_call_and_choice(client: openai.AsyncOpenAI):
assert
tool_calls
[
0
].
type
==
'function'
assert
tool_calls
[
0
].
type
==
'function'
assert
tool_calls
[
0
].
function
is
not
None
assert
tool_calls
[
0
].
function
is
not
None
assert
isinstance
(
tool_calls
[
0
].
id
,
str
)
assert
isinstance
(
tool_calls
[
0
].
id
,
str
)
assert
len
(
tool_calls
[
0
].
id
)
>
16
assert
len
(
tool_calls
[
0
].
id
)
>
=
9
# make sure the weather tool was called (classic example) with arguments
# make sure the weather tool was called (classic example) with arguments
assert
tool_calls
[
0
].
function
.
name
==
WEATHER_TOOL
[
"function"
][
"name"
]
assert
tool_calls
[
0
].
function
.
name
==
WEATHER_TOOL
[
"function"
][
"name"
]
...
@@ -106,7 +106,7 @@ async def test_tool_call_and_choice(client: openai.AsyncOpenAI):
...
@@ -106,7 +106,7 @@ async def test_tool_call_and_choice(client: openai.AsyncOpenAI):
assert
finish_reason_count
==
1
assert
finish_reason_count
==
1
assert
role_name
==
'assistant'
assert
role_name
==
'assistant'
assert
isinstance
(
tool_call_id
,
str
)
and
(
len
(
tool_call_id
)
>
16
)
assert
isinstance
(
tool_call_id
,
str
)
and
(
len
(
tool_call_id
)
>
=
9
)
# validate the name and arguments
# validate the name and arguments
assert
function_name
==
WEATHER_TOOL
[
"function"
][
"name"
]
assert
function_name
==
WEATHER_TOOL
[
"function"
][
"name"
]
...
...
vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py
View file @
83caf35e
import
json
import
json
import
re
import
re
from
random
import
choices
from
string
import
ascii_letters
,
digits
from
typing
import
Dict
,
List
,
Sequence
,
Union
from
typing
import
Dict
,
List
,
Sequence
,
Union
import
partial_json_parser
import
partial_json_parser
from
partial_json_parser.core.options
import
Allow
from
partial_json_parser.core.options
import
Allow
from
pydantic
import
Field
from
vllm.entrypoints.openai.protocol
import
(
DeltaFunctionCall
,
DeltaMessage
,
from
vllm.entrypoints.openai.protocol
import
(
DeltaFunctionCall
,
DeltaMessage
,
DeltaToolCall
,
DeltaToolCall
,
...
@@ -19,6 +22,19 @@ from vllm.utils import random_uuid
...
@@ -19,6 +22,19 @@ from vllm.utils import random_uuid
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
ALPHANUMERIC
=
ascii_letters
+
digits
class
MistralToolCall
(
ToolCall
):
id
:
str
=
Field
(
default_factory
=
lambda
:
MistralToolCall
.
generate_random_id
())
@
staticmethod
def
generate_random_id
():
# Mistral Tool Call Ids must be alphanumeric with a maximum length of 9.
# https://github.com/mistralai/mistral-common/blob/21ee9f6cee3441e9bb1e6ed2d10173f90bd9b94b/src/mistral_common/protocol/instruct/validator.py#L299
return
""
.
join
(
choices
(
ALPHANUMERIC
,
k
=
9
))
class
MistralToolParser
(
ToolParser
):
class
MistralToolParser
(
ToolParser
):
"""
"""
...
@@ -71,8 +87,8 @@ class MistralToolParser(ToolParser):
...
@@ -71,8 +87,8 @@ class MistralToolParser(ToolParser):
# load the JSON, and then use it to build the Function and
# load the JSON, and then use it to build the Function and
# Tool Call
# Tool Call
function_call_arr
=
json
.
loads
(
raw_tool_call
)
function_call_arr
=
json
.
loads
(
raw_tool_call
)
tool_calls
:
List
[
ToolCall
]
=
[
tool_calls
:
List
[
Mistral
ToolCall
]
=
[
ToolCall
(
Mistral
ToolCall
(
type
=
"function"
,
type
=
"function"
,
function
=
FunctionCall
(
function
=
FunctionCall
(
name
=
raw_function_call
[
"name"
],
name
=
raw_function_call
[
"name"
],
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment