Unverified Commit 2b302b93 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Fix the port_args in bench_latency (#1597)

parent 68f8b60d
...@@ -123,11 +123,10 @@ class BenchArgs: ...@@ -123,11 +123,10 @@ class BenchArgs:
) )
def load_model(server_args, tp_rank): def load_model(server_args, port_args, tp_rank):
suppress_other_loggers() suppress_other_loggers()
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
port_args = PortArgs.init_new(server_args)
model_config = ModelConfig( model_config = ModelConfig(
server_args.model_path, server_args.model_path,
server_args.trust_remote_code, server_args.trust_remote_code,
...@@ -248,13 +247,14 @@ def decode(input_token_ids, batch, model_runner): ...@@ -248,13 +247,14 @@ def decode(input_token_ids, batch, model_runner):
@torch.inference_mode() @torch.inference_mode()
def correctness_test( def correctness_test(
server_args, server_args,
port_args,
bench_args, bench_args,
tp_rank, tp_rank,
): ):
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
# Load the model # Load the model
model_runner, tokenizer = load_model(server_args, tp_rank) model_runner, tokenizer = load_model(server_args, port_args, tp_rank)
# Prepare inputs # Prepare inputs
input_ids, reqs = prepare_inputs_for_correctness_test(bench_args, tokenizer) input_ids, reqs = prepare_inputs_for_correctness_test(bench_args, tokenizer)
...@@ -362,6 +362,7 @@ def latency_test_run_once( ...@@ -362,6 +362,7 @@ def latency_test_run_once(
def latency_test( def latency_test(
server_args, server_args,
port_args,
bench_args, bench_args,
tp_rank, tp_rank,
): ):
...@@ -369,7 +370,7 @@ def latency_test( ...@@ -369,7 +370,7 @@ def latency_test(
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
# Load the model # Load the model
model_runner, tokenizer = load_model(server_args, tp_rank) model_runner, tokenizer = load_model(server_args, port_args, tp_rank)
# Prepare inputs for warm up # Prepare inputs for warm up
reqs = prepare_synthetic_inputs_for_latency_test( reqs = prepare_synthetic_inputs_for_latency_test(
...@@ -487,8 +488,10 @@ def main(server_args, bench_args): ...@@ -487,8 +488,10 @@ def main(server_args, bench_args):
"provide --result-filename for plotting the results" "provide --result-filename for plotting the results"
) )
port_args = PortArgs.init_new(server_args)
if server_args.tp_size == 1: if server_args.tp_size == 1:
work_func(server_args, bench_args, 0) work_func(server_args, port_args, bench_args, 0)
else: else:
workers = [] workers = []
for tp_rank in range(server_args.tp_size): for tp_rank in range(server_args.tp_size):
...@@ -496,6 +499,7 @@ def main(server_args, bench_args): ...@@ -496,6 +499,7 @@ def main(server_args, bench_args):
target=work_func, target=work_func,
args=( args=(
server_args, server_args,
port_args,
bench_args, bench_args,
tp_rank, tp_rank,
), ),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment