"vscode:/vscode.git/clone" did not exist on "367f96937e2fc62521649cbb789d06860c1f2777"
Unverified Commit 54659048 authored by amyeroberts's avatar amyeroberts Committed by GitHub
Browse files

Early labels validation (#31240)

* Move label validation checks - fail early

* Remove some formatting changes - add back labels change wav2vec2
parent 03ea1609
...@@ -1575,9 +1575,11 @@ class SEWDForCTC(SEWDPreTrainedModel): ...@@ -1575,9 +1575,11 @@ class SEWDForCTC(SEWDPreTrainedModel):
All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ...,
config.vocab_size - 1]`. config.vocab_size - 1]`.
""" """
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if labels is not None and labels.max() >= self.config.vocab_size:
raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
outputs = self.sew_d( outputs = self.sew_d(
input_values, input_values,
attention_mask=attention_mask, attention_mask=attention_mask,
...@@ -1593,9 +1595,6 @@ class SEWDForCTC(SEWDPreTrainedModel): ...@@ -1593,9 +1595,6 @@ class SEWDForCTC(SEWDPreTrainedModel):
loss = None loss = None
if labels is not None: if labels is not None:
if labels.max() >= self.config.vocab_size:
raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
# retrieve loss input_lengths from attention_mask # retrieve loss input_lengths from attention_mask
attention_mask = ( attention_mask = (
attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long) attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long)
......
...@@ -1128,6 +1128,10 @@ class Swin2SRForImageSuperResolution(Swin2SRPreTrainedModel): ...@@ -1128,6 +1128,10 @@ class Swin2SRForImageSuperResolution(Swin2SRPreTrainedModel):
```""" ```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
loss = None
if labels is not None:
raise NotImplementedError("Training is not supported at the moment")
height, width = pixel_values.shape[2:] height, width = pixel_values.shape[2:]
if self.config.upsampler == "pixelshuffle_aux": if self.config.upsampler == "pixelshuffle_aux":
...@@ -1159,10 +1163,6 @@ class Swin2SRForImageSuperResolution(Swin2SRPreTrainedModel): ...@@ -1159,10 +1163,6 @@ class Swin2SRForImageSuperResolution(Swin2SRPreTrainedModel):
reconstruction = reconstruction / self.swin2sr.img_range + self.swin2sr.mean reconstruction = reconstruction / self.swin2sr.img_range + self.swin2sr.mean
reconstruction = reconstruction[:, :, : height * self.upscale, : width * self.upscale] reconstruction = reconstruction[:, :, : height * self.upscale, : width * self.upscale]
loss = None
if labels is not None:
raise NotImplementedError("Training is not supported at the moment")
if not return_dict: if not return_dict:
output = (reconstruction,) + outputs[1:] output = (reconstruction,) + outputs[1:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
......
...@@ -1824,9 +1824,11 @@ class UniSpeechForCTC(UniSpeechPreTrainedModel): ...@@ -1824,9 +1824,11 @@ class UniSpeechForCTC(UniSpeechPreTrainedModel):
All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ...,
config.vocab_size - 1]`. config.vocab_size - 1]`.
""" """
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if labels is not None and labels.max() >= self.config.vocab_size:
raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
outputs = self.unispeech( outputs = self.unispeech(
input_values, input_values,
attention_mask=attention_mask, attention_mask=attention_mask,
...@@ -1842,9 +1844,6 @@ class UniSpeechForCTC(UniSpeechPreTrainedModel): ...@@ -1842,9 +1844,6 @@ class UniSpeechForCTC(UniSpeechPreTrainedModel):
loss = None loss = None
if labels is not None: if labels is not None:
if labels.max() >= self.config.vocab_size:
raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
# retrieve loss input_lengths from attention_mask # retrieve loss input_lengths from attention_mask
attention_mask = ( attention_mask = (
attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long) attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long)
......
...@@ -1834,9 +1834,11 @@ class UniSpeechSatForCTC(UniSpeechSatPreTrainedModel): ...@@ -1834,9 +1834,11 @@ class UniSpeechSatForCTC(UniSpeechSatPreTrainedModel):
All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ...,
config.vocab_size - 1]`. config.vocab_size - 1]`.
""" """
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if labels is not None and labels.max() >= self.config.vocab_size:
raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
outputs = self.unispeech_sat( outputs = self.unispeech_sat(
input_values, input_values,
attention_mask=attention_mask, attention_mask=attention_mask,
...@@ -1852,9 +1854,6 @@ class UniSpeechSatForCTC(UniSpeechSatPreTrainedModel): ...@@ -1852,9 +1854,6 @@ class UniSpeechSatForCTC(UniSpeechSatPreTrainedModel):
loss = None loss = None
if labels is not None: if labels is not None:
if labels.max() >= self.config.vocab_size:
raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
# retrieve loss input_lengths from attention_mask # retrieve loss input_lengths from attention_mask
attention_mask = ( attention_mask = (
attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long) attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long)
......
...@@ -392,6 +392,8 @@ class UperNetForSemanticSegmentation(UperNetPreTrainedModel): ...@@ -392,6 +392,8 @@ class UperNetForSemanticSegmentation(UperNetPreTrainedModel):
>>> list(logits.shape) >>> list(logits.shape)
[1, 150, 512, 512] [1, 150, 512, 512]
```""" ```"""
if labels is not None and self.config.num_labels == 1:
raise ValueError("The number of labels should be greater than one")
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
output_hidden_states = ( output_hidden_states = (
...@@ -416,9 +418,6 @@ class UperNetForSemanticSegmentation(UperNetPreTrainedModel): ...@@ -416,9 +418,6 @@ class UperNetForSemanticSegmentation(UperNetPreTrainedModel):
loss = None loss = None
if labels is not None: if labels is not None:
if self.config.num_labels == 1:
raise ValueError("The number of labels should be greater than one")
else:
# compute weighted loss # compute weighted loss
loss_fct = CrossEntropyLoss(ignore_index=self.config.loss_ignore_index) loss_fct = CrossEntropyLoss(ignore_index=self.config.loss_ignore_index)
loss = loss_fct(logits, labels) loss = loss_fct(logits, labels)
......
...@@ -1226,6 +1226,10 @@ class ViltForImageAndTextRetrieval(ViltPreTrainedModel): ...@@ -1226,6 +1226,10 @@ class ViltForImageAndTextRetrieval(ViltPreTrainedModel):
```""" ```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
loss = None
if labels is not None:
raise NotImplementedError("Training is not yet supported.")
outputs = self.vilt( outputs = self.vilt(
input_ids, input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
...@@ -1244,12 +1248,6 @@ class ViltForImageAndTextRetrieval(ViltPreTrainedModel): ...@@ -1244,12 +1248,6 @@ class ViltForImageAndTextRetrieval(ViltPreTrainedModel):
logits = self.rank_output(pooler_output) logits = self.rank_output(pooler_output)
loss = None
if labels is not None:
# move labels to correct device to enable PP
labels = labels.to(logits.device)
raise NotImplementedError("Training is not yet supported.")
if not return_dict: if not return_dict:
output = (logits,) + outputs[2:] output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
......
...@@ -939,6 +939,14 @@ class VisualBertForPreTraining(VisualBertPreTrainedModel): ...@@ -939,6 +939,14 @@ class VisualBertForPreTraining(VisualBertPreTrainedModel):
```""" ```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if labels is not None:
total_size = attention_mask.size(-1) + visual_attention_mask.size(-1)
if labels.size(-1) != total_size:
raise ValueError(
"The labels provided should have same sequence length as total attention mask. "
f"Found labels with sequence length {labels.size(-1)}, expected {total_size}."
)
outputs = self.visual_bert( outputs = self.visual_bert(
input_ids, input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
...@@ -960,26 +968,12 @@ class VisualBertForPreTraining(VisualBertPreTrainedModel): ...@@ -960,26 +968,12 @@ class VisualBertForPreTraining(VisualBertPreTrainedModel):
total_loss = None total_loss = None
if labels is not None and sentence_image_labels is not None: if labels is not None and sentence_image_labels is not None:
total_size = attention_mask.size(-1) + visual_attention_mask.size(-1)
if labels.size(-1) != total_size:
raise ValueError(
"The labels provided should have same sequence length as total attention mask. "
f"Found labels with sequence length {labels.size(-1)}, expected {total_size}."
)
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
sentence_image_loss = loss_fct(seq_relationship_score.view(-1, 2), sentence_image_labels.view(-1)) sentence_image_loss = loss_fct(seq_relationship_score.view(-1, 2), sentence_image_labels.view(-1))
total_loss = masked_lm_loss + sentence_image_loss total_loss = masked_lm_loss + sentence_image_loss
if labels is not None and sentence_image_labels is None: elif labels is not None:
total_size = attention_mask.size(-1) + visual_attention_mask.size(-1)
if labels.size(-1) != total_size:
raise ValueError(
"The labels provided should have same sequence length as total attention mask. "
f"Found labels with sequence length {labels.size(-1)}, expected {total_size}."
)
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
total_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) total_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
......
...@@ -310,6 +310,10 @@ class VitMatteForImageMatting(VitMattePreTrainedModel): ...@@ -310,6 +310,10 @@ class VitMatteForImageMatting(VitMattePreTrainedModel):
) )
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
loss = None
if labels is not None:
raise NotImplementedError("Training is not yet supported")
outputs = self.backbone.forward_with_filtered_kwargs( outputs = self.backbone.forward_with_filtered_kwargs(
pixel_values, output_hidden_states=output_hidden_states, output_attentions=output_attentions pixel_values, output_hidden_states=output_hidden_states, output_attentions=output_attentions
) )
...@@ -317,10 +321,6 @@ class VitMatteForImageMatting(VitMattePreTrainedModel): ...@@ -317,10 +321,6 @@ class VitMatteForImageMatting(VitMattePreTrainedModel):
features = outputs.feature_maps[-1] features = outputs.feature_maps[-1]
alphas = self.decoder(features, pixel_values) alphas = self.decoder(features, pixel_values)
loss = None
if labels is not None:
raise NotImplementedError("Training is not yet supported")
if not return_dict: if not return_dict:
output = (alphas,) + outputs[1:] output = (alphas,) + outputs[1:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
......
...@@ -1394,6 +1394,9 @@ class VitsModel(VitsPreTrainedModel): ...@@ -1394,6 +1394,9 @@ class VitsModel(VitsPreTrainedModel):
) )
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if labels is not None:
raise NotImplementedError("Training of VITS is not supported yet.")
if attention_mask is not None: if attention_mask is not None:
input_padding_mask = attention_mask.unsqueeze(-1).float() input_padding_mask = attention_mask.unsqueeze(-1).float()
else: else:
...@@ -1408,9 +1411,6 @@ class VitsModel(VitsPreTrainedModel): ...@@ -1408,9 +1411,6 @@ class VitsModel(VitsPreTrainedModel):
else: else:
speaker_embeddings = None speaker_embeddings = None
if labels is not None:
raise NotImplementedError("Training of VITS is not supported yet.")
text_encoder_output = self.text_encoder( text_encoder_output = self.text_encoder(
input_ids=input_ids, input_ids=input_ids,
padding_mask=input_padding_mask, padding_mask=input_padding_mask,
......
...@@ -1671,6 +1671,8 @@ class TFWav2Vec2ForCTC(TFWav2Vec2PreTrainedModel): ...@@ -1671,6 +1671,8 @@ class TFWav2Vec2ForCTC(TFWav2Vec2PreTrainedModel):
>>> loss = model(input_values, labels=labels).loss >>> loss = model(input_values, labels=labels).loss
```""" ```"""
if labels is not None and tf.reduce_max(labels) >= self.config.vocab_size:
raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
outputs = self.wav2vec2( outputs = self.wav2vec2(
input_values=input_values, input_values=input_values,
...@@ -1690,9 +1692,6 @@ class TFWav2Vec2ForCTC(TFWav2Vec2PreTrainedModel): ...@@ -1690,9 +1692,6 @@ class TFWav2Vec2ForCTC(TFWav2Vec2PreTrainedModel):
logits = self.lm_head(hidden_states) logits = self.lm_head(hidden_states)
if labels is not None: if labels is not None:
if tf.reduce_max(labels) >= self.config.vocab_size:
raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
attention_mask = ( attention_mask = (
attention_mask if attention_mask is not None else tf.ones_like(input_values, dtype=tf.float32) attention_mask if attention_mask is not None else tf.ones_like(input_values, dtype=tf.float32)
) )
......
...@@ -2327,9 +2327,11 @@ class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel): ...@@ -2327,9 +2327,11 @@ class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel):
All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ...,
config.vocab_size - 1]`. config.vocab_size - 1]`.
""" """
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if labels is not None and labels.max() >= self.config.vocab_size:
raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
outputs = self.wav2vec2( outputs = self.wav2vec2(
input_values, input_values,
attention_mask=attention_mask, attention_mask=attention_mask,
...@@ -2345,9 +2347,6 @@ class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel): ...@@ -2345,9 +2347,6 @@ class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel):
loss = None loss = None
if labels is not None: if labels is not None:
if labels.max() >= self.config.vocab_size:
raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
# retrieve loss input_lengths from attention_mask # retrieve loss input_lengths from attention_mask
attention_mask = ( attention_mask = (
attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long) attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long)
......
...@@ -1219,6 +1219,8 @@ class Wav2Vec2BertForCTC(Wav2Vec2BertPreTrainedModel): ...@@ -1219,6 +1219,8 @@ class Wav2Vec2BertForCTC(Wav2Vec2BertPreTrainedModel):
All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ...,
config.vocab_size - 1]`. config.vocab_size - 1]`.
""" """
if labels is not None and labels.max() >= self.config.vocab_size:
raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
...@@ -1237,9 +1239,6 @@ class Wav2Vec2BertForCTC(Wav2Vec2BertPreTrainedModel): ...@@ -1237,9 +1239,6 @@ class Wav2Vec2BertForCTC(Wav2Vec2BertPreTrainedModel):
loss = None loss = None
if labels is not None: if labels is not None:
if labels.max() >= self.config.vocab_size:
raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
# retrieve loss input_lengths from attention_mask # retrieve loss input_lengths from attention_mask
attention_mask = ( attention_mask = (
attention_mask attention_mask
......
...@@ -1645,9 +1645,11 @@ class Wav2Vec2ConformerForCTC(Wav2Vec2ConformerPreTrainedModel): ...@@ -1645,9 +1645,11 @@ class Wav2Vec2ConformerForCTC(Wav2Vec2ConformerPreTrainedModel):
All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ...,
config.vocab_size - 1]`. config.vocab_size - 1]`.
""" """
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if labels is not None and labels.max() >= self.config.vocab_size:
raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
outputs = self.wav2vec2_conformer( outputs = self.wav2vec2_conformer(
input_values, input_values,
attention_mask=attention_mask, attention_mask=attention_mask,
...@@ -1663,9 +1665,6 @@ class Wav2Vec2ConformerForCTC(Wav2Vec2ConformerPreTrainedModel): ...@@ -1663,9 +1665,6 @@ class Wav2Vec2ConformerForCTC(Wav2Vec2ConformerPreTrainedModel):
loss = None loss = None
if labels is not None: if labels is not None:
if labels.max() >= self.config.vocab_size:
raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
# retrieve loss input_lengths from attention_mask # retrieve loss input_lengths from attention_mask
attention_mask = ( attention_mask = (
attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long) attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long)
......
...@@ -1349,9 +1349,11 @@ class WavLMForCTC(WavLMPreTrainedModel): ...@@ -1349,9 +1349,11 @@ class WavLMForCTC(WavLMPreTrainedModel):
All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ...,
config.vocab_size - 1]`. config.vocab_size - 1]`.
""" """
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if labels is not None and labels.max() >= self.config.vocab_size:
raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
outputs = self.wavlm( outputs = self.wavlm(
input_values, input_values,
attention_mask=attention_mask, attention_mask=attention_mask,
...@@ -1367,9 +1369,6 @@ class WavLMForCTC(WavLMPreTrainedModel): ...@@ -1367,9 +1369,6 @@ class WavLMForCTC(WavLMPreTrainedModel):
loss = None loss = None
if labels is not None: if labels is not None:
if labels.max() >= self.config.vocab_size:
raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
# retrieve loss input_lengths from attention_mask # retrieve loss input_lengths from attention_mask
attention_mask = ( attention_mask = (
attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long) attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment