Unverified Commit d5ae2eba authored by Alessio Dalla Piazza's avatar Alessio Dalla Piazza Committed by GitHub
Browse files

Add Support for API Key Authentication (#230)

parent 1b355479
......@@ -12,17 +12,19 @@ from sglang.utils import encode_image_base64, find_printable_text, http_request
class RuntimeEndpoint(BaseBackend):
def __init__(self, base_url, auth_token=None, verify=None):
def __init__(self, base_url, auth_token=None, api_key=None, verify=None):
super().__init__()
self.support_concate_and_append = True
self.base_url = base_url
self.auth_token = auth_token
self.api_key = api_key
self.verify = verify
res = http_request(
self.base_url + "/get_model_info",
auth_token=self.auth_token,
api_key=self.api_key,
verify=self.verify,
)
assert res.status_code == 200
......@@ -59,6 +61,7 @@ class RuntimeEndpoint(BaseBackend):
self.base_url + "/generate",
json={"text": prefix_str, "sampling_params": {"max_new_tokens": 0}},
auth_token=self.auth_token,
api_key=self.api_key
verify=self.verify,
)
assert res.status_code == 200
......@@ -68,6 +71,7 @@ class RuntimeEndpoint(BaseBackend):
self.base_url + "/generate",
json={"text": s.text_, "sampling_params": {"max_new_tokens": 0}},
auth_token=self.auth_token,
api_key=self.api_key
verify=self.verify,
)
assert res.status_code == 200
......@@ -79,6 +83,7 @@ class RuntimeEndpoint(BaseBackend):
self.base_url + "/generate",
json=data,
auth_token=self.auth_token,
api_key=self.api_key,
verify=self.verify,
)
assert res.status_code == 200
......@@ -114,6 +119,7 @@ class RuntimeEndpoint(BaseBackend):
self.base_url + "/generate",
json=data,
auth_token=self.auth_token,
api_key=self.api_key,
verify=self.verify,
)
obj = res.json()
......@@ -153,6 +159,7 @@ class RuntimeEndpoint(BaseBackend):
json=data,
stream=True,
auth_token=self.auth_token,
api_key=self.api_key
verify=self.verify,
)
pos = 0
......@@ -188,6 +195,7 @@ class RuntimeEndpoint(BaseBackend):
self.base_url + "/generate",
json=data,
auth_token=self.auth_token,
api_key=self.api_key,
verify=self.verify,
)
assert res.status_code == 200
......@@ -205,6 +213,7 @@ class RuntimeEndpoint(BaseBackend):
self.base_url + "/generate",
json=data,
auth_token=self.auth_token,
api_key=self.api_key,
verify=self.verify,
)
assert res.status_code == 200
......@@ -222,6 +231,7 @@ class RuntimeEndpoint(BaseBackend):
self.base_url + "/concate_and_append_request",
json={"src_rids": src_rids, "dst_rid": dst_rid},
auth_token=self.auth_token,
api_key=self.api_key,
verify=self.verify,
)
assert res.status_code == 200
......
......@@ -20,6 +20,8 @@ import requests
import uvicorn
import uvloop
from fastapi import FastAPI, HTTPException, Request
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import JSONResponse
from fastapi.responses import Response, StreamingResponse
from pydantic import BaseModel
from sglang.backend.runtime_endpoint import RuntimeEndpoint
......@@ -57,6 +59,23 @@ from sglang.srt.utils import handle_port_init
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
API_KEY_HEADER_NAME = "X-API-Key"
class APIKeyValidatorMiddleware(BaseHTTPMiddleware):
def __init__(self, app, api_key: str):
super().__init__(app)
self.api_key = api_key
async def dispatch(self, request: Request, call_next):
# extract API key from the request headers
api_key_header = request.headers.get(API_KEY_HEADER_NAME)
if not api_key_header or api_key_header != self.api_key:
return JSONResponse(
status_code=403,
content={"detail": "Invalid API Key"},
)
response = await call_next(request)
return response
app = FastAPI()
tokenizer_manager = None
......@@ -482,6 +501,9 @@ def launch_server(server_args, pipe_finish_writer):
assert proc_router.is_alive() and proc_detoken.is_alive()
if server_args.api_key and server_args.api_key != "":
app.add_middleware(APIKeyValidatorMiddleware, api_key=server_args.api_key)
def _launch_server():
uvicorn.run(
app,
......@@ -493,11 +515,15 @@ def launch_server(server_args, pipe_finish_writer):
)
def _wait_and_warmup():
headers = {}
url = server_args.url()
for _ in range(60):
time.sleep(1)
if server_args.api_key and server_args.api_key != "":
headers[API_KEY_HEADER_NAME] = server_args.api_key
for _ in range(120):
time.sleep(0.5)
try:
requests.get(url + "/get_model_info", timeout=5)
requests.get(url + "/get_model_info", timeout=5, headers=headers)
break
except requests.exceptions.RequestException as e:
pass
......@@ -520,6 +546,7 @@ def launch_server(server_args, pipe_finish_writer):
"max_new_tokens": 16,
},
},
headers=headers,
timeout=60,
)
# print(f"Warmup done. model response: {res.json()['text']}")
......@@ -558,6 +585,7 @@ class Runtime:
attention_reduce_in_fp32: bool = False,
random_seed: int = 42,
log_level: str = "error",
api_key: str = "",
port: Optional[int] = None,
additional_ports: Optional[Union[List[int], int]] = None,
):
......@@ -580,6 +608,7 @@ class Runtime:
attention_reduce_in_fp32=attention_reduce_in_fp32,
random_seed=random_seed,
log_level=log_level,
api_key=api_key,
)
self.url = self.server_args.url()
......
......@@ -32,6 +32,7 @@ class ServerArgs:
enable_flashinfer: bool = False
disable_regex_jump_forward: bool = False
disable_disk_cache: bool = False
api_key: str = ""
def __post_init__(self):
if self.tokenizer_path is None:
......@@ -201,6 +202,12 @@ class ServerArgs:
action="store_true",
help="Disable disk cache to avoid possible crashes related to file system or high concurrency.",
)
parser.add_argument(
"--api-key",
type=str,
default=ServerArgs.api_key,
help="Set API Key",
)
@classmethod
def from_cli_args(cls, args: argparse.Namespace):
......
......@@ -88,23 +88,22 @@ class HttpResponse:
return self.resp.status
def http_request(url, json=None, stream=False, auth_token=None, verify=None):
def http_request(url, json=None, stream=False, auth_token=None, api_key=None, verify=None):
"""A faster version of requests.post with low-level urllib API."""
headers = {"Content-Type": "application/json; charset=utf-8"}
# add the Authorization header if an auth token is provided
if auth_token is not None:
headers["Authorization"] = f"Bearer {auth_token}"
# add the API Key header if an API key is provided
if api_key is not None:
headers["X-API-Key"] = api_key
if stream:
if auth_token is None:
return requests.post(url, json=json, stream=True, verify=verify)
headers = {
"Content-Type": "application/json",
"Authentication": f"Bearer {auth_token}",
}
return requests.post(
url, json=json, stream=True, headers=headers, verify=verify
)
return requests.post(url, json=json, stream=True, headers=headers)
else:
req = urllib.request.Request(url)
req.add_header("Content-Type", "application/json; charset=utf-8")
if auth_token is not None:
req.add_header("Authentication", f"Bearer {auth_token}")
req = urllib.request.Request(url, headers=headers)
if json is None:
data = None
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment