Unverified Commit b4a01aaf authored by Yihua Cheng's avatar Yihua Cheng Committed by GitHub
Browse files

[KV Connector] More async support for `get_num_new_matched_tokens` (#23620)


Signed-off-by: default avatarApostaC <yihua98@uchicago.edu>
parent 83dd28aa
......@@ -243,7 +243,7 @@ class KVConnectorBase_V1(ABC):
self,
request: "Request",
num_computed_tokens: int,
) -> tuple[int, bool]:
) -> tuple[Optional[int], bool]:
"""
Get number of new tokens that can be loaded from the
external KV cache beyond the num_computed_tokens.
......@@ -255,8 +255,11 @@ class KVConnectorBase_V1(ABC):
Returns:
A tuple with the following elements:
- The number of tokens that can be loaded from the
external KV cache beyond what is already computed.
- An optional number of tokens that can be loaded from the
external KV cache beyond what is already computed.
If None, it means that the connector needs more time to
determine the number of matched tokens, and the scheduler
should query for this request again later.
- `True` if external KV cache tokens will be loaded
asynchronously (between scheduler steps). Must be
'False' if the first element is 0.
......
......@@ -110,7 +110,7 @@ class LMCacheConnectorV1(KVConnectorBase_V1):
self,
request: "Request",
num_computed_tokens: int,
) -> tuple[int, bool]:
) -> tuple[Optional[int], bool]:
"""
Get number of new tokens that can be loaded from the
external KV cache beyond the num_computed_tokens.
......
......@@ -143,11 +143,15 @@ class MultiConnector(KVConnectorBase_V1):
self,
request: "Request",
num_computed_tokens: int,
) -> tuple[int, bool]:
) -> tuple[Optional[int], bool]:
to_return = (0, False)
for i, c in enumerate(self._connectors):
toks, load_async = c.get_num_new_matched_tokens(
request, num_computed_tokens)
# If there is a connector still looking up the matches,
# we return None to indicate that we are not done yet.
if toks is None:
return (None, False)
# The first connector that has new matched tokens will be assigned
# to this request.
if to_return[0] == 0 and toks > 0:
......
......@@ -162,7 +162,7 @@ class NixlConnector(KVConnectorBase_V1):
def get_num_new_matched_tokens(
self, request: "Request",
num_computed_tokens: int) -> tuple[int, bool]:
num_computed_tokens: int) -> tuple[Optional[int], bool]:
assert self.connector_scheduler is not None
return self.connector_scheduler.get_num_new_matched_tokens(
request, num_computed_tokens)
......
......@@ -3,7 +3,7 @@
import hashlib
import os
from dataclasses import dataclass
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Optional
import safetensors
import torch
......@@ -238,7 +238,7 @@ class SharedStorageConnector(KVConnectorBase_V1):
self,
request: "Request",
num_computed_tokens: int,
) -> tuple[int, bool]:
) -> tuple[Optional[int], bool]:
"""
Get number of new tokens that can be loaded from the
external KV cache beyond the num_computed_tokens.
......
......@@ -387,6 +387,14 @@ class Scheduler(SchedulerInterface):
self.connector.get_num_new_matched_tokens(
request, num_new_local_computed_tokens))
if num_external_computed_tokens is None:
# The request cannot be scheduled because
# the KVConnector couldn't determine
# the number of matched tokens.
self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request)
continue
# Total computed tokens (local + external).
num_computed_tokens = (num_new_local_computed_tokens +
num_external_computed_tokens)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment