Unverified Commit 60cd878a authored by Or Ozeri's avatar Or Ozeri Committed by GitHub
Browse files

[kv_offload+HMA][11/N]: Support store with multiple KV groups (#39403)


Signed-off-by: default avatarOr Ozeri <oro@il.ibm.com>
parent 1e9f19ca
...@@ -112,6 +112,13 @@ class RequestOffloadState: ...@@ -112,6 +112,13 @@ class RequestOffloadState:
for group_state, new_blocks in zip(self.group_states, new_block_id_groups): for group_state, new_blocks in zip(self.group_states, new_block_id_groups):
group_state.block_ids.extend(new_blocks) group_state.block_ids.extend(new_blocks)
def advance_stored_idx(self, num_offloadable_tokens: int) -> None:
for group_config, group_state in zip(
self.config.kv_group_configs, self.group_states
):
num_blocks = num_offloadable_tokens // group_config.offloaded_block_size
group_state.next_stored_block_idx = num_blocks
class OffloadingConnectorScheduler: class OffloadingConnectorScheduler:
"""Implementation of Scheduler side methods""" """Implementation of Scheduler side methods"""
...@@ -367,16 +374,16 @@ class OffloadingConnectorScheduler: ...@@ -367,16 +374,16 @@ class OffloadingConnectorScheduler:
if self._blocks_being_loaded is not None: if self._blocks_being_loaded is not None:
self._blocks_being_loaded.update(req_blocks_being_loaded) self._blocks_being_loaded.update(req_blocks_being_loaded)
def _get_reqs_to_store(self, scheduler_output: SchedulerOutput): def _get_reqs_to_store(
# Below assertion will be removed once this function supports HMA self, scheduler_output: SchedulerOutput
assert len(self.config.kv_group_configs) == 1 ) -> dict[ReqId, TransferSpec]:
group_config = self.config.kv_group_configs[0] block_size_factor = self.config.block_size_factor
reqs_to_store: dict[ReqId, TransferSpec] = {} reqs_to_store: dict[ReqId, TransferSpec] = {}
# iterate over both new and cached requests # iterate over both new and cached requests
for req_id, new_block_id_groups, preempted in yield_req_data(scheduler_output): for req_id, new_block_id_groups, preempted in yield_req_data(scheduler_output):
req_status = self._req_status[req_id] req_status = self._req_status[req_id]
req_status.update_offload_keys() req_status.update_offload_keys()
req = req_status.req
if preempted: if preempted:
for group_state in req_status.group_states: for group_state in req_status.group_states:
...@@ -385,68 +392,106 @@ class OffloadingConnectorScheduler: ...@@ -385,68 +392,106 @@ class OffloadingConnectorScheduler:
if new_block_id_groups: if new_block_id_groups:
req_status.update_block_id_groups(new_block_id_groups) req_status.update_block_id_groups(new_block_id_groups)
# Below assertion will be removed once this function supports HMA num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id]
assert len(req_status.group_states) == 1 num_tokens_after_batch = req.num_computed_tokens + num_scheduled_tokens
group_state = req_status.group_states[0]
block_ids = group_state.block_ids
req = req_status.req
new_tokens = scheduler_output.num_scheduled_tokens[req_id]
expected_tokens = req.num_computed_tokens + new_tokens
# with async scheduling, some tokens may be missing # with async scheduling, some tokens may be missing
total_tokens = min(expected_tokens, req.num_tokens) num_offloadable_tokens = min(num_tokens_after_batch, req.num_tokens)
num_blocks = total_tokens // group_config.offloaded_block_size
start_block_idx = group_state.next_stored_block_idx
num_new_blocks = num_blocks - start_block_idx
if num_new_blocks <= 0: # Filter out blocks skipped due to sliding window attention / SSM
new_offload_keys: list[OffloadKey] = []
for group_config, group_state in zip(
self.config.kv_group_configs, req_status.group_states
):
num_blocks = num_offloadable_tokens // group_config.offloaded_block_size
start_block_idx = group_state.next_stored_block_idx
if num_blocks <= start_block_idx:
continue continue
offload_keys = group_state.offload_keys[start_block_idx:num_blocks]
# For each block to offload, take the last corresponding GPU block.
# e.g. if block size factor is 3 and GPU block IDs are
# 1 5 6 7 2 4 9 3 8 then we'll take blocks 6 4 8.
# We will use these GPU blocks to determine if the block needs
# offloading, or (if the GPU block ID is 0) this block should
# be skipped due to sliding window attention / SSM.
# We know that if a block is skipped, then all the previous blocks
# are skipped as well. This is why we take the last of each block.
offload_block_ids = group_state.block_ids[
start_block_idx * block_size_factor
+ block_size_factor
- 1 : num_blocks * block_size_factor : block_size_factor
]
assert len(offload_keys) == len(offload_block_ids)
num_gpu_blocks = num_blocks * self.config.block_size_factor for offload_key, block_id in zip(offload_keys, offload_block_ids):
assert len(req.block_hashes) >= num_gpu_blocks if block_id != 0:
new_offload_keys.append(offload_key)
if not new_offload_keys:
req_status.advance_stored_idx(num_offloadable_tokens)
continue
new_offload_keys = group_state.offload_keys[start_block_idx:num_blocks]
store_output = self.manager.prepare_store( store_output = self.manager.prepare_store(
new_offload_keys, req_status.req_context new_offload_keys, req_status.req_context
) )
if store_output is None: if store_output is None:
logger.warning( logger.warning("Request %s: cannot store blocks", req_id)
"Request %s: cannot store %s blocks", req_id, num_new_blocks
)
continue continue
group_state.next_stored_block_idx = num_blocks
if not store_output.keys_to_store: if not store_output.keys_to_store:
req_status.advance_stored_idx(num_offloadable_tokens)
continue continue
keys_to_store = set(store_output.keys_to_store)
self.manager.touch(group_state.offload_keys[:num_blocks]) for group_state in req_status.group_states:
self.manager.touch(group_state.offload_keys)
dst_spec = store_output.store_spec keys_to_store = set(store_output.keys_to_store)
group_sizes: list[int] = []
block_indices: list[int] = []
src_block_ids: list[int] = [] src_block_ids: list[int] = []
for idx, key in enumerate(new_offload_keys): for group_config, group_state in zip(
if key not in keys_to_store: self.config.kv_group_configs, req_status.group_states
):
num_blocks = num_offloadable_tokens // group_config.offloaded_block_size
start_block_idx = group_state.next_stored_block_idx
block_ids = group_state.block_ids
num_group_blocks = 0
start_gpu_block_idx: int | None = None
for idx, offload_key in enumerate(
group_state.offload_keys[start_block_idx:num_blocks]
):
if offload_key not in keys_to_store:
continue continue
offloaded_block_idx = start_block_idx + idx offloaded_block_idx = start_block_idx + idx
gpu_block_idx = offloaded_block_idx * self.config.block_size_factor gpu_block_idx = offloaded_block_idx * block_size_factor
for i in range(self.config.block_size_factor): num_group_blocks += block_size_factor
src_block_ids.append(block_ids[gpu_block_idx + i]) for i in range(block_size_factor):
block_id = block_ids[gpu_block_idx + i]
if block_id == 0:
# skipped blocks cannot appear after non-skipped blocks
assert start_gpu_block_idx is None
continue
elif start_gpu_block_idx is None:
start_gpu_block_idx = gpu_block_idx + i
src_block_ids.append(block_id)
group_sizes.append(num_group_blocks)
block_indices.append(start_gpu_block_idx or 0)
group_state.next_stored_block_idx = num_blocks
src_spec = GPULoadStoreSpec( src_spec = GPULoadStoreSpec(
src_block_ids, src_block_ids, group_sizes=group_sizes, block_indices=block_indices
group_sizes=(len(src_block_ids),),
block_indices=(0,),
) )
dst_spec = store_output.store_spec
reqs_to_store[req_id] = (src_spec, dst_spec) reqs_to_store[req_id] = (src_spec, dst_spec)
self._reqs_being_stored[req_id] |= keys_to_store self._reqs_being_stored[req_id] |= keys_to_store
logger.debug( logger.debug(
"Request %s offloading %s blocks starting from block #%d", "Request %s offloading %s blocks upto %d tokens",
req_id, req_id,
len(keys_to_store), len(keys_to_store),
start_block_idx, num_offloadable_tokens,
) )
return reqs_to_store return reqs_to_store
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment