Unverified Commit 8b240a06 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Add TFEncoderDecoderModel + Add cross-attention to some TF models (#13222)



* Add cross attentions to TFGPT2Model

* Add TFEncoderDecoderModel

* Add TFBaseModelOutputWithPoolingAndCrossAttentions

* Add cross attentions to TFBertModel

* Fix past or past_key_values argument issue

* Fix generation

* Fix save and load

* Add some checks and comments

* Clean the code that deals with past keys/values

* Add kwargs to processing_inputs

* Add serving_output to TFEncoderDecoderModel

* Some cleaning + fix use_cache value issue

* Fix tests + add bert2bert/bert2gpt2 tests

* Fix more tests

* Ignore crossattention.bias when loading GPT2 weights into TFGPT2

* Fix return_dict_in_generate in tf generation

* Fix is_token_logit_eos_token bug in tf generation

* Finalize the tests after fixing some bugs

* Fix another is_token_logit_eos_token bug in tf generation

* Add/Update docs

* Add TFBertEncoderDecoderModelTest

* Clean test script

* Add TFEncoderDecoderModel to the library

* Add cross attentions to TFRobertaModel

* Add TFRobertaEncoderDecoderModelTest

* make style

* Change the way of position_ids computation

* bug fix

* Fix copies in tf_albert

* Remove some copied from and apply some fix-copies

* Remove some copied

* Add cross attentions to some other TF models

* Remove encoder_hidden_states from TFLayoutLMModel.call for now

* Make style

* Fix TFRemBertForCausalLM

* Revert the change to longformer + Remove copies

* Revert the change to albert and convbert + Remove copies

* make quality

* make style

* Add TFRembertEncoderDecoderModelTest

* make quality and fix-copies

* test TFRobertaForCausalLM

* Fixes for failed tests

* Fixes for failed tests

* fix more tests

* Fixes for failed tests

* Fix Auto mapping order

* Fix TFRemBertEncoder return value

* fix tf_rembert

* Check copies are OK

* Fix missing TFBaseModelOutputWithPastAndCrossAttentions is not defined

* Add TFEncoderDecoderModelSaveLoadTests

* fix tf weight loading

* check the change of use_cache

* Revert the change

* Add missing test_for_causal_lm for TFRobertaModelTest

* Try cleaning past

* fix _reorder_cache

* Revert some files to original versions

* Keep as many copies as possible

* Apply suggested changes - Use raise ValueError instead of assert

* Move import to top

* Fix wrong require_torch

* Replace more assert by raise ValueError

* Add test_pt_tf_model_equivalence (the test won't pass for now)

* add test for loading/saving

* finish

* finish

* Remove test_pt_tf_model_equivalence

* Update tf modeling template

* Remove pooling, added in the prev. commit, from MainLayer

* Update tf modeling test template

* Move inputs["use_cache"] = False to modeling_tf_utils.py

* Fix torch.Tensor in the comment

* fix use_cache

* Fix missing use_cache in ElectraConfig

* Add a note to from_pretrained

* Fix style

* Change test_encoder_decoder_save_load_from_encoder_decoder_from_pt

* Fix TFMLP (in TFGPT2) activation issue

* Fix None past_key_values value in serving_output

* Don't call get_encoderdecoder_model in TFEncoderDecoderModelTest.test_configuration_tie until we have a TF checkpoint on Hub

* Apply review suggestions - style for cross_attns in serving_output

* Apply review suggestions - change assert + docstrings

* break the error message to respect the char limit

* deprecate the argument past

* fix docstring style

* Update the encoder-decoder rst file

* fix Unknown interpreted text role "method"

* fix typo
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 26b6ef79
...@@ -379,7 +379,7 @@ Flax), PyTorch, and/or TensorFlow. ...@@ -379,7 +379,7 @@ Flax), PyTorch, and/or TensorFlow.
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| ELECTRA | ✅ | ✅ | ✅ | ✅ | ✅ | | ELECTRA | ✅ | ✅ | ✅ | ✅ | ✅ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| Encoder decoder | ❌ | ❌ | ✅ | | ✅ | | Encoder decoder | ❌ | ❌ | ✅ | | ✅ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| FairSeq Machine-Translation | ✅ | ❌ | ✅ | ❌ | ❌ | | FairSeq Machine-Translation | ✅ | ❌ | ✅ | ❌ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
......
...@@ -210,6 +210,13 @@ TFBaseModelOutputWithPooling ...@@ -210,6 +210,13 @@ TFBaseModelOutputWithPooling
:members: :members:
TFBaseModelOutputWithPoolingAndCrossAttentions
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.modeling_tf_outputs.TFBaseModelOutputWithPoolingAndCrossAttentions
:members:
TFBaseModelOutputWithPast TFBaseModelOutputWithPast
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
...@@ -217,6 +224,13 @@ TFBaseModelOutputWithPast ...@@ -217,6 +224,13 @@ TFBaseModelOutputWithPast
:members: :members:
TFBaseModelOutputWithPastAndCrossAttentions
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.modeling_tf_outputs.TFBaseModelOutputWithPastAndCrossAttentions
:members:
TFSeq2SeqModelOutput TFSeq2SeqModelOutput
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
...@@ -231,6 +245,13 @@ TFCausalLMOutput ...@@ -231,6 +245,13 @@ TFCausalLMOutput
:members: :members:
TFCausalLMOutputWithCrossAttentions
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.modeling_tf_outputs.TFCausalLMOutputWithCrossAttentions
:members:
TFCausalLMOutputWithPast TFCausalLMOutputWithPast
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
......
...@@ -27,6 +27,25 @@ An application of this architecture could be to leverage two pretrained :class:` ...@@ -27,6 +27,25 @@ An application of this architecture could be to leverage two pretrained :class:`
and decoder for a summarization model as was shown in: `Text Summarization with Pretrained Encoders and decoder for a summarization model as was shown in: `Text Summarization with Pretrained Encoders
<https://arxiv.org/abs/1908.08345>`__ by Yang Liu and Mirella Lapata. <https://arxiv.org/abs/1908.08345>`__ by Yang Liu and Mirella Lapata.
The :meth:`~transformers.TFEncoderDecoderModel.from_pretrained` currently doesn't support initializing the model from a
pytorch checkpoint. Passing ``from_pt=True`` to this method will throw an exception. If there are only pytorch
checkpoints for a particular encoder-decoder model, a workaround is:
.. code-block::
>>> # a workaround to load from pytorch checkpoint
>>> _model = EncoderDecoderModel.from_pretrained("patrickvonplaten/bert2bert-cnn_dailymail-fp16")
>>> _model.encoder.save_pretrained("./encoder")
>>> _model.decoder.save_pretrained("./decoder")
>>> model = TFEncoderDecoderModel.from_encoder_decoder_pretrained(
... "./encoder", "./decoder", encoder_from_pt=True, decoder_from_pt=True
... )
>>> # This is only for copying some specific attributes of this particular model.
>>> model.config = _model.config
This model was contributed by `thomwolf <https://github.com/thomwolf>`__. This model's TensorFlow and Flax versions
were contributed by `ydshieh <https://github.com/ydshieh>`__.
EncoderDecoderConfig EncoderDecoderConfig
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
...@@ -42,6 +61,13 @@ EncoderDecoderModel ...@@ -42,6 +61,13 @@ EncoderDecoderModel
:members: forward, from_encoder_decoder_pretrained :members: forward, from_encoder_decoder_pretrained
TFEncoderDecoderModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TFEncoderDecoderModel
:members: call, from_encoder_decoder_pretrained
FlaxEncoderDecoderModel FlaxEncoderDecoderModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
......
...@@ -126,6 +126,13 @@ TFRobertaModel ...@@ -126,6 +126,13 @@ TFRobertaModel
:members: call :members: call
TFRobertaForCausalLM
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TFRobertaForCausalLM
:members: call
TFRobertaForMaskedLM TFRobertaForMaskedLM
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
......
...@@ -1444,6 +1444,7 @@ if is_tf_available(): ...@@ -1444,6 +1444,7 @@ if is_tf_available():
"TFElectraPreTrainedModel", "TFElectraPreTrainedModel",
] ]
) )
_import_structure["models.encoder_decoder"].append("TFEncoderDecoderModel")
_import_structure["models.flaubert"].extend( _import_structure["models.flaubert"].extend(
[ [
"TF_FLAUBERT_PRETRAINED_MODEL_ARCHIVE_LIST", "TF_FLAUBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
...@@ -1596,6 +1597,7 @@ if is_tf_available(): ...@@ -1596,6 +1597,7 @@ if is_tf_available():
_import_structure["models.roberta"].extend( _import_structure["models.roberta"].extend(
[ [
"TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST", "TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFRobertaForCausalLM",
"TFRobertaForMaskedLM", "TFRobertaForMaskedLM",
"TFRobertaForMultipleChoice", "TFRobertaForMultipleChoice",
"TFRobertaForQuestionAnswering", "TFRobertaForQuestionAnswering",
...@@ -3096,6 +3098,7 @@ if TYPE_CHECKING: ...@@ -3096,6 +3098,7 @@ if TYPE_CHECKING:
TFElectraModel, TFElectraModel,
TFElectraPreTrainedModel, TFElectraPreTrainedModel,
) )
from .models.encoder_decoder import TFEncoderDecoderModel
from .models.flaubert import ( from .models.flaubert import (
TF_FLAUBERT_PRETRAINED_MODEL_ARCHIVE_LIST, TF_FLAUBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFFlaubertForMultipleChoice, TFFlaubertForMultipleChoice,
...@@ -3205,6 +3208,7 @@ if TYPE_CHECKING: ...@@ -3205,6 +3208,7 @@ if TYPE_CHECKING:
) )
from .models.roberta import ( from .models.roberta import (
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST, TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
TFRobertaForCausalLM,
TFRobertaForMaskedLM, TFRobertaForMaskedLM,
TFRobertaForMultipleChoice, TFRobertaForMultipleChoice,
TFRobertaForQuestionAnswering, TFRobertaForQuestionAnswering,
......
...@@ -76,6 +76,7 @@ from . import ( ...@@ -76,6 +76,7 @@ from . import (
TFLxmertForPreTraining, TFLxmertForPreTraining,
TFLxmertVisualFeatureEncoder, TFLxmertVisualFeatureEncoder,
TFOpenAIGPTLMHeadModel, TFOpenAIGPTLMHeadModel,
TFRobertaForCausalLM,
TFRobertaForMaskedLM, TFRobertaForMaskedLM,
TFRobertaForSequenceClassification, TFRobertaForSequenceClassification,
TFT5ForConditionalGeneration, TFT5ForConditionalGeneration,
...@@ -215,6 +216,7 @@ MODEL_CLASSES = { ...@@ -215,6 +216,7 @@ MODEL_CLASSES = {
), ),
"roberta": ( "roberta": (
RobertaConfig, RobertaConfig,
TFRobertaForCausalLM,
TFRobertaForMaskedLM, TFRobertaForMaskedLM,
RobertaForMaskedLM, RobertaForMaskedLM,
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
......
...@@ -650,7 +650,11 @@ class TFGenerationMixin: ...@@ -650,7 +650,11 @@ class TFGenerationMixin:
# current position and vocab size # current position and vocab size
cur_len = shape_list(input_ids)[1] # unused cur_len = shape_list(input_ids)[1] # unused
vocab_size = self.config.vocab_size vocab_size = getattr(self.config, "vocab_size", None)
if vocab_size is None and self.config.is_encoder_decoder:
decoder_config = getattr(self.config, "decoder", None)
if decoder_config is not None:
vocab_size = getattr(self.config.decoder, "vocab_size", None)
# set effective batch size and effective batch multiplier according to do_sample # set effective batch size and effective batch multiplier according to do_sample
if do_sample: if do_sample:
...@@ -678,6 +682,7 @@ class TFGenerationMixin: ...@@ -678,6 +682,7 @@ class TFGenerationMixin:
attention_mask=attention_mask, attention_mask=attention_mask,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict_in_generate,
) )
if return_dict_in_generate: if return_dict_in_generate:
if output_attentions: if output_attentions:
...@@ -911,7 +916,7 @@ class TFGenerationMixin: ...@@ -911,7 +916,7 @@ class TFGenerationMixin:
if eos_token_id is not None and cur_len < min_length: if eos_token_id is not None and cur_len < min_length:
# create eos_token_id boolean mask # create eos_token_id boolean mask
is_token_logit_eos_token = tf.convert_to_tensor( is_token_logit_eos_token = tf.convert_to_tensor(
[True if token is eos_token_id else False for token in range(vocab_size)], dtype=tf.bool [True if token == eos_token_id else False for token in range(vocab_size)], dtype=tf.bool
) )
eos_token_indices_mask = tf.broadcast_to(is_token_logit_eos_token, [batch_size, vocab_size]) eos_token_indices_mask = tf.broadcast_to(is_token_logit_eos_token, [batch_size, vocab_size])
...@@ -1142,7 +1147,7 @@ class TFGenerationMixin: ...@@ -1142,7 +1147,7 @@ class TFGenerationMixin:
num_batch_hypotheses = batch_size * num_beams num_batch_hypotheses = batch_size * num_beams
is_token_logit_eos_token = tf.convert_to_tensor( is_token_logit_eos_token = tf.convert_to_tensor(
[True if token is eos_token_id else False for token in range(vocab_size)], dtype=tf.bool [True if token == eos_token_id else False for token in range(vocab_size)], dtype=tf.bool
) )
eos_token_indices_mask = tf.broadcast_to(is_token_logit_eos_token, [num_batch_hypotheses, vocab_size]) eos_token_indices_mask = tf.broadcast_to(is_token_logit_eos_token, [num_batch_hypotheses, vocab_size])
...@@ -1446,11 +1451,17 @@ class TFGenerationMixin: ...@@ -1446,11 +1451,17 @@ class TFGenerationMixin:
Implement in subclasses of :class:`~transformers.PreTrainedModel` for custom behavior to adjust the logits in Implement in subclasses of :class:`~transformers.PreTrainedModel` for custom behavior to adjust the logits in
the generate method. the generate method.
""" """
vocab_size = getattr(self.config, "vocab_size", None)
if vocab_size is None and self.config.is_encoder_decoder:
decoder_config = getattr(self.config, "decoder", None)
if decoder_config is not None:
vocab_size = getattr(self.config.decoder, "vocab_size", None)
if cur_len == 1 and forced_bos_token_id is not None: if cur_len == 1 and forced_bos_token_id is not None:
vocab_range = tf.constant(range(self.config.vocab_size)) vocab_range = tf.constant(range(vocab_size))
return tf.where(vocab_range != forced_bos_token_id, -1e8, logits) return tf.where(vocab_range != forced_bos_token_id, -1e8, logits)
elif cur_len == max_length - 1 and forced_eos_token_id is not None: elif cur_len == max_length - 1 and forced_eos_token_id is not None:
vocab_range = tf.constant(range(self.config.vocab_size)) vocab_range = tf.constant(range(vocab_size))
return tf.where(vocab_range != forced_eos_token_id, -1e8, logits) return tf.where(vocab_range != forced_eos_token_id, -1e8, logits)
else: else:
return logits return logits
......
...@@ -80,6 +80,54 @@ class TFBaseModelOutputWithPooling(ModelOutput): ...@@ -80,6 +80,54 @@ class TFBaseModelOutputWithPooling(ModelOutput):
attentions: Optional[Tuple[tf.Tensor]] = None attentions: Optional[Tuple[tf.Tensor]] = None
@dataclass
class TFBaseModelOutputWithPoolingAndCrossAttentions(ModelOutput):
"""
Base class for model's outputs that also contains a pooling of the last hidden states.
Args:
last_hidden_state (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
pooler_output (:obj:`tf.Tensor` of shape :obj:`(batch_size, hidden_size)`):
Last layer hidden-state of the first token of the sequence (classification token) further processed by a
Linear layer and a Tanh activation function. The Linear layer weights are trained from the next sentence
prediction (classification) objective during pretraining.
This output is usually *not* a good summary of the semantic content of the input, you're often better with
averaging or pooling the sequence of hidden-states for the whole input sequence.
past_key_values (:obj:`List[tf.Tensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
List of :obj:`tf.Tensor` of length :obj:`config.n_layers`, with each tensor of shape :obj:`(2, batch_size,
num_heads, sequence_length, embed_size_per_head)`).
Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
:obj:`past_key_values` input) to speed up sequential decoding.
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of
shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`tf.Tensor` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
cross_attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`tf.Tensor` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
weighted average in the cross-attention heads.
"""
last_hidden_state: tf.Tensor = None
pooler_output: tf.Tensor = None
past_key_values: Optional[List[tf.Tensor]] = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[tf.Tensor]] = None
cross_attentions: Optional[Tuple[tf.Tensor]] = None
@dataclass @dataclass
class TFBaseModelOutputWithPast(ModelOutput): class TFBaseModelOutputWithPast(ModelOutput):
""" """
...@@ -317,6 +365,49 @@ class TFCausalLMOutputWithPast(ModelOutput): ...@@ -317,6 +365,49 @@ class TFCausalLMOutputWithPast(ModelOutput):
attentions: Optional[Tuple[tf.Tensor]] = None attentions: Optional[Tuple[tf.Tensor]] = None
@dataclass
class TFCausalLMOutputWithCrossAttentions(ModelOutput):
"""
Base class for causal language model (or autoregressive) outputs.
Args:
loss (:obj:`tf.Tensor` of shape :obj:`(n,)`, `optional`, where n is the number of non-masked labels, returned when :obj:`labels` is provided):
Language modeling loss (for next-token prediction).
logits (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of
shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`tf.Tensor` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
cross_attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`tf.Tensor` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
weighted average in the cross-attention heads.
past_key_values (:obj:`List[tf.Tensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
List of :obj:`tf.Tensor` of length :obj:`config.n_layers`, with each tensor of shape :obj:`(2, batch_size,
num_heads, sequence_length, embed_size_per_head)`).
Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
:obj:`past_key_values` input) to speed up sequential decoding.
"""
loss: Optional[tf.Tensor] = None
logits: tf.Tensor = None
past_key_values: Optional[List[tf.Tensor]] = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[tf.Tensor]] = None
cross_attentions: Optional[Tuple[tf.Tensor]] = None
@dataclass @dataclass
class TFMaskedLMOutput(ModelOutput): class TFMaskedLMOutput(ModelOutput):
""" """
......
...@@ -296,7 +296,9 @@ def booleans_processing(config, **kwargs): ...@@ -296,7 +296,9 @@ def booleans_processing(config, **kwargs):
) )
if "use_cache" in kwargs: if "use_cache" in kwargs:
final_booleans["use_cache"] = kwargs["use_cache"] if kwargs["use_cache"] is not None else config.use_cache final_booleans["use_cache"] = (
kwargs["use_cache"] if kwargs["use_cache"] is not None else getattr(config, "use_cache", None)
)
else: else:
if ( if (
kwargs["output_attentions"] not in (None, config.output_attentions) kwargs["output_attentions"] not in (None, config.output_attentions)
...@@ -318,7 +320,7 @@ def booleans_processing(config, **kwargs): ...@@ -318,7 +320,7 @@ def booleans_processing(config, **kwargs):
final_booleans["return_dict"] = True final_booleans["return_dict"] = True
if "use_cache" in kwargs: if "use_cache" in kwargs:
final_booleans["use_cache"] = config.use_cache final_booleans["use_cache"] = getattr(config, "use_cache", None)
return final_booleans return final_booleans
...@@ -362,6 +364,15 @@ def input_processing(func, config, input_ids, **kwargs): ...@@ -362,6 +364,15 @@ def input_processing(func, config, input_ids, **kwargs):
) )
output["past_key_values"] = kwargs["kwargs_call"].pop("decoder_cached_states") output["past_key_values"] = kwargs["kwargs_call"].pop("decoder_cached_states")
if "past" in kwargs["kwargs_call"] and "past_key_values" in kwargs:
warnings.warn(
"The `past` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
FutureWarning,
)
kwargs["past_key_values"] = kwargs["kwargs_call"].pop("past")
elif "past_key_values" in kwargs["kwargs_call"] and "past" in kwargs:
kwargs["past"] = kwargs["kwargs_call"].pop("past_key_values")
if len(kwargs["kwargs_call"]) > 0: if len(kwargs["kwargs_call"]) > 0:
raise ValueError( raise ValueError(
f"The following keyword arguments are not supported by this model: {list(kwargs['kwargs_call'].keys())}." f"The following keyword arguments are not supported by this model: {list(kwargs['kwargs_call'].keys())}."
......
...@@ -157,6 +157,7 @@ class TFAlbertEmbeddings(tf.keras.layers.Layer): ...@@ -157,6 +157,7 @@ class TFAlbertEmbeddings(tf.keras.layers.Layer):
position_ids: tf.Tensor = None, position_ids: tf.Tensor = None,
token_type_ids: tf.Tensor = None, token_type_ids: tf.Tensor = None,
inputs_embeds: tf.Tensor = None, inputs_embeds: tf.Tensor = None,
past_key_values_length=0,
training: bool = False, training: bool = False,
) -> tf.Tensor: ) -> tf.Tensor:
""" """
...@@ -177,7 +178,9 @@ class TFAlbertEmbeddings(tf.keras.layers.Layer): ...@@ -177,7 +178,9 @@ class TFAlbertEmbeddings(tf.keras.layers.Layer):
token_type_ids = tf.fill(dims=input_shape, value=0) token_type_ids = tf.fill(dims=input_shape, value=0)
if position_ids is None: if position_ids is None:
position_ids = tf.expand_dims(tf.range(start=0, limit=input_shape[-1]), axis=0) position_ids = tf.expand_dims(
tf.range(start=past_key_values_length, limit=input_shape[1] + past_key_values_length), axis=0
)
position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids) position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids)
position_embeds = tf.tile(input=position_embeds, multiples=(input_shape[0], 1, 1)) position_embeds = tf.tile(input=position_embeds, multiples=(input_shape[0], 1, 1))
...@@ -829,7 +832,6 @@ class TFAlbertModel(TFAlbertPreTrainedModel): ...@@ -829,7 +832,6 @@ class TFAlbertModel(TFAlbertPreTrainedModel):
return outputs return outputs
# Copied from transformers.models.bert.modeling_tf_bert.TFBertModel.serving_output
def serving_output(self, output: TFBaseModelOutputWithPooling) -> TFBaseModelOutputWithPooling: def serving_output(self, output: TFBaseModelOutputWithPooling) -> TFBaseModelOutputWithPooling:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
......
...@@ -133,6 +133,7 @@ TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict( ...@@ -133,6 +133,7 @@ TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
# Model for Causal LM mapping # Model for Causal LM mapping
("rembert", "TFRemBertForCausalLM"), ("rembert", "TFRemBertForCausalLM"),
("roformer", "TFRoFormerForCausalLM"), ("roformer", "TFRoFormerForCausalLM"),
("roberta", "TFRobertaForCausalLM"),
("bert", "TFBertLMHeadModel"), ("bert", "TFBertLMHeadModel"),
("openai-gpt", "TFOpenAIGPTLMHeadModel"), ("openai-gpt", "TFOpenAIGPTLMHeadModel"),
("gpt2", "TFGPT2LMHeadModel"), ("gpt2", "TFGPT2LMHeadModel"),
...@@ -181,6 +182,7 @@ TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict( ...@@ -181,6 +182,7 @@ TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
("blenderbot", "TFBlenderbotForConditionalGeneration"), ("blenderbot", "TFBlenderbotForConditionalGeneration"),
("blenderbot-small", "TFBlenderbotSmallForConditionalGeneration"), ("blenderbot-small", "TFBlenderbotSmallForConditionalGeneration"),
("bart", "TFBartForConditionalGeneration"), ("bart", "TFBartForConditionalGeneration"),
("encoder-decoder", "TFEncoderDecoderModel"),
] ]
) )
......
...@@ -110,6 +110,7 @@ class TFConvBertEmbeddings(tf.keras.layers.Layer): ...@@ -110,6 +110,7 @@ class TFConvBertEmbeddings(tf.keras.layers.Layer):
position_ids: tf.Tensor = None, position_ids: tf.Tensor = None,
token_type_ids: tf.Tensor = None, token_type_ids: tf.Tensor = None,
inputs_embeds: tf.Tensor = None, inputs_embeds: tf.Tensor = None,
past_key_values_length=0,
training: bool = False, training: bool = False,
) -> tf.Tensor: ) -> tf.Tensor:
""" """
...@@ -130,7 +131,9 @@ class TFConvBertEmbeddings(tf.keras.layers.Layer): ...@@ -130,7 +131,9 @@ class TFConvBertEmbeddings(tf.keras.layers.Layer):
token_type_ids = tf.fill(dims=input_shape, value=0) token_type_ids = tf.fill(dims=input_shape, value=0)
if position_ids is None: if position_ids is None:
position_ids = tf.expand_dims(tf.range(start=0, limit=input_shape[-1]), axis=0) position_ids = tf.expand_dims(
tf.range(start=past_key_values_length, limit=input_shape[1] + past_key_values_length), axis=0
)
position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids) position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids)
position_embeds = tf.tile(input=position_embeds, multiples=(input_shape[0], 1, 1)) position_embeds = tf.tile(input=position_embeds, multiples=(input_shape[0], 1, 1))
......
...@@ -104,6 +104,9 @@ class ElectraConfig(PretrainedConfig): ...@@ -104,6 +104,9 @@ class ElectraConfig(PretrainedConfig):
<https://arxiv.org/abs/1803.02155>`__. For more information on :obj:`"relative_key_query"`, please refer to <https://arxiv.org/abs/1803.02155>`__. For more information on :obj:`"relative_key_query"`, please refer to
`Method 4` in `Improve Transformer Models with Better Relative Position Embeddings (Huang et al.) `Method 4` in `Improve Transformer Models with Better Relative Position Embeddings (Huang et al.)
<https://arxiv.org/abs/2009.13658>`__. <https://arxiv.org/abs/2009.13658>`__.
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if ``config.is_decoder=True``.
classifier_dropout (:obj:`float`, `optional`): classifier_dropout (:obj:`float`, `optional`):
The dropout ratio for the classification head. The dropout ratio for the classification head.
...@@ -143,6 +146,7 @@ class ElectraConfig(PretrainedConfig): ...@@ -143,6 +146,7 @@ class ElectraConfig(PretrainedConfig):
summary_last_dropout=0.1, summary_last_dropout=0.1,
pad_token_id=0, pad_token_id=0,
position_embedding_type="absolute", position_embedding_type="absolute",
use_cache=True,
classifier_dropout=None, classifier_dropout=None,
**kwargs **kwargs
): ):
...@@ -167,4 +171,5 @@ class ElectraConfig(PretrainedConfig): ...@@ -167,4 +171,5 @@ class ElectraConfig(PretrainedConfig):
self.summary_activation = summary_activation self.summary_activation = summary_activation
self.summary_last_dropout = summary_last_dropout self.summary_last_dropout = summary_last_dropout
self.position_embedding_type = position_embedding_type self.position_embedding_type = position_embedding_type
self.use_cache = use_cache
self.classifier_dropout = classifier_dropout self.classifier_dropout = classifier_dropout
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from ...file_utils import _LazyModule, is_flax_available, is_torch_available from ...file_utils import _LazyModule, is_flax_available, is_tf_available, is_torch_available
_import_structure = { _import_structure = {
...@@ -28,6 +28,9 @@ _import_structure = { ...@@ -28,6 +28,9 @@ _import_structure = {
if is_torch_available(): if is_torch_available():
_import_structure["modeling_encoder_decoder"] = ["EncoderDecoderModel"] _import_structure["modeling_encoder_decoder"] = ["EncoderDecoderModel"]
if is_tf_available():
_import_structure["modeling_tf_encoder_decoder"] = ["TFEncoderDecoderModel"]
if is_flax_available(): if is_flax_available():
_import_structure["modeling_flax_encoder_decoder"] = ["FlaxEncoderDecoderModel"] _import_structure["modeling_flax_encoder_decoder"] = ["FlaxEncoderDecoderModel"]
...@@ -37,6 +40,9 @@ if TYPE_CHECKING: ...@@ -37,6 +40,9 @@ if TYPE_CHECKING:
if is_torch_available(): if is_torch_available():
from .modeling_encoder_decoder import EncoderDecoderModel from .modeling_encoder_decoder import EncoderDecoderModel
if is_tf_available():
from .modeling_tf_encoder_decoder import TFEncoderDecoderModel
if is_flax_available(): if is_flax_available():
from .modeling_flax_encoder_decoder import FlaxEncoderDecoderModel from .modeling_flax_encoder_decoder import FlaxEncoderDecoderModel
......
...@@ -24,8 +24,8 @@ import tensorflow as tf ...@@ -24,8 +24,8 @@ import tensorflow as tf
from ...activations_tf import get_tf_activation from ...activations_tf import get_tf_activation
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
from ...modeling_tf_outputs import ( from ...modeling_tf_outputs import (
TFBaseModelOutput, TFBaseModelOutputWithPastAndCrossAttentions,
TFBaseModelOutputWithPooling, TFBaseModelOutputWithPoolingAndCrossAttentions,
TFMaskedLMOutput, TFMaskedLMOutput,
TFSequenceClassifierOutput, TFSequenceClassifierOutput,
TFTokenClassifierOutput, TFTokenClassifierOutput,
...@@ -216,6 +216,8 @@ class TFLayoutLMSelfAttention(tf.keras.layers.Layer): ...@@ -216,6 +216,8 @@ class TFLayoutLMSelfAttention(tf.keras.layers.Layer):
) )
self.dropout = tf.keras.layers.Dropout(rate=config.attention_probs_dropout_prob) self.dropout = tf.keras.layers.Dropout(rate=config.attention_probs_dropout_prob)
self.is_decoder = config.is_decoder
def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor: def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor:
# Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size] # Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size]
tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size)) tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size))
...@@ -228,16 +230,49 @@ class TFLayoutLMSelfAttention(tf.keras.layers.Layer): ...@@ -228,16 +230,49 @@ class TFLayoutLMSelfAttention(tf.keras.layers.Layer):
hidden_states: tf.Tensor, hidden_states: tf.Tensor,
attention_mask: tf.Tensor, attention_mask: tf.Tensor,
head_mask: tf.Tensor, head_mask: tf.Tensor,
encoder_hidden_states: tf.Tensor,
encoder_attention_mask: tf.Tensor,
past_key_value: Tuple[tf.Tensor],
output_attentions: bool, output_attentions: bool,
training: bool = False, training: bool = False,
) -> Tuple[tf.Tensor]: ) -> Tuple[tf.Tensor]:
batch_size = shape_list(hidden_states)[0] batch_size = shape_list(hidden_states)[0]
mixed_query_layer = self.query(inputs=hidden_states) mixed_query_layer = self.query(inputs=hidden_states)
mixed_key_layer = self.key(inputs=hidden_states)
mixed_value_layer = self.value(inputs=hidden_states) # If this is instantiated as a cross-attention module, the keys
# and values come from an encoder; the attention mask needs to be
# such that the encoder's padding tokens are not attended to.
is_cross_attention = encoder_hidden_states is not None
if is_cross_attention and past_key_value is not None:
# reuse k,v, cross_attentions
key_layer = past_key_value[0]
value_layer = past_key_value[1]
attention_mask = encoder_attention_mask
elif is_cross_attention:
key_layer = self.transpose_for_scores(self.key(inputs=encoder_hidden_states), batch_size)
value_layer = self.transpose_for_scores(self.value(inputs=encoder_hidden_states), batch_size)
attention_mask = encoder_attention_mask
elif past_key_value is not None:
key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size)
value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size)
key_layer = tf.concatenate([past_key_value[0], key_layer], dim=2)
value_layer = tf.concatenate([past_key_value[1], value_layer], dim=2)
else:
key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size)
value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size)
query_layer = self.transpose_for_scores(mixed_query_layer, batch_size) query_layer = self.transpose_for_scores(mixed_query_layer, batch_size)
key_layer = self.transpose_for_scores(mixed_key_layer, batch_size)
value_layer = self.transpose_for_scores(mixed_value_layer, batch_size) if self.is_decoder:
# if cross_attention save Tuple(tf.Tensor, tf.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(tf.Tensor, tf.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_layer, value_layer)
# Take the dot product between "query" and "key" to get the raw attention scores. # Take the dot product between "query" and "key" to get the raw attention scores.
# (batch size, num_heads, seq_len_q, seq_len_k) # (batch size, num_heads, seq_len_q, seq_len_k)
...@@ -267,6 +302,8 @@ class TFLayoutLMSelfAttention(tf.keras.layers.Layer): ...@@ -267,6 +302,8 @@ class TFLayoutLMSelfAttention(tf.keras.layers.Layer):
attention_output = tf.reshape(tensor=attention_output, shape=(batch_size, -1, self.all_head_size)) attention_output = tf.reshape(tensor=attention_output, shape=(batch_size, -1, self.all_head_size))
outputs = (attention_output, attention_probs) if output_attentions else (attention_output,) outputs = (attention_output, attention_probs) if output_attentions else (attention_output,)
if self.is_decoder:
outputs = outputs + (past_key_value,)
return outputs return outputs
...@@ -305,6 +342,9 @@ class TFLayoutLMAttention(tf.keras.layers.Layer): ...@@ -305,6 +342,9 @@ class TFLayoutLMAttention(tf.keras.layers.Layer):
input_tensor: tf.Tensor, input_tensor: tf.Tensor,
attention_mask: tf.Tensor, attention_mask: tf.Tensor,
head_mask: tf.Tensor, head_mask: tf.Tensor,
encoder_hidden_states: tf.Tensor,
encoder_attention_mask: tf.Tensor,
past_key_value: Tuple[tf.Tensor],
output_attentions: bool, output_attentions: bool,
training: bool = False, training: bool = False,
) -> Tuple[tf.Tensor]: ) -> Tuple[tf.Tensor]:
...@@ -312,13 +352,17 @@ class TFLayoutLMAttention(tf.keras.layers.Layer): ...@@ -312,13 +352,17 @@ class TFLayoutLMAttention(tf.keras.layers.Layer):
hidden_states=input_tensor, hidden_states=input_tensor,
attention_mask=attention_mask, attention_mask=attention_mask,
head_mask=head_mask, head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_value=past_key_value,
output_attentions=output_attentions, output_attentions=output_attentions,
training=training, training=training,
) )
attention_output = self.dense_output( attention_output = self.dense_output(
hidden_states=self_outputs[0], input_tensor=input_tensor, training=training hidden_states=self_outputs[0], input_tensor=input_tensor, training=training
) )
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them # add attentions (possibly with past_key_value) if we output them
outputs = (attention_output,) + self_outputs[1:]
return outputs return outputs
...@@ -369,6 +413,12 @@ class TFLayoutLMLayer(tf.keras.layers.Layer): ...@@ -369,6 +413,12 @@ class TFLayoutLMLayer(tf.keras.layers.Layer):
super().__init__(**kwargs) super().__init__(**kwargs)
self.attention = TFLayoutLMAttention(config, name="attention") self.attention = TFLayoutLMAttention(config, name="attention")
self.is_decoder = config.is_decoder
self.add_cross_attention = config.add_cross_attention
if self.add_cross_attention:
if not self.is_decoder:
raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
self.crossattention = TFLayoutLMAttention(config, name="crossattention")
self.intermediate = TFLayoutLMIntermediate(config, name="intermediate") self.intermediate = TFLayoutLMIntermediate(config, name="intermediate")
self.bert_output = TFLayoutLMOutput(config, name="output") self.bert_output = TFLayoutLMOutput(config, name="output")
...@@ -377,22 +427,69 @@ class TFLayoutLMLayer(tf.keras.layers.Layer): ...@@ -377,22 +427,69 @@ class TFLayoutLMLayer(tf.keras.layers.Layer):
hidden_states: tf.Tensor, hidden_states: tf.Tensor,
attention_mask: tf.Tensor, attention_mask: tf.Tensor,
head_mask: tf.Tensor, head_mask: tf.Tensor,
encoder_hidden_states: Optional[tf.Tensor],
encoder_attention_mask: Optional[tf.Tensor],
past_key_value: Optional[Tuple[tf.Tensor]],
output_attentions: bool, output_attentions: bool,
training: bool = False, training: bool = False,
) -> Tuple[tf.Tensor]: ) -> Tuple[tf.Tensor]:
attention_outputs = self.attention( # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
self_attention_outputs = self.attention(
input_tensor=hidden_states, input_tensor=hidden_states,
attention_mask=attention_mask, attention_mask=attention_mask,
head_mask=head_mask, head_mask=head_mask,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_value=self_attn_past_key_value,
output_attentions=output_attentions, output_attentions=output_attentions,
training=training, training=training,
) )
attention_output = attention_outputs[0] attention_output = self_attention_outputs[0]
# if decoder, the last output is tuple of self-attn cache
if self.is_decoder:
outputs = self_attention_outputs[1:-1]
present_key_value = self_attention_outputs[-1]
else:
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
cross_attn_present_key_value = None
if self.is_decoder and encoder_hidden_states is not None:
if not hasattr(self, "crossattention"):
raise ValueError(
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers "
"by setting `config.add_cross_attention=True`"
)
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
cross_attention_outputs = self.crossattention(
input_tensor=attention_output,
attention_mask=attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_value=cross_attn_past_key_value,
output_attentions=output_attentions,
training=training,
)
attention_output = cross_attention_outputs[0]
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
# add cross-attn cache to positions 3,4 of present_key_value tuple
cross_attn_present_key_value = cross_attention_outputs[-1]
present_key_value = present_key_value + cross_attn_present_key_value
intermediate_output = self.intermediate(hidden_states=attention_output) intermediate_output = self.intermediate(hidden_states=attention_output)
layer_output = self.bert_output( layer_output = self.bert_output(
hidden_states=intermediate_output, input_tensor=attention_output, training=training hidden_states=intermediate_output, input_tensor=attention_output, training=training
) )
outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them outputs = (layer_output,) + outputs # add attentions if we output them
# if decoder, return the attn key/values as the last output
if self.is_decoder:
outputs = outputs + (present_key_value,)
return outputs return outputs
...@@ -401,7 +498,7 @@ class TFLayoutLMLayer(tf.keras.layers.Layer): ...@@ -401,7 +498,7 @@ class TFLayoutLMLayer(tf.keras.layers.Layer):
class TFLayoutLMEncoder(tf.keras.layers.Layer): class TFLayoutLMEncoder(tf.keras.layers.Layer):
def __init__(self, config: LayoutLMConfig, **kwargs): def __init__(self, config: LayoutLMConfig, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.config = config
self.layer = [TFLayoutLMLayer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)] self.layer = [TFLayoutLMLayer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)]
def call( def call(
...@@ -409,39 +506,61 @@ class TFLayoutLMEncoder(tf.keras.layers.Layer): ...@@ -409,39 +506,61 @@ class TFLayoutLMEncoder(tf.keras.layers.Layer):
hidden_states: tf.Tensor, hidden_states: tf.Tensor,
attention_mask: tf.Tensor, attention_mask: tf.Tensor,
head_mask: tf.Tensor, head_mask: tf.Tensor,
encoder_hidden_states: Optional[tf.Tensor],
encoder_attention_mask: Optional[tf.Tensor],
past_key_values: Optional[Tuple[Tuple[tf.Tensor]]],
use_cache: Optional[bool],
output_attentions: bool, output_attentions: bool,
output_hidden_states: bool, output_hidden_states: bool,
return_dict: bool, return_dict: bool,
training: bool = False, training: bool = False,
) -> Union[TFBaseModelOutput, Tuple[tf.Tensor]]: ) -> Union[TFBaseModelOutputWithPastAndCrossAttentions, Tuple[tf.Tensor]]:
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None all_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
next_decoder_cache = () if use_cache else None
for i, layer_module in enumerate(self.layer): for i, layer_module in enumerate(self.layer):
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
past_key_value = past_key_values[i] if past_key_values is not None else None
layer_outputs = layer_module( layer_outputs = layer_module(
hidden_states=hidden_states, hidden_states=hidden_states,
attention_mask=attention_mask, attention_mask=attention_mask,
head_mask=head_mask[i], head_mask=head_mask[i],
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_value=past_key_value,
output_attentions=output_attentions, output_attentions=output_attentions,
training=training, training=training,
) )
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[-1],)
if output_attentions: if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],) all_attentions = all_attentions + (layer_outputs[1],)
if self.config.add_cross_attention and encoder_hidden_states is not None:
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
# Add last layer # Add last layer
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict: if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None) return tuple(
v for v in [hidden_states, all_hidden_states, all_attentions, all_cross_attentions] if v is not None
)
return TFBaseModelOutput( return TFBaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions last_hidden_state=hidden_states,
past_key_values=next_decoder_cache,
hidden_states=all_hidden_states,
attentions=all_attentions,
cross_attentions=all_cross_attentions,
) )
...@@ -585,12 +704,14 @@ class TFLayoutLMMainLayer(tf.keras.layers.Layer): ...@@ -585,12 +704,14 @@ class TFLayoutLMMainLayer(tf.keras.layers.Layer):
position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None, position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None, inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
encoder_hidden_states: Optional[Union[np.ndarray, tf.Tensor]] = None,
encoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
training: bool = False, training: bool = False,
**kwargs, **kwargs,
) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor]]: ) -> Union[TFBaseModelOutputWithPoolingAndCrossAttentions, Tuple[tf.Tensor]]:
inputs = input_processing( inputs = input_processing(
func=self.call, func=self.call,
config=self.config, config=self.config,
...@@ -665,6 +786,11 @@ class TFLayoutLMMainLayer(tf.keras.layers.Layer): ...@@ -665,6 +786,11 @@ class TFLayoutLMMainLayer(tf.keras.layers.Layer):
hidden_states=embedding_output, hidden_states=embedding_output,
attention_mask=extended_attention_mask, attention_mask=extended_attention_mask,
head_mask=inputs["head_mask"], head_mask=inputs["head_mask"],
# Need to pass these required positional arguments to `Encoder`
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=None,
past_key_values=None,
use_cache=False,
output_attentions=inputs["output_attentions"], output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"], output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"], return_dict=inputs["return_dict"],
...@@ -680,11 +806,12 @@ class TFLayoutLMMainLayer(tf.keras.layers.Layer): ...@@ -680,11 +806,12 @@ class TFLayoutLMMainLayer(tf.keras.layers.Layer):
pooled_output, pooled_output,
) + encoder_outputs[1:] ) + encoder_outputs[1:]
return TFBaseModelOutputWithPooling( return TFBaseModelOutputWithPoolingAndCrossAttentions(
last_hidden_state=sequence_output, last_hidden_state=sequence_output,
pooler_output=pooled_output, pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states, hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions, attentions=encoder_outputs.attentions,
cross_attentions=encoder_outputs.cross_attentions,
) )
...@@ -802,7 +929,9 @@ class TFLayoutLMModel(TFLayoutLMPreTrainedModel): ...@@ -802,7 +929,9 @@ class TFLayoutLMModel(TFLayoutLMPreTrainedModel):
self.layoutlm = TFLayoutLMMainLayer(config, name="layoutlm") self.layoutlm = TFLayoutLMMainLayer(config, name="layoutlm")
@add_start_docstrings_to_model_forward(LAYOUTLM_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(LAYOUTLM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=TFBaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(
output_type=TFBaseModelOutputWithPoolingAndCrossAttentions, config_class=_CONFIG_FOR_DOC
)
def call( def call(
self, self,
input_ids: Optional[TFModelInputType] = None, input_ids: Optional[TFModelInputType] = None,
...@@ -812,12 +941,14 @@ class TFLayoutLMModel(TFLayoutLMPreTrainedModel): ...@@ -812,12 +941,14 @@ class TFLayoutLMModel(TFLayoutLMPreTrainedModel):
position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None, position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None, inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
encoder_hidden_states: Optional[Union[np.ndarray, tf.Tensor]] = None,
encoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
training: Optional[bool] = False, training: Optional[bool] = False,
**kwargs, **kwargs,
) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor]]: ) -> Union[TFBaseModelOutputWithPoolingAndCrossAttentions, Tuple[tf.Tensor]]:
r""" r"""
Returns: Returns:
...@@ -859,6 +990,8 @@ class TFLayoutLMModel(TFLayoutLMPreTrainedModel): ...@@ -859,6 +990,8 @@ class TFLayoutLMModel(TFLayoutLMPreTrainedModel):
position_ids=position_ids, position_ids=position_ids,
head_mask=head_mask, head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
...@@ -881,15 +1014,25 @@ class TFLayoutLMModel(TFLayoutLMPreTrainedModel): ...@@ -881,15 +1014,25 @@ class TFLayoutLMModel(TFLayoutLMPreTrainedModel):
return outputs return outputs
def serving_output(self, output: TFBaseModelOutputWithPooling) -> TFBaseModelOutputWithPooling: # Copied from transformers.models.bert.modeling_tf_bert.TFBertModel.serving_output
def serving_output(
self, output: TFBaseModelOutputWithPoolingAndCrossAttentions
) -> TFBaseModelOutputWithPoolingAndCrossAttentions:
output_cache = self.config.use_cache and self.config.is_decoder
pkv = tf.convert_to_tensor(output.past_key_values) if output_cache else None
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
cross_attns = tf.convert_to_tensor(output.cross_attentions) if output.cross_attentions is not None else None
if not (self.config.output_attentions and self.config.add_cross_attention):
cross_attns = None
return TFBaseModelOutputWithPooling( return TFBaseModelOutputWithPoolingAndCrossAttentions(
last_hidden_state=output.last_hidden_state, last_hidden_state=output.last_hidden_state,
pooler_output=output.pooler_output, pooler_output=output.pooler_output,
past_key_values=pkv,
hidden_states=hs, hidden_states=hs,
attentions=attns, attentions=attns,
cross_attentions=cross_attns,
) )
......
...@@ -510,7 +510,7 @@ class TFLongformerEmbeddings(tf.keras.layers.Layer): ...@@ -510,7 +510,7 @@ class TFLongformerEmbeddings(tf.keras.layers.Layer):
super().build(input_shape) super().build(input_shape)
def create_position_ids_from_input_ids(self, input_ids): def create_position_ids_from_input_ids(self, input_ids, past_key_values_length=0):
""" """
Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding
symbols are ignored. This is modified from fairseq's `utils.make_positions`. symbols are ignored. This is modified from fairseq's `utils.make_positions`.
...@@ -520,11 +520,19 @@ class TFLongformerEmbeddings(tf.keras.layers.Layer): ...@@ -520,11 +520,19 @@ class TFLongformerEmbeddings(tf.keras.layers.Layer):
Returns: tf.Tensor Returns: tf.Tensor
""" """
mask = tf.cast(tf.math.not_equal(input_ids, self.padding_idx), dtype=input_ids.dtype) mask = tf.cast(tf.math.not_equal(input_ids, self.padding_idx), dtype=input_ids.dtype)
incremental_indices = tf.math.cumsum(mask, axis=1) * mask incremental_indices = (tf.math.cumsum(mask, axis=1) + past_key_values_length) * mask
return incremental_indices + self.padding_idx return incremental_indices + self.padding_idx
def call(self, input_ids=None, position_ids=None, token_type_ids=None, inputs_embeds=None, training=False): def call(
self,
input_ids=None,
position_ids=None,
token_type_ids=None,
inputs_embeds=None,
past_key_values_length=0,
training=False,
):
""" """
Applies embedding based on inputs tensor. Applies embedding based on inputs tensor.
...@@ -544,7 +552,9 @@ class TFLongformerEmbeddings(tf.keras.layers.Layer): ...@@ -544,7 +552,9 @@ class TFLongformerEmbeddings(tf.keras.layers.Layer):
if position_ids is None: if position_ids is None:
if input_ids is not None: if input_ids is not None:
# Create the position ids from the input token ids. Any padded tokens remain padded. # Create the position ids from the input token ids. Any padded tokens remain padded.
position_ids = self.create_position_ids_from_input_ids(input_ids=input_ids) position_ids = self.create_position_ids_from_input_ids(
input_ids=input_ids, past_key_values_length=past_key_values_length
)
else: else:
position_ids = tf.expand_dims( position_ids = tf.expand_dims(
tf.range(start=self.padding_idx + 1, limit=input_shape[-1] + self.padding_idx + 1), axis=0 tf.range(start=self.padding_idx + 1, limit=input_shape[-1] + self.padding_idx + 1), axis=0
......
...@@ -982,7 +982,6 @@ class TFMobileBertModel(TFMobileBertPreTrainedModel): ...@@ -982,7 +982,6 @@ class TFMobileBertModel(TFMobileBertPreTrainedModel):
return outputs return outputs
# Copied from transformers.models.bert.modeling_tf_bert.TFBertModel.serving_output
def serving_output(self, output: TFBaseModelOutputWithPooling) -> TFBaseModelOutputWithPooling: def serving_output(self, output: TFBaseModelOutputWithPooling) -> TFBaseModelOutputWithPooling:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment