Unverified Commit 37066185 authored by Moritz Sanft's avatar Moritz Sanft Committed by GitHub
Browse files

[Frontend] Update OpenAI error response to upstream format (#22099)


Signed-off-by: default avatarMoritz Sanft <58110325+msanft@users.noreply.github.com>
parent cbc8457b
...@@ -121,8 +121,7 @@ def test_invalid_truncate_prompt_tokens_error(server: RemoteOpenAIServer, ...@@ -121,8 +121,7 @@ def test_invalid_truncate_prompt_tokens_error(server: RemoteOpenAIServer,
error = classification_response.json() error = classification_response.json()
assert classification_response.status_code == 400 assert classification_response.status_code == 400
assert error["object"] == "error" assert "truncate_prompt_tokens" in error["error"]["message"]
assert "truncate_prompt_tokens" in error["message"]
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
...@@ -137,7 +136,7 @@ def test_empty_input_error(server: RemoteOpenAIServer, model_name: str): ...@@ -137,7 +136,7 @@ def test_empty_input_error(server: RemoteOpenAIServer, model_name: str):
error = classification_response.json() error = classification_response.json()
assert classification_response.status_code == 400 assert classification_response.status_code == 400
assert error["object"] == "error" assert "error" in error
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
......
...@@ -160,8 +160,8 @@ async def test_serving_completion_resolver_not_found(mock_serving_setup, ...@@ -160,8 +160,8 @@ async def test_serving_completion_resolver_not_found(mock_serving_setup,
mock_engine.generate.assert_not_called() mock_engine.generate.assert_not_called()
assert isinstance(response, ErrorResponse) assert isinstance(response, ErrorResponse)
assert response.code == HTTPStatus.NOT_FOUND.value assert response.error.code == HTTPStatus.NOT_FOUND.value
assert non_existent_model in response.message assert non_existent_model in response.error.message
@pytest.mark.asyncio @pytest.mark.asyncio
...@@ -190,8 +190,8 @@ async def test_serving_completion_resolver_add_lora_fails( ...@@ -190,8 +190,8 @@ async def test_serving_completion_resolver_add_lora_fails(
# Assert the correct error response # Assert the correct error response
assert isinstance(response, ErrorResponse) assert isinstance(response, ErrorResponse)
assert response.code == HTTPStatus.BAD_REQUEST.value assert response.error.code == HTTPStatus.BAD_REQUEST.value
assert invalid_model in response.message assert invalid_model in response.error.message
@pytest.mark.asyncio @pytest.mark.asyncio
......
...@@ -66,8 +66,8 @@ async def test_load_lora_adapter_missing_fields(): ...@@ -66,8 +66,8 @@ async def test_load_lora_adapter_missing_fields():
request = LoadLoRAAdapterRequest(lora_name="", lora_path="") request = LoadLoRAAdapterRequest(lora_name="", lora_path="")
response = await serving_models.load_lora_adapter(request) response = await serving_models.load_lora_adapter(request)
assert isinstance(response, ErrorResponse) assert isinstance(response, ErrorResponse)
assert response.type == "InvalidUserInput" assert response.error.type == "InvalidUserInput"
assert response.code == HTTPStatus.BAD_REQUEST assert response.error.code == HTTPStatus.BAD_REQUEST
@pytest.mark.asyncio @pytest.mark.asyncio
...@@ -84,8 +84,8 @@ async def test_load_lora_adapter_duplicate(): ...@@ -84,8 +84,8 @@ async def test_load_lora_adapter_duplicate():
lora_path="/path/to/adapter1") lora_path="/path/to/adapter1")
response = await serving_models.load_lora_adapter(request) response = await serving_models.load_lora_adapter(request)
assert isinstance(response, ErrorResponse) assert isinstance(response, ErrorResponse)
assert response.type == "InvalidUserInput" assert response.error.type == "InvalidUserInput"
assert response.code == HTTPStatus.BAD_REQUEST assert response.error.code == HTTPStatus.BAD_REQUEST
assert len(serving_models.lora_requests) == 1 assert len(serving_models.lora_requests) == 1
...@@ -110,8 +110,8 @@ async def test_unload_lora_adapter_missing_fields(): ...@@ -110,8 +110,8 @@ async def test_unload_lora_adapter_missing_fields():
request = UnloadLoRAAdapterRequest(lora_name="", lora_int_id=None) request = UnloadLoRAAdapterRequest(lora_name="", lora_int_id=None)
response = await serving_models.unload_lora_adapter(request) response = await serving_models.unload_lora_adapter(request)
assert isinstance(response, ErrorResponse) assert isinstance(response, ErrorResponse)
assert response.type == "InvalidUserInput" assert response.error.type == "InvalidUserInput"
assert response.code == HTTPStatus.BAD_REQUEST assert response.error.code == HTTPStatus.BAD_REQUEST
@pytest.mark.asyncio @pytest.mark.asyncio
...@@ -120,5 +120,5 @@ async def test_unload_lora_adapter_not_found(): ...@@ -120,5 +120,5 @@ async def test_unload_lora_adapter_not_found():
request = UnloadLoRAAdapterRequest(lora_name="nonexistent_adapter") request = UnloadLoRAAdapterRequest(lora_name="nonexistent_adapter")
response = await serving_models.unload_lora_adapter(request) response = await serving_models.unload_lora_adapter(request)
assert isinstance(response, ErrorResponse) assert isinstance(response, ErrorResponse)
assert response.type == "NotFoundError" assert response.error.type == "NotFoundError"
assert response.code == HTTPStatus.NOT_FOUND assert response.error.code == HTTPStatus.NOT_FOUND
...@@ -116,8 +116,10 @@ async def test_non_asr_model(winning_call): ...@@ -116,8 +116,10 @@ async def test_non_asr_model(winning_call):
file=winning_call, file=winning_call,
language="en", language="en",
temperature=0.0) temperature=0.0)
assert res.code == 400 and not res.text err = res.error
assert res.message == "The model does not support Transcriptions API" assert err["code"] == 400 and not res.text
assert err[
"message"] == "The model does not support Transcriptions API"
@pytest.mark.asyncio @pytest.mark.asyncio
...@@ -133,12 +135,15 @@ async def test_completion_endpoints(): ...@@ -133,12 +135,15 @@ async def test_completion_endpoints():
"role": "system", "role": "system",
"content": "You are a helpful assistant." "content": "You are a helpful assistant."
}]) }])
assert res.code == 400 err = res.error
assert res.message == "The model does not support Chat Completions API" assert err["code"] == 400
assert err[
"message"] == "The model does not support Chat Completions API"
res = await client.completions.create(model=model_name, prompt="Hello") res = await client.completions.create(model=model_name, prompt="Hello")
assert res.code == 400 err = res.error
assert res.message == "The model does not support Completions API" assert err["code"] == 400
assert err["message"] == "The model does not support Completions API"
@pytest.mark.asyncio @pytest.mark.asyncio
......
...@@ -73,8 +73,9 @@ async def test_non_asr_model(foscolo): ...@@ -73,8 +73,9 @@ async def test_non_asr_model(foscolo):
res = await client.audio.translations.create(model=model_name, res = await client.audio.translations.create(model=model_name,
file=foscolo, file=foscolo,
temperature=0.0) temperature=0.0)
assert res.code == 400 and not res.text err = res.error
assert res.message == "The model does not support Translations API" assert err["code"] == 400 and not res.text
assert err["message"] == "The model does not support Translations API"
@pytest.mark.asyncio @pytest.mark.asyncio
......
...@@ -62,7 +62,8 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, ...@@ -62,7 +62,8 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DetokenizeRequest, DetokenizeRequest,
DetokenizeResponse, DetokenizeResponse,
EmbeddingRequest, EmbeddingRequest,
EmbeddingResponse, ErrorResponse, EmbeddingResponse, ErrorInfo,
ErrorResponse,
LoadLoRAAdapterRequest, LoadLoRAAdapterRequest,
PoolingRequest, PoolingResponse, PoolingRequest, PoolingResponse,
RerankRequest, RerankResponse, RerankRequest, RerankResponse,
...@@ -506,7 +507,7 @@ async def tokenize(request: TokenizeRequest, raw_request: Request): ...@@ -506,7 +507,7 @@ async def tokenize(request: TokenizeRequest, raw_request: Request):
if isinstance(generator, ErrorResponse): if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(), return JSONResponse(content=generator.model_dump(),
status_code=generator.code) status_code=generator.error.code)
elif isinstance(generator, TokenizeResponse): elif isinstance(generator, TokenizeResponse):
return JSONResponse(content=generator.model_dump()) return JSONResponse(content=generator.model_dump())
...@@ -540,7 +541,7 @@ async def detokenize(request: DetokenizeRequest, raw_request: Request): ...@@ -540,7 +541,7 @@ async def detokenize(request: DetokenizeRequest, raw_request: Request):
if isinstance(generator, ErrorResponse): if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(), return JSONResponse(content=generator.model_dump(),
status_code=generator.code) status_code=generator.error.code)
elif isinstance(generator, DetokenizeResponse): elif isinstance(generator, DetokenizeResponse):
return JSONResponse(content=generator.model_dump()) return JSONResponse(content=generator.model_dump())
...@@ -556,7 +557,7 @@ def maybe_register_tokenizer_info_endpoint(args): ...@@ -556,7 +557,7 @@ def maybe_register_tokenizer_info_endpoint(args):
"""Get comprehensive tokenizer information.""" """Get comprehensive tokenizer information."""
result = await tokenization(raw_request).get_tokenizer_info() result = await tokenization(raw_request).get_tokenizer_info()
return JSONResponse(content=result.model_dump(), return JSONResponse(content=result.model_dump(),
status_code=result.code if isinstance( status_code=result.error.code if isinstance(
result, ErrorResponse) else 200) result, ErrorResponse) else 200)
...@@ -603,7 +604,7 @@ async def create_responses(request: ResponsesRequest, raw_request: Request): ...@@ -603,7 +604,7 @@ async def create_responses(request: ResponsesRequest, raw_request: Request):
if isinstance(generator, ErrorResponse): if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(), return JSONResponse(content=generator.model_dump(),
status_code=generator.code) status_code=generator.error.code)
elif isinstance(generator, ResponsesResponse): elif isinstance(generator, ResponsesResponse):
return JSONResponse(content=generator.model_dump()) return JSONResponse(content=generator.model_dump())
return StreamingResponse(content=generator, media_type="text/event-stream") return StreamingResponse(content=generator, media_type="text/event-stream")
...@@ -620,7 +621,7 @@ async def retrieve_responses(response_id: str, raw_request: Request): ...@@ -620,7 +621,7 @@ async def retrieve_responses(response_id: str, raw_request: Request):
if isinstance(response, ErrorResponse): if isinstance(response, ErrorResponse):
return JSONResponse(content=response.model_dump(), return JSONResponse(content=response.model_dump(),
status_code=response.code) status_code=response.error.code)
return JSONResponse(content=response.model_dump()) return JSONResponse(content=response.model_dump())
...@@ -635,7 +636,7 @@ async def cancel_responses(response_id: str, raw_request: Request): ...@@ -635,7 +636,7 @@ async def cancel_responses(response_id: str, raw_request: Request):
if isinstance(response, ErrorResponse): if isinstance(response, ErrorResponse):
return JSONResponse(content=response.model_dump(), return JSONResponse(content=response.model_dump(),
status_code=response.code) status_code=response.error.code)
return JSONResponse(content=response.model_dump()) return JSONResponse(content=response.model_dump())
...@@ -670,7 +671,7 @@ async def create_chat_completion(request: ChatCompletionRequest, ...@@ -670,7 +671,7 @@ async def create_chat_completion(request: ChatCompletionRequest,
if isinstance(generator, ErrorResponse): if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(), return JSONResponse(content=generator.model_dump(),
status_code=generator.code) status_code=generator.error.code)
elif isinstance(generator, ChatCompletionResponse): elif isinstance(generator, ChatCompletionResponse):
return JSONResponse(content=generator.model_dump()) return JSONResponse(content=generator.model_dump())
...@@ -715,7 +716,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request): ...@@ -715,7 +716,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
if isinstance(generator, ErrorResponse): if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(), return JSONResponse(content=generator.model_dump(),
status_code=generator.code) status_code=generator.error.code)
elif isinstance(generator, CompletionResponse): elif isinstance(generator, CompletionResponse):
return JSONResponse(content=generator.model_dump()) return JSONResponse(content=generator.model_dump())
...@@ -744,7 +745,7 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request): ...@@ -744,7 +745,7 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
if isinstance(generator, ErrorResponse): if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(), return JSONResponse(content=generator.model_dump(),
status_code=generator.code) status_code=generator.error.code)
elif isinstance(generator, EmbeddingResponse): elif isinstance(generator, EmbeddingResponse):
return JSONResponse(content=generator.model_dump()) return JSONResponse(content=generator.model_dump())
...@@ -772,7 +773,7 @@ async def create_pooling(request: PoolingRequest, raw_request: Request): ...@@ -772,7 +773,7 @@ async def create_pooling(request: PoolingRequest, raw_request: Request):
generator = await handler.create_pooling(request, raw_request) generator = await handler.create_pooling(request, raw_request)
if isinstance(generator, ErrorResponse): if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(), return JSONResponse(content=generator.model_dump(),
status_code=generator.code) status_code=generator.error.code)
elif isinstance(generator, PoolingResponse): elif isinstance(generator, PoolingResponse):
return JSONResponse(content=generator.model_dump()) return JSONResponse(content=generator.model_dump())
...@@ -792,7 +793,7 @@ async def create_classify(request: ClassificationRequest, ...@@ -792,7 +793,7 @@ async def create_classify(request: ClassificationRequest,
generator = await handler.create_classify(request, raw_request) generator = await handler.create_classify(request, raw_request)
if isinstance(generator, ErrorResponse): if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(), return JSONResponse(content=generator.model_dump(),
status_code=generator.code) status_code=generator.error.code)
elif isinstance(generator, ClassificationResponse): elif isinstance(generator, ClassificationResponse):
return JSONResponse(content=generator.model_dump()) return JSONResponse(content=generator.model_dump())
...@@ -821,7 +822,7 @@ async def create_score(request: ScoreRequest, raw_request: Request): ...@@ -821,7 +822,7 @@ async def create_score(request: ScoreRequest, raw_request: Request):
generator = await handler.create_score(request, raw_request) generator = await handler.create_score(request, raw_request)
if isinstance(generator, ErrorResponse): if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(), return JSONResponse(content=generator.model_dump(),
status_code=generator.code) status_code=generator.error.code)
elif isinstance(generator, ScoreResponse): elif isinstance(generator, ScoreResponse):
return JSONResponse(content=generator.model_dump()) return JSONResponse(content=generator.model_dump())
...@@ -881,7 +882,7 @@ async def create_transcriptions(raw_request: Request, ...@@ -881,7 +882,7 @@ async def create_transcriptions(raw_request: Request,
if isinstance(generator, ErrorResponse): if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(), return JSONResponse(content=generator.model_dump(),
status_code=generator.code) status_code=generator.error.code)
elif isinstance(generator, TranscriptionResponse): elif isinstance(generator, TranscriptionResponse):
return JSONResponse(content=generator.model_dump()) return JSONResponse(content=generator.model_dump())
...@@ -922,7 +923,7 @@ async def create_translations(request: Annotated[TranslationRequest, ...@@ -922,7 +923,7 @@ async def create_translations(request: Annotated[TranslationRequest,
if isinstance(generator, ErrorResponse): if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(), return JSONResponse(content=generator.model_dump(),
status_code=generator.code) status_code=generator.error.code)
elif isinstance(generator, TranslationResponse): elif isinstance(generator, TranslationResponse):
return JSONResponse(content=generator.model_dump()) return JSONResponse(content=generator.model_dump())
...@@ -950,7 +951,7 @@ async def do_rerank(request: RerankRequest, raw_request: Request): ...@@ -950,7 +951,7 @@ async def do_rerank(request: RerankRequest, raw_request: Request):
generator = await handler.do_rerank(request, raw_request) generator = await handler.do_rerank(request, raw_request)
if isinstance(generator, ErrorResponse): if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(), return JSONResponse(content=generator.model_dump(),
status_code=generator.code) status_code=generator.error.code)
elif isinstance(generator, RerankResponse): elif isinstance(generator, RerankResponse):
return JSONResponse(content=generator.model_dump()) return JSONResponse(content=generator.model_dump())
...@@ -1175,7 +1176,7 @@ async def invocations(raw_request: Request): ...@@ -1175,7 +1176,7 @@ async def invocations(raw_request: Request):
msg = ("Cannot find suitable handler for request. " msg = ("Cannot find suitable handler for request. "
f"Expected one of: {type_names}") f"Expected one of: {type_names}")
res = base(raw_request).create_error_response(message=msg) res = base(raw_request).create_error_response(message=msg)
return JSONResponse(content=res.model_dump(), status_code=res.code) return JSONResponse(content=res.model_dump(), status_code=res.error.code)
if envs.VLLM_TORCH_PROFILER_DIR: if envs.VLLM_TORCH_PROFILER_DIR:
...@@ -1211,7 +1212,7 @@ if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING: ...@@ -1211,7 +1212,7 @@ if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
response = await handler.load_lora_adapter(request) response = await handler.load_lora_adapter(request)
if isinstance(response, ErrorResponse): if isinstance(response, ErrorResponse):
return JSONResponse(content=response.model_dump(), return JSONResponse(content=response.model_dump(),
status_code=response.code) status_code=response.error.code)
return Response(status_code=200, content=response) return Response(status_code=200, content=response)
...@@ -1223,7 +1224,7 @@ if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING: ...@@ -1223,7 +1224,7 @@ if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
response = await handler.unload_lora_adapter(request) response = await handler.unload_lora_adapter(request)
if isinstance(response, ErrorResponse): if isinstance(response, ErrorResponse):
return JSONResponse(content=response.model_dump(), return JSONResponse(content=response.model_dump(),
status_code=response.code) status_code=response.error.code)
return Response(status_code=200, content=response) return Response(status_code=200, content=response)
...@@ -1502,9 +1503,10 @@ def build_app(args: Namespace) -> FastAPI: ...@@ -1502,9 +1503,10 @@ def build_app(args: Namespace) -> FastAPI:
@app.exception_handler(HTTPException) @app.exception_handler(HTTPException)
async def http_exception_handler(_: Request, exc: HTTPException): async def http_exception_handler(_: Request, exc: HTTPException):
err = ErrorResponse(message=exc.detail, err = ErrorResponse(
error=ErrorInfo(message=exc.detail,
type=HTTPStatus(exc.status_code).phrase, type=HTTPStatus(exc.status_code).phrase,
code=exc.status_code) code=exc.status_code))
return JSONResponse(err.model_dump(), status_code=exc.status_code) return JSONResponse(err.model_dump(), status_code=exc.status_code)
@app.exception_handler(RequestValidationError) @app.exception_handler(RequestValidationError)
...@@ -1518,9 +1520,9 @@ def build_app(args: Namespace) -> FastAPI: ...@@ -1518,9 +1520,9 @@ def build_app(args: Namespace) -> FastAPI:
else: else:
message = exc_str message = exc_str
err = ErrorResponse(message=message, err = ErrorResponse(error=ErrorInfo(message=message,
type=HTTPStatus.BAD_REQUEST.phrase, type=HTTPStatus.BAD_REQUEST.phrase,
code=HTTPStatus.BAD_REQUEST) code=HTTPStatus.BAD_REQUEST))
return JSONResponse(err.model_dump(), return JSONResponse(err.model_dump(),
status_code=HTTPStatus.BAD_REQUEST) status_code=HTTPStatus.BAD_REQUEST)
......
...@@ -78,14 +78,17 @@ class OpenAIBaseModel(BaseModel): ...@@ -78,14 +78,17 @@ class OpenAIBaseModel(BaseModel):
return result return result
class ErrorResponse(OpenAIBaseModel): class ErrorInfo(OpenAIBaseModel):
object: str = "error"
message: str message: str
type: str type: str
param: Optional[str] = None param: Optional[str] = None
code: int code: int
class ErrorResponse(OpenAIBaseModel):
error: ErrorInfo
class ModelPermission(OpenAIBaseModel): class ModelPermission(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}") id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}")
object: str = "model_permission" object: str = "model_permission"
......
...@@ -302,7 +302,7 @@ async def run_request(serving_engine_func: Callable, ...@@ -302,7 +302,7 @@ async def run_request(serving_engine_func: Callable,
id=f"vllm-{random_uuid()}", id=f"vllm-{random_uuid()}",
custom_id=request.custom_id, custom_id=request.custom_id,
response=BatchResponseData( response=BatchResponseData(
status_code=response.code, status_code=response.error.code,
request_id=f"vllm-batch-{random_uuid()}"), request_id=f"vllm-batch-{random_uuid()}"),
error=response, error=response,
) )
......
...@@ -47,10 +47,10 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, ...@@ -47,10 +47,10 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
EmbeddingChatRequest, EmbeddingChatRequest,
EmbeddingCompletionRequest, EmbeddingCompletionRequest,
EmbeddingRequest, EmbeddingRequest,
EmbeddingResponse, ErrorResponse, EmbeddingResponse, ErrorInfo,
PoolingResponse, RerankRequest, ErrorResponse, PoolingResponse,
ResponsesRequest, ScoreRequest, RerankRequest, ResponsesRequest,
ScoreResponse, ScoreRequest, ScoreResponse,
TokenizeChatRequest, TokenizeChatRequest,
TokenizeCompletionRequest, TokenizeCompletionRequest,
TokenizeResponse, TokenizeResponse,
...@@ -412,21 +412,18 @@ class OpenAIServing: ...@@ -412,21 +412,18 @@ class OpenAIServing:
message: str, message: str,
err_type: str = "BadRequestError", err_type: str = "BadRequestError",
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> ErrorResponse: status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> ErrorResponse:
return ErrorResponse(message=message, return ErrorResponse(error=ErrorInfo(
type=err_type, message=message, type=err_type, code=status_code.value))
code=status_code.value)
def create_streaming_error_response( def create_streaming_error_response(
self, self,
message: str, message: str,
err_type: str = "BadRequestError", err_type: str = "BadRequestError",
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> str: status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> str:
json_str = json.dumps({ json_str = json.dumps(
"error":
self.create_error_response(message=message, self.create_error_response(message=message,
err_type=err_type, err_type=err_type,
status_code=status_code).model_dump() status_code=status_code).model_dump())
})
return json_str return json_str
async def _check_model( async def _check_model(
...@@ -445,7 +442,7 @@ class OpenAIServing: ...@@ -445,7 +442,7 @@ class OpenAIServing:
if isinstance(load_result, LoRARequest): if isinstance(load_result, LoRARequest):
return None return None
if isinstance(load_result, ErrorResponse) and \ if isinstance(load_result, ErrorResponse) and \
load_result.code == HTTPStatus.BAD_REQUEST.value: load_result.error.code == HTTPStatus.BAD_REQUEST.value:
error_response = load_result error_response = load_result
return error_response or self.create_error_response( return error_response or self.create_error_response(
......
...@@ -9,7 +9,7 @@ from typing import Optional, Union ...@@ -9,7 +9,7 @@ from typing import Optional, Union
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
from vllm.entrypoints.openai.protocol import (ErrorResponse, from vllm.entrypoints.openai.protocol import (ErrorInfo, ErrorResponse,
LoadLoRAAdapterRequest, LoadLoRAAdapterRequest,
ModelCard, ModelList, ModelCard, ModelList,
ModelPermission, ModelPermission,
...@@ -82,7 +82,7 @@ class OpenAIServingModels: ...@@ -82,7 +82,7 @@ class OpenAIServingModels:
load_result = await self.load_lora_adapter( load_result = await self.load_lora_adapter(
request=load_request, base_model_name=lora.base_model_name) request=load_request, base_model_name=lora.base_model_name)
if isinstance(load_result, ErrorResponse): if isinstance(load_result, ErrorResponse):
raise ValueError(load_result.message) raise ValueError(load_result.error.message)
def is_base_model(self, model_name) -> bool: def is_base_model(self, model_name) -> bool:
return any(model.name == model_name for model in self.base_model_paths) return any(model.name == model_name for model in self.base_model_paths)
...@@ -284,6 +284,5 @@ def create_error_response( ...@@ -284,6 +284,5 @@ def create_error_response(
message: str, message: str,
err_type: str = "BadRequestError", err_type: str = "BadRequestError",
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> ErrorResponse: status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> ErrorResponse:
return ErrorResponse(message=message, return ErrorResponse(error=ErrorInfo(
type=err_type, message=message, type=err_type, code=status_code.value))
code=status_code.value)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment