Unverified Commit e98233dd authored by Vasudev Gupta's avatar Vasudev Gupta Committed by GitHub
Browse files

Flax T5 (#12150)



* copy pytorch-t5

* init

* boom boom

* forward pass same

* make generation work

* add more tests

* make test work

* finish normal tests

* make fix-copies

* finish quality

* correct slow example

* correct slow test

* version table

* upload models

* Update tests/test_modeling_flax_t5.py

* correct incorrectly deleted line
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarPatrick von Platen <patrick@huggingface.co>
parent 7d4cfa3b
...@@ -396,7 +396,7 @@ Flax), PyTorch, and/or TensorFlow. ...@@ -396,7 +396,7 @@ Flax), PyTorch, and/or TensorFlow.
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| SqueezeBERT | ✅ | ✅ | ✅ | ❌ | ❌ | | SqueezeBERT | ✅ | ✅ | ✅ | ❌ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| T5 | ✅ | ✅ | ✅ | ✅ | | | T5 | ✅ | ✅ | ✅ | ✅ | |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| TAPAS | ✅ | ❌ | ✅ | ❌ | ❌ | | TAPAS | ✅ | ❌ | ✅ | ❌ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
......
...@@ -160,3 +160,15 @@ TFT5EncoderModel ...@@ -160,3 +160,15 @@ TFT5EncoderModel
.. autoclass:: transformers.TFT5EncoderModel .. autoclass:: transformers.TFT5EncoderModel
:members: call :members: call
FlaxT5Model
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxT5Model
:members: __call__, encode, decode
FlaxT5ForConditionalGeneration
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxT5ForConditionalGeneration
:members: __call__, encode, decode
...@@ -114,6 +114,7 @@ _deps = [ ...@@ -114,6 +114,7 @@ _deps = [
"onnxruntime-tools>=1.4.2", "onnxruntime-tools>=1.4.2",
"onnxruntime>=1.4.0", "onnxruntime>=1.4.0",
"optuna", "optuna",
"optax>=0.0.8",
"packaging", "packaging",
"parameterized", "parameterized",
"protobuf", "protobuf",
...@@ -234,7 +235,7 @@ if os.name == "nt": # windows ...@@ -234,7 +235,7 @@ if os.name == "nt": # windows
extras["flax"] = [] # jax is not supported on windows extras["flax"] = [] # jax is not supported on windows
else: else:
extras["retrieval"] = deps_list("faiss-cpu", "datasets") extras["retrieval"] = deps_list("faiss-cpu", "datasets")
extras["flax"] = deps_list("jax", "jaxlib", "flax") extras["flax"] = deps_list("jax", "jaxlib", "flax", "optax")
extras["tokenizers"] = deps_list("tokenizers") extras["tokenizers"] = deps_list("tokenizers")
extras["onnxruntime"] = deps_list("onnxruntime", "onnxruntime-tools") extras["onnxruntime"] = deps_list("onnxruntime", "onnxruntime-tools")
......
...@@ -1597,6 +1597,7 @@ if is_flax_available(): ...@@ -1597,6 +1597,7 @@ if is_flax_available():
"FlaxRobertaPreTrainedModel", "FlaxRobertaPreTrainedModel",
] ]
) )
_import_structure["models.t5"].extend(["FlaxT5ForConditionalGeneration", "FlaxT5Model"])
_import_structure["models.vit"].extend(["FlaxViTForImageClassification", "FlaxViTModel"]) _import_structure["models.vit"].extend(["FlaxViTForImageClassification", "FlaxViTModel"])
else: else:
from .utils import dummy_flax_objects from .utils import dummy_flax_objects
...@@ -2920,6 +2921,7 @@ if TYPE_CHECKING: ...@@ -2920,6 +2921,7 @@ if TYPE_CHECKING:
FlaxRobertaModel, FlaxRobertaModel,
FlaxRobertaPreTrainedModel, FlaxRobertaPreTrainedModel,
) )
from .models.t5 import FlaxT5ForConditionalGeneration, FlaxT5Model
from .models.vit import FlaxViTForImageClassification, FlaxViTModel from .models.vit import FlaxViTForImageClassification, FlaxViTModel
else: else:
# Import the same objects as dummies to get them in the namespace. # Import the same objects as dummies to get them in the namespace.
......
...@@ -31,6 +31,7 @@ deps = { ...@@ -31,6 +31,7 @@ deps = {
"onnxruntime-tools": "onnxruntime-tools>=1.4.2", "onnxruntime-tools": "onnxruntime-tools>=1.4.2",
"onnxruntime": "onnxruntime>=1.4.0", "onnxruntime": "onnxruntime>=1.4.0",
"optuna": "optuna", "optuna": "optuna",
"optax": "optax>=0.0.8",
"packaging": "packaging", "packaging": "packaging",
"parameterized": "parameterized", "parameterized": "parameterized",
"protobuf": "protobuf", "protobuf": "protobuf",
......
...@@ -62,6 +62,7 @@ from ..roberta.modeling_flax_roberta import ( ...@@ -62,6 +62,7 @@ from ..roberta.modeling_flax_roberta import (
FlaxRobertaForTokenClassification, FlaxRobertaForTokenClassification,
FlaxRobertaModel, FlaxRobertaModel,
) )
from ..t5.modeling_flax_t5 import FlaxT5ForConditionalGeneration, FlaxT5Model
from ..vit.modeling_flax_vit import FlaxViTForImageClassification, FlaxViTModel from ..vit.modeling_flax_vit import FlaxViTForImageClassification, FlaxViTModel
from .auto_factory import auto_class_factory from .auto_factory import auto_class_factory
from .configuration_auto import ( from .configuration_auto import (
...@@ -72,6 +73,7 @@ from .configuration_auto import ( ...@@ -72,6 +73,7 @@ from .configuration_auto import (
ElectraConfig, ElectraConfig,
GPT2Config, GPT2Config,
RobertaConfig, RobertaConfig,
T5Config,
ViTConfig, ViTConfig,
) )
...@@ -90,6 +92,7 @@ FLAX_MODEL_MAPPING = OrderedDict( ...@@ -90,6 +92,7 @@ FLAX_MODEL_MAPPING = OrderedDict(
(ElectraConfig, FlaxElectraModel), (ElectraConfig, FlaxElectraModel),
(CLIPConfig, FlaxCLIPModel), (CLIPConfig, FlaxCLIPModel),
(ViTConfig, FlaxViTModel), (ViTConfig, FlaxViTModel),
(T5Config, FlaxT5Model),
] ]
) )
...@@ -101,6 +104,7 @@ FLAX_MODEL_FOR_PRETRAINING_MAPPING = OrderedDict( ...@@ -101,6 +104,7 @@ FLAX_MODEL_FOR_PRETRAINING_MAPPING = OrderedDict(
(BigBirdConfig, FlaxBigBirdForPreTraining), (BigBirdConfig, FlaxBigBirdForPreTraining),
(BartConfig, FlaxBartForConditionalGeneration), (BartConfig, FlaxBartForConditionalGeneration),
(ElectraConfig, FlaxElectraForPreTraining), (ElectraConfig, FlaxElectraForPreTraining),
(T5Config, FlaxT5ForConditionalGeneration),
] ]
) )
...@@ -115,6 +119,14 @@ FLAX_MODEL_FOR_MASKED_LM_MAPPING = OrderedDict( ...@@ -115,6 +119,14 @@ FLAX_MODEL_FOR_MASKED_LM_MAPPING = OrderedDict(
] ]
) )
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = OrderedDict(
[
# Model for Seq2Seq Causal LM mapping
(BartConfig, FlaxBartForConditionalGeneration),
(T5Config, FlaxT5ForConditionalGeneration),
]
)
FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = OrderedDict( FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = OrderedDict(
[ [
# Model for Image-classsification # Model for Image-classsification
...@@ -234,3 +246,9 @@ FlaxAutoModelForNextSentencePrediction = auto_class_factory( ...@@ -234,3 +246,9 @@ FlaxAutoModelForNextSentencePrediction = auto_class_factory(
FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING, FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
head_doc="next sentence prediction", head_doc="next sentence prediction",
) )
FlaxAutoModelForSeq2SeqLM = auto_class_factory(
"FlaxAutoModelForSeq2SeqLM",
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
head_doc="sequence-to-sequence language modeling",
)
...@@ -229,7 +229,6 @@ class FlaxBartAttention(nn.Module): ...@@ -229,7 +229,6 @@ class FlaxBartAttention(nn.Module):
embed_dim: int embed_dim: int
num_heads: int num_heads: int
dropout: float = 0.0 dropout: float = 0.0
is_decoder: bool = False
causal: bool = False causal: bool = False
bias: bool = True bias: bool = True
dtype: jnp.dtype = jnp.float32 # the dtype of the computation dtype: jnp.dtype = jnp.float32 # the dtype of the computation
...@@ -510,7 +509,6 @@ class FlaxBartDecoderLayer(nn.Module): ...@@ -510,7 +509,6 @@ class FlaxBartDecoderLayer(nn.Module):
embed_dim=self.embed_dim, embed_dim=self.embed_dim,
num_heads=self.config.decoder_attention_heads, num_heads=self.config.decoder_attention_heads,
dropout=self.config.attention_dropout, dropout=self.config.attention_dropout,
is_decoder=True,
causal=True, causal=True,
) )
self.dropout_layer = nn.Dropout(rate=self.config.dropout) self.dropout_layer = nn.Dropout(rate=self.config.dropout)
...@@ -523,7 +521,6 @@ class FlaxBartDecoderLayer(nn.Module): ...@@ -523,7 +521,6 @@ class FlaxBartDecoderLayer(nn.Module):
embed_dim=self.embed_dim, embed_dim=self.embed_dim,
num_heads=self.config.decoder_attention_heads, num_heads=self.config.decoder_attention_heads,
dropout=self.config.attention_dropout, dropout=self.config.attention_dropout,
is_decoder=True,
) )
self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype) self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype)
self.fc1 = nn.Dense( self.fc1 = nn.Dense(
......
...@@ -20,6 +20,7 @@ from typing import TYPE_CHECKING ...@@ -20,6 +20,7 @@ from typing import TYPE_CHECKING
from ...file_utils import ( from ...file_utils import (
_BaseLazyModule, _BaseLazyModule,
is_flax_available,
is_sentencepiece_available, is_sentencepiece_available,
is_tf_available, is_tf_available,
is_tokenizers_available, is_tokenizers_available,
...@@ -56,6 +57,13 @@ if is_tf_available(): ...@@ -56,6 +57,13 @@ if is_tf_available():
"TFT5PreTrainedModel", "TFT5PreTrainedModel",
] ]
if is_flax_available():
_import_structure["modeling_flax_t5"] = [
"FlaxT5ForConditionalGeneration",
"FlaxT5Model",
"FlaxT5PreTrainedModel",
]
if TYPE_CHECKING: if TYPE_CHECKING:
from .configuration_t5 import T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5Config from .configuration_t5 import T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5Config
...@@ -85,6 +93,10 @@ if TYPE_CHECKING: ...@@ -85,6 +93,10 @@ if TYPE_CHECKING:
TFT5PreTrainedModel, TFT5PreTrainedModel,
) )
if is_flax_available():
from .modeling_flax_t5 import FlaxT5ForConditionalGeneration, FlaxT5Model, FlaxT5PreTrainedModel
else: else:
import importlib import importlib
import os import os
......
This diff is collapsed.
...@@ -570,6 +570,24 @@ class FlaxRobertaPreTrainedModel: ...@@ -570,6 +570,24 @@ class FlaxRobertaPreTrainedModel:
requires_backends(cls, ["flax"]) requires_backends(cls, ["flax"])
class FlaxT5ForConditionalGeneration:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
class FlaxT5Model:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
class FlaxViTForImageClassification: class FlaxViTForImageClassification:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"]) requires_backends(self, ["flax"])
......
...@@ -72,7 +72,7 @@ def prepare_bart_inputs_dict( ...@@ -72,7 +72,7 @@ def prepare_bart_inputs_dict(
} }
class FlaxBartModelTester(unittest.TestCase): class FlaxBartModelTester:
def __init__( def __init__(
self, self,
parent, parent,
......
This diff is collapsed.
...@@ -794,6 +794,21 @@ class T5ModelIntegrationTests(unittest.TestCase): ...@@ -794,6 +794,21 @@ class T5ModelIntegrationTests(unittest.TestCase):
def tokenizer(self): def tokenizer(self):
return T5Tokenizer.from_pretrained("t5-base") return T5Tokenizer.from_pretrained("t5-base")
@slow
def test_small_generation(self):
model = T5ForConditionalGeneration.from_pretrained("t5-small").to(torch_device)
model.config.max_length = 8
model.config.num_beams = 1
model.config.do_sample = False
tokenizer = T5Tokenizer.from_pretrained("t5-small")
input_ids = tokenizer("summarize: Hello there", return_tensors="pt").input_ids
sequences = model.generate(input_ids)
output_str = tokenizer.batch_decode(sequences, skip_special_tokens=True)[0]
self.assertTrue(output_str == "Hello there!")
@slow @slow
def test_small_integration_test(self): def test_small_integration_test(self):
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment