Unverified Commit ad25fd62 authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

Add FlaxCLIP (#11883)

* add flax CLIP

* default input_shape

* add tests

* fix test

* fix name

* fix docs

* fix shapes

* attend at least 1 token

* flax conv to torch conv

* return floats

* fix equivalence tests

* fix import

* return attention_weights and update tests

* fix dosctrings

* address patricks comments

* input_shape arg

* add tests for get_image_features and get_text_features methods

* fix tests
parent cfca638a
...@@ -304,7 +304,7 @@ Flax), PyTorch, and/or TensorFlow. ...@@ -304,7 +304,7 @@ Flax), PyTorch, and/or TensorFlow.
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| BlenderbotSmall | ✅ | ❌ | ✅ | ✅ | ❌ | | BlenderbotSmall | ✅ | ❌ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| CLIP | ✅ | ✅ | ✅ | ❌ | | | CLIP | ✅ | ✅ | ✅ | ❌ | |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| CTRL | ✅ | ❌ | ✅ | ✅ | ❌ | | CTRL | ✅ | ❌ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
......
...@@ -152,3 +152,24 @@ CLIPVisionModel ...@@ -152,3 +152,24 @@ CLIPVisionModel
.. autoclass:: transformers.CLIPVisionModel .. autoclass:: transformers.CLIPVisionModel
:members: forward :members: forward
FlaxCLIPModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxCLIPModel
:members: __call__, get_text_features, get_image_features
FlaxCLIPTextModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxCLIPTextModel
:members: __call__
FlaxCLIPVisionModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxCLIPVisionModel
:members: __call__
...@@ -1482,6 +1482,14 @@ if is_flax_available(): ...@@ -1482,6 +1482,14 @@ if is_flax_available():
"FlaxBertPreTrainedModel", "FlaxBertPreTrainedModel",
] ]
) )
_import_structure["models.clip"].extend(
[
"FlaxCLIPModel",
"FlaxCLIPPreTrainedModel",
"FlaxCLIPTextModel",
"FlaxCLIPVisionModel",
]
)
_import_structure["models.electra"].extend( _import_structure["models.electra"].extend(
[ [
"FlaxElectraForMaskedLM", "FlaxElectraForMaskedLM",
...@@ -2743,6 +2751,7 @@ if TYPE_CHECKING: ...@@ -2743,6 +2751,7 @@ if TYPE_CHECKING:
FlaxBertModel, FlaxBertModel,
FlaxBertPreTrainedModel, FlaxBertPreTrainedModel,
) )
from .models.clip import FlaxCLIPModel, FlaxCLIPPreTrainedModel, FlaxCLIPTextModel, FlaxCLIPVisionModel
from .models.electra import ( from .models.electra import (
FlaxElectraForMaskedLM, FlaxElectraForMaskedLM,
FlaxElectraForMultipleChoice, FlaxElectraForMultipleChoice,
......
...@@ -90,7 +90,12 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model): ...@@ -90,7 +90,12 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
pt_tuple_key = pt_tuple_key[:-1] + ("scale",) pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
if pt_tuple_key[-1] == "weight" and pt_tuple_key[:-1] + ("embedding",) in random_flax_state_dict: if pt_tuple_key[-1] == "weight" and pt_tuple_key[:-1] + ("embedding",) in random_flax_state_dict:
pt_tuple_key = pt_tuple_key[:-1] + ("embedding",) pt_tuple_key = pt_tuple_key[:-1] + ("embedding",)
elif pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 4 and pt_tuple_key not in random_flax_state_dict:
# conv layer
pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
pt_tensor = pt_tensor.transpose(2, 3, 1, 0)
elif pt_tuple_key[-1] == "weight" and pt_tuple_key not in random_flax_state_dict: elif pt_tuple_key[-1] == "weight" and pt_tuple_key not in random_flax_state_dict:
# linear layer
pt_tuple_key = pt_tuple_key[:-1] + ("kernel",) pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
pt_tensor = pt_tensor.T pt_tensor = pt_tensor.T
elif pt_tuple_key[-1] == "gamma": elif pt_tuple_key[-1] == "gamma":
...@@ -170,7 +175,12 @@ def load_flax_weights_in_pytorch_model(pt_model, flax_state): ...@@ -170,7 +175,12 @@ def load_flax_weights_in_pytorch_model(pt_model, flax_state):
flax_key_tuple = (pt_model.base_model_prefix,) + flax_key_tuple flax_key_tuple = (pt_model.base_model_prefix,) + flax_key_tuple
# rename flax weights to PyTorch format # rename flax weights to PyTorch format
if flax_key_tuple[-1] == "kernel" and ".".join(flax_key_tuple) not in pt_model_dict: if flax_key_tuple[-1] == "kernel" and flax_tensor.ndim == 4 and ".".join(flax_key_tuple) not in pt_model_dict:
# conv layer
flax_key_tuple = flax_key_tuple[:-1] + ("weight",)
flax_tensor = jnp.transpose(flax_tensor, (3, 2, 0, 1))
elif flax_key_tuple[-1] == "kernel" and ".".join(flax_key_tuple) not in pt_model_dict:
# linear layer
flax_key_tuple = flax_key_tuple[:-1] + ("weight",) flax_key_tuple = flax_key_tuple[:-1] + ("weight",)
flax_tensor = flax_tensor.T flax_tensor = flax_tensor.T
elif flax_key_tuple[-1] in ["scale", "embedding"]: elif flax_key_tuple[-1] in ["scale", "embedding"]:
......
...@@ -49,12 +49,17 @@ from .utils import logging ...@@ -49,12 +49,17 @@ from .utils import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
def quick_gelu(x):
return x * jax.nn.sigmoid(1.702 * x)
ACT2FN = { ACT2FN = {
"gelu": partial(nn.gelu, approximate=False), "gelu": partial(nn.gelu, approximate=False),
"relu": nn.relu, "relu": nn.relu,
"silu": nn.swish, "silu": nn.swish,
"swish": nn.swish, "swish": nn.swish,
"gelu_new": partial(nn.gelu, approximate=True), "gelu_new": partial(nn.gelu, approximate=True),
"quick_gelu": quick_gelu,
} }
......
...@@ -28,6 +28,7 @@ from ..bert.modeling_flax_bert import ( ...@@ -28,6 +28,7 @@ from ..bert.modeling_flax_bert import (
FlaxBertForTokenClassification, FlaxBertForTokenClassification,
FlaxBertModel, FlaxBertModel,
) )
from ..clip.modeling_flax_clip import FlaxCLIPModel
from ..electra.modeling_flax_electra import ( from ..electra.modeling_flax_electra import (
FlaxElectraForMaskedLM, FlaxElectraForMaskedLM,
FlaxElectraForMultipleChoice, FlaxElectraForMultipleChoice,
...@@ -47,7 +48,7 @@ from ..roberta.modeling_flax_roberta import ( ...@@ -47,7 +48,7 @@ from ..roberta.modeling_flax_roberta import (
FlaxRobertaModel, FlaxRobertaModel,
) )
from .auto_factory import auto_class_factory from .auto_factory import auto_class_factory
from .configuration_auto import BertConfig, ElectraConfig, GPT2Config, RobertaConfig from .configuration_auto import BertConfig, CLIPConfig, ElectraConfig, GPT2Config, RobertaConfig
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -60,6 +61,7 @@ FLAX_MODEL_MAPPING = OrderedDict( ...@@ -60,6 +61,7 @@ FLAX_MODEL_MAPPING = OrderedDict(
(BertConfig, FlaxBertModel), (BertConfig, FlaxBertModel),
(GPT2Config, FlaxGPT2Model), (GPT2Config, FlaxGPT2Model),
(ElectraConfig, FlaxElectraModel), (ElectraConfig, FlaxElectraModel),
(CLIPConfig, FlaxCLIPModel),
] ]
) )
......
...@@ -17,7 +17,13 @@ ...@@ -17,7 +17,13 @@
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from ...file_utils import _BaseLazyModule, is_tokenizers_available, is_torch_available, is_vision_available from ...file_utils import (
_BaseLazyModule,
is_flax_available,
is_tokenizers_available,
is_torch_available,
is_vision_available,
)
_import_structure = { _import_structure = {
...@@ -41,6 +47,14 @@ if is_torch_available(): ...@@ -41,6 +47,14 @@ if is_torch_available():
"CLIPVisionModel", "CLIPVisionModel",
] ]
if is_flax_available():
_import_structure["modeling_flax_clip"] = [
"FlaxCLIPModel",
"FlaxCLIPPreTrainedModel",
"FlaxCLIPTextModel",
"FlaxCLIPVisionModel",
]
if TYPE_CHECKING: if TYPE_CHECKING:
from .configuration_clip import CLIP_PRETRAINED_CONFIG_ARCHIVE_MAP, CLIPConfig, CLIPTextConfig, CLIPVisionConfig from .configuration_clip import CLIP_PRETRAINED_CONFIG_ARCHIVE_MAP, CLIPConfig, CLIPTextConfig, CLIPVisionConfig
...@@ -62,6 +76,9 @@ if TYPE_CHECKING: ...@@ -62,6 +76,9 @@ if TYPE_CHECKING:
CLIPVisionModel, CLIPVisionModel,
) )
if is_flax_available():
from .modeling_flax_clip import FlaxCLIPModel, FlaxCLIPPreTrainedModel, FlaxCLIPTextModel, FlaxCLIPVisionModel
else: else:
import importlib import importlib
......
...@@ -95,7 +95,7 @@ class CLIPTextConfig(PretrainedConfig): ...@@ -95,7 +95,7 @@ class CLIPTextConfig(PretrainedConfig):
num_attention_heads=8, num_attention_heads=8,
max_position_embeddings=77, max_position_embeddings=77,
hidden_act="quick_gelu", hidden_act="quick_gelu",
layer_norm_eps=1e-5, layer_norm_eps=0.00001,
dropout=0.0, dropout=0.0,
attention_dropout=0.0, attention_dropout=0.0,
initializer_range=0.02, initializer_range=0.02,
...@@ -189,7 +189,7 @@ class CLIPVisionConfig(PretrainedConfig): ...@@ -189,7 +189,7 @@ class CLIPVisionConfig(PretrainedConfig):
image_size=224, image_size=224,
patch_size=32, patch_size=32,
hidden_act="quick_gelu", hidden_act="quick_gelu",
layer_norm_eps=1e-5, layer_norm_eps=0.00001,
dropout=0.0, dropout=0.0,
attention_dropout=0.0, attention_dropout=0.0,
initializer_range=0.02, initializer_range=0.02,
......
This diff is collapsed.
...@@ -222,6 +222,42 @@ class FlaxBertPreTrainedModel: ...@@ -222,6 +222,42 @@ class FlaxBertPreTrainedModel:
requires_backends(self, ["flax"]) requires_backends(self, ["flax"])
class FlaxCLIPModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxCLIPPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxCLIPTextModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxCLIPVisionModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxElectraForMaskedLM: class FlaxElectraForMaskedLM:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"]) requires_backends(self, ["flax"])
......
This diff is collapsed.
...@@ -60,6 +60,22 @@ def ids_tensor(shape, vocab_size, rng=None): ...@@ -60,6 +60,22 @@ def ids_tensor(shape, vocab_size, rng=None):
return output return output
def floats_tensor(shape, scale=1.0, rng=None, name=None):
"""Creates a random float32 tensor"""
if rng is None:
rng = random.Random()
total_dims = 1
for dim in shape:
total_dims *= dim
values = []
for _ in range(total_dims):
values.append(rng.random() * scale)
return np.array(values, dtype=jnp.float32).reshape(shape)
def random_attention_mask(shape, rng=None): def random_attention_mask(shape, rng=None):
attn_mask = ids_tensor(shape, vocab_size=2, rng=rng) attn_mask = ids_tensor(shape, vocab_size=2, rng=rng)
# make sure that at least one token is attended to for each batch # make sure that at least one token is attended to for each batch
......
...@@ -93,6 +93,8 @@ IGNORE_NON_AUTO_CONFIGURED = [ ...@@ -93,6 +93,8 @@ IGNORE_NON_AUTO_CONFIGURED = [
# models to ignore for model xxx mapping # models to ignore for model xxx mapping
"CLIPTextModel", "CLIPTextModel",
"CLIPVisionModel", "CLIPVisionModel",
"FlaxCLIPTextModel",
"FlaxCLIPVisionModel",
"DPRReader", "DPRReader",
"DPRSpanPredictor", "DPRSpanPredictor",
"FlaubertForQuestionAnswering", "FlaubertForQuestionAnswering",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment