Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
38d80967
Commit
38d80967
authored
Sep 12, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.10.2rc2' into v0.10.2rc2-ori
parents
33650733
880c741b
Changes
544
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
503 additions
and
381 deletions
+503
-381
tests/transformers_utils/__init__.py
tests/transformers_utils/__init__.py
+0
-0
tests/transformers_utils/test_config_parser_registry.py
tests/transformers_utils/test_config_parser_registry.py
+37
-0
tests/utils.py
tests/utils.py
+101
-28
tests/utils_/test_utils.py
tests/utils_/test_utils.py
+9
-11
tests/v1/attention/test_attention_backends.py
tests/v1/attention/test_attention_backends.py
+0
-16
tests/v1/attention/test_chunked_local_attention.py
tests/v1/attention/test_chunked_local_attention.py
+1
-1
tests/v1/attention/test_mla_backends.py
tests/v1/attention/test_mla_backends.py
+101
-106
tests/v1/attention/utils.py
tests/v1/attention/utils.py
+2
-0
tests/v1/core/test_kv_cache_utils.py
tests/v1/core/test_kv_cache_utils.py
+28
-39
tests/v1/core/test_prefix_caching.py
tests/v1/core/test_prefix_caching.py
+122
-103
tests/v1/core/test_single_type_kv_cache_manager.py
tests/v1/core/test_single_type_kv_cache_manager.py
+8
-8
tests/v1/core/utils.py
tests/v1/core/utils.py
+3
-2
tests/v1/cudagraph/test_cudagraph_mode.py
tests/v1/cudagraph/test_cudagraph_mode.py
+10
-0
tests/v1/e2e/test_spec_decode.py
tests/v1/e2e/test_spec_decode.py
+28
-35
tests/v1/engine/test_async_llm.py
tests/v1/engine/test_async_llm.py
+3
-2
tests/v1/engine/test_engine_args.py
tests/v1/engine/test_engine_args.py
+7
-6
tests/v1/engine/test_processor_multi_modal_uuids.py
tests/v1/engine/test_processor_multi_modal_uuids.py
+6
-6
tests/v1/entrypoints/llm/test_struct_output_generate.py
tests/v1/entrypoints/llm/test_struct_output_generate.py
+5
-4
tests/v1/entrypoints/openai/responses/test_basic.py
tests/v1/entrypoints/openai/responses/test_basic.py
+16
-0
tests/v1/entrypoints/openai/responses/test_image.py
tests/v1/entrypoints/openai/responses/test_image.py
+16
-14
No files found.
Too many changes to show.
To preserve performance only
544 of 544+
files are displayed.
Plain diff
Email patch
tests/transformers_utils/__init__.py
0 → 100644
View file @
38d80967
tests/transformers_utils/test_config_parser_registry.py
0 → 100644
View file @
38d80967
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
pathlib
import
Path
from
typing
import
Optional
,
Union
import
pytest
from
transformers
import
PretrainedConfig
from
vllm.transformers_utils.config
import
(
get_config_parser
,
register_config_parser
)
from
vllm.transformers_utils.config_parser_base
import
ConfigParserBase
@
register_config_parser
(
"custom_config_parser"
)
class
CustomConfigParser
(
ConfigParserBase
):
def
parse
(
self
,
model
:
Union
[
str
,
Path
],
trust_remote_code
:
bool
,
revision
:
Optional
[
str
]
=
None
,
code_revision
:
Optional
[
str
]
=
None
,
**
kwargs
)
->
tuple
[
dict
,
PretrainedConfig
]:
raise
NotImplementedError
def
test_register_config_parser
():
assert
isinstance
(
get_config_parser
(
"custom_config_parser"
),
CustomConfigParser
)
def
test_invalid_config_parser
():
with
pytest
.
raises
(
ValueError
):
@
register_config_parser
(
"invalid_config_parser"
)
class
InvalidConfigParser
:
pass
tests/utils.py
View file @
38d80967
...
...
@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
asyncio
import
contextlib
import
copy
import
functools
import
importlib
...
...
@@ -13,10 +14,11 @@ import sys
import
tempfile
import
time
import
warnings
from
contextlib
import
contextmanager
,
suppress
from
contextlib
import
ExitStack
,
contextmanager
,
suppress
from
multiprocessing
import
Process
from
pathlib
import
Path
from
typing
import
Any
,
Callable
,
Literal
,
Optional
,
Union
from
unittest.mock
import
patch
import
cloudpickle
import
httpx
...
...
@@ -799,43 +801,106 @@ _P = ParamSpec("_P")
def
fork_new_process_for_each_test
(
f
:
Callable
[
_P
,
None
])
->
Callable
[
_P
,
None
]:
f
unc
:
Callable
[
_P
,
None
])
->
Callable
[
_P
,
None
]:
"""Decorator to fork a new process for each test function.
See https://github.com/vllm-project/vllm/issues/7053 for more details.
"""
@
functools
.
wraps
(
f
)
@
functools
.
wraps
(
f
unc
)
def
wrapper
(
*
args
:
_P
.
args
,
**
kwargs
:
_P
.
kwargs
)
->
None
:
# Make the process the leader of its own process group
# to avoid sending SIGTERM to the parent process
os
.
setpgrp
()
from
_pytest.outcomes
import
Skipped
pid
=
os
.
fork
()
print
(
f
"Fork a new process to run a test
{
pid
}
"
)
if
pid
==
0
:
try
:
f
(
*
args
,
**
kwargs
)
except
Skipped
as
e
:
# convert Skipped to exit code 0
print
(
str
(
e
))
os
.
_exit
(
0
)
except
Exception
:
import
traceback
traceback
.
print_exc
()
os
.
_exit
(
1
)
# Create a unique temporary file to store exception info from child
# process. Use test function name and process ID to avoid collisions.
with
tempfile
.
NamedTemporaryFile
(
delete
=
False
,
mode
=
'w+b'
,
prefix
=
f
"vllm_test_
{
func
.
__name__
}
_
{
os
.
getpid
()
}
_"
,
suffix
=
".exc"
)
as
exc_file
,
ExitStack
()
as
delete_after
:
exc_file_path
=
exc_file
.
name
delete_after
.
callback
(
os
.
remove
,
exc_file_path
)
pid
=
os
.
fork
()
print
(
f
"Fork a new process to run a test
{
pid
}
"
)
if
pid
==
0
:
# Parent process responsible for deleting, don't delete
# in child.
delete_after
.
pop_all
()
try
:
func
(
*
args
,
**
kwargs
)
except
Skipped
as
e
:
# convert Skipped to exit code 0
print
(
str
(
e
))
os
.
_exit
(
0
)
except
Exception
as
e
:
import
traceback
tb_string
=
traceback
.
format_exc
()
# Try to serialize the exception object first
exc_to_serialize
:
dict
[
str
,
Any
]
try
:
# First, try to pickle the actual exception with
# its traceback.
exc_to_serialize
=
{
'pickled_exception'
:
e
}
# Test if it can be pickled
cloudpickle
.
dumps
(
exc_to_serialize
)
except
(
Exception
,
KeyboardInterrupt
):
# Fall back to string-based approach.
exc_to_serialize
=
{
'exception_type'
:
type
(
e
).
__name__
,
'exception_msg'
:
str
(
e
),
'traceback'
:
tb_string
,
}
try
:
with
open
(
exc_file_path
,
'wb'
)
as
f
:
cloudpickle
.
dump
(
exc_to_serialize
,
f
)
except
Exception
:
# Fallback: just print the traceback.
print
(
tb_string
)
os
.
_exit
(
1
)
else
:
os
.
_exit
(
0
)
else
:
os
.
_exit
(
0
)
else
:
pgid
=
os
.
getpgid
(
pid
)
_pid
,
_exitcode
=
os
.
waitpid
(
pid
,
0
)
# ignore SIGTERM signal itself
old_signal_handler
=
signal
.
signal
(
signal
.
SIGTERM
,
signal
.
SIG_IGN
)
# kill all child processes
os
.
killpg
(
pgid
,
signal
.
SIGTERM
)
# restore the signal handler
signal
.
signal
(
signal
.
SIGTERM
,
old_signal_handler
)
assert
_exitcode
==
0
,
(
f
"function
{
f
}
failed when called with"
f
" args
{
args
}
and kwargs
{
kwargs
}
"
)
pgid
=
os
.
getpgid
(
pid
)
_pid
,
_exitcode
=
os
.
waitpid
(
pid
,
0
)
# ignore SIGTERM signal itself
old_signal_handler
=
signal
.
signal
(
signal
.
SIGTERM
,
signal
.
SIG_IGN
)
# kill all child processes
os
.
killpg
(
pgid
,
signal
.
SIGTERM
)
# restore the signal handler
signal
.
signal
(
signal
.
SIGTERM
,
old_signal_handler
)
if
_exitcode
!=
0
:
# Try to read the exception from the child process
exc_info
=
{}
if
os
.
path
.
exists
(
exc_file_path
):
with
contextlib
.
suppress
(
Exception
),
\
open
(
exc_file_path
,
'rb'
)
as
f
:
exc_info
=
cloudpickle
.
load
(
f
)
if
(
original_exception
:
=
exc_info
.
get
(
'pickled_exception'
))
is
not
None
:
# Re-raise the actual exception object if it was
# successfully pickled.
assert
isinstance
(
original_exception
,
Exception
)
raise
original_exception
if
(
original_tb
:
=
exc_info
.
get
(
"traceback"
))
is
not
None
:
# Use string-based traceback for fallback case
raise
AssertionError
(
f
"Test
{
func
.
__name__
}
failed when called with"
f
" args
{
args
}
and kwargs
{
kwargs
}
"
f
" (exit code:
{
_exitcode
}
):
\n
{
original_tb
}
"
)
from
None
# Fallback to the original generic error
raise
AssertionError
(
f
"function
{
func
.
__name__
}
failed when called with"
f
" args
{
args
}
and kwargs
{
kwargs
}
"
f
" (exit code:
{
_exitcode
}
)"
)
from
None
return
wrapper
...
...
@@ -1077,3 +1142,11 @@ def get_attn_backend_list_based_on_platform() -> list[str]:
return
attn_backend_list
else
:
raise
ValueError
(
"Unsupported platform"
)
@
contextmanager
def
override_cutlass_fp8_supported
(
value
:
bool
):
with
patch
(
"vllm.model_executor.layers.quantization.utils.w8a8_utils.cutlass_fp8_supported"
,
return_value
=
value
):
yield
tests/utils_/test_utils.py
View file @
38d80967
...
...
@@ -835,22 +835,20 @@ def test_model_specification(parser_with_config, cli_config_file,
@
pytest
.
mark
.
parametrize
(
"input"
,
[(),
(
"abc"
,
),
(
None
,
),
(
None
,
bool
,
[
1
,
2
,
3
])])
@
pytest
.
mark
.
parametrize
(
"output"
,
[
0
,
1
,
2
])
def
test_sha256
(
input
:
tuple
,
output
:
int
):
hash
=
sha256
(
input
)
assert
hash
is
not
None
assert
isinstance
(
hash
,
int
)
assert
hash
!=
0
def
test_sha256
(
input
:
tuple
):
digest
=
sha256
(
input
)
assert
digest
is
not
None
assert
isinstance
(
digest
,
bytes
)
assert
digest
!=
b
""
bytes
=
pickle
.
dumps
(
input
,
protocol
=
pickle
.
HIGHEST_PROTOCOL
)
assert
hash
==
int
.
from_bytes
(
hashlib
.
sha256
(
bytes
).
digest
(),
byteorder
=
"big"
)
input_bytes
=
pickle
.
dumps
(
input
,
protocol
=
pickle
.
HIGHEST_PROTOCOL
)
assert
digest
==
hashlib
.
sha256
(
input_bytes
).
digest
()
# hashing again, returns the same value
assert
hash
==
sha256
(
input
)
assert
digest
==
sha256
(
input
)
# hashing different input, returns different value
assert
hash
!=
sha256
(
input
+
(
1
,
))
assert
digest
!=
sha256
(
input
+
(
1
,
))
@
pytest
.
mark
.
parametrize
(
...
...
tests/v1/attention/test_attention_backends.py
View file @
38d80967
...
...
@@ -70,22 +70,6 @@ BATCH_SPECS = {
}
def
create_dummy_kv_cache
(
kv_cache_spec
:
FullAttentionSpec
,
device
:
torch
.
device
,
num_blocks
:
int
=
100
)
->
torch
.
Tensor
:
"""Create a dummy KV cache tensor for testing."""
kv_cache
=
torch
.
randn
(
2
,
# K and V
num_blocks
,
kv_cache_spec
.
block_size
,
kv_cache_spec
.
num_kv_heads
,
kv_cache_spec
.
head_size
,
dtype
=
_convert_dtype_to_torch
(
kv_cache_spec
.
dtype
),
device
=
device
,
)
return
kv_cache
def
create_and_prepopulate_kv_cache
(
k_contexts
:
list
[
torch
.
Tensor
],
v_contexts
:
list
[
torch
.
Tensor
],
...
...
tests/v1/attention/test_chunked_local_attention.py
View file @
38d80967
...
...
@@ -160,7 +160,7 @@ def test_local_attention_virtual_batches(test_data: LocalAttentionTestData):
# Use torch.arange instead of torch.randint so we can assert on
# block table tensor values. The block table will have shape
# (num_batches, cdiv(max_seq_len, block_size)) and the values will be
# aranged from 0 to cdiv(max_seq_len, block_size)-1
# ar
r
anged from 0 to cdiv(max_seq_len, block_size)-1
arange_block_indices
=
True
,
)
...
...
tests/v1/attention/test_mla_backends.py
View file @
38d80967
...
...
@@ -15,7 +15,7 @@ from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from
vllm.v1.kv_cache_interface
import
FullAttentionSpec
BACKENDS_TO_TEST
=
[
_Backend
.
CUTLASS_MLA
,
_Backend
.
FLASHMLA_VLLM_V1
,
_Backend
.
CUTLASS_MLA
,
_Backend
.
FLASHMLA_VLLM_V1
,
_Backend
.
FLASH_ATTN_MLA
,
_Backend
.
TRITON_MLA_VLLM_V1
]
...
...
@@ -69,25 +69,10 @@ BATCH_SPECS = {
}
def
create_dummy_kv_cache
(
kv_cache_spec
:
FullAttentionSpec
,
device
:
torch
.
device
,
num_blocks
:
int
=
100
)
->
torch
.
Tensor
:
"""Create a dummy KV cache tensor for testing."""
kv_cache
=
torch
.
randn
(
num_blocks
,
kv_cache_spec
.
block_size
,
kv_cache_spec
.
head_size
,
# latent dimension
dtype
=
_convert_dtype_to_torch
(
kv_cache_spec
.
dtype
),
device
=
device
,
)
return
kv_cache
def
create_and_prepopulate_kv_cache
(
kv_c_contexts
:
list
[
torch
.
Tensor
],
k_pe_contexts
:
list
[
torch
.
Tensor
],
block_size
:
int
,
num_kv_heads
:
int
,
head_size
:
int
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
,
...
...
@@ -101,7 +86,6 @@ def create_and_prepopulate_kv_cache(
k_pe_contexts: List of key positional embedding context tensors
for each sequence
block_size: Size of each block
num_kv_heads: Number of KV heads (should be 1 for MLA)
head_size: Size of each head (latent dimension)
dtype: Data type for the cache
device: Device to create the cache on
...
...
@@ -299,8 +283,6 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
query_lens
=
batch_spec
.
query_lens
num_q_heads
=
vllm_config
.
model_config
.
get_num_attention_heads
(
vllm_config
.
parallel_config
)
num_kv_heads
=
vllm_config
.
model_config
.
get_num_kv_heads
(
vllm_config
.
parallel_config
)
head_size
=
vllm_config
.
model_config
.
get_head_size
()
dtype
=
_convert_dtype_to_torch
(
vllm_config
.
model_config
.
dtype
)
block_size
=
vllm_config
.
cache_config
.
block_size
...
...
@@ -315,7 +297,7 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
# 2. Generate data and compute SDPA reference output for MLA
all_q_vllm
,
all_kv_c_vllm
,
all_k_pe_vllm
=
[],
[],
[]
all_sdpa_outputs
=
[]
all_sdpa_outputs
:
list
[
list
[
torch
.
Tensor
]]
=
[]
kv_c_contexts
,
k_pe_contexts
=
[],
[]
# Create shared MLA weight matrices for consistency across all sequences
...
...
@@ -331,6 +313,9 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
device
=
device
)
kv_b_proj_weight
=
torch
.
cat
([
W_UK
,
W_UV
],
dim
=-
1
)
for
i
,
backend
in
enumerate
(
BACKENDS_TO_TEST
):
all_sdpa_outputs
.
append
([])
for
i
in
range
(
batch_size
):
s_len
=
seq_lens
[
i
]
q_len
=
query_lens
[
i
]
...
...
@@ -358,85 +343,93 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
dtype
=
dtype
,
device
=
device
)
# Determine if this is decode (single token)
# or prefill (multiple tokens)
is_decode
=
q_len
==
1
# Determine if this is decode or prefill
is_decode
=
[]
for
i
,
backend
in
enumerate
(
BACKENDS_TO_TEST
):
builder_cls
,
_
=
get_attention_backend
(
backend
)
is_decode
.
append
(
q_len
<=
builder_cls
.
reorder_batch_threshold
)
# Split q into nope and rope components
q_nope
,
q_pe
=
q_c
.
split
([
qk_nope_head_dim
,
qk_rope_head_dim
],
dim
=-
1
)
if
is_decode
:
# Decode path: MQA-style attention in latent space
# Transform q_nope to latent space: q_nope @ W_UK
# q_nope: [1, num_heads, qk_nope_head_dim]
# W_UK: [kv_lora_rank, num_heads, qk_nope_head_dim]
ql_nope
=
torch
.
einsum
(
"qnh,lnh->qnl"
,
q_nope
,
W_UK
)
# [1, num_heads, kv_lora_rank]
# Build MQA attention inputs
# Q: [1, num_heads, kv_lora_rank + qk_rope_head_dim]
q_mqa
=
torch
.
cat
([
ql_nope
,
q_pe
],
dim
=-
1
)
# K: [s_len, kv_lora_rank + qk_rope_head_dim]
# (broadcasted to all heads)
k_mqa
=
torch
.
cat
([
kv_c_full
,
k_pe_full
.
squeeze
(
1
)],
dim
=-
1
)
k_mqa
=
k_mqa
.
unsqueeze
(
1
).
expand
(
-
1
,
num_q_heads
,
-
1
)
# V: [s_len, kv_lora_rank] (broadcasted to all heads)
v_mqa
=
kv_c_full
.
unsqueeze
(
1
).
expand
(
-
1
,
num_q_heads
,
-
1
)
# SDPA expects (N, H, L, D)
q_sdpa_in
=
q_mqa
.
unsqueeze
(
0
).
transpose
(
1
,
2
)
k_sdpa_in
=
k_mqa
.
unsqueeze
(
0
).
transpose
(
1
,
2
)
v_sdpa_in
=
v_mqa
.
unsqueeze
(
0
).
transpose
(
1
,
2
)
sdpa_out_i
=
torch
.
nn
.
functional
.
scaled_dot_product_attention
(
q_sdpa_in
,
k_sdpa_in
,
v_sdpa_in
,
is_causal
=
False
,
scale
=
scale
)
sdpa_out_i
=
sdpa_out_i
.
transpose
(
1
,
2
).
squeeze
(
0
)
# [1, num_heads, kv_lora_rank]
# Project back to output space: sdpa_out @ W_UV
sdpa_out_i
=
torch
.
einsum
(
"qnl,lnv->qnv"
,
sdpa_out_i
,
W_UV
)
sdpa_out_i
=
sdpa_out_i
.
flatten
(
start_dim
=-
2
)
else
:
# Prefill path: MHA-style attention with full sequence
# Apply kv_b_proj to the full kv_c tensor
kv_nope_full
=
torch
.
einsum
(
"sl,lnh->snh"
,
kv_c_full
,
kv_b_proj_weight
)
k_nope_full
,
v_full
=
kv_nope_full
.
split
(
[
qk_nope_head_dim
,
v_head_dim
],
dim
=-
1
)
# Build attention inputs for full sequence
q_mha
=
torch
.
cat
([
q_nope
,
q_pe
],
dim
=-
1
)
# [q_len, num_heads, total_dim]
k_pe_full_expanded
=
k_pe_full
.
expand
(
-
1
,
num_q_heads
,
-
1
)
k_full
=
torch
.
cat
([
k_nope_full
,
k_pe_full_expanded
],
dim
=-
1
)
# Create custom attention mask:
# - Query tokens can attend to all context tokens
# - Query tokens can only attend to query tokens up to their pos
attn_mask
=
torch
.
ones
(
q_len
,
s_len
,
dtype
=
torch
.
bool
,
device
=
device
)
# Apply causal mask only to the query portion (context_len onwards)
causal_mask
=
torch
.
tril
(
torch
.
ones
(
q_len
,
q_len
,
device
=
device
))
attn_mask
[:,
context_len
:]
=
causal_mask
# SDPA expects (N, H, L, D)
q_sdpa_in
=
q_mha
.
unsqueeze
(
0
).
transpose
(
1
,
2
)
k_sdpa_in
=
k_full
.
unsqueeze
(
0
).
transpose
(
1
,
2
)
v_sdpa_in
=
v_full
.
unsqueeze
(
0
).
transpose
(
1
,
2
)
# Single attention call with custom mask
sdpa_out_i
=
torch
.
nn
.
functional
.
scaled_dot_product_attention
(
q_sdpa_in
,
k_sdpa_in
,
v_sdpa_in
,
attn_mask
=
attn_mask
,
scale
=
scale
)
sdpa_out_i
=
sdpa_out_i
.
transpose
(
1
,
2
).
squeeze
(
0
)
sdpa_out_i
=
sdpa_out_i
.
flatten
(
start_dim
=-
2
)
all_sdpa_outputs
.
append
(
sdpa_out_i
)
#######################################################
# Decode path: MQA-style attention in latent space
# Transform q_nope to latent space: q_nope @ W_UK
# q_nope: [1, num_heads, qk_nope_head_dim]
# W_UK: [kv_lora_rank, num_heads, qk_nope_head_dim]
ql_nope
=
torch
.
einsum
(
"qnh,lnh->qnl"
,
q_nope
,
W_UK
)
# [1, num_heads, kv_lora_rank]
# Build MQA attention inputs
# Q: [1, num_heads, kv_lora_rank + qk_rope_head_dim]
q_mqa
=
torch
.
cat
([
ql_nope
,
q_pe
],
dim
=-
1
)
# K: [s_len, kv_lora_rank + qk_rope_head_dim]
# (broadcasted to all heads)
k_mqa
=
torch
.
cat
([
kv_c_full
,
k_pe_full
.
squeeze
(
1
)],
dim
=-
1
)
k_mqa
=
k_mqa
.
unsqueeze
(
1
).
expand
(
-
1
,
num_q_heads
,
-
1
)
# V: [s_len, kv_lora_rank] (broadcasted to all heads)
v_mqa
=
kv_c_full
.
unsqueeze
(
1
).
expand
(
-
1
,
num_q_heads
,
-
1
)
# Create custom attention mask for decode path:
# - Query tokens can attend to all context tokens
# - Query tokens can only attend to query tokens up to their position
attn_mask
=
torch
.
ones
(
q_len
,
s_len
,
dtype
=
torch
.
bool
,
device
=
device
)
# Apply causal mask only to the query portion (context_len onwards)
causal_mask
=
torch
.
tril
(
torch
.
ones
(
q_len
,
q_len
,
device
=
device
))
attn_mask
[:,
context_len
:]
=
causal_mask
# SDPA expects (N, H, L, D)
q_sdpa_in
=
q_mqa
.
unsqueeze
(
0
).
transpose
(
1
,
2
)
k_sdpa_in
=
k_mqa
.
unsqueeze
(
0
).
transpose
(
1
,
2
)
v_sdpa_in
=
v_mqa
.
unsqueeze
(
0
).
transpose
(
1
,
2
)
sdpa_out_i_decode
=
torch
.
nn
.
functional
.
scaled_dot_product_attention
(
q_sdpa_in
,
k_sdpa_in
,
v_sdpa_in
,
attn_mask
=
attn_mask
,
scale
=
scale
)
sdpa_out_i_decode
=
sdpa_out_i_decode
.
transpose
(
1
,
2
).
squeeze
(
0
)
# [1, num_heads, kv_lora_rank]
# Project back to output space: sdpa_out @ W_UV
sdpa_out_i_decode
=
torch
.
einsum
(
"qnl,lnv->qnv"
,
sdpa_out_i_decode
,
W_UV
)
sdpa_out_i_decode
=
sdpa_out_i_decode
.
flatten
(
start_dim
=-
2
)
#######################################################
# Prefill path: MHA-style attention with full sequence
# Apply kv_b_proj to the full kv_c tensor
kv_nope_full
=
torch
.
einsum
(
"sl,lnh->snh"
,
kv_c_full
,
kv_b_proj_weight
)
k_nope_full
,
v_full
=
kv_nope_full
.
split
(
[
qk_nope_head_dim
,
v_head_dim
],
dim
=-
1
)
# Build attention inputs for full sequence
q_mha
=
torch
.
cat
([
q_nope
,
q_pe
],
dim
=-
1
)
# [q_len, num_heads, total_dim]
k_pe_full_expanded
=
k_pe_full
.
expand
(
-
1
,
num_q_heads
,
-
1
)
k_full
=
torch
.
cat
([
k_nope_full
,
k_pe_full_expanded
],
dim
=-
1
)
# Create custom attention mask:
# - Query tokens can attend to all context tokens
# - Query tokens can only attend to query tokens up to their pos
attn_mask
=
torch
.
ones
(
q_len
,
s_len
,
dtype
=
torch
.
bool
,
device
=
device
)
# Apply causal mask only to the query portion (context_len onwards)
causal_mask
=
torch
.
tril
(
torch
.
ones
(
q_len
,
q_len
,
device
=
device
))
attn_mask
[:,
context_len
:]
=
causal_mask
# SDPA expects (N, H, L, D)
q_sdpa_in
=
q_mha
.
unsqueeze
(
0
).
transpose
(
1
,
2
)
k_sdpa_in
=
k_full
.
unsqueeze
(
0
).
transpose
(
1
,
2
)
v_sdpa_in
=
v_full
.
unsqueeze
(
0
).
transpose
(
1
,
2
)
# Single attention call with custom mask
sdpa_out_i_prefill
=
torch
.
nn
.
functional
.
scaled_dot_product_attention
(
q_sdpa_in
,
k_sdpa_in
,
v_sdpa_in
,
attn_mask
=
attn_mask
,
scale
=
scale
)
sdpa_out_i_prefill
=
sdpa_out_i_prefill
.
transpose
(
1
,
2
).
squeeze
(
0
)
sdpa_out_i_prefill
=
sdpa_out_i_prefill
.
flatten
(
start_dim
=-
2
)
for
i
,
backend
in
enumerate
(
BACKENDS_TO_TEST
):
if
is_decode
[
i
]:
all_sdpa_outputs
[
i
].
append
(
sdpa_out_i_decode
)
else
:
all_sdpa_outputs
[
i
].
append
(
sdpa_out_i_prefill
)
# Inputs for vLLM MLA backends are just the new tokens
all_q_vllm
.
append
(
q_c
)
...
...
@@ -451,7 +444,9 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
query_vllm
=
torch
.
cat
(
all_q_vllm
,
dim
=
0
)
kv_c_vllm
=
torch
.
cat
(
all_kv_c_vllm
,
dim
=
0
)
k_pe_vllm
=
torch
.
cat
(
all_k_pe_vllm
,
dim
=
0
)
sdpa_output
=
torch
.
cat
(
all_sdpa_outputs
,
dim
=
0
)
sdpa_outputs
=
[]
for
i
,
backend
in
enumerate
(
BACKENDS_TO_TEST
):
sdpa_outputs
.
append
(
torch
.
cat
(
all_sdpa_outputs
[
i
],
dim
=
0
))
# Create mock kv_b_proj using the same weights as reference implementation
from
vllm.model_executor.layers.linear
import
ColumnParallelLinear
...
...
@@ -477,7 +472,6 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
kv_c_contexts
=
kv_c_contexts
,
k_pe_contexts
=
k_pe_contexts
,
block_size
=
block_size
,
num_kv_heads
=
num_kv_heads
,
head_size
=
head_size
,
dtype
=
dtype
,
device
=
device
,
...
...
@@ -486,7 +480,7 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
randomize_blocks
=
True
)
# 4. Run vLLM backends and compare
for
backend_name
in
BACKENDS_TO_TEST
:
for
i
,
backend_name
in
enumerate
(
BACKENDS_TO_TEST
)
:
backend_output
=
run_attention_backend
(
backend_name
,
kv_cache_spec
,
[
"placeholder"
],
vllm_config
,
device
,
common_attn_metadata
,
query_vllm
,
kv_c_vllm
,
k_pe_vllm
,
kv_cache
,
...
...
@@ -494,12 +488,12 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
mock_kv_b_proj
)
# Check shape and dtype consistency
assert
backend_output
.
shape
==
sdpa_output
.
shape
,
(
assert
backend_output
.
shape
==
sdpa_output
s
[
i
]
.
shape
,
(
f
"[
{
backend_name
}
] shape
{
backend_output
.
shape
}
!= "
f
"SDPA shape
{
sdpa_output
.
shape
}
"
)
assert
backend_output
.
dtype
==
sdpa_output
.
dtype
,
(
f
"SDPA shape
{
sdpa_output
s
[
i
]
.
shape
}
"
)
assert
backend_output
.
dtype
==
sdpa_output
s
[
i
]
.
dtype
,
(
f
"[
{
backend_name
}
] dtype
{
backend_output
.
dtype
}
!= "
f
"SDPA dtype
{
sdpa_output
.
dtype
}
"
)
f
"SDPA dtype
{
sdpa_output
s
[
i
]
.
dtype
}
"
)
assert
torch
.
isfinite
(
backend_output
).
all
(),
(
f
"[
{
backend_name
}
] produced non-finite values"
)
...
...
@@ -508,12 +502,13 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
rtol
=
1e-2
atol
=
5e-1
max_diff
=
torch
.
max
(
torch
.
abs
(
backend_output
-
sdpa_output
)).
item
()
max_diff
=
torch
.
max
(
torch
.
abs
(
backend_output
-
sdpa_outputs
[
i
])).
item
()
max_rel_diff
=
torch
.
max
(
torch
.
abs
(
backend_output
-
sdpa_output
)
/
torch
.
abs
(
sdpa_output
)).
item
()
torch
.
abs
(
backend_output
-
sdpa_output
s
[
i
]
)
/
torch
.
abs
(
sdpa_output
s
[
i
]
)).
item
()
all_close
=
torch
.
allclose
(
backend_output
,
sdpa_output
,
sdpa_output
s
[
i
]
,
rtol
=
rtol
,
atol
=
atol
)
...
...
tests/v1/attention/utils.py
View file @
38d80967
...
...
@@ -139,6 +139,8 @@ def get_attention_backend(backend_name: _Backend):
"vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend"
,
_Backend
.
FLASHMLA_VLLM_V1
:
"vllm.v1.attention.backends.mla.flashmla.FlashMLABackend"
,
_Backend
.
FLASH_ATTN_MLA
:
"vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend"
,
_Backend
.
TRITON_MLA_VLLM_V1
:
"vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend"
,
}
...
...
tests/v1/core/test_kv_cache_utils.py
View file @
38d80967
...
...
@@ -6,20 +6,22 @@ from typing import Callable, Optional
import
pytest
import
torch
import
vllm.v1.core.kv_cache_utils
as
kv_cache_utils
from
vllm.config
import
ModelConfig
,
SchedulerConfig
,
VllmConfig
from
vllm.multimodal.inputs
import
(
MultiModalFeatureSpec
,
MultiModalKwargsItem
,
PlaceholderRange
)
from
vllm.sampling_params
import
SamplingParams
from
vllm.utils
import
GiB_bytes
,
sha256
,
sha256_cbor
_64bit
from
vllm.utils
import
GiB_bytes
,
sha256
,
sha256_cbor
from
vllm.v1.core.kv_cache_manager
import
KVCacheManager
# disable yapf here as it formats differently than isort such that both fail
# yapf: disable
from
vllm.v1.core.kv_cache_utils
import
(
FreeKVCacheBlockQueue
,
KVCacheBlock
,
PrefixCachingMetrics
,
BlockHash
,
FreeKVCacheBlockQueue
,
KVCacheBlock
,
PrefixCachingMetrics
,
estimate_max_model_len
,
generate_block_hash_extra_keys
,
get_kv_cache_config
,
get_max_concurrency_for_kv_cache_config
,
get_request_block_hasher
,
hash_block_tokens
,
init_none_hash
,
is_kv_cache_type_uniform
,
unify_kv_cache_configs
)
is_kv_cache_type_uniform
,
make_block_hash_with_group_id
,
unify_kv_cache_configs
)
from
vllm.v1.kv_cache_interface
import
(
FullAttentionSpec
,
KVCacheConfig
,
KVCacheGroupSpec
,
KVCacheTensor
,
SlidingWindowSpec
)
...
...
@@ -88,7 +90,7 @@ def new_sliding_window_spec(block_size=16,
sliding_window
=
sliding_window
)
@
pytest
.
mark
.
parametrize
(
"hash_fn"
,
[
sha256
,
sha256_cbor
_64bit
,
hash
])
@
pytest
.
mark
.
parametrize
(
"hash_fn"
,
[
sha256
,
sha256_cbor
])
def
test_none_hash
(
monkeypatch
,
hash_fn
):
import
vllm.v1.core.kv_cache_utils
...
...
@@ -98,8 +100,8 @@ def test_none_hash(monkeypatch, hash_fn):
reloaded_kv_cache_utils
=
importlib
.
reload
(
vllm
.
v1
.
core
.
kv_cache_utils
)
reloaded_kv_cache_utils
.
init_none_hash
(
hash_fn
)
assert
reloaded_kv_cache_utils
.
NONE_HASH
is
not
None
assert
isinstance
(
reloaded_kv_cache_utils
.
NONE_HASH
,
int
)
assert
reloaded_kv_cache_utils
.
NONE_HASH
!=
0
assert
isinstance
(
reloaded_kv_cache_utils
.
NONE_HASH
,
bytes
)
assert
reloaded_kv_cache_utils
.
NONE_HASH
!=
b
""
# case 2: PYTHONHASHSEED is set, use the seed and hash_fn
with
monkeypatch
.
context
()
as
m
:
...
...
@@ -107,12 +109,11 @@ def test_none_hash(monkeypatch, hash_fn):
reloaded_kv_cache_utils
=
importlib
.
reload
(
vllm
.
v1
.
core
.
kv_cache_utils
)
reloaded_kv_cache_utils
.
init_none_hash
(
hash_fn
)
assert
reloaded_kv_cache_utils
.
NONE_HASH
is
not
None
assert
isinstance
(
reloaded_kv_cache_utils
.
NONE_HASH
,
int
)
assert
isinstance
(
reloaded_kv_cache_utils
.
NONE_HASH
,
bytes
)
assert
hash_fn
(
'python hash seed'
)
==
reloaded_kv_cache_utils
.
NONE_HASH
def
test_kv_cache_block
():
import
vllm.v1.core.kv_cache_utils
# Test KVCacheBlock initialization
block
=
KVCacheBlock
(
block_id
=
0
)
...
...
@@ -127,8 +128,7 @@ def test_kv_cache_block():
assert
block
.
ref_cnt
==
0
# Test block hash setting and resetting
block_hash
=
vllm
.
v1
.
core
.
kv_cache_utils
.
BlockHash
(
hash_value
=
123
,
token_ids
=
(
1
,
2
,
3
))
block_hash
=
make_block_hash_with_group_id
(
BlockHash
(
b
"abc"
),
0
)
block
.
block_hash
=
block_hash
assert
block
.
block_hash
==
block_hash
...
...
@@ -247,7 +247,7 @@ def test_free_kv_cache_block_queue_append_n():
def
test_free_kv_cache_block_queue_popleft_n
():
blocks
=
[
KVCacheBlock
(
block_id
=
i
)
for
i
in
range
(
6
)]
# Create a empty FreeKVCacheBlockQueue with these blocks
# Create a
n
empty FreeKVCacheBlockQueue with these blocks
queue
=
FreeKVCacheBlockQueue
(
[
blocks
[
1
],
blocks
[
3
],
blocks
[
5
],
blocks
[
4
],
blocks
[
0
],
blocks
[
2
]])
assert
queue
.
num_free_blocks
==
6
...
...
@@ -407,27 +407,23 @@ def test_generate_block_hash_extra_keys_cache_salt():
assert
next_mm_idx
==
1
@
pytest
.
mark
.
parametrize
(
"hash_fn"
,
[
sha256
,
sha256_cbor
_64bit
,
hash
])
@
pytest
.
mark
.
parametrize
(
"hash_fn"
,
[
sha256
,
sha256_cbor
])
def
test_hash_block_tokens
(
hash_fn
):
import
vllm.v1.core.kv_cache_utils
init_none_hash
(
hash_fn
)
parent_block_hash
=
123
parent_block_hash
=
BlockHash
(
b
"
123
"
)
curr_block_token_ids
=
(
1
,
2
,
3
)
extra_keys
=
(
"key1"
,
"key2"
)
block_hash
=
hash_block_tokens
(
hash_fn
,
parent_block_hash
,
curr_block_token_ids
,
extra_keys
)
assert
isinstance
(
block_hash
,
vllm
.
v1
.
core
.
kv_cache_utils
.
BlockHash
)
assert
block_hash
.
hash_value
==
hash_fn
(
(
parent_block_hash
,
curr_block_token_ids
,
extra_keys
))
assert
block_hash
.
token_ids
==
curr_block_token_ids
assert
block_hash
.
extra_keys
==
extra_keys
expected
=
hash_fn
((
parent_block_hash
,
curr_block_token_ids
,
extra_keys
))
assert
block_hash
==
expected
@
pytest
.
mark
.
parametrize
(
"hash_fn"
,
[
sha256
,
sha256_cbor
_64bit
,
hash
])
@
pytest
.
mark
.
parametrize
(
"hash_fn"
,
[
sha256
,
sha256_cbor
])
def
test_request_block_hasher
(
hash_fn
):
import
vllm.v1.core.kv_cache_utils
init_none_hash
(
hash_fn
)
kv_cache_utils
.
init_none_hash
(
hash_fn
)
request
=
make_request
(
request_id
=
"0"
,
prompt_token_ids
=
[
_
for
_
in
range
(
6
)],
...
...
@@ -442,19 +438,13 @@ def test_request_block_hasher(hash_fn):
block_hashes
=
request
.
block_hashes
assert
len
(
block_hashes
)
==
2
assert
isinstance
(
block_hashes
[
0
],
vllm
.
v1
.
core
.
kv_cache_utils
.
BlockHash
)
assert
isinstance
(
block_hashes
[
1
],
vllm
.
v1
.
core
.
kv_cache_utils
.
BlockHash
)
# Check the first block
assert
block_hashes
[
0
].
token_ids
==
(
0
,
1
,
2
)
assert
block_hashes
[
0
].
extra_keys
==
(
"hash1"
,
)
assert
block_hashes
[
0
]
==
hash_fn
(
(
kv_cache_utils
.
NONE_HASH
,
(
0
,
1
,
2
),
(
"hash1"
,
)))
assert
block_hashes
[
1
]
==
hash_fn
(
(
block_hashes
[
0
],
(
3
,
4
,
5
),
(
"hash2"
,
)))
# Check the second block
assert
block_hashes
[
1
].
token_ids
==
(
3
,
4
,
5
)
assert
block_hashes
[
1
].
extra_keys
==
(
"hash2"
,
)
@
pytest
.
mark
.
parametrize
(
"hash_fn"
,
[
sha256
,
sha256_cbor_64bit
,
hash
])
@
pytest
.
mark
.
parametrize
(
"hash_fn"
,
[
sha256
,
sha256_cbor
])
def
test_hash_tokens_different_mm_input
(
hash_fn
):
init_none_hash
(
hash_fn
)
...
...
@@ -484,9 +474,9 @@ def test_hash_tokens_different_mm_input(hash_fn):
assert
block_hashes1
[
1
]
!=
block_hashes2
[
1
]
@
pytest
.
mark
.
parametrize
(
"hash_fn"
,
[
sha256
,
sha256_cbor
_64bit
,
hash
])
@
pytest
.
mark
.
parametrize
(
"hash_fn"
,
[
sha256
,
sha256_cbor
])
def
test_hash_request_tokens_no_mm_inputs
(
hash_fn
):
init_none_hash
(
hash_fn
)
kv_cache_utils
.
init_none_hash
(
hash_fn
)
request
=
make_request
(
request_id
=
"0"
,
...
...
@@ -500,10 +490,9 @@ def test_hash_request_tokens_no_mm_inputs(hash_fn):
block_hashes
=
request
.
block_hashes
assert
len
(
block_hashes
)
==
2
assert
block_hashes
[
0
].
token_ids
==
(
0
,
1
,
2
)
assert
block_hashes
[
0
].
extra_keys
is
None
assert
block_hashes
[
1
].
token_ids
==
(
3
,
4
,
5
)
assert
block_hashes
[
1
].
extra_keys
is
None
assert
block_hashes
[
0
]
==
hash_fn
(
(
kv_cache_utils
.
NONE_HASH
,
(
0
,
1
,
2
),
None
))
assert
block_hashes
[
1
]
==
hash_fn
((
block_hashes
[
0
],
(
3
,
4
,
5
),
None
))
def
test_metrics
():
...
...
tests/v1/core/test_prefix_caching.py
View file @
38d80967
...
...
@@ -8,17 +8,19 @@ from typing import Callable, Optional
import
pytest
import
torch
import
vllm.v1.core.kv_cache_utils
as
kv_cache_utils
from
vllm.distributed.kv_events
import
AllBlocksCleared
,
BlockRemoved
from
vllm.multimodal.inputs
import
(
MultiModalFeatureSpec
,
MultiModalKwargsItem
,
PlaceholderRange
)
from
vllm.sampling_params
import
SamplingParams
from
vllm.utils
import
sha256
,
sha256_cbor
_64bit
from
vllm.utils
import
sha256
,
sha256_cbor
from
vllm.v1.core.block_pool
import
BlockPool
from
vllm.v1.core.kv_cache_manager
import
KVCacheManager
,
Request
from
vllm.v1.core.kv_cache_utils
import
(
BlockHash
,
BlockHashWithGroupId
,
KVCacheBlock
,
from
vllm.v1.core.kv_cache_utils
import
(
BlockHash
,
KVCacheBlock
,
get_block_hash
,
get_group_id
,
get_request_block_hasher
,
hash_block_tokens
,
init_none_hash
)
hash_block_tokens
,
init_none_hash
,
make_block_hash_with_group_id
)
from
vllm.v1.kv_cache_interface
import
(
FullAttentionSpec
,
KVCacheConfig
,
KVCacheGroupSpec
,
SlidingWindowSpec
)
...
...
@@ -101,8 +103,10 @@ def make_kv_cache_config_hybrid_model(block_size: int,
)
@
pytest
.
mark
.
parametrize
(
"hash_algo"
,
[
"sha256"
,
"sha256_cbor_64bit"
,
"hash"
])
def
test_prefill
(
hash_algo
):
@
pytest
.
mark
.
parametrize
(
"hash_fn"
,
[
sha256
,
sha256_cbor
])
def
test_prefill
(
hash_fn
):
init_none_hash
(
hash_fn
)
block_size
=
16
manager
=
KVCacheManager
(
make_kv_cache_config
(
block_size
,
11
),
...
...
@@ -110,10 +114,6 @@ def test_prefill(hash_algo):
enable_caching
=
True
,
)
# choose the hash function according to the parameter
hash_fn
=
(
sha256_cbor_64bit
if
hash_algo
==
"sha256_cbor_64bit"
else
sha256
if
hash_algo
==
"sha256"
else
hash
)
# Complete 3 blocks (48 tokens)
common_token_ids
=
[
i
for
i
in
range
(
3
)
for
_
in
range
(
16
)]
...
...
@@ -137,10 +137,12 @@ def test_prefill(hash_algo):
block_tokens
=
tuple
(
all_token_ids
[(
block_id
-
1
)
*
16
:
block_id
*
16
])
block_hash
=
hash_block_tokens
(
hash_fn
,
parent_block_hash
,
block_tokens
)
assert
manager
.
block_pool
.
blocks
[
block_id
].
block_hash
.
block_hash
==
block_hash
blk_hash
=
manager
.
block_pool
.
blocks
[
block_id
].
block_hash
assert
blk_hash
is
not
None
assert
get_block_hash
(
blk_hash
)
==
block_hash
assert
get_group_id
(
blk_hash
)
==
0
assert
manager
.
block_pool
.
blocks
[
block_id
].
ref_cnt
==
1
parent_block_hash
=
block_hash
.
hash_value
parent_block_hash
=
block_hash
# Check partial block metadata
for
block_id
in
(
4
,
):
...
...
@@ -233,7 +235,7 @@ def test_prefill_hybrid_model():
enable_caching
=
True
,
)
hash_fn
=
ha
sh
hash_fn
=
sh
a256
# Complete 3 blocks (48 tokens)
common_token_ids
=
[
i
for
i
in
range
(
3
)
for
_
in
range
(
block_size
)]
...
...
@@ -260,11 +262,13 @@ def test_prefill_hybrid_model():
block_tokens
=
tuple
(
all_token_ids
[(
length
-
1
)
*
16
:
length
*
16
])
block_hash
=
hash_block_tokens
(
hash_fn
,
parent_block_hash
,
block_tokens
)
for
block_id
in
block_ids
:
assert
manager
.
block_pool
.
blocks
[
block_id
].
block_hash
.
block_hash
==
block_hash
for
group_id
,
block_id
in
enumerate
(
block_ids
):
blk_hash
=
manager
.
block_pool
.
blocks
[
block_id
].
block_hash
assert
blk_hash
is
not
None
assert
get_block_hash
(
blk_hash
)
==
block_hash
assert
get_group_id
(
blk_hash
)
==
group_id
assert
manager
.
block_pool
.
blocks
[
block_id
].
ref_cnt
==
1
parent_block_hash
=
block_hash
.
hash_value
parent_block_hash
=
block_hash
# Check partial block metadata
for
block_id
in
(
4
,
8
,
12
):
...
...
@@ -298,11 +302,10 @@ def test_prefill_hybrid_model():
cached_block_hash_to_block_bak
=
copy
.
copy
(
manager
.
block_pool
.
cached_block_hash_to_block
)
def
test_partial_request_hit
(
request_id
:
str
,
hash_to_evict
:
list
[
BlockHashWithGroupId
],
def
test_partial_request_hit
(
request_id
:
str
,
hash_to_evict
:
list
[
bytes
],
expect_hit_length
:
int
):
req
=
make_request
(
request_id
,
common_token_ids
+
unique_token_ids
,
block_size
,
ha
sh
)
block_size
,
sh
a256
)
for
hash_with_group_id
in
hash_to_evict
:
manager
.
block_pool
.
cached_block_hash_to_block
.
pop
(
hash_with_group_id
)
...
...
@@ -319,33 +322,32 @@ def test_prefill_hybrid_model():
# Evict the blocks outside sliding window, does not affect the hit length.
test_partial_request_hit
(
"2"
,
[
B
lock
H
ash
W
ith
G
roup
I
d
(
block_hashes
[
0
],
1
),
B
lock
H
ash
W
ith
G
roup
I
d
(
block_hashes
[
0
],
2
)
make_b
lock
_h
ash
_w
ith
_g
roup
_i
d
(
block_hashes
[
0
],
1
),
make_b
lock
_h
ash
_w
ith
_g
roup
_i
d
(
block_hashes
[
0
],
2
)
],
3
)
# Evict the first block of full attention, makes total cache miss.
test_partial_request_hit
(
"3"
,
[
BlockHashWithGroupId
(
block_hashes
[
0
],
0
),
],
0
)
test_partial_request_hit
(
"3"
,
[
make_block_hash_with_group_id
(
block_hashes
[
0
],
0
)],
0
)
# Evict the last block of all layers, reduces the hit length to 2.
test_partial_request_hit
(
"4"
,
[
B
lock
H
ash
W
ith
G
roup
I
d
(
block_hashes
[
2
],
0
),
B
lock
H
ash
W
ith
G
roup
I
d
(
block_hashes
[
2
],
1
),
B
lock
H
ash
W
ith
G
roup
I
d
(
block_hashes
[
2
],
2
),
make_b
lock
_h
ash
_w
ith
_g
roup
_i
d
(
block_hashes
[
2
],
0
),
make_b
lock
_h
ash
_w
ith
_g
roup
_i
d
(
block_hashes
[
2
],
1
),
make_b
lock
_h
ash
_w
ith
_g
roup
_i
d
(
block_hashes
[
2
],
2
),
],
2
)
# Evict the last block of full attention, reduces the hit length to 2.
test_partial_request_hit
(
"5"
,
[
BlockHashWithGroupId
(
block_hashes
[
2
],
0
)],
2
)
test_partial_request_hit
(
"5"
,
[
make_block_hash_with_group_id
(
block_hashes
[
2
],
0
)],
2
)
# Evict the last block of sliding window, reduces the hit length to 2.
test_partial_request_hit
(
"6"
,
[
BlockHashWithGroupId
(
block_hashes
[
2
],
1
)],
2
)
test_partial_request_hit
(
"6"
,
[
make_block_hash_with_group_id
(
block_hashes
[
2
],
1
)],
2
)
# Evict the last block of sliding window, reduces the hit length to 2.
test_partial_request_hit
(
"7"
,
[
BlockHashWithGroupId
(
block_hashes
[
2
],
2
)],
2
)
test_partial_request_hit
(
"7"
,
[
make_block_hash_with_group_id
(
block_hashes
[
2
],
2
)],
2
)
# Evict different set of blocks for full attention and sliding window makes
# total cache miss.
...
...
@@ -353,9 +355,9 @@ def test_prefill_hybrid_model():
# The cache hit length of sliding window is 2 * block_size.
# Then it is cache miss as the two type of layers have different hit length.
test_partial_request_hit
(
"8"
,
[
B
lock
H
ash
W
ith
G
roup
I
d
(
block_hashes
[
2
],
0
),
B
lock
H
ash
W
ith
G
roup
I
d
(
block_hashes
[
0
],
1
),
B
lock
H
ash
W
ith
G
roup
I
d
(
block_hashes
[
0
],
2
),
make_b
lock
_h
ash
_w
ith
_g
roup
_i
d
(
block_hashes
[
2
],
0
),
make_b
lock
_h
ash
_w
ith
_g
roup
_i
d
(
block_hashes
[
0
],
1
),
make_b
lock
_h
ash
_w
ith
_g
roup
_i
d
(
block_hashes
[
0
],
2
),
],
0
)
...
...
@@ -372,8 +374,8 @@ def test_prefill_plp():
max_model_len
=
8192
,
enable_caching
=
True
,
)
# the default hash function is
ha
sh
hash_fn
=
ha
sh
# the default hash function is sh
a256
hash_fn
=
sh
a256
# Complete 3 blocks (48 tokens)
common_token_ids
=
[
i
for
i
in
range
(
3
)
for
_
in
range
(
16
)]
...
...
@@ -404,10 +406,12 @@ def test_prefill_plp():
block_tokens
=
tuple
(
all_token_ids
[(
block_id
-
1
)
*
16
:
block_id
*
16
])
block_hash
=
hash_block_tokens
(
hash_fn
,
parent_block_hash
,
block_tokens
)
assert
manager
.
block_pool
.
blocks
[
block_id
].
block_hash
.
block_hash
==
block_hash
blk_hash
=
(
manager
.
block_pool
.
blocks
[
block_id
].
block_hash
)
assert
blk_hash
is
not
None
assert
get_block_hash
(
blk_hash
)
==
block_hash
assert
get_group_id
(
blk_hash
)
==
0
assert
manager
.
block_pool
.
blocks
[
block_id
].
ref_cnt
==
1
parent_block_hash
=
block_hash
.
hash_value
parent_block_hash
=
block_hash
# Check partial block metadata
for
block_id
in
(
4
,
):
...
...
@@ -493,7 +497,7 @@ def test_decode():
# Incomplete 1 block (7 tokens)
unique_token_ids
=
[
3
]
*
7
req0
=
make_request
(
"0"
,
common_token_ids
+
unique_token_ids
,
block_size
,
ha
sh
)
sh
a256
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req0
)
assert
not
computed_blocks
.
blocks
[
0
]
assert
num_computed_tokens
==
0
...
...
@@ -538,7 +542,7 @@ def test_evict():
)
last_token_id
=
5
*
16
+
7
req0
=
make_request
(
"0"
,
list
(
range
(
last_token_id
)),
block_size
,
ha
sh
)
req0
=
make_request
(
"0"
,
list
(
range
(
last_token_id
)),
block_size
,
sh
a256
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req0
)
assert
not
computed_blocks
.
blocks
[
0
]
assert
num_computed_tokens
==
0
...
...
@@ -550,7 +554,7 @@ def test_evict():
# 3 blocks.
req1
=
make_request
(
"1"
,
list
(
range
(
last_token_id
,
last_token_id
+
3
*
16
)),
block_size
,
ha
sh
)
sh
a256
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
assert
not
computed_blocks
.
blocks
[
0
]
assert
num_computed_tokens
==
0
...
...
@@ -572,7 +576,7 @@ def test_evict():
]
==
[
10
,
6
,
5
,
4
,
3
,
2
,
1
,
9
,
8
,
7
]
# Touch the first 2 blocks.
req2
=
make_request
(
"2"
,
list
(
range
(
2
*
16
+
3
)),
block_size
,
ha
sh
)
req2
=
make_request
(
"2"
,
list
(
range
(
2
*
16
+
3
)),
block_size
,
sh
a256
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req2
)
assert
computed_blocks
.
get_block_ids
()
==
([
1
,
2
],
)
assert
num_computed_tokens
==
2
*
16
...
...
@@ -597,7 +601,7 @@ def test_hash_block_correct_reuse():
# Allocate 1 block and cache it.
num_tokens
=
block_size
*
1
req
=
make_request
(
"0"
,
list
(
range
(
num_tokens
)),
block_size
,
ha
sh
)
req
=
make_request
(
"0"
,
list
(
range
(
num_tokens
)),
block_size
,
sh
a256
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req
)
assert
not
computed_blocks
.
blocks
[
0
]
assert
num_computed_tokens
==
0
...
...
@@ -611,7 +615,7 @@ def test_hash_block_correct_reuse():
# Allocate a new block that's not full, make sure hash info on the
# block is cleared.
req
=
make_request
(
"1"
,
list
(
range
(
num_tokens
-
1
)),
block_size
,
ha
sh
)
req
=
make_request
(
"1"
,
list
(
range
(
num_tokens
-
1
)),
block_size
,
sh
a256
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req
)
assert
not
computed_blocks
.
blocks
[
0
]
assert
num_computed_tokens
==
0
...
...
@@ -638,7 +642,7 @@ def test_computed_blocks_not_evicted():
# Allocate a block and cache it.
num_tokens
=
block_size
*
1
req0
=
make_request
(
"0"
,
list
(
range
(
num_tokens
)),
block_size
,
ha
sh
)
req0
=
make_request
(
"0"
,
list
(
range
(
num_tokens
)),
block_size
,
sh
a256
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req0
)
assert
not
computed_blocks
.
blocks
[
0
]
assert
num_computed_tokens
==
0
...
...
@@ -650,7 +654,7 @@ def test_computed_blocks_not_evicted():
# Allocate another block.
req1
=
make_request
(
"1"
,
list
(
range
(
num_tokens
,
num_tokens
*
2
)),
block_size
,
ha
sh
)
block_size
,
sh
a256
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
assert
not
computed_blocks
.
blocks
[
0
]
assert
num_computed_tokens
==
0
...
...
@@ -666,7 +670,7 @@ def test_computed_blocks_not_evicted():
# Now if we have a cache hit on the first block, we should evict the second
# cached block rather than the first one.
req2
=
make_request
(
"2"
,
list
(
range
(
num_tokens
*
2
)),
block_size
,
ha
sh
)
req2
=
make_request
(
"2"
,
list
(
range
(
num_tokens
*
2
)),
block_size
,
sh
a256
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req2
)
assert
len
(
computed_blocks
.
blocks
[
0
])
==
1
assert
computed_blocks
.
blocks
[
0
][
0
].
block_id
==
1
...
...
@@ -691,7 +695,7 @@ def test_basic_prefix_caching_disabled():
)
req1
=
make_request
(
"1"
,
list
(
range
(
10
)),
block_size
,
ha
sh
)
# 2 blocks and some more
sh
a256
)
# 2 blocks and some more
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
assert
not
computed_blocks
.
blocks
[
0
]
...
...
@@ -706,7 +710,7 @@ def test_basic_prefix_caching_disabled():
# No caching.
req2
=
make_request
(
"2"
,
list
(
range
(
16
)),
block_size
,
ha
sh
)
# shared prefix
sh
a256
)
# shared prefix
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req2
)
assert
not
computed_blocks
.
blocks
[
0
]
assert
num_computed_tokens
==
0
...
...
@@ -716,7 +720,7 @@ def test_basic_prefix_caching_disabled():
assert
len
(
blocks
.
blocks
[
0
])
==
4
# New requests should not have any blocks.
req3
=
make_request
(
"3"
,
list
(
range
(
4
)),
block_size
,
ha
sh
)
req3
=
make_request
(
"3"
,
list
(
range
(
4
)),
block_size
,
sh
a256
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req3
)
assert
not
computed_blocks
.
blocks
[
0
]
assert
num_computed_tokens
==
0
...
...
@@ -726,7 +730,7 @@ def test_basic_prefix_caching_disabled():
assert
not
blocks
@
pytest
.
mark
.
parametrize
(
"hash_fn"
,
[
sha256
,
sha256_cbor
_64bit
,
hash
])
@
pytest
.
mark
.
parametrize
(
"hash_fn"
,
[
sha256
,
sha256_cbor
])
def
test_cache_blocks
(
hash_fn
):
"""
This is a unit test that tests the correctness of the _cache_full_blocks
...
...
@@ -787,7 +791,7 @@ def test_cache_blocks_multi_group():
# Block 1/5: [4, 5, 6, 7]
# Block 2/6: [8, 9, 10, 11]
# Block 3/7: [12, 13]
req
=
make_request
(
"0"
,
list
(
range
(
14
)),
block_size
,
ha
sh
)
req
=
make_request
(
"0"
,
list
(
range
(
14
)),
block_size
,
sh
a256
)
# Cache the blocks for group 0.
blocks
=
[
KVCacheBlock
(
block_id
=
i
)
for
i
in
range
(
2
)]
...
...
@@ -845,6 +849,8 @@ def test_mm_prefix_caching():
"""
This tests that the multi-modal prefix caching is correct.
"""
kv_cache_utils
.
init_none_hash
(
sha256
)
block_size
=
16
manager
=
KVCacheManager
(
make_kv_cache_config
(
block_size
,
11
),
...
...
@@ -874,23 +880,30 @@ def test_mm_prefix_caching():
req0
=
make_request
(
"0"
,
all_token_ids
,
block_size
,
ha
sh
,
sh
a256
,
mm_positions
=
mm_positions
,
mm_hashes
=
mm_hashes
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req0
)
# Completed block should have hashes
with extra keys.
# Completed block should have hashes
assert
not
computed_blocks
.
blocks
[
0
]
assert
num_computed_tokens
==
0
block_hashes
=
req0
.
block_hashes
assert
len
(
block_hashes
)
==
3
assert
block_hashes
[
0
].
extra_keys
==
(
"aaa"
,
)
assert
block_hashes
[
1
].
extra_keys
==
(
"aaa"
,
"bbb"
)
assert
block_hashes
[
2
].
extra_keys
==
(
"bbb"
,
)
assert
block_hashes
[
0
]
==
sha256
(
(
kv_cache_utils
.
NONE_HASH
,
tuple
(
all_token_ids
[:
block_size
]),
(
"aaa"
,
)))
assert
block_hashes
[
1
]
==
sha256
(
(
block_hashes
[
0
],
tuple
(
all_token_ids
[
block_size
:
block_size
*
2
]),
(
"aaa"
,
"bbb"
)))
assert
block_hashes
[
2
]
==
sha256
(
(
block_hashes
[
1
],
tuple
(
all_token_ids
[
block_size
*
2
:
block_size
*
3
]),
(
"bbb"
,
)))
blocks
=
manager
.
allocate_slots
(
req0
,
59
,
len
(
computed_blocks
.
blocks
[
0
])
*
16
,
computed_blocks
)
assert
blocks
is
not
None
assert
blocks
.
get_block_ids
()
==
([
1
,
2
,
3
,
4
],
)
req0
.
num_computed_tokens
=
59
...
...
@@ -901,10 +914,10 @@ def test_mm_prefix_caching():
len
(
computed_blocks
.
blocks
[
0
])
*
16
,
computed_blocks
)
assert
new_blocks
is
not
None
and
len
(
new_blocks
.
blocks
[
0
])
==
0
# The just completed block should have hashes with extra keys.
assert
len
(
block_hashes
)
==
4
assert
block_hashes
[
3
].
extra_keys
==
(
"ccc"
,
)
assert
block_hashes
[
3
]
==
sha256
(
(
block_hashes
[
2
],
tuple
(
all_token_ids
[
3
*
block_size
:]
+
[
8
]
*
5
),
(
"ccc"
,
)))
# Cache hit.
unique_token_ids
=
[
-
1
]
*
7
+
[
200
]
*
5
...
...
@@ -916,7 +929,7 @@ def test_mm_prefix_caching():
req1
=
make_request
(
"1"
,
all_token_ids
,
block_size
,
ha
sh
,
sh
a256
,
mm_positions
=
mm_positions
,
mm_hashes
=
mm_hashes
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
...
...
@@ -929,6 +942,8 @@ def test_cache_key_salting():
This tests that cache salts are applied during hashing and the cache
is separated cache as expected.
"""
kv_cache_utils
.
init_none_hash
(
sha256
)
block_size
=
16
manager
=
KVCacheManager
(
make_kv_cache_config
(
block_size
,
11
),
...
...
@@ -939,21 +954,26 @@ def test_cache_key_salting():
# 3 complete blocks and an incomplete block with 11 tokens.
common_token_ids
=
[
i
for
i
in
range
(
3
)
for
_
in
range
(
block_size
)]
token_ids
=
common_token_ids
+
[
3
]
*
11
req0
=
make_request
(
"0"
,
token_ids
,
block_size
,
ha
sh
,
cache_salt
=
"salt1"
)
req0
=
make_request
(
"0"
,
token_ids
,
block_size
,
sh
a256
,
cache_salt
=
"salt1"
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req0
)
# Completed block should have hashes
with extra keys.
# Completed block should have hashes
assert
not
computed_blocks
.
blocks
[
0
]
assert
num_computed_tokens
==
0
block_hashes
=
req0
.
block_hashes
assert
len
(
block_hashes
)
==
3
assert
block_hashes
[
0
].
extra_keys
==
(
"salt1"
,
)
assert
block_hashes
[
1
].
extra_keys
is
None
assert
block_hashes
[
2
].
extra_keys
is
None
assert
block_hashes
[
0
]
==
sha256
(
(
kv_cache_utils
.
NONE_HASH
,
tuple
(
token_ids
[:
block_size
]),
(
"salt1"
,
)))
assert
block_hashes
[
1
]
==
sha256
(
(
block_hashes
[
0
],
tuple
(
token_ids
[
block_size
:
block_size
*
2
]),
None
))
assert
block_hashes
[
2
]
==
sha256
(
(
block_hashes
[
1
],
tuple
(
token_ids
[
block_size
*
2
:
block_size
*
3
]),
None
))
blocks
=
manager
.
allocate_slots
(
req0
,
59
,
len
(
computed_blocks
.
blocks
[
0
])
*
16
,
computed_blocks
)
assert
blocks
is
not
None
assert
blocks
.
get_block_ids
()
==
([
1
,
2
,
3
,
4
],
)
req0
.
num_computed_tokens
=
59
...
...
@@ -964,14 +984,13 @@ def test_cache_key_salting():
len
(
computed_blocks
.
blocks
[
0
])
*
16
,
computed_blocks
)
assert
new_blocks
is
not
None
and
len
(
new_blocks
.
blocks
[
0
])
==
0
# Now one more block that should not have extra keys.
assert
len
(
block_hashes
)
==
4
assert
block_hashes
[
3
].
extra_keys
is
None
assert
block_hashes
[
3
]
==
sha256
(
(
block_hashes
[
2
],
tuple
(
token_ids
[
3
*
block_size
:]
+
[
8
]
*
5
),
None
))
# Test cache hit with a new request that has the same salt.
token_ids
=
common_token_ids
+
[
4
]
*
11
req1
=
make_request
(
"1"
,
token_ids
,
block_size
,
ha
sh
,
cache_salt
=
"salt1"
)
req1
=
make_request
(
"1"
,
token_ids
,
block_size
,
sh
a256
,
cache_salt
=
"salt1"
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
# Should match only a prefix of 3 blocks.
assert
len
(
computed_blocks
.
blocks
[
0
])
==
3
...
...
@@ -979,13 +998,19 @@ def test_cache_key_salting():
# Test cache miss with same content but different salt.
token_ids
=
common_token_ids
+
[
4
]
*
11
req2
=
make_request
(
"2"
,
token_ids
,
block_size
,
ha
sh
,
cache_salt
=
"salt2"
)
req2
=
make_request
(
"2"
,
token_ids
,
block_size
,
sh
a256
,
cache_salt
=
"salt2"
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req2
)
assert
len
(
computed_blocks
.
blocks
[
0
])
==
0
assert
num_computed_tokens
==
0
block_hashes
=
req2
.
block_hashes
assert
len
(
block_hashes
)
==
3
assert
block_hashes
[
0
].
extra_keys
==
(
"salt2"
,
)
assert
block_hashes
[
0
]
==
sha256
(
(
kv_cache_utils
.
NONE_HASH
,
tuple
(
token_ids
[:
block_size
]),
(
"salt2"
,
)))
assert
block_hashes
[
1
]
==
sha256
(
(
block_hashes
[
0
],
tuple
(
token_ids
[
block_size
:
block_size
*
2
]),
None
))
assert
block_hashes
[
2
]
==
sha256
(
(
block_hashes
[
1
],
tuple
(
token_ids
[
block_size
*
2
:
block_size
*
3
]),
None
))
def
test_prefill_not_enough_free_blocks_with_computed_blocks
():
...
...
@@ -1004,7 +1029,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
# Complete 3 blocks (48 tokens)
# | Common-0 | Common-1 | Common-2 | ... |
common_token_ids
=
[
i
for
i
in
range
(
3
)
for
_
in
range
(
16
)]
req0
=
make_request
(
"0"
,
common_token_ids
,
block_size
,
ha
sh
)
req0
=
make_request
(
"0"
,
common_token_ids
,
block_size
,
sh
a256
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req0
)
assert
not
computed_blocks
.
blocks
[
0
]
assert
num_computed_tokens
==
0
...
...
@@ -1015,7 +1040,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
req0
.
request_id
]
# | Common-0 | Common-1 | Common-2 | Req1-3 | Req1-4 | Req1-5 | ... |
req1
=
make_request
(
"1"
,
common_token_ids
*
2
,
block_size
,
ha
sh
)
req1
=
make_request
(
"1"
,
common_token_ids
*
2
,
block_size
,
sh
a256
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
assert
computed_blocks
.
blocks
[
0
]
==
block_part0
assert
num_computed_tokens
==
3
*
16
...
...
@@ -1032,7 +1057,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
# | Common-0 | Common-1 | Common-2 | Req1-3 (F) | Req1-4 (F) |
# | Req1-5(F)| Req2-0 | Req2-1 | ... |
req2
=
make_request
(
"2"
,
[
7
]
*
block_size
*
2
,
block_size
,
ha
sh
)
req2
=
make_request
(
"2"
,
[
7
]
*
block_size
*
2
,
block_size
,
sh
a256
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req2
)
assert
not
computed_blocks
.
blocks
[
0
]
assert
num_computed_tokens
==
0
...
...
@@ -1044,7 +1069,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
# but it cannot be allocated due to insufficient free blocks (2).
# In this case, the ref_cnt of the computed blocks should not be changed.
assert
manager
.
block_pool
.
free_block_queue
.
num_free_blocks
==
5
req3
=
make_request
(
"3"
,
common_token_ids
*
3
,
block_size
,
ha
sh
)
req3
=
make_request
(
"3"
,
common_token_ids
*
3
,
block_size
,
sh
a256
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req3
)
assert
computed_blocks
.
blocks
[
0
]
==
block_part1
assert
num_computed_tokens
==
6
*
16
...
...
@@ -1069,13 +1094,13 @@ def test_reset_prefix_cache():
full_block_token_ids
=
[
i
for
i
in
range
(
3
)
for
_
in
range
(
16
)]
unique_token_ids
=
[
3
]
*
7
all_token_ids
=
full_block_token_ids
+
unique_token_ids
req0
=
make_request
(
"0"
,
all_token_ids
,
block_size
,
ha
sh
)
req0
=
make_request
(
"0"
,
all_token_ids
,
block_size
,
sh
a256
)
blocks
=
manager
.
allocate_slots
(
req0
,
55
)
assert
blocks
.
get_block_ids
()
==
([
1
,
2
,
3
,
4
],
)
unique_token_ids
=
[
4
]
*
7
all_token_ids
=
full_block_token_ids
+
unique_token_ids
req1
=
make_request
(
"1"
,
all_token_ids
,
block_size
,
ha
sh
)
req1
=
make_request
(
"1"
,
all_token_ids
,
block_size
,
sh
a256
)
computed_blocks
,
_
=
manager
.
get_computed_blocks
(
req1
)
assert
len
(
req1
.
block_hashes
)
==
3
assert
len
(
computed_blocks
.
blocks
[
0
])
==
3
...
...
@@ -1109,7 +1134,7 @@ def test_prefix_cache_stats_disabled():
assert
manager
.
prefix_cache_stats
is
None
# Call all functions that check whether log_stats is disabled.
req
=
make_request
(
"0"
,
list
(
range
(
16
)),
block_size
,
ha
sh
)
req
=
make_request
(
"0"
,
list
(
range
(
16
)),
block_size
,
sh
a256
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req
)
assert
not
computed_blocks
.
blocks
[
0
]
assert
num_computed_tokens
==
0
...
...
@@ -1124,15 +1149,9 @@ def test_prefix_cache_stats_disabled():
def
test_maybe_evict_cached_block
():
pool
=
BlockPool
(
num_gpu_blocks
=
4
,
enable_caching
=
True
)
block_hash0
=
BlockHashWithGroupId
(
block_hash
=
BlockHash
(
hash_value
=
10
,
token_ids
=
(
100
,
)),
group_id
=
1000
)
block_hash1
=
BlockHashWithGroupId
(
block_hash
=
BlockHash
(
hash_value
=
20
,
token_ids
=
(
200
,
)),
group_id
=
2000
)
block_hash2
=
BlockHashWithGroupId
(
block_hash
=
BlockHash
(
hash_value
=
30
,
token_ids
=
(
300
,
)),
group_id
=
3000
)
block_hash0
=
make_block_hash_with_group_id
(
BlockHash
(
b
"10"
),
1000
)
block_hash1
=
make_block_hash_with_group_id
(
BlockHash
(
b
"20"
),
2000
)
block_hash2
=
make_block_hash_with_group_id
(
BlockHash
(
b
"30"
),
3000
)
block_hashes
=
[
block_hash0
,
block_hash1
,
...
...
@@ -1206,7 +1225,7 @@ def test_kv_cache_events(blocks_to_cache: int):
)
num_tokens
=
block_size
*
blocks_to_cache
req0
=
make_request
(
"0"
,
list
(
range
(
num_tokens
)),
block_size
,
ha
sh
)
req0
=
make_request
(
"0"
,
list
(
range
(
num_tokens
)),
block_size
,
sh
a256
)
_
=
manager
.
allocate_slots
(
req0
,
num_tokens
)
events
=
manager
.
take_events
()
...
...
@@ -1222,7 +1241,7 @@ def test_kv_cache_events(blocks_to_cache: int):
# Should see block_to_cache number of removed block events and a new block
# stored event
manager
.
free
(
req0
)
req1
=
make_request
(
"1"
,
list
(
range
(
num_tokens
)),
block_size
,
ha
sh
)
req1
=
make_request
(
"1"
,
list
(
range
(
num_tokens
)),
block_size
,
sh
a256
)
_
=
manager
.
allocate_slots
(
req1
,
num_tokens
)
events
=
manager
.
take_events
()
...
...
@@ -1256,7 +1275,7 @@ def test_eagle_enabled_removes_last_block():
# Request with 3 full blocks (48 tokens)
token_ids
=
[
0
]
*
(
3
*
block_size
)
req
=
make_request
(
"divisible_request"
,
token_ids
,
block_size
,
ha
sh
)
req
=
make_request
(
"divisible_request"
,
token_ids
,
block_size
,
sh
a256
)
# Prime the cache
computed_blocks
,
_
=
manager
.
get_computed_blocks
(
req
)
...
...
@@ -1266,7 +1285,7 @@ def test_eagle_enabled_removes_last_block():
manager
.
free
(
req
)
# New request with same tokens + Eagle enabled
req_eagle
=
make_request
(
"eagle_divisible"
,
token_ids
,
block_size
,
ha
sh
)
req_eagle
=
make_request
(
"eagle_divisible"
,
token_ids
,
block_size
,
sh
a256
)
computed_blocks
,
num_tokens
=
manager
.
get_computed_blocks
(
req_eagle
)
# Should retain 1 block:
...
...
@@ -1287,7 +1306,7 @@ def test_eagle_with_partial_blocks():
)
# 2 full blocks + 5 tokens (non-divisible length)
token_ids
=
[
0
]
*
(
2
*
block_size
+
5
)
req
=
make_request
(
"partial_block_test"
,
token_ids
,
block_size
,
ha
sh
)
req
=
make_request
(
"partial_block_test"
,
token_ids
,
block_size
,
sh
a256
)
# Prime the cache
computed_blocks
,
_
=
manager
.
get_computed_blocks
(
req
)
...
...
@@ -1297,7 +1316,7 @@ def test_eagle_with_partial_blocks():
manager
.
free
(
req
)
# New request with Eagle enabled
req_eagle
=
make_request
(
"partial_eagle"
,
token_ids
,
block_size
,
ha
sh
)
req_eagle
=
make_request
(
"partial_eagle"
,
token_ids
,
block_size
,
sh
a256
)
computed_blocks
,
num_tokens
=
manager
.
get_computed_blocks
(
req_eagle
)
# Original match: 2 full blocks → Eagle removes 1 → 1 remaining
assert
len
(
computed_blocks
.
blocks
[
0
])
==
1
...
...
@@ -1328,7 +1347,7 @@ def test_eagle_with_sliding_window():
# 2 full blocks + 5 tokens (non-divisible length)
token_ids
=
[
0
]
*
(
2
*
block_size
+
5
)
req
=
make_request
(
"partial_block_test"
,
token_ids
,
block_size
,
ha
sh
)
req
=
make_request
(
"partial_block_test"
,
token_ids
,
block_size
,
sh
a256
)
# Prime the cache
computed_blocks
,
_
=
manager
.
get_computed_blocks
(
req
)
...
...
@@ -1341,7 +1360,7 @@ def test_eagle_with_sliding_window():
manager
.
free
(
req
)
# New request with Eagle enabled
req_eagle
=
make_request
(
"partial_eagle"
,
token_ids
,
block_size
,
ha
sh
)
req_eagle
=
make_request
(
"partial_eagle"
,
token_ids
,
block_size
,
sh
a256
)
computed_blocks
,
num_tokens
=
manager
.
get_computed_blocks
(
req_eagle
)
# Original match: 2 full blocks → Eagle removes 1 → 1 remaining
assert
len
(
computed_blocks
.
blocks
[
0
])
==
1
...
...
@@ -1351,11 +1370,11 @@ def test_eagle_with_sliding_window():
assert
manager
.
block_pool
.
get_cached_block
(
block_hash_first_block
,
kv_cache_group_ids
=
[
0
])
is
not
None
manager
.
block_pool
.
cached_block_hash_to_block
.
pop
(
B
lock
H
ash
W
ith
G
roup
I
d
(
block_hash_first_block
,
0
))
make_b
lock
_h
ash
_w
ith
_g
roup
_i
d
(
block_hash_first_block
,
0
))
# New request
req_after_evict
=
make_request
(
"partial_eagle_after_evict"
,
token_ids
,
block_size
,
ha
sh
)
block_size
,
sh
a256
)
computed_blocks
,
num_tokens
=
manager
.
get_computed_blocks
(
req_after_evict
)
# Cache miss. The only hit prefix is [NULL_BLOCK, BLOCK_2] if eagle is
# not considered. But after dropping the last matched block due to eagle,
...
...
tests/v1/core/test_single_type_kv_cache_manager.py
View file @
38d80967
...
...
@@ -6,8 +6,8 @@ import random
import
torch
from
vllm.v1.core.block_pool
import
BlockPool
from
vllm.v1.core.kv_cache_utils
import
(
BlockHash
,
BlockHashWithGroupId
,
KVCacheBlock
)
from
vllm.v1.core.kv_cache_utils
import
(
BlockHash
,
KVCacheBlock
,
make_block_hash_with_group_id
)
from
vllm.v1.core.single_type_kv_cache_manager
import
(
ChunkedLocalAttentionManager
,
SlidingWindowManager
)
from
vllm.v1.kv_cache_interface
import
(
ChunkedLocalAttentionSpec
,
...
...
@@ -44,7 +44,7 @@ def test_chunked_local_attention_possible_cached_prefix():
def
run_one_case
(
block_is_cached
,
tail_token
,
expect_length
):
block_hash_list
=
[
BlockHash
(
i
,
())
for
i
in
range
(
len
(
block_is_cached
))
BlockHash
(
str
(
i
).
encode
())
for
i
in
range
(
len
(
block_is_cached
))
]
block_pool
.
cached_block_hash_to_block
.
clear
()
...
...
@@ -53,8 +53,8 @@ def test_chunked_local_attention_possible_cached_prefix():
for
i
,
(
block_hash
,
is_cached
)
in
enumerate
(
zip
(
block_hash_list
,
block_is_cached
)):
if
is_cached
:
block_pool
.
cached_block_hash_to_block
[
BlockHashWithGroupId
(
block_hash
,
0
)]
=
{
block_pool
.
cached_block_hash_to_block
[
make_block_hash_with_group_id
(
block_hash
,
0
)]
=
{
i
:
block_pool
.
blocks
[
i
+
10
],
}
...
...
@@ -109,7 +109,7 @@ def test_sliding_window_possible_cached_prefix():
def
run_one_case
(
block_is_cached
,
expect_length
):
block_hash_list
=
[
BlockHash
(
i
,
())
for
i
in
range
(
len
(
block_is_cached
))
BlockHash
(
str
(
i
).
encode
())
for
i
in
range
(
len
(
block_is_cached
))
]
block_pool
.
cached_block_hash_to_block
.
clear
()
...
...
@@ -118,8 +118,8 @@ def test_sliding_window_possible_cached_prefix():
for
i
,
(
block_hash
,
is_cached
)
in
enumerate
(
zip
(
block_hash_list
,
block_is_cached
)):
if
is_cached
:
block_pool
.
cached_block_hash_to_block
[
BlockHashWithGroupId
(
block_hash
,
0
)]
=
{
block_pool
.
cached_block_hash_to_block
[
make_block_hash_with_group_id
(
block_hash
,
0
)]
=
{
i
:
block_pool
.
blocks
[
i
+
10
],
}
...
...
tests/v1/core/utils.py
View file @
38d80967
...
...
@@ -9,6 +9,7 @@ from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig,
from
vllm.multimodal.inputs
import
(
MultiModalFeatureSpec
,
MultiModalKwargsItem
,
PlaceholderRange
)
from
vllm.sampling_params
import
SamplingParams
from
vllm.utils
import
sha256
from
vllm.v1.core.kv_cache_utils
import
(
get_request_block_hasher
,
init_none_hash
)
from
vllm.v1.core.sched.async_scheduler
import
AsyncScheduler
...
...
@@ -130,10 +131,10 @@ def create_requests(
)
->
list
[
Request
]:
global
_none_hash_initialized
if
not
_none_hash_initialized
:
init_none_hash
(
ha
sh
)
init_none_hash
(
sh
a256
)
_none_hash_initialized
=
True
block_hasher
=
get_request_block_hasher
(
block_size
,
ha
sh
)
block_hasher
=
get_request_block_hasher
(
block_size
,
sh
a256
)
sampling_params
=
SamplingParams
(
ignore_eos
=
False
,
max_tokens
=
max_tokens
,
stop_token_ids
=
stop_token_ids
,
...
...
tests/v1/cudagraph/test_cudagraph_mode.py
View file @
38d80967
...
...
@@ -62,6 +62,16 @@ backend_configs = {
"cudagraph_mode"
:
"FULL_AND_PIECEWISE"
,
},
specific_gpu_arch
=
(
9
,
0
)),
# FlashAttention MLA on Hopper
"FlashAttentionMLA"
:
BackendConfig
(
name
=
"FlashAttentionMLA"
,
env_vars
=
{
"VLLM_ATTENTION_BACKEND"
:
"FLASH_ATTN_MLA"
,
},
comp_config
=
{
"cudagraph_mode"
:
"FULL_DECODE_ONLY"
,
},
specific_gpu_arch
=
(
9
,
0
)),
# FA2
"FA2"
:
BackendConfig
(
name
=
"FA2"
,
...
...
tests/v1/e2e/test_spec_decode.py
View file @
38d80967
...
...
@@ -83,7 +83,7 @@ def test_ngram_correctness(
model_name
:
str
,
):
'''
Compare the outputs of a original LLM and a speculative LLM
Compare the outputs of a
n
original LLM and a speculative LLM
should be the same when using ngram speculative decoding.
'''
with
monkeypatch
.
context
()
as
m
:
...
...
@@ -117,45 +117,38 @@ def test_ngram_correctness(
print
(
f
"ref_output:
{
ref_output
.
outputs
[
0
].
text
}
"
)
print
(
f
"spec_output:
{
spec_output
.
outputs
[
0
].
text
}
"
)
# Heuristic: expect at least
70
% of the prompts to match exactly
# Heuristic: expect at least
68
% of the prompts to match exactly
# Upon failure, inspect the outputs to check for inaccuracy.
assert
matches
>
int
(
0.
7
*
len
(
ref_outputs
))
assert
matches
>
=
int
(
0.
68
*
len
(
ref_outputs
))
del
spec_llm
torch
.
cuda
.
empty_cache
()
cleanup_dist_env_and_memory
()
@
pytest
.
mark
.
parametrize
(
[
"model_setup"
,
"mm_enabled"
],
[
# TODO: Re-enable this once tests/models/test_initialization.py is fixed, see PR #22333 #22611 # noqa: E501
# (("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False),
((
"eagle"
,
"meta-llama/Llama-3.1-8B-Instruct"
,
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B"
,
1
),
False
),
((
"eagle3"
,
"meta-llama/Llama-3.1-8B-Instruct"
,
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
,
1
),
False
),
pytest
.
param
(
(
"eagle"
,
"meta-llama/Llama-4-Scout-17B-16E-Instruct"
,
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct"
,
4
),
False
,
marks
=
pytest
.
mark
.
skip
(
reason
=
"Skipping due to CI OOM issues"
)),
pytest
.
param
(
(
"eagle"
,
"meta-llama/Llama-4-Scout-17B-16E-Instruct"
,
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct"
,
4
),
True
,
marks
=
pytest
.
mark
.
skip
(
reason
=
"Skipping due to CI OOM issues"
)),
((
"eagle"
,
"eagle618/deepseek-v3-random"
,
"eagle618/eagle-deepseek-v3-random"
,
1
),
False
),
],
ids
=
[
# TODO: Re-enable this once tests/models/test_initialization.py is fixed, see PR #22333 #22611 # noqa: E501
# "qwen3_eagle3",
"llama3_eagle"
,
"llama3_eagle3"
,
"llama4_eagle"
,
"llama4_eagle_mm"
,
"deepseek_eagle"
])
@
pytest
.
mark
.
parametrize
([
"model_setup"
,
"mm_enabled"
],
[
((
"eagle3"
,
"Qwen/Qwen3-8B"
,
"AngelSlim/Qwen3-8B_eagle3"
,
1
),
False
),
((
"eagle"
,
"meta-llama/Llama-3.1-8B-Instruct"
,
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B"
,
1
),
False
),
((
"eagle3"
,
"meta-llama/Llama-3.1-8B-Instruct"
,
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
,
1
),
False
),
pytest
.
param
(
(
"eagle"
,
"meta-llama/Llama-4-Scout-17B-16E-Instruct"
,
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct"
,
4
),
False
,
marks
=
pytest
.
mark
.
skip
(
reason
=
"Skipping due to CI OOM issues"
)),
pytest
.
param
(
(
"eagle"
,
"meta-llama/Llama-4-Scout-17B-16E-Instruct"
,
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct"
,
4
),
True
,
marks
=
pytest
.
mark
.
skip
(
reason
=
"Skipping due to CI OOM issues"
)),
((
"eagle"
,
"eagle618/deepseek-v3-random"
,
"eagle618/eagle-deepseek-v3-random"
,
1
),
False
),
],
ids
=
[
"qwen3_eagle3"
,
"llama3_eagle"
,
"llama3_eagle3"
,
"llama4_eagle"
,
"llama4_eagle_mm"
,
"deepseek_eagle"
])
@
pytest
.
mark
.
parametrize
(
"attn_backend"
,
get_attn_backend_list_based_on_platform
())
def
test_eagle_correctness
(
...
...
@@ -169,7 +162,7 @@ def test_eagle_correctness(
# TODO: Fix this flaky test
pytest
.
skip
(
"TREE_ATTN is flaky in the test disable for now until it can be "
"reolved (see https://github.com/vllm-project/vllm/issues/22922)"
)
"re
s
olved (see https://github.com/vllm-project/vllm/issues/22922)"
)
# Generate test prompts inside the function instead of using fixture
test_prompts
=
get_test_prompts
(
mm_enabled
)
...
...
tests/v1/engine/test_async_llm.py
View file @
38d80967
...
...
@@ -393,7 +393,7 @@ class MockLoggingStatLogger(LoggingStatLogger):
async
def
test_customize_loggers
(
monkeypatch
):
"""Test that we can customize the loggers.
If a customized logger is provided at the init, it should
be
used directly
.
be
added to the default loggers
.
"""
with
monkeypatch
.
context
()
as
m
,
ExitStack
()
as
after
:
...
...
@@ -410,7 +410,8 @@ async def test_customize_loggers(monkeypatch):
stat_loggers
=
engine
.
logger_manager
.
per_engine_logger_dict
assert
len
(
stat_loggers
)
==
1
assert
len
(
stat_loggers
[
0
])
==
1
assert
len
(
stat_loggers
[
0
])
==
2
# LoggingStatLogger + MockLoggingStatLogger
stat_loggers
[
0
][
0
].
log
.
assert_called_once
()
...
...
tests/v1/engine/test_engine_args.py
View file @
38d80967
...
...
@@ -36,18 +36,19 @@ def test_prefix_caching_from_cli():
assert
vllm_config
.
cache_config
.
enable_prefix_caching
# default hash algorithm is "builtin"
assert
vllm_config
.
cache_config
.
prefix_caching_hash_algo
==
"builtin"
assert
vllm_config
.
cache_config
.
prefix_caching_hash_algo
==
"sha256"
# set hash algorithm to sha256_cbor
args
=
parser
.
parse_args
([
"--prefix-caching-hash-algo"
,
"sha256_cbor"
])
vllm_config
=
EngineArgs
.
from_cli_args
(
args
=
args
).
create_engine_config
()
assert
vllm_config
.
cache_config
.
prefix_caching_hash_algo
==
\
"sha256_cbor"
# set hash algorithm to sha256
args
=
parser
.
parse_args
([
"--prefix-caching-hash-algo"
,
"sha256"
])
vllm_config
=
EngineArgs
.
from_cli_args
(
args
=
args
).
create_engine_config
()
assert
vllm_config
.
cache_config
.
prefix_caching_hash_algo
==
"sha256"
# set hash algorithm to builtin
args
=
parser
.
parse_args
([
"--prefix-caching-hash-algo"
,
"builtin"
])
vllm_config
=
EngineArgs
.
from_cli_args
(
args
=
args
).
create_engine_config
()
assert
vllm_config
.
cache_config
.
prefix_caching_hash_algo
==
"builtin"
# an invalid hash algorithm raises an error
parser
.
exit_on_error
=
False
with
pytest
.
raises
(
ArgumentError
):
...
...
tests/v1/engine/test_processor_multi_modal_uuids.py
View file @
38d80967
...
...
@@ -152,8 +152,8 @@ def test_multi_modal_uuids_accepts_none_and_passes_through(
*
,
tokenization_kwargs
=
None
,
lora_request
=
None
,
mm_
hash_overr
id
e
s
=
None
):
captured
[
"mm_
hash_overr
id
e
s"
]
=
mm_
hash_overr
id
e
s
mm_
uu
ids
=
None
):
captured
[
"mm_
uu
ids"
]
=
mm_
uu
ids
# Minimal processed inputs for decoder-only flow
return
{
"type"
:
"token"
,
"prompt_token_ids"
:
[
1
]}
...
...
@@ -180,7 +180,7 @@ def test_multi_modal_uuids_accepts_none_and_passes_through(
params
=
SamplingParams
(),
)
assert
captured
[
"mm_
hash_overr
id
e
s"
]
==
mm_uuids
assert
captured
[
"mm_
uu
ids"
]
==
mm_uuids
def
test_multi_modal_uuids_ignored_when_caching_disabled
(
monkeypatch
):
...
...
@@ -196,8 +196,8 @@ def test_multi_modal_uuids_ignored_when_caching_disabled(monkeypatch):
*
,
tokenization_kwargs
=
None
,
lora_request
=
None
,
mm_
hash_overr
id
e
s
=
None
):
captured
[
"mm_
hash_overr
id
e
s"
]
=
mm_
hash_overr
id
e
s
mm_
uu
ids
=
None
):
captured
[
"mm_
uu
ids"
]
=
mm_
uu
ids
return
{
"type"
:
"token"
,
"prompt_token_ids"
:
[
1
]}
monkeypatch
.
setattr
(
processor
.
input_preprocessor
,
...
...
@@ -223,7 +223,7 @@ def test_multi_modal_uuids_ignored_when_caching_disabled(monkeypatch):
)
# Expect request-id-based overrides are passed through
assert
captured
[
"mm_
hash_overr
id
e
s"
]
==
{
assert
captured
[
"mm_
uu
ids"
]
==
{
"image"
:
[
f
"
{
request_id
}
-image-0"
,
f
"
{
request_id
}
-image-1"
],
"video"
:
[
f
"
{
request_id
}
-video-0"
],
}
tests/v1/entrypoints/llm/test_struct_output_generate.py
View file @
38d80967
...
...
@@ -46,12 +46,12 @@ PARAMS_MODELS_BACKENDS_TOKENIZER_MODE = [
(
"mistralai/Ministral-8B-Instruct-2410"
,
"xgrammar"
,
"mistral"
,
None
),
(
"Qwen/Qwen2.5-1.5B-Instruct"
,
"xgrammar"
,
"auto"
,
None
),
(
"Qwen/Qwen2.5-1.5B-Instruct"
,
"lm-format-enforcer"
,
"auto"
,
None
),
(
"mistralai/Ministral-8B-Instruct-2410"
,
"outlines"
,
"auto"
,
None
),
(
"mistralai/Ministral-8B-Instruct-2410"
,
"outlines"
,
"mistral"
,
None
),
#FIXME: This tests are flaky on CI thus disabled. Tracking in Issue #24402
# ("mistralai/Ministral-8B-Instruct-2410", "outlines", "auto", None),
# ("mistralai/Ministral-8B-Instruct-2410", "outlines", "mistral", None),
#("Qwen/Qwen2.5-1.5B-Instruct", "guidance", "auto"),
(
"mistralai/Ministral-8B-Instruct-2410"
,
"outlines"
,
"auto"
,
NGRAM_SPEC_CONFIG
),
#FIXME: This test is flaky on CI thus disabled
#("Qwen/Qwen2.5-1.5B-Instruct", "guidance", "auto"),
(
"mistralai/Ministral-8B-Instruct-2410"
,
"guidance"
,
"auto"
,
NGRAM_SPEC_CONFIG
),
(
"Qwen/Qwen2.5-1.5B-Instruct"
,
"xgrammar"
,
"auto"
,
NGRAM_SPEC_CONFIG
),
...
...
@@ -122,6 +122,7 @@ def test_structured_output(
guided_decoding_backend
=
guided_decoding_backend
,
guided_decoding_disable_any_whitespace
=
(
guided_decoding_backend
in
{
"xgrammar"
,
"guidance"
}),
seed
=
120
,
tokenizer_mode
=
tokenizer_mode
,
speculative_config
=
speculative_config
)
...
...
tests/v1/entrypoints/openai/responses/test_basic.py
View file @
38d80967
...
...
@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
openai
# use the official client for correctness check
import
openai.types.responses
as
openai_responses_types
import
pytest
...
...
@@ -86,3 +87,18 @@ async def test_logprobs(client: openai.AsyncOpenAI):
outputs
=
response
.
output
assert
outputs
[
-
1
].
content
[
-
1
].
logprobs
assert
len
(
outputs
[
-
1
].
content
[
-
1
].
logprobs
[
0
].
top_logprobs
)
==
5
@
pytest
.
mark
.
asyncio
async
def
test_streaming
(
client
:
openai
.
AsyncOpenAI
):
stream
=
await
client
.
responses
.
create
(
input
=
"What is 13 * 24?"
,
stream
=
True
,
)
events
=
[
event
async
for
event
in
stream
]
assert
isinstance
(
events
[
0
],
openai_responses_types
.
ResponseCreatedEvent
)
assert
any
(
isinstance
(
event
,
openai_responses_types
.
ResponseTextDeltaEvent
)
for
event
in
events
)
assert
isinstance
(
events
[
-
1
],
openai_responses_types
.
ResponseCompletedEvent
)
tests/v1/entrypoints/openai/responses/test_image.py
View file @
38d80967
...
...
@@ -8,17 +8,17 @@ import pytest
import
pytest_asyncio
from
tests.utils
import
RemoteOpenAIServer
from
vllm.multimodal.utils
import
encode_image_base64
,
fetch_image
from
vllm.multimodal.utils
import
encode_image_base64
# Use a small vision model for testing
MODEL_NAME
=
"Qwen/Qwen2.5-VL-3B-Instruct"
MAXIMUM_IMAGES
=
2
# Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA)
TEST_IMAGE_
URL
S
=
[
"https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
,
"https://upload.wikimedia.org/wikipedia/commons/f/fa/Grayscale_8bits_palette_sample_image.png"
,
"https://upload.wikimedia.org/wikipedia/commons/thumb/9/91/Venn_diagram_rgb.svg/1280px-Venn_diagram_rgb.svg.png"
,
"https://upload.wikimedia.org/wikipedia/commons/0/0b/RGBA_comp.png"
,
TEST_IMAGE_
ASSET
S
=
[
"2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
,
#
"https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
"Grayscale_8bits_palette_sample_image.png"
,
#
"https://upload.wikimedia.org/wikipedia/commons/f/fa/Grayscale_8bits_palette_sample_image.png",
"1280px-Venn_diagram_rgb.svg.png"
,
#
"https://upload.wikimedia.org/wikipedia/commons/thumb/9/91/Venn_diagram_rgb.svg/1280px-Venn_diagram_rgb.svg.png",
"RGBA_comp.png"
,
#
"https://upload.wikimedia.org/wikipedia/commons/0/0b/RGBA_comp.png",
]
...
...
@@ -52,16 +52,17 @@ async def client(image_server):
@
pytest
.
fixture
(
scope
=
"session"
)
def
base64_encoded_image
()
->
dict
[
str
,
str
]:
def
base64_encoded_image
(
local_asset_server
)
->
dict
[
str
,
str
]:
return
{
image_url
:
encode_image_base64
(
fetch_image
(
image_url
))
for
image_url
in
TEST_IMAGE_URLS
image_url
:
encode_image_base64
(
local_asset_server
.
get_image_asset
(
image_url
))
for
image_url
in
TEST_IMAGE_ASSETS
}
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
MODEL_NAME
])
@
pytest
.
mark
.
parametrize
(
"image_url"
,
TEST_IMAGE_
URLS
)
@
pytest
.
mark
.
parametrize
(
"image_url"
,
TEST_IMAGE_
ASSETS
,
indirect
=
True
)
async
def
test_single_chat_session_image
(
client
:
openai
.
AsyncOpenAI
,
model_name
:
str
,
image_url
:
str
):
content_text
=
"What's in this image?"
...
...
@@ -91,11 +92,11 @@ async def test_single_chat_session_image(client: openai.AsyncOpenAI,
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
MODEL_NAME
])
@
pytest
.
mark
.
parametrize
(
"image_url"
,
TEST_IMAGE_
URL
S
)
@
pytest
.
mark
.
parametrize
(
"
raw_
image_url"
,
TEST_IMAGE_
ASSET
S
)
async
def
test_single_chat_session_image_base64encoded
(
client
:
openai
.
AsyncOpenAI
,
model_name
:
str
,
image_url
:
str
,
raw_
image_url
:
str
,
base64_encoded_image
:
dict
[
str
,
str
],
):
content_text
=
"What's in this image?"
...
...
@@ -106,7 +107,7 @@ async def test_single_chat_session_image_base64encoded(
{
"type"
:
"input_image"
,
"image_url"
:
f
"data:image/jpeg;base64,
{
base64_encoded_image
[
image_url
]
}
"
,
f
"data:image/jpeg;base64,
{
base64_encoded_image
[
raw_
image_url
]
}
"
,
"detail"
:
"auto"
,
},
{
...
...
@@ -127,7 +128,8 @@ async def test_single_chat_session_image_base64encoded(
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
MODEL_NAME
])
@
pytest
.
mark
.
parametrize
(
"image_urls"
,
[
TEST_IMAGE_URLS
[:
i
]
for
i
in
range
(
2
,
len
(
TEST_IMAGE_URLS
))])
[
TEST_IMAGE_ASSETS
[:
i
]
for
i
in
range
(
2
,
len
(
TEST_IMAGE_ASSETS
))],
indirect
=
True
)
async
def
test_multi_image_input
(
client
:
openai
.
AsyncOpenAI
,
model_name
:
str
,
image_urls
:
list
[
str
]):
messages
=
[{
...
...
Prev
1
…
13
14
15
16
17
18
19
20
21
…
28
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment