Unverified Commit 3bde1010 authored by Byron Hsu's avatar Byron Hsu Committed by GitHub
Browse files

[PD] Abort request if transfer fails (#6504)

parent 75135580
...@@ -41,6 +41,7 @@ from sglang.srt.disaggregation.utils import ( ...@@ -41,6 +41,7 @@ from sglang.srt.disaggregation.utils import (
is_mla_backend, is_mla_backend,
kv_to_page_indices, kv_to_page_indices,
poll_and_all_reduce, poll_and_all_reduce,
prepare_abort,
) )
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
...@@ -178,7 +179,17 @@ class DecodePreallocQueue: ...@@ -178,7 +179,17 @@ class DecodePreallocQueue:
elif poll == KVPoll.WaitingForInput: elif poll == KVPoll.WaitingForInput:
decode_req.waiting_for_input = True decode_req.waiting_for_input = True
elif poll == KVPoll.Failed: elif poll == KVPoll.Failed:
raise Exception("Handshake failed") error_message = f"Decode handshake failed for request rank={self.tp_rank} {decode_req.req.rid=} {decode_req.req.bootstrap_room=}"
try:
decode_req.kv_receiver.failure_exception()
except Exception as e:
error_message += f" with exception {e}"
logger.error(error_message)
prepare_abort(
decode_req.req,
error_message,
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
)
def pop_preallocated(self) -> List[DecodeRequest]: def pop_preallocated(self) -> List[DecodeRequest]:
"""Pop the preallocated requests from the pending queue (FIFO).""" """Pop the preallocated requests from the pending queue (FIFO)."""
...@@ -333,7 +344,24 @@ class DecodeTransferQueue: ...@@ -333,7 +344,24 @@ class DecodeTransferQueue:
indices_to_remove = set() indices_to_remove = set()
for i, (decode_req, poll) in enumerate(zip(self.queue, polls)): for i, (decode_req, poll) in enumerate(zip(self.queue, polls)):
if poll == KVPoll.Failed: if poll == KVPoll.Failed:
raise Exception("Transfer failed") error_message = f"Decode transfer failed for request rank={self.tp_rank} {decode_req.req.rid=} {decode_req.req.bootstrap_room=}"
try:
decode_req.kv_receiver.failure_exception()
except Exception as e:
error_message += f" with exception {e}"
logger.error(error_message)
prepare_abort(
decode_req.req,
error_message,
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
)
self.scheduler.stream_output(
[decode_req.req], decode_req.req.return_logprob
)
# unlock the kv cache or it will have memory leak
self.tree_cache.cache_finished_req(decode_req.req)
indices_to_remove.add(i)
continue
elif poll == KVPoll.Success: elif poll == KVPoll.Success:
# pop and push it to waiting queue # pop and push it to waiting queue
idx = decode_req.metadata_buffer_index idx = decode_req.metadata_buffer_index
......
...@@ -496,6 +496,7 @@ class MooncakeKVSender(BaseKVSender): ...@@ -496,6 +496,7 @@ class MooncakeKVSender(BaseKVSender):
return self.kv_mgr.check_status(self.bootstrap_room) return self.kv_mgr.check_status(self.bootstrap_room)
def failure_exception(self): def failure_exception(self):
# TODO: raise a real exception
raise Exception("Fake KVSender Exception") raise Exception("Fake KVSender Exception")
...@@ -723,6 +724,7 @@ class MooncakeKVReceiver(BaseKVReceiver): ...@@ -723,6 +724,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
return self.kv_mgr.check_status(self.bootstrap_room) return self.kv_mgr.check_status(self.bootstrap_room)
def failure_exception(self): def failure_exception(self):
# TODO: raise a real exception
raise Exception("Fake KVReceiver Exception") raise Exception("Fake KVReceiver Exception")
......
...@@ -38,6 +38,7 @@ from sglang.srt.disaggregation.utils import ( ...@@ -38,6 +38,7 @@ from sglang.srt.disaggregation.utils import (
kv_to_page_indices, kv_to_page_indices,
kv_to_page_num, kv_to_page_num,
poll_and_all_reduce, poll_and_all_reduce,
prepare_abort,
) )
from sglang.srt.managers.schedule_batch import FINISH_LENGTH, Req, ScheduleBatch from sglang.srt.managers.schedule_batch import FINISH_LENGTH, Req, ScheduleBatch
...@@ -157,7 +158,18 @@ class PrefillBootstrapQueue: ...@@ -157,7 +158,18 @@ class PrefillBootstrapQueue:
if poll == KVPoll.Bootstrapping: if poll == KVPoll.Bootstrapping:
continue continue
elif poll == KVPoll.Failed: elif poll == KVPoll.Failed:
raise Exception("Bootstrap failed") error_message = f"Prefill bootstrap failed for request rank={self.tp_rank} {req.rid=} {req.bootstrap_room=}"
try:
req.disagg_kv_sender.failure_exception()
except Exception as e:
error_message += f" with exception {e}"
logger.error(error_message)
prepare_abort(
req, error_message, status_code=HTTPStatus.INTERNAL_SERVER_ERROR
)
self.scheduler.stream_output([req], req.return_logprob)
indices_to_remove.add(i)
continue
# KV.WaitingForInput # KV.WaitingForInput
num_kv_indices = len(req.origin_input_ids) num_kv_indices = len(req.origin_input_ids)
...@@ -335,7 +347,17 @@ class SchedulerDisaggregationPrefillMixin: ...@@ -335,7 +347,17 @@ class SchedulerDisaggregationPrefillMixin:
# FIXME: clean up req's data in transfer engine # FIXME: clean up req's data in transfer engine
done_reqs.append(req) done_reqs.append(req)
elif poll == KVPoll.Failed: elif poll == KVPoll.Failed:
raise Exception("Transferring failed") error_message = f"Prefill transfer failed for request rank={self.tp_rank} {req.rid=} {req.bootstrap_room=}"
try:
req.disagg_kv_sender.failure_exception()
except Exception as e:
error_message += f" with exception {e}"
logger.warning(error_message)
self.tree_cache.cache_finished_req(req) # unlock the tree
prepare_abort(
req, error_message, status_code=HTTPStatus.INTERNAL_SERVER_ERROR
)
done_reqs.append(req)
for req in done_reqs: for req in done_reqs:
self.disagg_prefill_bootstrap_queue.req_to_metadata_buffer_idx_allocator.free( self.disagg_prefill_bootstrap_queue.req_to_metadata_buffer_idx_allocator.free(
......
...@@ -167,3 +167,18 @@ def is_mla_backend(target_kv_pool) -> bool: ...@@ -167,3 +167,18 @@ def is_mla_backend(target_kv_pool) -> bool:
from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool
return isinstance(target_kv_pool, MLATokenToKVPool) return isinstance(target_kv_pool, MLATokenToKVPool)
def prepare_abort(req: Req, error_message: str, status_code=None):
from sglang.srt.managers.schedule_batch import FINISH_ABORT
# populate finish metadata and stream output
req.finished_reason = FINISH_ABORT(error_message, status_code)
if req.return_logprob:
req.input_token_logprobs_val = []
req.input_token_logprobs_idx = []
req.input_top_logprobs_val = []
req.input_top_logprobs_idx = []
req.input_token_ids_logprobs_val = []
req.input_token_ids_logprobs_idx = []
...@@ -50,6 +50,7 @@ from sglang.srt.disaggregation.utils import ( ...@@ -50,6 +50,7 @@ from sglang.srt.disaggregation.utils import (
DisaggregationMode, DisaggregationMode,
ReqToMetadataIdxAllocator, ReqToMetadataIdxAllocator,
TransferBackend, TransferBackend,
prepare_abort,
) )
from sglang.srt.distributed import get_pp_group, get_world_group from sglang.srt.distributed import get_pp_group, get_world_group
from sglang.srt.hf_transformers_utils import ( from sglang.srt.hf_transformers_utils import (
...@@ -935,6 +936,18 @@ class Scheduler( ...@@ -935,6 +936,18 @@ class Scheduler(
) )
req.tokenizer = self.tokenizer req.tokenizer = self.tokenizer
if self.disaggregation_mode != DisaggregationMode.NULL:
# Invalid request for disaggregated mode
if recv_req.bootstrap_room is None:
error_message = (
f"Invalid request: Disaggregated request received without "
f"boostrap room id. {req.rid=}"
)
logger.error(error_message)
prepare_abort(req, error_message)
self.stream_output([req], req.return_logprob)
return
if ( if (
recv_req.session_params is not None recv_req.session_params is not None
and recv_req.session_params.id is not None and recv_req.session_params.id is not None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment