Unverified Commit d9a69029 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Fix bench latency (#607)

parent ad872feb
......@@ -30,8 +30,10 @@ import argparse
import dataclasses
import logging
import multiprocessing
import os
import time
import numpy as np
import torch
import torch.distributed as dist
......@@ -70,6 +72,7 @@ class BenchArgs:
def load_model(server_args, tp_rank):
suppress_other_loggers()
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
model_config = ModelConfig(path=server_args.model_path)
model_runner = ModelRunner(
......@@ -81,7 +84,7 @@ def load_model(server_args, tp_rank):
nccl_port=28888,
server_args=server_args,
)
print(f"max_total_num_tokens={model_runner.max_total_num_tokens}")
rank_print(f"max_total_num_tokens={model_runner.max_total_num_tokens}")
tokenizer = get_tokenizer(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
......@@ -201,7 +204,7 @@ def correctness_test(
# Print
for i in range(len(reqs)):
print(tokenizer.decode(output_ids[i]))
rank_print(tokenizer.decode(output_ids[i]))
def latency_test(
......@@ -213,7 +216,7 @@ def latency_test(
# Load the model
model_runner, tokenizer = load_model(server_args, tp_rank)
print(
rank_print(
f"max_batch_size={model_runner.max_total_num_tokens // (bench_args.input_len + bench_args.output_len)}"
)
......@@ -299,6 +302,8 @@ def main(server_args, bench_args):
for proc in workers:
proc.join()
proc.terminate()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment