Commit d5b526d5 authored by Deepak Narayanan's avatar Deepak Narayanan
Browse files

Back compatibility of checkpoints: use `model_parallel_size` when checking for equality of args

parent 318d68c2
...@@ -41,10 +41,13 @@ def get_checkpoint_version(): ...@@ -41,10 +41,13 @@ def get_checkpoint_version():
def check_checkpoint_args(checkpoint_args): def check_checkpoint_args(checkpoint_args):
"""Ensure fixed arguments for a model are the same for the input """Ensure fixed arguments for a model are the same for the input
arguments and the one retreived frm checkpoint.""" arguments and the one retrieved from checkpoint."""
args = get_args() args = get_args()
def _compare(arg_name): def _compare(arg_name, old_arg_name=None):
if old_arg_name is not None:
checkpoint_value = getattr(checkpoint_args, old_arg_name)
else:
checkpoint_value = getattr(checkpoint_args, arg_name) checkpoint_value = getattr(checkpoint_args, arg_name)
args_value = getattr(args, arg_name) args_value = getattr(args, arg_name)
error_message = '{} value from checkpoint ({}) is not equal to the ' \ error_message = '{} value from checkpoint ({}) is not equal to the ' \
...@@ -59,7 +62,12 @@ def check_checkpoint_args(checkpoint_args): ...@@ -59,7 +62,12 @@ def check_checkpoint_args(checkpoint_args):
_compare('make_vocab_size_divisible_by') _compare('make_vocab_size_divisible_by')
_compare('padded_vocab_size') _compare('padded_vocab_size')
_compare('tokenizer_type') _compare('tokenizer_type')
if get_checkpoint_version() < 3.0:
_compare('tensor_model_parallel_size',
old_arg_name='model_parallel_size')
if get_checkpoint_version() >= 3.0:
_compare('tensor_model_parallel_size') _compare('tensor_model_parallel_size')
_compare('pipeline_model_parallel_size')
def ensure_directory_exists(filename): def ensure_directory_exists(filename):
...@@ -107,7 +115,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler): ...@@ -107,7 +115,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
# Arguments, iteration, and model. # Arguments, iteration, and model.
state_dict = {} state_dict = {}
state_dict['args'] = args state_dict['args'] = args
state_dict['checkpoint_version'] = 2.0 state_dict['checkpoint_version'] = 3.0
state_dict['iteration'] = iteration state_dict['iteration'] = iteration
state_dict['model'] = model.state_dict_for_save_checkpoint() state_dict['model'] = model.state_dict_for_save_checkpoint()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment