Unverified Commit 02859973 authored by Or Ozeri's avatar Or Ozeri Committed by GitHub
Browse files

[BugFix] scheduler: Fix resuming of preempted requests after async load (#31583)


Signed-off-by: default avatarOr Ozeri <oro@il.ibm.com>
parent d1fd802f
...@@ -1261,10 +1261,11 @@ def test_kv_connector_unable_to_allocate(use_ec_connector, ec_role): ...@@ -1261,10 +1261,11 @@ def test_kv_connector_unable_to_allocate(use_ec_connector, ec_role):
assert len(scheduler.waiting) == 0 assert len(scheduler.waiting) == 0
@pytest.mark.parametrize("is_async", [False, True])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"use_ec_connector, ec_role", [(False, None), (True, "ec_consumer")] "use_ec_connector, ec_role", [(False, None), (True, "ec_consumer")]
) )
def test_kv_connector_handles_preemption(use_ec_connector, ec_role): def test_kv_connector_handles_preemption(is_async, use_ec_connector, ec_role):
""" """
Test whether scheduler with KVConnector is able to handle Test whether scheduler with KVConnector is able to handle
unable to allocate (run out of blocks in allocate_slots(). unable to allocate (run out of blocks in allocate_slots().
...@@ -1277,7 +1278,9 @@ def test_kv_connector_handles_preemption(use_ec_connector, ec_role): ...@@ -1277,7 +1278,9 @@ def test_kv_connector_handles_preemption(use_ec_connector, ec_role):
NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE
scheduler = create_scheduler( scheduler = create_scheduler(
enable_prefix_caching=True, enable_prefix_caching=True,
use_kv_connector=mock_kv(matched_tokens=NUM_MATCHED_NEW_TOKENS, is_async=False), use_kv_connector=mock_kv(
matched_tokens=NUM_MATCHED_NEW_TOKENS, is_async=is_async
),
block_size=BLOCK_SIZE, block_size=BLOCK_SIZE,
num_blocks=NUM_BLOCKS, num_blocks=NUM_BLOCKS,
# encoder connector should not affect test results # encoder connector should not affect test results
...@@ -1315,6 +1318,12 @@ def test_kv_connector_handles_preemption(use_ec_connector, ec_role): ...@@ -1315,6 +1318,12 @@ def test_kv_connector_handles_preemption(use_ec_connector, ec_role):
# All can be scheduled - 1st token. # All can be scheduled - 1st token.
output = scheduler.schedule() output = scheduler.schedule()
if is_async:
assert len(scheduler.waiting) == 2
assert scheduler.running == []
_step_until_kv_transfer_finished(scheduler, req_ids)
output = scheduler.schedule()
_assert_right_scheduler_output( _assert_right_scheduler_output(
output, output,
# 2 remote kv cache hits. # 2 remote kv cache hits.
...@@ -1367,6 +1376,12 @@ def test_kv_connector_handles_preemption(use_ec_connector, ec_role): ...@@ -1367,6 +1376,12 @@ def test_kv_connector_handles_preemption(use_ec_connector, ec_role):
# Restarts the preempted request - generate 3rd token. # Restarts the preempted request - generate 3rd token.
# This will have a local and remote cache hit. # This will have a local and remote cache hit.
output = scheduler.schedule() output = scheduler.schedule()
if is_async:
waiting_req_ids = [req.request_id for req in scheduler.waiting]
assert len(waiting_req_ids) == 1
_step_until_kv_transfer_finished(scheduler, waiting_req_ids)
output = scheduler.schedule()
_assert_right_scheduler_output( _assert_right_scheduler_output(
output, output,
# 1 remote kv_cache hit! # 1 remote kv_cache hit!
...@@ -1377,6 +1392,8 @@ def test_kv_connector_handles_preemption(use_ec_connector, ec_role): ...@@ -1377,6 +1392,8 @@ def test_kv_connector_handles_preemption(use_ec_connector, ec_role):
) )
assert len(scheduler.running) == 1 assert len(scheduler.running) == 1
assert len(scheduler.waiting) == 0 assert len(scheduler.waiting) == 0
assert output.scheduled_cached_reqs.num_reqs == 1
assert output.scheduled_new_reqs == []
_ = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT) _ = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT)
assert len(scheduler.running) == 1 assert len(scheduler.running) == 1
assert len(scheduler.waiting) == 0 assert len(scheduler.waiting) == 0
...@@ -1389,6 +1406,8 @@ def test_kv_connector_handles_preemption(use_ec_connector, ec_role): ...@@ -1389,6 +1406,8 @@ def test_kv_connector_handles_preemption(use_ec_connector, ec_role):
num_requests=0, num_requests=0,
expected_num_scheduled_tokens=1, expected_num_scheduled_tokens=1,
) )
assert output.scheduled_cached_reqs.num_reqs == 1
assert output.scheduled_new_reqs == []
assert len(scheduler.running) == 1 assert len(scheduler.running) == 1
_ = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT) _ = scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT)
assert len(scheduler.running) == 0 assert len(scheduler.running) == 0
......
...@@ -445,7 +445,12 @@ class Scheduler(SchedulerInterface): ...@@ -445,7 +445,12 @@ class Scheduler(SchedulerInterface):
if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS: if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
is_ready = self._update_waiting_for_remote_kv(request) is_ready = self._update_waiting_for_remote_kv(request)
if is_ready: if is_ready:
request.status = RequestStatus.WAITING if request.num_preemptions:
# We must be loading for a resumed preemption
# rather than a new request.
request.status = RequestStatus.PREEMPTED
else:
request.status = RequestStatus.WAITING
else: else:
logger.debug( logger.debug(
"%s is still in WAITING_FOR_REMOTE_KVS state.", "%s is still in WAITING_FOR_REMOTE_KVS state.",
......
...@@ -123,7 +123,7 @@ class Request: ...@@ -123,7 +123,7 @@ class Request:
# indicates that the output is corrupted # indicates that the output is corrupted
self.num_nans_in_logits = 0 self.num_nans_in_logits = 0
# The number of requests being preempted by the scheduler # The number of times this request has been preempted by the scheduler.
self.num_preemptions = 0 self.num_preemptions = 0
# The number of tokens that have been computed remotely. # The number of tokens that have been computed remotely.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment