Unverified Commit 5ef33ab2 authored by Or Ozeri's avatar Or Ozeri Committed by GitHub
Browse files

[kv_offload+HMA][10/N]: Support load with multiple KV groups (#39402)


Signed-off-by: default avatarOr Ozeri <oro@il.ibm.com>
parent 1c2c1eb8
......@@ -14,6 +14,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.offloading.common import (
ReqId,
)
from vllm.logger import init_logger
from vllm.utils.math_utils import cdiv
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_offload.abstract import (
......@@ -271,45 +272,66 @@ class OffloadingConnectorScheduler:
return
req_status = self._req_status[request.request_id]
block_groups = blocks.get_block_ids()
# Below assertions will be removed once this function supports HMA
assert len(self.config.kv_group_configs) == 1
assert len(req_status.group_states) == 1
assert len(block_groups) == 1
block_ids = block_groups[0]
group_config = self.config.kv_group_configs[0]
group_state = req_status.group_states[0]
num_computed_gpu_blocks = sum(
block.block_hash is not None for block in blocks.blocks[0]
)
num_computed_tokens = num_computed_gpu_blocks * group_config.gpu_block_size
full_block_tokens = num_computed_tokens + num_external_tokens
assert full_block_tokens % group_config.offloaded_block_size == 0
num_locally_computed_tokens = req_status.num_locally_computed_tokens
num_cached_tokens = num_locally_computed_tokens + num_external_tokens
keys_to_load: list[OffloadKey] = []
dst_block_ids: list[int] = []
# per group
group_sizes: list[int] = []
block_indices: list[int] = []
for group_config, group_state, group_blocks in zip(
self.config.kv_group_configs,
req_status.group_states,
blocks.blocks,
):
gpu_block_size = group_config.gpu_block_size
offloaded_block_size = group_config.offloaded_block_size
offload_keys = group_state.offload_keys
num_gpu_blocks = cdiv(num_cached_tokens, gpu_block_size)
assert len(group_blocks) >= num_gpu_blocks
num_locally_computed_gpu_blocks = num_gpu_blocks
# Skip null placeholder blocks (used for sliding window or mamba padding).
for i, block in enumerate(group_blocks[:num_gpu_blocks]):
if not block.is_null and block.block_hash is None:
num_locally_computed_gpu_blocks = i
break
assert (
num_locally_computed_tokens
<= num_locally_computed_gpu_blocks * gpu_block_size
)
num_pending_gpu_blocks = num_gpu_blocks - num_locally_computed_gpu_blocks
num_pending_gpu_blocks = len(block_ids) - num_computed_gpu_blocks
assert (
num_external_tokens == num_pending_gpu_blocks * group_config.gpu_block_size
)
num_blocks = cdiv(num_cached_tokens, offloaded_block_size)
assert len(offload_keys) >= num_blocks
if num_pending_gpu_blocks:
start_block_idx = (
num_locally_computed_gpu_blocks // self.config.block_size_factor
)
keys_to_load.extend(offload_keys[start_block_idx:num_blocks])
start_block_idx = num_computed_tokens // group_config.offloaded_block_size
num_blocks = full_block_tokens // group_config.offloaded_block_size
dst_block_ids.extend(
block.block_id
for block in group_blocks[
num_locally_computed_gpu_blocks:num_gpu_blocks
]
)
group_sizes.append(num_pending_gpu_blocks)
block_indices.append(num_locally_computed_gpu_blocks)
assert len(request.block_hashes) // self.config.block_size_factor >= num_blocks
offload_keys = group_state.offload_keys[start_block_idx:num_blocks]
group_state.next_stored_block_idx = num_blocks
src_spec = self.manager.prepare_load(offload_keys, req_status.req_context)
src_spec = self.manager.prepare_load(keys_to_load, req_status.req_context)
dst_spec = GPULoadStoreSpec(
block_ids[num_computed_gpu_blocks:],
group_sizes=(num_pending_gpu_blocks,),
block_indices=(num_computed_gpu_blocks,),
dst_block_ids, group_sizes=group_sizes, block_indices=block_indices
)
self._reqs_to_load[request.request_id] = (src_spec, dst_spec)
req_blocks_being_loaded = self._reqs_being_loaded[request.request_id]
req_blocks_being_loaded.update(offload_keys)
group_state.next_stored_block_idx = num_blocks
req_blocks_being_loaded.update(keys_to_load)
if self._blocks_being_loaded is not None:
self._blocks_being_loaded.update(req_blocks_being_loaded)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment