Unverified Commit 37dfa600 authored by Vaibhav Jain's avatar Vaibhav Jain Committed by GitHub
Browse files

[Bugfix] Missing Content Type returns 500 Internal Server Error (#13193)

parent 1bc3b5e7
...@@ -156,3 +156,19 @@ async def test_request_cancellation(server: RemoteOpenAIServer): ...@@ -156,3 +156,19 @@ async def test_request_cancellation(server: RemoteOpenAIServer):
max_tokens=10) max_tokens=10)
assert len(response.choices) == 1 assert len(response.choices) == 1
@pytest.mark.asyncio
async def test_request_wrong_content_type(server: RemoteOpenAIServer):
chat_input = [{"role": "user", "content": "Write a long story"}]
client = server.get_async_client()
with pytest.raises(openai.APIStatusError):
await client.chat.completions.create(
messages=chat_input,
model=MODEL_NAME,
max_tokens=10000,
extra_headers={
"Content-Type": "application/x-www-form-urlencoded"
})
...@@ -19,7 +19,7 @@ from http import HTTPStatus ...@@ -19,7 +19,7 @@ from http import HTTPStatus
from typing import AsyncIterator, Dict, Optional, Set, Tuple, Union from typing import AsyncIterator, Dict, Optional, Set, Tuple, Union
import uvloop import uvloop
from fastapi import APIRouter, FastAPI, HTTPException, Request from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request
from fastapi.exceptions import RequestValidationError from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, Response, StreamingResponse from fastapi.responses import JSONResponse, Response, StreamingResponse
...@@ -252,6 +252,15 @@ async def build_async_engine_client_from_engine_args( ...@@ -252,6 +252,15 @@ async def build_async_engine_client_from_engine_args(
multiprocess.mark_process_dead(engine_process.pid) multiprocess.mark_process_dead(engine_process.pid)
async def validate_json_request(raw_request: Request):
content_type = raw_request.headers.get("content-type", "").lower()
if content_type != "application/json":
raise HTTPException(
status_code=HTTPStatus.UNSUPPORTED_MEDIA_TYPE,
detail="Unsupported Media Type: Only 'application/json' is allowed"
)
router = APIRouter() router = APIRouter()
...@@ -335,7 +344,7 @@ async def ping(raw_request: Request) -> Response: ...@@ -335,7 +344,7 @@ async def ping(raw_request: Request) -> Response:
return await health(raw_request) return await health(raw_request)
@router.post("/tokenize") @router.post("/tokenize", dependencies=[Depends(validate_json_request)])
@with_cancellation @with_cancellation
async def tokenize(request: TokenizeRequest, raw_request: Request): async def tokenize(request: TokenizeRequest, raw_request: Request):
handler = tokenization(raw_request) handler = tokenization(raw_request)
...@@ -350,7 +359,7 @@ async def tokenize(request: TokenizeRequest, raw_request: Request): ...@@ -350,7 +359,7 @@ async def tokenize(request: TokenizeRequest, raw_request: Request):
assert_never(generator) assert_never(generator)
@router.post("/detokenize") @router.post("/detokenize", dependencies=[Depends(validate_json_request)])
@with_cancellation @with_cancellation
async def detokenize(request: DetokenizeRequest, raw_request: Request): async def detokenize(request: DetokenizeRequest, raw_request: Request):
handler = tokenization(raw_request) handler = tokenization(raw_request)
...@@ -379,7 +388,8 @@ async def show_version(): ...@@ -379,7 +388,8 @@ async def show_version():
return JSONResponse(content=ver) return JSONResponse(content=ver)
@router.post("/v1/chat/completions") @router.post("/v1/chat/completions",
dependencies=[Depends(validate_json_request)])
@with_cancellation @with_cancellation
async def create_chat_completion(request: ChatCompletionRequest, async def create_chat_completion(request: ChatCompletionRequest,
raw_request: Request): raw_request: Request):
...@@ -400,7 +410,7 @@ async def create_chat_completion(request: ChatCompletionRequest, ...@@ -400,7 +410,7 @@ async def create_chat_completion(request: ChatCompletionRequest,
return StreamingResponse(content=generator, media_type="text/event-stream") return StreamingResponse(content=generator, media_type="text/event-stream")
@router.post("/v1/completions") @router.post("/v1/completions", dependencies=[Depends(validate_json_request)])
@with_cancellation @with_cancellation
async def create_completion(request: CompletionRequest, raw_request: Request): async def create_completion(request: CompletionRequest, raw_request: Request):
handler = completion(raw_request) handler = completion(raw_request)
...@@ -418,7 +428,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request): ...@@ -418,7 +428,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
return StreamingResponse(content=generator, media_type="text/event-stream") return StreamingResponse(content=generator, media_type="text/event-stream")
@router.post("/v1/embeddings") @router.post("/v1/embeddings", dependencies=[Depends(validate_json_request)])
@with_cancellation @with_cancellation
async def create_embedding(request: EmbeddingRequest, raw_request: Request): async def create_embedding(request: EmbeddingRequest, raw_request: Request):
handler = embedding(raw_request) handler = embedding(raw_request)
...@@ -464,7 +474,7 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request): ...@@ -464,7 +474,7 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
assert_never(generator) assert_never(generator)
@router.post("/pooling") @router.post("/pooling", dependencies=[Depends(validate_json_request)])
@with_cancellation @with_cancellation
async def create_pooling(request: PoolingRequest, raw_request: Request): async def create_pooling(request: PoolingRequest, raw_request: Request):
handler = pooling(raw_request) handler = pooling(raw_request)
...@@ -482,7 +492,7 @@ async def create_pooling(request: PoolingRequest, raw_request: Request): ...@@ -482,7 +492,7 @@ async def create_pooling(request: PoolingRequest, raw_request: Request):
assert_never(generator) assert_never(generator)
@router.post("/score") @router.post("/score", dependencies=[Depends(validate_json_request)])
@with_cancellation @with_cancellation
async def create_score(request: ScoreRequest, raw_request: Request): async def create_score(request: ScoreRequest, raw_request: Request):
handler = score(raw_request) handler = score(raw_request)
...@@ -500,7 +510,7 @@ async def create_score(request: ScoreRequest, raw_request: Request): ...@@ -500,7 +510,7 @@ async def create_score(request: ScoreRequest, raw_request: Request):
assert_never(generator) assert_never(generator)
@router.post("/v1/score") @router.post("/v1/score", dependencies=[Depends(validate_json_request)])
@with_cancellation @with_cancellation
async def create_score_v1(request: ScoreRequest, raw_request: Request): async def create_score_v1(request: ScoreRequest, raw_request: Request):
logger.warning( logger.warning(
...@@ -510,7 +520,7 @@ async def create_score_v1(request: ScoreRequest, raw_request: Request): ...@@ -510,7 +520,7 @@ async def create_score_v1(request: ScoreRequest, raw_request: Request):
return await create_score(request, raw_request) return await create_score(request, raw_request)
@router.post("/rerank") @router.post("/rerank", dependencies=[Depends(validate_json_request)])
@with_cancellation @with_cancellation
async def do_rerank(request: RerankRequest, raw_request: Request): async def do_rerank(request: RerankRequest, raw_request: Request):
handler = rerank(raw_request) handler = rerank(raw_request)
...@@ -527,7 +537,7 @@ async def do_rerank(request: RerankRequest, raw_request: Request): ...@@ -527,7 +537,7 @@ async def do_rerank(request: RerankRequest, raw_request: Request):
assert_never(generator) assert_never(generator)
@router.post("/v1/rerank") @router.post("/v1/rerank", dependencies=[Depends(validate_json_request)])
@with_cancellation @with_cancellation
async def do_rerank_v1(request: RerankRequest, raw_request: Request): async def do_rerank_v1(request: RerankRequest, raw_request: Request):
logger.warning_once( logger.warning_once(
...@@ -538,7 +548,7 @@ async def do_rerank_v1(request: RerankRequest, raw_request: Request): ...@@ -538,7 +548,7 @@ async def do_rerank_v1(request: RerankRequest, raw_request: Request):
return await do_rerank(request, raw_request) return await do_rerank(request, raw_request)
@router.post("/v2/rerank") @router.post("/v2/rerank", dependencies=[Depends(validate_json_request)])
@with_cancellation @with_cancellation
async def do_rerank_v2(request: RerankRequest, raw_request: Request): async def do_rerank_v2(request: RerankRequest, raw_request: Request):
return await do_rerank(request, raw_request) return await do_rerank(request, raw_request)
...@@ -582,7 +592,7 @@ if envs.VLLM_SERVER_DEV_MODE: ...@@ -582,7 +592,7 @@ if envs.VLLM_SERVER_DEV_MODE:
return Response(status_code=200) return Response(status_code=200)
@router.post("/invocations") @router.post("/invocations", dependencies=[Depends(validate_json_request)])
async def invocations(raw_request: Request): async def invocations(raw_request: Request):
""" """
For SageMaker, routes requests to other handlers based on model `task`. For SageMaker, routes requests to other handlers based on model `task`.
...@@ -632,7 +642,8 @@ if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING: ...@@ -632,7 +642,8 @@ if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
"Lora dynamic loading & unloading is enabled in the API server. " "Lora dynamic loading & unloading is enabled in the API server. "
"This should ONLY be used for local development!") "This should ONLY be used for local development!")
@router.post("/v1/load_lora_adapter") @router.post("/v1/load_lora_adapter",
dependencies=[Depends(validate_json_request)])
async def load_lora_adapter(request: LoadLoraAdapterRequest, async def load_lora_adapter(request: LoadLoraAdapterRequest,
raw_request: Request): raw_request: Request):
handler = models(raw_request) handler = models(raw_request)
...@@ -643,7 +654,8 @@ if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING: ...@@ -643,7 +654,8 @@ if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
return Response(status_code=200, content=response) return Response(status_code=200, content=response)
@router.post("/v1/unload_lora_adapter") @router.post("/v1/unload_lora_adapter",
dependencies=[Depends(validate_json_request)])
async def unload_lora_adapter(request: UnloadLoraAdapterRequest, async def unload_lora_adapter(request: UnloadLoraAdapterRequest,
raw_request: Request): raw_request: Request):
handler = models(raw_request) handler = models(raw_request)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment