"sgl-kernel/python/vscode:/vscode.git/clone" did not exist on "532f998b0f894268b69b7310bf06349e26b8543c"
Unverified Commit 4d253057 authored by Zhiqiang Xie's avatar Zhiqiang Xie Committed by GitHub
Browse files

Move mem_state update into debug mode (#4525)

parent 11577ced
......@@ -580,13 +580,20 @@ class MemoryStateInt(IntEnum):
BACKUP = 4
def synchronized(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
with self.lock:
return func(self, *args, **kwargs)
def synchronized(debug_only=False):
def _decorator(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
if (not debug_only) or self.debug:
return func(self, *args, **kwargs)
with self.lock:
return func(self, *args, **kwargs)
else:
return True
return wrapper
return wrapper
return _decorator
class HostKVCache(abc.ABC):
......@@ -631,13 +638,9 @@ class HostKVCache(abc.ABC):
self.kv_buffer = self.init_kv_buffer()
# Initialize memory states and tracking structures.
self.mem_state = torch.zeros(
(self.size,), dtype=torch.uint8, device=self.device
)
# A lock for synchronized operations on memory allocation and state transitions.
self.lock = threading.RLock()
self.debug = logger.isEnabledFor(logging.DEBUG)
self.clear()
@abc.abstractmethod
......@@ -664,97 +667,102 @@ class HostKVCache(abc.ABC):
def assign_flat_data(self, indices, flat_data):
raise NotImplementedError()
@synchronized
@synchronized()
def clear(self):
self.mem_state.fill_(0)
self.can_use_mem_size = self.size
# Initialize memory states and tracking structures.
self.mem_state = torch.zeros(
(self.size,), dtype=torch.uint8, device=self.device
)
self.free_slots = torch.arange(self.size, dtype=torch.int64)
@synchronized
def get_state(self, indices: torch.Tensor) -> MemoryStateInt:
assert len(indices) > 0, "The indices should not be empty"
states = self.mem_state[indices]
assert (
states == states[0]
).all(), "The memory slots should have the same state {}".format(states)
return MemoryStateInt(states[0].item())
def available_size(self):
return len(self.free_slots)
@synchronized
@synchronized()
def alloc(self, need_size: int) -> torch.Tensor:
if need_size > self.can_use_mem_size:
if need_size > self.available_size():
return None
# todo: de-fragementation
select_index = self.free_slots[:need_size]
self.free_slots = self.free_slots[need_size:]
self.mem_state[select_index] = MemoryStateInt.RESERVED
self.can_use_mem_size -= need_size
if self.debug:
self.mem_state[select_index] = MemoryStateInt.RESERVED
return select_index
@synchronized
@synchronized()
def free(self, indices: torch.Tensor) -> int:
self.free_slots = torch.cat([self.free_slots, indices])
if self.debug:
self.mem_state[indices] = MemoryStateInt.IDLE
return len(indices)
@synchronized(debug_only=True)
def get_state(self, indices: torch.Tensor) -> MemoryStateInt:
assert len(indices) > 0, "The indices should not be empty"
states = self.mem_state[indices]
assert (
states == states[0]
).all(), "The memory slots should have the same state {}".format(states)
return MemoryStateInt(states[0].item())
@synchronized(debug_only=True)
def is_reserved(self, indices: torch.Tensor) -> bool:
return self.get_state(indices) == MemoryStateInt.RESERVED
@synchronized
@synchronized(debug_only=True)
def is_protected(self, indices: torch.Tensor) -> bool:
return self.get_state(indices) == MemoryStateInt.PROTECTED
@synchronized
@synchronized(debug_only=True)
def is_synced(self, indices: torch.Tensor) -> bool:
return self.get_state(indices) == MemoryStateInt.SYNCED
@synchronized
@synchronized(debug_only=True)
def is_backup(self, indices: torch.Tensor) -> bool:
return self.get_state(indices) == MemoryStateInt.BACKUP
@synchronized
@synchronized(debug_only=True)
def update_backup(self, indices: torch.Tensor):
assert self.is_synced(indices), (
f"The host memory slots should be in SYNCED state before turning into BACKUP. "
f"Current state: {self.get_state(indices)}"
)
if not self.is_synced(indices):
raise ValueError(
f"The host memory slots should be in SYNCED state before turning into BACKUP. "
f"Current state: {self.get_state(indices)}"
)
self.mem_state[indices] = MemoryStateInt.BACKUP
@synchronized
@synchronized(debug_only=True)
def update_synced(self, indices: torch.Tensor):
self.mem_state[indices] = MemoryStateInt.SYNCED
@synchronized
@synchronized(debug_only=True)
def protect_write(self, indices: torch.Tensor):
assert self.is_reserved(indices), (
f"The host memory slots should be RESERVED before write operations. "
f"Current state: {self.get_state(indices)}"
)
if not self.is_reserved(indices):
raise ValueError(
f"The host memory slots should be RESERVED before write operations. "
f"Current state: {self.get_state(indices)}"
)
self.mem_state[indices] = MemoryStateInt.PROTECTED
@synchronized
@synchronized(debug_only=True)
def protect_load(self, indices: torch.Tensor):
assert self.is_backup(indices), (
f"The host memory slots should be in BACKUP state before load operations. "
f"Current state: {self.get_state(indices)}"
)
if not self.is_backup(indices):
raise ValueError(
f"The host memory slots should be in BACKUP state before load operations. "
f"Current state: {self.get_state(indices)}"
)
self.mem_state[indices] = MemoryStateInt.PROTECTED
@synchronized
@synchronized(debug_only=True)
def complete_io(self, indices: torch.Tensor):
assert self.is_protected(indices), (
f"The host memory slots should be PROTECTED during I/O operations. "
f"Current state: {self.get_state(indices)}"
)
if not self.is_protected(indices):
raise ValueError(
f"The host memory slots should be PROTECTED during I/O operations. "
f"Current state: {self.get_state(indices)}"
)
self.mem_state[indices] = MemoryStateInt.SYNCED
def available_size(self):
return len(self.free_slots)
@synchronized
def free(self, indices: torch.Tensor) -> int:
self.mem_state[indices] = MemoryStateInt.IDLE
self.free_slots = torch.cat([self.free_slots, indices])
self.can_use_mem_size += len(indices)
return len(indices)
class MHATokenToKVPoolHost(HostKVCache):
def __init__(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment