Unverified Commit cfaf4668 authored by Or Ozeri's avatar Or Ozeri Committed by GitHub
Browse files

[kv_offload+HMA][1/N]: Support multiple KV groups in OffloadingSpec (#36610)


Signed-off-by: default avatarOr Ozeri <oro@il.ibm.com>
parent 99a57bdf
...@@ -26,8 +26,13 @@ from vllm.v1.core.kv_cache_utils import ( ...@@ -26,8 +26,13 @@ from vllm.v1.core.kv_cache_utils import (
get_request_block_hasher, get_request_block_hasher,
init_none_hash, init_none_hash,
) )
from vllm.v1.core.sched.async_scheduler import AsyncScheduler
from vllm.v1.core.sched.scheduler import Scheduler from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.kv_cache_interface import (
FullAttentionSpec,
KVCacheConfig,
KVCacheGroupSpec,
)
from vllm.v1.kv_offload.abstract import ( from vllm.v1.kv_offload.abstract import (
LoadStoreSpec, LoadStoreSpec,
OffloadingEvent, OffloadingEvent,
...@@ -43,11 +48,11 @@ from vllm.v1.kv_offload.worker.worker import ( ...@@ -43,11 +48,11 @@ from vllm.v1.kv_offload.worker.worker import (
) )
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, KVConnectorOutput from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, KVConnectorOutput
from vllm.v1.request import Request, RequestStatus from vllm.v1.request import Request, RequestStatus
from vllm.v1.structured_output import StructuredOutputManager
from .utils import ( from .utils import (
EOS_TOKEN_ID, EOS_TOKEN_ID,
create_model_runner_output, create_model_runner_output,
create_scheduler,
create_vllm_config, create_vllm_config,
) )
...@@ -175,10 +180,37 @@ class RequestRunner: ...@@ -175,10 +180,37 @@ class RequestRunner:
}, },
) )
self.scheduler: Scheduler = create_scheduler( block_size = vllm_config.cache_config.block_size
vllm_config, num_blocks=num_gpu_blocks kv_cache_config = KVCacheConfig(
num_blocks=num_gpu_blocks,
kv_cache_tensors=[],
kv_cache_groups=[
KVCacheGroupSpec(
["layer"],
FullAttentionSpec(
block_size=block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
),
)
],
)
vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks
self.num_kv_groups = len(kv_cache_config.kv_cache_groups)
scheduler_cls = AsyncScheduler if async_scheduling else Scheduler
self.scheduler = scheduler_cls(
vllm_config=vllm_config,
kv_cache_config=kv_cache_config,
log_stats=True,
structured_output_manager=StructuredOutputManager(vllm_config),
block_size=block_size,
)
self.worker_connector = OffloadingConnector(
vllm_config, KVConnectorRole.WORKER, kv_cache_config
) )
self.worker_connector = OffloadingConnector(vllm_config, KVConnectorRole.WORKER)
# register worker kv_caches to enable OffloadingWorker creations # register worker kv_caches to enable OffloadingWorker creations
self.worker_connector.register_cross_layers_kv_cache( self.worker_connector.register_cross_layers_kv_cache(
......
...@@ -126,6 +126,7 @@ class OffloadingConnector(KVConnectorBase_V1): ...@@ -126,6 +126,7 @@ class OffloadingConnector(KVConnectorBase_V1):
): ):
super().__init__(vllm_config, role, kv_cache_config) super().__init__(vllm_config, role, kv_cache_config)
assert kv_cache_config is not None
spec = OffloadingSpecFactory.create_spec(vllm_config, kv_cache_config) spec = OffloadingSpecFactory.create_spec(vllm_config, kv_cache_config)
self.connector_scheduler: OffloadingConnectorScheduler | None = None self.connector_scheduler: OffloadingConnectorScheduler | None = None
...@@ -245,9 +246,10 @@ class OffloadingConnectorScheduler: ...@@ -245,9 +246,10 @@ class OffloadingConnectorScheduler:
"""Implementation of Scheduler side methods""" """Implementation of Scheduler side methods"""
def __init__(self, spec: OffloadingSpec): def __init__(self, spec: OffloadingSpec):
self.gpu_block_size = spec.gpu_block_size assert len(spec.gpu_block_size) == 1
self.offloaded_block_size = spec.offloaded_block_size self.gpu_block_size = spec.gpu_block_size[0]
self.block_size_factor = self.offloaded_block_size // self.gpu_block_size self.offloaded_block_size = self.gpu_block_size * spec.block_size_factor
self.block_size_factor = spec.block_size_factor
self.manager: OffloadingManager = spec.get_manager() self.manager: OffloadingManager = spec.get_manager()
self._requests: dict[ReqId, Request] = {} self._requests: dict[ReqId, Request] = {}
......
...@@ -42,10 +42,8 @@ class CPUOffloadingSpec(OffloadingSpec): ...@@ -42,10 +42,8 @@ class CPUOffloadingSpec(OffloadingSpec):
* len(kv_cache_config.kv_cache_tensors) * len(kv_cache_config.kv_cache_tensors)
* vllm_config.parallel_config.world_size * vllm_config.parallel_config.world_size
) )
kv_bytes_per_offloaded_block = kv_bytes_per_block * (
self.offloaded_block_size // self.gpu_block_size
)
kv_bytes_per_offloaded_block = kv_bytes_per_block * self.block_size_factor
self.num_blocks = ( self.num_blocks = (
int(cpu_bytes_to_use) // kv_bytes_per_offloaded_block int(cpu_bytes_to_use) // kv_bytes_per_offloaded_block
if kv_bytes_per_offloaded_block > 0 if kv_bytes_per_offloaded_block > 0
...@@ -67,8 +65,11 @@ class CPUOffloadingSpec(OffloadingSpec): ...@@ -67,8 +65,11 @@ class CPUOffloadingSpec(OffloadingSpec):
kv_events_config is not None and kv_events_config.enable_kv_cache_events kv_events_config is not None and kv_events_config.enable_kv_cache_events
) )
assert len(self.gpu_block_size) == 1
gpu_block_size = self.gpu_block_size[0]
offloaded_block_size = gpu_block_size * self.block_size_factor
backend = CPUBackend( backend = CPUBackend(
block_size=self.offloaded_block_size, num_blocks=self.num_blocks block_size=offloaded_block_size, num_blocks=self.num_blocks
) )
if self.eviction_policy == "lru": if self.eviction_policy == "lru":
...@@ -111,10 +112,13 @@ class CPUOffloadingSpec(OffloadingSpec): ...@@ -111,10 +112,13 @@ class CPUOffloadingSpec(OffloadingSpec):
"CPU Offloading is currently only supported on CUDA-alike GPUs" "CPU Offloading is currently only supported on CUDA-alike GPUs"
) )
assert len(self.gpu_block_size) == 1
gpu_block_size = self.gpu_block_size[0]
self._handlers = CpuGpuOffloadingHandlers( self._handlers = CpuGpuOffloadingHandlers(
attn_backends=attn_backends, attn_backends=attn_backends,
gpu_block_size=self.gpu_block_size, gpu_block_size=gpu_block_size,
cpu_block_size=self.offloaded_block_size, cpu_block_size=gpu_block_size * self.block_size_factor,
num_cpu_blocks=self.num_blocks, num_cpu_blocks=self.num_blocks,
gpu_caches=kv_caches, gpu_caches=kv_caches,
) )
......
...@@ -33,7 +33,7 @@ class OffloadingSpecFactory: ...@@ -33,7 +33,7 @@ class OffloadingSpecFactory:
def create_spec( def create_spec(
cls, cls,
config: "VllmConfig", config: "VllmConfig",
kv_cache_config: "KVCacheConfig | None", kv_cache_config: "KVCacheConfig",
) -> OffloadingSpec: ) -> OffloadingSpec:
kv_transfer_config = config.kv_transfer_config kv_transfer_config = config.kv_transfer_config
assert kv_transfer_config is not None assert kv_transfer_config is not None
......
...@@ -21,9 +21,7 @@ logger = init_logger(__name__) ...@@ -21,9 +21,7 @@ logger = init_logger(__name__)
class OffloadingSpec(ABC): class OffloadingSpec(ABC):
"""Spec for an offloading connector""" """Spec for an offloading connector"""
def __init__( def __init__(self, vllm_config: "VllmConfig", kv_cache_config: "KVCacheConfig"):
self, vllm_config: "VllmConfig", kv_cache_config: "KVCacheConfig | None"
):
logger.warning( logger.warning(
"Initializing OffloadingSpec. This API is experimental and " "Initializing OffloadingSpec. This API is experimental and "
"subject to change in the future as we iterate the design." "subject to change in the future as we iterate the design."
...@@ -35,12 +33,34 @@ class OffloadingSpec(ABC): ...@@ -35,12 +33,34 @@ class OffloadingSpec(ABC):
assert kv_transfer_config is not None assert kv_transfer_config is not None
self.extra_config = kv_transfer_config.kv_connector_extra_config self.extra_config = kv_transfer_config.kv_connector_extra_config
self.gpu_block_size = vllm_config.cache_config.block_size # block size used by vLLM for hashing request tokens for the sake
self.offloaded_block_size = int( # of enabling prefix caching
self.extra_config.get("block_size", self.gpu_block_size) self.hash_block_size = vllm_config.cache_config.block_size
# gpu block size per group
self.gpu_block_size: tuple[int, ...] = tuple(
kv_cache_group.kv_cache_spec.block_size
for kv_cache_group in kv_cache_config.kv_cache_groups
) )
assert self.offloaded_block_size % self.gpu_block_size == 0 for block_size in self.gpu_block_size:
assert block_size % self.hash_block_size == 0
# offloaded_block_size / gpu_block_size
self.block_size_factor: int = 1
offloaded_block_size = self.extra_config.get("block_size")
if offloaded_block_size is not None:
offloaded_block_size_int = int(offloaded_block_size)
gpu_block_sizes = set(self.gpu_block_size)
assert len(gpu_block_sizes) == 1, (
"If 'block_size' is specified in kv_connector_extra_config, "
"there must be at least one KV cache group, "
"and all groups must have the same block size."
)
gpu_block_size = gpu_block_sizes.pop()
assert offloaded_block_size_int % gpu_block_size == 0
self.block_size_factor = offloaded_block_size_int // gpu_block_size
@abstractmethod @abstractmethod
def get_manager(self) -> OffloadingManager: def get_manager(self) -> OffloadingManager:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment