Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
4e0f2434
Commit
4e0f2434
authored
Oct 17, 2019
by
Rémi Louf
Browse files
document the MLM modification + raise exception on MLM training with encoder-decoder
parent
624a5644
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
14 additions
and
5 deletions
+14
-5
transformers/modeling_bert.py
transformers/modeling_bert.py
+14
-5
No files found.
transformers/modeling_bert.py
View file @
4e0f2434
...
@@ -830,21 +830,30 @@ class BertForMaskedLM(BertPreTrainedModel):
...
@@ -830,21 +830,30 @@ class BertForMaskedLM(BertPreTrainedModel):
prediction_scores
=
self
.
cls
(
sequence_output
)
prediction_scores
=
self
.
cls
(
sequence_output
)
outputs
=
(
prediction_scores
,)
+
outputs
[
2
:]
# Add hidden states and attention if they are here
outputs
=
(
prediction_scores
,)
+
outputs
[
2
:]
# Add hidden states and attention if they are here
# Although this may seem awkward, BertForMaskedLM supports two scenarios:
# 1. If a tensor that contains the indices of masked labels is provided,
# the cross-entropy is the MLM cross-entropy that measures the likelihood
# of predictions for masked words.
# 2. If encoder hidden states are provided we are in a causal situation where we
# try to predict the next word for each input in the encoder.
if
masked_lm_labels
is
not
None
and
encoder_hidden_states
is
not
None
:
raise
AttributeError
(
"Masked LM training with an encoder-decoder is not supported."
)
if
masked_lm_labels
is
not
None
:
if
masked_lm_labels
is
not
None
:
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
# -1 index = padding token
masked_lm_loss
=
loss_fct
(
prediction_scores
.
view
(
-
1
,
self
.
config
.
vocab_size
),
masked_lm_labels
.
view
(
-
1
))
masked_lm_loss
=
loss_fct
(
prediction_scores
.
view
(
-
1
,
self
.
config
.
vocab_size
),
masked_lm_labels
.
view
(
-
1
))
outputs
=
(
masked_lm_loss
,)
+
outputs
outputs
=
(
masked_lm_loss
,)
+
outputs
if
encoder_hidden_states
is
not
None
:
if
encoder_hidden_states
is
not
None
:
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
# we are doing next-token prediction; shift prediction scores and input ids by one
# shift predictions scores and input ids by one before computing loss
prediction_scores
=
prediction_scores
[:,
:
-
1
,
:]
prediction_scores
=
prediction_scores
[:,
:
-
1
,
:]
input_ids
=
input_ids
[:,
1
:,
:]
input_ids
=
input_ids
[:,
1
:,
:]
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
seq2seq_loss
=
loss_fct
(
prediction_scores
.
view
(
-
1
,
self
.
config
.
vocab_size
),
input_ids
.
view
(
-
1
))
seq2seq_loss
=
loss_fct
(
prediction_scores
.
view
(
-
1
,
self
.
config
.
vocab_size
),
input_ids
.
view
(
-
1
))
outputs
=
(
seq2seq_loss
,)
+
outputs
outputs
=
(
seq2seq_loss
,)
+
outputs
return
outputs
# (m
asked_lm
_loss), prediction_scores, (hidden_states), (attentions)
return
outputs
# (m
lm_or_seq2seq
_loss), prediction_scores, (hidden_states), (attentions)
@
add_start_docstrings
(
"""Bert Model with a `next sentence prediction (classification)` head on top. """
,
@
add_start_docstrings
(
"""Bert Model with a `next sentence prediction (classification)` head on top. """
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment