Commit e2a4d426 authored by mohammad's avatar mohammad
Browse files

found a bug in consumed tokens initialization

parent 75bd9b54
...@@ -220,8 +220,10 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'): ...@@ -220,8 +220,10 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'):
if 'args' in state_dict: if 'args' in state_dict:
checkpoint_args = state_dict['args'] checkpoint_args = state_dict['args']
check_checkpoint_args(checkpoint_args) check_checkpoint_args(checkpoint_args)
args.consumed_train_samples = getattr(args, 'consumed_train_samples', 0) args.consumed_train_samples = getattr(checkpoint_args,
args.consumed_valid_samples = getattr(args, 'consumed_valid_samples', 0) 'consumed_train_samples', 0)
args.consumed_valid_samples = getattr(checkpoint_args,
'consumed_valid_samples', 0)
else: else:
print_rank_0('could not find arguments in the checkpoint ...') print_rank_0('could not find arguments in the checkpoint ...')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment