Unverified Commit a023856b authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Move host memory pools into a separate file (#7200)

parent db0cc57e
......@@ -22,7 +22,8 @@ from typing import List, Optional
import torch
from sglang.srt.mem_cache.memory_pool import HostKVCache, TokenToKVPoolAllocator
from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
from sglang.srt.mem_cache.memory_pool_host import HostKVCache
logger = logging.getLogger(__name__)
......
......@@ -9,12 +9,14 @@ import torch
from sglang.srt.managers.cache_controller import HiCacheController
from sglang.srt.mem_cache.memory_pool import (
MHATokenToKVPool,
MHATokenToKVPoolHost,
MLATokenToKVPool,
MLATokenToKVPoolHost,
ReqToTokenPool,
TokenToKVPoolAllocator,
)
from sglang.srt.mem_cache.memory_pool_host import (
MHATokenToKVPoolHost,
MLATokenToKVPoolHost,
)
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
logger = logging.getLogger(__name__)
......
......@@ -26,24 +26,15 @@ KVCache actually holds the physical kv cache.
import abc
import logging
import threading
from enum import IntEnum
from functools import wraps
from typing import List, Optional, Tuple, Union
import numpy as np
import psutil
import torch
import triton
import triton.language as tl
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.utils import (
debug_timing,
get_compiler_backend,
is_cuda,
next_power_of_2,
)
from sglang.srt.utils import debug_timing, get_compiler_backend, is_cuda
logger = logging.getLogger(__name__)
......@@ -772,370 +763,3 @@ class DoubleSparseTokenToKVPool(KVCache):
def transfer_per_layer(self, indices, flat_data, layer_id):
pass
class MemoryStateInt(IntEnum):
IDLE = 0
RESERVED = 1
PROTECTED = 2
SYNCED = 3
BACKUP = 4
def synchronized(debug_only=False):
def _decorator(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
if (not debug_only) or self.debug:
return func(self, *args, **kwargs)
with self.lock:
return func(self, *args, **kwargs)
else:
return True
return wrapper
return _decorator
class HostKVCache(abc.ABC):
def __init__(
self,
device_pool: KVCache,
host_to_device_ratio: float,
host_size: int,
pin_memory: bool,
device: str,
page_size: int,
):
self.device_pool = device_pool
self.dtype = device_pool.store_dtype
self.pin_memory = pin_memory
self.device = device
self.page_size = page_size
self.size_per_token = self.get_size_per_token()
if host_size > 0:
self.size = int(host_size * 1e9 // self.size_per_token)
else:
self.size = int(device_pool.size * host_to_device_ratio)
# Align the host memory pool size to the page size
self.size = self.size - (self.size % self.page_size)
self.start_layer = device_pool.start_layer
self.end_layer = device_pool.end_layer
assert (
self.size > device_pool.size
), "The host memory should be larger than the device memory with the current protocol"
# Verify there is enough available host memory.
host_mem = psutil.virtual_memory()
requested_bytes = self.size * self.size_per_token
# preserve at least 10GB for other usage
ten_gb = 10 * (1024**3)
if requested_bytes > host_mem.available - ten_gb:
raise ValueError(
f"Not enough host memory available. Requesting "
f"{requested_bytes / 1e9:.2f} GB but only have "
f"{host_mem.available / 1e9:.2f} GB free. Please reduce the "
f"size of the hierarchical cache."
)
else:
logger.info(
f"Allocating {requested_bytes / 1e9:.2f} GB host memory for hierarchical KV cache."
)
self.kv_buffer = self.init_kv_buffer()
# A lock for synchronized operations on memory allocation and state transitions.
self.lock = threading.RLock()
self.debug = logger.isEnabledFor(logging.DEBUG)
self.clear()
@abc.abstractmethod
def get_size_per_token(self):
raise NotImplementedError()
@abc.abstractmethod
def init_kv_buffer(self):
raise NotImplementedError()
@abc.abstractmethod
def transfer(self, indices, flat_data):
raise NotImplementedError()
@abc.abstractmethod
def get_flat_data(self, indices):
raise NotImplementedError()
@abc.abstractmethod
def get_flat_data_by_layer(self, indices, layer_id):
raise NotImplementedError()
@abc.abstractmethod
def assign_flat_data(self, indices, flat_data):
raise NotImplementedError()
@synchronized()
def clear(self):
# Initialize memory states and tracking structures.
self.mem_state = torch.zeros(
(self.size,), dtype=torch.uint8, device=self.device
)
self.free_slots = torch.arange(self.size, dtype=torch.int64)
def available_size(self):
return len(self.free_slots)
@synchronized()
def alloc(self, need_size: int) -> torch.Tensor:
if need_size > self.available_size():
return None
select_index = self.free_slots[:need_size]
self.free_slots = self.free_slots[need_size:]
if self.debug:
self.mem_state[select_index] = MemoryStateInt.RESERVED
return select_index
@synchronized()
def free(self, indices: torch.Tensor) -> int:
self.free_slots = torch.cat([self.free_slots, indices])
if self.debug:
self.mem_state[indices] = MemoryStateInt.IDLE
return len(indices)
@synchronized(debug_only=True)
def get_state(self, indices: torch.Tensor) -> MemoryStateInt:
assert len(indices) > 0, "The indices should not be empty"
states = self.mem_state[indices]
assert (
states == states[0]
).all(), "The memory slots should have the same state {}".format(states)
return MemoryStateInt(states[0].item())
@synchronized(debug_only=True)
def is_reserved(self, indices: torch.Tensor) -> bool:
return self.get_state(indices) == MemoryStateInt.RESERVED
@synchronized(debug_only=True)
def is_protected(self, indices: torch.Tensor) -> bool:
return self.get_state(indices) == MemoryStateInt.PROTECTED
@synchronized(debug_only=True)
def is_synced(self, indices: torch.Tensor) -> bool:
return self.get_state(indices) == MemoryStateInt.SYNCED
@synchronized(debug_only=True)
def is_backup(self, indices: torch.Tensor) -> bool:
return self.get_state(indices) == MemoryStateInt.BACKUP
@synchronized(debug_only=True)
def update_backup(self, indices: torch.Tensor):
if not self.is_synced(indices):
raise ValueError(
f"The host memory slots should be in SYNCED state before turning into BACKUP. "
f"Current state: {self.get_state(indices)}"
)
self.mem_state[indices] = MemoryStateInt.BACKUP
@synchronized(debug_only=True)
def update_synced(self, indices: torch.Tensor):
self.mem_state[indices] = MemoryStateInt.SYNCED
@synchronized(debug_only=True)
def protect_write(self, indices: torch.Tensor):
if not self.is_reserved(indices):
raise ValueError(
f"The host memory slots should be RESERVED before write operations. "
f"Current state: {self.get_state(indices)}"
)
self.mem_state[indices] = MemoryStateInt.PROTECTED
@synchronized(debug_only=True)
def protect_load(self, indices: torch.Tensor):
if not self.is_backup(indices):
raise ValueError(
f"The host memory slots should be in BACKUP state before load operations. "
f"Current state: {self.get_state(indices)}"
)
self.mem_state[indices] = MemoryStateInt.PROTECTED
@synchronized(debug_only=True)
def complete_io(self, indices: torch.Tensor):
if not self.is_protected(indices):
raise ValueError(
f"The host memory slots should be PROTECTED during I/O operations. "
f"Current state: {self.get_state(indices)}"
)
self.mem_state[indices] = MemoryStateInt.SYNCED
class MHATokenToKVPoolHost(HostKVCache):
device_pool: MHATokenToKVPool
def __init__(
self,
device_pool: MHATokenToKVPool,
host_to_device_ratio: float,
host_size: int,
page_size: int,
pin_memory: bool = True,
device: str = "cpu",
):
super().__init__(
device_pool, host_to_device_ratio, host_size, pin_memory, device, page_size
)
def get_size_per_token(self):
self.head_num = self.device_pool.head_num
self.head_dim = self.device_pool.head_dim
self.layer_num = self.device_pool.layer_num
return self.head_dim * self.head_num * self.layer_num * self.dtype.itemsize * 2
def init_kv_buffer(self):
return torch.empty(
(2, self.layer_num, self.size, self.head_num, self.head_dim),
dtype=self.dtype,
device=self.device,
pin_memory=self.pin_memory,
)
@debug_timing
def transfer(self, indices, flat_data):
# backup prepared data from device to host
self.kv_buffer[:, :, indices] = flat_data.to(
device=self.device, non_blocking=False
)
def get_flat_data(self, indices):
return self.kv_buffer[:, :, indices]
def get_flat_data_by_layer(self, indices, layer_id):
return self.kv_buffer[:, layer_id - self.start_layer, indices]
def assign_flat_data(self, indices, flat_data):
self.kv_buffer[:, :, indices] = flat_data
def write_page_all_layers(self, host_indices, device_indices, device_pool):
device_indices_cpu = device_indices[:: self.page_size].cpu()
for i in range(len(device_indices_cpu)):
h_index = host_indices[i * self.page_size]
d_index = device_indices_cpu[i]
for j in range(self.layer_num):
self.kv_buffer[0, j, h_index : h_index + self.page_size].copy_(
device_pool.k_buffer[j][d_index : d_index + self.page_size],
non_blocking=True,
)
self.kv_buffer[1, j, h_index : h_index + self.page_size].copy_(
device_pool.v_buffer[j][d_index : d_index + self.page_size],
non_blocking=True,
)
def load_page_per_layer(self, host_indices, device_indices, device_pool, layer_id):
device_indices_cpu = device_indices[:: self.page_size].cpu()
for i in range(len(device_indices_cpu)):
h_index = host_indices[i * self.page_size]
d_index = device_indices_cpu[i]
device_pool.k_buffer[layer_id - self.start_layer][
d_index : d_index + self.page_size
].copy_(
self.kv_buffer[
0, layer_id - self.start_layer, h_index : h_index + self.page_size
],
non_blocking=True,
)
device_pool.v_buffer[layer_id - self.start_layer][
d_index : d_index + self.page_size
].copy_(
self.kv_buffer[
1, layer_id - self.start_layer, h_index : h_index + self.page_size
],
non_blocking=True,
)
class MLATokenToKVPoolHost(HostKVCache):
device_pool: MLATokenToKVPool
def __init__(
self,
device_pool: MLATokenToKVPool,
host_to_device_ratio: float,
host_size: int,
page_size: int,
pin_memory: bool = True,
device: str = "cpu",
):
super().__init__(
device_pool, host_to_device_ratio, host_size, pin_memory, device, page_size
)
def get_size_per_token(self):
self.kv_lora_rank = self.device_pool.kv_lora_rank
self.qk_rope_head_dim = self.device_pool.qk_rope_head_dim
self.layer_num = self.device_pool.layer_num
return (
(self.kv_lora_rank + self.qk_rope_head_dim)
* 1
* self.dtype.itemsize
* self.layer_num
)
def init_kv_buffer(self):
return torch.empty(
(
self.layer_num,
self.size,
1,
self.kv_lora_rank + self.qk_rope_head_dim,
),
dtype=self.dtype,
device=self.device,
pin_memory=self.pin_memory,
)
@debug_timing
def transfer(self, indices, flat_data):
# backup prepared data from device to host
self.kv_buffer[:, indices] = flat_data.to(
device=self.device, non_blocking=False
)
def get_flat_data(self, indices):
return self.kv_buffer[:, indices]
def get_flat_data_by_layer(self, indices, layer_id):
return self.kv_buffer[layer_id - self.start_layer, indices]
def assign_flat_data(self, indices, flat_data):
self.kv_buffer[:, indices] = flat_data
def write_page_all_layers(self, host_indices, device_indices, device_pool):
device_indices_cpu = device_indices[:: self.page_size].cpu()
for i in range(len(device_indices_cpu)):
h_index = host_indices[i * self.page_size]
d_index = device_indices_cpu[i]
for j in range(self.layer_num):
self.kv_buffer[j, h_index : h_index + self.page_size].copy_(
device_pool.kv_buffer[j][d_index : d_index + self.page_size],
non_blocking=True,
)
def load_page_per_layer(self, host_indices, device_indices, device_pool, layer_id):
device_indices_cpu = device_indices[:: self.page_size].cpu()
for i in range(len(device_indices_cpu)):
h_index = host_indices[i * self.page_size]
d_index = device_indices_cpu[i]
device_pool.kv_buffer[layer_id - self.start_layer][
d_index : d_index + self.page_size
].copy_(
self.kv_buffer[
layer_id - self.start_layer, h_index : h_index + self.page_size
],
non_blocking=True,
)
import abc
import logging
import threading
from enum import IntEnum
from functools import wraps
import psutil
import torch
from sglang.srt.mem_cache.memory_pool import KVCache, MHATokenToKVPool, MLATokenToKVPool
from sglang.srt.utils import debug_timing
logger = logging.getLogger(__name__)
class MemoryStateInt(IntEnum):
IDLE = 0
RESERVED = 1
PROTECTED = 2
SYNCED = 3
BACKUP = 4
def synchronized(debug_only=False):
def _decorator(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
if (not debug_only) or self.debug:
return func(self, *args, **kwargs)
with self.lock:
return func(self, *args, **kwargs)
else:
return True
return wrapper
return _decorator
class HostKVCache(abc.ABC):
def __init__(
self,
device_pool: KVCache,
host_to_device_ratio: float,
host_size: int,
pin_memory: bool,
device: str,
page_size: int,
):
self.device_pool = device_pool
self.dtype = device_pool.store_dtype
self.pin_memory = pin_memory
self.device = device
self.page_size = page_size
self.size_per_token = self.get_size_per_token()
if host_size > 0:
self.size = int(host_size * 1e9 // self.size_per_token)
else:
self.size = int(device_pool.size * host_to_device_ratio)
# Align the host memory pool size to the page size
self.size = self.size - (self.size % self.page_size)
self.start_layer = device_pool.start_layer
self.end_layer = device_pool.end_layer
assert (
self.size > device_pool.size
), "The host memory should be larger than the device memory with the current protocol"
# Verify there is enough available host memory.
host_mem = psutil.virtual_memory()
requested_bytes = self.size * self.size_per_token
# preserve at least 10GB for other usage
ten_gb = 10 * (1024**3)
if requested_bytes > host_mem.available - ten_gb:
raise ValueError(
f"Not enough host memory available. Requesting "
f"{requested_bytes / 1e9:.2f} GB but only have "
f"{host_mem.available / 1e9:.2f} GB free. Please reduce the "
f"size of the hierarchical cache."
)
else:
logger.info(
f"Allocating {requested_bytes / 1e9:.2f} GB host memory for hierarchical KV cache."
)
self.kv_buffer = self.init_kv_buffer()
# A lock for synchronized operations on memory allocation and state transitions.
self.lock = threading.RLock()
self.debug = logger.isEnabledFor(logging.DEBUG)
self.clear()
@abc.abstractmethod
def get_size_per_token(self):
raise NotImplementedError()
@abc.abstractmethod
def init_kv_buffer(self):
raise NotImplementedError()
@abc.abstractmethod
def transfer(self, indices, flat_data):
raise NotImplementedError()
@abc.abstractmethod
def get_flat_data(self, indices):
raise NotImplementedError()
@abc.abstractmethod
def get_flat_data_by_layer(self, indices, layer_id):
raise NotImplementedError()
@abc.abstractmethod
def assign_flat_data(self, indices, flat_data):
raise NotImplementedError()
@synchronized()
def clear(self):
# Initialize memory states and tracking structures.
self.mem_state = torch.zeros(
(self.size,), dtype=torch.uint8, device=self.device
)
self.free_slots = torch.arange(self.size, dtype=torch.int64)
def available_size(self):
return len(self.free_slots)
@synchronized()
def alloc(self, need_size: int) -> torch.Tensor:
if need_size > self.available_size():
return None
select_index = self.free_slots[:need_size]
self.free_slots = self.free_slots[need_size:]
if self.debug:
self.mem_state[select_index] = MemoryStateInt.RESERVED
return select_index
@synchronized()
def free(self, indices: torch.Tensor) -> int:
self.free_slots = torch.cat([self.free_slots, indices])
if self.debug:
self.mem_state[indices] = MemoryStateInt.IDLE
return len(indices)
@synchronized(debug_only=True)
def get_state(self, indices: torch.Tensor) -> MemoryStateInt:
assert len(indices) > 0, "The indices should not be empty"
states = self.mem_state[indices]
assert (
states == states[0]
).all(), "The memory slots should have the same state {}".format(states)
return MemoryStateInt(states[0].item())
@synchronized(debug_only=True)
def is_reserved(self, indices: torch.Tensor) -> bool:
return self.get_state(indices) == MemoryStateInt.RESERVED
@synchronized(debug_only=True)
def is_protected(self, indices: torch.Tensor) -> bool:
return self.get_state(indices) == MemoryStateInt.PROTECTED
@synchronized(debug_only=True)
def is_synced(self, indices: torch.Tensor) -> bool:
return self.get_state(indices) == MemoryStateInt.SYNCED
@synchronized(debug_only=True)
def is_backup(self, indices: torch.Tensor) -> bool:
return self.get_state(indices) == MemoryStateInt.BACKUP
@synchronized(debug_only=True)
def update_backup(self, indices: torch.Tensor):
if not self.is_synced(indices):
raise ValueError(
f"The host memory slots should be in SYNCED state before turning into BACKUP. "
f"Current state: {self.get_state(indices)}"
)
self.mem_state[indices] = MemoryStateInt.BACKUP
@synchronized(debug_only=True)
def update_synced(self, indices: torch.Tensor):
self.mem_state[indices] = MemoryStateInt.SYNCED
@synchronized(debug_only=True)
def protect_write(self, indices: torch.Tensor):
if not self.is_reserved(indices):
raise ValueError(
f"The host memory slots should be RESERVED before write operations. "
f"Current state: {self.get_state(indices)}"
)
self.mem_state[indices] = MemoryStateInt.PROTECTED
@synchronized(debug_only=True)
def protect_load(self, indices: torch.Tensor):
if not self.is_backup(indices):
raise ValueError(
f"The host memory slots should be in BACKUP state before load operations. "
f"Current state: {self.get_state(indices)}"
)
self.mem_state[indices] = MemoryStateInt.PROTECTED
@synchronized(debug_only=True)
def complete_io(self, indices: torch.Tensor):
if not self.is_protected(indices):
raise ValueError(
f"The host memory slots should be PROTECTED during I/O operations. "
f"Current state: {self.get_state(indices)}"
)
self.mem_state[indices] = MemoryStateInt.SYNCED
class MHATokenToKVPoolHost(HostKVCache):
device_pool: MHATokenToKVPool
def __init__(
self,
device_pool: MHATokenToKVPool,
host_to_device_ratio: float,
host_size: int,
page_size: int,
pin_memory: bool = True,
device: str = "cpu",
):
super().__init__(
device_pool, host_to_device_ratio, host_size, pin_memory, device, page_size
)
def get_size_per_token(self):
self.head_num = self.device_pool.head_num
self.head_dim = self.device_pool.head_dim
self.layer_num = self.device_pool.layer_num
return self.head_dim * self.head_num * self.layer_num * self.dtype.itemsize * 2
def init_kv_buffer(self):
return torch.empty(
(2, self.layer_num, self.size, self.head_num, self.head_dim),
dtype=self.dtype,
device=self.device,
pin_memory=self.pin_memory,
)
@debug_timing
def transfer(self, indices, flat_data):
# backup prepared data from device to host
self.kv_buffer[:, :, indices] = flat_data.to(
device=self.device, non_blocking=False
)
def get_flat_data(self, indices):
return self.kv_buffer[:, :, indices]
def get_flat_data_by_layer(self, indices, layer_id):
return self.kv_buffer[:, layer_id - self.start_layer, indices]
def assign_flat_data(self, indices, flat_data):
self.kv_buffer[:, :, indices] = flat_data
def write_page_all_layers(self, host_indices, device_indices, device_pool):
device_indices_cpu = device_indices[:: self.page_size].cpu()
for i in range(len(device_indices_cpu)):
h_index = host_indices[i * self.page_size]
d_index = device_indices_cpu[i]
for j in range(self.layer_num):
self.kv_buffer[0, j, h_index : h_index + self.page_size].copy_(
device_pool.k_buffer[j][d_index : d_index + self.page_size],
non_blocking=True,
)
self.kv_buffer[1, j, h_index : h_index + self.page_size].copy_(
device_pool.v_buffer[j][d_index : d_index + self.page_size],
non_blocking=True,
)
def load_page_per_layer(self, host_indices, device_indices, device_pool, layer_id):
device_indices_cpu = device_indices[:: self.page_size].cpu()
for i in range(len(device_indices_cpu)):
h_index = host_indices[i * self.page_size]
d_index = device_indices_cpu[i]
device_pool.k_buffer[layer_id - self.start_layer][
d_index : d_index + self.page_size
].copy_(
self.kv_buffer[
0, layer_id - self.start_layer, h_index : h_index + self.page_size
],
non_blocking=True,
)
device_pool.v_buffer[layer_id - self.start_layer][
d_index : d_index + self.page_size
].copy_(
self.kv_buffer[
1, layer_id - self.start_layer, h_index : h_index + self.page_size
],
non_blocking=True,
)
class MLATokenToKVPoolHost(HostKVCache):
device_pool: MLATokenToKVPool
def __init__(
self,
device_pool: MLATokenToKVPool,
host_to_device_ratio: float,
host_size: int,
page_size: int,
pin_memory: bool = True,
device: str = "cpu",
):
super().__init__(
device_pool, host_to_device_ratio, host_size, pin_memory, device, page_size
)
def get_size_per_token(self):
self.kv_lora_rank = self.device_pool.kv_lora_rank
self.qk_rope_head_dim = self.device_pool.qk_rope_head_dim
self.layer_num = self.device_pool.layer_num
return (
(self.kv_lora_rank + self.qk_rope_head_dim)
* 1
* self.dtype.itemsize
* self.layer_num
)
def init_kv_buffer(self):
return torch.empty(
(
self.layer_num,
self.size,
1,
self.kv_lora_rank + self.qk_rope_head_dim,
),
dtype=self.dtype,
device=self.device,
pin_memory=self.pin_memory,
)
@debug_timing
def transfer(self, indices, flat_data):
# backup prepared data from device to host
self.kv_buffer[:, indices] = flat_data.to(
device=self.device, non_blocking=False
)
def get_flat_data(self, indices):
return self.kv_buffer[:, indices]
def get_flat_data_by_layer(self, indices, layer_id):
return self.kv_buffer[layer_id - self.start_layer, indices]
def assign_flat_data(self, indices, flat_data):
self.kv_buffer[:, indices] = flat_data
def write_page_all_layers(self, host_indices, device_indices, device_pool):
device_indices_cpu = device_indices[:: self.page_size].cpu()
for i in range(len(device_indices_cpu)):
h_index = host_indices[i * self.page_size]
d_index = device_indices_cpu[i]
for j in range(self.layer_num):
self.kv_buffer[j, h_index : h_index + self.page_size].copy_(
device_pool.kv_buffer[j][d_index : d_index + self.page_size],
non_blocking=True,
)
def load_page_per_layer(self, host_indices, device_indices, device_pool, layer_id):
device_indices_cpu = device_indices[:: self.page_size].cpu()
for i in range(len(device_indices_cpu)):
h_index = host_indices[i * self.page_size]
d_index = device_indices_cpu[i]
device_pool.kv_buffer[layer_id - self.start_layer][
d_index : d_index + self.page_size
].copy_(
self.kv_buffer[
layer_id - self.start_layer, h_index : h_index + self.page_size
],
non_blocking=True,
)
......@@ -26,7 +26,7 @@ class TestHiCachePage(CustomTestCase):
"--page-size",
32,
"--hicache-write-policy",
"write-back",
"write_back",
],
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment