Unverified Commit d9a69029 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Fix bench latency (#607)

parent ad872feb
...@@ -30,8 +30,10 @@ import argparse ...@@ -30,8 +30,10 @@ import argparse
import dataclasses import dataclasses
import logging import logging
import multiprocessing import multiprocessing
import os
import time import time
import numpy as np import numpy as np
import torch import torch
import torch.distributed as dist import torch.distributed as dist
...@@ -70,6 +72,7 @@ class BenchArgs: ...@@ -70,6 +72,7 @@ class BenchArgs:
def load_model(server_args, tp_rank): def load_model(server_args, tp_rank):
suppress_other_loggers() suppress_other_loggers()
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
model_config = ModelConfig(path=server_args.model_path) model_config = ModelConfig(path=server_args.model_path)
model_runner = ModelRunner( model_runner = ModelRunner(
...@@ -81,7 +84,7 @@ def load_model(server_args, tp_rank): ...@@ -81,7 +84,7 @@ def load_model(server_args, tp_rank):
nccl_port=28888, nccl_port=28888,
server_args=server_args, server_args=server_args,
) )
print(f"max_total_num_tokens={model_runner.max_total_num_tokens}") rank_print(f"max_total_num_tokens={model_runner.max_total_num_tokens}")
tokenizer = get_tokenizer( tokenizer = get_tokenizer(
server_args.tokenizer_path, server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode, tokenizer_mode=server_args.tokenizer_mode,
...@@ -201,7 +204,7 @@ def correctness_test( ...@@ -201,7 +204,7 @@ def correctness_test(
# Print # Print
for i in range(len(reqs)): for i in range(len(reqs)):
print(tokenizer.decode(output_ids[i])) rank_print(tokenizer.decode(output_ids[i]))
def latency_test( def latency_test(
...@@ -213,7 +216,7 @@ def latency_test( ...@@ -213,7 +216,7 @@ def latency_test(
# Load the model # Load the model
model_runner, tokenizer = load_model(server_args, tp_rank) model_runner, tokenizer = load_model(server_args, tp_rank)
print( rank_print(
f"max_batch_size={model_runner.max_total_num_tokens // (bench_args.input_len + bench_args.output_len)}" f"max_batch_size={model_runner.max_total_num_tokens // (bench_args.input_len + bench_args.output_len)}"
) )
...@@ -299,6 +302,8 @@ def main(server_args, bench_args): ...@@ -299,6 +302,8 @@ def main(server_args, bench_args):
for proc in workers: for proc in workers:
proc.join() proc.join()
proc.terminate()
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment