Unverified Commit 6c1a3f0c authored by lpc0220's avatar lpc0220 Committed by GitHub
Browse files

enable cudaProfilerApi for one batch benchmarking (#11116)

parent 62377548
......@@ -11,6 +11,11 @@ python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruc
python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 12 14 --input-len 256 512 --output-len 32 256 --run-name test_run
## run with profiling:
python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 12 14 --input-len 256 512 --profile
## run with profiling to custom directory:
export SGLANG_TORCH_PROFILER_DIR=/root/sglang/profile_log
python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 --input-len 256 --profile
## run with CUDA profiler (nsys):
nsys profile --force-overwrite=true -o bench_one_batch python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 --input-len 256 --profile --profiler_activities CUDA_PROFILER
# Usage (correctness test):
python -m sglang.bench_one_batch --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --correct
......@@ -93,6 +98,68 @@ profile_activities = [torch.profiler.ProfilerActivity.CPU] + [
]
def start_profile(profiler_activities, profile_record_shapes=False, rank_print=print):
"""
Abstracted function to start profiling based on profiler_activities.
Returns profiler object (or None).
"""
if "CUDA_PROFILER" in profiler_activities:
try:
torch.cuda.cudart().cudaProfilerStart()
rank_print("CUDA Profiler started (nsys will begin capturing)")
except Exception as e:
rank_print(f"Failed to start CUDA profiler: {e}")
return None
else:
activities = []
if "CPU" in profiler_activities:
activities.append(torch.profiler.ProfilerActivity.CPU)
if "GPU" in profiler_activities:
activities.append(torch.profiler.ProfilerActivity.CUDA)
if activities:
profiler = torch.profiler.profile(
activities=activities,
with_stack=True,
record_shapes=profile_record_shapes,
)
profiler.start()
return profiler
return None
def stop_profile(
profiler,
profiler_activities,
rank_print=print,
save_trace=False,
trace_filename=None,
stage=None,
):
"""
Abstracted function to stop profiling based on profiler_activities.
Optionally saves trace results and prints completion messages.
"""
if "CUDA_PROFILER" in profiler_activities:
try:
torch.cuda.cudart().cudaProfilerStop()
rank_print("CUDA Profiler stopped (nsys should dump traces)")
except Exception as e:
rank_print(f"Failed to stop CUDA profiler: {e}")
elif profiler is not None:
profiler.stop()
if save_trace:
if profiler is not None:
if trace_filename:
_save_profile_trace_results(profiler, trace_filename)
stage_desc = f"for {stage}" if stage else ""
rank_print(
f"torch profiler chrome trace {stage_desc} saved to {trace_filename}"
)
if "CUDA_PROFILER" in profiler_activities:
rank_print(f"CUDA profiler trace for {stage} completed")
@dataclasses.dataclass
class BenchArgs:
run_name: str = "default"
......@@ -107,6 +174,8 @@ class BenchArgs:
log_decode_step: int = 0
profile: bool = False
profile_record_shapes: bool = False
profiler_activities: Tuple[str] = ("CPU", "GPU")
profile_stage: str = "all"
profile_filename_prefix: str = "profile"
@staticmethod
......@@ -135,14 +204,27 @@ class BenchArgs:
default=BenchArgs.log_decode_step,
help="Log decode latency by step, default is set to zero to disable.",
)
parser.add_argument(
"--profile", action="store_true", help="Use Torch Profiler."
)
parser.add_argument("--profile", action="store_true", help="Enable profiling.")
parser.add_argument(
"--profile-record-shapes",
action="store_true",
help="Record tensor shapes in profiling results.",
)
parser.add_argument(
"--profiler_activities",
type=str,
nargs="+",
default=["CPU", "GPU"],
choices=["CPU", "GPU", "CUDA_PROFILER"],
help="Profiler activities: CPU, GPU, CUDA_PROFILER. If CPU/GPU, use torch profiler. If CUDA_PROFILER, use CUDA profiler.",
)
parser.add_argument(
"--profile-stage",
type=str,
default=BenchArgs.profile_stage,
choices=["all", "prefill", "decode"],
help="Which stage to profile: all, prefill, or decode only.",
)
parser.add_argument(
"--profile-filename-prefix",
type=str,
......@@ -337,6 +419,18 @@ def _read_prompts_from_file(prompt_file, rank_print):
return pf.readlines()
def _get_torch_profiler_output_dir():
return os.environ.get("SGLANG_TORCH_PROFILER_DIR", "/tmp")
def _create_torch_profiler_filename(
profile_filename_prefix, batch_size, input_len, output_len, stage
):
output_dir = _get_torch_profiler_output_dir()
filename = f"{profile_filename_prefix}_batch{batch_size}_input{input_len}_output{output_len}_{stage}.trace.json.gz"
return os.path.join(output_dir, filename)
def _save_profile_trace_results(profiler, filename):
parent_dir = os.path.dirname(os.path.abspath(filename))
os.makedirs(parent_dir, exist_ok=True)
......@@ -413,7 +507,10 @@ def latency_test_run_once(
log_decode_step,
profile,
profile_record_shapes,
profiler_activities,
profile_filename_prefix,
profile_stage,
tp_rank,
):
max_batch_size = model_runner.max_total_num_tokens // (input_len + output_len)
if batch_size > max_batch_size:
......@@ -422,7 +519,6 @@ def latency_test_run_once(
)
return
# Clear the pools.
model_runner.req_to_token_pool.clear()
model_runner.token_to_kv_pool_allocator.clear()
......@@ -436,20 +532,33 @@ def latency_test_run_once(
tot_latency = 0
profiler = None
if profile:
profiler = torch.profiler.profile(
activities=profile_activities,
with_stack=True,
record_shapes=profile_record_shapes,
enable_profile_prefill = profile and profile_stage in ["all", "prefill"]
if enable_profile_prefill:
profiler = start_profile(
profiler_activities,
profile_record_shapes=profile_record_shapes,
rank_print=rank_print,
)
profiler.start()
# Prefill
synchronize(device)
tic = time.perf_counter()
next_token_ids, _, batch = extend(reqs, model_runner)
synchronize(device)
prefill_latency = time.perf_counter() - tic
if enable_profile_prefill:
trace_filename = _create_torch_profiler_filename(
profile_filename_prefix, batch_size, input_len, output_len, "prefill"
)
stop_profile(
profiler,
profiler_activities,
rank_print=rank_print,
save_trace=True,
trace_filename=trace_filename,
stage="prefill",
)
tot_latency += prefill_latency
throughput = input_len * batch_size / prefill_latency
rank_print(
......@@ -458,29 +567,37 @@ def latency_test_run_once(
measurement_results["prefill_latency"] = prefill_latency
measurement_results["prefill_throughput"] = throughput
if profile:
profiler.stop()
trace_filename = f"{profile_filename_prefix}_batch{batch_size}_input{input_len}_output{output_len}_prefill.trace.json.gz"
_save_profile_trace_results(profiler, trace_filename)
rank_print(f"torch profiler chrome trace for prefill saved to {trace_filename}")
# Decode
decode_latencies = []
profile_step_of_interest = output_len // 2
enable_profile_decode = profile and profile_stage in ["all", "decode"]
for i in range(output_len - 1):
synchronize(device)
if profile and i == output_len / 2:
profiler = None
profiler = torch.profiler.profile(
activities=profile_activities,
with_stack=True,
record_shapes=profile_record_shapes,
profiler = None
if enable_profile_decode and i == profile_step_of_interest:
profiler = start_profile(
profiler_activities,
profile_record_shapes=profile_record_shapes,
rank_print=rank_print,
)
profiler.start()
tic = time.perf_counter()
next_token_ids, _ = decode(next_token_ids, batch, model_runner)
synchronize(device)
latency = time.perf_counter() - tic
if enable_profile_decode and i == profile_step_of_interest:
trace_filename = _create_torch_profiler_filename(
profile_filename_prefix, batch_size, input_len, output_len, "decode"
)
stop_profile(
profiler,
profiler_activities,
rank_print=rank_print,
save_trace=True,
trace_filename=trace_filename,
stage="decode",
)
tot_latency += latency
throughput = batch_size / latency
decode_latencies.append(latency)
......@@ -489,14 +606,6 @@ def latency_test_run_once(
f"Decode {i}. Batch size: {batch_size}, latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s"
)
if profile and i == output_len / 2:
profiler.stop()
trace_filename = f"{profile_filename_prefix}_batch{batch_size}_input{input_len}_output{output_len}_decode.trace.json.gz"
_save_profile_trace_results(profiler, trace_filename)
rank_print(
f"torch profiler chrome trace for decoding 1 token saved to {trace_filename}"
)
# Record decode timing from 2nd output
if output_len > 1:
med_decode_latency = np.median(decode_latencies)
......@@ -557,7 +666,10 @@ def latency_test(
log_decode_step=0,
profile=False,
profile_record_shapes=False,
profile_filename_prefix="", # not used
profiler_activities=("CPU", "GPU"),
profile_filename_prefix="",
profile_stage="all",
tp_rank=tp_rank,
)
rank_print("Benchmark ...")
......@@ -604,7 +716,10 @@ def latency_test(
bench_args.log_decode_step,
bench_args.profile if tp_rank == 0 else None,
bench_args.profile_record_shapes if tp_rank == 0 else None,
bench_args.profiler_activities,
bench_args.profile_filename_prefix,
bench_args.profile_stage,
tp_rank,
)
if ret is not None:
result_list.append(ret)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment