Unverified Commit 61397891 authored by Nick Hill's avatar Nick Hill Committed by GitHub
Browse files

[Minor] Some code simplification in `scheduler.py` (#33597)


Signed-off-by: default avatarNick Hill <nickhill123@gmail.com>
parent ef248ff7
...@@ -440,17 +440,13 @@ class Scheduler(SchedulerInterface): ...@@ -440,17 +440,13 @@ class Scheduler(SchedulerInterface):
) )
self.running.remove(preempted_req) self.running.remove(preempted_req)
if preempted_req in scheduled_running_reqs: if preempted_req in scheduled_running_reqs:
preempted_req_id = preempted_req.request_id
scheduled_running_reqs.remove(preempted_req) scheduled_running_reqs.remove(preempted_req)
token_budget += num_scheduled_tokens[ token_budget += num_scheduled_tokens.pop(preempted_req_id)
preempted_req.request_id req_to_new_blocks.pop(preempted_req_id)
] scheduled_spec_decode_tokens.pop(preempted_req_id, None)
req_to_new_blocks.pop(preempted_req.request_id)
num_scheduled_tokens.pop(preempted_req.request_id)
scheduled_spec_decode_tokens.pop(
preempted_req.request_id, None
)
preempted_encoder_inputs = scheduled_encoder_inputs.pop( preempted_encoder_inputs = scheduled_encoder_inputs.pop(
preempted_req.request_id, None preempted_req_id, None
) )
if preempted_encoder_inputs: if preempted_encoder_inputs:
# Restore encoder compute budget if the preempted # Restore encoder compute budget if the preempted
...@@ -476,8 +472,9 @@ class Scheduler(SchedulerInterface): ...@@ -476,8 +472,9 @@ class Scheduler(SchedulerInterface):
# Schedule the request. # Schedule the request.
scheduled_running_reqs.append(request) scheduled_running_reqs.append(request)
req_to_new_blocks[request.request_id] = new_blocks request_id = request.request_id
num_scheduled_tokens[request.request_id] = num_new_tokens req_to_new_blocks[request_id] = new_blocks
num_scheduled_tokens[request_id] = num_new_tokens
token_budget -= num_new_tokens token_budget -= num_new_tokens
req_index += 1 req_index += 1
...@@ -492,18 +489,14 @@ class Scheduler(SchedulerInterface): ...@@ -492,18 +489,14 @@ class Scheduler(SchedulerInterface):
if num_scheduled_spec_tokens > 0: if num_scheduled_spec_tokens > 0:
# Trim spec_token_ids list to num_scheduled_spec_tokens. # Trim spec_token_ids list to num_scheduled_spec_tokens.
del request.spec_token_ids[num_scheduled_spec_tokens:] del request.spec_token_ids[num_scheduled_spec_tokens:]
scheduled_spec_decode_tokens[request.request_id] = ( scheduled_spec_decode_tokens[request_id] = request.spec_token_ids
request.spec_token_ids
)
# New spec tokens will be set in `update_draft_token_ids` before the # New spec tokens will be set in `update_draft_token_ids` before the
# next step when applicable. # next step when applicable.
request.spec_token_ids = [] request.spec_token_ids = []
# Encoder-related. # Encoder-related.
if encoder_inputs_to_schedule: if encoder_inputs_to_schedule:
scheduled_encoder_inputs[request.request_id] = ( scheduled_encoder_inputs[request_id] = encoder_inputs_to_schedule
encoder_inputs_to_schedule
)
# Allocate the encoder cache. # Allocate the encoder cache.
for i in encoder_inputs_to_schedule: for i in encoder_inputs_to_schedule:
self.encoder_cache_manager.allocate(request, i) self.encoder_cache_manager.allocate(request, i)
...@@ -535,6 +528,7 @@ class Scheduler(SchedulerInterface): ...@@ -535,6 +528,7 @@ class Scheduler(SchedulerInterface):
break break
request = self.waiting.peek_request() request = self.waiting.peek_request()
request_id = request.request_id
# KVTransfer: skip request if still waiting for remote kvs. # KVTransfer: skip request if still waiting for remote kvs.
if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS: if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
...@@ -549,7 +543,7 @@ class Scheduler(SchedulerInterface): ...@@ -549,7 +543,7 @@ class Scheduler(SchedulerInterface):
else: else:
logger.debug( logger.debug(
"%s is still in WAITING_FOR_REMOTE_KVS state.", "%s is still in WAITING_FOR_REMOTE_KVS state.",
request.request_id, request_id,
) )
self.waiting.pop_request() self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request) skipped_waiting_requests.prepend_request(request)
...@@ -729,7 +723,7 @@ class Scheduler(SchedulerInterface): ...@@ -729,7 +723,7 @@ class Scheduler(SchedulerInterface):
if self.connector is not None: if self.connector is not None:
self.connector.update_state_after_alloc( self.connector.update_state_after_alloc(
request, request,
self.kv_cache_manager.get_blocks(request.request_id), self.kv_cache_manager.get_blocks(request_id),
num_external_computed_tokens, num_external_computed_tokens,
) )
...@@ -759,10 +753,10 @@ class Scheduler(SchedulerInterface): ...@@ -759,10 +753,10 @@ class Scheduler(SchedulerInterface):
if self.lora_config and request.lora_request: if self.lora_config and request.lora_request:
scheduled_loras.add(request.lora_request.lora_int_id) scheduled_loras.add(request.lora_request.lora_int_id)
req_to_new_blocks[request.request_id] = ( req_to_new_blocks[request_id] = self.kv_cache_manager.get_blocks(
self.kv_cache_manager.get_blocks(request.request_id) request_id
) )
num_scheduled_tokens[request.request_id] = num_new_tokens num_scheduled_tokens[request_id] = num_new_tokens
token_budget -= num_new_tokens token_budget -= num_new_tokens
request.status = RequestStatus.RUNNING request.status = RequestStatus.RUNNING
request.num_computed_tokens = num_computed_tokens request.num_computed_tokens = num_computed_tokens
...@@ -771,9 +765,7 @@ class Scheduler(SchedulerInterface): ...@@ -771,9 +765,7 @@ class Scheduler(SchedulerInterface):
request.num_cached_tokens = num_computed_tokens request.num_cached_tokens = num_computed_tokens
# Encoder-related. # Encoder-related.
if encoder_inputs_to_schedule: if encoder_inputs_to_schedule:
scheduled_encoder_inputs[request.request_id] = ( scheduled_encoder_inputs[request_id] = encoder_inputs_to_schedule
encoder_inputs_to_schedule
)
# Allocate the encoder cache. # Allocate the encoder cache.
for i in encoder_inputs_to_schedule: for i in encoder_inputs_to_schedule:
self.encoder_cache_manager.allocate(request, i) self.encoder_cache_manager.allocate(request, i)
...@@ -806,11 +798,9 @@ class Scheduler(SchedulerInterface): ...@@ -806,11 +798,9 @@ class Scheduler(SchedulerInterface):
num_common_prefix_blocks = [0] * len(self.kv_cache_config.kv_cache_groups) num_common_prefix_blocks = [0] * len(self.kv_cache_config.kv_cache_groups)
with record_function_or_nullcontext("schedule: get_num_common_prefix_blocks"): with record_function_or_nullcontext("schedule: get_num_common_prefix_blocks"):
if self.running: if self.running:
any_request = self.running[0] any_request_id = self.running[0].request_id
num_common_prefix_blocks = ( num_common_prefix_blocks = (
self.kv_cache_manager.get_num_common_prefix_blocks( self.kv_cache_manager.get_num_common_prefix_blocks(any_request_id)
any_request.request_id
)
) )
# Construct the scheduler output. # Construct the scheduler output.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment