Unverified Commit 405f26b0 authored by Srinivas Billa's avatar Srinivas Billa Committed by GitHub
Browse files

Add Auth Token to RuntimeEndpoint (#162)

parent b1a3a454
......@@ -12,13 +12,14 @@ from sglang.utils import encode_image_base64, find_printable_text, http_request
class RuntimeEndpoint(BaseBackend):
def __init__(self, base_url):
def __init__(self, base_url, auth_token=None):
super().__init__()
self.support_concate_and_append = True
self.base_url = base_url
self.auth_token = auth_token
res = http_request(self.base_url + "/get_model_info")
res = http_request(self.base_url + "/get_model_info", auth_token=self.auth_token)
assert res.status_code == 200
self.model_info = res.json()
......@@ -36,6 +37,7 @@ class RuntimeEndpoint(BaseBackend):
res = http_request(
self.base_url + "/generate",
json={"text": prefix_str, "sampling_params": {"max_new_tokens": 0}},
auth_token=self.auth_token
)
assert res.status_code == 200
......@@ -43,13 +45,14 @@ class RuntimeEndpoint(BaseBackend):
res = http_request(
self.base_url + "/generate",
json={"text": s.text_, "sampling_params": {"max_new_tokens": 0}},
auth_token=self.auth_token
)
assert res.status_code == 200
def fill_image(self, s: StreamExecutor):
data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
self._add_images(s, data)
res = http_request(self.base_url + "/generate", json=data)
res = http_request(self.base_url + "/generate", json=data, auth_token=self.auth_token)
assert res.status_code == 200
def generate(
......@@ -79,7 +82,7 @@ class RuntimeEndpoint(BaseBackend):
self._add_images(s, data)
res = http_request(self.base_url + "/generate", json=data)
res = http_request(self.base_url + "/generate", json=data, auth_token=self.auth_token)
obj = res.json()
comp = obj["text"]
return comp, obj["meta_info"]
......@@ -112,7 +115,7 @@ class RuntimeEndpoint(BaseBackend):
data["stream"] = True
self._add_images(s, data)
response = http_request(self.base_url + "/generate", json=data, stream=True)
response = http_request(self.base_url + "/generate", json=data, stream=True, auth_token=self.auth_token)
pos = 0
incomplete_text = ""
......@@ -142,7 +145,7 @@ class RuntimeEndpoint(BaseBackend):
# Cache common prefix
data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
self._add_images(s, data)
res = http_request(self.base_url + "/generate", json=data)
res = http_request(self.base_url + "/generate", json=data, auth_token=self.auth_token)
assert res.status_code == 200
prompt_len = res.json()["meta_info"]["prompt_tokens"]
......@@ -154,7 +157,7 @@ class RuntimeEndpoint(BaseBackend):
"logprob_start_len": max(prompt_len - 2, 0),
}
self._add_images(s, data)
res = http_request(self.base_url + "/generate", json=data)
res = http_request(self.base_url + "/generate", json=data, auth_token=self.auth_token)
assert res.status_code == 200
obj = res.json()
normalized_prompt_logprob = [
......@@ -169,6 +172,7 @@ class RuntimeEndpoint(BaseBackend):
res = http_request(
self.base_url + "/concate_and_append_request",
json={"src_rids": src_rids, "dst_rid": dst_rid},
auth_token=self.auth_token
)
assert res.status_code == 200
......
......@@ -88,13 +88,18 @@ class HttpResponse:
return self.resp.status
def http_request(url, json=None, stream=False):
def http_request(url, json=None, stream=False, auth_token=None):
"""A faster version of requests.post with low-level urllib API."""
if stream:
return requests.post(url, json=json, stream=True)
headers = {
"Content-Type": "application/json",
"Authentication": f"Bearer {auth_token}"
}
return requests.post(url, json=json, stream=True, headers=headers)
else:
req = urllib.request.Request(url)
req.add_header("Content-Type", "application/json; charset=utf-8")
req.add_header("Authentication", f"Bearer {auth_token}")
if json is None:
data = None
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment