Commit d5b526d5 authored by Deepak Narayanan's avatar Deepak Narayanan
Browse files

Back compatibility of checkpoints: use `model_parallel_size` when checking for equality of args

parent 318d68c2
......@@ -41,11 +41,14 @@ def get_checkpoint_version():
def check_checkpoint_args(checkpoint_args):
"""Ensure fixed arguments for a model are the same for the input
arguments and the one retreived frm checkpoint."""
arguments and the one retrieved from checkpoint."""
args = get_args()
def _compare(arg_name):
checkpoint_value = getattr(checkpoint_args, arg_name)
def _compare(arg_name, old_arg_name=None):
if old_arg_name is not None:
checkpoint_value = getattr(checkpoint_args, old_arg_name)
else:
checkpoint_value = getattr(checkpoint_args, arg_name)
args_value = getattr(args, arg_name)
error_message = '{} value from checkpoint ({}) is not equal to the ' \
'input argument value ({}).'.format(
......@@ -59,7 +62,12 @@ def check_checkpoint_args(checkpoint_args):
_compare('make_vocab_size_divisible_by')
_compare('padded_vocab_size')
_compare('tokenizer_type')
_compare('tensor_model_parallel_size')
if get_checkpoint_version() < 3.0:
_compare('tensor_model_parallel_size',
old_arg_name='model_parallel_size')
if get_checkpoint_version() >= 3.0:
_compare('tensor_model_parallel_size')
_compare('pipeline_model_parallel_size')
def ensure_directory_exists(filename):
......@@ -107,7 +115,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
# Arguments, iteration, and model.
state_dict = {}
state_dict['args'] = args
state_dict['checkpoint_version'] = 2.0
state_dict['checkpoint_version'] = 3.0
state_dict['iteration'] = iteration
state_dict['model'] = model.state_dict_for_save_checkpoint()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment