Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
e891bb43
Commit
e891bb43
authored
Jul 02, 2019
by
LysandreJik
Browse files
BERT can be exported to TorchScript
parent
6ce1ee04
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
27 additions
and
27 deletions
+27
-27
pytorch_pretrained_bert/modeling.py
pytorch_pretrained_bert/modeling.py
+27
-27
No files found.
pytorch_pretrained_bert/modeling.py
View file @
e891bb43
...
...
@@ -323,7 +323,7 @@ class BertSelfAttention(nn.Module):
new_context_layer_shape
=
context_layer
.
size
()[:
-
2
]
+
(
self
.
all_head_size
,)
context_layer
=
context_layer
.
view
(
*
new_context_layer_shape
)
outputs
=
[
context_layer
,
attention_probs
]
if
self
.
output_attentions
else
[
context_layer
]
outputs
=
(
context_layer
,
attention_probs
)
if
self
.
output_attentions
else
(
context_layer
,)
return
outputs
...
...
@@ -367,7 +367,7 @@ class BertAttention(nn.Module):
def
forward
(
self
,
input_tensor
,
attention_mask
,
head_mask
=
None
):
self_outputs
=
self
.
self
(
input_tensor
,
attention_mask
,
head_mask
)
attention_output
=
self
.
output
(
self_outputs
[
0
],
input_tensor
)
outputs
=
[
attention_output
]
+
self_outputs
[
1
:]
# add attentions if we output them
outputs
=
(
attention_output
,)
+
self_outputs
[
1
:]
# add attentions if we output them
return
outputs
...
...
@@ -412,7 +412,7 @@ class BertLayer(nn.Module):
attention_output
=
attention_outputs
[
0
]
intermediate_output
=
self
.
intermediate
(
attention_output
)
layer_output
=
self
.
output
(
intermediate_output
,
attention_output
)
outputs
=
[
layer_output
]
+
attention_outputs
[
1
:]
# add attentions if we output them
outputs
=
(
layer_output
,)
+
attention_outputs
[
1
:]
# add attentions if we output them
return
outputs
...
...
@@ -424,27 +424,27 @@ class BertEncoder(nn.Module):
self
.
layer
=
nn
.
ModuleList
([
BertLayer
(
config
)
for
_
in
range
(
config
.
num_hidden_layers
)])
def
forward
(
self
,
hidden_states
,
attention_mask
,
head_mask
=
None
):
all_hidden_states
=
[]
all_attentions
=
[]
all_hidden_states
=
()
all_attentions
=
()
for
i
,
layer_module
in
enumerate
(
self
.
layer
):
if
self
.
output_hidden_states
:
all_hidden_states
.
append
(
hidden_states
)
all_hidden_states
+=
(
hidden_states
,
)
layer_outputs
=
layer_module
(
hidden_states
,
attention_mask
,
head_mask
[
i
])
hidden_states
=
layer_outputs
[
0
]
if
self
.
output_attentions
:
all_attentions
.
append
(
layer_outputs
[
1
])
all_attentions
+=
(
layer_outputs
[
1
]
,
)
# Add last layer
if
self
.
output_hidden_states
:
all_hidden_states
.
append
(
hidden_states
)
all_hidden_states
+=
(
hidden_states
,
)
outputs
=
[
hidden_states
]
outputs
=
(
hidden_states
,)
if
self
.
output_hidden_states
:
outputs
.
append
(
all_hidden_states
)
outputs
+=
(
all_hidden_states
,
)
if
self
.
output_attentions
:
outputs
.
append
(
all_attentions
)
outputs
+=
(
all_attentions
,
)
return
outputs
# outputs, (hidden states), (attentions)
...
...
@@ -490,7 +490,7 @@ class BertLMPredictionHead(nn.Module):
self
.
decoder
=
nn
.
Linear
(
bert_model_embedding_weights
.
size
(
1
),
bert_model_embedding_weights
.
size
(
0
),
bias
=
False
)
self
.
decoder
.
weight
=
bert_model_embedding_weights
self
.
decoder
.
weight
=
nn
.
Parameter
(
bert_model_embedding_weights
.
clone
())
self
.
bias
=
nn
.
Parameter
(
torch
.
zeros
(
bert_model_embedding_weights
.
size
(
0
)))
def
forward
(
self
,
hidden_states
):
...
...
@@ -666,7 +666,7 @@ class BertModel(BertPreTrainedModel):
sequence_output
=
encoder_outputs
[
0
]
pooled_output
=
self
.
pooler
(
sequence_output
)
outputs
=
[
sequence_output
,
pooled_output
]
+
encoder_outputs
[
1
:]
# add hidden_states and attentions if they are here
outputs
=
(
sequence_output
,
pooled_output
,)
+
encoder_outputs
[
1
:]
# add hidden_states and attentions if they are here
return
outputs
# sequence_output, pooled_output, (hidden_states), (attentions)
...
...
@@ -739,14 +739,14 @@ class BertForPreTraining(BertPreTrainedModel):
sequence_output
,
pooled_output
=
outputs
[:
2
]
prediction_scores
,
seq_relationship_score
=
self
.
cls
(
sequence_output
,
pooled_output
)
outputs
=
[
prediction_scores
,
seq_relationship_score
]
+
outputs
[
2
:]
# add hidden states and attention if they are here
outputs
=
(
prediction_scores
,
seq_relationship_score
,)
+
outputs
[
2
:]
# add hidden states and attention if they are here
if
masked_lm_labels
is
not
None
and
next_sentence_label
is
not
None
:
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
masked_lm_loss
=
loss_fct
(
prediction_scores
.
view
(
-
1
,
self
.
config
.
vocab_size
),
masked_lm_labels
.
view
(
-
1
))
next_sentence_loss
=
loss_fct
(
seq_relationship_score
.
view
(
-
1
,
2
),
next_sentence_label
.
view
(
-
1
))
total_loss
=
masked_lm_loss
+
next_sentence_loss
outputs
=
[
total_loss
]
+
outputs
outputs
=
(
total_loss
,)
+
outputs
return
outputs
# (loss), prediction_scores, seq_relationship_score, (hidden_states), (attentions)
...
...
@@ -815,11 +815,11 @@ class BertForMaskedLM(BertPreTrainedModel):
sequence_output
=
outputs
[
0
]
prediction_scores
=
self
.
cls
(
sequence_output
)
outputs
=
[
prediction_scores
]
+
outputs
[
2
:]
# Add hidden states and attention is they are here
outputs
=
(
prediction_scores
,)
+
outputs
[
2
:]
# Add hidden states and attention is they are here
if
masked_lm_labels
is
not
None
:
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
masked_lm_loss
=
loss_fct
(
prediction_scores
.
view
(
-
1
,
self
.
config
.
vocab_size
),
masked_lm_labels
.
view
(
-
1
))
outputs
=
[
masked_lm_loss
]
+
outputs
outputs
=
(
masked_lm_loss
,)
+
outputs
return
outputs
# (masked_lm_loss), prediction_scores, (hidden_states), (attentions)
...
...
@@ -885,11 +885,11 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
seq_relationship_score
=
self
.
cls
(
pooled_output
)
outputs
=
[
seq_relationship_score
]
+
outputs
[
2
:]
# add hidden states and attention if they are here
outputs
=
(
seq_relationship_score
,)
+
outputs
[
2
:]
# add hidden states and attention if they are here
if
next_sentence_label
is
not
None
:
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
next_sentence_loss
=
loss_fct
(
seq_relationship_score
.
view
(
-
1
,
2
),
next_sentence_label
.
view
(
-
1
))
outputs
=
[
next_sentence_loss
]
+
outputs
outputs
=
(
next_sentence_loss
,)
+
outputs
return
outputs
# (next_sentence_loss), seq_relationship_score, (hidden_states), (attentions)
...
...
@@ -960,7 +960,7 @@ class BertForSequenceClassification(BertPreTrainedModel):
pooled_output
=
self
.
dropout
(
pooled_output
)
logits
=
self
.
classifier
(
pooled_output
)
outputs
=
[
logits
]
+
outputs
[
2
:]
# add hidden states and attention if they are here
outputs
=
(
logits
,)
+
outputs
[
2
:]
# add hidden states and attention if they are here
if
labels
is
not
None
:
if
self
.
num_labels
==
1
:
...
...
@@ -970,7 +970,7 @@ class BertForSequenceClassification(BertPreTrainedModel):
else
:
loss_fct
=
CrossEntropyLoss
()
loss
=
loss_fct
(
logits
.
view
(
-
1
,
self
.
num_labels
),
labels
.
view
(
-
1
))
outputs
=
[
loss
]
+
outputs
outputs
=
(
loss
,)
+
outputs
return
outputs
# (loss), logits, (hidden_states), (attentions)
...
...
@@ -1043,12 +1043,12 @@ class BertForMultipleChoice(BertPreTrainedModel):
logits
=
self
.
classifier
(
pooled_output
)
reshaped_logits
=
logits
.
view
(
-
1
,
num_choices
)
outputs
=
[
reshaped_logits
]
+
outputs
[
2
:]
# add hidden states and attention if they are here
outputs
=
(
reshaped_logits
,)
+
outputs
[
2
:]
# add hidden states and attention if they are here
if
labels
is
not
None
:
loss_fct
=
CrossEntropyLoss
()
loss
=
loss_fct
(
reshaped_logits
,
labels
)
outputs
=
[
loss
]
+
outputs
outputs
=
(
loss
,)
+
outputs
return
outputs
# (loss), reshaped_logits, (hidden_states), (attentions)
...
...
@@ -1119,7 +1119,7 @@ class BertForTokenClassification(BertPreTrainedModel):
sequence_output
=
self
.
dropout
(
sequence_output
)
logits
=
self
.
classifier
(
sequence_output
)
outputs
=
[
logits
]
+
outputs
[
2
:]
# add hidden states and attention if they are here
outputs
=
(
logits
,)
+
outputs
[
2
:]
# add hidden states and attention if they are here
if
labels
is
not
None
:
loss_fct
=
CrossEntropyLoss
()
# Only keep active parts of the loss
...
...
@@ -1130,7 +1130,7 @@ class BertForTokenClassification(BertPreTrainedModel):
loss
=
loss_fct
(
active_logits
,
active_labels
)
else
:
loss
=
loss_fct
(
logits
.
view
(
-
1
,
self
.
num_labels
),
labels
.
view
(
-
1
))
outputs
=
[
loss
]
+
outputs
outputs
=
(
loss
,)
+
outputs
return
outputs
# (loss), logits, (hidden_states), (attentions)
...
...
@@ -1205,7 +1205,7 @@ class BertForQuestionAnswering(BertPreTrainedModel):
start_logits
=
start_logits
.
squeeze
(
-
1
)
end_logits
=
end_logits
.
squeeze
(
-
1
)
outputs
=
[
start_logits
,
end_logits
]
+
outputs
[
2
:]
outputs
=
(
start_logits
,
end_logits
,)
+
outputs
[
2
:]
if
start_positions
is
not
None
and
end_positions
is
not
None
:
# If we are on multi-GPU, split add a dimension
if
len
(
start_positions
.
size
())
>
1
:
...
...
@@ -1221,6 +1221,6 @@ class BertForQuestionAnswering(BertPreTrainedModel):
start_loss
=
loss_fct
(
start_logits
,
start_positions
)
end_loss
=
loss_fct
(
end_logits
,
end_positions
)
total_loss
=
(
start_loss
+
end_loss
)
/
2
outputs
=
[
total_loss
]
+
outputs
outputs
=
(
total_loss
,)
+
outputs
return
outputs
# (loss), start_logits, end_logits, (hidden_states), (attentions)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment