Unverified Commit 3d40794f authored by Zhiqiang Xie's avatar Zhiqiang Xie Committed by GitHub
Browse files

[HiCache] Cleaning the deprecated host memory state (#10778)

parent c1f39013
...@@ -462,7 +462,6 @@ class HiCacheController: ...@@ -462,7 +462,6 @@ class HiCacheController:
host_indices = self.mem_pool_host.alloc(len(device_indices)) host_indices = self.mem_pool_host.alloc(len(device_indices))
if host_indices is None: if host_indices is None:
return None return None
self.mem_pool_host.protect_write(host_indices)
self.write_queue.append( self.write_queue.append(
CacheOperation(host_indices, device_indices, node_id, priority) CacheOperation(host_indices, device_indices, node_id, priority)
) )
...@@ -486,7 +485,6 @@ class HiCacheController: ...@@ -486,7 +485,6 @@ class HiCacheController:
self.mem_pool_host.backup_from_device_all_layer( self.mem_pool_host.backup_from_device_all_layer(
self.mem_pool_device, host_indices, device_indices, self.io_backend self.mem_pool_device, host_indices, device_indices, self.io_backend
) )
self.mem_pool_host.complete_io(op.host_indices)
finish_event.record() finish_event.record()
# NOTE: We must save the host indices and device indices here, # NOTE: We must save the host indices and device indices here,
# this is because we need to guarantee that these tensors are # this is because we need to guarantee that these tensors are
...@@ -510,7 +508,6 @@ class HiCacheController: ...@@ -510,7 +508,6 @@ class HiCacheController:
device_indices = self.mem_pool_device_allocator.alloc(len(host_indices)) device_indices = self.mem_pool_device_allocator.alloc(len(host_indices))
if device_indices is None: if device_indices is None:
return None return None
self.mem_pool_host.protect_load(host_indices)
self.load_queue.append( self.load_queue.append(
CacheOperation(host_indices, device_indices, node_id, priority) CacheOperation(host_indices, device_indices, node_id, priority)
) )
...@@ -555,7 +552,6 @@ class HiCacheController: ...@@ -555,7 +552,6 @@ class HiCacheController:
self.io_backend, self.io_backend,
) )
producer_event.complete(i) producer_event.complete(i)
self.mem_pool_host.complete_io(op.host_indices)
# NOTE: We must save the host indices and device indices here, # NOTE: We must save the host indices and device indices here,
# this is because we need to guarantee that these tensors are # this is because we need to guarantee that these tensors are
# still alive when the load stream is executing. # still alive when the load stream is executing.
...@@ -573,29 +569,16 @@ class HiCacheController: ...@@ -573,29 +569,16 @@ class HiCacheController:
) )
return producer_id return producer_id
def evict_device( def evict_device(self, device_indices: torch.Tensor) -> int:
self, device_indices: torch.Tensor, host_indices: torch.Tensor self.mem_pool_device_allocator.free(device_indices)
) -> int: return len(device_indices)
if self.mem_pool_host.is_synced(host_indices):
self.mem_pool_device_allocator.free(device_indices)
self.mem_pool_host.update_backup(host_indices)
return len(device_indices)
else:
raise ValueError(
f"Inconsistent states: {self.mem_pool_host.get_state(host_indices)}"
)
def evict_host(self, host_indices: torch.Tensor, backup_only: bool = True) -> int: def evict_host(self, host_indices: torch.Tensor, backup_only: bool = True) -> int:
if not backup_only: if not backup_only:
raise ValueError("Other eviction policies are not supported yet.") raise ValueError("Other eviction policies are not supported yet.")
if self.mem_pool_host.is_backup(host_indices): self.mem_pool_host.free(host_indices)
self.mem_pool_host.free(host_indices) return len(host_indices)
return len(host_indices)
else:
raise ValueError(
f"Inconsistent states: {self.mem_pool_host.get_state(host_indices)}"
)
def prefetch( def prefetch(
self, self,
......
...@@ -305,7 +305,7 @@ class HiRadixCache(RadixCache): ...@@ -305,7 +305,7 @@ class HiRadixCache(RadixCache):
def _evict_backuped(self, node: TreeNode): def _evict_backuped(self, node: TreeNode):
# evict a node already written to host # evict a node already written to host
num_evicted = self.cache_controller.evict_device(node.value, node.host_value) num_evicted = self.cache_controller.evict_device(node.value)
assert num_evicted > 0 assert num_evicted > 0
self.evictable_size_ -= num_evicted self.evictable_size_ -= num_evicted
node.value = None node.value = None
...@@ -576,8 +576,6 @@ class HiRadixCache(RadixCache): ...@@ -576,8 +576,6 @@ class HiRadixCache(RadixCache):
written_indices, written_indices,
hash_value[: min_completed_tokens // self.page_size], hash_value[: min_completed_tokens // self.page_size],
) )
if len(written_indices):
self.cache_controller.mem_pool_host.update_prefetch(written_indices)
self.cache_controller.mem_pool_host.free(host_indices[:matched_length]) self.cache_controller.mem_pool_host.free(host_indices[:matched_length])
self.cache_controller.append_host_mem_release( self.cache_controller.append_host_mem_release(
...@@ -775,7 +773,6 @@ class HiRadixCache(RadixCache): ...@@ -775,7 +773,6 @@ class HiRadixCache(RadixCache):
# change the reference if the node is evicted # change the reference if the node is evicted
# this often happens in the case of KV cache recomputation # this often happens in the case of KV cache recomputation
node.value = value[:prefix_len] node.value = value[:prefix_len]
self.token_to_kv_pool_host.update_synced(node.host_value)
self.evictable_size_ += len(node.value) self.evictable_size_ += len(node.value)
else: else:
self._inc_hit_count(node, chunked) self._inc_hit_count(node, chunked)
...@@ -785,7 +782,6 @@ class HiRadixCache(RadixCache): ...@@ -785,7 +782,6 @@ class HiRadixCache(RadixCache):
new_node = self._split_node(node.key, node, prefix_len) new_node = self._split_node(node.key, node, prefix_len)
if new_node.evicted: if new_node.evicted:
new_node.value = value[:prefix_len] new_node.value = value[:prefix_len]
self.token_to_kv_pool_host.update_synced(new_node.host_value)
self.evictable_size_ += len(new_node.value) self.evictable_size_ += len(new_node.value)
else: else:
self._inc_hit_count(new_node, chunked) self._inc_hit_count(new_node, chunked)
......
...@@ -31,27 +31,13 @@ if not (_is_npu or _is_xpu): ...@@ -31,27 +31,13 @@ if not (_is_npu or _is_xpu):
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class MemoryStateInt(IntEnum): def synchronized(func):
IDLE = 0 @wraps(func)
RESERVED = 1 def wrapper(self, *args, **kwargs):
PROTECTED = 2 with self.lock:
SYNCED = 3 return func(self, *args, **kwargs)
BACKUP = 4
def synchronized(debug_only=False):
def _decorator(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
if (not debug_only) or self.debug:
with self.lock:
return func(self, *args, **kwargs)
else:
return True
return wrapper
return _decorator return wrapper
class HostKVCache(abc.ABC): class HostKVCache(abc.ABC):
...@@ -110,7 +96,6 @@ class HostKVCache(abc.ABC): ...@@ -110,7 +96,6 @@ class HostKVCache(abc.ABC):
# A lock for synchronized operations on memory allocation and state transitions. # A lock for synchronized operations on memory allocation and state transitions.
self.lock = threading.RLock() self.lock = threading.RLock()
self.debug = logger.isEnabledFor(logging.DEBUG)
self.clear() self.clear()
@abc.abstractmethod @abc.abstractmethod
...@@ -161,7 +146,7 @@ class HostKVCache(abc.ABC): ...@@ -161,7 +146,7 @@ class HostKVCache(abc.ABC):
""" """
raise NotImplementedError() raise NotImplementedError()
@synchronized() @synchronized
def clear(self): def clear(self):
# Initialize memory states and tracking structures. # Initialize memory states and tracking structures.
self.mem_state = torch.zeros( self.mem_state = torch.zeros(
...@@ -172,7 +157,7 @@ class HostKVCache(abc.ABC): ...@@ -172,7 +157,7 @@ class HostKVCache(abc.ABC):
def available_size(self): def available_size(self):
return len(self.free_slots) return len(self.free_slots)
@synchronized() @synchronized
def alloc(self, need_size: int) -> Optional[torch.Tensor]: def alloc(self, need_size: int) -> Optional[torch.Tensor]:
assert ( assert (
need_size % self.page_size == 0 need_size % self.page_size == 0
...@@ -183,92 +168,13 @@ class HostKVCache(abc.ABC): ...@@ -183,92 +168,13 @@ class HostKVCache(abc.ABC):
select_index = self.free_slots[:need_size] select_index = self.free_slots[:need_size]
self.free_slots = self.free_slots[need_size:] self.free_slots = self.free_slots[need_size:]
if self.debug:
self.mem_state[select_index] = MemoryStateInt.RESERVED
return select_index return select_index
@synchronized() @synchronized
def free(self, indices: torch.Tensor) -> int: def free(self, indices: torch.Tensor) -> int:
self.free_slots = torch.cat([self.free_slots, indices]) self.free_slots = torch.cat([self.free_slots, indices])
if self.debug:
self.mem_state[indices] = MemoryStateInt.IDLE
return len(indices) return len(indices)
@synchronized(debug_only=True)
def get_state(self, indices: torch.Tensor) -> MemoryStateInt:
assert len(indices) > 0, "The indices should not be empty"
states = self.mem_state[indices]
assert (
states == states[0]
).all(), "The memory slots should have the same state {}".format(states)
return MemoryStateInt(states[0].item())
@synchronized(debug_only=True)
def is_reserved(self, indices: torch.Tensor) -> bool:
return self.get_state(indices) == MemoryStateInt.RESERVED
@synchronized(debug_only=True)
def is_protected(self, indices: torch.Tensor) -> bool:
return self.get_state(indices) == MemoryStateInt.PROTECTED
@synchronized(debug_only=True)
def is_synced(self, indices: torch.Tensor) -> bool:
return self.get_state(indices) == MemoryStateInt.SYNCED
@synchronized(debug_only=True)
def is_backup(self, indices: torch.Tensor) -> bool:
return self.get_state(indices) == MemoryStateInt.BACKUP
@synchronized(debug_only=True)
def update_backup(self, indices: torch.Tensor):
if not self.is_synced(indices):
raise ValueError(
f"The host memory slots should be in SYNCED state before turning into BACKUP. "
f"Current state: {self.get_state(indices)}"
)
self.mem_state[indices] = MemoryStateInt.BACKUP
@synchronized(debug_only=True)
def update_prefetch(self, indices: torch.Tensor):
if not self.is_reserved(indices):
raise ValueError(
f"The host memory slots should be in RESERVED state before turning into BACKUP. "
f"Current state: {self.get_state(indices)}"
)
self.mem_state[indices] = MemoryStateInt.BACKUP
@synchronized(debug_only=True)
def update_synced(self, indices: torch.Tensor):
self.mem_state[indices] = MemoryStateInt.SYNCED
@synchronized(debug_only=True)
def protect_write(self, indices: torch.Tensor):
if not self.is_reserved(indices):
raise ValueError(
f"The host memory slots should be RESERVED before write operations. "
f"Current state: {self.get_state(indices)}"
)
self.mem_state[indices] = MemoryStateInt.PROTECTED
@synchronized(debug_only=True)
def protect_load(self, indices: torch.Tensor):
if not self.is_backup(indices):
raise ValueError(
f"The host memory slots should be in BACKUP state before load operations. "
f"Current state: {self.get_state(indices)}"
)
self.mem_state[indices] = MemoryStateInt.PROTECTED
@synchronized(debug_only=True)
def complete_io(self, indices: torch.Tensor):
if not self.is_protected(indices):
raise ValueError(
f"The host memory slots should be PROTECTED during I/O operations. "
f"Current state: {self.get_state(indices)}"
)
self.mem_state[indices] = MemoryStateInt.SYNCED
class MHATokenToKVPoolHost(HostKVCache): class MHATokenToKVPoolHost(HostKVCache):
device_pool: MHATokenToKVPool device_pool: MHATokenToKVPool
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment