Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
64ce4dbd
Unverified
Commit
64ce4dbd
authored
Jul 03, 2019
by
Thomas Wolf
Committed by
GitHub
Jul 03, 2019
Browse files
Merge pull request #748 from huggingface/torchscript
Release 0.7 - Add Torchscript capabilities
parents
288be7b7
b43b130f
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
154 additions
and
81 deletions
+154
-81
pytorch_pretrained_bert/model_utils.py
pytorch_pretrained_bert/model_utils.py
+1
-0
pytorch_pretrained_bert/modeling.py
pytorch_pretrained_bert/modeling.py
+33
-27
pytorch_pretrained_bert/modeling_gpt2.py
pytorch_pretrained_bert/modeling_gpt2.py
+20
-15
pytorch_pretrained_bert/modeling_openai.py
pytorch_pretrained_bert/modeling_openai.py
+19
-15
pytorch_pretrained_bert/modeling_xlnet.py
pytorch_pretrained_bert/modeling_xlnet.py
+26
-21
pytorch_pretrained_bert/tests/model_tests_commons.py
pytorch_pretrained_bert/tests/model_tests_commons.py
+54
-2
pytorch_pretrained_bert/tests/modeling_transfo_xl_test.py
pytorch_pretrained_bert/tests/modeling_transfo_xl_test.py
+1
-1
No files found.
pytorch_pretrained_bert/model_utils.py
View file @
64ce4dbd
...
...
@@ -46,6 +46,7 @@ class PretrainedConfig(object):
self
.
num_labels
=
kwargs
.
pop
(
'num_labels'
,
2
)
self
.
output_attentions
=
kwargs
.
pop
(
'output_attentions'
,
False
)
self
.
output_hidden_states
=
kwargs
.
pop
(
'output_hidden_states'
,
False
)
self
.
torchscript
=
kwargs
.
pop
(
'torchscript'
,
False
)
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
**
kwargs
):
...
...
pytorch_pretrained_bert/modeling.py
View file @
64ce4dbd
...
...
@@ -323,7 +323,7 @@ class BertSelfAttention(nn.Module):
new_context_layer_shape
=
context_layer
.
size
()[:
-
2
]
+
(
self
.
all_head_size
,)
context_layer
=
context_layer
.
view
(
*
new_context_layer_shape
)
outputs
=
[
context_layer
,
attention_probs
]
if
self
.
output_attentions
else
[
context_layer
]
outputs
=
(
context_layer
,
attention_probs
)
if
self
.
output_attentions
else
(
context_layer
,)
return
outputs
...
...
@@ -367,7 +367,7 @@ class BertAttention(nn.Module):
def
forward
(
self
,
input_tensor
,
attention_mask
,
head_mask
=
None
):
self_outputs
=
self
.
self
(
input_tensor
,
attention_mask
,
head_mask
)
attention_output
=
self
.
output
(
self_outputs
[
0
],
input_tensor
)
outputs
=
[
attention_output
]
+
self_outputs
[
1
:]
# add attentions if we output them
outputs
=
(
attention_output
,)
+
self_outputs
[
1
:]
# add attentions if we output them
return
outputs
...
...
@@ -412,7 +412,7 @@ class BertLayer(nn.Module):
attention_output
=
attention_outputs
[
0
]
intermediate_output
=
self
.
intermediate
(
attention_output
)
layer_output
=
self
.
output
(
intermediate_output
,
attention_output
)
outputs
=
[
layer_output
]
+
attention_outputs
[
1
:]
# add attentions if we output them
outputs
=
(
layer_output
,)
+
attention_outputs
[
1
:]
# add attentions if we output them
return
outputs
...
...
@@ -424,27 +424,27 @@ class BertEncoder(nn.Module):
self
.
layer
=
nn
.
ModuleList
([
BertLayer
(
config
)
for
_
in
range
(
config
.
num_hidden_layers
)])
def
forward
(
self
,
hidden_states
,
attention_mask
,
head_mask
=
None
):
all_hidden_states
=
[]
all_attentions
=
[]
all_hidden_states
=
()
all_attentions
=
()
for
i
,
layer_module
in
enumerate
(
self
.
layer
):
if
self
.
output_hidden_states
:
all_hidden_states
.
append
(
hidden_states
)
all_hidden_states
=
all_hidden_states
+
(
hidden_states
,
)
layer_outputs
=
layer_module
(
hidden_states
,
attention_mask
,
head_mask
[
i
])
hidden_states
=
layer_outputs
[
0
]
if
self
.
output_attentions
:
all_attentions
.
append
(
layer_outputs
[
1
])
all_attentions
=
all_attentions
+
(
layer_outputs
[
1
]
,
)
# Add last layer
if
self
.
output_hidden_states
:
all_hidden_states
.
append
(
hidden_states
)
all_hidden_states
=
all_hidden_states
+
(
hidden_states
,
)
outputs
=
[
hidden_states
]
outputs
=
(
hidden_states
,)
if
self
.
output_hidden_states
:
outputs
.
append
(
all_hidden_states
)
outputs
=
outputs
+
(
all_hidden_states
,
)
if
self
.
output_attentions
:
outputs
.
append
(
all_attentions
)
outputs
=
outputs
+
(
all_attentions
,
)
return
outputs
# outputs, (hidden states), (attentions)
...
...
@@ -484,13 +484,19 @@ class BertLMPredictionHead(nn.Module):
def
__init__
(
self
,
config
,
bert_model_embedding_weights
):
super
(
BertLMPredictionHead
,
self
).
__init__
()
self
.
transform
=
BertPredictionHeadTransform
(
config
)
self
.
torchscript
=
config
.
torchscript
# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
self
.
decoder
=
nn
.
Linear
(
bert_model_embedding_weights
.
size
(
1
),
bert_model_embedding_weights
.
size
(
0
),
bias
=
False
)
if
self
.
torchscript
:
self
.
decoder
.
weight
=
nn
.
Parameter
(
bert_model_embedding_weights
.
clone
())
else
:
self
.
decoder
.
weight
=
bert_model_embedding_weights
self
.
bias
=
nn
.
Parameter
(
torch
.
zeros
(
bert_model_embedding_weights
.
size
(
0
)))
def
forward
(
self
,
hidden_states
):
...
...
@@ -666,7 +672,7 @@ class BertModel(BertPreTrainedModel):
sequence_output
=
encoder_outputs
[
0
]
pooled_output
=
self
.
pooler
(
sequence_output
)
outputs
=
[
sequence_output
,
pooled_output
]
+
encoder_outputs
[
1
:]
# add hidden_states and attentions if they are here
outputs
=
(
sequence_output
,
pooled_output
,)
+
encoder_outputs
[
1
:]
# add hidden_states and attentions if they are here
return
outputs
# sequence_output, pooled_output, (hidden_states), (attentions)
...
...
@@ -739,14 +745,14 @@ class BertForPreTraining(BertPreTrainedModel):
sequence_output
,
pooled_output
=
outputs
[:
2
]
prediction_scores
,
seq_relationship_score
=
self
.
cls
(
sequence_output
,
pooled_output
)
outputs
=
[
prediction_scores
,
seq_relationship_score
]
+
outputs
[
2
:]
# add hidden states and attention if they are here
outputs
=
(
prediction_scores
,
seq_relationship_score
,)
+
outputs
[
2
:]
# add hidden states and attention if they are here
if
masked_lm_labels
is
not
None
and
next_sentence_label
is
not
None
:
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
masked_lm_loss
=
loss_fct
(
prediction_scores
.
view
(
-
1
,
self
.
config
.
vocab_size
),
masked_lm_labels
.
view
(
-
1
))
next_sentence_loss
=
loss_fct
(
seq_relationship_score
.
view
(
-
1
,
2
),
next_sentence_label
.
view
(
-
1
))
total_loss
=
masked_lm_loss
+
next_sentence_loss
outputs
=
[
total_loss
]
+
outputs
outputs
=
(
total_loss
,)
+
outputs
return
outputs
# (loss), prediction_scores, seq_relationship_score, (hidden_states), (attentions)
...
...
@@ -815,11 +821,11 @@ class BertForMaskedLM(BertPreTrainedModel):
sequence_output
=
outputs
[
0
]
prediction_scores
=
self
.
cls
(
sequence_output
)
outputs
=
[
prediction_scores
]
+
outputs
[
2
:]
# Add hidden states and attention is they are here
outputs
=
(
prediction_scores
,)
+
outputs
[
2
:]
# Add hidden states and attention is they are here
if
masked_lm_labels
is
not
None
:
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
masked_lm_loss
=
loss_fct
(
prediction_scores
.
view
(
-
1
,
self
.
config
.
vocab_size
),
masked_lm_labels
.
view
(
-
1
))
outputs
=
[
masked_lm_loss
]
+
outputs
outputs
=
(
masked_lm_loss
,)
+
outputs
return
outputs
# (masked_lm_loss), prediction_scores, (hidden_states), (attentions)
...
...
@@ -885,11 +891,11 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
seq_relationship_score
=
self
.
cls
(
pooled_output
)
outputs
=
[
seq_relationship_score
]
+
outputs
[
2
:]
# add hidden states and attention if they are here
outputs
=
(
seq_relationship_score
,)
+
outputs
[
2
:]
# add hidden states and attention if they are here
if
next_sentence_label
is
not
None
:
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
next_sentence_loss
=
loss_fct
(
seq_relationship_score
.
view
(
-
1
,
2
),
next_sentence_label
.
view
(
-
1
))
outputs
=
[
next_sentence_loss
]
+
outputs
outputs
=
(
next_sentence_loss
,)
+
outputs
return
outputs
# (next_sentence_loss), seq_relationship_score, (hidden_states), (attentions)
...
...
@@ -960,7 +966,7 @@ class BertForSequenceClassification(BertPreTrainedModel):
pooled_output
=
self
.
dropout
(
pooled_output
)
logits
=
self
.
classifier
(
pooled_output
)
outputs
=
[
logits
]
+
outputs
[
2
:]
# add hidden states and attention if they are here
outputs
=
(
logits
,)
+
outputs
[
2
:]
# add hidden states and attention if they are here
if
labels
is
not
None
:
if
self
.
num_labels
==
1
:
...
...
@@ -970,7 +976,7 @@ class BertForSequenceClassification(BertPreTrainedModel):
else
:
loss_fct
=
CrossEntropyLoss
()
loss
=
loss_fct
(
logits
.
view
(
-
1
,
self
.
num_labels
),
labels
.
view
(
-
1
))
outputs
=
[
loss
]
+
outputs
outputs
=
(
loss
,)
+
outputs
return
outputs
# (loss), logits, (hidden_states), (attentions)
...
...
@@ -1043,12 +1049,12 @@ class BertForMultipleChoice(BertPreTrainedModel):
logits
=
self
.
classifier
(
pooled_output
)
reshaped_logits
=
logits
.
view
(
-
1
,
num_choices
)
outputs
=
[
reshaped_logits
]
+
outputs
[
2
:]
# add hidden states and attention if they are here
outputs
=
(
reshaped_logits
,)
+
outputs
[
2
:]
# add hidden states and attention if they are here
if
labels
is
not
None
:
loss_fct
=
CrossEntropyLoss
()
loss
=
loss_fct
(
reshaped_logits
,
labels
)
outputs
=
[
loss
]
+
outputs
outputs
=
(
loss
,)
+
outputs
return
outputs
# (loss), reshaped_logits, (hidden_states), (attentions)
...
...
@@ -1119,7 +1125,7 @@ class BertForTokenClassification(BertPreTrainedModel):
sequence_output
=
self
.
dropout
(
sequence_output
)
logits
=
self
.
classifier
(
sequence_output
)
outputs
=
[
logits
]
+
outputs
[
2
:]
# add hidden states and attention if they are here
outputs
=
(
logits
,)
+
outputs
[
2
:]
# add hidden states and attention if they are here
if
labels
is
not
None
:
loss_fct
=
CrossEntropyLoss
()
# Only keep active parts of the loss
...
...
@@ -1130,7 +1136,7 @@ class BertForTokenClassification(BertPreTrainedModel):
loss
=
loss_fct
(
active_logits
,
active_labels
)
else
:
loss
=
loss_fct
(
logits
.
view
(
-
1
,
self
.
num_labels
),
labels
.
view
(
-
1
))
outputs
=
[
loss
]
+
outputs
outputs
=
(
loss
,)
+
outputs
return
outputs
# (loss), logits, (hidden_states), (attentions)
...
...
@@ -1205,7 +1211,7 @@ class BertForQuestionAnswering(BertPreTrainedModel):
start_logits
=
start_logits
.
squeeze
(
-
1
)
end_logits
=
end_logits
.
squeeze
(
-
1
)
outputs
=
[
start_logits
,
end_logits
]
+
outputs
[
2
:]
outputs
=
(
start_logits
,
end_logits
,)
+
outputs
[
2
:]
if
start_positions
is
not
None
and
end_positions
is
not
None
:
# If we are on multi-GPU, split add a dimension
if
len
(
start_positions
.
size
())
>
1
:
...
...
@@ -1221,6 +1227,6 @@ class BertForQuestionAnswering(BertPreTrainedModel):
start_loss
=
loss_fct
(
start_logits
,
start_positions
)
end_loss
=
loss_fct
(
end_logits
,
end_positions
)
total_loss
=
(
start_loss
+
end_loss
)
/
2
outputs
=
[
total_loss
]
+
outputs
outputs
=
(
total_loss
,)
+
outputs
return
outputs
# (loss), start_logits, end_logits, (hidden_states), (attentions)
pytorch_pretrained_bert/modeling_gpt2.py
View file @
64ce4dbd
...
...
@@ -322,12 +322,17 @@ class GPT2LMHead(nn.Module):
self
.
n_embd
=
config
.
n_embd
self
.
vocab_size
=
config
.
vocab_size
self
.
predict_special_tokens
=
config
.
predict_special_tokens
self
.
torchscript
=
config
.
torchscript
embed_shape
=
model_embeddings_weights
.
shape
self
.
decoder
=
nn
.
Linear
(
embed_shape
[
1
],
embed_shape
[
0
],
bias
=
False
)
self
.
set_embeddings_weights
(
model_embeddings_weights
)
def
set_embeddings_weights
(
self
,
model_embeddings_weights
,
predict_special_tokens
=
True
):
self
.
predict_special_tokens
=
predict_special_tokens
# Export to TorchScript can't handle parameter sharing so we are cloning them.
if
self
.
torchscript
:
self
.
decoder
.
weight
=
nn
.
Parameter
(
model_embeddings_weights
.
clone
())
else
:
self
.
decoder
.
weight
=
model_embeddings_weights
# Tied weights
def
forward
(
self
,
hidden_state
):
...
...
@@ -557,16 +562,16 @@ class GPT2Model(GPT2PreTrainedModel):
output_shape
=
input_shape
+
(
hidden_states
.
size
(
-
1
),)
presents
=
[]
presents
=
()
all_attentions
=
[]
all_hidden_states
=
[]
all_hidden_states
=
()
for
i
,
(
block
,
layer_past
)
in
enumerate
(
zip
(
self
.
h
,
past
)):
if
self
.
output_hidden_states
:
all_hidden_states
.
append
(
hidden_states
.
view
(
*
output_shape
))
all_hidden_states
=
all_hidden_states
+
(
hidden_states
.
view
(
*
output_shape
)
,
)
outputs
=
block
(
hidden_states
,
layer_past
,
head_mask
[
i
])
hidden_states
,
present
=
outputs
[:
2
]
presents
.
append
(
present
)
presents
=
presents
+
(
present
,
)
if
self
.
output_attentions
:
all_attentions
.
append
(
outputs
[
2
])
...
...
@@ -576,16 +581,16 @@ class GPT2Model(GPT2PreTrainedModel):
hidden_states
=
hidden_states
.
view
(
*
output_shape
)
# Add last hidden state
if
self
.
output_hidden_states
:
all_hidden_states
.
append
(
hidden_states
)
all_hidden_states
=
all_hidden_states
+
(
hidden_states
,
)
outputs
=
[
hidden_states
,
presents
]
outputs
=
(
hidden_states
,
presents
)
if
self
.
output_hidden_states
:
outputs
.
append
(
all_hidden_states
)
outputs
=
outputs
+
(
all_hidden_states
,
)
if
self
.
output_attentions
:
# let the number of heads free (-1) so we can extract attention even after head pruning
attention_output_shape
=
input_shape
[:
-
1
]
+
(
-
1
,)
+
all_attentions
[
0
].
shape
[
-
2
:]
all_attentions
=
list
(
t
.
view
(
*
attention_output_shape
)
for
t
in
all_attentions
)
outputs
.
append
(
all_attentions
)
all_attentions
=
tuple
(
t
.
view
(
*
attention_output_shape
)
for
t
in
all_attentions
)
outputs
=
outputs
+
(
all_attentions
,
)
return
outputs
# last hidden state, presents, (all hidden_states), (attentions)
...
...
@@ -658,7 +663,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
lm_logits
=
self
.
lm_head
(
hidden_states
)
outputs
=
[
lm_logits
]
+
transformer_outputs
[
1
:]
outputs
=
(
lm_logits
,)
+
transformer_outputs
[
1
:]
if
lm_labels
is
not
None
:
# Shift so that tokens < n predict n
shift_logits
=
lm_logits
[...,
:
-
1
,
:].
contiguous
()
...
...
@@ -667,7 +672,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
loss
=
loss_fct
(
shift_logits
.
view
(
-
1
,
shift_logits
.
size
(
-
1
)),
shift_labels
.
view
(
-
1
))
outputs
=
[
loss
]
+
outputs
outputs
=
(
loss
,)
+
outputs
return
outputs
# (loss), lm_logits, presents, (all hidden_states), (attentions)
...
...
@@ -750,18 +755,18 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
lm_logits
=
self
.
lm_head
(
hidden_states
)
mc_logits
=
self
.
multiple_choice_head
(
hidden_states
,
mc_token_ids
)
outputs
=
[
lm_logits
,
mc_logits
]
+
transformer_outputs
[
1
:]
outputs
=
(
lm_logits
,
mc_logits
)
+
transformer_outputs
[
1
:]
if
mc_labels
is
not
None
:
loss_fct
=
CrossEntropyLoss
()
loss
=
loss_fct
(
mc_logits
.
view
(
-
1
,
mc_logits
.
size
(
-
1
)),
mc_labels
.
view
(
-
1
))
outputs
=
[
loss
]
+
outputs
outputs
=
(
loss
,)
+
outputs
if
lm_labels
is
not
None
:
shift_logits
=
lm_logits
[...,
:
-
1
,
:].
contiguous
()
shift_labels
=
lm_labels
[...,
1
:].
contiguous
()
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
loss
=
loss_fct
(
shift_logits
.
view
(
-
1
,
shift_logits
.
size
(
-
1
)),
shift_labels
.
view
(
-
1
))
outputs
=
[
loss
]
+
outputs
outputs
=
(
loss
,)
+
outputs
return
outputs
# (lm loss), (mc loss), lm logits, mc logits, presents, (all hidden_states), (attentions)
pytorch_pretrained_bert/modeling_openai.py
View file @
64ce4dbd
...
...
@@ -348,13 +348,17 @@ class OpenAIGPTLMHead(nn.Module):
self
.
n_embd
=
config
.
n_embd
self
.
vocab_size
=
config
.
vocab_size
self
.
predict_special_tokens
=
config
.
predict_special_tokens
self
.
torchscript
=
config
.
torchscript
embed_shape
=
model_embeddings_weights
.
shape
self
.
decoder
=
nn
.
Linear
(
embed_shape
[
1
],
embed_shape
[
0
],
bias
=
False
)
self
.
set_embeddings_weights
(
model_embeddings_weights
)
def
set_embeddings_weights
(
self
,
model_embeddings_weights
,
predict_special_tokens
=
True
):
self
.
predict_special_tokens
=
predict_special_tokens
embed_shape
=
model_embeddings_weights
.
shape
if
self
.
torchscript
:
self
.
decoder
.
weight
=
nn
.
Parameter
(
model_embeddings_weights
.
clone
())
else
:
self
.
decoder
.
weight
=
model_embeddings_weights
# Tied weights
def
forward
(
self
,
hidden_state
):
...
...
@@ -579,26 +583,26 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
output_shape
=
input_shape
+
(
hidden_states
.
size
(
-
1
),)
all_attentions
=
[]
all_hidden_states
=
[]
all_attentions
=
()
all_hidden_states
=
()
for
i
,
block
in
enumerate
(
self
.
h
):
if
self
.
output_hidden_states
:
all_hidden_states
.
append
(
hidden_states
.
view
(
*
output_shape
))
all_hidden_states
=
all_hidden_states
+
(
hidden_states
.
view
(
*
output_shape
)
,
)
outputs
=
block
(
hidden_states
,
head_mask
[
i
])
hidden_states
=
outputs
[
0
]
if
self
.
output_attentions
:
all_attentions
.
append
(
outputs
[
1
])
all_attentions
=
all_attentions
+
(
outputs
[
1
]
,
)
# Add last layer
if
self
.
output_hidden_states
:
all_hidden_states
.
append
(
hidden_states
.
view
(
*
output_shape
))
all_hidden_states
=
all_hidden_states
+
(
hidden_states
.
view
(
*
output_shape
)
,
)
outputs
=
[
hidden_states
.
view
(
*
output_shape
)
]
outputs
=
(
hidden_states
.
view
(
*
output_shape
)
,)
if
self
.
output_hidden_states
:
outputs
.
append
(
all_hidden_states
)
outputs
=
outputs
+
(
all_hidden_states
,
)
if
self
.
output_attentions
:
outputs
.
append
(
all_attentions
)
outputs
=
outputs
+
(
all_attentions
,
)
return
outputs
# last hidden state, (all hidden states), (all attentions)
...
...
@@ -682,7 +686,7 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
hidden_states
=
transformer_outputs
[
0
]
lm_logits
=
self
.
lm_head
(
hidden_states
)
outputs
=
[
lm_logits
]
+
transformer_outputs
[
1
:]
outputs
=
(
lm_logits
,)
+
transformer_outputs
[
1
:]
if
lm_labels
is
not
None
:
# Shift so that tokens < n predict n
shift_logits
=
lm_logits
[...,
:
-
1
,
:].
contiguous
()
...
...
@@ -691,7 +695,7 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
loss
=
loss_fct
(
shift_logits
.
view
(
-
1
,
shift_logits
.
size
(
-
1
)),
shift_labels
.
view
(
-
1
))
outputs
=
[
loss
]
+
outputs
outputs
=
(
loss
,)
+
outputs
return
outputs
# (loss), lm_logits, (all hidden states), (all attentions)
...
...
@@ -785,18 +789,18 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
lm_logits
=
self
.
lm_head
(
hidden_states
)
mc_logits
=
self
.
multiple_choice_head
(
hidden_states
,
mc_token_ids
)
outputs
=
[
lm_logits
,
mc_logits
]
+
transformer_outputs
[
1
:]
outputs
=
(
lm_logits
,
mc_logits
)
+
transformer_outputs
[
1
:]
if
mc_labels
is
not
None
:
loss_fct
=
CrossEntropyLoss
()
loss
=
loss_fct
(
mc_logits
.
view
(
-
1
,
mc_logits
.
size
(
-
1
)),
mc_labels
.
view
(
-
1
))
outputs
=
[
loss
]
+
outputs
outputs
=
(
loss
,)
+
outputs
if
lm_labels
is
not
None
:
shift_logits
=
lm_logits
[...,
:
-
1
,
:].
contiguous
()
shift_labels
=
lm_labels
[...,
1
:].
contiguous
()
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
loss
=
loss_fct
(
shift_logits
.
view
(
-
1
,
shift_logits
.
size
(
-
1
)),
shift_labels
.
view
(
-
1
))
outputs
=
[
loss
]
+
outputs
outputs
=
(
loss
,)
+
outputs
return
outputs
# (lm loss), (mc loss), lm logits, mc logits, (all hidden_states), (attentions)
pytorch_pretrained_bert/modeling_xlnet.py
View file @
64ce4dbd
...
...
@@ -384,7 +384,8 @@ class XLNetRelativeAttention(nn.Module):
x
=
x
.
reshape
(
x_size
[
1
],
x_size
[
0
],
x_size
[
2
],
x_size
[
3
])
x
=
x
[
1
:,
...]
x
=
x
.
reshape
(
x_size
[
0
],
x_size
[
1
]
-
1
,
x_size
[
2
],
x_size
[
3
])
x
=
x
[:,
0
:
klen
,
:,
:]
# x = x[:, 0:klen, :, :]
x
=
torch
.
index_select
(
x
,
1
,
torch
.
arange
(
klen
))
return
x
...
...
@@ -527,9 +528,9 @@ class XLNetRelativeAttention(nn.Module):
output_h
=
self
.
post_attention
(
h
,
attn_vec
)
output_g
=
None
outputs
=
[
output_h
,
output_g
]
outputs
=
(
output_h
,
output_g
)
if
self
.
output_attentions
:
outputs
=
outputs
+
[
attn_prob
]
outputs
=
outputs
+
(
attn_prob
,)
return
outputs
class
XLNetFeedForward
(
nn
.
Module
):
...
...
@@ -574,7 +575,7 @@ class XLNetLayer(nn.Module):
output_g
=
self
.
ff
(
output_g
)
output_h
=
self
.
ff
(
output_h
)
outputs
=
[
output_h
,
output_g
]
+
outputs
[
2
:]
# Add again attentions if there are there
outputs
=
(
output_h
,
output_g
)
+
outputs
[
2
:]
# Add again attentions if there are there
return
outputs
...
...
@@ -688,7 +689,7 @@ class XLNetModel(XLNetPreTrainedModel):
def
relative_positional_encoding
(
self
,
qlen
,
klen
,
bsz
=
None
):
"""create relative positional encoding."""
freq_seq
=
torch
.
arange
(
0
,
self
.
d_model
,
2.0
,
dtype
=
torch
.
float
)
inv_freq
=
1
/
(
10000
**
(
freq_seq
/
self
.
d_model
))
inv_freq
=
1
/
torch
.
pow
(
10000
,
(
freq_seq
/
self
.
d_model
))
if
self
.
attn_type
==
'bi'
:
# beg, end = klen - 1, -qlen
...
...
@@ -869,7 +870,7 @@ class XLNetModel(XLNetPreTrainedModel):
else
:
head_mask
=
[
None
]
*
self
.
n_layer
new_mems
=
[]
new_mems
=
()
if
mems
is
None
:
mems
=
[
None
]
*
len
(
self
.
layer
)
...
...
@@ -877,7 +878,7 @@ class XLNetModel(XLNetPreTrainedModel):
hidden_states
=
[]
for
i
,
layer_module
in
enumerate
(
self
.
layer
):
# cache new mems
new_mems
.
append
(
self
.
cache_mem
(
output_h
,
mems
[
i
]))
new_mems
=
new_mems
+
(
self
.
cache_mem
(
output_h
,
mems
[
i
])
,
)
if
self
.
output_hidden_states
:
hidden_states
.
append
((
output_h
,
output_g
)
if
output_g
is
not
None
else
output_h
)
...
...
@@ -895,16 +896,16 @@ class XLNetModel(XLNetPreTrainedModel):
output
=
self
.
dropout
(
output_g
if
output_g
is
not
None
else
output_h
)
# Prepare outputs, we transpose back here to shape [bsz, len, hidden_dim] (cf. beginning of forward() method)
outputs
=
[
output
.
permute
(
1
,
0
,
2
).
contiguous
(),
new_mems
]
outputs
=
(
output
.
permute
(
1
,
0
,
2
).
contiguous
(),
new_mems
)
if
self
.
output_hidden_states
:
if
output_g
is
not
None
:
hidden_states
=
[
h
.
permute
(
1
,
0
,
2
).
contiguous
()
for
hs
in
hidden_states
for
h
in
hs
]
hidden_states
=
tuple
(
h
.
permute
(
1
,
0
,
2
).
contiguous
()
for
hs
in
hidden_states
for
h
in
hs
)
else
:
hidden_states
=
[
hs
.
permute
(
1
,
0
,
2
).
contiguous
()
for
hs
in
hidden_states
]
outputs
.
append
(
hidden_states
)
hidden_states
=
tuple
(
hs
.
permute
(
1
,
0
,
2
).
contiguous
()
for
hs
in
hidden_states
)
outputs
=
outputs
+
(
hidden_states
,
)
if
self
.
output_attentions
:
attentions
=
list
(
t
.
permute
(
2
,
3
,
0
,
1
).
contiguous
()
for
t
in
attentions
)
outputs
.
append
(
attentions
)
attentions
=
tuple
(
t
.
permute
(
2
,
3
,
0
,
1
).
contiguous
()
for
t
in
attentions
)
outputs
=
outputs
+
(
attentions
,
)
return
outputs
# outputs, new_mems, (hidden_states), (attentions)
...
...
@@ -974,6 +975,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
super
(
XLNetLMHeadModel
,
self
).
__init__
(
config
)
self
.
attn_type
=
config
.
attn_type
self
.
same_length
=
config
.
same_length
self
.
torchscript
=
config
.
torchscript
self
.
transformer
=
XLNetModel
(
config
)
self
.
lm_loss
=
nn
.
Linear
(
config
.
d_model
,
config
.
n_token
,
bias
=
True
)
...
...
@@ -986,6 +988,9 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
def
tie_weights
(
self
):
""" Make sure we are sharing the embeddings
"""
if
self
.
torchscript
:
self
.
lm_loss
.
weight
=
nn
.
Parameter
(
self
.
transformer
.
word_embedding
.
weight
.
clone
())
else
:
self
.
lm_loss
.
weight
=
self
.
transformer
.
word_embedding
.
weight
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
,
input_mask
=
None
,
attention_mask
=
None
,
...
...
@@ -1026,14 +1031,14 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
logits
=
self
.
lm_loss
(
transformer_outputs
[
0
])
outputs
=
[
logits
]
+
transformer_outputs
[
1
:]
# Keep mems, hidden states, attentions if there are in it
outputs
=
(
logits
,)
+
transformer_outputs
[
1
:]
# Keep mems, hidden states, attentions if there are in it
if
labels
is
not
None
:
# Flatten the tokens
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
loss
=
loss_fct
(
logits
.
view
(
-
1
,
logits
.
size
(
-
1
)),
labels
.
view
(
-
1
))
outputs
=
[
loss
]
+
outputs
outputs
=
(
loss
,)
+
outputs
return
outputs
# return (loss), logits, (mems), (hidden states), (attentions)
...
...
@@ -1061,7 +1066,7 @@ class XLNetSequenceSummary(nn.Module):
output
=
hidden_states
[:,
0
]
elif
self
.
summary_type
==
'mean'
:
output
=
hidden_states
.
mean
(
dim
=
1
)
elif
summary_type
==
'attn'
:
elif
self
.
summary_type
==
'attn'
:
raise
NotImplementedError
output
=
self
.
summary
(
output
)
...
...
@@ -1180,7 +1185,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
output
=
self
.
sequence_summary
(
output
)
logits
=
self
.
logits_proj
(
output
)
outputs
=
[
logits
]
+
transformer_outputs
[
1
:]
# Keep mems, hidden states, attentions if there are in it
outputs
=
(
logits
,)
+
transformer_outputs
[
1
:]
# Keep mems, hidden states, attentions if there are in it
if
labels
is
not
None
:
if
self
.
num_labels
==
1
:
...
...
@@ -1190,7 +1195,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
else
:
loss_fct
=
CrossEntropyLoss
()
loss
=
loss_fct
(
logits
.
view
(
-
1
,
self
.
num_labels
),
labels
.
view
(
-
1
))
outputs
=
[
loss
]
+
outputs
outputs
=
(
loss
,)
+
outputs
return
outputs
# return (loss), logits, (mems), (hidden states), (attentions)
...
...
@@ -1271,7 +1276,7 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
start_logits
=
start_logits
.
squeeze
(
-
1
)
end_logits
=
end_logits
.
squeeze
(
-
1
)
outputs
=
[
start_logits
,
end_logits
]
+
transformer_outputs
[
1
:]
# Keep mems, hidden states, attentions if there are in it
outputs
=
(
start_logits
,
end_logits
,)
+
transformer_outputs
[
1
:]
# Keep mems, hidden states, attentions if there are in it
if
start_positions
is
not
None
and
end_positions
is
not
None
:
# If we are on multi-GPU, split add a dimension
...
...
@@ -1288,6 +1293,6 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
start_loss
=
loss_fct
(
start_logits
,
start_positions
)
end_loss
=
loss_fct
(
end_logits
,
end_positions
)
total_loss
=
(
start_loss
+
end_loss
)
/
2
outputs
=
[
total_loss
]
+
outputs
outputs
=
(
total_loss
,)
+
outputs
return
outputs
# return (loss), logits, (mems), (hidden states), (attentions)
pytorch_pretrained_bert/tests/model_tests_commons.py
View file @
64ce4dbd
...
...
@@ -31,6 +31,52 @@ def _config_zero_init(config):
setattr
(
configs_no_init
,
key
,
0.0
)
return
configs_no_init
def
_create_and_check_torchscript_output_attentions
(
tester
,
model_classes
,
config
,
inputs_dict
):
config
.
output_attentions
=
True
_create_and_check_torchscript
(
tester
,
model_classes
,
config
,
inputs_dict
)
def
_create_and_check_torchscript_output_hidden_state
(
tester
,
model_classes
,
config
,
inputs_dict
):
config
.
output_hidden_states
=
True
_create_and_check_torchscript
(
tester
,
model_classes
,
config
,
inputs_dict
)
def
_create_and_check_torchscript
(
tester
,
model_classes
,
config
,
inputs_dict
):
configs_no_init
=
_config_zero_init
(
config
)
# To be sure we have no Nan
configs_no_init
.
torchscript
=
True
for
model_class
in
model_classes
:
model
=
model_class
(
config
=
configs_no_init
)
model
.
eval
()
inputs
=
inputs_dict
[
'input_ids'
]
# Let's keep only input_ids
try
:
torch
.
jit
.
trace
(
model
,
inputs
)
except
RuntimeError
:
tester
.
parent
.
fail
(
"Couldn't trace module."
)
try
:
traced_gpt2
=
torch
.
jit
.
trace
(
model
,
inputs
)
torch
.
jit
.
save
(
traced_gpt2
,
"traced_model.pt"
)
except
RuntimeError
:
tester
.
parent
.
fail
(
"Couldn't save module."
)
try
:
loaded_model
=
torch
.
jit
.
load
(
"traced_model.pt"
)
os
.
remove
(
"traced_model.pt"
)
except
ValueError
:
tester
.
parent
.
fail
(
"Couldn't load module."
)
model
.
eval
()
loaded_model
.
eval
()
model_params
=
model
.
parameters
()
loaded_model_params
=
loaded_model
.
parameters
()
models_equal
=
True
for
p1
,
p2
in
zip
(
model_params
,
loaded_model_params
):
if
p1
.
data
.
ne
(
p2
.
data
).
sum
()
>
0
:
models_equal
=
False
tester
.
parent
.
assertTrue
(
models_equal
)
def
_create_and_check_initialization
(
tester
,
model_classes
,
config
,
inputs_dict
):
configs_no_init
=
_config_zero_init
(
config
)
for
model_class
in
model_classes
:
...
...
@@ -39,7 +85,7 @@ def _create_and_check_initialization(tester, model_classes, config, inputs_dict)
tester
.
parent
.
assertIn
(
param
.
data
.
mean
().
item
(),
[
0.0
,
1.0
],
msg
=
"Parameter {} of model {} seems not properly initialized"
.
format
(
name
,
model_class
))
def
_create_and_check_for_headmasking
(
tester
,
model_classes
,
config
,
inputs_dict
):
configs_no_init
=
_config_zero_init
(
config
)
configs_no_init
=
_config_zero_init
(
config
)
# To be sure we have no Nan
for
model_class
in
model_classes
:
config
.
output_attentions
=
True
config
.
output_hidden_states
=
True
...
...
@@ -153,11 +199,17 @@ def _create_and_check_for_hidden_states(tester, model_classes, config, inputs_di
[
tester
.
seq_length
,
tester
.
hidden_size
])
def
create_and_check_commons
(
tester
,
config
,
inputs_dict
,
test_pruning
=
True
):
def
create_and_check_commons
(
tester
,
config
,
inputs_dict
,
test_pruning
=
True
,
test_torchscript
=
True
):
_create_and_check_initialization
(
tester
,
tester
.
all_model_classes
,
config
,
inputs_dict
)
_create_and_check_for_attentions
(
tester
,
tester
.
all_model_classes
,
config
,
inputs_dict
)
_create_and_check_for_headmasking
(
tester
,
tester
.
all_model_classes
,
config
,
inputs_dict
)
_create_and_check_for_hidden_states
(
tester
,
tester
.
all_model_classes
,
config
,
inputs_dict
)
if
test_torchscript
:
_create_and_check_torchscript
(
tester
,
tester
.
all_model_classes
,
config
,
inputs_dict
)
_create_and_check_torchscript_output_attentions
(
tester
,
tester
.
all_model_classes
,
config
,
inputs_dict
)
_create_and_check_torchscript_output_hidden_state
(
tester
,
tester
.
all_model_classes
,
config
,
inputs_dict
)
if
test_pruning
:
_create_and_check_for_head_pruning
(
tester
,
tester
.
all_model_classes
,
config
,
inputs_dict
)
...
...
pytorch_pretrained_bert/tests/modeling_transfo_xl_test.py
View file @
64ce4dbd
...
...
@@ -173,7 +173,7 @@ class TransfoXLModelTest(unittest.TestCase):
def
create_and_check_transfo_xl_commons
(
self
,
config
,
input_ids_1
,
input_ids_2
,
lm_labels
):
inputs_dict
=
{
'input_ids'
:
input_ids_1
}
create_and_check_commons
(
self
,
config
,
inputs_dict
,
test_pruning
=
False
)
create_and_check_commons
(
self
,
config
,
inputs_dict
,
test_pruning
=
False
,
test_torchscript
=
False
)
def
test_default
(
self
):
self
.
run_tester
(
TransfoXLModelTest
.
TransfoXLModelTester
(
self
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment