Unverified Commit 51adca74 authored by Or Ozeri's avatar Or Ozeri Committed by GitHub
Browse files

[kv_offload+HMA][9/N]: Support lookup with multiple KV groups (#39401)


Signed-off-by: default avatarOr Ozeri <oro@il.ibm.com>
parent e8eb0490
...@@ -120,6 +120,13 @@ class OffloadingConnectorScheduler: ...@@ -120,6 +120,13 @@ class OffloadingConnectorScheduler:
self.config = SchedulerOffloadConfig.from_spec(spec) self.config = SchedulerOffloadConfig.from_spec(spec)
self.manager: OffloadingManager = spec.get_manager() self.manager: OffloadingManager = spec.get_manager()
attention_groups: list[int] = []
for idx, _ in enumerate(spec.kv_cache_config.kv_cache_groups):
# currently treat all groups as full attention
attention_groups.append(idx)
self.lookup_groups = attention_groups
self._req_status: dict[ReqId, RequestOffloadState] = {} self._req_status: dict[ReqId, RequestOffloadState] = {}
# requests to load for the current scheduler step # requests to load for the current scheduler step
self._reqs_to_load: dict[ReqId, TransferSpec] = {} self._reqs_to_load: dict[ReqId, TransferSpec] = {}
...@@ -204,64 +211,88 @@ class OffloadingConnectorScheduler: ...@@ -204,64 +211,88 @@ class OffloadingConnectorScheduler:
group_state.block_ids.clear() group_state.block_ids.clear()
else: else:
req_status = RequestOffloadState(config=self.config, req=request) req_status = RequestOffloadState(config=self.config, req=request)
req_status.update_offload_keys()
self._req_status[request.request_id] = req_status self._req_status[request.request_id] = req_status
req_status.update_offload_keys()
req_status.num_locally_computed_tokens = num_computed_tokens req_status.num_locally_computed_tokens = num_computed_tokens
# Below assertions will be removed once this function supports HMA for gs in req_status.group_states:
assert len(self.config.kv_group_configs) == 1 self.manager.touch(gs.offload_keys)
assert len(req_status.group_states) == 1
group_config = self.config.kv_group_configs[0]
group_state = req_status.group_states[0]
num_blocks = request.num_tokens // group_config.offloaded_block_size # Start with the full request size as the maximum loadable
max_hit_size_tokens: int = req_status.req.num_tokens
num_hit_tokens: int = 0
defer_lookup = False
delay_request = False
for group_idx in self.lookup_groups:
group_config: GroupOffloadConfig = self.config.kv_group_configs[group_idx]
offloaded_block_size = group_config.offloaded_block_size
offload_keys = req_status.group_states[group_idx].offload_keys
assert len(request.block_hashes) // self.config.block_size_factor == num_blocks num_blocks = max_hit_size_tokens // offloaded_block_size
offload_keys = group_state.offload_keys assert len(offload_keys) >= num_blocks
self.manager.touch(offload_keys) # Constrain to block-aligned boundary for this group
max_hit_size_tokens = num_blocks * offloaded_block_size
num_hit_tokens = max_hit_size_tokens - num_computed_tokens
if num_hit_tokens < offloaded_block_size:
# we can only load less than a block, better skip
return 0, False
start_block_idx = num_computed_tokens // offloaded_block_size
offload_keys = offload_keys[start_block_idx:num_blocks]
# Full attention relies on all previous KV cache blocks.
# Thus, we search for a maximal prefix of KV cache which are all cached.
block_hits = self._maximal_prefix_lookup(
offload_keys, req_status.req_context
)
if block_hits == 0:
return 0, False
full_block_tokens = group_config.offloaded_block_size * num_blocks if block_hits is None:
if full_block_tokens - num_computed_tokens < group_config.offloaded_block_size: defer_lookup = True
# we can load less than a block, skip else:
return 0, False # Further constrain based on what's actually available by backend
max_hit_size_tokens = offloaded_block_size * (
start_block_idx + block_hits
)
start_block_idx = num_computed_tokens // group_config.offloaded_block_size num_hit_tokens = max_hit_size_tokens - num_computed_tokens
# Full attention relays on all previous KV cache blocks. if num_hit_tokens < offloaded_block_size:
# Thus, we search for a maximal prefix of KV cache which are all cached. # we can only load less than a block, better skip
hits = self._maximal_prefix_lookup( return 0, False
offload_keys[start_block_idx:], req_status.req_context
) if (
if hits is None: block_hits
# indicates a lookup that should be tried later and self._blocks_being_loaded
and any(
key in self._blocks_being_loaded
for key in offload_keys[:block_hits]
)
):
# hit blocks are being loaded, delay request
delay_request = True
if defer_lookup:
logger.debug(
"Offloading manager delayed request %s as backend requested",
req_status.req.request_id,
)
return None, False
if delay_request:
logger.debug(
"Delaying request %s since some of its blocks are already being loaded",
req_status.req.request_id,
)
return None, False return None, False
if hits == 0:
return 0, False
num_hit_tokens = (
group_config.offloaded_block_size * (start_block_idx + hits)
- num_computed_tokens
)
logger.debug( logger.debug(
"Request %s hit %s offloaded tokens after %s GPU hit tokens", "Request %s hit %s offloaded tokens after %s GPU hit tokens",
request.request_id, request.request_id,
num_hit_tokens, num_hit_tokens,
num_computed_tokens, num_computed_tokens,
) )
if num_hit_tokens < group_config.offloaded_block_size:
return 0, False
if self._blocks_being_loaded and any(
key in self._blocks_being_loaded
for key in offload_keys[start_block_idx : start_block_idx + hits]
):
# hit blocks are being loaded, delay request
logger.debug(
"Delaying request %s since some of its blocks are already being loaded",
request.request_id,
)
return None, False
return num_hit_tokens, True return num_hit_tokens, True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment