Commit 8d75f22e authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.13.0rc1' into v0.13.0rc1-ori

parents ce888aa4 7d80c73d
...@@ -112,7 +112,7 @@ class TestBaseThinkingReasoningParserMethods: ...@@ -112,7 +112,7 @@ class TestBaseThinkingReasoningParserMethods:
"""Test the is_reasoning_end method.""" """Test the is_reasoning_end method."""
parser = TestThinkingReasoningParser(test_tokenizer) parser = TestThinkingReasoningParser(test_tokenizer)
end_token_id = parser.end_token_id end_token_id = parser.end_token_id
start_token_id = parser.start_token_id
# Test with end token present # Test with end token present
assert parser.is_reasoning_end([1, 2, end_token_id, 4]) is True assert parser.is_reasoning_end([1, 2, end_token_id, 4]) is True
...@@ -122,6 +122,51 @@ class TestBaseThinkingReasoningParserMethods: ...@@ -122,6 +122,51 @@ class TestBaseThinkingReasoningParserMethods:
# Test with empty list # Test with empty list
assert parser.is_reasoning_end([]) is False assert parser.is_reasoning_end([]) is False
# Test with interleaved thinking
assert parser.is_reasoning_end([1, start_token_id, 2, end_token_id]) is True
assert parser.is_reasoning_end([1, start_token_id, 2, 3]) is False
assert (
parser.is_reasoning_end(
[1, start_token_id, 2, end_token_id, 2, 2, start_token_id]
)
is False
)
def test_is_reasoning_end_streaming(self, test_tokenizer):
"""Test the is_reasoning_end_streaming method."""
parser = TestThinkingReasoningParser(test_tokenizer)
end_token_id = parser.end_token_id
start_token_id = parser.start_token_id
assert (
parser.is_reasoning_end_streaming([1, 2, end_token_id], [end_token_id])
is True
)
assert parser.is_reasoning_end_streaming([1, 2, 3, 4], [4]) is False
assert parser.is_reasoning_end_streaming([], []) is False
assert (
parser.is_reasoning_end_streaming(
[1, start_token_id, 2, end_token_id], [end_token_id]
)
is True
)
assert (
parser.is_reasoning_end_streaming([1, start_token_id, 2, 3], [3]) is False
)
assert (
parser.is_reasoning_end_streaming(
[1, start_token_id, 2, end_token_id, 2, start_token_id, 2],
[2],
)
is False
)
assert (
parser.is_reasoning_end_streaming(
[1, start_token_id, 2, end_token_id, 2, 2], [2]
)
is False
)
def test_extract_content_ids(self, test_tokenizer): def test_extract_content_ids(self, test_tokenizer):
"""Test the extract_content_ids method.""" """Test the extract_content_ids method."""
parser = TestThinkingReasoningParser(test_tokenizer) parser = TestThinkingReasoningParser(test_tokenizer)
......
...@@ -40,6 +40,7 @@ def test_identity_reasoning_parser_basic(tokenizer): ...@@ -40,6 +40,7 @@ def test_identity_reasoning_parser_basic(tokenizer):
input_tokens = tokenizer.tokenize(input_text) input_tokens = tokenizer.tokenize(input_text)
input_ids = tokenizer.convert_tokens_to_ids(input_tokens) input_ids = tokenizer.convert_tokens_to_ids(input_tokens)
assert parser.is_reasoning_end(input_ids) is True assert parser.is_reasoning_end(input_ids) is True
assert parser.is_reasoning_end_streaming(input_ids, input_ids) is True
# Test extract_content_ids returns all input_ids # Test extract_content_ids returns all input_ids
assert parser.extract_content_ids(input_ids) == input_ids assert parser.extract_content_ids(input_ids) == input_ids
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from transformers import AutoTokenizer
from tests.reasoning.utils import run_reasoning_extraction
from vllm.reasoning import ReasoningParser, ReasoningParserManager
from vllm.reasoning.deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser
from vllm.reasoning.holo2_reasoning_parser import Holo2ReasoningParser
from vllm.reasoning.identity_reasoning_parser import IdentityReasoningParser
REASONING_MODEL_NAME = "HCompany/Holo2-4B"
@pytest.fixture(scope="module")
def tokenizer():
return AutoTokenizer.from_pretrained(REASONING_MODEL_NAME)
@pytest.mark.parametrize(
"thinking,expected_parser_type",
[
(True, DeepSeekR1ReasoningParser),
(False, IdentityReasoningParser),
],
)
def test_parser_selection(tokenizer, thinking, expected_parser_type):
parser = Holo2ReasoningParser(
tokenizer,
chat_template_kwargs={
"thinking": thinking,
},
)
assert isinstance(parser._parser, expected_parser_type)
def test_holo2_default_parser_is_deepseekr1(tokenizer):
parser = Holo2ReasoningParser(tokenizer)
assert isinstance(parser._parser, DeepSeekR1ReasoningParser)
def test_holo2_supports_structured_output(tokenizer):
# Structured output manager uses the reasoning parser to check if the
# reasoning content is ended before applying the grammar. The main function
# used is is_reasoning_end. This test checks if the parser is able to
# correctly identify the end of the reasoning content.
# important to not pass chat_template_kwargs here as it is done in the
# StructuredOutputManager
parser = Holo2ReasoningParser(tokenizer)
end_token_id = tokenizer.encode("</think>", add_special_tokens=False)[0]
assert parser.is_reasoning_end([1, 2, 4, end_token_id])
assert not parser.is_reasoning_end([1, 2, 4])
assert parser.is_reasoning_end([1, 2, 4, end_token_id, 5])
# thinking is True, non-streaming
WITH_THINK = {
"output": "This is a reasoning section</think>This is the rest",
"reasoning": "This is a reasoning section",
"content": "This is the rest",
}
# thinking is True, streaming
WITH_THINK_STREAM = {
"output": "This is a reasoning section</think>This is the rest",
"reasoning": "This is a reasoning section",
"content": "This is the rest",
}
# thinking is False, non-streaming
THINKING_DISABLED = {
"output": "This is the rest",
"reasoning": None,
"content": "This is the rest",
}
# thinking is False, streaming
THINKING_DISABLED_STREAM = {
"output": "This is the rest",
"reasoning": None,
"content": "This is the rest",
}
# thinking is False but the model output </think>, non-streaming
THINKING_DISABLED_WITH_CLOSE_TAG = {
"output": "</think>This is the rest",
"reasoning": None,
"content": "</think>This is the rest",
}
# thinking is False but the model output </think>, streaming
THINKING_DISABLED_WITH_CLOSE_TAG_STREAM = {
"output": "some text</think>This is the rest",
"reasoning": None,
"content": "some text</think>This is the rest",
}
COMPLETE_REASONING = {
"output": "This is a reasoning section</think>",
"reasoning": "This is a reasoning section",
"content": None,
}
TEST_CASES = [
pytest.param(
False,
WITH_THINK,
None,
id="with_think",
),
pytest.param(
True,
WITH_THINK_STREAM,
None,
id="with_think_stream",
),
pytest.param(
False,
WITH_THINK,
{"thinking": True},
id="with_think_enabled",
),
pytest.param(
True,
WITH_THINK_STREAM,
{"thinking": True},
id="with_think_stream_enabled",
),
pytest.param(
False,
THINKING_DISABLED,
{"thinking": False},
id="thinking_disabled",
),
pytest.param(
True,
THINKING_DISABLED_STREAM,
{"thinking": False},
id="thinking_disabled_stream",
),
pytest.param(
False,
THINKING_DISABLED_WITH_CLOSE_TAG,
{"thinking": False},
id="thinking_disabled_with_close_tag",
),
pytest.param(
True,
THINKING_DISABLED_WITH_CLOSE_TAG_STREAM,
{"thinking": False},
id="thinking_disabled_with_close_tag_stream",
),
pytest.param(
False,
COMPLETE_REASONING,
None,
id="complete_reasoning",
),
pytest.param(
True,
COMPLETE_REASONING,
None,
id="complete_reasoning_stream",
),
]
@pytest.mark.parametrize("streaming, param_dict, chat_template_kwargs", TEST_CASES)
def test_reasoning(
streaming: bool,
param_dict: dict,
chat_template_kwargs: dict | None,
tokenizer,
):
output = tokenizer.tokenize(param_dict["output"])
output_tokens: list[str] = [
tokenizer.convert_tokens_to_string([token]) for token in output
]
parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser("holo2")(
tokenizer,
chat_template_kwargs=chat_template_kwargs,
)
reasoning, content = run_reasoning_extraction(
parser, output_tokens, streaming=streaming
)
assert reasoning == param_dict["reasoning"]
assert content == param_dict["content"]
...@@ -5,6 +5,10 @@ ...@@ -5,6 +5,10 @@
set -e set -e
set -x set -x
merge_base_commit=$(git merge-base HEAD origin/main)
echo "Current merge base commit with main: $merge_base_commit"
git show --oneline -s $merge_base_commit
cd /vllm-workspace/ cd /vllm-workspace/
# uninstall vllm # uninstall vllm
...@@ -18,7 +22,7 @@ apt autoremove -y ...@@ -18,7 +22,7 @@ apt autoremove -y
echo 'import os; os.system("touch /tmp/changed.file")' >> vllm/__init__.py echo 'import os; os.system("touch /tmp/changed.file")' >> vllm/__init__.py
VLLM_TEST_USE_PRECOMPILED_NIGHTLY_WHEEL=1 VLLM_USE_PRECOMPILED=1 pip3 install -vvv -e . VLLM_PRECOMPILED_WHEEL_COMMIT=$merge_base_commit VLLM_USE_PRECOMPILED=1 pip3 install -vvv -e .
# Run the script # Run the script
python3 -c 'import vllm' python3 -c 'import vllm'
......
...@@ -97,7 +97,7 @@ def test_update_config(): ...@@ -97,7 +97,7 @@ def test_update_config():
("intfloat/multilingual-e5-small", "pooling", "none", "embed"), ("intfloat/multilingual-e5-small", "pooling", "none", "embed"),
("jason9693/Qwen2.5-1.5B-apeach", "pooling", "classify", "classify"), ("jason9693/Qwen2.5-1.5B-apeach", "pooling", "classify", "classify"),
("cross-encoder/ms-marco-MiniLM-L-6-v2", "pooling", "none", "classify"), ("cross-encoder/ms-marco-MiniLM-L-6-v2", "pooling", "none", "classify"),
("Qwen/Qwen2.5-Math-RM-72B", "pooling", "none", "reward"), ("Qwen/Qwen2.5-Math-RM-72B", "pooling", "none", "embed"),
("openai/whisper-small", "generate", "none", "transcription"), ("openai/whisper-small", "generate", "none", "transcription"),
], ],
) )
...@@ -629,8 +629,8 @@ def test_s3_url_different_models_create_different_directories(mock_pull_files): ...@@ -629,8 +629,8 @@ def test_s3_url_different_models_create_different_directories(mock_pull_files):
( (
"internlm/internlm2-1_8b-reward", "internlm/internlm2-1_8b-reward",
"decoder", "decoder",
False, True,
"Pooling models with all pooling does not support chunked prefill.", "Pooling models with causal attn and all pooling support chunked prefill.",
), ),
( (
"BAAI/bge-base-en", "BAAI/bge-base-en",
...@@ -748,8 +748,8 @@ def test_is_chunked_prefill_supported( ...@@ -748,8 +748,8 @@ def test_is_chunked_prefill_supported(
( (
"internlm/internlm2-1_8b-reward", "internlm/internlm2-1_8b-reward",
"decoder", "decoder",
False, True,
"Pooling models with all pooling does not support prefix caching.", "Pooling models with causal attn and all pooling support prefix caching.",
), ),
( (
"BAAI/bge-base-en", "BAAI/bge-base-en",
......
...@@ -365,3 +365,54 @@ class TestEnvSetWithChoices: ...@@ -365,3 +365,54 @@ class TestEnvSetWithChoices:
with patch.dict(os.environ, {"TEST_ENV": "option1,option1,option2"}): with patch.dict(os.environ, {"TEST_ENV": "option1,option1,option2"}):
env_func = env_set_with_choices("TEST_ENV", [], ["option1", "option2"]) env_func = env_set_with_choices("TEST_ENV", [], ["option1", "option2"])
assert env_func() == {"option1", "option2"} assert env_func() == {"option1", "option2"}
class TestVllmConfigureLogging:
"""Test cases for VLLM_CONFIGURE_LOGGING environment variable."""
def test_configure_logging_defaults_to_true(self):
"""Test that VLLM_CONFIGURE_LOGGING defaults to True when not set."""
# Ensure the env var is not set
with patch.dict(os.environ, {}, clear=False):
if "VLLM_CONFIGURE_LOGGING" in os.environ:
del os.environ["VLLM_CONFIGURE_LOGGING"]
# Clear cache if it exists
if hasattr(envs.__getattr__, "cache_clear"):
envs.__getattr__.cache_clear()
result = envs.VLLM_CONFIGURE_LOGGING
assert result is True
assert isinstance(result, bool)
def test_configure_logging_with_zero_string(self):
"""Test that VLLM_CONFIGURE_LOGGING='0' evaluates to False."""
with patch.dict(os.environ, {"VLLM_CONFIGURE_LOGGING": "0"}):
# Clear cache if it exists
if hasattr(envs.__getattr__, "cache_clear"):
envs.__getattr__.cache_clear()
result = envs.VLLM_CONFIGURE_LOGGING
assert result is False
assert isinstance(result, bool)
def test_configure_logging_with_one_string(self):
"""Test that VLLM_CONFIGURE_LOGGING='1' evaluates to True."""
with patch.dict(os.environ, {"VLLM_CONFIGURE_LOGGING": "1"}):
# Clear cache if it exists
if hasattr(envs.__getattr__, "cache_clear"):
envs.__getattr__.cache_clear()
result = envs.VLLM_CONFIGURE_LOGGING
assert result is True
assert isinstance(result, bool)
def test_configure_logging_with_invalid_value_raises_error(self):
"""Test that invalid VLLM_CONFIGURE_LOGGING value raises ValueError."""
with patch.dict(os.environ, {"VLLM_CONFIGURE_LOGGING": "invalid"}):
# Clear cache if it exists
if hasattr(envs.__getattr__, "cache_clear"):
envs.__getattr__.cache_clear()
with pytest.raises(ValueError, match="invalid literal for int"):
_ = envs.VLLM_CONFIGURE_LOGGING
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
from collections.abc import Generator
import partial_json_parser
import pytest
from mistral_common.protocol.instruct.messages import AssistantMessage
from mistral_common.protocol.instruct.request import InstructRequest
from mistral_common.protocol.instruct.tool_calls import FunctionCall, ToolCall
from partial_json_parser.core.options import Allow
from vllm.entrypoints.openai.protocol import DeltaMessage, DeltaToolCall
from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import MistralToolParser
from vllm.tokenizers import (
MistralTokenizer,
TokenizerLike,
get_tokenizer,
)
from vllm.tokenizers.detokenizer_utils import detokenize_incrementally
@pytest.fixture(scope="module")
def mistral_pre_v11_tokenizer():
MODEL = "mistralai/Mistral-7B-Instruct-v0.3"
return get_tokenizer(tokenizer_name=MODEL)
@pytest.fixture(scope="module")
def mistral_tokenizer():
MODEL = "mistralai/Mistral-Small-3.2-24B-Instruct-2506"
return get_tokenizer(tokenizer_name=MODEL, tokenizer_mode="mistral")
@pytest.fixture
def mistral_pre_v11_tool_parser(mistral_pre_v11_tokenizer):
return MistralToolParser(mistral_pre_v11_tokenizer)
@pytest.fixture
def mistral_tool_parser(mistral_tokenizer):
return MistralToolParser(mistral_tokenizer)
def assert_tool_calls(
actual_tool_calls: list[ToolCall] | list[DeltaToolCall],
expected_tool_calls: list[ToolCall],
):
assert len(actual_tool_calls) == len(expected_tool_calls)
for actual_tool_call, expected_tool_call in zip(
actual_tool_calls, expected_tool_calls
):
assert isinstance(actual_tool_call.id, str)
assert len(actual_tool_call.id) == 9
if isinstance(actual_tool_call, ToolCall):
assert actual_tool_call.type == "function"
elif isinstance(actual_tool_call, DeltaToolCall):
assert actual_tool_call.function is not None
assert actual_tool_call.function.name is not None
assert actual_tool_call.function.arguments is not None
assert actual_tool_call.function is not None
assert actual_tool_call.function.name == expected_tool_call.function.name, (
f"got wrong function name:${actual_tool_call.function.name}"
)
assert (
actual_tool_call.function.arguments == expected_tool_call.function.arguments
), f"got wrong function argument:${actual_tool_call.function.arguments}"
def fix_tool_call_tokenization(
tokens: list[int],
mistral_tool_parser: MistralToolParser,
mistral_tokenizer: TokenizerLike,
):
"""
Replaces the textual token sequence for [TOOL_CALLS]
with its single special token ID.
"""
textual_tool_call_token_ids = mistral_tokenizer.encode(
text=mistral_tool_parser.bot_token,
add_special_tokens=False,
)
# textual_tool_call_token_ids must not contain special tokens like bos, eos etc
special_tool_call_token_ids = [mistral_tool_parser.bot_token_id]
# If the input is too short to contain the sequence, no replacement is possible
if not tokens or len(tokens) < len(textual_tool_call_token_ids):
return tokens
result_tokens = []
i = 0
target_len = len(textual_tool_call_token_ids)
while i < len(tokens):
# Check if the slice from the current position matches the target sequence
if tokens[i : i + target_len] == textual_tool_call_token_ids:
# If it matches, add the replacement and jump the index forward
result_tokens.extend(special_tool_call_token_ids)
i += target_len
else:
# Otherwise, just add the current token and move to the next one
result_tokens.append(tokens[i])
i += 1
return result_tokens
def stream_delta_message_generator(
mistral_tool_parser: MistralToolParser,
mistral_tokenizer: TokenizerLike,
model_output: str | None,
tools: list[tuple[str, str]] | None,
) -> Generator[DeltaMessage, None, None]:
if (
isinstance(mistral_tokenizer, MistralTokenizer)
and mistral_tokenizer.version >= 11
):
# With the newer versions of the tokenizer,
# we cannot tokenize free text
# so we need to create a list of messages to get tokenized
assert tools is not None
assistant_msg = AssistantMessage(
tool_calls=[
ToolCall(
function=FunctionCall(
name=name,
arguments=arg,
)
)
for (name, arg) in tools
],
)
request = InstructRequest(
messages=[assistant_msg],
)
all_token_ids = mistral_tokenizer.instruct.encode_instruct(request).tokens
else:
# Older versions of the tokenizer are
# able to encode directly the model's output (free text) into tokens
assert model_output is not None
all_token_ids = mistral_tokenizer.encode(model_output, add_special_tokens=False)
all_token_ids = fix_tool_call_tokenization(
all_token_ids, mistral_tool_parser, mistral_tokenizer
)
previous_text = ""
previous_tokens = None
prefix_offset = 0
read_offset = 0
for i, delta_token in enumerate(all_token_ids):
delta_token_ids = [delta_token]
previous_token_ids = all_token_ids[:i]
current_token_ids = all_token_ids[: i + 1]
(new_tokens, delta_text, new_prefix_offset, new_read_offset) = (
detokenize_incrementally(
tokenizer=mistral_tokenizer,
all_input_ids=current_token_ids,
prev_tokens=previous_tokens,
prefix_offset=prefix_offset,
read_offset=read_offset,
skip_special_tokens=isinstance(mistral_tokenizer, MistralTokenizer),
spaces_between_special_tokens=True,
)
)
current_text = previous_text + delta_text
delta_message = mistral_tool_parser.extract_tool_calls_streaming(
previous_text,
current_text,
delta_text,
previous_token_ids,
current_token_ids,
delta_token_ids,
request=None, # type: ignore[arg-type]
)
if delta_message:
yield delta_message
previous_text = current_text
previous_tokens = (
previous_tokens + new_tokens if previous_tokens else new_tokens
)
prefix_offset = new_prefix_offset
read_offset = new_read_offset
def test_extract_tool_calls_no_tools(mistral_pre_v11_tool_parser):
model_output = "This is a test"
extracted_tool_calls = mistral_pre_v11_tool_parser.extract_tool_calls(
model_output, request=None
) # type: ignore[arg-type]
assert not extracted_tool_calls.tools_called
assert extracted_tool_calls.tool_calls == []
assert extracted_tool_calls.content == model_output
@pytest.mark.parametrize(
ids=[
"single_tool_add",
"single_tool_weather",
"argument_before_name",
"argument_before_name_and_name_in_argument",
],
argnames=["model_output", "expected_tool_calls", "expected_content"],
argvalues=[
(
"""[TOOL_CALLS][{"name": "add", "arguments":{"a": 3.5, "b": 4}}]""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="add", arguments=json.dumps({"a": 3.5, "b": 4})
)
)
],
None,
),
(
"""[TOOL_CALLS] [{"name": "get_current_weather", "arguments":{"city": "San Francisco", "state": "CA", "unit": "celsius"}}]""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{"city": "San Francisco", "state": "CA", "unit": "celsius"}
),
)
)
],
None,
),
(
"""[TOOL_CALLS] [{"arguments":{"city": "San Francisco", "state": "CA", "unit": "celsius"}, "name": "get_current_weather"}]""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{"city": "San Francisco", "state": "CA", "unit": "celsius"}
),
)
)
],
None,
),
(
"""[TOOL_CALLS] [{"arguments":{"name": "John Doe"}, "name": "get_age"}]""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="get_age",
arguments=json.dumps(
{
"name": "John Doe",
}
),
)
)
],
None,
),
],
)
def test_extract_tool_calls_pre_v11_tokenizer(
mistral_pre_v11_tool_parser, model_output, expected_tool_calls, expected_content
):
extracted_tool_calls = mistral_pre_v11_tool_parser.extract_tool_calls(
model_output, request=None
) # type: ignore[arg-type]
assert extracted_tool_calls.tools_called
assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls)
assert extracted_tool_calls.content == expected_content
@pytest.mark.parametrize(
ids=[
"single_tool_add",
"single_tool_weather",
"multiple_tool_calls",
],
argnames=["model_output", "expected_tool_calls", "expected_content"],
argvalues=[
(
"""[TOOL_CALLS]add_this_and_that{"a": 3.5, "b": 4}""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="add_this_and_that",
arguments=json.dumps({"a": 3.5, "b": 4}),
)
)
],
None,
),
(
"""[TOOL_CALLS]get_current_weather{"city": "San Francisco", "state": "CA", "unit": "celsius"}""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{"city": "San Francisco", "state": "CA", "unit": "celsius"}
),
)
)
],
None,
),
(
"""[TOOL_CALLS]add{"a": 3.5, "b": 4}[TOOL_CALLS]multiply{"a": 3, "b": 6}""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="add", arguments=json.dumps({"a": 3.5, "b": 4})
)
),
ToolCall(
function=FunctionCall(
name="multiply", arguments=json.dumps({"a": 3, "b": 6})
)
),
],
None,
),
],
)
def test_extract_tool_calls(
mistral_tool_parser, model_output, expected_tool_calls, expected_content
):
extracted_tool_calls = mistral_tool_parser.extract_tool_calls(
model_output, request=None
) # type: ignore[arg-type]
assert extracted_tool_calls.tools_called
assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls)
assert extracted_tool_calls.content == expected_content
def _test_extract_tool_calls_streaming(
tool_parser, tokenizer, model_output, tools, expected_tool_calls, expected_content
):
other_content: str = ""
function_names: list[str] = []
function_args_strs: list[str] = []
tool_call_idx: int = -1
tool_call_ids: list[str | None] = []
for delta_message in stream_delta_message_generator(
tool_parser, tokenizer, model_output, tools
):
# role should never be streamed from tool parser
assert not delta_message.role
if delta_message.content:
other_content += delta_message.content
streamed_tool_calls = delta_message.tool_calls
if streamed_tool_calls and len(streamed_tool_calls) > 0:
# make sure only one diff is present - correct even for parallel
assert len(streamed_tool_calls) == 1
tool_call = streamed_tool_calls[0]
assert len(tool_parser.prev_tool_call_arr) > 0
# if a new tool is being called, set up empty arguments
if tool_call.index != tool_call_idx:
tool_call_idx = tool_call.index
function_args_strs.append("")
tool_call_ids.append(None)
# if a tool call ID is streamed, make sure one hasn't been already
if tool_call.id and not tool_call_ids[tool_call.index]:
tool_call_ids[tool_call.index] = tool_call.id
# if parts of the function start being streamed
if tool_call.function:
# if the function name is defined, set it. it should be streamed
# IN ENTIRETY, exactly one time.
if tool_call.function.name:
assert isinstance(tool_call.function.name, str)
function_names.append(tool_call.function.name)
if tool_call.function.arguments:
# make sure they're a string and then add them to the list
assert isinstance(tool_call.function.arguments, str)
function_args_strs[tool_call.index] += tool_call.function.arguments
assert other_content == expected_content
actual_tool_calls = [
ToolCall(
id=tool_call_id,
function=FunctionCall(
name=function_name,
arguments=partial_json_parser.ensure_json(
function_args_str, Allow.OBJ | Allow.STR
),
),
)
for tool_call_id, function_name, function_args_str in zip(
tool_call_ids, function_names, function_args_strs
)
]
assert_tool_calls(actual_tool_calls, expected_tool_calls)
@pytest.mark.parametrize(
ids=[
"no_tools",
"single_tool_add",
"single_tool_add_strings",
"single_tool_weather",
"argument_before_name",
"argument_before_name_and_name_in_argument",
"multiple_tools",
],
argnames=["model_output", "expected_tool_calls", "expected_content"],
argvalues=[
("""This is a test""", [], """This is a test"""),
(
"""[TOOL_CALLS] [ {"name":"add" , "arguments" : {"a": 3, "b": 4} } ]""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="add", arguments=json.dumps({"a": 3, "b": 4})
)
)
],
"",
),
(
"""[TOOL_CALLS] [{"name": "add", "arguments":{"a": "3", "b": "4"}}]""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="add", arguments=json.dumps({"a": "3", "b": "4"})
)
)
],
"",
),
(
"""[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"city": "San Francisco", "state": "CA", "unit": "celsius"}}]""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{"city": "San Francisco", "state": "CA", "unit": "celsius"}
),
)
)
],
"",
),
(
"""[TOOL_CALLS] [{"arguments": {"city": "San Francisco", "state": "CA", "unit": "celsius"}, "name": "get_current_weather"}]""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{"city": "San Francisco", "state": "CA", "unit": "celsius"}
),
)
)
],
"",
),
(
"""[TOOL_CALLS] [{"arguments": {"name": "John Doe"}, "name": "get_age"}]""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="get_age",
arguments=json.dumps(
{
"name": "John Doe",
}
),
)
)
],
"",
),
(
"""[TOOL_CALLS] [{"name": "add", "arguments": {"a": 3.5, "b": 4}}, {"name": "get_current_weather", "arguments":{"city": "San Francisco", "state": "CA", "unit": "celsius"}}]""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="add", arguments=json.dumps({"a": 3.5, "b": 4})
)
),
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{"city": "San Francisco", "state": "CA", "unit": "celsius"}
),
)
),
],
"",
),
],
)
def test_extract_tool_calls_streaming_pre_v11_tokenizer(
mistral_pre_v11_tool_parser,
mistral_pre_v11_tokenizer,
model_output,
expected_tool_calls,
expected_content,
):
_test_extract_tool_calls_streaming(
mistral_pre_v11_tool_parser,
mistral_pre_v11_tokenizer,
model_output,
None,
expected_tool_calls,
expected_content,
)
@pytest.mark.parametrize(
ids=[
"single_tool_add",
"single_tool_add_strings",
"multiple_tools",
],
argnames=["tools", "expected_tool_calls", "expected_content"],
argvalues=[
(
[("add", '{"a": 3, "b": 4}')],
# [TOOL_CALLS]add{"a": 3, "b": 4}
[
ToolCall(
function=FunctionCall(
name="add", arguments=json.dumps({"a": 3, "b": 4})
)
)
],
"",
),
(
[("add_two_strings", '{"a": "3", "b": "4"}')],
# [TOOL_CALLS]add_two_strings{"a": "3", "b": "4"}
[
ToolCall(
function=FunctionCall(
name="add_two_strings",
arguments=json.dumps({"a": "3", "b": "4"}),
)
)
],
"",
),
(
[
("add", '{"a": 3.5, "b": 4}'),
(
"get_current_weather",
'{"city": "San Francisco", "state": "CA", "unit": "celsius"}', # noqa: E501
),
],
# [TOOL_CALLS]add{"a": 3.5, "b": 4}[TOOL_CALLS]get_current_weather{"city": "San Francisco", "state": "CA", "unit": "celsius"} # noqa: E501
[
ToolCall(
function=FunctionCall(
name="add", arguments=json.dumps({"a": 3.5, "b": 4})
)
),
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{"city": "San Francisco", "state": "CA", "unit": "celsius"}
),
)
),
],
"",
),
],
)
def test_extract_tool_calls_streaming(
mistral_tool_parser,
mistral_tokenizer,
tools,
expected_tool_calls,
expected_content,
):
_test_extract_tool_calls_streaming(
mistral_tool_parser,
mistral_tokenizer,
None,
tools,
expected_tool_calls,
expected_content,
)
@pytest.mark.parametrize(
ids=[
"single_tool_add",
"single_tool_weather",
"multiple_tool_calls",
"content_before_tool",
"complex",
],
argnames=["model_output", "expected_tool_calls", "expected_content"],
argvalues=[
(
"""[TOOL_CALLS]add_this_and_that{"a": 3.5, "b": 4}""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="add_this_and_that",
arguments=json.dumps({"a": 3.5, "b": 4}),
)
)
],
"",
),
(
"""[TOOL_CALLS]get_current_weather{"city": "San Francisco", "state": "CA", "unit": "celsius"}""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{"city": "San Francisco", "state": "CA", "unit": "celsius"}
),
)
)
],
"",
),
(
"""[TOOL_CALLS]add{"a": 3.5, "b": 4}[TOOL_CALLS]multiply{"a": 3, "b": 6}""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="add", arguments=json.dumps({"a": 3.5, "b": 4})
)
),
ToolCall(
function=FunctionCall(
name="multiply", arguments=json.dumps({"a": 3, "b": 6})
)
),
],
"",
),
(
# Additional content should not be after the tool calls
"""bla[TOOL_CALLS]add_this_and_that{"a": 3.5, "b": 4}""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="add_this_and_that",
arguments=json.dumps({"a": 3.5, "b": 4}),
)
)
],
"bla",
),
(
# Complex
"""[TOOL_CALLS]bash{"command": "print(\\"hello world!\\")\\nre.compile(r\'{}\')"}""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="bash",
arguments=json.dumps(
{"command": "print(\"hello world!\")\nre.compile(r'{}')"}
),
)
)
],
"",
),
],
)
def test_extract_tool_calls_streaming_one_chunk(
mistral_tool_parser,
mistral_tokenizer,
model_output,
expected_tool_calls,
expected_content,
):
if isinstance(mistral_tokenizer, MistralTokenizer):
all_token_ids = mistral_tokenizer.encode(model_output)
else:
all_token_ids = mistral_tokenizer.encode(model_output, add_special_tokens=False)
all_token_ids = fix_tool_call_tokenization(
all_token_ids, mistral_tool_parser, mistral_tokenizer
)
delta_message = mistral_tool_parser.extract_tool_calls_streaming(
previous_text="",
current_text=model_output,
delta_text=model_output,
previous_token_ids=[],
current_token_ids=all_token_ids,
delta_token_ids=all_token_ids,
request=None,
) # type: ignore[arg-type]
assert isinstance(delta_message, DeltaMessage)
assert len(delta_message.tool_calls) == len(expected_tool_calls)
assert_tool_calls(delta_message.tool_calls, expected_tool_calls)
if delta_message.content is None:
assert expected_content == ""
else:
assert delta_message.content == expected_content
@pytest.mark.parametrize(
ids=[
"no_tools",
"single_tool_add",
"single_tool_add_strings",
"single_tool_weather",
"argument_before_name",
"argument_before_name_and_name_in_argument",
"multiple_tools",
],
argnames=["model_output", "expected_tool_calls", "expected_content"],
argvalues=[
("""This is a test""", [], """This is a test"""),
(
"""[TOOL_CALLS] [ {"name":"add" , "arguments" : {"a": 3, "b": 4} } ]""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="add", arguments=json.dumps({"a": 3, "b": 4})
)
)
],
"",
),
(
"""[TOOL_CALLS] [{"name": "add", "arguments":{"a": "3", "b": "4"}}]""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="add", arguments=json.dumps({"a": "3", "b": "4"})
)
)
],
"",
),
(
"""[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"city": "San Francisco", "state": "CA", "unit": "celsius"}}]""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{"city": "San Francisco", "state": "CA", "unit": "celsius"}
),
)
)
],
"",
),
(
"""[TOOL_CALLS] [{"arguments": {"city": "San Francisco", "state": "CA", "unit": "celsius"}, "name": "get_current_weather"}]""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{"city": "San Francisco", "state": "CA", "unit": "celsius"}
),
)
)
],
"",
),
(
"""[TOOL_CALLS] [{"arguments": {"name": "John Doe"}, "name": "get_age"}]""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="get_age",
arguments=json.dumps(
{
"name": "John Doe",
}
),
)
)
],
"",
),
(
"""[TOOL_CALLS] [{"arguments": {"a": 3.5, "b": 4}, "name": "add"}, {"arguments":{"city": "San Francisco", "state": "CA", "unit": "celsius"}, "name": "get_current_weather"}]""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="add", arguments=json.dumps({"a": 3.5, "b": 4})
)
),
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{"city": "San Francisco", "state": "CA", "unit": "celsius"}
),
)
),
],
"",
),
],
)
def test_extract_tool_calls_streaming_pre_v11_tokenizer_one_chunk(
mistral_pre_v11_tool_parser,
mistral_pre_v11_tokenizer,
model_output,
expected_tool_calls,
expected_content,
):
if isinstance(mistral_pre_v11_tokenizer, MistralTokenizer):
all_token_ids = mistral_pre_v11_tokenizer.encode(model_output)
else:
all_token_ids = mistral_pre_v11_tokenizer.encode(
model_output, add_special_tokens=False
)
all_token_ids = fix_tool_call_tokenization(
all_token_ids, mistral_pre_v11_tool_parser, mistral_pre_v11_tokenizer
)
delta_message = mistral_pre_v11_tool_parser.extract_tool_calls_streaming(
previous_text="",
current_text=model_output,
delta_text=model_output,
previous_token_ids=[],
current_token_ids=all_token_ids,
delta_token_ids=all_token_ids,
request=None,
) # type: ignore[arg-type]
assert isinstance(delta_message, DeltaMessage)
assert len(delta_message.tool_calls) == len(expected_tool_calls)
assert_tool_calls(delta_message.tool_calls, expected_tool_calls)
if delta_message.content is None:
assert expected_content == ""
else:
assert delta_message.content == expected_content
...@@ -123,7 +123,7 @@ CONFIGS: dict[str, ServerConfig] = { ...@@ -123,7 +123,7 @@ CONFIGS: dict[str, ServerConfig] = {
"supports_parallel": True, "supports_parallel": True,
"extended": True, "extended": True,
}, },
"mistral": { "mistral-7b": {
"model": "mistralai/Mistral-7B-Instruct-v0.3", "model": "mistralai/Mistral-7B-Instruct-v0.3",
"arguments": [ "arguments": [
"--enforce-eager", "--enforce-eager",
...@@ -145,6 +145,32 @@ CONFIGS: dict[str, ServerConfig] = { ...@@ -145,6 +145,32 @@ CONFIGS: dict[str, ServerConfig] = {
"call the tool. Otherwise, answer the user's query directly " "call the tool. Otherwise, answer the user's query directly "
"without calling a tool. DO NOT CALL A TOOL THAT IS IRRELEVANT " "without calling a tool. DO NOT CALL A TOOL THAT IS IRRELEVANT "
"to the user's question - just respond to it normally.", "to the user's question - just respond to it normally.",
"supports_parallel": True,
},
"mistral-small-3.2": {
"model": "mistralai/Mistral-Small-3.2-24B-Instruct-2506",
"arguments": [
"--enforce-eager",
"--no-enable-prefix-caching",
"--tool-call-parser",
"mistral",
"--tokenizer-mode",
"mistral",
"--config-format",
"mistral",
"--load-format",
"mistral",
"--tensor-parallel-size",
"4",
'--ignore-patterns="consolidated.safetensors"',
],
"system_prompt": "You are a helpful assistant with access to tools. If a tool"
" that you have would be helpful to answer a user query, "
"call the tool. Otherwise, answer the user's query directly "
"without calling a tool. DO NOT CALL A TOOL THAT IS IRRELEVANT "
"to the user's question - just respond to it normally.",
"supports_parallel": True,
"extended": True,
}, },
# FIXME: This test currently fails, need to debug why. # FIXME: This test currently fails, need to debug why.
# "granite20b": { # "granite20b": {
......
...@@ -5,13 +5,15 @@ from unittest.mock import patch ...@@ -5,13 +5,15 @@ from unittest.mock import patch
import pytest import pytest
from vllm.transformers_utils.gguf_utils import (
is_gguf,
is_remote_gguf,
split_remote_gguf,
)
from vllm.transformers_utils.utils import ( from vllm.transformers_utils.utils import (
is_cloud_storage, is_cloud_storage,
is_gcs, is_gcs,
is_gguf,
is_remote_gguf,
is_s3, is_s3,
split_remote_gguf,
) )
...@@ -132,7 +134,7 @@ class TestSplitRemoteGGUF: ...@@ -132,7 +134,7 @@ class TestSplitRemoteGGUF:
class TestIsGGUF: class TestIsGGUF:
"""Test is_gguf utility function.""" """Test is_gguf utility function."""
@patch("vllm.transformers_utils.utils.check_gguf_file", return_value=True) @patch("vllm.transformers_utils.gguf_utils.check_gguf_file", return_value=True)
def test_is_gguf_with_local_file(self, mock_check_gguf): def test_is_gguf_with_local_file(self, mock_check_gguf):
"""Test is_gguf with local GGUF file.""" """Test is_gguf with local GGUF file."""
assert is_gguf("/path/to/model.gguf") assert is_gguf("/path/to/model.gguf")
...@@ -149,7 +151,7 @@ class TestIsGGUF: ...@@ -149,7 +151,7 @@ class TestIsGGUF:
assert not is_gguf("repo/model:quant") assert not is_gguf("repo/model:quant")
assert not is_gguf("repo/model:INVALID") assert not is_gguf("repo/model:INVALID")
@patch("vllm.transformers_utils.utils.check_gguf_file", return_value=False) @patch("vllm.transformers_utils.gguf_utils.check_gguf_file", return_value=False)
def test_is_gguf_false(self, mock_check_gguf): def test_is_gguf_false(self, mock_check_gguf):
"""Test is_gguf returns False for non-GGUF models.""" """Test is_gguf returns False for non-GGUF models."""
assert not is_gguf("unsloth/Qwen3-0.6B") assert not is_gguf("unsloth/Qwen3-0.6B")
......
...@@ -1225,9 +1225,9 @@ def get_attn_backend_list_based_on_platform() -> list[str]: ...@@ -1225,9 +1225,9 @@ def get_attn_backend_list_based_on_platform() -> list[str]:
try: try:
import aiter # noqa: F401 import aiter # noqa: F401
attn_backend_list.append("FLASH_ATTN") attn_backend_list.append("ROCM_AITER_FA")
except Exception: except Exception:
print("Skip FLASH_ATTN on ROCm as aiter is not installed") print("Skip ROCM_AITER_FA on ROCm as aiter is not installed")
return attn_backend_list return attn_backend_list
elif current_platform.is_xpu(): elif current_platform.is_xpu():
......
...@@ -458,25 +458,3 @@ def test_flat_product(): ...@@ -458,25 +458,3 @@ def test_flat_product():
(3, 4, "a", 5, 6), (3, 4, "a", 5, 6),
(3, 4, "b", 5, 6), (3, 4, "b", 5, 6),
] ]
def test_o_legacy_syntax_deprecation(caplog_vllm):
"""Test that -O.* dotted syntax emits warnings and converts correctly to -cc syntax."""
parser = FlexibleArgumentParser()
parser.add_argument("-cc", "--compilation-config", type=json.loads)
# Test that -O.backend gets converted correctly AND emits warning
args = parser.parse_args(["-O.backend=eager"])
assert args.compilation_config == {"backend": "eager"}
# Check that deprecation warning was logged
assert len(caplog_vllm.records) >= 1
assert (
"The -O.* dotted syntax for --compilation-config is deprecated"
in caplog_vllm.text
)
# Test that -O.mode gets converted correctly
# Note: warning_once won't emit again in same session
args = parser.parse_args(["-O.mode=2"])
assert args.compilation_config == {"mode": 2}
...@@ -13,7 +13,7 @@ from vllm.v1.attention.backends.utils import ( ...@@ -13,7 +13,7 @@ from vllm.v1.attention.backends.utils import (
split_attn_metadata, split_attn_metadata,
split_decodes_and_prefills, split_decodes_and_prefills,
) )
from vllm.v1.worker.ubatch_utils import create_ubatch_slices from vllm.v1.worker.ubatch_utils import maybe_create_ubatch_slices
@pytest.fixture @pytest.fixture
...@@ -154,7 +154,10 @@ def test_split_attn_metadata_decode_batch(large_decode_metadata): ...@@ -154,7 +154,10 @@ def test_split_attn_metadata_decode_batch(large_decode_metadata):
def apply_split_decodes_and_prefills( def apply_split_decodes_and_prefills(
query_lens: list[int], decode_threshold: int, require_uniform: bool query_lens: list[int],
decode_threshold: int,
require_uniform: bool,
padded_num_tokens: int | None = None,
): ):
"""Helper function to apply split_decodes_and_prefills and return """Helper function to apply split_decodes_and_prefills and return
the results.""" the results."""
...@@ -165,6 +168,10 @@ def apply_split_decodes_and_prefills( ...@@ -165,6 +168,10 @@ def apply_split_decodes_and_prefills(
block_size=16, block_size=16,
device=device, device=device,
) )
if padded_num_tokens is not None:
common_metadata.num_actual_tokens = padded_num_tokens
return split_decodes_and_prefills( return split_decodes_and_prefills(
common_metadata, common_metadata,
decode_threshold=decode_threshold, decode_threshold=decode_threshold,
...@@ -271,6 +278,22 @@ def test_split_decodes_and_prefills_uniform_mixed_batch_non_uniform_decodes(): ...@@ -271,6 +278,22 @@ def test_split_decodes_and_prefills_uniform_mixed_batch_non_uniform_decodes():
assert num_prefill_tokens == (sum(query_lens) - 2) # rest of the tokens assert num_prefill_tokens == (sum(query_lens) - 2) # rest of the tokens
def test_split_decodes_and_prefills_uniform_padded_batch_all_same():
"""uniform batch where all query lengths are identical with 0 length padded reqs."""
# All query lengths are 2, with decode_threshold=3 (so 2 <= 3)
# This triggers the padded uniform path at line 891
query_lens = [2, 2, 2, 0]
padded_num_tokens = 8
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
apply_split_decodes_and_prefills(query_lens, 3, True, padded_num_tokens)
)
# With uniform batch, all requests are treated as decodes
assert num_decodes == 4
assert num_prefills == 0
assert num_decode_tokens == padded_num_tokens
assert num_prefill_tokens == 0
@pytest.mark.parametrize( @pytest.mark.parametrize(
"seq_lens,query_lens,split_point,expected_first_reqs,expected_second_reqs", "seq_lens,query_lens,split_point,expected_first_reqs,expected_second_reqs",
[ [
...@@ -294,8 +317,14 @@ def test_prefill_split_across_ubatches( ...@@ -294,8 +317,14 @@ def test_prefill_split_across_ubatches(
qsl_np = common.query_start_loc_cpu.numpy() qsl_np = common.query_start_loc_cpu.numpy()
num_tokens = common.num_actual_tokens num_tokens = common.num_actual_tokens
ubatch_slices = create_ubatch_slices(num_scheduled_tokens, split_point) ubatch_slices, _ = maybe_create_ubatch_slices(
assert len(ubatch_slices) == 2 True,
num_scheduled_tokens,
num_tokens,
batch_spec.batch_size,
split_point=split_point,
)
assert ubatch_slices is not None and len(ubatch_slices) == 2
first_meta = _make_metadata_with_slice(ubatch_slices[0], common) first_meta = _make_metadata_with_slice(ubatch_slices[0], common)
second_meta = _make_metadata_with_slice(ubatch_slices[1], common) second_meta = _make_metadata_with_slice(ubatch_slices[1], common)
......
...@@ -106,8 +106,8 @@ def create_common_attn_metadata( ...@@ -106,8 +106,8 @@ def create_common_attn_metadata(
query_start_loc=query_start_loc, query_start_loc=query_start_loc,
query_start_loc_cpu=query_start_loc_cpu, query_start_loc_cpu=query_start_loc_cpu,
seq_lens=seq_lens, seq_lens=seq_lens,
seq_lens_cpu=seq_lens_cpu, _seq_lens_cpu=seq_lens_cpu,
num_computed_tokens_cpu=num_computed_tokens_cpu, _num_computed_tokens_cpu=num_computed_tokens_cpu,
num_reqs=batch_spec.batch_size, num_reqs=batch_spec.batch_size,
num_actual_tokens=num_tokens, num_actual_tokens=num_tokens,
max_query_len=max_query_len, max_query_len=max_query_len,
......
...@@ -11,7 +11,9 @@ PROMPTS = [ ...@@ -11,7 +11,9 @@ PROMPTS = [
] ]
def test_reset_prefix_cache_e2e(): def test_reset_prefix_cache_e2e(monkeypatch):
# "spawn" is required for test to be deterministic
monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
engine_args = EngineArgs( engine_args = EngineArgs(
model="Qwen/Qwen3-0.6B", model="Qwen/Qwen3-0.6B",
gpu_memory_utilization=0.2, gpu_memory_utilization=0.2,
...@@ -19,6 +21,7 @@ def test_reset_prefix_cache_e2e(): ...@@ -19,6 +21,7 @@ def test_reset_prefix_cache_e2e():
max_num_batched_tokens=32, max_num_batched_tokens=32,
max_model_len=2048, max_model_len=2048,
compilation_config={"mode": 0}, compilation_config={"mode": 0},
dtype="float16",
) )
engine = LLMEngine.from_engine_args(engine_args) engine = LLMEngine.from_engine_args(engine_args)
sampling_params = SamplingParams( sampling_params = SamplingParams(
......
...@@ -1536,7 +1536,7 @@ def create_scheduler_with_priority( ...@@ -1536,7 +1536,7 @@ def create_scheduler_with_priority(
) )
kv_transfer_config = ( kv_transfer_config = (
KVTransferConfig( KVTransferConfig(
kv_connector="SharedStorageConnector", kv_connector="ExampleConnector",
kv_role="kv_both", kv_role="kv_both",
kv_connector_extra_config={"shared_storage_path": "local_storage"}, kv_connector_extra_config={"shared_storage_path": "local_storage"},
) )
...@@ -1552,7 +1552,7 @@ def create_scheduler_with_priority( ...@@ -1552,7 +1552,7 @@ def create_scheduler_with_priority(
ec_transfer_config = ( ec_transfer_config = (
ECTransferConfig( ECTransferConfig(
ec_connector="ECSharedStorageConnector", ec_connector="ECExampleConnector",
ec_role=ec_role, ec_role=ec_role,
ec_connector_extra_config={"shared_storage_path": "/tmp/ec_test"}, ec_connector_extra_config={"shared_storage_path": "/tmp/ec_test"},
) )
...@@ -2413,7 +2413,7 @@ def _assert_right_ec_connector_metadata( ...@@ -2413,7 +2413,7 @@ def _assert_right_ec_connector_metadata(
metadata_dict = {mm_data.mm_hash: mm_data for mm_data in metadata.mm_datas} metadata_dict = {mm_data.mm_hash: mm_data for mm_data in metadata.mm_datas}
# Check all required identifiers exist in metadata; and no extra # Check all required identifiers exist in metadata; and no extra
# In ECSharedStorageConnector format # In ECExampleConnector format
# NOTE: even having same identifier, the mm_features can be different # NOTE: even having same identifier, the mm_features can be different
# since their mm_position can be in different offsets, etc # since their mm_position can be in different offsets, etc
identifiers_dict = {f.identifier for f in mm_features_list} identifiers_dict = {f.identifier for f in mm_features_list}
......
...@@ -108,7 +108,7 @@ def create_scheduler( ...@@ -108,7 +108,7 @@ def create_scheduler(
) )
elif use_kv_connector: elif use_kv_connector:
kv_transfer_config = KVTransferConfig( kv_transfer_config = KVTransferConfig(
kv_connector="SharedStorageConnector", kv_connector="ExampleConnector",
kv_role="kv_both", kv_role="kv_both",
kv_connector_extra_config={"shared_storage_path": "local_storage"}, kv_connector_extra_config={"shared_storage_path": "local_storage"},
) )
...@@ -121,7 +121,7 @@ def create_scheduler( ...@@ -121,7 +121,7 @@ def create_scheduler(
ec_transfer_config = ( ec_transfer_config = (
ECTransferConfig( ECTransferConfig(
ec_connector="ECSharedStorageConnector", ec_connector="ECExampleConnector",
ec_role=ec_role, ec_role=ec_role,
ec_connector_extra_config={"shared_storage_path": "/tmp/ec_test"}, ec_connector_extra_config={"shared_storage_path": "/tmp/ec_test"},
) )
......
...@@ -161,10 +161,10 @@ class TestCudagraphDispatcher: ...@@ -161,10 +161,10 @@ class TestCudagraphDispatcher:
assert rt_mode == CUDAGraphMode.NONE assert rt_mode == CUDAGraphMode.NONE
assert key == BatchDescriptor(num_tokens=15) assert key == BatchDescriptor(num_tokens=15)
# 4. Cascade attention should have a fall back mode # 4. disable_full should have a fall back mode (e.g., cascade attention)
desc_full_exact = BatchDescriptor(num_tokens=8, uniform=False) desc_full_exact = BatchDescriptor(num_tokens=8, uniform=False)
rt_mode, key = dispatcher.dispatch( rt_mode, key = dispatcher.dispatch(
num_tokens=8, uniform_decode=False, has_lora=False, use_cascade_attn=True num_tokens=8, uniform_decode=False, has_lora=False, disable_full=True
) )
if "PIECEWISE" in cudagraph_mode_str: # string contains check if "PIECEWISE" in cudagraph_mode_str: # string contains check
assert rt_mode == CUDAGraphMode.PIECEWISE assert rt_mode == CUDAGraphMode.PIECEWISE
......
...@@ -100,32 +100,20 @@ def test_backend_and_cudagraph_mode_combo(backend_name, cudagraph_mode, supporte ...@@ -100,32 +100,20 @@ def test_backend_and_cudagraph_mode_combo(backend_name, cudagraph_mode, supporte
# test cudagraph_mode with different compilation mode. # test cudagraph_mode with different compilation mode.
# (backend_name, cudagraph_mode, compilation_mode, supported) # (backend_name, cudagraph_mode, compilation_mode, supported)
if current_platform.is_rocm(): attn_backend = "RocmAttn" if current_platform.is_rocm() else "FA2"
combo_cases_2 = [
("RocmAttn", "FULL", CompilationMode.NONE, True), combo_cases_2 = [
("RocmAttn", "FULL", CompilationMode.VLLM_COMPILE, True), (attn_backend, "FULL", CompilationMode.NONE, True),
("RocmAttn", "PIECEWISE", CompilationMode.NONE, False), (attn_backend, "FULL", CompilationMode.VLLM_COMPILE, True),
("RocmAttn", "PIECEWISE", CompilationMode.VLLM_COMPILE, True), (attn_backend, "PIECEWISE", CompilationMode.NONE, True),
("RocmAttn", "FULL_AND_PIECEWISE", CompilationMode.NONE, False), (attn_backend, "PIECEWISE", CompilationMode.VLLM_COMPILE, True),
("RocmAttn", "FULL_AND_PIECEWISE", CompilationMode.VLLM_COMPILE, True), (attn_backend, "FULL_AND_PIECEWISE", CompilationMode.NONE, True),
("RocmAttn", "FULL_DECODE_ONLY", CompilationMode.NONE, True), (attn_backend, "FULL_AND_PIECEWISE", CompilationMode.VLLM_COMPILE, True),
("RocmAttn", "FULL_DECODE_ONLY", CompilationMode.VLLM_COMPILE, True), (attn_backend, "FULL_DECODE_ONLY", CompilationMode.NONE, True),
("RocmAttn", "NONE", CompilationMode.NONE, True), (attn_backend, "FULL_DECODE_ONLY", CompilationMode.VLLM_COMPILE, True),
("RocmAttn", "NONE", CompilationMode.VLLM_COMPILE, True), (attn_backend, "NONE", CompilationMode.NONE, True),
] (attn_backend, "NONE", CompilationMode.VLLM_COMPILE, True),
else: ]
combo_cases_2 = [
("FA2", "FULL", CompilationMode.NONE, True),
("FA2", "FULL", CompilationMode.VLLM_COMPILE, True),
("FA2", "PIECEWISE", CompilationMode.NONE, True),
("FA2", "PIECEWISE", CompilationMode.VLLM_COMPILE, True),
("FA2", "FULL_AND_PIECEWISE", CompilationMode.NONE, True),
("FA2", "FULL_AND_PIECEWISE", CompilationMode.VLLM_COMPILE, True),
("FA2", "FULL_DECODE_ONLY", CompilationMode.NONE, True),
("FA2", "FULL_DECODE_ONLY", CompilationMode.VLLM_COMPILE, True),
("FA2", "NONE", CompilationMode.NONE, True),
("FA2", "NONE", CompilationMode.VLLM_COMPILE, True),
]
@pytest.mark.parametrize( @pytest.mark.parametrize(
......
...@@ -10,6 +10,7 @@ from utils import ( ...@@ -10,6 +10,7 @@ from utils import (
BACKENDS, BACKENDS,
_extract_step_logprobs, _extract_step_logprobs,
_random_prompt, _random_prompt,
is_device_capability_below_90,
resolve_model_name, resolve_model_name,
skip_unsupported, skip_unsupported,
) )
...@@ -17,6 +18,8 @@ from utils import ( ...@@ -17,6 +18,8 @@ from utils import (
import vllm.model_executor.layers.batch_invariant as batch_invariant import vllm.model_executor.layers.batch_invariant as batch_invariant
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
IS_DEVICE_CAPABILITY_BELOW_90 = is_device_capability_below_90()
@skip_unsupported @skip_unsupported
@pytest.mark.timeout(1000) @pytest.mark.timeout(1000)
...@@ -185,11 +188,12 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN( ...@@ -185,11 +188,12 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
llm = LLM( llm = LLM(
model=model_name, model=model_name,
tensor_parallel_size=tp_size, tensor_parallel_size=tp_size,
enable_prefix_caching=False, # enable_prefix_caching=False,
max_num_seqs=32, max_num_seqs=32,
max_model_len=8192, max_model_len=8192,
dtype="bfloat16", # not everything is supported dtype="bfloat16", # not everything is supported
gpu_memory_utilization=0.9, gpu_memory_utilization=0.9,
enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90,
) )
# Use more realistic prompts for better token generation # Use more realistic prompts for better token generation
...@@ -394,6 +398,7 @@ def test_simple_generation(backend, monkeypatch: pytest.MonkeyPatch): ...@@ -394,6 +398,7 @@ def test_simple_generation(backend, monkeypatch: pytest.MonkeyPatch):
max_model_len=2048, max_model_len=2048,
dtype="bfloat16", dtype="bfloat16",
enable_prefix_caching=False, enable_prefix_caching=False,
enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90,
) )
prompt = "the capital of france is" prompt = "the capital of france is"
...@@ -457,10 +462,10 @@ def test_logprobs_without_batch_invariance_should_fail( ...@@ -457,10 +462,10 @@ def test_logprobs_without_batch_invariance_should_fail(
llm = LLM( llm = LLM(
model=model_name, model=model_name,
tensor_parallel_size=tp_size, tensor_parallel_size=tp_size,
enable_prefix_caching=False,
max_num_seqs=32, max_num_seqs=32,
max_model_len=8192, max_model_len=8192,
dtype="bfloat16", dtype="bfloat16",
enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90,
) )
# build ragged prompts to change shapes significantly across BS=1 vs BS=N # build ragged prompts to change shapes significantly across BS=1 vs BS=N
...@@ -681,10 +686,10 @@ def test_decode_logprobs_match_prefill_logprobs( ...@@ -681,10 +686,10 @@ def test_decode_logprobs_match_prefill_logprobs(
llm = LLM( llm = LLM(
model=model_name, model=model_name,
tensor_parallel_size=tp_size, tensor_parallel_size=tp_size,
enable_prefix_caching=False,
max_num_seqs=32, max_num_seqs=32,
max_model_len=8192, max_model_len=8192,
dtype="bfloat16", dtype="bfloat16",
enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90,
) )
# Use a few test prompts # Use a few test prompts
...@@ -929,6 +934,7 @@ def LLM_with_max_seqs( ...@@ -929,6 +934,7 @@ def LLM_with_max_seqs(
dtype="bfloat16", dtype="bfloat16",
tensor_parallel_size=int(os.getenv("VLLM_TP_SIZE", "1")), tensor_parallel_size=int(os.getenv("VLLM_TP_SIZE", "1")),
enable_prefix_caching=False, enable_prefix_caching=False,
enforce_eager=IS_DEVICE_CAPABILITY_BELOW_90,
# Enable for MOE models # Enable for MOE models
# enable_expert_parallel=True, # enable_expert_parallel=True,
) )
...@@ -153,7 +153,10 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN( ...@@ -153,7 +153,10 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
} }
tp_size = os.getenv("VLLM_TP_SIZE", "1") tp_size = os.getenv("VLLM_TP_SIZE", "1")
server_args: list[str] = [] server_args: list[str] = [
"--max-model-len=8192",
"--max-num-seqs=32",
]
if tp_size: if tp_size:
server_args += ["-tp", tp_size] server_args += ["-tp", tp_size]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment