Unverified Commit 2fc824b8 authored by Zhiqiang Xie's avatar Zhiqiang Xie Committed by GitHub
Browse files

Kernels for efficient KV cache IO (#7313)

parent 253454de
...@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
""" """
import concurrent.futures
import logging import logging
import math import math
import threading import threading
...@@ -169,12 +168,23 @@ class HiCacheController: ...@@ -169,12 +168,23 @@ class HiCacheController:
page_size: int, page_size: int,
load_cache_event: threading.Event = None, load_cache_event: threading.Event = None,
write_policy: str = "write_through_selective", write_policy: str = "write_through_selective",
io_backend: str = "",
): ):
self.mem_pool_device_allocator = token_to_kv_pool_allocator self.mem_pool_device_allocator = token_to_kv_pool_allocator
self.mem_pool_device = token_to_kv_pool_allocator.get_kvcache() self.mem_pool_device = token_to_kv_pool_allocator.get_kvcache()
self.mem_pool_host = mem_pool_host self.mem_pool_host = mem_pool_host
self.write_policy = write_policy self.write_policy = write_policy
self.page_size = page_size self.page_size = page_size
# using kernel for small page KV cache transfer and DMA for large pages
if not io_backend:
IO_BACKEND_PAGE_SIZE_THRESHOLD = 64
self.io_backend = (
"direct"
if self.page_size >= IO_BACKEND_PAGE_SIZE_THRESHOLD
else "kernel"
)
else:
self.io_backend = io_backend
self.load_cache_event = load_cache_event self.load_cache_event = load_cache_event
self.layer_done_counter = LayerDoneCounter(self.mem_pool_device.layer_num) self.layer_done_counter = LayerDoneCounter(self.mem_pool_device.layer_num)
...@@ -203,12 +213,7 @@ class HiCacheController: ...@@ -203,12 +213,7 @@ class HiCacheController:
self.load_stream = torch.cuda.Stream() self.load_stream = torch.cuda.Stream()
self.write_thread = threading.Thread( self.write_thread = threading.Thread(
target=( target=self.write_thread_func_direct, daemon=True
self.write_thread_func_buffer
if self.page_size == 1
else self.write_thread_func_direct
),
daemon=True,
) )
self.load_thread = threading.Thread( self.load_thread = threading.Thread(
target=self.load_thread_func_layer_by_layer, daemon=True target=self.load_thread_func_layer_by_layer, daemon=True
...@@ -229,12 +234,7 @@ class HiCacheController: ...@@ -229,12 +234,7 @@ class HiCacheController:
self.ack_load_queue.queue.clear() self.ack_load_queue.queue.clear()
self.write_thread = threading.Thread( self.write_thread = threading.Thread(
target=( target=self.write_thread_func_direct, daemon=True
self.write_thread_func_buffer
if self.page_size == 1
else self.write_thread_func_direct
),
daemon=True,
) )
self.load_thread = threading.Thread( self.load_thread = threading.Thread(
target=self.load_thread_func_layer_by_layer, daemon=True target=self.load_thread_func_layer_by_layer, daemon=True
...@@ -281,6 +281,15 @@ class HiCacheController: ...@@ -281,6 +281,15 @@ class HiCacheController:
) )
return device_indices return device_indices
def move_indices(self, host_indices, device_indices):
# move indices to GPU if using kernels, to host if using direct indexing
if self.io_backend == "kernel":
return host_indices.to(self.mem_pool_device.device), device_indices
elif self.io_backend == "direct":
return host_indices, device_indices.cpu()
else:
raise ValueError(f"Unsupported io backend")
def write_thread_func_direct(self): def write_thread_func_direct(self):
""" """
Directly write through KV caches to host memory without buffering. Directly write through KV caches to host memory without buffering.
...@@ -289,10 +298,14 @@ class HiCacheController: ...@@ -289,10 +298,14 @@ class HiCacheController:
while not self.stop_event.is_set(): while not self.stop_event.is_set():
try: try:
operation = self.write_queue.get(block=True, timeout=1) operation = self.write_queue.get(block=True, timeout=1)
self.mem_pool_host.write_page_all_layers( host_indices, device_indices = self.move_indices(
operation.host_indices, operation.host_indices, operation.device_indices
operation.device_indices, )
self.mem_pool_device, self.mem_pool_device.backup_to_host_all_layer(
self.mem_pool_host,
host_indices,
device_indices,
self.io_backend,
) )
self.write_stream.synchronize() self.write_stream.synchronize()
self.mem_pool_host.complete_io(operation.host_indices) self.mem_pool_host.complete_io(operation.host_indices)
...@@ -304,27 +317,6 @@ class HiCacheController: ...@@ -304,27 +317,6 @@ class HiCacheController:
except Exception as e: except Exception as e:
logger.error(e) logger.error(e)
def load_thread_func_direct(self):
"""
Directly load KV caches from host memory to device memory without buffering.
"""
torch.cuda.set_stream(self.load_stream)
while not self.stop_event.is_set():
try:
operation = self.load_queue.get(block=True, timeout=1)
operation.data = self.mem_pool_host.get_flat_data(
operation.host_indices
)
self.mem_pool_device.transfer(operation.device_indices, operation.data)
self.mem_pool_host.complete_io(operation.host_indices)
for node_id in operation.node_ids:
if node_id != 0:
self.ack_load_queue.put(node_id)
except Empty:
continue
except Exception as e:
logger.error(e)
def load_thread_func_layer_by_layer(self): def load_thread_func_layer_by_layer(self):
""" """
Load KV caches from host memory to device memory layer by layer. Load KV caches from host memory to device memory layer by layer.
...@@ -349,22 +341,18 @@ class HiCacheController: ...@@ -349,22 +341,18 @@ class HiCacheController:
# start layer-wise KV cache transfer from CPU to GPU # start layer-wise KV cache transfer from CPU to GPU
self.layer_done_counter.reset() self.layer_done_counter.reset()
host_indices, device_indices = self.move_indices(
batch_operation.host_indices, batch_operation.device_indices
)
for i in range(self.mem_pool_host.layer_num): for i in range(self.mem_pool_host.layer_num):
if self.page_size == 1: self.mem_pool_device.load_from_host_per_layer(
flat_data = self.mem_pool_host.get_flat_data_by_layer( self.mem_pool_host,
batch_operation.host_indices, i host_indices,
) device_indices,
self.mem_pool_device.transfer_per_layer( i,
batch_operation.device_indices, flat_data, i self.io_backend,
) )
else: self.load_stream.synchronize()
self.mem_pool_host.load_page_per_layer(
batch_operation.host_indices,
batch_operation.device_indices,
self.mem_pool_device,
i,
)
self.load_stream.synchronize()
self.layer_done_counter.increment() self.layer_done_counter.increment()
self.mem_pool_host.complete_io(batch_operation.host_indices) self.mem_pool_host.complete_io(batch_operation.host_indices)
...@@ -372,148 +360,6 @@ class HiCacheController: ...@@ -372,148 +360,6 @@ class HiCacheController:
if node_id != 0: if node_id != 0:
self.ack_load_queue.put(node_id) self.ack_load_queue.put(node_id)
def write_aux_func(self, no_wait=False):
"""
Auxiliary function to prepare the buffer for write operations.
"""
torch.cuda.set_stream(self.write_stream)
def _to_op(op_):
assert op_.device_indices.is_cuda, "Device indices should be on GPU"
op_.data = self.mem_pool_device.get_flat_data(op_.device_indices).to(
self.mem_pool_host.device
)
self.write_buffer.put(op_)
return op_
buffer = None
while not self.stop_event.is_set():
try:
operation = self.write_queue.get(block=True, timeout=1)
factor = (
len(operation.device_indices) // self.write_buffer.max_buffer_size
)
if factor >= 1:
if buffer is not None:
_to_op(buffer)
buffer = None
if factor < 2:
_to_op(operation)
else:
split_ops = operation.split(factor)
for op_ in split_ops:
_to_op(op_)
continue
if buffer is None:
buffer = operation
else:
buffer.merge(operation)
if (
no_wait
or len(buffer.host_indices) >= self.write_buffer.max_buffer_size
or self.write_queue.empty()
or self.write_buffer.empty()
):
_to_op(buffer)
buffer = None
except Empty:
continue
except Exception as e:
logger.error(e)
def load_aux_func(self):
"""
Auxiliary function to prepare the buffer for load operations.
"""
def _pin_op(op_, put=True):
op_.data = (
self.mem_pool_host.get_flat_data(op_.host_indices)
.contiguous()
.pin_memory()
)
if put:
self.load_buffer.put(op_)
return op_
buffer = None
while not self.stop_event.is_set():
try:
operation = self.load_queue.get(block=True, timeout=1)
factor = len(operation.host_indices) // self.load_buffer.max_buffer_size
if factor >= 1:
if buffer is not None:
_pin_op(buffer)
buffer = None
if factor < 2:
_pin_op(operation)
else:
split_ops = operation.split(factor)
split_args = [(op_, True) for op_ in split_ops[:-1]]
split_args.append((split_ops[-1], False))
# Spawn threads to pin each op concurrently
with concurrent.futures.ThreadPoolExecutor() as executor:
pinned_ops = list(
executor.map(
lambda x: _pin_op(x[0], put=x[1]), split_args
)
)
# preserve the order of last op to ensure correct ack
self.load_buffer.put(pinned_ops[-1])
continue
if buffer is None:
buffer = operation
else:
buffer.merge(operation)
if (
len(buffer.host_indices) >= self.load_buffer.max_buffer_size
or self.load_queue.empty()
or self.load_buffer.empty()
):
_pin_op(buffer)
buffer = None
except Empty:
continue
except Exception as e:
logger.error(e)
# todo (zhiqiang): double buffering to be deprecated
def write_thread_func_buffer(self):
aux_thread = threading.Thread(target=self.write_aux_func, daemon=True)
aux_thread.start()
while not self.stop_event.is_set():
operation = self.write_buffer.get()
if operation is None:
continue
self.mem_pool_host.assign_flat_data(operation.host_indices, operation.data)
self.mem_pool_host.complete_io(operation.host_indices)
for node_id in operation.node_ids:
if node_id != 0:
self.ack_write_queue.put(node_id)
aux_thread.join()
def load_thread_func_buffer(self):
torch.cuda.set_stream(self.load_stream)
aux_thread = threading.Thread(target=self.load_aux_func, daemon=True)
aux_thread.start()
while not self.stop_event.is_set():
operation = self.load_buffer.get()
if operation is None:
continue
self.mem_pool_device.transfer(operation.device_indices, operation.data)
self.mem_pool_host.complete_io(operation.host_indices)
for node_id in operation.node_ids:
if node_id != 0:
self.ack_load_queue.put(node_id)
aux_thread.join()
def evict_device( def evict_device(
self, device_indices: torch.Tensor, host_indices: torch.Tensor self, device_indices: torch.Tensor, host_indices: torch.Tensor
) -> int: ) -> int:
......
...@@ -591,6 +591,12 @@ class Scheduler( ...@@ -591,6 +591,12 @@ class Scheduler(
hicache_ratio=server_args.hicache_ratio, hicache_ratio=server_args.hicache_ratio,
hicache_size=server_args.hicache_size, hicache_size=server_args.hicache_size,
hicache_write_policy=server_args.hicache_write_policy, hicache_write_policy=server_args.hicache_write_policy,
hicache_io_backend=(
"direct"
if server_args.attention_backend
== "fa3" # hot fix for incompatibility
else server_args.hicache_io_backend
),
) )
self.tp_worker.register_hicache_layer_transfer_counter( self.tp_worker.register_hicache_layer_transfer_counter(
self.tree_cache.cache_controller.layer_done_counter self.tree_cache.cache_controller.layer_done_counter
......
...@@ -34,6 +34,7 @@ class HiRadixCache(RadixCache): ...@@ -34,6 +34,7 @@ class HiRadixCache(RadixCache):
hicache_ratio: float, hicache_ratio: float,
hicache_size: int, hicache_size: int,
hicache_write_policy: str, hicache_write_policy: str,
hicache_io_backend: str,
): ):
self.kv_cache = token_to_kv_pool_allocator.get_kvcache() self.kv_cache = token_to_kv_pool_allocator.get_kvcache()
if isinstance(self.kv_cache, MHATokenToKVPool): if isinstance(self.kv_cache, MHATokenToKVPool):
...@@ -56,6 +57,7 @@ class HiRadixCache(RadixCache): ...@@ -56,6 +57,7 @@ class HiRadixCache(RadixCache):
page_size, page_size,
load_cache_event=self.load_cache_event, load_cache_event=self.load_cache_event,
write_policy=hicache_write_policy, write_policy=hicache_write_policy,
io_backend=hicache_io_backend,
) )
# record the nodes with ongoing write through # record the nodes with ongoing write through
......
...@@ -34,10 +34,11 @@ import torch ...@@ -34,10 +34,11 @@ import torch
import torch.distributed as dist import torch.distributed as dist
import triton import triton
import triton.language as tl import triton.language as tl
from sgl_kernel.kvcacheio import transfer_kv_per_layer, transfer_kv_per_layer_mla
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.utils import debug_timing, get_bool_env_var, is_cuda, next_power_of_2 from sglang.srt.utils import get_bool_env_var, is_cuda, next_power_of_2
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -150,13 +151,16 @@ class KVCache(abc.ABC): ...@@ -150,13 +151,16 @@ class KVCache(abc.ABC):
) -> None: ) -> None:
raise NotImplementedError() raise NotImplementedError()
def get_flat_data(self, indices): @abc.abstractmethod
raise NotImplementedError() def load_from_host_per_layer(
self, host_pool, host_indices, device_indices, layer_id, io_backend
def transfer(self, indices, flat_data): ):
raise NotImplementedError() raise NotImplementedError()
def transfer_per_layer(self, indices, flat_data, layer_id): @abc.abstractmethod
def backup_to_host_all_layer(
self, host_pool, host_indices, device_indices, io_backend
):
raise NotImplementedError() raise NotImplementedError()
def register_layer_transfer_counter(self, layer_transfer_counter): def register_layer_transfer_counter(self, layer_transfer_counter):
...@@ -247,7 +251,7 @@ class MHATokenToKVPool(KVCache): ...@@ -247,7 +251,7 @@ class MHATokenToKVPool(KVCache):
) )
for _ in range(self.layer_num) for _ in range(self.layer_num)
] ]
self.token_stride = self.head_num * self.head_dim
self.data_ptrs = torch.tensor( self.data_ptrs = torch.tensor(
[x.data_ptr() for x in self.k_buffer + self.v_buffer], [x.data_ptr() for x in self.k_buffer + self.v_buffer],
dtype=torch.uint64, dtype=torch.uint64,
...@@ -281,24 +285,24 @@ class MHATokenToKVPool(KVCache): ...@@ -281,24 +285,24 @@ class MHATokenToKVPool(KVCache):
# layer_num x [seq_len, head_num, head_dim] # layer_num x [seq_len, head_num, head_dim]
# layer_num x [page_num, page_size, head_num, head_dim] # layer_num x [page_num, page_size, head_num, head_dim]
kv_data_ptrs = [ kv_data_ptrs = [
self.get_key_buffer(i).data_ptr() self._get_key_buffer(i).data_ptr()
for i in range(self.start_layer, self.start_layer + self.layer_num) for i in range(self.start_layer, self.start_layer + self.layer_num)
] + [ ] + [
self.get_value_buffer(i).data_ptr() self._get_value_buffer(i).data_ptr()
for i in range(self.start_layer, self.start_layer + self.layer_num) for i in range(self.start_layer, self.start_layer + self.layer_num)
] ]
kv_data_lens = [ kv_data_lens = [
self.get_key_buffer(i).nbytes self._get_key_buffer(i).nbytes
for i in range(self.start_layer, self.start_layer + self.layer_num) for i in range(self.start_layer, self.start_layer + self.layer_num)
] + [ ] + [
self.get_value_buffer(i).nbytes self._get_value_buffer(i).nbytes
for i in range(self.start_layer, self.start_layer + self.layer_num) for i in range(self.start_layer, self.start_layer + self.layer_num)
] ]
kv_item_lens = [ kv_item_lens = [
self.get_key_buffer(i)[0].nbytes * self.page_size self._get_key_buffer(i)[0].nbytes * self.page_size
for i in range(self.start_layer, self.start_layer + self.layer_num) for i in range(self.start_layer, self.start_layer + self.layer_num)
] + [ ] + [
self.get_value_buffer(i)[0].nbytes * self.page_size self._get_value_buffer(i)[0].nbytes * self.page_size
for i in range(self.start_layer, self.start_layer + self.layer_num) for i in range(self.start_layer, self.start_layer + self.layer_num)
] ]
return kv_data_ptrs, kv_data_lens, kv_item_lens return kv_data_ptrs, kv_data_lens, kv_item_lens
...@@ -341,49 +345,73 @@ class MHATokenToKVPool(KVCache): ...@@ -341,49 +345,73 @@ class MHATokenToKVPool(KVCache):
self.v_buffer[layer_id][chunk_indices] = v_chunk self.v_buffer[layer_id][chunk_indices] = v_chunk
torch.cuda.synchronize() torch.cuda.synchronize()
# Todo: different memory layout def load_from_host_per_layer(
def get_flat_data(self, indices): self,
# prepare a large chunk of contiguous data for efficient transfer host_pool,
flatten = torch.stack( host_indices,
[ device_indices,
torch.stack([self.k_buffer[i][indices] for i in range(self.layer_num)]), layer_id,
torch.stack([self.v_buffer[i][indices] for i in range(self.layer_num)]), io_backend,
] ):
transfer_kv_per_layer(
src_k=host_pool.k_buffer[layer_id],
dst_k=self.k_buffer[layer_id],
src_v=host_pool.v_buffer[layer_id],
dst_v=self.v_buffer[layer_id],
src_indices=host_indices,
dst_indices=device_indices,
io_backend=io_backend,
page_size=self.page_size,
item_size=self.token_stride,
) )
return flatten
@debug_timing
def transfer(self, indices, flat_data):
# transfer prepared data from host to device
flat_data = flat_data.to(device=self.device, non_blocking=False)
k_data, v_data = flat_data[0], flat_data[1]
for i in range(self.layer_num):
self.k_buffer[i][indices] = k_data[i]
self.v_buffer[i][indices] = v_data[i]
def transfer_per_layer(self, indices, flat_data, layer_id):
# transfer prepared data from host to device
flat_data = flat_data.to(device=self.device, non_blocking=False)
k_data, v_data = flat_data[0], flat_data[1]
self.k_buffer[layer_id - self.start_layer][indices] = k_data
self.v_buffer[layer_id - self.start_layer][indices] = v_data
def get_key_buffer(self, layer_id: int): def backup_to_host_all_layer(
if self.layer_transfer_counter is not None: self, host_pool, host_indices, device_indices, io_backend
self.layer_transfer_counter.wait_until(layer_id - self.start_layer) ):
# todo: specialized all layer kernels for the layer-non-contiguous memory pool
for layer_id in range(self.start_layer, self.start_layer + self.layer_num):
if layer_id - self.start_layer >= len(host_pool.k_buffer):
raise ValueError(
f"Layer ID {layer_id} exceeds the number of layers in host pool."
)
transfer_kv_per_layer(
src_k=self.k_buffer[layer_id],
dst_k=host_pool.k_buffer[layer_id],
src_v=self.v_buffer[layer_id],
dst_v=host_pool.v_buffer[layer_id],
src_indices=device_indices,
dst_indices=host_indices,
io_backend=io_backend,
page_size=self.page_size,
item_size=self.token_stride,
)
def _get_key_buffer(self, layer_id: int):
# for internal use of referencing
if self.store_dtype != self.dtype: if self.store_dtype != self.dtype:
return self.k_buffer[layer_id - self.start_layer].view(self.dtype) return self.k_buffer[layer_id - self.start_layer].view(self.dtype)
return self.k_buffer[layer_id - self.start_layer] return self.k_buffer[layer_id - self.start_layer]
def get_value_buffer(self, layer_id: int): def get_key_buffer(self, layer_id: int):
# note: get_key_buffer is hooked with synchronization for layer-wise KV cache loading
# it is supposed to be used only by attention backend not for information purpose
# same applies to get_value_buffer and get_kv_buffer
if self.layer_transfer_counter is not None: if self.layer_transfer_counter is not None:
self.layer_transfer_counter.wait_until(layer_id - self.start_layer) self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
return self._get_key_buffer(layer_id)
def _get_value_buffer(self, layer_id: int):
# for internal use of referencing
if self.store_dtype != self.dtype: if self.store_dtype != self.dtype:
return self.v_buffer[layer_id - self.start_layer].view(self.dtype) return self.v_buffer[layer_id - self.start_layer].view(self.dtype)
return self.v_buffer[layer_id - self.start_layer] return self.v_buffer[layer_id - self.start_layer]
def get_value_buffer(self, layer_id: int):
if self.layer_transfer_counter is not None:
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
return self._get_value_buffer(layer_id)
def get_kv_buffer(self, layer_id: int): def get_kv_buffer(self, layer_id: int):
return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id) return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id)
...@@ -761,6 +789,7 @@ class MLATokenToKVPool(KVCache): ...@@ -761,6 +789,7 @@ class MLATokenToKVPool(KVCache):
for _ in range(layer_num) for _ in range(layer_num)
] ]
self.token_stride = kv_lora_rank + qk_rope_head_dim
self.layer_transfer_counter = None self.layer_transfer_counter = None
kv_size = self.get_kv_size_bytes() kv_size = self.get_kv_size_bytes()
...@@ -846,21 +875,37 @@ class MLATokenToKVPool(KVCache): ...@@ -846,21 +875,37 @@ class MLATokenToKVPool(KVCache):
self.kv_buffer[layer_id], loc, cache_k_nope, cache_k_rope self.kv_buffer[layer_id], loc, cache_k_nope, cache_k_rope
) )
def get_flat_data(self, indices): def load_from_host_per_layer(
# prepare a large chunk of contiguous data for efficient transfer self, host_pool, host_indices, device_indices, layer_id, io_backend
return torch.stack([self.kv_buffer[i][indices] for i in range(self.layer_num)]) ):
transfer_kv_per_layer_mla(
@debug_timing src=host_pool.kv_buffer[layer_id],
def transfer(self, indices, flat_data): dst=self.kv_buffer[layer_id],
# transfer prepared data from host to device src_indices=host_indices,
flat_data = flat_data.to(device=self.device, non_blocking=False) dst_indices=device_indices,
for i in range(self.layer_num): io_backend=io_backend,
self.kv_buffer[i][indices] = flat_data[i] page_size=self.page_size,
item_size=self.token_stride,
)
def transfer_per_layer(self, indices, flat_data, layer_id): def backup_to_host_all_layer(
# transfer prepared data from host to device self, host_pool, host_indices, device_indices, io_backend
flat_data = flat_data.to(device=self.device, non_blocking=False) ):
self.kv_buffer[layer_id - self.start_layer][indices] = flat_data # todo: specialized all layer kernels for the layer-non-contiguous memory pool
for layer_id in range(self.start_layer, self.start_layer + self.layer_num):
if layer_id - self.start_layer >= len(host_pool.kv_buffer):
raise ValueError(
f"Layer ID {layer_id} exceeds the number of layers in host pool."
)
transfer_kv_per_layer_mla(
src=self.kv_buffer[layer_id],
dst=host_pool.kv_buffer[layer_id],
src_indices=device_indices,
dst_indices=host_indices,
io_backend=io_backend,
page_size=self.page_size,
item_size=self.token_stride,
)
def get_cpu_copy(self, indices): def get_cpu_copy(self, indices):
torch.cuda.synchronize() torch.cuda.synchronize()
...@@ -1046,14 +1091,19 @@ class DoubleSparseTokenToKVPool(KVCache): ...@@ -1046,14 +1091,19 @@ class DoubleSparseTokenToKVPool(KVCache):
self.v_buffer[layer_id - self.start_layer][loc] = cache_v self.v_buffer[layer_id - self.start_layer][loc] = cache_v
self.label_buffer[layer_id - self.start_layer][loc] = cache_label self.label_buffer[layer_id - self.start_layer][loc] = cache_label
def get_flat_data(self, indices): def load_from_host_per_layer(
pass self, host_pool, host_indices, device_indices, layer_id, io_backend
):
def transfer(self, indices, flat_data): raise NotImplementedError(
pass "HiCache not supported for DoubleSparseTokenToKVPool."
)
def transfer_per_layer(self, indices, flat_data, layer_id): def backup_to_host_all_layer(
pass self, host_pool, host_indices, device_indices, io_backend
):
raise NotImplementedError(
"HiCache not supported for DoubleSparseTokenToKVPool."
)
@triton.jit @triton.jit
......
...@@ -8,7 +8,6 @@ import psutil ...@@ -8,7 +8,6 @@ import psutil
import torch import torch
from sglang.srt.mem_cache.memory_pool import KVCache, MHATokenToKVPool, MLATokenToKVPool from sglang.srt.mem_cache.memory_pool import KVCache, MHATokenToKVPool, MLATokenToKVPool
from sglang.srt.utils import debug_timing
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -99,22 +98,6 @@ class HostKVCache(abc.ABC): ...@@ -99,22 +98,6 @@ class HostKVCache(abc.ABC):
def init_kv_buffer(self): def init_kv_buffer(self):
raise NotImplementedError() raise NotImplementedError()
@abc.abstractmethod
def transfer(self, indices, flat_data):
raise NotImplementedError()
@abc.abstractmethod
def get_flat_data(self, indices):
raise NotImplementedError()
@abc.abstractmethod
def get_flat_data_by_layer(self, indices, layer_id):
raise NotImplementedError()
@abc.abstractmethod
def assign_flat_data(self, indices, flat_data):
raise NotImplementedError()
@synchronized() @synchronized()
def clear(self): def clear(self):
# Initialize memory states and tracking structures. # Initialize memory states and tracking structures.
...@@ -243,58 +226,13 @@ class MHATokenToKVPoolHost(HostKVCache): ...@@ -243,58 +226,13 @@ class MHATokenToKVPoolHost(HostKVCache):
pin_memory=self.pin_memory, pin_memory=self.pin_memory,
) )
@debug_timing @property
def transfer(self, indices, flat_data): def k_buffer(self):
# backup prepared data from device to host return self.kv_buffer[0]
self.kv_buffer[:, :, indices] = flat_data.to(
device=self.device, non_blocking=False
)
def get_flat_data(self, indices): @property
return self.kv_buffer[:, :, indices] def v_buffer(self):
return self.kv_buffer[1]
def get_flat_data_by_layer(self, indices, layer_id):
return self.kv_buffer[:, layer_id - self.start_layer, indices]
def assign_flat_data(self, indices, flat_data):
self.kv_buffer[:, :, indices] = flat_data
def write_page_all_layers(self, host_indices, device_indices, device_pool):
device_indices_cpu = device_indices[:: self.page_size].cpu()
for i in range(len(device_indices_cpu)):
h_index = host_indices[i * self.page_size]
d_index = device_indices_cpu[i]
for j in range(self.layer_num):
self.kv_buffer[0, j, h_index : h_index + self.page_size].copy_(
device_pool.k_buffer[j][d_index : d_index + self.page_size],
non_blocking=True,
)
self.kv_buffer[1, j, h_index : h_index + self.page_size].copy_(
device_pool.v_buffer[j][d_index : d_index + self.page_size],
non_blocking=True,
)
def load_page_per_layer(self, host_indices, device_indices, device_pool, layer_id):
device_indices_cpu = device_indices[:: self.page_size].cpu()
for i in range(len(device_indices_cpu)):
h_index = host_indices[i * self.page_size]
d_index = device_indices_cpu[i]
device_pool.k_buffer[layer_id - self.start_layer][
d_index : d_index + self.page_size
].copy_(
self.kv_buffer[
0, layer_id - self.start_layer, h_index : h_index + self.page_size
],
non_blocking=True,
)
device_pool.v_buffer[layer_id - self.start_layer][
d_index : d_index + self.page_size
].copy_(
self.kv_buffer[
1, layer_id - self.start_layer, h_index : h_index + self.page_size
],
non_blocking=True,
)
class MLATokenToKVPoolHost(HostKVCache): class MLATokenToKVPoolHost(HostKVCache):
...@@ -337,44 +275,3 @@ class MLATokenToKVPoolHost(HostKVCache): ...@@ -337,44 +275,3 @@ class MLATokenToKVPoolHost(HostKVCache):
device=self.device, device=self.device,
pin_memory=self.pin_memory, pin_memory=self.pin_memory,
) )
@debug_timing
def transfer(self, indices, flat_data):
# backup prepared data from device to host
self.kv_buffer[:, indices] = flat_data.to(
device=self.device, non_blocking=False
)
def get_flat_data(self, indices):
return self.kv_buffer[:, indices]
def get_flat_data_by_layer(self, indices, layer_id):
return self.kv_buffer[layer_id - self.start_layer, indices]
def assign_flat_data(self, indices, flat_data):
self.kv_buffer[:, indices] = flat_data
def write_page_all_layers(self, host_indices, device_indices, device_pool):
device_indices_cpu = device_indices[:: self.page_size].cpu()
for i in range(len(device_indices_cpu)):
h_index = host_indices[i * self.page_size]
d_index = device_indices_cpu[i]
for j in range(self.layer_num):
self.kv_buffer[j, h_index : h_index + self.page_size].copy_(
device_pool.kv_buffer[j][d_index : d_index + self.page_size],
non_blocking=True,
)
def load_page_per_layer(self, host_indices, device_indices, device_pool, layer_id):
device_indices_cpu = device_indices[:: self.page_size].cpu()
for i in range(len(device_indices_cpu)):
h_index = host_indices[i * self.page_size]
d_index = device_indices_cpu[i]
device_pool.kv_buffer[layer_id - self.start_layer][
d_index : d_index + self.page_size
].copy_(
self.kv_buffer[
layer_id - self.start_layer, h_index : h_index + self.page_size
],
non_blocking=True,
)
...@@ -196,11 +196,13 @@ class RadixCache(BasePrefixCache): ...@@ -196,11 +196,13 @@ class RadixCache(BasePrefixCache):
if self.page_size != 1: if self.page_size != 1:
page_aligned_len = len(kv_indices) // self.page_size * self.page_size page_aligned_len = len(kv_indices) // self.page_size * self.page_size
page_aligned_kv_indices = kv_indices[:page_aligned_len].clone() page_aligned_kv_indices = kv_indices[:page_aligned_len].to(
dtype=torch.int64, copy=True
)
self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:]) self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:])
else: else:
page_aligned_len = len(kv_indices) page_aligned_len = len(kv_indices)
page_aligned_kv_indices = kv_indices.clone() page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True)
# Radix Cache takes one ref in memory pool # Radix Cache takes one ref in memory pool
new_prefix_len = self.insert( new_prefix_len = self.insert(
...@@ -226,10 +228,12 @@ class RadixCache(BasePrefixCache): ...@@ -226,10 +228,12 @@ class RadixCache(BasePrefixCache):
if self.page_size != 1: if self.page_size != 1:
page_aligned_len = len(kv_indices) // self.page_size * self.page_size page_aligned_len = len(kv_indices) // self.page_size * self.page_size
page_aligned_kv_indices = kv_indices[:page_aligned_len].clone() page_aligned_kv_indices = kv_indices[:page_aligned_len].to(
dtype=torch.int64, copy=True
)
else: else:
page_aligned_len = len(kv_indices) page_aligned_len = len(kv_indices)
page_aligned_kv_indices = kv_indices.clone() page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True)
page_aligned_token_ids = token_ids[:page_aligned_len] page_aligned_token_ids = token_ids[:page_aligned_len]
# Radix Cache takes one ref in memory pool # Radix Cache takes one ref in memory pool
......
...@@ -217,6 +217,7 @@ class ServerArgs: ...@@ -217,6 +217,7 @@ class ServerArgs:
hicache_ratio: float = 2.0 hicache_ratio: float = 2.0
hicache_size: int = 0 hicache_size: int = 0
hicache_write_policy: str = "write_through_selective" hicache_write_policy: str = "write_through_selective"
hicache_io_backend: str = ""
flashinfer_mla_disable_ragged: bool = False flashinfer_mla_disable_ragged: bool = False
disable_shared_experts_fusion: bool = False disable_shared_experts_fusion: bool = False
disable_chunked_prefix_cache: bool = False disable_chunked_prefix_cache: bool = False
...@@ -1530,6 +1531,13 @@ class ServerArgs: ...@@ -1530,6 +1531,13 @@ class ServerArgs:
default=ServerArgs.hicache_write_policy, default=ServerArgs.hicache_write_policy,
help="The write policy of hierarchical cache.", help="The write policy of hierarchical cache.",
) )
parser.add_argument(
"--hicache-io-backend",
type=str,
choices=["direct", "kernel"],
default=ServerArgs.hicache_io_backend,
help="The IO backend for KV cache transfer between CPU and GPU",
)
parser.add_argument( parser.add_argument(
"--flashinfer-mla-disable-ragged", "--flashinfer-mla-disable-ragged",
action="store_true", action="store_true",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment