Unverified Commit ff276fc0 authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

Merge branch 'master' into finish_torchhub_interfaces

parents 312fdd77 a64736dc
...@@ -7,9 +7,11 @@ jobs: ...@@ -7,9 +7,11 @@ jobs:
steps: steps:
- checkout - checkout
- run: sudo pip install --progress-bar off . - run: sudo pip install --progress-bar off .
- run: sudo pip install pytest ftfy spacy - run: sudo pip install pytest codecov pytest-cov
- run: sudo pip install spacy ftfy==4.4.3
- run: sudo python -m spacy download en - run: sudo python -m spacy download en
- run: python -m pytest -sv tests/ --runslow - run: python -m pytest -sv tests/ --runslow --cov
- run: codecov
build_py2: build_py2:
working_directory: ~/pytorch-pretrained-BERT working_directory: ~/pytorch-pretrained-BERT
docker: docker:
...@@ -17,10 +19,11 @@ jobs: ...@@ -17,10 +19,11 @@ jobs:
steps: steps:
- checkout - checkout
- run: sudo pip install --progress-bar off . - run: sudo pip install --progress-bar off .
- run: sudo pip install pytest spacy - run: sudo pip install pytest codecov pytest-cov
- run: sudo pip install ftfy==4.4.3 - run: sudo pip install spacy ftfy==4.4.3
- run: sudo python -m spacy download en - run: sudo python -m spacy download en
- run: python -m pytest -sv tests/ --runslow - run: python -m pytest -sv tests/ --runslow --cov
- run: codecov
workflows: workflows:
version: 2 version: 2
build_and_test: build_and_test:
......
[run]
source=pytorch_pretrained_bert
[report]
exclude_lines =
pragma: no cover
raise
except
register_parameter
\ No newline at end of file
...@@ -1033,7 +1033,7 @@ An overview of the implemented schedules: ...@@ -1033,7 +1033,7 @@ An overview of the implemented schedules:
|-|-| |-|-|
| [Training large models: introduction, tools and examples](#Training-large-models-introduction,-tools-and-examples) | How to use gradient-accumulation, multi-gpu training, distributed training, optimize on CPU and 16-bits training to train Bert models | | [Training large models: introduction, tools and examples](#Training-large-models-introduction,-tools-and-examples) | How to use gradient-accumulation, multi-gpu training, distributed training, optimize on CPU and 16-bits training to train Bert models |
| [Fine-tuning with BERT: running the examples](#Fine-tuning-with-BERT-running-the-examples) | Running the examples in [`./examples`](./examples/): `extract_classif.py`, `run_classifier.py`, `run_squad.py` and `run_lm_finetuning.py` | | [Fine-tuning with BERT: running the examples](#Fine-tuning-with-BERT-running-the-examples) | Running the examples in [`./examples`](./examples/): `extract_classif.py`, `run_classifier.py`, `run_squad.py` and `run_lm_finetuning.py` |
| [Fine-tuning with OpenAI GPT, Transformer-XL and GPT-2](#Fine-tuning-with-OpenAI-GPT-Transformer-XL-and-GPT-2) | Running the examples in [`./examples`](./examples/): `run_openai_gpt.py`, `run_transfo_xl.py` and `run_gpt2.py` | | [Fine-tuning with OpenAI GPT, Transformer-XL and GPT-2](#openai-gpt-transformer-xl-and-gpt-2-running-the-examples) | Running the examples in [`./examples`](./examples/): `run_openai_gpt.py`, `run_transfo_xl.py` and `run_gpt2.py` |
| [Fine-tuning BERT-large on GPUs](#Fine-tuning-BERT-large-on-GPUs) | How to fine tune `BERT large`| | [Fine-tuning BERT-large on GPUs](#Fine-tuning-BERT-large-on-GPUs) | How to fine tune `BERT large`|
### Training large models: introduction, tools and examples ### Training large models: introduction, tools and examples
......
...@@ -4,11 +4,11 @@ from tqdm import tqdm, trange ...@@ -4,11 +4,11 @@ from tqdm import tqdm, trange
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
import shelve import shelve
from random import random, randrange, randint, shuffle, choice, sample from random import random, randrange, randint, shuffle, choice
from pytorch_pretrained_bert.tokenization import BertTokenizer from pytorch_pretrained_bert.tokenization import BertTokenizer
import numpy as np import numpy as np
import json import json
import collections
class DocumentDatabase: class DocumentDatabase:
def __init__(self, reduce_memory=False): def __init__(self, reduce_memory=False):
...@@ -98,42 +98,77 @@ def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens): ...@@ -98,42 +98,77 @@ def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens):
else: else:
trunc_tokens.pop() trunc_tokens.pop()
MaskedLmInstance = collections.namedtuple("MaskedLmInstance",
["index", "label"])
def create_masked_lm_predictions(tokens, masked_lm_prob, max_predictions_per_seq, vocab_list): def create_masked_lm_predictions(tokens, masked_lm_prob, max_predictions_per_seq, whole_word_mask, vocab_list):
"""Creates the predictions for the masked LM objective. This is mostly copied from the Google BERT repo, but """Creates the predictions for the masked LM objective. This is mostly copied from the Google BERT repo, but
with several refactors to clean it up and remove a lot of unnecessary variables.""" with several refactors to clean it up and remove a lot of unnecessary variables."""
cand_indices = [] cand_indices = []
for (i, token) in enumerate(tokens): for (i, token) in enumerate(tokens):
if token == "[CLS]" or token == "[SEP]": if token == "[CLS]" or token == "[SEP]":
continue continue
cand_indices.append(i) # Whole Word Masking means that if we mask all of the wordpieces
# corresponding to an original word. When a word has been split into
# WordPieces, the first token does not have any marker and any subsequence
# tokens are prefixed with ##. So whenever we see the ## token, we
# append it to the previous set of word indexes.
#
# Note that Whole Word Masking does *not* change the training code
# at all -- we still predict each WordPiece independently, softmaxed
# over the entire vocabulary.
if (whole_word_mask and len(cand_indices) >= 1 and token.startswith("##")):
cand_indices[-1].append(i)
else:
cand_indices.append([i])
num_to_mask = min(max_predictions_per_seq, num_to_mask = min(max_predictions_per_seq,
max(1, int(round(len(tokens) * masked_lm_prob)))) max(1, int(round(len(tokens) * masked_lm_prob))))
shuffle(cand_indices) shuffle(cand_indices)
mask_indices = sorted(sample(cand_indices, num_to_mask)) masked_lms = []
masked_token_labels = [] covered_indexes = set()
for index in mask_indices: for index_set in cand_indices:
# 80% of the time, replace with [MASK] if len(masked_lms) >= num_to_mask:
if random() < 0.8: break
masked_token = "[MASK]" # If adding a whole-word mask would exceed the maximum number of
else: # predictions, then just skip this candidate.
# 10% of the time, keep original if len(masked_lms) + len(index_set) > num_to_mask:
if random() < 0.5: continue
masked_token = tokens[index] is_any_index_covered = False
# 10% of the time, replace with random word for index in index_set:
if index in covered_indexes:
is_any_index_covered = True
break
if is_any_index_covered:
continue
for index in index_set:
covered_indexes.add(index)
masked_token = None
# 80% of the time, replace with [MASK]
if random() < 0.8:
masked_token = "[MASK]"
else: else:
masked_token = choice(vocab_list) # 10% of the time, keep original
masked_token_labels.append(tokens[index]) if random() < 0.5:
# Once we've saved the true label for that token, we can overwrite it with the masked version masked_token = tokens[index]
tokens[index] = masked_token # 10% of the time, replace with random word
else:
masked_token = choice(vocab_list)
masked_lms.append(MaskedLmInstance(index=index, label=tokens[index]))
tokens[index] = masked_token
assert len(masked_lms) <= num_to_mask
masked_lms = sorted(masked_lms, key=lambda x: x.index)
mask_indices = [p.index for p in masked_lms]
masked_token_labels = [p.label for p in masked_lms]
return tokens, mask_indices, masked_token_labels return tokens, mask_indices, masked_token_labels
def create_instances_from_document( def create_instances_from_document(
doc_database, doc_idx, max_seq_length, short_seq_prob, doc_database, doc_idx, max_seq_length, short_seq_prob,
masked_lm_prob, max_predictions_per_seq, vocab_list): masked_lm_prob, max_predictions_per_seq, whole_word_mask, vocab_list):
"""This code is mostly a duplicate of the equivalent function from Google BERT's repo. """This code is mostly a duplicate of the equivalent function from Google BERT's repo.
However, we make some changes and improvements. Sampling is improved and no longer requires a loop in this function. However, we make some changes and improvements. Sampling is improved and no longer requires a loop in this function.
Also, documents are sampled proportionally to the number of sentences they contain, which means each sentence Also, documents are sampled proportionally to the number of sentences they contain, which means each sentence
...@@ -213,7 +248,7 @@ def create_instances_from_document( ...@@ -213,7 +248,7 @@ def create_instances_from_document(
segment_ids = [0 for _ in range(len(tokens_a) + 2)] + [1 for _ in range(len(tokens_b) + 1)] segment_ids = [0 for _ in range(len(tokens_a) + 2)] + [1 for _ in range(len(tokens_b) + 1)]
tokens, masked_lm_positions, masked_lm_labels = create_masked_lm_predictions( tokens, masked_lm_positions, masked_lm_labels = create_masked_lm_predictions(
tokens, masked_lm_prob, max_predictions_per_seq, vocab_list) tokens, masked_lm_prob, max_predictions_per_seq, whole_word_mask, vocab_list)
instance = { instance = {
"tokens": tokens, "tokens": tokens,
...@@ -235,9 +270,10 @@ def main(): ...@@ -235,9 +270,10 @@ def main():
parser.add_argument("--output_dir", type=Path, required=True) parser.add_argument("--output_dir", type=Path, required=True)
parser.add_argument("--bert_model", type=str, required=True, parser.add_argument("--bert_model", type=str, required=True,
choices=["bert-base-uncased", "bert-large-uncased", "bert-base-cased", choices=["bert-base-uncased", "bert-large-uncased", "bert-base-cased",
"bert-base-multilingual", "bert-base-chinese"]) "bert-base-multilingual-uncased", "bert-base-chinese", "bert-base-multilingual-cased"])
parser.add_argument("--do_lower_case", action="store_true") parser.add_argument("--do_lower_case", action="store_true")
parser.add_argument("--do_whole_word_mask", action="store_true",
help="Whether to use whole word masking rather than per-WordPiece masking.")
parser.add_argument("--reduce_memory", action="store_true", parser.add_argument("--reduce_memory", action="store_true",
help="Reduce memory usage for large datasets by keeping data on disc rather than in memory") help="Reduce memory usage for large datasets by keeping data on disc rather than in memory")
...@@ -284,7 +320,7 @@ def main(): ...@@ -284,7 +320,7 @@ def main():
doc_instances = create_instances_from_document( doc_instances = create_instances_from_document(
docs, doc_idx, max_seq_length=args.max_seq_len, short_seq_prob=args.short_seq_prob, docs, doc_idx, max_seq_length=args.max_seq_len, short_seq_prob=args.short_seq_prob,
masked_lm_prob=args.masked_lm_prob, max_predictions_per_seq=args.max_predictions_per_seq, masked_lm_prob=args.masked_lm_prob, max_predictions_per_seq=args.max_predictions_per_seq,
vocab_list=vocab_list) whole_word_mask=args.do_whole_word_mask, vocab_list=vocab_list)
doc_instances = [json.dumps(instance) for instance in doc_instances] doc_instances = [json.dumps(instance) for instance in doc_instances]
for instance in doc_instances: for instance in doc_instances:
epoch_file.write(instance + '\n') epoch_file.write(instance + '\n')
......
...@@ -25,6 +25,7 @@ import random ...@@ -25,6 +25,7 @@ import random
import sys import sys
import numpy as np import numpy as np
import math
import torch import torch
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler, from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
TensorDataset) TensorDataset)
...@@ -735,15 +736,6 @@ def main(): ...@@ -735,15 +736,6 @@ def main():
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
train_examples = None
num_train_optimization_steps = None
if args.do_train:
train_examples = processor.get_train_examples(args.data_dir)
num_train_optimization_steps = int(
len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps) * args.num_train_epochs
if args.local_rank != -1:
num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size()
# Prepare model # Prepare model
cache_dir = args.cache_dir if args.cache_dir else os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed_{}'.format(args.local_rank)) cache_dir = args.cache_dir if args.cache_dir else os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed_{}'.format(args.local_rank))
model = BertForSequenceClassification.from_pretrained(args.bert_model, model = BertForSequenceClassification.from_pretrained(args.bert_model,
...@@ -762,8 +754,35 @@ def main(): ...@@ -762,8 +754,35 @@ def main():
elif n_gpu > 1: elif n_gpu > 1:
model = torch.nn.DataParallel(model) model = torch.nn.DataParallel(model)
# Prepare optimizer
if args.do_train: if args.do_train:
# Prepare data loader
train_examples = processor.get_train_examples(args.data_dir)
train_features = convert_examples_to_features(
train_examples, label_list, args.max_seq_length, tokenizer, output_mode)
all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long)
if output_mode == "classification":
all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.long)
elif output_mode == "regression":
all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.float)
train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
if args.local_rank == -1:
train_sampler = RandomSampler(train_data)
else:
train_sampler = DistributedSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)
num_train_optimization_steps = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
if args.local_rank != -1:
num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size()
# Prepare optimizer
param_optimizer = list(model.named_parameters()) param_optimizer = list(model.named_parameters())
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [ optimizer_grouped_parameters = [
...@@ -794,31 +813,14 @@ def main(): ...@@ -794,31 +813,14 @@ def main():
warmup=args.warmup_proportion, warmup=args.warmup_proportion,
t_total=num_train_optimization_steps) t_total=num_train_optimization_steps)
global_step = 0 global_step = 0
nb_tr_steps = 0 nb_tr_steps = 0
tr_loss = 0 tr_loss = 0
if args.do_train:
train_features = convert_examples_to_features(
train_examples, label_list, args.max_seq_length, tokenizer, output_mode)
logger.info("***** Running training *****") logger.info("***** Running training *****")
logger.info(" Num examples = %d", len(train_examples)) logger.info(" Num examples = %d", len(train_examples))
logger.info(" Batch size = %d", args.train_batch_size) logger.info(" Batch size = %d", args.train_batch_size)
logger.info(" Num steps = %d", num_train_optimization_steps) logger.info(" Num steps = %d", num_train_optimization_steps)
all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long)
if output_mode == "classification":
all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.long)
elif output_mode == "regression":
all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.float)
train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
if args.local_rank == -1:
train_sampler = RandomSampler(train_data)
else:
train_sampler = DistributedSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)
model.train() model.train()
for _ in trange(int(args.num_train_epochs), desc="Epoch"): for _ in trange(int(args.num_train_epochs), desc="Epoch"):
......
...@@ -190,7 +190,7 @@ def main(): ...@@ -190,7 +190,7 @@ def main():
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01}, {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
] ]
num_train_optimization_steps = len(train_data) * args.num_train_epochs // args.train_batch_size num_train_optimization_steps = len(train_dataloader) * args.num_train_epochs
optimizer = OpenAIAdam(optimizer_grouped_parameters, optimizer = OpenAIAdam(optimizer_grouped_parameters,
lr=args.learning_rate, lr=args.learning_rate,
warmup=args.warmup_proportion, warmup=args.warmup_proportion,
......
...@@ -617,7 +617,7 @@ def write_predictions(all_examples, all_features, all_results, n_best_size, ...@@ -617,7 +617,7 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
all_predictions[example.qas_id] = "" all_predictions[example.qas_id] = ""
else: else:
all_predictions[example.qas_id] = best_non_null_entry.text all_predictions[example.qas_id] = best_non_null_entry.text
all_nbest_json[example.qas_id] = nbest_json all_nbest_json[example.qas_id] = nbest_json
with open(output_prediction_file, "w") as writer: with open(output_prediction_file, "w") as writer:
writer.write(json.dumps(all_predictions, indent=4) + "\n") writer.write(json.dumps(all_predictions, indent=4) + "\n")
...@@ -894,16 +894,6 @@ def main(): ...@@ -894,16 +894,6 @@ def main():
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
train_examples = None
num_train_optimization_steps = None
if args.do_train:
train_examples = read_squad_examples(
input_file=args.train_file, is_training=True, version_2_with_negative=args.version_2_with_negative)
num_train_optimization_steps = int(
len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps) * args.num_train_epochs
if args.local_rank != -1:
num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size()
# Prepare model # Prepare model
model = BertForQuestionAnswering.from_pretrained(args.bert_model, model = BertForQuestionAnswering.from_pretrained(args.bert_model,
cache_dir=os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed_{}'.format(args.local_rank))) cache_dir=os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed_{}'.format(args.local_rank)))
...@@ -921,8 +911,47 @@ def main(): ...@@ -921,8 +911,47 @@ def main():
elif n_gpu > 1: elif n_gpu > 1:
model = torch.nn.DataParallel(model) model = torch.nn.DataParallel(model)
# Prepare optimizer
if args.do_train: if args.do_train:
# Prepare data loader
train_examples = read_squad_examples(
input_file=args.train_file, is_training=True, version_2_with_negative=args.version_2_with_negative)
cached_train_features_file = args.train_file+'_{0}_{1}_{2}_{3}'.format(
list(filter(None, args.bert_model.split('/'))).pop(), str(args.max_seq_length), str(args.doc_stride), str(args.max_query_length))
try:
with open(cached_train_features_file, "rb") as reader:
train_features = pickle.load(reader)
except:
train_features = convert_examples_to_features(
examples=train_examples,
tokenizer=tokenizer,
max_seq_length=args.max_seq_length,
doc_stride=args.doc_stride,
max_query_length=args.max_query_length,
is_training=True)
if args.local_rank == -1 or torch.distributed.get_rank() == 0:
logger.info(" Saving train features into cached file %s", cached_train_features_file)
with open(cached_train_features_file, "wb") as writer:
pickle.dump(train_features, writer)
all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long)
all_start_positions = torch.tensor([f.start_position for f in train_features], dtype=torch.long)
all_end_positions = torch.tensor([f.end_position for f in train_features], dtype=torch.long)
train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids,
all_start_positions, all_end_positions)
if args.local_rank == -1:
train_sampler = RandomSampler(train_data)
else:
train_sampler = DistributedSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)
num_train_optimization_steps = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
if args.local_rank != -1:
num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size()
# Prepare optimizer
param_optimizer = list(model.named_parameters()) param_optimizer = list(model.named_parameters())
# hack to remove pooler, which is not used # hack to remove pooler, which is not used
...@@ -958,43 +987,13 @@ def main(): ...@@ -958,43 +987,13 @@ def main():
warmup=args.warmup_proportion, warmup=args.warmup_proportion,
t_total=num_train_optimization_steps) t_total=num_train_optimization_steps)
global_step = 0 global_step = 0
if args.do_train:
cached_train_features_file = args.train_file+'_{0}_{1}_{2}_{3}'.format(
list(filter(None, args.bert_model.split('/'))).pop(), str(args.max_seq_length), str(args.doc_stride), str(args.max_query_length))
train_features = None
try:
with open(cached_train_features_file, "rb") as reader:
train_features = pickle.load(reader)
except:
train_features = convert_examples_to_features(
examples=train_examples,
tokenizer=tokenizer,
max_seq_length=args.max_seq_length,
doc_stride=args.doc_stride,
max_query_length=args.max_query_length,
is_training=True)
if args.local_rank == -1 or torch.distributed.get_rank() == 0:
logger.info(" Saving train features into cached file %s", cached_train_features_file)
with open(cached_train_features_file, "wb") as writer:
pickle.dump(train_features, writer)
logger.info("***** Running training *****") logger.info("***** Running training *****")
logger.info(" Num orig examples = %d", len(train_examples)) logger.info(" Num orig examples = %d", len(train_examples))
logger.info(" Num split examples = %d", len(train_features)) logger.info(" Num split examples = %d", len(train_features))
logger.info(" Batch size = %d", args.train_batch_size) logger.info(" Batch size = %d", args.train_batch_size)
logger.info(" Num steps = %d", num_train_optimization_steps) logger.info(" Num steps = %d", num_train_optimization_steps)
all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long)
all_start_positions = torch.tensor([f.start_position for f in train_features], dtype=torch.long)
all_end_positions = torch.tensor([f.end_position for f in train_features], dtype=torch.long)
train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids,
all_start_positions, all_end_positions)
if args.local_rank == -1:
train_sampler = RandomSampler(train_data)
else:
train_sampler = DistributedSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)
model.train() model.train()
for _ in trange(int(args.num_train_epochs), desc="Epoch"): for _ in trange(int(args.num_train_epochs), desc="Epoch"):
......
...@@ -358,15 +358,6 @@ def main(): ...@@ -358,15 +358,6 @@ def main():
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
train_examples = None
num_train_optimization_steps = None
if args.do_train:
train_examples = read_swag_examples(os.path.join(args.data_dir, 'train.csv'), is_training = True)
num_train_optimization_steps = int(
len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps) * args.num_train_epochs
if args.local_rank != -1:
num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size()
# Prepare model # Prepare model
model = BertForMultipleChoice.from_pretrained(args.bert_model, model = BertForMultipleChoice.from_pretrained(args.bert_model,
cache_dir=os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed_{}'.format(args.local_rank)), cache_dir=os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed_{}'.format(args.local_rank)),
...@@ -384,13 +375,35 @@ def main(): ...@@ -384,13 +375,35 @@ def main():
elif n_gpu > 1: elif n_gpu > 1:
model = torch.nn.DataParallel(model) model = torch.nn.DataParallel(model)
# Prepare optimizer
if args.do_train: if args.do_train:
# Prepare data loader
train_examples = read_swag_examples(os.path.join(args.data_dir, 'train.csv'), is_training = True)
train_features = convert_examples_to_features(
train_examples, tokenizer, args.max_seq_length, True)
all_input_ids = torch.tensor(select_field(train_features, 'input_ids'), dtype=torch.long)
all_input_mask = torch.tensor(select_field(train_features, 'input_mask'), dtype=torch.long)
all_segment_ids = torch.tensor(select_field(train_features, 'segment_ids'), dtype=torch.long)
all_label = torch.tensor([f.label for f in train_features], dtype=torch.long)
train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label)
if args.local_rank == -1:
train_sampler = RandomSampler(train_data)
else:
train_sampler = DistributedSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)
num_train_optimization_steps = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
if args.local_rank != -1:
num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size()
# Prepare optimizer
param_optimizer = list(model.named_parameters()) param_optimizer = list(model.named_parameters())
# hack to remove pooler, which is not used # hack to remove pooler, which is not used
# thus it produce None grad that break apex # thus it produce None grad that break apex
param_optimizer = [n for n in param_optimizer if 'pooler' not in n[0]] param_optimizer = [n for n in param_optimizer]
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [ optimizer_grouped_parameters = [
...@@ -420,24 +433,12 @@ def main(): ...@@ -420,24 +433,12 @@ def main():
warmup=args.warmup_proportion, warmup=args.warmup_proportion,
t_total=num_train_optimization_steps) t_total=num_train_optimization_steps)
global_step = 0 global_step = 0
if args.do_train:
train_features = convert_examples_to_features(
train_examples, tokenizer, args.max_seq_length, True)
logger.info("***** Running training *****") logger.info("***** Running training *****")
logger.info(" Num examples = %d", len(train_examples)) logger.info(" Num examples = %d", len(train_examples))
logger.info(" Batch size = %d", args.train_batch_size) logger.info(" Batch size = %d", args.train_batch_size)
logger.info(" Num steps = %d", num_train_optimization_steps) logger.info(" Num steps = %d", num_train_optimization_steps)
all_input_ids = torch.tensor(select_field(train_features, 'input_ids'), dtype=torch.long)
all_input_mask = torch.tensor(select_field(train_features, 'input_mask'), dtype=torch.long)
all_segment_ids = torch.tensor(select_field(train_features, 'segment_ids'), dtype=torch.long)
all_label = torch.tensor([f.label for f in train_features], dtype=torch.long)
train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label)
if args.local_rank == -1:
train_sampler = RandomSampler(train_data)
else:
train_sampler = DistributedSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)
model.train() model.train()
for _ in trange(int(args.num_train_epochs), desc="Epoch"): for _ in trange(int(args.num_train_epochs), desc="Epoch"):
......
...@@ -82,7 +82,7 @@ def bertTokenizer(*args, **kwargs): ...@@ -82,7 +82,7 @@ def bertTokenizer(*args, **kwargs):
Example: Example:
>>> sentence = 'Hello, World!' >>> sentence = 'Hello, World!'
>>> tokenizer = torch.hub.load('huggingface/pytorch-pretrained-BERT', 'bertTokenizer', 'bert-base-cased', do_basic_tokenize=False, force_reload=False) >>> tokenizer = torch.hub.load('huggingface/pytorch-pretrained-BERT', 'bertTokenizer', 'bert-base-cased', do_basic_tokenize=False)
>>> toks = tokenizer.tokenize(sentence) >>> toks = tokenizer.tokenize(sentence)
['Hello', '##,', 'World', '##!'] ['Hello', '##,', 'World', '##!']
>>> ids = tokenizer.convert_tokens_to_ids(toks) >>> ids = tokenizer.convert_tokens_to_ids(toks)
...@@ -101,19 +101,16 @@ def bertModel(*args, **kwargs): ...@@ -101,19 +101,16 @@ def bertModel(*args, **kwargs):
Example: Example:
# Load the tokenizer # Load the tokenizer
>>> tokenizer = torch.hub.load('huggingface/pytorch-pretrained-BERT', 'bertTokenizer', 'bert-base-cased', do_basic_tokenize=False, force_reload=False) >>> tokenizer = torch.hub.load('huggingface/pytorch-pretrained-BERT', 'bertTokenizer', 'bert-base-cased', do_basic_tokenize=False)
# Prepare tokenized input # Prepare tokenized input
>>> text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]" >>> text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]"
>>> tokenized_text = tokenizer.tokenize(text) >>> tokenized_text = tokenizer.tokenize(text)
['[CLS]', 'Who', 'was', 'Jim', 'He', '##nson', '?', '[SEP]', 'Jim', 'He', '##nson', 'was', 'a', 'puppet', '##eer', '[SEP]']
>>> indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text) >>> indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
>>> segments_ids = [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1] >>> segments_ids = [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1]
>>> tokens_tensor = torch.tensor([indexed_tokens]) >>> tokens_tensor = torch.tensor([indexed_tokens])
tensor([[101, 2627, 1108, 3104, 1124, 15703, 136, 102, 3104, 1124, 15703, 1108, 170, 16797, 8284, 102]])
>>> segments_tensors = torch.tensor([segments_ids]) >>> segments_tensors = torch.tensor([segments_ids])
tensor([[0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1]])
# Load bertModel # Load bertModel
>>> model = torch.hub.load('huggingface/pytorch-pretrained-BERT', 'bertModel', 'bert-base-cased', force_reload=False) >>> model = torch.hub.load('huggingface/pytorch-pretrained-BERT', 'bertModel', 'bert-base-cased')
>>> model.eval() >>> model.eval()
# Predict hidden states features for each layer # Predict hidden states features for each layer
>>> with torch.no_grad(): >>> with torch.no_grad():
...@@ -129,6 +126,23 @@ def bertForNextSentencePrediction(*args, **kwargs): ...@@ -129,6 +126,23 @@ def bertForNextSentencePrediction(*args, **kwargs):
BERT model with next sentence prediction head. BERT model with next sentence prediction head.
This module comprises the BERT model followed by the next sentence This module comprises the BERT model followed by the next sentence
classification head. classification head.
Example:
# Load the tokenizer
>>> tokenizer = torch.hub.load('huggingface/pytorch-pretrained-BERT', 'bertTokenizer', 'bert-base-cased', do_basic_tokenize=False)
# Prepare tokenized input
>>> text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]"
>>> tokenized_text = tokenizer.tokenize(text)
>>> indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
>>> segments_ids = [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1]
>>> tokens_tensor = torch.tensor([indexed_tokens])
>>> segments_tensors = torch.tensor([segments_ids])
# Load bertForNextSentencePrediction
>>> model = torch.hub.load('huggingface/pytorch-pretrained-BERT', 'bertForNextSentencePrediction', 'bert-base-cased')
>>> model.eval()
# Predict the next sentence classification logits
>>> with torch.no_grad():
next_sent_classif_logits = model(tokens_tensor, segments_tensors)
""" """
model = BertForNextSentencePrediction.from_pretrained(*args, **kwargs) model = BertForNextSentencePrediction.from_pretrained(*args, **kwargs)
return model return model
...@@ -141,6 +155,19 @@ def bertForPreTraining(*args, **kwargs): ...@@ -141,6 +155,19 @@ def bertForPreTraining(*args, **kwargs):
This module comprises the BERT model followed by the two pre-training heads This module comprises the BERT model followed by the two pre-training heads
- the masked language modeling head, and - the masked language modeling head, and
- the next sentence classification head. - the next sentence classification head.
Example:
# Load the tokenizer
>>> tokenizer = torch.hub.load('huggingface/pytorch-pretrained-BERT', 'bertTokenizer', 'bert-base-cased', do_basic_tokenize=False)
# Prepare tokenized input
>>> text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]"
>>> tokenized_text = tokenizer.tokenize(text)
>>> segments_ids = [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1]
>>> tokens_tensor = torch.tensor([indexed_tokens])
>>> segments_tensors = torch.tensor([segments_ids])
# Load bertForPreTraining
>>> model = torch.hub.load('huggingface/pytorch-pretrained-BERT', 'bertForPreTraining', 'bert-base-cased')
>>> masked_lm_logits_scores, seq_relationship_logits = model(tokens_tensor, segments_tensors)
""" """
model = BertForPreTraining.from_pretrained(*args, **kwargs) model = BertForPreTraining.from_pretrained(*args, **kwargs)
return model return model
...@@ -154,19 +181,18 @@ def bertForMaskedLM(*args, **kwargs): ...@@ -154,19 +181,18 @@ def bertForMaskedLM(*args, **kwargs):
Example: Example:
# Load the tokenizer # Load the tokenizer
>>> tokenizer = torch.hub.load('huggingface/pytorch-pretrained-BERT', 'bertTokenizer', 'bert-base-cased', do_basic_tokenize=False, force_reload=False) >>> tokenizer = torch.hub.load('huggingface/pytorch-pretrained-BERT', 'bertTokenizer', 'bert-base-cased', do_basic_tokenize=False)
# Prepare tokenized input # Prepare tokenized input
>>> text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]" >>> text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]"
>>> tokenized_text = tokenizer.tokenize(text) >>> tokenized_text = tokenizer.tokenize(text)
>>> masked_index = 8 >>> masked_index = 8
>>> tokenized_text[masked_index] = '[MASK]' >>> tokenized_text[masked_index] = '[MASK]'
['[CLS]', 'who', 'was', 'jim', 'henson', '?', '[SEP]', 'jim', '[MASK]', 'was', 'a', 'puppet', '##eer', '[SEP]']
>>> indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text) >>> indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
>>> segments_ids = [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1] >>> segments_ids = [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1]
>>> tokens_tensor = torch.tensor([indexed_tokens]) >>> tokens_tensor = torch.tensor([indexed_tokens])
>>> segments_tensors = torch.tensor([segments_ids]) >>> segments_tensors = torch.tensor([segments_ids])
# Load bertForMaskedLM # Load bertForMaskedLM
>>> model = torch.hub.load('huggingface/pytorch-pretrained-BERT', 'bertForMaskedLM', 'bert-base-cased', force_reload=False) >>> model = torch.hub.load('huggingface/pytorch-pretrained-BERT', 'bertForMaskedLM', 'bert-base-cased')
>>> model.eval() >>> model.eval()
# Predict all tokens # Predict all tokens
>>> with torch.no_grad(): >>> with torch.no_grad():
...@@ -184,7 +210,8 @@ def bertForSequenceClassification(*args, **kwargs): ...@@ -184,7 +210,8 @@ def bertForSequenceClassification(*args, **kwargs):
""" """
BertForSequenceClassification is a fine-tuning model that includes BertForSequenceClassification is a fine-tuning model that includes
BertModel and a sequence-level (sequence or pair of sequences) classifier BertModel and a sequence-level (sequence or pair of sequences) classifier
on top of the BertModel. on top of the BertModel. Note that the classification head is only initialized
and has to be trained.
The sequence-level classifier is a linear layer that takes as input the The sequence-level classifier is a linear layer that takes as input the
last hidden state of the first character in the input sequence last hidden state of the first character in the input sequence
...@@ -194,7 +221,24 @@ def bertForSequenceClassification(*args, **kwargs): ...@@ -194,7 +221,24 @@ def bertForSequenceClassification(*args, **kwargs):
num_labels: the number (>=2) of classes for the classifier. num_labels: the number (>=2) of classes for the classifier.
Example: Example:
>>> torch.hub.load('huggingface/pytorch-pretrained-BERT', 'bertForSequenceClassification', 'bert-base-cased', num_labels=2, force_reload=True) # Load the tokenizer
>>> tokenizer = torch.hub.load('huggingface/pytorch-pretrained-BERT', 'bertTokenizer', 'bert-base-cased', do_basic_tokenize=False)
# Prepare tokenized input
>>> text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]"
>>> tokenized_text = tokenizer.tokenize(text)
>>> indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
>>> segments_ids = [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1]
>>> tokens_tensor = torch.tensor([indexed_tokens])
>>> segments_tensors = torch.tensor([segments_ids])
# Load bertForSequenceClassification
>>> model = torch.hub.load('huggingface/pytorch-pretrained-BERT', 'bertForSequenceClassification', 'bert-base-cased', num_labels=2)
>>> model.eval()
# Predict the sequence classification logits
>>> with torch.no_grad():
seq_classif_logits = model(tokens_tensor, segments_tensors)
# Or get the sequence classification loss
>>> labels = torch.tensor([1])
>>> seq_classif_loss = model(tokens_tensor, segments_tensors, labels=labels) # set model.train() before if training this loss
""" """
model = BertForSequenceClassification.from_pretrained(*args, **kwargs) model = BertForSequenceClassification.from_pretrained(*args, **kwargs)
return model return model
...@@ -204,13 +248,31 @@ def bertForSequenceClassification(*args, **kwargs): ...@@ -204,13 +248,31 @@ def bertForSequenceClassification(*args, **kwargs):
def bertForMultipleChoice(*args, **kwargs): def bertForMultipleChoice(*args, **kwargs):
""" """
BertForMultipleChoice is a fine-tuning model that includes BertModel and a BertForMultipleChoice is a fine-tuning model that includes BertModel and a
linear layer on top of the BertModel. linear layer on top of the BertModel. Note that the multiple choice head is
only initialized and has to be trained.
Args: Args:
num_choices: the number (>=2) of classes for the classifier. num_choices: the number (>=2) of classes for the classifier.
Example: Example:
>>> torch.hub.load('huggingface/pytorch-pretrained-BERT', 'bertForMultipleChoice', 'bert-base-cased', num_choices=2, force_reload=True) # Load the tokenizer
>>> tokenizer = torch.hub.load('huggingface/pytorch-pretrained-BERT', 'bertTokenizer', 'bert-base-cased', do_basic_tokenize=False)
# Prepare tokenized input
>>> text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]"
>>> tokenized_text = tokenizer.tokenize(text)
>>> indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
>>> segments_ids = [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1]
>>> tokens_tensor = torch.tensor([indexed_tokens, indexed_tokens]).unsqueeze(0)
>>> segments_tensors = torch.tensor([segments_ids, segments_ids]).unsqueeze(0)
# Load bertForMultipleChoice
>>> model = torch.hub.load('huggingface/pytorch-pretrained-BERT', 'bertForMultipleChoice', 'bert-base-cased', num_choices=2)
>>> model.eval()
# Predict the multiple choice logits
>>> with torch.no_grad():
multiple_choice_logits = model(tokens_tensor, segments_tensors)
# Or get the multiple choice loss
>>> labels = torch.tensor([1])
>>> multiple_choice_loss = model(tokens_tensor, segments_tensors, labels=labels) # set model.train() before if training this loss
""" """
model = BertForMultipleChoice.from_pretrained(*args, **kwargs) model = BertForMultipleChoice.from_pretrained(*args, **kwargs)
return model return model
...@@ -221,7 +283,29 @@ def bertForQuestionAnswering(*args, **kwargs): ...@@ -221,7 +283,29 @@ def bertForQuestionAnswering(*args, **kwargs):
""" """
BertForQuestionAnswering is a fine-tuning model that includes BertModel BertForQuestionAnswering is a fine-tuning model that includes BertModel
with a token-level classifiers on top of the full sequence of last hidden with a token-level classifiers on top of the full sequence of last hidden
states. states. Note that the classification head is only initialized
and has to be trained.
Example:
# Load the tokenizer
>>> tokenizer = torch.hub.load('huggingface/pytorch-pretrained-BERT', 'bertTokenizer', 'bert-base-cased', do_basic_tokenize=False)
# Prepare tokenized input
>>> text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]"
>>> tokenized_text = tokenizer.tokenize(text)
>>> indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
>>> segments_ids = [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1]
>>> tokens_tensor = torch.tensor([indexed_tokens])
>>> segments_tensors = torch.tensor([segments_ids])
# Load bertForQuestionAnswering
>>> model = torch.hub.load('huggingface/pytorch-pretrained-BERT', 'bertForQuestionAnswering', 'bert-base-cased')
>>> model.eval()
# Predict the start and end positions logits
>>> with torch.no_grad():
start_logits, end_logits = model(tokens_tensor, segments_tensors)
# Or get the total loss which is the sum of the CrossEntropy loss for the start and end token positions
>>> start_positions, end_positions = torch.tensor([12]), torch.tensor([14])
# set model.train() before if training this loss
>>> multiple_choice_loss = model(tokens_tensor, segments_tensors, start_positions=start_positions, end_positions=end_positions)
""" """
model = BertForQuestionAnswering.from_pretrained(*args, **kwargs) model = BertForQuestionAnswering.from_pretrained(*args, **kwargs)
return model return model
...@@ -231,7 +315,8 @@ def bertForQuestionAnswering(*args, **kwargs): ...@@ -231,7 +315,8 @@ def bertForQuestionAnswering(*args, **kwargs):
def bertForTokenClassification(*args, **kwargs): def bertForTokenClassification(*args, **kwargs):
""" """
BertForTokenClassification is a fine-tuning model that includes BertModel BertForTokenClassification is a fine-tuning model that includes BertModel
and a token-level classifier on top of the BertModel. and a token-level classifier on top of the BertModel. Note that the classification
head is only initialized and has to be trained.
The token-level classifier is a linear layer that takes as input the last The token-level classifier is a linear layer that takes as input the last
hidden state of the sequence. hidden state of the sequence.
...@@ -240,7 +325,24 @@ def bertForTokenClassification(*args, **kwargs): ...@@ -240,7 +325,24 @@ def bertForTokenClassification(*args, **kwargs):
num_labels: the number (>=2) of classes for the classifier. num_labels: the number (>=2) of classes for the classifier.
Example: Example:
>>> torch.hub.load('huggingface/pytorch-pretrained-BERT', 'bertForTokenClassification', 'bert-base-cased', num_labels=2, force_reload=True) # Load the tokenizer
>>> tokenizer = torch.hub.load('huggingface/pytorch-pretrained-BERT', 'bertTokenizer', 'bert-base-cased', do_basic_tokenize=False)
# Prepare tokenized input
>>> text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]"
>>> tokenized_text = tokenizer.tokenize(text)
>>> indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
>>> segments_ids = [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1]
>>> tokens_tensor = torch.tensor([indexed_tokens])
>>> segments_tensors = torch.tensor([segments_ids])
# Load bertForTokenClassification
>>> model = torch.hub.load('huggingface/pytorch-pretrained-BERT', 'bertForTokenClassification', 'bert-base-cased', num_labels=2)
>>> model.eval()
# Predict the token classification logits
>>> with torch.no_grad():
classif_logits = model(tokens_tensor, segments_tensors)
# Or get the token classification loss
>>> labels = torch.tensor([[0, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 1, 0, 0, 1, 0]])
>>> classif_loss = model(tokens_tensor, segments_tensors, labels=labels) # set model.train() before if training this loss
""" """
model = BertForTokenClassification.from_pretrained(*args, **kwargs) model = BertForTokenClassification.from_pretrained(*args, **kwargs)
return model return model
...@@ -278,12 +278,13 @@ class BertEmbeddings(nn.Module): ...@@ -278,12 +278,13 @@ class BertEmbeddings(nn.Module):
class BertSelfAttention(nn.Module): class BertSelfAttention(nn.Module):
def __init__(self, config): def __init__(self, config, output_attentions=False):
super(BertSelfAttention, self).__init__() super(BertSelfAttention, self).__init__()
if config.hidden_size % config.num_attention_heads != 0: if config.hidden_size % config.num_attention_heads != 0:
raise ValueError( raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention " "The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (config.hidden_size, config.num_attention_heads)) "heads (%d)" % (config.hidden_size, config.num_attention_heads))
self.output_attentions = output_attentions
self.num_attention_heads = config.num_attention_heads self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads) self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size self.all_head_size = self.num_attention_heads * self.attention_head_size
...@@ -325,6 +326,8 @@ class BertSelfAttention(nn.Module): ...@@ -325,6 +326,8 @@ class BertSelfAttention(nn.Module):
context_layer = context_layer.permute(0, 2, 1, 3).contiguous() context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape) context_layer = context_layer.view(*new_context_layer_shape)
if self.output_attentions:
return attention_probs, context_layer
return context_layer return context_layer
...@@ -343,14 +346,19 @@ class BertSelfOutput(nn.Module): ...@@ -343,14 +346,19 @@ class BertSelfOutput(nn.Module):
class BertAttention(nn.Module): class BertAttention(nn.Module):
def __init__(self, config): def __init__(self, config, output_attentions=False):
super(BertAttention, self).__init__() super(BertAttention, self).__init__()
self.self = BertSelfAttention(config) self.output_attentions = output_attentions
self.self = BertSelfAttention(config, output_attentions=output_attentions)
self.output = BertSelfOutput(config) self.output = BertSelfOutput(config)
def forward(self, input_tensor, attention_mask): def forward(self, input_tensor, attention_mask):
self_output = self.self(input_tensor, attention_mask) self_output = self.self(input_tensor, attention_mask)
if self.output_attentions:
attentions, self_output = self_output
attention_output = self.output(self_output, input_tensor) attention_output = self.output(self_output, input_tensor)
if self.output_attentions:
return attentions, attention_output
return attention_output return attention_output
...@@ -384,33 +392,45 @@ class BertOutput(nn.Module): ...@@ -384,33 +392,45 @@ class BertOutput(nn.Module):
class BertLayer(nn.Module): class BertLayer(nn.Module):
def __init__(self, config): def __init__(self, config, output_attentions=False):
super(BertLayer, self).__init__() super(BertLayer, self).__init__()
self.attention = BertAttention(config) self.output_attentions = output_attentions
self.attention = BertAttention(config, output_attentions=output_attentions)
self.intermediate = BertIntermediate(config) self.intermediate = BertIntermediate(config)
self.output = BertOutput(config) self.output = BertOutput(config)
def forward(self, hidden_states, attention_mask): def forward(self, hidden_states, attention_mask):
attention_output = self.attention(hidden_states, attention_mask) attention_output = self.attention(hidden_states, attention_mask)
if self.output_attentions:
attentions, attention_output = attention_output
intermediate_output = self.intermediate(attention_output) intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output) layer_output = self.output(intermediate_output, attention_output)
if self.output_attentions:
return attentions, layer_output
return layer_output return layer_output
class BertEncoder(nn.Module): class BertEncoder(nn.Module):
def __init__(self, config): def __init__(self, config, output_attentions=False):
super(BertEncoder, self).__init__() super(BertEncoder, self).__init__()
layer = BertLayer(config) self.output_attentions = output_attentions
layer = BertLayer(config, output_attentions=output_attentions)
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)]) self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True): def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True):
all_encoder_layers = [] all_encoder_layers = []
all_attentions = []
for layer_module in self.layer: for layer_module in self.layer:
hidden_states = layer_module(hidden_states, attention_mask) hidden_states = layer_module(hidden_states, attention_mask)
if self.output_attentions:
attentions, hidden_states = hidden_states
all_attentions.append(attentions)
if output_all_encoded_layers: if output_all_encoded_layers:
all_encoder_layers.append(hidden_states) all_encoder_layers.append(hidden_states)
if not output_all_encoded_layers: if not output_all_encoded_layers:
all_encoder_layers.append(hidden_states) all_encoder_layers.append(hidden_states)
if self.output_attentions:
return all_attentions, all_encoder_layers
return all_encoder_layers return all_encoder_layers
...@@ -702,10 +722,11 @@ class BertModel(BertPreTrainedModel): ...@@ -702,10 +722,11 @@ class BertModel(BertPreTrainedModel):
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
``` ```
""" """
def __init__(self, config): def __init__(self, config, output_attentions=False):
super(BertModel, self).__init__(config) super(BertModel, self).__init__(config)
self.output_attentions = output_attentions
self.embeddings = BertEmbeddings(config) self.embeddings = BertEmbeddings(config)
self.encoder = BertEncoder(config) self.encoder = BertEncoder(config, output_attentions=output_attentions)
self.pooler = BertPooler(config) self.pooler = BertPooler(config)
self.apply(self.init_bert_weights) self.apply(self.init_bert_weights)
...@@ -734,10 +755,14 @@ class BertModel(BertPreTrainedModel): ...@@ -734,10 +755,14 @@ class BertModel(BertPreTrainedModel):
encoded_layers = self.encoder(embedding_output, encoded_layers = self.encoder(embedding_output,
extended_attention_mask, extended_attention_mask,
output_all_encoded_layers=output_all_encoded_layers) output_all_encoded_layers=output_all_encoded_layers)
if self.output_attentions:
all_attentions, encoded_layers = encoded_layers
sequence_output = encoded_layers[-1] sequence_output = encoded_layers[-1]
pooled_output = self.pooler(sequence_output) pooled_output = self.pooler(sequence_output)
if not output_all_encoded_layers: if not output_all_encoded_layers:
encoded_layers = encoded_layers[-1] encoded_layers = encoded_layers[-1]
if self.output_attentions:
return all_attentions, encoded_layers, pooled_output
return encoded_layers, pooled_output return encoded_layers, pooled_output
...@@ -791,15 +816,20 @@ class BertForPreTraining(BertPreTrainedModel): ...@@ -791,15 +816,20 @@ class BertForPreTraining(BertPreTrainedModel):
masked_lm_logits_scores, seq_relationship_logits = model(input_ids, token_type_ids, input_mask) masked_lm_logits_scores, seq_relationship_logits = model(input_ids, token_type_ids, input_mask)
``` ```
""" """
def __init__(self, config): def __init__(self, config, output_attentions=False):
super(BertForPreTraining, self).__init__(config) super(BertForPreTraining, self).__init__(config)
self.bert = BertModel(config) self.output_attentions = output_attentions
self.bert = BertModel(config, output_attentions=output_attentions)
self.cls = BertPreTrainingHeads(config, self.bert.embeddings.word_embeddings.weight) self.cls = BertPreTrainingHeads(config, self.bert.embeddings.word_embeddings.weight)
self.apply(self.init_bert_weights) self.apply(self.init_bert_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, next_sentence_label=None): def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, next_sentence_label=None):
sequence_output, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, outputs = self.bert(input_ids, token_type_ids, attention_mask,
output_all_encoded_layers=False) output_all_encoded_layers=False)
if self.output_attentions:
all_attentions, sequence_output, pooled_output = outputs
else:
sequence_output, pooled_output = outputs
prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output) prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
if masked_lm_labels is not None and next_sentence_label is not None: if masked_lm_labels is not None and next_sentence_label is not None:
...@@ -808,8 +838,9 @@ class BertForPreTraining(BertPreTrainedModel): ...@@ -808,8 +838,9 @@ class BertForPreTraining(BertPreTrainedModel):
next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
total_loss = masked_lm_loss + next_sentence_loss total_loss = masked_lm_loss + next_sentence_loss
return total_loss return total_loss
else: elif self.output_attentions:
return prediction_scores, seq_relationship_score return all_attentions, prediction_scores, seq_relationship_score
return prediction_scores, seq_relationship_score
class BertForMaskedLM(BertPreTrainedModel): class BertForMaskedLM(BertPreTrainedModel):
...@@ -854,23 +885,29 @@ class BertForMaskedLM(BertPreTrainedModel): ...@@ -854,23 +885,29 @@ class BertForMaskedLM(BertPreTrainedModel):
masked_lm_logits_scores = model(input_ids, token_type_ids, input_mask) masked_lm_logits_scores = model(input_ids, token_type_ids, input_mask)
``` ```
""" """
def __init__(self, config): def __init__(self, config, output_attentions=False):
super(BertForMaskedLM, self).__init__(config) super(BertForMaskedLM, self).__init__(config)
self.bert = BertModel(config) self.output_attentions = output_attentions
self.bert = BertModel(config, output_attentions=output_attentions)
self.cls = BertOnlyMLMHead(config, self.bert.embeddings.word_embeddings.weight) self.cls = BertOnlyMLMHead(config, self.bert.embeddings.word_embeddings.weight)
self.apply(self.init_bert_weights) self.apply(self.init_bert_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None): def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None):
sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, outputs = self.bert(input_ids, token_type_ids, attention_mask,
output_all_encoded_layers=False) output_all_encoded_layers=False)
if self.output_attentions:
all_attentions, sequence_output, _ = outputs
else:
sequence_output, _ = outputs
prediction_scores = self.cls(sequence_output) prediction_scores = self.cls(sequence_output)
if masked_lm_labels is not None: if masked_lm_labels is not None:
loss_fct = CrossEntropyLoss(ignore_index=-1) loss_fct = CrossEntropyLoss(ignore_index=-1)
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1)) masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
return masked_lm_loss return masked_lm_loss
else: elif self.output_attentions:
return prediction_scores return all_attentions, prediction_scores
return prediction_scores
class BertForNextSentencePrediction(BertPreTrainedModel): class BertForNextSentencePrediction(BertPreTrainedModel):
...@@ -916,23 +953,29 @@ class BertForNextSentencePrediction(BertPreTrainedModel): ...@@ -916,23 +953,29 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
seq_relationship_logits = model(input_ids, token_type_ids, input_mask) seq_relationship_logits = model(input_ids, token_type_ids, input_mask)
``` ```
""" """
def __init__(self, config): def __init__(self, config, output_attentions=False):
super(BertForNextSentencePrediction, self).__init__(config) super(BertForNextSentencePrediction, self).__init__(config)
self.bert = BertModel(config) self.output_attentions = output_attentions
self.bert = BertModel(config, output_attentions=output_attentions)
self.cls = BertOnlyNSPHead(config) self.cls = BertOnlyNSPHead(config)
self.apply(self.init_bert_weights) self.apply(self.init_bert_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, next_sentence_label=None): def forward(self, input_ids, token_type_ids=None, attention_mask=None, next_sentence_label=None):
_, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, outputs = self.bert(input_ids, token_type_ids, attention_mask,
output_all_encoded_layers=False) output_all_encoded_layers=False)
seq_relationship_score = self.cls( pooled_output) if self.output_attentions:
all_attentions, _, pooled_output = outputs
else:
_, pooled_output = outputs
seq_relationship_score = self.cls(pooled_output)
if next_sentence_label is not None: if next_sentence_label is not None:
loss_fct = CrossEntropyLoss(ignore_index=-1) loss_fct = CrossEntropyLoss(ignore_index=-1)
next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
return next_sentence_loss return next_sentence_loss
else: elif self.output_attentions:
return seq_relationship_score return all_attentions, seq_relationship_score
return seq_relationship_score
class BertForSequenceClassification(BertPreTrainedModel): class BertForSequenceClassification(BertPreTrainedModel):
...@@ -980,16 +1023,21 @@ class BertForSequenceClassification(BertPreTrainedModel): ...@@ -980,16 +1023,21 @@ class BertForSequenceClassification(BertPreTrainedModel):
logits = model(input_ids, token_type_ids, input_mask) logits = model(input_ids, token_type_ids, input_mask)
``` ```
""" """
def __init__(self, config, num_labels=2): def __init__(self, config, num_labels=2, output_attentions=False):
super(BertForSequenceClassification, self).__init__(config) super(BertForSequenceClassification, self).__init__(config)
self.output_attentions = output_attentions
self.num_labels = num_labels self.num_labels = num_labels
self.bert = BertModel(config) self.bert = BertModel(config, output_attentions=output_attentions)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, num_labels) self.classifier = nn.Linear(config.hidden_size, num_labels)
self.apply(self.init_bert_weights) self.apply(self.init_bert_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
_, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) outputs = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
if self.output_attentions:
all_attentions, _, pooled_output = outputs
else:
_, pooled_output = outputs
pooled_output = self.dropout(pooled_output) pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output) logits = self.classifier(pooled_output)
...@@ -997,8 +1045,9 @@ class BertForSequenceClassification(BertPreTrainedModel): ...@@ -997,8 +1045,9 @@ class BertForSequenceClassification(BertPreTrainedModel):
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
return loss return loss
else: elif self.output_attentions:
return logits return all_attentions, logits
return logits
class BertForMultipleChoice(BertPreTrainedModel): class BertForMultipleChoice(BertPreTrainedModel):
...@@ -1045,10 +1094,11 @@ class BertForMultipleChoice(BertPreTrainedModel): ...@@ -1045,10 +1094,11 @@ class BertForMultipleChoice(BertPreTrainedModel):
logits = model(input_ids, token_type_ids, input_mask) logits = model(input_ids, token_type_ids, input_mask)
``` ```
""" """
def __init__(self, config, num_choices=2): def __init__(self, config, num_choices=2, output_attentions=False):
super(BertForMultipleChoice, self).__init__(config) super(BertForMultipleChoice, self).__init__(config)
self.output_attentions = output_attentions
self.num_choices = num_choices self.num_choices = num_choices
self.bert = BertModel(config) self.bert = BertModel(config, output_attentions=output_attentions)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, 1) self.classifier = nn.Linear(config.hidden_size, 1)
self.apply(self.init_bert_weights) self.apply(self.init_bert_weights)
...@@ -1057,7 +1107,11 @@ class BertForMultipleChoice(BertPreTrainedModel): ...@@ -1057,7 +1107,11 @@ class BertForMultipleChoice(BertPreTrainedModel):
flat_input_ids = input_ids.view(-1, input_ids.size(-1)) flat_input_ids = input_ids.view(-1, input_ids.size(-1))
flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
_, pooled_output = self.bert(flat_input_ids, flat_token_type_ids, flat_attention_mask, output_all_encoded_layers=False) outputs = self.bert(flat_input_ids, flat_token_type_ids, flat_attention_mask, output_all_encoded_layers=False)
if self.output_attentions:
all_attentions, _, pooled_output = outputs
else:
_, pooled_output = outputs
pooled_output = self.dropout(pooled_output) pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output) logits = self.classifier(pooled_output)
reshaped_logits = logits.view(-1, self.num_choices) reshaped_logits = logits.view(-1, self.num_choices)
...@@ -1066,8 +1120,9 @@ class BertForMultipleChoice(BertPreTrainedModel): ...@@ -1066,8 +1120,9 @@ class BertForMultipleChoice(BertPreTrainedModel):
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
loss = loss_fct(reshaped_logits, labels) loss = loss_fct(reshaped_logits, labels)
return loss return loss
else: elif self.output_attentions:
return reshaped_logits return all_attentions, reshaped_logits
return reshaped_logits
class BertForTokenClassification(BertPreTrainedModel): class BertForTokenClassification(BertPreTrainedModel):
...@@ -1115,16 +1170,21 @@ class BertForTokenClassification(BertPreTrainedModel): ...@@ -1115,16 +1170,21 @@ class BertForTokenClassification(BertPreTrainedModel):
logits = model(input_ids, token_type_ids, input_mask) logits = model(input_ids, token_type_ids, input_mask)
``` ```
""" """
def __init__(self, config, num_labels=2): def __init__(self, config, num_labels=2, output_attentions=False):
super(BertForTokenClassification, self).__init__(config) super(BertForTokenClassification, self).__init__(config)
self.output_attentions = output_attentions
self.num_labels = num_labels self.num_labels = num_labels
self.bert = BertModel(config) self.bert = BertModel(config, output_attentions=output_attentions)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, num_labels) self.classifier = nn.Linear(config.hidden_size, num_labels)
self.apply(self.init_bert_weights) self.apply(self.init_bert_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) outputs = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
if self.output_attentions:
all_attentions, sequence_output, _ = outputs
else:
sequence_output, _ = outputs
sequence_output = self.dropout(sequence_output) sequence_output = self.dropout(sequence_output)
logits = self.classifier(sequence_output) logits = self.classifier(sequence_output)
...@@ -1139,8 +1199,9 @@ class BertForTokenClassification(BertPreTrainedModel): ...@@ -1139,8 +1199,9 @@ class BertForTokenClassification(BertPreTrainedModel):
else: else:
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
return loss return loss
else: elif self.output_attentions:
return logits return all_attentions, logits
return logits
class BertForQuestionAnswering(BertPreTrainedModel): class BertForQuestionAnswering(BertPreTrainedModel):
...@@ -1190,16 +1251,19 @@ class BertForQuestionAnswering(BertPreTrainedModel): ...@@ -1190,16 +1251,19 @@ class BertForQuestionAnswering(BertPreTrainedModel):
start_logits, end_logits = model(input_ids, token_type_ids, input_mask) start_logits, end_logits = model(input_ids, token_type_ids, input_mask)
``` ```
""" """
def __init__(self, config): def __init__(self, config, output_attentions=False):
super(BertForQuestionAnswering, self).__init__(config) super(BertForQuestionAnswering, self).__init__(config)
self.bert = BertModel(config) self.output_attentions = output_attentions
# TODO check with Google if it's normal there is no dropout on the token classifier of SQuAD in the TF version self.bert = BertModel(config, output_attentions=output_attentions)
# self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.qa_outputs = nn.Linear(config.hidden_size, 2) self.qa_outputs = nn.Linear(config.hidden_size, 2)
self.apply(self.init_bert_weights) self.apply(self.init_bert_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None, end_positions=None): def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None, end_positions=None):
sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) outputs = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
if self.output_attentions:
all_attentions, sequence_output, _ = outputs
else:
sequence_output, _ = outputs
logits = self.qa_outputs(sequence_output) logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1) start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1) start_logits = start_logits.squeeze(-1)
...@@ -1221,5 +1285,6 @@ class BertForQuestionAnswering(BertPreTrainedModel): ...@@ -1221,5 +1285,6 @@ class BertForQuestionAnswering(BertPreTrainedModel):
end_loss = loss_fct(end_logits, end_positions) end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2 total_loss = (start_loss + end_loss) / 2
return total_loss return total_loss
else: elif self.output_attentions:
return start_logits, end_logits return all_attentions, start_logits, end_logits
return start_logits, end_logits
...@@ -39,8 +39,10 @@ from .modeling import BertLayerNorm as LayerNorm ...@@ -39,8 +39,10 @@ from .modeling import BertLayerNorm as LayerNorm
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
PRETRAINED_MODEL_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin"} PRETRAINED_MODEL_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin",
PRETRAINED_CONFIG_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json"} "gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-pytorch_model.bin"}
PRETRAINED_CONFIG_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json",
"gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-config.json"}
def load_tf_weights_in_gpt2(model, gpt2_checkpoint_path): def load_tf_weights_in_gpt2(model, gpt2_checkpoint_path):
""" Load tf checkpoints in a pytorch model """ Load tf checkpoints in a pytorch model
...@@ -107,18 +109,24 @@ class GPT2Config(object): ...@@ -107,18 +109,24 @@ class GPT2Config(object):
def __init__( def __init__(
self, self,
vocab_size_or_config_json_file=50257, vocab_size_or_config_json_file=50257,
n_special=0,
n_positions=1024, n_positions=1024,
n_ctx=1024, n_ctx=1024,
n_embd=768, n_embd=768,
n_layer=12, n_layer=12,
n_head=12, n_head=12,
resid_pdrop=0.1,
embd_pdrop=0.1,
attn_pdrop=0.1,
layer_norm_epsilon=1e-5, layer_norm_epsilon=1e-5,
initializer_range=0.02, initializer_range=0.02,
predict_special_tokens=True
): ):
"""Constructs GPT2Config. """Constructs GPT2Config.
Args: Args:
vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `GPT2Model` or a configuration json file. vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `GPT2Model` or a configuration json file.
n_special: The number of special tokens to learn during fine-tuning ('[SEP]', '[CLF]', ...)
n_positions: Number of positional embeddings. n_positions: Number of positional embeddings.
n_ctx: Size of the causal mask (usually same as n_positions). n_ctx: Size of the causal mask (usually same as n_positions).
n_embd: Dimensionality of the embeddings and hidden states. n_embd: Dimensionality of the embeddings and hidden states.
...@@ -126,8 +134,14 @@ class GPT2Config(object): ...@@ -126,8 +134,14 @@ class GPT2Config(object):
n_head: Number of attention heads for each attention layer in n_head: Number of attention heads for each attention layer in
the Transformer encoder. the Transformer encoder.
layer_norm_epsilon: epsilon to use in the layer norm layers layer_norm_epsilon: epsilon to use in the layer norm layers
resid_pdrop: The dropout probabilitiy for all fully connected
layers in the embeddings, encoder, and pooler.
attn_pdrop: The dropout ratio for the attention
probabilities.
embd_pdrop: The dropout ratio for the embeddings.
initializer_range: The sttdev of the truncated_normal_initializer for initializer_range: The sttdev of the truncated_normal_initializer for
initializing all weight matrices. initializing all weight matrices.
predict_special_tokens: should we predict special tokens (when the model has a LM head)
""" """
if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2 if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
and isinstance(vocab_size_or_config_json_file, unicode)): and isinstance(vocab_size_or_config_json_file, unicode)):
...@@ -137,19 +151,28 @@ class GPT2Config(object): ...@@ -137,19 +151,28 @@ class GPT2Config(object):
self.__dict__[key] = value self.__dict__[key] = value
elif isinstance(vocab_size_or_config_json_file, int): elif isinstance(vocab_size_or_config_json_file, int):
self.vocab_size = vocab_size_or_config_json_file self.vocab_size = vocab_size_or_config_json_file
self.n_special = n_special
self.n_ctx = n_ctx self.n_ctx = n_ctx
self.n_positions = n_positions self.n_positions = n_positions
self.n_embd = n_embd self.n_embd = n_embd
self.n_layer = n_layer self.n_layer = n_layer
self.n_head = n_head self.n_head = n_head
self.resid_pdrop = resid_pdrop
self.embd_pdrop = embd_pdrop
self.attn_pdrop = attn_pdrop
self.layer_norm_epsilon = layer_norm_epsilon self.layer_norm_epsilon = layer_norm_epsilon
self.initializer_range = initializer_range self.initializer_range = initializer_range
self.predict_special_tokens = predict_special_tokens
else: else:
raise ValueError( raise ValueError(
"First argument must be either a vocabulary size (int)" "First argument must be either a vocabulary size (int)"
"or the path to a pretrained model config file (str)" "or the path to a pretrained model config file (str)"
) )
@property
def total_tokens_embeddings(self):
return self.vocab_size + self.n_special
@classmethod @classmethod
def from_dict(cls, json_object): def from_dict(cls, json_object):
"""Constructs a `GPT2Config` from a Python dictionary of parameters.""" """Constructs a `GPT2Config` from a Python dictionary of parameters."""
...@@ -200,7 +223,7 @@ class Conv1D(nn.Module): ...@@ -200,7 +223,7 @@ class Conv1D(nn.Module):
class Attention(nn.Module): class Attention(nn.Module):
def __init__(self, nx, n_ctx, config, scale=False): def __init__(self, nx, n_ctx, config, scale=False, output_attentions=False):
super(Attention, self).__init__() super(Attention, self).__init__()
n_state = nx # in Attention: n_state=768 (nx=n_embd) n_state = nx # in Attention: n_state=768 (nx=n_embd)
# [switch nx => n_state from Block to Attention to keep identical to TF implem] # [switch nx => n_state from Block to Attention to keep identical to TF implem]
...@@ -209,8 +232,11 @@ class Attention(nn.Module): ...@@ -209,8 +232,11 @@ class Attention(nn.Module):
self.n_head = config.n_head self.n_head = config.n_head
self.split_size = n_state self.split_size = n_state
self.scale = scale self.scale = scale
self.output_attentions = output_attentions
self.c_attn = Conv1D(n_state * 3, nx) self.c_attn = Conv1D(n_state * 3, nx)
self.c_proj = Conv1D(n_state, nx) self.c_proj = Conv1D(n_state, nx)
self.attn_dropout = nn.Dropout(config.attn_pdrop)
self.resid_dropout = nn.Dropout(config.resid_pdrop)
def _attn(self, q, k, v): def _attn(self, q, k, v):
w = torch.matmul(q, k) w = torch.matmul(q, k)
...@@ -221,6 +247,9 @@ class Attention(nn.Module): ...@@ -221,6 +247,9 @@ class Attention(nn.Module):
w = w * b - 1e4 * (1 - b) w = w * b - 1e4 * (1 - b)
w = nn.Softmax(dim=-1)(w) w = nn.Softmax(dim=-1)(w)
w = self.attn_dropout(w)
if self.output_attentions:
return w, torch.matmul(w, v)
return torch.matmul(w, v) return torch.matmul(w, v)
def merge_heads(self, x): def merge_heads(self, x):
...@@ -248,8 +277,13 @@ class Attention(nn.Module): ...@@ -248,8 +277,13 @@ class Attention(nn.Module):
value = torch.cat((past_value, value), dim=-2) value = torch.cat((past_value, value), dim=-2)
present = torch.stack((key.transpose(-2, -1), value)) # transpose to have same shapes for stacking present = torch.stack((key.transpose(-2, -1), value)) # transpose to have same shapes for stacking
a = self._attn(query, key, value) a = self._attn(query, key, value)
if self.output_attentions:
attentions, a = a
a = self.merge_heads(a) a = self.merge_heads(a)
a = self.c_proj(a) a = self.c_proj(a)
a = self.resid_dropout(a)
if self.output_attentions:
return attentions, a, present
return a, present return a, present
...@@ -260,27 +294,35 @@ class MLP(nn.Module): ...@@ -260,27 +294,35 @@ class MLP(nn.Module):
self.c_fc = Conv1D(n_state, nx) self.c_fc = Conv1D(n_state, nx)
self.c_proj = Conv1D(nx, n_state) self.c_proj = Conv1D(nx, n_state)
self.act = gelu self.act = gelu
self.dropout = nn.Dropout(config.resid_pdrop)
def forward(self, x): def forward(self, x):
h = self.act(self.c_fc(x)) h = self.act(self.c_fc(x))
h2 = self.c_proj(h) h2 = self.c_proj(h)
return h2 return self.dropout(h2)
class Block(nn.Module): class Block(nn.Module):
def __init__(self, n_ctx, config, scale=False): def __init__(self, n_ctx, config, scale=False, output_attentions=False):
super(Block, self).__init__() super(Block, self).__init__()
nx = config.n_embd nx = config.n_embd
self.output_attentions = output_attentions
self.ln_1 = LayerNorm(nx, eps=config.layer_norm_epsilon) self.ln_1 = LayerNorm(nx, eps=config.layer_norm_epsilon)
self.attn = Attention(nx, n_ctx, config, scale) self.attn = Attention(nx, n_ctx, config, scale, output_attentions)
self.ln_2 = LayerNorm(nx, eps=config.layer_norm_epsilon) self.ln_2 = LayerNorm(nx, eps=config.layer_norm_epsilon)
self.mlp = MLP(4 * nx, config) self.mlp = MLP(4 * nx, config)
def forward(self, x, layer_past=None): def forward(self, x, layer_past=None):
a, present = self.attn(self.ln_1(x), layer_past=layer_past) output_attn = self.attn(self.ln_1(x), layer_past=layer_past)
if self.output_attentions:
attentions, a, present = output_attn
else:
a, present = output_attn
x = x + a x = x + a
m = self.mlp(self.ln_2(x)) m = self.mlp(self.ln_2(x))
x = x + m x = x + m
if self.output_attentions:
return attentions, x, present
return x, present return x, present
...@@ -290,17 +332,20 @@ class GPT2LMHead(nn.Module): ...@@ -290,17 +332,20 @@ class GPT2LMHead(nn.Module):
def __init__(self, model_embeddings_weights, config): def __init__(self, model_embeddings_weights, config):
super(GPT2LMHead, self).__init__() super(GPT2LMHead, self).__init__()
self.n_embd = config.n_embd self.n_embd = config.n_embd
self.set_embeddings_weights(model_embeddings_weights) self.vocab_size = config.vocab_size
self.predict_special_tokens = config.predict_special_tokens
def set_embeddings_weights(self, model_embeddings_weights):
embed_shape = model_embeddings_weights.shape embed_shape = model_embeddings_weights.shape
self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False) self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False)
self.set_embeddings_weights(model_embeddings_weights)
def set_embeddings_weights(self, model_embeddings_weights, predict_special_tokens=True):
self.predict_special_tokens = predict_special_tokens
self.decoder.weight = model_embeddings_weights # Tied weights self.decoder.weight = model_embeddings_weights # Tied weights
def forward(self, hidden_state): def forward(self, hidden_state):
# Truncated Language modeling logits (we remove the last token)
# h_trunc = h[:, :-1].contiguous().view(-1, self.n_embd)
lm_logits = self.decoder(hidden_state) lm_logits = self.decoder(hidden_state)
if not self.predict_special_tokens:
lm_logits = lm_logits[..., :self.vocab_size]
return lm_logits return lm_logits
...@@ -310,6 +355,7 @@ class GPT2MultipleChoiceHead(nn.Module): ...@@ -310,6 +355,7 @@ class GPT2MultipleChoiceHead(nn.Module):
def __init__(self, config): def __init__(self, config):
super(GPT2MultipleChoiceHead, self).__init__() super(GPT2MultipleChoiceHead, self).__init__()
self.n_embd = config.n_embd self.n_embd = config.n_embd
self.dropout = nn.Dropout2d(config.resid_pdrop) # To reproduce the noise_shape parameter of TF implementation
self.linear = nn.Linear(config.n_embd, 1) self.linear = nn.Linear(config.n_embd, 1)
nn.init.normal_(self.linear.weight, std=0.02) nn.init.normal_(self.linear.weight, std=0.02)
...@@ -323,6 +369,7 @@ class GPT2MultipleChoiceHead(nn.Module): ...@@ -323,6 +369,7 @@ class GPT2MultipleChoiceHead(nn.Module):
# (bsz, num_choices, 1, hidden_size) # (bsz, num_choices, 1, hidden_size)
multiple_choice_h = hidden_states.gather(2, mc_token_ids).squeeze(2) multiple_choice_h = hidden_states.gather(2, mc_token_ids).squeeze(2)
# (bsz, num_choices, hidden_size) # (bsz, num_choices, hidden_size)
multiple_choice_h = self.dropout(multiple_choice_h.transpose(1, 2)).transpose(1, 2)
multiple_choice_logits = self.linear(multiple_choice_h).squeeze(-1) multiple_choice_logits = self.linear(multiple_choice_h).squeeze(-1)
# (bsz, num_choices) # (bsz, num_choices)
return multiple_choice_logits return multiple_choice_logits
...@@ -345,9 +392,6 @@ class GPT2PreTrainedModel(nn.Module): ...@@ -345,9 +392,6 @@ class GPT2PreTrainedModel(nn.Module):
) )
self.config = config self.config = config
def set_tied(self):
pass
def init_weights(self, module): def init_weights(self, module):
""" Initialize the weights. """ Initialize the weights.
""" """
...@@ -480,14 +524,32 @@ class GPT2PreTrainedModel(nn.Module): ...@@ -480,14 +524,32 @@ class GPT2PreTrainedModel(nn.Module):
"Error(s) in loading state_dict for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs)) "Error(s) in loading state_dict for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs))
) )
# Make sure we are still sharing the output and input embeddings after loading weights # Add additional embeddings for special tokens if needed
model.set_tied() # This step also make sure we are still sharing the output and input embeddings after loading weights
model.set_num_special_tokens(num_special_tokens if num_special_tokens is not None else config.n_special)
return model return model
class GPT2Model(GPT2PreTrainedModel): class GPT2Model(GPT2PreTrainedModel):
"""OpenAI GPT-2 model ("Language Models are Unsupervised Multitask Learners"). """OpenAI GPT-2 model ("Language Models are Unsupervised Multitask Learners").
GPT-2 use a single embedding matrix to store the word and special embeddings.
Special tokens embeddings are additional tokens that are not pre-trained: [SEP], [CLS]...
Special tokens need to be trained during the fine-tuning if you use them.
The number of special embeddings can be controled using the `set_num_special_tokens(num_special_tokens)` function.
The embeddings are ordered as follow in the token embeddings matrice:
[0, ----------------------
... -> word embeddings
config.vocab_size - 1, ______________________
config.vocab_size,
... -> special embeddings
config.vocab_size + config.n_special - 1] ______________________
where total_tokens_embeddings can be obtained as config.total_tokens_embeddings and is:
total_tokens_embeddings = config.vocab_size + config.n_special
You should use the associate indices to index the embeddings.
Params: Params:
config: a GPT2Config class instance with the configuration to build a new model config: a GPT2Config class instance with the configuration to build a new model
...@@ -524,16 +586,32 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -524,16 +586,32 @@ class GPT2Model(GPT2PreTrainedModel):
``` ```
""" """
def __init__(self, config): def __init__(self, config, output_attentions=False):
super(GPT2Model, self).__init__(config) super(GPT2Model, self).__init__(config)
self.wte = nn.Embedding(config.vocab_size, config.n_embd) self.output_attentions = output_attentions
self.wte = nn.Embedding(config.total_tokens_embeddings, config.n_embd)
self.wpe = nn.Embedding(config.n_positions, config.n_embd) self.wpe = nn.Embedding(config.n_positions, config.n_embd)
block = Block(config.n_ctx, config, scale=True) self.drop = nn.Dropout(config.embd_pdrop)
block = Block(config.n_ctx, config, scale=True, output_attentions=output_attentions)
self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(config.n_layer)]) self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(config.n_layer)])
self.ln_f = LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) self.ln_f = LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
self.apply(self.init_weights) self.apply(self.init_weights)
def set_num_special_tokens(self, num_special_tokens):
" Update input embeddings with new embedding matrice if needed "
if self.config.n_special == num_special_tokens:
return
# Update config
self.config.n_special = num_special_tokens
# Build new embeddings and initialize all new embeddings (in particular the special tokens)
old_embed = self.wte
self.wte = nn.Embedding(self.config.total_tokens_embeddings, self.config.n_embd)
self.wte.to(old_embed.weight.device)
self.init_weights(self.wte)
# Copy word embeddings from the previous weights
self.wte.weight.data[:self.config.vocab_size, :] = old_embed.weight.data[:self.config.vocab_size, :]
def forward(self, input_ids, position_ids=None, token_type_ids=None, past=None): def forward(self, input_ids, position_ids=None, token_type_ids=None, past=None):
if past is None: if past is None:
past_length = 0 past_length = 0
...@@ -556,12 +634,21 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -556,12 +634,21 @@ class GPT2Model(GPT2PreTrainedModel):
else: else:
token_type_embeds = 0 token_type_embeds = 0
hidden_states = inputs_embeds + position_embeds + token_type_embeds hidden_states = inputs_embeds + position_embeds + token_type_embeds
hidden_states = self.drop(hidden_states)
presents = [] presents = []
all_attentions = []
for block, layer_past in zip(self.h, past): for block, layer_past in zip(self.h, past):
hidden_states, present = block(hidden_states, layer_past) if self.output_attentions:
attentions, hidden_states, present = block(hidden_states, layer_past)
all_attentions.append(attentions)
else:
hidden_states, present = block(hidden_states, layer_past)
presents.append(present) presents.append(present)
hidden_states = self.ln_f(hidden_states) hidden_states = self.ln_f(hidden_states)
output_shape = input_shape + (hidden_states.size(-1),) output_shape = input_shape + (hidden_states.size(-1),)
if self.output_attentions:
return all_attentions, hidden_states.view(*output_shape), presents
return hidden_states.view(*output_shape), presents return hidden_states.view(*output_shape), presents
...@@ -609,30 +696,38 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): ...@@ -609,30 +696,38 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
``` ```
""" """
def __init__(self, config): def __init__(self, config, output_attentions=False):
super(GPT2LMHeadModel, self).__init__(config) super(GPT2LMHeadModel, self).__init__(config)
self.transformer = GPT2Model(config) self.transformer = GPT2Model(config, output_attentions=output_attentions)
self.lm_head = GPT2LMHead(self.transformer.wte.weight, config) self.lm_head = GPT2LMHead(self.transformer.wte.weight, config)
self.apply(self.init_weights) self.apply(self.init_weights)
def set_tied(self): def set_num_special_tokens(self, num_special_tokens, predict_special_tokens=True):
""" Make sure we are sharing the embeddings """ Update input and output embeddings with new embedding matrice
Make sure we are sharing the embeddings
""" """
self.lm_head.set_embeddings_weights(self.transformer.wte.weight) self.config.predict_special_tokens = self.transformer.config.predict_special_tokens = predict_special_tokens
self.transformer.set_num_special_tokens(num_special_tokens)
self.lm_head.set_embeddings_weights(self.transformer.wte.weight, predict_special_tokens=predict_special_tokens)
def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None, past=None): def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None, past=None):
hidden_states, presents = self.transformer(input_ids, position_ids, token_type_ids, past) transformer_output = self.transformer(input_ids, position_ids, token_type_ids, past)
if self.transformer.output_attentions:
all_attentions, hidden_states, presents = transformer_output
else:
hidden_states, presents = transformer_output
lm_logits = self.lm_head(hidden_states) lm_logits = self.lm_head(hidden_states)
if lm_labels is not None: if lm_labels is not None:
# Shift so that tokens < n predict n # Shift so that tokens < n predict n
shift_logits = lm_logits[:, :-1].contiguous() shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = lm_labels[:, 1:].contiguous() shift_labels = lm_labels[..., 1:].contiguous()
# Flatten the tokens # Flatten the tokens
loss_fct = CrossEntropyLoss(ignore_index=-1) loss_fct = CrossEntropyLoss(ignore_index=-1)
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1)) shift_labels.view(-1))
return loss return loss
if self.transformer.output_attentions:
return all_attentions, lm_logits, presents
return lm_logits, presents return lm_logits, presents
...@@ -685,32 +780,40 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel): ...@@ -685,32 +780,40 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
``` ```
""" """
def __init__(self, config): def __init__(self, config, output_attentions=False):
super(GPT2DoubleHeadsModel, self).__init__(config) super(GPT2DoubleHeadsModel, self).__init__(config)
self.transformer = GPT2Model(config) self.transformer = GPT2Model(config, output_attentions=output_attentions)
self.lm_head = GPT2LMHead(self.transformer.wte.weight, config) self.lm_head = GPT2LMHead(self.transformer.wte.weight, config)
self.multiple_choice_head = GPT2MultipleChoiceHead(config) self.multiple_choice_head = GPT2MultipleChoiceHead(config)
self.apply(self.init_weights) self.apply(self.init_weights)
def set_tied(self): def set_num_special_tokens(self, num_special_tokens, predict_special_tokens=True):
""" Make sure we are sharing the embeddings """ Update input and output embeddings with new embedding matrice
Make sure we are sharing the embeddings
""" """
self.lm_head.set_embeddings_weights(self.transformer.wte.weight) self.config.predict_special_tokens = self.transformer.config.predict_special_tokens = predict_special_tokens
self.transformer.set_num_special_tokens(num_special_tokens)
self.lm_head.set_embeddings_weights(self.transformer.wte.weight, predict_special_tokens=predict_special_tokens)
def forward(self, input_ids, mc_token_ids, lm_labels=None, mc_labels=None, token_type_ids=None, position_ids=None, past=None): def forward(self, input_ids, mc_token_ids, lm_labels=None, mc_labels=None, token_type_ids=None, position_ids=None, past=None):
hidden_states, presents = self.transformer(input_ids, position_ids, token_type_ids, past) transformer_output = self.transformer(input_ids, position_ids, token_type_ids, past)
if self.transformer.output_attentions:
all_attentions, hidden_states, presents = transformer_output
else:
hidden_states, presents = transformer_output
lm_logits = self.lm_head(hidden_states) lm_logits = self.lm_head(hidden_states)
mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids) mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids)
losses = [] losses = []
if lm_labels is not None: if lm_labels is not None:
shift_logits = lm_logits[:, :-1].contiguous() shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = lm_labels[:, 1:].contiguous() shift_labels = lm_labels[..., 1:].contiguous()
loss_fct = CrossEntropyLoss(ignore_index=-1) loss_fct = CrossEntropyLoss(ignore_index=-1)
losses.append(loss_fct(shift_logits.view(-1, losses.append(loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)))
shift_logits.size(-1)), shift_labels.view(-1)))
if mc_labels is not None: if mc_labels is not None:
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
losses.append(loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1))) losses.append(loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1)))
if losses: if losses:
return losses return losses
if self.transformer.output_attentions:
return all_attentions, lm_logits, mc_logits, presents
return lm_logits, mc_logits, presents return lm_logits, mc_logits, presents
...@@ -143,6 +143,7 @@ class OpenAIGPTConfig(object): ...@@ -143,6 +143,7 @@ class OpenAIGPTConfig(object):
attn_pdrop=0.1, attn_pdrop=0.1,
layer_norm_epsilon=1e-5, layer_norm_epsilon=1e-5,
initializer_range=0.02, initializer_range=0.02,
predict_special_tokens=True
): ):
"""Constructs OpenAIGPTConfig. """Constructs OpenAIGPTConfig.
...@@ -165,6 +166,7 @@ class OpenAIGPTConfig(object): ...@@ -165,6 +166,7 @@ class OpenAIGPTConfig(object):
layer_norm_epsilon: epsilon to use in the layer norm layers layer_norm_epsilon: epsilon to use in the layer norm layers
initializer_range: The sttdev of the truncated_normal_initializer for initializer_range: The sttdev of the truncated_normal_initializer for
initializing all weight matrices. initializing all weight matrices.
predict_special_tokens: should we predict special tokens (when the model has a LM head)
""" """
if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2 if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
and isinstance(vocab_size_or_config_json_file, unicode)): and isinstance(vocab_size_or_config_json_file, unicode)):
...@@ -186,6 +188,7 @@ class OpenAIGPTConfig(object): ...@@ -186,6 +188,7 @@ class OpenAIGPTConfig(object):
self.attn_pdrop = attn_pdrop self.attn_pdrop = attn_pdrop
self.layer_norm_epsilon = layer_norm_epsilon self.layer_norm_epsilon = layer_norm_epsilon
self.initializer_range = initializer_range self.initializer_range = initializer_range
self.predict_special_tokens = predict_special_tokens
else: else:
raise ValueError( raise ValueError(
"First argument must be either a vocabulary size (int)" "First argument must be either a vocabulary size (int)"
...@@ -253,7 +256,7 @@ class Conv1D(nn.Module): ...@@ -253,7 +256,7 @@ class Conv1D(nn.Module):
class Attention(nn.Module): class Attention(nn.Module):
def __init__(self, nx, n_ctx, config, scale=False): def __init__(self, nx, n_ctx, config, scale=False, output_attentions=False):
super(Attention, self).__init__() super(Attention, self).__init__()
n_state = nx # in Attention: n_state=768 (nx=n_embd) n_state = nx # in Attention: n_state=768 (nx=n_embd)
# [switch nx => n_state from Block to Attention to keep identical to TF implem] # [switch nx => n_state from Block to Attention to keep identical to TF implem]
...@@ -262,6 +265,7 @@ class Attention(nn.Module): ...@@ -262,6 +265,7 @@ class Attention(nn.Module):
self.n_head = config.n_head self.n_head = config.n_head
self.split_size = n_state self.split_size = n_state
self.scale = scale self.scale = scale
self.output_attentions = output_attentions
self.c_attn = Conv1D(n_state * 3, 1, nx) self.c_attn = Conv1D(n_state * 3, 1, nx)
self.c_proj = Conv1D(n_state, 1, nx) self.c_proj = Conv1D(n_state, 1, nx)
self.attn_dropout = nn.Dropout(config.attn_pdrop) self.attn_dropout = nn.Dropout(config.attn_pdrop)
...@@ -278,6 +282,8 @@ class Attention(nn.Module): ...@@ -278,6 +282,8 @@ class Attention(nn.Module):
w = nn.Softmax(dim=-1)(w) w = nn.Softmax(dim=-1)(w)
w = self.attn_dropout(w) w = self.attn_dropout(w)
if self.output_attentions:
return w, torch.matmul(w, v)
return torch.matmul(w, v) return torch.matmul(w, v)
def merge_heads(self, x): def merge_heads(self, x):
...@@ -300,9 +306,13 @@ class Attention(nn.Module): ...@@ -300,9 +306,13 @@ class Attention(nn.Module):
key = self.split_heads(key, k=True) key = self.split_heads(key, k=True)
value = self.split_heads(value) value = self.split_heads(value)
a = self._attn(query, key, value) a = self._attn(query, key, value)
if self.output_attentions:
attentions, a = a
a = self.merge_heads(a) a = self.merge_heads(a)
a = self.c_proj(a) a = self.c_proj(a)
a = self.resid_dropout(a) a = self.resid_dropout(a)
if self.output_attentions:
return attentions, a
return a return a
...@@ -322,19 +332,24 @@ class MLP(nn.Module): ...@@ -322,19 +332,24 @@ class MLP(nn.Module):
class Block(nn.Module): class Block(nn.Module):
def __init__(self, n_ctx, config, scale=False): def __init__(self, n_ctx, config, scale=False, output_attentions=False):
super(Block, self).__init__() super(Block, self).__init__()
nx = config.n_embd nx = config.n_embd
self.attn = Attention(nx, n_ctx, config, scale) self.output_attentions = output_attentions
self.attn = Attention(nx, n_ctx, config, scale, output_attentions)
self.ln_1 = LayerNorm(nx, eps=config.layer_norm_epsilon) self.ln_1 = LayerNorm(nx, eps=config.layer_norm_epsilon)
self.mlp = MLP(4 * nx, config) self.mlp = MLP(4 * nx, config)
self.ln_2 = LayerNorm(nx, eps=config.layer_norm_epsilon) self.ln_2 = LayerNorm(nx, eps=config.layer_norm_epsilon)
def forward(self, x): def forward(self, x):
a = self.attn(x) a = self.attn(x)
if self.output_attentions:
attentions, a = a
n = self.ln_1(x + a) n = self.ln_1(x + a)
m = self.mlp(n) m = self.mlp(n)
h = self.ln_2(n + m) h = self.ln_2(n + m)
if self.output_attentions:
return attentions, h
return h return h
...@@ -344,17 +359,21 @@ class OpenAIGPTLMHead(nn.Module): ...@@ -344,17 +359,21 @@ class OpenAIGPTLMHead(nn.Module):
def __init__(self, model_embeddings_weights, config): def __init__(self, model_embeddings_weights, config):
super(OpenAIGPTLMHead, self).__init__() super(OpenAIGPTLMHead, self).__init__()
self.n_embd = config.n_embd self.n_embd = config.n_embd
self.vocab_size = config.vocab_size
self.predict_special_tokens = config.predict_special_tokens
embed_shape = model_embeddings_weights.shape
self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False)
self.set_embeddings_weights(model_embeddings_weights) self.set_embeddings_weights(model_embeddings_weights)
def set_embeddings_weights(self, model_embeddings_weights): def set_embeddings_weights(self, model_embeddings_weights, predict_special_tokens=True):
self.predict_special_tokens = predict_special_tokens
embed_shape = model_embeddings_weights.shape embed_shape = model_embeddings_weights.shape
self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False)
self.decoder.weight = model_embeddings_weights # Tied weights self.decoder.weight = model_embeddings_weights # Tied weights
def forward(self, hidden_state): def forward(self, hidden_state):
# Truncated Language modeling logits (we remove the last token)
# h_trunc = h[:, :-1].contiguous().view(-1, self.n_embd)
lm_logits = self.decoder(hidden_state) lm_logits = self.decoder(hidden_state)
if not self.predict_special_tokens:
lm_logits = lm_logits[..., :self.vocab_size]
return lm_logits return lm_logits
...@@ -364,7 +383,6 @@ class OpenAIGPTMultipleChoiceHead(nn.Module): ...@@ -364,7 +383,6 @@ class OpenAIGPTMultipleChoiceHead(nn.Module):
def __init__(self, config): def __init__(self, config):
super(OpenAIGPTMultipleChoiceHead, self).__init__() super(OpenAIGPTMultipleChoiceHead, self).__init__()
self.n_embd = config.n_embd self.n_embd = config.n_embd
# self.multiple_choice_token = multiple_choice_token
self.dropout = nn.Dropout2d(config.resid_pdrop) # To reproduce the noise_shape parameter of TF implementation self.dropout = nn.Dropout2d(config.resid_pdrop) # To reproduce the noise_shape parameter of TF implementation
self.linear = nn.Linear(config.n_embd, 1) self.linear = nn.Linear(config.n_embd, 1)
...@@ -415,9 +433,6 @@ class OpenAIGPTPreTrainedModel(nn.Module): ...@@ -415,9 +433,6 @@ class OpenAIGPTPreTrainedModel(nn.Module):
if isinstance(module, nn.Linear) and module.bias is not None: if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_() module.bias.data.zero_()
def set_num_special_tokens(self, num_special_tokens):
pass
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path, num_special_tokens=None, *inputs, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path, num_special_tokens=None, *inputs, **kwargs):
""" """
...@@ -594,17 +609,16 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel): ...@@ -594,17 +609,16 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
``` ```
""" """
def __init__(self, config): def __init__(self, config, output_attentions=False):
super(OpenAIGPTModel, self).__init__(config) super(OpenAIGPTModel, self).__init__(config)
num_tokens = config.vocab_size + config.n_special self.output_attentions = output_attentions
self.tokens_embed = nn.Embedding(num_tokens, config.n_embd) self.tokens_embed = nn.Embedding(config.total_tokens_embeddings, config.n_embd)
self.positions_embed = nn.Embedding(config.n_positions, config.n_embd) self.positions_embed = nn.Embedding(config.n_positions, config.n_embd)
self.drop = nn.Dropout(config.embd_pdrop) self.drop = nn.Dropout(config.embd_pdrop)
block = Block(config.n_ctx, config, scale=True) block = Block(config.n_ctx, config, scale=True, output_attentions=output_attentions)
self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(config.n_layer)]) self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(config.n_layer)])
self.apply(self.init_weights) self.apply(self.init_weights)
# nn.init.normal_(self.embed.weight, std=0.02)
def set_num_special_tokens(self, num_special_tokens): def set_num_special_tokens(self, num_special_tokens):
" Update input embeddings with new embedding matrice if needed " " Update input embeddings with new embedding matrice if needed "
...@@ -640,12 +654,19 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel): ...@@ -640,12 +654,19 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
token_type_embeds = self.tokens_embed(token_type_ids) token_type_embeds = self.tokens_embed(token_type_ids)
else: else:
token_type_embeds = 0 token_type_embeds = 0
# Add the position information to the input embeddings
# h = e.sum(dim=2)
hidden_states = inputs_embeds + position_embeds + token_type_embeds hidden_states = inputs_embeds + position_embeds + token_type_embeds
hidden_states = self.drop(hidden_states)
all_attentions = []
for block in self.h: for block in self.h:
hidden_states = block(hidden_states) if self.output_attentions:
attentions, hidden_states = block(hidden_states)
all_attentions.append(attentions)
else:
hidden_states = block(hidden_states)
output_shape = input_shape + (hidden_states.size(-1),) output_shape = input_shape + (hidden_states.size(-1),)
if self.output_attentions:
return all_attentions, hidden_states.view(*output_shape)
return hidden_states.view(*output_shape) return hidden_states.view(*output_shape)
...@@ -705,21 +726,24 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel): ...@@ -705,21 +726,24 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
``` ```
""" """
def __init__(self, config): def __init__(self, config, output_attentions=False):
super(OpenAIGPTLMHeadModel, self).__init__(config) super(OpenAIGPTLMHeadModel, self).__init__(config)
self.transformer = OpenAIGPTModel(config) self.transformer = OpenAIGPTModel(config, output_attentions=output_attentions)
self.lm_head = OpenAIGPTLMHead(self.transformer.tokens_embed.weight, config) self.lm_head = OpenAIGPTLMHead(self.transformer.tokens_embed.weight, config)
self.apply(self.init_weights) self.apply(self.init_weights)
def set_num_special_tokens(self, num_special_tokens): def set_num_special_tokens(self, num_special_tokens, predict_special_tokens=True):
""" Update input and output embeddings with new embedding matrice """ Update input and output embeddings with new embedding matrice
Make sure we are sharing the embeddings Make sure we are sharing the embeddings
""" """
self.config.predict_special_tokens = self.transformer.config.predict_special_tokens = predict_special_tokens
self.transformer.set_num_special_tokens(num_special_tokens) self.transformer.set_num_special_tokens(num_special_tokens)
self.lm_head.set_embeddings_weights(self.transformer.tokens_embed.weight) self.lm_head.set_embeddings_weights(self.transformer.tokens_embed.weight, predict_special_tokens=predict_special_tokens)
def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None): def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None):
hidden_states = self.transformer(input_ids, position_ids, token_type_ids) hidden_states = self.transformer(input_ids, position_ids, token_type_ids)
if self.transformer.output_attentions:
all_attentions, hidden_states = hidden_states
lm_logits = self.lm_head(hidden_states) lm_logits = self.lm_head(hidden_states)
if lm_labels is not None: if lm_labels is not None:
# Shift so that tokens < n predict n # Shift so that tokens < n predict n
...@@ -730,6 +754,8 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel): ...@@ -730,6 +754,8 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1)) shift_labels.view(-1))
return loss return loss
if self.transformer.output_attentions:
return all_attentions, lm_logits
return lm_logits return lm_logits
...@@ -794,22 +820,25 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel): ...@@ -794,22 +820,25 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
``` ```
""" """
def __init__(self, config): def __init__(self, config, output_attentions=False):
super(OpenAIGPTDoubleHeadsModel, self).__init__(config) super(OpenAIGPTDoubleHeadsModel, self).__init__(config)
self.transformer = OpenAIGPTModel(config) self.transformer = OpenAIGPTModel(config, output_attentions=output_attentions)
self.lm_head = OpenAIGPTLMHead(self.transformer.tokens_embed.weight, config) self.lm_head = OpenAIGPTLMHead(self.transformer.tokens_embed.weight, config)
self.multiple_choice_head = OpenAIGPTMultipleChoiceHead(config) self.multiple_choice_head = OpenAIGPTMultipleChoiceHead(config)
self.apply(self.init_weights) self.apply(self.init_weights)
def set_num_special_tokens(self, num_special_tokens): def set_num_special_tokens(self, num_special_tokens, predict_special_tokens=True):
""" Update input and output embeddings with new embedding matrice """ Update input and output embeddings with new embedding matrice
Make sure we are sharing the embeddings Make sure we are sharing the embeddings
""" """
self.config.predict_special_tokens = self.transformer.config.predict_special_tokens = predict_special_tokens
self.transformer.set_num_special_tokens(num_special_tokens) self.transformer.set_num_special_tokens(num_special_tokens)
self.lm_head.set_embeddings_weights(self.transformer.tokens_embed.weight) self.lm_head.set_embeddings_weights(self.transformer.tokens_embed.weight, predict_special_tokens=predict_special_tokens)
def forward(self, input_ids, mc_token_ids, lm_labels=None, mc_labels=None, token_type_ids=None, position_ids=None): def forward(self, input_ids, mc_token_ids, lm_labels=None, mc_labels=None, token_type_ids=None, position_ids=None):
hidden_states = self.transformer(input_ids, position_ids, token_type_ids) hidden_states = self.transformer(input_ids, position_ids, token_type_ids)
if self.transformer.output_attentions:
all_attentions, hidden_states = hidden_states
lm_logits = self.lm_head(hidden_states) lm_logits = self.lm_head(hidden_states)
mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids) mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids)
losses = [] losses = []
...@@ -823,4 +852,6 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel): ...@@ -823,4 +852,6 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
losses.append(loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1))) losses.append(loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1)))
if losses: if losses:
return losses return losses
if self.transformer.output_attentions:
return all_attentions, lm_logits, mc_logits
return lm_logits, mc_logits return lm_logits, mc_logits
...@@ -37,9 +37,11 @@ logger = logging.getLogger(__name__) ...@@ -37,9 +37,11 @@ logger = logging.getLogger(__name__)
PRETRAINED_VOCAB_ARCHIVE_MAP = { PRETRAINED_VOCAB_ARCHIVE_MAP = {
'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json", 'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json",
'gpt2-medium': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-vocab.json",
} }
PRETRAINED_MERGES_ARCHIVE_MAP = { PRETRAINED_MERGES_ARCHIVE_MAP = {
'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt", 'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt",
'gpt2-medium': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-merges.txt",
} }
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = { PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
'gpt2': 1024, 'gpt2': 1024,
...@@ -263,9 +265,14 @@ class GPT2Tokenizer(object): ...@@ -263,9 +265,14 @@ class GPT2Tokenizer(object):
def encode(self, text): def encode(self, text):
return self.convert_tokens_to_ids(self.tokenize(text)) return self.convert_tokens_to_ids(self.tokenize(text))
def decode(self, tokens): def decode(self, tokens, skip_special_tokens=False, clean_up_tokenization_spaces=True):
text = ''.join([self.decoder[token] for token in tokens]) text = ''.join(self.convert_ids_to_tokens(tokens, skip_special_tokens=skip_special_tokens))
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors) text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
if clean_up_tokenization_spaces:
text = text.replace('<unk>', '')
text = text.replace(' .', '.').replace(' ?', '?').replace(' !', '!').replace(' ,', ','
).replace(" ' ", "'").replace(" n't", "n't").replace(" 'm", "'m").replace(" do not", " don't"
).replace(" 's", "'s").replace(" 've", "'ve").replace(" 're", "'re")
return text return text
def save_vocabulary(self, vocab_path): def save_vocabulary(self, vocab_path):
......
...@@ -272,7 +272,7 @@ class OpenAIGPTTokenizer(object): ...@@ -272,7 +272,7 @@ class OpenAIGPTTokenizer(object):
out_string = ''.join(tokens).replace('</w>', ' ').strip() out_string = ''.join(tokens).replace('</w>', ' ').strip()
if clean_up_tokenization_spaces: if clean_up_tokenization_spaces:
out_string = out_string.replace('<unk>', '') out_string = out_string.replace('<unk>', '')
out_string = out_string.replace(' .', '.').replace(' ?', '?').replace(' !', '!').replace(' ,', ',').replace(' ,', ',' out_string = out_string.replace(' .', '.').replace(' ?', '?').replace(' !', '!').replace(' ,', ','
).replace(" ' ", "'").replace(" n't", "n't").replace(" 'm", "'m").replace(" do not", " don't" ).replace(" ' ", "'").replace(" n't", "n't").replace(" 'm", "'m").replace(" do not", " don't"
).replace(" 's", "'s").replace(" 've", "'ve").replace(" 're", "'re") ).replace(" 's", "'s").replace(" 've", "'ve").replace(" 're", "'re")
return out_string return out_string
......
...@@ -41,6 +41,7 @@ class GPT2ModelTest(unittest.TestCase): ...@@ -41,6 +41,7 @@ class GPT2ModelTest(unittest.TestCase):
use_token_type_ids=True, use_token_type_ids=True,
use_labels=True, use_labels=True,
vocab_size=99, vocab_size=99,
n_special=1,
n_positions=33, n_positions=33,
n_embd=32, n_embd=32,
n_layer=5, n_layer=5,
...@@ -58,6 +59,7 @@ class GPT2ModelTest(unittest.TestCase): ...@@ -58,6 +59,7 @@ class GPT2ModelTest(unittest.TestCase):
self.use_token_type_ids = use_token_type_ids self.use_token_type_ids = use_token_type_ids
self.use_labels = use_labels self.use_labels = use_labels
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.n_special = n_special
self.n_positions = n_positions self.n_positions = n_positions
self.n_embd = n_embd self.n_embd = n_embd
self.n_layer = n_layer self.n_layer = n_layer
...@@ -69,7 +71,8 @@ class GPT2ModelTest(unittest.TestCase): ...@@ -69,7 +71,8 @@ class GPT2ModelTest(unittest.TestCase):
self.scope = scope self.scope = scope
def prepare_config_and_inputs(self): def prepare_config_and_inputs(self):
input_ids = GPT2ModelTest.ids_tensor([self.batch_size, self.n_choices, self.seq_length], self.vocab_size) total_num_tokens = self.vocab_size + self.n_special
input_ids = GPT2ModelTest.ids_tensor([self.batch_size, self.n_choices, self.seq_length], total_num_tokens)
position_ids = None position_ids = None
if self.use_position_ids: if self.use_position_ids:
...@@ -90,6 +93,7 @@ class GPT2ModelTest(unittest.TestCase): ...@@ -90,6 +93,7 @@ class GPT2ModelTest(unittest.TestCase):
config = GPT2Config( config = GPT2Config(
vocab_size_or_config_json_file=self.vocab_size, vocab_size_or_config_json_file=self.vocab_size,
n_special=self.n_special,
n_positions=self.n_positions, n_positions=self.n_positions,
n_embd=self.n_embd, n_embd=self.n_embd,
n_layer=self.n_layer, n_layer=self.n_layer,
...@@ -129,11 +133,29 @@ class GPT2ModelTest(unittest.TestCase): ...@@ -129,11 +133,29 @@ class GPT2ModelTest(unittest.TestCase):
} }
return outputs return outputs
def create_gpt2_lm_head_with_output_attention(self, config, input_ids, token_type_ids, position_ids,
mc_labels, lm_labels, mc_token_ids):
model = GPT2LMHeadModel(config, output_attentions=True)
model.eval()
loss = model(input_ids, position_ids, token_type_ids, lm_labels)
attentions, lm_logits, presents = model(input_ids, position_ids, token_type_ids)
outputs = {
"loss": loss,
"lm_logits": lm_logits,
"presents": presents,
"attentions": attentions,
}
return outputs
def check_gpt2_lm_head_output(self, result): def check_gpt2_lm_head_output(self, result):
total_voc = self.vocab_size total_voc = self.n_special + self.vocab_size
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["lm_logits"].size()), list(result["lm_logits"].size()),
[self.batch_size, self.n_choices, self.seq_length, total_voc]) [self.batch_size, self.n_choices, self.seq_length, total_voc])
self.parent.assertEqual(self.n_layer, len(result["presents"]))
self.parent.assertListEqual(
list(result["presents"][0].size()),
[2, self.batch_size * self.n_choices, self.n_head, self.seq_length, self.n_embd // self.n_head])
def check_gpt2_lm_head_loss_output(self, result): def check_gpt2_lm_head_loss_output(self, result):
self.parent.assertListEqual( self.parent.assertListEqual(
...@@ -156,8 +178,25 @@ class GPT2ModelTest(unittest.TestCase): ...@@ -156,8 +178,25 @@ class GPT2ModelTest(unittest.TestCase):
} }
return outputs return outputs
def create_gpt2_double_heads_with_output_attention(self, config, input_ids, token_type_ids, position_ids,
mc_labels, lm_labels, mc_token_ids):
model = GPT2DoubleHeadsModel(config, output_attentions=True)
model.eval()
loss = model(input_ids, mc_token_ids,
lm_labels=lm_labels, mc_labels=mc_labels,
token_type_ids=token_type_ids, position_ids=position_ids)
attentions, lm_logits, mc_logits, presents = model(input_ids, mc_token_ids, position_ids=position_ids, token_type_ids=token_type_ids)
outputs = {
"loss": loss,
"lm_logits": lm_logits,
"mc_logits": mc_logits,
"presents": presents,
"attentions": attentions,
}
return outputs
def check_gpt2_double_heads_output(self, result): def check_gpt2_double_heads_output(self, result):
total_voc = self.vocab_size total_voc = self.n_special + self.vocab_size
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["lm_logits"].size()), list(result["lm_logits"].size()),
[self.batch_size, self.n_choices, self.seq_length, total_voc]) [self.batch_size, self.n_choices, self.seq_length, total_voc])
......
...@@ -28,7 +28,7 @@ import torch ...@@ -28,7 +28,7 @@ import torch
from pytorch_pretrained_bert import (BertConfig, BertModel, BertForMaskedLM, from pytorch_pretrained_bert import (BertConfig, BertModel, BertForMaskedLM,
BertForNextSentencePrediction, BertForPreTraining, BertForNextSentencePrediction, BertForPreTraining,
BertForQuestionAnswering, BertForSequenceClassification, BertForQuestionAnswering, BertForSequenceClassification,
BertForTokenClassification) BertForTokenClassification, BertForMultipleChoice)
from pytorch_pretrained_bert.modeling import PRETRAINED_MODEL_ARCHIVE_MAP from pytorch_pretrained_bert.modeling import PRETRAINED_MODEL_ARCHIVE_MAP
...@@ -56,6 +56,7 @@ class BertModelTest(unittest.TestCase): ...@@ -56,6 +56,7 @@ class BertModelTest(unittest.TestCase):
type_sequence_label_size=2, type_sequence_label_size=2,
initializer_range=0.02, initializer_range=0.02,
num_labels=3, num_labels=3,
num_choices=4,
scope=None): scope=None):
self.parent = parent self.parent = parent
self.batch_size = batch_size self.batch_size = batch_size
...@@ -77,6 +78,7 @@ class BertModelTest(unittest.TestCase): ...@@ -77,6 +78,7 @@ class BertModelTest(unittest.TestCase):
self.type_sequence_label_size = type_sequence_label_size self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range self.initializer_range = initializer_range
self.num_labels = num_labels self.num_labels = num_labels
self.num_choices = num_choices
self.scope = scope self.scope = scope
def prepare_config_and_inputs(self): def prepare_config_and_inputs(self):
...@@ -92,9 +94,11 @@ class BertModelTest(unittest.TestCase): ...@@ -92,9 +94,11 @@ class BertModelTest(unittest.TestCase):
sequence_labels = None sequence_labels = None
token_labels = None token_labels = None
choice_labels = None
if self.use_labels: if self.use_labels:
sequence_labels = BertModelTest.ids_tensor([self.batch_size], self.type_sequence_label_size) sequence_labels = BertModelTest.ids_tensor([self.batch_size], self.type_sequence_label_size)
token_labels = BertModelTest.ids_tensor([self.batch_size, self.seq_length], self.num_labels) token_labels = BertModelTest.ids_tensor([self.batch_size, self.seq_length], self.num_labels)
choice_labels = BertModelTest.ids_tensor([self.batch_size], self.num_choices)
config = BertConfig( config = BertConfig(
vocab_size_or_config_json_file=self.vocab_size, vocab_size_or_config_json_file=self.vocab_size,
...@@ -109,14 +113,14 @@ class BertModelTest(unittest.TestCase): ...@@ -109,14 +113,14 @@ class BertModelTest(unittest.TestCase):
type_vocab_size=self.type_vocab_size, type_vocab_size=self.type_vocab_size,
initializer_range=self.initializer_range) initializer_range=self.initializer_range)
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
def check_loss_output(self, result): def check_loss_output(self, result):
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["loss"].size()), list(result["loss"].size()),
[]) [])
def create_bert_model(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels): def create_bert_model(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = BertModel(config=config) model = BertModel(config=config)
model.eval() model.eval()
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
...@@ -137,7 +141,7 @@ class BertModelTest(unittest.TestCase): ...@@ -137,7 +141,7 @@ class BertModelTest(unittest.TestCase):
self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size]) self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size])
def create_bert_for_masked_lm(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels): def create_bert_for_masked_lm(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = BertForMaskedLM(config=config) model = BertForMaskedLM(config=config)
model.eval() model.eval()
loss = model(input_ids, token_type_ids, input_mask, token_labels) loss = model(input_ids, token_type_ids, input_mask, token_labels)
...@@ -153,7 +157,7 @@ class BertModelTest(unittest.TestCase): ...@@ -153,7 +157,7 @@ class BertModelTest(unittest.TestCase):
list(result["prediction_scores"].size()), list(result["prediction_scores"].size()),
[self.batch_size, self.seq_length, self.vocab_size]) [self.batch_size, self.seq_length, self.vocab_size])
def create_bert_for_next_sequence_prediction(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels): def create_bert_for_next_sequence_prediction(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = BertForNextSentencePrediction(config=config) model = BertForNextSentencePrediction(config=config)
model.eval() model.eval()
loss = model(input_ids, token_type_ids, input_mask, sequence_labels) loss = model(input_ids, token_type_ids, input_mask, sequence_labels)
...@@ -170,7 +174,7 @@ class BertModelTest(unittest.TestCase): ...@@ -170,7 +174,7 @@ class BertModelTest(unittest.TestCase):
[self.batch_size, 2]) [self.batch_size, 2])
def create_bert_for_pretraining(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels): def create_bert_for_pretraining(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = BertForPreTraining(config=config) model = BertForPreTraining(config=config)
model.eval() model.eval()
loss = model(input_ids, token_type_ids, input_mask, token_labels, sequence_labels) loss = model(input_ids, token_type_ids, input_mask, token_labels, sequence_labels)
...@@ -191,7 +195,7 @@ class BertModelTest(unittest.TestCase): ...@@ -191,7 +195,7 @@ class BertModelTest(unittest.TestCase):
[self.batch_size, 2]) [self.batch_size, 2])
def create_bert_for_question_answering(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels): def create_bert_for_question_answering(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = BertForQuestionAnswering(config=config) model = BertForQuestionAnswering(config=config)
model.eval() model.eval()
loss = model(input_ids, token_type_ids, input_mask, sequence_labels, sequence_labels) loss = model(input_ids, token_type_ids, input_mask, sequence_labels, sequence_labels)
...@@ -212,7 +216,7 @@ class BertModelTest(unittest.TestCase): ...@@ -212,7 +216,7 @@ class BertModelTest(unittest.TestCase):
[self.batch_size, self.seq_length]) [self.batch_size, self.seq_length])
def create_bert_for_sequence_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels): def create_bert_for_sequence_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = BertForSequenceClassification(config=config, num_labels=self.num_labels) model = BertForSequenceClassification(config=config, num_labels=self.num_labels)
model.eval() model.eval()
loss = model(input_ids, token_type_ids, input_mask, sequence_labels) loss = model(input_ids, token_type_ids, input_mask, sequence_labels)
...@@ -229,7 +233,7 @@ class BertModelTest(unittest.TestCase): ...@@ -229,7 +233,7 @@ class BertModelTest(unittest.TestCase):
[self.batch_size, self.num_labels]) [self.batch_size, self.num_labels])
def create_bert_for_token_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels): def create_bert_for_token_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = BertForTokenClassification(config=config, num_labels=self.num_labels) model = BertForTokenClassification(config=config, num_labels=self.num_labels)
model.eval() model.eval()
loss = model(input_ids, token_type_ids, input_mask, token_labels) loss = model(input_ids, token_type_ids, input_mask, token_labels)
...@@ -246,6 +250,49 @@ class BertModelTest(unittest.TestCase): ...@@ -246,6 +250,49 @@ class BertModelTest(unittest.TestCase):
[self.batch_size, self.seq_length, self.num_labels]) [self.batch_size, self.seq_length, self.num_labels])
def create_bert_for_multiple_choice(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = BertForMultipleChoice(config=config, num_choices=self.num_choices)
model.eval()
multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
loss = model(multiple_choice_inputs_ids,
multiple_choice_token_type_ids,
multiple_choice_input_mask,
choice_labels)
logits = model(multiple_choice_inputs_ids,
multiple_choice_token_type_ids,
multiple_choice_input_mask)
outputs = {
"loss": loss,
"logits": logits,
}
return outputs
def check_bert_for_multiple_choice(self, result):
self.parent.assertListEqual(
list(result["logits"].size()),
[self.batch_size, self.num_choices])
def create_and_check_bert_for_attentions(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
for model_class in (BertModel, BertForMaskedLM, BertForNextSentencePrediction,
BertForPreTraining, BertForQuestionAnswering, BertForSequenceClassification,
BertForTokenClassification):
if model_class in [BertForSequenceClassification,
BertForTokenClassification]:
model = model_class(config=config, num_labels=self.num_labels, output_attentions=True)
else:
model = model_class(config=config, output_attentions=True)
model.eval()
output = model(input_ids, token_type_ids, input_mask)
attentions = output[0]
self.parent.assertEqual(len(attentions), self.num_hidden_layers)
self.parent.assertListEqual(
list(attentions[0].size()),
[self.batch_size, self.num_attention_heads, self.seq_length, self.seq_length])
def test_default(self): def test_default(self):
self.run_tester(BertModelTest.BertModelTester(self)) self.run_tester(BertModelTest.BertModelTester(self))
...@@ -300,6 +347,12 @@ class BertModelTest(unittest.TestCase): ...@@ -300,6 +347,12 @@ class BertModelTest(unittest.TestCase):
tester.check_bert_for_token_classification_output(output_result) tester.check_bert_for_token_classification_output(output_result)
tester.check_loss_output(output_result) tester.check_loss_output(output_result)
output_result = tester.create_bert_for_multiple_choice(*config_and_inputs)
tester.check_bert_for_multiple_choice(output_result)
tester.check_loss_output(output_result)
tester.create_and_check_bert_for_attentions(*config_and_inputs)
@classmethod @classmethod
def ids_tensor(cls, shape, vocab_size, rng=None, name=None): def ids_tensor(cls, shape, vocab_size, rng=None, name=None):
"""Creates a random int32 tensor of the shape within the vocab size.""" """Creates a random int32 tensor of the shape within the vocab size."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment