Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
f7cd7392
Commit
f7cd7392
authored
Jul 15, 2019
by
thomwolf
Browse files
fixed tests
parent
e28d8bde
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
63 additions
and
38 deletions
+63
-38
pytorch_transformers/modeling_bert.py
pytorch_transformers/modeling_bert.py
+31
-19
pytorch_transformers/modeling_gpt2.py
pytorch_transformers/modeling_gpt2.py
+4
-2
pytorch_transformers/modeling_openai.py
pytorch_transformers/modeling_openai.py
+4
-2
pytorch_transformers/modeling_transfo_xl.py
pytorch_transformers/modeling_transfo_xl.py
+1
-1
pytorch_transformers/modeling_utils.py
pytorch_transformers/modeling_utils.py
+1
-1
pytorch_transformers/modeling_xlm.py
pytorch_transformers/modeling_xlm.py
+9
-6
pytorch_transformers/modeling_xlnet.py
pytorch_transformers/modeling_xlnet.py
+13
-7
No files found.
pytorch_transformers/modeling_bert.py
View file @
f7cd7392
...
@@ -253,7 +253,7 @@ class BertEmbeddings(nn.Module):
...
@@ -253,7 +253,7 @@ class BertEmbeddings(nn.Module):
self
.
LayerNorm
=
BertLayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_eps
)
self
.
LayerNorm
=
BertLayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_eps
)
self
.
dropout
=
nn
.
Dropout
(
config
.
hidden_dropout_prob
)
self
.
dropout
=
nn
.
Dropout
(
config
.
hidden_dropout_prob
)
def
forward
(
self
,
input_ids
,
position
_ids
=
None
,
token_type
_ids
=
None
):
def
forward
(
self
,
input_ids
,
token_type
_ids
=
None
,
position
_ids
=
None
):
seq_length
=
input_ids
.
size
(
1
)
seq_length
=
input_ids
.
size
(
1
)
if
position_ids
is
None
:
if
position_ids
is
None
:
position_ids
=
torch
.
arange
(
seq_length
,
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
position_ids
=
torch
.
arange
(
seq_length
,
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
...
@@ -667,7 +667,7 @@ class BertModel(BertPreTrainedModel):
...
@@ -667,7 +667,7 @@ class BertModel(BertPreTrainedModel):
for
layer
,
heads
in
heads_to_prune
.
items
():
for
layer
,
heads
in
heads_to_prune
.
items
():
self
.
encoder
.
layer
[
layer
].
attention
.
prune_heads
(
heads
)
self
.
encoder
.
layer
[
layer
].
attention
.
prune_heads
(
heads
)
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
token_type_ids
=
None
,
attention_mask
=
None
,
head_mask
=
None
):
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
,
attention_mask
=
None
,
position_ids
=
None
,
head_mask
=
None
):
if
attention_mask
is
None
:
if
attention_mask
is
None
:
attention_mask
=
torch
.
ones_like
(
input_ids
)
attention_mask
=
torch
.
ones_like
(
input_ids
)
if
token_type_ids
is
None
:
if
token_type_ids
is
None
:
...
@@ -703,7 +703,7 @@ class BertModel(BertPreTrainedModel):
...
@@ -703,7 +703,7 @@ class BertModel(BertPreTrainedModel):
else
:
else
:
head_mask
=
[
None
]
*
self
.
config
.
num_hidden_layers
head_mask
=
[
None
]
*
self
.
config
.
num_hidden_layers
embedding_output
=
self
.
embeddings
(
input_ids
,
position_ids
,
token_type_ids
)
embedding_output
=
self
.
embeddings
(
input_ids
,
position_ids
=
position_ids
,
token_type_ids
=
token_type_ids
)
encoder_outputs
=
self
.
encoder
(
embedding_output
,
encoder_outputs
=
self
.
encoder
(
embedding_output
,
extended_attention_mask
,
extended_attention_mask
,
head_mask
=
head_mask
)
head_mask
=
head_mask
)
...
@@ -772,9 +772,10 @@ class BertForPreTraining(BertPreTrainedModel):
...
@@ -772,9 +772,10 @@ class BertForPreTraining(BertPreTrainedModel):
self
.
_tie_or_clone_weights
(
self
.
cls
.
predictions
.
decoder
,
self
.
_tie_or_clone_weights
(
self
.
cls
.
predictions
.
decoder
,
self
.
bert
.
embeddings
.
word_embeddings
)
self
.
bert
.
embeddings
.
word_embeddings
)
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
token_type_ids
=
None
,
attention_mask
=
None
,
masked_lm_labels
=
None
,
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
,
attention_mask
=
None
,
masked_lm_labels
=
None
,
next_sentence_label
=
None
,
head_mask
=
None
):
next_sentence_label
=
None
,
position_ids
=
None
,
head_mask
=
None
):
outputs
=
self
.
bert
(
input_ids
,
position_ids
,
token_type_ids
,
attention_mask
,
head_mask
=
head_mask
)
outputs
=
self
.
bert
(
input_ids
,
position_ids
=
position_ids
,
token_type_ids
=
token_type_ids
,
attention_mask
=
attention_mask
,
head_mask
=
head_mask
)
sequence_output
,
pooled_output
=
outputs
[:
2
]
sequence_output
,
pooled_output
=
outputs
[:
2
]
prediction_scores
,
seq_relationship_score
=
self
.
cls
(
sequence_output
,
pooled_output
)
prediction_scores
,
seq_relationship_score
=
self
.
cls
(
sequence_output
,
pooled_output
)
...
@@ -841,8 +842,10 @@ class BertForMaskedLM(BertPreTrainedModel):
...
@@ -841,8 +842,10 @@ class BertForMaskedLM(BertPreTrainedModel):
self
.
_tie_or_clone_weights
(
self
.
cls
.
predictions
.
decoder
,
self
.
_tie_or_clone_weights
(
self
.
cls
.
predictions
.
decoder
,
self
.
bert
.
embeddings
.
word_embeddings
)
self
.
bert
.
embeddings
.
word_embeddings
)
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
token_type_ids
=
None
,
attention_mask
=
None
,
masked_lm_labels
=
None
,
head_mask
=
None
):
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
,
attention_mask
=
None
,
masked_lm_labels
=
None
,
outputs
=
self
.
bert
(
input_ids
,
position_ids
,
token_type_ids
,
attention_mask
,
head_mask
=
head_mask
)
position_ids
=
None
,
head_mask
=
None
):
outputs
=
self
.
bert
(
input_ids
,
position_ids
=
position_ids
,
token_type_ids
=
token_type_ids
,
attention_mask
=
attention_mask
,
head_mask
=
head_mask
)
sequence_output
=
outputs
[
0
]
sequence_output
=
outputs
[
0
]
prediction_scores
=
self
.
cls
(
sequence_output
)
prediction_scores
=
self
.
cls
(
sequence_output
)
...
@@ -898,8 +901,10 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
...
@@ -898,8 +901,10 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
self
.
apply
(
self
.
init_weights
)
self
.
apply
(
self
.
init_weights
)
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
token_type_ids
=
None
,
attention_mask
=
None
,
next_sentence_label
=
None
,
head_mask
=
None
):
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
,
attention_mask
=
None
,
next_sentence_label
=
None
,
outputs
=
self
.
bert
(
input_ids
,
position_ids
,
token_type_ids
,
attention_mask
,
head_mask
=
head_mask
)
position_ids
=
None
,
head_mask
=
None
):
outputs
=
self
.
bert
(
input_ids
,
position_ids
=
position_ids
,
token_type_ids
=
token_type_ids
,
attention_mask
=
attention_mask
,
head_mask
=
head_mask
)
pooled_output
=
outputs
[
1
]
pooled_output
=
outputs
[
1
]
seq_relationship_score
=
self
.
cls
(
pooled_output
)
seq_relationship_score
=
self
.
cls
(
pooled_output
)
...
@@ -959,8 +964,10 @@ class BertForSequenceClassification(BertPreTrainedModel):
...
@@ -959,8 +964,10 @@ class BertForSequenceClassification(BertPreTrainedModel):
self
.
apply
(
self
.
init_weights
)
self
.
apply
(
self
.
init_weights
)
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
token_type_ids
=
None
,
attention_mask
=
None
,
labels
=
None
,
head_mask
=
None
):
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
,
attention_mask
=
None
,
labels
=
None
,
outputs
=
self
.
bert
(
input_ids
,
position_ids
,
token_type_ids
,
attention_mask
,
head_mask
=
head_mask
)
position_ids
=
None
,
head_mask
=
None
):
outputs
=
self
.
bert
(
input_ids
,
position_ids
=
position_ids
,
token_type_ids
=
token_type_ids
,
attention_mask
=
attention_mask
,
head_mask
=
head_mask
)
pooled_output
=
outputs
[
1
]
pooled_output
=
outputs
[
1
]
pooled_output
=
self
.
dropout
(
pooled_output
)
pooled_output
=
self
.
dropout
(
pooled_output
)
...
@@ -1063,14 +1070,16 @@ class BertForMultipleChoice(BertPreTrainedModel):
...
@@ -1063,14 +1070,16 @@ class BertForMultipleChoice(BertPreTrainedModel):
self
.
apply
(
self
.
init_weights
)
self
.
apply
(
self
.
init_weights
)
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
token_type_ids
=
None
,
attention_mask
=
None
,
labels
=
None
,
head_mask
=
None
):
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
,
attention_mask
=
None
,
labels
=
None
,
position_ids
=
None
,
head_mask
=
None
):
num_choices
=
input_ids
.
shape
[
1
]
num_choices
=
input_ids
.
shape
[
1
]
flat_input_ids
=
input_ids
.
view
(
-
1
,
input_ids
.
size
(
-
1
))
flat_input_ids
=
input_ids
.
view
(
-
1
,
input_ids
.
size
(
-
1
))
flat_position_ids
=
position_ids
.
view
(
-
1
,
position_ids
.
size
(
-
1
))
if
position_ids
is
not
None
else
None
flat_position_ids
=
position_ids
.
view
(
-
1
,
position_ids
.
size
(
-
1
))
if
position_ids
is
not
None
else
None
flat_token_type_ids
=
token_type_ids
.
view
(
-
1
,
token_type_ids
.
size
(
-
1
))
if
token_type_ids
is
not
None
else
None
flat_token_type_ids
=
token_type_ids
.
view
(
-
1
,
token_type_ids
.
size
(
-
1
))
if
token_type_ids
is
not
None
else
None
flat_attention_mask
=
attention_mask
.
view
(
-
1
,
attention_mask
.
size
(
-
1
))
if
attention_mask
is
not
None
else
None
flat_attention_mask
=
attention_mask
.
view
(
-
1
,
attention_mask
.
size
(
-
1
))
if
attention_mask
is
not
None
else
None
outputs
=
self
.
bert
(
flat_input_ids
,
flat_position_ids
,
flat_token_type_ids
,
flat_attention_mask
,
head_mask
=
head_mask
)
outputs
=
self
.
bert
(
flat_input_ids
,
position_ids
=
flat_position_ids
,
token_type_ids
=
flat_token_type_ids
,
attention_mask
=
flat_attention_mask
,
head_mask
=
head_mask
)
pooled_output
=
outputs
[
1
]
pooled_output
=
outputs
[
1
]
pooled_output
=
self
.
dropout
(
pooled_output
)
pooled_output
=
self
.
dropout
(
pooled_output
)
...
@@ -1131,8 +1140,10 @@ class BertForTokenClassification(BertPreTrainedModel):
...
@@ -1131,8 +1140,10 @@ class BertForTokenClassification(BertPreTrainedModel):
self
.
apply
(
self
.
init_weights
)
self
.
apply
(
self
.
init_weights
)
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
token_type_ids
=
None
,
attention_mask
=
None
,
labels
=
None
,
head_mask
=
None
):
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
,
attention_mask
=
None
,
labels
=
None
,
outputs
=
self
.
bert
(
input_ids
,
position_ids
,
token_type_ids
,
attention_mask
,
head_mask
=
head_mask
)
position_ids
=
None
,
head_mask
=
None
):
outputs
=
self
.
bert
(
input_ids
,
position_ids
=
position_ids
,
token_type_ids
=
token_type_ids
,
attention_mask
=
attention_mask
,
head_mask
=
head_mask
)
sequence_output
=
outputs
[
0
]
sequence_output
=
outputs
[
0
]
sequence_output
=
self
.
dropout
(
sequence_output
)
sequence_output
=
self
.
dropout
(
sequence_output
)
...
@@ -1205,9 +1216,10 @@ class BertForQuestionAnswering(BertPreTrainedModel):
...
@@ -1205,9 +1216,10 @@ class BertForQuestionAnswering(BertPreTrainedModel):
self
.
apply
(
self
.
init_weights
)
self
.
apply
(
self
.
init_weights
)
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
token_type_ids
=
None
,
attention_mask
=
None
,
start_positions
=
None
,
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
,
attention_mask
=
None
,
start_positions
=
None
,
end_positions
=
None
,
head_mask
=
None
):
end_positions
=
None
,
position_ids
=
None
,
head_mask
=
None
):
outputs
=
self
.
bert
(
input_ids
,
position_ids
,
token_type_ids
,
attention_mask
,
head_mask
=
head_mask
)
outputs
=
self
.
bert
(
input_ids
,
position_ids
=
position_ids
,
token_type_ids
=
token_type_ids
,
attention_mask
=
attention_mask
,
head_mask
=
head_mask
)
sequence_output
=
outputs
[
0
]
sequence_output
=
outputs
[
0
]
logits
=
self
.
qa_outputs
(
sequence_output
)
logits
=
self
.
qa_outputs
(
sequence_output
)
...
...
pytorch_transformers/modeling_gpt2.py
View file @
f7cd7392
...
@@ -591,7 +591,8 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
...
@@ -591,7 +591,8 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
self
.
transformer
.
wte
)
self
.
transformer
.
wte
)
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
token_type_ids
=
None
,
labels
=
None
,
past
=
None
,
head_mask
=
None
):
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
token_type_ids
=
None
,
labels
=
None
,
past
=
None
,
head_mask
=
None
):
transformer_outputs
=
self
.
transformer
(
input_ids
,
position_ids
,
token_type_ids
,
past
,
head_mask
)
transformer_outputs
=
self
.
transformer
(
input_ids
,
position_ids
=
position_ids
,
token_type_ids
=
token_type_ids
,
past
=
past
,
head_mask
=
head_mask
)
hidden_states
=
transformer_outputs
[
0
]
hidden_states
=
transformer_outputs
[
0
]
lm_logits
=
self
.
lm_head
(
hidden_states
)
lm_logits
=
self
.
lm_head
(
hidden_states
)
...
@@ -709,7 +710,8 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
...
@@ -709,7 +710,8 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
def
forward
(
self
,
input_ids
,
mc_token_ids
=
None
,
lm_labels
=
None
,
mc_labels
=
None
,
token_type_ids
=
None
,
def
forward
(
self
,
input_ids
,
mc_token_ids
=
None
,
lm_labels
=
None
,
mc_labels
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
past
=
None
,
head_mask
=
None
):
position_ids
=
None
,
past
=
None
,
head_mask
=
None
):
transformer_outputs
=
self
.
transformer
(
input_ids
,
position_ids
,
token_type_ids
,
past
,
head_mask
)
transformer_outputs
=
self
.
transformer
(
input_ids
,
position_ids
=
position_ids
,
token_type_ids
=
token_type_ids
,
past
=
past
,
head_mask
=
head_mask
)
hidden_states
=
transformer_outputs
[
0
]
hidden_states
=
transformer_outputs
[
0
]
lm_logits
=
self
.
lm_head
(
hidden_states
)
lm_logits
=
self
.
lm_head
(
hidden_states
)
...
...
pytorch_transformers/modeling_openai.py
View file @
f7cd7392
...
@@ -582,7 +582,8 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
...
@@ -582,7 +582,8 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
self
.
transformer
.
tokens_embed
)
self
.
transformer
.
tokens_embed
)
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
token_type_ids
=
None
,
labels
=
None
,
head_mask
=
None
):
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
token_type_ids
=
None
,
labels
=
None
,
head_mask
=
None
):
transformer_outputs
=
self
.
transformer
(
input_ids
,
position_ids
,
token_type_ids
,
head_mask
)
transformer_outputs
=
self
.
transformer
(
input_ids
,
position_ids
=
position_ids
,
token_type_ids
=
token_type_ids
,
head_mask
=
head_mask
)
hidden_states
=
transformer_outputs
[
0
]
hidden_states
=
transformer_outputs
[
0
]
lm_logits
=
self
.
lm_head
(
hidden_states
)
lm_logits
=
self
.
lm_head
(
hidden_states
)
...
@@ -693,7 +694,8 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
...
@@ -693,7 +694,8 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
def
forward
(
self
,
input_ids
,
mc_token_ids
=
None
,
lm_labels
=
None
,
mc_labels
=
None
,
token_type_ids
=
None
,
def
forward
(
self
,
input_ids
,
mc_token_ids
=
None
,
lm_labels
=
None
,
mc_labels
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
):
position_ids
=
None
,
head_mask
=
None
):
transformer_outputs
=
self
.
transformer
(
input_ids
,
position_ids
,
token_type_ids
,
head_mask
)
transformer_outputs
=
self
.
transformer
(
input_ids
,
position_ids
=
position_ids
,
token_type_ids
=
token_type_ids
,
head_mask
=
head_mask
)
hidden_states
=
transformer_outputs
[
0
]
hidden_states
=
transformer_outputs
[
0
]
lm_logits
=
self
.
lm_head
(
hidden_states
)
lm_logits
=
self
.
lm_head
(
hidden_states
)
...
...
pytorch_transformers/modeling_transfo_xl.py
View file @
f7cd7392
...
@@ -1344,7 +1344,7 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
...
@@ -1344,7 +1344,7 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
bsz
=
input_ids
.
size
(
0
)
bsz
=
input_ids
.
size
(
0
)
tgt_len
=
input_ids
.
size
(
1
)
tgt_len
=
input_ids
.
size
(
1
)
transformer_outputs
=
self
.
transformer
(
input_ids
,
mems
,
head_mask
)
transformer_outputs
=
self
.
transformer
(
input_ids
,
mems
=
mems
,
head_mask
=
head_mask
)
last_hidden
=
transformer_outputs
[
0
]
last_hidden
=
transformer_outputs
[
0
]
pred_hid
=
last_hidden
[:,
-
tgt_len
:]
pred_hid
=
last_hidden
[:,
-
tgt_len
:]
...
...
pytorch_transformers/modeling_utils.py
View file @
f7cd7392
...
@@ -594,7 +594,7 @@ class SQuADHead(nn.Module):
...
@@ -594,7 +594,7 @@ class SQuADHead(nn.Module):
"""
"""
outputs
=
()
outputs
=
()
start_logits
=
self
.
start_logits
(
hidden_states
,
p_mask
)
start_logits
=
self
.
start_logits
(
hidden_states
,
p_mask
=
p_mask
)
if
start_positions
is
not
None
and
end_positions
is
not
None
:
if
start_positions
is
not
None
and
end_positions
is
not
None
:
# If we are on multi-GPU, let's remove the dimension added by batch splitting
# If we are on multi-GPU, let's remove the dimension added by batch splitting
...
...
pytorch_transformers/modeling_xlm.py
View file @
f7cd7392
...
@@ -768,8 +768,9 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
...
@@ -768,8 +768,9 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
def
forward
(
self
,
input_ids
,
lengths
=
None
,
position_ids
=
None
,
langs
=
None
,
token_type_ids
=
None
,
def
forward
(
self
,
input_ids
,
lengths
=
None
,
position_ids
=
None
,
langs
=
None
,
token_type_ids
=
None
,
attention_mask
=
None
,
cache
=
None
,
labels
=
None
,
head_mask
=
None
):
attention_mask
=
None
,
cache
=
None
,
labels
=
None
,
head_mask
=
None
):
transformer_outputs
=
self
.
transformer
(
input_ids
,
lengths
=
lengths
,
position_ids
=
position_ids
,
token_type_ids
=
token_type_ids
,
transformer_outputs
=
self
.
transformer
(
input_ids
,
lengths
=
lengths
,
position_ids
=
position_ids
,
langs
=
langs
,
attention_mask
=
attention_mask
,
cache
=
cache
,
head_mask
=
head_mask
)
token_type_ids
=
token_type_ids
,
langs
=
langs
,
attention_mask
=
attention_mask
,
cache
=
cache
,
head_mask
=
head_mask
)
output
=
transformer_outputs
[
0
]
output
=
transformer_outputs
[
0
]
outputs
=
self
.
pred_layer
(
output
,
labels
)
outputs
=
self
.
pred_layer
(
output
,
labels
)
...
@@ -825,8 +826,9 @@ class XLMForSequenceClassification(XLMPreTrainedModel):
...
@@ -825,8 +826,9 @@ class XLMForSequenceClassification(XLMPreTrainedModel):
def
forward
(
self
,
input_ids
,
lengths
=
None
,
position_ids
=
None
,
langs
=
None
,
token_type_ids
=
None
,
def
forward
(
self
,
input_ids
,
lengths
=
None
,
position_ids
=
None
,
langs
=
None
,
token_type_ids
=
None
,
attention_mask
=
None
,
cache
=
None
,
labels
=
None
,
head_mask
=
None
):
attention_mask
=
None
,
cache
=
None
,
labels
=
None
,
head_mask
=
None
):
transformer_outputs
=
self
.
transformer
(
input_ids
,
lengths
=
lengths
,
position_ids
=
position_ids
,
token_type_ids
=
token_type_ids
,
transformer_outputs
=
self
.
transformer
(
input_ids
,
lengths
=
lengths
,
position_ids
=
position_ids
,
langs
=
langs
,
attention_mask
=
attention_mask
,
cache
=
cache
,
head_mask
=
head_mask
)
token_type_ids
=
token_type_ids
,
langs
=
langs
,
attention_mask
=
attention_mask
,
cache
=
cache
,
head_mask
=
head_mask
)
output
=
transformer_outputs
[
0
]
output
=
transformer_outputs
[
0
]
logits
=
self
.
sequence_summary
(
output
)
logits
=
self
.
sequence_summary
(
output
)
...
@@ -905,8 +907,9 @@ class XLMForQuestionAnswering(XLMPreTrainedModel):
...
@@ -905,8 +907,9 @@ class XLMForQuestionAnswering(XLMPreTrainedModel):
def
forward
(
self
,
input_ids
,
lengths
=
None
,
position_ids
=
None
,
langs
=
None
,
token_type_ids
=
None
,
def
forward
(
self
,
input_ids
,
lengths
=
None
,
position_ids
=
None
,
langs
=
None
,
token_type_ids
=
None
,
attention_mask
=
None
,
cache
=
None
,
start_positions
=
None
,
end_positions
=
None
,
attention_mask
=
None
,
cache
=
None
,
start_positions
=
None
,
end_positions
=
None
,
cls_index
=
None
,
is_impossible
=
None
,
p_mask
=
None
,
head_mask
=
None
):
cls_index
=
None
,
is_impossible
=
None
,
p_mask
=
None
,
head_mask
=
None
):
transformer_outputs
=
self
.
transformer
(
input_ids
,
lengths
=
lengths
,
position_ids
=
position_ids
,
token_type_ids
=
token_type_ids
,
transformer_outputs
=
self
.
transformer
(
input_ids
,
lengths
=
lengths
,
position_ids
=
position_ids
,
langs
=
langs
,
attention_mask
=
attention_mask
,
cache
=
cache
,
head_mask
=
head_mask
)
token_type_ids
=
token_type_ids
,
langs
=
langs
,
attention_mask
=
attention_mask
,
cache
=
cache
,
head_mask
=
head_mask
)
output
=
transformer_outputs
[
0
]
output
=
transformer_outputs
[
0
]
...
...
pytorch_transformers/modeling_xlnet.py
View file @
f7cd7392
...
@@ -1049,8 +1049,10 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
...
@@ -1049,8 +1049,10 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
,
input_mask
=
None
,
attention_mask
=
None
,
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
,
input_mask
=
None
,
attention_mask
=
None
,
mems
=
None
,
perm_mask
=
None
,
target_mapping
=
None
,
mems
=
None
,
perm_mask
=
None
,
target_mapping
=
None
,
labels
=
None
,
head_mask
=
None
):
labels
=
None
,
head_mask
=
None
):
transformer_outputs
=
self
.
transformer
(
input_ids
,
token_type_ids
,
input_mask
,
attention_mask
,
transformer_outputs
=
self
.
transformer
(
input_ids
,
token_type_ids
=
token_type_ids
,
mems
,
perm_mask
,
target_mapping
,
head_mask
)
input_mask
=
input_mask
,
attention_mask
=
attention_mask
,
mems
=
mems
,
perm_mask
=
perm_mask
,
target_mapping
=
target_mapping
,
head_mask
=
head_mask
)
logits
=
self
.
lm_loss
(
transformer_outputs
[
0
])
logits
=
self
.
lm_loss
(
transformer_outputs
[
0
])
...
@@ -1119,8 +1121,10 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
...
@@ -1119,8 +1121,10 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
,
input_mask
=
None
,
attention_mask
=
None
,
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
,
input_mask
=
None
,
attention_mask
=
None
,
mems
=
None
,
perm_mask
=
None
,
target_mapping
=
None
,
mems
=
None
,
perm_mask
=
None
,
target_mapping
=
None
,
labels
=
None
,
head_mask
=
None
):
labels
=
None
,
head_mask
=
None
):
transformer_outputs
=
self
.
transformer
(
input_ids
,
token_type_ids
,
input_mask
,
attention_mask
,
transformer_outputs
=
self
.
transformer
(
input_ids
,
token_type_ids
=
token_type_ids
,
mems
,
perm_mask
,
target_mapping
,
head_mask
)
input_mask
=
input_mask
,
attention_mask
=
attention_mask
,
mems
=
mems
,
perm_mask
=
perm_mask
,
target_mapping
=
target_mapping
,
head_mask
=
head_mask
)
output
=
transformer_outputs
[
0
]
output
=
transformer_outputs
[
0
]
output
=
self
.
sequence_summary
(
output
)
output
=
self
.
sequence_summary
(
output
)
...
@@ -1209,10 +1213,12 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
...
@@ -1209,10 +1213,12 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
mems
=
None
,
perm_mask
=
None
,
target_mapping
=
None
,
mems
=
None
,
perm_mask
=
None
,
target_mapping
=
None
,
start_positions
=
None
,
end_positions
=
None
,
cls_index
=
None
,
is_impossible
=
None
,
p_mask
=
None
,
start_positions
=
None
,
end_positions
=
None
,
cls_index
=
None
,
is_impossible
=
None
,
p_mask
=
None
,
head_mask
=
None
):
head_mask
=
None
):
transformer_outputs
=
self
.
transformer
(
input_ids
,
token_type_ids
,
input_mask
,
attention_mask
,
transformer_outputs
=
self
.
transformer
(
input_ids
,
token_type_ids
=
token_type_ids
,
mems
,
perm_mask
,
target_mapping
,
head_mask
)
input_mask
=
input_mask
,
attention_mask
=
attention_mask
,
mems
=
mems
,
perm_mask
=
perm_mask
,
target_mapping
=
target_mapping
,
head_mask
=
head_mask
)
hidden_states
=
transformer_outputs
[
0
]
hidden_states
=
transformer_outputs
[
0
]
start_logits
=
self
.
start_logits
(
hidden_states
,
p_mask
)
start_logits
=
self
.
start_logits
(
hidden_states
,
p_mask
=
p_mask
)
outputs
=
transformer_outputs
[
1
:]
# Keep mems, hidden states, attentions if there are in it
outputs
=
transformer_outputs
[
1
:]
# Keep mems, hidden states, attentions if there are in it
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment