"tools/git@developer.sourcefind.cn:OpenDAS/mmdetection3d.git" did not exist on "49a2bc85952b3fcb0651d78cfea0c307ce8d65c6"
Unverified Commit 14cbe42f authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Refactor abortion in event loop (#12312)

parent 685c0645
...@@ -505,16 +505,15 @@ class Req: ...@@ -505,16 +505,15 @@ class Req:
# Check finish # Check finish
self.tokenizer = None self.tokenizer = None
self.finished_reason = None self.finished_reason: Optional[BaseFinishReason] = None
# finished position (in output_ids), used when checking stop conditions with speculative decoding # finished position (in output_ids), used when checking stop conditions with speculative decoding
self.finished_len = None self.finished_len = None
# Whether this request has finished output # Whether this request has finished output
self.finished_output = None self.finished_output = None
# If we want to abort the request in the middle of the event loop, set this to true # If we want to abort the request in the middle of the event loop,
# set to_finish instead of directly setting finished_reason.
# Note: We should never set finished_reason in the middle, the req will get filtered and never respond # Note: We should never set finished_reason in the middle, the req will get filtered and never respond
self.to_abort = False self.to_finish: Optional[BaseFinishReason] = None
# This carries the error message for `.to_abort` and will be attached to the finished_reason at the end of the event loop
self.to_abort_message: str = None
self.stream = stream self.stream = stream
self.eos_token_ids = eos_token_ids self.eos_token_ids = eos_token_ids
self.vocab_size = vocab_size self.vocab_size = vocab_size
...@@ -866,10 +865,9 @@ class Req: ...@@ -866,10 +865,9 @@ class Req:
if self.finished(): if self.finished():
return return
if self.to_abort: if self.to_finish:
self.finished_reason = FINISH_ABORT( self.finished_reason = self.to_finish
message=self.to_abort_message, self.to_finish = None
)
return return
if len(self.output_ids) >= self.sampling_params.max_new_tokens: if len(self.output_ids) >= self.sampling_params.max_new_tokens:
...@@ -945,7 +943,7 @@ class Req: ...@@ -945,7 +943,7 @@ class Req:
self.grammar = None self.grammar = None
self.origin_input_ids = [0] # set it to one token to skip the long prefill self.origin_input_ids = [0] # set it to one token to skip the long prefill
self.return_logprob = False self.return_logprob = False
self.finished_reason = FINISH_ABORT( self.to_finish = FINISH_ABORT(
error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError" error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
) )
...@@ -1509,7 +1507,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1509,7 +1507,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
) # avoid zero division ) # avoid zero division
new_estimate_ratio = min(1.0, new_estimate_ratio) new_estimate_ratio = min(1.0, new_estimate_ratio)
return retracted_reqs, new_estimate_ratio, [] return retracted_reqs, new_estimate_ratio
def release_req(self, idx: int, remaing_req_count: int, server_args: ServerArgs): def release_req(self, idx: int, remaing_req_count: int, server_args: ServerArgs):
req = self.reqs[idx] req = self.reqs[idx]
......
...@@ -1817,20 +1817,13 @@ class Scheduler( ...@@ -1817,20 +1817,13 @@ class Scheduler(
TEST_RETRACT and self.forward_ct % TEST_RETRACT_INTERVAL == 0 TEST_RETRACT and self.forward_ct % TEST_RETRACT_INTERVAL == 0
): ):
old_ratio = self.new_token_ratio old_ratio = self.new_token_ratio
retracted_reqs, new_token_ratio, reqs_to_abort = batch.retract_decode( retracted_reqs, new_token_ratio = batch.retract_decode(self.server_args)
self.server_args
)
self.num_retracted_reqs = len(retracted_reqs) self.num_retracted_reqs = len(retracted_reqs)
self.new_token_ratio = new_token_ratio self.new_token_ratio = new_token_ratio
for req in reqs_to_abort:
self.send_to_tokenizer.send_output(
AbortReq(abort_reason=req.to_abort_message, rid=req.rid), req
)
logger.info( logger.info(
"KV cache pool is full. Retract requests. " "KV cache pool is full. Retract requests. "
f"#retracted_reqs: {len(retracted_reqs)}, " f"#retracted_reqs: {len(retracted_reqs)}, "
f"#aborted_retracted_reqs: {len(reqs_to_abort)}, "
f"#new_token_ratio: {old_ratio:.4f} -> {new_token_ratio:.4f}" f"#new_token_ratio: {old_ratio:.4f} -> {new_token_ratio:.4f}"
) )
...@@ -2534,11 +2527,11 @@ class Scheduler( ...@@ -2534,11 +2527,11 @@ class Scheduler(
if not req.finished() and ( if not req.finished() and (
recv_req.abort_all or req.rid.startswith(recv_req.rid) recv_req.abort_all or req.rid.startswith(recv_req.rid)
): ):
# Abort method 3: set `to_abort=True` # Abort method 3: set `to_finish`
# The request will still run one decode forward pass. # The request will still run one decode forward pass.
# Then we reuse all existing code to clean up the KV cache allocation. # Then we reuse all existing code to clean up the KV cache allocation.
logger.debug(f"Abort running request. {req.rid=}") logger.debug(f"Abort running request. {req.rid=}")
req.to_abort = True req.to_finish = FINISH_ABORT()
def _pause_engine(self) -> Tuple[List[Req], int]: def _pause_engine(self) -> Tuple[List[Req], int]:
raise NotImplementedError() raise NotImplementedError()
......
...@@ -789,7 +789,7 @@ class SchedulerOutputProcessorMixin: ...@@ -789,7 +789,7 @@ class SchedulerOutputProcessorMixin:
continue continue
# Multimodal partial stream chunks break the detokenizer, so drop aborted requests here. # Multimodal partial stream chunks break the detokenizer, so drop aborted requests here.
if self.model_config.is_multimodal_gen and req.to_abort: if self.model_config.is_multimodal_gen and req.to_finish:
continue continue
if req.finished(): if req.finished():
......
...@@ -15,11 +15,11 @@ import uuid ...@@ -15,11 +15,11 @@ import uuid
from typing import Dict, Optional from typing import Dict, Optional
from sglang.srt.managers.io_struct import TokenizedGenerateReqInput from sglang.srt.managers.io_struct import TokenizedGenerateReqInput
from sglang.srt.managers.schedule_batch import Req from sglang.srt.managers.schedule_batch import FINISH_ABORT, Req
class SessionReqNode: class SessionReqNode:
def __init__(self, req, parent=None, childs=None): def __init__(self, req: Req, parent=None, childs=None):
self.req = req self.req = req
self.parent = parent self.parent = parent
if parent is not None: if parent is not None:
...@@ -36,12 +36,12 @@ class SessionReqNode: ...@@ -36,12 +36,12 @@ class SessionReqNode:
req_node.clear(req_dict) req_node.clear(req_dict)
if self.req.finished_reason is None: if self.req.finished_reason is None:
self.req.to_abort = True self.req.to_finish = FINISH_ABORT()
del req_dict[self.req.rid] del req_dict[self.req.rid]
def abort(self): def abort(self):
if self.req.finished_reason is None: if self.req.finished_reason is None:
self.req.to_abort = True self.req.to_finish = FINISH_ABORT()
def __str__(self): def __str__(self):
return self._str_helper(self.req.rid) return self._str_helper(self.req.rid)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment