Unverified Commit b4a01aaf authored by Yihua Cheng's avatar Yihua Cheng Committed by GitHub
Browse files

[KV Connector] More async support for `get_num_new_matched_tokens` (#23620)


Signed-off-by: default avatarApostaC <yihua98@uchicago.edu>
parent 83dd28aa
...@@ -243,7 +243,7 @@ class KVConnectorBase_V1(ABC): ...@@ -243,7 +243,7 @@ class KVConnectorBase_V1(ABC):
self, self,
request: "Request", request: "Request",
num_computed_tokens: int, num_computed_tokens: int,
) -> tuple[int, bool]: ) -> tuple[Optional[int], bool]:
""" """
Get number of new tokens that can be loaded from the Get number of new tokens that can be loaded from the
external KV cache beyond the num_computed_tokens. external KV cache beyond the num_computed_tokens.
...@@ -255,8 +255,11 @@ class KVConnectorBase_V1(ABC): ...@@ -255,8 +255,11 @@ class KVConnectorBase_V1(ABC):
Returns: Returns:
A tuple with the following elements: A tuple with the following elements:
- The number of tokens that can be loaded from the - An optional number of tokens that can be loaded from the
external KV cache beyond what is already computed. external KV cache beyond what is already computed.
If None, it means that the connector needs more time to
determine the number of matched tokens, and the scheduler
should query for this request again later.
- `True` if external KV cache tokens will be loaded - `True` if external KV cache tokens will be loaded
asynchronously (between scheduler steps). Must be asynchronously (between scheduler steps). Must be
'False' if the first element is 0. 'False' if the first element is 0.
......
...@@ -110,7 +110,7 @@ class LMCacheConnectorV1(KVConnectorBase_V1): ...@@ -110,7 +110,7 @@ class LMCacheConnectorV1(KVConnectorBase_V1):
self, self,
request: "Request", request: "Request",
num_computed_tokens: int, num_computed_tokens: int,
) -> tuple[int, bool]: ) -> tuple[Optional[int], bool]:
""" """
Get number of new tokens that can be loaded from the Get number of new tokens that can be loaded from the
external KV cache beyond the num_computed_tokens. external KV cache beyond the num_computed_tokens.
......
...@@ -143,11 +143,15 @@ class MultiConnector(KVConnectorBase_V1): ...@@ -143,11 +143,15 @@ class MultiConnector(KVConnectorBase_V1):
self, self,
request: "Request", request: "Request",
num_computed_tokens: int, num_computed_tokens: int,
) -> tuple[int, bool]: ) -> tuple[Optional[int], bool]:
to_return = (0, False) to_return = (0, False)
for i, c in enumerate(self._connectors): for i, c in enumerate(self._connectors):
toks, load_async = c.get_num_new_matched_tokens( toks, load_async = c.get_num_new_matched_tokens(
request, num_computed_tokens) request, num_computed_tokens)
# If there is a connector still looking up the matches,
# we return None to indicate that we are not done yet.
if toks is None:
return (None, False)
# The first connector that has new matched tokens will be assigned # The first connector that has new matched tokens will be assigned
# to this request. # to this request.
if to_return[0] == 0 and toks > 0: if to_return[0] == 0 and toks > 0:
......
...@@ -162,7 +162,7 @@ class NixlConnector(KVConnectorBase_V1): ...@@ -162,7 +162,7 @@ class NixlConnector(KVConnectorBase_V1):
def get_num_new_matched_tokens( def get_num_new_matched_tokens(
self, request: "Request", self, request: "Request",
num_computed_tokens: int) -> tuple[int, bool]: num_computed_tokens: int) -> tuple[Optional[int], bool]:
assert self.connector_scheduler is not None assert self.connector_scheduler is not None
return self.connector_scheduler.get_num_new_matched_tokens( return self.connector_scheduler.get_num_new_matched_tokens(
request, num_computed_tokens) request, num_computed_tokens)
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
import hashlib import hashlib
import os import os
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, Optional
import safetensors import safetensors
import torch import torch
...@@ -238,7 +238,7 @@ class SharedStorageConnector(KVConnectorBase_V1): ...@@ -238,7 +238,7 @@ class SharedStorageConnector(KVConnectorBase_V1):
self, self,
request: "Request", request: "Request",
num_computed_tokens: int, num_computed_tokens: int,
) -> tuple[int, bool]: ) -> tuple[Optional[int], bool]:
""" """
Get number of new tokens that can be loaded from the Get number of new tokens that can be loaded from the
external KV cache beyond the num_computed_tokens. external KV cache beyond the num_computed_tokens.
......
...@@ -387,6 +387,14 @@ class Scheduler(SchedulerInterface): ...@@ -387,6 +387,14 @@ class Scheduler(SchedulerInterface):
self.connector.get_num_new_matched_tokens( self.connector.get_num_new_matched_tokens(
request, num_new_local_computed_tokens)) request, num_new_local_computed_tokens))
if num_external_computed_tokens is None:
# The request cannot be scheduled because
# the KVConnector couldn't determine
# the number of matched tokens.
self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request)
continue
# Total computed tokens (local + external). # Total computed tokens (local + external).
num_computed_tokens = (num_new_local_computed_tokens + num_computed_tokens = (num_new_local_computed_tokens +
num_external_computed_tokens) num_external_computed_tokens)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment