Unverified Commit dbe17293 authored by Henry Hyeonmok Ko's avatar Henry Hyeonmok Ko Committed by GitHub
Browse files

Merged three native APIs into one: get_server_info (#2152)

parent 84a1698d
......@@ -113,7 +113,7 @@ def main(args):
# Compute accuracy
tokenizer = get_tokenizer(
global_config.default_backend.get_server_args()["tokenizer_path"]
global_config.default_backend.get_server_info()["tokenizer_path"]
)
output_jsons = [state["json_output"] for state in states]
num_output_tokens = sum(len(tokenizer.encode(x)) for x in output_jsons)
......
......@@ -9,13 +9,11 @@
"Apart from the OpenAI compatible APIs, the SGLang Runtime also provides its native server APIs. We introduce these following APIs:\n",
"\n",
"- `/generate` (text generation model)\n",
"- `/get_server_args`\n",
"- `/get_model_info`\n",
"- `/get_server_info`\n",
"- `/health`\n",
"- `/health_generate`\n",
"- `/flush_cache`\n",
"- `/get_memory_pool_size`\n",
"- `/get_max_total_num_tokens`\n",
"- `/update_weights`\n",
"- `/encode`(embedding model)\n",
"- `/classify`(reward model)\n",
......@@ -75,26 +73,6 @@
"print_highlight(response.json())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Get Server Args\n",
"Get the arguments of a server."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"url = \"http://localhost:30010/get_server_args\"\n",
"\n",
"response = requests.get(url)\n",
"print_highlight(response.json())"
]
},
{
"cell_type": "markdown",
"metadata": {},
......@@ -127,9 +105,12 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## Health Check\n",
"- `/health`: Check the health of the server.\n",
"- `/health_generate`: Check the health of the server by generating one token."
"## Get Server Info\n",
"Gets the server information including CLI arguments, token limits, and memory pool sizes.\n",
"- Note: `get_server_info` merges the following deprecated endpoints:\n",
" - `get_server_args`\n",
" - `get_memory_pool_size` \n",
" - `get_max_total_num_tokens`"
]
},
{
......@@ -138,19 +119,9 @@
"metadata": {},
"outputs": [],
"source": [
"url = \"http://localhost:30010/health_generate\"\n",
"# get_server_info\n",
"\n",
"response = requests.get(url)\n",
"print_highlight(response.text)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"url = \"http://localhost:30010/health\"\n",
"url = \"http://localhost:30010/get_server_info\"\n",
"\n",
"response = requests.get(url)\n",
"print_highlight(response.text)"
......@@ -160,9 +131,9 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## Flush Cache\n",
"\n",
"Flush the radix cache. It will be automatically triggered when the model weights are updated by the `/update_weights` API."
"## Health Check\n",
"- `/health`: Check the health of the server.\n",
"- `/health_generate`: Check the health of the server by generating one token."
]
},
{
......@@ -171,32 +142,19 @@
"metadata": {},
"outputs": [],
"source": [
"# flush cache\n",
"\n",
"url = \"http://localhost:30010/flush_cache\"\n",
"url = \"http://localhost:30010/health_generate\"\n",
"\n",
"response = requests.post(url)\n",
"response = requests.get(url)\n",
"print_highlight(response.text)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Get Memory Pool Size\n",
"\n",
"Get the memory pool size in number of tokens.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# get_memory_pool_size\n",
"\n",
"url = \"http://localhost:30010/get_memory_pool_size\"\n",
"url = \"http://localhost:30010/health\"\n",
"\n",
"response = requests.get(url)\n",
"print_highlight(response.text)"
......@@ -206,9 +164,9 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## Get Maximum Total Number of Tokens\n",
"## Flush Cache\n",
"\n",
"Exposes the maximum number of tokens SGLang can handle based on the current configuration."
"Flush the radix cache. It will be automatically triggered when the model weights are updated by the `/update_weights` API."
]
},
{
......@@ -217,11 +175,11 @@
"metadata": {},
"outputs": [],
"source": [
"# get_max_total_num_tokens\n",
"# flush cache\n",
"\n",
"url = \"http://localhost:30010/get_max_total_num_tokens\"\n",
"url = \"http://localhost:30010/flush_cache\"\n",
"\n",
"response = requests.get(url)\n",
"response = requests.post(url)\n",
"print_highlight(response.text)"
]
},
......
......@@ -11,7 +11,7 @@ from sglang.api import (
gen,
gen_int,
gen_string,
get_server_args,
get_server_info,
image,
select,
set_default_backend,
......@@ -41,7 +41,7 @@ __all__ = [
"gen",
"gen_int",
"gen_string",
"get_server_args",
"get_server_info",
"image",
"select",
"set_default_backend",
......
......@@ -65,7 +65,7 @@ def flush_cache(backend: Optional[BaseBackend] = None):
return backend.flush_cache()
def get_server_args(backend: Optional[BaseBackend] = None):
def get_server_info(backend: Optional[BaseBackend] = None):
backend = backend or global_config.default_backend
if backend is None:
return None
......@@ -73,7 +73,7 @@ def get_server_args(backend: Optional[BaseBackend] = None):
# If backend is Runtime
if hasattr(backend, "endpoint"):
backend = backend.endpoint
return backend.get_server_args()
return backend.get_server_info()
def gen(
......
......@@ -78,5 +78,5 @@ class BaseBackend:
def flush_cache(self):
pass
def get_server_args(self):
def get_server_info(self):
pass
......@@ -58,9 +58,9 @@ class RuntimeEndpoint(BaseBackend):
)
self._assert_success(res)
def get_server_args(self):
def get_server_info(self):
res = http_request(
self.base_url + "/get_server_args",
self.base_url + "/get_server_info",
api_key=self.api_key,
verify=self.verify,
)
......
......@@ -146,10 +146,15 @@ async def get_model_info():
return result
@app.get("/get_server_args")
async def get_server_args():
"""Get the server arguments."""
return dataclasses.asdict(tokenizer_manager.server_args)
@app.get("/get_server_info")
async def get_server_info():
try:
return await _get_server_info()
except Exception as e:
return ORJSONResponse(
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
)
@app.post("/flush_cache")
......@@ -185,30 +190,6 @@ async def stop_profile():
)
@app.get("/get_max_total_num_tokens")
async def get_max_total_num_tokens():
try:
return {"max_total_num_tokens": _get_max_total_num_tokens()}
except Exception as e:
return ORJSONResponse(
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
)
@app.api_route("/get_memory_pool_size", methods=["GET", "POST"])
async def get_memory_pool_size():
"""Get the memory pool size in number of tokens"""
try:
ret = await tokenizer_manager.get_memory_pool_size()
return ret
except Exception as e:
return ORJSONResponse(
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
)
@app.post("/update_weights")
@time_func_latency
async def update_weights(obj: UpdateWeightReqInput, request: Request):
......@@ -542,8 +523,12 @@ def launch_server(
t.join()
def _get_max_total_num_tokens():
return _max_total_num_tokens
async def _get_server_info():
return {
**dataclasses.asdict(tokenizer_manager.server_args), # server args
"memory_pool_size": await tokenizer_manager.get_memory_pool_size(), # memory pool size
"max_total_num_tokens": _max_total_num_tokens, # max total num tokens
}
def _set_envs_and_config(server_args: ServerArgs):
......@@ -787,14 +772,16 @@ class Runtime:
response = requests.post(self.url + "/encode", json=json_data)
return json.dumps(response.json())
def get_max_total_num_tokens(self):
response = requests.get(f"{self.url}/get_max_total_num_tokens")
if response.status_code == 200:
return response.json()["max_total_num_tokens"]
else:
raise RuntimeError(
f"Failed to get max tokens. {response.json()['error']['message']}"
)
async def get_server_info(self):
async with aiohttp.ClientSession() as session:
async with session.get(f"{self.url}/get_server_info") as response:
if response.status == 200:
return await response.json()
else:
error_data = await response.json()
raise RuntimeError(
f"Failed to get server info. {error_data['error']['message']}"
)
def __del__(self):
self.shutdown()
......@@ -946,5 +933,5 @@ class Engine:
loop = asyncio.get_event_loop()
return loop.run_until_complete(encode_request(obj, None))
def get_max_total_num_tokens(self):
return _get_max_total_num_tokens()
async def get_server_info(self):
return await _get_server_info()
......@@ -66,14 +66,14 @@ async fn health_generate(data: web::Data<AppState>) -> impl Responder {
forward_request(&data.client, worker_url, "/health_generate".to_string()).await
}
#[get("/get_server_args")]
async fn get_server_args(data: web::Data<AppState>) -> impl Responder {
#[get("/get_server_info")]
async fn get_server_info(data: web::Data<AppState>) -> impl Responder {
let worker_url = match data.router.get_first() {
Some(url) => url,
None => return HttpResponse::InternalServerError().finish(),
};
forward_request(&data.client, worker_url, "/get_server_args".to_string()).await
forward_request(&data.client, worker_url, "/get_server_info".to_string()).await
}
#[get("/v1/models")]
......@@ -153,7 +153,7 @@ pub async fn startup(
.service(get_model_info)
.service(health)
.service(health_generate)
.service(get_server_args)
.service(get_server_info)
})
.bind((host, port))?
.run()
......
......@@ -63,12 +63,13 @@ class TestDataParallelism(unittest.TestCase):
assert response.status_code == 200
def test_get_memory_pool_size(self):
response = requests.get(self.base_url + "/get_memory_pool_size")
# use `get_server_info` instead since `get_memory_pool_size` is merged into `get_server_info`
response = requests.get(self.base_url + "/get_server_info")
assert response.status_code == 200
time.sleep(5)
response = requests.get(self.base_url + "/get_memory_pool_size")
response = requests.get(self.base_url + "/get_server_info")
assert response.status_code == 200
......
......@@ -154,9 +154,18 @@ class TestSRTEndpoint(unittest.TestCase):
self.assertEqual(res["meta_info"]["completion_tokens"], new_tokens)
self.assertEqual(len(res["meta_info"]["output_token_logprobs"]), new_tokens)
def test_get_memory_pool_size(self):
response = requests.post(self.base_url + "/get_memory_pool_size")
self.assertIsInstance(response.json(), int)
def test_get_server_info(self):
response = requests.get(self.base_url + "/get_server_info")
response_json = response.json()
max_total_num_tokens = response_json["max_total_num_tokens"]
self.assertIsInstance(max_total_num_tokens, int)
memory_pool_size = response_json["memory_pool_size"]
self.assertIsInstance(memory_pool_size, int)
attention_backend = response_json["attention_backend"]
self.assertIsInstance(attention_backend, str)
if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment