Unverified Commit ad25fd62 authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

Add FlaxCLIP (#11883)

* add flax CLIP

* default input_shape

* add tests

* fix test

* fix name

* fix docs

* fix shapes

* attend at least 1 token

* flax conv to torch conv

* return floats

* fix equivalence tests

* fix import

* return attention_weights and update tests

* fix dosctrings

* address patricks comments

* input_shape arg

* add tests for get_image_features and get_text_features methods

* fix tests
parent cfca638a
...@@ -304,7 +304,7 @@ Flax), PyTorch, and/or TensorFlow. ...@@ -304,7 +304,7 @@ Flax), PyTorch, and/or TensorFlow.
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| BlenderbotSmall | ✅ | ❌ | ✅ | ✅ | ❌ | | BlenderbotSmall | ✅ | ❌ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| CLIP | ✅ | ✅ | ✅ | ❌ | | | CLIP | ✅ | ✅ | ✅ | ❌ | |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| CTRL | ✅ | ❌ | ✅ | ✅ | ❌ | | CTRL | ✅ | ❌ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
......
...@@ -152,3 +152,24 @@ CLIPVisionModel ...@@ -152,3 +152,24 @@ CLIPVisionModel
.. autoclass:: transformers.CLIPVisionModel .. autoclass:: transformers.CLIPVisionModel
:members: forward :members: forward
FlaxCLIPModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxCLIPModel
:members: __call__, get_text_features, get_image_features
FlaxCLIPTextModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxCLIPTextModel
:members: __call__
FlaxCLIPVisionModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxCLIPVisionModel
:members: __call__
...@@ -1482,6 +1482,14 @@ if is_flax_available(): ...@@ -1482,6 +1482,14 @@ if is_flax_available():
"FlaxBertPreTrainedModel", "FlaxBertPreTrainedModel",
] ]
) )
_import_structure["models.clip"].extend(
[
"FlaxCLIPModel",
"FlaxCLIPPreTrainedModel",
"FlaxCLIPTextModel",
"FlaxCLIPVisionModel",
]
)
_import_structure["models.electra"].extend( _import_structure["models.electra"].extend(
[ [
"FlaxElectraForMaskedLM", "FlaxElectraForMaskedLM",
...@@ -2743,6 +2751,7 @@ if TYPE_CHECKING: ...@@ -2743,6 +2751,7 @@ if TYPE_CHECKING:
FlaxBertModel, FlaxBertModel,
FlaxBertPreTrainedModel, FlaxBertPreTrainedModel,
) )
from .models.clip import FlaxCLIPModel, FlaxCLIPPreTrainedModel, FlaxCLIPTextModel, FlaxCLIPVisionModel
from .models.electra import ( from .models.electra import (
FlaxElectraForMaskedLM, FlaxElectraForMaskedLM,
FlaxElectraForMultipleChoice, FlaxElectraForMultipleChoice,
......
...@@ -90,7 +90,12 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model): ...@@ -90,7 +90,12 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
pt_tuple_key = pt_tuple_key[:-1] + ("scale",) pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
if pt_tuple_key[-1] == "weight" and pt_tuple_key[:-1] + ("embedding",) in random_flax_state_dict: if pt_tuple_key[-1] == "weight" and pt_tuple_key[:-1] + ("embedding",) in random_flax_state_dict:
pt_tuple_key = pt_tuple_key[:-1] + ("embedding",) pt_tuple_key = pt_tuple_key[:-1] + ("embedding",)
elif pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 4 and pt_tuple_key not in random_flax_state_dict:
# conv layer
pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
pt_tensor = pt_tensor.transpose(2, 3, 1, 0)
elif pt_tuple_key[-1] == "weight" and pt_tuple_key not in random_flax_state_dict: elif pt_tuple_key[-1] == "weight" and pt_tuple_key not in random_flax_state_dict:
# linear layer
pt_tuple_key = pt_tuple_key[:-1] + ("kernel",) pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
pt_tensor = pt_tensor.T pt_tensor = pt_tensor.T
elif pt_tuple_key[-1] == "gamma": elif pt_tuple_key[-1] == "gamma":
...@@ -170,7 +175,12 @@ def load_flax_weights_in_pytorch_model(pt_model, flax_state): ...@@ -170,7 +175,12 @@ def load_flax_weights_in_pytorch_model(pt_model, flax_state):
flax_key_tuple = (pt_model.base_model_prefix,) + flax_key_tuple flax_key_tuple = (pt_model.base_model_prefix,) + flax_key_tuple
# rename flax weights to PyTorch format # rename flax weights to PyTorch format
if flax_key_tuple[-1] == "kernel" and ".".join(flax_key_tuple) not in pt_model_dict: if flax_key_tuple[-1] == "kernel" and flax_tensor.ndim == 4 and ".".join(flax_key_tuple) not in pt_model_dict:
# conv layer
flax_key_tuple = flax_key_tuple[:-1] + ("weight",)
flax_tensor = jnp.transpose(flax_tensor, (3, 2, 0, 1))
elif flax_key_tuple[-1] == "kernel" and ".".join(flax_key_tuple) not in pt_model_dict:
# linear layer
flax_key_tuple = flax_key_tuple[:-1] + ("weight",) flax_key_tuple = flax_key_tuple[:-1] + ("weight",)
flax_tensor = flax_tensor.T flax_tensor = flax_tensor.T
elif flax_key_tuple[-1] in ["scale", "embedding"]: elif flax_key_tuple[-1] in ["scale", "embedding"]:
......
...@@ -49,12 +49,17 @@ from .utils import logging ...@@ -49,12 +49,17 @@ from .utils import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
def quick_gelu(x):
return x * jax.nn.sigmoid(1.702 * x)
ACT2FN = { ACT2FN = {
"gelu": partial(nn.gelu, approximate=False), "gelu": partial(nn.gelu, approximate=False),
"relu": nn.relu, "relu": nn.relu,
"silu": nn.swish, "silu": nn.swish,
"swish": nn.swish, "swish": nn.swish,
"gelu_new": partial(nn.gelu, approximate=True), "gelu_new": partial(nn.gelu, approximate=True),
"quick_gelu": quick_gelu,
} }
......
...@@ -28,6 +28,7 @@ from ..bert.modeling_flax_bert import ( ...@@ -28,6 +28,7 @@ from ..bert.modeling_flax_bert import (
FlaxBertForTokenClassification, FlaxBertForTokenClassification,
FlaxBertModel, FlaxBertModel,
) )
from ..clip.modeling_flax_clip import FlaxCLIPModel
from ..electra.modeling_flax_electra import ( from ..electra.modeling_flax_electra import (
FlaxElectraForMaskedLM, FlaxElectraForMaskedLM,
FlaxElectraForMultipleChoice, FlaxElectraForMultipleChoice,
...@@ -47,7 +48,7 @@ from ..roberta.modeling_flax_roberta import ( ...@@ -47,7 +48,7 @@ from ..roberta.modeling_flax_roberta import (
FlaxRobertaModel, FlaxRobertaModel,
) )
from .auto_factory import auto_class_factory from .auto_factory import auto_class_factory
from .configuration_auto import BertConfig, ElectraConfig, GPT2Config, RobertaConfig from .configuration_auto import BertConfig, CLIPConfig, ElectraConfig, GPT2Config, RobertaConfig
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -60,6 +61,7 @@ FLAX_MODEL_MAPPING = OrderedDict( ...@@ -60,6 +61,7 @@ FLAX_MODEL_MAPPING = OrderedDict(
(BertConfig, FlaxBertModel), (BertConfig, FlaxBertModel),
(GPT2Config, FlaxGPT2Model), (GPT2Config, FlaxGPT2Model),
(ElectraConfig, FlaxElectraModel), (ElectraConfig, FlaxElectraModel),
(CLIPConfig, FlaxCLIPModel),
] ]
) )
......
...@@ -17,7 +17,13 @@ ...@@ -17,7 +17,13 @@
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from ...file_utils import _BaseLazyModule, is_tokenizers_available, is_torch_available, is_vision_available from ...file_utils import (
_BaseLazyModule,
is_flax_available,
is_tokenizers_available,
is_torch_available,
is_vision_available,
)
_import_structure = { _import_structure = {
...@@ -41,6 +47,14 @@ if is_torch_available(): ...@@ -41,6 +47,14 @@ if is_torch_available():
"CLIPVisionModel", "CLIPVisionModel",
] ]
if is_flax_available():
_import_structure["modeling_flax_clip"] = [
"FlaxCLIPModel",
"FlaxCLIPPreTrainedModel",
"FlaxCLIPTextModel",
"FlaxCLIPVisionModel",
]
if TYPE_CHECKING: if TYPE_CHECKING:
from .configuration_clip import CLIP_PRETRAINED_CONFIG_ARCHIVE_MAP, CLIPConfig, CLIPTextConfig, CLIPVisionConfig from .configuration_clip import CLIP_PRETRAINED_CONFIG_ARCHIVE_MAP, CLIPConfig, CLIPTextConfig, CLIPVisionConfig
...@@ -62,6 +76,9 @@ if TYPE_CHECKING: ...@@ -62,6 +76,9 @@ if TYPE_CHECKING:
CLIPVisionModel, CLIPVisionModel,
) )
if is_flax_available():
from .modeling_flax_clip import FlaxCLIPModel, FlaxCLIPPreTrainedModel, FlaxCLIPTextModel, FlaxCLIPVisionModel
else: else:
import importlib import importlib
......
...@@ -95,7 +95,7 @@ class CLIPTextConfig(PretrainedConfig): ...@@ -95,7 +95,7 @@ class CLIPTextConfig(PretrainedConfig):
num_attention_heads=8, num_attention_heads=8,
max_position_embeddings=77, max_position_embeddings=77,
hidden_act="quick_gelu", hidden_act="quick_gelu",
layer_norm_eps=1e-5, layer_norm_eps=0.00001,
dropout=0.0, dropout=0.0,
attention_dropout=0.0, attention_dropout=0.0,
initializer_range=0.02, initializer_range=0.02,
...@@ -189,7 +189,7 @@ class CLIPVisionConfig(PretrainedConfig): ...@@ -189,7 +189,7 @@ class CLIPVisionConfig(PretrainedConfig):
image_size=224, image_size=224,
patch_size=32, patch_size=32,
hidden_act="quick_gelu", hidden_act="quick_gelu",
layer_norm_eps=1e-5, layer_norm_eps=0.00001,
dropout=0.0, dropout=0.0,
attention_dropout=0.0, attention_dropout=0.0,
initializer_range=0.02, initializer_range=0.02,
......
# coding=utf-8
# Copyright 2021 The OpenAI Team Authors, The Google Flax Team Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional, Tuple, Union
import flax
import flax.linen as nn
import jax
import jax.numpy as jnp
import jaxlib.xla_extension as jax_xla
from flax.core.frozen_dict import FrozenDict
from flax.linen import combine_masks, make_causal_mask
from flax.linen.attention import dot_product_attention_weights
from jax import lax
from ...file_utils import ModelOutput, add_start_docstrings
from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxBaseModelOutputWithPooling
from ...modeling_flax_utils import (
ACT2FN,
FlaxPreTrainedModel,
append_replace_return_docstrings,
overwrite_call_docstring,
)
from ...utils import logging
from .configuration_clip import CLIPConfig, CLIPTextConfig, CLIPVisionConfig
logger = logging.get_logger(__name__)
CLIP_START_DOCSTRING = r"""
This model inherits from :class:`~transformers.FlaxPreTrainedModel`. Check the superclass documentation for the
generic methods the library implements for all its model (such as downloading, saving and converting weights from
PyTorch models)
This model is also a Flax Linen `flax.linen.Module
<https://flax.readthedocs.io/en/latest/flax.linen.html#module>`__ subclass. Use it as a regular Flax linen Module
and refer to the Flax documentation for all matter related to general usage and behavior.
Finally, this model supports inherent JAX features such as:
- `Just-In-Time (JIT) compilation <https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit>`__
- `Automatic Differentiation <https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation>`__
- `Vectorization <https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap>`__
- `Parallelization <https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap>`__
Parameters:
config (:class:`~transformers.CLIPConfig`): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the
model weights.
"""
CLIP_TEXT_INPUTS_DOCSTRING = r"""
Args:
input_ids (:obj:`numpy.ndarray` of shape :obj:`(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Indices can be obtained using :class:`~transformers.CLIPTokenizer`. See
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
details.
`What are input IDs? <../glossary.html#input-ids>`__
attention_mask (:obj:`numpy.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__
position_ids (:obj:`numpy.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0,
config.max_position_embeddings - 1]``.
`What are position IDs? <../glossary.html#position-ids>`_
output_attentions (:obj:`bool`, `optional`):
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
tensors for more detail.
output_hidden_states (:obj:`bool`, `optional`):
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
more detail.
return_dict (:obj:`bool`, `optional`):
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
"""
CLIP_VISION_INPUTS_DOCSTRING = r"""
Args:
pixel_values (:obj:`numpy.ndarray` of shape :obj:`(batch_size, num_channels, height, width)`):
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
:class:`~transformers.CLIPFeatureExtractor`. See :meth:`transformers.CLIPFeatureExtractor.__call__` for
details.
output_attentions (:obj:`bool`, `optional`):
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
tensors for more detail.
output_hidden_states (:obj:`bool`, `optional`):
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
more detail.
return_dict (:obj:`bool`, `optional`):
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
"""
CLIP_INPUTS_DOCSTRING = r"""
Args:
input_ids (:obj:`numpy.ndarray` of shape :obj:`(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Indices can be obtained using :class:`~transformers.CLIPTokenizer`. See
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
details.
`What are input IDs? <../glossary.html#input-ids>`__
attention_mask (:obj:`numpy.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__
position_ids (:obj:`numpy.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0,
config.max_position_embeddings - 1]``.
`What are position IDs? <../glossary.html#position-ids>`_
pixel_values (:obj:`numpy.ndarray` of shape :obj:`(batch_size, num_channels, height, width)`):
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
:class:`~transformers.CLIPFeatureExtractor`. See :meth:`transformers.CLIPFeatureExtractor.__call__` for
details.
return_loss (:obj:`bool`, `optional`):
Whether or not to return the contrastive loss.
output_attentions (:obj:`bool`, `optional`):
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
tensors for more detail.
output_hidden_states (:obj:`bool`, `optional`):
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
more detail.
return_dict (:obj:`bool`, `optional`):
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
"""
@flax.struct.dataclass
class FlaxCLIPOutput(ModelOutput):
"""
Args:
logits_per_image:(:obj:`jax_xla.DeviceArray` of shape :obj:`(image_batch_size, text_batch_size)`):
The scaled dot product scores between :obj:`image_embeds` and :obj:`text_embeds`. This represents the
image-text similarity scores.
logits_per_text:(:obj:`jax_xla.DeviceArray` of shape :obj:`(text_batch_size, image_batch_size)`):
The scaled dot product scores between :obj:`text_embeds` and :obj:`image_embeds`. This represents the
text-image similarity scores.
text_embeds(:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, output_dim`):
The text embeddings obtained by applying the projection layer to the pooled output of
:class:`~transformers.FlaxCLIPTextModel`.
image_embeds(:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, output_dim`):
The image embeddings obtained by applying the projection layer to the pooled output of
:class:`~transformers.FlaxCLIPVisionModel`.
text_model_output(:obj:`FlaxBaseModelOutputWithPooling`):
The output of the :class:`~transformers.FlaxCLIPTextModel`.
vision_model_output(:obj:`FlaxBaseModelOutputWithPooling`):
The output of the :class:`~transformers.FlaxCLIPVisionModel`.
"""
logits_per_image: jax_xla.DeviceArray = None
logits_per_text: jax_xla.DeviceArray = None
text_embeds: jax_xla.DeviceArray = None
image_embeds: jax_xla.DeviceArray = None
text_model_output: FlaxBaseModelOutputWithPooling = None
vision_model_output: FlaxBaseModelOutputWithPooling = None
def to_tuple(self) -> Tuple[Any]:
return tuple(
self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
for k in self.keys()
)
class FlaxCLIPVisionEmbeddings(nn.Module):
config: CLIPVisionConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
embed_dim = self.config.hidden_size
image_size = self.config.image_size
patch_size = self.config.patch_size
self.class_embedding = self.param("class_embedding", jax.nn.initializers.normal(stddev=0.02), (embed_dim,))
self.patch_embedding = nn.Conv(
embed_dim,
kernel_size=(patch_size, patch_size),
strides=(patch_size, patch_size),
padding="VALID",
use_bias=False,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(),
)
self.num_patches = (image_size // patch_size) ** 2
num_positions = self.num_patches + 1
self.position_embedding = nn.Embed(num_positions, embed_dim, embedding_init=jax.nn.initializers.normal())
self.position_ids = jnp.expand_dims(jnp.arange(0, num_positions, dtype="i4"), axis=0)
def __call__(self, pixel_values):
patch_embeds = self.patch_embedding(pixel_values)
batch_size, height, width, channels = patch_embeds.shape
patch_embeds = jnp.reshape(patch_embeds, (batch_size, height * width, channels))
class_embeds = jnp.expand_dims(self.class_embedding, axis=(0, 1))
class_embeds = jnp.tile(class_embeds, (batch_size, 1, 1))
embeddings = jnp.concatenate([class_embeds, patch_embeds], axis=1)
embeddings = embeddings + self.position_embedding(self.position_ids)
return embeddings
class FlaxCLIPTextEmbeddings(nn.Module):
config: CLIPTextConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
embed_dim = self.config.hidden_size
self.token_embedding = nn.Embed(self.config.vocab_size, embed_dim, embedding_init=jax.nn.initializers.normal())
self.position_embedding = nn.Embed(
self.config.max_position_embeddings, embed_dim, embedding_init=jax.nn.initializers.normal()
)
self.position_ids = jnp.expand_dims(
jnp.arange(0, self.config.max_position_embeddings, dtype="i4"), axis=(0, 1)
)
def __call__(self, input_ids, position_ids):
input_embeds = self.token_embedding(input_ids.astype("i4"))
position_embeds = self.position_embedding(position_ids.astype("i4"))
embeddings = input_embeds + position_embeds
return embeddings
class FlaxCLIPAttention(nn.Module):
config: Union[CLIPTextConfig, CLIPVisionConfig]
dtype: jnp.dtype = jnp.float32
def setup(self):
self.embed_dim = self.config.hidden_size
self.num_heads = self.config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
assert (
self.head_dim * self.num_heads == self.embed_dim
), f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})."
self.scale = self.head_dim ** -0.5
self.dropout = self.config.attention_dropout
self.k_proj = nn.Dense(
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(0.01, dtype=self.dtype)
)
self.v_proj = nn.Dense(
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(0.01, dtype=self.dtype)
)
self.q_proj = nn.Dense(
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(0.01, dtype=self.dtype)
)
self.out_proj = nn.Dense(
self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(0.01, dtype=self.dtype)
)
self.causal = isinstance(self.config, CLIPTextConfig)
if self.causal:
self.causal_mask = make_causal_mask(jnp.ones((1, self.config.max_position_embeddings), dtype="i4"))
def _split_heads(self, hidden_states):
return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim))
def _merge_heads(self, hidden_states):
return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,))
def __call__(
self,
hidden_states,
attention_mask=None,
deterministic: bool = True,
output_attentions: bool = False,
):
query = self.q_proj(hidden_states)
key = self.k_proj(hidden_states)
value = self.v_proj(hidden_states)
query = self._split_heads(query)
key = self._split_heads(key)
value = self._split_heads(value)
causal_attention_mask = None
if self.causal:
query_length, key_length = query.shape[1], key.shape[1]
causal_attention_mask = self.causal_mask[:, :, key_length - query_length : key_length, :key_length]
if attention_mask is not None and causal_attention_mask is not None:
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
attention_mask = combine_masks(attention_mask, causal_attention_mask, dtype="i4")
elif causal_attention_mask is not None:
attention_mask = causal_attention_mask
elif attention_mask is not None:
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
if attention_mask is not None:
attention_bias = lax.select(
attention_mask > 0,
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
jnp.full(attention_mask.shape, -1e4).astype(self.dtype),
)
else:
attention_bias = None
dropout_rng = None
if not deterministic and self.dropout > 0.0:
dropout_rng = self.make_rng("dropout")
attn_weights = dot_product_attention_weights(
query,
key,
bias=attention_bias,
dropout_rng=dropout_rng,
dropout_rate=self.dropout,
deterministic=deterministic,
dtype=self.dtype,
precision=None,
)
attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value)
attn_output = self._merge_heads(attn_output)
attn_output = self.out_proj(attn_output)
outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
return outputs
class FlaxCLIPMLP(nn.Module):
config: Union[CLIPTextConfig, CLIPVisionConfig]
dtype: jnp.dtype = jnp.float32
def setup(self):
self.activation_fn = ACT2FN[self.config.hidden_act]
self.fc1 = nn.Dense(
self.config.intermediate_size,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(0.01, dtype=self.dtype),
)
self.fc2 = nn.Dense(
self.config.hidden_size, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(0.01, dtype=self.dtype)
)
def __call__(self, hidden_states):
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.fc2(hidden_states)
return hidden_states
class FlaxCLIPEncoderLayer(nn.Module):
config: Union[CLIPTextConfig, CLIPVisionConfig]
dtype: jnp.dtype = jnp.float32
def setup(self):
self.self_attn = FlaxCLIPAttention(self.config, dtype=self.dtype)
self.layer_norm1 = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
self.mlp = FlaxCLIPMLP(self.config, dtype=self.dtype)
self.layer_norm2 = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
def __call__(
self,
hidden_states,
attention_mask,
deterministic: bool = True,
output_attentions: bool = False,
):
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
attn_outputs = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
deterministic=deterministic,
output_attentions=output_attentions,
)
hidden_states = attn_outputs[0]
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.layer_norm2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += attn_outputs[1:]
return outputs
class FlaxCLIPLayerCollection(nn.Module):
config: Union[CLIPTextConfig, CLIPVisionConfig]
dtype: jnp.dtype = jnp.float32
def setup(self):
self.layers = [
FlaxCLIPEncoderLayer(self.config, name=str(i), dtype=self.dtype)
for i in range(self.config.num_hidden_layers)
]
def __call__(
self,
hidden_states,
attention_mask=None,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
all_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
for layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
layer_outputs = layer(
hidden_states, attention_mask, deterministic=deterministic, output_attentions=output_attentions
)
hidden_states = layer_outputs[0]
if output_attentions:
all_attentions += (layer_outputs[1],)
if output_hidden_states:
all_hidden_states += (hidden_states,)
outputs = (hidden_states,)
if not return_dict:
return tuple(v for v in outputs if v is not None)
return FlaxBaseModelOutput(
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
)
class FlaxCLIPEncoder(nn.Module):
config: Union[CLIPTextConfig, CLIPVisionConfig]
dtype: jnp.dtype = jnp.float32
def setup(self):
self.layers = FlaxCLIPLayerCollection(self.config, dtype=self.dtype)
def __call__(
self,
inputs_embeds,
attention_mask=None,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
return self.layers(
hidden_states=inputs_embeds,
attention_mask=attention_mask,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
class FlaxCLIPTextTransformer(nn.Module):
config: CLIPTextConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.embeddings = FlaxCLIPTextEmbeddings(self.config, dtype=self.dtype)
self.encoder = FlaxCLIPEncoder(self.config, dtype=self.dtype)
self.final_layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
def __call__(
self,
input_ids,
attention_mask,
position_ids,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
attention_mask=attention_mask,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
last_hidden_state = encoder_outputs[0]
last_hidden_state = self.final_layer_norm(last_hidden_state)
# text_embeds.shape = [batch_size, n_ctx, transformer.width]
# take features from the EOS embedding (eos_token_id is the highest number in each sequence)
pooled_output = last_hidden_state[jnp.arange(last_hidden_state.shape[0]), input_ids.argmax(axis=-1)]
if not return_dict:
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
return FlaxBaseModelOutputWithPooling(
last_hidden_state=last_hidden_state,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
class FlaxCLIPVisionTransformer(nn.Module):
config: CLIPVisionConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.embeddings = FlaxCLIPVisionEmbeddings(self.config, dtype=self.dtype)
self.pre_layrnorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
self.encoder = FlaxCLIPEncoder(self.config, dtype=self.dtype)
self.post_layernorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
def __call__(
self,
pixel_values=None,
deterministic: bool = True,
output_attentions=None,
output_hidden_states=None,
return_dict: bool = True,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
hidden_states = self.embeddings(pixel_values)
hidden_states = self.pre_layrnorm(hidden_states)
encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
last_hidden_state = encoder_outputs[0]
pooled_output = last_hidden_state[:, 0, :]
pooled_output = self.post_layernorm(pooled_output)
if not return_dict:
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
return FlaxBaseModelOutputWithPooling(
last_hidden_state=last_hidden_state,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
class FlaxCLIPTextPreTrainedModel(FlaxPreTrainedModel):
config_class = CLIPTextConfig
module_class: nn.Module = None
def __init__(
self, config: CLIPTextConfig, input_shape=(1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs
):
module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
# init input tensor
input_ids = jnp.zeros(input_shape, dtype="i4")
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape)
attention_mask = jnp.ones_like(input_ids)
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
return self.module.init(rngs, input_ids, attention_mask, position_ids)["params"]
def __call__(
self,
input_ids,
attention_mask=None,
position_ids=None,
params: dict = None,
dropout_rng: jax.random.PRNGKey = None,
train: bool = False,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.return_dict
if position_ids is None:
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
if attention_mask is None:
attention_mask = jnp.ones_like(input_ids)
# Handle any PRNG if needed
rngs = {}
if dropout_rng is not None:
rngs["dropout"] = dropout_rng
return self.module.apply(
{"params": params or self.params},
jnp.array(input_ids, dtype="i4"),
jnp.array(attention_mask, dtype="i4"),
jnp.array(position_ids, dtype="i4"),
not train,
output_attentions,
output_hidden_states,
return_dict,
rngs=rngs,
)
class FlaxCLIPVisionPreTrainedModel(FlaxPreTrainedModel):
config_class = CLIPVisionConfig
module_class: nn.Module = None
def __init__(
self,
config: CLIPVisionConfig,
input_shape: Optional[Tuple] = None,
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
**kwargs
):
if input_shape is None:
input_shape = (1, config.image_size, config.image_size, 3)
module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
# init input tensor
pixel_values = jax.random.normal(rng, input_shape)
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
return self.module.init(rngs, pixel_values)["params"]
def __call__(
self,
pixel_values,
params: dict = None,
dropout_rng: jax.random.PRNGKey = None,
train: bool = False,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.return_dict
pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1))
# Handle any PRNG if needed
rngs = {}
if dropout_rng is not None:
rngs["dropout"] = dropout_rng
return self.module.apply(
{"params": params or self.params},
jnp.array(pixel_values, dtype=jnp.float32),
not train,
output_attentions,
output_hidden_states,
return_dict,
rngs=rngs,
)
class FlaxCLIPPreTrainedModel(FlaxPreTrainedModel):
config_class = CLIPConfig
module_class: nn.Module = None
def __init__(
self,
config: CLIPConfig,
input_shape: Optional[Tuple] = None,
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
**kwargs
):
if input_shape is None:
input_shape = ((1, 1), (1, config.vision_config.image_size, config.vision_config.image_size, 3))
module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
# init input tensor
input_ids = jnp.zeros(input_shape[0], dtype="i4")
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape[0])
attention_mask = jnp.ones_like(input_ids)
pixel_values = jax.random.normal(rng, input_shape[1])
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
return self.module.init(rngs, input_ids, pixel_values, attention_mask, position_ids)["params"]
def __call__(
self,
input_ids,
pixel_values,
attention_mask=None,
position_ids=None,
params: dict = None,
dropout_rng: jax.random.PRNGKey = None,
train: bool = False,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.return_dict
if position_ids is None:
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
if attention_mask is None:
attention_mask = jnp.ones_like(input_ids)
pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1))
# Handle any PRNG if needed
rngs = {}
if dropout_rng is not None:
rngs["dropout"] = dropout_rng
return self.module.apply(
{"params": params or self.params},
jnp.array(input_ids, dtype="i4"),
jnp.array(pixel_values, dtype=jnp.float32),
jnp.array(attention_mask, dtype="i4"),
jnp.array(position_ids, dtype="i4"),
not train,
output_attentions,
output_hidden_states,
return_dict,
rngs=rngs,
)
def get_text_features(
self, input_ids, attention_mask=None, position_ids=None, dropout_rng: jax.random.PRNGKey = None, train=False
):
r"""
Args:
input_ids (:obj:`numpy.ndarray` of shape :obj:`(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
provide it.
Indices can be obtained using :class:`~transformers.CLIPTokenizer`. See
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__`
for details.
`What are input IDs? <../glossary.html#input-ids>`__
Returns:
text_features (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, output_dim`): The text embeddings
obtained by applying the projection layer to the pooled output of :class:`~transformers.FlaxCLIPTextModel`.
"""
if position_ids is None:
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
if attention_mask is None:
attention_mask = jnp.ones_like(input_ids)
# Handle any PRNG if needed
rngs = {}
if dropout_rng is not None:
rngs["dropout"] = dropout_rng
def _get_features(module, input_ids, attention_mask, position_ids, deterministic):
text_outputs = module.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
deterministic=deterministic,
)
pooled_output = text_outputs[1]
text_features = module.text_projection(pooled_output)
return text_features
return self.module.apply(
{"params": self.params},
jnp.array(input_ids, dtype="i4"),
jnp.array(attention_mask, dtype="i4"),
jnp.array(position_ids, dtype="i4"),
not train,
method=_get_features,
rngs=rngs,
)
def get_image_features(self, pixel_values, dropout_rng: jax.random.PRNGKey = None, train=False):
r"""
Args:
pixel_values (:obj:`numpy.ndarray` of shape :obj:`(batch_size, num_channels, height, width)`):
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained
using :class:`~transformers.CLIPFeatureExtractor`. See
:meth:`transformers.CLIPFeatureExtractor.__call__` for details.
Returns:
image_features (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, output_dim`): The image embeddings
obtained by applying the projection layer to the pooled output of
:class:`~transformers.FlaxCLIPVisionModel`
"""
pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1))
# Handle any PRNG if needed
rngs = {}
if dropout_rng is not None:
rngs["dropout"] = dropout_rng
def _get_features(module, pixel_values, deterministic):
vision_outputs = module.vision_model(pixel_values=pixel_values, deterministic=deterministic)
pooled_output = vision_outputs[1] # pooled_output
image_features = module.visual_projection(pooled_output)
return image_features
return self.module.apply(
{"params": self.params},
jnp.array(pixel_values, dtype=jnp.float32),
not train,
method=_get_features,
rngs=rngs,
)
class FlaxCLIPTextModule(nn.Module):
config: CLIPTextConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.text_model = FlaxCLIPTextTransformer(self.config, dtype=self.dtype)
def __call__(
self,
input_ids,
attention_mask,
position_ids,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
return self.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
class FlaxCLIPTextModel(FlaxCLIPTextPreTrainedModel):
module_class = FlaxCLIPTextModule
FLAX_CLIP_TEXT_MODEL_DOCSTRING = """
Returns:
Example::
>>> from transformers import CLIPTokenizer, FlaxCLIPTextModel
>>> model = FlaxCLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
>>> tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
>>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="np")
>>> outputs = model(**inputs)
>>> last_hidden_state = outputs.last_hidden_state
>>> pooled_output = outputs.pooled_output # pooled (EOS token) states
"""
overwrite_call_docstring(FlaxCLIPTextModel, CLIP_TEXT_INPUTS_DOCSTRING + FLAX_CLIP_TEXT_MODEL_DOCSTRING)
append_replace_return_docstrings(
FlaxCLIPTextModel, output_type=FlaxBaseModelOutputWithPooling, config_class=CLIPTextConfig
)
class FlaxCLIPVisionModule(nn.Module):
config: CLIPVisionConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.vision_model = FlaxCLIPVisionTransformer(self.config, dtype=self.dtype)
def __call__(
self,
pixel_values,
deterministic: bool = True,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
return self.vision_model(
pixel_values=pixel_values,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
class FlaxCLIPVisionModel(FlaxCLIPVisionPreTrainedModel):
module_class = FlaxCLIPVisionModule
FLAX_CLIP_VISION_MODEL_DOCSTRING = """
Returns:
Example::
>>> from PIL import Image
>>> import requests
>>> from transformers import CLIPProcessor, FlaxCLIPVisionModel
>>> model = FlaxCLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32")
>>> processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor(images=image, return_tensors="np")
>>> outputs = model(**inputs)
>>> last_hidden_state = outputs.last_hidden_state
>>> pooled_output = outputs.pooled_output # pooled CLS states
"""
overwrite_call_docstring(FlaxCLIPVisionModel, CLIP_VISION_INPUTS_DOCSTRING + FLAX_CLIP_VISION_MODEL_DOCSTRING)
append_replace_return_docstrings(
FlaxCLIPVisionModel, output_type=FlaxBaseModelOutputWithPooling, config_class=CLIPVisionConfig
)
class FlaxCLIPModule(nn.Module):
config: CLIPConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
text_config = self.config.text_config
vision_config = self.config.vision_config
self.projection_dim = self.config.projection_dim
self.text_embed_dim = text_config.hidden_size
self.vision_embed_dim = vision_config.hidden_size
self.text_model = FlaxCLIPTextTransformer(text_config, dtype=self.dtype)
self.vision_model = FlaxCLIPVisionTransformer(vision_config, dtype=self.dtype)
self.visual_projection = nn.Dense(
self.projection_dim,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(0.02, dtype=self.dtype),
use_bias=False,
)
self.text_projection = nn.Dense(
self.projection_dim,
dtype=self.dtype,
kernel_init=jax.nn.initializers.normal(0.02, dtype=self.dtype),
use_bias=False,
)
self.logit_scale = self.param("logit_scale", jax.nn.initializers.ones, [])
def __call__(
self,
input_ids=None,
pixel_values=None,
attention_mask=None,
position_ids=None,
deterministic: bool = True,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
return_dict = return_dict if return_dict is not None else self.config.return_dict
vision_outputs = self.vision_model(
pixel_values=pixel_values,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
text_outputs = self.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
image_embeds = vision_outputs[1]
image_embeds = self.visual_projection(image_embeds)
text_embeds = text_outputs[1]
text_embeds = self.text_projection(text_embeds)
# normalized features
image_embeds = image_embeds / jnp.linalg.norm(image_embeds, axis=-1, keepdims=True)
text_embeds = text_embeds / jnp.linalg.norm(text_embeds, axis=-1, keepdims=True)
# cosine similarity as logits
logit_scale = jnp.exp(self.logit_scale)
logits_per_text = jnp.matmul(text_embeds, image_embeds.T) * logit_scale
logits_per_image = logits_per_text.T
if not return_dict:
return (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
return FlaxCLIPOutput(
logits_per_image=logits_per_image,
logits_per_text=logits_per_text,
text_embeds=text_embeds,
image_embeds=image_embeds,
text_model_output=text_outputs,
vision_model_output=vision_outputs,
)
@add_start_docstrings(CLIP_START_DOCSTRING)
class FlaxCLIPModel(FlaxCLIPPreTrainedModel):
module_class = FlaxCLIPModule
FLAX_CLIP_MODEL_DOCSTRING = """
Returns:
Example::
>>> import jax
>>> from PIL import Image
>>> import requests
>>> from transformers import CLIPProcessor, FlaxCLIPModel
>>> model = FlaxCLIPModel.from_pretrained("openai/clip-vit-base-patch32")
>>> processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor(text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="np", padding=True)
>>> outputs = model(**inputs)
>>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score
>>> probs = jax.nn.softmax(logits_per_image, axis=1) # we can take the softmax to get the label probabilities
"""
overwrite_call_docstring(FlaxCLIPModel, CLIP_INPUTS_DOCSTRING + FLAX_CLIP_MODEL_DOCSTRING)
append_replace_return_docstrings(FlaxCLIPModel, output_type=FlaxCLIPOutput, config_class=CLIPConfig)
...@@ -222,6 +222,42 @@ class FlaxBertPreTrainedModel: ...@@ -222,6 +222,42 @@ class FlaxBertPreTrainedModel:
requires_backends(self, ["flax"]) requires_backends(self, ["flax"])
class FlaxCLIPModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxCLIPPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxCLIPTextModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxCLIPVisionModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxElectraForMaskedLM: class FlaxElectraForMaskedLM:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"]) requires_backends(self, ["flax"])
......
import inspect
import tempfile
import unittest
import numpy as np
import transformers
from transformers import CLIPConfig, CLIPTextConfig, CLIPVisionConfig, is_flax_available, is_torch_available
from transformers.testing_utils import is_pt_flax_cross_test, require_flax, slow
from .test_modeling_flax_common import FlaxModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
if is_flax_available():
import jax
import jax.numpy as jnp
from transformers.modeling_flax_pytorch_utils import (
convert_pytorch_state_dict_to_flax,
load_flax_weights_in_pytorch_model,
)
from transformers.models.clip.modeling_flax_clip import FlaxCLIPModel, FlaxCLIPTextModel, FlaxCLIPVisionModel
if is_torch_available():
import torch
class FlaxCLIPVisionModelTester:
def __init__(
self,
parent,
batch_size=12,
image_size=30,
patch_size=2,
num_channels=3,
is_training=True,
hidden_size=32,
num_hidden_layers=5,
num_attention_heads=4,
intermediate_size=37,
dropout=0.1,
attention_dropout=0.1,
initializer_range=0.02,
scope=None,
):
self.parent = parent
self.batch_size = batch_size
self.image_size = image_size
self.patch_size = patch_size
self.num_channels = num_channels
self.is_training = is_training
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.dropout = dropout
self.attention_dropout = attention_dropout
self.initializer_range = initializer_range
self.scope = scope
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
config = CLIPVisionConfig(
image_size=self.image_size,
patch_size=self.patch_size,
num_channels=self.num_channels,
hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
intermediate_size=self.intermediate_size,
dropout=self.dropout,
attention_dropout=self.attention_dropout,
initializer_range=self.initializer_range,
)
return config, pixel_values
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, pixel_values = config_and_inputs
inputs_dict = {"pixel_values": pixel_values}
return config, inputs_dict
@require_flax
class FlaxCLIPVisionModelTest(FlaxModelTesterMixin, unittest.TestCase):
"""
Here we also overwrite some of the tests of test_modeling_common.py, as CLIP does not use input_ids, inputs_embeds,
attention_mask and seq_length.
"""
all_model_classes = (FlaxCLIPVisionModel,) if is_flax_available() else ()
def setUp(self):
self.model_tester = FlaxCLIPVisionModelTester(self)
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
signature = inspect.signature(model.__call__)
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()]
expected_arg_names = ["pixel_values"]
self.assertListEqual(arg_names[:1], expected_arg_names)
def test_jit_compilation(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
with self.subTest(model_class.__name__):
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
model = model_class(config)
@jax.jit
def model_jitted(pixel_values, **kwargs):
return model(pixel_values=pixel_values, **kwargs).to_tuple()
with self.subTest("JIT Enabled"):
jitted_outputs = model_jitted(**prepared_inputs_dict)
with self.subTest("JIT Disabled"):
with jax.disable_jit():
outputs = model_jitted(**prepared_inputs_dict)
self.assertEqual(len(outputs), len(jitted_outputs))
for jitted_output, output in zip(jitted_outputs, outputs):
self.assertEqual(jitted_output.shape, output.shape)
def test_hidden_states_output(self):
def check_hidden_states_output(inputs_dict, config, model_class):
model = model_class(config)
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
hidden_states = outputs.hidden_states
self.assertEqual(len(hidden_states), self.model_tester.num_hidden_layers + 1)
# CLIP has a different seq_length
image_size = (self.model_tester.image_size, self.model_tester.image_size)
patch_size = (self.model_tester.patch_size, self.model_tester.patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
seq_length = num_patches + 1
self.assertListEqual(
list(hidden_states[0].shape[-2:]),
[seq_length, self.model_tester.hidden_size],
)
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
inputs_dict["output_hidden_states"] = True
check_hidden_states_output(inputs_dict, config, model_class)
# check that output_hidden_states also work using config
del inputs_dict["output_hidden_states"]
config.output_hidden_states = True
check_hidden_states_output(inputs_dict, config, model_class)
def test_attention_outputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True
# in CLIP, the seq_len equals the number of patches + 1 (we add 1 for the [CLS] token)
image_size = (self.model_tester.image_size, self.model_tester.image_size)
patch_size = (self.model_tester.patch_size, self.model_tester.patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
seq_length = num_patches + 1
for model_class in self.all_model_classes:
inputs_dict["output_attentions"] = True
inputs_dict["output_hidden_states"] = False
model = model_class(config)
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs.attentions
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
# check that output_attentions also work using config
del inputs_dict["output_attentions"]
config.output_attentions = True
model = model_class(config)
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs.attentions
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
list(attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, seq_length, seq_length],
)
out_len = len(outputs)
# Check attention is always last and order is fine
inputs_dict["output_attentions"] = True
inputs_dict["output_hidden_states"] = True
model = model_class(config)
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
added_hidden_states = 1
self.assertEqual(out_len + added_hidden_states, len(outputs))
self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
list(self_attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, seq_length, seq_length],
)
@slow
def test_model_from_pretrained(self):
for model_class_name in self.all_model_classes:
model = model_class_name.from_pretrained("openai/clip-vit-base-patch32", from_pt=True)
outputs = model(np.ones((1, 3, 224, 224)))
self.assertIsNotNone(outputs)
class FlaxCLIPTextModelTester:
def __init__(
self,
parent,
batch_size=12,
seq_length=7,
is_training=True,
use_input_mask=True,
use_labels=True,
vocab_size=99,
hidden_size=32,
num_hidden_layers=5,
num_attention_heads=4,
intermediate_size=37,
dropout=0.1,
attention_dropout=0.1,
max_position_embeddings=512,
initializer_range=0.02,
scope=None,
):
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
self.is_training = is_training
self.use_input_mask = use_input_mask
self.use_labels = use_labels
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.dropout = dropout
self.attention_dropout = attention_dropout
self.max_position_embeddings = max_position_embeddings
self.initializer_range = initializer_range
self.scope = scope
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_mask = None
if self.use_input_mask:
input_mask = random_attention_mask([self.batch_size, self.seq_length])
if input_mask is not None:
batch_size, seq_length = input_mask.shape
rnd_start_indices = np.random.randint(1, seq_length - 1, size=(batch_size,))
for batch_idx, start_index in enumerate(rnd_start_indices):
input_mask[batch_idx, :start_index] = 1
input_mask[batch_idx, start_index:] = 0
config = CLIPTextConfig(
vocab_size=self.vocab_size,
hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
intermediate_size=self.intermediate_size,
dropout=self.dropout,
attention_dropout=self.attention_dropout,
max_position_embeddings=self.max_position_embeddings,
initializer_range=self.initializer_range,
)
return config, input_ids, input_mask
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, input_ids, input_mask = config_and_inputs
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
return config, inputs_dict
@require_flax
class FlaxCLIPTextModelTest(FlaxModelTesterMixin, unittest.TestCase):
all_model_classes = (FlaxCLIPTextModel,) if is_flax_available() else ()
def setUp(self):
self.model_tester = FlaxCLIPTextModelTester(self)
@slow
def test_model_from_pretrained(self):
for model_class_name in self.all_model_classes:
model = model_class_name.from_pretrained("openai/clip-vit-base-patch32", from_pt=True)
outputs = model(np.ones((1, 1)))
self.assertIsNotNone(outputs)
class FlaxCLIPModelTester:
def __init__(self, parent, is_training=True):
self.parent = parent
self.text_model_tester = FlaxCLIPTextModelTester(parent)
self.vision_model_tester = FlaxCLIPVisionModelTester(parent)
self.is_training = is_training
def prepare_config_and_inputs(self):
text_config, input_ids, attention_mask = self.text_model_tester.prepare_config_and_inputs()
vision_config, pixel_values = self.vision_model_tester.prepare_config_and_inputs()
config = CLIPConfig.from_text_vision_configs(text_config, vision_config, projection_dim=64)
return config, input_ids, attention_mask, pixel_values
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, input_ids, attention_mask, pixel_values = config_and_inputs
inputs_dict = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"pixel_values": pixel_values,
}
return config, inputs_dict
@require_flax
class FlaxCLIPModelTest(FlaxModelTesterMixin, unittest.TestCase):
all_model_classes = (FlaxCLIPModel,) if is_flax_available() else ()
test_attention_outputs = False
def setUp(self):
self.model_tester = FlaxCLIPModelTester(self)
# hidden_states are tested in individual model tests
def test_hidden_states_output(self):
pass
def test_jit_compilation(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
with self.subTest(model_class.__name__):
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
model = model_class(config)
@jax.jit
def model_jitted(input_ids, pixel_values, **kwargs):
return model(input_ids=input_ids, pixel_values=pixel_values, **kwargs).to_tuple()
with self.subTest("JIT Enabled"):
jitted_outputs = model_jitted(**prepared_inputs_dict)
with self.subTest("JIT Disabled"):
with jax.disable_jit():
outputs = model_jitted(**prepared_inputs_dict)
self.assertEqual(len(outputs), len(jitted_outputs))
for jitted_output, output in zip(jitted_outputs[:4], outputs[:4]):
self.assertEqual(jitted_output.shape, output.shape)
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
signature = inspect.signature(model.__call__)
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()]
expected_arg_names = ["input_ids", "pixel_values", "attention_mask", "position_ids"]
self.assertListEqual(arg_names[:4], expected_arg_names)
def test_get_image_features(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = FlaxCLIPModel(config)
@jax.jit
def model_jitted(pixel_values):
return model.get_image_features(pixel_values=pixel_values)
with self.subTest("JIT Enabled"):
jitted_output = model_jitted(inputs_dict["pixel_values"])
with self.subTest("JIT Disabled"):
with jax.disable_jit():
output = model_jitted(inputs_dict["pixel_values"])
self.assertEqual(jitted_output.shape, output.shape)
self.assertTrue(np.allclose(jitted_output, output, atol=1e-3))
def test_get_text_features(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = FlaxCLIPModel(config)
@jax.jit
def model_jitted(input_ids, attention_mask, **kwargs):
return model.get_text_features(input_ids=input_ids, attention_mask=attention_mask)
with self.subTest("JIT Enabled"):
jitted_output = model_jitted(**inputs_dict)
with self.subTest("JIT Disabled"):
with jax.disable_jit():
output = model_jitted(**inputs_dict)
self.assertEqual(jitted_output.shape, output.shape)
self.assertTrue(np.allclose(jitted_output, output, atol=1e-3))
@slow
def test_model_from_pretrained(self):
for model_class_name in self.all_model_classes:
model = model_class_name.from_pretrained("openai/clip-vit-base-patch32", from_pt=True)
outputs = model(input_ids=np.ones((1, 1)), pixel_values=np.ones((1, 3, 224, 224)))
self.assertIsNotNone(outputs)
# overwrite from common since FlaxCLIPModel returns nested output
# which is not supported in the common test
@is_pt_flax_cross_test
def test_equivalence_pt_to_flax(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
with self.subTest(model_class.__name__):
# prepare inputs
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in prepared_inputs_dict.items()}
# load corresponding PyTorch class
pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning
pt_model_class = getattr(transformers, pt_model_class_name)
pt_model = pt_model_class(config).eval()
fx_model = model_class(config, dtype=jnp.float32)
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
fx_model.params = fx_state
with torch.no_grad():
pt_outputs = pt_model(**pt_inputs).to_tuple()
# PyTorch CLIPModel returns loss, we skip it here as we don't return loss in JAX/Flax models
pt_outputs = pt_outputs[1:]
fx_outputs = fx_model(**prepared_inputs_dict).to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs[:4]):
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname)
fx_model_loaded = model_class.from_pretrained(tmpdirname, from_pt=True)
fx_outputs_loaded = fx_model_loaded(**prepared_inputs_dict).to_tuple()
self.assertEqual(
len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch"
)
for fx_output_loaded, pt_output in zip(fx_outputs_loaded[:4], pt_outputs[:4]):
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 4e-2)
# overwrite from common since FlaxCLIPModel returns nested output
# which is not supported in the common test
@is_pt_flax_cross_test
def test_equivalence_flax_to_pt(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
with self.subTest(model_class.__name__):
# prepare inputs
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in prepared_inputs_dict.items()}
# load corresponding PyTorch class
pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning
pt_model_class = getattr(transformers, pt_model_class_name)
pt_model = pt_model_class(config).eval()
fx_model = model_class(config, dtype=jnp.float32)
pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)
# make sure weights are tied in PyTorch
pt_model.tie_weights()
with torch.no_grad():
pt_outputs = pt_model(**pt_inputs).to_tuple()
# PyTorch CLIPModel returns loss, we skip it here as we don't return loss in JAX/Flax models
pt_outputs = pt_outputs[1:]
fx_outputs = fx_model(**prepared_inputs_dict).to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs[:4]):
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
with tempfile.TemporaryDirectory() as tmpdirname:
fx_model.save_pretrained(tmpdirname)
pt_model_loaded = pt_model_class.from_pretrained(tmpdirname, from_flax=True)
with torch.no_grad():
pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple()
pt_outputs_loaded = pt_outputs_loaded[1:]
self.assertEqual(
len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch"
)
for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs_loaded[:4]):
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
...@@ -60,6 +60,22 @@ def ids_tensor(shape, vocab_size, rng=None): ...@@ -60,6 +60,22 @@ def ids_tensor(shape, vocab_size, rng=None):
return output return output
def floats_tensor(shape, scale=1.0, rng=None, name=None):
"""Creates a random float32 tensor"""
if rng is None:
rng = random.Random()
total_dims = 1
for dim in shape:
total_dims *= dim
values = []
for _ in range(total_dims):
values.append(rng.random() * scale)
return np.array(values, dtype=jnp.float32).reshape(shape)
def random_attention_mask(shape, rng=None): def random_attention_mask(shape, rng=None):
attn_mask = ids_tensor(shape, vocab_size=2, rng=rng) attn_mask = ids_tensor(shape, vocab_size=2, rng=rng)
# make sure that at least one token is attended to for each batch # make sure that at least one token is attended to for each batch
......
...@@ -93,6 +93,8 @@ IGNORE_NON_AUTO_CONFIGURED = [ ...@@ -93,6 +93,8 @@ IGNORE_NON_AUTO_CONFIGURED = [
# models to ignore for model xxx mapping # models to ignore for model xxx mapping
"CLIPTextModel", "CLIPTextModel",
"CLIPVisionModel", "CLIPVisionModel",
"FlaxCLIPTextModel",
"FlaxCLIPVisionModel",
"DPRReader", "DPRReader",
"DPRSpanPredictor", "DPRSpanPredictor",
"FlaubertForQuestionAnswering", "FlaubertForQuestionAnswering",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment