Commit b037a69e authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

using all_gather instead of gather (nccl does not support gather)

parent a7ee77ea
......@@ -154,13 +154,11 @@ def get_rng_state():
if torch.distributed.is_initialized() and \
mpu.get_data_parallel_world_size() > 1 and \
args.data_parallel_random_init:
if mpu.get_data_parallel_rank() == 0:
rng_state_list = \
[None for i in range(mpu.get_data_parallel_world_size())]
torch.distributed.gather_object(
rng_state,
rng_state_list = \
[None for i in range(mpu.get_data_parallel_world_size())]
torch.distributed.all_gather_object(
rng_state_list,
dst=mpu.get_data_parallel_src_rank(),
rng_state,
group=mpu.get_data_parallel_group())
else:
rng_state_list = [rng_state]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment