Unverified Commit 98b44e9e authored by datdo-msft's avatar datdo-msft Committed by GitHub
Browse files

[PD] Propagate internal server errors from aborted requests to clients instead...

[PD] Propagate internal server errors from aborted requests to clients instead of blindly returning 200's (#8936)
parent 6805f6da
...@@ -259,7 +259,7 @@ class DecodePreallocQueue: ...@@ -259,7 +259,7 @@ class DecodePreallocQueue:
if len(req.origin_input_ids) > self.max_total_num_tokens: if len(req.origin_input_ids) > self.max_total_num_tokens:
message = f"Request {req.rid} exceeds the maximum number of tokens: {len(req.origin_input_ids)} > {self.max_total_num_tokens}" message = f"Request {req.rid} exceeds the maximum number of tokens: {len(req.origin_input_ids)} > {self.max_total_num_tokens}"
logger.error(message) logger.error(message)
prepare_abort(req, message) prepare_abort(req, message, status_code=HTTPStatus.BAD_REQUEST)
self.scheduler.stream_output([req], req.return_logprob) self.scheduler.stream_output([req], req.return_logprob)
return True return True
return False return False
......
...@@ -178,7 +178,7 @@ class PrefillBootstrapQueue: ...@@ -178,7 +178,7 @@ class PrefillBootstrapQueue:
if len(req.origin_input_ids) > self.max_total_num_tokens: if len(req.origin_input_ids) > self.max_total_num_tokens:
message = f"Request {req.rid} exceeds the maximum number of tokens: {len(req.origin_input_ids)} > {self.max_total_num_tokens}" message = f"Request {req.rid} exceeds the maximum number of tokens: {len(req.origin_input_ids)} > {self.max_total_num_tokens}"
logger.error(message) logger.error(message)
prepare_abort(req, message) prepare_abort(req, message, status_code=HTTPStatus.BAD_REQUEST)
self.scheduler.stream_output([req], req.return_logprob) self.scheduler.stream_output([req], req.return_logprob)
return True return True
return False return False
......
...@@ -1141,7 +1141,7 @@ class Scheduler( ...@@ -1141,7 +1141,7 @@ class Scheduler(
f"boostrap room id. {req.rid=}" f"boostrap room id. {req.rid=}"
) )
logger.error(error_msg) logger.error(error_msg)
prepare_abort(req, error_msg) prepare_abort(req, error_msg, status_code=HTTPStatus.BAD_REQUEST)
self.stream_output([req], req.return_logprob) self.stream_output([req], req.return_logprob)
return return
......
...@@ -782,15 +782,17 @@ class TokenizerManager: ...@@ -782,15 +782,17 @@ class TokenizerManager:
): ):
raise ValueError(finish_reason["message"]) raise ValueError(finish_reason["message"])
if ( if finish_reason.get("type") == "abort" and finish_reason.get(
finish_reason.get("type") == "abort" "status_code"
and finish_reason.get("status_code") ) in (
== HTTPStatus.SERVICE_UNAVAILABLE HTTPStatus.SERVICE_UNAVAILABLE,
HTTPStatus.INTERNAL_SERVER_ERROR,
): ):
# This is an abort request initiated by scheduler. # This is an abort request initiated by scheduler.
# Delete the key to prevent resending abort request to the scheduler and # Delete the key to prevent resending abort request to the scheduler and
# to ensure aborted request state is cleaned up. # to ensure aborted request state is cleaned up.
del self.rid_to_state[state.obj.rid] if state.obj.rid in self.rid_to_state:
del self.rid_to_state[state.obj.rid]
# Mark ongoing LoRA request as finished. # Mark ongoing LoRA request as finished.
if self.server_args.enable_lora and state.obj.lora_path: if self.server_args.enable_lora and state.obj.lora_path:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment