Commit 8d75f22e authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.13.0rc1' into v0.13.0rc1-ori

parents ce888aa4 7d80c73d
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from openai_harmony import Role
from openai.types.responses import ResponseFunctionToolCall, ResponseReasoningItem
from openai.types.responses.response_output_item import McpCall
from openai_harmony import Author, Message, Role, TextContent
from vllm.entrypoints.harmony_utils import (
from vllm.entrypoints.openai.parser.harmony_utils import (
has_custom_tools,
parse_input_to_harmony_message,
parse_output_message,
)
......@@ -257,6 +260,193 @@ class TestParseInputToHarmonyMessage:
assert messages[0].content[1].text == "actual text"
class TestParseOutputMessage:
"""Tests for parse_output_message function."""
def test_commentary_with_no_recipient_creates_reasoning(self):
"""Test that commentary with recipient=None (preambles) creates reasoning items.
Per Harmony format, commentary channel can contain preambles to calling
multiple functions - explanatory text with no recipient.
"""
message = Message.from_role_and_content(
Role.ASSISTANT, "I will now search for the weather information."
)
message = message.with_channel("commentary")
# recipient is None by default, representing a preamble
output_items = parse_output_message(message)
assert len(output_items) == 1
assert isinstance(output_items[0], ResponseReasoningItem)
assert output_items[0].type == "reasoning"
assert (
output_items[0].content[0].text
== "I will now search for the weather information."
)
assert output_items[0].content[0].type == "reasoning_text"
def test_commentary_with_function_recipient_creates_function_call(self):
"""Test commentary with recipient='functions.X' creates function calls."""
message = Message.from_role_and_content(
Role.ASSISTANT, '{"location": "San Francisco", "units": "celsius"}'
)
message = message.with_channel("commentary")
message = message.with_recipient("functions.get_weather")
output_items = parse_output_message(message)
assert len(output_items) == 1
assert isinstance(output_items[0], ResponseFunctionToolCall)
assert output_items[0].type == "function_call"
assert output_items[0].name == "get_weather"
assert (
output_items[0].arguments
== '{"location": "San Francisco", "units": "celsius"}'
)
assert output_items[0].call_id.startswith("call_")
assert output_items[0].id.startswith("fc_")
def test_commentary_with_python_recipient_creates_reasoning(self):
"""Test that commentary with recipient='python' creates reasoning items."""
message = Message.from_role_and_content(
Role.ASSISTANT, "import numpy as np\nprint(np.array([1, 2, 3]))"
)
message = message.with_channel("commentary")
message = message.with_recipient("python")
output_items = parse_output_message(message)
assert len(output_items) == 1
assert isinstance(output_items[0], ResponseReasoningItem)
assert output_items[0].type == "reasoning"
assert (
output_items[0].content[0].text
== "import numpy as np\nprint(np.array([1, 2, 3]))"
)
def test_commentary_with_browser_recipient_creates_reasoning(self):
"""Test that commentary with recipient='browser' creates reasoning items."""
message = Message.from_role_and_content(
Role.ASSISTANT, "Navigating to the specified URL"
)
message = message.with_channel("commentary")
message = message.with_recipient("browser")
output_items = parse_output_message(message)
assert len(output_items) == 1
assert isinstance(output_items[0], ResponseReasoningItem)
assert output_items[0].type == "reasoning"
assert output_items[0].content[0].text == "Navigating to the specified URL"
def test_commentary_with_container_recipient_creates_reasoning(self):
"""Test that commentary with recipient='container' creates reasoning items."""
message = Message.from_role_and_content(
Role.ASSISTANT, "Running command in container"
)
message = message.with_channel("commentary")
message = message.with_recipient("container")
output_items = parse_output_message(message)
assert len(output_items) == 1
assert isinstance(output_items[0], ResponseReasoningItem)
assert output_items[0].type == "reasoning"
assert output_items[0].content[0].text == "Running command in container"
def test_commentary_with_empty_content_and_no_recipient(self):
"""Test edge case: empty commentary with recipient=None."""
message = Message.from_role_and_content(Role.ASSISTANT, "")
message = message.with_channel("commentary")
output_items = parse_output_message(message)
assert len(output_items) == 1
assert isinstance(output_items[0], ResponseReasoningItem)
assert output_items[0].content[0].text == ""
def test_commentary_with_multiple_contents_and_no_recipient(self):
"""Test multiple content items in commentary with no recipient."""
contents = [
TextContent(text="Step 1: Analyze the request"),
TextContent(text="Step 2: Prepare to call functions"),
]
message = Message.from_role_and_contents(Role.ASSISTANT, contents)
message = message.with_channel("commentary")
output_items = parse_output_message(message)
assert len(output_items) == 2
assert all(isinstance(item, ResponseReasoningItem) for item in output_items)
assert output_items[0].content[0].text == "Step 1: Analyze the request"
assert output_items[1].content[0].text == "Step 2: Prepare to call functions"
def test_commentary_with_multiple_function_calls(self):
"""Test multiple function calls in commentary channel."""
contents = [
TextContent(text='{"location": "San Francisco"}'),
TextContent(text='{"location": "New York"}'),
]
message = Message.from_role_and_contents(Role.ASSISTANT, contents)
message = message.with_channel("commentary")
message = message.with_recipient("functions.get_weather")
output_items = parse_output_message(message)
assert len(output_items) == 2
assert all(isinstance(item, ResponseFunctionToolCall) for item in output_items)
assert output_items[0].name == "get_weather"
assert output_items[1].name == "get_weather"
assert output_items[0].arguments == '{"location": "San Francisco"}'
assert output_items[1].arguments == '{"location": "New York"}'
def test_commentary_with_unknown_recipient_creates_mcp_call(self):
"""Test that commentary with unknown recipient creates MCP call."""
message = Message.from_role_and_content(Role.ASSISTANT, '{"arg": "value"}')
message = message.with_channel("commentary")
message = message.with_recipient("custom_tool")
output_items = parse_output_message(message)
assert len(output_items) == 1
assert isinstance(output_items[0], McpCall)
assert output_items[0].type == "mcp_call"
assert output_items[0].name == "custom_tool"
assert output_items[0].server_label == "custom_tool"
def test_analysis_channel_creates_reasoning(self):
"""Test that analysis channel creates reasoning items."""
message = Message.from_role_and_content(
Role.ASSISTANT, "Analyzing the problem step by step..."
)
message = message.with_channel("analysis")
output_items = parse_output_message(message)
assert len(output_items) == 1
assert isinstance(output_items[0], ResponseReasoningItem)
assert output_items[0].type == "reasoning"
assert (
output_items[0].content[0].text == "Analyzing the problem step by step..."
)
def test_non_assistant_message_returns_empty(self):
"""Test that non-assistant messages return empty list.
Per the implementation, tool messages to assistant (e.g., search results)
are not included in final output to align with OpenAI behavior.
"""
message = Message.from_author_and_content(
Author.new(Role.TOOL, "functions.get_weather"),
"The weather is sunny, 72°F",
)
output_items = parse_output_message(message)
assert len(output_items) == 0
def test_has_custom_tools() -> None:
assert not has_custom_tools(set())
assert not has_custom_tools({"web_search_preview", "code_interpreter", "container"})
......@@ -264,3 +454,167 @@ def test_has_custom_tools() -> None:
assert has_custom_tools(
{"web_search_preview", "code_interpreter", "container", "others"}
)
def test_parse_mcp_call_basic() -> None:
"""Test that MCP calls are parsed with correct type and server_label."""
message = Message.from_role_and_content(Role.ASSISTANT, '{"path": "/tmp"}')
message = message.with_recipient("filesystem")
message = message.with_channel("commentary")
output_items = parse_output_message(message)
assert len(output_items) == 1
assert isinstance(output_items[0], McpCall)
assert output_items[0].type == "mcp_call"
assert output_items[0].name == "filesystem"
assert output_items[0].server_label == "filesystem"
assert output_items[0].arguments == '{"path": "/tmp"}'
assert output_items[0].status == "completed"
def test_parse_mcp_call_dotted_recipient() -> None:
"""Test that dotted recipients extract the tool name correctly."""
message = Message.from_role_and_content(Role.ASSISTANT, '{"cmd": "ls"}')
message = message.with_recipient("repo_browser.list")
message = message.with_channel("commentary")
output_items = parse_output_message(message)
assert len(output_items) == 1
assert isinstance(output_items[0], McpCall)
assert output_items[0].name == "list"
assert output_items[0].server_label == "repo_browser"
def test_mcp_vs_function_call() -> None:
"""Test that function calls are not parsed as MCP calls."""
func_message = Message.from_role_and_content(Role.ASSISTANT, '{"arg": "value"}')
func_message = func_message.with_recipient("functions.my_tool")
func_message = func_message.with_channel("commentary")
func_items = parse_output_message(func_message)
assert len(func_items) == 1
assert not isinstance(func_items[0], McpCall)
assert func_items[0].type == "function_call"
def test_mcp_vs_builtin_tools() -> None:
"""Test that built-in tools (python, container) are not parsed as MCP calls."""
# Test python (built-in tool) - should be reasoning, not MCP
python_message = Message.from_role_and_content(Role.ASSISTANT, "print('hello')")
python_message = python_message.with_recipient("python")
python_message = python_message.with_channel("commentary")
python_items = parse_output_message(python_message)
assert len(python_items) == 1
assert not isinstance(python_items[0], McpCall)
assert python_items[0].type == "reasoning"
def test_parse_remaining_state_commentary_channel() -> None:
"""Test parse_remaining_state with commentary channel and various recipients."""
from unittest.mock import Mock
from vllm.entrypoints.openai.parser.harmony_utils import parse_remaining_state
# Test 1: functions.* recipient → should return function tool call
parser_func = Mock()
parser_func.current_content = '{"arg": "value"}'
parser_func.current_role = Role.ASSISTANT
parser_func.current_channel = "commentary"
parser_func.current_recipient = "functions.my_tool"
func_items = parse_remaining_state(parser_func)
assert len(func_items) == 1
assert not isinstance(func_items[0], McpCall)
assert func_items[0].type == "function_call"
assert func_items[0].name == "my_tool"
assert func_items[0].status == "in_progress"
# Test 2: MCP tool (not builtin) → should return MCP call
parser_mcp = Mock()
parser_mcp.current_content = '{"path": "/tmp"}'
parser_mcp.current_role = Role.ASSISTANT
parser_mcp.current_channel = "commentary"
parser_mcp.current_recipient = "filesystem"
mcp_items = parse_remaining_state(parser_mcp)
assert len(mcp_items) == 1
assert isinstance(mcp_items[0], McpCall)
assert mcp_items[0].type == "mcp_call"
assert mcp_items[0].name == "filesystem"
assert mcp_items[0].server_label == "filesystem"
assert mcp_items[0].status == "in_progress"
# Test 3: Built-in tool (python)
# should NOT return MCP call, falls through to reasoning
parser_builtin = Mock()
parser_builtin.current_content = "print('hello')"
parser_builtin.current_role = Role.ASSISTANT
parser_builtin.current_channel = "commentary"
parser_builtin.current_recipient = "python"
builtin_items = parse_remaining_state(parser_builtin)
# Should fall through to reasoning logic
assert len(builtin_items) == 1
assert not isinstance(builtin_items[0], McpCall)
assert builtin_items[0].type == "reasoning"
def test_parse_remaining_state_analysis_channel() -> None:
"""Test parse_remaining_state with analysis channel and various recipients."""
from unittest.mock import Mock
from vllm.entrypoints.openai.parser.harmony_utils import parse_remaining_state
# Test 1: functions.* recipient → should return function tool call
parser_func = Mock()
parser_func.current_content = '{"arg": "value"}'
parser_func.current_role = Role.ASSISTANT
parser_func.current_channel = "analysis"
parser_func.current_recipient = "functions.my_tool"
func_items = parse_remaining_state(parser_func)
assert len(func_items) == 1
assert not isinstance(func_items[0], McpCall)
assert func_items[0].type == "function_call"
assert func_items[0].name == "my_tool"
assert func_items[0].status == "in_progress"
# Test 2: MCP tool (not builtin) → should return MCP call
parser_mcp = Mock()
parser_mcp.current_content = '{"query": "test"}'
parser_mcp.current_role = Role.ASSISTANT
parser_mcp.current_channel = "analysis"
parser_mcp.current_recipient = "database"
mcp_items = parse_remaining_state(parser_mcp)
assert len(mcp_items) == 1
assert isinstance(mcp_items[0], McpCall)
assert mcp_items[0].type == "mcp_call"
assert mcp_items[0].name == "database"
assert mcp_items[0].server_label == "database"
assert mcp_items[0].status == "in_progress"
# Test 3: Built-in tool (container)
# should NOT return MCP call, falls through to reasoning
parser_builtin = Mock()
parser_builtin.current_content = "docker run"
parser_builtin.current_role = Role.ASSISTANT
parser_builtin.current_channel = "analysis"
parser_builtin.current_recipient = "container"
builtin_items = parse_remaining_state(parser_builtin)
# Should fall through to reasoning logic
assert len(builtin_items) == 1
assert not isinstance(builtin_items[0], McpCall)
assert builtin_items[0].type == "reasoning"
......@@ -232,7 +232,7 @@ async def test_server_load(server: RemoteOpenAIServer):
@pytest.mark.asyncio
async def test_health_check_engine_dead_error():
# Import the health function directly to test it in isolation
from vllm.entrypoints.openai.api_server import health
from vllm.entrypoints.serve.instrumentator.health import health
# Create a mock request that simulates what FastAPI would provide
mock_request = Mock(spec=Request)
......
......@@ -69,9 +69,20 @@ async def test_anthropic_streaming(client: anthropic.AsyncAnthropic):
stream=True,
)
first_chunk = None
chunk_count = 0
async for chunk in resp:
chunk_count += 1
if first_chunk is None and chunk.type == "message_start":
first_chunk = chunk
print(chunk.model_dump_json())
assert chunk_count > 0
assert first_chunk is not None, "message_start chunk was never observed"
assert first_chunk.usage is not None, "first chunk should include usage stats"
assert first_chunk.usage["output_tokens"] == 0
assert first_chunk.usage["input_tokens"] > 5
@pytest.mark.asyncio
async def test_anthropic_tool_call(client: anthropic.AsyncAnthropic):
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import importlib
import json
import pytest
import pytest_asyncio
from openai import OpenAI
from ...utils import RemoteOpenAIServer
MODEL_NAME = "Qwen/Qwen3-8B"
@pytest.fixture(scope="module")
def server():
assert importlib.util.find_spec("gpt_oss") is not None, (
"Harmony tests require gpt_oss package to be installed"
)
args = [
"--reasoning-parser",
"qwen3",
"--max_model_len",
"5000",
"--structured-outputs-config.backend",
"xgrammar",
"--enable-auto-tool-choice",
"--tool-call-parser",
"hermes",
"--tool-server",
"demo",
]
env_dict = dict(
VLLM_ENABLE_RESPONSES_API_STORE="1",
VLLM_USE_EXPERIMENTAL_PARSER_CONTEXT="1",
PYTHON_EXECUTION_BACKEND="dangerously_use_uv",
)
with RemoteOpenAIServer(MODEL_NAME, args, env_dict=env_dict) as remote_server:
yield remote_server
@pytest_asyncio.fixture
async def client(server):
async with server.get_async_client() as async_client:
yield async_client
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_basic(client: OpenAI, model_name: str):
response = await client.responses.create(
model=model_name,
input="What is 13 * 24?",
)
assert response is not None
print("response: ", response)
assert response.status == "completed"
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_reasoning_and_function_items(client: OpenAI, model_name: str):
response = await client.responses.create(
model=model_name,
input=[
{"type": "message", "content": "Hello.", "role": "user"},
{
"type": "reasoning",
"id": "lol",
"content": [
{
"type": "reasoning_text",
"text": "We need to respond: greeting.",
}
],
"summary": [],
},
{
"arguments": '{"location": "Paris", "unit": "celsius"}',
"call_id": "call_5f7b38f3b81e4b8380fd0ba74f3ca3ab",
"name": "get_weather",
"type": "function_call",
"id": "fc_4fe5d6fc5b6c4d6fa5f24cc80aa27f78",
"status": "completed",
},
{
"call_id": "call_5f7b38f3b81e4b8380fd0ba74f3ca3ab",
"id": "fc_4fe5d6fc5b6c4d6fa5f24cc80aa27f78",
"output": "The weather in Paris is 20 Celsius",
"status": "completed",
"type": "function_call_output",
},
],
temperature=0.0,
)
assert response is not None
assert response.status == "completed"
# make sure we get a reasoning and text output
assert response.output[0].type == "reasoning"
assert response.output[1].type == "message"
assert type(response.output[1].content[0].text) is str
def get_horoscope(sign):
return f"{sign}: Next Tuesday you will befriend a baby otter."
def call_function(name, args):
if name == "get_horoscope":
return get_horoscope(**args)
else:
raise ValueError(f"Unknown function: {name}")
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_function_call_first_turn(client: OpenAI, model_name: str):
tools = [
{
"type": "function",
"name": "get_horoscope",
"description": "Get today's horoscope for an astrological sign.",
"parameters": {
"type": "object",
"properties": {
"sign": {"type": "string"},
},
"required": ["sign"],
"additionalProperties": False,
},
"strict": True,
}
]
response = await client.responses.create(
model=model_name,
input="What is the horoscope for Aquarius today?",
tools=tools,
temperature=0.0,
)
assert response is not None
assert response.status == "completed"
assert len(response.output) == 2
assert response.output[0].type == "reasoning"
assert response.output[1].type == "function_call"
function_call = response.output[1]
assert function_call.name == "get_horoscope"
assert function_call.call_id is not None
args = json.loads(function_call.arguments)
assert "sign" in args
# the multi turn function call is tested above in
# test_reasoning_and_function_items
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_mcp_tool_call(client: OpenAI, model_name: str):
response = await client.responses.create(
model=model_name,
input="What is 13 * 24? Use python to calculate the result.",
tools=[{"type": "code_interpreter", "container": {"type": "auto"}}],
temperature=0.0,
)
assert response is not None
assert response.status == "completed"
assert response.output[0].type == "reasoning"
assert response.output[1].type == "mcp_call"
assert type(response.output[1].arguments) is str
assert type(response.output[1].output) is str
assert response.output[2].type == "reasoning"
# make sure the correct math is in the final output
assert response.output[3].type == "message"
assert "312" in response.output[3].content[0].text
......@@ -42,6 +42,24 @@ async def test_basic(client: OpenAI, model_name: str):
assert response.status == "completed"
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_enable_response_messages(client: OpenAI, model_name: str):
response = await client.responses.create(
model=model_name,
input="Hello?",
extra_body={"enable_response_messages": True},
)
assert response.status == "completed"
assert response.input_messages[0]["type"] == "raw_message_tokens"
assert type(response.input_messages[0]["message"]) is str
assert len(response.input_messages[0]["message"]) > 10
assert type(response.input_messages[0]["tokens"][0]) is int
assert type(response.output_messages[0]["message"]) is str
assert len(response.output_messages[0]["message"]) > 10
assert type(response.output_messages[0]["tokens"][0]) is int
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_reasoning_item(client: OpenAI, model_name: str):
......
......@@ -726,7 +726,7 @@ async def test_function_calling_required(client: OpenAI, model_name: str):
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_system_message_with_tools(client: OpenAI, model_name: str):
from vllm.entrypoints.harmony_utils import get_system_message
from vllm.entrypoints.openai.parser.harmony_utils import get_system_message
# Test with custom tools enabled - commentary channel should be available
sys_msg = get_system_message(with_custom_tools=True)
......
......@@ -32,13 +32,9 @@ async def whisper_client(server):
@pytest.mark.asyncio
async def test_basic_audio(mary_had_lamb):
server_args = ["--enforce-eager"]
async def test_basic_audio(whisper_client, mary_had_lamb):
# Based on https://github.com/openai/openai-cookbook/blob/main/examples/Whisper_prompting_guide.ipynb.
with RemoteOpenAIServer(MODEL_NAME, server_args) as remote_server:
client = remote_server.get_async_client()
transcription = await client.audio.transcriptions.create(
transcription = await whisper_client.audio.transcriptions.create(
model=MODEL_NAME,
file=mary_had_lamb,
language="en",
......
......@@ -8,6 +8,7 @@ import pytest
import pytest_asyncio
from transformers import AutoProcessor
from vllm.multimodal.base import MediaWithBytes
from vllm.multimodal.utils import encode_image_base64, fetch_image
from ...utils import RemoteOpenAIServer
......@@ -111,7 +112,11 @@ def get_hf_prompt_tokens(model_name, content, image_url):
"content": f"{placeholder}{content}",
}
]
images = [fetch_image(image_url)]
image = fetch_image(image_url)
# Unwrap MediaWithBytes if present
if isinstance(image, MediaWithBytes):
image = image.media
images = [image]
prompt = processor.tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
......
......@@ -2,64 +2,47 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import base64
import io
import numpy as np
import pytest
import requests
import torch
from ...utils import RemoteOpenAIServer
from vllm.utils.serial_utils import tensor2base64
MODEL_NAME = "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11"
DTYPE = "float16"
from ...utils import RemoteOpenAIServer
def _terratorch_dummy_inputs(model_name: str):
def _terratorch_dummy_messages():
pixel_values = torch.full((6, 512, 512), 1.0, dtype=torch.float16)
location_coords = torch.full((1, 2), 1.0, dtype=torch.float16)
buffer_tiff = io.BytesIO()
torch.save(pixel_values, buffer_tiff)
buffer_tiff.seek(0)
binary_data = buffer_tiff.read()
base64_tensor_embedding = base64.b64encode(binary_data).decode("utf-8")
buffer_coord = io.BytesIO()
torch.save(location_coords, buffer_coord)
buffer_coord.seek(0)
binary_data = buffer_coord.read()
base64_coord_embedding = base64.b64encode(binary_data).decode("utf-8")
return {
"model": model_name,
"additional_data": {"prompt_token_ids": [1]},
"encoding_format": "base64",
"messages": [
return [
{
"role": "user",
"content": [
{
"type": "image_embeds",
"image_embeds": {
"pixel_values": base64_tensor_embedding,
"location_coords": base64_coord_embedding,
"pixel_values": tensor2base64(pixel_values),
"location_coords": tensor2base64(location_coords),
},
}
],
}
],
}
]
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_single_request(model_name: str):
@pytest.mark.parametrize(
"model_name", ["ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11"]
)
def test_single_request(model_name: str):
args = [
"--runner",
"pooling",
# use half precision for speed and memory savings in CI environment
"--dtype",
DTYPE,
"float16",
"--enforce-eager",
"--trust-remote-code",
"--max-num-seqs",
......@@ -70,11 +53,15 @@ async def test_single_request(model_name: str):
"--enable-mm-embeds",
]
with RemoteOpenAIServer(MODEL_NAME, args) as server:
prompt = _terratorch_dummy_inputs(model_name)
# test single pooling
response = requests.post(server.url_for("pooling"), json=prompt)
with RemoteOpenAIServer(model_name, args) as server:
response = requests.post(
server.url_for("pooling"),
json={
"model": model_name,
"messages": _terratorch_dummy_messages(),
"encoding_format": "base64",
},
)
response.raise_for_status()
output = response.json()["data"][0]["data"]
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
import pytest
from tests.entrypoints.openai.tool_parsers.utils import (
run_tool_extraction,
run_tool_extraction_streaming,
)
from vllm.entrypoints.openai.protocol import FunctionCall
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
from vllm.tokenizers import TokenizerLike
SIMPLE_ARGS_DICT = {
"action": "create",
"id": "preferences",
}
SIMPLE_FUNCTION_JSON = json.dumps(
{
"name": "manage_user_memory",
"arguments": SIMPLE_ARGS_DICT,
},
ensure_ascii=False,
)
SIMPLE_FUNCTION_OUTPUT = "function call" + SIMPLE_FUNCTION_JSON
SIMPLE_FUNCTION_CALL = FunctionCall(
name="manage_user_memory",
arguments=json.dumps(SIMPLE_ARGS_DICT, ensure_ascii=False),
)
PARAMETERLESS_FUNCTION_JSON = json.dumps(
{
"name": "manage_user_memory",
"arguments": {},
},
ensure_ascii=False,
)
PARAMETERLESS_FUNCTION_OUTPUT = "function call" + PARAMETERLESS_FUNCTION_JSON
PARAMETERLESS_FUNCTION_CALL = FunctionCall(
name="manage_user_memory",
arguments=json.dumps({}, ensure_ascii=False),
)
COMPLEX_ARGS_DICT = {
"action": "create",
"id": "preferences",
"content": {
"short_answers": True,
"hate_emojis": True,
"english_ui": False,
"russian_math_explanations": True,
},
}
COMPLEX_FUNCTION_JSON = json.dumps(
{
"name": "manage_user_memory",
"arguments": COMPLEX_ARGS_DICT,
},
ensure_ascii=False,
)
COMPLEX_FUNCTION_OUTPUT = "function call" + COMPLEX_FUNCTION_JSON
COMPLEX_FUNCTION_CALL = FunctionCall(
name="manage_user_memory",
arguments=json.dumps(COMPLEX_ARGS_DICT, ensure_ascii=False),
)
@pytest.mark.parametrize("streaming", [True, False])
def test_no_tool_call(streaming: bool, default_tokenizer: TokenizerLike):
tool_parser: ToolParser = ToolParserManager.get_tool_parser("gigachat3")(
default_tokenizer
)
model_output = "How can I help you today?"
content, tool_calls = run_tool_extraction(
tool_parser, model_output, streaming=streaming
)
assert content == model_output
assert len(tool_calls) == 0
TEST_CASES = [
pytest.param(
True,
SIMPLE_FUNCTION_OUTPUT,
[SIMPLE_FUNCTION_CALL],
None,
id="simple_streaming",
),
pytest.param(
False,
SIMPLE_FUNCTION_OUTPUT,
[SIMPLE_FUNCTION_CALL],
None,
id="simple_nonstreaming",
),
pytest.param(
True,
PARAMETERLESS_FUNCTION_OUTPUT,
[PARAMETERLESS_FUNCTION_CALL],
None,
id="parameterless_streaming",
),
pytest.param(
False,
PARAMETERLESS_FUNCTION_OUTPUT,
[PARAMETERLESS_FUNCTION_CALL],
None,
id="parameterless_nonstreaming",
),
pytest.param(
True,
COMPLEX_FUNCTION_OUTPUT,
[COMPLEX_FUNCTION_CALL],
None,
id="complex_streaming",
),
pytest.param(
False,
COMPLEX_FUNCTION_OUTPUT,
[COMPLEX_FUNCTION_CALL],
None,
id="complex_nonstreaming",
),
]
@pytest.mark.parametrize(
"streaming, model_output, expected_tool_calls, expected_content", TEST_CASES
)
def test_tool_call(
streaming: bool,
model_output: str,
expected_tool_calls: list[FunctionCall],
expected_content: str | None,
default_tokenizer: TokenizerLike,
):
tool_parser: ToolParser = ToolParserManager.get_tool_parser("gigachat3")(
default_tokenizer
)
content, tool_calls = run_tool_extraction(
tool_parser, model_output, streaming=streaming
)
assert content == expected_content
assert len(tool_calls) == len(expected_tool_calls)
for actual, expected in zip(tool_calls, expected_tool_calls):
assert actual.type == "function"
assert actual.function.name == expected.name
actual_args = json.loads(actual.function.arguments)
expected_args = json.loads(expected.arguments)
assert actual_args == expected_args
def test_streaming_tool_call_with_large_steps(default_tokenizer: TokenizerLike):
tool_parser: ToolParser = ToolParserManager.get_tool_parser("gigachat3")(
default_tokenizer
)
model_output_deltas = [
"function call",
COMPLEX_FUNCTION_JSON[:40],
COMPLEX_FUNCTION_JSON[40:],
]
reconstructor = run_tool_extraction_streaming(
tool_parser,
model_output_deltas,
assert_one_tool_per_delta=False,
)
assert len(reconstructor.tool_calls) == 1
call = reconstructor.tool_calls[0]
assert call.type == "function"
assert call.function.name == "manage_user_memory"
args_dict = json.loads(call.function.arguments)
assert args_dict == COMPLEX_ARGS_DICT
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
import jsonschema
import openai
import pytest
import pytest_asyncio
from rapidfuzz import fuzz
from ....utils import RemoteOpenAIServer
MODEL_NAME = "openai/gpt-oss-20b"
@pytest.fixture(scope="module")
def server():
args = [
"--max-model-len",
"8192",
"--enforce-eager",
"--enable-auto-tool-choice",
"--tool-call-parser",
"openai",
]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
@pytest_asyncio.fixture
async def client(server):
"""Async fixture providing an OpenAI-compatible vLLM client."""
async with server.get_async_client() as async_client:
yield async_client
# ==========================================================
# Tool Definitions
# ==========================================================
TOOLS = [
{
"type": "function",
"function": {
"name": "calculator",
"description": "Performs basic arithmetic calculations.",
"parameters": {
"type": "object",
"properties": {
"expression": {
"type": "string",
"description": (
"Arithmetic expression to evaluate, e.g. '123 + 456'."
),
}
},
"required": ["expression"],
},
},
},
{
"type": "function",
"function": {
"name": "get_time",
"description": "Retrieves the current local time for a given city.",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "City name, e.g. 'New York'.",
}
},
"required": ["city"],
},
},
},
]
# ==========================================================
# Message Examples
# ==========================================================
MESSAGES_CALC = [
{"role": "user", "content": "Calculate 123 + 456 using the calculator."}
]
MESSAGES_GET_TIME = [
{"role": "user", "content": "What is the current time in New York?"}
]
MESSAGES_MULTIPLE_CALLS = [
{
"role": "system",
"content": (
"You can call multiple tools. "
"When using more than one, return single JSON object with tool_calls array"
"containing each tool call with its function name and arguments. "
"Do not output multiple JSON objects separately."
),
},
{
"role": "user",
"content": "First, calculate 7 * 8 using the calculator. "
"Then, use get_time to tell me the current time in New York.",
},
]
MESSAGES_INVALID_CALL = [
{
"role": "user",
"content": "Can you help with something, "
"but don’t actually perform any calculation?",
}
]
# Expected outputs
FUNC_CALC = "calculator"
FUNC_ARGS_CALC = '{"expression":"123 + 456"}'
FUNC_TIME = "get_time"
FUNC_ARGS_TIME = '{"city": "New York"}'
# ==========================================================
# Utility to extract reasoning and tool calls
# ==========================================================
def extract_reasoning_and_calls(chunks: list) -> tuple[str, list[str], list[str]]:
"""
Extract accumulated reasoning text and tool call arguments
from streaming chunks.
"""
reasoning_content: str = ""
tool_calls: dict[int, dict[str, str]] = {}
for chunk in chunks:
choice = getattr(chunk.choices[0], "delta", None)
if not choice:
continue
if hasattr(choice, "reasoning_content") and choice.reasoning_content:
reasoning_content += choice.reasoning_content
for tc in getattr(choice, "tool_calls", []) or []:
idx = getattr(tc, "index", 0)
tool_entry = tool_calls.setdefault(idx, {"name": "", "arguments": ""})
if getattr(tc, "function", None):
func = tc.function
if getattr(func, "name", None):
tool_entry["name"] = func.name
if getattr(func, "arguments", None):
tool_entry["arguments"] += func.arguments
function_names: list[str] = [v["name"] for _, v in sorted(tool_calls.items())]
arguments: list[str] = [v["arguments"] for _, v in sorted(tool_calls.items())]
return reasoning_content, arguments, function_names
# ==========================================================
# Test Scenarios
# ==========================================================
@pytest.mark.asyncio
async def test_calculator_tool_call_and_argument_accuracy(client: openai.AsyncOpenAI):
"""Verify calculator tool call is made and arguments are accurate."""
response = await client.chat.completions.create(
model=MODEL_NAME,
messages=MESSAGES_CALC,
tools=TOOLS,
temperature=0.0,
stream=False,
)
message = response.choices[0].message
tool_calls = getattr(message, "tool_calls", [])
assert tool_calls, "No tool calls detected"
calc_call = next((c for c in tool_calls if c.function.name == FUNC_CALC), None)
assert calc_call, "Calculator function not called"
raw_args = calc_call.function.arguments
assert raw_args, "Calculator arguments missing"
assert "123" in raw_args and "456" in raw_args, (
f"Expected values not in raw arguments: {raw_args}"
)
try:
parsed_args = json.loads(raw_args)
except json.JSONDecodeError:
pytest.fail(f"Invalid JSON in calculator arguments: {raw_args}")
expected_expr = "123 + 456"
actual_expr = parsed_args.get("expression", "")
similarity = fuzz.ratio(actual_expr, expected_expr)
assert similarity > 90, (
f"Expression mismatch: expected '{expected_expr}' "
f"got '{actual_expr}' (similarity={similarity}%)"
)
@pytest.mark.asyncio
async def test_streaming_tool_call_get_time_with_reasoning(client: openai.AsyncOpenAI):
"""Verify streamed reasoning and tool call behavior for get_time."""
stream = await client.chat.completions.create(
model=MODEL_NAME,
messages=MESSAGES_GET_TIME,
tools=TOOLS,
temperature=0.0,
stream=True,
)
chunks = [chunk async for chunk in stream]
reasoning, arguments, function_names = extract_reasoning_and_calls(chunks)
assert FUNC_TIME in function_names, "get_time function not called"
assert any("New York" in arg for arg in arguments), (
f"Expected get_time arguments for New York not found in {arguments}"
)
assert len(reasoning) > 0, "Expected reasoning content missing"
assert any(keyword in reasoning for keyword in ["New York", "time", "current"]), (
f"Reasoning is not relevant to the request: {reasoning}"
)
@pytest.mark.asyncio
async def test_streaming_multiple_tools(client: openai.AsyncOpenAI):
"""Test streamed multi-tool response with reasoning."""
stream = await client.chat.completions.create(
model=MODEL_NAME,
messages=MESSAGES_MULTIPLE_CALLS,
tools=TOOLS,
temperature=0.0,
stream=True,
)
chunks = [chunk async for chunk in stream]
reasoning, arguments, function_names = extract_reasoning_and_calls(chunks)
try:
assert FUNC_CALC in function_names, (
f"Calculator tool missing — found {function_names}"
)
assert FUNC_TIME in function_names, (
f"Time tool missing — found {function_names}"
)
assert len(reasoning) > 0, "Expected reasoning content in streamed response"
except AssertionError as e:
print(f"ERROR: {e}")
@pytest.mark.asyncio
async def test_invalid_tool_call(client: openai.AsyncOpenAI):
"""
Verify that ambiguous instructions that should not trigger a tool
do not produce any tool calls.
"""
response = await client.chat.completions.create(
model=MODEL_NAME,
messages=MESSAGES_INVALID_CALL,
tools=TOOLS,
temperature=0.0,
stream=False,
)
message = response.choices[0].message
assert message is not None, "Expected message in response"
assert hasattr(message, "content"), "Expected 'content' field in message"
tool_calls = getattr(message, "tool_calls", [])
assert not tool_calls, (
f"Model unexpectedly attempted a tool call on invalid input: {tool_calls}"
)
@pytest.mark.asyncio
async def test_tool_call_with_temperature(client: openai.AsyncOpenAI):
"""
Verify model produces valid tool or text output
under non-deterministic sampling.
"""
response = await client.chat.completions.create(
model=MODEL_NAME,
messages=MESSAGES_CALC,
tools=TOOLS,
temperature=0.7,
stream=False,
)
message = response.choices[0].message
assert message is not None, "Expected non-empty message in response"
assert message.tool_calls or message.content, (
"Response missing both text and tool calls"
)
print(f"\nTool calls: {message.tool_calls}")
print(f"Text: {message.content}")
@pytest.mark.asyncio
async def test_tool_response_schema_accuracy(client: openai.AsyncOpenAI):
"""Validate that tool call arguments adhere to their declared JSON schema."""
response = await client.chat.completions.create(
model=MODEL_NAME,
messages=MESSAGES_MULTIPLE_CALLS,
tools=TOOLS,
temperature=0.0,
)
calls = response.choices[0].message.tool_calls
assert calls, "No tool calls produced"
for call in calls:
func_name = call.function.name
args = json.loads(call.function.arguments)
schema: dict[str, object] | None = None
for tool_entry in TOOLS:
function_def = tool_entry.get("function")
if (
function_def
and isinstance(function_def, dict)
and function_def.get("name") == func_name
):
schema = function_def.get("parameters")
break
assert schema is not None, f"No matching tool schema found for {func_name}"
jsonschema.validate(instance=args, schema=schema)
@pytest.mark.asyncio
async def test_semantic_consistency_with_temperature(client: openai.AsyncOpenAI):
"""Test that temperature variation doesn't cause contradictory reasoning."""
responses = []
for temp in [0.0, 0.5, 1.0]:
resp = await client.chat.completions.create(
model=MODEL_NAME,
messages=MESSAGES_CALC,
tools=TOOLS,
temperature=temp,
)
text = (resp.choices[0].message.content or "").strip()
responses.append(text)
# Compare fuzzy similarity between low- and mid-temperature outputs
low_mid_sim = fuzz.ratio(responses[0], responses[1])
assert low_mid_sim > 60, (
f"Semantic drift too large between T=0.0 and T=0.5 ({low_mid_sim}%)"
)
......@@ -61,10 +61,7 @@ def test_pooling_params(llm: LLM):
@pytest.mark.skip_global_cleanup
def test_encode_api(llm: LLM):
# chunked prefill does not support all pooling
err_msg = "pooling_task must be one of.+"
with pytest.raises(ValueError, match=err_msg):
def test_token_classify(llm: LLM):
llm.encode(prompts, pooling_task="token_classify", use_tqdm=False)
......
......@@ -255,21 +255,21 @@ async def test_pooling_classify(server: RemoteOpenAIServer, model_name: str):
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_pooling_token_classify(server: RemoteOpenAIServer, model_name: str):
# token_classify uses ALL pooling, which does not support chunked prefill.
task = "token_classify"
input_text = ["This product was excellent and exceeded my expectations"]
response = requests.post(
server.url_for("pooling"),
json={
"model": model_name,
"input": "test",
"input": input_text,
"encoding_format": "float",
"task": task,
},
)
assert response.json()["error"]["type"] == "BadRequestError"
assert response.json()["error"]["message"].startswith(
f"Task {task} is not supported"
)
poolings = PoolingResponse.model_validate(response.json())
assert len(poolings.data) == 1
assert len(poolings.data[0].data) == 8
assert len(poolings.data[0].data[0]) == 2
@pytest.mark.asyncio
......
......@@ -42,7 +42,7 @@ def llm():
@pytest.mark.skip_global_cleanup
def test_encode_api(llm: LLM):
def test_token_embed(llm: LLM):
outputs = llm.encode(prompts, pooling_task="token_embed", use_tqdm=False)
multi_vector = outputs[0].outputs.data
assert multi_vector.shape == (11, 384)
......
......@@ -24,6 +24,7 @@ from vllm.utils.serial_utils import (
ENDIANNESS,
MetadataItem,
binary2tensor,
build_metadata_items,
decode_pooling_output,
)
......@@ -344,6 +345,55 @@ async def test_bytes_embed_dtype_and_endianness(
)
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_bytes_only_embed_dtype_and_endianness(
server: RemoteOpenAIServer, client: openai.AsyncOpenAI, model_name: str
):
input_texts = [
"The best thing about vLLM is that it supports many different models",
] * 2
responses_float = await client.embeddings.create(
input=input_texts, model=model_name, encoding_format="float"
)
float_data = [d.embedding for d in responses_float.data]
embedding_size = len(float_data[0])
for embed_dtype in list(EMBED_DTYPE_TO_TORCH_DTYPE.keys()):
for endianness in ENDIANNESS:
responses_bytes = requests.post(
server.url_for("/v1/embeddings"),
json={
"model": model_name,
"input": input_texts,
"encoding_format": "bytes_only",
"embed_dtype": embed_dtype,
"endianness": endianness,
},
)
assert "metadata" not in responses_bytes.headers
body = responses_bytes.content
items = build_metadata_items(
embed_dtype=embed_dtype,
endianness=endianness,
shape=(embedding_size,),
n_request=len(input_texts),
)
bytes_data = decode_pooling_output(items=items, body=body)
bytes_data = [x.to(torch.float32).tolist() for x in bytes_data]
check_embeddings_close(
embeddings_0_lst=float_data,
embeddings_1_lst=bytes_data,
name_0="float_data",
name_1="bytes_data",
tol=1e-2,
)
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("param_name", ["encoding_format", "embed_dtype", "endianness"])
......
......@@ -9,6 +9,7 @@ from transformers import AutoProcessor
from tests.utils import VLLM_PATH, RemoteOpenAIServer
from vllm.entrypoints.pooling.embed.protocol import EmbeddingResponse
from vllm.multimodal.base import MediaWithBytes
from vllm.multimodal.utils import encode_image_base64, fetch_image
MODEL_NAME = "TIGER-Lab/VLM2Vec-Full"
......@@ -62,7 +63,11 @@ def get_hf_prompt_tokens(model_name, content, image_url):
placeholder = "<|image_1|> "
prompt = f"{placeholder}{content}"
images = [fetch_image(image_url)]
image = fetch_image(image_url)
# Unwrap MediaWithBytes if present
if isinstance(image, MediaWithBytes):
image = image.media
images = [image]
inputs = processor(prompt, images, return_tensors="pt")
return inputs.input_ids.shape[1]
......
......@@ -18,6 +18,7 @@ from vllm.utils.serial_utils import (
ENDIANNESS,
MetadataItem,
binary2tensor,
build_metadata_items,
decode_pooling_output,
)
......@@ -352,6 +353,61 @@ async def test_bytes_embed_dtype_and_endianness(
)
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_bytes_only_embed_dtype_and_endianness(
server: RemoteOpenAIServer, model_name: str
):
input_texts = [
"The best thing about vLLM is that it supports many different models",
] * 2
url = server.url_for("pooling")
float_response = requests.post(
url,
json={
"model": model_name,
"input": input_texts,
"encoding_format": "float",
},
)
responses_float = PoolingResponse.model_validate(float_response.json())
float_data = [np.array(d.data).squeeze(-1).tolist() for d in responses_float.data]
n_tokens = responses_float.usage.prompt_tokens // len(input_texts)
for embed_dtype in list(EMBED_DTYPE_TO_TORCH_DTYPE.keys()):
for endianness in ENDIANNESS:
responses_bytes = requests.post(
url,
json={
"model": model_name,
"input": input_texts,
"encoding_format": "bytes_only",
"embed_dtype": embed_dtype,
"endianness": endianness,
},
)
assert "metadata" not in responses_bytes.headers
body = responses_bytes.content
items = build_metadata_items(
embed_dtype=embed_dtype,
endianness=endianness,
shape=(n_tokens, 1),
n_request=len(input_texts),
)
bytes_data = decode_pooling_output(items=items, body=body)
bytes_data = [x.to(torch.float32).view(-1).tolist() for x in bytes_data]
check_embeddings_close(
embeddings_0_lst=float_data,
embeddings_1_lst=bytes_data,
name_0="float_data",
name_1="bytes_data",
tol=1e-2,
)
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("param_name", ["encoding_format", "embed_dtype", "endianness"])
......
......@@ -36,6 +36,13 @@ def llm():
cleanup_dist_env_and_memory()
@pytest.mark.skip_global_cleanup
def test_config(llm: LLM):
vllm_config = llm.llm_engine.vllm_config
assert vllm_config.cache_config.enable_prefix_caching
assert vllm_config.scheduler_config.enable_chunked_prefill
def test_pooling_params(llm: LLM):
def get_outputs(use_activation):
outputs = llm.reward(
......
......@@ -6,6 +6,7 @@ from collections.abc import Mapping
from typing import Literal
import pytest
import torch
from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy
from vllm.assets.audio import AudioAsset
......@@ -29,6 +30,7 @@ from vllm.multimodal.utils import (
encode_video_base64,
)
from vllm.tokenizers import MistralTokenizer, get_tokenizer
from vllm.utils.serial_utils import tensor2base64
from ..models.registry import HF_EXAMPLE_MODELS
from ..utils import VLLM_PATH
......@@ -85,11 +87,6 @@ def phi3v_model_config_image_embeds():
)
@pytest.fixture(scope="module")
def phi3v_tokenizer():
return get_tokenizer(PHI3V_MODEL_ID)
@pytest.fixture(scope="function")
def qwen2_audio_model_config():
return ModelConfig(
......@@ -115,11 +112,6 @@ def audio_embeds_model_config():
)
@pytest.fixture(scope="module")
def qwen2_audio_tokenizer():
return get_tokenizer(QWEN2AUDIO_MODEL_ID)
@pytest.fixture(scope="function")
def qwen25omni_model_config_mm_interleaved():
return ModelConfig(
......@@ -134,11 +126,6 @@ def qwen25omni_model_config_mm_interleaved():
)
@pytest.fixture(scope="module")
def qwen25omni_tokenizer():
return get_tokenizer(QWEN25OMNI_MODEL_ID)
@pytest.fixture(scope="function")
def mistral_model_config():
return ModelConfig(
......@@ -150,11 +137,6 @@ def mistral_model_config():
)
@pytest.fixture(scope="module")
def mistral_tokenizer():
return get_tokenizer(MISTRAL_MODEL_ID)
@pytest.fixture(scope="module")
def image_url():
image = ImageAsset("cherry_blossom")
......@@ -239,7 +221,6 @@ def _assert_mm_data_inputs(
def test_parse_chat_messages_single_image(
phi3v_model_config,
phi3v_tokenizer,
image_url,
):
conversation, mm_data, mm_uuids = parse_chat_messages(
......@@ -253,7 +234,6 @@ def test_parse_chat_messages_single_image(
}
],
phi3v_model_config,
phi3v_tokenizer,
content_format="string",
)
......@@ -266,7 +246,6 @@ def test_parse_chat_messages_single_image(
def test_parse_chat_messages_single_image_with_uuid(
phi3v_model_config,
phi3v_tokenizer,
image_url,
):
image_uuid = str(hash(image_url))
......@@ -287,7 +266,6 @@ def test_parse_chat_messages_single_image_with_uuid(
}
],
phi3v_model_config,
phi3v_tokenizer,
content_format="string",
)
......@@ -300,7 +278,6 @@ def test_parse_chat_messages_single_image_with_uuid(
def test_parse_chat_messages_single_empty_image_with_uuid(
phi3v_model_config,
phi3v_tokenizer,
image_url,
):
image_uuid = str(hash(image_url))
......@@ -319,7 +296,6 @@ def test_parse_chat_messages_single_empty_image_with_uuid(
}
],
phi3v_model_config,
phi3v_tokenizer,
content_format="string",
)
......@@ -332,7 +308,6 @@ def test_parse_chat_messages_single_empty_image_with_uuid(
def test_parse_chat_messages_single_image_with_bad_uuid_format(
phi3v_model_config,
phi3v_tokenizer,
image_url,
):
image_uuid = str(hash(image_url))
......@@ -354,7 +329,6 @@ def test_parse_chat_messages_single_image_with_bad_uuid_format(
}
],
phi3v_model_config,
phi3v_tokenizer,
content_format="string",
)
......@@ -367,7 +341,6 @@ def test_parse_chat_messages_single_image_with_bad_uuid_format(
def test_parse_chat_messages_multiple_images_with_uuids(
phi3v_model_config,
phi3v_tokenizer,
image_url,
):
image_uuid1 = "my_uuid_1"
......@@ -397,7 +370,6 @@ def test_parse_chat_messages_multiple_images_with_uuids(
}
],
phi3v_model_config,
phi3v_tokenizer,
content_format="string",
)
......@@ -413,7 +385,6 @@ def test_parse_chat_messages_multiple_images_with_uuids(
def test_parse_chat_messages_multiple_empty_images_with_uuids(
phi3v_model_config,
phi3v_tokenizer,
image_url,
):
image_uuid1 = "my_uuid_1"
......@@ -439,7 +410,6 @@ def test_parse_chat_messages_multiple_empty_images_with_uuids(
}
],
phi3v_model_config,
phi3v_tokenizer,
content_format="string",
)
......@@ -455,7 +425,6 @@ def test_parse_chat_messages_multiple_empty_images_with_uuids(
def test_parse_chat_messages_mixed_empty_images_with_uuids(
phi3v_model_config,
phi3v_tokenizer,
image_url,
):
image_uuid1 = "my_uuid_1"
......@@ -483,7 +452,6 @@ def test_parse_chat_messages_mixed_empty_images_with_uuids(
}
],
phi3v_model_config,
phi3v_tokenizer,
content_format="string",
)
......@@ -500,7 +468,6 @@ def test_parse_chat_messages_mixed_empty_images_with_uuids(
@pytest.mark.asyncio
async def test_parse_chat_messages_single_image_with_uuid_async(
phi3v_model_config,
phi3v_tokenizer,
image_url,
):
image_uuid = str(hash(image_url))
......@@ -519,7 +486,6 @@ async def test_parse_chat_messages_single_image_with_uuid_async(
}
],
phi3v_model_config,
phi3v_tokenizer,
content_format="string",
)
......@@ -533,7 +499,6 @@ async def test_parse_chat_messages_single_image_with_uuid_async(
@pytest.mark.asyncio
async def test_parse_chat_messages_empty_image_with_uuid_async(
phi3v_model_config,
phi3v_tokenizer,
image_url,
):
image_uuid = str(hash(image_url))
......@@ -552,7 +517,6 @@ async def test_parse_chat_messages_empty_image_with_uuid_async(
}
],
phi3v_model_config,
phi3v_tokenizer,
content_format="string",
)
......@@ -566,7 +530,6 @@ async def test_parse_chat_messages_empty_image_with_uuid_async(
@pytest.mark.asyncio
async def test_parse_chat_messages_multiple_images_with_uuids_async(
phi3v_model_config,
phi3v_tokenizer,
image_url,
):
image_uuid1 = "my_uuid_1"
......@@ -592,7 +555,6 @@ async def test_parse_chat_messages_multiple_images_with_uuids_async(
}
],
phi3v_model_config,
phi3v_tokenizer,
content_format="string",
)
......@@ -609,7 +571,6 @@ async def test_parse_chat_messages_multiple_images_with_uuids_async(
@pytest.mark.asyncio
async def test_parse_chat_messages_multiple_empty_images_with_uuids_async(
phi3v_model_config,
phi3v_tokenizer,
image_url,
):
image_uuid1 = "my_uuid_1"
......@@ -635,7 +596,6 @@ async def test_parse_chat_messages_multiple_empty_images_with_uuids_async(
}
],
phi3v_model_config,
phi3v_tokenizer,
content_format="string",
)
......@@ -652,7 +612,6 @@ async def test_parse_chat_messages_multiple_empty_images_with_uuids_async(
@pytest.mark.asyncio
async def test_parse_chat_messages_multiple_images_with_partial_uuids_async(
phi3v_model_config,
phi3v_tokenizer,
image_url,
):
image_uuid2 = "my_uuid_2"
......@@ -676,7 +635,6 @@ async def test_parse_chat_messages_multiple_images_with_partial_uuids_async(
}
],
phi3v_model_config,
phi3v_tokenizer,
content_format="string",
)
......@@ -692,7 +650,6 @@ async def test_parse_chat_messages_multiple_images_with_partial_uuids_async(
def test_parse_chat_messages_empty_system(
mistral_model_config,
mistral_tokenizer,
):
# Test string format
conversation, _, _ = parse_chat_messages(
......@@ -704,7 +661,6 @@ def test_parse_chat_messages_empty_system(
},
],
mistral_model_config,
mistral_tokenizer,
content_format="string",
)
assert conversation == [
......@@ -722,7 +678,6 @@ def test_parse_chat_messages_empty_system(
},
],
mistral_model_config,
mistral_tokenizer,
content_format="openai",
)
assert conversation == [
......@@ -734,7 +689,6 @@ def test_parse_chat_messages_empty_system(
@pytest.mark.asyncio
async def test_parse_chat_messages_single_image_async(
phi3v_model_config,
phi3v_tokenizer,
image_url,
):
conversation, mm_future, mm_uuids = parse_chat_messages_futures(
......@@ -748,7 +702,6 @@ async def test_parse_chat_messages_single_image_async(
}
],
phi3v_model_config,
phi3v_tokenizer,
content_format="string",
)
......@@ -761,7 +714,6 @@ async def test_parse_chat_messages_single_image_async(
def test_parse_chat_messages_multiple_images(
phi3v_model_config,
phi3v_tokenizer,
image_url,
):
conversation, mm_data, mm_uuids = parse_chat_messages(
......@@ -779,7 +731,6 @@ def test_parse_chat_messages_multiple_images(
}
],
phi3v_model_config,
phi3v_tokenizer,
content_format="string",
)
......@@ -795,7 +746,6 @@ def test_parse_chat_messages_multiple_images(
def test_parse_chat_messages_empty_pil_image_with_uuid(
phi3v_model_config,
phi3v_tokenizer,
):
uuid = "abcd"
conversation, mm_data, mm_uuids = parse_chat_messages(
......@@ -809,7 +759,6 @@ def test_parse_chat_messages_empty_pil_image_with_uuid(
}
],
phi3v_model_config,
phi3v_tokenizer,
content_format="string",
)
......@@ -825,7 +774,6 @@ def test_parse_chat_messages_empty_pil_image_with_uuid(
def test_parse_chat_messages_empty_image_embeds_with_uuid(
phi3v_model_config_image_embeds,
phi3v_tokenizer,
):
uuid = "abcd"
conversation, mm_data, mm_uuids = parse_chat_messages(
......@@ -839,7 +787,6 @@ def test_parse_chat_messages_empty_image_embeds_with_uuid(
}
],
phi3v_model_config_image_embeds,
phi3v_tokenizer,
content_format="string",
)
......@@ -857,7 +804,6 @@ def test_parse_chat_messages_empty_image_embeds_with_uuid(
def test_parse_chat_messages_empty_audio_embeds_with_uuid(
audio_embeds_model_config,
qwen2_audio_tokenizer,
):
"""Test audio_embeds with UUID (no actual embeds data)."""
uuid = "test-audio-uuid-123"
......@@ -873,7 +819,6 @@ def test_parse_chat_messages_empty_audio_embeds_with_uuid(
}
],
audio_embeds_model_config,
qwen2_audio_tokenizer,
content_format="string",
)
......@@ -889,11 +834,8 @@ def test_parse_chat_messages_empty_audio_embeds_with_uuid(
def test_parse_chat_messages_audio_embeds_with_string(
audio_embeds_model_config,
qwen2_audio_tokenizer,
):
"""Test audio_embeds with base64 string embedding data."""
import base64
import io
import torch
......@@ -901,11 +843,7 @@ def test_parse_chat_messages_audio_embeds_with_string(
audio_embedding = torch.randn(1, 128, 768)
# Encode it as base64
buffer = io.BytesIO()
torch.save(audio_embedding, buffer)
buffer.seek(0)
binary_data = buffer.read()
base64_audio_embedding = base64.b64encode(binary_data).decode("utf-8")
base64_audio_embedding = tensor2base64(audio_embedding)
conversation, mm_data, mm_uuids = parse_chat_messages(
[
......@@ -921,7 +859,6 @@ def test_parse_chat_messages_audio_embeds_with_string(
}
],
audio_embeds_model_config,
qwen2_audio_tokenizer,
content_format="string",
)
......@@ -939,11 +876,8 @@ def test_parse_chat_messages_audio_embeds_with_string(
@pytest.mark.asyncio
async def test_parse_chat_messages_audio_embeds_async(
audio_embeds_model_config,
qwen2_audio_tokenizer,
):
"""Test audio_embeds with async futures."""
import base64
import io
import torch
......@@ -951,11 +885,7 @@ async def test_parse_chat_messages_audio_embeds_async(
audio_embedding = torch.randn(1, 128, 768)
# Encode it as base64
buffer = io.BytesIO()
torch.save(audio_embedding, buffer)
buffer.seek(0)
binary_data = buffer.read()
base64_audio_embedding = base64.b64encode(binary_data).decode("utf-8")
base64_audio_embedding = tensor2base64(audio_embedding)
conversation, mm_future, mm_uuids = parse_chat_messages_futures(
[
......@@ -971,7 +901,6 @@ async def test_parse_chat_messages_audio_embeds_async(
}
],
audio_embeds_model_config,
qwen2_audio_tokenizer,
content_format="string",
)
......@@ -987,10 +916,186 @@ async def test_parse_chat_messages_audio_embeds_async(
_assert_mm_uuids(mm_uuids, 1, modality="audio", expected_uuids=[None])
def test_parse_chat_messages_multiple_image_embeds(
phi3v_model_config_image_embeds,
):
"""Test that multiple image_embeds in a single message are now supported.
This test validates the fix for the limitation that previously only allowed
one message with {'type': 'image_embeds'}. Now multiple image embeddings
can be provided in a single request, similar to regular images.
"""
# Create two sample image embedding tensors
image_embedding_1 = torch.randn(256, 1024)
image_embedding_2 = torch.randn(128, 1024)
# Encode them as base64 using the convenience function
base64_image_embedding_1 = tensor2base64(image_embedding_1)
base64_image_embedding_2 = tensor2base64(image_embedding_2)
conversation, mm_data, mm_uuids = parse_chat_messages(
[
{
"role": "user",
"content": [
{
"type": "image_embeds",
"image_embeds": base64_image_embedding_1,
},
{
"type": "image_embeds",
"image_embeds": base64_image_embedding_2,
},
{"type": "text", "text": "Describe these two images."},
],
}
],
phi3v_model_config_image_embeds,
content_format="string",
)
# Verify conversation structure
assert conversation == [
{
"role": "user",
"content": "<|image_1|>\n<|image_2|>\nDescribe these two images.",
}
]
# Verify mm_data contains a list of embeddings (not a single embedding)
assert mm_data is not None
assert "image" in mm_data
assert isinstance(mm_data["image"], list)
assert len(mm_data["image"]) == 2
# Verify each embedding has the correct shape
assert isinstance(mm_data["image"][0], torch.Tensor)
assert mm_data["image"][0].shape == image_embedding_1.shape
assert isinstance(mm_data["image"][1], torch.Tensor)
assert mm_data["image"][1].shape == image_embedding_2.shape
# Verify UUIDs (None since we didn't provide any)
_assert_mm_uuids(mm_uuids, 2, expected_uuids=[None, None])
def test_parse_chat_messages_multiple_image_embeds_with_uuids(
phi3v_model_config_image_embeds,
):
"""Test multiple image_embeds with UUIDs.
This validates that UUIDs are properly tracked for multiple embeddings.
"""
uuid1 = "image-uuid-1"
uuid2 = "image-uuid-2"
conversation, mm_data, mm_uuids = parse_chat_messages(
[
{
"role": "user",
"content": [
{
"type": "image_embeds",
"image_embeds": None,
"uuid": uuid1,
},
{
"type": "image_embeds",
"image_embeds": None,
"uuid": uuid2,
},
{"type": "text", "text": "Compare these images."},
],
}
],
phi3v_model_config_image_embeds,
content_format="string",
)
# Verify conversation structure
assert conversation == [
{
"role": "user",
"content": "<|image_1|>\n<|image_2|>\nCompare these images.",
}
]
# Verify mm_data contains a list with None values (UUID references)
assert mm_data is not None
assert "image" in mm_data
assert isinstance(mm_data["image"], list)
assert len(mm_data["image"]) == 2
assert mm_data["image"][0] is None
assert mm_data["image"][1] is None
# Verify UUIDs are correctly tracked
_assert_mm_uuids(mm_uuids, 2, expected_uuids=[uuid1, uuid2])
@pytest.mark.asyncio
async def test_parse_chat_messages_multiple_image_embeds_async(
phi3v_model_config_image_embeds,
):
"""Test multiple image_embeds with async parsing.
This validates the AsyncMultiModalItemTracker also supports multiple embeddings.
"""
# Create two sample image embedding tensors
image_embedding_1 = torch.randn(200, 768)
image_embedding_2 = torch.randn(150, 768)
# Encode them as base64 using the convenience function
base64_image_embedding_1 = tensor2base64(image_embedding_1)
base64_image_embedding_2 = tensor2base64(image_embedding_2)
conversation, mm_future, mm_uuids = parse_chat_messages_futures(
[
{
"role": "user",
"content": [
{
"type": "image_embeds",
"image_embeds": base64_image_embedding_1,
},
{
"type": "image_embeds",
"image_embeds": base64_image_embedding_2,
},
{"type": "text", "text": "What do these images show?"},
],
}
],
phi3v_model_config_image_embeds,
content_format="string",
)
# Verify conversation structure
assert conversation == [
{
"role": "user",
"content": "<|image_1|>\n<|image_2|>\nWhat do these images show?",
}
]
# Await the future and verify mm_data
mm_data = await mm_future
assert mm_data is not None
assert "image" in mm_data
assert isinstance(mm_data["image"], list)
assert len(mm_data["image"]) == 2
# Verify each embedding has the correct shape
assert isinstance(mm_data["image"][0], torch.Tensor)
assert mm_data["image"][0].shape == image_embedding_1.shape
assert isinstance(mm_data["image"][1], torch.Tensor)
assert mm_data["image"][1].shape == image_embedding_2.shape
# Verify UUIDs
_assert_mm_uuids(mm_uuids, 2, expected_uuids=[None, None])
@pytest.mark.asyncio
async def test_parse_chat_messages_empty_image_embeds_with_uuid_async(
phi3v_model_config_image_embeds,
phi3v_tokenizer,
):
uuid = "abcd"
conversation, mm_future, mm_uuids = parse_chat_messages_futures(
......@@ -1004,7 +1109,6 @@ async def test_parse_chat_messages_empty_image_embeds_with_uuid_async(
}
],
phi3v_model_config_image_embeds,
phi3v_tokenizer,
content_format="string",
)
......@@ -1024,7 +1128,6 @@ async def test_parse_chat_messages_empty_image_embeds_with_uuid_async(
@pytest.mark.asyncio
async def test_parse_chat_messages_multiple_images_async(
phi3v_model_config,
phi3v_tokenizer,
image_url,
):
conversation, mm_future, mm_uuids = parse_chat_messages_futures(
......@@ -1042,7 +1145,6 @@ async def test_parse_chat_messages_multiple_images_async(
}
],
phi3v_model_config,
phi3v_tokenizer,
content_format="string",
)
......@@ -1058,7 +1160,6 @@ async def test_parse_chat_messages_multiple_images_async(
def test_parse_chat_messages_placeholder_already_in_prompt(
phi3v_model_config,
phi3v_tokenizer,
image_url,
):
conversation, mm_data, mm_uuids = parse_chat_messages(
......@@ -1076,7 +1177,6 @@ def test_parse_chat_messages_placeholder_already_in_prompt(
}
],
phi3v_model_config,
phi3v_tokenizer,
content_format="string",
)
assert conversation == [
......@@ -1091,7 +1191,6 @@ def test_parse_chat_messages_placeholder_already_in_prompt(
def test_parse_chat_messages_placeholder_one_already_in_prompt(
phi3v_model_config,
phi3v_tokenizer,
image_url,
):
conversation, mm_data, mm_uuids = parse_chat_messages(
......@@ -1110,7 +1209,6 @@ def test_parse_chat_messages_placeholder_one_already_in_prompt(
}
],
phi3v_model_config,
phi3v_tokenizer,
content_format="string",
)
......@@ -1127,7 +1225,6 @@ def test_parse_chat_messages_placeholder_one_already_in_prompt(
def test_parse_chat_messages_multiple_images_across_messages(
phi3v_model_config,
phi3v_tokenizer,
image_url,
):
conversation, mm_data, mm_uuids = parse_chat_messages(
......@@ -1149,7 +1246,6 @@ def test_parse_chat_messages_multiple_images_across_messages(
},
],
phi3v_model_config,
phi3v_tokenizer,
content_format="string",
)
......@@ -1164,7 +1260,6 @@ def test_parse_chat_messages_multiple_images_across_messages(
def test_parse_chat_messages_multiple_images_with_uuids_across_messages(
phi3v_model_config,
phi3v_tokenizer,
image_url,
):
image_uuid = str(hash(image_url))
......@@ -1195,7 +1290,6 @@ def test_parse_chat_messages_multiple_images_with_uuids_across_messages(
},
],
phi3v_model_config,
phi3v_tokenizer,
content_format="string",
)
......@@ -1210,7 +1304,6 @@ def test_parse_chat_messages_multiple_images_with_uuids_across_messages(
def test_parse_chat_messages_context_text_format(
phi3v_model_config,
phi3v_tokenizer,
):
conversation, mm_data, mm_uuids = parse_chat_messages(
[
......@@ -1222,7 +1315,6 @@ def test_parse_chat_messages_context_text_format(
{"role": "user", "content": "What about this one?"},
],
phi3v_model_config,
phi3v_tokenizer,
content_format="openai",
)
......@@ -1246,7 +1338,6 @@ def test_parse_chat_messages_context_text_format(
def test_parse_chat_messages_rejects_too_many_images_in_one_message(
phi3v_model_config,
phi3v_tokenizer,
image_url,
):
with warnings.catch_warnings():
......@@ -1277,14 +1368,12 @@ def test_parse_chat_messages_rejects_too_many_images_in_one_message(
}
],
phi3v_model_config,
phi3v_tokenizer,
content_format="string",
)
def test_parse_chat_messages_rejects_too_many_images_across_messages(
phi3v_model_config,
phi3v_tokenizer,
image_url,
):
with warnings.catch_warnings():
......@@ -1322,14 +1411,12 @@ def test_parse_chat_messages_rejects_too_many_images_across_messages(
},
],
phi3v_model_config,
phi3v_tokenizer,
content_format="string",
)
def test_parse_chat_messages_multiple_images_uncommon_input(
phi3v_model_config,
phi3v_tokenizer,
image_url,
):
conversation, mm_data, mm_uuids = parse_chat_messages(
......@@ -1344,7 +1431,6 @@ def test_parse_chat_messages_multiple_images_uncommon_input(
}
],
phi3v_model_config,
phi3v_tokenizer,
content_format="string",
)
......@@ -1360,7 +1446,6 @@ def test_parse_chat_messages_multiple_images_uncommon_input(
def test_parse_chat_messages_multiple_images_interleave(
phi3v_model_config_mm_interleaved,
phi3v_tokenizer,
image_url,
):
conversation, mm_data, mm_uuids = parse_chat_messages(
......@@ -1380,7 +1465,6 @@ def test_parse_chat_messages_multiple_images_interleave(
}
],
phi3v_model_config_mm_interleaved,
phi3v_tokenizer,
content_format="string",
)
......@@ -1398,7 +1482,6 @@ def test_parse_chat_messages_multiple_images_interleave(
@pytest.mark.asyncio
async def test_parse_chat_messages_multiple_images_interleave_async(
phi3v_model_config_mm_interleaved,
phi3v_tokenizer,
image_url,
):
conversation, mm_data, mm_uuids = parse_chat_messages_futures(
......@@ -1418,7 +1501,6 @@ async def test_parse_chat_messages_multiple_images_interleave_async(
}
],
phi3v_model_config_mm_interleaved,
phi3v_tokenizer,
content_format="string",
)
......@@ -1436,7 +1518,6 @@ async def test_parse_chat_messages_multiple_images_interleave_async(
@pytest.mark.asyncio
async def test_parse_chat_messages_multiple_images_with_uuids_interleave_async(
phi3v_model_config_mm_interleaved,
phi3v_tokenizer,
image_url,
):
image_uuid = str(hash(image_url))
......@@ -1465,7 +1546,6 @@ async def test_parse_chat_messages_multiple_images_with_uuids_interleave_async(
}
],
phi3v_model_config_mm_interleaved,
phi3v_tokenizer,
content_format="string",
)
......@@ -1482,7 +1562,6 @@ async def test_parse_chat_messages_multiple_images_with_uuids_interleave_async(
def test_parse_chat_messages_multiple_images_multiple_messages_interleave(
phi3v_model_config_mm_interleaved,
phi3v_tokenizer,
image_url,
):
conversation, mm_data, mm_uuids = parse_chat_messages(
......@@ -1505,7 +1584,6 @@ def test_parse_chat_messages_multiple_images_multiple_messages_interleave(
},
],
phi3v_model_config_mm_interleaved,
phi3v_tokenizer,
content_format="string",
)
......@@ -1523,7 +1601,6 @@ def test_parse_chat_messages_multiple_images_multiple_messages_interleave(
def test_parse_chat_messages_multiple_images_with_uuids_multiple_messages_interleave(
phi3v_model_config_mm_interleaved,
phi3v_tokenizer,
image_url,
):
image_uuid = str(hash(image_url))
......@@ -1555,7 +1632,6 @@ def test_parse_chat_messages_multiple_images_with_uuids_multiple_messages_interl
},
],
phi3v_model_config_mm_interleaved,
phi3v_tokenizer,
content_format="string",
)
......@@ -1573,7 +1649,6 @@ def test_parse_chat_messages_multiple_images_with_uuids_multiple_messages_interl
def test_parse_chat_messages_multiple_modals_multiple_messages_interleave(
qwen25omni_model_config_mm_interleaved,
qwen25omni_tokenizer,
image_url,
video_url,
audio_url,
......@@ -1601,7 +1676,6 @@ def test_parse_chat_messages_multiple_modals_multiple_messages_interleave(
},
],
qwen25omni_model_config_mm_interleaved,
qwen25omni_tokenizer,
content_format="string",
)
......@@ -1627,7 +1701,6 @@ def test_parse_chat_messages_multiple_modals_multiple_messages_interleave(
def test_parse_chat_messages_multiple_modals_with_uuids_multiple_messages_interleave(
qwen25omni_model_config_mm_interleaved,
qwen25omni_tokenizer,
image_url,
video_url,
audio_url,
......@@ -1671,7 +1744,6 @@ def test_parse_chat_messages_multiple_modals_with_uuids_multiple_messages_interl
},
],
qwen25omni_model_config_mm_interleaved,
qwen25omni_tokenizer,
content_format="string",
)
......@@ -1699,7 +1771,6 @@ def test_parse_chat_messages_multiple_modals_with_uuids_multiple_messages_interl
def test_parse_chat_messages_multiple_modals_with_uuids_multiple_empty_media_messages_interleave( # noqa: E501
qwen25omni_model_config_mm_interleaved,
qwen25omni_tokenizer,
image_url,
video_url,
audio_url,
......@@ -1743,7 +1814,6 @@ def test_parse_chat_messages_multiple_modals_with_uuids_multiple_empty_media_mes
},
],
qwen25omni_model_config_mm_interleaved,
qwen25omni_tokenizer,
content_format="string",
)
......@@ -1775,7 +1845,6 @@ def test_parse_chat_messages_multiple_modals_with_uuids_multiple_empty_media_mes
def test_parse_chat_messages_multiple_modals_with_partial_uuids_multiple_messages_interleave( # noqa: E501
qwen25omni_model_config_mm_interleaved,
qwen25omni_tokenizer,
image_url,
video_url,
audio_url,
......@@ -1811,7 +1880,6 @@ def test_parse_chat_messages_multiple_modals_with_partial_uuids_multiple_message
},
],
qwen25omni_model_config_mm_interleaved,
qwen25omni_tokenizer,
content_format="string",
)
......@@ -1837,7 +1905,6 @@ def test_parse_chat_messages_multiple_modals_with_partial_uuids_multiple_message
def test_parse_chat_messages_multiple_images_interleave_with_placeholders(
phi3v_model_config_mm_interleaved,
phi3v_tokenizer,
image_url,
):
with pytest.raises(
......@@ -1861,7 +1928,6 @@ def test_parse_chat_messages_multiple_images_interleave_with_placeholders(
}
],
phi3v_model_config_mm_interleaved,
phi3v_tokenizer,
content_format="string",
)
......@@ -2238,9 +2304,7 @@ def test_resolve_content_format_examples(template_path, expected_format):
assert resolved_format == expected_format
def test_parse_chat_messages_include_thinking_chunk(
mistral_model_config, mistral_tokenizer
):
def test_parse_chat_messages_include_thinking_chunk(mistral_model_config):
messages = [
{
"role": "system",
......@@ -2270,7 +2334,6 @@ def test_parse_chat_messages_include_thinking_chunk(
conversation_with_thinking, _, _ = parse_chat_messages(
messages,
mistral_model_config,
mistral_tokenizer,
content_format="openai",
)
......@@ -2354,7 +2417,6 @@ def test_apply_mistral_chat_template_thinking_chunk():
def test_parse_chat_messages_single_empty_audio_with_uuid(
qwen2_audio_model_config,
qwen2_audio_tokenizer,
):
audio_uuid = "abcd"
conversation, mm_data, mm_uuids = parse_chat_messages(
......@@ -2372,7 +2434,6 @@ def test_parse_chat_messages_single_empty_audio_with_uuid(
}
],
qwen2_audio_model_config,
qwen2_audio_tokenizer,
content_format="string",
)
......@@ -2390,7 +2451,6 @@ def test_parse_chat_messages_single_empty_audio_with_uuid(
@pytest.mark.asyncio
async def test_parse_chat_messages_single_empty_audio_with_uuid_async(
qwen2_audio_model_config,
qwen2_audio_tokenizer,
):
audio_uuid = "abcd"
conversation, mm_future, mm_uuids = parse_chat_messages_futures(
......@@ -2408,7 +2468,6 @@ async def test_parse_chat_messages_single_empty_audio_with_uuid_async(
}
],
qwen2_audio_model_config,
qwen2_audio_tokenizer,
content_format="string",
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment