Unverified Commit afd411d0 authored by min-xu-et's avatar min-xu-et Committed by GitHub
Browse files

enhance latency test - part 2 (#915)

parent e1eae1fd
......@@ -220,6 +220,68 @@ def correctness_test(
rank_print(tokenizer.decode(output_ids[i]))
@torch.inference_mode()
def latency_test_run_once(
model_runner, rank_print, reqs, batch_size, input_len, output_len
):
# Clear the pools.
model_runner.req_to_token_pool.clear()
model_runner.token_to_kv_pool.clear()
measurement_results = {
"run_name": "before",
"batch_size": batch_size,
"input_len": input_len,
"output_len": output_len,
}
tot_latency = 0
# Prefill
torch.cuda.synchronize()
tic = time.time()
next_token_ids, _, batch = extend(reqs, model_runner)
torch.cuda.synchronize()
prefill_latency = time.time() - tic
tot_latency += prefill_latency
throughput = input_len * batch_size / prefill_latency
rank_print(
f"Prefill. latency: {prefill_latency:6.5f} s, throughput: {throughput:9.2f} token/s"
)
measurement_results["prefill_latency"] = prefill_latency
measurement_results["prefill_throughput"] = throughput
# Decode
for i in range(output_len):
torch.cuda.synchronize()
tic = time.time()
next_token_ids, _ = decode(next_token_ids, batch, model_runner)
torch.cuda.synchronize()
latency = time.time() - tic
tot_latency += latency
throughput = batch_size / latency
if i < 5:
rank_print(
f"Decode. latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s"
)
avg_decode_latency = (tot_latency - prefill_latency) / output_len
avg_decode_throughput = batch_size / avg_decode_latency
rank_print(
f"Decode. avg latency: {avg_decode_latency:6.5f} s, avg throughput: {avg_decode_throughput:9.2f} token/s"
)
measurement_results["avg_decode_latency"] = avg_decode_latency
measurement_results["avg_decode_throughput"] = avg_decode_throughput
throughput = (input_len + output_len) * batch_size / tot_latency
rank_print(
f"Total. latency: {tot_latency:6.3f} s, throughput: {throughput:9.2f} token/s"
)
measurement_results["total_latency"] = tot_latency
measurement_results["total_throughput"] = throughput
return measurement_results
def latency_test(
server_args,
bench_args,
......@@ -241,72 +303,23 @@ def latency_test(
bench_args.batch_size, bench_args.input_len
)
def clear():
model_runner.req_to_token_pool.clear()
model_runner.token_to_kv_pool.clear()
@torch.inference_mode()
def run_once(output_len):
measurement_results = {
"batch_size": bench_args.batch_size,
"output_len": output_len,
}
# Prefill
torch.cuda.synchronize()
tot_latency = 0
tic = time.time()
next_token_ids, _, batch = extend(reqs, model_runner)
torch.cuda.synchronize()
prefill_latency = time.time() - tic
tot_latency += prefill_latency
throughput = bench_args.input_len * bench_args.batch_size / prefill_latency
rank_print(
f"Prefill. latency: {prefill_latency:6.5f} s, throughput: {throughput:9.2f} token/s"
)
measurement_results["prefill_latency"] = prefill_latency
measurement_results["prefill_throughput"] = throughput
# Decode
for i in range(output_len):
torch.cuda.synchronize()
tic = time.time()
next_token_ids, _ = decode(next_token_ids, batch, model_runner)
torch.cuda.synchronize()
latency = time.time() - tic
tot_latency += latency
throughput = bench_args.batch_size / latency
if i < 5:
rank_print(
f"Decode. latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s"
)
avg_decode_latency = (tot_latency - prefill_latency) / output_len
avg_decode_throughput = bench_args.batch_size / avg_decode_latency
rank_print(
f"Decode. avg latency: {avg_decode_latency:6.5f} s, avg throughput: {avg_decode_throughput:9.2f} token/s"
)
measurement_results["avg_decode_latency"] = avg_decode_latency
measurement_results["avg_decode_throughput"] = avg_decode_throughput
throughput = (
(bench_args.input_len + bench_args.output_len)
* bench_args.batch_size
/ tot_latency
)
rank_print(
f"Total. latency: {tot_latency:6.3f} s, throughput: {throughput:9.2f} token/s"
)
measurement_results["total_latency"] = tot_latency
measurement_results["total_throughput"] = throughput
return measurement_results
# Warm up
run_once(4)
clear()
latency_test_run_once(
model_runner, rank_print, reqs, bench_args.batch_size, bench_args.input_len, 4
)
# Run again
result_list = []
result_list.append(run_once(bench_args.output_len))
result_list.append(
latency_test_run_once(
model_runner,
rank_print,
reqs,
bench_args.batch_size,
bench_args.input_len,
bench_args.output_len,
)
)
# Write results in jsonlines format.
if bench_args.result_filename:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment