# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import numpy as np import torch from vllm.distributed import get_dcp_group, get_pcp_group from vllm.logger import init_logger from vllm.triton_utils import tl, triton from vllm.utils.math_utils import cdiv from vllm.v1.attention.backends.utils import PAD_SLOT_ID from vllm.v1.utils import CpuGpuBuffer from vllm.v1.worker.cp_utils import get_total_cp_world_size logger = init_logger(__name__) class BlockTable: def __init__( self, block_size: int, max_num_reqs: int, max_num_blocks_per_req: int, max_num_batched_tokens: int, pin_memory: bool, device: torch.device, kernel_block_size: int, cp_kv_cache_interleave_size: int, ): """ Args: block_size: Block size used for KV cache memory allocation max_num_reqs: Maximum number of concurrent requests supported. max_num_blocks_per_req: Maximum number of blocks per request. max_num_batched_tokens: Maximum number of tokens in a batch. pin_memory: Whether to pin memory for faster GPU transfers. device: Target device for the block table. kernel_block_size: The block_size of underlying attention kernel. Will be the same as `block_size` if `block_size` is supported by the attention kernel. """ self.max_num_reqs = max_num_reqs self.max_num_batched_tokens = max_num_batched_tokens self.pin_memory = pin_memory self.device = device if kernel_block_size == block_size: # Standard case: allocation and computation use same block size # No block splitting needed, direct mapping self.block_size = block_size self.blocks_per_kv_block = 1 self.use_hybrid_blocks = False else: # Hybrid case: allocation block size differs from kernel block size # Memory blocks are subdivided to match kernel requirements # Example: 32-token memory blocks with 16-token kernel blocks # → Each memory block corresponds to 2 kernel blocks if block_size % kernel_block_size != 0: raise ValueError( f"kernel_block_size {kernel_block_size} must divide " f"kv_manager_block_size size {block_size} evenly" ) self.block_size = kernel_block_size self.blocks_per_kv_block = block_size // kernel_block_size self.use_hybrid_blocks = True self.max_num_blocks_per_req = max_num_blocks_per_req * self.blocks_per_kv_block self.block_table = self._make_buffer( self.max_num_reqs, self.max_num_blocks_per_req, dtype=torch.int32 ) self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32) self.slot_mapping = self._make_buffer( self.max_num_batched_tokens, dtype=torch.int64 ) if self.use_hybrid_blocks: self._kernel_block_arange = np.arange(0, self.blocks_per_kv_block).reshape( 1, -1 ) else: self._kernel_block_arange = None try: self.pcp_world_size = get_pcp_group().world_size self.pcp_rank = get_pcp_group().rank_in_group except AssertionError: # PCP might not be initialized in testing self.pcp_world_size = 1 self.pcp_rank = 0 try: self.dcp_world_size = get_dcp_group().world_size self.dcp_rank = get_dcp_group().rank_in_group except AssertionError: # DCP might not be initialized in testing self.dcp_world_size = 1 self.dcp_rank = 0 self.cp_kv_cache_interleave_size = cp_kv_cache_interleave_size def append_row( self, block_ids: list[int], row_idx: int, ) -> None: if not block_ids: return if self.use_hybrid_blocks: block_ids = self.map_to_kernel_blocks( np.array(block_ids), self.blocks_per_kv_block, self._kernel_block_arange ) num_blocks = len(block_ids) start = self.num_blocks_per_row[row_idx] self.num_blocks_per_row[row_idx] += num_blocks self.block_table.np[row_idx, start : start + num_blocks] = block_ids def add_row(self, block_ids: list[int], row_idx: int) -> None: self.num_blocks_per_row[row_idx] = 0 self.append_row(block_ids, row_idx) def clear_row(self, row_idx: int) -> None: num_blocks = self.num_blocks_per_row[row_idx] if num_blocks > 0: self.block_table.np[row_idx, :num_blocks] = 0 self.num_blocks_per_row[row_idx] = 0 def move_row(self, src: int, tgt: int) -> None: num_blocks = self.num_blocks_per_row[src] block_table_np = self.block_table.np block_table_np[tgt, :num_blocks] = block_table_np[src, :num_blocks] self.num_blocks_per_row[tgt] = num_blocks def swap_row(self, src: int, tgt: int) -> None: src_tgt, tgt_src = [src, tgt], [tgt, src] self.num_blocks_per_row[src_tgt] = self.num_blocks_per_row[tgt_src] self.block_table.np[src_tgt] = self.block_table.np[tgt_src] def compute_slot_mapping( self, num_reqs: int, query_start_loc: torch.Tensor, positions: torch.Tensor, ) -> None: num_tokens = positions.shape[0] total_cp_world_size = self.pcp_world_size * self.dcp_world_size total_cp_rank = self.pcp_rank * self.dcp_world_size + self.dcp_rank _compute_slot_mapping_kernel[(num_reqs + 1,)]( num_tokens, self.max_num_batched_tokens, query_start_loc, positions, self.block_table.gpu, self.block_table.gpu.stride(0), self.block_size, self.slot_mapping.gpu, TOTAL_CP_WORLD_SIZE=total_cp_world_size, TOTAL_CP_RANK=total_cp_rank, CP_KV_CACHE_INTERLEAVE_SIZE=self.cp_kv_cache_interleave_size, PAD_ID=PAD_SLOT_ID, BLOCK_SIZE=1024, ) def commit_block_table(self, num_reqs: int) -> None: self.block_table.copy_to_gpu(num_reqs) def clear(self) -> None: self.block_table.gpu.fill_(0) self.block_table.cpu.fill_(0) @staticmethod def map_to_kernel_blocks( kv_manager_block_ids: np.ndarray, blocks_per_kv_block: int, kernel_block_arange: np.ndarray, ) -> np.ndarray: """Convert kv_manager_block_id IDs to kernel block IDs. Example: # kv_manager_block_ids: 32 tokens, # Kernel block size: 16 tokens # blocks_per_kv_block = 2 >>> kv_manager_block_ids = np.array([0, 1, 2]) >>> Result: [0, 1, 2, 3, 4, 5] # Each kv_manager_block_id maps to 2 kernel block id: # kv_manager_block_id 0 → kernel block id [0, 1] # kv_manager_block_id 1 → kernel block id [2, 3] # kv_manager_block_id 2 → kernel block id [4, 5] """ if blocks_per_kv_block == 1: return kv_manager_block_ids kernel_block_ids = ( kv_manager_block_ids.reshape(-1, 1) * blocks_per_kv_block + kernel_block_arange ) return kernel_block_ids.reshape(-1) def get_device_tensor(self, num_reqs: int) -> torch.Tensor: """Returns the device tensor of the block table.""" return self.block_table.gpu[:num_reqs] def get_cpu_tensor(self) -> torch.Tensor: """Returns the CPU tensor of the block table.""" return self.block_table.cpu def get_numpy_array(self) -> np.ndarray: """Returns the numpy array of the block table.""" return self.block_table.np def _make_buffer( self, *size: int | torch.SymInt, dtype: torch.dtype ) -> CpuGpuBuffer: return CpuGpuBuffer( *size, dtype=dtype, device=self.device, pin_memory=self.pin_memory ) class MultiGroupBlockTable: """The BlockTables for each KV cache group.""" def __init__( self, max_num_reqs: int, max_model_len: int, max_num_batched_tokens: int, pin_memory: bool, device: torch.device, block_sizes: list[int], kernel_block_sizes: list[int], max_num_blocks: list[int] | None = None, cp_kv_cache_interleave_size: int = 1, ) -> None: if len(kernel_block_sizes) != len(block_sizes): raise ValueError( f"kernel_block_sizes length ({len(kernel_block_sizes)}) " f"must match block_sizes length ({len(block_sizes)})" ) if max_num_blocks is None: # Note(hc): each dcp rank only store # (max_model_len//dcp_world_size) tokens in kvcache, # so the block_size which used for calc max_num_blocks_per_req # must be multiplied by dcp_world_size. total_cp_world_size = get_total_cp_world_size() max_num_blocks = [ cdiv(max_model_len, block_size * total_cp_world_size) for block_size in block_sizes ] if len(max_num_blocks) != len(block_sizes): raise ValueError( f"max_num_blocks length ({len(max_num_blocks)}) " f"must match block_sizes length ({len(block_sizes)})" ) self.block_tables = [ BlockTable( block_size, max_num_reqs, max_num_blocks_per_req, max_num_batched_tokens, pin_memory, device, kernel_block_size, cp_kv_cache_interleave_size, ) for block_size, kernel_block_size, max_num_blocks_per_req in zip( block_sizes, kernel_block_sizes, max_num_blocks ) ] def append_row(self, block_ids: tuple[list[int], ...], row_idx: int) -> None: for i, block_table in enumerate(self.block_tables): block_table.append_row(block_ids[i], row_idx) def add_row(self, block_ids: tuple[list[int], ...], row_idx: int) -> None: for i, block_table in enumerate(self.block_tables): block_table.add_row(block_ids[i], row_idx) def clear_row(self, row_idx: int) -> None: for block_table in self.block_tables: block_table.clear_row(row_idx) def move_row(self, src: int, tgt: int) -> None: for block_table in self.block_tables: block_table.move_row(src, tgt) def swap_row(self, src: int, tgt: int) -> None: for block_table in self.block_tables: block_table.swap_row(src, tgt) def compute_slot_mapping( self, num_reqs: int, query_start_loc: torch.Tensor, positions: torch.Tensor, ) -> None: for block_table in self.block_tables: block_table.compute_slot_mapping(num_reqs, query_start_loc, positions) def commit_block_table(self, num_reqs: int) -> None: for block_table in self.block_tables: block_table.commit_block_table(num_reqs) def clear(self) -> None: for block_table in self.block_tables: block_table.clear() def __getitem__(self, idx: int) -> "BlockTable": """Returns the BlockTable for the i-th KV cache group.""" return self.block_tables[idx] @triton.jit def _compute_slot_mapping_kernel( num_tokens, max_num_tokens, query_start_loc_ptr, # [num_reqs + 1], int32 positions_ptr, # [num_tokens], int64 block_table_ptr, # [max_num_reqs, max_num_blocks_per_req], int32 (flat) block_table_stride, # max_num_blocks_per_req block_size, slot_mapping_ptr, # [max_num_tokens], int64 TOTAL_CP_WORLD_SIZE: tl.constexpr, TOTAL_CP_RANK: tl.constexpr, CP_KV_CACHE_INTERLEAVE_SIZE: tl.constexpr, PAD_ID: tl.constexpr, BLOCK_SIZE: tl.constexpr, ): req_idx = tl.program_id(0) if req_idx == tl.num_programs(0) - 1: # Pad remaining slots for CUDA graph compatibility. for i in range(num_tokens, max_num_tokens, BLOCK_SIZE): offsets = i + tl.arange(0, BLOCK_SIZE) tl.store( slot_mapping_ptr + offsets, PAD_ID, mask=offsets < max_num_tokens, ) return start_idx = tl.load(query_start_loc_ptr + req_idx).to(tl.int64) end_idx = tl.load(query_start_loc_ptr + req_idx + 1).to(tl.int64) virtual_block_size = block_size * TOTAL_CP_WORLD_SIZE row_offset = req_idx * block_table_stride for i in range(start_idx, end_idx, BLOCK_SIZE): offsets = i + tl.arange(0, BLOCK_SIZE) mask = offsets < end_idx pos = tl.load(positions_ptr + offsets, mask=mask, other=0) block_indices = pos // virtual_block_size block_numbers = tl.load(block_table_ptr + row_offset + block_indices).to( tl.int64 ) virtual_block_offsets = pos - block_indices * virtual_block_size is_local = ( virtual_block_offsets // CP_KV_CACHE_INTERLEAVE_SIZE ) % TOTAL_CP_WORLD_SIZE == TOTAL_CP_RANK local_block_offsets = ( virtual_block_offsets // (TOTAL_CP_WORLD_SIZE * CP_KV_CACHE_INTERLEAVE_SIZE) ) * CP_KV_CACHE_INTERLEAVE_SIZE + ( virtual_block_offsets % CP_KV_CACHE_INTERLEAVE_SIZE ) slot_ids = block_numbers * block_size + local_block_offsets slot_ids = tl.where(is_local, slot_ids, PAD_ID) tl.store(slot_mapping_ptr + offsets, slot_ids, mask=mask)