Unverified Commit d198791f authored by zhyncs's avatar zhyncs Committed by GitHub
Browse files

misc: update output token logic (#695)

parent c07526e4
......@@ -54,6 +54,7 @@ class RequestFuncOutput:
itl: List[float] = field(default_factory=list) # List of inter-token latencies
prompt_len: int = 0
error: str = ""
output_len: int = 0
def remove_prefix(text: str, prefix: str) -> str:
......@@ -189,6 +190,7 @@ async def async_request_openai_completions(
output.generated_text = generated_text
output.success = True
output.latency = latency
output.output_len = request_func_input.output_len
else:
output.error = response.reason or ""
output.success = False
......@@ -451,6 +453,7 @@ def calculate_metrics(
outputs: List[RequestFuncOutput],
dur_s: float,
tokenizer: PreTrainedTokenizerBase,
backend: str,
) -> Tuple[BenchmarkMetrics, List[int]]:
actual_output_lens: List[int] = []
total_input = 0
......@@ -460,13 +463,16 @@ def calculate_metrics(
ttfts: List[float] = []
for i in range(len(outputs)):
if outputs[i].success:
# We use the tokenizer to count the number of output tokens for all
# serving backends instead of looking at len(outputs[i].itl) since
# multiple output tokens may be bundled together
# Note : this may inflate the output token count slightly
output_len = len(
tokenizer(outputs[i].generated_text, add_special_tokens=False).input_ids
)
# We use the tokenizer solely to count output tokens for the TensorRT LLM backend,
# as it lacks `ignore_eos` support.
if backend == "trt":
output_len = len(
tokenizer(
outputs[i].generated_text, add_special_tokens=False
).input_ids
)
else:
output_len = outputs[i].output_len
actual_output_lens.append(output_len)
total_input += input_requests[i][1]
if output_len > 1:
......@@ -571,9 +577,11 @@ async def benchmark(
outputs=outputs,
dur_s=benchmark_duration,
tokenizer=tokenizer,
backend=backend,
)
print("\n{s:{c}^{n}}".format(s=" Serving Benchmark Result ", n=50, c="="))
print("{:<40} {:<10}".format("Backend:", backend))
print("{:<40} {:<10}".format("Traffic request rate:", request_rate))
print("{:<40} {:<10}".format("Successful requests:", metrics.completed))
print("{:<40} {:<10.2f}".format("Benchmark duration (s):", benchmark_duration))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment