Commit 0a4fb0da authored by chrislarson1's avatar chrislarson1
Browse files

Merge remote-tracking branch 'upstream/master' into convert-back-to-tf

merging in latest changes from upstream
parents 314bc6bb 3763f894
...@@ -25,9 +25,6 @@ import copy ...@@ -25,9 +25,6 @@ import copy
import json import json
import math import math
import logging import logging
import tarfile
import tempfile
import shutil
import collections import collections
import sys import sys
from io import open from io import open
...@@ -888,8 +885,7 @@ class TransfoXLPreTrainedModel(nn.Module): ...@@ -888,8 +885,7 @@ class TransfoXLPreTrainedModel(nn.Module):
pass pass
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path, state_dict=None, cache_dir=None, def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
from_tf=False, *inputs, **kwargs):
""" """
Instantiate a TransfoXLPreTrainedModel from a pre-trained model file or a pytorch state dict. Instantiate a TransfoXLPreTrainedModel from a pre-trained model file or a pytorch state dict.
Download and cache the pre-trained model file if needed. Download and cache the pre-trained model file if needed.
...@@ -897,19 +893,25 @@ class TransfoXLPreTrainedModel(nn.Module): ...@@ -897,19 +893,25 @@ class TransfoXLPreTrainedModel(nn.Module):
Params: Params:
pretrained_model_name_or_path: either: pretrained_model_name_or_path: either:
- a str with the name of a pre-trained model to load selected in the list of: - a str with the name of a pre-trained model to load selected in the list of:
. `transfo-xl` . `transfo-xl-wt103`
- a path or url to a pretrained model archive containing: - a path or url to a pretrained model archive containing:
. `transfo_xl_config.json` a configuration file for the model . `transfo_xl_config.json` a configuration file for the model
. `pytorch_model.bin` a PyTorch dump of a TransfoXLModel instance . `pytorch_model.bin` a PyTorch dump of a TransfoXLModel instance
- a path or url to a pretrained model archive containing: - a path or url to a pretrained model archive containing:
. `bert_config.json` a configuration file for the model . `transfo_xl_config.json` a configuration file for the model
. `model.chkpt` a TensorFlow checkpoint . `model.chkpt` a TensorFlow checkpoint
from_tf: should we load the weights from a locally saved TensorFlow checkpoint from_tf: should we load the weights from a locally saved TensorFlow checkpoint
cache_dir: an optional path to a folder in which the pre-trained models will be cached. cache_dir: an optional path to a folder in which the pre-trained models will be cached.
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of pre-trained models state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of pre-trained models
*inputs, **kwargs: additional input for the specific Bert class *inputs, **kwargs: additional input for the specific TransformerXL class
(ex: num_labels for BertForSequenceClassification)
""" """
state_dict = kwargs.get('state_dict', None)
kwargs.pop('state_dict', None)
cache_dir = kwargs.get('cache_dir', None)
kwargs.pop('cache_dir', None)
from_tf = kwargs.get('from_tf', False)
kwargs.pop('from_tf', None)
if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP: if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path] archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path]
config_file = PRETRAINED_CONFIG_ARCHIVE_MAP[pretrained_model_name_or_path] config_file = PRETRAINED_CONFIG_ARCHIVE_MAP[pretrained_model_name_or_path]
...@@ -919,16 +921,37 @@ class TransfoXLPreTrainedModel(nn.Module): ...@@ -919,16 +921,37 @@ class TransfoXLPreTrainedModel(nn.Module):
# redirect to the cache, if necessary # redirect to the cache, if necessary
try: try:
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir) resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
except EnvironmentError:
if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
logger.error(
"Couldn't reach server at '{}' to download pretrained weights.".format(
archive_file))
else:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find file {} "
"at this path or url.".format(
pretrained_model_name_or_path, ", ".join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()), pretrained_model_name_or_path,
archive_file
)
)
return None
try:
resolved_config_file = cached_path(config_file, cache_dir=cache_dir) resolved_config_file = cached_path(config_file, cache_dir=cache_dir)
except EnvironmentError: except EnvironmentError:
logger.error( if pretrained_model_name_or_path in PRETRAINED_CONFIG_ARCHIVE_MAP:
"Model name '{}' was not found in model name list ({}). " logger.error(
"We assumed '{}' was a path or url but couldn't find files {} and {} " "Couldn't reach server at '{}' to download pretrained model configuration file.".format(
"at this path or url.".format( config_file))
pretrained_model_name_or_path, else:
', '.join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()), logger.error(
pretrained_model_name_or_path, "Model name '{}' was not found in model name list ({}). "
archive_file, config_file)) "We assumed '{}' was a path or url but couldn't find file {} "
"at this path or url.".format(
pretrained_model_name_or_path, ", ".join(PRETRAINED_CONFIG_ARCHIVE_MAP.keys()), pretrained_model_name_or_path,
config_file
)
)
return None return None
if resolved_archive_file == archive_file and resolved_config_file == config_file: if resolved_archive_file == archive_file and resolved_config_file == config_file:
logger.info("loading weights file {}".format(archive_file)) logger.info("loading weights file {}".format(archive_file))
......
...@@ -114,10 +114,10 @@ class ProjectedAdaptiveLogSoftmax(nn.Module): ...@@ -114,10 +114,10 @@ class ProjectedAdaptiveLogSoftmax(nn.Module):
logit = self._compute_logit(hidden, self.out_layers[0].weight, logit = self._compute_logit(hidden, self.out_layers[0].weight,
self.out_layers[0].bias, self.out_projs[0]) self.out_layers[0].bias, self.out_projs[0])
if target is not None: if target is not None:
output = -F.log_softmax(logit, dim=-1) \ out = -F.log_softmax(logit, dim=-1) \
.gather(1, target.unsqueeze(1)).squeeze(1) .gather(1, target.unsqueeze(1)).squeeze(1)
else: else:
output = F.log_softmax(logit, dim=-1) out = F.log_softmax(logit, dim=-1)
else: else:
# construct weights and biases # construct weights and biases
weights, biases = [], [] weights, biases = [], []
......
...@@ -34,6 +34,9 @@ PRETRAINED_VOCAB_ARCHIVE_MAP = { ...@@ -34,6 +34,9 @@ PRETRAINED_VOCAB_ARCHIVE_MAP = {
'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt", 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt",
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt", 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt",
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt", 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt",
'bert-base-german-cased': "https://int-deepset-models-bert.s3.eu-central-1.amazonaws.com/pytorch/bert-base-german-cased-vocab.txt",
'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-vocab.txt",
'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-vocab.txt",
} }
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = { PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
'bert-base-uncased': 512, 'bert-base-uncased': 512,
...@@ -43,6 +46,9 @@ PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = { ...@@ -43,6 +46,9 @@ PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
'bert-base-multilingual-uncased': 512, 'bert-base-multilingual-uncased': 512,
'bert-base-multilingual-cased': 512, 'bert-base-multilingual-cased': 512,
'bert-base-chinese': 512, 'bert-base-chinese': 512,
'bert-base-german-cased': 512,
'bert-large-uncased-whole-word-masking': 512,
'bert-large-cased-whole-word-masking': 512,
} }
VOCAB_NAME = 'vocab.txt' VOCAB_NAME = 'vocab.txt'
...@@ -175,13 +181,18 @@ class BertTokenizer(object): ...@@ -175,13 +181,18 @@ class BertTokenizer(object):
try: try:
resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir) resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
except EnvironmentError: except EnvironmentError:
logger.error( if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
"Model name '{}' was not found in model name list ({}). " logger.error(
"We assumed '{}' was a path or url but couldn't find any file " "Couldn't reach server at '{}' to download vocabulary.".format(
"associated to this path or url.".format( vocab_file))
pretrained_model_name_or_path, else:
', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()), logger.error(
vocab_file)) "Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url.".format(
pretrained_model_name_or_path,
', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
vocab_file))
return None return None
if resolved_vocab_file == vocab_file: if resolved_vocab_file == vocab_file:
logger.info("loading vocabulary file {}".format(vocab_file)) logger.info("loading vocabulary file {}".format(vocab_file))
......
...@@ -37,9 +37,11 @@ logger = logging.getLogger(__name__) ...@@ -37,9 +37,11 @@ logger = logging.getLogger(__name__)
PRETRAINED_VOCAB_ARCHIVE_MAP = { PRETRAINED_VOCAB_ARCHIVE_MAP = {
'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json", 'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json",
'gpt2-medium': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-vocab.json",
} }
PRETRAINED_MERGES_ARCHIVE_MAP = { PRETRAINED_MERGES_ARCHIVE_MAP = {
'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt", 'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt",
'gpt2-medium': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-merges.txt",
} }
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = { PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
'gpt2': 1024, 'gpt2': 1024,
...@@ -91,7 +93,7 @@ class GPT2Tokenizer(object): ...@@ -91,7 +93,7 @@ class GPT2Tokenizer(object):
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
""" """
Instantiate a PreTrainedBertModel from a pre-trained model file. Instantiate a GPT2Tokenizer from a pre-trained model file.
Download and cache the pre-trained model file if needed. Download and cache the pre-trained model file if needed.
""" """
if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP: if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
...@@ -111,14 +113,19 @@ class GPT2Tokenizer(object): ...@@ -111,14 +113,19 @@ class GPT2Tokenizer(object):
resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir) resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
resolved_merges_file = cached_path(merges_file, cache_dir=cache_dir) resolved_merges_file = cached_path(merges_file, cache_dir=cache_dir)
except EnvironmentError: except EnvironmentError:
logger.error( if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
"Model name '{}' was not found in model name list ({}). " logger.error(
"We assumed '{}' was a path or url but couldn't find files {} and {} " "Couldn't reach server at '{}' to download vocabulary.".format(
"at this path or url.".format( vocab_file))
pretrained_model_name_or_path, else:
', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()), logger.error(
pretrained_model_name_or_path, "Model name '{}' was not found in model name list ({}). "
vocab_file, merges_file)) "We assumed '{}' was a path or url but couldn't find files {} and {} "
"at this path or url.".format(
pretrained_model_name_or_path,
', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
pretrained_model_name_or_path,
vocab_file, merges_file))
return None return None
if resolved_vocab_file == vocab_file and resolved_merges_file == merges_file: if resolved_vocab_file == vocab_file and resolved_merges_file == merges_file:
logger.info("loading vocabulary file {}".format(vocab_file)) logger.info("loading vocabulary file {}".format(vocab_file))
...@@ -263,9 +270,14 @@ class GPT2Tokenizer(object): ...@@ -263,9 +270,14 @@ class GPT2Tokenizer(object):
def encode(self, text): def encode(self, text):
return self.convert_tokens_to_ids(self.tokenize(text)) return self.convert_tokens_to_ids(self.tokenize(text))
def decode(self, tokens): def decode(self, tokens, skip_special_tokens=False, clean_up_tokenization_spaces=True):
text = ''.join([self.decoder[token] for token in tokens]) text = ''.join(self.convert_ids_to_tokens(tokens, skip_special_tokens=skip_special_tokens))
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors) text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
if clean_up_tokenization_spaces:
text = text.replace('<unk>', '')
text = text.replace(' .', '.').replace(' ?', '?').replace(' !', '!').replace(' ,', ','
).replace(" ' ", "'").replace(" n't", "n't").replace(" 'm", "'m").replace(" do not", " don't"
).replace(" 's", "'s").replace(" 've", "'ve").replace(" 're", "'re")
return text return text
def save_vocabulary(self, vocab_path): def save_vocabulary(self, vocab_path):
......
...@@ -101,14 +101,19 @@ class OpenAIGPTTokenizer(object): ...@@ -101,14 +101,19 @@ class OpenAIGPTTokenizer(object):
resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir) resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
resolved_merges_file = cached_path(merges_file, cache_dir=cache_dir) resolved_merges_file = cached_path(merges_file, cache_dir=cache_dir)
except EnvironmentError: except EnvironmentError:
logger.error( if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
"Model name '{}' was not found in model name list ({}). " logger.error(
"We assumed '{}' was a path or url but couldn't find files {} and {} " "Couldn't reach server at '{}' to download vocabulary.".format(
"at this path or url.".format( vocab_file))
pretrained_model_name_or_path, else:
', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()), logger.error(
pretrained_model_name_or_path, "Model name '{}' was not found in model name list ({}). "
vocab_file, merges_file)) "We assumed '{}' was a path or url but couldn't find files {} and {} "
"at this path or url.".format(
pretrained_model_name_or_path,
', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
pretrained_model_name_or_path,
vocab_file, merges_file))
return None return None
if resolved_vocab_file == vocab_file and resolved_merges_file == merges_file: if resolved_vocab_file == vocab_file and resolved_merges_file == merges_file:
logger.info("loading vocabulary file {}".format(vocab_file)) logger.info("loading vocabulary file {}".format(vocab_file))
...@@ -272,7 +277,7 @@ class OpenAIGPTTokenizer(object): ...@@ -272,7 +277,7 @@ class OpenAIGPTTokenizer(object):
out_string = ''.join(tokens).replace('</w>', ' ').strip() out_string = ''.join(tokens).replace('</w>', ' ').strip()
if clean_up_tokenization_spaces: if clean_up_tokenization_spaces:
out_string = out_string.replace('<unk>', '') out_string = out_string.replace('<unk>', '')
out_string = out_string.replace(' .', '.').replace(' ?', '?').replace(' !', '!').replace(' ,', ',').replace(' ,', ',' out_string = out_string.replace(' .', '.').replace(' ?', '?').replace(' !', '!').replace(' ,', ','
).replace(" ' ", "'").replace(" n't", "n't").replace(" 'm", "'m").replace(" do not", " don't" ).replace(" ' ", "'").replace(" n't", "n't").replace(" 'm", "'m").replace(" do not", " don't"
).replace(" 's", "'s").replace(" 've", "'ve").replace(" 're", "'re") ).replace(" 's", "'s").replace(" 've", "'ve").replace(" 're", "'re")
return out_string return out_string
......
...@@ -71,14 +71,19 @@ class TransfoXLTokenizer(object): ...@@ -71,14 +71,19 @@ class TransfoXLTokenizer(object):
try: try:
resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir) resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
except EnvironmentError: except EnvironmentError:
logger.error( if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
"Model name '{}' was not found in model name list ({}). " logger.error(
"We assumed '{}' was a path or url but couldn't find files {} " "Couldn't reach server at '{}' to download vocabulary.".format(
"at this path or url.".format( vocab_file))
pretrained_model_name_or_path, else:
', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()), logger.error(
pretrained_model_name_or_path, "Model name '{}' was not found in model name list ({}). "
vocab_file)) "We assumed '{}' was a path or url but couldn't find files {} "
"at this path or url.".format(
pretrained_model_name_or_path,
', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
pretrained_model_name_or_path,
vocab_file))
return None return None
if resolved_vocab_file == vocab_file: if resolved_vocab_file == vocab_file:
logger.info("loading vocabulary file {}".format(vocab_file)) logger.info("loading vocabulary file {}".format(vocab_file))
......
...@@ -41,6 +41,7 @@ class GPT2ModelTest(unittest.TestCase): ...@@ -41,6 +41,7 @@ class GPT2ModelTest(unittest.TestCase):
use_token_type_ids=True, use_token_type_ids=True,
use_labels=True, use_labels=True,
vocab_size=99, vocab_size=99,
n_special=1,
n_positions=33, n_positions=33,
n_embd=32, n_embd=32,
n_layer=5, n_layer=5,
...@@ -58,6 +59,7 @@ class GPT2ModelTest(unittest.TestCase): ...@@ -58,6 +59,7 @@ class GPT2ModelTest(unittest.TestCase):
self.use_token_type_ids = use_token_type_ids self.use_token_type_ids = use_token_type_ids
self.use_labels = use_labels self.use_labels = use_labels
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.n_special = n_special
self.n_positions = n_positions self.n_positions = n_positions
self.n_embd = n_embd self.n_embd = n_embd
self.n_layer = n_layer self.n_layer = n_layer
...@@ -69,7 +71,8 @@ class GPT2ModelTest(unittest.TestCase): ...@@ -69,7 +71,8 @@ class GPT2ModelTest(unittest.TestCase):
self.scope = scope self.scope = scope
def prepare_config_and_inputs(self): def prepare_config_and_inputs(self):
input_ids = GPT2ModelTest.ids_tensor([self.batch_size, self.n_choices, self.seq_length], self.vocab_size) total_num_tokens = self.vocab_size + self.n_special
input_ids = GPT2ModelTest.ids_tensor([self.batch_size, self.n_choices, self.seq_length], total_num_tokens)
position_ids = None position_ids = None
if self.use_position_ids: if self.use_position_ids:
...@@ -90,6 +93,7 @@ class GPT2ModelTest(unittest.TestCase): ...@@ -90,6 +93,7 @@ class GPT2ModelTest(unittest.TestCase):
config = GPT2Config( config = GPT2Config(
vocab_size_or_config_json_file=self.vocab_size, vocab_size_or_config_json_file=self.vocab_size,
n_special=self.n_special,
n_positions=self.n_positions, n_positions=self.n_positions,
n_embd=self.n_embd, n_embd=self.n_embd,
n_layer=self.n_layer, n_layer=self.n_layer,
...@@ -111,8 +115,9 @@ class GPT2ModelTest(unittest.TestCase): ...@@ -111,8 +115,9 @@ class GPT2ModelTest(unittest.TestCase):
return outputs return outputs
def check_gpt2_model_output(self, result): def check_gpt2_model_output(self, result):
self.parent.assertEqual(len(result["hidden_states"]), self.n_layer + 1)
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["hidden_states"].size()), list(result["hidden_states"][0].size()),
[self.batch_size, self.n_choices, self.seq_length, self.n_embd]) [self.batch_size, self.n_choices, self.seq_length, self.n_embd])
...@@ -129,11 +134,29 @@ class GPT2ModelTest(unittest.TestCase): ...@@ -129,11 +134,29 @@ class GPT2ModelTest(unittest.TestCase):
} }
return outputs return outputs
def create_gpt2_lm_head_with_output_attention(self, config, input_ids, token_type_ids, position_ids,
mc_labels, lm_labels, mc_token_ids):
model = GPT2LMHeadModel(config, output_attentions=True)
model.eval()
loss = model(input_ids, position_ids, token_type_ids, lm_labels)
attentions, lm_logits, presents = model(input_ids, position_ids, token_type_ids)
outputs = {
"loss": loss,
"lm_logits": lm_logits,
"presents": presents,
"attentions": attentions,
}
return outputs
def check_gpt2_lm_head_output(self, result): def check_gpt2_lm_head_output(self, result):
total_voc = self.vocab_size total_voc = self.n_special + self.vocab_size
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["lm_logits"].size()), list(result["lm_logits"].size()),
[self.batch_size, self.n_choices, self.seq_length, total_voc]) [self.batch_size, self.n_choices, self.seq_length, total_voc])
self.parent.assertEqual(self.n_layer, len(result["presents"]))
self.parent.assertListEqual(
list(result["presents"][0].size()),
[2, self.batch_size * self.n_choices, self.n_head, self.seq_length, self.n_embd // self.n_head])
def check_gpt2_lm_head_loss_output(self, result): def check_gpt2_lm_head_loss_output(self, result):
self.parent.assertListEqual( self.parent.assertListEqual(
...@@ -156,8 +179,25 @@ class GPT2ModelTest(unittest.TestCase): ...@@ -156,8 +179,25 @@ class GPT2ModelTest(unittest.TestCase):
} }
return outputs return outputs
def create_gpt2_double_heads_with_output_attention(self, config, input_ids, token_type_ids, position_ids,
mc_labels, lm_labels, mc_token_ids):
model = GPT2DoubleHeadsModel(config, output_attentions=True)
model.eval()
loss = model(input_ids, mc_token_ids,
lm_labels=lm_labels, mc_labels=mc_labels,
token_type_ids=token_type_ids, position_ids=position_ids)
attentions, lm_logits, mc_logits, presents = model(input_ids, mc_token_ids, position_ids=position_ids, token_type_ids=token_type_ids)
outputs = {
"loss": loss,
"lm_logits": lm_logits,
"mc_logits": mc_logits,
"presents": presents,
"attentions": attentions,
}
return outputs
def check_gpt2_double_heads_output(self, result): def check_gpt2_double_heads_output(self, result):
total_voc = self.vocab_size total_voc = self.n_special + self.vocab_size
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["lm_logits"].size()), list(result["lm_logits"].size()),
[self.batch_size, self.n_choices, self.seq_length, total_voc]) [self.batch_size, self.n_choices, self.seq_length, total_voc])
...@@ -170,6 +210,98 @@ class GPT2ModelTest(unittest.TestCase): ...@@ -170,6 +210,98 @@ class GPT2ModelTest(unittest.TestCase):
[list(l.size()) for l in result["loss"]], [list(l.size()) for l in result["loss"]],
[[], []]) [[], []])
def create_and_check_gpt2_for_headmasking(self, config, input_ids, token_type_ids, position_ids,
mc_labels, lm_labels, mc_token_ids):
for model_class in (GPT2Model, GPT2LMHeadModel, GPT2DoubleHeadsModel):
model = model_class(config=config, keep_multihead_output=True)
model.eval()
head_mask = torch.zeros(self.n_layer, self.n_head).to(input_ids.device)
head_mask[0, 1:-1] = 1.0 # Mask all but the first and last heads on the first layer
head_mask[-1, 1:] = 1.0 # Mask all but the first head on the last layer
if isinstance(model, GPT2DoubleHeadsModel):
output = model(input_ids, mc_token_ids, head_mask=head_mask)
else:
output = model(input_ids, head_mask=head_mask)
if isinstance(model, GPT2Model):
output = sum(t.sum() for t in output[0])
elif isinstance(output, (list, tuple)):
output = sum(t.sum() for t in output[:-1])
output = output.sum()
output.backward()
multihead_outputs = (model if isinstance(model, GPT2Model) else model.transformer).get_multihead_outputs()
self.parent.assertEqual(len(multihead_outputs), self.n_layer)
self.parent.assertListEqual(
list(multihead_outputs[0].size()),
[self.batch_size * self.n_choices, self.n_head,
self.seq_length, self.n_embd // self.n_head])
self.parent.assertEqual(
len(multihead_outputs[0][:, 1:(self.n_head-1), :, :].nonzero()),
0)
self.parent.assertEqual(
len(multihead_outputs[0][:, 0, :, :].nonzero()),
self.batch_size * self.n_choices * self.seq_length * self.n_embd // self.n_head)
self.parent.assertEqual(
len(multihead_outputs[0][:, self.n_head-1, :, :].nonzero()),
self.batch_size * self.n_choices * self.seq_length * self.n_embd // self.n_head)
self.parent.assertListEqual(
list(multihead_outputs[1].size()),
[self.batch_size * self.n_choices, self.n_head,
self.seq_length, self.n_embd // self.n_head])
self.parent.assertEqual(
len(multihead_outputs[1].nonzero()),
multihead_outputs[1].numel())
self.parent.assertListEqual(
list(multihead_outputs[-1].size()),
[self.batch_size * self.n_choices, self.n_head,
self.seq_length, self.n_embd // self.n_head])
self.parent.assertEqual(
len(multihead_outputs[-1][:, 1:, :, :].nonzero()),
0)
self.parent.assertEqual(
len(multihead_outputs[-1][:, 0, :, :].nonzero()),
self.batch_size * self.n_choices * self.seq_length * self.n_embd // self.n_head)
def create_and_check_gpt2_for_head_pruning(self, config, input_ids, token_type_ids, position_ids,
mc_labels, lm_labels, mc_token_ids):
for model_class in (GPT2Model, GPT2LMHeadModel, GPT2DoubleHeadsModel):
model = model_class(config=config, keep_multihead_output=True)
model.eval()
transformer = model if isinstance(model, GPT2Model) else model.transformer
heads_to_prune = {0: list(range(1, self.n_head)),
-1: [0]}
transformer.prune_heads(heads_to_prune)
if isinstance(model, GPT2DoubleHeadsModel):
output = model(input_ids, mc_token_ids)
else:
output = model(input_ids)
if isinstance(model, GPT2Model):
output = sum(t.sum() for t in output[0])
elif isinstance(output, (list, tuple)):
output = sum(t.sum() for t in output[:-1])
output = output.sum()
output.backward()
multihead_outputs = transformer.get_multihead_outputs()
self.parent.assertEqual(len(multihead_outputs), self.n_layer)
self.parent.assertListEqual(
list(multihead_outputs[0].size()),
[self.batch_size * self.n_choices, 1,
self.seq_length, self.n_embd // self.n_head])
self.parent.assertListEqual(
list(multihead_outputs[1].size()),
[self.batch_size * self.n_choices, self.n_head,
self.seq_length, self.n_embd // self.n_head])
self.parent.assertListEqual(
list(multihead_outputs[-1].size()),
[self.batch_size * self.n_choices, self.n_head-1,
self.seq_length, self.n_embd // self.n_head])
def test_default(self): def test_default(self):
self.run_tester(GPT2ModelTest.GPT2ModelTester(self)) self.run_tester(GPT2ModelTest.GPT2ModelTester(self))
...@@ -208,6 +340,9 @@ class GPT2ModelTest(unittest.TestCase): ...@@ -208,6 +340,9 @@ class GPT2ModelTest(unittest.TestCase):
tester.check_gpt2_double_heads_output(output_result) tester.check_gpt2_double_heads_output(output_result)
tester.check_gpt2_double_heads_loss_output(output_result) tester.check_gpt2_double_heads_loss_output(output_result)
tester.create_and_check_gpt2_for_headmasking(*config_and_inputs)
tester.create_and_check_gpt2_for_head_pruning(*config_and_inputs)
@classmethod @classmethod
def ids_tensor(cls, shape, vocab_size, rng=None, name=None): def ids_tensor(cls, shape, vocab_size, rng=None, name=None):
"""Creates a random int32 tensor of the shape within the vocab size.""" """Creates a random int32 tensor of the shape within the vocab size."""
......
...@@ -125,8 +125,9 @@ class OpenAIGPTModelTest(unittest.TestCase): ...@@ -125,8 +125,9 @@ class OpenAIGPTModelTest(unittest.TestCase):
return outputs return outputs
def check_openai_model_output(self, result): def check_openai_model_output(self, result):
self.parent.assertEqual(len(result["hidden_states"]), self.n_layer + 1)
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["hidden_states"].size()), list(result["hidden_states"][0].size()),
[self.batch_size, self.n_choices, self.seq_length, self.n_embd]) [self.batch_size, self.n_choices, self.seq_length, self.n_embd])
...@@ -182,6 +183,99 @@ class OpenAIGPTModelTest(unittest.TestCase): ...@@ -182,6 +183,99 @@ class OpenAIGPTModelTest(unittest.TestCase):
[list(l.size()) for l in result["loss"]], [list(l.size()) for l in result["loss"]],
[[], []]) [[], []])
def create_and_check_openai_for_headmasking(self, config, input_ids, token_type_ids, position_ids,
mc_labels, lm_labels, mc_token_ids):
for model_class in (OpenAIGPTModel, OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel):
model = model_class(config=config, keep_multihead_output=True)
model.eval()
head_mask = torch.zeros(self.n_layer, self.n_head).to(input_ids.device)
head_mask[0, 1:-1] = 1.0 # Mask all but the first and last heads on the first layer
head_mask[-1, 1:] = 1.0 # Mask all but the first head on the last layer
if isinstance(model, OpenAIGPTDoubleHeadsModel):
output = model(input_ids, mc_token_ids, head_mask=head_mask)
else:
output = model(input_ids, head_mask=head_mask)
if isinstance(model, OpenAIGPTModel):
output = sum(t.sum() for t in output[0])
elif isinstance(output, (list, tuple)):
output = sum(t.sum() for t in output)
output = output.sum()
output.backward()
multihead_outputs = (model if isinstance(model, OpenAIGPTModel) else model.transformer).get_multihead_outputs()
self.parent.assertEqual(len(multihead_outputs), self.n_layer)
self.parent.assertListEqual(
list(multihead_outputs[0].size()),
[self.batch_size * self.n_choices, self.n_head,
self.seq_length, self.n_embd // self.n_head])
self.parent.assertEqual(
len(multihead_outputs[0][:, 1:(self.n_head-1), :, :].nonzero()),
0)
self.parent.assertEqual(
len(multihead_outputs[0][:, 0, :, :].nonzero()),
self.batch_size * self.n_choices * self.seq_length * self.n_embd // self.n_head)
self.parent.assertEqual(
len(multihead_outputs[0][:, self.n_head-1, :, :].nonzero()),
self.batch_size * self.n_choices * self.seq_length * self.n_embd // self.n_head)
self.parent.assertListEqual(
list(multihead_outputs[1].size()),
[self.batch_size * self.n_choices, self.n_head,
self.seq_length, self.n_embd // self.n_head])
self.parent.assertEqual(
len(multihead_outputs[1].nonzero()),
multihead_outputs[1].numel())
self.parent.assertListEqual(
list(multihead_outputs[-1].size()),
[self.batch_size * self.n_choices, self.n_head,
self.seq_length, self.n_embd // self.n_head])
self.parent.assertEqual(
len(multihead_outputs[-1][:, 1:, :, :].nonzero()),
0)
self.parent.assertEqual(
len(multihead_outputs[-1][:, 0, :, :].nonzero()),
self.batch_size * self.n_choices * self.seq_length * self.n_embd // self.n_head)
def create_and_check_openai_for_head_pruning(self, config, input_ids, token_type_ids, position_ids,
mc_labels, lm_labels, mc_token_ids):
for model_class in (OpenAIGPTModel, OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel):
model = model_class(config=config, keep_multihead_output=True)
model.eval()
transformer = model if isinstance(model, OpenAIGPTModel) else model.transformer
heads_to_prune = {0: list(range(1, self.n_head)),
-1: [0]}
transformer.prune_heads(heads_to_prune)
if isinstance(model, OpenAIGPTDoubleHeadsModel):
output = model(input_ids, mc_token_ids)
else:
output = model(input_ids)
if isinstance(model, OpenAIGPTModel):
output = sum(t.sum() for t in output[0])
elif isinstance(output, (list, tuple)):
output = sum(t.sum() for t in output)
output = output.sum()
output.backward()
multihead_outputs = transformer.get_multihead_outputs()
self.parent.assertEqual(len(multihead_outputs), self.n_layer)
self.parent.assertListEqual(
list(multihead_outputs[0].size()),
[self.batch_size * self.n_choices, 1,
self.seq_length, self.n_embd // self.n_head])
self.parent.assertListEqual(
list(multihead_outputs[1].size()),
[self.batch_size * self.n_choices, self.n_head,
self.seq_length, self.n_embd // self.n_head])
self.parent.assertListEqual(
list(multihead_outputs[-1].size()),
[self.batch_size * self.n_choices, self.n_head-1,
self.seq_length, self.n_embd // self.n_head])
def test_default(self): def test_default(self):
self.run_tester(OpenAIGPTModelTest.OpenAIGPTModelTester(self)) self.run_tester(OpenAIGPTModelTest.OpenAIGPTModelTester(self))
...@@ -220,6 +314,9 @@ class OpenAIGPTModelTest(unittest.TestCase): ...@@ -220,6 +314,9 @@ class OpenAIGPTModelTest(unittest.TestCase):
tester.check_openai_double_heads_output(output_result) tester.check_openai_double_heads_output(output_result)
tester.check_openai_double_heads_loss_output(output_result) tester.check_openai_double_heads_loss_output(output_result)
tester.create_and_check_openai_for_headmasking(*config_and_inputs)
tester.create_and_check_openai_for_head_pruning(*config_and_inputs)
@classmethod @classmethod
def ids_tensor(cls, shape, vocab_size, rng=None, name=None): def ids_tensor(cls, shape, vocab_size, rng=None, name=None):
"""Creates a random int32 tensor of the shape within the vocab size.""" """Creates a random int32 tensor of the shape within the vocab size."""
......
...@@ -28,7 +28,7 @@ import torch ...@@ -28,7 +28,7 @@ import torch
from pytorch_pretrained_bert import (BertConfig, BertModel, BertForMaskedLM, from pytorch_pretrained_bert import (BertConfig, BertModel, BertForMaskedLM,
BertForNextSentencePrediction, BertForPreTraining, BertForNextSentencePrediction, BertForPreTraining,
BertForQuestionAnswering, BertForSequenceClassification, BertForQuestionAnswering, BertForSequenceClassification,
BertForTokenClassification) BertForTokenClassification, BertForMultipleChoice)
from pytorch_pretrained_bert.modeling import PRETRAINED_MODEL_ARCHIVE_MAP from pytorch_pretrained_bert.modeling import PRETRAINED_MODEL_ARCHIVE_MAP
...@@ -56,6 +56,7 @@ class BertModelTest(unittest.TestCase): ...@@ -56,6 +56,7 @@ class BertModelTest(unittest.TestCase):
type_sequence_label_size=2, type_sequence_label_size=2,
initializer_range=0.02, initializer_range=0.02,
num_labels=3, num_labels=3,
num_choices=4,
scope=None): scope=None):
self.parent = parent self.parent = parent
self.batch_size = batch_size self.batch_size = batch_size
...@@ -77,6 +78,7 @@ class BertModelTest(unittest.TestCase): ...@@ -77,6 +78,7 @@ class BertModelTest(unittest.TestCase):
self.type_sequence_label_size = type_sequence_label_size self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range self.initializer_range = initializer_range
self.num_labels = num_labels self.num_labels = num_labels
self.num_choices = num_choices
self.scope = scope self.scope = scope
def prepare_config_and_inputs(self): def prepare_config_and_inputs(self):
...@@ -92,9 +94,11 @@ class BertModelTest(unittest.TestCase): ...@@ -92,9 +94,11 @@ class BertModelTest(unittest.TestCase):
sequence_labels = None sequence_labels = None
token_labels = None token_labels = None
choice_labels = None
if self.use_labels: if self.use_labels:
sequence_labels = BertModelTest.ids_tensor([self.batch_size], self.type_sequence_label_size) sequence_labels = BertModelTest.ids_tensor([self.batch_size], self.type_sequence_label_size)
token_labels = BertModelTest.ids_tensor([self.batch_size, self.seq_length], self.num_labels) token_labels = BertModelTest.ids_tensor([self.batch_size, self.seq_length], self.num_labels)
choice_labels = BertModelTest.ids_tensor([self.batch_size], self.num_choices)
config = BertConfig( config = BertConfig(
vocab_size_or_config_json_file=self.vocab_size, vocab_size_or_config_json_file=self.vocab_size,
...@@ -109,14 +113,14 @@ class BertModelTest(unittest.TestCase): ...@@ -109,14 +113,14 @@ class BertModelTest(unittest.TestCase):
type_vocab_size=self.type_vocab_size, type_vocab_size=self.type_vocab_size,
initializer_range=self.initializer_range) initializer_range=self.initializer_range)
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
def check_loss_output(self, result): def check_loss_output(self, result):
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["loss"].size()), list(result["loss"].size()),
[]) [])
def create_bert_model(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels): def create_bert_model(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = BertModel(config=config) model = BertModel(config=config)
model.eval() model.eval()
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
...@@ -137,7 +141,7 @@ class BertModelTest(unittest.TestCase): ...@@ -137,7 +141,7 @@ class BertModelTest(unittest.TestCase):
self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size]) self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size])
def create_bert_for_masked_lm(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels): def create_bert_for_masked_lm(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = BertForMaskedLM(config=config) model = BertForMaskedLM(config=config)
model.eval() model.eval()
loss = model(input_ids, token_type_ids, input_mask, token_labels) loss = model(input_ids, token_type_ids, input_mask, token_labels)
...@@ -153,7 +157,7 @@ class BertModelTest(unittest.TestCase): ...@@ -153,7 +157,7 @@ class BertModelTest(unittest.TestCase):
list(result["prediction_scores"].size()), list(result["prediction_scores"].size()),
[self.batch_size, self.seq_length, self.vocab_size]) [self.batch_size, self.seq_length, self.vocab_size])
def create_bert_for_next_sequence_prediction(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels): def create_bert_for_next_sequence_prediction(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = BertForNextSentencePrediction(config=config) model = BertForNextSentencePrediction(config=config)
model.eval() model.eval()
loss = model(input_ids, token_type_ids, input_mask, sequence_labels) loss = model(input_ids, token_type_ids, input_mask, sequence_labels)
...@@ -170,7 +174,7 @@ class BertModelTest(unittest.TestCase): ...@@ -170,7 +174,7 @@ class BertModelTest(unittest.TestCase):
[self.batch_size, 2]) [self.batch_size, 2])
def create_bert_for_pretraining(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels): def create_bert_for_pretraining(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = BertForPreTraining(config=config) model = BertForPreTraining(config=config)
model.eval() model.eval()
loss = model(input_ids, token_type_ids, input_mask, token_labels, sequence_labels) loss = model(input_ids, token_type_ids, input_mask, token_labels, sequence_labels)
...@@ -191,7 +195,7 @@ class BertModelTest(unittest.TestCase): ...@@ -191,7 +195,7 @@ class BertModelTest(unittest.TestCase):
[self.batch_size, 2]) [self.batch_size, 2])
def create_bert_for_question_answering(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels): def create_bert_for_question_answering(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = BertForQuestionAnswering(config=config) model = BertForQuestionAnswering(config=config)
model.eval() model.eval()
loss = model(input_ids, token_type_ids, input_mask, sequence_labels, sequence_labels) loss = model(input_ids, token_type_ids, input_mask, sequence_labels, sequence_labels)
...@@ -212,7 +216,7 @@ class BertModelTest(unittest.TestCase): ...@@ -212,7 +216,7 @@ class BertModelTest(unittest.TestCase):
[self.batch_size, self.seq_length]) [self.batch_size, self.seq_length])
def create_bert_for_sequence_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels): def create_bert_for_sequence_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = BertForSequenceClassification(config=config, num_labels=self.num_labels) model = BertForSequenceClassification(config=config, num_labels=self.num_labels)
model.eval() model.eval()
loss = model(input_ids, token_type_ids, input_mask, sequence_labels) loss = model(input_ids, token_type_ids, input_mask, sequence_labels)
...@@ -229,7 +233,7 @@ class BertModelTest(unittest.TestCase): ...@@ -229,7 +233,7 @@ class BertModelTest(unittest.TestCase):
[self.batch_size, self.num_labels]) [self.batch_size, self.num_labels])
def create_bert_for_token_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels): def create_bert_for_token_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = BertForTokenClassification(config=config, num_labels=self.num_labels) model = BertForTokenClassification(config=config, num_labels=self.num_labels)
model.eval() model.eval()
loss = model(input_ids, token_type_ids, input_mask, token_labels) loss = model(input_ids, token_type_ids, input_mask, token_labels)
...@@ -246,6 +250,150 @@ class BertModelTest(unittest.TestCase): ...@@ -246,6 +250,150 @@ class BertModelTest(unittest.TestCase):
[self.batch_size, self.seq_length, self.num_labels]) [self.batch_size, self.seq_length, self.num_labels])
def create_bert_for_multiple_choice(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = BertForMultipleChoice(config=config, num_choices=self.num_choices)
model.eval()
multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
loss = model(multiple_choice_inputs_ids,
multiple_choice_token_type_ids,
multiple_choice_input_mask,
choice_labels)
logits = model(multiple_choice_inputs_ids,
multiple_choice_token_type_ids,
multiple_choice_input_mask)
outputs = {
"loss": loss,
"logits": logits,
}
return outputs
def check_bert_for_multiple_choice(self, result):
self.parent.assertListEqual(
list(result["logits"].size()),
[self.batch_size, self.num_choices])
def create_and_check_bert_for_attentions(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
for model_class in (BertModel, BertForMaskedLM, BertForNextSentencePrediction,
BertForPreTraining, BertForQuestionAnswering, BertForSequenceClassification,
BertForTokenClassification):
if model_class in [BertForSequenceClassification,
BertForTokenClassification]:
model = model_class(config=config, num_labels=self.num_labels, output_attentions=True)
else:
model = model_class(config=config, output_attentions=True)
model.eval()
output = model(input_ids, token_type_ids, input_mask)
attentions = output[0]
self.parent.assertEqual(len(attentions), self.num_hidden_layers)
self.parent.assertListEqual(
list(attentions[0].size()),
[self.batch_size, self.num_attention_heads, self.seq_length, self.seq_length])
def create_and_check_bert_for_headmasking(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
for model_class in (BertModel, BertForMaskedLM, BertForNextSentencePrediction,
BertForPreTraining, BertForQuestionAnswering, BertForSequenceClassification,
BertForTokenClassification):
if model_class in [BertForSequenceClassification,
BertForTokenClassification]:
model = model_class(config=config,
num_labels=self.num_labels,
keep_multihead_output=True)
else:
model = model_class(config=config, keep_multihead_output=True)
model.eval()
head_mask = torch.zeros(self.num_hidden_layers, self.num_attention_heads).to(input_ids.device)
head_mask[0, 1:-1] = 1.0 # Mask all but the first and last heads on the first layer
head_mask[-1, 1:] = 1.0 # Mask all but the first head on the last layer
output = model(input_ids, token_type_ids, input_mask, head_mask=head_mask)
if isinstance(model, BertModel):
output = sum(t.sum() for t in output[0])
elif isinstance(output, (list, tuple)):
output = sum(t.sum() for t in output)
output = output.sum()
output.backward()
multihead_outputs = (model if isinstance(model, BertModel) else model.bert).get_multihead_outputs()
self.parent.assertEqual(len(multihead_outputs), self.num_hidden_layers)
self.parent.assertListEqual(
list(multihead_outputs[0].size()),
[self.batch_size, self.num_attention_heads,
self.seq_length, self.hidden_size // self.num_attention_heads])
self.parent.assertEqual(
len(multihead_outputs[0][:, 1:(self.num_attention_heads-1), :, :].nonzero()),
0)
self.parent.assertEqual(
len(multihead_outputs[0][:, 0, :, :].nonzero()),
self.batch_size * self.seq_length * self.hidden_size // self.num_attention_heads)
self.parent.assertEqual(
len(multihead_outputs[0][:, self.num_attention_heads-1, :, :].nonzero()),
self.batch_size * self.seq_length * self.hidden_size // self.num_attention_heads)
self.parent.assertListEqual(
list(multihead_outputs[1].size()),
[self.batch_size, self.num_attention_heads,
self.seq_length, self.hidden_size // self.num_attention_heads])
self.parent.assertEqual(
len(multihead_outputs[1].nonzero()),
multihead_outputs[1].numel())
self.parent.assertListEqual(
list(multihead_outputs[-1].size()),
[self.batch_size, self.num_attention_heads,
self.seq_length, self.hidden_size // self.num_attention_heads])
self.parent.assertEqual(
len(multihead_outputs[-1][:, 1:, :, :].nonzero()),
0)
self.parent.assertEqual(
len(multihead_outputs[-1][:, 0, :, :].nonzero()),
self.batch_size * self.seq_length * self.hidden_size // self.num_attention_heads)
def create_and_check_bert_for_head_pruning(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
for model_class in (BertModel, BertForMaskedLM, BertForNextSentencePrediction,
BertForPreTraining, BertForQuestionAnswering, BertForSequenceClassification,
BertForTokenClassification):
if model_class in [BertForSequenceClassification,
BertForTokenClassification]:
model = model_class(config=config,
num_labels=self.num_labels,
keep_multihead_output=True)
else:
model = model_class(config=config, keep_multihead_output=True)
model.eval()
bert_model = model if isinstance(model, BertModel) else model.bert
heads_to_prune = {0: list(range(1, self.num_attention_heads)),
-1: [0]}
bert_model.prune_heads(heads_to_prune)
output = model(input_ids, token_type_ids, input_mask)
if isinstance(model, BertModel):
output = sum(t.sum() for t in output[0])
elif isinstance(output, (list, tuple)):
output = sum(t.sum() for t in output)
output = output.sum()
output.backward()
multihead_outputs = bert_model.get_multihead_outputs()
self.parent.assertEqual(len(multihead_outputs), self.num_hidden_layers)
self.parent.assertListEqual(
list(multihead_outputs[0].size()),
[self.batch_size, 1,
self.seq_length, self.hidden_size // self.num_attention_heads])
self.parent.assertListEqual(
list(multihead_outputs[1].size()),
[self.batch_size, self.num_attention_heads,
self.seq_length, self.hidden_size // self.num_attention_heads])
self.parent.assertListEqual(
list(multihead_outputs[-1].size()),
[self.batch_size, self.num_attention_heads-1,
self.seq_length, self.hidden_size // self.num_attention_heads])
def test_default(self): def test_default(self):
self.run_tester(BertModelTest.BertModelTester(self)) self.run_tester(BertModelTest.BertModelTester(self))
...@@ -300,6 +448,14 @@ class BertModelTest(unittest.TestCase): ...@@ -300,6 +448,14 @@ class BertModelTest(unittest.TestCase):
tester.check_bert_for_token_classification_output(output_result) tester.check_bert_for_token_classification_output(output_result)
tester.check_loss_output(output_result) tester.check_loss_output(output_result)
output_result = tester.create_bert_for_multiple_choice(*config_and_inputs)
tester.check_bert_for_multiple_choice(output_result)
tester.check_loss_output(output_result)
tester.create_and_check_bert_for_attentions(*config_and_inputs)
tester.create_and_check_bert_for_headmasking(*config_and_inputs)
tester.create_and_check_bert_for_head_pruning(*config_and_inputs)
@classmethod @classmethod
def ids_tensor(cls, shape, vocab_size, rng=None, name=None): def ids_tensor(cls, shape, vocab_size, rng=None, name=None):
"""Creates a random int32 tensor of the shape within the vocab size.""" """Creates a random int32 tensor of the shape within the vocab size."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment