Commit 798b6a64 authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

address review comments

parent b037a69e
...@@ -65,6 +65,8 @@ def check_checkpoint_args(checkpoint_args): ...@@ -65,6 +65,8 @@ def check_checkpoint_args(checkpoint_args):
_compare('make_vocab_size_divisible_by') _compare('make_vocab_size_divisible_by')
_compare('padded_vocab_size') _compare('padded_vocab_size')
_compare('tokenizer_type') _compare('tokenizer_type')
if args.data_parallel_random_init:
_compare('data_parallel_random_init')
if get_checkpoint_version() < 3.0: if get_checkpoint_version() < 3.0:
_compare('tensor_model_parallel_size', _compare('tensor_model_parallel_size',
old_arg_name='model_parallel_size') old_arg_name='model_parallel_size')
...@@ -72,7 +74,6 @@ def check_checkpoint_args(checkpoint_args): ...@@ -72,7 +74,6 @@ def check_checkpoint_args(checkpoint_args):
_compare('tensor_model_parallel_size') _compare('tensor_model_parallel_size')
_compare('pipeline_model_parallel_size') _compare('pipeline_model_parallel_size')
def ensure_directory_exists(filename): def ensure_directory_exists(filename):
"""Build filename's path if it does not already exists.""" """Build filename's path if it does not already exists."""
dirname = os.path.dirname(filename) dirname = os.path.dirname(filename)
...@@ -408,6 +409,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True ...@@ -408,6 +409,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
if 'rng_state' in state_dict: if 'rng_state' in state_dict:
# access rng_state for data parallel rank # access rng_state for data parallel rank
if args.data_parallel_random_init: if args.data_parallel_random_init:
rng_state = state_dict['rng_state'][mpu.get_data_parallel_rank()] rng_state = state_dict['rng_state'][mpu.get_data_parallel_rank()]
else: else:
rng_state = state_dict['rng_state'][0] rng_state = state_dict['rng_state'][0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment