Unverified Commit d5ae2eba authored by Alessio Dalla Piazza's avatar Alessio Dalla Piazza Committed by GitHub
Browse files

Add Support for API Key Authentication (#230)

parent 1b355479
...@@ -12,17 +12,19 @@ from sglang.utils import encode_image_base64, find_printable_text, http_request ...@@ -12,17 +12,19 @@ from sglang.utils import encode_image_base64, find_printable_text, http_request
class RuntimeEndpoint(BaseBackend): class RuntimeEndpoint(BaseBackend):
def __init__(self, base_url, auth_token=None, verify=None): def __init__(self, base_url, auth_token=None, api_key=None, verify=None):
super().__init__() super().__init__()
self.support_concate_and_append = True self.support_concate_and_append = True
self.base_url = base_url self.base_url = base_url
self.auth_token = auth_token self.auth_token = auth_token
self.api_key = api_key
self.verify = verify self.verify = verify
res = http_request( res = http_request(
self.base_url + "/get_model_info", self.base_url + "/get_model_info",
auth_token=self.auth_token, auth_token=self.auth_token,
api_key=self.api_key,
verify=self.verify, verify=self.verify,
) )
assert res.status_code == 200 assert res.status_code == 200
...@@ -59,6 +61,7 @@ class RuntimeEndpoint(BaseBackend): ...@@ -59,6 +61,7 @@ class RuntimeEndpoint(BaseBackend):
self.base_url + "/generate", self.base_url + "/generate",
json={"text": prefix_str, "sampling_params": {"max_new_tokens": 0}}, json={"text": prefix_str, "sampling_params": {"max_new_tokens": 0}},
auth_token=self.auth_token, auth_token=self.auth_token,
api_key=self.api_key
verify=self.verify, verify=self.verify,
) )
assert res.status_code == 200 assert res.status_code == 200
...@@ -68,6 +71,7 @@ class RuntimeEndpoint(BaseBackend): ...@@ -68,6 +71,7 @@ class RuntimeEndpoint(BaseBackend):
self.base_url + "/generate", self.base_url + "/generate",
json={"text": s.text_, "sampling_params": {"max_new_tokens": 0}}, json={"text": s.text_, "sampling_params": {"max_new_tokens": 0}},
auth_token=self.auth_token, auth_token=self.auth_token,
api_key=self.api_key
verify=self.verify, verify=self.verify,
) )
assert res.status_code == 200 assert res.status_code == 200
...@@ -79,6 +83,7 @@ class RuntimeEndpoint(BaseBackend): ...@@ -79,6 +83,7 @@ class RuntimeEndpoint(BaseBackend):
self.base_url + "/generate", self.base_url + "/generate",
json=data, json=data,
auth_token=self.auth_token, auth_token=self.auth_token,
api_key=self.api_key,
verify=self.verify, verify=self.verify,
) )
assert res.status_code == 200 assert res.status_code == 200
...@@ -114,6 +119,7 @@ class RuntimeEndpoint(BaseBackend): ...@@ -114,6 +119,7 @@ class RuntimeEndpoint(BaseBackend):
self.base_url + "/generate", self.base_url + "/generate",
json=data, json=data,
auth_token=self.auth_token, auth_token=self.auth_token,
api_key=self.api_key,
verify=self.verify, verify=self.verify,
) )
obj = res.json() obj = res.json()
...@@ -153,6 +159,7 @@ class RuntimeEndpoint(BaseBackend): ...@@ -153,6 +159,7 @@ class RuntimeEndpoint(BaseBackend):
json=data, json=data,
stream=True, stream=True,
auth_token=self.auth_token, auth_token=self.auth_token,
api_key=self.api_key
verify=self.verify, verify=self.verify,
) )
pos = 0 pos = 0
...@@ -188,6 +195,7 @@ class RuntimeEndpoint(BaseBackend): ...@@ -188,6 +195,7 @@ class RuntimeEndpoint(BaseBackend):
self.base_url + "/generate", self.base_url + "/generate",
json=data, json=data,
auth_token=self.auth_token, auth_token=self.auth_token,
api_key=self.api_key,
verify=self.verify, verify=self.verify,
) )
assert res.status_code == 200 assert res.status_code == 200
...@@ -205,6 +213,7 @@ class RuntimeEndpoint(BaseBackend): ...@@ -205,6 +213,7 @@ class RuntimeEndpoint(BaseBackend):
self.base_url + "/generate", self.base_url + "/generate",
json=data, json=data,
auth_token=self.auth_token, auth_token=self.auth_token,
api_key=self.api_key,
verify=self.verify, verify=self.verify,
) )
assert res.status_code == 200 assert res.status_code == 200
...@@ -222,6 +231,7 @@ class RuntimeEndpoint(BaseBackend): ...@@ -222,6 +231,7 @@ class RuntimeEndpoint(BaseBackend):
self.base_url + "/concate_and_append_request", self.base_url + "/concate_and_append_request",
json={"src_rids": src_rids, "dst_rid": dst_rid}, json={"src_rids": src_rids, "dst_rid": dst_rid},
auth_token=self.auth_token, auth_token=self.auth_token,
api_key=self.api_key,
verify=self.verify, verify=self.verify,
) )
assert res.status_code == 200 assert res.status_code == 200
......
...@@ -20,6 +20,8 @@ import requests ...@@ -20,6 +20,8 @@ import requests
import uvicorn import uvicorn
import uvloop import uvloop
from fastapi import FastAPI, HTTPException, Request from fastapi import FastAPI, HTTPException, Request
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import JSONResponse
from fastapi.responses import Response, StreamingResponse from fastapi.responses import Response, StreamingResponse
from pydantic import BaseModel from pydantic import BaseModel
from sglang.backend.runtime_endpoint import RuntimeEndpoint from sglang.backend.runtime_endpoint import RuntimeEndpoint
...@@ -57,6 +59,23 @@ from sglang.srt.utils import handle_port_init ...@@ -57,6 +59,23 @@ from sglang.srt.utils import handle_port_init
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
API_KEY_HEADER_NAME = "X-API-Key"
class APIKeyValidatorMiddleware(BaseHTTPMiddleware):
def __init__(self, app, api_key: str):
super().__init__(app)
self.api_key = api_key
async def dispatch(self, request: Request, call_next):
# extract API key from the request headers
api_key_header = request.headers.get(API_KEY_HEADER_NAME)
if not api_key_header or api_key_header != self.api_key:
return JSONResponse(
status_code=403,
content={"detail": "Invalid API Key"},
)
response = await call_next(request)
return response
app = FastAPI() app = FastAPI()
tokenizer_manager = None tokenizer_manager = None
...@@ -482,6 +501,9 @@ def launch_server(server_args, pipe_finish_writer): ...@@ -482,6 +501,9 @@ def launch_server(server_args, pipe_finish_writer):
assert proc_router.is_alive() and proc_detoken.is_alive() assert proc_router.is_alive() and proc_detoken.is_alive()
if server_args.api_key and server_args.api_key != "":
app.add_middleware(APIKeyValidatorMiddleware, api_key=server_args.api_key)
def _launch_server(): def _launch_server():
uvicorn.run( uvicorn.run(
app, app,
...@@ -493,11 +515,15 @@ def launch_server(server_args, pipe_finish_writer): ...@@ -493,11 +515,15 @@ def launch_server(server_args, pipe_finish_writer):
) )
def _wait_and_warmup(): def _wait_and_warmup():
headers = {}
url = server_args.url() url = server_args.url()
for _ in range(60): if server_args.api_key and server_args.api_key != "":
time.sleep(1) headers[API_KEY_HEADER_NAME] = server_args.api_key
for _ in range(120):
time.sleep(0.5)
try: try:
requests.get(url + "/get_model_info", timeout=5) requests.get(url + "/get_model_info", timeout=5, headers=headers)
break break
except requests.exceptions.RequestException as e: except requests.exceptions.RequestException as e:
pass pass
...@@ -520,6 +546,7 @@ def launch_server(server_args, pipe_finish_writer): ...@@ -520,6 +546,7 @@ def launch_server(server_args, pipe_finish_writer):
"max_new_tokens": 16, "max_new_tokens": 16,
}, },
}, },
headers=headers,
timeout=60, timeout=60,
) )
# print(f"Warmup done. model response: {res.json()['text']}") # print(f"Warmup done. model response: {res.json()['text']}")
...@@ -558,6 +585,7 @@ class Runtime: ...@@ -558,6 +585,7 @@ class Runtime:
attention_reduce_in_fp32: bool = False, attention_reduce_in_fp32: bool = False,
random_seed: int = 42, random_seed: int = 42,
log_level: str = "error", log_level: str = "error",
api_key: str = "",
port: Optional[int] = None, port: Optional[int] = None,
additional_ports: Optional[Union[List[int], int]] = None, additional_ports: Optional[Union[List[int], int]] = None,
): ):
...@@ -580,6 +608,7 @@ class Runtime: ...@@ -580,6 +608,7 @@ class Runtime:
attention_reduce_in_fp32=attention_reduce_in_fp32, attention_reduce_in_fp32=attention_reduce_in_fp32,
random_seed=random_seed, random_seed=random_seed,
log_level=log_level, log_level=log_level,
api_key=api_key,
) )
self.url = self.server_args.url() self.url = self.server_args.url()
......
...@@ -32,6 +32,7 @@ class ServerArgs: ...@@ -32,6 +32,7 @@ class ServerArgs:
enable_flashinfer: bool = False enable_flashinfer: bool = False
disable_regex_jump_forward: bool = False disable_regex_jump_forward: bool = False
disable_disk_cache: bool = False disable_disk_cache: bool = False
api_key: str = ""
def __post_init__(self): def __post_init__(self):
if self.tokenizer_path is None: if self.tokenizer_path is None:
...@@ -201,6 +202,12 @@ class ServerArgs: ...@@ -201,6 +202,12 @@ class ServerArgs:
action="store_true", action="store_true",
help="Disable disk cache to avoid possible crashes related to file system or high concurrency.", help="Disable disk cache to avoid possible crashes related to file system or high concurrency.",
) )
parser.add_argument(
"--api-key",
type=str,
default=ServerArgs.api_key,
help="Set API Key",
)
@classmethod @classmethod
def from_cli_args(cls, args: argparse.Namespace): def from_cli_args(cls, args: argparse.Namespace):
......
...@@ -88,23 +88,22 @@ class HttpResponse: ...@@ -88,23 +88,22 @@ class HttpResponse:
return self.resp.status return self.resp.status
def http_request(url, json=None, stream=False, auth_token=None, verify=None): def http_request(url, json=None, stream=False, auth_token=None, api_key=None, verify=None):
"""A faster version of requests.post with low-level urllib API.""" """A faster version of requests.post with low-level urllib API."""
headers = {"Content-Type": "application/json; charset=utf-8"}
# add the Authorization header if an auth token is provided
if auth_token is not None:
headers["Authorization"] = f"Bearer {auth_token}"
# add the API Key header if an API key is provided
if api_key is not None:
headers["X-API-Key"] = api_key
if stream: if stream:
if auth_token is None: return requests.post(url, json=json, stream=True, headers=headers)
return requests.post(url, json=json, stream=True, verify=verify)
headers = {
"Content-Type": "application/json",
"Authentication": f"Bearer {auth_token}",
}
return requests.post(
url, json=json, stream=True, headers=headers, verify=verify
)
else: else:
req = urllib.request.Request(url) req = urllib.request.Request(url, headers=headers)
req.add_header("Content-Type", "application/json; charset=utf-8")
if auth_token is not None:
req.add_header("Authentication", f"Bearer {auth_token}")
if json is None: if json is None:
data = None data = None
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment