Unverified Commit 273b2834 authored by Xinyuan Tong's avatar Xinyuan Tong Committed by GitHub
Browse files

[Minor] Refactors KV memory pool (#9842)

parent f84db115
...@@ -130,6 +130,29 @@ class KVCache(abc.ABC): ...@@ -130,6 +130,29 @@ class KVCache(abc.ABC):
# used for chunked cpu-offloading # used for chunked cpu-offloading
self.cpu_offloading_chunk_size = 8192 self.cpu_offloading_chunk_size = 8192
# default state for optional layer-wise transfer control
self.layer_transfer_counter = None
def _finalize_allocation_log(self, num_tokens: int):
"""Common logging and mem_usage computation for KV cache allocation.
Supports both tuple (K, V) size returns and single KV size returns.
"""
kv_size_bytes = self.get_kv_size_bytes()
if isinstance(kv_size_bytes, tuple):
k_size, v_size = kv_size_bytes
k_size_GB = k_size / GB
v_size_GB = v_size / GB
logger.info(
f"KV Cache is allocated. #tokens: {num_tokens}, K size: {k_size_GB:.2f} GB, V size: {v_size_GB:.2f} GB"
)
self.mem_usage = k_size_GB + v_size_GB
else:
kv_size_GB = kv_size_bytes / GB
logger.info(
f"KV Cache is allocated. #tokens: {num_tokens}, KV size: {kv_size_GB:.2f} GB"
)
self.mem_usage = kv_size_GB
@abc.abstractmethod @abc.abstractmethod
def get_key_buffer(self, layer_id: int) -> torch.Tensor: def get_key_buffer(self, layer_id: int) -> torch.Tensor:
raise NotImplementedError() raise NotImplementedError()
...@@ -205,15 +228,9 @@ class MHATokenToKVPool(KVCache): ...@@ -205,15 +228,9 @@ class MHATokenToKVPool(KVCache):
self._create_buffers() self._create_buffers()
self.layer_transfer_counter = None
self.device_module = torch.get_device_module(self.device) self.device_module = torch.get_device_module(self.device)
self.alt_stream = self.device_module.Stream() if _is_cuda else None self.alt_stream = self.device_module.Stream() if _is_cuda else None
self._finalize_allocation_log(size)
k_size, v_size = self.get_kv_size_bytes()
logger.info(
f"KV Cache is allocated. #tokens: {size}, K size: {k_size / GB:.2f} GB, V size: {v_size / GB:.2f} GB"
)
self.mem_usage = (k_size + v_size) / GB
def _create_buffers(self): def _create_buffers(self):
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE): with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
...@@ -427,43 +444,30 @@ class SWAKVPool(KVCache): ...@@ -427,43 +444,30 @@ class SWAKVPool(KVCache):
self, self,
size: int, size: int,
size_swa: int, size_swa: int,
dtype: torch.dtype,
head_num: int,
head_dim: int,
swa_attention_layer_ids: List[int], swa_attention_layer_ids: List[int],
full_attention_layer_ids: List[int], full_attention_layer_ids: List[int],
enable_kvcache_transpose: bool, enable_kvcache_transpose: bool,
device: str, token_to_kv_pool_class: KVCache = MHATokenToKVPool,
**kwargs,
): ):
self.size = size self.size = size
self.size_swa = size_swa self.size_swa = size_swa
self.dtype = dtype
self.device = device
self.swa_layer_nums = len(swa_attention_layer_ids) self.swa_layer_nums = len(swa_attention_layer_ids)
self.full_layer_nums = len(full_attention_layer_ids) self.full_layer_nums = len(full_attention_layer_ids)
self.page_size = 1 kwargs["page_size"] = 1
kwargs["enable_memory_saver"] = False
# TODO MHATransposedTokenToKVPool if enable_kvcache_transpose is True # TODO MHATransposedTokenToKVPool if enable_kvcache_transpose is True
assert not enable_kvcache_transpose assert not enable_kvcache_transpose
TokenToKVPoolClass = MHATokenToKVPool
self.swa_kv_pool = TokenToKVPoolClass( self.swa_kv_pool = token_to_kv_pool_class(
size=size_swa, size=size_swa,
page_size=self.page_size,
dtype=dtype,
head_num=head_num,
head_dim=head_dim,
layer_num=self.swa_layer_nums, layer_num=self.swa_layer_nums,
device=device, **kwargs,
enable_memory_saver=False,
) )
self.full_kv_pool = TokenToKVPoolClass( self.full_kv_pool = token_to_kv_pool_class(
size=size, size=size,
page_size=self.page_size,
dtype=dtype,
head_num=head_num,
head_dim=head_dim,
layer_num=self.full_layer_nums, layer_num=self.full_layer_nums,
device=device, **kwargs,
enable_memory_saver=False,
) )
self.layers_mapping: Dict[int, Tuple[int, bool]] = {} self.layers_mapping: Dict[int, Tuple[int, bool]] = {}
for full_attn_layer_id, global_layer_id in enumerate(full_attention_layer_ids): for full_attn_layer_id, global_layer_id in enumerate(full_attention_layer_ids):
...@@ -768,13 +772,7 @@ class MLATokenToKVPool(KVCache): ...@@ -768,13 +772,7 @@ class MLATokenToKVPool(KVCache):
dtype=torch.uint64, dtype=torch.uint64,
device=self.device, device=self.device,
) )
self.layer_transfer_counter = None self._finalize_allocation_log(size)
kv_size = self.get_kv_size_bytes()
logger.info(
f"KV Cache is allocated. #tokens: {size}, KV size: {kv_size / GB:.2f} GB"
)
self.mem_usage = kv_size / GB
def get_kv_size_bytes(self): def get_kv_size_bytes(self):
assert hasattr(self, "kv_buffer") assert hasattr(self, "kv_buffer")
...@@ -936,13 +934,7 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool): ...@@ -936,13 +934,7 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
device=self.device, device=self.device,
) )
self.layer_transfer_counter = None self._finalize_allocation_log(size)
kv_size = self.get_kv_size_bytes()
logger.info(
f"KV Cache is allocated. #tokens: {size}, KV size: {kv_size / GB:.2f} GB"
)
self.mem_usage = kv_size / GB
def get_kv_size_bytes(self): def get_kv_size_bytes(self):
assert hasattr(self, "k_buffer") assert hasattr(self, "k_buffer")
......
...@@ -31,16 +31,18 @@ class TestSWA(unittest.TestCase): ...@@ -31,16 +31,18 @@ class TestSWA(unittest.TestCase):
i for i in range(num_layers) if i not in full_attention_layer_ids_set i for i in range(num_layers) if i not in full_attention_layer_ids_set
] ]
pool = SWAKVPool( pool = SWAKVPool(
size, size=size,
size_swa, size_swa=size_swa,
dtype, dtype=dtype,
num_head, num_head=num_head,
head_dim, head_dim=head_dim,
swa_attention_layer_ids, swa_attention_layer_ids=swa_attention_layer_ids,
full_attention_layer_ids, full_attention_layer_ids=full_attention_layer_ids,
device, device=device,
) )
alloc = SWATokenToKVPoolAllocator(size, size_swa, dtype, device, pool) alloc = SWATokenToKVPoolAllocator(
size=size, size_swa=size_swa, dtype=dtype, device=device, kvcache=pool
)
assert alloc.available_size() == size + size_swa assert alloc.available_size() == size + size_swa
index = alloc.alloc(1) index = alloc.alloc(1)
assert alloc.available_size() == size_swa + size_swa - 2 assert alloc.available_size() == size_swa + size_swa - 2
...@@ -75,18 +77,22 @@ class TestSWA(unittest.TestCase): ...@@ -75,18 +77,22 @@ class TestSWA(unittest.TestCase):
) )
# setup kv pool # setup kv pool
kv_pool = SWAKVPool( kv_pool = SWAKVPool(
kv_size, size=kv_size,
kv_size_swa, size_swa=kv_size_swa,
dtype, dtype=dtype,
num_head, num_head=num_head,
head_dim, head_dim=head_dim,
swa_attention_layer_ids, swa_attention_layer_ids=swa_attention_layer_ids,
full_attention_layer_ids, full_attention_layer_ids=full_attention_layer_ids,
device, device=device,
) )
# setup token to kv pool allocator # setup token to kv pool allocator
allocator = SWATokenToKVPoolAllocator( allocator = SWATokenToKVPoolAllocator(
kv_size, kv_size_swa, dtype, device, kv_pool size=kv_size,
size_swa=kv_size_swa,
dtype=dtype,
device=device,
kvcache=kv_pool,
) )
# setup radix cache # setup radix cache
tree = SWARadixCache( tree = SWARadixCache(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment