Commit a7ee77ea authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

flag for data parallel random initialization

parent fd8dd9c0
...@@ -518,6 +518,9 @@ def _add_initialization_args(parser): ...@@ -518,6 +518,9 @@ def _add_initialization_args(parser):
group.add_argument('--seed', type=int, default=1234, group.add_argument('--seed', type=int, default=1234,
help='Random seed used for python, numpy, ' help='Random seed used for python, numpy, '
'pytorch, and cuda.') 'pytorch, and cuda.')
group.add_argument('--data-parallel-random-init', action='store_true',
help='Enable random initialization of params '
'across data parallel ranks')
group.add_argument('--init-method-std', type=float, default=0.02, group.add_argument('--init-method-std', type=float, default=0.02,
help='Standard deviation of the zero mean normal ' help='Standard deviation of the zero mean normal '
'distribution used for weight initialization.') 'distribution used for weight initialization.')
......
...@@ -142,6 +142,7 @@ def read_metadata(tracker_filename): ...@@ -142,6 +142,7 @@ def read_metadata(tracker_filename):
def get_rng_state(): def get_rng_state():
""" collect rng state across data parallel ranks """ """ collect rng state across data parallel ranks """
args = get_args()
rng_state = { rng_state = {
'random_rng_state': random.getstate(), 'random_rng_state': random.getstate(),
'np_rng_state': np.random.get_state(), 'np_rng_state': np.random.get_state(),
...@@ -151,7 +152,8 @@ def get_rng_state(): ...@@ -151,7 +152,8 @@ def get_rng_state():
rng_state_list = None rng_state_list = None
if torch.distributed.is_initialized() and \ if torch.distributed.is_initialized() and \
mpu.get_data_parallel_world_size() > 1: mpu.get_data_parallel_world_size() > 1 and \
args.data_parallel_random_init:
if mpu.get_data_parallel_rank() == 0: if mpu.get_data_parallel_rank() == 0:
rng_state_list = \ rng_state_list = \
[None for i in range(mpu.get_data_parallel_world_size())] [None for i in range(mpu.get_data_parallel_world_size())]
...@@ -407,7 +409,10 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True ...@@ -407,7 +409,10 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
try: try:
if 'rng_state' in state_dict: if 'rng_state' in state_dict:
# access rng_state for data parallel rank # access rng_state for data parallel rank
rng_state = state_dict['rng_state'][mpu.get_data_parallel_rank()] if args.data_parallel_random_init:
rng_state = state_dict['rng_state'][mpu.get_data_parallel_rank()]
else:
rng_state = state_dict['rng_state'][0]
random.setstate(rng_state['random_rng_state']) random.setstate(rng_state['random_rng_state'])
np.random.set_state(rng_state['np_rng_state']) np.random.set_state(rng_state['np_rng_state'])
torch.set_rng_state(rng_state['torch_rng_state']) torch.set_rng_state(rng_state['torch_rng_state'])
......
...@@ -62,7 +62,7 @@ def initialize_megatron(extra_args_provider=None, args_defaults={}, ...@@ -62,7 +62,7 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
# Random seeds for reproducibility. # Random seeds for reproducibility.
if args.rank == 0: if args.rank == 0:
print('> setting random seeds to {} ...'.format(args.seed)) print('> setting random seeds to {} ...'.format(args.seed))
_set_random_seed(args.seed) _set_random_seed(args.seed, args.data_parallel_random_init)
# Set pytorch JIT layer fusion options. # Set pytorch JIT layer fusion options.
_set_jit_fusion_options() _set_jit_fusion_options()
...@@ -203,11 +203,14 @@ def _init_autoresume(): ...@@ -203,11 +203,14 @@ def _init_autoresume():
torch.distributed.barrier() torch.distributed.barrier()
def _set_random_seed(seed_): def _set_random_seed(seed_, data_parallel_random_init=False):
"""Set random seed for reproducability.""" """Set random seed for reproducability."""
if seed_ is not None and seed_ > 0: if seed_ is not None and seed_ > 0:
# Ensure that different pipeline MP stages and different data parallel ranks get different seeds. # Ensure that different pipeline MP stages get different seeds.
seed = seed_ + (100 * mpu.get_pipeline_model_parallel_rank()) + (10 * mpu.get_data_parallel_rank()) seed = seed_ + (100 * mpu.get_pipeline_model_parallel_rank())
# Ensure different data parallel ranks get different seeds
if data_parallel_random_init:
seed = seed + (10 * mpu.get_data_parallel_rank())
random.seed(seed) random.seed(seed)
np.random.seed(seed) np.random.seed(seed)
torch.manual_seed(seed) torch.manual_seed(seed)
......
...@@ -285,8 +285,10 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap ...@@ -285,8 +285,10 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap
args.accumulate_allreduce_grads_in_fp32, args.accumulate_allreduce_grads_in_fp32,
args.use_contiguous_buffers_in_local_ddp) args.use_contiguous_buffers_in_local_ddp)
for model_module in model] for model_module in model]
for model_module in model: # broad cast params from data parallel src rank to other data parallel ranks
model_module.broadcast_params() if args.data_parallel_random_init:
for model_module in model:
model_module.broadcast_params()
else: else:
raise NotImplementedError('Unknown DDP implementation specified: ' raise NotImplementedError('Unknown DDP implementation specified: '
'{}. Exiting.'.format(args.DDP_impl)) '{}. Exiting.'.format(args.DDP_impl))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment