Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
405f26b0
"git@developer.sourcefind.cn:change/sglang.git" did not exist on "569b032c580ac31bf0ccd0dcdad3969f37b3a590"
Unverified
Commit
405f26b0
authored
Feb 08, 2024
by
Srinivas Billa
Committed by
GitHub
Feb 07, 2024
Browse files
Add Auth Token to RuntimeEndpoint (#162)
parent
b1a3a454
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
18 additions
and
9 deletions
+18
-9
python/sglang/backend/runtime_endpoint.py
python/sglang/backend/runtime_endpoint.py
+11
-7
python/sglang/utils.py
python/sglang/utils.py
+7
-2
No files found.
python/sglang/backend/runtime_endpoint.py
View file @
405f26b0
...
@@ -12,13 +12,14 @@ from sglang.utils import encode_image_base64, find_printable_text, http_request
...
@@ -12,13 +12,14 @@ from sglang.utils import encode_image_base64, find_printable_text, http_request
class
RuntimeEndpoint
(
BaseBackend
):
class
RuntimeEndpoint
(
BaseBackend
):
def
__init__
(
self
,
base_url
):
def
__init__
(
self
,
base_url
,
auth_token
=
None
):
super
().
__init__
()
super
().
__init__
()
self
.
support_concate_and_append
=
True
self
.
support_concate_and_append
=
True
self
.
base_url
=
base_url
self
.
base_url
=
base_url
self
.
auth_token
=
auth_token
res
=
http_request
(
self
.
base_url
+
"/get_model_info"
)
res
=
http_request
(
self
.
base_url
+
"/get_model_info"
,
auth_token
=
self
.
auth_token
)
assert
res
.
status_code
==
200
assert
res
.
status_code
==
200
self
.
model_info
=
res
.
json
()
self
.
model_info
=
res
.
json
()
...
@@ -36,6 +37,7 @@ class RuntimeEndpoint(BaseBackend):
...
@@ -36,6 +37,7 @@ class RuntimeEndpoint(BaseBackend):
res
=
http_request
(
res
=
http_request
(
self
.
base_url
+
"/generate"
,
self
.
base_url
+
"/generate"
,
json
=
{
"text"
:
prefix_str
,
"sampling_params"
:
{
"max_new_tokens"
:
0
}},
json
=
{
"text"
:
prefix_str
,
"sampling_params"
:
{
"max_new_tokens"
:
0
}},
auth_token
=
self
.
auth_token
)
)
assert
res
.
status_code
==
200
assert
res
.
status_code
==
200
...
@@ -43,13 +45,14 @@ class RuntimeEndpoint(BaseBackend):
...
@@ -43,13 +45,14 @@ class RuntimeEndpoint(BaseBackend):
res
=
http_request
(
res
=
http_request
(
self
.
base_url
+
"/generate"
,
self
.
base_url
+
"/generate"
,
json
=
{
"text"
:
s
.
text_
,
"sampling_params"
:
{
"max_new_tokens"
:
0
}},
json
=
{
"text"
:
s
.
text_
,
"sampling_params"
:
{
"max_new_tokens"
:
0
}},
auth_token
=
self
.
auth_token
)
)
assert
res
.
status_code
==
200
assert
res
.
status_code
==
200
def
fill_image
(
self
,
s
:
StreamExecutor
):
def
fill_image
(
self
,
s
:
StreamExecutor
):
data
=
{
"text"
:
s
.
text_
,
"sampling_params"
:
{
"max_new_tokens"
:
0
}}
data
=
{
"text"
:
s
.
text_
,
"sampling_params"
:
{
"max_new_tokens"
:
0
}}
self
.
_add_images
(
s
,
data
)
self
.
_add_images
(
s
,
data
)
res
=
http_request
(
self
.
base_url
+
"/generate"
,
json
=
data
)
res
=
http_request
(
self
.
base_url
+
"/generate"
,
json
=
data
,
auth_token
=
self
.
auth_token
)
assert
res
.
status_code
==
200
assert
res
.
status_code
==
200
def
generate
(
def
generate
(
...
@@ -79,7 +82,7 @@ class RuntimeEndpoint(BaseBackend):
...
@@ -79,7 +82,7 @@ class RuntimeEndpoint(BaseBackend):
self
.
_add_images
(
s
,
data
)
self
.
_add_images
(
s
,
data
)
res
=
http_request
(
self
.
base_url
+
"/generate"
,
json
=
data
)
res
=
http_request
(
self
.
base_url
+
"/generate"
,
json
=
data
,
auth_token
=
self
.
auth_token
)
obj
=
res
.
json
()
obj
=
res
.
json
()
comp
=
obj
[
"text"
]
comp
=
obj
[
"text"
]
return
comp
,
obj
[
"meta_info"
]
return
comp
,
obj
[
"meta_info"
]
...
@@ -112,7 +115,7 @@ class RuntimeEndpoint(BaseBackend):
...
@@ -112,7 +115,7 @@ class RuntimeEndpoint(BaseBackend):
data
[
"stream"
]
=
True
data
[
"stream"
]
=
True
self
.
_add_images
(
s
,
data
)
self
.
_add_images
(
s
,
data
)
response
=
http_request
(
self
.
base_url
+
"/generate"
,
json
=
data
,
stream
=
True
)
response
=
http_request
(
self
.
base_url
+
"/generate"
,
json
=
data
,
stream
=
True
,
auth_token
=
self
.
auth_token
)
pos
=
0
pos
=
0
incomplete_text
=
""
incomplete_text
=
""
...
@@ -142,7 +145,7 @@ class RuntimeEndpoint(BaseBackend):
...
@@ -142,7 +145,7 @@ class RuntimeEndpoint(BaseBackend):
# Cache common prefix
# Cache common prefix
data
=
{
"text"
:
s
.
text_
,
"sampling_params"
:
{
"max_new_tokens"
:
0
}}
data
=
{
"text"
:
s
.
text_
,
"sampling_params"
:
{
"max_new_tokens"
:
0
}}
self
.
_add_images
(
s
,
data
)
self
.
_add_images
(
s
,
data
)
res
=
http_request
(
self
.
base_url
+
"/generate"
,
json
=
data
)
res
=
http_request
(
self
.
base_url
+
"/generate"
,
json
=
data
,
auth_token
=
self
.
auth_token
)
assert
res
.
status_code
==
200
assert
res
.
status_code
==
200
prompt_len
=
res
.
json
()[
"meta_info"
][
"prompt_tokens"
]
prompt_len
=
res
.
json
()[
"meta_info"
][
"prompt_tokens"
]
...
@@ -154,7 +157,7 @@ class RuntimeEndpoint(BaseBackend):
...
@@ -154,7 +157,7 @@ class RuntimeEndpoint(BaseBackend):
"logprob_start_len"
:
max
(
prompt_len
-
2
,
0
),
"logprob_start_len"
:
max
(
prompt_len
-
2
,
0
),
}
}
self
.
_add_images
(
s
,
data
)
self
.
_add_images
(
s
,
data
)
res
=
http_request
(
self
.
base_url
+
"/generate"
,
json
=
data
)
res
=
http_request
(
self
.
base_url
+
"/generate"
,
json
=
data
,
auth_token
=
self
.
auth_token
)
assert
res
.
status_code
==
200
assert
res
.
status_code
==
200
obj
=
res
.
json
()
obj
=
res
.
json
()
normalized_prompt_logprob
=
[
normalized_prompt_logprob
=
[
...
@@ -169,6 +172,7 @@ class RuntimeEndpoint(BaseBackend):
...
@@ -169,6 +172,7 @@ class RuntimeEndpoint(BaseBackend):
res
=
http_request
(
res
=
http_request
(
self
.
base_url
+
"/concate_and_append_request"
,
self
.
base_url
+
"/concate_and_append_request"
,
json
=
{
"src_rids"
:
src_rids
,
"dst_rid"
:
dst_rid
},
json
=
{
"src_rids"
:
src_rids
,
"dst_rid"
:
dst_rid
},
auth_token
=
self
.
auth_token
)
)
assert
res
.
status_code
==
200
assert
res
.
status_code
==
200
...
...
python/sglang/utils.py
View file @
405f26b0
...
@@ -88,13 +88,18 @@ class HttpResponse:
...
@@ -88,13 +88,18 @@ class HttpResponse:
return
self
.
resp
.
status
return
self
.
resp
.
status
def
http_request
(
url
,
json
=
None
,
stream
=
False
):
def
http_request
(
url
,
json
=
None
,
stream
=
False
,
auth_token
=
None
):
"""A faster version of requests.post with low-level urllib API."""
"""A faster version of requests.post with low-level urllib API."""
if
stream
:
if
stream
:
return
requests
.
post
(
url
,
json
=
json
,
stream
=
True
)
headers
=
{
"Content-Type"
:
"application/json"
,
"Authentication"
:
f
"Bearer
{
auth_token
}
"
}
return
requests
.
post
(
url
,
json
=
json
,
stream
=
True
,
headers
=
headers
)
else
:
else
:
req
=
urllib
.
request
.
Request
(
url
)
req
=
urllib
.
request
.
Request
(
url
)
req
.
add_header
(
"Content-Type"
,
"application/json; charset=utf-8"
)
req
.
add_header
(
"Content-Type"
,
"application/json; charset=utf-8"
)
req
.
add_header
(
"Authentication"
,
f
"Bearer
{
auth_token
}
"
)
if
json
is
None
:
if
json
is
None
:
data
=
None
data
=
None
else
:
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment