Commit 38d80967 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.10.2rc2' into v0.10.2rc2-ori

parents 33650733 880c741b
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from pathlib import Path
from typing import Optional, Union
import pytest
from transformers import PretrainedConfig
from vllm.transformers_utils.config import (get_config_parser,
register_config_parser)
from vllm.transformers_utils.config_parser_base import ConfigParserBase
@register_config_parser("custom_config_parser")
class CustomConfigParser(ConfigParserBase):
def parse(self,
model: Union[str, Path],
trust_remote_code: bool,
revision: Optional[str] = None,
code_revision: Optional[str] = None,
**kwargs) -> tuple[dict, PretrainedConfig]:
raise NotImplementedError
def test_register_config_parser():
assert isinstance(get_config_parser("custom_config_parser"),
CustomConfigParser)
def test_invalid_config_parser():
with pytest.raises(ValueError):
@register_config_parser("invalid_config_parser")
class InvalidConfigParser:
pass
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio import asyncio
import contextlib
import copy import copy
import functools import functools
import importlib import importlib
...@@ -13,10 +14,11 @@ import sys ...@@ -13,10 +14,11 @@ import sys
import tempfile import tempfile
import time import time
import warnings import warnings
from contextlib import contextmanager, suppress from contextlib import ExitStack, contextmanager, suppress
from multiprocessing import Process from multiprocessing import Process
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Literal, Optional, Union from typing import Any, Callable, Literal, Optional, Union
from unittest.mock import patch
import cloudpickle import cloudpickle
import httpx import httpx
...@@ -799,43 +801,106 @@ _P = ParamSpec("_P") ...@@ -799,43 +801,106 @@ _P = ParamSpec("_P")
def fork_new_process_for_each_test( def fork_new_process_for_each_test(
f: Callable[_P, None]) -> Callable[_P, None]: func: Callable[_P, None]) -> Callable[_P, None]:
"""Decorator to fork a new process for each test function. """Decorator to fork a new process for each test function.
See https://github.com/vllm-project/vllm/issues/7053 for more details. See https://github.com/vllm-project/vllm/issues/7053 for more details.
""" """
@functools.wraps(f) @functools.wraps(func)
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None: def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None:
# Make the process the leader of its own process group # Make the process the leader of its own process group
# to avoid sending SIGTERM to the parent process # to avoid sending SIGTERM to the parent process
os.setpgrp() os.setpgrp()
from _pytest.outcomes import Skipped from _pytest.outcomes import Skipped
pid = os.fork()
print(f"Fork a new process to run a test {pid}") # Create a unique temporary file to store exception info from child
if pid == 0: # process. Use test function name and process ID to avoid collisions.
try: with tempfile.NamedTemporaryFile(
f(*args, **kwargs) delete=False,
except Skipped as e: mode='w+b',
# convert Skipped to exit code 0 prefix=f"vllm_test_{func.__name__}_{os.getpid()}_",
print(str(e)) suffix=".exc") as exc_file, ExitStack() as delete_after:
os._exit(0) exc_file_path = exc_file.name
except Exception: delete_after.callback(os.remove, exc_file_path)
import traceback
traceback.print_exc() pid = os.fork()
os._exit(1) print(f"Fork a new process to run a test {pid}")
if pid == 0:
# Parent process responsible for deleting, don't delete
# in child.
delete_after.pop_all()
try:
func(*args, **kwargs)
except Skipped as e:
# convert Skipped to exit code 0
print(str(e))
os._exit(0)
except Exception as e:
import traceback
tb_string = traceback.format_exc()
# Try to serialize the exception object first
exc_to_serialize: dict[str, Any]
try:
# First, try to pickle the actual exception with
# its traceback.
exc_to_serialize = {'pickled_exception': e}
# Test if it can be pickled
cloudpickle.dumps(exc_to_serialize)
except (Exception, KeyboardInterrupt):
# Fall back to string-based approach.
exc_to_serialize = {
'exception_type': type(e).__name__,
'exception_msg': str(e),
'traceback': tb_string,
}
try:
with open(exc_file_path, 'wb') as f:
cloudpickle.dump(exc_to_serialize, f)
except Exception:
# Fallback: just print the traceback.
print(tb_string)
os._exit(1)
else:
os._exit(0)
else: else:
os._exit(0) pgid = os.getpgid(pid)
else: _pid, _exitcode = os.waitpid(pid, 0)
pgid = os.getpgid(pid) # ignore SIGTERM signal itself
_pid, _exitcode = os.waitpid(pid, 0) old_signal_handler = signal.signal(signal.SIGTERM,
# ignore SIGTERM signal itself signal.SIG_IGN)
old_signal_handler = signal.signal(signal.SIGTERM, signal.SIG_IGN) # kill all child processes
# kill all child processes os.killpg(pgid, signal.SIGTERM)
os.killpg(pgid, signal.SIGTERM) # restore the signal handler
# restore the signal handler signal.signal(signal.SIGTERM, old_signal_handler)
signal.signal(signal.SIGTERM, old_signal_handler) if _exitcode != 0:
assert _exitcode == 0, (f"function {f} failed when called with" # Try to read the exception from the child process
f" args {args} and kwargs {kwargs}") exc_info = {}
if os.path.exists(exc_file_path):
with contextlib.suppress(Exception), \
open(exc_file_path, 'rb') as f:
exc_info = cloudpickle.load(f)
if (original_exception :=
exc_info.get('pickled_exception')) is not None:
# Re-raise the actual exception object if it was
# successfully pickled.
assert isinstance(original_exception, Exception)
raise original_exception
if (original_tb := exc_info.get("traceback")) is not None:
# Use string-based traceback for fallback case
raise AssertionError(
f"Test {func.__name__} failed when called with"
f" args {args} and kwargs {kwargs}"
f" (exit code: {_exitcode}):\n{original_tb}"
) from None
# Fallback to the original generic error
raise AssertionError(
f"function {func.__name__} failed when called with"
f" args {args} and kwargs {kwargs}"
f" (exit code: {_exitcode})") from None
return wrapper return wrapper
...@@ -1077,3 +1142,11 @@ def get_attn_backend_list_based_on_platform() -> list[str]: ...@@ -1077,3 +1142,11 @@ def get_attn_backend_list_based_on_platform() -> list[str]:
return attn_backend_list return attn_backend_list
else: else:
raise ValueError("Unsupported platform") raise ValueError("Unsupported platform")
@contextmanager
def override_cutlass_fp8_supported(value: bool):
with patch(
"vllm.model_executor.layers.quantization.utils.w8a8_utils.cutlass_fp8_supported",
return_value=value):
yield
...@@ -835,22 +835,20 @@ def test_model_specification(parser_with_config, cli_config_file, ...@@ -835,22 +835,20 @@ def test_model_specification(parser_with_config, cli_config_file,
@pytest.mark.parametrize("input", [(), ("abc", ), (None, ), @pytest.mark.parametrize("input", [(), ("abc", ), (None, ),
(None, bool, [1, 2, 3])]) (None, bool, [1, 2, 3])])
@pytest.mark.parametrize("output", [0, 1, 2]) def test_sha256(input: tuple):
def test_sha256(input: tuple, output: int): digest = sha256(input)
hash = sha256(input) assert digest is not None
assert hash is not None assert isinstance(digest, bytes)
assert isinstance(hash, int) assert digest != b""
assert hash != 0
bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL) input_bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL)
assert hash == int.from_bytes(hashlib.sha256(bytes).digest(), assert digest == hashlib.sha256(input_bytes).digest()
byteorder="big")
# hashing again, returns the same value # hashing again, returns the same value
assert hash == sha256(input) assert digest == sha256(input)
# hashing different input, returns different value # hashing different input, returns different value
assert hash != sha256(input + (1, )) assert digest != sha256(input + (1, ))
@pytest.mark.parametrize( @pytest.mark.parametrize(
......
...@@ -70,22 +70,6 @@ BATCH_SPECS = { ...@@ -70,22 +70,6 @@ BATCH_SPECS = {
} }
def create_dummy_kv_cache(kv_cache_spec: FullAttentionSpec,
device: torch.device,
num_blocks: int = 100) -> torch.Tensor:
"""Create a dummy KV cache tensor for testing."""
kv_cache = torch.randn(
2, # K and V
num_blocks,
kv_cache_spec.block_size,
kv_cache_spec.num_kv_heads,
kv_cache_spec.head_size,
dtype=_convert_dtype_to_torch(kv_cache_spec.dtype),
device=device,
)
return kv_cache
def create_and_prepopulate_kv_cache( def create_and_prepopulate_kv_cache(
k_contexts: list[torch.Tensor], k_contexts: list[torch.Tensor],
v_contexts: list[torch.Tensor], v_contexts: list[torch.Tensor],
......
...@@ -160,7 +160,7 @@ def test_local_attention_virtual_batches(test_data: LocalAttentionTestData): ...@@ -160,7 +160,7 @@ def test_local_attention_virtual_batches(test_data: LocalAttentionTestData):
# Use torch.arange instead of torch.randint so we can assert on # Use torch.arange instead of torch.randint so we can assert on
# block table tensor values. The block table will have shape # block table tensor values. The block table will have shape
# (num_batches, cdiv(max_seq_len, block_size)) and the values will be # (num_batches, cdiv(max_seq_len, block_size)) and the values will be
# aranged from 0 to cdiv(max_seq_len, block_size)-1 # arranged from 0 to cdiv(max_seq_len, block_size)-1
arange_block_indices=True, arange_block_indices=True,
) )
......
...@@ -15,7 +15,7 @@ from vllm.v1.attention.backends.utils import CommonAttentionMetadata ...@@ -15,7 +15,7 @@ from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.kv_cache_interface import FullAttentionSpec from vllm.v1.kv_cache_interface import FullAttentionSpec
BACKENDS_TO_TEST = [ BACKENDS_TO_TEST = [
_Backend.CUTLASS_MLA, _Backend.FLASHMLA_VLLM_V1, _Backend.CUTLASS_MLA, _Backend.FLASHMLA_VLLM_V1, _Backend.FLASH_ATTN_MLA,
_Backend.TRITON_MLA_VLLM_V1 _Backend.TRITON_MLA_VLLM_V1
] ]
...@@ -69,25 +69,10 @@ BATCH_SPECS = { ...@@ -69,25 +69,10 @@ BATCH_SPECS = {
} }
def create_dummy_kv_cache(kv_cache_spec: FullAttentionSpec,
device: torch.device,
num_blocks: int = 100) -> torch.Tensor:
"""Create a dummy KV cache tensor for testing."""
kv_cache = torch.randn(
num_blocks,
kv_cache_spec.block_size,
kv_cache_spec.head_size, # latent dimension
dtype=_convert_dtype_to_torch(kv_cache_spec.dtype),
device=device,
)
return kv_cache
def create_and_prepopulate_kv_cache( def create_and_prepopulate_kv_cache(
kv_c_contexts: list[torch.Tensor], kv_c_contexts: list[torch.Tensor],
k_pe_contexts: list[torch.Tensor], k_pe_contexts: list[torch.Tensor],
block_size: int, block_size: int,
num_kv_heads: int,
head_size: int, head_size: int,
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device, device: torch.device,
...@@ -101,7 +86,6 @@ def create_and_prepopulate_kv_cache( ...@@ -101,7 +86,6 @@ def create_and_prepopulate_kv_cache(
k_pe_contexts: List of key positional embedding context tensors k_pe_contexts: List of key positional embedding context tensors
for each sequence for each sequence
block_size: Size of each block block_size: Size of each block
num_kv_heads: Number of KV heads (should be 1 for MLA)
head_size: Size of each head (latent dimension) head_size: Size of each head (latent dimension)
dtype: Data type for the cache dtype: Data type for the cache
device: Device to create the cache on device: Device to create the cache on
...@@ -299,8 +283,6 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str): ...@@ -299,8 +283,6 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
query_lens = batch_spec.query_lens query_lens = batch_spec.query_lens
num_q_heads = vllm_config.model_config.get_num_attention_heads( num_q_heads = vllm_config.model_config.get_num_attention_heads(
vllm_config.parallel_config) vllm_config.parallel_config)
num_kv_heads = vllm_config.model_config.get_num_kv_heads(
vllm_config.parallel_config)
head_size = vllm_config.model_config.get_head_size() head_size = vllm_config.model_config.get_head_size()
dtype = _convert_dtype_to_torch(vllm_config.model_config.dtype) dtype = _convert_dtype_to_torch(vllm_config.model_config.dtype)
block_size = vllm_config.cache_config.block_size block_size = vllm_config.cache_config.block_size
...@@ -315,7 +297,7 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str): ...@@ -315,7 +297,7 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
# 2. Generate data and compute SDPA reference output for MLA # 2. Generate data and compute SDPA reference output for MLA
all_q_vllm, all_kv_c_vllm, all_k_pe_vllm = [], [], [] all_q_vllm, all_kv_c_vllm, all_k_pe_vllm = [], [], []
all_sdpa_outputs = [] all_sdpa_outputs: list[list[torch.Tensor]] = []
kv_c_contexts, k_pe_contexts = [], [] kv_c_contexts, k_pe_contexts = [], []
# Create shared MLA weight matrices for consistency across all sequences # Create shared MLA weight matrices for consistency across all sequences
...@@ -331,6 +313,9 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str): ...@@ -331,6 +313,9 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
device=device) device=device)
kv_b_proj_weight = torch.cat([W_UK, W_UV], dim=-1) kv_b_proj_weight = torch.cat([W_UK, W_UV], dim=-1)
for i, backend in enumerate(BACKENDS_TO_TEST):
all_sdpa_outputs.append([])
for i in range(batch_size): for i in range(batch_size):
s_len = seq_lens[i] s_len = seq_lens[i]
q_len = query_lens[i] q_len = query_lens[i]
...@@ -358,85 +343,93 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str): ...@@ -358,85 +343,93 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
dtype=dtype, dtype=dtype,
device=device) device=device)
# Determine if this is decode (single token) # Determine if this is decode or prefill
# or prefill (multiple tokens) is_decode = []
is_decode = q_len == 1 for i, backend in enumerate(BACKENDS_TO_TEST):
builder_cls, _ = get_attention_backend(backend)
is_decode.append(q_len <= builder_cls.reorder_batch_threshold)
# Split q into nope and rope components # Split q into nope and rope components
q_nope, q_pe = q_c.split([qk_nope_head_dim, qk_rope_head_dim], dim=-1) q_nope, q_pe = q_c.split([qk_nope_head_dim, qk_rope_head_dim], dim=-1)
if is_decode: #######################################################
# Decode path: MQA-style attention in latent space # Decode path: MQA-style attention in latent space
# Transform q_nope to latent space: q_nope @ W_UK # Transform q_nope to latent space: q_nope @ W_UK
# q_nope: [1, num_heads, qk_nope_head_dim] # q_nope: [1, num_heads, qk_nope_head_dim]
# W_UK: [kv_lora_rank, num_heads, qk_nope_head_dim] # W_UK: [kv_lora_rank, num_heads, qk_nope_head_dim]
ql_nope = torch.einsum("qnh,lnh->qnl", q_nope, ql_nope = torch.einsum("qnh,lnh->qnl", q_nope,
W_UK) # [1, num_heads, kv_lora_rank] W_UK) # [1, num_heads, kv_lora_rank]
# Build MQA attention inputs # Build MQA attention inputs
# Q: [1, num_heads, kv_lora_rank + qk_rope_head_dim] # Q: [1, num_heads, kv_lora_rank + qk_rope_head_dim]
q_mqa = torch.cat([ql_nope, q_pe], dim=-1) q_mqa = torch.cat([ql_nope, q_pe], dim=-1)
# K: [s_len, kv_lora_rank + qk_rope_head_dim] # K: [s_len, kv_lora_rank + qk_rope_head_dim]
# (broadcasted to all heads) # (broadcasted to all heads)
k_mqa = torch.cat([kv_c_full, k_pe_full.squeeze(1)], dim=-1) k_mqa = torch.cat([kv_c_full, k_pe_full.squeeze(1)], dim=-1)
k_mqa = k_mqa.unsqueeze(1).expand(-1, num_q_heads, -1) k_mqa = k_mqa.unsqueeze(1).expand(-1, num_q_heads, -1)
# V: [s_len, kv_lora_rank] (broadcasted to all heads) # V: [s_len, kv_lora_rank] (broadcasted to all heads)
v_mqa = kv_c_full.unsqueeze(1).expand(-1, num_q_heads, -1) v_mqa = kv_c_full.unsqueeze(1).expand(-1, num_q_heads, -1)
# SDPA expects (N, H, L, D) # Create custom attention mask for decode path:
q_sdpa_in = q_mqa.unsqueeze(0).transpose(1, 2) # - Query tokens can attend to all context tokens
k_sdpa_in = k_mqa.unsqueeze(0).transpose(1, 2) # - Query tokens can only attend to query tokens up to their position
v_sdpa_in = v_mqa.unsqueeze(0).transpose(1, 2) attn_mask = torch.ones(q_len, s_len, dtype=torch.bool, device=device)
# Apply causal mask only to the query portion (context_len onwards)
sdpa_out_i = torch.nn.functional.scaled_dot_product_attention( causal_mask = torch.tril(torch.ones(q_len, q_len, device=device))
q_sdpa_in, k_sdpa_in, v_sdpa_in, is_causal=False, scale=scale) attn_mask[:, context_len:] = causal_mask
sdpa_out_i = sdpa_out_i.transpose(1, 2).squeeze(
0) # [1, num_heads, kv_lora_rank] # SDPA expects (N, H, L, D)
q_sdpa_in = q_mqa.unsqueeze(0).transpose(1, 2)
# Project back to output space: sdpa_out @ W_UV k_sdpa_in = k_mqa.unsqueeze(0).transpose(1, 2)
sdpa_out_i = torch.einsum("qnl,lnv->qnv", sdpa_out_i, W_UV) v_sdpa_in = v_mqa.unsqueeze(0).transpose(1, 2)
sdpa_out_i = sdpa_out_i.flatten(start_dim=-2)
else: sdpa_out_i_decode = torch.nn.functional.scaled_dot_product_attention(
# Prefill path: MHA-style attention with full sequence q_sdpa_in, k_sdpa_in, v_sdpa_in, attn_mask=attn_mask, scale=scale)
# Apply kv_b_proj to the full kv_c tensor sdpa_out_i_decode = sdpa_out_i_decode.transpose(1, 2).squeeze(
kv_nope_full = torch.einsum("sl,lnh->snh", kv_c_full, 0) # [1, num_heads, kv_lora_rank]
kv_b_proj_weight)
k_nope_full, v_full = kv_nope_full.split( # Project back to output space: sdpa_out @ W_UV
[qk_nope_head_dim, v_head_dim], dim=-1) sdpa_out_i_decode = torch.einsum("qnl,lnv->qnv", sdpa_out_i_decode,
W_UV)
# Build attention inputs for full sequence sdpa_out_i_decode = sdpa_out_i_decode.flatten(start_dim=-2)
q_mha = torch.cat([q_nope, q_pe],
dim=-1) # [q_len, num_heads, total_dim] #######################################################
k_pe_full_expanded = k_pe_full.expand(-1, num_q_heads, -1) # Prefill path: MHA-style attention with full sequence
k_full = torch.cat([k_nope_full, k_pe_full_expanded], dim=-1) # Apply kv_b_proj to the full kv_c tensor
kv_nope_full = torch.einsum("sl,lnh->snh", kv_c_full, kv_b_proj_weight)
# Create custom attention mask: k_nope_full, v_full = kv_nope_full.split(
# - Query tokens can attend to all context tokens [qk_nope_head_dim, v_head_dim], dim=-1)
# - Query tokens can only attend to query tokens up to their pos
attn_mask = torch.ones(q_len, # Build attention inputs for full sequence
s_len, q_mha = torch.cat([q_nope, q_pe],
dtype=torch.bool, dim=-1) # [q_len, num_heads, total_dim]
device=device) k_pe_full_expanded = k_pe_full.expand(-1, num_q_heads, -1)
# Apply causal mask only to the query portion (context_len onwards) k_full = torch.cat([k_nope_full, k_pe_full_expanded], dim=-1)
causal_mask = torch.tril(torch.ones(q_len, q_len, device=device))
attn_mask[:, context_len:] = causal_mask # Create custom attention mask:
# - Query tokens can attend to all context tokens
# SDPA expects (N, H, L, D) # - Query tokens can only attend to query tokens up to their pos
q_sdpa_in = q_mha.unsqueeze(0).transpose(1, 2) attn_mask = torch.ones(q_len, s_len, dtype=torch.bool, device=device)
k_sdpa_in = k_full.unsqueeze(0).transpose(1, 2) # Apply causal mask only to the query portion (context_len onwards)
v_sdpa_in = v_full.unsqueeze(0).transpose(1, 2) causal_mask = torch.tril(torch.ones(q_len, q_len, device=device))
attn_mask[:, context_len:] = causal_mask
# Single attention call with custom mask
sdpa_out_i = torch.nn.functional.scaled_dot_product_attention( # SDPA expects (N, H, L, D)
q_sdpa_in, q_sdpa_in = q_mha.unsqueeze(0).transpose(1, 2)
k_sdpa_in, k_sdpa_in = k_full.unsqueeze(0).transpose(1, 2)
v_sdpa_in, v_sdpa_in = v_full.unsqueeze(0).transpose(1, 2)
attn_mask=attn_mask,
scale=scale) # Single attention call with custom mask
sdpa_out_i = sdpa_out_i.transpose(1, 2).squeeze(0) sdpa_out_i_prefill = torch.nn.functional.scaled_dot_product_attention(
sdpa_out_i = sdpa_out_i.flatten(start_dim=-2) q_sdpa_in, k_sdpa_in, v_sdpa_in, attn_mask=attn_mask, scale=scale)
sdpa_out_i_prefill = sdpa_out_i_prefill.transpose(1, 2).squeeze(0)
all_sdpa_outputs.append(sdpa_out_i) sdpa_out_i_prefill = sdpa_out_i_prefill.flatten(start_dim=-2)
for i, backend in enumerate(BACKENDS_TO_TEST):
if is_decode[i]:
all_sdpa_outputs[i].append(sdpa_out_i_decode)
else:
all_sdpa_outputs[i].append(sdpa_out_i_prefill)
# Inputs for vLLM MLA backends are just the new tokens # Inputs for vLLM MLA backends are just the new tokens
all_q_vllm.append(q_c) all_q_vllm.append(q_c)
...@@ -451,7 +444,9 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str): ...@@ -451,7 +444,9 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
query_vllm = torch.cat(all_q_vllm, dim=0) query_vllm = torch.cat(all_q_vllm, dim=0)
kv_c_vllm = torch.cat(all_kv_c_vllm, dim=0) kv_c_vllm = torch.cat(all_kv_c_vllm, dim=0)
k_pe_vllm = torch.cat(all_k_pe_vllm, dim=0) k_pe_vllm = torch.cat(all_k_pe_vllm, dim=0)
sdpa_output = torch.cat(all_sdpa_outputs, dim=0) sdpa_outputs = []
for i, backend in enumerate(BACKENDS_TO_TEST):
sdpa_outputs.append(torch.cat(all_sdpa_outputs[i], dim=0))
# Create mock kv_b_proj using the same weights as reference implementation # Create mock kv_b_proj using the same weights as reference implementation
from vllm.model_executor.layers.linear import ColumnParallelLinear from vllm.model_executor.layers.linear import ColumnParallelLinear
...@@ -477,7 +472,6 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str): ...@@ -477,7 +472,6 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
kv_c_contexts=kv_c_contexts, kv_c_contexts=kv_c_contexts,
k_pe_contexts=k_pe_contexts, k_pe_contexts=k_pe_contexts,
block_size=block_size, block_size=block_size,
num_kv_heads=num_kv_heads,
head_size=head_size, head_size=head_size,
dtype=dtype, dtype=dtype,
device=device, device=device,
...@@ -486,7 +480,7 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str): ...@@ -486,7 +480,7 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
randomize_blocks=True) randomize_blocks=True)
# 4. Run vLLM backends and compare # 4. Run vLLM backends and compare
for backend_name in BACKENDS_TO_TEST: for i, backend_name in enumerate(BACKENDS_TO_TEST):
backend_output = run_attention_backend( backend_output = run_attention_backend(
backend_name, kv_cache_spec, ["placeholder"], vllm_config, device, backend_name, kv_cache_spec, ["placeholder"], vllm_config, device,
common_attn_metadata, query_vllm, kv_c_vllm, k_pe_vllm, kv_cache, common_attn_metadata, query_vllm, kv_c_vllm, k_pe_vllm, kv_cache,
...@@ -494,12 +488,12 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str): ...@@ -494,12 +488,12 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
mock_kv_b_proj) mock_kv_b_proj)
# Check shape and dtype consistency # Check shape and dtype consistency
assert backend_output.shape == sdpa_output.shape, ( assert backend_output.shape == sdpa_outputs[i].shape, (
f"[{backend_name}] shape {backend_output.shape} != " f"[{backend_name}] shape {backend_output.shape} != "
f"SDPA shape {sdpa_output.shape}") f"SDPA shape {sdpa_outputs[i].shape}")
assert backend_output.dtype == sdpa_output.dtype, ( assert backend_output.dtype == sdpa_outputs[i].dtype, (
f"[{backend_name}] dtype {backend_output.dtype} != " f"[{backend_name}] dtype {backend_output.dtype} != "
f"SDPA dtype {sdpa_output.dtype}") f"SDPA dtype {sdpa_outputs[i].dtype}")
assert torch.isfinite(backend_output).all(), ( assert torch.isfinite(backend_output).all(), (
f"[{backend_name}] produced non-finite values") f"[{backend_name}] produced non-finite values")
...@@ -508,12 +502,13 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str): ...@@ -508,12 +502,13 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
rtol = 1e-2 rtol = 1e-2
atol = 5e-1 atol = 5e-1
max_diff = torch.max(torch.abs(backend_output - sdpa_output)).item() max_diff = torch.max(torch.abs(backend_output -
sdpa_outputs[i])).item()
max_rel_diff = torch.max( max_rel_diff = torch.max(
torch.abs(backend_output - sdpa_output) / torch.abs(backend_output - sdpa_outputs[i]) /
torch.abs(sdpa_output)).item() torch.abs(sdpa_outputs[i])).item()
all_close = torch.allclose(backend_output, all_close = torch.allclose(backend_output,
sdpa_output, sdpa_outputs[i],
rtol=rtol, rtol=rtol,
atol=atol) atol=atol)
......
...@@ -139,6 +139,8 @@ def get_attention_backend(backend_name: _Backend): ...@@ -139,6 +139,8 @@ def get_attention_backend(backend_name: _Backend):
"vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend", "vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend",
_Backend.FLASHMLA_VLLM_V1: _Backend.FLASHMLA_VLLM_V1:
"vllm.v1.attention.backends.mla.flashmla.FlashMLABackend", "vllm.v1.attention.backends.mla.flashmla.FlashMLABackend",
_Backend.FLASH_ATTN_MLA:
"vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend",
_Backend.TRITON_MLA_VLLM_V1: _Backend.TRITON_MLA_VLLM_V1:
"vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend", "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend",
} }
......
...@@ -6,20 +6,22 @@ from typing import Callable, Optional ...@@ -6,20 +6,22 @@ from typing import Callable, Optional
import pytest import pytest
import torch import torch
import vllm.v1.core.kv_cache_utils as kv_cache_utils
from vllm.config import ModelConfig, SchedulerConfig, VllmConfig from vllm.config import ModelConfig, SchedulerConfig, VllmConfig
from vllm.multimodal.inputs import (MultiModalFeatureSpec, from vllm.multimodal.inputs import (MultiModalFeatureSpec,
MultiModalKwargsItem, PlaceholderRange) MultiModalKwargsItem, PlaceholderRange)
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.utils import GiB_bytes, sha256, sha256_cbor_64bit from vllm.utils import GiB_bytes, sha256, sha256_cbor
from vllm.v1.core.kv_cache_manager import KVCacheManager from vllm.v1.core.kv_cache_manager import KVCacheManager
# disable yapf here as it formats differently than isort such that both fail # disable yapf here as it formats differently than isort such that both fail
# yapf: disable # yapf: disable
from vllm.v1.core.kv_cache_utils import ( from vllm.v1.core.kv_cache_utils import (
FreeKVCacheBlockQueue, KVCacheBlock, PrefixCachingMetrics, BlockHash, FreeKVCacheBlockQueue, KVCacheBlock, PrefixCachingMetrics,
estimate_max_model_len, generate_block_hash_extra_keys, estimate_max_model_len, generate_block_hash_extra_keys,
get_kv_cache_config, get_max_concurrency_for_kv_cache_config, get_kv_cache_config, get_max_concurrency_for_kv_cache_config,
get_request_block_hasher, hash_block_tokens, init_none_hash, get_request_block_hasher, hash_block_tokens, init_none_hash,
is_kv_cache_type_uniform, unify_kv_cache_configs) is_kv_cache_type_uniform, make_block_hash_with_group_id,
unify_kv_cache_configs)
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec, KVCacheTensor, KVCacheGroupSpec, KVCacheTensor,
SlidingWindowSpec) SlidingWindowSpec)
...@@ -88,7 +90,7 @@ def new_sliding_window_spec(block_size=16, ...@@ -88,7 +90,7 @@ def new_sliding_window_spec(block_size=16,
sliding_window=sliding_window) sliding_window=sliding_window)
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash]) @pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor])
def test_none_hash(monkeypatch, hash_fn): def test_none_hash(monkeypatch, hash_fn):
import vllm.v1.core.kv_cache_utils import vllm.v1.core.kv_cache_utils
...@@ -98,8 +100,8 @@ def test_none_hash(monkeypatch, hash_fn): ...@@ -98,8 +100,8 @@ def test_none_hash(monkeypatch, hash_fn):
reloaded_kv_cache_utils = importlib.reload(vllm.v1.core.kv_cache_utils) reloaded_kv_cache_utils = importlib.reload(vllm.v1.core.kv_cache_utils)
reloaded_kv_cache_utils.init_none_hash(hash_fn) reloaded_kv_cache_utils.init_none_hash(hash_fn)
assert reloaded_kv_cache_utils.NONE_HASH is not None assert reloaded_kv_cache_utils.NONE_HASH is not None
assert isinstance(reloaded_kv_cache_utils.NONE_HASH, int) assert isinstance(reloaded_kv_cache_utils.NONE_HASH, bytes)
assert reloaded_kv_cache_utils.NONE_HASH != 0 assert reloaded_kv_cache_utils.NONE_HASH != b""
# case 2: PYTHONHASHSEED is set, use the seed and hash_fn # case 2: PYTHONHASHSEED is set, use the seed and hash_fn
with monkeypatch.context() as m: with monkeypatch.context() as m:
...@@ -107,12 +109,11 @@ def test_none_hash(monkeypatch, hash_fn): ...@@ -107,12 +109,11 @@ def test_none_hash(monkeypatch, hash_fn):
reloaded_kv_cache_utils = importlib.reload(vllm.v1.core.kv_cache_utils) reloaded_kv_cache_utils = importlib.reload(vllm.v1.core.kv_cache_utils)
reloaded_kv_cache_utils.init_none_hash(hash_fn) reloaded_kv_cache_utils.init_none_hash(hash_fn)
assert reloaded_kv_cache_utils.NONE_HASH is not None assert reloaded_kv_cache_utils.NONE_HASH is not None
assert isinstance(reloaded_kv_cache_utils.NONE_HASH, int) assert isinstance(reloaded_kv_cache_utils.NONE_HASH, bytes)
assert hash_fn('python hash seed') == reloaded_kv_cache_utils.NONE_HASH assert hash_fn('python hash seed') == reloaded_kv_cache_utils.NONE_HASH
def test_kv_cache_block(): def test_kv_cache_block():
import vllm.v1.core.kv_cache_utils
# Test KVCacheBlock initialization # Test KVCacheBlock initialization
block = KVCacheBlock(block_id=0) block = KVCacheBlock(block_id=0)
...@@ -127,8 +128,7 @@ def test_kv_cache_block(): ...@@ -127,8 +128,7 @@ def test_kv_cache_block():
assert block.ref_cnt == 0 assert block.ref_cnt == 0
# Test block hash setting and resetting # Test block hash setting and resetting
block_hash = vllm.v1.core.kv_cache_utils.BlockHash(hash_value=123, block_hash = make_block_hash_with_group_id(BlockHash(b"abc"), 0)
token_ids=(1, 2, 3))
block.block_hash = block_hash block.block_hash = block_hash
assert block.block_hash == block_hash assert block.block_hash == block_hash
...@@ -247,7 +247,7 @@ def test_free_kv_cache_block_queue_append_n(): ...@@ -247,7 +247,7 @@ def test_free_kv_cache_block_queue_append_n():
def test_free_kv_cache_block_queue_popleft_n(): def test_free_kv_cache_block_queue_popleft_n():
blocks = [KVCacheBlock(block_id=i) for i in range(6)] blocks = [KVCacheBlock(block_id=i) for i in range(6)]
# Create a empty FreeKVCacheBlockQueue with these blocks # Create an empty FreeKVCacheBlockQueue with these blocks
queue = FreeKVCacheBlockQueue( queue = FreeKVCacheBlockQueue(
[blocks[1], blocks[3], blocks[5], blocks[4], blocks[0], blocks[2]]) [blocks[1], blocks[3], blocks[5], blocks[4], blocks[0], blocks[2]])
assert queue.num_free_blocks == 6 assert queue.num_free_blocks == 6
...@@ -407,27 +407,23 @@ def test_generate_block_hash_extra_keys_cache_salt(): ...@@ -407,27 +407,23 @@ def test_generate_block_hash_extra_keys_cache_salt():
assert next_mm_idx == 1 assert next_mm_idx == 1
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash]) @pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor])
def test_hash_block_tokens(hash_fn): def test_hash_block_tokens(hash_fn):
import vllm.v1.core.kv_cache_utils
init_none_hash(hash_fn) init_none_hash(hash_fn)
parent_block_hash = 123 parent_block_hash = BlockHash(b"123")
curr_block_token_ids = (1, 2, 3) curr_block_token_ids = (1, 2, 3)
extra_keys = ("key1", "key2") extra_keys = ("key1", "key2")
block_hash = hash_block_tokens(hash_fn, parent_block_hash, block_hash = hash_block_tokens(hash_fn, parent_block_hash,
curr_block_token_ids, extra_keys) curr_block_token_ids, extra_keys)
assert isinstance(block_hash, vllm.v1.core.kv_cache_utils.BlockHash) expected = hash_fn((parent_block_hash, curr_block_token_ids, extra_keys))
assert block_hash.hash_value == hash_fn( assert block_hash == expected
(parent_block_hash, curr_block_token_ids, extra_keys))
assert block_hash.token_ids == curr_block_token_ids
assert block_hash.extra_keys == extra_keys
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash]) @pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor])
def test_request_block_hasher(hash_fn): def test_request_block_hasher(hash_fn):
import vllm.v1.core.kv_cache_utils kv_cache_utils.init_none_hash(hash_fn)
init_none_hash(hash_fn)
request = make_request( request = make_request(
request_id="0", request_id="0",
prompt_token_ids=[_ for _ in range(6)], prompt_token_ids=[_ for _ in range(6)],
...@@ -442,19 +438,13 @@ def test_request_block_hasher(hash_fn): ...@@ -442,19 +438,13 @@ def test_request_block_hasher(hash_fn):
block_hashes = request.block_hashes block_hashes = request.block_hashes
assert len(block_hashes) == 2 assert len(block_hashes) == 2
assert isinstance(block_hashes[0], vllm.v1.core.kv_cache_utils.BlockHash) assert block_hashes[0] == hash_fn(
assert isinstance(block_hashes[1], vllm.v1.core.kv_cache_utils.BlockHash) (kv_cache_utils.NONE_HASH, (0, 1, 2), ("hash1", )))
assert block_hashes[1] == hash_fn(
# Check the first block (block_hashes[0], (3, 4, 5), ("hash2", )))
assert block_hashes[0].token_ids == (0, 1, 2)
assert block_hashes[0].extra_keys == ("hash1", )
# Check the second block
assert block_hashes[1].token_ids == (3, 4, 5)
assert block_hashes[1].extra_keys == ("hash2", )
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor])
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash])
def test_hash_tokens_different_mm_input(hash_fn): def test_hash_tokens_different_mm_input(hash_fn):
init_none_hash(hash_fn) init_none_hash(hash_fn)
...@@ -484,9 +474,9 @@ def test_hash_tokens_different_mm_input(hash_fn): ...@@ -484,9 +474,9 @@ def test_hash_tokens_different_mm_input(hash_fn):
assert block_hashes1[1] != block_hashes2[1] assert block_hashes1[1] != block_hashes2[1]
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash]) @pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor])
def test_hash_request_tokens_no_mm_inputs(hash_fn): def test_hash_request_tokens_no_mm_inputs(hash_fn):
init_none_hash(hash_fn) kv_cache_utils.init_none_hash(hash_fn)
request = make_request( request = make_request(
request_id="0", request_id="0",
...@@ -500,10 +490,9 @@ def test_hash_request_tokens_no_mm_inputs(hash_fn): ...@@ -500,10 +490,9 @@ def test_hash_request_tokens_no_mm_inputs(hash_fn):
block_hashes = request.block_hashes block_hashes = request.block_hashes
assert len(block_hashes) == 2 assert len(block_hashes) == 2
assert block_hashes[0].token_ids == (0, 1, 2) assert block_hashes[0] == hash_fn(
assert block_hashes[0].extra_keys is None (kv_cache_utils.NONE_HASH, (0, 1, 2), None))
assert block_hashes[1].token_ids == (3, 4, 5) assert block_hashes[1] == hash_fn((block_hashes[0], (3, 4, 5), None))
assert block_hashes[1].extra_keys is None
def test_metrics(): def test_metrics():
......
...@@ -8,17 +8,19 @@ from typing import Callable, Optional ...@@ -8,17 +8,19 @@ from typing import Callable, Optional
import pytest import pytest
import torch import torch
import vllm.v1.core.kv_cache_utils as kv_cache_utils
from vllm.distributed.kv_events import AllBlocksCleared, BlockRemoved from vllm.distributed.kv_events import AllBlocksCleared, BlockRemoved
from vllm.multimodal.inputs import (MultiModalFeatureSpec, from vllm.multimodal.inputs import (MultiModalFeatureSpec,
MultiModalKwargsItem, PlaceholderRange) MultiModalKwargsItem, PlaceholderRange)
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.utils import sha256, sha256_cbor_64bit from vllm.utils import sha256, sha256_cbor
from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_manager import KVCacheManager, Request from vllm.v1.core.kv_cache_manager import KVCacheManager, Request
from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId, from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock,
KVCacheBlock, get_block_hash, get_group_id,
get_request_block_hasher, get_request_block_hasher,
hash_block_tokens, init_none_hash) hash_block_tokens, init_none_hash,
make_block_hash_with_group_id)
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec, SlidingWindowSpec) KVCacheGroupSpec, SlidingWindowSpec)
...@@ -101,8 +103,10 @@ def make_kv_cache_config_hybrid_model(block_size: int, ...@@ -101,8 +103,10 @@ def make_kv_cache_config_hybrid_model(block_size: int,
) )
@pytest.mark.parametrize("hash_algo", ["sha256", "sha256_cbor_64bit", "hash"]) @pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor])
def test_prefill(hash_algo): def test_prefill(hash_fn):
init_none_hash(hash_fn)
block_size = 16 block_size = 16
manager = KVCacheManager( manager = KVCacheManager(
make_kv_cache_config(block_size, 11), make_kv_cache_config(block_size, 11),
...@@ -110,10 +114,6 @@ def test_prefill(hash_algo): ...@@ -110,10 +114,6 @@ def test_prefill(hash_algo):
enable_caching=True, enable_caching=True,
) )
# choose the hash function according to the parameter
hash_fn = (sha256_cbor_64bit if hash_algo == "sha256_cbor_64bit" else
sha256 if hash_algo == "sha256" else hash)
# Complete 3 blocks (48 tokens) # Complete 3 blocks (48 tokens)
common_token_ids = [i for i in range(3) for _ in range(16)] common_token_ids = [i for i in range(3) for _ in range(16)]
...@@ -137,10 +137,12 @@ def test_prefill(hash_algo): ...@@ -137,10 +137,12 @@ def test_prefill(hash_algo):
block_tokens = tuple(all_token_ids[(block_id - 1) * 16:block_id * 16]) block_tokens = tuple(all_token_ids[(block_id - 1) * 16:block_id * 16])
block_hash = hash_block_tokens(hash_fn, parent_block_hash, block_hash = hash_block_tokens(hash_fn, parent_block_hash,
block_tokens) block_tokens)
assert manager.block_pool.blocks[ blk_hash = manager.block_pool.blocks[block_id].block_hash
block_id].block_hash.block_hash == block_hash assert blk_hash is not None
assert get_block_hash(blk_hash) == block_hash
assert get_group_id(blk_hash) == 0
assert manager.block_pool.blocks[block_id].ref_cnt == 1 assert manager.block_pool.blocks[block_id].ref_cnt == 1
parent_block_hash = block_hash.hash_value parent_block_hash = block_hash
# Check partial block metadata # Check partial block metadata
for block_id in (4, ): for block_id in (4, ):
...@@ -233,7 +235,7 @@ def test_prefill_hybrid_model(): ...@@ -233,7 +235,7 @@ def test_prefill_hybrid_model():
enable_caching=True, enable_caching=True,
) )
hash_fn = hash hash_fn = sha256
# Complete 3 blocks (48 tokens) # Complete 3 blocks (48 tokens)
common_token_ids = [i for i in range(3) for _ in range(block_size)] common_token_ids = [i for i in range(3) for _ in range(block_size)]
...@@ -260,11 +262,13 @@ def test_prefill_hybrid_model(): ...@@ -260,11 +262,13 @@ def test_prefill_hybrid_model():
block_tokens = tuple(all_token_ids[(length - 1) * 16:length * 16]) block_tokens = tuple(all_token_ids[(length - 1) * 16:length * 16])
block_hash = hash_block_tokens(hash_fn, parent_block_hash, block_hash = hash_block_tokens(hash_fn, parent_block_hash,
block_tokens) block_tokens)
for block_id in block_ids: for group_id, block_id in enumerate(block_ids):
assert manager.block_pool.blocks[ blk_hash = manager.block_pool.blocks[block_id].block_hash
block_id].block_hash.block_hash == block_hash assert blk_hash is not None
assert get_block_hash(blk_hash) == block_hash
assert get_group_id(blk_hash) == group_id
assert manager.block_pool.blocks[block_id].ref_cnt == 1 assert manager.block_pool.blocks[block_id].ref_cnt == 1
parent_block_hash = block_hash.hash_value parent_block_hash = block_hash
# Check partial block metadata # Check partial block metadata
for block_id in (4, 8, 12): for block_id in (4, 8, 12):
...@@ -298,11 +302,10 @@ def test_prefill_hybrid_model(): ...@@ -298,11 +302,10 @@ def test_prefill_hybrid_model():
cached_block_hash_to_block_bak = copy.copy( cached_block_hash_to_block_bak = copy.copy(
manager.block_pool.cached_block_hash_to_block) manager.block_pool.cached_block_hash_to_block)
def test_partial_request_hit(request_id: str, def test_partial_request_hit(request_id: str, hash_to_evict: list[bytes],
hash_to_evict: list[BlockHashWithGroupId],
expect_hit_length: int): expect_hit_length: int):
req = make_request(request_id, common_token_ids + unique_token_ids, req = make_request(request_id, common_token_ids + unique_token_ids,
block_size, hash) block_size, sha256)
for hash_with_group_id in hash_to_evict: for hash_with_group_id in hash_to_evict:
manager.block_pool.cached_block_hash_to_block.pop( manager.block_pool.cached_block_hash_to_block.pop(
hash_with_group_id) hash_with_group_id)
...@@ -319,33 +322,32 @@ def test_prefill_hybrid_model(): ...@@ -319,33 +322,32 @@ def test_prefill_hybrid_model():
# Evict the blocks outside sliding window, does not affect the hit length. # Evict the blocks outside sliding window, does not affect the hit length.
test_partial_request_hit("2", [ test_partial_request_hit("2", [
BlockHashWithGroupId(block_hashes[0], 1), make_block_hash_with_group_id(block_hashes[0], 1),
BlockHashWithGroupId(block_hashes[0], 2) make_block_hash_with_group_id(block_hashes[0], 2)
], 3) ], 3)
# Evict the first block of full attention, makes total cache miss. # Evict the first block of full attention, makes total cache miss.
test_partial_request_hit("3", [ test_partial_request_hit(
BlockHashWithGroupId(block_hashes[0], 0), "3", [make_block_hash_with_group_id(block_hashes[0], 0)], 0)
], 0)
# Evict the last block of all layers, reduces the hit length to 2. # Evict the last block of all layers, reduces the hit length to 2.
test_partial_request_hit("4", [ test_partial_request_hit("4", [
BlockHashWithGroupId(block_hashes[2], 0), make_block_hash_with_group_id(block_hashes[2], 0),
BlockHashWithGroupId(block_hashes[2], 1), make_block_hash_with_group_id(block_hashes[2], 1),
BlockHashWithGroupId(block_hashes[2], 2), make_block_hash_with_group_id(block_hashes[2], 2),
], 2) ], 2)
# Evict the last block of full attention, reduces the hit length to 2. # Evict the last block of full attention, reduces the hit length to 2.
test_partial_request_hit("5", [BlockHashWithGroupId(block_hashes[2], 0)], test_partial_request_hit(
2) "5", [make_block_hash_with_group_id(block_hashes[2], 0)], 2)
# Evict the last block of sliding window, reduces the hit length to 2. # Evict the last block of sliding window, reduces the hit length to 2.
test_partial_request_hit("6", [BlockHashWithGroupId(block_hashes[2], 1)], test_partial_request_hit(
2) "6", [make_block_hash_with_group_id(block_hashes[2], 1)], 2)
# Evict the last block of sliding window, reduces the hit length to 2. # Evict the last block of sliding window, reduces the hit length to 2.
test_partial_request_hit("7", [BlockHashWithGroupId(block_hashes[2], 2)], test_partial_request_hit(
2) "7", [make_block_hash_with_group_id(block_hashes[2], 2)], 2)
# Evict different set of blocks for full attention and sliding window makes # Evict different set of blocks for full attention and sliding window makes
# total cache miss. # total cache miss.
...@@ -353,9 +355,9 @@ def test_prefill_hybrid_model(): ...@@ -353,9 +355,9 @@ def test_prefill_hybrid_model():
# The cache hit length of sliding window is 2 * block_size. # The cache hit length of sliding window is 2 * block_size.
# Then it is cache miss as the two type of layers have different hit length. # Then it is cache miss as the two type of layers have different hit length.
test_partial_request_hit("8", [ test_partial_request_hit("8", [
BlockHashWithGroupId(block_hashes[2], 0), make_block_hash_with_group_id(block_hashes[2], 0),
BlockHashWithGroupId(block_hashes[0], 1), make_block_hash_with_group_id(block_hashes[0], 1),
BlockHashWithGroupId(block_hashes[0], 2), make_block_hash_with_group_id(block_hashes[0], 2),
], 0) ], 0)
...@@ -372,8 +374,8 @@ def test_prefill_plp(): ...@@ -372,8 +374,8 @@ def test_prefill_plp():
max_model_len=8192, max_model_len=8192,
enable_caching=True, enable_caching=True,
) )
# the default hash function is hash # the default hash function is sha256
hash_fn = hash hash_fn = sha256
# Complete 3 blocks (48 tokens) # Complete 3 blocks (48 tokens)
common_token_ids = [i for i in range(3) for _ in range(16)] common_token_ids = [i for i in range(3) for _ in range(16)]
...@@ -404,10 +406,12 @@ def test_prefill_plp(): ...@@ -404,10 +406,12 @@ def test_prefill_plp():
block_tokens = tuple(all_token_ids[(block_id - 1) * 16:block_id * 16]) block_tokens = tuple(all_token_ids[(block_id - 1) * 16:block_id * 16])
block_hash = hash_block_tokens(hash_fn, parent_block_hash, block_hash = hash_block_tokens(hash_fn, parent_block_hash,
block_tokens) block_tokens)
assert manager.block_pool.blocks[ blk_hash = (manager.block_pool.blocks[block_id].block_hash)
block_id].block_hash.block_hash == block_hash assert blk_hash is not None
assert get_block_hash(blk_hash) == block_hash
assert get_group_id(blk_hash) == 0
assert manager.block_pool.blocks[block_id].ref_cnt == 1 assert manager.block_pool.blocks[block_id].ref_cnt == 1
parent_block_hash = block_hash.hash_value parent_block_hash = block_hash
# Check partial block metadata # Check partial block metadata
for block_id in (4, ): for block_id in (4, ):
...@@ -493,7 +497,7 @@ def test_decode(): ...@@ -493,7 +497,7 @@ def test_decode():
# Incomplete 1 block (7 tokens) # Incomplete 1 block (7 tokens)
unique_token_ids = [3] * 7 unique_token_ids = [3] * 7
req0 = make_request("0", common_token_ids + unique_token_ids, block_size, req0 = make_request("0", common_token_ids + unique_token_ids, block_size,
hash) sha256)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
assert not computed_blocks.blocks[0] assert not computed_blocks.blocks[0]
assert num_computed_tokens == 0 assert num_computed_tokens == 0
...@@ -538,7 +542,7 @@ def test_evict(): ...@@ -538,7 +542,7 @@ def test_evict():
) )
last_token_id = 5 * 16 + 7 last_token_id = 5 * 16 + 7
req0 = make_request("0", list(range(last_token_id)), block_size, hash) req0 = make_request("0", list(range(last_token_id)), block_size, sha256)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
assert not computed_blocks.blocks[0] assert not computed_blocks.blocks[0]
assert num_computed_tokens == 0 assert num_computed_tokens == 0
...@@ -550,7 +554,7 @@ def test_evict(): ...@@ -550,7 +554,7 @@ def test_evict():
# 3 blocks. # 3 blocks.
req1 = make_request("1", list(range(last_token_id, req1 = make_request("1", list(range(last_token_id,
last_token_id + 3 * 16)), block_size, last_token_id + 3 * 16)), block_size,
hash) sha256)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert not computed_blocks.blocks[0] assert not computed_blocks.blocks[0]
assert num_computed_tokens == 0 assert num_computed_tokens == 0
...@@ -572,7 +576,7 @@ def test_evict(): ...@@ -572,7 +576,7 @@ def test_evict():
] == [10, 6, 5, 4, 3, 2, 1, 9, 8, 7] ] == [10, 6, 5, 4, 3, 2, 1, 9, 8, 7]
# Touch the first 2 blocks. # Touch the first 2 blocks.
req2 = make_request("2", list(range(2 * 16 + 3)), block_size, hash) req2 = make_request("2", list(range(2 * 16 + 3)), block_size, sha256)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert computed_blocks.get_block_ids() == ([1, 2], ) assert computed_blocks.get_block_ids() == ([1, 2], )
assert num_computed_tokens == 2 * 16 assert num_computed_tokens == 2 * 16
...@@ -597,7 +601,7 @@ def test_hash_block_correct_reuse(): ...@@ -597,7 +601,7 @@ def test_hash_block_correct_reuse():
# Allocate 1 block and cache it. # Allocate 1 block and cache it.
num_tokens = block_size * 1 num_tokens = block_size * 1
req = make_request("0", list(range(num_tokens)), block_size, hash) req = make_request("0", list(range(num_tokens)), block_size, sha256)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
assert not computed_blocks.blocks[0] assert not computed_blocks.blocks[0]
assert num_computed_tokens == 0 assert num_computed_tokens == 0
...@@ -611,7 +615,7 @@ def test_hash_block_correct_reuse(): ...@@ -611,7 +615,7 @@ def test_hash_block_correct_reuse():
# Allocate a new block that's not full, make sure hash info on the # Allocate a new block that's not full, make sure hash info on the
# block is cleared. # block is cleared.
req = make_request("1", list(range(num_tokens - 1)), block_size, hash) req = make_request("1", list(range(num_tokens - 1)), block_size, sha256)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
assert not computed_blocks.blocks[0] assert not computed_blocks.blocks[0]
assert num_computed_tokens == 0 assert num_computed_tokens == 0
...@@ -638,7 +642,7 @@ def test_computed_blocks_not_evicted(): ...@@ -638,7 +642,7 @@ def test_computed_blocks_not_evicted():
# Allocate a block and cache it. # Allocate a block and cache it.
num_tokens = block_size * 1 num_tokens = block_size * 1
req0 = make_request("0", list(range(num_tokens)), block_size, hash) req0 = make_request("0", list(range(num_tokens)), block_size, sha256)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
assert not computed_blocks.blocks[0] assert not computed_blocks.blocks[0]
assert num_computed_tokens == 0 assert num_computed_tokens == 0
...@@ -650,7 +654,7 @@ def test_computed_blocks_not_evicted(): ...@@ -650,7 +654,7 @@ def test_computed_blocks_not_evicted():
# Allocate another block. # Allocate another block.
req1 = make_request("1", list(range(num_tokens, num_tokens * 2)), req1 = make_request("1", list(range(num_tokens, num_tokens * 2)),
block_size, hash) block_size, sha256)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert not computed_blocks.blocks[0] assert not computed_blocks.blocks[0]
assert num_computed_tokens == 0 assert num_computed_tokens == 0
...@@ -666,7 +670,7 @@ def test_computed_blocks_not_evicted(): ...@@ -666,7 +670,7 @@ def test_computed_blocks_not_evicted():
# Now if we have a cache hit on the first block, we should evict the second # Now if we have a cache hit on the first block, we should evict the second
# cached block rather than the first one. # cached block rather than the first one.
req2 = make_request("2", list(range(num_tokens * 2)), block_size, hash) req2 = make_request("2", list(range(num_tokens * 2)), block_size, sha256)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert len(computed_blocks.blocks[0]) == 1 assert len(computed_blocks.blocks[0]) == 1
assert computed_blocks.blocks[0][0].block_id == 1 assert computed_blocks.blocks[0][0].block_id == 1
...@@ -691,7 +695,7 @@ def test_basic_prefix_caching_disabled(): ...@@ -691,7 +695,7 @@ def test_basic_prefix_caching_disabled():
) )
req1 = make_request("1", list(range(10)), block_size, req1 = make_request("1", list(range(10)), block_size,
hash) # 2 blocks and some more sha256) # 2 blocks and some more
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert not computed_blocks.blocks[0] assert not computed_blocks.blocks[0]
...@@ -706,7 +710,7 @@ def test_basic_prefix_caching_disabled(): ...@@ -706,7 +710,7 @@ def test_basic_prefix_caching_disabled():
# No caching. # No caching.
req2 = make_request("2", list(range(16)), block_size, req2 = make_request("2", list(range(16)), block_size,
hash) # shared prefix sha256) # shared prefix
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert not computed_blocks.blocks[0] assert not computed_blocks.blocks[0]
assert num_computed_tokens == 0 assert num_computed_tokens == 0
...@@ -716,7 +720,7 @@ def test_basic_prefix_caching_disabled(): ...@@ -716,7 +720,7 @@ def test_basic_prefix_caching_disabled():
assert len(blocks.blocks[0]) == 4 assert len(blocks.blocks[0]) == 4
# New requests should not have any blocks. # New requests should not have any blocks.
req3 = make_request("3", list(range(4)), block_size, hash) req3 = make_request("3", list(range(4)), block_size, sha256)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3)
assert not computed_blocks.blocks[0] assert not computed_blocks.blocks[0]
assert num_computed_tokens == 0 assert num_computed_tokens == 0
...@@ -726,7 +730,7 @@ def test_basic_prefix_caching_disabled(): ...@@ -726,7 +730,7 @@ def test_basic_prefix_caching_disabled():
assert not blocks assert not blocks
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash]) @pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor])
def test_cache_blocks(hash_fn): def test_cache_blocks(hash_fn):
""" """
This is a unit test that tests the correctness of the _cache_full_blocks This is a unit test that tests the correctness of the _cache_full_blocks
...@@ -787,7 +791,7 @@ def test_cache_blocks_multi_group(): ...@@ -787,7 +791,7 @@ def test_cache_blocks_multi_group():
# Block 1/5: [4, 5, 6, 7] # Block 1/5: [4, 5, 6, 7]
# Block 2/6: [8, 9, 10, 11] # Block 2/6: [8, 9, 10, 11]
# Block 3/7: [12, 13] # Block 3/7: [12, 13]
req = make_request("0", list(range(14)), block_size, hash) req = make_request("0", list(range(14)), block_size, sha256)
# Cache the blocks for group 0. # Cache the blocks for group 0.
blocks = [KVCacheBlock(block_id=i) for i in range(2)] blocks = [KVCacheBlock(block_id=i) for i in range(2)]
...@@ -845,6 +849,8 @@ def test_mm_prefix_caching(): ...@@ -845,6 +849,8 @@ def test_mm_prefix_caching():
""" """
This tests that the multi-modal prefix caching is correct. This tests that the multi-modal prefix caching is correct.
""" """
kv_cache_utils.init_none_hash(sha256)
block_size = 16 block_size = 16
manager = KVCacheManager( manager = KVCacheManager(
make_kv_cache_config(block_size, 11), make_kv_cache_config(block_size, 11),
...@@ -874,23 +880,30 @@ def test_mm_prefix_caching(): ...@@ -874,23 +880,30 @@ def test_mm_prefix_caching():
req0 = make_request("0", req0 = make_request("0",
all_token_ids, all_token_ids,
block_size, block_size,
hash, sha256,
mm_positions=mm_positions, mm_positions=mm_positions,
mm_hashes=mm_hashes) mm_hashes=mm_hashes)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
# Completed block should have hashes with extra keys. # Completed block should have hashes
assert not computed_blocks.blocks[0] assert not computed_blocks.blocks[0]
assert num_computed_tokens == 0 assert num_computed_tokens == 0
block_hashes = req0.block_hashes block_hashes = req0.block_hashes
assert len(block_hashes) == 3 assert len(block_hashes) == 3
assert block_hashes[0].extra_keys == ("aaa", ) assert block_hashes[0] == sha256(
assert block_hashes[1].extra_keys == ("aaa", "bbb") (kv_cache_utils.NONE_HASH, tuple(all_token_ids[:block_size]),
assert block_hashes[2].extra_keys == ("bbb", ) ("aaa", )))
assert block_hashes[1] == sha256(
(block_hashes[0], tuple(all_token_ids[block_size:block_size * 2]),
("aaa", "bbb")))
assert block_hashes[2] == sha256(
(block_hashes[1], tuple(all_token_ids[block_size * 2:block_size * 3]),
("bbb", )))
blocks = manager.allocate_slots(req0, 59, blocks = manager.allocate_slots(req0, 59,
len(computed_blocks.blocks[0]) * 16, len(computed_blocks.blocks[0]) * 16,
computed_blocks) computed_blocks)
assert blocks is not None
assert blocks.get_block_ids() == ([1, 2, 3, 4], ) assert blocks.get_block_ids() == ([1, 2, 3, 4], )
req0.num_computed_tokens = 59 req0.num_computed_tokens = 59
...@@ -901,10 +914,10 @@ def test_mm_prefix_caching(): ...@@ -901,10 +914,10 @@ def test_mm_prefix_caching():
len(computed_blocks.blocks[0]) * 16, len(computed_blocks.blocks[0]) * 16,
computed_blocks) computed_blocks)
assert new_blocks is not None and len(new_blocks.blocks[0]) == 0 assert new_blocks is not None and len(new_blocks.blocks[0]) == 0
# The just completed block should have hashes with extra keys.
assert len(block_hashes) == 4 assert len(block_hashes) == 4
assert block_hashes[3].extra_keys == ("ccc", ) assert block_hashes[3] == sha256(
(block_hashes[2], tuple(all_token_ids[3 * block_size:] + [8] * 5),
("ccc", )))
# Cache hit. # Cache hit.
unique_token_ids = [-1] * 7 + [200] * 5 unique_token_ids = [-1] * 7 + [200] * 5
...@@ -916,7 +929,7 @@ def test_mm_prefix_caching(): ...@@ -916,7 +929,7 @@ def test_mm_prefix_caching():
req1 = make_request("1", req1 = make_request("1",
all_token_ids, all_token_ids,
block_size, block_size,
hash, sha256,
mm_positions=mm_positions, mm_positions=mm_positions,
mm_hashes=mm_hashes) mm_hashes=mm_hashes)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
...@@ -929,6 +942,8 @@ def test_cache_key_salting(): ...@@ -929,6 +942,8 @@ def test_cache_key_salting():
This tests that cache salts are applied during hashing and the cache This tests that cache salts are applied during hashing and the cache
is separated cache as expected. is separated cache as expected.
""" """
kv_cache_utils.init_none_hash(sha256)
block_size = 16 block_size = 16
manager = KVCacheManager( manager = KVCacheManager(
make_kv_cache_config(block_size, 11), make_kv_cache_config(block_size, 11),
...@@ -939,21 +954,26 @@ def test_cache_key_salting(): ...@@ -939,21 +954,26 @@ def test_cache_key_salting():
# 3 complete blocks and an incomplete block with 11 tokens. # 3 complete blocks and an incomplete block with 11 tokens.
common_token_ids = [i for i in range(3) for _ in range(block_size)] common_token_ids = [i for i in range(3) for _ in range(block_size)]
token_ids = common_token_ids + [3] * 11 token_ids = common_token_ids + [3] * 11
req0 = make_request("0", token_ids, block_size, hash, cache_salt="salt1") req0 = make_request("0", token_ids, block_size, sha256, cache_salt="salt1")
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
# Completed block should have hashes with extra keys. # Completed block should have hashes
assert not computed_blocks.blocks[0] assert not computed_blocks.blocks[0]
assert num_computed_tokens == 0 assert num_computed_tokens == 0
block_hashes = req0.block_hashes block_hashes = req0.block_hashes
assert len(block_hashes) == 3 assert len(block_hashes) == 3
assert block_hashes[0].extra_keys == ("salt1", ) assert block_hashes[0] == sha256(
assert block_hashes[1].extra_keys is None (kv_cache_utils.NONE_HASH, tuple(token_ids[:block_size]), ("salt1", )))
assert block_hashes[2].extra_keys is None assert block_hashes[1] == sha256(
(block_hashes[0], tuple(token_ids[block_size:block_size * 2]), None))
assert block_hashes[2] == sha256(
(block_hashes[1], tuple(token_ids[block_size * 2:block_size * 3]),
None))
blocks = manager.allocate_slots(req0, 59, blocks = manager.allocate_slots(req0, 59,
len(computed_blocks.blocks[0]) * 16, len(computed_blocks.blocks[0]) * 16,
computed_blocks) computed_blocks)
assert blocks is not None
assert blocks.get_block_ids() == ([1, 2, 3, 4], ) assert blocks.get_block_ids() == ([1, 2, 3, 4], )
req0.num_computed_tokens = 59 req0.num_computed_tokens = 59
...@@ -964,14 +984,13 @@ def test_cache_key_salting(): ...@@ -964,14 +984,13 @@ def test_cache_key_salting():
len(computed_blocks.blocks[0]) * 16, len(computed_blocks.blocks[0]) * 16,
computed_blocks) computed_blocks)
assert new_blocks is not None and len(new_blocks.blocks[0]) == 0 assert new_blocks is not None and len(new_blocks.blocks[0]) == 0
# Now one more block that should not have extra keys.
assert len(block_hashes) == 4 assert len(block_hashes) == 4
assert block_hashes[3].extra_keys is None assert block_hashes[3] == sha256(
(block_hashes[2], tuple(token_ids[3 * block_size:] + [8] * 5), None))
# Test cache hit with a new request that has the same salt. # Test cache hit with a new request that has the same salt.
token_ids = common_token_ids + [4] * 11 token_ids = common_token_ids + [4] * 11
req1 = make_request("1", token_ids, block_size, hash, cache_salt="salt1") req1 = make_request("1", token_ids, block_size, sha256, cache_salt="salt1")
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
# Should match only a prefix of 3 blocks. # Should match only a prefix of 3 blocks.
assert len(computed_blocks.blocks[0]) == 3 assert len(computed_blocks.blocks[0]) == 3
...@@ -979,13 +998,19 @@ def test_cache_key_salting(): ...@@ -979,13 +998,19 @@ def test_cache_key_salting():
# Test cache miss with same content but different salt. # Test cache miss with same content but different salt.
token_ids = common_token_ids + [4] * 11 token_ids = common_token_ids + [4] * 11
req2 = make_request("2", token_ids, block_size, hash, cache_salt="salt2") req2 = make_request("2", token_ids, block_size, sha256, cache_salt="salt2")
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert len(computed_blocks.blocks[0]) == 0 assert len(computed_blocks.blocks[0]) == 0
assert num_computed_tokens == 0 assert num_computed_tokens == 0
block_hashes = req2.block_hashes block_hashes = req2.block_hashes
assert len(block_hashes) == 3 assert len(block_hashes) == 3
assert block_hashes[0].extra_keys == ("salt2", ) assert block_hashes[0] == sha256(
(kv_cache_utils.NONE_HASH, tuple(token_ids[:block_size]), ("salt2", )))
assert block_hashes[1] == sha256(
(block_hashes[0], tuple(token_ids[block_size:block_size * 2]), None))
assert block_hashes[2] == sha256(
(block_hashes[1], tuple(token_ids[block_size * 2:block_size * 3]),
None))
def test_prefill_not_enough_free_blocks_with_computed_blocks(): def test_prefill_not_enough_free_blocks_with_computed_blocks():
...@@ -1004,7 +1029,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): ...@@ -1004,7 +1029,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
# Complete 3 blocks (48 tokens) # Complete 3 blocks (48 tokens)
# | Common-0 | Common-1 | Common-2 | ... | # | Common-0 | Common-1 | Common-2 | ... |
common_token_ids = [i for i in range(3) for _ in range(16)] common_token_ids = [i for i in range(3) for _ in range(16)]
req0 = make_request("0", common_token_ids, block_size, hash) req0 = make_request("0", common_token_ids, block_size, sha256)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
assert not computed_blocks.blocks[0] assert not computed_blocks.blocks[0]
assert num_computed_tokens == 0 assert num_computed_tokens == 0
...@@ -1015,7 +1040,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): ...@@ -1015,7 +1040,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
req0.request_id] req0.request_id]
# | Common-0 | Common-1 | Common-2 | Req1-3 | Req1-4 | Req1-5 | ... | # | Common-0 | Common-1 | Common-2 | Req1-3 | Req1-4 | Req1-5 | ... |
req1 = make_request("1", common_token_ids * 2, block_size, hash) req1 = make_request("1", common_token_ids * 2, block_size, sha256)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
assert computed_blocks.blocks[0] == block_part0 assert computed_blocks.blocks[0] == block_part0
assert num_computed_tokens == 3 * 16 assert num_computed_tokens == 3 * 16
...@@ -1032,7 +1057,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): ...@@ -1032,7 +1057,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
# | Common-0 | Common-1 | Common-2 | Req1-3 (F) | Req1-4 (F) | # | Common-0 | Common-1 | Common-2 | Req1-3 (F) | Req1-4 (F) |
# | Req1-5(F)| Req2-0 | Req2-1 | ... | # | Req1-5(F)| Req2-0 | Req2-1 | ... |
req2 = make_request("2", [7] * block_size * 2, block_size, hash) req2 = make_request("2", [7] * block_size * 2, block_size, sha256)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
assert not computed_blocks.blocks[0] assert not computed_blocks.blocks[0]
assert num_computed_tokens == 0 assert num_computed_tokens == 0
...@@ -1044,7 +1069,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): ...@@ -1044,7 +1069,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
# but it cannot be allocated due to insufficient free blocks (2). # but it cannot be allocated due to insufficient free blocks (2).
# In this case, the ref_cnt of the computed blocks should not be changed. # In this case, the ref_cnt of the computed blocks should not be changed.
assert manager.block_pool.free_block_queue.num_free_blocks == 5 assert manager.block_pool.free_block_queue.num_free_blocks == 5
req3 = make_request("3", common_token_ids * 3, block_size, hash) req3 = make_request("3", common_token_ids * 3, block_size, sha256)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3)
assert computed_blocks.blocks[0] == block_part1 assert computed_blocks.blocks[0] == block_part1
assert num_computed_tokens == 6 * 16 assert num_computed_tokens == 6 * 16
...@@ -1069,13 +1094,13 @@ def test_reset_prefix_cache(): ...@@ -1069,13 +1094,13 @@ def test_reset_prefix_cache():
full_block_token_ids = [i for i in range(3) for _ in range(16)] full_block_token_ids = [i for i in range(3) for _ in range(16)]
unique_token_ids = [3] * 7 unique_token_ids = [3] * 7
all_token_ids = full_block_token_ids + unique_token_ids all_token_ids = full_block_token_ids + unique_token_ids
req0 = make_request("0", all_token_ids, block_size, hash) req0 = make_request("0", all_token_ids, block_size, sha256)
blocks = manager.allocate_slots(req0, 55) blocks = manager.allocate_slots(req0, 55)
assert blocks.get_block_ids() == ([1, 2, 3, 4], ) assert blocks.get_block_ids() == ([1, 2, 3, 4], )
unique_token_ids = [4] * 7 unique_token_ids = [4] * 7
all_token_ids = full_block_token_ids + unique_token_ids all_token_ids = full_block_token_ids + unique_token_ids
req1 = make_request("1", all_token_ids, block_size, hash) req1 = make_request("1", all_token_ids, block_size, sha256)
computed_blocks, _ = manager.get_computed_blocks(req1) computed_blocks, _ = manager.get_computed_blocks(req1)
assert len(req1.block_hashes) == 3 assert len(req1.block_hashes) == 3
assert len(computed_blocks.blocks[0]) == 3 assert len(computed_blocks.blocks[0]) == 3
...@@ -1109,7 +1134,7 @@ def test_prefix_cache_stats_disabled(): ...@@ -1109,7 +1134,7 @@ def test_prefix_cache_stats_disabled():
assert manager.prefix_cache_stats is None assert manager.prefix_cache_stats is None
# Call all functions that check whether log_stats is disabled. # Call all functions that check whether log_stats is disabled.
req = make_request("0", list(range(16)), block_size, hash) req = make_request("0", list(range(16)), block_size, sha256)
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
assert not computed_blocks.blocks[0] assert not computed_blocks.blocks[0]
assert num_computed_tokens == 0 assert num_computed_tokens == 0
...@@ -1124,15 +1149,9 @@ def test_prefix_cache_stats_disabled(): ...@@ -1124,15 +1149,9 @@ def test_prefix_cache_stats_disabled():
def test_maybe_evict_cached_block(): def test_maybe_evict_cached_block():
pool = BlockPool(num_gpu_blocks=4, enable_caching=True) pool = BlockPool(num_gpu_blocks=4, enable_caching=True)
block_hash0 = BlockHashWithGroupId(block_hash=BlockHash(hash_value=10, block_hash0 = make_block_hash_with_group_id(BlockHash(b"10"), 1000)
token_ids=(100, )), block_hash1 = make_block_hash_with_group_id(BlockHash(b"20"), 2000)
group_id=1000) block_hash2 = make_block_hash_with_group_id(BlockHash(b"30"), 3000)
block_hash1 = BlockHashWithGroupId(block_hash=BlockHash(hash_value=20,
token_ids=(200, )),
group_id=2000)
block_hash2 = BlockHashWithGroupId(block_hash=BlockHash(hash_value=30,
token_ids=(300, )),
group_id=3000)
block_hashes = [ block_hashes = [
block_hash0, block_hash0,
block_hash1, block_hash1,
...@@ -1206,7 +1225,7 @@ def test_kv_cache_events(blocks_to_cache: int): ...@@ -1206,7 +1225,7 @@ def test_kv_cache_events(blocks_to_cache: int):
) )
num_tokens = block_size * blocks_to_cache num_tokens = block_size * blocks_to_cache
req0 = make_request("0", list(range(num_tokens)), block_size, hash) req0 = make_request("0", list(range(num_tokens)), block_size, sha256)
_ = manager.allocate_slots(req0, num_tokens) _ = manager.allocate_slots(req0, num_tokens)
events = manager.take_events() events = manager.take_events()
...@@ -1222,7 +1241,7 @@ def test_kv_cache_events(blocks_to_cache: int): ...@@ -1222,7 +1241,7 @@ def test_kv_cache_events(blocks_to_cache: int):
# Should see block_to_cache number of removed block events and a new block # Should see block_to_cache number of removed block events and a new block
# stored event # stored event
manager.free(req0) manager.free(req0)
req1 = make_request("1", list(range(num_tokens)), block_size, hash) req1 = make_request("1", list(range(num_tokens)), block_size, sha256)
_ = manager.allocate_slots(req1, num_tokens) _ = manager.allocate_slots(req1, num_tokens)
events = manager.take_events() events = manager.take_events()
...@@ -1256,7 +1275,7 @@ def test_eagle_enabled_removes_last_block(): ...@@ -1256,7 +1275,7 @@ def test_eagle_enabled_removes_last_block():
# Request with 3 full blocks (48 tokens) # Request with 3 full blocks (48 tokens)
token_ids = [0] * (3 * block_size) token_ids = [0] * (3 * block_size)
req = make_request("divisible_request", token_ids, block_size, hash) req = make_request("divisible_request", token_ids, block_size, sha256)
# Prime the cache # Prime the cache
computed_blocks, _ = manager.get_computed_blocks(req) computed_blocks, _ = manager.get_computed_blocks(req)
...@@ -1266,7 +1285,7 @@ def test_eagle_enabled_removes_last_block(): ...@@ -1266,7 +1285,7 @@ def test_eagle_enabled_removes_last_block():
manager.free(req) manager.free(req)
# New request with same tokens + Eagle enabled # New request with same tokens + Eagle enabled
req_eagle = make_request("eagle_divisible", token_ids, block_size, hash) req_eagle = make_request("eagle_divisible", token_ids, block_size, sha256)
computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle) computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle)
# Should retain 1 block: # Should retain 1 block:
...@@ -1287,7 +1306,7 @@ def test_eagle_with_partial_blocks(): ...@@ -1287,7 +1306,7 @@ def test_eagle_with_partial_blocks():
) )
# 2 full blocks + 5 tokens (non-divisible length) # 2 full blocks + 5 tokens (non-divisible length)
token_ids = [0] * (2 * block_size + 5) token_ids = [0] * (2 * block_size + 5)
req = make_request("partial_block_test", token_ids, block_size, hash) req = make_request("partial_block_test", token_ids, block_size, sha256)
# Prime the cache # Prime the cache
computed_blocks, _ = manager.get_computed_blocks(req) computed_blocks, _ = manager.get_computed_blocks(req)
...@@ -1297,7 +1316,7 @@ def test_eagle_with_partial_blocks(): ...@@ -1297,7 +1316,7 @@ def test_eagle_with_partial_blocks():
manager.free(req) manager.free(req)
# New request with Eagle enabled # New request with Eagle enabled
req_eagle = make_request("partial_eagle", token_ids, block_size, hash) req_eagle = make_request("partial_eagle", token_ids, block_size, sha256)
computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle) computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle)
# Original match: 2 full blocks → Eagle removes 1 → 1 remaining # Original match: 2 full blocks → Eagle removes 1 → 1 remaining
assert len(computed_blocks.blocks[0]) == 1 assert len(computed_blocks.blocks[0]) == 1
...@@ -1328,7 +1347,7 @@ def test_eagle_with_sliding_window(): ...@@ -1328,7 +1347,7 @@ def test_eagle_with_sliding_window():
# 2 full blocks + 5 tokens (non-divisible length) # 2 full blocks + 5 tokens (non-divisible length)
token_ids = [0] * (2 * block_size + 5) token_ids = [0] * (2 * block_size + 5)
req = make_request("partial_block_test", token_ids, block_size, hash) req = make_request("partial_block_test", token_ids, block_size, sha256)
# Prime the cache # Prime the cache
computed_blocks, _ = manager.get_computed_blocks(req) computed_blocks, _ = manager.get_computed_blocks(req)
...@@ -1341,7 +1360,7 @@ def test_eagle_with_sliding_window(): ...@@ -1341,7 +1360,7 @@ def test_eagle_with_sliding_window():
manager.free(req) manager.free(req)
# New request with Eagle enabled # New request with Eagle enabled
req_eagle = make_request("partial_eagle", token_ids, block_size, hash) req_eagle = make_request("partial_eagle", token_ids, block_size, sha256)
computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle) computed_blocks, num_tokens = manager.get_computed_blocks(req_eagle)
# Original match: 2 full blocks → Eagle removes 1 → 1 remaining # Original match: 2 full blocks → Eagle removes 1 → 1 remaining
assert len(computed_blocks.blocks[0]) == 1 assert len(computed_blocks.blocks[0]) == 1
...@@ -1351,11 +1370,11 @@ def test_eagle_with_sliding_window(): ...@@ -1351,11 +1370,11 @@ def test_eagle_with_sliding_window():
assert manager.block_pool.get_cached_block( assert manager.block_pool.get_cached_block(
block_hash_first_block, kv_cache_group_ids=[0]) is not None block_hash_first_block, kv_cache_group_ids=[0]) is not None
manager.block_pool.cached_block_hash_to_block.pop( manager.block_pool.cached_block_hash_to_block.pop(
BlockHashWithGroupId(block_hash_first_block, 0)) make_block_hash_with_group_id(block_hash_first_block, 0))
# New request # New request
req_after_evict = make_request("partial_eagle_after_evict", token_ids, req_after_evict = make_request("partial_eagle_after_evict", token_ids,
block_size, hash) block_size, sha256)
computed_blocks, num_tokens = manager.get_computed_blocks(req_after_evict) computed_blocks, num_tokens = manager.get_computed_blocks(req_after_evict)
# Cache miss. The only hit prefix is [NULL_BLOCK, BLOCK_2] if eagle is # Cache miss. The only hit prefix is [NULL_BLOCK, BLOCK_2] if eagle is
# not considered. But after dropping the last matched block due to eagle, # not considered. But after dropping the last matched block due to eagle,
......
...@@ -6,8 +6,8 @@ import random ...@@ -6,8 +6,8 @@ import random
import torch import torch
from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId, from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock,
KVCacheBlock) make_block_hash_with_group_id)
from vllm.v1.core.single_type_kv_cache_manager import ( from vllm.v1.core.single_type_kv_cache_manager import (
ChunkedLocalAttentionManager, SlidingWindowManager) ChunkedLocalAttentionManager, SlidingWindowManager)
from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec, from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec,
...@@ -44,7 +44,7 @@ def test_chunked_local_attention_possible_cached_prefix(): ...@@ -44,7 +44,7 @@ def test_chunked_local_attention_possible_cached_prefix():
def run_one_case(block_is_cached, tail_token, expect_length): def run_one_case(block_is_cached, tail_token, expect_length):
block_hash_list = [ block_hash_list = [
BlockHash(i, ()) for i in range(len(block_is_cached)) BlockHash(str(i).encode()) for i in range(len(block_is_cached))
] ]
block_pool.cached_block_hash_to_block.clear() block_pool.cached_block_hash_to_block.clear()
...@@ -53,8 +53,8 @@ def test_chunked_local_attention_possible_cached_prefix(): ...@@ -53,8 +53,8 @@ def test_chunked_local_attention_possible_cached_prefix():
for i, (block_hash, for i, (block_hash,
is_cached) in enumerate(zip(block_hash_list, block_is_cached)): is_cached) in enumerate(zip(block_hash_list, block_is_cached)):
if is_cached: if is_cached:
block_pool.cached_block_hash_to_block[BlockHashWithGroupId( block_pool.cached_block_hash_to_block[
block_hash, 0)] = { make_block_hash_with_group_id(block_hash, 0)] = {
i: block_pool.blocks[i + 10], i: block_pool.blocks[i + 10],
} }
...@@ -109,7 +109,7 @@ def test_sliding_window_possible_cached_prefix(): ...@@ -109,7 +109,7 @@ def test_sliding_window_possible_cached_prefix():
def run_one_case(block_is_cached, expect_length): def run_one_case(block_is_cached, expect_length):
block_hash_list = [ block_hash_list = [
BlockHash(i, ()) for i in range(len(block_is_cached)) BlockHash(str(i).encode()) for i in range(len(block_is_cached))
] ]
block_pool.cached_block_hash_to_block.clear() block_pool.cached_block_hash_to_block.clear()
...@@ -118,8 +118,8 @@ def test_sliding_window_possible_cached_prefix(): ...@@ -118,8 +118,8 @@ def test_sliding_window_possible_cached_prefix():
for i, (block_hash, for i, (block_hash,
is_cached) in enumerate(zip(block_hash_list, block_is_cached)): is_cached) in enumerate(zip(block_hash_list, block_is_cached)):
if is_cached: if is_cached:
block_pool.cached_block_hash_to_block[BlockHashWithGroupId( block_pool.cached_block_hash_to_block[
block_hash, 0)] = { make_block_hash_with_group_id(block_hash, 0)] = {
i: block_pool.blocks[i + 10], i: block_pool.blocks[i + 10],
} }
......
...@@ -9,6 +9,7 @@ from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig, ...@@ -9,6 +9,7 @@ from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig,
from vllm.multimodal.inputs import (MultiModalFeatureSpec, from vllm.multimodal.inputs import (MultiModalFeatureSpec,
MultiModalKwargsItem, PlaceholderRange) MultiModalKwargsItem, PlaceholderRange)
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.utils import sha256
from vllm.v1.core.kv_cache_utils import (get_request_block_hasher, from vllm.v1.core.kv_cache_utils import (get_request_block_hasher,
init_none_hash) init_none_hash)
from vllm.v1.core.sched.async_scheduler import AsyncScheduler from vllm.v1.core.sched.async_scheduler import AsyncScheduler
...@@ -130,10 +131,10 @@ def create_requests( ...@@ -130,10 +131,10 @@ def create_requests(
) -> list[Request]: ) -> list[Request]:
global _none_hash_initialized global _none_hash_initialized
if not _none_hash_initialized: if not _none_hash_initialized:
init_none_hash(hash) init_none_hash(sha256)
_none_hash_initialized = True _none_hash_initialized = True
block_hasher = get_request_block_hasher(block_size, hash) block_hasher = get_request_block_hasher(block_size, sha256)
sampling_params = SamplingParams(ignore_eos=False, sampling_params = SamplingParams(ignore_eos=False,
max_tokens=max_tokens, max_tokens=max_tokens,
stop_token_ids=stop_token_ids, stop_token_ids=stop_token_ids,
......
...@@ -62,6 +62,16 @@ backend_configs = { ...@@ -62,6 +62,16 @@ backend_configs = {
"cudagraph_mode": "FULL_AND_PIECEWISE", "cudagraph_mode": "FULL_AND_PIECEWISE",
}, },
specific_gpu_arch=(9, 0)), specific_gpu_arch=(9, 0)),
# FlashAttention MLA on Hopper
"FlashAttentionMLA":
BackendConfig(name="FlashAttentionMLA",
env_vars={
"VLLM_ATTENTION_BACKEND": "FLASH_ATTN_MLA",
},
comp_config={
"cudagraph_mode": "FULL_DECODE_ONLY",
},
specific_gpu_arch=(9, 0)),
# FA2 # FA2
"FA2": "FA2":
BackendConfig(name="FA2", BackendConfig(name="FA2",
......
...@@ -83,7 +83,7 @@ def test_ngram_correctness( ...@@ -83,7 +83,7 @@ def test_ngram_correctness(
model_name: str, model_name: str,
): ):
''' '''
Compare the outputs of a original LLM and a speculative LLM Compare the outputs of an original LLM and a speculative LLM
should be the same when using ngram speculative decoding. should be the same when using ngram speculative decoding.
''' '''
with monkeypatch.context() as m: with monkeypatch.context() as m:
...@@ -117,45 +117,38 @@ def test_ngram_correctness( ...@@ -117,45 +117,38 @@ def test_ngram_correctness(
print(f"ref_output: {ref_output.outputs[0].text}") print(f"ref_output: {ref_output.outputs[0].text}")
print(f"spec_output: {spec_output.outputs[0].text}") print(f"spec_output: {spec_output.outputs[0].text}")
# Heuristic: expect at least 70% of the prompts to match exactly # Heuristic: expect at least 68% of the prompts to match exactly
# Upon failure, inspect the outputs to check for inaccuracy. # Upon failure, inspect the outputs to check for inaccuracy.
assert matches > int(0.7 * len(ref_outputs)) assert matches >= int(0.68 * len(ref_outputs))
del spec_llm del spec_llm
torch.cuda.empty_cache() torch.cuda.empty_cache()
cleanup_dist_env_and_memory() cleanup_dist_env_and_memory()
@pytest.mark.parametrize( @pytest.mark.parametrize(["model_setup", "mm_enabled"], [
["model_setup", "mm_enabled"], (("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False),
[ (("eagle", "meta-llama/Llama-3.1-8B-Instruct",
# TODO: Re-enable this once tests/models/test_initialization.py is fixed, see PR #22333 #22611 # noqa: E501 "yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), False),
# (("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False), (("eagle3", "meta-llama/Llama-3.1-8B-Instruct",
(("eagle", "meta-llama/Llama-3.1-8B-Instruct", "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1), False),
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), False), pytest.param(
(("eagle3", "meta-llama/Llama-3.1-8B-Instruct", ("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1), False), "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
pytest.param( False,
("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct", marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4), pytest.param(
False, ("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")), "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
pytest.param( True,
("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct", marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4), (("eagle", "eagle618/deepseek-v3-random",
True, "eagle618/eagle-deepseek-v3-random", 1), False),
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")), ],
(("eagle", "eagle618/deepseek-v3-random", ids=[
"eagle618/eagle-deepseek-v3-random", 1), False), "qwen3_eagle3", "llama3_eagle", "llama3_eagle3",
], "llama4_eagle", "llama4_eagle_mm",
ids=[ "deepseek_eagle"
# TODO: Re-enable this once tests/models/test_initialization.py is fixed, see PR #22333 #22611 # noqa: E501 ])
# "qwen3_eagle3",
"llama3_eagle",
"llama3_eagle3",
"llama4_eagle",
"llama4_eagle_mm",
"deepseek_eagle"
])
@pytest.mark.parametrize("attn_backend", @pytest.mark.parametrize("attn_backend",
get_attn_backend_list_based_on_platform()) get_attn_backend_list_based_on_platform())
def test_eagle_correctness( def test_eagle_correctness(
...@@ -169,7 +162,7 @@ def test_eagle_correctness( ...@@ -169,7 +162,7 @@ def test_eagle_correctness(
# TODO: Fix this flaky test # TODO: Fix this flaky test
pytest.skip( pytest.skip(
"TREE_ATTN is flaky in the test disable for now until it can be " "TREE_ATTN is flaky in the test disable for now until it can be "
"reolved (see https://github.com/vllm-project/vllm/issues/22922)") "resolved (see https://github.com/vllm-project/vllm/issues/22922)")
# Generate test prompts inside the function instead of using fixture # Generate test prompts inside the function instead of using fixture
test_prompts = get_test_prompts(mm_enabled) test_prompts = get_test_prompts(mm_enabled)
......
...@@ -393,7 +393,7 @@ class MockLoggingStatLogger(LoggingStatLogger): ...@@ -393,7 +393,7 @@ class MockLoggingStatLogger(LoggingStatLogger):
async def test_customize_loggers(monkeypatch): async def test_customize_loggers(monkeypatch):
"""Test that we can customize the loggers. """Test that we can customize the loggers.
If a customized logger is provided at the init, it should If a customized logger is provided at the init, it should
be used directly. be added to the default loggers.
""" """
with monkeypatch.context() as m, ExitStack() as after: with monkeypatch.context() as m, ExitStack() as after:
...@@ -410,7 +410,8 @@ async def test_customize_loggers(monkeypatch): ...@@ -410,7 +410,8 @@ async def test_customize_loggers(monkeypatch):
stat_loggers = engine.logger_manager.per_engine_logger_dict stat_loggers = engine.logger_manager.per_engine_logger_dict
assert len(stat_loggers) == 1 assert len(stat_loggers) == 1
assert len(stat_loggers[0]) == 1 assert len(
stat_loggers[0]) == 2 # LoggingStatLogger + MockLoggingStatLogger
stat_loggers[0][0].log.assert_called_once() stat_loggers[0][0].log.assert_called_once()
......
...@@ -36,18 +36,19 @@ def test_prefix_caching_from_cli(): ...@@ -36,18 +36,19 @@ def test_prefix_caching_from_cli():
assert vllm_config.cache_config.enable_prefix_caching assert vllm_config.cache_config.enable_prefix_caching
# default hash algorithm is "builtin" # default hash algorithm is "builtin"
assert vllm_config.cache_config.prefix_caching_hash_algo == "builtin" assert vllm_config.cache_config.prefix_caching_hash_algo == "sha256"
# set hash algorithm to sha256_cbor
args = parser.parse_args(["--prefix-caching-hash-algo", "sha256_cbor"])
vllm_config = EngineArgs.from_cli_args(args=args).create_engine_config()
assert vllm_config.cache_config.prefix_caching_hash_algo == \
"sha256_cbor"
# set hash algorithm to sha256 # set hash algorithm to sha256
args = parser.parse_args(["--prefix-caching-hash-algo", "sha256"]) args = parser.parse_args(["--prefix-caching-hash-algo", "sha256"])
vllm_config = EngineArgs.from_cli_args(args=args).create_engine_config() vllm_config = EngineArgs.from_cli_args(args=args).create_engine_config()
assert vllm_config.cache_config.prefix_caching_hash_algo == "sha256" assert vllm_config.cache_config.prefix_caching_hash_algo == "sha256"
# set hash algorithm to builtin
args = parser.parse_args(["--prefix-caching-hash-algo", "builtin"])
vllm_config = EngineArgs.from_cli_args(args=args).create_engine_config()
assert vllm_config.cache_config.prefix_caching_hash_algo == "builtin"
# an invalid hash algorithm raises an error # an invalid hash algorithm raises an error
parser.exit_on_error = False parser.exit_on_error = False
with pytest.raises(ArgumentError): with pytest.raises(ArgumentError):
......
...@@ -152,8 +152,8 @@ def test_multi_modal_uuids_accepts_none_and_passes_through( ...@@ -152,8 +152,8 @@ def test_multi_modal_uuids_accepts_none_and_passes_through(
*, *,
tokenization_kwargs=None, tokenization_kwargs=None,
lora_request=None, lora_request=None,
mm_hash_overrides=None): mm_uuids=None):
captured["mm_hash_overrides"] = mm_hash_overrides captured["mm_uuids"] = mm_uuids
# Minimal processed inputs for decoder-only flow # Minimal processed inputs for decoder-only flow
return {"type": "token", "prompt_token_ids": [1]} return {"type": "token", "prompt_token_ids": [1]}
...@@ -180,7 +180,7 @@ def test_multi_modal_uuids_accepts_none_and_passes_through( ...@@ -180,7 +180,7 @@ def test_multi_modal_uuids_accepts_none_and_passes_through(
params=SamplingParams(), params=SamplingParams(),
) )
assert captured["mm_hash_overrides"] == mm_uuids assert captured["mm_uuids"] == mm_uuids
def test_multi_modal_uuids_ignored_when_caching_disabled(monkeypatch): def test_multi_modal_uuids_ignored_when_caching_disabled(monkeypatch):
...@@ -196,8 +196,8 @@ def test_multi_modal_uuids_ignored_when_caching_disabled(monkeypatch): ...@@ -196,8 +196,8 @@ def test_multi_modal_uuids_ignored_when_caching_disabled(monkeypatch):
*, *,
tokenization_kwargs=None, tokenization_kwargs=None,
lora_request=None, lora_request=None,
mm_hash_overrides=None): mm_uuids=None):
captured["mm_hash_overrides"] = mm_hash_overrides captured["mm_uuids"] = mm_uuids
return {"type": "token", "prompt_token_ids": [1]} return {"type": "token", "prompt_token_ids": [1]}
monkeypatch.setattr(processor.input_preprocessor, monkeypatch.setattr(processor.input_preprocessor,
...@@ -223,7 +223,7 @@ def test_multi_modal_uuids_ignored_when_caching_disabled(monkeypatch): ...@@ -223,7 +223,7 @@ def test_multi_modal_uuids_ignored_when_caching_disabled(monkeypatch):
) )
# Expect request-id-based overrides are passed through # Expect request-id-based overrides are passed through
assert captured["mm_hash_overrides"] == { assert captured["mm_uuids"] == {
"image": [f"{request_id}-image-0", f"{request_id}-image-1"], "image": [f"{request_id}-image-0", f"{request_id}-image-1"],
"video": [f"{request_id}-video-0"], "video": [f"{request_id}-video-0"],
} }
...@@ -46,12 +46,12 @@ PARAMS_MODELS_BACKENDS_TOKENIZER_MODE = [ ...@@ -46,12 +46,12 @@ PARAMS_MODELS_BACKENDS_TOKENIZER_MODE = [
("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "mistral", None), ("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "mistral", None),
("Qwen/Qwen2.5-1.5B-Instruct", "xgrammar", "auto", None), ("Qwen/Qwen2.5-1.5B-Instruct", "xgrammar", "auto", None),
("Qwen/Qwen2.5-1.5B-Instruct", "lm-format-enforcer", "auto", None), ("Qwen/Qwen2.5-1.5B-Instruct", "lm-format-enforcer", "auto", None),
("mistralai/Ministral-8B-Instruct-2410", "outlines", "auto", None), #FIXME: This tests are flaky on CI thus disabled. Tracking in Issue #24402
("mistralai/Ministral-8B-Instruct-2410", "outlines", "mistral", None), # ("mistralai/Ministral-8B-Instruct-2410", "outlines", "auto", None),
# ("mistralai/Ministral-8B-Instruct-2410", "outlines", "mistral", None),
#("Qwen/Qwen2.5-1.5B-Instruct", "guidance", "auto"),
("mistralai/Ministral-8B-Instruct-2410", "outlines", "auto", ("mistralai/Ministral-8B-Instruct-2410", "outlines", "auto",
NGRAM_SPEC_CONFIG), NGRAM_SPEC_CONFIG),
#FIXME: This test is flaky on CI thus disabled
#("Qwen/Qwen2.5-1.5B-Instruct", "guidance", "auto"),
("mistralai/Ministral-8B-Instruct-2410", "guidance", "auto", ("mistralai/Ministral-8B-Instruct-2410", "guidance", "auto",
NGRAM_SPEC_CONFIG), NGRAM_SPEC_CONFIG),
("Qwen/Qwen2.5-1.5B-Instruct", "xgrammar", "auto", NGRAM_SPEC_CONFIG), ("Qwen/Qwen2.5-1.5B-Instruct", "xgrammar", "auto", NGRAM_SPEC_CONFIG),
...@@ -122,6 +122,7 @@ def test_structured_output( ...@@ -122,6 +122,7 @@ def test_structured_output(
guided_decoding_backend=guided_decoding_backend, guided_decoding_backend=guided_decoding_backend,
guided_decoding_disable_any_whitespace=(guided_decoding_backend guided_decoding_disable_any_whitespace=(guided_decoding_backend
in {"xgrammar", "guidance"}), in {"xgrammar", "guidance"}),
seed=120,
tokenizer_mode=tokenizer_mode, tokenizer_mode=tokenizer_mode,
speculative_config=speculative_config) speculative_config=speculative_config)
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import openai # use the official client for correctness check import openai # use the official client for correctness check
import openai.types.responses as openai_responses_types
import pytest import pytest
...@@ -86,3 +87,18 @@ async def test_logprobs(client: openai.AsyncOpenAI): ...@@ -86,3 +87,18 @@ async def test_logprobs(client: openai.AsyncOpenAI):
outputs = response.output outputs = response.output
assert outputs[-1].content[-1].logprobs assert outputs[-1].content[-1].logprobs
assert len(outputs[-1].content[-1].logprobs[0].top_logprobs) == 5 assert len(outputs[-1].content[-1].logprobs[0].top_logprobs) == 5
@pytest.mark.asyncio
async def test_streaming(client: openai.AsyncOpenAI):
stream = await client.responses.create(
input="What is 13 * 24?",
stream=True,
)
events = [event async for event in stream]
assert isinstance(events[0], openai_responses_types.ResponseCreatedEvent)
assert any(
isinstance(event, openai_responses_types.ResponseTextDeltaEvent)
for event in events)
assert isinstance(events[-1],
openai_responses_types.ResponseCompletedEvent)
...@@ -8,17 +8,17 @@ import pytest ...@@ -8,17 +8,17 @@ import pytest
import pytest_asyncio import pytest_asyncio
from tests.utils import RemoteOpenAIServer from tests.utils import RemoteOpenAIServer
from vllm.multimodal.utils import encode_image_base64, fetch_image from vllm.multimodal.utils import encode_image_base64
# Use a small vision model for testing # Use a small vision model for testing
MODEL_NAME = "Qwen/Qwen2.5-VL-3B-Instruct" MODEL_NAME = "Qwen/Qwen2.5-VL-3B-Instruct"
MAXIMUM_IMAGES = 2 MAXIMUM_IMAGES = 2
# Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA) # Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA)
TEST_IMAGE_URLS = [ TEST_IMAGE_ASSETS = [
"https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg", "2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg", # "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
"https://upload.wikimedia.org/wikipedia/commons/f/fa/Grayscale_8bits_palette_sample_image.png", "Grayscale_8bits_palette_sample_image.png", # "https://upload.wikimedia.org/wikipedia/commons/f/fa/Grayscale_8bits_palette_sample_image.png",
"https://upload.wikimedia.org/wikipedia/commons/thumb/9/91/Venn_diagram_rgb.svg/1280px-Venn_diagram_rgb.svg.png", "1280px-Venn_diagram_rgb.svg.png", # "https://upload.wikimedia.org/wikipedia/commons/thumb/9/91/Venn_diagram_rgb.svg/1280px-Venn_diagram_rgb.svg.png",
"https://upload.wikimedia.org/wikipedia/commons/0/0b/RGBA_comp.png", "RGBA_comp.png", # "https://upload.wikimedia.org/wikipedia/commons/0/0b/RGBA_comp.png",
] ]
...@@ -52,16 +52,17 @@ async def client(image_server): ...@@ -52,16 +52,17 @@ async def client(image_server):
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def base64_encoded_image() -> dict[str, str]: def base64_encoded_image(local_asset_server) -> dict[str, str]:
return { return {
image_url: encode_image_base64(fetch_image(image_url)) image_url:
for image_url in TEST_IMAGE_URLS encode_image_base64(local_asset_server.get_image_asset(image_url))
for image_url in TEST_IMAGE_ASSETS
} }
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS) @pytest.mark.parametrize("image_url", TEST_IMAGE_ASSETS, indirect=True)
async def test_single_chat_session_image(client: openai.AsyncOpenAI, async def test_single_chat_session_image(client: openai.AsyncOpenAI,
model_name: str, image_url: str): model_name: str, image_url: str):
content_text = "What's in this image?" content_text = "What's in this image?"
...@@ -91,11 +92,11 @@ async def test_single_chat_session_image(client: openai.AsyncOpenAI, ...@@ -91,11 +92,11 @@ async def test_single_chat_session_image(client: openai.AsyncOpenAI,
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS) @pytest.mark.parametrize("raw_image_url", TEST_IMAGE_ASSETS)
async def test_single_chat_session_image_base64encoded( async def test_single_chat_session_image_base64encoded(
client: openai.AsyncOpenAI, client: openai.AsyncOpenAI,
model_name: str, model_name: str,
image_url: str, raw_image_url: str,
base64_encoded_image: dict[str, str], base64_encoded_image: dict[str, str],
): ):
content_text = "What's in this image?" content_text = "What's in this image?"
...@@ -106,7 +107,7 @@ async def test_single_chat_session_image_base64encoded( ...@@ -106,7 +107,7 @@ async def test_single_chat_session_image_base64encoded(
{ {
"type": "input_image", "type": "input_image",
"image_url": "image_url":
f"data:image/jpeg;base64,{base64_encoded_image[image_url]}", f"data:image/jpeg;base64,{base64_encoded_image[raw_image_url]}",
"detail": "auto", "detail": "auto",
}, },
{ {
...@@ -127,7 +128,8 @@ async def test_single_chat_session_image_base64encoded( ...@@ -127,7 +128,8 @@ async def test_single_chat_session_image_base64encoded(
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"image_urls", "image_urls",
[TEST_IMAGE_URLS[:i] for i in range(2, len(TEST_IMAGE_URLS))]) [TEST_IMAGE_ASSETS[:i] for i in range(2, len(TEST_IMAGE_ASSETS))],
indirect=True)
async def test_multi_image_input(client: openai.AsyncOpenAI, model_name: str, async def test_multi_image_input(client: openai.AsyncOpenAI, model_name: str,
image_urls: list[str]): image_urls: list[str]):
messages = [{ messages = [{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment