Commit 0da93439 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.18.1rc0' into v0.18.1rc0-ori

parents 25f2f756 298e5108
...@@ -31,7 +31,7 @@ class TestAiterMlaFp8SupportCheck: ...@@ -31,7 +31,7 @@ class TestAiterMlaFp8SupportCheck:
# Should return False without raising # Should return False without raising
with patch( with patch(
"vllm._aiter_ops.inspect.signature", "inspect.signature",
side_effect=ImportError("No module"), side_effect=ImportError("No module"),
): ):
result = _check_aiter_mla_fp8_support() result = _check_aiter_mla_fp8_support()
...@@ -46,7 +46,7 @@ class TestAiterMlaFp8SupportCheck: ...@@ -46,7 +46,7 @@ class TestAiterMlaFp8SupportCheck:
aiter_ops._AITER_MLA_SUPPORTS_FP8 = None aiter_ops._AITER_MLA_SUPPORTS_FP8 = None
with patch( with patch(
"vllm._aiter_ops.inspect.signature", "inspect.signature",
side_effect=ModuleNotFoundError("Module not found"), side_effect=ModuleNotFoundError("Module not found"),
): ):
# Should return False without raising # Should return False without raising
...@@ -63,7 +63,7 @@ class TestAiterMlaFp8SupportCheck: ...@@ -63,7 +63,7 @@ class TestAiterMlaFp8SupportCheck:
aiter_ops._AITER_MLA_SUPPORTS_FP8 = None aiter_ops._AITER_MLA_SUPPORTS_FP8 = None
with patch( with patch(
"vllm._aiter_ops.inspect.signature", "inspect.signature",
side_effect=AttributeError("No attribute"), side_effect=AttributeError("No attribute"),
): ):
assert _check_aiter_mla_fp8_support() is False assert _check_aiter_mla_fp8_support() is False
...@@ -78,7 +78,7 @@ class TestAiterMlaFp8SupportCheck: ...@@ -78,7 +78,7 @@ class TestAiterMlaFp8SupportCheck:
aiter_ops._AITER_MLA_SUPPORTS_FP8 = None aiter_ops._AITER_MLA_SUPPORTS_FP8 = None
with patch( with patch(
"vllm._aiter_ops.inspect.signature", "inspect.signature",
side_effect=ValueError("No signature"), side_effect=ValueError("No signature"),
): ):
assert _check_aiter_mla_fp8_support() is False assert _check_aiter_mla_fp8_support() is False
...@@ -93,7 +93,7 @@ class TestAiterMlaFp8SupportCheck: ...@@ -93,7 +93,7 @@ class TestAiterMlaFp8SupportCheck:
aiter_ops._AITER_MLA_SUPPORTS_FP8 = None aiter_ops._AITER_MLA_SUPPORTS_FP8 = None
with patch( with patch(
"vllm._aiter_ops.inspect.signature", "inspect.signature",
side_effect=TypeError("Not a callable"), side_effect=TypeError("Not a callable"),
): ):
assert _check_aiter_mla_fp8_support() is False assert _check_aiter_mla_fp8_support() is False
......
...@@ -74,7 +74,7 @@ def test_embed_dimensions(model_info: EmbedModelInfo): ...@@ -74,7 +74,7 @@ def test_embed_dimensions(model_info: EmbedModelInfo):
pooling_params.verify(model_config) pooling_params.verify(model_config)
@pytest.mark.parametrize("task", ["score", "classify"]) @pytest.mark.parametrize("task", ["classify"])
def test_classify(task): def test_classify(task):
model_config = MockModelConfig(pooler_config=PoolerConfig(seq_pooling_type="CLS")) model_config = MockModelConfig(pooler_config=PoolerConfig(seq_pooling_type="CLS"))
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
from dataclasses import dataclass, field
from types import NoneType
from typing import Any
import pytest
from tests.tool_parsers.utils import run_tool_extraction
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers import ToolParserManager
@dataclass
class ToolParserTestConfig:
"""Configuration for a tool parser's common tests.
This dataclass contains all the test data and expected results needed
to run the common test suite for a parser. Each parser test file
creates one instance of this config with parser-specific values.
Attributes:
parser_name: Name used with ToolParserManager (e.g., "mistral")
Test data (model outputs):
no_tool_calls_output: Plain text without any tool syntax
single_tool_call_output: One tool call with simple arguments
parallel_tool_calls_output: Multiple tool calls in one response
various_data_types_output: Tool with various data types
empty_arguments_output: Tool call with no parameters
surrounding_text_output: Tool call mixed with regular text
escaped_strings_output: Tool call with escaped chars
malformed_input_outputs: List of invalid inputs
Expected results:
single_tool_call_expected_name: Expected function name
single_tool_call_expected_args: Expected arguments dict
parallel_tool_calls_count: Number of tools in parallel test
parallel_tool_calls_names: Function names in order
single_tool_call_expected_content: Content field when tool called
parallel_tool_calls_expected_content: Content for parallel test
xfail markers:
xfail_streaming: Mapping test name to xfail reason (streaming only)
xfail_nonstreaming: Mapping test name to xfail reason (non-streaming)
Special flags:
allow_empty_or_json_empty_args: True if "" or "{}" both valid for empty args
supports_typed_arguments: True if the parser supports typed function arguments
"""
# Parser identification
parser_name: str
# Test data - model outputs for each common test
no_tool_calls_output: str
single_tool_call_output: str
parallel_tool_calls_output: str
various_data_types_output: str
empty_arguments_output: str
surrounding_text_output: str
escaped_strings_output: str
malformed_input_outputs: list[str]
# Expected results for specific tests (optional overrides)
single_tool_call_expected_name: str = "get_weather"
single_tool_call_expected_args: dict[str, Any] = field(
default_factory=lambda: {"city": "Tokyo"}
)
parallel_tool_calls_count: int = 2
parallel_tool_calls_names: list[str] = field(
default_factory=lambda: ["get_weather", "get_time"]
)
# xfail configuration - maps test name to xfail reason
xfail_streaming: dict[str, str] = field(default_factory=dict)
xfail_nonstreaming: dict[str, str] = field(default_factory=dict)
# Content expectations (some parsers strip content, others don't)
single_tool_call_expected_content: str | None = None
parallel_tool_calls_expected_content: str | None = None
# Special assertions for edge cases
allow_empty_or_json_empty_args: bool = True # "{}" or "" for empty args
supports_typed_arguments: bool = True
class ToolParserTests:
"""Mixin class providing common test suite for tool parsers.
To use this mixin in a parser test file:
1. Create a test_config fixture that returns a ToolParserTestConfig instance
2. Inherit from this class
3. Add parser-specific tests as additional methods
Example:
class TestMistralToolParser(ToolParserTests):
@pytest.fixture
def test_config(self) -> ToolParserTestConfig:
return ToolParserTestConfig(
parser_name="mistral",
no_tool_calls_output="Plain text...",
# ... other config ...
)
# Parser-specific tests
def test_mistral_specific_feature(self, tool_parser):
# Custom test logic
pass
"""
@pytest.fixture
def test_config(self) -> ToolParserTestConfig:
"""Override this to provide parser-specific configuration."""
raise NotImplementedError(
"Subclass must provide test_config fixture returning ToolParserTestConfig"
)
@pytest.fixture
def tokenizer(self, default_tokenizer: TokenizerLike) -> TokenizerLike:
"""Override this to provide parser-specific tokenizer."""
return default_tokenizer
@pytest.fixture
def tool_parser(self, test_config: ToolParserTestConfig, tokenizer: TokenizerLike):
return ToolParserManager.get_tool_parser(test_config.parser_name)(tokenizer)
@pytest.fixture(params=[True, False])
def streaming(self, request: pytest.FixtureRequest) -> bool:
return request.param
def test_no_tool_calls(
self,
request: pytest.FixtureRequest,
tool_parser: Any,
test_config: ToolParserTestConfig,
streaming: bool,
):
"""Verify parser handles plain text without tool syntax."""
# Apply xfail markers if configured
test_name = "test_no_tool_calls"
self.apply_xfail_mark(request, test_config, test_name, streaming)
content, tool_calls = run_tool_extraction(
tool_parser, test_config.no_tool_calls_output, streaming=streaming
)
assert content == test_config.no_tool_calls_output, (
f"Expected content to match input, got {content}"
)
assert len(tool_calls) == 0, f"Expected no tool calls, got {len(tool_calls)}"
def test_single_tool_call_simple_args(
self,
request: pytest.FixtureRequest,
tool_parser: Any,
test_config: ToolParserTestConfig,
streaming: bool,
):
"""Verify parser extracts one tool with simple arguments."""
# Apply xfail markers if configured
test_name = "test_single_tool_call_simple_args"
self.apply_xfail_mark(request, test_config, test_name, streaming)
content, tool_calls = run_tool_extraction(
tool_parser, test_config.single_tool_call_output, streaming=streaming
)
# Content check (some parsers strip it)
if test_config.single_tool_call_expected_content is not None:
assert content == test_config.single_tool_call_expected_content
assert len(tool_calls) == 1, f"Expected 1 tool call, got {len(tool_calls)}"
assert tool_calls[0].type == "function"
assert tool_calls[0].function.name == test_config.single_tool_call_expected_name
args = json.loads(tool_calls[0].function.arguments)
for key, value in test_config.single_tool_call_expected_args.items():
assert args.get(key) == value, (
f"Expected {key}={value}, got {args.get(key)}"
)
def test_parallel_tool_calls(
self,
request: pytest.FixtureRequest,
tool_parser: Any,
test_config: ToolParserTestConfig,
streaming: bool,
):
"""Verify parser handles multiple tools in one response."""
# Apply xfail markers if configured
test_name = "test_parallel_tool_calls"
self.apply_xfail_mark(request, test_config, test_name, streaming)
content, tool_calls = run_tool_extraction(
tool_parser,
test_config.parallel_tool_calls_output,
streaming=streaming,
)
assert len(tool_calls) == test_config.parallel_tool_calls_count, (
f"Expected {test_config.parallel_tool_calls_count} "
f"tool calls, got {len(tool_calls)}"
)
# Verify tool names match expected
for i, expected_name in enumerate(test_config.parallel_tool_calls_names):
assert tool_calls[i].type == "function"
assert tool_calls[i].function.name == expected_name
# Verify unique IDs
ids = [tc.id for tc in tool_calls]
assert len(ids) == len(set(ids)), "Tool call IDs should be unique"
def test_various_data_types(
self,
request: pytest.FixtureRequest,
tool_parser: Any,
test_config: ToolParserTestConfig,
streaming: bool,
):
"""Verify parser handles all JSON types in arguments."""
# Apply xfail markers if configured
test_name = "test_various_data_types"
self.apply_xfail_mark(request, test_config, test_name, streaming)
content, tool_calls = run_tool_extraction(
tool_parser,
test_config.various_data_types_output,
streaming=streaming,
)
assert len(tool_calls) == 1, f"Expected 1 tool call, got {len(tool_calls)}"
args = json.loads(tool_calls[0].function.arguments)
# Verify all expected fields present
required_fields_types = {
"string_field": str,
"int_field": int,
"float_field": float,
"bool_field": bool,
"null_field": NoneType,
"array_field": list,
"object_field": dict,
}
for required_field, expected_type in required_fields_types.items():
assert required_field in args, (
f"Expected field '{required_field}' in arguments"
)
if test_config.supports_typed_arguments:
found_type = type(args[required_field])
assert found_type is expected_type, (
f"Expected field '{required_field}' to have type {expected_type}, "
f"got {found_type}"
)
def test_empty_arguments(
self,
request: pytest.FixtureRequest,
tool_parser: Any,
test_config: ToolParserTestConfig,
streaming: bool,
):
"""Verify parser handles parameterless tool calls."""
# Apply xfail markers if configured
test_name = "test_empty_arguments"
self.apply_xfail_mark(request, test_config, test_name, streaming)
content, tool_calls = run_tool_extraction(
tool_parser, test_config.empty_arguments_output, streaming=streaming
)
assert len(tool_calls) == 1, f"Expected 1 tool call, got {len(tool_calls)}"
args = tool_calls[0].function.arguments
if test_config.allow_empty_or_json_empty_args:
assert args in ["{}", ""], f"Expected empty args, got {args}"
else:
assert args == "{}", f"Expected {{}}, got {args}"
def test_surrounding_text(
self,
request: pytest.FixtureRequest,
tool_parser: Any,
test_config: ToolParserTestConfig,
streaming: bool,
):
"""Verify parser extracts tools from mixed content."""
# Apply xfail markers if configured
test_name = "test_surrounding_text"
self.apply_xfail_mark(request, test_config, test_name, streaming)
content, tool_calls = run_tool_extraction(
tool_parser, test_config.surrounding_text_output, streaming=streaming
)
assert len(tool_calls) >= 1, (
f"Expected at least 1 tool call, got {len(tool_calls)}"
)
def test_escaped_strings(
self,
request: pytest.FixtureRequest,
tool_parser: Any,
test_config: ToolParserTestConfig,
streaming: bool,
):
"""Verify parser handles escaped characters in arguments."""
# Apply xfail markers if configured
test_name = "test_escaped_strings"
self.apply_xfail_mark(request, test_config, test_name, streaming)
content, tool_calls = run_tool_extraction(
tool_parser, test_config.escaped_strings_output, streaming=streaming
)
assert len(tool_calls) == 1, f"Expected 1 tool call, got {len(tool_calls)}"
args = json.loads(tool_calls[0].function.arguments)
# At minimum, verify we can parse and have expected fields
# Exact escaping behavior varies by parser
assert len(args) > 0, "Expected some arguments with escaped strings"
def test_malformed_input(
self,
request: pytest.FixtureRequest,
tool_parser: Any,
test_config: ToolParserTestConfig,
streaming: bool,
):
"""Verify parser gracefully handles invalid syntax."""
# Apply xfail markers if configured
test_name = "test_malformed_input"
self.apply_xfail_mark(request, test_config, test_name, streaming)
for malformed_input in test_config.malformed_input_outputs:
# Should not raise exception
content, tool_calls = run_tool_extraction(
tool_parser, malformed_input, streaming=streaming
)
# Parser should handle gracefully (exact behavior varies)
def test_streaming_reconstruction(
self,
request: pytest.FixtureRequest,
tool_parser: Any,
test_config: ToolParserTestConfig,
):
"""Verify streaming produces same result as non-streaming."""
test_name = "test_streaming_reconstruction"
self.apply_xfail_mark(request, test_config, test_name, True)
test_output = test_config.single_tool_call_output
# Non-streaming result
content_non, tools_non = run_tool_extraction(
tool_parser, test_output, streaming=False
)
# Streaming result
content_stream, tools_stream = run_tool_extraction(
tool_parser, test_output, streaming=True
)
# Compare results
assert content_non == content_stream, "Content should match between modes"
assert len(tools_non) == len(tools_stream), "Tool count should match"
if len(tools_non) > 0:
assert tools_non[0].function.name == tools_stream[0].function.name
assert tools_non[0].function.arguments == tools_stream[0].function.arguments
def apply_xfail_mark(self, request, test_config, test_name, streaming):
reason = None
if streaming and test_name in test_config.xfail_streaming:
reason = test_config.xfail_streaming[test_name]
elif not streaming and test_name in test_config.xfail_nonstreaming:
reason = test_config.xfail_nonstreaming[test_name]
if reason is not None:
mark = pytest.mark.xfail(reason=reason, strict=True)
request.node.add_marker(mark)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from transformers import AutoTokenizer
from vllm.tokenizers import TokenizerLike
@pytest.fixture(scope="module")
def default_tokenizer() -> TokenizerLike:
return AutoTokenizer.from_pretrained("gpt2")
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Unit tests for DeepSeekV32ToolParser.
These tests use a minimal mock tokenizer so no real model weights are required.
"""
import json
from unittest.mock import MagicMock
import pytest
from vllm.tool_parsers.deepseekv32_tool_parser import DeepSeekV32ToolParser
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
# Token IDs are not used by the V32 parser logic, so we only need the
# tokenizer object to be truthy (the parser checks `if not self.model_tokenizer`).
MOCK_TOKENIZER = MagicMock()
MOCK_TOKENIZER.get_vocab.return_value = {}
def make_parser() -> DeepSeekV32ToolParser:
return DeepSeekV32ToolParser(MOCK_TOKENIZER)
def make_tool_param(name: str, params: dict) -> MagicMock:
"""Build a mock tool matching the ChatCompletionToolsParam shape."""
tool = MagicMock()
tool.function.name = name
tool.function.parameters = params
return tool
def make_request(tools=None) -> MagicMock:
req = MagicMock()
req.tools = tools
return req
# Shorthand for the DSML tokens used throughout
FC_START = "<|DSML|function_calls>"
FC_END = "</|DSML|function_calls>"
INV_START = '<|DSML|invoke name="'
INV_END = "</|DSML|invoke>"
PARAM_START = '<|DSML|parameter name="'
PARAM_END = "</|DSML|parameter>"
def build_tool_call(func_name: str, params: dict[str, str]) -> str:
"""Build a complete model-output tool call string."""
param_strs = "".join(
f'{PARAM_START}{k}" string="true">{v}{PARAM_END}' for k, v in params.items()
)
return f'{FC_START}\n{INV_START}{func_name}">\n{param_strs}\n{INV_END}\n{FC_END}'
# ---------------------------------------------------------------------------
# Tests: DeepSeekV32ToolParser._convert_param_value
# ---------------------------------------------------------------------------
class TestConvertParamValue:
@pytest.fixture
def parser(self):
return make_parser()
def test_null(self, parser):
assert parser._convert_param_value("null", "string") is None
assert parser._convert_param_value("NULL", "integer") is None
def test_string(self, parser):
assert parser._convert_param_value("hello", "string") == "hello"
def test_integer_valid(self, parser):
assert parser._convert_param_value("42", "integer") == 42
def test_integer_invalid_falls_back_to_str(self, parser):
assert parser._convert_param_value("abc", "int") == "abc"
def test_number_float(self, parser):
assert parser._convert_param_value("3.14", "number") == pytest.approx(3.14)
def test_number_whole_returns_int(self, parser):
assert parser._convert_param_value("5.0", "number") == 5
assert isinstance(parser._convert_param_value("5.0", "number"), int)
def test_boolean_true(self, parser):
assert parser._convert_param_value("true", "boolean") is True
assert parser._convert_param_value("1", "bool") is True
def test_boolean_false(self, parser):
assert parser._convert_param_value("false", "boolean") is False
assert parser._convert_param_value("False", "bool") is False
def test_object_valid_json(self, parser):
assert parser._convert_param_value('{"k": 1}', "object") == {"k": 1}
def test_object_invalid_json_falls_back(self, parser):
assert parser._convert_param_value("not-json", "object") == "not-json"
def test_array_valid_json(self, parser):
assert parser._convert_param_value("[1, 2]", "array") == [1, 2]
def test_unknown_type_tries_json_then_string(self, parser):
assert parser._convert_param_value("123", "unknown") == 123
assert parser._convert_param_value("hello", "unknown") == "hello"
# ---------------------------------------------------------------------------
# Tests: extract_tool_calls (non-streaming)
# ---------------------------------------------------------------------------
class TestExtractToolCalls:
@pytest.fixture
def parser(self):
return make_parser()
def test_no_tool_call(self, parser):
result = parser.extract_tool_calls("just some text", None)
assert not result.tools_called
assert result.tool_calls == []
assert result.content == "just some text"
def test_single_tool_no_params(self, parser):
model_output = f'{FC_START}\n{INV_START}get_time">\n{INV_END}\n{FC_END}'
result = parser.extract_tool_calls(model_output, None)
assert result.tools_called
assert len(result.tool_calls) == 1
assert result.tool_calls[0].function.name == "get_time"
assert json.loads(result.tool_calls[0].function.arguments) == {}
def test_single_tool_with_params(self, parser):
model_output = build_tool_call(
"get_weather", {"location": "SF", "date": "2024-01-16"}
)
result = parser.extract_tool_calls(model_output, None)
assert result.tools_called
assert len(result.tool_calls) == 1
tc = result.tool_calls[0]
assert tc.function.name == "get_weather"
assert json.loads(tc.function.arguments) == {
"location": "SF",
"date": "2024-01-16",
}
def test_content_before_tool_call(self, parser):
model_output = "Sure, let me check! " + build_tool_call(
"get_weather", {"location": "NYC"}
)
result = parser.extract_tool_calls(model_output, None)
assert result.tools_called
assert result.content == "Sure, let me check! "
def test_no_content_prefix_returns_none(self, parser):
model_output = build_tool_call("get_weather", {"location": "NYC"})
result = parser.extract_tool_calls(model_output, None)
assert result.tools_called
assert result.content is None
def test_multiple_tools(self, parser):
model_output = (
f"{FC_START}\n"
f'{INV_START}get_weather">\n'
f'{PARAM_START}location" string="true">SF{PARAM_END}\n'
f"{INV_END}\n"
f'{INV_START}get_weather">\n'
f'{PARAM_START}location" string="true">NYC{PARAM_END}\n'
f"{INV_END}\n"
f"{FC_END}"
)
result = parser.extract_tool_calls(model_output, None)
assert result.tools_called
assert len(result.tool_calls) == 2
assert json.loads(result.tool_calls[0].function.arguments) == {"location": "SF"}
assert json.loads(result.tool_calls[1].function.arguments) == {
"location": "NYC"
}
# ---------------------------------------------------------------------------
# Tests: extract_tool_calls_streaming
# ---------------------------------------------------------------------------
class TestExtractToolCallsStreaming:
"""Simulate character-by-character streaming and verify reconstructed args."""
@pytest.fixture
def parser(self):
return make_parser()
def _stream(self, parser, full_text: str, request=None):
"""Drive the parser line-by-line and collect non-None deltas.
Real tokenizers emit multi-character chunks, not individual characters.
Streaming character-by-character would never deliver the full sentinel
token (e.g. '|DSML|') in a single delta, so we split on newlines to
ensure each sentinel always lands in one chunk.
"""
if request is None:
request = make_request()
# Split into lines, preserving the trailing newline in each chunk.
chunks: list[str] = []
remaining = full_text
while remaining:
nl = remaining.find("\n")
if nl == -1:
chunks.append(remaining)
break
chunks.append(remaining[: nl + 1])
remaining = remaining[nl + 1 :]
deltas = []
prev = ""
for chunk in chunks:
curr = prev + chunk
result = parser.extract_tool_calls_streaming(
previous_text=prev,
current_text=curr,
delta_text=chunk,
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[1],
request=request,
)
prev = curr
if result is not None:
deltas.append(result)
return deltas
def _reconstruct_args(self, deltas, tool_index=0) -> str:
"""Concatenate all argument fragments for a given tool index."""
fragments = []
for d in deltas:
if d.tool_calls:
for tc in d.tool_calls:
if tc.index == tool_index and tc.function and tc.function.arguments:
fragments.append(tc.function.arguments)
return "".join(fragments)
def test_plain_content_no_tool(self, parser):
full_text = "Hello, world!"
deltas = self._stream(parser, full_text)
content = "".join(d.content for d in deltas if d.content is not None)
assert "Hello, world!" in content
assert all(not d.tool_calls for d in deltas)
def test_single_tool_streaming(self, parser):
full_text = build_tool_call("get_weather", {"location": "SF"})
deltas = self._stream(parser, full_text)
args_str = self._reconstruct_args(deltas)
assert json.loads(args_str) == {"location": "SF"}
def test_tool_name_emitted(self, parser):
full_text = build_tool_call("my_func", {"x": "1"})
deltas = self._stream(parser, full_text)
func_names = [
tc.function.name
for d in deltas
if d.tool_calls
for tc in d.tool_calls
if tc.function and tc.function.name
]
assert any("my_func" in n for n in func_names)
def test_content_before_tool_call_streaming(self, parser):
full_text = "Thinking... " + build_tool_call("fn", {"a": "b"})
deltas = self._stream(parser, full_text)
content = "".join(d.content for d in deltas if d.content is not None)
assert "Thinking" in content
def test_type_conversion_in_streaming(self, parser):
tool = make_tool_param(
"add",
{
"type": "object",
"properties": {
"x": {"type": "integer"},
"y": {"type": "integer"},
},
},
)
request = make_request(tools=[tool])
full_text = build_tool_call("add", {"x": "3", "y": "4"})
deltas = self._stream(parser, full_text, request=request)
args_str = self._reconstruct_args(deltas)
assert json.loads(args_str) == {"x": 3, "y": 4}
def test_multiple_tools_streaming(self, parser):
full_text = (
f"{FC_START}\n"
f'{INV_START}func_a">\n'
f'{PARAM_START}p" string="true">v1{PARAM_END}\n'
f"{INV_END}\n"
f'{INV_START}func_b">\n'
f'{PARAM_START}q" string="true">v2{PARAM_END}\n'
f"{INV_END}\n"
f"{FC_END}"
)
deltas = self._stream(parser, full_text)
# Collect function names by index
names_by_index: dict[int, str] = {}
for d in deltas:
if d.tool_calls:
for tc in d.tool_calls:
if tc.function and tc.function.name:
names_by_index[tc.index] = tc.function.name
assert names_by_index.get(0) == "func_a"
assert names_by_index.get(1) == "func_b"
assert json.loads(self._reconstruct_args(deltas, tool_index=0)) == {"p": "v1"}
assert json.loads(self._reconstruct_args(deltas, tool_index=1)) == {"q": "v2"}
def test_state_reset_on_new_stream(self, parser):
"""A second stream (previous_text == '') must reset state cleanly."""
full_text = build_tool_call("fn", {"k": "v"})
# First stream
self._stream(parser, full_text)
# Second stream - should produce identical results
deltas2 = self._stream(parser, full_text)
assert json.loads(self._reconstruct_args(deltas2)) == {"k": "v"}
def test_empty_arguments_streaming(self, parser):
"""Invoke block with zero parameters should produce empty JSON."""
full_text = f'{FC_START}\n{INV_START}get_time">\n{INV_END}\n{FC_END}'
deltas = self._stream(parser, full_text)
args_str = self._reconstruct_args(deltas)
assert json.loads(args_str) == {}
def test_unique_tool_call_ids(self, parser):
"""Each tool call in a parallel stream must get a distinct id."""
full_text = (
f"{FC_START}\n"
f'{INV_START}fn_a">\n'
f'{PARAM_START}x" string="true">1{PARAM_END}\n'
f"{INV_END}\n"
f'{INV_START}fn_b">\n'
f'{PARAM_START}y" string="true">2{PARAM_END}\n'
f"{INV_END}\n"
f"{FC_END}"
)
deltas = self._stream(parser, full_text)
ids = [
tc.id
for d in deltas
if d.tool_calls
for tc in d.tool_calls
if tc.id is not None
]
assert len(ids) == 2
assert ids[0] != ids[1]
def test_eos_after_tool_calls(self, parser):
"""EOS token (empty delta_text, non-empty delta_token_ids) returns
a non-None DeltaMessage so the serving framework can finalize."""
full_text = build_tool_call("fn", {"k": "v"})
# Drive through the full text first
deltas = self._stream(parser, full_text)
assert any(d.tool_calls for d in deltas)
# Now simulate EOS: empty delta_text, but token ids present
prev = full_text
result = parser.extract_tool_calls_streaming(
previous_text=prev,
current_text=prev,
delta_text="",
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[2], # EOS token id
request=make_request(),
)
assert result is not None
def test_streaming_matches_non_streaming(self, parser):
"""Streaming and non-streaming must produce the same result."""
full_text = build_tool_call(
"get_weather", {"location": "SF", "date": "2024-01-16"}
)
# Non-streaming
non_stream = parser.extract_tool_calls(full_text, None)
assert non_stream.tools_called
ns_name = non_stream.tool_calls[0].function.name
ns_args = json.loads(non_stream.tool_calls[0].function.arguments)
# Streaming
deltas = self._stream(parser, full_text)
s_names = [
tc.function.name
for d in deltas
if d.tool_calls
for tc in d.tool_calls
if tc.function and tc.function.name
]
s_args = json.loads(self._reconstruct_args(deltas))
assert s_names[0] == ns_name
assert s_args == ns_args
def _stream_chunked(self, parser, full_text: str, chunk_size: int, request=None):
"""Drive the parser with fixed-size chunks (simulates stream interval).
Unlike ``_stream`` which splits on newlines, this splits the text
into ``chunk_size``-character pieces so the start token can be
split across chunks — exactly what happens with stream interval > 1.
"""
if request is None:
request = make_request()
chunks = [
full_text[i : i + chunk_size] for i in range(0, len(full_text), chunk_size)
]
deltas = []
prev = ""
for chunk in chunks:
curr = prev + chunk
result = parser.extract_tool_calls_streaming(
previous_text=prev,
current_text=curr,
delta_text=chunk,
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[1],
request=request,
)
prev = curr
if result is not None:
deltas.append(result)
return deltas
def test_single_tool_chunked_stream_interval(self, parser):
"""Start token split across chunks (stream interval > 1)."""
full_text = build_tool_call("get_weather", {"location": "SF"})
# Use a chunk size that splits the start token
deltas = self._stream_chunked(parser, full_text, chunk_size=5)
args_str = self._reconstruct_args(deltas)
assert json.loads(args_str) == {"location": "SF"}
def test_content_before_tool_chunked(self, parser):
"""Content before tool call with chunked streaming."""
full_text = "Thinking... " + build_tool_call("fn", {"a": "b"})
deltas = self._stream_chunked(parser, full_text, chunk_size=7)
content = "".join(d.content for d in deltas if d.content is not None)
assert "Thinking" in content
args_str = self._reconstruct_args(deltas)
assert json.loads(args_str) == {"a": "b"}
def test_multiple_tools_chunked(self, parser):
"""Multiple tools with chunked streaming."""
full_text = (
f"{FC_START}\n"
f'{INV_START}func_a">\n'
f'{PARAM_START}p" string="true">v1{PARAM_END}\n'
f"{INV_END}\n"
f'{INV_START}func_b">\n'
f'{PARAM_START}q" string="true">v2{PARAM_END}\n'
f"{INV_END}\n"
f"{FC_END}"
)
deltas = self._stream_chunked(parser, full_text, chunk_size=10)
assert json.loads(self._reconstruct_args(deltas, tool_index=0)) == {"p": "v1"}
assert json.loads(self._reconstruct_args(deltas, tool_index=1)) == {"q": "v2"}
def test_no_emission_while_incomplete(self, parser):
"""No tool calls should be emitted until an invoke block completes."""
# Stream only a partial invoke (no closing tag)
partial_text = (
f"{FC_START}\n"
f'{INV_START}fn">\n'
f'{PARAM_START}k" string="true">val{PARAM_END}\n'
)
deltas = self._stream(parser, partial_text)
# Should have no tool call deltas yet
assert all(not d.tool_calls for d in deltas)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from tests.tool_parsers.common_tests import (
ToolParserTestConfig,
ToolParserTests,
)
from vllm.tokenizers import TokenizerLike, get_tokenizer
class TestDeepSeekV3ToolParser(ToolParserTests):
@pytest.fixture(scope="class")
def tokenizer(self) -> TokenizerLike:
return get_tokenizer("deepseek-ai/DeepSeek-V3")
@pytest.fixture
def test_config(self) -> ToolParserTestConfig:
return ToolParserTestConfig(
parser_name="deepseek_v3",
# Test data
no_tool_calls_output=(
"How can I help you today? I can check weather for you."
),
single_tool_call_output="""<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_weather
```json
{"city": "Tokyo", "unit": "celsius"}
```<|tool▁call▁end|><|tool▁calls▁end|>""",
parallel_tool_calls_output="""<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_weather
```json
{"city": "Tokyo", "unit": "celsius"}
```<|tool▁call▁end|><|tool▁call▁begin|>function<|tool▁sep|>search_hotels
```json
{"location": "Tokyo", "check_in": "2025-01-15"}
```<|tool▁call▁end|><|tool▁calls▁end|>""",
various_data_types_output=(
"""<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>test_function
```json
"""
"""{"string_field": "hello", "int_field": 42, "float_field": 3.14, """
""""bool_field": true, "null_field": null, """
""""array_field": ["a", "b", "c"], """
""""object_field": {"nested": "value"}, """
""""empty_array": [], "empty_object": {}}
```<|tool▁call▁end|><|tool▁calls▁end|>"""
),
empty_arguments_output="""<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_current_time
```json
{}
```<|tool▁call▁end|><|tool▁calls▁end|>""",
surrounding_text_output=(
"""Let me check the weather for you."""
"""<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_weather
```json
{"city": "Paris"}
```<|tool▁call▁end|><|tool▁calls▁end|>"""
),
escaped_strings_output=(
"""<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>send_message
```json
"""
"""{"text": "He said \\"hello\\"", "path": "C:\\\\Users\\\\file", """
""""newline": "line1\\nline2"}
```<|tool▁call▁end|><|tool▁calls▁end|>"""
),
malformed_input_outputs=[
"""<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_weather
```json
{"city": "Tokyo"
```<|tool▁call▁end|><|tool▁calls▁end|>""",
"""<|tool▁calls▁begin|>function<|tool▁sep|>get_weather
```json
{"city": "Tokyo"}
```<|tool▁calls▁end|>""",
],
# Expected results
single_tool_call_expected_name="get_weather",
single_tool_call_expected_args={"city": "Tokyo", "unit": "celsius"},
single_tool_call_expected_content=None,
parallel_tool_calls_count=2,
parallel_tool_calls_names=["get_weather", "search_hotels"],
# xfail markers
xfail_streaming={},
xfail_nonstreaming={
"test_malformed_input": (
"Parser sets tools_called=True even when tool_calls is "
"empty (detects start token but fails to parse)"
),
},
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa: E501
"""Tests for the GLM-4.7 tool call parser."""
import json
from unittest.mock import Mock
import pytest
from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest,
ChatCompletionToolsParam,
FunctionDefinition,
)
from vllm.tokenizers import get_tokenizer
from vllm.tool_parsers.glm47_moe_tool_parser import Glm47MoeModelToolParser
MODEL = "zai-org/GLM-4.5"
@pytest.fixture(scope="module")
def glm47_tokenizer():
return get_tokenizer(tokenizer_name=MODEL)
@pytest.fixture
def glm47_tool_parser(glm47_tokenizer):
return Glm47MoeModelToolParser(glm47_tokenizer)
@pytest.fixture
def mock_request() -> ChatCompletionRequest:
request = Mock(spec=ChatCompletionRequest)
request.tools = [
ChatCompletionToolsParam(
function=FunctionDefinition(name="get_current_date", parameters={}),
),
ChatCompletionToolsParam(
function=FunctionDefinition(
name="get_weather",
parameters={
"type": "object",
"properties": {
"city": {"type": "string"},
"date": {"type": "string"},
},
},
),
),
]
request.tool_choice = "auto"
return request
class TestGlm47ExtractToolCalls:
def test_no_tool_call(self, glm47_tool_parser, mock_request):
out = "This is a plain response."
r = glm47_tool_parser.extract_tool_calls(out, request=mock_request)
assert not r.tools_called
assert r.content == out
def test_zero_arg_inline(self, glm47_tool_parser, mock_request):
out = "<tool_call>get_current_date</tool_call>"
r = glm47_tool_parser.extract_tool_calls(out, request=mock_request)
assert r.tools_called
assert r.tool_calls[0].function.name == "get_current_date"
assert json.loads(r.tool_calls[0].function.arguments) == {}
assert r.content is None
def test_zero_arg_newline(self, glm47_tool_parser, mock_request):
out = "<tool_call>get_current_date\n</tool_call>"
r = glm47_tool_parser.extract_tool_calls(out, request=mock_request)
assert r.tools_called
assert r.tool_calls[0].function.name == "get_current_date"
def test_args_same_line(self, glm47_tool_parser, mock_request):
out = "<tool_call>get_weather<arg_key>city</arg_key><arg_value>Beijing</arg_value></tool_call>"
r = glm47_tool_parser.extract_tool_calls(out, request=mock_request)
assert r.tools_called
assert json.loads(r.tool_calls[0].function.arguments) == {"city": "Beijing"}
def test_args_with_newlines(self, glm47_tool_parser, mock_request):
out = "<tool_call>get_weather\n<arg_key>city</arg_key>\n<arg_value>Beijing</arg_value>\n</tool_call>"
r = glm47_tool_parser.extract_tool_calls(out, request=mock_request)
assert r.tools_called
assert json.loads(r.tool_calls[0].function.arguments) == {"city": "Beijing"}
def test_content_before(self, glm47_tool_parser, mock_request):
out = "Checking.<tool_call>get_current_date</tool_call>"
r = glm47_tool_parser.extract_tool_calls(out, request=mock_request)
assert r.tools_called
assert r.content == "Checking."
def test_multiple(self, glm47_tool_parser, mock_request):
out = (
"<tool_call>get_weather<arg_key>city</arg_key><arg_value>Beijing</arg_value></tool_call>"
"<tool_call>get_weather<arg_key>city</arg_key><arg_value>Shanghai</arg_value></tool_call>"
)
r = glm47_tool_parser.extract_tool_calls(out, request=mock_request)
assert len(r.tool_calls) == 2
def test_empty_content_none(self, glm47_tool_parser, mock_request):
out = "<tool_call>get_current_date</tool_call>"
r = glm47_tool_parser.extract_tool_calls(out, request=mock_request)
assert r.content is None
def test_whitespace_content_none(self, glm47_tool_parser, mock_request):
out = " \n <tool_call>get_current_date</tool_call>"
r = glm47_tool_parser.extract_tool_calls(out, request=mock_request)
assert r.content is None
def _reset(parser):
parser._buffer = ""
parser._in_tool_call = False
parser.current_tool_name_sent = False
parser._current_tool_name = None
parser._pending_key = None
parser._streaming_string_value = False
parser.prev_tool_call_arr = []
parser.current_tool_id = -1
parser.streamed_args_for_tool = []
parser._tool_call_ids = []
parser._args_started = []
parser._args_closed = []
parser._seen_keys = []
class TestGlm47Streaming:
def test_no_args(self, glm47_tool_parser, mock_request):
_reset(glm47_tool_parser)
for chunk in ["<tool_call>", "get_current_date", "</tool_call>"]:
glm47_tool_parser.extract_tool_calls_streaming(
previous_text="",
current_text="",
delta_text=chunk,
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[],
request=mock_request,
)
assert len(glm47_tool_parser.prev_tool_call_arr) >= 1
def test_with_args(self, glm47_tool_parser, mock_request):
_reset(glm47_tool_parser)
# Split chunks so that the incremental string streaming path
# processes the value, its closing tag, and the tool-call closing
# tag in separate calls.
for chunk in [
"<tool_call>",
"get_weather\n",
"<arg_key>city</arg_key>",
"<arg_value>",
"Beijing",
"</arg_value>",
"</tool_call>",
]:
glm47_tool_parser.extract_tool_calls_streaming(
previous_text="",
current_text="",
delta_text=chunk,
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[],
request=mock_request,
)
assert glm47_tool_parser.prev_tool_call_arr[0]["arguments"]["city"] == "Beijing"
...@@ -107,7 +107,7 @@ def test_extract_tool_calls_no_tools(glm4_moe_tool_parser, mock_request): ...@@ -107,7 +107,7 @@ def test_extract_tool_calls_no_tools(glm4_moe_tool_parser, mock_request):
) )
) )
], ],
"", None,
), ),
( (
"""<tool_call>get_current_weather """<tool_call>get_current_weather
...@@ -152,7 +152,7 @@ def test_extract_tool_calls_no_tools(glm4_moe_tool_parser, mock_request): ...@@ -152,7 +152,7 @@ def test_extract_tool_calls_no_tools(glm4_moe_tool_parser, mock_request):
) )
), ),
], ],
"", None,
), ),
( (
"""I'll help you check the weather. <tool_call>get_current_weather """I'll help you check the weather. <tool_call>get_current_weather
...@@ -202,7 +202,7 @@ def test_extract_tool_calls_no_tools(glm4_moe_tool_parser, mock_request): ...@@ -202,7 +202,7 @@ def test_extract_tool_calls_no_tools(glm4_moe_tool_parser, mock_request):
) )
) )
], ],
"", None,
), ),
( (
"""I will help you get the weather.<tool_call>get_weather """I will help you get the weather.<tool_call>get_weather
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from tests.tool_parsers.common_tests import (
ToolParserTestConfig,
ToolParserTests,
)
class TestGranite20bFcToolParser(ToolParserTests):
@pytest.fixture
def test_config(self) -> ToolParserTestConfig:
return ToolParserTestConfig(
parser_name="granite-20b-fc",
# Test data
no_tool_calls_output="This is a regular response without any tool calls.",
single_tool_call_output=(
'<function_call> {"name": "get_weather", '
'"arguments": {"city": "Tokyo"}}'
),
parallel_tool_calls_output=(
'<function_call> {"name": "get_weather", '
'"arguments": {"city": "Tokyo"}}\n'
'<function_call> {"name": "get_time", '
'"arguments": {"timezone": "Asia/Tokyo"}}'
),
various_data_types_output="""<function_call> {
"name": "test_function",
"arguments": {
"string_field": "hello",
"int_field": 42,
"float_field": 3.14,
"bool_field": true,
"null_field": null,
"array_field": ["a", "b", "c"],
"object_field": {"nested": "value"},
"empty_array": [],
"empty_object": {}
}
}""",
empty_arguments_output=(
'<function_call> {"name": "refresh", "arguments": {}}'
),
surrounding_text_output="""Let me check the weather for you.
<function_call> {"name": "get_weather", "arguments": {"city": "Tokyo"}}""",
escaped_strings_output="""<function_call> {
"name": "test_function",
"arguments": {
"quoted": "He said \\"hello\\"",
"path": "C:\\\\Users\\\\file.txt",
"newline": "line1\\nline2",
"unicode": "emoji: 🎉"
}
}""",
malformed_input_outputs=[
'<function_call> {"name": "func", "arguments": {',
'<function_call> [{"name": "func", "arguments": {}}]',
'{"name": "func", "arguments": {}}',
'<function_call> {"name": 123}',
],
# Expected results
single_tool_call_expected_name="get_weather",
single_tool_call_expected_args={"city": "Tokyo"},
single_tool_call_expected_content=None,
parallel_tool_calls_count=2,
parallel_tool_calls_names=["get_weather", "get_time"],
# xfail markers
xfail_streaming={
"test_surrounding_text": (
"Granite 20B FC streaming requires <function_call> at start"
),
},
xfail_nonstreaming={},
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from tests.tool_parsers.common_tests import (
ToolParserTestConfig,
ToolParserTests,
)
from tests.tool_parsers.utils import run_tool_extraction
class TestGraniteToolParser(ToolParserTests):
@pytest.fixture
def test_config(self) -> ToolParserTestConfig:
return ToolParserTestConfig(
parser_name="granite",
# Test data
no_tool_calls_output="This is a regular response without any tool calls.",
single_tool_call_output=(
'<|tool_call|> [{"name": "get_weather", '
'"arguments": {"city": "Tokyo"}}]'
),
parallel_tool_calls_output="""<|tool_call|> [
{"name": "get_weather", "arguments": {"city": "Tokyo"}},
{"name": "get_time", "arguments": {"timezone": "Asia/Tokyo"}}
]""",
various_data_types_output="""<tool_call> [{
"name": "test_function",
"arguments": {
"string_field": "hello",
"int_field": 42,
"float_field": 3.14,
"bool_field": true,
"null_field": null,
"array_field": ["a", "b", "c"],
"object_field": {"nested": "value"},
"empty_array": [],
"empty_object": {}
}
}]""",
empty_arguments_output=(
'<|tool_call|> [{"name": "refresh", "arguments": {}}]'
),
surrounding_text_output="""Let me check the weather for you.
<|tool_call|> [{"name": "get_weather", "arguments": {"city": "Tokyo"}}]
I'll get that information.""",
escaped_strings_output="""<tool_call> [{
"name": "test_function",
"arguments": {
"quoted": "He said \\"hello\\"",
"path": "C:\\\\Users\\\\file.txt",
"newline": "line1\\nline2",
"unicode": "emoji: 🎉"
}
}]""",
malformed_input_outputs=[
'<|tool_call|> [{"name": "func", "arguments": {',
'<|tool_call|> {"name": "func", "arguments": {}}', # Not an array
'[{"name": "func", "arguments": "not a dict"}]',
'Some text [{"name": "func"}]', # JSON but not tool call format
],
# Expected results
single_tool_call_expected_name="get_weather",
single_tool_call_expected_args={"city": "Tokyo"},
# Granite strips content when tool calls present
single_tool_call_expected_content=None,
parallel_tool_calls_count=2,
parallel_tool_calls_names=["get_weather", "get_time"],
# xfail markers
xfail_streaming={
"test_malformed_input": (
"Streaming mode incorrectly creates tool call from malformed JSON"
),
"test_surrounding_text": (
"Parser doesn't handle surrounding text correctly in streaming"
),
"test_streaming_reconstruction": (
"Streaming mode doesn't strip <|tool_call|> marker from content"
),
},
xfail_nonstreaming={
"test_surrounding_text": (
"Parser doesn't handle surrounding text correctly in non-streaming"
),
},
)
# Granite-Specific Tests
@pytest.mark.parametrize("streaming", [True, False])
def test_granite_token_prefix_format(self, tool_parser, streaming):
"""Verify parser handles Granite 3.0 <|tool_call|> token format."""
single_tool_call_token = (
'<|tool_call|> [{"name": "get_weather", "arguments": {"city": "Tokyo"}}]'
)
content, tool_calls = run_tool_extraction(
tool_parser, single_tool_call_token, streaming=streaming
)
assert len(tool_calls) == 1, (
f"Expected 1 tool call from token format, got {len(tool_calls)}"
)
assert tool_calls[0].function.name == "get_weather"
@pytest.mark.parametrize("streaming", [True, False])
def test_granite_string_prefix_format(self, tool_parser, streaming):
"""Verify parser handles Granite 3.1 <tool_call> string format."""
single_tool_call_string = (
'<tool_call> [{"name": "get_weather", "arguments": {"city": "Tokyo"}}]'
)
content, tool_calls = run_tool_extraction(
tool_parser, single_tool_call_string, streaming=streaming
)
assert len(tool_calls) == 1, (
f"Expected 1 tool call from string format, got {len(tool_calls)}"
)
assert tool_calls[0].function.name == "get_weather"
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from unittest.mock import MagicMock
import pytest
from tests.tool_parsers.common_tests import (
ToolParserTestConfig,
ToolParserTests,
)
from vllm.tokenizers import TokenizerLike
class TestInternLM2ToolParser(ToolParserTests):
@pytest.fixture
def tokenizer(self, default_tokenizer: TokenizerLike) -> TokenizerLike:
"""Add some internlm2 specific tokens to the default vocab."""
tokenizer_vocab = default_tokenizer.get_vocab()
default_tokenizer.get_vocab = MagicMock()
tokenizer_vocab.update(
{
"<|action_start|>": 92540,
"<|plugin|>": 92541,
"<|action_end|>": 92542,
}
)
default_tokenizer.get_vocab.return_value = tokenizer_vocab
return default_tokenizer
@pytest.fixture
def test_config(self) -> ToolParserTestConfig:
return ToolParserTestConfig(
parser_name="internlm",
# Test data
no_tool_calls_output="This is a regular response without any tool calls.",
single_tool_call_output=(
'<|action_start|><|plugin|>{"name": "get_weather", '
'"parameters": {"city": "Tokyo"}}<|action_end|>'
),
# InternLM2 doesn't support parallel calls
parallel_tool_calls_output=(
'<|action_start|><|plugin|>{"name": "get_weather", '
'"parameters": {"city": "Tokyo"}}<|action_end|>'
),
various_data_types_output="""<|action_start|><|plugin|>{
"name": "test_function",
"parameters": {
"string_field": "hello",
"int_field": 42,
"float_field": 3.14,
"bool_field": true,
"null_field": null,
"array_field": ["a", "b", "c"],
"object_field": {"nested": "value"},
"empty_array": [],
"empty_object": {}
}
}<|action_end|>""",
empty_arguments_output=(
'<|action_start|><|plugin|>{"name": "refresh", '
'"parameters": {}}<|action_end|>'
),
surrounding_text_output=(
"Let me check the weather for you. "
'<|action_start|><|plugin|>{"name": "get_weather", '
'"parameters": {"city": "Tokyo"}}<|action_end|>'
),
escaped_strings_output="""<|action_start|><|plugin|>{
"name": "test_function",
"parameters": {
"quoted": "He said \\"hello\\"",
"path": "C:\\\\Users\\\\file.txt",
"newline": "line1\\nline2",
"unicode": "emoji: 🎉"
}
}<|action_end|>""",
malformed_input_outputs=[
'<|action_start|><|plugin|>{"name": "func", "parameters": {',
(
'<|action_start|><|plugin|>{"name": "func", '
'"parameters": "not a dict"}<|action_end|>'
),
"<|action_start|><|plugin|>not json<|action_end|>",
"<|action_start|><|plugin|>",
'<|action_start|>{"name": "func"}',
],
# Expected results
single_tool_call_expected_name="get_weather",
single_tool_call_expected_args={"city": "Tokyo"},
single_tool_call_expected_content=None,
parallel_tool_calls_count=1, # InternLM2 only supports single tool calls
parallel_tool_calls_names=["get_weather"],
# Parser-specific settings
allow_empty_or_json_empty_args=True,
# xfail markers
xfail_streaming={
"test_single_tool_call_simple_args": (
"InternLM2 streaming not fully implemented"
),
"test_parallel_tool_calls": (
"InternLM2 streaming not fully implemented"
),
"test_various_data_types": (
"InternLM2 streaming not fully implemented"
),
"test_empty_arguments": ("InternLM2 streaming not fully implemented"),
"test_surrounding_text": ("InternLM2 streaming not fully implemented"),
"test_escaped_strings": ("InternLM2 streaming not fully implemented"),
"test_streaming_reconstruction": (
"InternLM2 streaming parser returns '<|action_start|' as "
"content instead of None - streaming/non-streaming inconsistency"
),
},
xfail_nonstreaming={
"test_malformed_input": (
"InternLM2 parser raises JSONDecodeError on malformed JSON "
"instead of gracefully handling it"
),
},
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from unittest.mock import MagicMock
import pytest
from tests.tool_parsers.common_tests import (
ToolParserTestConfig,
ToolParserTests,
)
from vllm.tokenizers import TokenizerLike
class TestLongCatToolParser(ToolParserTests):
@pytest.fixture
def tokenizer(self, default_tokenizer: TokenizerLike) -> TokenizerLike:
"""Add some longcat specific tokens to the default vocab."""
tokenizer = default_tokenizer
tokenizer_vocab = tokenizer.get_vocab()
tokenizer.get_vocab = MagicMock()
tokenizer_vocab.update(
{
"<longcat_tool_call>": 32000,
"</longcat_tool_call>": 32001,
}
)
tokenizer.get_vocab.return_value = tokenizer_vocab
return tokenizer
@pytest.fixture
def test_config(self) -> ToolParserTestConfig:
return ToolParserTestConfig(
parser_name="longcat",
# Test data
no_tool_calls_output="This is a regular response without any tool calls.",
single_tool_call_output=(
'<longcat_tool_call>{"name": "get_weather", '
'"arguments": {"city": "Tokyo"}}</longcat_tool_call>'
),
parallel_tool_calls_output=(
'<longcat_tool_call>{"name": "get_weather", '
'"arguments": {"city": "Tokyo"}}</longcat_tool_call>\n'
'<longcat_tool_call>{"name": "get_time", '
'"arguments": {"timezone": "Asia/Tokyo"}}</longcat_tool_call>'
),
various_data_types_output="""<longcat_tool_call>{
"name": "test_function",
"arguments": {
"string_field": "hello",
"int_field": 42,
"float_field": 3.14,
"bool_field": true,
"null_field": null,
"array_field": ["a", "b", "c"],
"object_field": {"nested": "value"},
"empty_array": [],
"empty_object": {}
}
}</longcat_tool_call>""",
empty_arguments_output=(
'<longcat_tool_call>{"name": "refresh", "arguments": {}}'
"</longcat_tool_call>"
),
surrounding_text_output=(
"Let me check the weather for you.\n"
'<longcat_tool_call>{"name": "get_weather", '
'"arguments": {"city": "Tokyo"}}</longcat_tool_call>\n'
"Here is the result."
),
escaped_strings_output="""<longcat_tool_call>{
"name": "test_function",
"arguments": {
"quoted": "He said \\"hello\\"",
"path": "C:\\\\Users\\\\file.txt",
"newline": "line1\\nline2",
"unicode": "emoji: 🎉"
}
}</longcat_tool_call>""",
malformed_input_outputs=[
'<longcat_tool_call>{"name": "func", "arguments": {',
(
'<longcat_tool_call>{"name": "func", '
'"arguments": "not a dict"}</longcat_tool_call>'
),
"Some text with <longcat_tool_call>invalid json",
],
# Expected results
single_tool_call_expected_name="get_weather",
single_tool_call_expected_args={"city": "Tokyo"},
single_tool_call_expected_content=None,
parallel_tool_calls_count=2,
parallel_tool_calls_names=["get_weather", "get_time"],
# xfail markers
xfail_streaming={
"test_malformed_input": "Streaming has complex buffering behavior",
},
xfail_nonstreaming={},
# Configuration
allow_empty_or_json_empty_args=True,
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from unittest.mock import MagicMock
import pytest
from tests.tool_parsers.common_tests import (
ToolParserTestConfig,
ToolParserTests,
)
from vllm.tokenizers import TokenizerLike
class TestPhi4MiniToolParser(ToolParserTests):
@pytest.fixture
def tokenizer(self, default_tokenizer: TokenizerLike) -> TokenizerLike:
"""Add some phi4mini specific tokens to the default vocab."""
tokenizer = default_tokenizer
tokenizer_vocab = tokenizer.get_vocab()
tokenizer.get_vocab = MagicMock()
tokenizer_vocab.update(
{
"functools": 32000,
}
)
tokenizer.get_vocab.return_value = tokenizer_vocab
return tokenizer
@pytest.fixture
def test_config(self) -> ToolParserTestConfig:
return ToolParserTestConfig(
parser_name="phi4_mini_json",
# Test data
no_tool_calls_output="This is a regular response without any tool calls.",
single_tool_call_output=(
'functools[{"name": "get_weather", "arguments": {"city": "Tokyo"}}]'
),
parallel_tool_calls_output="""functools[
{"name": "get_weather", "arguments": {"city": "Tokyo"}},
{"name": "get_time", "arguments": {"timezone": "Asia/Tokyo"}}
]""",
various_data_types_output="""functools[{
"name": "test_function",
"arguments": {
"string_field": "hello",
"int_field": 42,
"float_field": 3.14,
"bool_field": true,
"null_field": null,
"array_field": ["a", "b", "c"],
"object_field": {"nested": "value"},
"empty_array": [],
"empty_object": {}
}
}]""",
empty_arguments_output='functools[{"name": "refresh", "arguments": {}}]',
surrounding_text_output="""Let me check the weather for you.
functools[{"name": "get_weather", "arguments": {"city": "Tokyo"}}]
Would you like to know more?""",
escaped_strings_output="""functools[{
"name": "test_function",
"arguments": {
"quoted": "He said \\"hello\\"",
"path": "C:\\\\Users\\\\file.txt",
"newline": "line1\\nline2",
"unicode": "emoji: 🎉"
}
}]""",
malformed_input_outputs=[
'functools[{"name": "func", "arguments": {',
'functools[{"name": "func", "arguments": "not a dict"}]',
'functools{"name": "func"}', # Missing brackets
'functools[{"name": "func"}]', # Missing arguments/parameters
"functools[] This is just text", # Empty functools
"functools[ This is just text ]", # functools with invalid JSON
],
# Expected results
single_tool_call_expected_name="get_weather",
single_tool_call_expected_args={"city": "Tokyo"},
# Phi-4 Mini strips content when tool calls present
single_tool_call_expected_content=None,
parallel_tool_calls_count=2,
parallel_tool_calls_names=["get_weather", "get_time"],
parallel_tool_calls_expected_content=None,
# xfail markers
xfail_streaming={
"test_no_tool_calls": "Phi4 Mini streaming not implemented",
"test_single_tool_call_simple_args": (
"Phi4 Mini streaming not implemented"
),
"test_parallel_tool_calls": "Phi4 Mini streaming not implemented",
"test_various_data_types": "Phi4 Mini streaming not implemented",
"test_empty_arguments": "Phi4 Mini streaming not implemented",
"test_surrounding_text": "Phi4 Mini streaming not implemented",
"test_escaped_strings": "Phi4 Mini streaming not implemented",
"test_streaming_reconstruction": "Phi4 Mini streaming not implemented",
},
xfail_nonstreaming={
"test_various_data_types": (
"Phi4MiniJsonToolParser regex has nesting limitations "
"with nested objects"
),
"test_malformed_input": (
"Phi4MiniJsonToolParser incorrectly sets "
"tools_called=True on empty array"
),
},
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from tests.tool_parsers.common_tests import (
ToolParserTestConfig,
ToolParserTests,
)
class TestQwen3xmlToolParser(ToolParserTests):
@pytest.fixture
def test_config(self) -> ToolParserTestConfig:
return ToolParserTestConfig(
parser_name="qwen3_xml",
# Test data
no_tool_calls_output="This is a regular response without any tool calls.",
single_tool_call_output="<tool_call>\n<function=get_weather>\n<parameter=city>Tokyo</parameter>\n</function>\n</tool_call>",
parallel_tool_calls_output="<tool_call>\n<function=get_weather>\n<parameter=city>Tokyo</parameter>\n</function>\n</tool_call><tool_call>\n<function=get_time>\n<parameter=timezone>Asia/Tokyo</parameter>\n</function>\n</tool_call>",
various_data_types_output=(
"<tool_call>\n<function=test_function>\n"
"<parameter=string_field>hello</parameter>\n"
"<parameter=int_field>42</parameter>\n"
"<parameter=float_field>3.14</parameter>\n"
"<parameter=bool_field>true</parameter>\n"
"<parameter=null_field>null</parameter>\n"
'<parameter=array_field>["a", "b", "c"]</parameter>\n'
'<parameter=object_field>{"nested": "value"}</parameter>\n'
"</function>\n</tool_call>"
),
empty_arguments_output="<tool_call>\n<function=refresh>\n</function>\n</tool_call>",
surrounding_text_output=(
"Let me check the weather for you.\n\n"
"<tool_call>\n<function=get_weather>\n"
"<parameter=city>Tokyo</parameter>\n"
"</function>\n</tool_call>\n\n"
"I will get that information."
),
escaped_strings_output=(
"<tool_call>\n<function=test_function>\n"
'<parameter=quoted>He said "hello"</parameter>\n'
"<parameter=path>C:\\Users\\file.txt</parameter>\n"
"<parameter=newline>line1\nline2</parameter>\n"
"</function>\n</tool_call>"
),
malformed_input_outputs=[
"<tool_call><function=func>",
"<tool_call><function=></function></tool_call>",
],
# Expected results
single_tool_call_expected_name="get_weather",
single_tool_call_expected_args={"city": "Tokyo"},
parallel_tool_calls_count=2,
parallel_tool_calls_names=["get_weather", "get_time"],
# xfail markers - Qwen3XML has systematic streaming issues
xfail_streaming={
"test_single_tool_call_simple_args": (
"Qwen3XML streaming has systematic issues"
),
"test_parallel_tool_calls": "Qwen3XML streaming has systematic issues",
"test_various_data_types": "Qwen3XML streaming has systematic issues",
"test_empty_arguments": "Qwen3XML streaming has systematic issues",
"test_surrounding_text": "Qwen3XML streaming has systematic issues",
"test_escaped_strings": "Qwen3XML streaming has systematic issues",
"test_malformed_input": (
"Qwen3XML parser is lenient with malformed input"
),
"test_streaming_reconstruction": (
"Qwen3XML streaming reconstruction has known issues"
),
},
supports_typed_arguments=False,
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from tests.tool_parsers.common_tests import (
ToolParserTestConfig,
ToolParserTests,
)
from vllm.tokenizers import TokenizerLike, get_tokenizer
class TestStep3ToolParser(ToolParserTests):
@pytest.fixture(scope="class")
def tokenizer(self) -> TokenizerLike:
return get_tokenizer("stepfun-ai/step3")
@pytest.fixture
def test_config(self) -> ToolParserTestConfig:
return ToolParserTestConfig(
parser_name="step3",
# Test data
no_tool_calls_output="This is a regular response without any tool calls.",
single_tool_call_output=(
"<|tool_calls_begin|><|tool_call_begin|>"
'<steptml:invoke name="get_weather">'
'<steptml:parameter name="city">Tokyo</steptml:parameter>'
"</steptml:invoke><|tool_call_end|><|tool_calls_end|>"
),
parallel_tool_calls_output=(
"<|tool_calls_begin|><|tool_call_begin|>"
'<steptml:invoke name="get_weather">'
'<steptml:parameter name="city">Tokyo</steptml:parameter>'
"</steptml:invoke><|tool_call_end|><|tool_sep|>"
'<|tool_call_begin|><steptml:invoke name="get_time">'
'<steptml:parameter name="timezone">Asia/Tokyo</steptml:parameter>'
"</steptml:invoke><|tool_call_end|><|tool_calls_end|>"
),
various_data_types_output=(
"<|tool_calls_begin|><|tool_call_begin|>"
'<steptml:invoke name="test_function">'
'<steptml:parameter name="string_field">hello</steptml:parameter>'
'<steptml:parameter name="int_field">42</steptml:parameter>'
'<steptml:parameter name="float_field">3.14</steptml:parameter>'
'<steptml:parameter name="bool_field">true</steptml:parameter>'
'<steptml:parameter name="null_field">null</steptml:parameter>'
'<steptml:parameter name="array_field">'
'["a", "b", "c"]</steptml:parameter>'
'<steptml:parameter name="object_field">'
'{"nested": "value"}</steptml:parameter>'
"</steptml:invoke><|tool_call_end|><|tool_calls_end|>"
),
empty_arguments_output=(
"<|tool_calls_begin|><|tool_call_begin|>"
'<steptml:invoke name="refresh"></steptml:invoke>'
"<|tool_call_end|><|tool_calls_end|>"
),
surrounding_text_output=(
"Let me check the weather for you.\n\n"
"<|tool_calls_begin|><|tool_call_begin|>"
'<steptml:invoke name="get_weather">'
'<steptml:parameter name="city">Tokyo</steptml:parameter>'
"</steptml:invoke><|tool_call_end|><|tool_calls_end|>\n\n"
"I'll get that information."
),
escaped_strings_output=(
"<|tool_calls_begin|><|tool_call_begin|>"
'<steptml:invoke name="test_function">'
'<steptml:parameter name="quoted">He said "hello"</steptml:parameter>'
'<steptml:parameter name="path">C:\\Users\\file.txt</steptml:parameter>'
'<steptml:parameter name="newline">line1\nline2</steptml:parameter>'
"</steptml:invoke><|tool_call_end|><|tool_calls_end|>"
),
malformed_input_outputs=[
(
"<|tool_calls_begin|><|tool_call_begin|>"
'<steptml:invoke name="func">'
),
(
'<|tool_call_begin|><steptml:invoke name="func">'
"</steptml:invoke><|tool_call_end|>"
),
],
# Expected results
single_tool_call_expected_name="get_weather",
single_tool_call_expected_args={"city": "Tokyo"},
parallel_tool_calls_count=2,
parallel_tool_calls_names=["get_weather", "get_time"],
# xfail markers
xfail_nonstreaming={
"test_single_tool_call_simple_args": (
"Step3 parser non-streaming has bugs"
),
"test_parallel_tool_calls": ("Step3 parser non-streaming has bugs"),
"test_various_data_types": "Step3 parser non-streaming has bugs",
"test_empty_arguments": "Step3 parser non-streaming has bugs",
"test_surrounding_text": "Step3 parser non-streaming has bugs",
"test_escaped_strings": "Step3 parser non-streaming has bugs",
},
xfail_streaming={
"test_parallel_tool_calls": (
"Step3 parser has significant bugs in both streaming "
"and non-streaming"
),
"test_streaming_reconstruction": (
"Step3 parser non-streaming has bugs, so streaming "
"doesn't match non-streaming"
),
},
supports_typed_arguments=False,
)
...@@ -6,6 +6,7 @@ import pytest ...@@ -6,6 +6,7 @@ import pytest
from .utils import ( from .utils import (
MESSAGES_WITHOUT_TOOLS, MESSAGES_WITHOUT_TOOLS,
SEED,
WEATHER_TOOL, WEATHER_TOOL,
ServerConfig, ServerConfig,
ensure_system_prompt, ensure_system_prompt,
...@@ -27,6 +28,7 @@ async def test_chat_completion_without_tools( ...@@ -27,6 +28,7 @@ async def test_chat_completion_without_tools(
max_completion_tokens=150, max_completion_tokens=150,
model=model_name, model=model_name,
logprobs=False, logprobs=False,
seed=SEED,
) )
choice = chat_completion.choices[0] choice = chat_completion.choices[0]
stop_reason = chat_completion.choices[0].finish_reason stop_reason = chat_completion.choices[0].finish_reason
...@@ -47,6 +49,7 @@ async def test_chat_completion_without_tools( ...@@ -47,6 +49,7 @@ async def test_chat_completion_without_tools(
max_completion_tokens=150, max_completion_tokens=150,
model=model_name, model=model_name,
logprobs=False, logprobs=False,
seed=SEED,
stream=True, stream=True,
) )
chunks: list[str] = [] chunks: list[str] = []
...@@ -97,6 +100,7 @@ async def test_chat_completion_with_tools( ...@@ -97,6 +100,7 @@ async def test_chat_completion_with_tools(
model=model_name, model=model_name,
tools=[WEATHER_TOOL], tools=[WEATHER_TOOL],
logprobs=False, logprobs=False,
seed=SEED,
) )
choice = chat_completion.choices[0] choice = chat_completion.choices[0]
stop_reason = chat_completion.choices[0].finish_reason stop_reason = chat_completion.choices[0].finish_reason
...@@ -118,6 +122,7 @@ async def test_chat_completion_with_tools( ...@@ -118,6 +122,7 @@ async def test_chat_completion_with_tools(
model=model_name, model=model_name,
logprobs=False, logprobs=False,
tools=[WEATHER_TOOL], tools=[WEATHER_TOOL],
seed=SEED,
stream=True, stream=True,
) )
......
...@@ -10,6 +10,7 @@ from .utils import ( ...@@ -10,6 +10,7 @@ from .utils import (
MESSAGES_ASKING_FOR_PARALLEL_TOOLS, MESSAGES_ASKING_FOR_PARALLEL_TOOLS,
MESSAGES_WITH_PARALLEL_TOOL_RESPONSE, MESSAGES_WITH_PARALLEL_TOOL_RESPONSE,
SEARCH_TOOL, SEARCH_TOOL,
SEED,
WEATHER_TOOL, WEATHER_TOOL,
ServerConfig, ServerConfig,
) )
...@@ -39,6 +40,7 @@ async def test_parallel_tool_calls( ...@@ -39,6 +40,7 @@ async def test_parallel_tool_calls(
model=model_name, model=model_name,
tools=[WEATHER_TOOL, SEARCH_TOOL], tools=[WEATHER_TOOL, SEARCH_TOOL],
logprobs=False, logprobs=False,
seed=SEED,
) )
choice = chat_completion.choices[0] choice = chat_completion.choices[0]
...@@ -76,6 +78,7 @@ async def test_parallel_tool_calls( ...@@ -76,6 +78,7 @@ async def test_parallel_tool_calls(
max_completion_tokens=200, max_completion_tokens=200,
tools=[WEATHER_TOOL, SEARCH_TOOL], tools=[WEATHER_TOOL, SEARCH_TOOL],
logprobs=False, logprobs=False,
seed=SEED,
stream=True, stream=True,
) )
...@@ -166,6 +169,7 @@ async def test_parallel_tool_calls_with_results( ...@@ -166,6 +169,7 @@ async def test_parallel_tool_calls_with_results(
model=model_name, model=model_name,
tools=[WEATHER_TOOL, SEARCH_TOOL], tools=[WEATHER_TOOL, SEARCH_TOOL],
logprobs=False, logprobs=False,
seed=SEED,
) )
choice = chat_completion.choices[0] choice = chat_completion.choices[0]
...@@ -184,6 +188,7 @@ async def test_parallel_tool_calls_with_results( ...@@ -184,6 +188,7 @@ async def test_parallel_tool_calls_with_results(
model=model_name, model=model_name,
tools=[WEATHER_TOOL, SEARCH_TOOL], tools=[WEATHER_TOOL, SEARCH_TOOL],
logprobs=False, logprobs=False,
seed=SEED,
stream=True, stream=True,
) )
...@@ -229,6 +234,7 @@ async def test_parallel_tool_calls_false(client: openai.AsyncOpenAI): ...@@ -229,6 +234,7 @@ async def test_parallel_tool_calls_false(client: openai.AsyncOpenAI):
model=model_name, model=model_name,
tools=[WEATHER_TOOL, SEARCH_TOOL], tools=[WEATHER_TOOL, SEARCH_TOOL],
logprobs=False, logprobs=False,
seed=SEED,
parallel_tool_calls=False, parallel_tool_calls=False,
) )
...@@ -247,6 +253,7 @@ async def test_parallel_tool_calls_false(client: openai.AsyncOpenAI): ...@@ -247,6 +253,7 @@ async def test_parallel_tool_calls_false(client: openai.AsyncOpenAI):
max_completion_tokens=200, max_completion_tokens=200,
tools=[WEATHER_TOOL, SEARCH_TOOL], tools=[WEATHER_TOOL, SEARCH_TOOL],
logprobs=False, logprobs=False,
seed=SEED,
parallel_tool_calls=False, parallel_tool_calls=False,
stream=True, stream=True,
) )
......
...@@ -10,6 +10,7 @@ from .utils import ( ...@@ -10,6 +10,7 @@ from .utils import (
MESSAGES_ASKING_FOR_TOOLS, MESSAGES_ASKING_FOR_TOOLS,
MESSAGES_WITH_TOOL_RESPONSE, MESSAGES_WITH_TOOL_RESPONSE,
SEARCH_TOOL, SEARCH_TOOL,
SEED,
WEATHER_TOOL, WEATHER_TOOL,
) )
...@@ -27,6 +28,7 @@ async def test_tool_call_and_choice(client: openai.AsyncOpenAI): ...@@ -27,6 +28,7 @@ async def test_tool_call_and_choice(client: openai.AsyncOpenAI):
model=model_name, model=model_name,
tools=[WEATHER_TOOL, SEARCH_TOOL], tools=[WEATHER_TOOL, SEARCH_TOOL],
logprobs=False, logprobs=False,
seed=SEED,
) )
choice = chat_completion.choices[0] choice = chat_completion.choices[0]
...@@ -71,6 +73,7 @@ async def test_tool_call_and_choice(client: openai.AsyncOpenAI): ...@@ -71,6 +73,7 @@ async def test_tool_call_and_choice(client: openai.AsyncOpenAI):
max_completion_tokens=100, max_completion_tokens=100,
tools=[WEATHER_TOOL, SEARCH_TOOL], tools=[WEATHER_TOOL, SEARCH_TOOL],
logprobs=False, logprobs=False,
seed=SEED,
stream=True, stream=True,
) )
...@@ -154,6 +157,7 @@ async def test_tool_call_with_results(client: openai.AsyncOpenAI): ...@@ -154,6 +157,7 @@ async def test_tool_call_with_results(client: openai.AsyncOpenAI):
model=model_name, model=model_name,
tools=[WEATHER_TOOL, SEARCH_TOOL], tools=[WEATHER_TOOL, SEARCH_TOOL],
logprobs=False, logprobs=False,
seed=SEED,
) )
choice = chat_completion.choices[0] choice = chat_completion.choices[0]
...@@ -171,6 +175,7 @@ async def test_tool_call_with_results(client: openai.AsyncOpenAI): ...@@ -171,6 +175,7 @@ async def test_tool_call_with_results(client: openai.AsyncOpenAI):
model=model_name, model=model_name,
tools=[WEATHER_TOOL, SEARCH_TOOL], tools=[WEATHER_TOOL, SEARCH_TOOL],
logprobs=False, logprobs=False,
seed=SEED,
stream=True, stream=True,
) )
......
...@@ -42,6 +42,8 @@ def ensure_system_prompt( ...@@ -42,6 +42,8 @@ def ensure_system_prompt(
# universal args for all models go here. also good if you need to test locally # universal args for all models go here. also good if you need to test locally
# and change type or KV cache quantization or something. # and change type or KV cache quantization or something.
SEED = 42
ARGS: list[str] = [ ARGS: list[str] = [
"--enable-auto-tool-choice", "--enable-auto-tool-choice",
"--max-model-len", "--max-model-len",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment