Unverified Commit 95425d54 authored by Matt's avatar Matt Committed by GitHub
Browse files

Adding prepare_decoder_input_ids_from_labels methods to all...

Adding prepare_decoder_input_ids_from_labels methods to all ConditionalGeneration TF models (#12560)
parent ebc69afc
...@@ -1494,6 +1494,9 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode ...@@ -1494,6 +1494,9 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode
"use_cache": use_cache, # change this to avoid caching (presumably for debugging) "use_cache": use_cache, # change this to avoid caching (presumably for debugging)
} }
def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor):
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
@staticmethod @staticmethod
def _reorder_cache(past, beam_idx): def _reorder_cache(past, beam_idx):
if len(past) == 1: if len(past) == 1:
......
...@@ -2522,6 +2522,9 @@ class TFLEDForConditionalGeneration(TFLEDPreTrainedModel): ...@@ -2522,6 +2522,9 @@ class TFLEDForConditionalGeneration(TFLEDPreTrainedModel):
"use_cache": use_cache, # change this to avoid caching (presumably for debugging) "use_cache": use_cache, # change this to avoid caching (presumably for debugging)
} }
def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor):
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
@staticmethod @staticmethod
def _reorder_cache(past, beam_idx): def _reorder_cache(past, beam_idx):
if len(past) == 1: if len(past) == 1:
......
...@@ -1522,6 +1522,9 @@ class TFMarianMTModel(TFMarianPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -1522,6 +1522,9 @@ class TFMarianMTModel(TFMarianPreTrainedModel, TFCausalLanguageModelingLoss):
"use_cache": use_cache, # change this to avoid caching (presumably for debugging) "use_cache": use_cache, # change this to avoid caching (presumably for debugging)
} }
def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor):
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
@staticmethod @staticmethod
# Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration._reorder_cache # Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration._reorder_cache
def _reorder_cache(past, beam_idx): def _reorder_cache(past, beam_idx):
......
...@@ -1506,6 +1506,9 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel, TFCausalLanguageMo ...@@ -1506,6 +1506,9 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel, TFCausalLanguageMo
"use_cache": use_cache, # change this to avoid caching (presumably for debugging) "use_cache": use_cache, # change this to avoid caching (presumably for debugging)
} }
def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor):
return shift_tokens_right(labels, self.config.pad_token_id)
@staticmethod @staticmethod
# Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration._reorder_cache # Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration._reorder_cache
def _reorder_cache(past, beam_idx): def _reorder_cache(past, beam_idx):
......
...@@ -1531,6 +1531,9 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel, TFCausalLangua ...@@ -1531,6 +1531,9 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel, TFCausalLangua
"use_cache": use_cache, # change this to avoid caching (presumably for debugging) "use_cache": use_cache, # change this to avoid caching (presumably for debugging)
} }
def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor):
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
@staticmethod @staticmethod
# Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration._reorder_cache # Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration._reorder_cache
def _reorder_cache(past, beam_idx): def _reorder_cache(past, beam_idx):
......
...@@ -1499,6 +1499,9 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling ...@@ -1499,6 +1499,9 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
"use_cache": use_cache, "use_cache": use_cache,
} }
def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor):
return self._shift_right(labels)
def _reorder_cache(self, past, beam_idx) -> Tuple: def _reorder_cache(self, past, beam_idx) -> Tuple:
# if decoder past is not included in output # if decoder past is not included in output
# speedy decoding is disabled and no need to reorder # speedy decoding is disabled and no need to reorder
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment