Unverified Commit 60cd878a authored by Or Ozeri's avatar Or Ozeri Committed by GitHub
Browse files

[kv_offload+HMA][11/N]: Support store with multiple KV groups (#39403)


Signed-off-by: default avatarOr Ozeri <oro@il.ibm.com>
parent 1e9f19ca
......@@ -112,6 +112,13 @@ class RequestOffloadState:
for group_state, new_blocks in zip(self.group_states, new_block_id_groups):
group_state.block_ids.extend(new_blocks)
def advance_stored_idx(self, num_offloadable_tokens: int) -> None:
for group_config, group_state in zip(
self.config.kv_group_configs, self.group_states
):
num_blocks = num_offloadable_tokens // group_config.offloaded_block_size
group_state.next_stored_block_idx = num_blocks
class OffloadingConnectorScheduler:
"""Implementation of Scheduler side methods"""
......@@ -367,16 +374,16 @@ class OffloadingConnectorScheduler:
if self._blocks_being_loaded is not None:
self._blocks_being_loaded.update(req_blocks_being_loaded)
def _get_reqs_to_store(self, scheduler_output: SchedulerOutput):
# Below assertion will be removed once this function supports HMA
assert len(self.config.kv_group_configs) == 1
group_config = self.config.kv_group_configs[0]
def _get_reqs_to_store(
self, scheduler_output: SchedulerOutput
) -> dict[ReqId, TransferSpec]:
block_size_factor = self.config.block_size_factor
reqs_to_store: dict[ReqId, TransferSpec] = {}
# iterate over both new and cached requests
for req_id, new_block_id_groups, preempted in yield_req_data(scheduler_output):
req_status = self._req_status[req_id]
req_status.update_offload_keys()
req = req_status.req
if preempted:
for group_state in req_status.group_states:
......@@ -385,68 +392,106 @@ class OffloadingConnectorScheduler:
if new_block_id_groups:
req_status.update_block_id_groups(new_block_id_groups)
# Below assertion will be removed once this function supports HMA
assert len(req_status.group_states) == 1
group_state = req_status.group_states[0]
block_ids = group_state.block_ids
req = req_status.req
new_tokens = scheduler_output.num_scheduled_tokens[req_id]
expected_tokens = req.num_computed_tokens + new_tokens
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id]
num_tokens_after_batch = req.num_computed_tokens + num_scheduled_tokens
# with async scheduling, some tokens may be missing
total_tokens = min(expected_tokens, req.num_tokens)
num_blocks = total_tokens // group_config.offloaded_block_size
start_block_idx = group_state.next_stored_block_idx
num_new_blocks = num_blocks - start_block_idx
num_offloadable_tokens = min(num_tokens_after_batch, req.num_tokens)
if num_new_blocks <= 0:
# Filter out blocks skipped due to sliding window attention / SSM
new_offload_keys: list[OffloadKey] = []
for group_config, group_state in zip(
self.config.kv_group_configs, req_status.group_states
):
num_blocks = num_offloadable_tokens // group_config.offloaded_block_size
start_block_idx = group_state.next_stored_block_idx
if num_blocks <= start_block_idx:
continue
offload_keys = group_state.offload_keys[start_block_idx:num_blocks]
# For each block to offload, take the last corresponding GPU block.
# e.g. if block size factor is 3 and GPU block IDs are
# 1 5 6 7 2 4 9 3 8 then we'll take blocks 6 4 8.
# We will use these GPU blocks to determine if the block needs
# offloading, or (if the GPU block ID is 0) this block should
# be skipped due to sliding window attention / SSM.
# We know that if a block is skipped, then all the previous blocks
# are skipped as well. This is why we take the last of each block.
offload_block_ids = group_state.block_ids[
start_block_idx * block_size_factor
+ block_size_factor
- 1 : num_blocks * block_size_factor : block_size_factor
]
assert len(offload_keys) == len(offload_block_ids)
num_gpu_blocks = num_blocks * self.config.block_size_factor
assert len(req.block_hashes) >= num_gpu_blocks
for offload_key, block_id in zip(offload_keys, offload_block_ids):
if block_id != 0:
new_offload_keys.append(offload_key)
if not new_offload_keys:
req_status.advance_stored_idx(num_offloadable_tokens)
continue
new_offload_keys = group_state.offload_keys[start_block_idx:num_blocks]
store_output = self.manager.prepare_store(
new_offload_keys, req_status.req_context
)
if store_output is None:
logger.warning(
"Request %s: cannot store %s blocks", req_id, num_new_blocks
)
logger.warning("Request %s: cannot store blocks", req_id)
continue
group_state.next_stored_block_idx = num_blocks
if not store_output.keys_to_store:
req_status.advance_stored_idx(num_offloadable_tokens)
continue
keys_to_store = set(store_output.keys_to_store)
self.manager.touch(group_state.offload_keys[:num_blocks])
for group_state in req_status.group_states:
self.manager.touch(group_state.offload_keys)
dst_spec = store_output.store_spec
keys_to_store = set(store_output.keys_to_store)
group_sizes: list[int] = []
block_indices: list[int] = []
src_block_ids: list[int] = []
for idx, key in enumerate(new_offload_keys):
if key not in keys_to_store:
for group_config, group_state in zip(
self.config.kv_group_configs, req_status.group_states
):
num_blocks = num_offloadable_tokens // group_config.offloaded_block_size
start_block_idx = group_state.next_stored_block_idx
block_ids = group_state.block_ids
num_group_blocks = 0
start_gpu_block_idx: int | None = None
for idx, offload_key in enumerate(
group_state.offload_keys[start_block_idx:num_blocks]
):
if offload_key not in keys_to_store:
continue
offloaded_block_idx = start_block_idx + idx
gpu_block_idx = offloaded_block_idx * self.config.block_size_factor
for i in range(self.config.block_size_factor):
src_block_ids.append(block_ids[gpu_block_idx + i])
gpu_block_idx = offloaded_block_idx * block_size_factor
num_group_blocks += block_size_factor
for i in range(block_size_factor):
block_id = block_ids[gpu_block_idx + i]
if block_id == 0:
# skipped blocks cannot appear after non-skipped blocks
assert start_gpu_block_idx is None
continue
elif start_gpu_block_idx is None:
start_gpu_block_idx = gpu_block_idx + i
src_block_ids.append(block_id)
group_sizes.append(num_group_blocks)
block_indices.append(start_gpu_block_idx or 0)
group_state.next_stored_block_idx = num_blocks
src_spec = GPULoadStoreSpec(
src_block_ids,
group_sizes=(len(src_block_ids),),
block_indices=(0,),
src_block_ids, group_sizes=group_sizes, block_indices=block_indices
)
dst_spec = store_output.store_spec
reqs_to_store[req_id] = (src_spec, dst_spec)
self._reqs_being_stored[req_id] |= keys_to_store
logger.debug(
"Request %s offloading %s blocks starting from block #%d",
"Request %s offloading %s blocks upto %d tokens",
req_id,
len(keys_to_store),
start_block_idx,
num_offloadable_tokens,
)
return reqs_to_store
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment