Unverified Commit ed0a3dd5 authored by Zaili Wang's avatar Zaili Wang Committed by GitHub
Browse files

Enhancements for bench_one_batch (#8703)


Co-authored-by: default avatarroot <root@gnr630186.jf.intel.com>
parent 2e901e89
......@@ -43,6 +43,7 @@ I'm going to the park
"""
import argparse
import copy
import dataclasses
import itertools
import json
......@@ -84,12 +85,14 @@ class BenchArgs:
batch_size: Tuple[int] = (1,)
input_len: Tuple[int] = (1024,)
output_len: Tuple[int] = (16,)
prompt_filename: str = ""
result_filename: str = "result.jsonl"
correctness_test: bool = False
# This is only used for correctness test
cut_len: int = 4
log_decode_step: int = 0
profile: bool = False
profile_record_shapes: bool = False
profile_filename_prefix: str = "profile"
@staticmethod
......@@ -104,6 +107,9 @@ class BenchArgs:
parser.add_argument(
"--output-len", type=int, nargs="+", default=BenchArgs.output_len
)
parser.add_argument(
"--prompt-filename", type=str, default=BenchArgs.prompt_filename
)
parser.add_argument(
"--result-filename", type=str, default=BenchArgs.result_filename
)
......@@ -118,6 +124,11 @@ class BenchArgs:
parser.add_argument(
"--profile", action="store_true", help="Use Torch Profiler."
)
parser.add_argument(
"--profile-record-shapes",
action="store_true",
help="Record tensor shapes in profiling results.",
)
parser.add_argument(
"--profile-filename-prefix",
type=str,
......@@ -165,12 +176,16 @@ def load_model(server_args, port_args, tp_rank):
return model_runner, tokenizer
def prepare_inputs_for_correctness_test(bench_args, tokenizer):
prompts = [
"The capital of France is",
"The capital of the United Kindom is",
"Today is a sunny day and I like",
]
def prepare_inputs_for_correctness_test(bench_args, tokenizer, custom_prompts):
prompts = (
custom_prompts
if custom_prompts
else [
"The capital of France is",
"The capital of the United Kindom is",
"Today is a sunny day and I like",
]
)
input_ids = [tokenizer.encode(p) for p in prompts]
sampling_params = SamplingParams(
temperature=0,
......@@ -211,8 +226,14 @@ def prepare_extend_inputs_for_correctness_test(
return reqs
def prepare_synthetic_inputs_for_latency_test(batch_size, input_len):
input_ids = np.random.randint(0, 10000, (batch_size, input_len), dtype=np.int32)
def prepare_synthetic_inputs_for_latency_test(
batch_size, input_len, custom_inputs=None
):
input_ids = (
custom_inputs
if custom_inputs
else np.random.randint(0, 10000, (batch_size, input_len), dtype=np.int32)
)
sampling_params = SamplingParams(
temperature=0,
max_new_tokens=BenchArgs.output_len,
......@@ -284,6 +305,30 @@ def _maybe_prepare_mlp_sync_batch(batch: ScheduleBatch, model_runner):
)
def _read_prompts_from_file(prompt_file, rank_print):
"""Read custom prompts from the file specified by `--prompt-filename`."""
if not prompt_file:
return []
if not os.path.exists(prompt_file):
rank_print(
f"Custom prompt file {prompt_file} not found. Using default inputs..."
)
return []
with open(prompt_file, "r") as pf:
return pf.readlines()
def _save_profile_trace_results(profiler, filename):
parent_dir = os.path.dirname(os.path.abspath(filename))
os.makedirs(parent_dir, exist_ok=True)
profiler.export_chrome_trace(filename)
print(
profiler.key_averages(group_by_input_shape=True).table(
sort_by="self_cpu_time_total"
)
)
def correctness_test(
server_args,
port_args,
......@@ -298,7 +343,10 @@ def correctness_test(
model_runner, tokenizer = load_model(server_args, port_args, tp_rank)
# Prepare inputs
input_ids, reqs = prepare_inputs_for_correctness_test(bench_args, tokenizer)
custom_prompts = _read_prompts_from_file(bench_args.prompt_filename, rank_print)
input_ids, reqs = prepare_inputs_for_correctness_test(
bench_args, tokenizer, custom_prompts
)
rank_print(f"\n{input_ids=}\n")
if bench_args.cut_len > 0:
......@@ -344,6 +392,7 @@ def latency_test_run_once(
device,
log_decode_step,
profile,
profile_record_shapes,
profile_filename_prefix,
):
max_batch_size = model_runner.max_total_num_tokens // (input_len + output_len)
......@@ -374,6 +423,7 @@ def latency_test_run_once(
torch.profiler.ProfilerActivity.CUDA,
],
with_stack=True,
record_shapes=profile_record_shapes,
)
profiler.start()
......@@ -391,10 +441,30 @@ def latency_test_run_once(
measurement_results["prefill_latency"] = prefill_latency
measurement_results["prefill_throughput"] = throughput
if profile:
profiler.stop()
profile_filename = f"{profile_filename_prefix}_batch{batch_size}_input{input_len}_output{output_len}_prefill.trace.json.gz"
_save_profile_trace_results(profiler, profile_filename)
rank_print(
f"torch profiler chrome trace for prefill saved to {profile_filename}"
)
# Decode
decode_latencies = []
for i in range(output_len - 1):
synchronize(device)
if profile and i == output_len / 2:
profiler = None
profiler = torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
with_stack=True,
record_shapes=profile_record_shapes,
)
profiler.start()
tic = time.perf_counter()
next_token_ids, _ = decode(next_token_ids, batch, model_runner)
synchronize(device)
......@@ -407,13 +477,13 @@ def latency_test_run_once(
f"Decode {i}. Batch size: {batch_size}, latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s"
)
if profile:
profiler.stop()
profile_filename = f"{profile_filename_prefix}_batch{batch_size}_input{input_len}_output{output_len}.trace.json.gz"
parent_dir = os.path.dirname(os.path.abspath(profile_filename))
os.makedirs(parent_dir, exist_ok=True)
profiler.export_chrome_trace(profile_filename)
rank_print(f"torch profiler chrome trace saved to {profile_filename}")
if profile and i == output_len / 2:
profiler.stop()
profile_filename = f"{profile_filename_prefix}_batch{batch_size}_input{input_len}_output{output_len}_decode.trace.json.gz"
_save_profile_trace_results(profiler, profile_filename)
rank_print(
f"torch profiler chrome trace for decoding 1 token saved to {profile_filename}"
)
# Record decode timing from 2nd output
if output_len > 1:
......@@ -469,17 +539,42 @@ def latency_test(
server_args.device,
log_decode_step=0,
profile=False,
profile_record_shapes=False,
profile_filename_prefix="", # not used
)
rank_print("Benchmark ...")
custom_inputs = _read_prompts_from_file(bench_args.prompt_filename, rank_print)
custom_inputs = [tokenizer.encode(p.strip()) for p in custom_inputs]
custom_input_len = len(custom_inputs)
# Run the sweep
result_list = []
for bs, il, ol in itertools.product(
bench_args.batch_size, bench_args.input_len, bench_args.output_len
):
reqs = prepare_synthetic_inputs_for_latency_test(bs, il)
bs_aligned_inputs = []
if custom_inputs:
if custom_input_len == bs:
bs_aligned_inputs = custom_inputs
elif custom_input_len > bs:
rank_print(
f"Custom input size ({custom_input_len}) is larger than batch_size ({bs}). "
f"Using the first {bs} prompts."
)
bs_aligned_inputs = copy.deepcopy(custom_inputs[:bs])
else:
rank_print(
f"Custom input size ({custom_input_len}) is smaller than batch_size ({bs}). "
f"Pad to the desired batch_size with the last prompt."
)
bs_aligned_inputs = copy.deepcopy(custom_inputs)
bs_aligned_inputs.extend(
[bs_aligned_inputs[-1]] * (bs - custom_input_len)
)
reqs = prepare_synthetic_inputs_for_latency_test(bs, il, bs_aligned_inputs)
ret = latency_test_run_once(
bench_args.run_name,
model_runner,
......@@ -491,6 +586,7 @@ def latency_test(
server_args.device,
bench_args.log_decode_step,
bench_args.profile if tp_rank == 0 else None,
bench_args.profile_record_shapes if tp_rank == 0 else None,
bench_args.profile_filename_prefix,
)
if ret is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment