Unverified Commit dbe17293 authored by Henry Hyeonmok Ko's avatar Henry Hyeonmok Ko Committed by GitHub
Browse files

Merged three native APIs into one: get_server_info (#2152)

parent 84a1698d
...@@ -113,7 +113,7 @@ def main(args): ...@@ -113,7 +113,7 @@ def main(args):
# Compute accuracy # Compute accuracy
tokenizer = get_tokenizer( tokenizer = get_tokenizer(
global_config.default_backend.get_server_args()["tokenizer_path"] global_config.default_backend.get_server_info()["tokenizer_path"]
) )
output_jsons = [state["json_output"] for state in states] output_jsons = [state["json_output"] for state in states]
num_output_tokens = sum(len(tokenizer.encode(x)) for x in output_jsons) num_output_tokens = sum(len(tokenizer.encode(x)) for x in output_jsons)
......
...@@ -9,13 +9,11 @@ ...@@ -9,13 +9,11 @@
"Apart from the OpenAI compatible APIs, the SGLang Runtime also provides its native server APIs. We introduce these following APIs:\n", "Apart from the OpenAI compatible APIs, the SGLang Runtime also provides its native server APIs. We introduce these following APIs:\n",
"\n", "\n",
"- `/generate` (text generation model)\n", "- `/generate` (text generation model)\n",
"- `/get_server_args`\n",
"- `/get_model_info`\n", "- `/get_model_info`\n",
"- `/get_server_info`\n",
"- `/health`\n", "- `/health`\n",
"- `/health_generate`\n", "- `/health_generate`\n",
"- `/flush_cache`\n", "- `/flush_cache`\n",
"- `/get_memory_pool_size`\n",
"- `/get_max_total_num_tokens`\n",
"- `/update_weights`\n", "- `/update_weights`\n",
"- `/encode`(embedding model)\n", "- `/encode`(embedding model)\n",
"- `/classify`(reward model)\n", "- `/classify`(reward model)\n",
...@@ -75,26 +73,6 @@ ...@@ -75,26 +73,6 @@
"print_highlight(response.json())" "print_highlight(response.json())"
] ]
}, },
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Get Server Args\n",
"Get the arguments of a server."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"url = \"http://localhost:30010/get_server_args\"\n",
"\n",
"response = requests.get(url)\n",
"print_highlight(response.json())"
]
},
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
...@@ -127,9 +105,12 @@ ...@@ -127,9 +105,12 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"## Health Check\n", "## Get Server Info\n",
"- `/health`: Check the health of the server.\n", "Gets the server information including CLI arguments, token limits, and memory pool sizes.\n",
"- `/health_generate`: Check the health of the server by generating one token." "- Note: `get_server_info` merges the following deprecated endpoints:\n",
" - `get_server_args`\n",
" - `get_memory_pool_size` \n",
" - `get_max_total_num_tokens`"
] ]
}, },
{ {
...@@ -138,19 +119,9 @@ ...@@ -138,19 +119,9 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"url = \"http://localhost:30010/health_generate\"\n", "# get_server_info\n",
"\n", "\n",
"response = requests.get(url)\n", "url = \"http://localhost:30010/get_server_info\"\n",
"print_highlight(response.text)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"url = \"http://localhost:30010/health\"\n",
"\n", "\n",
"response = requests.get(url)\n", "response = requests.get(url)\n",
"print_highlight(response.text)" "print_highlight(response.text)"
...@@ -160,9 +131,9 @@ ...@@ -160,9 +131,9 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"## Flush Cache\n", "## Health Check\n",
"\n", "- `/health`: Check the health of the server.\n",
"Flush the radix cache. It will be automatically triggered when the model weights are updated by the `/update_weights` API." "- `/health_generate`: Check the health of the server by generating one token."
] ]
}, },
{ {
...@@ -171,32 +142,19 @@ ...@@ -171,32 +142,19 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"# flush cache\n", "url = \"http://localhost:30010/health_generate\"\n",
"\n",
"url = \"http://localhost:30010/flush_cache\"\n",
"\n", "\n",
"response = requests.post(url)\n", "response = requests.get(url)\n",
"print_highlight(response.text)" "print_highlight(response.text)"
] ]
}, },
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Get Memory Pool Size\n",
"\n",
"Get the memory pool size in number of tokens.\n"
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"# get_memory_pool_size\n", "url = \"http://localhost:30010/health\"\n",
"\n",
"url = \"http://localhost:30010/get_memory_pool_size\"\n",
"\n", "\n",
"response = requests.get(url)\n", "response = requests.get(url)\n",
"print_highlight(response.text)" "print_highlight(response.text)"
...@@ -206,9 +164,9 @@ ...@@ -206,9 +164,9 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"## Get Maximum Total Number of Tokens\n", "## Flush Cache\n",
"\n", "\n",
"Exposes the maximum number of tokens SGLang can handle based on the current configuration." "Flush the radix cache. It will be automatically triggered when the model weights are updated by the `/update_weights` API."
] ]
}, },
{ {
...@@ -217,11 +175,11 @@ ...@@ -217,11 +175,11 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"# get_max_total_num_tokens\n", "# flush cache\n",
"\n", "\n",
"url = \"http://localhost:30010/get_max_total_num_tokens\"\n", "url = \"http://localhost:30010/flush_cache\"\n",
"\n", "\n",
"response = requests.get(url)\n", "response = requests.post(url)\n",
"print_highlight(response.text)" "print_highlight(response.text)"
] ]
}, },
......
...@@ -11,7 +11,7 @@ from sglang.api import ( ...@@ -11,7 +11,7 @@ from sglang.api import (
gen, gen,
gen_int, gen_int,
gen_string, gen_string,
get_server_args, get_server_info,
image, image,
select, select,
set_default_backend, set_default_backend,
...@@ -41,7 +41,7 @@ __all__ = [ ...@@ -41,7 +41,7 @@ __all__ = [
"gen", "gen",
"gen_int", "gen_int",
"gen_string", "gen_string",
"get_server_args", "get_server_info",
"image", "image",
"select", "select",
"set_default_backend", "set_default_backend",
......
...@@ -65,7 +65,7 @@ def flush_cache(backend: Optional[BaseBackend] = None): ...@@ -65,7 +65,7 @@ def flush_cache(backend: Optional[BaseBackend] = None):
return backend.flush_cache() return backend.flush_cache()
def get_server_args(backend: Optional[BaseBackend] = None): def get_server_info(backend: Optional[BaseBackend] = None):
backend = backend or global_config.default_backend backend = backend or global_config.default_backend
if backend is None: if backend is None:
return None return None
...@@ -73,7 +73,7 @@ def get_server_args(backend: Optional[BaseBackend] = None): ...@@ -73,7 +73,7 @@ def get_server_args(backend: Optional[BaseBackend] = None):
# If backend is Runtime # If backend is Runtime
if hasattr(backend, "endpoint"): if hasattr(backend, "endpoint"):
backend = backend.endpoint backend = backend.endpoint
return backend.get_server_args() return backend.get_server_info()
def gen( def gen(
......
...@@ -78,5 +78,5 @@ class BaseBackend: ...@@ -78,5 +78,5 @@ class BaseBackend:
def flush_cache(self): def flush_cache(self):
pass pass
def get_server_args(self): def get_server_info(self):
pass pass
...@@ -58,9 +58,9 @@ class RuntimeEndpoint(BaseBackend): ...@@ -58,9 +58,9 @@ class RuntimeEndpoint(BaseBackend):
) )
self._assert_success(res) self._assert_success(res)
def get_server_args(self): def get_server_info(self):
res = http_request( res = http_request(
self.base_url + "/get_server_args", self.base_url + "/get_server_info",
api_key=self.api_key, api_key=self.api_key,
verify=self.verify, verify=self.verify,
) )
......
...@@ -146,10 +146,15 @@ async def get_model_info(): ...@@ -146,10 +146,15 @@ async def get_model_info():
return result return result
@app.get("/get_server_args") @app.get("/get_server_info")
async def get_server_args(): async def get_server_info():
"""Get the server arguments.""" try:
return dataclasses.asdict(tokenizer_manager.server_args) return await _get_server_info()
except Exception as e:
return ORJSONResponse(
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
)
@app.post("/flush_cache") @app.post("/flush_cache")
...@@ -185,30 +190,6 @@ async def stop_profile(): ...@@ -185,30 +190,6 @@ async def stop_profile():
) )
@app.get("/get_max_total_num_tokens")
async def get_max_total_num_tokens():
try:
return {"max_total_num_tokens": _get_max_total_num_tokens()}
except Exception as e:
return ORJSONResponse(
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
)
@app.api_route("/get_memory_pool_size", methods=["GET", "POST"])
async def get_memory_pool_size():
"""Get the memory pool size in number of tokens"""
try:
ret = await tokenizer_manager.get_memory_pool_size()
return ret
except Exception as e:
return ORJSONResponse(
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
)
@app.post("/update_weights") @app.post("/update_weights")
@time_func_latency @time_func_latency
async def update_weights(obj: UpdateWeightReqInput, request: Request): async def update_weights(obj: UpdateWeightReqInput, request: Request):
...@@ -542,8 +523,12 @@ def launch_server( ...@@ -542,8 +523,12 @@ def launch_server(
t.join() t.join()
def _get_max_total_num_tokens(): async def _get_server_info():
return _max_total_num_tokens return {
**dataclasses.asdict(tokenizer_manager.server_args), # server args
"memory_pool_size": await tokenizer_manager.get_memory_pool_size(), # memory pool size
"max_total_num_tokens": _max_total_num_tokens, # max total num tokens
}
def _set_envs_and_config(server_args: ServerArgs): def _set_envs_and_config(server_args: ServerArgs):
...@@ -787,14 +772,16 @@ class Runtime: ...@@ -787,14 +772,16 @@ class Runtime:
response = requests.post(self.url + "/encode", json=json_data) response = requests.post(self.url + "/encode", json=json_data)
return json.dumps(response.json()) return json.dumps(response.json())
def get_max_total_num_tokens(self): async def get_server_info(self):
response = requests.get(f"{self.url}/get_max_total_num_tokens") async with aiohttp.ClientSession() as session:
if response.status_code == 200: async with session.get(f"{self.url}/get_server_info") as response:
return response.json()["max_total_num_tokens"] if response.status == 200:
else: return await response.json()
raise RuntimeError( else:
f"Failed to get max tokens. {response.json()['error']['message']}" error_data = await response.json()
) raise RuntimeError(
f"Failed to get server info. {error_data['error']['message']}"
)
def __del__(self): def __del__(self):
self.shutdown() self.shutdown()
...@@ -946,5 +933,5 @@ class Engine: ...@@ -946,5 +933,5 @@ class Engine:
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
return loop.run_until_complete(encode_request(obj, None)) return loop.run_until_complete(encode_request(obj, None))
def get_max_total_num_tokens(self): async def get_server_info(self):
return _get_max_total_num_tokens() return await _get_server_info()
...@@ -66,14 +66,14 @@ async fn health_generate(data: web::Data<AppState>) -> impl Responder { ...@@ -66,14 +66,14 @@ async fn health_generate(data: web::Data<AppState>) -> impl Responder {
forward_request(&data.client, worker_url, "/health_generate".to_string()).await forward_request(&data.client, worker_url, "/health_generate".to_string()).await
} }
#[get("/get_server_args")] #[get("/get_server_info")]
async fn get_server_args(data: web::Data<AppState>) -> impl Responder { async fn get_server_info(data: web::Data<AppState>) -> impl Responder {
let worker_url = match data.router.get_first() { let worker_url = match data.router.get_first() {
Some(url) => url, Some(url) => url,
None => return HttpResponse::InternalServerError().finish(), None => return HttpResponse::InternalServerError().finish(),
}; };
forward_request(&data.client, worker_url, "/get_server_args".to_string()).await forward_request(&data.client, worker_url, "/get_server_info".to_string()).await
} }
#[get("/v1/models")] #[get("/v1/models")]
...@@ -153,7 +153,7 @@ pub async fn startup( ...@@ -153,7 +153,7 @@ pub async fn startup(
.service(get_model_info) .service(get_model_info)
.service(health) .service(health)
.service(health_generate) .service(health_generate)
.service(get_server_args) .service(get_server_info)
}) })
.bind((host, port))? .bind((host, port))?
.run() .run()
......
...@@ -63,12 +63,13 @@ class TestDataParallelism(unittest.TestCase): ...@@ -63,12 +63,13 @@ class TestDataParallelism(unittest.TestCase):
assert response.status_code == 200 assert response.status_code == 200
def test_get_memory_pool_size(self): def test_get_memory_pool_size(self):
response = requests.get(self.base_url + "/get_memory_pool_size") # use `get_server_info` instead since `get_memory_pool_size` is merged into `get_server_info`
response = requests.get(self.base_url + "/get_server_info")
assert response.status_code == 200 assert response.status_code == 200
time.sleep(5) time.sleep(5)
response = requests.get(self.base_url + "/get_memory_pool_size") response = requests.get(self.base_url + "/get_server_info")
assert response.status_code == 200 assert response.status_code == 200
......
...@@ -154,9 +154,18 @@ class TestSRTEndpoint(unittest.TestCase): ...@@ -154,9 +154,18 @@ class TestSRTEndpoint(unittest.TestCase):
self.assertEqual(res["meta_info"]["completion_tokens"], new_tokens) self.assertEqual(res["meta_info"]["completion_tokens"], new_tokens)
self.assertEqual(len(res["meta_info"]["output_token_logprobs"]), new_tokens) self.assertEqual(len(res["meta_info"]["output_token_logprobs"]), new_tokens)
def test_get_memory_pool_size(self): def test_get_server_info(self):
response = requests.post(self.base_url + "/get_memory_pool_size") response = requests.get(self.base_url + "/get_server_info")
self.assertIsInstance(response.json(), int) response_json = response.json()
max_total_num_tokens = response_json["max_total_num_tokens"]
self.assertIsInstance(max_total_num_tokens, int)
memory_pool_size = response_json["memory_pool_size"]
self.assertIsInstance(memory_pool_size, int)
attention_backend = response_json["attention_backend"]
self.assertIsInstance(attention_backend, str)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment