Unverified Commit 4fd89e49 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: delete unused TF `_reorder_cache` (#20964)

parent a3e8d3cb
...@@ -449,10 +449,6 @@ class TFGenerationMixin: ...@@ -449,10 +449,6 @@ class TFGenerationMixin:
supports_xla_generation = True supports_xla_generation = True
@staticmethod
def _reorder_cache(past, beam_idx):
return tuple(tf.gather(layer_past, beam_idx, axis=1) for layer_past in past)
def adjust_logits_during_generation( def adjust_logits_during_generation(
self, logits, cur_len, max_length, forced_bos_token_id, forced_eos_token_id, **kwargs self, logits, cur_len, max_length, forced_bos_token_id, forced_eos_token_id, **kwargs
): ):
......
...@@ -1475,16 +1475,6 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode ...@@ -1475,16 +1475,6 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode
def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor): def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor):
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
@staticmethod
def _reorder_cache(past, beam_idx):
reordered_past = ()
for layer_past in past:
# cached cross_attention states don't have to be reordered -> they are always the same
reordered_past += (
tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past[:2]) + layer_past[2:],
)
return reordered_past
@add_start_docstrings( @add_start_docstrings(
""" """
......
...@@ -1508,13 +1508,6 @@ class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -1508,13 +1508,6 @@ class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss):
logits=output.logits, past_key_values=pkv, hidden_states=hs, attentions=attns, cross_attentions=cross_attns logits=output.logits, past_key_values=pkv, hidden_states=hs, attentions=attns, cross_attentions=cross_attns
) )
@staticmethod
def _reorder_cache(past, beam_idx):
reordered_past = ()
for layer_past in past:
reordered_past += (tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past),)
return reordered_past
@add_start_docstrings( @add_start_docstrings(
"""Bert Model with a `next sentence prediction (classification)` head on top.""", """Bert Model with a `next sentence prediction (classification)` head on top.""",
......
...@@ -1473,14 +1473,3 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel, TFCausal ...@@ -1473,14 +1473,3 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel, TFCausal
"cross_attn_head_mask": cross_attn_head_mask, "cross_attn_head_mask": cross_attn_head_mask,
"use_cache": use_cache, # change this to avoid caching (presumably for debugging) "use_cache": use_cache, # change this to avoid caching (presumably for debugging)
} }
@staticmethod
# Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration._reorder_cache
def _reorder_cache(past, beam_idx):
reordered_past = ()
for layer_past in past:
# cached cross_attention states don't have to be reordered -> they are always the same
reordered_past += (
tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past[:2]) + layer_past[2:],
)
return reordered_past
...@@ -1453,14 +1453,3 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel ...@@ -1453,14 +1453,3 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel
"cross_attn_head_mask": cross_attn_head_mask, "cross_attn_head_mask": cross_attn_head_mask,
"use_cache": use_cache, # change this to avoid caching (presumably for debugging) "use_cache": use_cache, # change this to avoid caching (presumably for debugging)
} }
@staticmethod
# Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration._reorder_cache
def _reorder_cache(past, beam_idx):
reordered_past = ()
for layer_past in past:
# cached cross_attention states don't have to be reordered -> they are always the same
reordered_past += (
tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past[:2]) + layer_past[2:],
)
return reordered_past
...@@ -1726,11 +1726,3 @@ class TFCamembertForCausalLM(TFCamembertPreTrainedModel, TFCausalLanguageModelin ...@@ -1726,11 +1726,3 @@ class TFCamembertForCausalLM(TFCamembertPreTrainedModel, TFCausalLanguageModelin
return TFCausalLMOutputWithCrossAttentions( return TFCausalLMOutputWithCrossAttentions(
logits=output.logits, past_key_values=pkv, hidden_states=hs, attentions=attns, cross_attentions=cross_attns logits=output.logits, past_key_values=pkv, hidden_states=hs, attentions=attns, cross_attentions=cross_attns
) )
@staticmethod
# Copied from transformers.models.bert.modeling_tf_bert.TFBertLMHeadModel._reorder_cache
def _reorder_cache(past, beam_idx):
reordered_past = ()
for layer_past in past:
reordered_past += (tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past),)
return reordered_past
...@@ -722,12 +722,6 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -722,12 +722,6 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss):
return TFCausalLMOutputWithPast(logits=output.logits, past_key_values=pkv, hidden_states=hs, attentions=attns) return TFCausalLMOutputWithPast(logits=output.logits, past_key_values=pkv, hidden_states=hs, attentions=attns)
@staticmethod
def _reorder_cache(past: Tuple[Tuple[tf.Tensor]], beam_idx: tf.Tensor) -> Tuple[Tuple[tf.Tensor]]:
return tuple(
tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past) for layer_past in past
)
@add_start_docstrings( @add_start_docstrings(
""" """
......
...@@ -720,7 +720,3 @@ class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -720,7 +720,3 @@ class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss):
" respective methods of the wrapped objects (model.encoder.resize_token_embeddings(...) or" " respective methods of the wrapped objects (model.encoder.resize_token_embeddings(...) or"
" model.decoder.resize_token_embeddings(...))" " model.decoder.resize_token_embeddings(...))"
) )
def _reorder_cache(self, past, beam_idx):
# apply decoder cache reordering here
return self.decoder._reorder_cache(past, beam_idx)
...@@ -2538,16 +2538,6 @@ class TFLEDForConditionalGeneration(TFLEDPreTrainedModel): ...@@ -2538,16 +2538,6 @@ class TFLEDForConditionalGeneration(TFLEDPreTrainedModel):
def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor): def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor):
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
@staticmethod
def _reorder_cache(past, beam_idx):
reordered_past = ()
for layer_past in past:
# cached cross_attention states don't have to be reordered -> they are always the same
reordered_past += (
tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past[:2]) + layer_past[2:],
)
return reordered_past
def hf_compute_loss(self, labels, logits): def hf_compute_loss(self, labels, logits):
"""CrossEntropyLoss that ignores pad tokens""" """CrossEntropyLoss that ignores pad tokens"""
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy( loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
......
...@@ -1494,17 +1494,6 @@ class TFMarianMTModel(TFMarianPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -1494,17 +1494,6 @@ class TFMarianMTModel(TFMarianPreTrainedModel, TFCausalLanguageModelingLoss):
def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor): def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor):
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
@staticmethod
# Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration._reorder_cache
def _reorder_cache(past, beam_idx):
reordered_past = ()
for layer_past in past:
# cached cross_attention states don't have to be reordered -> they are always the same
reordered_past += (
tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past[:2]) + layer_past[2:],
)
return reordered_past
def adjust_logits_during_generation( def adjust_logits_during_generation(
self, logits, cur_len, max_length, forced_bos_token_id, forced_eos_token_id, **kwargs self, logits, cur_len, max_length, forced_bos_token_id, forced_eos_token_id, **kwargs
): ):
......
...@@ -1490,14 +1490,3 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel, TFCausalLanguageMo ...@@ -1490,14 +1490,3 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel, TFCausalLanguageMo
def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor): def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor):
return shift_tokens_right(labels, self.config.pad_token_id) return shift_tokens_right(labels, self.config.pad_token_id)
@staticmethod
# Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration._reorder_cache
def _reorder_cache(past, beam_idx):
reordered_past = ()
for layer_past in past:
# cached cross_attention states don't have to be reordered -> they are always the same
reordered_past += (
tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past[:2]) + layer_past[2:],
)
return reordered_past
...@@ -1503,14 +1503,3 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel, TFCausalLangua ...@@ -1503,14 +1503,3 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel, TFCausalLangua
def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor): def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor):
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
@staticmethod
# Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration._reorder_cache
def _reorder_cache(past, beam_idx):
reordered_past = ()
for layer_past in past:
# cached cross_attention states don't have to be reordered -> they are always the same
reordered_past += (
tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past[:2]) + layer_past[2:],
)
return reordered_past
...@@ -799,24 +799,6 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss ...@@ -799,24 +799,6 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
def question_encoder(self): def question_encoder(self):
return self.rag.question_encoder return self.rag.question_encoder
@staticmethod
def _reorder_cache(past, beam_idx):
"""Reorders cache for generation. BART-inspired but we need to take care of the extra dimension for docs"""
def _reorder_stacked(hidden_states, new_order):
n_docs = hidden_states.shape[0] // new_order.shape[0]
hidden_states = tf.reshape(hidden_states, (-1, n_docs, *hidden_states.shape[1:]))
hidden_states = tf.gather(hidden_states, new_order, axis=0)
result = tf.reshape(hidden_states, (-1, *hidden_states.shape[2:]))
return result
reordered_past = ()
for layer_past in past:
# get the correct batch idx from decoder layer's batch dim for cross and self-attn
reordered_past += (tuple(_reorder_stacked(past_state, beam_idx) for past_state in layer_past),)
return reordered_past
@staticmethod @staticmethod
def _gather_beams(nested, beam_indices, batch_axis=0): def _gather_beams(nested, beam_indices, batch_axis=0):
""" """
......
...@@ -1244,14 +1244,6 @@ class TFRemBertForCausalLM(TFRemBertPreTrainedModel, TFCausalLanguageModelingLos ...@@ -1244,14 +1244,6 @@ class TFRemBertForCausalLM(TFRemBertPreTrainedModel, TFCausalLanguageModelingLos
logits=output.logits, past_key_values=pkv, hidden_states=hs, attentions=attns, cross_attentions=cross_attns logits=output.logits, past_key_values=pkv, hidden_states=hs, attentions=attns, cross_attentions=cross_attns
) )
@staticmethod
# Copied from transformers.models.bert.modeling_tf_bert.TFBertLMHeadModel._reorder_cache
def _reorder_cache(past, beam_idx):
reordered_past = ()
for layer_past in past:
reordered_past += (tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past),)
return reordered_past
@add_start_docstrings( @add_start_docstrings(
""" """
......
...@@ -1286,14 +1286,6 @@ class TFRobertaForCausalLM(TFRobertaPreTrainedModel, TFCausalLanguageModelingLos ...@@ -1286,14 +1286,6 @@ class TFRobertaForCausalLM(TFRobertaPreTrainedModel, TFCausalLanguageModelingLos
logits=output.logits, past_key_values=pkv, hidden_states=hs, attentions=attns, cross_attentions=cross_attns logits=output.logits, past_key_values=pkv, hidden_states=hs, attentions=attns, cross_attentions=cross_attns
) )
@staticmethod
# Copied from transformers.models.bert.modeling_tf_bert.TFBertLMHeadModel._reorder_cache
def _reorder_cache(past, beam_idx):
reordered_past = ()
for layer_past in past:
reordered_past += (tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past),)
return reordered_past
class TFRobertaClassificationHead(tf.keras.layers.Layer): class TFRobertaClassificationHead(tf.keras.layers.Layer):
"""Head for sentence-level classification tasks.""" """Head for sentence-level classification tasks."""
......
...@@ -1301,14 +1301,6 @@ class TFRobertaPreLayerNormForCausalLM(TFRobertaPreLayerNormPreTrainedModel, TFC ...@@ -1301,14 +1301,6 @@ class TFRobertaPreLayerNormForCausalLM(TFRobertaPreLayerNormPreTrainedModel, TFC
logits=output.logits, past_key_values=pkv, hidden_states=hs, attentions=attns, cross_attentions=cross_attns logits=output.logits, past_key_values=pkv, hidden_states=hs, attentions=attns, cross_attentions=cross_attns
) )
@staticmethod
# Copied from transformers.models.bert.modeling_tf_bert.TFBertLMHeadModel._reorder_cache
def _reorder_cache(past, beam_idx):
reordered_past = ()
for layer_past in past:
reordered_past += (tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past),)
return reordered_past
# Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaClassificationHead with Roberta->RobertaPreLayerNorm # Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaClassificationHead with Roberta->RobertaPreLayerNorm
class TFRobertaPreLayerNormClassificationHead(tf.keras.layers.Layer): class TFRobertaPreLayerNormClassificationHead(tf.keras.layers.Layer):
......
...@@ -1501,10 +1501,3 @@ class TFSpeech2TextForConditionalGeneration(TFSpeech2TextPreTrainedModel, TFCaus ...@@ -1501,10 +1501,3 @@ class TFSpeech2TextForConditionalGeneration(TFSpeech2TextPreTrainedModel, TFCaus
"cross_attn_head_mask": cross_attn_head_mask, "cross_attn_head_mask": cross_attn_head_mask,
"use_cache": use_cache, # change this to avoid caching (presumably for debugging) "use_cache": use_cache, # change this to avoid caching (presumably for debugging)
} }
@staticmethod
def _reorder_cache(past, beam_idx):
reordered_past = ()
for layer_past in past:
reordered_past += (tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past),)
return reordered_past
...@@ -1528,30 +1528,6 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling ...@@ -1528,30 +1528,6 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor): def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor):
return self._shift_right(labels) return self._shift_right(labels)
def _reorder_cache(self, past, beam_idx):
# if decoder past is not included in output
# speedy decoding is disabled and no need to reorder
if past is None:
logger.warning("You might want to consider setting `use_cache=True` to speed up decoding")
return past
reordered_decoder_past = ()
for layer_past_states in past:
# get the correct batch idx from layer past batch dim
# batch dim of `past` is at 2nd position
reordered_layer_past_states = ()
for layer_past_state in layer_past_states:
# need to set correct `past` for each of the four key / value states
reordered_layer_past_states = reordered_layer_past_states + (
tf.gather(layer_past_state, beam_idx, axis=0),
)
assert reordered_layer_past_states[0].shape == layer_past_states[0].shape
assert len(reordered_layer_past_states) == len(layer_past_states)
reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,)
return reordered_decoder_past
@add_start_docstrings( @add_start_docstrings(
"The bare T5 Model transformer outputting encoder's raw hidden-stateswithout any specific head on top.", "The bare T5 Model transformer outputting encoder's raw hidden-stateswithout any specific head on top.",
......
...@@ -1039,10 +1039,6 @@ class TFTransfoXLLMHeadModel(TFTransfoXLPreTrainedModel): ...@@ -1039,10 +1039,6 @@ class TFTransfoXLLMHeadModel(TFTransfoXLPreTrainedModel):
return inputs return inputs
@staticmethod
def _reorder_cache(mems: List[tf.Tensor], beam_idx: tf.Tensor) -> List[tf.Tensor]:
return [tf.gather(layer_past, beam_idx, axis=1) for layer_past in mems]
@add_start_docstrings( @add_start_docstrings(
""" """
......
...@@ -756,7 +756,3 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos ...@@ -756,7 +756,3 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos
"Resizing the embedding layers via the TFVisionEncoderDecoderModel directly is not supported." "Resizing the embedding layers via the TFVisionEncoderDecoderModel directly is not supported."
"Please use the respective methods of the wrapped objects (model.decoder.resize_token_embeddings(...))" "Please use the respective methods of the wrapped objects (model.decoder.resize_token_embeddings(...))"
) )
def _reorder_cache(self, past, beam_idx):
# apply decoder cache reordering here
return self.decoder._reorder_cache(past, beam_idx)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment