Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
dc580dd4
Commit
dc580dd4
authored
Oct 17, 2019
by
Rémi Louf
Browse files
add lm_labels for the LM cross-entropy
parent
f873a3ed
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
5 additions
and
5 deletions
+5
-5
transformers/modeling_bert.py
transformers/modeling_bert.py
+5
-5
No files found.
transformers/modeling_bert.py
View file @
dc580dd4
...
...
@@ -819,7 +819,7 @@ class BertForMaskedLM(BertPreTrainedModel):
self
.
bert
.
embeddings
.
word_embeddings
)
def
forward
(
self
,
input_ids
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
,
masked_lm_labels
=
None
,
encoder_hidden_states
=
None
,
encoder_attention_mask
=
None
):
masked_lm_labels
=
None
,
lm_labels
=
None
,
encoder_hidden_states
=
None
,
encoder_attention_mask
=
None
):
outputs
=
self
.
bert
(
input_ids
,
attention_mask
=
attention_mask
,
...
...
@@ -840,7 +840,7 @@ class BertForMaskedLM(BertPreTrainedModel):
# of predictions for masked words.
# 2. If encoder hidden states are provided we are in a causal situation where we
# try to predict the next word for each input in the encoder.
if
masked_lm_labels
is
not
None
and
encoder_hidden_state
s
is
not
None
:
if
masked_lm_labels
is
not
None
and
lm_label
s
is
not
None
:
raise
AttributeError
(
"Masked LM training with an encoder-decoder is not supported."
)
if
masked_lm_labels
is
not
None
:
...
...
@@ -848,12 +848,12 @@ class BertForMaskedLM(BertPreTrainedModel):
masked_lm_loss
=
loss_fct
(
prediction_scores
.
view
(
-
1
,
self
.
config
.
vocab_size
),
masked_lm_labels
.
view
(
-
1
))
outputs
=
(
masked_lm_loss
,)
+
outputs
if
encoder_hidden_state
s
is
not
None
:
if
lm_label
s
is
not
None
:
# we are doing next-token prediction; shift prediction scores and input ids by one
prediction_scores
=
prediction_scores
[:,
:
-
1
,
:]
input_ids
=
input_id
s
[:,
1
:,
:]
lm_labels
=
lm_label
s
[:,
1
:,
:]
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
seq2seq_loss
=
loss_fct
(
prediction_scores
.
view
(
-
1
,
self
.
config
.
vocab_size
),
input_id
s
.
view
(
-
1
))
seq2seq_loss
=
loss_fct
(
prediction_scores
.
view
(
-
1
,
self
.
config
.
vocab_size
),
lm_label
s
.
view
(
-
1
))
outputs
=
(
seq2seq_loss
,)
+
outputs
return
outputs
# (mlm_or_seq2seq_loss), prediction_scores, (hidden_states), (attentions)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment