Unverified Commit 02684144 authored by aisensiy's avatar aisensiy Committed by GitHub
Browse files

Support CORS for openai api server (#481)

* Support CORS for openai api server

* Remove unnecessary var

* Add CORS support follow the same style with vllm
parent b58a9dff
...@@ -3,11 +3,12 @@ import json ...@@ -3,11 +3,12 @@ import json
import os import os
import time import time
from http import HTTPStatus from http import HTTPStatus
from typing import AsyncGenerator, Optional from typing import AsyncGenerator, List, Optional
import fire import fire
import uvicorn import uvicorn
from fastapi import BackgroundTasks, FastAPI, Request from fastapi import BackgroundTasks, FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse from fastapi.responses import JSONResponse, StreamingResponse
from lmdeploy.serve.async_engine import AsyncEngine from lmdeploy.serve.async_engine import AsyncEngine
...@@ -321,7 +322,11 @@ def main(model_path: str, ...@@ -321,7 +322,11 @@ def main(model_path: str,
server_name: str = 'localhost', server_name: str = 'localhost',
server_port: int = 23333, server_port: int = 23333,
instance_num: int = 32, instance_num: int = 32,
tp: int = 1): tp: int = 1,
allow_origins: List[str] = ['*'],
allow_credentials: bool = True,
allow_methods: List[str] = ['*'],
allow_headers: List[str] = ['*']):
"""An example to perform model inference through the command line """An example to perform model inference through the command line
interface. interface.
...@@ -331,7 +336,20 @@ def main(model_path: str, ...@@ -331,7 +336,20 @@ def main(model_path: str,
server_port (int): server port server_port (int): server port
instance_num (int): number of instances of turbomind model instance_num (int): number of instances of turbomind model
tp (int): tensor parallel tp (int): tensor parallel
allow_origins (List[str]): a list of allowed origins for CORS
allow_credentials (bool): whether to allow credentials for CORS
allow_methods (List[str]): a list of allowed HTTP methods for CORS
allow_headers (List[str]): a list of allowed HTTP headers for CORS
""" """
if allow_origins:
app.add_middleware(
CORSMiddleware,
allow_origins=allow_origins,
allow_credentials=allow_credentials,
allow_methods=allow_methods,
allow_headers=allow_headers,
)
VariableInterface.async_engine = AsyncEngine(model_path=model_path, VariableInterface.async_engine = AsyncEngine(model_path=model_path,
instance_num=instance_num, instance_num=instance_num,
tp=tp) tp=tp)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment