Unverified Commit cd9274d0 authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

[FlaxBert] Add ForCausalLM (#16995)

* [FlaxBert] Add ForCausalLM

* make style

* fix output attentions

* Add RobertaForCausalLM

* remove comment

* fix fx-to-pt model loading

* remove comment

* add modeling tests

* add enc-dec model tests

* add big_bird

* add electra

* make style

* make repo-consitency

* add to docs

* remove roberta test

* quality

* amend cookiecutter

* fix attention_mask bug in flax bert model tester

* tighten pt-fx thresholds to 1e-5

* add 'copied from' statements

* amend 'copied from' statements

* amend 'copied from' statements

* quality
parent 31616b8d
...@@ -166,6 +166,11 @@ This model was contributed by [thomwolf](https://huggingface.co/thomwolf). The o ...@@ -166,6 +166,11 @@ This model was contributed by [thomwolf](https://huggingface.co/thomwolf). The o
[[autodoc]] FlaxBertForPreTraining [[autodoc]] FlaxBertForPreTraining
- __call__ - __call__
## FlaxBertForCausalLM
[[autodoc]] FlaxBertForCausalLM
- __call__
## FlaxBertForMaskedLM ## FlaxBertForMaskedLM
[[autodoc]] FlaxBertForMaskedLM [[autodoc]] FlaxBertForMaskedLM
......
...@@ -120,6 +120,11 @@ This model was contributed by [vasudevgupta](https://huggingface.co/vasudevgupta ...@@ -120,6 +120,11 @@ This model was contributed by [vasudevgupta](https://huggingface.co/vasudevgupta
[[autodoc]] FlaxBigBirdForPreTraining [[autodoc]] FlaxBigBirdForPreTraining
- __call__ - __call__
## FlaxBigBirdForCausalLM
[[autodoc]] FlaxBigBirdForCausalLM
- __call__
## FlaxBigBirdForMaskedLM ## FlaxBigBirdForMaskedLM
[[autodoc]] FlaxBigBirdForMaskedLM [[autodoc]] FlaxBigBirdForMaskedLM
......
...@@ -158,6 +158,11 @@ This model was contributed by [lysandre](https://huggingface.co/lysandre). The o ...@@ -158,6 +158,11 @@ This model was contributed by [lysandre](https://huggingface.co/lysandre). The o
[[autodoc]] FlaxElectraForPreTraining [[autodoc]] FlaxElectraForPreTraining
- __call__ - __call__
## FlaxElectraForCausalLM
[[autodoc]] FlaxElectraForCausalLM
- __call__
## FlaxElectraForMaskedLM ## FlaxElectraForMaskedLM
[[autodoc]] FlaxElectraForMaskedLM [[autodoc]] FlaxElectraForMaskedLM
......
...@@ -136,6 +136,11 @@ This model was contributed by [julien-c](https://huggingface.co/julien-c). The o ...@@ -136,6 +136,11 @@ This model was contributed by [julien-c](https://huggingface.co/julien-c). The o
[[autodoc]] FlaxRobertaModel [[autodoc]] FlaxRobertaModel
- __call__ - __call__
## FlaxRobertaForCausalLM
[[autodoc]] FlaxRobertaForCausalLM
- __call__
## FlaxRobertaForMaskedLM ## FlaxRobertaForMaskedLM
[[autodoc]] FlaxRobertaForMaskedLM [[autodoc]] FlaxRobertaForMaskedLM
......
...@@ -2314,6 +2314,7 @@ if is_flax_available(): ...@@ -2314,6 +2314,7 @@ if is_flax_available():
) )
_import_structure["models.bert"].extend( _import_structure["models.bert"].extend(
[ [
"FlaxBertForCausalLM",
"FlaxBertForMaskedLM", "FlaxBertForMaskedLM",
"FlaxBertForMultipleChoice", "FlaxBertForMultipleChoice",
"FlaxBertForNextSentencePrediction", "FlaxBertForNextSentencePrediction",
...@@ -2327,6 +2328,7 @@ if is_flax_available(): ...@@ -2327,6 +2328,7 @@ if is_flax_available():
) )
_import_structure["models.big_bird"].extend( _import_structure["models.big_bird"].extend(
[ [
"FlaxBigBirdForCausalLM",
"FlaxBigBirdForMaskedLM", "FlaxBigBirdForMaskedLM",
"FlaxBigBirdForMultipleChoice", "FlaxBigBirdForMultipleChoice",
"FlaxBigBirdForPreTraining", "FlaxBigBirdForPreTraining",
...@@ -2370,6 +2372,7 @@ if is_flax_available(): ...@@ -2370,6 +2372,7 @@ if is_flax_available():
) )
_import_structure["models.electra"].extend( _import_structure["models.electra"].extend(
[ [
"FlaxElectraForCausalLM",
"FlaxElectraForMaskedLM", "FlaxElectraForMaskedLM",
"FlaxElectraForMultipleChoice", "FlaxElectraForMultipleChoice",
"FlaxElectraForPreTraining", "FlaxElectraForPreTraining",
...@@ -2412,6 +2415,7 @@ if is_flax_available(): ...@@ -2412,6 +2415,7 @@ if is_flax_available():
) )
_import_structure["models.roberta"].extend( _import_structure["models.roberta"].extend(
[ [
"FlaxRobertaForCausalLM",
"FlaxRobertaForMaskedLM", "FlaxRobertaForMaskedLM",
"FlaxRobertaForMultipleChoice", "FlaxRobertaForMultipleChoice",
"FlaxRobertaForQuestionAnswering", "FlaxRobertaForQuestionAnswering",
...@@ -4363,6 +4367,7 @@ if TYPE_CHECKING: ...@@ -4363,6 +4367,7 @@ if TYPE_CHECKING:
FlaxBeitPreTrainedModel, FlaxBeitPreTrainedModel,
) )
from .models.bert import ( from .models.bert import (
FlaxBertForCausalLM,
FlaxBertForMaskedLM, FlaxBertForMaskedLM,
FlaxBertForMultipleChoice, FlaxBertForMultipleChoice,
FlaxBertForNextSentencePrediction, FlaxBertForNextSentencePrediction,
...@@ -4374,6 +4379,7 @@ if TYPE_CHECKING: ...@@ -4374,6 +4379,7 @@ if TYPE_CHECKING:
FlaxBertPreTrainedModel, FlaxBertPreTrainedModel,
) )
from .models.big_bird import ( from .models.big_bird import (
FlaxBigBirdForCausalLM,
FlaxBigBirdForMaskedLM, FlaxBigBirdForMaskedLM,
FlaxBigBirdForMultipleChoice, FlaxBigBirdForMultipleChoice,
FlaxBigBirdForPreTraining, FlaxBigBirdForPreTraining,
...@@ -4411,6 +4417,7 @@ if TYPE_CHECKING: ...@@ -4411,6 +4417,7 @@ if TYPE_CHECKING:
FlaxDistilBertPreTrainedModel, FlaxDistilBertPreTrainedModel,
) )
from .models.electra import ( from .models.electra import (
FlaxElectraForCausalLM,
FlaxElectraForMaskedLM, FlaxElectraForMaskedLM,
FlaxElectraForMultipleChoice, FlaxElectraForMultipleChoice,
FlaxElectraForPreTraining, FlaxElectraForPreTraining,
...@@ -4435,6 +4442,7 @@ if TYPE_CHECKING: ...@@ -4435,6 +4442,7 @@ if TYPE_CHECKING:
from .models.mt5 import FlaxMT5ForConditionalGeneration, FlaxMT5Model from .models.mt5 import FlaxMT5ForConditionalGeneration, FlaxMT5Model
from .models.pegasus import FlaxPegasusForConditionalGeneration, FlaxPegasusModel, FlaxPegasusPreTrainedModel from .models.pegasus import FlaxPegasusForConditionalGeneration, FlaxPegasusModel, FlaxPegasusPreTrainedModel
from .models.roberta import ( from .models.roberta import (
FlaxRobertaForCausalLM,
FlaxRobertaForMaskedLM, FlaxRobertaForMaskedLM,
FlaxRobertaForMultipleChoice, FlaxRobertaForMultipleChoice,
FlaxRobertaForQuestionAnswering, FlaxRobertaForQuestionAnswering,
......
...@@ -106,6 +106,55 @@ class FlaxBaseModelOutputWithPooling(ModelOutput): ...@@ -106,6 +106,55 @@ class FlaxBaseModelOutputWithPooling(ModelOutput):
attentions: Optional[Tuple[jnp.ndarray]] = None attentions: Optional[Tuple[jnp.ndarray]] = None
@flax.struct.dataclass
class FlaxBaseModelOutputWithPoolingAndCrossAttentions(ModelOutput):
"""
Base class for model's outputs that also contains a pooling of the last hidden states.
Args:
last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
pooler_output (`jnp.ndarray` of shape `(batch_size, hidden_size)`):
Last layer hidden-state of the first token of the sequence (classification token) after further processing
through the layers used for the auxiliary pretraining task. E.g. for BERT-family of models, this returns
the classification token after processing through a linear layer and a tanh activation function. The linear
layer weights are trained from the next sentence prediction (classification) objective during pretraining.
hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `jnp.ndarray` (one for the output of the embeddings, if the model has an embedding layer, + one
for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
cross_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`):
Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
weighted average in the cross-attention heads.
past_key_values (`tuple(tuple(jnp.ndarray))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(jnp.ndarray)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
`config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
encoder_sequence_length, embed_size_per_head)`.
Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
`config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
input) to speed up sequential decoding.
"""
last_hidden_state: jnp.ndarray = None
pooler_output: jnp.ndarray = None
hidden_states: Optional[Tuple[jnp.ndarray]] = None
past_key_values: Optional[Tuple[Tuple[jnp.ndarray]]] = None
attentions: Optional[Tuple[jnp.ndarray]] = None
cross_attentions: Optional[Tuple[jnp.ndarray]] = None
@flax.struct.dataclass @flax.struct.dataclass
class FlaxBaseModelOutputWithPastAndCrossAttentions(ModelOutput): class FlaxBaseModelOutputWithPastAndCrossAttentions(ModelOutput):
""" """
......
...@@ -127,6 +127,10 @@ FLAX_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict( ...@@ -127,6 +127,10 @@ FLAX_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
("gptj", "FlaxGPTJForCausalLM"), ("gptj", "FlaxGPTJForCausalLM"),
("xglm", "FlaxXGLMForCausalLM"), ("xglm", "FlaxXGLMForCausalLM"),
("bart", "FlaxBartForCausalLM"), ("bart", "FlaxBartForCausalLM"),
("bert", "FlaxBertForCausalLM"),
("roberta", "FlaxRobertaForCausalLM"),
("big_bird", "FlaxBigBirdForCausalLM"),
("electra", "FlaxElectraForCausalLM"),
] ]
) )
......
...@@ -65,6 +65,7 @@ if is_tf_available(): ...@@ -65,6 +65,7 @@ if is_tf_available():
if is_flax_available(): if is_flax_available():
_import_structure["modeling_flax_bert"] = [ _import_structure["modeling_flax_bert"] = [
"FlaxBertForCausalLM",
"FlaxBertForMaskedLM", "FlaxBertForMaskedLM",
"FlaxBertForMultipleChoice", "FlaxBertForMultipleChoice",
"FlaxBertForNextSentencePrediction", "FlaxBertForNextSentencePrediction",
...@@ -119,6 +120,7 @@ if TYPE_CHECKING: ...@@ -119,6 +120,7 @@ if TYPE_CHECKING:
if is_flax_available(): if is_flax_available():
from .modeling_flax_bert import ( from .modeling_flax_bert import (
FlaxBertForCausalLM,
FlaxBertForMaskedLM, FlaxBertForMaskedLM,
FlaxBertForMultipleChoice, FlaxBertForMultipleChoice,
FlaxBertForNextSentencePrediction, FlaxBertForNextSentencePrediction,
......
...@@ -55,6 +55,7 @@ if is_torch_available(): ...@@ -55,6 +55,7 @@ if is_torch_available():
if is_flax_available(): if is_flax_available():
_import_structure["modeling_flax_big_bird"] = [ _import_structure["modeling_flax_big_bird"] = [
"FlaxBigBirdForCausalLM",
"FlaxBigBirdForMaskedLM", "FlaxBigBirdForMaskedLM",
"FlaxBigBirdForMultipleChoice", "FlaxBigBirdForMultipleChoice",
"FlaxBigBirdForPreTraining", "FlaxBigBirdForPreTraining",
...@@ -92,6 +93,7 @@ if TYPE_CHECKING: ...@@ -92,6 +93,7 @@ if TYPE_CHECKING:
if is_flax_available(): if is_flax_available():
from .modeling_flax_big_bird import ( from .modeling_flax_big_bird import (
FlaxBigBirdForCausalLM,
FlaxBigBirdForMaskedLM, FlaxBigBirdForMaskedLM,
FlaxBigBirdForMultipleChoice, FlaxBigBirdForMultipleChoice,
FlaxBigBirdForPreTraining, FlaxBigBirdForPreTraining,
......
...@@ -59,6 +59,7 @@ if is_tf_available(): ...@@ -59,6 +59,7 @@ if is_tf_available():
if is_flax_available(): if is_flax_available():
_import_structure["modeling_flax_electra"] = [ _import_structure["modeling_flax_electra"] = [
"FlaxElectraForCausalLM",
"FlaxElectraForMaskedLM", "FlaxElectraForMaskedLM",
"FlaxElectraForMultipleChoice", "FlaxElectraForMultipleChoice",
"FlaxElectraForPreTraining", "FlaxElectraForPreTraining",
...@@ -107,6 +108,7 @@ if TYPE_CHECKING: ...@@ -107,6 +108,7 @@ if TYPE_CHECKING:
if is_flax_available(): if is_flax_available():
from .modeling_flax_electra import ( from .modeling_flax_electra import (
FlaxElectraForCausalLM,
FlaxElectraForMaskedLM, FlaxElectraForMaskedLM,
FlaxElectraForMultipleChoice, FlaxElectraForMultipleChoice,
FlaxElectraForPreTraining, FlaxElectraForPreTraining,
......
...@@ -58,6 +58,7 @@ if is_tf_available(): ...@@ -58,6 +58,7 @@ if is_tf_available():
if is_flax_available(): if is_flax_available():
_import_structure["modeling_flax_roberta"] = [ _import_structure["modeling_flax_roberta"] = [
"FlaxRobertaForCausalLM",
"FlaxRobertaForMaskedLM", "FlaxRobertaForMaskedLM",
"FlaxRobertaForMultipleChoice", "FlaxRobertaForMultipleChoice",
"FlaxRobertaForQuestionAnswering", "FlaxRobertaForQuestionAnswering",
...@@ -103,7 +104,8 @@ if TYPE_CHECKING: ...@@ -103,7 +104,8 @@ if TYPE_CHECKING:
) )
if is_flax_available(): if is_flax_available():
from .modeling_tf_roberta import ( from .modeling_flax_roberta import (
FlaxRobertaForCausalLM,
FlaxRobertaForMaskedLM, FlaxRobertaForMaskedLM,
FlaxRobertaForMultipleChoice, FlaxRobertaForMultipleChoice,
FlaxRobertaForQuestionAnswering, FlaxRobertaForQuestionAnswering,
......
...@@ -326,6 +326,13 @@ class FlaxBeitPreTrainedModel(metaclass=DummyObject): ...@@ -326,6 +326,13 @@ class FlaxBeitPreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["flax"]) requires_backends(self, ["flax"])
class FlaxBertForCausalLM(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxBertForMaskedLM(metaclass=DummyObject): class FlaxBertForMaskedLM(metaclass=DummyObject):
_backends = ["flax"] _backends = ["flax"]
...@@ -389,6 +396,13 @@ class FlaxBertPreTrainedModel(metaclass=DummyObject): ...@@ -389,6 +396,13 @@ class FlaxBertPreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["flax"]) requires_backends(self, ["flax"])
class FlaxBigBirdForCausalLM(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxBigBirdForMaskedLM(metaclass=DummyObject): class FlaxBigBirdForMaskedLM(metaclass=DummyObject):
_backends = ["flax"] _backends = ["flax"]
...@@ -578,6 +592,13 @@ class FlaxDistilBertPreTrainedModel(metaclass=DummyObject): ...@@ -578,6 +592,13 @@ class FlaxDistilBertPreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["flax"]) requires_backends(self, ["flax"])
class FlaxElectraForCausalLM(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxElectraForMaskedLM(metaclass=DummyObject): class FlaxElectraForMaskedLM(metaclass=DummyObject):
_backends = ["flax"] _backends = ["flax"]
...@@ -795,6 +816,13 @@ class FlaxPegasusPreTrainedModel(metaclass=DummyObject): ...@@ -795,6 +816,13 @@ class FlaxPegasusPreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["flax"]) requires_backends(self, ["flax"])
class FlaxRobertaForCausalLM(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxRobertaForMaskedLM(metaclass=DummyObject): class FlaxRobertaForMaskedLM(metaclass=DummyObject):
_backends = ["flax"] _backends = ["flax"]
......
...@@ -19,7 +19,7 @@ import numpy as np ...@@ -19,7 +19,7 @@ import numpy as np
from transformers import BertConfig, is_flax_available from transformers import BertConfig, is_flax_available
from transformers.testing_utils import require_flax, slow from transformers.testing_utils import require_flax, slow
from ..test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_attention_mask from ..test_modeling_flax_common import FlaxModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
if is_flax_available(): if is_flax_available():
...@@ -114,6 +114,22 @@ class FlaxBertModelTester(unittest.TestCase): ...@@ -114,6 +114,22 @@ class FlaxBertModelTester(unittest.TestCase):
inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": attention_mask} inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": attention_mask}
return config, inputs_dict return config, inputs_dict
def prepare_config_and_inputs_for_decoder(self):
config_and_inputs = self.prepare_config_and_inputs()
config, input_ids, token_type_ids, attention_mask = config_and_inputs
config.is_decoder = True
encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size])
encoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
return (
config,
input_ids,
attention_mask,
encoder_hidden_states,
encoder_attention_mask,
)
@require_flax @require_flax
class FlaxBertModelTest(FlaxModelTesterMixin, unittest.TestCase): class FlaxBertModelTest(FlaxModelTesterMixin, unittest.TestCase):
......
...@@ -25,6 +25,7 @@ from ..test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random ...@@ -25,6 +25,7 @@ from ..test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random
if is_flax_available(): if is_flax_available():
import jax import jax
from transformers.models.big_bird.modeling_flax_big_bird import ( from transformers.models.big_bird.modeling_flax_big_bird import (
FlaxBigBirdForCausalLM,
FlaxBigBirdForMaskedLM, FlaxBigBirdForMaskedLM,
FlaxBigBirdForMultipleChoice, FlaxBigBirdForMultipleChoice,
FlaxBigBirdForPreTraining, FlaxBigBirdForPreTraining,
...@@ -136,6 +137,7 @@ class FlaxBigBirdModelTest(FlaxModelTesterMixin, unittest.TestCase): ...@@ -136,6 +137,7 @@ class FlaxBigBirdModelTest(FlaxModelTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
( (
FlaxBigBirdForCausalLM,
FlaxBigBirdModel, FlaxBigBirdModel,
FlaxBigBirdForPreTraining, FlaxBigBirdForPreTraining,
FlaxBigBirdForMaskedLM, FlaxBigBirdForMaskedLM,
......
...@@ -10,6 +10,7 @@ from ..test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random ...@@ -10,6 +10,7 @@ from ..test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random
if is_flax_available(): if is_flax_available():
from transformers.models.electra.modeling_flax_electra import ( from transformers.models.electra.modeling_flax_electra import (
FlaxElectraForCausalLM,
FlaxElectraForMaskedLM, FlaxElectraForMaskedLM,
FlaxElectraForMultipleChoice, FlaxElectraForMultipleChoice,
FlaxElectraForPreTraining, FlaxElectraForPreTraining,
...@@ -110,6 +111,7 @@ class FlaxElectraModelTest(FlaxModelTesterMixin, unittest.TestCase): ...@@ -110,6 +111,7 @@ class FlaxElectraModelTest(FlaxModelTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
( (
FlaxElectraModel, FlaxElectraModel,
FlaxElectraForCausalLM,
FlaxElectraForMaskedLM, FlaxElectraForMaskedLM,
FlaxElectraForPreTraining, FlaxElectraForPreTraining,
FlaxElectraForTokenClassification, FlaxElectraForTokenClassification,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment