Unverified Commit 66aa4c0b authored by Yuan Tang's avatar Yuan Tang Committed by GitHub
Browse files

[Feature] Add middleware to log API Server responses (#15593)


Signed-off-by: default avatarYuan Tang <terrytangyuan@gmail.com>
parent 24718153
...@@ -24,6 +24,7 @@ from fastapi import APIRouter, Depends, FastAPI, Form, HTTPException, Request ...@@ -24,6 +24,7 @@ from fastapi import APIRouter, Depends, FastAPI, Form, HTTPException, Request
from fastapi.exceptions import RequestValidationError from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, Response, StreamingResponse from fastapi.responses import JSONResponse, Response, StreamingResponse
from starlette.concurrency import iterate_in_threadpool
from starlette.datastructures import State from starlette.datastructures import State
from starlette.routing import Mount from starlette.routing import Mount
from typing_extensions import assert_never from typing_extensions import assert_never
...@@ -846,6 +847,21 @@ def build_app(args: Namespace) -> FastAPI: ...@@ -846,6 +847,21 @@ def build_app(args: Namespace) -> FastAPI:
response.headers["X-Request-Id"] = request_id response.headers["X-Request-Id"] = request_id
return response return response
if envs.VLLM_DEBUG_LOG_API_SERVER_RESPONSE:
logger.warning("CAUTION: Enabling log response in the API Server. "
"This can include sensitive information and should be "
"avoided in production.")
@app.middleware("http")
async def log_response(request: Request, call_next):
response = await call_next(request)
response_body = [
section async for section in response.body_iterator
]
response.body_iterator = iterate_in_threadpool(iter(response_body))
logger.info("response_body={%s}", response_body[0].decode())
return response
for middleware in args.middleware: for middleware in args.middleware:
module_path, object_name = middleware.rsplit(".", 1) module_path, object_name = middleware.rsplit(".", 1)
imported = getattr(importlib.import_module(module_path), object_name) imported = getattr(importlib.import_module(module_path), object_name)
......
...@@ -270,6 +270,11 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -270,6 +270,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_API_KEY": "VLLM_API_KEY":
lambda: os.environ.get("VLLM_API_KEY", None), lambda: os.environ.get("VLLM_API_KEY", None),
# Whether to log responses from API Server for debugging
"VLLM_DEBUG_LOG_API_SERVER_RESPONSE":
lambda: os.environ.get("VLLM_DEBUG_LOG_API_SERVER_RESPONSE", "False").
lower() == "true",
# S3 access information, used for tensorizer to load model from S3 # S3 access information, used for tensorizer to load model from S3
"S3_ACCESS_KEY_ID": "S3_ACCESS_KEY_ID":
lambda: os.environ.get("S3_ACCESS_KEY_ID", None), lambda: os.environ.get("S3_ACCESS_KEY_ID", None),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment