Commit 99324e25 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.9.2' into v0.9.2-ori

parents cc7f22a8 a5dd03c1
......@@ -5,6 +5,7 @@
import asyncio
import hashlib
import json
import logging
import pickle
import socket
from collections.abc import AsyncIterator
......@@ -19,10 +20,11 @@ from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.utils import (CacheInfo, FlexibleArgumentParser, LRUCache,
MemorySnapshot, PlaceholderModule, StoreBoolean,
bind_kv_cache, common_broadcastable_dtype,
deprecate_kwargs, get_open_port, is_lossless_cast,
make_zmq_path, make_zmq_socket, memory_profiling,
merge_async_iterators, sha256, split_zmq_path,
supports_kw, swap_dict_values)
deprecate_kwargs, get_open_port, get_tcp_uri,
is_lossless_cast, join_host_port, make_zmq_path,
make_zmq_socket, memory_profiling,
merge_async_iterators, sha256, split_host_port,
split_zmq_path, supports_kw, swap_dict_values)
from .utils import create_new_process_for_each_test, error_on_warning
......@@ -142,6 +144,7 @@ def parser():
parser.add_argument('--batch-size', type=int)
parser.add_argument('--enable-feature', action='store_true')
parser.add_argument('--hf-overrides', type=json.loads)
parser.add_argument('-O', '--compilation-config', type=json.loads)
return parser
......@@ -265,6 +268,11 @@ def test_dict_args(parser):
"val2",
"--hf-overrides.key2.key4",
"val3",
# Test compile config and compilation level
"-O.use_inductor=true",
"-O.backend",
"custom",
"-O1",
# Test = sign
"--hf-overrides.key5=val4",
# Test underscore to dash conversion
......@@ -272,6 +280,22 @@ def test_dict_args(parser):
"val5",
"--hf_overrides.key-7.key_8",
"val6",
# Test data type detection
"--hf_overrides.key9",
"100",
"--hf_overrides.key10",
"100.0",
"--hf_overrides.key11",
"true",
"--hf_overrides.key12.key13",
"null",
# Test '-' and '.' in value
"--hf_overrides.key14.key15",
"-minus.and.dot",
# Test array values
"-O.custom_ops+",
"-quant_fp8",
"-O.custom_ops+=+silu_mul,-rms_norm",
]
parsed_args = parser.parse_args(args)
assert parsed_args.model_name == "something.something"
......@@ -286,7 +310,46 @@ def test_dict_args(parser):
"key-7": {
"key_8": "val6",
},
"key9": 100,
"key10": 100.0,
"key11": True,
"key12": {
"key13": None,
},
"key14": {
"key15": "-minus.and.dot",
}
}
assert parsed_args.compilation_config == {
"level": 1,
"use_inductor": True,
"backend": "custom",
"custom_ops": ["-quant_fp8", "+silu_mul", "-rms_norm"],
}
def test_duplicate_dict_args(caplog_vllm, parser):
args = [
"--model-name=something.something",
"--hf-overrides.key1",
"val1",
"--hf-overrides.key1",
"val2",
"-O1",
"-O.level",
"2",
"-O3",
]
parsed_args = parser.parse_args(args)
# Should be the last value
assert parsed_args.hf_overrides == {"key1": "val2"}
assert parsed_args.compilation_config == {"level": 3}
assert len(caplog_vllm.records) == 1
assert "duplicate" in caplog_vllm.text
assert "--hf-overrides.key1" in caplog_vllm.text
assert "-O.level" in caplog_vllm.text
# yapf: enable
......@@ -814,3 +877,44 @@ def test_make_zmq_socket_ipv6():
def test_make_zmq_path():
assert make_zmq_path("tcp", "127.0.0.1", "5555") == "tcp://127.0.0.1:5555"
assert make_zmq_path("tcp", "::1", "5555") == "tcp://[::1]:5555"
def test_get_tcp_uri():
assert get_tcp_uri("127.0.0.1", 5555) == "tcp://127.0.0.1:5555"
assert get_tcp_uri("::1", 5555) == "tcp://[::1]:5555"
def test_split_host_port():
# valid ipv4
assert split_host_port("127.0.0.1:5555") == ("127.0.0.1", 5555)
# invalid ipv4
with pytest.raises(ValueError):
# multi colon
assert split_host_port("127.0.0.1::5555")
with pytest.raises(ValueError):
# tailing colon
assert split_host_port("127.0.0.1:5555:")
with pytest.raises(ValueError):
# no colon
assert split_host_port("127.0.0.15555")
with pytest.raises(ValueError):
# none int port
assert split_host_port("127.0.0.1:5555a")
# valid ipv6
assert split_host_port("[::1]:5555") == ("::1", 5555)
# invalid ipv6
with pytest.raises(ValueError):
# multi colon
assert split_host_port("[::1]::5555")
with pytest.raises(IndexError):
# no colon
assert split_host_port("[::1]5555")
with pytest.raises(ValueError):
# none int port
assert split_host_port("[::1]:5555a")
def test_join_host_port():
assert join_host_port("127.0.0.1", 5555) == "127.0.0.1:5555"
assert join_host_port("::1", 5555) == "[::1]:5555"
......@@ -68,6 +68,7 @@ def _run_incremental_decode(tokenizer,
None,
params,
None,
None,
0.0,
None,
cache_salt=None,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa: E501
import json
import pytest
from vllm.entrypoints.openai.protocol import FunctionCall, ToolCall
from vllm.entrypoints.openai.tool_parsers import MinimaxToolParser
from vllm.transformers_utils.tokenizer import get_tokenizer
# Use a common model that is likely to be available
MODEL = "MiniMaxAi/MiniMax-M1-40k"
@pytest.fixture(scope="module")
def minimax_tokenizer():
return get_tokenizer(tokenizer_name=MODEL)
@pytest.fixture
def minimax_tool_parser(minimax_tokenizer):
return MinimaxToolParser(minimax_tokenizer)
def assert_tool_calls(actual_tool_calls: list[ToolCall],
expected_tool_calls: list[ToolCall]):
assert len(actual_tool_calls) == len(expected_tool_calls)
for actual_tool_call, expected_tool_call in zip(actual_tool_calls,
expected_tool_calls):
assert isinstance(actual_tool_call.id, str)
assert len(actual_tool_call.id) > 16
assert actual_tool_call.type == "function"
assert actual_tool_call.function == expected_tool_call.function
def test_extract_tool_calls_no_tools(minimax_tool_parser):
model_output = "This is a test"
extracted_tool_calls = minimax_tool_parser.extract_tool_calls(
model_output, request=None) # type: ignore[arg-type]
assert not extracted_tool_calls.tools_called
assert extracted_tool_calls.tool_calls == []
assert extracted_tool_calls.content == model_output
@pytest.mark.parametrize(
ids=[
"single_tool_call",
"multiple_tool_calls",
"tool_call_with_content_before",
"tool_call_with_single_line_json",
"tool_call_incomplete_tag",
],
argnames=["model_output", "expected_tool_calls", "expected_content"],
argvalues=[
(
"""<tool_calls>
{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}
</tool_calls>""",
[
ToolCall(function=FunctionCall(
name="get_current_weather",
arguments=json.dumps({
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit",
}),
))
],
None,
),
(
"""<tool_calls>
{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}
{"name": "get_current_weather", "arguments": {"city": "Orlando", "state": "FL", "unit": "fahrenheit"}}
</tool_calls>""",
[
ToolCall(function=FunctionCall(
name="get_current_weather",
arguments=json.dumps({
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit",
}),
)),
ToolCall(function=FunctionCall(
name="get_current_weather",
arguments=json.dumps({
"city": "Orlando",
"state": "FL",
"unit": "fahrenheit",
}),
)),
],
None,
),
(
"""I'll help you check the weather. <tool_calls>
{"name": "get_current_weather", "arguments": {"city": "Seattle", "state": "WA", "unit": "celsius"}}
</tool_calls>""",
[
ToolCall(function=FunctionCall(
name="get_current_weather",
arguments=json.dumps({
"city": "Seattle",
"state": "WA",
"unit": "celsius",
}),
))
],
"I'll help you check the weather.",
),
(
"""<tool_calls>
{"name": "get_current_weather", "arguments": {"city": "New York", "state": "NY", "unit": "celsius"}}
</tool_calls>""",
[
ToolCall(function=FunctionCall(
name="get_current_weather",
arguments=json.dumps({
"city": "New York",
"state": "NY",
"unit": "celsius",
}),
))
],
None,
),
(
"""<tool_calls>
{"name": "get_current_weather", "arguments": {"city": "Boston", "state": "MA"}}""",
[
ToolCall(function=FunctionCall(
name="get_current_weather",
arguments=json.dumps({
"city": "Boston",
"state": "MA",
}),
))
],
None,
),
],
)
def test_extract_tool_calls(minimax_tool_parser, model_output,
expected_tool_calls, expected_content):
extracted_tool_calls = minimax_tool_parser.extract_tool_calls(
model_output, request=None) # type: ignore[arg-type]
assert extracted_tool_calls.tools_called
assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls)
assert extracted_tool_calls.content == expected_content
def test_preprocess_model_output_with_thinking_tags(minimax_tool_parser):
"""Test that tool calls within thinking tags are removed during preprocessing."""
model_output = """<think>Let me think about this. <tool_calls>
{"name": "fake_tool", "arguments": {"param": "value"}}
</tool_calls> This should be removed.</think>
I'll help you with that. <tool_calls>
{"name": "get_current_weather", "arguments": {"city": "Seattle", "state": "WA"}}
</tool_calls>"""
processed_output = minimax_tool_parser.preprocess_model_output(
model_output)
# The tool call within thinking tags should be removed
assert "fake_tool" not in processed_output
# But the thinking tag itself should remain
assert "<think>" in processed_output
assert "</think>" in processed_output
# The actual tool call outside thinking tags should remain
assert "get_current_weather" in processed_output
def test_extract_tool_calls_with_thinking_tags(minimax_tool_parser):
"""Test tool extraction when thinking tags contain tool calls that should be ignored."""
model_output = """<think>I should use a tool. <tool_calls>
{"name": "ignored_tool", "arguments": {"should": "ignore"}}
</tool_calls></think>
Let me help you with the weather. <tool_calls>
{"name": "get_current_weather", "arguments": {"city": "Miami", "state": "FL", "unit": "fahrenheit"}}
</tool_calls>"""
extracted_tool_calls = minimax_tool_parser.extract_tool_calls(
model_output, request=None) # type: ignore[arg-type]
assert extracted_tool_calls.tools_called
assert len(extracted_tool_calls.tool_calls) == 1
assert extracted_tool_calls.tool_calls[
0].function.name == "get_current_weather"
# Content extraction is based on the position of the first <tool_calls> in the original model_output
# Since preprocessing removes tool calls within thinking tags, the actual first <tool_calls> is the external one
expected_content = """<think>I should use a tool. <tool_calls>
{"name": "ignored_tool", "arguments": {"should": "ignore"}}
</tool_calls></think>
Let me help you with the weather."""
assert extracted_tool_calls.content == expected_content
def test_extract_tool_calls_invalid_json(minimax_tool_parser):
"""Test that invalid JSON in tool calls is handled gracefully."""
model_output = """<tool_calls>
{"name": "valid_tool", "arguments": {"city": "Seattle"}}
{invalid json here}
{"name": "another_valid_tool", "arguments": {"param": "value"}}
</tool_calls>"""
extracted_tool_calls = minimax_tool_parser.extract_tool_calls(
model_output, request=None) # type: ignore[arg-type]
assert extracted_tool_calls.tools_called
# Should extract only the valid JSON tool calls
assert len(extracted_tool_calls.tool_calls) == 2
assert extracted_tool_calls.tool_calls[0].function.name == "valid_tool"
assert extracted_tool_calls.tool_calls[
1].function.name == "another_valid_tool"
def test_extract_tool_calls_missing_name_or_arguments(minimax_tool_parser):
"""Test that tool calls missing name or arguments are filtered out."""
model_output = """<tool_calls>
{"name": "valid_tool", "arguments": {"city": "Seattle"}}
{"name": "missing_args"}
{"arguments": {"city": "Portland"}}
{"name": "another_valid_tool", "arguments": {"param": "value"}}
</tool_calls>"""
extracted_tool_calls = minimax_tool_parser.extract_tool_calls(
model_output, request=None) # type: ignore[arg-type]
assert extracted_tool_calls.tools_called
# Should extract only the valid tool calls with both name and arguments
assert len(extracted_tool_calls.tool_calls) == 2
assert extracted_tool_calls.tool_calls[0].function.name == "valid_tool"
assert extracted_tool_calls.tool_calls[
1].function.name == "another_valid_tool"
def test_streaming_basic_functionality(minimax_tool_parser):
"""Test basic streaming functionality."""
# Reset streaming state
minimax_tool_parser.current_tool_name_sent = False
minimax_tool_parser.prev_tool_call_arr = []
minimax_tool_parser.current_tool_id = -1
minimax_tool_parser.streamed_args_for_tool = []
# Test with a simple tool call
current_text = """<tool_calls>
{"name": "get_current_weather", "arguments": {"city": "Seattle"}}
</tool_calls>"""
# First call should handle the initial setup
result = minimax_tool_parser.extract_tool_calls_streaming(
previous_text="",
current_text=current_text,
delta_text="</tool_calls>",
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[],
request=None,
)
# The result might be None or contain tool call information
# This depends on the internal state management
if result is not None and hasattr(result,
'tool_calls') and result.tool_calls:
assert len(result.tool_calls) >= 0
def test_streaming_with_content_before_tool_calls(minimax_tool_parser):
"""Test streaming when there's content before tool calls."""
# Reset streaming state
minimax_tool_parser.current_tool_name_sent = False
minimax_tool_parser.prev_tool_call_arr = []
minimax_tool_parser.current_tool_id = -1
minimax_tool_parser.streamed_args_for_tool = []
current_text = "I'll help you with that. <tool_calls>"
# When there's content before tool calls, it should be returned as content
result = minimax_tool_parser.extract_tool_calls_streaming(
previous_text="I'll help you",
current_text=current_text,
delta_text=" with that. <tool_calls>",
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[],
request=None,
)
if result is not None and hasattr(result, 'content'):
# Should contain some content
assert result.content is not None
def test_streaming_no_tool_calls(minimax_tool_parser):
"""Test streaming when there are no tool calls."""
current_text = "This is just regular text without any tool calls."
result = minimax_tool_parser.extract_tool_calls_streaming(
previous_text="This is just regular text",
current_text=current_text,
delta_text=" without any tool calls.",
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[],
request=None,
)
# Should return the delta text as content
assert result is not None
assert hasattr(result, 'content')
assert result.content == " without any tool calls."
def test_streaming_with_thinking_tags(minimax_tool_parser):
"""Test streaming with thinking tags that contain tool calls."""
# Reset streaming state
minimax_tool_parser.current_tool_name_sent = False
minimax_tool_parser.prev_tool_call_arr = []
minimax_tool_parser.current_tool_id = -1
minimax_tool_parser.streamed_args_for_tool = []
current_text = """<think><tool_calls>{"name": "ignored", "arguments": {}}</tool_calls></think><tool_calls>{"name": "real_tool", "arguments": {"param": "value"}}</tool_calls>"""
result = minimax_tool_parser.extract_tool_calls_streaming(
previous_text="",
current_text=current_text,
delta_text=current_text,
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[],
request=None,
)
# The preprocessing should remove tool calls from thinking tags
# and only process the real tool call
if result is not None and hasattr(result,
'tool_calls') and result.tool_calls:
for tool_call in result.tool_calls:
assert tool_call.function.name != "ignored"
def test_extract_tool_calls_multiline_json_not_supported(minimax_tool_parser):
"""Test that multiline JSON in tool calls is not currently supported."""
model_output = """<tool_calls>
{
"name": "get_current_weather",
"arguments": {
"city": "New York",
"state": "NY",
"unit": "celsius"
}
}
</tool_calls>"""
extracted_tool_calls = minimax_tool_parser.extract_tool_calls(
model_output, request=None) # type: ignore[arg-type]
# Multiline JSON is currently not supported, should return no tools called
assert not extracted_tool_calls.tools_called
assert extracted_tool_calls.tool_calls == []
assert extracted_tool_calls.content is None
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
import pytest
from vllm.entrypoints.openai.protocol import FunctionCall, ToolCall
from vllm.entrypoints.openai.tool_parsers import xLAMToolParser
from vllm.transformers_utils.tokenizer import get_tokenizer
# Use a common model that is likely to be available
MODEL = "Salesforce/Llama-xLAM-2-8B-fc-r"
@pytest.fixture(scope="module")
def xlam_tokenizer():
return get_tokenizer(tokenizer_name=MODEL)
@pytest.fixture
def xlam_tool_parser(xlam_tokenizer):
return xLAMToolParser(xlam_tokenizer)
def assert_tool_calls(actual_tool_calls: list[ToolCall],
expected_tool_calls: list[ToolCall]):
assert len(actual_tool_calls) == len(expected_tool_calls)
for actual_tool_call, expected_tool_call in zip(actual_tool_calls,
expected_tool_calls):
assert isinstance(actual_tool_call.id, str)
assert len(actual_tool_call.id) > 16
assert actual_tool_call.type == "function"
assert actual_tool_call.function == expected_tool_call.function
def test_extract_tool_calls_no_tools(xlam_tool_parser):
model_output = "This is a test"
extracted_tool_calls = xlam_tool_parser.extract_tool_calls(
model_output, request=None) # type: ignore[arg-type]
assert not extracted_tool_calls.tools_called
assert extracted_tool_calls.tool_calls == []
assert extracted_tool_calls.content == model_output
@pytest.mark.parametrize(
ids=[
"parallel_tool_calls",
"single_tool_with_think_tag",
"single_tool_with_json_code_block",
"single_tool_with_tool_calls_tag",
],
argnames=["model_output", "expected_tool_calls", "expected_content"],
argvalues=[
(
"""[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}, {"name": "get_current_weather", "arguments": {"city": "Orlando", "state": "FL", "unit": "fahrenheit"}}]""", # noqa: E501
[
ToolCall(function=FunctionCall(
name="get_current_weather",
arguments=json.dumps({
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit",
}),
)),
ToolCall(function=FunctionCall(
name="get_current_weather",
arguments=json.dumps({
"city": "Orlando",
"state": "FL",
"unit": "fahrenheit",
}),
)),
],
None,
),
(
"""<think>I'll help you with that.</think>[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]""", # noqa: E501
[
ToolCall(function=FunctionCall(
name="get_current_weather",
arguments=json.dumps({
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit",
}),
))
],
"<think>I'll help you with that.</think>",
),
(
"""I'll help you with that.\n```json\n[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]\n```""", # noqa: E501
[
ToolCall(function=FunctionCall(
name="get_current_weather",
arguments=json.dumps({
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit",
}),
))
],
"I'll help you with that.",
),
(
"""I'll check the weather for you.[TOOL_CALLS][{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]""", # noqa: E501
[
ToolCall(function=FunctionCall(
name="get_current_weather",
arguments=json.dumps({
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit",
}),
))
],
"I'll check the weather for you.",
),
],
)
def test_extract_tool_calls(xlam_tool_parser, model_output,
expected_tool_calls, expected_content):
extracted_tool_calls = xlam_tool_parser.extract_tool_calls(
model_output, request=None) # type: ignore[arg-type]
assert extracted_tool_calls.tools_called
assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls)
assert extracted_tool_calls.content == expected_content
@pytest.mark.parametrize(
ids=["list_structured_tool_call"],
argnames=["model_output", "expected_tool_calls", "expected_content"],
argvalues=[
(
"""[{"name": "get_current_weather", "arguments": {"city": "Seattle", "state": "WA", "unit": "celsius"}}]""", # noqa: E501
[
ToolCall(function=FunctionCall(
name="get_current_weather",
arguments=json.dumps({
"city": "Seattle",
"state": "WA",
"unit": "celsius",
}),
))
],
None,
),
],
)
def test_extract_tool_calls_list_structure(xlam_tool_parser, model_output,
expected_tool_calls,
expected_content):
"""Test extraction of tool calls when the model outputs a list-structured tool call.""" # noqa: E501
extracted_tool_calls = xlam_tool_parser.extract_tool_calls(
model_output, request=None) # type: ignore[arg-type]
assert extracted_tool_calls.tools_called
assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls)
assert extracted_tool_calls.content == expected_content
# Test for preprocess_model_output method
def test_preprocess_model_output(xlam_tool_parser):
# Test with list structure
model_output = """[{"name": "get_current_weather", "arguments": {"city": "Seattle"}}]""" # noqa: E501
content, potential_tool_calls = xlam_tool_parser.preprocess_model_output(
model_output)
assert content is None
assert potential_tool_calls == model_output
# Test with thinking tag
model_output = """<think>I'll help you with that.</think>[{"name": "get_current_weather", "arguments": {"city": "Seattle"}}]""" # noqa: E501
content, potential_tool_calls = xlam_tool_parser.preprocess_model_output(
model_output)
assert content == "<think>I'll help you with that.</think>"
assert (
potential_tool_calls ==
'[{"name": "get_current_weather", "arguments": {"city": "Seattle"}}]')
# Test with JSON code block
model_output = """I'll help you with that.
```json
[{"name": "get_current_weather", "arguments": {"city": "Seattle"}}]
```"""
content, potential_tool_calls = xlam_tool_parser.preprocess_model_output(
model_output)
assert content == "I'll help you with that."
assert "get_current_weather" in potential_tool_calls
# Test with no tool calls
model_output = """I'll help you with that."""
content, potential_tool_calls = xlam_tool_parser.preprocess_model_output(
model_output)
assert content == model_output
assert potential_tool_calls is None
# Simulate streaming to test extract_tool_calls_streaming
def test_streaming_with_list_structure(xlam_tool_parser):
# Reset streaming state
xlam_tool_parser.prev_tool_calls = []
xlam_tool_parser.current_tools_sent = []
xlam_tool_parser.streamed_args = []
xlam_tool_parser.current_tool_id = -1
# Simulate receiving a message with list structure
current_text = """[{"name": "get_current_weather", "arguments": {"city": "Seattle"}}]""" # noqa: E501
# First call to set up the tool
xlam_tool_parser.extract_tool_calls_streaming(
previous_text="",
current_text=current_text,
delta_text="]",
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[],
request=None,
)
# Make sure the tool is set up correctly
assert (xlam_tool_parser.current_tool_id
>= 0), "Tool index should be initialized"
# Manually set up the state for sending the tool name
xlam_tool_parser.current_tools_sent = [False]
# Call to send the function name
result = xlam_tool_parser.extract_tool_calls_streaming(
previous_text=current_text,
current_text=current_text,
delta_text="",
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[],
request=None,
)
# Check that we get a result with the proper tool call
if result is not None:
assert hasattr(result, "tool_calls")
assert len(result.tool_calls) == 1
assert result.tool_calls[0].function.name == "get_current_weather"
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import ast
import pytest
from tools.validate_config import validate_ast
_TestConfig1 = '''
@config
class _TestConfig1:
pass
'''
_TestConfig2 = '''
@config
@dataclass
class _TestConfig2:
a: int
"""docstring"""
'''
_TestConfig3 = '''
@config
@dataclass
class _TestConfig3:
a: int = 1
'''
_TestConfig4 = '''
@config
@dataclass
class _TestConfig4:
a: Union[Literal[1], Literal[2]] = 1
"""docstring"""
'''
@pytest.mark.parametrize(("test_config", "expected_error"), [
(_TestConfig1, "must be a dataclass"),
(_TestConfig2, "must have a default"),
(_TestConfig3, "must have a docstring"),
(_TestConfig4, "must use a single Literal"),
])
def test_config(test_config, expected_error):
tree = ast.parse(test_config)
with pytest.raises(Exception, match=expected_error):
validate_ast(tree)
......@@ -667,42 +667,54 @@ def get_physical_device_indices(devices):
@_nvml()
def wait_for_gpu_memory_to_clear(devices: list[int],
threshold_bytes: int,
def wait_for_gpu_memory_to_clear(*,
devices: list[int],
threshold_bytes: Optional[int] = None,
threshold_ratio: Optional[float] = None,
timeout_s: float = 120) -> None:
assert threshold_bytes is not None or threshold_ratio is not None
# Use nvml instead of pytorch to reduce measurement error from torch cuda
# context.
devices = get_physical_device_indices(devices)
start_time = time.time()
while True:
output: dict[int, str] = {}
output_raw: dict[int, float] = {}
output_raw: dict[int, tuple[float, float]] = {}
for device in devices:
if current_platform.is_rocm():
dev_handle = amdsmi_get_processor_handles()[device]
mem_info = amdsmi_get_gpu_vram_usage(dev_handle)
gb_used = mem_info["vram_used"] / 2**10
gb_total = mem_info["vram_total"] / 2**10
else:
dev_handle = nvmlDeviceGetHandleByIndex(device)
mem_info = nvmlDeviceGetMemoryInfo(dev_handle)
gb_used = mem_info.used / 2**30
output_raw[device] = gb_used
output[device] = f'{gb_used:.02f}'
gb_total = mem_info.total / 2**30
output_raw[device] = (gb_used, gb_total)
output[device] = f'{gb_used:.02f}/{gb_total:.02f}'
print('gpu memory used (GB): ', end='')
print('gpu memory used/total (GiB): ', end='')
for k, v in output.items():
print(f'{k}={v}; ', end='')
print('')
if threshold_bytes is not None:
is_free = lambda used, total: used <= threshold_bytes / 2**30
threshold = f"{threshold_bytes/2**30} GiB"
else:
is_free = lambda used, total: used / total <= threshold_ratio
threshold = f"{threshold_ratio:.2f}"
dur_s = time.time() - start_time
if all(v <= (threshold_bytes / 2**30) for v in output_raw.values()):
if all(is_free(used, total) for used, total in output_raw.values()):
print(f'Done waiting for free GPU memory on devices {devices=} '
f'({threshold_bytes/2**30=}) {dur_s=:.02f}')
f'({threshold=}) {dur_s=:.02f}')
break
if dur_s >= timeout_s:
raise ValueError(f'Memory of devices {devices=} not free after '
f'{dur_s=:.02f} ({threshold_bytes/2**30=})')
f'{dur_s=:.02f} ({threshold=})')
time.sleep(5)
......
......@@ -43,6 +43,7 @@ def make_request(request_id,
multi_modal_hashes=mm_hashes,
multi_modal_placeholders=mm_positions,
sampling_params=SamplingParams(max_tokens=17),
pooling_params=None,
eos_token_id=100,
lora_request=None,
cache_salt=cache_salt,
......@@ -900,3 +901,19 @@ def test_get_kv_cache_config():
with pytest.raises(NotImplementedError):
get_kv_cache_config(vllm_config, kv_cache_specs_hybrid,
mem_per_block_per_layer * 2 * 32)
# Test num_gpu_blocks_override
vllm_config.cache_config.num_gpu_blocks_override = 16
kv_cache_config_override_blocks = get_kv_cache_config(
vllm_config, kv_cache_specs_full, mem_per_block_per_layer * 2 * 32)
assert kv_cache_config_override_blocks == KVCacheConfig(
num_blocks=16,
kv_cache_tensors=[
KVCacheTensor(size=mem_per_block_per_layer * 16,
shared_by=["layer_1"]),
KVCacheTensor(size=mem_per_block_per_layer * 16,
shared_by=["layer_2"]),
],
kv_cache_groups=[
KVCacheGroupSpec(["layer_1", "layer_2"], new_kv_cache_spec())
])
\ No newline at end of file
......@@ -39,6 +39,7 @@ def make_request(request_id,
multi_modal_placeholders=mm_positions,
sampling_params=SamplingParams(max_tokens=17,
prompt_logprobs=prompt_logprobs),
pooling_params=None,
eos_token_id=100,
lora_request=None,
cache_salt=cache_salt,
......
......@@ -9,14 +9,15 @@ import torch
from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig,
SchedulerConfig, SpeculativeConfig, VllmConfig)
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.sampling_params import SamplingParams
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput
from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec)
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus
from vllm.v1.structured_output import StructuredOutputManager
from vllm.v1.structured_output.request import StructuredOutputRequest
EOS_TOKEN_ID = 50256
......@@ -33,6 +34,7 @@ def create_scheduler(
block_size: int = 16,
max_model_len: Optional[int] = None,
num_speculative_tokens: Optional[int] = None,
skip_tokenizer_init: bool = False,
) -> Scheduler:
'''Create scheduler under test.
......@@ -65,6 +67,7 @@ def create_scheduler(
trust_remote_code=True,
dtype="float16",
seed=42,
skip_tokenizer_init=skip_tokenizer_init,
)
# Cache config, optionally force APC
kwargs_cache = ({} if enable_prefix_caching is None else {
......@@ -135,6 +138,7 @@ def create_requests(num_requests: int,
request_id=f"{i}",
prompt_token_ids=[i] * num_tokens,
sampling_params=sampling_params,
pooling_params=None,
multi_modal_inputs=mm_inputs,
multi_modal_placeholders=mm_position,
multi_modal_hashes=None,
......@@ -185,7 +189,7 @@ def test_get_num_unfinished_requests():
])
def test_schedule(enable_prefix_caching: Optional[bool],
prompt_logprobs: Optional[int]):
'''Test scheduling.
'''Test scheduling.
Two cases: default APC/no prompt logprobs; APC=True + prompt logprobs
'''
scheduler = create_scheduler(enable_prefix_caching=enable_prefix_caching)
......@@ -197,7 +201,7 @@ def test_schedule(enable_prefix_caching: Optional[bool],
# Test initial scheduling
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == len(requests)
assert len(output.scheduled_cached_reqs) == 0
assert output.scheduled_cached_reqs.num_reqs == 0
assert len(output.finished_req_ids) == 0
# Verify all requests are scheduled.
for req_id, num_tokens in output.num_scheduled_tokens.items():
......@@ -224,7 +228,7 @@ def test_schedule_multimodal_requests():
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == len(requests)
assert len(output.scheduled_cached_reqs) == 0
assert output.scheduled_cached_reqs.num_reqs == 0
assert len(output.finished_req_ids) == 0
for req_id, num_tokens in output.num_scheduled_tokens.items():
assert num_tokens == len(requests[int(req_id)].prompt_token_ids)
......@@ -258,7 +262,7 @@ def test_schedule_partial_requests():
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == 3
assert len(output.scheduled_cached_reqs) == 0
assert output.scheduled_cached_reqs.num_reqs == 0
assert len(output.finished_req_ids) == 0
assert scheduler.max_num_encoder_input_tokens == 1024
......@@ -283,6 +287,7 @@ def test_schedule_partial_requests():
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
scheduler.update_from_output(output, model_runner_output)
......@@ -293,7 +298,7 @@ def test_schedule_partial_requests():
output = scheduler.schedule()
assert len(scheduler.running) == 3
assert len(output.scheduled_new_reqs) == 0
assert len(output.scheduled_cached_reqs) == 2
assert output.scheduled_cached_reqs.num_reqs == 2
assert len(output.finished_req_ids) == 0
assert output.num_scheduled_tokens[requests[0].request_id] == 1
assert output.num_scheduled_tokens[requests[1].request_id] == 700
......@@ -317,7 +322,7 @@ def test_no_mm_input_chunking():
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == 1
assert len(output.scheduled_cached_reqs) == 0
assert output.scheduled_cached_reqs.num_reqs == 0
assert len(output.finished_req_ids) == 0
# We want to only see the 400 text tokens at the start scheduled
assert output.num_scheduled_tokens[requests[0].request_id] == 400
......@@ -333,13 +338,14 @@ def test_no_mm_input_chunking():
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
scheduler.update_from_output(output, model_runner_output)
output = scheduler.schedule()
assert len(scheduler.running) == 1
assert len(output.scheduled_new_reqs) == 0
assert len(output.scheduled_cached_reqs) == 1
assert output.scheduled_cached_reqs.num_reqs == 1
assert len(output.finished_req_ids) == 0
assert output.num_scheduled_tokens[requests[0].request_id] == 800
......@@ -376,7 +382,7 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == 3
assert len(output.scheduled_cached_reqs) == 0
assert output.scheduled_cached_reqs.num_reqs == 0
assert len(output.finished_req_ids) == 0
# The first request is scheduled partially - 400.
......@@ -396,6 +402,7 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
scheduler.update_from_output(output, model_runner_output)
......@@ -404,7 +411,7 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
output1 = scheduler.schedule()
assert len(scheduler.running) == 3
assert len(output1.scheduled_new_reqs) == 0
assert len(output1.scheduled_cached_reqs) == 3
assert output1.scheduled_cached_reqs.num_reqs == 3
assert len(output1.finished_req_ids) == 0
assert output1.num_scheduled_tokens[requests[0].request_id] == 400
assert output1.num_scheduled_tokens[requests[1].request_id] == 400
......@@ -420,12 +427,13 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
scheduler.update_from_output(output1, model_runner_output)
output2 = scheduler.schedule()
assert len(scheduler.running) == 3
assert len(output2.scheduled_new_reqs) == 0
assert len(output2.scheduled_cached_reqs) == 3
assert output2.scheduled_cached_reqs.num_reqs == 3
assert len(output2.finished_req_ids) == 0
assert output2.num_scheduled_tokens[requests[0].request_id] == 1
assert output2.num_scheduled_tokens[requests[1].request_id] == 1
......@@ -444,23 +452,24 @@ def test_stop_via_update_from_output():
scheduler.requests[req.request_id] = req
scheduler.running.append(req)
scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
scheduled_cached_reqs=[],
num_scheduled_tokens={
requests[0].request_id: 1,
requests[1].request_id: 2
},
total_num_scheduled_tokens=3,
scheduled_encoder_inputs={},
scheduled_spec_decode_tokens={
requests[0].request_id: [],
requests[1].request_id: [10]
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
structured_output_request_ids={},
grammar_bitmask=None)
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens={
requests[0].request_id: 1,
requests[1].request_id: 2
},
total_num_scheduled_tokens=3,
scheduled_encoder_inputs={},
scheduled_spec_decode_tokens={
requests[0].request_id: [],
requests[1].request_id: [10]
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
structured_output_request_ids={},
grammar_bitmask=None)
model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests],
......@@ -473,7 +482,8 @@ def test_stop_via_update_from_output():
11]], # First request hits EOS, second continues
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={})
prompt_logprobs_dict={},
pooler_output=[])
scheduler.update_from_output(scheduler_output, model_output)
......@@ -495,23 +505,25 @@ def test_stop_via_update_from_output():
scheduler.requests[req.request_id] = req
scheduler.running.append(req)
scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
scheduled_cached_reqs=[],
num_scheduled_tokens={
requests[0].request_id: 3,
requests[1].request_id: 2
},
total_num_scheduled_tokens=5,
scheduled_encoder_inputs={},
scheduled_spec_decode_tokens={
requests[0].request_id: [10, 42],
requests[1].request_id: [13]
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
structured_output_request_ids={},
grammar_bitmask=None)
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens={
requests[0].request_id: 3,
requests[1].request_id: 2
},
total_num_scheduled_tokens=5,
scheduled_encoder_inputs={},
scheduled_spec_decode_tokens={
requests[0].request_id: [10, 42],
requests[1].request_id: [13]
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
structured_output_request_ids={},
grammar_bitmask=None,
)
model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests],
......@@ -523,7 +535,8 @@ def test_stop_via_update_from_output():
[13, 14]], # First request hits stop token
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={})
prompt_logprobs_dict={},
pooler_output=[])
scheduler.update_from_output(scheduler_output, model_output)
......@@ -544,23 +557,25 @@ def test_stop_via_update_from_output():
scheduler.requests[req.request_id] = req
scheduler.running.append(req)
scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
scheduled_cached_reqs=[],
num_scheduled_tokens={
requests[0].request_id: 3,
requests[1].request_id: 1
},
total_num_scheduled_tokens=4,
scheduled_encoder_inputs={},
scheduled_spec_decode_tokens={
requests[0].request_id: [10, 11],
requests[1].request_id: []
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
structured_output_request_ids={},
grammar_bitmask=None)
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens={
requests[0].request_id: 3,
requests[1].request_id: 1
},
total_num_scheduled_tokens=4,
scheduled_encoder_inputs={},
scheduled_spec_decode_tokens={
requests[0].request_id: [10, 11],
requests[1].request_id: []
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
structured_output_request_ids={},
grammar_bitmask=None,
)
model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests],
......@@ -572,7 +587,8 @@ def test_stop_via_update_from_output():
[13]], # First request exceeds max_tokens
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={})
prompt_logprobs_dict={},
pooler_output=[])
scheduler.update_from_output(scheduler_output, model_output)
......@@ -595,7 +611,7 @@ def test_stop_via_update_from_output():
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=[],
scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens={requests[0].request_id: 3},
total_num_scheduled_tokens=3,
scheduled_encoder_inputs={},
......@@ -614,7 +630,8 @@ def test_stop_via_update_from_output():
sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={})
prompt_logprobs_dict={},
pooler_output=[])
scheduler.update_from_output(scheduler_output, model_output)
......@@ -663,6 +680,7 @@ def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
scheduler.update_from_output(scheduler_output0, model_runner_output)
......@@ -680,6 +698,7 @@ def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
scheduler.update_from_output(scheduler_output1, model_runner_output)
......@@ -730,6 +749,7 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
spec_token_ids=spec_tokens,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
engine_core_outputs = scheduler.update_from_output(output,
model_runner_output)
......@@ -769,6 +789,7 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
engine_core_outputs = scheduler.update_from_output(output,
model_runner_output)
......@@ -896,6 +917,7 @@ def test_kv_connector_basic():
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
# Ensure ScheduleOutput is correct.
......@@ -941,6 +963,7 @@ def test_kv_connector_basic():
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
# We should get a local cache hit of NUM_TOKENS_PREFIX and
......@@ -1007,6 +1030,7 @@ def test_kv_connector_unable_to_allocate():
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
# Just one request should be running.
......@@ -1087,6 +1111,7 @@ def test_kv_connector_handles_preemption():
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
# All can be scheduled - 1st token.
......@@ -1133,7 +1158,6 @@ def test_kv_connector_handles_preemption():
assert len(scheduler.running) == 1
_ = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT)
assert len(scheduler.running) == 0
assert len(scheduler.waiting) == 1
# All memory should be freed since nothing is running.
assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() \
== NUM_BLOCKS - 1
......@@ -1181,6 +1205,7 @@ def make_output(scheduler: Scheduler):
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
......@@ -1191,7 +1216,6 @@ def assert_scheduler_empty(scheduler: Scheduler):
assert len(scheduler.waiting) == 0
assert len(scheduler.running) == 0
assert len(scheduler.finished_req_ids) == 0
assert len(scheduler._cached_reqs_data) == 0
# EncoderCacheManager.
assert len(scheduler.encoder_cache_manager.freed) == 0
......@@ -1247,3 +1271,628 @@ def test_memory_leak():
# Confirm no memory leak.
assert_scheduler_empty(scheduler)
def create_scheduler_with_priority(
model: str = "facebook/opt-125m",
max_num_seqs: int = 16,
max_num_batched_tokens: int = 8192,
enable_prefix_caching: Optional[bool] = None,
long_prefill_token_threshold: int = 0,
disable_chunked_mm_input: bool = False,
use_kv_connector: bool = False,
num_blocks: int = 10000,
block_size: int = 16,
max_model_len: Optional[int] = None,
num_speculative_tokens: Optional[int] = None,
) -> Scheduler:
'''Create scheduler with priority policy enabled.
Args:
model: model under test
max_num_seqs: max sequences to schedule
max_num_batch_tokens: max num tokens to batch
enable_prefix_caching: optionally force APC config
(True/False) or use default
(None)
Returns:
{class}`Scheduler` instance with priority scheduling
'''
if max_model_len is None:
max_model_len = max_num_batched_tokens
scheduler_config = SchedulerConfig(
max_num_seqs=max_num_seqs,
max_num_batched_tokens=max_num_batched_tokens,
max_model_len=max_model_len,
long_prefill_token_threshold=long_prefill_token_threshold,
disable_chunked_mm_input=disable_chunked_mm_input,
enable_chunked_prefill=True,
policy="priority", # Enable priority scheduling
)
model_config = ModelConfig(
model=model,
task="auto",
tokenizer=model,
tokenizer_mode="auto",
trust_remote_code=True,
dtype="float16",
seed=42,
)
# Cache config, optionally force APC
kwargs_cache = ({} if enable_prefix_caching is None else {
'enable_prefix_caching': enable_prefix_caching
})
cache_config = CacheConfig(
block_size=block_size,
gpu_memory_utilization=0.9,
swap_space=0,
cache_dtype="auto",
**kwargs_cache,
)
kv_transfer_config = KVTransferConfig(
kv_connector="SharedStorageConnector",
kv_role="kv_both",
kv_connector_extra_config={"shared_storage_path": "local_storage"},
) if use_kv_connector else None
speculative_config: Optional[SpeculativeConfig] = None
if num_speculative_tokens is not None:
speculative_config = SpeculativeConfig(
model="ngram", num_speculative_tokens=num_speculative_tokens)
vllm_config = VllmConfig(
scheduler_config=scheduler_config,
model_config=model_config,
cache_config=cache_config,
kv_transfer_config=kv_transfer_config,
speculative_config=speculative_config,
)
kv_cache_config = KVCacheConfig(
num_blocks=num_blocks, # A large number of blocks to hold all requests
kv_cache_tensors=[],
kv_cache_groups=[
KVCacheGroupSpec(['layer'],
FullAttentionSpec(block_size, 1, 1, torch.float32,
False))
],
)
cache_config.num_gpu_blocks = num_blocks
return Scheduler(
vllm_config=vllm_config,
kv_cache_config=kv_cache_config,
log_stats=True,
structured_output_manager=StructuredOutputManager(vllm_config),
)
def create_requests_with_priority(
num_requests: int,
priorities: list[int],
arrival_times: Optional[list[float]] = None,
num_tokens: int = 10,
mm_positions: Optional[list[PlaceholderRange]] = None,
max_tokens: int = 16,
stop_token_ids: Optional[list[int]] = None,
prompt_logprobs: Optional[int] = None):
"""Create requests with specified priorities and arrival times."""
assert len(priorities) == num_requests
if arrival_times is not None:
assert len(arrival_times) == num_requests
else:
arrival_times = [float(i) for i in range(num_requests)]
sampling_params = SamplingParams(ignore_eos=False,
max_tokens=max_tokens,
stop_token_ids=stop_token_ids,
prompt_logprobs=prompt_logprobs)
requests = []
for i in range(num_requests):
if mm_positions is not None:
mm_position = mm_positions[i]
mm_inputs = [MultiModalKwargs({})] * len(mm_position)
else:
mm_position = None
mm_inputs = None
request = Request(
request_id=f"{i}",
prompt_token_ids=[i] * num_tokens,
sampling_params=sampling_params,
pooling_params=None,
multi_modal_inputs=mm_inputs,
multi_modal_placeholders=mm_position,
multi_modal_hashes=None,
eos_token_id=EOS_TOKEN_ID,
arrival_time=arrival_times[i],
priority=priorities[i],
)
requests.append(request)
return requests
def test_priority_scheduling_basic_ordering():
"""Test that requests are scheduled in priority order
(lower value = higher priority)."""
scheduler = create_scheduler_with_priority()
# Create requests with different priorities
# Priority 0 (highest), 1, 2 (lowest)
priorities = [2, 0, 1] # Add in non-priority order
arrival_times = [1.0, 2.0, 3.0] # All different arrival times
requests = create_requests_with_priority(num_requests=3,
priorities=priorities,
arrival_times=arrival_times)
# Add requests in non-priority order
for request in requests:
scheduler.add_request(request)
# Schedule and verify priority order
output = scheduler.schedule()
# Should schedule all requests since they fit in budget
assert len(output.scheduled_new_reqs) == 3
# Verify they are scheduled in priority order:
# req_1 (priority 0), req_2 (priority 1), req_0 (priority 2)
scheduled_req_ids = [req.req_id for req in output.scheduled_new_reqs]
assert scheduled_req_ids == ["1", "2", "0"]
def test_priority_scheduling_arrival_time_tiebreaker():
"""Test that arrival time is used
as tiebreaker when priorities are equal."""
scheduler = create_scheduler_with_priority()
# Create requests with same priority but different arrival times
priorities = [1, 1, 1] # All same priority
arrival_times = [3.0, 1.0, 2.0] # Different arrival times
requests = create_requests_with_priority(num_requests=3,
priorities=priorities,
arrival_times=arrival_times)
# Add requests in non-arrival order
for request in requests:
scheduler.add_request(request)
# Schedule and verify arrival time order
output = scheduler.schedule()
# Should schedule all requests since they fit in budget
assert len(output.scheduled_new_reqs) == 3
# Verify they are scheduled in arrival time order:
# req_1 (1.0), req_2 (2.0), req_0 (3.0)
scheduled_req_ids = [req.req_id for req in output.scheduled_new_reqs]
assert scheduled_req_ids == ["1", "2", "0"]
def test_priority_scheduling_mixed_priority_and_arrival():
"""Test priority scheduling with mixed priorities and arrival times."""
scheduler = create_scheduler_with_priority()
# Create requests with mixed priorities and arrival times
priorities = [2, 1, 1, 0] # Mixed priorities
arrival_times = [1.0, 3.0, 2.0, 4.0] # Mixed arrival times
requests = create_requests_with_priority(num_requests=4,
priorities=priorities,
arrival_times=arrival_times)
# Add requests
for request in requests:
scheduler.add_request(request)
# Schedule and verify order
output = scheduler.schedule()
# Should schedule all requests since they fit in budget
assert len(output.scheduled_new_reqs) == 4
# Expected order:
# 1. req_3 (priority 0, arrival 4.0)
# 2. req_2 (priority 1, arrival 2.0) - earlier arrival than req_1
# 3. req_1 (priority 1, arrival 3.0)
# 4. req_0 (priority 2, arrival 1.0)
scheduled_req_ids = [req.req_id for req in output.scheduled_new_reqs]
assert scheduled_req_ids == ["3", "2", "1", "0"]
def test_priority_scheduling_preemption():
"""Test that priority scheduling preempts
lower priority requests when memory is constrained."""
# Create scheduler with very limited memory to force preemption
scheduler = create_scheduler_with_priority(
max_num_seqs=3, # Allow multiple requests
max_num_batched_tokens=200,
num_blocks=6, # Very limited blocks to force memory pressure
block_size=16, # Standard block size
)
# Create initial low-priority requests that will consume most memory
low_priority_requests = create_requests_with_priority(
num_requests=2,
priorities=[5, 5], # Low priority
arrival_times=[1.0, 2.0],
num_tokens=30 # Large enough to consume significant memory
)
# Add and schedule low priority requests
for request in low_priority_requests:
scheduler.add_request(request)
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == 2
# Simulate model execution to move requests to running state
model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in low_priority_requests],
req_id_to_index={
req.request_id: i
for i, req in enumerate(low_priority_requests)
},
sampled_token_ids=[[100] for _ in low_priority_requests],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
scheduler.update_from_output(output, model_output)
# Verify both requests are running
assert len(scheduler.running) == 2
# Now add a high-priority request that requires memory allocation
# This should trigger preemption due to memory constraints
high_priority_request = create_requests_with_priority(
num_requests=1,
priorities=[0], # High priority
arrival_times=[3.0],
num_tokens=30 # Large enough to require significant memory
)[0]
scheduler.add_request(high_priority_request)
# Schedule again - this should trigger
# preemption when trying to allocate memory
output = scheduler.schedule()
# Due to the scheduler's design, if preemption happens
# during running request scheduling,
# waiting requests won't be scheduled in the same step
# Let's check if preemption occurred by looking at the waiting queue
# If preemption happened, we should see requests in the
# waiting queue
if len(scheduler.waiting) > 1: # high priority + preempted request
# Preemption occurred - verify the high priority request
# gets scheduled next
output2 = scheduler.schedule()
assert len(output2.scheduled_new_reqs) == 1
# High priority request
assert output2.scheduled_new_reqs[0].req_id == "0"
else:
# No preemption needed - all requests fit
# This is also valid behavior if memory allows
assert len(output.scheduled_new_reqs) == 1
# High priority request
assert output.scheduled_new_reqs[0].req_id == "0"
def test_priority_scheduling_no_preemption_when_space_available():
"""Test that preemption doesn't happen
when there's space for new requests."""
scheduler = create_scheduler_with_priority(
max_num_seqs=3, # Allow 3 concurrent requests
max_num_batched_tokens=200, # Sufficient token budget
)
# Add two low-priority running requests
low_priority_requests = create_requests_with_priority(
num_requests=2,
priorities=[5, 5],
arrival_times=[1.0, 2.0],
num_tokens=30)
for request in low_priority_requests:
scheduler.add_request(request)
output = scheduler.schedule()
model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in low_priority_requests],
req_id_to_index={
req.request_id: i
for i, req in enumerate(low_priority_requests)
},
sampled_token_ids=[[100] for _ in low_priority_requests],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
scheduler.update_from_output(output, model_output)
# Add high-priority request
high_priority_request = create_requests_with_priority(num_requests=1,
priorities=[0],
arrival_times=[3.0],
num_tokens=30)[0]
scheduler.add_request(high_priority_request)
# Schedule - should not preempt since there's space
output = scheduler.schedule()
# Should schedule the new request without preemption
assert len(output.scheduled_new_reqs) == 1
assert len(scheduler.running) == 3 # All three requests running
assert len(scheduler.waiting) == 0 # No requests waiting
def test_priority_scheduling_preemption_victim_selection():
"""Test that the correct victim is selected for
preemption based on priority and arrival time."""
# This test verifies the priority-based victim selection logic
# by checking the waiting queue order after adding requests with different
# priorities
scheduler = create_scheduler_with_priority(
max_num_seqs=1, # Force sequential processing to test priority order
)
# Create requests with different priorities
requests = create_requests_with_priority(
num_requests=3,
priorities=[3, 2, 0], # Different priorities: low, medium, high
arrival_times=[1.0, 2.0, 3.0],
num_tokens=10)
# Add all requests
for request in requests:
scheduler.add_request(request)
# Schedule - should only schedule the highest priority request
# (req_2, priority 0)
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == 1
assert output.scheduled_new_reqs[0].req_id == "2" # Highest priority
# Verify the waiting queue has the remaining requests in priority order
assert len(scheduler.waiting) == 2
# Extract waiting requests and verify priority order
waiting_requests = list(scheduler.waiting)
waiting_priorities = [req.priority for req in waiting_requests]
waiting_req_ids = [req.request_id for req in waiting_requests]
# Should be req_1 (priority 2) then req_0 (priority 3)
assert waiting_priorities == [2, 3]
assert waiting_req_ids == ["1", "0"]
def test_priority_scheduling_equal_priority_preemption():
"""Test arrival time tiebreaker when requests have equal priority."""
# This test verifies that arrival time is used as a tiebreaker for equal
# priorities
scheduler = create_scheduler_with_priority(
max_num_seqs=1, # Force sequential processing
)
# Create requests with same priority but different arrival times
requests = create_requests_with_priority(
num_requests=3,
priorities=[2, 2, 2], # Same priority
arrival_times=[3.0, 1.0, 2.0], # Different arrival times
num_tokens=10)
# Add all requests
for request in requests:
scheduler.add_request(request)
# Schedule - should schedule the request with earliest arrival time
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == 1
assert output.scheduled_new_reqs[0].req_id == "1" # Earliest arrival (1.0)
# Verify the waiting queue has remaining requests in arrival time order
assert len(scheduler.waiting) == 2
# Extract waiting requests and verify arrival time order
waiting_requests = list(scheduler.waiting)
waiting_arrival_times = [req.arrival_time for req in waiting_requests]
waiting_req_ids = [req.request_id for req in waiting_requests]
# Should be req_2 (arrival 2.0) then req_0 (arrival 3.0)
assert waiting_arrival_times == [2.0, 3.0]
assert waiting_req_ids == ["2", "0"]
def test_priority_scheduling_waiting_queue_order():
"""Test that the waiting queue maintains priority order."""
scheduler = create_scheduler_with_priority(
max_num_seqs=1, # Only one request can run at a time
)
# Create multiple requests with different priorities
requests = create_requests_with_priority(
num_requests=4,
priorities=[3, 1, 2, 0], # Mixed priorities
arrival_times=[1.0, 2.0, 3.0, 4.0],
num_tokens=10)
# Add all requests
for request in requests:
scheduler.add_request(request)
# Schedule - should only schedule the highest priority request
# (req_3, priority 0)
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == 1
assert output.scheduled_new_reqs[0].req_id == "3"
# Verify waiting queue has remaining requests in priority order
assert len(scheduler.waiting) == 3
# Extract requests from waiting queue
# (it's a heap, so we need to pop to see order)
waiting_requests = list(scheduler.waiting)
waiting_priorities = [req.priority for req in waiting_requests]
waiting_req_ids = [req.request_id for req in waiting_requests]
# Should be ordered by priority: req_1 (1), req_2 (2), req_0 (3)
assert waiting_req_ids == ["1", "2", "0"]
assert waiting_priorities == [1, 2, 3]
def test_priority_scheduling_fcfs_fallback():
"""Test that FCFS behavior is maintained when all
requests have same priority."""
scheduler = create_scheduler_with_priority()
# Create requests with same priority but different arrival times
priorities = [1, 1, 1, 1] # All same priority
arrival_times = [4.0, 1.0, 3.0, 2.0] # Different arrival times
requests = create_requests_with_priority(num_requests=4,
priorities=priorities,
arrival_times=arrival_times)
# Add requests
for request in requests:
scheduler.add_request(request)
# Schedule
output = scheduler.schedule()
# Should schedule all requests in arrival time order
assert len(output.scheduled_new_reqs) == 4
scheduled_req_ids = [req.req_id for req in output.scheduled_new_reqs]
# Expected order by arrival time:
# req_1 (1.0), req_3 (2.0), req_2 (3.0), req_0 (4.0)
assert scheduled_req_ids == ["1", "3", "2", "0"]
def test_priority_scheduling_with_limited_slots():
"""Test priority scheduling when max_num_seqs limits concurrent requests."""
scheduler = create_scheduler_with_priority(
max_num_seqs=2, # Only allow 2 concurrent requests
max_num_batched_tokens=1000, # Plenty of token budget
)
# Create requests with different priorities
requests = create_requests_with_priority(
num_requests=4,
priorities=[3, 1, 2, 0], # Mixed priorities
arrival_times=[1.0, 2.0, 3.0, 4.0],
num_tokens=10)
# Add all requests
for request in requests:
scheduler.add_request(request)
# Schedule - should only schedule the 2 highest priority requests
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == 2
# Should schedule req_3 (priority 0) and req_1 (priority 1)
scheduled_req_ids = [req.req_id for req in output.scheduled_new_reqs]
assert "3" in scheduled_req_ids # Priority 0
assert "1" in scheduled_req_ids # Priority 1
# Remaining requests should be in waiting queue in priority order
assert len(scheduler.waiting) == 2
# Extract waiting requests and verify order
waiting_requests = list(scheduler.waiting)
waiting_priorities = [req.priority for req in waiting_requests]
waiting_req_ids = [req.request_id for req in waiting_requests]
# Should be req_2 (priority 2) then req_0 (priority 3)
assert waiting_priorities == [2, 3]
assert waiting_req_ids == ["2", "0"]
def test_priority_scheduling_heap_property():
"""Test that the waiting queue maintains heap
property for priority scheduling."""
scheduler = create_scheduler_with_priority(
max_num_seqs=1, # Only one request can run at a time
)
# Add requests in random priority order
priorities = [5, 1, 8, 3, 2, 7, 4, 6]
arrival_times = [float(i) for i in range(len(priorities))]
requests = create_requests_with_priority(num_requests=len(priorities),
priorities=priorities,
arrival_times=arrival_times,
num_tokens=10)
# Add all requests
for request in requests:
scheduler.add_request(request)
# Schedule one request at a time and verify priority order
scheduled_priorities = []
while scheduler.waiting:
output = scheduler.schedule()
if output.scheduled_new_reqs:
req = output.scheduled_new_reqs[0]
scheduled_priorities.append(requests[int(req.req_id)].priority)
# Simulate completion to make room for next request
model_output = ModelRunnerOutput(
req_ids=[req.req_id],
req_id_to_index={req.req_id: 0},
sampled_token_ids=[[100]],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
scheduler.update_from_output(output, model_output)
# Finish the request to make room for the next one
scheduler.finish_requests(req.req_id,
RequestStatus.FINISHED_STOPPED)
# Verify requests were scheduled in priority order (lowest value first)
expected_priorities = sorted(priorities)
assert scheduled_priorities == expected_priorities
def test_schedule_skip_tokenizer_init():
scheduler = create_scheduler(skip_tokenizer_init=True)
requests = create_requests(num_requests=5)
for request in requests:
scheduler.add_request(request)
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == len(requests)
assert output.grammar_bitmask is None
def test_schedule_skip_tokenizer_init_structured_output_request():
scheduler = create_scheduler(skip_tokenizer_init=True)
guided_params = GuidedDecodingParams(regex="[0-9]+")
sampling_params = SamplingParams(
ignore_eos=False,
max_tokens=16,
guided_decoding=guided_params,
)
request = Request(
request_id="0",
prompt_token_ids=[0, 1],
multi_modal_inputs=None,
multi_modal_hashes=None,
multi_modal_placeholders=None,
sampling_params=sampling_params,
pooling_params=None,
eos_token_id=EOS_TOKEN_ID,
structured_output_request=StructuredOutputRequest(sampling_params),
)
scheduler.add_request(request)
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == 0
assert len(scheduler.running) == 0
assert len(scheduler.waiting) == 1
......@@ -30,7 +30,7 @@ model_config = {
])
@pytest.mark.parametrize("batch_size", [5])
@pytest.mark.parametrize("seed", [1])
def test_sliding_window_retrival(monkeypatch, model, batch_size, seed):
def test_sliding_window_retrieval(monkeypatch, model, batch_size, seed):
"""
The test does a bunch of assignments "x1 = 10\nx2 = 33\n..." and then
asks for value of one of them (which is outside the sliding window).
......
......@@ -15,6 +15,7 @@ from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.inputs import PromptType
from vllm.platforms import current_platform
from vllm.sampling_params import RequestOutputKind
from vllm.utils import set_default_torch_num_threads
from vllm.v1.engine.async_llm import AsyncLLM
from vllm.v1.metrics.loggers import LoggingStatLogger
......@@ -107,7 +108,8 @@ async def test_load(
with monkeypatch.context() as m, ExitStack() as after:
m.setenv("VLLM_USE_V1", "1")
engine = AsyncLLM.from_engine_args(engine_args)
with set_default_torch_num_threads(1):
engine = AsyncLLM.from_engine_args(engine_args)
after.callback(engine.shutdown)
NUM_REQUESTS = 100
......@@ -154,7 +156,8 @@ async def test_abort(
with monkeypatch.context() as m, ExitStack() as after:
m.setenv("VLLM_USE_V1", "1")
engine = AsyncLLM.from_engine_args(engine_args)
with set_default_torch_num_threads(1):
engine = AsyncLLM.from_engine_args(engine_args)
after.callback(engine.shutdown)
NUM_REQUESTS = 100
......@@ -226,7 +229,8 @@ async def test_finished_flag(
with monkeypatch.context() as m, ExitStack() as after:
m.setenv("VLLM_USE_V1", "1")
engine = AsyncLLM.from_engine_args(engine_args)
with set_default_torch_num_threads(1):
engine = AsyncLLM.from_engine_args(engine_args)
after.callback(engine.shutdown)
sampling_params = SamplingParams(
......@@ -260,7 +264,8 @@ async def test_mid_stream_cancellation(monkeypatch: pytest.MonkeyPatch,
with monkeypatch.context() as m, ExitStack() as after:
m.setenv("VLLM_USE_V1", "1")
engine = AsyncLLM.from_engine_args(engine_args)
with set_default_torch_num_threads(1):
engine = AsyncLLM.from_engine_args(engine_args)
after.callback(engine.shutdown)
NUM_REQUESTS = 100
......@@ -322,10 +327,11 @@ async def test_customize_loggers(monkeypatch):
with monkeypatch.context() as m, ExitStack() as after:
m.setenv("VLLM_USE_V1", "1")
engine = AsyncLLM.from_engine_args(
TEXT_ENGINE_ARGS,
stat_loggers=[MockLoggingStatLogger],
)
with set_default_torch_num_threads(1):
engine = AsyncLLM.from_engine_args(
TEXT_ENGINE_ARGS,
stat_loggers=[MockLoggingStatLogger],
)
after.callback(engine.shutdown)
await engine.do_log_stats()
......@@ -340,7 +346,8 @@ async def test_dp_rank_argument(monkeypatch: pytest.MonkeyPatch):
with monkeypatch.context() as m, ExitStack() as after:
m.setenv("VLLM_USE_V1", "1")
engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
with set_default_torch_num_threads(1):
engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
after.callback(engine.shutdown)
sampling_params = SamplingParams(max_tokens=100,
......@@ -362,3 +369,33 @@ async def test_dp_rank_argument(monkeypatch: pytest.MonkeyPatch):
sampling_params=sampling_params,
data_parallel_rank=1):
pass
@pytest.mark.asyncio
async def test_check_health(monkeypatch: pytest.MonkeyPatch):
"""Test that check_health returns normally for healthy engine
and raises EngineDeadError when the engine is dead.
"""
from unittest.mock import patch
from vllm.v1.engine.exceptions import EngineDeadError
with monkeypatch.context() as m, ExitStack() as after:
m.setenv("VLLM_USE_V1", "1")
with set_default_torch_num_threads(1):
engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
after.callback(engine.shutdown)
# Test 1: Healthy engine should not raise any exception
await engine.check_health()
# Test 2: Mock the errored property to simulate a dead engine
with patch.object(type(engine),
'errored',
new_callable=lambda: property(lambda self: True)
), pytest.raises(EngineDeadError):
await engine.check_health()
# Test 3: Verify healthy engine still works after mock
await engine.check_health()
......@@ -12,13 +12,14 @@ from transformers import AutoTokenizer
from vllm import SamplingParams
from vllm.engine.arg_utils import EngineArgs
from vllm.platforms import current_platform
from vllm.utils import set_default_torch_num_threads
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.core import EngineCore
from vllm.v1.executor.abstract import Executor, UniProcExecutor
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.outputs import ModelRunnerOutput
from ...utils import create_new_process_for_each_test
from ...utils import create_new_process_for_each_test, multi_gpu_test
if not current_platform.is_cuda():
pytest.skip(reason="V1 currently only supported on CUDA.",
......@@ -38,6 +39,7 @@ def make_request() -> EngineCoreRequest:
mm_hashes=None,
mm_placeholders=None,
sampling_params=SamplingParams(),
pooling_params=None,
eos_token_id=None,
arrival_time=time.time(),
lora_request=None,
......@@ -56,9 +58,10 @@ def test_engine_core(monkeypatch: pytest.MonkeyPatch):
vllm_config = engine_args.create_engine_config()
executor_class = Executor.get_class(vllm_config)
engine_core = EngineCore(vllm_config=vllm_config,
executor_class=executor_class,
log_stats=True)
with set_default_torch_num_threads(1):
engine_core = EngineCore(vllm_config=vllm_config,
executor_class=executor_class,
log_stats=True)
"""Test basic request lifecycle."""
# First request.
......@@ -190,9 +193,10 @@ def test_engine_core_advanced_sampling(monkeypatch: pytest.MonkeyPatch):
vllm_config = engine_args.create_engine_config()
executor_class = Executor.get_class(vllm_config)
engine_core = EngineCore(vllm_config=vllm_config,
executor_class=executor_class,
log_stats=True)
with set_default_torch_num_threads(1):
engine_core = EngineCore(vllm_config=vllm_config,
executor_class=executor_class,
log_stats=True)
"""Test basic request lifecycle."""
# First request.
request: EngineCoreRequest = make_request()
......@@ -286,9 +290,10 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
enforce_eager=True,
)
vllm_config = engine_args.create_engine_config()
engine_core = EngineCore(vllm_config=vllm_config,
log_stats=False,
executor_class=DummyExecutor)
with set_default_torch_num_threads(1):
engine_core = EngineCore(vllm_config=vllm_config,
log_stats=False,
executor_class=DummyExecutor)
assert engine_core.batch_queue is not None
# Add two requests in a row. Each request have 12 prompt tokens.
......@@ -374,3 +379,37 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
# Odd steps schedules a new batch.
assert output is None
step += 1
@multi_gpu_test(num_gpus=2)
def test_engine_core_tp(monkeypatch: pytest.MonkeyPatch):
"""
Test engine can initialize worker in tp properly
"""
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
"""Setup the EngineCore."""
engine_args = EngineArgs(
model=MODEL_NAME,
tensor_parallel_size=2,
# Reduce startup time.
enforce_eager=True,
)
vllm_config = engine_args.create_engine_config()
executor_class = Executor.get_class(vllm_config)
with set_default_torch_num_threads(1):
engine_core = EngineCore(vllm_config=vllm_config,
executor_class=executor_class,
log_stats=True)
def get_worker_cache_config_field(worker, key: str):
return getattr(worker.cache_config, key)
num_gpu_blocks = engine_core.collective_rpc(
get_worker_cache_config_field, args=("num_gpu_blocks", ))
num_cpu_blocks = engine_core.collective_rpc(
get_worker_cache_config_field, args=("num_cpu_blocks", ))
assert all(x is not None for x in num_gpu_blocks)
assert all(x is not None for x in num_cpu_blocks)
......@@ -8,8 +8,10 @@ import time
import uuid
from threading import Thread
from typing import Optional
from unittest.mock import MagicMock
import pytest
import torch
from transformers import AutoTokenizer
from tests.utils import multi_gpu_test
......@@ -19,12 +21,13 @@ from vllm.distributed.kv_events import (BlockStored, KVEventBatch,
from vllm.engine.arg_utils import EngineArgs
from vllm.platforms import current_platform
from vllm.usage.usage_lib import UsageContext
from vllm.utils import set_default_torch_num_threads
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.core import EngineCore
from vllm.v1.engine.core_client import (AsyncMPClient, EngineCoreClient,
SyncMPClient)
from vllm.v1.engine.utils import CoreEngineProcManager
from vllm.v1.executor.abstract import Executor
from vllm.v1.utils import CoreEngineProcManager
from ...distributed.conftest import MockSubscriber
from ...utils import create_new_process_for_each_test
......@@ -52,6 +55,7 @@ def make_request(
mm_hashes=None,
mm_placeholders=None,
sampling_params=params,
pooling_params=None,
eos_token_id=None,
arrival_time=time.time(),
lora_request=None,
......@@ -138,13 +142,15 @@ def test_engine_core_client(monkeypatch: pytest.MonkeyPatch,
vllm_config = engine_args.create_engine_config(
UsageContext.UNKNOWN_CONTEXT)
executor_class = Executor.get_class(vllm_config)
client = EngineCoreClient.make_client(
multiprocess_mode=multiprocessing_mode,
asyncio_mode=False,
vllm_config=vllm_config,
executor_class=executor_class,
log_stats=False,
)
with set_default_torch_num_threads(1):
client = EngineCoreClient.make_client(
multiprocess_mode=multiprocessing_mode,
asyncio_mode=False,
vllm_config=vllm_config,
executor_class=executor_class,
log_stats=False,
)
MAX_TOKENS = 20
params = SamplingParams(max_tokens=MAX_TOKENS)
......@@ -223,13 +229,15 @@ async def test_engine_core_client_asyncio(monkeypatch: pytest.MonkeyPatch):
vllm_config = engine_args.create_engine_config(
usage_context=UsageContext.UNKNOWN_CONTEXT)
executor_class = Executor.get_class(vllm_config)
client = EngineCoreClient.make_client(
multiprocess_mode=True,
asyncio_mode=True,
vllm_config=vllm_config,
executor_class=executor_class,
log_stats=True,
)
with set_default_torch_num_threads(1):
client = EngineCoreClient.make_client(
multiprocess_mode=True,
asyncio_mode=True,
vllm_config=vllm_config,
executor_class=executor_class,
log_stats=True,
)
try:
MAX_TOKENS = 20
......@@ -312,13 +320,14 @@ def test_kv_cache_events(
UsageContext.UNKNOWN_CONTEXT)
executor_class = Executor.get_class(vllm_config)
client = EngineCoreClient.make_client(
multiprocess_mode=multiprocessing_mode,
asyncio_mode=False,
vllm_config=vllm_config,
executor_class=executor_class,
log_stats=False,
)
with set_default_torch_num_threads(1):
client = EngineCoreClient.make_client(
multiprocess_mode=multiprocessing_mode,
asyncio_mode=False,
vllm_config=vllm_config,
executor_class=executor_class,
log_stats=False,
)
endpoint = publisher_config.endpoint.replace("*", "127.0.0.1")
subscriber = MockSubscriber(endpoint,
topic=publisher_config.topic,
......@@ -394,13 +403,14 @@ async def test_kv_cache_events_dp(
UsageContext.UNKNOWN_CONTEXT)
executor_class = Executor.get_class(vllm_config)
client = EngineCoreClient.make_client(
multiprocess_mode=multiprocessing_mode,
asyncio_mode=True,
vllm_config=vllm_config,
executor_class=executor_class,
log_stats=False,
)
with set_default_torch_num_threads(1):
client = EngineCoreClient.make_client(
multiprocess_mode=multiprocessing_mode,
asyncio_mode=True,
vllm_config=vllm_config,
executor_class=executor_class,
log_stats=False,
)
await asyncio.sleep(1)
# Build endpoints for all DP ranks
......@@ -509,3 +519,72 @@ def test_startup_failure(monkeypatch: pytest.MonkeyPatch):
)
assert "Engine core initialization failed" in str(e_info.value)
@create_new_process_for_each_test()
def test_engine_core_proc_instantiation_cuda_empty(
monkeypatch: pytest.MonkeyPatch):
"""
Test that EngineCoreProc can be instantiated when CUDA_VISIBLE_DEVICES
is empty. This ensures the engine frontend does not need access to GPUs.
"""
from vllm.v1.engine.core import EngineCoreProc
from vllm.v1.executor.abstract import Executor
# Create a simple mock executor instead of a complex custom class
mock_executor_class = MagicMock(spec=Executor)
def create_mock_executor(vllm_config):
mock_executor = MagicMock()
# Only implement the methods that are actually called during init
from vllm.v1.kv_cache_interface import FullAttentionSpec
mock_spec = FullAttentionSpec(block_size=16,
num_kv_heads=1,
head_size=64,
dtype=torch.float16,
use_mla=False)
mock_executor.get_kv_cache_specs.return_value = [{
"default": mock_spec
}]
mock_executor.determine_available_memory.return_value = [
1024 * 1024 * 1024
]
mock_executor.initialize_from_config.return_value = None
mock_executor.max_concurrent_batches = 1
return mock_executor
mock_executor_class.side_effect = create_mock_executor
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
m.setenv("CUDA_VISIBLE_DEVICES", "") # No CUDA devices
from vllm.v1.engine.utils import EngineZmqAddresses
def mock_startup_handshake(self, handshake_socket, on_head_node,
parallel_config):
return EngineZmqAddresses(inputs=["tcp://127.0.0.1:5555"],
outputs=["tcp://127.0.0.1:5556"],
coordinator_input=None,
coordinator_output=None)
# Background processes are not important here
m.setattr(EngineCoreProc, "startup_handshake", mock_startup_handshake)
vllm_config = EngineArgs(
model="deepseek-ai/DeepSeek-V2-Lite",
trust_remote_code=True).create_engine_config()
engine_core_proc = EngineCoreProc(
vllm_config=vllm_config,
local_client=True,
handshake_address="tcp://127.0.0.1:12345",
executor_class=mock_executor_class,
log_stats=False,
engine_index=0,
)
engine_core_proc.shutdown()
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from transformers import AutoTokenizer
from vllm.sampling_params import SamplingParams
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.detokenizer import IncrementalDetokenizer
# ruff: noqa: E501
def test_fast_inc_detok_invalid_utf8_err_case():
"""
Test edge case where tokenizer can produce non-monotonic,
invalid UTF-8 output, which breaks the internal state of
tokenizers' DecodeStream.
See https://github.com/vllm-project/vllm/issues/17448.
Thanks to reproducer from @fpaupier:
https://gist.github.com/fpaupier/0ed1375bd7633c5be6c894b1c7ac1be3.
"""
tokenizer = AutoTokenizer.from_pretrained("google/gemma-3-1b-it")
# Create a test request
prompt_token_ids = [107, 4606, 236787, 107]
params = SamplingParams(skip_special_tokens=True)
request = EngineCoreRequest(
"test",
prompt_token_ids,
None,
None,
None,
params,
None,
None,
0.0,
None,
cache_salt=None,
data_parallel_rank=None,
)
detokenizer = IncrementalDetokenizer.from_new_request(tokenizer, request)
assert detokenizer.__class__.__name__ == "FastIncrementalDetokenizer", \
"Should use FastIncrementalDetokenizer by default"
# Process tokens incrementally
test_tokens = [
236840, 107, 138, 236782, 107, 140, 236775, 6265, 1083, 623, 121908,
147418, 827, 107, 140, 236775, 6265, 236779, 2084, 1083, 623, 203292,
827, 107, 140, 236775, 6265, 236779, 7777, 1083, 623, 121908, 147418,
569, 537, 236789, 65880, 569, 537, 236789, 62580, 853, 115693, 210118,
35178, 16055, 1270, 759, 215817, 4758, 1925, 1117, 827, 107, 140,
236775, 5654, 1083, 623, 110733, 46291, 827, 107, 140, 236775, 5654,
236779, 2084, 1083, 623, 136955, 56731, 827, 107, 140, 236775, 5654,
236779, 7777, 1083, 623, 194776, 2947, 496, 109811, 1608, 890, 215817,
4758, 1925, 1117, 2789, 432, 398, 602, 31118, 569, 124866, 134772, 509,
19478, 1640, 33779, 236743, 236770, 236819, 236825, 236771, 432, 398,
432, 237167, 827, 107, 140, 236775, 77984, 1083, 623, 2709, 236745,
2555, 513, 236789, 602, 31118, 569
]
output = ""
for i, token_id in enumerate(test_tokens):
detokenizer.update([token_id], False)
finished = i == len(test_tokens) - 1
output += detokenizer.get_next_output_text(finished, delta=True)
# fmt: off
assert output == r'''[
{
"source": "Résultats",
"source_type": "CONCEPT",
"source_description": "Résultats de l'analyse de l'impact des opérations israéliennes sur la frontière libanaise",
"target": "Israël",
"target_type": "ORGANIZATION",
"target_description": "Pays qui a obtenu à sa frontière libanaise « un niveau de calme inédit depuis les années 1960 »",
"relationship": "Obtention d'un niveau de'''
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import random
from typing import Optional
from typing import TYPE_CHECKING, Optional
import pytest
from vllm import LLM, SamplingParams
from vllm import LLM
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
from vllm.v1.metrics.reader import Counter, Gauge, Histogram, Metric, Vector
if TYPE_CHECKING:
from tests.conftest import VllmRunner
MODEL = "facebook/opt-125m"
DTYPE = "half"
def _vllm_model(apc: bool, vllm_runner, monkeypatch):
def _vllm_model(
apc: bool,
vllm_runner: type[VllmRunner],
monkeypatch: pytest.MonkeyPatch,
*,
skip_tokenizer_init: bool = False,
):
"""Set up VllmRunner instance."""
monkeypatch.setenv("VLLM_USE_V1", "1")
return vllm_runner(
......@@ -23,6 +34,7 @@ def _vllm_model(apc: bool, vllm_runner, monkeypatch):
enforce_eager=True,
enable_prefix_caching=apc,
gpu_memory_utilization=0.5,
skip_tokenizer_init=skip_tokenizer_init,
)
......@@ -45,9 +57,27 @@ def vllm_model_apc(vllm_runner, monkeypatch):
yield vllm_model
@pytest.fixture(
# Function scope decouples tests & allows
# env var adjustment via monkeypatch
scope="function",
# Prefix caching
params=[False, True])
def vllm_model_skip_tokenizer_init(vllm_runner, request, monkeypatch):
"""VllmRunner test fixture with APC."""
with _vllm_model(
request.param,
vllm_runner,
monkeypatch,
skip_tokenizer_init=True,
) as vllm_model:
yield vllm_model
def _get_test_sampling_params(
prompt_list: list[str],
seed: Optional[int] = 42,
structured_outputs: bool = False,
) -> tuple[list[SamplingParams], list[int]]:
"""Generate random sampling params for a batch."""
......@@ -62,14 +92,34 @@ def _get_test_sampling_params(
n_list = [get_mostly_n_gt1() for _ in range(len(prompt_list))]
# High temperature to maximize the chance of unique completions
return [
SamplingParams(temperature=0.95, top_p=0.95, n=n, seed=seed)
for n in n_list
SamplingParams(
temperature=0.95,
top_p=0.95,
n=n,
seed=seed,
guided_decoding=GuidedDecodingParams(
regex="[0-9]+") if structured_outputs else None,
) for n in n_list
], n_list
def test_compatibility_with_skip_tokenizer_init(
vllm_model_skip_tokenizer_init: VllmRunner,
example_prompts: list[str],
):
# Case 1: Structured output request should raise an error.
sampling_params_list, _ = _get_test_sampling_params(
example_prompts,
structured_outputs=True,
)
model: LLM = vllm_model_skip_tokenizer_init.model
with pytest.raises(ValueError):
_ = model.generate(example_prompts, sampling_params_list)
def test_parallel_sampling(vllm_model, example_prompts) -> None:
"""Test passes if parallel sampling `n>1` yields `n` unique completions.
Args:
vllm_model: VllmRunner instance under test.
example_prompt: test fixture providing prompts for testing.
......
......@@ -66,7 +66,8 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind,
output_kind=request_output_kind,
stop=[],
include_stop_str_in_output=False,
))
),
pooling_params=None)
for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
]
......@@ -416,7 +417,8 @@ def test_logprobs_processor(request_output_kind: RequestOutputKind,
include_stop_str_in_output=False,
logprobs=num_sample_logprobs,
prompt_logprobs=num_prompt_logprobs,
))
),
pooling_params=None)
for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
]
......@@ -582,7 +584,8 @@ def test_stop_token(include_stop_str_in_output: bool,
logprobs=num_sample_logprobs,
prompt_logprobs=None,
ignore_eos=ignore_eos,
))
),
pooling_params=None)
# Add request to the detokenizer.
output_processor.add_request(request, prompt_string)
......@@ -678,7 +681,8 @@ def test_stop_string(include_stop_str_in_output: bool,
include_stop_str_in_output=include_stop_str_in_output,
logprobs=num_sample_logprobs,
prompt_logprobs=None,
))
),
pooling_params=None)
for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
]
......@@ -786,6 +790,7 @@ def test_iteration_stats(dummy_test_vectors):
cache_salt=None,
data_parallel_rank=None,
sampling_params=SamplingParams(),
pooling_params=None,
) for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
]
......
......@@ -38,7 +38,7 @@ def default_server_args():
]])
def server(default_server_args, request):
if request.param:
default_server_args.extend(request.param)
default_server_args = default_server_args + request.param
with RemoteOpenAIServer(MODEL_NAME, default_server_args) as remote_server:
yield remote_server
......
......@@ -2,10 +2,12 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import os
import re
import openai # use the official client for correctness check
import pytest
import pytest_asyncio
import requests
from tests.utils import RemoteOpenAIServer
......@@ -14,6 +16,122 @@ MODEL_NAME = "ibm-research/PowerMoE-3b"
DP_SIZE = os.getenv("DP_SIZE", "1")
def get_prometheus_metrics(
server: RemoteOpenAIServer) -> dict[str, dict[str, float]]:
"""Fetch and parse Prometheus metrics from the /metrics endpoint.
Returns:
Dict mapping metric names to their values grouped by labels.
For example: {"vllm:request_success": {
"engine=0": 5.0, "engine=1": 3.0}
}
"""
try:
response = requests.get(server.url_for("metrics"), timeout=10)
response.raise_for_status()
metrics: dict[str, dict[str, float]] = {}
# Regex patterns for Prometheus metrics
metric_with_labels = re.compile(
r'^([a-zA-Z_:][a-zA-Z0-9_:]*)\{([^}]*)\}\s+([\d\.\-\+e]+)$')
metric_simple = re.compile(
r'^([a-zA-Z_:][a-zA-Z0-9_:]*)\s+([\d\.\-\+e]+)$')
for line in response.text.split('\n'):
line = line.strip()
# Skip comments and empty lines
if not line or line.startswith('#'):
continue
# Try to match metric with labels first
match = metric_with_labels.match(line)
if match:
metric_name, labels_part, value_str = match.groups()
try:
value = float(value_str)
if metric_name not in metrics:
metrics[metric_name] = {}
metrics[metric_name][f'{{{labels_part}}}'] = value
except ValueError:
continue
else:
# Try simple metric without labels
match = metric_simple.match(line)
if match:
metric_name, value_str = match.groups()
try:
value = float(value_str)
if metric_name not in metrics:
metrics[metric_name] = {}
metrics[metric_name][''] = value
except ValueError:
continue
return metrics
except Exception as e:
pytest.fail(f"Failed to fetch Prometheus metrics: {e}")
return {}
def get_engine_request_counts(
metrics: dict[str, dict[str, float]]) -> dict[str, float]:
"""Extract request counts per engine from Prometheus metrics.
Returns:
Dict mapping engine indices to request counts.
For example: {"0": 15.0, "1": 12.0}
"""
engine_counts = {}
# Look for request success metrics with engine labels
success_metrics = metrics.get("vllm:request_success_total", {})
engine_pattern = re.compile(r'engine="([^"]*)"')
for labels, count in success_metrics.items():
# Extract engine ID from labels using regex
match = engine_pattern.search(labels)
if match:
engine_id = match.group(1)
if engine_id not in engine_counts:
engine_counts[engine_id] = 0.0
engine_counts[engine_id] += count
return engine_counts
def check_request_balancing(server: RemoteOpenAIServer):
"""Check request balancing via Prometheus metrics if DP_SIZE > 1.
Args:
server: The RemoteOpenAIServer instance
"""
dp_size = int(DP_SIZE)
if dp_size <= 1:
return
# Get metrics after all requests are completed
metrics = get_prometheus_metrics(server)
engine_counts = get_engine_request_counts(metrics)
# Check that multiple engines received requests
engines_with_requests = [
engine for engine, count in engine_counts.items() if count > 0
]
assert len(engines_with_requests) == dp_size, (
f"Expected requests to be distributed across multiple engines,"
f" but only engine(s) {engines_with_requests} received "
f"requests. Engine counts: {engine_counts}")
# Verify that the load is reasonably balanced
# (no engine should handle all requests)
total_requests = sum(engine_counts.values())
for count in engine_counts.values():
assert count > total_requests // (dp_size + 1), (
f"requests are imbalanced: {engine_counts}")
@pytest.fixture(scope="module")
def default_server_args():
return [
......@@ -50,6 +168,7 @@ async def client(server):
[MODEL_NAME],
)
async def test_single_completion(client: openai.AsyncOpenAI,
server: RemoteOpenAIServer,
model_name: str) -> None:
async def make_request():
......@@ -97,6 +216,9 @@ async def test_single_completion(client: openai.AsyncOpenAI,
assert len(results) == num_requests
assert all(completion is not None for completion in results)
# Check request balancing via Prometheus metrics if DP_SIZE > 1
check_request_balancing(server)
@pytest.mark.asyncio
@pytest.mark.parametrize(
......@@ -104,6 +226,7 @@ async def test_single_completion(client: openai.AsyncOpenAI,
[MODEL_NAME],
)
async def test_completion_streaming(client: openai.AsyncOpenAI,
server: RemoteOpenAIServer,
model_name: str) -> None:
prompt = "What is an LLM?"
......@@ -170,3 +293,6 @@ async def test_completion_streaming(client: openai.AsyncOpenAI,
results
) == num_requests, f"Expected {num_requests} results, got {len(results)}"
assert all(results), "Not all streaming requests completed successfully."
# Check request balancing via Prometheus metrics if DP_SIZE > 1
check_request_balancing(server)
......@@ -196,8 +196,7 @@ async def stream_service_response(client_info: dict, endpoint: str,
yield chunk
@app.post("/v1/completions")
async def handle_completions(request: Request):
async def _handle_completions(api: str, request: Request):
try:
req_data = await request.json()
request_id = str(uuid.uuid4())
......@@ -206,9 +205,8 @@ async def handle_completions(request: Request):
prefill_client_info = get_next_client(request.app, 'prefill')
# Send request to prefill service
response = await send_request_to_service(prefill_client_info,
"/completions", req_data,
request_id)
response = await send_request_to_service(prefill_client_info, api,
req_data, request_id)
# Extract the needed fields
response_json = response.json()
......@@ -224,7 +222,7 @@ async def handle_completions(request: Request):
# Stream response from decode service
async def generate_stream():
async for chunk in stream_service_response(decode_client_info,
"/completions",
api,
req_data,
request_id=request_id):
yield chunk
......@@ -237,12 +235,22 @@ async def handle_completions(request: Request):
import traceback
exc_info = sys.exc_info()
print("Error occurred in disagg prefill proxy server"
" - completions endpoint")
f" - {api} endpoint")
print(e)
print("".join(traceback.format_exception(*exc_info)))
raise
@app.post("/v1/completions")
async def handle_completions(request: Request):
return await _handle_completions("/completions", request)
@app.post("/v1/chat/completions")
async def handle_chat_completions(request: Request):
return await _handle_completions("/chat/completions", request)
@app.get("/healthcheck")
async def healthcheck():
"""Simple endpoint to check if the server is running."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment