"vscode:/vscode.git/clone" did not exist on "504ac53d18fc057d2a98741fa27d89df9054422d"
Unverified Commit 612e7729 authored by Or Ozeri's avatar Or Ozeri Committed by GitHub
Browse files

[KVConnector] Scheduler: Fix num_computed_tokens after async KV load (#34616)


Signed-off-by: default avatarOr Ozeri <oro@il.ibm.com>
parent ecde7af9
......@@ -121,7 +121,7 @@ def test_error_propagation_async_load(fail_scheduler: Scheduler):
assert len(fail_scheduler.waiting) == 1
assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS
assert request.num_computed_tokens == 0
assert request.num_computed_tokens == num_external_computed_tokens
(req_block_ids,) = fail_scheduler.kv_cache_manager.get_block_ids(request.request_id)
invalid_block_ids = {req_block_ids[invalid_block_idx]}
......
......@@ -339,7 +339,7 @@ def test_async_recompute_blocks_not_cached_when_invalid(
# request should be waiting for remote KVs
assert len(recompute_scheduler.waiting) == 1
assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS
assert request.num_computed_tokens == 0
assert request.num_computed_tokens == num_external_computed_tokens
# get the allocated block IDs
(req_block_ids,) = recompute_scheduler.kv_cache_manager.get_block_ids(
......
......@@ -78,7 +78,7 @@ def test_async_load_failure(
assert len(scheduler.waiting) == 3
for request in scheduler.waiting:
assert request.num_computed_tokens == 0
assert request.num_computed_tokens == num_external_computed_tokens
assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS
assert scheduler.connector.get_num_new_matched_tokens.call_count == 3
......@@ -103,7 +103,7 @@ def test_async_load_failure(
min_invalid_block_idx * scheduler.block_size
)
else:
assert request.num_computed_tokens == 0
assert request.num_computed_tokens == num_external_computed_tokens
assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS
assert scheduler.failed_recving_kv_req_ids == {request2.request_id}
assert scheduler.connector.get_num_new_matched_tokens.call_count == 3
......@@ -305,7 +305,7 @@ def test_async_progressive_load_failure(
assert len(scheduler.waiting) == 1
assert scheduler.waiting.peek_request().request_id == request.request_id
assert request.num_computed_tokens == 0
assert request.num_computed_tokens == num_external_computed_tokens
assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS
assert scheduler.connector.get_num_new_matched_tokens.call_count == 1
......
......@@ -57,7 +57,7 @@ def test_basic_lifecycle():
assert len(scheduler.waiting) == 1
assert request in scheduler.waiting
assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS
assert request.num_computed_tokens == 0
assert request.num_computed_tokens == NUM_TOKENS
# ... but should have (uncached) blocks allocated to it.
block_pool = scheduler.kv_cache_manager.block_pool
......
......@@ -638,6 +638,7 @@ class Scheduler(SchedulerInterface):
num_computed_tokens = (
num_new_local_computed_tokens + num_external_computed_tokens
)
assert num_computed_tokens <= request.num_tokens
else:
# KVTransfer: WAITING reqs have num_computed_tokens > 0
# after async KV recvs are completed.
......@@ -773,6 +774,20 @@ class Scheduler(SchedulerInterface):
# into the WAITING_FOR_REMOTE_KV state.
skipped_waiting_requests.prepend_request(request)
request.status = RequestStatus.WAITING_FOR_REMOTE_KVS
# Set num_computed_tokens even though KVs are not yet loaded.
# request.num_computed_tokens will not be used anywhere until
# the request finished the KV transfer.
#
# If a transfer error is reported by the connector,
# request.num_computed_tokens will be re-set accordingly in
# _update_requests_with_invalid_blocks.
#
# When the transfer is finished, either successfully or not,
# request.num_computed_tokens will correctly reflect the number
# of computed tokens.
# _update_waiting_for_remote_kv will then cache
# only the successfully loaded tokens.
request.num_computed_tokens = num_computed_tokens
continue
self.running.append(request)
......@@ -1994,17 +2009,17 @@ class Scheduler(SchedulerInterface):
self.failed_recving_kv_req_ids.remove(request.request_id)
else:
# Now that the blocks are ready, actually cache them.
(block_ids,) = self.kv_cache_manager.get_block_ids(request.request_id)
num_computed_tokens = len(block_ids) * self.block_size
# Handle the case where num request tokens less than one block.
num_computed_tokens = min(num_computed_tokens, request.num_tokens)
if num_computed_tokens == request.num_tokens:
num_computed_tokens -= 1
# This will cache the blocks iff caching is enabled.
self.kv_cache_manager.cache_blocks(request, num_computed_tokens)
self.kv_cache_manager.cache_blocks(request, request.num_computed_tokens)
# Update the request state for scheduling.
request.num_computed_tokens = num_computed_tokens
# on a full prompt hit, we need to re-compute the last token
# in order to be able to sample the next token
if request.num_computed_tokens == request.num_tokens:
request.num_computed_tokens = request.num_tokens - 1
# Count the number of prefix cached tokens.
if request.num_cached_tokens < 0:
request.num_cached_tokens = request.num_computed_tokens
# Return that we are ready.
self.finished_recving_kv_req_ids.remove(request.request_id)
......@@ -2084,13 +2099,8 @@ class Scheduler(SchedulerInterface):
# We iterate only over blocks that may contain externally computed
# tokens
if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
# Async loading. If num_computed_tokens is set it implies we
# already processed some block failures for it in a prior step
req_num_computed_tokens = (
request.num_computed_tokens
if req_id in self.failed_recving_kv_req_ids
else len(req_block_ids) * self.block_size
)
# Async loading. num_computed_tokens does not include new tokens
req_num_computed_tokens = request.num_computed_tokens
else:
# Sync loading. num_computed_tokens includes new tokens
req_num_computed_tokens = request.num_cached_tokens
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment