Unverified Commit 51caee74 authored by Zhiqiang Xie's avatar Zhiqiang Xie Committed by GitHub
Browse files

Host memory pool for hierarchical caching (#2771)

parent 58f9060e
...@@ -22,12 +22,16 @@ BaseTokenToKVPool maps a token location to its KV cache data. ...@@ -22,12 +22,16 @@ BaseTokenToKVPool maps a token location to its KV cache data.
""" """
import logging import logging
import threading
from enum import IntEnum
from functools import wraps
from typing import List, Tuple, Union from typing import List, Tuple, Union
import psutil
import torch import torch
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.utils import get_compiler_backend from sglang.srt.utils import debug_timing, get_compiler_backend
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -213,6 +217,26 @@ class MHATokenToKVPool(BaseTokenToKVPool): ...@@ -213,6 +217,26 @@ class MHATokenToKVPool(BaseTokenToKVPool):
del self.k_buffer del self.k_buffer
del self.v_buffer del self.v_buffer
# Todo: different memory layout
def get_flat_data(self, indices):
# prepare a large chunk of contiguous data for efficient transfer
flatten = torch.stack(
[
torch.stack([self.k_buffer[i][indices] for i in range(self.layer_num)]),
torch.stack([self.v_buffer[i][indices] for i in range(self.layer_num)]),
]
)
return flatten
@debug_timing
def transfer(self, indices, flat_data):
# transfer prepared data from host to device
flat_data = flat_data.to(device=self.device, non_blocking=False)
k_data, v_data = flat_data[0], flat_data[1]
for i in range(self.layer_num):
self.k_buffer[i][indices] = k_data[i]
self.v_buffer[i][indices] = v_data[i]
def get_key_buffer(self, layer_id: int): def get_key_buffer(self, layer_id: int):
if self.store_dtype != self.dtype: if self.store_dtype != self.dtype:
return self.k_buffer[layer_id].view(self.dtype) return self.k_buffer[layer_id].view(self.dtype)
...@@ -361,3 +385,184 @@ class DoubleSparseTokenToKVPool(BaseTokenToKVPool): ...@@ -361,3 +385,184 @@ class DoubleSparseTokenToKVPool(BaseTokenToKVPool):
self.k_buffer[layer_id][loc] = cache_k self.k_buffer[layer_id][loc] = cache_k
self.v_buffer[layer_id][loc] = cache_v self.v_buffer[layer_id][loc] = cache_v
self.label_buffer[layer_id][loc] = cache_label self.label_buffer[layer_id][loc] = cache_label
class MemoryStateInt(IntEnum):
IDLE = 0
RESERVED = 1
PROTECTED = 2
SYNCED = 3
BACKUP = 4
def synchronized(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
with self.lock:
return func(self, *args, **kwargs)
return wrapper
class MLATokenToKVPoolHost:
def __init__(
self,
device_pool: MHATokenToKVPool,
host_to_device_ratio: float = 2.0,
pin_memory: bool = False, # no need to use pin memory with the double buffering
device: str = "cpu",
):
assert (
host_to_device_ratio >= 1
), "The host memory should be larger than the device memory with the current protocol"
# todo, other ways of configuring the size
self.device_pool = device_pool
self.host_to_device_ratio = host_to_device_ratio
self.pin_memory = pin_memory
self.device = device
self.size = int(device_pool.size * host_to_device_ratio)
self.dtype = device_pool.store_dtype
self.head_num = device_pool.head_num
self.head_dim = device_pool.head_dim
self.layer_num = device_pool.layer_num
self.size_per_token = (
self.head_dim * self.head_num * self.layer_num * self.dtype.itemsize * 2
)
# Verify there is enough available host memory.
host_mem = psutil.virtual_memory()
requested_bytes = self.size * self.size_per_token
# preserve at least 10GB for other usage
ten_gb = 10 * (1024**3)
if requested_bytes > host_mem.available - ten_gb:
raise ValueError(
f"Not enough host memory available. Requesting "
f"{requested_bytes / 1e9:.2f} GB but only have "
f"{host_mem.available / 1e9:.2f} GB free. Please reduce the "
f"size of the hierarchical cache."
)
else:
logger.info(
f"Allocating {requested_bytes / 1e9:.2f} GB host memory for hierarchical KV cache."
)
self.kv_buffer = torch.empty(
(2, self.layer_num, self.size, self.head_num, self.head_dim),
dtype=self.dtype,
device=self.device,
pin_memory=self.pin_memory,
)
# Initialize memory states and tracking structures.
self.mem_state = torch.zeros(
(self.size,), dtype=torch.uint8, device=self.device
)
self.free_slots = torch.arange(self.size, dtype=torch.int32)
self.can_use_mem_size = self.size
# A lock for synchronized operations on memory allocation and state transitions.
self.lock = threading.RLock()
def get_flat_data(self, indices):
return self.kv_buffer[:, :, indices]
@debug_timing
def transfer(self, indices, flat_data):
# backup prepared data from device to host
self.kv_buffer[:, :, indices] = flat_data.to(
device=self.device, non_blocking=False
)
@synchronized
def clear(self):
self.mem_state.fill_(0)
self.can_use_mem_size = self.size
self.free_slots = torch.arange(self.size, dtype=torch.int32)
@synchronized
def get_state(self, indices: torch.Tensor) -> MemoryStateInt:
assert len(indices) > 0, "The indices should not be empty"
states = self.mem_state[indices]
assert (
states == states[0]
).all(), "The memory slots should have the same state {}".format(states)
return MemoryStateInt(states[0].item())
@synchronized
def alloc(self, need_size: int) -> torch.Tensor:
if need_size > self.can_use_mem_size:
return None
# todo: de-fragementation
select_index = self.free_slots[:need_size]
self.free_slots = self.free_slots[need_size:]
self.mem_state[select_index] = MemoryStateInt.RESERVED
self.can_use_mem_size -= need_size
return select_index
@synchronized
def is_reserved(self, indices: torch.Tensor) -> bool:
return self.get_state(indices) == MemoryStateInt.RESERVED
@synchronized
def is_protected(self, indices: torch.Tensor) -> bool:
return self.get_state(indices) == MemoryStateInt.PROTECTED
@synchronized
def is_synced(self, indices: torch.Tensor) -> bool:
return self.get_state(indices) == MemoryStateInt.SYNCED
@synchronized
def is_backup(self, indices: torch.Tensor) -> bool:
return self.get_state(indices) == MemoryStateInt.BACKUP
@synchronized
def update_backup(self, indices: torch.Tensor):
assert self.is_synced(indices), (
f"The host memory slots should be in SYNCED state before turning into BACKUP. "
f"Current state: {self.get_state(indices)}"
)
self.mem_state[indices] = MemoryStateInt.BACKUP
@synchronized
def update_synced(self, indices: torch.Tensor):
self.mem_state[indices] = MemoryStateInt.SYNCED
@synchronized
def protect_write(self, indices: torch.Tensor):
assert self.is_reserved(indices), (
f"The host memory slots should be RESERVED before write operations. "
f"Current state: {self.get_state(indices)}"
)
self.mem_state[indices] = MemoryStateInt.PROTECTED
@synchronized
def protect_load(self, indices: torch.Tensor):
assert self.is_backup(indices), (
f"The host memory slots should be in BACKUP state before load operations. "
f"Current state: {self.get_state(indices)}"
)
self.mem_state[indices] = MemoryStateInt.PROTECTED
@synchronized
def complete_io(self, indices: torch.Tensor):
assert self.is_protected(indices), (
f"The host memory slots should be PROTECTED during I/O operations. "
f"Current state: {self.get_state(indices)}"
)
self.mem_state[indices] = MemoryStateInt.SYNCED
def available_size(self):
return len(self.free_slots)
@synchronized
def free(self, indices: torch.Tensor) -> int:
self.mem_state[indices] = MemoryStateInt.IDLE
self.free_slots = torch.concat([self.free_slots, indices])
self.can_use_mem_size += len(indices)
return len(indices)
...@@ -1349,3 +1349,27 @@ class MultiprocessingSerializer: ...@@ -1349,3 +1349,27 @@ class MultiprocessingSerializer:
@staticmethod @staticmethod
def deserialize(data): def deserialize(data):
return ForkingPickler.loads(data) return ForkingPickler.loads(data)
def debug_timing(func):
# todo: replace with a more organized instrumentation
def wrapper(*args, **kwargs):
if logger.isEnabledFor(logging.DEBUG):
tic = torch.cuda.Event(enable_timing=True)
toc = torch.cuda.Event(enable_timing=True)
tic.record()
result = func(*args, **kwargs)
toc.record()
torch.cuda.synchronize() # Ensure all CUDA operations are complete
elapsed = tic.elapsed_time(toc)
indices = kwargs.get("indices", args[1] if len(args) > 1 else None)
num_tokens = len(indices) if indices is not None else 0
throughput = num_tokens / elapsed * 1000 if elapsed > 0 else 0
logger.debug(
f"Transfer time: {elapsed} ms, throughput: {throughput} tokens/s"
)
return result
else:
return func(*args, **kwargs)
return wrapper
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment