"...resnet50_tensorflow.git" did not exist on "adc01cd76ae0d9d3b2e8dde3ec6bf4086f7da046"
Unverified Commit 3e07196f authored by Arthur's avatar Arthur Committed by GitHub
Browse files

check decoder_inputs_embeds is None before shifting labels (#19671)

parent d356b89f
...@@ -1352,7 +1352,7 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode ...@@ -1352,7 +1352,7 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode
labels, labels,
) )
use_cache = False use_cache = False
if decoder_input_ids is None: if decoder_input_ids is None and decoder_inputs_embeds is None:
decoder_input_ids = shift_tokens_right( decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id labels, self.config.pad_token_id, self.config.decoder_start_token_id
) )
......
...@@ -1319,7 +1319,7 @@ class BlenderbotForConditionalGeneration(BlenderbotPreTrainedModel): ...@@ -1319,7 +1319,7 @@ class BlenderbotForConditionalGeneration(BlenderbotPreTrainedModel):
if use_cache: if use_cache:
logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.") logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
use_cache = False use_cache = False
if decoder_input_ids is None: if decoder_input_ids is None and decoder_inputs_embeds is None:
decoder_input_ids = shift_tokens_right( decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id labels, self.config.pad_token_id, self.config.decoder_start_token_id
) )
......
...@@ -1371,7 +1371,7 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel, TFCausal ...@@ -1371,7 +1371,7 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel, TFCausal
labels, labels,
) )
use_cache = False use_cache = False
if decoder_input_ids is None: if decoder_input_ids is None and decoder_inputs_embeds is None:
decoder_input_ids = shift_tokens_right( decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id labels, self.config.pad_token_id, self.config.decoder_start_token_id
) )
......
...@@ -1286,7 +1286,7 @@ class BlenderbotSmallForConditionalGeneration(BlenderbotSmallPreTrainedModel): ...@@ -1286,7 +1286,7 @@ class BlenderbotSmallForConditionalGeneration(BlenderbotSmallPreTrainedModel):
if use_cache: if use_cache:
logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.") logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
use_cache = False use_cache = False
if decoder_input_ids is None: if decoder_input_ids is None and decoder_inputs_embeds is None:
decoder_input_ids = shift_tokens_right( decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id labels, self.config.pad_token_id, self.config.decoder_start_token_id
) )
......
...@@ -1351,7 +1351,7 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel ...@@ -1351,7 +1351,7 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel
labels, labels,
) )
use_cache = False use_cache = False
if decoder_input_ids is None: if decoder_input_ids is None and decoder_inputs_embeds is None:
decoder_input_ids = shift_tokens_right( decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id labels, self.config.pad_token_id, self.config.decoder_start_token_id
) )
......
...@@ -2428,7 +2428,7 @@ class LEDForConditionalGeneration(LEDPreTrainedModel): ...@@ -2428,7 +2428,7 @@ class LEDForConditionalGeneration(LEDPreTrainedModel):
if use_cache: if use_cache:
logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.") logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
use_cache = False use_cache = False
if decoder_input_ids is None: if decoder_input_ids is None and decoder_inputs_embeds is None:
decoder_input_ids = shift_tokens_right( decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id labels, self.config.pad_token_id, self.config.decoder_start_token_id
) )
......
...@@ -2445,7 +2445,7 @@ class TFLEDForConditionalGeneration(TFLEDPreTrainedModel): ...@@ -2445,7 +2445,7 @@ class TFLEDForConditionalGeneration(TFLEDPreTrainedModel):
if labels is not None: if labels is not None:
use_cache = False use_cache = False
if decoder_input_ids is None: if decoder_input_ids is None and decoder_inputs_embeds is None:
decoder_input_ids = shift_tokens_right( decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id labels, self.config.pad_token_id, self.config.decoder_start_token_id
) )
......
...@@ -1432,7 +1432,7 @@ class MarianMTModel(MarianPreTrainedModel): ...@@ -1432,7 +1432,7 @@ class MarianMTModel(MarianPreTrainedModel):
if use_cache: if use_cache:
logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.") logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
use_cache = False use_cache = False
if decoder_input_ids is None: if decoder_input_ids is None and decoder_inputs_embeds is None:
decoder_input_ids = shift_tokens_right( decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id labels, self.config.pad_token_id, self.config.decoder_start_token_id
) )
......
...@@ -1388,7 +1388,7 @@ class TFMarianMTModel(TFMarianPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -1388,7 +1388,7 @@ class TFMarianMTModel(TFMarianPreTrainedModel, TFCausalLanguageModelingLoss):
labels, labels,
) )
use_cache = False use_cache = False
if decoder_input_ids is None: if decoder_input_ids is None and decoder_inputs_embeds is None:
decoder_input_ids = shift_tokens_right( decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id labels, self.config.pad_token_id, self.config.decoder_start_token_id
) )
......
...@@ -1347,7 +1347,7 @@ class MBartForConditionalGeneration(MBartPreTrainedModel): ...@@ -1347,7 +1347,7 @@ class MBartForConditionalGeneration(MBartPreTrainedModel):
if use_cache: if use_cache:
logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.") logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
use_cache = False use_cache = False
if decoder_input_ids is None: if decoder_input_ids is None and decoder_inputs_embeds is None:
decoder_input_ids = shift_tokens_right(labels, self.config.pad_token_id) decoder_input_ids = shift_tokens_right(labels, self.config.pad_token_id)
outputs = self.model( outputs = self.model(
......
...@@ -1387,7 +1387,7 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel, TFCausalLanguageMo ...@@ -1387,7 +1387,7 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel, TFCausalLanguageMo
labels, labels,
) )
use_cache = False use_cache = False
if decoder_input_ids is None: if decoder_input_ids is None and decoder_inputs_embeds is None:
decoder_input_ids = shift_tokens_right(labels, self.config.pad_token_id) decoder_input_ids = shift_tokens_right(labels, self.config.pad_token_id)
outputs = self.model( outputs = self.model(
......
...@@ -1393,7 +1393,7 @@ class PegasusForConditionalGeneration(PegasusPreTrainedModel): ...@@ -1393,7 +1393,7 @@ class PegasusForConditionalGeneration(PegasusPreTrainedModel):
if use_cache: if use_cache:
logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.") logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
use_cache = False use_cache = False
if decoder_input_ids is None: if decoder_input_ids is None and decoder_inputs_embeds is None:
decoder_input_ids = shift_tokens_right( decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id labels, self.config.pad_token_id, self.config.decoder_start_token_id
) )
......
...@@ -1397,7 +1397,7 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel, TFCausalLangua ...@@ -1397,7 +1397,7 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel, TFCausalLangua
labels, labels,
) )
use_cache = False use_cache = False
if decoder_input_ids is None: if decoder_input_ids is None and decoder_inputs_embeds is None:
decoder_input_ids = shift_tokens_right( decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id labels, self.config.pad_token_id, self.config.decoder_start_token_id
) )
......
...@@ -1605,7 +1605,7 @@ class PegasusXForConditionalGeneration(PegasusXPreTrainedModel): ...@@ -1605,7 +1605,7 @@ class PegasusXForConditionalGeneration(PegasusXPreTrainedModel):
if use_cache: if use_cache:
logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.") logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
use_cache = False use_cache = False
if decoder_input_ids is None: if decoder_input_ids is None and decoder_inputs_embeds is None:
decoder_input_ids = shift_tokens_right( decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id labels, self.config.pad_token_id, self.config.decoder_start_token_id
) )
......
...@@ -1314,7 +1314,7 @@ class PLBartForConditionalGeneration(PLBartPreTrainedModel): ...@@ -1314,7 +1314,7 @@ class PLBartForConditionalGeneration(PLBartPreTrainedModel):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if labels is not None: if labels is not None:
if decoder_input_ids is None: if decoder_input_ids is None and decoder_inputs_embeds is None:
decoder_input_ids = shift_tokens_right(labels, self.config.pad_token_id) decoder_input_ids = shift_tokens_right(labels, self.config.pad_token_id)
outputs = self.model( outputs = self.model(
......
...@@ -1341,7 +1341,7 @@ class Speech2TextForConditionalGeneration(Speech2TextPreTrainedModel): ...@@ -1341,7 +1341,7 @@ class Speech2TextForConditionalGeneration(Speech2TextPreTrainedModel):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if labels is not None: if labels is not None:
if decoder_input_ids is None: if decoder_input_ids is None and decoder_inputs_embeds is None:
decoder_input_ids = shift_tokens_right( decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id labels, self.config.pad_token_id, self.config.decoder_start_token_id
) )
......
...@@ -1405,7 +1405,7 @@ class TFSpeech2TextForConditionalGeneration(TFSpeech2TextPreTrainedModel, TFCaus ...@@ -1405,7 +1405,7 @@ class TFSpeech2TextForConditionalGeneration(TFSpeech2TextPreTrainedModel, TFCaus
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if labels is not None: if labels is not None:
if decoder_input_ids is None: if decoder_input_ids is None and decoder_inputs_embeds is None:
decoder_input_ids = shift_tokens_right( decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id labels, self.config.pad_token_id, self.config.decoder_start_token_id
) )
......
...@@ -1293,7 +1293,7 @@ class TFWhisperForConditionalGeneration(TFWhisperPreTrainedModel, TFCausalLangua ...@@ -1293,7 +1293,7 @@ class TFWhisperForConditionalGeneration(TFWhisperPreTrainedModel, TFCausalLangua
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if labels is not None: if labels is not None:
if decoder_input_ids is None: if decoder_input_ids is None and decoder_inputs_embeds is None:
decoder_input_ids = shift_tokens_right( decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id labels, self.config.pad_token_id, self.config.decoder_start_token_id
) )
......
...@@ -1183,7 +1183,7 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel): ...@@ -1183,7 +1183,7 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if labels is not None: if labels is not None:
if decoder_input_ids is None: if decoder_input_ids is None and decoder_inputs_embeds is None:
decoder_input_ids = shift_tokens_right( decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id labels, self.config.pad_token_id, self.config.decoder_start_token_id
) )
......
...@@ -2938,7 +2938,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec ...@@ -2938,7 +2938,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec
if labels is not None: if labels is not None:
use_cache = False use_cache = False
if decoder_input_ids is None: if decoder_input_ids is None and decoder_inputs_embeds is None:
decoder_input_ids = shift_tokens_right( decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id labels, self.config.pad_token_id, self.config.decoder_start_token_id
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment