Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
e891bb43
Commit
e891bb43
authored
Jul 02, 2019
by
LysandreJik
Browse files
BERT can be exported to TorchScript
parent
6ce1ee04
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
27 additions
and
27 deletions
+27
-27
pytorch_pretrained_bert/modeling.py
pytorch_pretrained_bert/modeling.py
+27
-27
No files found.
pytorch_pretrained_bert/modeling.py
View file @
e891bb43
...
...
@@ -323,7 +323,7 @@ class BertSelfAttention(nn.Module):
new_context_layer_shape
=
context_layer
.
size
()[:
-
2
]
+
(
self
.
all_head_size
,)
context_layer
=
context_layer
.
view
(
*
new_context_layer_shape
)
outputs
=
[
context_layer
,
attention_probs
]
if
self
.
output_attentions
else
[
context_layer
]
outputs
=
(
context_layer
,
attention_probs
)
if
self
.
output_attentions
else
(
context_layer
,)
return
outputs
...
...
@@ -367,7 +367,7 @@ class BertAttention(nn.Module):
def
forward
(
self
,
input_tensor
,
attention_mask
,
head_mask
=
None
):
self_outputs
=
self
.
self
(
input_tensor
,
attention_mask
,
head_mask
)
attention_output
=
self
.
output
(
self_outputs
[
0
],
input_tensor
)
outputs
=
[
attention_output
]
+
self_outputs
[
1
:]
# add attentions if we output them
outputs
=
(
attention_output
,)
+
self_outputs
[
1
:]
# add attentions if we output them
return
outputs
...
...
@@ -412,7 +412,7 @@ class BertLayer(nn.Module):
attention_output
=
attention_outputs
[
0
]
intermediate_output
=
self
.
intermediate
(
attention_output
)
layer_output
=
self
.
output
(
intermediate_output
,
attention_output
)
outputs
=
[
layer_output
]
+
attention_outputs
[
1
:]
# add attentions if we output them
outputs
=
(
layer_output
,)
+
attention_outputs
[
1
:]
# add attentions if we output them
return
outputs
...
...
@@ -424,27 +424,27 @@ class BertEncoder(nn.Module):
self
.
layer
=
nn
.
ModuleList
([
BertLayer
(
config
)
for
_
in
range
(
config
.
num_hidden_layers
)])
def
forward
(
self
,
hidden_states
,
attention_mask
,
head_mask
=
None
):
all_hidden_states
=
[]
all_attentions
=
[]
all_hidden_states
=
()
all_attentions
=
()
for
i
,
layer_module
in
enumerate
(
self
.
layer
):
if
self
.
output_hidden_states
:
all_hidden_states
.
append
(
hidden_states
)
all_hidden_states
+=
(
hidden_states
,
)
layer_outputs
=
layer_module
(
hidden_states
,
attention_mask
,
head_mask
[
i
])
hidden_states
=
layer_outputs
[
0
]
if
self
.
output_attentions
:
all_attentions
.
append
(
layer_outputs
[
1
])
all_attentions
+=
(
layer_outputs
[
1
]
,
)
# Add last layer
if
self
.
output_hidden_states
:
all_hidden_states
.
append
(
hidden_states
)
all_hidden_states
+=
(
hidden_states
,
)
outputs
=
[
hidden_states
]
outputs
=
(
hidden_states
,)
if
self
.
output_hidden_states
:
outputs
.
append
(
all_hidden_states
)
outputs
+=
(
all_hidden_states
,
)
if
self
.
output_attentions
:
outputs
.
append
(
all_attentions
)
outputs
+=
(
all_attentions
,
)
return
outputs
# outputs, (hidden states), (attentions)
...
...
@@ -490,7 +490,7 @@ class BertLMPredictionHead(nn.Module):
self
.
decoder
=
nn
.
Linear
(
bert_model_embedding_weights
.
size
(
1
),
bert_model_embedding_weights
.
size
(
0
),
bias
=
False
)
self
.
decoder
.
weight
=
bert_model_embedding_weights
self
.
decoder
.
weight
=
nn
.
Parameter
(
bert_model_embedding_weights
.
clone
())
self
.
bias
=
nn
.
Parameter
(
torch
.
zeros
(
bert_model_embedding_weights
.
size
(
0
)))
def
forward
(
self
,
hidden_states
):
...
...
@@ -666,7 +666,7 @@ class BertModel(BertPreTrainedModel):
sequence_output
=
encoder_outputs
[
0
]
pooled_output
=
self
.
pooler
(
sequence_output
)
outputs
=
[
sequence_output
,
pooled_output
]
+
encoder_outputs
[
1
:]
# add hidden_states and attentions if they are here
outputs
=
(
sequence_output
,
pooled_output
,)
+
encoder_outputs
[
1
:]
# add hidden_states and attentions if they are here
return
outputs
# sequence_output, pooled_output, (hidden_states), (attentions)
...
...
@@ -739,14 +739,14 @@ class BertForPreTraining(BertPreTrainedModel):
sequence_output
,
pooled_output
=
outputs
[:
2
]
prediction_scores
,
seq_relationship_score
=
self
.
cls
(
sequence_output
,
pooled_output
)
outputs
=
[
prediction_scores
,
seq_relationship_score
]
+
outputs
[
2
:]
# add hidden states and attention if they are here
outputs
=
(
prediction_scores
,
seq_relationship_score
,)
+
outputs
[
2
:]
# add hidden states and attention if they are here
if
masked_lm_labels
is
not
None
and
next_sentence_label
is
not
None
:
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
masked_lm_loss
=
loss_fct
(
prediction_scores
.
view
(
-
1
,
self
.
config
.
vocab_size
),
masked_lm_labels
.
view
(
-
1
))
next_sentence_loss
=
loss_fct
(
seq_relationship_score
.
view
(
-
1
,
2
),
next_sentence_label
.
view
(
-
1
))
total_loss
=
masked_lm_loss
+
next_sentence_loss
outputs
=
[
total_loss
]
+
outputs
outputs
=
(
total_loss
,)
+
outputs
return
outputs
# (loss), prediction_scores, seq_relationship_score, (hidden_states), (attentions)
...
...
@@ -815,11 +815,11 @@ class BertForMaskedLM(BertPreTrainedModel):
sequence_output
=
outputs
[
0
]
prediction_scores
=
self
.
cls
(
sequence_output
)
outputs
=
[
prediction_scores
]
+
outputs
[
2
:]
# Add hidden states and attention is they are here
outputs
=
(
prediction_scores
,)
+
outputs
[
2
:]
# Add hidden states and attention is they are here
if
masked_lm_labels
is
not
None
:
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
masked_lm_loss
=
loss_fct
(
prediction_scores
.
view
(
-
1
,
self
.
config
.
vocab_size
),
masked_lm_labels
.
view
(
-
1
))
outputs
=
[
masked_lm_loss
]
+
outputs
outputs
=
(
masked_lm_loss
,)
+
outputs
return
outputs
# (masked_lm_loss), prediction_scores, (hidden_states), (attentions)
...
...
@@ -885,11 +885,11 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
seq_relationship_score
=
self
.
cls
(
pooled_output
)
outputs
=
[
seq_relationship_score
]
+
outputs
[
2
:]
# add hidden states and attention if they are here
outputs
=
(
seq_relationship_score
,)
+
outputs
[
2
:]
# add hidden states and attention if they are here
if
next_sentence_label
is
not
None
:
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
next_sentence_loss
=
loss_fct
(
seq_relationship_score
.
view
(
-
1
,
2
),
next_sentence_label
.
view
(
-
1
))
outputs
=
[
next_sentence_loss
]
+
outputs
outputs
=
(
next_sentence_loss
,)
+
outputs
return
outputs
# (next_sentence_loss), seq_relationship_score, (hidden_states), (attentions)
...
...
@@ -960,7 +960,7 @@ class BertForSequenceClassification(BertPreTrainedModel):
pooled_output
=
self
.
dropout
(
pooled_output
)
logits
=
self
.
classifier
(
pooled_output
)
outputs
=
[
logits
]
+
outputs
[
2
:]
# add hidden states and attention if they are here
outputs
=
(
logits
,)
+
outputs
[
2
:]
# add hidden states and attention if they are here
if
labels
is
not
None
:
if
self
.
num_labels
==
1
:
...
...
@@ -970,7 +970,7 @@ class BertForSequenceClassification(BertPreTrainedModel):
else
:
loss_fct
=
CrossEntropyLoss
()
loss
=
loss_fct
(
logits
.
view
(
-
1
,
self
.
num_labels
),
labels
.
view
(
-
1
))
outputs
=
[
loss
]
+
outputs
outputs
=
(
loss
,)
+
outputs
return
outputs
# (loss), logits, (hidden_states), (attentions)
...
...
@@ -1043,12 +1043,12 @@ class BertForMultipleChoice(BertPreTrainedModel):
logits
=
self
.
classifier
(
pooled_output
)
reshaped_logits
=
logits
.
view
(
-
1
,
num_choices
)
outputs
=
[
reshaped_logits
]
+
outputs
[
2
:]
# add hidden states and attention if they are here
outputs
=
(
reshaped_logits
,)
+
outputs
[
2
:]
# add hidden states and attention if they are here
if
labels
is
not
None
:
loss_fct
=
CrossEntropyLoss
()
loss
=
loss_fct
(
reshaped_logits
,
labels
)
outputs
=
[
loss
]
+
outputs
outputs
=
(
loss
,)
+
outputs
return
outputs
# (loss), reshaped_logits, (hidden_states), (attentions)
...
...
@@ -1119,7 +1119,7 @@ class BertForTokenClassification(BertPreTrainedModel):
sequence_output
=
self
.
dropout
(
sequence_output
)
logits
=
self
.
classifier
(
sequence_output
)
outputs
=
[
logits
]
+
outputs
[
2
:]
# add hidden states and attention if they are here
outputs
=
(
logits
,)
+
outputs
[
2
:]
# add hidden states and attention if they are here
if
labels
is
not
None
:
loss_fct
=
CrossEntropyLoss
()
# Only keep active parts of the loss
...
...
@@ -1130,7 +1130,7 @@ class BertForTokenClassification(BertPreTrainedModel):
loss
=
loss_fct
(
active_logits
,
active_labels
)
else
:
loss
=
loss_fct
(
logits
.
view
(
-
1
,
self
.
num_labels
),
labels
.
view
(
-
1
))
outputs
=
[
loss
]
+
outputs
outputs
=
(
loss
,)
+
outputs
return
outputs
# (loss), logits, (hidden_states), (attentions)
...
...
@@ -1205,7 +1205,7 @@ class BertForQuestionAnswering(BertPreTrainedModel):
start_logits
=
start_logits
.
squeeze
(
-
1
)
end_logits
=
end_logits
.
squeeze
(
-
1
)
outputs
=
[
start_logits
,
end_logits
]
+
outputs
[
2
:]
outputs
=
(
start_logits
,
end_logits
,)
+
outputs
[
2
:]
if
start_positions
is
not
None
and
end_positions
is
not
None
:
# If we are on multi-GPU, split add a dimension
if
len
(
start_positions
.
size
())
>
1
:
...
...
@@ -1221,6 +1221,6 @@ class BertForQuestionAnswering(BertPreTrainedModel):
start_loss
=
loss_fct
(
start_logits
,
start_positions
)
end_loss
=
loss_fct
(
end_logits
,
end_positions
)
total_loss
=
(
start_loss
+
end_loss
)
/
2
outputs
=
[
total_loss
]
+
outputs
outputs
=
(
total_loss
,)
+
outputs
return
outputs
# (loss), start_logits, end_logits, (hidden_states), (attentions)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment