Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
4e0f2434
Commit
4e0f2434
authored
Oct 17, 2019
by
Rémi Louf
Browse files
document the MLM modification + raise exception on MLM training with encoder-decoder
parent
624a5644
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
14 additions
and
5 deletions
+14
-5
transformers/modeling_bert.py
transformers/modeling_bert.py
+14
-5
No files found.
transformers/modeling_bert.py
View file @
4e0f2434
...
@@ -830,21 +830,30 @@ class BertForMaskedLM(BertPreTrainedModel):
...
@@ -830,21 +830,30 @@ class BertForMaskedLM(BertPreTrainedModel):
prediction_scores
=
self
.
cls
(
sequence_output
)
prediction_scores
=
self
.
cls
(
sequence_output
)
outputs
=
(
prediction_scores
,)
+
outputs
[
2
:]
# Add hidden states and attention if they are here
outputs
=
(
prediction_scores
,)
+
outputs
[
2
:]
# Add hidden states and attention if they are here
# Although this may seem awkward, BertForMaskedLM supports two scenarios:
# 1. If a tensor that contains the indices of masked labels is provided,
# the cross-entropy is the MLM cross-entropy that measures the likelihood
# of predictions for masked words.
# 2. If encoder hidden states are provided we are in a causal situation where we
# try to predict the next word for each input in the encoder.
if
masked_lm_labels
is
not
None
and
encoder_hidden_states
is
not
None
:
raise
AttributeError
(
"Masked LM training with an encoder-decoder is not supported."
)
if
masked_lm_labels
is
not
None
:
if
masked_lm_labels
is
not
None
:
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
# -1 index = padding token
masked_lm_loss
=
loss_fct
(
prediction_scores
.
view
(
-
1
,
self
.
config
.
vocab_size
),
masked_lm_labels
.
view
(
-
1
))
masked_lm_loss
=
loss_fct
(
prediction_scores
.
view
(
-
1
,
self
.
config
.
vocab_size
),
masked_lm_labels
.
view
(
-
1
))
outputs
=
(
masked_lm_loss
,)
+
outputs
outputs
=
(
masked_lm_loss
,)
+
outputs
if
encoder_hidden_states
is
not
None
:
if
encoder_hidden_states
is
not
None
:
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
# we are doing next-token prediction; shift prediction scores and input ids by one
# shift predictions scores and input ids by one before computing loss
prediction_scores
=
prediction_scores
[:,
:
-
1
,
:]
prediction_scores
=
prediction_scores
[:,
:
-
1
,
:]
input_ids
=
input_ids
[:,
1
:,
:]
input_ids
=
input_ids
[:,
1
:,
:]
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
seq2seq_loss
=
loss_fct
(
prediction_scores
.
view
(
-
1
,
self
.
config
.
vocab_size
),
input_ids
.
view
(
-
1
))
seq2seq_loss
=
loss_fct
(
prediction_scores
.
view
(
-
1
,
self
.
config
.
vocab_size
),
input_ids
.
view
(
-
1
))
outputs
=
(
seq2seq_loss
,)
+
outputs
outputs
=
(
seq2seq_loss
,)
+
outputs
return
outputs
# (m
asked_lm
_loss), prediction_scores, (hidden_states), (attentions)
return
outputs
# (m
lm_or_seq2seq
_loss), prediction_scores, (hidden_states), (attentions)
@
add_start_docstrings
(
"""Bert Model with a `next sentence prediction (classification)` head on top. """
,
@
add_start_docstrings
(
"""Bert Model with a `next sentence prediction (classification)` head on top. """
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment