Unverified Commit 5ea5d221 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Fix CPU offloading for MLA memory pool (#7409)

parent fdfd5224
...@@ -123,6 +123,9 @@ class KVCache(abc.ABC): ...@@ -123,6 +123,9 @@ class KVCache(abc.ABC):
enable=enable_memory_saver enable=enable_memory_saver
) )
# used for chunked cpu-offloading
self.cpu_offloading_chunk_size = 8192
@abc.abstractmethod @abc.abstractmethod
def get_key_buffer(self, layer_id: int) -> torch.Tensor: def get_key_buffer(self, layer_id: int) -> torch.Tensor:
raise NotImplementedError() raise NotImplementedError()
...@@ -157,6 +160,12 @@ class KVCache(abc.ABC): ...@@ -157,6 +160,12 @@ class KVCache(abc.ABC):
def register_layer_transfer_counter(self, layer_transfer_counter): def register_layer_transfer_counter(self, layer_transfer_counter):
self.layer_transfer_counter = layer_transfer_counter self.layer_transfer_counter = layer_transfer_counter
def get_cpu_copy(self, indices):
raise NotImplementedError()
def load_cpu_copy(self, kv_cache_cpu, indices):
raise NotImplementedError()
class TokenToKVPoolAllocator: class TokenToKVPoolAllocator:
"""An allocator managing the indices to kv cache data.""" """An allocator managing the indices to kv cache data."""
...@@ -280,8 +289,6 @@ class MHATokenToKVPool(KVCache): ...@@ -280,8 +289,6 @@ class MHATokenToKVPool(KVCache):
self._create_buffers() self._create_buffers()
# used for chunked cpu-offloading
self.chunk_size = 8192
self.layer_transfer_counter = None self.layer_transfer_counter = None
self.device_module = torch.get_device_module(self.device) self.device_module = torch.get_device_module(self.device)
self.alt_stream = self.device_module.Stream() if _is_cuda else None self.alt_stream = self.device_module.Stream() if _is_cuda else None
...@@ -378,10 +385,11 @@ class MHATokenToKVPool(KVCache): ...@@ -378,10 +385,11 @@ class MHATokenToKVPool(KVCache):
def get_cpu_copy(self, indices): def get_cpu_copy(self, indices):
torch.cuda.synchronize() torch.cuda.synchronize()
kv_cache_cpu = [] kv_cache_cpu = []
chunk_size = self.cpu_offloading_chunk_size
for layer_id in range(self.layer_num): for layer_id in range(self.layer_num):
kv_cache_cpu.append([]) kv_cache_cpu.append([])
for i in range(0, len(indices), self.chunk_size): for i in range(0, len(indices), chunk_size):
chunk_indices = indices[i : i + self.chunk_size] chunk_indices = indices[i : i + chunk_size]
k_cpu = self.k_buffer[layer_id][chunk_indices].to( k_cpu = self.k_buffer[layer_id][chunk_indices].to(
"cpu", non_blocking=True "cpu", non_blocking=True
) )
...@@ -394,12 +402,13 @@ class MHATokenToKVPool(KVCache): ...@@ -394,12 +402,13 @@ class MHATokenToKVPool(KVCache):
def load_cpu_copy(self, kv_cache_cpu, indices): def load_cpu_copy(self, kv_cache_cpu, indices):
torch.cuda.synchronize() torch.cuda.synchronize()
chunk_size = self.cpu_offloading_chunk_size
for layer_id in range(self.layer_num): for layer_id in range(self.layer_num):
for i in range(0, len(indices), self.chunk_size): for i in range(0, len(indices), chunk_size):
chunk_indices = indices[i : i + self.chunk_size] chunk_indices = indices[i : i + chunk_size]
k_cpu, v_cpu = ( k_cpu, v_cpu = (
kv_cache_cpu[layer_id][i // self.chunk_size][0], kv_cache_cpu[layer_id][i // chunk_size][0],
kv_cache_cpu[layer_id][i // self.chunk_size][1], kv_cache_cpu[layer_id][i // chunk_size][1],
) )
assert k_cpu.shape[0] == v_cpu.shape[0] == len(chunk_indices) assert k_cpu.shape[0] == v_cpu.shape[0] == len(chunk_indices)
k_chunk = k_cpu.to(self.k_buffer[0].device, non_blocking=True) k_chunk = k_cpu.to(self.k_buffer[0].device, non_blocking=True)
...@@ -724,6 +733,33 @@ class MLATokenToKVPool(KVCache): ...@@ -724,6 +733,33 @@ class MLATokenToKVPool(KVCache):
flat_data = flat_data.to(device=self.device, non_blocking=False) flat_data = flat_data.to(device=self.device, non_blocking=False)
self.kv_buffer[layer_id - self.start_layer][indices] = flat_data self.kv_buffer[layer_id - self.start_layer][indices] = flat_data
def get_cpu_copy(self, indices):
torch.cuda.synchronize()
kv_cache_cpu = []
chunk_size = self.cpu_offloading_chunk_size
for layer_id in range(self.layer_num):
kv_cache_cpu.append([])
for i in range(0, len(indices), chunk_size):
chunk_indices = indices[i : i + chunk_size]
kv_cpu = self.kv_buffer[layer_id][chunk_indices].to(
"cpu", non_blocking=True
)
kv_cache_cpu[-1].append(kv_cpu)
torch.cuda.synchronize()
return kv_cache_cpu
def load_cpu_copy(self, kv_cache_cpu, indices):
torch.cuda.synchronize()
chunk_size = self.cpu_offloading_chunk_size
for layer_id in range(self.layer_num):
for i in range(0, len(indices), chunk_size):
chunk_indices = indices[i : i + chunk_size]
kv_cpu = kv_cache_cpu[layer_id][i // chunk_size]
assert kv_cpu.shape[0] == len(chunk_indices)
kv_chunk = kv_cpu.to(self.kv_buffer[0].device, non_blocking=True)
self.kv_buffer[layer_id][chunk_indices] = kv_chunk
torch.cuda.synchronize()
class DoubleSparseTokenToKVPool(KVCache): class DoubleSparseTokenToKVPool(KVCache):
def __init__( def __init__(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment