Commit 1dbbafd3 authored by Yifan Qiao's avatar Yifan Qiao Committed by khluu
Browse files

[Feat][v1] Simple yet General CPU KV Cache Offloading (#37160)


Signed-off-by: default avatarYifan Qiao <yifanqiao@berkeley.edu>
Signed-off-by: default avatarYifan Qiao <yifanqiao@inferact.ai>
(cherry picked from commit 91e4521f)
parent 0ee3b7fc
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Integration tests for SimpleCPUOffloadConnector with real models."""
import time
import pytest
from vllm import LLM, SamplingParams, TokensPrompt
from vllm.config import KVTransferConfig
from vllm.platforms import current_platform
if not current_platform.is_cuda():
pytest.skip("Requires CUDA", allow_module_level=True)
# Small models for default CI / local runs (accuracy only).
SMALL_MODELS = [
"meta-llama/Llama-3.2-1B-Instruct",
"google/gemma-3-1b-it",
]
# Large models for optional perf runs only (slow to load and execute).
PERF_MODELS = [
"meta-llama/Llama-3.1-8B",
"openai/gpt-oss-20b",
]
def _make_llm(model: str, lazy: bool, cpu_bytes_to_use: int) -> LLM:
kv_transfer_config = KVTransferConfig(
kv_connector="SimpleCPUOffloadConnector",
kv_role="kv_both",
kv_connector_extra_config={
"cpu_bytes_to_use": cpu_bytes_to_use,
"lazy_offload": lazy,
},
)
return LLM(
model=model,
kv_cache_memory_bytes=40 << 30, # 40 GiB
disable_hybrid_kv_cache_manager=False,
enable_prefix_caching=True,
kv_transfer_config=kv_transfer_config,
)
def _flush_gpu_cache(llm: LLM, sampling_params: SamplingParams, seed: int = 0):
"""Generate enough filler requests to allocate the entire GPU KV cache.
This pushes all prior blocks through the free queue so that the lazy
cursor offloads them to CPU before they are evicted.
"""
cache_config = llm.llm_engine.vllm_config.cache_config
num_gpu_blocks = cache_config.num_gpu_blocks
block_size = cache_config.block_size
# Use 1.2x GPU capacity to give the lazy cursor enough scheduling steps
# to walk past all target blocks near the tail of the free queue.
total_tokens_needed = int(num_gpu_blocks * block_size * 1.5)
# Use token-id prompts so each filler is unique (no prefix sharing).
# Split into multiple requests to stay under max_model_len.
max_tokens_per_req = 4096
num_fillers = (total_tokens_needed + max_tokens_per_req - 1) // max_tokens_per_req
batch_size = 10
for i in range(0, num_fillers, batch_size):
batch_end = min(i + batch_size, num_fillers)
filler_prompts = []
for j in range(i, batch_end):
ids = [seed * num_fillers + j + 1] * max_tokens_per_req
filler_prompts.append(TokensPrompt(prompt_token_ids=ids))
llm.generate(filler_prompts, sampling_params, use_tqdm=False)
def _accuracy_test(llm: LLM, lazy: bool = False):
"""Verify that CPU-loaded KV produces correct output."""
sampling_params = SamplingParams(max_tokens=1, temperature=0)
prompt = "hi " * 2000 + "Let's count to ten. One, two, three, "
# Cold run — populate GPU cache and trigger CPU offload
cold_output = llm.generate(prompt, sampling_params, use_tqdm=False)[0]
# CPU hit runs
test_count = 10
success_count = 0
expected = cold_output.outputs[0].text
for i in range(test_count):
if lazy:
_flush_gpu_cache(llm, sampling_params, seed=i)
time.sleep(2) # let engine core drain pending transfers
# Reset GPU prefix cache so next run must load from CPU
if not llm.reset_prefix_cache():
print(f"GPU prefix cache reset failed for iteration {i}")
output = llm.generate(prompt, sampling_params, use_tqdm=False)[0]
if output.outputs[0].text == expected:
success_count += 1
assert success_count >= 0.5 * test_count, (
f"Accuracy too low: {success_count}/{test_count} matched '{expected}'"
)
def _latency_test(llm: LLM, lazy: bool = False):
"""Verify CPU cache hit is faster than cold compute."""
sampling_params = SamplingParams(max_tokens=1, seed=42)
prompt_token_ids = [0] * 10001
num_times_cpu_better = 0
num_tests = 10
for i in range(num_tests):
prompt_token_ids[0] = i
prompts = [TokensPrompt(prompt_token_ids=prompt_token_ids)]
# Cold
time.sleep(2) # let engine core drain pending transfers
if not llm.reset_prefix_cache():
print(f"GPU prefix cache reset failed for iteration {i}")
start = time.time()
llm.generate(prompts, sampling_params, use_tqdm=False)
cold_time = time.time() - start
if lazy:
_flush_gpu_cache(llm, sampling_params, seed=i)
else:
# Eager mode: GPU hit ensures store completion is processed.
llm.generate(prompts, sampling_params, use_tqdm=False)
time.sleep(2) # let engine core drain pending transfers
if not llm.reset_prefix_cache():
print(f"GPU prefix cache reset failed for iteration {i}")
# CPU hit
start = time.time()
llm.generate(prompts, sampling_params, use_tqdm=False)
cpu_time = time.time() - start
if cpu_time < cold_time:
num_times_cpu_better += 1
assert num_times_cpu_better >= 0.8 * num_tests, (
f"CPU hit only faster {num_times_cpu_better}/{num_tests} times"
)
@pytest.mark.optional
@pytest.mark.slow_test
@pytest.mark.parametrize("model", SMALL_MODELS)
def test_simple_cpu_offload_accuracy(model: str):
"""Store to CPU, reset GPU, load from CPU; verify output matches baseline."""
llm = _make_llm(model, False, 1 << 30) # 1GB
try:
_accuracy_test(llm, lazy=False)
finally:
del llm
@pytest.mark.optional
@pytest.mark.slow_test
@pytest.mark.parametrize("model", PERF_MODELS)
def test_simple_cpu_offload_perf_latency(model: str):
"""CPU KV hit should beat cold prefill on long context (large models only)."""
llm = _make_llm(model, False, 10 << 30) # 10GB
try:
_latency_test(llm, lazy=False)
finally:
del llm
@pytest.mark.optional
@pytest.mark.slow_test
@pytest.mark.parametrize("model", SMALL_MODELS)
def test_simple_cpu_offload_accuracy_lazy(model: str):
"""Lazy mode: flush GPU cache to trigger CPU offload, then verify hit."""
# CPU must be larger than GPU KV cache to avoid evicting offloaded blocks.
llm = _make_llm(model, True, 80 << 30) # 80GB
try:
_accuracy_test(llm, lazy=True)
finally:
del llm
@pytest.mark.optional
@pytest.mark.slow_test
@pytest.mark.parametrize("model", PERF_MODELS)
def test_simple_cpu_offload_perf_latency_lazy(model: str):
"""Lazy mode: CPU KV hit should beat cold prefill (large models only)."""
# CPU must be larger than GPU KV cache to avoid evicting offloaded blocks.
llm = _make_llm(model, True, 80 << 30) # 80GB
try:
_latency_test(llm, lazy=True)
finally:
del llm
This diff is collapsed.
......@@ -657,7 +657,11 @@ class VllmConfig:
)
if kv_offloading_backend == "native":
self.kv_transfer_config.kv_connector = "OffloadingConnector"
if envs.VLLM_USE_SIMPLE_KV_OFFLOAD:
config_connector = "SimpleCPUOffloadConnector"
else:
config_connector = "OffloadingConnector"
self.kv_transfer_config.kv_connector = config_connector
self.kv_transfer_config.kv_connector_extra_config.update(
{"cpu_bytes_to_use": kv_offloading_size * (1 << 30)}
)
......
......@@ -202,6 +202,7 @@ KVConnectorFactory.register_connector(
"vllm.distributed.kv_transfer.kv_connector.v1.decode_bench_connector",
"DecodeBenchConnector",
)
KVConnectorFactory.register_connector(
"MooncakeConnector",
"vllm.distributed.kv_transfer.kv_connector.v1.mooncake.mooncake_connector",
......@@ -213,3 +214,9 @@ KVConnectorFactory.register_connector(
"vllm.distributed.kv_transfer.kv_connector.v1.flexkv_connector",
"FlexKVConnectorV1",
)
KVConnectorFactory.register_connector(
"SimpleCPUOffloadConnector",
"vllm.distributed.kv_transfer.kv_connector.v1.simple_cpu_offload_connector",
"SimpleCPUOffloadConnector",
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""SimpleCPUOffloadConnector: minimal CPU KV cache offloading."""
from collections.abc import Iterable
from typing import TYPE_CHECKING, Any
import torch
from vllm.config import VllmConfig
from vllm.distributed.kv_events import KVCacheEvent
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1,
KVConnectorMetadata,
KVConnectorRole,
SupportsHMA,
)
from vllm.logger import init_logger
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.outputs import KVConnectorOutput
from vllm.v1.simple_kv_offload.manager import (
SimpleCPUOffloadScheduler,
)
from vllm.v1.simple_kv_offload.metadata import (
SimpleCPUOffloadMetadata,
)
from vllm.v1.simple_kv_offload.worker import (
SimpleCPUOffloadWorker,
)
if TYPE_CHECKING:
from vllm.forward_context import ForwardContext
from vllm.v1.attention.backend import AttentionMetadata
from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.request import Request
logger = init_logger(__name__)
# Default CPU capacity: 8 GB
DEFAULT_CPU_CAPACITY_BYTES = 8 * (1024**3)
class SimpleCPUOffloadConnector(KVConnectorBase_V1, SupportsHMA):
"""CPU KV cache offloading with custom kernel transfers and BlockPool LRU."""
def __init__(
self,
vllm_config: VllmConfig,
role: KVConnectorRole,
kv_cache_config: "KVCacheConfig | None" = None,
):
super().__init__(vllm_config, role, kv_cache_config)
enable_prefix_caching = vllm_config.cache_config.enable_prefix_caching
extra_config = self._kv_transfer_config.kv_connector_extra_config or {}
cpu_capacity_bytes = int(
extra_config.get("cpu_bytes_to_use", DEFAULT_CPU_CAPACITY_BYTES)
)
# cpu_bytes_to_use is server-wide for compatibility;
# cpu_bytes_to_use_per_rank overrides for per-rank capacity.
world_size = vllm_config.parallel_config.world_size
cpu_capacity_per_rank = cpu_capacity_bytes // world_size
if "cpu_bytes_to_use_per_rank" in extra_config:
explicit = int(extra_config["cpu_bytes_to_use_per_rank"])
if explicit != cpu_capacity_per_rank:
logger.warning(
"cpu_bytes_to_use_per_rank (%.2f GB) != "
"cpu_bytes_to_use/world_size (%.2f GB). Using per-rank value.",
explicit / (1024**3),
cpu_capacity_per_rank / (1024**3),
)
cpu_capacity_per_rank = explicit
lazy_offload = bool(extra_config.get("lazy_offload", False))
self.scheduler_manager: SimpleCPUOffloadScheduler | None = None
self.worker_handler: SimpleCPUOffloadWorker | None = None
if not enable_prefix_caching:
logger.warning(
"Detected prefix caching disabled, disabling CPU offload "
"since it requires prefix caching."
)
return
logger.info(
"SimpleCPUOffloadConnector: role=%s, "
"per_rank=%.2f GB, world_size=%d, mode=%s",
role.name,
cpu_capacity_per_rank / (1024**3),
world_size,
"lazy" if lazy_offload else "eager",
)
if role == KVConnectorRole.SCHEDULER:
self.scheduler_manager = SimpleCPUOffloadScheduler(
vllm_config,
kv_cache_config,
cpu_capacity_per_rank,
lazy_offload=lazy_offload,
)
elif role == KVConnectorRole.WORKER:
self.worker_handler = SimpleCPUOffloadWorker(
vllm_config, kv_cache_config, cpu_capacity_per_rank
)
# --- Worker-side methods ---
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]) -> None:
if self.worker_handler is not None:
self.worker_handler.register_kv_caches(kv_caches)
def bind_connector_metadata(
self,
connector_metadata: KVConnectorMetadata,
) -> None:
super().bind_connector_metadata(connector_metadata)
if self.worker_handler is not None:
assert isinstance(connector_metadata, SimpleCPUOffloadMetadata)
self.worker_handler.bind_connector_metadata(connector_metadata)
def clear_connector_metadata(self) -> None:
super().clear_connector_metadata()
if self.worker_handler is not None:
self.worker_handler.clear_connector_metadata()
def handle_preemptions(self, kv_connector_metadata: KVConnectorMetadata) -> None:
if self.worker_handler is not None:
assert isinstance(kv_connector_metadata, SimpleCPUOffloadMetadata)
self.worker_handler.handle_preemptions(kv_connector_metadata)
def start_load_kv(self, forward_context: "ForwardContext", **kwargs: Any) -> None:
pass # Launch loads ops in get_finished() after launching model execution
def wait_for_layer_load(self, layer_name: str) -> None:
pass # Always load asynchronously and deferred to get_finished()
def save_kv_layer(
self,
layer_name: str,
kv_layer: torch.Tensor,
attn_metadata: "AttentionMetadata",
**kwargs: Any,
) -> None:
pass # Always save asynchronously and deferred to get_finished()
def wait_for_save(self) -> None:
pass # All stores are driven by get_finished() and no wait needed
def get_finished(
self,
finished_req_ids: set[str],
) -> tuple[set[str] | None, set[str] | None]:
if self.worker_handler is not None:
return self.worker_handler.get_finished(finished_req_ids)
return None, None
def build_connector_worker_meta(self):
if self.worker_handler is not None:
return self.worker_handler.build_connector_worker_meta()
return None
# --- Scheduler-side methods ---
# NOTE: New API only for SimpleCPUOffloadConnector.
def bind_gpu_block_pool(self, gpu_block_pool: "BlockPool") -> None:
if self.scheduler_manager is not None:
self.scheduler_manager.bind_gpu_block_pool(gpu_block_pool)
def get_num_new_matched_tokens(
self,
request: "Request",
num_computed_tokens: int,
) -> tuple[int | None, bool]:
if self.scheduler_manager is not None:
return self.scheduler_manager.get_num_new_matched_tokens(
request, num_computed_tokens
)
return 0, False
def update_state_after_alloc(
self,
request: "Request",
blocks: "KVCacheBlocks",
num_external_tokens: int,
) -> None:
if self.scheduler_manager is not None:
self.scheduler_manager.update_state_after_alloc(
request, blocks, num_external_tokens
)
def build_connector_meta(
self,
scheduler_output: SchedulerOutput,
) -> KVConnectorMetadata:
if self.scheduler_manager is not None:
return self.scheduler_manager.build_connector_meta(scheduler_output)
return SimpleCPUOffloadMetadata()
def update_connector_output(
self,
connector_output: KVConnectorOutput,
) -> None:
if self.scheduler_manager is not None:
self.scheduler_manager.update_connector_output(connector_output)
def request_finished(
self,
request: "Request",
block_ids: list[int],
) -> tuple[bool, dict[str, Any] | None]:
if self.scheduler_manager is not None:
return self.scheduler_manager.request_finished(request, block_ids)
return False, None
def request_finished_all_groups(
self,
request: "Request",
block_ids: tuple[list[int], ...],
) -> tuple[bool, dict[str, Any] | None]:
if self.scheduler_manager is not None:
return self.scheduler_manager.request_finished_all_groups(
request, block_ids
)
return False, None
# NOTE: New API only for SimpleCPUOffloadConnector.
def has_pending_transfers(self) -> bool:
if self.scheduler_manager is not None:
return self.scheduler_manager.has_pending_stores()
return False
def take_events(self) -> Iterable[KVCacheEvent]:
if self.scheduler_manager is not None:
return self.scheduler_manager.take_events()
return []
def reset_cache(self) -> bool | None:
raise NotImplementedError(
"SimpleCPUOffloadConnector does not support reset_cache(). "
"reset_prefix_cache() requires synchronizing all pending "
"CPU offload transfers before clearing GPU prefix cache blocks, "
"which is not yet implemented."
)
......@@ -1662,6 +1662,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_XPU_ENABLE_XPU_GRAPH": lambda: bool(
int(os.getenv("VLLM_XPU_ENABLE_XPU_GRAPH", "0"))
),
# Enable simple KV offload.
"VLLM_USE_SIMPLE_KV_OFFLOAD": lambda: bool(
int(os.getenv("VLLM_USE_SIMPLE_KV_OFFLOAD", "0"))
),
}
......
......@@ -234,6 +234,13 @@ class Scheduler(SchedulerInterface):
hash_block_size=self.block_size,
metrics_collector=self.kv_metrics_collector,
)
# Bind GPU block pool to the KV connector. This must happen after
# kv_cache_manager is constructed so block_pool is available.
if self.connector is not None and hasattr(
self.connector, "bind_gpu_block_pool"
):
self.connector.bind_gpu_block_pool(self.kv_cache_manager.block_pool)
self.use_pp = self.parallel_config.pipeline_parallel_size > 1
self.use_v2_model_runner = envs.VLLM_USE_V2_MODEL_RUNNER
self.scheduler_reserve_full_isl = (
......
......@@ -281,8 +281,16 @@ class PromptTokenStats:
self.computed += prompt_len - num_cached_tokens
self.external_kv_transfer += num_external_computed_tokens
self.local_cache_hit += (
num_cached_tokens + recomputed - num_external_computed_tokens
# FIXME(yifan): local_cache_hit can go negative after preemption.
# num_cached_tokens is a one-time snapshot from first scheduling and
# is never reset on preemption, while num_external_computed_tokens is
# overwritten on re-scheduling. If CPU offload finds more tokens on
# the second pass than the original total, the subtraction underflows.
# A fundamental fix is to track the first-time num_external_computed_tokens
# as a separate metric rather than reusing num_external_computed_tokens
# for metric directly.
self.local_cache_hit += max(
0, (num_cached_tokens + recomputed - num_external_computed_tokens)
)
self.cached_tokens += num_cached_tokens
self.recomputed_tokens += recomputed
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""DMA copy backend for GPU<->CPU block transfers."""
from __future__ import annotations
import queue
import threading
import torch
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.v1.simple_kv_offload.cuda_mem_ops import (
BatchMemcpyParams,
build_params,
copy_blocks,
)
logger = init_logger(__name__)
class DmaCopyBackend:
"""cuMemcpyBatchAsync copy backend (background thread)."""
def __init__(self) -> None:
self._store_params: BatchMemcpyParams | None = None
self._load_params: BatchMemcpyParams | None = None
self._load_stream: torch.cuda.Stream | None = None
self._store_stream: torch.cuda.Stream | None = None
self._queue: queue.SimpleQueue | None = None
self._thread: threading.Thread | None = None
self._shutdown: bool = False
def init(
self,
gpu_caches: dict[str, torch.Tensor],
cpu_caches: dict[str, torch.Tensor],
device: torch.device,
load_stream: torch.cuda.Stream,
store_stream: torch.cuda.Stream,
) -> None:
self._load_stream = load_stream
self._store_stream = store_stream
self._store_params = build_params(gpu_caches, cpu_caches, store_stream)
self._load_params = build_params(cpu_caches, gpu_caches, load_stream)
self._queue = queue.SimpleQueue()
self._thread = threading.Thread(
target=self._copy_loop,
args=(self._queue, device, load_stream, store_stream),
daemon=True,
)
self._thread.start()
def launch_copy(
self,
src_blocks: list[int],
dst_blocks: list[int],
is_store: bool,
event_idx: int,
events_list: list[tuple[int, torch.Event]],
) -> None:
params = self._store_params if is_store else self._load_params
assert params is not None and self._queue is not None
self._queue.put(
(src_blocks, dst_blocks, params, is_store, event_idx, events_list)
)
def shutdown(self) -> None:
if self._shutdown:
return
self._shutdown = True
if self._queue is not None:
self._queue.put(None)
if self._thread is not None:
self._thread.join(timeout=5.0)
@staticmethod
def _copy_loop(
q: queue.SimpleQueue,
device: torch.device,
load_stream: torch.cuda.Stream,
store_stream: torch.cuda.Stream,
) -> None:
current_platform.set_device(device)
while True:
item = q.get()
if item is None:
return
src_blocks, dst_blocks, params, is_store, event_idx, events_list = item
copy_blocks(src_blocks, dst_blocks, params)
stream = store_stream if is_store else load_stream
event = torch.Event()
event.record(stream)
events_list.append((event_idx, event))
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Low-level CUDA memory helpers: pinning and batch DMA transfers."""
import ctypes
from typing import Any, NamedTuple
import numpy as np
import torch
from vllm.logger import init_logger
logger = init_logger(__name__)
def pin_tensor(tensor: torch.Tensor) -> None:
"""Pin a CPU tensor via cudaHostRegister.
This bypasses PyTorch's CUDACachingHostAllocator which rounds
every ``pin_memory=True`` allocation up to the next power of 2
(e.g. 100 GB becomes 128 GB).
"""
err = torch.cuda.cudart().cudaHostRegister(tensor.data_ptr(), tensor.nbytes, 0)
if err.value != 0:
raise RuntimeError(f"cudaHostRegister failed: {err}")
class _CUmemLocation(ctypes.Structure):
_fields_ = [("type", ctypes.c_uint), ("id", ctypes.c_int)]
class _CUmemcpyAttributes(ctypes.Structure):
_fields_ = [
("srcAccessOrder", ctypes.c_uint),
("srcLocHint", _CUmemLocation),
("dstLocHint", _CUmemLocation),
("flags", ctypes.c_uint),
]
_BATCH_MEMCPY_FUNC_TYPE = ctypes.CFUNCTYPE(
ctypes.c_uint, # CUresult
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_size_t,
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_size_t,
ctypes.c_void_p,
ctypes.c_void_p,
)
# Resolved lazily on first use.
_batch_memcpy_fn: Any = None
def _resolve_batch_memcpy():
"""Resolve cuMemcpyBatchAsync via cuGetProcAddress (one-time)."""
from cuda.bindings import driver as drv
err, ptr, _ = drv.cuGetProcAddress(b"cuMemcpyBatchAsync", 12080, 0)
if err != drv.CUresult.CUDA_SUCCESS:
raise RuntimeError(f"cuGetProcAddress(cuMemcpyBatchAsync) failed: {err}")
return _BATCH_MEMCPY_FUNC_TYPE(ptr)
class BatchMemcpyParams(NamedTuple):
src_bases: np.ndarray # [num_layers] uint64 — data_ptr per layer
dst_bases: np.ndarray # [num_layers] uint64
bpb: np.ndarray # [num_layers] uint64 — bytes per block
num_layers: int
attrs: _CUmemcpyAttributes
attrs_idx: ctypes.c_size_t
# NOTE: cuMemcpyBatchAsync_v2() removed fail_idx field, but we use
# cuMemcpyBatchAsync() with fail_idx for backward compatibility
fail_idx: ctypes.c_size_t
stream_handle: int # raw cudaStream_t / CUstream
def build_params(
src_caches: dict[str, torch.Tensor],
dst_caches: dict[str, torch.Tensor],
stream: torch.cuda.Stream,
) -> BatchMemcpyParams:
global _batch_memcpy_fn
if _batch_memcpy_fn is None:
_batch_memcpy_fn = _resolve_batch_memcpy()
assert list(src_caches.keys()) == list(dst_caches.keys())
src_tensors = list(src_caches.values())
dst_tensors = list(dst_caches.values())
src_bases, dst_bases, bpb = [], [], []
for s, d in zip(src_tensors, dst_tensors):
s_bpb = s.stride(0) * s.element_size()
assert s_bpb == d.stride(0) * d.element_size()
src_bases.append(s.data_ptr())
dst_bases.append(d.data_ptr())
bpb.append(s_bpb)
# Refer to https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1g6f1ff58e3065df3eb4b573dba77ad31f for details. # noqa: E501
attrs = _CUmemcpyAttributes(srcAccessOrder=3) # ANY
return BatchMemcpyParams(
src_bases=np.array(src_bases, dtype=np.uint64),
dst_bases=np.array(dst_bases, dtype=np.uint64),
bpb=np.array(bpb, dtype=np.uint64),
num_layers=len(src_tensors),
attrs=attrs,
attrs_idx=ctypes.c_size_t(0),
fail_idx=ctypes.c_size_t(0),
stream_handle=stream.cuda_stream,
)
def copy_blocks(
src_block_ids: list[int],
dst_block_ids: list[int],
params: BatchMemcpyParams,
) -> None:
"""Copy blocks via cuMemcpyBatchAsync."""
n = len(src_block_ids)
if n == 0:
return
src_ids = np.array(src_block_ids, dtype=np.uint64)
dst_ids = np.array(dst_block_ids, dtype=np.uint64)
src_all = (
params.src_bases[:, None] + src_ids[None, :] * params.bpb[:, None]
).ravel()
dst_all = (
params.dst_bases[:, None] + dst_ids[None, :] * params.bpb[:, None]
).ravel()
sz_all = np.repeat(params.bpb, n)
total = n * params.num_layers
err = _batch_memcpy_fn(
dst_all.ctypes.data,
src_all.ctypes.data,
sz_all.ctypes.data,
total,
ctypes.addressof(params.attrs),
ctypes.byref(params.attrs_idx),
1,
ctypes.byref(params.fail_idx),
params.stream_handle,
)
if err != 0:
raise RuntimeError(
f"cuMemcpyBatchAsync failed: err={err} failIdx={params.fail_idx.value}"
)
This diff is collapsed.
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Metadata for SimpleCPUOffloadConnector."""
from dataclasses import dataclass, field
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorMetadata,
KVConnectorWorkerMetadata,
)
INVALID_JOB_ID = -1
@dataclass
class SimpleCPUOffloadMetadata(KVConnectorMetadata):
"""
Metadata passed from scheduler to worker for CPU offload operations.
The worker receives flat block lists keyed by a monotonic event_idx.
Job->req_id translation is handled by the scheduler-side manager
(via inverse maps), so the worker never knows about request identities.
"""
# Load event per step. INVALID_JOB_ID means no blocks to load this step.
load_event: int = INVALID_JOB_ID
load_gpu_blocks: list[int] = field(default_factory=list)
load_cpu_blocks: list[int] = field(default_factory=list)
# Reverse map: load_event->req_ids, for tracking requests with finished load events
load_event_to_reqs: dict[int, list[str]] = field(default_factory=dict)
# Store event per step. INVALID_JOB_ID means no blocks to store this step.
store_event: int = INVALID_JOB_ID
store_gpu_blocks: list[int] = field(default_factory=list)
store_cpu_blocks: list[int] = field(default_factory=list)
# Whether any requests were preempted this step and need flush pending transfers.
need_flush: bool = False
@dataclass
class SimpleCPUOffloadWorkerMetadata(KVConnectorWorkerMetadata):
"""Worker -> Scheduler metadata for completed store events.
Each worker reports {event_idx: 1} for newly completed stores.
``aggregate()`` sums counts across workers within a step.
The scheduler-side manager accumulates across steps and processes
a store completion only when count reaches ``world_size``.
"""
completed_store_events: dict[int, int]
def aggregate(
self, other: "KVConnectorWorkerMetadata"
) -> "KVConnectorWorkerMetadata":
assert isinstance(other, SimpleCPUOffloadWorkerMetadata)
merged = dict(self.completed_store_events)
for k, v in other.completed_store_events.items():
merged[k] = merged.get(k, 0) + v
return SimpleCPUOffloadWorkerMetadata(completed_store_events=merged)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Worker-side handler for SimpleCPUOffloadConnector."""
from typing import TYPE_CHECKING
import torch
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.utils.platform_utils import is_pin_memory_available
from vllm.v1.simple_kv_offload.copy_backend import DmaCopyBackend
from vllm.v1.simple_kv_offload.cuda_mem_ops import pin_tensor
from vllm.v1.simple_kv_offload.metadata import (
SimpleCPUOffloadMetadata,
SimpleCPUOffloadWorkerMetadata,
)
if TYPE_CHECKING:
from vllm.v1.kv_cache_interface import KVCacheConfig
logger = init_logger(__name__)
class SimpleCPUOffloadWorker:
"""Worker-side handler for CPU offloading transfers."""
def __init__(
self,
vllm_config: VllmConfig,
kv_cache_config: "KVCacheConfig | None",
cpu_capacity_bytes: int,
):
self.vllm_config = vllm_config
self.kv_cache_config = kv_cache_config
self.cpu_capacity_bytes = cpu_capacity_bytes
self.gpu_kv_caches: dict[str, torch.Tensor] | None = None
self.cpu_kv_caches: dict[str, torch.Tensor] | None = None
self.device: torch.device | None = None
self.num_cpu_blocks: int = 0
# CUDA streams for the async transfers
self.load_stream: torch.cuda.Stream | None = None
self.store_stream: torch.cuda.Stream | None = None
self._backend = DmaCopyBackend()
# Ordered (event_idx, Event). Events pre-allocated on main thread.
self._load_events: list[tuple[int, torch.Event]] = []
self._store_events: list[tuple[int, torch.Event]] = []
# High-water marks: highest event_idx completed per stream.
# When the event list is empty, the hwm covers all prior events.
self._load_hwm: int = -1
self._store_hwm: int = -1
# Metadata for the current step
self._connector_metadata: SimpleCPUOffloadMetadata | None = None
# Pending event index sets, populated in bind_connector_metadata
self._pending_load_event_indices: set[int] = set()
self._pending_store_event_indices: set[int] = set()
# Completed store events to report via build_connector_worker_meta
self._completed_store_events: dict[int, int] = {}
def register_kv_caches(
self,
kv_caches: dict[str, torch.Tensor],
) -> None:
"""Register GPU KV caches and allocate pinned CPU tensors.
The worker will infer the underlying raw storage from the kv_caches.
Args:
kv_caches: Per-layer GPU KV caches. Values are either a single
tensor (attention layers) or a list of tensors (Mamba layers
in hybrid models). All values are included for offloading
by resolving to their underlying raw storage.
"""
if not kv_caches:
logger.warning("No KV caches to offload.")
return
# Resolve each entry to a representative tensor for storage
# deduplication. For attention layers the value is already a tensor;
# for Mamba layers it is a list of tensors that all share the same
# underlying raw storage, so we take the first one.
def _repr_tensor(v: torch.Tensor | list[torch.Tensor]) -> torch.Tensor:
assert isinstance(v, torch.Tensor | list)
return v if isinstance(v, torch.Tensor) else v[0]
any_tensor = _repr_tensor(next(iter(kv_caches.values())))
self.device = any_tensor.device
assert self.kv_cache_config is not None
num_blocks = self.kv_cache_config.num_blocks
# Deduplicate: multiple layers may share the same backing storage.
seen_ptrs: dict[int, tuple[str, torch.Tensor]] = {}
for name, value in kv_caches.items():
tensor = _repr_tensor(value)
ptr = tensor.untyped_storage().data_ptr()
if ptr not in seen_ptrs:
seen_ptrs[ptr] = (name, tensor)
# Build [num_blocks, block_bytes] int8 views from each unique
# storage so that stride(0) gives block_bytes for the copy op.
#
# The physical layout varies across attention backends:
# FlashAttn/ROCm: (2, num_blocks, ...) -> K/V outermost, 2 segments
# FlashInfer/MLA: (num_blocks, ...) -> blocks outermost, 1 segment
# We derive page_size_bytes = storage.nbytes() // num_blocks, then
# classify dims: any dim whose byte-stride exceeds page_size_bytes
# must be an outer segment dim (e.g. the K/V dim of size 2). A less
# hacky way is to update the interface with the layout.
unique_gpu_caches: dict[str, torch.Tensor] = {}
for name, tensor in seen_ptrs.values():
storage = tensor.untyped_storage()
raw = torch.empty(0, dtype=torch.int8, device=self.device).set_(
storage, 0, (storage.nbytes(),)
)
el = tensor.element_size()
page_size_bytes = storage.nbytes() // num_blocks
outer_dims = [
d for d in range(tensor.ndim) if tensor.stride(d) * el > page_size_bytes
]
if not outer_dims:
unique_gpu_caches[name] = raw.view(num_blocks, -1)
else:
seg_stride = tensor.stride(outer_dims[0]) * el
for idx in range(tensor.shape[outer_dims[0]]):
offset = idx * seg_stride
chunk = raw[offset : offset + seg_stride]
unique_gpu_caches[f"{name}.{idx}"] = chunk.view(num_blocks, -1)
# Compute per-tensor bytes_per_block. Tensors may have different
# page_size_bytes (e.g., UniformTypeKVCacheSpecs with varying head_size).
per_tensor_bpb = [
t.stride(0) * t.element_size() for t in unique_gpu_caches.values()
]
total_bytes_per_block = sum(per_tensor_bpb)
self.num_cpu_blocks = max(1, self.cpu_capacity_bytes // total_bytes_per_block)
logger.info(
"SimpleCPUOffloadWorker: %d unique GPU KV tensors, "
"allocating %d CPU blocks (%.2f GB)",
len(unique_gpu_caches),
self.num_cpu_blocks,
(self.num_cpu_blocks * total_bytes_per_block) / (1024**3),
)
pin_memory = is_pin_memory_available()
if not pin_memory:
logger.warning(
"Pinned memory not available. CPU offload performance may be degraded."
)
self.gpu_kv_caches = unique_gpu_caches
self.cpu_kv_caches = {}
for name, gpu_tensor in unique_gpu_caches.items():
cpu_shape = (self.num_cpu_blocks,) + gpu_tensor.shape[1:]
# Allocate non-pinned first, then pin via cudaHostRegister to
# bypass PyTorch's CUDACachingHostAllocator which rounds up to
# the next power of 2 (e.g. 100 GB -> 128 GB).
tensor = torch.zeros(cpu_shape, dtype=gpu_tensor.dtype, device="cpu")
if pin_memory:
pin_tensor(tensor)
self.cpu_kv_caches[name] = tensor
# Use lowest priority so KV cache I/O yields to compute streams.
low_pri, _ = torch.cuda.Stream.priority_range()
self.load_stream = torch.cuda.Stream(priority=low_pri)
self.store_stream = torch.cuda.Stream(priority=low_pri)
# Initialize copy backend with caches and streams.
self._backend.init(
self.gpu_kv_caches,
self.cpu_kv_caches,
self.device,
self.load_stream,
self.store_stream,
)
def bind_connector_metadata(self, metadata: SimpleCPUOffloadMetadata) -> None:
self._connector_metadata = metadata
if metadata.load_event >= 0:
self._pending_load_event_indices.add(metadata.load_event)
if metadata.store_event >= 0:
self._pending_store_event_indices.add(metadata.store_event)
def clear_connector_metadata(self) -> None:
self._connector_metadata = None
def start_load_kv(self) -> None:
# NOTE: we defer launching both load and store to get_finished(),
# which runs after model execution. This hides the CPU-side
# block copy op overhead (~5ms) behind GPU compute.
pass
def wait_for_save(self) -> None:
pass
def get_finished(
self,
finished_req_ids: set[str],
) -> tuple[set[str] | None, set[str] | None]:
"""Submit transfers and report completed events to the scheduler.
Called after model execution. The manager only schedules stores for
blocks whose KV data is confirmed computed, so we launch both loads
and stores immediately — no deferral or cross-stream sync needed.
Returns:
tuple of (finished_sending, finished_recving).
- finished_sending: always None (stores use worker metadata).
- finished_recving: req_ids whose loads have completed.
"""
# (1) Submit transfers
metadata = self._connector_metadata
if metadata is not None:
# Launch loads (CPU->GPU).
if metadata.load_cpu_blocks:
self._backend.launch_copy(
metadata.load_cpu_blocks,
metadata.load_gpu_blocks,
is_store=False,
event_idx=metadata.load_event,
events_list=self._load_events,
)
# Launch stores (GPU->CPU).
if metadata.store_gpu_blocks:
self._backend.launch_copy(
metadata.store_gpu_blocks,
metadata.store_cpu_blocks,
is_store=True,
event_idx=metadata.store_event,
events_list=self._store_events,
)
# (2) Track completed transfer events
finished_recving: set[str] = set()
if self._pending_load_event_indices:
load_wm = self._poll_stream_events(is_store=False)
for j in [j for j in self._pending_load_event_indices if j <= load_wm]:
self._pending_load_event_indices.discard(j)
req_ids = (
metadata.load_event_to_reqs.get(j) if metadata is not None else None
)
if req_ids:
finished_recving.update(req_ids)
if self._pending_store_event_indices:
store_wm = self._poll_stream_events(is_store=True)
for j in [j for j in self._pending_store_event_indices if j <= store_wm]:
self._pending_store_event_indices.discard(j)
self._completed_store_events[j] = 1
return None, finished_recving or None
def build_connector_worker_meta(self) -> SimpleCPUOffloadWorkerMetadata | None:
"""Return completed store events since the last call."""
if not self._completed_store_events:
return None
meta = SimpleCPUOffloadWorkerMetadata(
completed_store_events=self._completed_store_events,
)
self._completed_store_events = {}
return meta
def handle_preemptions(
self, kv_connector_metadata: SimpleCPUOffloadMetadata
) -> None:
"""Sync all in-flight transfers before preempted blocks are reused."""
if not kv_connector_metadata.need_flush:
return
self._flush_and_sync_all()
def _flush_and_sync_all(self) -> None:
"""Synchronize all in-flight transfer events."""
for event_idx, event in self._load_events:
event.synchronize()
self._load_hwm = event_idx
self._load_events.clear()
for event_idx, event in self._store_events:
event.synchronize()
self._store_hwm = event_idx
self._store_events.clear()
def _poll_stream_events(self, is_store: bool) -> int:
"""Non-blocking poll for completed events and return the high-water mark."""
events = self._store_events if is_store else self._load_events
hwm = self._store_hwm if is_store else self._load_hwm
while events:
event_idx, event = events[0]
if not event.query():
break
hwm = event_idx
events.pop(0)
if is_store:
self._store_hwm = hwm
else:
self._load_hwm = hwm
return hwm
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment