Unverified Commit 70203b59 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

TF generate refactor - past without encoder outputs (#15944)

* Remove packed past from generation_tf_utils

* update models with the new past format

* update template accordingly
parent 62d84760
...@@ -867,9 +867,8 @@ class TFGenerationMixin: ...@@ -867,9 +867,8 @@ class TFGenerationMixin:
beam_scores = tf.reshape(beam_scores, (batch_size * num_beams,)) beam_scores = tf.reshape(beam_scores, (batch_size * num_beams,))
# cache compute states # variable to cache compute states
past = encoder_outputs past = None
# to stay similar to torch : past = (encoder_outputs, None) if encoder_outputs is not None else None
# init attention / hidden states / scores tuples # init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and kwargs["output_scores"]) else None scores = () if (return_dict_in_generate and kwargs["output_scores"]) else None
...@@ -886,6 +885,13 @@ class TFGenerationMixin: ...@@ -886,6 +885,13 @@ class TFGenerationMixin:
if (return_dict_in_generate and kwargs["encoder_hidden_states"]) if (return_dict_in_generate and kwargs["encoder_hidden_states"])
else None else None
) )
# the refactored generate, without the encoder outputs in `past`, expects the `encoder_outputs`
# variable to contain all (encoder_outputs, encoder_hidden_states, encoder_attentions) in
# `prepare_inputs_for_generation`
if encoder_hidden_states is not None:
encoder_outputs = (*encoder_outputs, encoder_hidden_states)
if encoder_attentions is not None:
encoder_outputs = (*encoder_outputs, encoder_attentions)
# done sentences # done sentences
done = [False for _ in range(batch_size)] done = [False for _ in range(batch_size)]
...@@ -896,6 +902,7 @@ class TFGenerationMixin: ...@@ -896,6 +902,7 @@ class TFGenerationMixin:
past=past, past=past,
attention_mask=attention_mask, attention_mask=attention_mask,
use_cache=use_cache, use_cache=use_cache,
encoder_outputs=encoder_outputs,
**kwargs, **kwargs,
) )
outputs = self( outputs = self(
...@@ -1486,14 +1493,10 @@ class TFGenerationMixin: ...@@ -1486,14 +1493,10 @@ class TFGenerationMixin:
if model_kwargs.get("attention_mask", None) is None and requires_attention_mask: if model_kwargs.get("attention_mask", None) is None and requires_attention_mask:
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(input_ids, pad_token_id) model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(input_ids, pad_token_id)
# 4. Prepare model inputs which will be used for auto-regressive generation
if self.config.is_encoder_decoder: if self.config.is_encoder_decoder:
# if model is encoder decoder model, we create encoder_outputs and add to `model_kwargs` # if encoder-decoder, we create encoder_outputs and add to `model_kwargs`
model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation( model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(input_ids, model_kwargs)
input_ids, return_dict_in_generate, model_kwargs
)
# 4. Prepare `input_ids` which will be used for auto-regressive generation
if self.config.is_encoder_decoder:
# if encoder-decoder then `input_ids` come from `decoder_start_token_id` # if encoder-decoder then `input_ids` come from `decoder_start_token_id`
input_ids = self._prepare_decoder_input_ids_for_generation( input_ids = self._prepare_decoder_input_ids_for_generation(
batch_size, batch_size,
...@@ -1531,10 +1534,6 @@ class TFGenerationMixin: ...@@ -1531,10 +1534,6 @@ class TFGenerationMixin:
f"num_return_sequences has to be 1, but is {num_return_sequences} when doing greedy search." f"num_return_sequences has to be 1, but is {num_return_sequences} when doing greedy search."
) )
# TODO(Patrick) - ugly `past`/`encoder_output` hack here which requires a bigger refactor of all
# generation models in TF. `past` should be optional everywhere and not be set equal to encoder_outputs.
model_kwargs["past"] = model_kwargs.get("encoder_outputs")[:1] if self.config.is_encoder_decoder else None
# 8. run greedy search # 8. run greedy search
return self.greedy_search( return self.greedy_search(
input_ids, input_ids,
...@@ -1559,10 +1558,6 @@ class TFGenerationMixin: ...@@ -1559,10 +1558,6 @@ class TFGenerationMixin:
**model_kwargs, **model_kwargs,
) )
# TODO(Patrick) - ugly `past`/`encoder_output` hack here which requires a bigger refactor of all
# generation models in TF. `past` should be optional everywhere and not be set equal to encoder_outputs.
model_kwargs["past"] = model_kwargs.get("encoder_outputs")[:1] if self.config.is_encoder_decoder else None
# 10. run sample # 10. run sample
return self.sample( return self.sample(
input_ids, input_ids,
...@@ -1589,12 +1584,7 @@ class TFGenerationMixin: ...@@ -1589,12 +1584,7 @@ class TFGenerationMixin:
else: else:
return tf.ones(input_ids.shape[:2], dtype=tf.int32) return tf.ones(input_ids.shape[:2], dtype=tf.int32)
def _prepare_encoder_decoder_kwargs_for_generation( def _prepare_encoder_decoder_kwargs_for_generation(self, input_ids: tf.Tensor, model_kwargs) -> Dict[str, Any]:
self, input_ids: tf.Tensor, return_dict_in_generate, model_kwargs
) -> Dict[str, Any]:
# TODO(Patrick) - remove `return_dict_in_generate` flag input once `past`/`encoder_outputs`
# is cleaned
# get encoder and store encoder outputs # get encoder and store encoder outputs
encoder = self.get_encoder() encoder = self.get_encoder()
...@@ -1612,17 +1602,8 @@ class TFGenerationMixin: ...@@ -1612,17 +1602,8 @@ class TFGenerationMixin:
encoder_kwargs.pop("attention_mask") encoder_kwargs.pop("attention_mask")
encoder_outputs = encoder(input_ids, **encoder_kwargs) encoder_outputs = encoder(input_ids, **encoder_kwargs)
model_kwargs["encoder_outputs"] = encoder_outputs model_kwargs["encoder_outputs"] = encoder_outputs
# TODO(Patrick): `encoder_outputs`, `past` hack. Currently, `encoder_attentions` and
# `encoder_hidden_states` have to be seperated from encoder_outputs and passed
# under other names because of `encoder_outputs`, `past` hack. Need to clean-up
# all encoder-decoder prepare_inputs_for_generation method to clean this
if return_dict_in_generate:
model_kwargs["encoder_attentions"] = encoder_outputs.get("attentions", None)
model_kwargs["encoder_hidden_states"] = encoder_outputs.get("hidden_states", None)
return model_kwargs return model_kwargs
def _prepare_decoder_input_ids_for_generation( def _prepare_decoder_input_ids_for_generation(
...@@ -1712,27 +1693,17 @@ class TFGenerationMixin: ...@@ -1712,27 +1693,17 @@ class TFGenerationMixin:
return inputs return inputs
@staticmethod
def _update_model_kwargs_for_generation( def _update_model_kwargs_for_generation(
self, outputs: ModelOutput, model_kwargs: Dict[str, Any], is_encoder_decoder: bool = False outputs: ModelOutput, model_kwargs: Dict[str, Any], is_encoder_decoder: bool = False
) -> Dict[str, Any]: ) -> Dict[str, Any]:
# update past # update past
if self._use_cache(outputs, model_kwargs["use_cache"]): if "past_key_values" in outputs:
# TODO(Patrick): `past`/`encoder_outputs` hack. This should be
# removed when cleaning up the encoder-decoder models
# if model has past, then set the past variable to speed up decoding
# make this method static then as well
model_kwargs["past"] = outputs[1]
elif "past_key_values" in outputs:
model_kwargs["past"] = outputs.past_key_values model_kwargs["past"] = outputs.past_key_values
elif "mems" in outputs: elif "mems" in outputs:
model_kwargs["past"] = outputs.mems model_kwargs["past"] = outputs.mems
elif "past_buckets_states" in outputs: elif "past_buckets_states" in outputs:
model_kwargs["past"] = outputs.past_buckets_states model_kwargs["past"] = outputs.past_buckets_states
elif "past" in model_kwargs:
# TODO(Patrick) `past`/`encoder_outputs` hack.
# removed when cleaning up the encoder-decoder models.
# The line should not be necessary.
pass
else: else:
model_kwargs["past"] = None model_kwargs["past"] = None
...@@ -1907,26 +1878,18 @@ class TFGenerationMixin: ...@@ -1907,26 +1878,18 @@ class TFGenerationMixin:
cross_attentions = () if (return_dict_in_generate and output_attentions) else None cross_attentions = () if (return_dict_in_generate and output_attentions) else None
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
# TODO(Patrick): `encoder_outputs`, `past` hack. Currently T5, Bart expect `encoder_outputs`
# to be wrapped into `past` variable. Tis is a bad design and needs
# to be updated.
# Remove the following lines when updating all encoder-decoder models
encoder_outputs = model_kwargs.pop("encoder_outputs", None)
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
if return_dict_in_generate and self.config.is_encoder_decoder: if return_dict_in_generate and self.config.is_encoder_decoder:
encoder_attentions = encoder_outputs.get("attentions") if output_attentions else None encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
encoder_hidden_states = encoder_outputs.get("hidden_states") if output_hidden_states else None encoder_hidden_states = (
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
)
# keep track of which sequences are already finished # keep track of which sequences are already finished
unfinished_sequences = tf.ones_like(input_ids[:, 0]) unfinished_sequences = tf.ones_like(input_ids[:, 0])
cur_len = input_ids.shape[-1] cur_len = input_ids.shape[-1]
while cur_len < max_length: while cur_len < max_length:
# TODO(Patrick): remove following line by cleaning up `prepare_inputs_for_generation`
# in all models
model_kwargs["use_cache"] = None if "use_cache" not in model_kwargs else model_kwargs["use_cache"]
# prepare model inputs # prepare model inputs
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
...@@ -2129,25 +2092,18 @@ class TFGenerationMixin: ...@@ -2129,25 +2092,18 @@ class TFGenerationMixin:
cross_attentions = () if (return_dict_in_generate and output_attentions) else None cross_attentions = () if (return_dict_in_generate and output_attentions) else None
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
# TODO(Patrick): `encoder_outputs`, `past` hack. Currently T5, Bart expect `encoder_outputs`
# to be wrapped into `past` variable. This is a bad design and needs to be updated.
# Remove the following lines when updating all encoder-decoder models
encoder_outputs = model_kwargs.pop("encoder_outputs", None)
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
if return_dict_in_generate and self.config.is_encoder_decoder: if return_dict_in_generate and self.config.is_encoder_decoder:
encoder_attentions = encoder_outputs.get("attentions") if output_attentions else None encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
encoder_hidden_states = encoder_outputs.get("hidden_states") if output_hidden_states else None encoder_hidden_states = (
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
)
# keep track of which sequences are already finished # keep track of which sequences are already finished
unfinished_sequences = tf.ones_like(input_ids[:, 0]) unfinished_sequences = tf.ones_like(input_ids[:, 0])
cur_len = input_ids.shape[-1] cur_len = input_ids.shape[-1]
while cur_len < max_length: while cur_len < max_length:
# TODO(Patrick): remove following line by cleaning up `prepare_inputs_for_generation`
# in all models
model_kwargs["use_cache"] = None if "use_cache" not in model_kwargs else model_kwargs["use_cache"]
# prepare model inputs # prepare model inputs
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
import random import random
from typing import Dict, Optional, Tuple, Union from typing import Optional, Tuple, Union
import tensorflow as tf import tensorflow as tf
...@@ -1012,9 +1012,6 @@ class TFBartDecoder(tf.keras.layers.Layer): ...@@ -1012,9 +1012,6 @@ class TFBartDecoder(tf.keras.layers.Layer):
if inputs["output_hidden_states"]: if inputs["output_hidden_states"]:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
if inputs["use_cache"]:
present_key_values = (inputs["encoder_hidden_states"], present_key_values)
if not inputs["return_dict"]: if not inputs["return_dict"]:
return hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attns return hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attns
else: else:
...@@ -1449,43 +1446,23 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode ...@@ -1449,43 +1446,23 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode
def prepare_inputs_for_generation( def prepare_inputs_for_generation(
self, self,
decoder_input_ids, decoder_input_ids,
past, past=None,
attention_mask, attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None, cross_attn_head_mask=None,
use_cache=None, use_cache=None,
**kwargs, encoder_outputs=None,
) -> Dict: **kwargs
assert past is not None and len(past) in {1, 2}, f"past has to be an iterable of length 1,2 got {past}" ):
if len(past) == 1: # cut decoder_input_ids if past is used
assert isinstance(past[0], tf.Tensor), f"`past[0]` has to be of type `tf.Tensor`, but is {type(past[0])}" if past is not None:
encoder_outputs = TFBaseModelOutput(last_hidden_state=past[0])
past_key_values = None
else:
assert (
len(past) == 2
), "`past` has to be of length 2 with the encoder_outputs at the first position and past_key_values at the second position."
encoder_outputs, past_key_values = past
if isinstance(encoder_outputs, tuple):
assert isinstance(
encoder_outputs[0], tf.Tensor
), f"`encoder_outputs[0]` has to be of type `tf.Tensor`, but is {type(encoder_outputs[0])}"
encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_outputs[0])
elif isinstance(encoder_outputs, tf.Tensor):
encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_outputs)
assert (
past_key_values
), f"decoder cached states must be truthy. got {past_key_values} from the 2nd element of past"
decoder_input_ids = decoder_input_ids[:, -1:] decoder_input_ids = decoder_input_ids[:, -1:]
assert isinstance(
encoder_outputs, TFBaseModelOutput
), f"encoder_outputs should be a TFBaseModelOutput, Instead got {type(encoder_outputs)}."
return { return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed "input_ids": None, # encoder_outputs is defined. input_ids not needed
"encoder_outputs": encoder_outputs, "encoder_outputs": encoder_outputs,
"past_key_values": past_key_values, "past_key_values": past,
"decoder_input_ids": decoder_input_ids, "decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask, "attention_mask": attention_mask,
"head_mask": head_mask, "head_mask": head_mask,
...@@ -1499,15 +1476,10 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode ...@@ -1499,15 +1476,10 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode
@staticmethod @staticmethod
def _reorder_cache(past, beam_idx): def _reorder_cache(past, beam_idx):
if len(past) == 1:
return past
past_key_values = past[1]
reordered_past = () reordered_past = ()
for layer_past_key_values in past_key_values: for layer_past in past:
# cached cross_attention states don't have to be reordered -> they are always the same
reordered_past += ( reordered_past += (
tuple(tf.gather(layer_past_key_value, beam_idx) for layer_past_key_value in layer_past_key_values[:2]) tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past[:2]) + layer_past[2:],
+ layer_past_key_values[2:],
) )
return (past[0], reordered_past) return reordered_past
...@@ -1443,17 +1443,17 @@ class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -1443,17 +1443,17 @@ class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss):
warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning) warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
return self.name + "/" + self.mlm.name + "/" + self.mlm.predictions.name return self.name + "/" + self.mlm.name + "/" + self.mlm.predictions.name
def prepare_inputs_for_generation(self, inputs, past=None, attention_mask=None, **model_kwargs): def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
input_shape = input_ids.shape
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = tf.ones(input_shape)
# cut decoder_input_ids if past is used # cut decoder_input_ids if past is used
if past: if past is not None:
inputs = tf.expand_dims(inputs[:, -1], -1) input_ids = input_ids[:, -1:]
return { return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past}
"input_ids": inputs,
"attention_mask": attention_mask,
"past_key_values": past,
"use_cache": model_kwargs["use_cache"],
}
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC, processor_class=_TOKENIZER_FOR_DOC,
...@@ -1575,6 +1575,13 @@ class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -1575,6 +1575,13 @@ class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss):
logits=output.logits, past_key_values=pkv, hidden_states=hs, attentions=attns, cross_attentions=cross_attns logits=output.logits, past_key_values=pkv, hidden_states=hs, attentions=attns, cross_attentions=cross_attns
) )
@staticmethod
def _reorder_cache(past, beam_idx):
reordered_past = ()
for layer_past in past:
reordered_past += (tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past),)
return reordered_past
@add_start_docstrings( @add_start_docstrings(
"""Bert Model with a `next sentence prediction (classification)` head on top.""", """Bert Model with a `next sentence prediction (classification)` head on top.""",
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
import os import os
import random import random
import warnings import warnings
from typing import Dict, Optional, Tuple, Union from typing import Optional, Tuple, Union
import tensorflow as tf import tensorflow as tf
...@@ -1011,9 +1011,6 @@ class TFBlenderbotDecoder(tf.keras.layers.Layer): ...@@ -1011,9 +1011,6 @@ class TFBlenderbotDecoder(tf.keras.layers.Layer):
if inputs["output_hidden_states"]: if inputs["output_hidden_states"]:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
if inputs["use_cache"]:
present_key_values = (inputs["encoder_hidden_states"], present_key_values)
if not inputs["return_dict"]: if not inputs["return_dict"]:
return hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attns return hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attns
else: else:
...@@ -1461,43 +1458,23 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel, TFCausal ...@@ -1461,43 +1458,23 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel, TFCausal
def prepare_inputs_for_generation( def prepare_inputs_for_generation(
self, self,
decoder_input_ids, decoder_input_ids,
past, past=None,
attention_mask, attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None, cross_attn_head_mask=None,
use_cache=None, use_cache=None,
**kwargs, encoder_outputs=None,
) -> Dict: **kwargs
assert past is not None and len(past) in {1, 2}, f"past has to be an iterable of length 1,2 got {past}" ):
if len(past) == 1: # cut decoder_input_ids if past is used
assert isinstance(past[0], tf.Tensor), f"`past[0]` has to be of type `tf.Tensor`, but is {type(past[0])}" if past is not None:
encoder_outputs = TFBaseModelOutput(last_hidden_state=past[0])
past_key_values = None
else:
assert (
len(past) == 2
), "`past` has to be of length 2 with the encoder_outputs at the first position and past_key_values at the second position."
encoder_outputs, past_key_values = past
if isinstance(encoder_outputs, tuple):
assert isinstance(
encoder_outputs[0], tf.Tensor
), f"`encoder_outputs[0]` has to be of type `tf.Tensor`, but is {type(encoder_outputs[0])}"
encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_outputs[0])
elif isinstance(encoder_outputs, tf.Tensor):
encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_outputs)
assert (
past_key_values
), f"decoder cached states must be truthy. got {past_key_values} from the 2nd element of past"
decoder_input_ids = decoder_input_ids[:, -1:] decoder_input_ids = decoder_input_ids[:, -1:]
assert isinstance(
encoder_outputs, TFBaseModelOutput
), f"encoder_outputs should be a TFBaseModelOutput, Instead got {type(encoder_outputs)}."
return { return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed "input_ids": None, # encoder_outputs is defined. input_ids not needed
"encoder_outputs": encoder_outputs, "encoder_outputs": encoder_outputs,
"past_key_values": past_key_values, "past_key_values": past,
"decoder_input_ids": decoder_input_ids, "decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask, "attention_mask": attention_mask,
"head_mask": head_mask, "head_mask": head_mask,
...@@ -1509,15 +1486,10 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel, TFCausal ...@@ -1509,15 +1486,10 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel, TFCausal
@staticmethod @staticmethod
# Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration._reorder_cache # Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration._reorder_cache
def _reorder_cache(past, beam_idx): def _reorder_cache(past, beam_idx):
if len(past) == 1:
return past
past_key_values = past[1]
reordered_past = () reordered_past = ()
for layer_past_key_values in past_key_values: for layer_past in past:
# cached cross_attention states don't have to be reordered -> they are always the same
reordered_past += ( reordered_past += (
tuple(tf.gather(layer_past_key_value, beam_idx) for layer_past_key_value in layer_past_key_values[:2]) tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past[:2]) + layer_past[2:],
+ layer_past_key_values[2:],
) )
return (past[0], reordered_past) return reordered_past
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
import random import random
from typing import Dict, Optional, Tuple, Union from typing import Optional, Tuple, Union
import tensorflow as tf import tensorflow as tf
...@@ -1010,9 +1010,6 @@ class TFBlenderbotSmallDecoder(tf.keras.layers.Layer): ...@@ -1010,9 +1010,6 @@ class TFBlenderbotSmallDecoder(tf.keras.layers.Layer):
if inputs["output_hidden_states"]: if inputs["output_hidden_states"]:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
if inputs["use_cache"]:
present_key_values = (inputs["encoder_hidden_states"], present_key_values)
if not inputs["return_dict"]: if not inputs["return_dict"]:
return hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attns return hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attns
else: else:
...@@ -1434,43 +1431,23 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel ...@@ -1434,43 +1431,23 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel
def prepare_inputs_for_generation( def prepare_inputs_for_generation(
self, self,
decoder_input_ids, decoder_input_ids,
past, past=None,
attention_mask, attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None, cross_attn_head_mask=None,
use_cache=None, use_cache=None,
**kwargs, encoder_outputs=None,
) -> Dict: **kwargs
assert past is not None and len(past) in {1, 2}, f"past has to be an iterable of length 1,2 got {past}" ):
if len(past) == 1: # cut decoder_input_ids if past is used
assert isinstance(past[0], tf.Tensor), f"`past[0]` has to be of type `tf.Tensor`, but is {type(past[0])}" if past is not None:
encoder_outputs = TFBaseModelOutput(last_hidden_state=past[0])
past_key_values = None
else:
assert (
len(past) == 2
), "`past` has to be of length 2 with the encoder_outputs at the first position and past_key_values at the second position."
encoder_outputs, past_key_values = past
if isinstance(encoder_outputs, tuple):
assert isinstance(
encoder_outputs[0], tf.Tensor
), f"`encoder_outputs[0]` has to be of type `tf.Tensor`, but is {type(encoder_outputs[0])}"
encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_outputs[0])
elif isinstance(encoder_outputs, tf.Tensor):
encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_outputs)
assert (
past_key_values
), f"decoder cached states must be truthy. got {past_key_values} from the 2nd element of past"
decoder_input_ids = decoder_input_ids[:, -1:] decoder_input_ids = decoder_input_ids[:, -1:]
assert isinstance(
encoder_outputs, TFBaseModelOutput
), f"encoder_outputs should be a TFBaseModelOutput, Instead got {type(encoder_outputs)}."
return { return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed "input_ids": None, # encoder_outputs is defined. input_ids not needed
"encoder_outputs": encoder_outputs, "encoder_outputs": encoder_outputs,
"past_key_values": past_key_values, "past_key_values": past,
"decoder_input_ids": decoder_input_ids, "decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask, "attention_mask": attention_mask,
"head_mask": head_mask, "head_mask": head_mask,
...@@ -1482,15 +1459,10 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel ...@@ -1482,15 +1459,10 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel
@staticmethod @staticmethod
# Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration._reorder_cache # Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration._reorder_cache
def _reorder_cache(past, beam_idx): def _reorder_cache(past, beam_idx):
if len(past) == 1:
return past
past_key_values = past[1]
reordered_past = () reordered_past = ()
for layer_past_key_values in past_key_values: for layer_past in past:
# cached cross_attention states don't have to be reordered -> they are always the same
reordered_past += ( reordered_past += (
tuple(tf.gather(layer_past_key_value, beam_idx) for layer_past_key_value in layer_past_key_values[:2]) tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past[:2]) + layer_past[2:],
+ layer_past_key_values[2:],
) )
return (past[0], reordered_past) return reordered_past
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
""" TF 2.0 CTRL model.""" """ TF 2.0 CTRL model."""
import warnings import warnings
from typing import Tuple
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
...@@ -659,12 +660,12 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -659,12 +660,12 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss):
warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning) warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
return self.name + "/" + self.lm_head.name return self.name + "/" + self.lm_head.name
def prepare_inputs_for_generation(self, inputs, past, **kwargs): def prepare_inputs_for_generation(self, input_ids, past=None, use_cache=None, **kwargs):
# only last token for inputs_ids if past is defined in kwargs # only last token for inputs_ids if past is defined in kwargs
if past: if past:
inputs = tf.expand_dims(inputs[:, -1], -1) input_ids = tf.expand_dims(input_ids[:, -1], -1)
return {"input_ids": inputs, "past": past, "use_cache": kwargs["use_cache"]} return {"input_ids": input_ids, "past_key_values": past, "use_cache": use_cache}
@add_start_docstrings_to_model_forward(CTRL_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(CTRL_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
...@@ -758,6 +759,12 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -758,6 +759,12 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel, TFCausalLanguageModelingLoss):
return TFCausalLMOutputWithPast(logits=output.logits, past_key_values=pkv, hidden_states=hs, attentions=attns) return TFCausalLMOutputWithPast(logits=output.logits, past_key_values=pkv, hidden_states=hs, attentions=attns)
@staticmethod
def _reorder_cache(past: Tuple[Tuple[tf.Tensor]], beam_idx: tf.Tensor) -> Tuple[Tuple[tf.Tensor]]:
return tuple(
tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past) for layer_past in past
)
@add_start_docstrings( @add_start_docstrings(
""" """
......
...@@ -692,52 +692,21 @@ class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -692,52 +692,21 @@ class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss):
) )
def prepare_inputs_for_generation( def prepare_inputs_for_generation(
self, self, input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
decoder_input_ids,
past,
attention_mask,
use_cache=None,
**kwargs,
): ):
if past is None or len(past) not in {1, 2}: decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past=past)
raise ValueError(f"past has to be an iterable of length 1,2 got {past}") decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None
input_dict = {
if len(past) == 1: "input_ids": None, # needs to be passed to make Keras.layer.__call__ happy
if not isinstance(past[0], tf.Tensor):
raise ValueError(f"`past[0]` has to be of type `tf.Tensor`, but is {type(past[0])}")
encoder_outputs = TFBaseModelOutput(last_hidden_state=past[0])
past_key_values = None
else:
if len(past) != 2:
raise ValueError(
"`past` has to be of length 2 with the encoder_outputs at the first position and past_key_values at the second position."
)
encoder_outputs, past_key_values = past
if isinstance(encoder_outputs, tuple):
if not isinstance(encoder_outputs[0], tf.Tensor):
raise ValueError(
f"`encoder_outputs[0]` has to be of type `tf.Tensor`, but is {type(encoder_outputs[0])}"
)
encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_outputs[0])
elif isinstance(encoder_outputs, tf.Tensor):
encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_outputs)
if not past_key_values:
raise ValueError(
f"decoder cached states must be truthy. got {past_key_values} from the 2nd element of past"
)
decoder_input_ids = decoder_input_ids[:, -1:]
if not isinstance(encoder_outputs, TFBaseModelOutput):
raise ValueError(f"encoder_outputs should be a TFBaseModelOutput, Instead got {type(encoder_outputs)}.")
return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed
"encoder_outputs": encoder_outputs,
"past_key_values": past_key_values,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask, "attention_mask": attention_mask,
"use_cache": use_cache, # change this to avoid caching (presumably for debugging) "decoder_attention_mask": decoder_attention_mask,
"decoder_input_ids": decoder_inputs["input_ids"],
# TODO (joao): the `TFBaseModelOutput` wrapper should not be needed after the generate refactor is complete
"encoder_outputs": TFBaseModelOutput(last_hidden_state=encoder_outputs[0]),
"past_key_values": decoder_inputs["past_key_values"],
"use_cache": use_cache,
} }
return input_dict
def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor): def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor):
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
...@@ -750,9 +719,4 @@ class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -750,9 +719,4 @@ class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss):
def _reorder_cache(self, past, beam_idx): def _reorder_cache(self, past, beam_idx):
# apply decoder cache reordering here # apply decoder cache reordering here
if len(past) == 1: return self.decoder._reorder_cache(past, beam_idx)
return past
encoder_outputs, past_key_values = past
return (encoder_outputs, self.decoder._reorder_cache(past_key_values, beam_idx))
...@@ -851,12 +851,15 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -851,12 +851,15 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel, TFCausalLanguageModelingLoss):
def set_output_embeddings(self, value): def set_output_embeddings(self, value):
self.set_input_embeddings(value) self.set_input_embeddings(value)
def prepare_inputs_for_generation(self, inputs, past, **kwargs): def prepare_inputs_for_generation(self, inputs, past=None, use_cache=None, **kwargs):
# TODO: (Joao) after the TF generator is complete, update GPT2 TF generation to match PT's. NB -- some GPT2
# tests will need to be fixed after the change
# only last token for inputs_ids if past is defined in kwargs # only last token for inputs_ids if past is defined in kwargs
if past: if past:
inputs = tf.expand_dims(inputs[:, -1], -1) inputs = tf.expand_dims(inputs[:, -1], -1)
return {"input_ids": inputs, "past": past, "use_cache": kwargs["use_cache"]} return {"input_ids": inputs, "past_key_values": past, "use_cache": use_cache}
@add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
import random import random
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import tensorflow as tf import tensorflow as tf
...@@ -2097,7 +2097,7 @@ class TFLEDDecoder(tf.keras.layers.Layer): ...@@ -2097,7 +2097,7 @@ class TFLEDDecoder(tf.keras.layers.Layer):
all_self_attns = all_self_attns if inputs["output_attentions"] else None all_self_attns = all_self_attns if inputs["output_attentions"] else None
all_cross_attentions = all_cross_attentions if inputs["output_attentions"] else None all_cross_attentions = all_cross_attentions if inputs["output_attentions"] else None
present_key_values = (encoder_hidden_states, present_key_values) if inputs["use_cache"] else None present_key_values = present_key_values if inputs["use_cache"] else None
if not inputs["return_dict"]: if not inputs["return_dict"]:
return tuple( return tuple(
...@@ -2527,45 +2527,26 @@ class TFLEDForConditionalGeneration(TFLEDPreTrainedModel): ...@@ -2527,45 +2527,26 @@ class TFLEDForConditionalGeneration(TFLEDPreTrainedModel):
def prepare_inputs_for_generation( def prepare_inputs_for_generation(
self, self,
decoder_input_ids, decoder_input_ids,
past, past=None,
attention_mask, attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None,
use_cache=None, use_cache=None,
encoder_outputs=None,
**kwargs, **kwargs,
) -> Dict: ):
assert past is not None and len(past) in {1, 2}, f"past has to be an iterable of length 1,2 got {past}" # cut decoder_input_ids if past is used
if len(past) == 1: if past is not None:
assert isinstance(past[0], tf.Tensor), f"`past[0]` has to be of type `tf.Tensor`, but is {type(past[0])}"
encoder_outputs = TFLEDEncoderBaseModelOutput(last_hidden_state=past[0])
past_key_values = None
else:
assert (
len(past) == 2
), "`past` has to be of length 2 with the encoder_outputs at the first position and past_key_values at the second position."
encoder_outputs, past_key_values = past
if isinstance(encoder_outputs, tuple):
assert isinstance(
encoder_outputs[0], tf.Tensor
), f"`encoder_outputs[0]` has to be of type `tf.Tensor`, but is {type(encoder_outputs[0])}"
encoder_outputs = TFLEDEncoderBaseModelOutput(last_hidden_state=encoder_outputs[0])
elif isinstance(encoder_outputs, tf.Tensor):
encoder_outputs = TFLEDEncoderBaseModelOutput(last_hidden_state=encoder_outputs)
assert (
past_key_values
), f"decoder cached states must be truthy. got {past_key_values} from the 2nd element of past"
decoder_input_ids = decoder_input_ids[:, -1:] decoder_input_ids = decoder_input_ids[:, -1:]
assert isinstance(
encoder_outputs,
TFLEDEncoderBaseModelOutput,
), f"encoder_outputs should be a TFLEDEncoderBaseModelOutput, Instead got {type(encoder_outputs)}."
return { return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed "input_ids": None, # encoder_outputs is defined. input_ids not needed
"encoder_outputs": encoder_outputs, "encoder_outputs": encoder_outputs,
"past_key_values": past_key_values, "past_key_values": past,
"decoder_input_ids": decoder_input_ids, "decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask, "attention_mask": attention_mask,
"head_mask": head_mask, "head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
"use_cache": use_cache, # change this to avoid caching (presumably for debugging) "use_cache": use_cache, # change this to avoid caching (presumably for debugging)
} }
...@@ -2574,18 +2555,13 @@ class TFLEDForConditionalGeneration(TFLEDPreTrainedModel): ...@@ -2574,18 +2555,13 @@ class TFLEDForConditionalGeneration(TFLEDPreTrainedModel):
@staticmethod @staticmethod
def _reorder_cache(past, beam_idx): def _reorder_cache(past, beam_idx):
if len(past) == 1:
return past
past_key_values = past[1]
reordered_past = () reordered_past = ()
for layer_past_key_values in past_key_values: for layer_past in past:
# cached cross_attention states don't have to be reordered -> they are always the same
reordered_past += ( reordered_past += (
tuple(tf.gather(layer_past_key_value, beam_idx) for layer_past_key_value in layer_past_key_values[:2]) tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past[:2]) + layer_past[2:],
+ layer_past_key_values[2:],
) )
return (past[0], reordered_past) return reordered_past
def hf_compute_loss(self, labels, logits): def hf_compute_loss(self, labels, logits):
"""CrossEntropyLoss that ignores pad tokens""" """CrossEntropyLoss that ignores pad tokens"""
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
import random import random
from typing import Dict, Optional, Tuple, Union from typing import Optional, Tuple, Union
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
...@@ -1050,9 +1050,6 @@ class TFMarianDecoder(tf.keras.layers.Layer): ...@@ -1050,9 +1050,6 @@ class TFMarianDecoder(tf.keras.layers.Layer):
if inputs["output_hidden_states"]: if inputs["output_hidden_states"]:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
if inputs["use_cache"]:
present_key_values = (inputs["encoder_hidden_states"], present_key_values)
if not inputs["return_dict"]: if not inputs["return_dict"]:
return hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attns return hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attns
else: else:
...@@ -1477,43 +1474,23 @@ class TFMarianMTModel(TFMarianPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -1477,43 +1474,23 @@ class TFMarianMTModel(TFMarianPreTrainedModel, TFCausalLanguageModelingLoss):
def prepare_inputs_for_generation( def prepare_inputs_for_generation(
self, self,
decoder_input_ids, decoder_input_ids,
past, past=None,
attention_mask, attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None, cross_attn_head_mask=None,
use_cache=None, use_cache=None,
**kwargs, encoder_outputs=None,
) -> Dict: **kwargs
assert past is not None and len(past) in {1, 2}, f"past has to be an iterable of length 1,2 got {past}" ):
if len(past) == 1: # cut decoder_input_ids if past is used
assert isinstance(past[0], tf.Tensor), f"`past[0]` has to be of type `tf.Tensor`, but is {type(past[0])}" if past is not None:
encoder_outputs = TFBaseModelOutput(last_hidden_state=past[0])
past_key_values = None
else:
assert (
len(past) == 2
), "`past` has to be of length 2 with the encoder_outputs at the first position and past_key_values at the second position."
encoder_outputs, past_key_values = past
if isinstance(encoder_outputs, tuple):
assert isinstance(
encoder_outputs[0], tf.Tensor
), f"`encoder_outputs[0]` has to be of type `tf.Tensor`, but is {type(encoder_outputs[0])}"
encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_outputs[0])
elif isinstance(encoder_outputs, tf.Tensor):
encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_outputs)
assert (
past_key_values
), f"decoder cached states must be truthy. got {past_key_values} from the 2nd element of past"
decoder_input_ids = decoder_input_ids[:, -1:] decoder_input_ids = decoder_input_ids[:, -1:]
assert isinstance(
encoder_outputs, TFBaseModelOutput
), f"encoder_outputs should be a TFBaseModelOutput, Instead got {type(encoder_outputs)}."
return { return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed "input_ids": None, # encoder_outputs is defined. input_ids not needed
"encoder_outputs": encoder_outputs, "encoder_outputs": encoder_outputs,
"past_key_values": past_key_values, "past_key_values": past,
"decoder_input_ids": decoder_input_ids, "decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask, "attention_mask": attention_mask,
"head_mask": head_mask, "head_mask": head_mask,
...@@ -1528,18 +1505,13 @@ class TFMarianMTModel(TFMarianPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -1528,18 +1505,13 @@ class TFMarianMTModel(TFMarianPreTrainedModel, TFCausalLanguageModelingLoss):
@staticmethod @staticmethod
# Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration._reorder_cache # Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration._reorder_cache
def _reorder_cache(past, beam_idx): def _reorder_cache(past, beam_idx):
if len(past) == 1:
return past
past_key_values = past[1]
reordered_past = () reordered_past = ()
for layer_past_key_values in past_key_values: for layer_past in past:
# cached cross_attention states don't have to be reordered -> they are always the same
reordered_past += ( reordered_past += (
tuple(tf.gather(layer_past_key_value, beam_idx) for layer_past_key_value in layer_past_key_values[:2]) tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past[:2]) + layer_past[2:],
+ layer_past_key_values[2:],
) )
return (past[0], reordered_past) return reordered_past
def adjust_logits_during_generation( def adjust_logits_during_generation(
self, logits, cur_len, max_length, forced_bos_token_id, forced_eos_token_id, **kwargs self, logits, cur_len, max_length, forced_bos_token_id, forced_eos_token_id, **kwargs
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
import random import random
from typing import Dict, Optional, Tuple, Union from typing import Optional, Tuple, Union
import tensorflow as tf import tensorflow as tf
...@@ -1034,9 +1034,6 @@ class TFMBartDecoder(tf.keras.layers.Layer): ...@@ -1034,9 +1034,6 @@ class TFMBartDecoder(tf.keras.layers.Layer):
if inputs["output_hidden_states"]: if inputs["output_hidden_states"]:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
if inputs["use_cache"]:
present_key_values = (inputs["encoder_hidden_states"], present_key_values)
if not inputs["return_dict"]: if not inputs["return_dict"]:
return hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attns return hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attns
else: else:
...@@ -1462,43 +1459,23 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel, TFCausalLanguageMo ...@@ -1462,43 +1459,23 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel, TFCausalLanguageMo
def prepare_inputs_for_generation( def prepare_inputs_for_generation(
self, self,
decoder_input_ids, decoder_input_ids,
past, past=None,
attention_mask, attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None, cross_attn_head_mask=None,
use_cache=None, use_cache=None,
**kwargs, encoder_outputs=None,
) -> Dict: **kwargs
assert past is not None and len(past) in {1, 2}, f"past has to be an iterable of length 1,2 got {past}" ):
if len(past) == 1: # cut decoder_input_ids if past is used
assert isinstance(past[0], tf.Tensor), f"`past[0]` has to be of type `tf.Tensor`, but is {type(past[0])}" if past is not None:
encoder_outputs = TFBaseModelOutput(last_hidden_state=past[0])
past_key_values = None
else:
assert (
len(past) == 2
), "`past` has to be of length 2 with the encoder_outputs at the first position and past_key_values at the second position."
encoder_outputs, past_key_values = past
if isinstance(encoder_outputs, tuple):
assert isinstance(
encoder_outputs[0], tf.Tensor
), f"`encoder_outputs[0]` has to be of type `tf.Tensor`, but is {type(encoder_outputs[0])}"
encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_outputs[0])
elif isinstance(encoder_outputs, tf.Tensor):
encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_outputs)
assert (
past_key_values
), f"decoder cached states must be truthy. got {past_key_values} from the 2nd element of past"
decoder_input_ids = decoder_input_ids[:, -1:] decoder_input_ids = decoder_input_ids[:, -1:]
assert isinstance(
encoder_outputs, TFBaseModelOutput
), f"encoder_outputs should be a TFBaseModelOutput, Instead got {type(encoder_outputs)}."
return { return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed "input_ids": None, # encoder_outputs is defined. input_ids not needed
"encoder_outputs": encoder_outputs, "encoder_outputs": encoder_outputs,
"past_key_values": past_key_values, "past_key_values": past,
"decoder_input_ids": decoder_input_ids, "decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask, "attention_mask": attention_mask,
"head_mask": head_mask, "head_mask": head_mask,
...@@ -1513,15 +1490,10 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel, TFCausalLanguageMo ...@@ -1513,15 +1490,10 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel, TFCausalLanguageMo
@staticmethod @staticmethod
# Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration._reorder_cache # Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration._reorder_cache
def _reorder_cache(past, beam_idx): def _reorder_cache(past, beam_idx):
if len(past) == 1:
return past
past_key_values = past[1]
reordered_past = () reordered_past = ()
for layer_past_key_values in past_key_values: for layer_past in past:
# cached cross_attention states don't have to be reordered -> they are always the same
reordered_past += ( reordered_past += (
tuple(tf.gather(layer_past_key_value, beam_idx) for layer_past_key_value in layer_past_key_values[:2]) tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past[:2]) + layer_past[2:],
+ layer_past_key_values[2:],
) )
return (past[0], reordered_past) return reordered_past
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
import random import random
from typing import Dict, Optional, Tuple, Union from typing import Optional, Tuple, Union
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
...@@ -1058,9 +1058,6 @@ class TFPegasusDecoder(tf.keras.layers.Layer): ...@@ -1058,9 +1058,6 @@ class TFPegasusDecoder(tf.keras.layers.Layer):
if inputs["output_hidden_states"]: if inputs["output_hidden_states"]:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
if inputs["use_cache"]:
present_key_values = (inputs["encoder_hidden_states"], present_key_values)
if not inputs["return_dict"]: if not inputs["return_dict"]:
return hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attns return hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attns
else: else:
...@@ -1485,43 +1482,23 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel, TFCausalLangua ...@@ -1485,43 +1482,23 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel, TFCausalLangua
def prepare_inputs_for_generation( def prepare_inputs_for_generation(
self, self,
decoder_input_ids, decoder_input_ids,
past, past=None,
attention_mask, attention_mask=None,
head_mask=None, head_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None, cross_attn_head_mask=None,
use_cache=None, use_cache=None,
**kwargs, encoder_outputs=None,
) -> Dict: **kwargs
assert past is not None and len(past) in {1, 2}, f"past has to be an iterable of length 1,2 got {past}" ):
if len(past) == 1: # cut decoder_input_ids if past is used
assert isinstance(past[0], tf.Tensor), f"`past[0]` has to be of type `tf.Tensor`, but is {type(past[0])}" if past is not None:
encoder_outputs = TFBaseModelOutput(last_hidden_state=past[0])
past_key_values = None
else:
assert (
len(past) == 2
), "`past` has to be of length 2 with the encoder_outputs at the first position and past_key_values at the second position."
encoder_outputs, past_key_values = past
if isinstance(encoder_outputs, tuple):
assert isinstance(
encoder_outputs[0], tf.Tensor
), f"`encoder_outputs[0]` has to be of type `tf.Tensor`, but is {type(encoder_outputs[0])}"
encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_outputs[0])
elif isinstance(encoder_outputs, tf.Tensor):
encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_outputs)
assert (
past_key_values
), f"decoder cached states must be truthy. got {past_key_values} from the 2nd element of past"
decoder_input_ids = decoder_input_ids[:, -1:] decoder_input_ids = decoder_input_ids[:, -1:]
assert isinstance(
encoder_outputs, TFBaseModelOutput
), f"encoder_outputs should be a TFBaseModelOutput, Instead got {type(encoder_outputs)}."
return { return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed "input_ids": None, # encoder_outputs is defined. input_ids not needed
"encoder_outputs": encoder_outputs, "encoder_outputs": encoder_outputs,
"past_key_values": past_key_values, "past_key_values": past,
"decoder_input_ids": decoder_input_ids, "decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask, "attention_mask": attention_mask,
"head_mask": head_mask, "head_mask": head_mask,
...@@ -1536,15 +1513,10 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel, TFCausalLangua ...@@ -1536,15 +1513,10 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel, TFCausalLangua
@staticmethod @staticmethod
# Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration._reorder_cache # Copied from transformers.models.bart.modeling_tf_bart.TFBartForConditionalGeneration._reorder_cache
def _reorder_cache(past, beam_idx): def _reorder_cache(past, beam_idx):
if len(past) == 1:
return past
past_key_values = past[1]
reordered_past = () reordered_past = ()
for layer_past_key_values in past_key_values: for layer_past in past:
# cached cross_attention states don't have to be reordered -> they are always the same
reordered_past += ( reordered_past += (
tuple(tf.gather(layer_past_key_value, beam_idx) for layer_past_key_value in layer_past_key_values[:2]) tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past[:2]) + layer_past[2:],
+ layer_past_key_values[2:],
) )
return (past[0], reordered_past) return reordered_past
...@@ -16,14 +16,13 @@ ...@@ -16,14 +16,13 @@
"""TFRAG model implementation.""" """TFRAG model implementation."""
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple from typing import List, Optional, Tuple
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...file_utils import ModelOutput, add_start_docstrings_to_model_forward, replace_return_docstrings from ...file_utils import ModelOutput, add_start_docstrings_to_model_forward, replace_return_docstrings
from ...modeling_tf_outputs import TFBaseModelOutput
from ...modeling_tf_utils import TFCausalLanguageModelingLoss, TFPreTrainedModel, input_processing, shape_list from ...modeling_tf_utils import TFCausalLanguageModelingLoss, TFPreTrainedModel, input_processing, shape_list
from ...utils import logging from ...utils import logging
from .configuration_rag import RagConfig from .configuration_rag import RagConfig
...@@ -788,42 +787,28 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss ...@@ -788,42 +787,28 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
# Adapted from https://github.com/huggingface/transformers/blob/master/src/transformers/modeling_tf_bart.py # Adapted from https://github.com/huggingface/transformers/blob/master/src/transformers/modeling_tf_bart.py
def prepare_inputs_for_generation( def prepare_inputs_for_generation(
self, decoder_input_ids, past, attention_mask, use_cache, doc_scores, n_docs=None, **kwargs self,
) -> Dict: decoder_input_ids,
assert past is not None and len(past) in {1, 2}, f"past has to be an iterable of length 1,2 got {past}" past=None,
attention_mask=None,
if len(past) == 1: use_cache=None,
assert isinstance(past[0], tf.Tensor) encoder_outputs=None,
encoder_outputs = TFBaseModelOutput(last_hidden_state=past[0]) doc_scores=None,
decoder_cached_states = None n_docs=None,
else: **kwargs
assert len(past) == 2 ):
# Note: encoder_outputs is never changed by Bart as a generator if past is not None:
encoder_outputs, decoder_cached_states = past # if past is defined use only last decoder_input_ids
if isinstance(encoder_outputs, tuple):
assert isinstance(encoder_outputs[0], tf.Tensor)
encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_outputs[0])
elif isinstance(encoder_outputs, tf.Tensor):
encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_outputs)
assert (
decoder_cached_states
), f"decoder cached states must be truthy. got {decoder_cached_states} from the 2nd element of past"
# if past is defined cut decoder_input_ids to last token
decoder_input_ids = decoder_input_ids[:, -1:] decoder_input_ids = decoder_input_ids[:, -1:]
assert isinstance(
encoder_outputs, TFBaseModelOutput
), f"encoder_outputs should be a TFBaseModelOutput, Instead got {type(encoder_outputs)}."
return { return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed "input_ids": None,
"encoder_outputs": encoder_outputs, "encoder_outputs": encoder_outputs,
"doc_scores": doc_scores, "doc_scores": doc_scores,
"context_attention_mask": attention_mask, "context_attention_mask": attention_mask,
"decoder_input_ids": decoder_input_ids, "decoder_input_ids": decoder_input_ids,
"past_key_values": decoder_cached_states, "past_key_values": past,
"use_cache": use_cache, # change this to avoid caching (presumably for debugging) "use_cache": use_cache,
"do_marginalize": True, "do_marginalize": True,
"n_docs": n_docs, "n_docs": n_docs,
} }
...@@ -844,46 +829,19 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss ...@@ -844,46 +829,19 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
def _reorder_cache(past, beam_idx): def _reorder_cache(past, beam_idx):
"""Reorders cache for generation. BART-inspired but we need to take care of the extra dimension for docs""" """Reorders cache for generation. BART-inspired but we need to take care of the extra dimension for docs"""
def tf_index_select(input_, dim, indices): def _reorder_stacked(hidden_states, new_order):
"""
Input:
input_(tensor): input tensor dim(int): dimension indices(list): selected indices list
Output:
mimic of torch_tensor.index_select(dim, indices)
credit:
https://stackoverflow.com/questions/58464790/is-there-an-equivalent-function-of-pytorch-named-index-select-in-tensorflow
"""
shape = shape_list(input_)
if dim == -1:
dim = len(shape) - 1
shape[dim] = 1
tmp = []
for idx in indices:
begin = [0] * len(shape)
begin[dim] = idx
tmp.append(tf.slice(input_, begin, shape))
res = tf.concat(tmp, axis=dim)
return res
def _reorder_stacked(hidden_states, new_order=beam_idx):
n_docs = hidden_states.shape[0] // new_order.shape[0] n_docs = hidden_states.shape[0] // new_order.shape[0]
hidden_states = tf.reshape(hidden_states, (-1, n_docs, *hidden_states.shape[1:])) hidden_states = tf.reshape(hidden_states, (-1, n_docs, *hidden_states.shape[1:]))
hidden_states = tf_index_select(hidden_states, 0, new_order) hidden_states = tf.gather(hidden_states, new_order, axis=0)
return tf.reshape(hidden_states, (-1, *hidden_states.shape[2:])) result = tf.reshape(hidden_states, (-1, *hidden_states.shape[2:]))
return result
if len(past) == 1:
return past
past_key_values = past[1]
reordered_past = () reordered_past = ()
for layer_past in past_key_values: for layer_past in past:
# get the correct batch idx from decoder layer's batch dim for cross and self-attn
reordered_past += (tuple(_reorder_stacked(past_state, beam_idx) for past_state in layer_past),) reordered_past += (tuple(_reorder_stacked(past_state, beam_idx) for past_state in layer_past),)
return (past[0], reordered_past) return reordered_past
def marginalize(self, seq_logits, doc_scores, n_docs=None): def marginalize(self, seq_logits, doc_scores, n_docs=None):
n_docs = n_docs if n_docs is not None else self.config.n_docs n_docs = n_docs if n_docs is not None else self.config.n_docs
...@@ -1268,14 +1226,6 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss ...@@ -1268,14 +1226,6 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
return_dict=True, return_dict=True,
) )
if return_dict_in_generate:
# TODO(Patrick): `encoder_outputs`, `past` hack.
# Remove after cleaning encoder-decoder outputs
if output_attentions:
model_kwargs["encoder_attentions"] = encoder_outputs.attentions
if output_hidden_states:
model_kwargs["encoder_hidden_states"] = encoder_outputs.hidden_states
decoder_input_ids = tf.fill( decoder_input_ids = tf.fill(
(batch_size * num_beams, 1), (batch_size * num_beams, 1),
tf.cast(decoder_start_token_id, tf.int32), tf.cast(decoder_start_token_id, tf.int32),
...@@ -1366,10 +1316,6 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss ...@@ -1366,10 +1316,6 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
model_kwargs.pop("output_attentions", None) model_kwargs.pop("output_attentions", None)
model_kwargs.pop("output_scores", None) model_kwargs.pop("output_scores", None)
# TODO(Patrick): `encoder_outputs`, `past` hack.
# Remove after cleaning encoder-decoder outputs
model_kwargs["past"] = encoder_outputs
return self.greedy_search( return self.greedy_search(
input_ids=decoder_input_ids, input_ids=decoder_input_ids,
max_length=max_length, max_length=max_length,
......
...@@ -1176,17 +1176,17 @@ class TFRemBertForCausalLM(TFRemBertPreTrainedModel, TFCausalLanguageModelingLos ...@@ -1176,17 +1176,17 @@ class TFRemBertForCausalLM(TFRemBertPreTrainedModel, TFCausalLanguageModelingLos
return self.mlm.predictions return self.mlm.predictions
# Copied from transformers.models.bert.modeling_tf_bert.TFBertLMHeadModel.prepare_inputs_for_generation # Copied from transformers.models.bert.modeling_tf_bert.TFBertLMHeadModel.prepare_inputs_for_generation
def prepare_inputs_for_generation(self, inputs, past=None, attention_mask=None, **model_kwargs): def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
input_shape = input_ids.shape
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = tf.ones(input_shape)
# cut decoder_input_ids if past is used # cut decoder_input_ids if past is used
if past: if past is not None:
inputs = tf.expand_dims(inputs[:, -1], -1) input_ids = input_ids[:, -1:]
return { return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past}
"input_ids": inputs,
"attention_mask": attention_mask,
"past_key_values": past,
"use_cache": model_kwargs["use_cache"],
}
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_TOKENIZER_FOR_DOC, processor_class=_TOKENIZER_FOR_DOC,
...@@ -1309,6 +1309,14 @@ class TFRemBertForCausalLM(TFRemBertPreTrainedModel, TFCausalLanguageModelingLos ...@@ -1309,6 +1309,14 @@ class TFRemBertForCausalLM(TFRemBertPreTrainedModel, TFCausalLanguageModelingLos
logits=output.logits, past_key_values=pkv, hidden_states=hs, attentions=attns, cross_attentions=cross_attns logits=output.logits, past_key_values=pkv, hidden_states=hs, attentions=attns, cross_attentions=cross_attns
) )
@staticmethod
# Copied from transformers.models.bert.modeling_tf_bert.TFBertLMHeadModel._reorder_cache
def _reorder_cache(past, beam_idx):
reordered_past = ()
for layer_past in past:
reordered_past += (tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past),)
return reordered_past
@add_start_docstrings( @add_start_docstrings(
""" """
......
...@@ -1209,17 +1209,17 @@ class TFRobertaForCausalLM(TFRobertaPreTrainedModel, TFCausalLanguageModelingLos ...@@ -1209,17 +1209,17 @@ class TFRobertaForCausalLM(TFRobertaPreTrainedModel, TFCausalLanguageModelingLos
return self.name + "/" + self.lm_head.name return self.name + "/" + self.lm_head.name
# Copied from transformers.models.bert.modeling_tf_bert.TFBertLMHeadModel.prepare_inputs_for_generation # Copied from transformers.models.bert.modeling_tf_bert.TFBertLMHeadModel.prepare_inputs_for_generation
def prepare_inputs_for_generation(self, inputs, past=None, attention_mask=None, **model_kwargs): def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
input_shape = input_ids.shape
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = tf.ones(input_shape)
# cut decoder_input_ids if past is used # cut decoder_input_ids if past is used
if past: if past is not None:
inputs = tf.expand_dims(inputs[:, -1], -1) input_ids = input_ids[:, -1:]
return { return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past}
"input_ids": inputs,
"attention_mask": attention_mask,
"past_key_values": past,
"use_cache": model_kwargs["use_cache"],
}
@add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings( @add_code_sample_docstrings(
...@@ -1344,6 +1344,14 @@ class TFRobertaForCausalLM(TFRobertaPreTrainedModel, TFCausalLanguageModelingLos ...@@ -1344,6 +1344,14 @@ class TFRobertaForCausalLM(TFRobertaPreTrainedModel, TFCausalLanguageModelingLos
logits=output.logits, past_key_values=pkv, hidden_states=hs, attentions=attns, cross_attentions=cross_attns logits=output.logits, past_key_values=pkv, hidden_states=hs, attentions=attns, cross_attentions=cross_attns
) )
@staticmethod
# Copied from transformers.models.bert.modeling_tf_bert.TFBertLMHeadModel._reorder_cache
def _reorder_cache(past, beam_idx):
reordered_past = ()
for layer_past in past:
reordered_past += (tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past),)
return reordered_past
class TFRobertaClassificationHead(tf.keras.layers.Layer): class TFRobertaClassificationHead(tf.keras.layers.Layer):
"""Head for sentence-level classification tasks.""" """Head for sentence-level classification tasks."""
......
...@@ -1139,7 +1139,7 @@ class TFSpeech2TextDecoder(tf.keras.layers.Layer): ...@@ -1139,7 +1139,7 @@ class TFSpeech2TextDecoder(tf.keras.layers.Layer):
if inputs["output_hidden_states"]: if inputs["output_hidden_states"]:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
next_cache = (inputs["encoder_hidden_states"], next_decoder_cache) if use_cache else None next_cache = next_decoder_cache if use_cache else None
if not inputs["return_dict"]: if not inputs["return_dict"]:
return hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attns return hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attns
...@@ -1571,26 +1571,17 @@ class TFSpeech2TextForConditionalGeneration(TFSpeech2TextPreTrainedModel, TFCaus ...@@ -1571,26 +1571,17 @@ class TFSpeech2TextForConditionalGeneration(TFSpeech2TextPreTrainedModel, TFCaus
decoder_head_mask=None, decoder_head_mask=None,
cross_attn_head_mask=None, cross_attn_head_mask=None,
use_cache=None, use_cache=None,
encoder_outputs=None,
**kwargs **kwargs
): ):
if past is not None and len(past) <= 2: # cut decoder_input_ids if past is used
if not isinstance(past[0], tf.Tensor): if past is not None:
raise ValueError(f"`past[0]` has to be of type `tf.Tensor`, but is {type(past[0])}")
encoder_outputs = TFBaseModelOutput(last_hidden_state=past[0])
if len(past) == 1:
past_key_values = None
else:
past_key_values = past[1]
if not past_key_values:
raise ValueError(f"decoder cached states must be truthy, got {past_key_values}")
decoder_input_ids = decoder_input_ids[:, -1:] decoder_input_ids = decoder_input_ids[:, -1:]
else:
raise ValueError(f"`past` must be an iterable with length 1 or 2, got {past}")
return { return {
"input_features": None, # needs to be passed to make Keras.layer.__call__ happy "input_features": None, # needs to be passed to make Keras.layer.__call__ happy
"encoder_outputs": encoder_outputs, "encoder_outputs": encoder_outputs,
"past_key_values": past_key_values, "past_key_values": past,
"decoder_input_ids": decoder_input_ids, "decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask, "attention_mask": attention_mask,
"head_mask": head_mask, "head_mask": head_mask,
...@@ -1601,15 +1592,7 @@ class TFSpeech2TextForConditionalGeneration(TFSpeech2TextPreTrainedModel, TFCaus ...@@ -1601,15 +1592,7 @@ class TFSpeech2TextForConditionalGeneration(TFSpeech2TextPreTrainedModel, TFCaus
@staticmethod @staticmethod
def _reorder_cache(past, beam_idx): def _reorder_cache(past, beam_idx):
if len(past) == 1:
return past
past_key_values = past[1]
reordered_past = () reordered_past = ()
for layer_past_key_values in past_key_values: for layer_past in past:
reordered_past += ( reordered_past += (tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past),)
tuple(tf.gather(layer_past_key_value, beam_idx) for layer_past_key_value in layer_past_key_values[:2]) return reordered_past
+ layer_past_key_values[2:],
)
return (past[0], reordered_past)
...@@ -1256,15 +1256,13 @@ class TFT5Model(TFT5PreTrainedModel): ...@@ -1256,15 +1256,13 @@ class TFT5Model(TFT5PreTrainedModel):
return_dict=inputs["return_dict"], return_dict=inputs["return_dict"],
training=inputs["training"], training=inputs["training"],
) )
past = decoder_outputs[1] if inputs["use_cache"] else None
if not inputs["return_dict"]: if not inputs["return_dict"]:
past = (inputs["encoder_outputs"], decoder_outputs[1]) if inputs["use_cache"] else None
if past is not None: if past is not None:
decoder_outputs = decoder_outputs[:1] + (past,) + decoder_outputs[2:] decoder_outputs = decoder_outputs[:1] + (past,) + decoder_outputs[2:]
return decoder_outputs + inputs["encoder_outputs"] return decoder_outputs + inputs["encoder_outputs"]
past = (inputs["encoder_outputs"].to_tuple(), decoder_outputs[1]) if inputs["use_cache"] else None
return TFSeq2SeqModelOutput( return TFSeq2SeqModelOutput(
last_hidden_state=decoder_outputs.last_hidden_state, last_hidden_state=decoder_outputs.last_hidden_state,
past_key_values=past, past_key_values=past,
...@@ -1483,8 +1481,8 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling ...@@ -1483,8 +1481,8 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], logits) loss = None if inputs["labels"] is None else self.hf_compute_loss(inputs["labels"], logits)
past = decoder_outputs[1] if inputs["use_cache"] else None
if not inputs["return_dict"]: if not inputs["return_dict"]:
past = (inputs["encoder_outputs"], decoder_outputs[1]) if inputs["use_cache"] else None
if past is not None: if past is not None:
decoder_outputs = decoder_outputs[:1] + (past,) + decoder_outputs[2:] decoder_outputs = decoder_outputs[:1] + (past,) + decoder_outputs[2:]
output = (logits,) + decoder_outputs[1:] + inputs["encoder_outputs"] output = (logits,) + decoder_outputs[1:] + inputs["encoder_outputs"]
...@@ -1509,8 +1507,6 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling ...@@ -1509,8 +1507,6 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
attentions=attentions, attentions=attentions,
) )
past = (inputs["encoder_outputs"].to_tuple(), decoder_outputs[1]) if inputs["use_cache"] else None
return TFSeq2SeqLMOutput( return TFSeq2SeqLMOutput(
loss=loss, loss=loss,
logits=logits, logits=logits,
...@@ -1544,65 +1540,57 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling ...@@ -1544,65 +1540,57 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
def prepare_inputs_for_generation( def prepare_inputs_for_generation(
self, self,
inputs, input_ids,
past, past=None,
attention_mask, attention_mask=None,
head_mask=None,
decoder_head_mask=None,
use_cache=None, use_cache=None,
**kwargs, encoder_outputs=None,
**kwargs
): ):
assert past is not None, "past has to be defined for encoder_outputs"
# first step
if len(past) < 2:
encoder_outputs, past_key_values = past, None
else:
encoder_outputs, past_key_values = past[0], past[1]
if "encoder_hidden_states" in kwargs:
encoder_outputs = (*encoder_outputs, kwargs["encoder_hidden_states"])
if "encoder_attentions" in kwargs:
encoder_outputs = (*encoder_outputs, kwargs["encoder_attentions"])
# cut decoder_input_ids if past is used # cut decoder_input_ids if past is used
if past_key_values is not None: if past is not None:
inputs = inputs[:, -1:] input_ids = input_ids[:, -1:]
return { return {
"input_ids": None, # inputs don't have to be defined, but still need to be passed to make Keras.layer.__call__ happy "input_ids": None, # needs to be passed to make Keras.layer.__call__ happy
"decoder_input_ids": inputs, # inputs are the decoder_input_ids "decoder_input_ids": input_ids,
"past_key_values": past_key_values, "past_key_values": past,
"encoder_outputs": encoder_outputs, "encoder_outputs": encoder_outputs,
"attention_mask": attention_mask, "attention_mask": attention_mask,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
"use_cache": use_cache, "use_cache": use_cache,
} }
def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor): def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor):
return self._shift_right(labels) return self._shift_right(labels)
def _reorder_cache(self, past, beam_idx) -> Tuple: def _reorder_cache(self, past, beam_idx):
# if decoder past is not included in output # if decoder past is not included in output
# speedy decoding is disabled and no need to reorder # speedy decoding is disabled and no need to reorder
if past is None:
if len(past) < 2:
logger.warning("You might want to consider setting `use_cache=True` to speed up decoding") logger.warning("You might want to consider setting `use_cache=True` to speed up decoding")
return past return past
decoder_past = past[1]
past = (past[0],)
reordered_decoder_past = () reordered_decoder_past = ()
for layer_past_states in past:
for layer_past_states in decoder_past:
# get the correct batch idx from layer past batch dim # get the correct batch idx from layer past batch dim
# batch dim of `past` is at 2nd position # batch dim of `past` is at 2nd position
reordered_layer_past_states = () reordered_layer_past_states = ()
for layer_past_state in layer_past_states: for layer_past_state in layer_past_states:
# need to set correct `past` for each of the four key / value states # need to set correct `past` for each of the four key / value states
reordered_layer_past_states = reordered_layer_past_states + (tf.gather(layer_past_state, beam_idx),) reordered_layer_past_states = reordered_layer_past_states + (
tf.gather(layer_past_state, beam_idx, axis=0),
)
assert shape_list(reordered_layer_past_states[0]) == shape_list(layer_past_states[0]) assert reordered_layer_past_states[0].shape == layer_past_states[0].shape
assert len(reordered_layer_past_states) == len(layer_past_states) assert len(reordered_layer_past_states) == len(layer_past_states)
reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,) reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,)
return past + (reordered_decoder_past,) return reordered_decoder_past
@add_start_docstrings( @add_start_docstrings(
......
...@@ -1058,15 +1058,22 @@ class TFTransfoXLLMHeadModel(TFTransfoXLPreTrainedModel): ...@@ -1058,15 +1058,22 @@ class TFTransfoXLLMHeadModel(TFTransfoXLPreTrainedModel):
attentions=attns, attentions=attns,
) )
def prepare_inputs_for_generation(self, inputs, past, **model_kwargs): def prepare_inputs_for_generation(self, input_ids, past=None, **model_kwargs):
inputs = {"input_ids": inputs} inputs = {}
# if past is defined in model kwargs then use it for faster decoding # if past is defined in model kwargs then use it for faster decoding
if past: if past:
inputs["mems"] = past inputs["mems"] = past
inputs["input_ids"] = tf.expand_dims(input_ids[:, -1], axis=-1)
else:
inputs["input_ids"] = input_ids
return inputs return inputs
@staticmethod
def _reorder_cache(mems: List[tf.Tensor], beam_idx: tf.Tensor) -> List[tf.Tensor]:
return [tf.gather(layer_past, beam_idx, axis=1) for layer_past in mems]
@add_start_docstrings( @add_start_docstrings(
""" """
......
...@@ -722,45 +722,22 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos ...@@ -722,45 +722,22 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos
cross_attentions=cross_attns, cross_attentions=cross_attns,
) )
def prepare_inputs_for_generation(self, decoder_input_ids, past, use_cache=None, **kwargs): def prepare_inputs_for_generation(
if past is None or len(past) not in {1, 2}: self, input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
raise ValueError(f"past has to be an iterable of length 1,2 got {past}") ):
decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past=past)
if len(past) == 1: decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None
if not isinstance(past[0], tf.Tensor): input_dict = {
raise ValueError(f"`past[0]` has to be of type `tf.Tensor`, but is {type(past[0])}") "pixel_values": None, # needs to be passed to make Keras.layer.__call__ happy
encoder_outputs = TFBaseModelOutput(last_hidden_state=past[0]) "attention_mask": attention_mask,
past_key_values = None "decoder_attention_mask": decoder_attention_mask,
else: "decoder_input_ids": decoder_inputs["input_ids"],
if len(past) != 2: # TODO (joao): the `TFBaseModelOutput` wrapper should not be needed after the generate refactor is complete
raise ValueError( "encoder_outputs": TFBaseModelOutput(last_hidden_state=encoder_outputs[0]),
"`past` has to be of length 2 with the encoder_outputs at the first position and past_key_values at the second position." "past_key_values": decoder_inputs["past_key_values"],
) "use_cache": use_cache,
encoder_outputs, past_key_values = past
if isinstance(encoder_outputs, tuple):
if not isinstance(encoder_outputs[0], tf.Tensor):
raise ValueError(
f"`encoder_outputs[0]` has to be of type `tf.Tensor`, but is {type(encoder_outputs[0])}"
)
encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_outputs[0])
elif isinstance(encoder_outputs, tf.Tensor):
encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_outputs)
if not past_key_values:
raise ValueError(
f"decoder cached states must be truthy. got {past_key_values} from the 2nd element of past"
)
decoder_input_ids = decoder_input_ids[:, -1:]
if not isinstance(encoder_outputs, TFBaseModelOutput):
raise ValueError(f"encoder_outputs should be a TFBaseModelOutput, Instead got {type(encoder_outputs)}.")
return {
"pixel_values": None, # encoder_outputs is defined. pixel_values not needed
"encoder_outputs": encoder_outputs,
"past_key_values": past_key_values,
"decoder_input_ids": decoder_input_ids,
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
} }
return input_dict
def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor): def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor):
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
...@@ -773,9 +750,4 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos ...@@ -773,9 +750,4 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos
def _reorder_cache(self, past, beam_idx): def _reorder_cache(self, past, beam_idx):
# apply decoder cache reordering here # apply decoder cache reordering here
if len(past) == 1: return self.decoder._reorder_cache(past, beam_idx)
return past
encoder_outputs, past_key_values = past
return (encoder_outputs, self.decoder._reorder_cache(past_key_values, beam_idx))
...@@ -1246,17 +1246,17 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -1246,17 +1246,17 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning) warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
return self.name + "/" + self.lm_loss.name return self.name + "/" + self.lm_loss.name
def prepare_inputs_for_generation(self, inputs, past, use_mems=None, **kwargs): def prepare_inputs_for_generation(self, inputs, past=None, use_mems=None, **kwargs):
# Add dummy token at the end (no attention on this one) # Add dummy token at the end (no attention on this one)
effective_batch_size = inputs.shape[0]
dummy_token = tf.zeros((effective_batch_size, 1), dtype=inputs.dtype)
# At every pass, the attention values for the new token and the two last generated tokens # At every pass, the attention values for the new token and the two last generated tokens
# are computed, the rest is reloaded from the `past` cache. A purely auto-regressive model would have # are computed, the rest is reloaded from the `past` cache. A purely auto-regressive model would have
# offset = 1; offset = 2 seems to have slightly better computation. # offset = 1; offset = 2 seems to have slightly better computation.
offset = 2 offset = 2
effective_batch_size = inputs.shape[0]
dummy_token = tf.zeros((effective_batch_size, 1), dtype=inputs.dtype)
if past: if past:
inputs = tf.concat([inputs[:, -offset:], dummy_token], axis=1) inputs = tf.concat([inputs[:, -offset:], dummy_token], axis=1)
else: else:
...@@ -1277,7 +1277,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -1277,7 +1277,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
"input_ids": inputs, "input_ids": inputs,
"perm_mask": perm_mask, "perm_mask": perm_mask,
"target_mapping": target_mapping, "target_mapping": target_mapping,
"use_mems": kwargs.get("use_mems"), "use_mems": use_mems,
} }
# if past is defined in model kwargs then use it for faster decoding # if past is defined in model kwargs then use it for faster decoding
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment