# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Scheduler-side manager for SimpleCPUOffloadConnector.""" import contextlib from collections.abc import Iterable from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any from vllm.config import VllmConfig from vllm.distributed.kv_events import KVCacheEvent from vllm.distributed.kv_transfer.kv_connector.utils import yield_req_data from vllm.logger import init_logger from vllm.utils.math_utils import cdiv from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.kv_cache_coordinator import ( KVCacheCoordinator, get_kv_cache_coordinator, ) from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import ( FullAttentionSpec, MambaSpec, SlidingWindowSpec, ) from vllm.v1.outputs import KVConnectorOutput from vllm.v1.simple_kv_offload.metadata import ( SimpleCPUOffloadMetadata, SimpleCPUOffloadWorkerMetadata, ) if TYPE_CHECKING: from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.core.kv_cache_utils import KVCacheBlock from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.request import Request logger = init_logger(__name__) @dataclass class TransferMeta: gpu_block_ids: list[int] cpu_block_ids: list[int] @dataclass class LoadRequestState: request: "Request" transfer_meta: TransferMeta load_event: int | None = None finished: bool = False # NOTE: This per-request state is only used in eager mode. @dataclass class StoreRequestState: request: "Request" # Accumulated block IDs from scheduler_output via yield_req_data. block_ids: tuple[list[int], ...] # Per-group cursors tracking how many blocks have been stored/skipped. num_stored_blocks: list[int] store_events: set[int] = field(default_factory=set) finished: bool = False class SimpleCPUOffloadScheduler: """Scheduler-side manager for CPU offloading.""" def __init__( self, vllm_config: VllmConfig, kv_cache_config: "KVCacheConfig | None", cpu_capacity_bytes: int, lazy_offload: bool = False, ): self.vllm_config = vllm_config self.kv_cache_config = kv_cache_config self.enable_kv_cache_events = ( vllm_config.kv_events_config is not None and vllm_config.kv_events_config.enable_kv_cache_events ) # NOTE: We use the same block size for both GPU and CPU. self.block_size = vllm_config.cache_config.block_size # Derive a CPU KVCacheConfig from the GPU config and build a coordinator assert kv_cache_config is not None self.cpu_kv_cache_config = self._derive_cpu_config( kv_cache_config, cpu_capacity_bytes ) self.num_cpu_blocks = self.cpu_kv_cache_config.num_blocks # Find the full attention kv group for prefix cache matching. self.fa_gidx = -1 for g_idx, g in enumerate(self.cpu_kv_cache_config.kv_cache_groups): if isinstance(g.kv_cache_spec, FullAttentionSpec): self.fa_gidx = g_idx break assert 0 <= self.fa_gidx < len(self.cpu_kv_cache_config.kv_cache_groups) logger.info( "SimpleCPUOffloadScheduler: Allocating %d CPU blocks (%.2f GB, mode=%s)", self.num_cpu_blocks, cpu_capacity_bytes / (1024**3), "lazy" if lazy_offload else "eager", ) # TODO (yifan): maybe need to enable kv_cache_events and metrics_collector here. dcp_world_size = vllm_config.parallel_config.decode_context_parallel_size pcp_world_size = vllm_config.parallel_config.prefill_context_parallel_size assert dcp_world_size == 1 and pcp_world_size == 1 self.cpu_coordinator: KVCacheCoordinator = get_kv_cache_coordinator( kv_cache_config=self.cpu_kv_cache_config, max_model_len=vllm_config.model_config.max_model_len, use_eagle=False, enable_caching=True, enable_kv_cache_events=self.enable_kv_cache_events, dcp_world_size=dcp_world_size, pcp_world_size=pcp_world_size, hash_block_size=self.block_size, ) self.cpu_block_pool: BlockPool = self.cpu_coordinator.block_pool # GPU block pool reference - bound after scheduler builds kv_cache_manager self._gpu_block_pool: BlockPool | None = None # Load metadata self._reqs_to_load: dict[str, LoadRequestState] = {} # Inverse map: load_event_idx -> req_ids. Keyed by load_event_idx because # the worker reports completions by event index, not request id. self._load_event_to_reqs: dict[int, list[str]] = {} # Store metadata self._lazy_mode = lazy_offload # Lazy mode: use a cursor to track the last scanned block in the GPU free queue. self._cursor: KVCacheBlock | None = None if self._lazy_mode: self._target_free = self._estimate_lazy_target_blocks( kv_cache_config, vllm_config.scheduler_config.max_num_batched_tokens, ) else: self._target_free = 0 self._store_event_to_blocks: dict[int, TransferMeta] = {} # Eager mode only self._reqs_to_store: dict[str, StoreRequestState] = {} self._store_event_to_reqs: dict[int, list[str]] = {} # Event counters self._load_event_counter: int = 0 self._store_event_counter: int = 0 # For TP/PP: track partial store completions across steps. # Events must be reported by all world_size workers before considered complete. self._expected_worker_count = vllm_config.parallel_config.world_size self._store_event_pending_counts: dict[int, int] = {} @staticmethod def _derive_cpu_config( gpu_config: "KVCacheConfig", cpu_capacity_bytes: int ) -> "KVCacheConfig": """Derive a CPU KVCacheConfig from the GPU config. Same kv_cache_groups, num_blocks scaled by CPU/GPU memory ratio.""" # Import here to avoid potential circular imports from vllm.v1.kv_cache_interface import KVCacheConfig as KVCacheConfigCls from vllm.v1.kv_cache_interface import KVCacheTensor assert len(gpu_config.kv_cache_tensors) > 0 gpu_total_bytes = sum(t.size for t in gpu_config.kv_cache_tensors) num_gpu_blocks = gpu_config.num_blocks num_cpu_blocks = max(1, num_gpu_blocks * cpu_capacity_bytes // gpu_total_bytes) # Create CPU kv_cache_tensors mirroring GPU by scaling size proportionally. cpu_tensors = [ KVCacheTensor( size=t.size // num_gpu_blocks * num_cpu_blocks, shared_by=list(t.shared_by), ) for t in gpu_config.kv_cache_tensors ] return KVCacheConfigCls( num_blocks=num_cpu_blocks, kv_cache_tensors=cpu_tensors, kv_cache_groups=gpu_config.kv_cache_groups, ) @staticmethod def _estimate_lazy_target_blocks( kv_cache_config: "KVCacheConfig", max_num_batched_tokens: int ) -> int: """GPU blocks to keep available (free/offloaded) per step in lazy mode.""" WATERMARK_RATIO = 1.0 # Reserve larger space to avoid running out of GPU blocks target = 0 for g in kv_cache_config.kv_cache_groups: spec = g.kv_cache_spec if isinstance(spec, MambaSpec): target += 2 elif isinstance(spec, SlidingWindowSpec): target += cdiv(spec.sliding_window, spec.block_size) + 1 else: target += cdiv(max_num_batched_tokens, spec.block_size) return int(target * (1 + WATERMARK_RATIO)) def bind_gpu_block_pool(self, gpu_block_pool: BlockPool) -> None: """Bind GPU block pool so that we can touch blocks during stores. Called by Scheduler after kv_cache_manager is ready.""" self._gpu_block_pool = gpu_block_pool def get_num_new_matched_tokens( self, request: "Request", num_computed_tokens: int ) -> tuple[int | None, bool]: """Return (num_new_tokens, is_async) from consecutive CPU cache hits.""" skipped = num_computed_tokens // self.block_size remaining_hashes = request.block_hashes[skipped:] if not remaining_hashes: return 0, False # Must recompute at least the last token, matching the logic in # kv_cache_manager.get_computed_blocks(). max_hit_len = request.num_tokens - 1 - num_computed_tokens if max_hit_len <= 0: return 0, False _, hit_length = self.cpu_coordinator.find_longest_cache_hit( remaining_hashes, max_hit_len ) if hit_length > 0: return hit_length, True return 0, False # TODO(yifan): this API now only matches the suffix part of the prefix cache. A more # general API should scan blocks in both GPU and CPU block pool in a single pass. def update_state_after_alloc( self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int, ) -> None: req_id = request.request_id block_ids_by_group = blocks.get_block_ids() num_groups = len(block_ids_by_group) # Store tracking (eager mode only). Register the request; # block IDs are accumulated from scheduler_output in # _prepare_eager_store_specs via yield_req_data. if not self._lazy_mode and req_id not in self._reqs_to_store: self._reqs_to_store[req_id] = StoreRequestState( request=request, block_ids=tuple([] for _ in range(num_groups)), num_stored_blocks=[0] * num_groups, ) if num_external_tokens == 0: return num_blocks_to_load = num_external_tokens // self.block_size assert num_blocks_to_load > 0 skipped = sum(blk.block_hash is not None for blk in blocks.blocks[self.fa_gidx]) num_computed_tokens = skipped * self.block_size hashes_to_load = request.block_hashes[skipped : skipped + num_blocks_to_load] # Find CPU cached blocks across all groups. max_hit_len = len(hashes_to_load) * self.block_size cpu_hit_blocks, hit_length = self.cpu_coordinator.find_longest_cache_hit( hashes_to_load, max_hit_len ) assert hit_length == num_external_tokens, ( f"Expected {num_external_tokens} hit tokens, got {hit_length}" ) # Build transfer pairs across all groups. total_computed_tokens = num_computed_tokens + num_external_tokens kv_cache_groups = self.cpu_kv_cache_config.kv_cache_groups gpu_block_ids: list[int] = [] cpu_block_ids: list[int] = [] cpu_blocks_to_touch: list[KVCacheBlock] = [] for g in range(num_groups): cpu_blocks_g = cpu_hit_blocks[g] n_ext_g = len(cpu_blocks_g) if n_ext_g == 0: continue # Number of blocks in the computed range for this group. g_block_size = kv_cache_groups[g].kv_cache_spec.block_size n_computed_g = cdiv(total_computed_tokens, g_block_size) # Back-trace: ext blocks sit at the tail of the computed range. gpu_ext_start = n_computed_g - n_ext_g group_gpu_ids = block_ids_by_group[g] for i, cpu_blk in enumerate(cpu_blocks_g): # Skip null blocks (e.g. sliding window or mamba padding). if cpu_blk.is_null: continue gpu_block_ids.append(group_gpu_ids[gpu_ext_start + i]) cpu_block_ids.append(cpu_blk.block_id) cpu_blocks_to_touch.append(cpu_blk) # Touch CPU blocks to prevent eviction during async load. self.cpu_block_pool.touch(cpu_blocks_to_touch) # Touch GPU blocks to prevent freeing during async load assert self._gpu_block_pool is not None self._gpu_block_pool.touch( [self._gpu_block_pool.blocks[bid] for bid in gpu_block_ids] ) assert self._reqs_to_load.get(req_id) is None self._reqs_to_load[req_id] = LoadRequestState( request=request, transfer_meta=TransferMeta(gpu_block_ids, cpu_block_ids) ) def build_connector_meta( self, scheduler_output: SchedulerOutput, ) -> SimpleCPUOffloadMetadata: # --- Stores --- store_event = -1 store_gpu, store_cpu, store_req_ids = self.prepare_store_specs(scheduler_output) if store_gpu: store_event = self._store_event_counter self._store_event_counter += 1 self._store_event_to_blocks[store_event] = TransferMeta( store_gpu, store_cpu ) if store_req_ids: # For eager mode only, track req->blocks mapping self._store_event_to_reqs[store_event] = store_req_ids for req_id in store_req_ids: store_state = self._reqs_to_store.get(req_id) if store_state is not None: store_state.store_events.add(store_event) # --- Loads --- load_event = -1 load_gpu: list[int] = [] load_cpu: list[int] = [] load_req_ids: list[str] = [] for req_id, load_state in self._reqs_to_load.items(): if load_state.load_event is not None: continue assert load_state.transfer_meta is not None load_gpu.extend(load_state.transfer_meta.gpu_block_ids) load_cpu.extend(load_state.transfer_meta.cpu_block_ids) load_req_ids.append(req_id) if load_req_ids: load_event = self._load_event_counter self._load_event_counter += 1 for req_id in load_req_ids: self._reqs_to_load[req_id].load_event = load_event self._load_event_to_reqs[load_event] = load_req_ids result = SimpleCPUOffloadMetadata( load_event=load_event, load_gpu_blocks=load_gpu, load_cpu_blocks=load_cpu, load_event_to_reqs=self._load_event_to_reqs, store_event=store_event, store_gpu_blocks=store_gpu, store_cpu_blocks=store_cpu, need_flush=bool(scheduler_output.preempted_req_ids), ) return result def prepare_store_specs( self, scheduler_output: SchedulerOutput ) -> tuple[list[int], list[int], list[str]]: """Prepare store specs for the store event.""" if self._lazy_mode: return self._prepare_lazy_store_specs() else: return self._prepare_eager_store_specs(scheduler_output) def _prepare_lazy_store_specs( self, ) -> tuple[list[int], list[int], list[str]]: """Single-pass cursor walk: offload cached GPU blocks near eviction. Walks the GPU free queue from the cursor, counting blocks that are free-or-offloaded (safe for the allocator to evict). Stops when target_free blocks are covered or CPU capacity is reached. """ gpu_pool = self._gpu_block_pool if gpu_pool is None or self._target_free <= 0: return [], [], [] free_queue = gpu_pool.free_block_queue cpu_pool = self.cpu_block_pool num_cpu_free = cpu_pool.get_num_free_blocks() # Validate cursor: stale if block was removed from free queue. if self._cursor is not None and self._cursor.ref_cnt > 0: self._cursor = None # Determine start node. if self._cursor is None: node = free_queue.fake_free_list_head.next_free_block else: node = self._cursor.next_free_block tail = free_queue.fake_free_list_tail gpu_ids: list[int] = [] block_hashes: list[bytes] = [] covered = 0 last_visited = self._cursor while ( node is not None and node is not tail and covered < self._target_free and len(gpu_ids) < num_cpu_free ): last_visited = node bhash = node.block_hash if ( bhash is not None and not node.is_null and cpu_pool.cached_block_hash_to_block.get_one_block(bhash) is None ): gpu_ids.append(node.block_id) block_hashes.append(bhash) covered += 1 node = node.next_free_block self._cursor = last_visited # Batch-allocate CPU blocks and stamp hashes. if gpu_ids: cpu_blocks = cpu_pool.get_new_blocks(len(gpu_ids)) cpu_ids = [blk.block_id for blk in cpu_blocks] for cpu_blk, bhash in zip(cpu_blocks, block_hashes): # type: ignore[assignment] cpu_blk._block_hash = bhash # type: ignore[assignment] # Touch GPU blocks to prevent eviction during async copy. gpu_pool.touch([gpu_pool.blocks[bid] for bid in gpu_ids]) else: cpu_ids = [] return gpu_ids, cpu_ids, [] def _prepare_eager_store_specs( self, scheduler_output: SchedulerOutput ) -> tuple[list[int], list[int], list[str]]: """Identify newly computed blocks to offload from scheduler requests. Only considers blocks whose KV data has been **confirmed computed** by the GPU. This means blocks from the current step are NOT stored until the next step. If a request finishes in the same step as its last full block, that block may be missed. (TODO: flush on finish.) Returns: (gpu_block_ids, cpu_block_ids, req_ids) for the store event. """ merged_gpu_block_ids: list[int] = [] merged_cpu_block_ids: list[int] = [] req_ids: list[str] = [] gpu_block_pool = self._gpu_block_pool if gpu_block_pool is None: return [], [], [] cpu_block_pool = self.cpu_block_pool num_free = cpu_block_pool.get_num_free_blocks() kv_cache_groups = self.cpu_kv_cache_config.kv_cache_groups num_groups = len(kv_cache_groups) gpu_blocks_this_step: set[int] = set() for req_id, new_block_id_groups, preempted in yield_req_data(scheduler_output): state = self._reqs_to_store.get(req_id) if state is None or state.finished: continue # Accumulate new block IDs. if preempted: state.block_ids = tuple([] for _ in range(num_groups)) state.num_stored_blocks = [0] * num_groups if new_block_id_groups: for g in range(min(num_groups, len(new_block_id_groups))): if new_block_id_groups[g] is not None: state.block_ids[g].extend(new_block_id_groups[g]) num_new_tokens = scheduler_output.num_scheduled_tokens.get(req_id, 0) if num_new_tokens == 0: continue block_ids_by_group = state.block_ids if not block_ids_by_group: continue # --- Phase 1: Scan blocks, classify as cached vs to-store --- gpu_block_ids: list[int] = [] block_hashes_to_store: list[bytes] = [] advanced_per_group: list[int] = [0] * num_groups out_of_space = False # Confirmed tokens: KV data written and visible to all streams. req = state.request confirmed_tokens = req.num_computed_tokens - req.num_output_placeholders for g in range(num_groups): # FIXME (yifan): handle CPU cache eviction, where # num_stored_blocks can be stale and omit evicted blocks in # the middle of the request. already_stored_g = state.num_stored_blocks[g] group_gpu_ids = block_ids_by_group[g] # Cap to blocks with confirmed KV data. g_block_size = kv_cache_groups[g].kv_cache_spec.block_size ready_blocks_g = confirmed_tokens // g_block_size scannable = group_gpu_ids[already_stored_g:ready_blocks_g] for gpu_block_id in scannable: gpu_block = gpu_block_pool.blocks[gpu_block_id] if gpu_block.is_null: advanced_per_group[g] += 1 continue bhash_with_group = gpu_block.block_hash if bhash_with_group is None: break # Check if this group's data is already scheduled for store # in this step or already cached in CPU. if ( gpu_block_id in gpu_blocks_this_step or cpu_block_pool.cached_block_hash_to_block.get_one_block( bhash_with_group ) is not None ): advanced_per_group[g] += 1 continue if num_free <= 0: out_of_space = True break num_free -= 1 gpu_block_ids.append(gpu_block_id) block_hashes_to_store.append(bhash_with_group) advanced_per_group[g] += 1 if out_of_space: break # --- Phase 2: Batch allocate CPU blocks and stamp hashes --- n_to_alloc = len(gpu_block_ids) if n_to_alloc > 0: cpu_blocks_alloc = cpu_block_pool.get_new_blocks(n_to_alloc) cpu_block_ids = [blk.block_id for blk in cpu_blocks_alloc] for cpu_blk, bhash in zip(cpu_blocks_alloc, block_hashes_to_store): cpu_blk._block_hash = bhash # type: ignore[assignment] else: cpu_block_ids = [] if cpu_block_ids: req_ids.append(req_id) merged_gpu_block_ids.extend(gpu_block_ids) merged_cpu_block_ids.extend(cpu_block_ids) gpu_blocks_this_step.update(gpu_block_ids) # Touch GPU blocks to prevent freeing during async copy gpu_block_pool.touch( [gpu_block_pool.blocks[bid] for bid in gpu_block_ids] ) logger.debug( "Request %s: Scheduling store of %d blocks to CPU (%d groups)", req_id, len(cpu_block_ids), num_groups, ) # Advance per-group cursors (includes cached hits + newly stored) for g in range(num_groups): state.num_stored_blocks[g] += advanced_per_group[g] return merged_gpu_block_ids, merged_cpu_block_ids, req_ids def update_connector_output(self, connector_output: KVConnectorOutput) -> None: """Handle async transfer completions from worker. Load completions arrive via finished_recving (real req_ids). Store completions arrive via kv_connector_worker_meta as per-event worker counts. We accumulate across steps and process a store event only when all workers have reported completion. """ # --- Load completions --- for req_id in list(connector_output.finished_recving or []): self._cleanup_load_request(req_id) # --- Store completions --- meta = connector_output.kv_connector_worker_meta if not isinstance(meta, SimpleCPUOffloadWorkerMetadata): return for event_idx, count in meta.completed_store_events.items(): total = self._store_event_pending_counts.get(event_idx, 0) + count if total >= self._expected_worker_count: self._store_event_pending_counts.pop(event_idx, None) self._process_store_event(event_idx) else: self._store_event_pending_counts[event_idx] = total def _process_store_event(self, event_idx: int) -> None: """Process a fully-completed store event.""" transfer = self._store_event_to_blocks.pop(event_idx) self._process_store_completion(transfer.gpu_block_ids, transfer.cpu_block_ids) logger.debug( "Store event %d completed: cached %d blocks to CPU", event_idx, len(transfer.cpu_block_ids), ) # Eager only: update per-req state if not self._lazy_mode: for req_id in self._store_event_to_reqs.pop(event_idx, []): state = self._reqs_to_store.get(req_id) if state is None: continue state.store_events.discard(event_idx) if state.finished and not state.store_events: self._cleanup_store_request(req_id) def _process_store_completion( self, gpu_block_ids: list[int], cpu_block_ids: list[int] ) -> None: """Cache CPU blocks per-group and release GPU refs. Block hashes were stamped on CPU blocks at allocation time (in ``_prepare_*_store_specs``). Here we just register them in the cache map so they become discoverable by the load path. """ assert len(cpu_block_ids) == len(gpu_block_ids) cpu_blocks = [self.cpu_block_pool.blocks[bid] for bid in cpu_block_ids] for cpu_block in cpu_blocks: bhash = cpu_block.block_hash assert bhash is not None self.cpu_block_pool.cached_block_hash_to_block.insert(bhash, cpu_block) # Free CPU and GPU blocks' ref counts to turn them into prefix cache self.cpu_block_pool.free_blocks(cpu_blocks) assert self._gpu_block_pool is not None self._gpu_block_pool.free_blocks( self._gpu_block_pool.blocks[bid] for bid in gpu_block_ids ) def has_pending_stores(self) -> bool: """Return True if there are in-flight store transfers.""" return bool(self._store_event_to_blocks) def request_finished( self, request: "Request", block_ids: list[int], ) -> tuple[bool, dict[str, Any] | None]: """Always returns (False, None). GPU blocks are protected by ref_cnt, so the scheduler can free blocks immediately.""" req_id = request.request_id # Handle load: defer cleanup if load is in-flight load_state = self._reqs_to_load.get(req_id) if load_state is not None: if load_state.load_event is not None: load_state.finished = True # Defer: load in-flight else: self._cleanup_load_request(req_id) # Handle store (eager mode only): defer cleanup if stores in-flight if not self._lazy_mode: store_state = self._reqs_to_store.get(req_id) if store_state is not None: if store_state.store_events: store_state.finished = True # Defer: stores in-flight else: self._cleanup_store_request(req_id) return False, None def request_finished_all_groups( self, request: "Request", block_ids: tuple[list[int], ...], ) -> tuple[bool, dict[str, Any] | None]: return self.request_finished(request, block_ids=[]) def _cleanup_load_request(self, req_id: str) -> None: """Release all load resources for a request. Shared between request_finished() and update_connector_output() paths. Removes the request from _reqs_to_load, cleans up event mappings, and frees CPU/GPU touch refs. """ state = self._reqs_to_load.pop(req_id, None) if state is None: return # Remove from load event mapping (only this req, not whole event) if state.load_event is not None: reqs = self._load_event_to_reqs.get(state.load_event) if reqs is not None: with contextlib.suppress(ValueError): reqs.remove(req_id) if not reqs: self._load_event_to_reqs.pop(state.load_event, None) if state.transfer_meta is not None: # Free CPU touch refs self.cpu_block_pool.free_blocks( self.cpu_block_pool.blocks[bid] for bid in state.transfer_meta.cpu_block_ids ) # Free GPU touch refs assert self._gpu_block_pool is not None self._gpu_block_pool.free_blocks( self._gpu_block_pool.blocks[bid] for bid in state.transfer_meta.gpu_block_ids ) def _cleanup_store_request(self, req_id: str) -> None: """Release store metadata for a request. Metadata-only cleanup but no block freeing. Job completion handles block caching and GPU ref freeing via _process_store_completion(). """ state = self._reqs_to_store.pop(req_id, None) if state is None: return for event_idx in list(state.store_events): if (reqs := self._store_event_to_reqs.get(event_idx)) is not None: with contextlib.suppress(ValueError): reqs.remove(req_id) if not reqs: self._store_event_to_reqs.pop(event_idx, None) state.store_events.clear() def take_events(self) -> Iterable[KVCacheEvent]: return self.cpu_block_pool.take_events()