Unverified Commit 848aae49 authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

Merge branch 'master' into python_2

parents 448937c0 82291514
...@@ -121,5 +121,5 @@ dmypy.json ...@@ -121,5 +121,5 @@ dmypy.json
# TF code # TF code
tensorflow_code tensorflow_code
# models # Models
models models
\ No newline at end of file
...@@ -53,14 +53,14 @@ python -m pytest -sv tests/ ...@@ -53,14 +53,14 @@ python -m pytest -sv tests/
This package comprises the following classes that can be imported in Python and are detailed in the [Doc](#doc) section of this readme: This package comprises the following classes that can be imported in Python and are detailed in the [Doc](#doc) section of this readme:
- Eight PyTorch models (`torch.nn.Module`) for Bert with pre-trained weights (in the [`modeling.py`](./pytorch_pretrained_bert/modeling.py) file): - Eight PyTorch models (`torch.nn.Module`) for Bert with pre-trained weights (in the [`modeling.py`](./pytorch_pretrained_bert/modeling.py) file):
- [`BertModel`](./pytorch_pretrained_bert/modeling.py#L537) - raw BERT Transformer model (**fully pre-trained**), - [`BertModel`](./pytorch_pretrained_bert/modeling.py#L556) - raw BERT Transformer model (**fully pre-trained**),
- [`BertForMaskedLM`](./pytorch_pretrained_bert/modeling.py#L691) - BERT Transformer with the pre-trained masked language modeling head on top (**fully pre-trained**), - [`BertForMaskedLM`](./pytorch_pretrained_bert/modeling.py#L710) - BERT Transformer with the pre-trained masked language modeling head on top (**fully pre-trained**),
- [`BertForNextSentencePrediction`](./pytorch_pretrained_bert/modeling.py#L752) - BERT Transformer with the pre-trained next sentence prediction classifier on top (**fully pre-trained**), - [`BertForNextSentencePrediction`](./pytorch_pretrained_bert/modeling.py#L771) - BERT Transformer with the pre-trained next sentence prediction classifier on top (**fully pre-trained**),
- [`BertForPreTraining`](./pytorch_pretrained_bert/modeling.py#L620) - BERT Transformer with masked language modeling head and next sentence prediction classifier on top (**fully pre-trained**), - [`BertForPreTraining`](./pytorch_pretrained_bert/modeling.py#L639) - BERT Transformer with masked language modeling head and next sentence prediction classifier on top (**fully pre-trained**),
- [`BertForSequenceClassification`](./pytorch_pretrained_bert/modeling.py#L814) - BERT Transformer with a sequence classification head on top (BERT Transformer is **pre-trained**, the sequence classification head **is only initialized and has to be trained**), - [`BertForSequenceClassification`](./pytorch_pretrained_bert/modeling.py#L833) - BERT Transformer with a sequence classification head on top (BERT Transformer is **pre-trained**, the sequence classification head **is only initialized and has to be trained**),
- [`BertForMultipleChoice`](./pytorch_pretrained_bert/modeling.py#L880) - BERT Transformer with a multiple choice head on top (used for task like Swag) (BERT Transformer is **pre-trained**, the multiple choice classification head **is only initialized and has to be trained**), - [`BertForMultipleChoice`](./pytorch_pretrained_bert/modeling.py#L899) - BERT Transformer with a multiple choice head on top (used for task like Swag) (BERT Transformer is **pre-trained**, the multiple choice classification head **is only initialized and has to be trained**),
- [`BertForTokenClassification`](./pytorch_pretrained_bert/modeling.py#L949) - BERT Transformer with a token classification head on top (BERT Transformer is **pre-trained**, the token classification head **is only initialized and has to be trained**), - [`BertForTokenClassification`](./pytorch_pretrained_bert/modeling.py#L969) - BERT Transformer with a token classification head on top (BERT Transformer is **pre-trained**, the token classification head **is only initialized and has to be trained**),
- [`BertForQuestionAnswering`](./pytorch_pretrained_bert/modeling.py#L1015) - BERT Transformer with a token classification head on top (BERT Transformer is **pre-trained**, the token classification head **is only initialized and has to be trained**). - [`BertForQuestionAnswering`](./pytorch_pretrained_bert/modeling.py#L1034) - BERT Transformer with a token classification head on top (BERT Transformer is **pre-trained**, the token classification head **is only initialized and has to be trained**).
- Three PyTorch models (`torch.nn.Module`) for OpenAI with pre-trained weights (in the [`modeling_openai.py`](./pytorch_pretrained_bert/modeling_openai.py) file): - Three PyTorch models (`torch.nn.Module`) for OpenAI with pre-trained weights (in the [`modeling_openai.py`](./pytorch_pretrained_bert/modeling_openai.py) file):
- [`OpenAIGPTModel`](./pytorch_pretrained_bert/modeling_openai.py#L537) - raw OpenAI GPT Transformer model (**fully pre-trained**), - [`OpenAIGPTModel`](./pytorch_pretrained_bert/modeling_openai.py#L537) - raw OpenAI GPT Transformer model (**fully pre-trained**),
...@@ -94,7 +94,7 @@ The repository further comprises: ...@@ -94,7 +94,7 @@ The repository further comprises:
- [`run_classifier.py`](./examples/run_classifier.py) - Show how to fine-tune an instance of `BertForSequenceClassification` on GLUE's MRPC task, - [`run_classifier.py`](./examples/run_classifier.py) - Show how to fine-tune an instance of `BertForSequenceClassification` on GLUE's MRPC task,
- [`run_squad.py`](./examples/run_squad.py) - Show how to fine-tune an instance of `BertForQuestionAnswering` on SQuAD v1.0 task. - [`run_squad.py`](./examples/run_squad.py) - Show how to fine-tune an instance of `BertForQuestionAnswering` on SQuAD v1.0 task.
- [`run_swag.py`](./examples/run_swag.py) - Show how to fine-tune an instance of `BertForMultipleChoice` on Swag task. - [`run_swag.py`](./examples/run_swag.py) - Show how to fine-tune an instance of `BertForMultipleChoice` on Swag task.
- [`run_lm_finetuning`](./examples/run_lm_finetuning.py) - Show how to fine-tune an instance of `BertForPretraining' on a target text corpus. - [`run_lm_finetuning.py`](./examples/run_lm_finetuning.py) - Show how to fine-tune an instance of `BertForPretraining' on a target text corpus.
These examples are detailed in the [Examples](#examples) section of this readme. These examples are detailed in the [Examples](#examples) section of this readme.
......
...@@ -34,8 +34,8 @@ from tqdm import tqdm, trange ...@@ -34,8 +34,8 @@ from tqdm import tqdm, trange
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
from pytorch_pretrained_bert.modeling import BertForSequenceClassification from pytorch_pretrained_bert.modeling import BertForSequenceClassification
from pytorch_pretrained_bert.optimization import BertAdam
from pytorch_pretrained_bert.tokenization import BertTokenizer from pytorch_pretrained_bert.tokenization import BertTokenizer
from pytorch_pretrained_bert.optimization import BertAdam, warmup_linear
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt = '%m/%d/%Y %H:%M:%S', datefmt = '%m/%d/%Y %H:%M:%S',
...@@ -299,11 +299,6 @@ def accuracy(out, labels): ...@@ -299,11 +299,6 @@ def accuracy(out, labels):
outputs = np.argmax(out, axis=1) outputs = np.argmax(out, axis=1)
return np.sum(outputs == labels) return np.sum(outputs == labels)
def warmup_linear(x, warmup=0.002):
if x < warmup:
return x/warmup
return 1.0 - x
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
...@@ -419,7 +414,7 @@ def main(): ...@@ -419,7 +414,7 @@ def main():
raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format( raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
args.gradient_accumulation_steps)) args.gradient_accumulation_steps))
args.train_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps) args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps
random.seed(args.seed) random.seed(args.seed)
np.random.seed(args.seed) np.random.seed(args.seed)
...@@ -447,11 +442,13 @@ def main(): ...@@ -447,11 +442,13 @@ def main():
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
train_examples = None train_examples = None
num_train_steps = None num_train_optimization_steps = None
if args.do_train: if args.do_train:
train_examples = processor.get_train_examples(args.data_dir) train_examples = processor.get_train_examples(args.data_dir)
num_train_steps = int( num_train_optimization_steps = int(
len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps * args.num_train_epochs) len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps) * args.num_train_epochs
if args.local_rank != -1:
num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size()
# Prepare model # Prepare model
model = BertForSequenceClassification.from_pretrained(args.bert_model, model = BertForSequenceClassification.from_pretrained(args.bert_model,
...@@ -477,9 +474,6 @@ def main(): ...@@ -477,9 +474,6 @@ def main():
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01}, {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
] ]
t_total = num_train_steps
if args.local_rank != -1:
t_total = t_total // torch.distributed.get_world_size()
if args.fp16: if args.fp16:
try: try:
from apex.optimizers import FP16_Optimizer from apex.optimizers import FP16_Optimizer
...@@ -500,7 +494,7 @@ def main(): ...@@ -500,7 +494,7 @@ def main():
optimizer = BertAdam(optimizer_grouped_parameters, optimizer = BertAdam(optimizer_grouped_parameters,
lr=args.learning_rate, lr=args.learning_rate,
warmup=args.warmup_proportion, warmup=args.warmup_proportion,
t_total=t_total) t_total=num_train_optimization_steps)
global_step = 0 global_step = 0
nb_tr_steps = 0 nb_tr_steps = 0
...@@ -511,7 +505,7 @@ def main(): ...@@ -511,7 +505,7 @@ def main():
logger.info("***** Running training *****") logger.info("***** Running training *****")
logger.info(" Num examples = %d", len(train_examples)) logger.info(" Num examples = %d", len(train_examples))
logger.info(" Batch size = %d", args.train_batch_size) logger.info(" Batch size = %d", args.train_batch_size)
logger.info(" Num steps = %d", num_train_steps) logger.info(" Num steps = %d", num_train_optimization_steps)
all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long) all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long) all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long) all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long)
...@@ -545,10 +539,12 @@ def main(): ...@@ -545,10 +539,12 @@ def main():
nb_tr_examples += input_ids.size(0) nb_tr_examples += input_ids.size(0)
nb_tr_steps += 1 nb_tr_steps += 1
if (step + 1) % args.gradient_accumulation_steps == 0: if (step + 1) % args.gradient_accumulation_steps == 0:
# modify learning rate with special warm up BERT uses if args.fp16:
lr_this_step = args.learning_rate * warmup_linear(global_step/t_total, args.warmup_proportion) # modify learning rate with special warm up BERT uses
for param_group in optimizer.param_groups: # if args.fp16 is False, BertAdam is used that handles this automatically
param_group['lr'] = lr_this_step lr_this_step = args.learning_rate * warmup_linear(global_step/num_train_optimization_steps, args.warmup_proportion)
for param_group in optimizer.param_groups:
param_group['lr'] = lr_this_step
optimizer.step() optimizer.step()
optimizer.zero_grad() optimizer.zero_grad()
global_step += 1 global_step += 1
......
...@@ -30,8 +30,11 @@ from torch.utils.data.distributed import DistributedSampler ...@@ -30,8 +30,11 @@ from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange from tqdm import tqdm, trange
from pytorch_pretrained_bert.modeling import BertForPreTraining from pytorch_pretrained_bert.modeling import BertForPreTraining
from pytorch_pretrained_bert.optimization import BertAdam
from pytorch_pretrained_bert.tokenization import BertTokenizer from pytorch_pretrained_bert.tokenization import BertTokenizer
from pytorch_pretrained_bert.optimization import BertAdam, warmup_linear
from torch.utils.data import Dataset
import random
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt='%m/%d/%Y %H:%M:%S', datefmt='%m/%d/%Y %H:%M:%S',
...@@ -39,12 +42,6 @@ logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message ...@@ -39,12 +42,6 @@ logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def warmup_linear(x, warmup=0.002):
if x < warmup:
return x/warmup
return 1.0 - x
class BERTDataset(Dataset): class BERTDataset(Dataset):
def __init__(self, corpus_path, tokenizer, seq_len, encoding="utf-8", corpus_lines=None, on_memory=True): def __init__(self, corpus_path, tokenizer, seq_len, encoding="utf-8", corpus_lines=None, on_memory=True):
self.vocab = tokenizer.vocab self.vocab = tokenizer.vocab
...@@ -136,11 +133,11 @@ class BERTDataset(Dataset): ...@@ -136,11 +133,11 @@ class BERTDataset(Dataset):
# transform sample to features # transform sample to features
cur_features = convert_example_to_features(cur_example, self.seq_len, self.tokenizer) cur_features = convert_example_to_features(cur_example, self.seq_len, self.tokenizer)
cur_tensors = {"input_ids": torch.tensor(cur_features.input_ids), cur_tensors = (torch.tensor(cur_features.input_ids),
"input_mask": torch.tensor(cur_features.input_mask), torch.tensor(cur_features.input_mask),
"segment_ids": torch.tensor(cur_features.segment_ids), torch.tensor(cur_features.segment_ids),
"lm_label_ids": torch.tensor(cur_features.lm_label_ids), torch.tensor(cur_features.lm_label_ids),
"is_next": torch.tensor(cur_features.is_next)} torch.tensor(cur_features.is_next))
return cur_tensors return cur_tensors
...@@ -325,8 +322,8 @@ def convert_example_to_features(example, max_seq_length, tokenizer): ...@@ -325,8 +322,8 @@ def convert_example_to_features(example, max_seq_length, tokenizer):
# Account for [CLS], [SEP], [SEP] with "- 3" # Account for [CLS], [SEP], [SEP] with "- 3"
_truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3) _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)
t1_random, t1_label = random_word(tokens_a, tokenizer) tokens_a, t1_label = random_word(tokens_a, tokenizer)
t2_random, t2_label = random_word(tokens_b, tokenizer) tokens_b, t2_label = random_word(tokens_b, tokenizer)
# concatenate lm labels and account for CLS, SEP, SEP # concatenate lm labels and account for CLS, SEP, SEP
lm_label_ids = ([-1] + t1_label + [-1] + t2_label + [-1]) lm_label_ids = ([-1] + t1_label + [-1] + t2_label + [-1])
...@@ -459,6 +456,9 @@ def main(): ...@@ -459,6 +456,9 @@ def main():
parser.add_argument("--on_memory", parser.add_argument("--on_memory",
action='store_true', action='store_true',
help="Whether to load train samples into memory or use disk") help="Whether to load train samples into memory or use disk")
parser.add_argument("--do_lower_case",
action='store_true',
help="Whether to lower case the input text. True for uncased models, False for cased models.")
parser.add_argument("--local_rank", parser.add_argument("--local_rank",
type=int, type=int,
default=-1, default=-1,
...@@ -498,7 +498,7 @@ def main(): ...@@ -498,7 +498,7 @@ def main():
raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format( raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
args.gradient_accumulation_steps)) args.gradient_accumulation_steps))
args.train_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps) args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps
random.seed(args.seed) random.seed(args.seed)
np.random.seed(args.seed) np.random.seed(args.seed)
...@@ -517,13 +517,15 @@ def main(): ...@@ -517,13 +517,15 @@ def main():
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
#train_examples = None #train_examples = None
num_train_steps = None num_train_optimization_steps = None
if args.do_train: if args.do_train:
print("Loading Train Dataset", args.train_file) print("Loading Train Dataset", args.train_file)
train_dataset = BERTDataset(args.train_file, tokenizer, seq_len=args.max_seq_length, train_dataset = BERTDataset(args.train_file, tokenizer, seq_len=args.max_seq_length,
corpus_lines=None, on_memory=args.on_memory) corpus_lines=None, on_memory=args.on_memory)
num_train_steps = int( num_train_optimization_steps = int(
len(train_dataset) / args.train_batch_size / args.gradient_accumulation_steps * args.num_train_epochs) len(train_dataset) / args.train_batch_size / args.gradient_accumulation_steps) * args.num_train_epochs
if args.local_rank != -1:
num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size()
# Prepare model # Prepare model
model = BertForPreTraining.from_pretrained(args.bert_model) model = BertForPreTraining.from_pretrained(args.bert_model)
...@@ -546,6 +548,7 @@ def main(): ...@@ -546,6 +548,7 @@ def main():
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01}, {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
] ]
if args.fp16: if args.fp16:
try: try:
from apex.optimizers import FP16_Optimizer from apex.optimizers import FP16_Optimizer
...@@ -566,14 +569,14 @@ def main(): ...@@ -566,14 +569,14 @@ def main():
optimizer = BertAdam(optimizer_grouped_parameters, optimizer = BertAdam(optimizer_grouped_parameters,
lr=args.learning_rate, lr=args.learning_rate,
warmup=args.warmup_proportion, warmup=args.warmup_proportion,
t_total=num_train_steps) t_total=num_train_optimization_steps)
global_step = 0 global_step = 0
if args.do_train: if args.do_train:
logger.info("***** Running training *****") logger.info("***** Running training *****")
logger.info(" Num examples = %d", len(train_dataset)) logger.info(" Num examples = %d", len(train_dataset))
logger.info(" Batch size = %d", args.train_batch_size) logger.info(" Batch size = %d", args.train_batch_size)
logger.info(" Num steps = %d", num_train_steps) logger.info(" Num steps = %d", num_train_optimization_steps)
if args.local_rank == -1: if args.local_rank == -1:
train_sampler = RandomSampler(train_dataset) train_sampler = RandomSampler(train_dataset)
...@@ -588,7 +591,7 @@ def main(): ...@@ -588,7 +591,7 @@ def main():
tr_loss = 0 tr_loss = 0
nb_tr_examples, nb_tr_steps = 0, 0 nb_tr_examples, nb_tr_steps = 0, 0
for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")): for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")):
batch = tuple(t.to(device) for t in batch.values()) batch = tuple(t.to(device) for t in batch)
input_ids, input_mask, segment_ids, lm_label_ids, is_next = batch input_ids, input_mask, segment_ids, lm_label_ids, is_next = batch
loss = model(input_ids, segment_ids, input_mask, lm_label_ids, is_next) loss = model(input_ids, segment_ids, input_mask, lm_label_ids, is_next)
if n_gpu > 1: if n_gpu > 1:
...@@ -603,20 +606,22 @@ def main(): ...@@ -603,20 +606,22 @@ def main():
nb_tr_examples += input_ids.size(0) nb_tr_examples += input_ids.size(0)
nb_tr_steps += 1 nb_tr_steps += 1
if (step + 1) % args.gradient_accumulation_steps == 0: if (step + 1) % args.gradient_accumulation_steps == 0:
# modify learning rate with special warm up BERT uses if args.fp16:
lr_this_step = args.learning_rate * warmup_linear(global_step/num_train_steps, args.warmup_proportion) # modify learning rate with special warm up BERT uses
for param_group in optimizer.param_groups: # if args.fp16 is False, BertAdam is used that handles this automatically
param_group['lr'] = lr_this_step lr_this_step = args.learning_rate * warmup_linear(global_step/num_train_optimization_steps, args.warmup_proportion)
for param_group in optimizer.param_groups:
param_group['lr'] = lr_this_step
optimizer.step() optimizer.step()
optimizer.zero_grad() optimizer.zero_grad()
global_step += 1 global_step += 1
# Save a trained model
logger.info("** ** * Saving fine - tuned model ** ** * ") logger.info("** ** * Saving fine - tuned model ** ** * ")
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
output_model_file = os.path.join(args.output_dir, "pytorch_model.bin") output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
if n_gpu > 1: if args.do_train:
torch.save(model.module.bert.state_dict(), output_model_file) torch.save(model_to_save.state_dict(), output_model_file)
else:
torch.save(model.bert.state_dict(), output_model_file)
def _truncate_seq_pair(tokens_a, tokens_b, max_length): def _truncate_seq_pair(tokens_a, tokens_b, max_length):
......
...@@ -36,7 +36,7 @@ from tqdm import tqdm, trange ...@@ -36,7 +36,7 @@ from tqdm import tqdm, trange
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
from pytorch_pretrained_bert.modeling import BertForQuestionAnswering from pytorch_pretrained_bert.modeling import BertForQuestionAnswering
from pytorch_pretrained_bert.optimization import BertAdam from pytorch_pretrained_bert.optimization import BertAdam, warmup_linear
from pytorch_pretrained_bert.tokenization import (BasicTokenizer, from pytorch_pretrained_bert.tokenization import (BasicTokenizer,
BertTokenizer, BertTokenizer,
whitespace_tokenize) whitespace_tokenize)
...@@ -53,7 +53,10 @@ logger = logging.getLogger(__name__) ...@@ -53,7 +53,10 @@ logger = logging.getLogger(__name__)
class SquadExample(object): class SquadExample(object):
"""A single training/test example for the Squad dataset.""" """
A single training/test example for the Squad dataset.
For examples without an answer, the start and end position are -1.
"""
def __init__(self, def __init__(self,
qas_id, qas_id,
...@@ -61,13 +64,15 @@ class SquadExample(object): ...@@ -61,13 +64,15 @@ class SquadExample(object):
doc_tokens, doc_tokens,
orig_answer_text=None, orig_answer_text=None,
start_position=None, start_position=None,
end_position=None): end_position=None,
is_impossible=None):
self.qas_id = qas_id self.qas_id = qas_id
self.question_text = question_text self.question_text = question_text
self.doc_tokens = doc_tokens self.doc_tokens = doc_tokens
self.orig_answer_text = orig_answer_text self.orig_answer_text = orig_answer_text
self.start_position = start_position self.start_position = start_position
self.end_position = end_position self.end_position = end_position
self.is_impossible = is_impossible
def __str__(self): def __str__(self):
return self.__repr__() return self.__repr__()
...@@ -82,6 +87,8 @@ class SquadExample(object): ...@@ -82,6 +87,8 @@ class SquadExample(object):
s += ", start_position: %d" % (self.start_position) s += ", start_position: %d" % (self.start_position)
if self.start_position: if self.start_position:
s += ", end_position: %d" % (self.end_position) s += ", end_position: %d" % (self.end_position)
if self.start_position:
s += ", is_impossible: %r" % (self.is_impossible)
return s return s
...@@ -99,7 +106,8 @@ class InputFeatures(object): ...@@ -99,7 +106,8 @@ class InputFeatures(object):
input_mask, input_mask,
segment_ids, segment_ids,
start_position=None, start_position=None,
end_position=None): end_position=None,
is_impossible=None):
self.unique_id = unique_id self.unique_id = unique_id
self.example_index = example_index self.example_index = example_index
self.doc_span_index = doc_span_index self.doc_span_index = doc_span_index
...@@ -111,9 +119,10 @@ class InputFeatures(object): ...@@ -111,9 +119,10 @@ class InputFeatures(object):
self.segment_ids = segment_ids self.segment_ids = segment_ids
self.start_position = start_position self.start_position = start_position
self.end_position = end_position self.end_position = end_position
self.is_impossible = is_impossible
def read_squad_examples(input_file, is_training): def read_squad_examples(input_file, is_training, version_2_with_negative):
"""Read a SQuAD json file into a list of SquadExample.""" """Read a SQuAD json file into a list of SquadExample."""
with open(input_file, "r", encoding='utf-8') as reader: with open(input_file, "r", encoding='utf-8') as reader:
input_data = json.load(reader)["data"] input_data = json.load(reader)["data"]
...@@ -147,29 +156,37 @@ def read_squad_examples(input_file, is_training): ...@@ -147,29 +156,37 @@ def read_squad_examples(input_file, is_training):
start_position = None start_position = None
end_position = None end_position = None
orig_answer_text = None orig_answer_text = None
is_impossible = False
if is_training: if is_training:
if len(qa["answers"]) != 1: if version_2_with_negative:
is_impossible = qa["is_impossible"]
if (len(qa["answers"]) != 1) and (not is_impossible):
raise ValueError( raise ValueError(
"For training, each question should have exactly 1 answer.") "For training, each question should have exactly 1 answer.")
answer = qa["answers"][0] if not is_impossible:
orig_answer_text = answer["text"] answer = qa["answers"][0]
answer_offset = answer["answer_start"] orig_answer_text = answer["text"]
answer_length = len(orig_answer_text) answer_offset = answer["answer_start"]
start_position = char_to_word_offset[answer_offset] answer_length = len(orig_answer_text)
end_position = char_to_word_offset[answer_offset + answer_length - 1] start_position = char_to_word_offset[answer_offset]
# Only add answers where the text can be exactly recovered from the end_position = char_to_word_offset[answer_offset + answer_length - 1]
# document. If this CAN'T happen it's likely due to weird Unicode # Only add answers where the text can be exactly recovered from the
# stuff so we will just skip the example. # document. If this CAN'T happen it's likely due to weird Unicode
# # stuff so we will just skip the example.
# Note that this means for training mode, every example is NOT #
# guaranteed to be preserved. # Note that this means for training mode, every example is NOT
actual_text = " ".join(doc_tokens[start_position:(end_position + 1)]) # guaranteed to be preserved.
cleaned_answer_text = " ".join( actual_text = " ".join(doc_tokens[start_position:(end_position + 1)])
whitespace_tokenize(orig_answer_text)) cleaned_answer_text = " ".join(
if actual_text.find(cleaned_answer_text) == -1: whitespace_tokenize(orig_answer_text))
logger.warning("Could not find answer: '%s' vs. '%s'", if actual_text.find(cleaned_answer_text) == -1:
logger.warning("Could not find answer: '%s' vs. '%s'",
actual_text, cleaned_answer_text) actual_text, cleaned_answer_text)
continue continue
else:
start_position = -1
end_position = -1
orig_answer_text = ""
example = SquadExample( example = SquadExample(
qas_id=qas_id, qas_id=qas_id,
...@@ -177,7 +194,8 @@ def read_squad_examples(input_file, is_training): ...@@ -177,7 +194,8 @@ def read_squad_examples(input_file, is_training):
doc_tokens=doc_tokens, doc_tokens=doc_tokens,
orig_answer_text=orig_answer_text, orig_answer_text=orig_answer_text,
start_position=start_position, start_position=start_position,
end_position=end_position) end_position=end_position,
is_impossible=is_impossible)
examples.append(example) examples.append(example)
return examples return examples
...@@ -207,7 +225,10 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length, ...@@ -207,7 +225,10 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
tok_start_position = None tok_start_position = None
tok_end_position = None tok_end_position = None
if is_training: if is_training and example.is_impossible:
tok_start_position = -1
tok_end_position = -1
if is_training and not example.is_impossible:
tok_start_position = orig_to_tok_index[example.start_position] tok_start_position = orig_to_tok_index[example.start_position]
if example.end_position < len(example.doc_tokens) - 1: if example.end_position < len(example.doc_tokens) - 1:
tok_end_position = orig_to_tok_index[example.end_position + 1] - 1 tok_end_position = orig_to_tok_index[example.end_position + 1] - 1
...@@ -279,20 +300,25 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length, ...@@ -279,20 +300,25 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
start_position = None start_position = None
end_position = None end_position = None
if is_training: if is_training and not example.is_impossible:
# For training, if our document chunk does not contain an annotation # For training, if our document chunk does not contain an annotation
# we throw it out, since there is nothing to predict. # we throw it out, since there is nothing to predict.
doc_start = doc_span.start doc_start = doc_span.start
doc_end = doc_span.start + doc_span.length - 1 doc_end = doc_span.start + doc_span.length - 1
if (example.start_position < doc_start or out_of_span = False
example.end_position < doc_start or if not (tok_start_position >= doc_start and
example.start_position > doc_end or example.end_position > doc_end): tok_end_position <= doc_end):
continue out_of_span = True
if out_of_span:
doc_offset = len(query_tokens) + 2 start_position = 0
start_position = tok_start_position - doc_start + doc_offset end_position = 0
end_position = tok_end_position - doc_start + doc_offset else:
doc_offset = len(query_tokens) + 2
start_position = tok_start_position - doc_start + doc_offset
end_position = tok_end_position - doc_start + doc_offset
if is_training and example.is_impossible:
start_position = 0
end_position = 0
if example_index < 20: if example_index < 20:
logger.info("*** Example ***") logger.info("*** Example ***")
logger.info("unique_id: %s" % (unique_id)) logger.info("unique_id: %s" % (unique_id))
...@@ -309,7 +335,9 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length, ...@@ -309,7 +335,9 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
"input_mask: %s" % " ".join([str(x) for x in input_mask])) "input_mask: %s" % " ".join([str(x) for x in input_mask]))
logger.info( logger.info(
"segment_ids: %s" % " ".join([str(x) for x in segment_ids])) "segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
if is_training: if is_training and example.is_impossible:
logger.info("impossible example")
if is_training and not example.is_impossible:
answer_text = " ".join(tokens[start_position:(end_position + 1)]) answer_text = " ".join(tokens[start_position:(end_position + 1)])
logger.info("start_position: %d" % (start_position)) logger.info("start_position: %d" % (start_position))
logger.info("end_position: %d" % (end_position)) logger.info("end_position: %d" % (end_position))
...@@ -328,7 +356,8 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length, ...@@ -328,7 +356,8 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
input_mask=input_mask, input_mask=input_mask,
segment_ids=segment_ids, segment_ids=segment_ids,
start_position=start_position, start_position=start_position,
end_position=end_position)) end_position=end_position,
is_impossible=example.is_impossible))
unique_id += 1 unique_id += 1
return features return features
...@@ -408,15 +437,15 @@ def _check_is_max_context(doc_spans, cur_span_index, position): ...@@ -408,15 +437,15 @@ def _check_is_max_context(doc_spans, cur_span_index, position):
return cur_span_index == best_span_index return cur_span_index == best_span_index
RawResult = collections.namedtuple("RawResult", RawResult = collections.namedtuple("RawResult",
["unique_id", "start_logits", "end_logits"]) ["unique_id", "start_logits", "end_logits"])
def write_predictions(all_examples, all_features, all_results, n_best_size, def write_predictions(all_examples, all_features, all_results, n_best_size,
max_answer_length, do_lower_case, output_prediction_file, max_answer_length, do_lower_case, output_prediction_file,
output_nbest_file, verbose_logging): output_nbest_file, output_null_log_odds_file, verbose_logging,
"""Write final predictions to the json file.""" version_2_with_negative, null_score_diff_threshold):
"""Write final predictions to the json file and log-odds of null if needed."""
logger.info("Writing predictions to: %s" % (output_prediction_file)) logger.info("Writing predictions to: %s" % (output_prediction_file))
logger.info("Writing nbest to: %s" % (output_nbest_file)) logger.info("Writing nbest to: %s" % (output_nbest_file))
...@@ -434,15 +463,29 @@ def write_predictions(all_examples, all_features, all_results, n_best_size, ...@@ -434,15 +463,29 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
all_predictions = collections.OrderedDict() all_predictions = collections.OrderedDict()
all_nbest_json = collections.OrderedDict() all_nbest_json = collections.OrderedDict()
scores_diff_json = collections.OrderedDict()
for (example_index, example) in enumerate(all_examples): for (example_index, example) in enumerate(all_examples):
features = example_index_to_features[example_index] features = example_index_to_features[example_index]
prelim_predictions = [] prelim_predictions = []
# keep track of the minimum score of null start+end of position 0
score_null = 1000000 # large and positive
min_null_feature_index = 0 # the paragraph slice with min mull score
null_start_logit = 0 # the start logit at the slice with min null score
null_end_logit = 0 # the end logit at the slice with min null score
for (feature_index, feature) in enumerate(features): for (feature_index, feature) in enumerate(features):
result = unique_id_to_result[feature.unique_id] result = unique_id_to_result[feature.unique_id]
start_indexes = _get_best_indexes(result.start_logits, n_best_size) start_indexes = _get_best_indexes(result.start_logits, n_best_size)
end_indexes = _get_best_indexes(result.end_logits, n_best_size) end_indexes = _get_best_indexes(result.end_logits, n_best_size)
# if we could have irrelevant answers, get the min score of irrelevant
if version_2_with_negative:
feature_null_score = result.start_logits[0] + result.end_logits[0]
if feature_null_score < score_null:
score_null = feature_null_score
min_null_feature_index = feature_index
null_start_logit = result.start_logits[0]
null_end_logit = result.end_logits[0]
for start_index in start_indexes: for start_index in start_indexes:
for end_index in end_indexes: for end_index in end_indexes:
# We could hypothetically create invalid predictions, e.g., predict # We could hypothetically create invalid predictions, e.g., predict
...@@ -470,7 +513,14 @@ def write_predictions(all_examples, all_features, all_results, n_best_size, ...@@ -470,7 +513,14 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
end_index=end_index, end_index=end_index,
start_logit=result.start_logits[start_index], start_logit=result.start_logits[start_index],
end_logit=result.end_logits[end_index])) end_logit=result.end_logits[end_index]))
if version_2_with_negative:
prelim_predictions.append(
_PrelimPrediction(
feature_index=min_null_feature_index,
start_index=0,
end_index=0,
start_logit=null_start_logit,
end_logit=null_end_logit))
prelim_predictions = sorted( prelim_predictions = sorted(
prelim_predictions, prelim_predictions,
key=lambda x: (x.start_logit + x.end_logit), key=lambda x: (x.start_logit + x.end_logit),
...@@ -485,33 +535,44 @@ def write_predictions(all_examples, all_features, all_results, n_best_size, ...@@ -485,33 +535,44 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
if len(nbest) >= n_best_size: if len(nbest) >= n_best_size:
break break
feature = features[pred.feature_index] feature = features[pred.feature_index]
if pred.start_index > 0: # this is a non-null prediction
tok_tokens = feature.tokens[pred.start_index:(pred.end_index + 1)]
orig_doc_start = feature.token_to_orig_map[pred.start_index]
orig_doc_end = feature.token_to_orig_map[pred.end_index]
orig_tokens = example.doc_tokens[orig_doc_start:(orig_doc_end + 1)]
tok_text = " ".join(tok_tokens)
# De-tokenize WordPieces that have been split off.
tok_text = tok_text.replace(" ##", "")
tok_text = tok_text.replace("##", "")
# Clean whitespace
tok_text = tok_text.strip()
tok_text = " ".join(tok_text.split())
orig_text = " ".join(orig_tokens)
final_text = get_final_text(tok_text, orig_text, do_lower_case, verbose_logging)
if final_text in seen_predictions:
continue
tok_tokens = feature.tokens[pred.start_index:(pred.end_index + 1)] seen_predictions[final_text] = True
orig_doc_start = feature.token_to_orig_map[pred.start_index] else:
orig_doc_end = feature.token_to_orig_map[pred.end_index] final_text = ""
orig_tokens = example.doc_tokens[orig_doc_start:(orig_doc_end + 1)] seen_predictions[final_text] = True
tok_text = " ".join(tok_tokens)
# De-tokenize WordPieces that have been split off.
tok_text = tok_text.replace(" ##", "")
tok_text = tok_text.replace("##", "")
# Clean whitespace
tok_text = tok_text.strip()
tok_text = " ".join(tok_text.split())
orig_text = " ".join(orig_tokens)
final_text = get_final_text(tok_text, orig_text, do_lower_case, verbose_logging)
if final_text in seen_predictions:
continue
seen_predictions[final_text] = True
nbest.append( nbest.append(
_NbestPrediction( _NbestPrediction(
text=final_text, text=final_text,
start_logit=pred.start_logit, start_logit=pred.start_logit,
end_logit=pred.end_logit)) end_logit=pred.end_logit))
# if we didn't include the empty option in the n-best, include it
if version_2_with_negative:
if "" not in seen_predictions:
nbest.append(
_NbestPrediction(
text="",
start_logit=null_start_logit,
end_logit=null_end_logit))
# In very rare edge cases we could have no valid predictions. So we # In very rare edge cases we could have no valid predictions. So we
# just create a nonce prediction in this case to avoid failure. # just create a nonce prediction in this case to avoid failure.
if not nbest: if not nbest:
...@@ -521,8 +582,12 @@ def write_predictions(all_examples, all_features, all_results, n_best_size, ...@@ -521,8 +582,12 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
assert len(nbest) >= 1 assert len(nbest) >= 1
total_scores = [] total_scores = []
best_non_null_entry = None
for entry in nbest: for entry in nbest:
total_scores.append(entry.start_logit + entry.end_logit) total_scores.append(entry.start_logit + entry.end_logit)
if not best_non_null_entry:
if entry.text:
best_non_null_entry = entry
probs = _compute_softmax(total_scores) probs = _compute_softmax(total_scores)
...@@ -537,8 +602,18 @@ def write_predictions(all_examples, all_features, all_results, n_best_size, ...@@ -537,8 +602,18 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
assert len(nbest_json) >= 1 assert len(nbest_json) >= 1
all_predictions[example.qas_id] = nbest_json[0]["text"] if not version_2_with_negative:
all_nbest_json[example.qas_id] = nbest_json all_predictions[example.qas_id] = nbest_json[0]["text"]
else:
# predict "" iff the null score - the score of best non-null > threshold
score_diff = score_null - best_non_null_entry.start_logit - (
best_non_null_entry.end_logit)
scores_diff_json[example.qas_id] = score_diff
if score_diff > null_score_diff_threshold:
all_predictions[example.qas_id] = ""
else:
all_predictions[example.qas_id] = best_non_null_entry.text
all_nbest_json[example.qas_id] = nbest_json
with open(output_prediction_file, "w") as writer: with open(output_prediction_file, "w") as writer:
writer.write(json.dumps(all_predictions, indent=4) + "\n") writer.write(json.dumps(all_predictions, indent=4) + "\n")
...@@ -546,6 +621,10 @@ def write_predictions(all_examples, all_features, all_results, n_best_size, ...@@ -546,6 +621,10 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
with open(output_nbest_file, "w") as writer: with open(output_nbest_file, "w") as writer:
writer.write(json.dumps(all_nbest_json, indent=4) + "\n") writer.write(json.dumps(all_nbest_json, indent=4) + "\n")
if version_2_with_negative:
with open(output_null_log_odds_file, "w") as writer:
writer.write(json.dumps(scores_diff_json, indent=4) + "\n")
def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False): def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False):
"""Project the tokenized prediction back to the original text.""" """Project the tokenized prediction back to the original text."""
...@@ -608,7 +687,7 @@ def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False): ...@@ -608,7 +687,7 @@ def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False):
if len(orig_ns_text) != len(tok_ns_text): if len(orig_ns_text) != len(tok_ns_text):
if verbose_logging: if verbose_logging:
logger.info("Length not equal after stripping spaces: '%s' vs '%s'", logger.info("Length not equal after stripping spaces: '%s' vs '%s'",
orig_ns_text, tok_ns_text) orig_ns_text, tok_ns_text)
return orig_text return orig_text
# We then project the characters in `pred_text` back to `orig_text` using # We then project the characters in `pred_text` back to `orig_text` using
...@@ -677,11 +756,6 @@ def _compute_softmax(scores): ...@@ -677,11 +756,6 @@ def _compute_softmax(scores):
probs.append(score / total_sum) probs.append(score / total_sum)
return probs return probs
def warmup_linear(x, warmup=0.002):
if x < warmup:
return x/warmup
return 1.0 - x
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
...@@ -713,7 +787,7 @@ def main(): ...@@ -713,7 +787,7 @@ def main():
parser.add_argument("--num_train_epochs", default=3.0, type=float, parser.add_argument("--num_train_epochs", default=3.0, type=float,
help="Total number of training epochs to perform.") help="Total number of training epochs to perform.")
parser.add_argument("--warmup_proportion", default=0.1, type=float, parser.add_argument("--warmup_proportion", default=0.1, type=float,
help="Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10% " help="Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10%% "
"of training.") "of training.")
parser.add_argument("--n_best_size", default=20, type=int, parser.add_argument("--n_best_size", default=20, type=int,
help="The total number of n-best predictions to generate in the nbest_predictions.json " help="The total number of n-best predictions to generate in the nbest_predictions.json "
...@@ -750,7 +824,12 @@ def main(): ...@@ -750,7 +824,12 @@ def main():
help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n" help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
"0 (default value): dynamic loss scaling.\n" "0 (default value): dynamic loss scaling.\n"
"Positive power of 2: static loss scaling value.\n") "Positive power of 2: static loss scaling value.\n")
parser.add_argument('--version_2_with_negative',
action='store_true',
help='If true, the SQuAD examples contain some that do not have an answer.')
parser.add_argument('--null_score_diff_threshold',
type=float, default=0.0,
help="If null_score - best_non_null is greater than the threshold predict null.")
args = parser.parse_args() args = parser.parse_args()
if args.local_rank == -1 or args.no_cuda: if args.local_rank == -1 or args.no_cuda:
...@@ -769,7 +848,7 @@ def main(): ...@@ -769,7 +848,7 @@ def main():
raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format( raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
args.gradient_accumulation_steps)) args.gradient_accumulation_steps))
args.train_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps) args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps
random.seed(args.seed) random.seed(args.seed)
np.random.seed(args.seed) np.random.seed(args.seed)
...@@ -789,7 +868,7 @@ def main(): ...@@ -789,7 +868,7 @@ def main():
raise ValueError( raise ValueError(
"If `do_predict` is True, then `predict_file` must be specified.") "If `do_predict` is True, then `predict_file` must be specified.")
if os.path.exists(args.output_dir) and os.listdir(args.output_dir): if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train:
raise ValueError("Output directory () already exists and is not empty.") raise ValueError("Output directory () already exists and is not empty.")
if not os.path.exists(args.output_dir): if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir) os.makedirs(args.output_dir)
...@@ -797,12 +876,14 @@ def main(): ...@@ -797,12 +876,14 @@ def main():
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
train_examples = None train_examples = None
num_train_steps = None num_train_optimization_steps = None
if args.do_train: if args.do_train:
train_examples = read_squad_examples( train_examples = read_squad_examples(
input_file=args.train_file, is_training=True) input_file=args.train_file, is_training=True, version_2_with_negative=args.version_2_with_negative)
num_train_steps = int( num_train_optimization_steps = int(
len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps * args.num_train_epochs) len(train_dataset) / args.train_batch_size / args.gradient_accumulation_steps) * args.num_train_epochs
if args.local_rank != -1:
num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size()
# Prepare model # Prepare model
model = BertForQuestionAnswering.from_pretrained(args.bert_model, model = BertForQuestionAnswering.from_pretrained(args.bert_model,
...@@ -834,12 +915,9 @@ def main(): ...@@ -834,12 +915,9 @@ def main():
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
] ]
t_total = num_train_steps
if args.local_rank != -1:
t_total = t_total // torch.distributed.get_world_size()
if args.fp16: if args.fp16:
try: try:
from apex.optimizers import FP16_Optimizer from apex.optimizer import FP16_Optimizer
from apex.optimizers import FusedAdam from apex.optimizers import FusedAdam
except ImportError: except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.") raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
...@@ -856,7 +934,7 @@ def main(): ...@@ -856,7 +934,7 @@ def main():
optimizer = BertAdam(optimizer_grouped_parameters, optimizer = BertAdam(optimizer_grouped_parameters,
lr=args.learning_rate, lr=args.learning_rate,
warmup=args.warmup_proportion, warmup=args.warmup_proportion,
t_total=t_total) t_total=num_train_optimization_steps)
global_step = 0 global_step = 0
if args.do_train: if args.do_train:
...@@ -882,7 +960,7 @@ def main(): ...@@ -882,7 +960,7 @@ def main():
logger.info(" Num orig examples = %d", len(train_examples)) logger.info(" Num orig examples = %d", len(train_examples))
logger.info(" Num split examples = %d", len(train_features)) logger.info(" Num split examples = %d", len(train_features))
logger.info(" Batch size = %d", args.train_batch_size) logger.info(" Batch size = %d", args.train_batch_size)
logger.info(" Num steps = %d", num_train_steps) logger.info(" Num steps = %d", num_train_optimization_steps)
all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long) all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long) all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long) all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long)
...@@ -913,10 +991,12 @@ def main(): ...@@ -913,10 +991,12 @@ def main():
else: else:
loss.backward() loss.backward()
if (step + 1) % args.gradient_accumulation_steps == 0: if (step + 1) % args.gradient_accumulation_steps == 0:
# modify learning rate with special warm up BERT uses if args.fp16:
lr_this_step = args.learning_rate * warmup_linear(global_step/t_total, args.warmup_proportion) # modify learning rate with special warm up BERT uses
for param_group in optimizer.param_groups: # if args.fp16 is False, BertAdam is used and handles this automatically
param_group['lr'] = lr_this_step lr_this_step = args.learning_rate * warmup_linear(global_step/num_train_optimization_steps, args.warmup_proportion)
for param_group in optimizer.param_groups:
param_group['lr'] = lr_this_step
optimizer.step() optimizer.step()
optimizer.zero_grad() optimizer.zero_grad()
global_step += 1 global_step += 1
...@@ -924,16 +1004,19 @@ def main(): ...@@ -924,16 +1004,19 @@ def main():
# Save a trained model # Save a trained model
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
output_model_file = os.path.join(args.output_dir, "pytorch_model.bin") output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
torch.save(model_to_save.state_dict(), output_model_file) if args.do_train:
torch.save(model_to_save.state_dict(), output_model_file)
# Load a trained model that you have fine-tuned
model_state_dict = torch.load(output_model_file)
model = BertForQuestionAnswering.from_pretrained(args.bert_model, state_dict=model_state_dict)
else:
model = BertForQuestionAnswering.from_pretrained(args.bert_model)
# Load a trained model that you have fine-tuned
model_state_dict = torch.load(output_model_file)
model = BertForQuestionAnswering.from_pretrained(args.bert_model, state_dict=model_state_dict)
model.to(device) model.to(device)
if args.do_predict and (args.local_rank == -1 or torch.distributed.get_rank() == 0): if args.do_predict and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
eval_examples = read_squad_examples( eval_examples = read_squad_examples(
input_file=args.predict_file, is_training=False) input_file=args.predict_file, is_training=False, version_2_with_negative=args.version_2_with_negative)
eval_features = convert_examples_to_features( eval_features = convert_examples_to_features(
examples=eval_examples, examples=eval_examples,
tokenizer=tokenizer, tokenizer=tokenizer,
...@@ -977,10 +1060,12 @@ def main(): ...@@ -977,10 +1060,12 @@ def main():
end_logits=end_logits)) end_logits=end_logits))
output_prediction_file = os.path.join(args.output_dir, "predictions.json") output_prediction_file = os.path.join(args.output_dir, "predictions.json")
output_nbest_file = os.path.join(args.output_dir, "nbest_predictions.json") output_nbest_file = os.path.join(args.output_dir, "nbest_predictions.json")
output_null_log_odds_file = os.path.join(args.output_dir, "null_odds.json")
write_predictions(eval_examples, eval_features, all_results, write_predictions(eval_examples, eval_features, all_results,
args.n_best_size, args.max_answer_length, args.n_best_size, args.max_answer_length,
args.do_lower_case, output_prediction_file, args.do_lower_case, output_prediction_file,
output_nbest_file, args.verbose_logging) output_nbest_file, output_null_log_odds_file, args.verbose_logging,
args.version_2_with_negative, args.null_score_diff_threshold)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -32,7 +32,7 @@ from tqdm import tqdm, trange ...@@ -32,7 +32,7 @@ from tqdm import tqdm, trange
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
from pytorch_pretrained_bert.modeling import BertForMultipleChoice from pytorch_pretrained_bert.modeling import BertForMultipleChoice
from pytorch_pretrained_bert.optimization import BertAdam from pytorch_pretrained_bert.optimization import BertAdam, warmup_linear
from pytorch_pretrained_bert.tokenization import BertTokenizer from pytorch_pretrained_bert.tokenization import BertTokenizer
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
...@@ -240,11 +240,6 @@ def select_field(features, field): ...@@ -240,11 +240,6 @@ def select_field(features, field):
for feature in features for feature in features
] ]
def warmup_linear(x, warmup=0.002):
if x < warmup:
return x/warmup
return 1.0 - x
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
...@@ -343,7 +338,7 @@ def main(): ...@@ -343,7 +338,7 @@ def main():
raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format( raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
args.gradient_accumulation_steps)) args.gradient_accumulation_steps))
args.train_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps) args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps
random.seed(args.seed) random.seed(args.seed)
np.random.seed(args.seed) np.random.seed(args.seed)
...@@ -362,11 +357,13 @@ def main(): ...@@ -362,11 +357,13 @@ def main():
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
train_examples = None train_examples = None
num_train_steps = None num_train_optimization_steps = None
if args.do_train: if args.do_train:
train_examples = read_swag_examples(os.path.join(args.data_dir, 'train.csv'), is_training = True) train_examples = read_swag_examples(os.path.join(args.data_dir, 'train.csv'), is_training = True)
num_train_steps = int( num_train_optimization_steps = int(
len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps * args.num_train_epochs) len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps) * args.num_train_epochs
if args.local_rank != -1:
num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size()
# Prepare model # Prepare model
model = BertForMultipleChoice.from_pretrained(args.bert_model, model = BertForMultipleChoice.from_pretrained(args.bert_model,
...@@ -397,9 +394,6 @@ def main(): ...@@ -397,9 +394,6 @@ def main():
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01}, {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
] ]
t_total = num_train_steps
if args.local_rank != -1:
t_total = t_total // torch.distributed.get_world_size()
if args.fp16: if args.fp16:
try: try:
from apex.optimizers import FP16_Optimizer from apex.optimizers import FP16_Optimizer
...@@ -419,7 +413,7 @@ def main(): ...@@ -419,7 +413,7 @@ def main():
optimizer = BertAdam(optimizer_grouped_parameters, optimizer = BertAdam(optimizer_grouped_parameters,
lr=args.learning_rate, lr=args.learning_rate,
warmup=args.warmup_proportion, warmup=args.warmup_proportion,
t_total=t_total) t_total=num_train_optimization_steps)
global_step = 0 global_step = 0
if args.do_train: if args.do_train:
...@@ -428,7 +422,7 @@ def main(): ...@@ -428,7 +422,7 @@ def main():
logger.info("***** Running training *****") logger.info("***** Running training *****")
logger.info(" Num examples = %d", len(train_examples)) logger.info(" Num examples = %d", len(train_examples))
logger.info(" Batch size = %d", args.train_batch_size) logger.info(" Batch size = %d", args.train_batch_size)
logger.info(" Num steps = %d", num_train_steps) logger.info(" Num steps = %d", num_train_optimization_steps)
all_input_ids = torch.tensor(select_field(train_features, 'input_ids'), dtype=torch.long) all_input_ids = torch.tensor(select_field(train_features, 'input_ids'), dtype=torch.long)
all_input_mask = torch.tensor(select_field(train_features, 'input_mask'), dtype=torch.long) all_input_mask = torch.tensor(select_field(train_features, 'input_mask'), dtype=torch.long)
all_segment_ids = torch.tensor(select_field(train_features, 'segment_ids'), dtype=torch.long) all_segment_ids = torch.tensor(select_field(train_features, 'segment_ids'), dtype=torch.long)
...@@ -465,10 +459,12 @@ def main(): ...@@ -465,10 +459,12 @@ def main():
else: else:
loss.backward() loss.backward()
if (step + 1) % args.gradient_accumulation_steps == 0: if (step + 1) % args.gradient_accumulation_steps == 0:
# modify learning rate with special warm up BERT uses if args.fp16:
lr_this_step = args.learning_rate * warmup_linear(global_step/t_total, args.warmup_proportion) # modify learning rate with special warm up BERT uses
for param_group in optimizer.param_groups: # if args.fp16 is False, BertAdam is used that handles this automatically
param_group['lr'] = lr_this_step lr_this_step = args.learning_rate * warmup_linear(global_step/num_train_optimization_steps, args.warmup_proportion)
for param_group in optimizer.param_groups:
param_group['lr'] = lr_this_step
optimizer.step() optimizer.step()
optimizer.zero_grad() optimizer.zero_grad()
global_step += 1 global_step += 1
......
...@@ -1067,7 +1067,7 @@ class BertForTokenClassification(BertPreTrainedModel): ...@@ -1067,7 +1067,7 @@ class BertForTokenClassification(BertPreTrainedModel):
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
input sequence length in the current batch. It's the mask that we typically use for attention when input sequence length in the current batch. It's the mask that we typically use for attention when
a batch has varying length sentences. a batch has varying length sentences.
`labels`: labels for the classification output: torch.LongTensor of shape [batch_size] `labels`: labels for the classification output: torch.LongTensor of shape [batch_size, sequence_length]
with indices selected in [0, ..., num_labels]. with indices selected in [0, ..., num_labels].
Outputs: Outputs:
...@@ -1107,7 +1107,14 @@ class BertForTokenClassification(BertPreTrainedModel): ...@@ -1107,7 +1107,14 @@ class BertForTokenClassification(BertPreTrainedModel):
if labels is not None: if labels is not None:
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) # Only keep active parts of the loss
if attention_mask is not None:
active_loss = attention_mask.view(-1) == 1
active_logits = logits.view(-1, self.num_labels)[active_loss]
active_labels = labels.view(-1)[active_loss]
loss = loss_fct(active_logits, active_labels)
else:
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
return loss return loss
else: else:
return logits return logits
......
...@@ -74,7 +74,8 @@ def whitespace_tokenize(text): ...@@ -74,7 +74,8 @@ def whitespace_tokenize(text):
class BertTokenizer(object): class BertTokenizer(object):
"""Runs end-to-end tokenization: punctuation splitting + wordpiece""" """Runs end-to-end tokenization: punctuation splitting + wordpiece"""
def __init__(self, vocab_file, do_lower_case=True, max_len=None): def __init__(self, vocab_file, do_lower_case=True, max_len=None,
never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")):
if not os.path.isfile(vocab_file): if not os.path.isfile(vocab_file):
raise ValueError( raise ValueError(
"Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained " "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
...@@ -82,7 +83,8 @@ class BertTokenizer(object): ...@@ -82,7 +83,8 @@ class BertTokenizer(object):
self.vocab = load_vocab(vocab_file) self.vocab = load_vocab(vocab_file)
self.ids_to_tokens = collections.OrderedDict( self.ids_to_tokens = collections.OrderedDict(
[(ids, tok) for tok, ids in self.vocab.items()]) [(ids, tok) for tok, ids in self.vocab.items()])
self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case,
never_split=never_split)
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
self.max_len = max_len if max_len is not None else int(1e12) self.max_len = max_len if max_len is not None else int(1e12)
...@@ -155,13 +157,16 @@ class BertTokenizer(object): ...@@ -155,13 +157,16 @@ class BertTokenizer(object):
class BasicTokenizer(object): class BasicTokenizer(object):
"""Runs basic tokenization (punctuation splitting, lower casing, etc.).""" """Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
def __init__(self, do_lower_case=True): def __init__(self,
do_lower_case=True,
never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")):
"""Constructs a BasicTokenizer. """Constructs a BasicTokenizer.
Args: Args:
do_lower_case: Whether to lower case the input. do_lower_case: Whether to lower case the input.
""" """
self.do_lower_case = do_lower_case self.do_lower_case = do_lower_case
self.never_split = never_split
def tokenize(self, text): def tokenize(self, text):
"""Tokenizes a piece of text.""" """Tokenizes a piece of text."""
...@@ -176,7 +181,7 @@ class BasicTokenizer(object): ...@@ -176,7 +181,7 @@ class BasicTokenizer(object):
orig_tokens = whitespace_tokenize(text) orig_tokens = whitespace_tokenize(text)
split_tokens = [] split_tokens = []
for token in orig_tokens: for token in orig_tokens:
if self.do_lower_case: if self.do_lower_case and token not in self.never_split:
token = token.lower() token = token.lower()
token = self._run_strip_accents(token) token = self._run_strip_accents(token)
split_tokens.extend(self._run_split_on_punc(token)) split_tokens.extend(self._run_split_on_punc(token))
...@@ -197,6 +202,8 @@ class BasicTokenizer(object): ...@@ -197,6 +202,8 @@ class BasicTokenizer(object):
def _run_split_on_punc(self, text): def _run_split_on_punc(self, text):
"""Splits punctuation on a piece of text.""" """Splits punctuation on a piece of text."""
if text in self.never_split:
return [text]
chars = list(text) chars = list(text)
i = 0 i = 0
start_new_word = True start_new_word = True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment