Unverified Commit df7c4c19 authored by Ying Sheng's avatar Ying Sheng Committed by GitHub
Browse files

Fix trt benchmark (#697)

parent c3f1aac8
......@@ -78,6 +78,8 @@ async def async_request_trt_llm(
"top_p": 1.0,
"max_tokens": request_func_input.output_len,
"stream": True,
"min_length": request_func_input.output_len,
"end_id": 1048576,
}
output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len
......@@ -111,6 +113,7 @@ async def async_request_trt_llm(
output.latency = most_recent_timestamp - st
output.success = True
output.output_len = request_func_input.output_len
else:
output.error = response.reason or ""
......@@ -244,9 +247,11 @@ class BenchmarkMetrics:
completed: int
total_input: int
total_output: int
total_output_retokenized: int
request_throughput: float
input_throughput: float
output_throughput: float
output_throughput_retokenized: float
mean_ttft_ms: float
median_ttft_ms: float
std_ttft_ms: float
......@@ -455,7 +460,8 @@ def calculate_metrics(
tokenizer: PreTrainedTokenizerBase,
backend: str,
) -> Tuple[BenchmarkMetrics, List[int]]:
actual_output_lens: List[int] = []
output_lens: List[int] = []
retokenized_output_lens: List[int] = []
total_input = 0
completed = 0
itls: List[float] = []
......@@ -463,17 +469,12 @@ def calculate_metrics(
ttfts: List[float] = []
for i in range(len(outputs)):
if outputs[i].success:
# We use the tokenizer solely to count output tokens for the TensorRT LLM backend,
# as it lacks `ignore_eos` support.
if backend == "trt":
output_len = len(
tokenizer(
outputs[i].generated_text, add_special_tokens=False
).input_ids
)
else:
output_len = outputs[i].output_len
actual_output_lens.append(output_len)
output_len = outputs[i].output_len
output_lens.append(output_len)
retokenized_output_len = len(
tokenizer(outputs[i].generated_text, add_special_tokens=False).input_ids
)
retokenized_output_lens.append(retokenized_output_len)
total_input += input_requests[i][1]
if output_len > 1:
tpots.append((outputs[i].latency - outputs[i].ttft) / (output_len - 1))
......@@ -481,7 +482,8 @@ def calculate_metrics(
ttfts.append(outputs[i].ttft)
completed += 1
else:
actual_output_lens.append(0)
output_lens.append(0)
retokenized_output_lens.append(0)
if completed == 0:
warnings.warn(
......@@ -492,10 +494,12 @@ def calculate_metrics(
metrics = BenchmarkMetrics(
completed=completed,
total_input=total_input,
total_output=sum(actual_output_lens),
total_output=sum(output_lens),
total_output_retokenized=sum(retokenized_output_lens),
request_throughput=completed / dur_s,
input_throughput=total_input / dur_s,
output_throughput=sum(actual_output_lens) / dur_s,
output_throughput=sum(output_lens) / dur_s,
output_throughput_retokenized=sum(retokenized_output_lens) / dur_s,
mean_ttft_ms=np.mean(ttfts or 0)
* 1000, # ttfts is empty if streaming is not supported by backend
median_ttft_ms=np.median(ttfts or 0) * 1000,
......@@ -511,7 +515,7 @@ def calculate_metrics(
p99_itl_ms=np.percentile(itls or 0, 99) * 1000,
)
return metrics, actual_output_lens
return metrics, output_lens
async def benchmark(
......@@ -572,7 +576,7 @@ async def benchmark(
benchmark_duration = time.perf_counter() - benchmark_start_time
metrics, actual_output_lens = calculate_metrics(
metrics, output_lens = calculate_metrics(
input_requests=input_requests,
outputs=outputs,
dur_s=benchmark_duration,
......@@ -587,6 +591,11 @@ async def benchmark(
print("{:<40} {:<10.2f}".format("Benchmark duration (s):", benchmark_duration))
print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input))
print("{:<40} {:<10}".format("Total generated tokens:", metrics.total_output))
print(
"{:<40} {:<10}".format(
"Total generated tokens (retokenized):", metrics.total_output_retokenized
)
)
print(
"{:<40} {:<10.2f}".format(
"Request throughput (req/s):", metrics.request_throughput
......@@ -629,6 +638,7 @@ async def benchmark(
"request_rate": request_rate,
"total_input": metrics.total_input,
"total_output": metrics.total_output,
"total_output_retokenized": metrics.total_output_retokenized,
"median_ttft": metrics.median_ttft_ms,
"median_itl": metrics.mean_itl_ms,
"output_token_throughput": metrics.output_throughput,
......@@ -661,6 +671,7 @@ async def benchmark(
"completed": metrics.completed,
"total_input_tokens": metrics.total_input,
"total_output_tokens": metrics.total_output,
"total_output_tokens_retokenized": metrics.total_output_retokenized,
"request_throughput": metrics.request_throughput,
"input_throughput": metrics.input_throughput,
"output_throughput": metrics.output_throughput,
......@@ -677,7 +688,7 @@ async def benchmark(
"std_itl_ms": metrics.std_itl_ms,
"p99_itl_ms": metrics.p99_itl_ms,
"input_lens": [output.prompt_len for output in outputs],
"output_lens": actual_output_lens,
"output_lens": output_lens,
"ttfts": [output.ttft for output in outputs],
"itls": [output.itl for output in outputs],
"generated_texts": [output.generated_text for output in outputs],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment