Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
d5ae2eba
"vscode:/vscode.git/clone" did not exist on "81eba3912f2453debda16d0c73e94940b2cc30f0"
Unverified
Commit
d5ae2eba
authored
Mar 11, 2024
by
Alessio Dalla Piazza
Committed by
GitHub
Mar 11, 2024
Browse files
Add Support for API Key Authentication (#230)
parent
1b355479
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
63 additions
and
18 deletions
+63
-18
python/sglang/backend/runtime_endpoint.py
python/sglang/backend/runtime_endpoint.py
+11
-1
python/sglang/srt/server.py
python/sglang/srt/server.py
+32
-3
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+7
-0
python/sglang/utils.py
python/sglang/utils.py
+13
-14
No files found.
python/sglang/backend/runtime_endpoint.py
View file @
d5ae2eba
...
@@ -12,17 +12,19 @@ from sglang.utils import encode_image_base64, find_printable_text, http_request
...
@@ -12,17 +12,19 @@ from sglang.utils import encode_image_base64, find_printable_text, http_request
class
RuntimeEndpoint
(
BaseBackend
):
class
RuntimeEndpoint
(
BaseBackend
):
def
__init__
(
self
,
base_url
,
auth_token
=
None
,
verify
=
None
):
def
__init__
(
self
,
base_url
,
auth_token
=
None
,
api_key
=
None
,
verify
=
None
):
super
().
__init__
()
super
().
__init__
()
self
.
support_concate_and_append
=
True
self
.
support_concate_and_append
=
True
self
.
base_url
=
base_url
self
.
base_url
=
base_url
self
.
auth_token
=
auth_token
self
.
auth_token
=
auth_token
self
.
api_key
=
api_key
self
.
verify
=
verify
self
.
verify
=
verify
res
=
http_request
(
res
=
http_request
(
self
.
base_url
+
"/get_model_info"
,
self
.
base_url
+
"/get_model_info"
,
auth_token
=
self
.
auth_token
,
auth_token
=
self
.
auth_token
,
api_key
=
self
.
api_key
,
verify
=
self
.
verify
,
verify
=
self
.
verify
,
)
)
assert
res
.
status_code
==
200
assert
res
.
status_code
==
200
...
@@ -59,6 +61,7 @@ class RuntimeEndpoint(BaseBackend):
...
@@ -59,6 +61,7 @@ class RuntimeEndpoint(BaseBackend):
self
.
base_url
+
"/generate"
,
self
.
base_url
+
"/generate"
,
json
=
{
"text"
:
prefix_str
,
"sampling_params"
:
{
"max_new_tokens"
:
0
}},
json
=
{
"text"
:
prefix_str
,
"sampling_params"
:
{
"max_new_tokens"
:
0
}},
auth_token
=
self
.
auth_token
,
auth_token
=
self
.
auth_token
,
api_key
=
self
.
api_key
verify
=
self
.
verify
,
verify
=
self
.
verify
,
)
)
assert
res
.
status_code
==
200
assert
res
.
status_code
==
200
...
@@ -68,6 +71,7 @@ class RuntimeEndpoint(BaseBackend):
...
@@ -68,6 +71,7 @@ class RuntimeEndpoint(BaseBackend):
self
.
base_url
+
"/generate"
,
self
.
base_url
+
"/generate"
,
json
=
{
"text"
:
s
.
text_
,
"sampling_params"
:
{
"max_new_tokens"
:
0
}},
json
=
{
"text"
:
s
.
text_
,
"sampling_params"
:
{
"max_new_tokens"
:
0
}},
auth_token
=
self
.
auth_token
,
auth_token
=
self
.
auth_token
,
api_key
=
self
.
api_key
verify
=
self
.
verify
,
verify
=
self
.
verify
,
)
)
assert
res
.
status_code
==
200
assert
res
.
status_code
==
200
...
@@ -79,6 +83,7 @@ class RuntimeEndpoint(BaseBackend):
...
@@ -79,6 +83,7 @@ class RuntimeEndpoint(BaseBackend):
self
.
base_url
+
"/generate"
,
self
.
base_url
+
"/generate"
,
json
=
data
,
json
=
data
,
auth_token
=
self
.
auth_token
,
auth_token
=
self
.
auth_token
,
api_key
=
self
.
api_key
,
verify
=
self
.
verify
,
verify
=
self
.
verify
,
)
)
assert
res
.
status_code
==
200
assert
res
.
status_code
==
200
...
@@ -114,6 +119,7 @@ class RuntimeEndpoint(BaseBackend):
...
@@ -114,6 +119,7 @@ class RuntimeEndpoint(BaseBackend):
self
.
base_url
+
"/generate"
,
self
.
base_url
+
"/generate"
,
json
=
data
,
json
=
data
,
auth_token
=
self
.
auth_token
,
auth_token
=
self
.
auth_token
,
api_key
=
self
.
api_key
,
verify
=
self
.
verify
,
verify
=
self
.
verify
,
)
)
obj
=
res
.
json
()
obj
=
res
.
json
()
...
@@ -153,6 +159,7 @@ class RuntimeEndpoint(BaseBackend):
...
@@ -153,6 +159,7 @@ class RuntimeEndpoint(BaseBackend):
json
=
data
,
json
=
data
,
stream
=
True
,
stream
=
True
,
auth_token
=
self
.
auth_token
,
auth_token
=
self
.
auth_token
,
api_key
=
self
.
api_key
verify
=
self
.
verify
,
verify
=
self
.
verify
,
)
)
pos
=
0
pos
=
0
...
@@ -188,6 +195,7 @@ class RuntimeEndpoint(BaseBackend):
...
@@ -188,6 +195,7 @@ class RuntimeEndpoint(BaseBackend):
self
.
base_url
+
"/generate"
,
self
.
base_url
+
"/generate"
,
json
=
data
,
json
=
data
,
auth_token
=
self
.
auth_token
,
auth_token
=
self
.
auth_token
,
api_key
=
self
.
api_key
,
verify
=
self
.
verify
,
verify
=
self
.
verify
,
)
)
assert
res
.
status_code
==
200
assert
res
.
status_code
==
200
...
@@ -205,6 +213,7 @@ class RuntimeEndpoint(BaseBackend):
...
@@ -205,6 +213,7 @@ class RuntimeEndpoint(BaseBackend):
self
.
base_url
+
"/generate"
,
self
.
base_url
+
"/generate"
,
json
=
data
,
json
=
data
,
auth_token
=
self
.
auth_token
,
auth_token
=
self
.
auth_token
,
api_key
=
self
.
api_key
,
verify
=
self
.
verify
,
verify
=
self
.
verify
,
)
)
assert
res
.
status_code
==
200
assert
res
.
status_code
==
200
...
@@ -222,6 +231,7 @@ class RuntimeEndpoint(BaseBackend):
...
@@ -222,6 +231,7 @@ class RuntimeEndpoint(BaseBackend):
self
.
base_url
+
"/concate_and_append_request"
,
self
.
base_url
+
"/concate_and_append_request"
,
json
=
{
"src_rids"
:
src_rids
,
"dst_rid"
:
dst_rid
},
json
=
{
"src_rids"
:
src_rids
,
"dst_rid"
:
dst_rid
},
auth_token
=
self
.
auth_token
,
auth_token
=
self
.
auth_token
,
api_key
=
self
.
api_key
,
verify
=
self
.
verify
,
verify
=
self
.
verify
,
)
)
assert
res
.
status_code
==
200
assert
res
.
status_code
==
200
...
...
python/sglang/srt/server.py
View file @
d5ae2eba
...
@@ -20,6 +20,8 @@ import requests
...
@@ -20,6 +20,8 @@ import requests
import
uvicorn
import
uvicorn
import
uvloop
import
uvloop
from
fastapi
import
FastAPI
,
HTTPException
,
Request
from
fastapi
import
FastAPI
,
HTTPException
,
Request
from
starlette.middleware.base
import
BaseHTTPMiddleware
from
starlette.responses
import
JSONResponse
from
fastapi.responses
import
Response
,
StreamingResponse
from
fastapi.responses
import
Response
,
StreamingResponse
from
pydantic
import
BaseModel
from
pydantic
import
BaseModel
from
sglang.backend.runtime_endpoint
import
RuntimeEndpoint
from
sglang.backend.runtime_endpoint
import
RuntimeEndpoint
...
@@ -57,6 +59,23 @@ from sglang.srt.utils import handle_port_init
...
@@ -57,6 +59,23 @@ from sglang.srt.utils import handle_port_init
asyncio
.
set_event_loop_policy
(
uvloop
.
EventLoopPolicy
())
asyncio
.
set_event_loop_policy
(
uvloop
.
EventLoopPolicy
())
API_KEY_HEADER_NAME
=
"X-API-Key"
class
APIKeyValidatorMiddleware
(
BaseHTTPMiddleware
):
def
__init__
(
self
,
app
,
api_key
:
str
):
super
().
__init__
(
app
)
self
.
api_key
=
api_key
async
def
dispatch
(
self
,
request
:
Request
,
call_next
):
# extract API key from the request headers
api_key_header
=
request
.
headers
.
get
(
API_KEY_HEADER_NAME
)
if
not
api_key_header
or
api_key_header
!=
self
.
api_key
:
return
JSONResponse
(
status_code
=
403
,
content
=
{
"detail"
:
"Invalid API Key"
},
)
response
=
await
call_next
(
request
)
return
response
app
=
FastAPI
()
app
=
FastAPI
()
tokenizer_manager
=
None
tokenizer_manager
=
None
...
@@ -482,6 +501,9 @@ def launch_server(server_args, pipe_finish_writer):
...
@@ -482,6 +501,9 @@ def launch_server(server_args, pipe_finish_writer):
assert
proc_router
.
is_alive
()
and
proc_detoken
.
is_alive
()
assert
proc_router
.
is_alive
()
and
proc_detoken
.
is_alive
()
if
server_args
.
api_key
and
server_args
.
api_key
!=
""
:
app
.
add_middleware
(
APIKeyValidatorMiddleware
,
api_key
=
server_args
.
api_key
)
def
_launch_server
():
def
_launch_server
():
uvicorn
.
run
(
uvicorn
.
run
(
app
,
app
,
...
@@ -493,11 +515,15 @@ def launch_server(server_args, pipe_finish_writer):
...
@@ -493,11 +515,15 @@ def launch_server(server_args, pipe_finish_writer):
)
)
def
_wait_and_warmup
():
def
_wait_and_warmup
():
headers
=
{}
url
=
server_args
.
url
()
url
=
server_args
.
url
()
for
_
in
range
(
60
):
if
server_args
.
api_key
and
server_args
.
api_key
!=
""
:
time
.
sleep
(
1
)
headers
[
API_KEY_HEADER_NAME
]
=
server_args
.
api_key
for
_
in
range
(
120
):
time
.
sleep
(
0.5
)
try
:
try
:
requests
.
get
(
url
+
"/get_model_info"
,
timeout
=
5
)
requests
.
get
(
url
+
"/get_model_info"
,
timeout
=
5
,
headers
=
headers
)
break
break
except
requests
.
exceptions
.
RequestException
as
e
:
except
requests
.
exceptions
.
RequestException
as
e
:
pass
pass
...
@@ -520,6 +546,7 @@ def launch_server(server_args, pipe_finish_writer):
...
@@ -520,6 +546,7 @@ def launch_server(server_args, pipe_finish_writer):
"max_new_tokens"
:
16
,
"max_new_tokens"
:
16
,
},
},
},
},
headers
=
headers
,
timeout
=
60
,
timeout
=
60
,
)
)
# print(f"Warmup done. model response: {res.json()['text']}")
# print(f"Warmup done. model response: {res.json()['text']}")
...
@@ -558,6 +585,7 @@ class Runtime:
...
@@ -558,6 +585,7 @@ class Runtime:
attention_reduce_in_fp32
:
bool
=
False
,
attention_reduce_in_fp32
:
bool
=
False
,
random_seed
:
int
=
42
,
random_seed
:
int
=
42
,
log_level
:
str
=
"error"
,
log_level
:
str
=
"error"
,
api_key
:
str
=
""
,
port
:
Optional
[
int
]
=
None
,
port
:
Optional
[
int
]
=
None
,
additional_ports
:
Optional
[
Union
[
List
[
int
],
int
]]
=
None
,
additional_ports
:
Optional
[
Union
[
List
[
int
],
int
]]
=
None
,
):
):
...
@@ -580,6 +608,7 @@ class Runtime:
...
@@ -580,6 +608,7 @@ class Runtime:
attention_reduce_in_fp32
=
attention_reduce_in_fp32
,
attention_reduce_in_fp32
=
attention_reduce_in_fp32
,
random_seed
=
random_seed
,
random_seed
=
random_seed
,
log_level
=
log_level
,
log_level
=
log_level
,
api_key
=
api_key
,
)
)
self
.
url
=
self
.
server_args
.
url
()
self
.
url
=
self
.
server_args
.
url
()
...
...
python/sglang/srt/server_args.py
View file @
d5ae2eba
...
@@ -32,6 +32,7 @@ class ServerArgs:
...
@@ -32,6 +32,7 @@ class ServerArgs:
enable_flashinfer
:
bool
=
False
enable_flashinfer
:
bool
=
False
disable_regex_jump_forward
:
bool
=
False
disable_regex_jump_forward
:
bool
=
False
disable_disk_cache
:
bool
=
False
disable_disk_cache
:
bool
=
False
api_key
:
str
=
""
def
__post_init__
(
self
):
def
__post_init__
(
self
):
if
self
.
tokenizer_path
is
None
:
if
self
.
tokenizer_path
is
None
:
...
@@ -201,6 +202,12 @@ class ServerArgs:
...
@@ -201,6 +202,12 @@ class ServerArgs:
action
=
"store_true"
,
action
=
"store_true"
,
help
=
"Disable disk cache to avoid possible crashes related to file system or high concurrency."
,
help
=
"Disable disk cache to avoid possible crashes related to file system or high concurrency."
,
)
)
parser
.
add_argument
(
"--api-key"
,
type
=
str
,
default
=
ServerArgs
.
api_key
,
help
=
"Set API Key"
,
)
@
classmethod
@
classmethod
def
from_cli_args
(
cls
,
args
:
argparse
.
Namespace
):
def
from_cli_args
(
cls
,
args
:
argparse
.
Namespace
):
...
...
python/sglang/utils.py
View file @
d5ae2eba
...
@@ -88,23 +88,22 @@ class HttpResponse:
...
@@ -88,23 +88,22 @@ class HttpResponse:
return
self
.
resp
.
status
return
self
.
resp
.
status
def
http_request
(
url
,
json
=
None
,
stream
=
False
,
auth_token
=
None
,
verify
=
None
):
def
http_request
(
url
,
json
=
None
,
stream
=
False
,
auth_token
=
None
,
api_key
=
None
,
verify
=
None
):
"""A faster version of requests.post with low-level urllib API."""
"""A faster version of requests.post with low-level urllib API."""
headers
=
{
"Content-Type"
:
"application/json; charset=utf-8"
}
# add the Authorization header if an auth token is provided
if
auth_token
is
not
None
:
headers
[
"Authorization"
]
=
f
"Bearer
{
auth_token
}
"
# add the API Key header if an API key is provided
if
api_key
is
not
None
:
headers
[
"X-API-Key"
]
=
api_key
if
stream
:
if
stream
:
if
auth_token
is
None
:
return
requests
.
post
(
url
,
json
=
json
,
stream
=
True
,
headers
=
headers
)
return
requests
.
post
(
url
,
json
=
json
,
stream
=
True
,
verify
=
verify
)
headers
=
{
"Content-Type"
:
"application/json"
,
"Authentication"
:
f
"Bearer
{
auth_token
}
"
,
}
return
requests
.
post
(
url
,
json
=
json
,
stream
=
True
,
headers
=
headers
,
verify
=
verify
)
else
:
else
:
req
=
urllib
.
request
.
Request
(
url
)
req
=
urllib
.
request
.
Request
(
url
,
headers
=
headers
)
req
.
add_header
(
"Content-Type"
,
"application/json; charset=utf-8"
)
if
auth_token
is
not
None
:
req
.
add_header
(
"Authentication"
,
f
"Bearer
{
auth_token
}
"
)
if
json
is
None
:
if
json
is
None
:
data
=
None
data
=
None
else
:
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment