Commit d2b52805 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.10.2rc1' into v0.10.2rc1-ori

parents 9a521c23 5438967f
......@@ -2,12 +2,17 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
from collections.abc import Generator
from typing import Optional
import pytest
from vllm.entrypoints.openai.protocol import FunctionCall, ToolCall
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaMessage, FunctionCall,
ToolCall)
from vllm.entrypoints.openai.tool_parsers import xLAMToolParser
from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.transformers_utils.detokenizer import detokenize_incrementally
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
# Use a common model that is likely to be available
MODEL = "Salesforce/Llama-xLAM-2-8B-fc-r"
......@@ -36,6 +41,56 @@ def assert_tool_calls(actual_tool_calls: list[ToolCall],
assert actual_tool_call.function == expected_tool_call.function
def stream_delta_message_generator(
xlam_tool_parser: xLAMToolParser,
xlam_tokenizer: AnyTokenizer,
model_output: str,
request: Optional[ChatCompletionRequest] = None,
) -> Generator[DeltaMessage, None, None]:
all_token_ids = xlam_tokenizer.encode(model_output,
add_special_tokens=False)
previous_text = ""
previous_tokens = None
prefix_offset = 0
read_offset = 0
for i, delta_token in enumerate(all_token_ids):
delta_token_ids = [delta_token]
previous_token_ids = all_token_ids[:i]
current_token_ids = all_token_ids[:i + 1]
(new_tokens, delta_text, new_prefix_offset,
new_read_offset) = (detokenize_incrementally(
tokenizer=xlam_tokenizer,
all_input_ids=current_token_ids,
prev_tokens=previous_tokens,
prefix_offset=prefix_offset,
read_offset=read_offset,
skip_special_tokens=False,
spaces_between_special_tokens=True,
))
current_text = previous_text + delta_text
delta_message = xlam_tool_parser.extract_tool_calls_streaming(
previous_text,
current_text,
delta_text,
previous_token_ids,
current_token_ids,
delta_token_ids,
request=request,
)
if delta_message:
yield delta_message
previous_text = current_text
previous_tokens = (previous_tokens +
new_tokens if previous_tokens else new_tokens)
prefix_offset = new_prefix_offset
read_offset = new_read_offset
def test_extract_tool_calls_no_tools(xlam_tool_parser):
model_output = "This is a test"
extracted_tool_calls = xlam_tool_parser.extract_tool_calls(
......@@ -51,6 +106,7 @@ def test_extract_tool_calls_no_tools(xlam_tool_parser):
"single_tool_with_think_tag",
"single_tool_with_json_code_block",
"single_tool_with_tool_calls_tag",
"single_tool_with_tool_call_xml_tags",
],
argnames=["model_output", "expected_tool_calls", "expected_content"],
argvalues=[
......@@ -118,6 +174,20 @@ def test_extract_tool_calls_no_tools(xlam_tool_parser):
],
"I'll check the weather for you.",
),
(
"""I'll help you check the weather.<tool_call>[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]</tool_call>""", # noqa: E501
[
ToolCall(function=FunctionCall(
name="get_current_weather",
arguments=json.dumps({
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit",
}),
))
],
"I'll help you check the weather.",
),
],
)
def test_extract_tool_calls(xlam_tool_parser, model_output,
......@@ -245,3 +315,147 @@ def test_streaming_with_list_structure(xlam_tool_parser):
assert hasattr(result, "tool_calls")
assert len(result.tool_calls) == 1
assert result.tool_calls[0].function.name == "get_current_weather"
@pytest.mark.parametrize(
ids=[
"parallel_tool_calls",
"single_tool_with_think_tag",
"single_tool_with_json_code_block",
"single_tool_with_tool_calls_tag",
"single_tool_with_tool_call_xml_tags",
],
argnames=["model_output", "expected_tool_calls", "expected_content"],
argvalues=[
(
"""[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}, {"name": "get_current_weather", "arguments": {"city": "Orlando", "state": "FL", "unit": "fahrenheit"}}]""", # noqa: E501
[
ToolCall(function=FunctionCall(
name="get_current_weather",
arguments=json.dumps({
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit",
}),
)),
ToolCall(function=FunctionCall(
name="get_current_weather",
arguments=json.dumps({
"city": "Orlando",
"state": "FL",
"unit": "fahrenheit",
}),
)),
],
"",
),
(
"""<think>I'll help you with that.</think>[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]""", # noqa: E501
[
ToolCall(function=FunctionCall(
name="get_current_weather",
arguments=json.dumps({
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit",
}),
))
],
"<think>I'll help you with that.</think>",
),
(
"""```json\n[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]\n```""", # noqa: E501
[
ToolCall(function=FunctionCall(
name="get_current_weather",
arguments=json.dumps({
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit",
}),
))
],
"",
),
(
"""[TOOL_CALLS][{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]""", # noqa: E501
[
ToolCall(function=FunctionCall(
name="get_current_weather",
arguments=json.dumps({
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit",
}),
))
],
"",
),
(
"""I can help with that.<tool_call>[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]</tool_call>""", # noqa: E501
[
ToolCall(function=FunctionCall(
name="get_current_weather",
arguments=json.dumps({
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit",
}),
))
],
"I can help with that.",
),
],
)
def test_extract_tool_calls_streaming_incremental(
xlam_tool_parser,
xlam_tokenizer,
model_output,
expected_tool_calls,
expected_content,
):
"""Verify the XLAM Parser streaming behavior by verifying each chunk is as expected.""" # noqa: E501
request = ChatCompletionRequest(model=MODEL, messages=[], tools=[])
chunks = []
for delta_message in stream_delta_message_generator(
xlam_tool_parser, xlam_tokenizer, model_output, request):
chunks.append(delta_message)
# Should have multiple chunks
assert len(chunks) >= 3
# Should have a chunk with tool header (id, name, type) for the first tool call # noqa: E501
header_found = False
expected_first_tool = expected_tool_calls[0]
for chunk in chunks:
if chunk.tool_calls and chunk.tool_calls[0].id:
header_found = True
assert (chunk.tool_calls[0].function.name ==
expected_first_tool.function.name)
assert chunk.tool_calls[0].type == "function"
# Arguments may be empty initially or None
if chunk.tool_calls[0].function.arguments is not None:
# If present, should be empty string initially
assert chunk.tool_calls[0].function.arguments == ""
break
assert header_found
# Should have chunks with incremental arguments
arg_chunks = []
for chunk in chunks:
if (chunk.tool_calls and chunk.tool_calls[0].function.arguments
and chunk.tool_calls[0].function.arguments != ""
and chunk.tool_calls[0].index ==
0 # Only collect arguments from the first tool call
):
arg_chunks.append(chunk.tool_calls[0].function.arguments)
# Arguments should be streamed incrementally
assert len(arg_chunks) > 1
# Concatenated arguments should form valid JSON for the first tool call
full_args = "".join(arg_chunks)
parsed_args = json.loads(full_args)
expected_args = json.loads(expected_first_tool.function.arguments)
assert parsed_args == expected_args
......@@ -5,6 +5,7 @@ import asyncio
import copy
import functools
import importlib
import json
import os
import signal
import subprocess
......@@ -101,7 +102,8 @@ class RemoteOpenAIServer:
env_dict: Optional[dict[str, str]] = None,
seed: Optional[int] = 0,
auto_port: bool = True,
max_wait_seconds: Optional[float] = None) -> None:
max_wait_seconds: Optional[float] = None,
override_hf_configs: Optional[dict[str, Any]] = None) -> None:
if auto_port:
if "-p" in vllm_serve_args or "--port" in vllm_serve_args:
raise ValueError("You have manually specified the port "
......@@ -120,6 +122,12 @@ class RemoteOpenAIServer:
vllm_serve_args = vllm_serve_args + ["--seed", str(seed)]
if override_hf_configs is not None:
vllm_serve_args = vllm_serve_args + [
"--hf-overrides",
json.dumps(override_hf_configs)
]
parser = FlexibleArgumentParser(
description="vLLM's remote OpenAI server.")
subparsers = parser.add_subparsers(required=False, dest="subparser")
......@@ -688,9 +696,12 @@ def multi_process_parallel(
os.environ["RAY_RUNTIME_ENV_IGNORE_GITIGNORE"] = "1"
ray.init(
runtime_env={
"working_dir": VLLM_PATH,
"excludes":
["build", ".git", "cmake-build-*", "shellcheck", "dist"]
"working_dir":
VLLM_PATH,
"excludes": [
"build", ".git", "cmake-build-*", "shellcheck", "dist",
"ep_kernels_workspace"
]
})
distributed_init_port = get_open_port()
......
......@@ -5,13 +5,17 @@
import asyncio
import hashlib
import json
import os
import pickle
import socket
import tempfile
from collections.abc import AsyncIterator
from pathlib import Path
from unittest.mock import patch
import pytest
import torch
import yaml
import zmq
from transformers import AutoTokenizer
from vllm_test_utils.monitor import monitor
......@@ -375,9 +379,9 @@ def test_duplicate_dict_args(caplog_vllm, parser):
def test_supports_kw(callable,kw_name,requires_kw_only,
allow_var_kwargs,is_supported):
assert supports_kw(
callable=callable,
kw_name=kw_name,
requires_kw_only=requires_kw_only,
callable=callable,
kw_name=kw_name,
requires_kw_only=requires_kw_only,
allow_var_kwargs=allow_var_kwargs
) == is_supported
......@@ -944,6 +948,36 @@ def test_join_host_port():
assert join_host_port("::1", 5555) == "[::1]:5555"
def test_json_count_leaves():
"""Test json_count_leaves function from jsontree utility."""
from vllm.utils.jsontree import json_count_leaves
# Single leaf values
assert json_count_leaves(42) == 1
assert json_count_leaves("hello") == 1
assert json_count_leaves(None) == 1
# Empty containers
assert json_count_leaves([]) == 0
assert json_count_leaves({}) == 0
assert json_count_leaves(()) == 0
# Flat structures
assert json_count_leaves([1, 2, 3]) == 3
assert json_count_leaves({"a": 1, "b": 2}) == 2
assert json_count_leaves((1, 2, 3)) == 3
# Nested structures
nested_dict = {"a": 1, "b": {"c": 2, "d": 3}}
assert json_count_leaves(nested_dict) == 3
nested_list = [1, [2, 3], 4]
assert json_count_leaves(nested_list) == 4
mixed_nested = {"list": [1, 2], "dict": {"x": 3}, "value": 4}
assert json_count_leaves(mixed_nested) == 4
def test_convert_ids_list_to_tokens():
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-1.5B-Instruct")
token_ids = tokenizer.encode("Hello, world!")
......@@ -991,3 +1025,40 @@ def test_current_stream_multithread():
child_thread.join(timeout=5)
if child_thread.is_alive():
pytest.fail("Child thread failed to exit properly")
def test_load_config_file(tmp_path):
# Define the configuration data
config_data = {
"enable-logging": True,
"list-arg": ["item1", "item2"],
"port": 12323,
"tensor-parallel-size": 4
}
# Write the configuration data to a temporary YAML file
config_file_path = tmp_path / "config.yaml"
with open(config_file_path, "w") as config_file:
yaml.dump(config_data, config_file)
# Initialize the parser
parser = FlexibleArgumentParser()
# Call the function with the temporary file path
processed_args = parser.load_config_file(str(config_file_path))
# Expected output
expected_args = [
"--enable-logging",
"--list-arg",
"item1",
"item2",
"--port",
"12323",
"--tensor-parallel-size",
"4",
]
# Assert that the processed arguments match the expected output
assert processed_args == expected_args
os.remove(str(config_file_path))
......@@ -10,14 +10,15 @@ from tests.v1.attention.utils import (BatchSpec, _Backend,
create_standard_kv_cache_spec,
create_vllm_config,
get_attention_backend)
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv, is_torch_equal_or_newer
from vllm.v1.attention.backends.utils import (CommonAttentionMetadata,
set_kv_cache_layout)
from vllm.v1.kv_cache_interface import FullAttentionSpec
BACKENDS_TO_TEST = [
_Backend.FLASH_ATTN_VLLM_V1, _Backend.FLASHINFER_VLLM_V1,
_Backend.FLEX_ATTENTION, _Backend.TRITON_ATTN_VLLM_V1, _Backend.TREE_ATTN
_Backend.FLEX_ATTENTION, _Backend.TRITON_ATTN_VLLM_V1, _Backend.TREE_ATTN,
"FLEX_ATTENTION_SLOW"
]
# Remove flashinfer from the list if it's not available
......@@ -97,7 +98,7 @@ def create_and_prepopulate_kv_cache(
common_attn_metadata: CommonAttentionMetadata,
randomize_blocks: bool = True) -> torch.Tensor:
"""Create and prepopulate a KV cache with context data.
Args:
k_contexts: List of key context tensors for each sequence
v_contexts: List of value context tensors for each sequence
......@@ -109,9 +110,9 @@ def create_and_prepopulate_kv_cache(
device: Device to create the cache on
num_blocks: Total number of blocks in the cache
block_table: Block table tensor to populate
randomize_blocks: Whether to randomly permute blocks
randomize_blocks: Whether to randomly permute blocks
or use sequential order
Returns:
Tuple of (kv_cache, updated_block_table)
"""
......@@ -150,15 +151,15 @@ def create_and_prepopulate_kv_cache(
# Permute the context blocks (excluding block 0 which is null)
if randomize_blocks:
perm = torch.randperm(
blocks_end - 1) + 1 # Random permutation starting from block 1
# Random permutation starting from block 1
perm = torch.randperm(blocks_end - 1) + 1
else:
perm = torch.arange(
1, blocks_end) # Sequential order starting from block 1
# Sequential order starting from block 1
perm = torch.arange(1, blocks_end)
inv_perm = torch.zeros(blocks_end, dtype=torch.long, device=device)
inv_perm[1:] = torch.argsort(
perm) + 1 # Add 1 to account for starting from block 1
# Add 1 to account for starting from block 1
inv_perm[1:] = torch.argsort(perm) + 1
kv_cache[:, 1:blocks_end, ...] = kv_cache[:, perm, ...]
# Construct the right block table
......@@ -206,10 +207,18 @@ def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec,
kv_cache: torch.Tensor) -> torch.Tensor:
"""Run attention computation using the specified backend's AttentionImpl."""
builder_cls, impl_cls = get_attention_backend(backend)
# Handle special case for FLEX_ATTENTION_SLOW
actual_backend = backend
use_direct_block_mask = is_torch_equal_or_newer("2.9.0.dev0")
if backend == "FLEX_ATTENTION_SLOW":
actual_backend = _Backend.FLEX_ATTENTION
use_direct_block_mask = False
builder_cls, impl_cls = get_attention_backend(actual_backend)
# Mock flashinfer's get_per_layer_parameters if needed
if backend == _Backend.FLASHINFER_VLLM_V1:
if actual_backend == _Backend.FLASHINFER_VLLM_V1:
import unittest.mock
from vllm.v1.attention.backends.utils import PerLayerParameters
......@@ -239,6 +248,8 @@ def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec,
else:
# Build metadata
builder = builder_cls(kv_cache_spec, layer_names, vllm_config, device)
if actual_backend == _Backend.FLEX_ATTENTION:
builder.direct_build = use_direct_block_mask
attn_metadata = builder.build(
common_prefix_len=0,
common_attn_metadata=common_attn_metadata,
......@@ -281,7 +292,8 @@ def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec,
@pytest.mark.parametrize("batch_spec_name", [
"small_decode", "small_prefill", "mixed_small", "medium_decode",
"medium_prefill", "mixed_medium"
"medium_prefill", "mixed_medium", "large_decode", "large_prefill",
"single_decode", "single_prefill"
])
@pytest.mark.parametrize("model", ["meta-llama/Meta-Llama-3-8B"])
def test_backend_correctness(batch_spec_name: str, model: str):
......@@ -302,7 +314,8 @@ def test_backend_correctness(batch_spec_name: str, model: str):
"""
batch_spec = BATCH_SPECS[batch_spec_name]
vllm_config = create_vllm_config(model_name=model,
max_model_len=max(batch_spec.seq_lens))
max_model_len=max(batch_spec.seq_lens),
num_gpu_blocks=8192)
device = torch.device("cuda:0")
kv_cache_spec = create_standard_kv_cache_spec(vllm_config)
......@@ -451,11 +464,6 @@ def test_backend_correctness(batch_spec_name: str, model: str):
rtol = 1e-2
atol = 5e-3
if backend_name == _Backend.FLEX_ATTENTION:
atol = 5e-1 # TODO: figure out why flex_attention has such large
# numerical differences for medium_decode, medium_prefill,
# mixed_medium
max_diff = torch.max(torch.abs(backend_output - sdpa_output)).item()
max_rel_diff = torch.max(
torch.abs(backend_output - sdpa_output) /
......@@ -465,12 +473,6 @@ def test_backend_correctness(batch_spec_name: str, model: str):
rtol=rtol,
atol=atol)
if not all_close:
print(f"[{backend_name}] output differs from SDPA baseline. "
f"Max diff: {max_diff:.6f} (rel: {max_rel_diff:.6f})")
print(f"[{backend_name}] output: {backend_output}")
print(f"[{backend_name}] SDPA baseline: {sdpa_output}")
assert all_close, (
f"[{backend_name}] output differs from SDPA baseline. "
f"Max diff: {max_diff:.6f} (rel: {max_rel_diff:.6f})")
f"Max diff: {max_diff:.6f}, max rel diff: {max_rel_diff:.6f})")
\ No newline at end of file
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for mamba attention backend selectors."""
from types import SimpleNamespace
import pytest
from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
from vllm.model_executor.layers.mamba.short_conv import ShortConv
from vllm.model_executor.models.minimax_text_01 import (
MiniMaxText01LinearAttention)
from vllm.v1.attention.backends.linear_attn import LinearAttentionBackend
from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionBackend
from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionBackend
from vllm.v1.attention.backends.short_conv_attn import (
ShortConvAttentionBackend)
@pytest.mark.parametrize(
"layer_class, init_kwargs, expected_backend, expected_mamba_type", [
(
MambaMixer,
dict(
hidden_size=128,
ssm_state_size=16,
conv_kernel_size=4,
intermediate_size=256,
time_step_rank=8,
use_conv_bias=True,
use_bias=False,
use_rms_norm=True,
),
Mamba1AttentionBackend,
"mamba1",
),
(
MambaMixer2,
dict(
hidden_size=128,
ssm_state_size=16,
conv_kernel_size=4,
intermediate_size=256,
use_conv_bias=True,
use_bias=False,
n_groups=1,
num_heads=8,
head_dim=32,
),
Mamba2AttentionBackend,
"mamba2",
),
(
MiniMaxText01LinearAttention,
dict(
hidden_size=128,
hidden_inner_size=256,
num_heads=8,
head_dim=32,
max_position=2048,
block_size=64,
num_hidden_layer=12,
layer_idx=0,
linear_layer_idx=0,
),
LinearAttentionBackend,
"linear_attention",
),
(
ShortConv,
dict(
config=SimpleNamespace(conv_L_cache=32, conv_bias=True),
dim=128,
layer_idx=0,
),
ShortConvAttentionBackend,
"short_conv",
),
])
def test_mamba_layers_get_attn_backend(dist_init, layer_class, init_kwargs,
expected_backend, expected_mamba_type):
"""Test that Mamba-like layers return the correct attention backend."""
layer = layer_class(**init_kwargs)
backend_class = layer.get_attn_backend()
assert backend_class is expected_backend
assert layer.mamba_type == expected_mamba_type
@pytest.mark.parametrize("layer_class,expected_backend,expected_mamba_type", [
(MambaMixer, Mamba1AttentionBackend, "mamba1"),
(MambaMixer2, Mamba2AttentionBackend, "mamba2"),
(MiniMaxText01LinearAttention, LinearAttentionBackend, "linear_attention"),
(ShortConv, ShortConvAttentionBackend, "short_conv"),
])
def test_mamba_layers_have_unified_interface(layer_class, expected_backend,
expected_mamba_type):
"""Test that all Mamba layers have the unified get_attn_backend
interface."""
assert hasattr(layer_class, 'get_attn_backend'), (
f"{layer_class.__name__} should have get_attn_backend method")
assert hasattr(layer_class, 'mamba_type'), (
f"{layer_class.__name__} should have mamba_type property")
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for mamba attention backend selectors."""
import pytest
from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionBackend
from vllm.v1.attention.backends.mamba_selectors import get_mamba_attn_backend
@pytest.mark.parametrize(argnames=["mamba_type", "expected_backend"],
argvalues=[("mamba2", Mamba2AttentionBackend)])
def test_get_mamba_attn_backend_mamba2(mamba_type, expected_backend):
backend_class = get_mamba_attn_backend(mamba_type)
assert backend_class is expected_backend
def test_get_mamba_attn_backend_unsupported():
unsupported_types = ["mamba", ""]
for mamba_type in unsupported_types:
err_message = f"Mamba Attention type {mamba_type} is not supported yet."
with pytest.raises(NotImplementedError, match=err_message):
get_mamba_attn_backend(mamba_type)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for v1 MLA backends without GPUModelRunner dependency."""
import pytest
import torch
from tests.v1.attention.utils import (BatchSpec, _Backend,
create_common_attn_metadata,
create_standard_kv_cache_spec,
create_vllm_config,
get_attention_backend)
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.kv_cache_interface import FullAttentionSpec
BACKENDS_TO_TEST = [
_Backend.CUTLASS_MLA, _Backend.FLASHMLA_VLLM_V1,
_Backend.TRITON_MLA_VLLM_V1
]
# Remove CUTLASS_MLA from the list if not using sm100
if not torch.cuda.is_available() or torch.cuda.get_device_properties(
0).major < 10:
BACKENDS_TO_TEST.remove(_Backend.CUTLASS_MLA)
torch.manual_seed(42)
def _convert_dtype_to_torch(dtype):
"""Convert ModelDType to torch.dtype."""
if isinstance(dtype, str):
if dtype == "auto":
return torch.float16 # Default dtype for testing
elif dtype in STR_DTYPE_TO_TORCH_DTYPE:
return STR_DTYPE_TO_TORCH_DTYPE[dtype]
else:
raise ValueError(f"Unknown dtype: {dtype}")
elif isinstance(dtype, torch.dtype):
return dtype
else:
raise ValueError(f"Unknown dtype: {dtype}")
# Define common batch configurations
BATCH_SPECS = {
"small_decode":
BatchSpec(seq_lens=[32, 40], query_lens=[1, 1]),
"small_prefill":
BatchSpec(seq_lens=[32, 40], query_lens=[8, 8]),
"mixed_small":
BatchSpec(seq_lens=[32, 40, 48, 56], query_lens=[1, 1, 5, 5]),
"medium_decode":
BatchSpec(seq_lens=[128, 256, 512, 1024, 128, 256, 512, 1024],
query_lens=[1, 1, 1, 1, 1, 1, 1, 1]),
"medium_prefill":
BatchSpec(seq_lens=[256, 512, 1024, 2048], query_lens=[16, 16, 16, 16]),
"mixed_medium":
BatchSpec(seq_lens=[512, 1024, 2048, 512, 1024, 2048],
query_lens=[1, 1, 1, 7, 7, 7]),
"large_decode":
BatchSpec(seq_lens=[2048] * 32, query_lens=[1] * 32),
"large_prefill":
BatchSpec(seq_lens=[4096] * 8, query_lens=[32] * 8),
"single_decode":
BatchSpec(seq_lens=[1024], query_lens=[1]),
"single_prefill":
BatchSpec(seq_lens=[1024], query_lens=[64]),
}
def create_dummy_kv_cache(kv_cache_spec: FullAttentionSpec,
device: torch.device,
num_blocks: int = 100) -> torch.Tensor:
"""Create a dummy KV cache tensor for testing."""
kv_cache = torch.randn(
num_blocks,
kv_cache_spec.block_size,
kv_cache_spec.head_size, # latent dimension
dtype=_convert_dtype_to_torch(kv_cache_spec.dtype),
device=device,
)
return kv_cache
def create_and_prepopulate_kv_cache(
kv_c_contexts: list[torch.Tensor],
k_pe_contexts: list[torch.Tensor],
block_size: int,
num_kv_heads: int,
head_size: int,
dtype: torch.dtype,
device: torch.device,
num_blocks: int,
common_attn_metadata: CommonAttentionMetadata,
randomize_blocks: bool = True) -> torch.Tensor:
"""Create and prepopulate an MLA KV cache with context data.
Args:
kv_c_contexts: List of latent KV context tensors for each sequence
k_pe_contexts: List of key positional embedding context tensors
for each sequence
block_size: Size of each block
num_kv_heads: Number of KV heads (should be 1 for MLA)
head_size: Size of each head (latent dimension)
dtype: Data type for the cache
device: Device to create the cache on
num_blocks: Total number of blocks in the cache
common_attn_metadata: Common attention metadata
randomize_blocks: Whether to randomly permute blocks
or use sequential order
Returns:
MLA KV cache tensor
"""
batch_size = len(kv_c_contexts)
seq_lens = common_attn_metadata.seq_lens_cpu
query_lens = common_attn_metadata.query_start_loc_cpu[
1:] - common_attn_metadata.query_start_loc_cpu[:-1]
context_lens = common_attn_metadata.num_computed_tokens_cpu
block_table = common_attn_metadata.block_table_tensor
slot_mapping = common_attn_metadata.slot_mapping
# Create MLA KV cache: (num_blocks, block_size, head_size)
kv_cache = torch.empty(num_blocks,
block_size,
head_size,
dtype=dtype,
device=device)
kv_cache_flat = kv_cache.view(-1, head_size)
# Populate the cache with the context tokens
# Start from block_id=1 since block_id=0 is considered the null block
start_block_idx = 1
for i in range(batch_size):
kv_c_context, k_pe_context = kv_c_contexts[i], k_pe_contexts[i]
kv_context = torch.cat([kv_c_context, k_pe_context.squeeze(1)], dim=-1)
start = start_block_idx * block_size
end = start + kv_context.shape[0]
kv_cache_flat[start:end, ...] = kv_context
# Stay block aligned and allocate enough blocks for the new tokens
start_block_idx += cdiv(int(seq_lens[i]), block_size)
blocks_end = start_block_idx
# Permute the context blocks (excluding block 0 which is null)
if randomize_blocks:
perm = torch.randperm(
blocks_end - 1) + 1 # Random permutation starting from block 1
else:
perm = torch.arange(
1, blocks_end) # Sequential order starting from block 1
inv_perm = torch.zeros(blocks_end, dtype=torch.long, device=device)
inv_perm[1:] = torch.argsort(
perm) + 1 # Add 1 to account for starting from block 1
kv_cache[1:blocks_end, ...] = kv_cache[perm, ...]
# Construct the right block table
# Start from block_id=1 since block_id=0 is considered the null block
start_block_idx = 1
for i in range(batch_size):
num_blocks_for_seq = cdiv(int(seq_lens[i]), block_size)
start = start_block_idx
end = start + num_blocks_for_seq
block_table[i, :num_blocks_for_seq] = inv_perm[start:end]
start_block_idx += num_blocks_for_seq
# Create a realistic slot mapping that corresponds to the block table
for i in range(batch_size):
token_offsets = torch.arange(int(query_lens[i])) + int(context_lens[i])
block_indices = token_offsets // block_size
token_inter_block_offsets = token_offsets % block_size
start = common_attn_metadata.query_start_loc_cpu[i]
end = common_attn_metadata.query_start_loc_cpu[i + 1]
slot_mapping[start:end] = block_table[
i,
block_indices] * block_size + token_inter_block_offsets.to(device)
return kv_cache
class MockAttentionLayer:
"""A mock attention layer for testing."""
def __init__(self, device: torch.device):
self._q_scale = torch.tensor(1.0, device=device)
self._k_scale = torch.tensor(1.0, device=device)
self._v_scale = torch.tensor(1.0, device=device)
def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec,
layer_names: list[str], vllm_config,
device: torch.device,
common_attn_metadata: CommonAttentionMetadata,
query: torch.Tensor, kv_c: torch.Tensor,
k_pe: torch.Tensor, kv_cache: torch.Tensor,
kv_lora_rank: int, qk_nope_head_dim: int,
qk_rope_head_dim: int, v_head_dim: int,
mock_kv_b_proj) -> torch.Tensor:
"""Run attention computation using the specified backend's AttentionImpl."""
builder_cls, impl_cls = get_attention_backend(backend)
# Build metadata
builder = builder_cls(kv_cache_spec, layer_names, vllm_config, device)
attn_metadata = builder.build(
common_prefix_len=0,
common_attn_metadata=common_attn_metadata,
)
# Instantiate MLA implementation
num_heads = vllm_config.model_config.get_num_attention_heads(
vllm_config.parallel_config)
num_kv_heads = vllm_config.model_config.get_num_kv_heads(
vllm_config.parallel_config)
head_size = vllm_config.model_config.get_head_size()
scale = 1.0 / (head_size**0.5)
impl = impl_cls(
num_heads=num_heads,
head_size=head_size,
scale=scale,
num_kv_heads=num_kv_heads,
alibi_slopes=None,
sliding_window=None,
kv_cache_dtype="auto",
logits_soft_cap=None,
attn_type="decoder",
kv_sharing_target_layer_name=None,
q_lora_rank=None,
kv_lora_rank=kv_lora_rank,
qk_nope_head_dim=qk_nope_head_dim,
qk_rope_head_dim=qk_rope_head_dim,
qk_head_dim=qk_nope_head_dim + qk_rope_head_dim,
v_head_dim=v_head_dim,
kv_b_proj=mock_kv_b_proj,
)
# Process weights to create W_UK_T and W_UV attributes needed by MLA
act_dtype = _convert_dtype_to_torch(vllm_config.model_config.dtype)
impl.process_weights_after_loading(act_dtype)
# Create mock layer and output buffer
mock_layer = MockAttentionLayer(device)
num_tokens = query.shape[0]
output = torch.empty(num_tokens,
num_heads * v_head_dim,
dtype=query.dtype,
device=query.device)
# Run forward pass
# NOTE: The query, key, and value are already shaped correctly
# in the calling test function.
output = impl.forward(mock_layer,
query,
kv_c,
k_pe,
kv_cache,
attn_metadata,
output=output)
return output
@pytest.mark.parametrize("batch_spec_name", [
"small_decode", "small_prefill", "mixed_small", "medium_decode",
"medium_prefill", "mixed_medium", "large_decode", "large_prefill",
"single_decode", "single_prefill"
])
@pytest.mark.parametrize("model", ["deepseek-ai/DeepSeek-V2-Lite-Chat"])
def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
"""
Test that all backends produce similar outputs to a reference implementation
using torch.nn.functional.scaled_dot_product_attention.
This test works by:
1. Generating a batch of sequences with specified context and query lengths.
2. Computing a ground-truth attention output using torch.sdpa on
contiguous Q, K, and V tensors.
3. Simulating vLLM's paged KV cache: It takes the context portion of the
K/V tensors and manually places them into a paged buffer according to
the test's (randomly generated) block table.
4. Running each vLLM attention backend with the new queries and the
simulated paged KV cache.
5. Comparing the vLLM backend's output to the ground-truth SDPA output.
"""
batch_spec = BATCH_SPECS[batch_spec_name]
vllm_config = create_vllm_config(model_name=model,
max_model_len=max(batch_spec.seq_lens),
num_gpu_blocks=2048)
device = torch.device("cuda:0")
kv_cache_spec = create_standard_kv_cache_spec(vllm_config)
# 1. Setup
batch_size = batch_spec.batch_size
seq_lens = batch_spec.seq_lens
query_lens = batch_spec.query_lens
num_q_heads = vllm_config.model_config.get_num_attention_heads(
vllm_config.parallel_config)
num_kv_heads = vllm_config.model_config.get_num_kv_heads(
vllm_config.parallel_config)
head_size = vllm_config.model_config.get_head_size()
dtype = _convert_dtype_to_torch(vllm_config.model_config.dtype)
block_size = vllm_config.cache_config.block_size
kv_lora_rank = 512
qk_rope_head_dim = 64
qk_nope_head_dim = 128
v_head_dim = 128
total_head_size = kv_lora_rank + qk_rope_head_dim
assert kv_lora_rank + qk_rope_head_dim == head_size, \
f"MLA dimensions don't match: {total_head_size} != {head_size}"
scale = 1.0 / (total_head_size**0.5)
# 2. Generate data and compute SDPA reference output for MLA
all_q_vllm, all_kv_c_vllm, all_k_pe_vllm = [], [], []
all_sdpa_outputs = []
kv_c_contexts, k_pe_contexts = [], []
# Create shared MLA weight matrices for consistency across all sequences
W_UK = torch.randn(kv_lora_rank,
num_q_heads,
qk_nope_head_dim,
dtype=dtype,
device=device)
W_UV = torch.randn(kv_lora_rank,
num_q_heads,
v_head_dim,
dtype=dtype,
device=device)
kv_b_proj_weight = torch.cat([W_UK, W_UV], dim=-1)
for i in range(batch_size):
s_len = seq_lens[i]
q_len = query_lens[i]
context_len = s_len - q_len
# Generate MLA tensors
# Q has both nope and rope components:
# [q_len, num_heads, qk_nope_head_dim + qk_rope_head_dim]
q_c = torch.randn(q_len,
num_q_heads,
qk_nope_head_dim + qk_rope_head_dim,
dtype=dtype,
device=device)
# KV_C (latent K/V): [s_len, kv_lora_rank]
kv_c_full = torch.randn(s_len,
kv_lora_rank,
dtype=dtype,
device=device)
# K_PE (rope component): [s_len, 1, qk_rope_head_dim]
k_pe_full = torch.randn(s_len,
1,
qk_rope_head_dim,
dtype=dtype,
device=device)
# Determine if this is decode (single token)
# or prefill (multiple tokens)
is_decode = q_len == 1
# Split q into nope and rope components
q_nope, q_pe = q_c.split([qk_nope_head_dim, qk_rope_head_dim], dim=-1)
if is_decode:
# Decode path: MQA-style attention in latent space
# Transform q_nope to latent space: q_nope @ W_UK
# q_nope: [1, num_heads, qk_nope_head_dim]
# W_UK: [kv_lora_rank, num_heads, qk_nope_head_dim]
ql_nope = torch.einsum("qnh,lnh->qnl", q_nope,
W_UK) # [1, num_heads, kv_lora_rank]
# Build MQA attention inputs
# Q: [1, num_heads, kv_lora_rank + qk_rope_head_dim]
q_mqa = torch.cat([ql_nope, q_pe], dim=-1)
# K: [s_len, kv_lora_rank + qk_rope_head_dim]
# (broadcasted to all heads)
k_mqa = torch.cat([kv_c_full, k_pe_full.squeeze(1)], dim=-1)
k_mqa = k_mqa.unsqueeze(1).expand(-1, num_q_heads, -1)
# V: [s_len, kv_lora_rank] (broadcasted to all heads)
v_mqa = kv_c_full.unsqueeze(1).expand(-1, num_q_heads, -1)
# SDPA expects (N, H, L, D)
q_sdpa_in = q_mqa.unsqueeze(0).transpose(1, 2)
k_sdpa_in = k_mqa.unsqueeze(0).transpose(1, 2)
v_sdpa_in = v_mqa.unsqueeze(0).transpose(1, 2)
sdpa_out_i = torch.nn.functional.scaled_dot_product_attention(
q_sdpa_in, k_sdpa_in, v_sdpa_in, is_causal=False, scale=scale)
sdpa_out_i = sdpa_out_i.transpose(1, 2).squeeze(
0) # [1, num_heads, kv_lora_rank]
# Project back to output space: sdpa_out @ W_UV
sdpa_out_i = torch.einsum("qnl,lnv->qnv", sdpa_out_i, W_UV)
sdpa_out_i = sdpa_out_i.flatten(start_dim=-2)
else:
# Prefill path: MHA-style attention with full sequence
# Apply kv_b_proj to the full kv_c tensor
kv_nope_full = torch.einsum("sl,lnh->snh", kv_c_full,
kv_b_proj_weight)
k_nope_full, v_full = kv_nope_full.split(
[qk_nope_head_dim, v_head_dim], dim=-1)
# Build attention inputs for full sequence
q_mha = torch.cat([q_nope, q_pe],
dim=-1) # [q_len, num_heads, total_dim]
k_pe_full_expanded = k_pe_full.expand(-1, num_q_heads, -1)
k_full = torch.cat([k_nope_full, k_pe_full_expanded], dim=-1)
# Create custom attention mask:
# - Query tokens can attend to all context tokens
# - Query tokens can only attend to query tokens up to their pos
attn_mask = torch.ones(q_len,
s_len,
dtype=torch.bool,
device=device)
# Apply causal mask only to the query portion (context_len onwards)
causal_mask = torch.tril(torch.ones(q_len, q_len, device=device))
attn_mask[:, context_len:] = causal_mask
# SDPA expects (N, H, L, D)
q_sdpa_in = q_mha.unsqueeze(0).transpose(1, 2)
k_sdpa_in = k_full.unsqueeze(0).transpose(1, 2)
v_sdpa_in = v_full.unsqueeze(0).transpose(1, 2)
# Single attention call with custom mask
sdpa_out_i = torch.nn.functional.scaled_dot_product_attention(
q_sdpa_in,
k_sdpa_in,
v_sdpa_in,
attn_mask=attn_mask,
scale=scale)
sdpa_out_i = sdpa_out_i.transpose(1, 2).squeeze(0)
sdpa_out_i = sdpa_out_i.flatten(start_dim=-2)
all_sdpa_outputs.append(sdpa_out_i)
# Inputs for vLLM MLA backends are just the new tokens
all_q_vllm.append(q_c)
all_kv_c_vllm.append(kv_c_full[context_len:]) # New kv_c tokens
all_k_pe_vllm.append(k_pe_full[context_len:]) # New k_pe tokens
# Contextual K/V data used to populate the paged cache (MLA format)
kv_c_contexts.append(kv_c_full[:context_len])
k_pe_contexts.append(k_pe_full[:context_len])
# Concatenate all sequences (no reordering needed)
query_vllm = torch.cat(all_q_vllm, dim=0)
kv_c_vllm = torch.cat(all_kv_c_vllm, dim=0)
k_pe_vllm = torch.cat(all_k_pe_vllm, dim=0)
sdpa_output = torch.cat(all_sdpa_outputs, dim=0)
# Create mock kv_b_proj using the same weights as reference implementation
from vllm.model_executor.layers.linear import ColumnParallelLinear
mock_kv_b_proj = ColumnParallelLinear(input_size=kv_lora_rank,
output_size=num_q_heads *
(qk_nope_head_dim + v_head_dim),
bias=False).to(device=device,
dtype=dtype)
# Set the mock weights to match our reference implementation
# Reshape W_UK and W_UV to match the expected kv_b_proj format
# [kv_lora_rank, num_heads, qk_nope_head_dim + v_head_dim]
kv_b_proj_weight = kv_b_proj_weight.view(
kv_lora_rank, num_q_heads * (qk_nope_head_dim + v_head_dim))
mock_kv_b_proj.weight = torch.nn.Parameter(kv_b_proj_weight.T)
# Create metadata using original batch spec
common_attn_metadata = create_common_attn_metadata(
batch_spec, vllm_config.cache_config.block_size, device)
# 3. Simulate Paged KV Cache and a realistic slot_mapping
kv_cache = create_and_prepopulate_kv_cache(
kv_c_contexts=kv_c_contexts,
k_pe_contexts=k_pe_contexts,
block_size=block_size,
num_kv_heads=num_kv_heads,
head_size=head_size,
dtype=dtype,
device=device,
num_blocks=vllm_config.cache_config.num_gpu_blocks,
common_attn_metadata=common_attn_metadata,
randomize_blocks=True)
# 4. Run vLLM backends and compare
for backend_name in BACKENDS_TO_TEST:
backend_output = run_attention_backend(
backend_name, kv_cache_spec, ["placeholder"], vllm_config, device,
common_attn_metadata, query_vllm, kv_c_vllm, k_pe_vllm, kv_cache,
kv_lora_rank, qk_nope_head_dim, qk_rope_head_dim, v_head_dim,
mock_kv_b_proj)
# Check shape and dtype consistency
assert backend_output.shape == sdpa_output.shape, (
f"[{backend_name}] shape {backend_output.shape} != "
f"SDPA shape {sdpa_output.shape}")
assert backend_output.dtype == sdpa_output.dtype, (
f"[{backend_name}] dtype {backend_output.dtype} != "
f"SDPA dtype {sdpa_output.dtype}")
assert torch.isfinite(backend_output).all(), (
f"[{backend_name}] produced non-finite values")
# Check numerical similarity
rtol = 1e-2
atol = 5e-1
max_diff = torch.max(torch.abs(backend_output - sdpa_output)).item()
max_rel_diff = torch.max(
torch.abs(backend_output - sdpa_output) /
torch.abs(sdpa_output)).item()
all_close = torch.allclose(backend_output,
sdpa_output,
rtol=rtol,
atol=atol)
assert all_close, (
f"[{backend_name}] output differs from SDPA baseline. "
f"Max diff: {max_diff:.6f}, max rel diff: {max_rel_diff:.6f})")
......@@ -58,6 +58,7 @@ def create_common_attn_metadata(
dtype=torch.int32,
device=device)
seq_lens_cpu = seq_lens.cpu()
max_seq_len = int(seq_lens_cpu.max())
# Create computed tokens (context length for each sequence)
context_lens = [
......@@ -101,6 +102,7 @@ def create_common_attn_metadata(
num_reqs=batch_spec.batch_size,
num_actual_tokens=num_tokens,
max_query_len=max_query_len,
max_seq_len=max_seq_len,
block_table_tensor=block_table_tensor,
slot_mapping=slot_mapping,
causal=True,
......@@ -133,6 +135,12 @@ def get_attention_backend(backend_name: _Backend):
"vllm.v1.attention.backends.tree_attn.TreeAttentionBackend",
_Backend.XFORMERS_VLLM_V1:
"vllm.v1.attention.backends.xformers.XFormersAttentionBackend",
_Backend.CUTLASS_MLA:
"vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend",
_Backend.FLASHMLA_VLLM_V1:
"vllm.v1.attention.backends.mla.flashmla.FlashMLABackend",
_Backend.TRITON_MLA_VLLM_V1:
"vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend",
}
if backend_name not in backend_map:
......@@ -165,9 +173,11 @@ def create_vllm_config(model_name: str = "meta-llama/Meta-Llama-3-8B",
tensor_parallel_size: int = 1,
max_model_len: int = 1024,
dtype: Union[ModelDType, torch.dtype] = "auto",
num_gpu_blocks: int = 1000,
block_size: int = 16,
max_num_seqs: int = 256,
max_num_batched_tokens: int = 8192,
enable_chunked_prefill: bool = True,
add_mock_model_methods: bool = True) -> VllmConfig:
"""Create a VllmConfig for testing with reasonable defaults."""
......@@ -187,7 +197,7 @@ def create_vllm_config(model_name: str = "meta-llama/Meta-Llama-3-8B",
)
# Set cache blocks for testing
# (these may be set during initialization normally)
cache_config.num_gpu_blocks = 1000
cache_config.num_gpu_blocks = num_gpu_blocks
cache_config.num_cpu_blocks = 0
parallel_config = ParallelConfig(
......@@ -196,6 +206,7 @@ def create_vllm_config(model_name: str = "meta-llama/Meta-Llama-3-8B",
scheduler_config = SchedulerConfig(
max_num_seqs=max_num_seqs,
max_num_batched_tokens=max_num_batched_tokens,
enable_chunked_prefill=enable_chunked_prefill,
)
device_config = DeviceConfig()
......
......@@ -22,7 +22,6 @@ def _make_model_runner_output(
for i, req_id in enumerate(req_ids)
},
sampled_token_ids=[[i] for i in range(len(req_ids))],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.v1.core.encoder_cache_manager import EncoderCacheManager
# ------------------ Mock Classes ------------------ #
class MockRequest:
def __init__(self, request_id, mm_hashes, token_counts):
self.request_id = request_id
self.mm_hashes = mm_hashes
self._token_counts = token_counts
def get_num_encoder_tokens(self, input_id: int) -> int:
return self._token_counts[input_id]
# ------------------ Unit Tests ------------------ #
def test_basic_allocate_and_reuse():
cache = EncoderCacheManager(cache_size=10)
req = MockRequest("r1", ["imgA"], [4])
assert not cache.check_and_update_cache(req, 0)
assert cache.can_allocate(req, 0, int(1e9), 0)
cache.allocate(req, 0)
assert cache.check_and_update_cache(req, 0)
assert "r1" in cache.cached["imgA"]
assert cache.num_free_slots == 6
# Free twice to bring refcount to 0.
cache.free_encoder_input(req, 0)
cache.free_encoder_input(req, 0)
assert not cache.cached["imgA"]
assert "imgA" in cache.freeable
assert cache.num_freeable_slots == 10
assert cache.num_free_slots == 6
def test_freeing_decreases_refcount_and_moves_to_freeable():
manager = EncoderCacheManager(cache_size=10)
req = MockRequest("req2", ["img3"], [5])
assert manager.can_allocate(req, 0, int(1e9), 0)
manager.allocate(req, 0)
assert len(manager.cached["img3"]) == 1
manager.free_encoder_input(req, 0)
assert not manager.cached["img3"]
assert "img3" in manager.freeable
assert manager.num_freeable_slots == 10
def test_free_request_frees_all_inputs():
manager = EncoderCacheManager(cache_size=10)
req = MockRequest("req3", ["a", "b"], [2, 3])
assert manager.can_allocate(req, 0, int(1e9), 0)
manager.allocate(req, 0)
assert manager.can_allocate(req, 1, int(1e9), 0)
manager.allocate(req, 1)
assert len(manager.cached["a"]) == 1
assert len(manager.cached["b"]) == 1
manager.free(req)
assert not manager.cached["a"]
assert not manager.cached["b"]
assert "a" in manager.freeable
assert "b" in manager.freeable
assert manager.num_freeable_slots == 10
def test_eviction_when_cache_is_full():
manager = EncoderCacheManager(cache_size=10)
req1 = MockRequest("req1", ["x"], [6])
req2 = MockRequest("req2", ["y"], [5])
assert manager.can_allocate(req1, 0, int(1e9), 0)
manager.allocate(req1, 0)
manager.free_encoder_input(req1, 0)
assert manager.can_allocate(req2, 0, int(1e9), 0)
manager.allocate(req2, 0)
# 'x' should have been evicted.
assert "x" not in manager.cached
assert "x" in manager.get_freed_mm_hashes()
def test_get_cached_input_ids():
manager = EncoderCacheManager(cache_size=10)
req = MockRequest("reqX", ["m", "n", "o"], [2, 4, 3])
assert manager.can_allocate(req, 0, int(1e9), 0)
manager.allocate(req, 0)
assert manager.can_allocate(req, 2, int(1e9), 0)
manager.allocate(req, 2)
cached_ids = manager.get_cached_input_ids(req)
assert cached_ids == {0, 2}
def test_has_cache_restores_from_freeable():
manager = EncoderCacheManager(cache_size=10)
req = MockRequest("reqY", ["imgZ"], [4])
assert manager.can_allocate(req, 0, int(1e9), 0)
manager.allocate(req, 0)
manager.free_encoder_input(req, 0)
# Should restore from freeable.
assert manager.check_and_update_cache(req, 0)
assert len(manager.cached["imgZ"]) == 1
assert "imgZ" not in manager.freeable
assert manager.num_freeable_slots == 6
def test_get_freed_mm_hashes_clears_freed_list():
manager = EncoderCacheManager(cache_size=10)
req1 = MockRequest("reqA", ["a"], [5])
req2 = MockRequest("reqB", ["b"], [6])
assert manager.can_allocate(req1, 0, int(1e9), 0)
manager.allocate(req1, 0)
manager.free_encoder_input(req1, 0)
# Should trigger eviction of 'a'.
assert manager.can_allocate(req2, 0, int(1e9), 0)
manager.allocate(req2, 0)
freed = manager.get_freed_mm_hashes()
assert "a" in freed
assert manager.get_freed_mm_hashes() == []
def test_schedule_request_multi_images_respect_space_limit():
manager = EncoderCacheManager(cache_size=10)
req = MockRequest("reqA", ["a", "b"], [5, 6])
compute_budget = 100
num_tokens_to_schedule = 0
assert manager.can_allocate(req, 0, compute_budget, num_tokens_to_schedule)
num_tokens_to_schedule += req.get_num_encoder_tokens(0)
compute_budget -= req.get_num_encoder_tokens(0)
assert not manager.can_allocate(req, 1, compute_budget,
num_tokens_to_schedule)
def test_schedule_request_multi_images_respect_compute_limit():
manager = EncoderCacheManager(cache_size=100)
req = MockRequest("reqA", ["a", "b"], [5, 6])
compute_budget = 10
num_tokens_to_schedule = 0
assert manager.can_allocate(req, 0, compute_budget, num_tokens_to_schedule)
num_tokens_to_schedule += req.get_num_encoder_tokens(0)
compute_budget -= req.get_num_encoder_tokens(0)
assert not manager.can_allocate(req, 1, compute_budget,
num_tokens_to_schedule)
......@@ -7,7 +7,8 @@ import pytest
import torch
from vllm.config import ModelConfig, SchedulerConfig, VllmConfig
from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange
from vllm.multimodal.inputs import (MultiModalFeatureSpec,
MultiModalKwargsItem, PlaceholderRange)
from vllm.sampling_params import SamplingParams
from vllm.utils import GiB_bytes, sha256, sha256_cbor_64bit
from vllm.v1.core.kv_cache_manager import KVCacheManager
......@@ -37,17 +38,20 @@ def make_request(
mm_hashes: Optional[list[str]] = None,
cache_salt: Optional[str] = None,
):
if mm_positions is None:
mm_kwargs = None
else:
mm_item = MultiModalKwargsItem.dummy("dummy_m")
mm_kwargs = [mm_item] * len(mm_positions)
mm_features = []
if mm_positions is not None:
for j, position in enumerate(mm_positions):
identifier = mm_hashes[j] if mm_hashes else f"hash_{j}"
mm_feature = MultiModalFeatureSpec(
data=MultiModalKwargsItem.dummy("dummy_m"),
mm_position=position,
identifier=identifier,
modality="image")
mm_features.append(mm_feature)
return Request(request_id=request_id,
prompt_token_ids=prompt_token_ids,
multi_modal_kwargs=mm_kwargs,
multi_modal_hashes=mm_hashes,
multi_modal_placeholders=mm_positions,
mm_features=mm_features if mm_features else None,
sampling_params=SamplingParams(max_tokens=17),
pooling_params=None,
eos_token_id=100,
......@@ -597,8 +601,14 @@ def test_unify_kv_cache_configs():
]
unify_kv_cache_configs(need_sort_kv_cache_config)
assert need_sort_kv_cache_config[0].num_blocks == 10
assert need_sort_kv_cache_config[1].num_blocks == 10
sorted_kv_cache_groups = [
KVCacheGroupSpec(["layer1"], new_kv_cache_spec()),
KVCacheGroupSpec(["layer2"], new_kv_cache_spec(num_kv_heads=4)),
]
assert (
need_sort_kv_cache_config[0].kv_cache_groups == sorted_kv_cache_groups)
assert (
need_sort_kv_cache_config[1].kv_cache_groups == sorted_kv_cache_groups)
diff_kv_cache_config = [
KVCacheConfig(
......
......@@ -9,7 +9,8 @@ import pytest
import torch
from vllm.distributed.kv_events import AllBlocksCleared, BlockRemoved
from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange
from vllm.multimodal.inputs import (MultiModalFeatureSpec,
MultiModalKwargsItem, PlaceholderRange)
from vllm.sampling_params import SamplingParams
from vllm.utils import sha256, sha256_cbor_64bit
from vllm.v1.core.block_pool import BlockPool
......@@ -32,17 +33,20 @@ def make_request(
prompt_logprobs: Optional[int] = None,
cache_salt: Optional[str] = None,
):
if mm_positions is None:
mm_kwargs = None
else:
mm_item = MultiModalKwargsItem.dummy("dummy_m")
mm_kwargs = [mm_item] * len(mm_positions)
mm_features = []
if mm_positions is not None:
for j, position in enumerate(mm_positions):
identifier = mm_hashes[j] if mm_hashes else f"hash_{j}"
mm_feature = MultiModalFeatureSpec(
data=MultiModalKwargsItem.dummy("dummy_m"),
mm_position=position,
identifier=identifier,
modality="image")
mm_features.append(mm_feature)
return Request(request_id=request_id,
prompt_token_ids=prompt_token_ids,
multi_modal_kwargs=mm_kwargs,
multi_modal_hashes=mm_hashes,
multi_modal_placeholders=mm_positions,
mm_features=mm_features if mm_features else None,
sampling_params=SamplingParams(
max_tokens=17, prompt_logprobs=prompt_logprobs),
pooling_params=None,
......
......@@ -8,13 +8,14 @@ import torch
from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig,
SchedulerConfig, SpeculativeConfig, VllmConfig)
from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange
from vllm.multimodal.inputs import (MultiModalFeatureSpec,
MultiModalKwargsItem, PlaceholderRange)
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput
from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec)
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus
from vllm.v1.structured_output import StructuredOutputManager
from vllm.v1.structured_output.request import StructuredOutputRequest
......@@ -158,7 +159,6 @@ def test_schedule_partial_requests():
# Only the first request has a sampled token id because
# the rest requests are still being prefilled.
sampled_token_ids=[[0], [], []],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
......@@ -209,7 +209,6 @@ def test_no_mm_input_chunking():
req_ids=[request.request_id for request in requests],
req_id_to_index=req_to_index,
sampled_token_ids=[[] for _ in range(len(requests))],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
......@@ -273,7 +272,6 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
req_ids=[request.request_id for request in requests],
req_id_to_index=req_to_index,
sampled_token_ids=[[] for _ in range(len(requests))],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
......@@ -298,7 +296,6 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
req_ids=[request.request_id for request in requests],
req_id_to_index=req_to_index,
sampled_token_ids=[[0], [0]] + [[] for _ in range(len(requests) - 2)],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
......@@ -342,7 +339,7 @@ def test_stop_via_update_from_output():
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
free_encoder_mm_hashes=[],
structured_output_request_ids={},
grammar_bitmask=None)
......@@ -355,7 +352,6 @@ def test_stop_via_update_from_output():
sampled_token_ids=[[EOS_TOKEN_ID],
[10,
11]], # First request hits EOS, second continues
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[])
......@@ -396,7 +392,7 @@ def test_stop_via_update_from_output():
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
free_encoder_mm_hashes=[],
structured_output_request_ids={},
grammar_bitmask=None,
)
......@@ -409,7 +405,6 @@ def test_stop_via_update_from_output():
},
sampled_token_ids=[[10, 42, 12],
[13, 14]], # First request hits stop token
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[])
......@@ -449,7 +444,7 @@ def test_stop_via_update_from_output():
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
free_encoder_mm_hashes=[],
structured_output_request_ids={},
grammar_bitmask=None,
)
......@@ -462,7 +457,6 @@ def test_stop_via_update_from_output():
},
sampled_token_ids=[[10, 11, 12],
[13]], # First request exceeds max_tokens
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[])
......@@ -497,7 +491,7 @@ def test_stop_via_update_from_output():
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
free_encoder_mm_hashes=[],
structured_output_request_ids={},
grammar_bitmask=None)
......@@ -505,7 +499,6 @@ def test_stop_via_update_from_output():
req_ids=[requests[0].request_id],
req_id_to_index={requests[0].request_id: 0},
sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[])
......@@ -554,7 +547,6 @@ def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool],
req_ids=[requests[0].request_id],
req_id_to_index={requests[0].request_id: 0},
sampled_token_ids=[[0]],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
......@@ -572,7 +564,6 @@ def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool],
req_ids=[requests[1].request_id],
req_id_to_index={requests[1].request_id: 0},
sampled_token_ids=[[0]],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
......@@ -608,7 +599,6 @@ def test_preempt_during_execution():
req_ids=[requests[0].request_id],
req_id_to_index={requests[0].request_id: 0},
sampled_token_ids=[[0]],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
......@@ -626,7 +616,6 @@ def test_preempt_during_execution():
req_ids=[requests[1].request_id],
req_id_to_index={requests[1].request_id: 0},
sampled_token_ids=[[42]],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
......@@ -682,13 +671,14 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
req_ids=req_ids,
req_id_to_index=req_to_index,
sampled_token_ids=[[0] for _ in range(len(requests))],
spec_token_ids=spec_tokens,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
engine_core_outputs = scheduler.update_from_output(output,
model_runner_output)
draft_token_ids = DraftTokenIds(req_ids, spec_tokens)
scheduler.update_draft_token_ids(draft_token_ids)
for i in range(len(requests)):
running_req = scheduler.running[i]
......@@ -722,7 +712,6 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
req_ids=req_ids,
req_id_to_index=req_to_index,
sampled_token_ids=output_tokens,
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
......@@ -851,7 +840,6 @@ def test_kv_connector_basic():
req_ids=req_ids,
req_id_to_index=req_to_index,
sampled_token_ids=[[1000]] * len(req_ids),
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
......@@ -898,7 +886,6 @@ def test_kv_connector_basic():
req_ids=req_ids,
req_id_to_index=req_to_index,
sampled_token_ids=[[1000]] * len(req_ids),
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
......@@ -966,7 +953,6 @@ def test_kv_connector_unable_to_allocate():
req_ids=req_ids,
req_id_to_index=req_to_index,
sampled_token_ids=[[1000]] * len(req_ids),
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
......@@ -1048,7 +1034,6 @@ def test_kv_connector_handles_preemption():
req_ids=req_ids,
req_id_to_index=req_to_index,
sampled_token_ids=[[1000]] * len(req_ids),
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
......@@ -1142,7 +1127,6 @@ def make_output(scheduler: Scheduler):
for i, req in enumerate(scheduler.running)
},
sampled_token_ids=[[1000]] * len(scheduler.running),
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
......@@ -1310,7 +1294,8 @@ def create_requests_with_priority(
mm_positions: Optional[list[list[PlaceholderRange]]] = None,
max_tokens: int = 16,
stop_token_ids: Optional[list[int]] = None,
prompt_logprobs: Optional[int] = None):
prompt_logprobs: Optional[int] = None,
starting_idx: int = 0):
"""Create requests with specified priorities and arrival times."""
assert len(priorities) == num_requests
if arrival_times is not None:
......@@ -1324,21 +1309,24 @@ def create_requests_with_priority(
prompt_logprobs=prompt_logprobs)
requests = []
for i in range(num_requests):
mm_features = []
if mm_positions is not None:
mm_position = mm_positions[i]
mm_item = MultiModalKwargsItem.dummy("dummy_m")
mm_kwargs = [mm_item] * len(mm_position)
else:
mm_position = None
mm_kwargs = None
for j, position in enumerate(mm_position):
identifier = f"hash{i}_{j}"
mm_feature = MultiModalFeatureSpec(
data=MultiModalKwargsItem.dummy("dummy_m"),
mm_position=position,
identifier=identifier,
modality="image")
mm_features.append(mm_feature)
request = Request(
request_id=f"{i}",
prompt_token_ids=[i] * num_tokens,
request_id=f"{i + starting_idx}",
prompt_token_ids=[i + starting_idx] * num_tokens,
sampling_params=sampling_params,
pooling_params=None,
multi_modal_kwargs=mm_kwargs,
multi_modal_placeholders=mm_position,
multi_modal_hashes=None,
mm_features=mm_features if mm_features else None,
eos_token_id=EOS_TOKEN_ID,
arrival_time=arrival_times[i],
priority=priorities[i],
......@@ -1468,7 +1456,6 @@ def test_priority_scheduling_preemption():
for i, req in enumerate(low_priority_requests)
},
sampled_token_ids=[[100] for _ in low_priority_requests],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
......@@ -1541,7 +1528,6 @@ def test_priority_scheduling_no_preemption_when_space_available():
for i, req in enumerate(low_priority_requests)
},
sampled_token_ids=[[100] for _ in low_priority_requests],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
......@@ -1783,7 +1769,6 @@ def test_priority_scheduling_heap_property():
req_ids=[req.req_id],
req_id_to_index={req.req_id: 0},
sampled_token_ids=[[100]],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
......@@ -1820,9 +1805,7 @@ def test_schedule_skip_tokenizer_init_structured_output_request():
request = Request(
request_id="0",
prompt_token_ids=[0, 1],
multi_modal_kwargs=None,
multi_modal_hashes=None,
multi_modal_placeholders=None,
mm_features=None,
sampling_params=sampling_params,
pooling_params=None,
eos_token_id=EOS_TOKEN_ID,
......@@ -1833,3 +1816,87 @@ def test_schedule_skip_tokenizer_init_structured_output_request():
assert len(output.scheduled_new_reqs) == 0
assert len(scheduler.running) == 0
assert len(scheduler.waiting) == 1
def test_priority_scheduling_preemption_when_out_of_kv():
"""Test that priority scheduling preempts lower priority requests
when out of KV cache space."""
# Create scheduler with very limited memory to force preemption
scheduler = create_scheduler_with_priority(
max_num_seqs=2, # Allow multiple requests
max_num_batched_tokens=200,
num_blocks=5, # Can hold 64 tokens (first block is null)
block_size=16, # Standard block size
)
# Create a request and schedule it
request_low = create_requests_with_priority(
num_requests=1,
priorities=[1],
arrival_times=[0.0],
num_tokens=30,
starting_idx=0,
)[0]
scheduler.add_request(request_low)
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == 1
assert len(scheduler.waiting) == 0
assert len(scheduler.running) == 1
# Simulate model execution
model_output = ModelRunnerOutput(
req_ids=[request_low.request_id],
req_id_to_index={request_low.request_id: 0},
sampled_token_ids=[[100]],
# spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
scheduler.update_from_output(output, model_output)
# Create a high priority request and schedule it
request_high = create_requests_with_priority(
num_requests=1,
priorities=[0],
arrival_times=[1.0],
num_tokens=32,
starting_idx=1,
)[0]
scheduler.add_request(request_high)
output = scheduler.schedule()
# KV cache should be full at this point
assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() == 0
assert len(output.scheduled_new_reqs) == 1
assert output.scheduled_cached_reqs.num_reqs == 1
assert len(scheduler.waiting) == 0
assert len(scheduler.running) == 2
# Simulate model execution
requests = [request_low, request_high]
model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests],
req_id_to_index={
req.request_id: i
for i, req in enumerate(requests)
},
sampled_token_ids=[[100] for _ in requests],
# spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
scheduler.update_from_output(output, model_output)
# Schedule again - this should trigger preemption
# req_low needs 32 tokens = 2 blocks
# req_high needs 33 tokens = 3 blocks
# so doesn't fit in 4 blocks.
output = scheduler.schedule()
# Should have preempted req_low
assert len(output.scheduled_new_reqs) == 0
assert output.scheduled_cached_reqs.num_reqs == 1
assert output.scheduled_cached_reqs.req_ids[0] == request_high.request_id
assert len(scheduler.waiting) == 1
assert len(scheduler.running) == 1
\ No newline at end of file
......@@ -6,7 +6,8 @@ import torch
from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig,
SchedulerConfig, SpeculativeConfig, VllmConfig)
from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange
from vllm.multimodal.inputs import (MultiModalFeatureSpec,
MultiModalKwargsItem, PlaceholderRange)
from vllm.sampling_params import SamplingParams
from vllm.v1.core.kv_cache_utils import (get_request_block_hasher,
init_none_hash)
......@@ -139,15 +140,20 @@ def create_requests(
prompt_logprobs=prompt_logprobs)
requests = []
for i in range(num_requests):
mm_features = []
if mm_positions is not None:
mm_position = mm_positions[i]
mm_item = MultiModalKwargsItem.dummy("dummy_m")
mm_kwargs = [mm_item] * len(mm_position)
mm_hashes = ["hash"] * len(mm_position)
else:
mm_position = None
mm_kwargs = None
mm_hashes = None
for j, position in enumerate(mm_position):
# Dummy hash for each mm item should be unique
# since encoder cache tracks entries by hash
identifier = f"hash{i}_{j}"
mm_feature = MultiModalFeatureSpec(
data=MultiModalKwargsItem.dummy("dummy_m"),
mm_position=position,
identifier=identifier,
modality="image")
mm_features.append(mm_feature)
prompt_token_ids = ([0] * num_tokens if same_prompt else [i] *
num_tokens)
request = Request(
......@@ -155,9 +161,7 @@ def create_requests(
prompt_token_ids=prompt_token_ids,
sampling_params=sampling_params,
pooling_params=None,
multi_modal_kwargs=mm_kwargs,
multi_modal_placeholders=mm_position,
multi_modal_hashes=mm_hashes,
mm_features=mm_features if mm_features else None,
eos_token_id=EOS_TOKEN_ID,
block_hasher=block_hasher,
)
......
......@@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import random
from typing import Optional, Union
import pytest
import torch
......@@ -10,12 +9,6 @@ import torch
from vllm import LLM, SamplingParams
from vllm.config import CompilationConfig, CompilationLevel
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.forward_context import get_forward_context
from vllm.model_executor.models.gemma3n_mm import (
Gemma3nForConditionalGeneration)
from vllm.model_executor.models.registry import ModelRegistry
from vllm.model_executor.models.utils import extract_layer_index
from vllm.sequence import IntermediateTensors
from ...utils import fork_new_process_for_each_test
......@@ -23,54 +16,6 @@ from ...utils import fork_new_process_for_each_test
SEED = 42
class TestGemma3nForConditionalGeneration(Gemma3nForConditionalGeneration):
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = super().forward(input_ids, positions,
intermediate_tensors, inputs_embeds,
**kwargs)
attn_metadata = get_forward_context().attn_metadata
# attn_metadata is None during dummy runs
if (attn_metadata is not None
and self.language_model.cache_config.kv_sharing_fast_prefill):
assert isinstance(attn_metadata, dict) # true in V1
# Gemma3n-E2B has 30 layers, with last 20 layers being
# cross-decoder layers. Check attention metadata is correct
for layer_name, metadata in attn_metadata.items():
layer_idx = extract_layer_index(layer_name)
if layer_idx >= 20:
assert hasattr(metadata, 'logits_indices_padded')
assert hasattr(metadata, 'num_logits_indices')
else:
assert not hasattr(metadata, 'logits_indices_padded')
assert not hasattr(metadata, 'num_logits_indices')
# Last layer will be a KV sharing layer
layer_attn_metadata = attn_metadata[
self.language_model.model.layers[-1].self_attn.attn.layer_name]
logits_indices_padded = (layer_attn_metadata.logits_indices_padded)
assert logits_indices_padded is not None
num_logits_indices = layer_attn_metadata.num_logits_indices
assert num_logits_indices > 0
# Reset hidden states to random values and
# only set logits at logits_indices to valid values
# Because logits_indices are the only positions that are used
# for output token sampling, this still produces same outputs
logits_hs = hidden_states[logits_indices_padded]
hidden_states = torch.randn_like(hidden_states)
gen_indices = logits_indices_padded[:num_logits_indices]
hidden_states[gen_indices] = logits_hs[:num_logits_indices]
return hidden_states
@pytest.fixture
def test_prompts():
"""
......@@ -119,13 +64,12 @@ def cleanup(llm: LLM, compilation_config: CompilationConfig):
@fork_new_process_for_each_test
@pytest.mark.parametrize("enforce_eager", [True])
@pytest.mark.skip(reason="Disable until Gemma3n supports fast prefill")
def test_kv_sharing_fast_prefill(
monkeypatch: pytest.MonkeyPatch,
enforce_eager: bool,
test_prompts: list[str],
):
ModelRegistry.register_model("Gemma3nForConditionalGeneration",
TestGemma3nForConditionalGeneration)
sampling_params = SamplingParams(temperature=0.0, max_tokens=100)
compilation_config = CompilationConfig(
# This allows vLLM compilation backend to handle allocating and
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Comprehensive end-to-end tests for `min_tokens` in the V1 engine.
Addresses #21950: verify and add CI coverage.
Covers:
1) Basic functionality
2) Stop strings with `min_tokens` (bug #21987; fix in PR #22014)
3) EOS behavior with `min_tokens` (potential logits-processor bug)
4) Edge cases (min_tokens == max_tokens, min_tokens == 0)
5) Multiple stop conditions
"""
import os
from typing import Optional, Union
import pytest
from vllm import LLM, SamplingParams
from vllm.outputs import RequestOutput
# Test configuration
TEST_MODEL = "facebook/opt-125m" # Small model for fast CI execution
GREEDY = 0.0 # Deterministic generation for consistent testing
class MinTokensTestCase:
"""Data class for min_tokens test scenarios"""
def __init__(
self,
name: str,
min_tokens: int,
max_tokens: int,
stop: Optional[Union[str, list[str]]] = None,
expected_min_len: Optional[int] = None,
expected_exact_len: Optional[int] = None,
):
self.name = name
self.min_tokens = min_tokens
self.max_tokens = max_tokens
self.stop = stop
self.expected_min_len = expected_min_len or min_tokens
self.expected_exact_len = expected_exact_len
def __str__(self):
return (f"{self.name}: min={self.min_tokens}, "
f"max={self.max_tokens}, stop={self.stop}")
# Test scenarios covering all critical cases
MIN_TOKENS_TEST_CASES = [
# === BASIC FUNCTIONALITY (should work) ===
MinTokensTestCase(name="basic_min_tokens_no_stop",
min_tokens=8,
max_tokens=20,
stop=None,
expected_min_len=8),
MinTokensTestCase(name="min_tokens_zero",
min_tokens=0,
max_tokens=10,
stop=None,
expected_min_len=0),
MinTokensTestCase(name="min_equals_max_no_stop",
min_tokens=15,
max_tokens=15,
stop=None,
expected_exact_len=15),
# === STOP STRINGS WITH MIN_TOKENS ===
# These tests expose the detokenizer bug where stop strings
# bypass min_tokens
# Using mathematically guaranteed approach with wide stop nets
pytest.param(
MinTokensTestCase(
name="min_tokens_with_comprehensive_stops",
min_tokens=5,
max_tokens=20,
stop=[
"a",
"e",
"i",
"o",
"u",
"t",
"n",
"s",
"r",
"l",
" ",
],
expected_min_len=5,
),
marks=pytest.mark.xfail(
reason=("Known bug #21987: stop strings bypass min_tokens "
"(fixed by PR #22014)"),
strict=False),
id="min_tokens_with_comprehensive_stops",
),
pytest.param(
MinTokensTestCase(
name="min_tokens_with_simple_char_stop",
min_tokens=3,
max_tokens=15,
stop=["e", "a", " "],
expected_min_len=3,
),
marks=pytest.mark.xfail(
reason=("Known bug #21987: stop strings bypass min_tokens "
"(fixed by PR #22014)"),
strict=False),
id="min_tokens_with_simple_char_stop",
),
# === EOS TOKEN WITH MIN_TOKENS (potential LogitsProcessor bug) ===
# These test the MinTokensLogitsProcessor handling of EOS tokens
pytest.param(
MinTokensTestCase(
name="min_equals_max_eos_only",
min_tokens=20,
max_tokens=20,
stop=None, # Relies on default EOS token behavior
expected_exact_len=20,
),
marks=pytest.mark.xfail(
reason=
("Potential logits-processor bug: EOS tokens may bypass min_tokens"
),
strict=False,
),
id="min_equals_max_eos_only",
),
# === EDGE CASES ===
MinTokensTestCase(name="large_min_tokens",
min_tokens=50,
max_tokens=60,
stop=None,
expected_min_len=50),
MinTokensTestCase(
name="min_tokens_with_empty_stop_list",
min_tokens=5,
max_tokens=15,
stop=[], # Empty stop list
expected_min_len=5),
]
@pytest.fixture(scope="module")
def llm_v1():
"""Create V1 LLM instance for testing"""
# Ensure V1 engine is used
os.environ["VLLM_USE_V1"] = "1"
llm = LLM(
model=TEST_MODEL,
tensor_parallel_size=1,
max_model_len=1024, # Small context for fast testing
enforce_eager=True, # Avoid graph compilation overhead
)
return llm
def get_token_count(output: RequestOutput) -> int:
"""Extract token count from LLM output"""
if not output.outputs:
return 0
return len(output.outputs[0].token_ids)
def assert_min_tokens_satisfied(output: RequestOutput,
test_case: MinTokensTestCase) -> None:
"""Assert that min_tokens requirement is satisfied"""
token_count = get_token_count(output)
stop_reason = (output.outputs[0].stop_reason
if output.outputs else "no output")
if test_case.expected_exact_len is not None:
# Exact length requirement
assert token_count == test_case.expected_exact_len, (
f"Expected exactly {test_case.expected_exact_len} tokens, "
f"got {token_count} tokens. "
f"Stop reason: {stop_reason}")
else:
# Minimum length requirement
assert token_count >= (test_case.expected_min_len or 0), (
f"Expected at least {test_case.expected_min_len} tokens, "
f"got {token_count} tokens. "
f"Stop reason: {stop_reason}")
@pytest.mark.parametrize(
"test_case",
MIN_TOKENS_TEST_CASES,
ids=lambda tc: tc.name,
)
def test_min_tokens_comprehensive(llm_v1: LLM, test_case: MinTokensTestCase):
"""
Comprehensive test for min_tokens functionality in V1 engine.
This test covers all critical scenarios for min_tokens:
- Basic functionality (should work)
- Stop strings with min_tokens (known bug)
- EOS tokens with min_tokens (potential bug)
- Edge cases
Args:
llm_v1: V1 LLM instance
test_case: Test scenario parameters
"""
# Known failing cases are handled via param-level xfail marks above.
# Create sampling parameters
sampling_params = SamplingParams(
min_tokens=test_case.min_tokens,
max_tokens=test_case.max_tokens,
stop=test_case.stop,
temperature=GREEDY,
include_stop_str_in_output=True # Include stop strings for debugging
)
# Use simple prompt. Comprehensive stop lists should catch any generation
prompt = "Hello"
# Generate output
outputs = llm_v1.generate([prompt], sampling_params)
assert len(outputs) == 1, "Expected exactly one output"
output = outputs[0]
# Debug information
token_count = get_token_count(output)
generated_text = output.outputs[0].text if output.outputs else ""
stop_reason = output.outputs[0].stop_reason if output.outputs else "unknown"
print(f"\nTest: {test_case.name}")
print(f"Generated {token_count} tokens")
print(f"Stop reason: {stop_reason}")
print(f"Generated text: {repr(generated_text)}")
print(f"Expected min: {test_case.expected_min_len}")
if test_case.expected_exact_len:
print(f"Expected exact: {test_case.expected_exact_len}")
# Validate min_tokens requirement
assert_min_tokens_satisfied(output, test_case)
def test_min_tokens_basic_functionality(llm_v1: LLM):
"""
Test basic min_tokens functionality without stop conditions.
This is a baseline test that should always pass and validates
that min_tokens works correctly in the simple case.
"""
sampling_params = SamplingParams(min_tokens=10,
max_tokens=20,
temperature=GREEDY)
prompt = "Once upon a time"
outputs = llm_v1.generate([prompt], sampling_params)
assert len(outputs) == 1
token_count = get_token_count(outputs[0])
assert token_count >= 10, f"Expected at least 10 tokens, got {token_count}"
assert token_count <= 20, f"Expected at most 20 tokens, got {token_count}"
@pytest.mark.xfail(
reason=("Known bug #21987: stop strings bypass min_tokens "
"(fixed by PR #22014)"),
strict=False,
)
def test_min_tokens_stop_strings_bug(llm_v1: LLM):
"""
Test the specific bug where stop strings bypass min_tokens.
This test specifically reproduces the bug Calvin is fixing in PR #22014.
It should fail until that fix is merged.
Strategy: Use guaranteed stop characters that will appear
in any generated text.
"""
# If the bug is fixed upstream, this test will XPASS
sampling_params = SamplingParams(
min_tokens=15,
max_tokens=50,
# Common letter; likely appears early
stop=["e"],
temperature=GREEDY,
include_stop_str_in_output=True)
# Simple prompt that will generate text containing "e"
prompt = "The quick brown fox"
outputs = llm_v1.generate([prompt], sampling_params)
assert len(outputs) == 1
token_count = get_token_count(outputs[0])
generated_text = outputs[0].outputs[0].text if outputs[0].outputs else ""
# Debug info to understand what happened
print(f"Generated text: {repr(generated_text)}")
print(f"Token count: {token_count}")
print(f"Contains 'e': {'e' in generated_text}")
# This assertion should fail due to the bug - if stop string is found early,
# the model should still continue generating until min_tokens is reached
stop_reason = (outputs[0].outputs[0].stop_reason
if outputs[0].outputs else "no output")
assert token_count >= 15, ("Bug confirmed: "
f"{token_count} tokens < min_tokens=15. "
f"Reason: {stop_reason}. "
f"Text: {repr(generated_text)}")
@pytest.mark.xfail(
reason=("Known bug #21987: stop strings bypass min_tokens "
"(fixed by PR #22014)"),
strict=False,
)
def test_min_tokens_stop_strings_guaranteed_early_trigger(llm_v1: LLM):
"""
Guaranteed test for stop strings bypassing min_tokens bug.
Strategy: Use very low temperature and multiple common stop strings
to virtually guarantee early detection, combined with long min_tokens
to ensure the bug is exposed regardless of model behavior.
"""
# If the bug is fixed upstream, this test will XPASS
sampling_params = SamplingParams(
min_tokens=50, # Set high min_tokens to ensure bug detection
max_tokens=200,
# Use multiple very common patterns - at least one will appear
stop=["e", "a", "i", "o", "u", " ", "t", "n", "s", "r"],
temperature=GREEDY,
include_stop_str_in_output=True)
# Simple prompt that will generate some text
prompt = "The cat"
outputs = llm_v1.generate([prompt], sampling_params)
assert len(outputs) == 1
token_count = get_token_count(outputs[0])
generated_text = outputs[0].outputs[0].text if outputs[0].outputs else ""
stop_reason = (outputs[0].outputs[0].stop_reason
if outputs[0].outputs else "unknown")
print(f"Generated text: {repr(generated_text)}")
print(f"Token count: {token_count}")
print(f"Stop reason: {stop_reason}")
# With the bug, this will fail because ANY of the common characters
# will trigger early termination before min_tokens=50 is reached
# It's virtually impossible to generate 50 tokens without hitting
# at least one of: e, a, i, o, u, space, t, n, s, r
finish_reason = (outputs[0].outputs[0].finish_reason
if outputs[0].outputs else "unknown")
print(f"Finish reason: {finish_reason}")
if finish_reason == "stop":
assert token_count >= 50, ("Bug confirmed: "
f"{token_count} tokens < min_tokens=50. "
f"Reason: {finish_reason}. "
f"Text: {repr(generated_text)}")
@pytest.mark.xfail(
reason=(
"Potential logits-processor bug: EOS tokens may bypass min_tokens"),
strict=False,
)
def test_min_tokens_eos_behavior(llm_v1: LLM):
"""
Verify EOS handling with and without min_tokens.
- Without min_tokens: expect early EOS -> finish_reason == "stop",
stop_reason is None, and generated tokens < max_tokens (25).
- With min_tokens: EOS should be blocked until min_tokens is reached
(finish_reason == "length"); verify that eos_token_id does not appear
in generated token_ids.
"""
# tokenizer + eos id
tokenizer = llm_v1.get_tokenizer()
eos_token_id = tokenizer.eos_token_id
prompt = "Give a file extension."
max_toks = 32
# Case 1: WITHOUT min_tokens
sp_no_min = SamplingParams(
max_tokens=max_toks,
temperature=GREEDY,
)
out_no_min = llm_v1.generate([prompt], sp_no_min)
assert len(out_no_min) == 1
choice_no_min = out_no_min[0].outputs[0]
ids_no_min = choice_no_min.token_ids or []
finish_no_min = choice_no_min.finish_reason
stop_no_min = choice_no_min.stop_reason
print("[no-min] tokens=", len(ids_no_min), " finish=", finish_no_min,
" stop_reason=", stop_no_min)
assert finish_no_min == "stop", (
f"Expected finish_reason 'stop' without min_tokens, got {finish_no_min}"
)
assert stop_no_min is None, (
"For EOS-based stop (no user stop strings), stop_reason should be None."
)
assert len(ids_no_min) < max_toks, (
f"Expected early EOS with < {max_toks} tokens, got {len(ids_no_min)}")
# Case 2: WITH min_tokens
sp_with_min = SamplingParams(
min_tokens=max_toks,
max_tokens=max_toks,
temperature=GREEDY,
)
out_with_min = llm_v1.generate([prompt], sp_with_min)
assert len(out_with_min) == 1
choice_with_min = out_with_min[0].outputs[0]
ids_with_min = choice_with_min.token_ids or []
finish_with_min = choice_with_min.finish_reason
stop_with_min = choice_with_min.stop_reason
print("[with-min] tokens=", len(ids_with_min), " finish=", finish_with_min,
" stop_reason=", stop_with_min)
# Exact length reached; EOS should have been blocked
assert len(ids_with_min) == max_toks, (
f"Expected exactly {max_toks} tokens with min_tokens; "
f"got {len(ids_with_min)}")
assert finish_with_min == "length", (
f"Expected finish_reason 'length'; got {finish_with_min}")
assert eos_token_id not in ids_with_min, (
"EOS token id should not appear when min_tokens prevents early EOS.")
def test_min_tokens_validation():
"""
Test that SamplingParams correctly validates min_tokens parameters.
This tests the parameter validation logic in SamplingParams.
"""
# Valid cases
SamplingParams(min_tokens=0, max_tokens=10)
SamplingParams(min_tokens=5, max_tokens=10)
SamplingParams(min_tokens=10, max_tokens=10)
# Invalid cases
with pytest.raises(
ValueError,
match="min_tokens must be greater than or equal to 0",
):
SamplingParams(min_tokens=-1, max_tokens=10)
with pytest.raises(
ValueError,
match="min_tokens must be less than or equal to max_tokens",
):
SamplingParams(min_tokens=15, max_tokens=10)
if __name__ == "__main__":
"""
Run tests locally for development.
Usage:
cd vllm/
VLLM_USE_V1=1 python -m pytest tests/v1/e2e/test_min_tokens.py -v
"""
pytest.main([__file__, "-v"])
......@@ -144,6 +144,8 @@ def test_ngram_correctness(
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
True,
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
(("eagle", "eagle618/deepseek-v3-random",
"eagle618/eagle-deepseek-v3-random", 1), False),
],
ids=[
# TODO: Re-enable this once tests/models/test_initialization.py is fixed, see PR #22333 #22611 # noqa: E501
......@@ -151,7 +153,8 @@ def test_ngram_correctness(
"llama3_eagle",
"llama3_eagle3",
"llama4_eagle",
"llama4_eagle_mm"
"llama4_eagle_mm",
"deepseek_eagle"
])
@pytest.mark.parametrize("attn_backend",
get_attn_backend_list_based_on_platform())
......@@ -177,6 +180,7 @@ def test_eagle_correctness(
'''
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
m.setenv("VLLM_MLA_DISABLE", "1")
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
if (attn_backend == "TRITON_ATTN_VLLM_V1"
......
......@@ -35,9 +35,7 @@ def make_request() -> EngineCoreRequest:
return EngineCoreRequest(
request_id=str(uuid.uuid4()),
prompt_token_ids=PROMPT_TOKENS,
mm_kwargs=None,
mm_hashes=None,
mm_placeholders=None,
mm_features=None,
sampling_params=SamplingParams(),
pooling_params=None,
eos_token_id=None,
......@@ -308,17 +306,17 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
# Schedule Batch 1: (10, req0)
assert engine_core.step_with_batch_queue()[0] is None
assert engine_core.batch_queue.qsize() == 1
scheduler_output = engine_core.batch_queue.queue[-1][1]
assert len(engine_core.batch_queue) == 1
scheduler_output = engine_core.batch_queue[-1][1]
assert scheduler_output.num_scheduled_tokens["0"] == 10
# num_computed_tokens should have been updated immediately.
assert engine_core.scheduler.requests[
req0.request_id].num_computed_tokens == 10
# Schedule Batch 2: (2, req0), (8, req1)
assert engine_core.step_with_batch_queue()[0] is None
assert engine_core.batch_queue.qsize() == 2
scheduler_output = engine_core.batch_queue.queue[-1][1]
assert engine_core.step_with_batch_queue()[0] == {}
assert len(engine_core.batch_queue) == 1
scheduler_output = engine_core.batch_queue[-1][1]
assert scheduler_output.num_scheduled_tokens["0"] == 2
assert scheduler_output.num_scheduled_tokens["1"] == 8
# num_computed_tokens should have been updated immediately.
......@@ -327,42 +325,32 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
assert engine_core.scheduler.get_num_unfinished_requests() == 2
# Batch queue is full. Finish Batch 1.
engine_core.step_with_batch_queue()
# Schedule Batch 3: (4, req1). Note that req0 cannot be scheduled
# Finish Batch 1 and schedule Batch 3: (4, req1).
# Note that req0 cannot be scheduled
# because it is in the decoding stage now.
engine_core.step_with_batch_queue()
assert engine_core.batch_queue.qsize() == 2
scheduler_output = engine_core.batch_queue.queue[-1][1]
assert len(engine_core.batch_queue) == 1
scheduler_output = engine_core.batch_queue[-1][1]
assert scheduler_output.num_scheduled_tokens["1"] == 4
# Batch queue is full. Finish Batch 2. Get first token of req0.
# Finish Batch 2. Get first token of req0.
# Schedule Batch 4: (1, req0).
output = engine_core.step_with_batch_queue()[0].get(0)
assert output is not None
assert len(output.outputs) == 1
assert engine_core.scheduler.requests[req0.request_id].num_tokens == 13
# Schedule Batch 4: (1, req0).
engine_core.step_with_batch_queue()
assert engine_core.batch_queue.qsize() == 2
scheduler_output = engine_core.batch_queue.queue[-1][1]
scheduler_output = engine_core.batch_queue[-1][1]
assert scheduler_output.num_scheduled_tokens["0"] == 1
# Batch queue is full. Finish Batch 3. Get first token of req1.
# Finish Batch 3. Get first token of req1. Schedule Batch 5: (1, req1).
output = engine_core.step_with_batch_queue()[0].get(0)
assert output is not None
assert len(output.outputs) == 1
assert engine_core.scheduler.requests[req1.request_id].num_tokens == 13
# Schedule Batch 5: (1, req1).
engine_core.step_with_batch_queue()
assert engine_core.batch_queue.qsize() == 2
scheduler_output = engine_core.batch_queue.queue[-1][1]
scheduler_output = engine_core.batch_queue[-1][1]
assert scheduler_output.num_scheduled_tokens["1"] == 1
# Loop until req0 is finished.
step = 0
req_id = 0
expected_num_tokens = [
engine_core.scheduler.requests["0"].num_tokens + 1,
......@@ -370,19 +358,14 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
]
while engine_core.scheduler.get_num_unfinished_requests() == 2:
output = engine_core.step_with_batch_queue()[0]
if step % 2 == 0:
# Even steps consumes an output.
assert output is not None
assert len(output[0].outputs) == 1
if req_id in engine_core.scheduler.requests:
assert engine_core.scheduler.requests[
req_id].num_tokens == expected_num_tokens[req_id]
expected_num_tokens[req_id] += 1
req_id = (req_id + 1) % 2
else:
# Odd steps schedules a new batch.
assert output is None
step += 1
# Every step consumes an output.
assert output is not None
assert len(output[0].outputs) == 1
if req_id in engine_core.scheduler.requests:
assert engine_core.scheduler.requests[
req_id].num_tokens == expected_num_tokens[req_id]
expected_num_tokens[req_id] += 1
req_id = (req_id + 1) % 2
@multi_gpu_test(num_gpus=2)
......
......@@ -52,9 +52,7 @@ def make_request(
return EngineCoreRequest(
request_id=str(uuid.uuid4()),
prompt_token_ids=prompt_tokens_ids,
mm_kwargs=None,
mm_hashes=None,
mm_placeholders=None,
mm_features=None,
sampling_params=params,
pooling_params=None,
eos_token_id=None,
......
......@@ -26,16 +26,14 @@ def test_fast_inc_detok_invalid_utf8_err_case():
prompt_token_ids = [107, 4606, 236787, 107]
params = SamplingParams(skip_special_tokens=True)
request = EngineCoreRequest(
"test",
prompt_token_ids,
None,
None,
None,
params,
None,
None,
0.0,
None,
request_id="test",
prompt_token_ids=prompt_token_ids,
mm_features=None,
sampling_params=params,
pooling_params=None,
eos_token_id=None,
arrival_time=0.0,
lora_request=None,
cache_salt=None,
data_parallel_rank=None,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment