Commit 38d80967 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.10.2rc2' into v0.10.2rc2-ori

parents 33650733 880c741b
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from pathlib import Path
from typing import Optional, Union
import pytest
from transformers import PretrainedConfig
from vllm.transformers_utils.config import (get_config_parser,
register_config_parser)
from vllm.transformers_utils.config_parser_base import ConfigParserBase
@register_config_parser("custom_config_parser")
class CustomConfigParser(ConfigParserBase):
def parse(self,
model: Union[str, Path],
trust_remote_code: bool,
revision: Optional[str] = None,
code_revision: Optional[str] = None,
**kwargs) -> tuple[dict, PretrainedConfig]:
raise NotImplementedError
def test_register_config_parser():
assert isinstance(get_config_parser("custom_config_parser"),
CustomConfigParser)
def test_invalid_config_parser():
with pytest.raises(ValueError):
@register_config_parser("invalid_config_parser")
class InvalidConfigParser:
pass
......@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import contextlib
import copy
import functools
import importlib
......@@ -13,10 +14,11 @@ import sys
import tempfile
import time
import warnings
from contextlib import contextmanager, suppress
from contextlib import ExitStack, contextmanager, suppress
from multiprocessing import Process
from pathlib import Path
from typing import Any, Callable, Literal, Optional, Union
from unittest.mock import patch
import cloudpickle
import httpx
......@@ -799,43 +801,106 @@ _P = ParamSpec("_P")
def fork_new_process_for_each_test(
f: Callable[_P, None]) -> Callable[_P, None]:
func: Callable[_P, None]) -> Callable[_P, None]:
"""Decorator to fork a new process for each test function.
See https://github.com/vllm-project/vllm/issues/7053 for more details.
"""
@functools.wraps(f)
@functools.wraps(func)
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None:
# Make the process the leader of its own process group
# to avoid sending SIGTERM to the parent process
os.setpgrp()
from _pytest.outcomes import Skipped
pid = os.fork()
print(f"Fork a new process to run a test {pid}")
if pid == 0:
try:
f(*args, **kwargs)
except Skipped as e:
# convert Skipped to exit code 0
print(str(e))
os._exit(0)
except Exception:
import traceback
traceback.print_exc()
os._exit(1)
# Create a unique temporary file to store exception info from child
# process. Use test function name and process ID to avoid collisions.
with tempfile.NamedTemporaryFile(
delete=False,
mode='w+b',
prefix=f"vllm_test_{func.__name__}_{os.getpid()}_",
suffix=".exc") as exc_file, ExitStack() as delete_after:
exc_file_path = exc_file.name
delete_after.callback(os.remove, exc_file_path)
pid = os.fork()
print(f"Fork a new process to run a test {pid}")
if pid == 0:
# Parent process responsible for deleting, don't delete
# in child.
delete_after.pop_all()
try:
func(*args, **kwargs)
except Skipped as e:
# convert Skipped to exit code 0
print(str(e))
os._exit(0)
except Exception as e:
import traceback
tb_string = traceback.format_exc()
# Try to serialize the exception object first
exc_to_serialize: dict[str, Any]
try:
# First, try to pickle the actual exception with
# its traceback.
exc_to_serialize = {'pickled_exception': e}
# Test if it can be pickled
cloudpickle.dumps(exc_to_serialize)
except (Exception, KeyboardInterrupt):
# Fall back to string-based approach.
exc_to_serialize = {
'exception_type': type(e).__name__,
'exception_msg': str(e),
'traceback': tb_string,
}
try:
with open(exc_file_path, 'wb') as f:
cloudpickle.dump(exc_to_serialize, f)
except Exception:
# Fallback: just print the traceback.
print(tb_string)
os._exit(1)
else:
os._exit(0)
else:
os._exit(0)
else:
pgid = os.getpgid(pid)
_pid, _exitcode = os.waitpid(pid, 0)
# ignore SIGTERM signal itself
old_signal_handler = signal.signal(signal.SIGTERM, signal.SIG_IGN)
# kill all child processes
os.killpg(pgid, signal.SIGTERM)
# restore the signal handler
signal.signal(signal.SIGTERM, old_signal_handler)
assert _exitcode == 0, (f"function {f} failed when called with"
f" args {args} and kwargs {kwargs}")
pgid = os.getpgid(pid)
_pid, _exitcode = os.waitpid(pid, 0)
# ignore SIGTERM signal itself
old_signal_handler = signal.signal(signal.SIGTERM,
signal.SIG_IGN)
# kill all child processes
os.killpg(pgid, signal.SIGTERM)
# restore the signal handler
signal.signal(signal.SIGTERM, old_signal_handler)
if _exitcode != 0:
# Try to read the exception from the child process
exc_info = {}
if os.path.exists(exc_file_path):
with contextlib.suppress(Exception), \
open(exc_file_path, 'rb') as f:
exc_info = cloudpickle.load(f)
if (original_exception :=
exc_info.get('pickled_exception')) is not None:
# Re-raise the actual exception object if it was
# successfully pickled.
assert isinstance(original_exception, Exception)
raise original_exception
if (original_tb := exc_info.get("traceback")) is not None:
# Use string-based traceback for fallback case
raise AssertionError(
f"Test {func.__name__} failed when called with"
f" args {args} and kwargs {kwargs}"
f" (exit code: {_exitcode}):\n{original_tb}"
) from None
# Fallback to the original generic error
raise AssertionError(
f"function {func.__name__} failed when called with"
f" args {args} and kwargs {kwargs}"
f" (exit code: {_exitcode})") from None
return wrapper
......@@ -1077,3 +1142,11 @@ def get_attn_backend_list_based_on_platform() -> list[str]:
return attn_backend_list
else:
raise ValueError("Unsupported platform")
@contextmanager
def override_cutlass_fp8_supported(value: bool):
with patch(
"vllm.model_executor.layers.quantization.utils.w8a8_utils.cutlass_fp8_supported",
return_value=value):
yield
......@@ -835,22 +835,20 @@ def test_model_specification(parser_with_config, cli_config_file,
@pytest.mark.parametrize("input", [(), ("abc", ), (None, ),
(None, bool, [1, 2, 3])])
@pytest.mark.parametrize("output", [0, 1, 2])
def test_sha256(input: tuple, output: int):
hash = sha256(input)
assert hash is not None
assert isinstance(hash, int)
assert hash != 0
def test_sha256(input: tuple):
digest = sha256(input)
assert digest is not None
assert isinstance(digest, bytes)
assert digest != b""
bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL)
assert hash == int.from_bytes(hashlib.sha256(bytes).digest(),
byteorder="big")
input_bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL)
assert digest == hashlib.sha256(input_bytes).digest()
# hashing again, returns the same value
assert hash == sha256(input)
assert digest == sha256(input)
# hashing different input, returns different value
assert hash != sha256(input + (1, ))
assert digest != sha256(input + (1, ))
@pytest.mark.parametrize(
......
......@@ -70,22 +70,6 @@ BATCH_SPECS = {
}
def create_dummy_kv_cache(kv_cache_spec: FullAttentionSpec,
device: torch.device,
num_blocks: int = 100) -> torch.Tensor:
"""Create a dummy KV cache tensor for testing."""
kv_cache = torch.randn(
2, # K and V
num_blocks,
kv_cache_spec.block_size,
kv_cache_spec.num_kv_heads,
kv_cache_spec.head_size,
dtype=_convert_dtype_to_torch(kv_cache_spec.dtype),
device=device,
)
return kv_cache
def create_and_prepopulate_kv_cache(
k_contexts: list[torch.Tensor],
v_contexts: list[torch.Tensor],
......
......@@ -160,7 +160,7 @@ def test_local_attention_virtual_batches(test_data: LocalAttentionTestData):
# Use torch.arange instead of torch.randint so we can assert on
# block table tensor values. The block table will have shape
# (num_batches, cdiv(max_seq_len, block_size)) and the values will be
# aranged from 0 to cdiv(max_seq_len, block_size)-1
# arranged from 0 to cdiv(max_seq_len, block_size)-1
arange_block_indices=True,
)
......
......@@ -15,7 +15,7 @@ from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.kv_cache_interface import FullAttentionSpec
BACKENDS_TO_TEST = [
_Backend.CUTLASS_MLA, _Backend.FLASHMLA_VLLM_V1,
_Backend.CUTLASS_MLA, _Backend.FLASHMLA_VLLM_V1, _Backend.FLASH_ATTN_MLA,
_Backend.TRITON_MLA_VLLM_V1
]
......@@ -69,25 +69,10 @@ BATCH_SPECS = {
}
def create_dummy_kv_cache(kv_cache_spec: FullAttentionSpec,
device: torch.device,
num_blocks: int = 100) -> torch.Tensor:
"""Create a dummy KV cache tensor for testing."""
kv_cache = torch.randn(
num_blocks,
kv_cache_spec.block_size,
kv_cache_spec.head_size, # latent dimension
dtype=_convert_dtype_to_torch(kv_cache_spec.dtype),
device=device,
)
return kv_cache
def create_and_prepopulate_kv_cache(
kv_c_contexts: list[torch.Tensor],
k_pe_contexts: list[torch.Tensor],
block_size: int,
num_kv_heads: int,
head_size: int,
dtype: torch.dtype,
device: torch.device,
......@@ -101,7 +86,6 @@ def create_and_prepopulate_kv_cache(
k_pe_contexts: List of key positional embedding context tensors
for each sequence
block_size: Size of each block
num_kv_heads: Number of KV heads (should be 1 for MLA)
head_size: Size of each head (latent dimension)
dtype: Data type for the cache
device: Device to create the cache on
......@@ -299,8 +283,6 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
query_lens = batch_spec.query_lens
num_q_heads = vllm_config.model_config.get_num_attention_heads(
vllm_config.parallel_config)
num_kv_heads = vllm_config.model_config.get_num_kv_heads(
vllm_config.parallel_config)
head_size = vllm_config.model_config.get_head_size()
dtype = _convert_dtype_to_torch(vllm_config.model_config.dtype)
block_size = vllm_config.cache_config.block_size
......@@ -315,7 +297,7 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
# 2. Generate data and compute SDPA reference output for MLA
all_q_vllm, all_kv_c_vllm, all_k_pe_vllm = [], [], []
all_sdpa_outputs = []
all_sdpa_outputs: list[list[torch.Tensor]] = []
kv_c_contexts, k_pe_contexts = [], []
# Create shared MLA weight matrices for consistency across all sequences
......@@ -331,6 +313,9 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
device=device)
kv_b_proj_weight = torch.cat([W_UK, W_UV], dim=-1)
for i, backend in enumerate(BACKENDS_TO_TEST):
all_sdpa_outputs.append([])
for i in range(batch_size):
s_len = seq_lens[i]
q_len = query_lens[i]
......@@ -358,85 +343,93 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
dtype=dtype,
device=device)
# Determine if this is decode (single token)
# or prefill (multiple tokens)
is_decode = q_len == 1
# Determine if this is decode or prefill
is_decode = []
for i, backend in enumerate(BACKENDS_TO_TEST):
builder_cls, _ = get_attention_backend(backend)
is_decode.append(q_len <= builder_cls.reorder_batch_threshold)
# Split q into nope and rope components
q_nope, q_pe = q_c.split([qk_nope_head_dim, qk_rope_head_dim], dim=-1)
if is_decode:
# Decode path: MQA-style attention in latent space
# Transform q_nope to latent space: q_nope @ W_UK
# q_nope: [1, num_heads, qk_nope_head_dim]
# W_UK: [kv_lora_rank, num_heads, qk_nope_head_dim]
ql_nope = torch.einsum("qnh,lnh->qnl", q_nope,
W_UK) # [1, num_heads, kv_lora_rank]
# Build MQA attention inputs
# Q: [1, num_heads, kv_lora_rank + qk_rope_head_dim]
q_mqa = torch.cat([ql_nope, q_pe], dim=-1)
# K: [s_len, kv_lora_rank + qk_rope_head_dim]
# (broadcasted to all heads)
k_mqa = torch.cat([kv_c_full, k_pe_full.squeeze(1)], dim=-1)
k_mqa = k_mqa.unsqueeze(1).expand(-1, num_q_heads, -1)
# V: [s_len, kv_lora_rank] (broadcasted to all heads)
v_mqa = kv_c_full.unsqueeze(1).expand(-1, num_q_heads, -1)
# SDPA expects (N, H, L, D)
q_sdpa_in = q_mqa.unsqueeze(0).transpose(1, 2)
k_sdpa_in = k_mqa.unsqueeze(0).transpose(1, 2)
v_sdpa_in = v_mqa.unsqueeze(0).transpose(1, 2)
sdpa_out_i = torch.nn.functional.scaled_dot_product_attention(
q_sdpa_in, k_sdpa_in, v_sdpa_in, is_causal=False, scale=scale)
sdpa_out_i = sdpa_out_i.transpose(1, 2).squeeze(
0) # [1, num_heads, kv_lora_rank]
# Project back to output space: sdpa_out @ W_UV
sdpa_out_i = torch.einsum("qnl,lnv->qnv", sdpa_out_i, W_UV)
sdpa_out_i = sdpa_out_i.flatten(start_dim=-2)
else:
# Prefill path: MHA-style attention with full sequence
# Apply kv_b_proj to the full kv_c tensor
kv_nope_full = torch.einsum("sl,lnh->snh", kv_c_full,
kv_b_proj_weight)
k_nope_full, v_full = kv_nope_full.split(
[qk_nope_head_dim, v_head_dim], dim=-1)
# Build attention inputs for full sequence
q_mha = torch.cat([q_nope, q_pe],
dim=-1) # [q_len, num_heads, total_dim]
k_pe_full_expanded = k_pe_full.expand(-1, num_q_heads, -1)
k_full = torch.cat([k_nope_full, k_pe_full_expanded], dim=-1)
# Create custom attention mask:
# - Query tokens can attend to all context tokens
# - Query tokens can only attend to query tokens up to their pos
attn_mask = torch.ones(q_len,
s_len,
dtype=torch.bool,
device=device)
# Apply causal mask only to the query portion (context_len onwards)
causal_mask = torch.tril(torch.ones(q_len, q_len, device=device))
attn_mask[:, context_len:] = causal_mask
# SDPA expects (N, H, L, D)
q_sdpa_in = q_mha.unsqueeze(0).transpose(1, 2)
k_sdpa_in = k_full.unsqueeze(0).transpose(1, 2)
v_sdpa_in = v_full.unsqueeze(0).transpose(1, 2)
# Single attention call with custom mask
sdpa_out_i = torch.nn.functional.scaled_dot_product_attention(
q_sdpa_in,
k_sdpa_in,
v_sdpa_in,
attn_mask=attn_mask,
scale=scale)
sdpa_out_i = sdpa_out_i.transpose(1, 2).squeeze(0)
sdpa_out_i = sdpa_out_i.flatten(start_dim=-2)
all_sdpa_outputs.append(sdpa_out_i)
#######################################################
# Decode path: MQA-style attention in latent space
# Transform q_nope to latent space: q_nope @ W_UK
# q_nope: [1, num_heads, qk_nope_head_dim]
# W_UK: [kv_lora_rank, num_heads, qk_nope_head_dim]
ql_nope = torch.einsum("qnh,lnh->qnl", q_nope,
W_UK) # [1, num_heads, kv_lora_rank]
# Build MQA attention inputs
# Q: [1, num_heads, kv_lora_rank + qk_rope_head_dim]
q_mqa = torch.cat([ql_nope, q_pe], dim=-1)
# K: [s_len, kv_lora_rank + qk_rope_head_dim]
# (broadcasted to all heads)
k_mqa = torch.cat([kv_c_full, k_pe_full.squeeze(1)], dim=-1)
k_mqa = k_mqa.unsqueeze(1).expand(-1, num_q_heads, -1)
# V: [s_len, kv_lora_rank] (broadcasted to all heads)
v_mqa = kv_c_full.unsqueeze(1).expand(-1, num_q_heads, -1)
# Create custom attention mask for decode path:
# - Query tokens can attend to all context tokens
# - Query tokens can only attend to query tokens up to their position
attn_mask = torch.ones(q_len, s_len, dtype=torch.bool, device=device)
# Apply causal mask only to the query portion (context_len onwards)
causal_mask = torch.tril(torch.ones(q_len, q_len, device=device))
attn_mask[:, context_len:] = causal_mask
# SDPA expects (N, H, L, D)
q_sdpa_in = q_mqa.unsqueeze(0).transpose(1, 2)
k_sdpa_in = k_mqa.unsqueeze(0).transpose(1, 2)
v_sdpa_in = v_mqa.unsqueeze(0).transpose(1, 2)
sdpa_out_i_decode = torch.nn.functional.scaled_dot_product_attention(
q_sdpa_in, k_sdpa_in, v_sdpa_in, attn_mask=attn_mask, scale=scale)
sdpa_out_i_decode = sdpa_out_i_decode.transpose(1, 2).squeeze(
0) # [1, num_heads, kv_lora_rank]
# Project back to output space: sdpa_out @ W_UV
sdpa_out_i_decode = torch.einsum("qnl,lnv->qnv", sdpa_out_i_decode,
W_UV)
sdpa_out_i_decode = sdpa_out_i_decode.flatten(start_dim=-2)
#######################################################
# Prefill path: MHA-style attention with full sequence
# Apply kv_b_proj to the full kv_c tensor
kv_nope_full = torch.einsum("sl,lnh->snh", kv_c_full, kv_b_proj_weight)
k_nope_full, v_full = kv_nope_full.split(
[qk_nope_head_dim, v_head_dim], dim=-1)
# Build attention inputs for full sequence
q_mha = torch.cat([q_nope, q_pe],
dim=-1) # [q_len, num_heads, total_dim]
k_pe_full_expanded = k_pe_full.expand(-1, num_q_heads, -1)
k_full = torch.cat([k_nope_full, k_pe_full_expanded], dim=-1)
# Create custom attention mask:
# - Query tokens can attend to all context tokens
# - Query tokens can only attend to query tokens up to their pos
attn_mask = torch.ones(q_len, s_len, dtype=torch.bool, device=device)
# Apply causal mask only to the query portion (context_len onwards)
causal_mask = torch.tril(torch.ones(q_len, q_len, device=device))
attn_mask[:, context_len:] = causal_mask
# SDPA expects (N, H, L, D)
q_sdpa_in = q_mha.unsqueeze(0).transpose(1, 2)
k_sdpa_in = k_full.unsqueeze(0).transpose(1, 2)
v_sdpa_in = v_full.unsqueeze(0).transpose(1, 2)
# Single attention call with custom mask
sdpa_out_i_prefill = torch.nn.functional.scaled_dot_product_attention(
q_sdpa_in, k_sdpa_in, v_sdpa_in, attn_mask=attn_mask, scale=scale)
sdpa_out_i_prefill = sdpa_out_i_prefill.transpose(1, 2).squeeze(0)
sdpa_out_i_prefill = sdpa_out_i_prefill.flatten(start_dim=-2)
for i, backend in enumerate(BACKENDS_TO_TEST):
if is_decode[i]:
all_sdpa_outputs[i].append(sdpa_out_i_decode)
else:
all_sdpa_outputs[i].append(sdpa_out_i_prefill)
# Inputs for vLLM MLA backends are just the new tokens
all_q_vllm.append(q_c)
......@@ -451,7 +444,9 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
query_vllm = torch.cat(all_q_vllm, dim=0)
kv_c_vllm = torch.cat(all_kv_c_vllm, dim=0)
k_pe_vllm = torch.cat(all_k_pe_vllm, dim=0)
sdpa_output = torch.cat(all_sdpa_outputs, dim=0)
sdpa_outputs = []
for i, backend in enumerate(BACKENDS_TO_TEST):
sdpa_outputs.append(torch.cat(all_sdpa_outputs[i], dim=0))
# Create mock kv_b_proj using the same weights as reference implementation
from vllm.model_executor.layers.linear import ColumnParallelLinear
......@@ -477,7 +472,6 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
kv_c_contexts=kv_c_contexts,
k_pe_contexts=k_pe_contexts,
block_size=block_size,
num_kv_heads=num_kv_heads,
head_size=head_size,
dtype=dtype,
device=device,
......@@ -486,7 +480,7 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
randomize_blocks=True)
# 4. Run vLLM backends and compare
for backend_name in BACKENDS_TO_TEST:
for i, backend_name in enumerate(BACKENDS_TO_TEST):
backend_output = run_attention_backend(
backend_name, kv_cache_spec, ["placeholder"], vllm_config, device,
common_attn_metadata, query_vllm, kv_c_vllm, k_pe_vllm, kv_cache,
......@@ -494,12 +488,12 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
mock_kv_b_proj)
# Check shape and dtype consistency
assert backend_output.shape == sdpa_output.shape, (
assert backend_output.shape == sdpa_outputs[i].shape, (
f"[{backend_name}] shape {backend_output.shape} != "
f"SDPA shape {sdpa_output.shape}")
assert backend_output.dtype == sdpa_output.dtype, (
f"SDPA shape {sdpa_outputs[i].shape}")
assert backend_output.dtype == sdpa_outputs[i].dtype, (
f"[{backend_name}] dtype {backend_output.dtype} != "
f"SDPA dtype {sdpa_output.dtype}")
f"SDPA dtype {sdpa_outputs[i].dtype}")
assert torch.isfinite(backend_output).all(), (
f"[{backend_name}] produced non-finite values")
......@@ -508,12 +502,13 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
rtol = 1e-2
atol = 5e-1
max_diff = torch.max(torch.abs(backend_output - sdpa_output)).item()
max_diff = torch.max(torch.abs(backend_output -
sdpa_outputs[i])).item()
max_rel_diff = torch.max(
torch.abs(backend_output - sdpa_output) /
torch.abs(sdpa_output)).item()
torch.abs(backend_output - sdpa_outputs[i]) /
torch.abs(sdpa_outputs[i])).item()
all_close = torch.allclose(backend_output,
sdpa_output,
sdpa_outputs[i],
rtol=rtol,
atol=atol)
......
......@@ -139,6 +139,8 @@ def get_attention_backend(backend_name: _Backend):
"vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend",
_Backend.FLASHMLA_VLLM_V1:
"vllm.v1.attention.backends.mla.flashmla.FlashMLABackend",
_Backend.FLASH_ATTN_MLA:
"vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend",
_Backend.TRITON_MLA_VLLM_V1:
"vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend",
}
......
......@@ -6,20 +6,22 @@ from typing import Callable, Optional
import pytest
import torch
import vllm.v1.core.kv_cache_utils as kv_cache_utils
from vllm.config import ModelConfig, SchedulerConfig, VllmConfig
from vllm.multimodal.inputs import (MultiModalFeatureSpec,
MultiModalKwargsItem, PlaceholderRange)
from vllm.sampling_params import SamplingParams
from vllm.utils import GiB_bytes, sha256, sha256_cbor_64bit
from vllm.utils import GiB_bytes, sha256, sha256_cbor
from vllm.v1.core.kv_cache_manager import KVCacheManager
# disable yapf here as it formats differently than isort such that both fail
# yapf: disable
from vllm.v1.core.kv_cache_utils import (
FreeKVCacheBlockQueue, KVCacheBlock, PrefixCachingMetrics,
BlockHash, FreeKVCacheBlockQueue, KVCacheBlock, PrefixCachingMetrics,
estimate_max_model_len, generate_block_hash_extra_keys,
get_kv_cache_config, get_max_concurrency_for_kv_cache_config,
get_request_block_hasher, hash_block_tokens, init_none_hash,
is_kv_cache_type_uniform, unify_kv_cache_configs)
is_kv_cache_type_uniform, make_block_hash_with_group_id,
unify_kv_cache_configs)
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec, KVCacheTensor,
SlidingWindowSpec)
......@@ -88,7 +90,7 @@ def new_sliding_window_spec(block_size=16,
sliding_window=sliding_window)
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash])
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor])
def test_none_hash(monkeypatch, hash_fn):
import vllm.v1.core.kv_cache_utils
......@@ -98,8 +100,8 @@ def test_none_hash(monkeypatch, hash_fn):
reloaded_kv_cache_utils = importlib.reload(vllm.v1.core.kv_cache_utils)
reloaded_kv_cache_utils.init_none_hash(hash_fn)
assert reloaded_kv_cache_utils.NONE_HASH is not None
assert isinstance(reloaded_kv_cache_utils.NONE_HASH, int)
assert reloaded_kv_cache_utils.NONE_HASH != 0
assert isinstance(reloaded_kv_cache_utils.NONE_HASH, bytes)
assert reloaded_kv_cache_utils.NONE_HASH != b""
# case 2: PYTHONHASHSEED is set, use the seed and hash_fn
with monkeypatch.context() as m:
......@@ -107,12 +109,11 @@ def test_none_hash(monkeypatch, hash_fn):
reloaded_kv_cache_utils = importlib.reload(vllm.v1.core.kv_cache_utils)
reloaded_kv_cache_utils.init_none_hash(hash_fn)
assert reloaded_kv_cache_utils.NONE_HASH is not None
assert isinstance(reloaded_kv_cache_utils.NONE_HASH, int)
assert isinstance(reloaded_kv_cache_utils.NONE_HASH, bytes)
assert hash_fn('python hash seed') == reloaded_kv_cache_utils.NONE_HASH
def test_kv_cache_block():
import vllm.v1.core.kv_cache_utils
# Test KVCacheBlock initialization
block = KVCacheBlock(block_id=0)
......@@ -127,8 +128,7 @@ def test_kv_cache_block():
assert block.ref_cnt == 0
# Test block hash setting and resetting
block_hash = vllm.v1.core.kv_cache_utils.BlockHash(hash_value=123,
token_ids=(1, 2, 3))
block_hash = make_block_hash_with_group_id(BlockHash(b"abc"), 0)
block.block_hash = block_hash
assert block.block_hash == block_hash
......@@ -247,7 +247,7 @@ def test_free_kv_cache_block_queue_append_n():
def test_free_kv_cache_block_queue_popleft_n():
blocks = [KVCacheBlock(block_id=i) for i in range(6)]
# Create a empty FreeKVCacheBlockQueue with these blocks
# Create an empty FreeKVCacheBlockQueue with these blocks
queue = FreeKVCacheBlockQueue(
[blocks[1], blocks[3], blocks[5], blocks[4], blocks[0], blocks[2]])
assert queue.num_free_blocks == 6
......@@ -407,27 +407,23 @@ def test_generate_block_hash_extra_keys_cache_salt():
assert next_mm_idx == 1
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash])
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor])
def test_hash_block_tokens(hash_fn):
import vllm.v1.core.kv_cache_utils
init_none_hash(hash_fn)
parent_block_hash = 123
parent_block_hash = BlockHash(b"123")
curr_block_token_ids = (1, 2, 3)
extra_keys = ("key1", "key2")
block_hash = hash_block_tokens(hash_fn, parent_block_hash,
curr_block_token_ids, extra_keys)
assert isinstance(block_hash, vllm.v1.core.kv_cache_utils.BlockHash)
assert block_hash.hash_value == hash_fn(
(parent_block_hash, curr_block_token_ids, extra_keys))
assert block_hash.token_ids == curr_block_token_ids
assert block_hash.extra_keys == extra_keys
expected = hash_fn((parent_block_hash, curr_block_token_ids, extra_keys))
assert block_hash == expected
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash])
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor])
def test_request_block_hasher(hash_fn):
import vllm.v1.core.kv_cache_utils
init_none_hash(hash_fn)
kv_cache_utils.init_none_hash(hash_fn)
request = make_request(
request_id="0",
prompt_token_ids=[_ for _ in range(6)],
......@@ -442,19 +438,13 @@ def test_request_block_hasher(hash_fn):
block_hashes = request.block_hashes
assert len(block_hashes) == 2
assert isinstance(block_hashes[0], vllm.v1.core.kv_cache_utils.BlockHash)
assert isinstance(block_hashes[1], vllm.v1.core.kv_cache_utils.BlockHash)
# Check the first block
assert block_hashes[0].token_ids == (0, 1, 2)
assert block_hashes[0].extra_keys == ("hash1", )
assert block_hashes[0] == hash_fn(
(kv_cache_utils.NONE_HASH, (0, 1, 2), ("hash1", )))
assert block_hashes[1] == hash_fn(
(block_hashes[0], (3, 4, 5), ("hash2", )))
# Check the second block
assert block_hashes[1].token_ids == (3, 4, 5)
assert block_hashes[1].extra_keys == ("hash2", )
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash])
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor])
def test_hash_tokens_different_mm_input(hash_fn):
init_none_hash(hash_fn)
......@@ -484,9 +474,9 @@ def test_hash_tokens_different_mm_input(hash_fn):
assert block_hashes1[1] != block_hashes2[1]
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash])
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor])
def test_hash_request_tokens_no_mm_inputs(hash_fn):
init_none_hash(hash_fn)
kv_cache_utils.init_none_hash(hash_fn)
request = make_request(
request_id="0",
......@@ -500,10 +490,9 @@ def test_hash_request_tokens_no_mm_inputs(hash_fn):
block_hashes = request.block_hashes
assert len(block_hashes) == 2
assert block_hashes[0].token_ids == (0, 1, 2)
assert block_hashes[0].extra_keys is None
assert block_hashes[1].token_ids == (3, 4, 5)
assert block_hashes[1].extra_keys is None
assert block_hashes[0] == hash_fn(
(kv_cache_utils.NONE_HASH, (0, 1, 2), None))
assert block_hashes[1] == hash_fn((block_hashes[0], (3, 4, 5), None))
def test_metrics():
......
......@@ -8,17 +8,19 @@ from typing import Callable, Optional
import pytest
import torch
import vllm.v1.core.kv_cache_utils as kv_cache_utils
from vllm.distributed.kv_events import AllBlocksCleared, BlockRemoved
from vllm.multimodal.inputs import (MultiModalFeatureSpec,
MultiModalKwargsItem, PlaceholderRange)
from vllm.sampling_params import SamplingParams
from vllm.utils import sha256, sha256_cbor_64bit
from vllm.utils import sha256, sha256_cbor
from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_manager import KVCacheManager, Request
from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId,
KVCacheBlock,
from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock,
get_block_hash, get_group_id,
get_request_block_hasher,
hash_block_tokens, init_none_hash)
hash_block_tokens, init_none_hash,
make_block_hash_with_group_id)
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec, SlidingWindowSpec)
......@@ -101,8 +103,10 @@ def make_kv_cache_config_hybrid_model(block_size: int,
)
@pytest.mark.parametrize("hash_algo", ["sha256", "sha256_cbor_64bit", "hash"])
def test_prefill(hash_algo):
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor])
def test_prefill(hash_fn):
init_none_hash(hash_fn)
block_size = 16
manager = KVCacheManager(
make_kv_cache_config(block_size, 11),
......@@ -110,10 +114,6 @@ def test_prefill(hash_algo):
enable_caching=True,
)
# choose the hash function according to the parameter
hash_fn = (sha256_cbor_64bit if hash_algo == "sha256_cbor_64bit" else
sha256 if hash_algo == "sha256" else hash)
# Complete 3 blocks (48 tokens)
common_token_ids = [i for i in range(3) for _ in range(16)]
......@@ -137,10 +137,12 @@ def test_prefill(hash_algo):
block_tokens = tuple(all_token_ids[(block_id - 1) * 16:block_id * 16])
block_hash = hash_block_tokens(hash_fn, parent_block_hash,
block_tokens)
assert manager.block_pool.blocks[
block_id].block_hash.block_hash == block_hash
blk_hash = manager.block_pool.blocks[block_id].block_hash
assert blk_hash is not None
assert get_block_hash(blk_hash) == block_hash
assert get_group_id(blk_hash) == 0
assert manager.block_pool.blocks[block_id].ref_cnt == 1
parent_block_hash = block_hash.hash_value
parent_block_hash = block_hash
# Check partial block metadata
for block_id in (4, ):
......@@ -233,7 +235,7 @@ def test_prefill_hybrid_model():
enable_caching=True,
)
hash_fn = hash
hash_fn = sha256
# Complete 3 blocks (48 tokens)
common_token_ids = [i for i in range(3) for _ in range(block_size)]
......@@ -260,11 +262,13 @@ def test_prefill_hybrid_model():
block_tokens = tuple(all_token_ids[(length - 1) * 16:length * 16])
block_hash = hash_block_tokens(hash_fn, parent_block_hash,
block_tokens)
for block_id in block_ids:
assert manager.block_pool.blocks[
block_id].block_hash.block_hash == block_hash
for group_id, block_id in enumerate(block_ids):
blk_hash = manager.block_pool.blocks[block_id].block_hash
assert blk_hash is not None
assert get_block_hash(blk_hash) == block_hash
assert get_group_id(blk_hash) == group_id
assert manager.block_pool.blocks[block_id].ref_cnt == 1
parent_block_hash = block_hash.hash_value
parent_block_hash = block_hash
# Check partial block metadata
for block_id in (4, 8, 12):
......@@ -298,11 +302,10 @@ def test_prefill_hybrid_model():
cached_block_hash_to_block_bak = copy.copy(
manager.block_pool.cached_block_hash_to_block)
def test_partial_request_hit(request_id: str,
hash_to_evict: list[BlockHashWithGroupId],
def test_partial_request_hit(request_id: str, hash_to_evict: list[bytes],
expect_hit_length: int):
req = make_request(request_id, common_token_ids + unique_token_ids,
block_size, hash)
block_size, sha256)
for hash_with_group_id in hash_to_evict:
manager.block_pool.cached_block_hash_to_block.pop(
hash_with_group_id)
......@@ -319,33 +322,32 @@ def test_prefill_hybrid_model():
# Evict the blocks outside sliding window, does not affect the hit length.
test_partial_request_hit("2", [
BlockHashWithGroupId(block_hashes[0], 1),
BlockHashWithGroupId(block_hashes[0], 2)
make_block_hash_with_group_id(block_hashes[0], 1),
make_block_hash_with_group_id(block_hashes[0], 2)
], 3)
# Evict the first block of full attention, makes total cache miss.
test_partial_request_hit("3", [
BlockHashWithGroupId(block_hashes[0], 0),
], 0)
test_partial_request_hit(
"3", [make_block_hash_with_group_id(block_hashes[0], 0)], 0)
# Evict the last block of all layers, reduces the hit length to 2.
test_partial_request_hit("4", [
BlockHashWithGroupId(block_hashes[2], 0),
BlockHashWithGroupId(block_hashes[2], 1),
BlockHashWithGroupId(block_hashes[2], 2),
make_block_hash_with_group_id(block_hashes[2], 0),
make_block_hash_with_group_id(block_hashes[2], 1),
make_block_hash_with_group_id(block_hashes[2], 2),
], 2)
# Evict the last block of full attention, reduces the hit length to 2.
test_partial_request_hit("5", [BlockHashWithGroupId(block_hashes[2], 0)],
2)
test_partial_request_hit(
"5", [make_block_hash_with_group_id(block_hashes[2], 0)], 2)
# Evict the last block of sliding window, reduces the hit length to 2.
test_partial_request_hit("6", [BlockHashWithGroupId(block_hashes[2], 1)],
2)
test_partial_request_hit(
"6", [make_block_hash_with_group_id(block_hashes[2], 1)], 2)
# Evict the last block of sliding window, reduces the hit length to 2.
test_partial_request_hit("7", [BlockHashWithGroupId(block_hashes[2], 2)],
2)
test_partial_request_hit(
"7", [make_block_hash_with_group_id(block_hashes[2], 2)], 2)
# Evict different set of blocks for full attention and sliding window makes
# total cache miss.
......@@ -353,9 +355,9 @@ def test_prefill_hybrid_model():
# The cache hit length of sliding window is 2 * block_size.
# Then it is cache miss as the two type of layers have different hit length.
test_partial_request_hit("8", [
BlockHashWithGroupId(block_hashes[2], 0),
BlockHashWithGroupId(block_hashes[0], 1),
BlockHashWithGroupId(block_hashes[0], 2),
make_block_hash_with_group_id(block_hashes[2], 0),
make_block_hash_with_group_id(block_hashes[0], 1),
make_block_hash_with_group_id(block_hashes[0], 2),
], 0)
......@@ -372,8 +374,8 @@ def test_prefill_plp():
max_model_len=8192,
enable_caching=True,
)
# the default hash function is hash
hash_fn = hash
# the default hash function is sha256
hash_fn = sha256
# Complete 3 blocks (48 tokens)
common_token_ids = [i for i in range(3) for _ in range(16)]
......@@ -404,10 +406,12 @@ def test_prefill_plp():
block_tokens = tuple(all_token_ids[(block_id - 1) * 16:block_id * 16])
block_hash = hash_block_tokens(hash_fn, parent_block_hash,
block_tokens)
assert manager.block_pool.blocks[
block_id].block_hash.block_hash == block_hash
blk_hash = (manager.block_pool.blocks[block_id].block_hash)
assert blk_hash is not None
assert get_block_hash(blk_hash) == block_hash
assert get_group_id(blk_hash) == 0
assert manager.block_pool.blocks[block_id].ref_cnt == 1
parent_block_hash = block_hash.hash_value
parent_block_hash = block_hash
# Check partial block metadata
for block_id in (4, ):
......@@ -493,7 +497,7 @@ def test_decode():
# Incomplete 1 block (7 tokens)
unique_token_ids = [3] * 7
req0 = make_request("0", common_token_ids + unique_token_ids, block_size,
hash)
sha256)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
assert not computed_blocks.blocks[0]
assert num_computed_tokens == 0
......@@ -538,7 +542,7 @@ def test_evict():
)
last_token_id = 5 * 16 + 7
req0 = make_request("0", list(range(last_token_id)), block_size, hash)
req0 = make_request("0", list(range(last_token_id)), block_size, sha256)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
assert not computed_blocks.blocks[0]
assert num_computed_tokens == 0
......@@ -550,7 +554,7 @@ def test_evict():
# 3 blocks.
req1 = make_request("1", list(range(last_token_id,
last_token_id + 3 * 16)), block_size,
hash)
sha256)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert not computed_blocks.blocks[0]
assert num_computed_tokens == 0
......@@ -572,7 +576,7 @@ def test_evict():
] == [10, 6, 5, 4, 3, 2, 1, 9, 8, 7]
# Touch the first 2 blocks.
req2 = make_request("2", list(range(2 * 16 + 3)), block_size, hash)
req2 = make_request("2", list(range(2 * 16 + 3)), block_size, sha256)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert computed_blocks.get_block_ids() == ([1, 2], )
assert num_computed_tokens == 2 * 16
......@@ -597,7 +601,7 @@ def test_hash_block_correct_reuse():
# Allocate 1 block and cache it.
num_tokens = block_size * 1
req = make_request("0", list(range(num_tokens)), block_size, hash)
req = make_request("0", list(range(num_tokens)), block_size, sha256)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
assert not computed_blocks.blocks[0]
assert num_computed_tokens == 0
......@@ -611,7 +615,7 @@ def test_hash_block_correct_reuse():
# Allocate a new block that's not full, make sure hash info on the
# block is cleared.
req = make_request("1", list(range(num_tokens - 1)), block_size, hash)
req = make_request("1", list(range(num_tokens - 1)), block_size, sha256)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
assert not computed_blocks.blocks[0]
assert num_computed_tokens == 0
......@@ -638,7 +642,7 @@ def test_computed_blocks_not_evicted():
# Allocate a block and cache it.
num_tokens = block_size * 1
req0 = make_request("0", list(range(num_tokens)), block_size, hash)
req0 = make_request("0", list(range(num_tokens)), block_size, sha256)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
assert not computed_blocks.blocks[0]
assert num_computed_tokens == 0
......@@ -650,7 +654,7 @@ def test_computed_blocks_not_evicted():
# Allocate another block.
req1 = make_request("1", list(range(num_tokens, num_tokens * 2)),
block_size, hash)
block_size, sha256)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert not computed_blocks.blocks[0]
assert num_computed_tokens == 0
......@@ -666,7 +670,7 @@ def test_computed_blocks_not_evicted():
# Now if we have a cache hit on the first block, we should evict the second
# cached block rather than the first one.
req2 = make_request("2", list(range(num_tokens * 2)), block_size, hash)
req2 = make_request("2", list(range(num_tokens * 2)), block_size, sha256)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert len(computed_blocks.blocks[0]) == 1
assert computed_blocks.blocks[0][0].block_id == 1
......@@ -691,7 +695,7 @@ def test_basic_prefix_caching_disabled():
)
req1 = make_request("1", list(range(10)), block_size,
hash) # 2 blocks and some more
sha256) # 2 blocks and some more
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert not computed_blocks.blocks[0]
......@@ -706,7 +710,7 @@ def test_basic_prefix_caching_disabled():
# No caching.
req2 = make_request("2", list(range(16)), block_size,
hash) # shared prefix
sha256) # shared prefix
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert not computed_blocks.blocks[0]
assert num_computed_tokens == 0
......@@ -716,7 +720,7 @@ def test_basic_prefix_caching_disabled():
assert len(blocks.blocks[0]) == 4
# New requests should not have any blocks.
req3 = make_request("3", list(range(4)), block_size, hash)
req3 = make_request("3", list(range(4)), block_size, sha256)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3)
assert not computed_blocks.blocks[0]
assert num_computed_tokens == 0
......@@ -726,7 +730,7 @@ def test_basic_prefix_caching_disabled():
assert not blocks
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash])
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor])
def test_cache_blocks(hash_fn):
"""
This is a unit test that tests the correctness of the _cache_full_blocks
......@@ -787,7 +791,7 @@ def test_cache_blocks_multi_group():
# Block 1/5: [4, 5, 6, 7]
# Block 2/6: [8, 9, 10, 11]
# Block 3/7: [12, 13]
req = make_request("0", list(range(14)), block_size, hash)
req = make_request("0", list(range(14)), block_size, sha256)
# Cache the blocks for group 0.
blocks = [KVCacheBlock(block_id=i) for i in range(2)]
......@@ -845,6 +849,8 @@ def test_mm_prefix_caching():
"""
This tests that the multi-modal prefix caching is correct.
"""
kv_cache_utils.init_none_hash(sha256)
block_size = 16
manager = KVCacheManager(
make_kv_cache_config(block_size, 11),
......@@ -874,23 +880,30 @@ def test_mm_prefix_caching():
req0 = make_request("0",
all_token_ids,
block_size,
hash,
sha256,
mm_positions=mm_positions,
mm_hashes=mm_hashes)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
# Completed block should have hashes with extra keys.
# Completed block should have hashes
assert not computed_blocks.blocks[0]
assert num_computed_tokens == 0
block_hashes = req0.block_hashes
assert len(block_hashes) == 3
assert block_hashes[0].extra_keys == ("aaa", )
assert block_hashes[1].extra_keys == ("aaa", "bbb")
assert block_hashes[2].extra_keys == ("bbb", )
assert block_hashes[0] == sha256(
(kv_cache_utils.NONE_HASH, tuple(all_token_ids[:block_size]),
("aaa", )))
assert block_hashes[1] == sha256(
(block_hashes[0], tuple(all_token_ids[block_size:block_size * 2]),
("aaa", "bbb")))
assert block_hashes[2] == sha256(
(block_hashes[1], tuple(all_token_ids[block_size * 2:block_size * 3]),
("bbb", )))
blocks = manager.allocate_slots(req0, 59,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert blocks is not None
assert blocks.get_block_ids() == ([1, 2, 3, 4], )
req0.num_computed_tokens = 59
......@@ -901,10 +914,10 @@ def test_mm_prefix_caching():
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert new_blocks is not None and len(new_blocks.blocks[0]) == 0
# The just completed block should have hashes with extra keys.
assert len(block_hashes) == 4
assert block_hashes[3].extra_keys == ("ccc", )
assert block_hashes[3] == sha256(
(block_hashes[2], tuple(all_token_ids[3 * block_size:] + [8] * 5),
("ccc", )))
# Cache hit.
unique_token_ids = [-1] * 7 + [200] * 5
......@@ -916,7 +929,7 @@ def test_mm_prefix_caching():
req1 = make_request("1",
all_token_ids,
block_size,
hash,
sha256,
mm_positions=mm_positions,
mm_hashes=mm_hashes)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
......@@ -929,6 +942,8 @@ def test_cache_key_salting():
This tests that cache salts are applied during hashing and the cache
is separated cache as expected.
"""
kv_cache_utils.init_none_hash(sha256)
block_size = 16
manager = KVCacheManager(
make_kv_cache_config(block_size, 11),
......@@ -939,21 +954,26 @@ def test_cache_key_salting():
# 3 complete blocks and an incomplete block with 11 tokens.
common_token_ids = [i for i in range(3) for _ in range(block_size)]
token_ids = common_token_ids + [3] * 11
req0 = make_request("0", token_ids, block_size, hash, cache_salt="salt1")
req0 = make_request("0", token_ids, block_size, sha256, cache_salt="salt1")
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
# Completed block should have hashes with extra keys.
# Completed block should have hashes
assert not computed_blocks.blocks[0]
assert num_computed_tokens == 0
block_hashes = req0.block_hashes
assert len(block_hashes) == 3
assert block_hashes[0].extra_keys == ("salt1", )
assert block_hashes[1].extra_keys is None
assert block_hashes[2].extra_keys is None
assert block_hashes[0] == sha256(
(kv_cache_utils.NONE_HASH, tuple(token_ids[:block_size]), ("salt1", )))
assert block_hashes[1] == sha256(
(block_hashes[0], tuple(token_ids[block_size:block_size * 2]), None))
assert block_hashes[2] == sha256(
(block_hashes[1], tuple(token_ids[block_size * 2:block_size * 3]),
None))
blocks = manager.allocate_slots(req0, 59,
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert blocks is not None
assert blocks.get_block_ids() == ([1, 2, 3, 4], )
req0.num_computed_tokens = 59
......@@ -964,14 +984,13 @@ def test_cache_key_salting():
len(computed_blocks.blocks[0]) * 16,
computed_blocks)
assert new_blocks is not None and len(new_blocks.blocks[0]) == 0
# Now one more block that should not have extra keys.
assert len(block_hashes) == 4
assert block_hashes[3].extra_keys is None
assert block_hashes[3] == sha256(
(block_hashes[2], tuple(token_ids[3 * block_size:] + [8] * 5), None))
# Test cache hit with a new request that has the same salt.
token_ids = common_token_ids + [4] * 11
req1 = make_request("1", token_ids, block_size, hash, cache_salt="salt1")
req1 = make_request("1", token_ids, block_size, sha256, cache_salt="salt1")
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
# Should match only a prefix of 3 blocks.
assert len(computed_blocks.blocks[0]) == 3
......@@ -979,13 +998,19 @@ def test_cache_key_salting():
# Test cache miss with same content but different salt.
token_ids = common_token_ids + [4] * 11
req2 = make_request("2", token_ids, block_size, hash, cache_salt="salt2")
req2 = make_request("2", token_ids, block_size, sha256, cache_salt="salt2")
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert len(computed_blocks.blocks[0]) == 0
assert num_computed_tokens == 0
block_hashes = req2.block_hashes
assert len(block_hashes) == 3
assert block_hashes[0].extra_keys == ("salt2", )
assert block_hashes[0] == sha256(
(kv_cache_utils.NONE_HASH, tuple(token_ids[:block_size]), ("salt2", )))
assert block_hashes[1] == sha256(
(block_hashes[0], tuple(token_ids[block_size:block_size * 2]), None))
assert block_hashes[2] == sha256(
(block_hashes[1], tuple(token_ids[block_size * 2:block_size * 3]),
None))
def test_prefill_not_enough_free_blocks_with_computed_blocks():
......@@ -1004,7 +1029,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
# Complete 3 blocks (48 tokens)
# | Common-0 | Common-1 | Common-2 | ... |
common_token_ids = [i for i in range(3) for _ in range(16)]
req0 = make_request("0", common_token_ids, block_size, hash)
req0 = make_request("0", common_token_ids, block_size, sha256)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
assert not computed_blocks.blocks[0]
assert num_computed_tokens == 0
......@@ -1015,7 +1040,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
req0.request_id]
# | Common-0 | Common-1 | Common-2 | Req1-3 | Req1-4 | Req1-5 | ... |
req1 = make_request("1", common_token_ids * 2, block_size, hash)
req1 = make_request("1", common_token_ids * 2, block_size, sha256)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert computed_blocks.blocks[0] == block_part0
assert num_computed_tokens == 3 * 16
......@@ -1032,7 +1057,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
# | Common-0 | Common-1 | Common-2 | Req1-3 (F) | Req1-4 (F) |
# | Req1-5(F)| Req2-0 | Req2-1 | ... |
req2 = make_request("2", [7] * block_size * 2, block_size, hash)
req2 = make_request("2", [7] * block_size * 2, block_size, sha256)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert not computed_blocks.blocks[0]
assert num_computed_tokens == 0
......@@ -1044,7 +1069,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
# but it cannot be allocated due to insufficient free blocks (2).
# In this case, the ref_cnt of the computed blocks should not be changed.
assert manager.block_pool.free_block_queue.num_free_blocks == 5
req3 = make_request("3", common_token_ids * 3, block_size, hash)
req3 = make_request("3", common_token_ids * 3, block_size, sha256)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3)
assert computed_blocks.blocks[0] == block_part1
assert num_computed_tokens == 6 * 16
......@@ -1069,13 +1094,13 @@ def test_reset_prefix_cache():
full_block_token_ids = [i for i in range(3) for _ in range(16)]
unique_token_ids = [3] * 7
all_token_ids = full_block_token_ids + unique_token_ids
req0 = make_request("0", all_token_ids, block_size, hash)
req0 = make_request("0", all_token_ids, block_size, sha256)
blocks = manager.allocate_slots(req0, 55)
assert blocks.get_block_ids() == ([1, 2, 3, 4], )
unique_token_ids = [4] * 7
all_token_ids = full_block_token_ids + unique_token_ids
req1 = make_request("1", all_token_ids, block_size, hash)
req1 = make_request("1", all_token_ids, block_size, sha256)
computed_blocks, _ = manager.get_computed_blocks(req1)
assert len(req1.block_hashes) == 3
assert len(computed_blocks.blocks[0]) == 3
......@@ -1109,7 +1134,7 @@ def test_prefix_cache_stats_disabled():
assert manager.prefix_cache_stats is None
# Call all functions that check whether log_stats is disabled.
req = make_request("0", list(range(16)), block_size, hash)
req = make_request("0", list(range(16)), block_size, sha256)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
assert not computed_blocks.blocks[0]
assert num_computed_tokens == 0
......@@ -1124,15 +1149,9 @@ def test_prefix_cache_stats_disabled():
def test_maybe_evict_cached_block():
pool = BlockPool(num_gpu_blocks=4, enable_caching=True)
block_hash0 = BlockHashWithGroupId(block_hash=BlockHash(hash_value=10,
token_ids=(100, )),
group_id=1000)
block_hash1 = BlockHashWithGroupId(block_hash=BlockHash(hash_value=20,
token_ids=(200, )),
group_id=2000)
block_hash2 = BlockHashWithGroupId(block_hash=BlockHash(hash_value=30,
token_ids=(300, )),
group_id=3000)
block_hash0 = make_block_hash_with_group_id(BlockHash(b"10"), 1000)
block_hash1 = make_block_hash_with_group_id(BlockHash(b"20"), 2000)
block_hash2 = make_block_hash_with_group_id(BlockHash(b"30"), 3000)
block_hashes = [
block_hash0,
block_hash1,
......@@ -1206,7 +1225,7 @@ def test_kv_cache_events(blocks_to_cache: int):
)
num_tokens = block_size * blocks_to_cache
req0 = make_request("0", list(range(num_tokens)), block_size, hash)
req0 = make_request("0", list(range(num_tokens)), block_size, sha256)
_ = manager.allocate_slots(req0, num_tokens)
events = manager.take_events()
......@@ -1222,7 +1241,7 @@ def test_kv_cache_events(blocks_to_cache: int):
# Should see block_to_cache number of removed block events and a new block
# stored event
manager.free(req0)
req1 = make_request("1", list(range(num_tokens)), block_size, hash)
req1 = make_request("1", list(range(num_tokens)), block_size, sha256)
_ = manager.allocate_slots(req1, num_tokens)
events = manager.take_events()
......@@ -1256,7 +1275,7 @@ def test_eagle_enabled_removes_last_block():
# Request with 3 full blocks (48 tokens)
token_ids = [0] * (3 * block_size)
req = make_request("divisible_request", token_ids, block_size, hash)
req = make_request("divisible_request", token_ids, block_size, sha256)
# Prime the cache
computed_blocks, _ = manager.get_computed_blocks(req)
......@@ -1266,7 +1285,7 @@ def test_eagle_enabled_removes_last_block():
manager.free(req)
# New request with same tokens + Eagle enabled
req_eagle = make_request("eagle_divisible", token_ids, block_size, hash)
req_eagle = make_request("eagle_divisible", token_ids, block_size, sha256)
computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle)
# Should retain 1 block:
......@@ -1287,7 +1306,7 @@ def test_eagle_with_partial_blocks():
)
# 2 full blocks + 5 tokens (non-divisible length)
token_ids = [0] * (2 * block_size + 5)
req = make_request("partial_block_test", token_ids, block_size, hash)
req = make_request("partial_block_test", token_ids, block_size, sha256)
# Prime the cache
computed_blocks, _ = manager.get_computed_blocks(req)
......@@ -1297,7 +1316,7 @@ def test_eagle_with_partial_blocks():
manager.free(req)
# New request with Eagle enabled
req_eagle = make_request("partial_eagle", token_ids, block_size, hash)
req_eagle = make_request("partial_eagle", token_ids, block_size, sha256)
computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle)
# Original match: 2 full blocks → Eagle removes 1 → 1 remaining
assert len(computed_blocks.blocks[0]) == 1
......@@ -1328,7 +1347,7 @@ def test_eagle_with_sliding_window():
# 2 full blocks + 5 tokens (non-divisible length)
token_ids = [0] * (2 * block_size + 5)
req = make_request("partial_block_test", token_ids, block_size, hash)
req = make_request("partial_block_test", token_ids, block_size, sha256)
# Prime the cache
computed_blocks, _ = manager.get_computed_blocks(req)
......@@ -1341,7 +1360,7 @@ def test_eagle_with_sliding_window():
manager.free(req)
# New request with Eagle enabled
req_eagle = make_request("partial_eagle", token_ids, block_size, hash)
req_eagle = make_request("partial_eagle", token_ids, block_size, sha256)
computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle)
# Original match: 2 full blocks → Eagle removes 1 → 1 remaining
assert len(computed_blocks.blocks[0]) == 1
......@@ -1351,11 +1370,11 @@ def test_eagle_with_sliding_window():
assert manager.block_pool.get_cached_block(
block_hash_first_block, kv_cache_group_ids=[0]) is not None
manager.block_pool.cached_block_hash_to_block.pop(
BlockHashWithGroupId(block_hash_first_block, 0))
make_block_hash_with_group_id(block_hash_first_block, 0))
# New request
req_after_evict = make_request("partial_eagle_after_evict", token_ids,
block_size, hash)
block_size, sha256)
computed_blocks, num_tokens = manager.get_computed_blocks(req_after_evict)
# Cache miss. The only hit prefix is [NULL_BLOCK, BLOCK_2] if eagle is
# not considered. But after dropping the last matched block due to eagle,
......
......@@ -6,8 +6,8 @@ import random
import torch
from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId,
KVCacheBlock)
from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock,
make_block_hash_with_group_id)
from vllm.v1.core.single_type_kv_cache_manager import (
ChunkedLocalAttentionManager, SlidingWindowManager)
from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec,
......@@ -44,7 +44,7 @@ def test_chunked_local_attention_possible_cached_prefix():
def run_one_case(block_is_cached, tail_token, expect_length):
block_hash_list = [
BlockHash(i, ()) for i in range(len(block_is_cached))
BlockHash(str(i).encode()) for i in range(len(block_is_cached))
]
block_pool.cached_block_hash_to_block.clear()
......@@ -53,8 +53,8 @@ def test_chunked_local_attention_possible_cached_prefix():
for i, (block_hash,
is_cached) in enumerate(zip(block_hash_list, block_is_cached)):
if is_cached:
block_pool.cached_block_hash_to_block[BlockHashWithGroupId(
block_hash, 0)] = {
block_pool.cached_block_hash_to_block[
make_block_hash_with_group_id(block_hash, 0)] = {
i: block_pool.blocks[i + 10],
}
......@@ -109,7 +109,7 @@ def test_sliding_window_possible_cached_prefix():
def run_one_case(block_is_cached, expect_length):
block_hash_list = [
BlockHash(i, ()) for i in range(len(block_is_cached))
BlockHash(str(i).encode()) for i in range(len(block_is_cached))
]
block_pool.cached_block_hash_to_block.clear()
......@@ -118,8 +118,8 @@ def test_sliding_window_possible_cached_prefix():
for i, (block_hash,
is_cached) in enumerate(zip(block_hash_list, block_is_cached)):
if is_cached:
block_pool.cached_block_hash_to_block[BlockHashWithGroupId(
block_hash, 0)] = {
block_pool.cached_block_hash_to_block[
make_block_hash_with_group_id(block_hash, 0)] = {
i: block_pool.blocks[i + 10],
}
......
......@@ -9,6 +9,7 @@ from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig,
from vllm.multimodal.inputs import (MultiModalFeatureSpec,
MultiModalKwargsItem, PlaceholderRange)
from vllm.sampling_params import SamplingParams
from vllm.utils import sha256
from vllm.v1.core.kv_cache_utils import (get_request_block_hasher,
init_none_hash)
from vllm.v1.core.sched.async_scheduler import AsyncScheduler
......@@ -130,10 +131,10 @@ def create_requests(
) -> list[Request]:
global _none_hash_initialized
if not _none_hash_initialized:
init_none_hash(hash)
init_none_hash(sha256)
_none_hash_initialized = True
block_hasher = get_request_block_hasher(block_size, hash)
block_hasher = get_request_block_hasher(block_size, sha256)
sampling_params = SamplingParams(ignore_eos=False,
max_tokens=max_tokens,
stop_token_ids=stop_token_ids,
......
......@@ -62,6 +62,16 @@ backend_configs = {
"cudagraph_mode": "FULL_AND_PIECEWISE",
},
specific_gpu_arch=(9, 0)),
# FlashAttention MLA on Hopper
"FlashAttentionMLA":
BackendConfig(name="FlashAttentionMLA",
env_vars={
"VLLM_ATTENTION_BACKEND": "FLASH_ATTN_MLA",
},
comp_config={
"cudagraph_mode": "FULL_DECODE_ONLY",
},
specific_gpu_arch=(9, 0)),
# FA2
"FA2":
BackendConfig(name="FA2",
......
......@@ -83,7 +83,7 @@ def test_ngram_correctness(
model_name: str,
):
'''
Compare the outputs of a original LLM and a speculative LLM
Compare the outputs of an original LLM and a speculative LLM
should be the same when using ngram speculative decoding.
'''
with monkeypatch.context() as m:
......@@ -117,45 +117,38 @@ def test_ngram_correctness(
print(f"ref_output: {ref_output.outputs[0].text}")
print(f"spec_output: {spec_output.outputs[0].text}")
# Heuristic: expect at least 70% of the prompts to match exactly
# Heuristic: expect at least 68% of the prompts to match exactly
# Upon failure, inspect the outputs to check for inaccuracy.
assert matches > int(0.7 * len(ref_outputs))
assert matches >= int(0.68 * len(ref_outputs))
del spec_llm
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()
@pytest.mark.parametrize(
["model_setup", "mm_enabled"],
[
# TODO: Re-enable this once tests/models/test_initialization.py is fixed, see PR #22333 #22611 # noqa: E501
# (("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False),
(("eagle", "meta-llama/Llama-3.1-8B-Instruct",
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), False),
(("eagle3", "meta-llama/Llama-3.1-8B-Instruct",
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1), False),
pytest.param(
("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
False,
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
pytest.param(
("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
True,
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
(("eagle", "eagle618/deepseek-v3-random",
"eagle618/eagle-deepseek-v3-random", 1), False),
],
ids=[
# TODO: Re-enable this once tests/models/test_initialization.py is fixed, see PR #22333 #22611 # noqa: E501
# "qwen3_eagle3",
"llama3_eagle",
"llama3_eagle3",
"llama4_eagle",
"llama4_eagle_mm",
"deepseek_eagle"
])
@pytest.mark.parametrize(["model_setup", "mm_enabled"], [
(("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False),
(("eagle", "meta-llama/Llama-3.1-8B-Instruct",
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), False),
(("eagle3", "meta-llama/Llama-3.1-8B-Instruct",
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1), False),
pytest.param(
("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
False,
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
pytest.param(
("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
True,
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
(("eagle", "eagle618/deepseek-v3-random",
"eagle618/eagle-deepseek-v3-random", 1), False),
],
ids=[
"qwen3_eagle3", "llama3_eagle", "llama3_eagle3",
"llama4_eagle", "llama4_eagle_mm",
"deepseek_eagle"
])
@pytest.mark.parametrize("attn_backend",
get_attn_backend_list_based_on_platform())
def test_eagle_correctness(
......@@ -169,7 +162,7 @@ def test_eagle_correctness(
# TODO: Fix this flaky test
pytest.skip(
"TREE_ATTN is flaky in the test disable for now until it can be "
"reolved (see https://github.com/vllm-project/vllm/issues/22922)")
"resolved (see https://github.com/vllm-project/vllm/issues/22922)")
# Generate test prompts inside the function instead of using fixture
test_prompts = get_test_prompts(mm_enabled)
......
......@@ -393,7 +393,7 @@ class MockLoggingStatLogger(LoggingStatLogger):
async def test_customize_loggers(monkeypatch):
"""Test that we can customize the loggers.
If a customized logger is provided at the init, it should
be used directly.
be added to the default loggers.
"""
with monkeypatch.context() as m, ExitStack() as after:
......@@ -410,7 +410,8 @@ async def test_customize_loggers(monkeypatch):
stat_loggers = engine.logger_manager.per_engine_logger_dict
assert len(stat_loggers) == 1
assert len(stat_loggers[0]) == 1
assert len(
stat_loggers[0]) == 2 # LoggingStatLogger + MockLoggingStatLogger
stat_loggers[0][0].log.assert_called_once()
......
......@@ -36,18 +36,19 @@ def test_prefix_caching_from_cli():
assert vllm_config.cache_config.enable_prefix_caching
# default hash algorithm is "builtin"
assert vllm_config.cache_config.prefix_caching_hash_algo == "builtin"
assert vllm_config.cache_config.prefix_caching_hash_algo == "sha256"
# set hash algorithm to sha256_cbor
args = parser.parse_args(["--prefix-caching-hash-algo", "sha256_cbor"])
vllm_config = EngineArgs.from_cli_args(args=args).create_engine_config()
assert vllm_config.cache_config.prefix_caching_hash_algo == \
"sha256_cbor"
# set hash algorithm to sha256
args = parser.parse_args(["--prefix-caching-hash-algo", "sha256"])
vllm_config = EngineArgs.from_cli_args(args=args).create_engine_config()
assert vllm_config.cache_config.prefix_caching_hash_algo == "sha256"
# set hash algorithm to builtin
args = parser.parse_args(["--prefix-caching-hash-algo", "builtin"])
vllm_config = EngineArgs.from_cli_args(args=args).create_engine_config()
assert vllm_config.cache_config.prefix_caching_hash_algo == "builtin"
# an invalid hash algorithm raises an error
parser.exit_on_error = False
with pytest.raises(ArgumentError):
......
......@@ -152,8 +152,8 @@ def test_multi_modal_uuids_accepts_none_and_passes_through(
*,
tokenization_kwargs=None,
lora_request=None,
mm_hash_overrides=None):
captured["mm_hash_overrides"] = mm_hash_overrides
mm_uuids=None):
captured["mm_uuids"] = mm_uuids
# Minimal processed inputs for decoder-only flow
return {"type": "token", "prompt_token_ids": [1]}
......@@ -180,7 +180,7 @@ def test_multi_modal_uuids_accepts_none_and_passes_through(
params=SamplingParams(),
)
assert captured["mm_hash_overrides"] == mm_uuids
assert captured["mm_uuids"] == mm_uuids
def test_multi_modal_uuids_ignored_when_caching_disabled(monkeypatch):
......@@ -196,8 +196,8 @@ def test_multi_modal_uuids_ignored_when_caching_disabled(monkeypatch):
*,
tokenization_kwargs=None,
lora_request=None,
mm_hash_overrides=None):
captured["mm_hash_overrides"] = mm_hash_overrides
mm_uuids=None):
captured["mm_uuids"] = mm_uuids
return {"type": "token", "prompt_token_ids": [1]}
monkeypatch.setattr(processor.input_preprocessor,
......@@ -223,7 +223,7 @@ def test_multi_modal_uuids_ignored_when_caching_disabled(monkeypatch):
)
# Expect request-id-based overrides are passed through
assert captured["mm_hash_overrides"] == {
assert captured["mm_uuids"] == {
"image": [f"{request_id}-image-0", f"{request_id}-image-1"],
"video": [f"{request_id}-video-0"],
}
......@@ -46,12 +46,12 @@ PARAMS_MODELS_BACKENDS_TOKENIZER_MODE = [
("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "mistral", None),
("Qwen/Qwen2.5-1.5B-Instruct", "xgrammar", "auto", None),
("Qwen/Qwen2.5-1.5B-Instruct", "lm-format-enforcer", "auto", None),
("mistralai/Ministral-8B-Instruct-2410", "outlines", "auto", None),
("mistralai/Ministral-8B-Instruct-2410", "outlines", "mistral", None),
#FIXME: This tests are flaky on CI thus disabled. Tracking in Issue #24402
# ("mistralai/Ministral-8B-Instruct-2410", "outlines", "auto", None),
# ("mistralai/Ministral-8B-Instruct-2410", "outlines", "mistral", None),
#("Qwen/Qwen2.5-1.5B-Instruct", "guidance", "auto"),
("mistralai/Ministral-8B-Instruct-2410", "outlines", "auto",
NGRAM_SPEC_CONFIG),
#FIXME: This test is flaky on CI thus disabled
#("Qwen/Qwen2.5-1.5B-Instruct", "guidance", "auto"),
("mistralai/Ministral-8B-Instruct-2410", "guidance", "auto",
NGRAM_SPEC_CONFIG),
("Qwen/Qwen2.5-1.5B-Instruct", "xgrammar", "auto", NGRAM_SPEC_CONFIG),
......@@ -122,6 +122,7 @@ def test_structured_output(
guided_decoding_backend=guided_decoding_backend,
guided_decoding_disable_any_whitespace=(guided_decoding_backend
in {"xgrammar", "guidance"}),
seed=120,
tokenizer_mode=tokenizer_mode,
speculative_config=speculative_config)
......
......@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import openai # use the official client for correctness check
import openai.types.responses as openai_responses_types
import pytest
......@@ -86,3 +87,18 @@ async def test_logprobs(client: openai.AsyncOpenAI):
outputs = response.output
assert outputs[-1].content[-1].logprobs
assert len(outputs[-1].content[-1].logprobs[0].top_logprobs) == 5
@pytest.mark.asyncio
async def test_streaming(client: openai.AsyncOpenAI):
stream = await client.responses.create(
input="What is 13 * 24?",
stream=True,
)
events = [event async for event in stream]
assert isinstance(events[0], openai_responses_types.ResponseCreatedEvent)
assert any(
isinstance(event, openai_responses_types.ResponseTextDeltaEvent)
for event in events)
assert isinstance(events[-1],
openai_responses_types.ResponseCompletedEvent)
......@@ -8,17 +8,17 @@ import pytest
import pytest_asyncio
from tests.utils import RemoteOpenAIServer
from vllm.multimodal.utils import encode_image_base64, fetch_image
from vllm.multimodal.utils import encode_image_base64
# Use a small vision model for testing
MODEL_NAME = "Qwen/Qwen2.5-VL-3B-Instruct"
MAXIMUM_IMAGES = 2
# Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA)
TEST_IMAGE_URLS = [
"https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg",
"https://upload.wikimedia.org/wikipedia/commons/f/fa/Grayscale_8bits_palette_sample_image.png",
"https://upload.wikimedia.org/wikipedia/commons/thumb/9/91/Venn_diagram_rgb.svg/1280px-Venn_diagram_rgb.svg.png",
"https://upload.wikimedia.org/wikipedia/commons/0/0b/RGBA_comp.png",
TEST_IMAGE_ASSETS = [
"2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg", # "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
"Grayscale_8bits_palette_sample_image.png", # "https://upload.wikimedia.org/wikipedia/commons/f/fa/Grayscale_8bits_palette_sample_image.png",
"1280px-Venn_diagram_rgb.svg.png", # "https://upload.wikimedia.org/wikipedia/commons/thumb/9/91/Venn_diagram_rgb.svg/1280px-Venn_diagram_rgb.svg.png",
"RGBA_comp.png", # "https://upload.wikimedia.org/wikipedia/commons/0/0b/RGBA_comp.png",
]
......@@ -52,16 +52,17 @@ async def client(image_server):
@pytest.fixture(scope="session")
def base64_encoded_image() -> dict[str, str]:
def base64_encoded_image(local_asset_server) -> dict[str, str]:
return {
image_url: encode_image_base64(fetch_image(image_url))
for image_url in TEST_IMAGE_URLS
image_url:
encode_image_base64(local_asset_server.get_image_asset(image_url))
for image_url in TEST_IMAGE_ASSETS
}
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
@pytest.mark.parametrize("image_url", TEST_IMAGE_ASSETS, indirect=True)
async def test_single_chat_session_image(client: openai.AsyncOpenAI,
model_name: str, image_url: str):
content_text = "What's in this image?"
......@@ -91,11 +92,11 @@ async def test_single_chat_session_image(client: openai.AsyncOpenAI,
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
@pytest.mark.parametrize("raw_image_url", TEST_IMAGE_ASSETS)
async def test_single_chat_session_image_base64encoded(
client: openai.AsyncOpenAI,
model_name: str,
image_url: str,
raw_image_url: str,
base64_encoded_image: dict[str, str],
):
content_text = "What's in this image?"
......@@ -106,7 +107,7 @@ async def test_single_chat_session_image_base64encoded(
{
"type": "input_image",
"image_url":
f"data:image/jpeg;base64,{base64_encoded_image[image_url]}",
f"data:image/jpeg;base64,{base64_encoded_image[raw_image_url]}",
"detail": "auto",
},
{
......@@ -127,7 +128,8 @@ async def test_single_chat_session_image_base64encoded(
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize(
"image_urls",
[TEST_IMAGE_URLS[:i] for i in range(2, len(TEST_IMAGE_URLS))])
[TEST_IMAGE_ASSETS[:i] for i in range(2, len(TEST_IMAGE_ASSETS))],
indirect=True)
async def test_multi_image_input(client: openai.AsyncOpenAI, model_name: str,
image_urls: list[str]):
messages = [{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment