Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
d5ae2eba
Unverified
Commit
d5ae2eba
authored
Mar 11, 2024
by
Alessio Dalla Piazza
Committed by
GitHub
Mar 11, 2024
Browse files
Add Support for API Key Authentication (#230)
parent
1b355479
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
63 additions
and
18 deletions
+63
-18
python/sglang/backend/runtime_endpoint.py
python/sglang/backend/runtime_endpoint.py
+11
-1
python/sglang/srt/server.py
python/sglang/srt/server.py
+32
-3
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+7
-0
python/sglang/utils.py
python/sglang/utils.py
+13
-14
No files found.
python/sglang/backend/runtime_endpoint.py
View file @
d5ae2eba
...
...
@@ -12,17 +12,19 @@ from sglang.utils import encode_image_base64, find_printable_text, http_request
class
RuntimeEndpoint
(
BaseBackend
):
def
__init__
(
self
,
base_url
,
auth_token
=
None
,
verify
=
None
):
def
__init__
(
self
,
base_url
,
auth_token
=
None
,
api_key
=
None
,
verify
=
None
):
super
().
__init__
()
self
.
support_concate_and_append
=
True
self
.
base_url
=
base_url
self
.
auth_token
=
auth_token
self
.
api_key
=
api_key
self
.
verify
=
verify
res
=
http_request
(
self
.
base_url
+
"/get_model_info"
,
auth_token
=
self
.
auth_token
,
api_key
=
self
.
api_key
,
verify
=
self
.
verify
,
)
assert
res
.
status_code
==
200
...
...
@@ -59,6 +61,7 @@ class RuntimeEndpoint(BaseBackend):
self
.
base_url
+
"/generate"
,
json
=
{
"text"
:
prefix_str
,
"sampling_params"
:
{
"max_new_tokens"
:
0
}},
auth_token
=
self
.
auth_token
,
api_key
=
self
.
api_key
verify
=
self
.
verify
,
)
assert
res
.
status_code
==
200
...
...
@@ -68,6 +71,7 @@ class RuntimeEndpoint(BaseBackend):
self
.
base_url
+
"/generate"
,
json
=
{
"text"
:
s
.
text_
,
"sampling_params"
:
{
"max_new_tokens"
:
0
}},
auth_token
=
self
.
auth_token
,
api_key
=
self
.
api_key
verify
=
self
.
verify
,
)
assert
res
.
status_code
==
200
...
...
@@ -79,6 +83,7 @@ class RuntimeEndpoint(BaseBackend):
self
.
base_url
+
"/generate"
,
json
=
data
,
auth_token
=
self
.
auth_token
,
api_key
=
self
.
api_key
,
verify
=
self
.
verify
,
)
assert
res
.
status_code
==
200
...
...
@@ -114,6 +119,7 @@ class RuntimeEndpoint(BaseBackend):
self
.
base_url
+
"/generate"
,
json
=
data
,
auth_token
=
self
.
auth_token
,
api_key
=
self
.
api_key
,
verify
=
self
.
verify
,
)
obj
=
res
.
json
()
...
...
@@ -153,6 +159,7 @@ class RuntimeEndpoint(BaseBackend):
json
=
data
,
stream
=
True
,
auth_token
=
self
.
auth_token
,
api_key
=
self
.
api_key
verify
=
self
.
verify
,
)
pos
=
0
...
...
@@ -188,6 +195,7 @@ class RuntimeEndpoint(BaseBackend):
self
.
base_url
+
"/generate"
,
json
=
data
,
auth_token
=
self
.
auth_token
,
api_key
=
self
.
api_key
,
verify
=
self
.
verify
,
)
assert
res
.
status_code
==
200
...
...
@@ -205,6 +213,7 @@ class RuntimeEndpoint(BaseBackend):
self
.
base_url
+
"/generate"
,
json
=
data
,
auth_token
=
self
.
auth_token
,
api_key
=
self
.
api_key
,
verify
=
self
.
verify
,
)
assert
res
.
status_code
==
200
...
...
@@ -222,6 +231,7 @@ class RuntimeEndpoint(BaseBackend):
self
.
base_url
+
"/concate_and_append_request"
,
json
=
{
"src_rids"
:
src_rids
,
"dst_rid"
:
dst_rid
},
auth_token
=
self
.
auth_token
,
api_key
=
self
.
api_key
,
verify
=
self
.
verify
,
)
assert
res
.
status_code
==
200
...
...
python/sglang/srt/server.py
View file @
d5ae2eba
...
...
@@ -20,6 +20,8 @@ import requests
import
uvicorn
import
uvloop
from
fastapi
import
FastAPI
,
HTTPException
,
Request
from
starlette.middleware.base
import
BaseHTTPMiddleware
from
starlette.responses
import
JSONResponse
from
fastapi.responses
import
Response
,
StreamingResponse
from
pydantic
import
BaseModel
from
sglang.backend.runtime_endpoint
import
RuntimeEndpoint
...
...
@@ -57,6 +59,23 @@ from sglang.srt.utils import handle_port_init
asyncio
.
set_event_loop_policy
(
uvloop
.
EventLoopPolicy
())
API_KEY_HEADER_NAME
=
"X-API-Key"
class
APIKeyValidatorMiddleware
(
BaseHTTPMiddleware
):
def
__init__
(
self
,
app
,
api_key
:
str
):
super
().
__init__
(
app
)
self
.
api_key
=
api_key
async
def
dispatch
(
self
,
request
:
Request
,
call_next
):
# extract API key from the request headers
api_key_header
=
request
.
headers
.
get
(
API_KEY_HEADER_NAME
)
if
not
api_key_header
or
api_key_header
!=
self
.
api_key
:
return
JSONResponse
(
status_code
=
403
,
content
=
{
"detail"
:
"Invalid API Key"
},
)
response
=
await
call_next
(
request
)
return
response
app
=
FastAPI
()
tokenizer_manager
=
None
...
...
@@ -482,6 +501,9 @@ def launch_server(server_args, pipe_finish_writer):
assert
proc_router
.
is_alive
()
and
proc_detoken
.
is_alive
()
if
server_args
.
api_key
and
server_args
.
api_key
!=
""
:
app
.
add_middleware
(
APIKeyValidatorMiddleware
,
api_key
=
server_args
.
api_key
)
def
_launch_server
():
uvicorn
.
run
(
app
,
...
...
@@ -493,11 +515,15 @@ def launch_server(server_args, pipe_finish_writer):
)
def
_wait_and_warmup
():
headers
=
{}
url
=
server_args
.
url
()
for
_
in
range
(
60
):
time
.
sleep
(
1
)
if
server_args
.
api_key
and
server_args
.
api_key
!=
""
:
headers
[
API_KEY_HEADER_NAME
]
=
server_args
.
api_key
for
_
in
range
(
120
):
time
.
sleep
(
0.5
)
try
:
requests
.
get
(
url
+
"/get_model_info"
,
timeout
=
5
)
requests
.
get
(
url
+
"/get_model_info"
,
timeout
=
5
,
headers
=
headers
)
break
except
requests
.
exceptions
.
RequestException
as
e
:
pass
...
...
@@ -520,6 +546,7 @@ def launch_server(server_args, pipe_finish_writer):
"max_new_tokens"
:
16
,
},
},
headers
=
headers
,
timeout
=
60
,
)
# print(f"Warmup done. model response: {res.json()['text']}")
...
...
@@ -558,6 +585,7 @@ class Runtime:
attention_reduce_in_fp32
:
bool
=
False
,
random_seed
:
int
=
42
,
log_level
:
str
=
"error"
,
api_key
:
str
=
""
,
port
:
Optional
[
int
]
=
None
,
additional_ports
:
Optional
[
Union
[
List
[
int
],
int
]]
=
None
,
):
...
...
@@ -580,6 +608,7 @@ class Runtime:
attention_reduce_in_fp32
=
attention_reduce_in_fp32
,
random_seed
=
random_seed
,
log_level
=
log_level
,
api_key
=
api_key
,
)
self
.
url
=
self
.
server_args
.
url
()
...
...
python/sglang/srt/server_args.py
View file @
d5ae2eba
...
...
@@ -32,6 +32,7 @@ class ServerArgs:
enable_flashinfer
:
bool
=
False
disable_regex_jump_forward
:
bool
=
False
disable_disk_cache
:
bool
=
False
api_key
:
str
=
""
def
__post_init__
(
self
):
if
self
.
tokenizer_path
is
None
:
...
...
@@ -201,6 +202,12 @@ class ServerArgs:
action
=
"store_true"
,
help
=
"Disable disk cache to avoid possible crashes related to file system or high concurrency."
,
)
parser
.
add_argument
(
"--api-key"
,
type
=
str
,
default
=
ServerArgs
.
api_key
,
help
=
"Set API Key"
,
)
@
classmethod
def
from_cli_args
(
cls
,
args
:
argparse
.
Namespace
):
...
...
python/sglang/utils.py
View file @
d5ae2eba
...
...
@@ -88,23 +88,22 @@ class HttpResponse:
return
self
.
resp
.
status
def
http_request
(
url
,
json
=
None
,
stream
=
False
,
auth_token
=
None
,
verify
=
None
):
def
http_request
(
url
,
json
=
None
,
stream
=
False
,
auth_token
=
None
,
api_key
=
None
,
verify
=
None
):
"""A faster version of requests.post with low-level urllib API."""
headers
=
{
"Content-Type"
:
"application/json; charset=utf-8"
}
# add the Authorization header if an auth token is provided
if
auth_token
is
not
None
:
headers
[
"Authorization"
]
=
f
"Bearer
{
auth_token
}
"
# add the API Key header if an API key is provided
if
api_key
is
not
None
:
headers
[
"X-API-Key"
]
=
api_key
if
stream
:
if
auth_token
is
None
:
return
requests
.
post
(
url
,
json
=
json
,
stream
=
True
,
verify
=
verify
)
headers
=
{
"Content-Type"
:
"application/json"
,
"Authentication"
:
f
"Bearer
{
auth_token
}
"
,
}
return
requests
.
post
(
url
,
json
=
json
,
stream
=
True
,
headers
=
headers
,
verify
=
verify
)
return
requests
.
post
(
url
,
json
=
json
,
stream
=
True
,
headers
=
headers
)
else
:
req
=
urllib
.
request
.
Request
(
url
)
req
.
add_header
(
"Content-Type"
,
"application/json; charset=utf-8"
)
if
auth_token
is
not
None
:
req
.
add_header
(
"Authentication"
,
f
"Bearer
{
auth_token
}
"
)
req
=
urllib
.
request
.
Request
(
url
,
headers
=
headers
)
if
json
is
None
:
data
=
None
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment