Commit a7ee77ea authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

flag for data parallel random initialization

parent fd8dd9c0
......@@ -518,6 +518,9 @@ def _add_initialization_args(parser):
group.add_argument('--seed', type=int, default=1234,
help='Random seed used for python, numpy, '
'pytorch, and cuda.')
group.add_argument('--data-parallel-random-init', action='store_true',
help='Enable random initialization of params '
'across data parallel ranks')
group.add_argument('--init-method-std', type=float, default=0.02,
help='Standard deviation of the zero mean normal '
'distribution used for weight initialization.')
......
......@@ -142,6 +142,7 @@ def read_metadata(tracker_filename):
def get_rng_state():
""" collect rng state across data parallel ranks """
args = get_args()
rng_state = {
'random_rng_state': random.getstate(),
'np_rng_state': np.random.get_state(),
......@@ -151,7 +152,8 @@ def get_rng_state():
rng_state_list = None
if torch.distributed.is_initialized() and \
mpu.get_data_parallel_world_size() > 1:
mpu.get_data_parallel_world_size() > 1 and \
args.data_parallel_random_init:
if mpu.get_data_parallel_rank() == 0:
rng_state_list = \
[None for i in range(mpu.get_data_parallel_world_size())]
......@@ -407,7 +409,10 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
try:
if 'rng_state' in state_dict:
# access rng_state for data parallel rank
rng_state = state_dict['rng_state'][mpu.get_data_parallel_rank()]
if args.data_parallel_random_init:
rng_state = state_dict['rng_state'][mpu.get_data_parallel_rank()]
else:
rng_state = state_dict['rng_state'][0]
random.setstate(rng_state['random_rng_state'])
np.random.set_state(rng_state['np_rng_state'])
torch.set_rng_state(rng_state['torch_rng_state'])
......
......@@ -62,7 +62,7 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
# Random seeds for reproducibility.
if args.rank == 0:
print('> setting random seeds to {} ...'.format(args.seed))
_set_random_seed(args.seed)
_set_random_seed(args.seed, args.data_parallel_random_init)
# Set pytorch JIT layer fusion options.
_set_jit_fusion_options()
......@@ -203,11 +203,14 @@ def _init_autoresume():
torch.distributed.barrier()
def _set_random_seed(seed_):
def _set_random_seed(seed_, data_parallel_random_init=False):
"""Set random seed for reproducability."""
if seed_ is not None and seed_ > 0:
# Ensure that different pipeline MP stages and different data parallel ranks get different seeds.
seed = seed_ + (100 * mpu.get_pipeline_model_parallel_rank()) + (10 * mpu.get_data_parallel_rank())
# Ensure that different pipeline MP stages get different seeds.
seed = seed_ + (100 * mpu.get_pipeline_model_parallel_rank())
# Ensure different data parallel ranks get different seeds
if data_parallel_random_init:
seed = seed + (10 * mpu.get_data_parallel_rank())
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
......
......@@ -285,8 +285,10 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap
args.accumulate_allreduce_grads_in_fp32,
args.use_contiguous_buffers_in_local_ddp)
for model_module in model]
for model_module in model:
model_module.broadcast_params()
# broad cast params from data parallel src rank to other data parallel ranks
if args.data_parallel_random_init:
for model_module in model:
model_module.broadcast_params()
else:
raise NotImplementedError('Unknown DDP implementation specified: '
'{}. Exiting.'.format(args.DDP_impl))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment