Unverified Commit 9a1f16da authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[Model Runner V2] Refactor `update_states` (#32562)


Signed-off-by: default avatarWoosuk Kwon <woosuk.kwon@berkeley.edu>
parent bb1848cd
...@@ -425,7 +425,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -425,7 +425,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self._dummy_run(self.max_num_tokens, skip_attn=False) self._dummy_run(self.max_num_tokens, skip_attn=False)
torch.cuda.synchronize() torch.cuda.synchronize()
def update_states(self, scheduler_output: SchedulerOutput) -> None: def finish_requests(self, scheduler_output: SchedulerOutput) -> None:
if scheduler_output.preempted_req_ids is not None: if scheduler_output.preempted_req_ids is not None:
for req_id in scheduler_output.preempted_req_ids: for req_id in scheduler_output.preempted_req_ids:
self.req_states.remove_request(req_id) self.req_states.remove_request(req_id)
...@@ -436,11 +436,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -436,11 +436,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
if self.supports_mm_inputs: if self.supports_mm_inputs:
self.encoder_runner.remove_request(req_id) self.encoder_runner.remove_request(req_id)
def free_states(self, scheduler_output: SchedulerOutput) -> None:
if self.supports_mm_inputs: if self.supports_mm_inputs:
for mm_hash in scheduler_output.free_encoder_mm_hashes: for mm_hash in scheduler_output.free_encoder_mm_hashes:
self.encoder_runner.free_encoder_cache(mm_hash) self.encoder_runner.free_encoder_cache(mm_hash)
# Add new requests. def add_requests(self, scheduler_output: SchedulerOutput) -> None:
for new_req_data in scheduler_output.scheduled_new_reqs: for new_req_data in scheduler_output.scheduled_new_reqs:
assert new_req_data.prompt_token_ids is not None assert new_req_data.prompt_token_ids is not None
assert new_req_data.prefill_token_ids is not None assert new_req_data.prefill_token_ids is not None
...@@ -476,6 +477,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -476,6 +477,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
req_index, prompt_len, new_req_data.sampling_params req_index, prompt_len, new_req_data.sampling_params
) )
if scheduler_output.scheduled_new_reqs:
self.req_states.apply_staged_writes()
self.sampler.apply_staged_writes(
self.req_states.prefill_token_ids.gpu,
self.req_states.prefill_len.np,
self.req_states.prompt_len,
)
if self.uses_mrope:
self.mrope_states.apply_staged_writes()
def update_requests(self, scheduler_output: SchedulerOutput) -> None:
# Add new blocks for the existing requests. # Add new blocks for the existing requests.
cached_reqs = scheduler_output.scheduled_cached_reqs cached_reqs = scheduler_output.scheduled_cached_reqs
for i, req_id in enumerate(cached_reqs.req_ids): for i, req_id in enumerate(cached_reqs.req_ids):
...@@ -486,16 +498,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -486,16 +498,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
req_index, req_new_block_ids, overwrite=False req_index, req_new_block_ids, overwrite=False
) )
self.req_states.apply_staged_writes()
self.block_tables.apply_staged_writes()
self.sampler.apply_staged_writes(
self.req_states.prefill_token_ids.gpu,
self.req_states.prefill_len.np,
self.req_states.prompt_len,
)
if self.uses_mrope:
self.mrope_states.apply_staged_writes()
def prepare_inputs( def prepare_inputs(
self, self,
scheduler_output: SchedulerOutput, scheduler_output: SchedulerOutput,
...@@ -951,15 +953,20 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -951,15 +953,20 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
dummy_run: bool = False, dummy_run: bool = False,
) -> ModelRunnerOutput | None: ) -> ModelRunnerOutput | None:
assert intermediate_tensors is None assert intermediate_tensors is None
if scheduler_output.total_num_scheduled_tokens == 0 and not dummy_run: if not dummy_run:
# No need to run the model. # Update the request states.
self.update_states(scheduler_output) self.finish_requests(scheduler_output)
return EMPTY_MODEL_RUNNER_OUTPUT self.free_states(scheduler_output)
self.add_requests(scheduler_output)
self.update_requests(scheduler_output)
self.block_tables.apply_staged_writes()
if scheduler_output.total_num_scheduled_tokens == 0:
# No need to run the model.
return EMPTY_MODEL_RUNNER_OUTPUT
cudagraph_mode, num_tokens_after_padding, num_tokens_across_dp = ( cudagraph_mode, num_tokens_after_padding, num_tokens_across_dp = (
self.get_cudagraph_and_dp_padding(scheduler_output) self.get_cudagraph_and_dp_padding(scheduler_output)
) )
self.update_states(scheduler_output)
if num_tokens_after_padding == 0: if num_tokens_after_padding == 0:
# All DP ranks have zero tokens to run. # All DP ranks have zero tokens to run.
return EMPTY_MODEL_RUNNER_OUTPUT return EMPTY_MODEL_RUNNER_OUTPUT
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment