Unverified Commit dd7ca006 authored by Zhiqiang Xie's avatar Zhiqiang Xie Committed by GitHub
Browse files

Interface change for kvcache io to support page first layout (#8318)

parent 9305ea6c
...@@ -231,16 +231,7 @@ class HiCacheController: ...@@ -231,16 +231,7 @@ class HiCacheController:
self.mem_pool_host = mem_pool_host self.mem_pool_host = mem_pool_host
self.write_policy = write_policy self.write_policy = write_policy
self.page_size = page_size self.page_size = page_size
# using kernel for small page KV cache transfer and DMA for large pages self.io_backend = io_backend
if not io_backend:
IO_BACKEND_PAGE_SIZE_THRESHOLD = 64
self.io_backend = (
"direct"
if self.page_size >= IO_BACKEND_PAGE_SIZE_THRESHOLD
else "kernel"
)
else:
self.io_backend = io_backend
self.enable_storage = False self.enable_storage = False
# todo: move backend initialization to storage backend module # todo: move backend initialization to storage backend module
...@@ -447,11 +438,8 @@ class HiCacheController: ...@@ -447,11 +438,8 @@ class HiCacheController:
host_indices, device_indices = self.move_indices( host_indices, device_indices = self.move_indices(
operation.host_indices, operation.device_indices operation.host_indices, operation.device_indices
) )
self.mem_pool_device.backup_to_host_all_layer( self.mem_pool_host.backup_from_device_all_layer(
self.mem_pool_host, self.mem_pool_device, host_indices, device_indices, self.io_backend
host_indices,
device_indices,
self.io_backend,
) )
self.write_stream.synchronize() self.write_stream.synchronize()
self.mem_pool_host.complete_io(operation.host_indices) self.mem_pool_host.complete_io(operation.host_indices)
...@@ -491,8 +479,8 @@ class HiCacheController: ...@@ -491,8 +479,8 @@ class HiCacheController:
batch_operation.host_indices, batch_operation.device_indices batch_operation.host_indices, batch_operation.device_indices
) )
for i in range(self.mem_pool_host.layer_num): for i in range(self.mem_pool_host.layer_num):
self.mem_pool_device.load_from_host_per_layer( self.mem_pool_host.load_to_device_per_layer(
self.mem_pool_host, self.mem_pool_device,
host_indices, host_indices,
device_indices, device_indices,
i, i,
......
...@@ -588,6 +588,7 @@ class Scheduler( ...@@ -588,6 +588,7 @@ class Scheduler(
== "fa3" # hot fix for incompatibility == "fa3" # hot fix for incompatibility
else server_args.hicache_io_backend else server_args.hicache_io_backend
), ),
hicache_mem_layout=server_args.hicache_mem_layout,
hicache_storage_backend=server_args.hicache_storage_backend, hicache_storage_backend=server_args.hicache_storage_backend,
) )
self.tp_worker.register_hicache_layer_transfer_counter( self.tp_worker.register_hicache_layer_transfer_counter(
......
...@@ -35,16 +35,33 @@ class HiRadixCache(RadixCache): ...@@ -35,16 +35,33 @@ class HiRadixCache(RadixCache):
hicache_size: int, hicache_size: int,
hicache_write_policy: str, hicache_write_policy: str,
hicache_io_backend: str, hicache_io_backend: str,
hicache_mem_layout: str,
hicache_storage_backend: Optional[str] = None, hicache_storage_backend: Optional[str] = None,
): ):
if hicache_io_backend == "direct":
if hicache_mem_layout == "page_first":
hicache_mem_layout = "layer_first"
logger.warning(
"Page first layout is not supported with direct IO backend, switching to layer first layout"
)
self.kv_cache = token_to_kv_pool_allocator.get_kvcache() self.kv_cache = token_to_kv_pool_allocator.get_kvcache()
if isinstance(self.kv_cache, MHATokenToKVPool): if isinstance(self.kv_cache, MHATokenToKVPool):
self.token_to_kv_pool_host = MHATokenToKVPoolHost( self.token_to_kv_pool_host = MHATokenToKVPoolHost(
self.kv_cache, hicache_ratio, hicache_size, page_size self.kv_cache,
hicache_ratio,
hicache_size,
page_size,
hicache_mem_layout,
) )
elif isinstance(self.kv_cache, MLATokenToKVPool): elif isinstance(self.kv_cache, MLATokenToKVPool):
self.token_to_kv_pool_host = MLATokenToKVPoolHost( self.token_to_kv_pool_host = MLATokenToKVPoolHost(
self.kv_cache, hicache_ratio, hicache_size, page_size self.kv_cache,
hicache_ratio,
hicache_size,
page_size,
hicache_mem_layout,
) )
else: else:
raise ValueError(f"HiRadixCache only supports MHA and MLA yet") raise ValueError(f"HiRadixCache only supports MHA and MLA yet")
......
...@@ -31,21 +31,17 @@ from typing import Dict, List, Optional, Tuple, Union ...@@ -31,21 +31,17 @@ from typing import Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
import torch.distributed as dist
import triton import triton
import triton.language as tl import triton.language as tl
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.utils import get_bool_env_var, is_cuda, is_npu, next_power_of_2 from sglang.srt.utils import get_bool_env_var, is_cuda, next_power_of_2
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
GB = 1024 * 1024 * 1024 GB = 1024 * 1024 * 1024
_is_cuda = is_cuda() _is_cuda = is_cuda()
_is_npu = is_npu()
if not _is_npu:
from sgl_kernel.kvcacheio import transfer_kv_per_layer, transfer_kv_per_layer_mla
class ReqToTokenPool: class ReqToTokenPool:
...@@ -153,18 +149,6 @@ class KVCache(abc.ABC): ...@@ -153,18 +149,6 @@ class KVCache(abc.ABC):
) -> None: ) -> None:
raise NotImplementedError() raise NotImplementedError()
@abc.abstractmethod
def load_from_host_per_layer(
self, host_pool, host_indices, device_indices, layer_id, io_backend
):
raise NotImplementedError()
@abc.abstractmethod
def backup_to_host_all_layer(
self, host_pool, host_indices, device_indices, io_backend
):
raise NotImplementedError()
def register_layer_transfer_counter(self, layer_transfer_counter): def register_layer_transfer_counter(self, layer_transfer_counter):
self.layer_transfer_counter = layer_transfer_counter self.layer_transfer_counter = layer_transfer_counter
...@@ -253,12 +237,18 @@ class MHATokenToKVPool(KVCache): ...@@ -253,12 +237,18 @@ class MHATokenToKVPool(KVCache):
) )
for _ in range(self.layer_num) for _ in range(self.layer_num)
] ]
self.token_stride = self.head_num * self.head_dim
self.data_ptrs = torch.tensor( self.k_data_ptrs = torch.tensor(
[x.data_ptr() for x in self.k_buffer + self.v_buffer], [x.data_ptr() for x in self.k_buffer],
dtype=torch.uint64,
device=self.device,
)
self.v_data_ptrs = torch.tensor(
[x.data_ptr() for x in self.v_buffer],
dtype=torch.uint64, dtype=torch.uint64,
device=self.device, device=self.device,
) )
self.data_ptrs = torch.cat([self.k_data_ptrs, self.v_data_ptrs], dim=0)
self.data_strides = torch.tensor( self.data_strides = torch.tensor(
[ [
np.prod(x.shape[1:]) * x.dtype.itemsize np.prod(x.shape[1:]) * x.dtype.itemsize
...@@ -347,47 +337,6 @@ class MHATokenToKVPool(KVCache): ...@@ -347,47 +337,6 @@ class MHATokenToKVPool(KVCache):
self.v_buffer[layer_id][chunk_indices] = v_chunk self.v_buffer[layer_id][chunk_indices] = v_chunk
torch.cuda.synchronize() torch.cuda.synchronize()
def load_from_host_per_layer(
self,
host_pool,
host_indices,
device_indices,
layer_id,
io_backend,
):
transfer_kv_per_layer(
src_k=host_pool.k_buffer[layer_id],
dst_k=self.k_buffer[layer_id],
src_v=host_pool.v_buffer[layer_id],
dst_v=self.v_buffer[layer_id],
src_indices=host_indices,
dst_indices=device_indices,
io_backend=io_backend,
page_size=self.page_size,
item_size=self.token_stride,
)
def backup_to_host_all_layer(
self, host_pool, host_indices, device_indices, io_backend
):
# todo: specialized all layer kernels for the layer-non-contiguous memory pool
for layer_id in range(self.start_layer, self.start_layer + self.layer_num):
if layer_id - self.start_layer >= len(host_pool.k_buffer):
raise ValueError(
f"Layer ID {layer_id} exceeds the number of layers in host pool."
)
transfer_kv_per_layer(
src_k=self.k_buffer[layer_id],
dst_k=host_pool.k_buffer[layer_id],
src_v=self.v_buffer[layer_id],
dst_v=host_pool.v_buffer[layer_id],
src_indices=device_indices,
dst_indices=host_indices,
io_backend=io_backend,
page_size=self.page_size,
item_size=self.token_stride,
)
def _get_key_buffer(self, layer_id: int): def _get_key_buffer(self, layer_id: int):
# for internal use of referencing # for internal use of referencing
if self.store_dtype != self.dtype: if self.store_dtype != self.dtype:
...@@ -602,16 +551,6 @@ class SWAKVPool(KVCache): ...@@ -602,16 +551,6 @@ class SWAKVPool(KVCache):
layer_id_override=layer_id_pool, layer_id_override=layer_id_pool,
) )
def load_from_host_per_layer(
self, host_pool, host_indices, device_indices, layer_id, io_backend
):
raise NotImplementedError("HiCache not supported for SWAKVPool.")
def backup_to_host_all_layer(
self, host_pool, host_indices, device_indices, io_backend
):
raise NotImplementedError("HiCache not supported for SWAKVPool.")
class AscendTokenToKVPool(MHATokenToKVPool): class AscendTokenToKVPool(MHATokenToKVPool):
...@@ -823,7 +762,11 @@ class MLATokenToKVPool(KVCache): ...@@ -823,7 +762,11 @@ class MLATokenToKVPool(KVCache):
for _ in range(layer_num) for _ in range(layer_num)
] ]
self.token_stride = kv_lora_rank + qk_rope_head_dim self.data_ptrs = torch.tensor(
[x.data_ptr() for x in self.kv_buffer],
dtype=torch.uint64,
device=self.device,
)
self.layer_transfer_counter = None self.layer_transfer_counter = None
kv_size = self.get_kv_size_bytes() kv_size = self.get_kv_size_bytes()
...@@ -909,38 +852,6 @@ class MLATokenToKVPool(KVCache): ...@@ -909,38 +852,6 @@ class MLATokenToKVPool(KVCache):
self.kv_buffer[layer_id], loc, cache_k_nope, cache_k_rope self.kv_buffer[layer_id], loc, cache_k_nope, cache_k_rope
) )
def load_from_host_per_layer(
self, host_pool, host_indices, device_indices, layer_id, io_backend
):
transfer_kv_per_layer_mla(
src=host_pool.kv_buffer[layer_id],
dst=self.kv_buffer[layer_id],
src_indices=host_indices,
dst_indices=device_indices,
io_backend=io_backend,
page_size=self.page_size,
item_size=self.token_stride,
)
def backup_to_host_all_layer(
self, host_pool, host_indices, device_indices, io_backend
):
# todo: specialized all layer kernels for the layer-non-contiguous memory pool
for layer_id in range(self.start_layer, self.start_layer + self.layer_num):
if layer_id - self.start_layer >= len(host_pool.kv_buffer):
raise ValueError(
f"Layer ID {layer_id} exceeds the number of layers in host pool."
)
transfer_kv_per_layer_mla(
src=self.kv_buffer[layer_id],
dst=host_pool.kv_buffer[layer_id],
src_indices=device_indices,
dst_indices=host_indices,
io_backend=io_backend,
page_size=self.page_size,
item_size=self.token_stride,
)
def get_cpu_copy(self, indices): def get_cpu_copy(self, indices):
torch.cuda.synchronize() torch.cuda.synchronize()
kv_cache_cpu = [] kv_cache_cpu = []
...@@ -1131,20 +1042,6 @@ class DoubleSparseTokenToKVPool(KVCache): ...@@ -1131,20 +1042,6 @@ class DoubleSparseTokenToKVPool(KVCache):
self.v_buffer[layer_id - self.start_layer][loc] = cache_v self.v_buffer[layer_id - self.start_layer][loc] = cache_v
self.label_buffer[layer_id - self.start_layer][loc] = cache_label self.label_buffer[layer_id - self.start_layer][loc] = cache_label
def load_from_host_per_layer(
self, host_pool, host_indices, device_indices, layer_id, io_backend
):
raise NotImplementedError(
"HiCache not supported for DoubleSparseTokenToKVPool."
)
def backup_to_host_all_layer(
self, host_pool, host_indices, device_indices, io_backend
):
raise NotImplementedError(
"HiCache not supported for DoubleSparseTokenToKVPool."
)
@triton.jit @triton.jit
def copy_all_layer_kv_cache( def copy_all_layer_kv_cache(
......
...@@ -8,6 +8,21 @@ import psutil ...@@ -8,6 +8,21 @@ import psutil
import torch import torch
from sglang.srt.mem_cache.memory_pool import KVCache, MHATokenToKVPool, MLATokenToKVPool from sglang.srt.mem_cache.memory_pool import KVCache, MHATokenToKVPool, MLATokenToKVPool
from sglang.srt.utils import is_npu
_is_npu = is_npu()
if not _is_npu:
from sgl_kernel.kvcacheio import (
transfer_kv_all_layer,
transfer_kv_all_layer_lf_pf,
transfer_kv_all_layer_mla,
transfer_kv_all_layer_mla_lf_pf,
transfer_kv_direct,
transfer_kv_per_layer,
transfer_kv_per_layer_mla,
transfer_kv_per_layer_mla_pf_lf,
transfer_kv_per_layer_pf_lf,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -42,15 +57,18 @@ class HostKVCache(abc.ABC): ...@@ -42,15 +57,18 @@ class HostKVCache(abc.ABC):
device_pool: KVCache, device_pool: KVCache,
host_to_device_ratio: float, host_to_device_ratio: float,
host_size: int, host_size: int,
page_size: int,
layout: str,
pin_memory: bool, pin_memory: bool,
device: str, device: str,
page_size: int,
): ):
self.device_pool = device_pool self.device_pool = device_pool
self.dtype = device_pool.store_dtype self.page_size = page_size
self.layout = layout
self.pin_memory = pin_memory self.pin_memory = pin_memory
self.device = device self.device = device
self.page_size = page_size
self.dtype = device_pool.store_dtype
self.size_per_token = self.get_size_per_token() self.size_per_token = self.get_size_per_token()
if host_size > 0: if host_size > 0:
self.size = int(host_size * 1e9 // self.size_per_token) self.size = int(host_size * 1e9 // self.size_per_token)
...@@ -98,6 +116,24 @@ class HostKVCache(abc.ABC): ...@@ -98,6 +116,24 @@ class HostKVCache(abc.ABC):
def init_kv_buffer(self): def init_kv_buffer(self):
raise NotImplementedError() raise NotImplementedError()
@abc.abstractmethod
def load_to_device_per_layer(
self, device_pool, host_indices, device_indices, layer_id, io_backend
) -> None:
"""
Load KV data from the host memory pool to the device memory pool for a specific layer.
"""
raise NotImplementedError()
@abc.abstractmethod
def backup_from_device_all_layer(
self, device_pool, host_indices, device_indices, io_backend
) -> None:
"""
Backup KV data from the device memory pool to the host memory pool for all layers.
"""
raise NotImplementedError()
@abc.abstractmethod @abc.abstractmethod
def get_flat_data_page(self, index) -> torch.Tensor: def get_flat_data_page(self, index) -> torch.Tensor:
""" """
...@@ -238,11 +274,30 @@ class MHATokenToKVPoolHost(HostKVCache): ...@@ -238,11 +274,30 @@ class MHATokenToKVPoolHost(HostKVCache):
host_to_device_ratio: float, host_to_device_ratio: float,
host_size: int, host_size: int,
page_size: int, page_size: int,
layout: str,
pin_memory: bool = True, pin_memory: bool = True,
device: str = "cpu", device: str = "cpu",
): ):
super().__init__( super().__init__(
device_pool, host_to_device_ratio, host_size, pin_memory, device, page_size device_pool,
host_to_device_ratio,
host_size,
page_size,
layout,
pin_memory,
device,
)
self.k_data_refs = [self.k_buffer[i] for i in range(self.layer_num)]
self.v_data_refs = [self.v_buffer[i] for i in range(self.layer_num)]
self.k_data_ptrs = torch.tensor(
[x.data_ptr() for x in self.k_data_refs],
dtype=torch.uint64,
device=self.device_pool.device,
)
self.v_data_ptrs = torch.tensor(
[x.data_ptr() for x in self.v_data_refs],
dtype=torch.uint64,
device=self.device_pool.device,
) )
def get_size_per_token(self): def get_size_per_token(self):
...@@ -253,16 +308,128 @@ class MHATokenToKVPoolHost(HostKVCache): ...@@ -253,16 +308,128 @@ class MHATokenToKVPoolHost(HostKVCache):
return self.head_dim * self.head_num * self.layer_num * self.dtype.itemsize * 2 return self.head_dim * self.head_num * self.layer_num * self.dtype.itemsize * 2
def init_kv_buffer(self): def init_kv_buffer(self):
if self.layout == "layer_first":
dims = (2, self.layer_num, self.size, self.head_num, self.head_dim)
elif self.layout == "page_first":
dims = (2, self.size, self.layer_num, self.head_num, self.head_dim)
else:
raise ValueError(f"Unsupported layout: {self.layout}")
self.token_stride_size = self.head_num * self.head_dim * self.dtype.itemsize
self.layout_dim = self.token_stride_size * self.layer_num
return torch.empty( return torch.empty(
(2, self.layer_num, self.size, self.head_num, self.head_dim), dims,
dtype=self.dtype, dtype=self.dtype,
device=self.device, device=self.device,
pin_memory=self.pin_memory, pin_memory=self.pin_memory,
) )
# todo, page first memory layout @property
def k_buffer(self):
return self.kv_buffer[0]
@property
def v_buffer(self):
return self.kv_buffer[1]
def load_to_device_per_layer(
self,
device_pool,
host_indices,
device_indices,
layer_id,
io_backend,
):
if io_backend == "kernel":
if self.layout == "layer_first":
transfer_kv_per_layer(
src_k=self.k_buffer[layer_id],
dst_k=device_pool.k_buffer[layer_id],
src_v=self.v_buffer[layer_id],
dst_v=device_pool.v_buffer[layer_id],
src_indices=host_indices,
dst_indices=device_indices,
item_size=self.token_stride_size,
)
elif self.layout == "page_first":
transfer_kv_per_layer_pf_lf(
src_k=self.k_buffer,
dst_k=device_pool.k_buffer[layer_id],
src_v=self.v_buffer,
dst_v=device_pool.v_buffer[layer_id],
src_indices=host_indices,
dst_indices=device_indices,
item_size=self.token_stride_size,
src_layout_dim=self.layout_dim,
)
else:
raise ValueError(f"Unsupported layout: {self.layout}")
elif io_backend == "direct":
assert (
self.layout == "layer_first"
), f"Direct IO backend only supports layer_first layout."
transfer_kv_direct(
src_layers=[self.k_buffer[layer_id], self.v_buffer[layer_id]],
dst_layers=[
device_pool.k_buffer[layer_id],
device_pool.v_buffer[layer_id],
],
src_indices=host_indices,
dst_indices=device_indices,
page_size=self.page_size,
)
else:
raise ValueError(f"Unsupported IO backend: {io_backend}")
def backup_from_device_all_layer(
self, device_pool, host_indices, device_indices, io_backend
):
if io_backend == "kernel":
if self.layout == "layer_first":
transfer_kv_all_layer(
src_k_layers=device_pool.k_data_ptrs,
dst_k_layers=self.k_data_ptrs,
src_v_layers=device_pool.v_data_ptrs,
dst_v_layers=self.v_data_ptrs,
src_indices=device_indices,
dst_indices=host_indices,
item_size=self.token_stride_size,
num_layers=self.layer_num,
)
elif self.layout == "page_first":
transfer_kv_all_layer_lf_pf(
src_k_layers=device_pool.k_data_ptrs,
dst_k=self.k_buffer,
src_v_layers=device_pool.v_data_ptrs,
dst_v=self.v_buffer,
src_indices=device_indices,
dst_indices=host_indices,
item_size=self.token_stride_size,
dst_layout_dim=self.layout_dim,
num_layers=self.layer_num,
)
else:
raise ValueError(f"Unsupported layout: {self.layout}")
elif io_backend == "direct":
assert (
self.layout == "layer_first"
), f"Direct IO backend only supports layer_first layout."
transfer_kv_direct(
src_layers=device_pool.k_buffer + device_pool.v_buffer,
dst_layers=self.k_data_refs + self.v_data_refs,
src_indices=device_indices,
dst_indices=host_indices,
page_size=self.page_size,
)
else:
raise ValueError(f"Unsupported IO backend: {io_backend}")
def get_flat_data_page(self, index) -> torch.Tensor: def get_flat_data_page(self, index) -> torch.Tensor:
return self.kv_buffer[:, :, index : index + self.page_size, :, :].flatten() if self.layout == "layer_first":
return self.kv_buffer[:, :, index : index + self.page_size, :, :].flatten()
elif self.layout == "page_first":
return self.kv_buffer[:, index : index + self.page_size, :, :, :].flatten()
else:
raise ValueError(f"Unsupported layout: {self.layout}")
def get_dummy_flat_data_page(self) -> torch.Tensor: def get_dummy_flat_data_page(self) -> torch.Tensor:
return torch.zeros( return torch.zeros(
...@@ -273,13 +440,24 @@ class MHATokenToKVPoolHost(HostKVCache): ...@@ -273,13 +440,24 @@ class MHATokenToKVPoolHost(HostKVCache):
).flatten() ).flatten()
def set_from_flat_data_page(self, index: int, data_page: torch.Tensor) -> None: def set_from_flat_data_page(self, index: int, data_page: torch.Tensor) -> None:
self.kv_buffer[:, :, index : index + self.page_size, :, :] = data_page.reshape( if self.layout == "layer_first":
2, self.kv_buffer[:, :, index : index + self.page_size, :, :] = (
self.layer_num, data_page.reshape(
self.page_size, 2,
self.head_num, self.layer_num,
self.head_dim, self.page_size,
) self.head_num,
self.head_dim,
)
)
elif self.layout == "page_first":
self.kv_buffer[:, index : index + self.page_size, :, :, :] = (
data_page.reshape(
2, self.page_size, self.layer_num, self.head_num, self.head_dim
)
)
else:
raise ValueError(f"Unsupported layout: {self.layout}")
def get_buffer_meta(self, keys, indices): def get_buffer_meta(self, keys, indices):
ptr_list = [] ptr_list = []
...@@ -318,14 +496,6 @@ class MHATokenToKVPoolHost(HostKVCache): ...@@ -318,14 +496,6 @@ class MHATokenToKVPoolHost(HostKVCache):
element_size_list = [element_size] * len(key_list) element_size_list = [element_size] * len(key_list)
return key_list, ptr_list, element_size_list return key_list, ptr_list, element_size_list
@property
def k_buffer(self):
return self.kv_buffer[0]
@property
def v_buffer(self):
return self.kv_buffer[1]
class MLATokenToKVPoolHost(HostKVCache): class MLATokenToKVPoolHost(HostKVCache):
device_pool: MLATokenToKVPool device_pool: MLATokenToKVPool
...@@ -336,11 +506,24 @@ class MLATokenToKVPoolHost(HostKVCache): ...@@ -336,11 +506,24 @@ class MLATokenToKVPoolHost(HostKVCache):
host_to_device_ratio: float, host_to_device_ratio: float,
host_size: int, host_size: int,
page_size: int, page_size: int,
layout: str,
pin_memory: bool = True, pin_memory: bool = True,
device: str = "cpu", device: str = "cpu",
): ):
super().__init__( super().__init__(
device_pool, host_to_device_ratio, host_size, pin_memory, device, page_size device_pool,
host_to_device_ratio,
host_size,
page_size,
layout,
pin_memory,
device,
)
self.data_refs = [self.kv_buffer[i] for i in range(self.layer_num)]
self.data_ptrs = torch.tensor(
[x.data_ptr() for x in self.data_refs],
dtype=torch.uint64,
device=self.device_pool.device,
) )
def get_size_per_token(self): def get_size_per_token(self):
...@@ -356,20 +539,115 @@ class MLATokenToKVPoolHost(HostKVCache): ...@@ -356,20 +539,115 @@ class MLATokenToKVPoolHost(HostKVCache):
) )
def init_kv_buffer(self): def init_kv_buffer(self):
return torch.empty( if self.layout == "layer_first":
( dims = (
self.layer_num, self.layer_num,
self.size, self.size,
1, 1,
self.kv_lora_rank + self.qk_rope_head_dim, self.kv_lora_rank + self.qk_rope_head_dim,
), )
elif self.layout == "page_first":
dims = (
self.size,
self.layer_num,
1,
self.kv_lora_rank + self.qk_rope_head_dim,
)
else:
raise ValueError(f"Unsupported layout: {self.layout}")
self.token_stride_size = (
self.kv_lora_rank + self.qk_rope_head_dim
) * self.dtype.itemsize
self.layout_dim = self.token_stride_size * self.layer_num
return torch.empty(
dims,
dtype=self.dtype, dtype=self.dtype,
device=self.device, device=self.device,
pin_memory=self.pin_memory, pin_memory=self.pin_memory,
) )
def load_to_device_per_layer(
self, device_pool, host_indices, device_indices, layer_id, io_backend
):
if io_backend == "kernel":
if self.layout == "layer_first":
transfer_kv_per_layer_mla(
src=self.kv_buffer[layer_id],
dst=device_pool.kv_buffer[layer_id],
src_indices=host_indices,
dst_indices=device_indices,
item_size=self.token_stride_size,
)
elif self.layout == "page_first":
transfer_kv_per_layer_mla_pf_lf(
src=self.kv_buffer,
dst=device_pool.kv_buffer[layer_id],
src_indices=host_indices,
dst_indices=device_indices,
item_size=self.token_stride_size,
src_layout_dim=self.layout_dim,
)
else:
raise ValueError(f"Unsupported layout: {self.layout}")
elif io_backend == "direct":
assert (
self.layout == "layer_first"
), f"Direct IO backend only supports layer_first layout."
transfer_kv_direct(
src_layers=[self.kv_buffer[layer_id]],
dst_layers=[device_pool.kv_buffer[layer_id]],
src_indices=host_indices,
dst_indices=device_indices,
page_size=self.page_size,
)
def backup_from_device_all_layer(
self, device_pool, host_indices, device_indices, io_backend
):
if io_backend == "kernel":
if self.layout == "layer_first":
transfer_kv_all_layer_mla(
src_layers=device_pool.data_ptrs,
dst_layers=self.data_ptrs,
src_indices=device_indices,
dst_indices=host_indices,
item_size=self.token_stride_size,
num_layers=self.layer_num,
)
elif self.layout == "page_first":
transfer_kv_all_layer_mla_lf_pf(
src_layers=device_pool.data_ptrs,
dst_k=self.kv_buffer,
src_indices=device_indices,
dst_indices=host_indices,
item_size=self.token_stride_size,
dst_layout_dim=self.layout_dim,
num_layers=self.layer_num,
)
else:
raise ValueError(f"Unsupported layout: {self.layout}")
elif io_backend == "direct":
assert (
self.layout == "layer_first"
), f"Direct IO backend only supports layer_first layout."
transfer_kv_direct(
src_layers=device_pool.kv_buffer,
dst_layers=self.data_refs,
src_indices=device_indices,
dst_indices=host_indices,
page_size=self.page_size,
)
else:
raise ValueError(f"Unsupported IO backend: {io_backend}")
def get_flat_data_page(self, index) -> torch.Tensor: def get_flat_data_page(self, index) -> torch.Tensor:
return self.kv_buffer[:, index : index + self.page_size, :, :].flatten() if self.layout == "layer_first":
return self.kv_buffer[:, index : index + self.page_size, :, :].flatten()
elif self.layout == "page_first":
return self.kv_buffer[index : index + self.page_size, :, :, :].flatten()
else:
raise ValueError(f"Unsupported layout: {self.layout}")
def get_dummy_flat_data_page(self) -> torch.Tensor: def get_dummy_flat_data_page(self) -> torch.Tensor:
return torch.zeros( return torch.zeros(
...@@ -385,12 +663,22 @@ class MLATokenToKVPoolHost(HostKVCache): ...@@ -385,12 +663,22 @@ class MLATokenToKVPoolHost(HostKVCache):
).flatten() ).flatten()
def set_from_flat_data_page(self, index: int, data_page: torch.Tensor) -> None: def set_from_flat_data_page(self, index: int, data_page: torch.Tensor) -> None:
self.kv_buffer[:, index : index + self.page_size, :, :] = data_page.reshape( if self.layout == "layer_first":
self.layer_num, self.kv_buffer[:, index : index + self.page_size, :, :] = data_page.reshape(
self.page_size, self.layer_num,
1, self.page_size,
self.kv_lora_rank + self.qk_rope_head_dim, 1,
) self.kv_lora_rank + self.qk_rope_head_dim,
)
elif self.layout == "page_first":
self.kv_buffer[index : index + self.page_size, :, :, :] = data_page.reshape(
self.page_size,
self.layer_num,
1,
self.kv_lora_rank + self.qk_rope_head_dim,
)
else:
raise ValueError(f"Unsupported layout: {self.layout}")
def get_buffer_meta(self, keys, indices): def get_buffer_meta(self, keys, indices):
ptr_list = [] ptr_list = []
......
...@@ -198,7 +198,8 @@ class ServerArgs: ...@@ -198,7 +198,8 @@ class ServerArgs:
hicache_ratio: float = 2.0 hicache_ratio: float = 2.0
hicache_size: int = 0 hicache_size: int = 0
hicache_write_policy: str = "write_through_selective" hicache_write_policy: str = "write_through_selective"
hicache_io_backend: str = "" hicache_io_backend: str = "kernel"
hicache_mem_layout: str = "layer_first"
hicache_storage_backend: Optional[str] = None hicache_storage_backend: Optional[str] = None
# Double Sparsity # Double Sparsity
...@@ -1487,6 +1488,14 @@ class ServerArgs: ...@@ -1487,6 +1488,14 @@ class ServerArgs:
default=ServerArgs.hicache_io_backend, default=ServerArgs.hicache_io_backend,
help="The IO backend for KV cache transfer between CPU and GPU", help="The IO backend for KV cache transfer between CPU and GPU",
) )
parser.add_argument(
"--hicache-mem-layout",
type=str,
choices=["layer_first", "page_first"],
default=ServerArgs.hicache_mem_layout,
help="The layout of host memory pool for hierarchical cache.",
)
parser.add_argument( parser.add_argument(
"--hicache-storage-backend", "--hicache-storage-backend",
type=str, type=str,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment