Unverified Commit c280066f authored by Or Ozeri's avatar Or Ozeri Committed by GitHub
Browse files

[v1] Move block_hashes from KVCacheManager to Request.block_hashes (#19728)


Signed-off-by: default avatarOr Ozeri <oro@il.ibm.com>
parent b9dc9d26
...@@ -7,6 +7,7 @@ import pytest ...@@ -7,6 +7,7 @@ import pytest
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import RequestStatus from vllm.v1.request import RequestStatus
from vllm.v1.utils import ConstantList
from .utils import create_requests, create_scheduler from .utils import create_requests, create_scheduler
...@@ -140,7 +141,8 @@ def test_prefix_caching_for_prefill_dedup(): ...@@ -140,7 +141,8 @@ def test_prefix_caching_for_prefill_dedup():
requests = create_requests(num_requests=5, requests = create_requests(num_requests=5,
num_tokens=num_prompt_tokens, num_tokens=num_prompt_tokens,
max_tokens=3, max_tokens=3,
same_prompt=True) same_prompt=True,
block_size=BLOCK_SIZE)
requests_copy = requests.copy() requests_copy = requests.copy()
# Two requests with the same prompt. # Two requests with the same prompt.
...@@ -188,7 +190,8 @@ def test_prefix_caching_for_multi_turn(): ...@@ -188,7 +190,8 @@ def test_prefix_caching_for_multi_turn():
block_size=BLOCK_SIZE) block_size=BLOCK_SIZE)
requests = create_requests(num_requests=5, requests = create_requests(num_requests=5,
num_tokens=num_prompt_tokens, num_tokens=num_prompt_tokens,
max_tokens=num_output_tokens) max_tokens=num_output_tokens,
block_size=BLOCK_SIZE)
for req in requests: for req in requests:
scheduler.add_request(req) scheduler.add_request(req)
...@@ -208,14 +211,19 @@ def test_prefix_caching_for_multi_turn(): ...@@ -208,14 +211,19 @@ def test_prefix_caching_for_multi_turn():
# Create next-turn requests whose prompts are the full output of the # Create next-turn requests whose prompts are the full output of the
# previous turn. # previous turn.
next_turn_requests = create_requests( next_turn_requests = create_requests(num_requests=5,
num_requests=5, num_tokens=num_prompt_tokens +
num_tokens=num_prompt_tokens + num_output_tokens, num_output_tokens,
max_tokens=num_output_tokens, max_tokens=num_output_tokens,
) block_size=BLOCK_SIZE)
for i, req in enumerate(next_turn_requests): for i, req in enumerate(next_turn_requests):
req.prompt_token_ids = (requests[i].prompt_token_ids + req.prompt_token_ids = (requests[i].prompt_token_ids +
list(requests[i].output_token_ids)) list(requests[i].output_token_ids))
req._all_token_ids = req.prompt_token_ids.copy()
req.all_token_ids = ConstantList(req._all_token_ids)
req.block_hashes = []
req.block_hashes = req.get_hash_new_full_blocks()
# Schedule the next-turn requests. # Schedule the next-turn requests.
for req in next_turn_requests: for req in next_turn_requests:
scheduler.add_request(req) scheduler.add_request(req)
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import importlib import importlib
from typing import Optional from typing import Callable, Optional
import pytest import pytest
import torch import torch
...@@ -19,7 +19,7 @@ from vllm.v1.core.kv_cache_utils import ( ...@@ -19,7 +19,7 @@ from vllm.v1.core.kv_cache_utils import (
FreeKVCacheBlockQueue, KVCacheBlock, PrefixCachingMetrics, FreeKVCacheBlockQueue, KVCacheBlock, PrefixCachingMetrics,
estimate_max_model_len, generate_block_hash_extra_keys, estimate_max_model_len, generate_block_hash_extra_keys,
get_kv_cache_config, get_max_concurrency_for_kv_cache_config, get_kv_cache_config, get_max_concurrency_for_kv_cache_config,
hash_block_tokens, hash_request_tokens, init_none_hash, get_request_block_hasher, hash_block_tokens, init_none_hash,
is_kv_cache_type_uniform, unify_kv_cache_configs) is_kv_cache_type_uniform, unify_kv_cache_configs)
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec, KVCacheTensor, KVCacheGroupSpec, KVCacheTensor,
...@@ -33,6 +33,8 @@ from vllm.v1.request import Request ...@@ -33,6 +33,8 @@ from vllm.v1.request import Request
def make_request( def make_request(
request_id: str, request_id: str,
prompt_token_ids: list[int], prompt_token_ids: list[int],
block_size: int = 3,
hash_fn: Callable = hash,
mm_positions: Optional[list[PlaceholderRange]] = None, mm_positions: Optional[list[PlaceholderRange]] = None,
mm_hashes: Optional[list[str]] = None, mm_hashes: Optional[list[str]] = None,
cache_salt: Optional[str] = None, cache_salt: Optional[str] = None,
...@@ -49,18 +51,17 @@ def make_request( ...@@ -49,18 +51,17 @@ def make_request(
mm_item = MultiModalKwargsItem.from_elems([mm_elem]) mm_item = MultiModalKwargsItem.from_elems([mm_elem])
mm_kwargs = [mm_item] * len(mm_positions) mm_kwargs = [mm_item] * len(mm_positions)
return Request( return Request(request_id=request_id,
request_id=request_id, prompt_token_ids=prompt_token_ids,
prompt_token_ids=prompt_token_ids, multi_modal_kwargs=mm_kwargs,
multi_modal_kwargs=mm_kwargs, multi_modal_hashes=mm_hashes,
multi_modal_hashes=mm_hashes, multi_modal_placeholders=mm_positions,
multi_modal_placeholders=mm_positions, sampling_params=SamplingParams(max_tokens=17),
sampling_params=SamplingParams(max_tokens=17), pooling_params=None,
pooling_params=None, eos_token_id=100,
eos_token_id=100, lora_request=None,
lora_request=None, cache_salt=cache_salt,
cache_salt=cache_salt, block_hasher=get_request_block_hasher(block_size, hash_fn))
)
def new_kv_cache_spec(block_size=16, def new_kv_cache_spec(block_size=16,
...@@ -428,12 +429,14 @@ def test_hash_block_tokens(hash_fn): ...@@ -428,12 +429,14 @@ def test_hash_block_tokens(hash_fn):
@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash]) @pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash])
def test_hash_request_tokens(hash_fn): def test_request_block_hasher(hash_fn):
import vllm.v1.core.kv_cache_utils import vllm.v1.core.kv_cache_utils
init_none_hash(hash_fn) init_none_hash(hash_fn)
request = make_request( request = make_request(
request_id="0", request_id="0",
prompt_token_ids=[_ for _ in range(6)], prompt_token_ids=[_ for _ in range(6)],
block_size=3,
hash_fn=hash_fn,
mm_positions=[ mm_positions=[
PlaceholderRange(offset=0, length=3), PlaceholderRange(offset=0, length=3),
PlaceholderRange(offset=3, length=3), PlaceholderRange(offset=3, length=3),
...@@ -441,9 +444,7 @@ def test_hash_request_tokens(hash_fn): ...@@ -441,9 +444,7 @@ def test_hash_request_tokens(hash_fn):
mm_hashes=["hash1", "hash2"], mm_hashes=["hash1", "hash2"],
) )
block_size = 3 block_hashes = request.block_hashes
block_hashes = hash_request_tokens(hash_fn, block_size, request)
assert len(block_hashes) == 2 assert len(block_hashes) == 2
assert isinstance(block_hashes[0], vllm.v1.core.kv_cache_utils.BlockHash) assert isinstance(block_hashes[0], vllm.v1.core.kv_cache_utils.BlockHash)
assert isinstance(block_hashes[1], vllm.v1.core.kv_cache_utils.BlockHash) assert isinstance(block_hashes[1], vllm.v1.core.kv_cache_utils.BlockHash)
...@@ -464,6 +465,8 @@ def test_hash_tokens_different_mm_input(hash_fn): ...@@ -464,6 +465,8 @@ def test_hash_tokens_different_mm_input(hash_fn):
request1 = make_request( request1 = make_request(
request_id="0", request_id="0",
prompt_token_ids=[_ for _ in range(6)], prompt_token_ids=[_ for _ in range(6)],
block_size=3,
hash_fn=hash_fn,
mm_positions=[ mm_positions=[
PlaceholderRange(offset=0, length=3), PlaceholderRange(offset=0, length=3),
PlaceholderRange(offset=3, length=3), PlaceholderRange(offset=3, length=3),
...@@ -479,9 +482,8 @@ def test_hash_tokens_different_mm_input(hash_fn): ...@@ -479,9 +482,8 @@ def test_hash_tokens_different_mm_input(hash_fn):
], ],
mm_hashes=["hash3", "hash2"], mm_hashes=["hash3", "hash2"],
) )
block_size = 3 block_hashes1 = request1.block_hashes
block_hashes1 = hash_request_tokens(hash_fn, block_size, request1) block_hashes2 = request2.block_hashes
block_hashes2 = hash_request_tokens(hash_fn, block_size, request2)
assert block_hashes1[0] != block_hashes2[0] assert block_hashes1[0] != block_hashes2[0]
assert block_hashes1[1] != block_hashes2[1] assert block_hashes1[1] != block_hashes2[1]
...@@ -493,12 +495,13 @@ def test_hash_request_tokens_no_mm_inputs(hash_fn): ...@@ -493,12 +495,13 @@ def test_hash_request_tokens_no_mm_inputs(hash_fn):
request = make_request( request = make_request(
request_id="0", request_id="0",
prompt_token_ids=[_ for _ in range(6)], prompt_token_ids=[_ for _ in range(6)],
block_size=3,
hash_fn=hash_fn,
mm_positions=None, mm_positions=None,
mm_hashes=None, mm_hashes=None,
) )
block_size = 3 block_hashes = request.block_hashes
block_hashes = hash_request_tokens(hash_fn, block_size, request)
assert len(block_hashes) == 2 assert len(block_hashes) == 2
assert block_hashes[0].token_ids == (0, 1, 2) assert block_hashes[0].token_ids == (0, 1, 2)
...@@ -858,6 +861,7 @@ def test_allocate_with_lookahead(): ...@@ -858,6 +861,7 @@ def test_allocate_with_lookahead():
request = make_request( request = make_request(
request_id="0", request_id="0",
prompt_token_ids=[], prompt_token_ids=[],
block_size=block_size,
mm_positions=None, mm_positions=None,
mm_hashes=None, mm_hashes=None,
) )
......
This diff is collapsed.
...@@ -589,7 +589,7 @@ def test_preempt_during_execution(): ...@@ -589,7 +589,7 @@ def test_preempt_during_execution():
block_size=16, block_size=16,
num_blocks=11, num_blocks=11,
enable_prefix_caching=False) enable_prefix_caching=False)
requests = create_requests(num_requests=2, num_tokens=80) requests = create_requests(num_requests=2, num_tokens=80, block_size=16)
# Schedule the first request. # Schedule the first request.
scheduler.add_request(requests[0]) scheduler.add_request(requests[0])
...@@ -762,7 +762,7 @@ def _assert_right_scheduler_output( ...@@ -762,7 +762,7 @@ def _assert_right_scheduler_output(
def _assert_right_kv_cache_manager( def _assert_right_kv_cache_manager(
scheduler: Scheduler, scheduler: Scheduler,
req_ids: list[str], requests: list[Request],
num_tokens: int, num_tokens: int,
block_size: int, block_size: int,
num_requests: int, num_requests: int,
...@@ -772,12 +772,12 @@ def _assert_right_kv_cache_manager( ...@@ -772,12 +772,12 @@ def _assert_right_kv_cache_manager(
# Make sure the request stats are right. # Make sure the request stats are right.
EXPECTED_TOTAL_BLOCKS = num_tokens // block_size EXPECTED_TOTAL_BLOCKS = num_tokens // block_size
for req_id in req_ids: for req in requests:
blocks = (scheduler.kv_cache_manager.coordinator. blocks = (scheduler.kv_cache_manager.coordinator.
single_type_managers[0].req_to_blocks[req_id]) single_type_managers[0].req_to_blocks[req.request_id])
hashes = scheduler.kv_cache_manager.req_to_block_hashes[req_id] hashes = req.block_hashes
assert (scheduler.kv_cache_manager.coordinator.single_type_managers[0]. assert (scheduler.kv_cache_manager.coordinator.single_type_managers[0].
num_cached_block[req_id] == EXPECTED_TOTAL_BLOCKS) num_cached_block[req.request_id] == EXPECTED_TOTAL_BLOCKS)
assert len(blocks) == EXPECTED_TOTAL_BLOCKS assert len(blocks) == EXPECTED_TOTAL_BLOCKS
assert len(hashes) == EXPECTED_TOTAL_BLOCKS assert len(hashes) == EXPECTED_TOTAL_BLOCKS
...@@ -840,7 +840,8 @@ def test_kv_connector_basic(): ...@@ -840,7 +840,8 @@ def test_kv_connector_basic():
MAX_TOKENS = 3 MAX_TOKENS = 3
requests = create_requests(num_requests=NUM_REQUESTS, requests = create_requests(num_requests=NUM_REQUESTS,
num_tokens=NUM_TOKENS, num_tokens=NUM_TOKENS,
max_tokens=MAX_TOKENS) max_tokens=MAX_TOKENS,
block_size=BLOCK_SIZE)
req_ids = [] req_ids = []
req_to_index = {} req_to_index = {}
for i, request in enumerate(requests): for i, request in enumerate(requests):
...@@ -868,7 +869,7 @@ def test_kv_connector_basic(): ...@@ -868,7 +869,7 @@ def test_kv_connector_basic():
) )
# Ensure KVCacheManager is correct. # Ensure KVCacheManager is correct.
_assert_right_kv_cache_manager(scheduler, req_ids, NUM_TOKENS, BLOCK_SIZE, _assert_right_kv_cache_manager(scheduler, requests, NUM_TOKENS, BLOCK_SIZE,
NUM_REQUESTS, NUM_TOTAL_BLOCKS) NUM_REQUESTS, NUM_TOTAL_BLOCKS)
# Continue Generation until done. # Continue Generation until done.
...@@ -886,7 +887,8 @@ def test_kv_connector_basic(): ...@@ -886,7 +887,8 @@ def test_kv_connector_basic():
NUM_TOKENS = NUM_TOKENS_PREFIX * 2 NUM_TOKENS = NUM_TOKENS_PREFIX * 2
requests = create_requests(num_requests=NUM_REQUESTS, requests = create_requests(num_requests=NUM_REQUESTS,
num_tokens=NUM_TOKENS, num_tokens=NUM_TOKENS,
max_tokens=MAX_TOKENS) max_tokens=MAX_TOKENS,
block_size=BLOCK_SIZE)
req_ids = [] req_ids = []
req_to_index = {} req_to_index = {}
for i, request in enumerate(requests): for i, request in enumerate(requests):
...@@ -915,7 +917,7 @@ def test_kv_connector_basic(): ...@@ -915,7 +917,7 @@ def test_kv_connector_basic():
NUM_MATCHED_NEW_TOKENS)) NUM_MATCHED_NEW_TOKENS))
# Ensure KVCacheManager is correct. # Ensure KVCacheManager is correct.
_assert_right_kv_cache_manager(scheduler, req_ids, NUM_TOKENS, BLOCK_SIZE, _assert_right_kv_cache_manager(scheduler, requests, NUM_TOKENS, BLOCK_SIZE,
NUM_REQUESTS, NUM_TOTAL_BLOCKS) NUM_REQUESTS, NUM_TOTAL_BLOCKS)
# Continue Generation until done. # Continue Generation until done.
...@@ -953,7 +955,8 @@ def test_kv_connector_unable_to_allocate(): ...@@ -953,7 +955,8 @@ def test_kv_connector_unable_to_allocate():
MAX_TOKENS = 2 MAX_TOKENS = 2
requests = create_requests(num_requests=NUM_REQUESTS, requests = create_requests(num_requests=NUM_REQUESTS,
num_tokens=NUM_TOKENS, num_tokens=NUM_TOKENS,
max_tokens=MAX_TOKENS) max_tokens=MAX_TOKENS,
block_size=BLOCK_SIZE)
req_ids = [] req_ids = []
req_to_index = {} req_to_index = {}
for i, request in enumerate(requests): for i, request in enumerate(requests):
...@@ -1034,7 +1037,8 @@ def test_kv_connector_handles_preemption(): ...@@ -1034,7 +1037,8 @@ def test_kv_connector_handles_preemption():
MAX_TOKENS = BLOCK_SIZE * 2 MAX_TOKENS = BLOCK_SIZE * 2
requests = create_requests(num_requests=NUM_REQUESTS, requests = create_requests(num_requests=NUM_REQUESTS,
num_tokens=NUM_TOKENS, num_tokens=NUM_TOKENS,
max_tokens=MAX_TOKENS) max_tokens=MAX_TOKENS,
block_size=BLOCK_SIZE)
req_ids = [] req_ids = []
req_to_index = {} req_to_index = {}
for i, request in enumerate(requests): for i, request in enumerate(requests):
...@@ -1162,7 +1166,6 @@ def assert_scheduler_empty(scheduler: Scheduler): ...@@ -1162,7 +1166,6 @@ def assert_scheduler_empty(scheduler: Scheduler):
# KVCache Manager. # KVCache Manager.
assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0]. assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0].
req_to_blocks) == 0 req_to_blocks) == 0
assert len(scheduler.kv_cache_manager.req_to_block_hashes) == 0
assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0]. assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0].
num_cached_block) == 0 num_cached_block) == 0
num_free_blocks = ( num_free_blocks = (
......
...@@ -17,7 +17,6 @@ from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec, ...@@ -17,7 +17,6 @@ from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec,
def get_sliding_window_manager(sliding_window_spec, block_pool): def get_sliding_window_manager(sliding_window_spec, block_pool):
return SlidingWindowManager(sliding_window_spec, return SlidingWindowManager(sliding_window_spec,
block_pool, block_pool,
caching_hash_fn=lambda x: x,
kv_cache_group_id=0) kv_cache_group_id=0)
...@@ -25,7 +24,6 @@ def get_chunked_local_attention_manager(chunked_local_attention_spec, ...@@ -25,7 +24,6 @@ def get_chunked_local_attention_manager(chunked_local_attention_spec,
block_pool): block_pool):
return ChunkedLocalAttentionManager(chunked_local_attention_spec, return ChunkedLocalAttentionManager(chunked_local_attention_spec,
block_pool, block_pool,
caching_hash_fn=lambda x: x,
kv_cache_group_id=0) kv_cache_group_id=0)
......
...@@ -10,6 +10,8 @@ from vllm.multimodal.inputs import (MultiModalBatchedField, ...@@ -10,6 +10,8 @@ from vllm.multimodal.inputs import (MultiModalBatchedField,
MultiModalFieldElem, MultiModalKwargsItem, MultiModalFieldElem, MultiModalKwargsItem,
PlaceholderRange) PlaceholderRange)
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.v1.core.kv_cache_utils import (get_request_block_hasher,
init_none_hash)
from vllm.v1.core.sched.async_scheduler import AsyncScheduler from vllm.v1.core.sched.async_scheduler import AsyncScheduler
from vllm.v1.core.sched.scheduler import Scheduler from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
...@@ -114,6 +116,9 @@ def create_scheduler( ...@@ -114,6 +116,9 @@ def create_scheduler(
) )
_none_hash_initialized = False
def create_requests( def create_requests(
num_requests: int, num_requests: int,
num_tokens: int = 10, num_tokens: int = 10,
...@@ -122,7 +127,14 @@ def create_requests( ...@@ -122,7 +127,14 @@ def create_requests(
stop_token_ids: Optional[list[int]] = None, stop_token_ids: Optional[list[int]] = None,
prompt_logprobs: Optional[int] = None, prompt_logprobs: Optional[int] = None,
same_prompt: bool = False, same_prompt: bool = False,
block_size: int = 16,
) -> list[Request]: ) -> list[Request]:
global _none_hash_initialized
if not _none_hash_initialized:
init_none_hash(hash)
_none_hash_initialized = True
block_hasher = get_request_block_hasher(block_size, hash)
sampling_params = SamplingParams(ignore_eos=False, sampling_params = SamplingParams(ignore_eos=False,
max_tokens=max_tokens, max_tokens=max_tokens,
stop_token_ids=stop_token_ids, stop_token_ids=stop_token_ids,
...@@ -139,9 +151,11 @@ def create_requests( ...@@ -139,9 +151,11 @@ def create_requests(
) )
mm_item = MultiModalKwargsItem.from_elems([mm_elem]) mm_item = MultiModalKwargsItem.from_elems([mm_elem])
mm_kwargs = [mm_item] * len(mm_position) mm_kwargs = [mm_item] * len(mm_position)
mm_hashes = ["hash"] * len(mm_position)
else: else:
mm_position = None mm_position = None
mm_kwargs = None mm_kwargs = None
mm_hashes = None
prompt_token_ids = ([0] * num_tokens if same_prompt else [i] * prompt_token_ids = ([0] * num_tokens if same_prompt else [i] *
num_tokens) num_tokens)
request = Request( request = Request(
...@@ -151,8 +165,9 @@ def create_requests( ...@@ -151,8 +165,9 @@ def create_requests(
pooling_params=None, pooling_params=None,
multi_modal_kwargs=mm_kwargs, multi_modal_kwargs=mm_kwargs,
multi_modal_placeholders=mm_position, multi_modal_placeholders=mm_position,
multi_modal_hashes=None, multi_modal_hashes=mm_hashes,
eos_token_id=EOS_TOKEN_ID, eos_token_id=EOS_TOKEN_ID,
block_hasher=block_hasher,
) )
requests.append(request) requests.append(request)
return requests return requests
...@@ -147,6 +147,7 @@ def test_basic_interface(): ...@@ -147,6 +147,7 @@ def test_basic_interface():
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5)) NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
request = create_request(request_id=1, request = create_request(request_id=1,
block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS, num_tokens=NUM_TOKENS,
do_remote_prefill=True) do_remote_prefill=True)
request_id = request.request_id request_id = request.request_id
...@@ -186,6 +187,7 @@ def test_prompt_less_than_block_size(): ...@@ -186,6 +187,7 @@ def test_prompt_less_than_block_size():
# Request will have 1 partial remote block. # Request will have 1 partial remote block.
request = create_request(request_id=1, request = create_request(request_id=1,
block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS, num_tokens=NUM_TOKENS,
do_remote_prefill=True, do_remote_prefill=True,
num_remote_blocks=1) num_remote_blocks=1)
......
...@@ -21,6 +21,7 @@ def test_basic_lifecycle(): ...@@ -21,6 +21,7 @@ def test_basic_lifecycle():
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5)) NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
request = create_request(request_id=1, request = create_request(request_id=1,
block_size=BLOCK_SIZE,
max_tokens=1, max_tokens=1,
num_tokens=NUM_TOKENS, num_tokens=NUM_TOKENS,
do_remote_decode=True) do_remote_decode=True)
...@@ -103,8 +104,10 @@ def test_short_prompt_lifecycle(): ...@@ -103,8 +104,10 @@ def test_short_prompt_lifecycle():
scheduler = create_scheduler(vllm_config) scheduler = create_scheduler(vllm_config)
# Not enough tokens for full block. # Not enough tokens for full block.
NUM_TOKENS = vllm_config.cache_config.block_size // 2 BLOCK_SIZE = vllm_config.cache_config.block_size
NUM_TOKENS = BLOCK_SIZE // 2
request = create_request(request_id=1, request = create_request(request_id=1,
block_size=BLOCK_SIZE,
max_tokens=1, max_tokens=1,
num_tokens=NUM_TOKENS, num_tokens=NUM_TOKENS,
do_remote_decode=True) do_remote_decode=True)
...@@ -148,7 +151,9 @@ def test_prefix_cache_lifecycle(): ...@@ -148,7 +151,9 @@ def test_prefix_cache_lifecycle():
NUM_EXTERNAL_FULL_BLOCKS = 3 NUM_EXTERNAL_FULL_BLOCKS = 3
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5)) NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
request_normal = create_request(request_id=1, num_tokens=NUM_TOKENS) request_normal = create_request(request_id=1,
block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS)
scheduler.add_request(request_normal) scheduler.add_request(request_normal)
scheduler_output = scheduler.schedule() scheduler_output = scheduler.schedule()
...@@ -166,6 +171,7 @@ def test_prefix_cache_lifecycle(): ...@@ -166,6 +171,7 @@ def test_prefix_cache_lifecycle():
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5)) NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
request_remote = create_request(request_id=1, request_remote = create_request(request_id=1,
block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS, num_tokens=NUM_TOKENS,
do_remote_decode=True) do_remote_decode=True)
......
...@@ -23,6 +23,7 @@ def test_basic_lifecycle(): ...@@ -23,6 +23,7 @@ def test_basic_lifecycle():
scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks) scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks)
request = create_request(request_id=1, request = create_request(request_id=1,
block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS, num_tokens=NUM_TOKENS,
do_remote_prefill=True) do_remote_prefill=True)
...@@ -133,14 +134,17 @@ def test_interleaved_lifecycle(): ...@@ -133,14 +134,17 @@ def test_interleaved_lifecycle():
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5)) NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
request_remote = create_request(request_id=1, request_remote = create_request(request_id=1,
block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS, num_tokens=NUM_TOKENS,
do_remote_prefill=True) do_remote_prefill=True)
request_local_a = create_request( request_local_a = create_request(
request_id=2, request_id=2,
block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS, num_tokens=NUM_TOKENS,
) )
request_local_b = create_request( request_local_b = create_request(
request_id=3, request_id=3,
block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS, num_tokens=NUM_TOKENS,
) )
...@@ -236,6 +240,7 @@ def test_no_spurious_prefix_caching(): ...@@ -236,6 +240,7 @@ def test_no_spurious_prefix_caching():
# Both of these requests have prompts like [1,1,1,1,1, ...] # Both of these requests have prompts like [1,1,1,1,1, ...]
request_remote = create_request( request_remote = create_request(
request_id=1, request_id=1,
block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS, num_tokens=NUM_TOKENS,
do_remote_prefill=True, do_remote_prefill=True,
use_all_1s_for_prompt_tokens=True, use_all_1s_for_prompt_tokens=True,
...@@ -243,6 +248,7 @@ def test_no_spurious_prefix_caching(): ...@@ -243,6 +248,7 @@ def test_no_spurious_prefix_caching():
request_local = create_request( request_local = create_request(
request_id=2, request_id=2,
block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS, num_tokens=NUM_TOKENS,
do_remote_prefill=False, do_remote_prefill=False,
use_all_1s_for_prompt_tokens=True, use_all_1s_for_prompt_tokens=True,
...@@ -292,6 +298,7 @@ def test_full_block_prompt(): ...@@ -292,6 +298,7 @@ def test_full_block_prompt():
NUM_TOKENS = int(BLOCK_SIZE * NUM_EXTERNAL_FULL_BLOCKS) NUM_TOKENS = int(BLOCK_SIZE * NUM_EXTERNAL_FULL_BLOCKS)
request = create_request(request_id=1, request = create_request(request_id=1,
block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS, num_tokens=NUM_TOKENS,
do_remote_prefill=True) do_remote_prefill=True)
...@@ -364,8 +371,11 @@ def test_cannot_schedule_after_recv(): ...@@ -364,8 +371,11 @@ def test_cannot_schedule_after_recv():
NUM_TOKENS_LOCAL = int(BLOCK_SIZE * NUM_PROMPT_BLOCKS) NUM_TOKENS_LOCAL = int(BLOCK_SIZE * NUM_PROMPT_BLOCKS)
NUM_TOKENS_REMOTE = int(BLOCK_SIZE * NUM_PROMPT_BLOCKS) NUM_TOKENS_REMOTE = int(BLOCK_SIZE * NUM_PROMPT_BLOCKS)
request_normal = create_request(request_id=1, num_tokens=NUM_TOKENS_LOCAL) request_normal = create_request(request_id=1,
block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS_LOCAL)
request_remote = create_request(request_id=2, request_remote = create_request(request_id=2,
block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS_REMOTE, num_tokens=NUM_TOKENS_REMOTE,
do_remote_prefill=True) do_remote_prefill=True)
...@@ -456,8 +466,11 @@ def test_cannot_recv(): ...@@ -456,8 +466,11 @@ def test_cannot_recv():
NUM_TOKENS_LOCAL = int(BLOCK_SIZE * NUM_PROMPT_BLOCKS) NUM_TOKENS_LOCAL = int(BLOCK_SIZE * NUM_PROMPT_BLOCKS)
NUM_TOKENS_REMOTE = int(BLOCK_SIZE * (NUM_PROMPT_BLOCKS + 0.5)) NUM_TOKENS_REMOTE = int(BLOCK_SIZE * (NUM_PROMPT_BLOCKS + 0.5))
request_normal = create_request(request_id=1, num_tokens=NUM_TOKENS_LOCAL) request_normal = create_request(request_id=1,
block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS_LOCAL)
request_remote = create_request(request_id=2, request_remote = create_request(request_id=2,
block_size=BLOCK_SIZE,
num_tokens=NUM_TOKENS_REMOTE, num_tokens=NUM_TOKENS_REMOTE,
do_remote_prefill=True) do_remote_prefill=True)
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import tempfile import tempfile
from collections import defaultdict from collections import defaultdict
from typing import Any, Optional from typing import Any, Callable, Optional
import torch import torch
...@@ -14,6 +14,8 @@ from vllm.distributed.kv_transfer.kv_connector.factory import ( ...@@ -14,6 +14,8 @@ from vllm.distributed.kv_transfer.kv_connector.factory import (
from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import ( # noqa from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import ( # noqa
SharedStorageConnector) SharedStorageConnector)
from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.core.kv_cache_utils import (get_request_block_hasher,
init_none_hash)
from vllm.v1.core.sched.scheduler import Scheduler from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec) KVCacheGroupSpec)
...@@ -40,7 +42,6 @@ def assert_scheduler_empty(scheduler: Scheduler): ...@@ -40,7 +42,6 @@ def assert_scheduler_empty(scheduler: Scheduler):
# KVCache Manager. # KVCache Manager.
assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0]. assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0].
req_to_blocks) == 0 req_to_blocks) == 0
assert len(scheduler.kv_cache_manager.req_to_block_hashes) == 0
assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0]. assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0].
num_cached_block) == 0 num_cached_block) == 0
num_free_blocks = ( num_free_blocks = (
...@@ -115,16 +116,23 @@ def create_scheduler( ...@@ -115,16 +116,23 @@ def create_scheduler(
) )
def create_request( _none_hash_initialized = False
request_id: int,
num_tokens: int = 10,
max_tokens: int = 16, def create_request(request_id: int,
do_remote_decode: bool = False, num_tokens: int = 10,
do_remote_prefill: bool = False, max_tokens: int = 16,
use_all_1s_for_prompt_tokens: bool = False, do_remote_decode: bool = False,
num_remote_blocks: int = 3, do_remote_prefill: bool = False,
) -> Request: use_all_1s_for_prompt_tokens: bool = False,
num_remote_blocks: int = 3,
block_size: int = 16,
hash_fn: Callable = hash) -> Request:
"""Make dummy request for testing.""" """Make dummy request for testing."""
global _none_hash_initialized
if not _none_hash_initialized:
init_none_hash(hash)
_none_hash_initialized = True
kv_transfer_params: Optional[dict[str, Any]] = None kv_transfer_params: Optional[dict[str, Any]] = None
...@@ -158,6 +166,7 @@ def create_request( ...@@ -158,6 +166,7 @@ def create_request(
multi_modal_placeholders=None, multi_modal_placeholders=None,
multi_modal_hashes=None, multi_modal_hashes=None,
eos_token_id=EOS_TOKEN_ID, eos_token_id=EOS_TOKEN_ID,
block_hasher=get_request_block_hasher(block_size, hash_fn),
) )
req.kv_transfer_params = kv_transfer_params req.kv_transfer_params = kv_transfer_params
return req return req
......
...@@ -3243,6 +3243,24 @@ def sha256_cbor_64bit(input) -> int: ...@@ -3243,6 +3243,24 @@ def sha256_cbor_64bit(input) -> int:
return full_hash & ((1 << 64) - 1) return full_hash & ((1 << 64) - 1)
def get_hash_fn_by_name(hash_fn_name: str) -> Callable:
"""Get a hash function by name, or raise an error if
the function is not found.
Args:
hash_fn_name: Name of the hash function.
Returns:
A hash function.
"""
if hash_fn_name == "sha256":
return sha256
if hash_fn_name == "sha256_cbor_64bit":
return sha256_cbor_64bit
if hash_fn_name == "builtin":
return hash
raise ValueError(f"Unsupported hash function: {hash_fn_name}")
def is_torch_equal_or_newer(target: str) -> bool: def is_torch_equal_or_newer(target: str) -> bool:
"""Check if the installed torch version is >= the target version. """Check if the installed torch version is >= the target version.
......
...@@ -2,15 +2,13 @@ ...@@ -2,15 +2,13 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections import defaultdict from collections import defaultdict
from collections.abc import Iterable from collections.abc import Iterable
from typing import Callable, Optional from typing import Optional
from vllm.distributed.kv_events import (AllBlocksCleared, BlockRemoved, from vllm.distributed.kv_events import (AllBlocksCleared, BlockRemoved,
BlockStored, KVCacheEvent) BlockStored, KVCacheEvent)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId, from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId,
FreeKVCacheBlockQueue, KVCacheBlock, FreeKVCacheBlockQueue, KVCacheBlock)
generate_block_hash_extra_keys,
hash_block_tokens)
from vllm.v1.request import Request from vllm.v1.request import Request
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -97,84 +95,39 @@ class BlockPool: ...@@ -97,84 +95,39 @@ class BlockPool:
self, self,
request: Request, request: Request,
blocks: list[KVCacheBlock], blocks: list[KVCacheBlock],
block_hashes: list[BlockHash],
num_cached_blocks: int, num_cached_blocks: int,
num_full_blocks: int, num_full_blocks: int,
block_size: int, block_size: int,
kv_cache_group_id: int, kv_cache_group_id: int,
hash_fn: Callable,
) -> None: ) -> None:
"""Cache a list of full blocks for prefix caching. """Cache a list of full blocks for prefix caching.
This function takes a list of blocks that will have their block hash This function takes a list of blocks that will have their block hash
metadata to be updated and cached. Given a request, it computes the metadata to be updated and cached. Given a request, it updates the
block hashes for the blocks starting from `num_cached_blocks` to metadata for each block and caching it in the
`num_full_blocks`, updating the metadata for each block `cached_block_hash_to_block`.
and caching them in the `cached_block_hash_to_block`. The block hashes values are computed by the Request object immediately
when it is created and when new tokens are appended.
Args: Args:
request: The request to cache the blocks. request: The request to cache the blocks.
blocks: All blocks in the request. blocks: All blocks in the request.
block_hashes: Block hashes of the blocks in the request. Note that
this list may be shorter than the blocks list. In this case the
missed block hash will be computed in this function.
num_cached_blocks: The number of blocks that are already cached. num_cached_blocks: The number of blocks that are already cached.
num_full_blocks: The number of blocks that are full and should num_full_blocks: The number of blocks that are full and should
be cached after this function. be cached after this function.
block_size: Number of tokens in each block. block_size: Number of tokens in each block.
kv_cache_group_id: The id of the KV cache group. kv_cache_group_id: The id of the KV cache group.
hash_fn: The hash function to use for block hashes.
""" """
if num_cached_blocks == num_full_blocks: if num_cached_blocks == num_full_blocks:
return return
new_full_blocks = blocks[num_cached_blocks:num_full_blocks] new_full_blocks = blocks[num_cached_blocks:num_full_blocks]
assert len(block_hashes) >= num_cached_blocks assert len(request.block_hashes) >= num_full_blocks
new_block_hashes = block_hashes[num_cached_blocks:] new_block_hashes = request.block_hashes[num_cached_blocks:]
# Update the new blocks with the block hashes through the chain.
if num_cached_blocks == 0:
prev_block_hash_value = None
else:
prev_block = blocks[num_cached_blocks - 1]
assert prev_block.block_hash is not None
prev_block_hash_value = prev_block.block_hash.get_hash_value()
parent_block_hash = prev_block_hash_value
new_hashes: Optional[list[int]] = ([] if self.enable_kv_cache_events new_hashes: Optional[list[int]] = ([] if self.enable_kv_cache_events
else None) else None)
for i, blk in enumerate(new_full_blocks): for i, blk in enumerate(new_full_blocks):
assert blk.block_hash is None assert blk.block_hash is None
block_hash = new_block_hashes[i]
if i < len(new_block_hashes):
# The block hash may already be computed in
# "get_computed_blocks" if the tokens are not generated by
# this request (either the prompt tokens or the previously
# generated tokens with preemption), or by other
# single_type_managers with the same block_size.
# In this case we simply reuse the block hash.
block_hash = new_block_hashes[i]
else:
# Otherwise compute the block hash and cache it in the request
# in case it will be preempted in the future.
blk_idx = num_cached_blocks + i
start_token_idx = blk_idx * block_size
end_token_idx = (blk_idx + 1) * block_size
block_tokens = request.all_token_ids[
start_token_idx:end_token_idx]
assert len(block_tokens) == block_size, (
f"Expected {block_size} tokens, got "
f"{len(block_tokens)} at {blk_idx}th block for request "
f"{request.request_id}({request})")
# Generate extra keys for multi-modal inputs. Note that since
# we reach to this branch only when the block is completed with
# generated tokens, we only need to consider the last mm input.
extra_keys, _ = generate_block_hash_extra_keys(
request, start_token_idx, end_token_idx, -1)
# Compute the hash of the current block.
block_hash = hash_block_tokens(hash_fn, prev_block_hash_value,
block_tokens, extra_keys)
block_hashes.append(block_hash)
# Update and added the full block to the cache. # Update and added the full block to the cache.
block_hash_with_group_id = BlockHashWithGroupId( block_hash_with_group_id = BlockHashWithGroupId(
...@@ -184,9 +137,15 @@ class BlockPool: ...@@ -184,9 +137,15 @@ class BlockPool:
blk.block_id] = blk blk.block_id] = blk
if new_hashes is not None: if new_hashes is not None:
new_hashes.append(block_hash.hash_value) new_hashes.append(block_hash.hash_value)
prev_block_hash_value = block_hash.hash_value
if self.enable_kv_cache_events: if self.enable_kv_cache_events:
if num_cached_blocks == 0:
parent_block_hash = None
else:
parent_block = blocks[num_cached_blocks - 1]
assert parent_block.block_hash is not None
parent_block_hash = parent_block.block_hash.get_hash_value()
self.kv_event_queue.append( self.kv_event_queue.append(
BlockStored( BlockStored(
block_hashes=new_hashes, block_hashes=new_hashes,
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Callable, Optional from typing import Optional
from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock
...@@ -23,7 +23,6 @@ class KVCacheCoordinator(ABC): ...@@ -23,7 +23,6 @@ class KVCacheCoordinator(ABC):
max_model_len: int, max_model_len: int,
use_eagle: bool, use_eagle: bool,
enable_caching: bool, enable_caching: bool,
caching_hash_fn: Callable,
enable_kv_cache_events: bool, enable_kv_cache_events: bool,
): ):
self.kv_cache_config = kv_cache_config self.kv_cache_config = kv_cache_config
...@@ -40,7 +39,6 @@ class KVCacheCoordinator(ABC): ...@@ -40,7 +39,6 @@ class KVCacheCoordinator(ABC):
kv_cache_spec=kv_cache_group.kv_cache_spec, kv_cache_spec=kv_cache_group.kv_cache_spec,
block_pool=self.block_pool, block_pool=self.block_pool,
kv_cache_group_id=i, kv_cache_group_id=i,
caching_hash_fn=caching_hash_fn,
) for i, kv_cache_group in enumerate( ) for i, kv_cache_group in enumerate(
self.kv_cache_config.kv_cache_groups)) self.kv_cache_config.kv_cache_groups))
...@@ -99,19 +97,17 @@ class KVCacheCoordinator(ABC): ...@@ -99,19 +97,17 @@ class KVCacheCoordinator(ABC):
manager.allocate_new_blocks(request_id, num_tokens) manager.allocate_new_blocks(request_id, num_tokens)
for manager in self.single_type_managers) for manager in self.single_type_managers)
def cache_blocks(self, request: Request, block_hashes: list[BlockHash], def cache_blocks(self, request: Request, num_computed_tokens: int) -> None:
num_computed_tokens: int) -> None:
""" """
Cache the blocks for the request. Cache the blocks for the request.
Args: Args:
request: The request. request: The request.
block_hashes: The block hashes of the request.
num_tokens: The total number of tokens that need to be cached num_tokens: The total number of tokens that need to be cached
(including tokens that are already cached). (including tokens that are already cached).
""" """
for manager in self.single_type_managers: for manager in self.single_type_managers:
manager.cache_blocks(request, block_hashes, num_computed_tokens) manager.cache_blocks(request, num_computed_tokens)
def free(self, request_id: str) -> None: def free(self, request_id: str) -> None:
""" """
...@@ -184,10 +180,9 @@ class KVCacheCoordinatorNoPrefixCache(KVCacheCoordinator): ...@@ -184,10 +180,9 @@ class KVCacheCoordinatorNoPrefixCache(KVCacheCoordinator):
""" """
def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int, def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int,
use_eagle: bool, caching_hash_fn: Callable, use_eagle: bool, enable_kv_cache_events: bool):
enable_kv_cache_events: bool):
super().__init__(kv_cache_config, max_model_len, use_eagle, False, super().__init__(kv_cache_config, max_model_len, use_eagle, False,
caching_hash_fn, enable_kv_cache_events) enable_kv_cache_events)
self.num_single_type_manager = len(self.single_type_managers) self.num_single_type_manager = len(self.single_type_managers)
def get_num_common_prefix_blocks(self, request_id: str, def get_num_common_prefix_blocks(self, request_id: str,
...@@ -213,10 +208,9 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator): ...@@ -213,10 +208,9 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator):
def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int, def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int,
use_eagle: bool, enable_caching: bool, use_eagle: bool, enable_caching: bool,
caching_hash_fn: Callable, enable_kv_cache_events: bool): enable_kv_cache_events: bool):
super().__init__(kv_cache_config, max_model_len, use_eagle, super().__init__(kv_cache_config, max_model_len, use_eagle,
enable_caching, caching_hash_fn, enable_caching, enable_kv_cache_events)
enable_kv_cache_events)
self.kv_cache_spec = self.kv_cache_config.kv_cache_groups[ self.kv_cache_spec = self.kv_cache_config.kv_cache_groups[
0].kv_cache_spec 0].kv_cache_spec
self.block_size = self.kv_cache_spec.block_size self.block_size = self.kv_cache_spec.block_size
...@@ -250,10 +244,9 @@ class HybridKVCacheCoordinator(KVCacheCoordinator): ...@@ -250,10 +244,9 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int, def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int,
use_eagle: bool, enable_caching: bool, use_eagle: bool, enable_caching: bool,
caching_hash_fn: Callable, enable_kv_cache_events: bool): enable_kv_cache_events: bool):
super().__init__(kv_cache_config, max_model_len, use_eagle, super().__init__(kv_cache_config, max_model_len, use_eagle,
enable_caching, caching_hash_fn, enable_caching, enable_kv_cache_events)
enable_kv_cache_events)
self.verify_and_split_kv_cache_groups() self.verify_and_split_kv_cache_groups()
def verify_and_split_kv_cache_groups(self) -> None: def verify_and_split_kv_cache_groups(self) -> None:
...@@ -386,17 +379,15 @@ class HybridKVCacheCoordinator(KVCacheCoordinator): ...@@ -386,17 +379,15 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
def get_kv_cache_coordinator( def get_kv_cache_coordinator(
kv_cache_config: KVCacheConfig, max_model_len: int, use_eagle: bool, kv_cache_config: KVCacheConfig, max_model_len: int, use_eagle: bool,
enable_caching: bool, caching_hash_fn: Callable, enable_caching: bool,
enable_kv_cache_events: bool) -> KVCacheCoordinator: enable_kv_cache_events: bool) -> KVCacheCoordinator:
if not enable_caching: if not enable_caching:
return KVCacheCoordinatorNoPrefixCache(kv_cache_config, max_model_len, return KVCacheCoordinatorNoPrefixCache(kv_cache_config, max_model_len,
use_eagle, caching_hash_fn, use_eagle,
enable_kv_cache_events) enable_kv_cache_events)
if len(kv_cache_config.kv_cache_groups) == 1: if len(kv_cache_config.kv_cache_groups) == 1:
return UnitaryKVCacheCoordinator(kv_cache_config, max_model_len, return UnitaryKVCacheCoordinator(kv_cache_config, max_model_len,
use_eagle, enable_caching, use_eagle, enable_caching,
caching_hash_fn,
enable_kv_cache_events) enable_kv_cache_events)
return HybridKVCacheCoordinator(kv_cache_config, max_model_len, use_eagle, return HybridKVCacheCoordinator(kv_cache_config, max_model_len, use_eagle,
enable_caching, caching_hash_fn, enable_caching, enable_kv_cache_events)
enable_kv_cache_events)
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections import defaultdict
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional from typing import Optional
from vllm.distributed.kv_events import KVCacheEvent from vllm.distributed.kv_events import KVCacheEvent
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import sha256, sha256_cbor_64bit
from vllm.v1.core.kv_cache_coordinator import get_kv_cache_coordinator from vllm.v1.core.kv_cache_coordinator import get_kv_cache_coordinator
from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock, from vllm.v1.core.kv_cache_utils import KVCacheBlock
hash_request_tokens, init_none_hash)
from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.metrics.stats import PrefixCacheStats from vllm.v1.metrics.stats import PrefixCacheStats
from vllm.v1.request import Request, RequestStatus from vllm.v1.request import Request, RequestStatus
...@@ -71,23 +68,13 @@ class KVCacheManager: ...@@ -71,23 +68,13 @@ class KVCacheManager:
kv_cache_config: KVCacheConfig, kv_cache_config: KVCacheConfig,
max_model_len: int, max_model_len: int,
enable_caching: bool = True, enable_caching: bool = True,
caching_hash_algo: str = "builtin",
use_eagle: bool = False, use_eagle: bool = False,
log_stats: bool = False, log_stats: bool = False,
enable_kv_cache_events: bool = False, enable_kv_cache_events: bool = False,
) -> None: ) -> None:
self.max_model_len = max_model_len self.max_model_len = max_model_len
if len(kv_cache_config.kv_cache_groups) == 0:
# Attention free models don't have kv cache,
# thus don't need prefix caching.
enable_caching = False
self.enable_caching = enable_caching self.enable_caching = enable_caching
self.caching_hash_fn = (
sha256_cbor_64bit if caching_hash_algo == "sha256_cbor_64bit" else
sha256 if caching_hash_algo == "sha256" else hash)
init_none_hash(self.caching_hash_fn)
self.use_eagle = use_eagle self.use_eagle = use_eagle
self.log_stats = log_stats self.log_stats = log_stats
# FIXME: make prefix cache stats conditional on log_stats # FIXME: make prefix cache stats conditional on log_stats
...@@ -107,19 +94,12 @@ class KVCacheManager: ...@@ -107,19 +94,12 @@ class KVCacheManager:
max_model_len=self.max_model_len, max_model_len=self.max_model_len,
use_eagle=self.use_eagle, use_eagle=self.use_eagle,
enable_caching=self.enable_caching, enable_caching=self.enable_caching,
caching_hash_fn=self.caching_hash_fn,
enable_kv_cache_events=enable_kv_cache_events, enable_kv_cache_events=enable_kv_cache_events,
) )
self.num_kv_cache_groups = len(kv_cache_config.kv_cache_groups) self.num_kv_cache_groups = len(kv_cache_config.kv_cache_groups)
self.block_pool = self.coordinator.block_pool self.block_pool = self.coordinator.block_pool
self.kv_cache_config = kv_cache_config self.kv_cache_config = kv_cache_config
# Mapping from request ID to kv block hashes.
# This is to avoid recomputing the block hashes for each call of
# `get_computed_blocks` or `allocate_slots`.
self.req_to_block_hashes: defaultdict[
str, list[BlockHash]] = defaultdict(list)
@property @property
def usage(self) -> float: def usage(self) -> float:
"""Get the KV cache usage. """Get the KV cache usage.
...@@ -161,15 +141,6 @@ class KVCacheManager: ...@@ -161,15 +141,6 @@ class KVCacheManager:
and request.sampling_params.prompt_logprobs is not None)): and request.sampling_params.prompt_logprobs is not None)):
return self.create_empty_block_list(), 0 return self.create_empty_block_list(), 0
# The block hashes for the request may already be computed
# if the scheduler has tried to schedule the request before.
block_hashes = self.req_to_block_hashes[request.request_id]
if not block_hashes:
assert self.block_size is not None
block_hashes = hash_request_tokens(self.caching_hash_fn,
self.block_size, request)
self.req_to_block_hashes[request.request_id] = block_hashes
# NOTE: When all tokens hit the cache, we must recompute the last token # NOTE: When all tokens hit the cache, we must recompute the last token
# to obtain logits. Thus, set max_cache_hit_length to prompt_length - 1. # to obtain logits. Thus, set max_cache_hit_length to prompt_length - 1.
# This can trigger recomputation of an entire block, rather than just # This can trigger recomputation of an entire block, rather than just
...@@ -178,7 +149,7 @@ class KVCacheManager: ...@@ -178,7 +149,7 @@ class KVCacheManager:
# could slightly improve performance in the future. # could slightly improve performance in the future.
max_cache_hit_length = request.num_tokens - 1 max_cache_hit_length = request.num_tokens - 1
computed_blocks, num_new_computed_tokens = ( computed_blocks, num_new_computed_tokens = (
self.coordinator.find_longest_cache_hit(block_hashes, self.coordinator.find_longest_cache_hit(request.block_hashes,
max_cache_hit_length)) max_cache_hit_length))
if self.log_stats: if self.log_stats:
...@@ -296,11 +267,7 @@ class KVCacheManager: ...@@ -296,11 +267,7 @@ class KVCacheManager:
# at `request.num_tokens`, ensuring only "finalized" tokens are cached. # at `request.num_tokens`, ensuring only "finalized" tokens are cached.
num_tokens_to_cache = min(num_computed_tokens + num_new_tokens, num_tokens_to_cache = min(num_computed_tokens + num_new_tokens,
request.num_tokens) request.num_tokens)
self.coordinator.cache_blocks( self.coordinator.cache_blocks(request, num_tokens_to_cache)
request,
self.req_to_block_hashes[request.request_id],
num_tokens_to_cache,
)
return KVCacheBlocks(new_blocks) return KVCacheBlocks(new_blocks)
...@@ -373,14 +340,6 @@ class KVCacheManager: ...@@ -373,14 +340,6 @@ class KVCacheManager:
return self.coordinator.get_num_common_prefix_blocks( return self.coordinator.get_num_common_prefix_blocks(
request.request_id, num_running_requests) request.request_id, num_running_requests)
def free_block_hashes(self, request: Request) -> None:
"""Discard the block hashes for the request.
NOTE: Unlike `free`, this method should be called only when the request
is finished, not when it is preempted.
"""
self.req_to_block_hashes.pop(request.request_id, None)
def take_events(self) -> list[KVCacheEvent]: def take_events(self) -> list[KVCacheEvent]:
"""Take the KV cache events from the block pool. """Take the KV cache events from the block pool.
...@@ -397,9 +356,7 @@ class KVCacheManager: ...@@ -397,9 +356,7 @@ class KVCacheManager:
def cache_blocks(self, request: Request, num_computed_tokens: int) -> None: def cache_blocks(self, request: Request, num_computed_tokens: int) -> None:
"""Cache the blocks for the request, if enabled.""" """Cache the blocks for the request, if enabled."""
if self.enable_caching: if self.enable_caching:
block_hashes = self.req_to_block_hashes[request.request_id] self.coordinator.cache_blocks(request, num_computed_tokens)
self.coordinator.cache_blocks(request, block_hashes,
num_computed_tokens)
def create_empty_block_list(self) -> KVCacheBlocks: def create_empty_block_list(self) -> KVCacheBlocks:
"""Creates a new KVCacheBlocks instance with no blocks.""" """Creates a new KVCacheBlocks instance with no blocks."""
......
...@@ -547,41 +547,61 @@ def hash_block_tokens( ...@@ -547,41 +547,61 @@ def hash_block_tokens(
curr_block_token_ids_tuple, extra_keys) curr_block_token_ids_tuple, extra_keys)
def hash_request_tokens(hash_function: Any, block_size: int, def get_request_block_hasher(
request: Request) -> list[BlockHash]: block_size: int,
"""Computes hash values of a chain of blocks given a sequence of caching_hash_fn: Callable[[Any],
token IDs. The hash value is used for prefix caching. int]) -> Callable[[Request], list[BlockHash]]:
"""
Returns a function which computes the list of un-computed block hashes
of a request.
Each request holds a list of its block hashes (request.block_hashes).
When a request is created, it calls the below function to compute
the hashes of all full blocks of the request's initial tokens.
The hashes are then stored in request.block_hashes.
Later, whenever new tokens are appended to the request, it calls
the below function again to compute any new full blocks of tokens.
The returned new hashes are appended to request.block_hashes.
"""
Args: def request_block_hasher(request: Request) -> list[BlockHash]:
block_size: The size of each block. start_token_idx = len(request.block_hashes) * block_size
request: The request object. num_tokens = request.num_tokens
curr_mm_idx = 0
if start_token_idx > 0:
# Set curr_mm_idx = -1 to indicate the last mm input.
# Note that since we reach to this branch only when the block is
# completed with generated tokens, we only need to consider the
# last mm input.
curr_mm_idx = -1
prev_block_hash_value = request.block_hashes[-1].hash_value \
if request.block_hashes else None
new_block_hashes: list[BlockHash] = []
while True:
end_token_idx = start_token_idx + block_size
if end_token_idx > num_tokens:
# We only hash full blocks
break
Returns: # MM and LoRA requests need extra keys for block-hash computation.
The list of computed hash values. extra_keys, curr_mm_idx = generate_block_hash_extra_keys(
""" request, start_token_idx, end_token_idx, curr_mm_idx)
token_ids = request.all_token_ids
req_need_extra_keys = need_extra_keys(request) # Compute the hash of the current block
req_extra_keys = None block_tokens = request.all_token_ids[start_token_idx:end_token_idx]
curr_mm_idx = 0 block_hash = hash_block_tokens(caching_hash_fn,
prev_block_hash_value, block_tokens,
extra_keys)
ret = [] new_block_hashes.append(block_hash)
parent_block_hash_value = None start_token_idx += block_size
# Only full blocks will be hashed prev_block_hash_value = block_hash.hash_value
for start in range(0, len(token_ids) - block_size + 1, block_size):
end = start + block_size
block_token_ids = token_ids[start:end]
if req_need_extra_keys: return new_block_hashes
# MM and LoRA requests need extra keys for block-hash computation.
req_extra_keys, curr_mm_idx = generate_block_hash_extra_keys( return request_block_hasher
request, start, end, curr_mm_idx)
block_hash = hash_block_tokens(hash_function, parent_block_hash_value,
block_token_ids, req_extra_keys)
ret.append(block_hash)
parent_block_hash_value = block_hash.hash_value
return ret
def max_memory_usage_bytes(vllm_config: VllmConfig, def max_memory_usage_bytes(vllm_config: VllmConfig,
......
...@@ -155,7 +155,6 @@ class Scheduler(SchedulerInterface): ...@@ -155,7 +155,6 @@ class Scheduler(SchedulerInterface):
kv_cache_config=kv_cache_config, kv_cache_config=kv_cache_config,
max_model_len=self.max_model_len, max_model_len=self.max_model_len,
enable_caching=self.cache_config.enable_prefix_caching, enable_caching=self.cache_config.enable_prefix_caching,
caching_hash_algo=self.cache_config.prefix_caching_hash_algo,
use_eagle=self.use_eagle, use_eagle=self.use_eagle,
log_stats=self.log_stats, log_stats=self.log_stats,
enable_kv_cache_events=self.enable_kv_cache_events, enable_kv_cache_events=self.enable_kv_cache_events,
...@@ -1036,7 +1035,6 @@ class Scheduler(SchedulerInterface): ...@@ -1036,7 +1035,6 @@ class Scheduler(SchedulerInterface):
def _free_blocks(self, request: Request): def _free_blocks(self, request: Request):
assert request.is_finished() assert request.is_finished()
self.kv_cache_manager.free(request) self.kv_cache_manager.free(request)
self.kv_cache_manager.free_block_hashes(request)
del self.requests[request.request_id] del self.requests[request.request_id]
def get_num_unfinished_requests(self) -> int: def get_num_unfinished_requests(self) -> int:
......
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
import itertools import itertools
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections import defaultdict from collections import defaultdict
from typing import Callable
from vllm.utils import cdiv from vllm.utils import cdiv
from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.block_pool import BlockPool
...@@ -25,7 +24,6 @@ class SingleTypeKVCacheManager(ABC): ...@@ -25,7 +24,6 @@ class SingleTypeKVCacheManager(ABC):
kv_cache_spec: KVCacheSpec, kv_cache_spec: KVCacheSpec,
block_pool: BlockPool, block_pool: BlockPool,
kv_cache_group_id: int, kv_cache_group_id: int,
caching_hash_fn: Callable,
) -> None: ) -> None:
""" """
Initializes the SingleTypeKVCacheManager. Initializes the SingleTypeKVCacheManager.
...@@ -33,7 +31,6 @@ class SingleTypeKVCacheManager(ABC): ...@@ -33,7 +31,6 @@ class SingleTypeKVCacheManager(ABC):
kv_cache_spec: The kv_cache_spec for this manager. kv_cache_spec: The kv_cache_spec for this manager.
block_pool: The block pool. block_pool: The block pool.
kv_cache_group_id: The id of the kv cache group of this manager. kv_cache_group_id: The id of the kv cache group of this manager.
caching_hash_fn: The caching hash function.
""" """
self.block_size = kv_cache_spec.block_size self.block_size = kv_cache_spec.block_size
...@@ -52,7 +49,6 @@ class SingleTypeKVCacheManager(ABC): ...@@ -52,7 +49,6 @@ class SingleTypeKVCacheManager(ABC):
# data for reempted ones. # data for reempted ones.
self.num_cached_block: dict[str, int] = {} self.num_cached_block: dict[str, int] = {}
self.caching_hash_fn = caching_hash_fn
self.kv_cache_group_id = kv_cache_group_id self.kv_cache_group_id = kv_cache_group_id
self._null_block = block_pool.null_block self._null_block = block_pool.null_block
...@@ -130,14 +126,12 @@ class SingleTypeKVCacheManager(ABC): ...@@ -130,14 +126,12 @@ class SingleTypeKVCacheManager(ABC):
req_blocks.extend(new_blocks) req_blocks.extend(new_blocks)
return new_blocks return new_blocks
def cache_blocks(self, request: Request, block_hashes: list[BlockHash], def cache_blocks(self, request: Request, num_tokens: int) -> None:
num_tokens: int) -> None:
""" """
Cache the blocks for the request. Cache the blocks for the request.
Args: Args:
request: The request. request: The request.
block_hashes: The block hashes of the request.
num_tokens: The total number of tokens that need to be cached num_tokens: The total number of tokens that need to be cached
(including tokens that are already cached). (including tokens that are already cached).
""" """
...@@ -147,12 +141,10 @@ class SingleTypeKVCacheManager(ABC): ...@@ -147,12 +141,10 @@ class SingleTypeKVCacheManager(ABC):
self.block_pool.cache_full_blocks( self.block_pool.cache_full_blocks(
request=request, request=request,
blocks=self.req_to_blocks[request.request_id], blocks=self.req_to_blocks[request.request_id],
block_hashes=block_hashes,
num_cached_blocks=num_cached_blocks, num_cached_blocks=num_cached_blocks,
num_full_blocks=num_full_blocks, num_full_blocks=num_full_blocks,
block_size=self.block_size, block_size=self.block_size,
kv_cache_group_id=self.kv_cache_group_id, kv_cache_group_id=self.kv_cache_group_id,
hash_fn=self.caching_hash_fn,
) )
self.num_cached_block[request.request_id] = num_full_blocks self.num_cached_block[request.request_id] = num_full_blocks
......
...@@ -25,9 +25,11 @@ from vllm.multimodal import MULTIMODAL_REGISTRY ...@@ -25,9 +25,11 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.tasks import POOLING_TASKS, SupportedTask from vllm.tasks import POOLING_TASKS, SupportedTask
from vllm.transformers_utils.config import ( from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value) maybe_register_config_serialize_by_value)
from vllm.utils import (decorate_logs, make_zmq_socket, from vllm.utils import (decorate_logs, get_hash_fn_by_name, make_zmq_socket,
resolve_obj_by_qualname, set_process_title) resolve_obj_by_qualname, set_process_title)
from vllm.v1.core.kv_cache_utils import (get_kv_cache_config, from vllm.v1.core.kv_cache_utils import (BlockHash, get_kv_cache_config,
get_request_block_hasher,
init_none_hash,
unify_kv_cache_configs) unify_kv_cache_configs)
from vllm.v1.core.sched.interface import SchedulerInterface from vllm.v1.core.sched.interface import SchedulerInterface
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
...@@ -140,6 +142,19 @@ class EngineCore: ...@@ -140,6 +142,19 @@ class EngineCore:
self.batch_queue_size) self.batch_queue_size)
self.batch_queue = queue.Queue(self.batch_queue_size) self.batch_queue = queue.Queue(self.batch_queue_size)
self.request_block_hasher: Optional[Callable[[Request],
list[BlockHash]]] = None
if (self.vllm_config.cache_config.enable_prefix_caching
or self.scheduler.get_kv_connector() is not None):
block_size = vllm_config.cache_config.block_size
caching_hash_fn = get_hash_fn_by_name(
vllm_config.cache_config.prefix_caching_hash_algo)
init_none_hash(caching_hash_fn)
self.request_block_hasher = get_request_block_hasher(
block_size, caching_hash_fn)
def _initialize_kv_caches( def _initialize_kv_caches(
self, vllm_config: VllmConfig) -> tuple[int, int, KVCacheConfig]: self, vllm_config: VllmConfig) -> tuple[int, int, KVCacheConfig]:
start = time.time() start = time.time()
...@@ -417,7 +432,8 @@ class EngineCore: ...@@ -417,7 +432,8 @@ class EngineCore:
request.mm_kwargs = self.mm_input_cache_server.get_and_update( request.mm_kwargs = self.mm_input_cache_server.get_and_update(
request.mm_kwargs, request.mm_hashes) request.mm_kwargs, request.mm_hashes)
req = Request.from_engine_core_request(request) req = Request.from_engine_core_request(request,
self.request_block_hasher)
if req.use_structured_output: if req.use_structured_output:
# Note on thread safety: no race condition. # Note on thread safety: no race condition.
# `grammar_init` is only invoked in input processing thread. For # `grammar_init` is only invoked in input processing thread. For
......
...@@ -3,7 +3,8 @@ ...@@ -3,7 +3,8 @@
import enum import enum
import time import time
from typing import TYPE_CHECKING, Any, Optional, Union from functools import partial
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
...@@ -16,6 +17,7 @@ from vllm.v1.utils import ConstantList ...@@ -16,6 +17,7 @@ from vllm.v1.utils import ConstantList
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.v1.core.kv_cache_utils import BlockHash
class Request: class Request:
...@@ -36,6 +38,8 @@ class Request: ...@@ -36,6 +38,8 @@ class Request:
structured_output_request: Optional["StructuredOutputRequest"] = None, structured_output_request: Optional["StructuredOutputRequest"] = None,
cache_salt: Optional[str] = None, cache_salt: Optional[str] = None,
priority: int = 0, priority: int = 0,
block_hasher: Optional[Callable[["Request"],
list["BlockHash"]]] = None,
) -> None: ) -> None:
self.request_id = request_id self.request_id = request_id
self.client_index = client_index self.client_index = client_index
...@@ -108,8 +112,18 @@ class Request: ...@@ -108,8 +112,18 @@ class Request:
# indicates that the output is corrupted # indicates that the output is corrupted
self.num_nans_in_logits = 0 self.num_nans_in_logits = 0
self.block_hashes: list[BlockHash] = []
self.get_hash_new_full_blocks: Optional[Callable[
[], list[BlockHash]]] = None
if block_hasher is not None:
self.get_hash_new_full_blocks = partial(block_hasher, self)
self.block_hashes = self.get_hash_new_full_blocks()
@classmethod @classmethod
def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request": def from_engine_core_request(
cls, request: EngineCoreRequest,
block_hasher: Optional[Callable[["Request"], list["BlockHash"]]]
) -> "Request":
if request.mm_kwargs is not None: if request.mm_kwargs is not None:
assert is_list_of(request.mm_kwargs, MultiModalKwargsItem), ( assert is_list_of(request.mm_kwargs, MultiModalKwargsItem), (
"mm_kwargs was not updated in EngineCore.add_request") "mm_kwargs was not updated in EngineCore.add_request")
...@@ -131,6 +145,7 @@ class Request: ...@@ -131,6 +145,7 @@ class Request:
if request.sampling_params else None, if request.sampling_params else None,
cache_salt=request.cache_salt, cache_salt=request.cache_salt,
priority=request.priority, priority=request.priority,
block_hasher=block_hasher,
) )
def append_output_token_ids( def append_output_token_ids(
...@@ -144,6 +159,9 @@ class Request: ...@@ -144,6 +159,9 @@ class Request:
self._output_token_ids.extend(token_ids) self._output_token_ids.extend(token_ids)
self._all_token_ids.extend(token_ids) self._all_token_ids.extend(token_ids)
if self.get_hash_new_full_blocks is not None:
self.block_hashes.extend(self.get_hash_new_full_blocks())
@property @property
def is_output_corrupted(self) -> bool: def is_output_corrupted(self) -> bool:
return self.num_nans_in_logits > 0 return self.num_nans_in_logits > 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment