Commit a03fcf57 authored by Bilal Khan's avatar Bilal Khan Committed by Lysandre Debut
Browse files

Save tokenizer after each epoch to be able to resume training from a checkpoint

parent f71b1bb0
...@@ -274,6 +274,8 @@ def train(args, train_dataset, model, tokenizer): ...@@ -274,6 +274,8 @@ def train(args, train_dataset, model, tokenizer):
os.makedirs(output_dir) os.makedirs(output_dir)
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
model_to_save.save_pretrained(output_dir) model_to_save.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
torch.save(args, os.path.join(output_dir, 'training_args.bin')) torch.save(args, os.path.join(output_dir, 'training_args.bin'))
logger.info("Saving model checkpoint to %s", output_dir) logger.info("Saving model checkpoint to %s", output_dir)
...@@ -282,6 +284,7 @@ def train(args, train_dataset, model, tokenizer): ...@@ -282,6 +284,7 @@ def train(args, train_dataset, model, tokenizer):
torch.save(optimizer.state_dict(), os.path.join(output_dir, 'optimizer.pt')) torch.save(optimizer.state_dict(), os.path.join(output_dir, 'optimizer.pt'))
torch.save(scheduler.state_dict(), os.path.join(output_dir, 'scheduler.pt')) torch.save(scheduler.state_dict(), os.path.join(output_dir, 'scheduler.pt'))
torch.save(epoch, os.path.join(output_dir, 'training_state.pt')) torch.save(epoch, os.path.join(output_dir, 'training_state.pt'))
logger.info("Saving training state to %s", output_dir)
if args.max_steps > 0 and global_step > args.max_steps: if args.max_steps > 0 and global_step > args.max_steps:
epoch_iterator.close() epoch_iterator.close()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment