Unverified Commit e6b7053b authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Fix a bug in abort & Improve docstrings for abort (#6931)

parent 5f91c825
...@@ -2041,10 +2041,23 @@ class Scheduler( ...@@ -2041,10 +2041,23 @@ class Scheduler(
# Sort in reverse order to avoid index issues when deleting # Sort in reverse order to avoid index issues when deleting
for i in reversed(to_del): for i in reversed(to_del):
# Abort method 1: directly pop from the queue
# This only works for requests that have not started anything.
# We still need to send something back to TokenizerManager to clean up the state.
req = self.waiting_queue.pop(i) req = self.waiting_queue.pop(i)
self.send_to_tokenizer.send_pyobj(AbortReq(req.rid)) self.send_to_tokenizer.send_pyobj(AbortReq(req.rid))
logger.debug(f"Abort queued request. {req.rid=}") logger.debug(f"Abort queued request. {req.rid=}")
# Delete the requests in the grammar queue
for req in self.grammar_queue:
# Abort method 2: call `set_finish_with_abort`
# The request will still run one prefill forward pass.
# In this case, we change the input_ids to be only one token to make this prefill cheap.
if req.rid.startswith(recv_req.rid):
logger.debug(f"Abort grammar queue request. {req.rid=}")
req.grammar.cancel()
req.set_finish_with_abort("Aborted by AbortReq.")
# Delete requests in the running batch # Delete requests in the running batch
if self.cur_batch is self.running_batch or self.cur_batch is None: if self.cur_batch is self.running_batch or self.cur_batch is None:
reqs = self.running_batch.reqs reqs = self.running_batch.reqs
...@@ -2053,17 +2066,12 @@ class Scheduler( ...@@ -2053,17 +2066,12 @@ class Scheduler(
for req in reqs: for req in reqs:
if req.rid.startswith(recv_req.rid) and not req.finished(): if req.rid.startswith(recv_req.rid) and not req.finished():
# Abort method 3: set `to_abort=True`
# The request will still run one decode forward pass.
# Then we reuse all existing code to clean up the KV cache allocation.
logger.debug(f"Abort running request. {req.rid=}") logger.debug(f"Abort running request. {req.rid=}")
# We must use to_abort because it is in a running batch
req.to_abort = True req.to_abort = True
# Delete the requests in the grammar queue
for req in self.grammar_queue:
if req.rid.startswith(recv_req.rid):
logger.debug(f"Abort grammar queue request. {req.rid=}")
req.grammar.cancel()
req.set_finish_with_abort("Aborted by AbortReq.")
def _pause_engine(self) -> Tuple[List[Req], int]: def _pause_engine(self) -> Tuple[List[Req], int]:
raise NotImplementedError() raise NotImplementedError()
......
...@@ -1419,7 +1419,7 @@ class TokenizerManager: ...@@ -1419,7 +1419,7 @@ class TokenizerManager:
asyncio.create_task(asyncio.to_thread(background_task)) asyncio.create_task(asyncio.to_thread(background_task))
def _handle_abort_req(self, recv_obj): def _handle_abort_req(self, recv_obj):
self.rid_to_state.pop(recv_obj.rid) self.rid_to_state.pop(recv_obj.rid, None)
def _handle_open_session_req_output(self, recv_obj): def _handle_open_session_req_output(self, recv_obj):
self.session_futures[recv_obj.session_id].set_result( self.session_futures[recv_obj.session_id].set_result(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment