Unverified Commit 6f68d559 authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

Merge pull request #2130 from huggingface/ignored-index-coherence

[BREAKING CHANGE] Setting all ignored index to the PyTorch standard
parents 18601c3b dc667ce1
...@@ -75,7 +75,7 @@ def pre_process_datasets(encoded_datasets, input_len, cap_length, start_token, d ...@@ -75,7 +75,7 @@ def pre_process_datasets(encoded_datasets, input_len, cap_length, start_token, d
n_batch = len(dataset) n_batch = len(dataset)
input_ids = np.zeros((n_batch, 2, input_len), dtype=np.int64) input_ids = np.zeros((n_batch, 2, input_len), dtype=np.int64)
mc_token_ids = np.zeros((n_batch, 2), dtype=np.int64) mc_token_ids = np.zeros((n_batch, 2), dtype=np.int64)
lm_labels = np.full((n_batch, 2, input_len), fill_value=-1, dtype=np.int64) lm_labels = np.full((n_batch, 2, input_len), fill_value=-100, dtype=np.int64)
mc_labels = np.zeros((n_batch,), dtype=np.int64) mc_labels = np.zeros((n_batch,), dtype=np.int64)
for i, (story, cont1, cont2, mc_label), in enumerate(dataset): for i, (story, cont1, cont2, mc_label), in enumerate(dataset):
with_cont1 = [start_token] + story[:cap_length] + [delimiter_token] + cont1[:cap_length] + [clf_token] with_cont1 = [start_token] + story[:cap_length] + [delimiter_token] + cont1[:cap_length] + [clf_token]
......
...@@ -112,7 +112,7 @@ class Distiller: ...@@ -112,7 +112,7 @@ class Distiller:
self.last_log = 0 self.last_log = 0
self.ce_loss_fct = nn.KLDivLoss(reduction='batchmean') self.ce_loss_fct = nn.KLDivLoss(reduction='batchmean')
self.lm_loss_fct = nn.CrossEntropyLoss(ignore_index=-1) self.lm_loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
if self.alpha_mse > 0.: if self.alpha_mse > 0.:
self.mse_loss_fct = nn.MSELoss(reduction='sum') self.mse_loss_fct = nn.MSELoss(reduction='sum')
if self.alpha_cos > 0.: if self.alpha_cos > 0.:
...@@ -186,7 +186,7 @@ class Distiller: ...@@ -186,7 +186,7 @@ class Distiller:
------- -------
token_ids: `torch.tensor(bs, seq_length)` - The token ids after the modifications for MLM. token_ids: `torch.tensor(bs, seq_length)` - The token ids after the modifications for MLM.
attn_mask: `torch.tensor(bs, seq_length)` - The attention mask for the self-attention. attn_mask: `torch.tensor(bs, seq_length)` - The attention mask for the self-attention.
mlm_labels: `torch.tensor(bs, seq_length)` - The masked languge modeling labels. There is a -1 where there is nothing to predict. mlm_labels: `torch.tensor(bs, seq_length)` - The masked languge modeling labels. There is a -100 where there is nothing to predict.
""" """
token_ids, lengths = batch token_ids, lengths = batch
token_ids, lengths = self.round_batch(x=token_ids, lengths=lengths) token_ids, lengths = self.round_batch(x=token_ids, lengths=lengths)
...@@ -224,7 +224,7 @@ class Distiller: ...@@ -224,7 +224,7 @@ class Distiller:
_token_ids = _token_ids_mask * (probs == 0).long() + _token_ids_real * (probs == 1).long() + _token_ids_rand * (probs == 2).long() _token_ids = _token_ids_mask * (probs == 0).long() + _token_ids_real * (probs == 1).long() + _token_ids_rand * (probs == 2).long()
token_ids = token_ids.masked_scatter(pred_mask, _token_ids) token_ids = token_ids.masked_scatter(pred_mask, _token_ids)
mlm_labels[~pred_mask] = -1 # previously `mlm_labels[1-pred_mask] = -1`, cf pytorch 1.2.0 compatibility mlm_labels[~pred_mask] = -100 # previously `mlm_labels[1-pred_mask] = -1`, cf pytorch 1.2.0 compatibility
# sanity checks # sanity checks
assert 0 <= token_ids.min() <= token_ids.max() < self.vocab_size assert 0 <= token_ids.min() <= token_ids.max() < self.vocab_size
...@@ -246,7 +246,7 @@ class Distiller: ...@@ -246,7 +246,7 @@ class Distiller:
------- -------
token_ids: `torch.tensor(bs, seq_length)` - The token ids after the modifications for MLM. token_ids: `torch.tensor(bs, seq_length)` - The token ids after the modifications for MLM.
attn_mask: `torch.tensor(bs, seq_length)` - The attention mask for the self-attention. attn_mask: `torch.tensor(bs, seq_length)` - The attention mask for the self-attention.
clm_labels: `torch.tensor(bs, seq_length)` - The causal languge modeling labels. There is a -1 where there is nothing to predict. clm_labels: `torch.tensor(bs, seq_length)` - The causal languge modeling labels. There is a -100 where there is nothing to predict.
""" """
token_ids, lengths = batch token_ids, lengths = batch
token_ids, lengths = self.round_batch(x=token_ids, lengths=lengths) token_ids, lengths = self.round_batch(x=token_ids, lengths=lengths)
...@@ -254,7 +254,7 @@ class Distiller: ...@@ -254,7 +254,7 @@ class Distiller:
attn_mask = (torch.arange(token_ids.size(1), dtype=torch.long, device=lengths.device) < lengths[:, None]) attn_mask = (torch.arange(token_ids.size(1), dtype=torch.long, device=lengths.device) < lengths[:, None])
clm_labels = token_ids.new(token_ids.size()).copy_(token_ids) clm_labels = token_ids.new(token_ids.size()).copy_(token_ids)
clm_labels[~attn_mask] = -1 # previously `clm_labels[1-attn_mask] = -1`, cf pytorch 1.2.0 compatibility clm_labels[~attn_mask] = -100 # previously `clm_labels[1-attn_mask] = -1`, cf pytorch 1.2.0 compatibility
# sanity checks # sanity checks
assert 0 <= token_ids.min() <= token_ids.max() < self.vocab_size assert 0 <= token_ids.min() <= token_ids.max() < self.vocab_size
......
...@@ -150,7 +150,7 @@ def mask_tokens(inputs, tokenizer, args): ...@@ -150,7 +150,7 @@ def mask_tokens(inputs, tokenizer, args):
special_tokens_mask = [tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()] special_tokens_mask = [tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()]
probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0) probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0)
masked_indices = torch.bernoulli(probability_matrix).bool() masked_indices = torch.bernoulli(probability_matrix).bool()
labels[~masked_indices] = -1 # We only compute loss on masked tokens labels[~masked_indices] = -100 # We only compute loss on masked tokens
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK]) # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
......
...@@ -94,7 +94,7 @@ def convert_examples_to_features(examples, ...@@ -94,7 +94,7 @@ def convert_examples_to_features(examples,
pad_on_left=False, pad_on_left=False,
pad_token=0, pad_token=0,
pad_token_segment_id=0, pad_token_segment_id=0,
pad_token_label_id=-1, pad_token_label_id=-100,
sequence_a_segment_id=0, sequence_a_segment_id=0,
mask_padding_with_zero=True): mask_padding_with_zero=True):
""" Loads a data file into a list of `InputBatch`s """ Loads a data file into a list of `InputBatch`s
......
...@@ -364,7 +364,7 @@ class XxxForMaskedLM(XxxPreTrainedModel): ...@@ -364,7 +364,7 @@ class XxxForMaskedLM(XxxPreTrainedModel):
**masked_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: **masked_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Labels for computing the masked language modeling loss. Labels for computing the masked language modeling loss.
Indices should be in ``[-1, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Indices should be in ``[-1, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
Tokens with indices set to ``-1`` are ignored (masked), the loss is only computed for the tokens with labels Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
in ``[0, ..., config.vocab_size]`` in ``[0, ..., config.vocab_size]``
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
...@@ -415,7 +415,7 @@ class XxxForMaskedLM(XxxPreTrainedModel): ...@@ -415,7 +415,7 @@ class XxxForMaskedLM(XxxPreTrainedModel):
outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here
if masked_lm_labels is not None: if masked_lm_labels is not None:
loss_fct = CrossEntropyLoss(ignore_index=-1) loss_fct = CrossEntropyLoss()
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1)) masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
outputs = (masked_lm_loss,) + outputs outputs = (masked_lm_loss,) + outputs
......
...@@ -572,7 +572,7 @@ class AlbertForMaskedLM(AlbertPreTrainedModel): ...@@ -572,7 +572,7 @@ class AlbertForMaskedLM(AlbertPreTrainedModel):
**masked_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: **masked_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Labels for computing the masked language modeling loss. Labels for computing the masked language modeling loss.
Indices should be in ``[-1, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Indices should be in ``[-1, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
Tokens with indices set to ``-1`` are ignored (masked), the loss is only computed for the tokens with labels Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
in ``[0, ..., config.vocab_size]`` in ``[0, ..., config.vocab_size]``
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
...@@ -624,7 +624,7 @@ class AlbertForMaskedLM(AlbertPreTrainedModel): ...@@ -624,7 +624,7 @@ class AlbertForMaskedLM(AlbertPreTrainedModel):
outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here
if masked_lm_labels is not None: if masked_lm_labels is not None:
loss_fct = CrossEntropyLoss(ignore_index=-1) loss_fct = CrossEntropyLoss()
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1)) masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
outputs = (masked_lm_loss,) + outputs outputs = (masked_lm_loss,) + outputs
......
...@@ -754,7 +754,7 @@ class BertForPreTraining(BertPreTrainedModel): ...@@ -754,7 +754,7 @@ class BertForPreTraining(BertPreTrainedModel):
**masked_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: **masked_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Labels for computing the masked language modeling loss. Labels for computing the masked language modeling loss.
Indices should be in ``[-1, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Indices should be in ``[-1, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
Tokens with indices set to ``-1`` are ignored (masked), the loss is only computed for the tokens with labels Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
in ``[0, ..., config.vocab_size]`` in ``[0, ..., config.vocab_size]``
**next_sentence_label**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``: **next_sentence_label**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair (see ``input_ids`` docstring) Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair (see ``input_ids`` docstring)
...@@ -813,7 +813,7 @@ class BertForPreTraining(BertPreTrainedModel): ...@@ -813,7 +813,7 @@ class BertForPreTraining(BertPreTrainedModel):
outputs = (prediction_scores, seq_relationship_score,) + outputs[2:] # add hidden states and attention if they are here outputs = (prediction_scores, seq_relationship_score,) + outputs[2:] # add hidden states and attention if they are here
if masked_lm_labels is not None and next_sentence_label is not None: if masked_lm_labels is not None and next_sentence_label is not None:
loss_fct = CrossEntropyLoss(ignore_index=-1) loss_fct = CrossEntropyLoss()
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1)) masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
total_loss = masked_lm_loss + next_sentence_loss total_loss = masked_lm_loss + next_sentence_loss
...@@ -830,12 +830,12 @@ class BertForMaskedLM(BertPreTrainedModel): ...@@ -830,12 +830,12 @@ class BertForMaskedLM(BertPreTrainedModel):
**masked_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: **masked_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Labels for computing the masked language modeling loss. Labels for computing the masked language modeling loss.
Indices should be in ``[-1, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Indices should be in ``[-1, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
Tokens with indices set to ``-1`` are ignored (masked), the loss is only computed for the tokens with labels Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
in ``[0, ..., config.vocab_size]`` in ``[0, ..., config.vocab_size]``
**lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: **lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Labels for computing the left-to-right language modeling loss (next word prediction). Labels for computing the left-to-right language modeling loss (next word prediction).
Indices should be in ``[-1, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Indices should be in ``[-1, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
Tokens with indices set to ``-1`` are ignored (masked), the loss is only computed for the tokens with labels Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
in ``[0, ..., config.vocab_size]`` in ``[0, ..., config.vocab_size]``
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
...@@ -897,7 +897,7 @@ class BertForMaskedLM(BertPreTrainedModel): ...@@ -897,7 +897,7 @@ class BertForMaskedLM(BertPreTrainedModel):
# 2. If `lm_labels` is provided we are in a causal scenario where we # 2. If `lm_labels` is provided we are in a causal scenario where we
# try to predict the next token for each input in the decoder. # try to predict the next token for each input in the decoder.
if masked_lm_labels is not None: if masked_lm_labels is not None:
loss_fct = CrossEntropyLoss(ignore_index=-1) # -1 index = padding token loss_fct = CrossEntropyLoss() # -100 index = padding token
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1)) masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
outputs = (masked_lm_loss,) + outputs outputs = (masked_lm_loss,) + outputs
...@@ -905,7 +905,7 @@ class BertForMaskedLM(BertPreTrainedModel): ...@@ -905,7 +905,7 @@ class BertForMaskedLM(BertPreTrainedModel):
# we are doing next-token prediction; shift prediction scores and input ids by one # we are doing next-token prediction; shift prediction scores and input ids by one
prediction_scores = prediction_scores[:, :-1, :].contiguous() prediction_scores = prediction_scores[:, :-1, :].contiguous()
lm_labels = lm_labels[:, 1:].contiguous() lm_labels = lm_labels[:, 1:].contiguous()
loss_fct = CrossEntropyLoss(ignore_index=-1) loss_fct = CrossEntropyLoss()
ltr_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), lm_labels.view(-1)) ltr_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), lm_labels.view(-1))
outputs = (ltr_lm_loss,) + outputs outputs = (ltr_lm_loss,) + outputs
...@@ -969,7 +969,7 @@ class BertForNextSentencePrediction(BertPreTrainedModel): ...@@ -969,7 +969,7 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
outputs = (seq_relationship_score,) + outputs[2:] # add hidden states and attention if they are here outputs = (seq_relationship_score,) + outputs[2:] # add hidden states and attention if they are here
if next_sentence_label is not None: if next_sentence_label is not None:
loss_fct = CrossEntropyLoss(ignore_index=-1) loss_fct = CrossEntropyLoss()
next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
outputs = (next_sentence_loss,) + outputs outputs = (next_sentence_loss,) + outputs
......
...@@ -156,7 +156,7 @@ class CamembertForMaskedLM(RobertaForMaskedLM): ...@@ -156,7 +156,7 @@ class CamembertForMaskedLM(RobertaForMaskedLM):
**masked_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: **masked_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Labels for computing the masked language modeling loss. Labels for computing the masked language modeling loss.
Indices should be in ``[-1, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Indices should be in ``[-1, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
Tokens with indices set to ``-1`` are ignored (masked), the loss is only computed for the tokens with labels Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
in ``[0, ..., config.vocab_size]`` in ``[0, ..., config.vocab_size]``
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
......
...@@ -429,7 +429,7 @@ class CTRLLMHeadModel(CTRLPreTrainedModel): ...@@ -429,7 +429,7 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
Labels for language modeling. Labels for language modeling.
Note that the labels **are shifted** inside the model, i.e. you can set ``lm_labels = input_ids`` Note that the labels **are shifted** inside the model, i.e. you can set ``lm_labels = input_ids``
Indices are selected in ``[-1, 0, ..., config.vocab_size]`` Indices are selected in ``[-1, 0, ..., config.vocab_size]``
All labels set to ``-1`` are ignored (masked), the loss is only All labels set to ``-100`` are ignored (masked), the loss is only
computed for labels in ``[0, ..., config.vocab_size]`` computed for labels in ``[0, ..., config.vocab_size]``
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
...@@ -494,7 +494,7 @@ class CTRLLMHeadModel(CTRLPreTrainedModel): ...@@ -494,7 +494,7 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
shift_logits = lm_logits[..., :-1, :].contiguous() shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous() shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens # Flatten the tokens
loss_fct = CrossEntropyLoss(ignore_index=-1) loss_fct = CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1)) shift_labels.view(-1))
outputs = (loss,) + outputs outputs = (loss,) + outputs
......
...@@ -491,7 +491,7 @@ class DistilBertForMaskedLM(DistilBertPreTrainedModel): ...@@ -491,7 +491,7 @@ class DistilBertForMaskedLM(DistilBertPreTrainedModel):
**masked_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: **masked_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Labels for computing the masked language modeling loss. Labels for computing the masked language modeling loss.
Indices should be in ``[-1, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Indices should be in ``[-1, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
Tokens with indices set to ``-1`` are ignored (masked), the loss is only computed for the tokens with labels Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
in ``[0, ..., config.vocab_size]`` in ``[0, ..., config.vocab_size]``
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
...@@ -528,7 +528,7 @@ class DistilBertForMaskedLM(DistilBertPreTrainedModel): ...@@ -528,7 +528,7 @@ class DistilBertForMaskedLM(DistilBertPreTrainedModel):
self.init_weights() self.init_weights()
self.mlm_loss_fct = nn.CrossEntropyLoss(ignore_index=-1) self.mlm_loss_fct = nn.CrossEntropyLoss()
def get_output_embeddings(self): def get_output_embeddings(self):
return self.vocab_projector return self.vocab_projector
......
...@@ -494,7 +494,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): ...@@ -494,7 +494,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
Labels for language modeling. Labels for language modeling.
Note that the labels **are shifted** inside the model, i.e. you can set ``lm_labels = input_ids`` Note that the labels **are shifted** inside the model, i.e. you can set ``lm_labels = input_ids``
Indices are selected in ``[-1, 0, ..., config.vocab_size]`` Indices are selected in ``[-1, 0, ..., config.vocab_size]``
All labels set to ``-1`` are ignored (masked), the loss is only All labels set to ``-100`` are ignored (masked), the loss is only
computed for labels in ``[0, ..., config.vocab_size]`` computed for labels in ``[0, ..., config.vocab_size]``
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
...@@ -557,7 +557,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): ...@@ -557,7 +557,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
shift_logits = lm_logits[..., :-1, :].contiguous() shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous() shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens # Flatten the tokens
loss_fct = CrossEntropyLoss(ignore_index=-1) loss_fct = CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1)) shift_labels.view(-1))
outputs = (loss,) + outputs outputs = (loss,) + outputs
...@@ -579,7 +579,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel): ...@@ -579,7 +579,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
Labels for language modeling. Labels for language modeling.
Note that the labels **are shifted** inside the model, i.e. you can set ``lm_labels = input_ids`` Note that the labels **are shifted** inside the model, i.e. you can set ``lm_labels = input_ids``
Indices are selected in ``[-1, 0, ..., config.vocab_size]`` Indices are selected in ``[-1, 0, ..., config.vocab_size]``
All labels set to ``-1`` are ignored (masked), the loss is only All labels set to ``-100`` are ignored (masked), the loss is only
computed for labels in ``[0, ..., config.vocab_size]`` computed for labels in ``[0, ..., config.vocab_size]``
**mc_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size)``: **mc_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size)``:
Labels for computing the multiple choice classification loss. Labels for computing the multiple choice classification loss.
...@@ -668,7 +668,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel): ...@@ -668,7 +668,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
if lm_labels is not None: if lm_labels is not None:
shift_logits = lm_logits[..., :-1, :].contiguous() shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = lm_labels[..., 1:].contiguous() shift_labels = lm_labels[..., 1:].contiguous()
loss_fct = CrossEntropyLoss(ignore_index=-1) loss_fct = CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1)) shift_labels.view(-1))
outputs = (loss,) + outputs outputs = (loss,) + outputs
......
...@@ -471,7 +471,7 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel): ...@@ -471,7 +471,7 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
Labels for language modeling. Labels for language modeling.
Note that the labels **are shifted** inside the model, i.e. you can set ``labels = input_ids`` Note that the labels **are shifted** inside the model, i.e. you can set ``labels = input_ids``
Indices are selected in ``[-1, 0, ..., config.vocab_size]`` Indices are selected in ``[-1, 0, ..., config.vocab_size]``
All labels set to ``-1`` are ignored (masked), the loss is only All labels set to ``-100`` are ignored (masked), the loss is only
computed for labels in ``[0, ..., config.vocab_size]`` computed for labels in ``[0, ..., config.vocab_size]``
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
...@@ -523,7 +523,7 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel): ...@@ -523,7 +523,7 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
shift_logits = lm_logits[..., :-1, :].contiguous() shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous() shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens # Flatten the tokens
loss_fct = CrossEntropyLoss(ignore_index=-1) loss_fct = CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1)) shift_labels.view(-1))
outputs = (loss,) + outputs outputs = (loss,) + outputs
...@@ -545,7 +545,7 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel): ...@@ -545,7 +545,7 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
Labels for language modeling. Labels for language modeling.
Note that the labels **are shifted** inside the model, i.e. you can set ``lm_labels = input_ids`` Note that the labels **are shifted** inside the model, i.e. you can set ``lm_labels = input_ids``
Indices are selected in ``[-1, 0, ..., config.vocab_size]`` Indices are selected in ``[-1, 0, ..., config.vocab_size]``
All labels set to ``-1`` are ignored (masked), the loss is only All labels set to ``-100`` are ignored (masked), the loss is only
computed for labels in ``[0, ..., config.vocab_size]`` computed for labels in ``[0, ..., config.vocab_size]``
**mc_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size)``: **mc_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size)``:
Labels for computing the multiple choice classification loss. Labels for computing the multiple choice classification loss.
...@@ -622,7 +622,7 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel): ...@@ -622,7 +622,7 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
if lm_labels is not None: if lm_labels is not None:
shift_logits = lm_logits[..., :-1, :].contiguous() shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = lm_labels[..., 1:].contiguous() shift_labels = lm_labels[..., 1:].contiguous()
loss_fct = CrossEntropyLoss(ignore_index=-1) loss_fct = CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1)) shift_labels.view(-1))
outputs = (loss,) + outputs outputs = (loss,) + outputs
......
...@@ -216,7 +216,7 @@ class RobertaForMaskedLM(BertPreTrainedModel): ...@@ -216,7 +216,7 @@ class RobertaForMaskedLM(BertPreTrainedModel):
**masked_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: **masked_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Labels for computing the masked language modeling loss. Labels for computing the masked language modeling loss.
Indices should be in ``[-1, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Indices should be in ``[-1, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
Tokens with indices set to ``-1`` are ignored (masked), the loss is only computed for the tokens with labels Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
in ``[0, ..., config.vocab_size]`` in ``[0, ..., config.vocab_size]``
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
...@@ -270,7 +270,7 @@ class RobertaForMaskedLM(BertPreTrainedModel): ...@@ -270,7 +270,7 @@ class RobertaForMaskedLM(BertPreTrainedModel):
outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here
if masked_lm_labels is not None: if masked_lm_labels is not None:
loss_fct = CrossEntropyLoss(ignore_index=-1) loss_fct = CrossEntropyLoss()
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1)) masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
outputs = (masked_lm_loss,) + outputs outputs = (masked_lm_loss,) + outputs
......
...@@ -250,12 +250,6 @@ class TFRobertaLMHead(tf.keras.layers.Layer): ...@@ -250,12 +250,6 @@ class TFRobertaLMHead(tf.keras.layers.Layer):
ROBERTA_START_DOCSTRING, ROBERTA_INPUTS_DOCSTRING) ROBERTA_START_DOCSTRING, ROBERTA_INPUTS_DOCSTRING)
class TFRobertaForMaskedLM(TFRobertaPreTrainedModel): class TFRobertaForMaskedLM(TFRobertaPreTrainedModel):
r""" r"""
**masked_lm_labels**: (`optional`) ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, sequence_length)``:
Labels for computing the masked language modeling loss.
Indices should be in ``[-1, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
Tokens with indices set to ``-1`` are ignored (masked), the loss is only computed for the tokens with labels
in ``[0, ..., config.vocab_size]``
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**loss**: (`optional`, returned when ``masked_lm_labels`` is provided) ``tf.Tensor`` of shape ``(1,)``: **loss**: (`optional`, returned when ``masked_lm_labels`` is provided) ``tf.Tensor`` of shape ``(1,)``:
Masked language modeling loss. Masked language modeling loss.
......
...@@ -796,17 +796,17 @@ class TransfoXLModel(TransfoXLPreTrainedModel): ...@@ -796,17 +796,17 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
TRANSFO_XL_START_DOCSTRING, TRANSFO_XL_INPUTS_DOCSTRING) TRANSFO_XL_START_DOCSTRING, TRANSFO_XL_INPUTS_DOCSTRING)
class TransfoXLLMHeadModel(TransfoXLPreTrainedModel): class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
r""" r"""
**lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: **labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Labels for language modeling. Labels for language modeling.
Note that the labels **are shifted** inside the model, i.e. you can set ``lm_labels = input_ids`` Note that the labels **are shifted** inside the model, i.e. you can set ``labels = input_ids``
Indices are selected in ``[-1, 0, ..., config.vocab_size]`` Indices are selected in ``[-1, 0, ..., config.vocab_size]``
All labels set to ``-1`` are ignored (masked), the loss is only All labels set to ``-100`` are ignored (masked), the loss is only
computed for labels in ``[0, ..., config.vocab_size]`` computed for labels in ``[0, ..., config.vocab_size]``
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**loss**: (`optional`, returned when ``lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``: **loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Language modeling loss. Language modeling loss.
**prediction_scores**: ``None`` if ``lm_labels`` is provided else ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)`` **prediction_scores**: ``None`` if ``labels`` is provided else ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
We don't output them when the loss is computed to speedup adaptive softmax decoding. We don't output them when the loss is computed to speedup adaptive softmax decoding.
**mems**: **mems**:
......
...@@ -614,7 +614,7 @@ class XLMWithLMHeadModel(XLMPreTrainedModel): ...@@ -614,7 +614,7 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
Labels for language modeling. Labels for language modeling.
Note that the labels **are shifted** inside the model, i.e. you can set ``lm_labels = input_ids`` Note that the labels **are shifted** inside the model, i.e. you can set ``lm_labels = input_ids``
Indices are selected in ``[-1, 0, ..., config.vocab_size]`` Indices are selected in ``[-1, 0, ..., config.vocab_size]``
All labels set to ``-1`` are ignored (masked), the loss is only All labels set to ``-100`` are ignored (masked), the loss is only
computed for labels in ``[0, ..., config.vocab_size]`` computed for labels in ``[0, ..., config.vocab_size]``
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
......
...@@ -898,7 +898,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel): ...@@ -898,7 +898,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
Labels for language modeling. Labels for language modeling.
Note that the labels **are shifted** inside the model, i.e. you can set ``lm_labels = input_ids`` Note that the labels **are shifted** inside the model, i.e. you can set ``lm_labels = input_ids``
Indices are selected in ``[-1, 0, ..., config.vocab_size]`` Indices are selected in ``[-1, 0, ..., config.vocab_size]``
All labels set to ``-1`` are ignored (masked), the loss is only All labels set to ``-100`` are ignored (masked), the loss is only
computed for labels in ``[0, ..., config.vocab_size]`` computed for labels in ``[0, ..., config.vocab_size]``
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
...@@ -989,7 +989,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel): ...@@ -989,7 +989,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
if labels is not None: if labels is not None:
# Flatten the tokens # Flatten the tokens
loss_fct = CrossEntropyLoss(ignore_index=-1) loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, logits.size(-1)), loss = loss_fct(logits.view(-1, logits.size(-1)),
labels.view(-1)) labels.view(-1))
outputs = (loss,) + outputs outputs = (loss,) + outputs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment