Unverified Commit 5b7fb6a4 authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

Merge pull request #2134 from bkkaggle/saving-and-resuming

closes #1960 Add saving and resuming functionality for remaining examples
parents 6f68d559 b03872aa
......@@ -112,6 +112,13 @@ def train(args, train_dataset, model, tokenizer):
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total)
# Check if saved optimizer or scheduler states exist
if os.path.isfile(os.path.join(args.model_name_or_path, 'optimizer.pt')) and os.path.isfile(os.path.join(args.model_name_or_path, 'scheduler.pt')):
# Load in optimizer and scheduler states
optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, 'optimizer.pt')))
scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, 'scheduler.pt')))
if args.fp16:
try:
from apex import amp
......@@ -140,13 +147,33 @@ def train(args, train_dataset, model, tokenizer):
logger.info(" Total optimization steps = %d", t_total)
global_step = 0
epochs_trained = 0
steps_trained_in_current_epoch = 0
# Check if continuing training from a checkpoint
if os.path.exists(args.model_name_or_path):
# set global_step to gobal_step of last saved checkpoint from model path
global_step = int(args.model_name_or_path.split('-')[-1].split('/')[0])
epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps)
steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps)
logger.info(" Continuing training from checkpoint, will skip to saved global_step")
logger.info(" Continuing training from epoch %d", epochs_trained)
logger.info(" Continuing training from global step %d", global_step)
logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
tr_loss, logging_loss = 0.0, 0.0
model.zero_grad()
train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
train_iterator = trange(epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
set_seed(args) # Added here for reproductibility (even between python 2 and 3)
for _ in train_iterator:
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
for step, batch in enumerate(epoch_iterator):
# Skip past any already trained steps if resuming training
if steps_trained_in_current_epoch > 0:
steps_trained_in_current_epoch -= 1
continue
model.train()
batch = tuple(t.to(args.device) for t in batch)
inputs = {'input_ids': batch[0],
......@@ -205,9 +232,15 @@ def train(args, train_dataset, model, tokenizer):
os.makedirs(output_dir)
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
model_to_save.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
torch.save(args, os.path.join(output_dir, 'training_args.bin'))
logger.info("Saving model checkpoint to %s", output_dir)
torch.save(optimizer.state_dict(), os.path.join(output_dir, 'optimizer.pt'))
torch.save(scheduler.state_dict(), os.path.join(output_dir, 'scheduler.pt'))
logger.info("Saving optimizer and scheduler states to %s", output_dir)
if args.max_steps > 0 and global_step > args.max_steps:
epoch_iterator.close()
break
......
......@@ -88,6 +88,13 @@ def train(args, train_dataset, model, tokenizer, labels, pad_token_label_id):
]
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total)
# Check if saved optimizer or scheduler states exist
if os.path.isfile(os.path.join(args.model_name_or_path, 'optimizer.pt')) and os.path.isfile(os.path.join(args.model_name_or_path, 'scheduler.pt')):
# Load in optimizer and scheduler states
optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, 'optimizer.pt')))
scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, 'scheduler.pt')))
if args.fp16:
try:
from apex import amp
......@@ -117,13 +124,33 @@ def train(args, train_dataset, model, tokenizer, labels, pad_token_label_id):
logger.info(" Total optimization steps = %d", t_total)
global_step = 0
epochs_trained = 0
steps_trained_in_current_epoch = 0
# Check if continuing training from a checkpoint
if os.path.exists(args.model_name_or_path):
# set global_step to gobal_step of last saved checkpoint from model path
global_step = int(args.model_name_or_path.split('-')[-1].split('/')[0])
epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps)
steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps)
logger.info(" Continuing training from checkpoint, will skip to saved global_step")
logger.info(" Continuing training from epoch %d", epochs_trained)
logger.info(" Continuing training from global step %d", global_step)
logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
tr_loss, logging_loss = 0.0, 0.0
model.zero_grad()
train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
train_iterator = trange(epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
set_seed(args) # Added here for reproductibility (even between python 2 and 3)
for _ in train_iterator:
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
for step, batch in enumerate(epoch_iterator):
# Skip past any already trained steps if resuming training
if steps_trained_in_current_epoch > 0:
steps_trained_in_current_epoch -= 1
continue
model.train()
batch = tuple(t.to(args.device) for t in batch)
inputs = {"input_ids": batch[0],
......@@ -175,9 +202,15 @@ def train(args, train_dataset, model, tokenizer, labels, pad_token_label_id):
os.makedirs(output_dir)
model_to_save = model.module if hasattr(model, "module") else model # Take care of distributed/parallel training
model_to_save.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
torch.save(args, os.path.join(output_dir, "training_args.bin"))
logger.info("Saving model checkpoint to %s", output_dir)
torch.save(optimizer.state_dict(), os.path.join(output_dir, 'optimizer.pt'))
torch.save(scheduler.state_dict(), os.path.join(output_dir, 'scheduler.pt'))
logger.info("Saving optimizer and scheduler states to %s", output_dir)
if args.max_steps > 0 and global_step > args.max_steps:
epoch_iterator.close()
break
......
......@@ -27,7 +27,8 @@ import glob
import timeit
import numpy as np
import torch
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler, TensorDataset)
from torch.utils.data import (
DataLoader, RandomSampler, SequentialSampler, TensorDataset)
from torch.utils.data.distributed import DistributedSampler
try:
......@@ -65,6 +66,7 @@ MODEL_CLASSES = {
'albert': (AlbertConfig, AlbertForQuestionAnswering, AlbertTokenizer),
}
def set_seed(args):
random.seed(args.seed)
np.random.seed(args.seed)
......@@ -72,40 +74,60 @@ def set_seed(args):
if args.n_gpu > 0:
torch.cuda.manual_seed_all(args.seed)
def to_list(tensor):
return tensor.detach().cpu().tolist()
def train(args, train_dataset, model, tokenizer):
""" Train the model """
if args.local_rank in [-1, 0]:
tb_writer = SummaryWriter()
args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
train_sampler = RandomSampler(
train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
train_dataloader = DataLoader(
train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
if args.max_steps > 0:
t_total = args.max_steps
args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
args.num_train_epochs = args.max_steps // (
len(train_dataloader) // args.gradient_accumulation_steps) + 1
else:
t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
t_total = len(
train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
# Prepare optimizer and schedule (linear warmup and decay)
no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
{'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
{'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
{'params': [p for n, p in model.named_parameters() if not any(
nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
{'params': [p for n, p in model.named_parameters() if any(
nd in n for nd in no_decay)], 'weight_decay': 0.0}
]
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total)
optimizer = AdamW(optimizer_grouped_parameters,
lr=args.learning_rate, eps=args.adam_epsilon)
scheduler = get_linear_schedule_with_warmup(
optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total)
# Check if saved optimizer or scheduler states exist
if os.path.isfile(os.path.join(args.model_name_or_path, 'optimizer.pt')) and os.path.isfile(os.path.join(args.model_name_or_path, 'scheduler.pt')):
# Load in optimizer and scheduler states
optimizer.load_state_dict(torch.load(
os.path.join(args.model_name_or_path, 'optimizer.pt')))
scheduler.load_state_dict(torch.load(
os.path.join(args.model_name_or_path, 'scheduler.pt')))
if args.fp16:
try:
from apex import amp
except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
raise ImportError(
"Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
model, optimizer = amp.initialize(
model, optimizer, opt_level=args.fp16_opt_level)
# multi-gpu training (should be after apex fp16 initialization)
if args.n_gpu > 1:
......@@ -121,21 +143,50 @@ def train(args, train_dataset, model, tokenizer):
logger.info("***** Running training *****")
logger.info(" Num examples = %d", len(train_dataset))
logger.info(" Num Epochs = %d", args.num_train_epochs)
logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
logger.info(" Instantaneous batch size per GPU = %d",
args.per_gpu_train_batch_size)
logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d",
args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
logger.info(" Gradient Accumulation steps = %d",
args.gradient_accumulation_steps)
logger.info(" Total optimization steps = %d", t_total)
global_step = 1
epochs_trained = 0
steps_trained_in_current_epoch = 0
# Check if continuing training from a checkpoint
if os.path.exists(args.model_name_or_path):
# set global_step to gobal_step of last saved checkpoint from model path
global_step = int(args.model_name_or_path.split('-')[-1].split('/')[0])
epochs_trained = global_step // (len(train_dataloader) //
args.gradient_accumulation_steps)
steps_trained_in_current_epoch = global_step % (
len(train_dataloader) // args.gradient_accumulation_steps)
logger.info(
" Continuing training from checkpoint, will skip to saved global_step")
logger.info(" Continuing training from epoch %d", epochs_trained)
logger.info(" Continuing training from global step %d", global_step)
logger.info(" Will skip the first %d steps in the first epoch",
steps_trained_in_current_epoch)
tr_loss, logging_loss = 0.0, 0.0
model.zero_grad()
train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
set_seed(args) # Added here for reproductibility (even between python 2 and 3)
train_iterator = trange(epochs_trained, int(
args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
# Added here for reproductibility (even between python 2 and 3)
set_seed(args)
for _ in train_iterator:
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
epoch_iterator = tqdm(train_dataloader, desc="Iteration",
disable=args.local_rank not in [-1, 0])
for step, batch in enumerate(epoch_iterator):
# Skip past any already trained steps if resuming training
if steps_trained_in_current_epoch > 0:
steps_trained_in_current_epoch -= 1
continue
model.train()
batch = tuple(t.to(args.device) for t in batch)
......@@ -153,7 +204,8 @@ def train(args, train_dataset, model, tokenizer):
if args.version_2_with_negative:
inputs.update({'is_impossible': batch[7]})
outputs = model(**inputs)
loss = outputs[0] # model outputs are always tuple in transformers (see doc)
# model outputs are always tuple in transformers (see doc)
loss = outputs[0]
if args.n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu parallel (not distributed) training
......@@ -169,9 +221,11 @@ def train(args, train_dataset, model, tokenizer):
tr_loss += loss.item()
if (step + 1) % args.gradient_accumulation_steps == 0:
if args.fp16:
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
torch.nn.utils.clip_grad_norm_(
amp.master_params(optimizer), args.max_grad_norm)
else:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
torch.nn.utils.clip_grad_norm_(
model.parameters(), args.max_grad_norm)
optimizer.step()
scheduler.step() # Update learning rate schedule
......@@ -180,24 +234,41 @@ def train(args, train_dataset, model, tokenizer):
# Log metrics
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
if args.local_rank == -1 and args.evaluate_during_training: # Only evaluate when single GPU otherwise metrics may not average well
# Only evaluate when single GPU otherwise metrics may not average well
if args.local_rank == -1 and args.evaluate_during_training:
results = evaluate(args, model, tokenizer)
for key, value in results.items():
tb_writer.add_scalar('eval_{}'.format(key), value, global_step)
tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step)
tb_writer.add_scalar('loss', (tr_loss - logging_loss)/args.logging_steps, global_step)
tb_writer.add_scalar(
'eval_{}'.format(key), value, global_step)
tb_writer.add_scalar(
'lr', scheduler.get_lr()[0], global_step)
tb_writer.add_scalar(
'loss', (tr_loss - logging_loss)/args.logging_steps, global_step)
logging_loss = tr_loss
# Save model checkpoint
if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(global_step))
output_dir = os.path.join(
args.output_dir, 'checkpoint-{}'.format(global_step))
if not os.path.exists(output_dir):
os.makedirs(output_dir)
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
# Take care of distributed/parallel training
model_to_save = model.module if hasattr(
model, 'module') else model
model_to_save.save_pretrained(output_dir)
torch.save(args, os.path.join(output_dir, 'training_args.bin'))
tokenizer.save_pretrained(output_dir)
torch.save(args, os.path.join(
output_dir, 'training_args.bin'))
logger.info("Saving model checkpoint to %s", output_dir)
torch.save(optimizer.state_dict(), os.path.join(
output_dir, 'optimizer.pt'))
torch.save(scheduler.state_dict(), os.path.join(
output_dir, 'scheduler.pt'))
logger.info(
"Saving optimizer and scheduler states to %s", output_dir)
if args.max_steps > 0 and global_step > args.max_steps:
epoch_iterator.close()
break
......@@ -212,7 +283,8 @@ def train(args, train_dataset, model, tokenizer):
def evaluate(args, model, tokenizer, prefix=""):
dataset, examples, features = load_and_cache_examples(args, tokenizer, evaluate=True, output_examples=True)
dataset, examples, features = load_and_cache_examples(
args, tokenizer, evaluate=True, output_examples=True)
if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
os.makedirs(args.output_dir)
......@@ -221,7 +293,8 @@ def evaluate(args, model, tokenizer, prefix=""):
# Note that DistributedSampler samples randomly
eval_sampler = SequentialSampler(dataset)
eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
eval_dataloader = DataLoader(
dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
# multi-gpu evaluate
if args.n_gpu > 1 and not isinstance(model, torch.nn.DataParallel):
......@@ -284,21 +357,27 @@ def evaluate(args, model, tokenizer, prefix=""):
all_results.append(result)
evalTime = timeit.default_timer() - start_time
logger.info(" Evaluation done in total %f secs (%f sec per example)", evalTime, evalTime / len(dataset))
logger.info(" Evaluation done in total %f secs (%f sec per example)",
evalTime, evalTime / len(dataset))
# Compute predictions
output_prediction_file = os.path.join(args.output_dir, "predictions_{}.json".format(prefix))
output_nbest_file = os.path.join(args.output_dir, "nbest_predictions_{}.json".format(prefix))
output_prediction_file = os.path.join(
args.output_dir, "predictions_{}.json".format(prefix))
output_nbest_file = os.path.join(
args.output_dir, "nbest_predictions_{}.json".format(prefix))
if args.version_2_with_negative:
output_null_log_odds_file = os.path.join(args.output_dir, "null_odds_{}.json".format(prefix))
output_null_log_odds_file = os.path.join(
args.output_dir, "null_odds_{}.json".format(prefix))
else:
output_null_log_odds_file = None
# XLNet and XLM use a more complex post-processing procedure
if args.model_type in ['xlnet', 'xlm']:
start_n_top = model.config.start_n_top if hasattr(model, "config") else model.module.config.start_n_top
end_n_top = model.config.end_n_top if hasattr(model, "config") else model.module.config.end_n_top
start_n_top = model.config.start_n_top if hasattr(
model, "config") else model.module.config.start_n_top
end_n_top = model.config.end_n_top if hasattr(
model, "config") else model.module.config.end_n_top
predictions = compute_predictions_log_probs(examples, features, all_results, args.n_best_size,
args.max_answer_length, output_prediction_file,
......@@ -315,9 +394,11 @@ def evaluate(args, model, tokenizer, prefix=""):
results = squad_evaluate(examples, predictions)
return results
def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False):
if args.local_rank not in [-1, 0] and not evaluate:
torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache
# Make sure only the first process in distributed training process the dataset, and the others will use the cache
torch.distributed.barrier()
# Load data features from cache or dataset file
input_dir = args.data_dir if args.data_dir else "."
......@@ -329,7 +410,8 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
# Init features and dataset from cache if it exists
if os.path.exists(cached_features_file) and not args.overwrite_cache and not output_examples:
logger.info("Loading features from cached file %s", cached_features_file)
logger.info("Loading features from cached file %s",
cached_features_file)
features_and_dataset = torch.load(cached_features_file)
features, dataset = features_and_dataset["features"], features_and_dataset["dataset"]
else:
......@@ -339,16 +421,18 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
try:
import tensorflow_datasets as tfds
except ImportError:
raise ImportError("If not data_dir is specified, tensorflow_datasets needs to be installed.")
raise ImportError(
"If not data_dir is specified, tensorflow_datasets needs to be installed.")
if args.version_2_with_negative:
logger.warn("tensorflow_datasets does not handle version 2 of SQuAD.")
logger.warn(
"tensorflow_datasets does not handle version 2 of SQuAD.")
tfds_examples = tfds.load("squad")
examples = SquadV1Processor().get_examples_from_dataset(tfds_examples, evaluate=evaluate)
examples = SquadV1Processor().get_examples_from_dataset(
tfds_examples, evaluate=evaluate)
else:
processor = SquadV2Processor() if args.version_2_with_negative else SquadV1Processor()
if evaluate:
examples = processor.get_dev_examples(args.data_dir, filename=args.predict_file)
else:
......@@ -366,11 +450,14 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
)
if args.local_rank in [-1, 0]:
logger.info("Saving features into cached file %s", cached_features_file)
torch.save({"features": features, "dataset": dataset}, cached_features_file)
logger.info("Saving features into cached file %s",
cached_features_file)
torch.save({"features": features, "dataset": dataset},
cached_features_file)
if args.local_rank == 0 and not evaluate:
torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache
# Make sure only the first process in distributed training process the dataset, and the others will use the cache
torch.distributed.barrier()
if output_examples:
return dataset, examples, features
......@@ -380,7 +467,7 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
def main():
parser = argparse.ArgumentParser()
## Required parameters
# Required parameters
parser.add_argument("--model_type", default=None, type=str, required=True,
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
parser.add_argument("--model_name_or_path", default=None, type=str, required=True,
......@@ -388,7 +475,7 @@ def main():
parser.add_argument("--output_dir", default=None, type=str, required=True,
help="The output directory where the model checkpoints and predictions will be written.")
## Other parameters
# Other parameters
parser.add_argument("--data_dir", default=None, type=str,
help="The input data dir. Should contain the .json files for the task." +
"If no data dir or train/predict files are specified, will run with tensorflow_datasets.")
......@@ -485,19 +572,22 @@ def main():
args = parser.parse_args()
if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir:
raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir))
raise ValueError(
"Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir))
# Setup distant debugging if needed
if args.server_ip and args.server_port:
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
import ptvsd
print("Waiting for debugger attach")
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
ptvsd.enable_attach(
address=(args.server_ip, args.server_port), redirect_output=True)
ptvsd.wait_for_attach()
# Setup CUDA, GPU & distributed training
if args.local_rank == -1 or args.no_cuda:
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
device = torch.device(
"cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
args.n_gpu = torch.cuda.device_count()
else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
torch.cuda.set_device(args.local_rank)
......@@ -507,9 +597,9 @@ def main():
args.device = device
# Setup logging
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt = '%m/%d/%Y %H:%M:%S',
level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt='%m/%d/%Y %H:%M:%S',
level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16)
......@@ -518,7 +608,8 @@ def main():
# Load pretrained model and tokenizer
if args.local_rank not in [-1, 0]:
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
# Make sure only the first process in distributed training will download model & vocab
torch.distributed.barrier()
args.model_type = args.model_type.lower()
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
......@@ -528,12 +619,14 @@ def main():
do_lower_case=args.do_lower_case,
cache_dir=args.cache_dir if args.cache_dir else None)
model = model_class.from_pretrained(args.model_name_or_path,
from_tf=bool('.ckpt' in args.model_name_or_path),
from_tf=bool(
'.ckpt' in args.model_name_or_path),
config=config,
cache_dir=args.cache_dir if args.cache_dir else None)
if args.local_rank == 0:
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
# Make sure only the first process in distributed training will download model & vocab
torch.distributed.barrier()
model.to(args.device)
......@@ -547,14 +640,16 @@ def main():
import apex
apex.amp.register_half_function(torch, 'einsum')
except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
raise ImportError(
"Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
# Training
if args.do_train:
train_dataset = load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False)
train_dataset = load_and_cache_examples(
args, tokenizer, evaluate=False, output_examples=False)
global_step, tr_loss = train(args, train_dataset, model, tokenizer)
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
logger.info(" global_step = %s, average loss = %s",
global_step, tr_loss)
# Save the trained model and the tokenizer
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
......@@ -565,7 +660,8 @@ def main():
logger.info("Saving model checkpoint to %s", args.output_dir)
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()`
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
# Take care of distributed/parallel training
model_to_save = model.module if hasattr(model, 'module') else model
model_to_save.save_pretrained(args.output_dir)
tokenizer.save_pretrained(args.output_dir)
......@@ -573,15 +669,15 @@ def main():
torch.save(args, os.path.join(args.output_dir, 'training_args.bin'))
# Load a trained model and vocabulary that you have fine-tuned
model = model_class.from_pretrained(args.output_dir, force_download=True)
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
model = model_class.from_pretrained(
args.output_dir, force_download=True)
tokenizer = tokenizer_class.from_pretrained(
args.output_dir, do_lower_case=args.do_lower_case)
model.to(args.device)
# Evaluation - we can ask to evaluate all the checkpoints (sub-directories) in a directory
results = {}
if args.do_eval and args.local_rank in [-1, 0]:
if args.do_train:
logger.info("Loading checkpoints saved during training for evaluation")
checkpoints = [args.output_dir]
......@@ -596,14 +692,17 @@ def main():
for checkpoint in checkpoints:
# Reload the model
global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else ""
model = model_class.from_pretrained(checkpoint, force_download=True)
global_step = checkpoint.split(
'-')[-1] if len(checkpoints) > 1 else ""
model = model_class.from_pretrained(
checkpoint, force_download=True)
model.to(args.device)
# Evaluate
result = evaluate(args, model, tokenizer, prefix=global_step)
result = dict((k + ('_{}'.format(global_step) if global_step else ''), v) for k, v in result.items())
result = dict((k + ('_{}'.format(global_step) if global_step else ''), v)
for k, v in result.items())
results.update(result)
logger.info("Results: {}".format(results))
......
......@@ -92,6 +92,13 @@ def train(args, train_dataset, model, tokenizer):
]
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total)
# Check if saved optimizer or scheduler states exist
if os.path.isfile(os.path.join(args.model_name_or_path, 'optimizer.pt')) and os.path.isfile(os.path.join(args.model_name_or_path, 'scheduler.pt')):
# Load in optimizer and scheduler states
optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, 'optimizer.pt')))
scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, 'scheduler.pt')))
if args.fp16:
try:
from apex import amp
......@@ -120,13 +127,32 @@ def train(args, train_dataset, model, tokenizer):
logger.info(" Total optimization steps = %d", t_total)
global_step = 0
epochs_trained = 0
steps_trained_in_current_epoch = 0
# Check if continuing training from a checkpoint
if os.path.exists(args.model_name_or_path):
# set global_step to gobal_step of last saved checkpoint from model path
global_step = int(args.model_name_or_path.split('-')[-1].split('/')[0])
epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps)
steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps)
logger.info(" Continuing training from checkpoint, will skip to saved global_step")
logger.info(" Continuing training from epoch %d", epochs_trained)
logger.info(" Continuing training from global step %d", global_step)
logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
tr_loss, logging_loss = 0.0, 0.0
model.zero_grad()
train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
train_iterator = trange(epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
set_seed(args) # Added here for reproductibility (even between python 2 and 3)
for _ in train_iterator:
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
for step, batch in enumerate(epoch_iterator):
# Skip past any already trained steps if resuming training
if steps_trained_in_current_epoch > 0:
steps_trained_in_current_epoch -= 1
continue
model.train()
batch = tuple(t.to(args.device) for t in batch)
inputs = {'input_ids': batch[0],
......@@ -177,9 +203,15 @@ def train(args, train_dataset, model, tokenizer):
os.makedirs(output_dir)
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
model_to_save.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
torch.save(args, os.path.join(output_dir, 'training_args.bin'))
logger.info("Saving model checkpoint to %s", output_dir)
torch.save(optimizer.state_dict(), os.path.join(output_dir, 'optimizer.pt'))
torch.save(scheduler.state_dict(), os.path.join(output_dir, 'scheduler.pt'))
logger.info("Saving optimizer and scheduler states to %s", output_dir)
if args.max_steps > 0 and global_step > args.max_steps:
epoch_iterator.close()
break
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment