Commit 7e63ef82 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.14.0' into v0.14.0-dev

parents 8cbcac5d b17039bc
...@@ -11,7 +11,7 @@ MODEL_NAME = "Qwen/Qwen3-0.6B" ...@@ -11,7 +11,7 @@ MODEL_NAME = "Qwen/Qwen3-0.6B"
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def server(): # noqa: F811 def server():
args = [ args = [
"--max-model-len", "--max-model-len",
"2048", "2048",
......
...@@ -90,7 +90,10 @@ def before_generate_case(context: schemathesis.hooks.HookContext, strategy): ...@@ -90,7 +90,10 @@ def before_generate_case(context: schemathesis.hooks.HookContext, strategy):
if ( if (
isinstance(content, list) isinstance(content, list)
and len(content) > 0 and len(content) > 0
and any(item.get("type") == "file" for item in content) and any(
isinstance(item, dict) and item.get("type") == "file"
for item in content
)
): ):
return False return False
...@@ -126,7 +129,7 @@ def before_generate_case(context: schemathesis.hooks.HookContext, strategy): ...@@ -126,7 +129,7 @@ def before_generate_case(context: schemathesis.hooks.HookContext, strategy):
@schema.parametrize() @schema.parametrize()
@schema.override(headers={"Content-Type": "application/json"}) @schema.override(headers={"Content-Type": "application/json"})
@settings(deadline=LONG_TIMEOUT_SECONDS * 1000) @settings(deadline=LONG_TIMEOUT_SECONDS * 1000, max_examples=50)
def test_openapi_stateless(case: schemathesis.Case): def test_openapi_stateless(case: schemathesis.Case):
key = ( key = (
case.operation.method.upper(), case.operation.method.upper(),
...@@ -139,6 +142,7 @@ def test_openapi_stateless(case: schemathesis.Case): ...@@ -139,6 +142,7 @@ def test_openapi_stateless(case: schemathesis.Case):
timeout = { timeout = {
# requires a longer timeout # requires a longer timeout
("POST", "/v1/chat/completions"): LONG_TIMEOUT_SECONDS, ("POST", "/v1/chat/completions"): LONG_TIMEOUT_SECONDS,
("POST", "/v1/completions"): LONG_TIMEOUT_SECONDS,
}.get(key, DEFAULT_TIMEOUT_SECONDS) }.get(key, DEFAULT_TIMEOUT_SECONDS)
# No need to verify SSL certificate for localhost # No need to verify SSL certificate for localhost
......
...@@ -39,6 +39,7 @@ def server(request: pytest.FixtureRequest): ...@@ -39,6 +39,7 @@ def server(request: pytest.FixtureRequest):
"2", "2",
*passed_params, *passed_params,
] ]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server yield remote_server
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import pytest_asyncio
from openai import OpenAI
from openai_harmony import ToolDescription, ToolNamespaceConfig
from vllm.entrypoints.tool_server import MCPToolServer
from ...utils import RemoteOpenAIServer
MODEL_NAME = "openai/gpt-oss-20b"
@pytest.fixture(scope="module")
def monkeypatch_module():
from _pytest.monkeypatch import MonkeyPatch
mpatch = MonkeyPatch()
yield mpatch
mpatch.undo()
@pytest.fixture(scope="module")
def mcp_disabled_server(monkeypatch_module: pytest.MonkeyPatch):
args = ["--enforce-eager", "--tool-server", "demo"]
with monkeypatch_module.context() as m:
m.setenv("VLLM_ENABLE_RESPONSES_API_STORE", "1")
m.setenv("PYTHON_EXECUTION_BACKEND", "dangerously_use_uv")
# Helps the model follow instructions better
m.setenv("VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS", "1")
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
@pytest.fixture(scope="function")
def mcp_enabled_server(monkeypatch_module: pytest.MonkeyPatch):
args = ["--enforce-eager", "--tool-server", "demo"]
with monkeypatch_module.context() as m:
m.setenv("VLLM_ENABLE_RESPONSES_API_STORE", "1")
m.setenv("PYTHON_EXECUTION_BACKEND", "dangerously_use_uv")
m.setenv("VLLM_GPT_OSS_SYSTEM_TOOL_MCP_LABELS", "code_interpreter,container")
# Helps the model follow instructions better
m.setenv("VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS", "1")
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
@pytest_asyncio.fixture
async def mcp_disabled_client(mcp_disabled_server):
async with mcp_disabled_server.get_async_client() as async_client:
yield async_client
@pytest_asyncio.fixture
async def mcp_enabled_client(mcp_enabled_server):
async with mcp_enabled_server.get_async_client() as async_client:
yield async_client
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_mcp_tool_env_flag_enabled(mcp_enabled_client: OpenAI, model_name: str):
response = await mcp_enabled_client.responses.create(
model=model_name,
input=(
"Execute the following code: "
"import random; print(random.randint(1, 1000000))"
),
instructions=(
"You must use the Python tool to execute code. Never simulate execution."
),
tools=[
{
"type": "mcp",
"server_label": "code_interpreter",
# URL unused for DemoToolServer
"server_url": "http://localhost:8888",
}
],
extra_body={"enable_response_messages": True},
)
assert response is not None
assert response.status == "completed"
# Verify output messages: Tool calls and responses on analysis channel
tool_call_found = False
tool_response_found = False
for message in response.output_messages:
recipient = message.get("recipient")
if recipient and recipient.startswith("python"):
tool_call_found = True
assert message.get("channel") == "analysis", (
"Tool call should be on analysis channel"
)
author = message.get("author", {})
if (
author.get("role") == "tool"
and author.get("name")
and author.get("name").startswith("python")
):
tool_response_found = True
assert message.get("channel") == "analysis", (
"Tool response should be on analysis channel"
)
assert tool_call_found, "Should have found at least one Python tool call"
assert tool_response_found, "Should have found at least one Python tool response"
for message in response.input_messages:
assert message.get("author").get("role") != "developer", (
"No developer messages should be present with valid mcp tool"
)
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_mcp_tool_with_allowed_tools_star(
mcp_enabled_client: OpenAI, model_name: str
):
"""Test MCP tool with allowed_tools=['*'] to select all available tools.
This E2E test verifies that the "*" wildcard works end-to-end.
See test_serving_responses.py for detailed unit tests of "*" normalization.
"""
response = await mcp_enabled_client.responses.create(
model=model_name,
input=(
"Execute the following code: "
"import random; print(random.randint(1, 1000000))"
),
instructions=(
"You must use the Python tool to execute code. Never simulate execution."
),
tools=[
{
"type": "mcp",
"server_label": "code_interpreter",
"server_url": "http://localhost:8888",
# Using "*" to allow all tools from this MCP server
"allowed_tools": ["*"],
}
],
extra_body={"enable_response_messages": True},
)
assert response is not None
assert response.status == "completed"
# Verify tool calls work with allowed_tools=["*"]
tool_call_found = False
for message in response.output_messages:
recipient = message.get("recipient")
if recipient and recipient.startswith("python"):
tool_call_found = True
break
assert tool_call_found, "Should have found at least one Python tool call with '*'"
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_mcp_tool_env_flag_disabled(mcp_disabled_client: OpenAI, model_name: str):
response = await mcp_disabled_client.responses.create(
model=model_name,
input=(
"Execute the following code if the tool is present: "
"import random; print(random.randint(1, 1000000))"
),
tools=[
{
"type": "mcp",
"server_label": "code_interpreter",
# URL unused for DemoToolServer
"server_url": "http://localhost:8888",
}
],
extra_body={"enable_response_messages": True},
)
assert response is not None
assert response.status == "completed"
# Verify output messages: No tool calls and responses
tool_call_found = False
tool_response_found = False
for message in response.output_messages:
recipient = message.get("recipient")
if recipient and recipient.startswith("python"):
tool_call_found = True
assert message.get("channel") == "analysis", (
"Tool call should be on analysis channel"
)
author = message.get("author", {})
if (
author.get("role") == "tool"
and author.get("name")
and author.get("name").startswith("python")
):
tool_response_found = True
assert message.get("channel") == "analysis", (
"Tool response should be on analysis channel"
)
assert not tool_call_found, "Should not have a python call"
assert not tool_response_found, "Should not have a tool response"
for message in response.input_messages:
assert message.get("author").get("role") != "developer", (
"No developer messages should be present without a valid tool"
)
def test_get_tool_description():
"""Test MCPToolServer.get_tool_description filtering logic.
Note: The wildcard "*" is normalized to None by
_extract_allowed_tools_from_mcp_requests before reaching this layer,
so we only test None and specific tool filtering here.
See test_serving_responses.py for "*" normalization tests.
"""
pytest.importorskip("mcp")
server = MCPToolServer()
tool1 = ToolDescription.new(
name="tool1", description="First", parameters={"type": "object"}
)
tool2 = ToolDescription.new(
name="tool2", description="Second", parameters={"type": "object"}
)
tool3 = ToolDescription.new(
name="tool3", description="Third", parameters={"type": "object"}
)
server.harmony_tool_descriptions = {
"test_server": ToolNamespaceConfig(
name="test_server", description="test", tools=[tool1, tool2, tool3]
)
}
# Nonexistent server
assert server.get_tool_description("nonexistent") is None
# None (no filter) - returns all tools
result = server.get_tool_description("test_server", allowed_tools=None)
assert len(result.tools) == 3
# Filter to specific tools
result = server.get_tool_description(
"test_server", allowed_tools=["tool1", "tool3"]
)
assert len(result.tools) == 2
assert result.tools[0].name == "tool1"
assert result.tools[1].name == "tool3"
# Single tool
result = server.get_tool_description("test_server", allowed_tools=["tool2"])
assert len(result.tools) == 1
assert result.tools[0].name == "tool2"
# No matching tools - returns None
result = server.get_tool_description("test_server", allowed_tools=["nonexistent"])
assert result is None
# Empty list - returns None
assert server.get_tool_description("test_server", allowed_tools=[]) is None
...@@ -37,7 +37,7 @@ def default_server_args(qwen3_lora_files): ...@@ -37,7 +37,7 @@ def default_server_args(qwen3_lora_files):
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def server_fixture(request, default_server_args): # noqa: F811 def server_fixture(request, default_server_args):
use_server_flag = request.param use_server_flag = request.param
if use_server_flag: if use_server_flag:
args_with_flag = default_server_args + ["--return-tokens-as-token-ids"] args_with_flag = default_server_args + ["--return-tokens-as-token-ids"]
......
...@@ -17,6 +17,7 @@ from vllm.entrypoints.openai.parser.harmony_utils import get_encoding ...@@ -17,6 +17,7 @@ from vllm.entrypoints.openai.parser.harmony_utils import get_encoding
from vllm.entrypoints.openai.protocol import ( from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest, ChatCompletionRequest,
ChatCompletionResponse, ChatCompletionResponse,
ErrorResponse,
RequestResponseMetadata, RequestResponseMetadata,
) )
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
...@@ -54,8 +55,19 @@ def with_tool_parser(request) -> bool: ...@@ -54,8 +55,19 @@ def with_tool_parser(request) -> bool:
return request.param return request.param
@pytest.fixture(
scope="module",
params=[True],
ids=["exclude_tools_when_tool_choice_none"],
)
def exclude_tools_when_tool_choice_none(request) -> bool:
return request.param
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def default_server_args(with_tool_parser: bool): def default_server_args(
with_tool_parser: bool, exclude_tools_when_tool_choice_none: bool
):
args = [ args = [
# use half precision for speed and memory savings in CI environment # use half precision for speed and memory savings in CI environment
"--enforce-eager", "--enforce-eager",
...@@ -74,19 +86,16 @@ def default_server_args(with_tool_parser: bool): ...@@ -74,19 +86,16 @@ def default_server_args(with_tool_parser: bool):
"--enable-auto-tool-choice", "--enable-auto-tool-choice",
] ]
) )
if exclude_tools_when_tool_choice_none:
args.append("--exclude-tools-when-tool-choice-none")
return args return args
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def gptoss_server( def gptoss_server(default_server_args: list[str]):
monkeypatch_module: pytest.MonkeyPatch, default_server_args: list[str] server_args = default_server_args + ["--attention-backend=TRITON_ATTN"]
): with RemoteOpenAIServer(GPT_OSS_MODEL_NAME, server_args) as remote_server:
with monkeypatch_module.context() as m: yield remote_server
m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN")
with RemoteOpenAIServer(
GPT_OSS_MODEL_NAME, default_server_args
) as remote_server:
yield remote_server
@pytest_asyncio.fixture @pytest_asyncio.fixture
...@@ -342,6 +351,69 @@ async def test_gpt_oss_tool_message_array_content( ...@@ -342,6 +351,69 @@ async def test_gpt_oss_tool_message_array_content(
assert response_multi_array.choices[0].message is not None assert response_multi_array.choices[0].message is not None
@pytest.mark.asyncio
async def test_gpt_oss_tool_choice_none(
gptoss_client: OpenAI,
with_tool_parser: bool,
exclude_tools_when_tool_choice_none: bool,
):
if not (with_tool_parser and exclude_tools_when_tool_choice_none):
pytest.skip(
"skip tool_choice tests when non-tool or "
"--exclude-tools-when-tool-choice-none not set"
)
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"city": {"type": "string"},
"state": {"type": "string"},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
},
},
"required": ["city", "state", "unit"],
},
},
}
]
messages = [
{
"role": "user",
"content": "What's the temperature(in degrees Celsius) in Dallas?",
},
]
tool_choice_auto = await gptoss_client.chat.completions.create(
model=GPT_OSS_MODEL_NAME,
messages=messages,
tools=tools,
tool_choice="auto",
temperature=0.0,
)
msg = tool_choice_auto.choices[0].message
assert len(msg.tool_calls) == 1
tool_choice_none = await gptoss_client.chat.completions.create(
model=GPT_OSS_MODEL_NAME,
messages=messages,
tools=tools,
tool_choice="none",
temperature=0.0,
)
msg = tool_choice_none.choices[0].message
assert len(msg.tool_calls) == 0
MODEL_NAME = os.path.join(models_path_prefix, "openai-community/gpt2") MODEL_NAME = os.path.join(models_path_prefix, "openai-community/gpt2")
MODEL_NAME_SHORT = os.path.join(models_path_prefix, "gpt2") MODEL_NAME_SHORT = os.path.join(models_path_prefix, "gpt2")
CHAT_TEMPLATE = "Dummy chat template for testing {}" CHAT_TEMPLATE = "Dummy chat template for testing {}"
...@@ -403,6 +475,7 @@ def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat: ...@@ -403,6 +475,7 @@ def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat:
lora_request, lora_request,
trace_headers, trace_headers,
priority, priority,
data_parallel_rank,
): ):
return dict(engine_prompt), {} return dict(engine_prompt), {}
...@@ -884,7 +957,6 @@ class TestServingChatWithHarmony: ...@@ -884,7 +957,6 @@ class TestServingChatWithHarmony:
input_messages, input_messages,
[ [
{"role": "system"}, {"role": "system"},
{"role": "developer"},
{"role": "user", "content": messages[0]["content"]}, {"role": "user", "content": messages[0]["content"]},
], ],
) )
...@@ -912,7 +984,6 @@ class TestServingChatWithHarmony: ...@@ -912,7 +984,6 @@ class TestServingChatWithHarmony:
input_messages_2, input_messages_2,
[ [
{"role": "system"}, {"role": "system"},
{"role": "developer"},
{"role": "user"}, {"role": "user"},
# The analysis message should be dropped on subsequent inputs because # The analysis message should be dropped on subsequent inputs because
# of the subsequent assistant message to the final channel. # of the subsequent assistant message to the final channel.
...@@ -972,7 +1043,7 @@ class TestServingChatWithHarmony: ...@@ -972,7 +1043,7 @@ class TestServingChatWithHarmony:
) )
# Test the Harmony messages for the second turn's input # Test the Harmony messages for the second turn's input
req_2 = ChatCompletionRequest(model=MODEL_NAME, messages=messages) req_2 = ChatCompletionRequest(model=MODEL_NAME, messages=messages, tools=tools)
input_messages_2, _ = serving_chat._make_request_with_harmony(req_2) input_messages_2, _ = serving_chat._make_request_with_harmony(req_2)
verify_harmony_messages( verify_harmony_messages(
input_messages_2, input_messages_2,
...@@ -1053,7 +1124,7 @@ class TestServingChatWithHarmony: ...@@ -1053,7 +1124,7 @@ class TestServingChatWithHarmony:
) )
# Test the Harmony messages for the second turn's input # Test the Harmony messages for the second turn's input
req_2 = ChatCompletionRequest(model=MODEL_NAME, messages=messages) req_2 = ChatCompletionRequest(model=MODEL_NAME, messages=messages, tools=tools)
input_messages_2, _ = serving_chat._make_request_with_harmony(req_2) input_messages_2, _ = serving_chat._make_request_with_harmony(req_2)
verify_harmony_messages( verify_harmony_messages(
input_messages_2, input_messages_2,
...@@ -1134,7 +1205,7 @@ class TestServingChatWithHarmony: ...@@ -1134,7 +1205,7 @@ class TestServingChatWithHarmony:
) )
# Test the Harmony messages for the second turn's input # Test the Harmony messages for the second turn's input
req_2 = ChatCompletionRequest(model=MODEL_NAME, messages=messages) req_2 = ChatCompletionRequest(model=MODEL_NAME, messages=messages, tools=tools)
input_messages_2, _ = serving_chat._make_request_with_harmony(req_2) input_messages_2, _ = serving_chat._make_request_with_harmony(req_2)
verify_harmony_messages( verify_harmony_messages(
input_messages_2, input_messages_2,
...@@ -1184,7 +1255,7 @@ class TestServingChatWithHarmony: ...@@ -1184,7 +1255,7 @@ class TestServingChatWithHarmony:
) )
# Test the Harmony messages for the third turn's input # Test the Harmony messages for the third turn's input
req_3 = ChatCompletionRequest(model=MODEL_NAME, messages=messages) req_3 = ChatCompletionRequest(model=MODEL_NAME, messages=messages, tools=tools)
input_messages_3, _ = serving_chat._make_request_with_harmony(req_3) input_messages_3, _ = serving_chat._make_request_with_harmony(req_3)
verify_harmony_messages( verify_harmony_messages(
input_messages_3, input_messages_3,
...@@ -1247,7 +1318,7 @@ class TestServingChatWithHarmony: ...@@ -1247,7 +1318,7 @@ class TestServingChatWithHarmony:
) )
# Test the Harmony messages for the fourth turn's input # Test the Harmony messages for the fourth turn's input
req_4 = ChatCompletionRequest(model=MODEL_NAME, messages=messages) req_4 = ChatCompletionRequest(model=MODEL_NAME, messages=messages, tools=tools)
input_messages_4, _ = serving_chat._make_request_with_harmony(req_4) input_messages_4, _ = serving_chat._make_request_with_harmony(req_4)
verify_harmony_messages( verify_harmony_messages(
input_messages_4, input_messages_4,
...@@ -1303,7 +1374,6 @@ class TestServingChatWithHarmony: ...@@ -1303,7 +1374,6 @@ class TestServingChatWithHarmony:
input_messages, input_messages,
[ [
{"role": "system"}, {"role": "system"},
{"role": "developer"},
{"role": "user", "content": messages[0]["content"]}, {"role": "user", "content": messages[0]["content"]},
# The reasoning that would have resulted in an analysis message is # The reasoning that would have resulted in an analysis message is
# dropped because of a later assistant message to the final channel. # dropped because of a later assistant message to the final channel.
...@@ -1335,7 +1405,6 @@ class TestServingChatWithHarmony: ...@@ -1335,7 +1405,6 @@ class TestServingChatWithHarmony:
input_messages, input_messages,
[ [
{"role": "system"}, {"role": "system"},
{"role": "developer"},
{"role": "user", "content": messages[0]["content"]}, {"role": "user", "content": messages[0]["content"]},
{ {
"role": "assistant", "role": "assistant",
...@@ -1365,7 +1434,6 @@ class TestServingChatWithHarmony: ...@@ -1365,7 +1434,6 @@ class TestServingChatWithHarmony:
input_messages, input_messages,
[ [
{"role": "system"}, {"role": "system"},
{"role": "developer"},
{"role": "user", "content": messages[0]["content"]}, {"role": "user", "content": messages[0]["content"]},
{ {
"role": "assistant", "role": "assistant",
...@@ -1374,3 +1442,208 @@ class TestServingChatWithHarmony: ...@@ -1374,3 +1442,208 @@ class TestServingChatWithHarmony:
}, },
], ],
) )
@pytest.mark.asyncio
async def test_tool_choice_validation_without_parser():
"""Test that tool_choice='required' or named tool without tool_parser
returns an appropriate error message."""
mock_engine = MagicMock(spec=AsyncLLM)
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False
mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
models = OpenAIServingModels(
engine_client=mock_engine,
base_model_paths=BASE_MODEL_PATHS,
)
# Create serving_chat without tool_parser (enable_auto_tools=False)
serving_chat = OpenAIServingChat(
mock_engine,
models,
response_role="assistant",
chat_template=CHAT_TEMPLATE,
chat_template_content_format="auto",
request_logger=None,
enable_auto_tools=False, # No tool parser
)
tools = [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get the weather in a given location",
"parameters": {
"type": "object",
"properties": {"location": {"type": "string"}},
"required": ["location"],
},
},
}
]
# Test tool_choice="required" without tool_parser
req_required = ChatCompletionRequest(
model=MODEL_NAME,
messages=[{"role": "user", "content": "What's the weather?"}],
tools=tools,
tool_choice="required",
)
response_required = await serving_chat.create_chat_completion(req_required)
assert isinstance(response_required, ErrorResponse)
assert "tool_choice" in response_required.error.message
assert "--tool-call-parser" in response_required.error.message
# Test named tool_choice without tool_parser
req_named = ChatCompletionRequest(
model=MODEL_NAME,
messages=[{"role": "user", "content": "What's the weather?"}],
tools=tools,
tool_choice={"type": "function", "function": {"name": "get_weather"}},
)
response_named = await serving_chat.create_chat_completion(req_named)
assert isinstance(response_named, ErrorResponse)
assert "tool_choice" in response_named.error.message
assert "--tool-call-parser" in response_named.error.message
class TestCreateRemainingArgsDelta:
"""Tests for _create_remaining_args_delta helper function.
This helper is used when streaming tool calls to preserve id/type/name
fields in the finish chunk, which would otherwise be lost.
"""
def test_preserves_id_type_name(self):
"""Test that id, type, and name are preserved from original delta."""
from vllm.entrypoints.openai.protocol import (
DeltaFunctionCall,
DeltaMessage,
DeltaToolCall,
)
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
original_delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=0,
id="call_abc123",
type="function",
function=DeltaFunctionCall(
name="get_weather",
arguments='{"location": "Paris"}',
),
)
]
)
result = OpenAIServingChat._create_remaining_args_delta(
original_delta, '", "unit": "celsius"}', 0
)
assert len(result.tool_calls) == 1
tc = result.tool_calls[0]
assert tc.index == 0
assert tc.id == "call_abc123"
assert tc.type == "function"
assert tc.function.name == "get_weather"
assert tc.function.arguments == '", "unit": "celsius"}'
def test_matches_by_index(self):
"""Test that the correct tool call is matched by index."""
from vllm.entrypoints.openai.protocol import (
DeltaFunctionCall,
DeltaMessage,
DeltaToolCall,
)
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
original_delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=0,
id="call_first",
type="function",
function=DeltaFunctionCall(name="func_a", arguments="{}"),
),
DeltaToolCall(
index=1,
id="call_second",
type="function",
function=DeltaFunctionCall(name="func_b", arguments="{}"),
),
]
)
result = OpenAIServingChat._create_remaining_args_delta(
original_delta, '{"extra": true}', 1
)
assert len(result.tool_calls) == 1
tc = result.tool_calls[0]
assert tc.index == 1
assert tc.id == "call_second"
assert tc.function.name == "func_b"
def test_no_matching_tool_call(self):
"""Test graceful handling when no matching tool call is found."""
from vllm.entrypoints.openai.protocol import (
DeltaFunctionCall,
DeltaMessage,
DeltaToolCall,
)
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
original_delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=0,
id="call_zero",
type="function",
function=DeltaFunctionCall(name="func", arguments="{}"),
)
]
)
result = OpenAIServingChat._create_remaining_args_delta(
original_delta, '{"arg": 1}', 5
)
assert len(result.tool_calls) == 1
tc = result.tool_calls[0]
assert tc.index == 5
assert tc.id is None
assert tc.type is None
assert tc.function.name is None
assert tc.function.arguments == '{"arg": 1}'
def test_function_is_none(self):
"""Test handling when original tool call has no function."""
from vllm.entrypoints.openai.protocol import DeltaMessage, DeltaToolCall
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
original_delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=0,
id="call_nofunc",
type="function",
function=None,
)
]
)
result = OpenAIServingChat._create_remaining_args_delta(
original_delta, '{"data": "value"}', 0
)
assert len(result.tool_calls) == 1
tc = result.tool_calls[0]
assert tc.index == 0
assert tc.id == "call_nofunc"
assert tc.type == "function"
assert tc.function.name is None
assert tc.function.arguments == '{"data": "value"}'
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Unit tests for harmony streaming delta extraction.
"""
from dataclasses import dataclass, field
from unittest.mock import patch
import pytest
from vllm.entrypoints.openai.serving_chat_stream_harmony import (
extract_harmony_streaming_delta,
)
@dataclass
class MockMessage:
"""Mock message object for testing."""
channel: str | None = None
recipient: str | None = None
@dataclass
class MockStreamableParser:
"""Mock StreamableParser for testing without openai_harmony dependency."""
messages: list[MockMessage] = field(default_factory=list)
class TestExtractHarmonyStreamingDelta:
"""Tests for extract_harmony_streaming_delta function."""
@pytest.mark.parametrize(
"delta_text,expected_content",
[
("Hello, world!", "Hello, world!"),
("", ""),
],
)
def test_final_channel_returns_content_delta(self, delta_text, expected_content):
"""Test that final channel returns a DeltaMessage with content."""
parser = MockStreamableParser()
delta_message, tools_streamed = extract_harmony_streaming_delta(
harmony_parser=parser,
cur_channel="final",
cur_recipient=None,
prev_recipient=None,
delta_text=delta_text,
include_reasoning=False,
)
assert delta_message is not None
assert delta_message.content == expected_content
assert tools_streamed is False
@pytest.mark.parametrize(
"include_reasoning,expected_has_message",
[
(True, True),
(False, False),
],
)
def test_analysis_channel_reasoning(self, include_reasoning, expected_has_message):
"""Test analysis channel respects include_reasoning flag."""
parser = MockStreamableParser()
delta_message, tools_streamed = extract_harmony_streaming_delta(
harmony_parser=parser,
cur_channel="analysis",
cur_recipient=None,
prev_recipient=None,
delta_text="Let me think...",
include_reasoning=include_reasoning,
)
if expected_has_message:
assert delta_message is not None
assert delta_message.reasoning == "Let me think..."
else:
assert delta_message is None
assert tools_streamed is False
@pytest.mark.parametrize("channel", ["commentary", "analysis"])
@patch("vllm.entrypoints.openai.serving_chat_stream_harmony.make_tool_call_id")
def test_new_tool_call(self, mock_make_tool_call_id, channel):
"""Test new tool call creation when recipient changes."""
mock_make_tool_call_id.return_value = "call_test123"
parser = MockStreamableParser()
delta_message, tools_streamed = extract_harmony_streaming_delta(
harmony_parser=parser,
cur_channel=channel,
cur_recipient="functions.get_weather",
prev_recipient=None,
delta_text="",
include_reasoning=False,
)
assert delta_message is not None
assert len(delta_message.tool_calls) == 1
tool_call = delta_message.tool_calls[0]
assert tool_call.id == "call_test123"
assert tool_call.type == "function"
assert tool_call.function.name == "get_weather"
assert tool_call.function.arguments == ""
assert tool_call.index == 0
assert tools_streamed is True
@pytest.mark.parametrize("channel", ["commentary", "analysis"])
def test_tool_call_argument_streaming(self, channel):
"""Test streaming tool call arguments (same recipient)."""
parser = MockStreamableParser()
delta_message, tools_streamed = extract_harmony_streaming_delta(
harmony_parser=parser,
cur_channel=channel,
cur_recipient="functions.get_weather",
prev_recipient="functions.get_weather",
delta_text='{"location": "Paris"}',
include_reasoning=False,
)
assert delta_message is not None
tool_call = delta_message.tool_calls[0]
assert tool_call.id is None
assert tool_call.function.arguments == '{"location": "Paris"}'
assert tool_call.index == 0
assert tools_streamed is True
@pytest.mark.parametrize("channel", ["commentary", "analysis"])
def test_tool_call_empty_arguments_returns_none(self, channel):
"""Test empty delta_text with same recipient returns None."""
parser = MockStreamableParser()
delta_message, tools_streamed = extract_harmony_streaming_delta(
harmony_parser=parser,
cur_channel=channel,
cur_recipient="functions.get_weather",
prev_recipient="functions.get_weather",
delta_text="",
include_reasoning=False,
)
assert delta_message is None
assert tools_streamed is False
def test_tool_call_index_from_previous_messages(self):
"""Test tool call index accounts for previous function messages."""
messages = [
MockMessage(channel="analysis", recipient=None), # Not counted
MockMessage(channel="commentary", recipient="functions.tool1"), # Counted
MockMessage(channel="final", recipient=None), # Not counted
]
parser = MockStreamableParser(messages=messages)
delta_message, _ = extract_harmony_streaming_delta(
harmony_parser=parser,
cur_channel="commentary",
cur_recipient="functions.tool2",
prev_recipient="functions.tool2",
delta_text="args",
include_reasoning=False,
)
assert delta_message.tool_calls[0].index == 1
@pytest.mark.parametrize(
"channel,recipient",
[
("commentary", None),
("commentary", "browser.search"),
],
)
def test_returns_tool_call_preambles(self, channel, recipient):
"""Test that invalid channel/recipient combinations return None."""
parser = MockStreamableParser()
delta_text = "some text"
delta_message, tools_streamed = extract_harmony_streaming_delta(
harmony_parser=parser,
cur_channel=channel,
cur_recipient=recipient,
prev_recipient=None,
delta_text=delta_text,
include_reasoning=True,
)
assert delta_message.content == delta_text
assert tools_streamed is False
@pytest.mark.parametrize(
"channel,recipient",
[
(None, None),
("unknown_channel", None),
],
)
def test_returns_none_for_invalid_inputs(self, channel, recipient):
"""Test that invalid channel/recipient combinations return None."""
parser = MockStreamableParser()
delta_message, tools_streamed = extract_harmony_streaming_delta(
harmony_parser=parser,
cur_channel=channel,
cur_recipient=recipient,
prev_recipient=None,
delta_text="some text",
include_reasoning=True,
)
assert delta_message is None
assert tools_streamed is False
...@@ -93,6 +93,7 @@ async def test_same_response_as_chat_completions(client, tokenizer, messages): ...@@ -93,6 +93,7 @@ async def test_same_response_as_chat_completions(client, tokenizer, messages):
add_generation_prompt=True, add_generation_prompt=True,
enable_thinking=False, # default with Qwen3 enable_thinking=False, # default with Qwen3
) )
for ignore_eos in [True, False]: for ignore_eos in [True, False]:
payload = { payload = {
"model": MODEL_NAME, "model": MODEL_NAME,
...@@ -108,9 +109,8 @@ async def test_same_response_as_chat_completions(client, tokenizer, messages): ...@@ -108,9 +109,8 @@ async def test_same_response_as_chat_completions(client, tokenizer, messages):
} }
generate_resp = await client.post(GEN_ENDPOINT, json=payload) generate_resp = await client.post(GEN_ENDPOINT, json=payload)
generate_data = generate_resp.json() generate_data = generate_resp.json()
generate_res = tokenizer.decode( gen_token_ids = generate_data["choices"][0]["token_ids"]
generate_data["choices"][0]["token_ids"], skip_special_tokens=True generate_res = tokenizer.decode(gen_token_ids, skip_special_tokens=True)
)
payload = { payload = {
"model": MODEL_NAME, "model": MODEL_NAME,
...@@ -119,12 +119,33 @@ async def test_same_response_as_chat_completions(client, tokenizer, messages): ...@@ -119,12 +119,33 @@ async def test_same_response_as_chat_completions(client, tokenizer, messages):
"temperature": 0.0, "temperature": 0.0,
"stream": False, "stream": False,
"ignore_eos": ignore_eos, "ignore_eos": ignore_eos,
"chat_template_kwargs": dict(enable_thinking=False), "chat_template_kwargs": {"enable_thinking": False},
} }
completions_resp = await client.post("/v1/chat/completions", json=payload) completions_resp = await client.post("/v1/chat/completions", json=payload)
completions_data = completions_resp.json() completions_data = completions_resp.json()
completions_res = completions_data["choices"][0]["message"]["content"] completions_res = completions_data["choices"][0]["message"]["content"]
if ignore_eos:
# When ignoring EOS, only compare up to the first EOS token
# Post-EOS generation is undefined and may differ
eos_tokens = {
tokenizer.eos_token_id,
*tokenizer.additional_special_tokens_ids,
}
# Find first EOS in generated tokens
eos_pos = None
for i, tid in enumerate(gen_token_ids):
if tid in eos_tokens:
eos_pos = i
break
if eos_pos is not None:
gen_token_ids_truncated = gen_token_ids[:eos_pos]
generate_res = tokenizer.decode(
gen_token_ids_truncated, skip_special_tokens=True
)
# Truncate completions_res to same length for comparison
completions_res = completions_res[: len(generate_res)]
assert generate_res == completions_res assert generate_res == completions_res
......
...@@ -10,11 +10,17 @@ import time ...@@ -10,11 +10,17 @@ import time
import openai import openai
import pytest import pytest
from vllm.platforms import current_platform
from vllm.utils.network_utils import get_open_port from vllm.utils.network_utils import get_open_port
from ...utils import models_path_prefix from ...utils import models_path_prefix
MODEL_NAME = os.path.join(models_path_prefix, "hmellor/tiny-random-LlamaForCausalLM") MODEL_NAME = os.path.join(models_path_prefix, "hmellor/tiny-random-LlamaForCausalLM")
# GPU initialization might take take longer
_IS_ROCM = current_platform.is_rocm()
_SERVER_STARTUP_TIMEOUT = 120
_PROCESS_EXIT_TIMEOUT = 15
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_shutdown_on_engine_failure(): async def test_shutdown_on_engine_failure():
...@@ -47,9 +53,11 @@ async def test_shutdown_on_engine_failure(): ...@@ -47,9 +53,11 @@ async def test_shutdown_on_engine_failure():
"2", "2",
"--disable-frontend-multiprocessing", "--disable-frontend-multiprocessing",
], ],
stdout=subprocess.PIPE, # ROCm: Disable stdout/stderr pipe capture. Subprocess hangs when
stderr=subprocess.PIPE, # stdout/stderr pipes are enabled during ROCm GPU initialization.
text=True, stdout=None if _IS_ROCM else subprocess.PIPE,
stderr=None if _IS_ROCM else subprocess.PIPE,
text=None if _IS_ROCM else True,
preexec_fn=lambda: signal.signal(signal.SIGINT, signal.SIG_IGN), preexec_fn=lambda: signal.signal(signal.SIGINT, signal.SIG_IGN),
) )
...@@ -63,7 +71,7 @@ async def test_shutdown_on_engine_failure(): ...@@ -63,7 +71,7 @@ async def test_shutdown_on_engine_failure():
) )
# Poll until server is ready # Poll until server is ready
while time.time() - start_time < 30: while time.time() - start_time < _SERVER_STARTUP_TIMEOUT:
try: try:
await client.completions.create( await client.completions.create(
model=MODEL_NAME, prompt="Hello", max_tokens=1 model=MODEL_NAME, prompt="Hello", max_tokens=1
...@@ -72,14 +80,18 @@ async def test_shutdown_on_engine_failure(): ...@@ -72,14 +80,18 @@ async def test_shutdown_on_engine_failure():
except Exception: except Exception:
time.sleep(0.5) time.sleep(0.5)
if proc.poll() is not None: if proc.poll() is not None:
stdout, stderr = proc.communicate(timeout=1) if _IS_ROCM:
pytest.fail( pytest.fail(f"Server died during startup: {proc.returncode}")
f"Server died during startup. stdout: {stdout}, stderr: {stderr}" else:
) stdout, stderr = proc.communicate(timeout=1)
pytest.fail(
f"Server died during startup. "
f"stdout: {stdout}, stderr: {stderr}"
)
else: else:
proc.terminate() proc.terminate()
proc.wait(timeout=5) proc.wait(timeout=_PROCESS_EXIT_TIMEOUT)
pytest.fail("Server failed to start in 30 seconds") pytest.fail(f"Server failed to start in {_SERVER_STARTUP_TIMEOUT} seconds")
# Kill server to simulate crash # Kill server to simulate crash
proc.terminate() proc.terminate()
...@@ -91,5 +103,5 @@ async def test_shutdown_on_engine_failure(): ...@@ -91,5 +103,5 @@ async def test_shutdown_on_engine_failure():
model=MODEL_NAME, prompt="This should fail", max_tokens=1 model=MODEL_NAME, prompt="This should fail", max_tokens=1
) )
return_code = proc.wait(timeout=5) return_code = proc.wait(timeout=_PROCESS_EXIT_TIMEOUT)
assert return_code is not None assert return_code is not None
...@@ -7,6 +7,7 @@ import json ...@@ -7,6 +7,7 @@ import json
import pytest import pytest
from ...utils import RemoteOpenAIServer from ...utils import RemoteOpenAIServer
from .conftest import add_attention_backend
MISTRAL_FORMAT_ARGS = [ MISTRAL_FORMAT_ARGS = [
"--tokenizer_mode", "--tokenizer_mode",
...@@ -20,12 +21,14 @@ MISTRAL_FORMAT_ARGS = [ ...@@ -20,12 +21,14 @@ MISTRAL_FORMAT_ARGS = [
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("model_name", ["mistralai/Voxtral-Mini-3B-2507"]) @pytest.mark.parametrize("model_name", ["mistralai/Voxtral-Mini-3B-2507"])
async def test_basic_audio(mary_had_lamb, model_name): async def test_basic_audio(mary_had_lamb, model_name, rocm_aiter_fa_attention):
server_args = ["--enforce-eager"] server_args = ["--enforce-eager"]
if model_name.startswith("mistralai"): if model_name.startswith("mistralai"):
server_args += MISTRAL_FORMAT_ARGS server_args += MISTRAL_FORMAT_ARGS
add_attention_backend(server_args, rocm_aiter_fa_attention)
# Based on https://github.com/openai/openai-cookbook/blob/main/examples/Whisper_prompting_guide.ipynb. # Based on https://github.com/openai/openai-cookbook/blob/main/examples/Whisper_prompting_guide.ipynb.
with RemoteOpenAIServer(model_name, server_args) as remote_server: with RemoteOpenAIServer(model_name, server_args) as remote_server:
client = remote_server.get_async_client() client = remote_server.get_async_client()
...@@ -44,8 +47,13 @@ async def test_basic_audio(mary_had_lamb, model_name): ...@@ -44,8 +47,13 @@ async def test_basic_audio(mary_had_lamb, model_name):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_basic_audio_with_lora(mary_had_lamb): async def test_basic_audio_with_lora(mary_had_lamb, rocm_aiter_fa_attention):
"""Ensure STT (transcribe) requests can pass LoRA through to generate.""" """Ensure STT (transcribe) requests can pass LoRA through to generate."""
# ROCm SPECIFIC CONFIGURATION:
# To ensure the test passes on ROCm, we modify the max model length to 512.
# We DO NOT apply this to other platforms to maintain strict upstream parity.
from vllm.platforms import current_platform
model_name = "ibm-granite/granite-speech-3.3-2b" model_name = "ibm-granite/granite-speech-3.3-2b"
lora_model_name = "speech" lora_model_name = "speech"
server_args = [ server_args = [
...@@ -56,11 +64,13 @@ async def test_basic_audio_with_lora(mary_had_lamb): ...@@ -56,11 +64,13 @@ async def test_basic_audio_with_lora(mary_had_lamb):
"--lora-modules", "--lora-modules",
f"{lora_model_name}={model_name}", f"{lora_model_name}={model_name}",
"--max-model-len", "--max-model-len",
"2048", "512" if current_platform.is_rocm() else "2048",
"--max-num-seqs", "--max-num-seqs",
"1", "1",
] ]
add_attention_backend(server_args, rocm_aiter_fa_attention)
# Based on https://github.com/openai/openai-cookbook/blob/main/examples/Whisper_prompting_guide.ipynb. # Based on https://github.com/openai/openai-cookbook/blob/main/examples/Whisper_prompting_guide.ipynb.
with RemoteOpenAIServer(model_name, server_args) as remote_server: with RemoteOpenAIServer(model_name, server_args) as remote_server:
client = remote_server.get_async_client() client = remote_server.get_async_client()
...@@ -79,12 +89,14 @@ async def test_basic_audio_with_lora(mary_had_lamb): ...@@ -79,12 +89,14 @@ async def test_basic_audio_with_lora(mary_had_lamb):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_basic_audio_gemma(foscolo): async def test_basic_audio_gemma(foscolo, rocm_aiter_fa_attention):
# Gemma accuracy on some of the audio samples we use is particularly bad, # Gemma accuracy on some of the audio samples we use is particularly bad,
# hence we use a different one here. WER is evaluated separately. # hence we use a different one here. WER is evaluated separately.
model_name = "google/gemma-3n-E2B-it" model_name = "google/gemma-3n-E2B-it"
server_args = ["--enforce-eager"] server_args = ["--enforce-eager"]
add_attention_backend(server_args, rocm_aiter_fa_attention)
with RemoteOpenAIServer( with RemoteOpenAIServer(
model_name, server_args, max_wait_seconds=480 model_name, server_args, max_wait_seconds=480
) as remote_server: ) as remote_server:
......
...@@ -244,3 +244,35 @@ async def test_audio_with_timestamp(mary_had_lamb, whisper_client): ...@@ -244,3 +244,35 @@ async def test_audio_with_timestamp(mary_had_lamb, whisper_client):
) )
assert transcription.segments is not None assert transcription.segments is not None
assert len(transcription.segments) > 0 assert len(transcription.segments) > 0
@pytest.mark.asyncio
async def test_audio_with_max_tokens(whisper_client, mary_had_lamb):
transcription = await whisper_client.audio.transcriptions.create(
model=MODEL_NAME,
file=mary_had_lamb,
language="en",
response_format="text",
temperature=0.0,
extra_body={"max_completion_tokens": 1},
)
out = json.loads(transcription)
out_text = out["text"]
from transformers import AutoTokenizer
tok = AutoTokenizer.from_pretrained(MODEL_NAME)
out_tokens = tok(out_text, add_special_tokens=False)["input_ids"]
assert len(out_tokens) == 1
# max_completion_tokens > max_model_len
transcription = await whisper_client.audio.transcriptions.create(
model=MODEL_NAME,
file=mary_had_lamb,
language="en",
response_format="text",
temperature=0.0,
extra_body={"max_completion_tokens": int(1e6)},
)
out = json.loads(transcription)
out_text = out["text"]
out_tokens = tok(out_text, add_special_tokens=False)["input_ids"]
assert len(out_tokens) < 450 # ~Whisper max output len
...@@ -14,16 +14,26 @@ import pytest_asyncio ...@@ -14,16 +14,26 @@ import pytest_asyncio
import soundfile as sf import soundfile as sf
from ...utils import RemoteOpenAIServer from ...utils import RemoteOpenAIServer
from .conftest import add_attention_backend
SERVER_ARGS = ["--enforce-eager"] SERVER_ARGS = ["--enforce-eager"]
def _get_server_args(attention_config):
"""Get server args with attention backend if specified."""
args = SERVER_ARGS.copy()
add_attention_backend(args, attention_config)
return args
@pytest.fixture( @pytest.fixture(
scope="module", params=["openai/whisper-small", "google/gemma-3n-E2B-it"] scope="module", params=["openai/whisper-small", "google/gemma-3n-E2B-it"]
) )
def server(request): def server(request, rocm_aiter_fa_attention):
# Parametrize over model name # Parametrize over model name
with RemoteOpenAIServer(request.param, SERVER_ARGS) as remote_server: with RemoteOpenAIServer(
request.param, _get_server_args(rocm_aiter_fa_attention)
) as remote_server:
yield remote_server, request.param yield remote_server, request.param
...@@ -35,10 +45,12 @@ async def client_and_model(server): ...@@ -35,10 +45,12 @@ async def client_and_model(server):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_non_asr_model(foscolo): async def test_non_asr_model(foscolo, rocm_aiter_fa_attention):
# text to text model # text to text model
model_name = "JackFram/llama-68m" model_name = "JackFram/llama-68m"
with RemoteOpenAIServer(model_name, SERVER_ARGS) as remote_server: with RemoteOpenAIServer(
model_name, _get_server_args(rocm_aiter_fa_attention)
) as remote_server:
client = remote_server.get_async_client() client = remote_server.get_async_client()
res = await client.audio.translations.create( res = await client.audio.translations.create(
model=model_name, file=foscolo, temperature=0.0 model=model_name, file=foscolo, temperature=0.0
...@@ -49,8 +61,13 @@ async def test_non_asr_model(foscolo): ...@@ -49,8 +61,13 @@ async def test_non_asr_model(foscolo):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_basic_audio_with_lora(mary_had_lamb): async def test_basic_audio_with_lora(mary_had_lamb, rocm_aiter_fa_attention):
"""Ensure STT (translate) requests can pass LoRA through to generate.""" """Ensure STT (translate) requests can pass LoRA through to generate."""
# ROCm SPECIFIC CONFIGURATION:
# To ensure the test passes on ROCm, we modify the max model length to 512.
# We DO NOT apply this to other platforms to maintain strict upstream parity.
from vllm.platforms import current_platform
# NOTE - careful to call this test before the module scoped server # NOTE - careful to call this test before the module scoped server
# fixture, otherwise it'll OOMkill the CI # fixture, otherwise it'll OOMkill the CI
model_name = "ibm-granite/granite-speech-3.3-2b" model_name = "ibm-granite/granite-speech-3.3-2b"
...@@ -63,11 +80,13 @@ async def test_basic_audio_with_lora(mary_had_lamb): ...@@ -63,11 +80,13 @@ async def test_basic_audio_with_lora(mary_had_lamb):
"--lora-modules", "--lora-modules",
f"{lora_model_name}={model_name}", f"{lora_model_name}={model_name}",
"--max-model-len", "--max-model-len",
"2048", "512" if current_platform.is_rocm() else "2048",
"--max-num-seqs", "--max-num-seqs",
"1", "1",
] ]
add_attention_backend(server_args, rocm_aiter_fa_attention)
# Based on https://github.com/openai/openai-cookbook/blob/main/examples/Whisper_prompting_guide.ipynb. # Based on https://github.com/openai/openai-cookbook/blob/main/examples/Whisper_prompting_guide.ipynb.
with RemoteOpenAIServer(model_name, server_args) as remote_server: with RemoteOpenAIServer(model_name, server_args) as remote_server:
client = remote_server.get_async_client() client = remote_server.get_async_client()
...@@ -227,3 +246,36 @@ async def test_long_audio_request(foscolo, client_and_model): ...@@ -227,3 +246,36 @@ async def test_long_audio_request(foscolo, client_and_model):
) )
out = json.loads(translation)["text"].strip().lower() out = json.loads(translation)["text"].strip().lower()
assert out.count("greek sea") == 2 assert out.count("greek sea") == 2
@pytest.mark.asyncio
async def test_audio_with_max_tokens(mary_had_lamb, client_and_model):
client, model_name = client_and_model
transcription = await client.audio.translations.create(
model=model_name,
file=mary_had_lamb,
response_format="text",
temperature=0.0,
extra_body={"max_completion_tokens": 1},
)
out = json.loads(transcription)
out_text = out["text"]
print(out_text)
from transformers import AutoTokenizer
tok = AutoTokenizer.from_pretrained(model_name)
out_tokens = tok(out_text, add_special_tokens=False)["input_ids"]
assert len(out_tokens) == 1
# max_completion_tokens > max_model_len
transcription = await client.audio.transcriptions.create(
model=model_name,
file=mary_had_lamb,
response_format="text",
temperature=0.0,
extra_body={"max_completion_tokens": int(1e6)},
)
out = json.loads(transcription)
out_text = out["text"]
print(out_text)
out_tokens = tok(out_text, add_special_tokens=False)["input_ids"]
assert len(out_tokens) < 450 # ~Whisper max output len
...@@ -8,7 +8,8 @@ import openai ...@@ -8,7 +8,8 @@ import openai
import pytest import pytest
import pytest_asyncio import pytest_asyncio
from vllm.multimodal.utils import encode_video_base64, fetch_video from vllm.multimodal.utils import encode_video_url, fetch_video
from vllm.platforms import current_platform
from ...utils import RemoteOpenAIServer, models_path_prefix, urls_port from ...utils import RemoteOpenAIServer, models_path_prefix, urls_port
...@@ -45,7 +46,16 @@ def server(): ...@@ -45,7 +46,16 @@ def server():
json.dumps({"video": MAXIMUM_VIDEOS}), json.dumps({"video": MAXIMUM_VIDEOS}),
] ]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: # ROCm: Increase timeouts to handle potential network delays and slower
# video processing when downloading multiple videos from external sources
env_overrides = {}
if current_platform.is_rocm():
env_overrides = {
"VLLM_VIDEO_FETCH_TIMEOUT": "120",
"VLLM_ENGINE_ITERATION_TIMEOUT_S": "300",
}
with RemoteOpenAIServer(MODEL_NAME, args, env_dict=env_overrides) as remote_server:
yield remote_server yield remote_server
...@@ -56,9 +66,9 @@ async def client(server): ...@@ -56,9 +66,9 @@ async def client(server):
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def base64_encoded_video() -> dict[str, str]: def url_encoded_video() -> dict[str, str]:
return { return {
video_url: encode_video_base64(fetch_video(video_url)[0]) video_url: encode_video_url(fetch_video(video_url)[0])
for video_url in TEST_VIDEO_URLS for video_url in TEST_VIDEO_URLS
} }
...@@ -183,11 +193,9 @@ async def test_single_chat_session_video_base64encoded( ...@@ -183,11 +193,9 @@ async def test_single_chat_session_video_base64encoded(
client: openai.AsyncOpenAI, client: openai.AsyncOpenAI,
model_name: str, model_name: str,
video_url: str, video_url: str,
base64_encoded_video: dict[str, str], url_encoded_video: dict[str, str],
): ):
messages = dummy_messages_from_video_url( messages = dummy_messages_from_video_url(url_encoded_video[video_url])
f"data:video/jpeg;base64,{base64_encoded_video[video_url]}"
)
# test single completion # test single completion
chat_completion = await client.chat.completions.create( chat_completion = await client.chat.completions.create(
...@@ -231,11 +239,9 @@ async def test_single_chat_session_video_base64encoded_beamsearch( ...@@ -231,11 +239,9 @@ async def test_single_chat_session_video_base64encoded_beamsearch(
client: openai.AsyncOpenAI, client: openai.AsyncOpenAI,
model_name: str, model_name: str,
video_url: str, video_url: str,
base64_encoded_video: dict[str, str], url_encoded_video: dict[str, str],
): ):
messages = dummy_messages_from_video_url( messages = dummy_messages_from_video_url(url_encoded_video[video_url])
f"data:video/jpeg;base64,{base64_encoded_video[video_url]}"
)
chat_completion = await client.chat.completions.create( chat_completion = await client.chat.completions.create(
model=model_name, model=model_name,
...@@ -299,6 +305,11 @@ async def test_chat_streaming_video( ...@@ -299,6 +305,11 @@ async def test_chat_streaming_video(
@pytest.mark.parametrize( @pytest.mark.parametrize(
"video_urls", [TEST_VIDEO_URLS[:i] for i in range(2, len(TEST_VIDEO_URLS))] "video_urls", [TEST_VIDEO_URLS[:i] for i in range(2, len(TEST_VIDEO_URLS))]
) )
@pytest.mark.flaky(
reruns=2,
reruns_delay=5,
condition=current_platform.is_rocm(),
)
async def test_multi_video_input( async def test_multi_video_input(
client: openai.AsyncOpenAI, model_name: str, video_urls: list[str] client: openai.AsyncOpenAI, model_name: str, video_urls: list[str]
): ):
......
...@@ -10,7 +10,8 @@ import pytest_asyncio ...@@ -10,7 +10,8 @@ import pytest_asyncio
from transformers import AutoProcessor from transformers import AutoProcessor
from vllm.multimodal.base import MediaWithBytes from vllm.multimodal.base import MediaWithBytes
from vllm.multimodal.utils import encode_image_base64, fetch_image from vllm.multimodal.utils import encode_image_url, fetch_image
from vllm.platforms import current_platform
from ...utils import RemoteOpenAIServer, models_path_prefix, urls_port from ...utils import RemoteOpenAIServer, models_path_prefix, urls_port
...@@ -31,26 +32,35 @@ TEST_IMAGE_ASSETS = [ ...@@ -31,26 +32,35 @@ TEST_IMAGE_ASSETS = [
f"http://localhost:{urls_port}/RGBA_comp.png", f"http://localhost:{urls_port}/RGBA_comp.png",
] ]
EXPECTED_MM_BEAM_SEARCH_RES = [ # Required terms for beam search validation
[ # Each entry is a list of term groups - ALL groups must match
"The image shows a wooden boardwalk leading through a", # Each group is a list of alternatives - at least ONE term in the group must appear
"The image shows a wooden boardwalk extending into a", # This provides semantic validation while allowing wording variation
], REQUIRED_BEAM_SEARCH_TERMS = [
[ # Boardwalk image: must have "boardwalk" AND ("wooden" or "wood")
"The image shows two parrots perched on", [["boardwalk"], ["wooden", "wood"]],
"The image shows two birds perched on a cur", # Parrots image: must have ("parrot" or "bird") AND "two"
], [["parrot", "bird"], ["two"]],
[ # Venn diagram: must have "venn" AND "diagram"
"The image shows a Venn diagram with three over", [["venn"], ["diagram"]],
"The image shows a colorful Venn diagram with", # Gradient image: must have "gradient" AND ("color" or "spectrum")
], [["gradient"], ["color", "spectrum"]],
[
"This image displays a gradient of colors ranging from",
"This image displays a gradient of colors forming a spectrum",
],
] ]
def check_output_matches_terms(content: str, term_groups: list[list[str]]) -> bool:
"""
Check if content matches all required term groups.
Each term group requires at least one of its terms to be present.
All term groups must be satisfied.
"""
content_lower = content.lower()
for group in term_groups:
if not any(term.lower() in content_lower for term in group):
return False
return True
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def server(): def server():
args = [ args = [
...@@ -66,7 +76,16 @@ def server(): ...@@ -66,7 +76,16 @@ def server():
json.dumps({"image": MAXIMUM_IMAGES}), json.dumps({"image": MAXIMUM_IMAGES}),
] ]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: # ROCm: Increase timeouts to handle potential network delays and slower
# video processing when downloading multiple videos from external sources
env_overrides = {}
if current_platform.is_rocm():
env_overrides = {
"VLLM_VIDEO_FETCH_TIMEOUT": "120",
"VLLM_ENGINE_ITERATION_TIMEOUT_S": "300",
}
with RemoteOpenAIServer(MODEL_NAME, args, env_dict=env_overrides) as remote_server:
yield remote_server yield remote_server
...@@ -77,11 +96,9 @@ async def client(server): ...@@ -77,11 +96,9 @@ async def client(server):
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def base64_encoded_image(local_asset_server) -> dict[str, str]: def url_encoded_image(local_asset_server) -> dict[str, str]:
return { return {
image_asset: encode_image_base64( image_asset: encode_image_url(local_asset_server.get_image_asset(image_asset))
local_asset_server.get_image_asset(image_asset)
)
for image_asset in TEST_IMAGE_ASSETS for image_asset in TEST_IMAGE_ASSETS
} }
...@@ -241,11 +258,11 @@ async def test_single_chat_session_image_base64encoded( ...@@ -241,11 +258,11 @@ async def test_single_chat_session_image_base64encoded(
model_name: str, model_name: str,
raw_image_url: str, raw_image_url: str,
image_url: str, image_url: str,
base64_encoded_image: dict[str, str], url_encoded_image: dict[str, str],
): ):
content_text = "What's in this image?" content_text = "What's in this image?"
messages = dummy_messages_from_image_url( messages = dummy_messages_from_image_url(
f"data:image/jpeg;base64,{base64_encoded_image[raw_image_url]}", url_encoded_image[raw_image_url],
content_text, content_text,
) )
...@@ -295,15 +312,13 @@ async def test_single_chat_session_image_base64encoded_beamsearch( ...@@ -295,15 +312,13 @@ async def test_single_chat_session_image_base64encoded_beamsearch(
client: openai.AsyncOpenAI, client: openai.AsyncOpenAI,
model_name: str, model_name: str,
image_idx: int, image_idx: int,
base64_encoded_image: dict[str, str], url_encoded_image: dict[str, str],
): ):
# NOTE: This test also validates that we pass MM data through beam search # NOTE: This test validates that we pass MM data through beam search
raw_image_url = TEST_IMAGE_ASSETS[image_idx] raw_image_url = TEST_IMAGE_ASSETS[image_idx]
expected_res = EXPECTED_MM_BEAM_SEARCH_RES[image_idx] required_terms = REQUIRED_BEAM_SEARCH_TERMS[image_idx]
messages = dummy_messages_from_image_url( messages = dummy_messages_from_image_url(url_encoded_image[raw_image_url])
f"data:image/jpeg;base64,{base64_encoded_image[raw_image_url]}"
)
chat_completion = await client.chat.completions.create( chat_completion = await client.chat.completions.create(
model=model_name, model=model_name,
...@@ -314,8 +329,29 @@ async def test_single_chat_session_image_base64encoded_beamsearch( ...@@ -314,8 +329,29 @@ async def test_single_chat_session_image_base64encoded_beamsearch(
extra_body=dict(use_beam_search=True), extra_body=dict(use_beam_search=True),
) )
assert len(chat_completion.choices) == 2 assert len(chat_completion.choices) == 2
for actual, expected_str in zip(chat_completion.choices, expected_res):
assert actual.message.content == expected_str # Verify beam search produces two different non-empty outputs
content_0 = chat_completion.choices[0].message.content
content_1 = chat_completion.choices[1].message.content
# Emit beam search outputs for debugging
print(
f"Beam search outputs for image {image_idx} ({raw_image_url}): "
f"Output 0: {content_0!r}, Output 1: {content_1!r}"
)
assert content_0, "First beam search output should not be empty"
assert content_1, "Second beam search output should not be empty"
assert content_0 != content_1, "Beam search should produce different outputs"
# Verify each output contains the required terms for this image
for i, content in enumerate([content_0, content_1]):
if not check_output_matches_terms(content, required_terms):
pytest.fail(
f"Output {i} '{content}' doesn't contain required terms. "
f"Expected all of these term groups (at least one from each): "
f"{required_terms}"
)
@pytest.mark.asyncio @pytest.mark.asyncio
......
...@@ -33,6 +33,7 @@ def _terratorch_dummy_messages(): ...@@ -33,6 +33,7 @@ def _terratorch_dummy_messages():
] ]
@pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model_name", ["ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11"] "model_name", ["ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11"]
) )
......
...@@ -12,11 +12,6 @@ from vllm.distributed import cleanup_dist_env_and_memory ...@@ -12,11 +12,6 @@ from vllm.distributed import cleanup_dist_env_and_memory
from vllm.platforms import current_platform from vllm.platforms import current_platform
from ....utils import models_path_prefix from ....utils import models_path_prefix
if current_platform.is_rocm():
pytest.skip(
"Encoder self-attention is not implemented on ROCm.", allow_module_level=True
)
MODEL_NAME = os.path.join(models_path_prefix, "intfloat/multilingual-e5-small") MODEL_NAME = os.path.join(models_path_prefix, "intfloat/multilingual-e5-small")
PROMPTS = [ PROMPTS = [
...@@ -38,6 +33,12 @@ TOKEN_IDS = [ ...@@ -38,6 +33,12 @@ TOKEN_IDS = [
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def llm(): def llm():
# ROCm: Use FLEX_ATTENTION backend as it's the only attention backend
# that supports encoder-only models on ROCm.
attention_config = None
if current_platform.is_rocm():
attention_config = {"backend": "FLEX_ATTENTION"}
# pytest caches the fixture so we use weakref.proxy to # pytest caches the fixture so we use weakref.proxy to
# enable garbage collection # enable garbage collection
llm = LLM( llm = LLM(
...@@ -47,6 +48,7 @@ def llm(): ...@@ -47,6 +48,7 @@ def llm():
gpu_memory_utilization=0.75, gpu_memory_utilization=0.75,
enforce_eager=True, enforce_eager=True,
seed=0, seed=0,
attention_config=attention_config,
) )
yield weakref.proxy(llm) yield weakref.proxy(llm)
......
...@@ -9,11 +9,6 @@ import pytest_asyncio ...@@ -9,11 +9,6 @@ import pytest_asyncio
from tests.utils import RemoteOpenAIServer from tests.utils import RemoteOpenAIServer
from vllm.platforms import current_platform from vllm.platforms import current_platform
if current_platform.is_rocm():
pytest.skip(
"Encoder self-attention is not implemented on ROCm.", allow_module_level=True
)
MODEL_NAME = "sentence-transformers/all-MiniLM-L12-v2" MODEL_NAME = "sentence-transformers/all-MiniLM-L12-v2"
max_model_len = 128 max_model_len = 128
...@@ -44,6 +39,10 @@ def server(): ...@@ -44,6 +39,10 @@ def server():
str(max_model_len), str(max_model_len),
] ]
# ROCm: Use Flex Attention to support encoder-only self-attention.
if current_platform.is_rocm():
args.extend(["--attention-backend", "FLEX_ATTENTION"])
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server yield remote_server
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Pytest configuration for vLLM pooling embed tests."""
import warnings
import torch
from vllm.platforms import current_platform
def pytest_collection_modifyitems(config, items):
"""Configure ROCm-specific settings based on collected tests."""
if not current_platform.is_rocm():
return
# Disable Flash/MemEfficient SDP on ROCm to avoid HF Transformers
# accuracy issues: https://github.com/vllm-project/vllm/issues/30167
# TODO: Remove once ROCm SDP accuracy issues are resolved on HuggingFace
torch.backends.cuda.enable_flash_sdp(False)
torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_math_sdp(True)
warnings.warn(
"ROCm: Disabled flash_sdp and mem_efficient_sdp, enabled math_sdp "
"to avoid HuggingFace Transformers accuracy issues",
UserWarning,
stacklevel=1,
)
...@@ -4,7 +4,7 @@ import os ...@@ -4,7 +4,7 @@ import os
import pytest import pytest
from tests.models.language.pooling_mteb_test.mteb_utils import ( from tests.models.language.pooling_mteb_test.mteb_embed_utils import (
MTEB_EMBED_TASKS, MTEB_EMBED_TASKS,
MTEB_EMBED_TOL, MTEB_EMBED_TOL,
OpenAIClientMtebEncoder, OpenAIClientMtebEncoder,
...@@ -13,11 +13,6 @@ from tests.models.language.pooling_mteb_test.mteb_utils import ( ...@@ -13,11 +13,6 @@ from tests.models.language.pooling_mteb_test.mteb_utils import (
from tests.utils import RemoteOpenAIServer from tests.utils import RemoteOpenAIServer
from vllm.platforms import current_platform from vllm.platforms import current_platform
if current_platform.is_rocm():
pytest.skip(
"Encoder self-attention is not implemented on ROCm.", allow_module_level=True
)
os.environ["VLLM_LOGGING_LEVEL"] = "WARNING" os.environ["VLLM_LOGGING_LEVEL"] = "WARNING"
MODEL_NAME = "intfloat/e5-small" MODEL_NAME = "intfloat/e5-small"
...@@ -28,6 +23,10 @@ MAIN_SCORE = 0.7422994752439667 ...@@ -28,6 +23,10 @@ MAIN_SCORE = 0.7422994752439667
def server(): def server():
args = ["--runner", "pooling", "--enforce-eager", "--disable-uvicorn-access-log"] args = ["--runner", "pooling", "--enforce-eager", "--disable-uvicorn-access-log"]
# ROCm: Use Flex Attention to support encoder-only self-attention.
if current_platform.is_rocm():
args.extend(["--attention-backend", "FLEX_ATTENTION"])
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server yield remote_server
......
...@@ -11,11 +11,6 @@ from vllm import LLM, PoolingParams ...@@ -11,11 +11,6 @@ from vllm import LLM, PoolingParams
from vllm.distributed import cleanup_dist_env_and_memory from vllm.distributed import cleanup_dist_env_and_memory
from vllm.platforms import current_platform from vllm.platforms import current_platform
if current_platform.is_rocm():
pytest.skip(
"Encoder self-attention is not implemented on ROCm.", allow_module_level=True
)
MODEL_NAME = "intfloat/multilingual-e5-small" MODEL_NAME = "intfloat/multilingual-e5-small"
prompts = ["The chef prepared a delicious meal."] prompts = ["The chef prepared a delicious meal."]
...@@ -23,6 +18,12 @@ prompts = ["The chef prepared a delicious meal."] ...@@ -23,6 +18,12 @@ prompts = ["The chef prepared a delicious meal."]
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def llm(): def llm():
# ROCm: Use FLEX_ATTENTION backend as it's the only attention backend
# that supports encoder-only models on ROCm.
attention_config = None
if current_platform.is_rocm():
attention_config = {"backend": "FLEX_ATTENTION"}
# pytest caches the fixture so we use weakref.proxy to # pytest caches the fixture so we use weakref.proxy to
# enable garbage collection # enable garbage collection
llm = LLM( llm = LLM(
...@@ -32,6 +33,7 @@ def llm(): ...@@ -32,6 +33,7 @@ def llm():
gpu_memory_utilization=0.75, gpu_memory_utilization=0.75,
enforce_eager=True, enforce_eager=True,
seed=0, seed=0,
attention_config=attention_config,
) )
yield weakref.proxy(llm) yield weakref.proxy(llm)
...@@ -51,7 +53,9 @@ def test_token_embed(llm: LLM): ...@@ -51,7 +53,9 @@ def test_token_embed(llm: LLM):
def test_pooling_params(llm: LLM): def test_pooling_params(llm: LLM):
def get_outputs(normalize): def get_outputs(normalize):
outputs = llm.embed( outputs = llm.embed(
prompts, pooling_params=PoolingParams(normalize=normalize), use_tqdm=False prompts,
pooling_params=PoolingParams(use_activation=normalize),
use_tqdm=False,
) )
return torch.tensor([x.outputs.embedding for x in outputs]) return torch.tensor([x.outputs.embedding for x in outputs])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment