Unverified Commit 612e7729 authored by Or Ozeri's avatar Or Ozeri Committed by GitHub
Browse files

[KVConnector] Scheduler: Fix num_computed_tokens after async KV load (#34616)


Signed-off-by: default avatarOr Ozeri <oro@il.ibm.com>
parent ecde7af9
...@@ -121,7 +121,7 @@ def test_error_propagation_async_load(fail_scheduler: Scheduler): ...@@ -121,7 +121,7 @@ def test_error_propagation_async_load(fail_scheduler: Scheduler):
assert len(fail_scheduler.waiting) == 1 assert len(fail_scheduler.waiting) == 1
assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS
assert request.num_computed_tokens == 0 assert request.num_computed_tokens == num_external_computed_tokens
(req_block_ids,) = fail_scheduler.kv_cache_manager.get_block_ids(request.request_id) (req_block_ids,) = fail_scheduler.kv_cache_manager.get_block_ids(request.request_id)
invalid_block_ids = {req_block_ids[invalid_block_idx]} invalid_block_ids = {req_block_ids[invalid_block_idx]}
......
...@@ -339,7 +339,7 @@ def test_async_recompute_blocks_not_cached_when_invalid( ...@@ -339,7 +339,7 @@ def test_async_recompute_blocks_not_cached_when_invalid(
# request should be waiting for remote KVs # request should be waiting for remote KVs
assert len(recompute_scheduler.waiting) == 1 assert len(recompute_scheduler.waiting) == 1
assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS
assert request.num_computed_tokens == 0 assert request.num_computed_tokens == num_external_computed_tokens
# get the allocated block IDs # get the allocated block IDs
(req_block_ids,) = recompute_scheduler.kv_cache_manager.get_block_ids( (req_block_ids,) = recompute_scheduler.kv_cache_manager.get_block_ids(
......
...@@ -78,7 +78,7 @@ def test_async_load_failure( ...@@ -78,7 +78,7 @@ def test_async_load_failure(
assert len(scheduler.waiting) == 3 assert len(scheduler.waiting) == 3
for request in scheduler.waiting: for request in scheduler.waiting:
assert request.num_computed_tokens == 0 assert request.num_computed_tokens == num_external_computed_tokens
assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS
assert scheduler.connector.get_num_new_matched_tokens.call_count == 3 assert scheduler.connector.get_num_new_matched_tokens.call_count == 3
...@@ -103,7 +103,7 @@ def test_async_load_failure( ...@@ -103,7 +103,7 @@ def test_async_load_failure(
min_invalid_block_idx * scheduler.block_size min_invalid_block_idx * scheduler.block_size
) )
else: else:
assert request.num_computed_tokens == 0 assert request.num_computed_tokens == num_external_computed_tokens
assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS
assert scheduler.failed_recving_kv_req_ids == {request2.request_id} assert scheduler.failed_recving_kv_req_ids == {request2.request_id}
assert scheduler.connector.get_num_new_matched_tokens.call_count == 3 assert scheduler.connector.get_num_new_matched_tokens.call_count == 3
...@@ -305,7 +305,7 @@ def test_async_progressive_load_failure( ...@@ -305,7 +305,7 @@ def test_async_progressive_load_failure(
assert len(scheduler.waiting) == 1 assert len(scheduler.waiting) == 1
assert scheduler.waiting.peek_request().request_id == request.request_id assert scheduler.waiting.peek_request().request_id == request.request_id
assert request.num_computed_tokens == 0 assert request.num_computed_tokens == num_external_computed_tokens
assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS
assert scheduler.connector.get_num_new_matched_tokens.call_count == 1 assert scheduler.connector.get_num_new_matched_tokens.call_count == 1
......
...@@ -57,7 +57,7 @@ def test_basic_lifecycle(): ...@@ -57,7 +57,7 @@ def test_basic_lifecycle():
assert len(scheduler.waiting) == 1 assert len(scheduler.waiting) == 1
assert request in scheduler.waiting assert request in scheduler.waiting
assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS
assert request.num_computed_tokens == 0 assert request.num_computed_tokens == NUM_TOKENS
# ... but should have (uncached) blocks allocated to it. # ... but should have (uncached) blocks allocated to it.
block_pool = scheduler.kv_cache_manager.block_pool block_pool = scheduler.kv_cache_manager.block_pool
......
...@@ -638,6 +638,7 @@ class Scheduler(SchedulerInterface): ...@@ -638,6 +638,7 @@ class Scheduler(SchedulerInterface):
num_computed_tokens = ( num_computed_tokens = (
num_new_local_computed_tokens + num_external_computed_tokens num_new_local_computed_tokens + num_external_computed_tokens
) )
assert num_computed_tokens <= request.num_tokens
else: else:
# KVTransfer: WAITING reqs have num_computed_tokens > 0 # KVTransfer: WAITING reqs have num_computed_tokens > 0
# after async KV recvs are completed. # after async KV recvs are completed.
...@@ -773,6 +774,20 @@ class Scheduler(SchedulerInterface): ...@@ -773,6 +774,20 @@ class Scheduler(SchedulerInterface):
# into the WAITING_FOR_REMOTE_KV state. # into the WAITING_FOR_REMOTE_KV state.
skipped_waiting_requests.prepend_request(request) skipped_waiting_requests.prepend_request(request)
request.status = RequestStatus.WAITING_FOR_REMOTE_KVS request.status = RequestStatus.WAITING_FOR_REMOTE_KVS
# Set num_computed_tokens even though KVs are not yet loaded.
# request.num_computed_tokens will not be used anywhere until
# the request finished the KV transfer.
#
# If a transfer error is reported by the connector,
# request.num_computed_tokens will be re-set accordingly in
# _update_requests_with_invalid_blocks.
#
# When the transfer is finished, either successfully or not,
# request.num_computed_tokens will correctly reflect the number
# of computed tokens.
# _update_waiting_for_remote_kv will then cache
# only the successfully loaded tokens.
request.num_computed_tokens = num_computed_tokens
continue continue
self.running.append(request) self.running.append(request)
...@@ -1994,17 +2009,17 @@ class Scheduler(SchedulerInterface): ...@@ -1994,17 +2009,17 @@ class Scheduler(SchedulerInterface):
self.failed_recving_kv_req_ids.remove(request.request_id) self.failed_recving_kv_req_ids.remove(request.request_id)
else: else:
# Now that the blocks are ready, actually cache them. # Now that the blocks are ready, actually cache them.
(block_ids,) = self.kv_cache_manager.get_block_ids(request.request_id)
num_computed_tokens = len(block_ids) * self.block_size
# Handle the case where num request tokens less than one block.
num_computed_tokens = min(num_computed_tokens, request.num_tokens)
if num_computed_tokens == request.num_tokens:
num_computed_tokens -= 1
# This will cache the blocks iff caching is enabled. # This will cache the blocks iff caching is enabled.
self.kv_cache_manager.cache_blocks(request, num_computed_tokens) self.kv_cache_manager.cache_blocks(request, request.num_computed_tokens)
# Update the request state for scheduling. # on a full prompt hit, we need to re-compute the last token
request.num_computed_tokens = num_computed_tokens # in order to be able to sample the next token
if request.num_computed_tokens == request.num_tokens:
request.num_computed_tokens = request.num_tokens - 1
# Count the number of prefix cached tokens.
if request.num_cached_tokens < 0:
request.num_cached_tokens = request.num_computed_tokens
# Return that we are ready. # Return that we are ready.
self.finished_recving_kv_req_ids.remove(request.request_id) self.finished_recving_kv_req_ids.remove(request.request_id)
...@@ -2084,13 +2099,8 @@ class Scheduler(SchedulerInterface): ...@@ -2084,13 +2099,8 @@ class Scheduler(SchedulerInterface):
# We iterate only over blocks that may contain externally computed # We iterate only over blocks that may contain externally computed
# tokens # tokens
if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS: if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
# Async loading. If num_computed_tokens is set it implies we # Async loading. num_computed_tokens does not include new tokens
# already processed some block failures for it in a prior step req_num_computed_tokens = request.num_computed_tokens
req_num_computed_tokens = (
request.num_computed_tokens
if req_id in self.failed_recving_kv_req_ids
else len(req_block_ids) * self.block_size
)
else: else:
# Sync loading. num_computed_tokens includes new tokens # Sync loading. num_computed_tokens includes new tokens
req_num_computed_tokens = request.num_cached_tokens req_num_computed_tokens = request.num_cached_tokens
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment