Unverified Commit 51adca74 authored by Or Ozeri's avatar Or Ozeri Committed by GitHub
Browse files

[kv_offload+HMA][9/N]: Support lookup with multiple KV groups (#39401)


Signed-off-by: default avatarOr Ozeri <oro@il.ibm.com>
parent e8eb0490
......@@ -120,6 +120,13 @@ class OffloadingConnectorScheduler:
self.config = SchedulerOffloadConfig.from_spec(spec)
self.manager: OffloadingManager = spec.get_manager()
attention_groups: list[int] = []
for idx, _ in enumerate(spec.kv_cache_config.kv_cache_groups):
# currently treat all groups as full attention
attention_groups.append(idx)
self.lookup_groups = attention_groups
self._req_status: dict[ReqId, RequestOffloadState] = {}
# requests to load for the current scheduler step
self._reqs_to_load: dict[ReqId, TransferSpec] = {}
......@@ -204,64 +211,88 @@ class OffloadingConnectorScheduler:
group_state.block_ids.clear()
else:
req_status = RequestOffloadState(config=self.config, req=request)
req_status.update_offload_keys()
self._req_status[request.request_id] = req_status
req_status.update_offload_keys()
req_status.num_locally_computed_tokens = num_computed_tokens
# Below assertions will be removed once this function supports HMA
assert len(self.config.kv_group_configs) == 1
assert len(req_status.group_states) == 1
group_config = self.config.kv_group_configs[0]
group_state = req_status.group_states[0]
for gs in req_status.group_states:
self.manager.touch(gs.offload_keys)
num_blocks = request.num_tokens // group_config.offloaded_block_size
# Start with the full request size as the maximum loadable
max_hit_size_tokens: int = req_status.req.num_tokens
num_hit_tokens: int = 0
defer_lookup = False
delay_request = False
for group_idx in self.lookup_groups:
group_config: GroupOffloadConfig = self.config.kv_group_configs[group_idx]
offloaded_block_size = group_config.offloaded_block_size
offload_keys = req_status.group_states[group_idx].offload_keys
assert len(request.block_hashes) // self.config.block_size_factor == num_blocks
offload_keys = group_state.offload_keys
num_blocks = max_hit_size_tokens // offloaded_block_size
assert len(offload_keys) >= num_blocks
self.manager.touch(offload_keys)
# Constrain to block-aligned boundary for this group
max_hit_size_tokens = num_blocks * offloaded_block_size
num_hit_tokens = max_hit_size_tokens - num_computed_tokens
if num_hit_tokens < offloaded_block_size:
# we can only load less than a block, better skip
return 0, False
start_block_idx = num_computed_tokens // offloaded_block_size
offload_keys = offload_keys[start_block_idx:num_blocks]
# Full attention relies on all previous KV cache blocks.
# Thus, we search for a maximal prefix of KV cache which are all cached.
block_hits = self._maximal_prefix_lookup(
offload_keys, req_status.req_context
)
if block_hits == 0:
return 0, False
full_block_tokens = group_config.offloaded_block_size * num_blocks
if full_block_tokens - num_computed_tokens < group_config.offloaded_block_size:
# we can load less than a block, skip
return 0, False
if block_hits is None:
defer_lookup = True
else:
# Further constrain based on what's actually available by backend
max_hit_size_tokens = offloaded_block_size * (
start_block_idx + block_hits
)
start_block_idx = num_computed_tokens // group_config.offloaded_block_size
# Full attention relays on all previous KV cache blocks.
# Thus, we search for a maximal prefix of KV cache which are all cached.
hits = self._maximal_prefix_lookup(
offload_keys[start_block_idx:], req_status.req_context
)
if hits is None:
# indicates a lookup that should be tried later
num_hit_tokens = max_hit_size_tokens - num_computed_tokens
if num_hit_tokens < offloaded_block_size:
# we can only load less than a block, better skip
return 0, False
if (
block_hits
and self._blocks_being_loaded
and any(
key in self._blocks_being_loaded
for key in offload_keys[:block_hits]
)
):
# hit blocks are being loaded, delay request
delay_request = True
if defer_lookup:
logger.debug(
"Offloading manager delayed request %s as backend requested",
req_status.req.request_id,
)
return None, False
if delay_request:
logger.debug(
"Delaying request %s since some of its blocks are already being loaded",
req_status.req.request_id,
)
return None, False
if hits == 0:
return 0, False
num_hit_tokens = (
group_config.offloaded_block_size * (start_block_idx + hits)
- num_computed_tokens
)
logger.debug(
"Request %s hit %s offloaded tokens after %s GPU hit tokens",
request.request_id,
num_hit_tokens,
num_computed_tokens,
)
if num_hit_tokens < group_config.offloaded_block_size:
return 0, False
if self._blocks_being_loaded and any(
key in self._blocks_being_loaded
for key in offload_keys[start_block_idx : start_block_idx + hits]
):
# hit blocks are being loaded, delay request
logger.debug(
"Delaying request %s since some of its blocks are already being loaded",
request.request_id,
)
return None, False
return num_hit_tokens, True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment