Unverified Commit 4d4cdb3f authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

Frontend: better error message handling for FINISH_ABORT in scheduler.py (#2956)

parent 2bd18e2d
...@@ -115,14 +115,18 @@ class FINISH_LENGTH(BaseFinishReason): ...@@ -115,14 +115,18 @@ class FINISH_LENGTH(BaseFinishReason):
class FINISH_ABORT(BaseFinishReason): class FINISH_ABORT(BaseFinishReason):
def __init__(self, message="Unknown error"): def __init__(self, message="Unknown error", status_code=None, err_type=None):
super().__init__(is_error=True) super().__init__(is_error=True)
self.message = message self.message = message
self.status_code = status_code
self.err_type = err_type
def to_json(self): def to_json(self):
return { return {
"type": "abort", "type": "abort",
"message": self.message, "message": self.message,
"status_code": self.status_code,
"err_type": self.err_type,
} }
......
...@@ -23,6 +23,7 @@ import warnings ...@@ -23,6 +23,7 @@ import warnings
from collections import deque from collections import deque
from concurrent import futures from concurrent import futures
from dataclasses import dataclass from dataclasses import dataclass
from http import HTTPStatus
from types import SimpleNamespace from types import SimpleNamespace
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple, Union
...@@ -672,15 +673,16 @@ class Scheduler: ...@@ -672,15 +673,16 @@ class Scheduler:
req.extend_image_inputs(image_inputs) req.extend_image_inputs(image_inputs)
if len(req.origin_input_ids) >= self.max_req_input_len: if len(req.origin_input_ids) >= self.max_req_input_len:
logger.error( error_msg = (
"Multimodal prompt is too long after expanding multimodal tokens. " "Multimodal prompt is too long after expanding multimodal tokens. "
f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}. " f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}."
) )
logger.error(error_msg)
req.origin_input_ids = [0] req.origin_input_ids = [0]
req.image_inputs = None req.image_inputs = None
req.sampling_params.max_new_tokens = 0 req.sampling_params.max_new_tokens = 0
req.finished_reason = FINISH_ABORT( req.finished_reason = FINISH_ABORT(
"Multimodal prompt is too long. Check server logs for details." error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
) )
self.waiting_queue.append(req) self.waiting_queue.append(req)
return return
......
...@@ -25,6 +25,7 @@ import threading ...@@ -25,6 +25,7 @@ import threading
import time import time
import uuid import uuid
from datetime import datetime from datetime import datetime
from http import HTTPStatus
from typing import Any, Awaitable, Dict, Generic, List, Optional, Tuple, TypeVar, Union from typing import Any, Awaitable, Dict, Generic, List, Optional, Tuple, TypeVar, Union
import fastapi import fastapi
...@@ -384,6 +385,16 @@ class TokenizerManager: ...@@ -384,6 +385,16 @@ class TokenizerManager:
msg = f"Finish: obj={dataclass_to_string_truncated(obj)}, out={dataclass_to_string_truncated(out)}" msg = f"Finish: obj={dataclass_to_string_truncated(obj)}, out={dataclass_to_string_truncated(out)}"
logger.info(msg) logger.info(msg)
del self.rid_to_state[obj.rid] del self.rid_to_state[obj.rid]
# Check if this was an abort/error created by scheduler
if isinstance(out["meta_info"].get("finish_reason"), dict):
finish_reason = out["meta_info"]["finish_reason"]
if (
finish_reason.get("type") == "abort"
and finish_reason.get("status_code") == HTTPStatus.BAD_REQUEST
):
raise ValueError(finish_reason["message"])
yield out yield out
break break
......
import logging import logging
from http import HTTPStatus
from typing import Optional from typing import Optional
from sglang.srt.managers.schedule_batch import FINISH_ABORT, Req from sglang.srt.managers.schedule_batch import FINISH_ABORT, Req
...@@ -35,7 +36,9 @@ def validate_input_length( ...@@ -35,7 +36,9 @@ def validate_input_length(
f"Use a shorter input or enable --allow-auto-truncate." f"Use a shorter input or enable --allow-auto-truncate."
) )
logger.error(error_msg) logger.error(error_msg)
req.finished_reason = FINISH_ABORT(error_msg) req.finished_reason = FINISH_ABORT(
error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
)
return error_msg return error_msg
return None return None
...@@ -392,34 +392,33 @@ class TestQWen2VLServerContextLengthIssue(unittest.TestCase): ...@@ -392,34 +392,33 @@ class TestQWen2VLServerContextLengthIssue(unittest.TestCase):
def test_chat_completion(self): def test_chat_completion(self):
client = openai.Client(api_key=self.api_key, base_url=self.base_url) client = openai.Client(api_key=self.api_key, base_url=self.base_url)
response = client.chat.completions.create( with self.assertRaises(openai.BadRequestError) as cm:
model="default", client.chat.completions.create(
messages=[ model="default",
{ messages=[
"role": "user", {
"content": [ "role": "user",
{ "content": [
"type": "image_url", {
"image_url": { "type": "image_url",
"url": "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true" "image_url": {
"url": "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true"
},
}, },
}, {
{ "type": "text",
"type": "text", "text": "Give a lengthy description of this picture",
"text": "Give a lengthy description of this picture", },
}, ],
], },
}, ],
], temperature=0,
temperature=0, )
)
assert response.choices[0].finish_reason == "abort" self.assertIn(
assert response.id "Multimodal prompt is too long after expanding multimodal tokens.",
assert response.created str(cm.exception),
assert response.usage.prompt_tokens > 0 )
assert response.usage.completion_tokens > 0
assert response.usage.total_tokens > 0
class TestMllamaServer(TestOpenAIVisionServer): class TestMllamaServer(TestOpenAIVisionServer):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment