"src/vscode:/vscode.git/clone" did not exist on "69e72b1dd113927ed638f26e82738e9735385edc"
Unverified Commit e00e5385 authored by Yun Dai's avatar Yun Dai Committed by GitHub
Browse files

add profiling to bench_one_batch script (#2821)

parent a2f602b5
......@@ -9,7 +9,8 @@ It accepts server arguments (the same as launch_server.py) and benchmark argumen
python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --load-format dummy
## sweep through multiple data points and store (append) the results in a jsonl file:
python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 12 14 --input-len 256 512 --output-len 32 256 --run-name test_run
## run with profiling:
python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 12 14 --input-len 256 512 --profile
# Usage (correctness test):
python -m sglang.bench_one_batch --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --correct
......@@ -77,6 +78,8 @@ class BenchArgs:
correctness_test: bool = False
# This is only used for correctness test
cut_len: int = 4
profile: bool = False
profile_filename_prefix: str = "profile"
@staticmethod
def add_cli_args(parser: argparse.ArgumentParser):
......@@ -95,6 +98,19 @@ class BenchArgs:
)
parser.add_argument("--correctness-test", action="store_true")
parser.add_argument("--cut-len", type=int, default=BenchArgs.cut_len)
parser.add_argument(
"--profile",
action="store_true",
help="Use Torch Profiler. The endpoint must be launched with "
"SGLANG_TORCH_PROFILER_DIR to enable profiler.",
)
parser.add_argument(
"--profile-filename-prefix",
type=str,
default=BenchArgs.profile_filename_prefix,
help="Prefix of the profiling file names. The full profiling result file(s) be "
'"[profile_filename_prefix]_batch[batch_size]_input[input_len]_output[output_len].trace.json.gz"',
)
@classmethod
def from_cli_args(cls, args: argparse.Namespace):
......@@ -286,7 +302,16 @@ def synchronize(device):
def latency_test_run_once(
run_name, model_runner, rank_print, reqs, batch_size, input_len, output_len, device
run_name,
model_runner,
rank_print,
reqs,
batch_size,
input_len,
output_len,
device,
profile,
profile_filename_prefix,
):
max_batch_size = model_runner.max_total_num_tokens // (input_len + output_len)
if batch_size > max_batch_size:
......@@ -308,6 +333,17 @@ def latency_test_run_once(
tot_latency = 0
profiler = None
if profile:
profiler = torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
with_stack=True,
)
profiler.start()
# Prefill
synchronize(device)
tic = time.time()
......@@ -338,6 +374,13 @@ def latency_test_run_once(
f"Decode. latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s"
)
if profile:
profiler.stop()
profile_filename = f"{profile_filename_prefix}_batch{batch_size}_input{input_len}_output{output_len}.trace.json.gz"
parent_dir = os.path.dirname(os.path.abspath(profile_filename))
os.makedirs(parent_dir, exist_ok=True)
profiler.export_chrome_trace(profile_filename)
# Record decode timing from 2nd output
if output_len > 1:
med_decode_latency = np.median(decode_latencies)
......@@ -386,6 +429,8 @@ def latency_test(
bench_args.input_len[0],
8, # shorter decoding to speed up the warmup
server_args.device,
profile=False,
profile_filename_prefix="", # not used
)
rank_print("Benchmark ...")
......@@ -405,6 +450,8 @@ def latency_test(
il,
ol,
server_args.device,
bench_args.profile,
bench_args.profile_filename_prefix,
)
if ret is not None:
result_list.append(ret)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment