Unverified Commit 14b4326b authored by Or Ozeri's avatar Or Ozeri Committed by GitHub
Browse files

v1: Support KV events from connectors (#19737)


Signed-off-by: default avatarOr Ozeri <oro@il.ibm.com>
parent 752d2e1c
......@@ -27,10 +27,12 @@ class BlockStored(KVCacheEvent):
token_ids: list[int]
block_size: int
lora_id: Optional[int]
medium: Optional[str]
class BlockRemoved(KVCacheEvent):
block_hashes: list[int]
medium: Optional[str]
class AllBlocksCleared(KVCacheEvent):
......
......@@ -40,16 +40,21 @@ class KVCacheEvent(
"""Base class for all KV cache-related events"""
MEDIUM_GPU = "GPU"
class BlockStored(KVCacheEvent):
block_hashes: list[int]
parent_block_hash: Optional[int]
token_ids: list[int]
block_size: int
lora_id: Optional[int]
medium: Optional[str]
class BlockRemoved(KVCacheEvent):
block_hashes: list[int]
medium: Optional[str]
class AllBlocksCleared(KVCacheEvent):
......
......@@ -19,6 +19,8 @@ The class provides the following primitives:
Returns whether KV cache should be freed now or will be
freed asynchronously and optionally returns KV transfer
params.
take_events() - returns new KV events that were collected
by the connector since the last call.
Worker-side: runs in each worker, loads/saves KV cache to/from
the Connector based on the metadata.
......@@ -34,6 +36,7 @@ The class provides the following primitives:
import enum
from abc import ABC, abstractmethod
from collections.abc import Iterable
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional
import torch
......@@ -45,6 +48,7 @@ from vllm.v1.outputs import KVConnectorOutput
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig
from vllm.distributed.kv_events import KVCacheEvent
from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.request import Request
......@@ -313,6 +317,15 @@ class KVConnectorBase_V1(ABC):
"""
return False, None
def take_events(self) -> Iterable["KVCacheEvent"]:
"""
Take the KV cache events from the connector.
Yields:
New KV cache events since the last call.
"""
return ()
@classmethod
def get_required_kvcache_layout(
cls, vllm_config: "VllmConfig") -> Optional[str]:
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
from collections.abc import Iterable
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional
import torch
from vllm.config import KVTransferConfig, VllmConfig
from vllm.distributed.kv_events import KVCacheEvent
from vllm.distributed.kv_transfer.kv_connector.factory import (
KVConnectorFactory)
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
......@@ -208,6 +210,10 @@ class MultiConnector(KVConnectorBase_V1):
return async_saves > 0, kv_txfer_params
def take_events(self) -> Iterable[KVCacheEvent]:
for c in self._connectors:
yield from c.take_events()
@classmethod
def get_required_kvcache_layout(
cls, vllm_config: "VllmConfig") -> Optional[str]:
......
......@@ -4,8 +4,9 @@ from collections import defaultdict
from collections.abc import Iterable
from typing import Optional
from vllm.distributed.kv_events import (AllBlocksCleared, BlockRemoved,
BlockStored, KVCacheEvent)
from vllm.distributed.kv_events import (MEDIUM_GPU, AllBlocksCleared,
BlockRemoved, BlockStored,
KVCacheEvent)
from vllm.logger import init_logger
from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId,
FreeKVCacheBlockQueue, KVCacheBlock)
......@@ -156,6 +157,7 @@ class BlockPool:
block_size=block_size,
lora_id=request.lora_request.id
if request.lora_request else None,
medium=MEDIUM_GPU,
))
def get_new_blocks(self, num_blocks: int) -> list[KVCacheBlock]:
......@@ -218,7 +220,8 @@ class BlockPool:
# we disable hybrid kv cache manager when kv cache event is
# enabled, so there is only one group.
self.kv_event_queue.append(
BlockRemoved(block_hashes=[block_hash.get_hash_value()]))
BlockRemoved(block_hashes=[block_hash.get_hash_value()],
medium=MEDIUM_GPU))
return True
def touch(self, blocks: tuple[list[KVCacheBlock], ...]) -> None:
......
......@@ -589,7 +589,19 @@ class Scheduler(SchedulerInterface):
meta = self.connector.build_connector_meta(scheduler_output)
scheduler_output.kv_connector_metadata = meta
# collect KV cache events from KV cache manager
events = self.kv_cache_manager.take_events()
# collect KV cache events from connector
if self.connector is not None:
connector_events = self.connector.take_events()
if connector_events:
if events is None:
events = list(connector_events)
else:
events.extend(connector_events)
# publish collected KV cache events
if events:
batch = KVEventBatch(ts=time.time(), events=events)
self.kv_event_publisher.publish(batch)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment