"examples/vscode:/vscode.git/clone" did not exist on "a7d73cfdd497d7bf6c9336452decacf540c46e20"
Commit f71b1bb0 authored by Bilal Khan's avatar Bilal Khan Committed by Lysandre Debut
Browse files

Save optimizer state, scheduler state and current epoch

parent 0cb16386
......@@ -224,7 +224,7 @@ def train(args, train_dataset, model, tokenizer):
model.zero_grad()
train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
set_seed(args) # Added here for reproducibility (even between python 2 and 3)
for _ in train_iterator:
for epoch in train_iterator:
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
for step, batch in enumerate(epoch_iterator):
inputs, labels = mask_tokens(batch, tokenizer, args) if args.mlm else (batch, batch)
......@@ -279,6 +279,10 @@ def train(args, train_dataset, model, tokenizer):
_rotate_checkpoints(args, checkpoint_prefix)
torch.save(optimizer.state_dict(), os.path.join(output_dir, 'optimizer.pt'))
torch.save(scheduler.state_dict(), os.path.join(output_dir, 'scheduler.pt'))
torch.save(epoch, os.path.join(output_dir, 'training_state.pt'))
if args.max_steps > 0 and global_step > args.max_steps:
epoch_iterator.close()
break
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment