Unverified Commit 9a9f48df authored by David Ben-David's avatar David Ben-David Committed by GitHub
Browse files

[V1] [P/D] Add Support for KV Load Failure Recovery (#19330)


Signed-off-by: default avatarDavid Ben-David <davidb@pliops.com>
Co-authored-by: default avatarDavid Ben-David <davidb@pliops.com>
parent 67f3fb08
......@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, NamedTuple, Optional, Union
import torch
......@@ -87,10 +87,13 @@ class KVConnectorOutput:
finished_sending: Optional[set[str]] = None
finished_recving: Optional[set[str]] = None
kv_connector_stats: Optional["KVConnectorStats"] = None
# IDs of externally computed KV blocks that failed to load.
# Requests referencing these blocks should be rescheduled to recompute them.
invalid_block_ids: set[int] = field(default_factory=set)
def is_empty(self):
return (not self.finished_sending and not self.finished_recving
and not self.kv_connector_stats)
and not self.kv_connector_stats and not self.invalid_block_ids)
# ModelRunnerOutput is serialized and sent to the scheduler process.
......
......@@ -634,8 +634,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_computed_tokens = req_data.num_computed_tokens[i]
new_block_ids = req_data.new_block_ids[i]
resumed_from_preemption = req_data.resumed_from_preemption[i]
num_output_tokens = req_data.num_output_tokens[i]
# Update the cached states.
req_state.num_computed_tokens = num_computed_tokens
if not is_last_rank:
......@@ -653,6 +655,21 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
elif num_new_tokens > 0:
req_state.output_token_ids.extend(
new_token_ids[-num_new_tokens:])
elif num_output_tokens < len(req_state.output_token_ids):
# Some output tokens were discarded due to a sync-KV-load
# failure. Align the cached state.
del req_state.output_token_ids[num_output_tokens:]
req_index = self.input_batch.req_id_to_index.get(req_id)
if req_index is not None:
old_end_idx = self.input_batch.num_tokens_no_spec[
req_index]
end_idx = self.input_batch.num_prompt_tokens[
req_index] + num_output_tokens
self.input_batch.num_tokens[req_index] = end_idx
self.input_batch.num_tokens_no_spec[req_index] = end_idx
self.input_batch.is_token_ids[req_index,
end_idx:old_end_idx] = False
# Update the block IDs.
if not resumed_from_preemption:
......
......@@ -464,8 +464,7 @@ class Worker(WorkerBase):
# In case of PP with kv transfer, we need to pass through the
# kv_connector_output
if (not kv_connector_output.finished_sending
and not kv_connector_output.finished_recving):
if kv_connector_output.is_empty():
return EMPTY_MODEL_RUNNER_OUTPUT
output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
......
......@@ -75,8 +75,7 @@ class KVConnectorModelRunnerMixin:
scheduler_output, wait_for_save=False) as kv_connector_output:
pass
if (not kv_connector_output.finished_sending
and not kv_connector_output.finished_recving):
if kv_connector_output.is_empty():
return EMPTY_MODEL_RUNNER_OUTPUT
output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
......@@ -120,6 +119,8 @@ class KVConnectorModelRunnerMixin:
output.finished_sending, output.finished_recving = (
kv_connector.get_finished(scheduler_output.finished_req_ids))
output.invalid_block_ids = (
kv_connector.get_block_ids_with_load_errors())
output.kv_connector_stats = KVConnectorModelRunnerMixin.\
get_kv_connector_stats()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment