""" Copyright 2023-2024 SGLang Team Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ from __future__ import annotations from dataclasses import dataclass from sglang.srt.configs.mamba_utils import Mamba2CacheParams from sglang.srt.layers.attention.nsa import index_buf_accessor from sglang.srt.layers.attention.nsa.quant_k_cache import quantize_k_cache from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter """ Memory pool. SGLang has two levels of memory pool. ReqToTokenPool maps a request to its token locations. TokenToKVPoolAllocator manages the indices to kv cache data. KVCache actually holds the physical kv cache. """ import abc import logging from contextlib import nullcontext from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union import numpy as np import torch import triton import triton.language as tl from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.utils import get_bool_env_var, is_cuda, is_npu, next_power_of_2 if TYPE_CHECKING: from sglang.srt.managers.cache_controller import LayerDoneCounter logger = logging.getLogger(__name__) GB = 1024 * 1024 * 1024 _is_cuda = is_cuda() _is_npu = is_npu() if _is_npu: import torch_npu def get_tensor_size_bytes(t: torch.Tensor): return np.prod(t.shape) * t.dtype.itemsize class ReqToTokenPool: """A memory pool that maps a request to its token locations.""" def __init__( self, size: int, max_context_len: int, device: str, enable_memory_saver: bool, ): memory_saver_adapter = TorchMemorySaverAdapter.create( enable=enable_memory_saver ) self.size = size self.max_context_len = max_context_len self.device = device with memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE): self.req_to_token = torch.zeros( (size, max_context_len), dtype=torch.int32, device=device ) self.free_slots = list(range(size)) def write(self, indices, values): self.req_to_token[indices] = values def available_size(self): return len(self.free_slots) def alloc(self, need_size: int) -> List[int]: if need_size > len(self.free_slots): return None select_index = self.free_slots[:need_size] self.free_slots = self.free_slots[need_size:] return select_index def free(self, free_index: Union[int, List[int]]): if isinstance(free_index, (int,)): self.free_slots.append(free_index) else: self.free_slots.extend(free_index) def clear(self): self.free_slots = list(range(self.size)) class MambaPool: @dataclass(frozen=True, kw_only=True) class State: conv: torch.Tensor temporal: torch.Tensor def at_layer_idx(self, layer: int): return type(self)(**{k: v[layer] for k, v in vars(self).items()}) def mem_usage_bytes(self): return sum(get_tensor_size_bytes(t) for t in vars(self).values()) @dataclass(frozen=True, kw_only=True) class SpeculativeState(State): intermediate_ssm: torch.Tensor intermediate_conv_window: torch.Tensor def __init__( self, *, size: int, cache_params: "Mamba2CacheParams", device: str, speculative_num_draft_tokens: Optional[int] = None, ): conv_state_shape = cache_params.shape.conv temporal_state_shape = cache_params.shape.temporal conv_dtype = cache_params.dtype.conv ssm_dtype = cache_params.dtype.temporal num_mamba_layers = len(cache_params.layers) # assume conv_state = (dim, state_len) assert conv_state_shape[0] > conv_state_shape[1] conv_state = torch.zeros( size=(num_mamba_layers, size + 1) + conv_state_shape, dtype=conv_dtype, device=device, ) temporal_state = torch.zeros( size=(num_mamba_layers, size + 1) + temporal_state_shape, dtype=ssm_dtype, device=device, ) if speculative_num_draft_tokens is not None: # Cache intermediate SSM states per draft token during target verify # Shape: [num_layers, size + 1, speculative_num_draft_tokens, HV, K, V] intermediate_ssm_state_cache = torch.zeros( size=( num_mamba_layers, size + 1, speculative_num_draft_tokens, temporal_state_shape[0], temporal_state_shape[1], temporal_state_shape[2], ), dtype=ssm_dtype, device="cuda", ) # Cache intermediate conv windows (last K-1 inputs) per draft token during target verify # Shape: [num_layers, size + 1, speculative_num_draft_tokens, dim, K-1] intermediate_conv_window_cache = torch.zeros( size=( num_mamba_layers, size + 1, speculative_num_draft_tokens, conv_state_shape[0], conv_state_shape[1], ), dtype=conv_dtype, device="cuda", ) self.mamba_cache = self.SpeculativeState( conv=conv_state, temporal=temporal_state, intermediate_ssm=intermediate_ssm_state_cache, intermediate_conv_window=intermediate_conv_window_cache, ) logger.info( f"Mamba Cache is allocated. " f"conv_state size: {get_tensor_size_bytes(conv_state) / GB:.2f}GB, " f"ssm_state size: {get_tensor_size_bytes(temporal_state) / GB:.2f}GB " f"intermediate_ssm_state_cache size: {get_tensor_size_bytes(intermediate_ssm_state_cache) / GB:.2f}GB " f"intermediate_conv_window_cache size: {get_tensor_size_bytes(intermediate_conv_window_cache) / GB:.2f}GB " ) else: self.mamba_cache = self.State(conv=conv_state, temporal=temporal_state) logger.info( f"Mamba Cache is allocated. " f"conv_state size: {get_tensor_size_bytes(conv_state) / GB:.2f}GB, " f"ssm_state size: {get_tensor_size_bytes(temporal_state) / GB:.2f}GB " ) self.size = size self.free_slots = list(range(size)) self.mem_usage = self.mamba_cache.mem_usage_bytes() / GB def get_speculative_mamba2_params_all_layers(self) -> SpeculativeState: assert isinstance(self.mamba_cache, self.SpeculativeState) return self.mamba_cache def mamba2_layer_cache(self, layer_id: int): return self.mamba_cache.at_layer_idx(layer_id) def available_size(self): return len(self.free_slots) def alloc(self, need_size: int) -> Optional[List[int]]: if need_size > len(self.free_slots): return None select_index = self.free_slots[:need_size] self.free_slots = self.free_slots[need_size:] return select_index def free(self, free_index: Union[int, List[int]]): if isinstance(free_index, (int,)): self.free_slots.append(free_index) else: self.free_slots.extend(free_index) self.mamba_cache.conv[:, free_index] = self.mamba_cache.temporal[ :, free_index ] = 0 def clear(self): self.free_slots = list(range(self.size)) class HybridReqToTokenPool(ReqToTokenPool): """A memory pool that maps a request to its token locations.""" def __init__( self, *, size: int, max_context_len: int, device: str, enable_memory_saver: bool, cache_params: "Mamba2CacheParams", speculative_num_draft_tokens: int = None, ): super().__init__( size=size, max_context_len=max_context_len, device=device, enable_memory_saver=enable_memory_saver, ) self.mamba_pool = MambaPool( size=size, cache_params=cache_params, device=device, speculative_num_draft_tokens=speculative_num_draft_tokens, ) self.mamba_map = {layer_id: i for i, layer_id in enumerate(cache_params.layers)} self.device = device self.req_index_to_mamba_index_mapping: torch.Tensor = torch.zeros( size, dtype=torch.int32, device=self.device ) self.rid_to_mamba_index_mapping: Dict[str, int] = {} self.mamba_index_to_rid_mapping: Dict[int, str] = {} # For chunk prefill req, we do not need to allocate mamba cache, # We could use allocated mamba cache instead. def alloc( self, need_size: int, reqs: Optional[List["Req"]] = None ) -> Optional[List[int]]: select_index = super().alloc(need_size) if select_index == None: return None mamba_index = [] for req in reqs: rid = req.rid if rid in self.rid_to_mamba_index_mapping: mid = self.rid_to_mamba_index_mapping[rid] elif (mid := self.mamba_pool.alloc(1)) is not None: mid = mid[0] self.rid_to_mamba_index_mapping[rid] = mid self.mamba_index_to_rid_mapping[mid] = rid mamba_index.append(mid) assert len(select_index) == len( mamba_index ), f"Not enough space for mamba cache, try to increase --max-mamba-cache-size." self.req_index_to_mamba_index_mapping[select_index] = torch.tensor( mamba_index, dtype=torch.int32, device=self.device ) return select_index def get_mamba_indices(self, req_indices: torch.Tensor) -> torch.Tensor: return self.req_index_to_mamba_index_mapping[req_indices] def mamba2_layer_cache(self, layer_id: int): assert layer_id in self.mamba_map return self.mamba_pool.mamba2_layer_cache(self.mamba_map[layer_id]) def get_speculative_mamba2_params_all_layers(self) -> MambaPool.SpeculativeState: return self.mamba_pool.get_speculative_mamba2_params_all_layers() # For chunk prefill, we can not free mamba cache, we need use it in the future def free(self, free_index: Union[int, List[int]], free_mamba_cache: bool = True): super().free(free_index) if free_mamba_cache: mamba_index = self.req_index_to_mamba_index_mapping[free_index] mamba_index_list = mamba_index.tolist() if isinstance(mamba_index_list, int): mamba_index_list = [mamba_index_list] self.mamba_pool.free(mamba_index_list) for mid in mamba_index_list: rid = self.mamba_index_to_rid_mapping[mid] self.mamba_index_to_rid_mapping.pop(mid) self.rid_to_mamba_index_mapping.pop(rid) def clear(self): super().clear() self.mamba_pool.clear() class KVCache(abc.ABC): @abc.abstractmethod def __init__( self, size: int, page_size: int, dtype: torch.dtype, layer_num: int, device: str, enable_memory_saver: bool, start_layer: Optional[int] = None, end_layer: Optional[int] = None, ): self.size = size self.page_size = page_size self.dtype = dtype self.device = device if dtype in (torch.float8_e5m2, torch.float8_e4m3fn): # NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2 self.store_dtype = torch.uint8 else: self.store_dtype = dtype self.layer_num = layer_num self.start_layer = start_layer or 0 self.end_layer = end_layer or layer_num - 1 self.memory_saver_adapter = TorchMemorySaverAdapter.create( enable=enable_memory_saver ) self.mem_usage = 0 # used for chunked cpu-offloading self.cpu_offloading_chunk_size = 8192 # default state for optional layer-wise transfer control self.layer_transfer_counter = None def _finalize_allocation_log(self, num_tokens: int): """Common logging and mem_usage computation for KV cache allocation. Supports both tuple (K, V) size returns and single KV size returns. """ kv_size_bytes = self.get_kv_size_bytes() if isinstance(kv_size_bytes, tuple): k_size, v_size = kv_size_bytes k_size_GB = k_size / GB v_size_GB = v_size / GB logger.info( f"KV Cache is allocated. #tokens: {num_tokens}, K size: {k_size_GB:.2f} GB, V size: {v_size_GB:.2f} GB" ) self.mem_usage = k_size_GB + v_size_GB else: kv_size_GB = kv_size_bytes / GB logger.info( f"KV Cache is allocated. #tokens: {num_tokens}, KV size: {kv_size_GB:.2f} GB" ) self.mem_usage = kv_size_GB @abc.abstractmethod def get_key_buffer(self, layer_id: int) -> torch.Tensor: raise NotImplementedError() @abc.abstractmethod def get_value_buffer(self, layer_id: int) -> torch.Tensor: raise NotImplementedError() @abc.abstractmethod def get_kv_buffer(self, layer_id: int) -> Tuple[torch.Tensor, torch.Tensor]: raise NotImplementedError() @abc.abstractmethod def set_kv_buffer( self, layer: RadixAttention, loc: torch.Tensor, cache_k: torch.Tensor, cache_v: torch.Tensor, ) -> None: raise NotImplementedError() def register_layer_transfer_counter(self, layer_transfer_counter: LayerDoneCounter): self.layer_transfer_counter = layer_transfer_counter def get_cpu_copy(self, indices): raise NotImplementedError() def load_cpu_copy(self, kv_cache_cpu, indices): raise NotImplementedError() class MHATokenToKVPool(KVCache): def __init__( self, size: int, page_size: int, dtype: torch.dtype, head_num: int, head_dim: int, layer_num: int, device: str, enable_memory_saver: bool, start_layer: Optional[int] = None, end_layer: Optional[int] = None, enable_kv_cache_copy: bool = False, ): super().__init__( size, page_size, dtype, layer_num, device, enable_memory_saver, start_layer, end_layer, ) self.head_num = head_num self.head_dim = head_dim # for disagg with nvlink self.enable_custom_mem_pool = get_bool_env_var( "SGLANG_MOONCAKE_CUSTOM_MEM_POOL", "false" ) if self.enable_custom_mem_pool: # TODO(shangming): abstract custom allocator class for more backends from mooncake.allocator import NVLinkAllocator allocator = NVLinkAllocator.get_allocator(self.device) self.custom_mem_pool = torch.cuda.MemPool(allocator.allocator()) else: self.custom_mem_pool = None self._create_buffers() self.device_module = torch.get_device_module(self.device) self.alt_stream = self.device_module.Stream() if _is_cuda else None if enable_kv_cache_copy: self._init_kv_copy_and_warmup() else: self._kv_copy_config = None self._finalize_allocation_log(size) def _init_kv_copy_and_warmup(self): # Heuristics for KV copy tiling _KV_COPY_STRIDE_THRESHOLD_LARGE = 8192 _KV_COPY_STRIDE_THRESHOLD_MEDIUM = 4096 _KV_COPY_TILE_SIZE_LARGE = 512 _KV_COPY_TILE_SIZE_MEDIUM = 256 _KV_COPY_TILE_SIZE_SMALL = 128 _KV_COPY_NUM_WARPS_LARGE_TILE = 8 _KV_COPY_NUM_WARPS_SMALL_TILE = 4 stride_bytes = int(self.data_strides[0].item()) if stride_bytes >= _KV_COPY_STRIDE_THRESHOLD_LARGE: bytes_per_tile = _KV_COPY_TILE_SIZE_LARGE elif stride_bytes >= _KV_COPY_STRIDE_THRESHOLD_MEDIUM: bytes_per_tile = _KV_COPY_TILE_SIZE_MEDIUM else: bytes_per_tile = _KV_COPY_TILE_SIZE_SMALL self._kv_copy_config = { "bytes_per_tile": bytes_per_tile, "byte_tiles": (stride_bytes + bytes_per_tile - 1) // bytes_per_tile, "num_warps": ( _KV_COPY_NUM_WARPS_SMALL_TILE if bytes_per_tile <= _KV_COPY_TILE_SIZE_MEDIUM else _KV_COPY_NUM_WARPS_LARGE_TILE ), } dummy_loc = torch.zeros(1, dtype=torch.int32, device=self.device) grid = (self.data_ptrs.numel(), self._kv_copy_config["byte_tiles"]) copy_all_layer_kv_cache_tiled[grid]( self.data_ptrs, self.data_strides, dummy_loc, dummy_loc, 1, 1, BYTES_PER_TILE=self._kv_copy_config["bytes_per_tile"], num_warps=self._kv_copy_config["num_warps"], num_stages=2, ) def _create_buffers(self): with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE): with ( torch.cuda.use_mem_pool(self.custom_mem_pool) if self.enable_custom_mem_pool else nullcontext() ): # [size, head_num, head_dim] for each layer # The padded slot 0 is used for writing dummy outputs from padded tokens. self.k_buffer = [ torch.zeros( (self.size + self.page_size, self.head_num, self.head_dim), dtype=self.store_dtype, device=self.device, ) for _ in range(self.layer_num) ] self.v_buffer = [ torch.zeros( (self.size + self.page_size, self.head_num, self.head_dim), dtype=self.store_dtype, device=self.device, ) for _ in range(self.layer_num) ] self.k_data_ptrs = torch.tensor( [x.data_ptr() for x in self.k_buffer], dtype=torch.uint64, device=self.device, ) self.v_data_ptrs = torch.tensor( [x.data_ptr() for x in self.v_buffer], dtype=torch.uint64, device=self.device, ) self.data_ptrs = torch.cat([self.k_data_ptrs, self.v_data_ptrs], dim=0) self.data_strides = torch.tensor( [ np.prod(x.shape[1:]) * x.dtype.itemsize for x in self.k_buffer + self.v_buffer ], device=self.device, ) def _clear_buffers(self): del self.k_buffer del self.v_buffer def get_kv_size_bytes(self): assert hasattr(self, "k_buffer") assert hasattr(self, "v_buffer") k_size_bytes = 0 for k_cache in self.k_buffer: k_size_bytes += get_tensor_size_bytes(k_cache) v_size_bytes = 0 for v_cache in self.v_buffer: v_size_bytes += get_tensor_size_bytes(v_cache) return k_size_bytes, v_size_bytes # for disagg def get_contiguous_buf_infos(self): # layer_num x [seq_len, head_num, head_dim] # layer_num x [page_num, page_size, head_num, head_dim] kv_data_ptrs = [ self._get_key_buffer(i).data_ptr() for i in range(self.start_layer, self.start_layer + self.layer_num) ] + [ self._get_value_buffer(i).data_ptr() for i in range(self.start_layer, self.start_layer + self.layer_num) ] kv_data_lens = [ self._get_key_buffer(i).nbytes for i in range(self.start_layer, self.start_layer + self.layer_num) ] + [ self._get_value_buffer(i).nbytes for i in range(self.start_layer, self.start_layer + self.layer_num) ] kv_item_lens = [ self._get_key_buffer(i)[0].nbytes * self.page_size for i in range(self.start_layer, self.start_layer + self.layer_num) ] + [ self._get_value_buffer(i)[0].nbytes * self.page_size for i in range(self.start_layer, self.start_layer + self.layer_num) ] return kv_data_ptrs, kv_data_lens, kv_item_lens def maybe_get_custom_mem_pool(self): return self.custom_mem_pool def get_cpu_copy(self, indices): torch.cuda.synchronize() kv_cache_cpu = [] chunk_size = self.cpu_offloading_chunk_size for layer_id in range(self.layer_num): kv_cache_cpu.append([]) for i in range(0, len(indices), chunk_size): chunk_indices = indices[i : i + chunk_size] k_cpu = self.k_buffer[layer_id][chunk_indices].to( "cpu", non_blocking=True ) v_cpu = self.v_buffer[layer_id][chunk_indices].to( "cpu", non_blocking=True ) kv_cache_cpu[-1].append([k_cpu, v_cpu]) torch.cuda.synchronize() return kv_cache_cpu def load_cpu_copy(self, kv_cache_cpu, indices): torch.cuda.synchronize() chunk_size = self.cpu_offloading_chunk_size for layer_id in range(self.layer_num): for i in range(0, len(indices), chunk_size): chunk_indices = indices[i : i + chunk_size] k_cpu, v_cpu = ( kv_cache_cpu[layer_id][i // chunk_size][0], kv_cache_cpu[layer_id][i // chunk_size][1], ) assert k_cpu.shape[0] == v_cpu.shape[0] == len(chunk_indices) k_chunk = k_cpu.to(self.k_buffer[0].device, non_blocking=True) v_chunk = v_cpu.to(self.v_buffer[0].device, non_blocking=True) self.k_buffer[layer_id][chunk_indices] = k_chunk self.v_buffer[layer_id][chunk_indices] = v_chunk torch.cuda.synchronize() def _get_key_buffer(self, layer_id: int): # for internal use of referencing if self.store_dtype != self.dtype: return self.k_buffer[layer_id - self.start_layer].view(self.dtype) return self.k_buffer[layer_id - self.start_layer] def get_key_buffer(self, layer_id: int): # note: get_key_buffer is hooked with synchronization for layer-wise KV cache loading # it is supposed to be used only by attention backend not for information purpose # same applies to get_value_buffer and get_kv_buffer if self.layer_transfer_counter is not None: self.layer_transfer_counter.wait_until(layer_id - self.start_layer) return self._get_key_buffer(layer_id) def _get_value_buffer(self, layer_id: int): # for internal use of referencing if self.store_dtype != self.dtype: return self.v_buffer[layer_id - self.start_layer].view(self.dtype) return self.v_buffer[layer_id - self.start_layer] def get_value_buffer(self, layer_id: int): if self.layer_transfer_counter is not None: self.layer_transfer_counter.wait_until(layer_id - self.start_layer) return self._get_value_buffer(layer_id) def get_kv_buffer(self, layer_id: int): return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id) def set_kv_buffer( self, layer: RadixAttention, loc: torch.Tensor, cache_k: torch.Tensor, cache_v: torch.Tensor, k_scale: Optional[float] = None, v_scale: Optional[float] = None, layer_id_override: Optional[int] = None, ): from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode if layer_id_override is not None: layer_id = layer_id_override else: layer_id = layer.layer_id if cache_k.dtype != self.dtype: if k_scale is not None: cache_k.div_(k_scale) if v_scale is not None: cache_v.div_(v_scale) cache_k = cache_k.to(self.dtype) cache_v = cache_v.to(self.dtype) if self.store_dtype != self.dtype: cache_k = cache_k.view(self.store_dtype) cache_v = cache_v.view(self.store_dtype) if get_is_capture_mode() and self.alt_stream is not None: # Overlap the copy of K and V cache for small batch size current_stream = self.device_module.current_stream() self.alt_stream.wait_stream(current_stream) self.k_buffer[layer_id - self.start_layer][loc] = cache_k with self.device_module.stream(self.alt_stream): self.v_buffer[layer_id - self.start_layer][loc] = cache_v current_stream.wait_stream(self.alt_stream) else: self.k_buffer[layer_id - self.start_layer][loc] = cache_k self.v_buffer[layer_id - self.start_layer][loc] = cache_v def move_kv_cache(self, tgt_loc: torch.Tensor, src_loc: torch.Tensor): N = tgt_loc.numel() if N == 0: return assert ( self._kv_copy_config is not None ), "KV copy not initialized. Set enable_kv_cache_copy=True in __init__" cfg = self._kv_copy_config N_upper = next_power_of_2(N) grid = (self.data_ptrs.numel(), cfg["byte_tiles"]) copy_all_layer_kv_cache_tiled[grid]( self.data_ptrs, self.data_strides, tgt_loc, src_loc, N, N_upper, BYTES_PER_TILE=cfg["bytes_per_tile"], num_warps=cfg["num_warps"], num_stages=2, ) class HybridLinearKVPool(KVCache): """KV cache with separate pools for full and linear attention layers.""" def __init__( self, size: int, dtype: torch.dtype, page_size: int, head_num: int, head_dim: int, full_attention_layer_ids: List[int], enable_kvcache_transpose: bool, device: str, ): self.size = size self.dtype = dtype self.device = device self.full_layer_nums = len(full_attention_layer_ids) self.page_size = page_size # TODO MHATransposedTokenToKVPool if enable_kvcache_transpose is True assert not enable_kvcache_transpose if _is_npu: TokenToKVPoolClass = AscendTokenToKVPool else: TokenToKVPoolClass = MHATokenToKVPool self.full_kv_pool = TokenToKVPoolClass( size=size, page_size=self.page_size, dtype=dtype, head_num=head_num, head_dim=head_dim, layer_num=self.full_layer_nums, device=device, enable_memory_saver=False, ) self.full_attention_layer_id_mapping = { id: i for i, id in enumerate(full_attention_layer_ids) } k_size, v_size = self.get_kv_size_bytes() self.mem_usage = (k_size + v_size) / GB def get_kv_size_bytes(self): return self.full_kv_pool.get_kv_size_bytes() def get_contiguous_buf_infos(self): return self.full_kv_pool.get_contiguous_buf_infos() def _transfer_full_attention_id(self, layer_id: int): if layer_id not in self.full_attention_layer_id_mapping: raise ValueError( f"{layer_id=} not in full attention layers: {self.full_attention_layer_id_mapping.keys()}" ) return self.full_attention_layer_id_mapping[layer_id] def get_key_buffer(self, layer_id: int): layer_id = self._transfer_full_attention_id(layer_id) return self.full_kv_pool.get_key_buffer(layer_id) def get_value_buffer(self, layer_id: int): layer_id = self._transfer_full_attention_id(layer_id) return self.full_kv_pool.get_value_buffer(layer_id) def get_kv_buffer(self, layer_id: int): layer_id = self._transfer_full_attention_id(layer_id) return self.full_kv_pool.get_kv_buffer(layer_id) def set_kv_buffer( self, layer: RadixAttention, loc: torch.Tensor, cache_k: torch.Tensor, cache_v: torch.Tensor, k_scale: float = 1.0, v_scale: float = 1.0, ): layer_id = self._transfer_full_attention_id(layer.layer_id) self.full_kv_pool.set_kv_buffer( None, loc, cache_k, cache_v, k_scale, v_scale, layer_id_override=layer_id, ) def get_v_head_dim(self): return self.full_kv_pool.get_value_buffer(0).shape[-1] class SWAKVPool(KVCache): """KV cache with separate pools for full and SWA attention layers.""" def __init__( self, size: int, size_swa: int, dtype: torch.dtype, swa_attention_layer_ids: List[int], full_attention_layer_ids: List[int], enable_kvcache_transpose: bool, token_to_kv_pool_class: KVCache = MHATokenToKVPool, **kwargs, ): self.size = size self.size_swa = size_swa self.dtype = dtype self.swa_layer_nums = len(swa_attention_layer_ids) self.full_layer_nums = len(full_attention_layer_ids) kwargs["page_size"] = 1 kwargs["enable_memory_saver"] = False # TODO MHATransposedTokenToKVPool if enable_kvcache_transpose is True assert not enable_kvcache_transpose self.swa_kv_pool = token_to_kv_pool_class( size=size_swa, dtype=dtype, layer_num=self.swa_layer_nums, **kwargs, ) self.full_kv_pool = token_to_kv_pool_class( size=size, dtype=dtype, layer_num=self.full_layer_nums, **kwargs, ) self.layers_mapping: Dict[int, Tuple[int, bool]] = {} for full_attn_layer_id, global_layer_id in enumerate(full_attention_layer_ids): self.layers_mapping[global_layer_id] = (full_attn_layer_id, False) for swa_layer_id, global_layer_id in enumerate(swa_attention_layer_ids): self.layers_mapping[global_layer_id] = (swa_layer_id, True) self.full_to_swa_index_mapping: Optional[torch.Tensor] = None k_size, v_size = self.get_kv_size_bytes() self.mem_usage = (k_size + v_size) / GB def get_kv_size_bytes(self): k_size, v_size = self.full_kv_pool.get_kv_size_bytes() k_size_swa, v_size_swa = self.swa_kv_pool.get_kv_size_bytes() return k_size + k_size_swa, v_size + v_size_swa def get_contiguous_buf_infos(self): full_kv_data_ptrs, full_kv_data_lens, full_kv_item_lens = ( self.full_kv_pool.get_contiguous_buf_infos() ) swa_kv_data_ptrs, swa_kv_data_lens, swa_kv_item_lens = ( self.swa_kv_pool.get_contiguous_buf_infos() ) kv_data_ptrs = full_kv_data_ptrs + swa_kv_data_ptrs kv_data_lens = full_kv_data_lens + swa_kv_data_lens kv_item_lens = full_kv_item_lens + swa_kv_item_lens return kv_data_ptrs, kv_data_lens, kv_item_lens def get_key_buffer(self, layer_id: int): layer_id_pool, is_swa = self.layers_mapping[layer_id] if is_swa: return self.swa_kv_pool.get_key_buffer(layer_id_pool) else: return self.full_kv_pool.get_key_buffer(layer_id_pool) def get_value_buffer(self, layer_id: int): layer_id_pool, is_swa = self.layers_mapping[layer_id] if is_swa: return self.swa_kv_pool.get_value_buffer(layer_id_pool) else: return self.full_kv_pool.get_value_buffer(layer_id_pool) def get_kv_buffer(self, layer_id: int): layer_id_pool, is_swa = self.layers_mapping[layer_id] if is_swa: return self.swa_kv_pool.get_kv_buffer(layer_id_pool) else: return self.full_kv_pool.get_kv_buffer(layer_id_pool) def translate_loc_from_full_to_swa(self, kv_indices: torch.Tensor): assert self.full_to_swa_index_mapping is not None return self.full_to_swa_index_mapping[kv_indices].to(torch.int32) def set_kv_buffer( self, layer: RadixAttention, loc: torch.Tensor, cache_k: torch.Tensor, cache_v: torch.Tensor, k_scale: float = 1.0, v_scale: float = 1.0, ): layer_id = layer.layer_id layer_id_pool, is_swa = self.layers_mapping[layer_id] if is_swa: if self.full_to_swa_index_mapping is not None: loc = self.translate_loc_from_full_to_swa(loc) self.swa_kv_pool.set_kv_buffer( None, loc, cache_k, cache_v, k_scale, v_scale, layer_id_override=layer_id_pool, ) else: self.full_kv_pool.set_kv_buffer( None, loc, cache_k, cache_v, k_scale, v_scale, layer_id_override=layer_id_pool, ) class AscendTokenToKVPool(MHATokenToKVPool): def _create_buffers(self): with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE): # [size, head_num, head_dim] for each layer # The padded slot 0 is used for writing dummy outputs from padded tokens. # Continuous memory improves the efficiency of Ascend`s transmission backend, # while other backends remain unchanged. self.kv_buffer = torch.zeros( ( 2, self.layer_num, self.size // self.page_size + 1, self.page_size, self.head_num, self.head_dim, ), dtype=self.store_dtype, device=self.device, ) self.k_buffer = self.kv_buffer[0] self.v_buffer = self.kv_buffer[1] # for disagg def get_contiguous_buf_infos(self): # layer_num x [seq_len, head_num, head_dim] # layer_num x [page_num, page_size, head_num, head_dim] kv_data_ptrs = [ self.get_key_buffer(i).data_ptr() for i in range(self.start_layer, self.start_layer + self.layer_num) ] + [ self.get_value_buffer(i).data_ptr() for i in range(self.start_layer, self.start_layer + self.layer_num) ] kv_data_lens = [ self.get_key_buffer(i).nbytes for i in range(self.start_layer, self.start_layer + self.layer_num) ] + [ self.get_value_buffer(i).nbytes for i in range(self.start_layer, self.start_layer + self.layer_num) ] kv_item_lens = [ self.get_key_buffer(i)[0].nbytes for i in range(self.start_layer, self.start_layer + self.layer_num) ] + [ self.get_value_buffer(i)[0].nbytes for i in range(self.start_layer, self.start_layer + self.layer_num) ] return kv_data_ptrs, kv_data_lens, kv_item_lens def set_kv_buffer( self, layer: RadixAttention, loc: torch.Tensor, cache_k: torch.Tensor, cache_v: torch.Tensor, k_scale: Optional[float] = None, v_scale: Optional[float] = None, layer_id_override: Optional[int] = None, ): if layer_id_override is not None: layer_id = layer_id_override else: layer_id = layer.layer_id if cache_k.dtype != self.dtype: if k_scale is not None: cache_k.div_(k_scale) if v_scale is not None: cache_v.div_(v_scale) cache_k = cache_k.to(self.dtype) cache_v = cache_v.to(self.dtype) if self.store_dtype != self.dtype: cache_k = cache_k.view(self.store_dtype) cache_v = cache_v.view(self.store_dtype) torch_npu._npu_reshape_and_cache( key=cache_k, value=cache_v, key_cache=self.k_buffer[layer_id].view( -1, self.page_size, self.head_num, self.head_dim ), value_cache=self.v_buffer[layer_id].view( -1, self.page_size, self.head_num, self.head_dim ), slot_indices=loc, ) @triton.jit def set_mla_kv_buffer_kernel( kv_buffer_ptr, cache_k_nope_ptr, cache_k_rope_ptr, loc_ptr, buffer_stride: tl.constexpr, nope_stride: tl.constexpr, rope_stride: tl.constexpr, nope_dim: tl.constexpr, rope_dim: tl.constexpr, BLOCK: tl.constexpr, ): pid_loc = tl.program_id(0) pid_blk = tl.program_id(1) base = pid_blk * BLOCK offs = base + tl.arange(0, BLOCK) total_dim = nope_dim + rope_dim mask = offs < total_dim loc = tl.load(loc_ptr + pid_loc) dst_ptr = kv_buffer_ptr + loc * buffer_stride + offs if base + BLOCK <= nope_dim: src = tl.load( cache_k_nope_ptr + pid_loc * nope_stride + offs, mask=mask, ) else: offs_rope = offs - nope_dim src = tl.load( cache_k_rope_ptr + pid_loc * rope_stride + offs_rope, mask=mask, ) tl.store(dst_ptr, src, mask=mask) def set_mla_kv_buffer_triton( kv_buffer: torch.Tensor, loc: torch.Tensor, cache_k_nope: torch.Tensor, cache_k_rope: torch.Tensor, ): nope_dim = cache_k_nope.shape[-1] rope_dim = cache_k_rope.shape[-1] total_dim = nope_dim + rope_dim BLOCK = 128 n_loc = loc.numel() grid = (n_loc, triton.cdiv(total_dim, BLOCK)) set_mla_kv_buffer_kernel[grid]( kv_buffer, cache_k_nope, cache_k_rope, loc, kv_buffer.stride(0), cache_k_nope.stride(0), cache_k_rope.stride(0), nope_dim, rope_dim, BLOCK=BLOCK, ) class MLATokenToKVPool(KVCache): def __init__( self, size: int, page_size: int, dtype: torch.dtype, kv_lora_rank: int, qk_rope_head_dim: int, layer_num: int, device: str, enable_memory_saver: bool, start_layer: Optional[int] = None, end_layer: Optional[int] = None, use_nsa: bool = False, override_kv_cache_dim: Optional[int] = None, ): super().__init__( size, page_size, dtype, layer_num, device, enable_memory_saver, start_layer, end_layer, ) self.kv_lora_rank = kv_lora_rank self.qk_rope_head_dim = qk_rope_head_dim self.use_nsa = use_nsa self.nsa_kv_cache_store_fp8 = use_nsa and dtype == torch.float8_e4m3fn # TODO do not hardcode self.kv_cache_dim = ( 656 if self.use_nsa and self.nsa_kv_cache_store_fp8 else (kv_lora_rank + qk_rope_head_dim) ) # for disagg with nvlink self.enable_custom_mem_pool = get_bool_env_var( "SGLANG_MOONCAKE_CUSTOM_MEM_POOL", "false" ) if self.enable_custom_mem_pool: # TODO(shangming): abstract custom allocator class for more backends from mooncake.allocator import NVLinkAllocator allocator = NVLinkAllocator.get_allocator(self.device) self.custom_mem_pool = torch.cuda.MemPool(allocator.allocator()) else: self.custom_mem_pool = None with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE): with ( torch.cuda.use_mem_pool(self.custom_mem_pool) if self.custom_mem_pool else nullcontext() ): # The padded slot 0 is used for writing dummy outputs from padded tokens. self.kv_buffer = [ torch.zeros( (size + page_size, 1, self.kv_cache_dim), dtype=self.store_dtype, device=device, ) for _ in range(layer_num) ] self.data_ptrs = torch.tensor( [x.data_ptr() for x in self.kv_buffer], dtype=torch.uint64, device=self.device, ) if not use_nsa: # NSA will allocate indexer KV cache later and then log the total size self._finalize_allocation_log(size) def get_kv_size_bytes(self): assert hasattr(self, "kv_buffer") kv_size_bytes = 0 for kv_cache in self.kv_buffer: kv_size_bytes += get_tensor_size_bytes(kv_cache) return kv_size_bytes # for disagg def get_contiguous_buf_infos(self): # MLA has only one kv_buffer, so only the information of this buffer needs to be returned. kv_data_ptrs = [self.kv_buffer[i].data_ptr() for i in range(self.layer_num)] kv_data_lens = [self.kv_buffer[i].nbytes for i in range(self.layer_num)] kv_item_lens = [ self.kv_buffer[i][0].nbytes * self.page_size for i in range(self.layer_num) ] return kv_data_ptrs, kv_data_lens, kv_item_lens def maybe_get_custom_mem_pool(self): return self.custom_mem_pool def get_key_buffer(self, layer_id: int): if self.layer_transfer_counter is not None: self.layer_transfer_counter.wait_until(layer_id - self.start_layer) if self.store_dtype != self.dtype: return self.kv_buffer[layer_id - self.start_layer].view(self.dtype) return self.kv_buffer[layer_id - self.start_layer] def get_value_buffer(self, layer_id: int): if self.layer_transfer_counter is not None: self.layer_transfer_counter.wait_until(layer_id - self.start_layer) if self.store_dtype != self.dtype: return self.kv_buffer[layer_id - self.start_layer][ ..., : self.kv_lora_rank ].view(self.dtype) return self.kv_buffer[layer_id - self.start_layer][..., : self.kv_lora_rank] def get_kv_buffer(self, layer_id: int): return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id) def set_kv_buffer( self, layer: RadixAttention, loc: torch.Tensor, cache_k: torch.Tensor, cache_v: torch.Tensor, ): layer_id = layer.layer_id assert not (self.use_nsa and self.nsa_kv_cache_store_fp8) if cache_k.dtype != self.dtype: cache_k = cache_k.to(self.dtype) if self.store_dtype != self.dtype: self.kv_buffer[layer_id - self.start_layer][loc] = cache_k.view( self.store_dtype ) else: self.kv_buffer[layer_id - self.start_layer][loc] = cache_k def set_mla_kv_buffer( self, layer: RadixAttention, loc: torch.Tensor, cache_k_nope: torch.Tensor, cache_k_rope: torch.Tensor, ): layer_id = layer.layer_id if self.use_nsa and self.nsa_kv_cache_store_fp8: # original cache_k: (num_tokens, num_heads 1, hidden 576); we unsqueeze the page_size=1 dim here # TODO no need to cat cache_k = torch.cat([cache_k_nope, cache_k_rope], dim=-1) cache_k = quantize_k_cache(cache_k.unsqueeze(1)).squeeze(1) cache_k = cache_k.view(self.store_dtype) self.kv_buffer[layer_id - self.start_layer][loc] = cache_k else: if cache_k_nope.dtype != self.dtype: cache_k_nope = cache_k_nope.to(self.dtype) cache_k_rope = cache_k_rope.to(self.dtype) if self.store_dtype != self.dtype: cache_k_nope = cache_k_nope.view(self.store_dtype) cache_k_rope = cache_k_rope.view(self.store_dtype) set_mla_kv_buffer_triton( self.kv_buffer[layer_id - self.start_layer], loc, cache_k_nope, cache_k_rope, ) def get_cpu_copy(self, indices): torch.cuda.synchronize() kv_cache_cpu = [] chunk_size = self.cpu_offloading_chunk_size for layer_id in range(self.layer_num): kv_cache_cpu.append([]) for i in range(0, len(indices), chunk_size): chunk_indices = indices[i : i + chunk_size] kv_cpu = self.kv_buffer[layer_id][chunk_indices].to( "cpu", non_blocking=True ) kv_cache_cpu[-1].append(kv_cpu) torch.cuda.synchronize() return kv_cache_cpu def load_cpu_copy(self, kv_cache_cpu, indices): torch.cuda.synchronize() chunk_size = self.cpu_offloading_chunk_size for layer_id in range(self.layer_num): for i in range(0, len(indices), chunk_size): chunk_indices = indices[i : i + chunk_size] kv_cpu = kv_cache_cpu[layer_id][i // chunk_size] assert kv_cpu.shape[0] == len(chunk_indices) kv_chunk = kv_cpu.to(self.kv_buffer[0].device, non_blocking=True) self.kv_buffer[layer_id][chunk_indices] = kv_chunk torch.cuda.synchronize() class NSATokenToKVPool(MLATokenToKVPool): quant_block_size = 128 index_k_with_scale_buffer_dtype = torch.uint8 def __init__( self, size: int, page_size: int, kv_lora_rank: int, dtype: torch.dtype, qk_rope_head_dim: int, layer_num: int, device: str, index_head_dim: int, enable_memory_saver: bool, start_layer: Optional[int] = None, end_layer: Optional[int] = None, ): super().__init__( size, page_size, dtype, kv_lora_rank, qk_rope_head_dim, layer_num, device, enable_memory_saver, start_layer, end_layer, use_nsa=True, ) # self.index_k_dtype = torch.float8_e4m3fn # self.index_k_scale_dtype = torch.float32 self.index_head_dim = index_head_dim # num head == 1 and head dim == 128 for index_k in NSA assert index_head_dim == 128 assert self.page_size == 64 self.index_k_with_scale_buffer = [ torch.zeros( # Layout: # ref: test_attention.py :: kv_cache_cast_to_fp8 # shape: (num_pages, page_size 64 * head_dim 128 + page_size 64 * fp32_nbytes 4) # data: for page i, # * buf[i, :page_size * head_dim] for fp8 data # * buf[i, page_size * head_dim:].view(float32) for scale ( (size + page_size + 1) // self.page_size, self.page_size * (index_head_dim + index_head_dim // self.quant_block_size * 4), ), dtype=self.index_k_with_scale_buffer_dtype, device=device, ) for _ in range(layer_num) ] self._finalize_allocation_log(size) def get_index_k_with_scale_buffer(self, layer_id: int) -> torch.Tensor: if self.layer_transfer_counter is not None: self.layer_transfer_counter.wait_until(layer_id - self.start_layer) return self.index_k_with_scale_buffer[layer_id - self.start_layer] def get_index_k_continuous( self, layer_id: int, seq_len: int, page_indices: torch.Tensor, ): buf = self.index_k_with_scale_buffer[layer_id - self.start_layer] return index_buf_accessor.GetK.execute( self, buf, seq_len=seq_len, page_indices=page_indices ) def get_index_k_scale_continuous( self, layer_id: int, seq_len: int, page_indices: torch.Tensor, ): buf = self.index_k_with_scale_buffer[layer_id - self.start_layer] return index_buf_accessor.GetS.execute( self, buf, seq_len=seq_len, page_indices=page_indices ) # TODO rename later (currently use diff name to avoid confusion) def set_index_k_and_scale_buffer( self, layer_id: int, loc: torch.Tensor, index_k: torch.Tensor, index_k_scale: torch.Tensor, ) -> None: buf = self.index_k_with_scale_buffer[layer_id - self.start_layer] index_buf_accessor.SetKAndS.execute( pool=self, buf=buf, loc=loc, index_k=index_k, index_k_scale=index_k_scale ) def get_kv_size_bytes(self): kv_size_bytes = super().get_kv_size_bytes() for index_k_cache in self.index_k_with_scale_buffer: kv_size_bytes += get_tensor_size_bytes(index_k_cache) return kv_size_bytes class AscendMLAPagedTokenToKVPool(MLATokenToKVPool): def __init__( self, size: int, page_size: int, dtype: torch.dtype, kv_lora_rank: int, qk_rope_head_dim: int, index_head_dim: Optional[int], layer_num: int, device: str, enable_memory_saver: bool, start_layer: Optional[int] = None, end_layer: Optional[int] = None, ): super(MLATokenToKVPool, self).__init__( size, page_size, dtype, layer_num, device, enable_memory_saver, start_layer, end_layer, ) self.kv_lora_rank = kv_lora_rank self.qk_rope_head_dim = qk_rope_head_dim self.index_head_dim = index_head_dim self.custom_mem_pool = None with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE): # The padded slot 0 is used for writing dummy outputs from padded tokens. self.k_buffer = torch.zeros( ( layer_num, self.size // self.page_size + 1, self.page_size, 1, self.kv_lora_rank, ), dtype=self.store_dtype, device=self.device, ) self.v_buffer = torch.zeros( ( layer_num, self.size // self.page_size + 1, self.page_size, 1, self.qk_rope_head_dim, ), dtype=self.store_dtype, device=self.device, ) if self.index_head_dim is not None: self.index_k_buffer = torch.zeros( ( layer_num, self.size // self.page_size + 1, self.page_size, 1, self.index_head_dim, ), dtype=self.store_dtype, device=self.device, ) self._finalize_allocation_log(size) def get_kv_size_bytes(self): assert hasattr(self, "k_buffer") assert hasattr(self, "v_buffer") kv_size_bytes = 0 for k_cache in self.k_buffer: kv_size_bytes += get_tensor_size_bytes(k_cache) for v_cache in self.v_buffer: kv_size_bytes += get_tensor_size_bytes(v_cache) if self.index_head_dim is not None: assert hasattr(self, "index_k_buffer") for index_k_cache in self.index_k_buffer: kv_size_bytes += get_tensor_size_bytes(index_k_cache) return kv_size_bytes def get_kv_buffer(self, layer_id: int): if self.layer_transfer_counter is not None: self.layer_transfer_counter.wait_until(layer_id - self.start_layer) return ( self.k_buffer[layer_id - self.start_layer], self.v_buffer[layer_id - self.start_layer], ) def get_key_buffer(self, layer_id: int): if self.layer_transfer_counter is not None: self.layer_transfer_counter.wait_until(layer_id - self.start_layer) if self.store_dtype != self.dtype: return self.k_buffer[layer_id - self.start_layer].view(self.dtype) return self.k_buffer[layer_id - self.start_layer] def get_value_buffer(self, layer_id: int): if self.layer_transfer_counter is not None: self.layer_transfer_counter.wait_until(layer_id - self.start_layer) if self.store_dtype != self.dtype: return self.v_buffer[layer_id - self.start_layer].view(self.dtype) return self.v_buffer[layer_id - self.start_layer] def get_index_k_buffer(self, layer_id: int): if self.layer_transfer_counter is not None: self.layer_transfer_counter.wait_until(layer_id - self.start_layer) if self.store_dtype != self.dtype: return self.index_k_buffer[layer_id - self.start_layer].view(self.dtype) return self.index_k_buffer[layer_id - self.start_layer] # for disagg def get_contiguous_buf_infos(self): # MLA has only one kv_buffer, so only the information of this buffer needs to be returned. kv_data_ptrs = [self.k_buffer[i].data_ptr() for i in range(self.layer_num)] + [ self.v_buffer[i].data_ptr() for i in range(self.layer_num) ] kv_data_lens = [self.k_buffer[i].nbytes for i in range(self.layer_num)] + [ self.v_buffer[i].nbytes for i in range(self.layer_num) ] kv_item_lens = [self.k_buffer[i][0].nbytes for i in range(self.layer_num)] + [ self.v_buffer[i][0].nbytes for i in range(self.layer_num) ] if self.index_head_dim is not None: kv_data_ptrs += [ self.index_k_buffer[i].data_ptr() for i in range(self.layer_num) ] kv_data_lens += [ self.index_k_buffer[i].nbytes for i in range(self.layer_num) ] kv_item_lens += [ self.index_k_buffer[i][0].nbytes for i in range(self.layer_num) ] return kv_data_ptrs, kv_data_lens, kv_item_lens def set_kv_buffer( self, layer: RadixAttention, loc: torch.Tensor, cache_k: torch.Tensor, cache_v: torch.Tensor, ): layer_id = layer.layer_id if cache_k.dtype != self.dtype: cache_k = cache_k.to(self.dtype) cache_v = cache_v.to(self.dtype) if self.store_dtype != self.dtype: cache_k = cache_k.view(self.store_dtype) cache_v = cache_v.view(self.store_dtype) if cache_v is None: cache_k, cache_v = cache_k.split( [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 ) torch_npu.npu_scatter_nd_update_( self.k_buffer[layer_id - self.start_layer].view(-1, 1, self.kv_lora_rank), loc.view(-1, 1), cache_k.view(-1, 1, self.kv_lora_rank), ) torch_npu.npu_scatter_nd_update_( self.v_buffer[layer_id - self.start_layer].view( -1, 1, self.qk_rope_head_dim ), loc.view(-1, 1), cache_v.view(-1, 1, self.qk_rope_head_dim), ) def set_index_k_buffer( self, layer_id: int, loc: torch.Tensor, index_k: torch.Tensor, ): if index_k.dtype != self.dtype: index_k = index_k.to(self.dtype) if self.store_dtype != self.dtype: index_k = index_k.view(self.store_dtype) torch_npu.npu_scatter_nd_update_( self.index_k_buffer[layer_id - self.start_layer].view( -1, 1, self.index_head_dim ), loc.view(-1, 1), index_k.view(-1, 1, self.index_head_dim), ) class DoubleSparseTokenToKVPool(KVCache): def __init__( self, size: int, page_size: int, dtype: torch.dtype, head_num: int, head_dim: int, layer_num: int, device: str, heavy_channel_num: int, enable_memory_saver: bool, start_layer: Optional[int] = None, end_layer: Optional[int] = None, ): super().__init__( size, page_size, dtype, layer_num, device, enable_memory_saver, start_layer, end_layer, ) with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE): # [size, head_num, head_dim] for each layer self.k_buffer = [ torch.zeros( (size + page_size, head_num, head_dim), dtype=dtype, device=device ) for _ in range(layer_num) ] self.v_buffer = [ torch.zeros( (size + page_size, head_num, head_dim), dtype=dtype, device=device ) for _ in range(layer_num) ] # [size, head_num, heavy_channel_num] for each layer self.label_buffer = [ torch.zeros( (size + 1, head_num, heavy_channel_num), dtype=dtype, device=device ) for _ in range(layer_num) ] def get_key_buffer(self, layer_id: int): return self.k_buffer[layer_id - self.start_layer] def get_value_buffer(self, layer_id: int): return self.v_buffer[layer_id - self.start_layer] def get_label_buffer(self, layer_id: int): return self.label_buffer[layer_id - self.start_layer] def get_kv_buffer(self, layer_id: int): return ( self.k_buffer[layer_id - self.start_layer], self.v_buffer[layer_id - self.start_layer], ) def set_kv_buffer( self, layer: RadixAttention, loc: torch.Tensor, cache_k: torch.Tensor, cache_v: torch.Tensor, cache_label: torch.Tensor, ): # NOTE(Andy): ignore the dtype check layer_id = layer.layer_id self.k_buffer[layer_id - self.start_layer][loc] = cache_k self.v_buffer[layer_id - self.start_layer][loc] = cache_v self.label_buffer[layer_id - self.start_layer][loc] = cache_label @triton.jit def copy_all_layer_kv_cache_tiled( data_ptrs, strides, tgt_loc_ptr, src_loc_ptr, num_locs, num_locs_upper: tl.constexpr, BYTES_PER_TILE: tl.constexpr, ): """2D tiled kernel. Safe for in-place copy.""" bid = tl.program_id(0) tid = tl.program_id(1) stride = tl.load(strides + bid) base_ptr = tl.load(data_ptrs + bid) base_ptr = tl.cast(base_ptr, tl.pointer_type(tl.uint8)) byte_off = tid * BYTES_PER_TILE + tl.arange(0, BYTES_PER_TILE) mask_byte = byte_off < stride tl.multiple_of(byte_off, 16) loc_idx = tl.arange(0, num_locs_upper) mask_loc = loc_idx < num_locs src = tl.load(src_loc_ptr + loc_idx, mask=mask_loc, other=0) tgt = tl.load(tgt_loc_ptr + loc_idx, mask=mask_loc, other=0) src_ptr = base_ptr + src[:, None] * stride + byte_off[None, :] tgt_ptr = base_ptr + tgt[:, None] * stride + byte_off[None, :] mask = mask_loc[:, None] & mask_byte[None, :] vals = tl.load(src_ptr, mask=mask) tl.store(tgt_ptr, vals, mask=mask)