Unverified Commit 2e20c0f3 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Make Flax GPT2 working with cross attention (#13008)



* make flax gpt2 working with cross attention

* Remove encoder->decoder projection layer

* A draft (incomplete) for FlaxEncoderDecoderModel

* Add the method from_encoder_decoder_pretrained + the docstrings

* Fix the mistakes of using EncoderDecoderModel

* Fix style

* Add FlaxEncoderDecoderModel to the library

* Fix cyclic imports

* Add FlaxEncoderDecoderModel to modeling_flax_auto.py

* Remove question comments

* add tests for FlaxEncoderDecoderModel

* add flax_encoder_decoder to the lists of ignored entries in check_repo.py

* fix missing required positional arguments

* Remove **kwargs when creating FlaxEncoderDecoderModel in from_encoder_decoder_pretrained()

Also fix generation eos/pad tokens issue

* Fix: Use sequences from the generated_output

* Change a check from assert to raise ValueError

* Fix examples and token ids issues

* Fix missing all_cross_attentions when outputting tuple in modeling_gpt2

* Remove the changes in configuration docstrings.

* allow for bert 2 gpt2

* make fix-copies

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Change remaining examples to bert2gpt2

* Change the test to Bert2GPT2

* Fix examples

* Fix import

* Fix unpack bug

* Rename to FlaxEncoderDecoderModelTest and change the test to bert2gpt2

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Fix: NotImplentedError -> NotImplementedError

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* up

* finalize
Co-authored-by: default avatarydshieh <ydshieh@user.noreply>
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 7223844d
...@@ -356,7 +356,7 @@ Flax), PyTorch, and/or TensorFlow. ...@@ -356,7 +356,7 @@ Flax), PyTorch, and/or TensorFlow.
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| ELECTRA | ✅ | ✅ | ✅ | ✅ | ✅ | | ELECTRA | ✅ | ✅ | ✅ | ✅ | ✅ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| Encoder decoder | ❌ | ❌ | ✅ | ❌ | | | Encoder decoder | ❌ | ❌ | ✅ | ❌ | |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| FairSeq Machine-Translation | ✅ | ❌ | ✅ | ❌ | ❌ | | FairSeq Machine-Translation | ✅ | ❌ | ✅ | ❌ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
......
...@@ -40,3 +40,10 @@ EncoderDecoderModel ...@@ -40,3 +40,10 @@ EncoderDecoderModel
.. autoclass:: transformers.EncoderDecoderModel .. autoclass:: transformers.EncoderDecoderModel
:members: forward, from_encoder_decoder_pretrained :members: forward, from_encoder_decoder_pretrained
FlaxEncoderDecoderModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxEncoderDecoderModel
:members: __call__, from_encoder_decoder_pretrained
...@@ -1703,6 +1703,7 @@ if is_flax_available(): ...@@ -1703,6 +1703,7 @@ if is_flax_available():
"FlaxElectraPreTrainedModel", "FlaxElectraPreTrainedModel",
] ]
) )
_import_structure["models.encoder_decoder"].append("FlaxEncoderDecoderModel")
_import_structure["models.gpt2"].extend(["FlaxGPT2LMHeadModel", "FlaxGPT2Model", "FlaxGPT2PreTrainedModel"]) _import_structure["models.gpt2"].extend(["FlaxGPT2LMHeadModel", "FlaxGPT2Model", "FlaxGPT2PreTrainedModel"])
_import_structure["models.gpt_neo"].extend( _import_structure["models.gpt_neo"].extend(
["FlaxGPTNeoForCausalLM", "FlaxGPTNeoModel", "FlaxGPTNeoPreTrainedModel"] ["FlaxGPTNeoForCausalLM", "FlaxGPTNeoModel", "FlaxGPTNeoPreTrainedModel"]
...@@ -3171,6 +3172,7 @@ if TYPE_CHECKING: ...@@ -3171,6 +3172,7 @@ if TYPE_CHECKING:
FlaxElectraModel, FlaxElectraModel,
FlaxElectraPreTrainedModel, FlaxElectraPreTrainedModel,
) )
from .models.encoder_decoder import FlaxEncoderDecoderModel
from .models.gpt2 import FlaxGPT2LMHeadModel, FlaxGPT2Model, FlaxGPT2PreTrainedModel from .models.gpt2 import FlaxGPT2LMHeadModel, FlaxGPT2Model, FlaxGPT2PreTrainedModel
from .models.gpt_neo import FlaxGPTNeoForCausalLM, FlaxGPTNeoModel, FlaxGPTNeoPreTrainedModel from .models.gpt_neo import FlaxGPTNeoForCausalLM, FlaxGPTNeoModel, FlaxGPTNeoPreTrainedModel
from .models.marian import FlaxMarianModel, FlaxMarianMTModel, FlaxMarianPreTrainedModel from .models.marian import FlaxMarianModel, FlaxMarianMTModel, FlaxMarianPreTrainedModel
......
...@@ -79,6 +79,7 @@ FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict( ...@@ -79,6 +79,7 @@ FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
("t5", "FlaxT5ForConditionalGeneration"), ("t5", "FlaxT5ForConditionalGeneration"),
("mt5", "FlaxMT5ForConditionalGeneration"), ("mt5", "FlaxMT5ForConditionalGeneration"),
("marian", "FlaxMarianMTModel"), ("marian", "FlaxMarianMTModel"),
("encoder-decoder", "FlaxEncoderDecoderModel"),
] ]
) )
......
...@@ -625,13 +625,21 @@ class FlaxBertModule(nn.Module): ...@@ -625,13 +625,21 @@ class FlaxBertModule(nn.Module):
self, self,
input_ids, input_ids,
attention_mask, attention_mask,
token_type_ids, token_type_ids: Optional[np.ndarray] = None,
position_ids, position_ids: Optional[np.ndarray] = None,
deterministic: bool = True, deterministic: bool = True,
output_attentions: bool = False, output_attentions: bool = False,
output_hidden_states: bool = False, output_hidden_states: bool = False,
return_dict: bool = True, return_dict: bool = True,
): ):
# make sure `token_type_ids` is correctly initialized when not passed
if token_type_ids is None:
token_type_ids = jnp.zeros_like(input_ids)
# make sure `position_ids` is correctly initialized when not passed
if position_ids is None:
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
hidden_states = self.embeddings( hidden_states = self.embeddings(
input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic
) )
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from ...file_utils import _LazyModule, is_torch_available from ...file_utils import _LazyModule, is_flax_available, is_torch_available
_import_structure = { _import_structure = {
...@@ -28,6 +28,8 @@ _import_structure = { ...@@ -28,6 +28,8 @@ _import_structure = {
if is_torch_available(): if is_torch_available():
_import_structure["modeling_encoder_decoder"] = ["EncoderDecoderModel"] _import_structure["modeling_encoder_decoder"] = ["EncoderDecoderModel"]
if is_flax_available():
_import_structure["modeling_flax_encoder_decoder"] = ["FlaxEncoderDecoderModel"]
if TYPE_CHECKING: if TYPE_CHECKING:
from .configuration_encoder_decoder import EncoderDecoderConfig from .configuration_encoder_decoder import EncoderDecoderConfig
...@@ -35,6 +37,9 @@ if TYPE_CHECKING: ...@@ -35,6 +37,9 @@ if TYPE_CHECKING:
if is_torch_available(): if is_torch_available():
from .modeling_encoder_decoder import EncoderDecoderModel from .modeling_encoder_decoder import EncoderDecoderModel
if is_flax_available():
from .modeling_flax_encoder_decoder import FlaxEncoderDecoderModel
else: else:
import sys import sys
......
# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Classes to support Flax Encoder-Decoder architectures """
import os
from typing import Optional, Tuple, Union
import flax.linen as nn
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict, unfreeze
from jax import lax
from jax.random import PRNGKey
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutputWithCrossAttentions, FlaxSeq2SeqLMOutput
from ...modeling_flax_utils import FlaxPreTrainedModel
from ...utils import logging
from .configuration_encoder_decoder import EncoderDecoderConfig
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "EncoderDecoderConfig"
ENCODER_DECODER_START_DOCSTRING = r"""
This class can be used to initialize a sequence-to-sequence model with any pretrained autoencoding model as the
encoder and any pretrained autoregressive model as the decoder. The encoder is loaded via
:meth:`~transformers.AutoModel.from_pretrained` function and the decoder is loaded via
:meth:`~transformers.AutoModelForCausalLM.from_pretrained` function. Cross-attention layers are automatically added
to the decoder and should be fine-tuned on a downstream generative task, like summarization.
The effectiveness of initializing sequence-to-sequence models with pretrained checkpoints for sequence generation
tasks was shown in `Leveraging Pre-trained Checkpoints for Sequence Generation Tasks
<https://arxiv.org/abs/1907.12461>`__ by Sascha Rothe, Shashi Narayan, Aliaksei Severyn. Michael Matena, Yanqi
Zhou, Wei Li, Peter J. Liu.
After such an Encoder Decoder model has been trained/fine-tuned, it can be saved/loaded just like any other models
(see the examples for more information).
This model inherits from :class:`~transformers.FlaxPreTrainedModel`. Check the superclass documentation for the
generic methods the library implements for all its model (such as downloading or saving, resizing the input
embeddings, pruning heads etc.)
This model is also a Flax Linen `flax.nn.Module
<https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html>`__ subclass. Use it as a regular Flax
Module and refer to the Flax documentation for all matter related to general usage and behavior.
Parameters:
config (:class:`~transformers.EncoderDecoderConfig`): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the
model weights.
"""
ENCODER_DECODER_INPUTS_DOCSTRING = r"""
Args:
input_ids (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Indices can be obtained using :class:`~transformers.PreTrainedTokenizer`. See
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
details.
`What are input IDs? <../glossary.html#input-ids>`__
attention_mask (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__
decoder_input_ids (:obj:`jnp.ndarray` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
Indices of decoder input sequence tokens in the vocabulary.
Indices can be obtained using :class:`~transformers.PreTrainedTokenizer`. See
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
details.
`What are decoder input IDs? <../glossary.html#decoder-input-ids>`__
For sequence to sequence training, :obj:`decoder_input_ids` should be provided. If no
:obj:`decoder_input_ids` is provided, the model will create this tensor by shifting the :obj:`input_ids` to
the right for denoising pre-training.
decoder_attention_mask (:obj:`jnp.ndarray` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will
also be used by default.
position_ids (:obj:`numpy.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0,
config.encoder.max_position_embeddings - 1]``.
decoder_position_ids (:obj:`numpy.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
range ``[0, config.decoder.max_position_embeddings - 1]``.
output_attentions (:obj:`bool`, `optional`):
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
tensors for more detail.
output_hidden_states (:obj:`bool`, `optional`):
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
more detail.
return_dict (:obj:`bool`, `optional`):
If set to ``True``, the model will return a :class:`~transformers.file_utils.FlaxSeq2SeqLMOutput` instead
of a plain tuple.
"""
ENCODER_DECODER_ENCODE_INPUTS_DOCSTRING = r"""
Args:
input_ids (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Indices can be obtained using :class:`~transformers.PreTrainedTokenizer`. See
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
details.
`What are input IDs? <../glossary.html#input-ids>`__
attention_mask (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__
position_ids (:obj:`numpy.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0,
config.encoder.max_position_embeddings - 1]``.
output_attentions (:obj:`bool`, `optional`):
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
tensors for more detail.
output_hidden_states (:obj:`bool`, `optional`):
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
more detail.
return_dict (:obj:`bool`, `optional`):
If set to ``True``, the model will return a :class:`~transformers.file_utils.FlaxBaseModelOutput` instead
of a plain tuple.
"""
ENCODER_DECODER_DECODE_INPUTS_DOCSTRING = r"""
Args:
decoder_input_ids (:obj:`jnp.ndarray` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
Indices of decoder input sequence tokens in the vocabulary.
Indices can be obtained using :class:`~transformers.PreTrainedTokenizer`. See
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
details.
`What are decoder input IDs? <../glossary.html#decoder-input-ids>`__
If :obj:`past_key_values` is used, optionally only the last :obj:`decoder_input_ids` have to be input (see
:obj:`past_key_values`).
For sequence to sequence training, :obj:`decoder_input_ids` should be provided. If no
:obj:`decoder_input_ids` is provided, the model will create this tensor by shifting the :obj:`input_ids` to
the right for denoising pre-training.
encoder_outputs (:obj:`tuple(tuple(jnp.ndarray)`):
Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`:
:obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`,
`optional`) is a sequence of hidden-states at the output of the last layer of the encoder. Used in the
cross-attention of the decoder.
encoder_attention_mask (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__
decoder_attention_mask (:obj:`jnp.ndarray` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will
also be used by default.
decoder_position_ids (:obj:`numpy.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
range ``[0, config.decoder.max_position_embeddings - 1]``.
past_key_values (:obj:`Dict[str, np.ndarray]`, `optional`, returned by ``init_cache`` or when passing previous ``past_key_values``):
Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast
auto-regressive decoding. Pre-computed key and value hidden-states are of shape `[batch_size, max_length]`.
output_attentions (:obj:`bool`, `optional`):
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
tensors for more detail.
output_hidden_states (:obj:`bool`, `optional`):
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
more detail.
return_dict (:obj:`bool`, `optional`):
If set to ``True``, the model will return a
:class:`~transformers.file_utils.FlaxCausalLMOutputWithCrossAttentions` instead of a plain tuple.
"""
class FlaxEncoderDecoderModule(nn.Module):
config: EncoderDecoderConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
encoder_config = self.config.encoder
decoder_config = self.config.decoder
# Copied from `modeling_hybrid_clip.py` with modifications.
from ...models.auto.modeling_flax_auto import FLAX_MODEL_FOR_CAUSAL_LM_MAPPING, FLAX_MODEL_MAPPING
encoder_module = FLAX_MODEL_MAPPING[encoder_config.__class__].module_class
decoder_module = FLAX_MODEL_FOR_CAUSAL_LM_MAPPING[decoder_config.__class__].module_class
self.encoder = encoder_module(encoder_config, dtype=self.dtype)
self.decoder = decoder_module(decoder_config, dtype=self.dtype)
def _get_encoder_module(self):
return self.encoder
def _get_decoder_module(self):
return self.decoder
def __call__(
self,
input_ids,
attention_mask,
decoder_input_ids,
decoder_attention_mask,
position_ids,
decoder_position_ids,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
deterministic: bool = True,
):
encoder_outputs = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=deterministic,
)
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask,
position_ids=decoder_position_ids,
encoder_hidden_states=encoder_outputs[0],
encoder_attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=deterministic,
)
if not return_dict:
return decoder_outputs + encoder_outputs
return FlaxSeq2SeqLMOutput(
logits=decoder_outputs.logits,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions,
)
@add_start_docstrings(ENCODER_DECODER_START_DOCSTRING)
class FlaxEncoderDecoderModel(FlaxPreTrainedModel):
r"""
:class:`~transformers.FlaxEncoderDecoderModel` is a generic model class that will be instantiated as a transformer
architecture with the module (flax.nn.Module) of one of the base model classes of the library as encoder module and
another one as decoder module when created with the :meth`~transformers.FlaxAutoModel.from_pretrained` class method
for the encoder and :meth`~transformers.FlaxAutoModelForCausalLM.from_pretrained` class method for the decoder.
"""
config_class = EncoderDecoderConfig
base_model_prefix = "encoder_decoder"
module_class = FlaxEncoderDecoderModule
def __init__(
self,
config: EncoderDecoderConfig,
input_shape: Optional[Tuple] = None,
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
**kwargs
):
if input_shape is None:
input_shape = ((1, 1), (1, 1))
module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
encoder_input_shape, decoder_input_shape = input_shape
# init input tensors
input_ids = jnp.zeros(encoder_input_shape, dtype="i4")
attention_mask = jnp.ones_like(input_ids)
decoder_input_ids = jnp.zeros(decoder_input_shape, dtype="i4")
decoder_attention_mask = jnp.ones_like(decoder_input_ids)
batch_size, sequence_length = input_ids.shape
position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
decoder_batch_size, decoder_sequence_length = decoder_input_ids.shape
if not decoder_batch_size == batch_size:
raise ValueError(
f"The inputs of encoder and decoder should have the same batch size, but got {batch_size} for encoder and {decoder_batch_size} for decoder."
)
decoder_position_ids = jnp.broadcast_to(
jnp.arange(decoder_sequence_length)[None, :], (decoder_batch_size, decoder_sequence_length)
)
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
return self.module.init(
rngs,
input_ids,
attention_mask,
decoder_input_ids,
decoder_attention_mask,
position_ids,
decoder_position_ids,
)["params"]
def init_cache(self, batch_size, max_length, encoder_outputs):
r"""
Args:
batch_size (:obj:`int`):
batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
max_length (:obj:`int`):
maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
cache.
encoder_outputs (:obj:`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`):
``encoder_outputs`` consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`,
`optional`: :obj:`attentions`). :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length,
hidden_size)`, `optional`) is a sequence of hidden-states at the output of the last layer of the
encoder. Used in the cross-attention of the decoder.
"""
# init input variables to retrieve cache
decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4")
decoder_attention_mask = jnp.ones_like(decoder_input_ids)
decoder_position_ids = jnp.broadcast_to(
jnp.arange(jnp.atleast_2d(decoder_input_ids).shape[-1]), decoder_input_ids.shape
)
def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs):
decoder_module = module._get_decoder_module()
return decoder_module(
input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask,
position_ids=decoder_position_ids,
**kwargs,
)
init_variables = self.module.init(
jax.random.PRNGKey(0),
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
decoder_position_ids=decoder_position_ids,
encoder_hidden_states=encoder_outputs[0],
init_cache=True,
method=_decoder_forward, # we only need to call the decoder to init the cache
)
return unfreeze(init_variables["cache"])
@add_start_docstrings(ENCODER_DECODER_ENCODE_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=FlaxBaseModelOutput, config_class=_CONFIG_FOR_DOC)
def encode(
self,
input_ids: jnp.ndarray,
attention_mask: Optional[jnp.ndarray] = None,
position_ids: Optional[jnp.ndarray] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
train: bool = False,
params: dict = None,
dropout_rng: PRNGKey = None,
):
r"""
Returns:
Example::
>>> from transformers import FlaxEncoderDecoderModel, BertTokenizer
>>> # initialize a bert2gpt2 from pretrained BERT and GPT2 models. Note that the cross-attention layers will be randomly initialized
>>> model = FlaxEncoderDecoderModel.from_encoder_decoder_pretrained('bert-base-cased', 'gpt2')
>>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
>>> text = "My friends are cool but they eat too many carbs."
>>> input_ids = tokenizer.encode(text, return_tensors='np')
>>> encoder_outputs = model.encode(input_ids)
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.return_dict
if attention_mask is None:
attention_mask = jnp.ones_like(input_ids)
if position_ids is None:
batch_size, sequence_length = input_ids.shape
position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
# Handle any PRNG if needed
rngs = {}
if dropout_rng is not None:
rngs["dropout"] = dropout_rng
def _encoder_forward(module, input_ids, attention_mask, position_ids, **kwargs):
encode_module = module._get_encoder_module()
return encode_module(input_ids, attention_mask, position_ids, **kwargs)
outputs = self.module.apply(
{"params": params or self.params},
input_ids=jnp.array(input_ids, dtype="i4"),
attention_mask=jnp.array(attention_mask, dtype="i4"),
position_ids=jnp.array(position_ids, dtype="i4"),
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=not train,
rngs=rngs,
method=_encoder_forward,
)
if return_dict:
outputs = FlaxBaseModelOutput(
last_hidden_state=outputs.last_hidden_state,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
return outputs
@add_start_docstrings(ENCODER_DECODER_DECODE_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=FlaxCausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
def decode(
self,
decoder_input_ids,
encoder_outputs,
encoder_attention_mask: Optional[jnp.ndarray] = None,
decoder_attention_mask: Optional[jnp.ndarray] = None,
decoder_position_ids: Optional[jnp.ndarray] = None,
past_key_values: dict = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
train: bool = False,
params: dict = None,
dropout_rng: PRNGKey = None,
):
r"""
Returns:
Example::
>>> from transformers import FlaxEncoderDecoderModel, BertTokenizer
>>> import jax.numpy as jnp
>>> # initialize a bert2gpt2 from pretrained BERT and GPT2 models. Note that the cross-attention layers will be randomly initialized
>>> model = FlaxEncoderDecoderModel.from_encoder_decoder_pretrained('bert-base-cased', 'gpt2')
>>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
>>> text = "My friends are cool but they eat too many carbs."
>>> input_ids = tokenizer.encode(text, max_length=1024, return_tensors='np')
>>> encoder_outputs = model.encode(input_ids)
>>> decoder_start_token_id = model.config.decoder.bos_token_id
>>> decoder_input_ids = jnp.ones((input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id
>>> outputs = model.decode(decoder_input_ids, encoder_outputs)
>>> logits = outputs.logits
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.return_dict
encoder_hidden_states = encoder_outputs[0]
if encoder_attention_mask is None:
batch_size, sequence_length = encoder_hidden_states.shape[:2]
encoder_attention_mask = jnp.ones((batch_size, sequence_length))
batch_size, sequence_length = decoder_input_ids.shape
if decoder_attention_mask is None:
decoder_attention_mask = jnp.ones((batch_size, sequence_length))
if decoder_position_ids is None:
if past_key_values is not None:
raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.")
decoder_position_ids = jnp.broadcast_to(
jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
)
# Handle any PRNG if needed
rngs = {}
if dropout_rng is not None:
rngs["dropout"] = dropout_rng
inputs = {"params": params or self.params}
# if past_key_values are passed then cache is already initialized a private flag init_cache has to be
# passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that
# it can be changed by FlaxBartAttention module
if past_key_values:
inputs["cache"] = past_key_values
mutable = ["cache"]
else:
mutable = False
def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs):
decoder_module = module._get_decoder_module()
return decoder_module(
decoder_input_ids,
decoder_attention_mask,
decoder_position_ids,
**kwargs,
)
outputs = self.module.apply(
inputs,
decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"),
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=not train,
rngs=rngs,
mutable=mutable,
method=_decoder_forward,
)
# add updated cache to model output
if past_key_values is not None and return_dict:
outputs, past = outputs
outputs["past_key_values"] = unfreeze(past["cache"])
return outputs
elif past_key_values is not None and not return_dict:
outputs, past = outputs
outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]
return outputs
@add_start_docstrings_to_model_forward(ENCODER_DECODER_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
def __call__(
self,
input_ids: jnp.ndarray,
attention_mask: Optional[jnp.ndarray] = None,
decoder_input_ids: Optional[jnp.ndarray] = None,
decoder_attention_mask: Optional[jnp.ndarray] = None,
position_ids: Optional[jnp.ndarray] = None,
decoder_position_ids: Optional[jnp.ndarray] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
train: bool = False,
params: dict = None,
dropout_rng: PRNGKey = None,
):
r"""
Returns:
Examples::
>>> from transformers import FlaxEncoderDecoderModel, BertTokenizer, GPT2Tokenizer
>>> # load a fine-tuned bert2gpt2 model
>>> model = FlaxEncoderDecoderModel.from_pretrained("patrickvonplaten/bert2gpt2-cnn_dailymail-fp16")
>>> # load input & output tokenizer
>>> tokenizer_input = BertTokenizer.from_pretrained('bert-base-cased')
>>> tokenizer_output = GPT2Tokenizer.from_pretrained('gpt2')
>>> article = '''Sigma Alpha Epsilon is under fire for a video showing party-bound fraternity members
... singing a racist chant. SAE's national chapter suspended the students,
... but University of Oklahoma President David Boren took it a step further,
... saying the university's affiliation with the fraternity is permanently done.'''
>>> input_ids = tokenizer_input(article, add_special_tokens=True, return_tensors='np').input_ids
>>> # use GPT2's eos_token as the pad as well as eos token
>>> model.config.eos_token_id = model.config.decoder.eos_token_id
>>> model.config.pad_token_id = model.config.eos_token_id
>>> sequences = model.generate(input_ids, num_beams=4, max_length=12).sequences
>>> summary = tokenizer_output.batch_decode(sequences, skip_special_tokens=True)[0]
>>> assert summary == "SAS Alpha Epsilon suspended Sigma Alpha Epsilon members"
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.return_dict
# prepare encoder inputs
if attention_mask is None:
attention_mask = jnp.ones_like(input_ids)
if position_ids is None:
batch_size, sequence_length = input_ids.shape
position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
if decoder_attention_mask is None:
decoder_attention_mask = jnp.ones_like(decoder_input_ids)
if decoder_position_ids is None:
batch_size, sequence_length = decoder_input_ids.shape
decoder_position_ids = jnp.broadcast_to(
jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
)
# Handle any PRNG if needed
rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
return self.module.apply(
{"params": params or self.params},
input_ids=jnp.array(input_ids, dtype="i4"),
attention_mask=jnp.array(attention_mask, dtype="i4"),
decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
position_ids=jnp.array(position_ids, dtype="i4"),
decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=not train,
rngs=rngs,
)
def prepare_inputs_for_generation(
self,
decoder_input_ids,
max_length,
attention_mask: Optional[jnp.DeviceArray] = None,
decoder_attention_mask: Optional[jnp.DeviceArray] = None,
encoder_outputs=None,
**kwargs
):
# initializing the cache
batch_size, seq_length = decoder_input_ids.shape
past_key_values = self.init_cache(batch_size, max_length, encoder_outputs)
# Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
# But since the decoder uses a causal mask, those positions are masked anyways.
# Thus we can create a single static attention_mask here, which is more efficient for compilation
extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
if decoder_attention_mask is not None:
decoder_position_ids = decoder_attention_mask.cumsum(axis=-1) - 1
extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, decoder_attention_mask, (0, 0))
else:
decoder_position_ids = jnp.broadcast_to(
jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)
)
return {
"past_key_values": past_key_values,
"encoder_outputs": encoder_outputs,
"encoder_attention_mask": attention_mask,
"decoder_attention_mask": extended_attention_mask,
"decoder_position_ids": decoder_position_ids,
}
def update_inputs_for_generation(self, model_outputs, model_kwargs):
model_kwargs["past_key_values"] = model_outputs.past_key_values
model_kwargs["decoder_position_ids"] = model_kwargs["decoder_position_ids"][:, -1:] + 1
return model_kwargs
@classmethod
def from_encoder_decoder_pretrained(
cls,
encoder_pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
decoder_pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
*model_args,
**kwargs
) -> FlaxPreTrainedModel:
r"""
Instantiate an encoder and a decoder from one or two base classes of the library from pretrained model
checkpoints.
Params:
encoder_pretrained_model_name_or_path (:obj: `Union[str, os.PathLike]`, `optional`):
Information necessary to initiate the encoder. Can be either:
- A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co.
Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under
a user or organization name, like ``dbmdz/bert-base-german-cased``.
- A path to a `directory` containing model weights saved using
:func:`~transformers.FlaxPreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``.
decoder_pretrained_model_name_or_path (:obj: `Union[str, os.PathLike]`, `optional`, defaults to `None`):
Information necessary to initiate the decoder. Can be either:
- A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co.
Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under
a user or organization name, like ``dbmdz/bert-base-german-cased``.
- A path to a `directory` containing model weights saved using
:func:`~transformers.FlaxPreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``.
model_args (remaining positional arguments, `optional`):
All remaning positional arguments will be passed to the underlying model's ``__init__`` method.
kwargs (remaining dictionary of keyword arguments, `optional`):
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
:obj:`output_attentions=True`).
- To update the encoder configuration, use the prefix `encoder_` for each configuration parameter.
- To update the decoder configuration, use the prefix `decoder_` for each configuration parameter.
- To update the parent model configuration, do not use a prefix for each configuration parameter.
Behaves differently depending on whether a :obj:`config` is provided or automatically loaded.
Example::
>>> from transformers import FlaxEncoderDecoderModel
>>> # initialize a bert2gpt2 from pretrained BERT and GPT2 models. Note that the cross-attention layers will be randomly initialized
>>> model = FlaxEncoderDecoderModel.from_encoder_decoder_pretrained('bert-base-cased', 'gpt2')
>>> # saving model after fine-tuning
>>> model.save_pretrained("./bert2gpt2")
>>> # load fine-tuned model
>>> model = FlaxEncoderDecoderModel.from_pretrained("./bert2gpt2")
"""
kwargs_encoder = {
argument[len("encoder_") :]: value for argument, value in kwargs.items() if argument.startswith("encoder_")
}
kwargs_decoder = {
argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
}
# remove encoder, decoder kwargs from kwargs
for key in kwargs_encoder.keys():
del kwargs["encoder_" + key]
for key in kwargs_decoder.keys():
del kwargs["decoder_" + key]
# Load and initialize the encoder and decoder
# The distinction between encoder and decoder at the model level is made
# by the value of the flag `is_decoder` that we need to set correctly.
encoder = kwargs_encoder.pop("model", None)
if encoder is None:
assert (
encoder_pretrained_model_name_or_path is not None
), "If `model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has to be defined"
from ..auto.modeling_flax_auto import FlaxAutoModel
if "config" not in kwargs_encoder:
from ..auto.configuration_auto import AutoConfig
encoder_config = AutoConfig.from_pretrained(encoder_pretrained_model_name_or_path)
if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True:
logger.info(
f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model from a decoder model. Cross-attention and casual mask are disabled."
)
encoder_config.is_decoder = False
encoder_config.add_cross_attention = False
kwargs_encoder["config"] = encoder_config
encoder = FlaxAutoModel.from_pretrained(
encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder
)
decoder = kwargs_decoder.pop("model", None)
if decoder is None:
assert (
decoder_pretrained_model_name_or_path is not None
), "If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has to be defined"
from ..auto.modeling_flax_auto import FlaxAutoModelForCausalLM
if "config" not in kwargs_decoder:
from ..auto.configuration_auto import AutoConfig
decoder_config = AutoConfig.from_pretrained(decoder_pretrained_model_name_or_path)
if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
logger.info(
f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. Cross attention layers are added to {decoder_pretrained_model_name_or_path} and randomly initialized if {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers."
)
decoder_config.is_decoder = True
decoder_config.add_cross_attention = True
kwargs_decoder["config"] = decoder_config
if kwargs_decoder["config"].is_decoder is False or kwargs_decoder["config"].add_cross_attention is False:
logger.warning(
f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a `decoder_config` to `.from_encoder_decoder_pretrained(...)`"
)
decoder = FlaxAutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
# instantiate config with corresponding kwargs
dtype = kwargs.pop("dtype", jnp.float32)
config = EncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config, **kwargs)
# init model
model = cls(config, dtype=dtype)
model.params["encoder"] = encoder.params
model.params["decoder"] = decoder.params
return model
...@@ -24,7 +24,10 @@ from flax.linen.attention import dot_product_attention_weights ...@@ -24,7 +24,10 @@ from flax.linen.attention import dot_product_attention_weights
from jax import lax from jax import lax
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput from ...modeling_flax_outputs import (
FlaxBaseModelOutputWithPastAndCrossAttentions,
FlaxCausalLMOutputWithCrossAttentions,
)
from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring
from ...utils import logging from ...utils import logging
from .configuration_gpt2 import GPT2Config from .configuration_gpt2 import GPT2Config
...@@ -117,6 +120,8 @@ class FlaxConv1D(nn.Module): ...@@ -117,6 +120,8 @@ class FlaxConv1D(nn.Module):
class FlaxGPT2Attention(nn.Module): class FlaxGPT2Attention(nn.Module):
config: GPT2Config config: GPT2Config
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
causal: bool = True
is_cross_attention: bool = False
def setup(self): def setup(self):
config = self.config config = self.config
...@@ -124,10 +129,19 @@ class FlaxGPT2Attention(nn.Module): ...@@ -124,10 +129,19 @@ class FlaxGPT2Attention(nn.Module):
self.num_heads = config.num_attention_heads self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads self.head_dim = self.embed_dim // self.num_heads
self.c_attn = FlaxConv1D(features=3 * self.embed_dim, dtype=self.dtype) if self.is_cross_attention:
self.c_attn = FlaxConv1D(2 * self.embed_dim, dtype=self.dtype)
self.q_attn = FlaxConv1D(self.embed_dim, dtype=self.dtype)
else:
self.c_attn = FlaxConv1D(3 * self.embed_dim, dtype=self.dtype)
self.c_proj = FlaxConv1D(self.embed_dim, dtype=self.dtype) self.c_proj = FlaxConv1D(self.embed_dim, dtype=self.dtype)
self.resid_dropout = nn.Dropout(rate=config.resid_pdrop) self.resid_dropout = nn.Dropout(rate=config.resid_pdrop)
self.causal_mask = make_causal_mask(jnp.ones((1, config.max_position_embeddings), dtype="bool"), dtype="bool")
if self.causal:
self.causal_mask = make_causal_mask(
jnp.ones((1, config.max_position_embeddings), dtype="bool"), dtype="bool"
)
def _split_heads(self, hidden_states): def _split_heads(self, hidden_states):
return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim)) return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim))
...@@ -170,13 +184,26 @@ class FlaxGPT2Attention(nn.Module): ...@@ -170,13 +184,26 @@ class FlaxGPT2Attention(nn.Module):
def __call__( def __call__(
self, self,
hidden_states, hidden_states,
key_value_states: Optional[jnp.ndarray] = None,
attention_mask=None, attention_mask=None,
deterministic: bool = True, deterministic: bool = True,
init_cache: bool = False, init_cache: bool = False,
output_attentions: bool = False, output_attentions: bool = False,
): ):
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
batch_size = hidden_states.shape[0]
if not is_cross_attention:
qkv_out = self.c_attn(hidden_states) qkv_out = self.c_attn(hidden_states)
query, key, value = jnp.split(qkv_out, 3, axis=2) query, key, value = jnp.split(qkv_out, 3, axis=2)
else:
q_out = self.q_attn(hidden_states)
(query,) = jnp.split(q_out, 1, axis=2)
kv_out = self.c_attn(key_value_states)
key, value = jnp.split(kv_out, 2, axis=2)
query = self._split_heads(query) query = self._split_heads(query)
key = self._split_heads(key) key = self._split_heads(key)
...@@ -184,6 +211,7 @@ class FlaxGPT2Attention(nn.Module): ...@@ -184,6 +211,7 @@ class FlaxGPT2Attention(nn.Module):
query_length, key_length = query.shape[1], key.shape[1] query_length, key_length = query.shape[1], key.shape[1]
if self.causal:
if self.has_variable("cache", "cached_key"): if self.has_variable("cache", "cached_key"):
mask_shift = self.variables["cache"]["cache_index"] mask_shift = self.variables["cache"]["cache_index"]
max_decoder_length = self.variables["cache"]["cached_key"].shape[1] max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
...@@ -192,12 +220,16 @@ class FlaxGPT2Attention(nn.Module): ...@@ -192,12 +220,16 @@ class FlaxGPT2Attention(nn.Module):
) )
else: else:
causal_mask = self.causal_mask[:, :, :query_length, :key_length] causal_mask = self.causal_mask[:, :, :query_length, :key_length]
batch_size = hidden_states.shape[0]
causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:]) causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
# combine masks if needed
if attention_mask is not None and self.causal:
attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape) attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
attention_mask = combine_masks(attention_mask, causal_mask) attention_mask = combine_masks(attention_mask, causal_mask)
elif self.causal:
attention_mask = causal_mask
elif attention_mask is not None:
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
dropout_rng = None dropout_rng = None
if not deterministic and self.config.attn_pdrop > 0.0: if not deterministic and self.config.attn_pdrop > 0.0:
...@@ -205,15 +237,18 @@ class FlaxGPT2Attention(nn.Module): ...@@ -205,15 +237,18 @@ class FlaxGPT2Attention(nn.Module):
# During fast autoregressive decoding, we feed one position at a time, # During fast autoregressive decoding, we feed one position at a time,
# and cache the keys and values step by step. # and cache the keys and values step by step.
if self.has_variable("cache", "cached_key") or init_cache: if self.causal and (self.has_variable("cache", "cached_key") or init_cache):
key, value, attention_mask = self._concatenate_to_cache(key, value, query, attention_mask) key, value, attention_mask = self._concatenate_to_cache(key, value, query, attention_mask)
# transform boolean mask into float mask # transform boolean mask into float mask
if attention_mask is not None:
attention_bias = lax.select( attention_bias = lax.select(
attention_mask > 0, attention_mask > 0,
jnp.full(attention_mask.shape, 0.0).astype(self.dtype), jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
jnp.full(attention_mask.shape, -1e4).astype(self.dtype), jnp.full(attention_mask.shape, -1e4).astype(self.dtype),
) )
else:
attention_bias = None
# usual dot product attention # usual dot product attention
attn_weights = dot_product_attention_weights( attn_weights = dot_product_attention_weights(
...@@ -267,19 +302,28 @@ class FlaxGPT2Block(nn.Module): ...@@ -267,19 +302,28 @@ class FlaxGPT2Block(nn.Module):
self.ln_1 = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype) self.ln_1 = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
self.attn = FlaxGPT2Attention(self.config, dtype=self.dtype) self.attn = FlaxGPT2Attention(self.config, dtype=self.dtype)
self.ln_2 = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype) self.ln_2 = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
if self.config.add_cross_attention:
self.crossattention = FlaxGPT2Attention(
config=self.config, dtype=self.dtype, causal=False, is_cross_attention=True
)
self.ln_cross_attn = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
self.mlp = FlaxGPT2MLP(self.config, inner_dim, dtype=self.dtype) self.mlp = FlaxGPT2MLP(self.config, inner_dim, dtype=self.dtype)
def __call__( def __call__(
self, self,
hidden_states, hidden_states,
attention_mask=None, attention_mask=None,
encoder_hidden_states: Optional[jnp.ndarray] = None,
encoder_attention_mask: Optional[jnp.ndarray] = None,
deterministic: bool = True, deterministic: bool = True,
init_cache: bool = False, init_cache: bool = False,
output_attentions: bool = False, output_attentions: bool = False,
): ):
residual = hidden_states residual = hidden_states
hidden_states = self.ln_1(hidden_states) hidden_states = self.ln_1(hidden_states)
outputs = self.attn( attn_outputs = self.attn(
hidden_states, hidden_states,
attention_mask=attention_mask, attention_mask=attention_mask,
deterministic=deterministic, deterministic=deterministic,
...@@ -287,16 +331,42 @@ class FlaxGPT2Block(nn.Module): ...@@ -287,16 +331,42 @@ class FlaxGPT2Block(nn.Module):
output_attentions=output_attentions, output_attentions=output_attentions,
) )
# residual connection # residual connection
attn_output = outputs[0] attn_output = attn_outputs[0] # output_attn: a, (attentions)
outputs = attn_outputs[1:]
# residual connection
hidden_states = attn_output + residual hidden_states = attn_output + residual
# Cross-Attention Block
if encoder_hidden_states is not None:
# add one self-attention block for cross-attention
if not hasattr(self, "crossattention"):
raise ValueError(
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with "
"cross-attention layers by setting `config.add_cross_attention=True`"
)
residual = hidden_states
hidden_states = self.ln_cross_attn(hidden_states)
cross_attn_outputs = self.crossattention(
hidden_states,
key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
deterministic=deterministic,
output_attentions=output_attentions,
)
attn_output = cross_attn_outputs[0]
# residual connection
hidden_states = residual + attn_output
outputs = outputs + cross_attn_outputs[1:] # add cross attentions if we output attention weights
residual = hidden_states residual = hidden_states
hidden_states = self.ln_2(hidden_states) hidden_states = self.ln_2(hidden_states)
feed_forward_hidden_states = self.mlp(hidden_states, deterministic=deterministic) feed_forward_hidden_states = self.mlp(hidden_states, deterministic=deterministic)
# residual connection # residual connection
hidden_states = residual + feed_forward_hidden_states hidden_states = residual + feed_forward_hidden_states
return (hidden_states,) + outputs[1:] outputs = (hidden_states,) + outputs
return outputs
class FlaxGPT2PreTrainedModel(FlaxPreTrainedModel): class FlaxGPT2PreTrainedModel(FlaxPreTrainedModel):
...@@ -328,7 +398,22 @@ class FlaxGPT2PreTrainedModel(FlaxPreTrainedModel): ...@@ -328,7 +398,22 @@ class FlaxGPT2PreTrainedModel(FlaxPreTrainedModel):
params_rng, dropout_rng = jax.random.split(rng) params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng} rngs = {"params": params_rng, "dropout": dropout_rng}
return self.module.init(rngs, input_ids, attention_mask, position_ids, return_dict=False)["params"] if self.config.add_cross_attention:
encoder_hidden_states = jnp.zeros(input_shape + (self.config.n_embd,))
encoder_attention_mask = attention_mask
module_init_outputs = self.module.init(
rngs,
input_ids,
attention_mask,
position_ids,
encoder_hidden_states,
encoder_attention_mask,
return_dict=False,
)
else:
module_init_outputs = self.module.init(rngs, input_ids, attention_mask, position_ids, return_dict=False)
return module_init_outputs["params"]
def init_cache(self, batch_size, max_length): def init_cache(self, batch_size, max_length):
r""" r"""
...@@ -355,6 +440,8 @@ class FlaxGPT2PreTrainedModel(FlaxPreTrainedModel): ...@@ -355,6 +440,8 @@ class FlaxGPT2PreTrainedModel(FlaxPreTrainedModel):
input_ids, input_ids,
attention_mask=None, attention_mask=None,
position_ids=None, position_ids=None,
encoder_hidden_states: Optional[jnp.ndarray] = None,
encoder_attention_mask: Optional[jnp.ndarray] = None,
params: dict = None, params: dict = None,
past_key_values: dict = None, past_key_values: dict = None,
dropout_rng: jax.random.PRNGKey = None, dropout_rng: jax.random.PRNGKey = None,
...@@ -369,6 +456,10 @@ class FlaxGPT2PreTrainedModel(FlaxPreTrainedModel): ...@@ -369,6 +456,10 @@ class FlaxGPT2PreTrainedModel(FlaxPreTrainedModel):
) )
return_dict = return_dict if return_dict is not None else self.config.return_dict return_dict = return_dict if return_dict is not None else self.config.return_dict
if encoder_hidden_states is not None and encoder_attention_mask is None:
batch_size, sequence_length = encoder_hidden_states.shape[:2]
encoder_attention_mask = jnp.ones((batch_size, sequence_length))
batch_size, sequence_length = input_ids.shape batch_size, sequence_length = input_ids.shape
if position_ids is None: if position_ids is None:
...@@ -399,6 +490,8 @@ class FlaxGPT2PreTrainedModel(FlaxPreTrainedModel): ...@@ -399,6 +490,8 @@ class FlaxGPT2PreTrainedModel(FlaxPreTrainedModel):
jnp.array(input_ids, dtype="i4"), jnp.array(input_ids, dtype="i4"),
jnp.array(attention_mask, dtype="i4"), jnp.array(attention_mask, dtype="i4"),
jnp.array(position_ids, dtype="i4"), jnp.array(position_ids, dtype="i4"),
encoder_hidden_states,
encoder_attention_mask,
not train, not train,
False, False,
output_attentions, output_attentions,
...@@ -433,6 +526,8 @@ class FlaxGPT2BlockCollection(nn.Module): ...@@ -433,6 +526,8 @@ class FlaxGPT2BlockCollection(nn.Module):
self, self,
hidden_states, hidden_states,
attention_mask=None, attention_mask=None,
encoder_hidden_states: Optional[jnp.ndarray] = None,
encoder_attention_mask: Optional[jnp.ndarray] = None,
deterministic: bool = True, deterministic: bool = True,
init_cache: bool = False, init_cache: bool = False,
output_attentions: bool = False, output_attentions: bool = False,
...@@ -441,6 +536,7 @@ class FlaxGPT2BlockCollection(nn.Module): ...@@ -441,6 +536,7 @@ class FlaxGPT2BlockCollection(nn.Module):
): ):
all_attentions = () if output_attentions else None all_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
for block in self.blocks: for block in self.blocks:
if output_hidden_states: if output_hidden_states:
...@@ -449,6 +545,8 @@ class FlaxGPT2BlockCollection(nn.Module): ...@@ -449,6 +545,8 @@ class FlaxGPT2BlockCollection(nn.Module):
layer_outputs = block( layer_outputs = block(
hidden_states, hidden_states,
attention_mask, attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
deterministic=deterministic, deterministic=deterministic,
init_cache=init_cache, init_cache=init_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
...@@ -458,8 +556,11 @@ class FlaxGPT2BlockCollection(nn.Module): ...@@ -458,8 +556,11 @@ class FlaxGPT2BlockCollection(nn.Module):
if output_attentions: if output_attentions:
all_attentions += (layer_outputs[1],) all_attentions += (layer_outputs[1],)
if encoder_hidden_states is not None:
all_cross_attentions += (layer_outputs[2],)
# this contains possible `None` values - `FlaxGPT2Module` will filter them out # this contains possible `None` values - `FlaxGPT2Module` will filter them out
outputs = (hidden_states, all_hidden_states, all_attentions) outputs = (hidden_states, all_hidden_states, all_attentions, all_cross_attentions)
return outputs return outputs
...@@ -492,6 +593,8 @@ class FlaxGPT2Module(nn.Module): ...@@ -492,6 +593,8 @@ class FlaxGPT2Module(nn.Module):
input_ids, input_ids,
attention_mask, attention_mask,
position_ids, position_ids,
encoder_hidden_states: Optional[jnp.ndarray] = None,
encoder_attention_mask: Optional[jnp.ndarray] = None,
deterministic=True, deterministic=True,
init_cache: bool = False, init_cache: bool = False,
output_attentions: bool = False, output_attentions: bool = False,
...@@ -507,6 +610,8 @@ class FlaxGPT2Module(nn.Module): ...@@ -507,6 +610,8 @@ class FlaxGPT2Module(nn.Module):
outputs = self.h( outputs = self.h(
hidden_states, hidden_states,
attention_mask, attention_mask,
encoder_hidden_states,
encoder_attention_mask,
deterministic=deterministic, deterministic=deterministic,
init_cache=init_cache, init_cache=init_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
...@@ -526,10 +631,11 @@ class FlaxGPT2Module(nn.Module): ...@@ -526,10 +631,11 @@ class FlaxGPT2Module(nn.Module):
if not return_dict: if not return_dict:
return tuple(v for v in outputs if v is not None) return tuple(v for v in outputs if v is not None)
return FlaxBaseModelOutput( return FlaxBaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states, last_hidden_state=hidden_states,
hidden_states=outputs[1], hidden_states=outputs[1],
attentions=outputs[-1], attentions=outputs[2],
cross_attentions=outputs[3],
) )
...@@ -542,7 +648,11 @@ class FlaxGPT2Model(FlaxGPT2PreTrainedModel): ...@@ -542,7 +648,11 @@ class FlaxGPT2Model(FlaxGPT2PreTrainedModel):
append_call_sample_docstring( append_call_sample_docstring(
FlaxGPT2Model, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxBaseModelOutput, _CONFIG_FOR_DOC FlaxGPT2Model,
_TOKENIZER_FOR_DOC,
_CHECKPOINT_FOR_DOC,
FlaxBaseModelOutputWithPastAndCrossAttentions,
_CONFIG_FOR_DOC,
) )
...@@ -564,6 +674,8 @@ class FlaxGPT2LMHeadModule(nn.Module): ...@@ -564,6 +674,8 @@ class FlaxGPT2LMHeadModule(nn.Module):
input_ids, input_ids,
attention_mask, attention_mask,
position_ids, position_ids,
encoder_hidden_states: Optional[jnp.ndarray] = None,
encoder_attention_mask: Optional[jnp.ndarray] = None,
deterministic: bool = True, deterministic: bool = True,
init_cache: bool = False, init_cache: bool = False,
output_attentions: bool = False, output_attentions: bool = False,
...@@ -574,6 +686,8 @@ class FlaxGPT2LMHeadModule(nn.Module): ...@@ -574,6 +686,8 @@ class FlaxGPT2LMHeadModule(nn.Module):
input_ids, input_ids,
attention_mask, attention_mask,
position_ids, position_ids,
encoder_hidden_states,
encoder_attention_mask,
deterministic=deterministic, deterministic=deterministic,
init_cache=init_cache, init_cache=init_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
...@@ -592,7 +706,12 @@ class FlaxGPT2LMHeadModule(nn.Module): ...@@ -592,7 +706,12 @@ class FlaxGPT2LMHeadModule(nn.Module):
if not return_dict: if not return_dict:
return (lm_logits,) + outputs[1:] return (lm_logits,) + outputs[1:]
return FlaxCausalLMOutput(logits=lm_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions) return FlaxCausalLMOutputWithCrossAttentions(
logits=lm_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
cross_attentions=outputs.cross_attentions,
)
@add_start_docstrings( @add_start_docstrings(
...@@ -633,5 +752,9 @@ class FlaxGPT2LMHeadModel(FlaxGPT2PreTrainedModel): ...@@ -633,5 +752,9 @@ class FlaxGPT2LMHeadModel(FlaxGPT2PreTrainedModel):
append_call_sample_docstring( append_call_sample_docstring(
FlaxGPT2LMHeadModel, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxCausalLMOutput, _CONFIG_FOR_DOC FlaxGPT2LMHeadModel,
_TOKENIZER_FOR_DOC,
_CHECKPOINT_FOR_DOC,
FlaxCausalLMOutputWithCrossAttentions,
_CONFIG_FOR_DOC,
) )
...@@ -822,7 +822,11 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -822,7 +822,11 @@ class GPT2Model(GPT2PreTrainedModel):
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict: if not return_dict:
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) return tuple(
v
for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions]
if v is not None
)
return BaseModelOutputWithPastAndCrossAttentions( return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states, last_hidden_state=hidden_states,
......
...@@ -620,13 +620,21 @@ class FlaxRobertaModule(nn.Module): ...@@ -620,13 +620,21 @@ class FlaxRobertaModule(nn.Module):
self, self,
input_ids, input_ids,
attention_mask, attention_mask,
token_type_ids, token_type_ids: Optional[np.ndarray] = None,
position_ids, position_ids: Optional[np.ndarray] = None,
deterministic: bool = True, deterministic: bool = True,
output_attentions: bool = False, output_attentions: bool = False,
output_hidden_states: bool = False, output_hidden_states: bool = False,
return_dict: bool = True, return_dict: bool = True,
): ):
# make sure `token_type_ids` is correctly initialized when not passed
if token_type_ids is None:
token_type_ids = jnp.zeros_like(input_ids)
# make sure `position_ids` is correctly initialized when not passed
if position_ids is None:
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
hidden_states = self.embeddings( hidden_states = self.embeddings(
input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic
) )
......
...@@ -516,6 +516,15 @@ class FlaxElectraPreTrainedModel: ...@@ -516,6 +516,15 @@ class FlaxElectraPreTrainedModel:
requires_backends(cls, ["flax"]) requires_backends(cls, ["flax"])
class FlaxEncoderDecoderModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
class FlaxGPT2LMHeadModel: class FlaxGPT2LMHeadModel:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"]) requires_backends(self, ["flax"])
......
...@@ -55,13 +55,13 @@ if is_torch_available(): ...@@ -55,13 +55,13 @@ if is_torch_available():
@require_torch @require_torch
class EncoderDecoderMixin: class EncoderDecoderMixin:
def get_encoder_decoder_model(self, config, decoder_config): def get_encoder_decoder_model(self, config, decoder_config):
pass raise NotImplementedError
def prepare_config_and_inputs(self): def prepare_config_and_inputs(self):
pass raise NotImplementedError
def get_pretrained_model(self): def get_pretrained_model(self):
pass raise NotImplementedError
def check_encoder_decoder_model_from_pretrained_configs( def check_encoder_decoder_model_from_pretrained_configs(
self, self,
...@@ -776,6 +776,24 @@ class GPT2EncoderDecoderModelTest(EncoderDecoderMixin, unittest.TestCase): ...@@ -776,6 +776,24 @@ class GPT2EncoderDecoderModelTest(EncoderDecoderMixin, unittest.TestCase):
def test_encoder_decoder_model_shared_weights(self): def test_encoder_decoder_model_shared_weights(self):
pass pass
@slow
def test_bert2gpt2_summarization(self):
model = EncoderDecoderModel.from_pretrained("patrickvonplaten/bert2gpt2-cnn_dailymail-fp16")
model.to(torch_device)
tokenizer_in = AutoTokenizer.from_pretrained("bert-base-cased")
tokenizer_out = AutoTokenizer.from_pretrained("gpt2")
ARTICLE_STUDENTS = """(CNN)Sigma Alpha Epsilon is under fire for a video showing party-bound fraternity members singing a racist chant. SAE's national chapter suspended the students, but University of Oklahoma President David Boren took it a step further, saying the university's affiliation with the fraternity is permanently done. The news is shocking, but it's not the first time SAE has faced controversy. SAE was founded March 9, 1856, at the University of Alabama, five years before the American Civil War, according to the fraternity website. When the war began, the group had fewer than 400 members, of which "369 went to war for the Confederate States and seven for the Union Army," the website says. The fraternity now boasts more than 200,000 living alumni, along with about 15,000 undergraduates populating 219 chapters and 20 "colonies" seeking full membership at universities. SAE has had to work hard to change recently after a string of member deaths, many blamed on the hazing of new recruits, SAE national President Bradley Cohen wrote in a message on the fraternity's website. The fraternity's website lists more than 130 chapters cited or suspended for "health and safety incidents" since 2010. At least 30 of the incidents involved hazing, and dozens more involved alcohol. However, the list is missing numerous incidents from recent months. Among them, according to various media outlets: Yale University banned the SAEs from campus activities last month after members allegedly tried to interfere with a sexual misconduct investigation connected to an initiation rite. Stanford University in December suspended SAE housing privileges after finding sorority members attending a fraternity function were subjected to graphic sexual content. And Johns Hopkins University in November suspended the fraternity for underage drinking. "The media has labeled us as the 'nation's deadliest fraternity,' " Cohen said. In 2011, for example, a student died while being coerced into excessive alcohol consumption, according to a lawsuit. SAE's previous insurer dumped the fraternity. "As a result, we are paying Lloyd's of London the highest insurance rates in the Greek-letter world," Cohen said. Universities have turned down SAE's attempts to open new chapters, and the fraternity had to close 12 in 18 months over hazing incidents."""
EXPECTED_SUMMARY_STUDENTS = """SAS Alpha Epsilon suspended the students, but university president says it's permanent.\nThe fraternity has had to deal with a string of student deaths since 2010.\nSAS has more than 200,000 members, many of whom are students.\nA student died while being forced into excessive alcohol consumption."""
input_dict = tokenizer_in(ARTICLE_STUDENTS, return_tensors="pt")
output_ids = model.generate(input_dict["input_ids"].to(torch_device))
summary = tokenizer_out.batch_decode(output_ids, skip_special_tokens=True)
self.assertEqual(summary, [EXPECTED_SUMMARY_STUDENTS])
@require_torch @require_torch
class ProphetNetEncoderDecoderModelTest(EncoderDecoderMixin, unittest.TestCase): class ProphetNetEncoderDecoderModelTest(EncoderDecoderMixin, unittest.TestCase):
......
# coding=utf-8
# Copyright 2020 HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tempfile
import unittest
import numpy as np
from transformers import is_flax_available
from transformers.testing_utils import require_flax, slow
from .test_modeling_flax_bert import FlaxBertModelTester
from .test_modeling_flax_common import ids_tensor
from .test_modeling_flax_gpt2 import FlaxGPT2ModelTester
if is_flax_available():
from transformers import (
AutoConfig,
AutoTokenizer,
EncoderDecoderConfig,
FlaxBertModel,
FlaxEncoderDecoderModel,
FlaxGPT2LMHeadModel,
)
@require_flax
class FlaxEncoderDecoderMixin:
def get_encoder_decoder_model(self, config, decoder_config):
raise NotImplementedError
def prepare_config_and_inputs(self):
raise NotImplementedError
def get_pretrained_model(self):
raise NotImplementedError
def check_encoder_decoder_model_from_pretrained_configs(
self,
config,
input_ids,
attention_mask,
encoder_hidden_states,
decoder_config,
decoder_input_ids,
decoder_attention_mask,
**kwargs
):
encoder_decoder_config = EncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)
self.assertTrue(encoder_decoder_config.decoder.is_decoder)
enc_dec_model = FlaxEncoderDecoderModel(encoder_decoder_config)
self.assertTrue(enc_dec_model.config.is_encoder_decoder)
outputs_encoder_decoder = enc_dec_model(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
)
self.assertEqual(
outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,))
)
self.assertEqual(
outputs_encoder_decoder["encoder_last_hidden_state"].shape, (input_ids.shape + (config.hidden_size,))
)
def check_encoder_decoder_model_from_pretrained(
self,
config,
input_ids,
attention_mask,
encoder_hidden_states,
decoder_config,
decoder_input_ids,
decoder_attention_mask,
return_dict,
**kwargs
):
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
kwargs = {"encoder_model": encoder_model, "decoder_model": decoder_model, "return_dict": return_dict}
enc_dec_model = FlaxEncoderDecoderModel.from_encoder_decoder_pretrained(**kwargs)
outputs_encoder_decoder = enc_dec_model(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
return_dict=True,
)
self.assertEqual(
outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,))
)
self.assertEqual(
outputs_encoder_decoder["encoder_last_hidden_state"].shape, (input_ids.shape + (config.hidden_size,))
)
def check_save_and_load(
self,
config,
input_ids,
attention_mask,
encoder_hidden_states,
decoder_config,
decoder_input_ids,
decoder_attention_mask,
**kwargs
):
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
kwargs = {"encoder_model": encoder_model, "decoder_model": decoder_model}
enc_dec_model = FlaxEncoderDecoderModel.from_encoder_decoder_pretrained(**kwargs)
outputs = enc_dec_model(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
)
out_2 = np.array(outputs[0])
out_2[np.isnan(out_2)] = 0
with tempfile.TemporaryDirectory() as tmpdirname:
enc_dec_model.save_pretrained(tmpdirname)
FlaxEncoderDecoderModel.from_pretrained(tmpdirname)
after_outputs = enc_dec_model(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
)
out_1 = np.array(after_outputs[0])
out_1[np.isnan(out_1)] = 0
max_diff = np.amax(np.abs(out_1 - out_2))
self.assertLessEqual(max_diff, 1e-5)
def check_encoder_decoder_model_output_attentions(
self,
config,
input_ids,
attention_mask,
encoder_hidden_states,
decoder_config,
decoder_input_ids,
decoder_attention_mask,
**kwargs
):
# make the decoder inputs a different shape from the encoder inputs to harden the test
decoder_input_ids = decoder_input_ids[:, :-1]
decoder_attention_mask = decoder_attention_mask[:, :-1]
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
kwargs = {"encoder_model": encoder_model, "decoder_model": decoder_model}
enc_dec_model = FlaxEncoderDecoderModel.from_encoder_decoder_pretrained(**kwargs)
outputs_encoder_decoder = enc_dec_model(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
output_attentions=True,
)
encoder_attentions = outputs_encoder_decoder["encoder_attentions"]
self.assertEqual(len(encoder_attentions), config.num_hidden_layers)
self.assertEqual(
encoder_attentions[0].shape[-3:], (config.num_attention_heads, input_ids.shape[-1], input_ids.shape[-1])
)
decoder_attentions = outputs_encoder_decoder["decoder_attentions"]
num_decoder_layers = (
decoder_config.num_decoder_layers
if hasattr(decoder_config, "num_decoder_layers")
else decoder_config.num_hidden_layers
)
self.assertEqual(len(decoder_attentions), num_decoder_layers)
self.assertEqual(
decoder_attentions[0].shape[-3:],
(decoder_config.num_attention_heads, decoder_input_ids.shape[-1], decoder_input_ids.shape[-1]),
)
cross_attentions = outputs_encoder_decoder["cross_attentions"]
self.assertEqual(len(cross_attentions), num_decoder_layers)
cross_attention_input_seq_len = decoder_input_ids.shape[-1] * (
1 + (decoder_config.ngram if hasattr(decoder_config, "ngram") else 0)
)
self.assertEqual(
cross_attentions[0].shape[-3:],
(decoder_config.num_attention_heads, cross_attention_input_seq_len, input_ids.shape[-1]),
)
def check_encoder_decoder_model_generate(self, input_ids, config, decoder_config, **kwargs):
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
kwargs = {"encoder_model": encoder_model, "decoder_model": decoder_model}
enc_dec_model = FlaxEncoderDecoderModel.from_encoder_decoder_pretrained(**kwargs)
pad_token_id = enc_dec_model.config.decoder.pad_token_id
eos_token_id = enc_dec_model.config.decoder.eos_token_id
decoder_start_token_id = enc_dec_model.config.decoder.decoder_start_token_id
# Copied from generation_utils (GPT2 doesn't have `pad_token_id`)
if pad_token_id is None and eos_token_id is not None:
pad_token_id = eos_token_id
if decoder_start_token_id is None:
decoder_start_token_id = enc_dec_model.config.decoder.bos_token_id
# Bert does not have a bos token id, so use pad_token_id instead
# Copied from `test_modeling_encoder_decoder.py`
if decoder_start_token_id is None:
decoder_start_token_id = pad_token_id
generated_output = enc_dec_model.generate(
input_ids,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
decoder_start_token_id=decoder_start_token_id,
)
generated_sequences = generated_output.sequences
self.assertEqual(generated_sequences.shape, (input_ids.shape[0],) + (decoder_config.max_length,))
def test_encoder_decoder_model_from_pretrained_configs(self):
input_ids_dict = self.prepare_config_and_inputs()
self.check_encoder_decoder_model_from_pretrained_configs(**input_ids_dict)
def test_encoder_decoder_model_from_pretrained(self):
input_ids_dict = self.prepare_config_and_inputs()
self.check_encoder_decoder_model_from_pretrained(**input_ids_dict, return_dict=False)
def test_encoder_decoder_model_from_pretrained_return_dict(self):
input_ids_dict = self.prepare_config_and_inputs()
self.check_encoder_decoder_model_from_pretrained(**input_ids_dict, return_dict=True)
def test_save_and_load_from_pretrained(self):
input_ids_dict = self.prepare_config_and_inputs()
self.check_save_and_load(**input_ids_dict)
def test_encoder_decoder_model_output_attentions(self):
input_ids_dict = self.prepare_config_and_inputs()
self.check_encoder_decoder_model_output_attentions(**input_ids_dict)
def test_encoder_decoder_model_generate(self):
input_ids_dict = self.prepare_config_and_inputs()
self.check_encoder_decoder_model_generate(**input_ids_dict)
@slow
def test_real_model_save_load_from_pretrained(self):
model_2 = self.get_pretrained_model()
input_ids = ids_tensor([13, 5], model_2.config.encoder.vocab_size)
decoder_input_ids = ids_tensor([13, 1], model_2.config.encoder.vocab_size)
attention_mask = ids_tensor([13, 5], vocab_size=2)
outputs = model_2(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
)
out_2 = np.array(outputs[0])
out_2[np.isnan(out_2)] = 0
with tempfile.TemporaryDirectory() as tmp_dirname:
model_2.save_pretrained(tmp_dirname)
model_1 = FlaxEncoderDecoderModel.from_pretrained(tmp_dirname)
after_outputs = model_1(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
)
out_1 = np.array(after_outputs[0])
out_1[np.isnan(out_1)] = 0
max_diff = np.amax(np.abs(out_1 - out_2))
self.assertLessEqual(max_diff, 1e-5)
@require_flax
class FlaxGPT2EncoderDecoderModelTest(FlaxEncoderDecoderMixin, unittest.TestCase):
def get_encoder_decoder_model(self, config, decoder_config):
encoder_model = FlaxBertModel(config)
decoder_model = FlaxGPT2LMHeadModel(decoder_config)
return encoder_model, decoder_model
def prepare_config_and_inputs(self):
model_tester_encoder = FlaxBertModelTester(self, batch_size=13)
model_tester_decoder = FlaxGPT2ModelTester(self, batch_size=13)
encoder_config_and_inputs = model_tester_encoder.prepare_config_and_inputs()
decoder_config_and_inputs = model_tester_decoder.prepare_config_and_inputs_for_decoder()
(config, input_ids, token_type_ids, attention_mask) = encoder_config_and_inputs
(
decoder_config,
decoder_input_ids,
decoder_attention_mask,
encoder_hidden_states,
encoder_attention_mask,
) = decoder_config_and_inputs
# make sure that cross attention layers are added
decoder_config.add_cross_attention = True
return {
"config": config,
"input_ids": input_ids,
"attention_mask": attention_mask,
"decoder_config": decoder_config,
"decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": decoder_attention_mask,
"encoder_hidden_states": encoder_hidden_states,
}
def get_pretrained_model(self):
return FlaxEncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-cased", "gpt2")
@slow
def test_bert2gpt2_summarization(self):
tokenizer_in = AutoTokenizer.from_pretrained("bert-base-cased")
tokenizer_out = AutoTokenizer.from_pretrained("gpt2")
model = FlaxEncoderDecoderModel.from_pretrained(
"patrickvonplaten/bert2gpt2-cnn_dailymail-fp16", pad_token_id=tokenizer_out.eos_token_id
)
ARTICLE_STUDENTS = """(CNN)Sigma Alpha Epsilon is under fire for a video showing party-bound fraternity members singing a racist chant. SAE's national chapter suspended the students, but University of Oklahoma President David Boren took it a step further, saying the university's affiliation with the fraternity is permanently done. The news is shocking, but it's not the first time SAE has faced controversy. SAE was founded March 9, 1856, at the University of Alabama, five years before the American Civil War, according to the fraternity website. When the war began, the group had fewer than 400 members, of which "369 went to war for the Confederate States and seven for the Union Army," the website says. The fraternity now boasts more than 200,000 living alumni, along with about 15,000 undergraduates populating 219 chapters and 20 "colonies" seeking full membership at universities. SAE has had to work hard to change recently after a string of member deaths, many blamed on the hazing of new recruits, SAE national President Bradley Cohen wrote in a message on the fraternity's website. The fraternity's website lists more than 130 chapters cited or suspended for "health and safety incidents" since 2010. At least 30 of the incidents involved hazing, and dozens more involved alcohol. However, the list is missing numerous incidents from recent months. Among them, according to various media outlets: Yale University banned the SAEs from campus activities last month after members allegedly tried to interfere with a sexual misconduct investigation connected to an initiation rite. Stanford University in December suspended SAE housing privileges after finding sorority members attending a fraternity function were subjected to graphic sexual content. And Johns Hopkins University in November suspended the fraternity for underage drinking. "The media has labeled us as the 'nation's deadliest fraternity,' " Cohen said. In 2011, for example, a student died while being coerced into excessive alcohol consumption, according to a lawsuit. SAE's previous insurer dumped the fraternity. "As a result, we are paying Lloyd's of London the highest insurance rates in the Greek-letter world," Cohen said. Universities have turned down SAE's attempts to open new chapters, and the fraternity had to close 12 in 18 months over hazing incidents."""
EXPECTED_SUMMARY_STUDENTS = """SAE's national chapter suspended the students, but university president says it's permanent.\nSAE's national chapter has had to work hard to change recently.\nSAE's chapter has more than 200,000 members.\nSAE's chapter has been criticized for its hazing of new recruits."""
input_dict = tokenizer_in(ARTICLE_STUDENTS, return_tensors="np")
output_ids = model.generate(input_dict["input_ids"]).sequences
summary = tokenizer_out.batch_decode(output_ids, skip_special_tokens=True)
self.assertEqual(summary, [EXPECTED_SUMMARY_STUDENTS])
@require_flax
class FlaxEncoderDecoderModelTest(unittest.TestCase):
def get_from_encoderdecoder_pretrained_model(self):
return FlaxEncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-cased", "gpt2")
def get_decoder_config(self):
config = AutoConfig.from_pretrained("gpt2")
config.is_decoder = True
config.add_cross_attention = True
return config
def _check_configuration_tie(self, model):
assert id(model.decoder.config) == id(model.config.decoder)
assert id(model.encoder.config) == id(model.config.encoder)
@slow
def test_configuration_tie(self):
model = self.get_from_encoderdecoder_pretrained_model()
self._check_configuration_tie(model)
...@@ -23,7 +23,7 @@ from transformers import GPT2Config, GPT2Tokenizer, is_flax_available, is_torch_ ...@@ -23,7 +23,7 @@ from transformers import GPT2Config, GPT2Tokenizer, is_flax_available, is_torch_
from transformers.testing_utils import is_pt_flax_cross_test, require_flax, slow from transformers.testing_utils import is_pt_flax_cross_test, require_flax, slow
from .test_generation_flax_utils import FlaxGenerationTesterMixin from .test_generation_flax_utils import FlaxGenerationTesterMixin
from .test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_attention_mask from .test_modeling_flax_common import FlaxModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
if is_flax_available(): if is_flax_available():
...@@ -111,6 +111,20 @@ class FlaxGPT2ModelTester: ...@@ -111,6 +111,20 @@ class FlaxGPT2ModelTester:
inputs_dict = {"input_ids": input_ids, "attention_mask": attention_mask} inputs_dict = {"input_ids": input_ids, "attention_mask": attention_mask}
return config, inputs_dict return config, inputs_dict
def prepare_config_and_inputs_for_decoder(self):
config, input_ids, attention_mask = self.prepare_config_and_inputs()
encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size])
encoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
return (
config,
input_ids,
attention_mask,
encoder_hidden_states,
encoder_attention_mask,
)
def check_use_cache_forward(self, model_class_name, config, input_ids, attention_mask): def check_use_cache_forward(self, model_class_name, config, input_ids, attention_mask):
max_decoder_length = 20 max_decoder_length = 20
model = model_class_name(config) model = model_class_name(config)
......
...@@ -155,6 +155,7 @@ def get_model_modules(): ...@@ -155,6 +155,7 @@ def get_model_modules():
"modeling_retribert", "modeling_retribert",
"modeling_utils", "modeling_utils",
"modeling_flax_auto", "modeling_flax_auto",
"modeling_flax_encoder_decoder",
"modeling_flax_utils", "modeling_flax_utils",
"modeling_transfo_xl_utilities", "modeling_transfo_xl_utilities",
"modeling_tf_auto", "modeling_tf_auto",
...@@ -226,6 +227,7 @@ def get_model_test_files(): ...@@ -226,6 +227,7 @@ def get_model_test_files():
_ignore_files = [ _ignore_files = [
"test_modeling_common", "test_modeling_common",
"test_modeling_encoder_decoder", "test_modeling_encoder_decoder",
"test_modeling_flax_encoder_decoder",
"test_modeling_marian", "test_modeling_marian",
"test_modeling_tf_common", "test_modeling_tf_common",
] ]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment