Unverified Commit 5b7fb6a4 authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

Merge pull request #2134 from bkkaggle/saving-and-resuming

closes #1960 Add saving and resuming functionality for remaining examples
parents 6f68d559 b03872aa
...@@ -112,6 +112,13 @@ def train(args, train_dataset, model, tokenizer): ...@@ -112,6 +112,13 @@ def train(args, train_dataset, model, tokenizer):
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total) scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total)
# Check if saved optimizer or scheduler states exist
if os.path.isfile(os.path.join(args.model_name_or_path, 'optimizer.pt')) and os.path.isfile(os.path.join(args.model_name_or_path, 'scheduler.pt')):
# Load in optimizer and scheduler states
optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, 'optimizer.pt')))
scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, 'scheduler.pt')))
if args.fp16: if args.fp16:
try: try:
from apex import amp from apex import amp
...@@ -140,13 +147,33 @@ def train(args, train_dataset, model, tokenizer): ...@@ -140,13 +147,33 @@ def train(args, train_dataset, model, tokenizer):
logger.info(" Total optimization steps = %d", t_total) logger.info(" Total optimization steps = %d", t_total)
global_step = 0 global_step = 0
epochs_trained = 0
steps_trained_in_current_epoch = 0
# Check if continuing training from a checkpoint
if os.path.exists(args.model_name_or_path):
# set global_step to gobal_step of last saved checkpoint from model path
global_step = int(args.model_name_or_path.split('-')[-1].split('/')[0])
epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps)
steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps)
logger.info(" Continuing training from checkpoint, will skip to saved global_step")
logger.info(" Continuing training from epoch %d", epochs_trained)
logger.info(" Continuing training from global step %d", global_step)
logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
tr_loss, logging_loss = 0.0, 0.0 tr_loss, logging_loss = 0.0, 0.0
model.zero_grad() model.zero_grad()
train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]) train_iterator = trange(epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
set_seed(args) # Added here for reproductibility (even between python 2 and 3) set_seed(args) # Added here for reproductibility (even between python 2 and 3)
for _ in train_iterator: for _ in train_iterator:
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0]) epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
for step, batch in enumerate(epoch_iterator): for step, batch in enumerate(epoch_iterator):
# Skip past any already trained steps if resuming training
if steps_trained_in_current_epoch > 0:
steps_trained_in_current_epoch -= 1
continue
model.train() model.train()
batch = tuple(t.to(args.device) for t in batch) batch = tuple(t.to(args.device) for t in batch)
inputs = {'input_ids': batch[0], inputs = {'input_ids': batch[0],
...@@ -205,9 +232,15 @@ def train(args, train_dataset, model, tokenizer): ...@@ -205,9 +232,15 @@ def train(args, train_dataset, model, tokenizer):
os.makedirs(output_dir) os.makedirs(output_dir)
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
model_to_save.save_pretrained(output_dir) model_to_save.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
torch.save(args, os.path.join(output_dir, 'training_args.bin')) torch.save(args, os.path.join(output_dir, 'training_args.bin'))
logger.info("Saving model checkpoint to %s", output_dir) logger.info("Saving model checkpoint to %s", output_dir)
torch.save(optimizer.state_dict(), os.path.join(output_dir, 'optimizer.pt'))
torch.save(scheduler.state_dict(), os.path.join(output_dir, 'scheduler.pt'))
logger.info("Saving optimizer and scheduler states to %s", output_dir)
if args.max_steps > 0 and global_step > args.max_steps: if args.max_steps > 0 and global_step > args.max_steps:
epoch_iterator.close() epoch_iterator.close()
break break
......
...@@ -88,6 +88,13 @@ def train(args, train_dataset, model, tokenizer, labels, pad_token_label_id): ...@@ -88,6 +88,13 @@ def train(args, train_dataset, model, tokenizer, labels, pad_token_label_id):
] ]
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total) scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total)
# Check if saved optimizer or scheduler states exist
if os.path.isfile(os.path.join(args.model_name_or_path, 'optimizer.pt')) and os.path.isfile(os.path.join(args.model_name_or_path, 'scheduler.pt')):
# Load in optimizer and scheduler states
optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, 'optimizer.pt')))
scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, 'scheduler.pt')))
if args.fp16: if args.fp16:
try: try:
from apex import amp from apex import amp
...@@ -117,13 +124,33 @@ def train(args, train_dataset, model, tokenizer, labels, pad_token_label_id): ...@@ -117,13 +124,33 @@ def train(args, train_dataset, model, tokenizer, labels, pad_token_label_id):
logger.info(" Total optimization steps = %d", t_total) logger.info(" Total optimization steps = %d", t_total)
global_step = 0 global_step = 0
epochs_trained = 0
steps_trained_in_current_epoch = 0
# Check if continuing training from a checkpoint
if os.path.exists(args.model_name_or_path):
# set global_step to gobal_step of last saved checkpoint from model path
global_step = int(args.model_name_or_path.split('-')[-1].split('/')[0])
epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps)
steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps)
logger.info(" Continuing training from checkpoint, will skip to saved global_step")
logger.info(" Continuing training from epoch %d", epochs_trained)
logger.info(" Continuing training from global step %d", global_step)
logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
tr_loss, logging_loss = 0.0, 0.0 tr_loss, logging_loss = 0.0, 0.0
model.zero_grad() model.zero_grad()
train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]) train_iterator = trange(epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
set_seed(args) # Added here for reproductibility (even between python 2 and 3) set_seed(args) # Added here for reproductibility (even between python 2 and 3)
for _ in train_iterator: for _ in train_iterator:
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0]) epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
for step, batch in enumerate(epoch_iterator): for step, batch in enumerate(epoch_iterator):
# Skip past any already trained steps if resuming training
if steps_trained_in_current_epoch > 0:
steps_trained_in_current_epoch -= 1
continue
model.train() model.train()
batch = tuple(t.to(args.device) for t in batch) batch = tuple(t.to(args.device) for t in batch)
inputs = {"input_ids": batch[0], inputs = {"input_ids": batch[0],
...@@ -175,9 +202,15 @@ def train(args, train_dataset, model, tokenizer, labels, pad_token_label_id): ...@@ -175,9 +202,15 @@ def train(args, train_dataset, model, tokenizer, labels, pad_token_label_id):
os.makedirs(output_dir) os.makedirs(output_dir)
model_to_save = model.module if hasattr(model, "module") else model # Take care of distributed/parallel training model_to_save = model.module if hasattr(model, "module") else model # Take care of distributed/parallel training
model_to_save.save_pretrained(output_dir) model_to_save.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
torch.save(args, os.path.join(output_dir, "training_args.bin")) torch.save(args, os.path.join(output_dir, "training_args.bin"))
logger.info("Saving model checkpoint to %s", output_dir) logger.info("Saving model checkpoint to %s", output_dir)
torch.save(optimizer.state_dict(), os.path.join(output_dir, 'optimizer.pt'))
torch.save(scheduler.state_dict(), os.path.join(output_dir, 'scheduler.pt'))
logger.info("Saving optimizer and scheduler states to %s", output_dir)
if args.max_steps > 0 and global_step > args.max_steps: if args.max_steps > 0 and global_step > args.max_steps:
epoch_iterator.close() epoch_iterator.close()
break break
......
This diff is collapsed.
...@@ -92,6 +92,13 @@ def train(args, train_dataset, model, tokenizer): ...@@ -92,6 +92,13 @@ def train(args, train_dataset, model, tokenizer):
] ]
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total) scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total)
# Check if saved optimizer or scheduler states exist
if os.path.isfile(os.path.join(args.model_name_or_path, 'optimizer.pt')) and os.path.isfile(os.path.join(args.model_name_or_path, 'scheduler.pt')):
# Load in optimizer and scheduler states
optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, 'optimizer.pt')))
scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, 'scheduler.pt')))
if args.fp16: if args.fp16:
try: try:
from apex import amp from apex import amp
...@@ -120,13 +127,32 @@ def train(args, train_dataset, model, tokenizer): ...@@ -120,13 +127,32 @@ def train(args, train_dataset, model, tokenizer):
logger.info(" Total optimization steps = %d", t_total) logger.info(" Total optimization steps = %d", t_total)
global_step = 0 global_step = 0
epochs_trained = 0
steps_trained_in_current_epoch = 0
# Check if continuing training from a checkpoint
if os.path.exists(args.model_name_or_path):
# set global_step to gobal_step of last saved checkpoint from model path
global_step = int(args.model_name_or_path.split('-')[-1].split('/')[0])
epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps)
steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps)
logger.info(" Continuing training from checkpoint, will skip to saved global_step")
logger.info(" Continuing training from epoch %d", epochs_trained)
logger.info(" Continuing training from global step %d", global_step)
logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
tr_loss, logging_loss = 0.0, 0.0 tr_loss, logging_loss = 0.0, 0.0
model.zero_grad() model.zero_grad()
train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]) train_iterator = trange(epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
set_seed(args) # Added here for reproductibility (even between python 2 and 3) set_seed(args) # Added here for reproductibility (even between python 2 and 3)
for _ in train_iterator: for _ in train_iterator:
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0]) epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
for step, batch in enumerate(epoch_iterator): for step, batch in enumerate(epoch_iterator):
# Skip past any already trained steps if resuming training
if steps_trained_in_current_epoch > 0:
steps_trained_in_current_epoch -= 1
continue
model.train() model.train()
batch = tuple(t.to(args.device) for t in batch) batch = tuple(t.to(args.device) for t in batch)
inputs = {'input_ids': batch[0], inputs = {'input_ids': batch[0],
...@@ -177,9 +203,15 @@ def train(args, train_dataset, model, tokenizer): ...@@ -177,9 +203,15 @@ def train(args, train_dataset, model, tokenizer):
os.makedirs(output_dir) os.makedirs(output_dir)
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
model_to_save.save_pretrained(output_dir) model_to_save.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
torch.save(args, os.path.join(output_dir, 'training_args.bin')) torch.save(args, os.path.join(output_dir, 'training_args.bin'))
logger.info("Saving model checkpoint to %s", output_dir) logger.info("Saving model checkpoint to %s", output_dir)
torch.save(optimizer.state_dict(), os.path.join(output_dir, 'optimizer.pt'))
torch.save(scheduler.state_dict(), os.path.join(output_dir, 'scheduler.pt'))
logger.info("Saving optimizer and scheduler states to %s", output_dir)
if args.max_steps > 0 and global_step > args.max_steps: if args.max_steps > 0 and global_step > args.max_steps:
epoch_iterator.close() epoch_iterator.close()
break break
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment