Unverified Commit 0e0ec702 authored by Lu Changqi's avatar Lu Changqi Committed by GitHub
Browse files

Hierarchical Caching supports MLA (#4009)


Signed-off-by: default avatarChangqi Lu <luchangqi.123@bytedance.com>
Co-authored-by: default avatarZhiqiang Xie <xiezhq@stanford.edu>
parent bb378556
...@@ -22,10 +22,7 @@ from typing import List, Optional ...@@ -22,10 +22,7 @@ from typing import List, Optional
import torch import torch
from sglang.srt.mem_cache.memory_pool import ( from sglang.srt.mem_cache.memory_pool import HostKVCache, TokenToKVPoolAllocator
MHATokenToKVPoolHost,
TokenToKVPoolAllocator,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -151,7 +148,7 @@ class HiCacheController: ...@@ -151,7 +148,7 @@ class HiCacheController:
def __init__( def __init__(
self, self,
token_to_kv_pool_allocator: TokenToKVPoolAllocator, token_to_kv_pool_allocator: TokenToKVPoolAllocator,
mem_pool_host: MHATokenToKVPoolHost, mem_pool_host: HostKVCache,
load_cache_event: threading.Event = None, load_cache_event: threading.Event = None,
write_policy: str = "write_through_selective", write_policy: str = "write_through_selective",
): ):
......
...@@ -8,7 +8,10 @@ import torch ...@@ -8,7 +8,10 @@ import torch
from sglang.srt.managers.cache_controller import HiCacheController from sglang.srt.managers.cache_controller import HiCacheController
from sglang.srt.mem_cache.memory_pool import ( from sglang.srt.mem_cache.memory_pool import (
MHATokenToKVPool,
MHATokenToKVPoolHost, MHATokenToKVPoolHost,
MLATokenToKVPool,
MLATokenToKVPoolHost,
ReqToTokenPool, ReqToTokenPool,
TokenToKVPoolAllocator, TokenToKVPoolAllocator,
) )
...@@ -31,9 +34,14 @@ class HiRadixCache(RadixCache): ...@@ -31,9 +34,14 @@ class HiRadixCache(RadixCache):
raise ValueError( raise ValueError(
"Page size larger than 1 is not yet supported in HiRadixCache." "Page size larger than 1 is not yet supported in HiRadixCache."
) )
self.token_to_kv_pool_host = MHATokenToKVPoolHost( self.kv_cache = token_to_kv_pool_allocator.get_kvcache()
token_to_kv_pool_allocator.get_kvcache() if isinstance(self.kv_cache, MHATokenToKVPool):
) self.token_to_kv_pool_host = MHATokenToKVPoolHost(self.kv_cache)
elif isinstance(self.kv_cache, MLATokenToKVPool):
self.token_to_kv_pool_host = MLATokenToKVPoolHost(self.kv_cache)
else:
raise ValueError(f"Only MHA and MLA supports swap kv_cache to host.")
self.tp_group = tp_cache_group self.tp_group = tp_cache_group
self.page_size = page_size self.page_size = page_size
...@@ -317,13 +325,11 @@ class HiRadixCache(RadixCache): ...@@ -317,13 +325,11 @@ class HiRadixCache(RadixCache):
prefix_len = _key_match(child.key, key) prefix_len = _key_match(child.key, key)
if prefix_len < len(child.key): if prefix_len < len(child.key):
new_node = self._split_node(child.key, child, prefix_len) new_node = self._split_node(child.key, child, prefix_len)
self.inc_hit_count(new_node)
if not new_node.evicted: if not new_node.evicted:
value.append(new_node.value) value.append(new_node.value)
node = new_node node = new_node
break break
else: else:
self.inc_hit_count(child)
if not child.evicted: if not child.evicted:
value.append(child.value) value.append(child.value)
node = child node = child
......
...@@ -115,6 +115,21 @@ class KVCache(abc.ABC): ...@@ -115,6 +115,21 @@ class KVCache(abc.ABC):
) -> None: ) -> None:
raise NotImplementedError() raise NotImplementedError()
@abc.abstractmethod
def get_flat_data(self, indices):
raise NotImplementedError()
@abc.abstractmethod
def transfer(self, indices, flat_data):
raise NotImplementedError()
@abc.abstractmethod
def transfer_per_layer(self, indices, flat_data, layer_id):
raise NotImplementedError()
def register_layer_transfer_counter(self, layer_transfer_counter):
self.layer_transfer_counter = layer_transfer_counter
class TokenToKVPoolAllocator: class TokenToKVPoolAllocator:
"""An allocator managing the indices to kv cache data.""" """An allocator managing the indices to kv cache data."""
...@@ -275,9 +290,6 @@ class MHATokenToKVPool(KVCache): ...@@ -275,9 +290,6 @@ class MHATokenToKVPool(KVCache):
self.k_buffer[i][indices] = k_data[i] self.k_buffer[i][indices] = k_data[i]
self.v_buffer[i][indices] = v_data[i] self.v_buffer[i][indices] = v_data[i]
def register_layer_transfer_counter(self, layer_transfer_counter):
self.layer_transfer_counter = layer_transfer_counter
def transfer_per_layer(self, indices, flat_data, layer_id): def transfer_per_layer(self, indices, flat_data, layer_id):
# transfer prepared data from host to device # transfer prepared data from host to device
flat_data = flat_data.to(device=self.device, non_blocking=False) flat_data = flat_data.to(device=self.device, non_blocking=False)
...@@ -388,6 +400,8 @@ class MLATokenToKVPool(KVCache): ...@@ -388,6 +400,8 @@ class MLATokenToKVPool(KVCache):
else: else:
self.store_dtype = dtype self.store_dtype = dtype
self.kv_lora_rank = kv_lora_rank self.kv_lora_rank = kv_lora_rank
self.qk_rope_head_dim = qk_rope_head_dim
self.layer_num = layer_num
memory_saver_adapter = TorchMemorySaverAdapter.create( memory_saver_adapter = TorchMemorySaverAdapter.create(
enable=enable_memory_saver enable=enable_memory_saver
...@@ -404,12 +418,20 @@ class MLATokenToKVPool(KVCache): ...@@ -404,12 +418,20 @@ class MLATokenToKVPool(KVCache):
for _ in range(layer_num) for _ in range(layer_num)
] ]
self.layer_transfer_counter = None
def get_key_buffer(self, layer_id: int): def get_key_buffer(self, layer_id: int):
if self.layer_transfer_counter is not None:
self.layer_transfer_counter.wait_until(layer_id)
if self.store_dtype != self.dtype: if self.store_dtype != self.dtype:
return self.kv_buffer[layer_id].view(self.dtype) return self.kv_buffer[layer_id].view(self.dtype)
return self.kv_buffer[layer_id] return self.kv_buffer[layer_id]
def get_value_buffer(self, layer_id: int): def get_value_buffer(self, layer_id: int):
if self.layer_transfer_counter is not None:
self.layer_transfer_counter.wait_until(layer_id)
if self.store_dtype != self.dtype: if self.store_dtype != self.dtype:
return self.kv_buffer[layer_id][..., : self.kv_lora_rank].view(self.dtype) return self.kv_buffer[layer_id][..., : self.kv_lora_rank].view(self.dtype)
return self.kv_buffer[layer_id][..., : self.kv_lora_rank] return self.kv_buffer[layer_id][..., : self.kv_lora_rank]
...@@ -432,6 +454,22 @@ class MLATokenToKVPool(KVCache): ...@@ -432,6 +454,22 @@ class MLATokenToKVPool(KVCache):
else: else:
self.kv_buffer[layer_id][loc] = cache_k self.kv_buffer[layer_id][loc] = cache_k
def get_flat_data(self, indices):
# prepare a large chunk of contiguous data for efficient transfer
return torch.stack([self.kv_buffer[i][indices] for i in range(self.layer_num)])
@debug_timing
def transfer(self, indices, flat_data):
# transfer prepared data from host to device
flat_data = flat_data.to(device=self.device, non_blocking=False)
for i in range(self.layer_num):
self.kv_buffer[i][indices] = flat_data[i]
def transfer_per_layer(self, indices, flat_data, layer_id):
# transfer prepared data from host to device
flat_data = flat_data.to(device=self.device, non_blocking=False)
self.kv_buffer[layer_id][indices] = flat_data
class DoubleSparseTokenToKVPool(KVCache): class DoubleSparseTokenToKVPool(KVCache):
def __init__( def __init__(
...@@ -508,6 +546,15 @@ class DoubleSparseTokenToKVPool(KVCache): ...@@ -508,6 +546,15 @@ class DoubleSparseTokenToKVPool(KVCache):
self.v_buffer[layer_id][loc] = cache_v self.v_buffer[layer_id][loc] = cache_v
self.label_buffer[layer_id][loc] = cache_label self.label_buffer[layer_id][loc] = cache_label
def get_flat_data(self, indices):
pass
def transfer(self, indices, flat_data):
pass
def transfer_per_layer(self, indices, flat_data, layer_id):
pass
class MemoryStateInt(IntEnum): class MemoryStateInt(IntEnum):
IDLE = 0 IDLE = 0
...@@ -526,7 +573,7 @@ def synchronized(func): ...@@ -526,7 +573,7 @@ def synchronized(func):
return wrapper return wrapper
class MHATokenToKVPoolHost: class HostKVCache(abc.ABC):
def __init__( def __init__(
self, self,
...@@ -547,12 +594,7 @@ class MHATokenToKVPoolHost: ...@@ -547,12 +594,7 @@ class MHATokenToKVPoolHost:
self.size = int(device_pool.size * host_to_device_ratio) self.size = int(device_pool.size * host_to_device_ratio)
self.dtype = device_pool.store_dtype self.dtype = device_pool.store_dtype
self.head_num = device_pool.head_num self.size_per_token = self.get_size_per_token()
self.head_dim = device_pool.head_dim
self.layer_num = device_pool.layer_num
self.size_per_token = (
self.head_dim * self.head_num * self.layer_num * self.dtype.itemsize * 2
)
# Verify there is enough available host memory. # Verify there is enough available host memory.
host_mem = psutil.virtual_memory() host_mem = psutil.virtual_memory()
...@@ -571,12 +613,7 @@ class MHATokenToKVPoolHost: ...@@ -571,12 +613,7 @@ class MHATokenToKVPoolHost:
f"Allocating {requested_bytes / 1e9:.2f} GB host memory for hierarchical KV cache." f"Allocating {requested_bytes / 1e9:.2f} GB host memory for hierarchical KV cache."
) )
self.kv_buffer = torch.zeros( self.kv_buffer = self.init_kv_buffer()
(2, self.layer_num, self.size, self.head_num, self.head_dim),
dtype=self.dtype,
device=self.device,
pin_memory=self.pin_memory,
)
# Initialize memory states and tracking structures. # Initialize memory states and tracking structures.
self.mem_state = torch.zeros( self.mem_state = torch.zeros(
...@@ -588,21 +625,29 @@ class MHATokenToKVPoolHost: ...@@ -588,21 +625,29 @@ class MHATokenToKVPoolHost:
# A lock for synchronized operations on memory allocation and state transitions. # A lock for synchronized operations on memory allocation and state transitions.
self.lock = threading.RLock() self.lock = threading.RLock()
@abc.abstractmethod
def get_size_per_token(self):
raise NotImplementedError()
@abc.abstractmethod
def init_kv_buffer(self):
raise NotImplementedError()
@abc.abstractmethod
def transfer(self, indices, flat_data):
raise NotImplementedError()
@abc.abstractmethod
def get_flat_data(self, indices): def get_flat_data(self, indices):
return self.kv_buffer[:, :, indices] raise NotImplementedError()
@abc.abstractmethod
def get_flat_data_by_layer(self, indices, layer_id): def get_flat_data_by_layer(self, indices, layer_id):
return self.kv_buffer[:, layer_id, indices] raise NotImplementedError()
@abc.abstractmethod
def assign_flat_data(self, indices, flat_data): def assign_flat_data(self, indices, flat_data):
self.kv_buffer[:, :, indices] = flat_data raise NotImplementedError()
@debug_timing
def transfer(self, indices, flat_data):
# backup prepared data from device to host
self.kv_buffer[:, :, indices] = flat_data.to(
device=self.device, non_blocking=False
)
@synchronized @synchronized
def clear(self): def clear(self):
...@@ -694,3 +739,92 @@ class MHATokenToKVPoolHost: ...@@ -694,3 +739,92 @@ class MHATokenToKVPoolHost:
self.free_slots = torch.concat([self.free_slots, indices]) self.free_slots = torch.concat([self.free_slots, indices])
self.can_use_mem_size += len(indices) self.can_use_mem_size += len(indices)
return len(indices) return len(indices)
class MHATokenToKVPoolHost(HostKVCache):
def __init__(
self,
device_pool: MHATokenToKVPool,
host_to_device_ratio: float = 3.0,
pin_memory: bool = False, # no need to use pin memory with the double buffering
device: str = "cpu",
):
super().__init__(device_pool, host_to_device_ratio, pin_memory, device)
def get_size_per_token(self):
self.head_num = self.device_pool.head_num
self.head_dim = self.device_pool.head_dim
self.layer_num = self.device_pool.layer_num
return self.head_dim * self.head_num * self.layer_num * self.dtype.itemsize * 2
def init_kv_buffer(self):
return torch.empty(
(2, self.layer_num, self.size, self.head_num, self.head_dim),
dtype=self.dtype,
device=self.device,
pin_memory=self.pin_memory,
)
@debug_timing
def transfer(self, indices, flat_data):
# backup prepared data from device to host
self.kv_buffer[:, :, indices] = flat_data.to(
device=self.device, non_blocking=False
)
def get_flat_data(self, indices):
return self.kv_buffer[:, :, indices]
def get_flat_data_by_layer(self, indices, layer_id):
return self.kv_buffer[:, layer_id, indices]
def assign_flat_data(self, indices, flat_data):
self.kv_buffer[:, :, indices] = flat_data
class MLATokenToKVPoolHost(HostKVCache):
def __init__(
self,
device_pool: MLATokenToKVPool,
host_to_device_ratio: float = 4.0,
pin_memory: bool = False, # no need to use pin memory with the double buffering
device: str = "cpu",
):
super().__init__(device_pool, host_to_device_ratio, pin_memory, device)
def get_size_per_token(self):
self.kv_lora_rank = self.device_pool.kv_lora_rank
self.qk_rope_head_dim = self.device_pool.qk_rope_head_dim
self.layer_num = self.device_pool.layer_num
return (self.kv_lora_rank + self.qk_rope_head_dim) * 1 * self.dtype.itemsize
def init_kv_buffer(self):
return torch.empty(
(
self.layer_num,
self.size,
1,
self.kv_lora_rank + self.qk_rope_head_dim,
),
dtype=self.dtype,
device=self.device,
pin_memory=self.pin_memory,
)
@debug_timing
def transfer(self, indices, flat_data):
# backup prepared data from device to host
self.kv_buffer[:, indices] = flat_data.to(
device=self.device, non_blocking=False
)
def get_flat_data(self, indices):
return self.kv_buffer[:, indices]
def get_flat_data_by_layer(self, indices, layer_id):
return self.kv_buffer[layer_id, indices]
def assign_flat_data(self, indices, flat_data):
self.kv_buffer[:, indices] = flat_data
import unittest
from types import SimpleNamespace
from sglang.srt.utils import kill_process_tree
from sglang.test.run_eval import run_eval
from sglang.test.test_utils import (
DEFAULT_MLA_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
popen_launch_server,
)
class TestHierarchicalMLA(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=["--trust-remote-code", "--enable-hierarchical-cache"],
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_mmlu(self):
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mmlu",
num_examples=64,
num_threads=32,
)
metrics = run_eval(args)
self.assertGreater(metrics["score"], 0.5)
def test_mgsm_en(self):
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mgsm_en",
num_examples=None,
num_threads=1024,
)
metrics = run_eval(args)
self.assertGreater(metrics["score"], 0.8)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment