Unverified Commit 8b240a06 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Add TFEncoderDecoderModel + Add cross-attention to some TF models (#13222)



* Add cross attentions to TFGPT2Model

* Add TFEncoderDecoderModel

* Add TFBaseModelOutputWithPoolingAndCrossAttentions

* Add cross attentions to TFBertModel

* Fix past or past_key_values argument issue

* Fix generation

* Fix save and load

* Add some checks and comments

* Clean the code that deals with past keys/values

* Add kwargs to processing_inputs

* Add serving_output to TFEncoderDecoderModel

* Some cleaning + fix use_cache value issue

* Fix tests + add bert2bert/bert2gpt2 tests

* Fix more tests

* Ignore crossattention.bias when loading GPT2 weights into TFGPT2

* Fix return_dict_in_generate in tf generation

* Fix is_token_logit_eos_token bug in tf generation

* Finalize the tests after fixing some bugs

* Fix another is_token_logit_eos_token bug in tf generation

* Add/Update docs

* Add TFBertEncoderDecoderModelTest

* Clean test script

* Add TFEncoderDecoderModel to the library

* Add cross attentions to TFRobertaModel

* Add TFRobertaEncoderDecoderModelTest

* make style

* Change the way of position_ids computation

* bug fix

* Fix copies in tf_albert

* Remove some copied from and apply some fix-copies

* Remove some copied

* Add cross attentions to some other TF models

* Remove encoder_hidden_states from TFLayoutLMModel.call for now

* Make style

* Fix TFRemBertForCausalLM

* Revert the change to longformer + Remove copies

* Revert the change to albert and convbert + Remove copies

* make quality

* make style

* Add TFRembertEncoderDecoderModelTest

* make quality and fix-copies

* test TFRobertaForCausalLM

* Fixes for failed tests

* Fixes for failed tests

* fix more tests

* Fixes for failed tests

* Fix Auto mapping order

* Fix TFRemBertEncoder return value

* fix tf_rembert

* Check copies are OK

* Fix missing TFBaseModelOutputWithPastAndCrossAttentions is not defined

* Add TFEncoderDecoderModelSaveLoadTests

* fix tf weight loading

* check the change of use_cache

* Revert the change

* Add missing test_for_causal_lm for TFRobertaModelTest

* Try cleaning past

* fix _reorder_cache

* Revert some files to original versions

* Keep as many copies as possible

* Apply suggested changes - Use raise ValueError instead of assert

* Move import to top

* Fix wrong require_torch

* Replace more assert by raise ValueError

* Add test_pt_tf_model_equivalence (the test won't pass for now)

* add test for loading/saving

* finish

* finish

* Remove test_pt_tf_model_equivalence

* Update tf modeling template

* Remove pooling, added in the prev. commit, from MainLayer

* Update tf modeling test template

* Move inputs["use_cache"] = False to modeling_tf_utils.py

* Fix torch.Tensor in the comment

* fix use_cache

* Fix missing use_cache in ElectraConfig

* Add a note to from_pretrained

* Fix style

* Change test_encoder_decoder_save_load_from_encoder_decoder_from_pt

* Fix TFMLP (in TFGPT2) activation issue

* Fix None past_key_values value in serving_output

* Don't call get_encoderdecoder_model in TFEncoderDecoderModelTest.test_configuration_tie until we have a TF checkpoint on Hub

* Apply review suggestions - style for cross_attns in serving_output

* Apply review suggestions - change assert + docstrings

* break the error message to respect the char limit

* deprecate the argument past

* fix docstring style

* Update the encoder-decoder rst file

* fix Unknown interpreted text role "method"

* fix typo
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 26b6ef79
......@@ -379,7 +379,7 @@ Flax), PyTorch, and/or TensorFlow.
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| ELECTRA | ✅ | ✅ | ✅ | ✅ | ✅ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| Encoder decoder | ❌ | ❌ | ✅ | | ✅ |
| Encoder decoder | ❌ | ❌ | ✅ | | ✅ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| FairSeq Machine-Translation | ✅ | ❌ | ✅ | ❌ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
......
......@@ -210,6 +210,13 @@ TFBaseModelOutputWithPooling
:members:
TFBaseModelOutputWithPoolingAndCrossAttentions
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.modeling_tf_outputs.TFBaseModelOutputWithPoolingAndCrossAttentions
:members:
TFBaseModelOutputWithPast
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
......@@ -217,6 +224,13 @@ TFBaseModelOutputWithPast
:members:
TFBaseModelOutputWithPastAndCrossAttentions
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.modeling_tf_outputs.TFBaseModelOutputWithPastAndCrossAttentions
:members:
TFSeq2SeqModelOutput
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
......@@ -231,6 +245,13 @@ TFCausalLMOutput
:members:
TFCausalLMOutputWithCrossAttentions
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.modeling_tf_outputs.TFCausalLMOutputWithCrossAttentions
:members:
TFCausalLMOutputWithPast
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
......
......@@ -27,6 +27,25 @@ An application of this architecture could be to leverage two pretrained :class:`
and decoder for a summarization model as was shown in: `Text Summarization with Pretrained Encoders
<https://arxiv.org/abs/1908.08345>`__ by Yang Liu and Mirella Lapata.
The :meth:`~transformers.TFEncoderDecoderModel.from_pretrained` currently doesn't support initializing the model from a
pytorch checkpoint. Passing ``from_pt=True`` to this method will throw an exception. If there are only pytorch
checkpoints for a particular encoder-decoder model, a workaround is:
.. code-block::
>>> # a workaround to load from pytorch checkpoint
>>> _model = EncoderDecoderModel.from_pretrained("patrickvonplaten/bert2bert-cnn_dailymail-fp16")
>>> _model.encoder.save_pretrained("./encoder")
>>> _model.decoder.save_pretrained("./decoder")
>>> model = TFEncoderDecoderModel.from_encoder_decoder_pretrained(
... "./encoder", "./decoder", encoder_from_pt=True, decoder_from_pt=True
... )
>>> # This is only for copying some specific attributes of this particular model.
>>> model.config = _model.config
This model was contributed by `thomwolf <https://github.com/thomwolf>`__. This model's TensorFlow and Flax versions
were contributed by `ydshieh <https://github.com/ydshieh>`__.
EncoderDecoderConfig
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
......@@ -42,6 +61,13 @@ EncoderDecoderModel
:members: forward, from_encoder_decoder_pretrained
TFEncoderDecoderModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TFEncoderDecoderModel
:members: call, from_encoder_decoder_pretrained
FlaxEncoderDecoderModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
......
......@@ -126,6 +126,13 @@ TFRobertaModel
:members: call
TFRobertaForCausalLM
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TFRobertaForCausalLM
:members: call
TFRobertaForMaskedLM
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
......
......@@ -1444,6 +1444,7 @@ if is_tf_available():
"TFElectraPreTrainedModel",
]
)
_import_structure["models.encoder_decoder"].append("TFEncoderDecoderModel")
_import_structure["models.flaubert"].extend(
[
"TF_FLAUBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
......@@ -1596,6 +1597,7 @@ if is_tf_available():
_import_structure["models.roberta"].extend(
[
"TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFRobertaForCausalLM",
"TFRobertaForMaskedLM",
"TFRobertaForMultipleChoice",
"TFRobertaForQuestionAnswering",
......@@ -3096,6 +3098,7 @@ if TYPE_CHECKING:
TFElectraModel,
TFElectraPreTrainedModel,
)
from .models.encoder_decoder import TFEncoderDecoderModel
from .models.flaubert import (
TF_FLAUBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFFlaubertForMultipleChoice,
......@@ -3205,6 +3208,7 @@ if TYPE_CHECKING:
)
from .models.roberta import (
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
TFRobertaForCausalLM,
TFRobertaForMaskedLM,
TFRobertaForMultipleChoice,
TFRobertaForQuestionAnswering,
......
......@@ -76,6 +76,7 @@ from . import (
TFLxmertForPreTraining,
TFLxmertVisualFeatureEncoder,
TFOpenAIGPTLMHeadModel,
TFRobertaForCausalLM,
TFRobertaForMaskedLM,
TFRobertaForSequenceClassification,
TFT5ForConditionalGeneration,
......@@ -215,6 +216,7 @@ MODEL_CLASSES = {
),
"roberta": (
RobertaConfig,
TFRobertaForCausalLM,
TFRobertaForMaskedLM,
RobertaForMaskedLM,
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
......
......@@ -650,7 +650,11 @@ class TFGenerationMixin:
# current position and vocab size
cur_len = shape_list(input_ids)[1] # unused
vocab_size = self.config.vocab_size
vocab_size = getattr(self.config, "vocab_size", None)
if vocab_size is None and self.config.is_encoder_decoder:
decoder_config = getattr(self.config, "decoder", None)
if decoder_config is not None:
vocab_size = getattr(self.config.decoder, "vocab_size", None)
# set effective batch size and effective batch multiplier according to do_sample
if do_sample:
......@@ -678,6 +682,7 @@ class TFGenerationMixin:
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict_in_generate,
)
if return_dict_in_generate:
if output_attentions:
......@@ -911,7 +916,7 @@ class TFGenerationMixin:
if eos_token_id is not None and cur_len < min_length:
# create eos_token_id boolean mask
is_token_logit_eos_token = tf.convert_to_tensor(
[True if token is eos_token_id else False for token in range(vocab_size)], dtype=tf.bool
[True if token == eos_token_id else False for token in range(vocab_size)], dtype=tf.bool
)
eos_token_indices_mask = tf.broadcast_to(is_token_logit_eos_token, [batch_size, vocab_size])
......@@ -1142,7 +1147,7 @@ class TFGenerationMixin:
num_batch_hypotheses = batch_size * num_beams
is_token_logit_eos_token = tf.convert_to_tensor(
[True if token is eos_token_id else False for token in range(vocab_size)], dtype=tf.bool
[True if token == eos_token_id else False for token in range(vocab_size)], dtype=tf.bool
)
eos_token_indices_mask = tf.broadcast_to(is_token_logit_eos_token, [num_batch_hypotheses, vocab_size])
......@@ -1446,11 +1451,17 @@ class TFGenerationMixin:
Implement in subclasses of :class:`~transformers.PreTrainedModel` for custom behavior to adjust the logits in
the generate method.
"""
vocab_size = getattr(self.config, "vocab_size", None)
if vocab_size is None and self.config.is_encoder_decoder:
decoder_config = getattr(self.config, "decoder", None)
if decoder_config is not None:
vocab_size = getattr(self.config.decoder, "vocab_size", None)
if cur_len == 1 and forced_bos_token_id is not None:
vocab_range = tf.constant(range(self.config.vocab_size))
vocab_range = tf.constant(range(vocab_size))
return tf.where(vocab_range != forced_bos_token_id, -1e8, logits)
elif cur_len == max_length - 1 and forced_eos_token_id is not None:
vocab_range = tf.constant(range(self.config.vocab_size))
vocab_range = tf.constant(range(vocab_size))
return tf.where(vocab_range != forced_eos_token_id, -1e8, logits)
else:
return logits
......
......@@ -80,6 +80,54 @@ class TFBaseModelOutputWithPooling(ModelOutput):
attentions: Optional[Tuple[tf.Tensor]] = None
@dataclass
class TFBaseModelOutputWithPoolingAndCrossAttentions(ModelOutput):
"""
Base class for model's outputs that also contains a pooling of the last hidden states.
Args:
last_hidden_state (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
pooler_output (:obj:`tf.Tensor` of shape :obj:`(batch_size, hidden_size)`):
Last layer hidden-state of the first token of the sequence (classification token) further processed by a
Linear layer and a Tanh activation function. The Linear layer weights are trained from the next sentence
prediction (classification) objective during pretraining.
This output is usually *not* a good summary of the semantic content of the input, you're often better with
averaging or pooling the sequence of hidden-states for the whole input sequence.
past_key_values (:obj:`List[tf.Tensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
List of :obj:`tf.Tensor` of length :obj:`config.n_layers`, with each tensor of shape :obj:`(2, batch_size,
num_heads, sequence_length, embed_size_per_head)`).
Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
:obj:`past_key_values` input) to speed up sequential decoding.
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of
shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`tf.Tensor` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
cross_attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`tf.Tensor` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
weighted average in the cross-attention heads.
"""
last_hidden_state: tf.Tensor = None
pooler_output: tf.Tensor = None
past_key_values: Optional[List[tf.Tensor]] = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[tf.Tensor]] = None
cross_attentions: Optional[Tuple[tf.Tensor]] = None
@dataclass
class TFBaseModelOutputWithPast(ModelOutput):
"""
......@@ -317,6 +365,49 @@ class TFCausalLMOutputWithPast(ModelOutput):
attentions: Optional[Tuple[tf.Tensor]] = None
@dataclass
class TFCausalLMOutputWithCrossAttentions(ModelOutput):
"""
Base class for causal language model (or autoregressive) outputs.
Args:
loss (:obj:`tf.Tensor` of shape :obj:`(n,)`, `optional`, where n is the number of non-masked labels, returned when :obj:`labels` is provided):
Language modeling loss (for next-token prediction).
logits (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of
shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`tf.Tensor` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
cross_attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`tf.Tensor` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
weighted average in the cross-attention heads.
past_key_values (:obj:`List[tf.Tensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
List of :obj:`tf.Tensor` of length :obj:`config.n_layers`, with each tensor of shape :obj:`(2, batch_size,
num_heads, sequence_length, embed_size_per_head)`).
Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
:obj:`past_key_values` input) to speed up sequential decoding.
"""
loss: Optional[tf.Tensor] = None
logits: tf.Tensor = None
past_key_values: Optional[List[tf.Tensor]] = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[tf.Tensor]] = None
cross_attentions: Optional[Tuple[tf.Tensor]] = None
@dataclass
class TFMaskedLMOutput(ModelOutput):
"""
......
......@@ -296,7 +296,9 @@ def booleans_processing(config, **kwargs):
)
if "use_cache" in kwargs:
final_booleans["use_cache"] = kwargs["use_cache"] if kwargs["use_cache"] is not None else config.use_cache
final_booleans["use_cache"] = (
kwargs["use_cache"] if kwargs["use_cache"] is not None else getattr(config, "use_cache", None)
)
else:
if (
kwargs["output_attentions"] not in (None, config.output_attentions)
......@@ -318,7 +320,7 @@ def booleans_processing(config, **kwargs):
final_booleans["return_dict"] = True
if "use_cache" in kwargs:
final_booleans["use_cache"] = config.use_cache
final_booleans["use_cache"] = getattr(config, "use_cache", None)
return final_booleans
......@@ -362,6 +364,15 @@ def input_processing(func, config, input_ids, **kwargs):
)
output["past_key_values"] = kwargs["kwargs_call"].pop("decoder_cached_states")
if "past" in kwargs["kwargs_call"] and "past_key_values" in kwargs:
warnings.warn(
"The `past` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
FutureWarning,
)
kwargs["past_key_values"] = kwargs["kwargs_call"].pop("past")
elif "past_key_values" in kwargs["kwargs_call"] and "past" in kwargs:
kwargs["past"] = kwargs["kwargs_call"].pop("past_key_values")
if len(kwargs["kwargs_call"]) > 0:
raise ValueError(
f"The following keyword arguments are not supported by this model: {list(kwargs['kwargs_call'].keys())}."
......
......@@ -157,6 +157,7 @@ class TFAlbertEmbeddings(tf.keras.layers.Layer):
position_ids: tf.Tensor = None,
token_type_ids: tf.Tensor = None,
inputs_embeds: tf.Tensor = None,
past_key_values_length=0,
training: bool = False,
) -> tf.Tensor:
"""
......@@ -177,7 +178,9 @@ class TFAlbertEmbeddings(tf.keras.layers.Layer):
token_type_ids = tf.fill(dims=input_shape, value=0)
if position_ids is None:
position_ids = tf.expand_dims(tf.range(start=0, limit=input_shape[-1]), axis=0)
position_ids = tf.expand_dims(
tf.range(start=past_key_values_length, limit=input_shape[1] + past_key_values_length), axis=0
)
position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids)
position_embeds = tf.tile(input=position_embeds, multiples=(input_shape[0], 1, 1))
......@@ -829,7 +832,6 @@ class TFAlbertModel(TFAlbertPreTrainedModel):
return outputs
# Copied from transformers.models.bert.modeling_tf_bert.TFBertModel.serving_output
def serving_output(self, output: TFBaseModelOutputWithPooling) -> TFBaseModelOutputWithPooling:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
......
......@@ -133,6 +133,7 @@ TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
# Model for Causal LM mapping
("rembert", "TFRemBertForCausalLM"),
("roformer", "TFRoFormerForCausalLM"),
("roberta", "TFRobertaForCausalLM"),
("bert", "TFBertLMHeadModel"),
("openai-gpt", "TFOpenAIGPTLMHeadModel"),
("gpt2", "TFGPT2LMHeadModel"),
......@@ -181,6 +182,7 @@ TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
("blenderbot", "TFBlenderbotForConditionalGeneration"),
("blenderbot-small", "TFBlenderbotSmallForConditionalGeneration"),
("bart", "TFBartForConditionalGeneration"),
("encoder-decoder", "TFEncoderDecoderModel"),
]
)
......
......@@ -110,6 +110,7 @@ class TFConvBertEmbeddings(tf.keras.layers.Layer):
position_ids: tf.Tensor = None,
token_type_ids: tf.Tensor = None,
inputs_embeds: tf.Tensor = None,
past_key_values_length=0,
training: bool = False,
) -> tf.Tensor:
"""
......@@ -130,7 +131,9 @@ class TFConvBertEmbeddings(tf.keras.layers.Layer):
token_type_ids = tf.fill(dims=input_shape, value=0)
if position_ids is None:
position_ids = tf.expand_dims(tf.range(start=0, limit=input_shape[-1]), axis=0)
position_ids = tf.expand_dims(
tf.range(start=past_key_values_length, limit=input_shape[1] + past_key_values_length), axis=0
)
position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids)
position_embeds = tf.tile(input=position_embeds, multiples=(input_shape[0], 1, 1))
......
......@@ -104,6 +104,9 @@ class ElectraConfig(PretrainedConfig):
<https://arxiv.org/abs/1803.02155>`__. For more information on :obj:`"relative_key_query"`, please refer to
`Method 4` in `Improve Transformer Models with Better Relative Position Embeddings (Huang et al.)
<https://arxiv.org/abs/2009.13658>`__.
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if ``config.is_decoder=True``.
classifier_dropout (:obj:`float`, `optional`):
The dropout ratio for the classification head.
......@@ -143,6 +146,7 @@ class ElectraConfig(PretrainedConfig):
summary_last_dropout=0.1,
pad_token_id=0,
position_embedding_type="absolute",
use_cache=True,
classifier_dropout=None,
**kwargs
):
......@@ -167,4 +171,5 @@ class ElectraConfig(PretrainedConfig):
self.summary_activation = summary_activation
self.summary_last_dropout = summary_last_dropout
self.position_embedding_type = position_embedding_type
self.use_cache = use_cache
self.classifier_dropout = classifier_dropout
......@@ -18,7 +18,7 @@
from typing import TYPE_CHECKING
from ...file_utils import _LazyModule, is_flax_available, is_torch_available
from ...file_utils import _LazyModule, is_flax_available, is_tf_available, is_torch_available
_import_structure = {
......@@ -28,6 +28,9 @@ _import_structure = {
if is_torch_available():
_import_structure["modeling_encoder_decoder"] = ["EncoderDecoderModel"]
if is_tf_available():
_import_structure["modeling_tf_encoder_decoder"] = ["TFEncoderDecoderModel"]
if is_flax_available():
_import_structure["modeling_flax_encoder_decoder"] = ["FlaxEncoderDecoderModel"]
......@@ -37,6 +40,9 @@ if TYPE_CHECKING:
if is_torch_available():
from .modeling_encoder_decoder import EncoderDecoderModel
if is_tf_available():
from .modeling_tf_encoder_decoder import TFEncoderDecoderModel
if is_flax_available():
from .modeling_flax_encoder_decoder import FlaxEncoderDecoderModel
......
......@@ -24,8 +24,8 @@ import tensorflow as tf
from ...activations_tf import get_tf_activation
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
from ...modeling_tf_outputs import (
TFBaseModelOutput,
TFBaseModelOutputWithPooling,
TFBaseModelOutputWithPastAndCrossAttentions,
TFBaseModelOutputWithPoolingAndCrossAttentions,
TFMaskedLMOutput,
TFSequenceClassifierOutput,
TFTokenClassifierOutput,
......@@ -216,6 +216,8 @@ class TFLayoutLMSelfAttention(tf.keras.layers.Layer):
)
self.dropout = tf.keras.layers.Dropout(rate=config.attention_probs_dropout_prob)
self.is_decoder = config.is_decoder
def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor:
# Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size]
tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size))
......@@ -228,16 +230,49 @@ class TFLayoutLMSelfAttention(tf.keras.layers.Layer):
hidden_states: tf.Tensor,
attention_mask: tf.Tensor,
head_mask: tf.Tensor,
encoder_hidden_states: tf.Tensor,
encoder_attention_mask: tf.Tensor,
past_key_value: Tuple[tf.Tensor],
output_attentions: bool,
training: bool = False,
) -> Tuple[tf.Tensor]:
batch_size = shape_list(hidden_states)[0]
mixed_query_layer = self.query(inputs=hidden_states)
mixed_key_layer = self.key(inputs=hidden_states)
mixed_value_layer = self.value(inputs=hidden_states)
# If this is instantiated as a cross-attention module, the keys
# and values come from an encoder; the attention mask needs to be
# such that the encoder's padding tokens are not attended to.
is_cross_attention = encoder_hidden_states is not None
if is_cross_attention and past_key_value is not None:
# reuse k,v, cross_attentions
key_layer = past_key_value[0]
value_layer = past_key_value[1]
attention_mask = encoder_attention_mask
elif is_cross_attention:
key_layer = self.transpose_for_scores(self.key(inputs=encoder_hidden_states), batch_size)
value_layer = self.transpose_for_scores(self.value(inputs=encoder_hidden_states), batch_size)
attention_mask = encoder_attention_mask
elif past_key_value is not None:
key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size)
value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size)
key_layer = tf.concatenate([past_key_value[0], key_layer], dim=2)
value_layer = tf.concatenate([past_key_value[1], value_layer], dim=2)
else:
key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size)
value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size)
query_layer = self.transpose_for_scores(mixed_query_layer, batch_size)
key_layer = self.transpose_for_scores(mixed_key_layer, batch_size)
value_layer = self.transpose_for_scores(mixed_value_layer, batch_size)
if self.is_decoder:
# if cross_attention save Tuple(tf.Tensor, tf.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(tf.Tensor, tf.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_layer, value_layer)
# Take the dot product between "query" and "key" to get the raw attention scores.
# (batch size, num_heads, seq_len_q, seq_len_k)
......@@ -267,6 +302,8 @@ class TFLayoutLMSelfAttention(tf.keras.layers.Layer):
attention_output = tf.reshape(tensor=attention_output, shape=(batch_size, -1, self.all_head_size))
outputs = (attention_output, attention_probs) if output_attentions else (attention_output,)
if self.is_decoder:
outputs = outputs + (past_key_value,)
return outputs
......@@ -305,6 +342,9 @@ class TFLayoutLMAttention(tf.keras.layers.Layer):
input_tensor: tf.Tensor,
attention_mask: tf.Tensor,
head_mask: tf.Tensor,
encoder_hidden_states: tf.Tensor,
encoder_attention_mask: tf.Tensor,
past_key_value: Tuple[tf.Tensor],
output_attentions: bool,
training: bool = False,
) -> Tuple[tf.Tensor]:
......@@ -312,13 +352,17 @@ class TFLayoutLMAttention(tf.keras.layers.Layer):
hidden_states=input_tensor,
attention_mask=attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_value=past_key_value,
output_attentions=output_attentions,
training=training,
)
attention_output = self.dense_output(
hidden_states=self_outputs[0], input_tensor=input_tensor, training=training
)
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
# add attentions (possibly with past_key_value) if we output them
outputs = (attention_output,) + self_outputs[1:]
return outputs
......@@ -369,6 +413,12 @@ class TFLayoutLMLayer(tf.keras.layers.Layer):
super().__init__(**kwargs)
self.attention = TFLayoutLMAttention(config, name="attention")
self.is_decoder = config.is_decoder
self.add_cross_attention = config.add_cross_attention
if self.add_cross_attention:
if not self.is_decoder:
raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
self.crossattention = TFLayoutLMAttention(config, name="crossattention")
self.intermediate = TFLayoutLMIntermediate(config, name="intermediate")
self.bert_output = TFLayoutLMOutput(config, name="output")
......@@ -377,22 +427,69 @@ class TFLayoutLMLayer(tf.keras.layers.Layer):
hidden_states: tf.Tensor,
attention_mask: tf.Tensor,
head_mask: tf.Tensor,
encoder_hidden_states: Optional[tf.Tensor],
encoder_attention_mask: Optional[tf.Tensor],
past_key_value: Optional[Tuple[tf.Tensor]],
output_attentions: bool,
training: bool = False,
) -> Tuple[tf.Tensor]:
attention_outputs = self.attention(
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
self_attention_outputs = self.attention(
input_tensor=hidden_states,
attention_mask=attention_mask,
head_mask=head_mask,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_value=self_attn_past_key_value,
output_attentions=output_attentions,
training=training,
)
attention_output = attention_outputs[0]
attention_output = self_attention_outputs[0]
# if decoder, the last output is tuple of self-attn cache
if self.is_decoder:
outputs = self_attention_outputs[1:-1]
present_key_value = self_attention_outputs[-1]
else:
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
cross_attn_present_key_value = None
if self.is_decoder and encoder_hidden_states is not None:
if not hasattr(self, "crossattention"):
raise ValueError(
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers "
"by setting `config.add_cross_attention=True`"
)
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
cross_attention_outputs = self.crossattention(
input_tensor=attention_output,
attention_mask=attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_value=cross_attn_past_key_value,
output_attentions=output_attentions,
training=training,
)
attention_output = cross_attention_outputs[0]
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
# add cross-attn cache to positions 3,4 of present_key_value tuple
cross_attn_present_key_value = cross_attention_outputs[-1]
present_key_value = present_key_value + cross_attn_present_key_value
intermediate_output = self.intermediate(hidden_states=attention_output)
layer_output = self.bert_output(
hidden_states=intermediate_output, input_tensor=attention_output, training=training
)
outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them
outputs = (layer_output,) + outputs # add attentions if we output them
# if decoder, return the attn key/values as the last output
if self.is_decoder:
outputs = outputs + (present_key_value,)
return outputs
......@@ -401,7 +498,7 @@ class TFLayoutLMLayer(tf.keras.layers.Layer):
class TFLayoutLMEncoder(tf.keras.layers.Layer):
def __init__(self, config: LayoutLMConfig, **kwargs):
super().__init__(**kwargs)
self.config = config
self.layer = [TFLayoutLMLayer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)]
def call(
......@@ -409,39 +506,61 @@ class TFLayoutLMEncoder(tf.keras.layers.Layer):
hidden_states: tf.Tensor,
attention_mask: tf.Tensor,
head_mask: tf.Tensor,
encoder_hidden_states: Optional[tf.Tensor],
encoder_attention_mask: Optional[tf.Tensor],
past_key_values: Optional[Tuple[Tuple[tf.Tensor]]],
use_cache: Optional[bool],
output_attentions: bool,
output_hidden_states: bool,
return_dict: bool,
training: bool = False,
) -> Union[TFBaseModelOutput, Tuple[tf.Tensor]]:
) -> Union[TFBaseModelOutputWithPastAndCrossAttentions, Tuple[tf.Tensor]]:
all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
next_decoder_cache = () if use_cache else None
for i, layer_module in enumerate(self.layer):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
past_key_value = past_key_values[i] if past_key_values is not None else None
layer_outputs = layer_module(
hidden_states=hidden_states,
attention_mask=attention_mask,
head_mask=head_mask[i],
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_value=past_key_value,
output_attentions=output_attentions,
training=training,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[-1],)
if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
if self.config.add_cross_attention and encoder_hidden_states is not None:
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
# Add last layer
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
return tuple(
v for v in [hidden_states, all_hidden_states, all_attentions, all_cross_attentions] if v is not None
)
return TFBaseModelOutput(
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
return TFBaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=next_decoder_cache,
hidden_states=all_hidden_states,
attentions=all_attentions,
cross_attentions=all_cross_attentions,
)
......@@ -585,12 +704,14 @@ class TFLayoutLMMainLayer(tf.keras.layers.Layer):
position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
encoder_hidden_states: Optional[Union[np.ndarray, tf.Tensor]] = None,
encoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: bool = False,
**kwargs,
) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor]]:
) -> Union[TFBaseModelOutputWithPoolingAndCrossAttentions, Tuple[tf.Tensor]]:
inputs = input_processing(
func=self.call,
config=self.config,
......@@ -665,6 +786,11 @@ class TFLayoutLMMainLayer(tf.keras.layers.Layer):
hidden_states=embedding_output,
attention_mask=extended_attention_mask,
head_mask=inputs["head_mask"],
# Need to pass these required positional arguments to `Encoder`
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=None,
past_key_values=None,
use_cache=False,
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
......@@ -680,11 +806,12 @@ class TFLayoutLMMainLayer(tf.keras.layers.Layer):
pooled_output,
) + encoder_outputs[1:]
return TFBaseModelOutputWithPooling(
return TFBaseModelOutputWithPoolingAndCrossAttentions(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
cross_attentions=encoder_outputs.cross_attentions,
)
......@@ -802,7 +929,9 @@ class TFLayoutLMModel(TFLayoutLMPreTrainedModel):
self.layoutlm = TFLayoutLMMainLayer(config, name="layoutlm")
@add_start_docstrings_to_model_forward(LAYOUTLM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=TFBaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC)
@replace_return_docstrings(
output_type=TFBaseModelOutputWithPoolingAndCrossAttentions, config_class=_CONFIG_FOR_DOC
)
def call(
self,
input_ids: Optional[TFModelInputType] = None,
......@@ -812,12 +941,14 @@ class TFLayoutLMModel(TFLayoutLMPreTrainedModel):
position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
encoder_hidden_states: Optional[Union[np.ndarray, tf.Tensor]] = None,
encoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: Optional[bool] = False,
**kwargs,
) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor]]:
) -> Union[TFBaseModelOutputWithPoolingAndCrossAttentions, Tuple[tf.Tensor]]:
r"""
Returns:
......@@ -859,6 +990,8 @@ class TFLayoutLMModel(TFLayoutLMPreTrainedModel):
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
......@@ -881,15 +1014,25 @@ class TFLayoutLMModel(TFLayoutLMPreTrainedModel):
return outputs
def serving_output(self, output: TFBaseModelOutputWithPooling) -> TFBaseModelOutputWithPooling:
# Copied from transformers.models.bert.modeling_tf_bert.TFBertModel.serving_output
def serving_output(
self, output: TFBaseModelOutputWithPoolingAndCrossAttentions
) -> TFBaseModelOutputWithPoolingAndCrossAttentions:
output_cache = self.config.use_cache and self.config.is_decoder
pkv = tf.convert_to_tensor(output.past_key_values) if output_cache else None
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
cross_attns = tf.convert_to_tensor(output.cross_attentions) if output.cross_attentions is not None else None
if not (self.config.output_attentions and self.config.add_cross_attention):
cross_attns = None
return TFBaseModelOutputWithPooling(
return TFBaseModelOutputWithPoolingAndCrossAttentions(
last_hidden_state=output.last_hidden_state,
pooler_output=output.pooler_output,
past_key_values=pkv,
hidden_states=hs,
attentions=attns,
cross_attentions=cross_attns,
)
......
......@@ -510,7 +510,7 @@ class TFLongformerEmbeddings(tf.keras.layers.Layer):
super().build(input_shape)
def create_position_ids_from_input_ids(self, input_ids):
def create_position_ids_from_input_ids(self, input_ids, past_key_values_length=0):
"""
Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding
symbols are ignored. This is modified from fairseq's `utils.make_positions`.
......@@ -520,11 +520,19 @@ class TFLongformerEmbeddings(tf.keras.layers.Layer):
Returns: tf.Tensor
"""
mask = tf.cast(tf.math.not_equal(input_ids, self.padding_idx), dtype=input_ids.dtype)
incremental_indices = tf.math.cumsum(mask, axis=1) * mask
incremental_indices = (tf.math.cumsum(mask, axis=1) + past_key_values_length) * mask
return incremental_indices + self.padding_idx
def call(self, input_ids=None, position_ids=None, token_type_ids=None, inputs_embeds=None, training=False):
def call(
self,
input_ids=None,
position_ids=None,
token_type_ids=None,
inputs_embeds=None,
past_key_values_length=0,
training=False,
):
"""
Applies embedding based on inputs tensor.
......@@ -544,7 +552,9 @@ class TFLongformerEmbeddings(tf.keras.layers.Layer):
if position_ids is None:
if input_ids is not None:
# Create the position ids from the input token ids. Any padded tokens remain padded.
position_ids = self.create_position_ids_from_input_ids(input_ids=input_ids)
position_ids = self.create_position_ids_from_input_ids(
input_ids=input_ids, past_key_values_length=past_key_values_length
)
else:
position_ids = tf.expand_dims(
tf.range(start=self.padding_idx + 1, limit=input_shape[-1] + self.padding_idx + 1), axis=0
......
......@@ -982,7 +982,6 @@ class TFMobileBertModel(TFMobileBertPreTrainedModel):
return outputs
# Copied from transformers.models.bert.modeling_tf_bert.TFBertModel.serving_output
def serving_output(self, output: TFBaseModelOutputWithPooling) -> TFBaseModelOutputWithPooling:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment