Unverified Commit bbd72bfc authored by 科英's avatar 科英 Committed by GitHub
Browse files

Add the ability to enable and disable the Profiler via HTTP API. (#1626)

parent b503881b
......@@ -20,6 +20,7 @@ processes (TokenizerManager, DetokenizerManager, Controller).
import uuid
from dataclasses import dataclass
from enum import Enum
from typing import Dict, List, Optional, Union
from sglang.srt.managers.schedule_batch import BaseFinishReason
......@@ -343,3 +344,8 @@ class UpdateWeightReqOutput:
class AbortReq:
# The request id
rid: str
class ProfileReq(Enum):
START_PROFILE = 1
STOP_PROFILE = 2
......@@ -37,6 +37,7 @@ from sglang.srt.managers.io_struct import (
BatchEmbeddingOut,
BatchTokenIDOut,
FlushCacheReq,
ProfileReq,
TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput,
TokenizedRewardReqInput,
......@@ -229,6 +230,22 @@ class Scheduler:
self.new_token_ratio_decay = global_config.new_token_ratio_decay
self.batch_is_full = False
if os.getenv("SGLANG_TORCH_PROFILER_DIR", "") == "":
self.profiler = None
else:
self.torch_profiler_trace_dir = os.getenv("SGLANG_TORCH_PROFILER_DIR")
logger.info(
"Profiling enabled. Traces will be saved to: %s",
self.torch_profiler_trace_dir,
)
self.profiler = torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
with_stack=True,
)
@torch.inference_mode()
def event_loop(self):
while True:
......@@ -271,6 +288,11 @@ class Scheduler:
elif isinstance(recv_req, UpdateWeightReqInput):
success, message = self.update_weights(recv_req)
self.out_pyobjs.append(UpdateWeightReqOutput(success, message))
elif isinstance(recv_req, ProfileReq):
if recv_req == ProfileReq.START_PROFILE:
self.start_profile()
else:
self.stop_profile()
else:
raise ValueError(f"Invalid request: {recv_req}")
......@@ -1000,6 +1022,20 @@ class Scheduler:
logger.error(message)
return success, message
def start_profile(self) -> None:
if self.profiler is None:
raise RuntimeError("Profiler is not enabled.")
self.profiler.start()
def stop_profile(self) -> None:
if self.profiler is None:
raise RuntimeError("Profiler is not enabled.")
self.profiler.stop()
self.profiler.export_chrome_trace(
self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz"
)
logger.info("Profiler is done")
def run_scheduler_process(
server_args: ServerArgs,
......
......@@ -46,6 +46,7 @@ from sglang.srt.managers.io_struct import (
EmbeddingReqInput,
FlushCacheReq,
GenerateReqInput,
ProfileReq,
RewardReqInput,
TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput,
......@@ -512,6 +513,14 @@ class TokenizerManager:
req = AbortReq(rid)
self.send_to_scheduler.send_pyobj(req)
def start_profile(self):
req = ProfileReq.START_PROFILE
self.send_to_scheduler.send_pyobj(req)
def stop_profile(self):
req = ProfileReq.STOP_PROFILE
self.send_to_scheduler.send_pyobj(req)
async def update_weights(
self, obj: UpdateWeightReqInput, request: Optional[fastapi.Request] = None
):
......
......@@ -145,6 +145,28 @@ async def flush_cache():
)
@app.get("/start_profile")
@app.post("/start_profile")
async def start_profile():
"""Start profiling."""
tokenizer_manager.start_profile()
return Response(
content="Start profiling.\n",
status_code=200,
)
@app.get("/stop_profile")
@app.post("/stop_profile")
async def stop_profile():
"""Stop profiling."""
tokenizer_manager.stop_profile()
return Response(
content="Stop profiling. This will take some time.\n",
status_code=200,
)
@app.post("/update_weights")
async def update_weights(obj: UpdateWeightReqInput, request: Request):
"""Update the weights inplace without re-launching the server."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment