Unverified Commit 0e0ec702 authored by Lu Changqi's avatar Lu Changqi Committed by GitHub
Browse files

Hierarchical Caching supports MLA (#4009)


Signed-off-by: default avatarChangqi Lu <luchangqi.123@bytedance.com>
Co-authored-by: default avatarZhiqiang Xie <xiezhq@stanford.edu>
parent bb378556
......@@ -22,10 +22,7 @@ from typing import List, Optional
import torch
from sglang.srt.mem_cache.memory_pool import (
MHATokenToKVPoolHost,
TokenToKVPoolAllocator,
)
from sglang.srt.mem_cache.memory_pool import HostKVCache, TokenToKVPoolAllocator
logger = logging.getLogger(__name__)
......@@ -151,7 +148,7 @@ class HiCacheController:
def __init__(
self,
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
mem_pool_host: MHATokenToKVPoolHost,
mem_pool_host: HostKVCache,
load_cache_event: threading.Event = None,
write_policy: str = "write_through_selective",
):
......
......@@ -8,7 +8,10 @@ import torch
from sglang.srt.managers.cache_controller import HiCacheController
from sglang.srt.mem_cache.memory_pool import (
MHATokenToKVPool,
MHATokenToKVPoolHost,
MLATokenToKVPool,
MLATokenToKVPoolHost,
ReqToTokenPool,
TokenToKVPoolAllocator,
)
......@@ -31,9 +34,14 @@ class HiRadixCache(RadixCache):
raise ValueError(
"Page size larger than 1 is not yet supported in HiRadixCache."
)
self.token_to_kv_pool_host = MHATokenToKVPoolHost(
token_to_kv_pool_allocator.get_kvcache()
)
self.kv_cache = token_to_kv_pool_allocator.get_kvcache()
if isinstance(self.kv_cache, MHATokenToKVPool):
self.token_to_kv_pool_host = MHATokenToKVPoolHost(self.kv_cache)
elif isinstance(self.kv_cache, MLATokenToKVPool):
self.token_to_kv_pool_host = MLATokenToKVPoolHost(self.kv_cache)
else:
raise ValueError(f"Only MHA and MLA supports swap kv_cache to host.")
self.tp_group = tp_cache_group
self.page_size = page_size
......@@ -317,13 +325,11 @@ class HiRadixCache(RadixCache):
prefix_len = _key_match(child.key, key)
if prefix_len < len(child.key):
new_node = self._split_node(child.key, child, prefix_len)
self.inc_hit_count(new_node)
if not new_node.evicted:
value.append(new_node.value)
node = new_node
break
else:
self.inc_hit_count(child)
if not child.evicted:
value.append(child.value)
node = child
......
......@@ -115,6 +115,21 @@ class KVCache(abc.ABC):
) -> None:
raise NotImplementedError()
@abc.abstractmethod
def get_flat_data(self, indices):
raise NotImplementedError()
@abc.abstractmethod
def transfer(self, indices, flat_data):
raise NotImplementedError()
@abc.abstractmethod
def transfer_per_layer(self, indices, flat_data, layer_id):
raise NotImplementedError()
def register_layer_transfer_counter(self, layer_transfer_counter):
self.layer_transfer_counter = layer_transfer_counter
class TokenToKVPoolAllocator:
"""An allocator managing the indices to kv cache data."""
......@@ -275,9 +290,6 @@ class MHATokenToKVPool(KVCache):
self.k_buffer[i][indices] = k_data[i]
self.v_buffer[i][indices] = v_data[i]
def register_layer_transfer_counter(self, layer_transfer_counter):
self.layer_transfer_counter = layer_transfer_counter
def transfer_per_layer(self, indices, flat_data, layer_id):
# transfer prepared data from host to device
flat_data = flat_data.to(device=self.device, non_blocking=False)
......@@ -388,6 +400,8 @@ class MLATokenToKVPool(KVCache):
else:
self.store_dtype = dtype
self.kv_lora_rank = kv_lora_rank
self.qk_rope_head_dim = qk_rope_head_dim
self.layer_num = layer_num
memory_saver_adapter = TorchMemorySaverAdapter.create(
enable=enable_memory_saver
......@@ -404,12 +418,20 @@ class MLATokenToKVPool(KVCache):
for _ in range(layer_num)
]
self.layer_transfer_counter = None
def get_key_buffer(self, layer_id: int):
if self.layer_transfer_counter is not None:
self.layer_transfer_counter.wait_until(layer_id)
if self.store_dtype != self.dtype:
return self.kv_buffer[layer_id].view(self.dtype)
return self.kv_buffer[layer_id]
def get_value_buffer(self, layer_id: int):
if self.layer_transfer_counter is not None:
self.layer_transfer_counter.wait_until(layer_id)
if self.store_dtype != self.dtype:
return self.kv_buffer[layer_id][..., : self.kv_lora_rank].view(self.dtype)
return self.kv_buffer[layer_id][..., : self.kv_lora_rank]
......@@ -432,6 +454,22 @@ class MLATokenToKVPool(KVCache):
else:
self.kv_buffer[layer_id][loc] = cache_k
def get_flat_data(self, indices):
# prepare a large chunk of contiguous data for efficient transfer
return torch.stack([self.kv_buffer[i][indices] for i in range(self.layer_num)])
@debug_timing
def transfer(self, indices, flat_data):
# transfer prepared data from host to device
flat_data = flat_data.to(device=self.device, non_blocking=False)
for i in range(self.layer_num):
self.kv_buffer[i][indices] = flat_data[i]
def transfer_per_layer(self, indices, flat_data, layer_id):
# transfer prepared data from host to device
flat_data = flat_data.to(device=self.device, non_blocking=False)
self.kv_buffer[layer_id][indices] = flat_data
class DoubleSparseTokenToKVPool(KVCache):
def __init__(
......@@ -508,6 +546,15 @@ class DoubleSparseTokenToKVPool(KVCache):
self.v_buffer[layer_id][loc] = cache_v
self.label_buffer[layer_id][loc] = cache_label
def get_flat_data(self, indices):
pass
def transfer(self, indices, flat_data):
pass
def transfer_per_layer(self, indices, flat_data, layer_id):
pass
class MemoryStateInt(IntEnum):
IDLE = 0
......@@ -526,7 +573,7 @@ def synchronized(func):
return wrapper
class MHATokenToKVPoolHost:
class HostKVCache(abc.ABC):
def __init__(
self,
......@@ -547,12 +594,7 @@ class MHATokenToKVPoolHost:
self.size = int(device_pool.size * host_to_device_ratio)
self.dtype = device_pool.store_dtype
self.head_num = device_pool.head_num
self.head_dim = device_pool.head_dim
self.layer_num = device_pool.layer_num
self.size_per_token = (
self.head_dim * self.head_num * self.layer_num * self.dtype.itemsize * 2
)
self.size_per_token = self.get_size_per_token()
# Verify there is enough available host memory.
host_mem = psutil.virtual_memory()
......@@ -571,12 +613,7 @@ class MHATokenToKVPoolHost:
f"Allocating {requested_bytes / 1e9:.2f} GB host memory for hierarchical KV cache."
)
self.kv_buffer = torch.zeros(
(2, self.layer_num, self.size, self.head_num, self.head_dim),
dtype=self.dtype,
device=self.device,
pin_memory=self.pin_memory,
)
self.kv_buffer = self.init_kv_buffer()
# Initialize memory states and tracking structures.
self.mem_state = torch.zeros(
......@@ -588,21 +625,29 @@ class MHATokenToKVPoolHost:
# A lock for synchronized operations on memory allocation and state transitions.
self.lock = threading.RLock()
@abc.abstractmethod
def get_size_per_token(self):
raise NotImplementedError()
@abc.abstractmethod
def init_kv_buffer(self):
raise NotImplementedError()
@abc.abstractmethod
def transfer(self, indices, flat_data):
raise NotImplementedError()
@abc.abstractmethod
def get_flat_data(self, indices):
return self.kv_buffer[:, :, indices]
raise NotImplementedError()
@abc.abstractmethod
def get_flat_data_by_layer(self, indices, layer_id):
return self.kv_buffer[:, layer_id, indices]
raise NotImplementedError()
@abc.abstractmethod
def assign_flat_data(self, indices, flat_data):
self.kv_buffer[:, :, indices] = flat_data
@debug_timing
def transfer(self, indices, flat_data):
# backup prepared data from device to host
self.kv_buffer[:, :, indices] = flat_data.to(
device=self.device, non_blocking=False
)
raise NotImplementedError()
@synchronized
def clear(self):
......@@ -694,3 +739,92 @@ class MHATokenToKVPoolHost:
self.free_slots = torch.concat([self.free_slots, indices])
self.can_use_mem_size += len(indices)
return len(indices)
class MHATokenToKVPoolHost(HostKVCache):
def __init__(
self,
device_pool: MHATokenToKVPool,
host_to_device_ratio: float = 3.0,
pin_memory: bool = False, # no need to use pin memory with the double buffering
device: str = "cpu",
):
super().__init__(device_pool, host_to_device_ratio, pin_memory, device)
def get_size_per_token(self):
self.head_num = self.device_pool.head_num
self.head_dim = self.device_pool.head_dim
self.layer_num = self.device_pool.layer_num
return self.head_dim * self.head_num * self.layer_num * self.dtype.itemsize * 2
def init_kv_buffer(self):
return torch.empty(
(2, self.layer_num, self.size, self.head_num, self.head_dim),
dtype=self.dtype,
device=self.device,
pin_memory=self.pin_memory,
)
@debug_timing
def transfer(self, indices, flat_data):
# backup prepared data from device to host
self.kv_buffer[:, :, indices] = flat_data.to(
device=self.device, non_blocking=False
)
def get_flat_data(self, indices):
return self.kv_buffer[:, :, indices]
def get_flat_data_by_layer(self, indices, layer_id):
return self.kv_buffer[:, layer_id, indices]
def assign_flat_data(self, indices, flat_data):
self.kv_buffer[:, :, indices] = flat_data
class MLATokenToKVPoolHost(HostKVCache):
def __init__(
self,
device_pool: MLATokenToKVPool,
host_to_device_ratio: float = 4.0,
pin_memory: bool = False, # no need to use pin memory with the double buffering
device: str = "cpu",
):
super().__init__(device_pool, host_to_device_ratio, pin_memory, device)
def get_size_per_token(self):
self.kv_lora_rank = self.device_pool.kv_lora_rank
self.qk_rope_head_dim = self.device_pool.qk_rope_head_dim
self.layer_num = self.device_pool.layer_num
return (self.kv_lora_rank + self.qk_rope_head_dim) * 1 * self.dtype.itemsize
def init_kv_buffer(self):
return torch.empty(
(
self.layer_num,
self.size,
1,
self.kv_lora_rank + self.qk_rope_head_dim,
),
dtype=self.dtype,
device=self.device,
pin_memory=self.pin_memory,
)
@debug_timing
def transfer(self, indices, flat_data):
# backup prepared data from device to host
self.kv_buffer[:, indices] = flat_data.to(
device=self.device, non_blocking=False
)
def get_flat_data(self, indices):
return self.kv_buffer[:, indices]
def get_flat_data_by_layer(self, indices, layer_id):
return self.kv_buffer[layer_id, indices]
def assign_flat_data(self, indices, flat_data):
self.kv_buffer[:, indices] = flat_data
import unittest
from types import SimpleNamespace
from sglang.srt.utils import kill_process_tree
from sglang.test.run_eval import run_eval
from sglang.test.test_utils import (
DEFAULT_MLA_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
popen_launch_server,
)
class TestHierarchicalMLA(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=["--trust-remote-code", "--enable-hierarchical-cache"],
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_mmlu(self):
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mmlu",
num_examples=64,
num_threads=32,
)
metrics = run_eval(args)
self.assertGreater(metrics["score"], 0.5)
def test_mgsm_en(self):
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mgsm_en",
num_examples=None,
num_threads=1024,
)
metrics = run_eval(args)
self.assertGreater(metrics["score"], 0.8)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment