Unverified Commit b0facb33 authored by Michael Feil's avatar Michael Feil Committed by GitHub
Browse files

add orjson for jsonresponse (#1688)

parent ecb8bad2
...@@ -21,7 +21,7 @@ dependencies = [ ...@@ -21,7 +21,7 @@ dependencies = [
[project.optional-dependencies] [project.optional-dependencies]
runtime_common = ["aiohttp", "decord", "fastapi", "hf_transfer", "huggingface_hub", "interegular", runtime_common = ["aiohttp", "decord", "fastapi", "hf_transfer", "huggingface_hub", "interegular",
"packaging", "pillow", "psutil", "pydantic", "python-multipart", "orjson", "packaging", "pillow", "psutil", "pydantic", "python-multipart",
"torchao", "uvicorn", "uvloop", "zmq", "torchao", "uvicorn", "uvloop", "zmq",
"outlines>=0.0.44", "modelscope"] "outlines>=0.0.44", "modelscope"]
# xpu is not enabled in public vllm and torch whl, # xpu is not enabled in public vllm and torch whl,
......
...@@ -25,7 +25,7 @@ from http import HTTPStatus ...@@ -25,7 +25,7 @@ from http import HTTPStatus
from typing import Dict, List from typing import Dict, List
from fastapi import HTTPException, Request, UploadFile from fastapi import HTTPException, Request, UploadFile
from fastapi.responses import JSONResponse, StreamingResponse from fastapi.responses import ORJSONResponse, StreamingResponse
from pydantic import ValidationError from pydantic import ValidationError
try: try:
...@@ -101,7 +101,7 @@ def create_error_response( ...@@ -101,7 +101,7 @@ def create_error_response(
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST, status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
): ):
error = ErrorResponse(message=message, type=err_type, code=status_code.value) error = ErrorResponse(message=message, type=err_type, code=status_code.value)
return JSONResponse(content=error.model_dump(), status_code=error.code) return ORJSONResponse(content=error.model_dump(), status_code=error.code)
def create_streaming_error_response( def create_streaming_error_response(
......
...@@ -40,7 +40,7 @@ import uvicorn ...@@ -40,7 +40,7 @@ import uvicorn
import uvloop import uvloop
from fastapi import FastAPI, File, Form, Request, UploadFile from fastapi import FastAPI, File, Form, Request, UploadFile
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, Response, StreamingResponse from fastapi.responses import ORJSONResponse, Response, StreamingResponse
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.hf_transformers_utils import get_tokenizer
...@@ -176,12 +176,12 @@ async def update_weights(obj: UpdateWeightReqInput, request: Request): ...@@ -176,12 +176,12 @@ async def update_weights(obj: UpdateWeightReqInput, request: Request):
success, message = await tokenizer_manager.update_weights(obj, request) success, message = await tokenizer_manager.update_weights(obj, request)
content = {"success": success, "message": message} content = {"success": success, "message": message}
if success: if success:
return JSONResponse( return ORJSONResponse(
content, content,
status_code=HTTPStatus.OK, status_code=HTTPStatus.OK,
) )
else: else:
return JSONResponse( return ORJSONResponse(
content, content,
status_code=HTTPStatus.BAD_REQUEST, status_code=HTTPStatus.BAD_REQUEST,
) )
...@@ -211,7 +211,7 @@ async def generate_request(obj: GenerateReqInput, request: Request): ...@@ -211,7 +211,7 @@ async def generate_request(obj: GenerateReqInput, request: Request):
ret = await tokenizer_manager.generate_request(obj, request).__anext__() ret = await tokenizer_manager.generate_request(obj, request).__anext__()
return ret return ret
except ValueError as e: except ValueError as e:
return JSONResponse( return ORJSONResponse(
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
) )
...@@ -226,7 +226,7 @@ async def encode_request(obj: EmbeddingReqInput, request: Request): ...@@ -226,7 +226,7 @@ async def encode_request(obj: EmbeddingReqInput, request: Request):
ret = await tokenizer_manager.generate_request(obj, request).__anext__() ret = await tokenizer_manager.generate_request(obj, request).__anext__()
return ret return ret
except ValueError as e: except ValueError as e:
return JSONResponse( return ORJSONResponse(
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
) )
...@@ -241,7 +241,7 @@ async def judge_request(obj: RewardReqInput, request: Request): ...@@ -241,7 +241,7 @@ async def judge_request(obj: RewardReqInput, request: Request):
ret = await tokenizer_manager.generate_request(obj, request).__anext__() ret = await tokenizer_manager.generate_request(obj, request).__anext__()
return ret return ret
except ValueError as e: except ValueError as e:
return JSONResponse( return ORJSONResponse(
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
) )
......
...@@ -35,7 +35,7 @@ import psutil ...@@ -35,7 +35,7 @@ import psutil
import requests import requests
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from fastapi.responses import JSONResponse from fastapi.responses import ORJSONResponse
from packaging import version as pkg_version from packaging import version as pkg_version
from torch import nn from torch import nn
from torch.profiler import ProfilerActivity, profile, record_function from torch.profiler import ProfilerActivity, profile, record_function
...@@ -566,7 +566,7 @@ def add_api_key_middleware(app, api_key: str): ...@@ -566,7 +566,7 @@ def add_api_key_middleware(app, api_key: str):
if request.url.path.startswith("/health"): if request.url.path.startswith("/health"):
return await call_next(request) return await call_next(request)
if request.headers.get("Authorization") != "Bearer " + api_key: if request.headers.get("Authorization") != "Bearer " + api_key:
return JSONResponse(content={"error": "Unauthorized"}, status_code=401) return ORJSONResponse(content={"error": "Unauthorized"}, status_code=401)
return await call_next(request) return await call_next(request)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment