"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "566b083eb11cd4a742d3616559484409fbdb34b5"
Commit 3b23a846 authored by thomwolf's avatar thomwolf
Browse files

Merge branch 'xlnet' of https://github.com/huggingface/pytorch-pretrained-BERT into xlnet

parents 8fa3a1f0 64ce4dbd
...@@ -46,6 +46,7 @@ class PretrainedConfig(object): ...@@ -46,6 +46,7 @@ class PretrainedConfig(object):
self.num_labels = kwargs.pop('num_labels', 2) self.num_labels = kwargs.pop('num_labels', 2)
self.output_attentions = kwargs.pop('output_attentions', False) self.output_attentions = kwargs.pop('output_attentions', False)
self.output_hidden_states = kwargs.pop('output_hidden_states', False) self.output_hidden_states = kwargs.pop('output_hidden_states', False)
self.torchscript = kwargs.pop('torchscript', False)
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
......
...@@ -323,7 +323,7 @@ class BertSelfAttention(nn.Module): ...@@ -323,7 +323,7 @@ class BertSelfAttention(nn.Module):
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape) context_layer = context_layer.view(*new_context_layer_shape)
outputs = [context_layer, attention_probs] if self.output_attentions else [context_layer] outputs = (context_layer, attention_probs) if self.output_attentions else (context_layer,)
return outputs return outputs
...@@ -367,7 +367,7 @@ class BertAttention(nn.Module): ...@@ -367,7 +367,7 @@ class BertAttention(nn.Module):
def forward(self, input_tensor, attention_mask, head_mask=None): def forward(self, input_tensor, attention_mask, head_mask=None):
self_outputs = self.self(input_tensor, attention_mask, head_mask) self_outputs = self.self(input_tensor, attention_mask, head_mask)
attention_output = self.output(self_outputs[0], input_tensor) attention_output = self.output(self_outputs[0], input_tensor)
outputs = [attention_output] + self_outputs[1:] # add attentions if we output them outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
return outputs return outputs
...@@ -412,7 +412,7 @@ class BertLayer(nn.Module): ...@@ -412,7 +412,7 @@ class BertLayer(nn.Module):
attention_output = attention_outputs[0] attention_output = attention_outputs[0]
intermediate_output = self.intermediate(attention_output) intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output) layer_output = self.output(intermediate_output, attention_output)
outputs = [layer_output] + attention_outputs[1:] # add attentions if we output them outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them
return outputs return outputs
...@@ -424,27 +424,27 @@ class BertEncoder(nn.Module): ...@@ -424,27 +424,27 @@ class BertEncoder(nn.Module):
self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)]) self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
def forward(self, hidden_states, attention_mask, head_mask=None): def forward(self, hidden_states, attention_mask, head_mask=None):
all_hidden_states = [] all_hidden_states = ()
all_attentions = [] all_attentions = ()
for i, layer_module in enumerate(self.layer): for i, layer_module in enumerate(self.layer):
if self.output_hidden_states: if self.output_hidden_states:
all_hidden_states.append(hidden_states) all_hidden_states = all_hidden_states + (hidden_states,)
layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i]) layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i])
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if self.output_attentions: if self.output_attentions:
all_attentions.append(layer_outputs[1]) all_attentions = all_attentions + (layer_outputs[1],)
# Add last layer # Add last layer
if self.output_hidden_states: if self.output_hidden_states:
all_hidden_states.append(hidden_states) all_hidden_states = all_hidden_states + (hidden_states,)
outputs = [hidden_states] outputs = (hidden_states,)
if self.output_hidden_states: if self.output_hidden_states:
outputs.append(all_hidden_states) outputs = outputs + (all_hidden_states,)
if self.output_attentions: if self.output_attentions:
outputs.append(all_attentions) outputs = outputs + (all_attentions,)
return outputs # outputs, (hidden states), (attentions) return outputs # outputs, (hidden states), (attentions)
...@@ -484,13 +484,19 @@ class BertLMPredictionHead(nn.Module): ...@@ -484,13 +484,19 @@ class BertLMPredictionHead(nn.Module):
def __init__(self, config, bert_model_embedding_weights): def __init__(self, config, bert_model_embedding_weights):
super(BertLMPredictionHead, self).__init__() super(BertLMPredictionHead, self).__init__()
self.transform = BertPredictionHeadTransform(config) self.transform = BertPredictionHeadTransform(config)
self.torchscript = config.torchscript
# The output weights are the same as the input embeddings, but there is # The output weights are the same as the input embeddings, but there is
# an output-only bias for each token. # an output-only bias for each token.
self.decoder = nn.Linear(bert_model_embedding_weights.size(1), self.decoder = nn.Linear(bert_model_embedding_weights.size(1),
bert_model_embedding_weights.size(0), bert_model_embedding_weights.size(0),
bias=False) bias=False)
self.decoder.weight = bert_model_embedding_weights
if self.torchscript:
self.decoder.weight = nn.Parameter(bert_model_embedding_weights.clone())
else:
self.decoder.weight = bert_model_embedding_weights
self.bias = nn.Parameter(torch.zeros(bert_model_embedding_weights.size(0))) self.bias = nn.Parameter(torch.zeros(bert_model_embedding_weights.size(0)))
def forward(self, hidden_states): def forward(self, hidden_states):
...@@ -666,7 +672,7 @@ class BertModel(BertPreTrainedModel): ...@@ -666,7 +672,7 @@ class BertModel(BertPreTrainedModel):
sequence_output = encoder_outputs[0] sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output) pooled_output = self.pooler(sequence_output)
outputs = [sequence_output, pooled_output] + encoder_outputs[1:] # add hidden_states and attentions if they are here outputs = (sequence_output, pooled_output,) + encoder_outputs[1:] # add hidden_states and attentions if they are here
return outputs # sequence_output, pooled_output, (hidden_states), (attentions) return outputs # sequence_output, pooled_output, (hidden_states), (attentions)
...@@ -739,14 +745,14 @@ class BertForPreTraining(BertPreTrainedModel): ...@@ -739,14 +745,14 @@ class BertForPreTraining(BertPreTrainedModel):
sequence_output, pooled_output = outputs[:2] sequence_output, pooled_output = outputs[:2]
prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output) prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
outputs = [prediction_scores, seq_relationship_score] + outputs[2:] # add hidden states and attention if they are here outputs = (prediction_scores, seq_relationship_score,) + outputs[2:] # add hidden states and attention if they are here
if masked_lm_labels is not None and next_sentence_label is not None: if masked_lm_labels is not None and next_sentence_label is not None:
loss_fct = CrossEntropyLoss(ignore_index=-1) loss_fct = CrossEntropyLoss(ignore_index=-1)
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1)) masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
total_loss = masked_lm_loss + next_sentence_loss total_loss = masked_lm_loss + next_sentence_loss
outputs = [total_loss] + outputs outputs = (total_loss,) + outputs
return outputs # (loss), prediction_scores, seq_relationship_score, (hidden_states), (attentions) return outputs # (loss), prediction_scores, seq_relationship_score, (hidden_states), (attentions)
...@@ -815,11 +821,11 @@ class BertForMaskedLM(BertPreTrainedModel): ...@@ -815,11 +821,11 @@ class BertForMaskedLM(BertPreTrainedModel):
sequence_output = outputs[0] sequence_output = outputs[0]
prediction_scores = self.cls(sequence_output) prediction_scores = self.cls(sequence_output)
outputs = [prediction_scores] + outputs[2:] # Add hidden states and attention is they are here outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention is they are here
if masked_lm_labels is not None: if masked_lm_labels is not None:
loss_fct = CrossEntropyLoss(ignore_index=-1) loss_fct = CrossEntropyLoss(ignore_index=-1)
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1)) masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
outputs = [masked_lm_loss] + outputs outputs = (masked_lm_loss,) + outputs
return outputs # (masked_lm_loss), prediction_scores, (hidden_states), (attentions) return outputs # (masked_lm_loss), prediction_scores, (hidden_states), (attentions)
...@@ -885,11 +891,11 @@ class BertForNextSentencePrediction(BertPreTrainedModel): ...@@ -885,11 +891,11 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
seq_relationship_score = self.cls(pooled_output) seq_relationship_score = self.cls(pooled_output)
outputs = [seq_relationship_score] + outputs[2:] # add hidden states and attention if they are here outputs = (seq_relationship_score,) + outputs[2:] # add hidden states and attention if they are here
if next_sentence_label is not None: if next_sentence_label is not None:
loss_fct = CrossEntropyLoss(ignore_index=-1) loss_fct = CrossEntropyLoss(ignore_index=-1)
next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
outputs = [next_sentence_loss] + outputs outputs = (next_sentence_loss,) + outputs
return outputs # (next_sentence_loss), seq_relationship_score, (hidden_states), (attentions) return outputs # (next_sentence_loss), seq_relationship_score, (hidden_states), (attentions)
...@@ -960,7 +966,7 @@ class BertForSequenceClassification(BertPreTrainedModel): ...@@ -960,7 +966,7 @@ class BertForSequenceClassification(BertPreTrainedModel):
pooled_output = self.dropout(pooled_output) pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output) logits = self.classifier(pooled_output)
outputs = [logits] + outputs[2:] # add hidden states and attention if they are here outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
if labels is not None: if labels is not None:
if self.num_labels == 1: if self.num_labels == 1:
...@@ -970,7 +976,7 @@ class BertForSequenceClassification(BertPreTrainedModel): ...@@ -970,7 +976,7 @@ class BertForSequenceClassification(BertPreTrainedModel):
else: else:
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
outputs = [loss] + outputs outputs = (loss,) + outputs
return outputs # (loss), logits, (hidden_states), (attentions) return outputs # (loss), logits, (hidden_states), (attentions)
...@@ -1043,12 +1049,12 @@ class BertForMultipleChoice(BertPreTrainedModel): ...@@ -1043,12 +1049,12 @@ class BertForMultipleChoice(BertPreTrainedModel):
logits = self.classifier(pooled_output) logits = self.classifier(pooled_output)
reshaped_logits = logits.view(-1, num_choices) reshaped_logits = logits.view(-1, num_choices)
outputs = [reshaped_logits] + outputs[2:] # add hidden states and attention if they are here outputs = (reshaped_logits,) + outputs[2:] # add hidden states and attention if they are here
if labels is not None: if labels is not None:
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
loss = loss_fct(reshaped_logits, labels) loss = loss_fct(reshaped_logits, labels)
outputs = [loss] + outputs outputs = (loss,) + outputs
return outputs # (loss), reshaped_logits, (hidden_states), (attentions) return outputs # (loss), reshaped_logits, (hidden_states), (attentions)
...@@ -1119,7 +1125,7 @@ class BertForTokenClassification(BertPreTrainedModel): ...@@ -1119,7 +1125,7 @@ class BertForTokenClassification(BertPreTrainedModel):
sequence_output = self.dropout(sequence_output) sequence_output = self.dropout(sequence_output)
logits = self.classifier(sequence_output) logits = self.classifier(sequence_output)
outputs = [logits] + outputs[2:] # add hidden states and attention if they are here outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
if labels is not None: if labels is not None:
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
# Only keep active parts of the loss # Only keep active parts of the loss
...@@ -1130,7 +1136,7 @@ class BertForTokenClassification(BertPreTrainedModel): ...@@ -1130,7 +1136,7 @@ class BertForTokenClassification(BertPreTrainedModel):
loss = loss_fct(active_logits, active_labels) loss = loss_fct(active_logits, active_labels)
else: else:
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
outputs = [loss] + outputs outputs = (loss,) + outputs
return outputs # (loss), logits, (hidden_states), (attentions) return outputs # (loss), logits, (hidden_states), (attentions)
...@@ -1205,7 +1211,7 @@ class BertForQuestionAnswering(BertPreTrainedModel): ...@@ -1205,7 +1211,7 @@ class BertForQuestionAnswering(BertPreTrainedModel):
start_logits = start_logits.squeeze(-1) start_logits = start_logits.squeeze(-1)
end_logits = end_logits.squeeze(-1) end_logits = end_logits.squeeze(-1)
outputs = [start_logits, end_logits] + outputs[2:] outputs = (start_logits, end_logits,) + outputs[2:]
if start_positions is not None and end_positions is not None: if start_positions is not None and end_positions is not None:
# If we are on multi-GPU, split add a dimension # If we are on multi-GPU, split add a dimension
if len(start_positions.size()) > 1: if len(start_positions.size()) > 1:
...@@ -1221,6 +1227,6 @@ class BertForQuestionAnswering(BertPreTrainedModel): ...@@ -1221,6 +1227,6 @@ class BertForQuestionAnswering(BertPreTrainedModel):
start_loss = loss_fct(start_logits, start_positions) start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions) end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2 total_loss = (start_loss + end_loss) / 2
outputs = [total_loss] + outputs outputs = (total_loss,) + outputs
return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions) return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions)
...@@ -322,13 +322,18 @@ class GPT2LMHead(nn.Module): ...@@ -322,13 +322,18 @@ class GPT2LMHead(nn.Module):
self.n_embd = config.n_embd self.n_embd = config.n_embd
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.predict_special_tokens = config.predict_special_tokens self.predict_special_tokens = config.predict_special_tokens
self.torchscript = config.torchscript
embed_shape = model_embeddings_weights.shape embed_shape = model_embeddings_weights.shape
self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False) self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False)
self.set_embeddings_weights(model_embeddings_weights) self.set_embeddings_weights(model_embeddings_weights)
def set_embeddings_weights(self, model_embeddings_weights, predict_special_tokens=True): def set_embeddings_weights(self, model_embeddings_weights, predict_special_tokens=True):
self.predict_special_tokens = predict_special_tokens self.predict_special_tokens = predict_special_tokens
self.decoder.weight = model_embeddings_weights # Tied weights # Export to TorchScript can't handle parameter sharing so we are cloning them.
if self.torchscript:
self.decoder.weight = nn.Parameter(model_embeddings_weights.clone())
else:
self.decoder.weight = model_embeddings_weights # Tied weights
def forward(self, hidden_state): def forward(self, hidden_state):
lm_logits = self.decoder(hidden_state) lm_logits = self.decoder(hidden_state)
...@@ -557,16 +562,16 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -557,16 +562,16 @@ class GPT2Model(GPT2PreTrainedModel):
output_shape = input_shape + (hidden_states.size(-1),) output_shape = input_shape + (hidden_states.size(-1),)
presents = [] presents = ()
all_attentions = [] all_attentions = []
all_hidden_states = [] all_hidden_states = ()
for i, (block, layer_past) in enumerate(zip(self.h, past)): for i, (block, layer_past) in enumerate(zip(self.h, past)):
if self.output_hidden_states: if self.output_hidden_states:
all_hidden_states.append(hidden_states.view(*output_shape)) all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),)
outputs = block(hidden_states, layer_past, head_mask[i]) outputs = block(hidden_states, layer_past, head_mask[i])
hidden_states, present = outputs[:2] hidden_states, present = outputs[:2]
presents.append(present) presents = presents + (present,)
if self.output_attentions: if self.output_attentions:
all_attentions.append(outputs[2]) all_attentions.append(outputs[2])
...@@ -576,16 +581,16 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -576,16 +581,16 @@ class GPT2Model(GPT2PreTrainedModel):
hidden_states = hidden_states.view(*output_shape) hidden_states = hidden_states.view(*output_shape)
# Add last hidden state # Add last hidden state
if self.output_hidden_states: if self.output_hidden_states:
all_hidden_states.append(hidden_states) all_hidden_states = all_hidden_states + (hidden_states,)
outputs = [hidden_states, presents] outputs = (hidden_states, presents)
if self.output_hidden_states: if self.output_hidden_states:
outputs.append(all_hidden_states) outputs = outputs + (all_hidden_states,)
if self.output_attentions: if self.output_attentions:
# let the number of heads free (-1) so we can extract attention even after head pruning # let the number of heads free (-1) so we can extract attention even after head pruning
attention_output_shape = input_shape[:-1] + (-1,) + all_attentions[0].shape[-2:] attention_output_shape = input_shape[:-1] + (-1,) + all_attentions[0].shape[-2:]
all_attentions = list(t.view(*attention_output_shape) for t in all_attentions) all_attentions = tuple(t.view(*attention_output_shape) for t in all_attentions)
outputs.append(all_attentions) outputs = outputs + (all_attentions,)
return outputs # last hidden state, presents, (all hidden_states), (attentions) return outputs # last hidden state, presents, (all hidden_states), (attentions)
...@@ -658,7 +663,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): ...@@ -658,7 +663,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
lm_logits = self.lm_head(hidden_states) lm_logits = self.lm_head(hidden_states)
outputs = [lm_logits] + transformer_outputs[1:] outputs = (lm_logits,) + transformer_outputs[1:]
if lm_labels is not None: if lm_labels is not None:
# Shift so that tokens < n predict n # Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].contiguous() shift_logits = lm_logits[..., :-1, :].contiguous()
...@@ -667,7 +672,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): ...@@ -667,7 +672,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
loss_fct = CrossEntropyLoss(ignore_index=-1) loss_fct = CrossEntropyLoss(ignore_index=-1)
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1)) shift_labels.view(-1))
outputs = [loss] + outputs outputs = (loss,) + outputs
return outputs # (loss), lm_logits, presents, (all hidden_states), (attentions) return outputs # (loss), lm_logits, presents, (all hidden_states), (attentions)
...@@ -750,18 +755,18 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel): ...@@ -750,18 +755,18 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
lm_logits = self.lm_head(hidden_states) lm_logits = self.lm_head(hidden_states)
mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids) mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids)
outputs = [lm_logits, mc_logits] + transformer_outputs[1:] outputs = (lm_logits, mc_logits) + transformer_outputs[1:]
if mc_labels is not None: if mc_labels is not None:
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)), loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)),
mc_labels.view(-1)) mc_labels.view(-1))
outputs = [loss] + outputs outputs = (loss,) + outputs
if lm_labels is not None: if lm_labels is not None:
shift_logits = lm_logits[..., :-1, :].contiguous() shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = lm_labels[..., 1:].contiguous() shift_labels = lm_labels[..., 1:].contiguous()
loss_fct = CrossEntropyLoss(ignore_index=-1) loss_fct = CrossEntropyLoss(ignore_index=-1)
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1)) shift_labels.view(-1))
outputs = [loss] + outputs outputs = (loss,) + outputs
return outputs # (lm loss), (mc loss), lm logits, mc logits, presents, (all hidden_states), (attentions) return outputs # (lm loss), (mc loss), lm logits, mc logits, presents, (all hidden_states), (attentions)
...@@ -348,14 +348,18 @@ class OpenAIGPTLMHead(nn.Module): ...@@ -348,14 +348,18 @@ class OpenAIGPTLMHead(nn.Module):
self.n_embd = config.n_embd self.n_embd = config.n_embd
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.predict_special_tokens = config.predict_special_tokens self.predict_special_tokens = config.predict_special_tokens
self.torchscript = config.torchscript
embed_shape = model_embeddings_weights.shape embed_shape = model_embeddings_weights.shape
self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False) self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False)
self.set_embeddings_weights(model_embeddings_weights) self.set_embeddings_weights(model_embeddings_weights)
def set_embeddings_weights(self, model_embeddings_weights, predict_special_tokens=True): def set_embeddings_weights(self, model_embeddings_weights, predict_special_tokens=True):
self.predict_special_tokens = predict_special_tokens self.predict_special_tokens = predict_special_tokens
embed_shape = model_embeddings_weights.shape
self.decoder.weight = model_embeddings_weights # Tied weights if self.torchscript:
self.decoder.weight = nn.Parameter(model_embeddings_weights.clone())
else:
self.decoder.weight = model_embeddings_weights # Tied weights
def forward(self, hidden_state): def forward(self, hidden_state):
lm_logits = self.decoder(hidden_state) lm_logits = self.decoder(hidden_state)
...@@ -579,26 +583,26 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel): ...@@ -579,26 +583,26 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
output_shape = input_shape + (hidden_states.size(-1),) output_shape = input_shape + (hidden_states.size(-1),)
all_attentions = [] all_attentions = ()
all_hidden_states = [] all_hidden_states = ()
for i, block in enumerate(self.h): for i, block in enumerate(self.h):
if self.output_hidden_states: if self.output_hidden_states:
all_hidden_states.append(hidden_states.view(*output_shape)) all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),)
outputs = block(hidden_states, head_mask[i]) outputs = block(hidden_states, head_mask[i])
hidden_states = outputs[0] hidden_states = outputs[0]
if self.output_attentions: if self.output_attentions:
all_attentions.append(outputs[1]) all_attentions = all_attentions + (outputs[1],)
# Add last layer # Add last layer
if self.output_hidden_states: if self.output_hidden_states:
all_hidden_states.append(hidden_states.view(*output_shape)) all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),)
outputs = [hidden_states.view(*output_shape)] outputs = (hidden_states.view(*output_shape),)
if self.output_hidden_states: if self.output_hidden_states:
outputs.append(all_hidden_states) outputs = outputs + (all_hidden_states,)
if self.output_attentions: if self.output_attentions:
outputs.append(all_attentions) outputs = outputs + (all_attentions,)
return outputs # last hidden state, (all hidden states), (all attentions) return outputs # last hidden state, (all hidden states), (all attentions)
...@@ -682,7 +686,7 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel): ...@@ -682,7 +686,7 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
hidden_states = transformer_outputs[0] hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states) lm_logits = self.lm_head(hidden_states)
outputs = [lm_logits] + transformer_outputs[1:] outputs = (lm_logits,) + transformer_outputs[1:]
if lm_labels is not None: if lm_labels is not None:
# Shift so that tokens < n predict n # Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].contiguous() shift_logits = lm_logits[..., :-1, :].contiguous()
...@@ -691,7 +695,7 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel): ...@@ -691,7 +695,7 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
loss_fct = CrossEntropyLoss(ignore_index=-1) loss_fct = CrossEntropyLoss(ignore_index=-1)
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1)) shift_labels.view(-1))
outputs = [loss] + outputs outputs = (loss,) + outputs
return outputs # (loss), lm_logits, (all hidden states), (all attentions) return outputs # (loss), lm_logits, (all hidden states), (all attentions)
...@@ -785,18 +789,18 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel): ...@@ -785,18 +789,18 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
lm_logits = self.lm_head(hidden_states) lm_logits = self.lm_head(hidden_states)
mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids) mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids)
outputs = [lm_logits, mc_logits] + transformer_outputs[1:] outputs = (lm_logits, mc_logits) + transformer_outputs[1:]
if mc_labels is not None: if mc_labels is not None:
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)), loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)),
mc_labels.view(-1)) mc_labels.view(-1))
outputs = [loss] + outputs outputs = (loss,) + outputs
if lm_labels is not None: if lm_labels is not None:
shift_logits = lm_logits[..., :-1, :].contiguous() shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = lm_labels[..., 1:].contiguous() shift_labels = lm_labels[..., 1:].contiguous()
loss_fct = CrossEntropyLoss(ignore_index=-1) loss_fct = CrossEntropyLoss(ignore_index=-1)
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1)) shift_labels.view(-1))
outputs = [loss] + outputs outputs = (loss,) + outputs
return outputs # (lm loss), (mc loss), lm logits, mc logits, (all hidden_states), (attentions) return outputs # (lm loss), (mc loss), lm logits, mc logits, (all hidden_states), (attentions)
...@@ -384,7 +384,8 @@ class XLNetRelativeAttention(nn.Module): ...@@ -384,7 +384,8 @@ class XLNetRelativeAttention(nn.Module):
x = x.reshape(x_size[1], x_size[0], x_size[2], x_size[3]) x = x.reshape(x_size[1], x_size[0], x_size[2], x_size[3])
x = x[1:, ...] x = x[1:, ...]
x = x.reshape(x_size[0], x_size[1] - 1, x_size[2], x_size[3]) x = x.reshape(x_size[0], x_size[1] - 1, x_size[2], x_size[3])
x = x[:, 0:klen, :, :] # x = x[:, 0:klen, :, :]
x = torch.index_select(x, 1, torch.arange(klen))
return x return x
...@@ -527,9 +528,9 @@ class XLNetRelativeAttention(nn.Module): ...@@ -527,9 +528,9 @@ class XLNetRelativeAttention(nn.Module):
output_h = self.post_attention(h, attn_vec) output_h = self.post_attention(h, attn_vec)
output_g = None output_g = None
outputs = [output_h, output_g] outputs = (output_h, output_g)
if self.output_attentions: if self.output_attentions:
outputs = outputs + [attn_prob] outputs = outputs + (attn_prob,)
return outputs return outputs
class XLNetFeedForward(nn.Module): class XLNetFeedForward(nn.Module):
...@@ -574,7 +575,7 @@ class XLNetLayer(nn.Module): ...@@ -574,7 +575,7 @@ class XLNetLayer(nn.Module):
output_g = self.ff(output_g) output_g = self.ff(output_g)
output_h = self.ff(output_h) output_h = self.ff(output_h)
outputs = [output_h, output_g] + outputs[2:] # Add again attentions if there are there outputs = (output_h, output_g) + outputs[2:] # Add again attentions if there are there
return outputs return outputs
...@@ -688,7 +689,7 @@ class XLNetModel(XLNetPreTrainedModel): ...@@ -688,7 +689,7 @@ class XLNetModel(XLNetPreTrainedModel):
def relative_positional_encoding(self, qlen, klen, bsz=None): def relative_positional_encoding(self, qlen, klen, bsz=None):
"""create relative positional encoding.""" """create relative positional encoding."""
freq_seq = torch.arange(0, self.d_model, 2.0, dtype=torch.float) freq_seq = torch.arange(0, self.d_model, 2.0, dtype=torch.float)
inv_freq = 1 / (10000 ** (freq_seq / self.d_model)) inv_freq = 1 / torch.pow(10000, (freq_seq / self.d_model))
if self.attn_type == 'bi': if self.attn_type == 'bi':
# beg, end = klen - 1, -qlen # beg, end = klen - 1, -qlen
...@@ -869,7 +870,7 @@ class XLNetModel(XLNetPreTrainedModel): ...@@ -869,7 +870,7 @@ class XLNetModel(XLNetPreTrainedModel):
else: else:
head_mask = [None] * self.n_layer head_mask = [None] * self.n_layer
new_mems = [] new_mems = ()
if mems is None: if mems is None:
mems = [None] * len(self.layer) mems = [None] * len(self.layer)
...@@ -877,7 +878,7 @@ class XLNetModel(XLNetPreTrainedModel): ...@@ -877,7 +878,7 @@ class XLNetModel(XLNetPreTrainedModel):
hidden_states = [] hidden_states = []
for i, layer_module in enumerate(self.layer): for i, layer_module in enumerate(self.layer):
# cache new mems # cache new mems
new_mems.append(self.cache_mem(output_h, mems[i])) new_mems = new_mems + (self.cache_mem(output_h, mems[i]),)
if self.output_hidden_states: if self.output_hidden_states:
hidden_states.append((output_h, output_g) if output_g is not None else output_h) hidden_states.append((output_h, output_g) if output_g is not None else output_h)
...@@ -895,16 +896,16 @@ class XLNetModel(XLNetPreTrainedModel): ...@@ -895,16 +896,16 @@ class XLNetModel(XLNetPreTrainedModel):
output = self.dropout(output_g if output_g is not None else output_h) output = self.dropout(output_g if output_g is not None else output_h)
# Prepare outputs, we transpose back here to shape [bsz, len, hidden_dim] (cf. beginning of forward() method) # Prepare outputs, we transpose back here to shape [bsz, len, hidden_dim] (cf. beginning of forward() method)
outputs = [output.permute(1, 0, 2).contiguous(), new_mems] outputs = (output.permute(1, 0, 2).contiguous(), new_mems)
if self.output_hidden_states: if self.output_hidden_states:
if output_g is not None: if output_g is not None:
hidden_states = [h.permute(1, 0, 2).contiguous() for hs in hidden_states for h in hs] hidden_states = tuple(h.permute(1, 0, 2).contiguous() for hs in hidden_states for h in hs)
else: else:
hidden_states = [hs.permute(1, 0, 2).contiguous() for hs in hidden_states] hidden_states = tuple(hs.permute(1, 0, 2).contiguous() for hs in hidden_states)
outputs.append(hidden_states) outputs = outputs + (hidden_states,)
if self.output_attentions: if self.output_attentions:
attentions = list(t.permute(2, 3, 0, 1).contiguous() for t in attentions) attentions = tuple(t.permute(2, 3, 0, 1).contiguous() for t in attentions)
outputs.append(attentions) outputs = outputs + (attentions,)
return outputs # outputs, new_mems, (hidden_states), (attentions) return outputs # outputs, new_mems, (hidden_states), (attentions)
...@@ -974,6 +975,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel): ...@@ -974,6 +975,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
super(XLNetLMHeadModel, self).__init__(config) super(XLNetLMHeadModel, self).__init__(config)
self.attn_type = config.attn_type self.attn_type = config.attn_type
self.same_length = config.same_length self.same_length = config.same_length
self.torchscript = config.torchscript
self.transformer = XLNetModel(config) self.transformer = XLNetModel(config)
self.lm_loss = nn.Linear(config.d_model, config.n_token, bias=True) self.lm_loss = nn.Linear(config.d_model, config.n_token, bias=True)
...@@ -986,7 +988,10 @@ class XLNetLMHeadModel(XLNetPreTrainedModel): ...@@ -986,7 +988,10 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
def tie_weights(self): def tie_weights(self):
""" Make sure we are sharing the embeddings """ Make sure we are sharing the embeddings
""" """
self.lm_loss.weight = self.transformer.word_embedding.weight if self.torchscript:
self.lm_loss.weight = nn.Parameter(self.transformer.word_embedding.weight.clone())
else:
self.lm_loss.weight = self.transformer.word_embedding.weight
def forward(self, input_ids, token_type_ids=None, input_mask=None, attention_mask=None, def forward(self, input_ids, token_type_ids=None, input_mask=None, attention_mask=None,
mems=None, perm_mask=None, target_mapping=None, inp_q=None, mems=None, perm_mask=None, target_mapping=None, inp_q=None,
...@@ -1026,14 +1031,14 @@ class XLNetLMHeadModel(XLNetPreTrainedModel): ...@@ -1026,14 +1031,14 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
logits = self.lm_loss(transformer_outputs[0]) logits = self.lm_loss(transformer_outputs[0])
outputs = [logits] + transformer_outputs[1:] # Keep mems, hidden states, attentions if there are in it outputs = (logits,) + transformer_outputs[1:] # Keep mems, hidden states, attentions if there are in it
if labels is not None: if labels is not None:
# Flatten the tokens # Flatten the tokens
loss_fct = CrossEntropyLoss(ignore_index=-1) loss_fct = CrossEntropyLoss(ignore_index=-1)
loss = loss_fct(logits.view(-1, logits.size(-1)), loss = loss_fct(logits.view(-1, logits.size(-1)),
labels.view(-1)) labels.view(-1))
outputs = [loss] + outputs outputs = (loss,) + outputs
return outputs # return (loss), logits, (mems), (hidden states), (attentions) return outputs # return (loss), logits, (mems), (hidden states), (attentions)
...@@ -1061,7 +1066,7 @@ class XLNetSequenceSummary(nn.Module): ...@@ -1061,7 +1066,7 @@ class XLNetSequenceSummary(nn.Module):
output = hidden_states[:, 0] output = hidden_states[:, 0]
elif self.summary_type == 'mean': elif self.summary_type == 'mean':
output = hidden_states.mean(dim=1) output = hidden_states.mean(dim=1)
elif summary_type == 'attn': elif self.summary_type == 'attn':
raise NotImplementedError raise NotImplementedError
output = self.summary(output) output = self.summary(output)
...@@ -1180,7 +1185,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel): ...@@ -1180,7 +1185,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
output = self.sequence_summary(output) output = self.sequence_summary(output)
logits = self.logits_proj(output) logits = self.logits_proj(output)
outputs = [logits] + transformer_outputs[1:] # Keep mems, hidden states, attentions if there are in it outputs = (logits,) + transformer_outputs[1:] # Keep mems, hidden states, attentions if there are in it
if labels is not None: if labels is not None:
if self.num_labels == 1: if self.num_labels == 1:
...@@ -1190,7 +1195,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel): ...@@ -1190,7 +1195,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
else: else:
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
outputs = [loss] + outputs outputs = (loss,) + outputs
return outputs # return (loss), logits, (mems), (hidden states), (attentions) return outputs # return (loss), logits, (mems), (hidden states), (attentions)
...@@ -1271,7 +1276,7 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel): ...@@ -1271,7 +1276,7 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
start_logits = start_logits.squeeze(-1) start_logits = start_logits.squeeze(-1)
end_logits = end_logits.squeeze(-1) end_logits = end_logits.squeeze(-1)
outputs = [start_logits, end_logits] + transformer_outputs[1:] # Keep mems, hidden states, attentions if there are in it outputs = (start_logits, end_logits,) + transformer_outputs[1:] # Keep mems, hidden states, attentions if there are in it
if start_positions is not None and end_positions is not None: if start_positions is not None and end_positions is not None:
# If we are on multi-GPU, split add a dimension # If we are on multi-GPU, split add a dimension
...@@ -1288,6 +1293,6 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel): ...@@ -1288,6 +1293,6 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
start_loss = loss_fct(start_logits, start_positions) start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions) end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2 total_loss = (start_loss + end_loss) / 2
outputs = [total_loss] + outputs outputs = (total_loss,) + outputs
return outputs # return (loss), logits, (mems), (hidden states), (attentions) return outputs # return (loss), logits, (mems), (hidden states), (attentions)
...@@ -31,6 +31,52 @@ def _config_zero_init(config): ...@@ -31,6 +31,52 @@ def _config_zero_init(config):
setattr(configs_no_init, key, 0.0) setattr(configs_no_init, key, 0.0)
return configs_no_init return configs_no_init
def _create_and_check_torchscript_output_attentions(tester, model_classes, config, inputs_dict):
config.output_attentions = True
_create_and_check_torchscript(tester, model_classes, config, inputs_dict)
def _create_and_check_torchscript_output_hidden_state(tester, model_classes, config, inputs_dict):
config.output_hidden_states = True
_create_and_check_torchscript(tester, model_classes, config, inputs_dict)
def _create_and_check_torchscript(tester, model_classes, config, inputs_dict):
configs_no_init = _config_zero_init(config) # To be sure we have no Nan
configs_no_init.torchscript = True
for model_class in model_classes:
model = model_class(config=configs_no_init)
model.eval()
inputs = inputs_dict['input_ids'] # Let's keep only input_ids
try:
torch.jit.trace(model, inputs)
except RuntimeError:
tester.parent.fail("Couldn't trace module.")
try:
traced_gpt2 = torch.jit.trace(model, inputs)
torch.jit.save(traced_gpt2, "traced_model.pt")
except RuntimeError:
tester.parent.fail("Couldn't save module.")
try:
loaded_model = torch.jit.load("traced_model.pt")
os.remove("traced_model.pt")
except ValueError:
tester.parent.fail("Couldn't load module.")
model.eval()
loaded_model.eval()
model_params = model.parameters()
loaded_model_params = loaded_model.parameters()
models_equal = True
for p1, p2 in zip(model_params, loaded_model_params):
if p1.data.ne(p2.data).sum() > 0:
models_equal = False
tester.parent.assertTrue(models_equal)
def _create_and_check_initialization(tester, model_classes, config, inputs_dict): def _create_and_check_initialization(tester, model_classes, config, inputs_dict):
configs_no_init = _config_zero_init(config) configs_no_init = _config_zero_init(config)
for model_class in model_classes: for model_class in model_classes:
...@@ -41,7 +87,7 @@ def _create_and_check_initialization(tester, model_classes, config, inputs_dict) ...@@ -41,7 +87,7 @@ def _create_and_check_initialization(tester, model_classes, config, inputs_dict)
msg="Parameter {} of model {} seems not properly initialized".format(name, model_class)) msg="Parameter {} of model {} seems not properly initialized".format(name, model_class))
def _create_and_check_for_headmasking(tester, model_classes, config, inputs_dict): def _create_and_check_for_headmasking(tester, model_classes, config, inputs_dict):
configs_no_init = _config_zero_init(config) configs_no_init = _config_zero_init(config) # To be sure we have no Nan
for model_class in model_classes: for model_class in model_classes:
config.output_attentions = True config.output_attentions = True
config.output_hidden_states = True config.output_hidden_states = True
...@@ -157,11 +203,17 @@ def _create_and_check_for_hidden_states(tester, model_classes, config, inputs_di ...@@ -157,11 +203,17 @@ def _create_and_check_for_hidden_states(tester, model_classes, config, inputs_di
[tester.seq_length, tester.hidden_size]) [tester.seq_length, tester.hidden_size])
def create_and_check_commons(tester, config, inputs_dict, test_pruning=True): def create_and_check_commons(tester, config, inputs_dict, test_pruning=True, test_torchscript=True):
_create_and_check_initialization(tester, tester.all_model_classes, config, inputs_dict) _create_and_check_initialization(tester, tester.all_model_classes, config, inputs_dict)
_create_and_check_for_attentions(tester, tester.all_model_classes, config, inputs_dict) _create_and_check_for_attentions(tester, tester.all_model_classes, config, inputs_dict)
_create_and_check_for_headmasking(tester, tester.all_model_classes, config, inputs_dict) _create_and_check_for_headmasking(tester, tester.all_model_classes, config, inputs_dict)
_create_and_check_for_hidden_states(tester, tester.all_model_classes, config, inputs_dict) _create_and_check_for_hidden_states(tester, tester.all_model_classes, config, inputs_dict)
if test_torchscript:
_create_and_check_torchscript(tester, tester.all_model_classes, config, inputs_dict)
_create_and_check_torchscript_output_attentions(tester, tester.all_model_classes, config, inputs_dict)
_create_and_check_torchscript_output_hidden_state(tester, tester.all_model_classes, config, inputs_dict)
if test_pruning: if test_pruning:
_create_and_check_for_head_pruning(tester, tester.all_model_classes, config, inputs_dict) _create_and_check_for_head_pruning(tester, tester.all_model_classes, config, inputs_dict)
......
...@@ -173,7 +173,7 @@ class TransfoXLModelTest(unittest.TestCase): ...@@ -173,7 +173,7 @@ class TransfoXLModelTest(unittest.TestCase):
def create_and_check_transfo_xl_commons(self, config, input_ids_1, input_ids_2, lm_labels): def create_and_check_transfo_xl_commons(self, config, input_ids_1, input_ids_2, lm_labels):
inputs_dict = {'input_ids': input_ids_1} inputs_dict = {'input_ids': input_ids_1}
create_and_check_commons(self, config, inputs_dict, test_pruning=False) create_and_check_commons(self, config, inputs_dict, test_pruning=False, test_torchscript=False)
def test_default(self): def test_default(self):
self.run_tester(TransfoXLModelTest.TransfoXLModelTester(self)) self.run_tester(TransfoXLModelTest.TransfoXLModelTester(self))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment