Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
38d80967
Commit
38d80967
authored
Sep 12, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.10.2rc2' into v0.10.2rc2-ori
parents
33650733
880c741b
Changes
544
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
503 additions
and
381 deletions
+503
-381
tests/transformers_utils/__init__.py
tests/transformers_utils/__init__.py
+0
-0
tests/transformers_utils/test_config_parser_registry.py
tests/transformers_utils/test_config_parser_registry.py
+37
-0
tests/utils.py
tests/utils.py
+101
-28
tests/utils_/test_utils.py
tests/utils_/test_utils.py
+9
-11
tests/v1/attention/test_attention_backends.py
tests/v1/attention/test_attention_backends.py
+0
-16
tests/v1/attention/test_chunked_local_attention.py
tests/v1/attention/test_chunked_local_attention.py
+1
-1
tests/v1/attention/test_mla_backends.py
tests/v1/attention/test_mla_backends.py
+101
-106
tests/v1/attention/utils.py
tests/v1/attention/utils.py
+2
-0
tests/v1/core/test_kv_cache_utils.py
tests/v1/core/test_kv_cache_utils.py
+28
-39
tests/v1/core/test_prefix_caching.py
tests/v1/core/test_prefix_caching.py
+122
-103
tests/v1/core/test_single_type_kv_cache_manager.py
tests/v1/core/test_single_type_kv_cache_manager.py
+8
-8
tests/v1/core/utils.py
tests/v1/core/utils.py
+3
-2
tests/v1/cudagraph/test_cudagraph_mode.py
tests/v1/cudagraph/test_cudagraph_mode.py
+10
-0
tests/v1/e2e/test_spec_decode.py
tests/v1/e2e/test_spec_decode.py
+28
-35
tests/v1/engine/test_async_llm.py
tests/v1/engine/test_async_llm.py
+3
-2
tests/v1/engine/test_engine_args.py
tests/v1/engine/test_engine_args.py
+7
-6
tests/v1/engine/test_processor_multi_modal_uuids.py
tests/v1/engine/test_processor_multi_modal_uuids.py
+6
-6
tests/v1/entrypoints/llm/test_struct_output_generate.py
tests/v1/entrypoints/llm/test_struct_output_generate.py
+5
-4
tests/v1/entrypoints/openai/responses/test_basic.py
tests/v1/entrypoints/openai/responses/test_basic.py
+16
-0
tests/v1/entrypoints/openai/responses/test_image.py
tests/v1/entrypoints/openai/responses/test_image.py
+16
-14
No files found.
Too many changes to show.
To preserve performance only
544 of 544+
files are displayed.
Plain diff
Email patch
tests/transformers_utils/__init__.py
0 → 100644
View file @
38d80967
tests/transformers_utils/test_config_parser_registry.py
0 → 100644
View file @
38d80967
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
pathlib
import
Path
from
typing
import
Optional
,
Union
import
pytest
from
transformers
import
PretrainedConfig
from
vllm.transformers_utils.config
import
(
get_config_parser
,
register_config_parser
)
from
vllm.transformers_utils.config_parser_base
import
ConfigParserBase
@
register_config_parser
(
"custom_config_parser"
)
class
CustomConfigParser
(
ConfigParserBase
):
def
parse
(
self
,
model
:
Union
[
str
,
Path
],
trust_remote_code
:
bool
,
revision
:
Optional
[
str
]
=
None
,
code_revision
:
Optional
[
str
]
=
None
,
**
kwargs
)
->
tuple
[
dict
,
PretrainedConfig
]:
raise
NotImplementedError
def
test_register_config_parser
():
assert
isinstance
(
get_config_parser
(
"custom_config_parser"
),
CustomConfigParser
)
def
test_invalid_config_parser
():
with
pytest
.
raises
(
ValueError
):
@
register_config_parser
(
"invalid_config_parser"
)
class
InvalidConfigParser
:
pass
tests/utils.py
View file @
38d80967
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
asyncio
import
asyncio
import
contextlib
import
copy
import
copy
import
functools
import
functools
import
importlib
import
importlib
...
@@ -13,10 +14,11 @@ import sys
...
@@ -13,10 +14,11 @@ import sys
import
tempfile
import
tempfile
import
time
import
time
import
warnings
import
warnings
from
contextlib
import
contextmanager
,
suppress
from
contextlib
import
ExitStack
,
contextmanager
,
suppress
from
multiprocessing
import
Process
from
multiprocessing
import
Process
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
Any
,
Callable
,
Literal
,
Optional
,
Union
from
typing
import
Any
,
Callable
,
Literal
,
Optional
,
Union
from
unittest.mock
import
patch
import
cloudpickle
import
cloudpickle
import
httpx
import
httpx
...
@@ -799,43 +801,106 @@ _P = ParamSpec("_P")
...
@@ -799,43 +801,106 @@ _P = ParamSpec("_P")
def
fork_new_process_for_each_test
(
def
fork_new_process_for_each_test
(
f
:
Callable
[
_P
,
None
])
->
Callable
[
_P
,
None
]:
f
unc
:
Callable
[
_P
,
None
])
->
Callable
[
_P
,
None
]:
"""Decorator to fork a new process for each test function.
"""Decorator to fork a new process for each test function.
See https://github.com/vllm-project/vllm/issues/7053 for more details.
See https://github.com/vllm-project/vllm/issues/7053 for more details.
"""
"""
@
functools
.
wraps
(
f
)
@
functools
.
wraps
(
f
unc
)
def
wrapper
(
*
args
:
_P
.
args
,
**
kwargs
:
_P
.
kwargs
)
->
None
:
def
wrapper
(
*
args
:
_P
.
args
,
**
kwargs
:
_P
.
kwargs
)
->
None
:
# Make the process the leader of its own process group
# Make the process the leader of its own process group
# to avoid sending SIGTERM to the parent process
# to avoid sending SIGTERM to the parent process
os
.
setpgrp
()
os
.
setpgrp
()
from
_pytest.outcomes
import
Skipped
from
_pytest.outcomes
import
Skipped
pid
=
os
.
fork
()
print
(
f
"Fork a new process to run a test
{
pid
}
"
)
# Create a unique temporary file to store exception info from child
if
pid
==
0
:
# process. Use test function name and process ID to avoid collisions.
try
:
with
tempfile
.
NamedTemporaryFile
(
f
(
*
args
,
**
kwargs
)
delete
=
False
,
except
Skipped
as
e
:
mode
=
'w+b'
,
# convert Skipped to exit code 0
prefix
=
f
"vllm_test_
{
func
.
__name__
}
_
{
os
.
getpid
()
}
_"
,
print
(
str
(
e
))
suffix
=
".exc"
)
as
exc_file
,
ExitStack
()
as
delete_after
:
os
.
_exit
(
0
)
exc_file_path
=
exc_file
.
name
except
Exception
:
delete_after
.
callback
(
os
.
remove
,
exc_file_path
)
import
traceback
traceback
.
print_exc
()
pid
=
os
.
fork
()
os
.
_exit
(
1
)
print
(
f
"Fork a new process to run a test
{
pid
}
"
)
if
pid
==
0
:
# Parent process responsible for deleting, don't delete
# in child.
delete_after
.
pop_all
()
try
:
func
(
*
args
,
**
kwargs
)
except
Skipped
as
e
:
# convert Skipped to exit code 0
print
(
str
(
e
))
os
.
_exit
(
0
)
except
Exception
as
e
:
import
traceback
tb_string
=
traceback
.
format_exc
()
# Try to serialize the exception object first
exc_to_serialize
:
dict
[
str
,
Any
]
try
:
# First, try to pickle the actual exception with
# its traceback.
exc_to_serialize
=
{
'pickled_exception'
:
e
}
# Test if it can be pickled
cloudpickle
.
dumps
(
exc_to_serialize
)
except
(
Exception
,
KeyboardInterrupt
):
# Fall back to string-based approach.
exc_to_serialize
=
{
'exception_type'
:
type
(
e
).
__name__
,
'exception_msg'
:
str
(
e
),
'traceback'
:
tb_string
,
}
try
:
with
open
(
exc_file_path
,
'wb'
)
as
f
:
cloudpickle
.
dump
(
exc_to_serialize
,
f
)
except
Exception
:
# Fallback: just print the traceback.
print
(
tb_string
)
os
.
_exit
(
1
)
else
:
os
.
_exit
(
0
)
else
:
else
:
os
.
_exit
(
0
)
pgid
=
os
.
getpgid
(
pid
)
else
:
_pid
,
_exitcode
=
os
.
waitpid
(
pid
,
0
)
pgid
=
os
.
getpgid
(
pid
)
# ignore SIGTERM signal itself
_pid
,
_exitcode
=
os
.
waitpid
(
pid
,
0
)
old_signal_handler
=
signal
.
signal
(
signal
.
SIGTERM
,
# ignore SIGTERM signal itself
signal
.
SIG_IGN
)
old_signal_handler
=
signal
.
signal
(
signal
.
SIGTERM
,
signal
.
SIG_IGN
)
# kill all child processes
# kill all child processes
os
.
killpg
(
pgid
,
signal
.
SIGTERM
)
os
.
killpg
(
pgid
,
signal
.
SIGTERM
)
# restore the signal handler
# restore the signal handler
signal
.
signal
(
signal
.
SIGTERM
,
old_signal_handler
)
signal
.
signal
(
signal
.
SIGTERM
,
old_signal_handler
)
if
_exitcode
!=
0
:
assert
_exitcode
==
0
,
(
f
"function
{
f
}
failed when called with"
# Try to read the exception from the child process
f
" args
{
args
}
and kwargs
{
kwargs
}
"
)
exc_info
=
{}
if
os
.
path
.
exists
(
exc_file_path
):
with
contextlib
.
suppress
(
Exception
),
\
open
(
exc_file_path
,
'rb'
)
as
f
:
exc_info
=
cloudpickle
.
load
(
f
)
if
(
original_exception
:
=
exc_info
.
get
(
'pickled_exception'
))
is
not
None
:
# Re-raise the actual exception object if it was
# successfully pickled.
assert
isinstance
(
original_exception
,
Exception
)
raise
original_exception
if
(
original_tb
:
=
exc_info
.
get
(
"traceback"
))
is
not
None
:
# Use string-based traceback for fallback case
raise
AssertionError
(
f
"Test
{
func
.
__name__
}
failed when called with"
f
" args
{
args
}
and kwargs
{
kwargs
}
"
f
" (exit code:
{
_exitcode
}
):
\n
{
original_tb
}
"
)
from
None
# Fallback to the original generic error
raise
AssertionError
(
f
"function
{
func
.
__name__
}
failed when called with"
f
" args
{
args
}
and kwargs
{
kwargs
}
"
f
" (exit code:
{
_exitcode
}
)"
)
from
None
return
wrapper
return
wrapper
...
@@ -1077,3 +1142,11 @@ def get_attn_backend_list_based_on_platform() -> list[str]:
...
@@ -1077,3 +1142,11 @@ def get_attn_backend_list_based_on_platform() -> list[str]:
return
attn_backend_list
return
attn_backend_list
else
:
else
:
raise
ValueError
(
"Unsupported platform"
)
raise
ValueError
(
"Unsupported platform"
)
@
contextmanager
def
override_cutlass_fp8_supported
(
value
:
bool
):
with
patch
(
"vllm.model_executor.layers.quantization.utils.w8a8_utils.cutlass_fp8_supported"
,
return_value
=
value
):
yield
tests/utils_/test_utils.py
View file @
38d80967
...
@@ -835,22 +835,20 @@ def test_model_specification(parser_with_config, cli_config_file,
...
@@ -835,22 +835,20 @@ def test_model_specification(parser_with_config, cli_config_file,
@
pytest
.
mark
.
parametrize
(
"input"
,
[(),
(
"abc"
,
),
(
None
,
),
@
pytest
.
mark
.
parametrize
(
"input"
,
[(),
(
"abc"
,
),
(
None
,
),
(
None
,
bool
,
[
1
,
2
,
3
])])
(
None
,
bool
,
[
1
,
2
,
3
])])
@
pytest
.
mark
.
parametrize
(
"output"
,
[
0
,
1
,
2
])
def
test_sha256
(
input
:
tuple
):
def
test_sha256
(
input
:
tuple
,
output
:
int
):
digest
=
sha256
(
input
)
hash
=
sha256
(
input
)
assert
digest
is
not
None
assert
hash
is
not
None
assert
isinstance
(
digest
,
bytes
)
assert
isinstance
(
hash
,
int
)
assert
digest
!=
b
""
assert
hash
!=
0
bytes
=
pickle
.
dumps
(
input
,
protocol
=
pickle
.
HIGHEST_PROTOCOL
)
input_bytes
=
pickle
.
dumps
(
input
,
protocol
=
pickle
.
HIGHEST_PROTOCOL
)
assert
hash
==
int
.
from_bytes
(
hashlib
.
sha256
(
bytes
).
digest
(),
assert
digest
==
hashlib
.
sha256
(
input_bytes
).
digest
()
byteorder
=
"big"
)
# hashing again, returns the same value
# hashing again, returns the same value
assert
hash
==
sha256
(
input
)
assert
digest
==
sha256
(
input
)
# hashing different input, returns different value
# hashing different input, returns different value
assert
hash
!=
sha256
(
input
+
(
1
,
))
assert
digest
!=
sha256
(
input
+
(
1
,
))
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
...
...
tests/v1/attention/test_attention_backends.py
View file @
38d80967
...
@@ -70,22 +70,6 @@ BATCH_SPECS = {
...
@@ -70,22 +70,6 @@ BATCH_SPECS = {
}
}
def
create_dummy_kv_cache
(
kv_cache_spec
:
FullAttentionSpec
,
device
:
torch
.
device
,
num_blocks
:
int
=
100
)
->
torch
.
Tensor
:
"""Create a dummy KV cache tensor for testing."""
kv_cache
=
torch
.
randn
(
2
,
# K and V
num_blocks
,
kv_cache_spec
.
block_size
,
kv_cache_spec
.
num_kv_heads
,
kv_cache_spec
.
head_size
,
dtype
=
_convert_dtype_to_torch
(
kv_cache_spec
.
dtype
),
device
=
device
,
)
return
kv_cache
def
create_and_prepopulate_kv_cache
(
def
create_and_prepopulate_kv_cache
(
k_contexts
:
list
[
torch
.
Tensor
],
k_contexts
:
list
[
torch
.
Tensor
],
v_contexts
:
list
[
torch
.
Tensor
],
v_contexts
:
list
[
torch
.
Tensor
],
...
...
tests/v1/attention/test_chunked_local_attention.py
View file @
38d80967
...
@@ -160,7 +160,7 @@ def test_local_attention_virtual_batches(test_data: LocalAttentionTestData):
...
@@ -160,7 +160,7 @@ def test_local_attention_virtual_batches(test_data: LocalAttentionTestData):
# Use torch.arange instead of torch.randint so we can assert on
# Use torch.arange instead of torch.randint so we can assert on
# block table tensor values. The block table will have shape
# block table tensor values. The block table will have shape
# (num_batches, cdiv(max_seq_len, block_size)) and the values will be
# (num_batches, cdiv(max_seq_len, block_size)) and the values will be
# aranged from 0 to cdiv(max_seq_len, block_size)-1
# ar
r
anged from 0 to cdiv(max_seq_len, block_size)-1
arange_block_indices
=
True
,
arange_block_indices
=
True
,
)
)
...
...
tests/v1/attention/test_mla_backends.py
View file @
38d80967
...
@@ -15,7 +15,7 @@ from vllm.v1.attention.backends.utils import CommonAttentionMetadata
...
@@ -15,7 +15,7 @@ from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from
vllm.v1.kv_cache_interface
import
FullAttentionSpec
from
vllm.v1.kv_cache_interface
import
FullAttentionSpec
BACKENDS_TO_TEST
=
[
BACKENDS_TO_TEST
=
[
_Backend
.
CUTLASS_MLA
,
_Backend
.
FLASHMLA_VLLM_V1
,
_Backend
.
CUTLASS_MLA
,
_Backend
.
FLASHMLA_VLLM_V1
,
_Backend
.
FLASH_ATTN_MLA
,
_Backend
.
TRITON_MLA_VLLM_V1
_Backend
.
TRITON_MLA_VLLM_V1
]
]
...
@@ -69,25 +69,10 @@ BATCH_SPECS = {
...
@@ -69,25 +69,10 @@ BATCH_SPECS = {
}
}
def
create_dummy_kv_cache
(
kv_cache_spec
:
FullAttentionSpec
,
device
:
torch
.
device
,
num_blocks
:
int
=
100
)
->
torch
.
Tensor
:
"""Create a dummy KV cache tensor for testing."""
kv_cache
=
torch
.
randn
(
num_blocks
,
kv_cache_spec
.
block_size
,
kv_cache_spec
.
head_size
,
# latent dimension
dtype
=
_convert_dtype_to_torch
(
kv_cache_spec
.
dtype
),
device
=
device
,
)
return
kv_cache
def
create_and_prepopulate_kv_cache
(
def
create_and_prepopulate_kv_cache
(
kv_c_contexts
:
list
[
torch
.
Tensor
],
kv_c_contexts
:
list
[
torch
.
Tensor
],
k_pe_contexts
:
list
[
torch
.
Tensor
],
k_pe_contexts
:
list
[
torch
.
Tensor
],
block_size
:
int
,
block_size
:
int
,
num_kv_heads
:
int
,
head_size
:
int
,
head_size
:
int
,
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
,
device
:
torch
.
device
,
...
@@ -101,7 +86,6 @@ def create_and_prepopulate_kv_cache(
...
@@ -101,7 +86,6 @@ def create_and_prepopulate_kv_cache(
k_pe_contexts: List of key positional embedding context tensors
k_pe_contexts: List of key positional embedding context tensors
for each sequence
for each sequence
block_size: Size of each block
block_size: Size of each block
num_kv_heads: Number of KV heads (should be 1 for MLA)
head_size: Size of each head (latent dimension)
head_size: Size of each head (latent dimension)
dtype: Data type for the cache
dtype: Data type for the cache
device: Device to create the cache on
device: Device to create the cache on
...
@@ -299,8 +283,6 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
...
@@ -299,8 +283,6 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
query_lens
=
batch_spec
.
query_lens
query_lens
=
batch_spec
.
query_lens
num_q_heads
=
vllm_config
.
model_config
.
get_num_attention_heads
(
num_q_heads
=
vllm_config
.
model_config
.
get_num_attention_heads
(
vllm_config
.
parallel_config
)
vllm_config
.
parallel_config
)
num_kv_heads
=
vllm_config
.
model_config
.
get_num_kv_heads
(
vllm_config
.
parallel_config
)
head_size
=
vllm_config
.
model_config
.
get_head_size
()
head_size
=
vllm_config
.
model_config
.
get_head_size
()
dtype
=
_convert_dtype_to_torch
(
vllm_config
.
model_config
.
dtype
)
dtype
=
_convert_dtype_to_torch
(
vllm_config
.
model_config
.
dtype
)
block_size
=
vllm_config
.
cache_config
.
block_size
block_size
=
vllm_config
.
cache_config
.
block_size
...
@@ -315,7 +297,7 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
...
@@ -315,7 +297,7 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
# 2. Generate data and compute SDPA reference output for MLA
# 2. Generate data and compute SDPA reference output for MLA
all_q_vllm
,
all_kv_c_vllm
,
all_k_pe_vllm
=
[],
[],
[]
all_q_vllm
,
all_kv_c_vllm
,
all_k_pe_vllm
=
[],
[],
[]
all_sdpa_outputs
=
[]
all_sdpa_outputs
:
list
[
list
[
torch
.
Tensor
]]
=
[]
kv_c_contexts
,
k_pe_contexts
=
[],
[]
kv_c_contexts
,
k_pe_contexts
=
[],
[]
# Create shared MLA weight matrices for consistency across all sequences
# Create shared MLA weight matrices for consistency across all sequences
...
@@ -331,6 +313,9 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
...
@@ -331,6 +313,9 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
device
=
device
)
device
=
device
)
kv_b_proj_weight
=
torch
.
cat
([
W_UK
,
W_UV
],
dim
=-
1
)
kv_b_proj_weight
=
torch
.
cat
([
W_UK
,
W_UV
],
dim
=-
1
)
for
i
,
backend
in
enumerate
(
BACKENDS_TO_TEST
):
all_sdpa_outputs
.
append
([])
for
i
in
range
(
batch_size
):
for
i
in
range
(
batch_size
):
s_len
=
seq_lens
[
i
]
s_len
=
seq_lens
[
i
]
q_len
=
query_lens
[
i
]
q_len
=
query_lens
[
i
]
...
@@ -358,85 +343,93 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
...
@@ -358,85 +343,93 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
dtype
=
dtype
,
dtype
=
dtype
,
device
=
device
)
device
=
device
)
# Determine if this is decode (single token)
# Determine if this is decode or prefill
# or prefill (multiple tokens)
is_decode
=
[]
is_decode
=
q_len
==
1
for
i
,
backend
in
enumerate
(
BACKENDS_TO_TEST
):
builder_cls
,
_
=
get_attention_backend
(
backend
)
is_decode
.
append
(
q_len
<=
builder_cls
.
reorder_batch_threshold
)
# Split q into nope and rope components
# Split q into nope and rope components
q_nope
,
q_pe
=
q_c
.
split
([
qk_nope_head_dim
,
qk_rope_head_dim
],
dim
=-
1
)
q_nope
,
q_pe
=
q_c
.
split
([
qk_nope_head_dim
,
qk_rope_head_dim
],
dim
=-
1
)
if
is_decode
:
#######################################################
# Decode path: MQA-style attention in latent space
# Decode path: MQA-style attention in latent space
# Transform q_nope to latent space: q_nope @ W_UK
# Transform q_nope to latent space: q_nope @ W_UK
# q_nope: [1, num_heads, qk_nope_head_dim]
# q_nope: [1, num_heads, qk_nope_head_dim]
# W_UK: [kv_lora_rank, num_heads, qk_nope_head_dim]
# W_UK: [kv_lora_rank, num_heads, qk_nope_head_dim]
ql_nope
=
torch
.
einsum
(
"qnh,lnh->qnl"
,
q_nope
,
ql_nope
=
torch
.
einsum
(
"qnh,lnh->qnl"
,
q_nope
,
W_UK
)
# [1, num_heads, kv_lora_rank]
W_UK
)
# [1, num_heads, kv_lora_rank]
# Build MQA attention inputs
# Build MQA attention inputs
# Q: [1, num_heads, kv_lora_rank + qk_rope_head_dim]
# Q: [1, num_heads, kv_lora_rank + qk_rope_head_dim]
q_mqa
=
torch
.
cat
([
ql_nope
,
q_pe
],
dim
=-
1
)
q_mqa
=
torch
.
cat
([
ql_nope
,
q_pe
],
dim
=-
1
)
# K: [s_len, kv_lora_rank + qk_rope_head_dim]
# K: [s_len, kv_lora_rank + qk_rope_head_dim]
# (broadcasted to all heads)
# (broadcasted to all heads)
k_mqa
=
torch
.
cat
([
kv_c_full
,
k_pe_full
.
squeeze
(
1
)],
dim
=-
1
)
k_mqa
=
torch
.
cat
([
kv_c_full
,
k_pe_full
.
squeeze
(
1
)],
dim
=-
1
)
k_mqa
=
k_mqa
.
unsqueeze
(
1
).
expand
(
-
1
,
num_q_heads
,
-
1
)
k_mqa
=
k_mqa
.
unsqueeze
(
1
).
expand
(
-
1
,
num_q_heads
,
-
1
)
# V: [s_len, kv_lora_rank] (broadcasted to all heads)
# V: [s_len, kv_lora_rank] (broadcasted to all heads)
v_mqa
=
kv_c_full
.
unsqueeze
(
1
).
expand
(
-
1
,
num_q_heads
,
-
1
)
v_mqa
=
kv_c_full
.
unsqueeze
(
1
).
expand
(
-
1
,
num_q_heads
,
-
1
)
# SDPA expects (N, H, L, D)
# Create custom attention mask for decode path:
q_sdpa_in
=
q_mqa
.
unsqueeze
(
0
).
transpose
(
1
,
2
)
# - Query tokens can attend to all context tokens
k_sdpa_in
=
k_mqa
.
unsqueeze
(
0
).
transpose
(
1
,
2
)
# - Query tokens can only attend to query tokens up to their position
v_sdpa_in
=
v_mqa
.
unsqueeze
(
0
).
transpose
(
1
,
2
)
attn_mask
=
torch
.
ones
(
q_len
,
s_len
,
dtype
=
torch
.
bool
,
device
=
device
)
# Apply causal mask only to the query portion (context_len onwards)
sdpa_out_i
=
torch
.
nn
.
functional
.
scaled_dot_product_attention
(
causal_mask
=
torch
.
tril
(
torch
.
ones
(
q_len
,
q_len
,
device
=
device
))
q_sdpa_in
,
k_sdpa_in
,
v_sdpa_in
,
is_causal
=
False
,
scale
=
scale
)
attn_mask
[:,
context_len
:]
=
causal_mask
sdpa_out_i
=
sdpa_out_i
.
transpose
(
1
,
2
).
squeeze
(
0
)
# [1, num_heads, kv_lora_rank]
# SDPA expects (N, H, L, D)
q_sdpa_in
=
q_mqa
.
unsqueeze
(
0
).
transpose
(
1
,
2
)
# Project back to output space: sdpa_out @ W_UV
k_sdpa_in
=
k_mqa
.
unsqueeze
(
0
).
transpose
(
1
,
2
)
sdpa_out_i
=
torch
.
einsum
(
"qnl,lnv->qnv"
,
sdpa_out_i
,
W_UV
)
v_sdpa_in
=
v_mqa
.
unsqueeze
(
0
).
transpose
(
1
,
2
)
sdpa_out_i
=
sdpa_out_i
.
flatten
(
start_dim
=-
2
)
else
:
sdpa_out_i_decode
=
torch
.
nn
.
functional
.
scaled_dot_product_attention
(
# Prefill path: MHA-style attention with full sequence
q_sdpa_in
,
k_sdpa_in
,
v_sdpa_in
,
attn_mask
=
attn_mask
,
scale
=
scale
)
# Apply kv_b_proj to the full kv_c tensor
sdpa_out_i_decode
=
sdpa_out_i_decode
.
transpose
(
1
,
2
).
squeeze
(
kv_nope_full
=
torch
.
einsum
(
"sl,lnh->snh"
,
kv_c_full
,
0
)
# [1, num_heads, kv_lora_rank]
kv_b_proj_weight
)
k_nope_full
,
v_full
=
kv_nope_full
.
split
(
# Project back to output space: sdpa_out @ W_UV
[
qk_nope_head_dim
,
v_head_dim
],
dim
=-
1
)
sdpa_out_i_decode
=
torch
.
einsum
(
"qnl,lnv->qnv"
,
sdpa_out_i_decode
,
W_UV
)
# Build attention inputs for full sequence
sdpa_out_i_decode
=
sdpa_out_i_decode
.
flatten
(
start_dim
=-
2
)
q_mha
=
torch
.
cat
([
q_nope
,
q_pe
],
dim
=-
1
)
# [q_len, num_heads, total_dim]
#######################################################
k_pe_full_expanded
=
k_pe_full
.
expand
(
-
1
,
num_q_heads
,
-
1
)
# Prefill path: MHA-style attention with full sequence
k_full
=
torch
.
cat
([
k_nope_full
,
k_pe_full_expanded
],
dim
=-
1
)
# Apply kv_b_proj to the full kv_c tensor
kv_nope_full
=
torch
.
einsum
(
"sl,lnh->snh"
,
kv_c_full
,
kv_b_proj_weight
)
# Create custom attention mask:
k_nope_full
,
v_full
=
kv_nope_full
.
split
(
# - Query tokens can attend to all context tokens
[
qk_nope_head_dim
,
v_head_dim
],
dim
=-
1
)
# - Query tokens can only attend to query tokens up to their pos
attn_mask
=
torch
.
ones
(
q_len
,
# Build attention inputs for full sequence
s_len
,
q_mha
=
torch
.
cat
([
q_nope
,
q_pe
],
dtype
=
torch
.
bool
,
dim
=-
1
)
# [q_len, num_heads, total_dim]
device
=
device
)
k_pe_full_expanded
=
k_pe_full
.
expand
(
-
1
,
num_q_heads
,
-
1
)
# Apply causal mask only to the query portion (context_len onwards)
k_full
=
torch
.
cat
([
k_nope_full
,
k_pe_full_expanded
],
dim
=-
1
)
causal_mask
=
torch
.
tril
(
torch
.
ones
(
q_len
,
q_len
,
device
=
device
))
attn_mask
[:,
context_len
:]
=
causal_mask
# Create custom attention mask:
# - Query tokens can attend to all context tokens
# SDPA expects (N, H, L, D)
# - Query tokens can only attend to query tokens up to their pos
q_sdpa_in
=
q_mha
.
unsqueeze
(
0
).
transpose
(
1
,
2
)
attn_mask
=
torch
.
ones
(
q_len
,
s_len
,
dtype
=
torch
.
bool
,
device
=
device
)
k_sdpa_in
=
k_full
.
unsqueeze
(
0
).
transpose
(
1
,
2
)
# Apply causal mask only to the query portion (context_len onwards)
v_sdpa_in
=
v_full
.
unsqueeze
(
0
).
transpose
(
1
,
2
)
causal_mask
=
torch
.
tril
(
torch
.
ones
(
q_len
,
q_len
,
device
=
device
))
attn_mask
[:,
context_len
:]
=
causal_mask
# Single attention call with custom mask
sdpa_out_i
=
torch
.
nn
.
functional
.
scaled_dot_product_attention
(
# SDPA expects (N, H, L, D)
q_sdpa_in
,
q_sdpa_in
=
q_mha
.
unsqueeze
(
0
).
transpose
(
1
,
2
)
k_sdpa_in
,
k_sdpa_in
=
k_full
.
unsqueeze
(
0
).
transpose
(
1
,
2
)
v_sdpa_in
,
v_sdpa_in
=
v_full
.
unsqueeze
(
0
).
transpose
(
1
,
2
)
attn_mask
=
attn_mask
,
scale
=
scale
)
# Single attention call with custom mask
sdpa_out_i
=
sdpa_out_i
.
transpose
(
1
,
2
).
squeeze
(
0
)
sdpa_out_i_prefill
=
torch
.
nn
.
functional
.
scaled_dot_product_attention
(
sdpa_out_i
=
sdpa_out_i
.
flatten
(
start_dim
=-
2
)
q_sdpa_in
,
k_sdpa_in
,
v_sdpa_in
,
attn_mask
=
attn_mask
,
scale
=
scale
)
sdpa_out_i_prefill
=
sdpa_out_i_prefill
.
transpose
(
1
,
2
).
squeeze
(
0
)
all_sdpa_outputs
.
append
(
sdpa_out_i
)
sdpa_out_i_prefill
=
sdpa_out_i_prefill
.
flatten
(
start_dim
=-
2
)
for
i
,
backend
in
enumerate
(
BACKENDS_TO_TEST
):
if
is_decode
[
i
]:
all_sdpa_outputs
[
i
].
append
(
sdpa_out_i_decode
)
else
:
all_sdpa_outputs
[
i
].
append
(
sdpa_out_i_prefill
)
# Inputs for vLLM MLA backends are just the new tokens
# Inputs for vLLM MLA backends are just the new tokens
all_q_vllm
.
append
(
q_c
)
all_q_vllm
.
append
(
q_c
)
...
@@ -451,7 +444,9 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
...
@@ -451,7 +444,9 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
query_vllm
=
torch
.
cat
(
all_q_vllm
,
dim
=
0
)
query_vllm
=
torch
.
cat
(
all_q_vllm
,
dim
=
0
)
kv_c_vllm
=
torch
.
cat
(
all_kv_c_vllm
,
dim
=
0
)
kv_c_vllm
=
torch
.
cat
(
all_kv_c_vllm
,
dim
=
0
)
k_pe_vllm
=
torch
.
cat
(
all_k_pe_vllm
,
dim
=
0
)
k_pe_vllm
=
torch
.
cat
(
all_k_pe_vllm
,
dim
=
0
)
sdpa_output
=
torch
.
cat
(
all_sdpa_outputs
,
dim
=
0
)
sdpa_outputs
=
[]
for
i
,
backend
in
enumerate
(
BACKENDS_TO_TEST
):
sdpa_outputs
.
append
(
torch
.
cat
(
all_sdpa_outputs
[
i
],
dim
=
0
))
# Create mock kv_b_proj using the same weights as reference implementation
# Create mock kv_b_proj using the same weights as reference implementation
from
vllm.model_executor.layers.linear
import
ColumnParallelLinear
from
vllm.model_executor.layers.linear
import
ColumnParallelLinear
...
@@ -477,7 +472,6 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
...
@@ -477,7 +472,6 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
kv_c_contexts
=
kv_c_contexts
,
kv_c_contexts
=
kv_c_contexts
,
k_pe_contexts
=
k_pe_contexts
,
k_pe_contexts
=
k_pe_contexts
,
block_size
=
block_size
,
block_size
=
block_size
,
num_kv_heads
=
num_kv_heads
,
head_size
=
head_size
,
head_size
=
head_size
,
dtype
=
dtype
,
dtype
=
dtype
,
device
=
device
,
device
=
device
,
...
@@ -486,7 +480,7 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
...
@@ -486,7 +480,7 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
randomize_blocks
=
True
)
randomize_blocks
=
True
)
# 4. Run vLLM backends and compare
# 4. Run vLLM backends and compare
for
backend_name
in
BACKENDS_TO_TEST
:
for
i
,
backend_name
in
enumerate
(
BACKENDS_TO_TEST
)
:
backend_output
=
run_attention_backend
(
backend_output
=
run_attention_backend
(
backend_name
,
kv_cache_spec
,
[
"placeholder"
],
vllm_config
,
device
,
backend_name
,
kv_cache_spec
,
[
"placeholder"
],
vllm_config
,
device
,
common_attn_metadata
,
query_vllm
,
kv_c_vllm
,
k_pe_vllm
,
kv_cache
,
common_attn_metadata
,
query_vllm
,
kv_c_vllm
,
k_pe_vllm
,
kv_cache
,
...
@@ -494,12 +488,12 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
...
@@ -494,12 +488,12 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
mock_kv_b_proj
)
mock_kv_b_proj
)
# Check shape and dtype consistency
# Check shape and dtype consistency
assert
backend_output
.
shape
==
sdpa_output
.
shape
,
(
assert
backend_output
.
shape
==
sdpa_output
s
[
i
]
.
shape
,
(
f
"[
{
backend_name
}
] shape
{
backend_output
.
shape
}
!= "
f
"[
{
backend_name
}
] shape
{
backend_output
.
shape
}
!= "
f
"SDPA shape
{
sdpa_output
.
shape
}
"
)
f
"SDPA shape
{
sdpa_output
s
[
i
]
.
shape
}
"
)
assert
backend_output
.
dtype
==
sdpa_output
.
dtype
,
(
assert
backend_output
.
dtype
==
sdpa_output
s
[
i
]
.
dtype
,
(
f
"[
{
backend_name
}
] dtype
{
backend_output
.
dtype
}
!= "
f
"[
{
backend_name
}
] dtype
{
backend_output
.
dtype
}
!= "
f
"SDPA dtype
{
sdpa_output
.
dtype
}
"
)
f
"SDPA dtype
{
sdpa_output
s
[
i
]
.
dtype
}
"
)
assert
torch
.
isfinite
(
backend_output
).
all
(),
(
assert
torch
.
isfinite
(
backend_output
).
all
(),
(
f
"[
{
backend_name
}
] produced non-finite values"
)
f
"[
{
backend_name
}
] produced non-finite values"
)
...
@@ -508,12 +502,13 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
...
@@ -508,12 +502,13 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
rtol
=
1e-2
rtol
=
1e-2
atol
=
5e-1
atol
=
5e-1
max_diff
=
torch
.
max
(
torch
.
abs
(
backend_output
-
sdpa_output
)).
item
()
max_diff
=
torch
.
max
(
torch
.
abs
(
backend_output
-
sdpa_outputs
[
i
])).
item
()
max_rel_diff
=
torch
.
max
(
max_rel_diff
=
torch
.
max
(
torch
.
abs
(
backend_output
-
sdpa_output
)
/
torch
.
abs
(
backend_output
-
sdpa_output
s
[
i
]
)
/
torch
.
abs
(
sdpa_output
)).
item
()
torch
.
abs
(
sdpa_output
s
[
i
]
)).
item
()
all_close
=
torch
.
allclose
(
backend_output
,
all_close
=
torch
.
allclose
(
backend_output
,
sdpa_output
,
sdpa_output
s
[
i
]
,
rtol
=
rtol
,
rtol
=
rtol
,
atol
=
atol
)
atol
=
atol
)
...
...
tests/v1/attention/utils.py
View file @
38d80967
...
@@ -139,6 +139,8 @@ def get_attention_backend(backend_name: _Backend):
...
@@ -139,6 +139,8 @@ def get_attention_backend(backend_name: _Backend):
"vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend"
,
"vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend"
,
_Backend
.
FLASHMLA_VLLM_V1
:
_Backend
.
FLASHMLA_VLLM_V1
:
"vllm.v1.attention.backends.mla.flashmla.FlashMLABackend"
,
"vllm.v1.attention.backends.mla.flashmla.FlashMLABackend"
,
_Backend
.
FLASH_ATTN_MLA
:
"vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend"
,
_Backend
.
TRITON_MLA_VLLM_V1
:
_Backend
.
TRITON_MLA_VLLM_V1
:
"vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend"
,
"vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend"
,
}
}
...
...
tests/v1/core/test_kv_cache_utils.py
View file @
38d80967
...
@@ -6,20 +6,22 @@ from typing import Callable, Optional
...
@@ -6,20 +6,22 @@ from typing import Callable, Optional
import
pytest
import
pytest
import
torch
import
torch
import
vllm.v1.core.kv_cache_utils
as
kv_cache_utils
from
vllm.config
import
ModelConfig
,
SchedulerConfig
,
VllmConfig
from
vllm.config
import
ModelConfig
,
SchedulerConfig
,
VllmConfig
from
vllm.multimodal.inputs
import
(
MultiModalFeatureSpec
,
from
vllm.multimodal.inputs
import
(
MultiModalFeatureSpec
,
MultiModalKwargsItem
,
PlaceholderRange
)
MultiModalKwargsItem
,
PlaceholderRange
)
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.utils
import
GiB_bytes
,
sha256
,
sha256_cbor
_64bit
from
vllm.utils
import
GiB_bytes
,
sha256
,
sha256_cbor
from
vllm.v1.core.kv_cache_manager
import
KVCacheManager
from
vllm.v1.core.kv_cache_manager
import
KVCacheManager
# disable yapf here as it formats differently than isort such that both fail
# disable yapf here as it formats differently than isort such that both fail
# yapf: disable
# yapf: disable
from
vllm.v1.core.kv_cache_utils
import
(
from
vllm.v1.core.kv_cache_utils
import
(
FreeKVCacheBlockQueue
,
KVCacheBlock
,
PrefixCachingMetrics
,
BlockHash
,
FreeKVCacheBlockQueue
,
KVCacheBlock
,
PrefixCachingMetrics
,
estimate_max_model_len
,
generate_block_hash_extra_keys
,
estimate_max_model_len
,
generate_block_hash_extra_keys
,
get_kv_cache_config
,
get_max_concurrency_for_kv_cache_config
,
get_kv_cache_config
,
get_max_concurrency_for_kv_cache_config
,
get_request_block_hasher
,
hash_block_tokens
,
init_none_hash
,
get_request_block_hasher
,
hash_block_tokens
,
init_none_hash
,
is_kv_cache_type_uniform
,
unify_kv_cache_configs
)
is_kv_cache_type_uniform
,
make_block_hash_with_group_id
,
unify_kv_cache_configs
)
from
vllm.v1.kv_cache_interface
import
(
FullAttentionSpec
,
KVCacheConfig
,
from
vllm.v1.kv_cache_interface
import
(
FullAttentionSpec
,
KVCacheConfig
,
KVCacheGroupSpec
,
KVCacheTensor
,
KVCacheGroupSpec
,
KVCacheTensor
,
SlidingWindowSpec
)
SlidingWindowSpec
)
...
@@ -88,7 +90,7 @@ def new_sliding_window_spec(block_size=16,
...
@@ -88,7 +90,7 @@ def new_sliding_window_spec(block_size=16,
sliding_window
=
sliding_window
)
sliding_window
=
sliding_window
)
@
pytest
.
mark
.
parametrize
(
"hash_fn"
,
[
sha256
,
sha256_cbor
_64bit
,
hash
])
@
pytest
.
mark
.
parametrize
(
"hash_fn"
,
[
sha256
,
sha256_cbor
])
def
test_none_hash
(
monkeypatch
,
hash_fn
):
def
test_none_hash
(
monkeypatch
,
hash_fn
):
import
vllm.v1.core.kv_cache_utils
import
vllm.v1.core.kv_cache_utils
...
@@ -98,8 +100,8 @@ def test_none_hash(monkeypatch, hash_fn):
...
@@ -98,8 +100,8 @@ def test_none_hash(monkeypatch, hash_fn):
reloaded_kv_cache_utils
=
importlib
.
reload
(
vllm
.
v1
.
core
.
kv_cache_utils
)
reloaded_kv_cache_utils
=
importlib
.
reload
(
vllm
.
v1
.
core
.
kv_cache_utils
)
reloaded_kv_cache_utils
.
init_none_hash
(
hash_fn
)
reloaded_kv_cache_utils
.
init_none_hash
(
hash_fn
)
assert
reloaded_kv_cache_utils
.
NONE_HASH
is
not
None
assert
reloaded_kv_cache_utils
.
NONE_HASH
is
not
None
assert
isinstance
(
reloaded_kv_cache_utils
.
NONE_HASH
,
int
)
assert
isinstance
(
reloaded_kv_cache_utils
.
NONE_HASH
,
bytes
)
assert
reloaded_kv_cache_utils
.
NONE_HASH
!=
0
assert
reloaded_kv_cache_utils
.
NONE_HASH
!=
b
""
# case 2: PYTHONHASHSEED is set, use the seed and hash_fn
# case 2: PYTHONHASHSEED is set, use the seed and hash_fn
with
monkeypatch
.
context
()
as
m
:
with
monkeypatch
.
context
()
as
m
:
...
@@ -107,12 +109,11 @@ def test_none_hash(monkeypatch, hash_fn):
...
@@ -107,12 +109,11 @@ def test_none_hash(monkeypatch, hash_fn):
reloaded_kv_cache_utils
=
importlib
.
reload
(
vllm
.
v1
.
core
.
kv_cache_utils
)
reloaded_kv_cache_utils
=
importlib
.
reload
(
vllm
.
v1
.
core
.
kv_cache_utils
)
reloaded_kv_cache_utils
.
init_none_hash
(
hash_fn
)
reloaded_kv_cache_utils
.
init_none_hash
(
hash_fn
)
assert
reloaded_kv_cache_utils
.
NONE_HASH
is
not
None
assert
reloaded_kv_cache_utils
.
NONE_HASH
is
not
None
assert
isinstance
(
reloaded_kv_cache_utils
.
NONE_HASH
,
int
)
assert
isinstance
(
reloaded_kv_cache_utils
.
NONE_HASH
,
bytes
)
assert
hash_fn
(
'python hash seed'
)
==
reloaded_kv_cache_utils
.
NONE_HASH
assert
hash_fn
(
'python hash seed'
)
==
reloaded_kv_cache_utils
.
NONE_HASH
def
test_kv_cache_block
():
def
test_kv_cache_block
():
import
vllm.v1.core.kv_cache_utils
# Test KVCacheBlock initialization
# Test KVCacheBlock initialization
block
=
KVCacheBlock
(
block_id
=
0
)
block
=
KVCacheBlock
(
block_id
=
0
)
...
@@ -127,8 +128,7 @@ def test_kv_cache_block():
...
@@ -127,8 +128,7 @@ def test_kv_cache_block():
assert
block
.
ref_cnt
==
0
assert
block
.
ref_cnt
==
0
# Test block hash setting and resetting
# Test block hash setting and resetting
block_hash
=
vllm
.
v1
.
core
.
kv_cache_utils
.
BlockHash
(
hash_value
=
123
,
block_hash
=
make_block_hash_with_group_id
(
BlockHash
(
b
"abc"
),
0
)
token_ids
=
(
1
,
2
,
3
))
block
.
block_hash
=
block_hash
block
.
block_hash
=
block_hash
assert
block
.
block_hash
==
block_hash
assert
block
.
block_hash
==
block_hash
...
@@ -247,7 +247,7 @@ def test_free_kv_cache_block_queue_append_n():
...
@@ -247,7 +247,7 @@ def test_free_kv_cache_block_queue_append_n():
def
test_free_kv_cache_block_queue_popleft_n
():
def
test_free_kv_cache_block_queue_popleft_n
():
blocks
=
[
KVCacheBlock
(
block_id
=
i
)
for
i
in
range
(
6
)]
blocks
=
[
KVCacheBlock
(
block_id
=
i
)
for
i
in
range
(
6
)]
# Create a empty FreeKVCacheBlockQueue with these blocks
# Create a
n
empty FreeKVCacheBlockQueue with these blocks
queue
=
FreeKVCacheBlockQueue
(
queue
=
FreeKVCacheBlockQueue
(
[
blocks
[
1
],
blocks
[
3
],
blocks
[
5
],
blocks
[
4
],
blocks
[
0
],
blocks
[
2
]])
[
blocks
[
1
],
blocks
[
3
],
blocks
[
5
],
blocks
[
4
],
blocks
[
0
],
blocks
[
2
]])
assert
queue
.
num_free_blocks
==
6
assert
queue
.
num_free_blocks
==
6
...
@@ -407,27 +407,23 @@ def test_generate_block_hash_extra_keys_cache_salt():
...
@@ -407,27 +407,23 @@ def test_generate_block_hash_extra_keys_cache_salt():
assert
next_mm_idx
==
1
assert
next_mm_idx
==
1
@
pytest
.
mark
.
parametrize
(
"hash_fn"
,
[
sha256
,
sha256_cbor
_64bit
,
hash
])
@
pytest
.
mark
.
parametrize
(
"hash_fn"
,
[
sha256
,
sha256_cbor
])
def
test_hash_block_tokens
(
hash_fn
):
def
test_hash_block_tokens
(
hash_fn
):
import
vllm.v1.core.kv_cache_utils
init_none_hash
(
hash_fn
)
init_none_hash
(
hash_fn
)
parent_block_hash
=
123
parent_block_hash
=
BlockHash
(
b
"
123
"
)
curr_block_token_ids
=
(
1
,
2
,
3
)
curr_block_token_ids
=
(
1
,
2
,
3
)
extra_keys
=
(
"key1"
,
"key2"
)
extra_keys
=
(
"key1"
,
"key2"
)
block_hash
=
hash_block_tokens
(
hash_fn
,
parent_block_hash
,
block_hash
=
hash_block_tokens
(
hash_fn
,
parent_block_hash
,
curr_block_token_ids
,
extra_keys
)
curr_block_token_ids
,
extra_keys
)
assert
isinstance
(
block_hash
,
vllm
.
v1
.
core
.
kv_cache_utils
.
BlockHash
)
expected
=
hash_fn
((
parent_block_hash
,
curr_block_token_ids
,
extra_keys
))
assert
block_hash
.
hash_value
==
hash_fn
(
assert
block_hash
==
expected
(
parent_block_hash
,
curr_block_token_ids
,
extra_keys
))
assert
block_hash
.
token_ids
==
curr_block_token_ids
assert
block_hash
.
extra_keys
==
extra_keys
@
pytest
.
mark
.
parametrize
(
"hash_fn"
,
[
sha256
,
sha256_cbor
_64bit
,
hash
])
@
pytest
.
mark
.
parametrize
(
"hash_fn"
,
[
sha256
,
sha256_cbor
])
def
test_request_block_hasher
(
hash_fn
):
def
test_request_block_hasher
(
hash_fn
):
import
vllm.v1.core.kv_cache_utils
kv_cache_utils
.
init_none_hash
(
hash_fn
)
init_none_hash
(
hash_fn
)
request
=
make_request
(
request
=
make_request
(
request_id
=
"0"
,
request_id
=
"0"
,
prompt_token_ids
=
[
_
for
_
in
range
(
6
)],
prompt_token_ids
=
[
_
for
_
in
range
(
6
)],
...
@@ -442,19 +438,13 @@ def test_request_block_hasher(hash_fn):
...
@@ -442,19 +438,13 @@ def test_request_block_hasher(hash_fn):
block_hashes
=
request
.
block_hashes
block_hashes
=
request
.
block_hashes
assert
len
(
block_hashes
)
==
2
assert
len
(
block_hashes
)
==
2
assert
isinstance
(
block_hashes
[
0
],
vllm
.
v1
.
core
.
kv_cache_utils
.
BlockHash
)
assert
block_hashes
[
0
]
==
hash_fn
(
assert
isinstance
(
block_hashes
[
1
],
vllm
.
v1
.
core
.
kv_cache_utils
.
BlockHash
)
(
kv_cache_utils
.
NONE_HASH
,
(
0
,
1
,
2
),
(
"hash1"
,
)))
assert
block_hashes
[
1
]
==
hash_fn
(
# Check the first block
(
block_hashes
[
0
],
(
3
,
4
,
5
),
(
"hash2"
,
)))
assert
block_hashes
[
0
].
token_ids
==
(
0
,
1
,
2
)
assert
block_hashes
[
0
].
extra_keys
==
(
"hash1"
,
)
# Check the second block
assert
block_hashes
[
1
].
token_ids
==
(
3
,
4
,
5
)
assert
block_hashes
[
1
].
extra_keys
==
(
"hash2"
,
)
@
pytest
.
mark
.
parametrize
(
"hash_fn"
,
[
sha256
,
sha256_cbor
])
@
pytest
.
mark
.
parametrize
(
"hash_fn"
,
[
sha256
,
sha256_cbor_64bit
,
hash
])
def
test_hash_tokens_different_mm_input
(
hash_fn
):
def
test_hash_tokens_different_mm_input
(
hash_fn
):
init_none_hash
(
hash_fn
)
init_none_hash
(
hash_fn
)
...
@@ -484,9 +474,9 @@ def test_hash_tokens_different_mm_input(hash_fn):
...
@@ -484,9 +474,9 @@ def test_hash_tokens_different_mm_input(hash_fn):
assert
block_hashes1
[
1
]
!=
block_hashes2
[
1
]
assert
block_hashes1
[
1
]
!=
block_hashes2
[
1
]
@
pytest
.
mark
.
parametrize
(
"hash_fn"
,
[
sha256
,
sha256_cbor
_64bit
,
hash
])
@
pytest
.
mark
.
parametrize
(
"hash_fn"
,
[
sha256
,
sha256_cbor
])
def
test_hash_request_tokens_no_mm_inputs
(
hash_fn
):
def
test_hash_request_tokens_no_mm_inputs
(
hash_fn
):
init_none_hash
(
hash_fn
)
kv_cache_utils
.
init_none_hash
(
hash_fn
)
request
=
make_request
(
request
=
make_request
(
request_id
=
"0"
,
request_id
=
"0"
,
...
@@ -500,10 +490,9 @@ def test_hash_request_tokens_no_mm_inputs(hash_fn):
...
@@ -500,10 +490,9 @@ def test_hash_request_tokens_no_mm_inputs(hash_fn):
block_hashes
=
request
.
block_hashes
block_hashes
=
request
.
block_hashes
assert
len
(
block_hashes
)
==
2
assert
len
(
block_hashes
)
==
2
assert
block_hashes
[
0
].
token_ids
==
(
0
,
1
,
2
)
assert
block_hashes
[
0
]
==
hash_fn
(
assert
block_hashes
[
0
].
extra_keys
is
None
(
kv_cache_utils
.
NONE_HASH
,
(
0
,
1
,
2
),
None
))
assert
block_hashes
[
1
].
token_ids
==
(
3
,
4
,
5
)
assert
block_hashes
[
1
]
==
hash_fn
((
block_hashes
[
0
],
(
3
,
4
,
5
),
None
))
assert
block_hashes
[
1
].
extra_keys
is
None
def
test_metrics
():
def
test_metrics
():
...
...
tests/v1/core/test_prefix_caching.py
View file @
38d80967
...
@@ -8,17 +8,19 @@ from typing import Callable, Optional
...
@@ -8,17 +8,19 @@ from typing import Callable, Optional
import
pytest
import
pytest
import
torch
import
torch
import
vllm.v1.core.kv_cache_utils
as
kv_cache_utils
from
vllm.distributed.kv_events
import
AllBlocksCleared
,
BlockRemoved
from
vllm.distributed.kv_events
import
AllBlocksCleared
,
BlockRemoved
from
vllm.multimodal.inputs
import
(
MultiModalFeatureSpec
,
from
vllm.multimodal.inputs
import
(
MultiModalFeatureSpec
,
MultiModalKwargsItem
,
PlaceholderRange
)
MultiModalKwargsItem
,
PlaceholderRange
)
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.utils
import
sha256
,
sha256_cbor
_64bit
from
vllm.utils
import
sha256
,
sha256_cbor
from
vllm.v1.core.block_pool
import
BlockPool
from
vllm.v1.core.block_pool
import
BlockPool
from
vllm.v1.core.kv_cache_manager
import
KVCacheManager
,
Request
from
vllm.v1.core.kv_cache_manager
import
KVCacheManager
,
Request
from
vllm.v1.core.kv_cache_utils
import
(
BlockHash
,
BlockHashWithGroupId
,
from
vllm.v1.core.kv_cache_utils
import
(
BlockHash
,
KVCacheBlock
,
KVCacheBlock
,
get_block_hash
,
get_group_id
,
get_request_block_hasher
,
get_request_block_hasher
,
hash_block_tokens
,
init_none_hash
)
hash_block_tokens
,
init_none_hash
,
make_block_hash_with_group_id
)
from
vllm.v1.kv_cache_interface
import
(
FullAttentionSpec
,
KVCacheConfig
,
from
vllm.v1.kv_cache_interface
import
(
FullAttentionSpec
,
KVCacheConfig
,
KVCacheGroupSpec
,
SlidingWindowSpec
)
KVCacheGroupSpec
,
SlidingWindowSpec
)
...
@@ -101,8 +103,10 @@ def make_kv_cache_config_hybrid_model(block_size: int,
...
@@ -101,8 +103,10 @@ def make_kv_cache_config_hybrid_model(block_size: int,
)
)
@
pytest
.
mark
.
parametrize
(
"hash_algo"
,
[
"sha256"
,
"sha256_cbor_64bit"
,
"hash"
])
@
pytest
.
mark
.
parametrize
(
"hash_fn"
,
[
sha256
,
sha256_cbor
])
def
test_prefill
(
hash_algo
):
def
test_prefill
(
hash_fn
):
init_none_hash
(
hash_fn
)
block_size
=
16
block_size
=
16
manager
=
KVCacheManager
(
manager
=
KVCacheManager
(
make_kv_cache_config
(
block_size
,
11
),
make_kv_cache_config
(
block_size
,
11
),
...
@@ -110,10 +114,6 @@ def test_prefill(hash_algo):
...
@@ -110,10 +114,6 @@ def test_prefill(hash_algo):
enable_caching
=
True
,
enable_caching
=
True
,
)
)
# choose the hash function according to the parameter
hash_fn
=
(
sha256_cbor_64bit
if
hash_algo
==
"sha256_cbor_64bit"
else
sha256
if
hash_algo
==
"sha256"
else
hash
)
# Complete 3 blocks (48 tokens)
# Complete 3 blocks (48 tokens)
common_token_ids
=
[
i
for
i
in
range
(
3
)
for
_
in
range
(
16
)]
common_token_ids
=
[
i
for
i
in
range
(
3
)
for
_
in
range
(
16
)]
...
@@ -137,10 +137,12 @@ def test_prefill(hash_algo):
...
@@ -137,10 +137,12 @@ def test_prefill(hash_algo):
block_tokens
=
tuple
(
all_token_ids
[(
block_id
-
1
)
*
16
:
block_id
*
16
])
block_tokens
=
tuple
(
all_token_ids
[(
block_id
-
1
)
*
16
:
block_id
*
16
])
block_hash
=
hash_block_tokens
(
hash_fn
,
parent_block_hash
,
block_hash
=
hash_block_tokens
(
hash_fn
,
parent_block_hash
,
block_tokens
)
block_tokens
)
assert
manager
.
block_pool
.
blocks
[
blk_hash
=
manager
.
block_pool
.
blocks
[
block_id
].
block_hash
block_id
].
block_hash
.
block_hash
==
block_hash
assert
blk_hash
is
not
None
assert
get_block_hash
(
blk_hash
)
==
block_hash
assert
get_group_id
(
blk_hash
)
==
0
assert
manager
.
block_pool
.
blocks
[
block_id
].
ref_cnt
==
1
assert
manager
.
block_pool
.
blocks
[
block_id
].
ref_cnt
==
1
parent_block_hash
=
block_hash
.
hash_value
parent_block_hash
=
block_hash
# Check partial block metadata
# Check partial block metadata
for
block_id
in
(
4
,
):
for
block_id
in
(
4
,
):
...
@@ -233,7 +235,7 @@ def test_prefill_hybrid_model():
...
@@ -233,7 +235,7 @@ def test_prefill_hybrid_model():
enable_caching
=
True
,
enable_caching
=
True
,
)
)
hash_fn
=
ha
sh
hash_fn
=
sh
a256
# Complete 3 blocks (48 tokens)
# Complete 3 blocks (48 tokens)
common_token_ids
=
[
i
for
i
in
range
(
3
)
for
_
in
range
(
block_size
)]
common_token_ids
=
[
i
for
i
in
range
(
3
)
for
_
in
range
(
block_size
)]
...
@@ -260,11 +262,13 @@ def test_prefill_hybrid_model():
...
@@ -260,11 +262,13 @@ def test_prefill_hybrid_model():
block_tokens
=
tuple
(
all_token_ids
[(
length
-
1
)
*
16
:
length
*
16
])
block_tokens
=
tuple
(
all_token_ids
[(
length
-
1
)
*
16
:
length
*
16
])
block_hash
=
hash_block_tokens
(
hash_fn
,
parent_block_hash
,
block_hash
=
hash_block_tokens
(
hash_fn
,
parent_block_hash
,
block_tokens
)
block_tokens
)
for
block_id
in
block_ids
:
for
group_id
,
block_id
in
enumerate
(
block_ids
):
assert
manager
.
block_pool
.
blocks
[
blk_hash
=
manager
.
block_pool
.
blocks
[
block_id
].
block_hash
block_id
].
block_hash
.
block_hash
==
block_hash
assert
blk_hash
is
not
None
assert
get_block_hash
(
blk_hash
)
==
block_hash
assert
get_group_id
(
blk_hash
)
==
group_id
assert
manager
.
block_pool
.
blocks
[
block_id
].
ref_cnt
==
1
assert
manager
.
block_pool
.
blocks
[
block_id
].
ref_cnt
==
1
parent_block_hash
=
block_hash
.
hash_value
parent_block_hash
=
block_hash
# Check partial block metadata
# Check partial block metadata
for
block_id
in
(
4
,
8
,
12
):
for
block_id
in
(
4
,
8
,
12
):
...
@@ -298,11 +302,10 @@ def test_prefill_hybrid_model():
...
@@ -298,11 +302,10 @@ def test_prefill_hybrid_model():
cached_block_hash_to_block_bak
=
copy
.
copy
(
cached_block_hash_to_block_bak
=
copy
.
copy
(
manager
.
block_pool
.
cached_block_hash_to_block
)
manager
.
block_pool
.
cached_block_hash_to_block
)
def
test_partial_request_hit
(
request_id
:
str
,
def
test_partial_request_hit
(
request_id
:
str
,
hash_to_evict
:
list
[
bytes
],
hash_to_evict
:
list
[
BlockHashWithGroupId
],
expect_hit_length
:
int
):
expect_hit_length
:
int
):
req
=
make_request
(
request_id
,
common_token_ids
+
unique_token_ids
,
req
=
make_request
(
request_id
,
common_token_ids
+
unique_token_ids
,
block_size
,
ha
sh
)
block_size
,
sh
a256
)
for
hash_with_group_id
in
hash_to_evict
:
for
hash_with_group_id
in
hash_to_evict
:
manager
.
block_pool
.
cached_block_hash_to_block
.
pop
(
manager
.
block_pool
.
cached_block_hash_to_block
.
pop
(
hash_with_group_id
)
hash_with_group_id
)
...
@@ -319,33 +322,32 @@ def test_prefill_hybrid_model():
...
@@ -319,33 +322,32 @@ def test_prefill_hybrid_model():
# Evict the blocks outside sliding window, does not affect the hit length.
# Evict the blocks outside sliding window, does not affect the hit length.
test_partial_request_hit
(
"2"
,
[
test_partial_request_hit
(
"2"
,
[
B
lock
H
ash
W
ith
G
roup
I
d
(
block_hashes
[
0
],
1
),
make_b
lock
_h
ash
_w
ith
_g
roup
_i
d
(
block_hashes
[
0
],
1
),
B
lock
H
ash
W
ith
G
roup
I
d
(
block_hashes
[
0
],
2
)
make_b
lock
_h
ash
_w
ith
_g
roup
_i
d
(
block_hashes
[
0
],
2
)
],
3
)
],
3
)
# Evict the first block of full attention, makes total cache miss.
# Evict the first block of full attention, makes total cache miss.
test_partial_request_hit
(
"3"
,
[
test_partial_request_hit
(
BlockHashWithGroupId
(
block_hashes
[
0
],
0
),
"3"
,
[
make_block_hash_with_group_id
(
block_hashes
[
0
],
0
)],
0
)
],
0
)
# Evict the last block of all layers, reduces the hit length to 2.
# Evict the last block of all layers, reduces the hit length to 2.
test_partial_request_hit
(
"4"
,
[
test_partial_request_hit
(
"4"
,
[
B
lock
H
ash
W
ith
G
roup
I
d
(
block_hashes
[
2
],
0
),
make_b
lock
_h
ash
_w
ith
_g
roup
_i
d
(
block_hashes
[
2
],
0
),
B
lock
H
ash
W
ith
G
roup
I
d
(
block_hashes
[
2
],
1
),
make_b
lock
_h
ash
_w
ith
_g
roup
_i
d
(
block_hashes
[
2
],
1
),
B
lock
H
ash
W
ith
G
roup
I
d
(
block_hashes
[
2
],
2
),
make_b
lock
_h
ash
_w
ith
_g
roup
_i
d
(
block_hashes
[
2
],
2
),
],
2
)
],
2
)
# Evict the last block of full attention, reduces the hit length to 2.
# Evict the last block of full attention, reduces the hit length to 2.
test_partial_request_hit
(
"5"
,
[
BlockHashWithGroupId
(
block_hashes
[
2
],
0
)],
test_partial_request_hit
(
2
)
"5"
,
[
make_block_hash_with_group_id
(
block_hashes
[
2
],
0
)],
2
)
# Evict the last block of sliding window, reduces the hit length to 2.
# Evict the last block of sliding window, reduces the hit length to 2.
test_partial_request_hit
(
"6"
,
[
BlockHashWithGroupId
(
block_hashes
[
2
],
1
)],
test_partial_request_hit
(
2
)
"6"
,
[
make_block_hash_with_group_id
(
block_hashes
[
2
],
1
)],
2
)
# Evict the last block of sliding window, reduces the hit length to 2.
# Evict the last block of sliding window, reduces the hit length to 2.
test_partial_request_hit
(
"7"
,
[
BlockHashWithGroupId
(
block_hashes
[
2
],
2
)],
test_partial_request_hit
(
2
)
"7"
,
[
make_block_hash_with_group_id
(
block_hashes
[
2
],
2
)],
2
)
# Evict different set of blocks for full attention and sliding window makes
# Evict different set of blocks for full attention and sliding window makes
# total cache miss.
# total cache miss.
...
@@ -353,9 +355,9 @@ def test_prefill_hybrid_model():
...
@@ -353,9 +355,9 @@ def test_prefill_hybrid_model():
# The cache hit length of sliding window is 2 * block_size.
# The cache hit length of sliding window is 2 * block_size.
# Then it is cache miss as the two type of layers have different hit length.
# Then it is cache miss as the two type of layers have different hit length.
test_partial_request_hit
(
"8"
,
[
test_partial_request_hit
(
"8"
,
[
B
lock
H
ash
W
ith
G
roup
I
d
(
block_hashes
[
2
],
0
),
make_b
lock
_h
ash
_w
ith
_g
roup
_i
d
(
block_hashes
[
2
],
0
),
B
lock
H
ash
W
ith
G
roup
I
d
(
block_hashes
[
0
],
1
),
make_b
lock
_h
ash
_w
ith
_g
roup
_i
d
(
block_hashes
[
0
],
1
),
B
lock
H
ash
W
ith
G
roup
I
d
(
block_hashes
[
0
],
2
),
make_b
lock
_h
ash
_w
ith
_g
roup
_i
d
(
block_hashes
[
0
],
2
),
],
0
)
],
0
)
...
@@ -372,8 +374,8 @@ def test_prefill_plp():
...
@@ -372,8 +374,8 @@ def test_prefill_plp():
max_model_len
=
8192
,
max_model_len
=
8192
,
enable_caching
=
True
,
enable_caching
=
True
,
)
)
# the default hash function is
ha
sh
# the default hash function is sh
a256
hash_fn
=
ha
sh
hash_fn
=
sh
a256
# Complete 3 blocks (48 tokens)
# Complete 3 blocks (48 tokens)
common_token_ids
=
[
i
for
i
in
range
(
3
)
for
_
in
range
(
16
)]
common_token_ids
=
[
i
for
i
in
range
(
3
)
for
_
in
range
(
16
)]
...
@@ -404,10 +406,12 @@ def test_prefill_plp():
...
@@ -404,10 +406,12 @@ def test_prefill_plp():
block_tokens
=
tuple
(
all_token_ids
[(
block_id
-
1
)
*
16
:
block_id
*
16
])
block_tokens
=
tuple
(
all_token_ids
[(
block_id
-
1
)
*
16
:
block_id
*
16
])
block_hash
=
hash_block_tokens
(
hash_fn
,
parent_block_hash
,
block_hash
=
hash_block_tokens
(
hash_fn
,
parent_block_hash
,
block_tokens
)
block_tokens
)
assert
manager
.
block_pool
.
blocks
[
blk_hash
=
(
manager
.
block_pool
.
blocks
[
block_id
].
block_hash
)
block_id
].
block_hash
.
block_hash
==
block_hash
assert
blk_hash
is
not
None
assert
get_block_hash
(
blk_hash
)
==
block_hash
assert
get_group_id
(
blk_hash
)
==
0
assert
manager
.
block_pool
.
blocks
[
block_id
].
ref_cnt
==
1
assert
manager
.
block_pool
.
blocks
[
block_id
].
ref_cnt
==
1
parent_block_hash
=
block_hash
.
hash_value
parent_block_hash
=
block_hash
# Check partial block metadata
# Check partial block metadata
for
block_id
in
(
4
,
):
for
block_id
in
(
4
,
):
...
@@ -493,7 +497,7 @@ def test_decode():
...
@@ -493,7 +497,7 @@ def test_decode():
# Incomplete 1 block (7 tokens)
# Incomplete 1 block (7 tokens)
unique_token_ids
=
[
3
]
*
7
unique_token_ids
=
[
3
]
*
7
req0
=
make_request
(
"0"
,
common_token_ids
+
unique_token_ids
,
block_size
,
req0
=
make_request
(
"0"
,
common_token_ids
+
unique_token_ids
,
block_size
,
ha
sh
)
sh
a256
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req0
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req0
)
assert
not
computed_blocks
.
blocks
[
0
]
assert
not
computed_blocks
.
blocks
[
0
]
assert
num_computed_tokens
==
0
assert
num_computed_tokens
==
0
...
@@ -538,7 +542,7 @@ def test_evict():
...
@@ -538,7 +542,7 @@ def test_evict():
)
)
last_token_id
=
5
*
16
+
7
last_token_id
=
5
*
16
+
7
req0
=
make_request
(
"0"
,
list
(
range
(
last_token_id
)),
block_size
,
ha
sh
)
req0
=
make_request
(
"0"
,
list
(
range
(
last_token_id
)),
block_size
,
sh
a256
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req0
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req0
)
assert
not
computed_blocks
.
blocks
[
0
]
assert
not
computed_blocks
.
blocks
[
0
]
assert
num_computed_tokens
==
0
assert
num_computed_tokens
==
0
...
@@ -550,7 +554,7 @@ def test_evict():
...
@@ -550,7 +554,7 @@ def test_evict():
# 3 blocks.
# 3 blocks.
req1
=
make_request
(
"1"
,
list
(
range
(
last_token_id
,
req1
=
make_request
(
"1"
,
list
(
range
(
last_token_id
,
last_token_id
+
3
*
16
)),
block_size
,
last_token_id
+
3
*
16
)),
block_size
,
ha
sh
)
sh
a256
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
assert
not
computed_blocks
.
blocks
[
0
]
assert
not
computed_blocks
.
blocks
[
0
]
assert
num_computed_tokens
==
0
assert
num_computed_tokens
==
0
...
@@ -572,7 +576,7 @@ def test_evict():
...
@@ -572,7 +576,7 @@ def test_evict():
]
==
[
10
,
6
,
5
,
4
,
3
,
2
,
1
,
9
,
8
,
7
]
]
==
[
10
,
6
,
5
,
4
,
3
,
2
,
1
,
9
,
8
,
7
]
# Touch the first 2 blocks.
# Touch the first 2 blocks.
req2
=
make_request
(
"2"
,
list
(
range
(
2
*
16
+
3
)),
block_size
,
ha
sh
)
req2
=
make_request
(
"2"
,
list
(
range
(
2
*
16
+
3
)),
block_size
,
sh
a256
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req2
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req2
)
assert
computed_blocks
.
get_block_ids
()
==
([
1
,
2
],
)
assert
computed_blocks
.
get_block_ids
()
==
([
1
,
2
],
)
assert
num_computed_tokens
==
2
*
16
assert
num_computed_tokens
==
2
*
16
...
@@ -597,7 +601,7 @@ def test_hash_block_correct_reuse():
...
@@ -597,7 +601,7 @@ def test_hash_block_correct_reuse():
# Allocate 1 block and cache it.
# Allocate 1 block and cache it.
num_tokens
=
block_size
*
1
num_tokens
=
block_size
*
1
req
=
make_request
(
"0"
,
list
(
range
(
num_tokens
)),
block_size
,
ha
sh
)
req
=
make_request
(
"0"
,
list
(
range
(
num_tokens
)),
block_size
,
sh
a256
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req
)
assert
not
computed_blocks
.
blocks
[
0
]
assert
not
computed_blocks
.
blocks
[
0
]
assert
num_computed_tokens
==
0
assert
num_computed_tokens
==
0
...
@@ -611,7 +615,7 @@ def test_hash_block_correct_reuse():
...
@@ -611,7 +615,7 @@ def test_hash_block_correct_reuse():
# Allocate a new block that's not full, make sure hash info on the
# Allocate a new block that's not full, make sure hash info on the
# block is cleared.
# block is cleared.
req
=
make_request
(
"1"
,
list
(
range
(
num_tokens
-
1
)),
block_size
,
ha
sh
)
req
=
make_request
(
"1"
,
list
(
range
(
num_tokens
-
1
)),
block_size
,
sh
a256
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req
)
assert
not
computed_blocks
.
blocks
[
0
]
assert
not
computed_blocks
.
blocks
[
0
]
assert
num_computed_tokens
==
0
assert
num_computed_tokens
==
0
...
@@ -638,7 +642,7 @@ def test_computed_blocks_not_evicted():
...
@@ -638,7 +642,7 @@ def test_computed_blocks_not_evicted():
# Allocate a block and cache it.
# Allocate a block and cache it.
num_tokens
=
block_size
*
1
num_tokens
=
block_size
*
1
req0
=
make_request
(
"0"
,
list
(
range
(
num_tokens
)),
block_size
,
ha
sh
)
req0
=
make_request
(
"0"
,
list
(
range
(
num_tokens
)),
block_size
,
sh
a256
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req0
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req0
)
assert
not
computed_blocks
.
blocks
[
0
]
assert
not
computed_blocks
.
blocks
[
0
]
assert
num_computed_tokens
==
0
assert
num_computed_tokens
==
0
...
@@ -650,7 +654,7 @@ def test_computed_blocks_not_evicted():
...
@@ -650,7 +654,7 @@ def test_computed_blocks_not_evicted():
# Allocate another block.
# Allocate another block.
req1
=
make_request
(
"1"
,
list
(
range
(
num_tokens
,
num_tokens
*
2
)),
req1
=
make_request
(
"1"
,
list
(
range
(
num_tokens
,
num_tokens
*
2
)),
block_size
,
ha
sh
)
block_size
,
sh
a256
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
assert
not
computed_blocks
.
blocks
[
0
]
assert
not
computed_blocks
.
blocks
[
0
]
assert
num_computed_tokens
==
0
assert
num_computed_tokens
==
0
...
@@ -666,7 +670,7 @@ def test_computed_blocks_not_evicted():
...
@@ -666,7 +670,7 @@ def test_computed_blocks_not_evicted():
# Now if we have a cache hit on the first block, we should evict the second
# Now if we have a cache hit on the first block, we should evict the second
# cached block rather than the first one.
# cached block rather than the first one.
req2
=
make_request
(
"2"
,
list
(
range
(
num_tokens
*
2
)),
block_size
,
ha
sh
)
req2
=
make_request
(
"2"
,
list
(
range
(
num_tokens
*
2
)),
block_size
,
sh
a256
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req2
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req2
)
assert
len
(
computed_blocks
.
blocks
[
0
])
==
1
assert
len
(
computed_blocks
.
blocks
[
0
])
==
1
assert
computed_blocks
.
blocks
[
0
][
0
].
block_id
==
1
assert
computed_blocks
.
blocks
[
0
][
0
].
block_id
==
1
...
@@ -691,7 +695,7 @@ def test_basic_prefix_caching_disabled():
...
@@ -691,7 +695,7 @@ def test_basic_prefix_caching_disabled():
)
)
req1
=
make_request
(
"1"
,
list
(
range
(
10
)),
block_size
,
req1
=
make_request
(
"1"
,
list
(
range
(
10
)),
block_size
,
ha
sh
)
# 2 blocks and some more
sh
a256
)
# 2 blocks and some more
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
assert
not
computed_blocks
.
blocks
[
0
]
assert
not
computed_blocks
.
blocks
[
0
]
...
@@ -706,7 +710,7 @@ def test_basic_prefix_caching_disabled():
...
@@ -706,7 +710,7 @@ def test_basic_prefix_caching_disabled():
# No caching.
# No caching.
req2
=
make_request
(
"2"
,
list
(
range
(
16
)),
block_size
,
req2
=
make_request
(
"2"
,
list
(
range
(
16
)),
block_size
,
ha
sh
)
# shared prefix
sh
a256
)
# shared prefix
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req2
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req2
)
assert
not
computed_blocks
.
blocks
[
0
]
assert
not
computed_blocks
.
blocks
[
0
]
assert
num_computed_tokens
==
0
assert
num_computed_tokens
==
0
...
@@ -716,7 +720,7 @@ def test_basic_prefix_caching_disabled():
...
@@ -716,7 +720,7 @@ def test_basic_prefix_caching_disabled():
assert
len
(
blocks
.
blocks
[
0
])
==
4
assert
len
(
blocks
.
blocks
[
0
])
==
4
# New requests should not have any blocks.
# New requests should not have any blocks.
req3
=
make_request
(
"3"
,
list
(
range
(
4
)),
block_size
,
ha
sh
)
req3
=
make_request
(
"3"
,
list
(
range
(
4
)),
block_size
,
sh
a256
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req3
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req3
)
assert
not
computed_blocks
.
blocks
[
0
]
assert
not
computed_blocks
.
blocks
[
0
]
assert
num_computed_tokens
==
0
assert
num_computed_tokens
==
0
...
@@ -726,7 +730,7 @@ def test_basic_prefix_caching_disabled():
...
@@ -726,7 +730,7 @@ def test_basic_prefix_caching_disabled():
assert
not
blocks
assert
not
blocks
@
pytest
.
mark
.
parametrize
(
"hash_fn"
,
[
sha256
,
sha256_cbor
_64bit
,
hash
])
@
pytest
.
mark
.
parametrize
(
"hash_fn"
,
[
sha256
,
sha256_cbor
])
def
test_cache_blocks
(
hash_fn
):
def
test_cache_blocks
(
hash_fn
):
"""
"""
This is a unit test that tests the correctness of the _cache_full_blocks
This is a unit test that tests the correctness of the _cache_full_blocks
...
@@ -787,7 +791,7 @@ def test_cache_blocks_multi_group():
...
@@ -787,7 +791,7 @@ def test_cache_blocks_multi_group():
# Block 1/5: [4, 5, 6, 7]
# Block 1/5: [4, 5, 6, 7]
# Block 2/6: [8, 9, 10, 11]
# Block 2/6: [8, 9, 10, 11]
# Block 3/7: [12, 13]
# Block 3/7: [12, 13]
req
=
make_request
(
"0"
,
list
(
range
(
14
)),
block_size
,
ha
sh
)
req
=
make_request
(
"0"
,
list
(
range
(
14
)),
block_size
,
sh
a256
)
# Cache the blocks for group 0.
# Cache the blocks for group 0.
blocks
=
[
KVCacheBlock
(
block_id
=
i
)
for
i
in
range
(
2
)]
blocks
=
[
KVCacheBlock
(
block_id
=
i
)
for
i
in
range
(
2
)]
...
@@ -845,6 +849,8 @@ def test_mm_prefix_caching():
...
@@ -845,6 +849,8 @@ def test_mm_prefix_caching():
"""
"""
This tests that the multi-modal prefix caching is correct.
This tests that the multi-modal prefix caching is correct.
"""
"""
kv_cache_utils
.
init_none_hash
(
sha256
)
block_size
=
16
block_size
=
16
manager
=
KVCacheManager
(
manager
=
KVCacheManager
(
make_kv_cache_config
(
block_size
,
11
),
make_kv_cache_config
(
block_size
,
11
),
...
@@ -874,23 +880,30 @@ def test_mm_prefix_caching():
...
@@ -874,23 +880,30 @@ def test_mm_prefix_caching():
req0
=
make_request
(
"0"
,
req0
=
make_request
(
"0"
,
all_token_ids
,
all_token_ids
,
block_size
,
block_size
,
ha
sh
,
sh
a256
,
mm_positions
=
mm_positions
,
mm_positions
=
mm_positions
,
mm_hashes
=
mm_hashes
)
mm_hashes
=
mm_hashes
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req0
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req0
)
# Completed block should have hashes
with extra keys.
# Completed block should have hashes
assert
not
computed_blocks
.
blocks
[
0
]
assert
not
computed_blocks
.
blocks
[
0
]
assert
num_computed_tokens
==
0
assert
num_computed_tokens
==
0
block_hashes
=
req0
.
block_hashes
block_hashes
=
req0
.
block_hashes
assert
len
(
block_hashes
)
==
3
assert
len
(
block_hashes
)
==
3
assert
block_hashes
[
0
].
extra_keys
==
(
"aaa"
,
)
assert
block_hashes
[
0
]
==
sha256
(
assert
block_hashes
[
1
].
extra_keys
==
(
"aaa"
,
"bbb"
)
(
kv_cache_utils
.
NONE_HASH
,
tuple
(
all_token_ids
[:
block_size
]),
assert
block_hashes
[
2
].
extra_keys
==
(
"bbb"
,
)
(
"aaa"
,
)))
assert
block_hashes
[
1
]
==
sha256
(
(
block_hashes
[
0
],
tuple
(
all_token_ids
[
block_size
:
block_size
*
2
]),
(
"aaa"
,
"bbb"
)))
assert
block_hashes
[
2
]
==
sha256
(
(
block_hashes
[
1
],
tuple
(
all_token_ids
[
block_size
*
2
:
block_size
*
3
]),
(
"bbb"
,
)))
blocks
=
manager
.
allocate_slots
(
req0
,
59
,
blocks
=
manager
.
allocate_slots
(
req0
,
59
,
len
(
computed_blocks
.
blocks
[
0
])
*
16
,
len
(
computed_blocks
.
blocks
[
0
])
*
16
,
computed_blocks
)
computed_blocks
)
assert
blocks
is
not
None
assert
blocks
.
get_block_ids
()
==
([
1
,
2
,
3
,
4
],
)
assert
blocks
.
get_block_ids
()
==
([
1
,
2
,
3
,
4
],
)
req0
.
num_computed_tokens
=
59
req0
.
num_computed_tokens
=
59
...
@@ -901,10 +914,10 @@ def test_mm_prefix_caching():
...
@@ -901,10 +914,10 @@ def test_mm_prefix_caching():
len
(
computed_blocks
.
blocks
[
0
])
*
16
,
len
(
computed_blocks
.
blocks
[
0
])
*
16
,
computed_blocks
)
computed_blocks
)
assert
new_blocks
is
not
None
and
len
(
new_blocks
.
blocks
[
0
])
==
0
assert
new_blocks
is
not
None
and
len
(
new_blocks
.
blocks
[
0
])
==
0
# The just completed block should have hashes with extra keys.
assert
len
(
block_hashes
)
==
4
assert
len
(
block_hashes
)
==
4
assert
block_hashes
[
3
].
extra_keys
==
(
"ccc"
,
)
assert
block_hashes
[
3
]
==
sha256
(
(
block_hashes
[
2
],
tuple
(
all_token_ids
[
3
*
block_size
:]
+
[
8
]
*
5
),
(
"ccc"
,
)))
# Cache hit.
# Cache hit.
unique_token_ids
=
[
-
1
]
*
7
+
[
200
]
*
5
unique_token_ids
=
[
-
1
]
*
7
+
[
200
]
*
5
...
@@ -916,7 +929,7 @@ def test_mm_prefix_caching():
...
@@ -916,7 +929,7 @@ def test_mm_prefix_caching():
req1
=
make_request
(
"1"
,
req1
=
make_request
(
"1"
,
all_token_ids
,
all_token_ids
,
block_size
,
block_size
,
ha
sh
,
sh
a256
,
mm_positions
=
mm_positions
,
mm_positions
=
mm_positions
,
mm_hashes
=
mm_hashes
)
mm_hashes
=
mm_hashes
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
...
@@ -929,6 +942,8 @@ def test_cache_key_salting():
...
@@ -929,6 +942,8 @@ def test_cache_key_salting():
This tests that cache salts are applied during hashing and the cache
This tests that cache salts are applied during hashing and the cache
is separated cache as expected.
is separated cache as expected.
"""
"""
kv_cache_utils
.
init_none_hash
(
sha256
)
block_size
=
16
block_size
=
16
manager
=
KVCacheManager
(
manager
=
KVCacheManager
(
make_kv_cache_config
(
block_size
,
11
),
make_kv_cache_config
(
block_size
,
11
),
...
@@ -939,21 +954,26 @@ def test_cache_key_salting():
...
@@ -939,21 +954,26 @@ def test_cache_key_salting():
# 3 complete blocks and an incomplete block with 11 tokens.
# 3 complete blocks and an incomplete block with 11 tokens.
common_token_ids
=
[
i
for
i
in
range
(
3
)
for
_
in
range
(
block_size
)]
common_token_ids
=
[
i
for
i
in
range
(
3
)
for
_
in
range
(
block_size
)]
token_ids
=
common_token_ids
+
[
3
]
*
11
token_ids
=
common_token_ids
+
[
3
]
*
11
req0
=
make_request
(
"0"
,
token_ids
,
block_size
,
ha
sh
,
cache_salt
=
"salt1"
)
req0
=
make_request
(
"0"
,
token_ids
,
block_size
,
sh
a256
,
cache_salt
=
"salt1"
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req0
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req0
)
# Completed block should have hashes
with extra keys.
# Completed block should have hashes
assert
not
computed_blocks
.
blocks
[
0
]
assert
not
computed_blocks
.
blocks
[
0
]
assert
num_computed_tokens
==
0
assert
num_computed_tokens
==
0
block_hashes
=
req0
.
block_hashes
block_hashes
=
req0
.
block_hashes
assert
len
(
block_hashes
)
==
3
assert
len
(
block_hashes
)
==
3
assert
block_hashes
[
0
].
extra_keys
==
(
"salt1"
,
)
assert
block_hashes
[
0
]
==
sha256
(
assert
block_hashes
[
1
].
extra_keys
is
None
(
kv_cache_utils
.
NONE_HASH
,
tuple
(
token_ids
[:
block_size
]),
(
"salt1"
,
)))
assert
block_hashes
[
2
].
extra_keys
is
None
assert
block_hashes
[
1
]
==
sha256
(
(
block_hashes
[
0
],
tuple
(
token_ids
[
block_size
:
block_size
*
2
]),
None
))
assert
block_hashes
[
2
]
==
sha256
(
(
block_hashes
[
1
],
tuple
(
token_ids
[
block_size
*
2
:
block_size
*
3
]),
None
))
blocks
=
manager
.
allocate_slots
(
req0
,
59
,
blocks
=
manager
.
allocate_slots
(
req0
,
59
,
len
(
computed_blocks
.
blocks
[
0
])
*
16
,
len
(
computed_blocks
.
blocks
[
0
])
*
16
,
computed_blocks
)
computed_blocks
)
assert
blocks
is
not
None
assert
blocks
.
get_block_ids
()
==
([
1
,
2
,
3
,
4
],
)
assert
blocks
.
get_block_ids
()
==
([
1
,
2
,
3
,
4
],
)
req0
.
num_computed_tokens
=
59
req0
.
num_computed_tokens
=
59
...
@@ -964,14 +984,13 @@ def test_cache_key_salting():
...
@@ -964,14 +984,13 @@ def test_cache_key_salting():
len
(
computed_blocks
.
blocks
[
0
])
*
16
,
len
(
computed_blocks
.
blocks
[
0
])
*
16
,
computed_blocks
)
computed_blocks
)
assert
new_blocks
is
not
None
and
len
(
new_blocks
.
blocks
[
0
])
==
0
assert
new_blocks
is
not
None
and
len
(
new_blocks
.
blocks
[
0
])
==
0
# Now one more block that should not have extra keys.
assert
len
(
block_hashes
)
==
4
assert
len
(
block_hashes
)
==
4
assert
block_hashes
[
3
].
extra_keys
is
None
assert
block_hashes
[
3
]
==
sha256
(
(
block_hashes
[
2
],
tuple
(
token_ids
[
3
*
block_size
:]
+
[
8
]
*
5
),
None
))
# Test cache hit with a new request that has the same salt.
# Test cache hit with a new request that has the same salt.
token_ids
=
common_token_ids
+
[
4
]
*
11
token_ids
=
common_token_ids
+
[
4
]
*
11
req1
=
make_request
(
"1"
,
token_ids
,
block_size
,
ha
sh
,
cache_salt
=
"salt1"
)
req1
=
make_request
(
"1"
,
token_ids
,
block_size
,
sh
a256
,
cache_salt
=
"salt1"
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
# Should match only a prefix of 3 blocks.
# Should match only a prefix of 3 blocks.
assert
len
(
computed_blocks
.
blocks
[
0
])
==
3
assert
len
(
computed_blocks
.
blocks
[
0
])
==
3
...
@@ -979,13 +998,19 @@ def test_cache_key_salting():
...
@@ -979,13 +998,19 @@ def test_cache_key_salting():
# Test cache miss with same content but different salt.
# Test cache miss with same content but different salt.
token_ids
=
common_token_ids
+
[
4
]
*
11
token_ids
=
common_token_ids
+
[
4
]
*
11
req2
=
make_request
(
"2"
,
token_ids
,
block_size
,
ha
sh
,
cache_salt
=
"salt2"
)
req2
=
make_request
(
"2"
,
token_ids
,
block_size
,
sh
a256
,
cache_salt
=
"salt2"
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req2
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req2
)
assert
len
(
computed_blocks
.
blocks
[
0
])
==
0
assert
len
(
computed_blocks
.
blocks
[
0
])
==
0
assert
num_computed_tokens
==
0
assert
num_computed_tokens
==
0
block_hashes
=
req2
.
block_hashes
block_hashes
=
req2
.
block_hashes
assert
len
(
block_hashes
)
==
3
assert
len
(
block_hashes
)
==
3
assert
block_hashes
[
0
].
extra_keys
==
(
"salt2"
,
)
assert
block_hashes
[
0
]
==
sha256
(
(
kv_cache_utils
.
NONE_HASH
,
tuple
(
token_ids
[:
block_size
]),
(
"salt2"
,
)))
assert
block_hashes
[
1
]
==
sha256
(
(
block_hashes
[
0
],
tuple
(
token_ids
[
block_size
:
block_size
*
2
]),
None
))
assert
block_hashes
[
2
]
==
sha256
(
(
block_hashes
[
1
],
tuple
(
token_ids
[
block_size
*
2
:
block_size
*
3
]),
None
))
def
test_prefill_not_enough_free_blocks_with_computed_blocks
():
def
test_prefill_not_enough_free_blocks_with_computed_blocks
():
...
@@ -1004,7 +1029,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
...
@@ -1004,7 +1029,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
# Complete 3 blocks (48 tokens)
# Complete 3 blocks (48 tokens)
# | Common-0 | Common-1 | Common-2 | ... |
# | Common-0 | Common-1 | Common-2 | ... |
common_token_ids
=
[
i
for
i
in
range
(
3
)
for
_
in
range
(
16
)]
common_token_ids
=
[
i
for
i
in
range
(
3
)
for
_
in
range
(
16
)]
req0
=
make_request
(
"0"
,
common_token_ids
,
block_size
,
ha
sh
)
req0
=
make_request
(
"0"
,
common_token_ids
,
block_size
,
sh
a256
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req0
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req0
)
assert
not
computed_blocks
.
blocks
[
0
]
assert
not
computed_blocks
.
blocks
[
0
]
assert
num_computed_tokens
==
0
assert
num_computed_tokens
==
0
...
@@ -1015,7 +1040,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
...
@@ -1015,7 +1040,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
req0
.
request_id
]
req0
.
request_id
]
# | Common-0 | Common-1 | Common-2 | Req1-3 | Req1-4 | Req1-5 | ... |
# | Common-0 | Common-1 | Common-2 | Req1-3 | Req1-4 | Req1-5 | ... |
req1
=
make_request
(
"1"
,
common_token_ids
*
2
,
block_size
,
ha
sh
)
req1
=
make_request
(
"1"
,
common_token_ids
*
2
,
block_size
,
sh
a256
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
assert
computed_blocks
.
blocks
[
0
]
==
block_part0
assert
computed_blocks
.
blocks
[
0
]
==
block_part0
assert
num_computed_tokens
==
3
*
16
assert
num_computed_tokens
==
3
*
16
...
@@ -1032,7 +1057,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
...
@@ -1032,7 +1057,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
# | Common-0 | Common-1 | Common-2 | Req1-3 (F) | Req1-4 (F) |
# | Common-0 | Common-1 | Common-2 | Req1-3 (F) | Req1-4 (F) |
# | Req1-5(F)| Req2-0 | Req2-1 | ... |
# | Req1-5(F)| Req2-0 | Req2-1 | ... |
req2
=
make_request
(
"2"
,
[
7
]
*
block_size
*
2
,
block_size
,
ha
sh
)
req2
=
make_request
(
"2"
,
[
7
]
*
block_size
*
2
,
block_size
,
sh
a256
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req2
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req2
)
assert
not
computed_blocks
.
blocks
[
0
]
assert
not
computed_blocks
.
blocks
[
0
]
assert
num_computed_tokens
==
0
assert
num_computed_tokens
==
0
...
@@ -1044,7 +1069,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
...
@@ -1044,7 +1069,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
# but it cannot be allocated due to insufficient free blocks (2).
# but it cannot be allocated due to insufficient free blocks (2).
# In this case, the ref_cnt of the computed blocks should not be changed.
# In this case, the ref_cnt of the computed blocks should not be changed.
assert
manager
.
block_pool
.
free_block_queue
.
num_free_blocks
==
5
assert
manager
.
block_pool
.
free_block_queue
.
num_free_blocks
==
5
req3
=
make_request
(
"3"
,
common_token_ids
*
3
,
block_size
,
ha
sh
)
req3
=
make_request
(
"3"
,
common_token_ids
*
3
,
block_size
,
sh
a256
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req3
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req3
)
assert
computed_blocks
.
blocks
[
0
]
==
block_part1
assert
computed_blocks
.
blocks
[
0
]
==
block_part1
assert
num_computed_tokens
==
6
*
16
assert
num_computed_tokens
==
6
*
16
...
@@ -1069,13 +1094,13 @@ def test_reset_prefix_cache():
...
@@ -1069,13 +1094,13 @@ def test_reset_prefix_cache():
full_block_token_ids
=
[
i
for
i
in
range
(
3
)
for
_
in
range
(
16
)]
full_block_token_ids
=
[
i
for
i
in
range
(
3
)
for
_
in
range
(
16
)]
unique_token_ids
=
[
3
]
*
7
unique_token_ids
=
[
3
]
*
7
all_token_ids
=
full_block_token_ids
+
unique_token_ids
all_token_ids
=
full_block_token_ids
+
unique_token_ids
req0
=
make_request
(
"0"
,
all_token_ids
,
block_size
,
ha
sh
)
req0
=
make_request
(
"0"
,
all_token_ids
,
block_size
,
sh
a256
)
blocks
=
manager
.
allocate_slots
(
req0
,
55
)
blocks
=
manager
.
allocate_slots
(
req0
,
55
)
assert
blocks
.
get_block_ids
()
==
([
1
,
2
,
3
,
4
],
)
assert
blocks
.
get_block_ids
()
==
([
1
,
2
,
3
,
4
],
)
unique_token_ids
=
[
4
]
*
7
unique_token_ids
=
[
4
]
*
7
all_token_ids
=
full_block_token_ids
+
unique_token_ids
all_token_ids
=
full_block_token_ids
+
unique_token_ids
req1
=
make_request
(
"1"
,
all_token_ids
,
block_size
,
ha
sh
)
req1
=
make_request
(
"1"
,
all_token_ids
,
block_size
,
sh
a256
)
computed_blocks
,
_
=
manager
.
get_computed_blocks
(
req1
)
computed_blocks
,
_
=
manager
.
get_computed_blocks
(
req1
)
assert
len
(
req1
.
block_hashes
)
==
3
assert
len
(
req1
.
block_hashes
)
==
3
assert
len
(
computed_blocks
.
blocks
[
0
])
==
3
assert
len
(
computed_blocks
.
blocks
[
0
])
==
3
...
@@ -1109,7 +1134,7 @@ def test_prefix_cache_stats_disabled():
...
@@ -1109,7 +1134,7 @@ def test_prefix_cache_stats_disabled():
assert
manager
.
prefix_cache_stats
is
None
assert
manager
.
prefix_cache_stats
is
None
# Call all functions that check whether log_stats is disabled.
# Call all functions that check whether log_stats is disabled.
req
=
make_request
(
"0"
,
list
(
range
(
16
)),
block_size
,
ha
sh
)
req
=
make_request
(
"0"
,
list
(
range
(
16
)),
block_size
,
sh
a256
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req
)
assert
not
computed_blocks
.
blocks
[
0
]
assert
not
computed_blocks
.
blocks
[
0
]
assert
num_computed_tokens
==
0
assert
num_computed_tokens
==
0
...
@@ -1124,15 +1149,9 @@ def test_prefix_cache_stats_disabled():
...
@@ -1124,15 +1149,9 @@ def test_prefix_cache_stats_disabled():
def
test_maybe_evict_cached_block
():
def
test_maybe_evict_cached_block
():
pool
=
BlockPool
(
num_gpu_blocks
=
4
,
enable_caching
=
True
)
pool
=
BlockPool
(
num_gpu_blocks
=
4
,
enable_caching
=
True
)
block_hash0
=
BlockHashWithGroupId
(
block_hash
=
BlockHash
(
hash_value
=
10
,
block_hash0
=
make_block_hash_with_group_id
(
BlockHash
(
b
"10"
),
1000
)
token_ids
=
(
100
,
)),
block_hash1
=
make_block_hash_with_group_id
(
BlockHash
(
b
"20"
),
2000
)
group_id
=
1000
)
block_hash2
=
make_block_hash_with_group_id
(
BlockHash
(
b
"30"
),
3000
)
block_hash1
=
BlockHashWithGroupId
(
block_hash
=
BlockHash
(
hash_value
=
20
,
token_ids
=
(
200
,
)),
group_id
=
2000
)
block_hash2
=
BlockHashWithGroupId
(
block_hash
=
BlockHash
(
hash_value
=
30
,
token_ids
=
(
300
,
)),
group_id
=
3000
)
block_hashes
=
[
block_hashes
=
[
block_hash0
,
block_hash0
,
block_hash1
,
block_hash1
,
...
@@ -1206,7 +1225,7 @@ def test_kv_cache_events(blocks_to_cache: int):
...
@@ -1206,7 +1225,7 @@ def test_kv_cache_events(blocks_to_cache: int):
)
)
num_tokens
=
block_size
*
blocks_to_cache
num_tokens
=
block_size
*
blocks_to_cache
req0
=
make_request
(
"0"
,
list
(
range
(
num_tokens
)),
block_size
,
ha
sh
)
req0
=
make_request
(
"0"
,
list
(
range
(
num_tokens
)),
block_size
,
sh
a256
)
_
=
manager
.
allocate_slots
(
req0
,
num_tokens
)
_
=
manager
.
allocate_slots
(
req0
,
num_tokens
)
events
=
manager
.
take_events
()
events
=
manager
.
take_events
()
...
@@ -1222,7 +1241,7 @@ def test_kv_cache_events(blocks_to_cache: int):
...
@@ -1222,7 +1241,7 @@ def test_kv_cache_events(blocks_to_cache: int):
# Should see block_to_cache number of removed block events and a new block
# Should see block_to_cache number of removed block events and a new block
# stored event
# stored event
manager
.
free
(
req0
)
manager
.
free
(
req0
)
req1
=
make_request
(
"1"
,
list
(
range
(
num_tokens
)),
block_size
,
ha
sh
)
req1
=
make_request
(
"1"
,
list
(
range
(
num_tokens
)),
block_size
,
sh
a256
)
_
=
manager
.
allocate_slots
(
req1
,
num_tokens
)
_
=
manager
.
allocate_slots
(
req1
,
num_tokens
)
events
=
manager
.
take_events
()
events
=
manager
.
take_events
()
...
@@ -1256,7 +1275,7 @@ def test_eagle_enabled_removes_last_block():
...
@@ -1256,7 +1275,7 @@ def test_eagle_enabled_removes_last_block():
# Request with 3 full blocks (48 tokens)
# Request with 3 full blocks (48 tokens)
token_ids
=
[
0
]
*
(
3
*
block_size
)
token_ids
=
[
0
]
*
(
3
*
block_size
)
req
=
make_request
(
"divisible_request"
,
token_ids
,
block_size
,
ha
sh
)
req
=
make_request
(
"divisible_request"
,
token_ids
,
block_size
,
sh
a256
)
# Prime the cache
# Prime the cache
computed_blocks
,
_
=
manager
.
get_computed_blocks
(
req
)
computed_blocks
,
_
=
manager
.
get_computed_blocks
(
req
)
...
@@ -1266,7 +1285,7 @@ def test_eagle_enabled_removes_last_block():
...
@@ -1266,7 +1285,7 @@ def test_eagle_enabled_removes_last_block():
manager
.
free
(
req
)
manager
.
free
(
req
)
# New request with same tokens + Eagle enabled
# New request with same tokens + Eagle enabled
req_eagle
=
make_request
(
"eagle_divisible"
,
token_ids
,
block_size
,
ha
sh
)
req_eagle
=
make_request
(
"eagle_divisible"
,
token_ids
,
block_size
,
sh
a256
)
computed_blocks
,
num_tokens
=
manager
.
get_computed_blocks
(
req_eagle
)
computed_blocks
,
num_tokens
=
manager
.
get_computed_blocks
(
req_eagle
)
# Should retain 1 block:
# Should retain 1 block:
...
@@ -1287,7 +1306,7 @@ def test_eagle_with_partial_blocks():
...
@@ -1287,7 +1306,7 @@ def test_eagle_with_partial_blocks():
)
)
# 2 full blocks + 5 tokens (non-divisible length)
# 2 full blocks + 5 tokens (non-divisible length)
token_ids
=
[
0
]
*
(
2
*
block_size
+
5
)
token_ids
=
[
0
]
*
(
2
*
block_size
+
5
)
req
=
make_request
(
"partial_block_test"
,
token_ids
,
block_size
,
ha
sh
)
req
=
make_request
(
"partial_block_test"
,
token_ids
,
block_size
,
sh
a256
)
# Prime the cache
# Prime the cache
computed_blocks
,
_
=
manager
.
get_computed_blocks
(
req
)
computed_blocks
,
_
=
manager
.
get_computed_blocks
(
req
)
...
@@ -1297,7 +1316,7 @@ def test_eagle_with_partial_blocks():
...
@@ -1297,7 +1316,7 @@ def test_eagle_with_partial_blocks():
manager
.
free
(
req
)
manager
.
free
(
req
)
# New request with Eagle enabled
# New request with Eagle enabled
req_eagle
=
make_request
(
"partial_eagle"
,
token_ids
,
block_size
,
ha
sh
)
req_eagle
=
make_request
(
"partial_eagle"
,
token_ids
,
block_size
,
sh
a256
)
computed_blocks
,
num_tokens
=
manager
.
get_computed_blocks
(
req_eagle
)
computed_blocks
,
num_tokens
=
manager
.
get_computed_blocks
(
req_eagle
)
# Original match: 2 full blocks → Eagle removes 1 → 1 remaining
# Original match: 2 full blocks → Eagle removes 1 → 1 remaining
assert
len
(
computed_blocks
.
blocks
[
0
])
==
1
assert
len
(
computed_blocks
.
blocks
[
0
])
==
1
...
@@ -1328,7 +1347,7 @@ def test_eagle_with_sliding_window():
...
@@ -1328,7 +1347,7 @@ def test_eagle_with_sliding_window():
# 2 full blocks + 5 tokens (non-divisible length)
# 2 full blocks + 5 tokens (non-divisible length)
token_ids
=
[
0
]
*
(
2
*
block_size
+
5
)
token_ids
=
[
0
]
*
(
2
*
block_size
+
5
)
req
=
make_request
(
"partial_block_test"
,
token_ids
,
block_size
,
ha
sh
)
req
=
make_request
(
"partial_block_test"
,
token_ids
,
block_size
,
sh
a256
)
# Prime the cache
# Prime the cache
computed_blocks
,
_
=
manager
.
get_computed_blocks
(
req
)
computed_blocks
,
_
=
manager
.
get_computed_blocks
(
req
)
...
@@ -1341,7 +1360,7 @@ def test_eagle_with_sliding_window():
...
@@ -1341,7 +1360,7 @@ def test_eagle_with_sliding_window():
manager
.
free
(
req
)
manager
.
free
(
req
)
# New request with Eagle enabled
# New request with Eagle enabled
req_eagle
=
make_request
(
"partial_eagle"
,
token_ids
,
block_size
,
ha
sh
)
req_eagle
=
make_request
(
"partial_eagle"
,
token_ids
,
block_size
,
sh
a256
)
computed_blocks
,
num_tokens
=
manager
.
get_computed_blocks
(
req_eagle
)
computed_blocks
,
num_tokens
=
manager
.
get_computed_blocks
(
req_eagle
)
# Original match: 2 full blocks → Eagle removes 1 → 1 remaining
# Original match: 2 full blocks → Eagle removes 1 → 1 remaining
assert
len
(
computed_blocks
.
blocks
[
0
])
==
1
assert
len
(
computed_blocks
.
blocks
[
0
])
==
1
...
@@ -1351,11 +1370,11 @@ def test_eagle_with_sliding_window():
...
@@ -1351,11 +1370,11 @@ def test_eagle_with_sliding_window():
assert
manager
.
block_pool
.
get_cached_block
(
assert
manager
.
block_pool
.
get_cached_block
(
block_hash_first_block
,
kv_cache_group_ids
=
[
0
])
is
not
None
block_hash_first_block
,
kv_cache_group_ids
=
[
0
])
is
not
None
manager
.
block_pool
.
cached_block_hash_to_block
.
pop
(
manager
.
block_pool
.
cached_block_hash_to_block
.
pop
(
B
lock
H
ash
W
ith
G
roup
I
d
(
block_hash_first_block
,
0
))
make_b
lock
_h
ash
_w
ith
_g
roup
_i
d
(
block_hash_first_block
,
0
))
# New request
# New request
req_after_evict
=
make_request
(
"partial_eagle_after_evict"
,
token_ids
,
req_after_evict
=
make_request
(
"partial_eagle_after_evict"
,
token_ids
,
block_size
,
ha
sh
)
block_size
,
sh
a256
)
computed_blocks
,
num_tokens
=
manager
.
get_computed_blocks
(
req_after_evict
)
computed_blocks
,
num_tokens
=
manager
.
get_computed_blocks
(
req_after_evict
)
# Cache miss. The only hit prefix is [NULL_BLOCK, BLOCK_2] if eagle is
# Cache miss. The only hit prefix is [NULL_BLOCK, BLOCK_2] if eagle is
# not considered. But after dropping the last matched block due to eagle,
# not considered. But after dropping the last matched block due to eagle,
...
...
tests/v1/core/test_single_type_kv_cache_manager.py
View file @
38d80967
...
@@ -6,8 +6,8 @@ import random
...
@@ -6,8 +6,8 @@ import random
import
torch
import
torch
from
vllm.v1.core.block_pool
import
BlockPool
from
vllm.v1.core.block_pool
import
BlockPool
from
vllm.v1.core.kv_cache_utils
import
(
BlockHash
,
BlockHashWithGroupId
,
from
vllm.v1.core.kv_cache_utils
import
(
BlockHash
,
KVCacheBlock
,
KVCacheBlock
)
make_block_hash_with_group_id
)
from
vllm.v1.core.single_type_kv_cache_manager
import
(
from
vllm.v1.core.single_type_kv_cache_manager
import
(
ChunkedLocalAttentionManager
,
SlidingWindowManager
)
ChunkedLocalAttentionManager
,
SlidingWindowManager
)
from
vllm.v1.kv_cache_interface
import
(
ChunkedLocalAttentionSpec
,
from
vllm.v1.kv_cache_interface
import
(
ChunkedLocalAttentionSpec
,
...
@@ -44,7 +44,7 @@ def test_chunked_local_attention_possible_cached_prefix():
...
@@ -44,7 +44,7 @@ def test_chunked_local_attention_possible_cached_prefix():
def
run_one_case
(
block_is_cached
,
tail_token
,
expect_length
):
def
run_one_case
(
block_is_cached
,
tail_token
,
expect_length
):
block_hash_list
=
[
block_hash_list
=
[
BlockHash
(
i
,
())
for
i
in
range
(
len
(
block_is_cached
))
BlockHash
(
str
(
i
).
encode
())
for
i
in
range
(
len
(
block_is_cached
))
]
]
block_pool
.
cached_block_hash_to_block
.
clear
()
block_pool
.
cached_block_hash_to_block
.
clear
()
...
@@ -53,8 +53,8 @@ def test_chunked_local_attention_possible_cached_prefix():
...
@@ -53,8 +53,8 @@ def test_chunked_local_attention_possible_cached_prefix():
for
i
,
(
block_hash
,
for
i
,
(
block_hash
,
is_cached
)
in
enumerate
(
zip
(
block_hash_list
,
block_is_cached
)):
is_cached
)
in
enumerate
(
zip
(
block_hash_list
,
block_is_cached
)):
if
is_cached
:
if
is_cached
:
block_pool
.
cached_block_hash_to_block
[
BlockHashWithGroupId
(
block_pool
.
cached_block_hash_to_block
[
block_hash
,
0
)]
=
{
make_block_hash_with_group_id
(
block_hash
,
0
)]
=
{
i
:
block_pool
.
blocks
[
i
+
10
],
i
:
block_pool
.
blocks
[
i
+
10
],
}
}
...
@@ -109,7 +109,7 @@ def test_sliding_window_possible_cached_prefix():
...
@@ -109,7 +109,7 @@ def test_sliding_window_possible_cached_prefix():
def
run_one_case
(
block_is_cached
,
expect_length
):
def
run_one_case
(
block_is_cached
,
expect_length
):
block_hash_list
=
[
block_hash_list
=
[
BlockHash
(
i
,
())
for
i
in
range
(
len
(
block_is_cached
))
BlockHash
(
str
(
i
).
encode
())
for
i
in
range
(
len
(
block_is_cached
))
]
]
block_pool
.
cached_block_hash_to_block
.
clear
()
block_pool
.
cached_block_hash_to_block
.
clear
()
...
@@ -118,8 +118,8 @@ def test_sliding_window_possible_cached_prefix():
...
@@ -118,8 +118,8 @@ def test_sliding_window_possible_cached_prefix():
for
i
,
(
block_hash
,
for
i
,
(
block_hash
,
is_cached
)
in
enumerate
(
zip
(
block_hash_list
,
block_is_cached
)):
is_cached
)
in
enumerate
(
zip
(
block_hash_list
,
block_is_cached
)):
if
is_cached
:
if
is_cached
:
block_pool
.
cached_block_hash_to_block
[
BlockHashWithGroupId
(
block_pool
.
cached_block_hash_to_block
[
block_hash
,
0
)]
=
{
make_block_hash_with_group_id
(
block_hash
,
0
)]
=
{
i
:
block_pool
.
blocks
[
i
+
10
],
i
:
block_pool
.
blocks
[
i
+
10
],
}
}
...
...
tests/v1/core/utils.py
View file @
38d80967
...
@@ -9,6 +9,7 @@ from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig,
...
@@ -9,6 +9,7 @@ from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig,
from
vllm.multimodal.inputs
import
(
MultiModalFeatureSpec
,
from
vllm.multimodal.inputs
import
(
MultiModalFeatureSpec
,
MultiModalKwargsItem
,
PlaceholderRange
)
MultiModalKwargsItem
,
PlaceholderRange
)
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.utils
import
sha256
from
vllm.v1.core.kv_cache_utils
import
(
get_request_block_hasher
,
from
vllm.v1.core.kv_cache_utils
import
(
get_request_block_hasher
,
init_none_hash
)
init_none_hash
)
from
vllm.v1.core.sched.async_scheduler
import
AsyncScheduler
from
vllm.v1.core.sched.async_scheduler
import
AsyncScheduler
...
@@ -130,10 +131,10 @@ def create_requests(
...
@@ -130,10 +131,10 @@ def create_requests(
)
->
list
[
Request
]:
)
->
list
[
Request
]:
global
_none_hash_initialized
global
_none_hash_initialized
if
not
_none_hash_initialized
:
if
not
_none_hash_initialized
:
init_none_hash
(
ha
sh
)
init_none_hash
(
sh
a256
)
_none_hash_initialized
=
True
_none_hash_initialized
=
True
block_hasher
=
get_request_block_hasher
(
block_size
,
ha
sh
)
block_hasher
=
get_request_block_hasher
(
block_size
,
sh
a256
)
sampling_params
=
SamplingParams
(
ignore_eos
=
False
,
sampling_params
=
SamplingParams
(
ignore_eos
=
False
,
max_tokens
=
max_tokens
,
max_tokens
=
max_tokens
,
stop_token_ids
=
stop_token_ids
,
stop_token_ids
=
stop_token_ids
,
...
...
tests/v1/cudagraph/test_cudagraph_mode.py
View file @
38d80967
...
@@ -62,6 +62,16 @@ backend_configs = {
...
@@ -62,6 +62,16 @@ backend_configs = {
"cudagraph_mode"
:
"FULL_AND_PIECEWISE"
,
"cudagraph_mode"
:
"FULL_AND_PIECEWISE"
,
},
},
specific_gpu_arch
=
(
9
,
0
)),
specific_gpu_arch
=
(
9
,
0
)),
# FlashAttention MLA on Hopper
"FlashAttentionMLA"
:
BackendConfig
(
name
=
"FlashAttentionMLA"
,
env_vars
=
{
"VLLM_ATTENTION_BACKEND"
:
"FLASH_ATTN_MLA"
,
},
comp_config
=
{
"cudagraph_mode"
:
"FULL_DECODE_ONLY"
,
},
specific_gpu_arch
=
(
9
,
0
)),
# FA2
# FA2
"FA2"
:
"FA2"
:
BackendConfig
(
name
=
"FA2"
,
BackendConfig
(
name
=
"FA2"
,
...
...
tests/v1/e2e/test_spec_decode.py
View file @
38d80967
...
@@ -83,7 +83,7 @@ def test_ngram_correctness(
...
@@ -83,7 +83,7 @@ def test_ngram_correctness(
model_name
:
str
,
model_name
:
str
,
):
):
'''
'''
Compare the outputs of a original LLM and a speculative LLM
Compare the outputs of a
n
original LLM and a speculative LLM
should be the same when using ngram speculative decoding.
should be the same when using ngram speculative decoding.
'''
'''
with
monkeypatch
.
context
()
as
m
:
with
monkeypatch
.
context
()
as
m
:
...
@@ -117,45 +117,38 @@ def test_ngram_correctness(
...
@@ -117,45 +117,38 @@ def test_ngram_correctness(
print
(
f
"ref_output:
{
ref_output
.
outputs
[
0
].
text
}
"
)
print
(
f
"ref_output:
{
ref_output
.
outputs
[
0
].
text
}
"
)
print
(
f
"spec_output:
{
spec_output
.
outputs
[
0
].
text
}
"
)
print
(
f
"spec_output:
{
spec_output
.
outputs
[
0
].
text
}
"
)
# Heuristic: expect at least
70
% of the prompts to match exactly
# Heuristic: expect at least
68
% of the prompts to match exactly
# Upon failure, inspect the outputs to check for inaccuracy.
# Upon failure, inspect the outputs to check for inaccuracy.
assert
matches
>
int
(
0.
7
*
len
(
ref_outputs
))
assert
matches
>
=
int
(
0.
68
*
len
(
ref_outputs
))
del
spec_llm
del
spec_llm
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
cleanup_dist_env_and_memory
()
cleanup_dist_env_and_memory
()
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
([
"model_setup"
,
"mm_enabled"
],
[
[
"model_setup"
,
"mm_enabled"
],
((
"eagle3"
,
"Qwen/Qwen3-8B"
,
"AngelSlim/Qwen3-8B_eagle3"
,
1
),
False
),
[
((
"eagle"
,
"meta-llama/Llama-3.1-8B-Instruct"
,
# TODO: Re-enable this once tests/models/test_initialization.py is fixed, see PR #22333 #22611 # noqa: E501
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B"
,
1
),
False
),
# (("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False),
((
"eagle3"
,
"meta-llama/Llama-3.1-8B-Instruct"
,
((
"eagle"
,
"meta-llama/Llama-3.1-8B-Instruct"
,
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
,
1
),
False
),
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B"
,
1
),
False
),
pytest
.
param
(
((
"eagle3"
,
"meta-llama/Llama-3.1-8B-Instruct"
,
(
"eagle"
,
"meta-llama/Llama-4-Scout-17B-16E-Instruct"
,
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
,
1
),
False
),
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct"
,
4
),
pytest
.
param
(
False
,
(
"eagle"
,
"meta-llama/Llama-4-Scout-17B-16E-Instruct"
,
marks
=
pytest
.
mark
.
skip
(
reason
=
"Skipping due to CI OOM issues"
)),
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct"
,
4
),
pytest
.
param
(
False
,
(
"eagle"
,
"meta-llama/Llama-4-Scout-17B-16E-Instruct"
,
marks
=
pytest
.
mark
.
skip
(
reason
=
"Skipping due to CI OOM issues"
)),
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct"
,
4
),
pytest
.
param
(
True
,
(
"eagle"
,
"meta-llama/Llama-4-Scout-17B-16E-Instruct"
,
marks
=
pytest
.
mark
.
skip
(
reason
=
"Skipping due to CI OOM issues"
)),
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct"
,
4
),
((
"eagle"
,
"eagle618/deepseek-v3-random"
,
True
,
"eagle618/eagle-deepseek-v3-random"
,
1
),
False
),
marks
=
pytest
.
mark
.
skip
(
reason
=
"Skipping due to CI OOM issues"
)),
],
((
"eagle"
,
"eagle618/deepseek-v3-random"
,
ids
=
[
"eagle618/eagle-deepseek-v3-random"
,
1
),
False
),
"qwen3_eagle3"
,
"llama3_eagle"
,
"llama3_eagle3"
,
],
"llama4_eagle"
,
"llama4_eagle_mm"
,
ids
=
[
"deepseek_eagle"
# TODO: Re-enable this once tests/models/test_initialization.py is fixed, see PR #22333 #22611 # noqa: E501
])
# "qwen3_eagle3",
"llama3_eagle"
,
"llama3_eagle3"
,
"llama4_eagle"
,
"llama4_eagle_mm"
,
"deepseek_eagle"
])
@
pytest
.
mark
.
parametrize
(
"attn_backend"
,
@
pytest
.
mark
.
parametrize
(
"attn_backend"
,
get_attn_backend_list_based_on_platform
())
get_attn_backend_list_based_on_platform
())
def
test_eagle_correctness
(
def
test_eagle_correctness
(
...
@@ -169,7 +162,7 @@ def test_eagle_correctness(
...
@@ -169,7 +162,7 @@ def test_eagle_correctness(
# TODO: Fix this flaky test
# TODO: Fix this flaky test
pytest
.
skip
(
pytest
.
skip
(
"TREE_ATTN is flaky in the test disable for now until it can be "
"TREE_ATTN is flaky in the test disable for now until it can be "
"reolved (see https://github.com/vllm-project/vllm/issues/22922)"
)
"re
s
olved (see https://github.com/vllm-project/vllm/issues/22922)"
)
# Generate test prompts inside the function instead of using fixture
# Generate test prompts inside the function instead of using fixture
test_prompts
=
get_test_prompts
(
mm_enabled
)
test_prompts
=
get_test_prompts
(
mm_enabled
)
...
...
tests/v1/engine/test_async_llm.py
View file @
38d80967
...
@@ -393,7 +393,7 @@ class MockLoggingStatLogger(LoggingStatLogger):
...
@@ -393,7 +393,7 @@ class MockLoggingStatLogger(LoggingStatLogger):
async
def
test_customize_loggers
(
monkeypatch
):
async
def
test_customize_loggers
(
monkeypatch
):
"""Test that we can customize the loggers.
"""Test that we can customize the loggers.
If a customized logger is provided at the init, it should
If a customized logger is provided at the init, it should
be
used directly
.
be
added to the default loggers
.
"""
"""
with
monkeypatch
.
context
()
as
m
,
ExitStack
()
as
after
:
with
monkeypatch
.
context
()
as
m
,
ExitStack
()
as
after
:
...
@@ -410,7 +410,8 @@ async def test_customize_loggers(monkeypatch):
...
@@ -410,7 +410,8 @@ async def test_customize_loggers(monkeypatch):
stat_loggers
=
engine
.
logger_manager
.
per_engine_logger_dict
stat_loggers
=
engine
.
logger_manager
.
per_engine_logger_dict
assert
len
(
stat_loggers
)
==
1
assert
len
(
stat_loggers
)
==
1
assert
len
(
stat_loggers
[
0
])
==
1
assert
len
(
stat_loggers
[
0
])
==
2
# LoggingStatLogger + MockLoggingStatLogger
stat_loggers
[
0
][
0
].
log
.
assert_called_once
()
stat_loggers
[
0
][
0
].
log
.
assert_called_once
()
...
...
tests/v1/engine/test_engine_args.py
View file @
38d80967
...
@@ -36,18 +36,19 @@ def test_prefix_caching_from_cli():
...
@@ -36,18 +36,19 @@ def test_prefix_caching_from_cli():
assert
vllm_config
.
cache_config
.
enable_prefix_caching
assert
vllm_config
.
cache_config
.
enable_prefix_caching
# default hash algorithm is "builtin"
# default hash algorithm is "builtin"
assert
vllm_config
.
cache_config
.
prefix_caching_hash_algo
==
"builtin"
assert
vllm_config
.
cache_config
.
prefix_caching_hash_algo
==
"sha256"
# set hash algorithm to sha256_cbor
args
=
parser
.
parse_args
([
"--prefix-caching-hash-algo"
,
"sha256_cbor"
])
vllm_config
=
EngineArgs
.
from_cli_args
(
args
=
args
).
create_engine_config
()
assert
vllm_config
.
cache_config
.
prefix_caching_hash_algo
==
\
"sha256_cbor"
# set hash algorithm to sha256
# set hash algorithm to sha256
args
=
parser
.
parse_args
([
"--prefix-caching-hash-algo"
,
"sha256"
])
args
=
parser
.
parse_args
([
"--prefix-caching-hash-algo"
,
"sha256"
])
vllm_config
=
EngineArgs
.
from_cli_args
(
args
=
args
).
create_engine_config
()
vllm_config
=
EngineArgs
.
from_cli_args
(
args
=
args
).
create_engine_config
()
assert
vllm_config
.
cache_config
.
prefix_caching_hash_algo
==
"sha256"
assert
vllm_config
.
cache_config
.
prefix_caching_hash_algo
==
"sha256"
# set hash algorithm to builtin
args
=
parser
.
parse_args
([
"--prefix-caching-hash-algo"
,
"builtin"
])
vllm_config
=
EngineArgs
.
from_cli_args
(
args
=
args
).
create_engine_config
()
assert
vllm_config
.
cache_config
.
prefix_caching_hash_algo
==
"builtin"
# an invalid hash algorithm raises an error
# an invalid hash algorithm raises an error
parser
.
exit_on_error
=
False
parser
.
exit_on_error
=
False
with
pytest
.
raises
(
ArgumentError
):
with
pytest
.
raises
(
ArgumentError
):
...
...
tests/v1/engine/test_processor_multi_modal_uuids.py
View file @
38d80967
...
@@ -152,8 +152,8 @@ def test_multi_modal_uuids_accepts_none_and_passes_through(
...
@@ -152,8 +152,8 @@ def test_multi_modal_uuids_accepts_none_and_passes_through(
*
,
*
,
tokenization_kwargs
=
None
,
tokenization_kwargs
=
None
,
lora_request
=
None
,
lora_request
=
None
,
mm_
hash_overr
id
e
s
=
None
):
mm_
uu
ids
=
None
):
captured
[
"mm_
hash_overr
id
e
s"
]
=
mm_
hash_overr
id
e
s
captured
[
"mm_
uu
ids"
]
=
mm_
uu
ids
# Minimal processed inputs for decoder-only flow
# Minimal processed inputs for decoder-only flow
return
{
"type"
:
"token"
,
"prompt_token_ids"
:
[
1
]}
return
{
"type"
:
"token"
,
"prompt_token_ids"
:
[
1
]}
...
@@ -180,7 +180,7 @@ def test_multi_modal_uuids_accepts_none_and_passes_through(
...
@@ -180,7 +180,7 @@ def test_multi_modal_uuids_accepts_none_and_passes_through(
params
=
SamplingParams
(),
params
=
SamplingParams
(),
)
)
assert
captured
[
"mm_
hash_overr
id
e
s"
]
==
mm_uuids
assert
captured
[
"mm_
uu
ids"
]
==
mm_uuids
def
test_multi_modal_uuids_ignored_when_caching_disabled
(
monkeypatch
):
def
test_multi_modal_uuids_ignored_when_caching_disabled
(
monkeypatch
):
...
@@ -196,8 +196,8 @@ def test_multi_modal_uuids_ignored_when_caching_disabled(monkeypatch):
...
@@ -196,8 +196,8 @@ def test_multi_modal_uuids_ignored_when_caching_disabled(monkeypatch):
*
,
*
,
tokenization_kwargs
=
None
,
tokenization_kwargs
=
None
,
lora_request
=
None
,
lora_request
=
None
,
mm_
hash_overr
id
e
s
=
None
):
mm_
uu
ids
=
None
):
captured
[
"mm_
hash_overr
id
e
s"
]
=
mm_
hash_overr
id
e
s
captured
[
"mm_
uu
ids"
]
=
mm_
uu
ids
return
{
"type"
:
"token"
,
"prompt_token_ids"
:
[
1
]}
return
{
"type"
:
"token"
,
"prompt_token_ids"
:
[
1
]}
monkeypatch
.
setattr
(
processor
.
input_preprocessor
,
monkeypatch
.
setattr
(
processor
.
input_preprocessor
,
...
@@ -223,7 +223,7 @@ def test_multi_modal_uuids_ignored_when_caching_disabled(monkeypatch):
...
@@ -223,7 +223,7 @@ def test_multi_modal_uuids_ignored_when_caching_disabled(monkeypatch):
)
)
# Expect request-id-based overrides are passed through
# Expect request-id-based overrides are passed through
assert
captured
[
"mm_
hash_overr
id
e
s"
]
==
{
assert
captured
[
"mm_
uu
ids"
]
==
{
"image"
:
[
f
"
{
request_id
}
-image-0"
,
f
"
{
request_id
}
-image-1"
],
"image"
:
[
f
"
{
request_id
}
-image-0"
,
f
"
{
request_id
}
-image-1"
],
"video"
:
[
f
"
{
request_id
}
-video-0"
],
"video"
:
[
f
"
{
request_id
}
-video-0"
],
}
}
tests/v1/entrypoints/llm/test_struct_output_generate.py
View file @
38d80967
...
@@ -46,12 +46,12 @@ PARAMS_MODELS_BACKENDS_TOKENIZER_MODE = [
...
@@ -46,12 +46,12 @@ PARAMS_MODELS_BACKENDS_TOKENIZER_MODE = [
(
"mistralai/Ministral-8B-Instruct-2410"
,
"xgrammar"
,
"mistral"
,
None
),
(
"mistralai/Ministral-8B-Instruct-2410"
,
"xgrammar"
,
"mistral"
,
None
),
(
"Qwen/Qwen2.5-1.5B-Instruct"
,
"xgrammar"
,
"auto"
,
None
),
(
"Qwen/Qwen2.5-1.5B-Instruct"
,
"xgrammar"
,
"auto"
,
None
),
(
"Qwen/Qwen2.5-1.5B-Instruct"
,
"lm-format-enforcer"
,
"auto"
,
None
),
(
"Qwen/Qwen2.5-1.5B-Instruct"
,
"lm-format-enforcer"
,
"auto"
,
None
),
(
"mistralai/Ministral-8B-Instruct-2410"
,
"outlines"
,
"auto"
,
None
),
#FIXME: This tests are flaky on CI thus disabled. Tracking in Issue #24402
(
"mistralai/Ministral-8B-Instruct-2410"
,
"outlines"
,
"mistral"
,
None
),
# ("mistralai/Ministral-8B-Instruct-2410", "outlines", "auto", None),
# ("mistralai/Ministral-8B-Instruct-2410", "outlines", "mistral", None),
#("Qwen/Qwen2.5-1.5B-Instruct", "guidance", "auto"),
(
"mistralai/Ministral-8B-Instruct-2410"
,
"outlines"
,
"auto"
,
(
"mistralai/Ministral-8B-Instruct-2410"
,
"outlines"
,
"auto"
,
NGRAM_SPEC_CONFIG
),
NGRAM_SPEC_CONFIG
),
#FIXME: This test is flaky on CI thus disabled
#("Qwen/Qwen2.5-1.5B-Instruct", "guidance", "auto"),
(
"mistralai/Ministral-8B-Instruct-2410"
,
"guidance"
,
"auto"
,
(
"mistralai/Ministral-8B-Instruct-2410"
,
"guidance"
,
"auto"
,
NGRAM_SPEC_CONFIG
),
NGRAM_SPEC_CONFIG
),
(
"Qwen/Qwen2.5-1.5B-Instruct"
,
"xgrammar"
,
"auto"
,
NGRAM_SPEC_CONFIG
),
(
"Qwen/Qwen2.5-1.5B-Instruct"
,
"xgrammar"
,
"auto"
,
NGRAM_SPEC_CONFIG
),
...
@@ -122,6 +122,7 @@ def test_structured_output(
...
@@ -122,6 +122,7 @@ def test_structured_output(
guided_decoding_backend
=
guided_decoding_backend
,
guided_decoding_backend
=
guided_decoding_backend
,
guided_decoding_disable_any_whitespace
=
(
guided_decoding_backend
guided_decoding_disable_any_whitespace
=
(
guided_decoding_backend
in
{
"xgrammar"
,
"guidance"
}),
in
{
"xgrammar"
,
"guidance"
}),
seed
=
120
,
tokenizer_mode
=
tokenizer_mode
,
tokenizer_mode
=
tokenizer_mode
,
speculative_config
=
speculative_config
)
speculative_config
=
speculative_config
)
...
...
tests/v1/entrypoints/openai/responses/test_basic.py
View file @
38d80967
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
openai
# use the official client for correctness check
import
openai
# use the official client for correctness check
import
openai.types.responses
as
openai_responses_types
import
pytest
import
pytest
...
@@ -86,3 +87,18 @@ async def test_logprobs(client: openai.AsyncOpenAI):
...
@@ -86,3 +87,18 @@ async def test_logprobs(client: openai.AsyncOpenAI):
outputs
=
response
.
output
outputs
=
response
.
output
assert
outputs
[
-
1
].
content
[
-
1
].
logprobs
assert
outputs
[
-
1
].
content
[
-
1
].
logprobs
assert
len
(
outputs
[
-
1
].
content
[
-
1
].
logprobs
[
0
].
top_logprobs
)
==
5
assert
len
(
outputs
[
-
1
].
content
[
-
1
].
logprobs
[
0
].
top_logprobs
)
==
5
@
pytest
.
mark
.
asyncio
async
def
test_streaming
(
client
:
openai
.
AsyncOpenAI
):
stream
=
await
client
.
responses
.
create
(
input
=
"What is 13 * 24?"
,
stream
=
True
,
)
events
=
[
event
async
for
event
in
stream
]
assert
isinstance
(
events
[
0
],
openai_responses_types
.
ResponseCreatedEvent
)
assert
any
(
isinstance
(
event
,
openai_responses_types
.
ResponseTextDeltaEvent
)
for
event
in
events
)
assert
isinstance
(
events
[
-
1
],
openai_responses_types
.
ResponseCompletedEvent
)
tests/v1/entrypoints/openai/responses/test_image.py
View file @
38d80967
...
@@ -8,17 +8,17 @@ import pytest
...
@@ -8,17 +8,17 @@ import pytest
import
pytest_asyncio
import
pytest_asyncio
from
tests.utils
import
RemoteOpenAIServer
from
tests.utils
import
RemoteOpenAIServer
from
vllm.multimodal.utils
import
encode_image_base64
,
fetch_image
from
vllm.multimodal.utils
import
encode_image_base64
# Use a small vision model for testing
# Use a small vision model for testing
MODEL_NAME
=
"Qwen/Qwen2.5-VL-3B-Instruct"
MODEL_NAME
=
"Qwen/Qwen2.5-VL-3B-Instruct"
MAXIMUM_IMAGES
=
2
MAXIMUM_IMAGES
=
2
# Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA)
# Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA)
TEST_IMAGE_
URL
S
=
[
TEST_IMAGE_
ASSET
S
=
[
"https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
,
"2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
,
#
"https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
"https://upload.wikimedia.org/wikipedia/commons/f/fa/Grayscale_8bits_palette_sample_image.png"
,
"Grayscale_8bits_palette_sample_image.png"
,
#
"https://upload.wikimedia.org/wikipedia/commons/f/fa/Grayscale_8bits_palette_sample_image.png",
"https://upload.wikimedia.org/wikipedia/commons/thumb/9/91/Venn_diagram_rgb.svg/1280px-Venn_diagram_rgb.svg.png"
,
"1280px-Venn_diagram_rgb.svg.png"
,
#
"https://upload.wikimedia.org/wikipedia/commons/thumb/9/91/Venn_diagram_rgb.svg/1280px-Venn_diagram_rgb.svg.png",
"https://upload.wikimedia.org/wikipedia/commons/0/0b/RGBA_comp.png"
,
"RGBA_comp.png"
,
#
"https://upload.wikimedia.org/wikipedia/commons/0/0b/RGBA_comp.png",
]
]
...
@@ -52,16 +52,17 @@ async def client(image_server):
...
@@ -52,16 +52,17 @@ async def client(image_server):
@
pytest
.
fixture
(
scope
=
"session"
)
@
pytest
.
fixture
(
scope
=
"session"
)
def
base64_encoded_image
()
->
dict
[
str
,
str
]:
def
base64_encoded_image
(
local_asset_server
)
->
dict
[
str
,
str
]:
return
{
return
{
image_url
:
encode_image_base64
(
fetch_image
(
image_url
))
image_url
:
for
image_url
in
TEST_IMAGE_URLS
encode_image_base64
(
local_asset_server
.
get_image_asset
(
image_url
))
for
image_url
in
TEST_IMAGE_ASSETS
}
}
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
MODEL_NAME
])
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
MODEL_NAME
])
@
pytest
.
mark
.
parametrize
(
"image_url"
,
TEST_IMAGE_
URLS
)
@
pytest
.
mark
.
parametrize
(
"image_url"
,
TEST_IMAGE_
ASSETS
,
indirect
=
True
)
async
def
test_single_chat_session_image
(
client
:
openai
.
AsyncOpenAI
,
async
def
test_single_chat_session_image
(
client
:
openai
.
AsyncOpenAI
,
model_name
:
str
,
image_url
:
str
):
model_name
:
str
,
image_url
:
str
):
content_text
=
"What's in this image?"
content_text
=
"What's in this image?"
...
@@ -91,11 +92,11 @@ async def test_single_chat_session_image(client: openai.AsyncOpenAI,
...
@@ -91,11 +92,11 @@ async def test_single_chat_session_image(client: openai.AsyncOpenAI,
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
MODEL_NAME
])
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
MODEL_NAME
])
@
pytest
.
mark
.
parametrize
(
"image_url"
,
TEST_IMAGE_
URL
S
)
@
pytest
.
mark
.
parametrize
(
"
raw_
image_url"
,
TEST_IMAGE_
ASSET
S
)
async
def
test_single_chat_session_image_base64encoded
(
async
def
test_single_chat_session_image_base64encoded
(
client
:
openai
.
AsyncOpenAI
,
client
:
openai
.
AsyncOpenAI
,
model_name
:
str
,
model_name
:
str
,
image_url
:
str
,
raw_
image_url
:
str
,
base64_encoded_image
:
dict
[
str
,
str
],
base64_encoded_image
:
dict
[
str
,
str
],
):
):
content_text
=
"What's in this image?"
content_text
=
"What's in this image?"
...
@@ -106,7 +107,7 @@ async def test_single_chat_session_image_base64encoded(
...
@@ -106,7 +107,7 @@ async def test_single_chat_session_image_base64encoded(
{
{
"type"
:
"input_image"
,
"type"
:
"input_image"
,
"image_url"
:
"image_url"
:
f
"data:image/jpeg;base64,
{
base64_encoded_image
[
image_url
]
}
"
,
f
"data:image/jpeg;base64,
{
base64_encoded_image
[
raw_
image_url
]
}
"
,
"detail"
:
"auto"
,
"detail"
:
"auto"
,
},
},
{
{
...
@@ -127,7 +128,8 @@ async def test_single_chat_session_image_base64encoded(
...
@@ -127,7 +128,8 @@ async def test_single_chat_session_image_base64encoded(
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
MODEL_NAME
])
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
MODEL_NAME
])
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"image_urls"
,
"image_urls"
,
[
TEST_IMAGE_URLS
[:
i
]
for
i
in
range
(
2
,
len
(
TEST_IMAGE_URLS
))])
[
TEST_IMAGE_ASSETS
[:
i
]
for
i
in
range
(
2
,
len
(
TEST_IMAGE_ASSETS
))],
indirect
=
True
)
async
def
test_multi_image_input
(
client
:
openai
.
AsyncOpenAI
,
model_name
:
str
,
async
def
test_multi_image_input
(
client
:
openai
.
AsyncOpenAI
,
model_name
:
str
,
image_urls
:
list
[
str
]):
image_urls
:
list
[
str
]):
messages
=
[{
messages
=
[{
...
...
Prev
1
…
13
14
15
16
17
18
19
20
21
…
28
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment