Commit b037a69e authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

using all_gather instead of gather (nccl does not support gather)

parent a7ee77ea
...@@ -154,13 +154,11 @@ def get_rng_state(): ...@@ -154,13 +154,11 @@ def get_rng_state():
if torch.distributed.is_initialized() and \ if torch.distributed.is_initialized() and \
mpu.get_data_parallel_world_size() > 1 and \ mpu.get_data_parallel_world_size() > 1 and \
args.data_parallel_random_init: args.data_parallel_random_init:
if mpu.get_data_parallel_rank() == 0: rng_state_list = \
rng_state_list = \ [None for i in range(mpu.get_data_parallel_world_size())]
[None for i in range(mpu.get_data_parallel_world_size())] torch.distributed.all_gather_object(
torch.distributed.gather_object(
rng_state,
rng_state_list, rng_state_list,
dst=mpu.get_data_parallel_src_rank(), rng_state,
group=mpu.get_data_parallel_group()) group=mpu.get_data_parallel_group())
else: else:
rng_state_list = [rng_state] rng_state_list = [rng_state]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment