Unverified Commit 2f2e0743 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Fix weight update for data parallelism (#2050)

parent 2ffe0a73
...@@ -231,15 +231,16 @@ def throughput_test( ...@@ -231,15 +231,16 @@ def throughput_test(
input_requests = get_dataset(bench_args, tokenizer) input_requests = get_dataset(bench_args, tokenizer)
warmup_requests = sample_random_requests( warmup_requests = sample_random_requests(
input_len=20, input_len=256,
output_len=4, output_len=16,
num_prompts=2, num_prompts=16,
range_ratio=0.8, range_ratio=0.8,
tokenizer=tokenizer, tokenizer=tokenizer,
dataset_path=bench_args.dataset_path, dataset_path=bench_args.dataset_path,
) )
# Warm up # Warm up
logging.info("\nWarmup...")
throughput_test_once( throughput_test_once(
backend_name=bench_args.backend, backend_name=bench_args.backend,
backend=backend, backend=backend,
...@@ -247,6 +248,7 @@ def throughput_test( ...@@ -247,6 +248,7 @@ def throughput_test(
ignore_eos=not bench_args.disable_ignore_eos, ignore_eos=not bench_args.disable_ignore_eos,
) )
logging.info("\nBenchmark...")
result = throughput_test_once( result = throughput_test_once(
backend_name=bench_args.backend, backend_name=bench_args.backend,
backend=backend, backend=backend,
......
...@@ -83,6 +83,7 @@ class DataParallelController: ...@@ -83,6 +83,7 @@ class DataParallelController:
self.workers = [] self.workers = []
for dp_rank in range(server_args.dp_size): for dp_rank in range(server_args.dp_size):
tmp_port_args = PortArgs.init_new(server_args) tmp_port_args = PortArgs.init_new(server_args)
tmp_port_args.tokenizer_ipc_name = port_args.tokenizer_ipc_name
tmp_port_args.detokenizer_ipc_name = port_args.detokenizer_ipc_name tmp_port_args.detokenizer_ipc_name = port_args.detokenizer_ipc_name
send_to = self.launch_tensor_parallel_group( send_to = self.launch_tensor_parallel_group(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment