Commit a3f8d5dd authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.13.0rc2' into v0.13.0rc2-ori

parents 8d75f22e f34eca5f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Test that the interaction between EPLB and FusedMoE Layer is okay for DP w/ NVFP4
from dataclasses import dataclass
import pytest
import torch
from tests.kernels.moe.utils import make_test_quant_config
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.distributed.eplb.rebalance_execute import rearrange_expert_weights_inplace
from vllm.distributed.parallel_state import (
ensure_model_parallel_initialized,
get_dp_group,
)
from vllm.forward_context import set_forward_context
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
from vllm.model_executor.layers.quantization.modelopt import (
ModelOptNvFp4Config,
ModelOptNvFp4FusedMoE,
)
from .eplb_utils import distributed_run, set_env_vars_and_device
@dataclass
class TestConfig:
num_layers: int
num_experts: int
num_local_experts: int
num_topk: int
hidden_size: int
intermediate_size: int
num_tokens: int
def make_fused_moe_layer(
rank: int,
layer_idx: int,
test_config: TestConfig,
) -> FusedMoE:
quant_config = None
device = torch.device(f"cuda:{rank}")
quant_config = ModelOptNvFp4Config(
is_checkpoint_nvfp4_serialized=True,
kv_cache_quant_algo=None,
exclude_modules=[],
)
fml = FusedMoE(
num_experts=test_config.num_experts,
top_k=test_config.num_topk,
hidden_size=test_config.hidden_size,
intermediate_size=test_config.intermediate_size,
prefix=f"dummy_layer_{layer_idx}",
activation="silu",
is_act_and_mul=True,
params_dtype=torch.bfloat16,
quant_config=quant_config,
)
nvfp4_fused_moe = ModelOptNvFp4FusedMoE(quant_config, fml)
nvfp4_fused_moe.create_weights(
fml,
test_config.num_local_experts,
test_config.hidden_size,
test_config.intermediate_size,
params_dtype=torch.uint8,
global_num_experts=test_config.num_experts,
)
fml = fml.to(device)
w1_q, w2_q, quant_config = make_test_quant_config(
test_config.num_local_experts,
test_config.intermediate_size,
test_config.hidden_size,
in_dtype=torch.bfloat16,
quant_dtype="nvfp4",
block_shape=None,
per_act_token_quant=False,
)
fml.w13_weight.data = w1_q
fml.w2_weight.data = w2_q
fml.w2_input_scale.data = torch.randn_like(fml.w2_input_scale.data) / 5
fml.w13_input_scale.data = torch.randn_like(fml.w13_input_scale.data) / 5
fml.w2_weight_scale_2.data = torch.randn_like(fml.w2_weight_scale_2.data) / 5
fml.w13_weight_scale_2.data = torch.randn_like(fml.w13_weight_scale_2.data) / 5
fml.w2_weight_scale.data = (
torch.randn(fml.w2_weight_scale.data.shape, device=device) / 5
).to(fml.w2_weight_scale.data.dtype)
fml.w13_weight_scale.data = (
torch.randn(fml.w13_weight_scale.data.shape, device=device) / 5
).to(fml.w13_weight_scale.data.dtype)
nvfp4_fused_moe.process_weights_after_loading(fml)
fml.maybe_init_modular_kernel()
return fml
def _test_eplb_fml(env, world_size: int, test_config: TestConfig):
set_env_vars_and_device(env)
vllm_config = VllmConfig()
vllm_config.parallel_config.data_parallel_size = world_size
vllm_config.parallel_config.enable_expert_parallel = True
with set_current_vllm_config(vllm_config):
ensure_model_parallel_initialized(
tensor_model_parallel_size=1, pipeline_model_parallel_size=1
)
ep_group = get_dp_group().cpu_group
ep_rank = torch.distributed.get_rank()
device = torch.device(f"cuda:{ep_rank}")
fml_layers = [
make_fused_moe_layer(ep_rank, layer_idx, test_config).to(device)
for layer_idx in range(test_config.num_layers)
]
rank_expert_weights = [fml.get_expert_weights() for fml in fml_layers]
hidden_states = []
router_logits = []
for layer_idx in range(test_config.num_layers):
hidden_states.append(
torch.randn(
(test_config.num_tokens, test_config.hidden_size),
dtype=torch.bfloat16,
device=device,
)
)
router_logits.append(
torch.randn(
(test_config.num_tokens, test_config.num_experts),
dtype=torch.bfloat16,
device=device,
)
)
out_before_shuffle = []
with set_forward_context(
{},
num_tokens=test_config.num_tokens,
num_tokens_across_dp=torch.tensor(
[test_config.num_tokens] * world_size, device="cpu", dtype=torch.int
),
vllm_config=vllm_config,
):
for lidx, fml in enumerate(fml_layers):
out_before_shuffle.append(
fml(hidden_states[lidx].clone(), router_logits[lidx].clone())
)
indices = torch.zeros(
test_config.num_layers, test_config.num_experts, dtype=torch.long
)
for lidx in range(test_config.num_layers):
indices[lidx] = torch.Tensor(range(test_config.num_experts))
shuffled_indices = torch.zeros_like(indices)
for lidx in range(test_config.num_layers):
shuffled_indices[lidx] = torch.randperm(test_config.num_experts)
rearrange_expert_weights_inplace(
indices,
shuffled_indices,
rank_expert_weights,
ep_group,
is_profile=False,
)
num_global_experts = test_config.num_experts
logical_to_physical_map_list = []
for lidx, fml in enumerate(fml_layers):
physical_to_logical_map = shuffled_indices[lidx].to(device)
logical_to_physical_map = torch.empty(
(num_global_experts,), dtype=torch.int32, device=device
)
logical_to_physical_map[physical_to_logical_map] = torch.arange(
0, num_global_experts, dtype=torch.int32, device=device
)
logical_to_physical_map_list.append(
logical_to_physical_map.reshape(num_global_experts, 1)
)
logical_to_physical_map = torch.stack(logical_to_physical_map_list)
for lidx, fml in enumerate(fml_layers):
logical_replica_count = torch.ones(
(test_config.num_layers, num_global_experts),
dtype=torch.int32,
device=device,
)
fml.enable_eplb = True
fml.set_eplb_state(
lidx,
torch.zeros(
(test_config.num_layers, num_global_experts),
dtype=torch.int32,
device=device,
),
logical_to_physical_map,
logical_replica_count,
)
out_after_shuffle = []
with set_forward_context(
{},
num_tokens=test_config.num_tokens,
num_tokens_across_dp=torch.tensor(
[test_config.num_tokens] * world_size, device="cpu", dtype=torch.int
),
vllm_config=vllm_config,
):
for lidx, fml in enumerate(fml_layers):
out_after_shuffle.append(
fml(hidden_states[lidx].clone(), router_logits[lidx].clone())
)
for lidx in range(test_config.num_layers):
torch.testing.assert_close(
out_before_shuffle[lidx], out_after_shuffle[lidx], atol=1e-1, rtol=1e-1
)
@pytest.mark.parametrize("world_size", [2, 4])
@pytest.mark.parametrize("num_layers", [8])
@pytest.mark.parametrize("num_experts", [32])
@pytest.mark.parametrize("hidden_size", [256])
@pytest.mark.parametrize("intermediate_size", [256])
@pytest.mark.parametrize("num_tokens", [256])
@pytest.mark.parametrize("backend", ["latency", "throughput"])
def test_eplb_fml(
world_size: int,
num_layers: int,
num_experts: int,
hidden_size: int,
intermediate_size: int,
num_tokens: int,
backend: str,
monkeypatch,
):
monkeypatch.setenv("VLLM_USE_FLASHINFER_MOE_FP4", "1")
monkeypatch.setenv("VLLM_FLASHINFER_MOE_BACKEND", backend)
if torch.cuda.device_count() < world_size:
pytest.skip(f"Need at least {world_size} GPUs to run the test")
num_local_experts = num_experts // world_size
num_topk = 4
test_config = TestConfig(
num_layers=num_layers,
num_experts=num_experts,
num_local_experts=num_local_experts,
num_topk=num_topk,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_tokens=num_tokens,
)
distributed_run(
_test_eplb_fml,
world_size,
test_config,
)
......@@ -350,21 +350,35 @@ def test_human_readable_model_len():
assert args.max_model_len == 1_000_000
args = parser.parse_args(["--max-model-len", "10k"])
assert args.max_model_len == 10_000
args = parser.parse_args(["--max-model-len", "2g"])
assert args.max_model_len == 2_000_000_000
args = parser.parse_args(["--max-model-len", "2t"])
assert args.max_model_len == 2_000_000_000_000
# Capital
args = parser.parse_args(["--max-model-len", "3K"])
assert args.max_model_len == 1024 * 3
assert args.max_model_len == 2**10 * 3
args = parser.parse_args(["--max-model-len", "10M"])
assert args.max_model_len == 2**20 * 10
args = parser.parse_args(["--max-model-len", "4G"])
assert args.max_model_len == 2**30 * 4
args = parser.parse_args(["--max-model-len", "4T"])
assert args.max_model_len == 2**40 * 4
# Decimal values
args = parser.parse_args(["--max-model-len", "10.2k"])
assert args.max_model_len == 10200
# ..truncated to the nearest int
args = parser.parse_args(["--max-model-len", "10.212345k"])
args = parser.parse_args(["--max-model-len", "10.2123451234567k"])
assert args.max_model_len == 10212
args = parser.parse_args(["--max-model-len", "10.2123451234567m"])
assert args.max_model_len == 10212345
args = parser.parse_args(["--max-model-len", "10.2123451234567g"])
assert args.max_model_len == 10212345123
args = parser.parse_args(["--max-model-len", "10.2123451234567t"])
assert args.max_model_len == 10212345123456
# Invalid (do not allow decimals with binary multipliers)
for invalid in ["1a", "pwd", "10.24", "1.23M"]:
for invalid in ["1a", "pwd", "10.24", "1.23M", "1.22T"]:
with pytest.raises(ArgumentError):
args = parser.parse_args(["--max-model-len", invalid])
parser.parse_args(["--max-model-len", invalid])
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from openai.types.responses import ResponseFunctionToolCall, ResponseReasoningItem
from openai.types.responses.response_output_item import McpCall
from openai_harmony import Author, Message, Role, TextContent
from tests.entrypoints.openai.utils import verify_harmony_messages
from vllm.entrypoints.openai.parser.harmony_utils import (
auto_drop_analysis_messages,
get_encoding,
has_custom_tools,
parse_chat_input_to_harmony_message,
parse_chat_output,
parse_input_to_harmony_message,
parse_output_message,
)
class TestParseInputToHarmonyMessage:
"""Tests for parse_input_to_harmony_message function."""
class TestCommonParseInputToHarmonyMessage:
"""
Tests for scenarios that are common to both Chat Completion
parse_chat_input_to_harmony_message and Responsees API
parse_input_to_harmony_message functions.
"""
@pytest.fixture(
params=[parse_chat_input_to_harmony_message, parse_input_to_harmony_message]
)
def parse_function(self, request):
return request.param
def test_assistant_message_with_tool_calls(self):
def test_assistant_message_with_tool_calls(self, parse_function):
"""Test parsing assistant message with tool calls."""
chat_msg = {
"role": "assistant",
......@@ -35,7 +51,7 @@ class TestParseInputToHarmonyMessage:
],
}
messages = parse_input_to_harmony_message(chat_msg)
messages = parse_function(chat_msg)
assert len(messages) == 2
......@@ -53,7 +69,7 @@ class TestParseInputToHarmonyMessage:
assert messages[1].recipient == "functions.search_web"
assert messages[1].content_type == "json"
def test_assistant_message_with_empty_tool_call_arguments(self):
def test_assistant_message_with_empty_tool_call_arguments(self, parse_function):
"""Test parsing assistant message with tool call having None arguments."""
chat_msg = {
"role": "assistant",
......@@ -67,12 +83,152 @@ class TestParseInputToHarmonyMessage:
],
}
messages = parse_input_to_harmony_message(chat_msg)
messages = parse_function(chat_msg)
assert len(messages) == 1
assert messages[0].content[0].text == ""
assert messages[0].recipient == "functions.get_current_time"
def test_system_message(self, parse_function):
"""Test parsing system message."""
chat_msg = {
"role": "system",
"content": "You are a helpful assistant",
}
messages = parse_function(chat_msg)
assert len(messages) == 1
# System messages are converted using Message.from_dict
# which should preserve the role
assert messages[0].author.role == Role.SYSTEM
def test_developer_message(self, parse_function):
"""Test parsing developer message."""
chat_msg = {
"role": "developer",
"content": "Use concise language",
}
messages = parse_function(chat_msg)
assert len(messages) == 1
assert messages[0].author.role == Role.DEVELOPER
def test_user_message_with_string_content(self, parse_function):
"""Test parsing user message with string content."""
chat_msg = {
"role": "user",
"content": "What's the weather in San Francisco?",
}
messages = parse_function(chat_msg)
assert len(messages) == 1
assert messages[0].author.role == Role.USER
assert messages[0].content[0].text == "What's the weather in San Francisco?"
def test_user_message_with_array_content(self, parse_function):
"""Test parsing user message with array content."""
chat_msg = {
"role": "user",
"content": [
{"text": "What's in this image? "},
{"text": "Please describe it."},
],
}
messages = parse_function(chat_msg)
assert len(messages) == 1
assert messages[0].author.role == Role.USER
assert len(messages[0].content) == 2
assert messages[0].content[0].text == "What's in this image? "
assert messages[0].content[1].text == "Please describe it."
def test_assistant_message_with_string_content(self, parse_function):
"""Test parsing assistant message with string content (no tool calls)."""
chat_msg = {
"role": "assistant",
"content": "Hello! How can I help you today?",
}
messages = parse_function(chat_msg)
assert len(messages) == 1
assert messages[0].author.role == Role.ASSISTANT
assert messages[0].content[0].text == "Hello! How can I help you today?"
def test_pydantic_model_input(self, parse_function):
"""Test parsing Pydantic model input (has model_dump method)."""
class MockPydanticModel:
def model_dump(self, exclude_none=True):
return {
"role": "user",
"content": "Test message",
}
chat_msg = MockPydanticModel()
messages = parse_function(chat_msg)
assert len(messages) == 1
assert messages[0].author.role == Role.USER
assert messages[0].content[0].text == "Test message"
def test_tool_call_with_missing_function_fields(self, parse_function):
"""Test parsing tool call with missing name or arguments."""
chat_msg = {
"role": "assistant",
"tool_calls": [
{
"function": {} # Missing both name and arguments
}
],
}
messages = parse_function(chat_msg)
assert len(messages) == 1
assert messages[0].recipient == "functions."
assert messages[0].content[0].text == ""
def test_array_content_with_missing_text(self, parse_function):
"""Test parsing array content where text field is missing."""
chat_msg = {
"role": "user",
"content": [
{}, # Missing text field
{"text": "actual text"},
],
}
messages = parse_function(chat_msg)
assert len(messages) == 1
assert len(messages[0].content) == 2
assert messages[0].content[0].text == ""
assert messages[0].content[1].text == "actual text"
class TestParseInputToHarmonyMessage:
"""
Tests for scenarios that are specific to the Responses API
parse_input_to_harmony_message function.
"""
def test_message_with_empty_content(self):
"""Test parsing message with empty string content."""
chat_msg = {
"role": "user",
"content": "",
}
messages = parse_input_to_harmony_message(chat_msg)
assert len(messages) == 1
assert messages[0].content[0].text == ""
def test_tool_message_with_string_content(self):
"""Test parsing tool message with string content."""
chat_msg = {
......@@ -111,6 +267,7 @@ class TestParseInputToHarmonyMessage:
assert len(messages) == 1
assert messages[0].author.role == Role.TOOL
assert messages[0].author.name == "functions.search_results"
assert messages[0].content[0].text == "Result 1: Result 2: Result 3"
def test_tool_message_with_empty_content(self):
......@@ -124,140 +281,564 @@ class TestParseInputToHarmonyMessage:
messages = parse_input_to_harmony_message(chat_msg)
assert len(messages) == 1
assert messages[0].author.role == Role.TOOL
assert messages[0].author.name == "functions.empty_tool"
assert messages[0].content[0].text == ""
def test_system_message(self):
"""Test parsing system message."""
class TestParseChatInputToHarmonyMessage:
"""
Tests for scenarios that are specific to the Chat Completion API
parse_chat_input_to_harmony_message function.
"""
def test_user_message_with_empty_content(self):
chat_msg = {
"role": "system",
"content": "You are a helpful assistant",
"role": "user",
"content": "",
}
messages = parse_input_to_harmony_message(chat_msg)
messages = parse_chat_input_to_harmony_message(chat_msg)
assert len(messages) == 1
# System messages are converted using Message.from_dict
# which should preserve the role
assert messages[0].author.role == Role.SYSTEM
verify_harmony_messages(
messages,
[
{
"role": "user",
"content": "",
},
],
)
def test_developer_message(self):
"""Test parsing developer message."""
def test_user_message_with_none_content(self):
chat_msg = {
"role": "developer",
"content": "Use concise language",
"role": "user",
"content": None,
}
messages = parse_input_to_harmony_message(chat_msg)
messages = parse_chat_input_to_harmony_message(chat_msg)
assert len(messages) == 1
assert messages[0].author.role == Role.DEVELOPER
verify_harmony_messages(
messages,
[
{
"role": "user",
"content": "",
},
],
)
def test_user_message_with_string_content(self):
"""Test parsing user message with string content."""
def test_assistant_message_with_empty_content(self):
chat_msg = {
"role": "user",
"content": "What's the weather in San Francisco?",
"role": "assistant",
"content": "",
}
messages = parse_input_to_harmony_message(chat_msg)
messages = parse_chat_input_to_harmony_message(chat_msg)
assert len(messages) == 1
assert messages[0].author.role == Role.USER
assert messages[0].content[0].text == "What's the weather in San Francisco?"
assert len(messages) == 0
def test_user_message_with_array_content(self):
"""Test parsing user message with array content."""
def test_assistant_message_with_none_content(self):
chat_msg = {
"role": "user",
"content": [
{"text": "What's in this image? "},
{"text": "Please describe it."},
"role": "assistant",
"content": None,
}
messages = parse_chat_input_to_harmony_message(chat_msg)
assert len(messages) == 0
def test_assistant_message_with_content_but_empty_reasoning(self):
chat_msg = {
"role": "assistant",
"content": "The answer is 4.",
"reasoning": "",
}
messages = parse_chat_input_to_harmony_message(chat_msg)
verify_harmony_messages(
messages,
[
{
"role": "assistant",
"channel": "final",
"content": "The answer is 4.",
},
],
)
def test_assistant_message_with_reasoning_but_empty_content(self):
chat_msg = {
"role": "assistant",
"reasoning": "I'm thinking about the user's question.",
"content": "",
}
messages = parse_input_to_harmony_message(chat_msg)
messages = parse_chat_input_to_harmony_message(chat_msg)
assert len(messages) == 1
assert messages[0].author.role == Role.USER
assert len(messages[0].content) == 2
assert messages[0].content[0].text == "What's in this image? "
assert messages[0].content[1].text == "Please describe it."
verify_harmony_messages(
messages,
[
{
"role": "assistant",
"channel": "analysis",
"content": "I'm thinking about the user's question.",
},
],
)
def test_assistant_message_with_string_content(self):
"""Test parsing assistant message with string content (no tool calls)."""
def test_assistant_message_with_reasoning_but_none_content(self):
chat_msg = {
"role": "assistant",
"content": "Hello! How can I help you today?",
"reasoning": "I'm thinking about the user's question.",
"content": None,
}
messages = parse_input_to_harmony_message(chat_msg)
messages = parse_chat_input_to_harmony_message(chat_msg)
assert len(messages) == 1
assert messages[0].author.role == Role.ASSISTANT
assert messages[0].content[0].text == "Hello! How can I help you today?"
verify_harmony_messages(
messages,
[
{
"role": "assistant",
"channel": "analysis",
"content": "I'm thinking about the user's question.",
},
],
)
def test_pydantic_model_input(self):
"""Test parsing Pydantic model input (has model_dump method)."""
def test_assistant_message_with_tool_calls_but_no_content(self):
chat_msg = {
"role": "assistant",
"tool_calls": [
{
"function": {
"name": "get_weather",
"arguments": '{"location": "San Francisco"}',
}
}
],
}
class MockPydanticModel:
def model_dump(self, exclude_none=True):
return {
"role": "user",
"content": "Test message",
messages = parse_chat_input_to_harmony_message(chat_msg)
verify_harmony_messages(
messages,
[
{
"role": "assistant",
"channel": "commentary",
"recipient": "functions.get_weather",
"content": '{"location": "San Francisco"}',
"content_type": "json",
},
],
)
def test_assistant_message_with_tool_calls_and_content(self):
chat_msg = {
"role": "assistant",
"tool_calls": [
{
"function": {
"name": "get_weather",
"arguments": '{"location": "San Francisco"}',
}
}
],
"content": "I'll call the tool.",
}
chat_msg = MockPydanticModel()
messages = parse_input_to_harmony_message(chat_msg)
messages = parse_chat_input_to_harmony_message(chat_msg)
assert len(messages) == 1
assert messages[0].author.role == Role.USER
assert messages[0].content[0].text == "Test message"
verify_harmony_messages(
messages,
[
{
"role": "assistant",
"channel": "commentary",
"content": "I'll call the tool.",
},
{
"role": "assistant",
"channel": "commentary",
"recipient": "functions.get_weather",
"content": '{"location": "San Francisco"}',
"content_type": "json",
},
],
)
def test_message_with_empty_content(self):
"""Test parsing message with empty string content."""
def test_assistant_message_with_tool_calls_and_reasoning(self):
chat_msg = {
"role": "user",
"content": "",
"role": "assistant",
"tool_calls": [
{
"function": {
"name": "get_weather",
"arguments": '{"location": "San Francisco"}',
}
}
],
"reasoning": "I should use the get_weather tool.",
}
messages = parse_input_to_harmony_message(chat_msg)
messages = parse_chat_input_to_harmony_message(chat_msg)
assert len(messages) == 1
assert messages[0].content[0].text == ""
verify_harmony_messages(
messages,
[
{
"role": "assistant",
"channel": "analysis",
"content": "I should use the get_weather tool.",
},
{
"role": "assistant",
"channel": "commentary",
"recipient": "functions.get_weather",
"content": '{"location": "San Francisco"}',
"content_type": "json",
},
],
)
def test_tool_call_with_missing_function_fields(self):
"""Test parsing tool call with missing name or arguments."""
def test_assistant_message_with_tool_calls_and_reasoning_and_content(self):
chat_msg = {
"role": "assistant",
"tool_calls": [
{
"function": {} # Missing both name and arguments
"function": {
"name": "get_weather",
"arguments": '{"location": "San Francisco"}',
}
}
],
"reasoning": "I should use the get_weather tool.",
"content": "I'll call the tool.",
}
messages = parse_input_to_harmony_message(chat_msg)
messages = parse_chat_input_to_harmony_message(chat_msg)
assert len(messages) == 1
assert messages[0].recipient == "functions."
assert messages[0].content[0].text == ""
verify_harmony_messages(
messages,
[
{
"role": "assistant",
"channel": "commentary",
"content": "I'll call the tool.",
},
{
"role": "assistant",
"channel": "analysis",
"content": "I should use the get_weather tool.",
},
{
"role": "assistant",
"channel": "commentary",
"recipient": "functions.get_weather",
"content": '{"location": "San Francisco"}',
"content_type": "json",
},
],
)
def test_array_content_with_missing_text(self):
"""Test parsing array content where text field is missing."""
def test_tool_message_with_string_content(self):
tool_id_names = {
"call_123": "get_weather",
}
chat_msg = {
"role": "user",
"role": "tool",
"tool_call_id": "call_123",
"content": "The weather in San Francisco is sunny, 72°F",
}
messages = parse_chat_input_to_harmony_message(
chat_msg, tool_id_names=tool_id_names
)
verify_harmony_messages(
messages,
[
{
"role": "tool",
"name": "functions.get_weather",
"content": "The weather in San Francisco is sunny, 72°F",
"channel": "commentary",
},
],
)
def test_tool_message_with_array_content(self):
tool_id_names = {
"call_123": "search_results",
}
chat_msg = {
"role": "tool",
"tool_call_id": "call_123",
"content": [
{}, # Missing text field
{"text": "actual text"},
{"type": "text", "text": "Result 1: "},
{"type": "text", "text": "Result 2: "},
{
"type": "image",
"url": "http://example.com/img.png",
}, # Should be ignored
{"type": "text", "text": "Result 3"},
],
}
messages = parse_input_to_harmony_message(chat_msg)
messages = parse_chat_input_to_harmony_message(
chat_msg, tool_id_names=tool_id_names
)
assert len(messages) == 1
assert len(messages[0].content) == 2
assert messages[0].content[0].text == ""
assert messages[0].content[1].text == "actual text"
verify_harmony_messages(
messages,
[
{
"role": "tool",
"name": "functions.search_results",
"content": "Result 1: Result 2: Result 3",
"channel": "commentary",
},
],
)
def test_tool_message_with_empty_content(self):
tool_id_names = {
"call_123": "empty_tool",
}
chat_msg = {
"role": "tool",
"tool_call_id": "call_123",
"content": "",
}
messages = parse_chat_input_to_harmony_message(
chat_msg, tool_id_names=tool_id_names
)
verify_harmony_messages(
messages,
[
{
"role": "tool",
"name": "functions.empty_tool",
"content": "",
"channel": "commentary",
},
],
)
def test_tool_message_with_none_content(self):
tool_id_names = {
"call_123": "empty_tool",
}
chat_msg = {
"role": "tool",
"tool_call_id": "call_123",
"content": None,
}
messages = parse_chat_input_to_harmony_message(
chat_msg, tool_id_names=tool_id_names
)
verify_harmony_messages(
messages,
[
{
"role": "tool",
"name": "functions.empty_tool",
"content": "",
"channel": "commentary",
},
],
)
class TestAutoDropAnalysisMessages:
def test_no_analysis_messages(self) -> None:
messages = [
Message.from_role_and_content(
Role.ASSISTANT, "The answer is 4."
).with_channel("final"),
]
cleaned_messages = auto_drop_analysis_messages(messages)
assert cleaned_messages == messages
def test_only_analysis_message(self) -> None:
messages = [
Message.from_role_and_content(
Role.ASSISTANT, "I'm thinking about the user's question."
).with_channel("analysis"),
]
cleaned_messages = auto_drop_analysis_messages(messages)
assert cleaned_messages == messages
def test_multiple_analysis_messages_without_final_message(self) -> None:
messages = [
Message.from_role_and_content(
Role.ASSISTANT, "I'm thinking about the user's question."
).with_channel("analysis"),
Message.from_role_and_content(
Role.ASSISTANT, "I'm thinking more."
).with_channel("analysis"),
Message.from_role_and_content(
Role.ASSISTANT, "I'm thinking even more."
).with_channel("analysis"),
]
cleaned_messages = auto_drop_analysis_messages(messages)
assert cleaned_messages == messages
def test_only_final_message(self) -> None:
messages = [
Message.from_role_and_content(
Role.ASSISTANT, "The answer is 4."
).with_channel("final"),
]
cleaned_messages = auto_drop_analysis_messages(messages)
assert cleaned_messages == messages
def test_drops_one_analysis_messages_before_final_message(self) -> None:
messages = [
Message.from_role_and_content(
Role.ASSISTANT, "I'm thinking about the user's question."
).with_channel("analysis"),
Message.from_role_and_content(
Role.ASSISTANT, "The answer is 4."
).with_channel("final"),
Message.from_role_and_content(
Role.ASSISTANT, "I should think harder."
).with_channel("analysis"),
]
cleaned_messages = auto_drop_analysis_messages(messages)
# Should have dropped the first analysis message
assert cleaned_messages == messages[1:]
def test_drops_all_analysis_messages_before_final_message(self) -> None:
messages = [
Message.from_role_and_content(
Role.ASSISTANT, "I'm thinking about the user's question."
).with_channel("analysis"),
Message.from_role_and_content(
Role.ASSISTANT, "I'm thinking more."
).with_channel("analysis"),
Message.from_role_and_content(
Role.ASSISTANT, "I'm thinking even more."
).with_channel("analysis"),
Message.from_role_and_content(
Role.ASSISTANT, "The answer is 4."
).with_channel("final"),
Message.from_role_and_content(
Role.ASSISTANT, "I should think harder."
).with_channel("analysis"),
]
cleaned_messages = auto_drop_analysis_messages(messages)
# Should have dropped the first 3 analysis messages
assert cleaned_messages == messages[3:]
def test_multiple_analysis_messages_with_multiple_final_messages(self) -> None:
messages = [
Message.from_role_and_content(
Role.ASSISTANT, "I'm thinking about the user's question."
).with_channel("analysis"),
Message.from_role_and_content(
Role.ASSISTANT, "I'm thinking more."
).with_channel("analysis"),
Message.from_role_and_content(
Role.ASSISTANT, "I'm thinking even more."
).with_channel("analysis"),
Message.from_role_and_content(
Role.ASSISTANT, "The answer is 4."
).with_channel("final"),
Message.from_role_and_content(
Role.ASSISTANT, "I should think harder."
).with_channel("analysis"),
Message.from_role_and_content(
Role.ASSISTANT, "The answer is 5."
).with_channel("final"),
]
cleaned_messages = auto_drop_analysis_messages(messages)
# Should have dropped all those analysis messages
assert len(cleaned_messages) == 2
assert cleaned_messages[0].content[0].text == "The answer is 4."
assert cleaned_messages[1].content[0].text == "The answer is 5."
def test_drops_non_assistant_analysis_messages(self) -> None:
messages = [
Message.from_role_and_content(
Role.TOOL, "The tool thinks we should think harder."
).with_channel("analysis"),
Message.from_role_and_content(
Role.ASSISTANT, "The answer is 4."
).with_channel("final"),
]
cleaned_messages = auto_drop_analysis_messages(messages)
# Should have dropped the analysis message
assert cleaned_messages == messages[1:]
class TestParseChatOutput:
def test_parse_chat_output_interrupted_first_message(self) -> None:
harmony_str = "<|channel|>final<|message|>I'm in the middle of answering"
token_ids = get_encoding().encode(harmony_str, allowed_special="all")
reasoning, final_content, _ = parse_chat_output(token_ids)
assert reasoning is None
assert final_content == "I'm in the middle of answering"
def test_parse_chat_output_interrupted_reasoning_first_message(self) -> None:
harmony_str = "<|channel|>analysis<|message|>I'm in the middle of thinking"
token_ids = get_encoding().encode(harmony_str, allowed_special="all")
reasoning, final_content, _ = parse_chat_output(token_ids)
assert reasoning == "I'm in the middle of thinking"
assert final_content is None
def test_parse_chat_output_complete_reasoning_interrupted_content(self) -> None:
harmony_str = (
"<|channel|>analysis<|message|>I'm thinking.<|end|>"
"<|start|>assistant<|channel|>final"
"<|message|>I'm in the middle of answering"
)
token_ids = get_encoding().encode(harmony_str, allowed_special="all")
reasoning, final_content, _ = parse_chat_output(token_ids)
assert reasoning == "I'm thinking."
assert final_content == "I'm in the middle of answering"
def test_parse_chat_output_complete_content(self) -> None:
harmony_str = "<|channel|>final<|message|>The answer is 4.<|end|>"
token_ids = get_encoding().encode(harmony_str, allowed_special="all")
reasoning, final_content, _ = parse_chat_output(token_ids)
assert reasoning is None
assert final_content == "The answer is 4."
def test_parse_chat_output_complete_commentary(self) -> None:
harmony_str = (
"<|channel|>commentary<|message|>I need to call some tools.<|end|>"
)
token_ids = get_encoding().encode(harmony_str, allowed_special="all")
reasoning, final_content, _ = parse_chat_output(token_ids)
assert reasoning is None
assert final_content == "I need to call some tools."
def test_parse_chat_output_complete_reasoning(self) -> None:
harmony_str = (
"<|channel|>analysis<|message|>I've thought hard about this.<|end|>"
)
token_ids = get_encoding().encode(harmony_str, allowed_special="all")
reasoning, final_content, _ = parse_chat_output(token_ids)
assert reasoning == "I've thought hard about this."
assert final_content is None
def test_parse_chat_output_complete_reasoning_and_content(self) -> None:
harmony_str = (
"<|channel|>analysis<|message|>I've thought hard about this.<|end|>"
"<|start|>assistant<|channel|>final<|message|>The answer is 4.<|end|>"
)
token_ids = get_encoding().encode(harmony_str, allowed_special="all")
reasoning, final_content, _ = parse_chat_output(token_ids)
assert reasoning == "I've thought hard about this."
assert final_content == "The answer is 4."
class TestParseOutputMessage:
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass, field
from http import HTTPStatus
from typing import Any
from unittest.mock import AsyncMock, MagicMock
import pytest
from vllm.config.multimodal import MultiModalConfig
from vllm.entrypoints.openai.protocol import ChatCompletionRequest, ErrorResponse
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.v1.engine.async_llm import AsyncLLM
MODEL_NAME = "openai-community/gpt2"
MODEL_NAME_SHORT = "gpt2"
BASE_MODEL_PATHS = [
BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME),
BaseModelPath(name=MODEL_NAME_SHORT, model_path=MODEL_NAME_SHORT),
]
@dataclass
class MockHFConfig:
model_type: str = "any"
@dataclass
class MockModelConfig:
task = "generate"
runner_type = "generate"
tokenizer = MODEL_NAME
trust_remote_code = False
tokenizer_mode = "auto"
max_model_len = 100
tokenizer_revision = None
multimodal_config = MultiModalConfig()
hf_config = MockHFConfig()
logits_processor_pattern = None
logits_processors: list[str] | None = None
diff_sampling_param: dict | None = None
allowed_local_media_path: str = ""
allowed_media_domains: list[str] | None = None
encoder_config = None
generation_config: str = "auto"
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
skip_tokenizer_init = False
def get_diff_sampling_param(self):
return self.diff_sampling_param or {}
def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat:
models = OpenAIServingModels(
engine_client=engine,
base_model_paths=BASE_MODEL_PATHS,
)
serving_chat = OpenAIServingChat(
engine,
models,
response_role="assistant",
request_logger=None,
chat_template=None,
chat_template_content_format="auto",
)
async def _fake_process_inputs(
request_id,
engine_prompt,
sampling_params,
*,
lora_request,
trace_headers,
priority,
):
return dict(engine_prompt), {}
async def _fake_preprocess_chat(*args, **kwargs):
# return conversation, engine_prompts
return (
[{"role": "user", "content": "Test"}],
[{"prompt_token_ids": [1, 2, 3]}],
)
serving_chat._process_inputs = AsyncMock(side_effect=_fake_process_inputs)
serving_chat._preprocess_chat = AsyncMock(side_effect=_fake_preprocess_chat)
return serving_chat
@pytest.mark.asyncio
async def test_chat_error_non_stream():
"""test finish_reason='error' returns 500 InternalServerError (non-streaming)"""
mock_engine = MagicMock(spec=AsyncLLM)
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False
mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
serving_chat = _build_serving_chat(mock_engine)
completion_output = CompletionOutput(
index=0,
text="",
token_ids=[],
cumulative_logprob=None,
logprobs=None,
finish_reason="error",
)
request_output = RequestOutput(
request_id="test-id",
prompt="Test prompt",
prompt_token_ids=[1, 2, 3],
prompt_logprobs=None,
outputs=[completion_output],
finished=True,
metrics=None,
lora_request=None,
encoder_prompt=None,
encoder_prompt_token_ids=None,
)
async def mock_generate(*args, **kwargs):
yield request_output
mock_engine.generate = MagicMock(side_effect=mock_generate)
request = ChatCompletionRequest(
model=MODEL_NAME,
messages=[{"role": "user", "content": "Test prompt"}],
max_tokens=10,
stream=False,
)
response = await serving_chat.create_chat_completion(request)
assert isinstance(response, ErrorResponse)
assert response.error.type == "InternalServerError"
assert response.error.message == "Internal server error"
assert response.error.code == HTTPStatus.INTERNAL_SERVER_ERROR
@pytest.mark.asyncio
async def test_chat_error_stream():
"""test finish_reason='error' returns 500 InternalServerError (streaming)"""
mock_engine = MagicMock(spec=AsyncLLM)
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False
mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
serving_chat = _build_serving_chat(mock_engine)
completion_output_1 = CompletionOutput(
index=0,
text="Hello",
token_ids=[100],
cumulative_logprob=None,
logprobs=None,
finish_reason=None,
)
request_output_1 = RequestOutput(
request_id="test-id",
prompt="Test prompt",
prompt_token_ids=[1, 2, 3],
prompt_logprobs=None,
outputs=[completion_output_1],
finished=False,
metrics=None,
lora_request=None,
encoder_prompt=None,
encoder_prompt_token_ids=None,
)
completion_output_2 = CompletionOutput(
index=0,
text="Hello",
token_ids=[100],
cumulative_logprob=None,
logprobs=None,
finish_reason="error",
)
request_output_2 = RequestOutput(
request_id="test-id",
prompt="Test prompt",
prompt_token_ids=[1, 2, 3],
prompt_logprobs=None,
outputs=[completion_output_2],
finished=True,
metrics=None,
lora_request=None,
encoder_prompt=None,
encoder_prompt_token_ids=None,
)
async def mock_generate(*args, **kwargs):
yield request_output_1
yield request_output_2
mock_engine.generate = MagicMock(side_effect=mock_generate)
request = ChatCompletionRequest(
model=MODEL_NAME,
messages=[{"role": "user", "content": "Test prompt"}],
max_tokens=10,
stream=True,
)
response = await serving_chat.create_chat_completion(request)
chunks = []
async for chunk in response:
chunks.append(chunk)
assert len(chunks) >= 2
assert any("Internal server error" in chunk for chunk in chunks), (
f"Expected error message in chunks: {chunks}"
)
assert chunks[-1] == "data: [DONE]\n\n"
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass, field
from http import HTTPStatus
from typing import Any
from unittest.mock import AsyncMock, MagicMock
import pytest
from vllm.config.multimodal import MultiModalConfig
from vllm.entrypoints.openai.protocol import CompletionRequest, ErrorResponse
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.v1.engine.async_llm import AsyncLLM
MODEL_NAME = "openai-community/gpt2"
MODEL_NAME_SHORT = "gpt2"
BASE_MODEL_PATHS = [
BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME),
BaseModelPath(name=MODEL_NAME_SHORT, model_path=MODEL_NAME_SHORT),
]
@dataclass
class MockHFConfig:
model_type: str = "any"
@dataclass
class MockModelConfig:
task = "generate"
runner_type = "generate"
tokenizer = MODEL_NAME
trust_remote_code = False
tokenizer_mode = "auto"
max_model_len = 100
tokenizer_revision = None
multimodal_config = MultiModalConfig()
hf_config = MockHFConfig()
logits_processor_pattern = None
logits_processors: list[str] | None = None
diff_sampling_param: dict | None = None
allowed_local_media_path: str = ""
allowed_media_domains: list[str] | None = None
encoder_config = None
generation_config: str = "auto"
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
skip_tokenizer_init = False
def get_diff_sampling_param(self):
return self.diff_sampling_param or {}
def _build_serving_completion(engine: AsyncLLM) -> OpenAIServingCompletion:
models = OpenAIServingModels(
engine_client=engine,
base_model_paths=BASE_MODEL_PATHS,
)
serving_completion = OpenAIServingCompletion(
engine,
models,
request_logger=None,
)
async def _fake_process_inputs(
request_id,
engine_prompt,
sampling_params,
*,
lora_request,
trace_headers,
priority,
):
return dict(engine_prompt), {}
serving_completion._process_inputs = AsyncMock(side_effect=_fake_process_inputs)
return serving_completion
@pytest.mark.asyncio
async def test_completion_error_non_stream():
"""test finish_reason='error' returns 500 InternalServerError (non-streaming)"""
mock_engine = MagicMock(spec=AsyncLLM)
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False
mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
serving_completion = _build_serving_completion(mock_engine)
completion_output = CompletionOutput(
index=0,
text="",
token_ids=[],
cumulative_logprob=None,
logprobs=None,
finish_reason="error",
)
request_output = RequestOutput(
request_id="test-id",
prompt="Test prompt",
prompt_token_ids=[1, 2, 3],
prompt_logprobs=None,
outputs=[completion_output],
finished=True,
metrics=None,
lora_request=None,
encoder_prompt=None,
encoder_prompt_token_ids=None,
)
async def mock_generate(*args, **kwargs):
yield request_output
mock_engine.generate = MagicMock(side_effect=mock_generate)
request = CompletionRequest(
model=MODEL_NAME,
prompt="Test prompt",
max_tokens=10,
stream=False,
)
response = await serving_completion.create_completion(request)
assert isinstance(response, ErrorResponse)
assert response.error.type == "InternalServerError"
assert response.error.message == "Internal server error"
assert response.error.code == HTTPStatus.INTERNAL_SERVER_ERROR
@pytest.mark.asyncio
async def test_completion_error_stream():
"""test finish_reason='error' returns 500 InternalServerError (streaming)"""
mock_engine = MagicMock(spec=AsyncLLM)
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False
mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
serving_completion = _build_serving_completion(mock_engine)
completion_output_1 = CompletionOutput(
index=0,
text="Hello",
token_ids=[100],
cumulative_logprob=None,
logprobs=None,
finish_reason=None,
)
request_output_1 = RequestOutput(
request_id="test-id",
prompt="Test prompt",
prompt_token_ids=[1, 2, 3],
prompt_logprobs=None,
outputs=[completion_output_1],
finished=False,
metrics=None,
lora_request=None,
encoder_prompt=None,
encoder_prompt_token_ids=None,
)
completion_output_2 = CompletionOutput(
index=0,
text="Hello",
token_ids=[100],
cumulative_logprob=None,
logprobs=None,
finish_reason="error",
)
request_output_2 = RequestOutput(
request_id="test-id",
prompt="Test prompt",
prompt_token_ids=[1, 2, 3],
prompt_logprobs=None,
outputs=[completion_output_2],
finished=True,
metrics=None,
lora_request=None,
encoder_prompt=None,
encoder_prompt_token_ids=None,
)
async def mock_generate(*args, **kwargs):
yield request_output_1
yield request_output_2
mock_engine.generate = MagicMock(side_effect=mock_generate)
request = CompletionRequest(
model=MODEL_NAME,
prompt="Test prompt",
max_tokens=10,
stream=True,
)
response = await serving_completion.create_completion(request)
chunks = []
async for chunk in response:
chunks.append(chunk)
assert len(chunks) >= 2
assert any("Internal server error" in chunk for chunk in chunks), (
f"Expected error message in chunks: {chunks}"
)
assert chunks[-1] == "data: [DONE]\n\n"
......@@ -79,9 +79,12 @@ async def test_anthropic_streaming(client: anthropic.AsyncAnthropic):
assert chunk_count > 0
assert first_chunk is not None, "message_start chunk was never observed"
assert first_chunk.usage is not None, "first chunk should include usage stats"
assert first_chunk.usage["output_tokens"] == 0
assert first_chunk.usage["input_tokens"] > 5
assert first_chunk.message is not None, "first chunk should include message"
assert first_chunk.message.usage is not None, (
"first chunk should include usage stats"
)
assert first_chunk.message.usage.output_tokens == 0
assert first_chunk.message.usage.input_tokens > 5
@pytest.mark.asyncio
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from http import HTTPStatus
from unittest.mock import MagicMock
import pytest
from vllm.entrypoints.openai.protocol import ErrorResponse
from vllm.entrypoints.openai.serving_engine import GenerationError, OpenAIServing
@pytest.mark.asyncio
async def test_raise_if_error_raises_generation_error():
"""test _raise_if_error raises GenerationError"""
# create a minimal OpenAIServing instance
mock_engine = MagicMock()
mock_engine.model_config = MagicMock()
mock_engine.model_config.max_model_len = 100
mock_models = MagicMock()
serving = OpenAIServing(
engine_client=mock_engine,
models=mock_models,
request_logger=None,
)
# test that error finish_reason raises GenerationError
with pytest.raises(GenerationError) as exc_info:
serving._raise_if_error("error", "test-request-id")
assert str(exc_info.value) == "Internal server error"
assert exc_info.value.status_code == HTTPStatus.INTERNAL_SERVER_ERROR
# test that other finish_reasons don't raise
serving._raise_if_error("stop", "test-request-id") # should not raise
serving._raise_if_error("length", "test-request-id") # should not raise
serving._raise_if_error(None, "test-request-id") # should not raise
@pytest.mark.asyncio
async def test_convert_generation_error_to_response():
"""test _convert_generation_error_to_response creates proper ErrorResponse"""
mock_engine = MagicMock()
mock_engine.model_config = MagicMock()
mock_engine.model_config.max_model_len = 100
mock_models = MagicMock()
serving = OpenAIServing(
engine_client=mock_engine,
models=mock_models,
request_logger=None,
)
# create a GenerationError
gen_error = GenerationError("Internal server error")
# convert to ErrorResponse
error_response = serving._convert_generation_error_to_response(gen_error)
assert isinstance(error_response, ErrorResponse)
assert error_response.error.type == "InternalServerError"
assert error_response.error.message == "Internal server error"
assert error_response.error.code == HTTPStatus.INTERNAL_SERVER_ERROR
@pytest.mark.asyncio
async def test_convert_generation_error_to_streaming_response():
"""test _convert_generation_error_to_streaming_response output"""
mock_engine = MagicMock()
mock_engine.model_config = MagicMock()
mock_engine.model_config.max_model_len = 100
mock_models = MagicMock()
serving = OpenAIServing(
engine_client=mock_engine,
models=mock_models,
request_logger=None,
)
# create a GenerationError
gen_error = GenerationError("Internal server error")
# convert to streaming error response
error_json = serving._convert_generation_error_to_streaming_response(gen_error)
assert isinstance(error_json, str)
assert "Internal server error" in error_json
assert "InternalServerError" in error_json
......@@ -11,13 +11,25 @@ import pytest_asyncio
from openai import OpenAI
from vllm.config.multimodal import MultiModalConfig
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
from vllm.entrypoints.openai.parser.harmony_utils import get_encoding
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
ChatCompletionResponse,
RequestResponseMetadata,
)
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.tokenizers import get_tokenizer
from vllm.tool_parsers import ToolParserManager
from vllm.v1.engine.async_llm import AsyncLLM
from ...utils import RemoteOpenAIServer
from .utils import (
accumulate_streaming_response,
verify_chat_response,
verify_harmony_messages,
)
GPT_OSS_MODEL_NAME = "openai/gpt-oss-20b"
......@@ -728,3 +740,635 @@ async def test_serving_chat_data_parallel_rank_extraction():
# Verify that data_parallel_rank defaults to None
assert "data_parallel_rank" in mock_engine.generate.call_args.kwargs
assert mock_engine.generate.call_args.kwargs["data_parallel_rank"] is None
class TestServingChatWithHarmony:
"""
These tests ensure Chat Completion requests are being properly converted into
Harmony messages and Harmony response messages back into Chat Completion responses.
These tests are not exhaustive, but each one was created to cover a specific case
that we got wrong but is now fixed.
Any changes to the tests and their expectations may result in changes to the
accuracy of model prompting and responses generated. It is suggested to run
an evaluation or benchmarking suite (such as bfcl multi_turn) to understand
any impact of changes in how we prompt Harmony models.
"""
@pytest.fixture(params=[False, True], ids=["non_streaming", "streaming"])
def stream(self, request) -> bool:
"""Parameterize tests to run in both non-streaming and streaming modes."""
return request.param
@pytest.fixture()
def mock_engine(self) -> AsyncLLM:
mock_engine = MagicMock(spec=AsyncLLM)
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False
mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
return mock_engine
@pytest.fixture()
def serving_chat(self, mock_engine) -> OpenAIServingChat:
chat = _build_serving_chat(mock_engine)
chat.use_harmony = True
chat.tool_parser = ToolParserManager.get_tool_parser("openai")
return chat
def mock_request_output_from_req_and_token_ids(
self, req: ChatCompletionRequest, token_ids: list[int], finished: bool = False
) -> RequestOutput:
# Our tests don't use most fields, so just get the token ids correct
completion_output = CompletionOutput(
index=0,
text="",
token_ids=token_ids,
cumulative_logprob=0.0,
logprobs=None,
)
return RequestOutput(
request_id=req.request_id,
prompt=[],
prompt_token_ids=[],
prompt_logprobs=None,
outputs=[completion_output],
finished=finished,
)
@pytest.fixture
def weather_tools(self) -> list[dict[str, Any]]:
return [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get the weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {"type": "string"},
},
"required": ["location"],
},
},
},
]
@pytest.fixture
def weather_messages_start(self) -> list[dict[str, Any]]:
return [
{
"role": "user",
"content": "What's the weather like in Paris today?",
},
]
async def generate_response_from_harmony_str(
self,
serving_chat: OpenAIServingChat,
req: ChatCompletionRequest,
harmony_str: str,
stream: bool = False,
) -> ChatCompletionResponse:
harmony_token_ids = get_encoding().encode(harmony_str, allowed_special="all")
async def result_generator():
if stream:
for token_id in harmony_token_ids:
yield self.mock_request_output_from_req_and_token_ids(
req, [token_id]
)
yield self.mock_request_output_from_req_and_token_ids(
req, [], finished=True
)
else:
yield self.mock_request_output_from_req_and_token_ids(
req, harmony_token_ids, finished=True
)
generator_func = (
serving_chat.chat_completion_stream_generator
if stream
else serving_chat.chat_completion_full_generator
)
result = generator_func(
request=req,
result_generator=result_generator(),
request_id=req.request_id,
model_name=req.model,
conversation=[],
tokenizer=get_tokenizer(req.model),
request_metadata=RequestResponseMetadata(
request_id=req.request_id,
model_name=req.model,
),
)
if stream:
return await accumulate_streaming_response(result)
return await result
@pytest.mark.asyncio
async def test_simple_chat(self, serving_chat, stream):
messages = [{"role": "user", "content": "what is 1+1?"}]
# Test the Harmony messages for the first turn's input
req = ChatCompletionRequest(model=MODEL_NAME, messages=messages)
input_messages, _ = serving_chat._make_request_with_harmony(req)
verify_harmony_messages(
input_messages,
[
{"role": "system"},
{"role": "developer"},
{"role": "user", "content": messages[0]["content"]},
],
)
# Test the Chat Completion response for the first turn's output
reasoning_str = "We need to think really hard about this."
final_str = "The answer is 2."
response_str = (
f"<|channel|>analysis<|message|>{reasoning_str}<|end|>"
f"<|start|>assistant<|channel|>final<|message|>{final_str}<|end|>"
)
response = await self.generate_response_from_harmony_str(
serving_chat, req, response_str, stream=stream
)
verify_chat_response(response, content=final_str, reasoning=reasoning_str)
# Add the output messages from the first turn as input to the second turn
for choice in response.choices:
messages.append(choice.message.model_dump(exclude_none=True))
# Test the Harmony messages for the second turn's input
req_2 = ChatCompletionRequest(model=MODEL_NAME, messages=messages)
input_messages_2, _ = serving_chat._make_request_with_harmony(req_2)
verify_harmony_messages(
input_messages_2,
[
{"role": "system"},
{"role": "developer"},
{"role": "user"},
# The analysis message should be dropped on subsequent inputs because
# of the subsequent assistant message to the final channel.
{"role": "assistant", "channel": "final", "content": final_str},
],
)
@pytest.mark.asyncio
async def test_tool_call_response_with_content(
self, serving_chat, stream, weather_tools, weather_messages_start
):
tools = weather_tools
messages = list(weather_messages_start)
# Test the Harmony messages for the first turn's input
req = ChatCompletionRequest(model=MODEL_NAME, messages=messages, tools=tools)
input_messages, _ = serving_chat._make_request_with_harmony(req)
verify_harmony_messages(
input_messages,
[
{"role": "system"},
{"role": "developer", "tool_definitions": ["get_weather"]},
{"role": "user", "content": messages[0]["content"]},
],
)
# Test the Chat Completion response for the first turn's output
commentary_str = "We'll call get_weather."
tool_args_str = '{"location": "Paris"}'
response_str = (
f"<|channel|>commentary<|message|>{commentary_str}<|end|>"
"<|start|>assistant to=functions.get_weather<|channel|>commentary"
f"<|constrain|>json<|message|>{tool_args_str}<|call|>"
)
response = await self.generate_response_from_harmony_str(
serving_chat, req, response_str, stream=stream
)
verify_chat_response(
response,
content=commentary_str,
tool_calls=[("get_weather", tool_args_str)],
)
tool_call = response.choices[0].message.tool_calls[0]
# Add the output messages from the first turn as input to the second turn
for choice in response.choices:
messages.append(choice.message.model_dump(exclude_none=True))
# Add our tool output message
messages.append(
{
"role": "tool",
"tool_call_id": tool_call.id,
"content": "20 degrees Celsius",
},
)
# Test the Harmony messages for the second turn's input
req_2 = ChatCompletionRequest(model=MODEL_NAME, messages=messages)
input_messages_2, _ = serving_chat._make_request_with_harmony(req_2)
verify_harmony_messages(
input_messages_2,
[
{"role": "system"},
{"role": "developer"},
{"role": "user"},
{
"role": "assistant",
"channel": "commentary",
"content": commentary_str,
},
{
"role": "assistant",
"channel": "commentary",
"recipient": "functions.get_weather",
"content": tool_args_str,
},
{
"role": "tool",
"author_name": "functions.get_weather",
"channel": "commentary",
"recipient": "assistant",
"content": "20 degrees Celsius",
},
],
)
@pytest.mark.asyncio
async def test_tools_and_reasoning(
self, serving_chat, stream, weather_tools, weather_messages_start
):
tools = weather_tools
messages = list(weather_messages_start)
# Test the Harmony messages for the first turn's input
req = ChatCompletionRequest(model=MODEL_NAME, messages=messages, tools=tools)
input_messages, _ = serving_chat._make_request_with_harmony(req)
verify_harmony_messages(
input_messages,
[
{"role": "system"},
{"role": "developer", "tool_definitions": ["get_weather"]},
{"role": "user", "content": messages[0]["content"]},
],
)
# Test the Chat Completion response for the first turn's output
reasoning_str = "I'll call get_weather."
tool_args_str = '{"location": "Paris"}'
response_str = (
f"<|channel|>analysis<|message|>{reasoning_str}<|end|>"
"<|start|>assistant to=functions.get_weather<|channel|>commentary"
f"<|constrain|>json<|message|>{tool_args_str}<|call|>"
)
response = await self.generate_response_from_harmony_str(
serving_chat, req, response_str, stream=stream
)
verify_chat_response(
response,
reasoning=reasoning_str,
tool_calls=[("get_weather", tool_args_str)],
)
tool_call = response.choices[0].message.tool_calls[0]
# Add the output messages from the first turn as input to the second turn
for choice in response.choices:
messages.append(choice.message.model_dump(exclude_none=True))
# Add our tool output message
messages.append(
{
"role": "tool",
"tool_call_id": tool_call.id,
"content": "20 degrees Celsius",
},
)
# Test the Harmony messages for the second turn's input
req_2 = ChatCompletionRequest(model=MODEL_NAME, messages=messages)
input_messages_2, _ = serving_chat._make_request_with_harmony(req_2)
verify_harmony_messages(
input_messages_2,
[
{"role": "system"},
{"role": "developer"},
{"role": "user"},
{
"role": "assistant",
"channel": "analysis",
"content": reasoning_str,
},
{
"role": "assistant",
"channel": "commentary",
"recipient": "functions.get_weather",
"content": tool_args_str,
},
{
"role": "tool",
"author_name": "functions.get_weather",
"channel": "commentary",
"recipient": "assistant",
"content": "20 degrees Celsius",
},
],
)
@pytest.mark.asyncio
async def test_multi_turn_tools_and_reasoning(
self, serving_chat, stream, weather_tools, weather_messages_start
):
tools = weather_tools
messages = list(weather_messages_start)
# Test the Harmony messages for the first turn's input
req = ChatCompletionRequest(model=MODEL_NAME, messages=messages, tools=tools)
input_messages, _ = serving_chat._make_request_with_harmony(req)
verify_harmony_messages(
input_messages,
[
{"role": "system"},
{"role": "developer", "tool_definitions": ["get_weather"]},
{"role": "user", "content": messages[0]["content"]},
],
)
# Test the Chat Completion response for the first turn's output
reasoning_str = "I'll call get_weather."
paris_tool_args_str = '{"location": "Paris"}'
response_str = (
f"<|channel|>analysis<|message|>{reasoning_str}<|end|>"
"<|start|>assistant to=functions.get_weather<|channel|>commentary"
f"<|constrain|>json<|message|>{paris_tool_args_str}<|call|>"
)
response = await self.generate_response_from_harmony_str(
serving_chat, req, response_str, stream=stream
)
verify_chat_response(
response,
reasoning=reasoning_str,
tool_calls=[("get_weather", paris_tool_args_str)],
)
tool_call = response.choices[0].message.tool_calls[0]
# Add the output messages from the first turn as input to the second turn
for choice in response.choices:
messages.append(choice.message.model_dump(exclude_none=True))
# Add our tool output message
messages.append(
{
"role": "tool",
"tool_call_id": tool_call.id,
"content": "20 degrees Celsius",
},
)
# Test the Harmony messages for the second turn's input
req_2 = ChatCompletionRequest(model=MODEL_NAME, messages=messages)
input_messages_2, _ = serving_chat._make_request_with_harmony(req_2)
verify_harmony_messages(
input_messages_2,
[
{"role": "system"},
{"role": "developer"},
{"role": "user"},
{
"role": "assistant",
"channel": "analysis",
"content": reasoning_str,
},
{
"role": "assistant",
"channel": "commentary",
"recipient": "functions.get_weather",
"content": paris_tool_args_str,
},
{
"role": "tool",
"author_name": "functions.get_weather",
"channel": "commentary",
"recipient": "assistant",
"content": "20 degrees Celsius",
},
],
)
# Test the Chat Completion response for the second turn's output
paris_weather_str = "The weather in Paris today is 20 degrees Celsius."
response_str = f"<|channel|>final<|message|>{paris_weather_str}<|end|>"
response_2 = await self.generate_response_from_harmony_str(
serving_chat, req_2, response_str, stream=stream
)
verify_chat_response(response_2, content=paris_weather_str)
# Add the output messages from the second turn as input to the third turn
for choice in response_2.choices:
messages.append(choice.message.model_dump(exclude_none=True))
# Add a new user message for the third turn
messages.append(
{
"role": "user",
"content": "What's the weather like in Boston today?",
},
)
# Test the Harmony messages for the third turn's input
req_3 = ChatCompletionRequest(model=MODEL_NAME, messages=messages)
input_messages_3, _ = serving_chat._make_request_with_harmony(req_3)
verify_harmony_messages(
input_messages_3,
[
{"role": "system"},
{"role": "developer"},
{"role": "user"},
{
"role": "assistant",
"channel": "commentary",
"recipient": "functions.get_weather",
"content": paris_tool_args_str,
},
{
"role": "tool",
"author_name": "functions.get_weather",
"channel": "commentary",
"recipient": "assistant",
"content": "20 degrees Celsius",
},
{
"role": "assistant",
"channel": "final",
"content": paris_weather_str,
},
{"role": "user", "content": messages[-1]["content"]},
],
)
# Test the Chat Completion response for the third turn's output
reasoning_str = "I'll call get_weather."
boston_tool_args_str = '{"location": "Boston"}'
response_str = (
f"<|channel|>analysis<|message|>{reasoning_str}<|end|>"
"<|start|>assistant to=functions.get_weather<|channel|>commentary"
f"<|constrain|>json<|message|>{boston_tool_args_str}<|call|>"
)
response_3 = await self.generate_response_from_harmony_str(
serving_chat, req, response_str, stream=stream
)
verify_chat_response(
response_3,
reasoning=reasoning_str,
tool_calls=[("get_weather", boston_tool_args_str)],
)
tool_call = response_3.choices[0].message.tool_calls[0]
# Add the output messages from the third turn as input to the fourth turn
for choice in response_3.choices:
messages.append(choice.message.model_dump(exclude_none=True))
# Add our tool output message
messages.append(
{
"role": "tool",
"tool_call_id": tool_call.id,
"content": "10 degrees Celsius",
},
)
# Test the Harmony messages for the fourth turn's input
req_4 = ChatCompletionRequest(model=MODEL_NAME, messages=messages)
input_messages_4, _ = serving_chat._make_request_with_harmony(req_4)
verify_harmony_messages(
input_messages_4,
[
{"role": "system"},
{"role": "developer"},
{"role": "user"},
{"role": "assistant"},
{"role": "tool"},
{
"role": "assistant",
"channel": "final",
},
{"role": "user"},
{
"role": "assistant",
"channel": "analysis",
"content": reasoning_str,
},
{
"role": "assistant",
"channel": "commentary",
"recipient": "functions.get_weather",
"content": boston_tool_args_str,
},
{
"role": "tool",
"author_name": "functions.get_weather",
"channel": "commentary",
"recipient": "assistant",
"content": "10 degrees Celsius",
},
],
)
@pytest.mark.asyncio
async def test_non_tool_reasoning(self, serving_chat):
messages: list[dict[str, Any]] = [
{
"role": "user",
"content": "What's 2+2?",
},
{
"role": "assistant",
"reasoning": "Adding 2 and 2 is easy. The result is 4.",
"content": "4",
},
]
req = ChatCompletionRequest(model=MODEL_NAME, messages=messages)
input_messages, _ = serving_chat._make_request_with_harmony(req)
verify_harmony_messages(
input_messages,
[
{"role": "system"},
{"role": "developer"},
{"role": "user", "content": messages[0]["content"]},
# The reasoning that would have resulted in an analysis message is
# dropped because of a later assistant message to the final channel.
{
"role": "assistant",
"channel": "final",
"content": messages[1]["content"],
},
],
)
@pytest.mark.asyncio
async def test_non_tool_reasoning_empty_content(self, serving_chat):
messages: list[dict[str, Any]] = [
{
"role": "user",
"content": "What's 2+2?",
},
{
"role": "assistant",
"reasoning": "Adding 2 and 2 is easy. The result is 4.",
"content": "",
},
]
req = ChatCompletionRequest(model=MODEL_NAME, messages=messages)
input_messages, _ = serving_chat._make_request_with_harmony(req)
verify_harmony_messages(
input_messages,
[
{"role": "system"},
{"role": "developer"},
{"role": "user", "content": messages[0]["content"]},
{
"role": "assistant",
"channel": "analysis",
"content": messages[1]["reasoning"],
},
],
)
@pytest.mark.asyncio
async def test_non_tool_reasoning_empty_content_list(self, serving_chat):
messages: list[dict[str, Any]] = [
{
"role": "user",
"content": "What's 2+2?",
},
{
"role": "assistant",
"reasoning": "Adding 2 and 2 is easy. The result is 4.",
"content": [],
},
]
req = ChatCompletionRequest(model=MODEL_NAME, messages=messages)
input_messages, _ = serving_chat._make_request_with_harmony(req)
verify_harmony_messages(
input_messages,
[
{"role": "system"},
{"role": "developer"},
{"role": "user", "content": messages[0]["content"]},
{
"role": "assistant",
"channel": "analysis",
"content": messages[1]["reasoning"],
},
],
)
......@@ -10,7 +10,7 @@ import pytest
from vllm.config import ModelConfig
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.tokenizers import MistralTokenizer
from vllm.tokenizers.mistral import MistralTokenizer
@pytest.fixture()
......
......@@ -21,7 +21,7 @@ from vllm.entrypoints.openai.serving_responses import (
extract_tool_types,
)
from vllm.entrypoints.tool_server import ToolServer
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
from vllm.inputs.data import TokensPrompt
class MockConversationContext(ConversationContext):
......@@ -237,7 +237,7 @@ class TestValidateGeneratorInput:
"""Test _validate_generator_input with valid prompt length"""
# Create an engine prompt with valid length (less than max_model_len)
valid_prompt_token_ids = list(range(5)) # 5 tokens < 100 max_model_len
engine_prompt = EngineTokensPrompt(prompt_token_ids=valid_prompt_token_ids)
engine_prompt = TokensPrompt(prompt_token_ids=valid_prompt_token_ids)
# Call the method
result = serving_responses_instance._validate_generator_input(engine_prompt)
......@@ -247,7 +247,7 @@ class TestValidateGeneratorInput:
# create an invalid engine prompt
invalid_prompt_token_ids = list(range(200)) # 100 tokens >= 100 max_model_len
engine_prompt = EngineTokensPrompt(prompt_token_ids=invalid_prompt_token_ids)
engine_prompt = TokensPrompt(prompt_token_ids=invalid_prompt_token_ids)
# Call the method
result = serving_responses_instance._validate_generator_input(engine_prompt)
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Sparse tensor validation in embedding APIs.
Tests verify that malicious sparse tensors are rejected before they can trigger
out-of-bounds memory writes during to_dense() operations.
"""
import base64
import io
import pytest
import torch
from vllm.entrypoints.renderer import CompletionRenderer
from vllm.multimodal.audio import AudioEmbeddingMediaIO
from vllm.multimodal.image import ImageEmbeddingMediaIO
def _encode_tensor(tensor: torch.Tensor) -> bytes:
"""Helper to encode a tensor as base64 bytes."""
buffer = io.BytesIO()
torch.save(tensor, buffer)
buffer.seek(0)
return base64.b64encode(buffer.read())
def _create_malicious_sparse_tensor() -> torch.Tensor:
"""
Create a malicious sparse COO tensor with out-of-bounds indices.
This tensor has indices that point beyond the declared shape, which would
cause an out-of-bounds write when converted to dense format without
validation.
"""
# Create a 3x3 sparse tensor but with indices pointing to (10, 10)
indices = torch.tensor([[10], [10]]) # Out of bounds for 3x3 shape
values = torch.tensor([1.0])
shape = (3, 3)
# Create sparse tensor (this will be invalid)
sparse_tensor = torch.sparse_coo_tensor(indices, values, shape, dtype=torch.float32)
return sparse_tensor
def _create_valid_sparse_tensor() -> torch.Tensor:
"""Create a valid sparse COO tensor for baseline testing."""
indices = torch.tensor([[0, 1, 2], [0, 1, 2]])
values = torch.tensor([1.0, 2.0, 3.0])
shape = (3, 3)
sparse_tensor = torch.sparse_coo_tensor(indices, values, shape, dtype=torch.float32)
return sparse_tensor
def _create_valid_dense_tensor() -> torch.Tensor:
"""Create a valid dense tensor for baseline testing."""
return torch.randn(10, 768, dtype=torch.float32) # (seq_len, hidden_size)
class TestPromptEmbedsValidation:
"""Test sparse tensor validation in prompt embeddings (Completions API)."""
def test_valid_dense_tensor_accepted(self, model_config):
"""Baseline: Valid dense tensors should work normally."""
renderer = CompletionRenderer(model_config)
valid_tensor = _create_valid_dense_tensor()
encoded = _encode_tensor(valid_tensor)
# Should not raise any exception
result = renderer.load_prompt_embeds(encoded)
assert len(result) == 1
assert result[0]["prompt_embeds"].shape == valid_tensor.shape
def test_valid_sparse_tensor_accepted(self):
"""Baseline: Valid sparse tensors should load successfully."""
io_handler = ImageEmbeddingMediaIO()
valid_sparse = _create_valid_sparse_tensor()
encoded = _encode_tensor(valid_sparse)
# Should not raise any exception (sparse tensors remain sparse)
result = io_handler.load_base64("", encoded.decode("utf-8"))
assert result.shape == valid_sparse.shape
def test_malicious_sparse_tensor_rejected(self, model_config):
"""Security: Malicious sparse tensors should be rejected."""
renderer = CompletionRenderer(model_config)
malicious_tensor = _create_malicious_sparse_tensor()
encoded = _encode_tensor(malicious_tensor)
# Should raise RuntimeError due to invalid sparse tensor
with pytest.raises((RuntimeError, ValueError)) as exc_info:
renderer.load_prompt_embeds(encoded)
# Error should indicate sparse tensor validation failure
error_msg = str(exc_info.value).lower()
assert "sparse" in error_msg or "index" in error_msg or "bounds" in error_msg
def test_extremely_large_indices_rejected(self, model_config):
"""Security: Sparse tensors with extremely large indices should be rejected."""
renderer = CompletionRenderer(model_config)
# Create tensor with indices far beyond reasonable bounds
indices = torch.tensor([[999999], [999999]])
values = torch.tensor([1.0])
shape = (10, 10)
malicious_tensor = torch.sparse_coo_tensor(
indices, values, shape, dtype=torch.float32
)
encoded = _encode_tensor(malicious_tensor)
with pytest.raises((RuntimeError, ValueError)):
renderer.load_prompt_embeds(encoded)
def test_negative_indices_rejected(self, model_config):
"""Security: Sparse tensors with negative indices should be rejected."""
renderer = CompletionRenderer(model_config)
# Create tensor with negative indices
indices = torch.tensor([[-1], [-1]])
values = torch.tensor([1.0])
shape = (10, 10)
malicious_tensor = torch.sparse_coo_tensor(
indices, values, shape, dtype=torch.float32
)
encoded = _encode_tensor(malicious_tensor)
with pytest.raises((RuntimeError, ValueError)):
renderer.load_prompt_embeds(encoded)
class TestImageEmbedsValidation:
"""Test sparse tensor validation in image embeddings (Chat API)."""
def test_valid_dense_tensor_accepted(self):
"""Baseline: Valid dense tensors should work normally."""
io_handler = ImageEmbeddingMediaIO()
valid_tensor = _create_valid_dense_tensor()
encoded = _encode_tensor(valid_tensor)
# Should not raise any exception
result = io_handler.load_base64("", encoded.decode("utf-8"))
assert result.shape == valid_tensor.shape
def test_valid_sparse_tensor_accepted(self):
"""Baseline: Valid sparse tensors should load successfully."""
io_handler = AudioEmbeddingMediaIO()
valid_sparse = _create_valid_sparse_tensor()
encoded = _encode_tensor(valid_sparse)
# Should not raise any exception (sparse tensors remain sparse)
result = io_handler.load_base64("", encoded.decode("utf-8"))
assert result.shape == valid_sparse.shape
def test_malicious_sparse_tensor_rejected(self):
"""Security: Malicious sparse tensors should be rejected."""
io_handler = ImageEmbeddingMediaIO()
malicious_tensor = _create_malicious_sparse_tensor()
encoded = _encode_tensor(malicious_tensor)
# Should raise RuntimeError due to invalid sparse tensor
with pytest.raises((RuntimeError, ValueError)) as exc_info:
io_handler.load_base64("", encoded.decode("utf-8"))
error_msg = str(exc_info.value).lower()
assert "sparse" in error_msg or "index" in error_msg or "bounds" in error_msg
def test_load_bytes_validates(self):
"""Security: Validation should also work for load_bytes method."""
io_handler = ImageEmbeddingMediaIO()
malicious_tensor = _create_malicious_sparse_tensor()
buffer = io.BytesIO()
torch.save(malicious_tensor, buffer)
buffer.seek(0)
with pytest.raises((RuntimeError, ValueError)):
io_handler.load_bytes(buffer.read())
class TestAudioEmbedsValidation:
"""Test sparse tensor validation in audio embeddings (Chat API)."""
def test_valid_dense_tensor_accepted(self):
"""Baseline: Valid dense tensors should work normally."""
io_handler = AudioEmbeddingMediaIO()
valid_tensor = _create_valid_dense_tensor()
encoded = _encode_tensor(valid_tensor)
# Should not raise any exception
result = io_handler.load_base64("", encoded.decode("utf-8"))
assert result.shape == valid_tensor.shape
def test_valid_sparse_tensor_accepted(self):
"""Baseline: Valid sparse tensors should be converted successfully."""
io_handler = AudioEmbeddingMediaIO()
valid_sparse = _create_valid_sparse_tensor()
encoded = _encode_tensor(valid_sparse)
# Should not raise any exception
result = io_handler.load_base64("", encoded.decode("utf-8"))
assert result.is_sparse is False
def test_malicious_sparse_tensor_rejected(self):
"""Security: Malicious sparse tensors should be rejected."""
io_handler = AudioEmbeddingMediaIO()
malicious_tensor = _create_malicious_sparse_tensor()
encoded = _encode_tensor(malicious_tensor)
# Should raise RuntimeError due to invalid sparse tensor
with pytest.raises((RuntimeError, ValueError)) as exc_info:
io_handler.load_base64("", encoded.decode("utf-8"))
error_msg = str(exc_info.value).lower()
assert "sparse" in error_msg or "index" in error_msg or "bounds" in error_msg
def test_load_bytes_validates(self):
"""Security: Validation should also work for load_bytes method."""
io_handler = AudioEmbeddingMediaIO()
malicious_tensor = _create_malicious_sparse_tensor()
buffer = io.BytesIO()
torch.save(malicious_tensor, buffer)
buffer.seek(0)
with pytest.raises((RuntimeError, ValueError)):
io_handler.load_bytes(buffer.read())
class TestSparseTensorValidationIntegration:
"""
These tests verify the complete attack chain is blocked at all entry points.
"""
def test_attack_scenario_completions_api(self, model_config):
"""
Simulate a complete attack through the Completions API.
Attack scenario:
1. Attacker crafts malicious sparse tensor
2. Encodes it as base64
3. Sends to /v1/completions with prompt_embeds parameter
4. Server should reject before memory corruption occurs
"""
renderer = CompletionRenderer(model_config)
# Step 1-2: Attacker creates malicious payload
attack_payload = _encode_tensor(_create_malicious_sparse_tensor())
# Step 3-4: Server processes and should reject
with pytest.raises((RuntimeError, ValueError)):
renderer.load_prompt_embeds(attack_payload)
def test_attack_scenario_chat_api_image(self):
"""
Simulate attack through Chat API with image_embeds.
Verifies the image embeddings path is protected.
"""
io_handler = ImageEmbeddingMediaIO()
attack_payload = _encode_tensor(_create_malicious_sparse_tensor())
with pytest.raises((RuntimeError, ValueError)):
io_handler.load_base64("", attack_payload.decode("utf-8"))
def test_attack_scenario_chat_api_audio(self):
"""
Simulate attack through Chat API with audio_embeds.
Verifies the audio embeddings path is protected.
"""
io_handler = AudioEmbeddingMediaIO()
attack_payload = _encode_tensor(_create_malicious_sparse_tensor())
with pytest.raises((RuntimeError, ValueError)):
io_handler.load_base64("", attack_payload.decode("utf-8"))
def test_multiple_valid_embeddings_in_batch(self, model_config):
"""
Regression test: Multiple valid embeddings should still work.
Ensures the fix doesn't break legitimate batch processing.
"""
renderer = CompletionRenderer(model_config)
valid_tensors = [
_encode_tensor(_create_valid_dense_tensor()),
_encode_tensor(_create_valid_dense_tensor()),
_encode_tensor(_create_valid_dense_tensor()),
]
# Should process all without error
result = renderer.load_prompt_embeds(valid_tensors)
assert len(result) == 3
def test_mixed_valid_and_malicious_rejected(self, model_config):
"""
Security: Batch with one malicious tensor should be rejected.
Even if most tensors are valid, a single malicious one should
cause rejection of the entire batch.
"""
renderer = CompletionRenderer(model_config)
mixed_batch = [
_encode_tensor(_create_valid_dense_tensor()),
_encode_tensor(_create_malicious_sparse_tensor()), # Malicious
_encode_tensor(_create_valid_dense_tensor()),
]
# Should fail on the malicious tensor
with pytest.raises((RuntimeError, ValueError)):
renderer.load_prompt_embeds(mixed_batch)
# Pytest fixtures
@pytest.fixture
def model_config():
"""Mock ModelConfig for testing."""
from vllm.config import ModelConfig
return ModelConfig(
model="facebook/opt-125m",
tokenizer="facebook/opt-125m",
tokenizer_mode="auto",
trust_remote_code=False,
dtype="float32",
seed=0,
enable_prompt_embeds=True, # Required for prompt embeds tests
)
......@@ -10,8 +10,8 @@ from tests.entrypoints.openai.tool_parsers.utils import (
run_tool_extraction_streaming,
)
from vllm.entrypoints.openai.protocol import FunctionCall
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers import ToolParser, ToolParserManager
SIMPLE_ARGS_DICT = {
"action": "create",
......
......@@ -6,8 +6,8 @@ import json
import pytest
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
from vllm.entrypoints.openai.tool_parsers.hermes_tool_parser import Hermes2ProToolParser
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.hermes_tool_parser import Hermes2ProToolParser
from ....utils import RemoteOpenAIServer
......
......@@ -12,7 +12,7 @@ from tests.entrypoints.openai.tool_parsers.utils import (
run_tool_extraction_streaming,
)
from vllm.entrypoints.openai.protocol import FunctionCall, ToolCall
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
from vllm.tool_parsers import ToolParser, ToolParserManager
def make_tool_call(name, arguments):
......
......@@ -6,8 +6,8 @@ from unittest.mock import MagicMock, patch
import pytest
from vllm.entrypoints.openai.protocol import ExtractedToolCallInformation
from vllm.entrypoints.openai.tool_parsers.llama_tool_parser import Llama3JsonToolParser
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.llama_tool_parser import Llama3JsonToolParser
@pytest.fixture
......
......@@ -10,8 +10,8 @@ from tests.entrypoints.openai.tool_parsers.utils import (
run_tool_extraction_streaming,
)
from vllm.entrypoints.openai.protocol import FunctionCall
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers import ToolParser, ToolParserManager
# Test cases similar to pythonic parser but with Llama4 specific format
SIMPLE_FUNCTION_OUTPUT = "[get_weather(city='LA', metric='C')]"
......
......@@ -10,8 +10,8 @@ from tests.entrypoints.openai.tool_parsers.utils import (
run_tool_extraction_streaming,
)
from vllm.entrypoints.openai.protocol import FunctionCall
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers import ToolParser, ToolParserManager
# https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/text_prompt_format.md#model-response-format-1
SIMPLE_FUNCTION_OUTPUT = "get_weather(city='San Francisco', metric='celsius')"
......
......@@ -10,8 +10,8 @@ from tests.entrypoints.openai.tool_parsers.utils import (
run_tool_extraction_streaming,
)
from vllm.entrypoints.openai.protocol import FunctionCall
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers import ToolParser, ToolParserManager
# https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/text_prompt_format.md#model-response-format-1
SIMPLE_FUNCTION_OUTPUT = "get_weather(city='San Francisco', metric='celsius')"
......
......@@ -10,8 +10,8 @@ from vllm.entrypoints.openai.protocol import (
FunctionCall,
ToolCall,
)
from vllm.entrypoints.openai.tool_parsers import ToolParser
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers import ToolParser
class StreamingToolReconstructor:
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
from collections.abc import AsyncGenerator
from typing import Any
from vllm.entrypoints.openai.protocol import (
ChatCompletionResponse,
ChatCompletionResponseChoice,
ChatCompletionStreamResponse,
ChatMessage,
UsageInfo,
)
async def accumulate_streaming_response(
stream_generator: AsyncGenerator[str, None],
) -> ChatCompletionResponse:
"""
Accumulate streaming SSE chunks into a complete ChatCompletionResponse.
This helper parses the SSE format and builds up the complete response
by combining all the delta chunks.
"""
accumulated_content = ""
accumulated_reasoning = None
accumulated_tool_calls: list[dict[str, Any]] = []
role = None
finish_reason = None
response_id = None
created = None
model = None
index = 0
async for chunk_str in stream_generator:
# Skip empty lines and [DONE] marker
if not chunk_str.strip() or chunk_str.strip() == "data: [DONE]":
continue
# Parse SSE format: "data: {json}\n\n"
if chunk_str.startswith("data: "):
json_str = chunk_str[6:].strip()
try:
chunk_data = json.loads(json_str)
# print(f"DEBUG: Parsed chunk_data: {chunk_data}")
chunk = ChatCompletionStreamResponse(**chunk_data)
# Store metadata from first chunk
if response_id is None:
response_id = chunk.id
created = chunk.created
model = chunk.model
# Process each choice in the chunk
for choice in chunk.choices:
if choice.delta.role:
role = choice.delta.role
if choice.delta.content:
accumulated_content += choice.delta.content
if choice.delta.reasoning:
if accumulated_reasoning is None:
accumulated_reasoning = ""
accumulated_reasoning += choice.delta.reasoning
if choice.delta.tool_calls:
# Accumulate tool calls
for tool_call_delta in choice.delta.tool_calls:
# Find or create the tool call at this index
while len(accumulated_tool_calls) <= tool_call_delta.index:
accumulated_tool_calls.append(
{
"id": None,
"type": "function",
"function": {"name": "", "arguments": ""},
}
)
if tool_call_delta.id:
accumulated_tool_calls[tool_call_delta.index]["id"] = (
tool_call_delta.id
)
if tool_call_delta.function:
if tool_call_delta.function.name:
accumulated_tool_calls[tool_call_delta.index][
"function"
]["name"] += tool_call_delta.function.name
if tool_call_delta.function.arguments:
accumulated_tool_calls[tool_call_delta.index][
"function"
]["arguments"] += tool_call_delta.function.arguments
if choice.finish_reason:
finish_reason = choice.finish_reason
if choice.index is not None:
index = choice.index
except json.JSONDecodeError:
continue
# Build the final message
message_kwargs = {
"role": role or "assistant",
"content": accumulated_content if accumulated_content else None,
"reasoning": accumulated_reasoning,
}
# Only include tool_calls if there are any
if accumulated_tool_calls:
message_kwargs["tool_calls"] = [
{"id": tc["id"], "type": tc["type"], "function": tc["function"]}
for tc in accumulated_tool_calls
]
message = ChatMessage(**message_kwargs)
# Build the final response
choice = ChatCompletionResponseChoice(
index=index,
message=message,
finish_reason=finish_reason or "stop",
)
# Create usage info (with dummy values for tests)
usage = UsageInfo(
prompt_tokens=0,
completion_tokens=0,
total_tokens=0,
)
response = ChatCompletionResponse(
id=response_id or "chatcmpl-test",
object="chat.completion",
created=created or 0,
model=model or "test-model",
choices=[choice],
usage=usage,
)
return response
def verify_harmony_messages(
messages: list[Any], expected_messages: list[dict[str, Any]]
):
assert len(messages) == len(expected_messages)
for msg, expected in zip(messages, expected_messages):
if "role" in expected:
assert msg.author.role == expected["role"]
if "author_name" in expected:
assert msg.author.name == expected["author_name"]
if "channel" in expected:
assert msg.channel == expected["channel"]
if "recipient" in expected:
assert msg.recipient == expected["recipient"]
if "content" in expected:
assert msg.content[0].text == expected["content"]
if "content_type" in expected:
assert msg.content_type == expected["content_type"]
if "tool_definitions" in expected:
# Check that the tool definitions match the expected list of tool names
actual_tools = [t.name for t in msg.content[0].tools["functions"].tools]
assert actual_tools == expected["tool_definitions"]
def verify_chat_response(
response: ChatCompletionResponse,
content: str | None = None,
reasoning: str | None = None,
tool_calls: list[tuple[str, str]] | None = None,
):
assert len(response.choices) == 1
message = response.choices[0].message
if content is not None:
assert message.content == content
else:
assert not message.content
if reasoning is not None:
assert message.reasoning == reasoning
else:
assert not message.reasoning
if tool_calls:
assert message.tool_calls is not None
assert len(message.tool_calls) == len(tool_calls)
for tc, (expected_name, expected_args) in zip(message.tool_calls, tool_calls):
assert tc.function.name == expected_name
assert tc.function.arguments == expected_args
else:
assert not message.tool_calls
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment