Unverified Commit df7c4c19 authored by Ying Sheng's avatar Ying Sheng Committed by GitHub
Browse files

Fix trt benchmark (#697)

parent c3f1aac8
...@@ -78,6 +78,8 @@ async def async_request_trt_llm( ...@@ -78,6 +78,8 @@ async def async_request_trt_llm(
"top_p": 1.0, "top_p": 1.0,
"max_tokens": request_func_input.output_len, "max_tokens": request_func_input.output_len,
"stream": True, "stream": True,
"min_length": request_func_input.output_len,
"end_id": 1048576,
} }
output = RequestFuncOutput() output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len output.prompt_len = request_func_input.prompt_len
...@@ -111,6 +113,7 @@ async def async_request_trt_llm( ...@@ -111,6 +113,7 @@ async def async_request_trt_llm(
output.latency = most_recent_timestamp - st output.latency = most_recent_timestamp - st
output.success = True output.success = True
output.output_len = request_func_input.output_len
else: else:
output.error = response.reason or "" output.error = response.reason or ""
...@@ -244,9 +247,11 @@ class BenchmarkMetrics: ...@@ -244,9 +247,11 @@ class BenchmarkMetrics:
completed: int completed: int
total_input: int total_input: int
total_output: int total_output: int
total_output_retokenized: int
request_throughput: float request_throughput: float
input_throughput: float input_throughput: float
output_throughput: float output_throughput: float
output_throughput_retokenized: float
mean_ttft_ms: float mean_ttft_ms: float
median_ttft_ms: float median_ttft_ms: float
std_ttft_ms: float std_ttft_ms: float
...@@ -455,7 +460,8 @@ def calculate_metrics( ...@@ -455,7 +460,8 @@ def calculate_metrics(
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
backend: str, backend: str,
) -> Tuple[BenchmarkMetrics, List[int]]: ) -> Tuple[BenchmarkMetrics, List[int]]:
actual_output_lens: List[int] = [] output_lens: List[int] = []
retokenized_output_lens: List[int] = []
total_input = 0 total_input = 0
completed = 0 completed = 0
itls: List[float] = [] itls: List[float] = []
...@@ -463,17 +469,12 @@ def calculate_metrics( ...@@ -463,17 +469,12 @@ def calculate_metrics(
ttfts: List[float] = [] ttfts: List[float] = []
for i in range(len(outputs)): for i in range(len(outputs)):
if outputs[i].success: if outputs[i].success:
# We use the tokenizer solely to count output tokens for the TensorRT LLM backend, output_len = outputs[i].output_len
# as it lacks `ignore_eos` support. output_lens.append(output_len)
if backend == "trt": retokenized_output_len = len(
output_len = len( tokenizer(outputs[i].generated_text, add_special_tokens=False).input_ids
tokenizer( )
outputs[i].generated_text, add_special_tokens=False retokenized_output_lens.append(retokenized_output_len)
).input_ids
)
else:
output_len = outputs[i].output_len
actual_output_lens.append(output_len)
total_input += input_requests[i][1] total_input += input_requests[i][1]
if output_len > 1: if output_len > 1:
tpots.append((outputs[i].latency - outputs[i].ttft) / (output_len - 1)) tpots.append((outputs[i].latency - outputs[i].ttft) / (output_len - 1))
...@@ -481,7 +482,8 @@ def calculate_metrics( ...@@ -481,7 +482,8 @@ def calculate_metrics(
ttfts.append(outputs[i].ttft) ttfts.append(outputs[i].ttft)
completed += 1 completed += 1
else: else:
actual_output_lens.append(0) output_lens.append(0)
retokenized_output_lens.append(0)
if completed == 0: if completed == 0:
warnings.warn( warnings.warn(
...@@ -492,10 +494,12 @@ def calculate_metrics( ...@@ -492,10 +494,12 @@ def calculate_metrics(
metrics = BenchmarkMetrics( metrics = BenchmarkMetrics(
completed=completed, completed=completed,
total_input=total_input, total_input=total_input,
total_output=sum(actual_output_lens), total_output=sum(output_lens),
total_output_retokenized=sum(retokenized_output_lens),
request_throughput=completed / dur_s, request_throughput=completed / dur_s,
input_throughput=total_input / dur_s, input_throughput=total_input / dur_s,
output_throughput=sum(actual_output_lens) / dur_s, output_throughput=sum(output_lens) / dur_s,
output_throughput_retokenized=sum(retokenized_output_lens) / dur_s,
mean_ttft_ms=np.mean(ttfts or 0) mean_ttft_ms=np.mean(ttfts or 0)
* 1000, # ttfts is empty if streaming is not supported by backend * 1000, # ttfts is empty if streaming is not supported by backend
median_ttft_ms=np.median(ttfts or 0) * 1000, median_ttft_ms=np.median(ttfts or 0) * 1000,
...@@ -511,7 +515,7 @@ def calculate_metrics( ...@@ -511,7 +515,7 @@ def calculate_metrics(
p99_itl_ms=np.percentile(itls or 0, 99) * 1000, p99_itl_ms=np.percentile(itls or 0, 99) * 1000,
) )
return metrics, actual_output_lens return metrics, output_lens
async def benchmark( async def benchmark(
...@@ -572,7 +576,7 @@ async def benchmark( ...@@ -572,7 +576,7 @@ async def benchmark(
benchmark_duration = time.perf_counter() - benchmark_start_time benchmark_duration = time.perf_counter() - benchmark_start_time
metrics, actual_output_lens = calculate_metrics( metrics, output_lens = calculate_metrics(
input_requests=input_requests, input_requests=input_requests,
outputs=outputs, outputs=outputs,
dur_s=benchmark_duration, dur_s=benchmark_duration,
...@@ -587,6 +591,11 @@ async def benchmark( ...@@ -587,6 +591,11 @@ async def benchmark(
print("{:<40} {:<10.2f}".format("Benchmark duration (s):", benchmark_duration)) print("{:<40} {:<10.2f}".format("Benchmark duration (s):", benchmark_duration))
print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input)) print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input))
print("{:<40} {:<10}".format("Total generated tokens:", metrics.total_output)) print("{:<40} {:<10}".format("Total generated tokens:", metrics.total_output))
print(
"{:<40} {:<10}".format(
"Total generated tokens (retokenized):", metrics.total_output_retokenized
)
)
print( print(
"{:<40} {:<10.2f}".format( "{:<40} {:<10.2f}".format(
"Request throughput (req/s):", metrics.request_throughput "Request throughput (req/s):", metrics.request_throughput
...@@ -629,6 +638,7 @@ async def benchmark( ...@@ -629,6 +638,7 @@ async def benchmark(
"request_rate": request_rate, "request_rate": request_rate,
"total_input": metrics.total_input, "total_input": metrics.total_input,
"total_output": metrics.total_output, "total_output": metrics.total_output,
"total_output_retokenized": metrics.total_output_retokenized,
"median_ttft": metrics.median_ttft_ms, "median_ttft": metrics.median_ttft_ms,
"median_itl": metrics.mean_itl_ms, "median_itl": metrics.mean_itl_ms,
"output_token_throughput": metrics.output_throughput, "output_token_throughput": metrics.output_throughput,
...@@ -661,6 +671,7 @@ async def benchmark( ...@@ -661,6 +671,7 @@ async def benchmark(
"completed": metrics.completed, "completed": metrics.completed,
"total_input_tokens": metrics.total_input, "total_input_tokens": metrics.total_input,
"total_output_tokens": metrics.total_output, "total_output_tokens": metrics.total_output,
"total_output_tokens_retokenized": metrics.total_output_retokenized,
"request_throughput": metrics.request_throughput, "request_throughput": metrics.request_throughput,
"input_throughput": metrics.input_throughput, "input_throughput": metrics.input_throughput,
"output_throughput": metrics.output_throughput, "output_throughput": metrics.output_throughput,
...@@ -677,7 +688,7 @@ async def benchmark( ...@@ -677,7 +688,7 @@ async def benchmark(
"std_itl_ms": metrics.std_itl_ms, "std_itl_ms": metrics.std_itl_ms,
"p99_itl_ms": metrics.p99_itl_ms, "p99_itl_ms": metrics.p99_itl_ms,
"input_lens": [output.prompt_len for output in outputs], "input_lens": [output.prompt_len for output in outputs],
"output_lens": actual_output_lens, "output_lens": output_lens,
"ttfts": [output.ttft for output in outputs], "ttfts": [output.ttft for output in outputs],
"itls": [output.itl for output in outputs], "itls": [output.itl for output in outputs],
"generated_texts": [output.generated_text for output in outputs], "generated_texts": [output.generated_text for output in outputs],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment