Unverified Commit 51adca74 authored by Or Ozeri's avatar Or Ozeri Committed by GitHub
Browse files

[kv_offload+HMA][9/N]: Support lookup with multiple KV groups (#39401)


Signed-off-by: default avatarOr Ozeri <oro@il.ibm.com>
parent e8eb0490
...@@ -120,6 +120,13 @@ class OffloadingConnectorScheduler: ...@@ -120,6 +120,13 @@ class OffloadingConnectorScheduler:
self.config = SchedulerOffloadConfig.from_spec(spec) self.config = SchedulerOffloadConfig.from_spec(spec)
self.manager: OffloadingManager = spec.get_manager() self.manager: OffloadingManager = spec.get_manager()
attention_groups: list[int] = []
for idx, _ in enumerate(spec.kv_cache_config.kv_cache_groups):
# currently treat all groups as full attention
attention_groups.append(idx)
self.lookup_groups = attention_groups
self._req_status: dict[ReqId, RequestOffloadState] = {} self._req_status: dict[ReqId, RequestOffloadState] = {}
# requests to load for the current scheduler step # requests to load for the current scheduler step
self._reqs_to_load: dict[ReqId, TransferSpec] = {} self._reqs_to_load: dict[ReqId, TransferSpec] = {}
...@@ -204,65 +211,89 @@ class OffloadingConnectorScheduler: ...@@ -204,65 +211,89 @@ class OffloadingConnectorScheduler:
group_state.block_ids.clear() group_state.block_ids.clear()
else: else:
req_status = RequestOffloadState(config=self.config, req=request) req_status = RequestOffloadState(config=self.config, req=request)
req_status.update_offload_keys()
self._req_status[request.request_id] = req_status self._req_status[request.request_id] = req_status
req_status.update_offload_keys()
req_status.num_locally_computed_tokens = num_computed_tokens req_status.num_locally_computed_tokens = num_computed_tokens
# Below assertions will be removed once this function supports HMA for gs in req_status.group_states:
assert len(self.config.kv_group_configs) == 1 self.manager.touch(gs.offload_keys)
assert len(req_status.group_states) == 1
group_config = self.config.kv_group_configs[0]
group_state = req_status.group_states[0]
num_blocks = request.num_tokens // group_config.offloaded_block_size
assert len(request.block_hashes) // self.config.block_size_factor == num_blocks # Start with the full request size as the maximum loadable
offload_keys = group_state.offload_keys max_hit_size_tokens: int = req_status.req.num_tokens
num_hit_tokens: int = 0
defer_lookup = False
delay_request = False
for group_idx in self.lookup_groups:
group_config: GroupOffloadConfig = self.config.kv_group_configs[group_idx]
offloaded_block_size = group_config.offloaded_block_size
offload_keys = req_status.group_states[group_idx].offload_keys
self.manager.touch(offload_keys) num_blocks = max_hit_size_tokens // offloaded_block_size
assert len(offload_keys) >= num_blocks
full_block_tokens = group_config.offloaded_block_size * num_blocks # Constrain to block-aligned boundary for this group
if full_block_tokens - num_computed_tokens < group_config.offloaded_block_size: max_hit_size_tokens = num_blocks * offloaded_block_size
# we can load less than a block, skip num_hit_tokens = max_hit_size_tokens - num_computed_tokens
if num_hit_tokens < offloaded_block_size:
# we can only load less than a block, better skip
return 0, False return 0, False
start_block_idx = num_computed_tokens // group_config.offloaded_block_size start_block_idx = num_computed_tokens // offloaded_block_size
# Full attention relays on all previous KV cache blocks. offload_keys = offload_keys[start_block_idx:num_blocks]
# Full attention relies on all previous KV cache blocks.
# Thus, we search for a maximal prefix of KV cache which are all cached. # Thus, we search for a maximal prefix of KV cache which are all cached.
hits = self._maximal_prefix_lookup( block_hits = self._maximal_prefix_lookup(
offload_keys[start_block_idx:], req_status.req_context offload_keys, req_status.req_context
) )
if hits is None: if block_hits == 0:
# indicates a lookup that should be tried later
return None, False
if hits == 0:
return 0, False return 0, False
num_hit_tokens = ( if block_hits is None:
group_config.offloaded_block_size * (start_block_idx + hits) defer_lookup = True
- num_computed_tokens else:
) # Further constrain based on what's actually available by backend
logger.debug( max_hit_size_tokens = offloaded_block_size * (
"Request %s hit %s offloaded tokens after %s GPU hit tokens", start_block_idx + block_hits
request.request_id,
num_hit_tokens,
num_computed_tokens,
) )
if num_hit_tokens < group_config.offloaded_block_size:
num_hit_tokens = max_hit_size_tokens - num_computed_tokens
if num_hit_tokens < offloaded_block_size:
# we can only load less than a block, better skip
return 0, False return 0, False
if self._blocks_being_loaded and any( if (
block_hits
and self._blocks_being_loaded
and any(
key in self._blocks_being_loaded key in self._blocks_being_loaded
for key in offload_keys[start_block_idx : start_block_idx + hits] for key in offload_keys[:block_hits]
)
): ):
# hit blocks are being loaded, delay request # hit blocks are being loaded, delay request
delay_request = True
if defer_lookup:
logger.debug(
"Offloading manager delayed request %s as backend requested",
req_status.req.request_id,
)
return None, False
if delay_request:
logger.debug( logger.debug(
"Delaying request %s since some of its blocks are already being loaded", "Delaying request %s since some of its blocks are already being loaded",
request.request_id, req_status.req.request_id,
) )
return None, False return None, False
logger.debug(
"Request %s hit %s offloaded tokens after %s GPU hit tokens",
request.request_id,
num_hit_tokens,
num_computed_tokens,
)
return num_hit_tokens, True return num_hit_tokens, True
def update_state_after_alloc( def update_state_after_alloc(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment