Unverified Commit 4d253057 authored by Zhiqiang Xie's avatar Zhiqiang Xie Committed by GitHub
Browse files

Move mem_state update into debug mode (#4525)

parent 11577ced
...@@ -580,14 +580,21 @@ class MemoryStateInt(IntEnum): ...@@ -580,14 +580,21 @@ class MemoryStateInt(IntEnum):
BACKUP = 4 BACKUP = 4
def synchronized(func): def synchronized(debug_only=False):
def _decorator(func):
@wraps(func) @wraps(func)
def wrapper(self, *args, **kwargs): def wrapper(self, *args, **kwargs):
if (not debug_only) or self.debug:
return func(self, *args, **kwargs)
with self.lock: with self.lock:
return func(self, *args, **kwargs) return func(self, *args, **kwargs)
else:
return True
return wrapper return wrapper
return _decorator
class HostKVCache(abc.ABC): class HostKVCache(abc.ABC):
...@@ -631,13 +638,9 @@ class HostKVCache(abc.ABC): ...@@ -631,13 +638,9 @@ class HostKVCache(abc.ABC):
self.kv_buffer = self.init_kv_buffer() self.kv_buffer = self.init_kv_buffer()
# Initialize memory states and tracking structures.
self.mem_state = torch.zeros(
(self.size,), dtype=torch.uint8, device=self.device
)
# A lock for synchronized operations on memory allocation and state transitions. # A lock for synchronized operations on memory allocation and state transitions.
self.lock = threading.RLock() self.lock = threading.RLock()
self.debug = logger.isEnabledFor(logging.DEBUG)
self.clear() self.clear()
@abc.abstractmethod @abc.abstractmethod
...@@ -664,97 +667,102 @@ class HostKVCache(abc.ABC): ...@@ -664,97 +667,102 @@ class HostKVCache(abc.ABC):
def assign_flat_data(self, indices, flat_data): def assign_flat_data(self, indices, flat_data):
raise NotImplementedError() raise NotImplementedError()
@synchronized @synchronized()
def clear(self): def clear(self):
self.mem_state.fill_(0) # Initialize memory states and tracking structures.
self.can_use_mem_size = self.size self.mem_state = torch.zeros(
(self.size,), dtype=torch.uint8, device=self.device
)
self.free_slots = torch.arange(self.size, dtype=torch.int64) self.free_slots = torch.arange(self.size, dtype=torch.int64)
@synchronized def available_size(self):
def get_state(self, indices: torch.Tensor) -> MemoryStateInt: return len(self.free_slots)
assert len(indices) > 0, "The indices should not be empty"
states = self.mem_state[indices]
assert (
states == states[0]
).all(), "The memory slots should have the same state {}".format(states)
return MemoryStateInt(states[0].item())
@synchronized @synchronized()
def alloc(self, need_size: int) -> torch.Tensor: def alloc(self, need_size: int) -> torch.Tensor:
if need_size > self.can_use_mem_size: if need_size > self.available_size():
return None return None
# todo: de-fragementation
select_index = self.free_slots[:need_size] select_index = self.free_slots[:need_size]
self.free_slots = self.free_slots[need_size:] self.free_slots = self.free_slots[need_size:]
if self.debug:
self.mem_state[select_index] = MemoryStateInt.RESERVED self.mem_state[select_index] = MemoryStateInt.RESERVED
self.can_use_mem_size -= need_size
return select_index return select_index
@synchronized @synchronized()
def free(self, indices: torch.Tensor) -> int:
self.free_slots = torch.cat([self.free_slots, indices])
if self.debug:
self.mem_state[indices] = MemoryStateInt.IDLE
return len(indices)
@synchronized(debug_only=True)
def get_state(self, indices: torch.Tensor) -> MemoryStateInt:
assert len(indices) > 0, "The indices should not be empty"
states = self.mem_state[indices]
assert (
states == states[0]
).all(), "The memory slots should have the same state {}".format(states)
return MemoryStateInt(states[0].item())
@synchronized(debug_only=True)
def is_reserved(self, indices: torch.Tensor) -> bool: def is_reserved(self, indices: torch.Tensor) -> bool:
return self.get_state(indices) == MemoryStateInt.RESERVED return self.get_state(indices) == MemoryStateInt.RESERVED
@synchronized @synchronized(debug_only=True)
def is_protected(self, indices: torch.Tensor) -> bool: def is_protected(self, indices: torch.Tensor) -> bool:
return self.get_state(indices) == MemoryStateInt.PROTECTED return self.get_state(indices) == MemoryStateInt.PROTECTED
@synchronized @synchronized(debug_only=True)
def is_synced(self, indices: torch.Tensor) -> bool: def is_synced(self, indices: torch.Tensor) -> bool:
return self.get_state(indices) == MemoryStateInt.SYNCED return self.get_state(indices) == MemoryStateInt.SYNCED
@synchronized @synchronized(debug_only=True)
def is_backup(self, indices: torch.Tensor) -> bool: def is_backup(self, indices: torch.Tensor) -> bool:
return self.get_state(indices) == MemoryStateInt.BACKUP return self.get_state(indices) == MemoryStateInt.BACKUP
@synchronized @synchronized(debug_only=True)
def update_backup(self, indices: torch.Tensor): def update_backup(self, indices: torch.Tensor):
assert self.is_synced(indices), ( if not self.is_synced(indices):
raise ValueError(
f"The host memory slots should be in SYNCED state before turning into BACKUP. " f"The host memory slots should be in SYNCED state before turning into BACKUP. "
f"Current state: {self.get_state(indices)}" f"Current state: {self.get_state(indices)}"
) )
self.mem_state[indices] = MemoryStateInt.BACKUP self.mem_state[indices] = MemoryStateInt.BACKUP
@synchronized @synchronized(debug_only=True)
def update_synced(self, indices: torch.Tensor): def update_synced(self, indices: torch.Tensor):
self.mem_state[indices] = MemoryStateInt.SYNCED self.mem_state[indices] = MemoryStateInt.SYNCED
@synchronized @synchronized(debug_only=True)
def protect_write(self, indices: torch.Tensor): def protect_write(self, indices: torch.Tensor):
assert self.is_reserved(indices), ( if not self.is_reserved(indices):
raise ValueError(
f"The host memory slots should be RESERVED before write operations. " f"The host memory slots should be RESERVED before write operations. "
f"Current state: {self.get_state(indices)}" f"Current state: {self.get_state(indices)}"
) )
self.mem_state[indices] = MemoryStateInt.PROTECTED self.mem_state[indices] = MemoryStateInt.PROTECTED
@synchronized @synchronized(debug_only=True)
def protect_load(self, indices: torch.Tensor): def protect_load(self, indices: torch.Tensor):
assert self.is_backup(indices), ( if not self.is_backup(indices):
raise ValueError(
f"The host memory slots should be in BACKUP state before load operations. " f"The host memory slots should be in BACKUP state before load operations. "
f"Current state: {self.get_state(indices)}" f"Current state: {self.get_state(indices)}"
) )
self.mem_state[indices] = MemoryStateInt.PROTECTED self.mem_state[indices] = MemoryStateInt.PROTECTED
@synchronized @synchronized(debug_only=True)
def complete_io(self, indices: torch.Tensor): def complete_io(self, indices: torch.Tensor):
assert self.is_protected(indices), ( if not self.is_protected(indices):
raise ValueError(
f"The host memory slots should be PROTECTED during I/O operations. " f"The host memory slots should be PROTECTED during I/O operations. "
f"Current state: {self.get_state(indices)}" f"Current state: {self.get_state(indices)}"
) )
self.mem_state[indices] = MemoryStateInt.SYNCED self.mem_state[indices] = MemoryStateInt.SYNCED
def available_size(self):
return len(self.free_slots)
@synchronized
def free(self, indices: torch.Tensor) -> int:
self.mem_state[indices] = MemoryStateInt.IDLE
self.free_slots = torch.cat([self.free_slots, indices])
self.can_use_mem_size += len(indices)
return len(indices)
class MHATokenToKVPoolHost(HostKVCache): class MHATokenToKVPoolHost(HostKVCache):
def __init__( def __init__(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment