Commit 0a4fb0da authored by chrislarson1's avatar chrislarson1
Browse files

Merge remote-tracking branch 'upstream/master' into convert-back-to-tf

merging in latest changes from upstream
parents 314bc6bb 3763f894
......@@ -25,9 +25,6 @@ import copy
import json
import math
import logging
import tarfile
import tempfile
import shutil
import collections
import sys
from io import open
......@@ -888,8 +885,7 @@ class TransfoXLPreTrainedModel(nn.Module):
pass
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, state_dict=None, cache_dir=None,
from_tf=False, *inputs, **kwargs):
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
"""
Instantiate a TransfoXLPreTrainedModel from a pre-trained model file or a pytorch state dict.
Download and cache the pre-trained model file if needed.
......@@ -897,19 +893,25 @@ class TransfoXLPreTrainedModel(nn.Module):
Params:
pretrained_model_name_or_path: either:
- a str with the name of a pre-trained model to load selected in the list of:
. `transfo-xl`
. `transfo-xl-wt103`
- a path or url to a pretrained model archive containing:
. `transfo_xl_config.json` a configuration file for the model
. `pytorch_model.bin` a PyTorch dump of a TransfoXLModel instance
- a path or url to a pretrained model archive containing:
. `bert_config.json` a configuration file for the model
. `transfo_xl_config.json` a configuration file for the model
. `model.chkpt` a TensorFlow checkpoint
from_tf: should we load the weights from a locally saved TensorFlow checkpoint
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of pre-trained models
*inputs, **kwargs: additional input for the specific Bert class
(ex: num_labels for BertForSequenceClassification)
*inputs, **kwargs: additional input for the specific TransformerXL class
"""
state_dict = kwargs.get('state_dict', None)
kwargs.pop('state_dict', None)
cache_dir = kwargs.get('cache_dir', None)
kwargs.pop('cache_dir', None)
from_tf = kwargs.get('from_tf', False)
kwargs.pop('from_tf', None)
if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path]
config_file = PRETRAINED_CONFIG_ARCHIVE_MAP[pretrained_model_name_or_path]
......@@ -919,16 +921,37 @@ class TransfoXLPreTrainedModel(nn.Module):
# redirect to the cache, if necessary
try:
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
except EnvironmentError:
if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
logger.error(
"Couldn't reach server at '{}' to download pretrained weights.".format(
archive_file))
else:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find file {} "
"at this path or url.".format(
pretrained_model_name_or_path, ", ".join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()), pretrained_model_name_or_path,
archive_file
)
)
return None
try:
resolved_config_file = cached_path(config_file, cache_dir=cache_dir)
except EnvironmentError:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find files {} and {} "
"at this path or url.".format(
pretrained_model_name_or_path,
', '.join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()),
pretrained_model_name_or_path,
archive_file, config_file))
if pretrained_model_name_or_path in PRETRAINED_CONFIG_ARCHIVE_MAP:
logger.error(
"Couldn't reach server at '{}' to download pretrained model configuration file.".format(
config_file))
else:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find file {} "
"at this path or url.".format(
pretrained_model_name_or_path, ", ".join(PRETRAINED_CONFIG_ARCHIVE_MAP.keys()), pretrained_model_name_or_path,
config_file
)
)
return None
if resolved_archive_file == archive_file and resolved_config_file == config_file:
logger.info("loading weights file {}".format(archive_file))
......
......@@ -114,10 +114,10 @@ class ProjectedAdaptiveLogSoftmax(nn.Module):
logit = self._compute_logit(hidden, self.out_layers[0].weight,
self.out_layers[0].bias, self.out_projs[0])
if target is not None:
output = -F.log_softmax(logit, dim=-1) \
out = -F.log_softmax(logit, dim=-1) \
.gather(1, target.unsqueeze(1)).squeeze(1)
else:
output = F.log_softmax(logit, dim=-1)
out = F.log_softmax(logit, dim=-1)
else:
# construct weights and biases
weights, biases = [], []
......
......@@ -34,6 +34,9 @@ PRETRAINED_VOCAB_ARCHIVE_MAP = {
'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt",
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt",
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt",
'bert-base-german-cased': "https://int-deepset-models-bert.s3.eu-central-1.amazonaws.com/pytorch/bert-base-german-cased-vocab.txt",
'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-vocab.txt",
'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-vocab.txt",
}
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
'bert-base-uncased': 512,
......@@ -43,6 +46,9 @@ PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
'bert-base-multilingual-uncased': 512,
'bert-base-multilingual-cased': 512,
'bert-base-chinese': 512,
'bert-base-german-cased': 512,
'bert-large-uncased-whole-word-masking': 512,
'bert-large-cased-whole-word-masking': 512,
}
VOCAB_NAME = 'vocab.txt'
......@@ -175,13 +181,18 @@ class BertTokenizer(object):
try:
resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
except EnvironmentError:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url.".format(
pretrained_model_name_or_path,
', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
vocab_file))
if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
logger.error(
"Couldn't reach server at '{}' to download vocabulary.".format(
vocab_file))
else:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url.".format(
pretrained_model_name_or_path,
', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
vocab_file))
return None
if resolved_vocab_file == vocab_file:
logger.info("loading vocabulary file {}".format(vocab_file))
......
......@@ -37,9 +37,11 @@ logger = logging.getLogger(__name__)
PRETRAINED_VOCAB_ARCHIVE_MAP = {
'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json",
'gpt2-medium': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-vocab.json",
}
PRETRAINED_MERGES_ARCHIVE_MAP = {
'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt",
'gpt2-medium': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-merges.txt",
}
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
'gpt2': 1024,
......@@ -91,7 +93,7 @@ class GPT2Tokenizer(object):
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
"""
Instantiate a PreTrainedBertModel from a pre-trained model file.
Instantiate a GPT2Tokenizer from a pre-trained model file.
Download and cache the pre-trained model file if needed.
"""
if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
......@@ -111,14 +113,19 @@ class GPT2Tokenizer(object):
resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
resolved_merges_file = cached_path(merges_file, cache_dir=cache_dir)
except EnvironmentError:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find files {} and {} "
"at this path or url.".format(
pretrained_model_name_or_path,
', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
pretrained_model_name_or_path,
vocab_file, merges_file))
if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
logger.error(
"Couldn't reach server at '{}' to download vocabulary.".format(
vocab_file))
else:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find files {} and {} "
"at this path or url.".format(
pretrained_model_name_or_path,
', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
pretrained_model_name_or_path,
vocab_file, merges_file))
return None
if resolved_vocab_file == vocab_file and resolved_merges_file == merges_file:
logger.info("loading vocabulary file {}".format(vocab_file))
......@@ -263,9 +270,14 @@ class GPT2Tokenizer(object):
def encode(self, text):
return self.convert_tokens_to_ids(self.tokenize(text))
def decode(self, tokens):
text = ''.join([self.decoder[token] for token in tokens])
def decode(self, tokens, skip_special_tokens=False, clean_up_tokenization_spaces=True):
text = ''.join(self.convert_ids_to_tokens(tokens, skip_special_tokens=skip_special_tokens))
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
if clean_up_tokenization_spaces:
text = text.replace('<unk>', '')
text = text.replace(' .', '.').replace(' ?', '?').replace(' !', '!').replace(' ,', ','
).replace(" ' ", "'").replace(" n't", "n't").replace(" 'm", "'m").replace(" do not", " don't"
).replace(" 's", "'s").replace(" 've", "'ve").replace(" 're", "'re")
return text
def save_vocabulary(self, vocab_path):
......
......@@ -101,14 +101,19 @@ class OpenAIGPTTokenizer(object):
resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
resolved_merges_file = cached_path(merges_file, cache_dir=cache_dir)
except EnvironmentError:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find files {} and {} "
"at this path or url.".format(
pretrained_model_name_or_path,
', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
pretrained_model_name_or_path,
vocab_file, merges_file))
if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
logger.error(
"Couldn't reach server at '{}' to download vocabulary.".format(
vocab_file))
else:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find files {} and {} "
"at this path or url.".format(
pretrained_model_name_or_path,
', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
pretrained_model_name_or_path,
vocab_file, merges_file))
return None
if resolved_vocab_file == vocab_file and resolved_merges_file == merges_file:
logger.info("loading vocabulary file {}".format(vocab_file))
......@@ -272,7 +277,7 @@ class OpenAIGPTTokenizer(object):
out_string = ''.join(tokens).replace('</w>', ' ').strip()
if clean_up_tokenization_spaces:
out_string = out_string.replace('<unk>', '')
out_string = out_string.replace(' .', '.').replace(' ?', '?').replace(' !', '!').replace(' ,', ',').replace(' ,', ','
out_string = out_string.replace(' .', '.').replace(' ?', '?').replace(' !', '!').replace(' ,', ','
).replace(" ' ", "'").replace(" n't", "n't").replace(" 'm", "'m").replace(" do not", " don't"
).replace(" 's", "'s").replace(" 've", "'ve").replace(" 're", "'re")
return out_string
......
......@@ -71,14 +71,19 @@ class TransfoXLTokenizer(object):
try:
resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
except EnvironmentError:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find files {} "
"at this path or url.".format(
pretrained_model_name_or_path,
', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
pretrained_model_name_or_path,
vocab_file))
if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
logger.error(
"Couldn't reach server at '{}' to download vocabulary.".format(
vocab_file))
else:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find files {} "
"at this path or url.".format(
pretrained_model_name_or_path,
', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
pretrained_model_name_or_path,
vocab_file))
return None
if resolved_vocab_file == vocab_file:
logger.info("loading vocabulary file {}".format(vocab_file))
......
......@@ -41,6 +41,7 @@ class GPT2ModelTest(unittest.TestCase):
use_token_type_ids=True,
use_labels=True,
vocab_size=99,
n_special=1,
n_positions=33,
n_embd=32,
n_layer=5,
......@@ -58,6 +59,7 @@ class GPT2ModelTest(unittest.TestCase):
self.use_token_type_ids = use_token_type_ids
self.use_labels = use_labels
self.vocab_size = vocab_size
self.n_special = n_special
self.n_positions = n_positions
self.n_embd = n_embd
self.n_layer = n_layer
......@@ -69,7 +71,8 @@ class GPT2ModelTest(unittest.TestCase):
self.scope = scope
def prepare_config_and_inputs(self):
input_ids = GPT2ModelTest.ids_tensor([self.batch_size, self.n_choices, self.seq_length], self.vocab_size)
total_num_tokens = self.vocab_size + self.n_special
input_ids = GPT2ModelTest.ids_tensor([self.batch_size, self.n_choices, self.seq_length], total_num_tokens)
position_ids = None
if self.use_position_ids:
......@@ -90,6 +93,7 @@ class GPT2ModelTest(unittest.TestCase):
config = GPT2Config(
vocab_size_or_config_json_file=self.vocab_size,
n_special=self.n_special,
n_positions=self.n_positions,
n_embd=self.n_embd,
n_layer=self.n_layer,
......@@ -111,8 +115,9 @@ class GPT2ModelTest(unittest.TestCase):
return outputs
def check_gpt2_model_output(self, result):
self.parent.assertEqual(len(result["hidden_states"]), self.n_layer + 1)
self.parent.assertListEqual(
list(result["hidden_states"].size()),
list(result["hidden_states"][0].size()),
[self.batch_size, self.n_choices, self.seq_length, self.n_embd])
......@@ -129,11 +134,29 @@ class GPT2ModelTest(unittest.TestCase):
}
return outputs
def create_gpt2_lm_head_with_output_attention(self, config, input_ids, token_type_ids, position_ids,
mc_labels, lm_labels, mc_token_ids):
model = GPT2LMHeadModel(config, output_attentions=True)
model.eval()
loss = model(input_ids, position_ids, token_type_ids, lm_labels)
attentions, lm_logits, presents = model(input_ids, position_ids, token_type_ids)
outputs = {
"loss": loss,
"lm_logits": lm_logits,
"presents": presents,
"attentions": attentions,
}
return outputs
def check_gpt2_lm_head_output(self, result):
total_voc = self.vocab_size
total_voc = self.n_special + self.vocab_size
self.parent.assertListEqual(
list(result["lm_logits"].size()),
[self.batch_size, self.n_choices, self.seq_length, total_voc])
self.parent.assertEqual(self.n_layer, len(result["presents"]))
self.parent.assertListEqual(
list(result["presents"][0].size()),
[2, self.batch_size * self.n_choices, self.n_head, self.seq_length, self.n_embd // self.n_head])
def check_gpt2_lm_head_loss_output(self, result):
self.parent.assertListEqual(
......@@ -156,8 +179,25 @@ class GPT2ModelTest(unittest.TestCase):
}
return outputs
def create_gpt2_double_heads_with_output_attention(self, config, input_ids, token_type_ids, position_ids,
mc_labels, lm_labels, mc_token_ids):
model = GPT2DoubleHeadsModel(config, output_attentions=True)
model.eval()
loss = model(input_ids, mc_token_ids,
lm_labels=lm_labels, mc_labels=mc_labels,
token_type_ids=token_type_ids, position_ids=position_ids)
attentions, lm_logits, mc_logits, presents = model(input_ids, mc_token_ids, position_ids=position_ids, token_type_ids=token_type_ids)
outputs = {
"loss": loss,
"lm_logits": lm_logits,
"mc_logits": mc_logits,
"presents": presents,
"attentions": attentions,
}
return outputs
def check_gpt2_double_heads_output(self, result):
total_voc = self.vocab_size
total_voc = self.n_special + self.vocab_size
self.parent.assertListEqual(
list(result["lm_logits"].size()),
[self.batch_size, self.n_choices, self.seq_length, total_voc])
......@@ -170,6 +210,98 @@ class GPT2ModelTest(unittest.TestCase):
[list(l.size()) for l in result["loss"]],
[[], []])
def create_and_check_gpt2_for_headmasking(self, config, input_ids, token_type_ids, position_ids,
mc_labels, lm_labels, mc_token_ids):
for model_class in (GPT2Model, GPT2LMHeadModel, GPT2DoubleHeadsModel):
model = model_class(config=config, keep_multihead_output=True)
model.eval()
head_mask = torch.zeros(self.n_layer, self.n_head).to(input_ids.device)
head_mask[0, 1:-1] = 1.0 # Mask all but the first and last heads on the first layer
head_mask[-1, 1:] = 1.0 # Mask all but the first head on the last layer
if isinstance(model, GPT2DoubleHeadsModel):
output = model(input_ids, mc_token_ids, head_mask=head_mask)
else:
output = model(input_ids, head_mask=head_mask)
if isinstance(model, GPT2Model):
output = sum(t.sum() for t in output[0])
elif isinstance(output, (list, tuple)):
output = sum(t.sum() for t in output[:-1])
output = output.sum()
output.backward()
multihead_outputs = (model if isinstance(model, GPT2Model) else model.transformer).get_multihead_outputs()
self.parent.assertEqual(len(multihead_outputs), self.n_layer)
self.parent.assertListEqual(
list(multihead_outputs[0].size()),
[self.batch_size * self.n_choices, self.n_head,
self.seq_length, self.n_embd // self.n_head])
self.parent.assertEqual(
len(multihead_outputs[0][:, 1:(self.n_head-1), :, :].nonzero()),
0)
self.parent.assertEqual(
len(multihead_outputs[0][:, 0, :, :].nonzero()),
self.batch_size * self.n_choices * self.seq_length * self.n_embd // self.n_head)
self.parent.assertEqual(
len(multihead_outputs[0][:, self.n_head-1, :, :].nonzero()),
self.batch_size * self.n_choices * self.seq_length * self.n_embd // self.n_head)
self.parent.assertListEqual(
list(multihead_outputs[1].size()),
[self.batch_size * self.n_choices, self.n_head,
self.seq_length, self.n_embd // self.n_head])
self.parent.assertEqual(
len(multihead_outputs[1].nonzero()),
multihead_outputs[1].numel())
self.parent.assertListEqual(
list(multihead_outputs[-1].size()),
[self.batch_size * self.n_choices, self.n_head,
self.seq_length, self.n_embd // self.n_head])
self.parent.assertEqual(
len(multihead_outputs[-1][:, 1:, :, :].nonzero()),
0)
self.parent.assertEqual(
len(multihead_outputs[-1][:, 0, :, :].nonzero()),
self.batch_size * self.n_choices * self.seq_length * self.n_embd // self.n_head)
def create_and_check_gpt2_for_head_pruning(self, config, input_ids, token_type_ids, position_ids,
mc_labels, lm_labels, mc_token_ids):
for model_class in (GPT2Model, GPT2LMHeadModel, GPT2DoubleHeadsModel):
model = model_class(config=config, keep_multihead_output=True)
model.eval()
transformer = model if isinstance(model, GPT2Model) else model.transformer
heads_to_prune = {0: list(range(1, self.n_head)),
-1: [0]}
transformer.prune_heads(heads_to_prune)
if isinstance(model, GPT2DoubleHeadsModel):
output = model(input_ids, mc_token_ids)
else:
output = model(input_ids)
if isinstance(model, GPT2Model):
output = sum(t.sum() for t in output[0])
elif isinstance(output, (list, tuple)):
output = sum(t.sum() for t in output[:-1])
output = output.sum()
output.backward()
multihead_outputs = transformer.get_multihead_outputs()
self.parent.assertEqual(len(multihead_outputs), self.n_layer)
self.parent.assertListEqual(
list(multihead_outputs[0].size()),
[self.batch_size * self.n_choices, 1,
self.seq_length, self.n_embd // self.n_head])
self.parent.assertListEqual(
list(multihead_outputs[1].size()),
[self.batch_size * self.n_choices, self.n_head,
self.seq_length, self.n_embd // self.n_head])
self.parent.assertListEqual(
list(multihead_outputs[-1].size()),
[self.batch_size * self.n_choices, self.n_head-1,
self.seq_length, self.n_embd // self.n_head])
def test_default(self):
self.run_tester(GPT2ModelTest.GPT2ModelTester(self))
......@@ -208,6 +340,9 @@ class GPT2ModelTest(unittest.TestCase):
tester.check_gpt2_double_heads_output(output_result)
tester.check_gpt2_double_heads_loss_output(output_result)
tester.create_and_check_gpt2_for_headmasking(*config_and_inputs)
tester.create_and_check_gpt2_for_head_pruning(*config_and_inputs)
@classmethod
def ids_tensor(cls, shape, vocab_size, rng=None, name=None):
"""Creates a random int32 tensor of the shape within the vocab size."""
......
......@@ -125,8 +125,9 @@ class OpenAIGPTModelTest(unittest.TestCase):
return outputs
def check_openai_model_output(self, result):
self.parent.assertEqual(len(result["hidden_states"]), self.n_layer + 1)
self.parent.assertListEqual(
list(result["hidden_states"].size()),
list(result["hidden_states"][0].size()),
[self.batch_size, self.n_choices, self.seq_length, self.n_embd])
......@@ -182,6 +183,99 @@ class OpenAIGPTModelTest(unittest.TestCase):
[list(l.size()) for l in result["loss"]],
[[], []])
def create_and_check_openai_for_headmasking(self, config, input_ids, token_type_ids, position_ids,
mc_labels, lm_labels, mc_token_ids):
for model_class in (OpenAIGPTModel, OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel):
model = model_class(config=config, keep_multihead_output=True)
model.eval()
head_mask = torch.zeros(self.n_layer, self.n_head).to(input_ids.device)
head_mask[0, 1:-1] = 1.0 # Mask all but the first and last heads on the first layer
head_mask[-1, 1:] = 1.0 # Mask all but the first head on the last layer
if isinstance(model, OpenAIGPTDoubleHeadsModel):
output = model(input_ids, mc_token_ids, head_mask=head_mask)
else:
output = model(input_ids, head_mask=head_mask)
if isinstance(model, OpenAIGPTModel):
output = sum(t.sum() for t in output[0])
elif isinstance(output, (list, tuple)):
output = sum(t.sum() for t in output)
output = output.sum()
output.backward()
multihead_outputs = (model if isinstance(model, OpenAIGPTModel) else model.transformer).get_multihead_outputs()
self.parent.assertEqual(len(multihead_outputs), self.n_layer)
self.parent.assertListEqual(
list(multihead_outputs[0].size()),
[self.batch_size * self.n_choices, self.n_head,
self.seq_length, self.n_embd // self.n_head])
self.parent.assertEqual(
len(multihead_outputs[0][:, 1:(self.n_head-1), :, :].nonzero()),
0)
self.parent.assertEqual(
len(multihead_outputs[0][:, 0, :, :].nonzero()),
self.batch_size * self.n_choices * self.seq_length * self.n_embd // self.n_head)
self.parent.assertEqual(
len(multihead_outputs[0][:, self.n_head-1, :, :].nonzero()),
self.batch_size * self.n_choices * self.seq_length * self.n_embd // self.n_head)
self.parent.assertListEqual(
list(multihead_outputs[1].size()),
[self.batch_size * self.n_choices, self.n_head,
self.seq_length, self.n_embd // self.n_head])
self.parent.assertEqual(
len(multihead_outputs[1].nonzero()),
multihead_outputs[1].numel())
self.parent.assertListEqual(
list(multihead_outputs[-1].size()),
[self.batch_size * self.n_choices, self.n_head,
self.seq_length, self.n_embd // self.n_head])
self.parent.assertEqual(
len(multihead_outputs[-1][:, 1:, :, :].nonzero()),
0)
self.parent.assertEqual(
len(multihead_outputs[-1][:, 0, :, :].nonzero()),
self.batch_size * self.n_choices * self.seq_length * self.n_embd // self.n_head)
def create_and_check_openai_for_head_pruning(self, config, input_ids, token_type_ids, position_ids,
mc_labels, lm_labels, mc_token_ids):
for model_class in (OpenAIGPTModel, OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel):
model = model_class(config=config, keep_multihead_output=True)
model.eval()
transformer = model if isinstance(model, OpenAIGPTModel) else model.transformer
heads_to_prune = {0: list(range(1, self.n_head)),
-1: [0]}
transformer.prune_heads(heads_to_prune)
if isinstance(model, OpenAIGPTDoubleHeadsModel):
output = model(input_ids, mc_token_ids)
else:
output = model(input_ids)
if isinstance(model, OpenAIGPTModel):
output = sum(t.sum() for t in output[0])
elif isinstance(output, (list, tuple)):
output = sum(t.sum() for t in output)
output = output.sum()
output.backward()
multihead_outputs = transformer.get_multihead_outputs()
self.parent.assertEqual(len(multihead_outputs), self.n_layer)
self.parent.assertListEqual(
list(multihead_outputs[0].size()),
[self.batch_size * self.n_choices, 1,
self.seq_length, self.n_embd // self.n_head])
self.parent.assertListEqual(
list(multihead_outputs[1].size()),
[self.batch_size * self.n_choices, self.n_head,
self.seq_length, self.n_embd // self.n_head])
self.parent.assertListEqual(
list(multihead_outputs[-1].size()),
[self.batch_size * self.n_choices, self.n_head-1,
self.seq_length, self.n_embd // self.n_head])
def test_default(self):
self.run_tester(OpenAIGPTModelTest.OpenAIGPTModelTester(self))
......@@ -220,6 +314,9 @@ class OpenAIGPTModelTest(unittest.TestCase):
tester.check_openai_double_heads_output(output_result)
tester.check_openai_double_heads_loss_output(output_result)
tester.create_and_check_openai_for_headmasking(*config_and_inputs)
tester.create_and_check_openai_for_head_pruning(*config_and_inputs)
@classmethod
def ids_tensor(cls, shape, vocab_size, rng=None, name=None):
"""Creates a random int32 tensor of the shape within the vocab size."""
......
......@@ -28,7 +28,7 @@ import torch
from pytorch_pretrained_bert import (BertConfig, BertModel, BertForMaskedLM,
BertForNextSentencePrediction, BertForPreTraining,
BertForQuestionAnswering, BertForSequenceClassification,
BertForTokenClassification)
BertForTokenClassification, BertForMultipleChoice)
from pytorch_pretrained_bert.modeling import PRETRAINED_MODEL_ARCHIVE_MAP
......@@ -56,6 +56,7 @@ class BertModelTest(unittest.TestCase):
type_sequence_label_size=2,
initializer_range=0.02,
num_labels=3,
num_choices=4,
scope=None):
self.parent = parent
self.batch_size = batch_size
......@@ -77,6 +78,7 @@ class BertModelTest(unittest.TestCase):
self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range
self.num_labels = num_labels
self.num_choices = num_choices
self.scope = scope
def prepare_config_and_inputs(self):
......@@ -92,9 +94,11 @@ class BertModelTest(unittest.TestCase):
sequence_labels = None
token_labels = None
choice_labels = None
if self.use_labels:
sequence_labels = BertModelTest.ids_tensor([self.batch_size], self.type_sequence_label_size)
token_labels = BertModelTest.ids_tensor([self.batch_size, self.seq_length], self.num_labels)
choice_labels = BertModelTest.ids_tensor([self.batch_size], self.num_choices)
config = BertConfig(
vocab_size_or_config_json_file=self.vocab_size,
......@@ -109,14 +113,14 @@ class BertModelTest(unittest.TestCase):
type_vocab_size=self.type_vocab_size,
initializer_range=self.initializer_range)
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
def check_loss_output(self, result):
self.parent.assertListEqual(
list(result["loss"].size()),
[])
def create_bert_model(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels):
def create_bert_model(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = BertModel(config=config)
model.eval()
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
......@@ -137,7 +141,7 @@ class BertModelTest(unittest.TestCase):
self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size])
def create_bert_for_masked_lm(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels):
def create_bert_for_masked_lm(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = BertForMaskedLM(config=config)
model.eval()
loss = model(input_ids, token_type_ids, input_mask, token_labels)
......@@ -153,7 +157,7 @@ class BertModelTest(unittest.TestCase):
list(result["prediction_scores"].size()),
[self.batch_size, self.seq_length, self.vocab_size])
def create_bert_for_next_sequence_prediction(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels):
def create_bert_for_next_sequence_prediction(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = BertForNextSentencePrediction(config=config)
model.eval()
loss = model(input_ids, token_type_ids, input_mask, sequence_labels)
......@@ -170,7 +174,7 @@ class BertModelTest(unittest.TestCase):
[self.batch_size, 2])
def create_bert_for_pretraining(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels):
def create_bert_for_pretraining(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = BertForPreTraining(config=config)
model.eval()
loss = model(input_ids, token_type_ids, input_mask, token_labels, sequence_labels)
......@@ -191,7 +195,7 @@ class BertModelTest(unittest.TestCase):
[self.batch_size, 2])
def create_bert_for_question_answering(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels):
def create_bert_for_question_answering(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = BertForQuestionAnswering(config=config)
model.eval()
loss = model(input_ids, token_type_ids, input_mask, sequence_labels, sequence_labels)
......@@ -212,7 +216,7 @@ class BertModelTest(unittest.TestCase):
[self.batch_size, self.seq_length])
def create_bert_for_sequence_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels):
def create_bert_for_sequence_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = BertForSequenceClassification(config=config, num_labels=self.num_labels)
model.eval()
loss = model(input_ids, token_type_ids, input_mask, sequence_labels)
......@@ -229,7 +233,7 @@ class BertModelTest(unittest.TestCase):
[self.batch_size, self.num_labels])
def create_bert_for_token_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels):
def create_bert_for_token_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = BertForTokenClassification(config=config, num_labels=self.num_labels)
model.eval()
loss = model(input_ids, token_type_ids, input_mask, token_labels)
......@@ -246,6 +250,150 @@ class BertModelTest(unittest.TestCase):
[self.batch_size, self.seq_length, self.num_labels])
def create_bert_for_multiple_choice(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = BertForMultipleChoice(config=config, num_choices=self.num_choices)
model.eval()
multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
loss = model(multiple_choice_inputs_ids,
multiple_choice_token_type_ids,
multiple_choice_input_mask,
choice_labels)
logits = model(multiple_choice_inputs_ids,
multiple_choice_token_type_ids,
multiple_choice_input_mask)
outputs = {
"loss": loss,
"logits": logits,
}
return outputs
def check_bert_for_multiple_choice(self, result):
self.parent.assertListEqual(
list(result["logits"].size()),
[self.batch_size, self.num_choices])
def create_and_check_bert_for_attentions(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
for model_class in (BertModel, BertForMaskedLM, BertForNextSentencePrediction,
BertForPreTraining, BertForQuestionAnswering, BertForSequenceClassification,
BertForTokenClassification):
if model_class in [BertForSequenceClassification,
BertForTokenClassification]:
model = model_class(config=config, num_labels=self.num_labels, output_attentions=True)
else:
model = model_class(config=config, output_attentions=True)
model.eval()
output = model(input_ids, token_type_ids, input_mask)
attentions = output[0]
self.parent.assertEqual(len(attentions), self.num_hidden_layers)
self.parent.assertListEqual(
list(attentions[0].size()),
[self.batch_size, self.num_attention_heads, self.seq_length, self.seq_length])
def create_and_check_bert_for_headmasking(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
for model_class in (BertModel, BertForMaskedLM, BertForNextSentencePrediction,
BertForPreTraining, BertForQuestionAnswering, BertForSequenceClassification,
BertForTokenClassification):
if model_class in [BertForSequenceClassification,
BertForTokenClassification]:
model = model_class(config=config,
num_labels=self.num_labels,
keep_multihead_output=True)
else:
model = model_class(config=config, keep_multihead_output=True)
model.eval()
head_mask = torch.zeros(self.num_hidden_layers, self.num_attention_heads).to(input_ids.device)
head_mask[0, 1:-1] = 1.0 # Mask all but the first and last heads on the first layer
head_mask[-1, 1:] = 1.0 # Mask all but the first head on the last layer
output = model(input_ids, token_type_ids, input_mask, head_mask=head_mask)
if isinstance(model, BertModel):
output = sum(t.sum() for t in output[0])
elif isinstance(output, (list, tuple)):
output = sum(t.sum() for t in output)
output = output.sum()
output.backward()
multihead_outputs = (model if isinstance(model, BertModel) else model.bert).get_multihead_outputs()
self.parent.assertEqual(len(multihead_outputs), self.num_hidden_layers)
self.parent.assertListEqual(
list(multihead_outputs[0].size()),
[self.batch_size, self.num_attention_heads,
self.seq_length, self.hidden_size // self.num_attention_heads])
self.parent.assertEqual(
len(multihead_outputs[0][:, 1:(self.num_attention_heads-1), :, :].nonzero()),
0)
self.parent.assertEqual(
len(multihead_outputs[0][:, 0, :, :].nonzero()),
self.batch_size * self.seq_length * self.hidden_size // self.num_attention_heads)
self.parent.assertEqual(
len(multihead_outputs[0][:, self.num_attention_heads-1, :, :].nonzero()),
self.batch_size * self.seq_length * self.hidden_size // self.num_attention_heads)
self.parent.assertListEqual(
list(multihead_outputs[1].size()),
[self.batch_size, self.num_attention_heads,
self.seq_length, self.hidden_size // self.num_attention_heads])
self.parent.assertEqual(
len(multihead_outputs[1].nonzero()),
multihead_outputs[1].numel())
self.parent.assertListEqual(
list(multihead_outputs[-1].size()),
[self.batch_size, self.num_attention_heads,
self.seq_length, self.hidden_size // self.num_attention_heads])
self.parent.assertEqual(
len(multihead_outputs[-1][:, 1:, :, :].nonzero()),
0)
self.parent.assertEqual(
len(multihead_outputs[-1][:, 0, :, :].nonzero()),
self.batch_size * self.seq_length * self.hidden_size // self.num_attention_heads)
def create_and_check_bert_for_head_pruning(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
for model_class in (BertModel, BertForMaskedLM, BertForNextSentencePrediction,
BertForPreTraining, BertForQuestionAnswering, BertForSequenceClassification,
BertForTokenClassification):
if model_class in [BertForSequenceClassification,
BertForTokenClassification]:
model = model_class(config=config,
num_labels=self.num_labels,
keep_multihead_output=True)
else:
model = model_class(config=config, keep_multihead_output=True)
model.eval()
bert_model = model if isinstance(model, BertModel) else model.bert
heads_to_prune = {0: list(range(1, self.num_attention_heads)),
-1: [0]}
bert_model.prune_heads(heads_to_prune)
output = model(input_ids, token_type_ids, input_mask)
if isinstance(model, BertModel):
output = sum(t.sum() for t in output[0])
elif isinstance(output, (list, tuple)):
output = sum(t.sum() for t in output)
output = output.sum()
output.backward()
multihead_outputs = bert_model.get_multihead_outputs()
self.parent.assertEqual(len(multihead_outputs), self.num_hidden_layers)
self.parent.assertListEqual(
list(multihead_outputs[0].size()),
[self.batch_size, 1,
self.seq_length, self.hidden_size // self.num_attention_heads])
self.parent.assertListEqual(
list(multihead_outputs[1].size()),
[self.batch_size, self.num_attention_heads,
self.seq_length, self.hidden_size // self.num_attention_heads])
self.parent.assertListEqual(
list(multihead_outputs[-1].size()),
[self.batch_size, self.num_attention_heads-1,
self.seq_length, self.hidden_size // self.num_attention_heads])
def test_default(self):
self.run_tester(BertModelTest.BertModelTester(self))
......@@ -300,6 +448,14 @@ class BertModelTest(unittest.TestCase):
tester.check_bert_for_token_classification_output(output_result)
tester.check_loss_output(output_result)
output_result = tester.create_bert_for_multiple_choice(*config_and_inputs)
tester.check_bert_for_multiple_choice(output_result)
tester.check_loss_output(output_result)
tester.create_and_check_bert_for_attentions(*config_and_inputs)
tester.create_and_check_bert_for_headmasking(*config_and_inputs)
tester.create_and_check_bert_for_head_pruning(*config_and_inputs)
@classmethod
def ids_tensor(cls, shape, vocab_size, rng=None, name=None):
"""Creates a random int32 tensor of the shape within the vocab size."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment