Unverified Commit cd9274d0 authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

[FlaxBert] Add ForCausalLM (#16995)

* [FlaxBert] Add ForCausalLM

* make style

* fix output attentions

* Add RobertaForCausalLM

* remove comment

* fix fx-to-pt model loading

* remove comment

* add modeling tests

* add enc-dec model tests

* add big_bird

* add electra

* make style

* make repo-consitency

* add to docs

* remove roberta test

* quality

* amend cookiecutter

* fix attention_mask bug in flax bert model tester

* tighten pt-fx thresholds to 1e-5

* add 'copied from' statements

* amend 'copied from' statements

* amend 'copied from' statements

* quality
parent 31616b8d
......@@ -166,6 +166,11 @@ This model was contributed by [thomwolf](https://huggingface.co/thomwolf). The o
[[autodoc]] FlaxBertForPreTraining
- __call__
## FlaxBertForCausalLM
[[autodoc]] FlaxBertForCausalLM
- __call__
## FlaxBertForMaskedLM
[[autodoc]] FlaxBertForMaskedLM
......
......@@ -120,6 +120,11 @@ This model was contributed by [vasudevgupta](https://huggingface.co/vasudevgupta
[[autodoc]] FlaxBigBirdForPreTraining
- __call__
## FlaxBigBirdForCausalLM
[[autodoc]] FlaxBigBirdForCausalLM
- __call__
## FlaxBigBirdForMaskedLM
[[autodoc]] FlaxBigBirdForMaskedLM
......
......@@ -158,6 +158,11 @@ This model was contributed by [lysandre](https://huggingface.co/lysandre). The o
[[autodoc]] FlaxElectraForPreTraining
- __call__
## FlaxElectraForCausalLM
[[autodoc]] FlaxElectraForCausalLM
- __call__
## FlaxElectraForMaskedLM
[[autodoc]] FlaxElectraForMaskedLM
......
......@@ -136,6 +136,11 @@ This model was contributed by [julien-c](https://huggingface.co/julien-c). The o
[[autodoc]] FlaxRobertaModel
- __call__
## FlaxRobertaForCausalLM
[[autodoc]] FlaxRobertaForCausalLM
- __call__
## FlaxRobertaForMaskedLM
[[autodoc]] FlaxRobertaForMaskedLM
......
......@@ -2314,6 +2314,7 @@ if is_flax_available():
)
_import_structure["models.bert"].extend(
[
"FlaxBertForCausalLM",
"FlaxBertForMaskedLM",
"FlaxBertForMultipleChoice",
"FlaxBertForNextSentencePrediction",
......@@ -2327,6 +2328,7 @@ if is_flax_available():
)
_import_structure["models.big_bird"].extend(
[
"FlaxBigBirdForCausalLM",
"FlaxBigBirdForMaskedLM",
"FlaxBigBirdForMultipleChoice",
"FlaxBigBirdForPreTraining",
......@@ -2370,6 +2372,7 @@ if is_flax_available():
)
_import_structure["models.electra"].extend(
[
"FlaxElectraForCausalLM",
"FlaxElectraForMaskedLM",
"FlaxElectraForMultipleChoice",
"FlaxElectraForPreTraining",
......@@ -2412,6 +2415,7 @@ if is_flax_available():
)
_import_structure["models.roberta"].extend(
[
"FlaxRobertaForCausalLM",
"FlaxRobertaForMaskedLM",
"FlaxRobertaForMultipleChoice",
"FlaxRobertaForQuestionAnswering",
......@@ -4363,6 +4367,7 @@ if TYPE_CHECKING:
FlaxBeitPreTrainedModel,
)
from .models.bert import (
FlaxBertForCausalLM,
FlaxBertForMaskedLM,
FlaxBertForMultipleChoice,
FlaxBertForNextSentencePrediction,
......@@ -4374,6 +4379,7 @@ if TYPE_CHECKING:
FlaxBertPreTrainedModel,
)
from .models.big_bird import (
FlaxBigBirdForCausalLM,
FlaxBigBirdForMaskedLM,
FlaxBigBirdForMultipleChoice,
FlaxBigBirdForPreTraining,
......@@ -4411,6 +4417,7 @@ if TYPE_CHECKING:
FlaxDistilBertPreTrainedModel,
)
from .models.electra import (
FlaxElectraForCausalLM,
FlaxElectraForMaskedLM,
FlaxElectraForMultipleChoice,
FlaxElectraForPreTraining,
......@@ -4435,6 +4442,7 @@ if TYPE_CHECKING:
from .models.mt5 import FlaxMT5ForConditionalGeneration, FlaxMT5Model
from .models.pegasus import FlaxPegasusForConditionalGeneration, FlaxPegasusModel, FlaxPegasusPreTrainedModel
from .models.roberta import (
FlaxRobertaForCausalLM,
FlaxRobertaForMaskedLM,
FlaxRobertaForMultipleChoice,
FlaxRobertaForQuestionAnswering,
......
......@@ -106,6 +106,55 @@ class FlaxBaseModelOutputWithPooling(ModelOutput):
attentions: Optional[Tuple[jnp.ndarray]] = None
@flax.struct.dataclass
class FlaxBaseModelOutputWithPoolingAndCrossAttentions(ModelOutput):
"""
Base class for model's outputs that also contains a pooling of the last hidden states.
Args:
last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
pooler_output (`jnp.ndarray` of shape `(batch_size, hidden_size)`):
Last layer hidden-state of the first token of the sequence (classification token) after further processing
through the layers used for the auxiliary pretraining task. E.g. for BERT-family of models, this returns
the classification token after processing through a linear layer and a tanh activation function. The linear
layer weights are trained from the next sentence prediction (classification) objective during pretraining.
hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `jnp.ndarray` (one for the output of the embeddings, if the model has an embedding layer, + one
for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
cross_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`):
Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
weighted average in the cross-attention heads.
past_key_values (`tuple(tuple(jnp.ndarray))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(jnp.ndarray)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
`config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
encoder_sequence_length, embed_size_per_head)`.
Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
`config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
input) to speed up sequential decoding.
"""
last_hidden_state: jnp.ndarray = None
pooler_output: jnp.ndarray = None
hidden_states: Optional[Tuple[jnp.ndarray]] = None
past_key_values: Optional[Tuple[Tuple[jnp.ndarray]]] = None
attentions: Optional[Tuple[jnp.ndarray]] = None
cross_attentions: Optional[Tuple[jnp.ndarray]] = None
@flax.struct.dataclass
class FlaxBaseModelOutputWithPastAndCrossAttentions(ModelOutput):
"""
......
......@@ -127,6 +127,10 @@ FLAX_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
("gptj", "FlaxGPTJForCausalLM"),
("xglm", "FlaxXGLMForCausalLM"),
("bart", "FlaxBartForCausalLM"),
("bert", "FlaxBertForCausalLM"),
("roberta", "FlaxRobertaForCausalLM"),
("big_bird", "FlaxBigBirdForCausalLM"),
("electra", "FlaxElectraForCausalLM"),
]
)
......
......@@ -65,6 +65,7 @@ if is_tf_available():
if is_flax_available():
_import_structure["modeling_flax_bert"] = [
"FlaxBertForCausalLM",
"FlaxBertForMaskedLM",
"FlaxBertForMultipleChoice",
"FlaxBertForNextSentencePrediction",
......@@ -119,6 +120,7 @@ if TYPE_CHECKING:
if is_flax_available():
from .modeling_flax_bert import (
FlaxBertForCausalLM,
FlaxBertForMaskedLM,
FlaxBertForMultipleChoice,
FlaxBertForNextSentencePrediction,
......
......@@ -55,6 +55,7 @@ if is_torch_available():
if is_flax_available():
_import_structure["modeling_flax_big_bird"] = [
"FlaxBigBirdForCausalLM",
"FlaxBigBirdForMaskedLM",
"FlaxBigBirdForMultipleChoice",
"FlaxBigBirdForPreTraining",
......@@ -92,6 +93,7 @@ if TYPE_CHECKING:
if is_flax_available():
from .modeling_flax_big_bird import (
FlaxBigBirdForCausalLM,
FlaxBigBirdForMaskedLM,
FlaxBigBirdForMultipleChoice,
FlaxBigBirdForPreTraining,
......
......@@ -59,6 +59,7 @@ if is_tf_available():
if is_flax_available():
_import_structure["modeling_flax_electra"] = [
"FlaxElectraForCausalLM",
"FlaxElectraForMaskedLM",
"FlaxElectraForMultipleChoice",
"FlaxElectraForPreTraining",
......@@ -107,6 +108,7 @@ if TYPE_CHECKING:
if is_flax_available():
from .modeling_flax_electra import (
FlaxElectraForCausalLM,
FlaxElectraForMaskedLM,
FlaxElectraForMultipleChoice,
FlaxElectraForPreTraining,
......
......@@ -58,6 +58,7 @@ if is_tf_available():
if is_flax_available():
_import_structure["modeling_flax_roberta"] = [
"FlaxRobertaForCausalLM",
"FlaxRobertaForMaskedLM",
"FlaxRobertaForMultipleChoice",
"FlaxRobertaForQuestionAnswering",
......@@ -103,7 +104,8 @@ if TYPE_CHECKING:
)
if is_flax_available():
from .modeling_tf_roberta import (
from .modeling_flax_roberta import (
FlaxRobertaForCausalLM,
FlaxRobertaForMaskedLM,
FlaxRobertaForMultipleChoice,
FlaxRobertaForQuestionAnswering,
......
......@@ -326,6 +326,13 @@ class FlaxBeitPreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["flax"])
class FlaxBertForCausalLM(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxBertForMaskedLM(metaclass=DummyObject):
_backends = ["flax"]
......@@ -389,6 +396,13 @@ class FlaxBertPreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["flax"])
class FlaxBigBirdForCausalLM(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxBigBirdForMaskedLM(metaclass=DummyObject):
_backends = ["flax"]
......@@ -578,6 +592,13 @@ class FlaxDistilBertPreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["flax"])
class FlaxElectraForCausalLM(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxElectraForMaskedLM(metaclass=DummyObject):
_backends = ["flax"]
......@@ -795,6 +816,13 @@ class FlaxPegasusPreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["flax"])
class FlaxRobertaForCausalLM(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxRobertaForMaskedLM(metaclass=DummyObject):
_backends = ["flax"]
......
......@@ -19,7 +19,7 @@ import numpy as np
from transformers import BertConfig, is_flax_available
from transformers.testing_utils import require_flax, slow
from ..test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_attention_mask
from ..test_modeling_flax_common import FlaxModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
if is_flax_available():
......@@ -114,6 +114,22 @@ class FlaxBertModelTester(unittest.TestCase):
inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": attention_mask}
return config, inputs_dict
def prepare_config_and_inputs_for_decoder(self):
config_and_inputs = self.prepare_config_and_inputs()
config, input_ids, token_type_ids, attention_mask = config_and_inputs
config.is_decoder = True
encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size])
encoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
return (
config,
input_ids,
attention_mask,
encoder_hidden_states,
encoder_attention_mask,
)
@require_flax
class FlaxBertModelTest(FlaxModelTesterMixin, unittest.TestCase):
......
......@@ -25,6 +25,7 @@ from ..test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random
if is_flax_available():
import jax
from transformers.models.big_bird.modeling_flax_big_bird import (
FlaxBigBirdForCausalLM,
FlaxBigBirdForMaskedLM,
FlaxBigBirdForMultipleChoice,
FlaxBigBirdForPreTraining,
......@@ -136,6 +137,7 @@ class FlaxBigBirdModelTest(FlaxModelTesterMixin, unittest.TestCase):
all_model_classes = (
(
FlaxBigBirdForCausalLM,
FlaxBigBirdModel,
FlaxBigBirdForPreTraining,
FlaxBigBirdForMaskedLM,
......
......@@ -10,6 +10,7 @@ from ..test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random
if is_flax_available():
from transformers.models.electra.modeling_flax_electra import (
FlaxElectraForCausalLM,
FlaxElectraForMaskedLM,
FlaxElectraForMultipleChoice,
FlaxElectraForPreTraining,
......@@ -110,6 +111,7 @@ class FlaxElectraModelTest(FlaxModelTesterMixin, unittest.TestCase):
all_model_classes = (
(
FlaxElectraModel,
FlaxElectraForCausalLM,
FlaxElectraForMaskedLM,
FlaxElectraForPreTraining,
FlaxElectraForTokenClassification,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment