Unverified Commit ed0a3dd5 authored by Zaili Wang's avatar Zaili Wang Committed by GitHub
Browse files

Enhancements for bench_one_batch (#8703)


Co-authored-by: default avatarroot <root@gnr630186.jf.intel.com>
parent 2e901e89
...@@ -43,6 +43,7 @@ I'm going to the park ...@@ -43,6 +43,7 @@ I'm going to the park
""" """
import argparse import argparse
import copy
import dataclasses import dataclasses
import itertools import itertools
import json import json
...@@ -84,12 +85,14 @@ class BenchArgs: ...@@ -84,12 +85,14 @@ class BenchArgs:
batch_size: Tuple[int] = (1,) batch_size: Tuple[int] = (1,)
input_len: Tuple[int] = (1024,) input_len: Tuple[int] = (1024,)
output_len: Tuple[int] = (16,) output_len: Tuple[int] = (16,)
prompt_filename: str = ""
result_filename: str = "result.jsonl" result_filename: str = "result.jsonl"
correctness_test: bool = False correctness_test: bool = False
# This is only used for correctness test # This is only used for correctness test
cut_len: int = 4 cut_len: int = 4
log_decode_step: int = 0 log_decode_step: int = 0
profile: bool = False profile: bool = False
profile_record_shapes: bool = False
profile_filename_prefix: str = "profile" profile_filename_prefix: str = "profile"
@staticmethod @staticmethod
...@@ -104,6 +107,9 @@ class BenchArgs: ...@@ -104,6 +107,9 @@ class BenchArgs:
parser.add_argument( parser.add_argument(
"--output-len", type=int, nargs="+", default=BenchArgs.output_len "--output-len", type=int, nargs="+", default=BenchArgs.output_len
) )
parser.add_argument(
"--prompt-filename", type=str, default=BenchArgs.prompt_filename
)
parser.add_argument( parser.add_argument(
"--result-filename", type=str, default=BenchArgs.result_filename "--result-filename", type=str, default=BenchArgs.result_filename
) )
...@@ -118,6 +124,11 @@ class BenchArgs: ...@@ -118,6 +124,11 @@ class BenchArgs:
parser.add_argument( parser.add_argument(
"--profile", action="store_true", help="Use Torch Profiler." "--profile", action="store_true", help="Use Torch Profiler."
) )
parser.add_argument(
"--profile-record-shapes",
action="store_true",
help="Record tensor shapes in profiling results.",
)
parser.add_argument( parser.add_argument(
"--profile-filename-prefix", "--profile-filename-prefix",
type=str, type=str,
...@@ -165,12 +176,16 @@ def load_model(server_args, port_args, tp_rank): ...@@ -165,12 +176,16 @@ def load_model(server_args, port_args, tp_rank):
return model_runner, tokenizer return model_runner, tokenizer
def prepare_inputs_for_correctness_test(bench_args, tokenizer): def prepare_inputs_for_correctness_test(bench_args, tokenizer, custom_prompts):
prompts = [ prompts = (
"The capital of France is", custom_prompts
"The capital of the United Kindom is", if custom_prompts
"Today is a sunny day and I like", else [
] "The capital of France is",
"The capital of the United Kindom is",
"Today is a sunny day and I like",
]
)
input_ids = [tokenizer.encode(p) for p in prompts] input_ids = [tokenizer.encode(p) for p in prompts]
sampling_params = SamplingParams( sampling_params = SamplingParams(
temperature=0, temperature=0,
...@@ -211,8 +226,14 @@ def prepare_extend_inputs_for_correctness_test( ...@@ -211,8 +226,14 @@ def prepare_extend_inputs_for_correctness_test(
return reqs return reqs
def prepare_synthetic_inputs_for_latency_test(batch_size, input_len): def prepare_synthetic_inputs_for_latency_test(
input_ids = np.random.randint(0, 10000, (batch_size, input_len), dtype=np.int32) batch_size, input_len, custom_inputs=None
):
input_ids = (
custom_inputs
if custom_inputs
else np.random.randint(0, 10000, (batch_size, input_len), dtype=np.int32)
)
sampling_params = SamplingParams( sampling_params = SamplingParams(
temperature=0, temperature=0,
max_new_tokens=BenchArgs.output_len, max_new_tokens=BenchArgs.output_len,
...@@ -284,6 +305,30 @@ def _maybe_prepare_mlp_sync_batch(batch: ScheduleBatch, model_runner): ...@@ -284,6 +305,30 @@ def _maybe_prepare_mlp_sync_batch(batch: ScheduleBatch, model_runner):
) )
def _read_prompts_from_file(prompt_file, rank_print):
"""Read custom prompts from the file specified by `--prompt-filename`."""
if not prompt_file:
return []
if not os.path.exists(prompt_file):
rank_print(
f"Custom prompt file {prompt_file} not found. Using default inputs..."
)
return []
with open(prompt_file, "r") as pf:
return pf.readlines()
def _save_profile_trace_results(profiler, filename):
parent_dir = os.path.dirname(os.path.abspath(filename))
os.makedirs(parent_dir, exist_ok=True)
profiler.export_chrome_trace(filename)
print(
profiler.key_averages(group_by_input_shape=True).table(
sort_by="self_cpu_time_total"
)
)
def correctness_test( def correctness_test(
server_args, server_args,
port_args, port_args,
...@@ -298,7 +343,10 @@ def correctness_test( ...@@ -298,7 +343,10 @@ def correctness_test(
model_runner, tokenizer = load_model(server_args, port_args, tp_rank) model_runner, tokenizer = load_model(server_args, port_args, tp_rank)
# Prepare inputs # Prepare inputs
input_ids, reqs = prepare_inputs_for_correctness_test(bench_args, tokenizer) custom_prompts = _read_prompts_from_file(bench_args.prompt_filename, rank_print)
input_ids, reqs = prepare_inputs_for_correctness_test(
bench_args, tokenizer, custom_prompts
)
rank_print(f"\n{input_ids=}\n") rank_print(f"\n{input_ids=}\n")
if bench_args.cut_len > 0: if bench_args.cut_len > 0:
...@@ -344,6 +392,7 @@ def latency_test_run_once( ...@@ -344,6 +392,7 @@ def latency_test_run_once(
device, device,
log_decode_step, log_decode_step,
profile, profile,
profile_record_shapes,
profile_filename_prefix, profile_filename_prefix,
): ):
max_batch_size = model_runner.max_total_num_tokens // (input_len + output_len) max_batch_size = model_runner.max_total_num_tokens // (input_len + output_len)
...@@ -374,6 +423,7 @@ def latency_test_run_once( ...@@ -374,6 +423,7 @@ def latency_test_run_once(
torch.profiler.ProfilerActivity.CUDA, torch.profiler.ProfilerActivity.CUDA,
], ],
with_stack=True, with_stack=True,
record_shapes=profile_record_shapes,
) )
profiler.start() profiler.start()
...@@ -391,10 +441,30 @@ def latency_test_run_once( ...@@ -391,10 +441,30 @@ def latency_test_run_once(
measurement_results["prefill_latency"] = prefill_latency measurement_results["prefill_latency"] = prefill_latency
measurement_results["prefill_throughput"] = throughput measurement_results["prefill_throughput"] = throughput
if profile:
profiler.stop()
profile_filename = f"{profile_filename_prefix}_batch{batch_size}_input{input_len}_output{output_len}_prefill.trace.json.gz"
_save_profile_trace_results(profiler, profile_filename)
rank_print(
f"torch profiler chrome trace for prefill saved to {profile_filename}"
)
# Decode # Decode
decode_latencies = [] decode_latencies = []
for i in range(output_len - 1): for i in range(output_len - 1):
synchronize(device) synchronize(device)
if profile and i == output_len / 2:
profiler = None
profiler = torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
with_stack=True,
record_shapes=profile_record_shapes,
)
profiler.start()
tic = time.perf_counter() tic = time.perf_counter()
next_token_ids, _ = decode(next_token_ids, batch, model_runner) next_token_ids, _ = decode(next_token_ids, batch, model_runner)
synchronize(device) synchronize(device)
...@@ -407,13 +477,13 @@ def latency_test_run_once( ...@@ -407,13 +477,13 @@ def latency_test_run_once(
f"Decode {i}. Batch size: {batch_size}, latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s" f"Decode {i}. Batch size: {batch_size}, latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s"
) )
if profile: if profile and i == output_len / 2:
profiler.stop() profiler.stop()
profile_filename = f"{profile_filename_prefix}_batch{batch_size}_input{input_len}_output{output_len}.trace.json.gz" profile_filename = f"{profile_filename_prefix}_batch{batch_size}_input{input_len}_output{output_len}_decode.trace.json.gz"
parent_dir = os.path.dirname(os.path.abspath(profile_filename)) _save_profile_trace_results(profiler, profile_filename)
os.makedirs(parent_dir, exist_ok=True) rank_print(
profiler.export_chrome_trace(profile_filename) f"torch profiler chrome trace for decoding 1 token saved to {profile_filename}"
rank_print(f"torch profiler chrome trace saved to {profile_filename}") )
# Record decode timing from 2nd output # Record decode timing from 2nd output
if output_len > 1: if output_len > 1:
...@@ -469,17 +539,42 @@ def latency_test( ...@@ -469,17 +539,42 @@ def latency_test(
server_args.device, server_args.device,
log_decode_step=0, log_decode_step=0,
profile=False, profile=False,
profile_record_shapes=False,
profile_filename_prefix="", # not used profile_filename_prefix="", # not used
) )
rank_print("Benchmark ...") rank_print("Benchmark ...")
custom_inputs = _read_prompts_from_file(bench_args.prompt_filename, rank_print)
custom_inputs = [tokenizer.encode(p.strip()) for p in custom_inputs]
custom_input_len = len(custom_inputs)
# Run the sweep # Run the sweep
result_list = [] result_list = []
for bs, il, ol in itertools.product( for bs, il, ol in itertools.product(
bench_args.batch_size, bench_args.input_len, bench_args.output_len bench_args.batch_size, bench_args.input_len, bench_args.output_len
): ):
reqs = prepare_synthetic_inputs_for_latency_test(bs, il) bs_aligned_inputs = []
if custom_inputs:
if custom_input_len == bs:
bs_aligned_inputs = custom_inputs
elif custom_input_len > bs:
rank_print(
f"Custom input size ({custom_input_len}) is larger than batch_size ({bs}). "
f"Using the first {bs} prompts."
)
bs_aligned_inputs = copy.deepcopy(custom_inputs[:bs])
else:
rank_print(
f"Custom input size ({custom_input_len}) is smaller than batch_size ({bs}). "
f"Pad to the desired batch_size with the last prompt."
)
bs_aligned_inputs = copy.deepcopy(custom_inputs)
bs_aligned_inputs.extend(
[bs_aligned_inputs[-1]] * (bs - custom_input_len)
)
reqs = prepare_synthetic_inputs_for_latency_test(bs, il, bs_aligned_inputs)
ret = latency_test_run_once( ret = latency_test_run_once(
bench_args.run_name, bench_args.run_name,
model_runner, model_runner,
...@@ -491,6 +586,7 @@ def latency_test( ...@@ -491,6 +586,7 @@ def latency_test(
server_args.device, server_args.device,
bench_args.log_decode_step, bench_args.log_decode_step,
bench_args.profile if tp_rank == 0 else None, bench_args.profile if tp_rank == 0 else None,
bench_args.profile_record_shapes if tp_rank == 0 else None,
bench_args.profile_filename_prefix, bench_args.profile_filename_prefix,
) )
if ret is not None: if ret is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment