Unverified Commit 82dfb12e authored by Zebing Lin's avatar Zebing Lin Committed by GitHub
Browse files

[Core] Use sha256 bytes instead of BlockHash to reduce GC overhead (#23673)


Signed-off-by: default avatarlinzebing <linzebing1995@gmail.com>
parent bba1042c
...@@ -6,6 +6,8 @@ import msgspec ...@@ -6,6 +6,8 @@ import msgspec
import zmq import zmq
from msgspec.msgpack import Decoder from msgspec.msgpack import Decoder
from vllm.v1.core.kv_cache_utils import BlockHash
# #
# Types copied from vllm.distributed.kv_events # Types copied from vllm.distributed.kv_events
...@@ -22,8 +24,8 @@ class KVCacheEvent( ...@@ -22,8 +24,8 @@ class KVCacheEvent(
class BlockStored(KVCacheEvent): class BlockStored(KVCacheEvent):
block_hashes: list[int] block_hashes: list[BlockHash]
parent_block_hash: Optional[int] parent_block_hash: Optional[BlockHash]
token_ids: list[int] token_ids: list[int]
block_size: int block_size: int
lora_id: Optional[int] lora_id: Optional[int]
...@@ -31,7 +33,7 @@ class BlockStored(KVCacheEvent): ...@@ -31,7 +33,7 @@ class BlockStored(KVCacheEvent):
class BlockRemoved(KVCacheEvent): class BlockRemoved(KVCacheEvent):
block_hashes: list[int] block_hashes: list[BlockHash]
medium: Optional[str] medium: Optional[str]
......
...@@ -835,22 +835,20 @@ def test_model_specification(parser_with_config, cli_config_file, ...@@ -835,22 +835,20 @@ def test_model_specification(parser_with_config, cli_config_file,
@pytest.mark.parametrize("input", [(), ("abc", ), (None, ), @pytest.mark.parametrize("input", [(), ("abc", ), (None, ),
(None, bool, [1, 2, 3])]) (None, bool, [1, 2, 3])])
@pytest.mark.parametrize("output", [0, 1, 2]) def test_sha256(input: tuple):
def test_sha256(input: tuple, output: int): digest = sha256(input)
hash = sha256(input) assert digest is not None
assert hash is not None assert isinstance(digest, bytes)
assert isinstance(hash, int) assert digest != b""
assert hash != 0
bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL) input_bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL)
assert hash == int.from_bytes(hashlib.sha256(bytes).digest(), assert digest == hashlib.sha256(input_bytes).digest()
byteorder="big")
# hashing again, returns the same value # hashing again, returns the same value
assert hash == sha256(input) assert digest == sha256(input)
# hashing different input, returns different value # hashing different input, returns different value
assert hash != sha256(input + (1, )) assert digest != sha256(input + (1, ))
@pytest.mark.parametrize( @pytest.mark.parametrize(
......
...@@ -6,20 +6,22 @@ from typing import Callable, Optional ...@@ -6,20 +6,22 @@ from typing import Callable, Optional
import pytest import pytest
import torch import torch
import vllm.v1.core.kv_cache_utils as kv_cache_utils
from vllm.config import ModelConfig, SchedulerConfig, VllmConfig from vllm.config import ModelConfig, SchedulerConfig, VllmConfig
from vllm.multimodal.inputs import (MultiModalFeatureSpec, from vllm.multimodal.inputs import (MultiModalFeatureSpec,
MultiModalKwargsItem, PlaceholderRange) MultiModalKwargsItem, PlaceholderRange)
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.utils import GiB_bytes, sha256, sha256_cbor_64bit from vllm.utils import GiB_bytes, sha256, sha256_cbor
from vllm.v1.core.kv_cache_manager import KVCacheManager from vllm.v1.core.kv_cache_manager import KVCacheManager
# disable yapf here as it formats differently than isort such that both fail # disable yapf here as it formats differently than isort such that both fail
# yapf: disable # yapf: disable
from vllm.v1.core.kv_cache_utils import ( from vllm.v1.core.kv_cache_utils import (
FreeKVCacheBlockQueue, KVCacheBlock, PrefixCachingMetrics, BlockHash, FreeKVCacheBlockQueue, KVCacheBlock, PrefixCachingMetrics,
estimate_max_model_len, generate_block_hash_extra_keys, estimate_max_model_len, generate_block_hash_extra_keys,
get_kv_cache_config, get_max_concurrency_for_kv_cache_config, get_kv_cache_config, get_max_concurrency_for_kv_cache_config,
get_request_block_hasher, hash_block_tokens, init_none_hash, get_request_block_hasher, hash_block_tokens, init_none_hash,
is_kv_cache_type_uniform, unify_kv_cache_configs) is_kv_cache_type_uniform, make_block_hash_with_group_id,
unify_kv_cache_configs)
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec, KVCacheTensor, KVCacheGroupSpec, KVCacheTensor,
SlidingWindowSpec) SlidingWindowSpec)
...@@ -88,7 +90,7 @@ def new_sliding_window_spec(block_size=16, ...@@ -88,7 +90,7 @@ def new_sliding_window_spec(block_size=16,
sliding_window=sliding_window) sliding_window=sliding_window)
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash]) @pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor])
def test_none_hash(monkeypatch, hash_fn): def test_none_hash(monkeypatch, hash_fn):
import vllm.v1.core.kv_cache_utils import vllm.v1.core.kv_cache_utils
...@@ -98,8 +100,8 @@ def test_none_hash(monkeypatch, hash_fn): ...@@ -98,8 +100,8 @@ def test_none_hash(monkeypatch, hash_fn):
reloaded_kv_cache_utils = importlib.reload(vllm.v1.core.kv_cache_utils) reloaded_kv_cache_utils = importlib.reload(vllm.v1.core.kv_cache_utils)
reloaded_kv_cache_utils.init_none_hash(hash_fn) reloaded_kv_cache_utils.init_none_hash(hash_fn)
assert reloaded_kv_cache_utils.NONE_HASH is not None assert reloaded_kv_cache_utils.NONE_HASH is not None
assert isinstance(reloaded_kv_cache_utils.NONE_HASH, int) assert isinstance(reloaded_kv_cache_utils.NONE_HASH, bytes)
assert reloaded_kv_cache_utils.NONE_HASH != 0 assert reloaded_kv_cache_utils.NONE_HASH != b""
# case 2: PYTHONHASHSEED is set, use the seed and hash_fn # case 2: PYTHONHASHSEED is set, use the seed and hash_fn
with monkeypatch.context() as m: with monkeypatch.context() as m:
...@@ -107,12 +109,11 @@ def test_none_hash(monkeypatch, hash_fn): ...@@ -107,12 +109,11 @@ def test_none_hash(monkeypatch, hash_fn):
reloaded_kv_cache_utils = importlib.reload(vllm.v1.core.kv_cache_utils) reloaded_kv_cache_utils = importlib.reload(vllm.v1.core.kv_cache_utils)
reloaded_kv_cache_utils.init_none_hash(hash_fn) reloaded_kv_cache_utils.init_none_hash(hash_fn)
assert reloaded_kv_cache_utils.NONE_HASH is not None assert reloaded_kv_cache_utils.NONE_HASH is not None
assert isinstance(reloaded_kv_cache_utils.NONE_HASH, int) assert isinstance(reloaded_kv_cache_utils.NONE_HASH, bytes)
assert hash_fn('python hash seed') == reloaded_kv_cache_utils.NONE_HASH assert hash_fn('python hash seed') == reloaded_kv_cache_utils.NONE_HASH
def test_kv_cache_block(): def test_kv_cache_block():
import vllm.v1.core.kv_cache_utils
# Test KVCacheBlock initialization # Test KVCacheBlock initialization
block = KVCacheBlock(block_id=0) block = KVCacheBlock(block_id=0)
...@@ -127,8 +128,7 @@ def test_kv_cache_block(): ...@@ -127,8 +128,7 @@ def test_kv_cache_block():
assert block.ref_cnt == 0 assert block.ref_cnt == 0
# Test block hash setting and resetting # Test block hash setting and resetting
block_hash = vllm.v1.core.kv_cache_utils.BlockHash(hash_value=123, block_hash = make_block_hash_with_group_id(BlockHash(b"abc"), 0)
token_ids=(1, 2, 3))
block.block_hash = block_hash block.block_hash = block_hash
assert block.block_hash == block_hash assert block.block_hash == block_hash
...@@ -407,27 +407,23 @@ def test_generate_block_hash_extra_keys_cache_salt(): ...@@ -407,27 +407,23 @@ def test_generate_block_hash_extra_keys_cache_salt():
assert next_mm_idx == 1 assert next_mm_idx == 1
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash]) @pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor])
def test_hash_block_tokens(hash_fn): def test_hash_block_tokens(hash_fn):
import vllm.v1.core.kv_cache_utils
init_none_hash(hash_fn) init_none_hash(hash_fn)
parent_block_hash = 123 parent_block_hash = BlockHash(b"123")
curr_block_token_ids = (1, 2, 3) curr_block_token_ids = (1, 2, 3)
extra_keys = ("key1", "key2") extra_keys = ("key1", "key2")
block_hash = hash_block_tokens(hash_fn, parent_block_hash, block_hash = hash_block_tokens(hash_fn, parent_block_hash,
curr_block_token_ids, extra_keys) curr_block_token_ids, extra_keys)
assert isinstance(block_hash, vllm.v1.core.kv_cache_utils.BlockHash) expected = hash_fn((parent_block_hash, curr_block_token_ids, extra_keys))
assert block_hash.hash_value == hash_fn( assert block_hash == expected
(parent_block_hash, curr_block_token_ids, extra_keys))
assert block_hash.token_ids == curr_block_token_ids
assert block_hash.extra_keys == extra_keys
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash]) @pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor])
def test_request_block_hasher(hash_fn): def test_request_block_hasher(hash_fn):
import vllm.v1.core.kv_cache_utils kv_cache_utils.init_none_hash(hash_fn)
init_none_hash(hash_fn)
request = make_request( request = make_request(
request_id="0", request_id="0",
prompt_token_ids=[_ for _ in range(6)], prompt_token_ids=[_ for _ in range(6)],
...@@ -442,19 +438,13 @@ def test_request_block_hasher(hash_fn): ...@@ -442,19 +438,13 @@ def test_request_block_hasher(hash_fn):
block_hashes = request.block_hashes block_hashes = request.block_hashes
assert len(block_hashes) == 2 assert len(block_hashes) == 2
assert isinstance(block_hashes[0], vllm.v1.core.kv_cache_utils.BlockHash) assert block_hashes[0] == hash_fn(
assert isinstance(block_hashes[1], vllm.v1.core.kv_cache_utils.BlockHash) (kv_cache_utils.NONE_HASH, (0, 1, 2), ("hash1", )))
assert block_hashes[1] == hash_fn(
# Check the first block (block_hashes[0], (3, 4, 5), ("hash2", )))
assert block_hashes[0].token_ids == (0, 1, 2)
assert block_hashes[0].extra_keys == ("hash1", )
# Check the second block
assert block_hashes[1].token_ids == (3, 4, 5)
assert block_hashes[1].extra_keys == ("hash2", )
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor])
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash])
def test_hash_tokens_different_mm_input(hash_fn): def test_hash_tokens_different_mm_input(hash_fn):
init_none_hash(hash_fn) init_none_hash(hash_fn)
...@@ -484,9 +474,9 @@ def test_hash_tokens_different_mm_input(hash_fn): ...@@ -484,9 +474,9 @@ def test_hash_tokens_different_mm_input(hash_fn):
assert block_hashes1[1] != block_hashes2[1] assert block_hashes1[1] != block_hashes2[1]
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash]) @pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor])
def test_hash_request_tokens_no_mm_inputs(hash_fn): def test_hash_request_tokens_no_mm_inputs(hash_fn):
init_none_hash(hash_fn) kv_cache_utils.init_none_hash(hash_fn)
request = make_request( request = make_request(
request_id="0", request_id="0",
...@@ -500,10 +490,9 @@ def test_hash_request_tokens_no_mm_inputs(hash_fn): ...@@ -500,10 +490,9 @@ def test_hash_request_tokens_no_mm_inputs(hash_fn):
block_hashes = request.block_hashes block_hashes = request.block_hashes
assert len(block_hashes) == 2 assert len(block_hashes) == 2
assert block_hashes[0].token_ids == (0, 1, 2) assert block_hashes[0] == hash_fn(
assert block_hashes[0].extra_keys is None (kv_cache_utils.NONE_HASH, (0, 1, 2), None))
assert block_hashes[1].token_ids == (3, 4, 5) assert block_hashes[1] == hash_fn((block_hashes[0], (3, 4, 5), None))
assert block_hashes[1].extra_keys is None
def test_metrics(): def test_metrics():
......
This diff is collapsed.
...@@ -6,8 +6,8 @@ import random ...@@ -6,8 +6,8 @@ import random
import torch import torch
from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId, from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock,
KVCacheBlock) make_block_hash_with_group_id)
from vllm.v1.core.single_type_kv_cache_manager import ( from vllm.v1.core.single_type_kv_cache_manager import (
ChunkedLocalAttentionManager, SlidingWindowManager) ChunkedLocalAttentionManager, SlidingWindowManager)
from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec, from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec,
...@@ -44,7 +44,7 @@ def test_chunked_local_attention_possible_cached_prefix(): ...@@ -44,7 +44,7 @@ def test_chunked_local_attention_possible_cached_prefix():
def run_one_case(block_is_cached, tail_token, expect_length): def run_one_case(block_is_cached, tail_token, expect_length):
block_hash_list = [ block_hash_list = [
BlockHash(i, ()) for i in range(len(block_is_cached)) BlockHash(str(i).encode()) for i in range(len(block_is_cached))
] ]
block_pool.cached_block_hash_to_block.clear() block_pool.cached_block_hash_to_block.clear()
...@@ -53,8 +53,8 @@ def test_chunked_local_attention_possible_cached_prefix(): ...@@ -53,8 +53,8 @@ def test_chunked_local_attention_possible_cached_prefix():
for i, (block_hash, for i, (block_hash,
is_cached) in enumerate(zip(block_hash_list, block_is_cached)): is_cached) in enumerate(zip(block_hash_list, block_is_cached)):
if is_cached: if is_cached:
block_pool.cached_block_hash_to_block[BlockHashWithGroupId( block_pool.cached_block_hash_to_block[
block_hash, 0)] = { make_block_hash_with_group_id(block_hash, 0)] = {
i: block_pool.blocks[i + 10], i: block_pool.blocks[i + 10],
} }
...@@ -109,7 +109,7 @@ def test_sliding_window_possible_cached_prefix(): ...@@ -109,7 +109,7 @@ def test_sliding_window_possible_cached_prefix():
def run_one_case(block_is_cached, expect_length): def run_one_case(block_is_cached, expect_length):
block_hash_list = [ block_hash_list = [
BlockHash(i, ()) for i in range(len(block_is_cached)) BlockHash(str(i).encode()) for i in range(len(block_is_cached))
] ]
block_pool.cached_block_hash_to_block.clear() block_pool.cached_block_hash_to_block.clear()
...@@ -118,8 +118,8 @@ def test_sliding_window_possible_cached_prefix(): ...@@ -118,8 +118,8 @@ def test_sliding_window_possible_cached_prefix():
for i, (block_hash, for i, (block_hash,
is_cached) in enumerate(zip(block_hash_list, block_is_cached)): is_cached) in enumerate(zip(block_hash_list, block_is_cached)):
if is_cached: if is_cached:
block_pool.cached_block_hash_to_block[BlockHashWithGroupId( block_pool.cached_block_hash_to_block[
block_hash, 0)] = { make_block_hash_with_group_id(block_hash, 0)] = {
i: block_pool.blocks[i + 10], i: block_pool.blocks[i + 10],
} }
......
...@@ -9,6 +9,7 @@ from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig, ...@@ -9,6 +9,7 @@ from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig,
from vllm.multimodal.inputs import (MultiModalFeatureSpec, from vllm.multimodal.inputs import (MultiModalFeatureSpec,
MultiModalKwargsItem, PlaceholderRange) MultiModalKwargsItem, PlaceholderRange)
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.utils import sha256
from vllm.v1.core.kv_cache_utils import (get_request_block_hasher, from vllm.v1.core.kv_cache_utils import (get_request_block_hasher,
init_none_hash) init_none_hash)
from vllm.v1.core.sched.async_scheduler import AsyncScheduler from vllm.v1.core.sched.async_scheduler import AsyncScheduler
...@@ -130,10 +131,10 @@ def create_requests( ...@@ -130,10 +131,10 @@ def create_requests(
) -> list[Request]: ) -> list[Request]:
global _none_hash_initialized global _none_hash_initialized
if not _none_hash_initialized: if not _none_hash_initialized:
init_none_hash(hash) init_none_hash(sha256)
_none_hash_initialized = True _none_hash_initialized = True
block_hasher = get_request_block_hasher(block_size, hash) block_hasher = get_request_block_hasher(block_size, sha256)
sampling_params = SamplingParams(ignore_eos=False, sampling_params = SamplingParams(ignore_eos=False,
max_tokens=max_tokens, max_tokens=max_tokens,
stop_token_ids=stop_token_ids, stop_token_ids=stop_token_ids,
......
...@@ -36,18 +36,19 @@ def test_prefix_caching_from_cli(): ...@@ -36,18 +36,19 @@ def test_prefix_caching_from_cli():
assert vllm_config.cache_config.enable_prefix_caching assert vllm_config.cache_config.enable_prefix_caching
# default hash algorithm is "builtin" # default hash algorithm is "builtin"
assert vllm_config.cache_config.prefix_caching_hash_algo == "builtin" assert vllm_config.cache_config.prefix_caching_hash_algo == "sha256"
# set hash algorithm to sha256_cbor
args = parser.parse_args(["--prefix-caching-hash-algo", "sha256_cbor"])
vllm_config = EngineArgs.from_cli_args(args=args).create_engine_config()
assert vllm_config.cache_config.prefix_caching_hash_algo == \
"sha256_cbor"
# set hash algorithm to sha256 # set hash algorithm to sha256
args = parser.parse_args(["--prefix-caching-hash-algo", "sha256"]) args = parser.parse_args(["--prefix-caching-hash-algo", "sha256"])
vllm_config = EngineArgs.from_cli_args(args=args).create_engine_config() vllm_config = EngineArgs.from_cli_args(args=args).create_engine_config()
assert vllm_config.cache_config.prefix_caching_hash_algo == "sha256" assert vllm_config.cache_config.prefix_caching_hash_algo == "sha256"
# set hash algorithm to builtin
args = parser.parse_args(["--prefix-caching-hash-algo", "builtin"])
vllm_config = EngineArgs.from_cli_args(args=args).create_engine_config()
assert vllm_config.cache_config.prefix_caching_hash_algo == "builtin"
# an invalid hash algorithm raises an error # an invalid hash algorithm raises an error
parser.exit_on_error = False parser.exit_on_error = False
with pytest.raises(ArgumentError): with pytest.raises(ArgumentError):
......
...@@ -13,6 +13,7 @@ from vllm.distributed.kv_transfer.kv_connector.factory import ( ...@@ -13,6 +13,7 @@ from vllm.distributed.kv_transfer.kv_connector.factory import (
KVConnectorFactory) KVConnectorFactory)
from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import ( # noqa from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import ( # noqa
SharedStorageConnector) SharedStorageConnector)
from vllm.utils import sha256
from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.core.kv_cache_utils import (get_request_block_hasher, from vllm.v1.core.kv_cache_utils import (get_request_block_hasher,
init_none_hash) init_none_hash)
...@@ -127,11 +128,11 @@ def create_request(request_id: int, ...@@ -127,11 +128,11 @@ def create_request(request_id: int,
use_all_1s_for_prompt_tokens: bool = False, use_all_1s_for_prompt_tokens: bool = False,
num_remote_blocks: int = 3, num_remote_blocks: int = 3,
block_size: int = 16, block_size: int = 16,
hash_fn: Callable = hash) -> Request: hash_fn: Callable = sha256) -> Request:
"""Make dummy request for testing.""" """Make dummy request for testing."""
global _none_hash_initialized global _none_hash_initialized
if not _none_hash_initialized: if not _none_hash_initialized:
init_none_hash(hash) init_none_hash(hash_fn)
_none_hash_initialized = True _none_hash_initialized = True
kv_transfer_params: Optional[dict[str, Any]] = None kv_transfer_params: Optional[dict[str, Any]] = None
......
...@@ -24,7 +24,7 @@ logger = init_logger(__name__) ...@@ -24,7 +24,7 @@ logger = init_logger(__name__)
BlockSize = Literal[1, 8, 16, 32, 64, 128] BlockSize = Literal[1, 8, 16, 32, 64, 128]
CacheDType = Literal["auto", "fp8", "fp8_e4m3", "fp8_e5m2", "fp8_inc"] CacheDType = Literal["auto", "fp8", "fp8_e4m3", "fp8_e5m2", "fp8_inc"]
MambaDType = Literal["auto", "float32"] MambaDType = Literal["auto", "float32"]
PrefixCachingHashAlgo = Literal["builtin", "sha256", "sha256_cbor_64bit"] PrefixCachingHashAlgo = Literal["sha256", "sha256_cbor"]
@config @config
...@@ -63,17 +63,12 @@ class CacheConfig: ...@@ -63,17 +63,12 @@ class CacheConfig:
"""Sliding window size for the KV cache. This is primarily set in """Sliding window size for the KV cache. This is primarily set in
`ModelConfig` and that value should be manually duplicated here.""" `ModelConfig` and that value should be manually duplicated here."""
enable_prefix_caching: Optional[bool] = None enable_prefix_caching: Optional[bool] = None
"""Whether to enable prefix caching. Disabled by default for V0. Enabled by """Whether to enable prefix caching. Enabled by default for V1."""
default for V1.""" prefix_caching_hash_algo: PrefixCachingHashAlgo = "sha256"
prefix_caching_hash_algo: PrefixCachingHashAlgo = "builtin"
"""Set the hash algorithm for prefix caching:\n """Set the hash algorithm for prefix caching:\n
- "builtin" is Python's built-in hash.\n - "sha256" uses Pickle for object serialization before hashing.\n
- "sha256" is collision resistant but with certain overheads. - "sha256_cbor" provides a reproducible, cross-language compatible hash. It
This option uses Pickle for object serialization before hashing.\n serializes objects using canonical CBOR and hashes them with SHA-256."""
- "sha256_cbor_64bit" provides a reproducible, cross-language compatible
hash. It serializes objects using canonical CBOR and hashes them with
SHA-256. The resulting hash consists of the lower 64 bits of the SHA-256
digest."""
cpu_offload_gb: float = 0 cpu_offload_gb: float = 0
"""The space in GiB to offload to CPU, per GPU. Default is 0, which means """The space in GiB to offload to CPU, per GPU. Default is 0, which means
no offloading. Intuitively, this argument can be seen as a virtual way to no offloading. Intuitively, this argument can be seen as a virtual way to
......
...@@ -16,6 +16,7 @@ import zmq ...@@ -16,6 +16,7 @@ import zmq
from vllm.config.kv_events import KVEventsConfig from vllm.config.kv_events import KVEventsConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.v1.core.kv_cache_utils import ExternalBlockHash
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -44,8 +45,8 @@ MEDIUM_GPU = "GPU" ...@@ -44,8 +45,8 @@ MEDIUM_GPU = "GPU"
class BlockStored(KVCacheEvent): class BlockStored(KVCacheEvent):
block_hashes: list[int] block_hashes: list[ExternalBlockHash]
parent_block_hash: Optional[int] parent_block_hash: Optional[ExternalBlockHash]
token_ids: list[int] token_ids: list[int]
block_size: int block_size: int
lora_id: Optional[int] lora_id: Optional[int]
...@@ -53,7 +54,7 @@ class BlockStored(KVCacheEvent): ...@@ -53,7 +54,7 @@ class BlockStored(KVCacheEvent):
class BlockRemoved(KVCacheEvent): class BlockRemoved(KVCacheEvent):
block_hashes: list[int] block_hashes: list[ExternalBlockHash]
medium: Optional[str] medium: Optional[str]
......
...@@ -1592,20 +1592,12 @@ class EngineArgs: ...@@ -1592,20 +1592,12 @@ class EngineArgs:
"in low performance due to small KV cache size. Consider " "in low performance due to small KV cache size. Consider "
"setting --max-model-len to a smaller value.", max_model_len) "setting --max-model-len to a smaller value.", max_model_len)
# if using prefix caching, we must set a hash algo # Disable prefix caching for multimodal models for VLLM_V0.
if self.enable_prefix_caching: if self.enable_prefix_caching and model_config.is_multimodal_model:
# Disable prefix caching for multimodal models for VLLM_V0. logger.warning(
if model_config.is_multimodal_model: "--enable-prefix-caching is not supported for multimodal "
logger.warning( "models in V0 and has been disabled.")
"--enable-prefix-caching is not supported for multimodal " self.enable_prefix_caching = False
"models in V0 and has been disabled.")
self.enable_prefix_caching = False
# VLLM_V0 only supports builtin hash algo for prefix caching.
if self.prefix_caching_hash_algo == "sha256":
raise ValueError(
"sha256 is not supported for prefix caching in V0 engine. "
"Please use 'builtin'.")
# Set max_num_seqs to 256 for VLLM_V0. # Set max_num_seqs to 256 for VLLM_V0.
if self.max_num_seqs is None: if self.max_num_seqs is None:
......
...@@ -171,6 +171,7 @@ if TYPE_CHECKING: ...@@ -171,6 +171,7 @@ if TYPE_CHECKING:
VLLM_GPT_OSS_USE_CONTAINER_TOOL: bool = False VLLM_GPT_OSS_USE_CONTAINER_TOOL: bool = False
VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS: bool = False VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS: bool = False
VLLM_CUSTOM_SCOPES_FOR_PROFILING: bool = False VLLM_CUSTOM_SCOPES_FOR_PROFILING: bool = False
VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES: bool = True
def get_default_cache_root(): def get_default_cache_root():
...@@ -1215,6 +1216,11 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1215,6 +1216,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
# Add optional custom scopes for profiling, disable to avoid overheads # Add optional custom scopes for profiling, disable to avoid overheads
"VLLM_CUSTOM_SCOPES_FOR_PROFILING": "VLLM_CUSTOM_SCOPES_FOR_PROFILING":
lambda: bool(int(os.getenv("VLLM_CUSTOM_SCOPES_FOR_PROFILING", "0"))), lambda: bool(int(os.getenv("VLLM_CUSTOM_SCOPES_FOR_PROFILING", "0"))),
# Represent block hashes in KV cache events as 64-bit integers instead of
# raw bytes. Defaults to True for backward compatibility.
"VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES":
lambda: bool(int(os.getenv("VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES", "1"))),
} }
# --8<-- [end:env-vars-definition] # --8<-- [end:env-vars-definition]
......
...@@ -3249,7 +3249,7 @@ def check_use_alibi(model_config: ModelConfig) -> bool: ...@@ -3249,7 +3249,7 @@ def check_use_alibi(model_config: ModelConfig) -> bool:
and getattr(cfg.attn_config, "alibi", False))))) and getattr(cfg.attn_config, "alibi", False)))))
def sha256(input) -> int: def sha256(input) -> bytes:
"""Hash any picklable Python object using SHA-256. """Hash any picklable Python object using SHA-256.
The input is serialized using pickle before hashing, which allows The input is serialized using pickle before hashing, which allows
...@@ -3260,16 +3260,15 @@ def sha256(input) -> int: ...@@ -3260,16 +3260,15 @@ def sha256(input) -> int:
input: Any picklable Python object. input: Any picklable Python object.
Returns: Returns:
An integer representing the SHA-256 hash of the serialized input. Bytes representing the SHA-256 hash of the serialized input.
""" """
input_bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL) input_bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL)
return int.from_bytes(hashlib.sha256(input_bytes).digest(), return hashlib.sha256(input_bytes).digest()
byteorder="big")
def sha256_cbor_64bit(input) -> int: def sha256_cbor(input) -> bytes:
""" """
Hash objects using CBOR serialization and SHA-256, then truncate to 64bits. Hash objects using CBOR serialization and SHA-256.
This option is useful for non-Python-dependent serialization and hashing. This option is useful for non-Python-dependent serialization and hashing.
...@@ -3280,17 +3279,13 @@ def sha256_cbor_64bit(input) -> int: ...@@ -3280,17 +3279,13 @@ def sha256_cbor_64bit(input) -> int:
Custom classes must implement CBOR serialization methods. Custom classes must implement CBOR serialization methods.
Returns: Returns:
An integer in the range [0, 2^64-1] representing the lower 64 bits Bytes representing the SHA-256 hash of the CBOR serialized input.
of the SHA-256 hash of the CBOR serialized input.
""" """
input_bytes = cbor2.dumps(input, canonical=True) input_bytes = cbor2.dumps(input, canonical=True)
full_hash = int.from_bytes(hashlib.sha256(input_bytes).digest(), return hashlib.sha256(input_bytes).digest()
byteorder="big")
return full_hash & ((1 << 64) - 1)
def get_hash_fn_by_name(hash_fn_name: str) -> Callable[[Any], bytes]:
def get_hash_fn_by_name(hash_fn_name: str) -> Callable[[Any], int]:
"""Get a hash function by name, or raise an error if """Get a hash function by name, or raise an error if
the function is not found. the function is not found.
Args: Args:
...@@ -3300,10 +3295,8 @@ def get_hash_fn_by_name(hash_fn_name: str) -> Callable[[Any], int]: ...@@ -3300,10 +3295,8 @@ def get_hash_fn_by_name(hash_fn_name: str) -> Callable[[Any], int]:
""" """
if hash_fn_name == "sha256": if hash_fn_name == "sha256":
return sha256 return sha256
if hash_fn_name == "sha256_cbor_64bit": if hash_fn_name == "sha256_cbor":
return sha256_cbor_64bit return sha256_cbor
if hash_fn_name == "builtin":
return hash
raise ValueError(f"Unsupported hash function: {hash_fn_name}") raise ValueError(f"Unsupported hash function: {hash_fn_name}")
......
...@@ -9,7 +9,11 @@ from vllm.distributed.kv_events import (MEDIUM_GPU, AllBlocksCleared, ...@@ -9,7 +9,11 @@ from vllm.distributed.kv_events import (MEDIUM_GPU, AllBlocksCleared,
KVCacheEvent) KVCacheEvent)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId, from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId,
FreeKVCacheBlockQueue, KVCacheBlock) ExternalBlockHash,
FreeKVCacheBlockQueue, KVCacheBlock,
get_block_hash,
make_block_hash_with_group_id,
maybe_convert_block_hash)
from vllm.v1.request import Request from vllm.v1.request import Request
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -84,8 +88,10 @@ class BlockPool: ...@@ -84,8 +88,10 @@ class BlockPool:
""" """
cached_blocks = [] cached_blocks = []
for group_id in kv_cache_group_ids: for group_id in kv_cache_group_ids:
block_hash_with_group_id = make_block_hash_with_group_id(
block_hash, group_id)
cached_blocks_one_group = self.cached_block_hash_to_block.get( cached_blocks_one_group = self.cached_block_hash_to_block.get(
BlockHashWithGroupId(block_hash, group_id)) block_hash_with_group_id)
if not cached_blocks_one_group: if not cached_blocks_one_group:
return None return None
first_block = next(iter(cached_blocks_one_group.values())) first_block = next(iter(cached_blocks_one_group.values()))
...@@ -124,28 +130,29 @@ class BlockPool: ...@@ -124,28 +130,29 @@ class BlockPool:
assert len(request.block_hashes) >= num_full_blocks assert len(request.block_hashes) >= num_full_blocks
new_block_hashes = request.block_hashes[num_cached_blocks:] new_block_hashes = request.block_hashes[num_cached_blocks:]
new_hashes: Optional[list[int]] = ([] if self.enable_kv_cache_events new_hashes: Optional[list[ExternalBlockHash]] = (
else None) [] if self.enable_kv_cache_events else None)
for i, blk in enumerate(new_full_blocks): for i, blk in enumerate(new_full_blocks):
assert blk.block_hash is None assert blk.block_hash is None
block_hash = new_block_hashes[i] block_hash = new_block_hashes[i]
# Update and added the full block to the cache. # Update and added the full block to the cache.
block_hash_with_group_id = BlockHashWithGroupId( block_hash_with_group_id = make_block_hash_with_group_id(
block_hash, kv_cache_group_id) block_hash, kv_cache_group_id)
blk.block_hash = block_hash_with_group_id blk.block_hash = block_hash_with_group_id
self.cached_block_hash_to_block[block_hash_with_group_id][ self.cached_block_hash_to_block[block_hash_with_group_id][
blk.block_id] = blk blk.block_id] = blk
if new_hashes is not None: if new_hashes is not None:
new_hashes.append(block_hash.hash_value) new_hashes.append(maybe_convert_block_hash(block_hash))
if self.enable_kv_cache_events: if self.enable_kv_cache_events:
if num_cached_blocks == 0: if num_cached_blocks == 0:
parent_block_hash = None parent_block_hash: Optional[ExternalBlockHash] = None
else: else:
parent_block = blocks[num_cached_blocks - 1] parent_block = blocks[num_cached_blocks - 1]
assert parent_block.block_hash is not None assert parent_block.block_hash is not None
parent_block_hash = parent_block.block_hash.get_hash_value() parent_block_hash = maybe_convert_block_hash(
get_block_hash(parent_block.block_hash))
self.kv_event_queue.append( self.kv_event_queue.append(
BlockStored( BlockStored(
...@@ -220,7 +227,9 @@ class BlockPool: ...@@ -220,7 +227,9 @@ class BlockPool:
# we disable hybrid kv cache manager when kv cache event is # we disable hybrid kv cache manager when kv cache event is
# enabled, so there is only one group. # enabled, so there is only one group.
self.kv_event_queue.append( self.kv_event_queue.append(
BlockRemoved(block_hashes=[block_hash.get_hash_value()], BlockRemoved(block_hashes=[
maybe_convert_block_hash(get_block_hash(block_hash))
],
medium=MEDIUM_GPU)) medium=MEDIUM_GPU))
return True return True
......
...@@ -6,11 +6,12 @@ import os ...@@ -6,11 +6,12 @@ import os
from collections import defaultdict, deque from collections import defaultdict, deque
from collections.abc import Iterable, Sequence from collections.abc import Iterable, Sequence
from dataclasses import astuple, dataclass from dataclasses import astuple, dataclass
from typing import Any, Callable, NamedTuple, Optional from typing import Any, Callable, NewType, Optional, Union
from vllm import envs
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import GiB_bytes, cdiv, sha256_cbor_64bit from vllm.utils import GiB_bytes, cdiv, sha256_cbor
from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec, from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec,
FullAttentionSpec, KVCacheConfig, FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec, KVCacheSpec, KVCacheGroupSpec, KVCacheSpec,
...@@ -18,59 +19,78 @@ from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec, ...@@ -18,59 +19,78 @@ from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec,
from vllm.v1.metrics.stats import PrefixCacheStats from vllm.v1.metrics.stats import PrefixCacheStats
from vllm.v1.request import Request from vllm.v1.request import Request
logger = init_logger(__name__) # BlockHash represents the hash of a single KV-cache block used for
# prefix caching. Treating it as a distinct type from ``bytes`` helps
# catch accidental misuse when passing around raw byte strings.
BlockHash = NewType("BlockHash", bytes)
# ``BlockHashWithGroupId`` combines a ``BlockHash`` with its KV cache group ID.
# It is represented as raw bytes for compactness and efficiency. The helper
# functions below pack/unpack the ``BlockHash`` and group id into/from the key.
BlockHashWithGroupId = NewType("BlockHashWithGroupId", bytes)
# ExternalBlockHash is used for reproducible prefix-cache block hashing.
# It's a union of ``bytes`` and ``int`` to keep backward compatibility
# after we default block hashing to use sha256 bytes.
ExternalBlockHash = Union[bytes, int]
class BlockHash(NamedTuple): def make_block_hash_with_group_id(block_hash: BlockHash,
"""Hash value of a block (int), the token IDs in the block, and extra keys. group_id: int) -> BlockHashWithGroupId:
We keep a tuple of token IDs and extra keys to reduce the likelihood of """Pack a ``BlockHash`` and group id into a ``BlockHashWithGroupId``.
hash collisions when the hash value is the same. By using SHA256 however,
hash collisions are practically impossible. The group id is encoded using 4 bytes in big-endian order and appended to
the block hash bytes. This representation avoids creating tuples while
still allowing us to recover both components when needed.
""" """
# Hash value of the block in an integer. return BlockHashWithGroupId(block_hash +
hash_value: int group_id.to_bytes(4, "big", signed=False))
# Token IDs in the block.
token_ids: tuple[int, ...]
# Extra keys for the block. def get_block_hash(key: BlockHashWithGroupId) -> BlockHash:
extra_keys: Optional[Any] = None """Extract the ``BlockHash`` from a ``BlockHashWithGroupId``."""
return BlockHash(key[:-4])
class BlockHashWithGroupId(NamedTuple): def get_group_id(key: BlockHashWithGroupId) -> int:
# The hash value for the contents (e.g., token_ids) of a block without group """Extract the group id from a ``BlockHashWithGroupId``."""
# ID. The value is the same for blocks representing the same tokens but for return int.from_bytes(key[-4:], "big", signed=False)
# different groups.
block_hash: BlockHash
# The KV cache group ID.
group_id: int
def get_hash_value(self) -> int:
return self.block_hash.hash_value
def maybe_convert_block_hash(hash_bytes: BlockHash) -> ExternalBlockHash:
if not envs.VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES:
return hash_bytes
return int.from_bytes(hash_bytes, byteorder="big") & ((1 << 64) - 1)
logger = init_logger(__name__)
# The hash seed for the first block of any prefix block sequence. # The hash seed for the first block of any prefix block sequence.
# #
# We use a random value to avoid hash collisions or PYTHONHASHSEED environment # We use a random value to avoid hash collisions or PYTHONHASHSEED environment
# variable if set such that processes can share the seed if needed. # variable if set such that processes can share the seed if needed. This aligns
# This aligns with the behavior of Python's hash() function, which also uses # with the behavior of Python's hash() function, which also uses a random seed
# a random seed if PYTHONHASHSEED is not set. # if PYTHONHASHSEED is not set.
# #
# The function `init_none_hash` initializes this variable globally. # The function `init_none_hash` initializes this variable globally.
NONE_HASH: int NONE_HASH: BlockHash
def init_none_hash(hash_fn: Callable): def init_none_hash(hash_fn: Callable[[Any], bytes]):
global NONE_HASH global NONE_HASH
hash_seed = os.getenv("PYTHONHASHSEED") hash_seed = os.getenv("PYTHONHASHSEED")
if hash_seed is None and hash_fn is sha256_cbor_64bit: if hash_seed is None and hash_fn is sha256_cbor:
logger.warning( logger.warning(
"PYTHONHASHSEED is not set. This will lead to non-reproducible " "PYTHONHASHSEED is not set. This will lead to non-reproducible "
"block-hashes when using sha256_cbor_64bit as the hash function." "block-hashes when using sha256_cbor as the hash function."
"Consider setting PYTHONHASHSEED to a fixed value for " "Consider setting PYTHONHASHSEED to a fixed value for "
"reproducibility.") "reproducibility.")
NONE_HASH = (int.from_bytes(os.urandom(32), byteorder="big") if hash_seed is None:
if hash_seed is None else hash_fn(hash_seed)) NONE_HASH = BlockHash(os.urandom(32))
else:
NONE_HASH = BlockHash(hash_fn(hash_seed))
class PrefixCachingMetrics: class PrefixCachingMetrics:
...@@ -142,8 +162,8 @@ class KVCacheBlock: ...@@ -142,8 +162,8 @@ class KVCacheBlock:
block_id: int block_id: int
# Reference count. # Reference count.
ref_cnt: int = 0 ref_cnt: int = 0
# The hash of the block composed of (block hash, tuple of token IDs). # The hash key (block hash + group id) of the block, only available
# It is only available when the block is full. # when the block is full and cached.
_block_hash: Optional[BlockHashWithGroupId] = None _block_hash: Optional[BlockHashWithGroupId] = None
# Used to construct a doubly linked list for free blocks. # Used to construct a doubly linked list for free blocks.
...@@ -177,7 +197,7 @@ class KVCacheBlock: ...@@ -177,7 +197,7 @@ class KVCacheBlock:
if self.next_free_block else None) if self.next_free_block else None)
return (f"KVCacheBlock(block_id={self.block_id}, " return (f"KVCacheBlock(block_id={self.block_id}, "
f"ref_cnt={self.ref_cnt}, " f"ref_cnt={self.ref_cnt}, "
f"_block_hash={self._block_hash}, " f"_block_hash={self._block_hash!r}, "
f"prev_free_block={prev_block_id}, " f"prev_free_block={prev_block_id}, "
f"next_free_block={next_block_id})") f"next_free_block={next_block_id})")
...@@ -517,15 +537,14 @@ def generate_block_hash_extra_keys( ...@@ -517,15 +537,14 @@ def generate_block_hash_extra_keys(
def hash_block_tokens( def hash_block_tokens(
hash_function: Callable, hash_function: Callable[[Any], bytes],
parent_block_hash: Optional[int], parent_block_hash: Optional[BlockHash],
curr_block_token_ids: Sequence[int], curr_block_token_ids: Sequence[int],
extra_keys: Optional[tuple[Any, ...]] = None) -> BlockHash: extra_keys: Optional[tuple[Any, ...]] = None) -> BlockHash:
"""Computes a hash value corresponding to the contents of a block and """Computes a hash value corresponding to the contents of a block and
the contents of the preceding block(s). The hash value is used for the contents of the preceding block(s). The hash value is used for
prefix caching. We use LRU cache for this function to avoid recomputing prefix caching. We use LRU cache for this function to avoid recomputing
hash values for the same block contents. hash values for the same block contents.
Args: Args:
hash_function: The hash function used to compute block hash. hash_function: The hash function used to compute block hash.
parent_block_hash: The hash of the parent block. None parent_block_hash: The hash of the parent block. None
...@@ -533,7 +552,6 @@ def hash_block_tokens( ...@@ -533,7 +552,6 @@ def hash_block_tokens(
curr_block_token_ids: A list of token ids in the current curr_block_token_ids: A list of token ids in the current
block. The current block is assumed to be full. block. The current block is assumed to be full.
extra_keys: Extra keys for the block. extra_keys: Extra keys for the block.
Returns: Returns:
The hash value of the block and the token ids in the block. The hash value of the block and the token ids in the block.
The entire tuple is used as the hash key of the block. The entire tuple is used as the hash key of the block.
...@@ -544,26 +562,16 @@ def hash_block_tokens( ...@@ -544,26 +562,16 @@ def hash_block_tokens(
curr_block_token_ids_tuple = tuple(curr_block_token_ids) curr_block_token_ids_tuple = tuple(curr_block_token_ids)
return BlockHash( return BlockHash(
hash_function( hash_function(
(parent_block_hash, curr_block_token_ids_tuple, extra_keys)), (parent_block_hash, curr_block_token_ids_tuple, extra_keys)))
curr_block_token_ids_tuple, extra_keys)
def get_request_block_hasher( def get_request_block_hasher(
block_size: int, block_size: int,
caching_hash_fn: Callable[[Any], caching_hash_fn: Callable[[Any], bytes],
int]) -> Callable[[Request], list[BlockHash]]: ) -> Callable[[Request], list[BlockHash]]:
""" """
Returns a function which computes the list of un-computed block hashes Returns a function which computes the list of un-computed block hashes
of a request. of a request."""
Each request holds a list of its block hashes (request.block_hashes).
When a request is created, it calls the below function to compute
the hashes of all full blocks of the request's initial tokens.
The hashes are then stored in request.block_hashes.
Later, whenever new tokens are appended to the request, it calls
the below function again to compute any new full blocks of tokens.
The returned new hashes are appended to request.block_hashes.
"""
def request_block_hasher(request: Request) -> list[BlockHash]: def request_block_hasher(request: Request) -> list[BlockHash]:
start_token_idx = len(request.block_hashes) * block_size start_token_idx = len(request.block_hashes) * block_size
...@@ -577,8 +585,8 @@ def get_request_block_hasher( ...@@ -577,8 +585,8 @@ def get_request_block_hasher(
# last mm input. # last mm input.
curr_mm_idx = -1 curr_mm_idx = -1
prev_block_hash_value = request.block_hashes[-1].hash_value \ prev_block_hash_value = (request.block_hashes[-1]
if request.block_hashes else None if request.block_hashes else None)
new_block_hashes: list[BlockHash] = [] new_block_hashes: list[BlockHash] = []
while True: while True:
end_token_idx = start_token_idx + block_size end_token_idx = start_token_idx + block_size
...@@ -598,7 +606,7 @@ def get_request_block_hasher( ...@@ -598,7 +606,7 @@ def get_request_block_hasher(
new_block_hashes.append(block_hash) new_block_hashes.append(block_hash)
start_token_idx += block_size start_token_idx += block_size
prev_block_hash_value = block_hash.hash_value prev_block_hash_value = block_hash
return new_block_hashes return new_block_hashes
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment