Commit fdc05cd6 authored by Bilal Khan's avatar Bilal Khan
Browse files

Update run_squad to save optimizer and scheduler states, then resume training from a checkpoint

parent 854ec578
...@@ -27,7 +27,8 @@ import glob ...@@ -27,7 +27,8 @@ import glob
import timeit import timeit
import numpy as np import numpy as np
import torch import torch
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler, TensorDataset) from torch.utils.data import (
DataLoader, RandomSampler, SequentialSampler, TensorDataset)
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
try: try:
...@@ -38,21 +39,21 @@ except: ...@@ -38,21 +39,21 @@ except:
from tqdm import tqdm, trange from tqdm import tqdm, trange
from transformers import (WEIGHTS_NAME, BertConfig, from transformers import (WEIGHTS_NAME, BertConfig,
BertForQuestionAnswering, BertTokenizer, BertForQuestionAnswering, BertTokenizer,
XLMConfig, XLMForQuestionAnswering, XLMConfig, XLMForQuestionAnswering,
XLMTokenizer, XLNetConfig, XLMTokenizer, XLNetConfig,
XLNetForQuestionAnswering, XLNetForQuestionAnswering,
XLNetTokenizer, XLNetTokenizer,
DistilBertConfig, DistilBertForQuestionAnswering, DistilBertTokenizer, DistilBertConfig, DistilBertForQuestionAnswering, DistilBertTokenizer,
AlbertConfig, AlbertForQuestionAnswering, AlbertTokenizer, AlbertConfig, AlbertForQuestionAnswering, AlbertTokenizer,
XLMConfig, XLMForQuestionAnswering, XLMTokenizer, XLMConfig, XLMForQuestionAnswering, XLMTokenizer,
) )
from transformers import AdamW, get_linear_schedule_with_warmup, squad_convert_examples_to_features from transformers import AdamW, get_linear_schedule_with_warmup, squad_convert_examples_to_features
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) \ ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys())
for conf in (BertConfig, XLNetConfig, XLMConfig)), ()) for conf in (BertConfig, XLNetConfig, XLMConfig)), ())
MODEL_CLASSES = { MODEL_CLASSES = {
...@@ -64,6 +65,7 @@ MODEL_CLASSES = { ...@@ -64,6 +65,7 @@ MODEL_CLASSES = {
'xlm': (XLMConfig, XLMForQuestionAnswering, XLMTokenizer) 'xlm': (XLMConfig, XLMForQuestionAnswering, XLMTokenizer)
} }
def set_seed(args): def set_seed(args):
random.seed(args.seed) random.seed(args.seed)
np.random.seed(args.seed) np.random.seed(args.seed)
...@@ -71,40 +73,60 @@ def set_seed(args): ...@@ -71,40 +73,60 @@ def set_seed(args):
if args.n_gpu > 0: if args.n_gpu > 0:
torch.cuda.manual_seed_all(args.seed) torch.cuda.manual_seed_all(args.seed)
def to_list(tensor): def to_list(tensor):
return tensor.detach().cpu().tolist() return tensor.detach().cpu().tolist()
def train(args, train_dataset, model, tokenizer): def train(args, train_dataset, model, tokenizer):
""" Train the model """ """ Train the model """
if args.local_rank in [-1, 0]: if args.local_rank in [-1, 0]:
tb_writer = SummaryWriter() tb_writer = SummaryWriter()
args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset) train_sampler = RandomSampler(
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size) train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
train_dataloader = DataLoader(
train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
if args.max_steps > 0: if args.max_steps > 0:
t_total = args.max_steps t_total = args.max_steps
args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1 args.num_train_epochs = args.max_steps // (
len(train_dataloader) // args.gradient_accumulation_steps) + 1
else: else:
t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs t_total = len(
train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
# Prepare optimizer and schedule (linear warmup and decay) # Prepare optimizer and schedule (linear warmup and decay)
no_decay = ['bias', 'LayerNorm.weight'] no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [ optimizer_grouped_parameters = [
{'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay}, {'params': [p for n, p in model.named_parameters() if not any(
{'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
{'params': [p for n, p in model.named_parameters() if any(
nd in n for nd in no_decay)], 'weight_decay': 0.0}
] ]
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) optimizer = AdamW(optimizer_grouped_parameters,
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total) lr=args.learning_rate, eps=args.adam_epsilon)
scheduler = get_linear_schedule_with_warmup(
optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total)
# Check if saved optimizer or scheduler states exist
if os.path.isfile(os.path.join(args.model_name_or_path, 'optimizer.pt')) and os.path.isfile(os.path.join(args.model_name_or_path, 'scheduler.pt')):
# Load in optimizer and scheduler states
optimizer.load_state_dict(torch.load(
os.path.join(args.model_name_or_path, 'optimizer.pt')))
scheduler.load_state_dict(torch.load(
os.path.join(args.model_name_or_path, 'scheduler.pt')))
if args.fp16: if args.fp16:
try: try:
from apex import amp from apex import amp
except ImportError: except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") raise ImportError(
"Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
model, optimizer = amp.initialize(
model, optimizer, opt_level=args.fp16_opt_level)
# multi-gpu training (should be after apex fp16 initialization) # multi-gpu training (should be after apex fp16 initialization)
if args.n_gpu > 1: if args.n_gpu > 1:
...@@ -120,21 +142,50 @@ def train(args, train_dataset, model, tokenizer): ...@@ -120,21 +142,50 @@ def train(args, train_dataset, model, tokenizer):
logger.info("***** Running training *****") logger.info("***** Running training *****")
logger.info(" Num examples = %d", len(train_dataset)) logger.info(" Num examples = %d", len(train_dataset))
logger.info(" Num Epochs = %d", args.num_train_epochs) logger.info(" Num Epochs = %d", args.num_train_epochs)
logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) logger.info(" Instantaneous batch size per GPU = %d",
args.per_gpu_train_batch_size)
logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d", logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d",
args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1)) args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) logger.info(" Gradient Accumulation steps = %d",
args.gradient_accumulation_steps)
logger.info(" Total optimization steps = %d", t_total) logger.info(" Total optimization steps = %d", t_total)
global_step = 1 global_step = 1
epochs_trained = 0
steps_trained_in_current_epoch = 0
# Check if continuing training from a checkpoint
if os.path.exists(args.model_name_or_path):
# set global_step to gobal_step of last saved checkpoint from model path
global_step = int(args.model_name_or_path.split('-')[-1].split('/')[0])
epochs_trained = global_step // (len(train_dataloader) //
args.gradient_accumulation_steps)
steps_trained_in_current_epoch = global_step % (
len(train_dataloader) // args.gradient_accumulation_steps)
logger.info(
" Continuing training from checkpoint, will skip to saved global_step")
logger.info(" Continuing training from epoch %d", epochs_trained)
logger.info(" Continuing training from global step %d", global_step)
logger.info(" Will skip the first %d steps in the first epoch",
steps_trained_in_current_epoch)
tr_loss, logging_loss = 0.0, 0.0 tr_loss, logging_loss = 0.0, 0.0
model.zero_grad() model.zero_grad()
train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]) train_iterator = trange(epochs_trained, int(
set_seed(args) # Added here for reproductibility (even between python 2 and 3) args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
# Added here for reproductibility (even between python 2 and 3)
set_seed(args)
for _ in train_iterator: for _ in train_iterator:
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0]) epoch_iterator = tqdm(train_dataloader, desc="Iteration",
disable=args.local_rank not in [-1, 0])
for step, batch in enumerate(epoch_iterator): for step, batch in enumerate(epoch_iterator):
# Skip past any already trained steps if resuming training
if steps_trained_in_current_epoch > 0:
steps_trained_in_current_epoch -= 1
continue
model.train() model.train()
batch = tuple(t.to(args.device) for t in batch) batch = tuple(t.to(args.device) for t in batch)
...@@ -152,10 +203,11 @@ def train(args, train_dataset, model, tokenizer): ...@@ -152,10 +203,11 @@ def train(args, train_dataset, model, tokenizer):
inputs.update({'cls_index': batch[5], 'p_mask': batch[6]}) inputs.update({'cls_index': batch[5], 'p_mask': batch[6]})
outputs = model(**inputs) outputs = model(**inputs)
loss = outputs[0] # model outputs are always tuple in transformers (see doc) # model outputs are always tuple in transformers (see doc)
loss = outputs[0]
if args.n_gpu > 1: if args.n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu parallel (not distributed) training loss = loss.mean() # mean() to average on multi-gpu parallel (not distributed) training
if args.gradient_accumulation_steps > 1: if args.gradient_accumulation_steps > 1:
loss = loss / args.gradient_accumulation_steps loss = loss / args.gradient_accumulation_steps
...@@ -168,9 +220,11 @@ def train(args, train_dataset, model, tokenizer): ...@@ -168,9 +220,11 @@ def train(args, train_dataset, model, tokenizer):
tr_loss += loss.item() tr_loss += loss.item()
if (step + 1) % args.gradient_accumulation_steps == 0: if (step + 1) % args.gradient_accumulation_steps == 0:
if args.fp16: if args.fp16:
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm) torch.nn.utils.clip_grad_norm_(
amp.master_params(optimizer), args.max_grad_norm)
else: else:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) torch.nn.utils.clip_grad_norm_(
model.parameters(), args.max_grad_norm)
optimizer.step() optimizer.step()
scheduler.step() # Update learning rate schedule scheduler.step() # Update learning rate schedule
...@@ -179,24 +233,41 @@ def train(args, train_dataset, model, tokenizer): ...@@ -179,24 +233,41 @@ def train(args, train_dataset, model, tokenizer):
# Log metrics # Log metrics
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0: if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
if args.local_rank == -1 and args.evaluate_during_training: # Only evaluate when single GPU otherwise metrics may not average well # Only evaluate when single GPU otherwise metrics may not average well
if args.local_rank == -1 and args.evaluate_during_training:
results = evaluate(args, model, tokenizer) results = evaluate(args, model, tokenizer)
for key, value in results.items(): for key, value in results.items():
tb_writer.add_scalar('eval_{}'.format(key), value, global_step) tb_writer.add_scalar(
tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step) 'eval_{}'.format(key), value, global_step)
tb_writer.add_scalar('loss', (tr_loss - logging_loss)/args.logging_steps, global_step) tb_writer.add_scalar(
'lr', scheduler.get_lr()[0], global_step)
tb_writer.add_scalar(
'loss', (tr_loss - logging_loss)/args.logging_steps, global_step)
logging_loss = tr_loss logging_loss = tr_loss
# Save model checkpoint # Save model checkpoint
if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0: if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(global_step)) output_dir = os.path.join(
args.output_dir, 'checkpoint-{}'.format(global_step))
if not os.path.exists(output_dir): if not os.path.exists(output_dir):
os.makedirs(output_dir) os.makedirs(output_dir)
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training # Take care of distributed/parallel training
model_to_save = model.module if hasattr(
model, 'module') else model
model_to_save.save_pretrained(output_dir) model_to_save.save_pretrained(output_dir)
torch.save(args, os.path.join(output_dir, 'training_args.bin')) tokenizer.save_pretrained(output_dir)
torch.save(args, os.path.join(
output_dir, 'training_args.bin'))
logger.info("Saving model checkpoint to %s", output_dir) logger.info("Saving model checkpoint to %s", output_dir)
torch.save(optimizer.state_dict(), os.path.join(
output_dir, 'optimizer.pt'))
torch.save(scheduler.state_dict(), os.path.join(
output_dir, 'scheduler.pt'))
logger.info(
"Saving optimizer and scheduler states to %s", output_dir)
if args.max_steps > 0 and global_step > args.max_steps: if args.max_steps > 0 and global_step > args.max_steps:
epoch_iterator.close() epoch_iterator.close()
break break
...@@ -211,7 +282,8 @@ def train(args, train_dataset, model, tokenizer): ...@@ -211,7 +282,8 @@ def train(args, train_dataset, model, tokenizer):
def evaluate(args, model, tokenizer, prefix=""): def evaluate(args, model, tokenizer, prefix=""):
dataset, examples, features = load_and_cache_examples(args, tokenizer, evaluate=True, output_examples=True) dataset, examples, features = load_and_cache_examples(
args, tokenizer, evaluate=True, output_examples=True)
if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]: if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
os.makedirs(args.output_dir) os.makedirs(args.output_dir)
...@@ -220,7 +292,8 @@ def evaluate(args, model, tokenizer, prefix=""): ...@@ -220,7 +292,8 @@ def evaluate(args, model, tokenizer, prefix=""):
# Note that DistributedSampler samples randomly # Note that DistributedSampler samples randomly
eval_sampler = SequentialSampler(dataset) eval_sampler = SequentialSampler(dataset)
eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=args.eval_batch_size) eval_dataloader = DataLoader(
dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
# multi-gpu evaluate # multi-gpu evaluate
if args.n_gpu > 1: if args.n_gpu > 1:
...@@ -243,12 +316,13 @@ def evaluate(args, model, tokenizer, prefix=""): ...@@ -243,12 +316,13 @@ def evaluate(args, model, tokenizer, prefix=""):
'input_ids': batch[0], 'input_ids': batch[0],
'attention_mask': batch[1] 'attention_mask': batch[1]
} }
if args.model_type != 'distilbert': if args.model_type != 'distilbert':
inputs['token_type_ids'] = None if args.model_type == 'xlm' else batch[2] # XLM don't use segment_ids # XLM don't use segment_ids
inputs['token_type_ids'] = None if args.model_type == 'xlm' else batch[2]
example_indices = batch[3] example_indices = batch[3]
# XLNet and XLM use more arguments for their predictions # XLNet and XLM use more arguments for their predictions
if args.model_type in ['xlnet', 'xlm']: if args.model_type in ['xlnet', 'xlm']:
inputs.update({'cls_index': batch[4], 'p_mask': batch[5]}) inputs.update({'cls_index': batch[4], 'p_mask': batch[5]})
...@@ -271,9 +345,9 @@ def evaluate(args, model, tokenizer, prefix=""): ...@@ -271,9 +345,9 @@ def evaluate(args, model, tokenizer, prefix=""):
cls_logits = output[4] cls_logits = output[4]
result = SquadResult( result = SquadResult(
unique_id, start_logits, end_logits, unique_id, start_logits, end_logits,
start_top_index=start_top_index, start_top_index=start_top_index,
end_top_index=end_top_index, end_top_index=end_top_index,
cls_logits=cls_logits cls_logits=cls_logits
) )
...@@ -286,40 +360,48 @@ def evaluate(args, model, tokenizer, prefix=""): ...@@ -286,40 +360,48 @@ def evaluate(args, model, tokenizer, prefix=""):
all_results.append(result) all_results.append(result)
evalTime = timeit.default_timer() - start_time evalTime = timeit.default_timer() - start_time
logger.info(" Evaluation done in total %f secs (%f sec per example)", evalTime, evalTime / len(dataset)) logger.info(" Evaluation done in total %f secs (%f sec per example)",
evalTime, evalTime / len(dataset))
# Compute predictions # Compute predictions
output_prediction_file = os.path.join(args.output_dir, "predictions_{}.json".format(prefix)) output_prediction_file = os.path.join(
output_nbest_file = os.path.join(args.output_dir, "nbest_predictions_{}.json".format(prefix)) args.output_dir, "predictions_{}.json".format(prefix))
output_nbest_file = os.path.join(
args.output_dir, "nbest_predictions_{}.json".format(prefix))
if args.version_2_with_negative: if args.version_2_with_negative:
output_null_log_odds_file = os.path.join(args.output_dir, "null_odds_{}.json".format(prefix)) output_null_log_odds_file = os.path.join(
args.output_dir, "null_odds_{}.json".format(prefix))
else: else:
output_null_log_odds_file = None output_null_log_odds_file = None
# XLNet and XLM use a more complex post-processing procedure # XLNet and XLM use a more complex post-processing procedure
if args.model_type in ['xlnet', 'xlm']: if args.model_type in ['xlnet', 'xlm']:
start_n_top = model.config.start_n_top if hasattr(model, "config") else model.module.config.start_n_top start_n_top = model.config.start_n_top if hasattr(
end_n_top = model.config.end_n_top if hasattr(model, "config") else model.module.config.end_n_top model, "config") else model.module.config.start_n_top
end_n_top = model.config.end_n_top if hasattr(
model, "config") else model.module.config.end_n_top
predictions = compute_predictions_log_probs(examples, features, all_results, args.n_best_size, predictions = compute_predictions_log_probs(examples, features, all_results, args.n_best_size,
args.max_answer_length, output_prediction_file, args.max_answer_length, output_prediction_file,
output_nbest_file, output_null_log_odds_file, output_nbest_file, output_null_log_odds_file,
start_n_top, end_n_top, start_n_top, end_n_top,
args.version_2_with_negative, tokenizer, args.verbose_logging) args.version_2_with_negative, tokenizer, args.verbose_logging)
else: else:
predictions = compute_predictions_logits(examples, features, all_results, args.n_best_size, predictions = compute_predictions_logits(examples, features, all_results, args.n_best_size,
args.max_answer_length, args.do_lower_case, output_prediction_file, args.max_answer_length, args.do_lower_case, output_prediction_file,
output_nbest_file, output_null_log_odds_file, args.verbose_logging, output_nbest_file, output_null_log_odds_file, args.verbose_logging,
args.version_2_with_negative, args.null_score_diff_threshold) args.version_2_with_negative, args.null_score_diff_threshold)
# Compute the F1 and exact scores. # Compute the F1 and exact scores.
results = squad_evaluate(examples, predictions) results = squad_evaluate(examples, predictions)
return results return results
def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False): def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False):
if args.local_rank not in [-1, 0] and not evaluate: if args.local_rank not in [-1, 0] and not evaluate:
torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache # Make sure only the first process in distributed training process the dataset, and the others will use the cache
torch.distributed.barrier()
# Load data features from cache or dataset file # Load data features from cache or dataset file
input_dir = args.data_dir if args.data_dir else "." input_dir = args.data_dir if args.data_dir else "."
...@@ -331,7 +413,8 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal ...@@ -331,7 +413,8 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
# Init features and dataset from cache if it exists # Init features and dataset from cache if it exists
if os.path.exists(cached_features_file) and not args.overwrite_cache and not output_examples: if os.path.exists(cached_features_file) and not args.overwrite_cache and not output_examples:
logger.info("Loading features from cached file %s", cached_features_file) logger.info("Loading features from cached file %s",
cached_features_file)
features_and_dataset = torch.load(cached_features_file) features_and_dataset = torch.load(cached_features_file)
features, dataset = features_and_dataset["features"], features_and_dataset["dataset"] features, dataset = features_and_dataset["features"], features_and_dataset["dataset"]
else: else:
...@@ -341,18 +424,22 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal ...@@ -341,18 +424,22 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
try: try:
import tensorflow_datasets as tfds import tensorflow_datasets as tfds
except ImportError: except ImportError:
raise ImportError("If not data_dir is specified, tensorflow_datasets needs to be installed.") raise ImportError(
"If not data_dir is specified, tensorflow_datasets needs to be installed.")
if args.version_2_with_negative: if args.version_2_with_negative:
logger.warn("tensorflow_datasets does not handle version 2 of SQuAD.") logger.warn(
"tensorflow_datasets does not handle version 2 of SQuAD.")
tfds_examples = tfds.load("squad") tfds_examples = tfds.load("squad")
examples = SquadV1Processor().get_examples_from_dataset(tfds_examples, evaluate=evaluate) examples = SquadV1Processor().get_examples_from_dataset(
tfds_examples, evaluate=evaluate)
else: else:
processor = SquadV2Processor() if args.version_2_with_negative else SquadV1Processor() processor = SquadV2Processor() if args.version_2_with_negative else SquadV1Processor()
examples = processor.get_dev_examples(args.data_dir) if evaluate else processor.get_train_examples(args.data_dir) examples = processor.get_dev_examples(
args.data_dir) if evaluate else processor.get_train_examples(args.data_dir)
features, dataset = squad_convert_examples_to_features( features, dataset = squad_convert_examples_to_features(
examples=examples, examples=examples,
tokenizer=tokenizer, tokenizer=tokenizer,
max_seq_length=args.max_seq_length, max_seq_length=args.max_seq_length,
...@@ -363,11 +450,14 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal ...@@ -363,11 +450,14 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
) )
if args.local_rank in [-1, 0]: if args.local_rank in [-1, 0]:
logger.info("Saving features into cached file %s", cached_features_file) logger.info("Saving features into cached file %s",
torch.save({"features": features, "dataset": dataset}, cached_features_file) cached_features_file)
torch.save({"features": features, "dataset": dataset},
cached_features_file)
if args.local_rank == 0 and not evaluate: if args.local_rank == 0 and not evaluate:
torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache # Make sure only the first process in distributed training process the dataset, and the others will use the cache
torch.distributed.barrier()
if output_examples: if output_examples:
return dataset, examples, features return dataset, examples, features
...@@ -377,7 +467,7 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal ...@@ -377,7 +467,7 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
## Required parameters # Required parameters
parser.add_argument("--model_type", default=None, type=str, required=True, parser.add_argument("--model_type", default=None, type=str, required=True,
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys())) help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
parser.add_argument("--model_name_or_path", default=None, type=str, required=True, parser.add_argument("--model_name_or_path", default=None, type=str, required=True,
...@@ -385,7 +475,7 @@ def main(): ...@@ -385,7 +475,7 @@ def main():
parser.add_argument("--output_dir", default=None, type=str, required=True, parser.add_argument("--output_dir", default=None, type=str, required=True,
help="The output directory where the model checkpoints and predictions will be written.") help="The output directory where the model checkpoints and predictions will be written.")
## Other parameters # Other parameters
parser.add_argument("--data_dir", default=None, type=str, parser.add_argument("--data_dir", default=None, type=str,
help="The input data dir. Should contain the .json files for the task. If not specified, will run with tensorflow_datasets.") help="The input data dir. Should contain the .json files for the task. If not specified, will run with tensorflow_datasets.")
parser.add_argument("--config_name", default="", type=str, parser.add_argument("--config_name", default="", type=str,
...@@ -468,8 +558,10 @@ def main(): ...@@ -468,8 +558,10 @@ def main():
parser.add_argument('--fp16_opt_level', type=str, default='O1', parser.add_argument('--fp16_opt_level', type=str, default='O1',
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
"See details at https://nvidia.github.io/apex/amp.html") "See details at https://nvidia.github.io/apex/amp.html")
parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.") parser.add_argument('--server_ip', type=str, default='',
parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.") help="Can be used for distant debugging.")
parser.add_argument('--server_port', type=str, default='',
help="Can be used for distant debugging.")
args = parser.parse_args() args = parser.parse_args()
args.predict_file = os.path.join(args.output_dir, 'predictions_{}_{}.txt'.format( args.predict_file = os.path.join(args.output_dir, 'predictions_{}_{}.txt'.format(
...@@ -478,19 +570,22 @@ def main(): ...@@ -478,19 +570,22 @@ def main():
) )
if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir: if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir:
raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir)) raise ValueError(
"Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir))
# Setup distant debugging if needed # Setup distant debugging if needed
if args.server_ip and args.server_port: if args.server_ip and args.server_port:
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
import ptvsd import ptvsd
print("Waiting for debugger attach") print("Waiting for debugger attach")
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True) ptvsd.enable_attach(
address=(args.server_ip, args.server_port), redirect_output=True)
ptvsd.wait_for_attach() ptvsd.wait_for_attach()
# Setup CUDA, GPU & distributed training # Setup CUDA, GPU & distributed training
if args.local_rank == -1 or args.no_cuda: if args.local_rank == -1 or args.no_cuda:
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") device = torch.device(
"cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
args.n_gpu = torch.cuda.device_count() args.n_gpu = torch.cuda.device_count()
else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
torch.cuda.set_device(args.local_rank) torch.cuda.set_device(args.local_rank)
...@@ -500,18 +595,19 @@ def main(): ...@@ -500,18 +595,19 @@ def main():
args.device = device args.device = device
# Setup logging # Setup logging
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt = '%m/%d/%Y %H:%M:%S', datefmt='%m/%d/%Y %H:%M:%S',
level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN) level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16) args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16)
# Set seed # Set seed
set_seed(args) set_seed(args)
# Load pretrained model and tokenizer # Load pretrained model and tokenizer
if args.local_rank not in [-1, 0]: if args.local_rank not in [-1, 0]:
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab # Make sure only the first process in distributed training will download model & vocab
torch.distributed.barrier()
args.model_type = args.model_type.lower() args.model_type = args.model_type.lower()
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type] config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
...@@ -521,12 +617,14 @@ def main(): ...@@ -521,12 +617,14 @@ def main():
do_lower_case=args.do_lower_case, do_lower_case=args.do_lower_case,
cache_dir=args.cache_dir if args.cache_dir else None) cache_dir=args.cache_dir if args.cache_dir else None)
model = model_class.from_pretrained(args.model_name_or_path, model = model_class.from_pretrained(args.model_name_or_path,
from_tf=bool('.ckpt' in args.model_name_or_path), from_tf=bool(
'.ckpt' in args.model_name_or_path),
config=config, config=config,
cache_dir=args.cache_dir if args.cache_dir else None) cache_dir=args.cache_dir if args.cache_dir else None)
if args.local_rank == 0: if args.local_rank == 0:
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab # Make sure only the first process in distributed training will download model & vocab
torch.distributed.barrier()
model.to(args.device) model.to(args.device)
...@@ -540,14 +638,16 @@ def main(): ...@@ -540,14 +638,16 @@ def main():
import apex import apex
apex.amp.register_half_function(torch, 'einsum') apex.amp.register_half_function(torch, 'einsum')
except ImportError: except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") raise ImportError(
"Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
# Training # Training
if args.do_train: if args.do_train:
train_dataset = load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False) train_dataset = load_and_cache_examples(
args, tokenizer, evaluate=False, output_examples=False)
global_step, tr_loss = train(args, train_dataset, model, tokenizer) global_step, tr_loss = train(args, train_dataset, model, tokenizer)
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss) logger.info(" global_step = %s, average loss = %s",
global_step, tr_loss)
# Save the trained model and the tokenizer # Save the trained model and the tokenizer
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0): if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
...@@ -558,7 +658,8 @@ def main(): ...@@ -558,7 +658,8 @@ def main():
logger.info("Saving model checkpoint to %s", args.output_dir) logger.info("Saving model checkpoint to %s", args.output_dir)
# Save a trained model, configuration and tokenizer using `save_pretrained()`. # Save a trained model, configuration and tokenizer using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()` # They can then be reloaded using `from_pretrained()`
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training # Take care of distributed/parallel training
model_to_save = model.module if hasattr(model, 'module') else model
model_to_save.save_pretrained(args.output_dir) model_to_save.save_pretrained(args.output_dir)
tokenizer.save_pretrained(args.output_dir) tokenizer.save_pretrained(args.output_dir)
...@@ -566,31 +667,37 @@ def main(): ...@@ -566,31 +667,37 @@ def main():
torch.save(args, os.path.join(args.output_dir, 'training_args.bin')) torch.save(args, os.path.join(args.output_dir, 'training_args.bin'))
# Load a trained model and vocabulary that you have fine-tuned # Load a trained model and vocabulary that you have fine-tuned
model = model_class.from_pretrained(args.output_dir, force_download=True) model = model_class.from_pretrained(
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case) args.output_dir, force_download=True)
tokenizer = tokenizer_class.from_pretrained(
args.output_dir, do_lower_case=args.do_lower_case)
model.to(args.device) model.to(args.device)
# Evaluation - we can ask to evaluate all the checkpoints (sub-directories) in a directory # Evaluation - we can ask to evaluate all the checkpoints (sub-directories) in a directory
results = {} results = {}
if args.do_eval and args.local_rank in [-1, 0]: if args.do_eval and args.local_rank in [-1, 0]:
checkpoints = [args.output_dir] checkpoints = [args.output_dir]
if args.eval_all_checkpoints: if args.eval_all_checkpoints:
checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True))) checkpoints = list(os.path.dirname(c) for c in sorted(
logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN) # Reduce model loading logs glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True)))
logging.getLogger("transformers.modeling_utils").setLevel(
logging.WARN) # Reduce model loading logs
logger.info("Evaluate the following checkpoints: %s", checkpoints) logger.info("Evaluate the following checkpoints: %s", checkpoints)
for checkpoint in checkpoints: for checkpoint in checkpoints:
# Reload the model # Reload the model
global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else "" global_step = checkpoint.split(
model = model_class.from_pretrained(checkpoint, force_download=True) '-')[-1] if len(checkpoints) > 1 else ""
model = model_class.from_pretrained(
checkpoint, force_download=True)
model.to(args.device) model.to(args.device)
# Evaluate # Evaluate
result = evaluate(args, model, tokenizer, prefix=global_step) result = evaluate(args, model, tokenizer, prefix=global_step)
result = dict((k + ('_{}'.format(global_step) if global_step else ''), v) for k, v in result.items()) result = dict((k + ('_{}'.format(global_step) if global_step else ''), v)
for k, v in result.items())
results.update(result) results.update(result)
logger.info("Results: {}".format(results)) logger.info("Results: {}".format(results))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment