Unverified Commit e98233dd authored by Vasudev Gupta's avatar Vasudev Gupta Committed by GitHub
Browse files

Flax T5 (#12150)



* copy pytorch-t5

* init

* boom boom

* forward pass same

* make generation work

* add more tests

* make test work

* finish normal tests

* make fix-copies

* finish quality

* correct slow example

* correct slow test

* version table

* upload models

* Update tests/test_modeling_flax_t5.py

* correct incorrectly deleted line
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarPatrick von Platen <patrick@huggingface.co>
parent 7d4cfa3b
......@@ -396,7 +396,7 @@ Flax), PyTorch, and/or TensorFlow.
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| SqueezeBERT | ✅ | ✅ | ✅ | ❌ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| T5 | ✅ | ✅ | ✅ | ✅ | |
| T5 | ✅ | ✅ | ✅ | ✅ | |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| TAPAS | ✅ | ❌ | ✅ | ❌ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
......
......@@ -160,3 +160,15 @@ TFT5EncoderModel
.. autoclass:: transformers.TFT5EncoderModel
:members: call
FlaxT5Model
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxT5Model
:members: __call__, encode, decode
FlaxT5ForConditionalGeneration
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxT5ForConditionalGeneration
:members: __call__, encode, decode
......@@ -114,6 +114,7 @@ _deps = [
"onnxruntime-tools>=1.4.2",
"onnxruntime>=1.4.0",
"optuna",
"optax>=0.0.8",
"packaging",
"parameterized",
"protobuf",
......@@ -234,7 +235,7 @@ if os.name == "nt": # windows
extras["flax"] = [] # jax is not supported on windows
else:
extras["retrieval"] = deps_list("faiss-cpu", "datasets")
extras["flax"] = deps_list("jax", "jaxlib", "flax")
extras["flax"] = deps_list("jax", "jaxlib", "flax", "optax")
extras["tokenizers"] = deps_list("tokenizers")
extras["onnxruntime"] = deps_list("onnxruntime", "onnxruntime-tools")
......@@ -325,7 +326,7 @@ install_requires = [
deps["huggingface-hub"],
deps["numpy"],
deps["packaging"], # utilities from PyPA to e.g., compare versions
deps["pyyaml"], # used for the model cards metadata
deps["pyyaml"], # used for the model cards metadata
deps["regex"], # for OpenAI GPT
deps["requests"], # for downloading models over HTTPS
deps["sacremoses"], # for XLM
......
......@@ -1597,6 +1597,7 @@ if is_flax_available():
"FlaxRobertaPreTrainedModel",
]
)
_import_structure["models.t5"].extend(["FlaxT5ForConditionalGeneration", "FlaxT5Model"])
_import_structure["models.vit"].extend(["FlaxViTForImageClassification", "FlaxViTModel"])
else:
from .utils import dummy_flax_objects
......@@ -2920,6 +2921,7 @@ if TYPE_CHECKING:
FlaxRobertaModel,
FlaxRobertaPreTrainedModel,
)
from .models.t5 import FlaxT5ForConditionalGeneration, FlaxT5Model
from .models.vit import FlaxViTForImageClassification, FlaxViTModel
else:
# Import the same objects as dummies to get them in the namespace.
......
......@@ -31,6 +31,7 @@ deps = {
"onnxruntime-tools": "onnxruntime-tools>=1.4.2",
"onnxruntime": "onnxruntime>=1.4.0",
"optuna": "optuna",
"optax": "optax>=0.0.8",
"packaging": "packaging",
"parameterized": "parameterized",
"protobuf": "protobuf",
......
......@@ -62,6 +62,7 @@ from ..roberta.modeling_flax_roberta import (
FlaxRobertaForTokenClassification,
FlaxRobertaModel,
)
from ..t5.modeling_flax_t5 import FlaxT5ForConditionalGeneration, FlaxT5Model
from ..vit.modeling_flax_vit import FlaxViTForImageClassification, FlaxViTModel
from .auto_factory import auto_class_factory
from .configuration_auto import (
......@@ -72,6 +73,7 @@ from .configuration_auto import (
ElectraConfig,
GPT2Config,
RobertaConfig,
T5Config,
ViTConfig,
)
......@@ -90,6 +92,7 @@ FLAX_MODEL_MAPPING = OrderedDict(
(ElectraConfig, FlaxElectraModel),
(CLIPConfig, FlaxCLIPModel),
(ViTConfig, FlaxViTModel),
(T5Config, FlaxT5Model),
]
)
......@@ -101,6 +104,7 @@ FLAX_MODEL_FOR_PRETRAINING_MAPPING = OrderedDict(
(BigBirdConfig, FlaxBigBirdForPreTraining),
(BartConfig, FlaxBartForConditionalGeneration),
(ElectraConfig, FlaxElectraForPreTraining),
(T5Config, FlaxT5ForConditionalGeneration),
]
)
......@@ -115,6 +119,14 @@ FLAX_MODEL_FOR_MASKED_LM_MAPPING = OrderedDict(
]
)
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = OrderedDict(
[
# Model for Seq2Seq Causal LM mapping
(BartConfig, FlaxBartForConditionalGeneration),
(T5Config, FlaxT5ForConditionalGeneration),
]
)
FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = OrderedDict(
[
# Model for Image-classsification
......@@ -234,3 +246,9 @@ FlaxAutoModelForNextSentencePrediction = auto_class_factory(
FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
head_doc="next sentence prediction",
)
FlaxAutoModelForSeq2SeqLM = auto_class_factory(
"FlaxAutoModelForSeq2SeqLM",
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
head_doc="sequence-to-sequence language modeling",
)
......@@ -229,7 +229,6 @@ class FlaxBartAttention(nn.Module):
embed_dim: int
num_heads: int
dropout: float = 0.0
is_decoder: bool = False
causal: bool = False
bias: bool = True
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
......@@ -510,7 +509,6 @@ class FlaxBartDecoderLayer(nn.Module):
embed_dim=self.embed_dim,
num_heads=self.config.decoder_attention_heads,
dropout=self.config.attention_dropout,
is_decoder=True,
causal=True,
)
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
......@@ -523,7 +521,6 @@ class FlaxBartDecoderLayer(nn.Module):
embed_dim=self.embed_dim,
num_heads=self.config.decoder_attention_heads,
dropout=self.config.attention_dropout,
is_decoder=True,
)
self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype)
self.fc1 = nn.Dense(
......
......@@ -20,6 +20,7 @@ from typing import TYPE_CHECKING
from ...file_utils import (
_BaseLazyModule,
is_flax_available,
is_sentencepiece_available,
is_tf_available,
is_tokenizers_available,
......@@ -56,6 +57,13 @@ if is_tf_available():
"TFT5PreTrainedModel",
]
if is_flax_available():
_import_structure["modeling_flax_t5"] = [
"FlaxT5ForConditionalGeneration",
"FlaxT5Model",
"FlaxT5PreTrainedModel",
]
if TYPE_CHECKING:
from .configuration_t5 import T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5Config
......@@ -85,6 +93,10 @@ if TYPE_CHECKING:
TFT5PreTrainedModel,
)
if is_flax_available():
from .modeling_flax_t5 import FlaxT5ForConditionalGeneration, FlaxT5Model, FlaxT5PreTrainedModel
else:
import importlib
import os
......
This diff is collapsed.
......@@ -570,6 +570,24 @@ class FlaxRobertaPreTrainedModel:
requires_backends(cls, ["flax"])
class FlaxT5ForConditionalGeneration:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
class FlaxT5Model:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
class FlaxViTForImageClassification:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
......
......@@ -72,7 +72,7 @@ def prepare_bart_inputs_dict(
}
class FlaxBartModelTester(unittest.TestCase):
class FlaxBartModelTester:
def __init__(
self,
parent,
......
This diff is collapsed.
......@@ -794,6 +794,21 @@ class T5ModelIntegrationTests(unittest.TestCase):
def tokenizer(self):
return T5Tokenizer.from_pretrained("t5-base")
@slow
def test_small_generation(self):
model = T5ForConditionalGeneration.from_pretrained("t5-small").to(torch_device)
model.config.max_length = 8
model.config.num_beams = 1
model.config.do_sample = False
tokenizer = T5Tokenizer.from_pretrained("t5-small")
input_ids = tokenizer("summarize: Hello there", return_tensors="pt").input_ids
sequences = model.generate(input_ids)
output_str = tokenizer.batch_decode(sequences, skip_special_tokens=True)[0]
self.assertTrue(output_str == "Hello there!")
@slow
def test_small_integration_test(self):
"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment