Unverified Commit 33467c05 authored by Shisong Ma's avatar Shisong Ma Committed by GitHub
Browse files

[BUG FIX] add fail check when get fail in case wait complete block (#9971)


Co-authored-by: default avatarmashisong <mashisong@bytedance.com>
Co-authored-by: default avatarZhiqiang Xie <xiezhq@stanford.edu>
parent b0fcbb74
...@@ -207,26 +207,25 @@ class PrefetchOperation(StorageOperation): ...@@ -207,26 +207,25 @@ class PrefetchOperation(StorageOperation):
): ):
self.request_id = request_id self.request_id = request_id
self._done_flag = False
self._lock = threading.Lock() self._lock = threading.Lock()
self._terminated_flag = False
self.start_time = time.monotonic() self.start_time = time.monotonic()
super().__init__(host_indices, token_ids, last_hash) super().__init__(host_indices, token_ids, last_hash)
def increment(self, num_tokens: int): def increment(self, num_tokens: int):
with self._lock: with self._lock:
if self._done_flag: if self._terminated_flag:
return False return False
self.completed_tokens += num_tokens self.completed_tokens += num_tokens
return True return True
def mark_done(self): def mark_terminate(self):
with self._lock: with self._lock:
self._done_flag = True self._terminated_flag = True
def is_done(self) -> bool: def is_terminated(self) -> bool:
return self._done_flag return self._terminated_flag
class HiCacheController: class HiCacheController:
...@@ -628,7 +627,7 @@ class HiCacheController: ...@@ -628,7 +627,7 @@ class HiCacheController:
return operation return operation
def terminate_prefetch(self, operation): def terminate_prefetch(self, operation):
operation.mark_done() operation.mark_terminate()
return operation.completed_tokens, operation.hash_value return operation.completed_tokens, operation.hash_value
def append_host_mem_release(self, host_indices: torch.Tensor): def append_host_mem_release(self, host_indices: torch.Tensor):
...@@ -709,6 +708,7 @@ class HiCacheController: ...@@ -709,6 +708,7 @@ class HiCacheController:
operation.completed_tokens operation.completed_tokens
!= prev_completed_tokens + len(batch_hashes) * self.page_size != prev_completed_tokens + len(batch_hashes) * self.page_size
): ):
operation.mark_terminate()
break # Some operations fail or operation terminated by controller break # Some operations fail or operation terminated by controller
# release pre-allocated memory # release pre-allocated memory
self.append_host_mem_release( self.append_host_mem_release(
......
...@@ -482,15 +482,22 @@ class HiRadixCache(RadixCache): ...@@ -482,15 +482,22 @@ class HiRadixCache(RadixCache):
# unknown prefetch stop policy, just return True # unknown prefetch stop policy, just return True
return True return True
operation_terminated = operation.is_terminated()
if self.tp_world_size > 1: if self.tp_world_size > 1:
can_terminate = torch.tensor(can_terminate, dtype=torch.int) states = torch.tensor(
[1 - int(can_terminate), int(operation_terminated)],
dtype=torch.int,
)
torch.distributed.all_reduce( torch.distributed.all_reduce(
can_terminate, states,
op=torch.distributed.ReduceOp.MIN, op=torch.distributed.ReduceOp.MAX,
group=self.tp_group, group=self.tp_group,
) )
can_terminate = bool(can_terminate.item()) can_terminate = states[0].item() == 0
operation_terminated = states[1].item() == 1
# the operation should be terminated if it is already terminated on any TP worker
# or it meets the termination condition on all TP workers
can_terminate = can_terminate or operation_terminated
return can_terminate return can_terminate
def check_prefetch_progress(self, req_id: str) -> bool: def check_prefetch_progress(self, req_id: str) -> bool:
...@@ -517,7 +524,7 @@ class HiRadixCache(RadixCache): ...@@ -517,7 +524,7 @@ class HiRadixCache(RadixCache):
logger.debug(f"Prefetch {req_id} completed with {completed_tokens} tokens") logger.debug(f"Prefetch {req_id} completed with {completed_tokens} tokens")
min_completed_tokens = completed_tokens min_completed_tokens = completed_tokens
if self.tp_world_size > 1 and self.prefetch_stop_policy != "wait_complete": if self.tp_world_size > 1:
# synchrnoize TP workers to make the same update to hiradix cache # synchrnoize TP workers to make the same update to hiradix cache
completed_tokens_tensor = torch.tensor( completed_tokens_tensor = torch.tensor(
min_completed_tokens, dtype=torch.int min_completed_tokens, dtype=torch.int
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment