Unverified Commit 3d40794f authored by Zhiqiang Xie's avatar Zhiqiang Xie Committed by GitHub
Browse files

[HiCache] Cleaning the deprecated host memory state (#10778)

parent c1f39013
......@@ -462,7 +462,6 @@ class HiCacheController:
host_indices = self.mem_pool_host.alloc(len(device_indices))
if host_indices is None:
return None
self.mem_pool_host.protect_write(host_indices)
self.write_queue.append(
CacheOperation(host_indices, device_indices, node_id, priority)
)
......@@ -486,7 +485,6 @@ class HiCacheController:
self.mem_pool_host.backup_from_device_all_layer(
self.mem_pool_device, host_indices, device_indices, self.io_backend
)
self.mem_pool_host.complete_io(op.host_indices)
finish_event.record()
# NOTE: We must save the host indices and device indices here,
# this is because we need to guarantee that these tensors are
......@@ -510,7 +508,6 @@ class HiCacheController:
device_indices = self.mem_pool_device_allocator.alloc(len(host_indices))
if device_indices is None:
return None
self.mem_pool_host.protect_load(host_indices)
self.load_queue.append(
CacheOperation(host_indices, device_indices, node_id, priority)
)
......@@ -555,7 +552,6 @@ class HiCacheController:
self.io_backend,
)
producer_event.complete(i)
self.mem_pool_host.complete_io(op.host_indices)
# NOTE: We must save the host indices and device indices here,
# this is because we need to guarantee that these tensors are
# still alive when the load stream is executing.
......@@ -573,29 +569,16 @@ class HiCacheController:
)
return producer_id
def evict_device(
self, device_indices: torch.Tensor, host_indices: torch.Tensor
) -> int:
if self.mem_pool_host.is_synced(host_indices):
self.mem_pool_device_allocator.free(device_indices)
self.mem_pool_host.update_backup(host_indices)
return len(device_indices)
else:
raise ValueError(
f"Inconsistent states: {self.mem_pool_host.get_state(host_indices)}"
)
def evict_device(self, device_indices: torch.Tensor) -> int:
self.mem_pool_device_allocator.free(device_indices)
return len(device_indices)
def evict_host(self, host_indices: torch.Tensor, backup_only: bool = True) -> int:
if not backup_only:
raise ValueError("Other eviction policies are not supported yet.")
if self.mem_pool_host.is_backup(host_indices):
self.mem_pool_host.free(host_indices)
return len(host_indices)
else:
raise ValueError(
f"Inconsistent states: {self.mem_pool_host.get_state(host_indices)}"
)
self.mem_pool_host.free(host_indices)
return len(host_indices)
def prefetch(
self,
......
......@@ -305,7 +305,7 @@ class HiRadixCache(RadixCache):
def _evict_backuped(self, node: TreeNode):
# evict a node already written to host
num_evicted = self.cache_controller.evict_device(node.value, node.host_value)
num_evicted = self.cache_controller.evict_device(node.value)
assert num_evicted > 0
self.evictable_size_ -= num_evicted
node.value = None
......@@ -576,8 +576,6 @@ class HiRadixCache(RadixCache):
written_indices,
hash_value[: min_completed_tokens // self.page_size],
)
if len(written_indices):
self.cache_controller.mem_pool_host.update_prefetch(written_indices)
self.cache_controller.mem_pool_host.free(host_indices[:matched_length])
self.cache_controller.append_host_mem_release(
......@@ -775,7 +773,6 @@ class HiRadixCache(RadixCache):
# change the reference if the node is evicted
# this often happens in the case of KV cache recomputation
node.value = value[:prefix_len]
self.token_to_kv_pool_host.update_synced(node.host_value)
self.evictable_size_ += len(node.value)
else:
self._inc_hit_count(node, chunked)
......@@ -785,7 +782,6 @@ class HiRadixCache(RadixCache):
new_node = self._split_node(node.key, node, prefix_len)
if new_node.evicted:
new_node.value = value[:prefix_len]
self.token_to_kv_pool_host.update_synced(new_node.host_value)
self.evictable_size_ += len(new_node.value)
else:
self._inc_hit_count(new_node, chunked)
......
......@@ -31,27 +31,13 @@ if not (_is_npu or _is_xpu):
logger = logging.getLogger(__name__)
class MemoryStateInt(IntEnum):
IDLE = 0
RESERVED = 1
PROTECTED = 2
SYNCED = 3
BACKUP = 4
def synchronized(debug_only=False):
def _decorator(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
if (not debug_only) or self.debug:
with self.lock:
return func(self, *args, **kwargs)
else:
return True
return wrapper
def synchronized(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
with self.lock:
return func(self, *args, **kwargs)
return _decorator
return wrapper
class HostKVCache(abc.ABC):
......@@ -110,7 +96,6 @@ class HostKVCache(abc.ABC):
# A lock for synchronized operations on memory allocation and state transitions.
self.lock = threading.RLock()
self.debug = logger.isEnabledFor(logging.DEBUG)
self.clear()
@abc.abstractmethod
......@@ -161,7 +146,7 @@ class HostKVCache(abc.ABC):
"""
raise NotImplementedError()
@synchronized()
@synchronized
def clear(self):
# Initialize memory states and tracking structures.
self.mem_state = torch.zeros(
......@@ -172,7 +157,7 @@ class HostKVCache(abc.ABC):
def available_size(self):
return len(self.free_slots)
@synchronized()
@synchronized
def alloc(self, need_size: int) -> Optional[torch.Tensor]:
assert (
need_size % self.page_size == 0
......@@ -183,92 +168,13 @@ class HostKVCache(abc.ABC):
select_index = self.free_slots[:need_size]
self.free_slots = self.free_slots[need_size:]
if self.debug:
self.mem_state[select_index] = MemoryStateInt.RESERVED
return select_index
@synchronized()
@synchronized
def free(self, indices: torch.Tensor) -> int:
self.free_slots = torch.cat([self.free_slots, indices])
if self.debug:
self.mem_state[indices] = MemoryStateInt.IDLE
return len(indices)
@synchronized(debug_only=True)
def get_state(self, indices: torch.Tensor) -> MemoryStateInt:
assert len(indices) > 0, "The indices should not be empty"
states = self.mem_state[indices]
assert (
states == states[0]
).all(), "The memory slots should have the same state {}".format(states)
return MemoryStateInt(states[0].item())
@synchronized(debug_only=True)
def is_reserved(self, indices: torch.Tensor) -> bool:
return self.get_state(indices) == MemoryStateInt.RESERVED
@synchronized(debug_only=True)
def is_protected(self, indices: torch.Tensor) -> bool:
return self.get_state(indices) == MemoryStateInt.PROTECTED
@synchronized(debug_only=True)
def is_synced(self, indices: torch.Tensor) -> bool:
return self.get_state(indices) == MemoryStateInt.SYNCED
@synchronized(debug_only=True)
def is_backup(self, indices: torch.Tensor) -> bool:
return self.get_state(indices) == MemoryStateInt.BACKUP
@synchronized(debug_only=True)
def update_backup(self, indices: torch.Tensor):
if not self.is_synced(indices):
raise ValueError(
f"The host memory slots should be in SYNCED state before turning into BACKUP. "
f"Current state: {self.get_state(indices)}"
)
self.mem_state[indices] = MemoryStateInt.BACKUP
@synchronized(debug_only=True)
def update_prefetch(self, indices: torch.Tensor):
if not self.is_reserved(indices):
raise ValueError(
f"The host memory slots should be in RESERVED state before turning into BACKUP. "
f"Current state: {self.get_state(indices)}"
)
self.mem_state[indices] = MemoryStateInt.BACKUP
@synchronized(debug_only=True)
def update_synced(self, indices: torch.Tensor):
self.mem_state[indices] = MemoryStateInt.SYNCED
@synchronized(debug_only=True)
def protect_write(self, indices: torch.Tensor):
if not self.is_reserved(indices):
raise ValueError(
f"The host memory slots should be RESERVED before write operations. "
f"Current state: {self.get_state(indices)}"
)
self.mem_state[indices] = MemoryStateInt.PROTECTED
@synchronized(debug_only=True)
def protect_load(self, indices: torch.Tensor):
if not self.is_backup(indices):
raise ValueError(
f"The host memory slots should be in BACKUP state before load operations. "
f"Current state: {self.get_state(indices)}"
)
self.mem_state[indices] = MemoryStateInt.PROTECTED
@synchronized(debug_only=True)
def complete_io(self, indices: torch.Tensor):
if not self.is_protected(indices):
raise ValueError(
f"The host memory slots should be PROTECTED during I/O operations. "
f"Current state: {self.get_state(indices)}"
)
self.mem_state[indices] = MemoryStateInt.SYNCED
class MHATokenToKVPoolHost(HostKVCache):
device_pool: MHATokenToKVPool
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment