Commit e891bb43 authored by LysandreJik's avatar LysandreJik
Browse files

BERT can be exported to TorchScript

parent 6ce1ee04
...@@ -323,7 +323,7 @@ class BertSelfAttention(nn.Module): ...@@ -323,7 +323,7 @@ class BertSelfAttention(nn.Module):
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape) context_layer = context_layer.view(*new_context_layer_shape)
outputs = [context_layer, attention_probs] if self.output_attentions else [context_layer] outputs = (context_layer, attention_probs) if self.output_attentions else (context_layer,)
return outputs return outputs
...@@ -367,7 +367,7 @@ class BertAttention(nn.Module): ...@@ -367,7 +367,7 @@ class BertAttention(nn.Module):
def forward(self, input_tensor, attention_mask, head_mask=None): def forward(self, input_tensor, attention_mask, head_mask=None):
self_outputs = self.self(input_tensor, attention_mask, head_mask) self_outputs = self.self(input_tensor, attention_mask, head_mask)
attention_output = self.output(self_outputs[0], input_tensor) attention_output = self.output(self_outputs[0], input_tensor)
outputs = [attention_output] + self_outputs[1:] # add attentions if we output them outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
return outputs return outputs
...@@ -412,7 +412,7 @@ class BertLayer(nn.Module): ...@@ -412,7 +412,7 @@ class BertLayer(nn.Module):
attention_output = attention_outputs[0] attention_output = attention_outputs[0]
intermediate_output = self.intermediate(attention_output) intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output) layer_output = self.output(intermediate_output, attention_output)
outputs = [layer_output] + attention_outputs[1:] # add attentions if we output them outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them
return outputs return outputs
...@@ -424,27 +424,27 @@ class BertEncoder(nn.Module): ...@@ -424,27 +424,27 @@ class BertEncoder(nn.Module):
self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)]) self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
def forward(self, hidden_states, attention_mask, head_mask=None): def forward(self, hidden_states, attention_mask, head_mask=None):
all_hidden_states = [] all_hidden_states = ()
all_attentions = [] all_attentions = ()
for i, layer_module in enumerate(self.layer): for i, layer_module in enumerate(self.layer):
if self.output_hidden_states: if self.output_hidden_states:
all_hidden_states.append(hidden_states) all_hidden_states += (hidden_states,)
layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i]) layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i])
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if self.output_attentions: if self.output_attentions:
all_attentions.append(layer_outputs[1]) all_attentions += (layer_outputs[1],)
# Add last layer # Add last layer
if self.output_hidden_states: if self.output_hidden_states:
all_hidden_states.append(hidden_states) all_hidden_states += (hidden_states,)
outputs = [hidden_states] outputs = (hidden_states,)
if self.output_hidden_states: if self.output_hidden_states:
outputs.append(all_hidden_states) outputs += (all_hidden_states,)
if self.output_attentions: if self.output_attentions:
outputs.append(all_attentions) outputs += (all_attentions,)
return outputs # outputs, (hidden states), (attentions) return outputs # outputs, (hidden states), (attentions)
...@@ -490,7 +490,7 @@ class BertLMPredictionHead(nn.Module): ...@@ -490,7 +490,7 @@ class BertLMPredictionHead(nn.Module):
self.decoder = nn.Linear(bert_model_embedding_weights.size(1), self.decoder = nn.Linear(bert_model_embedding_weights.size(1),
bert_model_embedding_weights.size(0), bert_model_embedding_weights.size(0),
bias=False) bias=False)
self.decoder.weight = bert_model_embedding_weights self.decoder.weight = nn.Parameter(bert_model_embedding_weights.clone())
self.bias = nn.Parameter(torch.zeros(bert_model_embedding_weights.size(0))) self.bias = nn.Parameter(torch.zeros(bert_model_embedding_weights.size(0)))
def forward(self, hidden_states): def forward(self, hidden_states):
...@@ -666,7 +666,7 @@ class BertModel(BertPreTrainedModel): ...@@ -666,7 +666,7 @@ class BertModel(BertPreTrainedModel):
sequence_output = encoder_outputs[0] sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output) pooled_output = self.pooler(sequence_output)
outputs = [sequence_output, pooled_output] + encoder_outputs[1:] # add hidden_states and attentions if they are here outputs = (sequence_output, pooled_output,) + encoder_outputs[1:] # add hidden_states and attentions if they are here
return outputs # sequence_output, pooled_output, (hidden_states), (attentions) return outputs # sequence_output, pooled_output, (hidden_states), (attentions)
...@@ -739,14 +739,14 @@ class BertForPreTraining(BertPreTrainedModel): ...@@ -739,14 +739,14 @@ class BertForPreTraining(BertPreTrainedModel):
sequence_output, pooled_output = outputs[:2] sequence_output, pooled_output = outputs[:2]
prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output) prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
outputs = [prediction_scores, seq_relationship_score] + outputs[2:] # add hidden states and attention if they are here outputs = (prediction_scores, seq_relationship_score,) + outputs[2:] # add hidden states and attention if they are here
if masked_lm_labels is not None and next_sentence_label is not None: if masked_lm_labels is not None and next_sentence_label is not None:
loss_fct = CrossEntropyLoss(ignore_index=-1) loss_fct = CrossEntropyLoss(ignore_index=-1)
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1)) masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
total_loss = masked_lm_loss + next_sentence_loss total_loss = masked_lm_loss + next_sentence_loss
outputs = [total_loss] + outputs outputs = (total_loss,) + outputs
return outputs # (loss), prediction_scores, seq_relationship_score, (hidden_states), (attentions) return outputs # (loss), prediction_scores, seq_relationship_score, (hidden_states), (attentions)
...@@ -815,11 +815,11 @@ class BertForMaskedLM(BertPreTrainedModel): ...@@ -815,11 +815,11 @@ class BertForMaskedLM(BertPreTrainedModel):
sequence_output = outputs[0] sequence_output = outputs[0]
prediction_scores = self.cls(sequence_output) prediction_scores = self.cls(sequence_output)
outputs = [prediction_scores] + outputs[2:] # Add hidden states and attention is they are here outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention is they are here
if masked_lm_labels is not None: if masked_lm_labels is not None:
loss_fct = CrossEntropyLoss(ignore_index=-1) loss_fct = CrossEntropyLoss(ignore_index=-1)
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1)) masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
outputs = [masked_lm_loss] + outputs outputs = (masked_lm_loss,) + outputs
return outputs # (masked_lm_loss), prediction_scores, (hidden_states), (attentions) return outputs # (masked_lm_loss), prediction_scores, (hidden_states), (attentions)
...@@ -885,11 +885,11 @@ class BertForNextSentencePrediction(BertPreTrainedModel): ...@@ -885,11 +885,11 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
seq_relationship_score = self.cls(pooled_output) seq_relationship_score = self.cls(pooled_output)
outputs = [seq_relationship_score] + outputs[2:] # add hidden states and attention if they are here outputs = (seq_relationship_score,) + outputs[2:] # add hidden states and attention if they are here
if next_sentence_label is not None: if next_sentence_label is not None:
loss_fct = CrossEntropyLoss(ignore_index=-1) loss_fct = CrossEntropyLoss(ignore_index=-1)
next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
outputs = [next_sentence_loss] + outputs outputs = (next_sentence_loss,) + outputs
return outputs # (next_sentence_loss), seq_relationship_score, (hidden_states), (attentions) return outputs # (next_sentence_loss), seq_relationship_score, (hidden_states), (attentions)
...@@ -960,7 +960,7 @@ class BertForSequenceClassification(BertPreTrainedModel): ...@@ -960,7 +960,7 @@ class BertForSequenceClassification(BertPreTrainedModel):
pooled_output = self.dropout(pooled_output) pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output) logits = self.classifier(pooled_output)
outputs = [logits] + outputs[2:] # add hidden states and attention if they are here outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
if labels is not None: if labels is not None:
if self.num_labels == 1: if self.num_labels == 1:
...@@ -970,7 +970,7 @@ class BertForSequenceClassification(BertPreTrainedModel): ...@@ -970,7 +970,7 @@ class BertForSequenceClassification(BertPreTrainedModel):
else: else:
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
outputs = [loss] + outputs outputs = (loss,) + outputs
return outputs # (loss), logits, (hidden_states), (attentions) return outputs # (loss), logits, (hidden_states), (attentions)
...@@ -1043,12 +1043,12 @@ class BertForMultipleChoice(BertPreTrainedModel): ...@@ -1043,12 +1043,12 @@ class BertForMultipleChoice(BertPreTrainedModel):
logits = self.classifier(pooled_output) logits = self.classifier(pooled_output)
reshaped_logits = logits.view(-1, num_choices) reshaped_logits = logits.view(-1, num_choices)
outputs = [reshaped_logits] + outputs[2:] # add hidden states and attention if they are here outputs = (reshaped_logits,) + outputs[2:] # add hidden states and attention if they are here
if labels is not None: if labels is not None:
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
loss = loss_fct(reshaped_logits, labels) loss = loss_fct(reshaped_logits, labels)
outputs = [loss] + outputs outputs = (loss,) + outputs
return outputs # (loss), reshaped_logits, (hidden_states), (attentions) return outputs # (loss), reshaped_logits, (hidden_states), (attentions)
...@@ -1119,7 +1119,7 @@ class BertForTokenClassification(BertPreTrainedModel): ...@@ -1119,7 +1119,7 @@ class BertForTokenClassification(BertPreTrainedModel):
sequence_output = self.dropout(sequence_output) sequence_output = self.dropout(sequence_output)
logits = self.classifier(sequence_output) logits = self.classifier(sequence_output)
outputs = [logits] + outputs[2:] # add hidden states and attention if they are here outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
if labels is not None: if labels is not None:
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
# Only keep active parts of the loss # Only keep active parts of the loss
...@@ -1130,7 +1130,7 @@ class BertForTokenClassification(BertPreTrainedModel): ...@@ -1130,7 +1130,7 @@ class BertForTokenClassification(BertPreTrainedModel):
loss = loss_fct(active_logits, active_labels) loss = loss_fct(active_logits, active_labels)
else: else:
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
outputs = [loss] + outputs outputs = (loss,) + outputs
return outputs # (loss), logits, (hidden_states), (attentions) return outputs # (loss), logits, (hidden_states), (attentions)
...@@ -1205,7 +1205,7 @@ class BertForQuestionAnswering(BertPreTrainedModel): ...@@ -1205,7 +1205,7 @@ class BertForQuestionAnswering(BertPreTrainedModel):
start_logits = start_logits.squeeze(-1) start_logits = start_logits.squeeze(-1)
end_logits = end_logits.squeeze(-1) end_logits = end_logits.squeeze(-1)
outputs = [start_logits, end_logits] + outputs[2:] outputs = (start_logits, end_logits,) + outputs[2:]
if start_positions is not None and end_positions is not None: if start_positions is not None and end_positions is not None:
# If we are on multi-GPU, split add a dimension # If we are on multi-GPU, split add a dimension
if len(start_positions.size()) > 1: if len(start_positions.size()) > 1:
...@@ -1221,6 +1221,6 @@ class BertForQuestionAnswering(BertPreTrainedModel): ...@@ -1221,6 +1221,6 @@ class BertForQuestionAnswering(BertPreTrainedModel):
start_loss = loss_fct(start_logits, start_positions) start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions) end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2 total_loss = (start_loss + end_loss) / 2
outputs = [total_loss] + outputs outputs = (total_loss,) + outputs
return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions) return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment