Unverified Commit 91e5dbf5 authored by bjmsong's avatar bjmsong Committed by GitHub
Browse files

add profile in offline benchmark & update doc (#2123)


Co-authored-by: default avatarroot <bjmsong@126.com>
parent dd5eba4c
...@@ -56,3 +56,22 @@ with nvtx.annotate("description", color="color"): ...@@ -56,3 +56,22 @@ with nvtx.annotate("description", color="color"):
## Other tips ## Other tips
1. You can benchmark a model using dummy weights by only providing the config.json file. This allows for quick testing of model variants without training. To do so, add `--load-format dummy` to the above commands and then you only need a correct `config.json` under the checkpoint folder. 1. You can benchmark a model using dummy weights by only providing the config.json file. This allows for quick testing of model variants without training. To do so, add `--load-format dummy` to the above commands and then you only need a correct `config.json` under the checkpoint folder.
## Profile with PyTorch Profiler
- To profile a server
```bash
# set trace path
export SGLANG_TORCH_PROFILER_DIR=/root/sglang/profile_log
# start server
python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct
python -m sglang.bench_serving --backend sglang --model-path meta-llama/Llama-3.1-8B-Instruct --num-prompts 10 --profile
```
Traces can be visualized using https://ui.perfetto.dev/.
- To profile offline
```bash
export SGLANG_TORCH_PROFILER_DIR=/root/sglang/profile_log
python -m sglang.bench_offline_throughput --model-path meta-llama/Llama-3.1-8B-Instruct --dataset-name random --num-prompts 10 --profile --mem-frac=0.8
```
...@@ -14,6 +14,7 @@ import argparse ...@@ -14,6 +14,7 @@ import argparse
import dataclasses import dataclasses
import json import json
import logging import logging
import os
import random import random
import time import time
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
...@@ -27,7 +28,7 @@ from sglang.bench_serving import ( ...@@ -27,7 +28,7 @@ from sglang.bench_serving import (
sample_random_requests, sample_random_requests,
set_ulimit, set_ulimit,
) )
from sglang.srt.server import Runtime from sglang.srt.server import Runtime, start_profile, stop_profile
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
...@@ -52,6 +53,7 @@ class BenchArgs: ...@@ -52,6 +53,7 @@ class BenchArgs:
seed: int = 1 seed: int = 1
skip_warmup: bool = False skip_warmup: bool = False
do_not_exit: bool = False do_not_exit: bool = False
profile: bool = False
@staticmethod @staticmethod
def add_cli_args(parser: argparse.ArgumentParser): def add_cli_args(parser: argparse.ArgumentParser):
...@@ -156,6 +158,12 @@ class BenchArgs: ...@@ -156,6 +158,12 @@ class BenchArgs:
action="store_true", action="store_true",
help="Do not exit the program. This is useful for nsys profile with --duration and --delay.", help="Do not exit the program. This is useful for nsys profile with --duration and --delay.",
) )
parser.add_argument(
"--profile",
action="store_true",
help="Use Torch Profiler. The endpoint must be launched with "
"SGLANG_TORCH_PROFILER_DIR to enable profiler.",
)
@classmethod @classmethod
def from_cli_args(cls, args: argparse.Namespace): def from_cli_args(cls, args: argparse.Namespace):
...@@ -169,6 +177,7 @@ def throughput_test_once( ...@@ -169,6 +177,7 @@ def throughput_test_once(
reqs: List[Tuple[str, int, int]], reqs: List[Tuple[str, int, int]],
ignore_eos: bool, ignore_eos: bool,
extra_request_body: Dict, extra_request_body: Dict,
profile: bool,
): ):
measurement_results = { measurement_results = {
"backend": backend_name, "backend": backend_name,
...@@ -194,7 +203,15 @@ def throughput_test_once( ...@@ -194,7 +203,15 @@ def throughput_test_once(
] ]
st = time.perf_counter() st = time.perf_counter()
if profile:
start_profile()
gen_out = backend.generate(prompt=prompt, sampling_params=sampling_params) gen_out = backend.generate(prompt=prompt, sampling_params=sampling_params)
if profile:
stop_profile()
monitor_trace_file(os.getenv("SGLANG_TORCH_PROFILER_DIR"))
latency = time.perf_counter() - st latency = time.perf_counter() - st
if backend_name == "runtime": if backend_name == "runtime":
...@@ -221,6 +238,41 @@ def throughput_test_once( ...@@ -221,6 +238,41 @@ def throughput_test_once(
return measurement_results return measurement_results
def monitor_trace_file(directory, interval=1):
print(f"Monitoring {directory} for new trace files...")
known_files = set(os.listdir(directory))
while True:
flag = False
time.sleep(interval)
current_files = set(os.listdir(directory))
new_files = current_files - known_files
for new_file in new_files:
new_file_path = os.path.join(directory, new_file)
print(f"New file detected: {new_file}")
previous_size = 0
while True:
try:
current_size = os.path.getsize(new_file_path)
except FileNotFoundError:
print(f"File {new_file} is no longer accessible.")
break
if current_size > previous_size:
previous_size = current_size
else:
flag = True
break
time.sleep(interval)
if flag:
break
def throughput_test( def throughput_test(
server_args: ServerArgs, server_args: ServerArgs,
bench_args: BenchArgs, bench_args: BenchArgs,
...@@ -268,6 +320,7 @@ def throughput_test( ...@@ -268,6 +320,7 @@ def throughput_test(
reqs=warmup_requests, reqs=warmup_requests,
ignore_eos=not bench_args.disable_ignore_eos, ignore_eos=not bench_args.disable_ignore_eos,
extra_request_body=extra_request_body, extra_request_body=extra_request_body,
profile=False,
) )
logging.info("\nBenchmark...") logging.info("\nBenchmark...")
...@@ -277,6 +330,7 @@ def throughput_test( ...@@ -277,6 +330,7 @@ def throughput_test(
reqs=input_requests, reqs=input_requests,
ignore_eos=not bench_args.disable_ignore_eos, ignore_eos=not bench_args.disable_ignore_eos,
extra_request_body=extra_request_body, extra_request_body=extra_request_body,
profile=bench_args.profile,
) )
if bench_args.result_filename: if bench_args.result_filename:
......
...@@ -169,9 +169,19 @@ async def flush_cache(): ...@@ -169,9 +169,19 @@ async def flush_cache():
) )
def start_profile():
"""Start profiling."""
tokenizer_manager.start_profile()
def stop_profile():
"""Stop profiling."""
tokenizer_manager.stop_profile()
@app.get("/start_profile") @app.get("/start_profile")
@app.post("/start_profile") @app.post("/start_profile")
async def start_profile(): async def start_profile_async():
"""Start profiling.""" """Start profiling."""
tokenizer_manager.start_profile() tokenizer_manager.start_profile()
return Response( return Response(
...@@ -182,7 +192,7 @@ async def start_profile(): ...@@ -182,7 +192,7 @@ async def start_profile():
@app.get("/stop_profile") @app.get("/stop_profile")
@app.post("/stop_profile") @app.post("/stop_profile")
async def stop_profile(): async def stop_profile_async():
"""Stop profiling.""" """Stop profiling."""
tokenizer_manager.stop_profile() tokenizer_manager.stop_profile()
return Response( return Response(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment