"docs/vscode:/vscode.git/clone" did not exist on "533db0935da051ac793e8b22afbcb9ae9fa4255b"
Unverified Commit 45bfa49c authored by Mickaël Seznec's avatar Mickaël Seznec Committed by GitHub
Browse files

[Tests] fix initialization of kv hash in tests (#24273)


Signed-off-by: default avatarMickael Seznec <mickael@mistral.ai>
parent fd2f1054
...@@ -30,6 +30,16 @@ from vllm.v1.request import Request ...@@ -30,6 +30,16 @@ from vllm.v1.request import Request
# yapf: enable # yapf: enable
@pytest.fixture(autouse=True)
def _auto_init_hash_fn(request):
hash_fn: Callable
if "hash_fn" in request.fixturenames:
hash_fn = init_none_hash(request.getfixturevalue("hash_fn"))
else:
hash_fn = sha256
init_none_hash(hash_fn)
def make_request( def make_request(
request_id: str, request_id: str,
prompt_token_ids: list[int], prompt_token_ids: list[int],
...@@ -424,7 +434,6 @@ def test_generate_block_hash_extra_keys_cache_salt(): ...@@ -424,7 +434,6 @@ def test_generate_block_hash_extra_keys_cache_salt():
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor]) @pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor])
def test_hash_block_tokens(hash_fn): def test_hash_block_tokens(hash_fn):
init_none_hash(hash_fn)
parent_block_hash = BlockHash(b"123") parent_block_hash = BlockHash(b"123")
curr_block_token_ids = (1, 2, 3) curr_block_token_ids = (1, 2, 3)
extra_keys = ("key1", "key2") extra_keys = ("key1", "key2")
...@@ -437,8 +446,6 @@ def test_hash_block_tokens(hash_fn): ...@@ -437,8 +446,6 @@ def test_hash_block_tokens(hash_fn):
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor]) @pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor])
def test_request_block_hasher(hash_fn): def test_request_block_hasher(hash_fn):
kv_cache_utils.init_none_hash(hash_fn)
request = make_request( request = make_request(
request_id="0", request_id="0",
prompt_token_ids=[_ for _ in range(6)], prompt_token_ids=[_ for _ in range(6)],
...@@ -461,8 +468,6 @@ def test_request_block_hasher(hash_fn): ...@@ -461,8 +468,6 @@ def test_request_block_hasher(hash_fn):
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor]) @pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor])
def test_hash_tokens_different_mm_input(hash_fn): def test_hash_tokens_different_mm_input(hash_fn):
init_none_hash(hash_fn)
request1 = make_request( request1 = make_request(
request_id="0", request_id="0",
prompt_token_ids=[_ for _ in range(6)], prompt_token_ids=[_ for _ in range(6)],
...@@ -491,8 +496,6 @@ def test_hash_tokens_different_mm_input(hash_fn): ...@@ -491,8 +496,6 @@ def test_hash_tokens_different_mm_input(hash_fn):
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor]) @pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor])
def test_hash_request_tokens_no_mm_inputs(hash_fn): def test_hash_request_tokens_no_mm_inputs(hash_fn):
kv_cache_utils.init_none_hash(hash_fn)
request = make_request( request = make_request(
request_id="0", request_id="0",
prompt_token_ids=[_ for _ in range(6)], prompt_token_ids=[_ for _ in range(6)],
......
...@@ -25,6 +25,16 @@ from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, ...@@ -25,6 +25,16 @@ from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec, SlidingWindowSpec) KVCacheGroupSpec, SlidingWindowSpec)
@pytest.fixture(autouse=True)
def _auto_init_hash_fn(request):
hash_fn: Callable
if "hash_fn" in request.fixturenames:
hash_fn = init_none_hash(request.getfixturevalue("hash_fn"))
else:
hash_fn = sha256
init_none_hash(hash_fn)
def make_request( def make_request(
request_id: str, request_id: str,
prompt_token_ids: list[int], prompt_token_ids: list[int],
...@@ -105,7 +115,6 @@ def make_kv_cache_config_hybrid_model(block_size: int, ...@@ -105,7 +115,6 @@ def make_kv_cache_config_hybrid_model(block_size: int,
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor]) @pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor])
def test_prefill(hash_fn): def test_prefill(hash_fn):
init_none_hash(hash_fn)
block_size = 16 block_size = 16
manager = KVCacheManager( manager = KVCacheManager(
...@@ -736,7 +745,6 @@ def test_cache_blocks(hash_fn): ...@@ -736,7 +745,6 @@ def test_cache_blocks(hash_fn):
This is a unit test that tests the correctness of the _cache_full_blocks This is a unit test that tests the correctness of the _cache_full_blocks
function of KVCacheManager. function of KVCacheManager.
""" """
init_none_hash(hash_fn)
block_size = 4 block_size = 4
block_pool = BlockPool( block_pool = BlockPool(
...@@ -849,7 +857,6 @@ def test_mm_prefix_caching(): ...@@ -849,7 +857,6 @@ def test_mm_prefix_caching():
""" """
This tests that the multi-modal prefix caching is correct. This tests that the multi-modal prefix caching is correct.
""" """
kv_cache_utils.init_none_hash(sha256)
block_size = 16 block_size = 16
manager = KVCacheManager( manager = KVCacheManager(
...@@ -942,8 +949,6 @@ def test_cache_key_salting(): ...@@ -942,8 +949,6 @@ def test_cache_key_salting():
This tests that cache salts are applied during hashing and the cache This tests that cache salts are applied during hashing and the cache
is separated cache as expected. is separated cache as expected.
""" """
kv_cache_utils.init_none_hash(sha256)
block_size = 16 block_size = 16
manager = KVCacheManager( manager = KVCacheManager(
make_kv_cache_config(block_size, 11), make_kv_cache_config(block_size, 11),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment