"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "fbd879219584e4a3b438d41306d68a6d85ce3da8"
Commit 4e0f2434 authored by Rémi Louf's avatar Rémi Louf
Browse files

document the MLM modification + raise exception on MLM training with encoder-decoder

parent 624a5644
...@@ -830,21 +830,30 @@ class BertForMaskedLM(BertPreTrainedModel): ...@@ -830,21 +830,30 @@ class BertForMaskedLM(BertPreTrainedModel):
prediction_scores = self.cls(sequence_output) prediction_scores = self.cls(sequence_output)
outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here
# Although this may seem awkward, BertForMaskedLM supports two scenarios:
# 1. If a tensor that contains the indices of masked labels is provided,
# the cross-entropy is the MLM cross-entropy that measures the likelihood
# of predictions for masked words.
# 2. If encoder hidden states are provided we are in a causal situation where we
# try to predict the next word for each input in the encoder.
if masked_lm_labels is not None and encoder_hidden_states is not None:
raise AttributeError("Masked LM training with an encoder-decoder is not supported.")
if masked_lm_labels is not None: if masked_lm_labels is not None:
loss_fct = CrossEntropyLoss(ignore_index=-1) loss_fct = CrossEntropyLoss(ignore_index=-1) # -1 index = padding token
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1)) masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
outputs = (masked_lm_loss,) + outputs outputs = (masked_lm_loss,) + outputs
if encoder_hidden_states is not None: if encoder_hidden_states is not None:
loss_fct = CrossEntropyLoss(ignore_index=-1) # we are doing next-token prediction; shift prediction scores and input ids by one
# shift predictions scores and input ids by one before computing loss
prediction_scores = prediction_scores[:, :-1, :] prediction_scores = prediction_scores[:, :-1, :]
input_ids = input_ids[:, 1:, :] input_ids = input_ids[:, 1:, :]
loss_fct = CrossEntropyLoss(ignore_index=-1)
seq2seq_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), input_ids.view(-1)) seq2seq_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), input_ids.view(-1))
outputs = (seq2seq_loss,) + outputs outputs = (seq2seq_loss,) + outputs
return outputs # (masked_lm_loss), prediction_scores, (hidden_states), (attentions) return outputs # (mlm_or_seq2seq_loss), prediction_scores, (hidden_states), (attentions)
@add_start_docstrings("""Bert Model with a `next sentence prediction (classification)` head on top. """, @add_start_docstrings("""Bert Model with a `next sentence prediction (classification)` head on top. """,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment