"...resnet50_tensorflow.git" did not exist on "331e137215180ef1a747ac377b18b4cb4abedf74"
Commit f71b1bb0 authored by Bilal Khan's avatar Bilal Khan Committed by Lysandre Debut
Browse files

Save optimizer state, scheduler state and current epoch

parent 0cb16386
...@@ -224,7 +224,7 @@ def train(args, train_dataset, model, tokenizer): ...@@ -224,7 +224,7 @@ def train(args, train_dataset, model, tokenizer):
model.zero_grad() model.zero_grad()
train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]) train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
set_seed(args) # Added here for reproducibility (even between python 2 and 3) set_seed(args) # Added here for reproducibility (even between python 2 and 3)
for _ in train_iterator: for epoch in train_iterator:
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0]) epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
for step, batch in enumerate(epoch_iterator): for step, batch in enumerate(epoch_iterator):
inputs, labels = mask_tokens(batch, tokenizer, args) if args.mlm else (batch, batch) inputs, labels = mask_tokens(batch, tokenizer, args) if args.mlm else (batch, batch)
...@@ -279,6 +279,10 @@ def train(args, train_dataset, model, tokenizer): ...@@ -279,6 +279,10 @@ def train(args, train_dataset, model, tokenizer):
_rotate_checkpoints(args, checkpoint_prefix) _rotate_checkpoints(args, checkpoint_prefix)
torch.save(optimizer.state_dict(), os.path.join(output_dir, 'optimizer.pt'))
torch.save(scheduler.state_dict(), os.path.join(output_dir, 'scheduler.pt'))
torch.save(epoch, os.path.join(output_dir, 'training_state.pt'))
if args.max_steps > 0 and global_step > args.max_steps: if args.max_steps > 0 and global_step > args.max_steps:
epoch_iterator.close() epoch_iterator.close()
break break
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment