Commit 99324e25 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.9.2' into v0.9.2-ori

parents cc7f22a8 a5dd03c1
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
import asyncio import asyncio
import hashlib import hashlib
import json import json
import logging
import pickle import pickle
import socket import socket
from collections.abc import AsyncIterator from collections.abc import AsyncIterator
...@@ -19,10 +20,11 @@ from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config ...@@ -19,10 +20,11 @@ from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.utils import (CacheInfo, FlexibleArgumentParser, LRUCache, from vllm.utils import (CacheInfo, FlexibleArgumentParser, LRUCache,
MemorySnapshot, PlaceholderModule, StoreBoolean, MemorySnapshot, PlaceholderModule, StoreBoolean,
bind_kv_cache, common_broadcastable_dtype, bind_kv_cache, common_broadcastable_dtype,
deprecate_kwargs, get_open_port, is_lossless_cast, deprecate_kwargs, get_open_port, get_tcp_uri,
make_zmq_path, make_zmq_socket, memory_profiling, is_lossless_cast, join_host_port, make_zmq_path,
merge_async_iterators, sha256, split_zmq_path, make_zmq_socket, memory_profiling,
supports_kw, swap_dict_values) merge_async_iterators, sha256, split_host_port,
split_zmq_path, supports_kw, swap_dict_values)
from .utils import create_new_process_for_each_test, error_on_warning from .utils import create_new_process_for_each_test, error_on_warning
...@@ -142,6 +144,7 @@ def parser(): ...@@ -142,6 +144,7 @@ def parser():
parser.add_argument('--batch-size', type=int) parser.add_argument('--batch-size', type=int)
parser.add_argument('--enable-feature', action='store_true') parser.add_argument('--enable-feature', action='store_true')
parser.add_argument('--hf-overrides', type=json.loads) parser.add_argument('--hf-overrides', type=json.loads)
parser.add_argument('-O', '--compilation-config', type=json.loads)
return parser return parser
...@@ -265,6 +268,11 @@ def test_dict_args(parser): ...@@ -265,6 +268,11 @@ def test_dict_args(parser):
"val2", "val2",
"--hf-overrides.key2.key4", "--hf-overrides.key2.key4",
"val3", "val3",
# Test compile config and compilation level
"-O.use_inductor=true",
"-O.backend",
"custom",
"-O1",
# Test = sign # Test = sign
"--hf-overrides.key5=val4", "--hf-overrides.key5=val4",
# Test underscore to dash conversion # Test underscore to dash conversion
...@@ -272,6 +280,22 @@ def test_dict_args(parser): ...@@ -272,6 +280,22 @@ def test_dict_args(parser):
"val5", "val5",
"--hf_overrides.key-7.key_8", "--hf_overrides.key-7.key_8",
"val6", "val6",
# Test data type detection
"--hf_overrides.key9",
"100",
"--hf_overrides.key10",
"100.0",
"--hf_overrides.key11",
"true",
"--hf_overrides.key12.key13",
"null",
# Test '-' and '.' in value
"--hf_overrides.key14.key15",
"-minus.and.dot",
# Test array values
"-O.custom_ops+",
"-quant_fp8",
"-O.custom_ops+=+silu_mul,-rms_norm",
] ]
parsed_args = parser.parse_args(args) parsed_args = parser.parse_args(args)
assert parsed_args.model_name == "something.something" assert parsed_args.model_name == "something.something"
...@@ -286,7 +310,46 @@ def test_dict_args(parser): ...@@ -286,7 +310,46 @@ def test_dict_args(parser):
"key-7": { "key-7": {
"key_8": "val6", "key_8": "val6",
}, },
"key9": 100,
"key10": 100.0,
"key11": True,
"key12": {
"key13": None,
},
"key14": {
"key15": "-minus.and.dot",
}
} }
assert parsed_args.compilation_config == {
"level": 1,
"use_inductor": True,
"backend": "custom",
"custom_ops": ["-quant_fp8", "+silu_mul", "-rms_norm"],
}
def test_duplicate_dict_args(caplog_vllm, parser):
args = [
"--model-name=something.something",
"--hf-overrides.key1",
"val1",
"--hf-overrides.key1",
"val2",
"-O1",
"-O.level",
"2",
"-O3",
]
parsed_args = parser.parse_args(args)
# Should be the last value
assert parsed_args.hf_overrides == {"key1": "val2"}
assert parsed_args.compilation_config == {"level": 3}
assert len(caplog_vllm.records) == 1
assert "duplicate" in caplog_vllm.text
assert "--hf-overrides.key1" in caplog_vllm.text
assert "-O.level" in caplog_vllm.text
# yapf: enable # yapf: enable
...@@ -814,3 +877,44 @@ def test_make_zmq_socket_ipv6(): ...@@ -814,3 +877,44 @@ def test_make_zmq_socket_ipv6():
def test_make_zmq_path(): def test_make_zmq_path():
assert make_zmq_path("tcp", "127.0.0.1", "5555") == "tcp://127.0.0.1:5555" assert make_zmq_path("tcp", "127.0.0.1", "5555") == "tcp://127.0.0.1:5555"
assert make_zmq_path("tcp", "::1", "5555") == "tcp://[::1]:5555" assert make_zmq_path("tcp", "::1", "5555") == "tcp://[::1]:5555"
def test_get_tcp_uri():
assert get_tcp_uri("127.0.0.1", 5555) == "tcp://127.0.0.1:5555"
assert get_tcp_uri("::1", 5555) == "tcp://[::1]:5555"
def test_split_host_port():
# valid ipv4
assert split_host_port("127.0.0.1:5555") == ("127.0.0.1", 5555)
# invalid ipv4
with pytest.raises(ValueError):
# multi colon
assert split_host_port("127.0.0.1::5555")
with pytest.raises(ValueError):
# tailing colon
assert split_host_port("127.0.0.1:5555:")
with pytest.raises(ValueError):
# no colon
assert split_host_port("127.0.0.15555")
with pytest.raises(ValueError):
# none int port
assert split_host_port("127.0.0.1:5555a")
# valid ipv6
assert split_host_port("[::1]:5555") == ("::1", 5555)
# invalid ipv6
with pytest.raises(ValueError):
# multi colon
assert split_host_port("[::1]::5555")
with pytest.raises(IndexError):
# no colon
assert split_host_port("[::1]5555")
with pytest.raises(ValueError):
# none int port
assert split_host_port("[::1]:5555a")
def test_join_host_port():
assert join_host_port("127.0.0.1", 5555) == "127.0.0.1:5555"
assert join_host_port("::1", 5555) == "[::1]:5555"
...@@ -68,6 +68,7 @@ def _run_incremental_decode(tokenizer, ...@@ -68,6 +68,7 @@ def _run_incremental_decode(tokenizer,
None, None,
params, params,
None, None,
None,
0.0, 0.0,
None, None,
cache_salt=None, cache_salt=None,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa: E501
import json
import pytest
from vllm.entrypoints.openai.protocol import FunctionCall, ToolCall
from vllm.entrypoints.openai.tool_parsers import MinimaxToolParser
from vllm.transformers_utils.tokenizer import get_tokenizer
# Use a common model that is likely to be available
MODEL = "MiniMaxAi/MiniMax-M1-40k"
@pytest.fixture(scope="module")
def minimax_tokenizer():
return get_tokenizer(tokenizer_name=MODEL)
@pytest.fixture
def minimax_tool_parser(minimax_tokenizer):
return MinimaxToolParser(minimax_tokenizer)
def assert_tool_calls(actual_tool_calls: list[ToolCall],
expected_tool_calls: list[ToolCall]):
assert len(actual_tool_calls) == len(expected_tool_calls)
for actual_tool_call, expected_tool_call in zip(actual_tool_calls,
expected_tool_calls):
assert isinstance(actual_tool_call.id, str)
assert len(actual_tool_call.id) > 16
assert actual_tool_call.type == "function"
assert actual_tool_call.function == expected_tool_call.function
def test_extract_tool_calls_no_tools(minimax_tool_parser):
model_output = "This is a test"
extracted_tool_calls = minimax_tool_parser.extract_tool_calls(
model_output, request=None) # type: ignore[arg-type]
assert not extracted_tool_calls.tools_called
assert extracted_tool_calls.tool_calls == []
assert extracted_tool_calls.content == model_output
@pytest.mark.parametrize(
ids=[
"single_tool_call",
"multiple_tool_calls",
"tool_call_with_content_before",
"tool_call_with_single_line_json",
"tool_call_incomplete_tag",
],
argnames=["model_output", "expected_tool_calls", "expected_content"],
argvalues=[
(
"""<tool_calls>
{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}
</tool_calls>""",
[
ToolCall(function=FunctionCall(
name="get_current_weather",
arguments=json.dumps({
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit",
}),
))
],
None,
),
(
"""<tool_calls>
{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}
{"name": "get_current_weather", "arguments": {"city": "Orlando", "state": "FL", "unit": "fahrenheit"}}
</tool_calls>""",
[
ToolCall(function=FunctionCall(
name="get_current_weather",
arguments=json.dumps({
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit",
}),
)),
ToolCall(function=FunctionCall(
name="get_current_weather",
arguments=json.dumps({
"city": "Orlando",
"state": "FL",
"unit": "fahrenheit",
}),
)),
],
None,
),
(
"""I'll help you check the weather. <tool_calls>
{"name": "get_current_weather", "arguments": {"city": "Seattle", "state": "WA", "unit": "celsius"}}
</tool_calls>""",
[
ToolCall(function=FunctionCall(
name="get_current_weather",
arguments=json.dumps({
"city": "Seattle",
"state": "WA",
"unit": "celsius",
}),
))
],
"I'll help you check the weather.",
),
(
"""<tool_calls>
{"name": "get_current_weather", "arguments": {"city": "New York", "state": "NY", "unit": "celsius"}}
</tool_calls>""",
[
ToolCall(function=FunctionCall(
name="get_current_weather",
arguments=json.dumps({
"city": "New York",
"state": "NY",
"unit": "celsius",
}),
))
],
None,
),
(
"""<tool_calls>
{"name": "get_current_weather", "arguments": {"city": "Boston", "state": "MA"}}""",
[
ToolCall(function=FunctionCall(
name="get_current_weather",
arguments=json.dumps({
"city": "Boston",
"state": "MA",
}),
))
],
None,
),
],
)
def test_extract_tool_calls(minimax_tool_parser, model_output,
expected_tool_calls, expected_content):
extracted_tool_calls = minimax_tool_parser.extract_tool_calls(
model_output, request=None) # type: ignore[arg-type]
assert extracted_tool_calls.tools_called
assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls)
assert extracted_tool_calls.content == expected_content
def test_preprocess_model_output_with_thinking_tags(minimax_tool_parser):
"""Test that tool calls within thinking tags are removed during preprocessing."""
model_output = """<think>Let me think about this. <tool_calls>
{"name": "fake_tool", "arguments": {"param": "value"}}
</tool_calls> This should be removed.</think>
I'll help you with that. <tool_calls>
{"name": "get_current_weather", "arguments": {"city": "Seattle", "state": "WA"}}
</tool_calls>"""
processed_output = minimax_tool_parser.preprocess_model_output(
model_output)
# The tool call within thinking tags should be removed
assert "fake_tool" not in processed_output
# But the thinking tag itself should remain
assert "<think>" in processed_output
assert "</think>" in processed_output
# The actual tool call outside thinking tags should remain
assert "get_current_weather" in processed_output
def test_extract_tool_calls_with_thinking_tags(minimax_tool_parser):
"""Test tool extraction when thinking tags contain tool calls that should be ignored."""
model_output = """<think>I should use a tool. <tool_calls>
{"name": "ignored_tool", "arguments": {"should": "ignore"}}
</tool_calls></think>
Let me help you with the weather. <tool_calls>
{"name": "get_current_weather", "arguments": {"city": "Miami", "state": "FL", "unit": "fahrenheit"}}
</tool_calls>"""
extracted_tool_calls = minimax_tool_parser.extract_tool_calls(
model_output, request=None) # type: ignore[arg-type]
assert extracted_tool_calls.tools_called
assert len(extracted_tool_calls.tool_calls) == 1
assert extracted_tool_calls.tool_calls[
0].function.name == "get_current_weather"
# Content extraction is based on the position of the first <tool_calls> in the original model_output
# Since preprocessing removes tool calls within thinking tags, the actual first <tool_calls> is the external one
expected_content = """<think>I should use a tool. <tool_calls>
{"name": "ignored_tool", "arguments": {"should": "ignore"}}
</tool_calls></think>
Let me help you with the weather."""
assert extracted_tool_calls.content == expected_content
def test_extract_tool_calls_invalid_json(minimax_tool_parser):
"""Test that invalid JSON in tool calls is handled gracefully."""
model_output = """<tool_calls>
{"name": "valid_tool", "arguments": {"city": "Seattle"}}
{invalid json here}
{"name": "another_valid_tool", "arguments": {"param": "value"}}
</tool_calls>"""
extracted_tool_calls = minimax_tool_parser.extract_tool_calls(
model_output, request=None) # type: ignore[arg-type]
assert extracted_tool_calls.tools_called
# Should extract only the valid JSON tool calls
assert len(extracted_tool_calls.tool_calls) == 2
assert extracted_tool_calls.tool_calls[0].function.name == "valid_tool"
assert extracted_tool_calls.tool_calls[
1].function.name == "another_valid_tool"
def test_extract_tool_calls_missing_name_or_arguments(minimax_tool_parser):
"""Test that tool calls missing name or arguments are filtered out."""
model_output = """<tool_calls>
{"name": "valid_tool", "arguments": {"city": "Seattle"}}
{"name": "missing_args"}
{"arguments": {"city": "Portland"}}
{"name": "another_valid_tool", "arguments": {"param": "value"}}
</tool_calls>"""
extracted_tool_calls = minimax_tool_parser.extract_tool_calls(
model_output, request=None) # type: ignore[arg-type]
assert extracted_tool_calls.tools_called
# Should extract only the valid tool calls with both name and arguments
assert len(extracted_tool_calls.tool_calls) == 2
assert extracted_tool_calls.tool_calls[0].function.name == "valid_tool"
assert extracted_tool_calls.tool_calls[
1].function.name == "another_valid_tool"
def test_streaming_basic_functionality(minimax_tool_parser):
"""Test basic streaming functionality."""
# Reset streaming state
minimax_tool_parser.current_tool_name_sent = False
minimax_tool_parser.prev_tool_call_arr = []
minimax_tool_parser.current_tool_id = -1
minimax_tool_parser.streamed_args_for_tool = []
# Test with a simple tool call
current_text = """<tool_calls>
{"name": "get_current_weather", "arguments": {"city": "Seattle"}}
</tool_calls>"""
# First call should handle the initial setup
result = minimax_tool_parser.extract_tool_calls_streaming(
previous_text="",
current_text=current_text,
delta_text="</tool_calls>",
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[],
request=None,
)
# The result might be None or contain tool call information
# This depends on the internal state management
if result is not None and hasattr(result,
'tool_calls') and result.tool_calls:
assert len(result.tool_calls) >= 0
def test_streaming_with_content_before_tool_calls(minimax_tool_parser):
"""Test streaming when there's content before tool calls."""
# Reset streaming state
minimax_tool_parser.current_tool_name_sent = False
minimax_tool_parser.prev_tool_call_arr = []
minimax_tool_parser.current_tool_id = -1
minimax_tool_parser.streamed_args_for_tool = []
current_text = "I'll help you with that. <tool_calls>"
# When there's content before tool calls, it should be returned as content
result = minimax_tool_parser.extract_tool_calls_streaming(
previous_text="I'll help you",
current_text=current_text,
delta_text=" with that. <tool_calls>",
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[],
request=None,
)
if result is not None and hasattr(result, 'content'):
# Should contain some content
assert result.content is not None
def test_streaming_no_tool_calls(minimax_tool_parser):
"""Test streaming when there are no tool calls."""
current_text = "This is just regular text without any tool calls."
result = minimax_tool_parser.extract_tool_calls_streaming(
previous_text="This is just regular text",
current_text=current_text,
delta_text=" without any tool calls.",
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[],
request=None,
)
# Should return the delta text as content
assert result is not None
assert hasattr(result, 'content')
assert result.content == " without any tool calls."
def test_streaming_with_thinking_tags(minimax_tool_parser):
"""Test streaming with thinking tags that contain tool calls."""
# Reset streaming state
minimax_tool_parser.current_tool_name_sent = False
minimax_tool_parser.prev_tool_call_arr = []
minimax_tool_parser.current_tool_id = -1
minimax_tool_parser.streamed_args_for_tool = []
current_text = """<think><tool_calls>{"name": "ignored", "arguments": {}}</tool_calls></think><tool_calls>{"name": "real_tool", "arguments": {"param": "value"}}</tool_calls>"""
result = minimax_tool_parser.extract_tool_calls_streaming(
previous_text="",
current_text=current_text,
delta_text=current_text,
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[],
request=None,
)
# The preprocessing should remove tool calls from thinking tags
# and only process the real tool call
if result is not None and hasattr(result,
'tool_calls') and result.tool_calls:
for tool_call in result.tool_calls:
assert tool_call.function.name != "ignored"
def test_extract_tool_calls_multiline_json_not_supported(minimax_tool_parser):
"""Test that multiline JSON in tool calls is not currently supported."""
model_output = """<tool_calls>
{
"name": "get_current_weather",
"arguments": {
"city": "New York",
"state": "NY",
"unit": "celsius"
}
}
</tool_calls>"""
extracted_tool_calls = minimax_tool_parser.extract_tool_calls(
model_output, request=None) # type: ignore[arg-type]
# Multiline JSON is currently not supported, should return no tools called
assert not extracted_tool_calls.tools_called
assert extracted_tool_calls.tool_calls == []
assert extracted_tool_calls.content is None
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
import pytest
from vllm.entrypoints.openai.protocol import FunctionCall, ToolCall
from vllm.entrypoints.openai.tool_parsers import xLAMToolParser
from vllm.transformers_utils.tokenizer import get_tokenizer
# Use a common model that is likely to be available
MODEL = "Salesforce/Llama-xLAM-2-8B-fc-r"
@pytest.fixture(scope="module")
def xlam_tokenizer():
return get_tokenizer(tokenizer_name=MODEL)
@pytest.fixture
def xlam_tool_parser(xlam_tokenizer):
return xLAMToolParser(xlam_tokenizer)
def assert_tool_calls(actual_tool_calls: list[ToolCall],
expected_tool_calls: list[ToolCall]):
assert len(actual_tool_calls) == len(expected_tool_calls)
for actual_tool_call, expected_tool_call in zip(actual_tool_calls,
expected_tool_calls):
assert isinstance(actual_tool_call.id, str)
assert len(actual_tool_call.id) > 16
assert actual_tool_call.type == "function"
assert actual_tool_call.function == expected_tool_call.function
def test_extract_tool_calls_no_tools(xlam_tool_parser):
model_output = "This is a test"
extracted_tool_calls = xlam_tool_parser.extract_tool_calls(
model_output, request=None) # type: ignore[arg-type]
assert not extracted_tool_calls.tools_called
assert extracted_tool_calls.tool_calls == []
assert extracted_tool_calls.content == model_output
@pytest.mark.parametrize(
ids=[
"parallel_tool_calls",
"single_tool_with_think_tag",
"single_tool_with_json_code_block",
"single_tool_with_tool_calls_tag",
],
argnames=["model_output", "expected_tool_calls", "expected_content"],
argvalues=[
(
"""[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}, {"name": "get_current_weather", "arguments": {"city": "Orlando", "state": "FL", "unit": "fahrenheit"}}]""", # noqa: E501
[
ToolCall(function=FunctionCall(
name="get_current_weather",
arguments=json.dumps({
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit",
}),
)),
ToolCall(function=FunctionCall(
name="get_current_weather",
arguments=json.dumps({
"city": "Orlando",
"state": "FL",
"unit": "fahrenheit",
}),
)),
],
None,
),
(
"""<think>I'll help you with that.</think>[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]""", # noqa: E501
[
ToolCall(function=FunctionCall(
name="get_current_weather",
arguments=json.dumps({
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit",
}),
))
],
"<think>I'll help you with that.</think>",
),
(
"""I'll help you with that.\n```json\n[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]\n```""", # noqa: E501
[
ToolCall(function=FunctionCall(
name="get_current_weather",
arguments=json.dumps({
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit",
}),
))
],
"I'll help you with that.",
),
(
"""I'll check the weather for you.[TOOL_CALLS][{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]""", # noqa: E501
[
ToolCall(function=FunctionCall(
name="get_current_weather",
arguments=json.dumps({
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit",
}),
))
],
"I'll check the weather for you.",
),
],
)
def test_extract_tool_calls(xlam_tool_parser, model_output,
expected_tool_calls, expected_content):
extracted_tool_calls = xlam_tool_parser.extract_tool_calls(
model_output, request=None) # type: ignore[arg-type]
assert extracted_tool_calls.tools_called
assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls)
assert extracted_tool_calls.content == expected_content
@pytest.mark.parametrize(
ids=["list_structured_tool_call"],
argnames=["model_output", "expected_tool_calls", "expected_content"],
argvalues=[
(
"""[{"name": "get_current_weather", "arguments": {"city": "Seattle", "state": "WA", "unit": "celsius"}}]""", # noqa: E501
[
ToolCall(function=FunctionCall(
name="get_current_weather",
arguments=json.dumps({
"city": "Seattle",
"state": "WA",
"unit": "celsius",
}),
))
],
None,
),
],
)
def test_extract_tool_calls_list_structure(xlam_tool_parser, model_output,
expected_tool_calls,
expected_content):
"""Test extraction of tool calls when the model outputs a list-structured tool call.""" # noqa: E501
extracted_tool_calls = xlam_tool_parser.extract_tool_calls(
model_output, request=None) # type: ignore[arg-type]
assert extracted_tool_calls.tools_called
assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls)
assert extracted_tool_calls.content == expected_content
# Test for preprocess_model_output method
def test_preprocess_model_output(xlam_tool_parser):
# Test with list structure
model_output = """[{"name": "get_current_weather", "arguments": {"city": "Seattle"}}]""" # noqa: E501
content, potential_tool_calls = xlam_tool_parser.preprocess_model_output(
model_output)
assert content is None
assert potential_tool_calls == model_output
# Test with thinking tag
model_output = """<think>I'll help you with that.</think>[{"name": "get_current_weather", "arguments": {"city": "Seattle"}}]""" # noqa: E501
content, potential_tool_calls = xlam_tool_parser.preprocess_model_output(
model_output)
assert content == "<think>I'll help you with that.</think>"
assert (
potential_tool_calls ==
'[{"name": "get_current_weather", "arguments": {"city": "Seattle"}}]')
# Test with JSON code block
model_output = """I'll help you with that.
```json
[{"name": "get_current_weather", "arguments": {"city": "Seattle"}}]
```"""
content, potential_tool_calls = xlam_tool_parser.preprocess_model_output(
model_output)
assert content == "I'll help you with that."
assert "get_current_weather" in potential_tool_calls
# Test with no tool calls
model_output = """I'll help you with that."""
content, potential_tool_calls = xlam_tool_parser.preprocess_model_output(
model_output)
assert content == model_output
assert potential_tool_calls is None
# Simulate streaming to test extract_tool_calls_streaming
def test_streaming_with_list_structure(xlam_tool_parser):
# Reset streaming state
xlam_tool_parser.prev_tool_calls = []
xlam_tool_parser.current_tools_sent = []
xlam_tool_parser.streamed_args = []
xlam_tool_parser.current_tool_id = -1
# Simulate receiving a message with list structure
current_text = """[{"name": "get_current_weather", "arguments": {"city": "Seattle"}}]""" # noqa: E501
# First call to set up the tool
xlam_tool_parser.extract_tool_calls_streaming(
previous_text="",
current_text=current_text,
delta_text="]",
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[],
request=None,
)
# Make sure the tool is set up correctly
assert (xlam_tool_parser.current_tool_id
>= 0), "Tool index should be initialized"
# Manually set up the state for sending the tool name
xlam_tool_parser.current_tools_sent = [False]
# Call to send the function name
result = xlam_tool_parser.extract_tool_calls_streaming(
previous_text=current_text,
current_text=current_text,
delta_text="",
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[],
request=None,
)
# Check that we get a result with the proper tool call
if result is not None:
assert hasattr(result, "tool_calls")
assert len(result.tool_calls) == 1
assert result.tool_calls[0].function.name == "get_current_weather"
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import ast
import pytest
from tools.validate_config import validate_ast
_TestConfig1 = '''
@config
class _TestConfig1:
pass
'''
_TestConfig2 = '''
@config
@dataclass
class _TestConfig2:
a: int
"""docstring"""
'''
_TestConfig3 = '''
@config
@dataclass
class _TestConfig3:
a: int = 1
'''
_TestConfig4 = '''
@config
@dataclass
class _TestConfig4:
a: Union[Literal[1], Literal[2]] = 1
"""docstring"""
'''
@pytest.mark.parametrize(("test_config", "expected_error"), [
(_TestConfig1, "must be a dataclass"),
(_TestConfig2, "must have a default"),
(_TestConfig3, "must have a docstring"),
(_TestConfig4, "must use a single Literal"),
])
def test_config(test_config, expected_error):
tree = ast.parse(test_config)
with pytest.raises(Exception, match=expected_error):
validate_ast(tree)
...@@ -667,42 +667,54 @@ def get_physical_device_indices(devices): ...@@ -667,42 +667,54 @@ def get_physical_device_indices(devices):
@_nvml() @_nvml()
def wait_for_gpu_memory_to_clear(devices: list[int], def wait_for_gpu_memory_to_clear(*,
threshold_bytes: int, devices: list[int],
threshold_bytes: Optional[int] = None,
threshold_ratio: Optional[float] = None,
timeout_s: float = 120) -> None: timeout_s: float = 120) -> None:
assert threshold_bytes is not None or threshold_ratio is not None
# Use nvml instead of pytorch to reduce measurement error from torch cuda # Use nvml instead of pytorch to reduce measurement error from torch cuda
# context. # context.
devices = get_physical_device_indices(devices) devices = get_physical_device_indices(devices)
start_time = time.time() start_time = time.time()
while True: while True:
output: dict[int, str] = {} output: dict[int, str] = {}
output_raw: dict[int, float] = {} output_raw: dict[int, tuple[float, float]] = {}
for device in devices: for device in devices:
if current_platform.is_rocm(): if current_platform.is_rocm():
dev_handle = amdsmi_get_processor_handles()[device] dev_handle = amdsmi_get_processor_handles()[device]
mem_info = amdsmi_get_gpu_vram_usage(dev_handle) mem_info = amdsmi_get_gpu_vram_usage(dev_handle)
gb_used = mem_info["vram_used"] / 2**10 gb_used = mem_info["vram_used"] / 2**10
gb_total = mem_info["vram_total"] / 2**10
else: else:
dev_handle = nvmlDeviceGetHandleByIndex(device) dev_handle = nvmlDeviceGetHandleByIndex(device)
mem_info = nvmlDeviceGetMemoryInfo(dev_handle) mem_info = nvmlDeviceGetMemoryInfo(dev_handle)
gb_used = mem_info.used / 2**30 gb_used = mem_info.used / 2**30
output_raw[device] = gb_used gb_total = mem_info.total / 2**30
output[device] = f'{gb_used:.02f}' output_raw[device] = (gb_used, gb_total)
output[device] = f'{gb_used:.02f}/{gb_total:.02f}'
print('gpu memory used (GB): ', end='') print('gpu memory used/total (GiB): ', end='')
for k, v in output.items(): for k, v in output.items():
print(f'{k}={v}; ', end='') print(f'{k}={v}; ', end='')
print('') print('')
if threshold_bytes is not None:
is_free = lambda used, total: used <= threshold_bytes / 2**30
threshold = f"{threshold_bytes/2**30} GiB"
else:
is_free = lambda used, total: used / total <= threshold_ratio
threshold = f"{threshold_ratio:.2f}"
dur_s = time.time() - start_time dur_s = time.time() - start_time
if all(v <= (threshold_bytes / 2**30) for v in output_raw.values()): if all(is_free(used, total) for used, total in output_raw.values()):
print(f'Done waiting for free GPU memory on devices {devices=} ' print(f'Done waiting for free GPU memory on devices {devices=} '
f'({threshold_bytes/2**30=}) {dur_s=:.02f}') f'({threshold=}) {dur_s=:.02f}')
break break
if dur_s >= timeout_s: if dur_s >= timeout_s:
raise ValueError(f'Memory of devices {devices=} not free after ' raise ValueError(f'Memory of devices {devices=} not free after '
f'{dur_s=:.02f} ({threshold_bytes/2**30=})') f'{dur_s=:.02f} ({threshold=})')
time.sleep(5) time.sleep(5)
......
...@@ -43,6 +43,7 @@ def make_request(request_id, ...@@ -43,6 +43,7 @@ def make_request(request_id,
multi_modal_hashes=mm_hashes, multi_modal_hashes=mm_hashes,
multi_modal_placeholders=mm_positions, multi_modal_placeholders=mm_positions,
sampling_params=SamplingParams(max_tokens=17), sampling_params=SamplingParams(max_tokens=17),
pooling_params=None,
eos_token_id=100, eos_token_id=100,
lora_request=None, lora_request=None,
cache_salt=cache_salt, cache_salt=cache_salt,
...@@ -900,3 +901,19 @@ def test_get_kv_cache_config(): ...@@ -900,3 +901,19 @@ def test_get_kv_cache_config():
with pytest.raises(NotImplementedError): with pytest.raises(NotImplementedError):
get_kv_cache_config(vllm_config, kv_cache_specs_hybrid, get_kv_cache_config(vllm_config, kv_cache_specs_hybrid,
mem_per_block_per_layer * 2 * 32) mem_per_block_per_layer * 2 * 32)
# Test num_gpu_blocks_override
vllm_config.cache_config.num_gpu_blocks_override = 16
kv_cache_config_override_blocks = get_kv_cache_config(
vllm_config, kv_cache_specs_full, mem_per_block_per_layer * 2 * 32)
assert kv_cache_config_override_blocks == KVCacheConfig(
num_blocks=16,
kv_cache_tensors=[
KVCacheTensor(size=mem_per_block_per_layer * 16,
shared_by=["layer_1"]),
KVCacheTensor(size=mem_per_block_per_layer * 16,
shared_by=["layer_2"]),
],
kv_cache_groups=[
KVCacheGroupSpec(["layer_1", "layer_2"], new_kv_cache_spec())
])
\ No newline at end of file
...@@ -39,6 +39,7 @@ def make_request(request_id, ...@@ -39,6 +39,7 @@ def make_request(request_id,
multi_modal_placeholders=mm_positions, multi_modal_placeholders=mm_positions,
sampling_params=SamplingParams(max_tokens=17, sampling_params=SamplingParams(max_tokens=17,
prompt_logprobs=prompt_logprobs), prompt_logprobs=prompt_logprobs),
pooling_params=None,
eos_token_id=100, eos_token_id=100,
lora_request=None, lora_request=None,
cache_salt=cache_salt, cache_salt=cache_salt,
......
...@@ -9,14 +9,15 @@ import torch ...@@ -9,14 +9,15 @@ import torch
from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig, from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig,
SchedulerConfig, SpeculativeConfig, VllmConfig) SchedulerConfig, SpeculativeConfig, VllmConfig)
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.sampling_params import SamplingParams from vllm.sampling_params import GuidedDecodingParams, SamplingParams
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput
from vllm.v1.core.sched.scheduler import Scheduler from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec) KVCacheGroupSpec)
from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus from vllm.v1.request import Request, RequestStatus
from vllm.v1.structured_output import StructuredOutputManager from vllm.v1.structured_output import StructuredOutputManager
from vllm.v1.structured_output.request import StructuredOutputRequest
EOS_TOKEN_ID = 50256 EOS_TOKEN_ID = 50256
...@@ -33,6 +34,7 @@ def create_scheduler( ...@@ -33,6 +34,7 @@ def create_scheduler(
block_size: int = 16, block_size: int = 16,
max_model_len: Optional[int] = None, max_model_len: Optional[int] = None,
num_speculative_tokens: Optional[int] = None, num_speculative_tokens: Optional[int] = None,
skip_tokenizer_init: bool = False,
) -> Scheduler: ) -> Scheduler:
'''Create scheduler under test. '''Create scheduler under test.
...@@ -65,6 +67,7 @@ def create_scheduler( ...@@ -65,6 +67,7 @@ def create_scheduler(
trust_remote_code=True, trust_remote_code=True,
dtype="float16", dtype="float16",
seed=42, seed=42,
skip_tokenizer_init=skip_tokenizer_init,
) )
# Cache config, optionally force APC # Cache config, optionally force APC
kwargs_cache = ({} if enable_prefix_caching is None else { kwargs_cache = ({} if enable_prefix_caching is None else {
...@@ -135,6 +138,7 @@ def create_requests(num_requests: int, ...@@ -135,6 +138,7 @@ def create_requests(num_requests: int,
request_id=f"{i}", request_id=f"{i}",
prompt_token_ids=[i] * num_tokens, prompt_token_ids=[i] * num_tokens,
sampling_params=sampling_params, sampling_params=sampling_params,
pooling_params=None,
multi_modal_inputs=mm_inputs, multi_modal_inputs=mm_inputs,
multi_modal_placeholders=mm_position, multi_modal_placeholders=mm_position,
multi_modal_hashes=None, multi_modal_hashes=None,
...@@ -185,7 +189,7 @@ def test_get_num_unfinished_requests(): ...@@ -185,7 +189,7 @@ def test_get_num_unfinished_requests():
]) ])
def test_schedule(enable_prefix_caching: Optional[bool], def test_schedule(enable_prefix_caching: Optional[bool],
prompt_logprobs: Optional[int]): prompt_logprobs: Optional[int]):
'''Test scheduling. '''Test scheduling.
Two cases: default APC/no prompt logprobs; APC=True + prompt logprobs Two cases: default APC/no prompt logprobs; APC=True + prompt logprobs
''' '''
scheduler = create_scheduler(enable_prefix_caching=enable_prefix_caching) scheduler = create_scheduler(enable_prefix_caching=enable_prefix_caching)
...@@ -197,7 +201,7 @@ def test_schedule(enable_prefix_caching: Optional[bool], ...@@ -197,7 +201,7 @@ def test_schedule(enable_prefix_caching: Optional[bool],
# Test initial scheduling # Test initial scheduling
output = scheduler.schedule() output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == len(requests) assert len(output.scheduled_new_reqs) == len(requests)
assert len(output.scheduled_cached_reqs) == 0 assert output.scheduled_cached_reqs.num_reqs == 0
assert len(output.finished_req_ids) == 0 assert len(output.finished_req_ids) == 0
# Verify all requests are scheduled. # Verify all requests are scheduled.
for req_id, num_tokens in output.num_scheduled_tokens.items(): for req_id, num_tokens in output.num_scheduled_tokens.items():
...@@ -224,7 +228,7 @@ def test_schedule_multimodal_requests(): ...@@ -224,7 +228,7 @@ def test_schedule_multimodal_requests():
output = scheduler.schedule() output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == len(requests) assert len(output.scheduled_new_reqs) == len(requests)
assert len(output.scheduled_cached_reqs) == 0 assert output.scheduled_cached_reqs.num_reqs == 0
assert len(output.finished_req_ids) == 0 assert len(output.finished_req_ids) == 0
for req_id, num_tokens in output.num_scheduled_tokens.items(): for req_id, num_tokens in output.num_scheduled_tokens.items():
assert num_tokens == len(requests[int(req_id)].prompt_token_ids) assert num_tokens == len(requests[int(req_id)].prompt_token_ids)
...@@ -258,7 +262,7 @@ def test_schedule_partial_requests(): ...@@ -258,7 +262,7 @@ def test_schedule_partial_requests():
output = scheduler.schedule() output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == 3 assert len(output.scheduled_new_reqs) == 3
assert len(output.scheduled_cached_reqs) == 0 assert output.scheduled_cached_reqs.num_reqs == 0
assert len(output.finished_req_ids) == 0 assert len(output.finished_req_ids) == 0
assert scheduler.max_num_encoder_input_tokens == 1024 assert scheduler.max_num_encoder_input_tokens == 1024
...@@ -283,6 +287,7 @@ def test_schedule_partial_requests(): ...@@ -283,6 +287,7 @@ def test_schedule_partial_requests():
spec_token_ids=None, spec_token_ids=None,
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
pooler_output=[],
) )
scheduler.update_from_output(output, model_runner_output) scheduler.update_from_output(output, model_runner_output)
...@@ -293,7 +298,7 @@ def test_schedule_partial_requests(): ...@@ -293,7 +298,7 @@ def test_schedule_partial_requests():
output = scheduler.schedule() output = scheduler.schedule()
assert len(scheduler.running) == 3 assert len(scheduler.running) == 3
assert len(output.scheduled_new_reqs) == 0 assert len(output.scheduled_new_reqs) == 0
assert len(output.scheduled_cached_reqs) == 2 assert output.scheduled_cached_reqs.num_reqs == 2
assert len(output.finished_req_ids) == 0 assert len(output.finished_req_ids) == 0
assert output.num_scheduled_tokens[requests[0].request_id] == 1 assert output.num_scheduled_tokens[requests[0].request_id] == 1
assert output.num_scheduled_tokens[requests[1].request_id] == 700 assert output.num_scheduled_tokens[requests[1].request_id] == 700
...@@ -317,7 +322,7 @@ def test_no_mm_input_chunking(): ...@@ -317,7 +322,7 @@ def test_no_mm_input_chunking():
output = scheduler.schedule() output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == 1 assert len(output.scheduled_new_reqs) == 1
assert len(output.scheduled_cached_reqs) == 0 assert output.scheduled_cached_reqs.num_reqs == 0
assert len(output.finished_req_ids) == 0 assert len(output.finished_req_ids) == 0
# We want to only see the 400 text tokens at the start scheduled # We want to only see the 400 text tokens at the start scheduled
assert output.num_scheduled_tokens[requests[0].request_id] == 400 assert output.num_scheduled_tokens[requests[0].request_id] == 400
...@@ -333,13 +338,14 @@ def test_no_mm_input_chunking(): ...@@ -333,13 +338,14 @@ def test_no_mm_input_chunking():
spec_token_ids=None, spec_token_ids=None,
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
pooler_output=[],
) )
scheduler.update_from_output(output, model_runner_output) scheduler.update_from_output(output, model_runner_output)
output = scheduler.schedule() output = scheduler.schedule()
assert len(scheduler.running) == 1 assert len(scheduler.running) == 1
assert len(output.scheduled_new_reqs) == 0 assert len(output.scheduled_new_reqs) == 0
assert len(output.scheduled_cached_reqs) == 1 assert output.scheduled_cached_reqs.num_reqs == 1
assert len(output.finished_req_ids) == 0 assert len(output.finished_req_ids) == 0
assert output.num_scheduled_tokens[requests[0].request_id] == 800 assert output.num_scheduled_tokens[requests[0].request_id] == 800
...@@ -376,7 +382,7 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool): ...@@ -376,7 +382,7 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
output = scheduler.schedule() output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == 3 assert len(output.scheduled_new_reqs) == 3
assert len(output.scheduled_cached_reqs) == 0 assert output.scheduled_cached_reqs.num_reqs == 0
assert len(output.finished_req_ids) == 0 assert len(output.finished_req_ids) == 0
# The first request is scheduled partially - 400. # The first request is scheduled partially - 400.
...@@ -396,6 +402,7 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool): ...@@ -396,6 +402,7 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
spec_token_ids=None, spec_token_ids=None,
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
pooler_output=[],
) )
scheduler.update_from_output(output, model_runner_output) scheduler.update_from_output(output, model_runner_output)
...@@ -404,7 +411,7 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool): ...@@ -404,7 +411,7 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
output1 = scheduler.schedule() output1 = scheduler.schedule()
assert len(scheduler.running) == 3 assert len(scheduler.running) == 3
assert len(output1.scheduled_new_reqs) == 0 assert len(output1.scheduled_new_reqs) == 0
assert len(output1.scheduled_cached_reqs) == 3 assert output1.scheduled_cached_reqs.num_reqs == 3
assert len(output1.finished_req_ids) == 0 assert len(output1.finished_req_ids) == 0
assert output1.num_scheduled_tokens[requests[0].request_id] == 400 assert output1.num_scheduled_tokens[requests[0].request_id] == 400
assert output1.num_scheduled_tokens[requests[1].request_id] == 400 assert output1.num_scheduled_tokens[requests[1].request_id] == 400
...@@ -420,12 +427,13 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool): ...@@ -420,12 +427,13 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
spec_token_ids=None, spec_token_ids=None,
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
pooler_output=[],
) )
scheduler.update_from_output(output1, model_runner_output) scheduler.update_from_output(output1, model_runner_output)
output2 = scheduler.schedule() output2 = scheduler.schedule()
assert len(scheduler.running) == 3 assert len(scheduler.running) == 3
assert len(output2.scheduled_new_reqs) == 0 assert len(output2.scheduled_new_reqs) == 0
assert len(output2.scheduled_cached_reqs) == 3 assert output2.scheduled_cached_reqs.num_reqs == 3
assert len(output2.finished_req_ids) == 0 assert len(output2.finished_req_ids) == 0
assert output2.num_scheduled_tokens[requests[0].request_id] == 1 assert output2.num_scheduled_tokens[requests[0].request_id] == 1
assert output2.num_scheduled_tokens[requests[1].request_id] == 1 assert output2.num_scheduled_tokens[requests[1].request_id] == 1
...@@ -444,23 +452,24 @@ def test_stop_via_update_from_output(): ...@@ -444,23 +452,24 @@ def test_stop_via_update_from_output():
scheduler.requests[req.request_id] = req scheduler.requests[req.request_id] = req
scheduler.running.append(req) scheduler.running.append(req)
scheduler_output = SchedulerOutput(scheduled_new_reqs=[], scheduler_output = SchedulerOutput(
scheduled_cached_reqs=[], scheduled_new_reqs=[],
num_scheduled_tokens={ scheduled_cached_reqs=CachedRequestData.make_empty(),
requests[0].request_id: 1, num_scheduled_tokens={
requests[1].request_id: 2 requests[0].request_id: 1,
}, requests[1].request_id: 2
total_num_scheduled_tokens=3, },
scheduled_encoder_inputs={}, total_num_scheduled_tokens=3,
scheduled_spec_decode_tokens={ scheduled_encoder_inputs={},
requests[0].request_id: [], scheduled_spec_decode_tokens={
requests[1].request_id: [10] requests[0].request_id: [],
}, requests[1].request_id: [10]
num_common_prefix_blocks=0, },
finished_req_ids=set(), num_common_prefix_blocks=0,
free_encoder_input_ids=[], finished_req_ids=set(),
structured_output_request_ids={}, free_encoder_input_ids=[],
grammar_bitmask=None) structured_output_request_ids={},
grammar_bitmask=None)
model_output = ModelRunnerOutput( model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests], req_ids=[req.request_id for req in requests],
...@@ -473,7 +482,8 @@ def test_stop_via_update_from_output(): ...@@ -473,7 +482,8 @@ def test_stop_via_update_from_output():
11]], # First request hits EOS, second continues 11]], # First request hits EOS, second continues
spec_token_ids=None, spec_token_ids=None,
logprobs=None, logprobs=None,
prompt_logprobs_dict={}) prompt_logprobs_dict={},
pooler_output=[])
scheduler.update_from_output(scheduler_output, model_output) scheduler.update_from_output(scheduler_output, model_output)
...@@ -495,23 +505,25 @@ def test_stop_via_update_from_output(): ...@@ -495,23 +505,25 @@ def test_stop_via_update_from_output():
scheduler.requests[req.request_id] = req scheduler.requests[req.request_id] = req
scheduler.running.append(req) scheduler.running.append(req)
scheduler_output = SchedulerOutput(scheduled_new_reqs=[], scheduler_output = SchedulerOutput(
scheduled_cached_reqs=[], scheduled_new_reqs=[],
num_scheduled_tokens={ scheduled_cached_reqs=CachedRequestData.make_empty(),
requests[0].request_id: 3, num_scheduled_tokens={
requests[1].request_id: 2 requests[0].request_id: 3,
}, requests[1].request_id: 2
total_num_scheduled_tokens=5, },
scheduled_encoder_inputs={}, total_num_scheduled_tokens=5,
scheduled_spec_decode_tokens={ scheduled_encoder_inputs={},
requests[0].request_id: [10, 42], scheduled_spec_decode_tokens={
requests[1].request_id: [13] requests[0].request_id: [10, 42],
}, requests[1].request_id: [13]
num_common_prefix_blocks=0, },
finished_req_ids=set(), num_common_prefix_blocks=0,
free_encoder_input_ids=[], finished_req_ids=set(),
structured_output_request_ids={}, free_encoder_input_ids=[],
grammar_bitmask=None) structured_output_request_ids={},
grammar_bitmask=None,
)
model_output = ModelRunnerOutput( model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests], req_ids=[req.request_id for req in requests],
...@@ -523,7 +535,8 @@ def test_stop_via_update_from_output(): ...@@ -523,7 +535,8 @@ def test_stop_via_update_from_output():
[13, 14]], # First request hits stop token [13, 14]], # First request hits stop token
spec_token_ids=None, spec_token_ids=None,
logprobs=None, logprobs=None,
prompt_logprobs_dict={}) prompt_logprobs_dict={},
pooler_output=[])
scheduler.update_from_output(scheduler_output, model_output) scheduler.update_from_output(scheduler_output, model_output)
...@@ -544,23 +557,25 @@ def test_stop_via_update_from_output(): ...@@ -544,23 +557,25 @@ def test_stop_via_update_from_output():
scheduler.requests[req.request_id] = req scheduler.requests[req.request_id] = req
scheduler.running.append(req) scheduler.running.append(req)
scheduler_output = SchedulerOutput(scheduled_new_reqs=[], scheduler_output = SchedulerOutput(
scheduled_cached_reqs=[], scheduled_new_reqs=[],
num_scheduled_tokens={ scheduled_cached_reqs=CachedRequestData.make_empty(),
requests[0].request_id: 3, num_scheduled_tokens={
requests[1].request_id: 1 requests[0].request_id: 3,
}, requests[1].request_id: 1
total_num_scheduled_tokens=4, },
scheduled_encoder_inputs={}, total_num_scheduled_tokens=4,
scheduled_spec_decode_tokens={ scheduled_encoder_inputs={},
requests[0].request_id: [10, 11], scheduled_spec_decode_tokens={
requests[1].request_id: [] requests[0].request_id: [10, 11],
}, requests[1].request_id: []
num_common_prefix_blocks=0, },
finished_req_ids=set(), num_common_prefix_blocks=0,
free_encoder_input_ids=[], finished_req_ids=set(),
structured_output_request_ids={}, free_encoder_input_ids=[],
grammar_bitmask=None) structured_output_request_ids={},
grammar_bitmask=None,
)
model_output = ModelRunnerOutput( model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests], req_ids=[req.request_id for req in requests],
...@@ -572,7 +587,8 @@ def test_stop_via_update_from_output(): ...@@ -572,7 +587,8 @@ def test_stop_via_update_from_output():
[13]], # First request exceeds max_tokens [13]], # First request exceeds max_tokens
spec_token_ids=None, spec_token_ids=None,
logprobs=None, logprobs=None,
prompt_logprobs_dict={}) prompt_logprobs_dict={},
pooler_output=[])
scheduler.update_from_output(scheduler_output, model_output) scheduler.update_from_output(scheduler_output, model_output)
...@@ -595,7 +611,7 @@ def test_stop_via_update_from_output(): ...@@ -595,7 +611,7 @@ def test_stop_via_update_from_output():
scheduler_output = SchedulerOutput( scheduler_output = SchedulerOutput(
scheduled_new_reqs=[], scheduled_new_reqs=[],
scheduled_cached_reqs=[], scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens={requests[0].request_id: 3}, num_scheduled_tokens={requests[0].request_id: 3},
total_num_scheduled_tokens=3, total_num_scheduled_tokens=3,
scheduled_encoder_inputs={}, scheduled_encoder_inputs={},
...@@ -614,7 +630,8 @@ def test_stop_via_update_from_output(): ...@@ -614,7 +630,8 @@ def test_stop_via_update_from_output():
sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]], sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]],
spec_token_ids=None, spec_token_ids=None,
logprobs=None, logprobs=None,
prompt_logprobs_dict={}) prompt_logprobs_dict={},
pooler_output=[])
scheduler.update_from_output(scheduler_output, model_output) scheduler.update_from_output(scheduler_output, model_output)
...@@ -663,6 +680,7 @@ def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool], ...@@ -663,6 +680,7 @@ def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool],
spec_token_ids=None, spec_token_ids=None,
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
pooler_output=[],
) )
scheduler.update_from_output(scheduler_output0, model_runner_output) scheduler.update_from_output(scheduler_output0, model_runner_output)
...@@ -680,6 +698,7 @@ def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool], ...@@ -680,6 +698,7 @@ def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool],
spec_token_ids=None, spec_token_ids=None,
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
pooler_output=[],
) )
scheduler.update_from_output(scheduler_output1, model_runner_output) scheduler.update_from_output(scheduler_output1, model_runner_output)
...@@ -730,6 +749,7 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected): ...@@ -730,6 +749,7 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
spec_token_ids=spec_tokens, spec_token_ids=spec_tokens,
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
pooler_output=[],
) )
engine_core_outputs = scheduler.update_from_output(output, engine_core_outputs = scheduler.update_from_output(output,
model_runner_output) model_runner_output)
...@@ -769,6 +789,7 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected): ...@@ -769,6 +789,7 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
spec_token_ids=None, spec_token_ids=None,
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
pooler_output=[],
) )
engine_core_outputs = scheduler.update_from_output(output, engine_core_outputs = scheduler.update_from_output(output,
model_runner_output) model_runner_output)
...@@ -896,6 +917,7 @@ def test_kv_connector_basic(): ...@@ -896,6 +917,7 @@ def test_kv_connector_basic():
spec_token_ids=None, spec_token_ids=None,
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
pooler_output=[],
) )
# Ensure ScheduleOutput is correct. # Ensure ScheduleOutput is correct.
...@@ -941,6 +963,7 @@ def test_kv_connector_basic(): ...@@ -941,6 +963,7 @@ def test_kv_connector_basic():
spec_token_ids=None, spec_token_ids=None,
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
pooler_output=[],
) )
# We should get a local cache hit of NUM_TOKENS_PREFIX and # We should get a local cache hit of NUM_TOKENS_PREFIX and
...@@ -1007,6 +1030,7 @@ def test_kv_connector_unable_to_allocate(): ...@@ -1007,6 +1030,7 @@ def test_kv_connector_unable_to_allocate():
spec_token_ids=None, spec_token_ids=None,
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
pooler_output=[],
) )
# Just one request should be running. # Just one request should be running.
...@@ -1087,6 +1111,7 @@ def test_kv_connector_handles_preemption(): ...@@ -1087,6 +1111,7 @@ def test_kv_connector_handles_preemption():
spec_token_ids=None, spec_token_ids=None,
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
pooler_output=[],
) )
# All can be scheduled - 1st token. # All can be scheduled - 1st token.
...@@ -1133,7 +1158,6 @@ def test_kv_connector_handles_preemption(): ...@@ -1133,7 +1158,6 @@ def test_kv_connector_handles_preemption():
assert len(scheduler.running) == 1 assert len(scheduler.running) == 1
_ = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT) _ = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT)
assert len(scheduler.running) == 0 assert len(scheduler.running) == 0
assert len(scheduler.waiting) == 1
# All memory should be freed since nothing is running. # All memory should be freed since nothing is running.
assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() \ assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() \
== NUM_BLOCKS - 1 == NUM_BLOCKS - 1
...@@ -1181,6 +1205,7 @@ def make_output(scheduler: Scheduler): ...@@ -1181,6 +1205,7 @@ def make_output(scheduler: Scheduler):
spec_token_ids=None, spec_token_ids=None,
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
pooler_output=[],
) )
...@@ -1191,7 +1216,6 @@ def assert_scheduler_empty(scheduler: Scheduler): ...@@ -1191,7 +1216,6 @@ def assert_scheduler_empty(scheduler: Scheduler):
assert len(scheduler.waiting) == 0 assert len(scheduler.waiting) == 0
assert len(scheduler.running) == 0 assert len(scheduler.running) == 0
assert len(scheduler.finished_req_ids) == 0 assert len(scheduler.finished_req_ids) == 0
assert len(scheduler._cached_reqs_data) == 0
# EncoderCacheManager. # EncoderCacheManager.
assert len(scheduler.encoder_cache_manager.freed) == 0 assert len(scheduler.encoder_cache_manager.freed) == 0
...@@ -1247,3 +1271,628 @@ def test_memory_leak(): ...@@ -1247,3 +1271,628 @@ def test_memory_leak():
# Confirm no memory leak. # Confirm no memory leak.
assert_scheduler_empty(scheduler) assert_scheduler_empty(scheduler)
def create_scheduler_with_priority(
model: str = "facebook/opt-125m",
max_num_seqs: int = 16,
max_num_batched_tokens: int = 8192,
enable_prefix_caching: Optional[bool] = None,
long_prefill_token_threshold: int = 0,
disable_chunked_mm_input: bool = False,
use_kv_connector: bool = False,
num_blocks: int = 10000,
block_size: int = 16,
max_model_len: Optional[int] = None,
num_speculative_tokens: Optional[int] = None,
) -> Scheduler:
'''Create scheduler with priority policy enabled.
Args:
model: model under test
max_num_seqs: max sequences to schedule
max_num_batch_tokens: max num tokens to batch
enable_prefix_caching: optionally force APC config
(True/False) or use default
(None)
Returns:
{class}`Scheduler` instance with priority scheduling
'''
if max_model_len is None:
max_model_len = max_num_batched_tokens
scheduler_config = SchedulerConfig(
max_num_seqs=max_num_seqs,
max_num_batched_tokens=max_num_batched_tokens,
max_model_len=max_model_len,
long_prefill_token_threshold=long_prefill_token_threshold,
disable_chunked_mm_input=disable_chunked_mm_input,
enable_chunked_prefill=True,
policy="priority", # Enable priority scheduling
)
model_config = ModelConfig(
model=model,
task="auto",
tokenizer=model,
tokenizer_mode="auto",
trust_remote_code=True,
dtype="float16",
seed=42,
)
# Cache config, optionally force APC
kwargs_cache = ({} if enable_prefix_caching is None else {
'enable_prefix_caching': enable_prefix_caching
})
cache_config = CacheConfig(
block_size=block_size,
gpu_memory_utilization=0.9,
swap_space=0,
cache_dtype="auto",
**kwargs_cache,
)
kv_transfer_config = KVTransferConfig(
kv_connector="SharedStorageConnector",
kv_role="kv_both",
kv_connector_extra_config={"shared_storage_path": "local_storage"},
) if use_kv_connector else None
speculative_config: Optional[SpeculativeConfig] = None
if num_speculative_tokens is not None:
speculative_config = SpeculativeConfig(
model="ngram", num_speculative_tokens=num_speculative_tokens)
vllm_config = VllmConfig(
scheduler_config=scheduler_config,
model_config=model_config,
cache_config=cache_config,
kv_transfer_config=kv_transfer_config,
speculative_config=speculative_config,
)
kv_cache_config = KVCacheConfig(
num_blocks=num_blocks, # A large number of blocks to hold all requests
kv_cache_tensors=[],
kv_cache_groups=[
KVCacheGroupSpec(['layer'],
FullAttentionSpec(block_size, 1, 1, torch.float32,
False))
],
)
cache_config.num_gpu_blocks = num_blocks
return Scheduler(
vllm_config=vllm_config,
kv_cache_config=kv_cache_config,
log_stats=True,
structured_output_manager=StructuredOutputManager(vllm_config),
)
def create_requests_with_priority(
num_requests: int,
priorities: list[int],
arrival_times: Optional[list[float]] = None,
num_tokens: int = 10,
mm_positions: Optional[list[PlaceholderRange]] = None,
max_tokens: int = 16,
stop_token_ids: Optional[list[int]] = None,
prompt_logprobs: Optional[int] = None):
"""Create requests with specified priorities and arrival times."""
assert len(priorities) == num_requests
if arrival_times is not None:
assert len(arrival_times) == num_requests
else:
arrival_times = [float(i) for i in range(num_requests)]
sampling_params = SamplingParams(ignore_eos=False,
max_tokens=max_tokens,
stop_token_ids=stop_token_ids,
prompt_logprobs=prompt_logprobs)
requests = []
for i in range(num_requests):
if mm_positions is not None:
mm_position = mm_positions[i]
mm_inputs = [MultiModalKwargs({})] * len(mm_position)
else:
mm_position = None
mm_inputs = None
request = Request(
request_id=f"{i}",
prompt_token_ids=[i] * num_tokens,
sampling_params=sampling_params,
pooling_params=None,
multi_modal_inputs=mm_inputs,
multi_modal_placeholders=mm_position,
multi_modal_hashes=None,
eos_token_id=EOS_TOKEN_ID,
arrival_time=arrival_times[i],
priority=priorities[i],
)
requests.append(request)
return requests
def test_priority_scheduling_basic_ordering():
"""Test that requests are scheduled in priority order
(lower value = higher priority)."""
scheduler = create_scheduler_with_priority()
# Create requests with different priorities
# Priority 0 (highest), 1, 2 (lowest)
priorities = [2, 0, 1] # Add in non-priority order
arrival_times = [1.0, 2.0, 3.0] # All different arrival times
requests = create_requests_with_priority(num_requests=3,
priorities=priorities,
arrival_times=arrival_times)
# Add requests in non-priority order
for request in requests:
scheduler.add_request(request)
# Schedule and verify priority order
output = scheduler.schedule()
# Should schedule all requests since they fit in budget
assert len(output.scheduled_new_reqs) == 3
# Verify they are scheduled in priority order:
# req_1 (priority 0), req_2 (priority 1), req_0 (priority 2)
scheduled_req_ids = [req.req_id for req in output.scheduled_new_reqs]
assert scheduled_req_ids == ["1", "2", "0"]
def test_priority_scheduling_arrival_time_tiebreaker():
"""Test that arrival time is used
as tiebreaker when priorities are equal."""
scheduler = create_scheduler_with_priority()
# Create requests with same priority but different arrival times
priorities = [1, 1, 1] # All same priority
arrival_times = [3.0, 1.0, 2.0] # Different arrival times
requests = create_requests_with_priority(num_requests=3,
priorities=priorities,
arrival_times=arrival_times)
# Add requests in non-arrival order
for request in requests:
scheduler.add_request(request)
# Schedule and verify arrival time order
output = scheduler.schedule()
# Should schedule all requests since they fit in budget
assert len(output.scheduled_new_reqs) == 3
# Verify they are scheduled in arrival time order:
# req_1 (1.0), req_2 (2.0), req_0 (3.0)
scheduled_req_ids = [req.req_id for req in output.scheduled_new_reqs]
assert scheduled_req_ids == ["1", "2", "0"]
def test_priority_scheduling_mixed_priority_and_arrival():
"""Test priority scheduling with mixed priorities and arrival times."""
scheduler = create_scheduler_with_priority()
# Create requests with mixed priorities and arrival times
priorities = [2, 1, 1, 0] # Mixed priorities
arrival_times = [1.0, 3.0, 2.0, 4.0] # Mixed arrival times
requests = create_requests_with_priority(num_requests=4,
priorities=priorities,
arrival_times=arrival_times)
# Add requests
for request in requests:
scheduler.add_request(request)
# Schedule and verify order
output = scheduler.schedule()
# Should schedule all requests since they fit in budget
assert len(output.scheduled_new_reqs) == 4
# Expected order:
# 1. req_3 (priority 0, arrival 4.0)
# 2. req_2 (priority 1, arrival 2.0) - earlier arrival than req_1
# 3. req_1 (priority 1, arrival 3.0)
# 4. req_0 (priority 2, arrival 1.0)
scheduled_req_ids = [req.req_id for req in output.scheduled_new_reqs]
assert scheduled_req_ids == ["3", "2", "1", "0"]
def test_priority_scheduling_preemption():
"""Test that priority scheduling preempts
lower priority requests when memory is constrained."""
# Create scheduler with very limited memory to force preemption
scheduler = create_scheduler_with_priority(
max_num_seqs=3, # Allow multiple requests
max_num_batched_tokens=200,
num_blocks=6, # Very limited blocks to force memory pressure
block_size=16, # Standard block size
)
# Create initial low-priority requests that will consume most memory
low_priority_requests = create_requests_with_priority(
num_requests=2,
priorities=[5, 5], # Low priority
arrival_times=[1.0, 2.0],
num_tokens=30 # Large enough to consume significant memory
)
# Add and schedule low priority requests
for request in low_priority_requests:
scheduler.add_request(request)
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == 2
# Simulate model execution to move requests to running state
model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in low_priority_requests],
req_id_to_index={
req.request_id: i
for i, req in enumerate(low_priority_requests)
},
sampled_token_ids=[[100] for _ in low_priority_requests],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
scheduler.update_from_output(output, model_output)
# Verify both requests are running
assert len(scheduler.running) == 2
# Now add a high-priority request that requires memory allocation
# This should trigger preemption due to memory constraints
high_priority_request = create_requests_with_priority(
num_requests=1,
priorities=[0], # High priority
arrival_times=[3.0],
num_tokens=30 # Large enough to require significant memory
)[0]
scheduler.add_request(high_priority_request)
# Schedule again - this should trigger
# preemption when trying to allocate memory
output = scheduler.schedule()
# Due to the scheduler's design, if preemption happens
# during running request scheduling,
# waiting requests won't be scheduled in the same step
# Let's check if preemption occurred by looking at the waiting queue
# If preemption happened, we should see requests in the
# waiting queue
if len(scheduler.waiting) > 1: # high priority + preempted request
# Preemption occurred - verify the high priority request
# gets scheduled next
output2 = scheduler.schedule()
assert len(output2.scheduled_new_reqs) == 1
# High priority request
assert output2.scheduled_new_reqs[0].req_id == "0"
else:
# No preemption needed - all requests fit
# This is also valid behavior if memory allows
assert len(output.scheduled_new_reqs) == 1
# High priority request
assert output.scheduled_new_reqs[0].req_id == "0"
def test_priority_scheduling_no_preemption_when_space_available():
"""Test that preemption doesn't happen
when there's space for new requests."""
scheduler = create_scheduler_with_priority(
max_num_seqs=3, # Allow 3 concurrent requests
max_num_batched_tokens=200, # Sufficient token budget
)
# Add two low-priority running requests
low_priority_requests = create_requests_with_priority(
num_requests=2,
priorities=[5, 5],
arrival_times=[1.0, 2.0],
num_tokens=30)
for request in low_priority_requests:
scheduler.add_request(request)
output = scheduler.schedule()
model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in low_priority_requests],
req_id_to_index={
req.request_id: i
for i, req in enumerate(low_priority_requests)
},
sampled_token_ids=[[100] for _ in low_priority_requests],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
scheduler.update_from_output(output, model_output)
# Add high-priority request
high_priority_request = create_requests_with_priority(num_requests=1,
priorities=[0],
arrival_times=[3.0],
num_tokens=30)[0]
scheduler.add_request(high_priority_request)
# Schedule - should not preempt since there's space
output = scheduler.schedule()
# Should schedule the new request without preemption
assert len(output.scheduled_new_reqs) == 1
assert len(scheduler.running) == 3 # All three requests running
assert len(scheduler.waiting) == 0 # No requests waiting
def test_priority_scheduling_preemption_victim_selection():
"""Test that the correct victim is selected for
preemption based on priority and arrival time."""
# This test verifies the priority-based victim selection logic
# by checking the waiting queue order after adding requests with different
# priorities
scheduler = create_scheduler_with_priority(
max_num_seqs=1, # Force sequential processing to test priority order
)
# Create requests with different priorities
requests = create_requests_with_priority(
num_requests=3,
priorities=[3, 2, 0], # Different priorities: low, medium, high
arrival_times=[1.0, 2.0, 3.0],
num_tokens=10)
# Add all requests
for request in requests:
scheduler.add_request(request)
# Schedule - should only schedule the highest priority request
# (req_2, priority 0)
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == 1
assert output.scheduled_new_reqs[0].req_id == "2" # Highest priority
# Verify the waiting queue has the remaining requests in priority order
assert len(scheduler.waiting) == 2
# Extract waiting requests and verify priority order
waiting_requests = list(scheduler.waiting)
waiting_priorities = [req.priority for req in waiting_requests]
waiting_req_ids = [req.request_id for req in waiting_requests]
# Should be req_1 (priority 2) then req_0 (priority 3)
assert waiting_priorities == [2, 3]
assert waiting_req_ids == ["1", "0"]
def test_priority_scheduling_equal_priority_preemption():
"""Test arrival time tiebreaker when requests have equal priority."""
# This test verifies that arrival time is used as a tiebreaker for equal
# priorities
scheduler = create_scheduler_with_priority(
max_num_seqs=1, # Force sequential processing
)
# Create requests with same priority but different arrival times
requests = create_requests_with_priority(
num_requests=3,
priorities=[2, 2, 2], # Same priority
arrival_times=[3.0, 1.0, 2.0], # Different arrival times
num_tokens=10)
# Add all requests
for request in requests:
scheduler.add_request(request)
# Schedule - should schedule the request with earliest arrival time
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == 1
assert output.scheduled_new_reqs[0].req_id == "1" # Earliest arrival (1.0)
# Verify the waiting queue has remaining requests in arrival time order
assert len(scheduler.waiting) == 2
# Extract waiting requests and verify arrival time order
waiting_requests = list(scheduler.waiting)
waiting_arrival_times = [req.arrival_time for req in waiting_requests]
waiting_req_ids = [req.request_id for req in waiting_requests]
# Should be req_2 (arrival 2.0) then req_0 (arrival 3.0)
assert waiting_arrival_times == [2.0, 3.0]
assert waiting_req_ids == ["2", "0"]
def test_priority_scheduling_waiting_queue_order():
"""Test that the waiting queue maintains priority order."""
scheduler = create_scheduler_with_priority(
max_num_seqs=1, # Only one request can run at a time
)
# Create multiple requests with different priorities
requests = create_requests_with_priority(
num_requests=4,
priorities=[3, 1, 2, 0], # Mixed priorities
arrival_times=[1.0, 2.0, 3.0, 4.0],
num_tokens=10)
# Add all requests
for request in requests:
scheduler.add_request(request)
# Schedule - should only schedule the highest priority request
# (req_3, priority 0)
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == 1
assert output.scheduled_new_reqs[0].req_id == "3"
# Verify waiting queue has remaining requests in priority order
assert len(scheduler.waiting) == 3
# Extract requests from waiting queue
# (it's a heap, so we need to pop to see order)
waiting_requests = list(scheduler.waiting)
waiting_priorities = [req.priority for req in waiting_requests]
waiting_req_ids = [req.request_id for req in waiting_requests]
# Should be ordered by priority: req_1 (1), req_2 (2), req_0 (3)
assert waiting_req_ids == ["1", "2", "0"]
assert waiting_priorities == [1, 2, 3]
def test_priority_scheduling_fcfs_fallback():
"""Test that FCFS behavior is maintained when all
requests have same priority."""
scheduler = create_scheduler_with_priority()
# Create requests with same priority but different arrival times
priorities = [1, 1, 1, 1] # All same priority
arrival_times = [4.0, 1.0, 3.0, 2.0] # Different arrival times
requests = create_requests_with_priority(num_requests=4,
priorities=priorities,
arrival_times=arrival_times)
# Add requests
for request in requests:
scheduler.add_request(request)
# Schedule
output = scheduler.schedule()
# Should schedule all requests in arrival time order
assert len(output.scheduled_new_reqs) == 4
scheduled_req_ids = [req.req_id for req in output.scheduled_new_reqs]
# Expected order by arrival time:
# req_1 (1.0), req_3 (2.0), req_2 (3.0), req_0 (4.0)
assert scheduled_req_ids == ["1", "3", "2", "0"]
def test_priority_scheduling_with_limited_slots():
"""Test priority scheduling when max_num_seqs limits concurrent requests."""
scheduler = create_scheduler_with_priority(
max_num_seqs=2, # Only allow 2 concurrent requests
max_num_batched_tokens=1000, # Plenty of token budget
)
# Create requests with different priorities
requests = create_requests_with_priority(
num_requests=4,
priorities=[3, 1, 2, 0], # Mixed priorities
arrival_times=[1.0, 2.0, 3.0, 4.0],
num_tokens=10)
# Add all requests
for request in requests:
scheduler.add_request(request)
# Schedule - should only schedule the 2 highest priority requests
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == 2
# Should schedule req_3 (priority 0) and req_1 (priority 1)
scheduled_req_ids = [req.req_id for req in output.scheduled_new_reqs]
assert "3" in scheduled_req_ids # Priority 0
assert "1" in scheduled_req_ids # Priority 1
# Remaining requests should be in waiting queue in priority order
assert len(scheduler.waiting) == 2
# Extract waiting requests and verify order
waiting_requests = list(scheduler.waiting)
waiting_priorities = [req.priority for req in waiting_requests]
waiting_req_ids = [req.request_id for req in waiting_requests]
# Should be req_2 (priority 2) then req_0 (priority 3)
assert waiting_priorities == [2, 3]
assert waiting_req_ids == ["2", "0"]
def test_priority_scheduling_heap_property():
"""Test that the waiting queue maintains heap
property for priority scheduling."""
scheduler = create_scheduler_with_priority(
max_num_seqs=1, # Only one request can run at a time
)
# Add requests in random priority order
priorities = [5, 1, 8, 3, 2, 7, 4, 6]
arrival_times = [float(i) for i in range(len(priorities))]
requests = create_requests_with_priority(num_requests=len(priorities),
priorities=priorities,
arrival_times=arrival_times,
num_tokens=10)
# Add all requests
for request in requests:
scheduler.add_request(request)
# Schedule one request at a time and verify priority order
scheduled_priorities = []
while scheduler.waiting:
output = scheduler.schedule()
if output.scheduled_new_reqs:
req = output.scheduled_new_reqs[0]
scheduled_priorities.append(requests[int(req.req_id)].priority)
# Simulate completion to make room for next request
model_output = ModelRunnerOutput(
req_ids=[req.req_id],
req_id_to_index={req.req_id: 0},
sampled_token_ids=[[100]],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
scheduler.update_from_output(output, model_output)
# Finish the request to make room for the next one
scheduler.finish_requests(req.req_id,
RequestStatus.FINISHED_STOPPED)
# Verify requests were scheduled in priority order (lowest value first)
expected_priorities = sorted(priorities)
assert scheduled_priorities == expected_priorities
def test_schedule_skip_tokenizer_init():
scheduler = create_scheduler(skip_tokenizer_init=True)
requests = create_requests(num_requests=5)
for request in requests:
scheduler.add_request(request)
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == len(requests)
assert output.grammar_bitmask is None
def test_schedule_skip_tokenizer_init_structured_output_request():
scheduler = create_scheduler(skip_tokenizer_init=True)
guided_params = GuidedDecodingParams(regex="[0-9]+")
sampling_params = SamplingParams(
ignore_eos=False,
max_tokens=16,
guided_decoding=guided_params,
)
request = Request(
request_id="0",
prompt_token_ids=[0, 1],
multi_modal_inputs=None,
multi_modal_hashes=None,
multi_modal_placeholders=None,
sampling_params=sampling_params,
pooling_params=None,
eos_token_id=EOS_TOKEN_ID,
structured_output_request=StructuredOutputRequest(sampling_params),
)
scheduler.add_request(request)
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == 0
assert len(scheduler.running) == 0
assert len(scheduler.waiting) == 1
...@@ -30,7 +30,7 @@ model_config = { ...@@ -30,7 +30,7 @@ model_config = {
]) ])
@pytest.mark.parametrize("batch_size", [5]) @pytest.mark.parametrize("batch_size", [5])
@pytest.mark.parametrize("seed", [1]) @pytest.mark.parametrize("seed", [1])
def test_sliding_window_retrival(monkeypatch, model, batch_size, seed): def test_sliding_window_retrieval(monkeypatch, model, batch_size, seed):
""" """
The test does a bunch of assignments "x1 = 10\nx2 = 33\n..." and then The test does a bunch of assignments "x1 = 10\nx2 = 33\n..." and then
asks for value of one of them (which is outside the sliding window). asks for value of one of them (which is outside the sliding window).
......
...@@ -15,6 +15,7 @@ from vllm.engine.arg_utils import AsyncEngineArgs ...@@ -15,6 +15,7 @@ from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.inputs import PromptType from vllm.inputs import PromptType
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.sampling_params import RequestOutputKind from vllm.sampling_params import RequestOutputKind
from vllm.utils import set_default_torch_num_threads
from vllm.v1.engine.async_llm import AsyncLLM from vllm.v1.engine.async_llm import AsyncLLM
from vllm.v1.metrics.loggers import LoggingStatLogger from vllm.v1.metrics.loggers import LoggingStatLogger
...@@ -107,7 +108,8 @@ async def test_load( ...@@ -107,7 +108,8 @@ async def test_load(
with monkeypatch.context() as m, ExitStack() as after: with monkeypatch.context() as m, ExitStack() as after:
m.setenv("VLLM_USE_V1", "1") m.setenv("VLLM_USE_V1", "1")
engine = AsyncLLM.from_engine_args(engine_args) with set_default_torch_num_threads(1):
engine = AsyncLLM.from_engine_args(engine_args)
after.callback(engine.shutdown) after.callback(engine.shutdown)
NUM_REQUESTS = 100 NUM_REQUESTS = 100
...@@ -154,7 +156,8 @@ async def test_abort( ...@@ -154,7 +156,8 @@ async def test_abort(
with monkeypatch.context() as m, ExitStack() as after: with monkeypatch.context() as m, ExitStack() as after:
m.setenv("VLLM_USE_V1", "1") m.setenv("VLLM_USE_V1", "1")
engine = AsyncLLM.from_engine_args(engine_args) with set_default_torch_num_threads(1):
engine = AsyncLLM.from_engine_args(engine_args)
after.callback(engine.shutdown) after.callback(engine.shutdown)
NUM_REQUESTS = 100 NUM_REQUESTS = 100
...@@ -226,7 +229,8 @@ async def test_finished_flag( ...@@ -226,7 +229,8 @@ async def test_finished_flag(
with monkeypatch.context() as m, ExitStack() as after: with monkeypatch.context() as m, ExitStack() as after:
m.setenv("VLLM_USE_V1", "1") m.setenv("VLLM_USE_V1", "1")
engine = AsyncLLM.from_engine_args(engine_args) with set_default_torch_num_threads(1):
engine = AsyncLLM.from_engine_args(engine_args)
after.callback(engine.shutdown) after.callback(engine.shutdown)
sampling_params = SamplingParams( sampling_params = SamplingParams(
...@@ -260,7 +264,8 @@ async def test_mid_stream_cancellation(monkeypatch: pytest.MonkeyPatch, ...@@ -260,7 +264,8 @@ async def test_mid_stream_cancellation(monkeypatch: pytest.MonkeyPatch,
with monkeypatch.context() as m, ExitStack() as after: with monkeypatch.context() as m, ExitStack() as after:
m.setenv("VLLM_USE_V1", "1") m.setenv("VLLM_USE_V1", "1")
engine = AsyncLLM.from_engine_args(engine_args) with set_default_torch_num_threads(1):
engine = AsyncLLM.from_engine_args(engine_args)
after.callback(engine.shutdown) after.callback(engine.shutdown)
NUM_REQUESTS = 100 NUM_REQUESTS = 100
...@@ -322,10 +327,11 @@ async def test_customize_loggers(monkeypatch): ...@@ -322,10 +327,11 @@ async def test_customize_loggers(monkeypatch):
with monkeypatch.context() as m, ExitStack() as after: with monkeypatch.context() as m, ExitStack() as after:
m.setenv("VLLM_USE_V1", "1") m.setenv("VLLM_USE_V1", "1")
engine = AsyncLLM.from_engine_args( with set_default_torch_num_threads(1):
TEXT_ENGINE_ARGS, engine = AsyncLLM.from_engine_args(
stat_loggers=[MockLoggingStatLogger], TEXT_ENGINE_ARGS,
) stat_loggers=[MockLoggingStatLogger],
)
after.callback(engine.shutdown) after.callback(engine.shutdown)
await engine.do_log_stats() await engine.do_log_stats()
...@@ -340,7 +346,8 @@ async def test_dp_rank_argument(monkeypatch: pytest.MonkeyPatch): ...@@ -340,7 +346,8 @@ async def test_dp_rank_argument(monkeypatch: pytest.MonkeyPatch):
with monkeypatch.context() as m, ExitStack() as after: with monkeypatch.context() as m, ExitStack() as after:
m.setenv("VLLM_USE_V1", "1") m.setenv("VLLM_USE_V1", "1")
engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS) with set_default_torch_num_threads(1):
engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
after.callback(engine.shutdown) after.callback(engine.shutdown)
sampling_params = SamplingParams(max_tokens=100, sampling_params = SamplingParams(max_tokens=100,
...@@ -362,3 +369,33 @@ async def test_dp_rank_argument(monkeypatch: pytest.MonkeyPatch): ...@@ -362,3 +369,33 @@ async def test_dp_rank_argument(monkeypatch: pytest.MonkeyPatch):
sampling_params=sampling_params, sampling_params=sampling_params,
data_parallel_rank=1): data_parallel_rank=1):
pass pass
@pytest.mark.asyncio
async def test_check_health(monkeypatch: pytest.MonkeyPatch):
"""Test that check_health returns normally for healthy engine
and raises EngineDeadError when the engine is dead.
"""
from unittest.mock import patch
from vllm.v1.engine.exceptions import EngineDeadError
with monkeypatch.context() as m, ExitStack() as after:
m.setenv("VLLM_USE_V1", "1")
with set_default_torch_num_threads(1):
engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
after.callback(engine.shutdown)
# Test 1: Healthy engine should not raise any exception
await engine.check_health()
# Test 2: Mock the errored property to simulate a dead engine
with patch.object(type(engine),
'errored',
new_callable=lambda: property(lambda self: True)
), pytest.raises(EngineDeadError):
await engine.check_health()
# Test 3: Verify healthy engine still works after mock
await engine.check_health()
...@@ -12,13 +12,14 @@ from transformers import AutoTokenizer ...@@ -12,13 +12,14 @@ from transformers import AutoTokenizer
from vllm import SamplingParams from vllm import SamplingParams
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import set_default_torch_num_threads
from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.core import EngineCore from vllm.v1.engine.core import EngineCore
from vllm.v1.executor.abstract import Executor, UniProcExecutor from vllm.v1.executor.abstract import Executor, UniProcExecutor
from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.outputs import ModelRunnerOutput
from ...utils import create_new_process_for_each_test from ...utils import create_new_process_for_each_test, multi_gpu_test
if not current_platform.is_cuda(): if not current_platform.is_cuda():
pytest.skip(reason="V1 currently only supported on CUDA.", pytest.skip(reason="V1 currently only supported on CUDA.",
...@@ -38,6 +39,7 @@ def make_request() -> EngineCoreRequest: ...@@ -38,6 +39,7 @@ def make_request() -> EngineCoreRequest:
mm_hashes=None, mm_hashes=None,
mm_placeholders=None, mm_placeholders=None,
sampling_params=SamplingParams(), sampling_params=SamplingParams(),
pooling_params=None,
eos_token_id=None, eos_token_id=None,
arrival_time=time.time(), arrival_time=time.time(),
lora_request=None, lora_request=None,
...@@ -56,9 +58,10 @@ def test_engine_core(monkeypatch: pytest.MonkeyPatch): ...@@ -56,9 +58,10 @@ def test_engine_core(monkeypatch: pytest.MonkeyPatch):
vllm_config = engine_args.create_engine_config() vllm_config = engine_args.create_engine_config()
executor_class = Executor.get_class(vllm_config) executor_class = Executor.get_class(vllm_config)
engine_core = EngineCore(vllm_config=vllm_config, with set_default_torch_num_threads(1):
executor_class=executor_class, engine_core = EngineCore(vllm_config=vllm_config,
log_stats=True) executor_class=executor_class,
log_stats=True)
"""Test basic request lifecycle.""" """Test basic request lifecycle."""
# First request. # First request.
...@@ -190,9 +193,10 @@ def test_engine_core_advanced_sampling(monkeypatch: pytest.MonkeyPatch): ...@@ -190,9 +193,10 @@ def test_engine_core_advanced_sampling(monkeypatch: pytest.MonkeyPatch):
vllm_config = engine_args.create_engine_config() vllm_config = engine_args.create_engine_config()
executor_class = Executor.get_class(vllm_config) executor_class = Executor.get_class(vllm_config)
engine_core = EngineCore(vllm_config=vllm_config, with set_default_torch_num_threads(1):
executor_class=executor_class, engine_core = EngineCore(vllm_config=vllm_config,
log_stats=True) executor_class=executor_class,
log_stats=True)
"""Test basic request lifecycle.""" """Test basic request lifecycle."""
# First request. # First request.
request: EngineCoreRequest = make_request() request: EngineCoreRequest = make_request()
...@@ -286,9 +290,10 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch): ...@@ -286,9 +290,10 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
enforce_eager=True, enforce_eager=True,
) )
vllm_config = engine_args.create_engine_config() vllm_config = engine_args.create_engine_config()
engine_core = EngineCore(vllm_config=vllm_config, with set_default_torch_num_threads(1):
log_stats=False, engine_core = EngineCore(vllm_config=vllm_config,
executor_class=DummyExecutor) log_stats=False,
executor_class=DummyExecutor)
assert engine_core.batch_queue is not None assert engine_core.batch_queue is not None
# Add two requests in a row. Each request have 12 prompt tokens. # Add two requests in a row. Each request have 12 prompt tokens.
...@@ -374,3 +379,37 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch): ...@@ -374,3 +379,37 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
# Odd steps schedules a new batch. # Odd steps schedules a new batch.
assert output is None assert output is None
step += 1 step += 1
@multi_gpu_test(num_gpus=2)
def test_engine_core_tp(monkeypatch: pytest.MonkeyPatch):
"""
Test engine can initialize worker in tp properly
"""
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
"""Setup the EngineCore."""
engine_args = EngineArgs(
model=MODEL_NAME,
tensor_parallel_size=2,
# Reduce startup time.
enforce_eager=True,
)
vllm_config = engine_args.create_engine_config()
executor_class = Executor.get_class(vllm_config)
with set_default_torch_num_threads(1):
engine_core = EngineCore(vllm_config=vllm_config,
executor_class=executor_class,
log_stats=True)
def get_worker_cache_config_field(worker, key: str):
return getattr(worker.cache_config, key)
num_gpu_blocks = engine_core.collective_rpc(
get_worker_cache_config_field, args=("num_gpu_blocks", ))
num_cpu_blocks = engine_core.collective_rpc(
get_worker_cache_config_field, args=("num_cpu_blocks", ))
assert all(x is not None for x in num_gpu_blocks)
assert all(x is not None for x in num_cpu_blocks)
...@@ -8,8 +8,10 @@ import time ...@@ -8,8 +8,10 @@ import time
import uuid import uuid
from threading import Thread from threading import Thread
from typing import Optional from typing import Optional
from unittest.mock import MagicMock
import pytest import pytest
import torch
from transformers import AutoTokenizer from transformers import AutoTokenizer
from tests.utils import multi_gpu_test from tests.utils import multi_gpu_test
...@@ -19,12 +21,13 @@ from vllm.distributed.kv_events import (BlockStored, KVEventBatch, ...@@ -19,12 +21,13 @@ from vllm.distributed.kv_events import (BlockStored, KVEventBatch,
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils import set_default_torch_num_threads
from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.core import EngineCore from vllm.v1.engine.core import EngineCore
from vllm.v1.engine.core_client import (AsyncMPClient, EngineCoreClient, from vllm.v1.engine.core_client import (AsyncMPClient, EngineCoreClient,
SyncMPClient) SyncMPClient)
from vllm.v1.engine.utils import CoreEngineProcManager
from vllm.v1.executor.abstract import Executor from vllm.v1.executor.abstract import Executor
from vllm.v1.utils import CoreEngineProcManager
from ...distributed.conftest import MockSubscriber from ...distributed.conftest import MockSubscriber
from ...utils import create_new_process_for_each_test from ...utils import create_new_process_for_each_test
...@@ -52,6 +55,7 @@ def make_request( ...@@ -52,6 +55,7 @@ def make_request(
mm_hashes=None, mm_hashes=None,
mm_placeholders=None, mm_placeholders=None,
sampling_params=params, sampling_params=params,
pooling_params=None,
eos_token_id=None, eos_token_id=None,
arrival_time=time.time(), arrival_time=time.time(),
lora_request=None, lora_request=None,
...@@ -138,13 +142,15 @@ def test_engine_core_client(monkeypatch: pytest.MonkeyPatch, ...@@ -138,13 +142,15 @@ def test_engine_core_client(monkeypatch: pytest.MonkeyPatch,
vllm_config = engine_args.create_engine_config( vllm_config = engine_args.create_engine_config(
UsageContext.UNKNOWN_CONTEXT) UsageContext.UNKNOWN_CONTEXT)
executor_class = Executor.get_class(vllm_config) executor_class = Executor.get_class(vllm_config)
client = EngineCoreClient.make_client(
multiprocess_mode=multiprocessing_mode, with set_default_torch_num_threads(1):
asyncio_mode=False, client = EngineCoreClient.make_client(
vllm_config=vllm_config, multiprocess_mode=multiprocessing_mode,
executor_class=executor_class, asyncio_mode=False,
log_stats=False, vllm_config=vllm_config,
) executor_class=executor_class,
log_stats=False,
)
MAX_TOKENS = 20 MAX_TOKENS = 20
params = SamplingParams(max_tokens=MAX_TOKENS) params = SamplingParams(max_tokens=MAX_TOKENS)
...@@ -223,13 +229,15 @@ async def test_engine_core_client_asyncio(monkeypatch: pytest.MonkeyPatch): ...@@ -223,13 +229,15 @@ async def test_engine_core_client_asyncio(monkeypatch: pytest.MonkeyPatch):
vllm_config = engine_args.create_engine_config( vllm_config = engine_args.create_engine_config(
usage_context=UsageContext.UNKNOWN_CONTEXT) usage_context=UsageContext.UNKNOWN_CONTEXT)
executor_class = Executor.get_class(vllm_config) executor_class = Executor.get_class(vllm_config)
client = EngineCoreClient.make_client(
multiprocess_mode=True, with set_default_torch_num_threads(1):
asyncio_mode=True, client = EngineCoreClient.make_client(
vllm_config=vllm_config, multiprocess_mode=True,
executor_class=executor_class, asyncio_mode=True,
log_stats=True, vllm_config=vllm_config,
) executor_class=executor_class,
log_stats=True,
)
try: try:
MAX_TOKENS = 20 MAX_TOKENS = 20
...@@ -312,13 +320,14 @@ def test_kv_cache_events( ...@@ -312,13 +320,14 @@ def test_kv_cache_events(
UsageContext.UNKNOWN_CONTEXT) UsageContext.UNKNOWN_CONTEXT)
executor_class = Executor.get_class(vllm_config) executor_class = Executor.get_class(vllm_config)
client = EngineCoreClient.make_client( with set_default_torch_num_threads(1):
multiprocess_mode=multiprocessing_mode, client = EngineCoreClient.make_client(
asyncio_mode=False, multiprocess_mode=multiprocessing_mode,
vllm_config=vllm_config, asyncio_mode=False,
executor_class=executor_class, vllm_config=vllm_config,
log_stats=False, executor_class=executor_class,
) log_stats=False,
)
endpoint = publisher_config.endpoint.replace("*", "127.0.0.1") endpoint = publisher_config.endpoint.replace("*", "127.0.0.1")
subscriber = MockSubscriber(endpoint, subscriber = MockSubscriber(endpoint,
topic=publisher_config.topic, topic=publisher_config.topic,
...@@ -394,13 +403,14 @@ async def test_kv_cache_events_dp( ...@@ -394,13 +403,14 @@ async def test_kv_cache_events_dp(
UsageContext.UNKNOWN_CONTEXT) UsageContext.UNKNOWN_CONTEXT)
executor_class = Executor.get_class(vllm_config) executor_class = Executor.get_class(vllm_config)
client = EngineCoreClient.make_client( with set_default_torch_num_threads(1):
multiprocess_mode=multiprocessing_mode, client = EngineCoreClient.make_client(
asyncio_mode=True, multiprocess_mode=multiprocessing_mode,
vllm_config=vllm_config, asyncio_mode=True,
executor_class=executor_class, vllm_config=vllm_config,
log_stats=False, executor_class=executor_class,
) log_stats=False,
)
await asyncio.sleep(1) await asyncio.sleep(1)
# Build endpoints for all DP ranks # Build endpoints for all DP ranks
...@@ -509,3 +519,72 @@ def test_startup_failure(monkeypatch: pytest.MonkeyPatch): ...@@ -509,3 +519,72 @@ def test_startup_failure(monkeypatch: pytest.MonkeyPatch):
) )
assert "Engine core initialization failed" in str(e_info.value) assert "Engine core initialization failed" in str(e_info.value)
@create_new_process_for_each_test()
def test_engine_core_proc_instantiation_cuda_empty(
monkeypatch: pytest.MonkeyPatch):
"""
Test that EngineCoreProc can be instantiated when CUDA_VISIBLE_DEVICES
is empty. This ensures the engine frontend does not need access to GPUs.
"""
from vllm.v1.engine.core import EngineCoreProc
from vllm.v1.executor.abstract import Executor
# Create a simple mock executor instead of a complex custom class
mock_executor_class = MagicMock(spec=Executor)
def create_mock_executor(vllm_config):
mock_executor = MagicMock()
# Only implement the methods that are actually called during init
from vllm.v1.kv_cache_interface import FullAttentionSpec
mock_spec = FullAttentionSpec(block_size=16,
num_kv_heads=1,
head_size=64,
dtype=torch.float16,
use_mla=False)
mock_executor.get_kv_cache_specs.return_value = [{
"default": mock_spec
}]
mock_executor.determine_available_memory.return_value = [
1024 * 1024 * 1024
]
mock_executor.initialize_from_config.return_value = None
mock_executor.max_concurrent_batches = 1
return mock_executor
mock_executor_class.side_effect = create_mock_executor
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
m.setenv("CUDA_VISIBLE_DEVICES", "") # No CUDA devices
from vllm.v1.engine.utils import EngineZmqAddresses
def mock_startup_handshake(self, handshake_socket, on_head_node,
parallel_config):
return EngineZmqAddresses(inputs=["tcp://127.0.0.1:5555"],
outputs=["tcp://127.0.0.1:5556"],
coordinator_input=None,
coordinator_output=None)
# Background processes are not important here
m.setattr(EngineCoreProc, "startup_handshake", mock_startup_handshake)
vllm_config = EngineArgs(
model="deepseek-ai/DeepSeek-V2-Lite",
trust_remote_code=True).create_engine_config()
engine_core_proc = EngineCoreProc(
vllm_config=vllm_config,
local_client=True,
handshake_address="tcp://127.0.0.1:12345",
executor_class=mock_executor_class,
log_stats=False,
engine_index=0,
)
engine_core_proc.shutdown()
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from transformers import AutoTokenizer
from vllm.sampling_params import SamplingParams
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.detokenizer import IncrementalDetokenizer
# ruff: noqa: E501
def test_fast_inc_detok_invalid_utf8_err_case():
"""
Test edge case where tokenizer can produce non-monotonic,
invalid UTF-8 output, which breaks the internal state of
tokenizers' DecodeStream.
See https://github.com/vllm-project/vllm/issues/17448.
Thanks to reproducer from @fpaupier:
https://gist.github.com/fpaupier/0ed1375bd7633c5be6c894b1c7ac1be3.
"""
tokenizer = AutoTokenizer.from_pretrained("google/gemma-3-1b-it")
# Create a test request
prompt_token_ids = [107, 4606, 236787, 107]
params = SamplingParams(skip_special_tokens=True)
request = EngineCoreRequest(
"test",
prompt_token_ids,
None,
None,
None,
params,
None,
None,
0.0,
None,
cache_salt=None,
data_parallel_rank=None,
)
detokenizer = IncrementalDetokenizer.from_new_request(tokenizer, request)
assert detokenizer.__class__.__name__ == "FastIncrementalDetokenizer", \
"Should use FastIncrementalDetokenizer by default"
# Process tokens incrementally
test_tokens = [
236840, 107, 138, 236782, 107, 140, 236775, 6265, 1083, 623, 121908,
147418, 827, 107, 140, 236775, 6265, 236779, 2084, 1083, 623, 203292,
827, 107, 140, 236775, 6265, 236779, 7777, 1083, 623, 121908, 147418,
569, 537, 236789, 65880, 569, 537, 236789, 62580, 853, 115693, 210118,
35178, 16055, 1270, 759, 215817, 4758, 1925, 1117, 827, 107, 140,
236775, 5654, 1083, 623, 110733, 46291, 827, 107, 140, 236775, 5654,
236779, 2084, 1083, 623, 136955, 56731, 827, 107, 140, 236775, 5654,
236779, 7777, 1083, 623, 194776, 2947, 496, 109811, 1608, 890, 215817,
4758, 1925, 1117, 2789, 432, 398, 602, 31118, 569, 124866, 134772, 509,
19478, 1640, 33779, 236743, 236770, 236819, 236825, 236771, 432, 398,
432, 237167, 827, 107, 140, 236775, 77984, 1083, 623, 2709, 236745,
2555, 513, 236789, 602, 31118, 569
]
output = ""
for i, token_id in enumerate(test_tokens):
detokenizer.update([token_id], False)
finished = i == len(test_tokens) - 1
output += detokenizer.get_next_output_text(finished, delta=True)
# fmt: off
assert output == r'''[
{
"source": "Résultats",
"source_type": "CONCEPT",
"source_description": "Résultats de l'analyse de l'impact des opérations israéliennes sur la frontière libanaise",
"target": "Israël",
"target_type": "ORGANIZATION",
"target_description": "Pays qui a obtenu à sa frontière libanaise « un niveau de calme inédit depuis les années 1960 »",
"relationship": "Obtention d'un niveau de'''
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import random import random
from typing import Optional from typing import TYPE_CHECKING, Optional
import pytest import pytest
from vllm import LLM, SamplingParams from vllm import LLM
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
from vllm.v1.metrics.reader import Counter, Gauge, Histogram, Metric, Vector from vllm.v1.metrics.reader import Counter, Gauge, Histogram, Metric, Vector
if TYPE_CHECKING:
from tests.conftest import VllmRunner
MODEL = "facebook/opt-125m" MODEL = "facebook/opt-125m"
DTYPE = "half" DTYPE = "half"
def _vllm_model(apc: bool, vllm_runner, monkeypatch): def _vllm_model(
apc: bool,
vllm_runner: type[VllmRunner],
monkeypatch: pytest.MonkeyPatch,
*,
skip_tokenizer_init: bool = False,
):
"""Set up VllmRunner instance.""" """Set up VllmRunner instance."""
monkeypatch.setenv("VLLM_USE_V1", "1") monkeypatch.setenv("VLLM_USE_V1", "1")
return vllm_runner( return vllm_runner(
...@@ -23,6 +34,7 @@ def _vllm_model(apc: bool, vllm_runner, monkeypatch): ...@@ -23,6 +34,7 @@ def _vllm_model(apc: bool, vllm_runner, monkeypatch):
enforce_eager=True, enforce_eager=True,
enable_prefix_caching=apc, enable_prefix_caching=apc,
gpu_memory_utilization=0.5, gpu_memory_utilization=0.5,
skip_tokenizer_init=skip_tokenizer_init,
) )
...@@ -45,9 +57,27 @@ def vllm_model_apc(vllm_runner, monkeypatch): ...@@ -45,9 +57,27 @@ def vllm_model_apc(vllm_runner, monkeypatch):
yield vllm_model yield vllm_model
@pytest.fixture(
# Function scope decouples tests & allows
# env var adjustment via monkeypatch
scope="function",
# Prefix caching
params=[False, True])
def vllm_model_skip_tokenizer_init(vllm_runner, request, monkeypatch):
"""VllmRunner test fixture with APC."""
with _vllm_model(
request.param,
vllm_runner,
monkeypatch,
skip_tokenizer_init=True,
) as vllm_model:
yield vllm_model
def _get_test_sampling_params( def _get_test_sampling_params(
prompt_list: list[str], prompt_list: list[str],
seed: Optional[int] = 42, seed: Optional[int] = 42,
structured_outputs: bool = False,
) -> tuple[list[SamplingParams], list[int]]: ) -> tuple[list[SamplingParams], list[int]]:
"""Generate random sampling params for a batch.""" """Generate random sampling params for a batch."""
...@@ -62,14 +92,34 @@ def _get_test_sampling_params( ...@@ -62,14 +92,34 @@ def _get_test_sampling_params(
n_list = [get_mostly_n_gt1() for _ in range(len(prompt_list))] n_list = [get_mostly_n_gt1() for _ in range(len(prompt_list))]
# High temperature to maximize the chance of unique completions # High temperature to maximize the chance of unique completions
return [ return [
SamplingParams(temperature=0.95, top_p=0.95, n=n, seed=seed) SamplingParams(
for n in n_list temperature=0.95,
top_p=0.95,
n=n,
seed=seed,
guided_decoding=GuidedDecodingParams(
regex="[0-9]+") if structured_outputs else None,
) for n in n_list
], n_list ], n_list
def test_compatibility_with_skip_tokenizer_init(
vllm_model_skip_tokenizer_init: VllmRunner,
example_prompts: list[str],
):
# Case 1: Structured output request should raise an error.
sampling_params_list, _ = _get_test_sampling_params(
example_prompts,
structured_outputs=True,
)
model: LLM = vllm_model_skip_tokenizer_init.model
with pytest.raises(ValueError):
_ = model.generate(example_prompts, sampling_params_list)
def test_parallel_sampling(vllm_model, example_prompts) -> None: def test_parallel_sampling(vllm_model, example_prompts) -> None:
"""Test passes if parallel sampling `n>1` yields `n` unique completions. """Test passes if parallel sampling `n>1` yields `n` unique completions.
Args: Args:
vllm_model: VllmRunner instance under test. vllm_model: VllmRunner instance under test.
example_prompt: test fixture providing prompts for testing. example_prompt: test fixture providing prompts for testing.
......
...@@ -66,7 +66,8 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind, ...@@ -66,7 +66,8 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind,
output_kind=request_output_kind, output_kind=request_output_kind,
stop=[], stop=[],
include_stop_str_in_output=False, include_stop_str_in_output=False,
)) ),
pooling_params=None)
for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens) for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
] ]
...@@ -416,7 +417,8 @@ def test_logprobs_processor(request_output_kind: RequestOutputKind, ...@@ -416,7 +417,8 @@ def test_logprobs_processor(request_output_kind: RequestOutputKind,
include_stop_str_in_output=False, include_stop_str_in_output=False,
logprobs=num_sample_logprobs, logprobs=num_sample_logprobs,
prompt_logprobs=num_prompt_logprobs, prompt_logprobs=num_prompt_logprobs,
)) ),
pooling_params=None)
for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens) for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
] ]
...@@ -582,7 +584,8 @@ def test_stop_token(include_stop_str_in_output: bool, ...@@ -582,7 +584,8 @@ def test_stop_token(include_stop_str_in_output: bool,
logprobs=num_sample_logprobs, logprobs=num_sample_logprobs,
prompt_logprobs=None, prompt_logprobs=None,
ignore_eos=ignore_eos, ignore_eos=ignore_eos,
)) ),
pooling_params=None)
# Add request to the detokenizer. # Add request to the detokenizer.
output_processor.add_request(request, prompt_string) output_processor.add_request(request, prompt_string)
...@@ -678,7 +681,8 @@ def test_stop_string(include_stop_str_in_output: bool, ...@@ -678,7 +681,8 @@ def test_stop_string(include_stop_str_in_output: bool,
include_stop_str_in_output=include_stop_str_in_output, include_stop_str_in_output=include_stop_str_in_output,
logprobs=num_sample_logprobs, logprobs=num_sample_logprobs,
prompt_logprobs=None, prompt_logprobs=None,
)) ),
pooling_params=None)
for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens) for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
] ]
...@@ -786,6 +790,7 @@ def test_iteration_stats(dummy_test_vectors): ...@@ -786,6 +790,7 @@ def test_iteration_stats(dummy_test_vectors):
cache_salt=None, cache_salt=None,
data_parallel_rank=None, data_parallel_rank=None,
sampling_params=SamplingParams(), sampling_params=SamplingParams(),
pooling_params=None,
) for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens) ) for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
] ]
......
...@@ -38,7 +38,7 @@ def default_server_args(): ...@@ -38,7 +38,7 @@ def default_server_args():
]]) ]])
def server(default_server_args, request): def server(default_server_args, request):
if request.param: if request.param:
default_server_args.extend(request.param) default_server_args = default_server_args + request.param
with RemoteOpenAIServer(MODEL_NAME, default_server_args) as remote_server: with RemoteOpenAIServer(MODEL_NAME, default_server_args) as remote_server:
yield remote_server yield remote_server
......
...@@ -2,10 +2,12 @@ ...@@ -2,10 +2,12 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio import asyncio
import os import os
import re
import openai # use the official client for correctness check import openai # use the official client for correctness check
import pytest import pytest
import pytest_asyncio import pytest_asyncio
import requests
from tests.utils import RemoteOpenAIServer from tests.utils import RemoteOpenAIServer
...@@ -14,6 +16,122 @@ MODEL_NAME = "ibm-research/PowerMoE-3b" ...@@ -14,6 +16,122 @@ MODEL_NAME = "ibm-research/PowerMoE-3b"
DP_SIZE = os.getenv("DP_SIZE", "1") DP_SIZE = os.getenv("DP_SIZE", "1")
def get_prometheus_metrics(
server: RemoteOpenAIServer) -> dict[str, dict[str, float]]:
"""Fetch and parse Prometheus metrics from the /metrics endpoint.
Returns:
Dict mapping metric names to their values grouped by labels.
For example: {"vllm:request_success": {
"engine=0": 5.0, "engine=1": 3.0}
}
"""
try:
response = requests.get(server.url_for("metrics"), timeout=10)
response.raise_for_status()
metrics: dict[str, dict[str, float]] = {}
# Regex patterns for Prometheus metrics
metric_with_labels = re.compile(
r'^([a-zA-Z_:][a-zA-Z0-9_:]*)\{([^}]*)\}\s+([\d\.\-\+e]+)$')
metric_simple = re.compile(
r'^([a-zA-Z_:][a-zA-Z0-9_:]*)\s+([\d\.\-\+e]+)$')
for line in response.text.split('\n'):
line = line.strip()
# Skip comments and empty lines
if not line or line.startswith('#'):
continue
# Try to match metric with labels first
match = metric_with_labels.match(line)
if match:
metric_name, labels_part, value_str = match.groups()
try:
value = float(value_str)
if metric_name not in metrics:
metrics[metric_name] = {}
metrics[metric_name][f'{{{labels_part}}}'] = value
except ValueError:
continue
else:
# Try simple metric without labels
match = metric_simple.match(line)
if match:
metric_name, value_str = match.groups()
try:
value = float(value_str)
if metric_name not in metrics:
metrics[metric_name] = {}
metrics[metric_name][''] = value
except ValueError:
continue
return metrics
except Exception as e:
pytest.fail(f"Failed to fetch Prometheus metrics: {e}")
return {}
def get_engine_request_counts(
metrics: dict[str, dict[str, float]]) -> dict[str, float]:
"""Extract request counts per engine from Prometheus metrics.
Returns:
Dict mapping engine indices to request counts.
For example: {"0": 15.0, "1": 12.0}
"""
engine_counts = {}
# Look for request success metrics with engine labels
success_metrics = metrics.get("vllm:request_success_total", {})
engine_pattern = re.compile(r'engine="([^"]*)"')
for labels, count in success_metrics.items():
# Extract engine ID from labels using regex
match = engine_pattern.search(labels)
if match:
engine_id = match.group(1)
if engine_id not in engine_counts:
engine_counts[engine_id] = 0.0
engine_counts[engine_id] += count
return engine_counts
def check_request_balancing(server: RemoteOpenAIServer):
"""Check request balancing via Prometheus metrics if DP_SIZE > 1.
Args:
server: The RemoteOpenAIServer instance
"""
dp_size = int(DP_SIZE)
if dp_size <= 1:
return
# Get metrics after all requests are completed
metrics = get_prometheus_metrics(server)
engine_counts = get_engine_request_counts(metrics)
# Check that multiple engines received requests
engines_with_requests = [
engine for engine, count in engine_counts.items() if count > 0
]
assert len(engines_with_requests) == dp_size, (
f"Expected requests to be distributed across multiple engines,"
f" but only engine(s) {engines_with_requests} received "
f"requests. Engine counts: {engine_counts}")
# Verify that the load is reasonably balanced
# (no engine should handle all requests)
total_requests = sum(engine_counts.values())
for count in engine_counts.values():
assert count > total_requests // (dp_size + 1), (
f"requests are imbalanced: {engine_counts}")
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def default_server_args(): def default_server_args():
return [ return [
...@@ -50,6 +168,7 @@ async def client(server): ...@@ -50,6 +168,7 @@ async def client(server):
[MODEL_NAME], [MODEL_NAME],
) )
async def test_single_completion(client: openai.AsyncOpenAI, async def test_single_completion(client: openai.AsyncOpenAI,
server: RemoteOpenAIServer,
model_name: str) -> None: model_name: str) -> None:
async def make_request(): async def make_request():
...@@ -97,6 +216,9 @@ async def test_single_completion(client: openai.AsyncOpenAI, ...@@ -97,6 +216,9 @@ async def test_single_completion(client: openai.AsyncOpenAI,
assert len(results) == num_requests assert len(results) == num_requests
assert all(completion is not None for completion in results) assert all(completion is not None for completion in results)
# Check request balancing via Prometheus metrics if DP_SIZE > 1
check_request_balancing(server)
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -104,6 +226,7 @@ async def test_single_completion(client: openai.AsyncOpenAI, ...@@ -104,6 +226,7 @@ async def test_single_completion(client: openai.AsyncOpenAI,
[MODEL_NAME], [MODEL_NAME],
) )
async def test_completion_streaming(client: openai.AsyncOpenAI, async def test_completion_streaming(client: openai.AsyncOpenAI,
server: RemoteOpenAIServer,
model_name: str) -> None: model_name: str) -> None:
prompt = "What is an LLM?" prompt = "What is an LLM?"
...@@ -170,3 +293,6 @@ async def test_completion_streaming(client: openai.AsyncOpenAI, ...@@ -170,3 +293,6 @@ async def test_completion_streaming(client: openai.AsyncOpenAI,
results results
) == num_requests, f"Expected {num_requests} results, got {len(results)}" ) == num_requests, f"Expected {num_requests} results, got {len(results)}"
assert all(results), "Not all streaming requests completed successfully." assert all(results), "Not all streaming requests completed successfully."
# Check request balancing via Prometheus metrics if DP_SIZE > 1
check_request_balancing(server)
...@@ -196,8 +196,7 @@ async def stream_service_response(client_info: dict, endpoint: str, ...@@ -196,8 +196,7 @@ async def stream_service_response(client_info: dict, endpoint: str,
yield chunk yield chunk
@app.post("/v1/completions") async def _handle_completions(api: str, request: Request):
async def handle_completions(request: Request):
try: try:
req_data = await request.json() req_data = await request.json()
request_id = str(uuid.uuid4()) request_id = str(uuid.uuid4())
...@@ -206,9 +205,8 @@ async def handle_completions(request: Request): ...@@ -206,9 +205,8 @@ async def handle_completions(request: Request):
prefill_client_info = get_next_client(request.app, 'prefill') prefill_client_info = get_next_client(request.app, 'prefill')
# Send request to prefill service # Send request to prefill service
response = await send_request_to_service(prefill_client_info, response = await send_request_to_service(prefill_client_info, api,
"/completions", req_data, req_data, request_id)
request_id)
# Extract the needed fields # Extract the needed fields
response_json = response.json() response_json = response.json()
...@@ -224,7 +222,7 @@ async def handle_completions(request: Request): ...@@ -224,7 +222,7 @@ async def handle_completions(request: Request):
# Stream response from decode service # Stream response from decode service
async def generate_stream(): async def generate_stream():
async for chunk in stream_service_response(decode_client_info, async for chunk in stream_service_response(decode_client_info,
"/completions", api,
req_data, req_data,
request_id=request_id): request_id=request_id):
yield chunk yield chunk
...@@ -237,12 +235,22 @@ async def handle_completions(request: Request): ...@@ -237,12 +235,22 @@ async def handle_completions(request: Request):
import traceback import traceback
exc_info = sys.exc_info() exc_info = sys.exc_info()
print("Error occurred in disagg prefill proxy server" print("Error occurred in disagg prefill proxy server"
" - completions endpoint") f" - {api} endpoint")
print(e) print(e)
print("".join(traceback.format_exception(*exc_info))) print("".join(traceback.format_exception(*exc_info)))
raise raise
@app.post("/v1/completions")
async def handle_completions(request: Request):
return await _handle_completions("/completions", request)
@app.post("/v1/chat/completions")
async def handle_chat_completions(request: Request):
return await _handle_completions("/chat/completions", request)
@app.get("/healthcheck") @app.get("/healthcheck")
async def healthcheck(): async def healthcheck():
"""Simple endpoint to check if the server is running.""" """Simple endpoint to check if the server is running."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment