Unverified Commit 9416ee60 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Reserved abortion API when retracting (#12425)

parent d4a09ec9
...@@ -1230,7 +1230,7 @@ class AbortReq(BaseReq): ...@@ -1230,7 +1230,7 @@ class AbortReq(BaseReq):
abort_all: bool = False abort_all: bool = False
# The finished reason data # The finished reason data
finished_reason: Optional[Dict[str, Any]] = None finished_reason: Optional[Dict[str, Any]] = None
abort_reason: Optional[str] = None abort_message: Optional[str] = None
def __post_init__(self): def __post_init__(self):
# FIXME: This is a hack to keep the same with the old code # FIXME: This is a hack to keep the same with the old code
......
...@@ -1445,7 +1445,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1445,7 +1445,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
evict_from_tree_cache(self.tree_cache, num_tokens) evict_from_tree_cache(self.tree_cache, num_tokens)
return self._is_available_size_sufficient(num_tokens) return self._is_available_size_sufficient(num_tokens)
def retract_decode(self, server_args: ServerArgs): def retract_decode(
self, server_args: ServerArgs
) -> Tuple[List[Req], float, List[Req]]:
"""Retract the decoding requests when there is not enough memory.""" """Retract the decoding requests when there is not enough memory."""
sorted_indices = list(range(len(self.reqs))) sorted_indices = list(range(len(self.reqs)))
...@@ -1513,7 +1515,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1513,7 +1515,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
) # avoid zero division ) # avoid zero division
new_estimate_ratio = min(1.0, new_estimate_ratio) new_estimate_ratio = min(1.0, new_estimate_ratio)
return retracted_reqs, new_estimate_ratio return retracted_reqs, new_estimate_ratio, []
def release_req(self, idx: int, remaing_req_count: int, server_args: ServerArgs): def release_req(self, idx: int, remaing_req_count: int, server_args: ServerArgs):
req = self.reqs[idx] req = self.reqs[idx]
......
...@@ -1902,9 +1902,16 @@ class Scheduler( ...@@ -1902,9 +1902,16 @@ class Scheduler(
TEST_RETRACT and self.forward_ct % TEST_RETRACT_INTERVAL == 0 TEST_RETRACT and self.forward_ct % TEST_RETRACT_INTERVAL == 0
): ):
old_ratio = self.new_token_ratio old_ratio = self.new_token_ratio
retracted_reqs, new_token_ratio = batch.retract_decode(self.server_args) retracted_reqs, new_token_ratio, reqs_to_abort = batch.retract_decode(
self.server_args
)
self.num_retracted_reqs = len(retracted_reqs) self.num_retracted_reqs = len(retracted_reqs)
self.new_token_ratio = new_token_ratio self.new_token_ratio = new_token_ratio
for req in reqs_to_abort:
abort_reason: FINISH_ABORT = req.to_finish
self.send_to_tokenizer.send_output(
AbortReq(abort_message=abort_reason.message, rid=req.rid), req
)
logger.info( logger.info(
"KV cache pool is full. Retract requests. " "KV cache pool is full. Retract requests. "
......
...@@ -1781,7 +1781,7 @@ class TokenizerManager(TokenizerCommunicatorMixin): ...@@ -1781,7 +1781,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
state = self.rid_to_state[recv_obj.rid] state = self.rid_to_state[recv_obj.rid]
state.finished = True state.finished = True
abort_message = recv_obj.abort_reason or "Abort in waiting queue" abort_message = recv_obj.abort_message or "Abort in waiting queue"
finish_reason = { finish_reason = {
"type": "abort", "type": "abort",
"message": abort_message, "message": abort_message,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment