Commit 1484d67d authored by thomwolf's avatar thomwolf
Browse files

[LARGE] updating all tests and API

parent 4f8b5f68
......@@ -41,6 +41,12 @@ class PretrainedConfig(object):
"""
pretrained_config_archive_map = {}
def __init__(self, **kwargs):
self.finetuning_task = kwargs.pop('finetuning_task', None)
self.num_labels = kwargs.pop('num_labels', 2)
self.output_attentions = kwargs.pop('output_attentions', False)
self.output_hidden_states = kwargs.pop('output_hidden_states', False)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
"""
......@@ -114,6 +120,9 @@ class PretrainedConfig(object):
text = reader.read()
return cls.from_dict(json.loads(text))
def __eq__(self, other):
return self.__dict__ == other.__dict__
def __repr__(self):
return str(self.to_json_string())
......@@ -133,12 +142,11 @@ class PretrainedConfig(object):
class PreTrainedModel(nn.Module):
""" An abstract class to handle weights initialization and
""" An abstract class to handle storing model config and
a simple interface for dowloading and loading pretrained models.
"""
config_class = PretrainedConfig
pretrained_model_archive_map = {}
pretrained_config_archive_map = {}
load_tf_weights = lambda model, config, path: None
base_model_prefix = ""
......@@ -151,8 +159,16 @@ class PreTrainedModel(nn.Module):
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
self.__class__.__name__, self.__class__.__name__
))
# Save config in model
self.config = config
def prune_heads(self, heads_to_prune):
""" Prunes heads of the base model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
"""
model_to_prune = getattr(self, self.base_model_prefix, self) # get the base model if needed
model_to_prune._prune_heads(heads_to_prune)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
"""
......@@ -175,24 +191,22 @@ class PreTrainedModel(nn.Module):
*inputs, **kwargs: additional input for the specific XLNet class
(ex: num_labels for XLNetForSequenceClassification)
"""
state_dict = kwargs.get('state_dict', None)
kwargs.pop('state_dict', None)
cache_dir = kwargs.get('cache_dir', None)
kwargs.pop('cache_dir', None)
from_tf = kwargs.get('from_tf', False)
kwargs.pop('from_tf', None)
state_dict = kwargs.pop('state_dict', None)
cache_dir = kwargs.pop('cache_dir', None)
from_tf = kwargs.pop('from_tf', None)
# Load config
config = cls.config_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
# Load model
if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
archive_file = cls.pretrained_model_archive_map[pretrained_model_name_or_path]
config_file = cls.pretrained_config_archive_map[pretrained_model_name_or_path]
else:
if from_tf:
# Directly load from a TensorFlow checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")
config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
else:
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
# redirect to the cache, if necessary
try:
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
......@@ -210,47 +224,15 @@ class PreTrainedModel(nn.Module):
', '.join(cls.pretrained_model_archive_map.keys()),
archive_file))
return None
try:
resolved_config_file = cached_path(config_file, cache_dir=cache_dir)
except EnvironmentError:
if pretrained_model_name_or_path in cls.pretrained_config_archive_map:
logger.error(
"Couldn't reach server at '{}' to download pretrained model configuration file.".format(
config_file))
else:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url.".format(
pretrained_model_name_or_path,
', '.join(cls.pretrained_config_archive_map.keys()),
config_file))
return None
if resolved_archive_file == archive_file and resolved_config_file == config_file:
if resolved_archive_file == archive_file:
logger.info("loading weights file {}".format(archive_file))
logger.info("loading configuration file {}".format(config_file))
else:
logger.info("loading weights file {} from cache at {}".format(
archive_file, resolved_archive_file))
logger.info("loading configuration file {} from cache at {}".format(
config_file, resolved_config_file))
# Load config
config = cls.config_class.from_json_file(resolved_config_file)
# Update config with kwargs if needed
to_remove = []
for key, value in kwargs.items():
if hasattr(config, key):
setattr(config, key, value)
to_remove.append(key)
for key in to_remove:
kwargs.pop(key, None)
logger.info("Model config {}".format(config))
# Instantiate model.
model = cls(config, *inputs, **kwargs)
model = cls(config)
if state_dict is None and not from_tf:
state_dict = torch.load(resolved_archive_file, map_location='cpu')
if from_tf:
......@@ -275,7 +257,7 @@ class PreTrainedModel(nn.Module):
if child is not None:
load(child, prefix + name + '.')
# Be able to load base models as well as derived models (with heads)
# Make sure we are able to load base models as well as derived models (with heads)
start_prefix = ''
model_to_load = model
if not hasattr(model, cls.base_model_prefix) and any(s.startswith(cls.base_model_prefix) for s in state_dict.keys()):
......
......@@ -155,7 +155,7 @@ class BertConfig(PretrainedConfig):
pretrained_config_archive_map = PRETRAINED_CONFIG_ARCHIVE_MAP
def __init__(self,
vocab_size_or_config_json_file,
vocab_size_or_config_json_file=30522,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
......@@ -167,7 +167,7 @@ class BertConfig(PretrainedConfig):
type_vocab_size=2,
initializer_range=0.02,
layer_norm_eps=1e-12,
finetuning_task=None):
**kwargs):
"""Constructs BertConfig.
Args:
......@@ -192,8 +192,8 @@ class BertConfig(PretrainedConfig):
initializer_range: The sttdev of the truncated_normal_initializer for
initializing all weight matrices.
layer_norm_eps: The epsilon used by LayerNorm.
finetuning_task: name of the glue task on which the model was fine-tuned if any
"""
super(BertConfig, self).__init__(**kwargs)
if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
and isinstance(vocab_size_or_config_json_file, unicode)):
with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
......@@ -213,7 +213,6 @@ class BertConfig(PretrainedConfig):
self.type_vocab_size = type_vocab_size
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.finetuning_task = finetuning_task
else:
raise ValueError("First argument must be either a vocabulary size (int)"
"or the path to a pretrained model config file (str)")
......@@ -270,13 +269,13 @@ class BertEmbeddings(nn.Module):
class BertSelfAttention(nn.Module):
def __init__(self, config, output_attentions=False):
def __init__(self, config):
super(BertSelfAttention, self).__init__()
if config.hidden_size % config.num_attention_heads != 0:
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (config.hidden_size, config.num_attention_heads))
self.output_attentions = output_attentions
self.output_attentions = config.output_attentions
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
......@@ -344,10 +343,9 @@ class BertSelfOutput(nn.Module):
class BertAttention(nn.Module):
def __init__(self, config, output_attentions=False):
def __init__(self, config):
super(BertAttention, self).__init__()
self.output_attentions = output_attentions
self.self = BertSelfAttention(config, output_attentions=output_attentions)
self.self = BertSelfAttention(config)
self.output = BertSelfOutput(config)
def prune_heads(self, heads):
......@@ -404,10 +402,9 @@ class BertOutput(nn.Module):
class BertLayer(nn.Module):
def __init__(self, config, output_attentions=False):
def __init__(self, config):
super(BertLayer, self).__init__()
self.output_attentions = output_attentions
self.attention = BertAttention(config, output_attentions=output_attentions)
self.attention = BertAttention(config)
self.intermediate = BertIntermediate(config)
self.output = BertOutput(config)
......@@ -421,11 +418,11 @@ class BertLayer(nn.Module):
class BertEncoder(nn.Module):
def __init__(self, config, output_attentions=False, output_hidden_states=False):
def __init__(self, config):
super(BertEncoder, self).__init__()
self.output_attentions = output_attentions
self.output_hidden_states = output_hidden_states
layer = BertLayer(config, output_attentions=output_attentions)
self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states
layer = BertLayer(config)
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
def forward(self, hidden_states, attention_mask, head_mask=None):
......@@ -546,9 +543,6 @@ class BertPreTrainedModel(PreTrainedModel):
load_tf_weights = load_tf_weights_in_bert
base_model_prefix = "bert"
def __init__(self, *inputs, **kwargs):
super(BertPreTrainedModel, self).__init__(*inputs, **kwargs)
def init_weights(self, module):
""" Initialize the weights.
"""
......@@ -612,19 +606,19 @@ class BertModel(BertPreTrainedModel):
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
```
"""
def __init__(self, config, output_attentions=False, output_hidden_states=False):
def __init__(self, config):
super(BertModel, self).__init__(config)
self.output_attentions = output_attentions
self.output_hidden_states = output_hidden_states
self.embeddings = BertEmbeddings(config)
self.encoder = BertEncoder(config, output_attentions=output_attentions,
output_hidden_states=output_hidden_states)
self.encoder = BertEncoder(config)
self.pooler = BertPooler(config)
self.apply(self.init_weights)
def prune_heads(self, heads_to_prune):
def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
See base class PreTrainedModel
"""
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)
......@@ -730,14 +724,12 @@ class BertForPreTraining(BertPreTrainedModel):
masked_lm_logits_scores, seq_relationship_logits = model(input_ids, token_type_ids, input_mask)
```
"""
def __init__(self, config, output_attentions=False, output_hidden_states=False):
def __init__(self, config):
super(BertForPreTraining, self).__init__(config)
self.output_attentions = output_attentions
self.output_hidden_states = output_hidden_states
self.bert = BertModel(config, output_attentions=output_attentions,
output_hidden_states=output_hidden_states)
self.bert = BertModel(config)
self.cls = BertPreTrainingHeads(config, self.bert.embeddings.word_embeddings.weight)
self.apply(self.init_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None,
......@@ -809,13 +801,12 @@ class BertForMaskedLM(BertPreTrainedModel):
masked_lm_logits_scores = model(input_ids, token_type_ids, input_mask)
```
"""
def __init__(self, config, output_attentions=False, output_hidden_states=False):
def __init__(self, config):
super(BertForMaskedLM, self).__init__(config)
self.output_attentions = output_attentions
self.output_hidden_states = output_hidden_states
self.bert = BertModel(config, output_attentions=output_attentions )
self.bert = BertModel(config)
self.cls = BertOnlyMLMHead(config, self.bert.embeddings.word_embeddings.weight)
self.apply(self.init_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, head_mask=None):
......@@ -880,12 +871,10 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
seq_relationship_logits = model(input_ids, token_type_ids, input_mask)
```
"""
def __init__(self, config, output_attentions=False, output_hidden_states=False):
def __init__(self, config):
super(BertForNextSentencePrediction, self).__init__(config)
self.output_attentions = output_attentions
self.output_hidden_states = output_hidden_states
self.bert = BertModel(config, output_attentions=output_attentions)
self.bert = BertModel(config)
self.cls = BertOnlyNSPHead(config)
self.apply(self.init_weights)
......@@ -954,15 +943,13 @@ class BertForSequenceClassification(BertPreTrainedModel):
logits = model(input_ids, token_type_ids, input_mask)
```
"""
def __init__(self, config, num_labels=2, output_attentions=False, output_hidden_states=False):
def __init__(self, config):
super(BertForSequenceClassification, self).__init__(config)
self.output_attentions = output_attentions
self.output_hidden_states = output_hidden_states
self.num_labels = num_labels
self.num_labels = config.num_labels
self.bert = BertModel(config, output_attentions=output_attentions)
self.bert = BertModel(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, num_labels)
self.classifier = nn.Linear(config.hidden_size, self.config.num_labels)
self.apply(self.init_weights)
......@@ -997,7 +984,6 @@ class BertForMultipleChoice(BertPreTrainedModel):
`config`: a BertConfig class instance with the configuration to build a new model
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
`output_hidden_states`: If True, also output hidden states computed by the model at each layer. Default: False
`num_choices`: the number of classes for the classifier. Default = 2.
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, num_choices, sequence_length]
......@@ -1030,25 +1016,23 @@ class BertForMultipleChoice(BertPreTrainedModel):
config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
num_choices = 2
model = BertForMultipleChoice(config, num_choices)
model = BertForMultipleChoice(config)
logits = model(input_ids, token_type_ids, input_mask)
```
"""
def __init__(self, config, num_choices=2, output_attentions=False, output_hidden_states=False):
def __init__(self, config):
super(BertForMultipleChoice, self).__init__(config)
self.output_attentions = output_attentions
self.output_hidden_states = output_hidden_states
self.num_choices = num_choices
self.bert = BertModel(config, output_attentions=output_attentions)
self.bert = BertModel(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, 1)
self.apply(self.init_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, head_mask=None):
""" Input shapes should be [bsz, num choices, seq length] """
num_choices = input_ids.shape[1]
flat_input_ids = input_ids.view(-1, input_ids.size(-1))
flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
......@@ -1057,7 +1041,7 @@ class BertForMultipleChoice(BertPreTrainedModel):
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
reshaped_logits = logits.view(-1, self.num_choices)
reshaped_logits = logits.view(-1, num_choices)
outputs = [reshaped_logits] + outputs[2:] # add hidden states and attention if they are here
......@@ -1118,15 +1102,13 @@ class BertForTokenClassification(BertPreTrainedModel):
logits = model(input_ids, token_type_ids, input_mask)
```
"""
def __init__(self, config, num_labels=2, output_attentions=False, output_hidden_states=False):
def __init__(self, config):
super(BertForTokenClassification, self).__init__(config)
self.output_attentions = output_attentions
self.output_hidden_states = output_hidden_states
self.num_labels = num_labels
self.num_labels = config.num_labels
self.bert = BertModel(config, output_attentions=output_attentions)
self.bert = BertModel(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, num_labels)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self.apply(self.init_weights)
......@@ -1204,12 +1186,12 @@ class BertForQuestionAnswering(BertPreTrainedModel):
start_logits, end_logits = model(input_ids, token_type_ids, input_mask)
```
"""
def __init__(self, config, output_attentions=False, output_hidden_states=False):
def __init__(self, config):
super(BertForQuestionAnswering, self).__init__(config)
self.output_attentions = output_attentions
self.output_hidden_states = output_hidden_states
self.bert = BertModel(config, output_attentions=output_attentions)
self.qa_outputs = nn.Linear(config.hidden_size, 2)
self.num_labels = config.num_labels
self.bert = BertModel(config)
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
self.apply(self.init_weights)
......
......@@ -119,7 +119,8 @@ class GPT2Config(PretrainedConfig):
attn_pdrop=0.1,
layer_norm_epsilon=1e-5,
initializer_range=0.02,
predict_special_tokens=True
predict_special_tokens=True,
**kwargs
):
"""Constructs GPT2Config.
......@@ -142,6 +143,8 @@ class GPT2Config(PretrainedConfig):
initializing all weight matrices.
predict_special_tokens: should we predict special tokens (when the model has a LM head)
"""
super(GPT2Config, self).__init__(**kwargs)
if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
and isinstance(vocab_size_or_config_json_file, unicode)):
with open(vocab_size_or_config_json_file, "r", encoding="utf-8") as reader:
......@@ -174,8 +177,10 @@ class GPT2Config(PretrainedConfig):
class Attention(nn.Module):
def __init__(self, nx, n_ctx, config, scale=False, output_attentions=False, keep_multihead_output=False):
def __init__(self, nx, n_ctx, config, scale=False):
super(Attention, self).__init__()
self.output_attentions = config.output_attentions
n_state = nx # in Attention: n_state=768 (nx=n_embd)
# [switch nx => n_state from Block to Attention to keep identical to TF implem]
assert n_state % config.n_head == 0
......@@ -184,10 +189,6 @@ class Attention(nn.Module):
self.split_size = n_state
self.scale = scale
self.output_attentions = output_attentions
self.keep_multihead_output = keep_multihead_output
self.multihead_output = None
self.c_attn = Conv1D(n_state * 3, nx)
self.c_proj = Conv1D(n_state, nx)
self.attn_dropout = nn.Dropout(config.attn_pdrop)
......@@ -224,9 +225,10 @@ class Attention(nn.Module):
if head_mask is not None:
w = w * head_mask
outputs = [torch.matmul(w, v)]
if self.output_attentions:
return w, torch.matmul(w, v)
return torch.matmul(w, v)
outputs.append(w)
return outputs
def merge_heads(self, x):
x = x.permute(0, 2, 1, 3).contiguous()
......@@ -253,19 +255,15 @@ class Attention(nn.Module):
value = torch.cat((past_value, value), dim=-2)
present = torch.stack((key.transpose(-2, -1), value)) # transpose to have same shapes for stacking
a = self._attn(query, key, value, head_mask)
if self.keep_multihead_output:
self.multihead_output = a
self.multihead_output.retain_grad()
attn_outputs = self._attn(query, key, value, head_mask)
a = attn_outputs[0]
if self.output_attentions:
attentions, a = a
a = self.merge_heads(a)
a = self.c_proj(a)
a = self.resid_dropout(a)
if self.output_attentions:
return attentions, a, present
return a, present
outputs = [a, present] + attn_outputs[1:]
return outputs # a, present, (attentions)
class MLP(nn.Module):
......@@ -284,27 +282,24 @@ class MLP(nn.Module):
class Block(nn.Module):
def __init__(self, n_ctx, config, scale=False, output_attentions=False, keep_multihead_output=False):
def __init__(self, n_ctx, config, scale=False):
super(Block, self).__init__()
nx = config.n_embd
self.output_attentions = output_attentions
self.ln_1 = LayerNorm(nx, eps=config.layer_norm_epsilon)
self.attn = Attention(nx, n_ctx, config, scale, output_attentions, keep_multihead_output)
self.attn = Attention(nx, n_ctx, config, scale)
self.ln_2 = LayerNorm(nx, eps=config.layer_norm_epsilon)
self.mlp = MLP(4 * nx, config)
def forward(self, x, layer_past=None, head_mask=None):
output_attn = self.attn(self.ln_1(x), layer_past=layer_past, head_mask=head_mask)
if self.output_attentions:
attentions, a, present = output_attn
else:
a, present = output_attn
a = output_attn[0] # output_attn: a, present, (attentions)
x = x + a
m = self.mlp(self.ln_2(x))
x = x + m
if self.output_attentions:
return attentions, x, present
return x, present
outputs = [x] + output_attn[1:]
return outputs # x, present, (attentions)
class GPT2LMHead(nn.Module):
......@@ -342,12 +337,17 @@ class GPT2MultipleChoiceHead(nn.Module):
nn.init.normal_(self.linear.weight, std=0.02)
nn.init.normal_(self.linear.bias, 0)
def forward(self, hidden_states, mc_token_ids):
# Classification logits
# hidden_state (bsz, num_choices, seq_length, hidden_size)
# mc_token_ids (bsz, num_choices)
def forward(self, hidden_states, mc_token_ids=None):
""" Extract classification token hidden state and project it using self.linear
hidden_state: shape (bsz, num_choices, seq_length, hidden_size)
mc_token_ids: [optional] index of the classification token, shape (bsz, num_choices)
if mc_token_ids=None we take the last token of the sequence as classification token
"""
if mc_token_ids is None:
mc_token_ids = torch.full_like(hidden_states[:, :, :1, :], hidden_states.shape[2] - 1, dtype=torch.long)
else:
mc_token_ids = mc_token_ids.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, -1, hidden_states.size(-1))
# (bsz, num_choices, 1, hidden_size)
# mc_token_ids has shape (bsz, num_choices, 1, hidden_size)
multiple_choice_h = hidden_states.gather(2, mc_token_ids).squeeze(2)
# (bsz, num_choices, hidden_size)
multiple_choice_h = self.dropout(multiple_choice_h.transpose(1, 2)).transpose(1, 2)
......@@ -362,13 +362,9 @@ class GPT2PreTrainedModel(PreTrainedModel):
"""
config_class = GPT2Config
pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP
pretrained_config_archive_map = PRETRAINED_CONFIG_ARCHIVE_MAP
load_tf_weights = load_tf_weights_in_gpt2
base_model_prefix = "transformer"
def __init__(self, *inputs, **kwargs):
super(GPT2PreTrainedModel, self).__init__(*inputs, **kwargs)
def init_weights(self, module):
""" Initialize the weights.
"""
......@@ -403,126 +399,9 @@ class GPT2PreTrainedModel(PreTrainedModel):
state_dict: an optional state dictionary (collections.OrderedDict object) to use instead of pre-trained models
*inputs, **kwargs: additional input for the specific GPT2 class
"""
# state_dict = kwargs.get('state_dict', None)
# kwargs.pop('state_dict', None)
# cache_dir = kwargs.get('cache_dir', None)
# kwargs.pop('cache_dir', None)
# from_tf = kwargs.get('from_tf', False)
# kwargs.pop('from_tf', None)
num_special_tokens = kwargs.get('num_special_tokens', None)
kwargs.pop('num_special_tokens', None)
# if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
# archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path]
# config_file = PRETRAINED_CONFIG_ARCHIVE_MAP[pretrained_model_name_or_path]
# else:
# archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
# config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
# # redirect to the cache, if necessary
# try:
# resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
# except EnvironmentError:
# if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
# logger.error(
# "Couldn't reach server at '{}' to download pretrained weights.".format(
# archive_file))
# else:
# logger.error(
# "Model name '{}' was not found in model name list ({}). "
# "We assumed '{}' was a path or url but couldn't find file {} "
# "at this path or url.".format(
# pretrained_model_name_or_path, ", ".join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()), pretrained_model_name_or_path,
# archive_file
# )
# )
# return None
# try:
# resolved_config_file = cached_path(config_file, cache_dir=cache_dir)
# except EnvironmentError:
# if pretrained_model_name_or_path in PRETRAINED_CONFIG_ARCHIVE_MAP:
# logger.error(
# "Couldn't reach server at '{}' to download pretrained model configuration file.".format(
# config_file))
# else:
# logger.error(
# "Model name '{}' was not found in model name list ({}). "
# "We assumed '{}' was a path or url but couldn't find file {} "
# "at this path or url.".format(
# pretrained_model_name_or_path, ", ".join(PRETRAINED_CONFIG_ARCHIVE_MAP.keys()), pretrained_model_name_or_path,
# config_file
# )
# )
# return None
# if resolved_archive_file == archive_file and resolved_config_file == config_file:
# logger.info("loading weights file {}".format(archive_file))
# logger.info("loading configuration file {}".format(config_file))
# else:
# logger.info("loading weights file {} from cache at {}".format(
# archive_file, resolved_archive_file))
# logger.info("loading configuration file {} from cache at {}".format(
# config_file, resolved_config_file))
# # Load config
# config = GPT2Config.from_json_file(resolved_config_file)
# logger.info("Model config {}".format(config))
# # Instantiate model.
# model = cls(config, *inputs, **kwargs)
# if state_dict is None and not from_tf:
# state_dict = torch.load(resolved_archive_file, map_location='cpu')
# if from_tf:
# # Directly load from a TensorFlow checkpoint (stored as NumPy array)
# return load_tf_weights_in_gpt2(model, resolved_archive_file)
# old_keys = []
# new_keys = []
# for key in state_dict.keys():
# new_key = None
# if key.endswith(".g"):
# new_key = key[:-2] + ".weight"
# elif key.endswith(".b"):
# new_key = key[:-2] + ".bias"
# elif key.endswith(".w"):
# new_key = key[:-2] + ".weight"
# if new_key:
# old_keys.append(key)
# new_keys.append(new_key)
# for old_key, new_key in zip(old_keys, new_keys):
# state_dict[new_key] = state_dict.pop(old_key)
# missing_keys = []
# unexpected_keys = []
# error_msgs = []
# # copy state_dict so _load_from_state_dict can modify it
# metadata = getattr(state_dict, "_metadata", None)
# state_dict = state_dict.copy()
# if metadata is not None:
# state_dict._metadata = metadata
# def load(module, prefix=""):
# local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
# module._load_from_state_dict(
# state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs
# )
# for name, child in module._modules.items():
# if child is not None:
# load(child, prefix + name + ".")
# start_model = model
# if hasattr(model, "transformer") and all(not s.startswith('transformer.') for s in state_dict.keys()):
# start_model = model.transformer
# load(start_model, prefix="")
# if len(missing_keys) > 0:
# logger.info(
# "Weights of {} not initialized from pretrained model: {}".format(model.__class__.__name__, missing_keys)
# )
# if len(unexpected_keys) > 0:
# logger.info(
# "Weights from pretrained model not used in {}: {}".format(model.__class__.__name__, unexpected_keys)
# )
# if len(error_msgs) > 0:
# raise RuntimeError(
# "Error(s) in loading state_dict for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs))
# )
num_special_tokens = kwargs.pop('num_special_tokens', None)
model = PreTrainedModel.from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs)
# Add additional embeddings for special tokens if needed
# This step also make sure we are still sharing the output and input embeddings after loading weights
......@@ -553,8 +432,6 @@ class GPT2Model(GPT2PreTrainedModel):
Params:
`config`: a GPT2Config class instance with the configuration to build a new model
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
`keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient.
This can be used to compute head importance metrics. Default: False
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] (or more generally [d_1, ..., d_n, sequence_length]
......@@ -591,14 +468,15 @@ class GPT2Model(GPT2PreTrainedModel):
```
"""
def __init__(self, config, output_attentions=False, keep_multihead_output=False):
def __init__(self, config):
super(GPT2Model, self).__init__(config)
self.output_attentions = output_attentions
self.output_hidden_states = config.output_hidden_states
self.output_attentions = config.output_attentions
self.wte = nn.Embedding(config.total_tokens_embeddings, config.n_embd)
self.wpe = nn.Embedding(config.n_positions, config.n_embd)
self.drop = nn.Dropout(config.embd_pdrop)
block = Block(config.n_ctx, config, scale=True, output_attentions=output_attentions,
keep_multihead_output=keep_multihead_output)
block = Block(config.n_ctx, config, scale=True)
self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(config.n_layer)])
self.ln_f = LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
......@@ -618,19 +496,13 @@ class GPT2Model(GPT2PreTrainedModel):
# Copy word embeddings from the previous weights
self.wte.weight.data[:self.config.vocab_size, :] = old_embed.weight.data[:self.config.vocab_size, :]
def prune_heads(self, heads_to_prune):
def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
"""
for layer, heads in heads_to_prune.items():
self.h[layer].attn.prune_heads(heads)
def get_multihead_outputs(self):
""" Gather all multi-head outputs.
Return: list (layers) of multihead module outputs with gradients
"""
return [h.attn.multihead_output for h in self.h]
def forward(self, input_ids, position_ids=None, token_type_ids=None, past=None, head_mask=None):
if past is None:
past_length = 0
......@@ -675,20 +547,32 @@ class GPT2Model(GPT2PreTrainedModel):
all_attentions = []
all_hidden_states = []
for i, (block, layer_past) in enumerate(zip(self.h, past)):
if self.output_hidden_states:
all_hidden_states.append(hidden_states.view(*output_shape))
outputs = block(hidden_states, layer_past, head_mask[i])
if self.output_attentions:
attentions, hidden_states, present = outputs
all_attentions.append(attentions)
else:
hidden_states, present = outputs
hidden_states, present = outputs[:2]
presents.append(present)
if self.output_attentions:
all_attentions.append(outputs[2])
hidden_states = self.ln_f(hidden_states)
all_hidden_states.append(hidden_states.view(*output_shape))
hidden_states = hidden_states.view(*output_shape)
# Add last hidden state
if self.output_hidden_states:
all_hidden_states.append(hidden_states)
outputs = [hidden_states, presents]
if self.output_hidden_states:
outputs.append(all_hidden_states)
if self.output_attentions:
return all_attentions, all_hidden_states, presents
return all_hidden_states, presents
# let the number of heads free (-1) so we can extract attention even after head pruning
attention_output_shape = input_shape[:-1] + (-1,) + all_attentions[0].shape[-2:]
all_attentions = list(t.view(*attention_output_shape) for t in all_attentions)
outputs.append(all_attentions)
return outputs # last hidden state, presents, (all hidden_states), (attentions)
class GPT2LMHeadModel(GPT2PreTrainedModel):
......@@ -740,10 +624,9 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
```
"""
def __init__(self, config, output_attentions=False, keep_multihead_output=False):
def __init__(self, config):
super(GPT2LMHeadModel, self).__init__(config)
self.transformer = GPT2Model(config, output_attentions=output_attentions,
keep_multihead_output=keep_multihead_output)
self.transformer = GPT2Model(config)
self.lm_head = GPT2LMHead(self.transformer.wte.weight, config)
self.apply(self.init_weights)
......@@ -756,14 +639,12 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
self.lm_head.set_embeddings_weights(self.transformer.wte.weight, predict_special_tokens=predict_special_tokens)
def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None, past=None, head_mask=None):
transformer_output = self.transformer(input_ids, position_ids, token_type_ids, past, head_mask)
if self.transformer.output_attentions:
all_attentions, hidden_states, presents = transformer_output
else:
hidden_states, presents = transformer_output
hidden_states = hidden_states[-1]
transformer_outputs = self.transformer(input_ids, position_ids, token_type_ids, past, head_mask)
hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states)
outputs = [lm_logits] + transformer_outputs[1:]
if lm_labels is not None:
# Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].contiguous()
......@@ -772,10 +653,9 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
loss_fct = CrossEntropyLoss(ignore_index=-1)
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1))
return loss
if self.transformer.output_attentions:
return all_attentions, lm_logits, presents
return lm_logits, presents
outputs = [loss] + outputs
return outputs # (loss), lm_logits, presents, (all hidden_states), (attentions)
class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
......@@ -832,12 +712,12 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
```
"""
def __init__(self, config, output_attentions=False, keep_multihead_output=False):
def __init__(self, config):
super(GPT2DoubleHeadsModel, self).__init__(config)
self.transformer = GPT2Model(config, output_attentions=output_attentions,
keep_multihead_output=keep_multihead_output)
self.transformer = GPT2Model(config)
self.lm_head = GPT2LMHead(self.transformer.wte.weight, config)
self.multiple_choice_head = GPT2MultipleChoiceHead(config)
self.apply(self.init_weights)
def set_num_special_tokens(self, num_special_tokens, predict_special_tokens=True):
......@@ -848,28 +728,26 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
self.transformer.set_num_special_tokens(num_special_tokens)
self.lm_head.set_embeddings_weights(self.transformer.wte.weight, predict_special_tokens=predict_special_tokens)
def forward(self, input_ids, mc_token_ids, lm_labels=None, mc_labels=None, token_type_ids=None,
def forward(self, input_ids, mc_token_ids=None, lm_labels=None, mc_labels=None, token_type_ids=None,
position_ids=None, past=None, head_mask=None):
transformer_output = self.transformer(input_ids, position_ids, token_type_ids, past, head_mask)
if self.transformer.output_attentions:
all_attentions, hidden_states, presents = transformer_output
else:
hidden_states, presents = transformer_output
hidden_states = hidden_states[-1]
transformer_outputs = self.transformer(input_ids, position_ids, token_type_ids, past, head_mask)
hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states)
mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids)
losses = []
outputs = [lm_logits, mc_logits] + transformer_outputs[1:]
if mc_labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)),
mc_labels.view(-1))
outputs = [loss] + outputs
if lm_labels is not None:
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = lm_labels[..., 1:].contiguous()
loss_fct = CrossEntropyLoss(ignore_index=-1)
losses.append(loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)))
if mc_labels is not None:
loss_fct = CrossEntropyLoss()
losses.append(loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1)))
if losses:
return losses
if self.transformer.output_attentions:
return all_attentions, lm_logits, mc_logits, presents
return lm_logits, mc_logits, presents
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1))
outputs = [loss] + outputs
return outputs # (lm loss), (mc loss), lm logits, mc logits, presents, (all hidden_states), (attentions)
......@@ -147,7 +147,8 @@ class OpenAIGPTConfig(PretrainedConfig):
attn_pdrop=0.1,
layer_norm_epsilon=1e-5,
initializer_range=0.02,
predict_special_tokens=True
predict_special_tokens=True,
**kwargs
):
"""Constructs OpenAIGPTConfig.
......@@ -172,6 +173,8 @@ class OpenAIGPTConfig(PretrainedConfig):
initializing all weight matrices.
predict_special_tokens: should we predict special tokens (when the model has a LM head)
"""
super(OpenAIGPTConfig, self).__init__(**kwargs)
if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
and isinstance(vocab_size_or_config_json_file, unicode)):
with open(vocab_size_or_config_json_file, "r", encoding="utf-8") as reader:
......@@ -205,7 +208,7 @@ class OpenAIGPTConfig(PretrainedConfig):
class Attention(nn.Module):
def __init__(self, nx, n_ctx, config, scale=False, output_attentions=False, keep_multihead_output=False):
def __init__(self, nx, n_ctx, config, scale=False):
super(Attention, self).__init__()
n_state = nx # in Attention: n_state=768 (nx=n_embd)
# [switch nx => n_state from Block to Attention to keep identical to TF implem]
......@@ -215,9 +218,7 @@ class Attention(nn.Module):
self.split_size = n_state
self.scale = scale
self.output_attentions = output_attentions
self.keep_multihead_output = keep_multihead_output
self.multihead_output = None
self.output_attentions = config.output_attentions
self.c_attn = Conv1D(n_state * 3, nx)
self.c_proj = Conv1D(n_state, nx)
......@@ -256,9 +257,10 @@ class Attention(nn.Module):
if head_mask is not None:
w = w * head_mask
outputs = [torch.matmul(w, v)]
if self.output_attentions:
return w, torch.matmul(w, v)
return torch.matmul(w, v)
outputs.append(w)
return outputs
def merge_heads(self, x):
x = x.permute(0, 2, 1, 3).contiguous()
......@@ -280,19 +282,15 @@ class Attention(nn.Module):
key = self.split_heads(key, k=True)
value = self.split_heads(value)
a = self._attn(query, key, value, head_mask)
if self.keep_multihead_output:
self.multihead_output = a
self.multihead_output.retain_grad()
attn_outputs = self._attn(query, key, value, head_mask)
a = attn_outputs[0]
if self.output_attentions:
attentions, a = a
a = self.merge_heads(a)
a = self.c_proj(a)
a = self.resid_dropout(a)
if self.output_attentions:
return attentions, a
return a
outputs = [a] + attn_outputs[1:]
return outputs # a, (attentions)
class MLP(nn.Module):
......@@ -311,25 +309,24 @@ class MLP(nn.Module):
class Block(nn.Module):
def __init__(self, n_ctx, config, scale=False, output_attentions=False, keep_multihead_output=False):
def __init__(self, n_ctx, config, scale=False):
super(Block, self).__init__()
nx = config.n_embd
self.output_attentions = output_attentions
self.attn = Attention(nx, n_ctx, config, scale, output_attentions, keep_multihead_output)
self.attn = Attention(nx, n_ctx, config, scale)
self.ln_1 = LayerNorm(nx, eps=config.layer_norm_epsilon)
self.mlp = MLP(4 * nx, config)
self.ln_2 = LayerNorm(nx, eps=config.layer_norm_epsilon)
def forward(self, x, head_mask=None):
a = self.attn(x, head_mask=head_mask)
if self.output_attentions:
attentions, a = a
attn_outputs = self.attn(x, head_mask=head_mask)
a = attn_outputs[0]
n = self.ln_1(x + a)
m = self.mlp(n)
h = self.ln_2(n + m)
if self.output_attentions:
return attentions, h
return h
outputs = [h] + attn_outputs[1:]
return outputs
class OpenAIGPTLMHead(nn.Module):
......@@ -368,10 +365,15 @@ class OpenAIGPTMultipleChoiceHead(nn.Module):
nn.init.normal_(self.linear.weight, std=0.02)
nn.init.normal_(self.linear.bias, 0)
def forward(self, hidden_states, mc_token_ids):
# Classification logits
# hidden_state (bsz, num_choices, seq_length, hidden_size)
# mc_token_ids (bsz, num_choices)
def forward(self, hidden_states, mc_token_ids=None):
""" Extract classification token hidden state and project it using self.linear
hidden_state: hidden state of shape (bsz, num_choices, seq_length, hidden_size)
mc_token_ids: [optional] index of the classification token, shape (bsz, num_choices)
if mc_token_ids=None we take the last token of the sequence as classification token
"""
if mc_token_ids is None:
mc_token_ids = torch.full_like(hidden_states[:, :, :1, :], hidden_states.shape[2] - 1, dtype=torch.long)
else:
mc_token_ids = mc_token_ids.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, -1, hidden_states.size(-1))
# (bsz, num_choices, 1, hidden_size)
multiple_choice_h = hidden_states.gather(2, mc_token_ids).squeeze(2)
......@@ -388,13 +390,9 @@ class OpenAIGPTPreTrainedModel(PreTrainedModel):
"""
config_class = OpenAIGPTConfig
pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP
pretrained_config_archive_map = PRETRAINED_CONFIG_ARCHIVE_MAP
load_tf_weights = load_tf_weights_in_openai_gpt
base_model_prefix = "transformer"
def __init__(self, *inputs, **kwargs):
super(OpenAIGPTPreTrainedModel, self).__init__(*inputs, **kwargs)
def init_weights(self, module):
""" Initialize the weights.
"""
......@@ -495,14 +493,15 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
```
"""
def __init__(self, config, output_attentions=False, keep_multihead_output=False):
def __init__(self, config):
super(OpenAIGPTModel, self).__init__(config)
self.output_attentions = output_attentions
self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states
self.tokens_embed = nn.Embedding(config.total_tokens_embeddings, config.n_embd)
self.positions_embed = nn.Embedding(config.n_positions, config.n_embd)
self.drop = nn.Dropout(config.embd_pdrop)
block = Block(config.n_ctx, config, scale=True, output_attentions=output_attentions,
keep_multihead_output=keep_multihead_output)
block = Block(config.n_ctx, config, scale=True)
self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(config.n_layer)])
self.apply(self.init_weights)
......@@ -521,19 +520,13 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
# Copy word embeddings from the previous weights
self.tokens_embed.weight.data[:self.config.vocab_size, :] = old_embed.weight.data[:self.config.vocab_size, :]
def prune_heads(self, heads_to_prune):
def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
"""
for layer, heads in heads_to_prune.items():
self.h[layer].attn.prune_heads(heads)
def get_multihead_outputs(self):
""" Gather all multi-head outputs.
Return: list (layers) of multihead module outputs with gradients
"""
return [h.attn.multihead_output for h in self.h]
def forward(self, input_ids, position_ids=None, token_type_ids=None, head_mask=None):
if position_ids is None:
# This was used when we had a single embedding matrice from position and token embeddings
......@@ -574,19 +567,26 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
output_shape = input_shape + (hidden_states.size(-1),)
all_attentions = []
all_hidden_states = [hidden_states.view(*output_shape)]
all_hidden_states = []
for i, block in enumerate(self.h):
if self.output_hidden_states:
all_hidden_states.append(hidden_states.view(*output_shape))
outputs = block(hidden_states, head_mask[i])
hidden_states = outputs[0]
if self.output_attentions:
attentions, hidden_states = outputs
all_attentions.append(attentions)
else:
hidden_states = outputs
all_attentions.append(outputs[1])
# Add last layer
if self.output_hidden_states:
all_hidden_states.append(hidden_states.view(*output_shape))
outputs = [hidden_states.view(*output_shape)]
if self.output_hidden_states:
outputs.append(all_hidden_states)
if self.output_attentions:
return all_attentions, all_hidden_states
return all_hidden_states
outputs.append(all_attentions)
return outputs # last hidden state, (all hidden states), (all attentions)
class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
......@@ -650,10 +650,9 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
```
"""
def __init__(self, config, output_attentions=False, keep_multihead_output=False):
def __init__(self, config):
super(OpenAIGPTLMHeadModel, self).__init__(config)
self.transformer = OpenAIGPTModel(config, output_attentions=output_attentions,
keep_multihead_output=keep_multihead_output)
self.transformer = OpenAIGPTModel(config)
self.lm_head = OpenAIGPTLMHead(self.transformer.tokens_embed.weight, config)
self.apply(self.init_weights)
......@@ -666,12 +665,11 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
self.lm_head.set_embeddings_weights(self.transformer.tokens_embed.weight, predict_special_tokens=predict_special_tokens)
def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None, head_mask=None):
hidden_states = self.transformer(input_ids, position_ids, token_type_ids, head_mask)
if self.transformer.output_attentions:
all_attentions, hidden_states = hidden_states
hidden_states = hidden_states[-1]
transformer_outputs = self.transformer(input_ids, position_ids, token_type_ids, head_mask)
hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states)
outputs = [lm_logits] + transformer_outputs[1:]
if lm_labels is not None:
# Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].contiguous()
......@@ -680,10 +678,9 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
loss_fct = CrossEntropyLoss(ignore_index=-1)
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1))
return loss
if self.transformer.output_attentions:
return all_attentions, lm_logits
return lm_logits
outputs = [loss] + outputs
return outputs # (loss), lm_logits, (all hidden states), (all attentions)
class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
......@@ -752,10 +749,9 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
```
"""
def __init__(self, config, output_attentions=False, keep_multihead_output=False):
def __init__(self, config):
super(OpenAIGPTDoubleHeadsModel, self).__init__(config)
self.transformer = OpenAIGPTModel(config, output_attentions=output_attentions,
keep_multihead_output=keep_multihead_output)
self.transformer = OpenAIGPTModel(config)
self.lm_head = OpenAIGPTLMHead(self.transformer.tokens_embed.weight, config)
self.multiple_choice_head = OpenAIGPTMultipleChoiceHead(config)
self.apply(self.init_weights)
......@@ -768,26 +764,26 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
self.transformer.set_num_special_tokens(num_special_tokens)
self.lm_head.set_embeddings_weights(self.transformer.tokens_embed.weight, predict_special_tokens=predict_special_tokens)
def forward(self, input_ids, mc_token_ids, lm_labels=None, mc_labels=None, token_type_ids=None,
def forward(self, input_ids, mc_token_ids=None, lm_labels=None, mc_labels=None, token_type_ids=None,
position_ids=None, head_mask=None):
hidden_states = self.transformer(input_ids, position_ids, token_type_ids, head_mask)
if self.transformer.output_attentions:
all_attentions, hidden_states = hidden_states
hidden_states = hidden_states[-1]
transformer_outputs = self.transformer(input_ids, position_ids, token_type_ids, head_mask)
hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states)
mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids)
losses = []
outputs = [lm_logits, mc_logits] + transformer_outputs[1:]
if mc_labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)),
mc_labels.view(-1))
outputs = [loss] + outputs
if lm_labels is not None:
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = lm_labels[..., 1:].contiguous()
loss_fct = CrossEntropyLoss(ignore_index=-1)
losses.append(loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)))
if mc_labels is not None:
loss_fct = CrossEntropyLoss()
losses.append(loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1)))
if losses:
return losses
if self.transformer.output_attentions:
return all_attentions, lm_logits, mc_logits
return lm_logits, mc_logits
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1))
outputs = [loss] + outputs
return outputs # (lm loss), (mc loss), lm logits, mc logits, (all hidden_states), (attentions)
......@@ -209,7 +209,8 @@ class TransfoXLConfig(PretrainedConfig):
init="normal",
init_range=0.01,
proj_init_std=0.01,
init_std=0.02):
init_std=0.02,
**kwargs):
"""Constructs TransfoXLConfig.
Args:
......@@ -244,6 +245,8 @@ class TransfoXLConfig(PretrainedConfig):
proj_init_std: parameters initialized by N(0, init_std)
init_std: parameters initialized by N(0, init_std)
"""
super(TransfoXLConfig, self).__init__(**kwargs)
if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
and isinstance(vocab_size_or_config_json_file, unicode)):
with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
......@@ -287,6 +290,7 @@ class TransfoXLConfig(PretrainedConfig):
"or the path to a pretrained model config file (str)")
class PositionalEmbedding(nn.Module):
def __init__(self, demb):
super(PositionalEmbedding, self).__init__()
......@@ -306,6 +310,7 @@ class PositionalEmbedding(nn.Module):
return pos_emb[:,None,:]
class PositionwiseFF(nn.Module):
def __init__(self, d_model, d_inner, dropout, pre_lnorm=False):
super(PositionwiseFF, self).__init__()
......@@ -341,11 +346,14 @@ class PositionwiseFF(nn.Module):
return output
class MultiHeadAttn(nn.Module):
def __init__(self, n_head, d_model, d_head, dropout, dropatt=0,
pre_lnorm=False, r_r_bias=None, r_w_bias=None):
pre_lnorm=False, r_r_bias=None, r_w_bias=None, output_attentions=False):
super(MultiHeadAttn, self).__init__()
self.output_attentions = output_attentions
self.n_head = n_head
self.d_model = d_model
self.d_head = d_head
......@@ -371,7 +379,7 @@ class MultiHeadAttn(nn.Module):
self.r_r_bias = r_r_bias
self.r_w_bias = r_w_bias
def forward(self, h, attn_mask=None, mems=None):
def forward(self, h, attn_mask=None, mems=None, head_mask=None):
##### multihead attention
# [hlen x bsz x n_head x d_head]
......@@ -404,6 +412,10 @@ class MultiHeadAttn(nn.Module):
attn_prob = F.softmax(attn_score, dim=1)
attn_prob = self.dropatt(attn_prob)
# Mask heads if we want to
if head_mask is not None:
attn_prob = attn_prob * head_mask
# [qlen x klen x bsz x n_head] + [klen x bsz x n_head x d_head] -> [qlen x bsz x n_head x d_head]
attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, head_v))
attn_vec = attn_vec.contiguous().view(
......@@ -415,19 +427,23 @@ class MultiHeadAttn(nn.Module):
if self.pre_lnorm:
##### residual connection
output = h + attn_out
outputs = [h + attn_out]
else:
##### residual connection + layer normalization
output = self.layer_norm(h + attn_out)
outputs = [self.layer_norm(h + attn_out)]
return output
if self.output_attentions:
outputs.append(attn_prob)
return outputs
class RelMultiHeadAttn(nn.Module):
def __init__(self, n_head, d_model, d_head, dropout, dropatt=0,
tgt_len=None, ext_len=None, mem_len=None, pre_lnorm=False,
r_r_bias=None, r_w_bias=None):
r_r_bias=None, r_w_bias=None, output_attentions=False):
super(RelMultiHeadAttn, self).__init__()
self.output_attentions = output_attentions
self.n_head = n_head
self.d_model = d_model
self.d_head = d_head
......@@ -506,7 +522,7 @@ class RelPartialLearnableMultiHeadAttn(RelMultiHeadAttn):
self.r_net = nn.Linear(self.d_model, self.n_head * self.d_head, bias=False)
def forward(self, w, r, attn_mask=None, mems=None):
def forward(self, w, r, attn_mask=None, mems=None, head_mask=None):
qlen, rlen, bsz = w.size(0), r.size(0), w.size(1)
if mems is not None:
......@@ -561,6 +577,10 @@ class RelPartialLearnableMultiHeadAttn(RelMultiHeadAttn):
attn_prob = F.softmax(attn_score, dim=1)
attn_prob = self.dropatt(attn_prob)
# Mask heads if we want to
if head_mask is not None:
attn_prob = attn_prob * head_mask
#### compute attention vector
attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, w_head_v))
......@@ -574,18 +594,21 @@ class RelPartialLearnableMultiHeadAttn(RelMultiHeadAttn):
if self.pre_lnorm:
##### residual connection
output = w + attn_out
outputs = [w + attn_out]
else:
##### residual connection + layer normalization
output = self.layer_norm(w + attn_out)
outputs = [self.layer_norm(w + attn_out)]
return output
if self.output_attentions:
outputs.append(attn_prob)
return outputs
class RelLearnableMultiHeadAttn(RelMultiHeadAttn):
def __init__(self, *args, **kwargs):
super(RelLearnableMultiHeadAttn, self).__init__(*args, **kwargs)
def forward(self, w, r_emb, r_w_bias, r_bias, attn_mask=None, mems=None):
def forward(self, w, r_emb, r_w_bias, r_bias, attn_mask=None, mems=None, head_mask=None):
# r_emb: [klen, n_head, d_head], used for term B
# r_w_bias: [n_head, d_head], used for term C
# r_bias: [klen, n_head], used for term D
......@@ -646,6 +669,9 @@ class RelLearnableMultiHeadAttn(RelMultiHeadAttn):
attn_prob = F.softmax(attn_score, dim=1)
attn_prob = self.dropatt(attn_prob)
if head_mask is not None:
attn_prob = attn_prob * head_mask
#### compute attention vector
attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, w_head_v))
......@@ -659,12 +685,17 @@ class RelLearnableMultiHeadAttn(RelMultiHeadAttn):
if self.pre_lnorm:
##### residual connection
output = w + attn_out
outputs = [w + attn_out]
else:
##### residual connection + layer normalization
output = self.layer_norm(w + attn_out)
outputs = [self.layer_norm(w + attn_out)]
if self.output_attentions:
outputs.append(attn_prob)
return outputs
return output
class DecoderLayer(nn.Module):
def __init__(self, n_head, d_model, d_head, d_inner, dropout, **kwargs):
......@@ -674,13 +705,15 @@ class DecoderLayer(nn.Module):
self.pos_ff = PositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm'))
def forward(self, dec_inp, dec_attn_mask=None, mems=None):
def forward(self, dec_inp, dec_attn_mask=None, mems=None, head_mask=None):
output = self.dec_attn(dec_inp, attn_mask=dec_attn_mask,
mems=mems)
output = self.pos_ff(output)
attn_outputs = self.dec_attn(dec_inp, attn_mask=dec_attn_mask,
mems=mems, head_mask=head_mask)
ff_output = self.pos_ff(attn_outputs[0])
return output
outputs = [ff_output] + attn_outputs[1:]
return outputs
class RelLearnableDecoderLayer(nn.Module):
def __init__(self, n_head, d_model, d_head, d_inner, dropout,
......@@ -692,14 +725,16 @@ class RelLearnableDecoderLayer(nn.Module):
self.pos_ff = PositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm'))
def forward(self, dec_inp, r_emb, r_w_bias, r_bias, dec_attn_mask=None, mems=None):
def forward(self, dec_inp, r_emb, r_w_bias, r_bias, dec_attn_mask=None, mems=None, head_mask=None):
output = self.dec_attn(dec_inp, r_emb, r_w_bias, r_bias,
attn_outputs = self.dec_attn(dec_inp, r_emb, r_w_bias, r_bias,
attn_mask=dec_attn_mask,
mems=mems)
output = self.pos_ff(output)
mems=mems, head_mask=head_mask)
ff_output = self.pos_ff(attn_outputs[0])
return output
outputs = [ff_output] + attn_outputs[1:]
return outputs
class RelPartialLearnableDecoderLayer(nn.Module):
def __init__(self, n_head, d_model, d_head, d_inner, dropout,
......@@ -711,14 +746,17 @@ class RelPartialLearnableDecoderLayer(nn.Module):
self.pos_ff = PositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm'))
def forward(self, dec_inp, r, dec_attn_mask=None, mems=None):
def forward(self, dec_inp, r, dec_attn_mask=None, mems=None, head_mask=None):
output = self.dec_attn(dec_inp, r,
attn_outputs = self.dec_attn(dec_inp, r,
attn_mask=dec_attn_mask,
mems=mems)
output = self.pos_ff(output)
mems=mems, head_mask=head_mask)
ff_output = self.pos_ff(attn_outputs[0])
outputs = [ff_output] + attn_outputs[1:]
return outputs
return output
class AdaptiveEmbedding(nn.Module):
......@@ -791,13 +829,9 @@ class TransfoXLPreTrainedModel(PreTrainedModel):
"""
config_class = TransfoXLConfig
pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP
pretrained_config_archive_map = PRETRAINED_CONFIG_ARCHIVE_MAP
load_tf_weights = load_tf_weights_in_transfo_xl
base_model_prefix = "transformer"
def __init__(self, *inputs, **kwargs):
super(TransfoXLPreTrainedModel, self).__init__(*inputs, **kwargs)
def _init_weight(self, weight):
if self.config.init == 'uniform':
nn.init.uniform_(weight, -self.config.init_range, self.config.init_range)
......@@ -894,6 +928,9 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
"""
def __init__(self, config):
super(TransfoXLModel, self).__init__(config)
self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states
self.n_token = config.n_token
self.d_embed = config.d_embed
......@@ -928,7 +965,8 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
tgt_len=config.tgt_len, ext_len=config.ext_len, mem_len=config.mem_len,
dropatt=config.dropatt, pre_lnorm=config.pre_lnorm,
r_w_bias=None if config.untie_r else self.r_w_bias,
r_r_bias=None if config.untie_r else self.r_r_bias)
r_r_bias=None if config.untie_r else self.r_r_bias,
output_attentions=self.output_attentions)
)
elif config.attn_type == 1: # learnable embeddings
for i in range(config.n_layer):
......@@ -938,7 +976,8 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
tgt_len=config.tgt_len, ext_len=config.ext_len, mem_len=config.mem_len,
dropatt=config.dropatt, pre_lnorm=config.pre_lnorm,
r_w_bias=None if config.untie_r else self.r_w_bias,
r_r_bias=None if config.untie_r else self.r_r_bias)
r_r_bias=None if config.untie_r else self.r_r_bias,
output_attentions=self.output_attentions)
)
elif config.attn_type in [2, 3]: # absolute embeddings
for i in range(config.n_layer):
......@@ -947,7 +986,8 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
config.n_head, config.d_model, config.d_head, config.d_inner, config.dropout,
dropatt=config.dropatt, pre_lnorm=config.pre_lnorm,
r_w_bias=None if config.untie_r else self.r_w_bias,
r_r_bias=None if config.untie_r else self.r_r_bias)
r_r_bias=None if config.untie_r else self.r_r_bias,
output_attentions=self.output_attentions)
)
self.same_length = config.same_length
......@@ -965,17 +1005,21 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
elif self.attn_type == 3: # absolute deeper SA
self.r_emb = nn.Parameter(torch.Tensor(
self.n_layer, self.max_klen, self.n_head, self.d_head))
self.apply(self.init_weights)
def backward_compatible(self):
self.sample_softmax = -1
def reset_length(self, tgt_len, ext_len, mem_len):
self.tgt_len = tgt_len
self.mem_len = mem_len
self.ext_len = ext_len
def _prune_heads(self, heads):
logger.info("Head pruning is not implemented for Transformer-XL model")
pass
def init_mems(self, data):
if self.mem_len > 0:
mems = []
......@@ -1012,9 +1056,24 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
return new_mems
def _forward(self, dec_inp, mems=None):
def _forward(self, dec_inp, mems=None, head_mask=None):
qlen, bsz = dec_inp.size()
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] (a head_mask for each layer)
# and head_mask is converted to shape [num_hidden_layers x qlen x klen x bsz x n_head]
if head_mask is not None:
if head_mask.dim() == 1:
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(0).unsqueeze(0)
head_mask = head_mask.expand(self.n_layer, -1, -1, -1, -1)
elif head_mask.dim() == 2:
head_mask = head_mask.unsqueeze(1).unsqueeze(1).unsqueeze(1)
head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
else:
head_mask = [None] * self.n_layer
word_emb = self.word_emb(dec_inp)
mlen = mems[0].size(0) if mems is not None else 0
......@@ -1033,6 +1092,7 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
word_emb.new_ones(qlen, klen), diagonal=1+mlen).byte()[:,:,None]
hids = []
attentions = []
if self.attn_type == 0: # default
pos_seq = torch.arange(klen-1, -1, -1.0, device=word_emb.device,
dtype=word_emb.dtype)
......@@ -1046,7 +1106,11 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
for i, layer in enumerate(self.layers):
hids.append(core_out)
mems_i = None if mems is None else mems[i]
core_out = layer(core_out, pos_emb, dec_attn_mask=dec_attn_mask, mems=mems_i)
layer_outputs = layer(core_out, pos_emb, dec_attn_mask=dec_attn_mask,
mems=mems_i, head_mask=head_mask[i])
core_out = layer_outputs[0]
if self.output_attentions:
attentions.append(layer_outputs[1])
elif self.attn_type == 1: # learnable
core_out = self.drop(word_emb)
for i, layer in enumerate(self.layers):
......@@ -1058,8 +1122,12 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
r_emb, r_bias = self.r_emb[i], self.r_bias[i]
mems_i = None if mems is None else mems[i]
core_out = layer(core_out, r_emb, self.r_w_bias[i],
r_bias, dec_attn_mask=dec_attn_mask, mems=mems_i)
layer_outputs = layer(core_out, r_emb, self.r_w_bias[i],
r_bias, dec_attn_mask=dec_attn_mask,
mems=mems_i, head_mask=head_mask[i])
core_out = layer_outputs[0]
if self.output_attentions:
attentions.append(layer_outputs[1])
elif self.attn_type == 2: # absolute
pos_seq = torch.arange(klen - 1, -1, -1.0, device=word_emb.device,
dtype=word_emb.dtype)
......@@ -1074,8 +1142,11 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
mems_i = None if mems is None else mems[i]
if mems_i is not None and i == 0:
mems_i += pos_emb[:mlen]
core_out = layer(core_out, dec_attn_mask=dec_attn_mask,
mems=mems_i)
layer_outputs = layer(core_out, dec_attn_mask=dec_attn_mask,
mems=mems_i, head_mask=head_mask[i])
core_out = layer_outputs[0]
if self.output_attentions:
attentions.append(layer_outputs[1])
elif self.attn_type == 3:
core_out = self.drop(word_emb)
......@@ -1093,16 +1164,30 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
mems_i += cur_emb.view(mlen, 1, -1)
core_out += self.r_emb[i][-qlen:].view(qlen, 1, -1)
core_out = layer(core_out, dec_attn_mask=dec_attn_mask,
mems=mems_i)
layer_outputs = layer(core_out, dec_attn_mask=dec_attn_mask,
mems=mems_i, head_mask=head_mask[i])
core_out = layer_outputs[0]
if self.output_attentions:
attentions.append(layer_outputs[1])
core_out = self.drop(core_out)
new_mems = self._update_mems(hids, mems, mlen, qlen)
return core_out, new_mems
def forward(self, input_ids, mems=None):
# We transpose back here to shape [bsz, len, hidden_dim]
outputs = [core_out.transpose(0, 1).contiguous(), new_mems]
if self.output_hidden_states:
# Add last layer and transpose to library standard shape [bsz, len, hidden_dim]
hids.append(core_out)
hids = list(t.transpose(0, 1).contiguous() for t in hids)
outputs.append(hids)
if self.output_attentions:
# Transpose to library standard shape [bsz, n_heads, query_seq_len, key_seq_len]
attentions = list(t.permute(2, 3, 0, 1).contiguous() for t in attentions)
outputs.append(attentions)
return outputs # last hidden state, new_mems, (all hidden states), (all attentions)
def forward(self, input_ids, mems=None, head_mask=None):
""" Params:
input_ids :: [bsz, len]
mems :: optional mems from previous forwar passes (or init_mems)
......@@ -1122,11 +1207,9 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
if mems is None:
mems = self.init_mems(input_ids)
last_hidden, new_mems = self._forward(input_ids, mems=mems)
outputs = self._forward(input_ids, mems=mems, head_mask=head_mask)
# We transpose back here to shape [bsz, len, hidden_dim]
last_hidden = last_hidden.transpose(0, 1).contiguous()
return (last_hidden, new_mems)
return outputs # last hidden state, new_mems, (all hidden states), (all attentions)
class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
......@@ -1218,7 +1301,7 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
def init_mems(self, data):
return self.transformer.init_mems(data)
def forward(self, input_ids, labels=None, mems=None):
def forward(self, input_ids, labels=None, mems=None, head_mask=None):
""" Params:
input_ids :: [bsz, len]
labels :: [bsz, len]
......@@ -1235,19 +1318,26 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
bsz = input_ids.size(0)
tgt_len = input_ids.size(1)
last_hidden, new_mems = self.transformer(input_ids, mems)
transformer_outputs = self.transformer(input_ids, mems, head_mask)
last_hidden = transformer_outputs[0]
pred_hid = last_hidden[:, -tgt_len:]
outputs = transformer_outputs[1:]
if self.sample_softmax > 0 and self.training:
assert self.config.tie_weight
logit = sample_logits(self.transformer.word_emb, self.out_layer.bias, labels, pred_hid, self.sampler)
softmax_output = -F.log_softmax(logit, -1)[:, :, 0]
outputs = [softmax_output] + outputs
if labels is not None:
# TODO: This is not implemented
raise NotImplementedError
else:
softmax_output = self.crit(pred_hid.view(-1, pred_hid.size(-1)), labels)
if labels is None:
softmax_output = softmax_output.view(bsz, tgt_len, -1)
outputs = [softmax_output] + outputs
else:
softmax_output = softmax_output.view(bsz, tgt_len)
outputs = [softmax_output, None] + outputs
# We transpose back
return (softmax_output, new_mems)
return outputs # (loss), logits or None if labels is not None (speed up adaptive softmax), new_mems, (all hidden states), (all attentions)
......@@ -73,6 +73,7 @@ class XLMConfig(PretrainedConfig):
def __init__(self,
vocab_size_or_config_json_file,
causal=True,
d_model=1024,
n_layer=24,
n_head=16,
......@@ -145,6 +146,7 @@ class XLMConfig(PretrainedConfig):
self.__dict__[key] = value
elif isinstance(vocab_size_or_config_json_file, int):
self.n_token = vocab_size_or_config_json_file
self.causal = causal
self.d_model = d_model
self.n_layer = n_layer
self.n_head = n_head
......@@ -396,7 +398,6 @@ class XLMPreTrainedModel(PreTrainedModel):
"""
config_class = XLMConfig
pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP
pretrained_config_archive_map = PRETRAINED_CONFIG_ARCHIVE_MAP
load_tf_weights = None
base_model_prefix = "xlm"
......@@ -429,7 +430,7 @@ class XLMModel(XLMPreTrainedModel):
'hidden_dim', 'dropout', 'attention_dropout', 'asm',
'asm_cutoffs', 'asm_div_value']
def __init__(self, params, output_attentions=False, keep_multihead_output=False): #, dico, is_encoder, with_output):
def __init__(self, params, output_attentions=False, output_hidden_states=False): #, dico, is_encoder, with_output):
""" XLM model from: "Cross-lingual Language Model Pretraining" by Guillaume Lample, Alexis Conneau
Paper: https://arxiv.org/abs/1901.07291
Original code: https://github.com/facebookresearch/XLM
......@@ -483,11 +484,13 @@ class XLMModel(XLMPreTrainedModel):
"""
super(XLMModel, self).__init__(params)
self.output_attentions = output_attentions
self.output_hidden_states = output_hidden_states
# encoder / decoder, output layer
# self.is_encoder = is_encoder
# self.is_decoder = not is_encoder
# self.with_output = with_output
self.causal = params.causal
# dictionary / languages
self.n_langs = params.n_langs
......@@ -536,63 +539,45 @@ class XLMModel(XLMPreTrainedModel):
self.ffns.append(TransformerFFN(self.dim, self.hidden_dim, self.dim, dropout=self.dropout, gelu_activation=params.gelu_activation))
self.layer_norm2.append(nn.LayerNorm(self.dim, eps=1e-12))
# output layer
# if self.with_output:
# self.pred_layer = PredLayer(params)
# if params.share_inout_emb:
# self.pred_layer.proj.weight = self.embeddings.weight
# def forward(self, mode, **kwargs):
# """
# Forward function with different forward modes.
# ### Small hack to handle PyTorch distributed.
# """
# if mode == 'fwd':
# return self.fwd(**kwargs)
# elif mode == 'predict':
# return self.predict(**kwargs)
# else:
# raise Exception("Unknown mode: %s" % mode)
def forward(self, x, lengths, causal, src_enc=None, src_len=None, positions=None, langs=None, cache=None):
def forward(self, x, lengths, positions=None, langs=None, cache=None, head_mask=None): # src_enc=None, src_len=None,
"""
Inputs:
`x` LongTensor(slen, bs), containing word indices
`x` LongTensor(bs, slen), containing word indices
`lengths` LongTensor(bs), containing the length of each sentence
`causal` Boolean, if True, the attention is only done over previous hidden states
`positions` LongTensor(slen, bs), containing word positions
`langs` LongTensor(slen, bs), containing language IDs
`positions` LongTensor(bs, slen), containing word positions
`langs` LongTensor(bs, slen), containing language IDs
"""
# lengths = (x != self.pad_index).float().sum(dim=1)
# mask = x != self.pad_index
# check inputs
slen, bs = x.size()
bs, slen = x.size()
assert lengths.size(0) == bs
assert lengths.max().item() <= slen
x = x.transpose(0, 1) # batch size as dimension 0
assert (src_enc is None) == (src_len is None)
if src_enc is not None:
assert self.is_decoder
assert src_enc.size(0) == bs
# x = x.transpose(0, 1) # batch size as dimension 0
# assert (src_enc is None) == (src_len is None)
# if src_enc is not None:
# assert self.is_decoder
# assert src_enc.size(0) == bs
# generate masks
mask, attn_mask = get_masks(slen, lengths, causal)
if self.is_decoder and src_enc is not None:
src_mask = torch.arange(src_len.max(), dtype=torch.long, device=lengths.device) < src_len[:, None]
mask, attn_mask = get_masks(slen, lengths, self.causal)
# if self.is_decoder and src_enc is not None:
# src_mask = torch.arange(src_len.max(), dtype=torch.long, device=lengths.device) < src_len[:, None]
# positions
if positions is None:
positions = x.new(slen).long()
positions = torch.arange(slen, out=positions).unsqueeze(0)
else:
assert positions.size() == (slen, bs)
positions = positions.transpose(0, 1)
assert positions.size() == (bs, slen) # (slen, bs)
# positions = positions.transpose(0, 1)
# langs
if langs is not None:
assert langs.size() == (slen, bs)
langs = langs.transpose(0, 1)
assert langs.size() == (bs, slen) # (slen, bs)
# langs = langs.transpose(0, 1)
# do not recompute cached elements
if cache is not None:
......@@ -614,620 +599,50 @@ class XLMModel(XLMPreTrainedModel):
tensor *= mask.unsqueeze(-1).to(tensor.dtype)
# transformer layers
hidden_states = []
attentions = []
for i in range(self.n_layers):
if self.output_hidden_states:
hidden_states.append(tensor)
# self attention
attn = self.attentions[i](tensor, attn_mask, cache=cache)
attn_outputs = self.attentions[i](tensor, attn_mask, cache=cache, head_mask=head_mask[i])
attn = attn_outputs[0]
if self.output_attentions:
attentions.append(attn_outputs[1])
attn = F.dropout(attn, p=self.dropout, training=self.training)
tensor = tensor + attn
tensor = self.layer_norm1[i](tensor)
# encoder attention (for decoder only)
if self.is_decoder and src_enc is not None:
attn = self.encoder_attn[i](tensor, src_mask, kv=src_enc, cache=cache)
attn = F.dropout(attn, p=self.dropout, training=self.training)
tensor = tensor + attn
tensor = self.layer_norm15[i](tensor)
# if self.is_decoder and src_enc is not None:
# attn = self.encoder_attn[i](tensor, src_mask, kv=src_enc, cache=cache)
# attn = F.dropout(attn, p=self.dropout, training=self.training)
# tensor = tensor + attn
# tensor = self.layer_norm15[i](tensor)
# FFN
tensor = tensor + self.ffns[i](tensor)
tensor = self.layer_norm2[i](tensor)
tensor *= mask.unsqueeze(-1).to(tensor.dtype)
# Add last hidden state
if self.output_hidden_states:
hidden_states.append(tensor)
# update cache length
if cache is not None:
cache['slen'] += tensor.size(1)
# move back sequence length to dimension 0
tensor = tensor.transpose(0, 1)
return tensor
def predict(self, tensor, pred_mask, y, get_scores):
"""
Given the last hidden state, compute word scores and/or the loss.
`pred_mask` is a ByteTensor of shape (slen, bs), filled with 1 when
we need to predict a word
`y` is a LongTensor of shape (pred_mask.sum(),)
`get_scores` is a boolean specifying whether we need to return scores
"""
masked_tensor = tensor[pred_mask.unsqueeze(-1).expand_as(tensor)].view(-1, self.dim)
scores, loss = self.pred_layer(masked_tensor, y, get_scores)
return scores, loss
def generate(self, src_enc, src_len, tgt_lang_id, max_len=200, sample_temperature=None):
"""
Decode a sentence given initial start.
`x`:
- LongTensor(bs, slen)
<EOS> W1 W2 W3 <EOS> <PAD>
<EOS> W1 W2 W3 W4 <EOS>
`lengths`:
- LongTensor(bs) [5, 6]
`positions`:
- False, for regular "arange" positions (LM)
- True, to reset positions from the new generation (MT)
`langs`:
- must be None if the model only supports one language
- lang_id if only one language is involved (LM)
- (lang_id1, lang_id2) if two languages are involved (MT)
"""
# input batch
bs = len(src_len)
assert src_enc.size(0) == bs
# generated sentences
generated = src_len.new(max_len, bs) # upcoming output
generated.fill_(self.pad_index) # fill upcoming ouput with <PAD>
generated[0].fill_(self.eos_index) # we use <EOS> for <BOS> everywhere
# positions
positions = src_len.new(max_len).long()
positions = torch.arange(max_len, out=positions).unsqueeze(1).expand(max_len, bs)
# language IDs
langs = src_len.new(max_len).long().fill_(tgt_lang_id)
langs = langs.unsqueeze(1).expand(max_len, bs)
# current position / max lengths / length of generated sentences / unfinished sentences
cur_len = 1
gen_len = src_len.clone().fill_(1)
unfinished_sents = src_len.clone().fill_(1)
# cache compute states
cache = {'slen': 0}
while cur_len < max_len:
# compute word scores
tensor = self.forward(
'fwd',
x=generated[:cur_len],
lengths=gen_len,
positions=positions[:cur_len],
langs=langs[:cur_len],
causal=True,
src_enc=src_enc,
src_len=src_len,
cache=cache
)
assert tensor.size() == (1, bs, self.dim)
tensor = tensor.data[-1, :, :] # (bs, dim)
scores = self.pred_layer.get_scores(tensor) # (bs, n_words)
# select next words: sample or greedy
if sample_temperature is None:
next_words = torch.topk(scores, 1)[1].squeeze(1)
else:
next_words = torch.multinomial(F.softmax(scores / sample_temperature, dim=1), 1).squeeze(1)
assert next_words.size() == (bs,)
# update generations / lengths / finished sentences / current length
generated[cur_len] = next_words * unfinished_sents + self.pad_index * (1 - unfinished_sents)
gen_len.add_(unfinished_sents)
unfinished_sents.mul_(next_words.ne(self.eos_index).long())
cur_len = cur_len + 1
# stop when there is a </s> in each sentence, or if we exceed the maximul length
if unfinished_sents.max() == 0:
break
# add <EOS> to unfinished sentences
if cur_len == max_len:
generated[-1].masked_fill_(unfinished_sents.byte(), self.eos_index)
# sanity check
assert (generated == self.eos_index).sum() == 2 * bs
return generated[:cur_len], gen_len
def generate_beam(self, src_enc, src_len, tgt_lang_id, beam_size, length_penalty, early_stopping, max_len=200):
"""
Decode a sentence given initial start.
`x`:
- LongTensor(bs, slen)
<EOS> W1 W2 W3 <EOS> <PAD>
<EOS> W1 W2 W3 W4 <EOS>
`lengths`:
- LongTensor(bs) [5, 6]
`positions`:
- False, for regular "arange" positions (LM)
- True, to reset positions from the new generation (MT)
`langs`:
- must be None if the model only supports one language
- lang_id if only one language is involved (LM)
- (lang_id1, lang_id2) if two languages are involved (MT)
"""
# check inputs
assert src_enc.size(0) == src_len.size(0)
assert beam_size >= 1
# batch size / number of words
bs = len(src_len)
n_words = self.n_words
# tensor = tensor.transpose(0, 1)
# expand to beam size the source latent representations / source lengths
src_enc = src_enc.unsqueeze(1).expand((bs, beam_size) + src_enc.shape[1:]).contiguous().view((bs * beam_size,) + src_enc.shape[1:])
src_len = src_len.unsqueeze(1).expand(bs, beam_size).contiguous().view(-1)
# generated sentences (batch with beam current hypotheses)
generated = src_len.new(max_len, bs * beam_size) # upcoming output
generated.fill_(self.pad_index) # fill upcoming ouput with <PAD>
generated[0].fill_(self.eos_index) # we use <EOS> for <BOS> everywhere
# generated hypotheses
generated_hyps = [BeamHypotheses(beam_size, max_len, length_penalty, early_stopping) for _ in range(bs)]
# positions
positions = src_len.new(max_len).long()
positions = torch.arange(max_len, out=positions).unsqueeze(1).expand_as(generated)
# language IDs
langs = positions.clone().fill_(tgt_lang_id)
# scores for each sentence in the beam
beam_scores = src_enc.new(bs, beam_size).fill_(0)
beam_scores[:, 1:] = -1e9
beam_scores = beam_scores.view(-1)
# current position
cur_len = 1
# cache compute states
cache = {'slen': 0}
# done sentences
done = [False for _ in range(bs)]
while cur_len < max_len:
# compute word scores
tensor = self.forward(
'fwd',
x=generated[:cur_len],
lengths=src_len.new(bs * beam_size).fill_(cur_len),
positions=positions[:cur_len],
langs=langs[:cur_len],
causal=True,
src_enc=src_enc,
src_len=src_len,
cache=cache
)
assert tensor.size() == (1, bs * beam_size, self.dim)
tensor = tensor.data[-1, :, :] # (bs * beam_size, dim)
scores = self.pred_layer.get_scores(tensor) # (bs * beam_size, n_words)
scores = F.log_softmax(scores, dim=-1) # (bs * beam_size, n_words)
assert scores.size() == (bs * beam_size, n_words)
# select next words with scores
_scores = scores + beam_scores[:, None].expand_as(scores) # (bs * beam_size, n_words)
_scores = _scores.view(bs, beam_size * n_words) # (bs, beam_size * n_words)
next_scores, next_words = torch.topk(_scores, 2 * beam_size, dim=1, largest=True, sorted=True)
assert next_scores.size() == next_words.size() == (bs, 2 * beam_size)
# next batch beam content
# list of (bs * beam_size) tuple(next hypothesis score, next word, current position in the batch)
next_batch_beam = []
# for each sentence
for sent_id in range(bs):
# if we are done with this sentence
done[sent_id] = done[sent_id] or generated_hyps[sent_id].is_done(next_scores[sent_id].max().item())
if done[sent_id]:
next_batch_beam.extend([(0, self.pad_index, 0)] * beam_size) # pad the batch
continue
# next sentence beam content
next_sent_beam = []
# next words for this sentence
for idx, value in zip(next_words[sent_id], next_scores[sent_id]):
# get beam and word IDs
beam_id = idx // n_words
word_id = idx % n_words
# end of sentence, or next word
if word_id == self.eos_index or cur_len + 1 == max_len:
generated_hyps[sent_id].add(generated[:cur_len, sent_id * beam_size + beam_id].clone(), value.item())
else:
next_sent_beam.append((value, word_id, sent_id * beam_size + beam_id))
# the beam for next step is full
if len(next_sent_beam) == beam_size:
break
# update next beam content
assert len(next_sent_beam) == 0 if cur_len + 1 == max_len else beam_size
if len(next_sent_beam) == 0:
next_sent_beam = [(0, self.pad_index, 0)] * beam_size # pad the batch
next_batch_beam.extend(next_sent_beam)
assert len(next_batch_beam) == beam_size * (sent_id + 1)
# sanity check / prepare next batch
assert len(next_batch_beam) == bs * beam_size
beam_scores = beam_scores.new([x[0] for x in next_batch_beam])
beam_words = generated.new([x[1] for x in next_batch_beam])
beam_idx = src_len.new([x[2] for x in next_batch_beam])
# re-order batch and internal states
generated = generated[:, beam_idx]
generated[cur_len] = beam_words
for k in cache.keys():
if k != 'slen':
cache[k] = (cache[k][0][beam_idx], cache[k][1][beam_idx])
# update current length
cur_len = cur_len + 1
# stop when we are done with each sentence
if all(done):
break
# visualize hypotheses
# print([len(x) for x in generated_hyps], cur_len)
# globals().update( locals() );
# !import code; code.interact(local=vars())
# for ii in range(bs):
# for ss, ww in sorted(generated_hyps[ii].hyp, key=lambda x: x[0], reverse=True):
# print("%.3f " % ss + " ".join(self.dico[x] for x in ww.tolist()))
# print("")
# select the best hypotheses
tgt_len = src_len.new(bs)
best = []
for i, hypotheses in enumerate(generated_hyps):
best_hyp = max(hypotheses.hyp, key=lambda x: x[0])[1]
tgt_len[i] = len(best_hyp) + 1 # +1 for the <EOS> symbol
best.append(best_hyp)
# generate target batch
decoded = src_len.new(tgt_len.max().item(), bs).fill_(self.pad_index)
for i, hypo in enumerate(best):
decoded[:tgt_len[i] - 1, i] = hypo
decoded[tgt_len[i] - 1, i] = self.eos_index
# sanity check
assert (decoded == self.eos_index).sum() == 2 * bs
return decoded, tgt_len
class XLMModel(XLMPreTrainedModel):
def __init__(self, config, output_attentions=False, output_hidden_states=False):
super(XLMModel, self).__init__(config)
self.output_attentions = output_attentions
self.output_hidden_states = output_hidden_states
self.mem_len = config.mem_len
self.reuse_len = config.reuse_len
self.d_model = config.d_model
self.same_length = config.same_length
self.attn_type = config.attn_type
self.bi_data = config.bi_data
self.clamp_len = config.clamp_len
self.word_embedding = nn.Embedding(config.n_token, config.d_model)
self.mask_emb = nn.Parameter(torch.Tensor(1, 1, config.d_model))
layer = XLMLayer(config, output_attentions=output_attentions,
keep_multihead_output=keep_multihead_output)
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.n_layer)])
self.dropout = nn.Dropout(config.dropout)
def prune_heads(self, heads_to_prune):
""" Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
"""
for layer, heads in heads_to_prune.items():
self.layer[layer].attention.prune_heads(heads)
def get_multihead_outputs(self):
""" Gather all multi-head outputs.
Return: list (layers) of multihead module outputs with gradients
"""
return [layer.attention.self.multihead_output for layer in self.layer]
def create_mask(self, qlen, mlen):
""" create causal attention mask.
float mask where 1.0 indicate masked, 0.0 indicated not-masked.
same_length=False: same_length=True:
<mlen > < qlen > <mlen > < qlen >
^ [0 0 0 0 0 1 1 1 1] [0 0 0 0 0 1 1 1 1]
[0 0 0 0 0 0 1 1 1] [1 0 0 0 0 0 1 1 1]
qlen [0 0 0 0 0 0 0 1 1] [1 1 0 0 0 0 0 1 1]
[0 0 0 0 0 0 0 0 1] [1 1 1 0 0 0 0 0 1]
v [0 0 0 0 0 0 0 0 0] [1 1 1 1 0 0 0 0 0]
"""
attn_mask = torch.ones([qlen, qlen])
mask_up = torch.triu(attn_mask, diagonal=1)
attn_mask_pad = torch.zeros([qlen, mlen])
ret = torch.cat([attn_mask_pad, mask_up], dim=1)
if self.same_length:
mask_lo = torch.tril(attn_mask, diagonal=-1)
ret = torch.cat([ret[:, :qlen] + mask_lo, ret[:, qlen:]], dim=1)
ret = ret.to(next(self.parameters()))
return ret
def cache_mem(self, curr_out, prev_mem):
"""cache hidden states into memory."""
if self.mem_len is None or self.mem_len == 0:
return None
else:
if self.reuse_len is not None and self.reuse_len > 0:
curr_out = curr_out[:self.reuse_len]
if prev_mem is None:
new_mem = curr_out[-self.mem_len:]
else:
new_mem = torch.cat([prev_mem, curr_out], dim=0)[-self.mem_len:]
return new_mem.detach()
@staticmethod
def positional_embedding(pos_seq, inv_freq, bsz=None):
sinusoid_inp = torch.einsum('i,d->id', pos_seq, inv_freq)
pos_emb = torch.cat([torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)], dim=-1)
pos_emb = pos_emb[:, None, :]
if bsz is not None:
pos_emb = pos_emb.expand(-1, bsz, -1)
return pos_emb
def relative_positional_encoding(self, qlen, klen, bsz=None):
"""create relative positional encoding."""
freq_seq = torch.arange(0, self.d_model, 2.0, dtype=torch.float)
inv_freq = 1 / (10000 ** (freq_seq / self.d_model))
if self.attn_type == 'bi':
# beg, end = klen - 1, -qlen
beg, end = klen, -qlen
elif self.attn_type == 'uni':
# beg, end = klen - 1, -1
beg, end = klen, -1
else:
raise ValueError('Unknown `attn_type` {}.'.format(self.attn_type))
if self.bi_data:
fwd_pos_seq = torch.arange(beg, end, -1.0, dtype=torch.float)
bwd_pos_seq = torch.arange(-beg, -end, 1.0, dtype=torch.float)
if self.clamp_len > 0:
fwd_pos_seq = fwd_pos_seq.clamp(-self.clamp_len, self.clamp_len)
bwd_pos_seq = bwd_pos_seq.clamp(-self.clamp_len, self.clamp_len)
if bsz is not None:
fwd_pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq, bsz//2)
bwd_pos_emb = self.positional_embedding(bwd_pos_seq, inv_freq, bsz//2)
else:
fwd_pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq)
bwd_pos_emb = self.positional_embedding(bwd_pos_seq, inv_freq)
pos_emb = torch.cat([fwd_pos_emb, bwd_pos_emb], dim=1)
else:
fwd_pos_seq = torch.arange(beg, end, -1.0)
if self.clamp_len > 0:
fwd_pos_seq = fwd_pos_seq.clamp(-self.clamp_len, self.clamp_len)
pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq, bsz)
pos_emb = pos_emb.to(next(self.parameters()))
return pos_emb
def forward(self, inp_k, token_type_ids=None, input_mask=None, attention_mask=None,
mems=None, perm_mask=None, target_mapping=None, inp_q=None, head_mask=None):
"""
Args:
inp_k: int32 Tensor in shape [bsz, len], the input token IDs.
token_type_ids: int32 Tensor in shape [bsz, len], the input segment IDs.
input_mask: [optional] float32 Tensor in shape [bsz, len], the input mask.
0 for real tokens and 1 for padding.
attention_mask: [optional] float32 Tensor, SAME FUNCTION as `input_mask`
but with 1 for real tokens and 0 for padding.
Added for easy compatibility with the XLM model (which uses this negative masking).
You can only uses one among `input_mask` and `attention_mask`
mems: [optional] a list of float32 Tensors in shape [mem_len, bsz, d_model], memory
from previous batches. The length of the list equals n_layer.
If None, no memory is used.
perm_mask: [optional] float32 Tensor in shape [bsz, len, len].
If perm_mask[k, i, j] = 0, i attend to j in batch k;
if perm_mask[k, i, j] = 1, i does not attend to j in batch k.
If None, each position attends to all the others.
target_mapping: [optional] float32 Tensor in shape [bsz, num_predict, len].
If target_mapping[k, i, j] = 1, the i-th predict in batch k is
on the j-th token.
Only used during pretraining for partial prediction.
Set to None during finetuning.
inp_q: [optional] float32 Tensor in shape [bsz, len].
1 for tokens with losses and 0 for tokens without losses.
Only used during pretraining for two-stream attention.
Set to None during finetuning.
mem_len: int, the number of tokens to cache.
reuse_len: int, the number of tokens in the currect batch to be cached
and reused in the future.
bi_data: bool, whether to use bidirectional input pipeline.
Usually set to True during pretraining and False during finetuning.
clamp_len: int, clamp all relative distances larger than clamp_len.
-1 means no clamping.
same_length: bool, whether to use the same attention length for each token.
summary_type: str, "last", "first", "mean", or "attn". The method
to pool the input to get a vector representation.
"""
# the original code for XLM uses shapes [len, bsz] with the batch dimension at the end
# but we want a unified interface in the library with the batch size on the first dimension
# so we move here the first dimension (batch) to the end
inp_k = inp_k.transpose(0, 1).contiguous()
token_type_ids = token_type_ids.transpose(0, 1).contiguous() if token_type_ids is not None else None
input_mask = input_mask.transpose(0, 1).contiguous() if input_mask is not None else None
attention_mask = attention_mask.transpose(0, 1).contiguous() if attention_mask is not None else None
perm_mask = perm_mask.permute(1, 2, 0).contiguous() if perm_mask is not None else None
target_mapping = target_mapping.permute(1, 2, 0).contiguous() if target_mapping is not None else None
inp_q = inp_q.transpose(0, 1).contiguous() if inp_q is not None else None
qlen, bsz = inp_k.shape[0], inp_k.shape[1]
mlen = mems[0].shape[0] if mems is not None else 0
klen = mlen + qlen
dtype_float = next(self.parameters()).dtype
device = next(self.parameters()).device
##### Attention mask
# causal attention mask
if self.attn_type == 'uni':
attn_mask = self.create_mask(qlen, mlen)
attn_mask = attn_mask[:, :, None, None]
elif self.attn_type == 'bi':
attn_mask = None
else:
raise ValueError('Unsupported attention type: {}'.format(self.attn_type))
# data mask: input mask & perm mask
assert input_mask is None or attention_mask is None, "You can only use one of input_mask (uses 1 for padding) "
"or attention_mask (uses 0 for padding, added for compatbility with XLM). Please choose one."
if input_mask is None and attention_mask is not None:
input_mask = 1.0 - attention_mask
if input_mask is not None and perm_mask is not None:
data_mask = input_mask[None] + perm_mask
elif input_mask is not None and perm_mask is None:
data_mask = input_mask[None]
elif input_mask is None and perm_mask is not None:
data_mask = perm_mask
else:
data_mask = None
if data_mask is not None:
# all mems can be attended to
mems_mask = torch.zeros([data_mask.shape[0], mlen, bsz]).to(data_mask)
data_mask = torch.cat([mems_mask, data_mask], dim=1)
if attn_mask is None:
attn_mask = data_mask[:, :, :, None]
else:
attn_mask += data_mask[:, :, :, None]
if attn_mask is not None:
attn_mask = (attn_mask > 0).to(dtype_float)
if attn_mask is not None:
non_tgt_mask = -torch.eye(qlen).to(attn_mask)
non_tgt_mask = torch.cat([torch.zeros([qlen, mlen]).to(attn_mask), non_tgt_mask], dim=-1)
non_tgt_mask = ((attn_mask + non_tgt_mask[:, :, None, None]) > 0).to(attn_mask)
else:
non_tgt_mask = None
##### Word embeddings and prepare h & g hidden states
word_emb_k = self.word_embedding(inp_k)
output_h = self.dropout(word_emb_k)
if inp_q is not None:
if target_mapping is not None:
word_emb_q = self.mask_emb.expand(target_mapping.shape[0], bsz, -1)
else:
inp_q_ext = inp_q[:, :, None]
word_emb_q = inp_q_ext * self.mask_emb + (1 - inp_q_ext) * word_emb_k
output_g = self.dropout(word_emb_q)
else:
output_g = None
##### Segment embedding
if token_type_ids is not None:
# Convert `token_type_ids` to one-hot `seg_mat`
mem_pad = torch.zeros([mlen, bsz], dtype=torch.long, device=device)
cat_ids = torch.cat([mem_pad, token_type_ids], dim=0)
# `1` indicates not in the same segment [qlen x klen x bsz]
seg_mat = (token_type_ids[:, None] != cat_ids[None, :]).long()
seg_mat = F.one_hot(seg_mat, num_classes=2).to(dtype_float)
else:
seg_mat = None
##### Positional encoding
pos_emb = self.relative_positional_encoding(qlen, klen, bsz=bsz)
pos_emb = self.dropout(pos_emb)
##### Head mask if needed (for bertology/pruning)
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [n_layer x num_heads]
# and head_mask is converted to shape [n_layer x batch x num_heads x seq_length x seq_length]
if head_mask is not None:
if head_mask.dim() == 1:
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
head_mask = head_mask.expand(self.config.n_layer, -1, -1, -1, -1)
elif head_mask.dim() == 2:
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
else:
head_mask = [None] * self.config.n_layer
new_mems = []
if mems is None:
mems = [None] * len(self.layer)
hidden_states = []
attentions = []
for i, layer_module in enumerate(self.layer):
# cache new mems
new_mems.append(self.cache_mem(output_h, mems[i]))
# Save hidden_states
if output_g is None:
hidden_states.append(output_h)
else:
hidden_states.append((output_h, output_g))
output_h, output_g = layer_module(output_h, output_g,
attn_mask_h=non_tgt_mask, attn_mask_g=attn_mask,
r=pos_emb, seg_mat=seg_mat,
mems=mems[i], target_mapping=target_mapping,
head_mask=head_mask)
# Save last hidden_state
if output_g is None:
hidden_states.append(output_h)
else:
hidden_states.append((output_h, output_g))
# Select the right output and add dropout
output = self.dropout(output_g if output_g is not None else output_h)
# We transpose back here to shape [bsz, len, hidden_dim] (cf. beginning of forward() method)
output = output.permute(1, 0, 2).contiguous()
if output_g is None:
hidden_states = [hs.permute(1, 0, 2).contiguous() for hs in hidden_states]
else:
hidden_states = [h.permute(1, 0, 2).contiguous() for hs in hidden_states for h in hs]
# Build the list of outputs
outputs = [output, new_mems]
if self.output_attentions:
outputs.append(attentions)
outputs = [tensor]
if self.output_hidden_states:
outputs.append(hidden_states)
return outputs
if self.output_attentions:
outputs.append(attentions)
return outputs # outputs, (hidden_states), (attentions)
class XLMPredLayer(nn.Module):
......@@ -1275,8 +690,11 @@ class XLMPredLayer(nn.Module):
return self.proj.log_prob(x) if self.asm else self.proj(x)
class XLMLMHeadModel(XLMPreTrainedModel):
"""XLM model ("XLM: Generalized Autoregressive Pretraining for Language Understanding").
class XLMWithLMHeadModel(XLMPreTrainedModel):
""" XLM model from: "Cross-lingual Language Model Pretraining" by Guillaume Lample, Alexis Conneau
Paper: https://arxiv.org/abs/1901.07291
Original code: https://github.com/facebookresearch/XLM
Params:
`config`: a XLMConfig class instance with the configuration to build a new model
......@@ -1285,36 +703,29 @@ class XLMLMHeadModel(XLMPreTrainedModel):
This can be used to compute head importance metrics. Default: False
Inputs:
inp_k: int32 Tensor in shape [bsz, len], the input token IDs.
token_type_ids: int32 Tensor in shape [bsz, len], the input segment IDs.
attention_mask: [optional] float32 Tensor in shape [bsz, len], the input mask.
0 for real tokens and 1 for padding.
mems: [optional] a list of float32 Tensors in shape [mem_len, bsz, d_model], memory
from previous batches. The length of the list equals n_layer.
If None, no memory is used.
perm_mask: [optional] float32 Tensor in shape [bsz, len, len].
If perm_mask[k, i, j] = 0, i attend to j in batch k;
if perm_mask[k, i, j] = 1, i does not attend to j in batch k.
If None, each position attends to all the others.
target_mapping: [optional] float32 Tensor in shape [bsz, num_predict, len].
If target_mapping[k, i, j] = 1, the i-th predict in batch k is
on the j-th token.
Only used during pretraining for partial prediction.
Set to None during finetuning.
inp_q: [optional] float32 Tensor in shape [bsz, len].
1 for tokens with losses and 0 for tokens without losses.
Only used during pretraining for two-stream attention.
Set to None during finetuning.
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
`run_bert_extract_features.py`, `run_bert_classifier.py` and `run_bert_squad.py`)
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
a `sentence B` token (see XLM paper for more details).
`attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
input sequence length in the current batch. It's the mask that we typically use for attention when
a batch has varying length sentences.
`output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`.
`head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1.
It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked.
Outputs: Tuple of (encoded_layers, pooled_output)
`encoded_layers`: controled by `output_all_encoded_layers` argument:
- `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end
of each attention block (i.e. 12 full sequences for XLM-base, 24 for XLM-large), each
encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, d_model],
encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size],
- `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding
to the last attention block of shape [batch_size, sequence_length, d_model],
`pooled_output`: a torch.FloatTensor of size [batch_size, d_model] which is the output of a
to the last attention block of shape [batch_size, sequence_length, hidden_size],
`pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a
classifier pretrained on top of the hidden state associated to the first character of the
input (`CLS`) to train on the Next-Sentence task (see XLM's paper).
......@@ -1325,8 +736,8 @@ class XLMLMHeadModel(XLMPreTrainedModel):
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
config = modeling.XLMConfig(vocab_size_or_config_json_file=32000, d_model=768,
n_layer=12, num_attention_heads=12, intermediate_size=3072)
config = modeling.XLMConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
model = modeling.XLMModel(config=config)
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
......@@ -1341,9 +752,7 @@ class XLMLMHeadModel(XLMPreTrainedModel):
self.same_length = config.same_length
self.transformer = XLMModel(config, output_attentions=output_attentions, output_hidden_states=output_hidden_states)
self.lm_loss = nn.Linear(config.d_model, config.n_token, bias=True)
# Tie weights
self.pred_layer = XLMPredLayer(config)
self.apply(self.init_weights)
self.tie_weights()
......@@ -1351,10 +760,9 @@ class XLMLMHeadModel(XLMPreTrainedModel):
def tie_weights(self):
""" Make sure we are sharing the embeddings
"""
self.lm_loss.weight = self.transformer.word_embedding.weight
self.pred_layer.proj.weight = self.transformer.embeddings.weight
def forward(self, inp_k, token_type_ids=None, input_mask=None, attention_mask=None,
mems=None, perm_mask=None, target_mapping=None, inp_q=None,
def forward(self, x, lengths, positions=None, langs=None, cache=None,
labels=None, head_mask=None):
"""
Args:
......@@ -1382,11 +790,10 @@ class XLMLMHeadModel(XLMPreTrainedModel):
summary_type: str, "last", "first", "mean", or "attn". The method
to pool the input to get a vector representation.
"""
transformer_outputs = self.transformer(inp_k, token_type_ids, input_mask, attention_mask,
mems, perm_mask, target_mapping, inp_q, head_mask)
transformer_outputs = self.transformer(x, lengths, positions=positions, langs=langs, cache=cache, head_mask=head_mask)
output = transformer_outputs[0]
logits = self.lm_loss(output)
logits = self.pred_layer(output, labels)
outputs = transformer_outputs[1:] # Keep new_mems and attention/hidden states if they are here
......
......@@ -198,7 +198,7 @@ class XLNetConfig(PretrainedConfig):
pretrained_config_archive_map = PRETRAINED_CONFIG_ARCHIVE_MAP
def __init__(self,
vocab_size_or_config_json_file,
vocab_size_or_config_json_file=32000,
d_model=1024,
n_layer=24,
n_head=16,
......@@ -221,7 +221,12 @@ class XLNetConfig(PretrainedConfig):
bi_data=False,
clamp_len=-1,
same_length=False,
finetuning_task=None):
finetuning_task=None,
num_labels=2,
summary_type="last",
use_proj=True,
**kwargs):
"""Constructs XLNetConfig.
Args:
......@@ -265,6 +270,8 @@ class XLNetConfig(PretrainedConfig):
same_length: bool, whether to use the same attention length for each token.
finetuning_task: name of the glue task on which the model was fine-tuned if any
"""
super(XLNetConfig, self).__init__(**kwargs)
if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
and isinstance(vocab_size_or_config_json_file, unicode)):
with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
......@@ -297,7 +304,11 @@ class XLNetConfig(PretrainedConfig):
self.bi_data = bi_data
self.clamp_len = clamp_len
self.same_length = same_length
self.finetuning_task = finetuning_task
self.num_labels = num_labels
self.summary_type = summary_type
self.use_proj = use_proj
else:
raise ValueError("First argument must be either a vocabulary size (int)"
"or the path to a pretrained model config file (str)")
......@@ -323,9 +334,10 @@ except ImportError:
return self.weight * x + self.bias
class XLNetRelativeAttention(nn.Module):
def __init__(self, config, output_attentions=False):
def __init__(self, config):
super(XLNetRelativeAttention, self).__init__()
self.output_attentions = output_attentions
self.output_attentions = config.output_attentions
if config.d_model % config.n_head != 0:
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention "
......@@ -533,10 +545,9 @@ class XLNetFeedForward(nn.Module):
return output
class XLNetLayer(nn.Module):
def __init__(self, config, output_attentions=False, ):
def __init__(self, config):
super(XLNetLayer, self).__init__()
self.output_attentions = output_attentions
self.rel_attn = XLNetRelativeAttention(config, output_attentions=output_attentions)
self.rel_attn = XLNetRelativeAttention(config)
self.ff = XLNetFeedForward(config)
self.dropout = nn.Dropout(config.dropout)
......@@ -562,7 +573,6 @@ class XLNetPreTrainedModel(PreTrainedModel):
"""
config_class = XLNetConfig
pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP
pretrained_config_archive_map = PRETRAINED_CONFIG_ARCHIVE_MAP
load_tf_weights = load_tf_weights_in_xlnet
base_model_prefix = "transformer"
......@@ -589,10 +599,10 @@ class XLNetPreTrainedModel(PreTrainedModel):
class XLNetModel(XLNetPreTrainedModel):
def __init__(self, config, output_attentions=False, output_hidden_states=False):
def __init__(self, config):
super(XLNetModel, self).__init__(config)
self.output_attentions = output_attentions
self.output_hidden_states = output_hidden_states
self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states
self.mem_len = config.mem_len
self.reuse_len = config.reuse_len
......@@ -601,25 +611,17 @@ class XLNetModel(XLNetPreTrainedModel):
self.attn_type = config.attn_type
self.bi_data = config.bi_data
self.clamp_len = config.clamp_len
self.n_layer = config.n_layer
self.word_embedding = nn.Embedding(config.n_token, config.d_model)
self.mask_emb = nn.Parameter(torch.Tensor(1, 1, config.d_model))
layer = XLNetLayer(config, output_attentions=output_attentions)
layer = XLNetLayer(config)
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.n_layer)])
self.dropout = nn.Dropout(config.dropout)
def prune_heads(self, heads_to_prune):
""" Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
"""
for layer, heads in heads_to_prune.items():
self.layer[layer].attention.prune_heads(heads)
def get_multihead_outputs(self):
""" Gather all multi-head outputs.
Return: list (layers) of multihead module outputs with gradients
"""
return [layer.attention.self.multihead_output for layer in self.layer]
def _prune_heads(self, heads_to_prune):
logger.info("Head pruning is not implemented for XLNet")
pass
def create_mask(self, qlen, mlen):
""" create causal attention mask.
......@@ -708,11 +710,11 @@ class XLNetModel(XLNetPreTrainedModel):
pos_emb = pos_emb.to(next(self.parameters()))
return pos_emb
def forward(self, inp_k, token_type_ids=None, input_mask=None, attention_mask=None,
def forward(self, input_ids, token_type_ids=None, input_mask=None, attention_mask=None,
mems=None, perm_mask=None, target_mapping=None, inp_q=None, head_mask=None):
"""
Args:
inp_k: int32 Tensor in shape [bsz, len], the input token IDs.
input_ids: int32 Tensor in shape [bsz, len], the input token IDs.
token_type_ids: int32 Tensor in shape [bsz, len], the input segment IDs.
input_mask: [optional] float32 Tensor in shape [bsz, len], the input mask.
0 for real tokens and 1 for padding.
......@@ -751,7 +753,7 @@ class XLNetModel(XLNetPreTrainedModel):
# the original code for XLNet uses shapes [len, bsz] with the batch dimension at the end
# but we want a unified interface in the library with the batch size on the first dimension
# so we move here the first dimension (batch) to the end
inp_k = inp_k.transpose(0, 1).contiguous()
input_ids = input_ids.transpose(0, 1).contiguous()
token_type_ids = token_type_ids.transpose(0, 1).contiguous() if token_type_ids is not None else None
input_mask = input_mask.transpose(0, 1).contiguous() if input_mask is not None else None
attention_mask = attention_mask.transpose(0, 1).contiguous() if attention_mask is not None else None
......@@ -759,7 +761,7 @@ class XLNetModel(XLNetPreTrainedModel):
target_mapping = target_mapping.permute(1, 2, 0).contiguous() if target_mapping is not None else None
inp_q = inp_q.transpose(0, 1).contiguous() if inp_q is not None else None
qlen, bsz = inp_k.shape[0], inp_k.shape[1]
qlen, bsz = input_ids.shape[0], input_ids.shape[1]
mlen = mems[0].shape[0] if mems is not None else 0
klen = mlen + qlen
......@@ -810,7 +812,7 @@ class XLNetModel(XLNetPreTrainedModel):
non_tgt_mask = None
##### Word embeddings and prepare h & g hidden states
word_emb_k = self.word_embedding(inp_k)
word_emb_k = self.word_embedding(input_ids)
output_h = self.dropout(word_emb_k)
if inp_q is not None:
if target_mapping is not None:
......@@ -838,20 +840,20 @@ class XLNetModel(XLNetPreTrainedModel):
pos_emb = self.relative_positional_encoding(qlen, klen, bsz=bsz)
pos_emb = self.dropout(pos_emb)
##### Head mask if needed (for bertology/pruning)
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [n_layer x num_heads]
# and head_mask is converted to shape [n_layer x batch x num_heads x seq_length x seq_length]
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] (a head_mask for each layer)
# and head_mask is converted to shape [num_hidden_layers x qlen x klen x bsz x n_head]
if head_mask is not None:
if head_mask.dim() == 1:
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
head_mask = head_mask.expand(self.config.n_layer, -1, -1, -1, -1)
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(0).unsqueeze(0)
head_mask = head_mask.expand(self.n_layer, -1, -1, -1, -1)
elif head_mask.dim() == 2:
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
head_mask = head_mask.unsqueeze(1).unsqueeze(1).unsqueeze(1)
head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
else:
head_mask = [None] * self.config.n_layer
head_mask = [None] * self.n_layer
new_mems = []
if mems is None:
......@@ -870,7 +872,7 @@ class XLNetModel(XLNetPreTrainedModel):
head_mask=head_mask[i])
output_h, output_g = outputs[:2]
if self.output_attentions:
attentions.append(outputs[2:])
attentions.append(outputs[2])
# Add last hidden state
if self.output_hidden_states:
......@@ -887,6 +889,7 @@ class XLNetModel(XLNetPreTrainedModel):
hidden_states = [hs.permute(1, 0, 2).contiguous() for hs in hidden_states]
outputs.append(hidden_states)
if self.output_attentions:
attentions = list(t.permute(2, 3, 0, 1).contiguous() for t in attentions)
outputs.append(attentions)
return outputs # outputs, new_mems, (hidden_states), (attentions)
......@@ -902,7 +905,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
This can be used to compute head importance metrics. Default: False
Inputs:
inp_k: int32 Tensor in shape [bsz, len], the input token IDs.
input_ids: int32 Tensor in shape [bsz, len], the input token IDs.
token_type_ids: int32 Tensor in shape [bsz, len], the input segment IDs.
input_mask: [optional] float32 Tensor in shape [bsz, len], the input mask.
0 for real tokens and 1 for padding.
......@@ -953,16 +956,12 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
```
"""
def __init__(self, config, output_attentions=False, output_hidden_states=False):
def __init__(self, config):
super(XLNetLMHeadModel, self).__init__(config)
self.output_attentions = output_attentions
self.output_hidden_states = output_hidden_states
self.attn_type = config.attn_type
self.same_length = config.same_length
self.transformer = XLNetModel(config, output_attentions=output_attentions,
output_hidden_states=output_hidden_states)
self.transformer = XLNetModel(config)
self.lm_loss = nn.Linear(config.d_model, config.n_token, bias=True)
# Tie weights
......@@ -975,12 +974,12 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
"""
self.lm_loss.weight = self.transformer.word_embedding.weight
def forward(self, inp_k, token_type_ids=None, input_mask=None, attention_mask=None,
def forward(self, input_ids, token_type_ids=None, input_mask=None, attention_mask=None,
mems=None, perm_mask=None, target_mapping=None, inp_q=None,
labels=None, head_mask=None):
"""
Args:
inp_k: int32 Tensor in shape [bsz, len], the input token IDs.
input_ids: int32 Tensor in shape [bsz, len], the input token IDs.
token_type_ids: int32 Tensor in shape [bsz, len], the input segment IDs.
input_mask: float32 Tensor in shape [bsz, len], the input mask.
0 for real tokens and 1 for padding.
......@@ -1008,7 +1007,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
summary_type: str, "last", "first", "mean", or "attn". The method
to pool the input to get a vector representation.
"""
transformer_outputs = self.transformer(inp_k, token_type_ids, input_mask, attention_mask,
transformer_outputs = self.transformer(input_ids, token_type_ids, input_mask, attention_mask,
mems, perm_mask, target_mapping, inp_q, head_mask)
logits = self.lm_loss(transformer_outputs[0])
......@@ -1025,14 +1024,14 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
return outputs # return (loss), logits, (mems), (hidden states), (attentions)
class XLNetSequenceSummary(nn.Module):
def __init__(self, config, summary_type="last", use_proj=True):
def __init__(self, config):
super(XLNetSequenceSummary, self).__init__()
self.summary_type = summary_type
if use_proj:
self.summary_type = config.summary_type
if config.use_proj:
self.summary = nn.Linear(config.d_model, config.d_model)
else:
self.summary = None
if summary_type == 'attn':
if config.summary_type == 'attn':
# We should use a standard multi-head attention module with absolute positional embedding for that.
# Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
# We can probably just use the multi-head attention module of PyTorch >=1.1.0
......@@ -1069,7 +1068,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
to pool the input to get a vector representation. Default: last
Inputs:
inp_k: int32 Tensor in shape [bsz, len], the input token IDs.
input_ids: int32 Tensor in shape [bsz, len], the input token IDs.
token_type_ids: int32 Tensor in shape [bsz, len], the input segment IDs.
input_mask: float32 Tensor in shape [bsz, len], the input mask.
0 for real tokens and 1 for padding.
......@@ -1121,30 +1120,21 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
```
"""
def __init__(self, config, summary_type="last", use_proj=True, num_labels=2,
output_attentions=False, output_hidden_states=False):
def __init__(self, config):
super(XLNetForSequenceClassification, self).__init__(config)
self.output_attentions = output_attentions
self.output_hidden_states = output_hidden_states
self.attn_type = config.attn_type
self.same_length = config.same_length
self.summary_type = summary_type
self.num_labels = num_labels
self.transformer = XLNetModel(config, output_attentions=output_attentions,
output_hidden_states=output_hidden_states)
self.transformer = XLNetModel(config)
self.sequence_summary = XLNetSequenceSummary(config)
self.logits_proj = nn.Linear(config.d_model, config.num_labels)
self.sequence_summary = XLNetSequenceSummary(config, summary_type=summary_type, use_proj=use_proj)
self.logits_proj = nn.Linear(config.d_model, num_labels)
self.apply(self.init_weights)
def forward(self, inp_k, token_type_ids=None, input_mask=None, attention_mask=None,
def forward(self, input_ids, token_type_ids=None, input_mask=None, attention_mask=None,
mems=None, perm_mask=None, target_mapping=None, inp_q=None,
labels=None, head_mask=None):
"""
Args:
inp_k: int32 Tensor in shape [bsz, len], the input token IDs.
input_ids: int32 Tensor in shape [bsz, len], the input token IDs.
token_type_ids: int32 Tensor in shape [bsz, len], the input segment IDs.
input_mask: float32 Tensor in shape [bsz, len], the input mask.
0 for real tokens and 1 for padding.
......@@ -1169,7 +1159,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
Only used during pretraining for two-stream attention.
Set to None during finetuning.
"""
transformer_outputs = self.transformer(inp_k, token_type_ids, input_mask, attention_mask,
transformer_outputs = self.transformer(input_ids, token_type_ids, input_mask, attention_mask,
mems, perm_mask, target_mapping, inp_q, head_mask)
output = transformer_outputs[0]
......@@ -1247,20 +1237,18 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
start_logits, end_logits = model(input_ids, token_type_ids, input_mask)
```
"""
def __init__(self, config, output_attentions=False, output_hidden_states=False):
def __init__(self, config):
super(XLNetForQuestionAnswering, self).__init__(config)
self.output_attentions = output_attentions
self.output_hidden_states = output_hidden_states
self.transformer = XLNetModel(config, output_attentions=output_attentions,
output_hidden_states=output_hidden_states)
self.qa_outputs = nn.Linear(config.hidden_size, 2)
self.transformer = XLNetModel(config)
self.qa_outputs = nn.Linear(config.d_model, config.num_labels)
self.apply(self.init_weights)
def forward(self, inp_k, token_type_ids=None, input_mask=None, attention_mask=None,
def forward(self, input_ids, token_type_ids=None, input_mask=None, attention_mask=None,
mems=None, perm_mask=None, target_mapping=None, inp_q=None,
start_positions=None, end_positions=None, head_mask=None):
transformer_outputs = self.transformer(inp_k, token_type_ids, input_mask, attention_mask,
transformer_outputs = self.transformer(input_ids, token_type_ids, input_mask, attention_mask,
mems, perm_mask, target_mapping, inp_q, head_mask)
logits = self.qa_outputs(transformer_outputs[0])
......
# coding=utf-8
# Copyright 2019 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import shutil
import json
import random
import torch
def create_and_check_for_headmasking(tester, model_classes, config, inputs_dict):
for model_class in model_classes:
config.output_hidden_states = True
model = model_class(config=config)
model.eval()
head_mask = torch.zeros(tester.num_hidden_layers, tester.num_attention_heads)
# Set that after having prepared the tensor to avoid error (leaf variable has been moved into the graph interior)
head_mask.requires_grad_(requires_grad=True)
outputs = model(**inputs_dict, head_mask=head_mask)
# Compute some gradients
output = sum(t.sum() for t in outputs[0])
output = output.sum()
output.backward()
multihead_outputs = head_mask.grad
tester.parent.assertEqual(len(multihead_outputs), tester.num_hidden_layers)
# self.parent.assertListEqual(
# list(multihead_outputs[0].size()),
# [self.batch_size, self.num_attention_heads,
# self.seq_length, self.hidden_size // self.num_attention_heads])
# self.parent.assertEqual(
# len(multihead_outputs[0][:, 1:(self.num_attention_heads-1), :, :].nonzero()),
# 0)
# self.parent.assertEqual(
# len(multihead_outputs[0][:, 0, :, :].nonzero()),
# self.batch_size * self.seq_length * self.hidden_size // self.num_attention_heads)
# self.parent.assertEqual(
# len(multihead_outputs[0][:, self.num_attention_heads-1, :, :].nonzero()),
# self.batch_size * self.seq_length * self.hidden_size // self.num_attention_heads)
# self.parent.assertListEqual(
# list(multihead_outputs[1].size()),
# [self.batch_size, self.num_attention_heads,
# self.seq_length, self.hidden_size // self.num_attention_heads])
# self.parent.assertEqual(
# len(multihead_outputs[1].nonzero()),
# multihead_outputs[1].numel())
# self.parent.assertListEqual(
# list(multihead_outputs[-1].size()),
# [self.batch_size, self.num_attention_heads,
# self.seq_length, self.hidden_size // self.num_attention_heads])
# self.parent.assertEqual(
# len(multihead_outputs[-1][:, 1:, :, :].nonzero()),
# 0)
# self.parent.assertEqual(
# len(multihead_outputs[-1][:, 0, :, :].nonzero()),
# self.batch_size * self.seq_length * self.hidden_size // self.num_attention_heads)
def create_and_check_for_head_pruning(tester, model_classes, config, inputs_dict):
for model_class in model_classes:
model = model_class(config=config)
model.eval()
heads_to_prune = {0: list(range(1, tester.num_attention_heads)),
-1: [0]}
model.prune_heads(heads_to_prune)
outputs = model(**inputs_dict)
# output = sum(t.sum() for t in outputs[0])
# output = output.sum()
# output.backward()
# multihead_outputs = bert_model.get_multihead_outputs()
# self.parent.assertEqual(len(multihead_outputs), self.num_hidden_layers)
# self.parent.assertListEqual(
# list(multihead_outputs[0].size()),
# [self.batch_size, 1,
# self.seq_length, self.hidden_size // self.num_attention_heads])
# self.parent.assertListEqual(
# list(multihead_outputs[1].size()),
# [self.batch_size, self.num_attention_heads,
# self.seq_length, self.hidden_size // self.num_attention_heads])
# self.parent.assertListEqual(
# list(multihead_outputs[-1].size()),
# [self.batch_size, self.num_attention_heads-1,
# self.seq_length, self.hidden_size // self.num_attention_heads])
def create_and_check_for_attentions(tester, model_classes, config, inputs_dict):
for model_class in model_classes:
config.output_attentions = True
config.output_hidden_states = False
model = model_class(config)
model.eval()
outputs = model(**inputs_dict)
attentions = outputs[-1]
tester.parent.assertEqual(model.config.output_attentions, True)
tester.parent.assertEqual(model.config.output_hidden_states, False)
tester.parent.assertEqual(len(attentions), tester.num_hidden_layers)
tester.parent.assertListEqual(
list(attentions[0].shape[-3:]),
[tester.num_attention_heads,
tester.seq_length,
tester.key_len if hasattr(tester, 'key_len') else tester.seq_length])
out_len = len(outputs)
# Check attention is always last and order is fine
config.output_attentions = True
config.output_hidden_states = True
model = model_class(config)
model.eval()
outputs = model(**inputs_dict)
tester.parent.assertEqual(out_len+1, len(outputs))
tester.parent.assertEqual(model.config.output_attentions, True)
tester.parent.assertEqual(model.config.output_hidden_states, True)
attentions = outputs[-1]
tester.parent.assertEqual(len(attentions), tester.num_hidden_layers)
tester.parent.assertListEqual(
list(attentions[0].shape[-3:]),
[tester.num_attention_heads,
tester.seq_length,
tester.key_len if hasattr(tester, 'key_len') else tester.seq_length])
def create_and_check_for_hidden_states(tester, model_classes, config, inputs_dict):
for model_class in model_classes:
config.output_hidden_states = True
config.output_attentions = False
model = model_class(config)
model.eval()
outputs = model(**inputs_dict)
hidden_states = outputs[-1]
tester.parent.assertEqual(model.config.output_attentions, False)
tester.parent.assertEqual(model.config.output_hidden_states, True)
tester.parent.assertEqual(len(hidden_states), tester.num_hidden_layers + 1)
tester.parent.assertListEqual(
list(hidden_states[0].shape[-2:]),
[tester.seq_length, tester.hidden_size])
def create_and_check_commons(tester, config, inputs_dict):
create_and_check_for_attentions(tester, tester.all_model_classes, config, inputs_dict)
create_and_check_for_headmasking(tester, tester.all_model_classes, config, inputs_dict)
create_and_check_for_head_pruning(tester, tester.all_model_classes, config, inputs_dict)
create_and_check_for_hidden_states(tester, tester.all_model_classes, config, inputs_dict)
def ids_tensor(shape, vocab_size, rng=None, name=None):
"""Creates a random int32 tensor of the shape within the vocab size."""
if rng is None:
rng = random.Random()
total_dims = 1
for dim in shape:
total_dims *= dim
values = []
for _ in range(total_dims):
values.append(rng.randint(0, vocab_size - 1))
return torch.tensor(data=values, dtype=torch.long).view(shape).contiguous()
class ConfigTester(object):
def __init__(self, parent, config_class=None, **kwargs):
self.parent = parent
self.config_class = config_class
self.inputs_dict = kwargs
def create_and_test_config_to_json_string(self):
config = self.config_class(**self.inputs_dict)
obj = json.loads(config.to_json_string())
for key, value in self.inputs_dict.items():
self.parent.assertEqual(obj[key], value)
def create_and_test_config_to_json_file(self):
config_first = self.config_class(**self.inputs_dict)
json_file_path = "/tmp/config.json"
config_first.to_json_file(json_file_path)
config_second = self.config_class.from_json_file(json_file_path)
os.remove(json_file_path)
self.parent.assertEqual(config_second.to_dict(), config_first.to_dict())
def run_common_tests(self):
self.create_and_test_config_to_json_string()
self.create_and_test_config_to_json_file()
class GPTModelTester(object):
def __init__(self,
parent,
batch_size=13,
seq_length=7,
is_training=True,
use_position_ids=True,
use_token_type_ids=True,
use_labels=True,
vocab_size=99,
n_special=1,
n_positions=33,
hidden_size=32,
num_hidden_layers=5,
num_attention_heads=4,
n_choices=3,
type_sequence_label_size=2,
initializer_range=0.02,
num_labels=3,
scope=None,
config_class=None,
base_model_class=None,
lm_head_model_class=None,
double_head_model_class=None,
):
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
self.is_training = is_training
self.use_position_ids = use_position_ids
self.use_token_type_ids = use_token_type_ids
self.use_labels = use_labels
self.vocab_size = vocab_size
self.n_special = n_special
self.n_positions = n_positions
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.n_choices = n_choices
self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range
self.num_labels = num_labels
self.scope = scope
self.config_class = config_class
self.base_model_class = base_model_class
self.lm_head_model_class = lm_head_model_class
self.double_head_model_class = double_head_model_class
self.all_model_classes = (base_model_class, lm_head_model_class, double_head_model_class)
def prepare_config_and_inputs(self):
total_num_tokens = self.vocab_size + self.n_special
input_ids = ids_tensor([self.batch_size, self.n_choices, self.seq_length], total_num_tokens)
position_ids = None
if self.use_position_ids:
position_ids = ids_tensor([self.batch_size, self.n_choices, self.seq_length], self.n_positions)
token_type_ids = None
if self.use_token_type_ids:
total_voc = self.vocab_size
token_type_ids = ids_tensor([self.batch_size, self.n_choices, self.seq_length], total_voc)
mc_labels = None
lm_labels = None
mc_token_ids = None
if self.use_labels:
mc_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
lm_labels = ids_tensor([self.batch_size, self.n_choices, self.seq_length], self.num_labels)
mc_token_ids = ids_tensor([self.batch_size, self.n_choices], self.seq_length)
config = self.config_class(
vocab_size_or_config_json_file=self.vocab_size,
n_special=self.n_special,
n_positions=self.n_positions,
n_embd=self.hidden_size,
n_layer=self.num_hidden_layers,
n_head=self.num_attention_heads,
initializer_range=self.initializer_range)
return (config, input_ids, token_type_ids, position_ids,
mc_labels, lm_labels, mc_token_ids)
def create_and_check_base_model(self, config, input_ids, token_type_ids, position_ids,
mc_labels, lm_labels, mc_token_ids):
model = self.base_model_class(config)
model.eval()
outputs = model(input_ids, position_ids, token_type_ids)
hidden_state = outputs[0]
self.parent.assertListEqual(
list(hidden_state.size()),
[self.batch_size, self.n_choices, self.seq_length, self.hidden_size])
def create_and_check_lm_head(self, config, input_ids, token_type_ids, position_ids,
mc_labels, lm_labels, mc_token_ids):
model = self.lm_head_model_class(config)
model.eval()
outputs = model(input_ids, position_ids, token_type_ids, lm_labels)
loss, lm_logits = outputs[:2]
total_voc = self.n_special + self.vocab_size
self.parent.assertListEqual(
list(lm_logits.size()),
[self.batch_size, self.n_choices, self.seq_length, total_voc])
self.parent.assertListEqual(
list(loss.size()),
[])
def create_and_check_presents(self, config, input_ids, token_type_ids, position_ids,
mc_labels, lm_labels, mc_token_ids):
for model_class in self.all_model_classes:
model = model_class(config)
model.eval()
outputs = model(input_ids)
presents = outputs[-1]
self.parent.assertEqual(self.num_hidden_layers, len(presents))
self.parent.assertListEqual(
list(presents[0].size()),
[2, self.batch_size * self.n_choices, self.num_attention_heads,
self.seq_length, self.hidden_size // self.num_attention_heads])
def create_and_check_double_heads(self, config, input_ids, token_type_ids, position_ids,
mc_labels, lm_labels, mc_token_ids):
model = self.double_head_model_class(config)
model.eval()
outputs = model(input_ids, mc_token_ids, lm_labels=lm_labels, mc_labels=mc_labels,
token_type_ids=token_type_ids, position_ids=position_ids)
lm_loss, mc_loss, lm_logits, mc_logits = outputs[:4]
loss = [lm_loss, mc_loss]
total_voc = self.n_special + self.vocab_size
self.parent.assertListEqual(
list(lm_logits.size()),
[self.batch_size, self.n_choices, self.seq_length, total_voc])
self.parent.assertListEqual(
list(mc_logits.size()),
[self.batch_size, self.n_choices])
self.parent.assertListEqual(
[list(l.size()) for l in loss],
[[], []])
def create_and_check_model_from_pretrained(self):
cache_dir = "/tmp/pytorch_pretrained_bert_test/"
for model_name in list(self.base_model_class.PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = self.base_model_class.from_pretrained(model_name, cache_dir=cache_dir)
shutil.rmtree(cache_dir)
self.parent.assertIsNotNone(model)
def create_and_check_commons(self, config, input_ids, token_type_ids, position_ids,
mc_labels, lm_labels, mc_token_ids):
inputs_dict = {'input_ids': input_ids}
create_and_check_commons(self, config, inputs_dict)
def run_common_tests(self, test_presents=False):
config_and_inputs = self.prepare_config_and_inputs()
self.create_and_check_base_model(*config_and_inputs)
config_and_inputs = self.prepare_config_and_inputs()
self.create_and_check_lm_head(*config_and_inputs)
config_and_inputs = self.prepare_config_and_inputs()
self.create_and_check_double_heads(*config_and_inputs)
if test_presents:
config_and_inputs = self.prepare_config_and_inputs()
self.create_and_check_presents(*config_and_inputs)
config_and_inputs = self.prepare_config_and_inputs()
self.create_and_check_commons(*config_and_inputs)
def run_slow_tests(self):
config_and_inputs = self.prepare_config_and_inputs()
self.create_and_check_model_from_pretrained(*config_and_inputs)
# coding=utf-8
# Copyright 2018 HuggingFace Inc..
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import unittest
import json
import random
import shutil
import pytest
import torch
from pytorch_pretrained_bert import PretrainedConfig, PreTrainedModel
from pytorch_pretrained_bert.modeling import BertModel, BertConfig, PRETRAINED_MODEL_ARCHIVE_MAP, PRETRAINED_CONFIG_ARCHIVE_MAP
class ModelUtilsTest(unittest.TestCase):
def test_model_from_pretrained(self):
for model_name in list(PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
config = BertConfig.from_pretrained(model_name)
self.assertIsNotNone(config)
self.assertIsInstance(config, PretrainedConfig)
model = BertModel.from_pretrained(model_name)
self.assertIsNotNone(model)
self.assertIsInstance(model, PreTrainedModel)
config = BertConfig.from_pretrained(model_name, output_attentions=True, output_hidden_states=True)
model = BertModel.from_pretrained(model_name, output_attentions=True, output_hidden_states=True)
self.assertEqual(model.config.output_attentions, True)
self.assertEqual(model.config.output_hidden_states, True)
self.assertEqual(model.config, config)
if __name__ == "__main__":
unittest.main()
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import unittest
import json
import random
import shutil
import pytest
import torch
from pytorch_pretrained_bert import (GPT2Config, GPT2Model,
GPT2LMHeadModel, GPT2DoubleHeadsModel)
from .model_tests_commons import (create_and_check_for_attentions, create_and_check_for_head_pruning,
create_and_check_for_headmasking, create_and_check_for_hidden_states,
ConfigTester, GPTModelTester)
class GPT2ModelTest(unittest.TestCase):
def test_config(self):
config_tester = ConfigTester(self, config_class=GPT2Config, n_embd=37)
config_tester.run_common_tests()
def test_model(self):
model_tester = GPTModelTester(self, config_class=GPT2Config, base_model_class=GPT2Model,
lm_head_model_class=GPT2LMHeadModel,
double_head_model_class=GPT2DoubleHeadsModel)
model_tester.run_common_tests(test_presents=True)
@pytest.mark.slow
def test_pretrained(self):
model_tester = GPTModelTester(self, config_class=GPT2Config, base_model_class=GPT2Model,
lm_head_model_class=GPT2LMHeadModel,
double_head_model_class=GPT2DoubleHeadsModel)
model_tester.run_slow_tests()
if __name__ == "__main__":
unittest.main()
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import unittest
import json
import random
import shutil
import pytest
import torch
from pytorch_pretrained_bert import (OpenAIGPTConfig, OpenAIGPTModel,
OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel)
from .model_tests_commons import (create_and_check_for_attentions, create_and_check_for_head_pruning,
create_and_check_for_headmasking, create_and_check_for_hidden_states,
ConfigTester, GPTModelTester)
class OpenAIModelTest(unittest.TestCase):
def test_config(self):
config_tester = ConfigTester(self, config_class=OpenAIGPTConfig, n_embd=37)
config_tester.run_common_tests()
def test_model(self):
model_tester = GPTModelTester(self, config_class=OpenAIGPTConfig, base_model_class=OpenAIGPTModel,
lm_head_model_class=OpenAIGPTLMHeadModel,
double_head_model_class=OpenAIGPTDoubleHeadsModel)
model_tester.run_common_tests(test_presents=False)
@pytest.mark.slow
def test_pretrained(self):
model_tester = GPTModelTester(self, config_class=OpenAIGPTConfig, base_model_class=OpenAIGPTModel,
lm_head_model_class=OpenAIGPTLMHeadModel,
double_head_model_class=OpenAIGPTDoubleHeadsModel)
model_tester.run_slow_tests()
if __name__ == "__main__":
unittest.main()
......@@ -31,6 +31,8 @@ from pytorch_pretrained_bert import (BertConfig, BertModel, BertForMaskedLM,
BertForTokenClassification, BertForMultipleChoice)
from pytorch_pretrained_bert.modeling import PRETRAINED_MODEL_ARCHIVE_MAP
from .model_tests_commons import (create_and_check_commons, ConfigTester, ids_tensor)
class BertModelTest(unittest.TestCase):
class BertModelTester(object):
......@@ -57,7 +59,11 @@ class BertModelTest(unittest.TestCase):
initializer_range=0.02,
num_labels=3,
num_choices=4,
scope=None):
scope=None,
all_model_classes = (BertModel, BertForMaskedLM, BertForNextSentencePrediction,
BertForPreTraining, BertForQuestionAnswering, BertForSequenceClassification,
BertForTokenClassification),
):
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
......@@ -80,25 +86,26 @@ class BertModelTest(unittest.TestCase):
self.num_labels = num_labels
self.num_choices = num_choices
self.scope = scope
self.all_model_classes = all_model_classes
def prepare_config_and_inputs(self):
input_ids = BertModelTest.ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_mask = None
if self.use_input_mask:
input_mask = BertModelTest.ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
token_type_ids = None
if self.use_token_type_ids:
token_type_ids = BertModelTest.ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
sequence_labels = None
token_labels = None
choice_labels = None
if self.use_labels:
sequence_labels = BertModelTest.ids_tensor([self.batch_size], self.type_sequence_label_size)
token_labels = BertModelTest.ids_tensor([self.batch_size, self.seq_length], self.num_labels)
choice_labels = BertModelTest.ids_tensor([self.batch_size], self.num_choices)
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
choice_labels = ids_tensor([self.batch_size], self.num_choices)
config = BertConfig(
vocab_size_or_config_json_file=self.vocab_size,
......@@ -120,136 +127,117 @@ class BertModelTest(unittest.TestCase):
list(result["loss"].size()),
[])
def create_bert_model(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
def create_and_check_bert_model(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = BertModel(config=config)
model.eval()
sequence_output, pooled_output = model(input_ids, token_type_ids, input_mask)
model = BertModel(config=config, output_hidden_states=True)
model.eval()
_, _, all_encoder_layers = model(input_ids, token_type_ids, input_mask)
outputs = {
result = {
"sequence_output": sequence_output,
"pooled_output": pooled_output,
"all_encoder_layers": all_encoder_layers,
}
return outputs
def check_bert_model_output(self, result):
self.parent.assertListEqual(
[size for layer in result["all_encoder_layers"] for size in layer.size()],
[self.batch_size, self.seq_length, self.hidden_size] * (self.num_hidden_layers + 1))
self.parent.assertListEqual(
list(result["sequence_output"].size()),
[self.batch_size, self.seq_length, self.hidden_size])
self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size])
def create_bert_for_masked_lm(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
def create_and_check_bert_for_masked_lm(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = BertForMaskedLM(config=config)
model.eval()
loss, prediction_scores = model(input_ids, token_type_ids, input_mask, token_labels)
outputs = {
result = {
"loss": loss,
"prediction_scores": prediction_scores,
}
return outputs
def check_bert_for_masked_lm_output(self, result):
self.parent.assertListEqual(
list(result["prediction_scores"].size()),
[self.batch_size, self.seq_length, self.vocab_size])
self.check_loss_output(result)
def create_bert_for_next_sequence_prediction(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
def create_and_check_bert_for_next_sequence_prediction(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = BertForNextSentencePrediction(config=config)
model.eval()
loss, seq_relationship_score = model(input_ids, token_type_ids, input_mask, sequence_labels)
outputs = {
result = {
"loss": loss,
"seq_relationship_score": seq_relationship_score,
}
return outputs
def check_bert_for_next_sequence_prediction_output(self, result):
self.parent.assertListEqual(
list(result["seq_relationship_score"].size()),
[self.batch_size, 2])
self.check_loss_output(result)
def create_bert_for_pretraining(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
def create_and_check_bert_for_pretraining(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = BertForPreTraining(config=config)
model.eval()
loss, prediction_scores, seq_relationship_score = model(input_ids, token_type_ids, input_mask, token_labels, sequence_labels)
outputs = {
result = {
"loss": loss,
"prediction_scores": prediction_scores,
"seq_relationship_score": seq_relationship_score,
}
return outputs
def check_bert_for_pretraining_output(self, result):
self.parent.assertListEqual(
list(result["prediction_scores"].size()),
[self.batch_size, self.seq_length, self.vocab_size])
self.parent.assertListEqual(
list(result["seq_relationship_score"].size()),
[self.batch_size, 2])
self.check_loss_output(result)
def create_bert_for_question_answering(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
def create_and_check_bert_for_question_answering(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = BertForQuestionAnswering(config=config)
model.eval()
loss, start_logits, end_logits = model(input_ids, token_type_ids, input_mask, sequence_labels, sequence_labels)
outputs = {
result = {
"loss": loss,
"start_logits": start_logits,
"end_logits": end_logits,
}
return outputs
def check_bert_for_question_answering_output(self, result):
self.parent.assertListEqual(
list(result["start_logits"].size()),
[self.batch_size, self.seq_length])
self.parent.assertListEqual(
list(result["end_logits"].size()),
[self.batch_size, self.seq_length])
self.check_loss_output(result)
def create_bert_for_sequence_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = BertForSequenceClassification(config=config, num_labels=self.num_labels)
def create_and_check_bert_for_sequence_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
config.num_labels = self.num_labels
model = BertForSequenceClassification(config)
model.eval()
loss, logits = model(input_ids, token_type_ids, input_mask, sequence_labels)
outputs = {
result = {
"loss": loss,
"logits": logits,
}
return outputs
def check_bert_for_sequence_classification_output(self, result):
self.parent.assertListEqual(
list(result["logits"].size()),
[self.batch_size, self.num_labels])
self.check_loss_output(result)
def create_bert_for_token_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = BertForTokenClassification(config=config, num_labels=self.num_labels)
def create_and_check_bert_for_token_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
config.num_labels = self.num_labels
model = BertForTokenClassification(config=config)
model.eval()
loss, logits = model(input_ids, token_type_ids, input_mask, token_labels)
outputs = {
result = {
"loss": loss,
"logits": logits,
}
return outputs
def check_bert_for_token_classification_output(self, result):
self.parent.assertListEqual(
list(result["logits"].size()),
[self.batch_size, self.seq_length, self.num_labels])
self.check_loss_output(result)
def create_bert_for_multiple_choice(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
model = BertForMultipleChoice(config=config, num_choices=self.num_choices)
def create_and_check_bert_for_multiple_choice(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
config.num_choices = self.num_choices
model = BertForMultipleChoice(config=config)
model.eval()
multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
......@@ -258,148 +246,26 @@ class BertModelTest(unittest.TestCase):
multiple_choice_token_type_ids,
multiple_choice_input_mask,
choice_labels)
outputs = {
result = {
"loss": loss,
"logits": logits,
}
return outputs
def check_bert_for_multiple_choice(self, result):
self.parent.assertListEqual(
list(result["logits"].size()),
[self.batch_size, self.num_choices])
self.check_loss_output(result)
def create_and_check_bert_for_attentions(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
for model_class in (BertModel, BertForMaskedLM, BertForNextSentencePrediction,
BertForPreTraining, BertForQuestionAnswering, BertForSequenceClassification,
BertForTokenClassification):
if model_class in [BertForSequenceClassification,
BertForTokenClassification]:
model = model_class(config=config, num_labels=self.num_labels, output_attentions=True)
else:
model = model_class(config=config, output_attentions=True)
model.eval()
outputs = model(input_ids, token_type_ids, input_mask)
attentions = outputs[-1]
self.parent.assertEqual(len(attentions), self.num_hidden_layers)
self.parent.assertListEqual(
list(attentions[0].size()),
[self.batch_size, self.num_attention_heads, self.seq_length, self.seq_length])
def create_and_check_bert_for_headmasking(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
for model_class in (BertModel, BertForMaskedLM, BertForNextSentencePrediction,
BertForPreTraining, BertForQuestionAnswering, BertForSequenceClassification,
BertForTokenClassification):
if model_class in [BertForSequenceClassification,
BertForTokenClassification]:
model = model_class(config=config,
num_labels=self.num_labels)
else:
model = model_class(config=config)
model.eval()
head_mask = torch.ones(self.num_hidden_layers, self.num_attention_heads).to(input_ids.device)
head_mask[0, 1:-1] = 0.0 # Mask all but the first and last heads on the first layer
head_mask[-1, 1:] = 0.0 # Mask all but the first head on the last layer
# Set that after having prepared the tensor to avoid error (leaf variable has been moved into the graph interior)
head_mask.requires_grad_(requires_grad=True)
outputs = model(input_ids, token_type_ids, input_mask, head_mask=head_mask)
# Compute some gradients
output = sum(t.sum() for t in outputs[0])
output = output.sum()
output.backward()
multihead_outputs = head_mask.grad
self.parent.assertEqual(len(multihead_outputs), self.num_hidden_layers)
# self.parent.assertListEqual(
# list(multihead_outputs[0].size()),
# [self.batch_size, self.num_attention_heads,
# self.seq_length, self.hidden_size // self.num_attention_heads])
# self.parent.assertEqual(
# len(multihead_outputs[0][:, 1:(self.num_attention_heads-1), :, :].nonzero()),
# 0)
# self.parent.assertEqual(
# len(multihead_outputs[0][:, 0, :, :].nonzero()),
# self.batch_size * self.seq_length * self.hidden_size // self.num_attention_heads)
# self.parent.assertEqual(
# len(multihead_outputs[0][:, self.num_attention_heads-1, :, :].nonzero()),
# self.batch_size * self.seq_length * self.hidden_size // self.num_attention_heads)
# self.parent.assertListEqual(
# list(multihead_outputs[1].size()),
# [self.batch_size, self.num_attention_heads,
# self.seq_length, self.hidden_size // self.num_attention_heads])
# self.parent.assertEqual(
# len(multihead_outputs[1].nonzero()),
# multihead_outputs[1].numel())
# self.parent.assertListEqual(
# list(multihead_outputs[-1].size()),
# [self.batch_size, self.num_attention_heads,
# self.seq_length, self.hidden_size // self.num_attention_heads])
# self.parent.assertEqual(
# len(multihead_outputs[-1][:, 1:, :, :].nonzero()),
# 0)
# self.parent.assertEqual(
# len(multihead_outputs[-1][:, 0, :, :].nonzero()),
# self.batch_size * self.seq_length * self.hidden_size // self.num_attention_heads)
def create_and_check_bert_for_head_pruning(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
for model_class in (BertModel, BertForMaskedLM, BertForNextSentencePrediction,
BertForPreTraining, BertForQuestionAnswering, BertForSequenceClassification,
BertForTokenClassification):
if model_class in [BertForSequenceClassification,
BertForTokenClassification]:
model = model_class(config=config,
num_labels=self.num_labels)
else:
model = model_class(config=config)
model.eval()
bert_model = model if isinstance(model, BertModel) else model.bert
heads_to_prune = {0: list(range(1, self.num_attention_heads)),
-1: [0]}
bert_model.prune_heads(heads_to_prune)
outputs = model(input_ids, token_type_ids, input_mask)
# output = sum(t.sum() for t in outputs[0])
# output = output.sum()
# output.backward()
# multihead_outputs = bert_model.get_multihead_outputs()
# self.parent.assertEqual(len(multihead_outputs), self.num_hidden_layers)
# self.parent.assertListEqual(
# list(multihead_outputs[0].size()),
# [self.batch_size, 1,
# self.seq_length, self.hidden_size // self.num_attention_heads])
# self.parent.assertListEqual(
# list(multihead_outputs[1].size()),
# [self.batch_size, self.num_attention_heads,
# self.seq_length, self.hidden_size // self.num_attention_heads])
# self.parent.assertListEqual(
# list(multihead_outputs[-1].size()),
# [self.batch_size, self.num_attention_heads-1,
# self.seq_length, self.hidden_size // self.num_attention_heads])
def create_and_check_bert_commons(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels):
inputs_dict = {'input_ids': input_ids, 'token_type_ids': token_type_ids, 'attention_mask': input_mask}
create_and_check_commons(self, config, inputs_dict)
def test_default(self):
self.run_tester(BertModelTest.BertModelTester(self))
def test_config_to_json_string(self):
config = BertConfig(vocab_size_or_config_json_file=99, hidden_size=37)
obj = json.loads(config.to_json_string())
self.assertEqual(obj["vocab_size"], 99)
self.assertEqual(obj["hidden_size"], 37)
def test_config_to_json_file(self):
config_first = BertConfig(vocab_size_or_config_json_file=99, hidden_size=37)
json_file_path = "/tmp/config.json"
config_first.to_json_file(json_file_path)
config_second = BertConfig.from_json_file(json_file_path)
os.remove(json_file_path)
self.assertEqual(config_second.to_dict(), config_first.to_dict())
def test_config(self):
config_tester = ConfigTester(self, config_class=BertConfig, hidden_size=37)
config_tester.run_common_tests()
@pytest.mark.slow
def test_model_from_pretrained(self):
......@@ -411,57 +277,31 @@ class BertModelTest(unittest.TestCase):
def run_tester(self, tester):
config_and_inputs = tester.prepare_config_and_inputs()
output_result = tester.create_bert_model(*config_and_inputs)
tester.check_bert_model_output(output_result)
output_result = tester.create_bert_for_masked_lm(*config_and_inputs)
tester.check_bert_for_masked_lm_output(output_result)
tester.check_loss_output(output_result)
output_result = tester.create_bert_for_next_sequence_prediction(*config_and_inputs)
tester.check_bert_for_next_sequence_prediction_output(output_result)
tester.check_loss_output(output_result)
output_result = tester.create_bert_for_pretraining(*config_and_inputs)
tester.check_bert_for_pretraining_output(output_result)
tester.check_loss_output(output_result)
output_result = tester.create_bert_for_question_answering(*config_and_inputs)
tester.check_bert_for_question_answering_output(output_result)
tester.check_loss_output(output_result)
tester.create_and_check_bert_model(*config_and_inputs)
output_result = tester.create_bert_for_sequence_classification(*config_and_inputs)
tester.check_bert_for_sequence_classification_output(output_result)
tester.check_loss_output(output_result)
output_result = tester.create_bert_for_token_classification(*config_and_inputs)
tester.check_bert_for_token_classification_output(output_result)
tester.check_loss_output(output_result)
config_and_inputs = tester.prepare_config_and_inputs()
tester.create_and_check_bert_for_masked_lm(*config_and_inputs)
output_result = tester.create_bert_for_multiple_choice(*config_and_inputs)
tester.check_bert_for_multiple_choice(output_result)
tester.check_loss_output(output_result)
config_and_inputs = tester.prepare_config_and_inputs()
tester.create_and_check_bert_for_multiple_choice(*config_and_inputs)
tester.create_and_check_bert_for_attentions(*config_and_inputs)
tester.create_and_check_bert_for_headmasking(*config_and_inputs)
tester.create_and_check_bert_for_head_pruning(*config_and_inputs)
config_and_inputs = tester.prepare_config_and_inputs()
tester.create_and_check_bert_for_next_sequence_prediction(*config_and_inputs)
@classmethod
def ids_tensor(cls, shape, vocab_size, rng=None, name=None):
"""Creates a random int32 tensor of the shape within the vocab size."""
if rng is None:
rng = random.Random()
config_and_inputs = tester.prepare_config_and_inputs()
tester.create_and_check_bert_for_pretraining(*config_and_inputs)
total_dims = 1
for dim in shape:
total_dims *= dim
config_and_inputs = tester.prepare_config_and_inputs()
tester.create_and_check_bert_for_question_answering(*config_and_inputs)
values = []
for _ in range(total_dims):
values.append(rng.randint(0, vocab_size - 1))
config_and_inputs = tester.prepare_config_and_inputs()
tester.create_and_check_bert_for_sequence_classification(*config_and_inputs)
return torch.tensor(data=values, dtype=torch.long).view(shape).contiguous()
config_and_inputs = tester.prepare_config_and_inputs()
tester.create_and_check_bert_for_token_classification(*config_and_inputs)
config_and_inputs = tester.prepare_config_and_inputs()
tester.create_and_check_bert_commons(*config_and_inputs)
if __name__ == "__main__":
unittest.main()
......@@ -28,6 +28,8 @@ import torch
from pytorch_pretrained_bert import (TransfoXLConfig, TransfoXLModel, TransfoXLLMHeadModel)
from pytorch_pretrained_bert.modeling_transfo_xl import PRETRAINED_MODEL_ARCHIVE_MAP
from .model_tests_commons import ConfigTester, create_and_check_commons, ids_tensor
class TransfoXLModelTest(unittest.TestCase):
class TransfoXLModelTester(object):
......@@ -41,54 +43,58 @@ class TransfoXLModelTest(unittest.TestCase):
use_labels=True,
vocab_size=99,
cutoffs=[10, 50, 80],
d_model=32,
hidden_size=32,
d_embed=32,
n_head=4,
num_attention_heads=4,
d_head=8,
d_inner=128,
div_val=2,
n_layer=5,
num_hidden_layers=5,
scope=None,
seed=1):
seed=1,
all_model_classes=(TransfoXLModel, TransfoXLLMHeadModel),
):
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
self.mem_len = mem_len
self.key_len = seq_length + mem_len
self.clamp_len = clamp_len
self.is_training = is_training
self.use_labels = use_labels
self.vocab_size = vocab_size
self.cutoffs = cutoffs
self.d_model = d_model
self.hidden_size = hidden_size
self.d_embed = d_embed
self.n_head = n_head
self.num_attention_heads = num_attention_heads
self.d_head = d_head
self.d_inner = d_inner
self.div_val = div_val
self.n_layer = n_layer
self.num_hidden_layers = num_hidden_layers
self.scope = scope
self.seed = seed
self.all_model_classes = all_model_classes
def prepare_config_and_inputs(self):
input_ids_1 = TransfoXLModelTest.ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_ids_2 = TransfoXLModelTest.ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_ids_1 = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_ids_2 = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
lm_labels = None
if self.use_labels:
lm_labels = TransfoXLModelTest.ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
lm_labels = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
config = TransfoXLConfig(
vocab_size_or_config_json_file=self.vocab_size,
mem_len=self.mem_len,
clamp_len=self.clamp_len,
cutoffs=self.cutoffs,
d_model=self.d_model,
d_model=self.hidden_size,
d_embed=self.d_embed,
n_head=self.n_head,
n_head=self.num_attention_heads,
d_head=self.d_head,
d_inner=self.d_inner,
div_val=self.div_val,
n_layer=self.n_layer)
n_layer=self.num_hidden_layers)
return (config, input_ids_1, input_ids_2, lm_labels)
......@@ -113,37 +119,34 @@ class TransfoXLModelTest(unittest.TestCase):
def check_transfo_xl_model_output(self, result):
self.parent.assertListEqual(
list(result["hidden_states_1"].size()),
[self.batch_size, self.seq_length, self.d_model])
[self.batch_size, self.seq_length, self.hidden_size])
self.parent.assertListEqual(
list(result["hidden_states_2"].size()),
[self.batch_size, self.seq_length, self.d_model])
[self.batch_size, self.seq_length, self.hidden_size])
self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_1"]),
[[self.mem_len, self.batch_size, self.d_model]] * self.n_layer)
[[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_2"]),
[[self.mem_len, self.batch_size, self.d_model]] * self.n_layer)
[[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
def create_transfo_xl_lm_head(self, config, input_ids_1, input_ids_2, lm_labels):
model = TransfoXLLMHeadModel(config)
model.eval()
loss_1, mems_1a = model(input_ids_1, labels=lm_labels)
lm_logits_1, mems_1b = model(input_ids_1)
loss_2, mems_2a = model(input_ids_2, labels=lm_labels, mems=mems_1a)
lm_logits_2, mems_2b = model(input_ids_2, mems=mems_1b)
lm_logits_1, mems_1 = model(input_ids_1)
loss_1, _, mems_1 = model(input_ids_1, labels=lm_labels)
lm_logits_2, mems_2 = model(input_ids_2, mems=mems_1)
loss_2, _, mems_2 = model(input_ids_2, labels=lm_labels, mems=mems_1)
outputs = {
"loss_1": loss_1,
"mems_1a": mems_1a,
"mems_1": mems_1,
"lm_logits_1": lm_logits_1,
"mems_1b": mems_1b,
"loss_2": loss_2,
"mems_2a": mems_2a,
"mems_2": mems_2,
"lm_logits_2": lm_logits_2,
"mems_2b": mems_2b,
}
return outputs
......@@ -155,14 +158,8 @@ class TransfoXLModelTest(unittest.TestCase):
list(result["lm_logits_1"].size()),
[self.batch_size, self.seq_length, self.vocab_size])
self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_1a"]),
[[self.mem_len, self.batch_size, self.d_model]] * self.n_layer)
self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_1b"]),
[[self.mem_len, self.batch_size, self.d_model]] * self.n_layer)
self.parent.assertListEqual(
list(mem[~torch.isnan(mem)].sum() for mem in result["mems_1a"]),
list(mem[~torch.isnan(mem)].sum() for mem in result["mems_1b"]))
list(list(mem.size()) for mem in result["mems_1"]),
[[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
self.parent.assertListEqual(
list(result["loss_2"].size()),
......@@ -171,31 +168,19 @@ class TransfoXLModelTest(unittest.TestCase):
list(result["lm_logits_2"].size()),
[self.batch_size, self.seq_length, self.vocab_size])
self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_2a"]),
[[self.mem_len, self.batch_size, self.d_model]] * self.n_layer)
self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_2b"]),
[[self.mem_len, self.batch_size, self.d_model]] * self.n_layer)
self.parent.assertListEqual(
list(mem[~torch.isnan(mem)].sum() for mem in result["mems_2a"]),
list(mem[~torch.isnan(mem)].sum() for mem in result["mems_2b"]))
list(list(mem.size()) for mem in result["mems_2"]),
[[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
def create_and_check_transfo_xl_commons(self, config, input_ids_1, input_ids_2, lm_labels):
inputs_dict = {'input_ids': input_ids_1}
create_and_check_commons(self, config, inputs_dict)
def test_default(self):
self.run_tester(TransfoXLModelTest.TransfoXLModelTester(self))
def test_config_to_json_string(self):
config = TransfoXLConfig(vocab_size_or_config_json_file=96, d_embed=37)
obj = json.loads(config.to_json_string())
self.assertEqual(obj["n_token"], 96)
self.assertEqual(obj["d_embed"], 37)
def test_config_to_json_file(self):
config_first = TransfoXLConfig(vocab_size_or_config_json_file=96, d_embed=37)
json_file_path = "/tmp/config.json"
config_first.to_json_file(json_file_path)
config_second = TransfoXLConfig.from_json_file(json_file_path)
os.remove(json_file_path)
self.assertEqual(config_second.to_dict(), config_first.to_dict())
def test_config(self):
config_tester = ConfigTester(self, config_class=TransfoXLConfig, d_embed=37)
config_tester.run_common_tests()
@pytest.mark.slow
def test_model_from_pretrained(self):
......@@ -209,28 +194,18 @@ class TransfoXLModelTest(unittest.TestCase):
config_and_inputs = tester.prepare_config_and_inputs()
tester.set_seed()
config_and_inputs = tester.prepare_config_and_inputs()
output_result = tester.create_transfo_xl_model(*config_and_inputs)
tester.check_transfo_xl_model_output(output_result)
tester.set_seed()
config_and_inputs = tester.prepare_config_and_inputs()
output_result = tester.create_transfo_xl_lm_head(*config_and_inputs)
tester.check_transfo_xl_lm_head_output(output_result)
@classmethod
def ids_tensor(cls, shape, vocab_size, rng=None, name=None):
"""Creates a random int32 tensor of the shape within the vocab size."""
if rng is None:
rng = random.Random()
total_dims = 1
for dim in shape:
total_dims *= dim
values = []
for _ in range(total_dims):
values.append(rng.randint(0, vocab_size - 1))
return torch.tensor(data=values, dtype=torch.long).view(shape).contiguous()
tester.set_seed()
config_and_inputs = tester.prepare_config_and_inputs()
tester.create_and_check_transfo_xl_commons(*config_and_inputs)
if __name__ == "__main__":
......
......@@ -25,9 +25,11 @@ import pytest
import torch
from pytorch_pretrained_bert import (XLNetConfig, XLNetModel, XLNetLMHeadModel)
from pytorch_pretrained_bert import (XLNetConfig, XLNetModel, XLNetLMHeadModel, XLNetForSequenceClassification, XLNetForQuestionAnswering)
from pytorch_pretrained_bert.modeling_xlnet import PRETRAINED_MODEL_ARCHIVE_MAP
from .model_tests_commons import ConfigTester, create_and_check_commons, ids_tensor
class XLNetModelTest(unittest.TestCase):
class XLNetModelTester(object):
......@@ -42,43 +44,48 @@ class XLNetModelTest(unittest.TestCase):
use_labels=True,
vocab_size=99,
cutoffs=[10, 50, 80],
d_model=32,
n_head=4,
hidden_size=32,
num_attention_heads=4,
d_inner=128,
n_layer=5,
num_hidden_layers=5,
max_position_embeddings=10,
untie_r=True,
bi_data=False,
same_length=False,
seed=1,
type_vocab_size=2):
type_vocab_size=2,
all_model_classes=(XLNetModel, XLNetLMHeadModel,
XLNetForSequenceClassification, XLNetForQuestionAnswering),
):
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
self.mem_len = mem_len
# self.key_len = seq_length + mem_len
self.clamp_len = clamp_len
self.reuse_len = reuse_len
self.is_training = is_training
self.use_labels = use_labels
self.vocab_size = vocab_size
self.cutoffs = cutoffs
self.d_model = d_model
self.n_head = n_head
self.hidden_size = hidden_size
self.num_attention_heads = num_attention_heads
self.d_inner = d_inner
self.n_layer = n_layer
self.num_hidden_layers = num_hidden_layers
self.max_position_embeddings = max_position_embeddings
self.bi_data = bi_data
self.untie_r = untie_r
self.same_length = same_length
self.seed = seed
self.type_vocab_size = type_vocab_size
self.all_model_classes = all_model_classes
def prepare_config_and_inputs(self):
input_ids_1 = XLNetModelTest.ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_ids_2 = XLNetModelTest.ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
segment_ids = XLNetModelTest.ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
input_ids_1 = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_ids_2 = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
segment_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
input_ids_q = XLNetModelTest.ids_tensor([self.batch_size, self.seq_length + 1], self.vocab_size)
input_ids_q = ids_tensor([self.batch_size, self.seq_length + 1], self.vocab_size)
perm_mask = torch.zeros(self.batch_size, self.seq_length + 1, self.seq_length + 1, dtype=torch.float)
perm_mask[:, :, -1] = 1.0 # Previous tokens don't see last token
target_mapping = torch.zeros(self.batch_size, 1, self.seq_length + 1, dtype=torch.float)
......@@ -89,8 +96,8 @@ class XLNetModelTest(unittest.TestCase):
# token_type_ids: int32 Tensor in shape [bsz, len], the input segment IDs.
# input_mask: float32 Tensor in shape [bsz, len], the input mask.
# 0 for real tokens and 1 for padding.
# mems: a list of float32 Tensors in shape [bsz, mem_len, d_model], memory
# from previous batches. The length of the list equals n_layer.
# mems: a list of float32 Tensors in shape [bsz, mem_len, hidden_size], memory
# from previous batches. The length of the list equals num_hidden_layers.
# If None, no memory is used.
# perm_mask: float32 Tensor in shape [bsz, len, len].
# If perm_mask[k, i, j] = 0, i attend to j in batch k;
......@@ -108,14 +115,14 @@ class XLNetModelTest(unittest.TestCase):
lm_labels = None
if self.use_labels:
lm_labels = XLNetModelTest.ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
lm_labels = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
config = XLNetConfig(
vocab_size_or_config_json_file=self.vocab_size,
d_model=self.d_model,
n_head=self.n_head,
d_model=self.hidden_size,
n_head=self.num_attention_heads,
d_inner=self.d_inner,
n_layer=self.n_layer,
n_layer=self.num_hidden_layers,
untie_r=self.untie_r,
max_position_embeddings=self.max_position_embeddings,
mem_len=self.mem_len,
......@@ -159,7 +166,7 @@ class XLNetModelTest(unittest.TestCase):
[self.batch_size, self.seq_length, self.vocab_size])
self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_1"]),
[[self.seq_length, self.batch_size, self.d_model]] * self.n_layer)
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
self.parent.assertListEqual(
list(result["loss_2"].size()),
......@@ -169,24 +176,18 @@ class XLNetModelTest(unittest.TestCase):
[self.batch_size, self.seq_length, self.vocab_size])
self.parent.assertListEqual(
list(list(mem.size()) for mem in result["mems_2"]),
[[self.mem_len, self.batch_size, self.d_model]] * self.n_layer)
[[self.mem_len, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
def create_and_check_xlnet_commons(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, target_mapping, inp_q, segment_ids, lm_labels):
inputs_dict = {'input_ids': input_ids_1}
create_and_check_commons(self, config, inputs_dict)
def test_default(self):
self.run_tester(XLNetModelTest.XLNetModelTester(self))
def test_config_to_json_string(self):
config = XLNetConfig(vocab_size_or_config_json_file=96, d_model=16*4)
obj = json.loads(config.to_json_string())
self.assertEqual(obj["n_token"], 96)
self.assertEqual(obj["d_model"], 16*4)
def test_config_to_json_file(self):
config_first = XLNetConfig(vocab_size_or_config_json_file=96, d_model=16*4)
json_file_path = "/tmp/config.json"
config_first.to_json_file(json_file_path)
config_second = XLNetConfig.from_json_file(json_file_path)
os.remove(json_file_path)
self.assertEqual(config_second.to_dict(), config_first.to_dict())
def test_config(self):
config_tester = ConfigTester(self, config_class=XLNetConfig, d_inner=37)
config_tester.run_common_tests()
@pytest.mark.slow
def test_model_from_pretrained(self):
......@@ -197,27 +198,14 @@ class XLNetModelTest(unittest.TestCase):
self.assertIsNotNone(model)
def run_tester(self, tester):
config_and_inputs = tester.prepare_config_and_inputs()
tester.set_seed()
config_and_inputs = tester.prepare_config_and_inputs()
output_result = tester.create_transfo_xl_lm_head(*config_and_inputs)
tester.check_transfo_xl_lm_head_output(output_result)
@classmethod
def ids_tensor(cls, shape, vocab_size, rng=None, name=None):
"""Creates a random int32 tensor of the shape within the vocab size."""
if rng is None:
rng = random.Random()
total_dims = 1
for dim in shape:
total_dims *= dim
values = []
for _ in range(total_dims):
values.append(rng.randint(0, vocab_size - 1))
return torch.tensor(data=values, dtype=torch.long).view(shape).contiguous()
tester.set_seed()
config_and_inputs = tester.prepare_config_and_inputs()
tester.create_and_check_xlnet_commons(*config_and_inputs)
@classmethod
def mask_tensor(cls, shape, vocab_size, rng=None, name=None):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment