Commit 8d75f22e authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.13.0rc1' into v0.13.0rc1-ori

parents ce888aa4 7d80c73d
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from openai_harmony import Role from openai.types.responses import ResponseFunctionToolCall, ResponseReasoningItem
from openai.types.responses.response_output_item import McpCall
from openai_harmony import Author, Message, Role, TextContent
from vllm.entrypoints.harmony_utils import ( from vllm.entrypoints.openai.parser.harmony_utils import (
has_custom_tools, has_custom_tools,
parse_input_to_harmony_message, parse_input_to_harmony_message,
parse_output_message,
) )
...@@ -257,6 +260,193 @@ class TestParseInputToHarmonyMessage: ...@@ -257,6 +260,193 @@ class TestParseInputToHarmonyMessage:
assert messages[0].content[1].text == "actual text" assert messages[0].content[1].text == "actual text"
class TestParseOutputMessage:
"""Tests for parse_output_message function."""
def test_commentary_with_no_recipient_creates_reasoning(self):
"""Test that commentary with recipient=None (preambles) creates reasoning items.
Per Harmony format, commentary channel can contain preambles to calling
multiple functions - explanatory text with no recipient.
"""
message = Message.from_role_and_content(
Role.ASSISTANT, "I will now search for the weather information."
)
message = message.with_channel("commentary")
# recipient is None by default, representing a preamble
output_items = parse_output_message(message)
assert len(output_items) == 1
assert isinstance(output_items[0], ResponseReasoningItem)
assert output_items[0].type == "reasoning"
assert (
output_items[0].content[0].text
== "I will now search for the weather information."
)
assert output_items[0].content[0].type == "reasoning_text"
def test_commentary_with_function_recipient_creates_function_call(self):
"""Test commentary with recipient='functions.X' creates function calls."""
message = Message.from_role_and_content(
Role.ASSISTANT, '{"location": "San Francisco", "units": "celsius"}'
)
message = message.with_channel("commentary")
message = message.with_recipient("functions.get_weather")
output_items = parse_output_message(message)
assert len(output_items) == 1
assert isinstance(output_items[0], ResponseFunctionToolCall)
assert output_items[0].type == "function_call"
assert output_items[0].name == "get_weather"
assert (
output_items[0].arguments
== '{"location": "San Francisco", "units": "celsius"}'
)
assert output_items[0].call_id.startswith("call_")
assert output_items[0].id.startswith("fc_")
def test_commentary_with_python_recipient_creates_reasoning(self):
"""Test that commentary with recipient='python' creates reasoning items."""
message = Message.from_role_and_content(
Role.ASSISTANT, "import numpy as np\nprint(np.array([1, 2, 3]))"
)
message = message.with_channel("commentary")
message = message.with_recipient("python")
output_items = parse_output_message(message)
assert len(output_items) == 1
assert isinstance(output_items[0], ResponseReasoningItem)
assert output_items[0].type == "reasoning"
assert (
output_items[0].content[0].text
== "import numpy as np\nprint(np.array([1, 2, 3]))"
)
def test_commentary_with_browser_recipient_creates_reasoning(self):
"""Test that commentary with recipient='browser' creates reasoning items."""
message = Message.from_role_and_content(
Role.ASSISTANT, "Navigating to the specified URL"
)
message = message.with_channel("commentary")
message = message.with_recipient("browser")
output_items = parse_output_message(message)
assert len(output_items) == 1
assert isinstance(output_items[0], ResponseReasoningItem)
assert output_items[0].type == "reasoning"
assert output_items[0].content[0].text == "Navigating to the specified URL"
def test_commentary_with_container_recipient_creates_reasoning(self):
"""Test that commentary with recipient='container' creates reasoning items."""
message = Message.from_role_and_content(
Role.ASSISTANT, "Running command in container"
)
message = message.with_channel("commentary")
message = message.with_recipient("container")
output_items = parse_output_message(message)
assert len(output_items) == 1
assert isinstance(output_items[0], ResponseReasoningItem)
assert output_items[0].type == "reasoning"
assert output_items[0].content[0].text == "Running command in container"
def test_commentary_with_empty_content_and_no_recipient(self):
"""Test edge case: empty commentary with recipient=None."""
message = Message.from_role_and_content(Role.ASSISTANT, "")
message = message.with_channel("commentary")
output_items = parse_output_message(message)
assert len(output_items) == 1
assert isinstance(output_items[0], ResponseReasoningItem)
assert output_items[0].content[0].text == ""
def test_commentary_with_multiple_contents_and_no_recipient(self):
"""Test multiple content items in commentary with no recipient."""
contents = [
TextContent(text="Step 1: Analyze the request"),
TextContent(text="Step 2: Prepare to call functions"),
]
message = Message.from_role_and_contents(Role.ASSISTANT, contents)
message = message.with_channel("commentary")
output_items = parse_output_message(message)
assert len(output_items) == 2
assert all(isinstance(item, ResponseReasoningItem) for item in output_items)
assert output_items[0].content[0].text == "Step 1: Analyze the request"
assert output_items[1].content[0].text == "Step 2: Prepare to call functions"
def test_commentary_with_multiple_function_calls(self):
"""Test multiple function calls in commentary channel."""
contents = [
TextContent(text='{"location": "San Francisco"}'),
TextContent(text='{"location": "New York"}'),
]
message = Message.from_role_and_contents(Role.ASSISTANT, contents)
message = message.with_channel("commentary")
message = message.with_recipient("functions.get_weather")
output_items = parse_output_message(message)
assert len(output_items) == 2
assert all(isinstance(item, ResponseFunctionToolCall) for item in output_items)
assert output_items[0].name == "get_weather"
assert output_items[1].name == "get_weather"
assert output_items[0].arguments == '{"location": "San Francisco"}'
assert output_items[1].arguments == '{"location": "New York"}'
def test_commentary_with_unknown_recipient_creates_mcp_call(self):
"""Test that commentary with unknown recipient creates MCP call."""
message = Message.from_role_and_content(Role.ASSISTANT, '{"arg": "value"}')
message = message.with_channel("commentary")
message = message.with_recipient("custom_tool")
output_items = parse_output_message(message)
assert len(output_items) == 1
assert isinstance(output_items[0], McpCall)
assert output_items[0].type == "mcp_call"
assert output_items[0].name == "custom_tool"
assert output_items[0].server_label == "custom_tool"
def test_analysis_channel_creates_reasoning(self):
"""Test that analysis channel creates reasoning items."""
message = Message.from_role_and_content(
Role.ASSISTANT, "Analyzing the problem step by step..."
)
message = message.with_channel("analysis")
output_items = parse_output_message(message)
assert len(output_items) == 1
assert isinstance(output_items[0], ResponseReasoningItem)
assert output_items[0].type == "reasoning"
assert (
output_items[0].content[0].text == "Analyzing the problem step by step..."
)
def test_non_assistant_message_returns_empty(self):
"""Test that non-assistant messages return empty list.
Per the implementation, tool messages to assistant (e.g., search results)
are not included in final output to align with OpenAI behavior.
"""
message = Message.from_author_and_content(
Author.new(Role.TOOL, "functions.get_weather"),
"The weather is sunny, 72°F",
)
output_items = parse_output_message(message)
assert len(output_items) == 0
def test_has_custom_tools() -> None: def test_has_custom_tools() -> None:
assert not has_custom_tools(set()) assert not has_custom_tools(set())
assert not has_custom_tools({"web_search_preview", "code_interpreter", "container"}) assert not has_custom_tools({"web_search_preview", "code_interpreter", "container"})
...@@ -264,3 +454,167 @@ def test_has_custom_tools() -> None: ...@@ -264,3 +454,167 @@ def test_has_custom_tools() -> None:
assert has_custom_tools( assert has_custom_tools(
{"web_search_preview", "code_interpreter", "container", "others"} {"web_search_preview", "code_interpreter", "container", "others"}
) )
def test_parse_mcp_call_basic() -> None:
"""Test that MCP calls are parsed with correct type and server_label."""
message = Message.from_role_and_content(Role.ASSISTANT, '{"path": "/tmp"}')
message = message.with_recipient("filesystem")
message = message.with_channel("commentary")
output_items = parse_output_message(message)
assert len(output_items) == 1
assert isinstance(output_items[0], McpCall)
assert output_items[0].type == "mcp_call"
assert output_items[0].name == "filesystem"
assert output_items[0].server_label == "filesystem"
assert output_items[0].arguments == '{"path": "/tmp"}'
assert output_items[0].status == "completed"
def test_parse_mcp_call_dotted_recipient() -> None:
"""Test that dotted recipients extract the tool name correctly."""
message = Message.from_role_and_content(Role.ASSISTANT, '{"cmd": "ls"}')
message = message.with_recipient("repo_browser.list")
message = message.with_channel("commentary")
output_items = parse_output_message(message)
assert len(output_items) == 1
assert isinstance(output_items[0], McpCall)
assert output_items[0].name == "list"
assert output_items[0].server_label == "repo_browser"
def test_mcp_vs_function_call() -> None:
"""Test that function calls are not parsed as MCP calls."""
func_message = Message.from_role_and_content(Role.ASSISTANT, '{"arg": "value"}')
func_message = func_message.with_recipient("functions.my_tool")
func_message = func_message.with_channel("commentary")
func_items = parse_output_message(func_message)
assert len(func_items) == 1
assert not isinstance(func_items[0], McpCall)
assert func_items[0].type == "function_call"
def test_mcp_vs_builtin_tools() -> None:
"""Test that built-in tools (python, container) are not parsed as MCP calls."""
# Test python (built-in tool) - should be reasoning, not MCP
python_message = Message.from_role_and_content(Role.ASSISTANT, "print('hello')")
python_message = python_message.with_recipient("python")
python_message = python_message.with_channel("commentary")
python_items = parse_output_message(python_message)
assert len(python_items) == 1
assert not isinstance(python_items[0], McpCall)
assert python_items[0].type == "reasoning"
def test_parse_remaining_state_commentary_channel() -> None:
"""Test parse_remaining_state with commentary channel and various recipients."""
from unittest.mock import Mock
from vllm.entrypoints.openai.parser.harmony_utils import parse_remaining_state
# Test 1: functions.* recipient → should return function tool call
parser_func = Mock()
parser_func.current_content = '{"arg": "value"}'
parser_func.current_role = Role.ASSISTANT
parser_func.current_channel = "commentary"
parser_func.current_recipient = "functions.my_tool"
func_items = parse_remaining_state(parser_func)
assert len(func_items) == 1
assert not isinstance(func_items[0], McpCall)
assert func_items[0].type == "function_call"
assert func_items[0].name == "my_tool"
assert func_items[0].status == "in_progress"
# Test 2: MCP tool (not builtin) → should return MCP call
parser_mcp = Mock()
parser_mcp.current_content = '{"path": "/tmp"}'
parser_mcp.current_role = Role.ASSISTANT
parser_mcp.current_channel = "commentary"
parser_mcp.current_recipient = "filesystem"
mcp_items = parse_remaining_state(parser_mcp)
assert len(mcp_items) == 1
assert isinstance(mcp_items[0], McpCall)
assert mcp_items[0].type == "mcp_call"
assert mcp_items[0].name == "filesystem"
assert mcp_items[0].server_label == "filesystem"
assert mcp_items[0].status == "in_progress"
# Test 3: Built-in tool (python)
# should NOT return MCP call, falls through to reasoning
parser_builtin = Mock()
parser_builtin.current_content = "print('hello')"
parser_builtin.current_role = Role.ASSISTANT
parser_builtin.current_channel = "commentary"
parser_builtin.current_recipient = "python"
builtin_items = parse_remaining_state(parser_builtin)
# Should fall through to reasoning logic
assert len(builtin_items) == 1
assert not isinstance(builtin_items[0], McpCall)
assert builtin_items[0].type == "reasoning"
def test_parse_remaining_state_analysis_channel() -> None:
"""Test parse_remaining_state with analysis channel and various recipients."""
from unittest.mock import Mock
from vllm.entrypoints.openai.parser.harmony_utils import parse_remaining_state
# Test 1: functions.* recipient → should return function tool call
parser_func = Mock()
parser_func.current_content = '{"arg": "value"}'
parser_func.current_role = Role.ASSISTANT
parser_func.current_channel = "analysis"
parser_func.current_recipient = "functions.my_tool"
func_items = parse_remaining_state(parser_func)
assert len(func_items) == 1
assert not isinstance(func_items[0], McpCall)
assert func_items[0].type == "function_call"
assert func_items[0].name == "my_tool"
assert func_items[0].status == "in_progress"
# Test 2: MCP tool (not builtin) → should return MCP call
parser_mcp = Mock()
parser_mcp.current_content = '{"query": "test"}'
parser_mcp.current_role = Role.ASSISTANT
parser_mcp.current_channel = "analysis"
parser_mcp.current_recipient = "database"
mcp_items = parse_remaining_state(parser_mcp)
assert len(mcp_items) == 1
assert isinstance(mcp_items[0], McpCall)
assert mcp_items[0].type == "mcp_call"
assert mcp_items[0].name == "database"
assert mcp_items[0].server_label == "database"
assert mcp_items[0].status == "in_progress"
# Test 3: Built-in tool (container)
# should NOT return MCP call, falls through to reasoning
parser_builtin = Mock()
parser_builtin.current_content = "docker run"
parser_builtin.current_role = Role.ASSISTANT
parser_builtin.current_channel = "analysis"
parser_builtin.current_recipient = "container"
builtin_items = parse_remaining_state(parser_builtin)
# Should fall through to reasoning logic
assert len(builtin_items) == 1
assert not isinstance(builtin_items[0], McpCall)
assert builtin_items[0].type == "reasoning"
...@@ -232,7 +232,7 @@ async def test_server_load(server: RemoteOpenAIServer): ...@@ -232,7 +232,7 @@ async def test_server_load(server: RemoteOpenAIServer):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_health_check_engine_dead_error(): async def test_health_check_engine_dead_error():
# Import the health function directly to test it in isolation # Import the health function directly to test it in isolation
from vllm.entrypoints.openai.api_server import health from vllm.entrypoints.serve.instrumentator.health import health
# Create a mock request that simulates what FastAPI would provide # Create a mock request that simulates what FastAPI would provide
mock_request = Mock(spec=Request) mock_request = Mock(spec=Request)
......
...@@ -69,9 +69,20 @@ async def test_anthropic_streaming(client: anthropic.AsyncAnthropic): ...@@ -69,9 +69,20 @@ async def test_anthropic_streaming(client: anthropic.AsyncAnthropic):
stream=True, stream=True,
) )
first_chunk = None
chunk_count = 0
async for chunk in resp: async for chunk in resp:
chunk_count += 1
if first_chunk is None and chunk.type == "message_start":
first_chunk = chunk
print(chunk.model_dump_json()) print(chunk.model_dump_json())
assert chunk_count > 0
assert first_chunk is not None, "message_start chunk was never observed"
assert first_chunk.usage is not None, "first chunk should include usage stats"
assert first_chunk.usage["output_tokens"] == 0
assert first_chunk.usage["input_tokens"] > 5
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_anthropic_tool_call(client: anthropic.AsyncAnthropic): async def test_anthropic_tool_call(client: anthropic.AsyncAnthropic):
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import importlib
import json
import pytest
import pytest_asyncio
from openai import OpenAI
from ...utils import RemoteOpenAIServer
MODEL_NAME = "Qwen/Qwen3-8B"
@pytest.fixture(scope="module")
def server():
assert importlib.util.find_spec("gpt_oss") is not None, (
"Harmony tests require gpt_oss package to be installed"
)
args = [
"--reasoning-parser",
"qwen3",
"--max_model_len",
"5000",
"--structured-outputs-config.backend",
"xgrammar",
"--enable-auto-tool-choice",
"--tool-call-parser",
"hermes",
"--tool-server",
"demo",
]
env_dict = dict(
VLLM_ENABLE_RESPONSES_API_STORE="1",
VLLM_USE_EXPERIMENTAL_PARSER_CONTEXT="1",
PYTHON_EXECUTION_BACKEND="dangerously_use_uv",
)
with RemoteOpenAIServer(MODEL_NAME, args, env_dict=env_dict) as remote_server:
yield remote_server
@pytest_asyncio.fixture
async def client(server):
async with server.get_async_client() as async_client:
yield async_client
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_basic(client: OpenAI, model_name: str):
response = await client.responses.create(
model=model_name,
input="What is 13 * 24?",
)
assert response is not None
print("response: ", response)
assert response.status == "completed"
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_reasoning_and_function_items(client: OpenAI, model_name: str):
response = await client.responses.create(
model=model_name,
input=[
{"type": "message", "content": "Hello.", "role": "user"},
{
"type": "reasoning",
"id": "lol",
"content": [
{
"type": "reasoning_text",
"text": "We need to respond: greeting.",
}
],
"summary": [],
},
{
"arguments": '{"location": "Paris", "unit": "celsius"}',
"call_id": "call_5f7b38f3b81e4b8380fd0ba74f3ca3ab",
"name": "get_weather",
"type": "function_call",
"id": "fc_4fe5d6fc5b6c4d6fa5f24cc80aa27f78",
"status": "completed",
},
{
"call_id": "call_5f7b38f3b81e4b8380fd0ba74f3ca3ab",
"id": "fc_4fe5d6fc5b6c4d6fa5f24cc80aa27f78",
"output": "The weather in Paris is 20 Celsius",
"status": "completed",
"type": "function_call_output",
},
],
temperature=0.0,
)
assert response is not None
assert response.status == "completed"
# make sure we get a reasoning and text output
assert response.output[0].type == "reasoning"
assert response.output[1].type == "message"
assert type(response.output[1].content[0].text) is str
def get_horoscope(sign):
return f"{sign}: Next Tuesday you will befriend a baby otter."
def call_function(name, args):
if name == "get_horoscope":
return get_horoscope(**args)
else:
raise ValueError(f"Unknown function: {name}")
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_function_call_first_turn(client: OpenAI, model_name: str):
tools = [
{
"type": "function",
"name": "get_horoscope",
"description": "Get today's horoscope for an astrological sign.",
"parameters": {
"type": "object",
"properties": {
"sign": {"type": "string"},
},
"required": ["sign"],
"additionalProperties": False,
},
"strict": True,
}
]
response = await client.responses.create(
model=model_name,
input="What is the horoscope for Aquarius today?",
tools=tools,
temperature=0.0,
)
assert response is not None
assert response.status == "completed"
assert len(response.output) == 2
assert response.output[0].type == "reasoning"
assert response.output[1].type == "function_call"
function_call = response.output[1]
assert function_call.name == "get_horoscope"
assert function_call.call_id is not None
args = json.loads(function_call.arguments)
assert "sign" in args
# the multi turn function call is tested above in
# test_reasoning_and_function_items
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_mcp_tool_call(client: OpenAI, model_name: str):
response = await client.responses.create(
model=model_name,
input="What is 13 * 24? Use python to calculate the result.",
tools=[{"type": "code_interpreter", "container": {"type": "auto"}}],
temperature=0.0,
)
assert response is not None
assert response.status == "completed"
assert response.output[0].type == "reasoning"
assert response.output[1].type == "mcp_call"
assert type(response.output[1].arguments) is str
assert type(response.output[1].output) is str
assert response.output[2].type == "reasoning"
# make sure the correct math is in the final output
assert response.output[3].type == "message"
assert "312" in response.output[3].content[0].text
...@@ -42,6 +42,24 @@ async def test_basic(client: OpenAI, model_name: str): ...@@ -42,6 +42,24 @@ async def test_basic(client: OpenAI, model_name: str):
assert response.status == "completed" assert response.status == "completed"
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_enable_response_messages(client: OpenAI, model_name: str):
response = await client.responses.create(
model=model_name,
input="Hello?",
extra_body={"enable_response_messages": True},
)
assert response.status == "completed"
assert response.input_messages[0]["type"] == "raw_message_tokens"
assert type(response.input_messages[0]["message"]) is str
assert len(response.input_messages[0]["message"]) > 10
assert type(response.input_messages[0]["tokens"][0]) is int
assert type(response.output_messages[0]["message"]) is str
assert len(response.output_messages[0]["message"]) > 10
assert type(response.output_messages[0]["tokens"][0]) is int
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_reasoning_item(client: OpenAI, model_name: str): async def test_reasoning_item(client: OpenAI, model_name: str):
......
...@@ -726,7 +726,7 @@ async def test_function_calling_required(client: OpenAI, model_name: str): ...@@ -726,7 +726,7 @@ async def test_function_calling_required(client: OpenAI, model_name: str):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_system_message_with_tools(client: OpenAI, model_name: str): async def test_system_message_with_tools(client: OpenAI, model_name: str):
from vllm.entrypoints.harmony_utils import get_system_message from vllm.entrypoints.openai.parser.harmony_utils import get_system_message
# Test with custom tools enabled - commentary channel should be available # Test with custom tools enabled - commentary channel should be available
sys_msg = get_system_message(with_custom_tools=True) sys_msg = get_system_message(with_custom_tools=True)
......
...@@ -32,24 +32,20 @@ async def whisper_client(server): ...@@ -32,24 +32,20 @@ async def whisper_client(server):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_basic_audio(mary_had_lamb): async def test_basic_audio(whisper_client, mary_had_lamb):
server_args = ["--enforce-eager"]
# Based on https://github.com/openai/openai-cookbook/blob/main/examples/Whisper_prompting_guide.ipynb. # Based on https://github.com/openai/openai-cookbook/blob/main/examples/Whisper_prompting_guide.ipynb.
with RemoteOpenAIServer(MODEL_NAME, server_args) as remote_server: transcription = await whisper_client.audio.transcriptions.create(
client = remote_server.get_async_client() model=MODEL_NAME,
transcription = await client.audio.transcriptions.create( file=mary_had_lamb,
model=MODEL_NAME, language="en",
file=mary_had_lamb, response_format="text",
language="en", temperature=0.0,
response_format="text", )
temperature=0.0, out = json.loads(transcription)
) out_text = out["text"]
out = json.loads(transcription) out_usage = out["usage"]
out_text = out["text"] assert "Mary had a little lamb," in out_text
out_usage = out["usage"] assert out_usage["seconds"] == 16, out_usage["seconds"]
assert "Mary had a little lamb," in out_text
assert out_usage["seconds"] == 16, out_usage["seconds"]
@pytest.mark.asyncio @pytest.mark.asyncio
......
...@@ -8,6 +8,7 @@ import pytest ...@@ -8,6 +8,7 @@ import pytest
import pytest_asyncio import pytest_asyncio
from transformers import AutoProcessor from transformers import AutoProcessor
from vllm.multimodal.base import MediaWithBytes
from vllm.multimodal.utils import encode_image_base64, fetch_image from vllm.multimodal.utils import encode_image_base64, fetch_image
from ...utils import RemoteOpenAIServer from ...utils import RemoteOpenAIServer
...@@ -111,7 +112,11 @@ def get_hf_prompt_tokens(model_name, content, image_url): ...@@ -111,7 +112,11 @@ def get_hf_prompt_tokens(model_name, content, image_url):
"content": f"{placeholder}{content}", "content": f"{placeholder}{content}",
} }
] ]
images = [fetch_image(image_url)] image = fetch_image(image_url)
# Unwrap MediaWithBytes if present
if isinstance(image, MediaWithBytes):
image = image.media
images = [image]
prompt = processor.tokenizer.apply_chat_template( prompt = processor.tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True messages, tokenize=False, add_generation_prompt=True
......
...@@ -2,64 +2,47 @@ ...@@ -2,64 +2,47 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import base64 import base64
import io
import numpy as np import numpy as np
import pytest import pytest
import requests import requests
import torch import torch
from ...utils import RemoteOpenAIServer from vllm.utils.serial_utils import tensor2base64
MODEL_NAME = "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11" from ...utils import RemoteOpenAIServer
DTYPE = "float16"
def _terratorch_dummy_inputs(model_name: str): def _terratorch_dummy_messages():
pixel_values = torch.full((6, 512, 512), 1.0, dtype=torch.float16) pixel_values = torch.full((6, 512, 512), 1.0, dtype=torch.float16)
location_coords = torch.full((1, 2), 1.0, dtype=torch.float16) location_coords = torch.full((1, 2), 1.0, dtype=torch.float16)
buffer_tiff = io.BytesIO() return [
torch.save(pixel_values, buffer_tiff) {
buffer_tiff.seek(0) "role": "user",
binary_data = buffer_tiff.read() "content": [
base64_tensor_embedding = base64.b64encode(binary_data).decode("utf-8") {
"type": "image_embeds",
buffer_coord = io.BytesIO() "image_embeds": {
torch.save(location_coords, buffer_coord) "pixel_values": tensor2base64(pixel_values),
buffer_coord.seek(0) "location_coords": tensor2base64(location_coords),
binary_data = buffer_coord.read() },
base64_coord_embedding = base64.b64encode(binary_data).decode("utf-8") }
],
return { }
"model": model_name, ]
"additional_data": {"prompt_token_ids": [1]},
"encoding_format": "base64",
"messages": [
{
"role": "user",
"content": [
{
"type": "image_embeds",
"image_embeds": {
"pixel_values": base64_tensor_embedding,
"location_coords": base64_coord_embedding,
},
}
],
}
],
}
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize(
async def test_single_request(model_name: str): "model_name", ["ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11"]
)
def test_single_request(model_name: str):
args = [ args = [
"--runner", "--runner",
"pooling", "pooling",
# use half precision for speed and memory savings in CI environment # use half precision for speed and memory savings in CI environment
"--dtype", "--dtype",
DTYPE, "float16",
"--enforce-eager", "--enforce-eager",
"--trust-remote-code", "--trust-remote-code",
"--max-num-seqs", "--max-num-seqs",
...@@ -70,11 +53,15 @@ async def test_single_request(model_name: str): ...@@ -70,11 +53,15 @@ async def test_single_request(model_name: str):
"--enable-mm-embeds", "--enable-mm-embeds",
] ]
with RemoteOpenAIServer(MODEL_NAME, args) as server: with RemoteOpenAIServer(model_name, args) as server:
prompt = _terratorch_dummy_inputs(model_name) response = requests.post(
server.url_for("pooling"),
# test single pooling json={
response = requests.post(server.url_for("pooling"), json=prompt) "model": model_name,
"messages": _terratorch_dummy_messages(),
"encoding_format": "base64",
},
)
response.raise_for_status() response.raise_for_status()
output = response.json()["data"][0]["data"] output = response.json()["data"][0]["data"]
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
import pytest
from tests.entrypoints.openai.tool_parsers.utils import (
run_tool_extraction,
run_tool_extraction_streaming,
)
from vllm.entrypoints.openai.protocol import FunctionCall
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
from vllm.tokenizers import TokenizerLike
SIMPLE_ARGS_DICT = {
"action": "create",
"id": "preferences",
}
SIMPLE_FUNCTION_JSON = json.dumps(
{
"name": "manage_user_memory",
"arguments": SIMPLE_ARGS_DICT,
},
ensure_ascii=False,
)
SIMPLE_FUNCTION_OUTPUT = "function call" + SIMPLE_FUNCTION_JSON
SIMPLE_FUNCTION_CALL = FunctionCall(
name="manage_user_memory",
arguments=json.dumps(SIMPLE_ARGS_DICT, ensure_ascii=False),
)
PARAMETERLESS_FUNCTION_JSON = json.dumps(
{
"name": "manage_user_memory",
"arguments": {},
},
ensure_ascii=False,
)
PARAMETERLESS_FUNCTION_OUTPUT = "function call" + PARAMETERLESS_FUNCTION_JSON
PARAMETERLESS_FUNCTION_CALL = FunctionCall(
name="manage_user_memory",
arguments=json.dumps({}, ensure_ascii=False),
)
COMPLEX_ARGS_DICT = {
"action": "create",
"id": "preferences",
"content": {
"short_answers": True,
"hate_emojis": True,
"english_ui": False,
"russian_math_explanations": True,
},
}
COMPLEX_FUNCTION_JSON = json.dumps(
{
"name": "manage_user_memory",
"arguments": COMPLEX_ARGS_DICT,
},
ensure_ascii=False,
)
COMPLEX_FUNCTION_OUTPUT = "function call" + COMPLEX_FUNCTION_JSON
COMPLEX_FUNCTION_CALL = FunctionCall(
name="manage_user_memory",
arguments=json.dumps(COMPLEX_ARGS_DICT, ensure_ascii=False),
)
@pytest.mark.parametrize("streaming", [True, False])
def test_no_tool_call(streaming: bool, default_tokenizer: TokenizerLike):
tool_parser: ToolParser = ToolParserManager.get_tool_parser("gigachat3")(
default_tokenizer
)
model_output = "How can I help you today?"
content, tool_calls = run_tool_extraction(
tool_parser, model_output, streaming=streaming
)
assert content == model_output
assert len(tool_calls) == 0
TEST_CASES = [
pytest.param(
True,
SIMPLE_FUNCTION_OUTPUT,
[SIMPLE_FUNCTION_CALL],
None,
id="simple_streaming",
),
pytest.param(
False,
SIMPLE_FUNCTION_OUTPUT,
[SIMPLE_FUNCTION_CALL],
None,
id="simple_nonstreaming",
),
pytest.param(
True,
PARAMETERLESS_FUNCTION_OUTPUT,
[PARAMETERLESS_FUNCTION_CALL],
None,
id="parameterless_streaming",
),
pytest.param(
False,
PARAMETERLESS_FUNCTION_OUTPUT,
[PARAMETERLESS_FUNCTION_CALL],
None,
id="parameterless_nonstreaming",
),
pytest.param(
True,
COMPLEX_FUNCTION_OUTPUT,
[COMPLEX_FUNCTION_CALL],
None,
id="complex_streaming",
),
pytest.param(
False,
COMPLEX_FUNCTION_OUTPUT,
[COMPLEX_FUNCTION_CALL],
None,
id="complex_nonstreaming",
),
]
@pytest.mark.parametrize(
"streaming, model_output, expected_tool_calls, expected_content", TEST_CASES
)
def test_tool_call(
streaming: bool,
model_output: str,
expected_tool_calls: list[FunctionCall],
expected_content: str | None,
default_tokenizer: TokenizerLike,
):
tool_parser: ToolParser = ToolParserManager.get_tool_parser("gigachat3")(
default_tokenizer
)
content, tool_calls = run_tool_extraction(
tool_parser, model_output, streaming=streaming
)
assert content == expected_content
assert len(tool_calls) == len(expected_tool_calls)
for actual, expected in zip(tool_calls, expected_tool_calls):
assert actual.type == "function"
assert actual.function.name == expected.name
actual_args = json.loads(actual.function.arguments)
expected_args = json.loads(expected.arguments)
assert actual_args == expected_args
def test_streaming_tool_call_with_large_steps(default_tokenizer: TokenizerLike):
tool_parser: ToolParser = ToolParserManager.get_tool_parser("gigachat3")(
default_tokenizer
)
model_output_deltas = [
"function call",
COMPLEX_FUNCTION_JSON[:40],
COMPLEX_FUNCTION_JSON[40:],
]
reconstructor = run_tool_extraction_streaming(
tool_parser,
model_output_deltas,
assert_one_tool_per_delta=False,
)
assert len(reconstructor.tool_calls) == 1
call = reconstructor.tool_calls[0]
assert call.type == "function"
assert call.function.name == "manage_user_memory"
args_dict = json.loads(call.function.arguments)
assert args_dict == COMPLEX_ARGS_DICT
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
import jsonschema
import openai
import pytest
import pytest_asyncio
from rapidfuzz import fuzz
from ....utils import RemoteOpenAIServer
MODEL_NAME = "openai/gpt-oss-20b"
@pytest.fixture(scope="module")
def server():
args = [
"--max-model-len",
"8192",
"--enforce-eager",
"--enable-auto-tool-choice",
"--tool-call-parser",
"openai",
]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
@pytest_asyncio.fixture
async def client(server):
"""Async fixture providing an OpenAI-compatible vLLM client."""
async with server.get_async_client() as async_client:
yield async_client
# ==========================================================
# Tool Definitions
# ==========================================================
TOOLS = [
{
"type": "function",
"function": {
"name": "calculator",
"description": "Performs basic arithmetic calculations.",
"parameters": {
"type": "object",
"properties": {
"expression": {
"type": "string",
"description": (
"Arithmetic expression to evaluate, e.g. '123 + 456'."
),
}
},
"required": ["expression"],
},
},
},
{
"type": "function",
"function": {
"name": "get_time",
"description": "Retrieves the current local time for a given city.",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "City name, e.g. 'New York'.",
}
},
"required": ["city"],
},
},
},
]
# ==========================================================
# Message Examples
# ==========================================================
MESSAGES_CALC = [
{"role": "user", "content": "Calculate 123 + 456 using the calculator."}
]
MESSAGES_GET_TIME = [
{"role": "user", "content": "What is the current time in New York?"}
]
MESSAGES_MULTIPLE_CALLS = [
{
"role": "system",
"content": (
"You can call multiple tools. "
"When using more than one, return single JSON object with tool_calls array"
"containing each tool call with its function name and arguments. "
"Do not output multiple JSON objects separately."
),
},
{
"role": "user",
"content": "First, calculate 7 * 8 using the calculator. "
"Then, use get_time to tell me the current time in New York.",
},
]
MESSAGES_INVALID_CALL = [
{
"role": "user",
"content": "Can you help with something, "
"but don’t actually perform any calculation?",
}
]
# Expected outputs
FUNC_CALC = "calculator"
FUNC_ARGS_CALC = '{"expression":"123 + 456"}'
FUNC_TIME = "get_time"
FUNC_ARGS_TIME = '{"city": "New York"}'
# ==========================================================
# Utility to extract reasoning and tool calls
# ==========================================================
def extract_reasoning_and_calls(chunks: list) -> tuple[str, list[str], list[str]]:
"""
Extract accumulated reasoning text and tool call arguments
from streaming chunks.
"""
reasoning_content: str = ""
tool_calls: dict[int, dict[str, str]] = {}
for chunk in chunks:
choice = getattr(chunk.choices[0], "delta", None)
if not choice:
continue
if hasattr(choice, "reasoning_content") and choice.reasoning_content:
reasoning_content += choice.reasoning_content
for tc in getattr(choice, "tool_calls", []) or []:
idx = getattr(tc, "index", 0)
tool_entry = tool_calls.setdefault(idx, {"name": "", "arguments": ""})
if getattr(tc, "function", None):
func = tc.function
if getattr(func, "name", None):
tool_entry["name"] = func.name
if getattr(func, "arguments", None):
tool_entry["arguments"] += func.arguments
function_names: list[str] = [v["name"] for _, v in sorted(tool_calls.items())]
arguments: list[str] = [v["arguments"] for _, v in sorted(tool_calls.items())]
return reasoning_content, arguments, function_names
# ==========================================================
# Test Scenarios
# ==========================================================
@pytest.mark.asyncio
async def test_calculator_tool_call_and_argument_accuracy(client: openai.AsyncOpenAI):
"""Verify calculator tool call is made and arguments are accurate."""
response = await client.chat.completions.create(
model=MODEL_NAME,
messages=MESSAGES_CALC,
tools=TOOLS,
temperature=0.0,
stream=False,
)
message = response.choices[0].message
tool_calls = getattr(message, "tool_calls", [])
assert tool_calls, "No tool calls detected"
calc_call = next((c for c in tool_calls if c.function.name == FUNC_CALC), None)
assert calc_call, "Calculator function not called"
raw_args = calc_call.function.arguments
assert raw_args, "Calculator arguments missing"
assert "123" in raw_args and "456" in raw_args, (
f"Expected values not in raw arguments: {raw_args}"
)
try:
parsed_args = json.loads(raw_args)
except json.JSONDecodeError:
pytest.fail(f"Invalid JSON in calculator arguments: {raw_args}")
expected_expr = "123 + 456"
actual_expr = parsed_args.get("expression", "")
similarity = fuzz.ratio(actual_expr, expected_expr)
assert similarity > 90, (
f"Expression mismatch: expected '{expected_expr}' "
f"got '{actual_expr}' (similarity={similarity}%)"
)
@pytest.mark.asyncio
async def test_streaming_tool_call_get_time_with_reasoning(client: openai.AsyncOpenAI):
"""Verify streamed reasoning and tool call behavior for get_time."""
stream = await client.chat.completions.create(
model=MODEL_NAME,
messages=MESSAGES_GET_TIME,
tools=TOOLS,
temperature=0.0,
stream=True,
)
chunks = [chunk async for chunk in stream]
reasoning, arguments, function_names = extract_reasoning_and_calls(chunks)
assert FUNC_TIME in function_names, "get_time function not called"
assert any("New York" in arg for arg in arguments), (
f"Expected get_time arguments for New York not found in {arguments}"
)
assert len(reasoning) > 0, "Expected reasoning content missing"
assert any(keyword in reasoning for keyword in ["New York", "time", "current"]), (
f"Reasoning is not relevant to the request: {reasoning}"
)
@pytest.mark.asyncio
async def test_streaming_multiple_tools(client: openai.AsyncOpenAI):
"""Test streamed multi-tool response with reasoning."""
stream = await client.chat.completions.create(
model=MODEL_NAME,
messages=MESSAGES_MULTIPLE_CALLS,
tools=TOOLS,
temperature=0.0,
stream=True,
)
chunks = [chunk async for chunk in stream]
reasoning, arguments, function_names = extract_reasoning_and_calls(chunks)
try:
assert FUNC_CALC in function_names, (
f"Calculator tool missing — found {function_names}"
)
assert FUNC_TIME in function_names, (
f"Time tool missing — found {function_names}"
)
assert len(reasoning) > 0, "Expected reasoning content in streamed response"
except AssertionError as e:
print(f"ERROR: {e}")
@pytest.mark.asyncio
async def test_invalid_tool_call(client: openai.AsyncOpenAI):
"""
Verify that ambiguous instructions that should not trigger a tool
do not produce any tool calls.
"""
response = await client.chat.completions.create(
model=MODEL_NAME,
messages=MESSAGES_INVALID_CALL,
tools=TOOLS,
temperature=0.0,
stream=False,
)
message = response.choices[0].message
assert message is not None, "Expected message in response"
assert hasattr(message, "content"), "Expected 'content' field in message"
tool_calls = getattr(message, "tool_calls", [])
assert not tool_calls, (
f"Model unexpectedly attempted a tool call on invalid input: {tool_calls}"
)
@pytest.mark.asyncio
async def test_tool_call_with_temperature(client: openai.AsyncOpenAI):
"""
Verify model produces valid tool or text output
under non-deterministic sampling.
"""
response = await client.chat.completions.create(
model=MODEL_NAME,
messages=MESSAGES_CALC,
tools=TOOLS,
temperature=0.7,
stream=False,
)
message = response.choices[0].message
assert message is not None, "Expected non-empty message in response"
assert message.tool_calls or message.content, (
"Response missing both text and tool calls"
)
print(f"\nTool calls: {message.tool_calls}")
print(f"Text: {message.content}")
@pytest.mark.asyncio
async def test_tool_response_schema_accuracy(client: openai.AsyncOpenAI):
"""Validate that tool call arguments adhere to their declared JSON schema."""
response = await client.chat.completions.create(
model=MODEL_NAME,
messages=MESSAGES_MULTIPLE_CALLS,
tools=TOOLS,
temperature=0.0,
)
calls = response.choices[0].message.tool_calls
assert calls, "No tool calls produced"
for call in calls:
func_name = call.function.name
args = json.loads(call.function.arguments)
schema: dict[str, object] | None = None
for tool_entry in TOOLS:
function_def = tool_entry.get("function")
if (
function_def
and isinstance(function_def, dict)
and function_def.get("name") == func_name
):
schema = function_def.get("parameters")
break
assert schema is not None, f"No matching tool schema found for {func_name}"
jsonschema.validate(instance=args, schema=schema)
@pytest.mark.asyncio
async def test_semantic_consistency_with_temperature(client: openai.AsyncOpenAI):
"""Test that temperature variation doesn't cause contradictory reasoning."""
responses = []
for temp in [0.0, 0.5, 1.0]:
resp = await client.chat.completions.create(
model=MODEL_NAME,
messages=MESSAGES_CALC,
tools=TOOLS,
temperature=temp,
)
text = (resp.choices[0].message.content or "").strip()
responses.append(text)
# Compare fuzzy similarity between low- and mid-temperature outputs
low_mid_sim = fuzz.ratio(responses[0], responses[1])
assert low_mid_sim > 60, (
f"Semantic drift too large between T=0.0 and T=0.5 ({low_mid_sim}%)"
)
...@@ -61,11 +61,8 @@ def test_pooling_params(llm: LLM): ...@@ -61,11 +61,8 @@ def test_pooling_params(llm: LLM):
@pytest.mark.skip_global_cleanup @pytest.mark.skip_global_cleanup
def test_encode_api(llm: LLM): def test_token_classify(llm: LLM):
# chunked prefill does not support all pooling llm.encode(prompts, pooling_task="token_classify", use_tqdm=False)
err_msg = "pooling_task must be one of.+"
with pytest.raises(ValueError, match=err_msg):
llm.encode(prompts, pooling_task="token_classify", use_tqdm=False)
def test_score_api(llm: LLM): def test_score_api(llm: LLM):
......
...@@ -255,21 +255,21 @@ async def test_pooling_classify(server: RemoteOpenAIServer, model_name: str): ...@@ -255,21 +255,21 @@ async def test_pooling_classify(server: RemoteOpenAIServer, model_name: str):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_pooling_token_classify(server: RemoteOpenAIServer, model_name: str): async def test_pooling_token_classify(server: RemoteOpenAIServer, model_name: str):
# token_classify uses ALL pooling, which does not support chunked prefill.
task = "token_classify" task = "token_classify"
input_text = ["This product was excellent and exceeded my expectations"]
response = requests.post( response = requests.post(
server.url_for("pooling"), server.url_for("pooling"),
json={ json={
"model": model_name, "model": model_name,
"input": "test", "input": input_text,
"encoding_format": "float", "encoding_format": "float",
"task": task, "task": task,
}, },
) )
assert response.json()["error"]["type"] == "BadRequestError" poolings = PoolingResponse.model_validate(response.json())
assert response.json()["error"]["message"].startswith( assert len(poolings.data) == 1
f"Task {task} is not supported" assert len(poolings.data[0].data) == 8
) assert len(poolings.data[0].data[0]) == 2
@pytest.mark.asyncio @pytest.mark.asyncio
......
...@@ -42,7 +42,7 @@ def llm(): ...@@ -42,7 +42,7 @@ def llm():
@pytest.mark.skip_global_cleanup @pytest.mark.skip_global_cleanup
def test_encode_api(llm: LLM): def test_token_embed(llm: LLM):
outputs = llm.encode(prompts, pooling_task="token_embed", use_tqdm=False) outputs = llm.encode(prompts, pooling_task="token_embed", use_tqdm=False)
multi_vector = outputs[0].outputs.data multi_vector = outputs[0].outputs.data
assert multi_vector.shape == (11, 384) assert multi_vector.shape == (11, 384)
......
...@@ -24,6 +24,7 @@ from vllm.utils.serial_utils import ( ...@@ -24,6 +24,7 @@ from vllm.utils.serial_utils import (
ENDIANNESS, ENDIANNESS,
MetadataItem, MetadataItem,
binary2tensor, binary2tensor,
build_metadata_items,
decode_pooling_output, decode_pooling_output,
) )
...@@ -344,6 +345,55 @@ async def test_bytes_embed_dtype_and_endianness( ...@@ -344,6 +345,55 @@ async def test_bytes_embed_dtype_and_endianness(
) )
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_bytes_only_embed_dtype_and_endianness(
server: RemoteOpenAIServer, client: openai.AsyncOpenAI, model_name: str
):
input_texts = [
"The best thing about vLLM is that it supports many different models",
] * 2
responses_float = await client.embeddings.create(
input=input_texts, model=model_name, encoding_format="float"
)
float_data = [d.embedding for d in responses_float.data]
embedding_size = len(float_data[0])
for embed_dtype in list(EMBED_DTYPE_TO_TORCH_DTYPE.keys()):
for endianness in ENDIANNESS:
responses_bytes = requests.post(
server.url_for("/v1/embeddings"),
json={
"model": model_name,
"input": input_texts,
"encoding_format": "bytes_only",
"embed_dtype": embed_dtype,
"endianness": endianness,
},
)
assert "metadata" not in responses_bytes.headers
body = responses_bytes.content
items = build_metadata_items(
embed_dtype=embed_dtype,
endianness=endianness,
shape=(embedding_size,),
n_request=len(input_texts),
)
bytes_data = decode_pooling_output(items=items, body=body)
bytes_data = [x.to(torch.float32).tolist() for x in bytes_data]
check_embeddings_close(
embeddings_0_lst=float_data,
embeddings_1_lst=bytes_data,
name_0="float_data",
name_1="bytes_data",
tol=1e-2,
)
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("param_name", ["encoding_format", "embed_dtype", "endianness"]) @pytest.mark.parametrize("param_name", ["encoding_format", "embed_dtype", "endianness"])
......
...@@ -9,6 +9,7 @@ from transformers import AutoProcessor ...@@ -9,6 +9,7 @@ from transformers import AutoProcessor
from tests.utils import VLLM_PATH, RemoteOpenAIServer from tests.utils import VLLM_PATH, RemoteOpenAIServer
from vllm.entrypoints.pooling.embed.protocol import EmbeddingResponse from vllm.entrypoints.pooling.embed.protocol import EmbeddingResponse
from vllm.multimodal.base import MediaWithBytes
from vllm.multimodal.utils import encode_image_base64, fetch_image from vllm.multimodal.utils import encode_image_base64, fetch_image
MODEL_NAME = "TIGER-Lab/VLM2Vec-Full" MODEL_NAME = "TIGER-Lab/VLM2Vec-Full"
...@@ -62,7 +63,11 @@ def get_hf_prompt_tokens(model_name, content, image_url): ...@@ -62,7 +63,11 @@ def get_hf_prompt_tokens(model_name, content, image_url):
placeholder = "<|image_1|> " placeholder = "<|image_1|> "
prompt = f"{placeholder}{content}" prompt = f"{placeholder}{content}"
images = [fetch_image(image_url)] image = fetch_image(image_url)
# Unwrap MediaWithBytes if present
if isinstance(image, MediaWithBytes):
image = image.media
images = [image]
inputs = processor(prompt, images, return_tensors="pt") inputs = processor(prompt, images, return_tensors="pt")
return inputs.input_ids.shape[1] return inputs.input_ids.shape[1]
......
...@@ -18,6 +18,7 @@ from vllm.utils.serial_utils import ( ...@@ -18,6 +18,7 @@ from vllm.utils.serial_utils import (
ENDIANNESS, ENDIANNESS,
MetadataItem, MetadataItem,
binary2tensor, binary2tensor,
build_metadata_items,
decode_pooling_output, decode_pooling_output,
) )
...@@ -352,6 +353,61 @@ async def test_bytes_embed_dtype_and_endianness( ...@@ -352,6 +353,61 @@ async def test_bytes_embed_dtype_and_endianness(
) )
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_bytes_only_embed_dtype_and_endianness(
server: RemoteOpenAIServer, model_name: str
):
input_texts = [
"The best thing about vLLM is that it supports many different models",
] * 2
url = server.url_for("pooling")
float_response = requests.post(
url,
json={
"model": model_name,
"input": input_texts,
"encoding_format": "float",
},
)
responses_float = PoolingResponse.model_validate(float_response.json())
float_data = [np.array(d.data).squeeze(-1).tolist() for d in responses_float.data]
n_tokens = responses_float.usage.prompt_tokens // len(input_texts)
for embed_dtype in list(EMBED_DTYPE_TO_TORCH_DTYPE.keys()):
for endianness in ENDIANNESS:
responses_bytes = requests.post(
url,
json={
"model": model_name,
"input": input_texts,
"encoding_format": "bytes_only",
"embed_dtype": embed_dtype,
"endianness": endianness,
},
)
assert "metadata" not in responses_bytes.headers
body = responses_bytes.content
items = build_metadata_items(
embed_dtype=embed_dtype,
endianness=endianness,
shape=(n_tokens, 1),
n_request=len(input_texts),
)
bytes_data = decode_pooling_output(items=items, body=body)
bytes_data = [x.to(torch.float32).view(-1).tolist() for x in bytes_data]
check_embeddings_close(
embeddings_0_lst=float_data,
embeddings_1_lst=bytes_data,
name_0="float_data",
name_1="bytes_data",
tol=1e-2,
)
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("param_name", ["encoding_format", "embed_dtype", "endianness"]) @pytest.mark.parametrize("param_name", ["encoding_format", "embed_dtype", "endianness"])
......
...@@ -36,6 +36,13 @@ def llm(): ...@@ -36,6 +36,13 @@ def llm():
cleanup_dist_env_and_memory() cleanup_dist_env_and_memory()
@pytest.mark.skip_global_cleanup
def test_config(llm: LLM):
vllm_config = llm.llm_engine.vllm_config
assert vllm_config.cache_config.enable_prefix_caching
assert vllm_config.scheduler_config.enable_chunked_prefill
def test_pooling_params(llm: LLM): def test_pooling_params(llm: LLM):
def get_outputs(use_activation): def get_outputs(use_activation):
outputs = llm.reward( outputs = llm.reward(
......
...@@ -6,6 +6,7 @@ from collections.abc import Mapping ...@@ -6,6 +6,7 @@ from collections.abc import Mapping
from typing import Literal from typing import Literal
import pytest import pytest
import torch
from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy
from vllm.assets.audio import AudioAsset from vllm.assets.audio import AudioAsset
...@@ -29,6 +30,7 @@ from vllm.multimodal.utils import ( ...@@ -29,6 +30,7 @@ from vllm.multimodal.utils import (
encode_video_base64, encode_video_base64,
) )
from vllm.tokenizers import MistralTokenizer, get_tokenizer from vllm.tokenizers import MistralTokenizer, get_tokenizer
from vllm.utils.serial_utils import tensor2base64
from ..models.registry import HF_EXAMPLE_MODELS from ..models.registry import HF_EXAMPLE_MODELS
from ..utils import VLLM_PATH from ..utils import VLLM_PATH
...@@ -85,11 +87,6 @@ def phi3v_model_config_image_embeds(): ...@@ -85,11 +87,6 @@ def phi3v_model_config_image_embeds():
) )
@pytest.fixture(scope="module")
def phi3v_tokenizer():
return get_tokenizer(PHI3V_MODEL_ID)
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
def qwen2_audio_model_config(): def qwen2_audio_model_config():
return ModelConfig( return ModelConfig(
...@@ -115,11 +112,6 @@ def audio_embeds_model_config(): ...@@ -115,11 +112,6 @@ def audio_embeds_model_config():
) )
@pytest.fixture(scope="module")
def qwen2_audio_tokenizer():
return get_tokenizer(QWEN2AUDIO_MODEL_ID)
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
def qwen25omni_model_config_mm_interleaved(): def qwen25omni_model_config_mm_interleaved():
return ModelConfig( return ModelConfig(
...@@ -134,11 +126,6 @@ def qwen25omni_model_config_mm_interleaved(): ...@@ -134,11 +126,6 @@ def qwen25omni_model_config_mm_interleaved():
) )
@pytest.fixture(scope="module")
def qwen25omni_tokenizer():
return get_tokenizer(QWEN25OMNI_MODEL_ID)
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
def mistral_model_config(): def mistral_model_config():
return ModelConfig( return ModelConfig(
...@@ -150,11 +137,6 @@ def mistral_model_config(): ...@@ -150,11 +137,6 @@ def mistral_model_config():
) )
@pytest.fixture(scope="module")
def mistral_tokenizer():
return get_tokenizer(MISTRAL_MODEL_ID)
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def image_url(): def image_url():
image = ImageAsset("cherry_blossom") image = ImageAsset("cherry_blossom")
...@@ -239,7 +221,6 @@ def _assert_mm_data_inputs( ...@@ -239,7 +221,6 @@ def _assert_mm_data_inputs(
def test_parse_chat_messages_single_image( def test_parse_chat_messages_single_image(
phi3v_model_config, phi3v_model_config,
phi3v_tokenizer,
image_url, image_url,
): ):
conversation, mm_data, mm_uuids = parse_chat_messages( conversation, mm_data, mm_uuids = parse_chat_messages(
...@@ -253,7 +234,6 @@ def test_parse_chat_messages_single_image( ...@@ -253,7 +234,6 @@ def test_parse_chat_messages_single_image(
} }
], ],
phi3v_model_config, phi3v_model_config,
phi3v_tokenizer,
content_format="string", content_format="string",
) )
...@@ -266,7 +246,6 @@ def test_parse_chat_messages_single_image( ...@@ -266,7 +246,6 @@ def test_parse_chat_messages_single_image(
def test_parse_chat_messages_single_image_with_uuid( def test_parse_chat_messages_single_image_with_uuid(
phi3v_model_config, phi3v_model_config,
phi3v_tokenizer,
image_url, image_url,
): ):
image_uuid = str(hash(image_url)) image_uuid = str(hash(image_url))
...@@ -287,7 +266,6 @@ def test_parse_chat_messages_single_image_with_uuid( ...@@ -287,7 +266,6 @@ def test_parse_chat_messages_single_image_with_uuid(
} }
], ],
phi3v_model_config, phi3v_model_config,
phi3v_tokenizer,
content_format="string", content_format="string",
) )
...@@ -300,7 +278,6 @@ def test_parse_chat_messages_single_image_with_uuid( ...@@ -300,7 +278,6 @@ def test_parse_chat_messages_single_image_with_uuid(
def test_parse_chat_messages_single_empty_image_with_uuid( def test_parse_chat_messages_single_empty_image_with_uuid(
phi3v_model_config, phi3v_model_config,
phi3v_tokenizer,
image_url, image_url,
): ):
image_uuid = str(hash(image_url)) image_uuid = str(hash(image_url))
...@@ -319,7 +296,6 @@ def test_parse_chat_messages_single_empty_image_with_uuid( ...@@ -319,7 +296,6 @@ def test_parse_chat_messages_single_empty_image_with_uuid(
} }
], ],
phi3v_model_config, phi3v_model_config,
phi3v_tokenizer,
content_format="string", content_format="string",
) )
...@@ -332,7 +308,6 @@ def test_parse_chat_messages_single_empty_image_with_uuid( ...@@ -332,7 +308,6 @@ def test_parse_chat_messages_single_empty_image_with_uuid(
def test_parse_chat_messages_single_image_with_bad_uuid_format( def test_parse_chat_messages_single_image_with_bad_uuid_format(
phi3v_model_config, phi3v_model_config,
phi3v_tokenizer,
image_url, image_url,
): ):
image_uuid = str(hash(image_url)) image_uuid = str(hash(image_url))
...@@ -354,7 +329,6 @@ def test_parse_chat_messages_single_image_with_bad_uuid_format( ...@@ -354,7 +329,6 @@ def test_parse_chat_messages_single_image_with_bad_uuid_format(
} }
], ],
phi3v_model_config, phi3v_model_config,
phi3v_tokenizer,
content_format="string", content_format="string",
) )
...@@ -367,7 +341,6 @@ def test_parse_chat_messages_single_image_with_bad_uuid_format( ...@@ -367,7 +341,6 @@ def test_parse_chat_messages_single_image_with_bad_uuid_format(
def test_parse_chat_messages_multiple_images_with_uuids( def test_parse_chat_messages_multiple_images_with_uuids(
phi3v_model_config, phi3v_model_config,
phi3v_tokenizer,
image_url, image_url,
): ):
image_uuid1 = "my_uuid_1" image_uuid1 = "my_uuid_1"
...@@ -397,7 +370,6 @@ def test_parse_chat_messages_multiple_images_with_uuids( ...@@ -397,7 +370,6 @@ def test_parse_chat_messages_multiple_images_with_uuids(
} }
], ],
phi3v_model_config, phi3v_model_config,
phi3v_tokenizer,
content_format="string", content_format="string",
) )
...@@ -413,7 +385,6 @@ def test_parse_chat_messages_multiple_images_with_uuids( ...@@ -413,7 +385,6 @@ def test_parse_chat_messages_multiple_images_with_uuids(
def test_parse_chat_messages_multiple_empty_images_with_uuids( def test_parse_chat_messages_multiple_empty_images_with_uuids(
phi3v_model_config, phi3v_model_config,
phi3v_tokenizer,
image_url, image_url,
): ):
image_uuid1 = "my_uuid_1" image_uuid1 = "my_uuid_1"
...@@ -439,7 +410,6 @@ def test_parse_chat_messages_multiple_empty_images_with_uuids( ...@@ -439,7 +410,6 @@ def test_parse_chat_messages_multiple_empty_images_with_uuids(
} }
], ],
phi3v_model_config, phi3v_model_config,
phi3v_tokenizer,
content_format="string", content_format="string",
) )
...@@ -455,7 +425,6 @@ def test_parse_chat_messages_multiple_empty_images_with_uuids( ...@@ -455,7 +425,6 @@ def test_parse_chat_messages_multiple_empty_images_with_uuids(
def test_parse_chat_messages_mixed_empty_images_with_uuids( def test_parse_chat_messages_mixed_empty_images_with_uuids(
phi3v_model_config, phi3v_model_config,
phi3v_tokenizer,
image_url, image_url,
): ):
image_uuid1 = "my_uuid_1" image_uuid1 = "my_uuid_1"
...@@ -483,7 +452,6 @@ def test_parse_chat_messages_mixed_empty_images_with_uuids( ...@@ -483,7 +452,6 @@ def test_parse_chat_messages_mixed_empty_images_with_uuids(
} }
], ],
phi3v_model_config, phi3v_model_config,
phi3v_tokenizer,
content_format="string", content_format="string",
) )
...@@ -500,7 +468,6 @@ def test_parse_chat_messages_mixed_empty_images_with_uuids( ...@@ -500,7 +468,6 @@ def test_parse_chat_messages_mixed_empty_images_with_uuids(
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_parse_chat_messages_single_image_with_uuid_async( async def test_parse_chat_messages_single_image_with_uuid_async(
phi3v_model_config, phi3v_model_config,
phi3v_tokenizer,
image_url, image_url,
): ):
image_uuid = str(hash(image_url)) image_uuid = str(hash(image_url))
...@@ -519,7 +486,6 @@ async def test_parse_chat_messages_single_image_with_uuid_async( ...@@ -519,7 +486,6 @@ async def test_parse_chat_messages_single_image_with_uuid_async(
} }
], ],
phi3v_model_config, phi3v_model_config,
phi3v_tokenizer,
content_format="string", content_format="string",
) )
...@@ -533,7 +499,6 @@ async def test_parse_chat_messages_single_image_with_uuid_async( ...@@ -533,7 +499,6 @@ async def test_parse_chat_messages_single_image_with_uuid_async(
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_parse_chat_messages_empty_image_with_uuid_async( async def test_parse_chat_messages_empty_image_with_uuid_async(
phi3v_model_config, phi3v_model_config,
phi3v_tokenizer,
image_url, image_url,
): ):
image_uuid = str(hash(image_url)) image_uuid = str(hash(image_url))
...@@ -552,7 +517,6 @@ async def test_parse_chat_messages_empty_image_with_uuid_async( ...@@ -552,7 +517,6 @@ async def test_parse_chat_messages_empty_image_with_uuid_async(
} }
], ],
phi3v_model_config, phi3v_model_config,
phi3v_tokenizer,
content_format="string", content_format="string",
) )
...@@ -566,7 +530,6 @@ async def test_parse_chat_messages_empty_image_with_uuid_async( ...@@ -566,7 +530,6 @@ async def test_parse_chat_messages_empty_image_with_uuid_async(
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_parse_chat_messages_multiple_images_with_uuids_async( async def test_parse_chat_messages_multiple_images_with_uuids_async(
phi3v_model_config, phi3v_model_config,
phi3v_tokenizer,
image_url, image_url,
): ):
image_uuid1 = "my_uuid_1" image_uuid1 = "my_uuid_1"
...@@ -592,7 +555,6 @@ async def test_parse_chat_messages_multiple_images_with_uuids_async( ...@@ -592,7 +555,6 @@ async def test_parse_chat_messages_multiple_images_with_uuids_async(
} }
], ],
phi3v_model_config, phi3v_model_config,
phi3v_tokenizer,
content_format="string", content_format="string",
) )
...@@ -609,7 +571,6 @@ async def test_parse_chat_messages_multiple_images_with_uuids_async( ...@@ -609,7 +571,6 @@ async def test_parse_chat_messages_multiple_images_with_uuids_async(
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_parse_chat_messages_multiple_empty_images_with_uuids_async( async def test_parse_chat_messages_multiple_empty_images_with_uuids_async(
phi3v_model_config, phi3v_model_config,
phi3v_tokenizer,
image_url, image_url,
): ):
image_uuid1 = "my_uuid_1" image_uuid1 = "my_uuid_1"
...@@ -635,7 +596,6 @@ async def test_parse_chat_messages_multiple_empty_images_with_uuids_async( ...@@ -635,7 +596,6 @@ async def test_parse_chat_messages_multiple_empty_images_with_uuids_async(
} }
], ],
phi3v_model_config, phi3v_model_config,
phi3v_tokenizer,
content_format="string", content_format="string",
) )
...@@ -652,7 +612,6 @@ async def test_parse_chat_messages_multiple_empty_images_with_uuids_async( ...@@ -652,7 +612,6 @@ async def test_parse_chat_messages_multiple_empty_images_with_uuids_async(
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_parse_chat_messages_multiple_images_with_partial_uuids_async( async def test_parse_chat_messages_multiple_images_with_partial_uuids_async(
phi3v_model_config, phi3v_model_config,
phi3v_tokenizer,
image_url, image_url,
): ):
image_uuid2 = "my_uuid_2" image_uuid2 = "my_uuid_2"
...@@ -676,7 +635,6 @@ async def test_parse_chat_messages_multiple_images_with_partial_uuids_async( ...@@ -676,7 +635,6 @@ async def test_parse_chat_messages_multiple_images_with_partial_uuids_async(
} }
], ],
phi3v_model_config, phi3v_model_config,
phi3v_tokenizer,
content_format="string", content_format="string",
) )
...@@ -692,7 +650,6 @@ async def test_parse_chat_messages_multiple_images_with_partial_uuids_async( ...@@ -692,7 +650,6 @@ async def test_parse_chat_messages_multiple_images_with_partial_uuids_async(
def test_parse_chat_messages_empty_system( def test_parse_chat_messages_empty_system(
mistral_model_config, mistral_model_config,
mistral_tokenizer,
): ):
# Test string format # Test string format
conversation, _, _ = parse_chat_messages( conversation, _, _ = parse_chat_messages(
...@@ -704,7 +661,6 @@ def test_parse_chat_messages_empty_system( ...@@ -704,7 +661,6 @@ def test_parse_chat_messages_empty_system(
}, },
], ],
mistral_model_config, mistral_model_config,
mistral_tokenizer,
content_format="string", content_format="string",
) )
assert conversation == [ assert conversation == [
...@@ -722,7 +678,6 @@ def test_parse_chat_messages_empty_system( ...@@ -722,7 +678,6 @@ def test_parse_chat_messages_empty_system(
}, },
], ],
mistral_model_config, mistral_model_config,
mistral_tokenizer,
content_format="openai", content_format="openai",
) )
assert conversation == [ assert conversation == [
...@@ -734,7 +689,6 @@ def test_parse_chat_messages_empty_system( ...@@ -734,7 +689,6 @@ def test_parse_chat_messages_empty_system(
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_parse_chat_messages_single_image_async( async def test_parse_chat_messages_single_image_async(
phi3v_model_config, phi3v_model_config,
phi3v_tokenizer,
image_url, image_url,
): ):
conversation, mm_future, mm_uuids = parse_chat_messages_futures( conversation, mm_future, mm_uuids = parse_chat_messages_futures(
...@@ -748,7 +702,6 @@ async def test_parse_chat_messages_single_image_async( ...@@ -748,7 +702,6 @@ async def test_parse_chat_messages_single_image_async(
} }
], ],
phi3v_model_config, phi3v_model_config,
phi3v_tokenizer,
content_format="string", content_format="string",
) )
...@@ -761,7 +714,6 @@ async def test_parse_chat_messages_single_image_async( ...@@ -761,7 +714,6 @@ async def test_parse_chat_messages_single_image_async(
def test_parse_chat_messages_multiple_images( def test_parse_chat_messages_multiple_images(
phi3v_model_config, phi3v_model_config,
phi3v_tokenizer,
image_url, image_url,
): ):
conversation, mm_data, mm_uuids = parse_chat_messages( conversation, mm_data, mm_uuids = parse_chat_messages(
...@@ -779,7 +731,6 @@ def test_parse_chat_messages_multiple_images( ...@@ -779,7 +731,6 @@ def test_parse_chat_messages_multiple_images(
} }
], ],
phi3v_model_config, phi3v_model_config,
phi3v_tokenizer,
content_format="string", content_format="string",
) )
...@@ -795,7 +746,6 @@ def test_parse_chat_messages_multiple_images( ...@@ -795,7 +746,6 @@ def test_parse_chat_messages_multiple_images(
def test_parse_chat_messages_empty_pil_image_with_uuid( def test_parse_chat_messages_empty_pil_image_with_uuid(
phi3v_model_config, phi3v_model_config,
phi3v_tokenizer,
): ):
uuid = "abcd" uuid = "abcd"
conversation, mm_data, mm_uuids = parse_chat_messages( conversation, mm_data, mm_uuids = parse_chat_messages(
...@@ -809,7 +759,6 @@ def test_parse_chat_messages_empty_pil_image_with_uuid( ...@@ -809,7 +759,6 @@ def test_parse_chat_messages_empty_pil_image_with_uuid(
} }
], ],
phi3v_model_config, phi3v_model_config,
phi3v_tokenizer,
content_format="string", content_format="string",
) )
...@@ -825,7 +774,6 @@ def test_parse_chat_messages_empty_pil_image_with_uuid( ...@@ -825,7 +774,6 @@ def test_parse_chat_messages_empty_pil_image_with_uuid(
def test_parse_chat_messages_empty_image_embeds_with_uuid( def test_parse_chat_messages_empty_image_embeds_with_uuid(
phi3v_model_config_image_embeds, phi3v_model_config_image_embeds,
phi3v_tokenizer,
): ):
uuid = "abcd" uuid = "abcd"
conversation, mm_data, mm_uuids = parse_chat_messages( conversation, mm_data, mm_uuids = parse_chat_messages(
...@@ -839,7 +787,6 @@ def test_parse_chat_messages_empty_image_embeds_with_uuid( ...@@ -839,7 +787,6 @@ def test_parse_chat_messages_empty_image_embeds_with_uuid(
} }
], ],
phi3v_model_config_image_embeds, phi3v_model_config_image_embeds,
phi3v_tokenizer,
content_format="string", content_format="string",
) )
...@@ -857,7 +804,6 @@ def test_parse_chat_messages_empty_image_embeds_with_uuid( ...@@ -857,7 +804,6 @@ def test_parse_chat_messages_empty_image_embeds_with_uuid(
def test_parse_chat_messages_empty_audio_embeds_with_uuid( def test_parse_chat_messages_empty_audio_embeds_with_uuid(
audio_embeds_model_config, audio_embeds_model_config,
qwen2_audio_tokenizer,
): ):
"""Test audio_embeds with UUID (no actual embeds data).""" """Test audio_embeds with UUID (no actual embeds data)."""
uuid = "test-audio-uuid-123" uuid = "test-audio-uuid-123"
...@@ -873,7 +819,6 @@ def test_parse_chat_messages_empty_audio_embeds_with_uuid( ...@@ -873,7 +819,6 @@ def test_parse_chat_messages_empty_audio_embeds_with_uuid(
} }
], ],
audio_embeds_model_config, audio_embeds_model_config,
qwen2_audio_tokenizer,
content_format="string", content_format="string",
) )
...@@ -889,11 +834,8 @@ def test_parse_chat_messages_empty_audio_embeds_with_uuid( ...@@ -889,11 +834,8 @@ def test_parse_chat_messages_empty_audio_embeds_with_uuid(
def test_parse_chat_messages_audio_embeds_with_string( def test_parse_chat_messages_audio_embeds_with_string(
audio_embeds_model_config, audio_embeds_model_config,
qwen2_audio_tokenizer,
): ):
"""Test audio_embeds with base64 string embedding data.""" """Test audio_embeds with base64 string embedding data."""
import base64
import io
import torch import torch
...@@ -901,11 +843,7 @@ def test_parse_chat_messages_audio_embeds_with_string( ...@@ -901,11 +843,7 @@ def test_parse_chat_messages_audio_embeds_with_string(
audio_embedding = torch.randn(1, 128, 768) audio_embedding = torch.randn(1, 128, 768)
# Encode it as base64 # Encode it as base64
buffer = io.BytesIO() base64_audio_embedding = tensor2base64(audio_embedding)
torch.save(audio_embedding, buffer)
buffer.seek(0)
binary_data = buffer.read()
base64_audio_embedding = base64.b64encode(binary_data).decode("utf-8")
conversation, mm_data, mm_uuids = parse_chat_messages( conversation, mm_data, mm_uuids = parse_chat_messages(
[ [
...@@ -921,7 +859,6 @@ def test_parse_chat_messages_audio_embeds_with_string( ...@@ -921,7 +859,6 @@ def test_parse_chat_messages_audio_embeds_with_string(
} }
], ],
audio_embeds_model_config, audio_embeds_model_config,
qwen2_audio_tokenizer,
content_format="string", content_format="string",
) )
...@@ -939,11 +876,8 @@ def test_parse_chat_messages_audio_embeds_with_string( ...@@ -939,11 +876,8 @@ def test_parse_chat_messages_audio_embeds_with_string(
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_parse_chat_messages_audio_embeds_async( async def test_parse_chat_messages_audio_embeds_async(
audio_embeds_model_config, audio_embeds_model_config,
qwen2_audio_tokenizer,
): ):
"""Test audio_embeds with async futures.""" """Test audio_embeds with async futures."""
import base64
import io
import torch import torch
...@@ -951,11 +885,7 @@ async def test_parse_chat_messages_audio_embeds_async( ...@@ -951,11 +885,7 @@ async def test_parse_chat_messages_audio_embeds_async(
audio_embedding = torch.randn(1, 128, 768) audio_embedding = torch.randn(1, 128, 768)
# Encode it as base64 # Encode it as base64
buffer = io.BytesIO() base64_audio_embedding = tensor2base64(audio_embedding)
torch.save(audio_embedding, buffer)
buffer.seek(0)
binary_data = buffer.read()
base64_audio_embedding = base64.b64encode(binary_data).decode("utf-8")
conversation, mm_future, mm_uuids = parse_chat_messages_futures( conversation, mm_future, mm_uuids = parse_chat_messages_futures(
[ [
...@@ -971,7 +901,6 @@ async def test_parse_chat_messages_audio_embeds_async( ...@@ -971,7 +901,6 @@ async def test_parse_chat_messages_audio_embeds_async(
} }
], ],
audio_embeds_model_config, audio_embeds_model_config,
qwen2_audio_tokenizer,
content_format="string", content_format="string",
) )
...@@ -987,10 +916,186 @@ async def test_parse_chat_messages_audio_embeds_async( ...@@ -987,10 +916,186 @@ async def test_parse_chat_messages_audio_embeds_async(
_assert_mm_uuids(mm_uuids, 1, modality="audio", expected_uuids=[None]) _assert_mm_uuids(mm_uuids, 1, modality="audio", expected_uuids=[None])
def test_parse_chat_messages_multiple_image_embeds(
phi3v_model_config_image_embeds,
):
"""Test that multiple image_embeds in a single message are now supported.
This test validates the fix for the limitation that previously only allowed
one message with {'type': 'image_embeds'}. Now multiple image embeddings
can be provided in a single request, similar to regular images.
"""
# Create two sample image embedding tensors
image_embedding_1 = torch.randn(256, 1024)
image_embedding_2 = torch.randn(128, 1024)
# Encode them as base64 using the convenience function
base64_image_embedding_1 = tensor2base64(image_embedding_1)
base64_image_embedding_2 = tensor2base64(image_embedding_2)
conversation, mm_data, mm_uuids = parse_chat_messages(
[
{
"role": "user",
"content": [
{
"type": "image_embeds",
"image_embeds": base64_image_embedding_1,
},
{
"type": "image_embeds",
"image_embeds": base64_image_embedding_2,
},
{"type": "text", "text": "Describe these two images."},
],
}
],
phi3v_model_config_image_embeds,
content_format="string",
)
# Verify conversation structure
assert conversation == [
{
"role": "user",
"content": "<|image_1|>\n<|image_2|>\nDescribe these two images.",
}
]
# Verify mm_data contains a list of embeddings (not a single embedding)
assert mm_data is not None
assert "image" in mm_data
assert isinstance(mm_data["image"], list)
assert len(mm_data["image"]) == 2
# Verify each embedding has the correct shape
assert isinstance(mm_data["image"][0], torch.Tensor)
assert mm_data["image"][0].shape == image_embedding_1.shape
assert isinstance(mm_data["image"][1], torch.Tensor)
assert mm_data["image"][1].shape == image_embedding_2.shape
# Verify UUIDs (None since we didn't provide any)
_assert_mm_uuids(mm_uuids, 2, expected_uuids=[None, None])
def test_parse_chat_messages_multiple_image_embeds_with_uuids(
phi3v_model_config_image_embeds,
):
"""Test multiple image_embeds with UUIDs.
This validates that UUIDs are properly tracked for multiple embeddings.
"""
uuid1 = "image-uuid-1"
uuid2 = "image-uuid-2"
conversation, mm_data, mm_uuids = parse_chat_messages(
[
{
"role": "user",
"content": [
{
"type": "image_embeds",
"image_embeds": None,
"uuid": uuid1,
},
{
"type": "image_embeds",
"image_embeds": None,
"uuid": uuid2,
},
{"type": "text", "text": "Compare these images."},
],
}
],
phi3v_model_config_image_embeds,
content_format="string",
)
# Verify conversation structure
assert conversation == [
{
"role": "user",
"content": "<|image_1|>\n<|image_2|>\nCompare these images.",
}
]
# Verify mm_data contains a list with None values (UUID references)
assert mm_data is not None
assert "image" in mm_data
assert isinstance(mm_data["image"], list)
assert len(mm_data["image"]) == 2
assert mm_data["image"][0] is None
assert mm_data["image"][1] is None
# Verify UUIDs are correctly tracked
_assert_mm_uuids(mm_uuids, 2, expected_uuids=[uuid1, uuid2])
@pytest.mark.asyncio
async def test_parse_chat_messages_multiple_image_embeds_async(
phi3v_model_config_image_embeds,
):
"""Test multiple image_embeds with async parsing.
This validates the AsyncMultiModalItemTracker also supports multiple embeddings.
"""
# Create two sample image embedding tensors
image_embedding_1 = torch.randn(200, 768)
image_embedding_2 = torch.randn(150, 768)
# Encode them as base64 using the convenience function
base64_image_embedding_1 = tensor2base64(image_embedding_1)
base64_image_embedding_2 = tensor2base64(image_embedding_2)
conversation, mm_future, mm_uuids = parse_chat_messages_futures(
[
{
"role": "user",
"content": [
{
"type": "image_embeds",
"image_embeds": base64_image_embedding_1,
},
{
"type": "image_embeds",
"image_embeds": base64_image_embedding_2,
},
{"type": "text", "text": "What do these images show?"},
],
}
],
phi3v_model_config_image_embeds,
content_format="string",
)
# Verify conversation structure
assert conversation == [
{
"role": "user",
"content": "<|image_1|>\n<|image_2|>\nWhat do these images show?",
}
]
# Await the future and verify mm_data
mm_data = await mm_future
assert mm_data is not None
assert "image" in mm_data
assert isinstance(mm_data["image"], list)
assert len(mm_data["image"]) == 2
# Verify each embedding has the correct shape
assert isinstance(mm_data["image"][0], torch.Tensor)
assert mm_data["image"][0].shape == image_embedding_1.shape
assert isinstance(mm_data["image"][1], torch.Tensor)
assert mm_data["image"][1].shape == image_embedding_2.shape
# Verify UUIDs
_assert_mm_uuids(mm_uuids, 2, expected_uuids=[None, None])
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_parse_chat_messages_empty_image_embeds_with_uuid_async( async def test_parse_chat_messages_empty_image_embeds_with_uuid_async(
phi3v_model_config_image_embeds, phi3v_model_config_image_embeds,
phi3v_tokenizer,
): ):
uuid = "abcd" uuid = "abcd"
conversation, mm_future, mm_uuids = parse_chat_messages_futures( conversation, mm_future, mm_uuids = parse_chat_messages_futures(
...@@ -1004,7 +1109,6 @@ async def test_parse_chat_messages_empty_image_embeds_with_uuid_async( ...@@ -1004,7 +1109,6 @@ async def test_parse_chat_messages_empty_image_embeds_with_uuid_async(
} }
], ],
phi3v_model_config_image_embeds, phi3v_model_config_image_embeds,
phi3v_tokenizer,
content_format="string", content_format="string",
) )
...@@ -1024,7 +1128,6 @@ async def test_parse_chat_messages_empty_image_embeds_with_uuid_async( ...@@ -1024,7 +1128,6 @@ async def test_parse_chat_messages_empty_image_embeds_with_uuid_async(
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_parse_chat_messages_multiple_images_async( async def test_parse_chat_messages_multiple_images_async(
phi3v_model_config, phi3v_model_config,
phi3v_tokenizer,
image_url, image_url,
): ):
conversation, mm_future, mm_uuids = parse_chat_messages_futures( conversation, mm_future, mm_uuids = parse_chat_messages_futures(
...@@ -1042,7 +1145,6 @@ async def test_parse_chat_messages_multiple_images_async( ...@@ -1042,7 +1145,6 @@ async def test_parse_chat_messages_multiple_images_async(
} }
], ],
phi3v_model_config, phi3v_model_config,
phi3v_tokenizer,
content_format="string", content_format="string",
) )
...@@ -1058,7 +1160,6 @@ async def test_parse_chat_messages_multiple_images_async( ...@@ -1058,7 +1160,6 @@ async def test_parse_chat_messages_multiple_images_async(
def test_parse_chat_messages_placeholder_already_in_prompt( def test_parse_chat_messages_placeholder_already_in_prompt(
phi3v_model_config, phi3v_model_config,
phi3v_tokenizer,
image_url, image_url,
): ):
conversation, mm_data, mm_uuids = parse_chat_messages( conversation, mm_data, mm_uuids = parse_chat_messages(
...@@ -1076,7 +1177,6 @@ def test_parse_chat_messages_placeholder_already_in_prompt( ...@@ -1076,7 +1177,6 @@ def test_parse_chat_messages_placeholder_already_in_prompt(
} }
], ],
phi3v_model_config, phi3v_model_config,
phi3v_tokenizer,
content_format="string", content_format="string",
) )
assert conversation == [ assert conversation == [
...@@ -1091,7 +1191,6 @@ def test_parse_chat_messages_placeholder_already_in_prompt( ...@@ -1091,7 +1191,6 @@ def test_parse_chat_messages_placeholder_already_in_prompt(
def test_parse_chat_messages_placeholder_one_already_in_prompt( def test_parse_chat_messages_placeholder_one_already_in_prompt(
phi3v_model_config, phi3v_model_config,
phi3v_tokenizer,
image_url, image_url,
): ):
conversation, mm_data, mm_uuids = parse_chat_messages( conversation, mm_data, mm_uuids = parse_chat_messages(
...@@ -1110,7 +1209,6 @@ def test_parse_chat_messages_placeholder_one_already_in_prompt( ...@@ -1110,7 +1209,6 @@ def test_parse_chat_messages_placeholder_one_already_in_prompt(
} }
], ],
phi3v_model_config, phi3v_model_config,
phi3v_tokenizer,
content_format="string", content_format="string",
) )
...@@ -1127,7 +1225,6 @@ def test_parse_chat_messages_placeholder_one_already_in_prompt( ...@@ -1127,7 +1225,6 @@ def test_parse_chat_messages_placeholder_one_already_in_prompt(
def test_parse_chat_messages_multiple_images_across_messages( def test_parse_chat_messages_multiple_images_across_messages(
phi3v_model_config, phi3v_model_config,
phi3v_tokenizer,
image_url, image_url,
): ):
conversation, mm_data, mm_uuids = parse_chat_messages( conversation, mm_data, mm_uuids = parse_chat_messages(
...@@ -1149,7 +1246,6 @@ def test_parse_chat_messages_multiple_images_across_messages( ...@@ -1149,7 +1246,6 @@ def test_parse_chat_messages_multiple_images_across_messages(
}, },
], ],
phi3v_model_config, phi3v_model_config,
phi3v_tokenizer,
content_format="string", content_format="string",
) )
...@@ -1164,7 +1260,6 @@ def test_parse_chat_messages_multiple_images_across_messages( ...@@ -1164,7 +1260,6 @@ def test_parse_chat_messages_multiple_images_across_messages(
def test_parse_chat_messages_multiple_images_with_uuids_across_messages( def test_parse_chat_messages_multiple_images_with_uuids_across_messages(
phi3v_model_config, phi3v_model_config,
phi3v_tokenizer,
image_url, image_url,
): ):
image_uuid = str(hash(image_url)) image_uuid = str(hash(image_url))
...@@ -1195,7 +1290,6 @@ def test_parse_chat_messages_multiple_images_with_uuids_across_messages( ...@@ -1195,7 +1290,6 @@ def test_parse_chat_messages_multiple_images_with_uuids_across_messages(
}, },
], ],
phi3v_model_config, phi3v_model_config,
phi3v_tokenizer,
content_format="string", content_format="string",
) )
...@@ -1210,7 +1304,6 @@ def test_parse_chat_messages_multiple_images_with_uuids_across_messages( ...@@ -1210,7 +1304,6 @@ def test_parse_chat_messages_multiple_images_with_uuids_across_messages(
def test_parse_chat_messages_context_text_format( def test_parse_chat_messages_context_text_format(
phi3v_model_config, phi3v_model_config,
phi3v_tokenizer,
): ):
conversation, mm_data, mm_uuids = parse_chat_messages( conversation, mm_data, mm_uuids = parse_chat_messages(
[ [
...@@ -1222,7 +1315,6 @@ def test_parse_chat_messages_context_text_format( ...@@ -1222,7 +1315,6 @@ def test_parse_chat_messages_context_text_format(
{"role": "user", "content": "What about this one?"}, {"role": "user", "content": "What about this one?"},
], ],
phi3v_model_config, phi3v_model_config,
phi3v_tokenizer,
content_format="openai", content_format="openai",
) )
...@@ -1246,7 +1338,6 @@ def test_parse_chat_messages_context_text_format( ...@@ -1246,7 +1338,6 @@ def test_parse_chat_messages_context_text_format(
def test_parse_chat_messages_rejects_too_many_images_in_one_message( def test_parse_chat_messages_rejects_too_many_images_in_one_message(
phi3v_model_config, phi3v_model_config,
phi3v_tokenizer,
image_url, image_url,
): ):
with warnings.catch_warnings(): with warnings.catch_warnings():
...@@ -1277,14 +1368,12 @@ def test_parse_chat_messages_rejects_too_many_images_in_one_message( ...@@ -1277,14 +1368,12 @@ def test_parse_chat_messages_rejects_too_many_images_in_one_message(
} }
], ],
phi3v_model_config, phi3v_model_config,
phi3v_tokenizer,
content_format="string", content_format="string",
) )
def test_parse_chat_messages_rejects_too_many_images_across_messages( def test_parse_chat_messages_rejects_too_many_images_across_messages(
phi3v_model_config, phi3v_model_config,
phi3v_tokenizer,
image_url, image_url,
): ):
with warnings.catch_warnings(): with warnings.catch_warnings():
...@@ -1322,14 +1411,12 @@ def test_parse_chat_messages_rejects_too_many_images_across_messages( ...@@ -1322,14 +1411,12 @@ def test_parse_chat_messages_rejects_too_many_images_across_messages(
}, },
], ],
phi3v_model_config, phi3v_model_config,
phi3v_tokenizer,
content_format="string", content_format="string",
) )
def test_parse_chat_messages_multiple_images_uncommon_input( def test_parse_chat_messages_multiple_images_uncommon_input(
phi3v_model_config, phi3v_model_config,
phi3v_tokenizer,
image_url, image_url,
): ):
conversation, mm_data, mm_uuids = parse_chat_messages( conversation, mm_data, mm_uuids = parse_chat_messages(
...@@ -1344,7 +1431,6 @@ def test_parse_chat_messages_multiple_images_uncommon_input( ...@@ -1344,7 +1431,6 @@ def test_parse_chat_messages_multiple_images_uncommon_input(
} }
], ],
phi3v_model_config, phi3v_model_config,
phi3v_tokenizer,
content_format="string", content_format="string",
) )
...@@ -1360,7 +1446,6 @@ def test_parse_chat_messages_multiple_images_uncommon_input( ...@@ -1360,7 +1446,6 @@ def test_parse_chat_messages_multiple_images_uncommon_input(
def test_parse_chat_messages_multiple_images_interleave( def test_parse_chat_messages_multiple_images_interleave(
phi3v_model_config_mm_interleaved, phi3v_model_config_mm_interleaved,
phi3v_tokenizer,
image_url, image_url,
): ):
conversation, mm_data, mm_uuids = parse_chat_messages( conversation, mm_data, mm_uuids = parse_chat_messages(
...@@ -1380,7 +1465,6 @@ def test_parse_chat_messages_multiple_images_interleave( ...@@ -1380,7 +1465,6 @@ def test_parse_chat_messages_multiple_images_interleave(
} }
], ],
phi3v_model_config_mm_interleaved, phi3v_model_config_mm_interleaved,
phi3v_tokenizer,
content_format="string", content_format="string",
) )
...@@ -1398,7 +1482,6 @@ def test_parse_chat_messages_multiple_images_interleave( ...@@ -1398,7 +1482,6 @@ def test_parse_chat_messages_multiple_images_interleave(
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_parse_chat_messages_multiple_images_interleave_async( async def test_parse_chat_messages_multiple_images_interleave_async(
phi3v_model_config_mm_interleaved, phi3v_model_config_mm_interleaved,
phi3v_tokenizer,
image_url, image_url,
): ):
conversation, mm_data, mm_uuids = parse_chat_messages_futures( conversation, mm_data, mm_uuids = parse_chat_messages_futures(
...@@ -1418,7 +1501,6 @@ async def test_parse_chat_messages_multiple_images_interleave_async( ...@@ -1418,7 +1501,6 @@ async def test_parse_chat_messages_multiple_images_interleave_async(
} }
], ],
phi3v_model_config_mm_interleaved, phi3v_model_config_mm_interleaved,
phi3v_tokenizer,
content_format="string", content_format="string",
) )
...@@ -1436,7 +1518,6 @@ async def test_parse_chat_messages_multiple_images_interleave_async( ...@@ -1436,7 +1518,6 @@ async def test_parse_chat_messages_multiple_images_interleave_async(
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_parse_chat_messages_multiple_images_with_uuids_interleave_async( async def test_parse_chat_messages_multiple_images_with_uuids_interleave_async(
phi3v_model_config_mm_interleaved, phi3v_model_config_mm_interleaved,
phi3v_tokenizer,
image_url, image_url,
): ):
image_uuid = str(hash(image_url)) image_uuid = str(hash(image_url))
...@@ -1465,7 +1546,6 @@ async def test_parse_chat_messages_multiple_images_with_uuids_interleave_async( ...@@ -1465,7 +1546,6 @@ async def test_parse_chat_messages_multiple_images_with_uuids_interleave_async(
} }
], ],
phi3v_model_config_mm_interleaved, phi3v_model_config_mm_interleaved,
phi3v_tokenizer,
content_format="string", content_format="string",
) )
...@@ -1482,7 +1562,6 @@ async def test_parse_chat_messages_multiple_images_with_uuids_interleave_async( ...@@ -1482,7 +1562,6 @@ async def test_parse_chat_messages_multiple_images_with_uuids_interleave_async(
def test_parse_chat_messages_multiple_images_multiple_messages_interleave( def test_parse_chat_messages_multiple_images_multiple_messages_interleave(
phi3v_model_config_mm_interleaved, phi3v_model_config_mm_interleaved,
phi3v_tokenizer,
image_url, image_url,
): ):
conversation, mm_data, mm_uuids = parse_chat_messages( conversation, mm_data, mm_uuids = parse_chat_messages(
...@@ -1505,7 +1584,6 @@ def test_parse_chat_messages_multiple_images_multiple_messages_interleave( ...@@ -1505,7 +1584,6 @@ def test_parse_chat_messages_multiple_images_multiple_messages_interleave(
}, },
], ],
phi3v_model_config_mm_interleaved, phi3v_model_config_mm_interleaved,
phi3v_tokenizer,
content_format="string", content_format="string",
) )
...@@ -1523,7 +1601,6 @@ def test_parse_chat_messages_multiple_images_multiple_messages_interleave( ...@@ -1523,7 +1601,6 @@ def test_parse_chat_messages_multiple_images_multiple_messages_interleave(
def test_parse_chat_messages_multiple_images_with_uuids_multiple_messages_interleave( def test_parse_chat_messages_multiple_images_with_uuids_multiple_messages_interleave(
phi3v_model_config_mm_interleaved, phi3v_model_config_mm_interleaved,
phi3v_tokenizer,
image_url, image_url,
): ):
image_uuid = str(hash(image_url)) image_uuid = str(hash(image_url))
...@@ -1555,7 +1632,6 @@ def test_parse_chat_messages_multiple_images_with_uuids_multiple_messages_interl ...@@ -1555,7 +1632,6 @@ def test_parse_chat_messages_multiple_images_with_uuids_multiple_messages_interl
}, },
], ],
phi3v_model_config_mm_interleaved, phi3v_model_config_mm_interleaved,
phi3v_tokenizer,
content_format="string", content_format="string",
) )
...@@ -1573,7 +1649,6 @@ def test_parse_chat_messages_multiple_images_with_uuids_multiple_messages_interl ...@@ -1573,7 +1649,6 @@ def test_parse_chat_messages_multiple_images_with_uuids_multiple_messages_interl
def test_parse_chat_messages_multiple_modals_multiple_messages_interleave( def test_parse_chat_messages_multiple_modals_multiple_messages_interleave(
qwen25omni_model_config_mm_interleaved, qwen25omni_model_config_mm_interleaved,
qwen25omni_tokenizer,
image_url, image_url,
video_url, video_url,
audio_url, audio_url,
...@@ -1601,7 +1676,6 @@ def test_parse_chat_messages_multiple_modals_multiple_messages_interleave( ...@@ -1601,7 +1676,6 @@ def test_parse_chat_messages_multiple_modals_multiple_messages_interleave(
}, },
], ],
qwen25omni_model_config_mm_interleaved, qwen25omni_model_config_mm_interleaved,
qwen25omni_tokenizer,
content_format="string", content_format="string",
) )
...@@ -1627,7 +1701,6 @@ def test_parse_chat_messages_multiple_modals_multiple_messages_interleave( ...@@ -1627,7 +1701,6 @@ def test_parse_chat_messages_multiple_modals_multiple_messages_interleave(
def test_parse_chat_messages_multiple_modals_with_uuids_multiple_messages_interleave( def test_parse_chat_messages_multiple_modals_with_uuids_multiple_messages_interleave(
qwen25omni_model_config_mm_interleaved, qwen25omni_model_config_mm_interleaved,
qwen25omni_tokenizer,
image_url, image_url,
video_url, video_url,
audio_url, audio_url,
...@@ -1671,7 +1744,6 @@ def test_parse_chat_messages_multiple_modals_with_uuids_multiple_messages_interl ...@@ -1671,7 +1744,6 @@ def test_parse_chat_messages_multiple_modals_with_uuids_multiple_messages_interl
}, },
], ],
qwen25omni_model_config_mm_interleaved, qwen25omni_model_config_mm_interleaved,
qwen25omni_tokenizer,
content_format="string", content_format="string",
) )
...@@ -1699,7 +1771,6 @@ def test_parse_chat_messages_multiple_modals_with_uuids_multiple_messages_interl ...@@ -1699,7 +1771,6 @@ def test_parse_chat_messages_multiple_modals_with_uuids_multiple_messages_interl
def test_parse_chat_messages_multiple_modals_with_uuids_multiple_empty_media_messages_interleave( # noqa: E501 def test_parse_chat_messages_multiple_modals_with_uuids_multiple_empty_media_messages_interleave( # noqa: E501
qwen25omni_model_config_mm_interleaved, qwen25omni_model_config_mm_interleaved,
qwen25omni_tokenizer,
image_url, image_url,
video_url, video_url,
audio_url, audio_url,
...@@ -1743,7 +1814,6 @@ def test_parse_chat_messages_multiple_modals_with_uuids_multiple_empty_media_mes ...@@ -1743,7 +1814,6 @@ def test_parse_chat_messages_multiple_modals_with_uuids_multiple_empty_media_mes
}, },
], ],
qwen25omni_model_config_mm_interleaved, qwen25omni_model_config_mm_interleaved,
qwen25omni_tokenizer,
content_format="string", content_format="string",
) )
...@@ -1775,7 +1845,6 @@ def test_parse_chat_messages_multiple_modals_with_uuids_multiple_empty_media_mes ...@@ -1775,7 +1845,6 @@ def test_parse_chat_messages_multiple_modals_with_uuids_multiple_empty_media_mes
def test_parse_chat_messages_multiple_modals_with_partial_uuids_multiple_messages_interleave( # noqa: E501 def test_parse_chat_messages_multiple_modals_with_partial_uuids_multiple_messages_interleave( # noqa: E501
qwen25omni_model_config_mm_interleaved, qwen25omni_model_config_mm_interleaved,
qwen25omni_tokenizer,
image_url, image_url,
video_url, video_url,
audio_url, audio_url,
...@@ -1811,7 +1880,6 @@ def test_parse_chat_messages_multiple_modals_with_partial_uuids_multiple_message ...@@ -1811,7 +1880,6 @@ def test_parse_chat_messages_multiple_modals_with_partial_uuids_multiple_message
}, },
], ],
qwen25omni_model_config_mm_interleaved, qwen25omni_model_config_mm_interleaved,
qwen25omni_tokenizer,
content_format="string", content_format="string",
) )
...@@ -1837,7 +1905,6 @@ def test_parse_chat_messages_multiple_modals_with_partial_uuids_multiple_message ...@@ -1837,7 +1905,6 @@ def test_parse_chat_messages_multiple_modals_with_partial_uuids_multiple_message
def test_parse_chat_messages_multiple_images_interleave_with_placeholders( def test_parse_chat_messages_multiple_images_interleave_with_placeholders(
phi3v_model_config_mm_interleaved, phi3v_model_config_mm_interleaved,
phi3v_tokenizer,
image_url, image_url,
): ):
with pytest.raises( with pytest.raises(
...@@ -1861,7 +1928,6 @@ def test_parse_chat_messages_multiple_images_interleave_with_placeholders( ...@@ -1861,7 +1928,6 @@ def test_parse_chat_messages_multiple_images_interleave_with_placeholders(
} }
], ],
phi3v_model_config_mm_interleaved, phi3v_model_config_mm_interleaved,
phi3v_tokenizer,
content_format="string", content_format="string",
) )
...@@ -2238,9 +2304,7 @@ def test_resolve_content_format_examples(template_path, expected_format): ...@@ -2238,9 +2304,7 @@ def test_resolve_content_format_examples(template_path, expected_format):
assert resolved_format == expected_format assert resolved_format == expected_format
def test_parse_chat_messages_include_thinking_chunk( def test_parse_chat_messages_include_thinking_chunk(mistral_model_config):
mistral_model_config, mistral_tokenizer
):
messages = [ messages = [
{ {
"role": "system", "role": "system",
...@@ -2270,7 +2334,6 @@ def test_parse_chat_messages_include_thinking_chunk( ...@@ -2270,7 +2334,6 @@ def test_parse_chat_messages_include_thinking_chunk(
conversation_with_thinking, _, _ = parse_chat_messages( conversation_with_thinking, _, _ = parse_chat_messages(
messages, messages,
mistral_model_config, mistral_model_config,
mistral_tokenizer,
content_format="openai", content_format="openai",
) )
...@@ -2354,7 +2417,6 @@ def test_apply_mistral_chat_template_thinking_chunk(): ...@@ -2354,7 +2417,6 @@ def test_apply_mistral_chat_template_thinking_chunk():
def test_parse_chat_messages_single_empty_audio_with_uuid( def test_parse_chat_messages_single_empty_audio_with_uuid(
qwen2_audio_model_config, qwen2_audio_model_config,
qwen2_audio_tokenizer,
): ):
audio_uuid = "abcd" audio_uuid = "abcd"
conversation, mm_data, mm_uuids = parse_chat_messages( conversation, mm_data, mm_uuids = parse_chat_messages(
...@@ -2372,7 +2434,6 @@ def test_parse_chat_messages_single_empty_audio_with_uuid( ...@@ -2372,7 +2434,6 @@ def test_parse_chat_messages_single_empty_audio_with_uuid(
} }
], ],
qwen2_audio_model_config, qwen2_audio_model_config,
qwen2_audio_tokenizer,
content_format="string", content_format="string",
) )
...@@ -2390,7 +2451,6 @@ def test_parse_chat_messages_single_empty_audio_with_uuid( ...@@ -2390,7 +2451,6 @@ def test_parse_chat_messages_single_empty_audio_with_uuid(
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_parse_chat_messages_single_empty_audio_with_uuid_async( async def test_parse_chat_messages_single_empty_audio_with_uuid_async(
qwen2_audio_model_config, qwen2_audio_model_config,
qwen2_audio_tokenizer,
): ):
audio_uuid = "abcd" audio_uuid = "abcd"
conversation, mm_future, mm_uuids = parse_chat_messages_futures( conversation, mm_future, mm_uuids = parse_chat_messages_futures(
...@@ -2408,7 +2468,6 @@ async def test_parse_chat_messages_single_empty_audio_with_uuid_async( ...@@ -2408,7 +2468,6 @@ async def test_parse_chat_messages_single_empty_audio_with_uuid_async(
} }
], ],
qwen2_audio_model_config, qwen2_audio_model_config,
qwen2_audio_tokenizer,
content_format="string", content_format="string",
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment