Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
3b23a846
Commit
3b23a846
authored
Jul 03, 2019
by
thomwolf
Browse files
Merge branch 'xlnet' of
https://github.com/huggingface/pytorch-pretrained-BERT
into xlnet
parents
8fa3a1f0
64ce4dbd
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
154 additions
and
81 deletions
+154
-81
pytorch_pretrained_bert/model_utils.py
pytorch_pretrained_bert/model_utils.py
+1
-0
pytorch_pretrained_bert/modeling_bert.py
pytorch_pretrained_bert/modeling_bert.py
+33
-27
pytorch_pretrained_bert/modeling_gpt2.py
pytorch_pretrained_bert/modeling_gpt2.py
+20
-15
pytorch_pretrained_bert/modeling_openai.py
pytorch_pretrained_bert/modeling_openai.py
+19
-15
pytorch_pretrained_bert/modeling_xlnet.py
pytorch_pretrained_bert/modeling_xlnet.py
+26
-21
pytorch_pretrained_bert/tests/model_tests_commons.py
pytorch_pretrained_bert/tests/model_tests_commons.py
+54
-2
pytorch_pretrained_bert/tests/modeling_transfo_xl_test.py
pytorch_pretrained_bert/tests/modeling_transfo_xl_test.py
+1
-1
No files found.
pytorch_pretrained_bert/model_utils.py
View file @
3b23a846
...
...
@@ -46,6 +46,7 @@ class PretrainedConfig(object):
self
.
num_labels
=
kwargs
.
pop
(
'num_labels'
,
2
)
self
.
output_attentions
=
kwargs
.
pop
(
'output_attentions'
,
False
)
self
.
output_hidden_states
=
kwargs
.
pop
(
'output_hidden_states'
,
False
)
self
.
torchscript
=
kwargs
.
pop
(
'torchscript'
,
False
)
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
**
kwargs
):
...
...
pytorch_pretrained_bert/modeling_bert.py
View file @
3b23a846
...
...
@@ -323,7 +323,7 @@ class BertSelfAttention(nn.Module):
new_context_layer_shape
=
context_layer
.
size
()[:
-
2
]
+
(
self
.
all_head_size
,)
context_layer
=
context_layer
.
view
(
*
new_context_layer_shape
)
outputs
=
[
context_layer
,
attention_probs
]
if
self
.
output_attentions
else
[
context_layer
]
outputs
=
(
context_layer
,
attention_probs
)
if
self
.
output_attentions
else
(
context_layer
,)
return
outputs
...
...
@@ -367,7 +367,7 @@ class BertAttention(nn.Module):
def
forward
(
self
,
input_tensor
,
attention_mask
,
head_mask
=
None
):
self_outputs
=
self
.
self
(
input_tensor
,
attention_mask
,
head_mask
)
attention_output
=
self
.
output
(
self_outputs
[
0
],
input_tensor
)
outputs
=
[
attention_output
]
+
self_outputs
[
1
:]
# add attentions if we output them
outputs
=
(
attention_output
,)
+
self_outputs
[
1
:]
# add attentions if we output them
return
outputs
...
...
@@ -412,7 +412,7 @@ class BertLayer(nn.Module):
attention_output
=
attention_outputs
[
0
]
intermediate_output
=
self
.
intermediate
(
attention_output
)
layer_output
=
self
.
output
(
intermediate_output
,
attention_output
)
outputs
=
[
layer_output
]
+
attention_outputs
[
1
:]
# add attentions if we output them
outputs
=
(
layer_output
,)
+
attention_outputs
[
1
:]
# add attentions if we output them
return
outputs
...
...
@@ -424,27 +424,27 @@ class BertEncoder(nn.Module):
self
.
layer
=
nn
.
ModuleList
([
BertLayer
(
config
)
for
_
in
range
(
config
.
num_hidden_layers
)])
def
forward
(
self
,
hidden_states
,
attention_mask
,
head_mask
=
None
):
all_hidden_states
=
[]
all_attentions
=
[]
all_hidden_states
=
()
all_attentions
=
()
for
i
,
layer_module
in
enumerate
(
self
.
layer
):
if
self
.
output_hidden_states
:
all_hidden_states
.
append
(
hidden_states
)
all_hidden_states
=
all_hidden_states
+
(
hidden_states
,
)
layer_outputs
=
layer_module
(
hidden_states
,
attention_mask
,
head_mask
[
i
])
hidden_states
=
layer_outputs
[
0
]
if
self
.
output_attentions
:
all_attentions
.
append
(
layer_outputs
[
1
])
all_attentions
=
all_attentions
+
(
layer_outputs
[
1
]
,
)
# Add last layer
if
self
.
output_hidden_states
:
all_hidden_states
.
append
(
hidden_states
)
all_hidden_states
=
all_hidden_states
+
(
hidden_states
,
)
outputs
=
[
hidden_states
]
outputs
=
(
hidden_states
,)
if
self
.
output_hidden_states
:
outputs
.
append
(
all_hidden_states
)
outputs
=
outputs
+
(
all_hidden_states
,
)
if
self
.
output_attentions
:
outputs
.
append
(
all_attentions
)
outputs
=
outputs
+
(
all_attentions
,
)
return
outputs
# outputs, (hidden states), (attentions)
...
...
@@ -484,13 +484,19 @@ class BertLMPredictionHead(nn.Module):
def
__init__
(
self
,
config
,
bert_model_embedding_weights
):
super
(
BertLMPredictionHead
,
self
).
__init__
()
self
.
transform
=
BertPredictionHeadTransform
(
config
)
self
.
torchscript
=
config
.
torchscript
# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
self
.
decoder
=
nn
.
Linear
(
bert_model_embedding_weights
.
size
(
1
),
bert_model_embedding_weights
.
size
(
0
),
bias
=
False
)
self
.
decoder
.
weight
=
bert_model_embedding_weights
if
self
.
torchscript
:
self
.
decoder
.
weight
=
nn
.
Parameter
(
bert_model_embedding_weights
.
clone
())
else
:
self
.
decoder
.
weight
=
bert_model_embedding_weights
self
.
bias
=
nn
.
Parameter
(
torch
.
zeros
(
bert_model_embedding_weights
.
size
(
0
)))
def
forward
(
self
,
hidden_states
):
...
...
@@ -666,7 +672,7 @@ class BertModel(BertPreTrainedModel):
sequence_output
=
encoder_outputs
[
0
]
pooled_output
=
self
.
pooler
(
sequence_output
)
outputs
=
[
sequence_output
,
pooled_output
]
+
encoder_outputs
[
1
:]
# add hidden_states and attentions if they are here
outputs
=
(
sequence_output
,
pooled_output
,)
+
encoder_outputs
[
1
:]
# add hidden_states and attentions if they are here
return
outputs
# sequence_output, pooled_output, (hidden_states), (attentions)
...
...
@@ -739,14 +745,14 @@ class BertForPreTraining(BertPreTrainedModel):
sequence_output
,
pooled_output
=
outputs
[:
2
]
prediction_scores
,
seq_relationship_score
=
self
.
cls
(
sequence_output
,
pooled_output
)
outputs
=
[
prediction_scores
,
seq_relationship_score
]
+
outputs
[
2
:]
# add hidden states and attention if they are here
outputs
=
(
prediction_scores
,
seq_relationship_score
,)
+
outputs
[
2
:]
# add hidden states and attention if they are here
if
masked_lm_labels
is
not
None
and
next_sentence_label
is
not
None
:
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
masked_lm_loss
=
loss_fct
(
prediction_scores
.
view
(
-
1
,
self
.
config
.
vocab_size
),
masked_lm_labels
.
view
(
-
1
))
next_sentence_loss
=
loss_fct
(
seq_relationship_score
.
view
(
-
1
,
2
),
next_sentence_label
.
view
(
-
1
))
total_loss
=
masked_lm_loss
+
next_sentence_loss
outputs
=
[
total_loss
]
+
outputs
outputs
=
(
total_loss
,)
+
outputs
return
outputs
# (loss), prediction_scores, seq_relationship_score, (hidden_states), (attentions)
...
...
@@ -815,11 +821,11 @@ class BertForMaskedLM(BertPreTrainedModel):
sequence_output
=
outputs
[
0
]
prediction_scores
=
self
.
cls
(
sequence_output
)
outputs
=
[
prediction_scores
]
+
outputs
[
2
:]
# Add hidden states and attention is they are here
outputs
=
(
prediction_scores
,)
+
outputs
[
2
:]
# Add hidden states and attention is they are here
if
masked_lm_labels
is
not
None
:
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
masked_lm_loss
=
loss_fct
(
prediction_scores
.
view
(
-
1
,
self
.
config
.
vocab_size
),
masked_lm_labels
.
view
(
-
1
))
outputs
=
[
masked_lm_loss
]
+
outputs
outputs
=
(
masked_lm_loss
,)
+
outputs
return
outputs
# (masked_lm_loss), prediction_scores, (hidden_states), (attentions)
...
...
@@ -885,11 +891,11 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
seq_relationship_score
=
self
.
cls
(
pooled_output
)
outputs
=
[
seq_relationship_score
]
+
outputs
[
2
:]
# add hidden states and attention if they are here
outputs
=
(
seq_relationship_score
,)
+
outputs
[
2
:]
# add hidden states and attention if they are here
if
next_sentence_label
is
not
None
:
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
next_sentence_loss
=
loss_fct
(
seq_relationship_score
.
view
(
-
1
,
2
),
next_sentence_label
.
view
(
-
1
))
outputs
=
[
next_sentence_loss
]
+
outputs
outputs
=
(
next_sentence_loss
,)
+
outputs
return
outputs
# (next_sentence_loss), seq_relationship_score, (hidden_states), (attentions)
...
...
@@ -960,7 +966,7 @@ class BertForSequenceClassification(BertPreTrainedModel):
pooled_output
=
self
.
dropout
(
pooled_output
)
logits
=
self
.
classifier
(
pooled_output
)
outputs
=
[
logits
]
+
outputs
[
2
:]
# add hidden states and attention if they are here
outputs
=
(
logits
,)
+
outputs
[
2
:]
# add hidden states and attention if they are here
if
labels
is
not
None
:
if
self
.
num_labels
==
1
:
...
...
@@ -970,7 +976,7 @@ class BertForSequenceClassification(BertPreTrainedModel):
else
:
loss_fct
=
CrossEntropyLoss
()
loss
=
loss_fct
(
logits
.
view
(
-
1
,
self
.
num_labels
),
labels
.
view
(
-
1
))
outputs
=
[
loss
]
+
outputs
outputs
=
(
loss
,)
+
outputs
return
outputs
# (loss), logits, (hidden_states), (attentions)
...
...
@@ -1043,12 +1049,12 @@ class BertForMultipleChoice(BertPreTrainedModel):
logits
=
self
.
classifier
(
pooled_output
)
reshaped_logits
=
logits
.
view
(
-
1
,
num_choices
)
outputs
=
[
reshaped_logits
]
+
outputs
[
2
:]
# add hidden states and attention if they are here
outputs
=
(
reshaped_logits
,)
+
outputs
[
2
:]
# add hidden states and attention if they are here
if
labels
is
not
None
:
loss_fct
=
CrossEntropyLoss
()
loss
=
loss_fct
(
reshaped_logits
,
labels
)
outputs
=
[
loss
]
+
outputs
outputs
=
(
loss
,)
+
outputs
return
outputs
# (loss), reshaped_logits, (hidden_states), (attentions)
...
...
@@ -1119,7 +1125,7 @@ class BertForTokenClassification(BertPreTrainedModel):
sequence_output
=
self
.
dropout
(
sequence_output
)
logits
=
self
.
classifier
(
sequence_output
)
outputs
=
[
logits
]
+
outputs
[
2
:]
# add hidden states and attention if they are here
outputs
=
(
logits
,)
+
outputs
[
2
:]
# add hidden states and attention if they are here
if
labels
is
not
None
:
loss_fct
=
CrossEntropyLoss
()
# Only keep active parts of the loss
...
...
@@ -1130,7 +1136,7 @@ class BertForTokenClassification(BertPreTrainedModel):
loss
=
loss_fct
(
active_logits
,
active_labels
)
else
:
loss
=
loss_fct
(
logits
.
view
(
-
1
,
self
.
num_labels
),
labels
.
view
(
-
1
))
outputs
=
[
loss
]
+
outputs
outputs
=
(
loss
,)
+
outputs
return
outputs
# (loss), logits, (hidden_states), (attentions)
...
...
@@ -1205,7 +1211,7 @@ class BertForQuestionAnswering(BertPreTrainedModel):
start_logits
=
start_logits
.
squeeze
(
-
1
)
end_logits
=
end_logits
.
squeeze
(
-
1
)
outputs
=
[
start_logits
,
end_logits
]
+
outputs
[
2
:]
outputs
=
(
start_logits
,
end_logits
,)
+
outputs
[
2
:]
if
start_positions
is
not
None
and
end_positions
is
not
None
:
# If we are on multi-GPU, split add a dimension
if
len
(
start_positions
.
size
())
>
1
:
...
...
@@ -1221,6 +1227,6 @@ class BertForQuestionAnswering(BertPreTrainedModel):
start_loss
=
loss_fct
(
start_logits
,
start_positions
)
end_loss
=
loss_fct
(
end_logits
,
end_positions
)
total_loss
=
(
start_loss
+
end_loss
)
/
2
outputs
=
[
total_loss
]
+
outputs
outputs
=
(
total_loss
,)
+
outputs
return
outputs
# (loss), start_logits, end_logits, (hidden_states), (attentions)
pytorch_pretrained_bert/modeling_gpt2.py
View file @
3b23a846
...
...
@@ -322,13 +322,18 @@ class GPT2LMHead(nn.Module):
self
.
n_embd
=
config
.
n_embd
self
.
vocab_size
=
config
.
vocab_size
self
.
predict_special_tokens
=
config
.
predict_special_tokens
self
.
torchscript
=
config
.
torchscript
embed_shape
=
model_embeddings_weights
.
shape
self
.
decoder
=
nn
.
Linear
(
embed_shape
[
1
],
embed_shape
[
0
],
bias
=
False
)
self
.
set_embeddings_weights
(
model_embeddings_weights
)
def
set_embeddings_weights
(
self
,
model_embeddings_weights
,
predict_special_tokens
=
True
):
self
.
predict_special_tokens
=
predict_special_tokens
self
.
decoder
.
weight
=
model_embeddings_weights
# Tied weights
# Export to TorchScript can't handle parameter sharing so we are cloning them.
if
self
.
torchscript
:
self
.
decoder
.
weight
=
nn
.
Parameter
(
model_embeddings_weights
.
clone
())
else
:
self
.
decoder
.
weight
=
model_embeddings_weights
# Tied weights
def
forward
(
self
,
hidden_state
):
lm_logits
=
self
.
decoder
(
hidden_state
)
...
...
@@ -557,16 +562,16 @@ class GPT2Model(GPT2PreTrainedModel):
output_shape
=
input_shape
+
(
hidden_states
.
size
(
-
1
),)
presents
=
[]
presents
=
()
all_attentions
=
[]
all_hidden_states
=
[]
all_hidden_states
=
()
for
i
,
(
block
,
layer_past
)
in
enumerate
(
zip
(
self
.
h
,
past
)):
if
self
.
output_hidden_states
:
all_hidden_states
.
append
(
hidden_states
.
view
(
*
output_shape
))
all_hidden_states
=
all_hidden_states
+
(
hidden_states
.
view
(
*
output_shape
)
,
)
outputs
=
block
(
hidden_states
,
layer_past
,
head_mask
[
i
])
hidden_states
,
present
=
outputs
[:
2
]
presents
.
append
(
present
)
presents
=
presents
+
(
present
,
)
if
self
.
output_attentions
:
all_attentions
.
append
(
outputs
[
2
])
...
...
@@ -576,16 +581,16 @@ class GPT2Model(GPT2PreTrainedModel):
hidden_states
=
hidden_states
.
view
(
*
output_shape
)
# Add last hidden state
if
self
.
output_hidden_states
:
all_hidden_states
.
append
(
hidden_states
)
all_hidden_states
=
all_hidden_states
+
(
hidden_states
,
)
outputs
=
[
hidden_states
,
presents
]
outputs
=
(
hidden_states
,
presents
)
if
self
.
output_hidden_states
:
outputs
.
append
(
all_hidden_states
)
outputs
=
outputs
+
(
all_hidden_states
,
)
if
self
.
output_attentions
:
# let the number of heads free (-1) so we can extract attention even after head pruning
attention_output_shape
=
input_shape
[:
-
1
]
+
(
-
1
,)
+
all_attentions
[
0
].
shape
[
-
2
:]
all_attentions
=
list
(
t
.
view
(
*
attention_output_shape
)
for
t
in
all_attentions
)
outputs
.
append
(
all_attentions
)
all_attentions
=
tuple
(
t
.
view
(
*
attention_output_shape
)
for
t
in
all_attentions
)
outputs
=
outputs
+
(
all_attentions
,
)
return
outputs
# last hidden state, presents, (all hidden_states), (attentions)
...
...
@@ -658,7 +663,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
lm_logits
=
self
.
lm_head
(
hidden_states
)
outputs
=
[
lm_logits
]
+
transformer_outputs
[
1
:]
outputs
=
(
lm_logits
,)
+
transformer_outputs
[
1
:]
if
lm_labels
is
not
None
:
# Shift so that tokens < n predict n
shift_logits
=
lm_logits
[...,
:
-
1
,
:].
contiguous
()
...
...
@@ -667,7 +672,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
loss
=
loss_fct
(
shift_logits
.
view
(
-
1
,
shift_logits
.
size
(
-
1
)),
shift_labels
.
view
(
-
1
))
outputs
=
[
loss
]
+
outputs
outputs
=
(
loss
,)
+
outputs
return
outputs
# (loss), lm_logits, presents, (all hidden_states), (attentions)
...
...
@@ -750,18 +755,18 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
lm_logits
=
self
.
lm_head
(
hidden_states
)
mc_logits
=
self
.
multiple_choice_head
(
hidden_states
,
mc_token_ids
)
outputs
=
[
lm_logits
,
mc_logits
]
+
transformer_outputs
[
1
:]
outputs
=
(
lm_logits
,
mc_logits
)
+
transformer_outputs
[
1
:]
if
mc_labels
is
not
None
:
loss_fct
=
CrossEntropyLoss
()
loss
=
loss_fct
(
mc_logits
.
view
(
-
1
,
mc_logits
.
size
(
-
1
)),
mc_labels
.
view
(
-
1
))
outputs
=
[
loss
]
+
outputs
outputs
=
(
loss
,)
+
outputs
if
lm_labels
is
not
None
:
shift_logits
=
lm_logits
[...,
:
-
1
,
:].
contiguous
()
shift_labels
=
lm_labels
[...,
1
:].
contiguous
()
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
loss
=
loss_fct
(
shift_logits
.
view
(
-
1
,
shift_logits
.
size
(
-
1
)),
shift_labels
.
view
(
-
1
))
outputs
=
[
loss
]
+
outputs
outputs
=
(
loss
,)
+
outputs
return
outputs
# (lm loss), (mc loss), lm logits, mc logits, presents, (all hidden_states), (attentions)
pytorch_pretrained_bert/modeling_openai.py
View file @
3b23a846
...
...
@@ -348,14 +348,18 @@ class OpenAIGPTLMHead(nn.Module):
self
.
n_embd
=
config
.
n_embd
self
.
vocab_size
=
config
.
vocab_size
self
.
predict_special_tokens
=
config
.
predict_special_tokens
self
.
torchscript
=
config
.
torchscript
embed_shape
=
model_embeddings_weights
.
shape
self
.
decoder
=
nn
.
Linear
(
embed_shape
[
1
],
embed_shape
[
0
],
bias
=
False
)
self
.
set_embeddings_weights
(
model_embeddings_weights
)
def
set_embeddings_weights
(
self
,
model_embeddings_weights
,
predict_special_tokens
=
True
):
self
.
predict_special_tokens
=
predict_special_tokens
embed_shape
=
model_embeddings_weights
.
shape
self
.
decoder
.
weight
=
model_embeddings_weights
# Tied weights
if
self
.
torchscript
:
self
.
decoder
.
weight
=
nn
.
Parameter
(
model_embeddings_weights
.
clone
())
else
:
self
.
decoder
.
weight
=
model_embeddings_weights
# Tied weights
def
forward
(
self
,
hidden_state
):
lm_logits
=
self
.
decoder
(
hidden_state
)
...
...
@@ -579,26 +583,26 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
output_shape
=
input_shape
+
(
hidden_states
.
size
(
-
1
),)
all_attentions
=
[]
all_hidden_states
=
[]
all_attentions
=
()
all_hidden_states
=
()
for
i
,
block
in
enumerate
(
self
.
h
):
if
self
.
output_hidden_states
:
all_hidden_states
.
append
(
hidden_states
.
view
(
*
output_shape
))
all_hidden_states
=
all_hidden_states
+
(
hidden_states
.
view
(
*
output_shape
)
,
)
outputs
=
block
(
hidden_states
,
head_mask
[
i
])
hidden_states
=
outputs
[
0
]
if
self
.
output_attentions
:
all_attentions
.
append
(
outputs
[
1
])
all_attentions
=
all_attentions
+
(
outputs
[
1
]
,
)
# Add last layer
if
self
.
output_hidden_states
:
all_hidden_states
.
append
(
hidden_states
.
view
(
*
output_shape
))
all_hidden_states
=
all_hidden_states
+
(
hidden_states
.
view
(
*
output_shape
)
,
)
outputs
=
[
hidden_states
.
view
(
*
output_shape
)
]
outputs
=
(
hidden_states
.
view
(
*
output_shape
)
,)
if
self
.
output_hidden_states
:
outputs
.
append
(
all_hidden_states
)
outputs
=
outputs
+
(
all_hidden_states
,
)
if
self
.
output_attentions
:
outputs
.
append
(
all_attentions
)
outputs
=
outputs
+
(
all_attentions
,
)
return
outputs
# last hidden state, (all hidden states), (all attentions)
...
...
@@ -682,7 +686,7 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
hidden_states
=
transformer_outputs
[
0
]
lm_logits
=
self
.
lm_head
(
hidden_states
)
outputs
=
[
lm_logits
]
+
transformer_outputs
[
1
:]
outputs
=
(
lm_logits
,)
+
transformer_outputs
[
1
:]
if
lm_labels
is
not
None
:
# Shift so that tokens < n predict n
shift_logits
=
lm_logits
[...,
:
-
1
,
:].
contiguous
()
...
...
@@ -691,7 +695,7 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
loss
=
loss_fct
(
shift_logits
.
view
(
-
1
,
shift_logits
.
size
(
-
1
)),
shift_labels
.
view
(
-
1
))
outputs
=
[
loss
]
+
outputs
outputs
=
(
loss
,)
+
outputs
return
outputs
# (loss), lm_logits, (all hidden states), (all attentions)
...
...
@@ -785,18 +789,18 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
lm_logits
=
self
.
lm_head
(
hidden_states
)
mc_logits
=
self
.
multiple_choice_head
(
hidden_states
,
mc_token_ids
)
outputs
=
[
lm_logits
,
mc_logits
]
+
transformer_outputs
[
1
:]
outputs
=
(
lm_logits
,
mc_logits
)
+
transformer_outputs
[
1
:]
if
mc_labels
is
not
None
:
loss_fct
=
CrossEntropyLoss
()
loss
=
loss_fct
(
mc_logits
.
view
(
-
1
,
mc_logits
.
size
(
-
1
)),
mc_labels
.
view
(
-
1
))
outputs
=
[
loss
]
+
outputs
outputs
=
(
loss
,)
+
outputs
if
lm_labels
is
not
None
:
shift_logits
=
lm_logits
[...,
:
-
1
,
:].
contiguous
()
shift_labels
=
lm_labels
[...,
1
:].
contiguous
()
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
loss
=
loss_fct
(
shift_logits
.
view
(
-
1
,
shift_logits
.
size
(
-
1
)),
shift_labels
.
view
(
-
1
))
outputs
=
[
loss
]
+
outputs
outputs
=
(
loss
,)
+
outputs
return
outputs
# (lm loss), (mc loss), lm logits, mc logits, (all hidden_states), (attentions)
pytorch_pretrained_bert/modeling_xlnet.py
View file @
3b23a846
...
...
@@ -384,7 +384,8 @@ class XLNetRelativeAttention(nn.Module):
x
=
x
.
reshape
(
x_size
[
1
],
x_size
[
0
],
x_size
[
2
],
x_size
[
3
])
x
=
x
[
1
:,
...]
x
=
x
.
reshape
(
x_size
[
0
],
x_size
[
1
]
-
1
,
x_size
[
2
],
x_size
[
3
])
x
=
x
[:,
0
:
klen
,
:,
:]
# x = x[:, 0:klen, :, :]
x
=
torch
.
index_select
(
x
,
1
,
torch
.
arange
(
klen
))
return
x
...
...
@@ -527,9 +528,9 @@ class XLNetRelativeAttention(nn.Module):
output_h
=
self
.
post_attention
(
h
,
attn_vec
)
output_g
=
None
outputs
=
[
output_h
,
output_g
]
outputs
=
(
output_h
,
output_g
)
if
self
.
output_attentions
:
outputs
=
outputs
+
[
attn_prob
]
outputs
=
outputs
+
(
attn_prob
,)
return
outputs
class
XLNetFeedForward
(
nn
.
Module
):
...
...
@@ -574,7 +575,7 @@ class XLNetLayer(nn.Module):
output_g
=
self
.
ff
(
output_g
)
output_h
=
self
.
ff
(
output_h
)
outputs
=
[
output_h
,
output_g
]
+
outputs
[
2
:]
# Add again attentions if there are there
outputs
=
(
output_h
,
output_g
)
+
outputs
[
2
:]
# Add again attentions if there are there
return
outputs
...
...
@@ -688,7 +689,7 @@ class XLNetModel(XLNetPreTrainedModel):
def
relative_positional_encoding
(
self
,
qlen
,
klen
,
bsz
=
None
):
"""create relative positional encoding."""
freq_seq
=
torch
.
arange
(
0
,
self
.
d_model
,
2.0
,
dtype
=
torch
.
float
)
inv_freq
=
1
/
(
10000
**
(
freq_seq
/
self
.
d_model
))
inv_freq
=
1
/
torch
.
pow
(
10000
,
(
freq_seq
/
self
.
d_model
))
if
self
.
attn_type
==
'bi'
:
# beg, end = klen - 1, -qlen
...
...
@@ -869,7 +870,7 @@ class XLNetModel(XLNetPreTrainedModel):
else
:
head_mask
=
[
None
]
*
self
.
n_layer
new_mems
=
[]
new_mems
=
()
if
mems
is
None
:
mems
=
[
None
]
*
len
(
self
.
layer
)
...
...
@@ -877,7 +878,7 @@ class XLNetModel(XLNetPreTrainedModel):
hidden_states
=
[]
for
i
,
layer_module
in
enumerate
(
self
.
layer
):
# cache new mems
new_mems
.
append
(
self
.
cache_mem
(
output_h
,
mems
[
i
]))
new_mems
=
new_mems
+
(
self
.
cache_mem
(
output_h
,
mems
[
i
])
,
)
if
self
.
output_hidden_states
:
hidden_states
.
append
((
output_h
,
output_g
)
if
output_g
is
not
None
else
output_h
)
...
...
@@ -895,16 +896,16 @@ class XLNetModel(XLNetPreTrainedModel):
output
=
self
.
dropout
(
output_g
if
output_g
is
not
None
else
output_h
)
# Prepare outputs, we transpose back here to shape [bsz, len, hidden_dim] (cf. beginning of forward() method)
outputs
=
[
output
.
permute
(
1
,
0
,
2
).
contiguous
(),
new_mems
]
outputs
=
(
output
.
permute
(
1
,
0
,
2
).
contiguous
(),
new_mems
)
if
self
.
output_hidden_states
:
if
output_g
is
not
None
:
hidden_states
=
[
h
.
permute
(
1
,
0
,
2
).
contiguous
()
for
hs
in
hidden_states
for
h
in
hs
]
hidden_states
=
tuple
(
h
.
permute
(
1
,
0
,
2
).
contiguous
()
for
hs
in
hidden_states
for
h
in
hs
)
else
:
hidden_states
=
[
hs
.
permute
(
1
,
0
,
2
).
contiguous
()
for
hs
in
hidden_states
]
outputs
.
append
(
hidden_states
)
hidden_states
=
tuple
(
hs
.
permute
(
1
,
0
,
2
).
contiguous
()
for
hs
in
hidden_states
)
outputs
=
outputs
+
(
hidden_states
,
)
if
self
.
output_attentions
:
attentions
=
list
(
t
.
permute
(
2
,
3
,
0
,
1
).
contiguous
()
for
t
in
attentions
)
outputs
.
append
(
attentions
)
attentions
=
tuple
(
t
.
permute
(
2
,
3
,
0
,
1
).
contiguous
()
for
t
in
attentions
)
outputs
=
outputs
+
(
attentions
,
)
return
outputs
# outputs, new_mems, (hidden_states), (attentions)
...
...
@@ -974,6 +975,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
super
(
XLNetLMHeadModel
,
self
).
__init__
(
config
)
self
.
attn_type
=
config
.
attn_type
self
.
same_length
=
config
.
same_length
self
.
torchscript
=
config
.
torchscript
self
.
transformer
=
XLNetModel
(
config
)
self
.
lm_loss
=
nn
.
Linear
(
config
.
d_model
,
config
.
n_token
,
bias
=
True
)
...
...
@@ -986,7 +988,10 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
def
tie_weights
(
self
):
""" Make sure we are sharing the embeddings
"""
self
.
lm_loss
.
weight
=
self
.
transformer
.
word_embedding
.
weight
if
self
.
torchscript
:
self
.
lm_loss
.
weight
=
nn
.
Parameter
(
self
.
transformer
.
word_embedding
.
weight
.
clone
())
else
:
self
.
lm_loss
.
weight
=
self
.
transformer
.
word_embedding
.
weight
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
,
input_mask
=
None
,
attention_mask
=
None
,
mems
=
None
,
perm_mask
=
None
,
target_mapping
=
None
,
inp_q
=
None
,
...
...
@@ -1026,14 +1031,14 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
logits
=
self
.
lm_loss
(
transformer_outputs
[
0
])
outputs
=
[
logits
]
+
transformer_outputs
[
1
:]
# Keep mems, hidden states, attentions if there are in it
outputs
=
(
logits
,)
+
transformer_outputs
[
1
:]
# Keep mems, hidden states, attentions if there are in it
if
labels
is
not
None
:
# Flatten the tokens
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
loss
=
loss_fct
(
logits
.
view
(
-
1
,
logits
.
size
(
-
1
)),
labels
.
view
(
-
1
))
outputs
=
[
loss
]
+
outputs
outputs
=
(
loss
,)
+
outputs
return
outputs
# return (loss), logits, (mems), (hidden states), (attentions)
...
...
@@ -1061,7 +1066,7 @@ class XLNetSequenceSummary(nn.Module):
output
=
hidden_states
[:,
0
]
elif
self
.
summary_type
==
'mean'
:
output
=
hidden_states
.
mean
(
dim
=
1
)
elif
summary_type
==
'attn'
:
elif
self
.
summary_type
==
'attn'
:
raise
NotImplementedError
output
=
self
.
summary
(
output
)
...
...
@@ -1180,7 +1185,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
output
=
self
.
sequence_summary
(
output
)
logits
=
self
.
logits_proj
(
output
)
outputs
=
[
logits
]
+
transformer_outputs
[
1
:]
# Keep mems, hidden states, attentions if there are in it
outputs
=
(
logits
,)
+
transformer_outputs
[
1
:]
# Keep mems, hidden states, attentions if there are in it
if
labels
is
not
None
:
if
self
.
num_labels
==
1
:
...
...
@@ -1190,7 +1195,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
else
:
loss_fct
=
CrossEntropyLoss
()
loss
=
loss_fct
(
logits
.
view
(
-
1
,
self
.
num_labels
),
labels
.
view
(
-
1
))
outputs
=
[
loss
]
+
outputs
outputs
=
(
loss
,)
+
outputs
return
outputs
# return (loss), logits, (mems), (hidden states), (attentions)
...
...
@@ -1271,7 +1276,7 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
start_logits
=
start_logits
.
squeeze
(
-
1
)
end_logits
=
end_logits
.
squeeze
(
-
1
)
outputs
=
[
start_logits
,
end_logits
]
+
transformer_outputs
[
1
:]
# Keep mems, hidden states, attentions if there are in it
outputs
=
(
start_logits
,
end_logits
,)
+
transformer_outputs
[
1
:]
# Keep mems, hidden states, attentions if there are in it
if
start_positions
is
not
None
and
end_positions
is
not
None
:
# If we are on multi-GPU, split add a dimension
...
...
@@ -1288,6 +1293,6 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
start_loss
=
loss_fct
(
start_logits
,
start_positions
)
end_loss
=
loss_fct
(
end_logits
,
end_positions
)
total_loss
=
(
start_loss
+
end_loss
)
/
2
outputs
=
[
total_loss
]
+
outputs
outputs
=
(
total_loss
,)
+
outputs
return
outputs
# return (loss), logits, (mems), (hidden states), (attentions)
pytorch_pretrained_bert/tests/model_tests_commons.py
View file @
3b23a846
...
...
@@ -31,6 +31,52 @@ def _config_zero_init(config):
setattr
(
configs_no_init
,
key
,
0.0
)
return
configs_no_init
def
_create_and_check_torchscript_output_attentions
(
tester
,
model_classes
,
config
,
inputs_dict
):
config
.
output_attentions
=
True
_create_and_check_torchscript
(
tester
,
model_classes
,
config
,
inputs_dict
)
def
_create_and_check_torchscript_output_hidden_state
(
tester
,
model_classes
,
config
,
inputs_dict
):
config
.
output_hidden_states
=
True
_create_and_check_torchscript
(
tester
,
model_classes
,
config
,
inputs_dict
)
def
_create_and_check_torchscript
(
tester
,
model_classes
,
config
,
inputs_dict
):
configs_no_init
=
_config_zero_init
(
config
)
# To be sure we have no Nan
configs_no_init
.
torchscript
=
True
for
model_class
in
model_classes
:
model
=
model_class
(
config
=
configs_no_init
)
model
.
eval
()
inputs
=
inputs_dict
[
'input_ids'
]
# Let's keep only input_ids
try
:
torch
.
jit
.
trace
(
model
,
inputs
)
except
RuntimeError
:
tester
.
parent
.
fail
(
"Couldn't trace module."
)
try
:
traced_gpt2
=
torch
.
jit
.
trace
(
model
,
inputs
)
torch
.
jit
.
save
(
traced_gpt2
,
"traced_model.pt"
)
except
RuntimeError
:
tester
.
parent
.
fail
(
"Couldn't save module."
)
try
:
loaded_model
=
torch
.
jit
.
load
(
"traced_model.pt"
)
os
.
remove
(
"traced_model.pt"
)
except
ValueError
:
tester
.
parent
.
fail
(
"Couldn't load module."
)
model
.
eval
()
loaded_model
.
eval
()
model_params
=
model
.
parameters
()
loaded_model_params
=
loaded_model
.
parameters
()
models_equal
=
True
for
p1
,
p2
in
zip
(
model_params
,
loaded_model_params
):
if
p1
.
data
.
ne
(
p2
.
data
).
sum
()
>
0
:
models_equal
=
False
tester
.
parent
.
assertTrue
(
models_equal
)
def
_create_and_check_initialization
(
tester
,
model_classes
,
config
,
inputs_dict
):
configs_no_init
=
_config_zero_init
(
config
)
for
model_class
in
model_classes
:
...
...
@@ -41,7 +87,7 @@ def _create_and_check_initialization(tester, model_classes, config, inputs_dict)
msg
=
"Parameter {} of model {} seems not properly initialized"
.
format
(
name
,
model_class
))
def
_create_and_check_for_headmasking
(
tester
,
model_classes
,
config
,
inputs_dict
):
configs_no_init
=
_config_zero_init
(
config
)
configs_no_init
=
_config_zero_init
(
config
)
# To be sure we have no Nan
for
model_class
in
model_classes
:
config
.
output_attentions
=
True
config
.
output_hidden_states
=
True
...
...
@@ -157,11 +203,17 @@ def _create_and_check_for_hidden_states(tester, model_classes, config, inputs_di
[
tester
.
seq_length
,
tester
.
hidden_size
])
def
create_and_check_commons
(
tester
,
config
,
inputs_dict
,
test_pruning
=
True
):
def
create_and_check_commons
(
tester
,
config
,
inputs_dict
,
test_pruning
=
True
,
test_torchscript
=
True
):
_create_and_check_initialization
(
tester
,
tester
.
all_model_classes
,
config
,
inputs_dict
)
_create_and_check_for_attentions
(
tester
,
tester
.
all_model_classes
,
config
,
inputs_dict
)
_create_and_check_for_headmasking
(
tester
,
tester
.
all_model_classes
,
config
,
inputs_dict
)
_create_and_check_for_hidden_states
(
tester
,
tester
.
all_model_classes
,
config
,
inputs_dict
)
if
test_torchscript
:
_create_and_check_torchscript
(
tester
,
tester
.
all_model_classes
,
config
,
inputs_dict
)
_create_and_check_torchscript_output_attentions
(
tester
,
tester
.
all_model_classes
,
config
,
inputs_dict
)
_create_and_check_torchscript_output_hidden_state
(
tester
,
tester
.
all_model_classes
,
config
,
inputs_dict
)
if
test_pruning
:
_create_and_check_for_head_pruning
(
tester
,
tester
.
all_model_classes
,
config
,
inputs_dict
)
...
...
pytorch_pretrained_bert/tests/modeling_transfo_xl_test.py
View file @
3b23a846
...
...
@@ -173,7 +173,7 @@ class TransfoXLModelTest(unittest.TestCase):
def
create_and_check_transfo_xl_commons
(
self
,
config
,
input_ids_1
,
input_ids_2
,
lm_labels
):
inputs_dict
=
{
'input_ids'
:
input_ids_1
}
create_and_check_commons
(
self
,
config
,
inputs_dict
,
test_pruning
=
False
)
create_and_check_commons
(
self
,
config
,
inputs_dict
,
test_pruning
=
False
,
test_torchscript
=
False
)
def
test_default
(
self
):
self
.
run_tester
(
TransfoXLModelTest
.
TransfoXLModelTester
(
self
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment