Unverified Commit 46ecc579 authored by Nick Hill's avatar Nick Hill Committed by GitHub
Browse files

[BugFix] Fix tpu_model_runner block_id concatenation (#19228)


Signed-off-by: default avatarNick Hill <nhill@redhat.com>
parent b6a3a9f7
...@@ -226,7 +226,7 @@ def test_update_states_request_resumed(model_runner): ...@@ -226,7 +226,7 @@ def test_update_states_request_resumed(model_runner):
req_id=req_id, req_id=req_id,
resumed_from_preemption=False, resumed_from_preemption=False,
new_token_ids=[], new_token_ids=[],
new_block_ids=[], new_block_ids=[[]],
num_computed_tokens=0, num_computed_tokens=0,
) )
......
...@@ -460,8 +460,11 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -460,8 +460,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Update the block IDs. # Update the block IDs.
if not req_data.resumed_from_preemption: if not req_data.resumed_from_preemption:
# Append the new blocks to the existing block IDs. # Append the new blocks to the existing block IDs.
for i in range(len(self.kv_cache_config.kv_cache_groups)): for block_ids, new_block_ids in zip( # type: ignore[call-overload]
req_state.block_ids[i].extend(req_data.new_block_ids[i]) req_state.block_ids,
req_data.new_block_ids,
strict=True):
block_ids.extend(new_block_ids)
else: else:
# The request is resumed from preemption. # The request is resumed from preemption.
# Replace the existing block IDs with the new ones. # Replace the existing block IDs with the new ones.
......
...@@ -413,7 +413,11 @@ class TPUModelRunner(LoRAModelRunnerMixin): ...@@ -413,7 +413,11 @@ class TPUModelRunner(LoRAModelRunnerMixin):
req_state.num_computed_tokens = req_data.num_computed_tokens req_state.num_computed_tokens = req_data.num_computed_tokens
if not req_data.resumed_from_preemption: if not req_data.resumed_from_preemption:
# Append the new blocks to the existing block IDs. # Append the new blocks to the existing block IDs.
req_state.block_ids.extend(req_data.new_block_ids) for block_ids, new_block_ids in zip( # type: ignore[call-overload]
req_state.block_ids,
req_data.new_block_ids,
strict=True):
block_ids.extend(new_block_ids)
else: else:
# The request is resumed from preemption. # The request is resumed from preemption.
# Replace the existing block IDs with the new ones. # Replace the existing block IDs with the new ones.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment