Commit 1484d67d authored by thomwolf's avatar thomwolf
Browse files

[LARGE] updating all tests and API

parent 4f8b5f68
...@@ -41,6 +41,12 @@ class PretrainedConfig(object): ...@@ -41,6 +41,12 @@ class PretrainedConfig(object):
""" """
pretrained_config_archive_map = {} pretrained_config_archive_map = {}
def __init__(self, **kwargs):
self.finetuning_task = kwargs.pop('finetuning_task', None)
self.num_labels = kwargs.pop('num_labels', 2)
self.output_attentions = kwargs.pop('output_attentions', False)
self.output_hidden_states = kwargs.pop('output_hidden_states', False)
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
""" """
...@@ -114,6 +120,9 @@ class PretrainedConfig(object): ...@@ -114,6 +120,9 @@ class PretrainedConfig(object):
text = reader.read() text = reader.read()
return cls.from_dict(json.loads(text)) return cls.from_dict(json.loads(text))
def __eq__(self, other):
return self.__dict__ == other.__dict__
def __repr__(self): def __repr__(self):
return str(self.to_json_string()) return str(self.to_json_string())
...@@ -133,12 +142,11 @@ class PretrainedConfig(object): ...@@ -133,12 +142,11 @@ class PretrainedConfig(object):
class PreTrainedModel(nn.Module): class PreTrainedModel(nn.Module):
""" An abstract class to handle weights initialization and """ An abstract class to handle storing model config and
a simple interface for dowloading and loading pretrained models. a simple interface for dowloading and loading pretrained models.
""" """
config_class = PretrainedConfig config_class = PretrainedConfig
pretrained_model_archive_map = {} pretrained_model_archive_map = {}
pretrained_config_archive_map = {}
load_tf_weights = lambda model, config, path: None load_tf_weights = lambda model, config, path: None
base_model_prefix = "" base_model_prefix = ""
...@@ -151,8 +159,16 @@ class PreTrainedModel(nn.Module): ...@@ -151,8 +159,16 @@ class PreTrainedModel(nn.Module):
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format( "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
self.__class__.__name__, self.__class__.__name__ self.__class__.__name__, self.__class__.__name__
)) ))
# Save config in model
self.config = config self.config = config
def prune_heads(self, heads_to_prune):
""" Prunes heads of the base model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
"""
model_to_prune = getattr(self, self.base_model_prefix, self) # get the base model if needed
model_to_prune._prune_heads(heads_to_prune)
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
""" """
...@@ -175,24 +191,22 @@ class PreTrainedModel(nn.Module): ...@@ -175,24 +191,22 @@ class PreTrainedModel(nn.Module):
*inputs, **kwargs: additional input for the specific XLNet class *inputs, **kwargs: additional input for the specific XLNet class
(ex: num_labels for XLNetForSequenceClassification) (ex: num_labels for XLNetForSequenceClassification)
""" """
state_dict = kwargs.get('state_dict', None) state_dict = kwargs.pop('state_dict', None)
kwargs.pop('state_dict', None) cache_dir = kwargs.pop('cache_dir', None)
cache_dir = kwargs.get('cache_dir', None) from_tf = kwargs.pop('from_tf', None)
kwargs.pop('cache_dir', None)
from_tf = kwargs.get('from_tf', False) # Load config
kwargs.pop('from_tf', None) config = cls.config_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
# Load model
if pretrained_model_name_or_path in cls.pretrained_model_archive_map: if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
archive_file = cls.pretrained_model_archive_map[pretrained_model_name_or_path] archive_file = cls.pretrained_model_archive_map[pretrained_model_name_or_path]
config_file = cls.pretrained_config_archive_map[pretrained_model_name_or_path]
else: else:
if from_tf: if from_tf:
# Directly load from a TensorFlow checkpoint # Directly load from a TensorFlow checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index") archive_file = os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")
config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
else: else:
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
# redirect to the cache, if necessary # redirect to the cache, if necessary
try: try:
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir) resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
...@@ -210,47 +224,15 @@ class PreTrainedModel(nn.Module): ...@@ -210,47 +224,15 @@ class PreTrainedModel(nn.Module):
', '.join(cls.pretrained_model_archive_map.keys()), ', '.join(cls.pretrained_model_archive_map.keys()),
archive_file)) archive_file))
return None return None
try: if resolved_archive_file == archive_file:
resolved_config_file = cached_path(config_file, cache_dir=cache_dir)
except EnvironmentError:
if pretrained_model_name_or_path in cls.pretrained_config_archive_map:
logger.error(
"Couldn't reach server at '{}' to download pretrained model configuration file.".format(
config_file))
else:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url.".format(
pretrained_model_name_or_path,
', '.join(cls.pretrained_config_archive_map.keys()),
config_file))
return None
if resolved_archive_file == archive_file and resolved_config_file == config_file:
logger.info("loading weights file {}".format(archive_file)) logger.info("loading weights file {}".format(archive_file))
logger.info("loading configuration file {}".format(config_file))
else: else:
logger.info("loading weights file {} from cache at {}".format( logger.info("loading weights file {} from cache at {}".format(
archive_file, resolved_archive_file)) archive_file, resolved_archive_file))
logger.info("loading configuration file {} from cache at {}".format(
config_file, resolved_config_file))
# Load config
config = cls.config_class.from_json_file(resolved_config_file)
# Update config with kwargs if needed
to_remove = []
for key, value in kwargs.items():
if hasattr(config, key):
setattr(config, key, value)
to_remove.append(key)
for key in to_remove:
kwargs.pop(key, None)
logger.info("Model config {}".format(config))
# Instantiate model. # Instantiate model.
model = cls(config, *inputs, **kwargs) model = cls(config)
if state_dict is None and not from_tf: if state_dict is None and not from_tf:
state_dict = torch.load(resolved_archive_file, map_location='cpu') state_dict = torch.load(resolved_archive_file, map_location='cpu')
if from_tf: if from_tf:
...@@ -275,7 +257,7 @@ class PreTrainedModel(nn.Module): ...@@ -275,7 +257,7 @@ class PreTrainedModel(nn.Module):
if child is not None: if child is not None:
load(child, prefix + name + '.') load(child, prefix + name + '.')
# Be able to load base models as well as derived models (with heads) # Make sure we are able to load base models as well as derived models (with heads)
start_prefix = '' start_prefix = ''
model_to_load = model model_to_load = model
if not hasattr(model, cls.base_model_prefix) and any(s.startswith(cls.base_model_prefix) for s in state_dict.keys()): if not hasattr(model, cls.base_model_prefix) and any(s.startswith(cls.base_model_prefix) for s in state_dict.keys()):
......
...@@ -155,7 +155,7 @@ class BertConfig(PretrainedConfig): ...@@ -155,7 +155,7 @@ class BertConfig(PretrainedConfig):
pretrained_config_archive_map = PRETRAINED_CONFIG_ARCHIVE_MAP pretrained_config_archive_map = PRETRAINED_CONFIG_ARCHIVE_MAP
def __init__(self, def __init__(self,
vocab_size_or_config_json_file, vocab_size_or_config_json_file=30522,
hidden_size=768, hidden_size=768,
num_hidden_layers=12, num_hidden_layers=12,
num_attention_heads=12, num_attention_heads=12,
...@@ -167,7 +167,7 @@ class BertConfig(PretrainedConfig): ...@@ -167,7 +167,7 @@ class BertConfig(PretrainedConfig):
type_vocab_size=2, type_vocab_size=2,
initializer_range=0.02, initializer_range=0.02,
layer_norm_eps=1e-12, layer_norm_eps=1e-12,
finetuning_task=None): **kwargs):
"""Constructs BertConfig. """Constructs BertConfig.
Args: Args:
...@@ -192,8 +192,8 @@ class BertConfig(PretrainedConfig): ...@@ -192,8 +192,8 @@ class BertConfig(PretrainedConfig):
initializer_range: The sttdev of the truncated_normal_initializer for initializer_range: The sttdev of the truncated_normal_initializer for
initializing all weight matrices. initializing all weight matrices.
layer_norm_eps: The epsilon used by LayerNorm. layer_norm_eps: The epsilon used by LayerNorm.
finetuning_task: name of the glue task on which the model was fine-tuned if any
""" """
super(BertConfig, self).__init__(**kwargs)
if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2 if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
and isinstance(vocab_size_or_config_json_file, unicode)): and isinstance(vocab_size_or_config_json_file, unicode)):
with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader: with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
...@@ -213,7 +213,6 @@ class BertConfig(PretrainedConfig): ...@@ -213,7 +213,6 @@ class BertConfig(PretrainedConfig):
self.type_vocab_size = type_vocab_size self.type_vocab_size = type_vocab_size
self.initializer_range = initializer_range self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps self.layer_norm_eps = layer_norm_eps
self.finetuning_task = finetuning_task
else: else:
raise ValueError("First argument must be either a vocabulary size (int)" raise ValueError("First argument must be either a vocabulary size (int)"
"or the path to a pretrained model config file (str)") "or the path to a pretrained model config file (str)")
...@@ -270,13 +269,13 @@ class BertEmbeddings(nn.Module): ...@@ -270,13 +269,13 @@ class BertEmbeddings(nn.Module):
class BertSelfAttention(nn.Module): class BertSelfAttention(nn.Module):
def __init__(self, config, output_attentions=False): def __init__(self, config):
super(BertSelfAttention, self).__init__() super(BertSelfAttention, self).__init__()
if config.hidden_size % config.num_attention_heads != 0: if config.hidden_size % config.num_attention_heads != 0:
raise ValueError( raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention " "The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (config.hidden_size, config.num_attention_heads)) "heads (%d)" % (config.hidden_size, config.num_attention_heads))
self.output_attentions = output_attentions self.output_attentions = config.output_attentions
self.num_attention_heads = config.num_attention_heads self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads) self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
...@@ -344,10 +343,9 @@ class BertSelfOutput(nn.Module): ...@@ -344,10 +343,9 @@ class BertSelfOutput(nn.Module):
class BertAttention(nn.Module): class BertAttention(nn.Module):
def __init__(self, config, output_attentions=False): def __init__(self, config):
super(BertAttention, self).__init__() super(BertAttention, self).__init__()
self.output_attentions = output_attentions self.self = BertSelfAttention(config)
self.self = BertSelfAttention(config, output_attentions=output_attentions)
self.output = BertSelfOutput(config) self.output = BertSelfOutput(config)
def prune_heads(self, heads): def prune_heads(self, heads):
...@@ -404,10 +402,9 @@ class BertOutput(nn.Module): ...@@ -404,10 +402,9 @@ class BertOutput(nn.Module):
class BertLayer(nn.Module): class BertLayer(nn.Module):
def __init__(self, config, output_attentions=False): def __init__(self, config):
super(BertLayer, self).__init__() super(BertLayer, self).__init__()
self.output_attentions = output_attentions self.attention = BertAttention(config)
self.attention = BertAttention(config, output_attentions=output_attentions)
self.intermediate = BertIntermediate(config) self.intermediate = BertIntermediate(config)
self.output = BertOutput(config) self.output = BertOutput(config)
...@@ -421,11 +418,11 @@ class BertLayer(nn.Module): ...@@ -421,11 +418,11 @@ class BertLayer(nn.Module):
class BertEncoder(nn.Module): class BertEncoder(nn.Module):
def __init__(self, config, output_attentions=False, output_hidden_states=False): def __init__(self, config):
super(BertEncoder, self).__init__() super(BertEncoder, self).__init__()
self.output_attentions = output_attentions self.output_attentions = config.output_attentions
self.output_hidden_states = output_hidden_states self.output_hidden_states = config.output_hidden_states
layer = BertLayer(config, output_attentions=output_attentions) layer = BertLayer(config)
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)]) self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
def forward(self, hidden_states, attention_mask, head_mask=None): def forward(self, hidden_states, attention_mask, head_mask=None):
...@@ -546,9 +543,6 @@ class BertPreTrainedModel(PreTrainedModel): ...@@ -546,9 +543,6 @@ class BertPreTrainedModel(PreTrainedModel):
load_tf_weights = load_tf_weights_in_bert load_tf_weights = load_tf_weights_in_bert
base_model_prefix = "bert" base_model_prefix = "bert"
def __init__(self, *inputs, **kwargs):
super(BertPreTrainedModel, self).__init__(*inputs, **kwargs)
def init_weights(self, module): def init_weights(self, module):
""" Initialize the weights. """ Initialize the weights.
""" """
...@@ -612,19 +606,19 @@ class BertModel(BertPreTrainedModel): ...@@ -612,19 +606,19 @@ class BertModel(BertPreTrainedModel):
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
``` ```
""" """
def __init__(self, config, output_attentions=False, output_hidden_states=False): def __init__(self, config):
super(BertModel, self).__init__(config) super(BertModel, self).__init__(config)
self.output_attentions = output_attentions
self.output_hidden_states = output_hidden_states
self.embeddings = BertEmbeddings(config) self.embeddings = BertEmbeddings(config)
self.encoder = BertEncoder(config, output_attentions=output_attentions, self.encoder = BertEncoder(config)
output_hidden_states=output_hidden_states)
self.pooler = BertPooler(config) self.pooler = BertPooler(config)
self.apply(self.init_weights) self.apply(self.init_weights)
def prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model. """ Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer} heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
See base class PreTrainedModel
""" """
for layer, heads in heads_to_prune.items(): for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads) self.encoder.layer[layer].attention.prune_heads(heads)
...@@ -730,14 +724,12 @@ class BertForPreTraining(BertPreTrainedModel): ...@@ -730,14 +724,12 @@ class BertForPreTraining(BertPreTrainedModel):
masked_lm_logits_scores, seq_relationship_logits = model(input_ids, token_type_ids, input_mask) masked_lm_logits_scores, seq_relationship_logits = model(input_ids, token_type_ids, input_mask)
``` ```
""" """
def __init__(self, config, output_attentions=False, output_hidden_states=False): def __init__(self, config):
super(BertForPreTraining, self).__init__(config) super(BertForPreTraining, self).__init__(config)
self.output_attentions = output_attentions
self.output_hidden_states = output_hidden_states
self.bert = BertModel(config, output_attentions=output_attentions, self.bert = BertModel(config)
output_hidden_states=output_hidden_states)
self.cls = BertPreTrainingHeads(config, self.bert.embeddings.word_embeddings.weight) self.cls = BertPreTrainingHeads(config, self.bert.embeddings.word_embeddings.weight)
self.apply(self.init_weights) self.apply(self.init_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None,
...@@ -809,13 +801,12 @@ class BertForMaskedLM(BertPreTrainedModel): ...@@ -809,13 +801,12 @@ class BertForMaskedLM(BertPreTrainedModel):
masked_lm_logits_scores = model(input_ids, token_type_ids, input_mask) masked_lm_logits_scores = model(input_ids, token_type_ids, input_mask)
``` ```
""" """
def __init__(self, config, output_attentions=False, output_hidden_states=False): def __init__(self, config):
super(BertForMaskedLM, self).__init__(config) super(BertForMaskedLM, self).__init__(config)
self.output_attentions = output_attentions
self.output_hidden_states = output_hidden_states
self.bert = BertModel(config, output_attentions=output_attentions ) self.bert = BertModel(config)
self.cls = BertOnlyMLMHead(config, self.bert.embeddings.word_embeddings.weight) self.cls = BertOnlyMLMHead(config, self.bert.embeddings.word_embeddings.weight)
self.apply(self.init_weights) self.apply(self.init_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, head_mask=None): def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, head_mask=None):
...@@ -880,12 +871,10 @@ class BertForNextSentencePrediction(BertPreTrainedModel): ...@@ -880,12 +871,10 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
seq_relationship_logits = model(input_ids, token_type_ids, input_mask) seq_relationship_logits = model(input_ids, token_type_ids, input_mask)
``` ```
""" """
def __init__(self, config, output_attentions=False, output_hidden_states=False): def __init__(self, config):
super(BertForNextSentencePrediction, self).__init__(config) super(BertForNextSentencePrediction, self).__init__(config)
self.output_attentions = output_attentions
self.output_hidden_states = output_hidden_states
self.bert = BertModel(config, output_attentions=output_attentions) self.bert = BertModel(config)
self.cls = BertOnlyNSPHead(config) self.cls = BertOnlyNSPHead(config)
self.apply(self.init_weights) self.apply(self.init_weights)
...@@ -954,15 +943,13 @@ class BertForSequenceClassification(BertPreTrainedModel): ...@@ -954,15 +943,13 @@ class BertForSequenceClassification(BertPreTrainedModel):
logits = model(input_ids, token_type_ids, input_mask) logits = model(input_ids, token_type_ids, input_mask)
``` ```
""" """
def __init__(self, config, num_labels=2, output_attentions=False, output_hidden_states=False): def __init__(self, config):
super(BertForSequenceClassification, self).__init__(config) super(BertForSequenceClassification, self).__init__(config)
self.output_attentions = output_attentions self.num_labels = config.num_labels
self.output_hidden_states = output_hidden_states
self.num_labels = num_labels
self.bert = BertModel(config, output_attentions=output_attentions) self.bert = BertModel(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, num_labels) self.classifier = nn.Linear(config.hidden_size, self.config.num_labels)
self.apply(self.init_weights) self.apply(self.init_weights)
...@@ -997,7 +984,6 @@ class BertForMultipleChoice(BertPreTrainedModel): ...@@ -997,7 +984,6 @@ class BertForMultipleChoice(BertPreTrainedModel):
`config`: a BertConfig class instance with the configuration to build a new model `config`: a BertConfig class instance with the configuration to build a new model
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False `output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
`output_hidden_states`: If True, also output hidden states computed by the model at each layer. Default: False `output_hidden_states`: If True, also output hidden states computed by the model at each layer. Default: False
`num_choices`: the number of classes for the classifier. Default = 2.
Inputs: Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, num_choices, sequence_length] `input_ids`: a torch.LongTensor of shape [batch_size, num_choices, sequence_length]
...@@ -1030,25 +1016,23 @@ class BertForMultipleChoice(BertPreTrainedModel): ...@@ -1030,25 +1016,23 @@ class BertForMultipleChoice(BertPreTrainedModel):
config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
num_choices = 2 model = BertForMultipleChoice(config)
model = BertForMultipleChoice(config, num_choices)
logits = model(input_ids, token_type_ids, input_mask) logits = model(input_ids, token_type_ids, input_mask)
``` ```
""" """
def __init__(self, config, num_choices=2, output_attentions=False, output_hidden_states=False): def __init__(self, config):
super(BertForMultipleChoice, self).__init__(config) super(BertForMultipleChoice, self).__init__(config)
self.output_attentions = output_attentions
self.output_hidden_states = output_hidden_states
self.num_choices = num_choices
self.bert = BertModel(config, output_attentions=output_attentions) self.bert = BertModel(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, 1) self.classifier = nn.Linear(config.hidden_size, 1)
self.apply(self.init_weights) self.apply(self.init_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, head_mask=None): def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, head_mask=None):
""" Input shapes should be [bsz, num choices, seq length] """
num_choices = input_ids.shape[1]
flat_input_ids = input_ids.view(-1, input_ids.size(-1)) flat_input_ids = input_ids.view(-1, input_ids.size(-1))
flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
...@@ -1057,7 +1041,7 @@ class BertForMultipleChoice(BertPreTrainedModel): ...@@ -1057,7 +1041,7 @@ class BertForMultipleChoice(BertPreTrainedModel):
pooled_output = self.dropout(pooled_output) pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output) logits = self.classifier(pooled_output)
reshaped_logits = logits.view(-1, self.num_choices) reshaped_logits = logits.view(-1, num_choices)
outputs = [reshaped_logits] + outputs[2:] # add hidden states and attention if they are here outputs = [reshaped_logits] + outputs[2:] # add hidden states and attention if they are here
...@@ -1118,15 +1102,13 @@ class BertForTokenClassification(BertPreTrainedModel): ...@@ -1118,15 +1102,13 @@ class BertForTokenClassification(BertPreTrainedModel):
logits = model(input_ids, token_type_ids, input_mask) logits = model(input_ids, token_type_ids, input_mask)
``` ```
""" """
def __init__(self, config, num_labels=2, output_attentions=False, output_hidden_states=False): def __init__(self, config):
super(BertForTokenClassification, self).__init__(config) super(BertForTokenClassification, self).__init__(config)
self.output_attentions = output_attentions self.num_labels = config.num_labels
self.output_hidden_states = output_hidden_states
self.num_labels = num_labels
self.bert = BertModel(config, output_attentions=output_attentions) self.bert = BertModel(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, num_labels) self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self.apply(self.init_weights) self.apply(self.init_weights)
...@@ -1204,12 +1186,12 @@ class BertForQuestionAnswering(BertPreTrainedModel): ...@@ -1204,12 +1186,12 @@ class BertForQuestionAnswering(BertPreTrainedModel):
start_logits, end_logits = model(input_ids, token_type_ids, input_mask) start_logits, end_logits = model(input_ids, token_type_ids, input_mask)
``` ```
""" """
def __init__(self, config, output_attentions=False, output_hidden_states=False): def __init__(self, config):
super(BertForQuestionAnswering, self).__init__(config) super(BertForQuestionAnswering, self).__init__(config)
self.output_attentions = output_attentions self.num_labels = config.num_labels
self.output_hidden_states = output_hidden_states
self.bert = BertModel(config, output_attentions=output_attentions) self.bert = BertModel(config)
self.qa_outputs = nn.Linear(config.hidden_size, 2) self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
self.apply(self.init_weights) self.apply(self.init_weights)
......
...@@ -119,7 +119,8 @@ class GPT2Config(PretrainedConfig): ...@@ -119,7 +119,8 @@ class GPT2Config(PretrainedConfig):
attn_pdrop=0.1, attn_pdrop=0.1,
layer_norm_epsilon=1e-5, layer_norm_epsilon=1e-5,
initializer_range=0.02, initializer_range=0.02,
predict_special_tokens=True predict_special_tokens=True,
**kwargs
): ):
"""Constructs GPT2Config. """Constructs GPT2Config.
...@@ -142,6 +143,8 @@ class GPT2Config(PretrainedConfig): ...@@ -142,6 +143,8 @@ class GPT2Config(PretrainedConfig):
initializing all weight matrices. initializing all weight matrices.
predict_special_tokens: should we predict special tokens (when the model has a LM head) predict_special_tokens: should we predict special tokens (when the model has a LM head)
""" """
super(GPT2Config, self).__init__(**kwargs)
if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2 if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
and isinstance(vocab_size_or_config_json_file, unicode)): and isinstance(vocab_size_or_config_json_file, unicode)):
with open(vocab_size_or_config_json_file, "r", encoding="utf-8") as reader: with open(vocab_size_or_config_json_file, "r", encoding="utf-8") as reader:
...@@ -174,8 +177,10 @@ class GPT2Config(PretrainedConfig): ...@@ -174,8 +177,10 @@ class GPT2Config(PretrainedConfig):
class Attention(nn.Module): class Attention(nn.Module):
def __init__(self, nx, n_ctx, config, scale=False, output_attentions=False, keep_multihead_output=False): def __init__(self, nx, n_ctx, config, scale=False):
super(Attention, self).__init__() super(Attention, self).__init__()
self.output_attentions = config.output_attentions
n_state = nx # in Attention: n_state=768 (nx=n_embd) n_state = nx # in Attention: n_state=768 (nx=n_embd)
# [switch nx => n_state from Block to Attention to keep identical to TF implem] # [switch nx => n_state from Block to Attention to keep identical to TF implem]
assert n_state % config.n_head == 0 assert n_state % config.n_head == 0
...@@ -184,10 +189,6 @@ class Attention(nn.Module): ...@@ -184,10 +189,6 @@ class Attention(nn.Module):
self.split_size = n_state self.split_size = n_state
self.scale = scale self.scale = scale
self.output_attentions = output_attentions
self.keep_multihead_output = keep_multihead_output
self.multihead_output = None
self.c_attn = Conv1D(n_state * 3, nx) self.c_attn = Conv1D(n_state * 3, nx)
self.c_proj = Conv1D(n_state, nx) self.c_proj = Conv1D(n_state, nx)
self.attn_dropout = nn.Dropout(config.attn_pdrop) self.attn_dropout = nn.Dropout(config.attn_pdrop)
...@@ -224,9 +225,10 @@ class Attention(nn.Module): ...@@ -224,9 +225,10 @@ class Attention(nn.Module):
if head_mask is not None: if head_mask is not None:
w = w * head_mask w = w * head_mask
outputs = [torch.matmul(w, v)]
if self.output_attentions: if self.output_attentions:
return w, torch.matmul(w, v) outputs.append(w)
return torch.matmul(w, v) return outputs
def merge_heads(self, x): def merge_heads(self, x):
x = x.permute(0, 2, 1, 3).contiguous() x = x.permute(0, 2, 1, 3).contiguous()
...@@ -253,19 +255,15 @@ class Attention(nn.Module): ...@@ -253,19 +255,15 @@ class Attention(nn.Module):
value = torch.cat((past_value, value), dim=-2) value = torch.cat((past_value, value), dim=-2)
present = torch.stack((key.transpose(-2, -1), value)) # transpose to have same shapes for stacking present = torch.stack((key.transpose(-2, -1), value)) # transpose to have same shapes for stacking
a = self._attn(query, key, value, head_mask) attn_outputs = self._attn(query, key, value, head_mask)
if self.keep_multihead_output: a = attn_outputs[0]
self.multihead_output = a
self.multihead_output.retain_grad()
if self.output_attentions:
attentions, a = a
a = self.merge_heads(a) a = self.merge_heads(a)
a = self.c_proj(a) a = self.c_proj(a)
a = self.resid_dropout(a) a = self.resid_dropout(a)
if self.output_attentions:
return attentions, a, present outputs = [a, present] + attn_outputs[1:]
return a, present return outputs # a, present, (attentions)
class MLP(nn.Module): class MLP(nn.Module):
...@@ -284,27 +282,24 @@ class MLP(nn.Module): ...@@ -284,27 +282,24 @@ class MLP(nn.Module):
class Block(nn.Module): class Block(nn.Module):
def __init__(self, n_ctx, config, scale=False, output_attentions=False, keep_multihead_output=False): def __init__(self, n_ctx, config, scale=False):
super(Block, self).__init__() super(Block, self).__init__()
nx = config.n_embd nx = config.n_embd
self.output_attentions = output_attentions
self.ln_1 = LayerNorm(nx, eps=config.layer_norm_epsilon) self.ln_1 = LayerNorm(nx, eps=config.layer_norm_epsilon)
self.attn = Attention(nx, n_ctx, config, scale, output_attentions, keep_multihead_output) self.attn = Attention(nx, n_ctx, config, scale)
self.ln_2 = LayerNorm(nx, eps=config.layer_norm_epsilon) self.ln_2 = LayerNorm(nx, eps=config.layer_norm_epsilon)
self.mlp = MLP(4 * nx, config) self.mlp = MLP(4 * nx, config)
def forward(self, x, layer_past=None, head_mask=None): def forward(self, x, layer_past=None, head_mask=None):
output_attn = self.attn(self.ln_1(x), layer_past=layer_past, head_mask=head_mask) output_attn = self.attn(self.ln_1(x), layer_past=layer_past, head_mask=head_mask)
if self.output_attentions: a = output_attn[0] # output_attn: a, present, (attentions)
attentions, a, present = output_attn
else:
a, present = output_attn
x = x + a x = x + a
m = self.mlp(self.ln_2(x)) m = self.mlp(self.ln_2(x))
x = x + m x = x + m
if self.output_attentions:
return attentions, x, present outputs = [x] + output_attn[1:]
return x, present return outputs # x, present, (attentions)
class GPT2LMHead(nn.Module): class GPT2LMHead(nn.Module):
...@@ -342,12 +337,17 @@ class GPT2MultipleChoiceHead(nn.Module): ...@@ -342,12 +337,17 @@ class GPT2MultipleChoiceHead(nn.Module):
nn.init.normal_(self.linear.weight, std=0.02) nn.init.normal_(self.linear.weight, std=0.02)
nn.init.normal_(self.linear.bias, 0) nn.init.normal_(self.linear.bias, 0)
def forward(self, hidden_states, mc_token_ids): def forward(self, hidden_states, mc_token_ids=None):
# Classification logits """ Extract classification token hidden state and project it using self.linear
# hidden_state (bsz, num_choices, seq_length, hidden_size) hidden_state: shape (bsz, num_choices, seq_length, hidden_size)
# mc_token_ids (bsz, num_choices) mc_token_ids: [optional] index of the classification token, shape (bsz, num_choices)
mc_token_ids = mc_token_ids.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, -1, hidden_states.size(-1)) if mc_token_ids=None we take the last token of the sequence as classification token
# (bsz, num_choices, 1, hidden_size) """
if mc_token_ids is None:
mc_token_ids = torch.full_like(hidden_states[:, :, :1, :], hidden_states.shape[2] - 1, dtype=torch.long)
else:
mc_token_ids = mc_token_ids.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, -1, hidden_states.size(-1))
# mc_token_ids has shape (bsz, num_choices, 1, hidden_size)
multiple_choice_h = hidden_states.gather(2, mc_token_ids).squeeze(2) multiple_choice_h = hidden_states.gather(2, mc_token_ids).squeeze(2)
# (bsz, num_choices, hidden_size) # (bsz, num_choices, hidden_size)
multiple_choice_h = self.dropout(multiple_choice_h.transpose(1, 2)).transpose(1, 2) multiple_choice_h = self.dropout(multiple_choice_h.transpose(1, 2)).transpose(1, 2)
...@@ -362,13 +362,9 @@ class GPT2PreTrainedModel(PreTrainedModel): ...@@ -362,13 +362,9 @@ class GPT2PreTrainedModel(PreTrainedModel):
""" """
config_class = GPT2Config config_class = GPT2Config
pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP
pretrained_config_archive_map = PRETRAINED_CONFIG_ARCHIVE_MAP
load_tf_weights = load_tf_weights_in_gpt2 load_tf_weights = load_tf_weights_in_gpt2
base_model_prefix = "transformer" base_model_prefix = "transformer"
def __init__(self, *inputs, **kwargs):
super(GPT2PreTrainedModel, self).__init__(*inputs, **kwargs)
def init_weights(self, module): def init_weights(self, module):
""" Initialize the weights. """ Initialize the weights.
""" """
...@@ -403,126 +399,9 @@ class GPT2PreTrainedModel(PreTrainedModel): ...@@ -403,126 +399,9 @@ class GPT2PreTrainedModel(PreTrainedModel):
state_dict: an optional state dictionary (collections.OrderedDict object) to use instead of pre-trained models state_dict: an optional state dictionary (collections.OrderedDict object) to use instead of pre-trained models
*inputs, **kwargs: additional input for the specific GPT2 class *inputs, **kwargs: additional input for the specific GPT2 class
""" """
# state_dict = kwargs.get('state_dict', None) num_special_tokens = kwargs.pop('num_special_tokens', None)
# kwargs.pop('state_dict', None)
# cache_dir = kwargs.get('cache_dir', None) model = PreTrainedModel.from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs)
# kwargs.pop('cache_dir', None)
# from_tf = kwargs.get('from_tf', False)
# kwargs.pop('from_tf', None)
num_special_tokens = kwargs.get('num_special_tokens', None)
kwargs.pop('num_special_tokens', None)
# if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
# archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path]
# config_file = PRETRAINED_CONFIG_ARCHIVE_MAP[pretrained_model_name_or_path]
# else:
# archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
# config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
# # redirect to the cache, if necessary
# try:
# resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
# except EnvironmentError:
# if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
# logger.error(
# "Couldn't reach server at '{}' to download pretrained weights.".format(
# archive_file))
# else:
# logger.error(
# "Model name '{}' was not found in model name list ({}). "
# "We assumed '{}' was a path or url but couldn't find file {} "
# "at this path or url.".format(
# pretrained_model_name_or_path, ", ".join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()), pretrained_model_name_or_path,
# archive_file
# )
# )
# return None
# try:
# resolved_config_file = cached_path(config_file, cache_dir=cache_dir)
# except EnvironmentError:
# if pretrained_model_name_or_path in PRETRAINED_CONFIG_ARCHIVE_MAP:
# logger.error(
# "Couldn't reach server at '{}' to download pretrained model configuration file.".format(
# config_file))
# else:
# logger.error(
# "Model name '{}' was not found in model name list ({}). "
# "We assumed '{}' was a path or url but couldn't find file {} "
# "at this path or url.".format(
# pretrained_model_name_or_path, ", ".join(PRETRAINED_CONFIG_ARCHIVE_MAP.keys()), pretrained_model_name_or_path,
# config_file
# )
# )
# return None
# if resolved_archive_file == archive_file and resolved_config_file == config_file:
# logger.info("loading weights file {}".format(archive_file))
# logger.info("loading configuration file {}".format(config_file))
# else:
# logger.info("loading weights file {} from cache at {}".format(
# archive_file, resolved_archive_file))
# logger.info("loading configuration file {} from cache at {}".format(
# config_file, resolved_config_file))
# # Load config
# config = GPT2Config.from_json_file(resolved_config_file)
# logger.info("Model config {}".format(config))
# # Instantiate model.
# model = cls(config, *inputs, **kwargs)
# if state_dict is None and not from_tf:
# state_dict = torch.load(resolved_archive_file, map_location='cpu')
# if from_tf:
# # Directly load from a TensorFlow checkpoint (stored as NumPy array)
# return load_tf_weights_in_gpt2(model, resolved_archive_file)
# old_keys = []
# new_keys = []
# for key in state_dict.keys():
# new_key = None
# if key.endswith(".g"):
# new_key = key[:-2] + ".weight"
# elif key.endswith(".b"):
# new_key = key[:-2] + ".bias"
# elif key.endswith(".w"):
# new_key = key[:-2] + ".weight"
# if new_key:
# old_keys.append(key)
# new_keys.append(new_key)
# for old_key, new_key in zip(old_keys, new_keys):
# state_dict[new_key] = state_dict.pop(old_key)
# missing_keys = []
# unexpected_keys = []
# error_msgs = []
# # copy state_dict so _load_from_state_dict can modify it
# metadata = getattr(state_dict, "_metadata", None)
# state_dict = state_dict.copy()
# if metadata is not None:
# state_dict._metadata = metadata
# def load(module, prefix=""):
# local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
# module._load_from_state_dict(
# state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs
# )
# for name, child in module._modules.items():
# if child is not None:
# load(child, prefix + name + ".")
# start_model = model
# if hasattr(model, "transformer") and all(not s.startswith('transformer.') for s in state_dict.keys()):
# start_model = model.transformer
# load(start_model, prefix="")
# if len(missing_keys) > 0:
# logger.info(
# "Weights of {} not initialized from pretrained model: {}".format(model.__class__.__name__, missing_keys)
# )
# if len(unexpected_keys) > 0:
# logger.info(
# "Weights from pretrained model not used in {}: {}".format(model.__class__.__name__, unexpected_keys)
# )
# if len(error_msgs) > 0:
# raise RuntimeError(
# "Error(s) in loading state_dict for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs))
# )
# Add additional embeddings for special tokens if needed # Add additional embeddings for special tokens if needed
# This step also make sure we are still sharing the output and input embeddings after loading weights # This step also make sure we are still sharing the output and input embeddings after loading weights
...@@ -553,8 +432,6 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -553,8 +432,6 @@ class GPT2Model(GPT2PreTrainedModel):
Params: Params:
`config`: a GPT2Config class instance with the configuration to build a new model `config`: a GPT2Config class instance with the configuration to build a new model
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False `output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
`keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient.
This can be used to compute head importance metrics. Default: False
Inputs: Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] (or more generally [d_1, ..., d_n, sequence_length] `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] (or more generally [d_1, ..., d_n, sequence_length]
...@@ -591,14 +468,15 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -591,14 +468,15 @@ class GPT2Model(GPT2PreTrainedModel):
``` ```
""" """
def __init__(self, config, output_attentions=False, keep_multihead_output=False): def __init__(self, config):
super(GPT2Model, self).__init__(config) super(GPT2Model, self).__init__(config)
self.output_attentions = output_attentions self.output_hidden_states = config.output_hidden_states
self.output_attentions = config.output_attentions
self.wte = nn.Embedding(config.total_tokens_embeddings, config.n_embd) self.wte = nn.Embedding(config.total_tokens_embeddings, config.n_embd)
self.wpe = nn.Embedding(config.n_positions, config.n_embd) self.wpe = nn.Embedding(config.n_positions, config.n_embd)
self.drop = nn.Dropout(config.embd_pdrop) self.drop = nn.Dropout(config.embd_pdrop)
block = Block(config.n_ctx, config, scale=True, output_attentions=output_attentions, block = Block(config.n_ctx, config, scale=True)
keep_multihead_output=keep_multihead_output)
self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(config.n_layer)]) self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(config.n_layer)])
self.ln_f = LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) self.ln_f = LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
...@@ -618,19 +496,13 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -618,19 +496,13 @@ class GPT2Model(GPT2PreTrainedModel):
# Copy word embeddings from the previous weights # Copy word embeddings from the previous weights
self.wte.weight.data[:self.config.vocab_size, :] = old_embed.weight.data[:self.config.vocab_size, :] self.wte.weight.data[:self.config.vocab_size, :] = old_embed.weight.data[:self.config.vocab_size, :]
def prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model. """ Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer} heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
""" """
for layer, heads in heads_to_prune.items(): for layer, heads in heads_to_prune.items():
self.h[layer].attn.prune_heads(heads) self.h[layer].attn.prune_heads(heads)
def get_multihead_outputs(self):
""" Gather all multi-head outputs.
Return: list (layers) of multihead module outputs with gradients
"""
return [h.attn.multihead_output for h in self.h]
def forward(self, input_ids, position_ids=None, token_type_ids=None, past=None, head_mask=None): def forward(self, input_ids, position_ids=None, token_type_ids=None, past=None, head_mask=None):
if past is None: if past is None:
past_length = 0 past_length = 0
...@@ -675,20 +547,32 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -675,20 +547,32 @@ class GPT2Model(GPT2PreTrainedModel):
all_attentions = [] all_attentions = []
all_hidden_states = [] all_hidden_states = []
for i, (block, layer_past) in enumerate(zip(self.h, past)): for i, (block, layer_past) in enumerate(zip(self.h, past)):
all_hidden_states.append(hidden_states.view(*output_shape)) if self.output_hidden_states:
all_hidden_states.append(hidden_states.view(*output_shape))
outputs = block(hidden_states, layer_past, head_mask[i]) outputs = block(hidden_states, layer_past, head_mask[i])
if self.output_attentions: hidden_states, present = outputs[:2]
attentions, hidden_states, present = outputs
all_attentions.append(attentions)
else:
hidden_states, present = outputs
presents.append(present) presents.append(present)
if self.output_attentions:
all_attentions.append(outputs[2])
hidden_states = self.ln_f(hidden_states) hidden_states = self.ln_f(hidden_states)
all_hidden_states.append(hidden_states.view(*output_shape))
hidden_states = hidden_states.view(*output_shape)
# Add last hidden state
if self.output_hidden_states:
all_hidden_states.append(hidden_states)
outputs = [hidden_states, presents]
if self.output_hidden_states:
outputs.append(all_hidden_states)
if self.output_attentions: if self.output_attentions:
return all_attentions, all_hidden_states, presents # let the number of heads free (-1) so we can extract attention even after head pruning
return all_hidden_states, presents attention_output_shape = input_shape[:-1] + (-1,) + all_attentions[0].shape[-2:]
all_attentions = list(t.view(*attention_output_shape) for t in all_attentions)
outputs.append(all_attentions)
return outputs # last hidden state, presents, (all hidden_states), (attentions)
class GPT2LMHeadModel(GPT2PreTrainedModel): class GPT2LMHeadModel(GPT2PreTrainedModel):
...@@ -740,10 +624,9 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): ...@@ -740,10 +624,9 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
``` ```
""" """
def __init__(self, config, output_attentions=False, keep_multihead_output=False): def __init__(self, config):
super(GPT2LMHeadModel, self).__init__(config) super(GPT2LMHeadModel, self).__init__(config)
self.transformer = GPT2Model(config, output_attentions=output_attentions, self.transformer = GPT2Model(config)
keep_multihead_output=keep_multihead_output)
self.lm_head = GPT2LMHead(self.transformer.wte.weight, config) self.lm_head = GPT2LMHead(self.transformer.wte.weight, config)
self.apply(self.init_weights) self.apply(self.init_weights)
...@@ -756,14 +639,12 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): ...@@ -756,14 +639,12 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
self.lm_head.set_embeddings_weights(self.transformer.wte.weight, predict_special_tokens=predict_special_tokens) self.lm_head.set_embeddings_weights(self.transformer.wte.weight, predict_special_tokens=predict_special_tokens)
def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None, past=None, head_mask=None): def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None, past=None, head_mask=None):
transformer_output = self.transformer(input_ids, position_ids, token_type_ids, past, head_mask) transformer_outputs = self.transformer(input_ids, position_ids, token_type_ids, past, head_mask)
if self.transformer.output_attentions: hidden_states = transformer_outputs[0]
all_attentions, hidden_states, presents = transformer_output
else:
hidden_states, presents = transformer_output
hidden_states = hidden_states[-1]
lm_logits = self.lm_head(hidden_states) lm_logits = self.lm_head(hidden_states)
outputs = [lm_logits] + transformer_outputs[1:]
if lm_labels is not None: if lm_labels is not None:
# Shift so that tokens < n predict n # Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].contiguous() shift_logits = lm_logits[..., :-1, :].contiguous()
...@@ -772,10 +653,9 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): ...@@ -772,10 +653,9 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
loss_fct = CrossEntropyLoss(ignore_index=-1) loss_fct = CrossEntropyLoss(ignore_index=-1)
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1)) shift_labels.view(-1))
return loss outputs = [loss] + outputs
if self.transformer.output_attentions:
return all_attentions, lm_logits, presents return outputs # (loss), lm_logits, presents, (all hidden_states), (attentions)
return lm_logits, presents
class GPT2DoubleHeadsModel(GPT2PreTrainedModel): class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
...@@ -832,12 +712,12 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel): ...@@ -832,12 +712,12 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
``` ```
""" """
def __init__(self, config, output_attentions=False, keep_multihead_output=False): def __init__(self, config):
super(GPT2DoubleHeadsModel, self).__init__(config) super(GPT2DoubleHeadsModel, self).__init__(config)
self.transformer = GPT2Model(config, output_attentions=output_attentions, self.transformer = GPT2Model(config)
keep_multihead_output=keep_multihead_output)
self.lm_head = GPT2LMHead(self.transformer.wte.weight, config) self.lm_head = GPT2LMHead(self.transformer.wte.weight, config)
self.multiple_choice_head = GPT2MultipleChoiceHead(config) self.multiple_choice_head = GPT2MultipleChoiceHead(config)
self.apply(self.init_weights) self.apply(self.init_weights)
def set_num_special_tokens(self, num_special_tokens, predict_special_tokens=True): def set_num_special_tokens(self, num_special_tokens, predict_special_tokens=True):
...@@ -848,28 +728,26 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel): ...@@ -848,28 +728,26 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
self.transformer.set_num_special_tokens(num_special_tokens) self.transformer.set_num_special_tokens(num_special_tokens)
self.lm_head.set_embeddings_weights(self.transformer.wte.weight, predict_special_tokens=predict_special_tokens) self.lm_head.set_embeddings_weights(self.transformer.wte.weight, predict_special_tokens=predict_special_tokens)
def forward(self, input_ids, mc_token_ids, lm_labels=None, mc_labels=None, token_type_ids=None, def forward(self, input_ids, mc_token_ids=None, lm_labels=None, mc_labels=None, token_type_ids=None,
position_ids=None, past=None, head_mask=None): position_ids=None, past=None, head_mask=None):
transformer_output = self.transformer(input_ids, position_ids, token_type_ids, past, head_mask) transformer_outputs = self.transformer(input_ids, position_ids, token_type_ids, past, head_mask)
if self.transformer.output_attentions: hidden_states = transformer_outputs[0]
all_attentions, hidden_states, presents = transformer_output
else:
hidden_states, presents = transformer_output
hidden_states = hidden_states[-1]
lm_logits = self.lm_head(hidden_states) lm_logits = self.lm_head(hidden_states)
mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids) mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids)
losses = []
outputs = [lm_logits, mc_logits] + transformer_outputs[1:]
if mc_labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)),
mc_labels.view(-1))
outputs = [loss] + outputs
if lm_labels is not None: if lm_labels is not None:
shift_logits = lm_logits[..., :-1, :].contiguous() shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = lm_labels[..., 1:].contiguous() shift_labels = lm_labels[..., 1:].contiguous()
loss_fct = CrossEntropyLoss(ignore_index=-1) loss_fct = CrossEntropyLoss(ignore_index=-1)
losses.append(loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))) loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
if mc_labels is not None: shift_labels.view(-1))
loss_fct = CrossEntropyLoss() outputs = [loss] + outputs
losses.append(loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1)))
if losses: return outputs # (lm loss), (mc loss), lm logits, mc logits, presents, (all hidden_states), (attentions)
return losses
if self.transformer.output_attentions:
return all_attentions, lm_logits, mc_logits, presents
return lm_logits, mc_logits, presents
...@@ -147,7 +147,8 @@ class OpenAIGPTConfig(PretrainedConfig): ...@@ -147,7 +147,8 @@ class OpenAIGPTConfig(PretrainedConfig):
attn_pdrop=0.1, attn_pdrop=0.1,
layer_norm_epsilon=1e-5, layer_norm_epsilon=1e-5,
initializer_range=0.02, initializer_range=0.02,
predict_special_tokens=True predict_special_tokens=True,
**kwargs
): ):
"""Constructs OpenAIGPTConfig. """Constructs OpenAIGPTConfig.
...@@ -172,6 +173,8 @@ class OpenAIGPTConfig(PretrainedConfig): ...@@ -172,6 +173,8 @@ class OpenAIGPTConfig(PretrainedConfig):
initializing all weight matrices. initializing all weight matrices.
predict_special_tokens: should we predict special tokens (when the model has a LM head) predict_special_tokens: should we predict special tokens (when the model has a LM head)
""" """
super(OpenAIGPTConfig, self).__init__(**kwargs)
if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2 if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
and isinstance(vocab_size_or_config_json_file, unicode)): and isinstance(vocab_size_or_config_json_file, unicode)):
with open(vocab_size_or_config_json_file, "r", encoding="utf-8") as reader: with open(vocab_size_or_config_json_file, "r", encoding="utf-8") as reader:
...@@ -205,7 +208,7 @@ class OpenAIGPTConfig(PretrainedConfig): ...@@ -205,7 +208,7 @@ class OpenAIGPTConfig(PretrainedConfig):
class Attention(nn.Module): class Attention(nn.Module):
def __init__(self, nx, n_ctx, config, scale=False, output_attentions=False, keep_multihead_output=False): def __init__(self, nx, n_ctx, config, scale=False):
super(Attention, self).__init__() super(Attention, self).__init__()
n_state = nx # in Attention: n_state=768 (nx=n_embd) n_state = nx # in Attention: n_state=768 (nx=n_embd)
# [switch nx => n_state from Block to Attention to keep identical to TF implem] # [switch nx => n_state from Block to Attention to keep identical to TF implem]
...@@ -215,9 +218,7 @@ class Attention(nn.Module): ...@@ -215,9 +218,7 @@ class Attention(nn.Module):
self.split_size = n_state self.split_size = n_state
self.scale = scale self.scale = scale
self.output_attentions = output_attentions self.output_attentions = config.output_attentions
self.keep_multihead_output = keep_multihead_output
self.multihead_output = None
self.c_attn = Conv1D(n_state * 3, nx) self.c_attn = Conv1D(n_state * 3, nx)
self.c_proj = Conv1D(n_state, nx) self.c_proj = Conv1D(n_state, nx)
...@@ -256,9 +257,10 @@ class Attention(nn.Module): ...@@ -256,9 +257,10 @@ class Attention(nn.Module):
if head_mask is not None: if head_mask is not None:
w = w * head_mask w = w * head_mask
outputs = [torch.matmul(w, v)]
if self.output_attentions: if self.output_attentions:
return w, torch.matmul(w, v) outputs.append(w)
return torch.matmul(w, v) return outputs
def merge_heads(self, x): def merge_heads(self, x):
x = x.permute(0, 2, 1, 3).contiguous() x = x.permute(0, 2, 1, 3).contiguous()
...@@ -280,19 +282,15 @@ class Attention(nn.Module): ...@@ -280,19 +282,15 @@ class Attention(nn.Module):
key = self.split_heads(key, k=True) key = self.split_heads(key, k=True)
value = self.split_heads(value) value = self.split_heads(value)
a = self._attn(query, key, value, head_mask) attn_outputs = self._attn(query, key, value, head_mask)
if self.keep_multihead_output: a = attn_outputs[0]
self.multihead_output = a
self.multihead_output.retain_grad()
if self.output_attentions:
attentions, a = a
a = self.merge_heads(a) a = self.merge_heads(a)
a = self.c_proj(a) a = self.c_proj(a)
a = self.resid_dropout(a) a = self.resid_dropout(a)
if self.output_attentions:
return attentions, a outputs = [a] + attn_outputs[1:]
return a return outputs # a, (attentions)
class MLP(nn.Module): class MLP(nn.Module):
...@@ -311,25 +309,24 @@ class MLP(nn.Module): ...@@ -311,25 +309,24 @@ class MLP(nn.Module):
class Block(nn.Module): class Block(nn.Module):
def __init__(self, n_ctx, config, scale=False, output_attentions=False, keep_multihead_output=False): def __init__(self, n_ctx, config, scale=False):
super(Block, self).__init__() super(Block, self).__init__()
nx = config.n_embd nx = config.n_embd
self.output_attentions = output_attentions self.attn = Attention(nx, n_ctx, config, scale)
self.attn = Attention(nx, n_ctx, config, scale, output_attentions, keep_multihead_output)
self.ln_1 = LayerNorm(nx, eps=config.layer_norm_epsilon) self.ln_1 = LayerNorm(nx, eps=config.layer_norm_epsilon)
self.mlp = MLP(4 * nx, config) self.mlp = MLP(4 * nx, config)
self.ln_2 = LayerNorm(nx, eps=config.layer_norm_epsilon) self.ln_2 = LayerNorm(nx, eps=config.layer_norm_epsilon)
def forward(self, x, head_mask=None): def forward(self, x, head_mask=None):
a = self.attn(x, head_mask=head_mask) attn_outputs = self.attn(x, head_mask=head_mask)
if self.output_attentions: a = attn_outputs[0]
attentions, a = a
n = self.ln_1(x + a) n = self.ln_1(x + a)
m = self.mlp(n) m = self.mlp(n)
h = self.ln_2(n + m) h = self.ln_2(n + m)
if self.output_attentions:
return attentions, h outputs = [h] + attn_outputs[1:]
return h return outputs
class OpenAIGPTLMHead(nn.Module): class OpenAIGPTLMHead(nn.Module):
...@@ -368,11 +365,16 @@ class OpenAIGPTMultipleChoiceHead(nn.Module): ...@@ -368,11 +365,16 @@ class OpenAIGPTMultipleChoiceHead(nn.Module):
nn.init.normal_(self.linear.weight, std=0.02) nn.init.normal_(self.linear.weight, std=0.02)
nn.init.normal_(self.linear.bias, 0) nn.init.normal_(self.linear.bias, 0)
def forward(self, hidden_states, mc_token_ids): def forward(self, hidden_states, mc_token_ids=None):
# Classification logits """ Extract classification token hidden state and project it using self.linear
# hidden_state (bsz, num_choices, seq_length, hidden_size) hidden_state: hidden state of shape (bsz, num_choices, seq_length, hidden_size)
# mc_token_ids (bsz, num_choices) mc_token_ids: [optional] index of the classification token, shape (bsz, num_choices)
mc_token_ids = mc_token_ids.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, -1, hidden_states.size(-1)) if mc_token_ids=None we take the last token of the sequence as classification token
"""
if mc_token_ids is None:
mc_token_ids = torch.full_like(hidden_states[:, :, :1, :], hidden_states.shape[2] - 1, dtype=torch.long)
else:
mc_token_ids = mc_token_ids.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, -1, hidden_states.size(-1))
# (bsz, num_choices, 1, hidden_size) # (bsz, num_choices, 1, hidden_size)
multiple_choice_h = hidden_states.gather(2, mc_token_ids).squeeze(2) multiple_choice_h = hidden_states.gather(2, mc_token_ids).squeeze(2)
# (bsz, num_choices, hidden_size) # (bsz, num_choices, hidden_size)
...@@ -388,13 +390,9 @@ class OpenAIGPTPreTrainedModel(PreTrainedModel): ...@@ -388,13 +390,9 @@ class OpenAIGPTPreTrainedModel(PreTrainedModel):
""" """
config_class = OpenAIGPTConfig config_class = OpenAIGPTConfig
pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP
pretrained_config_archive_map = PRETRAINED_CONFIG_ARCHIVE_MAP
load_tf_weights = load_tf_weights_in_openai_gpt load_tf_weights = load_tf_weights_in_openai_gpt
base_model_prefix = "transformer" base_model_prefix = "transformer"
def __init__(self, *inputs, **kwargs):
super(OpenAIGPTPreTrainedModel, self).__init__(*inputs, **kwargs)
def init_weights(self, module): def init_weights(self, module):
""" Initialize the weights. """ Initialize the weights.
""" """
...@@ -495,14 +493,15 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel): ...@@ -495,14 +493,15 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
``` ```
""" """
def __init__(self, config, output_attentions=False, keep_multihead_output=False): def __init__(self, config):
super(OpenAIGPTModel, self).__init__(config) super(OpenAIGPTModel, self).__init__(config)
self.output_attentions = output_attentions self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states
self.tokens_embed = nn.Embedding(config.total_tokens_embeddings, config.n_embd) self.tokens_embed = nn.Embedding(config.total_tokens_embeddings, config.n_embd)
self.positions_embed = nn.Embedding(config.n_positions, config.n_embd) self.positions_embed = nn.Embedding(config.n_positions, config.n_embd)
self.drop = nn.Dropout(config.embd_pdrop) self.drop = nn.Dropout(config.embd_pdrop)
block = Block(config.n_ctx, config, scale=True, output_attentions=output_attentions, block = Block(config.n_ctx, config, scale=True)
keep_multihead_output=keep_multihead_output)
self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(config.n_layer)]) self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(config.n_layer)])
self.apply(self.init_weights) self.apply(self.init_weights)
...@@ -521,19 +520,13 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel): ...@@ -521,19 +520,13 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
# Copy word embeddings from the previous weights # Copy word embeddings from the previous weights
self.tokens_embed.weight.data[:self.config.vocab_size, :] = old_embed.weight.data[:self.config.vocab_size, :] self.tokens_embed.weight.data[:self.config.vocab_size, :] = old_embed.weight.data[:self.config.vocab_size, :]
def prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model. """ Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer} heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
""" """
for layer, heads in heads_to_prune.items(): for layer, heads in heads_to_prune.items():
self.h[layer].attn.prune_heads(heads) self.h[layer].attn.prune_heads(heads)
def get_multihead_outputs(self):
""" Gather all multi-head outputs.
Return: list (layers) of multihead module outputs with gradients
"""
return [h.attn.multihead_output for h in self.h]
def forward(self, input_ids, position_ids=None, token_type_ids=None, head_mask=None): def forward(self, input_ids, position_ids=None, token_type_ids=None, head_mask=None):
if position_ids is None: if position_ids is None:
# This was used when we had a single embedding matrice from position and token embeddings # This was used when we had a single embedding matrice from position and token embeddings
...@@ -574,19 +567,26 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel): ...@@ -574,19 +567,26 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
output_shape = input_shape + (hidden_states.size(-1),) output_shape = input_shape + (hidden_states.size(-1),)
all_attentions = [] all_attentions = []
all_hidden_states = [hidden_states.view(*output_shape)] all_hidden_states = []
for i, block in enumerate(self.h): for i, block in enumerate(self.h):
if self.output_hidden_states:
all_hidden_states.append(hidden_states.view(*output_shape))
outputs = block(hidden_states, head_mask[i]) outputs = block(hidden_states, head_mask[i])
hidden_states = outputs[0]
if self.output_attentions: if self.output_attentions:
attentions, hidden_states = outputs all_attentions.append(outputs[1])
all_attentions.append(attentions)
else: # Add last layer
hidden_states = outputs if self.output_hidden_states:
all_hidden_states.append(hidden_states.view(*output_shape)) all_hidden_states.append(hidden_states.view(*output_shape))
outputs = [hidden_states.view(*output_shape)]
if self.output_hidden_states:
outputs.append(all_hidden_states)
if self.output_attentions: if self.output_attentions:
return all_attentions, all_hidden_states outputs.append(all_attentions)
return all_hidden_states return outputs # last hidden state, (all hidden states), (all attentions)
class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel): class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
...@@ -650,10 +650,9 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel): ...@@ -650,10 +650,9 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
``` ```
""" """
def __init__(self, config, output_attentions=False, keep_multihead_output=False): def __init__(self, config):
super(OpenAIGPTLMHeadModel, self).__init__(config) super(OpenAIGPTLMHeadModel, self).__init__(config)
self.transformer = OpenAIGPTModel(config, output_attentions=output_attentions, self.transformer = OpenAIGPTModel(config)
keep_multihead_output=keep_multihead_output)
self.lm_head = OpenAIGPTLMHead(self.transformer.tokens_embed.weight, config) self.lm_head = OpenAIGPTLMHead(self.transformer.tokens_embed.weight, config)
self.apply(self.init_weights) self.apply(self.init_weights)
...@@ -666,12 +665,11 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel): ...@@ -666,12 +665,11 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
self.lm_head.set_embeddings_weights(self.transformer.tokens_embed.weight, predict_special_tokens=predict_special_tokens) self.lm_head.set_embeddings_weights(self.transformer.tokens_embed.weight, predict_special_tokens=predict_special_tokens)
def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None, head_mask=None): def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None, head_mask=None):
hidden_states = self.transformer(input_ids, position_ids, token_type_ids, head_mask) transformer_outputs = self.transformer(input_ids, position_ids, token_type_ids, head_mask)
if self.transformer.output_attentions: hidden_states = transformer_outputs[0]
all_attentions, hidden_states = hidden_states
hidden_states = hidden_states[-1]
lm_logits = self.lm_head(hidden_states) lm_logits = self.lm_head(hidden_states)
outputs = [lm_logits] + transformer_outputs[1:]
if lm_labels is not None: if lm_labels is not None:
# Shift so that tokens < n predict n # Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].contiguous() shift_logits = lm_logits[..., :-1, :].contiguous()
...@@ -680,10 +678,9 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel): ...@@ -680,10 +678,9 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
loss_fct = CrossEntropyLoss(ignore_index=-1) loss_fct = CrossEntropyLoss(ignore_index=-1)
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1)) shift_labels.view(-1))
return loss outputs = [loss] + outputs
if self.transformer.output_attentions:
return all_attentions, lm_logits return outputs # (loss), lm_logits, (all hidden states), (all attentions)
return lm_logits
class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel): class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
...@@ -752,10 +749,9 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel): ...@@ -752,10 +749,9 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
``` ```
""" """
def __init__(self, config, output_attentions=False, keep_multihead_output=False): def __init__(self, config):
super(OpenAIGPTDoubleHeadsModel, self).__init__(config) super(OpenAIGPTDoubleHeadsModel, self).__init__(config)
self.transformer = OpenAIGPTModel(config, output_attentions=output_attentions, self.transformer = OpenAIGPTModel(config)
keep_multihead_output=keep_multihead_output)
self.lm_head = OpenAIGPTLMHead(self.transformer.tokens_embed.weight, config) self.lm_head = OpenAIGPTLMHead(self.transformer.tokens_embed.weight, config)
self.multiple_choice_head = OpenAIGPTMultipleChoiceHead(config) self.multiple_choice_head = OpenAIGPTMultipleChoiceHead(config)
self.apply(self.init_weights) self.apply(self.init_weights)
...@@ -768,26 +764,26 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel): ...@@ -768,26 +764,26 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
self.transformer.set_num_special_tokens(num_special_tokens) self.transformer.set_num_special_tokens(num_special_tokens)
self.lm_head.set_embeddings_weights(self.transformer.tokens_embed.weight, predict_special_tokens=predict_special_tokens) self.lm_head.set_embeddings_weights(self.transformer.tokens_embed.weight, predict_special_tokens=predict_special_tokens)
def forward(self, input_ids, mc_token_ids, lm_labels=None, mc_labels=None, token_type_ids=None, def forward(self, input_ids, mc_token_ids=None, lm_labels=None, mc_labels=None, token_type_ids=None,
position_ids=None, head_mask=None): position_ids=None, head_mask=None):
hidden_states = self.transformer(input_ids, position_ids, token_type_ids, head_mask) transformer_outputs = self.transformer(input_ids, position_ids, token_type_ids, head_mask)
if self.transformer.output_attentions: hidden_states = transformer_outputs[0]
all_attentions, hidden_states = hidden_states
hidden_states = hidden_states[-1]
lm_logits = self.lm_head(hidden_states) lm_logits = self.lm_head(hidden_states)
mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids) mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids)
losses = []
outputs = [lm_logits, mc_logits] + transformer_outputs[1:]
if mc_labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)),
mc_labels.view(-1))
outputs = [loss] + outputs
if lm_labels is not None: if lm_labels is not None:
shift_logits = lm_logits[..., :-1, :].contiguous() shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = lm_labels[..., 1:].contiguous() shift_labels = lm_labels[..., 1:].contiguous()
loss_fct = CrossEntropyLoss(ignore_index=-1) loss_fct = CrossEntropyLoss(ignore_index=-1)
losses.append(loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))) loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
if mc_labels is not None: shift_labels.view(-1))
loss_fct = CrossEntropyLoss() outputs = [loss] + outputs
losses.append(loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1)))
if losses: return outputs # (lm loss), (mc loss), lm logits, mc logits, (all hidden_states), (attentions)
return losses
if self.transformer.output_attentions:
return all_attentions, lm_logits, mc_logits
return lm_logits, mc_logits
...@@ -209,7 +209,8 @@ class TransfoXLConfig(PretrainedConfig): ...@@ -209,7 +209,8 @@ class TransfoXLConfig(PretrainedConfig):
init="normal", init="normal",
init_range=0.01, init_range=0.01,
proj_init_std=0.01, proj_init_std=0.01,
init_std=0.02): init_std=0.02,
**kwargs):
"""Constructs TransfoXLConfig. """Constructs TransfoXLConfig.
Args: Args:
...@@ -244,6 +245,8 @@ class TransfoXLConfig(PretrainedConfig): ...@@ -244,6 +245,8 @@ class TransfoXLConfig(PretrainedConfig):
proj_init_std: parameters initialized by N(0, init_std) proj_init_std: parameters initialized by N(0, init_std)
init_std: parameters initialized by N(0, init_std) init_std: parameters initialized by N(0, init_std)
""" """
super(TransfoXLConfig, self).__init__(**kwargs)
if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2 if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
and isinstance(vocab_size_or_config_json_file, unicode)): and isinstance(vocab_size_or_config_json_file, unicode)):
with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader: with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
...@@ -287,6 +290,7 @@ class TransfoXLConfig(PretrainedConfig): ...@@ -287,6 +290,7 @@ class TransfoXLConfig(PretrainedConfig):
"or the path to a pretrained model config file (str)") "or the path to a pretrained model config file (str)")
class PositionalEmbedding(nn.Module): class PositionalEmbedding(nn.Module):
def __init__(self, demb): def __init__(self, demb):
super(PositionalEmbedding, self).__init__() super(PositionalEmbedding, self).__init__()
...@@ -306,6 +310,7 @@ class PositionalEmbedding(nn.Module): ...@@ -306,6 +310,7 @@ class PositionalEmbedding(nn.Module):
return pos_emb[:,None,:] return pos_emb[:,None,:]
class PositionwiseFF(nn.Module): class PositionwiseFF(nn.Module):
def __init__(self, d_model, d_inner, dropout, pre_lnorm=False): def __init__(self, d_model, d_inner, dropout, pre_lnorm=False):
super(PositionwiseFF, self).__init__() super(PositionwiseFF, self).__init__()
...@@ -341,11 +346,14 @@ class PositionwiseFF(nn.Module): ...@@ -341,11 +346,14 @@ class PositionwiseFF(nn.Module):
return output return output
class MultiHeadAttn(nn.Module): class MultiHeadAttn(nn.Module):
def __init__(self, n_head, d_model, d_head, dropout, dropatt=0, def __init__(self, n_head, d_model, d_head, dropout, dropatt=0,
pre_lnorm=False, r_r_bias=None, r_w_bias=None): pre_lnorm=False, r_r_bias=None, r_w_bias=None, output_attentions=False):
super(MultiHeadAttn, self).__init__() super(MultiHeadAttn, self).__init__()
self.output_attentions = output_attentions
self.n_head = n_head self.n_head = n_head
self.d_model = d_model self.d_model = d_model
self.d_head = d_head self.d_head = d_head
...@@ -371,7 +379,7 @@ class MultiHeadAttn(nn.Module): ...@@ -371,7 +379,7 @@ class MultiHeadAttn(nn.Module):
self.r_r_bias = r_r_bias self.r_r_bias = r_r_bias
self.r_w_bias = r_w_bias self.r_w_bias = r_w_bias
def forward(self, h, attn_mask=None, mems=None): def forward(self, h, attn_mask=None, mems=None, head_mask=None):
##### multihead attention ##### multihead attention
# [hlen x bsz x n_head x d_head] # [hlen x bsz x n_head x d_head]
...@@ -404,6 +412,10 @@ class MultiHeadAttn(nn.Module): ...@@ -404,6 +412,10 @@ class MultiHeadAttn(nn.Module):
attn_prob = F.softmax(attn_score, dim=1) attn_prob = F.softmax(attn_score, dim=1)
attn_prob = self.dropatt(attn_prob) attn_prob = self.dropatt(attn_prob)
# Mask heads if we want to
if head_mask is not None:
attn_prob = attn_prob * head_mask
# [qlen x klen x bsz x n_head] + [klen x bsz x n_head x d_head] -> [qlen x bsz x n_head x d_head] # [qlen x klen x bsz x n_head] + [klen x bsz x n_head x d_head] -> [qlen x bsz x n_head x d_head]
attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, head_v)) attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, head_v))
attn_vec = attn_vec.contiguous().view( attn_vec = attn_vec.contiguous().view(
...@@ -415,19 +427,23 @@ class MultiHeadAttn(nn.Module): ...@@ -415,19 +427,23 @@ class MultiHeadAttn(nn.Module):
if self.pre_lnorm: if self.pre_lnorm:
##### residual connection ##### residual connection
output = h + attn_out outputs = [h + attn_out]
else: else:
##### residual connection + layer normalization ##### residual connection + layer normalization
output = self.layer_norm(h + attn_out) outputs = [self.layer_norm(h + attn_out)]
return output if self.output_attentions:
outputs.append(attn_prob)
return outputs
class RelMultiHeadAttn(nn.Module): class RelMultiHeadAttn(nn.Module):
def __init__(self, n_head, d_model, d_head, dropout, dropatt=0, def __init__(self, n_head, d_model, d_head, dropout, dropatt=0,
tgt_len=None, ext_len=None, mem_len=None, pre_lnorm=False, tgt_len=None, ext_len=None, mem_len=None, pre_lnorm=False,
r_r_bias=None, r_w_bias=None): r_r_bias=None, r_w_bias=None, output_attentions=False):
super(RelMultiHeadAttn, self).__init__() super(RelMultiHeadAttn, self).__init__()
self.output_attentions = output_attentions
self.n_head = n_head self.n_head = n_head
self.d_model = d_model self.d_model = d_model
self.d_head = d_head self.d_head = d_head
...@@ -506,7 +522,7 @@ class RelPartialLearnableMultiHeadAttn(RelMultiHeadAttn): ...@@ -506,7 +522,7 @@ class RelPartialLearnableMultiHeadAttn(RelMultiHeadAttn):
self.r_net = nn.Linear(self.d_model, self.n_head * self.d_head, bias=False) self.r_net = nn.Linear(self.d_model, self.n_head * self.d_head, bias=False)
def forward(self, w, r, attn_mask=None, mems=None): def forward(self, w, r, attn_mask=None, mems=None, head_mask=None):
qlen, rlen, bsz = w.size(0), r.size(0), w.size(1) qlen, rlen, bsz = w.size(0), r.size(0), w.size(1)
if mems is not None: if mems is not None:
...@@ -561,6 +577,10 @@ class RelPartialLearnableMultiHeadAttn(RelMultiHeadAttn): ...@@ -561,6 +577,10 @@ class RelPartialLearnableMultiHeadAttn(RelMultiHeadAttn):
attn_prob = F.softmax(attn_score, dim=1) attn_prob = F.softmax(attn_score, dim=1)
attn_prob = self.dropatt(attn_prob) attn_prob = self.dropatt(attn_prob)
# Mask heads if we want to
if head_mask is not None:
attn_prob = attn_prob * head_mask
#### compute attention vector #### compute attention vector
attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, w_head_v)) attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, w_head_v))
...@@ -574,18 +594,21 @@ class RelPartialLearnableMultiHeadAttn(RelMultiHeadAttn): ...@@ -574,18 +594,21 @@ class RelPartialLearnableMultiHeadAttn(RelMultiHeadAttn):
if self.pre_lnorm: if self.pre_lnorm:
##### residual connection ##### residual connection
output = w + attn_out outputs = [w + attn_out]
else: else:
##### residual connection + layer normalization ##### residual connection + layer normalization
output = self.layer_norm(w + attn_out) outputs = [self.layer_norm(w + attn_out)]
return output if self.output_attentions:
outputs.append(attn_prob)
return outputs
class RelLearnableMultiHeadAttn(RelMultiHeadAttn): class RelLearnableMultiHeadAttn(RelMultiHeadAttn):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(RelLearnableMultiHeadAttn, self).__init__(*args, **kwargs) super(RelLearnableMultiHeadAttn, self).__init__(*args, **kwargs)
def forward(self, w, r_emb, r_w_bias, r_bias, attn_mask=None, mems=None): def forward(self, w, r_emb, r_w_bias, r_bias, attn_mask=None, mems=None, head_mask=None):
# r_emb: [klen, n_head, d_head], used for term B # r_emb: [klen, n_head, d_head], used for term B
# r_w_bias: [n_head, d_head], used for term C # r_w_bias: [n_head, d_head], used for term C
# r_bias: [klen, n_head], used for term D # r_bias: [klen, n_head], used for term D
...@@ -646,6 +669,9 @@ class RelLearnableMultiHeadAttn(RelMultiHeadAttn): ...@@ -646,6 +669,9 @@ class RelLearnableMultiHeadAttn(RelMultiHeadAttn):
attn_prob = F.softmax(attn_score, dim=1) attn_prob = F.softmax(attn_score, dim=1)
attn_prob = self.dropatt(attn_prob) attn_prob = self.dropatt(attn_prob)
if head_mask is not None:
attn_prob = attn_prob * head_mask
#### compute attention vector #### compute attention vector
attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, w_head_v)) attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, w_head_v))
...@@ -659,12 +685,17 @@ class RelLearnableMultiHeadAttn(RelMultiHeadAttn): ...@@ -659,12 +685,17 @@ class RelLearnableMultiHeadAttn(RelMultiHeadAttn):
if self.pre_lnorm: if self.pre_lnorm:
##### residual connection ##### residual connection
output = w + attn_out outputs = [w + attn_out]
else: else:
##### residual connection + layer normalization ##### residual connection + layer normalization
output = self.layer_norm(w + attn_out) outputs = [self.layer_norm(w + attn_out)]
if self.output_attentions:
outputs.append(attn_prob)
return outputs
return output
class DecoderLayer(nn.Module): class DecoderLayer(nn.Module):
def __init__(self, n_head, d_model, d_head, d_inner, dropout, **kwargs): def __init__(self, n_head, d_model, d_head, d_inner, dropout, **kwargs):
...@@ -674,13 +705,15 @@ class DecoderLayer(nn.Module): ...@@ -674,13 +705,15 @@ class DecoderLayer(nn.Module):
self.pos_ff = PositionwiseFF(d_model, d_inner, dropout, self.pos_ff = PositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm')) pre_lnorm=kwargs.get('pre_lnorm'))
def forward(self, dec_inp, dec_attn_mask=None, mems=None): def forward(self, dec_inp, dec_attn_mask=None, mems=None, head_mask=None):
output = self.dec_attn(dec_inp, attn_mask=dec_attn_mask, attn_outputs = self.dec_attn(dec_inp, attn_mask=dec_attn_mask,
mems=mems) mems=mems, head_mask=head_mask)
output = self.pos_ff(output) ff_output = self.pos_ff(attn_outputs[0])
return output outputs = [ff_output] + attn_outputs[1:]
return outputs
class RelLearnableDecoderLayer(nn.Module): class RelLearnableDecoderLayer(nn.Module):
def __init__(self, n_head, d_model, d_head, d_inner, dropout, def __init__(self, n_head, d_model, d_head, d_inner, dropout,
...@@ -692,14 +725,16 @@ class RelLearnableDecoderLayer(nn.Module): ...@@ -692,14 +725,16 @@ class RelLearnableDecoderLayer(nn.Module):
self.pos_ff = PositionwiseFF(d_model, d_inner, dropout, self.pos_ff = PositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm')) pre_lnorm=kwargs.get('pre_lnorm'))
def forward(self, dec_inp, r_emb, r_w_bias, r_bias, dec_attn_mask=None, mems=None): def forward(self, dec_inp, r_emb, r_w_bias, r_bias, dec_attn_mask=None, mems=None, head_mask=None):
output = self.dec_attn(dec_inp, r_emb, r_w_bias, r_bias, attn_outputs = self.dec_attn(dec_inp, r_emb, r_w_bias, r_bias,
attn_mask=dec_attn_mask, attn_mask=dec_attn_mask,
mems=mems) mems=mems, head_mask=head_mask)
output = self.pos_ff(output) ff_output = self.pos_ff(attn_outputs[0])
return output outputs = [ff_output] + attn_outputs[1:]
return outputs
class RelPartialLearnableDecoderLayer(nn.Module): class RelPartialLearnableDecoderLayer(nn.Module):
def __init__(self, n_head, d_model, d_head, d_inner, dropout, def __init__(self, n_head, d_model, d_head, d_inner, dropout,
...@@ -711,14 +746,17 @@ class RelPartialLearnableDecoderLayer(nn.Module): ...@@ -711,14 +746,17 @@ class RelPartialLearnableDecoderLayer(nn.Module):
self.pos_ff = PositionwiseFF(d_model, d_inner, dropout, self.pos_ff = PositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm')) pre_lnorm=kwargs.get('pre_lnorm'))
def forward(self, dec_inp, r, dec_attn_mask=None, mems=None): def forward(self, dec_inp, r, dec_attn_mask=None, mems=None, head_mask=None):
output = self.dec_attn(dec_inp, r, attn_outputs = self.dec_attn(dec_inp, r,
attn_mask=dec_attn_mask, attn_mask=dec_attn_mask,
mems=mems) mems=mems, head_mask=head_mask)
output = self.pos_ff(output) ff_output = self.pos_ff(attn_outputs[0])
outputs = [ff_output] + attn_outputs[1:]
return outputs
return output
class AdaptiveEmbedding(nn.Module): class AdaptiveEmbedding(nn.Module):
...@@ -791,13 +829,9 @@ class TransfoXLPreTrainedModel(PreTrainedModel): ...@@ -791,13 +829,9 @@ class TransfoXLPreTrainedModel(PreTrainedModel):
""" """
config_class = TransfoXLConfig config_class = TransfoXLConfig
pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP
pretrained_config_archive_map = PRETRAINED_CONFIG_ARCHIVE_MAP
load_tf_weights = load_tf_weights_in_transfo_xl load_tf_weights = load_tf_weights_in_transfo_xl
base_model_prefix = "transformer" base_model_prefix = "transformer"
def __init__(self, *inputs, **kwargs):
super(TransfoXLPreTrainedModel, self).__init__(*inputs, **kwargs)
def _init_weight(self, weight): def _init_weight(self, weight):
if self.config.init == 'uniform': if self.config.init == 'uniform':
nn.init.uniform_(weight, -self.config.init_range, self.config.init_range) nn.init.uniform_(weight, -self.config.init_range, self.config.init_range)
...@@ -894,6 +928,9 @@ class TransfoXLModel(TransfoXLPreTrainedModel): ...@@ -894,6 +928,9 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
""" """
def __init__(self, config): def __init__(self, config):
super(TransfoXLModel, self).__init__(config) super(TransfoXLModel, self).__init__(config)
self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states
self.n_token = config.n_token self.n_token = config.n_token
self.d_embed = config.d_embed self.d_embed = config.d_embed
...@@ -928,7 +965,8 @@ class TransfoXLModel(TransfoXLPreTrainedModel): ...@@ -928,7 +965,8 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
tgt_len=config.tgt_len, ext_len=config.ext_len, mem_len=config.mem_len, tgt_len=config.tgt_len, ext_len=config.ext_len, mem_len=config.mem_len,
dropatt=config.dropatt, pre_lnorm=config.pre_lnorm, dropatt=config.dropatt, pre_lnorm=config.pre_lnorm,
r_w_bias=None if config.untie_r else self.r_w_bias, r_w_bias=None if config.untie_r else self.r_w_bias,
r_r_bias=None if config.untie_r else self.r_r_bias) r_r_bias=None if config.untie_r else self.r_r_bias,
output_attentions=self.output_attentions)
) )
elif config.attn_type == 1: # learnable embeddings elif config.attn_type == 1: # learnable embeddings
for i in range(config.n_layer): for i in range(config.n_layer):
...@@ -938,7 +976,8 @@ class TransfoXLModel(TransfoXLPreTrainedModel): ...@@ -938,7 +976,8 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
tgt_len=config.tgt_len, ext_len=config.ext_len, mem_len=config.mem_len, tgt_len=config.tgt_len, ext_len=config.ext_len, mem_len=config.mem_len,
dropatt=config.dropatt, pre_lnorm=config.pre_lnorm, dropatt=config.dropatt, pre_lnorm=config.pre_lnorm,
r_w_bias=None if config.untie_r else self.r_w_bias, r_w_bias=None if config.untie_r else self.r_w_bias,
r_r_bias=None if config.untie_r else self.r_r_bias) r_r_bias=None if config.untie_r else self.r_r_bias,
output_attentions=self.output_attentions)
) )
elif config.attn_type in [2, 3]: # absolute embeddings elif config.attn_type in [2, 3]: # absolute embeddings
for i in range(config.n_layer): for i in range(config.n_layer):
...@@ -947,7 +986,8 @@ class TransfoXLModel(TransfoXLPreTrainedModel): ...@@ -947,7 +986,8 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
config.n_head, config.d_model, config.d_head, config.d_inner, config.dropout, config.n_head, config.d_model, config.d_head, config.d_inner, config.dropout,
dropatt=config.dropatt, pre_lnorm=config.pre_lnorm, dropatt=config.dropatt, pre_lnorm=config.pre_lnorm,
r_w_bias=None if config.untie_r else self.r_w_bias, r_w_bias=None if config.untie_r else self.r_w_bias,
r_r_bias=None if config.untie_r else self.r_r_bias) r_r_bias=None if config.untie_r else self.r_r_bias,
output_attentions=self.output_attentions)
) )
self.same_length = config.same_length self.same_length = config.same_length
...@@ -965,17 +1005,21 @@ class TransfoXLModel(TransfoXLPreTrainedModel): ...@@ -965,17 +1005,21 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
elif self.attn_type == 3: # absolute deeper SA elif self.attn_type == 3: # absolute deeper SA
self.r_emb = nn.Parameter(torch.Tensor( self.r_emb = nn.Parameter(torch.Tensor(
self.n_layer, self.max_klen, self.n_head, self.d_head)) self.n_layer, self.max_klen, self.n_head, self.d_head))
self.apply(self.init_weights) self.apply(self.init_weights)
def backward_compatible(self): def backward_compatible(self):
self.sample_softmax = -1 self.sample_softmax = -1
def reset_length(self, tgt_len, ext_len, mem_len): def reset_length(self, tgt_len, ext_len, mem_len):
self.tgt_len = tgt_len self.tgt_len = tgt_len
self.mem_len = mem_len self.mem_len = mem_len
self.ext_len = ext_len self.ext_len = ext_len
def _prune_heads(self, heads):
logger.info("Head pruning is not implemented for Transformer-XL model")
pass
def init_mems(self, data): def init_mems(self, data):
if self.mem_len > 0: if self.mem_len > 0:
mems = [] mems = []
...@@ -1012,9 +1056,24 @@ class TransfoXLModel(TransfoXLPreTrainedModel): ...@@ -1012,9 +1056,24 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
return new_mems return new_mems
def _forward(self, dec_inp, mems=None): def _forward(self, dec_inp, mems=None, head_mask=None):
qlen, bsz = dec_inp.size() qlen, bsz = dec_inp.size()
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] (a head_mask for each layer)
# and head_mask is converted to shape [num_hidden_layers x qlen x klen x bsz x n_head]
if head_mask is not None:
if head_mask.dim() == 1:
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(0).unsqueeze(0)
head_mask = head_mask.expand(self.n_layer, -1, -1, -1, -1)
elif head_mask.dim() == 2:
head_mask = head_mask.unsqueeze(1).unsqueeze(1).unsqueeze(1)
head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
else:
head_mask = [None] * self.n_layer
word_emb = self.word_emb(dec_inp) word_emb = self.word_emb(dec_inp)
mlen = mems[0].size(0) if mems is not None else 0 mlen = mems[0].size(0) if mems is not None else 0
...@@ -1033,6 +1092,7 @@ class TransfoXLModel(TransfoXLPreTrainedModel): ...@@ -1033,6 +1092,7 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
word_emb.new_ones(qlen, klen), diagonal=1+mlen).byte()[:,:,None] word_emb.new_ones(qlen, klen), diagonal=1+mlen).byte()[:,:,None]
hids = [] hids = []
attentions = []
if self.attn_type == 0: # default if self.attn_type == 0: # default
pos_seq = torch.arange(klen-1, -1, -1.0, device=word_emb.device, pos_seq = torch.arange(klen-1, -1, -1.0, device=word_emb.device,
dtype=word_emb.dtype) dtype=word_emb.dtype)
...@@ -1046,7 +1106,11 @@ class TransfoXLModel(TransfoXLPreTrainedModel): ...@@ -1046,7 +1106,11 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
hids.append(core_out) hids.append(core_out)
mems_i = None if mems is None else mems[i] mems_i = None if mems is None else mems[i]
core_out = layer(core_out, pos_emb, dec_attn_mask=dec_attn_mask, mems=mems_i) layer_outputs = layer(core_out, pos_emb, dec_attn_mask=dec_attn_mask,
mems=mems_i, head_mask=head_mask[i])
core_out = layer_outputs[0]
if self.output_attentions:
attentions.append(layer_outputs[1])
elif self.attn_type == 1: # learnable elif self.attn_type == 1: # learnable
core_out = self.drop(word_emb) core_out = self.drop(word_emb)
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
...@@ -1058,8 +1122,12 @@ class TransfoXLModel(TransfoXLPreTrainedModel): ...@@ -1058,8 +1122,12 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
r_emb, r_bias = self.r_emb[i], self.r_bias[i] r_emb, r_bias = self.r_emb[i], self.r_bias[i]
mems_i = None if mems is None else mems[i] mems_i = None if mems is None else mems[i]
core_out = layer(core_out, r_emb, self.r_w_bias[i], layer_outputs = layer(core_out, r_emb, self.r_w_bias[i],
r_bias, dec_attn_mask=dec_attn_mask, mems=mems_i) r_bias, dec_attn_mask=dec_attn_mask,
mems=mems_i, head_mask=head_mask[i])
core_out = layer_outputs[0]
if self.output_attentions:
attentions.append(layer_outputs[1])
elif self.attn_type == 2: # absolute elif self.attn_type == 2: # absolute
pos_seq = torch.arange(klen - 1, -1, -1.0, device=word_emb.device, pos_seq = torch.arange(klen - 1, -1, -1.0, device=word_emb.device,
dtype=word_emb.dtype) dtype=word_emb.dtype)
...@@ -1074,8 +1142,11 @@ class TransfoXLModel(TransfoXLPreTrainedModel): ...@@ -1074,8 +1142,11 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
mems_i = None if mems is None else mems[i] mems_i = None if mems is None else mems[i]
if mems_i is not None and i == 0: if mems_i is not None and i == 0:
mems_i += pos_emb[:mlen] mems_i += pos_emb[:mlen]
core_out = layer(core_out, dec_attn_mask=dec_attn_mask, layer_outputs = layer(core_out, dec_attn_mask=dec_attn_mask,
mems=mems_i) mems=mems_i, head_mask=head_mask[i])
core_out = layer_outputs[0]
if self.output_attentions:
attentions.append(layer_outputs[1])
elif self.attn_type == 3: elif self.attn_type == 3:
core_out = self.drop(word_emb) core_out = self.drop(word_emb)
...@@ -1093,16 +1164,30 @@ class TransfoXLModel(TransfoXLPreTrainedModel): ...@@ -1093,16 +1164,30 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
mems_i += cur_emb.view(mlen, 1, -1) mems_i += cur_emb.view(mlen, 1, -1)
core_out += self.r_emb[i][-qlen:].view(qlen, 1, -1) core_out += self.r_emb[i][-qlen:].view(qlen, 1, -1)
core_out = layer(core_out, dec_attn_mask=dec_attn_mask, layer_outputs = layer(core_out, dec_attn_mask=dec_attn_mask,
mems=mems_i) mems=mems_i, head_mask=head_mask[i])
core_out = layer_outputs[0]
if self.output_attentions:
attentions.append(layer_outputs[1])
core_out = self.drop(core_out) core_out = self.drop(core_out)
new_mems = self._update_mems(hids, mems, mlen, qlen) new_mems = self._update_mems(hids, mems, mlen, qlen)
return core_out, new_mems # We transpose back here to shape [bsz, len, hidden_dim]
outputs = [core_out.transpose(0, 1).contiguous(), new_mems]
def forward(self, input_ids, mems=None): if self.output_hidden_states:
# Add last layer and transpose to library standard shape [bsz, len, hidden_dim]
hids.append(core_out)
hids = list(t.transpose(0, 1).contiguous() for t in hids)
outputs.append(hids)
if self.output_attentions:
# Transpose to library standard shape [bsz, n_heads, query_seq_len, key_seq_len]
attentions = list(t.permute(2, 3, 0, 1).contiguous() for t in attentions)
outputs.append(attentions)
return outputs # last hidden state, new_mems, (all hidden states), (all attentions)
def forward(self, input_ids, mems=None, head_mask=None):
""" Params: """ Params:
input_ids :: [bsz, len] input_ids :: [bsz, len]
mems :: optional mems from previous forwar passes (or init_mems) mems :: optional mems from previous forwar passes (or init_mems)
...@@ -1122,11 +1207,9 @@ class TransfoXLModel(TransfoXLPreTrainedModel): ...@@ -1122,11 +1207,9 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
if mems is None: if mems is None:
mems = self.init_mems(input_ids) mems = self.init_mems(input_ids)
last_hidden, new_mems = self._forward(input_ids, mems=mems) outputs = self._forward(input_ids, mems=mems, head_mask=head_mask)
# We transpose back here to shape [bsz, len, hidden_dim] return outputs # last hidden state, new_mems, (all hidden states), (all attentions)
last_hidden = last_hidden.transpose(0, 1).contiguous()
return (last_hidden, new_mems)
class TransfoXLLMHeadModel(TransfoXLPreTrainedModel): class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
...@@ -1218,7 +1301,7 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel): ...@@ -1218,7 +1301,7 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
def init_mems(self, data): def init_mems(self, data):
return self.transformer.init_mems(data) return self.transformer.init_mems(data)
def forward(self, input_ids, labels=None, mems=None): def forward(self, input_ids, labels=None, mems=None, head_mask=None):
""" Params: """ Params:
input_ids :: [bsz, len] input_ids :: [bsz, len]
labels :: [bsz, len] labels :: [bsz, len]
...@@ -1235,19 +1318,26 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel): ...@@ -1235,19 +1318,26 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
bsz = input_ids.size(0) bsz = input_ids.size(0)
tgt_len = input_ids.size(1) tgt_len = input_ids.size(1)
last_hidden, new_mems = self.transformer(input_ids, mems) transformer_outputs = self.transformer(input_ids, mems, head_mask)
last_hidden = transformer_outputs[0]
pred_hid = last_hidden[:, -tgt_len:] pred_hid = last_hidden[:, -tgt_len:]
outputs = transformer_outputs[1:]
if self.sample_softmax > 0 and self.training: if self.sample_softmax > 0 and self.training:
assert self.config.tie_weight assert self.config.tie_weight
logit = sample_logits(self.transformer.word_emb, self.out_layer.bias, labels, pred_hid, self.sampler) logit = sample_logits(self.transformer.word_emb, self.out_layer.bias, labels, pred_hid, self.sampler)
softmax_output = -F.log_softmax(logit, -1)[:, :, 0] softmax_output = -F.log_softmax(logit, -1)[:, :, 0]
outputs = [softmax_output] + outputs
if labels is not None:
# TODO: This is not implemented
raise NotImplementedError
else: else:
softmax_output = self.crit(pred_hid.view(-1, pred_hid.size(-1)), labels) softmax_output = self.crit(pred_hid.view(-1, pred_hid.size(-1)), labels)
if labels is None: if labels is None:
softmax_output = softmax_output.view(bsz, tgt_len, -1) softmax_output = softmax_output.view(bsz, tgt_len, -1)
outputs = [softmax_output] + outputs
else: else:
softmax_output = softmax_output.view(bsz, tgt_len) softmax_output = softmax_output.view(bsz, tgt_len)
outputs = [softmax_output, None] + outputs
# We transpose back return outputs # (loss), logits or None if labels is not None (speed up adaptive softmax), new_mems, (all hidden states), (all attentions)
return (softmax_output, new_mems)
...@@ -73,6 +73,7 @@ class XLMConfig(PretrainedConfig): ...@@ -73,6 +73,7 @@ class XLMConfig(PretrainedConfig):
def __init__(self, def __init__(self,
vocab_size_or_config_json_file, vocab_size_or_config_json_file,
causal=True,
d_model=1024, d_model=1024,
n_layer=24, n_layer=24,
n_head=16, n_head=16,
...@@ -145,6 +146,7 @@ class XLMConfig(PretrainedConfig): ...@@ -145,6 +146,7 @@ class XLMConfig(PretrainedConfig):
self.__dict__[key] = value self.__dict__[key] = value
elif isinstance(vocab_size_or_config_json_file, int): elif isinstance(vocab_size_or_config_json_file, int):
self.n_token = vocab_size_or_config_json_file self.n_token = vocab_size_or_config_json_file
self.causal = causal
self.d_model = d_model self.d_model = d_model
self.n_layer = n_layer self.n_layer = n_layer
self.n_head = n_head self.n_head = n_head
...@@ -396,7 +398,6 @@ class XLMPreTrainedModel(PreTrainedModel): ...@@ -396,7 +398,6 @@ class XLMPreTrainedModel(PreTrainedModel):
""" """
config_class = XLMConfig config_class = XLMConfig
pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP
pretrained_config_archive_map = PRETRAINED_CONFIG_ARCHIVE_MAP
load_tf_weights = None load_tf_weights = None
base_model_prefix = "xlm" base_model_prefix = "xlm"
...@@ -429,7 +430,7 @@ class XLMModel(XLMPreTrainedModel): ...@@ -429,7 +430,7 @@ class XLMModel(XLMPreTrainedModel):
'hidden_dim', 'dropout', 'attention_dropout', 'asm', 'hidden_dim', 'dropout', 'attention_dropout', 'asm',
'asm_cutoffs', 'asm_div_value'] 'asm_cutoffs', 'asm_div_value']
def __init__(self, params, output_attentions=False, keep_multihead_output=False): #, dico, is_encoder, with_output): def __init__(self, params, output_attentions=False, output_hidden_states=False): #, dico, is_encoder, with_output):
""" XLM model from: "Cross-lingual Language Model Pretraining" by Guillaume Lample, Alexis Conneau """ XLM model from: "Cross-lingual Language Model Pretraining" by Guillaume Lample, Alexis Conneau
Paper: https://arxiv.org/abs/1901.07291 Paper: https://arxiv.org/abs/1901.07291
Original code: https://github.com/facebookresearch/XLM Original code: https://github.com/facebookresearch/XLM
...@@ -483,11 +484,13 @@ class XLMModel(XLMPreTrainedModel): ...@@ -483,11 +484,13 @@ class XLMModel(XLMPreTrainedModel):
""" """
super(XLMModel, self).__init__(params) super(XLMModel, self).__init__(params)
self.output_attentions = output_attentions self.output_attentions = output_attentions
self.output_hidden_states = output_hidden_states
# encoder / decoder, output layer # encoder / decoder, output layer
# self.is_encoder = is_encoder # self.is_encoder = is_encoder
# self.is_decoder = not is_encoder # self.is_decoder = not is_encoder
# self.with_output = with_output # self.with_output = with_output
self.causal = params.causal
# dictionary / languages # dictionary / languages
self.n_langs = params.n_langs self.n_langs = params.n_langs
...@@ -536,63 +539,45 @@ class XLMModel(XLMPreTrainedModel): ...@@ -536,63 +539,45 @@ class XLMModel(XLMPreTrainedModel):
self.ffns.append(TransformerFFN(self.dim, self.hidden_dim, self.dim, dropout=self.dropout, gelu_activation=params.gelu_activation)) self.ffns.append(TransformerFFN(self.dim, self.hidden_dim, self.dim, dropout=self.dropout, gelu_activation=params.gelu_activation))
self.layer_norm2.append(nn.LayerNorm(self.dim, eps=1e-12)) self.layer_norm2.append(nn.LayerNorm(self.dim, eps=1e-12))
# output layer def forward(self, x, lengths, positions=None, langs=None, cache=None, head_mask=None): # src_enc=None, src_len=None,
# if self.with_output:
# self.pred_layer = PredLayer(params)
# if params.share_inout_emb:
# self.pred_layer.proj.weight = self.embeddings.weight
# def forward(self, mode, **kwargs):
# """
# Forward function with different forward modes.
# ### Small hack to handle PyTorch distributed.
# """
# if mode == 'fwd':
# return self.fwd(**kwargs)
# elif mode == 'predict':
# return self.predict(**kwargs)
# else:
# raise Exception("Unknown mode: %s" % mode)
def forward(self, x, lengths, causal, src_enc=None, src_len=None, positions=None, langs=None, cache=None):
""" """
Inputs: Inputs:
`x` LongTensor(slen, bs), containing word indices `x` LongTensor(bs, slen), containing word indices
`lengths` LongTensor(bs), containing the length of each sentence `lengths` LongTensor(bs), containing the length of each sentence
`causal` Boolean, if True, the attention is only done over previous hidden states `causal` Boolean, if True, the attention is only done over previous hidden states
`positions` LongTensor(slen, bs), containing word positions `positions` LongTensor(bs, slen), containing word positions
`langs` LongTensor(slen, bs), containing language IDs `langs` LongTensor(bs, slen), containing language IDs
""" """
# lengths = (x != self.pad_index).float().sum(dim=1) # lengths = (x != self.pad_index).float().sum(dim=1)
# mask = x != self.pad_index # mask = x != self.pad_index
# check inputs # check inputs
slen, bs = x.size() bs, slen = x.size()
assert lengths.size(0) == bs assert lengths.size(0) == bs
assert lengths.max().item() <= slen assert lengths.max().item() <= slen
x = x.transpose(0, 1) # batch size as dimension 0 # x = x.transpose(0, 1) # batch size as dimension 0
assert (src_enc is None) == (src_len is None) # assert (src_enc is None) == (src_len is None)
if src_enc is not None: # if src_enc is not None:
assert self.is_decoder # assert self.is_decoder
assert src_enc.size(0) == bs # assert src_enc.size(0) == bs
# generate masks # generate masks
mask, attn_mask = get_masks(slen, lengths, causal) mask, attn_mask = get_masks(slen, lengths, self.causal)
if self.is_decoder and src_enc is not None: # if self.is_decoder and src_enc is not None:
src_mask = torch.arange(src_len.max(), dtype=torch.long, device=lengths.device) < src_len[:, None] # src_mask = torch.arange(src_len.max(), dtype=torch.long, device=lengths.device) < src_len[:, None]
# positions # positions
if positions is None: if positions is None:
positions = x.new(slen).long() positions = x.new(slen).long()
positions = torch.arange(slen, out=positions).unsqueeze(0) positions = torch.arange(slen, out=positions).unsqueeze(0)
else: else:
assert positions.size() == (slen, bs) assert positions.size() == (bs, slen) # (slen, bs)
positions = positions.transpose(0, 1) # positions = positions.transpose(0, 1)
# langs # langs
if langs is not None: if langs is not None:
assert langs.size() == (slen, bs) assert langs.size() == (bs, slen) # (slen, bs)
langs = langs.transpose(0, 1) # langs = langs.transpose(0, 1)
# do not recompute cached elements # do not recompute cached elements
if cache is not None: if cache is not None:
...@@ -614,620 +599,50 @@ class XLMModel(XLMPreTrainedModel): ...@@ -614,620 +599,50 @@ class XLMModel(XLMPreTrainedModel):
tensor *= mask.unsqueeze(-1).to(tensor.dtype) tensor *= mask.unsqueeze(-1).to(tensor.dtype)
# transformer layers # transformer layers
hidden_states = []
attentions = []
for i in range(self.n_layers): for i in range(self.n_layers):
if self.output_hidden_states:
hidden_states.append(tensor)
# self attention # self attention
attn = self.attentions[i](tensor, attn_mask, cache=cache) attn_outputs = self.attentions[i](tensor, attn_mask, cache=cache, head_mask=head_mask[i])
attn = attn_outputs[0]
if self.output_attentions:
attentions.append(attn_outputs[1])
attn = F.dropout(attn, p=self.dropout, training=self.training) attn = F.dropout(attn, p=self.dropout, training=self.training)
tensor = tensor + attn tensor = tensor + attn
tensor = self.layer_norm1[i](tensor) tensor = self.layer_norm1[i](tensor)
# encoder attention (for decoder only) # encoder attention (for decoder only)
if self.is_decoder and src_enc is not None: # if self.is_decoder and src_enc is not None:
attn = self.encoder_attn[i](tensor, src_mask, kv=src_enc, cache=cache) # attn = self.encoder_attn[i](tensor, src_mask, kv=src_enc, cache=cache)
attn = F.dropout(attn, p=self.dropout, training=self.training) # attn = F.dropout(attn, p=self.dropout, training=self.training)
tensor = tensor + attn # tensor = tensor + attn
tensor = self.layer_norm15[i](tensor) # tensor = self.layer_norm15[i](tensor)
# FFN # FFN
tensor = tensor + self.ffns[i](tensor) tensor = tensor + self.ffns[i](tensor)
tensor = self.layer_norm2[i](tensor) tensor = self.layer_norm2[i](tensor)
tensor *= mask.unsqueeze(-1).to(tensor.dtype) tensor *= mask.unsqueeze(-1).to(tensor.dtype)
# Add last hidden state
if self.output_hidden_states:
hidden_states.append(tensor)
# update cache length # update cache length
if cache is not None: if cache is not None:
cache['slen'] += tensor.size(1) cache['slen'] += tensor.size(1)
# move back sequence length to dimension 0 # move back sequence length to dimension 0
tensor = tensor.transpose(0, 1) # tensor = tensor.transpose(0, 1)
return tensor
def predict(self, tensor, pred_mask, y, get_scores):
"""
Given the last hidden state, compute word scores and/or the loss.
`pred_mask` is a ByteTensor of shape (slen, bs), filled with 1 when
we need to predict a word
`y` is a LongTensor of shape (pred_mask.sum(),)
`get_scores` is a boolean specifying whether we need to return scores
"""
masked_tensor = tensor[pred_mask.unsqueeze(-1).expand_as(tensor)].view(-1, self.dim)
scores, loss = self.pred_layer(masked_tensor, y, get_scores)
return scores, loss
def generate(self, src_enc, src_len, tgt_lang_id, max_len=200, sample_temperature=None):
"""
Decode a sentence given initial start.
`x`:
- LongTensor(bs, slen)
<EOS> W1 W2 W3 <EOS> <PAD>
<EOS> W1 W2 W3 W4 <EOS>
`lengths`:
- LongTensor(bs) [5, 6]
`positions`:
- False, for regular "arange" positions (LM)
- True, to reset positions from the new generation (MT)
`langs`:
- must be None if the model only supports one language
- lang_id if only one language is involved (LM)
- (lang_id1, lang_id2) if two languages are involved (MT)
"""
# input batch
bs = len(src_len)
assert src_enc.size(0) == bs
# generated sentences
generated = src_len.new(max_len, bs) # upcoming output
generated.fill_(self.pad_index) # fill upcoming ouput with <PAD>
generated[0].fill_(self.eos_index) # we use <EOS> for <BOS> everywhere
# positions
positions = src_len.new(max_len).long()
positions = torch.arange(max_len, out=positions).unsqueeze(1).expand(max_len, bs)
# language IDs
langs = src_len.new(max_len).long().fill_(tgt_lang_id)
langs = langs.unsqueeze(1).expand(max_len, bs)
# current position / max lengths / length of generated sentences / unfinished sentences
cur_len = 1
gen_len = src_len.clone().fill_(1)
unfinished_sents = src_len.clone().fill_(1)
# cache compute states
cache = {'slen': 0}
while cur_len < max_len:
# compute word scores
tensor = self.forward(
'fwd',
x=generated[:cur_len],
lengths=gen_len,
positions=positions[:cur_len],
langs=langs[:cur_len],
causal=True,
src_enc=src_enc,
src_len=src_len,
cache=cache
)
assert tensor.size() == (1, bs, self.dim)
tensor = tensor.data[-1, :, :] # (bs, dim)
scores = self.pred_layer.get_scores(tensor) # (bs, n_words)
# select next words: sample or greedy
if sample_temperature is None:
next_words = torch.topk(scores, 1)[1].squeeze(1)
else:
next_words = torch.multinomial(F.softmax(scores / sample_temperature, dim=1), 1).squeeze(1)
assert next_words.size() == (bs,)
# update generations / lengths / finished sentences / current length
generated[cur_len] = next_words * unfinished_sents + self.pad_index * (1 - unfinished_sents)
gen_len.add_(unfinished_sents)
unfinished_sents.mul_(next_words.ne(self.eos_index).long())
cur_len = cur_len + 1
# stop when there is a </s> in each sentence, or if we exceed the maximul length
if unfinished_sents.max() == 0:
break
# add <EOS> to unfinished sentences
if cur_len == max_len:
generated[-1].masked_fill_(unfinished_sents.byte(), self.eos_index)
# sanity check
assert (generated == self.eos_index).sum() == 2 * bs
return generated[:cur_len], gen_len
def generate_beam(self, src_enc, src_len, tgt_lang_id, beam_size, length_penalty, early_stopping, max_len=200):
"""
Decode a sentence given initial start.
`x`:
- LongTensor(bs, slen)
<EOS> W1 W2 W3 <EOS> <PAD>
<EOS> W1 W2 W3 W4 <EOS>
`lengths`:
- LongTensor(bs) [5, 6]
`positions`:
- False, for regular "arange" positions (LM)
- True, to reset positions from the new generation (MT)
`langs`:
- must be None if the model only supports one language
- lang_id if only one language is involved (LM)
- (lang_id1, lang_id2) if two languages are involved (MT)
"""
# check inputs
assert src_enc.size(0) == src_len.size(0)
assert beam_size >= 1
# batch size / number of words
bs = len(src_len)
n_words = self.n_words
# expand to beam size the source latent representations / source lengths
src_enc = src_enc.unsqueeze(1).expand((bs, beam_size) + src_enc.shape[1:]).contiguous().view((bs * beam_size,) + src_enc.shape[1:])
src_len = src_len.unsqueeze(1).expand(bs, beam_size).contiguous().view(-1)
# generated sentences (batch with beam current hypotheses)
generated = src_len.new(max_len, bs * beam_size) # upcoming output
generated.fill_(self.pad_index) # fill upcoming ouput with <PAD>
generated[0].fill_(self.eos_index) # we use <EOS> for <BOS> everywhere
# generated hypotheses
generated_hyps = [BeamHypotheses(beam_size, max_len, length_penalty, early_stopping) for _ in range(bs)]
# positions
positions = src_len.new(max_len).long()
positions = torch.arange(max_len, out=positions).unsqueeze(1).expand_as(generated)
# language IDs
langs = positions.clone().fill_(tgt_lang_id)
# scores for each sentence in the beam
beam_scores = src_enc.new(bs, beam_size).fill_(0)
beam_scores[:, 1:] = -1e9
beam_scores = beam_scores.view(-1)
# current position
cur_len = 1
# cache compute states
cache = {'slen': 0}
# done sentences
done = [False for _ in range(bs)]
while cur_len < max_len:
# compute word scores
tensor = self.forward(
'fwd',
x=generated[:cur_len],
lengths=src_len.new(bs * beam_size).fill_(cur_len),
positions=positions[:cur_len],
langs=langs[:cur_len],
causal=True,
src_enc=src_enc,
src_len=src_len,
cache=cache
)
assert tensor.size() == (1, bs * beam_size, self.dim)
tensor = tensor.data[-1, :, :] # (bs * beam_size, dim)
scores = self.pred_layer.get_scores(tensor) # (bs * beam_size, n_words)
scores = F.log_softmax(scores, dim=-1) # (bs * beam_size, n_words)
assert scores.size() == (bs * beam_size, n_words)
# select next words with scores
_scores = scores + beam_scores[:, None].expand_as(scores) # (bs * beam_size, n_words)
_scores = _scores.view(bs, beam_size * n_words) # (bs, beam_size * n_words)
next_scores, next_words = torch.topk(_scores, 2 * beam_size, dim=1, largest=True, sorted=True)
assert next_scores.size() == next_words.size() == (bs, 2 * beam_size)
# next batch beam content
# list of (bs * beam_size) tuple(next hypothesis score, next word, current position in the batch)
next_batch_beam = []
# for each sentence
for sent_id in range(bs):
# if we are done with this sentence
done[sent_id] = done[sent_id] or generated_hyps[sent_id].is_done(next_scores[sent_id].max().item())
if done[sent_id]:
next_batch_beam.extend([(0, self.pad_index, 0)] * beam_size) # pad the batch
continue
# next sentence beam content
next_sent_beam = []
# next words for this sentence
for idx, value in zip(next_words[sent_id], next_scores[sent_id]):
# get beam and word IDs
beam_id = idx // n_words
word_id = idx % n_words
# end of sentence, or next word
if word_id == self.eos_index or cur_len + 1 == max_len:
generated_hyps[sent_id].add(generated[:cur_len, sent_id * beam_size + beam_id].clone(), value.item())
else:
next_sent_beam.append((value, word_id, sent_id * beam_size + beam_id))
# the beam for next step is full
if len(next_sent_beam) == beam_size:
break
# update next beam content
assert len(next_sent_beam) == 0 if cur_len + 1 == max_len else beam_size
if len(next_sent_beam) == 0:
next_sent_beam = [(0, self.pad_index, 0)] * beam_size # pad the batch
next_batch_beam.extend(next_sent_beam)
assert len(next_batch_beam) == beam_size * (sent_id + 1)
# sanity check / prepare next batch
assert len(next_batch_beam) == bs * beam_size
beam_scores = beam_scores.new([x[0] for x in next_batch_beam])
beam_words = generated.new([x[1] for x in next_batch_beam])
beam_idx = src_len.new([x[2] for x in next_batch_beam])
# re-order batch and internal states
generated = generated[:, beam_idx]
generated[cur_len] = beam_words
for k in cache.keys():
if k != 'slen':
cache[k] = (cache[k][0][beam_idx], cache[k][1][beam_idx])
# update current length
cur_len = cur_len + 1
# stop when we are done with each sentence
if all(done):
break
# visualize hypotheses
# print([len(x) for x in generated_hyps], cur_len)
# globals().update( locals() );
# !import code; code.interact(local=vars())
# for ii in range(bs):
# for ss, ww in sorted(generated_hyps[ii].hyp, key=lambda x: x[0], reverse=True):
# print("%.3f " % ss + " ".join(self.dico[x] for x in ww.tolist()))
# print("")
# select the best hypotheses
tgt_len = src_len.new(bs)
best = []
for i, hypotheses in enumerate(generated_hyps):
best_hyp = max(hypotheses.hyp, key=lambda x: x[0])[1]
tgt_len[i] = len(best_hyp) + 1 # +1 for the <EOS> symbol
best.append(best_hyp)
# generate target batch
decoded = src_len.new(tgt_len.max().item(), bs).fill_(self.pad_index)
for i, hypo in enumerate(best):
decoded[:tgt_len[i] - 1, i] = hypo
decoded[tgt_len[i] - 1, i] = self.eos_index
# sanity check
assert (decoded == self.eos_index).sum() == 2 * bs
return decoded, tgt_len
class XLMModel(XLMPreTrainedModel):
def __init__(self, config, output_attentions=False, output_hidden_states=False):
super(XLMModel, self).__init__(config)
self.output_attentions = output_attentions
self.output_hidden_states = output_hidden_states
self.mem_len = config.mem_len
self.reuse_len = config.reuse_len
self.d_model = config.d_model
self.same_length = config.same_length
self.attn_type = config.attn_type
self.bi_data = config.bi_data
self.clamp_len = config.clamp_len
self.word_embedding = nn.Embedding(config.n_token, config.d_model)
self.mask_emb = nn.Parameter(torch.Tensor(1, 1, config.d_model))
layer = XLMLayer(config, output_attentions=output_attentions,
keep_multihead_output=keep_multihead_output)
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.n_layer)])
self.dropout = nn.Dropout(config.dropout)
def prune_heads(self, heads_to_prune):
""" Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
"""
for layer, heads in heads_to_prune.items():
self.layer[layer].attention.prune_heads(heads)
def get_multihead_outputs(self):
""" Gather all multi-head outputs.
Return: list (layers) of multihead module outputs with gradients
"""
return [layer.attention.self.multihead_output for layer in self.layer]
def create_mask(self, qlen, mlen):
""" create causal attention mask.
float mask where 1.0 indicate masked, 0.0 indicated not-masked.
same_length=False: same_length=True:
<mlen > < qlen > <mlen > < qlen >
^ [0 0 0 0 0 1 1 1 1] [0 0 0 0 0 1 1 1 1]
[0 0 0 0 0 0 1 1 1] [1 0 0 0 0 0 1 1 1]
qlen [0 0 0 0 0 0 0 1 1] [1 1 0 0 0 0 0 1 1]
[0 0 0 0 0 0 0 0 1] [1 1 1 0 0 0 0 0 1]
v [0 0 0 0 0 0 0 0 0] [1 1 1 1 0 0 0 0 0]
"""
attn_mask = torch.ones([qlen, qlen])
mask_up = torch.triu(attn_mask, diagonal=1)
attn_mask_pad = torch.zeros([qlen, mlen])
ret = torch.cat([attn_mask_pad, mask_up], dim=1)
if self.same_length:
mask_lo = torch.tril(attn_mask, diagonal=-1)
ret = torch.cat([ret[:, :qlen] + mask_lo, ret[:, qlen:]], dim=1)
ret = ret.to(next(self.parameters()))
return ret
def cache_mem(self, curr_out, prev_mem):
"""cache hidden states into memory."""
if self.mem_len is None or self.mem_len == 0:
return None
else:
if self.reuse_len is not None and self.reuse_len > 0:
curr_out = curr_out[:self.reuse_len]
if prev_mem is None:
new_mem = curr_out[-self.mem_len:]
else:
new_mem = torch.cat([prev_mem, curr_out], dim=0)[-self.mem_len:]
return new_mem.detach()
@staticmethod
def positional_embedding(pos_seq, inv_freq, bsz=None):
sinusoid_inp = torch.einsum('i,d->id', pos_seq, inv_freq)
pos_emb = torch.cat([torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)], dim=-1)
pos_emb = pos_emb[:, None, :]
if bsz is not None:
pos_emb = pos_emb.expand(-1, bsz, -1)
return pos_emb
def relative_positional_encoding(self, qlen, klen, bsz=None):
"""create relative positional encoding."""
freq_seq = torch.arange(0, self.d_model, 2.0, dtype=torch.float)
inv_freq = 1 / (10000 ** (freq_seq / self.d_model))
if self.attn_type == 'bi':
# beg, end = klen - 1, -qlen
beg, end = klen, -qlen
elif self.attn_type == 'uni':
# beg, end = klen - 1, -1
beg, end = klen, -1
else:
raise ValueError('Unknown `attn_type` {}.'.format(self.attn_type))
if self.bi_data:
fwd_pos_seq = torch.arange(beg, end, -1.0, dtype=torch.float)
bwd_pos_seq = torch.arange(-beg, -end, 1.0, dtype=torch.float)
if self.clamp_len > 0:
fwd_pos_seq = fwd_pos_seq.clamp(-self.clamp_len, self.clamp_len)
bwd_pos_seq = bwd_pos_seq.clamp(-self.clamp_len, self.clamp_len)
if bsz is not None:
fwd_pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq, bsz//2)
bwd_pos_emb = self.positional_embedding(bwd_pos_seq, inv_freq, bsz//2)
else:
fwd_pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq)
bwd_pos_emb = self.positional_embedding(bwd_pos_seq, inv_freq)
pos_emb = torch.cat([fwd_pos_emb, bwd_pos_emb], dim=1)
else:
fwd_pos_seq = torch.arange(beg, end, -1.0)
if self.clamp_len > 0:
fwd_pos_seq = fwd_pos_seq.clamp(-self.clamp_len, self.clamp_len)
pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq, bsz)
pos_emb = pos_emb.to(next(self.parameters()))
return pos_emb
def forward(self, inp_k, token_type_ids=None, input_mask=None, attention_mask=None,
mems=None, perm_mask=None, target_mapping=None, inp_q=None, head_mask=None):
"""
Args:
inp_k: int32 Tensor in shape [bsz, len], the input token IDs.
token_type_ids: int32 Tensor in shape [bsz, len], the input segment IDs.
input_mask: [optional] float32 Tensor in shape [bsz, len], the input mask.
0 for real tokens and 1 for padding.
attention_mask: [optional] float32 Tensor, SAME FUNCTION as `input_mask`
but with 1 for real tokens and 0 for padding.
Added for easy compatibility with the XLM model (which uses this negative masking).
You can only uses one among `input_mask` and `attention_mask`
mems: [optional] a list of float32 Tensors in shape [mem_len, bsz, d_model], memory
from previous batches. The length of the list equals n_layer.
If None, no memory is used.
perm_mask: [optional] float32 Tensor in shape [bsz, len, len].
If perm_mask[k, i, j] = 0, i attend to j in batch k;
if perm_mask[k, i, j] = 1, i does not attend to j in batch k.
If None, each position attends to all the others.
target_mapping: [optional] float32 Tensor in shape [bsz, num_predict, len].
If target_mapping[k, i, j] = 1, the i-th predict in batch k is
on the j-th token.
Only used during pretraining for partial prediction.
Set to None during finetuning.
inp_q: [optional] float32 Tensor in shape [bsz, len].
1 for tokens with losses and 0 for tokens without losses.
Only used during pretraining for two-stream attention.
Set to None during finetuning.
mem_len: int, the number of tokens to cache.
reuse_len: int, the number of tokens in the currect batch to be cached
and reused in the future.
bi_data: bool, whether to use bidirectional input pipeline.
Usually set to True during pretraining and False during finetuning.
clamp_len: int, clamp all relative distances larger than clamp_len.
-1 means no clamping.
same_length: bool, whether to use the same attention length for each token.
summary_type: str, "last", "first", "mean", or "attn". The method
to pool the input to get a vector representation.
"""
# the original code for XLM uses shapes [len, bsz] with the batch dimension at the end
# but we want a unified interface in the library with the batch size on the first dimension
# so we move here the first dimension (batch) to the end
inp_k = inp_k.transpose(0, 1).contiguous()
token_type_ids = token_type_ids.transpose(0, 1).contiguous() if token_type_ids is not None else None
input_mask = input_mask.transpose(0, 1).contiguous() if input_mask is not None else None
attention_mask = attention_mask.transpose(0, 1).contiguous() if attention_mask is not None else None
perm_mask = perm_mask.permute(1, 2, 0).contiguous() if perm_mask is not None else None
target_mapping = target_mapping.permute(1, 2, 0).contiguous() if target_mapping is not None else None
inp_q = inp_q.transpose(0, 1).contiguous() if inp_q is not None else None
qlen, bsz = inp_k.shape[0], inp_k.shape[1]
mlen = mems[0].shape[0] if mems is not None else 0
klen = mlen + qlen
dtype_float = next(self.parameters()).dtype
device = next(self.parameters()).device
##### Attention mask
# causal attention mask
if self.attn_type == 'uni':
attn_mask = self.create_mask(qlen, mlen)
attn_mask = attn_mask[:, :, None, None]
elif self.attn_type == 'bi':
attn_mask = None
else:
raise ValueError('Unsupported attention type: {}'.format(self.attn_type))
# data mask: input mask & perm mask
assert input_mask is None or attention_mask is None, "You can only use one of input_mask (uses 1 for padding) "
"or attention_mask (uses 0 for padding, added for compatbility with XLM). Please choose one."
if input_mask is None and attention_mask is not None:
input_mask = 1.0 - attention_mask
if input_mask is not None and perm_mask is not None:
data_mask = input_mask[None] + perm_mask
elif input_mask is not None and perm_mask is None:
data_mask = input_mask[None]
elif input_mask is None and perm_mask is not None:
data_mask = perm_mask
else:
data_mask = None
if data_mask is not None:
# all mems can be attended to
mems_mask = torch.zeros([data_mask.shape[0], mlen, bsz]).to(data_mask)
data_mask = torch.cat([mems_mask, data_mask], dim=1)
if attn_mask is None:
attn_mask = data_mask[:, :, :, None]
else:
attn_mask += data_mask[:, :, :, None]
if attn_mask is not None: outputs = [tensor]
attn_mask = (attn_mask > 0).to(dtype_float)
if attn_mask is not None:
non_tgt_mask = -torch.eye(qlen).to(attn_mask)
non_tgt_mask = torch.cat([torch.zeros([qlen, mlen]).to(attn_mask), non_tgt_mask], dim=-1)
non_tgt_mask = ((attn_mask + non_tgt_mask[:, :, None, None]) > 0).to(attn_mask)
else:
non_tgt_mask = None
##### Word embeddings and prepare h & g hidden states
word_emb_k = self.word_embedding(inp_k)
output_h = self.dropout(word_emb_k)
if inp_q is not None:
if target_mapping is not None:
word_emb_q = self.mask_emb.expand(target_mapping.shape[0], bsz, -1)
else:
inp_q_ext = inp_q[:, :, None]
word_emb_q = inp_q_ext * self.mask_emb + (1 - inp_q_ext) * word_emb_k
output_g = self.dropout(word_emb_q)
else:
output_g = None
##### Segment embedding
if token_type_ids is not None:
# Convert `token_type_ids` to one-hot `seg_mat`
mem_pad = torch.zeros([mlen, bsz], dtype=torch.long, device=device)
cat_ids = torch.cat([mem_pad, token_type_ids], dim=0)
# `1` indicates not in the same segment [qlen x klen x bsz]
seg_mat = (token_type_ids[:, None] != cat_ids[None, :]).long()
seg_mat = F.one_hot(seg_mat, num_classes=2).to(dtype_float)
else:
seg_mat = None
##### Positional encoding
pos_emb = self.relative_positional_encoding(qlen, klen, bsz=bsz)
pos_emb = self.dropout(pos_emb)
##### Head mask if needed (for bertology/pruning)
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [n_layer x num_heads]
# and head_mask is converted to shape [n_layer x batch x num_heads x seq_length x seq_length]
if head_mask is not None:
if head_mask.dim() == 1:
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
head_mask = head_mask.expand(self.config.n_layer, -1, -1, -1, -1)
elif head_mask.dim() == 2:
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
else:
head_mask = [None] * self.config.n_layer
new_mems = []
if mems is None:
mems = [None] * len(self.layer)
hidden_states = []
attentions = []
for i, layer_module in enumerate(self.layer):
# cache new mems
new_mems.append(self.cache_mem(output_h, mems[i]))
# Save hidden_states
if output_g is None:
hidden_states.append(output_h)
else:
hidden_states.append((output_h, output_g))
output_h, output_g = layer_module(output_h, output_g,
attn_mask_h=non_tgt_mask, attn_mask_g=attn_mask,
r=pos_emb, seg_mat=seg_mat,
mems=mems[i], target_mapping=target_mapping,
head_mask=head_mask)
# Save last hidden_state
if output_g is None:
hidden_states.append(output_h)
else:
hidden_states.append((output_h, output_g))
# Select the right output and add dropout
output = self.dropout(output_g if output_g is not None else output_h)
# We transpose back here to shape [bsz, len, hidden_dim] (cf. beginning of forward() method)
output = output.permute(1, 0, 2).contiguous()
if output_g is None:
hidden_states = [hs.permute(1, 0, 2).contiguous() for hs in hidden_states]
else:
hidden_states = [h.permute(1, 0, 2).contiguous() for hs in hidden_states for h in hs]
# Build the list of outputs
outputs = [output, new_mems]
if self.output_attentions:
outputs.append(attentions)
if self.output_hidden_states: if self.output_hidden_states:
outputs.append(hidden_states) outputs.append(hidden_states)
if self.output_attentions:
return outputs outputs.append(attentions)
return outputs # outputs, (hidden_states), (attentions)
class XLMPredLayer(nn.Module): class XLMPredLayer(nn.Module):
...@@ -1275,63 +690,59 @@ class XLMPredLayer(nn.Module): ...@@ -1275,63 +690,59 @@ class XLMPredLayer(nn.Module):
return self.proj.log_prob(x) if self.asm else self.proj(x) return self.proj.log_prob(x) if self.asm else self.proj(x)
class XLMLMHeadModel(XLMPreTrainedModel):
"""XLM model ("XLM: Generalized Autoregressive Pretraining for Language Understanding").
Params: class XLMWithLMHeadModel(XLMPreTrainedModel):
`config`: a XLMConfig class instance with the configuration to build a new model """ XLM model from: "Cross-lingual Language Model Pretraining" by Guillaume Lample, Alexis Conneau
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False Paper: https://arxiv.org/abs/1901.07291
`keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient. Original code: https://github.com/facebookresearch/XLM
This can be used to compute head importance metrics. Default: False
Inputs: Params:
inp_k: int32 Tensor in shape [bsz, len], the input token IDs. `config`: a XLMConfig class instance with the configuration to build a new model
token_type_ids: int32 Tensor in shape [bsz, len], the input segment IDs. `output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
attention_mask: [optional] float32 Tensor in shape [bsz, len], the input mask. `keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient.
0 for real tokens and 1 for padding. This can be used to compute head importance metrics. Default: False
mems: [optional] a list of float32 Tensors in shape [mem_len, bsz, d_model], memory
from previous batches. The length of the list equals n_layer.
If None, no memory is used.
perm_mask: [optional] float32 Tensor in shape [bsz, len, len].
If perm_mask[k, i, j] = 0, i attend to j in batch k;
if perm_mask[k, i, j] = 1, i does not attend to j in batch k.
If None, each position attends to all the others.
target_mapping: [optional] float32 Tensor in shape [bsz, num_predict, len].
If target_mapping[k, i, j] = 1, the i-th predict in batch k is
on the j-th token.
Only used during pretraining for partial prediction.
Set to None during finetuning.
inp_q: [optional] float32 Tensor in shape [bsz, len].
1 for tokens with losses and 0 for tokens without losses.
Only used during pretraining for two-stream attention.
Set to None during finetuning.
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
`run_bert_extract_features.py`, `run_bert_classifier.py` and `run_bert_squad.py`)
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
a `sentence B` token (see XLM paper for more details).
`attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
input sequence length in the current batch. It's the mask that we typically use for attention when
a batch has varying length sentences.
`output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`.
`head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1.
It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked.
Outputs: Tuple of (encoded_layers, pooled_output)
`encoded_layers`: controled by `output_all_encoded_layers` argument:
- `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end
of each attention block (i.e. 12 full sequences for XLM-base, 24 for XLM-large), each
encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, d_model],
- `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding
to the last attention block of shape [batch_size, sequence_length, d_model],
`pooled_output`: a torch.FloatTensor of size [batch_size, d_model] which is the output of a
classifier pretrained on top of the hidden state associated to the first character of the
input (`CLS`) to train on the Next-Sentence task (see XLM's paper).
Example usage: Outputs: Tuple of (encoded_layers, pooled_output)
```python `encoded_layers`: controled by `output_all_encoded_layers` argument:
# Already been converted into WordPiece token ids - `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) of each attention block (i.e. 12 full sequences for XLM-base, 24 for XLM-large), each
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size],
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) - `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding
to the last attention block of shape [batch_size, sequence_length, hidden_size],
`pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a
classifier pretrained on top of the hidden state associated to the first character of the
input (`CLS`) to train on the Next-Sentence task (see XLM's paper).
config = modeling.XLMConfig(vocab_size_or_config_json_file=32000, d_model=768, Example usage:
n_layer=12, num_attention_heads=12, intermediate_size=3072) ```python
# Already been converted into WordPiece token ids
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
model = modeling.XLMModel(config=config) config = modeling.XLMConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
```
""" model = modeling.XLMModel(config=config)
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
```
"""
def __init__(self, config, output_attentions=False, output_hidden_states=False): def __init__(self, config, output_attentions=False, output_hidden_states=False):
super(XLMLMHeadModel, self).__init__(config) super(XLMLMHeadModel, self).__init__(config)
self.output_attentions = output_attentions self.output_attentions = output_attentions
...@@ -1341,9 +752,7 @@ class XLMLMHeadModel(XLMPreTrainedModel): ...@@ -1341,9 +752,7 @@ class XLMLMHeadModel(XLMPreTrainedModel):
self.same_length = config.same_length self.same_length = config.same_length
self.transformer = XLMModel(config, output_attentions=output_attentions, output_hidden_states=output_hidden_states) self.transformer = XLMModel(config, output_attentions=output_attentions, output_hidden_states=output_hidden_states)
self.lm_loss = nn.Linear(config.d_model, config.n_token, bias=True) self.pred_layer = XLMPredLayer(config)
# Tie weights
self.apply(self.init_weights) self.apply(self.init_weights)
self.tie_weights() self.tie_weights()
...@@ -1351,10 +760,9 @@ class XLMLMHeadModel(XLMPreTrainedModel): ...@@ -1351,10 +760,9 @@ class XLMLMHeadModel(XLMPreTrainedModel):
def tie_weights(self): def tie_weights(self):
""" Make sure we are sharing the embeddings """ Make sure we are sharing the embeddings
""" """
self.lm_loss.weight = self.transformer.word_embedding.weight self.pred_layer.proj.weight = self.transformer.embeddings.weight
def forward(self, inp_k, token_type_ids=None, input_mask=None, attention_mask=None, def forward(self, x, lengths, positions=None, langs=None, cache=None,
mems=None, perm_mask=None, target_mapping=None, inp_q=None,
labels=None, head_mask=None): labels=None, head_mask=None):
""" """
Args: Args:
...@@ -1382,11 +790,10 @@ class XLMLMHeadModel(XLMPreTrainedModel): ...@@ -1382,11 +790,10 @@ class XLMLMHeadModel(XLMPreTrainedModel):
summary_type: str, "last", "first", "mean", or "attn". The method summary_type: str, "last", "first", "mean", or "attn". The method
to pool the input to get a vector representation. to pool the input to get a vector representation.
""" """
transformer_outputs = self.transformer(inp_k, token_type_ids, input_mask, attention_mask, transformer_outputs = self.transformer(x, lengths, positions=positions, langs=langs, cache=cache, head_mask=head_mask)
mems, perm_mask, target_mapping, inp_q, head_mask)
output = transformer_outputs[0] output = transformer_outputs[0]
logits = self.lm_loss(output) logits = self.pred_layer(output, labels)
outputs = transformer_outputs[1:] # Keep new_mems and attention/hidden states if they are here outputs = transformer_outputs[1:] # Keep new_mems and attention/hidden states if they are here
......
...@@ -198,7 +198,7 @@ class XLNetConfig(PretrainedConfig): ...@@ -198,7 +198,7 @@ class XLNetConfig(PretrainedConfig):
pretrained_config_archive_map = PRETRAINED_CONFIG_ARCHIVE_MAP pretrained_config_archive_map = PRETRAINED_CONFIG_ARCHIVE_MAP
def __init__(self, def __init__(self,
vocab_size_or_config_json_file, vocab_size_or_config_json_file=32000,
d_model=1024, d_model=1024,
n_layer=24, n_layer=24,
n_head=16, n_head=16,
...@@ -221,7 +221,12 @@ class XLNetConfig(PretrainedConfig): ...@@ -221,7 +221,12 @@ class XLNetConfig(PretrainedConfig):
bi_data=False, bi_data=False,
clamp_len=-1, clamp_len=-1,
same_length=False, same_length=False,
finetuning_task=None):
finetuning_task=None,
num_labels=2,
summary_type="last",
use_proj=True,
**kwargs):
"""Constructs XLNetConfig. """Constructs XLNetConfig.
Args: Args:
...@@ -265,6 +270,8 @@ class XLNetConfig(PretrainedConfig): ...@@ -265,6 +270,8 @@ class XLNetConfig(PretrainedConfig):
same_length: bool, whether to use the same attention length for each token. same_length: bool, whether to use the same attention length for each token.
finetuning_task: name of the glue task on which the model was fine-tuned if any finetuning_task: name of the glue task on which the model was fine-tuned if any
""" """
super(XLNetConfig, self).__init__(**kwargs)
if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2 if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
and isinstance(vocab_size_or_config_json_file, unicode)): and isinstance(vocab_size_or_config_json_file, unicode)):
with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader: with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
...@@ -297,7 +304,11 @@ class XLNetConfig(PretrainedConfig): ...@@ -297,7 +304,11 @@ class XLNetConfig(PretrainedConfig):
self.bi_data = bi_data self.bi_data = bi_data
self.clamp_len = clamp_len self.clamp_len = clamp_len
self.same_length = same_length self.same_length = same_length
self.finetuning_task = finetuning_task self.finetuning_task = finetuning_task
self.num_labels = num_labels
self.summary_type = summary_type
self.use_proj = use_proj
else: else:
raise ValueError("First argument must be either a vocabulary size (int)" raise ValueError("First argument must be either a vocabulary size (int)"
"or the path to a pretrained model config file (str)") "or the path to a pretrained model config file (str)")
...@@ -323,9 +334,10 @@ except ImportError: ...@@ -323,9 +334,10 @@ except ImportError:
return self.weight * x + self.bias return self.weight * x + self.bias
class XLNetRelativeAttention(nn.Module): class XLNetRelativeAttention(nn.Module):
def __init__(self, config, output_attentions=False): def __init__(self, config):
super(XLNetRelativeAttention, self).__init__() super(XLNetRelativeAttention, self).__init__()
self.output_attentions = output_attentions self.output_attentions = config.output_attentions
if config.d_model % config.n_head != 0: if config.d_model % config.n_head != 0:
raise ValueError( raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention " "The hidden size (%d) is not a multiple of the number of attention "
...@@ -533,10 +545,9 @@ class XLNetFeedForward(nn.Module): ...@@ -533,10 +545,9 @@ class XLNetFeedForward(nn.Module):
return output return output
class XLNetLayer(nn.Module): class XLNetLayer(nn.Module):
def __init__(self, config, output_attentions=False, ): def __init__(self, config):
super(XLNetLayer, self).__init__() super(XLNetLayer, self).__init__()
self.output_attentions = output_attentions self.rel_attn = XLNetRelativeAttention(config)
self.rel_attn = XLNetRelativeAttention(config, output_attentions=output_attentions)
self.ff = XLNetFeedForward(config) self.ff = XLNetFeedForward(config)
self.dropout = nn.Dropout(config.dropout) self.dropout = nn.Dropout(config.dropout)
...@@ -562,7 +573,6 @@ class XLNetPreTrainedModel(PreTrainedModel): ...@@ -562,7 +573,6 @@ class XLNetPreTrainedModel(PreTrainedModel):
""" """
config_class = XLNetConfig config_class = XLNetConfig
pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP
pretrained_config_archive_map = PRETRAINED_CONFIG_ARCHIVE_MAP
load_tf_weights = load_tf_weights_in_xlnet load_tf_weights = load_tf_weights_in_xlnet
base_model_prefix = "transformer" base_model_prefix = "transformer"
...@@ -589,10 +599,10 @@ class XLNetPreTrainedModel(PreTrainedModel): ...@@ -589,10 +599,10 @@ class XLNetPreTrainedModel(PreTrainedModel):
class XLNetModel(XLNetPreTrainedModel): class XLNetModel(XLNetPreTrainedModel):
def __init__(self, config, output_attentions=False, output_hidden_states=False): def __init__(self, config):
super(XLNetModel, self).__init__(config) super(XLNetModel, self).__init__(config)
self.output_attentions = output_attentions self.output_attentions = config.output_attentions
self.output_hidden_states = output_hidden_states self.output_hidden_states = config.output_hidden_states
self.mem_len = config.mem_len self.mem_len = config.mem_len
self.reuse_len = config.reuse_len self.reuse_len = config.reuse_len
...@@ -601,25 +611,17 @@ class XLNetModel(XLNetPreTrainedModel): ...@@ -601,25 +611,17 @@ class XLNetModel(XLNetPreTrainedModel):
self.attn_type = config.attn_type self.attn_type = config.attn_type
self.bi_data = config.bi_data self.bi_data = config.bi_data
self.clamp_len = config.clamp_len self.clamp_len = config.clamp_len
self.n_layer = config.n_layer
self.word_embedding = nn.Embedding(config.n_token, config.d_model) self.word_embedding = nn.Embedding(config.n_token, config.d_model)
self.mask_emb = nn.Parameter(torch.Tensor(1, 1, config.d_model)) self.mask_emb = nn.Parameter(torch.Tensor(1, 1, config.d_model))
layer = XLNetLayer(config, output_attentions=output_attentions) layer = XLNetLayer(config)
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.n_layer)]) self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.n_layer)])
self.dropout = nn.Dropout(config.dropout) self.dropout = nn.Dropout(config.dropout)
def prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model. logger.info("Head pruning is not implemented for XLNet")
heads_to_prune: dict of {layer_num: list of heads to prune in this layer} pass
"""
for layer, heads in heads_to_prune.items():
self.layer[layer].attention.prune_heads(heads)
def get_multihead_outputs(self):
""" Gather all multi-head outputs.
Return: list (layers) of multihead module outputs with gradients
"""
return [layer.attention.self.multihead_output for layer in self.layer]
def create_mask(self, qlen, mlen): def create_mask(self, qlen, mlen):
""" create causal attention mask. """ create causal attention mask.
...@@ -708,11 +710,11 @@ class XLNetModel(XLNetPreTrainedModel): ...@@ -708,11 +710,11 @@ class XLNetModel(XLNetPreTrainedModel):
pos_emb = pos_emb.to(next(self.parameters())) pos_emb = pos_emb.to(next(self.parameters()))
return pos_emb return pos_emb
def forward(self, inp_k, token_type_ids=None, input_mask=None, attention_mask=None, def forward(self, input_ids, token_type_ids=None, input_mask=None, attention_mask=None,
mems=None, perm_mask=None, target_mapping=None, inp_q=None, head_mask=None): mems=None, perm_mask=None, target_mapping=None, inp_q=None, head_mask=None):
""" """
Args: Args:
inp_k: int32 Tensor in shape [bsz, len], the input token IDs. input_ids: int32 Tensor in shape [bsz, len], the input token IDs.
token_type_ids: int32 Tensor in shape [bsz, len], the input segment IDs. token_type_ids: int32 Tensor in shape [bsz, len], the input segment IDs.
input_mask: [optional] float32 Tensor in shape [bsz, len], the input mask. input_mask: [optional] float32 Tensor in shape [bsz, len], the input mask.
0 for real tokens and 1 for padding. 0 for real tokens and 1 for padding.
...@@ -751,7 +753,7 @@ class XLNetModel(XLNetPreTrainedModel): ...@@ -751,7 +753,7 @@ class XLNetModel(XLNetPreTrainedModel):
# the original code for XLNet uses shapes [len, bsz] with the batch dimension at the end # the original code for XLNet uses shapes [len, bsz] with the batch dimension at the end
# but we want a unified interface in the library with the batch size on the first dimension # but we want a unified interface in the library with the batch size on the first dimension
# so we move here the first dimension (batch) to the end # so we move here the first dimension (batch) to the end
inp_k = inp_k.transpose(0, 1).contiguous() input_ids = input_ids.transpose(0, 1).contiguous()
token_type_ids = token_type_ids.transpose(0, 1).contiguous() if token_type_ids is not None else None token_type_ids = token_type_ids.transpose(0, 1).contiguous() if token_type_ids is not None else None
input_mask = input_mask.transpose(0, 1).contiguous() if input_mask is not None else None input_mask = input_mask.transpose(0, 1).contiguous() if input_mask is not None else None
attention_mask = attention_mask.transpose(0, 1).contiguous() if attention_mask is not None else None attention_mask = attention_mask.transpose(0, 1).contiguous() if attention_mask is not None else None
...@@ -759,7 +761,7 @@ class XLNetModel(XLNetPreTrainedModel): ...@@ -759,7 +761,7 @@ class XLNetModel(XLNetPreTrainedModel):
target_mapping = target_mapping.permute(1, 2, 0).contiguous() if target_mapping is not None else None target_mapping = target_mapping.permute(1, 2, 0).contiguous() if target_mapping is not None else None
inp_q = inp_q.transpose(0, 1).contiguous() if inp_q is not None else None inp_q = inp_q.transpose(0, 1).contiguous() if inp_q is not None else None
qlen, bsz = inp_k.shape[0], inp_k.shape[1] qlen, bsz = input_ids.shape[0], input_ids.shape[1]
mlen = mems[0].shape[0] if mems is not None else 0 mlen = mems[0].shape[0] if mems is not None else 0
klen = mlen + qlen klen = mlen + qlen
...@@ -810,7 +812,7 @@ class XLNetModel(XLNetPreTrainedModel): ...@@ -810,7 +812,7 @@ class XLNetModel(XLNetPreTrainedModel):
non_tgt_mask = None non_tgt_mask = None
##### Word embeddings and prepare h & g hidden states ##### Word embeddings and prepare h & g hidden states
word_emb_k = self.word_embedding(inp_k) word_emb_k = self.word_embedding(input_ids)
output_h = self.dropout(word_emb_k) output_h = self.dropout(word_emb_k)
if inp_q is not None: if inp_q is not None:
if target_mapping is not None: if target_mapping is not None:
...@@ -838,20 +840,20 @@ class XLNetModel(XLNetPreTrainedModel): ...@@ -838,20 +840,20 @@ class XLNetModel(XLNetPreTrainedModel):
pos_emb = self.relative_positional_encoding(qlen, klen, bsz=bsz) pos_emb = self.relative_positional_encoding(qlen, klen, bsz=bsz)
pos_emb = self.dropout(pos_emb) pos_emb = self.dropout(pos_emb)
##### Head mask if needed (for bertology/pruning) # Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head # 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N # attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [n_layer x num_heads] # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] (a head_mask for each layer)
# and head_mask is converted to shape [n_layer x batch x num_heads x seq_length x seq_length] # and head_mask is converted to shape [num_hidden_layers x qlen x klen x bsz x n_head]
if head_mask is not None: if head_mask is not None:
if head_mask.dim() == 1: if head_mask.dim() == 1:
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(0).unsqueeze(0)
head_mask = head_mask.expand(self.config.n_layer, -1, -1, -1, -1) head_mask = head_mask.expand(self.n_layer, -1, -1, -1, -1)
elif head_mask.dim() == 2: elif head_mask.dim() == 2:
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer head_mask = head_mask.unsqueeze(1).unsqueeze(1).unsqueeze(1)
head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
else: else:
head_mask = [None] * self.config.n_layer head_mask = [None] * self.n_layer
new_mems = [] new_mems = []
if mems is None: if mems is None:
...@@ -870,7 +872,7 @@ class XLNetModel(XLNetPreTrainedModel): ...@@ -870,7 +872,7 @@ class XLNetModel(XLNetPreTrainedModel):
head_mask=head_mask[i]) head_mask=head_mask[i])
output_h, output_g = outputs[:2] output_h, output_g = outputs[:2]
if self.output_attentions: if self.output_attentions:
attentions.append(outputs[2:]) attentions.append(outputs[2])
# Add last hidden state # Add last hidden state
if self.output_hidden_states: if self.output_hidden_states:
...@@ -887,6 +889,7 @@ class XLNetModel(XLNetPreTrainedModel): ...@@ -887,6 +889,7 @@ class XLNetModel(XLNetPreTrainedModel):
hidden_states = [hs.permute(1, 0, 2).contiguous() for hs in hidden_states] hidden_states = [hs.permute(1, 0, 2).contiguous() for hs in hidden_states]
outputs.append(hidden_states) outputs.append(hidden_states)
if self.output_attentions: if self.output_attentions:
attentions = list(t.permute(2, 3, 0, 1).contiguous() for t in attentions)
outputs.append(attentions) outputs.append(attentions)
return outputs # outputs, new_mems, (hidden_states), (attentions) return outputs # outputs, new_mems, (hidden_states), (attentions)
...@@ -902,7 +905,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel): ...@@ -902,7 +905,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
This can be used to compute head importance metrics. Default: False This can be used to compute head importance metrics. Default: False
Inputs: Inputs:
inp_k: int32 Tensor in shape [bsz, len], the input token IDs. input_ids: int32 Tensor in shape [bsz, len], the input token IDs.
token_type_ids: int32 Tensor in shape [bsz, len], the input segment IDs. token_type_ids: int32 Tensor in shape [bsz, len], the input segment IDs.
input_mask: [optional] float32 Tensor in shape [bsz, len], the input mask. input_mask: [optional] float32 Tensor in shape [bsz, len], the input mask.
0 for real tokens and 1 for padding. 0 for real tokens and 1 for padding.
...@@ -953,16 +956,12 @@ class XLNetLMHeadModel(XLNetPreTrainedModel): ...@@ -953,16 +956,12 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
``` ```
""" """
def __init__(self, config, output_attentions=False, output_hidden_states=False): def __init__(self, config):
super(XLNetLMHeadModel, self).__init__(config) super(XLNetLMHeadModel, self).__init__(config)
self.output_attentions = output_attentions
self.output_hidden_states = output_hidden_states
self.attn_type = config.attn_type self.attn_type = config.attn_type
self.same_length = config.same_length self.same_length = config.same_length
self.transformer = XLNetModel(config, output_attentions=output_attentions, self.transformer = XLNetModel(config)
output_hidden_states=output_hidden_states)
self.lm_loss = nn.Linear(config.d_model, config.n_token, bias=True) self.lm_loss = nn.Linear(config.d_model, config.n_token, bias=True)
# Tie weights # Tie weights
...@@ -975,12 +974,12 @@ class XLNetLMHeadModel(XLNetPreTrainedModel): ...@@ -975,12 +974,12 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
""" """
self.lm_loss.weight = self.transformer.word_embedding.weight self.lm_loss.weight = self.transformer.word_embedding.weight
def forward(self, inp_k, token_type_ids=None, input_mask=None, attention_mask=None, def forward(self, input_ids, token_type_ids=None, input_mask=None, attention_mask=None,
mems=None, perm_mask=None, target_mapping=None, inp_q=None, mems=None, perm_mask=None, target_mapping=None, inp_q=None,
labels=None, head_mask=None): labels=None, head_mask=None):
""" """
Args: Args:
inp_k: int32 Tensor in shape [bsz, len], the input token IDs. input_ids: int32 Tensor in shape [bsz, len], the input token IDs.
token_type_ids: int32 Tensor in shape [bsz, len], the input segment IDs. token_type_ids: int32 Tensor in shape [bsz, len], the input segment IDs.
input_mask: float32 Tensor in shape [bsz, len], the input mask. input_mask: float32 Tensor in shape [bsz, len], the input mask.
0 for real tokens and 1 for padding. 0 for real tokens and 1 for padding.
...@@ -1008,7 +1007,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel): ...@@ -1008,7 +1007,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
summary_type: str, "last", "first", "mean", or "attn". The method summary_type: str, "last", "first", "mean", or "attn". The method
to pool the input to get a vector representation. to pool the input to get a vector representation.
""" """
transformer_outputs = self.transformer(inp_k, token_type_ids, input_mask, attention_mask, transformer_outputs = self.transformer(input_ids, token_type_ids, input_mask, attention_mask,
mems, perm_mask, target_mapping, inp_q, head_mask) mems, perm_mask, target_mapping, inp_q, head_mask)
logits = self.lm_loss(transformer_outputs[0]) logits = self.lm_loss(transformer_outputs[0])
...@@ -1025,14 +1024,14 @@ class XLNetLMHeadModel(XLNetPreTrainedModel): ...@@ -1025,14 +1024,14 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
return outputs # return (loss), logits, (mems), (hidden states), (attentions) return outputs # return (loss), logits, (mems), (hidden states), (attentions)
class XLNetSequenceSummary(nn.Module): class XLNetSequenceSummary(nn.Module):
def __init__(self, config, summary_type="last", use_proj=True): def __init__(self, config):
super(XLNetSequenceSummary, self).__init__() super(XLNetSequenceSummary, self).__init__()
self.summary_type = summary_type self.summary_type = config.summary_type
if use_proj: if config.use_proj:
self.summary = nn.Linear(config.d_model, config.d_model) self.summary = nn.Linear(config.d_model, config.d_model)
else: else:
self.summary = None self.summary = None
if summary_type == 'attn': if config.summary_type == 'attn':
# We should use a standard multi-head attention module with absolute positional embedding for that. # We should use a standard multi-head attention module with absolute positional embedding for that.
# Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276 # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
# We can probably just use the multi-head attention module of PyTorch >=1.1.0 # We can probably just use the multi-head attention module of PyTorch >=1.1.0
...@@ -1069,7 +1068,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel): ...@@ -1069,7 +1068,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
to pool the input to get a vector representation. Default: last to pool the input to get a vector representation. Default: last
Inputs: Inputs:
inp_k: int32 Tensor in shape [bsz, len], the input token IDs. input_ids: int32 Tensor in shape [bsz, len], the input token IDs.
token_type_ids: int32 Tensor in shape [bsz, len], the input segment IDs. token_type_ids: int32 Tensor in shape [bsz, len], the input segment IDs.
input_mask: float32 Tensor in shape [bsz, len], the input mask. input_mask: float32 Tensor in shape [bsz, len], the input mask.
0 for real tokens and 1 for padding. 0 for real tokens and 1 for padding.
...@@ -1121,30 +1120,21 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel): ...@@ -1121,30 +1120,21 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
``` ```
""" """
def __init__(self, config, summary_type="last", use_proj=True, num_labels=2, def __init__(self, config):
output_attentions=False, output_hidden_states=False):
super(XLNetForSequenceClassification, self).__init__(config) super(XLNetForSequenceClassification, self).__init__(config)
self.output_attentions = output_attentions
self.output_hidden_states = output_hidden_states
self.attn_type = config.attn_type self.transformer = XLNetModel(config)
self.same_length = config.same_length self.sequence_summary = XLNetSequenceSummary(config)
self.summary_type = summary_type self.logits_proj = nn.Linear(config.d_model, config.num_labels)
self.num_labels = num_labels
self.transformer = XLNetModel(config, output_attentions=output_attentions,
output_hidden_states=output_hidden_states)
self.sequence_summary = XLNetSequenceSummary(config, summary_type=summary_type, use_proj=use_proj)
self.logits_proj = nn.Linear(config.d_model, num_labels)
self.apply(self.init_weights) self.apply(self.init_weights)
def forward(self, inp_k, token_type_ids=None, input_mask=None, attention_mask=None, def forward(self, input_ids, token_type_ids=None, input_mask=None, attention_mask=None,
mems=None, perm_mask=None, target_mapping=None, inp_q=None, mems=None, perm_mask=None, target_mapping=None, inp_q=None,
labels=None, head_mask=None): labels=None, head_mask=None):
""" """
Args: Args:
inp_k: int32 Tensor in shape [bsz, len], the input token IDs. input_ids: int32 Tensor in shape [bsz, len], the input token IDs.
token_type_ids: int32 Tensor in shape [bsz, len], the input segment IDs. token_type_ids: int32 Tensor in shape [bsz, len], the input segment IDs.
input_mask: float32 Tensor in shape [bsz, len], the input mask. input_mask: float32 Tensor in shape [bsz, len], the input mask.
0 for real tokens and 1 for padding. 0 for real tokens and 1 for padding.
...@@ -1169,7 +1159,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel): ...@@ -1169,7 +1159,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
Only used during pretraining for two-stream attention. Only used during pretraining for two-stream attention.
Set to None during finetuning. Set to None during finetuning.
""" """
transformer_outputs = self.transformer(inp_k, token_type_ids, input_mask, attention_mask, transformer_outputs = self.transformer(input_ids, token_type_ids, input_mask, attention_mask,
mems, perm_mask, target_mapping, inp_q, head_mask) mems, perm_mask, target_mapping, inp_q, head_mask)
output = transformer_outputs[0] output = transformer_outputs[0]
...@@ -1247,20 +1237,18 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel): ...@@ -1247,20 +1237,18 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
start_logits, end_logits = model(input_ids, token_type_ids, input_mask) start_logits, end_logits = model(input_ids, token_type_ids, input_mask)
``` ```
""" """
def __init__(self, config, output_attentions=False, output_hidden_states=False): def __init__(self, config):
super(XLNetForQuestionAnswering, self).__init__(config) super(XLNetForQuestionAnswering, self).__init__(config)
self.output_attentions = output_attentions
self.output_hidden_states = output_hidden_states
self.transformer = XLNetModel(config, output_attentions=output_attentions, self.transformer = XLNetModel(config)
output_hidden_states=output_hidden_states) self.qa_outputs = nn.Linear(config.d_model, config.num_labels)
self.qa_outputs = nn.Linear(config.hidden_size, 2)
self.apply(self.init_weights) self.apply(self.init_weights)
def forward(self, inp_k, token_type_ids=None, input_mask=None, attention_mask=None, def forward(self, input_ids, token_type_ids=None, input_mask=None, attention_mask=None,
mems=None, perm_mask=None, target_mapping=None, inp_q=None, mems=None, perm_mask=None, target_mapping=None, inp_q=None,
start_positions=None, end_positions=None, head_mask=None): start_positions=None, end_positions=None, head_mask=None):
transformer_outputs = self.transformer(inp_k, token_type_ids, input_mask, attention_mask, transformer_outputs = self.transformer(input_ids, token_type_ids, input_mask, attention_mask,
mems, perm_mask, target_mapping, inp_q, head_mask) mems, perm_mask, target_mapping, inp_q, head_mask)
logits = self.qa_outputs(transformer_outputs[0]) logits = self.qa_outputs(transformer_outputs[0])
......
# coding=utf-8
# Copyright 2019 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import shutil
import json
import random
import torch
def create_and_check_for_headmasking(tester, model_classes, config, inputs_dict):
for model_class in model_classes:
config.output_hidden_states = True
model = model_class(config=config)
model.eval()
head_mask = torch.zeros(tester.num_hidden_layers, tester.num_attention_heads)
# Set that after having prepared the tensor to avoid error (leaf variable has been moved into the graph interior)
head_mask.requires_grad_(requires_grad=True)
outputs = model(**inputs_dict, head_mask=head_mask)
# Compute some gradients
output = sum(t.sum() for t in outputs[0])
output = output.sum()
output.backward()
multihead_outputs = head_mask.grad
tester.parent.assertEqual(len(multihead_outputs), tester.num_hidden_layers)
# self.parent.assertListEqual(
# list(multihead_outputs[0].size()),
# [self.batch_size, self.num_attention_heads,
# self.seq_length, self.hidden_size // self.num_attention_heads])
# self.parent.assertEqual(
# len(multihead_outputs[0][:, 1:(self.num_attention_heads-1), :, :].nonzero()),
# 0)
# self.parent.assertEqual(
# len(multihead_outputs[0][:, 0, :, :].nonzero()),
# self.batch_size * self.seq_length * self.hidden_size // self.num_attention_heads)
# self.parent.assertEqual(
# len(multihead_outputs[0][:, self.num_attention_heads-1, :, :].nonzero()),
# self.batch_size * self.seq_length * self.hidden_size // self.num_attention_heads)
# self.parent.assertListEqual(
# list(multihead_outputs[1].size()),
# [self.batch_size, self.num_attention_heads,
# self.seq_length, self.hidden_size // self.num_attention_heads])
# self.parent.assertEqual(
# len(multihead_outputs[1].nonzero()),
# multihead_outputs[1].numel())
# self.parent.assertListEqual(
# list(multihead_outputs[-1].size()),
# [self.batch_size, self.num_attention_heads,
# self.seq_length, self.hidden_size // self.num_attention_heads])
# self.parent.assertEqual(
# len(multihead_outputs[-1][:, 1:, :, :].nonzero()),
# 0)
# self.parent.assertEqual(
# len(multihead_outputs[-1][:, 0, :, :].nonzero()),
# self.batch_size * self.seq_length * self.hidden_size // self.num_attention_heads)
def create_and_check_for_head_pruning(tester, model_classes, config, inputs_dict):
for model_class in model_classes:
model = model_class(config=config)
model.eval()
heads_to_prune = {0: list(range(1, tester.num_attention_heads)),
-1: [0]}
model.prune_heads(heads_to_prune)
outputs = model(**inputs_dict)
# output = sum(t.sum() for t in outputs[0])
# output = output.sum()
# output.backward()
# multihead_outputs = bert_model.get_multihead_outputs()
# self.parent.assertEqual(len(multihead_outputs), self.num_hidden_layers)
# self.parent.assertListEqual(
# list(multihead_outputs[0].size()),
# [self.batch_size, 1,
# self.seq_length, self.hidden_size // self.num_attention_heads])
# self.parent.assertListEqual(
# list(multihead_outputs[1].size()),
# [self.batch_size, self.num_attention_heads,
# self.seq_length, self.hidden_size // self.num_attention_heads])
# self.parent.assertListEqual(
# list(multihead_outputs[-1].size()),
# [self.batch_size, self.num_attention_heads-1,
# self.seq_length, self.hidden_size // self.num_attention_heads])
def create_and_check_for_attentions(tester, model_classes, config, inputs_dict):
for model_class in model_classes:
config.output_attentions = True
config.output_hidden_states = False
model = model_class(config)
model.eval()
outputs = model(**inputs_dict)
attentions = outputs[-1]
tester.parent.assertEqual(model.config.output_attentions, True)
tester.parent.assertEqual(model.config.output_hidden_states, False)
tester.parent.assertEqual(len(attentions), tester.num_hidden_layers)
tester.parent.assertListEqual(
list(attentions[0].shape[-3:]),
[tester.num_attention_heads,
tester.seq_length,
tester.key_len if hasattr(tester, 'key_len') else tester.seq_length])
out_len = len(outputs)
# Check attention is always last and order is fine
config.output_attentions = True
config.output_hidden_states = True
model = model_class(config)
model.eval()
outputs = model(**inputs_dict)
tester.parent.assertEqual(out_len+1, len(outputs))
tester.parent.assertEqual(model.config.output_attentions, True)
tester.parent.assertEqual(model.config.output_hidden_states, True)
attentions = outputs[-1]
tester.parent.assertEqual(len(attentions), tester.num_hidden_layers)
tester.parent.assertListEqual(
list(attentions[0].shape[-3:]),
[tester.num_attention_heads,
tester.seq_length,
tester.key_len if hasattr(tester, 'key_len') else tester.seq_length])
def create_and_check_for_hidden_states(tester, model_classes, config, inputs_dict):
for model_class in model_classes:
config.output_hidden_states = True
config.output_attentions = False
model = model_class(config)
model.eval()
outputs = model(**inputs_dict)
hidden_states = outputs[-1]
tester.parent.assertEqual(model.config.output_attentions, False)
tester.parent.assertEqual(model.config.output_hidden_states, True)
tester.parent.assertEqual(len(hidden_states), tester.num_hidden_layers + 1)
tester.parent.assertListEqual(
list(hidden_states[0].shape[-2:]),
[tester.seq_length, tester.hidden_size])
def create_and_check_commons(tester, config, inputs_dict):
create_and_check_for_attentions(tester, tester.all_model_classes, config, inputs_dict)
create_and_check_for_headmasking(tester, tester.all_model_classes, config, inputs_dict)
create_and_check_for_head_pruning(tester, tester.all_model_classes, config, inputs_dict)
create_and_check_for_hidden_states(tester, tester.all_model_classes, config, inputs_dict)
def ids_tensor(shape, vocab_size, rng=None, name=None):
"""Creates a random int32 tensor of the shape within the vocab size."""
if rng is None:
rng = random.Random()
total_dims = 1
for dim in shape:
total_dims *= dim
values = []
for _ in range(total_dims):
values.append(rng.randint(0, vocab_size - 1))
return torch.tensor(data=values, dtype=torch.long).view(shape).contiguous()
class ConfigTester(object):
def __init__(self, parent, config_class=None, **kwargs):
self.parent = parent
self.config_class = config_class
self.inputs_dict = kwargs
def create_and_test_config_to_json_string(self):
config = self.config_class(**self.inputs_dict)
obj = json.loads(config.to_json_string())
for key, value in self.inputs_dict.items():
self.parent.assertEqual(obj[key], value)
def create_and_test_config_to_json_file(self):
config_first = self.config_class(**self.inputs_dict)
json_file_path = "/tmp/config.json"
config_first.to_json_file(json_file_path)
config_second = self.config_class.from_json_file(json_file_path)
os.remove(json_file_path)
self.parent.assertEqual(config_second.to_dict(), config_first.to_dict())
def run_common_tests(self):
self.create_and_test_config_to_json_string()
self.create_and_test_config_to_json_file()
class GPTModelTester(object):
def __init__(self,
parent,
batch_size=13,
seq_length=7,
is_training=True,
use_position_ids=True,
use_token_type_ids=True,
use_labels=True,
vocab_size=99,
n_special=1,
n_positions=33,
hidden_size=32,
num_hidden_layers=5,
num_attention_heads=4,
n_choices=3,
type_sequence_label_size=2,
initializer_range=0.02,
num_labels=3,
scope=None,
config_class=None,
base_model_class=None,
lm_head_model_class=None,
double_head_model_class=None,
):
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
self.is_training = is_training
self.use_position_ids = use_position_ids
self.use_token_type_ids = use_token_type_ids
self.use_labels = use_labels
self.vocab_size = vocab_size
self.n_special = n_special
self.n_positions = n_positions
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.n_choices = n_choices
self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range
self.num_labels = num_labels
self.scope = scope
self.config_class = config_class
self.base_model_class = base_model_class
self.lm_head_model_class = lm_head_model_class
self.double_head_model_class = double_head_model_class
self.all_model_classes = (base_model_class, lm_head_model_class, double_head_model_class)
def prepare_config_and_inputs(self):
total_num_tokens = self.vocab_size + self.n_special
input_ids = ids_tensor([self.batch_size, self.n_choices, self.seq_length], total_num_tokens)
position_ids = None
if self.use_position_ids:
position_ids = ids_tensor([self.batch_size, self.n_choices, self.seq_length], self.n_positions)
token_type_ids = None
if self.use_token_type_ids:
total_voc = self.vocab_size
token_type_ids = ids_tensor([self.batch_size, self.n_choices, self.seq_length], total_voc)
mc_labels = None
lm_labels = None
mc_token_ids = None
if self.use_labels:
mc_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
lm_labels = ids_tensor([self.batch_size, self.n_choices, self.seq_length], self.num_labels)
mc_token_ids = ids_tensor([self.batch_size, self.n_choices], self.seq_length)
config = self.config_class(
vocab_size_or_config_json_file=self.vocab_size,
n_special=self.n_special,
n_positions=self.n_positions,
n_embd=self.hidden_size,
n_layer=self.num_hidden_layers,
n_head=self.num_attention_heads,
initializer_range=self.initializer_range)
return (config, input_ids, token_type_ids, position_ids,
mc_labels, lm_labels, mc_token_ids)
def create_and_check_base_model(self, config, input_ids, token_type_ids, position_ids,
mc_labels, lm_labels, mc_token_ids):
model = self.base_model_class(config)
model.eval()
outputs = model(input_ids, position_ids, token_type_ids)
hidden_state = outputs[0]
self.parent.assertListEqual(
list(hidden_state.size()),
[self.batch_size, self.n_choices, self.seq_length, self.hidden_size])
def create_and_check_lm_head(self, config, input_ids, token_type_ids, position_ids,
mc_labels, lm_labels, mc_token_ids):
model = self.lm_head_model_class(config)
model.eval()
outputs = model(input_ids, position_ids, token_type_ids, lm_labels)
loss, lm_logits = outputs[:2]
total_voc = self.n_special + self.vocab_size
self.parent.assertListEqual(
list(lm_logits.size()),
[self.batch_size, self.n_choices, self.seq_length, total_voc])
self.parent.assertListEqual(
list(loss.size()),
[])
def create_and_check_presents(self, config, input_ids, token_type_ids, position_ids,
mc_labels, lm_labels, mc_token_ids):
for model_class in self.all_model_classes:
model = model_class(config)
model.eval()
outputs = model(input_ids)
presents = outputs[-1]
self.parent.assertEqual(self.num_hidden_layers, len(presents))
self.parent.assertListEqual(
list(presents[0].size()),
[2, self.batch_size * self.n_choices, self.num_attention_heads,
self.seq_length, self.hidden_size // self.num_attention_heads])
def create_and_check_double_heads(self, config, input_ids, token_type_ids, position_ids,
mc_labels, lm_labels, mc_token_ids):
model = self.double_head_model_class(config)
model.eval()
outputs = model(input_ids, mc_token_ids, lm_labels=lm_labels, mc_labels=mc_labels,
token_type_ids=token_type_ids, position_ids=position_ids)
lm_loss, mc_loss, lm_logits, mc_logits = outputs[:4]
loss = [lm_loss, mc_loss]
total_voc = self.n_special + self.vocab_size
self.parent.assertListEqual(
list(lm_logits.size()),
[self.batch_size, self.n_choices, self.seq_length, total_voc])
self.parent.assertListEqual(
list(mc_logits.size()),
[self.batch_size, self.n_choices])
self.parent.assertListEqual(
[list(l.size()) for l in loss],
[[], []])
def create_and_check_model_from_pretrained(self):
cache_dir = "/tmp/pytorch_pretrained_bert_test/"
for model_name in list(self.base_model_class.PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = self.base_model_class.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
self.parent.assertIsNotNone(model)
def create_and_check_commons(self, config, input_ids, token_type_ids, position_ids,
mc_labels, lm_labels, mc_token_ids):
inputs_dict = {'input_ids': input_ids}
create_and_check_commons(self, config, inputs_dict)
def run_common_tests(self, test_presents=False):
config_and_inputs = self.prepare_config_and_inputs()
self.create_and_check_base_model(*config_and_inputs)
config_and_inputs = self.prepare_config_and_inputs()
self.create_and_check_lm_head(*config_and_inputs)
config_and_inputs = self.prepare_config_and_inputs()
self.create_and_check_double_heads(*config_and_inputs)
if test_presents:
config_and_inputs = self.prepare_config_and_inputs()
self.create_and_check_presents(*config_and_inputs)
config_and_inputs = self.prepare_config_and_inputs()
self.create_and_check_commons(*config_and_inputs)
def run_slow_tests(self):
config_and_inputs = self.prepare_config_and_inputs()
self.create_and_check_model_from_pretrained(*config_and_inputs)
# coding=utf-8
# Copyright 2018 HuggingFace Inc..
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import unittest
import json
import random
import shutil
import pytest
import torch
from pytorch_pretrained_bert import PretrainedConfig, PreTrainedModel
from pytorch_pretrained_bert.modeling import BertModel, BertConfig, PRETRAINED_MODEL_ARCHIVE_MAP, PRETRAINED_CONFIG_ARCHIVE_MAP
class ModelUtilsTest(unittest.TestCase):
def test_model_from_pretrained(self):
for model_name in list(PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
config = BertConfig.from_pretrained(model_name)
self.assertIsNotNone(config)
self.assertIsInstance(config, PretrainedConfig)
model = BertModel.from_pretrained(model_name)
self.assertIsNotNone(model)
self.assertIsInstance(model, PreTrainedModel)
config = BertConfig.from_pretrained(model_name, output_attentions=True, output_hidden_states=True)
model = BertModel.from_pretrained(model_name, output_attentions=True, output_hidden_states=True)
self.assertEqual(model.config.output_attentions, True)
self.assertEqual(model.config.output_hidden_states, True)
self.assertEqual(model.config, config)
if __name__ == "__main__":
unittest.main()
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import unittest
import json
import random
import shutil
import pytest
import torch
from pytorch_pretrained_bert import (GPT2Config, GPT2Model,
GPT2LMHeadModel, GPT2DoubleHeadsModel)
from .model_tests_commons import (create_and_check_for_attentions, create_and_check_for_head_pruning,
create_and_check_for_headmasking, create_and_check_for_hidden_states,
ConfigTester, GPTModelTester)
class GPT2ModelTest(unittest.TestCase):
def test_config(self):
config_tester = ConfigTester(self, config_class=GPT2Config, n_embd=37)
config_tester.run_common_tests()
def test_model(self):
model_tester = GPTModelTester(self, config_class=GPT2Config, base_model_class=GPT2Model,
lm_head_model_class=GPT2LMHeadModel,
double_head_model_class=GPT2DoubleHeadsModel)
model_tester.run_common_tests(test_presents=True)
@pytest.mark.slow
def test_pretrained(self):
model_tester = GPTModelTester(self, config_class=GPT2Config, base_model_class=GPT2Model,
lm_head_model_class=GPT2LMHeadModel,
double_head_model_class=GPT2DoubleHeadsModel)
model_tester.run_slow_tests()
if __name__ == "__main__":
unittest.main()
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import unittest
import json
import random
import shutil
import pytest
import torch
from pytorch_pretrained_bert import (OpenAIGPTConfig, OpenAIGPTModel,
OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel)
from .model_tests_commons import (create_and_check_for_attentions, create_and_check_for_head_pruning,
create_and_check_for_headmasking, create_and_check_for_hidden_states,
ConfigTester, GPTModelTester)
class OpenAIModelTest(unittest.TestCase):
def test_config(self):
config_tester = ConfigTester(self, config_class=OpenAIGPTConfig, n_embd=37)
config_tester.run_common_tests()
def test_model(self):
model_tester = GPTModelTester(self, config_class=OpenAIGPTConfig, base_model_class=OpenAIGPTModel,
lm_head_model_class=OpenAIGPTLMHeadModel,
double_head_model_class=OpenAIGPTDoubleHeadsModel)
model_tester.run_common_tests(test_presents=False)
@pytest.mark.slow
def test_pretrained(self):
model_tester = GPTModelTester(self, config_class=OpenAIGPTConfig, base_model_class=OpenAIGPTModel,
lm_head_model_class=OpenAIGPTLMHeadModel,
double_head_model_class=OpenAIGPTDoubleHeadsModel)
model_tester.run_slow_tests()
if __name__ == "__main__":
unittest.main()
...@@ -31,6 +31,8 @@ from pytorch_pretrained_bert import (BertConfig, BertModel, BertForMaskedLM, ...@@ -31,6 +31,8 @@ from pytorch_pretrained_bert import (BertConfig, BertModel, BertForMaskedLM,
BertForTokenClassification, BertForMultipleChoice) BertForTokenClassification, BertForMultipleChoice)
from pytorch_pretrained_bert.modeling import PRETRAINED_MODEL_ARCHIVE_MAP from pytorch_pretrained_bert.modeling import PRETRAINED_MODEL_ARCHIVE_MAP
from .model_tests_commons import (create_and_check_commons, ConfigTester, ids_tensor)
class BertModelTest(unittest.TestCase): class BertModelTest(unittest.TestCase):
class BertModelTester(object): class BertModelTester(object):
...@@ -57,7 +59,11 @@ class BertModelTest(unittest.TestCase): ...@@ -57,7 +59,11 @@ class BertModelTest(unittest.TestCase):
initializer_range=0.02, initializer_range=0.02,
num_labels=3, num_labels=3,
num_choices=4, num_choices=4,
scope=None): scope=None,
all_model_classes = (BertModel, BertForMaskedLM, BertForNextSentencePrediction,
BertForPreTraining, BertForQuestionAnswering, BertForSequenceClassification,
BertForTokenClassification),
):
self.parent = parent self.parent = parent
self.batch_size = batch_size self.batch_size = batch_size
self.seq_length = seq_length self.seq_length = seq_length
...@@ -80,25 +86,26 @@ class BertModelTest(unittest.TestCase): ...@@ -80,25 +86,26 @@ class BertModelTest(unittest.TestCase):
self.num_labels = num_labels self.num_labels = num_labels
self.num_choices = num_choices self.num_choices = num_choices
self.scope = scope self.scope = scope
self.all_model_classes = all_model_classes
def prepare_config_and_inputs(self): def prepare_config_and_inputs(self):
input_ids = BertModelTest.ids_tensor([self.batch_size, self.seq_length], self.vocab_size) input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_mask = None input_mask = None
if self.use_input_mask: if self.use_input_mask:
input_mask = BertModelTest.ids_tensor([self.batch_size, self.seq_length], vocab_size=2) input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
token_type_ids = None token_type_ids = None
if self.use_token_type_ids: if self.use_token_type_ids:
token_type_ids = BertModelTest.ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
sequence_labels = None sequence_labels = None
token_labels = None token_labels = None
choice_labels = None choice_labels = None
if self.use_labels: if self.use_labels:
sequence_labels = BertModelTest.ids_tensor([self.batch_size], self.type_sequence_label_size) sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
token_labels = BertModelTest.ids_tensor([self.batch_size, self.seq_length], self.num_labels) token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
choice_labels = BertModelTest.ids_tensor([self.batch_size], self.num_choices) choice_labels = ids_tensor([self.batch_size], self.num_choices)
config = BertConfig( config = BertConfig(
vocab_size_or_config_json_file=self.vocab_size, vocab_size_or_config_json_file=self.vocab_size,
...@@ -120,136 +127,117 @@ class BertModelTest(unittest.TestCase): ...@@ -120,136 +127,117 @@ class BertModelTest(unittest.TestCase):
list(result["loss"].size()), list(result["loss"].size()),
[]) [])
def create_bert_model(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels): def create_and_check_bert_model(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = BertModel(config=config) model = BertModel(config=config)
model.eval() model.eval()
sequence_output, pooled_output = model(input_ids, token_type_ids, input_mask) sequence_output, pooled_output = model(input_ids, token_type_ids, input_mask)
model = BertModel(config=config, output_hidden_states=True) result = {
model.eval()
_, _, all_encoder_layers = model(input_ids, token_type_ids, input_mask)
outputs = {
"sequence_output": sequence_output, "sequence_output": sequence_output,
"pooled_output": pooled_output, "pooled_output": pooled_output,
"all_encoder_layers": all_encoder_layers,
} }
return outputs
def check_bert_model_output(self, result):
self.parent.assertListEqual(
[size for layer in result["all_encoder_layers"] for size in layer.size()],
[self.batch_size, self.seq_length, self.hidden_size] * (self.num_hidden_layers + 1))
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["sequence_output"].size()), list(result["sequence_output"].size()),
[self.batch_size, self.seq_length, self.hidden_size]) [self.batch_size, self.seq_length, self.hidden_size])
self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size]) self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size])
def create_bert_for_masked_lm(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels): def create_and_check_bert_for_masked_lm(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = BertForMaskedLM(config=config) model = BertForMaskedLM(config=config)
model.eval() model.eval()
loss, prediction_scores = model(input_ids, token_type_ids, input_mask, token_labels) loss, prediction_scores = model(input_ids, token_type_ids, input_mask, token_labels)
outputs = { result = {
"loss": loss, "loss": loss,
"prediction_scores": prediction_scores, "prediction_scores": prediction_scores,
} }
return outputs
def check_bert_for_masked_lm_output(self, result):
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["prediction_scores"].size()), list(result["prediction_scores"].size()),
[self.batch_size, self.seq_length, self.vocab_size]) [self.batch_size, self.seq_length, self.vocab_size])
self.check_loss_output(result)
def create_bert_for_next_sequence_prediction(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels): def create_and_check_bert_for_next_sequence_prediction(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = BertForNextSentencePrediction(config=config) model = BertForNextSentencePrediction(config=config)
model.eval() model.eval()
loss, seq_relationship_score = model(input_ids, token_type_ids, input_mask, sequence_labels) loss, seq_relationship_score = model(input_ids, token_type_ids, input_mask, sequence_labels)
outputs = { result = {
"loss": loss, "loss": loss,
"seq_relationship_score": seq_relationship_score, "seq_relationship_score": seq_relationship_score,
} }
return outputs
def check_bert_for_next_sequence_prediction_output(self, result):
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["seq_relationship_score"].size()), list(result["seq_relationship_score"].size()),
[self.batch_size, 2]) [self.batch_size, 2])
self.check_loss_output(result)
def create_bert_for_pretraining(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels): def create_and_check_bert_for_pretraining(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = BertForPreTraining(config=config) model = BertForPreTraining(config=config)
model.eval() model.eval()
loss, prediction_scores, seq_relationship_score = model(input_ids, token_type_ids, input_mask, token_labels, sequence_labels) loss, prediction_scores, seq_relationship_score = model(input_ids, token_type_ids, input_mask, token_labels, sequence_labels)
outputs = { result = {
"loss": loss, "loss": loss,
"prediction_scores": prediction_scores, "prediction_scores": prediction_scores,
"seq_relationship_score": seq_relationship_score, "seq_relationship_score": seq_relationship_score,
} }
return outputs
def check_bert_for_pretraining_output(self, result):
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["prediction_scores"].size()), list(result["prediction_scores"].size()),
[self.batch_size, self.seq_length, self.vocab_size]) [self.batch_size, self.seq_length, self.vocab_size])
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["seq_relationship_score"].size()), list(result["seq_relationship_score"].size()),
[self.batch_size, 2]) [self.batch_size, 2])
self.check_loss_output(result)
def create_bert_for_question_answering(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels): def create_and_check_bert_for_question_answering(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = BertForQuestionAnswering(config=config) model = BertForQuestionAnswering(config=config)
model.eval() model.eval()
loss, start_logits, end_logits = model(input_ids, token_type_ids, input_mask, sequence_labels, sequence_labels) loss, start_logits, end_logits = model(input_ids, token_type_ids, input_mask, sequence_labels, sequence_labels)
outputs = { result = {
"loss": loss, "loss": loss,
"start_logits": start_logits, "start_logits": start_logits,
"end_logits": end_logits, "end_logits": end_logits,
} }
return outputs
def check_bert_for_question_answering_output(self, result):
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["start_logits"].size()), list(result["start_logits"].size()),
[self.batch_size, self.seq_length]) [self.batch_size, self.seq_length])
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["end_logits"].size()), list(result["end_logits"].size()),
[self.batch_size, self.seq_length]) [self.batch_size, self.seq_length])
self.check_loss_output(result)
def create_bert_for_sequence_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels): def create_and_check_bert_for_sequence_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = BertForSequenceClassification(config=config, num_labels=self.num_labels) config.num_labels = self.num_labels
model = BertForSequenceClassification(config)
model.eval() model.eval()
loss, logits = model(input_ids, token_type_ids, input_mask, sequence_labels) loss, logits = model(input_ids, token_type_ids, input_mask, sequence_labels)
outputs = { result = {
"loss": loss, "loss": loss,
"logits": logits, "logits": logits,
} }
return outputs
def check_bert_for_sequence_classification_output(self, result):
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["logits"].size()), list(result["logits"].size()),
[self.batch_size, self.num_labels]) [self.batch_size, self.num_labels])
self.check_loss_output(result)
def create_bert_for_token_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels): def create_and_check_bert_for_token_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = BertForTokenClassification(config=config, num_labels=self.num_labels) config.num_labels = self.num_labels
model = BertForTokenClassification(config=config)
model.eval() model.eval()
loss, logits = model(input_ids, token_type_ids, input_mask, token_labels) loss, logits = model(input_ids, token_type_ids, input_mask, token_labels)
outputs = { result = {
"loss": loss, "loss": loss,
"logits": logits, "logits": logits,
} }
return outputs
def check_bert_for_token_classification_output(self, result):
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["logits"].size()), list(result["logits"].size()),
[self.batch_size, self.seq_length, self.num_labels]) [self.batch_size, self.seq_length, self.num_labels])
self.check_loss_output(result)
def create_bert_for_multiple_choice(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels): def create_and_check_bert_for_multiple_choice(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = BertForMultipleChoice(config=config, num_choices=self.num_choices) config.num_choices = self.num_choices
model = BertForMultipleChoice(config=config)
model.eval() model.eval()
multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
...@@ -258,148 +246,26 @@ class BertModelTest(unittest.TestCase): ...@@ -258,148 +246,26 @@ class BertModelTest(unittest.TestCase):
multiple_choice_token_type_ids, multiple_choice_token_type_ids,
multiple_choice_input_mask, multiple_choice_input_mask,
choice_labels) choice_labels)
outputs = { result = {
"loss": loss, "loss": loss,
"logits": logits, "logits": logits,
} }
return outputs
def check_bert_for_multiple_choice(self, result):
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["logits"].size()), list(result["logits"].size()),
[self.batch_size, self.num_choices]) [self.batch_size, self.num_choices])
self.check_loss_output(result)
def create_and_check_bert_for_attentions(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels): def create_and_check_bert_commons(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
for model_class in (BertModel, BertForMaskedLM, BertForNextSentencePrediction, inputs_dict = {'input_ids': input_ids, 'token_type_ids': token_type_ids, 'attention_mask': input_mask}
BertForPreTraining, BertForQuestionAnswering, BertForSequenceClassification, create_and_check_commons(self, config, inputs_dict)
BertForTokenClassification):
if model_class in [BertForSequenceClassification,
BertForTokenClassification]:
model = model_class(config=config, num_labels=self.num_labels, output_attentions=True)
else:
model = model_class(config=config, output_attentions=True)
model.eval()
outputs = model(input_ids, token_type_ids, input_mask)
attentions = outputs[-1]
self.parent.assertEqual(len(attentions), self.num_hidden_layers)
self.parent.assertListEqual(
list(attentions[0].size()),
[self.batch_size, self.num_attention_heads, self.seq_length, self.seq_length])
def create_and_check_bert_for_headmasking(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
for model_class in (BertModel, BertForMaskedLM, BertForNextSentencePrediction,
BertForPreTraining, BertForQuestionAnswering, BertForSequenceClassification,
BertForTokenClassification):
if model_class in [BertForSequenceClassification,
BertForTokenClassification]:
model = model_class(config=config,
num_labels=self.num_labels)
else:
model = model_class(config=config)
model.eval()
head_mask = torch.ones(self.num_hidden_layers, self.num_attention_heads).to(input_ids.device)
head_mask[0, 1:-1] = 0.0 # Mask all but the first and last heads on the first layer
head_mask[-1, 1:] = 0.0 # Mask all but the first head on the last layer
# Set that after having prepared the tensor to avoid error (leaf variable has been moved into the graph interior)
head_mask.requires_grad_(requires_grad=True)
outputs = model(input_ids, token_type_ids, input_mask, head_mask=head_mask)
# Compute some gradients
output = sum(t.sum() for t in outputs[0])
output = output.sum()
output.backward()
multihead_outputs = head_mask.grad
self.parent.assertEqual(len(multihead_outputs), self.num_hidden_layers)
# self.parent.assertListEqual(
# list(multihead_outputs[0].size()),
# [self.batch_size, self.num_attention_heads,
# self.seq_length, self.hidden_size // self.num_attention_heads])
# self.parent.assertEqual(
# len(multihead_outputs[0][:, 1:(self.num_attention_heads-1), :, :].nonzero()),
# 0)
# self.parent.assertEqual(
# len(multihead_outputs[0][:, 0, :, :].nonzero()),
# self.batch_size * self.seq_length * self.hidden_size // self.num_attention_heads)
# self.parent.assertEqual(
# len(multihead_outputs[0][:, self.num_attention_heads-1, :, :].nonzero()),
# self.batch_size * self.seq_length * self.hidden_size // self.num_attention_heads)
# self.parent.assertListEqual(
# list(multihead_outputs[1].size()),
# [self.batch_size, self.num_attention_heads,
# self.seq_length, self.hidden_size // self.num_attention_heads])
# self.parent.assertEqual(
# len(multihead_outputs[1].nonzero()),
# multihead_outputs[1].numel())
# self.parent.assertListEqual(
# list(multihead_outputs[-1].size()),
# [self.batch_size, self.num_attention_heads,
# self.seq_length, self.hidden_size // self.num_attention_heads])
# self.parent.assertEqual(
# len(multihead_outputs[-1][:, 1:, :, :].nonzero()),
# 0)
# self.parent.assertEqual(
# len(multihead_outputs[-1][:, 0, :, :].nonzero()),
# self.batch_size * self.seq_length * self.hidden_size // self.num_attention_heads)
def create_and_check_bert_for_head_pruning(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
for model_class in (BertModel, BertForMaskedLM, BertForNextSentencePrediction,
BertForPreTraining, BertForQuestionAnswering, BertForSequenceClassification,
BertForTokenClassification):
if model_class in [BertForSequenceClassification,
BertForTokenClassification]:
model = model_class(config=config,
num_labels=self.num_labels)
else:
model = model_class(config=config)
model.eval()
bert_model = model if isinstance(model, BertModel) else model.bert
heads_to_prune = {0: list(range(1, self.num_attention_heads)),
-1: [0]}
bert_model.prune_heads(heads_to_prune)
outputs = model(input_ids, token_type_ids, input_mask)
# output = sum(t.sum() for t in outputs[0])
# output = output.sum()
# output.backward()
# multihead_outputs = bert_model.get_multihead_outputs()
# self.parent.assertEqual(len(multihead_outputs), self.num_hidden_layers)
# self.parent.assertListEqual(
# list(multihead_outputs[0].size()),
# [self.batch_size, 1,
# self.seq_length, self.hidden_size // self.num_attention_heads])
# self.parent.assertListEqual(
# list(multihead_outputs[1].size()),
# [self.batch_size, self.num_attention_heads,
# self.seq_length, self.hidden_size // self.num_attention_heads])
# self.parent.assertListEqual(
# list(multihead_outputs[-1].size()),
# [self.batch_size, self.num_attention_heads-1,
# self.seq_length, self.hidden_size // self.num_attention_heads])
def test_default(self): def test_default(self):
self.run_tester(BertModelTest.BertModelTester(self)) self.run_tester(BertModelTest.BertModelTester(self))
def test_config_to_json_string(self): def test_config(self):
config = BertConfig(vocab_size_or_config_json_file=99, hidden_size=37) config_tester = ConfigTester(self, config_class=BertConfig, hidden_size=37)
obj = json.loads(config.to_json_string()) config_tester.run_common_tests()
self.assertEqual(obj["vocab_size"], 99)
self.assertEqual(obj["hidden_size"], 37)
def test_config_to_json_file(self):
config_first = BertConfig(vocab_size_or_config_json_file=99, hidden_size=37)
json_file_path = "/tmp/config.json"
config_first.to_json_file(json_file_path)
config_second = BertConfig.from_json_file(json_file_path)
os.remove(json_file_path)
self.assertEqual(config_second.to_dict(), config_first.to_dict())
@pytest.mark.slow @pytest.mark.slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
...@@ -411,57 +277,31 @@ class BertModelTest(unittest.TestCase): ...@@ -411,57 +277,31 @@ class BertModelTest(unittest.TestCase):
def run_tester(self, tester): def run_tester(self, tester):
config_and_inputs = tester.prepare_config_and_inputs() config_and_inputs = tester.prepare_config_and_inputs()
output_result = tester.create_bert_model(*config_and_inputs) tester.create_and_check_bert_model(*config_and_inputs)
tester.check_bert_model_output(output_result)
output_result = tester.create_bert_for_masked_lm(*config_and_inputs) config_and_inputs = tester.prepare_config_and_inputs()
tester.check_bert_for_masked_lm_output(output_result) tester.create_and_check_bert_for_masked_lm(*config_and_inputs)
tester.check_loss_output(output_result)
output_result = tester.create_bert_for_next_sequence_prediction(*config_and_inputs)
tester.check_bert_for_next_sequence_prediction_output(output_result)
tester.check_loss_output(output_result)
output_result = tester.create_bert_for_pretraining(*config_and_inputs)
tester.check_bert_for_pretraining_output(output_result)
tester.check_loss_output(output_result)
output_result = tester.create_bert_for_question_answering(*config_and_inputs)
tester.check_bert_for_question_answering_output(output_result)
tester.check_loss_output(output_result)
output_result = tester.create_bert_for_sequence_classification(*config_and_inputs)
tester.check_bert_for_sequence_classification_output(output_result)
tester.check_loss_output(output_result)
output_result = tester.create_bert_for_token_classification(*config_and_inputs)
tester.check_bert_for_token_classification_output(output_result)
tester.check_loss_output(output_result)
output_result = tester.create_bert_for_multiple_choice(*config_and_inputs) config_and_inputs = tester.prepare_config_and_inputs()
tester.check_bert_for_multiple_choice(output_result) tester.create_and_check_bert_for_multiple_choice(*config_and_inputs)
tester.check_loss_output(output_result)
tester.create_and_check_bert_for_attentions(*config_and_inputs) config_and_inputs = tester.prepare_config_and_inputs()
tester.create_and_check_bert_for_headmasking(*config_and_inputs) tester.create_and_check_bert_for_next_sequence_prediction(*config_and_inputs)
tester.create_and_check_bert_for_head_pruning(*config_and_inputs)
@classmethod config_and_inputs = tester.prepare_config_and_inputs()
def ids_tensor(cls, shape, vocab_size, rng=None, name=None): tester.create_and_check_bert_for_pretraining(*config_and_inputs)
"""Creates a random int32 tensor of the shape within the vocab size."""
if rng is None:
rng = random.Random()
total_dims = 1 config_and_inputs = tester.prepare_config_and_inputs()
for dim in shape: tester.create_and_check_bert_for_question_answering(*config_and_inputs)
total_dims *= dim
values = [] config_and_inputs = tester.prepare_config_and_inputs()
for _ in range(total_dims): tester.create_and_check_bert_for_sequence_classification(*config_and_inputs)
values.append(rng.randint(0, vocab_size - 1))
return torch.tensor(data=values, dtype=torch.long).view(shape).contiguous() config_and_inputs = tester.prepare_config_and_inputs()
tester.create_and_check_bert_for_token_classification(*config_and_inputs)
config_and_inputs = tester.prepare_config_and_inputs()
tester.create_and_check_bert_commons(*config_and_inputs)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -28,6 +28,8 @@ import torch ...@@ -28,6 +28,8 @@ import torch
from pytorch_pretrained_bert import (TransfoXLConfig, TransfoXLModel, TransfoXLLMHeadModel) from pytorch_pretrained_bert import (TransfoXLConfig, TransfoXLModel, TransfoXLLMHeadModel)
from pytorch_pretrained_bert.modeling_transfo_xl import PRETRAINED_MODEL_ARCHIVE_MAP from pytorch_pretrained_bert.modeling_transfo_xl import PRETRAINED_MODEL_ARCHIVE_MAP
from .model_tests_commons import ConfigTester, create_and_check_commons, ids_tensor
class TransfoXLModelTest(unittest.TestCase): class TransfoXLModelTest(unittest.TestCase):
class TransfoXLModelTester(object): class TransfoXLModelTester(object):
...@@ -41,54 +43,58 @@ class TransfoXLModelTest(unittest.TestCase): ...@@ -41,54 +43,58 @@ class TransfoXLModelTest(unittest.TestCase):
use_labels=True, use_labels=True,
vocab_size=99, vocab_size=99,
cutoffs=[10, 50, 80], cutoffs=[10, 50, 80],
d_model=32, hidden_size=32,
d_embed=32, d_embed=32,
n_head=4, num_attention_heads=4,
d_head=8, d_head=8,
d_inner=128, d_inner=128,
div_val=2, div_val=2,
n_layer=5, num_hidden_layers=5,
scope=None, scope=None,
seed=1): seed=1,
all_model_classes=(TransfoXLModel, TransfoXLLMHeadModel),
):
self.parent = parent self.parent = parent
self.batch_size = batch_size self.batch_size = batch_size
self.seq_length = seq_length self.seq_length = seq_length
self.mem_len = mem_len self.mem_len = mem_len
self.key_len = seq_length + mem_len
self.clamp_len = clamp_len self.clamp_len = clamp_len
self.is_training = is_training self.is_training = is_training
self.use_labels = use_labels self.use_labels = use_labels
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.cutoffs = cutoffs self.cutoffs = cutoffs
self.d_model = d_model self.hidden_size = hidden_size
self.d_embed = d_embed self.d_embed = d_embed
self.n_head = n_head self.num_attention_heads = num_attention_heads
self.d_head = d_head self.d_head = d_head
self.d_inner = d_inner self.d_inner = d_inner
self.div_val = div_val self.div_val = div_val
self.n_layer = n_layer self.num_hidden_layers = num_hidden_layers
self.scope = scope self.scope = scope
self.seed = seed self.seed = seed
self.all_model_classes = all_model_classes
def prepare_config_and_inputs(self): def prepare_config_and_inputs(self):
input_ids_1 = TransfoXLModelTest.ids_tensor([self.batch_size, self.seq_length], self.vocab_size) input_ids_1 = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_ids_2 = TransfoXLModelTest.ids_tensor([self.batch_size, self.seq_length], self.vocab_size) input_ids_2 = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
lm_labels = None lm_labels = None
if self.use_labels: if self.use_labels:
lm_labels = TransfoXLModelTest.ids_tensor([self.batch_size, self.seq_length], self.vocab_size) lm_labels = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
config = TransfoXLConfig( config = TransfoXLConfig(
vocab_size_or_config_json_file=self.vocab_size, vocab_size_or_config_json_file=self.vocab_size,
mem_len=self.mem_len, mem_len=self.mem_len,
clamp_len=self.clamp_len, clamp_len=self.clamp_len,
cutoffs=self.cutoffs, cutoffs=self.cutoffs,
d_model=self.d_model, d_model=self.hidden_size,
d_embed=self.d_embed, d_embed=self.d_embed,
n_head=self.n_head, n_head=self.num_attention_heads,
d_head=self.d_head, d_head=self.d_head,
d_inner=self.d_inner, d_inner=self.d_inner,
div_val=self.div_val, div_val=self.div_val,
n_layer=self.n_layer) n_layer=self.num_hidden_layers)
return (config, input_ids_1, input_ids_2, lm_labels) return (config, input_ids_1, input_ids_2, lm_labels)
...@@ -113,37 +119,34 @@ class TransfoXLModelTest(unittest.TestCase): ...@@ -113,37 +119,34 @@ class TransfoXLModelTest(unittest.TestCase):
def check_transfo_xl_model_output(self, result): def check_transfo_xl_model_output(self, result):
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["hidden_states_1"].size()), list(result["hidden_states_1"].size()),
[self.batch_size, self.seq_length, self.d_model]) [self.batch_size, self.seq_length, self.hidden_size])
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["hidden_states_2"].size()), list(result["hidden_states_2"].size()),
[self.batch_size, self.seq_length, self.d_model]) [self.batch_size, self.seq_length, self.hidden_size])
self.parent.assertListEqual( self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_1"]), list(list(mem.size()) for mem in result["mems_1"]),
[[self.mem_len, self.batch_size, self.d_model]] * self.n_layer) [[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
self.parent.assertListEqual( self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_2"]), list(list(mem.size()) for mem in result["mems_2"]),
[[self.mem_len, self.batch_size, self.d_model]] * self.n_layer) [[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
def create_transfo_xl_lm_head(self, config, input_ids_1, input_ids_2, lm_labels): def create_transfo_xl_lm_head(self, config, input_ids_1, input_ids_2, lm_labels):
model = TransfoXLLMHeadModel(config) model = TransfoXLLMHeadModel(config)
model.eval() model.eval()
loss_1, mems_1a = model(input_ids_1, labels=lm_labels) lm_logits_1, mems_1 = model(input_ids_1)
lm_logits_1, mems_1b = model(input_ids_1) loss_1, _, mems_1 = model(input_ids_1, labels=lm_labels)
lm_logits_2, mems_2 = model(input_ids_2, mems=mems_1)
loss_2, mems_2a = model(input_ids_2, labels=lm_labels, mems=mems_1a) loss_2, _, mems_2 = model(input_ids_2, labels=lm_labels, mems=mems_1)
lm_logits_2, mems_2b = model(input_ids_2, mems=mems_1b)
outputs = { outputs = {
"loss_1": loss_1, "loss_1": loss_1,
"mems_1a": mems_1a, "mems_1": mems_1,
"lm_logits_1": lm_logits_1, "lm_logits_1": lm_logits_1,
"mems_1b": mems_1b,
"loss_2": loss_2, "loss_2": loss_2,
"mems_2a": mems_2a, "mems_2": mems_2,
"lm_logits_2": lm_logits_2, "lm_logits_2": lm_logits_2,
"mems_2b": mems_2b,
} }
return outputs return outputs
...@@ -155,14 +158,8 @@ class TransfoXLModelTest(unittest.TestCase): ...@@ -155,14 +158,8 @@ class TransfoXLModelTest(unittest.TestCase):
list(result["lm_logits_1"].size()), list(result["lm_logits_1"].size()),
[self.batch_size, self.seq_length, self.vocab_size]) [self.batch_size, self.seq_length, self.vocab_size])
self.parent.assertListEqual( self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_1a"]), list(list(mem.size()) for mem in result["mems_1"]),
[[self.mem_len, self.batch_size, self.d_model]] * self.n_layer) [[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_1b"]),
[[self.mem_len, self.batch_size, self.d_model]] * self.n_layer)
self.parent.assertListEqual(
list(mem[~torch.isnan(mem)].sum() for mem in result["mems_1a"]),
list(mem[~torch.isnan(mem)].sum() for mem in result["mems_1b"]))
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["loss_2"].size()), list(result["loss_2"].size()),
...@@ -171,31 +168,19 @@ class TransfoXLModelTest(unittest.TestCase): ...@@ -171,31 +168,19 @@ class TransfoXLModelTest(unittest.TestCase):
list(result["lm_logits_2"].size()), list(result["lm_logits_2"].size()),
[self.batch_size, self.seq_length, self.vocab_size]) [self.batch_size, self.seq_length, self.vocab_size])
self.parent.assertListEqual( self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_2a"]), list(list(mem.size()) for mem in result["mems_2"]),
[[self.mem_len, self.batch_size, self.d_model]] * self.n_layer) [[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_2b"]), def create_and_check_transfo_xl_commons(self, config, input_ids_1, input_ids_2, lm_labels):
[[self.mem_len, self.batch_size, self.d_model]] * self.n_layer) inputs_dict = {'input_ids': input_ids_1}
self.parent.assertListEqual( create_and_check_commons(self, config, inputs_dict)
list(mem[~torch.isnan(mem)].sum() for mem in result["mems_2a"]),
list(mem[~torch.isnan(mem)].sum() for mem in result["mems_2b"]))
def test_default(self): def test_default(self):
self.run_tester(TransfoXLModelTest.TransfoXLModelTester(self)) self.run_tester(TransfoXLModelTest.TransfoXLModelTester(self))
def test_config_to_json_string(self): def test_config(self):
config = TransfoXLConfig(vocab_size_or_config_json_file=96, d_embed=37) config_tester = ConfigTester(self, config_class=TransfoXLConfig, d_embed=37)
obj = json.loads(config.to_json_string()) config_tester.run_common_tests()
self.assertEqual(obj["n_token"], 96)
self.assertEqual(obj["d_embed"], 37)
def test_config_to_json_file(self):
config_first = TransfoXLConfig(vocab_size_or_config_json_file=96, d_embed=37)
json_file_path = "/tmp/config.json"
config_first.to_json_file(json_file_path)
config_second = TransfoXLConfig.from_json_file(json_file_path)
os.remove(json_file_path)
self.assertEqual(config_second.to_dict(), config_first.to_dict())
@pytest.mark.slow @pytest.mark.slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
...@@ -209,28 +194,18 @@ class TransfoXLModelTest(unittest.TestCase): ...@@ -209,28 +194,18 @@ class TransfoXLModelTest(unittest.TestCase):
config_and_inputs = tester.prepare_config_and_inputs() config_and_inputs = tester.prepare_config_and_inputs()
tester.set_seed() tester.set_seed()
config_and_inputs = tester.prepare_config_and_inputs()
output_result = tester.create_transfo_xl_model(*config_and_inputs) output_result = tester.create_transfo_xl_model(*config_and_inputs)
tester.check_transfo_xl_model_output(output_result) tester.check_transfo_xl_model_output(output_result)
tester.set_seed() tester.set_seed()
config_and_inputs = tester.prepare_config_and_inputs()
output_result = tester.create_transfo_xl_lm_head(*config_and_inputs) output_result = tester.create_transfo_xl_lm_head(*config_and_inputs)
tester.check_transfo_xl_lm_head_output(output_result) tester.check_transfo_xl_lm_head_output(output_result)
@classmethod tester.set_seed()
def ids_tensor(cls, shape, vocab_size, rng=None, name=None): config_and_inputs = tester.prepare_config_and_inputs()
"""Creates a random int32 tensor of the shape within the vocab size.""" tester.create_and_check_transfo_xl_commons(*config_and_inputs)
if rng is None:
rng = random.Random()
total_dims = 1
for dim in shape:
total_dims *= dim
values = []
for _ in range(total_dims):
values.append(rng.randint(0, vocab_size - 1))
return torch.tensor(data=values, dtype=torch.long).view(shape).contiguous()
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -25,9 +25,11 @@ import pytest ...@@ -25,9 +25,11 @@ import pytest
import torch import torch
from pytorch_pretrained_bert import (XLNetConfig, XLNetModel, XLNetLMHeadModel) from pytorch_pretrained_bert import (XLNetConfig, XLNetModel, XLNetLMHeadModel, XLNetForSequenceClassification, XLNetForQuestionAnswering)
from pytorch_pretrained_bert.modeling_xlnet import PRETRAINED_MODEL_ARCHIVE_MAP from pytorch_pretrained_bert.modeling_xlnet import PRETRAINED_MODEL_ARCHIVE_MAP
from .model_tests_commons import ConfigTester, create_and_check_commons, ids_tensor
class XLNetModelTest(unittest.TestCase): class XLNetModelTest(unittest.TestCase):
class XLNetModelTester(object): class XLNetModelTester(object):
...@@ -42,43 +44,48 @@ class XLNetModelTest(unittest.TestCase): ...@@ -42,43 +44,48 @@ class XLNetModelTest(unittest.TestCase):
use_labels=True, use_labels=True,
vocab_size=99, vocab_size=99,
cutoffs=[10, 50, 80], cutoffs=[10, 50, 80],
d_model=32, hidden_size=32,
n_head=4, num_attention_heads=4,
d_inner=128, d_inner=128,
n_layer=5, num_hidden_layers=5,
max_position_embeddings=10, max_position_embeddings=10,
untie_r=True, untie_r=True,
bi_data=False, bi_data=False,
same_length=False, same_length=False,
seed=1, seed=1,
type_vocab_size=2): type_vocab_size=2,
all_model_classes=(XLNetModel, XLNetLMHeadModel,
XLNetForSequenceClassification, XLNetForQuestionAnswering),
):
self.parent = parent self.parent = parent
self.batch_size = batch_size self.batch_size = batch_size
self.seq_length = seq_length self.seq_length = seq_length
self.mem_len = mem_len self.mem_len = mem_len
# self.key_len = seq_length + mem_len
self.clamp_len = clamp_len self.clamp_len = clamp_len
self.reuse_len = reuse_len self.reuse_len = reuse_len
self.is_training = is_training self.is_training = is_training
self.use_labels = use_labels self.use_labels = use_labels
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.cutoffs = cutoffs self.cutoffs = cutoffs
self.d_model = d_model self.hidden_size = hidden_size
self.n_head = n_head self.num_attention_heads = num_attention_heads
self.d_inner = d_inner self.d_inner = d_inner
self.n_layer = n_layer self.num_hidden_layers = num_hidden_layers
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
self.bi_data = bi_data self.bi_data = bi_data
self.untie_r = untie_r self.untie_r = untie_r
self.same_length = same_length self.same_length = same_length
self.seed = seed self.seed = seed
self.type_vocab_size = type_vocab_size self.type_vocab_size = type_vocab_size
self.all_model_classes = all_model_classes
def prepare_config_and_inputs(self): def prepare_config_and_inputs(self):
input_ids_1 = XLNetModelTest.ids_tensor([self.batch_size, self.seq_length], self.vocab_size) input_ids_1 = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_ids_2 = XLNetModelTest.ids_tensor([self.batch_size, self.seq_length], self.vocab_size) input_ids_2 = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
segment_ids = XLNetModelTest.ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) segment_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
input_ids_q = XLNetModelTest.ids_tensor([self.batch_size, self.seq_length + 1], self.vocab_size) input_ids_q = ids_tensor([self.batch_size, self.seq_length + 1], self.vocab_size)
perm_mask = torch.zeros(self.batch_size, self.seq_length + 1, self.seq_length + 1, dtype=torch.float) perm_mask = torch.zeros(self.batch_size, self.seq_length + 1, self.seq_length + 1, dtype=torch.float)
perm_mask[:, :, -1] = 1.0 # Previous tokens don't see last token perm_mask[:, :, -1] = 1.0 # Previous tokens don't see last token
target_mapping = torch.zeros(self.batch_size, 1, self.seq_length + 1, dtype=torch.float) target_mapping = torch.zeros(self.batch_size, 1, self.seq_length + 1, dtype=torch.float)
...@@ -89,8 +96,8 @@ class XLNetModelTest(unittest.TestCase): ...@@ -89,8 +96,8 @@ class XLNetModelTest(unittest.TestCase):
# token_type_ids: int32 Tensor in shape [bsz, len], the input segment IDs. # token_type_ids: int32 Tensor in shape [bsz, len], the input segment IDs.
# input_mask: float32 Tensor in shape [bsz, len], the input mask. # input_mask: float32 Tensor in shape [bsz, len], the input mask.
# 0 for real tokens and 1 for padding. # 0 for real tokens and 1 for padding.
# mems: a list of float32 Tensors in shape [bsz, mem_len, d_model], memory # mems: a list of float32 Tensors in shape [bsz, mem_len, hidden_size], memory
# from previous batches. The length of the list equals n_layer. # from previous batches. The length of the list equals num_hidden_layers.
# If None, no memory is used. # If None, no memory is used.
# perm_mask: float32 Tensor in shape [bsz, len, len]. # perm_mask: float32 Tensor in shape [bsz, len, len].
# If perm_mask[k, i, j] = 0, i attend to j in batch k; # If perm_mask[k, i, j] = 0, i attend to j in batch k;
...@@ -108,14 +115,14 @@ class XLNetModelTest(unittest.TestCase): ...@@ -108,14 +115,14 @@ class XLNetModelTest(unittest.TestCase):
lm_labels = None lm_labels = None
if self.use_labels: if self.use_labels:
lm_labels = XLNetModelTest.ids_tensor([self.batch_size, self.seq_length], self.vocab_size) lm_labels = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
config = XLNetConfig( config = XLNetConfig(
vocab_size_or_config_json_file=self.vocab_size, vocab_size_or_config_json_file=self.vocab_size,
d_model=self.d_model, d_model=self.hidden_size,
n_head=self.n_head, n_head=self.num_attention_heads,
d_inner=self.d_inner, d_inner=self.d_inner,
n_layer=self.n_layer, n_layer=self.num_hidden_layers,
untie_r=self.untie_r, untie_r=self.untie_r,
max_position_embeddings=self.max_position_embeddings, max_position_embeddings=self.max_position_embeddings,
mem_len=self.mem_len, mem_len=self.mem_len,
...@@ -159,7 +166,7 @@ class XLNetModelTest(unittest.TestCase): ...@@ -159,7 +166,7 @@ class XLNetModelTest(unittest.TestCase):
[self.batch_size, self.seq_length, self.vocab_size]) [self.batch_size, self.seq_length, self.vocab_size])
self.parent.assertListEqual( self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_1"]), list(list(mem.size()) for mem in result["mems_1"]),
[[self.seq_length, self.batch_size, self.d_model]] * self.n_layer) [[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["loss_2"].size()), list(result["loss_2"].size()),
...@@ -169,24 +176,18 @@ class XLNetModelTest(unittest.TestCase): ...@@ -169,24 +176,18 @@ class XLNetModelTest(unittest.TestCase):
[self.batch_size, self.seq_length, self.vocab_size]) [self.batch_size, self.seq_length, self.vocab_size])
self.parent.assertListEqual( self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_2"]), list(list(mem.size()) for mem in result["mems_2"]),
[[self.mem_len, self.batch_size, self.d_model]] * self.n_layer) [[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
def create_and_check_xlnet_commons(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, target_mapping, inp_q, segment_ids, lm_labels):
inputs_dict = {'input_ids': input_ids_1}
create_and_check_commons(self, config, inputs_dict)
def test_default(self): def test_default(self):
self.run_tester(XLNetModelTest.XLNetModelTester(self)) self.run_tester(XLNetModelTest.XLNetModelTester(self))
def test_config_to_json_string(self): def test_config(self):
config = XLNetConfig(vocab_size_or_config_json_file=96, d_model=16*4) config_tester = ConfigTester(self, config_class=XLNetConfig, d_inner=37)
obj = json.loads(config.to_json_string()) config_tester.run_common_tests()
self.assertEqual(obj["n_token"], 96)
self.assertEqual(obj["d_model"], 16*4)
def test_config_to_json_file(self):
config_first = XLNetConfig(vocab_size_or_config_json_file=96, d_model=16*4)
json_file_path = "/tmp/config.json"
config_first.to_json_file(json_file_path)
config_second = XLNetConfig.from_json_file(json_file_path)
os.remove(json_file_path)
self.assertEqual(config_second.to_dict(), config_first.to_dict())
@pytest.mark.slow @pytest.mark.slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
...@@ -197,27 +198,14 @@ class XLNetModelTest(unittest.TestCase): ...@@ -197,27 +198,14 @@ class XLNetModelTest(unittest.TestCase):
self.assertIsNotNone(model) self.assertIsNotNone(model)
def run_tester(self, tester): def run_tester(self, tester):
config_and_inputs = tester.prepare_config_and_inputs()
tester.set_seed() tester.set_seed()
config_and_inputs = tester.prepare_config_and_inputs()
output_result = tester.create_transfo_xl_lm_head(*config_and_inputs) output_result = tester.create_transfo_xl_lm_head(*config_and_inputs)
tester.check_transfo_xl_lm_head_output(output_result) tester.check_transfo_xl_lm_head_output(output_result)
@classmethod tester.set_seed()
def ids_tensor(cls, shape, vocab_size, rng=None, name=None): config_and_inputs = tester.prepare_config_and_inputs()
"""Creates a random int32 tensor of the shape within the vocab size.""" tester.create_and_check_xlnet_commons(*config_and_inputs)
if rng is None:
rng = random.Random()
total_dims = 1
for dim in shape:
total_dims *= dim
values = []
for _ in range(total_dims):
values.append(rng.randint(0, vocab_size - 1))
return torch.tensor(data=values, dtype=torch.long).view(shape).contiguous()
@classmethod @classmethod
def mask_tensor(cls, shape, vocab_size, rng=None, name=None): def mask_tensor(cls, shape, vocab_size, rng=None, name=None):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment