Unverified Commit 46ecc579 authored by Nick Hill's avatar Nick Hill Committed by GitHub
Browse files

[BugFix] Fix tpu_model_runner block_id concatenation (#19228)


Signed-off-by: default avatarNick Hill <nhill@redhat.com>
parent b6a3a9f7
......@@ -226,7 +226,7 @@ def test_update_states_request_resumed(model_runner):
req_id=req_id,
resumed_from_preemption=False,
new_token_ids=[],
new_block_ids=[],
new_block_ids=[[]],
num_computed_tokens=0,
)
......
......@@ -460,8 +460,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Update the block IDs.
if not req_data.resumed_from_preemption:
# Append the new blocks to the existing block IDs.
for i in range(len(self.kv_cache_config.kv_cache_groups)):
req_state.block_ids[i].extend(req_data.new_block_ids[i])
for block_ids, new_block_ids in zip( # type: ignore[call-overload]
req_state.block_ids,
req_data.new_block_ids,
strict=True):
block_ids.extend(new_block_ids)
else:
# The request is resumed from preemption.
# Replace the existing block IDs with the new ones.
......
......@@ -413,7 +413,11 @@ class TPUModelRunner(LoRAModelRunnerMixin):
req_state.num_computed_tokens = req_data.num_computed_tokens
if not req_data.resumed_from_preemption:
# Append the new blocks to the existing block IDs.
req_state.block_ids.extend(req_data.new_block_ids)
for block_ids, new_block_ids in zip( # type: ignore[call-overload]
req_state.block_ids,
req_data.new_block_ids,
strict=True):
block_ids.extend(new_block_ids)
else:
# The request is resumed from preemption.
# Replace the existing block IDs with the new ones.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment