Unverified Commit 4a51b1dd authored by Daniel Stancl's avatar Daniel Stancl Committed by GitHub
Browse files

FlaxBart (#11537)



* Start working on FlaxBart

* Create modeling_flax_bart.py

* Write FlaxBartAttention

* Add FlaxBartEncoderLayer

* Add FlaxBartDecoderLayer and some typing

* Add helepr function for FlaxBart

* shift_tokens_right

* _make_causal_mask

* _expand_mask

* Add PositionalEmbedding and fix init_std naming

* Add FlaxBartPretrainedModel

* Add FlaxBartEncoder

* Add FlaxBartEncoder

* Add FlaxBartEncoder among modules to be imported

* YET WE CANNOT INITIALIZE THAT!! :(

* Make BartEncoder working

Change BartEncoder to instance of nn.Module so far

* Add FlaxBartDecoder

* Add FlaxBartModel

* TODO to make model run -> Prepapre model inputs

* Resolve padding

* Add FlaxBartModel

* Add FlaxBartModel into importable modules

* Remove FlaxBartEncoder and FlaxBartDecoder from importable modules

* make style; not properly working

* make style; make quality not pass due to some import I left

* Remove TODO for padding_idx in nn.Embed so far

* Add FlaxBartForConditionalGeneration

* Incorporate Flax model output classes, i.e. return_dict

* Add another models and incorporate use_cache arg

* Add FlaxBartForSequenceClassification and FlaxBartForQuestionAnswering

* Incorporate use_cache arg from PyTorch implementation

* Add all necessary Flax output utils

* Add FlaxBartForCausalLM; not working yet'

* Add minor improvements; still lacks some functionality

* Update docs, src and tests

* Add support of FlaxBart to docs/source

* Fix some bugs in FlaxBart souce code

* Add some neccessary tests for FlaxBart models - jit_compilation not passing

* Fix tests and add test_head_masking

* Fix tests for @jax.jit computation

* Add test_head_masking

* Migrate FlaxBart tests from jax.numpy to numpy

* Remove FlaxBartForCausalLM

* Clean repo

* fix bart model weight structure

* Fix FlaxBartForSequenceClassification

Slicing is not possible to use below jit, therefore, selecting sentence
representation from hidden_states must be changed.

* Allow FlaxBartForSequenceClassification for testing pt_flax equivalence

* Allow testing for FlaxBartForQA for pt_flax equivalence

* Add a comment to FlaxBartForSequenceClassification + change noise from 1e-3 to 1e-6

* remove past_key_values

* remove inputs_mebeds and make input_ids required

* add position ids

* re-write attention layer

* fix dataclass

* fix pos embeds and attention output

* fix pos embeds

* expose encode method

* expose decode method

* move docstring to top

* add cache for causal attn layer

* remove head masking for now

* s2s greedy search first pass

* boom boom

* fix typos

* fix greedy generate for bart

* use encoder, decoder layers instead of num_hidden_layers

* handle encoder_outputs

* cleanup

* simplify decoding

* more clean-up

* typos

* Change header + add {decoder_,}position_ids into 2 models

* add BartConfig

* fix existing tests

* add encode, decode methods

* Fix shift_tokens_right for JIT compilation + clarify one condition

* fix decode

* encoder => encode

* simplify generate

* add tests for encode and decode

* style

* add tests for cache

* fix equivalence tests

* sample generate now works with seq2seq

* generation tests

* initialize dense layers

* docstring and cleanup

* quality

* remove get/set input_embeddings

* address Patricks suggestions

* decode for every model, remove encoder_outputs from call

* update tests accordingly

* decode returns only decoder outputs and logits

* fix arguments

* doc encode, decode methods

* correct base_model_prefix

* fix test for seq classif model

* fix docs
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>
parent d36fce82
...@@ -299,7 +299,7 @@ Flax), PyTorch, and/or TensorFlow. ...@@ -299,7 +299,7 @@ Flax), PyTorch, and/or TensorFlow.
+=============================+================+================+=================+====================+==============+ +=============================+================+================+=================+====================+==============+
| ALBERT | ✅ | ✅ | ✅ | ✅ | ❌ | | ALBERT | ✅ | ✅ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| BART | ✅ | ✅ | ✅ | ✅ | | | BART | ✅ | ✅ | ✅ | ✅ | |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| BERT | ✅ | ✅ | ✅ | ✅ | ✅ | | BERT | ✅ | ✅ | ✅ | ✅ | ✅ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
......
...@@ -131,6 +131,7 @@ BartForQuestionAnswering ...@@ -131,6 +131,7 @@ BartForQuestionAnswering
.. autoclass:: transformers.BartForQuestionAnswering .. autoclass:: transformers.BartForQuestionAnswering
:members: forward :members: forward
BartForCausalLM BartForCausalLM
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
...@@ -138,7 +139,6 @@ BartForCausalLM ...@@ -138,7 +139,6 @@ BartForCausalLM
:members: forward :members: forward
TFBartModel TFBartModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
...@@ -151,3 +151,32 @@ TFBartForConditionalGeneration ...@@ -151,3 +151,32 @@ TFBartForConditionalGeneration
.. autoclass:: transformers.TFBartForConditionalGeneration .. autoclass:: transformers.TFBartForConditionalGeneration
:members: call :members: call
FlaxBartModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxBartModel
:members: __call__, encode, decode
FlaxBartForConditionalGeneration
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxBartForConditionalGeneration
:members: __call__, encode, decode
FlaxBartForSequenceClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxBartForSequenceClassification
:members: __call__, encode, decode
FlaxBartForQuestionAnswering
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxBartForQuestionAnswering
:members: __call__, encode, decode
...@@ -1508,6 +1508,14 @@ if is_flax_available(): ...@@ -1508,6 +1508,14 @@ if is_flax_available():
"FlaxAutoModelForTokenClassification", "FlaxAutoModelForTokenClassification",
] ]
) )
_import_structure["models.bart"].extend(
[
"FlaxBartForConditionalGeneration",
"FlaxBartForQuestionAnswering",
"FlaxBartForSequenceClassification",
"FlaxBartModel",
]
)
_import_structure["models.bert"].extend( _import_structure["models.bert"].extend(
[ [
"FlaxBertForMaskedLM", "FlaxBertForMaskedLM",
...@@ -2808,6 +2816,12 @@ if TYPE_CHECKING: ...@@ -2808,6 +2816,12 @@ if TYPE_CHECKING:
FlaxAutoModelForSequenceClassification, FlaxAutoModelForSequenceClassification,
FlaxAutoModelForTokenClassification, FlaxAutoModelForTokenClassification,
) )
from .models.bart import (
FlaxBartForConditionalGeneration,
FlaxBartForQuestionAnswering,
FlaxBartForSequenceClassification,
FlaxBartModel,
)
from .models.bert import ( from .models.bert import (
FlaxBertForMaskedLM, FlaxBertForMaskedLM,
FlaxBertForMultipleChoice, FlaxBertForMultipleChoice,
......
...@@ -101,12 +101,23 @@ class FlaxGenerationMixin: ...@@ -101,12 +101,23 @@ class FlaxGenerationMixin:
state = body_fn(state) state = body_fn(state)
return state return state
def _prepare_encoder_decoder_kwargs_for_generation(self, input_ids, model_kwargs):
encoder_kwargs = {
argument: value
for argument, value in model_kwargs.items()
if not (argument.startswith("decoder_") or argument.startswith("cross_attn"))
}
model_kwargs["encoder_outputs"] = self.encode(input_ids, return_dict=True, **encoder_kwargs)
return model_kwargs
def generate( def generate(
self, self,
input_ids: jax_xla.DeviceArray, input_ids: jax_xla.DeviceArray,
max_length: Optional[int] = None, max_length: Optional[int] = None,
pad_token_id: Optional[int] = None, pad_token_id: Optional[int] = None,
bos_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None, eos_token_id: Optional[int] = None,
decoder_start_token_id: Optional[int] = None,
do_sample: Optional[bool] = None, do_sample: Optional[bool] = None,
prng_key: Optional[jax_xla.DeviceArray] = None, prng_key: Optional[jax_xla.DeviceArray] = None,
top_k: Optional[int] = None, top_k: Optional[int] = None,
...@@ -147,6 +158,8 @@ class FlaxGenerationMixin: ...@@ -147,6 +158,8 @@ class FlaxGenerationMixin:
The id of the `beginning-of-sequence` token. The id of the `beginning-of-sequence` token.
eos_token_id (:obj:`int`, `optional`): eos_token_id (:obj:`int`, `optional`):
The id of the `end-of-sequence` token. The id of the `end-of-sequence` token.
decoder_start_token_id (:obj:`int`, `optional`):
If an encoder-decoder model starts decoding with a different token than `bos`, the id of that token.
trace (:obj:`bool`, `optional`, defaults to :obj:`True`): trace (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether to trace generation. Setting ``trace=False`` should only be used for debugging and will lead to Whether to trace generation. Setting ``trace=False`` should only be used for debugging and will lead to
a considerably slower runtime. a considerably slower runtime.
...@@ -170,10 +183,23 @@ class FlaxGenerationMixin: ...@@ -170,10 +183,23 @@ class FlaxGenerationMixin:
""" """
# set init values # set init values
max_length = max_length if max_length is not None else self.config.max_length max_length = max_length if max_length is not None else self.config.max_length
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
decoder_start_token_id = (
decoder_start_token_id if decoder_start_token_id else self.config.decoder_start_token_id
)
prng_key = prng_key if prng_key is not None else jax.random.PRNGKey(0) prng_key = prng_key if prng_key is not None else jax.random.PRNGKey(0)
if decoder_start_token_id is None and self.config.is_encoder_decoder:
raise ValueError("`decoder_start_token_id` has to be defined for encoder-decoder generation.")
if self.config.is_encoder_decoder:
# add encoder_outputs to model_kwargs
model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(input_ids, model_kwargs)
# prepare decoder_input_ids for generation
input_ids = jnp.ones((input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id
do_sample = do_sample if do_sample is not None else self.config.do_sample do_sample = do_sample if do_sample is not None else self.config.do_sample
if do_sample: if do_sample:
...@@ -246,10 +272,11 @@ class FlaxGenerationMixin: ...@@ -246,10 +272,11 @@ class FlaxGenerationMixin:
# per batch-item state bit indicating if sentence has finished. # per batch-item state bit indicating if sentence has finished.
is_sent_finished = jnp.zeros((batch_size,), dtype=jnp.bool_) is_sent_finished = jnp.zeros((batch_size,), dtype=jnp.bool_)
model = self # For Seq2Seq generation, we only need to use the decoder instead of the whole model in generation loop
# and pass it the `encoder_outputs`, which are part of the `model_kwargs`.
model = self.decode if self.config.is_encoder_decoder else self
# initialize model specific kwargs # initialize model specific kwargs
model_kwargs = model.prepare_inputs_for_generation(input_ids, max_length, **model_kwargs) model_kwargs = self.prepare_inputs_for_generation(input_ids, max_length, **model_kwargs)
# initialize state # initialize state
state = GreedyState( state = GreedyState(
...@@ -277,8 +304,7 @@ class FlaxGenerationMixin: ...@@ -277,8 +304,7 @@ class FlaxGenerationMixin:
next_token = next_token[:, None] next_token = next_token[:, None]
next_sequences = lax.dynamic_update_slice(state.sequences, next_token, (0, state.cur_len)) next_sequences = lax.dynamic_update_slice(state.sequences, next_token, (0, state.cur_len))
next_model_kwargs = model.update_inputs_for_generation(model_outputs, model_kwargs) next_model_kwargs = self.update_inputs_for_generation(model_outputs, state.model_kwargs)
return GreedyState( return GreedyState(
cur_len=state.cur_len + 1, cur_len=state.cur_len + 1,
sequences=next_sequences, sequences=next_sequences,
...@@ -288,7 +314,8 @@ class FlaxGenerationMixin: ...@@ -288,7 +314,8 @@ class FlaxGenerationMixin:
) )
# The very first prompt often has sequence length > 1, so run outside of `lax.while_loop` to comply with TPU # The very first prompt often has sequence length > 1, so run outside of `lax.while_loop` to comply with TPU
state = greedy_search_body_fn(state) if input_ids.shape[1] > 1:
state = greedy_search_body_fn(state)
if not trace: if not trace:
state = self._run_loop_in_debug(greedy_search_cond_fn, greedy_search_body_fn, state) state = self._run_loop_in_debug(greedy_search_cond_fn, greedy_search_body_fn, state)
...@@ -327,10 +354,12 @@ class FlaxGenerationMixin: ...@@ -327,10 +354,12 @@ class FlaxGenerationMixin:
# per batch-item state bit indicating if sentence has finished. # per batch-item state bit indicating if sentence has finished.
is_sent_finished = jnp.zeros((batch_size,), dtype=jnp.bool_) is_sent_finished = jnp.zeros((batch_size,), dtype=jnp.bool_)
model = self # For Seq2Seq generation, we only need to use the decoder instead of the whole model in generation loop
# and pass it the `encoder_outputs`, which are part of the `model_kwargs`.
model = self.decode if self.config.is_encoder_decoder else self
# initialize model specific kwargs # initialize model specific kwargs
model_kwargs = model.prepare_inputs_for_generation(input_ids, max_length, **model_kwargs) model_kwargs = self.prepare_inputs_for_generation(input_ids, max_length, **model_kwargs)
# initialize state # initialize state
state = SampleState( state = SampleState(
...@@ -366,7 +395,7 @@ class FlaxGenerationMixin: ...@@ -366,7 +395,7 @@ class FlaxGenerationMixin:
next_token = next_token[:, None] next_token = next_token[:, None]
next_sequences = lax.dynamic_update_slice(state.sequences, next_token, (0, state.cur_len)) next_sequences = lax.dynamic_update_slice(state.sequences, next_token, (0, state.cur_len))
next_model_kwargs = model.update_inputs_for_generation(model_outputs, model_kwargs) next_model_kwargs = self.update_inputs_for_generation(model_outputs, state.model_kwargs)
return SampleState( return SampleState(
cur_len=state.cur_len + 1, cur_len=state.cur_len + 1,
...@@ -378,7 +407,8 @@ class FlaxGenerationMixin: ...@@ -378,7 +407,8 @@ class FlaxGenerationMixin:
) )
# The very first prompt often has sequence length > 1, so run outside of `lax.while_loop` to comply with TPU # The very first prompt often has sequence length > 1, so run outside of `lax.while_loop` to comply with TPU
state = sample_search_body_fn(state) if input_ids.shape[1] > 1:
state = sample_search_body_fn(state)
if not trace: if not trace:
state = self._run_loop_in_debug(sample_search_cond_fn, sample_search_body_fn, state) state = self._run_loop_in_debug(sample_search_cond_fn, sample_search_body_fn, state)
......
This diff is collapsed.
...@@ -18,6 +18,12 @@ ...@@ -18,6 +18,12 @@
from collections import OrderedDict from collections import OrderedDict
from ...utils import logging from ...utils import logging
from ..bart.modeling_flax_bart import (
FlaxBartForConditionalGeneration,
FlaxBartForQuestionAnswering,
FlaxBartForSequenceClassification,
FlaxBartModel,
)
from ..bert.modeling_flax_bert import ( from ..bert.modeling_flax_bert import (
FlaxBertForMaskedLM, FlaxBertForMaskedLM,
FlaxBertForMultipleChoice, FlaxBertForMultipleChoice,
...@@ -49,7 +55,7 @@ from ..roberta.modeling_flax_roberta import ( ...@@ -49,7 +55,7 @@ from ..roberta.modeling_flax_roberta import (
) )
from ..vit.modeling_flax_vit import FlaxViTForImageClassification, FlaxViTModel from ..vit.modeling_flax_vit import FlaxViTForImageClassification, FlaxViTModel
from .auto_factory import auto_class_factory from .auto_factory import auto_class_factory
from .configuration_auto import BertConfig, CLIPConfig, ElectraConfig, GPT2Config, RobertaConfig, ViTConfig from .configuration_auto import BartConfig, BertConfig, CLIPConfig, ElectraConfig, GPT2Config, RobertaConfig, ViTConfig
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -60,6 +66,7 @@ FLAX_MODEL_MAPPING = OrderedDict( ...@@ -60,6 +66,7 @@ FLAX_MODEL_MAPPING = OrderedDict(
# Base model mapping # Base model mapping
(RobertaConfig, FlaxRobertaModel), (RobertaConfig, FlaxRobertaModel),
(BertConfig, FlaxBertModel), (BertConfig, FlaxBertModel),
(BartConfig, FlaxBartModel),
(GPT2Config, FlaxGPT2Model), (GPT2Config, FlaxGPT2Model),
(ElectraConfig, FlaxElectraModel), (ElectraConfig, FlaxElectraModel),
(CLIPConfig, FlaxCLIPModel), (CLIPConfig, FlaxCLIPModel),
...@@ -72,6 +79,7 @@ FLAX_MODEL_FOR_PRETRAINING_MAPPING = OrderedDict( ...@@ -72,6 +79,7 @@ FLAX_MODEL_FOR_PRETRAINING_MAPPING = OrderedDict(
# Model for pre-training mapping # Model for pre-training mapping
(RobertaConfig, FlaxRobertaForMaskedLM), (RobertaConfig, FlaxRobertaForMaskedLM),
(BertConfig, FlaxBertForPreTraining), (BertConfig, FlaxBertForPreTraining),
(BartConfig, FlaxBartForConditionalGeneration),
(ElectraConfig, FlaxElectraForPreTraining), (ElectraConfig, FlaxElectraForPreTraining),
] ]
) )
...@@ -81,6 +89,7 @@ FLAX_MODEL_FOR_MASKED_LM_MAPPING = OrderedDict( ...@@ -81,6 +89,7 @@ FLAX_MODEL_FOR_MASKED_LM_MAPPING = OrderedDict(
# Model for Masked LM mapping # Model for Masked LM mapping
(RobertaConfig, FlaxRobertaForMaskedLM), (RobertaConfig, FlaxRobertaForMaskedLM),
(BertConfig, FlaxBertForMaskedLM), (BertConfig, FlaxBertForMaskedLM),
(BartConfig, FlaxBartForConditionalGeneration),
(ElectraConfig, FlaxElectraForMaskedLM), (ElectraConfig, FlaxElectraForMaskedLM),
] ]
) )
...@@ -104,6 +113,7 @@ FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict( ...@@ -104,6 +113,7 @@ FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
# Model for Sequence Classification mapping # Model for Sequence Classification mapping
(RobertaConfig, FlaxRobertaForSequenceClassification), (RobertaConfig, FlaxRobertaForSequenceClassification),
(BertConfig, FlaxBertForSequenceClassification), (BertConfig, FlaxBertForSequenceClassification),
(BartConfig, FlaxBartForSequenceClassification),
(ElectraConfig, FlaxElectraForSequenceClassification), (ElectraConfig, FlaxElectraForSequenceClassification),
] ]
) )
...@@ -113,6 +123,7 @@ FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING = OrderedDict( ...@@ -113,6 +123,7 @@ FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING = OrderedDict(
# Model for Question Answering mapping # Model for Question Answering mapping
(RobertaConfig, FlaxRobertaForQuestionAnswering), (RobertaConfig, FlaxRobertaForQuestionAnswering),
(BertConfig, FlaxBertForQuestionAnswering), (BertConfig, FlaxBertForQuestionAnswering),
(BartConfig, FlaxBartForQuestionAnswering),
(ElectraConfig, FlaxElectraForQuestionAnswering), (ElectraConfig, FlaxElectraForQuestionAnswering),
] ]
) )
......
...@@ -17,7 +17,13 @@ ...@@ -17,7 +17,13 @@
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from ...file_utils import _BaseLazyModule, is_tf_available, is_tokenizers_available, is_torch_available from ...file_utils import (
_BaseLazyModule,
is_flax_available,
is_tf_available,
is_tokenizers_available,
is_torch_available,
)
_import_structure = { _import_structure = {
...@@ -43,6 +49,13 @@ if is_torch_available(): ...@@ -43,6 +49,13 @@ if is_torch_available():
if is_tf_available(): if is_tf_available():
_import_structure["modeling_tf_bart"] = ["TFBartForConditionalGeneration", "TFBartModel", "TFBartPretrainedModel"] _import_structure["modeling_tf_bart"] = ["TFBartForConditionalGeneration", "TFBartModel", "TFBartPretrainedModel"]
if is_flax_available():
_import_structure["modeling_flax_bart"] = [
"FlaxBartForConditionalGeneration",
"FlaxBartForQuestionAnswering",
"FlaxBartForSequenceClassification",
"FlaxBartModel",
]
if TYPE_CHECKING: if TYPE_CHECKING:
from .configuration_bart import BART_PRETRAINED_CONFIG_ARCHIVE_MAP, BartConfig from .configuration_bart import BART_PRETRAINED_CONFIG_ARCHIVE_MAP, BartConfig
...@@ -66,6 +79,14 @@ if TYPE_CHECKING: ...@@ -66,6 +79,14 @@ if TYPE_CHECKING:
if is_tf_available(): if is_tf_available():
from .modeling_tf_bart import TFBartForConditionalGeneration, TFBartModel, TFBartPretrainedModel from .modeling_tf_bart import TFBartForConditionalGeneration, TFBartModel, TFBartPretrainedModel
if is_flax_available():
from .modeling_flax_bart import (
FlaxBartForConditionalGeneration,
FlaxBartForQuestionAnswering,
FlaxBartForSequenceClassification,
FlaxBartModel,
)
else: else:
import importlib import importlib
import os import os
......
This diff is collapsed.
...@@ -149,6 +149,42 @@ class FlaxAutoModelForTokenClassification: ...@@ -149,6 +149,42 @@ class FlaxAutoModelForTokenClassification:
requires_backends(cls, ["flax"]) requires_backends(cls, ["flax"])
class FlaxBartForConditionalGeneration:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
class FlaxBartForQuestionAnswering:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
class FlaxBartForSequenceClassification:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
class FlaxBartModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
class FlaxBertForMaskedLM: class FlaxBertForMaskedLM:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"]) requires_backends(self, ["flax"])
......
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import timeout_decorator # noqa
from transformers import BartConfig, is_flax_available
from transformers.testing_utils import require_flax, slow
from .test_generation_flax_utils import FlaxGenerationTesterMixin
from .test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor
if is_flax_available():
import os
# The slow tests are often failing with OOM error on GPU
# This makes JAX allocate exactly what is needed on demand, and deallocate memory that is no longer needed
# but will be slower as stated here https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
import jax
import jax.numpy as jnp
from transformers.models.bart.modeling_flax_bart import (
FlaxBartForConditionalGeneration,
FlaxBartForQuestionAnswering,
FlaxBartForSequenceClassification,
FlaxBartModel,
shift_tokens_right,
)
def prepare_bart_inputs_dict(
config,
input_ids,
decoder_input_ids=None,
attention_mask=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
):
if attention_mask is None:
attention_mask = np.where(input_ids != config.pad_token_id, 1, 0)
if decoder_attention_mask is None:
decoder_attention_mask = np.where(decoder_input_ids != config.pad_token_id, 1, 0)
if head_mask is None:
head_mask = np.ones((config.encoder_layers, config.encoder_attention_heads))
if decoder_head_mask is None:
decoder_head_mask = np.ones((config.decoder_layers, config.decoder_attention_heads))
if cross_attn_head_mask is None:
cross_attn_head_mask = np.ones((config.decoder_layers, config.decoder_attention_heads))
return {
"input_ids": input_ids,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"decoder_attention_mask": attention_mask,
}
class FlaxBartModelTester(unittest.TestCase):
def __init__(
self,
parent,
batch_size=13,
seq_length=7,
is_training=True,
use_labels=False,
vocab_size=99,
hidden_size=16,
num_hidden_layers=2,
num_attention_heads=4,
intermediate_size=4,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=32,
eos_token_id=2,
pad_token_id=1,
bos_token_id=0,
initializer_range=0.02,
):
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
self.is_training = is_training
self.use_labels = use_labels
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.eos_token_id = eos_token_id
self.pad_token_id = pad_token_id
self.bos_token_id = bos_token_id
self.initializer_range = initializer_range
def prepare_config_and_inputs(self):
input_ids = np.clip(ids_tensor([self.batch_size, self.seq_length - 1], self.vocab_size), 3, self.vocab_size)
input_ids = np.concatenate((input_ids, 2 * np.ones((self.batch_size, 1), dtype=np.int64)), -1)
decoder_input_ids = shift_tokens_right(input_ids, 1, 2)
config = BartConfig(
vocab_size=self.vocab_size,
d_model=self.hidden_size,
encoder_layers=self.num_hidden_layers,
decoder_layers=self.num_hidden_layers,
encoder_attention_heads=self.num_attention_heads,
decoder_attention_heads=self.num_attention_heads,
encoder_ffn_dim=self.intermediate_size,
decoder_ffn_dim=self.intermediate_size,
dropout=self.hidden_dropout_prob,
attention_dropout=self.attention_probs_dropout_prob,
max_position_embeddings=self.max_position_embeddings,
eos_token_id=self.eos_token_id,
bos_token_id=self.bos_token_id,
pad_token_id=self.pad_token_id,
initializer_range=self.initializer_range,
use_cache=False,
)
inputs_dict = prepare_bart_inputs_dict(config, input_ids, decoder_input_ids)
return config, inputs_dict
def prepare_config_and_inputs_for_common(self):
config, inputs_dict = self.prepare_config_and_inputs()
return config, inputs_dict
def check_use_cache_forward(self, model_class_name, config, inputs_dict):
max_decoder_length = 20
model = model_class_name(config)
encoder_outputs = model.encode(inputs_dict["input_ids"])
decoder_input_ids, decoder_attention_mask = (
inputs_dict["decoder_input_ids"],
inputs_dict["decoder_attention_mask"],
)
past_key_values = model.init_cache(decoder_input_ids.shape[0], max_decoder_length, encoder_outputs)
decoder_attention_mask = jnp.ones((decoder_input_ids.shape[0], max_decoder_length), dtype="i4")
decoder_position_ids = jnp.broadcast_to(
jnp.arange(decoder_input_ids.shape[-1] - 1)[None, :],
(decoder_input_ids.shape[0], decoder_input_ids.shape[-1] - 1),
)
outputs_cache = model.decode(
decoder_input_ids[:, :-1],
encoder_outputs,
decoder_attention_mask=decoder_attention_mask,
past_key_values=past_key_values,
decoder_position_ids=decoder_position_ids,
)
decoder_position_ids = jnp.array(decoder_input_ids.shape[0] * [[decoder_input_ids.shape[-1] - 1]], dtype="i4")
outputs_cache_next = model.decode(
decoder_input_ids[:, -1:],
encoder_outputs,
decoder_attention_mask=decoder_attention_mask,
past_key_values=outputs_cache.past_key_values,
decoder_position_ids=decoder_position_ids,
)
outputs = model.decode(decoder_input_ids, encoder_outputs)
diff = np.max(np.abs((outputs_cache_next[0][:, -1, :5] - outputs[0][:, -1, :5])))
self.parent.assertTrue(diff < 1e-3, msg=f"Max diff is {diff}")
def check_use_cache_forward_with_attn_mask(self, model_class_name, config, inputs_dict):
max_decoder_length = 20
model = model_class_name(config)
encoder_outputs = model.encode(inputs_dict["input_ids"])
decoder_input_ids, decoder_attention_mask = (
inputs_dict["decoder_input_ids"],
inputs_dict["decoder_attention_mask"],
)
decoder_attention_mask_cache = jnp.concatenate(
[
decoder_attention_mask,
jnp.zeros((decoder_attention_mask.shape[0], max_decoder_length - decoder_attention_mask.shape[1])),
],
axis=-1,
)
past_key_values = model.init_cache(decoder_input_ids.shape[0], max_decoder_length, encoder_outputs)
decoder_position_ids = jnp.broadcast_to(
jnp.arange(decoder_input_ids.shape[-1] - 1)[None, :],
(decoder_input_ids.shape[0], decoder_input_ids.shape[-1] - 1),
)
outputs_cache = model.decode(
decoder_input_ids[:, :-1],
encoder_outputs,
decoder_attention_mask=decoder_attention_mask_cache,
past_key_values=past_key_values,
decoder_position_ids=decoder_position_ids,
)
decoder_position_ids = jnp.array(decoder_input_ids.shape[0] * [[decoder_input_ids.shape[-1] - 1]], dtype="i4")
outputs_cache_next = model.decode(
decoder_input_ids[:, -1:],
encoder_outputs,
past_key_values=outputs_cache.past_key_values,
decoder_attention_mask=decoder_attention_mask_cache,
decoder_position_ids=decoder_position_ids,
)
outputs = model.decode(decoder_input_ids, encoder_outputs, decoder_attention_mask=decoder_attention_mask)
diff = np.max(np.abs((outputs_cache_next[0][:, -1, :5] - outputs[0][:, -1, :5])))
self.parent.assertTrue(diff < 1e-3, msg=f"Max diff is {diff}")
@require_flax
class BartHeadTests(unittest.TestCase):
vocab_size = 99
def _get_config_and_data(self):
input_ids = np.array(
[
[71, 82, 18, 33, 46, 91, 2],
[68, 34, 26, 58, 30, 82, 2],
[5, 97, 17, 39, 94, 40, 2],
[76, 83, 94, 25, 70, 78, 2],
[87, 59, 41, 35, 48, 66, 2],
[55, 13, 16, 58, 5, 2, 1], # note padding
[64, 27, 31, 51, 12, 75, 2],
[52, 64, 86, 17, 83, 39, 2],
[48, 61, 9, 24, 71, 82, 2],
[26, 1, 60, 48, 22, 13, 2],
[21, 5, 62, 28, 14, 76, 2],
[45, 98, 37, 86, 59, 48, 2],
[70, 70, 50, 9, 28, 0, 2],
],
dtype=np.int64,
)
batch_size = input_ids.shape[0]
config = BartConfig(
vocab_size=self.vocab_size,
d_model=24,
encoder_layers=2,
decoder_layers=2,
encoder_attention_heads=2,
decoder_attention_heads=2,
encoder_ffn_dim=32,
decoder_ffn_dim=32,
max_position_embeddings=48,
eos_token_id=2,
pad_token_id=1,
bos_token_id=0,
)
return config, input_ids, batch_size
def test_sequence_classification_forward(self):
config, input_ids, batch_size = self._get_config_and_data()
model = FlaxBartForSequenceClassification(config)
outputs = model(input_ids=input_ids, decoder_input_ids=input_ids)
expected_shape = (batch_size, config.num_labels)
self.assertEqual(outputs["logits"].shape, expected_shape)
def test_question_answering_forward(self):
config, input_ids, batch_size = self._get_config_and_data()
model = FlaxBartForQuestionAnswering(config)
outputs = model(input_ids=input_ids)
self.assertEqual(outputs["start_logits"].shape, input_ids.shape)
self.assertEqual(outputs["end_logits"].shape, input_ids.shape)
# @timeout_decorator.timeout(1) # not working with the decorator so far
def test_lm_forward(self):
config, input_ids, batch_size = self._get_config_and_data()
lm_model = FlaxBartForConditionalGeneration(config)
outputs = lm_model(input_ids=input_ids)
expected_shape = (batch_size, input_ids.shape[1], config.vocab_size)
self.assertEqual(outputs["logits"].shape, expected_shape)
def test_lm_uneven_forward(self):
config = BartConfig(
vocab_size=self.vocab_size,
d_model=14,
encoder_layers=2,
decoder_layers=2,
encoder_attention_heads=2,
decoder_attention_heads=2,
encoder_ffn_dim=8,
decoder_ffn_dim=8,
max_position_embeddings=48,
)
lm_model = FlaxBartForConditionalGeneration(config)
context = np.array([[71, 82, 18, 33, 46, 91, 2], [68, 34, 26, 58, 30, 2, 1]], dtype=np.int64)
summary = np.array([[82, 71, 82, 18, 2], [58, 68, 2, 1, 1]], dtype=np.int64)
outputs = lm_model(input_ids=context, decoder_input_ids=summary)
expected_shape = (*summary.shape, config.vocab_size)
self.assertEqual(outputs["logits"].shape, expected_shape)
def test_shift_tokens_right(self):
input_ids = np.array([[71, 82, 18, 33, 2, 1, 1], [68, 34, 26, 58, 30, 82, 2]], dtype=np.int64)
shifted = shift_tokens_right(input_ids, 1, 2)
n_pad_before = np.equal(input_ids, 1).astype(np.float32).sum()
n_pad_after = np.equal(shifted, 1).astype(np.float32).sum()
self.assertEqual(shifted.shape, input_ids.shape)
self.assertEqual(n_pad_after, n_pad_before - 1)
self.assertTrue(np.equal(shifted[:, 0], 2).all())
@require_flax
class FlaxBartModelTest(FlaxModelTesterMixin, unittest.TestCase, FlaxGenerationTesterMixin):
is_encoder_decoder = True
all_model_classes = (
(
FlaxBartModel,
FlaxBartForConditionalGeneration,
FlaxBartForSequenceClassification,
FlaxBartForQuestionAnswering,
)
if is_flax_available()
else ()
)
all_generative_model_classes = (FlaxBartForConditionalGeneration,) if is_flax_available() else ()
def setUp(self):
self.model_tester = FlaxBartModelTester(self)
def test_use_cache_forward(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs()
for model_class in self.all_model_classes:
self.model_tester.check_use_cache_forward(model_class, config, inputs_dict)
def test_use_cache_forward_with_attn_mask(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs()
for model_class in self.all_model_classes:
self.model_tester.check_use_cache_forward_with_attn_mask(model_class, config, inputs_dict)
def test_encode(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
with self.subTest(model_class.__name__):
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
model = model_class(config)
@jax.jit
def encode_jitted(input_ids, attention_mask=None, **kwargs):
return model.encode(input_ids=input_ids, attention_mask=attention_mask)
with self.subTest("JIT Enabled"):
jitted_outputs = encode_jitted(**prepared_inputs_dict).to_tuple()
with self.subTest("JIT Disabled"):
with jax.disable_jit():
outputs = encode_jitted(**prepared_inputs_dict).to_tuple()
self.assertEqual(len(outputs), len(jitted_outputs))
for jitted_output, output in zip(jitted_outputs, outputs):
self.assertEqual(jitted_output.shape, output.shape)
def test_decode(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
with self.subTest(model_class.__name__):
model = model_class(config)
encoder_outputs = model.encode(inputs_dict["input_ids"], inputs_dict["attention_mask"])
prepared_inputs_dict = {
"decoder_input_ids": inputs_dict["decoder_input_ids"],
"decoder_attention_mask": inputs_dict["decoder_attention_mask"],
"encoder_outputs": encoder_outputs,
}
@jax.jit
def decode_jitted(decoder_input_ids, decoder_attention_mask, encoder_outputs):
return model.decode(
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
encoder_outputs=encoder_outputs,
)
with self.subTest("JIT Enabled"):
jitted_outputs = decode_jitted(**prepared_inputs_dict).to_tuple()
with self.subTest("JIT Disabled"):
with jax.disable_jit():
outputs = decode_jitted(**prepared_inputs_dict).to_tuple()
self.assertEqual(len(outputs), len(jitted_outputs))
for jitted_output, output in zip(jitted_outputs, outputs):
self.assertEqual(jitted_output.shape, output.shape)
@slow
def test_model_from_pretrained(self):
for model_class_name in self.all_model_classes:
model = model_class_name.from_pretrained("facebook/bart-base", from_pt=True)
# FlaxBartForSequenceClassification expects eos token in input_ids
input_ids = np.ones((1, 1)) * model.config.eos_token_id
outputs = model(input_ids)
self.assertIsNotNone(outputs)
...@@ -22,6 +22,7 @@ import numpy as np ...@@ -22,6 +22,7 @@ import numpy as np
import transformers import transformers
from transformers import is_flax_available, is_torch_available from transformers import is_flax_available, is_torch_available
from transformers.models.auto import get_values
from transformers.testing_utils import is_pt_flax_cross_test, require_flax from transformers.testing_utils import is_pt_flax_cross_test, require_flax
...@@ -31,6 +32,7 @@ if is_flax_available(): ...@@ -31,6 +32,7 @@ if is_flax_available():
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
import jaxlib.xla_extension as jax_xla import jaxlib.xla_extension as jax_xla
from transformers import FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING
from transformers.modeling_flax_pytorch_utils import ( from transformers.modeling_flax_pytorch_utils import (
convert_pytorch_state_dict_to_flax, convert_pytorch_state_dict_to_flax,
load_flax_weights_in_pytorch_model, load_flax_weights_in_pytorch_model,
...@@ -42,6 +44,14 @@ if is_torch_available(): ...@@ -42,6 +44,14 @@ if is_torch_available():
import torch import torch
def _config_zero_init(config):
configs_no_init = copy.deepcopy(config)
for key in configs_no_init.__dict__.keys():
if "_range" in key or "_std" in key or "initializer_factor" in key:
setattr(configs_no_init, key, 1e-10)
return configs_no_init
def ids_tensor(shape, vocab_size, rng=None): def ids_tensor(shape, vocab_size, rng=None):
"""Creates a random int32 tensor of the shape within the vocab size.""" """Creates a random int32 tensor of the shape within the vocab size."""
if rng is None: if rng is None:
...@@ -87,6 +97,7 @@ def random_attention_mask(shape, rng=None): ...@@ -87,6 +97,7 @@ def random_attention_mask(shape, rng=None):
class FlaxModelTesterMixin: class FlaxModelTesterMixin:
model_tester = None model_tester = None
all_model_classes = () all_model_classes = ()
is_encoder_decoder = False
def _prepare_for_class(self, inputs_dict, model_class): def _prepare_for_class(self, inputs_dict, model_class):
inputs_dict = copy.deepcopy(inputs_dict) inputs_dict = copy.deepcopy(inputs_dict)
...@@ -156,6 +167,9 @@ class FlaxModelTesterMixin: ...@@ -156,6 +167,9 @@ class FlaxModelTesterMixin:
pt_model_class = getattr(transformers, pt_model_class_name) pt_model_class = getattr(transformers, pt_model_class_name)
pt_model = pt_model_class(config).eval() pt_model = pt_model_class(config).eval()
# Flax models don't use the `use_cache` option and cache is not returned as a default.
# So we disable `use_cache` here for PyTorch model.
pt_model.config.use_cache = False
fx_model = model_class(config, dtype=jnp.float32) fx_model = model_class(config, dtype=jnp.float32)
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model) fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
...@@ -167,7 +181,7 @@ class FlaxModelTesterMixin: ...@@ -167,7 +181,7 @@ class FlaxModelTesterMixin:
fx_outputs = fx_model(**prepared_inputs_dict).to_tuple() fx_outputs = fx_model(**prepared_inputs_dict).to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output in zip(fx_outputs, pt_outputs): for fx_output, pt_output in zip(fx_outputs, pt_outputs):
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2) self.assert_almost_equals(fx_output, pt_output.numpy(), 1e-3)
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname) pt_model.save_pretrained(tmpdirname)
...@@ -178,7 +192,10 @@ class FlaxModelTesterMixin: ...@@ -178,7 +192,10 @@ class FlaxModelTesterMixin:
len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch" len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch"
) )
for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs): for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs):
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 4e-2) if not isinstance(
fx_output_loaded, tuple
): # TODO(Patrick, Daniel) - let's discard use_cache for now
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 1e-3)
@is_pt_flax_cross_test @is_pt_flax_cross_test
def test_equivalence_flax_to_pt(self): def test_equivalence_flax_to_pt(self):
...@@ -195,6 +212,9 @@ class FlaxModelTesterMixin: ...@@ -195,6 +212,9 @@ class FlaxModelTesterMixin:
pt_model_class = getattr(transformers, pt_model_class_name) pt_model_class = getattr(transformers, pt_model_class_name)
pt_model = pt_model_class(config).eval() pt_model = pt_model_class(config).eval()
# Flax models don't use the `use_cache` option and cache is not returned as a default.
# So we disable `use_cache` here for PyTorch model.
pt_model.config.use_cache = False
fx_model = model_class(config, dtype=jnp.float32) fx_model = model_class(config, dtype=jnp.float32)
pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params) pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)
...@@ -207,8 +227,9 @@ class FlaxModelTesterMixin: ...@@ -207,8 +227,9 @@ class FlaxModelTesterMixin:
fx_outputs = fx_model(**prepared_inputs_dict).to_tuple() fx_outputs = fx_model(**prepared_inputs_dict).to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output in zip(fx_outputs, pt_outputs): for fx_output, pt_output in zip(fx_outputs, pt_outputs):
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2) self.assert_almost_equals(fx_output, pt_output.numpy(), 1e-3)
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
fx_model.save_pretrained(tmpdirname) fx_model.save_pretrained(tmpdirname)
...@@ -221,7 +242,8 @@ class FlaxModelTesterMixin: ...@@ -221,7 +242,8 @@ class FlaxModelTesterMixin:
len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch" len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch"
) )
for fx_output, pt_output in zip(fx_outputs, pt_outputs_loaded): for fx_output, pt_output in zip(fx_outputs, pt_outputs_loaded):
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2) if not isinstance(fx_output, tuple): # TODO(Patrick, Daniel) - let's discard use_cache for now
self.assert_almost_equals(fx_output, pt_output.numpy(), 5e-3)
def test_from_pretrained_save_pretrained(self): def test_from_pretrained_save_pretrained(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
...@@ -276,6 +298,7 @@ class FlaxModelTesterMixin: ...@@ -276,6 +298,7 @@ class FlaxModelTesterMixin:
self.assertEqual(len(outputs), len(jitted_outputs)) self.assertEqual(len(outputs), len(jitted_outputs))
for jitted_output, output in zip(jitted_outputs, outputs): for jitted_output, output in zip(jitted_outputs, outputs):
self.assertEqual(jitted_output.shape, output.shape) self.assertEqual(jitted_output.shape, output.shape)
def test_forward_signature(self): def test_forward_signature(self):
...@@ -287,8 +310,17 @@ class FlaxModelTesterMixin: ...@@ -287,8 +310,17 @@ class FlaxModelTesterMixin:
# signature.parameters is an OrderedDict => so arg_names order is deterministic # signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()] arg_names = [*signature.parameters.keys()]
expected_arg_names = ["input_ids", "attention_mask"] if model.config.is_encoder_decoder:
self.assertListEqual(arg_names[:2], expected_arg_names) expected_arg_names = [
"input_ids",
"attention_mask",
"decoder_input_ids",
"decoder_attention_mask",
]
self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
else:
expected_arg_names = ["input_ids", "attention_mask"]
self.assertListEqual(arg_names[:2], expected_arg_names)
def test_naming_convention(self): def test_naming_convention(self):
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
...@@ -306,16 +338,36 @@ class FlaxModelTesterMixin: ...@@ -306,16 +338,36 @@ class FlaxModelTesterMixin:
model = model_class(config) model = model_class(config)
outputs = model(**self._prepare_for_class(inputs_dict, model_class)) outputs = model(**self._prepare_for_class(inputs_dict, model_class))
hidden_states = outputs.hidden_states hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states
self.assertEqual(len(hidden_states), self.model_tester.num_hidden_layers + 1) expected_num_layers = getattr(
seq_length = self.model_tester.seq_length self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
)
self.assertEqual(len(hidden_states), expected_num_layers)
if hasattr(self.model_tester, "encoder_seq_length"):
seq_length = self.model_tester.encoder_seq_length
else:
seq_length = self.model_tester.seq_length
self.assertListEqual( self.assertListEqual(
list(hidden_states[0].shape[-2:]), list(hidden_states[0].shape[-2:]),
[seq_length, self.model_tester.hidden_size], [seq_length, self.model_tester.hidden_size],
) )
if config.is_encoder_decoder:
hidden_states = outputs.decoder_hidden_states
self.assertIsInstance(hidden_states, (list, tuple))
self.assertEqual(len(hidden_states), expected_num_layers)
seq_len = getattr(self.model_tester, "seq_length", None)
decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", seq_len)
self.assertListEqual(
list(hidden_states[0].shape[-2:]),
[decoder_seq_length, self.model_tester.hidden_size],
)
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
...@@ -333,13 +385,17 @@ class FlaxModelTesterMixin: ...@@ -333,13 +385,17 @@ class FlaxModelTesterMixin:
config.return_dict = True config.return_dict = True
seq_length = getattr(self.model_tester, "seq_length", None) seq_length = getattr(self.model_tester, "seq_length", None)
decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", seq_length)
encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_length)
decoder_key_length = getattr(self.model_tester, "decoder_key_length", decoder_seq_length)
encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
inputs_dict["output_attentions"] = True inputs_dict["output_attentions"] = True
inputs_dict["output_hidden_states"] = False inputs_dict["output_hidden_states"] = False
model = model_class(config) model = model_class(config)
outputs = model(**self._prepare_for_class(inputs_dict, model_class)) outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs.attentions attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
# check that output_attentions also work using config # check that output_attentions also work using config
...@@ -347,22 +403,58 @@ class FlaxModelTesterMixin: ...@@ -347,22 +403,58 @@ class FlaxModelTesterMixin:
config.output_attentions = True config.output_attentions = True
model = model_class(config) model = model_class(config)
outputs = model(**self._prepare_for_class(inputs_dict, model_class)) outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs.attentions attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
self.assertListEqual( self.assertListEqual(
list(attentions[0].shape[-3:]), list(attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, seq_length, seq_length], [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
) )
out_len = len(outputs) out_len = len(outputs)
if self.is_encoder_decoder:
correct_outlen = 5
# Question Answering model returns start_logits and end_logits
if model_class in get_values(FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING):
correct_outlen += 1 # start_logits and end_logits instead of only 1 output
self.assertEqual(out_len, correct_outlen)
# decoder attentions
decoder_attentions = outputs.decoder_attentions
self.assertIsInstance(decoder_attentions, (list, tuple))
self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
list(decoder_attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, decoder_seq_length, decoder_key_length],
)
# cross attentions
cross_attentions = outputs.cross_attentions
self.assertIsInstance(cross_attentions, (list, tuple))
self.assertEqual(len(cross_attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
list(cross_attentions[0].shape[-3:]),
[
self.model_tester.num_attention_heads,
decoder_seq_length,
encoder_key_length,
],
)
# Check attention is always last and order is fine # Check attention is always last and order is fine
inputs_dict["output_attentions"] = True inputs_dict["output_attentions"] = True
inputs_dict["output_hidden_states"] = True inputs_dict["output_hidden_states"] = True
model = model_class(config) model = model_class(config)
outputs = model(**self._prepare_for_class(inputs_dict, model_class)) outputs = model(**self._prepare_for_class(inputs_dict, model_class))
added_hidden_states = 1 if hasattr(self.model_tester, "num_hidden_states_types"):
added_hidden_states = self.model_tester.num_hidden_states_types
elif self.is_encoder_decoder:
added_hidden_states = 2
else:
added_hidden_states = 1
self.assertEqual(out_len + added_hidden_states, len(outputs)) self.assertEqual(out_len + added_hidden_states, len(outputs))
self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
...@@ -370,5 +462,5 @@ class FlaxModelTesterMixin: ...@@ -370,5 +462,5 @@ class FlaxModelTesterMixin:
self.assertListEqual( self.assertListEqual(
list(self_attentions[0].shape[-3:]), list(self_attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, seq_length, seq_length], [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment