Commit cebd3b8b authored by mohammad's avatar mohammad
Browse files

addrressed jareds comments

parent f0a445fa
......@@ -89,8 +89,7 @@ def get_checkpoint_tracker_filename(checkpoints_path):
return os.path.join(checkpoints_path, 'latest_checkpointed_iteration.txt')
def save_checkpoint(iteration, model, optimizer, lr_scheduler,
consumed_train_samples=None, consumed_valid_samples=None):
def save_checkpoint(iteration, model, optimizer, lr_scheduler):
"""Save a model checkpoint."""
args = get_args()
......@@ -104,10 +103,6 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler,
state_dict['args'] = args
state_dict['checkpoint_version'] = 2.0
state_dict['iteration'] = iteration
if consumed_train_samples:
state_dict['consumed_train_samples'] = consumed_train_samples
if consumed_valid_samples:
state_dict['consumed_valid_samples'] = consumed_valid_samples
state_dict['model'] = model.state_dict_for_save_checkpoint()
# Optimizer stuff.
......@@ -219,17 +214,14 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'):
checkpoint_name))
sys.exit()
if 'consumed_train_samples' in state_dict:
assert args.consumed_train_samples == 0
args.consumed_train_samples = state_dict['consumed_train_samples']
if 'consumed_valid_samples' in state_dict:
assert args.consumed_valid_samples == 0
args.consumed_valid_samples = state_dict['consumed_valid_samples']
# Check arguments.
assert args.consumed_train_samples == 0
assert args.consumed_valid_samples == 0
if 'args' in state_dict:
checkpoint_args = state_dict['args']
check_checkpoint_args(checkpoint_args)
args.consumed_train_samples = getattr(args, 'consumed_train_samples', 0)
args.consumed_valid_samples = getattr(args, 'consumed_valid_samples', 0)
else:
print_rank_0('could not find arguments in the checkpoint ...')
......
......@@ -104,9 +104,7 @@ def pretrain(train_valid_test_dataset_provider, model_provider,
iteration, False)
if args.save and iteration != 0:
save_checkpoint(iteration, model, optimizer, lr_scheduler,
consumed_train_samples=args.consumed_train_samples,
consumed_valid_samples=args.consumed_valid_samples)
save_checkpoint(iteration, model, optimizer, lr_scheduler)
if args.do_test:
# Run on test data.
......@@ -438,9 +436,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
# Checkpointing
if args.save and args.save_interval and \
iteration % args.save_interval == 0:
save_checkpoint(iteration, model, optimizer, lr_scheduler,
consumed_train_samples=args.consumed_train_samples,
consumed_valid_samples=args.consumed_valid_samples)
save_checkpoint(iteration, model, optimizer, lr_scheduler)
# Evaluation
if args.eval_interval and iteration % args.eval_interval == 0 and \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment