Unverified Commit bbd72bfc authored by 科英's avatar 科英 Committed by GitHub
Browse files

Add the ability to enable and disable the Profiler via HTTP API. (#1626)

parent b503881b
...@@ -20,6 +20,7 @@ processes (TokenizerManager, DetokenizerManager, Controller). ...@@ -20,6 +20,7 @@ processes (TokenizerManager, DetokenizerManager, Controller).
import uuid import uuid
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
from sglang.srt.managers.schedule_batch import BaseFinishReason from sglang.srt.managers.schedule_batch import BaseFinishReason
...@@ -343,3 +344,8 @@ class UpdateWeightReqOutput: ...@@ -343,3 +344,8 @@ class UpdateWeightReqOutput:
class AbortReq: class AbortReq:
# The request id # The request id
rid: str rid: str
class ProfileReq(Enum):
START_PROFILE = 1
STOP_PROFILE = 2
...@@ -37,6 +37,7 @@ from sglang.srt.managers.io_struct import ( ...@@ -37,6 +37,7 @@ from sglang.srt.managers.io_struct import (
BatchEmbeddingOut, BatchEmbeddingOut,
BatchTokenIDOut, BatchTokenIDOut,
FlushCacheReq, FlushCacheReq,
ProfileReq,
TokenizedEmbeddingReqInput, TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput, TokenizedGenerateReqInput,
TokenizedRewardReqInput, TokenizedRewardReqInput,
...@@ -229,6 +230,22 @@ class Scheduler: ...@@ -229,6 +230,22 @@ class Scheduler:
self.new_token_ratio_decay = global_config.new_token_ratio_decay self.new_token_ratio_decay = global_config.new_token_ratio_decay
self.batch_is_full = False self.batch_is_full = False
if os.getenv("SGLANG_TORCH_PROFILER_DIR", "") == "":
self.profiler = None
else:
self.torch_profiler_trace_dir = os.getenv("SGLANG_TORCH_PROFILER_DIR")
logger.info(
"Profiling enabled. Traces will be saved to: %s",
self.torch_profiler_trace_dir,
)
self.profiler = torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
with_stack=True,
)
@torch.inference_mode() @torch.inference_mode()
def event_loop(self): def event_loop(self):
while True: while True:
...@@ -271,6 +288,11 @@ class Scheduler: ...@@ -271,6 +288,11 @@ class Scheduler:
elif isinstance(recv_req, UpdateWeightReqInput): elif isinstance(recv_req, UpdateWeightReqInput):
success, message = self.update_weights(recv_req) success, message = self.update_weights(recv_req)
self.out_pyobjs.append(UpdateWeightReqOutput(success, message)) self.out_pyobjs.append(UpdateWeightReqOutput(success, message))
elif isinstance(recv_req, ProfileReq):
if recv_req == ProfileReq.START_PROFILE:
self.start_profile()
else:
self.stop_profile()
else: else:
raise ValueError(f"Invalid request: {recv_req}") raise ValueError(f"Invalid request: {recv_req}")
...@@ -1000,6 +1022,20 @@ class Scheduler: ...@@ -1000,6 +1022,20 @@ class Scheduler:
logger.error(message) logger.error(message)
return success, message return success, message
def start_profile(self) -> None:
if self.profiler is None:
raise RuntimeError("Profiler is not enabled.")
self.profiler.start()
def stop_profile(self) -> None:
if self.profiler is None:
raise RuntimeError("Profiler is not enabled.")
self.profiler.stop()
self.profiler.export_chrome_trace(
self.torch_profiler_trace_dir + "/" + str(time.time()) + ".trace.json.gz"
)
logger.info("Profiler is done")
def run_scheduler_process( def run_scheduler_process(
server_args: ServerArgs, server_args: ServerArgs,
......
...@@ -46,6 +46,7 @@ from sglang.srt.managers.io_struct import ( ...@@ -46,6 +46,7 @@ from sglang.srt.managers.io_struct import (
EmbeddingReqInput, EmbeddingReqInput,
FlushCacheReq, FlushCacheReq,
GenerateReqInput, GenerateReqInput,
ProfileReq,
RewardReqInput, RewardReqInput,
TokenizedEmbeddingReqInput, TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput, TokenizedGenerateReqInput,
...@@ -512,6 +513,14 @@ class TokenizerManager: ...@@ -512,6 +513,14 @@ class TokenizerManager:
req = AbortReq(rid) req = AbortReq(rid)
self.send_to_scheduler.send_pyobj(req) self.send_to_scheduler.send_pyobj(req)
def start_profile(self):
req = ProfileReq.START_PROFILE
self.send_to_scheduler.send_pyobj(req)
def stop_profile(self):
req = ProfileReq.STOP_PROFILE
self.send_to_scheduler.send_pyobj(req)
async def update_weights( async def update_weights(
self, obj: UpdateWeightReqInput, request: Optional[fastapi.Request] = None self, obj: UpdateWeightReqInput, request: Optional[fastapi.Request] = None
): ):
......
...@@ -145,6 +145,28 @@ async def flush_cache(): ...@@ -145,6 +145,28 @@ async def flush_cache():
) )
@app.get("/start_profile")
@app.post("/start_profile")
async def start_profile():
"""Start profiling."""
tokenizer_manager.start_profile()
return Response(
content="Start profiling.\n",
status_code=200,
)
@app.get("/stop_profile")
@app.post("/stop_profile")
async def stop_profile():
"""Stop profiling."""
tokenizer_manager.stop_profile()
return Response(
content="Stop profiling. This will take some time.\n",
status_code=200,
)
@app.post("/update_weights") @app.post("/update_weights")
async def update_weights(obj: UpdateWeightReqInput, request: Request): async def update_weights(obj: UpdateWeightReqInput, request: Request):
"""Update the weights inplace without re-launching the server.""" """Update the weights inplace without re-launching the server."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment