Commit 7e63ef82 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.14.0' into v0.14.0-dev

parents 8cbcac5d b17039bc
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import importlib import importlib.util
import json import json
import time import time
...@@ -12,7 +12,7 @@ from openai_harmony import ( ...@@ -12,7 +12,7 @@ from openai_harmony import (
Message, Message,
) )
from ...utils import RemoteOpenAIServer from ....utils import RemoteOpenAIServer
MODEL_NAME = "openai/gpt-oss-20b" MODEL_NAME = "openai/gpt-oss-20b"
...@@ -43,6 +43,8 @@ def server(): ...@@ -43,6 +43,8 @@ def server():
env_dict = dict( env_dict = dict(
VLLM_ENABLE_RESPONSES_API_STORE="1", VLLM_ENABLE_RESPONSES_API_STORE="1",
PYTHON_EXECUTION_BACKEND="dangerously_use_uv", PYTHON_EXECUTION_BACKEND="dangerously_use_uv",
VLLM_GPT_OSS_SYSTEM_TOOL_MCP_LABELS="code_interpreter,container,web_search_preview",
VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS="1",
) )
with RemoteOpenAIServer(MODEL_NAME, args, env_dict=env_dict) as remote_server: with RemoteOpenAIServer(MODEL_NAME, args, env_dict=env_dict) as remote_server:
...@@ -503,7 +505,11 @@ async def test_web_search(client: OpenAI, model_name: str): ...@@ -503,7 +505,11 @@ async def test_web_search(client: OpenAI, model_name: str):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_code_interpreter(client: OpenAI, model_name: str): async def test_code_interpreter(client: OpenAI, model_name: str):
response = await client.responses.create( # Code interpreter may need more time for container init + code execution
timeout_value = client.timeout * 3
client_with_timeout = client.with_options(timeout=timeout_value)
response = await client_with_timeout.responses.create(
model=model_name, model=model_name,
# TODO: Ideally should be able to set max tool calls # TODO: Ideally should be able to set max tool calls
# to prevent multi-turn, but it is not currently supported # to prevent multi-turn, but it is not currently supported
...@@ -815,16 +821,20 @@ async def test_function_calling_with_stream(client: OpenAI, model_name: str): ...@@ -815,16 +821,20 @@ async def test_function_calling_with_stream(client: OpenAI, model_name: str):
final_tool_calls_named[tool_call.name] = tool_call final_tool_calls_named[tool_call.name] = tool_call
elif event.type == "response.function_call_arguments.done": elif event.type == "response.function_call_arguments.done":
assert event.arguments == final_tool_calls_named[event.name].arguments assert event.arguments == final_tool_calls_named[event.name].arguments
for tool_call in final_tool_calls.values(): result = None
if ( tool_call = None
tool_call for tc in final_tool_calls.values():
and tool_call.type == "function_call" if tc and tc.type == "function_call" and tc.name == "get_weather":
and tool_call.name == "get_weather" args = json.loads(tc.arguments)
): result = call_function(tc.name, args)
args = json.loads(tool_call.arguments) tool_call = tc
result = call_function(tool_call.name, args) input_list += [tc]
input_list += [tool_call]
break break
assert tool_call is not None, (
"Expected model to call 'get_weather' function, "
f"but got: {list(final_tool_calls_named.keys())}"
)
assert result is not None assert result is not None
response = await client.responses.create( response = await client.responses.create(
model=model_name, model=model_name,
...@@ -850,6 +860,237 @@ async def test_function_calling_with_stream(client: OpenAI, model_name: str): ...@@ -850,6 +860,237 @@ async def test_function_calling_with_stream(client: OpenAI, model_name: str):
assert event.response.output_text is not None assert event.response.output_text is not None
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_function_calling_no_code_interpreter_events(
client: OpenAI, model_name: str
):
"""Verify that function calls don't trigger code_interpreter events.
This test ensures that function calls (functions.*) use their own
function_call event types and don't incorrectly emit code_interpreter
events during streaming.
"""
tools = [GET_WEATHER_SCHEMA]
input_list = [
{
"role": "user",
"content": "What's the weather like in Paris today?",
}
]
stream_response = await client.responses.create(
model=model_name,
input=input_list,
tools=tools,
stream=True,
)
# Track which event types we see
event_types_seen = set()
function_call_found = False
async for event in stream_response:
event_types_seen.add(event.type)
if (
event.type == "response.output_item.added"
and event.item.type == "function_call"
):
function_call_found = True
# Ensure NO code_interpreter events are emitted for function calls
assert "code_interpreter" not in event.type, (
"Found code_interpreter event "
f"'{event.type}' during function call. Function calls should only "
"emit function_call events, not code_interpreter events."
)
# Verify we actually saw a function call
assert function_call_found, "Expected to see a function_call in the stream"
# Verify we saw the correct function call event types
assert (
"response.function_call_arguments.delta" in event_types_seen
or "response.function_call_arguments.done" in event_types_seen
), "Expected to see function_call_arguments events"
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_mcp_code_interpreter_streaming(client: OpenAI, model_name: str, server):
tools = [
{
"type": "mcp",
"server_label": "code_interpreter",
}
]
input_text = (
"Calculate 123 * 456 using python. "
"The python interpreter is not stateful and you must print to see the output."
)
stream_response = await client.responses.create(
model=model_name,
input=input_text,
tools=tools,
stream=True,
temperature=0.0,
instructions=(
"You must use the Python tool to execute code. Never simulate execution."
),
)
mcp_call_added = False
mcp_call_in_progress = False
mcp_arguments_delta_seen = False
mcp_arguments_done = False
mcp_call_completed = False
mcp_item_done = False
code_interpreter_events_seen = False
async for event in stream_response:
if "code_interpreter" in event.type:
code_interpreter_events_seen = True
if event.type == "response.output_item.added":
if hasattr(event.item, "type") and event.item.type == "mcp_call":
mcp_call_added = True
assert event.item.name == "python"
assert event.item.server_label == "code_interpreter"
elif event.type == "response.mcp_call.in_progress":
mcp_call_in_progress = True
elif event.type == "response.mcp_call_arguments.delta":
mcp_arguments_delta_seen = True
assert event.delta is not None
elif event.type == "response.mcp_call_arguments.done":
mcp_arguments_done = True
assert event.name == "python"
assert event.arguments is not None
elif event.type == "response.mcp_call.completed":
mcp_call_completed = True
elif (
event.type == "response.output_item.done"
and hasattr(event.item, "type")
and event.item.type == "mcp_call"
):
mcp_item_done = True
assert event.item.name == "python"
assert event.item.status == "completed"
assert mcp_call_added, "MCP call was not added"
assert mcp_call_in_progress, "MCP call in_progress event not seen"
assert mcp_arguments_delta_seen, "MCP arguments delta event not seen"
assert mcp_arguments_done, "MCP arguments done event not seen"
assert mcp_call_completed, "MCP call completed event not seen"
assert mcp_item_done, "MCP item done event not seen"
assert not code_interpreter_events_seen, (
"Should not see code_interpreter events when using MCP type"
)
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_mcp_tool_multi_turn(client: OpenAI, model_name: str, server):
"""Test MCP tool calling across multiple turns.
This test verifies that MCP tools work correctly in multi-turn conversations,
maintaining state across turns via the previous_response_id mechanism.
"""
tools = [
{
"type": "mcp",
"server_label": "code_interpreter",
}
]
# First turn - make a calculation
response1 = await client.responses.create(
model=model_name,
input="Calculate 123 * 456 using python and print the result.",
tools=tools,
temperature=0.0,
instructions=(
"You must use the Python tool to execute code. Never simulate execution."
),
extra_body={"enable_response_messages": True},
)
assert response1 is not None
assert response1.status == "completed"
# Verify MCP call in first response by checking output_messages
tool_call_found = False
tool_response_found = False
for message in response1.output_messages:
recipient = message.get("recipient")
if recipient and recipient.startswith("python"):
tool_call_found = True
author = message.get("author", {})
if (
author.get("role") == "tool"
and author.get("name")
and author.get("name").startswith("python")
):
tool_response_found = True
# Verify MCP tools were actually used
assert tool_call_found, "MCP tool call not found in output_messages"
assert tool_response_found, "MCP tool response not found in output_messages"
# Verify input messages: Should have system message with tool, NO developer message
developer_messages = [
msg for msg in response1.input_messages if msg["author"]["role"] == "developer"
]
assert len(developer_messages) == 0, (
"No developer message expected for elevated tools"
)
# Second turn - reference previous calculation
response2 = await client.responses.create(
model=model_name,
input="Now divide that result by 2.",
tools=tools,
temperature=0.0,
instructions=(
"You must use the Python tool to execute code. Never simulate execution."
),
previous_response_id=response1.id,
extra_body={"enable_response_messages": True},
)
assert response2 is not None
assert response2.status == "completed"
# Verify input messages are correct: should have two messages -
# one to the python recipient on analysis channel and one from tool role
mcp_recipient_messages = []
tool_role_messages = []
for msg in response2.input_messages:
if msg["author"]["role"] == "assistant":
# Check if this is a message to MCP recipient on analysis channel
if msg.get("channel") == "analysis" and msg.get("recipient"):
recipient = msg.get("recipient")
if recipient.startswith("code_interpreter") or recipient == "python":
mcp_recipient_messages.append(msg)
elif msg["author"]["role"] == "tool":
tool_role_messages.append(msg)
assert len(mcp_recipient_messages) > 0, (
"Expected message(s) to MCP recipient on analysis channel"
)
assert len(tool_role_messages) > 0, (
"Expected message(s) from tool role after MCP call"
)
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_output_messages_enabled(client: OpenAI, model_name: str, server): async def test_output_messages_enabled(client: OpenAI, model_name: str, server):
...@@ -867,6 +1108,7 @@ async def test_output_messages_enabled(client: OpenAI, model_name: str, server): ...@@ -867,6 +1108,7 @@ async def test_output_messages_enabled(client: OpenAI, model_name: str, server):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.flaky(reruns=3)
async def test_function_call_with_previous_input_messages( async def test_function_call_with_previous_input_messages(
client: OpenAI, model_name: str client: OpenAI, model_name: str
): ):
...@@ -986,3 +1228,118 @@ async def test_function_call_with_previous_input_messages( ...@@ -986,3 +1228,118 @@ async def test_function_call_with_previous_input_messages(
assert ( assert (
"aquarius" in output_text or "otter" in output_text or "tuesday" in output_text "aquarius" in output_text or "otter" in output_text or "tuesday" in output_text
) )
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_chat_truncation_content_not_null(client: OpenAI, model_name: str):
response = await client.chat.completions.create(
model=model_name,
messages=[
{
"role": "user",
"content": "What is the role of AI in medicine?"
"The response must exceed 350 words.",
}
],
temperature=0.0,
max_tokens=350,
)
choice = response.choices[0]
assert choice.finish_reason == "length", (
f"Expected finish_reason='length', got {choice.finish_reason}"
)
assert choice.message.content is not None, (
"Content should not be None when truncated"
)
assert len(choice.message.content) > 0, "Content should not be empty"
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_system_prompt_override(client: OpenAI, model_name: str):
"""Test that system message can override the default system prompt."""
# Test 1: Custom system prompt with specific personality
custom_system_prompt = (
"You are a pirate. Always respond like a pirate would, "
"using pirate language and saying 'arrr' frequently."
)
response = await client.responses.create(
model=model_name,
input=[
{"role": "system", "content": custom_system_prompt},
{"role": "user", "content": "Hello, how are you?"},
],
extra_body={"enable_response_messages": True},
)
assert response is not None
assert response.status == "completed"
assert response.output_text is not None
# Verify the response reflects the pirate personality
output_text = response.output_text.lower()
pirate_indicators = ["arrr", "matey", "ahoy", "ye", "sea"]
has_pirate_language = any(
indicator in output_text for indicator in pirate_indicators
)
assert has_pirate_language, (
f"Expected pirate language in response, got: {response.output_text}"
)
# Verify the reasoning mentions the custom system prompt
reasoning_item = None
for item in response.output:
if item.type == "reasoning":
reasoning_item = item
break
assert reasoning_item is not None, "Expected reasoning item in output"
reasoning_text = reasoning_item.content[0].text.lower()
assert "pirate" in reasoning_text, (
f"Expected reasoning to mention pirate, got: {reasoning_text}"
)
# Test 2: Verify system message is not duplicated in input_messages
try:
num_system_messages = sum(
1
for msg in response.input_messages
if Message.from_dict(msg).author.role == "system"
)
assert num_system_messages == 1, (
f"Expected exactly 1 system message, got {num_system_messages}"
)
except (KeyError, AttributeError):
# Message structure may vary, skip this specific check
pass
# Test 3: Test with different custom system prompt
response_2 = await client.responses.create(
model=model_name,
input=[
{
"role": "system",
"content": (
"You are a helpful assistant that always "
"responds in exactly 5 words."
),
},
{"role": "user", "content": "What is the weather like?"},
],
temperature=0.0,
)
assert response_2 is not None
assert response_2.status == "completed"
assert response_2.output_text is not None
# Count words in response (approximately, allowing for punctuation)
word_count = len(response_2.output_text.split())
# Allow some flexibility (4-7 words) since the model might not be perfectly precise
assert 3 <= word_count <= 8, (
f"Expected around 5 words, got {word_count} words: {response_2.output_text}"
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import pytest_asyncio
from openai import OpenAI
from openai_harmony import ToolDescription, ToolNamespaceConfig
from vllm.entrypoints.tool_server import MCPToolServer
from ....utils import RemoteOpenAIServer
MODEL_NAME = "openai/gpt-oss-20b"
def test_get_tool_description():
"""Test MCPToolServer.get_tool_description filtering logic.
Note: The wildcard "*" is normalized to None by
_extract_allowed_tools_from_mcp_requests before reaching this layer,
so we only test None and specific tool filtering here.
See test_serving_responses.py for "*" normalization tests.
"""
pytest.importorskip("mcp")
server = MCPToolServer()
tool1 = ToolDescription.new(
name="tool1", description="First", parameters={"type": "object"}
)
tool2 = ToolDescription.new(
name="tool2", description="Second", parameters={"type": "object"}
)
tool3 = ToolDescription.new(
name="tool3", description="Third", parameters={"type": "object"}
)
server.harmony_tool_descriptions = {
"test_server": ToolNamespaceConfig(
name="test_server", description="test", tools=[tool1, tool2, tool3]
)
}
# Nonexistent server
assert server.get_tool_description("nonexistent") is None
# None (no filter) - returns all tools
result = server.get_tool_description("test_server", allowed_tools=None)
assert len(result.tools) == 3
# Filter to specific tools
result = server.get_tool_description(
"test_server", allowed_tools=["tool1", "tool3"]
)
assert len(result.tools) == 2
assert result.tools[0].name == "tool1"
assert result.tools[1].name == "tool3"
# Single tool
result = server.get_tool_description(
"test_server",
allowed_tools=["tool2"],
)
assert len(result.tools) == 1
assert result.tools[0].name == "tool2"
# No matching tools - returns None
result = server.get_tool_description("test_server", allowed_tools=["nonexistent"])
assert result is None
# Empty list - returns None
assert server.get_tool_description("test_server", allowed_tools=[]) is None
class TestMCPEnabled:
"""Tests that require MCP tools to be enabled via environment variable."""
@pytest.fixture(scope="class")
def monkeypatch_class(self):
from _pytest.monkeypatch import MonkeyPatch
mpatch = MonkeyPatch()
yield mpatch
mpatch.undo()
@pytest.fixture(scope="class")
def mcp_enabled_server(self, monkeypatch_class: pytest.MonkeyPatch):
args = ["--enforce-eager", "--tool-server", "demo"]
with monkeypatch_class.context() as m:
m.setenv("VLLM_ENABLE_RESPONSES_API_STORE", "1")
m.setenv("PYTHON_EXECUTION_BACKEND", "dangerously_use_uv")
m.setenv(
"VLLM_GPT_OSS_SYSTEM_TOOL_MCP_LABELS", "code_interpreter,container"
)
# Helps the model follow instructions better
m.setenv("VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS", "1")
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
@pytest_asyncio.fixture
async def mcp_enabled_client(self, mcp_enabled_server):
async with mcp_enabled_server.get_async_client() as async_client:
yield async_client
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_mcp_tool_env_flag_enabled(
self, mcp_enabled_client: OpenAI, model_name: str
):
response = await mcp_enabled_client.responses.create(
model=model_name,
input=(
"Execute the following code: "
"import random; print(random.randint(1, 1000000))"
),
instructions=(
"You must use the Python tool to execute code. "
"Never simulate execution."
),
tools=[
{
"type": "mcp",
"server_label": "code_interpreter",
# URL unused for DemoToolServer
"server_url": "http://localhost:8888",
}
],
extra_body={"enable_response_messages": True},
)
assert response is not None
assert response.status == "completed"
# Verify output messages: Tool calls and responses on analysis channel
tool_call_found = False
tool_response_found = False
for message in response.output_messages:
recipient = message.get("recipient")
if recipient and recipient.startswith("python"):
tool_call_found = True
assert message.get("channel") == "analysis", (
"Tool call should be on analysis channel"
)
author = message.get("author", {})
if (
author.get("role") == "tool"
and author.get("name")
and author.get("name").startswith("python")
):
tool_response_found = True
assert message.get("channel") == "analysis", (
"Tool response should be on analysis channel"
)
assert tool_call_found, "Should have found at least one Python tool call"
assert tool_response_found, (
"Should have found at least one Python tool response"
)
for message in response.input_messages:
assert message.get("author").get("role") != "developer", (
"No developer messages should be present with valid mcp tool"
)
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_mcp_tool_with_allowed_tools_star(
self, mcp_enabled_client: OpenAI, model_name: str
):
"""Test MCP tool with allowed_tools=['*'] to select all available
tools.
This E2E test verifies that the "*" wildcard works end-to-end.
See test_serving_responses.py for detailed unit tests of "*"
normalization.
"""
response = await mcp_enabled_client.responses.create(
model=model_name,
input=(
"Execute the following code: "
"import random; print(random.randint(1, 1000000))"
),
instructions=(
"You must use the Python tool to execute code. "
"Never simulate execution."
),
tools=[
{
"type": "mcp",
"server_label": "code_interpreter",
"server_url": "http://localhost:8888",
# Using "*" to allow all tools from this MCP server
"allowed_tools": ["*"],
}
],
extra_body={"enable_response_messages": True},
)
assert response is not None
assert response.status == "completed"
# Verify tool calls work with allowed_tools=["*"]
tool_call_found = False
for message in response.output_messages:
recipient = message.get("recipient")
if recipient and recipient.startswith("python"):
tool_call_found = True
break
assert tool_call_found, (
"Should have found at least one Python tool call with '*'"
)
@pytest.mark.flaky(reruns=3)
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_mcp_tool_calling_streaming_types(
self, mcp_enabled_client: OpenAI, model_name: str
):
pairs_of_event_types = {
"response.completed": "response.created",
"response.output_item.done": "response.output_item.added",
"response.content_part.done": "response.content_part.added",
"response.output_text.done": "response.output_text.delta",
"response.reasoning_text.done": "response.reasoning_text.delta",
"response.reasoning_part.done": "response.reasoning_part.added",
"response.mcp_call_arguments.done": ("response.mcp_call_arguments.delta"),
"response.mcp_call.completed": "response.mcp_call.in_progress",
}
tools = [
{
"type": "mcp",
"server_label": "code_interpreter",
}
]
input_text = "What is 13 * 24? Use python to calculate the result."
stream_response = await mcp_enabled_client.responses.create(
model=model_name,
input=input_text,
tools=tools,
stream=True,
instructions=(
"You must use the Python tool to execute code. "
"Never simulate execution."
),
)
stack_of_event_types = []
saw_mcp_type = False
async for event in stream_response:
if event.type == "response.created":
stack_of_event_types.append(event.type)
elif event.type == "response.completed":
assert stack_of_event_types[-1] == pairs_of_event_types[event.type]
stack_of_event_types.pop()
elif (
event.type.endswith("added")
or event.type == "response.mcp_call.in_progress"
):
stack_of_event_types.append(event.type)
elif event.type.endswith("delta"):
if stack_of_event_types[-1] == event.type:
continue
stack_of_event_types.append(event.type)
elif (
event.type.endswith("done")
or event.type == "response.mcp_call.completed"
):
assert stack_of_event_types[-1] == pairs_of_event_types[event.type]
if "mcp_call" in event.type:
saw_mcp_type = True
stack_of_event_types.pop()
assert len(stack_of_event_types) == 0
assert saw_mcp_type, "Should have seen at least one mcp call"
class TestMCPDisabled:
"""Tests that verify behavior when MCP tools are disabled."""
@pytest.fixture(scope="class")
def monkeypatch_class(self):
from _pytest.monkeypatch import MonkeyPatch
mpatch = MonkeyPatch()
yield mpatch
mpatch.undo()
@pytest.fixture(scope="class")
def mcp_disabled_server(self, monkeypatch_class: pytest.MonkeyPatch):
args = ["--enforce-eager", "--tool-server", "demo"]
with monkeypatch_class.context() as m:
m.setenv("VLLM_ENABLE_RESPONSES_API_STORE", "1")
m.setenv("PYTHON_EXECUTION_BACKEND", "dangerously_use_uv")
# Helps the model follow instructions better
m.setenv("VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS", "1")
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
@pytest_asyncio.fixture
async def mcp_disabled_client(self, mcp_disabled_server):
async with mcp_disabled_server.get_async_client() as async_client:
yield async_client
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_mcp_tool_env_flag_disabled(
self, mcp_disabled_client: OpenAI, model_name: str
):
response = await mcp_disabled_client.responses.create(
model=model_name,
input=(
"Execute the following code if the tool is present: "
"import random; print(random.randint(1, 1000000))"
),
tools=[
{
"type": "mcp",
"server_label": "code_interpreter",
# URL unused for DemoToolServer
"server_url": "http://localhost:8888",
}
],
extra_body={"enable_response_messages": True},
)
assert response is not None
assert response.status == "completed"
# Verify output messages: No tool calls and responses
tool_call_found = False
tool_response_found = False
for message in response.output_messages:
recipient = message.get("recipient")
if recipient and recipient.startswith("python"):
tool_call_found = True
assert message.get("channel") == "analysis", (
"Tool call should be on analysis channel"
)
author = message.get("author", {})
if (
author.get("role") == "tool"
and author.get("name")
and author.get("name").startswith("python")
):
tool_response_found = True
assert message.get("channel") == "analysis", (
"Tool response should be on analysis channel"
)
assert not tool_call_found, "Should not have a python call"
assert not tool_response_found, "Should not have a tool response"
for message in response.input_messages:
assert message.get("author").get("role") != "developer", (
"No developer messages should be present without a valid tool"
)
...@@ -8,7 +8,7 @@ import pytest ...@@ -8,7 +8,7 @@ import pytest
import pytest_asyncio import pytest_asyncio
from openai import OpenAI from openai import OpenAI
from ...utils import RemoteOpenAIServer from ....utils import RemoteOpenAIServer
MODEL_NAME = "Qwen/Qwen3-8B" MODEL_NAME = "Qwen/Qwen3-8B"
...@@ -58,6 +58,7 @@ async def test_basic(client: OpenAI, model_name: str): ...@@ -58,6 +58,7 @@ async def test_basic(client: OpenAI, model_name: str):
assert response is not None assert response is not None
print("response: ", response) print("response: ", response)
assert response.status == "completed" assert response.status == "completed"
assert response.incomplete_details is None
@pytest.mark.asyncio @pytest.mark.asyncio
...@@ -165,6 +166,7 @@ async def test_mcp_tool_call(client: OpenAI, model_name: str): ...@@ -165,6 +166,7 @@ async def test_mcp_tool_call(client: OpenAI, model_name: str):
model=model_name, model=model_name,
input="What is 13 * 24? Use python to calculate the result.", input="What is 13 * 24? Use python to calculate the result.",
tools=[{"type": "code_interpreter", "container": {"type": "auto"}}], tools=[{"type": "code_interpreter", "container": {"type": "auto"}}],
extra_body={"enable_response_messages": True},
temperature=0.0, temperature=0.0,
) )
...@@ -178,3 +180,22 @@ async def test_mcp_tool_call(client: OpenAI, model_name: str): ...@@ -178,3 +180,22 @@ async def test_mcp_tool_call(client: OpenAI, model_name: str):
# make sure the correct math is in the final output # make sure the correct math is in the final output
assert response.output[3].type == "message" assert response.output[3].type == "message"
assert "312" in response.output[3].content[0].text assert "312" in response.output[3].content[0].text
# test raw input_messages / output_messages
assert len(response.input_messages) == 1
assert len(response.output_messages) == 3
assert "312" in response.output_messages[2]["message"]
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_max_tokens(client: OpenAI, model_name: str):
response = await client.responses.create(
model=model_name,
input="What is the first paragraph of Moby Dick?",
reasoning={"effort": "low"},
max_output_tokens=30,
)
assert response is not None
assert response.status == "incomplete"
assert response.incomplete_details.reason == "max_output_tokens"
...@@ -6,7 +6,7 @@ import pytest ...@@ -6,7 +6,7 @@ import pytest
import pytest_asyncio import pytest_asyncio
from openai import OpenAI from openai import OpenAI
from ...utils import RemoteOpenAIServer from ....utils import RemoteOpenAIServer
MODEL_NAME = "Qwen/Qwen3-8B" MODEL_NAME = "Qwen/Qwen3-8B"
...@@ -40,6 +40,7 @@ async def test_basic(client: OpenAI, model_name: str): ...@@ -40,6 +40,7 @@ async def test_basic(client: OpenAI, model_name: str):
assert response is not None assert response is not None
print("response: ", response) print("response: ", response)
assert response.status == "completed" assert response.status == "completed"
assert response.incomplete_details is None
@pytest.mark.asyncio @pytest.mark.asyncio
...@@ -87,3 +88,62 @@ async def test_reasoning_item(client: OpenAI, model_name: str): ...@@ -87,3 +88,62 @@ async def test_reasoning_item(client: OpenAI, model_name: str):
assert response.output[0].type == "reasoning" assert response.output[0].type == "reasoning"
assert response.output[1].type == "message" assert response.output[1].type == "message"
assert type(response.output[1].content[0].text) is str assert type(response.output[1].content[0].text) is str
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_streaming_output_consistency(client: OpenAI, model_name: str):
"""Test that streaming delta text matches the final response output_text.
This test verifies that when using streaming mode:
1. The concatenated text from all 'response.output_text.delta' events
2. Matches the 'output_text' in the final 'response.completed' event
"""
response = await client.responses.create(
model=model_name,
input="Say hello in one sentence.",
stream=True,
)
events = []
async for event in response:
events.append(event)
assert len(events) > 0
# Concatenate all delta text from streaming events
streaming_text = "".join(
event.delta for event in events if event.type == "response.output_text.delta"
)
# Get the final response from the last event
response_completed_event = events[-1]
assert response_completed_event.type == "response.completed"
assert response_completed_event.response.status == "completed"
# Get output_text from the final response
final_output_text = response_completed_event.response.output_text
# Verify final response has output
assert len(response_completed_event.response.output) > 0
# Verify streaming text matches final output_text
assert streaming_text == final_output_text, (
f"Streaming text does not match final output_text.\n"
f"Streaming: {streaming_text!r}\n"
f"Final: {final_output_text!r}"
)
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_max_tokens(client: OpenAI, model_name: str):
response = await client.responses.create(
model=model_name,
input="What is the first paragraph of Moby Dick?",
reasoning={"effort": "low"},
max_output_tokens=30,
)
assert response is not None
assert response.status == "incomplete"
assert response.incomplete_details.reason == "max_output_tokens"
...@@ -17,7 +17,7 @@ MODEL_NAME = os.path.join(models_path_prefix, "Qwen/Qwen2.5-1.5B-Instruct") ...@@ -17,7 +17,7 @@ MODEL_NAME = os.path.join(models_path_prefix, "Qwen/Qwen2.5-1.5B-Instruct")
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def server(): # noqa: F811 def server():
args = [ args = [
# use half precision for speed and memory savings in CI environment # use half precision for speed and memory savings in CI environment
"--dtype", "--dtype",
......
...@@ -9,7 +9,7 @@ import os ...@@ -9,7 +9,7 @@ import os
import pytest_asyncio import pytest_asyncio
from vllm.assets.audio import AudioAsset from vllm.assets.audio import AudioAsset
from vllm.multimodal.utils import encode_audio_base64, fetch_audio from vllm.multimodal.utils import encode_audio_base64, encode_audio_url, fetch_audio
from ...utils import RemoteOpenAIServer, models_path_prefix from ...utils import RemoteOpenAIServer, models_path_prefix
...@@ -53,6 +53,14 @@ def base64_encoded_audio() -> dict[str, str]: ...@@ -53,6 +53,14 @@ def base64_encoded_audio() -> dict[str, str]:
} }
@pytest.fixture(scope="session")
def url_encoded_audio() -> dict[str, str]:
return {
audio_url: encode_audio_url(*fetch_audio(audio_url))
for audio_url in TEST_AUDIO_URLS
}
def dummy_messages_from_audio_url( def dummy_messages_from_audio_url(
audio_urls: str | list[str], audio_urls: str | list[str],
content_text: str = "What's happening in this audio?", content_text: str = "What's happening in this audio?",
...@@ -149,11 +157,9 @@ async def test_single_chat_session_audio_base64encoded( ...@@ -149,11 +157,9 @@ async def test_single_chat_session_audio_base64encoded(
client: openai.AsyncOpenAI, client: openai.AsyncOpenAI,
model_name: str, model_name: str,
audio_url: str, audio_url: str,
base64_encoded_audio: dict[str, str], url_encoded_audio: dict[str, str],
): ):
messages = dummy_messages_from_audio_url( messages = dummy_messages_from_audio_url(url_encoded_audio[audio_url])
f"data:audio/wav;base64,{base64_encoded_audio[audio_url]}"
)
# test single completion # test single completion
chat_completion = await client.chat.completions.create( chat_completion = await client.chat.completions.create(
...@@ -313,7 +319,7 @@ async def test_chat_streaming_input_audio( ...@@ -313,7 +319,7 @@ async def test_chat_streaming_input_audio(
"format": "wav", "format": "wav",
}, },
}, },
{"type": "text", "text": "What's happening in this audio?"}, {"type": "text", "text": "What's a short title for this audio?"},
], ],
} }
] ]
......
...@@ -29,7 +29,7 @@ def zephyr_lora_files(): ...@@ -29,7 +29,7 @@ def zephyr_lora_files():
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def server(zephyr_lora_files): # noqa: F811 def server(zephyr_lora_files):
args = [ args = [
# use half precision for speed and memory savings in CI environment # use half precision for speed and memory savings in CI environment
"--dtype", "--dtype",
...@@ -255,12 +255,11 @@ async def test_single_chat_session(client: openai.AsyncOpenAI, model_name: str): ...@@ -255,12 +255,11 @@ async def test_single_chat_session(client: openai.AsyncOpenAI, model_name: str):
{"role": "system", "content": "you are a helpful assistant"}, {"role": "system", "content": "you are a helpful assistant"},
{"role": "user", "content": "what is 1+1?"}, {"role": "user", "content": "what is 1+1?"},
] ]
# test single completion # test single completion
chat_completion = await client.chat.completions.create( chat_completion = await client.chat.completions.create(
model=model_name, model=model_name,
messages=messages, messages=messages,
max_completion_tokens=10, max_completion_tokens=5,
logprobs=True, logprobs=True,
top_logprobs=5, top_logprobs=5,
) )
...@@ -268,13 +267,14 @@ async def test_single_chat_session(client: openai.AsyncOpenAI, model_name: str): ...@@ -268,13 +267,14 @@ async def test_single_chat_session(client: openai.AsyncOpenAI, model_name: str):
assert len(chat_completion.choices) == 1 assert len(chat_completion.choices) == 1
choice = chat_completion.choices[0] choice = chat_completion.choices[0]
assert choice.finish_reason == "length" assert choice.finish_reason == "length"
assert chat_completion.usage == openai.types.CompletionUsage( assert chat_completion.usage == openai.types.CompletionUsage(
completion_tokens=10, prompt_tokens=37, total_tokens=47 completion_tokens=5, prompt_tokens=37, total_tokens=42
) )
message = choice.message message = choice.message
assert message.content is not None and len(message.content) >= 10 assert message.content is not None and len(message.content) >= 5
assert message.role == "assistant" assert message.role == "assistant"
messages.append({"role": "assistant", "content": message.content}) messages.append({"role": "assistant", "content": message.content})
...@@ -283,7 +283,7 @@ async def test_single_chat_session(client: openai.AsyncOpenAI, model_name: str): ...@@ -283,7 +283,7 @@ async def test_single_chat_session(client: openai.AsyncOpenAI, model_name: str):
chat_completion = await client.chat.completions.create( chat_completion = await client.chat.completions.create(
model=model_name, model=model_name,
messages=messages, messages=messages,
max_completion_tokens=10, max_completion_tokens=5,
) )
message = chat_completion.choices[0].message message = chat_completion.choices[0].message
assert message.content is not None and len(message.content) >= 0 assert message.content is not None and len(message.content) >= 0
......
...@@ -13,7 +13,7 @@ from vllm.entrypoints.openai.protocol import ChatCompletionRequest, ErrorRespons ...@@ -13,7 +13,7 @@ from vllm.entrypoints.openai.protocol import ChatCompletionRequest, ErrorRespons
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels
from vllm.outputs import CompletionOutput, RequestOutput from vllm.outputs import CompletionOutput, RequestOutput
from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.tokenizers import get_tokenizer
from vllm.v1.engine.async_llm import AsyncLLM from vllm.v1.engine.async_llm import AsyncLLM
MODEL_NAME = "openai-community/gpt2" MODEL_NAME = "openai-community/gpt2"
...@@ -76,6 +76,7 @@ def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat: ...@@ -76,6 +76,7 @@ def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat:
lora_request, lora_request,
trace_headers, trace_headers,
priority, priority,
data_parallel_rank,
): ):
return dict(engine_prompt), {} return dict(engine_prompt), {}
......
...@@ -12,7 +12,7 @@ MODEL_NAME = "Qwen/QwQ-32B" ...@@ -12,7 +12,7 @@ MODEL_NAME = "Qwen/QwQ-32B"
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def server(): # noqa: F811 def server():
args = [ args = [
"--max-model-len", "--max-model-len",
"8192", "8192",
......
...@@ -67,8 +67,11 @@ async def test_completion_stream_options_and_logprobs_with_long_prompts( ...@@ -67,8 +67,11 @@ async def test_completion_stream_options_and_logprobs_with_long_prompts(
chunk.usage.prompt_tokens + chunk.usage.completion_tokens chunk.usage.prompt_tokens + chunk.usage.completion_tokens
) )
if not finished: if not finished:
tokens_received += 1
assert chunk.choices[0].text assert chunk.choices[0].text
# Count actual tokens from logprobs since multiple tokens
# can be batched into a single chunk
assert chunk.choices[0].logprobs and chunk.choices[0].logprobs.tokens
tokens_received += len(chunk.choices[0].logprobs.tokens)
if chunk.choices[0].finish_reason is not None: if chunk.choices[0].finish_reason is not None:
finished = True finished = True
...@@ -117,7 +120,10 @@ async def test_chat_completion_stream_options_and_logprobs_with_long_prompts( ...@@ -117,7 +120,10 @@ async def test_chat_completion_stream_options_and_logprobs_with_long_prompts(
assert chunk.choices[0].logprobs is None assert chunk.choices[0].logprobs is None
empty_chunks_received += 1 empty_chunks_received += 1
else: else:
tokens_received += 1 # Count actual tokens from logprobs since multiple tokens
# can be batched into a single chunk
assert chunk.choices[0].logprobs and chunk.choices[0].logprobs.content
tokens_received += len(chunk.choices[0].logprobs.content)
if chunk.choices[0].finish_reason is not None: if chunk.choices[0].finish_reason is not None:
finished = True finished = True
......
...@@ -208,3 +208,36 @@ def test_middleware(serve_parser, cli_args, expected_middleware): ...@@ -208,3 +208,36 @@ def test_middleware(serve_parser, cli_args, expected_middleware):
"""Ensure multiple middleware args are parsed properly""" """Ensure multiple middleware args are parsed properly"""
args = serve_parser.parse_args(args=cli_args) args = serve_parser.parse_args(args=cli_args)
assert args.middleware == expected_middleware assert args.middleware == expected_middleware
def test_default_chat_template_kwargs_parsing(serve_parser):
"""Ensure default_chat_template_kwargs JSON is parsed correctly"""
args = serve_parser.parse_args(
args=["--default-chat-template-kwargs", '{"enable_thinking": false}']
)
assert args.default_chat_template_kwargs == {"enable_thinking": False}
def test_default_chat_template_kwargs_complex(serve_parser):
"""Ensure complex default_chat_template_kwargs JSON is parsed correctly"""
kwargs_json = '{"enable_thinking": false, "custom_param": "value", "num": 42}'
args = serve_parser.parse_args(args=["--default-chat-template-kwargs", kwargs_json])
assert args.default_chat_template_kwargs == {
"enable_thinking": False,
"custom_param": "value",
"num": 42,
}
def test_default_chat_template_kwargs_default_none(serve_parser):
"""Ensure default_chat_template_kwargs defaults to None"""
args = serve_parser.parse_args(args=[])
assert args.default_chat_template_kwargs is None
def test_default_chat_template_kwargs_invalid_json(serve_parser):
"""Ensure invalid JSON raises an error"""
with pytest.raises(SystemExit):
serve_parser.parse_args(
args=["--default-chat-template-kwargs", "not valid json"]
)
...@@ -13,7 +13,7 @@ from vllm.entrypoints.openai.protocol import CompletionRequest, ErrorResponse ...@@ -13,7 +13,7 @@ from vllm.entrypoints.openai.protocol import CompletionRequest, ErrorResponse
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels
from vllm.outputs import CompletionOutput, RequestOutput from vllm.outputs import CompletionOutput, RequestOutput
from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.tokenizers import get_tokenizer
from vllm.v1.engine.async_llm import AsyncLLM from vllm.v1.engine.async_llm import AsyncLLM
MODEL_NAME = "openai-community/gpt2" MODEL_NAME = "openai-community/gpt2"
...@@ -73,6 +73,7 @@ def _build_serving_completion(engine: AsyncLLM) -> OpenAIServingCompletion: ...@@ -73,6 +73,7 @@ def _build_serving_completion(engine: AsyncLLM) -> OpenAIServingCompletion:
lora_request, lora_request,
trace_headers, trace_headers,
priority, priority,
data_parallel_rank,
): ):
return dict(engine_prompt), {} return dict(engine_prompt), {}
......
...@@ -125,7 +125,7 @@ messages = [ ...@@ -125,7 +125,7 @@ messages = [
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def server(): # noqa: F811 def server():
args = [ args = [
# use half precision for speed and memory savings in CI environment # use half precision for speed and memory savings in CI environment
"--dtype", "--dtype",
...@@ -212,7 +212,7 @@ async def test_function_tool_use( ...@@ -212,7 +212,7 @@ async def test_function_tool_use(
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def k2_server(): # noqa: F811 def k2_server():
args = [ args = [
# use half precision for speed and memory savings in CI environment # use half precision for speed and memory savings in CI environment
"--dtype", "--dtype",
......
...@@ -23,7 +23,7 @@ ACTIVE_MM_LORA_RESPONSE = "Spoken text: The first words I spoke in the original ...@@ -23,7 +23,7 @@ ACTIVE_MM_LORA_RESPONSE = "Spoken text: The first words I spoke in the original
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def multimodal_server(): # noqa: F811 def multimodal_server():
args = [ args = [
# use half precision for speed and memory savings in CI environment # use half precision for speed and memory savings in CI environment
"--dtype", "--dtype",
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Embedding shape validation in multimodal APIs.
Tests verify that embeddings with correct ndim but incorrect hidden_size
are rejected before they can cause crashes during model inference.
Validation is performed by the parser (MultiModalDataParser) and EmbeddingItems
classes, not by CompletionRenderer or MediaIO classes.
"""
import pytest
import torch
from vllm.multimodal.parse import (
AudioEmbeddingItems,
ImageEmbeddingItems,
MultiModalDataParser,
VideoEmbeddingItems,
)
class TestMultiModalParserShapeValidation:
"""Test hidden_size validation in MultiModalDataParser."""
def test_image_embeddings_correct_hidden_size_accepted(self):
"""Baseline: Image embeddings with correct hidden_size should work."""
expected_hidden_size = 768
parser = MultiModalDataParser(expected_hidden_size=expected_hidden_size)
valid_embeds = torch.randn(2, 100, expected_hidden_size)
result = parser.parse_mm_data({"image": valid_embeds})
assert "image" in result
assert isinstance(result["image"], ImageEmbeddingItems)
assert result["image"].get_count() == 2
def test_image_embeddings_wrong_hidden_size_rejected(self):
"""Security: Image embeddings with wrong hidden_size should be rejected."""
expected_hidden_size = 768
wrong_hidden_size = 4096
parser = MultiModalDataParser(expected_hidden_size=expected_hidden_size)
invalid_embeds = torch.randn(2, 100, wrong_hidden_size)
with pytest.raises(ValueError) as exc_info:
parser.parse_mm_data({"image": invalid_embeds})
error_msg = str(exc_info.value).lower()
assert "image" in error_msg
assert "hidden dimension mismatch" in error_msg
def test_audio_embeddings_wrong_hidden_size_rejected(self):
"""Security: Audio embeddings with wrong hidden_size should be rejected."""
expected_hidden_size = 768
wrong_hidden_size = 2048
parser = MultiModalDataParser(expected_hidden_size=expected_hidden_size)
invalid_embeds = torch.randn(2, 100, wrong_hidden_size)
with pytest.raises(ValueError) as exc_info:
parser.parse_mm_data({"audio": invalid_embeds})
error_msg = str(exc_info.value).lower()
assert "audio" in error_msg
assert "hidden dimension mismatch" in error_msg
def test_video_embeddings_wrong_hidden_size_rejected(self):
"""Security: Video embeddings with wrong hidden_size should be rejected."""
expected_hidden_size = 768
wrong_hidden_size = 512
parser = MultiModalDataParser(expected_hidden_size=expected_hidden_size)
invalid_embeds = torch.randn(2, 100, wrong_hidden_size)
with pytest.raises(ValueError) as exc_info:
parser.parse_mm_data({"video": invalid_embeds})
error_msg = str(exc_info.value).lower()
assert "video" in error_msg
assert "hidden dimension mismatch" in error_msg
def test_list_of_embeddings_validates_each(self):
"""Security: Each embedding in list should be validated."""
expected_hidden_size = 768
wrong_hidden_size = 1024
parser = MultiModalDataParser(expected_hidden_size=expected_hidden_size)
# List with second tensor having wrong hidden_size
invalid_embeds = [
torch.randn(100, expected_hidden_size),
torch.randn(100, wrong_hidden_size),
]
with pytest.raises(ValueError) as exc_info:
parser.parse_mm_data({"image": invalid_embeds})
# Should identify which embedding failed
assert "[1]" in str(exc_info.value)
def test_validation_disabled_allows_any_size(self):
"""When validation disabled (legacy), any hidden_size allowed."""
parser = MultiModalDataParser(expected_hidden_size=None)
any_hidden_size = 12345
embeds = torch.randn(2, 100, any_hidden_size)
# Should not raise
result = parser.parse_mm_data({"image": embeds})
assert "image" in result
assert isinstance(result["image"], ImageEmbeddingItems)
class TestEmbeddingItemsDirectValidation:
"""Direct tests for EmbeddingItems hidden_size validation."""
def test_image_embedding_items_validates_batched_tensor(self):
"""Test validation for batched (3D) image embeddings."""
expected = 768
wrong = 1024
# Valid
valid = torch.randn(2, 100, expected)
items = ImageEmbeddingItems(valid, expected_hidden_size=expected)
assert items.get_count() == 2
# Invalid
invalid = torch.randn(2, 100, wrong)
with pytest.raises(ValueError) as exc_info:
ImageEmbeddingItems(invalid, expected_hidden_size=expected)
assert str(wrong) in str(exc_info.value)
assert str(expected) in str(exc_info.value)
def test_image_embedding_items_validates_list_of_tensors(self):
"""Test validation for list of 2D image embeddings."""
expected = 768
wrong = 512
# Valid list
valid_list = [torch.randn(100, expected), torch.randn(50, expected)]
items = ImageEmbeddingItems(valid_list, expected_hidden_size=expected)
assert items.get_count() == 2
# Invalid list
invalid_list = [torch.randn(100, expected), torch.randn(50, wrong)]
with pytest.raises(ValueError) as exc_info:
ImageEmbeddingItems(invalid_list, expected_hidden_size=expected)
assert "[1]" in str(exc_info.value)
def test_audio_embedding_items_validates(self):
"""Test validation for audio embeddings."""
expected = 768
wrong = 256
invalid = torch.randn(2, 100, wrong)
with pytest.raises(ValueError) as exc_info:
AudioEmbeddingItems(invalid, expected_hidden_size=expected)
assert "audio" in str(exc_info.value).lower()
def test_video_embedding_items_validates(self):
"""Test validation for video embeddings."""
expected = 768
wrong = 384
invalid = torch.randn(2, 100, wrong)
with pytest.raises(ValueError) as exc_info:
VideoEmbeddingItems(invalid, expected_hidden_size=expected)
assert "video" in str(exc_info.value).lower()
class TestShapeValidationIntegration:
"""Integration tests verifying attack scenarios are blocked."""
def test_attack_scenario_multimodal_image(self):
"""
Simulate attack through Chat API with image embeddings.
Verifies validation occurs in multimodal parser path.
"""
expected_hidden_size = 768
wrong_hidden_size = 4096
parser = MultiModalDataParser(expected_hidden_size=expected_hidden_size)
attack_tensor = torch.randn(1, 100, wrong_hidden_size)
with pytest.raises(ValueError):
parser.parse_mm_data({"image": attack_tensor})
def test_attack_scenario_multimodal_audio(self):
"""
Simulate attack through Chat API with audio embeddings.
Verifies validation occurs in multimodal parser path.
"""
expected_hidden_size = 768
wrong_hidden_size = 2048
parser = MultiModalDataParser(expected_hidden_size=expected_hidden_size)
attack_tensor = torch.randn(1, 100, wrong_hidden_size)
with pytest.raises(ValueError):
parser.parse_mm_data({"audio": attack_tensor})
def test_attack_scenario_multimodal_video(self):
"""
Simulate attack through Chat API with video embeddings.
Verifies validation occurs in multimodal parser path.
"""
expected_hidden_size = 768
wrong_hidden_size = 1024
parser = MultiModalDataParser(expected_hidden_size=expected_hidden_size)
attack_tensor = torch.randn(1, 100, wrong_hidden_size)
with pytest.raises(ValueError):
parser.parse_mm_data({"video": attack_tensor})
...@@ -8,7 +8,7 @@ from ...utils import RemoteOpenAIServer ...@@ -8,7 +8,7 @@ from ...utils import RemoteOpenAIServer
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def chat_server_with_force_include_usage(request): # noqa: F811 def chat_server_with_force_include_usage(request):
args = [ args = [
# use half precision for speed and memory savings in CI environment # use half precision for speed and memory savings in CI environment
"--dtype", "--dtype",
......
...@@ -61,13 +61,13 @@ class MockLoRAResolver(LoRAResolver): ...@@ -61,13 +61,13 @@ class MockLoRAResolver(LoRAResolver):
return LoRARequest( return LoRARequest(
lora_name="test-lora", lora_name="test-lora",
lora_int_id=1, lora_int_id=1,
lora_local_path="/fake/path/test-lora", lora_path="/fake/path/test-lora",
) )
elif lora_name == "invalid-lora": elif lora_name == "invalid-lora":
return LoRARequest( return LoRARequest(
lora_name="invalid-lora", lora_name="invalid-lora",
lora_int_id=2, lora_int_id=2,
lora_local_path="/fake/path/invalid-lora", lora_path="/fake/path/invalid-lora",
) )
return None return None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment