Unverified Commit c35cd1f8 authored by Henry Hyeonmok Ko's avatar Henry Hyeonmok Ko Committed by GitHub
Browse files

Expose max total num tokens from Runtime & Engine API (#2092)

parent 72f87b72
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
"- `/health_generate`\n", "- `/health_generate`\n",
"- `/flush_cache`\n", "- `/flush_cache`\n",
"- `/get_memory_pool_size`\n", "- `/get_memory_pool_size`\n",
"- `/get_max_total_num_tokens`\n",
"- `/update_weights`\n", "- `/update_weights`\n",
"- `/encode`(embedding model)\n", "- `/encode`(embedding model)\n",
"- `/classify`(reward model)\n", "- `/classify`(reward model)\n",
...@@ -201,6 +202,29 @@ ...@@ -201,6 +202,29 @@
"print_highlight(response.text)" "print_highlight(response.text)"
] ]
}, },
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Get Maximum Total Number of Tokens\n",
"\n",
"Exposes the maximum number of tokens SGLang can handle based on the current configuration."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# get_max_total_num_tokens\n",
"\n",
"url = \"http://localhost:30010/get_max_total_num_tokens\"\n",
"\n",
"response = requests.get(url)\n",
"print_highlight(response.text)"
]
},
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
......
...@@ -167,9 +167,12 @@ class DataParallelController: ...@@ -167,9 +167,12 @@ class DataParallelController:
self.context, zmq.PUSH, port_args.scheduler_input_ipc_name self.context, zmq.PUSH, port_args.scheduler_input_ipc_name
) )
# Wait for model to finish loading # Wait for model to finish loading and get max token nums
scheduler_info = []
for i in range(len(scheduler_pipe_readers)): for i in range(len(scheduler_pipe_readers)):
scheduler_pipe_readers[i].recv() scheduler_info.append(scheduler_pipe_readers[i].recv())
self.max_total_num_tokens = scheduler_info[0]["max_total_num_tokens"]
return send_to return send_to
...@@ -191,7 +194,10 @@ class DataParallelController: ...@@ -191,7 +194,10 @@ class DataParallelController:
send_to = get_zmq_socket( send_to = get_zmq_socket(
self.context, zmq.PUSH, port_args.scheduler_input_ipc_name self.context, zmq.PUSH, port_args.scheduler_input_ipc_name
) )
reader.recv()
scheduler_info = reader.recv()
self.max_total_num_tokens = scheduler_info["max_total_num_tokens"]
return send_to return send_to
def round_robin_scheduler(self, req): def round_robin_scheduler(self, req):
...@@ -233,7 +239,9 @@ def run_data_parallel_controller_process( ...@@ -233,7 +239,9 @@ def run_data_parallel_controller_process(
try: try:
controller = DataParallelController(server_args, port_args) controller = DataParallelController(server_args, port_args)
pipe_writer.send("ready") pipe_writer.send(
{"status": "ready", "max_total_num_tokens": controller.max_total_num_tokens}
)
controller.event_loop() controller.event_loop()
except Exception: except Exception:
msg = get_exception_traceback() msg = get_exception_traceback()
......
...@@ -1400,7 +1400,9 @@ def run_scheduler_process( ...@@ -1400,7 +1400,9 @@ def run_scheduler_process(
try: try:
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank) scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
pipe_writer.send("ready") pipe_writer.send(
{"status": "ready", "max_total_num_tokens": scheduler.max_total_num_tokens}
)
if scheduler.enable_overlap: if scheduler.enable_overlap:
scheduler.event_loop_overlap() scheduler.event_loop_overlap()
else: else:
......
...@@ -102,6 +102,7 @@ app.add_middleware( ...@@ -102,6 +102,7 @@ app.add_middleware(
) )
tokenizer_manager: TokenizerManager = None tokenizer_manager: TokenizerManager = None
_max_total_num_tokens = None
##### Native API endpoints ##### ##### Native API endpoints #####
...@@ -184,6 +185,17 @@ async def stop_profile(): ...@@ -184,6 +185,17 @@ async def stop_profile():
) )
@app.get("/get_max_total_num_tokens")
async def get_max_total_num_tokens():
try:
return {"max_total_num_tokens": _get_max_total_num_tokens()}
except Exception as e:
return ORJSONResponse(
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
)
@app.api_route("/get_memory_pool_size", methods=["GET", "POST"]) @app.api_route("/get_memory_pool_size", methods=["GET", "POST"])
async def get_memory_pool_size(): async def get_memory_pool_size():
"""Get the memory pool size in number of tokens""" """Get the memory pool size in number of tokens"""
...@@ -390,6 +402,7 @@ def launch_engine( ...@@ -390,6 +402,7 @@ def launch_engine(
""" """
global tokenizer_manager global tokenizer_manager
global _max_total_num_tokens
# Configure global environment # Configure global environment
configure_logger(server_args) configure_logger(server_args)
...@@ -455,9 +468,20 @@ def launch_engine( ...@@ -455,9 +468,20 @@ def launch_engine(
if server_args.chat_template: if server_args.chat_template:
load_chat_template_for_openai_api(tokenizer_manager, server_args.chat_template) load_chat_template_for_openai_api(tokenizer_manager, server_args.chat_template)
# Wait for model to finish loading # Wait for model to finish loading & get max token nums
scheduler_info = []
for i in range(len(scheduler_pipe_readers)): for i in range(len(scheduler_pipe_readers)):
scheduler_pipe_readers[i].recv() data = scheduler_pipe_readers[i].recv()
if data["status"] != "ready":
self.shutdown()
raise RuntimeError(
"Initialization failed. Please see the error messages above."
)
scheduler_info.append(data)
# Assume all schedulers have same max_total_num_tokens
_max_total_num_tokens = scheduler_info[0]["max_total_num_tokens"]
def launch_server( def launch_server(
...@@ -518,6 +542,10 @@ def launch_server( ...@@ -518,6 +542,10 @@ def launch_server(
t.join() t.join()
def _get_max_total_num_tokens():
return _max_total_num_tokens
def _set_envs_and_config(server_args: ServerArgs): def _set_envs_and_config(server_args: ServerArgs):
# Set global environments # Set global environments
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
...@@ -759,6 +787,15 @@ class Runtime: ...@@ -759,6 +787,15 @@ class Runtime:
response = requests.post(self.url + "/encode", json=json_data) response = requests.post(self.url + "/encode", json=json_data)
return json.dumps(response.json()) return json.dumps(response.json())
def get_max_total_num_tokens(self):
response = requests.get(f"{self.url}/get_max_total_num_tokens")
if response.status_code == 200:
return response.json()["max_total_num_tokens"]
else:
raise RuntimeError(
f"Failed to get max tokens. {response.json()['error']['message']}"
)
def __del__(self): def __del__(self):
self.shutdown() self.shutdown()
...@@ -908,3 +945,6 @@ class Engine: ...@@ -908,3 +945,6 @@ class Engine:
# get the current event loop # get the current event loop
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
return loop.run_until_complete(encode_request(obj, None)) return loop.run_until_complete(encode_request(obj, None))
def get_max_total_num_tokens(self):
return _get_max_total_num_tokens()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment