"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "02b176c4ce14340d26d42825523f406959c6c202"
Unverified Commit 4a51b1dd authored by Daniel Stancl's avatar Daniel Stancl Committed by GitHub
Browse files

FlaxBart (#11537)



* Start working on FlaxBart

* Create modeling_flax_bart.py

* Write FlaxBartAttention

* Add FlaxBartEncoderLayer

* Add FlaxBartDecoderLayer and some typing

* Add helepr function for FlaxBart

* shift_tokens_right

* _make_causal_mask

* _expand_mask

* Add PositionalEmbedding and fix init_std naming

* Add FlaxBartPretrainedModel

* Add FlaxBartEncoder

* Add FlaxBartEncoder

* Add FlaxBartEncoder among modules to be imported

* YET WE CANNOT INITIALIZE THAT!! :(

* Make BartEncoder working

Change BartEncoder to instance of nn.Module so far

* Add FlaxBartDecoder

* Add FlaxBartModel

* TODO to make model run -> Prepapre model inputs

* Resolve padding

* Add FlaxBartModel

* Add FlaxBartModel into importable modules

* Remove FlaxBartEncoder and FlaxBartDecoder from importable modules

* make style; not properly working

* make style; make quality not pass due to some import I left

* Remove TODO for padding_idx in nn.Embed so far

* Add FlaxBartForConditionalGeneration

* Incorporate Flax model output classes, i.e. return_dict

* Add another models and incorporate use_cache arg

* Add FlaxBartForSequenceClassification and FlaxBartForQuestionAnswering

* Incorporate use_cache arg from PyTorch implementation

* Add all necessary Flax output utils

* Add FlaxBartForCausalLM; not working yet'

* Add minor improvements; still lacks some functionality

* Update docs, src and tests

* Add support of FlaxBart to docs/source

* Fix some bugs in FlaxBart souce code

* Add some neccessary tests for FlaxBart models - jit_compilation not passing

* Fix tests and add test_head_masking

* Fix tests for @jax.jit computation

* Add test_head_masking

* Migrate FlaxBart tests from jax.numpy to numpy

* Remove FlaxBartForCausalLM

* Clean repo

* fix bart model weight structure

* Fix FlaxBartForSequenceClassification

Slicing is not possible to use below jit, therefore, selecting sentence
representation from hidden_states must be changed.

* Allow FlaxBartForSequenceClassification for testing pt_flax equivalence

* Allow testing for FlaxBartForQA for pt_flax equivalence

* Add a comment to FlaxBartForSequenceClassification + change noise from 1e-3 to 1e-6

* remove past_key_values

* remove inputs_mebeds and make input_ids required

* add position ids

* re-write attention layer

* fix dataclass

* fix pos embeds and attention output

* fix pos embeds

* expose encode method

* expose decode method

* move docstring to top

* add cache for causal attn layer

* remove head masking for now

* s2s greedy search first pass

* boom boom

* fix typos

* fix greedy generate for bart

* use encoder, decoder layers instead of num_hidden_layers

* handle encoder_outputs

* cleanup

* simplify decoding

* more clean-up

* typos

* Change header + add {decoder_,}position_ids into 2 models

* add BartConfig

* fix existing tests

* add encode, decode methods

* Fix shift_tokens_right for JIT compilation + clarify one condition

* fix decode

* encoder => encode

* simplify generate

* add tests for encode and decode

* style

* add tests for cache

* fix equivalence tests

* sample generate now works with seq2seq

* generation tests

* initialize dense layers

* docstring and cleanup

* quality

* remove get/set input_embeddings

* address Patricks suggestions

* decode for every model, remove encoder_outputs from call

* update tests accordingly

* decode returns only decoder outputs and logits

* fix arguments

* doc encode, decode methods

* correct base_model_prefix

* fix test for seq classif model

* fix docs
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>
parent d36fce82
...@@ -299,7 +299,7 @@ Flax), PyTorch, and/or TensorFlow. ...@@ -299,7 +299,7 @@ Flax), PyTorch, and/or TensorFlow.
+=============================+================+================+=================+====================+==============+ +=============================+================+================+=================+====================+==============+
| ALBERT | ✅ | ✅ | ✅ | ✅ | ❌ | | ALBERT | ✅ | ✅ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| BART | ✅ | ✅ | ✅ | ✅ | | | BART | ✅ | ✅ | ✅ | ✅ | |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| BERT | ✅ | ✅ | ✅ | ✅ | ✅ | | BERT | ✅ | ✅ | ✅ | ✅ | ✅ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
......
...@@ -131,6 +131,7 @@ BartForQuestionAnswering ...@@ -131,6 +131,7 @@ BartForQuestionAnswering
.. autoclass:: transformers.BartForQuestionAnswering .. autoclass:: transformers.BartForQuestionAnswering
:members: forward :members: forward
BartForCausalLM BartForCausalLM
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
...@@ -138,7 +139,6 @@ BartForCausalLM ...@@ -138,7 +139,6 @@ BartForCausalLM
:members: forward :members: forward
TFBartModel TFBartModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
...@@ -151,3 +151,32 @@ TFBartForConditionalGeneration ...@@ -151,3 +151,32 @@ TFBartForConditionalGeneration
.. autoclass:: transformers.TFBartForConditionalGeneration .. autoclass:: transformers.TFBartForConditionalGeneration
:members: call :members: call
FlaxBartModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxBartModel
:members: __call__, encode, decode
FlaxBartForConditionalGeneration
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxBartForConditionalGeneration
:members: __call__, encode, decode
FlaxBartForSequenceClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxBartForSequenceClassification
:members: __call__, encode, decode
FlaxBartForQuestionAnswering
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxBartForQuestionAnswering
:members: __call__, encode, decode
...@@ -1508,6 +1508,14 @@ if is_flax_available(): ...@@ -1508,6 +1508,14 @@ if is_flax_available():
"FlaxAutoModelForTokenClassification", "FlaxAutoModelForTokenClassification",
] ]
) )
_import_structure["models.bart"].extend(
[
"FlaxBartForConditionalGeneration",
"FlaxBartForQuestionAnswering",
"FlaxBartForSequenceClassification",
"FlaxBartModel",
]
)
_import_structure["models.bert"].extend( _import_structure["models.bert"].extend(
[ [
"FlaxBertForMaskedLM", "FlaxBertForMaskedLM",
...@@ -2808,6 +2816,12 @@ if TYPE_CHECKING: ...@@ -2808,6 +2816,12 @@ if TYPE_CHECKING:
FlaxAutoModelForSequenceClassification, FlaxAutoModelForSequenceClassification,
FlaxAutoModelForTokenClassification, FlaxAutoModelForTokenClassification,
) )
from .models.bart import (
FlaxBartForConditionalGeneration,
FlaxBartForQuestionAnswering,
FlaxBartForSequenceClassification,
FlaxBartModel,
)
from .models.bert import ( from .models.bert import (
FlaxBertForMaskedLM, FlaxBertForMaskedLM,
FlaxBertForMultipleChoice, FlaxBertForMultipleChoice,
......
...@@ -101,12 +101,23 @@ class FlaxGenerationMixin: ...@@ -101,12 +101,23 @@ class FlaxGenerationMixin:
state = body_fn(state) state = body_fn(state)
return state return state
def _prepare_encoder_decoder_kwargs_for_generation(self, input_ids, model_kwargs):
encoder_kwargs = {
argument: value
for argument, value in model_kwargs.items()
if not (argument.startswith("decoder_") or argument.startswith("cross_attn"))
}
model_kwargs["encoder_outputs"] = self.encode(input_ids, return_dict=True, **encoder_kwargs)
return model_kwargs
def generate( def generate(
self, self,
input_ids: jax_xla.DeviceArray, input_ids: jax_xla.DeviceArray,
max_length: Optional[int] = None, max_length: Optional[int] = None,
pad_token_id: Optional[int] = None, pad_token_id: Optional[int] = None,
bos_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None, eos_token_id: Optional[int] = None,
decoder_start_token_id: Optional[int] = None,
do_sample: Optional[bool] = None, do_sample: Optional[bool] = None,
prng_key: Optional[jax_xla.DeviceArray] = None, prng_key: Optional[jax_xla.DeviceArray] = None,
top_k: Optional[int] = None, top_k: Optional[int] = None,
...@@ -147,6 +158,8 @@ class FlaxGenerationMixin: ...@@ -147,6 +158,8 @@ class FlaxGenerationMixin:
The id of the `beginning-of-sequence` token. The id of the `beginning-of-sequence` token.
eos_token_id (:obj:`int`, `optional`): eos_token_id (:obj:`int`, `optional`):
The id of the `end-of-sequence` token. The id of the `end-of-sequence` token.
decoder_start_token_id (:obj:`int`, `optional`):
If an encoder-decoder model starts decoding with a different token than `bos`, the id of that token.
trace (:obj:`bool`, `optional`, defaults to :obj:`True`): trace (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether to trace generation. Setting ``trace=False`` should only be used for debugging and will lead to Whether to trace generation. Setting ``trace=False`` should only be used for debugging and will lead to
a considerably slower runtime. a considerably slower runtime.
...@@ -170,10 +183,23 @@ class FlaxGenerationMixin: ...@@ -170,10 +183,23 @@ class FlaxGenerationMixin:
""" """
# set init values # set init values
max_length = max_length if max_length is not None else self.config.max_length max_length = max_length if max_length is not None else self.config.max_length
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
decoder_start_token_id = (
decoder_start_token_id if decoder_start_token_id else self.config.decoder_start_token_id
)
prng_key = prng_key if prng_key is not None else jax.random.PRNGKey(0) prng_key = prng_key if prng_key is not None else jax.random.PRNGKey(0)
if decoder_start_token_id is None and self.config.is_encoder_decoder:
raise ValueError("`decoder_start_token_id` has to be defined for encoder-decoder generation.")
if self.config.is_encoder_decoder:
# add encoder_outputs to model_kwargs
model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(input_ids, model_kwargs)
# prepare decoder_input_ids for generation
input_ids = jnp.ones((input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id
do_sample = do_sample if do_sample is not None else self.config.do_sample do_sample = do_sample if do_sample is not None else self.config.do_sample
if do_sample: if do_sample:
...@@ -246,10 +272,11 @@ class FlaxGenerationMixin: ...@@ -246,10 +272,11 @@ class FlaxGenerationMixin:
# per batch-item state bit indicating if sentence has finished. # per batch-item state bit indicating if sentence has finished.
is_sent_finished = jnp.zeros((batch_size,), dtype=jnp.bool_) is_sent_finished = jnp.zeros((batch_size,), dtype=jnp.bool_)
model = self # For Seq2Seq generation, we only need to use the decoder instead of the whole model in generation loop
# and pass it the `encoder_outputs`, which are part of the `model_kwargs`.
model = self.decode if self.config.is_encoder_decoder else self
# initialize model specific kwargs # initialize model specific kwargs
model_kwargs = model.prepare_inputs_for_generation(input_ids, max_length, **model_kwargs) model_kwargs = self.prepare_inputs_for_generation(input_ids, max_length, **model_kwargs)
# initialize state # initialize state
state = GreedyState( state = GreedyState(
...@@ -277,8 +304,7 @@ class FlaxGenerationMixin: ...@@ -277,8 +304,7 @@ class FlaxGenerationMixin:
next_token = next_token[:, None] next_token = next_token[:, None]
next_sequences = lax.dynamic_update_slice(state.sequences, next_token, (0, state.cur_len)) next_sequences = lax.dynamic_update_slice(state.sequences, next_token, (0, state.cur_len))
next_model_kwargs = model.update_inputs_for_generation(model_outputs, model_kwargs) next_model_kwargs = self.update_inputs_for_generation(model_outputs, state.model_kwargs)
return GreedyState( return GreedyState(
cur_len=state.cur_len + 1, cur_len=state.cur_len + 1,
sequences=next_sequences, sequences=next_sequences,
...@@ -288,7 +314,8 @@ class FlaxGenerationMixin: ...@@ -288,7 +314,8 @@ class FlaxGenerationMixin:
) )
# The very first prompt often has sequence length > 1, so run outside of `lax.while_loop` to comply with TPU # The very first prompt often has sequence length > 1, so run outside of `lax.while_loop` to comply with TPU
state = greedy_search_body_fn(state) if input_ids.shape[1] > 1:
state = greedy_search_body_fn(state)
if not trace: if not trace:
state = self._run_loop_in_debug(greedy_search_cond_fn, greedy_search_body_fn, state) state = self._run_loop_in_debug(greedy_search_cond_fn, greedy_search_body_fn, state)
...@@ -327,10 +354,12 @@ class FlaxGenerationMixin: ...@@ -327,10 +354,12 @@ class FlaxGenerationMixin:
# per batch-item state bit indicating if sentence has finished. # per batch-item state bit indicating if sentence has finished.
is_sent_finished = jnp.zeros((batch_size,), dtype=jnp.bool_) is_sent_finished = jnp.zeros((batch_size,), dtype=jnp.bool_)
model = self # For Seq2Seq generation, we only need to use the decoder instead of the whole model in generation loop
# and pass it the `encoder_outputs`, which are part of the `model_kwargs`.
model = self.decode if self.config.is_encoder_decoder else self
# initialize model specific kwargs # initialize model specific kwargs
model_kwargs = model.prepare_inputs_for_generation(input_ids, max_length, **model_kwargs) model_kwargs = self.prepare_inputs_for_generation(input_ids, max_length, **model_kwargs)
# initialize state # initialize state
state = SampleState( state = SampleState(
...@@ -366,7 +395,7 @@ class FlaxGenerationMixin: ...@@ -366,7 +395,7 @@ class FlaxGenerationMixin:
next_token = next_token[:, None] next_token = next_token[:, None]
next_sequences = lax.dynamic_update_slice(state.sequences, next_token, (0, state.cur_len)) next_sequences = lax.dynamic_update_slice(state.sequences, next_token, (0, state.cur_len))
next_model_kwargs = model.update_inputs_for_generation(model_outputs, model_kwargs) next_model_kwargs = self.update_inputs_for_generation(model_outputs, state.model_kwargs)
return SampleState( return SampleState(
cur_len=state.cur_len + 1, cur_len=state.cur_len + 1,
...@@ -378,7 +407,8 @@ class FlaxGenerationMixin: ...@@ -378,7 +407,8 @@ class FlaxGenerationMixin:
) )
# The very first prompt often has sequence length > 1, so run outside of `lax.while_loop` to comply with TPU # The very first prompt often has sequence length > 1, so run outside of `lax.while_loop` to comply with TPU
state = sample_search_body_fn(state) if input_ids.shape[1] > 1:
state = sample_search_body_fn(state)
if not trace: if not trace:
state = self._run_loop_in_debug(sample_search_cond_fn, sample_search_body_fn, state) state = self._run_loop_in_debug(sample_search_cond_fn, sample_search_body_fn, state)
......
...@@ -106,6 +106,154 @@ class FlaxBaseModelOutputWithPooling(ModelOutput): ...@@ -106,6 +106,154 @@ class FlaxBaseModelOutputWithPooling(ModelOutput):
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
@flax.struct.dataclass
class FlaxBaseModelOutputWithPastAndCrossAttentions(ModelOutput):
"""
Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding).
Args:
last_hidden_state (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
If :obj:`past_key_values` is used only the last hidden-state of the sequences of shape :obj:`(batch_size,
1, hidden_size)` is output.
past_key_values (:obj:`tuple(tuple(jax_xla.DeviceArray))`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
Tuple of :obj:`tuple(jax_xla.DeviceArray)` of length :obj:`config.n_layers`, with each tuple having 2
tensors of shape :obj:`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
``config.is_encoder_decoder=True`` 2 additional tensors of shape :obj:`(batch_size, num_heads,
encoder_sequence_length, embed_size_per_head)`.
Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
``config.is_encoder_decoder=True`` in the cross-attention blocks) that can be used (see
:obj:`past_key_values` input) to speed up sequential decoding.
hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
cross_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` and ``config.add_cross_attention=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
weighted average in the cross-attention heads.
"""
last_hidden_state: jax_xla.DeviceArray = None
past_key_values: Optional[Tuple[Tuple[jax_xla.DeviceArray]]] = None
hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
cross_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
@flax.struct.dataclass
class FlaxSeq2SeqModelOutput(ModelOutput):
"""
Base class for model encoder's outputs that also contains : pre-computed hidden states that can speed up sequential
decoding.
Args:
last_hidden_state (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the decoder of the model.
If :obj:`past_key_values` is used only the last hidden-state of the sequences of shape :obj:`(batch_size,
1, hidden_size)` is output.
past_key_values (:obj:`tuple(tuple(jax_xla.DeviceArray))`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
Tuple of :obj:`tuple(jax_xla.DeviceArray)` of length :obj:`config.n_layers`, with each tuple having 2
tensors of shape :obj:`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional
tensors of shape :obj:`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used (see :obj:`past_key_values` input) to speed up sequential decoding.
decoder_hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
decoder_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
self-attention heads.
cross_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
weighted average in the cross-attention heads.
encoder_last_hidden_state (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Sequence of hidden-states at the output of the last layer of the encoder of the model.
encoder_hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
encoder_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
self-attention heads.
"""
last_hidden_state: jax_xla.DeviceArray = None
past_key_values: Optional[Tuple[Tuple[jax_xla.DeviceArray]]] = None
decoder_hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
decoder_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
cross_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
encoder_last_hidden_state: Optional[jax_xla.DeviceArray] = None
encoder_hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
encoder_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
@flax.struct.dataclass
class FlaxCausalLMOutputWithCrossAttentions(ModelOutput):
"""
Base class for causal language model (or autoregressive) outputs.
Args:
logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
cross_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
Cross attentions weights after the attention softmax, used to compute the weighted average in the
cross-attention heads.
past_key_values (:obj:`tuple(tuple(jax_xla.DeviceArray))`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
Tuple of :obj:`jax_xla.DeviceArray` tuples of length :obj:`config.n_layers`, with each tuple containing the
cached key, value states of the self-attention and the cross-attention layers if model is used in
encoder-decoder setting. Only relevant if ``config.is_decoder = True``.
Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
:obj:`past_key_values` input) to speed up sequential decoding.
"""
logits: jax_xla.DeviceArray = None
past_key_values: Optional[Tuple[Tuple[jax_xla.DeviceArray]]] = None
hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
cross_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
@flax.struct.dataclass @flax.struct.dataclass
class FlaxMaskedLMOutput(ModelOutput): class FlaxMaskedLMOutput(ModelOutput):
""" """
...@@ -135,6 +283,63 @@ class FlaxMaskedLMOutput(ModelOutput): ...@@ -135,6 +283,63 @@ class FlaxMaskedLMOutput(ModelOutput):
FlaxCausalLMOutput = FlaxMaskedLMOutput FlaxCausalLMOutput = FlaxMaskedLMOutput
@flax.struct.dataclass
class FlaxSeq2SeqLMOutput(ModelOutput):
"""
Base class for sequence-to-sequence language models outputs.
Args:
logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
past_key_values (:obj:`tuple(tuple(jax_xla.DeviceArray))`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
Tuple of :obj:`tuple(jax_xla.DeviceArray)` of length :obj:`config.n_layers`, with each tuple having 2
tensors of shape :obj:`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional
tensors of shape :obj:`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used (see :obj:`past_key_values` input) to speed up sequential decoding.
decoder_hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
decoder_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
self-attention heads.
cross_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
weighted average in the cross-attention heads.
encoder_last_hidden_state (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Sequence of hidden-states at the output of the last layer of the encoder of the model.
encoder_hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
encoder_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
self-attention heads.
"""
logits: jax_xla.DeviceArray = None
past_key_values: Optional[Tuple[Tuple[jax_xla.DeviceArray]]] = None
decoder_hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
decoder_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
cross_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
encoder_last_hidden_state: Optional[jax_xla.DeviceArray] = None
encoder_hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
encoder_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
@flax.struct.dataclass @flax.struct.dataclass
class FlaxNextSentencePredictorOutput(ModelOutput): class FlaxNextSentencePredictorOutput(ModelOutput):
""" """
...@@ -188,6 +393,63 @@ class FlaxSequenceClassifierOutput(ModelOutput): ...@@ -188,6 +393,63 @@ class FlaxSequenceClassifierOutput(ModelOutput):
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
@flax.struct.dataclass
class FlaxSeq2SeqSequenceClassifierOutput(ModelOutput):
"""
Base class for outputs of sequence-to-sequence sentence classification models.
Args:
logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, config.num_labels)`):
Classification (or regression if config.num_labels==1) scores (before SoftMax).
past_key_values (:obj:`tuple(tuple(jax_xla.DeviceArray))`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
Tuple of :obj:`tuple(jax_xla.DeviceArray)` of length :obj:`config.n_layers`, with each tuple having 2
tensors of shape :obj:`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional
tensors of shape :obj:`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used (see :obj:`past_key_values` input) to speed up sequential decoding.
decoder_hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
decoder_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
self-attention heads.
cross_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
weighted average in the cross-attention heads.
encoder_last_hidden_state (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Sequence of hidden-states at the output of the last layer of the encoder of the model.
encoder_hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
encoder_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
self-attention heads.
"""
logits: jax_xla.DeviceArray = None
past_key_values: Optional[Tuple[Tuple[jax_xla.DeviceArray]]] = None
decoder_hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
decoder_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
cross_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
encoder_last_hidden_state: Optional[jax_xla.DeviceArray] = None
encoder_hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
encoder_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
@flax.struct.dataclass @flax.struct.dataclass
class FlaxMultipleChoiceModelOutput(ModelOutput): class FlaxMultipleChoiceModelOutput(ModelOutput):
""" """
...@@ -269,3 +531,63 @@ class FlaxQuestionAnsweringModelOutput(ModelOutput): ...@@ -269,3 +531,63 @@ class FlaxQuestionAnsweringModelOutput(ModelOutput):
end_logits: jax_xla.DeviceArray = None end_logits: jax_xla.DeviceArray = None
hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
@flax.struct.dataclass
class FlaxSeq2SeqQuestionAnsweringModelOutput(ModelOutput):
"""
Base class for outputs of sequence-to-sequence question answering models.
Args:
start_logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length)`):
Span-start scores (before SoftMax).
end_logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length)`):
Span-end scores (before SoftMax).
past_key_values (:obj:`tuple(tuple(jax_xla.DeviceArray))`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
Tuple of :obj:`tuple(jax_xla.DeviceArray)` of length :obj:`config.n_layers`, with each tuple having 2
tensors of shape :obj:`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional
tensors of shape :obj:`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used (see :obj:`past_key_values` input) to speed up sequential decoding.
decoder_hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
decoder_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
self-attention heads.
cross_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
weighted average in the cross-attention heads.
encoder_last_hidden_state (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Sequence of hidden-states at the output of the last layer of the encoder of the model.
encoder_hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each
layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
encoder_attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
self-attention heads.
"""
start_logits: jax_xla.DeviceArray = None
end_logits: jax_xla.DeviceArray = None
past_key_values: Optional[Tuple[Tuple[jax_xla.DeviceArray]]] = None
decoder_hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
decoder_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
cross_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
encoder_last_hidden_state: Optional[jax_xla.DeviceArray] = None
encoder_hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None
encoder_attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
...@@ -18,6 +18,12 @@ ...@@ -18,6 +18,12 @@
from collections import OrderedDict from collections import OrderedDict
from ...utils import logging from ...utils import logging
from ..bart.modeling_flax_bart import (
FlaxBartForConditionalGeneration,
FlaxBartForQuestionAnswering,
FlaxBartForSequenceClassification,
FlaxBartModel,
)
from ..bert.modeling_flax_bert import ( from ..bert.modeling_flax_bert import (
FlaxBertForMaskedLM, FlaxBertForMaskedLM,
FlaxBertForMultipleChoice, FlaxBertForMultipleChoice,
...@@ -49,7 +55,7 @@ from ..roberta.modeling_flax_roberta import ( ...@@ -49,7 +55,7 @@ from ..roberta.modeling_flax_roberta import (
) )
from ..vit.modeling_flax_vit import FlaxViTForImageClassification, FlaxViTModel from ..vit.modeling_flax_vit import FlaxViTForImageClassification, FlaxViTModel
from .auto_factory import auto_class_factory from .auto_factory import auto_class_factory
from .configuration_auto import BertConfig, CLIPConfig, ElectraConfig, GPT2Config, RobertaConfig, ViTConfig from .configuration_auto import BartConfig, BertConfig, CLIPConfig, ElectraConfig, GPT2Config, RobertaConfig, ViTConfig
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -60,6 +66,7 @@ FLAX_MODEL_MAPPING = OrderedDict( ...@@ -60,6 +66,7 @@ FLAX_MODEL_MAPPING = OrderedDict(
# Base model mapping # Base model mapping
(RobertaConfig, FlaxRobertaModel), (RobertaConfig, FlaxRobertaModel),
(BertConfig, FlaxBertModel), (BertConfig, FlaxBertModel),
(BartConfig, FlaxBartModel),
(GPT2Config, FlaxGPT2Model), (GPT2Config, FlaxGPT2Model),
(ElectraConfig, FlaxElectraModel), (ElectraConfig, FlaxElectraModel),
(CLIPConfig, FlaxCLIPModel), (CLIPConfig, FlaxCLIPModel),
...@@ -72,6 +79,7 @@ FLAX_MODEL_FOR_PRETRAINING_MAPPING = OrderedDict( ...@@ -72,6 +79,7 @@ FLAX_MODEL_FOR_PRETRAINING_MAPPING = OrderedDict(
# Model for pre-training mapping # Model for pre-training mapping
(RobertaConfig, FlaxRobertaForMaskedLM), (RobertaConfig, FlaxRobertaForMaskedLM),
(BertConfig, FlaxBertForPreTraining), (BertConfig, FlaxBertForPreTraining),
(BartConfig, FlaxBartForConditionalGeneration),
(ElectraConfig, FlaxElectraForPreTraining), (ElectraConfig, FlaxElectraForPreTraining),
] ]
) )
...@@ -81,6 +89,7 @@ FLAX_MODEL_FOR_MASKED_LM_MAPPING = OrderedDict( ...@@ -81,6 +89,7 @@ FLAX_MODEL_FOR_MASKED_LM_MAPPING = OrderedDict(
# Model for Masked LM mapping # Model for Masked LM mapping
(RobertaConfig, FlaxRobertaForMaskedLM), (RobertaConfig, FlaxRobertaForMaskedLM),
(BertConfig, FlaxBertForMaskedLM), (BertConfig, FlaxBertForMaskedLM),
(BartConfig, FlaxBartForConditionalGeneration),
(ElectraConfig, FlaxElectraForMaskedLM), (ElectraConfig, FlaxElectraForMaskedLM),
] ]
) )
...@@ -104,6 +113,7 @@ FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict( ...@@ -104,6 +113,7 @@ FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
# Model for Sequence Classification mapping # Model for Sequence Classification mapping
(RobertaConfig, FlaxRobertaForSequenceClassification), (RobertaConfig, FlaxRobertaForSequenceClassification),
(BertConfig, FlaxBertForSequenceClassification), (BertConfig, FlaxBertForSequenceClassification),
(BartConfig, FlaxBartForSequenceClassification),
(ElectraConfig, FlaxElectraForSequenceClassification), (ElectraConfig, FlaxElectraForSequenceClassification),
] ]
) )
...@@ -113,6 +123,7 @@ FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING = OrderedDict( ...@@ -113,6 +123,7 @@ FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING = OrderedDict(
# Model for Question Answering mapping # Model for Question Answering mapping
(RobertaConfig, FlaxRobertaForQuestionAnswering), (RobertaConfig, FlaxRobertaForQuestionAnswering),
(BertConfig, FlaxBertForQuestionAnswering), (BertConfig, FlaxBertForQuestionAnswering),
(BartConfig, FlaxBartForQuestionAnswering),
(ElectraConfig, FlaxElectraForQuestionAnswering), (ElectraConfig, FlaxElectraForQuestionAnswering),
] ]
) )
......
...@@ -17,7 +17,13 @@ ...@@ -17,7 +17,13 @@
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from ...file_utils import _BaseLazyModule, is_tf_available, is_tokenizers_available, is_torch_available from ...file_utils import (
_BaseLazyModule,
is_flax_available,
is_tf_available,
is_tokenizers_available,
is_torch_available,
)
_import_structure = { _import_structure = {
...@@ -43,6 +49,13 @@ if is_torch_available(): ...@@ -43,6 +49,13 @@ if is_torch_available():
if is_tf_available(): if is_tf_available():
_import_structure["modeling_tf_bart"] = ["TFBartForConditionalGeneration", "TFBartModel", "TFBartPretrainedModel"] _import_structure["modeling_tf_bart"] = ["TFBartForConditionalGeneration", "TFBartModel", "TFBartPretrainedModel"]
if is_flax_available():
_import_structure["modeling_flax_bart"] = [
"FlaxBartForConditionalGeneration",
"FlaxBartForQuestionAnswering",
"FlaxBartForSequenceClassification",
"FlaxBartModel",
]
if TYPE_CHECKING: if TYPE_CHECKING:
from .configuration_bart import BART_PRETRAINED_CONFIG_ARCHIVE_MAP, BartConfig from .configuration_bart import BART_PRETRAINED_CONFIG_ARCHIVE_MAP, BartConfig
...@@ -66,6 +79,14 @@ if TYPE_CHECKING: ...@@ -66,6 +79,14 @@ if TYPE_CHECKING:
if is_tf_available(): if is_tf_available():
from .modeling_tf_bart import TFBartForConditionalGeneration, TFBartModel, TFBartPretrainedModel from .modeling_tf_bart import TFBartForConditionalGeneration, TFBartModel, TFBartPretrainedModel
if is_flax_available():
from .modeling_flax_bart import (
FlaxBartForConditionalGeneration,
FlaxBartForQuestionAnswering,
FlaxBartForSequenceClassification,
FlaxBartModel,
)
else: else:
import importlib import importlib
import os import os
......
# coding=utf-8
# Copyright 2021 The Fairseq Authors and The Google Flax Team Authors And The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Flax Bart model. """
import math
import random
from functools import partial
from typing import Callable, Optional, Tuple
import flax.linen as nn
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict, unfreeze
from flax.linen import combine_masks, make_causal_mask
from flax.linen.attention import dot_product_attention_weights
from jax import lax
from jax.random import PRNGKey
from ...file_utils import add_start_docstrings, replace_return_docstrings
from ...modeling_flax_outputs import (
FlaxBaseModelOutput,
FlaxBaseModelOutputWithPastAndCrossAttentions,
FlaxCausalLMOutputWithCrossAttentions,
FlaxSeq2SeqLMOutput,
FlaxSeq2SeqModelOutput,
FlaxSeq2SeqQuestionAnsweringModelOutput,
FlaxSeq2SeqSequenceClassifierOutput,
)
from ...modeling_flax_utils import (
ACT2FN,
FlaxPreTrainedModel,
append_call_sample_docstring,
append_replace_return_docstrings,
overwrite_call_docstring,
)
from ...utils import logging
from .configuration_bart import BartConfig
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "facebook/bart-base"
_CONFIG_FOR_DOC = "BartConfig"
_TOKENIZER_FOR_DOC = "BartTokenizer"
BART_START_DOCSTRING = r"""
This model inherits from :class:`~transformers.FlaxPreTrainedModel`. Check the superclass documentation for the
generic methods the library implements for all its model (such as downloading or saving, resizing the input
embeddings, pruning heads etc.)
This model is also a Flax Linen `flax.nn.Module
<https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html>`__ subclass. Use it as a regular Flax
Module and refer to the Flax documentation for all matter related to general usage and behavior.
Finally, this model supports inherent JAX features such as:
- `Just-In-Time (JIT) compilation <https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit>`__
- `Automatic Differentiation <https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation>`__
- `Vectorization <https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap>`__
- `Parallelization <https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap>`__
Parameters:
config (:class:`~transformers.BartConfig`): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the
model weights.
"""
BART_INPUTS_DOCSTRING = r"""
Args:
input_ids (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Indices can be obtained using :class:`~transformers.BartTokenizer`. See
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
details.
`What are input IDs? <../glossary.html#input-ids>`__
attention_mask (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__
decoder_input_ids (:obj:`jnp.ndarray` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
Indices of decoder input sequence tokens in the vocabulary.
Indices can be obtained using :class:`~transformers.BartTokenizer`. See
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
details.
`What are decoder input IDs? <../glossary.html#decoder-input-ids>`__
For translation and summarization training, :obj:`decoder_input_ids` should be provided. If no
:obj:`decoder_input_ids` is provided, the model will create this tensor by shifting the :obj:`input_ids` to
the right for denoising pre-training following the paper.
decoder_attention_mask (:obj:`jnp.ndarray` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will
also be used by default.
If you want to change padding behavior, you should modify to your needs. See diagram 1 in `the paper
<https://arxiv.org/abs/1910.13461>`__ for more information on the default strategy.
position_ids (:obj:`numpy.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0,
config.max_position_embeddings - 1]``.
decoder_position_ids (:obj:`numpy.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
range ``[0, config.max_position_embeddings - 1]``.
output_attentions (:obj:`bool`, `optional`):
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
tensors for more detail.
output_hidden_states (:obj:`bool`, `optional`):
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
more detail.
return_dict (:obj:`bool`, `optional`):
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
"""
BART_ENCODE_INPUTS_DOCSTRING = r"""
Args:
input_ids (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Indices can be obtained using :class:`~transformers.BartTokenizer`. See
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
details.
`What are input IDs? <../glossary.html#input-ids>`__
attention_mask (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__
position_ids (:obj:`numpy.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0,
config.max_position_embeddings - 1]``.
output_attentions (:obj:`bool`, `optional`):
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
tensors for more detail.
output_hidden_states (:obj:`bool`, `optional`):
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
more detail.
return_dict (:obj:`bool`, `optional`):
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
"""
BART_DECODE_INPUTS_DOCSTRING = r"""
Args:
decoder_input_ids (:obj:`jnp.ndarray` of shape :obj:`(batch_size, target_sequence_length)`):
Indices of decoder input sequence tokens in the vocabulary.
Indices can be obtained using :class:`~transformers.BartTokenizer`. See
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
details.
`What are decoder input IDs? <../glossary.html#decoder-input-ids>`__
For translation and summarization training, :obj:`decoder_input_ids` should be provided. If no
:obj:`decoder_input_ids` is provided, the model will create this tensor by shifting the :obj:`input_ids` to
the right for denoising pre-training following the paper.
encoder_outputs (:obj:`tuple(tuple(jnp.ndarray)`):
Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`:
:obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`,
`optional`) is a sequence of hidden-states at the output of the last layer of the encoder. Used in the
cross-attention of the decoder.
encoder_attention_mask (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__
decoder_attention_mask (:obj:`jnp.ndarray` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will
also be used by default.
If you want to change padding behavior, you should modify to your needs. See diagram 1 in `the paper
<https://arxiv.org/abs/1910.13461>`__ for more information on the default strategy.
decoder_position_ids (:obj:`numpy.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
range ``[0, config.max_position_embeddings - 1]``.
past_key_values (:obj:`Dict[str, np.ndarray]`, `optional`, returned by ``init_cache`` or when passing previous ``past_key_values``):
Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast
auto-regressive decoding. Pre-computed key and value hidden-states are of shape `[batch_size, max_length]`.
output_attentions (:obj:`bool`, `optional`):
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
tensors for more detail.
output_hidden_states (:obj:`bool`, `optional`):
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
more detail.
return_dict (:obj:`bool`, `optional`):
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
"""
def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray:
"""
Shift input ids one token to the right.
"""
shifted_input_ids = jnp.roll(input_ids, 1, axis=-1)
shifted_input_ids = jax.ops.index_update(shifted_input_ids, (..., 0), decoder_start_token_id)
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids = jnp.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
return shifted_input_ids
class FlaxBartAttention(nn.Module):
config: BartConfig
embed_dim: int
num_heads: int
dropout: float = 0.0
is_decoder: bool = False
causal: bool = False
bias: bool = True
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self) -> None:
self.head_dim = self.embed_dim // self.num_heads
assert (
self.head_dim * self.num_heads == self.embed_dim
), f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})."
dense = partial(
nn.Dense,
self.embed_dim,
use_bias=self.bias,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
)
self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense()
self.out_proj = dense()
self.dropout_layer = nn.Dropout(rate=self.dropout)
if self.causal:
self.causal_mask = make_causal_mask(
jnp.ones((1, self.config.max_position_embeddings), dtype="bool"), dtype="bool"
)
def _split_heads(self, hidden_states):
return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim))
def _merge_heads(self, hidden_states):
return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,))
@nn.compact
def _concatenate_to_cache(self, key, value, query, attention_mask):
"""
This function takes projected key, value states from a single input token and concatenates the states to cached
states from previous steps. This function is slighly adapted from the official Flax repository:
https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
"""
# detect if we're initializing by absence of existing cache data.
is_initialized = self.has_variable("cache", "cached_key")
cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
if is_initialized:
*batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
# update key, value caches with our new 1d spatial slices
cur_index = cache_index.value
indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
key = lax.dynamic_update_slice(cached_key.value, key, indices)
value = lax.dynamic_update_slice(cached_value.value, value, indices)
cached_key.value = key
cached_value.value = value
num_updated_cache_vectors = query.shape[1]
cache_index.value = cache_index.value + num_updated_cache_vectors
# causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements.
pad_mask = jnp.broadcast_to(
jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
)
attention_mask = combine_masks(pad_mask, attention_mask)
return key, value, attention_mask
def __call__(
self,
hidden_states: jnp.ndarray,
key_value_states: Optional[jnp.ndarray] = None,
attention_mask: Optional[jnp.ndarray] = None,
init_cache: bool = False,
deterministic: bool = True,
) -> Tuple[jnp.ndarray]:
"""Input shape: Batch x Time x Channel"""
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
batch_size = hidden_states.shape[0]
# get query proj
query_states = self.q_proj(hidden_states)
# get key, value proj
if is_cross_attention:
# cross_attentions
key_states = self.k_proj(key_value_states)
value_states = self.v_proj(key_value_states)
else:
# self_attention
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = self._split_heads(query_states)
key_states = self._split_heads(key_states)
value_states = self._split_heads(value_states)
# handle cache prepare causal attention mask
if self.causal:
query_length, key_length = query_states.shape[1], key_states.shape[1]
if self.has_variable("cache", "cached_key"):
mask_shift = self.variables["cache"]["cache_index"]
max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
causal_mask = lax.dynamic_slice(
self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)
)
else:
causal_mask = self.causal_mask[:, :, :query_length, :key_length]
causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
# combine masks if needed
if attention_mask is not None and self.causal:
attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
attention_mask = combine_masks(attention_mask, causal_mask)
elif self.causal:
attention_mask = causal_mask
elif attention_mask is not None:
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
# During fast autoregressive decoding, we feed one position at a time,
# and cache the keys and values step by step.
if self.causal and (self.has_variable("cache", "cached_key") or init_cache):
key_states, value_states, attention_mask = self._concatenate_to_cache(
key_states, value_states, query_states, attention_mask
)
# Convert the boolean attention mask to an attention bias.
if attention_mask is not None:
# attention mask in the form of attention bias
attention_bias = lax.select(
attention_mask > 0,
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
jnp.full(attention_mask.shape, float("-inf")).astype(self.dtype),
)
else:
attention_bias = None
dropout_rng = None
if not deterministic and self.dropout > 0.0:
dropout_rng = self.make_rng("dropout")
attn_weights = dot_product_attention_weights(
query_states,
key_states,
bias=attention_bias,
dropout_rng=dropout_rng,
dropout_rate=self.dropout,
broadcast_dropout=True,
deterministic=deterministic,
dtype=self.dtype,
precision=None,
)
attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
attn_output = self._merge_heads(attn_output)
attn_output = self.out_proj(attn_output)
return attn_output, attn_weights
class FlaxBartEncoderLayer(nn.Module):
config: BartConfig
dtype: jnp.dtype = jnp.float32
def setup(self) -> None:
self.embed_dim = self.config.d_model
self.self_attn = FlaxBartAttention(
config=self.config,
embed_dim=self.embed_dim,
num_heads=self.config.encoder_attention_heads,
dropout=self.config.attention_dropout,
)
self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype)
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
self.activation_fn = ACT2FN[self.config.activation_function]
self.acticvation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout)
self.fc1 = nn.Dense(
self.config.encoder_ffn_dim,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
)
self.fc2 = nn.Dense(
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype)
)
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype)
def __call__(
self,
hidden_states: jnp.ndarray,
attention_mask: jnp.ndarray,
output_attentions: bool = True,
deterministic: bool = True,
) -> Tuple[jnp.ndarray]:
residual = hidden_states
hidden_states, attn_weights = self.self_attn(hidden_states=hidden_states, attention_mask=attention_mask)
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
hidden_states = residual + hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
residual = hidden_states
hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = self.acticvation_dropout_layer(hidden_states, deterministic=deterministic)
hidden_states = self.fc2(hidden_states)
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
hidden_states = residual + hidden_states
hidden_states = self.final_layer_norm(hidden_states)
outputs = (hidden_states,)
if output_attentions:
outputs += (attn_weights,)
return outputs
class FlaxBartEncoderLayerCollection(nn.Module):
config: BartConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self):
self.layers = [
FlaxBartEncoderLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.encoder_layers)
]
self.layerdrop = self.config.encoder_layerdrop
def __call__(
self,
hidden_states,
attention_mask,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
all_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
for encoder_layer in self.layers:
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
dropout_probability = random.uniform(0, 1)
if not deterministic and (dropout_probability < self.layerdrop): # skip the layer
layer_outputs = (None, None)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
output_attentions,
deterministic,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
if output_hidden_states:
all_hidden_states += (hidden_states,)
outputs = (hidden_states, all_hidden_states, all_attentions)
if not return_dict:
return tuple(v for v in outputs if v is not None)
return FlaxBaseModelOutput(
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
)
class FlaxBartDecoderLayer(nn.Module):
config: BartConfig
dtype: jnp.dtype = jnp.float32
def setup(self) -> None:
self.embed_dim = self.config.d_model
self.self_attn = FlaxBartAttention(
config=self.config,
embed_dim=self.embed_dim,
num_heads=self.config.decoder_attention_heads,
dropout=self.config.attention_dropout,
is_decoder=True,
causal=True,
)
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
self.activation_fn = ACT2FN[self.config.activation_function]
self.acticvation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout)
self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype)
self.encoder_attn = FlaxBartAttention(
config=self.config,
embed_dim=self.embed_dim,
num_heads=self.config.decoder_attention_heads,
dropout=self.config.attention_dropout,
is_decoder=True,
)
self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype)
self.fc1 = nn.Dense(
self.config.encoder_ffn_dim,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
)
self.fc2 = nn.Dense(
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype)
)
self.final_layer_norm = nn.LayerNorm(dtype=self.dtype)
def __call__(
self,
hidden_states: jnp.ndarray,
attention_mask: jnp.ndarray,
encoder_hidden_states: Optional[jnp.ndarray] = None,
encoder_attention_mask: Optional[jnp.ndarray] = None,
init_cache: bool = False,
output_attentions: bool = True,
deterministic: bool = True,
) -> Tuple[jnp.ndarray]:
residual = hidden_states
# Self Attention
hidden_states, self_attn_weights = self.self_attn(
hidden_states=hidden_states, attention_mask=attention_mask, init_cache=init_cache
)
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
hidden_states = residual + hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
# Cross-Attention Block
cross_attn_weights = None
if encoder_hidden_states is not None:
residual = hidden_states
hidden_states, cross_attn_weights = self.encoder_attn(
hidden_states=hidden_states,
key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
)
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
hidden_states = residual + hidden_states
hidden_states = self.encoder_attn_layer_norm(hidden_states)
# Fully Connected
residual = hidden_states
hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = self.acticvation_dropout_layer(hidden_states, deterministic=deterministic)
hidden_states = self.fc2(hidden_states)
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
hidden_states = residual + hidden_states
hidden_states = self.final_layer_norm(hidden_states)
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights, cross_attn_weights)
return outputs
class FlaxBartDecoderLayerCollection(nn.Module):
config: BartConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self):
self.layers = [
FlaxBartDecoderLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.decoder_layers)
]
self.layerdrop = self.config.encoder_layerdrop
def __call__(
self,
hidden_states,
attention_mask,
encoder_hidden_states: Optional[jnp.ndarray] = None,
encoder_attention_mask: Optional[jnp.ndarray] = None,
deterministic: bool = True,
init_cache: bool = False,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
dropout_probability = random.uniform(0, 1)
if not deterministic and (dropout_probability < self.layerdrop):
layer_outputs = (None, None, None)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
init_cache=init_cache,
output_attentions=output_attentions,
deterministic=deterministic,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attns += (layer_outputs[1],)
if encoder_hidden_states is not None:
all_cross_attentions += (layer_outputs[2],)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
outputs = [hidden_states, all_hidden_states, all_self_attns, all_cross_attentions]
if not return_dict:
return tuple(v for v in outputs if v is not None)
return FlaxBaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_self_attns,
cross_attentions=all_cross_attentions,
)
class FlaxBartClassificationHead(nn.Module):
"""Head for sentence-level classification tasks."""
config: BartConfig
inner_dim: int
num_classes: int
pooler_dropout: float
dtype: jnp.dtype = jnp.float32
def setup(self):
self.dense = nn.Dense(
self.inner_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype)
)
self.dropout = nn.Dropout(rate=self.pooler_dropout)
self.out_proj = nn.Dense(
self.num_classes,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
)
def __call__(self, hidden_states: jnp.ndarray, deterministic: bool):
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
hidden_states = self.dense(hidden_states)
hidden_states = jnp.tanh(hidden_states)
hidden_states = self.dropout(hidden_states, deterministic=deterministic)
hidden_states = self.out_proj(hidden_states)
return hidden_states
class FlaxBartEncoder(nn.Module):
config: BartConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
embed_tokens: Optional[nn.Embed] = None
def setup(self):
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
self.layerdrop = self.config.encoder_layerdrop
embed_dim = self.config.d_model
self.padding_idx = self.config.pad_token_id
self.max_source_positions = self.config.max_position_embeddings
self.embed_scale = math.sqrt(embed_dim) if self.config.scale_embedding else 1.0
if self.embed_tokens is None:
self.embed_tokens = nn.Embed(
self.config.vocab_size,
embed_dim,
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
dtype=self.dtype,
)
# Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
# and adjust num_embeddings appropriately. Other models don't have this hack
self.offset = 2
self.embed_positions = nn.Embed(
self.config.max_position_embeddings + self.offset,
embed_dim,
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
dtype=self.dtype,
)
self.layers = FlaxBartEncoderLayerCollection(self.config, self.dtype)
self.layernorm_embedding = nn.LayerNorm(dtype=self.dtype)
def __call__(
self,
input_ids,
attention_mask,
position_ids,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
deterministic: bool = True,
):
input_shape = input_ids.shape
input_ids = input_ids.reshape(-1, input_shape[-1])
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
embed_pos = self.embed_positions(position_ids + self.offset)
hidden_states = inputs_embeds + embed_pos
hidden_states = self.layernorm_embedding(hidden_states)
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
outputs = self.layers(
hidden_states,
attention_mask,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
if not return_dict:
return outputs
return FlaxBaseModelOutput(
last_hidden_state=outputs.last_hidden_state,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
class FlaxBartDecoder(nn.Module):
config: BartConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
embed_tokens: Optional[nn.Embed] = None
def setup(self):
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
self.layerdrop = self.config.decoder_layerdrop
embed_dim = self.config.d_model
self.padding_idx = self.config.pad_token_id
self.max_target_positions = self.config.max_position_embeddings
self.embed_scale = math.sqrt(self.config.d_model) if self.config.scale_embedding else 1.0
if self.embed_tokens is None:
self.embed_tokens = nn.Embed(
self.config.vocab_size,
embed_dim,
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
dtype=self.dtype,
)
# Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
# and adjust num_embeddings appropriately. Other models don't have this hack
self.offset = 2
self.embed_positions = nn.Embed(
self.config.max_position_embeddings + self.offset,
embed_dim,
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
dtype=self.dtype,
)
self.layers = FlaxBartDecoderLayerCollection(self.config, self.dtype)
self.layernorm_embedding = nn.LayerNorm(dtype=self.dtype)
def __call__(
self,
input_ids,
attention_mask,
position_ids,
encoder_hidden_states: Optional[jnp.ndarray] = None,
encoder_attention_mask: Optional[jnp.ndarray] = None,
init_cache: bool = False,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
deterministic: bool = True,
):
input_shape = input_ids.shape
input_ids = input_ids.reshape(-1, input_shape[-1])
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
# embed positions
positions = self.embed_positions(position_ids + self.offset)
hidden_states = inputs_embeds + positions
hidden_states = self.layernorm_embedding(hidden_states)
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
outputs = self.layers(
hidden_states,
attention_mask,
encoder_hidden_states,
encoder_attention_mask,
deterministic=deterministic,
init_cache=init_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
if not return_dict:
return outputs
return FlaxBaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=outputs.last_hidden_state,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
cross_attentions=outputs.cross_attentions,
)
class FlaxBartModule(nn.Module):
config: BartConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self):
self.shared = nn.Embed(
self.config.vocab_size,
self.config.d_model,
embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
dtype=self.dtype,
)
self.encoder = FlaxBartEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared)
self.decoder = FlaxBartDecoder(self.config, dtype=self.dtype, embed_tokens=self.shared)
def _get_encoder_module(self):
return self.encoder
def _get_decoder_module(self):
return self.decoder
def __call__(
self,
input_ids,
attention_mask,
decoder_input_ids,
decoder_attention_mask,
position_ids,
decoder_position_ids,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
deterministic: bool = True,
):
encoder_outputs = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=deterministic,
)
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask,
position_ids=decoder_position_ids,
encoder_hidden_states=encoder_outputs[0],
encoder_attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=deterministic,
)
if not return_dict:
return decoder_outputs + encoder_outputs
return FlaxSeq2SeqModelOutput(
last_hidden_state=decoder_outputs.last_hidden_state,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions,
)
class FlaxBartPretrainedModel(FlaxPreTrainedModel):
config_class = BartConfig
base_model_prefix: str = "model"
module_class: nn.Module = None
def __init__(
self,
config: BartConfig,
input_shape: Tuple[int] = (1, 1),
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
**kwargs
):
module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
# init input tensors
input_ids = jnp.zeros(input_shape, dtype="i4")
# make sure initialization pass will work for FlaxBartForSequenceClassificationModule
input_ids = jax.ops.index_update(input_ids, (..., -1), self.config.eos_token_id)
attention_mask = jnp.ones_like(input_ids)
decoder_input_ids = input_ids
decoder_attention_mask = jnp.ones_like(input_ids)
batch_size, sequence_length = input_ids.shape
position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
decoder_position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
return self.module.init(
rngs,
input_ids,
attention_mask,
decoder_input_ids,
decoder_attention_mask,
position_ids,
decoder_position_ids,
)["params"]
def init_cache(self, batch_size, max_length, encoder_outputs):
r"""
Args:
batch_size (:obj:`int`):
batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
max_length (:obj:`int`):
maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
cache.
encoder_outputs (:obj:`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`):
``encoder_outputs`` consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`,
`optional`: :obj:`attentions`). :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length,
hidden_size)`, `optional`) is a sequence of hidden-states at the output of the last layer of the
encoder. Used in the cross-attention of the decoder.
"""
# init input variables to retrieve cache
decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4")
decoder_attention_mask = jnp.ones_like(decoder_input_ids)
decoder_position_ids = jnp.broadcast_to(
jnp.arange(jnp.atleast_2d(decoder_input_ids).shape[-1]), decoder_input_ids.shape
)
def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs):
decoder_module = module._get_decoder_module()
return decoder_module(
decoder_input_ids,
decoder_attention_mask,
decoder_position_ids,
**kwargs,
)
init_variables = self.module.init(
jax.random.PRNGKey(0),
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
decoder_position_ids=decoder_position_ids,
encoder_hidden_states=encoder_outputs[0],
init_cache=True,
method=_decoder_forward, # we only need to call the decoder to init the cache
)
return unfreeze(init_variables["cache"])
@add_start_docstrings(BART_ENCODE_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=FlaxBaseModelOutput, config_class=BartConfig)
def encode(
self,
input_ids: jnp.ndarray,
attention_mask: Optional[jnp.ndarray] = None,
position_ids: Optional[jnp.ndarray] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
train: bool = False,
params: dict = None,
dropout_rng: PRNGKey = None,
):
r"""
Returns:
Example::
>>> from transformers import BartTokenizer, FlaxBartForConditionalGeneration
>>> model = FlaxBartForConditionalGeneration.from_pretrained('facebook/bart-large-cnn')
>>> tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-cnn')
>>> text = "My friends are cool but they eat too many carbs."
>>> inputs = tokenizer(text, max_length=1024, return_tensors='jax')
>>> encoder_outputs = model.encode(**inputs)
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.return_dict
if attention_mask is None:
attention_mask = jnp.ones_like(input_ids)
if position_ids is None:
batch_size, sequence_length = input_ids.shape
position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
# Handle any PRNG if needed
rngs = {}
if dropout_rng is not None:
rngs["dropout"] = dropout_rng
def _encoder_forward(module, input_ids, attention_mask, position_ids, **kwargs):
encode_module = module._get_encoder_module()
return encode_module(input_ids, attention_mask, position_ids, **kwargs)
return self.module.apply(
{"params": params or self.params},
input_ids=jnp.array(input_ids, dtype="i4"),
attention_mask=jnp.array(attention_mask, dtype="i4"),
position_ids=jnp.array(position_ids, dtype="i4"),
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=not train,
rngs=rngs,
method=_encoder_forward,
)
@add_start_docstrings(BART_DECODE_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=FlaxBaseModelOutputWithPastAndCrossAttentions, config_class=BartConfig)
def decode(
self,
decoder_input_ids,
encoder_outputs,
encoder_attention_mask: Optional[jnp.ndarray] = None,
decoder_attention_mask: Optional[jnp.ndarray] = None,
decoder_position_ids: Optional[jnp.ndarray] = None,
past_key_values: dict = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
train: bool = False,
params: dict = None,
dropout_rng: PRNGKey = None,
):
r"""
Returns:
Example::
>>> from transformers import BartTokenizer, FlaxBartForConditionalGeneration
>>> model = FlaxBartForConditionalGeneration.from_pretrained('facebook/bart-large-cnn')
>>> tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-cnn')
>>> text = "My friends are cool but they eat too many carbs."
>>> inputs = tokenizer(text, max_length=1024, return_tensors='jax')
>>> encoder_outputs = model.encode(**inputs)
>>> decoder_start_token_id = model.config.decoder_start_token_id
>>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id
>>> outputs = model.decode(decoder_input_ids, encoder_outputs)
>>> last_decoder_hidden_states = outputs.last_hidden_state
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.return_dict
encoder_hidden_states = encoder_outputs[0]
if encoder_attention_mask is None:
batch_size, sequence_length = encoder_hidden_states.shape[:2]
encoder_attention_mask = jnp.ones((batch_size, sequence_length))
batch_size, sequence_length = decoder_input_ids.shape
if decoder_attention_mask is None:
decoder_attention_mask = jnp.ones((batch_size, sequence_length))
if decoder_position_ids is None:
if past_key_values is not None:
raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.")
decoder_position_ids = jnp.broadcast_to(
jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
)
# Handle any PRNG if needed
rngs = {}
if dropout_rng is not None:
rngs["dropout"] = dropout_rng
inputs = {"params": params or self.params}
# if past_key_values are passed then cache is already initialized a private flag init_cache has to be
# passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that
# it can be changed by FlaxBartAttention module
if past_key_values:
inputs["cache"] = past_key_values
mutable = ["cache"]
else:
mutable = False
def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs):
decoder_module = module._get_decoder_module()
return decoder_module(
decoder_input_ids,
decoder_attention_mask,
decoder_position_ids,
**kwargs,
)
outputs = self.module.apply(
inputs,
decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"),
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=not train,
rngs=rngs,
mutable=mutable,
method=_decoder_forward,
)
# add updated cache to model output
if past_key_values is not None and return_dict:
outputs, past = outputs
outputs["past_key_values"] = unfreeze(past["cache"])
return outputs
elif past_key_values is not None and not return_dict:
outputs, past = outputs
outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]
return outputs
def __call__(
self,
input_ids: jnp.ndarray,
attention_mask: Optional[jnp.ndarray] = None,
decoder_input_ids: Optional[jnp.ndarray] = None,
decoder_attention_mask: Optional[jnp.ndarray] = None,
position_ids: Optional[jnp.ndarray] = None,
decoder_position_ids: Optional[jnp.ndarray] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
train: bool = False,
params: dict = None,
dropout_rng: PRNGKey = None,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.return_dict
# prepare encoder inputs
if attention_mask is None:
attention_mask = jnp.ones_like(input_ids)
if position_ids is None:
batch_size, sequence_length = input_ids.shape
position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
# prepare decoder inputs
if decoder_input_ids is None:
decoder_input_ids = shift_tokens_right(
input_ids, self.config.pad_token_id, decoder_start_token_id=self.config.decoder_start_token_id
)
if decoder_attention_mask is None:
decoder_attention_mask = jnp.ones_like(decoder_input_ids)
if decoder_position_ids is None:
batch_size, sequence_length = decoder_input_ids.shape
decoder_position_ids = jnp.broadcast_to(
jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
)
# Handle any PRNG if needed
rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
return self.module.apply(
{"params": params or self.params},
input_ids=jnp.array(input_ids, dtype="i4"),
attention_mask=jnp.array(attention_mask, dtype="i4"),
position_ids=jnp.array(position_ids, dtype="i4"),
decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=not train,
rngs=rngs,
)
@add_start_docstrings(
"The bare Bart Model transformer outputting raw hidden-states without any specific head on top.",
BART_START_DOCSTRING,
)
class FlaxBartModel(FlaxBartPretrainedModel):
config: BartConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
module_class = FlaxBartModule
append_call_sample_docstring(
FlaxBartModel, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxSeq2SeqModelOutput, _CONFIG_FOR_DOC
)
class FlaxBartForConditionalGenerationModule(nn.Module):
config: BartConfig
dtype: jnp.dtype = jnp.float32
bias_init: Callable[..., jnp.ndarray] = jax.nn.initializers.zeros
def setup(self):
self.model = FlaxBartModule(config=self.config, dtype=self.dtype)
self.lm_head = nn.Dense(
self.model.shared.num_embeddings,
use_bias=False,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
)
self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, self.model.shared.num_embeddings))
def _get_encoder_module(self):
return self.model.encoder
def _get_decoder_module(self):
return self.model.decoder
def __call__(
self,
input_ids,
attention_mask,
decoder_input_ids,
decoder_attention_mask,
position_ids,
decoder_position_ids,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
deterministic: bool = True,
):
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
position_ids=position_ids,
decoder_position_ids=decoder_position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=deterministic,
)
hidden_states = outputs[0]
if self.config.tie_word_embeddings:
shared_embedding = self.model.variables["params"]["shared"]["embedding"]
lm_logits = self.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states)
else:
lm_logits = self.lm_head(hidden_states)
lm_logits += self.final_logits_bias
if not return_dict:
output = (lm_logits,) + outputs[1:]
return output
return FlaxSeq2SeqLMOutput(
logits=lm_logits,
decoder_hidden_states=outputs.decoder_hidden_states,
decoder_attentions=outputs.decoder_attentions,
cross_attentions=outputs.cross_attentions,
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
encoder_hidden_states=outputs.encoder_hidden_states,
encoder_attentions=outputs.encoder_attentions,
)
@add_start_docstrings(
"The BART Model with a language modeling head. Can be used for summarization.", BART_START_DOCSTRING
)
class FlaxBartForConditionalGeneration(FlaxBartPretrainedModel):
module_class = FlaxBartForConditionalGenerationModule
dtype: jnp.dtype = jnp.float32
@add_start_docstrings(BART_DECODE_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=FlaxCausalLMOutputWithCrossAttentions, config_class=BartConfig)
def decode(
self,
decoder_input_ids,
encoder_outputs,
encoder_attention_mask: Optional[jnp.ndarray] = None,
decoder_attention_mask: Optional[jnp.ndarray] = None,
decoder_position_ids: Optional[jnp.ndarray] = None,
past_key_values: dict = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
deterministic: bool = True,
params: dict = None,
dropout_rng: PRNGKey = None,
):
r"""
Returns:
Example::
>>> from transformers import BartTokenizer, FlaxBartForConditionalGeneration
>>> model = FlaxBartForConditionalGeneration.from_pretrained('facebook/bart-large-cnn')
>>> tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-cnn')
>>> text = "My friends are cool but they eat too many carbs."
>>> inputs = tokenizer(text, max_length=1024, return_tensors='jax')
>>> encoder_outputs = model.encode(**inputs)
>>> decoder_start_token_id = model.config.decoder_start_token_id
>>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id
>>> outputs = model.decode(decoder_input_ids, encoder_outputs)
>>> logits = outputs.logits
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.return_dict
encoder_hidden_states = encoder_outputs[0]
if encoder_attention_mask is None:
batch_size, sequence_length = encoder_hidden_states.shape[:2]
encoder_attention_mask = jnp.ones((batch_size, sequence_length))
batch_size, sequence_length = decoder_input_ids.shape
if decoder_attention_mask is None:
decoder_attention_mask = jnp.ones((batch_size, sequence_length))
if decoder_position_ids is None:
if past_key_values is not None:
raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.")
decoder_position_ids = jnp.broadcast_to(
jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
)
# Handle any PRNG if needed
rngs = {}
if dropout_rng is not None:
rngs["dropout"] = dropout_rng
inputs = {"params": params or self.params}
# if past_key_values are passed then cache is already initialized a private flag init_cache has to be
# passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that
# it can be changed by FlaxBartAttention module
if past_key_values:
inputs["cache"] = past_key_values
mutable = ["cache"]
else:
mutable = False
def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs):
decoder_module = module._get_decoder_module()
outputs = decoder_module(
decoder_input_ids,
decoder_attention_mask,
decoder_position_ids,
**kwargs,
)
hidden_states = outputs[0]
if self.config.tie_word_embeddings:
shared_embedding = module.model.variables["params"]["shared"]["embedding"]
lm_logits = module.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states)
else:
lm_logits = module.lm_head(hidden_states)
lm_logits += module.final_logits_bias
return lm_logits, outputs
outputs = self.module.apply(
inputs,
decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"),
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=deterministic,
rngs=rngs,
mutable=mutable,
method=_decoder_forward,
)
if past_key_values is None:
lm_logits, decoder_outputs = outputs
else:
(lm_logits, decoder_outputs), past = outputs
if return_dict:
outputs = FlaxCausalLMOutputWithCrossAttentions(
logits=lm_logits,
hidden_states=decoder_outputs.hidden_states,
attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
)
else:
outputs = (lm_logits,) + decoder_outputs[1:]
# add updated cache to model output
if past_key_values is not None and return_dict:
outputs["past_key_values"] = unfreeze(past["cache"])
return outputs
elif past_key_values is not None and not return_dict:
outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]
return outputs
def prepare_inputs_for_generation(
self,
decoder_input_ids,
max_length,
attention_mask: Optional[jnp.DeviceArray] = None,
decoder_attention_mask: Optional[jnp.DeviceArray] = None,
encoder_outputs=None,
**kwargs
):
# initializing the cache
batch_size, seq_length = decoder_input_ids.shape
past_key_values = self.init_cache(batch_size, max_length, encoder_outputs)
# Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
# But since the decoder uses a causal mask, those positions are masked anyways.
# Thus we can create a single static attention_mask here, which is more efficient for compilation
extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
if decoder_attention_mask is not None:
position_ids = decoder_attention_mask.cumsum(axis=-1) - 1
extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, decoder_attention_mask, (0, 0))
else:
position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
return {
"past_key_values": past_key_values,
"encoder_outputs": encoder_outputs,
"encoder_attention_mask": attention_mask,
"decoder_attention_mask": extended_attention_mask,
"decoder_position_ids": position_ids,
}
def update_inputs_for_generation(self, model_outputs, model_kwargs):
model_kwargs["past_key_values"] = model_outputs.past_key_values
model_kwargs["decoder_position_ids"] = model_kwargs["decoder_position_ids"][:, -1:] + 1
return model_kwargs
FLAX_BART_CONDITIONAL_GENERATION_DOCSTRING = """
Returns:
Summarization example::
>>> from transformers import BartTokenizer, FlaxBartForConditionalGeneration
>>> model = FlaxBartForConditionalGeneration.from_pretrained('facebook/bart-large-cnn')
>>> tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-cnn')
>>> ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs."
>>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors='jax')
>>> # Generate Summary
>>> summary_ids = model.generate(inputs['input_ids']).sequences
>>> print(tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False))
Mask filling example::
>>> from transformers import BartTokenizer, FlaxBartForConditionalGeneration
>>> tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')
>>> TXT = "My friends are <mask> but they eat too many carbs."
>>> model = FlaxBartForConditionalGeneration.from_pretrained('facebook/bart-large')
>>> input_ids = tokenizer([TXT], return_tensors='jax')['input_ids']
>>> logits = model(input_ids).logits
>>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item()
>>> probs = jax.nn.softmax(logits[0, masked_index], axis=0)
>>> values, predictions = jax.lax.top_k(probs)
>>> tokenizer.decode(predictions).split()
"""
overwrite_call_docstring(
FlaxBartForConditionalGeneration, BART_INPUTS_DOCSTRING + FLAX_BART_CONDITIONAL_GENERATION_DOCSTRING
)
append_replace_return_docstrings(
FlaxBartForConditionalGeneration, output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC
)
class FlaxBartForSequenceClassificationModule(nn.Module):
config: BartConfig
dtype: jnp.dtype = jnp.float32
num_labels: Optional[int] = None
def setup(self):
self.model = FlaxBartModule(config=self.config, dtype=self.dtype)
self.classification_head = FlaxBartClassificationHead(
config=self.config,
inner_dim=self.config.d_model,
num_classes=self.num_labels if self.num_labels is not None else self.config.num_labels,
pooler_dropout=self.config.classifier_dropout,
)
def _get_encoder_module(self):
return self.model.encoder
def _get_decoder_module(self):
return self.model.decoder
def __call__(
self,
input_ids,
attention_mask,
decoder_input_ids,
decoder_attention_mask,
position_ids,
decoder_position_ids,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
deterministic: bool = True,
):
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
position_ids=position_ids,
decoder_position_ids=decoder_position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=deterministic,
)
hidden_states = outputs[0] # last hidden state
eos_mask = jnp.where(input_ids == self.config.eos_token_id, 1, 0)
# The first condition is necessary to overcome jax._src.errors.ConcretizationTypeError during JIT compilation
if type(eos_mask) != jax.interpreters.partial_eval.DynamicJaxprTracer:
if len(jnp.unique(eos_mask.sum(1))) > 1:
raise ValueError("All examples must have the same number of <eos> tokens.")
if any(eos_mask.sum(1) == 0):
raise ValueError("There are missing <eos> tokens in input_ids")
# Ensure to keep 1 only for the last <eos> token for each example
eos_mask_noised = eos_mask + jnp.arange(eos_mask.shape[1]) * 1e-6
eos_mask = jnp.where(eos_mask_noised == eos_mask_noised.max(1).reshape(-1, 1), 1, 0)
sentence_representation = jnp.einsum("ijk, ij -> ijk", hidden_states, eos_mask).sum(1)
logits = self.classification_head(sentence_representation, deterministic=deterministic)
if not return_dict:
output = (logits,) + outputs[1:]
return output
return FlaxSeq2SeqSequenceClassifierOutput(
logits=logits,
decoder_hidden_states=outputs.decoder_hidden_states,
decoder_attentions=outputs.decoder_attentions,
cross_attentions=outputs.cross_attentions,
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
encoder_hidden_states=outputs.encoder_hidden_states,
encoder_attentions=outputs.encoder_attentions,
)
@add_start_docstrings(
"""
Bart model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE
tasks.
""",
BART_START_DOCSTRING,
)
class FlaxBartForSequenceClassification(FlaxBartPretrainedModel):
module_class = FlaxBartForSequenceClassificationModule
dtype = jnp.float32
append_call_sample_docstring(
FlaxBartForSequenceClassification,
_TOKENIZER_FOR_DOC,
_CHECKPOINT_FOR_DOC,
FlaxSeq2SeqSequenceClassifierOutput,
_CONFIG_FOR_DOC,
)
class FlaxBartForQuestionAnsweringModule(nn.Module):
config: BartConfig
dtype: jnp.dtype = jnp.float32
num_labels = 2
def setup(self):
self.model = FlaxBartModule(config=self.config, dtype=self.dtype)
self.qa_outputs = nn.Dense(
self.num_labels, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype)
)
def _get_encoder_module(self):
return self.model.encoder
def _get_decoder_module(self):
return self.model.decoder
def __call__(
self,
input_ids,
attention_mask,
decoder_input_ids,
decoder_attention_mask,
position_ids,
decoder_position_ids,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
deterministic: bool = True,
):
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
position_ids=position_ids,
decoder_position_ids=decoder_position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=deterministic,
)
sequence_output = outputs[0]
logits = self.qa_outputs(sequence_output)
start_logits, end_logits = jnp.split(logits, logits.shape[-1], axis=-1)
start_logits = start_logits.squeeze(-1)
end_logits = end_logits.squeeze(-1)
if not return_dict:
output = (start_logits, end_logits) + outputs[1:]
return output
return FlaxSeq2SeqQuestionAnsweringModelOutput(
start_logits=start_logits,
end_logits=end_logits,
decoder_hidden_states=outputs.decoder_hidden_states,
decoder_attentions=outputs.decoder_attentions,
cross_attentions=outputs.cross_attentions,
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
encoder_hidden_states=outputs.encoder_hidden_states,
encoder_attentions=outputs.encoder_attentions,
)
@add_start_docstrings(
"""
BART Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
""",
BART_START_DOCSTRING,
)
class FlaxBartForQuestionAnswering(FlaxBartPretrainedModel):
module_class = FlaxBartForQuestionAnsweringModule
dtype = jnp.float32
append_call_sample_docstring(
FlaxBartForQuestionAnswering,
_TOKENIZER_FOR_DOC,
_CHECKPOINT_FOR_DOC,
FlaxSeq2SeqQuestionAnsweringModelOutput,
_CONFIG_FOR_DOC,
)
...@@ -149,6 +149,42 @@ class FlaxAutoModelForTokenClassification: ...@@ -149,6 +149,42 @@ class FlaxAutoModelForTokenClassification:
requires_backends(cls, ["flax"]) requires_backends(cls, ["flax"])
class FlaxBartForConditionalGeneration:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
class FlaxBartForQuestionAnswering:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
class FlaxBartForSequenceClassification:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
class FlaxBartModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
class FlaxBertForMaskedLM: class FlaxBertForMaskedLM:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"]) requires_backends(self, ["flax"])
......
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import timeout_decorator # noqa
from transformers import BartConfig, is_flax_available
from transformers.testing_utils import require_flax, slow
from .test_generation_flax_utils import FlaxGenerationTesterMixin
from .test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor
if is_flax_available():
import os
# The slow tests are often failing with OOM error on GPU
# This makes JAX allocate exactly what is needed on demand, and deallocate memory that is no longer needed
# but will be slower as stated here https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
import jax
import jax.numpy as jnp
from transformers.models.bart.modeling_flax_bart import (
FlaxBartForConditionalGeneration,
FlaxBartForQuestionAnswering,
FlaxBartForSequenceClassification,
FlaxBartModel,
shift_tokens_right,
)
def prepare_bart_inputs_dict(
config,
input_ids,
decoder_input_ids=None,
attention_mask=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
):
if attention_mask is None:
attention_mask = np.where(input_ids != config.pad_token_id, 1, 0)
if decoder_attention_mask is None:
decoder_attention_mask = np.where(decoder_input_ids != config.pad_token_id, 1, 0)
if head_mask is None:
head_mask = np.ones((config.encoder_layers, config.encoder_attention_heads))
if decoder_head_mask is None:
decoder_head_mask = np.ones((config.decoder_layers, config.decoder_attention_heads))
if cross_attn_head_mask is None:
cross_attn_head_mask = np.ones((config.decoder_layers, config.decoder_attention_heads))
return {
"input_ids": input_ids,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"decoder_attention_mask": attention_mask,
}
class FlaxBartModelTester(unittest.TestCase):
def __init__(
self,
parent,
batch_size=13,
seq_length=7,
is_training=True,
use_labels=False,
vocab_size=99,
hidden_size=16,
num_hidden_layers=2,
num_attention_heads=4,
intermediate_size=4,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=32,
eos_token_id=2,
pad_token_id=1,
bos_token_id=0,
initializer_range=0.02,
):
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
self.is_training = is_training
self.use_labels = use_labels
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.eos_token_id = eos_token_id
self.pad_token_id = pad_token_id
self.bos_token_id = bos_token_id
self.initializer_range = initializer_range
def prepare_config_and_inputs(self):
input_ids = np.clip(ids_tensor([self.batch_size, self.seq_length - 1], self.vocab_size), 3, self.vocab_size)
input_ids = np.concatenate((input_ids, 2 * np.ones((self.batch_size, 1), dtype=np.int64)), -1)
decoder_input_ids = shift_tokens_right(input_ids, 1, 2)
config = BartConfig(
vocab_size=self.vocab_size,
d_model=self.hidden_size,
encoder_layers=self.num_hidden_layers,
decoder_layers=self.num_hidden_layers,
encoder_attention_heads=self.num_attention_heads,
decoder_attention_heads=self.num_attention_heads,
encoder_ffn_dim=self.intermediate_size,
decoder_ffn_dim=self.intermediate_size,
dropout=self.hidden_dropout_prob,
attention_dropout=self.attention_probs_dropout_prob,
max_position_embeddings=self.max_position_embeddings,
eos_token_id=self.eos_token_id,
bos_token_id=self.bos_token_id,
pad_token_id=self.pad_token_id,
initializer_range=self.initializer_range,
use_cache=False,
)
inputs_dict = prepare_bart_inputs_dict(config, input_ids, decoder_input_ids)
return config, inputs_dict
def prepare_config_and_inputs_for_common(self):
config, inputs_dict = self.prepare_config_and_inputs()
return config, inputs_dict
def check_use_cache_forward(self, model_class_name, config, inputs_dict):
max_decoder_length = 20
model = model_class_name(config)
encoder_outputs = model.encode(inputs_dict["input_ids"])
decoder_input_ids, decoder_attention_mask = (
inputs_dict["decoder_input_ids"],
inputs_dict["decoder_attention_mask"],
)
past_key_values = model.init_cache(decoder_input_ids.shape[0], max_decoder_length, encoder_outputs)
decoder_attention_mask = jnp.ones((decoder_input_ids.shape[0], max_decoder_length), dtype="i4")
decoder_position_ids = jnp.broadcast_to(
jnp.arange(decoder_input_ids.shape[-1] - 1)[None, :],
(decoder_input_ids.shape[0], decoder_input_ids.shape[-1] - 1),
)
outputs_cache = model.decode(
decoder_input_ids[:, :-1],
encoder_outputs,
decoder_attention_mask=decoder_attention_mask,
past_key_values=past_key_values,
decoder_position_ids=decoder_position_ids,
)
decoder_position_ids = jnp.array(decoder_input_ids.shape[0] * [[decoder_input_ids.shape[-1] - 1]], dtype="i4")
outputs_cache_next = model.decode(
decoder_input_ids[:, -1:],
encoder_outputs,
decoder_attention_mask=decoder_attention_mask,
past_key_values=outputs_cache.past_key_values,
decoder_position_ids=decoder_position_ids,
)
outputs = model.decode(decoder_input_ids, encoder_outputs)
diff = np.max(np.abs((outputs_cache_next[0][:, -1, :5] - outputs[0][:, -1, :5])))
self.parent.assertTrue(diff < 1e-3, msg=f"Max diff is {diff}")
def check_use_cache_forward_with_attn_mask(self, model_class_name, config, inputs_dict):
max_decoder_length = 20
model = model_class_name(config)
encoder_outputs = model.encode(inputs_dict["input_ids"])
decoder_input_ids, decoder_attention_mask = (
inputs_dict["decoder_input_ids"],
inputs_dict["decoder_attention_mask"],
)
decoder_attention_mask_cache = jnp.concatenate(
[
decoder_attention_mask,
jnp.zeros((decoder_attention_mask.shape[0], max_decoder_length - decoder_attention_mask.shape[1])),
],
axis=-1,
)
past_key_values = model.init_cache(decoder_input_ids.shape[0], max_decoder_length, encoder_outputs)
decoder_position_ids = jnp.broadcast_to(
jnp.arange(decoder_input_ids.shape[-1] - 1)[None, :],
(decoder_input_ids.shape[0], decoder_input_ids.shape[-1] - 1),
)
outputs_cache = model.decode(
decoder_input_ids[:, :-1],
encoder_outputs,
decoder_attention_mask=decoder_attention_mask_cache,
past_key_values=past_key_values,
decoder_position_ids=decoder_position_ids,
)
decoder_position_ids = jnp.array(decoder_input_ids.shape[0] * [[decoder_input_ids.shape[-1] - 1]], dtype="i4")
outputs_cache_next = model.decode(
decoder_input_ids[:, -1:],
encoder_outputs,
past_key_values=outputs_cache.past_key_values,
decoder_attention_mask=decoder_attention_mask_cache,
decoder_position_ids=decoder_position_ids,
)
outputs = model.decode(decoder_input_ids, encoder_outputs, decoder_attention_mask=decoder_attention_mask)
diff = np.max(np.abs((outputs_cache_next[0][:, -1, :5] - outputs[0][:, -1, :5])))
self.parent.assertTrue(diff < 1e-3, msg=f"Max diff is {diff}")
@require_flax
class BartHeadTests(unittest.TestCase):
vocab_size = 99
def _get_config_and_data(self):
input_ids = np.array(
[
[71, 82, 18, 33, 46, 91, 2],
[68, 34, 26, 58, 30, 82, 2],
[5, 97, 17, 39, 94, 40, 2],
[76, 83, 94, 25, 70, 78, 2],
[87, 59, 41, 35, 48, 66, 2],
[55, 13, 16, 58, 5, 2, 1], # note padding
[64, 27, 31, 51, 12, 75, 2],
[52, 64, 86, 17, 83, 39, 2],
[48, 61, 9, 24, 71, 82, 2],
[26, 1, 60, 48, 22, 13, 2],
[21, 5, 62, 28, 14, 76, 2],
[45, 98, 37, 86, 59, 48, 2],
[70, 70, 50, 9, 28, 0, 2],
],
dtype=np.int64,
)
batch_size = input_ids.shape[0]
config = BartConfig(
vocab_size=self.vocab_size,
d_model=24,
encoder_layers=2,
decoder_layers=2,
encoder_attention_heads=2,
decoder_attention_heads=2,
encoder_ffn_dim=32,
decoder_ffn_dim=32,
max_position_embeddings=48,
eos_token_id=2,
pad_token_id=1,
bos_token_id=0,
)
return config, input_ids, batch_size
def test_sequence_classification_forward(self):
config, input_ids, batch_size = self._get_config_and_data()
model = FlaxBartForSequenceClassification(config)
outputs = model(input_ids=input_ids, decoder_input_ids=input_ids)
expected_shape = (batch_size, config.num_labels)
self.assertEqual(outputs["logits"].shape, expected_shape)
def test_question_answering_forward(self):
config, input_ids, batch_size = self._get_config_and_data()
model = FlaxBartForQuestionAnswering(config)
outputs = model(input_ids=input_ids)
self.assertEqual(outputs["start_logits"].shape, input_ids.shape)
self.assertEqual(outputs["end_logits"].shape, input_ids.shape)
# @timeout_decorator.timeout(1) # not working with the decorator so far
def test_lm_forward(self):
config, input_ids, batch_size = self._get_config_and_data()
lm_model = FlaxBartForConditionalGeneration(config)
outputs = lm_model(input_ids=input_ids)
expected_shape = (batch_size, input_ids.shape[1], config.vocab_size)
self.assertEqual(outputs["logits"].shape, expected_shape)
def test_lm_uneven_forward(self):
config = BartConfig(
vocab_size=self.vocab_size,
d_model=14,
encoder_layers=2,
decoder_layers=2,
encoder_attention_heads=2,
decoder_attention_heads=2,
encoder_ffn_dim=8,
decoder_ffn_dim=8,
max_position_embeddings=48,
)
lm_model = FlaxBartForConditionalGeneration(config)
context = np.array([[71, 82, 18, 33, 46, 91, 2], [68, 34, 26, 58, 30, 2, 1]], dtype=np.int64)
summary = np.array([[82, 71, 82, 18, 2], [58, 68, 2, 1, 1]], dtype=np.int64)
outputs = lm_model(input_ids=context, decoder_input_ids=summary)
expected_shape = (*summary.shape, config.vocab_size)
self.assertEqual(outputs["logits"].shape, expected_shape)
def test_shift_tokens_right(self):
input_ids = np.array([[71, 82, 18, 33, 2, 1, 1], [68, 34, 26, 58, 30, 82, 2]], dtype=np.int64)
shifted = shift_tokens_right(input_ids, 1, 2)
n_pad_before = np.equal(input_ids, 1).astype(np.float32).sum()
n_pad_after = np.equal(shifted, 1).astype(np.float32).sum()
self.assertEqual(shifted.shape, input_ids.shape)
self.assertEqual(n_pad_after, n_pad_before - 1)
self.assertTrue(np.equal(shifted[:, 0], 2).all())
@require_flax
class FlaxBartModelTest(FlaxModelTesterMixin, unittest.TestCase, FlaxGenerationTesterMixin):
is_encoder_decoder = True
all_model_classes = (
(
FlaxBartModel,
FlaxBartForConditionalGeneration,
FlaxBartForSequenceClassification,
FlaxBartForQuestionAnswering,
)
if is_flax_available()
else ()
)
all_generative_model_classes = (FlaxBartForConditionalGeneration,) if is_flax_available() else ()
def setUp(self):
self.model_tester = FlaxBartModelTester(self)
def test_use_cache_forward(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs()
for model_class in self.all_model_classes:
self.model_tester.check_use_cache_forward(model_class, config, inputs_dict)
def test_use_cache_forward_with_attn_mask(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs()
for model_class in self.all_model_classes:
self.model_tester.check_use_cache_forward_with_attn_mask(model_class, config, inputs_dict)
def test_encode(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
with self.subTest(model_class.__name__):
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
model = model_class(config)
@jax.jit
def encode_jitted(input_ids, attention_mask=None, **kwargs):
return model.encode(input_ids=input_ids, attention_mask=attention_mask)
with self.subTest("JIT Enabled"):
jitted_outputs = encode_jitted(**prepared_inputs_dict).to_tuple()
with self.subTest("JIT Disabled"):
with jax.disable_jit():
outputs = encode_jitted(**prepared_inputs_dict).to_tuple()
self.assertEqual(len(outputs), len(jitted_outputs))
for jitted_output, output in zip(jitted_outputs, outputs):
self.assertEqual(jitted_output.shape, output.shape)
def test_decode(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
with self.subTest(model_class.__name__):
model = model_class(config)
encoder_outputs = model.encode(inputs_dict["input_ids"], inputs_dict["attention_mask"])
prepared_inputs_dict = {
"decoder_input_ids": inputs_dict["decoder_input_ids"],
"decoder_attention_mask": inputs_dict["decoder_attention_mask"],
"encoder_outputs": encoder_outputs,
}
@jax.jit
def decode_jitted(decoder_input_ids, decoder_attention_mask, encoder_outputs):
return model.decode(
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
encoder_outputs=encoder_outputs,
)
with self.subTest("JIT Enabled"):
jitted_outputs = decode_jitted(**prepared_inputs_dict).to_tuple()
with self.subTest("JIT Disabled"):
with jax.disable_jit():
outputs = decode_jitted(**prepared_inputs_dict).to_tuple()
self.assertEqual(len(outputs), len(jitted_outputs))
for jitted_output, output in zip(jitted_outputs, outputs):
self.assertEqual(jitted_output.shape, output.shape)
@slow
def test_model_from_pretrained(self):
for model_class_name in self.all_model_classes:
model = model_class_name.from_pretrained("facebook/bart-base", from_pt=True)
# FlaxBartForSequenceClassification expects eos token in input_ids
input_ids = np.ones((1, 1)) * model.config.eos_token_id
outputs = model(input_ids)
self.assertIsNotNone(outputs)
...@@ -22,6 +22,7 @@ import numpy as np ...@@ -22,6 +22,7 @@ import numpy as np
import transformers import transformers
from transformers import is_flax_available, is_torch_available from transformers import is_flax_available, is_torch_available
from transformers.models.auto import get_values
from transformers.testing_utils import is_pt_flax_cross_test, require_flax from transformers.testing_utils import is_pt_flax_cross_test, require_flax
...@@ -31,6 +32,7 @@ if is_flax_available(): ...@@ -31,6 +32,7 @@ if is_flax_available():
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
import jaxlib.xla_extension as jax_xla import jaxlib.xla_extension as jax_xla
from transformers import FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING
from transformers.modeling_flax_pytorch_utils import ( from transformers.modeling_flax_pytorch_utils import (
convert_pytorch_state_dict_to_flax, convert_pytorch_state_dict_to_flax,
load_flax_weights_in_pytorch_model, load_flax_weights_in_pytorch_model,
...@@ -42,6 +44,14 @@ if is_torch_available(): ...@@ -42,6 +44,14 @@ if is_torch_available():
import torch import torch
def _config_zero_init(config):
configs_no_init = copy.deepcopy(config)
for key in configs_no_init.__dict__.keys():
if "_range" in key or "_std" in key or "initializer_factor" in key:
setattr(configs_no_init, key, 1e-10)
return configs_no_init
def ids_tensor(shape, vocab_size, rng=None): def ids_tensor(shape, vocab_size, rng=None):
"""Creates a random int32 tensor of the shape within the vocab size.""" """Creates a random int32 tensor of the shape within the vocab size."""
if rng is None: if rng is None:
...@@ -87,6 +97,7 @@ def random_attention_mask(shape, rng=None): ...@@ -87,6 +97,7 @@ def random_attention_mask(shape, rng=None):
class FlaxModelTesterMixin: class FlaxModelTesterMixin:
model_tester = None model_tester = None
all_model_classes = () all_model_classes = ()
is_encoder_decoder = False
def _prepare_for_class(self, inputs_dict, model_class): def _prepare_for_class(self, inputs_dict, model_class):
inputs_dict = copy.deepcopy(inputs_dict) inputs_dict = copy.deepcopy(inputs_dict)
...@@ -156,6 +167,9 @@ class FlaxModelTesterMixin: ...@@ -156,6 +167,9 @@ class FlaxModelTesterMixin:
pt_model_class = getattr(transformers, pt_model_class_name) pt_model_class = getattr(transformers, pt_model_class_name)
pt_model = pt_model_class(config).eval() pt_model = pt_model_class(config).eval()
# Flax models don't use the `use_cache` option and cache is not returned as a default.
# So we disable `use_cache` here for PyTorch model.
pt_model.config.use_cache = False
fx_model = model_class(config, dtype=jnp.float32) fx_model = model_class(config, dtype=jnp.float32)
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model) fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
...@@ -167,7 +181,7 @@ class FlaxModelTesterMixin: ...@@ -167,7 +181,7 @@ class FlaxModelTesterMixin:
fx_outputs = fx_model(**prepared_inputs_dict).to_tuple() fx_outputs = fx_model(**prepared_inputs_dict).to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output in zip(fx_outputs, pt_outputs): for fx_output, pt_output in zip(fx_outputs, pt_outputs):
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2) self.assert_almost_equals(fx_output, pt_output.numpy(), 1e-3)
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname) pt_model.save_pretrained(tmpdirname)
...@@ -178,7 +192,10 @@ class FlaxModelTesterMixin: ...@@ -178,7 +192,10 @@ class FlaxModelTesterMixin:
len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch" len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch"
) )
for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs): for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs):
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 4e-2) if not isinstance(
fx_output_loaded, tuple
): # TODO(Patrick, Daniel) - let's discard use_cache for now
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 1e-3)
@is_pt_flax_cross_test @is_pt_flax_cross_test
def test_equivalence_flax_to_pt(self): def test_equivalence_flax_to_pt(self):
...@@ -195,6 +212,9 @@ class FlaxModelTesterMixin: ...@@ -195,6 +212,9 @@ class FlaxModelTesterMixin:
pt_model_class = getattr(transformers, pt_model_class_name) pt_model_class = getattr(transformers, pt_model_class_name)
pt_model = pt_model_class(config).eval() pt_model = pt_model_class(config).eval()
# Flax models don't use the `use_cache` option and cache is not returned as a default.
# So we disable `use_cache` here for PyTorch model.
pt_model.config.use_cache = False
fx_model = model_class(config, dtype=jnp.float32) fx_model = model_class(config, dtype=jnp.float32)
pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params) pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)
...@@ -207,8 +227,9 @@ class FlaxModelTesterMixin: ...@@ -207,8 +227,9 @@ class FlaxModelTesterMixin:
fx_outputs = fx_model(**prepared_inputs_dict).to_tuple() fx_outputs = fx_model(**prepared_inputs_dict).to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output in zip(fx_outputs, pt_outputs): for fx_output, pt_output in zip(fx_outputs, pt_outputs):
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2) self.assert_almost_equals(fx_output, pt_output.numpy(), 1e-3)
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
fx_model.save_pretrained(tmpdirname) fx_model.save_pretrained(tmpdirname)
...@@ -221,7 +242,8 @@ class FlaxModelTesterMixin: ...@@ -221,7 +242,8 @@ class FlaxModelTesterMixin:
len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch" len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch"
) )
for fx_output, pt_output in zip(fx_outputs, pt_outputs_loaded): for fx_output, pt_output in zip(fx_outputs, pt_outputs_loaded):
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2) if not isinstance(fx_output, tuple): # TODO(Patrick, Daniel) - let's discard use_cache for now
self.assert_almost_equals(fx_output, pt_output.numpy(), 5e-3)
def test_from_pretrained_save_pretrained(self): def test_from_pretrained_save_pretrained(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
...@@ -276,6 +298,7 @@ class FlaxModelTesterMixin: ...@@ -276,6 +298,7 @@ class FlaxModelTesterMixin:
self.assertEqual(len(outputs), len(jitted_outputs)) self.assertEqual(len(outputs), len(jitted_outputs))
for jitted_output, output in zip(jitted_outputs, outputs): for jitted_output, output in zip(jitted_outputs, outputs):
self.assertEqual(jitted_output.shape, output.shape) self.assertEqual(jitted_output.shape, output.shape)
def test_forward_signature(self): def test_forward_signature(self):
...@@ -287,8 +310,17 @@ class FlaxModelTesterMixin: ...@@ -287,8 +310,17 @@ class FlaxModelTesterMixin:
# signature.parameters is an OrderedDict => so arg_names order is deterministic # signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()] arg_names = [*signature.parameters.keys()]
expected_arg_names = ["input_ids", "attention_mask"] if model.config.is_encoder_decoder:
self.assertListEqual(arg_names[:2], expected_arg_names) expected_arg_names = [
"input_ids",
"attention_mask",
"decoder_input_ids",
"decoder_attention_mask",
]
self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
else:
expected_arg_names = ["input_ids", "attention_mask"]
self.assertListEqual(arg_names[:2], expected_arg_names)
def test_naming_convention(self): def test_naming_convention(self):
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
...@@ -306,16 +338,36 @@ class FlaxModelTesterMixin: ...@@ -306,16 +338,36 @@ class FlaxModelTesterMixin:
model = model_class(config) model = model_class(config)
outputs = model(**self._prepare_for_class(inputs_dict, model_class)) outputs = model(**self._prepare_for_class(inputs_dict, model_class))
hidden_states = outputs.hidden_states hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states
self.assertEqual(len(hidden_states), self.model_tester.num_hidden_layers + 1) expected_num_layers = getattr(
seq_length = self.model_tester.seq_length self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
)
self.assertEqual(len(hidden_states), expected_num_layers)
if hasattr(self.model_tester, "encoder_seq_length"):
seq_length = self.model_tester.encoder_seq_length
else:
seq_length = self.model_tester.seq_length
self.assertListEqual( self.assertListEqual(
list(hidden_states[0].shape[-2:]), list(hidden_states[0].shape[-2:]),
[seq_length, self.model_tester.hidden_size], [seq_length, self.model_tester.hidden_size],
) )
if config.is_encoder_decoder:
hidden_states = outputs.decoder_hidden_states
self.assertIsInstance(hidden_states, (list, tuple))
self.assertEqual(len(hidden_states), expected_num_layers)
seq_len = getattr(self.model_tester, "seq_length", None)
decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", seq_len)
self.assertListEqual(
list(hidden_states[0].shape[-2:]),
[decoder_seq_length, self.model_tester.hidden_size],
)
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
...@@ -333,13 +385,17 @@ class FlaxModelTesterMixin: ...@@ -333,13 +385,17 @@ class FlaxModelTesterMixin:
config.return_dict = True config.return_dict = True
seq_length = getattr(self.model_tester, "seq_length", None) seq_length = getattr(self.model_tester, "seq_length", None)
decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", seq_length)
encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_length)
decoder_key_length = getattr(self.model_tester, "decoder_key_length", decoder_seq_length)
encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
inputs_dict["output_attentions"] = True inputs_dict["output_attentions"] = True
inputs_dict["output_hidden_states"] = False inputs_dict["output_hidden_states"] = False
model = model_class(config) model = model_class(config)
outputs = model(**self._prepare_for_class(inputs_dict, model_class)) outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs.attentions attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
# check that output_attentions also work using config # check that output_attentions also work using config
...@@ -347,22 +403,58 @@ class FlaxModelTesterMixin: ...@@ -347,22 +403,58 @@ class FlaxModelTesterMixin:
config.output_attentions = True config.output_attentions = True
model = model_class(config) model = model_class(config)
outputs = model(**self._prepare_for_class(inputs_dict, model_class)) outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs.attentions attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
self.assertListEqual( self.assertListEqual(
list(attentions[0].shape[-3:]), list(attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, seq_length, seq_length], [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
) )
out_len = len(outputs) out_len = len(outputs)
if self.is_encoder_decoder:
correct_outlen = 5
# Question Answering model returns start_logits and end_logits
if model_class in get_values(FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING):
correct_outlen += 1 # start_logits and end_logits instead of only 1 output
self.assertEqual(out_len, correct_outlen)
# decoder attentions
decoder_attentions = outputs.decoder_attentions
self.assertIsInstance(decoder_attentions, (list, tuple))
self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
list(decoder_attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, decoder_seq_length, decoder_key_length],
)
# cross attentions
cross_attentions = outputs.cross_attentions
self.assertIsInstance(cross_attentions, (list, tuple))
self.assertEqual(len(cross_attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
list(cross_attentions[0].shape[-3:]),
[
self.model_tester.num_attention_heads,
decoder_seq_length,
encoder_key_length,
],
)
# Check attention is always last and order is fine # Check attention is always last and order is fine
inputs_dict["output_attentions"] = True inputs_dict["output_attentions"] = True
inputs_dict["output_hidden_states"] = True inputs_dict["output_hidden_states"] = True
model = model_class(config) model = model_class(config)
outputs = model(**self._prepare_for_class(inputs_dict, model_class)) outputs = model(**self._prepare_for_class(inputs_dict, model_class))
added_hidden_states = 1 if hasattr(self.model_tester, "num_hidden_states_types"):
added_hidden_states = self.model_tester.num_hidden_states_types
elif self.is_encoder_decoder:
added_hidden_states = 2
else:
added_hidden_states = 1
self.assertEqual(out_len + added_hidden_states, len(outputs)) self.assertEqual(out_len + added_hidden_states, len(outputs))
self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
...@@ -370,5 +462,5 @@ class FlaxModelTesterMixin: ...@@ -370,5 +462,5 @@ class FlaxModelTesterMixin:
self.assertListEqual( self.assertListEqual(
list(self_attentions[0].shape[-3:]), list(self_attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, seq_length, seq_length], [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment