Unverified Commit 90067748 authored by Shubhamai's avatar Shubhamai Committed by GitHub
Browse files

Flax Regnet (#21867)

* initial commit

* review changes

* post model PR merge

* updating doc
parent fc5b7419
...@@ -283,7 +283,7 @@ Flax), PyTorch, und/oder TensorFlow haben. ...@@ -283,7 +283,7 @@ Flax), PyTorch, und/oder TensorFlow haben.
| RAG | ✅ | ❌ | ✅ | ✅ | ❌ | | RAG | ✅ | ❌ | ✅ | ✅ | ❌ |
| REALM | ✅ | ✅ | ✅ | ❌ | ❌ | | REALM | ✅ | ✅ | ✅ | ❌ | ❌ |
| Reformer | ✅ | ✅ | ✅ | ❌ | ❌ | | Reformer | ✅ | ✅ | ✅ | ❌ | ❌ |
| RegNet | ❌ | ❌ | ✅ | ✅ | | | RegNet | ❌ | ❌ | ✅ | ✅ | |
| RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ | | RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ |
| ResNet | ❌ | ❌ | ✅ | ✅ | ✅ | | ResNet | ❌ | ❌ | ✅ | ✅ | ✅ |
| RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ | | RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ |
......
...@@ -377,7 +377,7 @@ Flax), PyTorch, and/or TensorFlow. ...@@ -377,7 +377,7 @@ Flax), PyTorch, and/or TensorFlow.
| RAG | ✅ | ❌ | ✅ | ✅ | ❌ | | RAG | ✅ | ❌ | ✅ | ✅ | ❌ |
| REALM | ✅ | ✅ | ✅ | ❌ | ❌ | | REALM | ✅ | ✅ | ✅ | ❌ | ❌ |
| Reformer | ✅ | ✅ | ✅ | ❌ | ❌ | | Reformer | ✅ | ✅ | ✅ | ❌ | ❌ |
| RegNet | ❌ | ❌ | ✅ | ✅ | | | RegNet | ❌ | ❌ | ✅ | ✅ | |
| RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ | | RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ |
| ResNet | ❌ | ❌ | ✅ | ✅ | ✅ | | ResNet | ❌ | ❌ | ✅ | ✅ | ✅ |
| RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ | | RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ |
......
...@@ -67,4 +67,16 @@ If you're interested in submitting a resource to be included here, please feel f ...@@ -67,4 +67,16 @@ If you're interested in submitting a resource to be included here, please feel f
## TFRegNetForImageClassification ## TFRegNetForImageClassification
[[autodoc]] TFRegNetForImageClassification [[autodoc]] TFRegNetForImageClassification
- call - call
\ No newline at end of file
## FlaxRegNetModel
[[autodoc]] FlaxRegNetModel
- __call__
## FlaxRegNetForImageClassification
[[autodoc]] FlaxRegNetForImageClassification
- __call__
\ No newline at end of file
...@@ -235,7 +235,7 @@ Flax), PyTorch y/o TensorFlow. ...@@ -235,7 +235,7 @@ Flax), PyTorch y/o TensorFlow.
| RAG | ✅ | ❌ | ✅ | ✅ | ❌ | | RAG | ✅ | ❌ | ✅ | ✅ | ❌ |
| Realm | ✅ | ✅ | ✅ | ❌ | ❌ | | Realm | ✅ | ✅ | ✅ | ❌ | ❌ |
| Reformer | ✅ | ✅ | ✅ | ❌ | ❌ | | Reformer | ✅ | ✅ | ✅ | ❌ | ❌ |
| RegNet | ❌ | ❌ | ✅ | | | | RegNet | ❌ | ❌ | ✅ | | |
| RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ | | RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ |
| ResNet | ❌ | ❌ | ✅ | ❌ | ✅ | | ResNet | ❌ | ❌ | ✅ | ❌ | ✅ |
| RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ | | RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ |
......
...@@ -347,7 +347,7 @@ Le tableau ci-dessous représente la prise en charge actuelle dans la bibliothè ...@@ -347,7 +347,7 @@ Le tableau ci-dessous représente la prise en charge actuelle dans la bibliothè
| RAG | ✅ | ❌ | ✅ | ✅ | ❌ | | RAG | ✅ | ❌ | ✅ | ✅ | ❌ |
| REALM | ✅ | ✅ | ✅ | ❌ | ❌ | | REALM | ✅ | ✅ | ✅ | ❌ | ❌ |
| Reformer | ✅ | ✅ | ✅ | ❌ | ❌ | | Reformer | ✅ | ✅ | ✅ | ❌ | ❌ |
| RegNet | ❌ | ❌ | ✅ | ✅ | | | RegNet | ❌ | ❌ | ✅ | ✅ | |
| RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ | | RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ |
| ResNet | ❌ | ❌ | ✅ | ✅ | ❌ | | ResNet | ❌ | ❌ | ✅ | ✅ | ❌ |
| RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ | | RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ |
......
...@@ -252,7 +252,7 @@ tokenizer (chiamato "slow"). Un tokenizer "fast" supportato dalla libreria 🤗 ...@@ -252,7 +252,7 @@ tokenizer (chiamato "slow"). Un tokenizer "fast" supportato dalla libreria 🤗
| RAG | ✅ | ❌ | ✅ | ✅ | ❌ | | RAG | ✅ | ❌ | ✅ | ✅ | ❌ |
| Realm | ✅ | ✅ | ✅ | ❌ | ❌ | | Realm | ✅ | ✅ | ✅ | ❌ | ❌ |
| Reformer | ✅ | ✅ | ✅ | ❌ | ❌ | | Reformer | ✅ | ✅ | ✅ | ❌ | ❌ |
| RegNet | ❌ | ❌ | ✅ | | | | RegNet | ❌ | ❌ | ✅ | | |
| RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ | | RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ |
| ResNet | ❌ | ❌ | ✅ | ✅ | ✅ | | ResNet | ❌ | ❌ | ✅ | ✅ | ✅ |
| RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ | | RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ |
......
...@@ -337,7 +337,7 @@ specific language governing permissions and limitations under the License. ...@@ -337,7 +337,7 @@ specific language governing permissions and limitations under the License.
| RAG | ✅ | ❌ | ✅ | ✅ | ❌ | | RAG | ✅ | ❌ | ✅ | ✅ | ❌ |
| REALM | ✅ | ✅ | ✅ | ❌ | ❌ | | REALM | ✅ | ✅ | ✅ | ❌ | ❌ |
| Reformer | ✅ | ✅ | ✅ | ❌ | ❌ | | Reformer | ✅ | ✅ | ✅ | ❌ | ❌ |
| RegNet | ❌ | ❌ | ✅ | ✅ | | | RegNet | ❌ | ❌ | ✅ | ✅ | |
| RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ | | RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ |
| ResNet | ❌ | ❌ | ✅ | ✅ | ✅ | | ResNet | ❌ | ❌ | ✅ | ✅ | ✅ |
| RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ | | RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ |
......
...@@ -306,7 +306,7 @@ specific language governing permissions and limitations under the License. ...@@ -306,7 +306,7 @@ specific language governing permissions and limitations under the License.
| RAG | ✅ | ❌ | ✅ | ✅ | ❌ | | RAG | ✅ | ❌ | ✅ | ✅ | ❌ |
| REALM | ✅ | ✅ | ✅ | ❌ | ❌ | | REALM | ✅ | ✅ | ✅ | ❌ | ❌ |
| Reformer | ✅ | ✅ | ✅ | ❌ | ❌ | | Reformer | ✅ | ✅ | ✅ | ❌ | ❌ |
| RegNet | ❌ | ❌ | ✅ | ✅ | | | RegNet | ❌ | ❌ | ✅ | ✅ | |
| RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ | | RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ |
| ResNet | ❌ | ❌ | ✅ | ✅ | ✅ | | ResNet | ❌ | ❌ | ✅ | ✅ | ✅ |
| RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ | | RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ |
......
...@@ -250,7 +250,7 @@ disso, são diferenciados pelo suporte em diferentes frameworks: JAX (por meio d ...@@ -250,7 +250,7 @@ disso, são diferenciados pelo suporte em diferentes frameworks: JAX (por meio d
| RAG | ✅ | ❌ | ✅ | ✅ | ❌ | | RAG | ✅ | ❌ | ✅ | ✅ | ❌ |
| Realm | ✅ | ✅ | ✅ | ❌ | ❌ | | Realm | ✅ | ✅ | ✅ | ❌ | ❌ |
| Reformer | ✅ | ✅ | ✅ | ❌ | ❌ | | Reformer | ✅ | ✅ | ✅ | ❌ | ❌ |
| RegNet | ❌ | ❌ | ✅ | | | | RegNet | ❌ | ❌ | ✅ | | |
| RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ | | RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ |
| ResNet | ❌ | ❌ | ✅ | ❌ | ✅ | | ResNet | ❌ | ❌ | ✅ | ❌ | ✅ |
| RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ | | RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ |
......
...@@ -336,7 +336,7 @@ Flax), PyTorch, 和/或者 TensorFlow. ...@@ -336,7 +336,7 @@ Flax), PyTorch, 和/或者 TensorFlow.
| RAG | ✅ | ❌ | ✅ | ✅ | ❌ | | RAG | ✅ | ❌ | ✅ | ✅ | ❌ |
| REALM | ✅ | ✅ | ✅ | ❌ | ❌ | | REALM | ✅ | ✅ | ✅ | ❌ | ❌ |
| Reformer | ✅ | ✅ | ✅ | ❌ | ❌ | | Reformer | ✅ | ✅ | ✅ | ❌ | ❌ |
| RegNet | ❌ | ❌ | ✅ | ✅ | | | RegNet | ❌ | ❌ | ✅ | ✅ | |
| RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ | | RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ |
| ResNet | ❌ | ❌ | ✅ | ✅ | ❌ | | ResNet | ❌ | ❌ | ✅ | ✅ | ❌ |
| RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ | | RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ |
......
...@@ -3661,6 +3661,9 @@ else: ...@@ -3661,6 +3661,9 @@ else:
"FlaxPegasusPreTrainedModel", "FlaxPegasusPreTrainedModel",
] ]
) )
_import_structure["models.regnet"].extend(
["FlaxRegNetForImageClassification", "FlaxRegNetModel", "FlaxRegNetPreTrainedModel"]
)
_import_structure["models.resnet"].extend( _import_structure["models.resnet"].extend(
["FlaxResNetForImageClassification", "FlaxResNetModel", "FlaxResNetPreTrainedModel"] ["FlaxResNetForImageClassification", "FlaxResNetModel", "FlaxResNetPreTrainedModel"]
) )
...@@ -6739,6 +6742,7 @@ if TYPE_CHECKING: ...@@ -6739,6 +6742,7 @@ if TYPE_CHECKING:
from .models.mt5 import FlaxMT5EncoderModel, FlaxMT5ForConditionalGeneration, FlaxMT5Model from .models.mt5 import FlaxMT5EncoderModel, FlaxMT5ForConditionalGeneration, FlaxMT5Model
from .models.opt import FlaxOPTForCausalLM, FlaxOPTModel, FlaxOPTPreTrainedModel from .models.opt import FlaxOPTForCausalLM, FlaxOPTModel, FlaxOPTPreTrainedModel
from .models.pegasus import FlaxPegasusForConditionalGeneration, FlaxPegasusModel, FlaxPegasusPreTrainedModel from .models.pegasus import FlaxPegasusForConditionalGeneration, FlaxPegasusModel, FlaxPegasusPreTrainedModel
from .models.regnet import FlaxRegNetForImageClassification, FlaxRegNetModel, FlaxRegNetPreTrainedModel
from .models.resnet import FlaxResNetForImageClassification, FlaxResNetModel, FlaxResNetPreTrainedModel from .models.resnet import FlaxResNetForImageClassification, FlaxResNetModel, FlaxResNetPreTrainedModel
from .models.roberta import ( from .models.roberta import (
FlaxRobertaForCausalLM, FlaxRobertaForCausalLM,
......
...@@ -48,6 +48,7 @@ FLAX_MODEL_MAPPING_NAMES = OrderedDict( ...@@ -48,6 +48,7 @@ FLAX_MODEL_MAPPING_NAMES = OrderedDict(
("mt5", "FlaxMT5Model"), ("mt5", "FlaxMT5Model"),
("opt", "FlaxOPTModel"), ("opt", "FlaxOPTModel"),
("pegasus", "FlaxPegasusModel"), ("pegasus", "FlaxPegasusModel"),
("regnet", "FlaxRegNetModel"),
("resnet", "FlaxResNetModel"), ("resnet", "FlaxResNetModel"),
("roberta", "FlaxRobertaModel"), ("roberta", "FlaxRobertaModel"),
("roberta-prelayernorm", "FlaxRobertaPreLayerNormModel"), ("roberta-prelayernorm", "FlaxRobertaPreLayerNormModel"),
...@@ -120,6 +121,7 @@ FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( ...@@ -120,6 +121,7 @@ FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[ [
# Model for Image-classsification # Model for Image-classsification
("beit", "FlaxBeitForImageClassification"), ("beit", "FlaxBeitForImageClassification"),
("regnet", "FlaxRegNetForImageClassification"),
("resnet", "FlaxResNetForImageClassification"), ("resnet", "FlaxResNetForImageClassification"),
("vit", "FlaxViTForImageClassification"), ("vit", "FlaxViTForImageClassification"),
] ]
......
...@@ -13,7 +13,13 @@ ...@@ -13,7 +13,13 @@
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tf_available, is_torch_available from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_flax_available,
is_tf_available,
is_torch_available,
)
_import_structure = {"configuration_regnet": ["REGNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "RegNetConfig"]} _import_structure = {"configuration_regnet": ["REGNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "RegNetConfig"]}
...@@ -44,6 +50,18 @@ else: ...@@ -44,6 +50,18 @@ else:
"TFRegNetPreTrainedModel", "TFRegNetPreTrainedModel",
] ]
try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_flax_regnet"] = [
"FlaxRegNetForImageClassification",
"FlaxRegNetModel",
"FlaxRegNetPreTrainedModel",
]
if TYPE_CHECKING: if TYPE_CHECKING:
from .configuration_regnet import REGNET_PRETRAINED_CONFIG_ARCHIVE_MAP, RegNetConfig from .configuration_regnet import REGNET_PRETRAINED_CONFIG_ARCHIVE_MAP, RegNetConfig
...@@ -74,6 +92,18 @@ if TYPE_CHECKING: ...@@ -74,6 +92,18 @@ if TYPE_CHECKING:
TFRegNetPreTrainedModel, TFRegNetPreTrainedModel,
) )
try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_flax_regnet import (
FlaxRegNetForImageClassification,
FlaxRegNetModel,
FlaxRegNetPreTrainedModel,
)
else: else:
import sys import sys
......
# coding=utf-8
# Copyright 2023 The Google Flax Team Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from typing import Optional, Tuple
import flax.linen as nn
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.traverse_util import flatten_dict, unflatten_dict
from transformers import RegNetConfig
from transformers.modeling_flax_outputs import (
FlaxBaseModelOutputWithNoAttention,
FlaxBaseModelOutputWithPooling,
FlaxBaseModelOutputWithPoolingAndNoAttention,
FlaxImageClassifierOutputWithNoAttention,
)
from transformers.modeling_flax_utils import (
ACT2FN,
FlaxPreTrainedModel,
append_replace_return_docstrings,
overwrite_call_docstring,
)
from transformers.utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
)
REGNET_START_DOCSTRING = r"""
This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading, saving and converting weights from PyTorch models)
This model is also a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module)
subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to
general usage and behavior.
Finally, this model supports inherent JAX features such as:
- [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
- [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
- [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
- [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
Parameters:
config ([`RegNetConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
`jax.numpy.bfloat16` (on TPUs).
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
specified all the computation will be performed with the given `dtype`.
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
parameters.**
If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
[`~FlaxPreTrainedModel.to_bf16`].
"""
REGNET_INPUTS_DOCSTRING = r"""
Args:
pixel_values (`numpy.ndarray` of shape `(batch_size, num_channels, height, width)`):
Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
[`RegNetImageProcessor.__call__`] for details.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
# Copied from transformers.models.resnet.modeling_flax_resnet.Identity
class Identity(nn.Module):
"""Identity function."""
@nn.compact
def __call__(self, x, **kwargs):
return x
class FlaxRegNetConvLayer(nn.Module):
out_channels: int
kernel_size: int = 3
stride: int = 1
groups: int = 1
activation: Optional[str] = "relu"
dtype: jnp.dtype = jnp.float32
def setup(self):
self.convolution = nn.Conv(
self.out_channels,
kernel_size=(self.kernel_size, self.kernel_size),
strides=self.stride,
padding=self.kernel_size // 2,
feature_group_count=self.groups,
use_bias=False,
kernel_init=nn.initializers.variance_scaling(2.0, mode="fan_out", distribution="truncated_normal"),
dtype=self.dtype,
)
self.normalization = nn.BatchNorm(momentum=0.9, epsilon=1e-05, dtype=self.dtype)
self.activation_func = ACT2FN[self.activation] if self.activation is not None else Identity()
def __call__(self, hidden_state: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray:
hidden_state = self.convolution(hidden_state)
hidden_state = self.normalization(hidden_state, use_running_average=deterministic)
hidden_state = self.activation_func(hidden_state)
return hidden_state
class FlaxRegNetEmbeddings(nn.Module):
config: RegNetConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.embedder = FlaxRegNetConvLayer(
self.config.embedding_size,
kernel_size=3,
stride=2,
activation=self.config.hidden_act,
dtype=self.dtype,
)
def __call__(self, pixel_values: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray:
num_channels = pixel_values.shape[-1]
if num_channels != self.config.num_channels:
raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
)
hidden_state = self.embedder(pixel_values, deterministic=deterministic)
return hidden_state
# Copied from transformers.models.resnet.modeling_flax_resnet.FlaxResNetShortCut with ResNet->RegNet
class FlaxRegNetShortCut(nn.Module):
"""
RegNet shortcut, used to project the residual features to the correct size. If needed, it is also used to
downsample the input using `stride=2`.
"""
out_channels: int
stride: int = 2
dtype: jnp.dtype = jnp.float32
def setup(self):
self.convolution = nn.Conv(
self.out_channels,
kernel_size=(1, 1),
strides=self.stride,
use_bias=False,
kernel_init=nn.initializers.variance_scaling(2.0, mode="fan_out", distribution="truncated_normal"),
dtype=self.dtype,
)
self.normalization = nn.BatchNorm(momentum=0.9, epsilon=1e-05, dtype=self.dtype)
def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray:
hidden_state = self.convolution(x)
hidden_state = self.normalization(hidden_state, use_running_average=deterministic)
return hidden_state
class FlaxRegNetSELayerCollection(nn.Module):
in_channels: int
reduced_channels: int
dtype: jnp.dtype = jnp.float32
def setup(self):
self.conv_1 = nn.Conv(
self.reduced_channels,
kernel_size=(1, 1),
kernel_init=nn.initializers.variance_scaling(2.0, mode="fan_out", distribution="truncated_normal"),
dtype=self.dtype,
name="0",
) # 0 is the name used in corresponding pytorch implementation
self.conv_2 = nn.Conv(
self.in_channels,
kernel_size=(1, 1),
kernel_init=nn.initializers.variance_scaling(2.0, mode="fan_out", distribution="truncated_normal"),
dtype=self.dtype,
name="2",
) # 2 is the name used in corresponding pytorch implementation
def __call__(self, hidden_state: jnp.ndarray) -> jnp.ndarray:
hidden_state = self.conv_1(hidden_state)
hidden_state = nn.relu(hidden_state)
hidden_state = self.conv_2(hidden_state)
attention = nn.sigmoid(hidden_state)
return attention
class FlaxRegNetSELayer(nn.Module):
"""
Squeeze and Excitation layer (SE) proposed in [Squeeze-and-Excitation Networks](https://arxiv.org/abs/1709.01507).
"""
in_channels: int
reduced_channels: int
dtype: jnp.dtype = jnp.float32
def setup(self):
self.pooler = partial(nn.avg_pool, padding=((0, 0), (0, 0)))
self.attention = FlaxRegNetSELayerCollection(self.in_channels, self.reduced_channels, dtype=self.dtype)
def __call__(self, hidden_state: jnp.ndarray) -> jnp.ndarray:
pooled = self.pooler(
hidden_state,
window_shape=(hidden_state.shape[1], hidden_state.shape[2]),
strides=(hidden_state.shape[1], hidden_state.shape[2]),
)
attention = self.attention(pooled)
hidden_state = hidden_state * attention
return hidden_state
class FlaxRegNetXLayerCollection(nn.Module):
config: RegNetConfig
out_channels: int
stride: int = 1
dtype: jnp.dtype = jnp.float32
def setup(self):
groups = max(1, self.out_channels // self.config.groups_width)
self.layer = [
FlaxRegNetConvLayer(
self.out_channels,
kernel_size=1,
activation=self.config.hidden_act,
dtype=self.dtype,
name="0",
),
FlaxRegNetConvLayer(
self.out_channels,
stride=self.stride,
groups=groups,
activation=self.config.hidden_act,
dtype=self.dtype,
name="1",
),
FlaxRegNetConvLayer(
self.out_channels,
kernel_size=1,
activation=None,
dtype=self.dtype,
name="2",
),
]
def __call__(self, hidden_state: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray:
for layer in self.layer:
hidden_state = layer(hidden_state, deterministic=deterministic)
return hidden_state
class FlaxRegNetXLayer(nn.Module):
"""
RegNet's layer composed by three `3x3` convolutions, same as a ResNet bottleneck layer with reduction = 1.
"""
config: RegNetConfig
in_channels: int
out_channels: int
stride: int = 1
dtype: jnp.dtype = jnp.float32
def setup(self):
should_apply_shortcut = self.in_channels != self.out_channels or self.stride != 1
self.shortcut = (
FlaxRegNetShortCut(
self.out_channels,
stride=self.stride,
dtype=self.dtype,
)
if should_apply_shortcut
else Identity()
)
self.layer = FlaxRegNetXLayerCollection(
self.config,
in_channels=self.in_channels,
out_channels=self.out_channels,
stride=self.stride,
dtype=self.dtype,
)
self.activation_func = ACT2FN[self.config.hidden_act]
def __call__(self, hidden_state: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray:
residual = hidden_state
hidden_state = self.layer(hidden_state)
residual = self.shortcut(residual, deterministic=deterministic)
hidden_state += residual
hidden_state = self.activation_func(hidden_state)
return hidden_state
class FlaxRegNetYLayerCollection(nn.Module):
config: RegNetConfig
in_channels: int
out_channels: int
stride: int = 1
dtype: jnp.dtype = jnp.float32
def setup(self):
groups = max(1, self.out_channels // self.config.groups_width)
self.layer = [
FlaxRegNetConvLayer(
self.out_channels,
kernel_size=1,
activation=self.config.hidden_act,
dtype=self.dtype,
name="0",
),
FlaxRegNetConvLayer(
self.out_channels,
stride=self.stride,
groups=groups,
activation=self.config.hidden_act,
dtype=self.dtype,
name="1",
),
FlaxRegNetSELayer(
self.out_channels,
reduced_channels=int(round(self.in_channels / 4)),
dtype=self.dtype,
name="2",
),
FlaxRegNetConvLayer(
self.out_channels,
kernel_size=1,
activation=None,
dtype=self.dtype,
name="3",
),
]
def __call__(self, hidden_state: jnp.ndarray) -> jnp.ndarray:
for layer in self.layer:
hidden_state = layer(hidden_state)
return hidden_state
class FlaxRegNetYLayer(nn.Module):
"""
RegNet's Y layer: an X layer with Squeeze and Excitation.
"""
config: RegNetConfig
in_channels: int
out_channels: int
stride: int = 1
dtype: jnp.dtype = jnp.float32
def setup(self):
should_apply_shortcut = self.in_channels != self.out_channels or self.stride != 1
self.shortcut = (
FlaxRegNetShortCut(
self.out_channels,
stride=self.stride,
dtype=self.dtype,
)
if should_apply_shortcut
else Identity()
)
self.layer = FlaxRegNetYLayerCollection(
self.config,
in_channels=self.in_channels,
out_channels=self.out_channels,
stride=self.stride,
dtype=self.dtype,
)
self.activation_func = ACT2FN[self.config.hidden_act]
def __call__(self, hidden_state: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray:
residual = hidden_state
hidden_state = self.layer(hidden_state)
residual = self.shortcut(residual, deterministic=deterministic)
hidden_state += residual
hidden_state = self.activation_func(hidden_state)
return hidden_state
class FlaxRegNetStageLayersCollection(nn.Module):
"""
A RegNet stage composed by stacked layers.
"""
config: RegNetConfig
in_channels: int
out_channels: int
stride: int = 2
depth: int = 2
dtype: jnp.dtype = jnp.float32
def setup(self):
layer = FlaxRegNetXLayer if self.config.layer_type == "x" else FlaxRegNetYLayer
layers = [
# downsampling is done in the first layer with stride of 2
layer(
self.config,
self.in_channels,
self.out_channels,
stride=self.stride,
dtype=self.dtype,
name="0",
)
]
for i in range(self.depth - 1):
layers.append(
layer(
self.config,
self.out_channels,
self.out_channels,
dtype=self.dtype,
name=str(i + 1),
)
)
self.layers = layers
def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray:
hidden_state = x
for layer in self.layers:
hidden_state = layer(hidden_state, deterministic=deterministic)
return hidden_state
# Copied from transformers.models.resnet.modeling_flax_resnet.FlaxResNetStage with ResNet->RegNet
class FlaxRegNetStage(nn.Module):
"""
A RegNet stage composed by stacked layers.
"""
config: RegNetConfig
in_channels: int
out_channels: int
stride: int = 2
depth: int = 2
dtype: jnp.dtype = jnp.float32
def setup(self):
self.layers = FlaxRegNetStageLayersCollection(
self.config,
in_channels=self.in_channels,
out_channels=self.out_channels,
stride=self.stride,
depth=self.depth,
dtype=self.dtype,
)
def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray:
return self.layers(x, deterministic=deterministic)
# Copied from transformers.models.resnet.modeling_flax_resnet.FlaxResNetStageCollection with ResNet->RegNet
class FlaxRegNetStageCollection(nn.Module):
config: RegNetConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
in_out_channels = zip(self.config.hidden_sizes, self.config.hidden_sizes[1:])
stages = [
FlaxRegNetStage(
self.config,
self.config.embedding_size,
self.config.hidden_sizes[0],
stride=2 if self.config.downsample_in_first_stage else 1,
depth=self.config.depths[0],
dtype=self.dtype,
name="0",
)
]
for i, ((in_channels, out_channels), depth) in enumerate(zip(in_out_channels, self.config.depths[1:])):
stages.append(
FlaxRegNetStage(self.config, in_channels, out_channels, depth=depth, dtype=self.dtype, name=str(i + 1))
)
self.stages = stages
def __call__(
self,
hidden_state: jnp.ndarray,
output_hidden_states: bool = False,
deterministic: bool = True,
) -> FlaxBaseModelOutputWithNoAttention:
hidden_states = () if output_hidden_states else None
for stage_module in self.stages:
if output_hidden_states:
hidden_states = hidden_states + (hidden_state.transpose(0, 3, 1, 2),)
hidden_state = stage_module(hidden_state, deterministic=deterministic)
return hidden_state, hidden_states
# Copied from transformers.models.resnet.modeling_flax_resnet.FlaxResNetEncoder with ResNet->RegNet
class FlaxRegNetEncoder(nn.Module):
config: RegNetConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.stages = FlaxRegNetStageCollection(self.config, dtype=self.dtype)
def __call__(
self,
hidden_state: jnp.ndarray,
output_hidden_states: bool = False,
return_dict: bool = True,
deterministic: bool = True,
) -> FlaxBaseModelOutputWithNoAttention:
hidden_state, hidden_states = self.stages(
hidden_state, output_hidden_states=output_hidden_states, deterministic=deterministic
)
if output_hidden_states:
hidden_states = hidden_states + (hidden_state.transpose(0, 3, 1, 2),)
if not return_dict:
return tuple(v for v in [hidden_state, hidden_states] if v is not None)
return FlaxBaseModelOutputWithNoAttention(
last_hidden_state=hidden_state,
hidden_states=hidden_states,
)
# Copied from transformers.models.resnet.modeling_flax_resnet.FlaxResNetPreTrainedModel with ResNet->RegNet,resnet->regnet,RESNET->REGNET
class FlaxRegNetPreTrainedModel(FlaxPreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = RegNetConfig
base_model_prefix = "regnet"
main_input_name = "pixel_values"
module_class: nn.Module = None
def __init__(
self,
config: RegNetConfig,
input_shape=(1, 224, 224, 3),
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
**kwargs,
):
module = self.module_class(config=config, dtype=dtype, **kwargs)
if input_shape is None:
input_shape = (1, config.image_size, config.image_size, config.num_channels)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensors
pixel_values = jnp.zeros(input_shape, dtype=self.dtype)
rngs = {"params": rng}
random_params = self.module.init(rngs, pixel_values, return_dict=False)
if params is not None:
random_params = flatten_dict(unfreeze(random_params))
params = flatten_dict(unfreeze(params))
for missing_key in self._missing_keys:
params[missing_key] = random_params[missing_key]
self._missing_keys = set()
return freeze(unflatten_dict(params))
else:
return random_params
@add_start_docstrings_to_model_forward(REGNET_INPUTS_DOCSTRING)
def __call__(
self,
pixel_values,
params: dict = None,
train: bool = False,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.return_dict
pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1))
# Handle any PRNG if needed
rngs = {}
return self.module.apply(
{
"params": params["params"] if params is not None else self.params["params"],
"batch_stats": params["batch_stats"] if params is not None else self.params["batch_stats"],
},
jnp.array(pixel_values, dtype=jnp.float32),
not train,
output_hidden_states,
return_dict,
rngs=rngs,
mutable=["batch_stats"] if train else False, # Returing tuple with batch_stats only when train is True
)
# Copied from transformers.models.resnet.modeling_flax_resnet.FlaxResNetModule with ResNet->RegNet
class FlaxRegNetModule(nn.Module):
config: RegNetConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self):
self.embedder = FlaxRegNetEmbeddings(self.config, dtype=self.dtype)
self.encoder = FlaxRegNetEncoder(self.config, dtype=self.dtype)
# Adaptive average pooling used in resnet
self.pooler = partial(
nn.avg_pool,
padding=((0, 0), (0, 0)),
)
def __call__(
self,
pixel_values,
deterministic: bool = True,
output_hidden_states: bool = False,
return_dict: bool = True,
) -> FlaxBaseModelOutputWithPoolingAndNoAttention:
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
embedding_output = self.embedder(pixel_values, deterministic=deterministic)
encoder_outputs = self.encoder(
embedding_output,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=deterministic,
)
last_hidden_state = encoder_outputs[0]
pooled_output = self.pooler(
last_hidden_state,
window_shape=(last_hidden_state.shape[1], last_hidden_state.shape[2]),
strides=(last_hidden_state.shape[1], last_hidden_state.shape[2]),
).transpose(0, 3, 1, 2)
last_hidden_state = last_hidden_state.transpose(0, 3, 1, 2)
if not return_dict:
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
return FlaxBaseModelOutputWithPoolingAndNoAttention(
last_hidden_state=last_hidden_state,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
)
@add_start_docstrings(
"The bare RegNet model outputting raw features without any specific head on top.",
REGNET_START_DOCSTRING,
)
class FlaxRegNetModel(FlaxRegNetPreTrainedModel):
module_class = FlaxRegNetModule
FLAX_VISION_MODEL_DOCSTRING = """
Returns:
Examples:
```python
>>> from transformers import AutoImageProcessor, FlaxRegNetModel
>>> from PIL import Image
>>> import requests
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> image_processor = AutoImageProcessor.from_pretrained("facebook/regnet-y-040")
>>> model = FlaxRegNetModel.from_pretrained("facebook/regnet-y-040")
>>> inputs = image_processor(images=image, return_tensors="np")
>>> outputs = model(**inputs)
>>> last_hidden_states = outputs.last_hidden_state
```
"""
overwrite_call_docstring(FlaxRegNetModel, FLAX_VISION_MODEL_DOCSTRING)
append_replace_return_docstrings(
FlaxRegNetModel,
output_type=FlaxBaseModelOutputWithPooling,
config_class=RegNetConfig,
)
# Copied from transformers.models.resnet.modeling_flax_resnet.FlaxResNetClassifierCollection with ResNet->RegNet
class FlaxRegNetClassifierCollection(nn.Module):
config: RegNetConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype, name="1")
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
return self.classifier(x)
# Copied from transformers.models.resnet.modeling_flax_resnet.FlaxResNetForImageClassificationModule with ResNet->RegNet,resnet->regnet,RESNET->REGNET
class FlaxRegNetForImageClassificationModule(nn.Module):
config: RegNetConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.regnet = FlaxRegNetModule(config=self.config, dtype=self.dtype)
if self.config.num_labels > 0:
self.classifier = FlaxRegNetClassifierCollection(self.config, dtype=self.dtype)
else:
self.classifier = Identity()
def __call__(
self,
pixel_values=None,
deterministic: bool = True,
output_hidden_states=None,
return_dict=None,
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.regnet(
pixel_values,
deterministic=deterministic,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
pooled_output = outputs.pooler_output if return_dict else outputs[1]
logits = self.classifier(pooled_output[:, :, 0, 0])
if not return_dict:
output = (logits,) + outputs[2:]
return output
return FlaxImageClassifierOutputWithNoAttention(logits=logits, hidden_states=outputs.hidden_states)
@add_start_docstrings(
"""
RegNet Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for
ImageNet.
""",
REGNET_START_DOCSTRING,
)
class FlaxRegNetForImageClassification(FlaxRegNetPreTrainedModel):
module_class = FlaxRegNetForImageClassificationModule
FLAX_VISION_CLASSIF_DOCSTRING = """
Returns:
Example:
```python
>>> from transformers import AutoImageProcessor, FlaxRegNetForImageClassification
>>> from PIL import Image
>>> import jax
>>> import requests
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> image_processor = AutoImageProcessor.from_pretrained("facebook/regnet-y-040")
>>> model = FlaxRegNetForImageClassification.from_pretrained("facebook/regnet-y-040")
>>> inputs = image_processor(images=image, return_tensors="np")
>>> outputs = model(**inputs)
>>> logits = outputs.logits
>>> # model predicts one of the 1000 ImageNet classes
>>> predicted_class_idx = jax.numpy.argmax(logits, axis=-1)
>>> print("Predicted class:", model.config.id2label[predicted_class_idx.item()])
```
"""
overwrite_call_docstring(FlaxRegNetForImageClassification, FLAX_VISION_CLASSIF_DOCSTRING)
append_replace_return_docstrings(
FlaxRegNetForImageClassification,
output_type=FlaxImageClassifierOutputWithNoAttention,
config_class=RegNetConfig,
)
...@@ -89,7 +89,7 @@ class Identity(nn.Module): ...@@ -89,7 +89,7 @@ class Identity(nn.Module):
"""Identity function.""" """Identity function."""
@nn.compact @nn.compact
def __call__(self, x): def __call__(self, x, **kwargs):
return x return x
......
...@@ -881,6 +881,27 @@ class FlaxPegasusPreTrainedModel(metaclass=DummyObject): ...@@ -881,6 +881,27 @@ class FlaxPegasusPreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["flax"]) requires_backends(self, ["flax"])
class FlaxRegNetForImageClassification(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxRegNetModel(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxRegNetPreTrainedModel(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxResNetForImageClassification(metaclass=DummyObject): class FlaxResNetForImageClassification(metaclass=DummyObject):
_backends = ["flax"] _backends = ["flax"]
......
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import unittest
from transformers import RegNetConfig, is_flax_available
from transformers.testing_utils import require_flax, slow
from transformers.utils import cached_property, is_vision_available
from ...test_configuration_common import ConfigTester
from ...test_modeling_flax_common import FlaxModelTesterMixin, floats_tensor
if is_flax_available():
import jax
import jax.numpy as jnp
from transformers.models.regnet.modeling_flax_regnet import FlaxRegNetForImageClassification, FlaxRegNetModel
if is_vision_available():
from PIL import Image
from transformers import AutoFeatureExtractor
class FlaxRegNetModelTester(unittest.TestCase):
def __init__(
self,
parent,
batch_size=3,
image_size=32,
num_channels=3,
embeddings_size=10,
hidden_sizes=[10, 20, 30, 40],
depths=[1, 1, 2, 1],
is_training=True,
use_labels=True,
hidden_act="relu",
num_labels=3,
scope=None,
):
self.parent = parent
self.batch_size = batch_size
self.image_size = image_size
self.num_channels = num_channels
self.embeddings_size = embeddings_size
self.hidden_sizes = hidden_sizes
self.depths = depths
self.is_training = is_training
self.use_labels = use_labels
self.hidden_act = hidden_act
self.num_labels = num_labels
self.scope = scope
self.num_stages = len(hidden_sizes)
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
config = self.get_config()
return config, pixel_values
def get_config(self):
return RegNetConfig(
num_channels=self.num_channels,
embeddings_size=self.embeddings_size,
hidden_sizes=self.hidden_sizes,
depths=self.depths,
hidden_act=self.hidden_act,
num_labels=self.num_labels,
image_size=self.image_size,
)
def create_and_check_model(self, config, pixel_values):
model = FlaxRegNetModel(config=config)
result = model(pixel_values)
# Output shape (b, c, h, w)
self.parent.assertEqual(
result.last_hidden_state.shape,
(self.batch_size, self.hidden_sizes[-1], self.image_size // 32, self.image_size // 32),
)
def create_and_check_for_image_classification(self, config, pixel_values):
config.num_labels = self.num_labels
model = FlaxRegNetForImageClassification(config=config)
result = model(pixel_values)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, pixel_values = config_and_inputs
inputs_dict = {"pixel_values": pixel_values}
return config, inputs_dict
@require_flax
class FlaxResNetModelTest(FlaxModelTesterMixin, unittest.TestCase):
all_model_classes = (FlaxRegNetModel, FlaxRegNetForImageClassification) if is_flax_available() else ()
is_encoder_decoder = False
test_head_masking = False
has_attentions = False
def setUp(self) -> None:
self.model_tester = FlaxRegNetModelTester(self)
self.config_tester = ConfigTester(self, config_class=RegNetConfig, has_text_modality=False)
def test_config(self):
self.create_and_test_config_common_properties()
self.config_tester.create_and_test_config_to_json_string()
self.config_tester.create_and_test_config_to_json_file()
self.config_tester.create_and_test_config_from_and_save_pretrained()
self.config_tester.create_and_test_config_with_num_labels()
self.config_tester.check_config_can_be_init_without_params()
self.config_tester.check_config_arguments_init()
def create_and_test_config_common_properties(self):
return
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
def test_for_image_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
@unittest.skip(reason="RegNet does not use inputs_embeds")
def test_inputs_embeds(self):
pass
@unittest.skip(reason="RegNet does not support input and output embeddings")
def test_model_common_attributes(self):
pass
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
signature = inspect.signature(model.__call__)
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()]
expected_arg_names = ["pixel_values"]
self.assertListEqual(arg_names[:1], expected_arg_names)
def test_hidden_states_output(self):
def check_hidden_states_output(inputs_dict, config, model_class):
model = model_class(config)
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states
expected_num_stages = self.model_tester.num_stages
self.assertEqual(len(hidden_states), expected_num_stages + 1)
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
inputs_dict["output_hidden_states"] = True
check_hidden_states_output(inputs_dict, config, model_class)
# check that output_hidden_states also work using config
del inputs_dict["output_hidden_states"]
config.output_hidden_states = True
check_hidden_states_output(inputs_dict, config, model_class)
def test_jit_compilation(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
with self.subTest(model_class.__name__):
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
model = model_class(config)
@jax.jit
def model_jitted(pixel_values, **kwargs):
return model(pixel_values=pixel_values, **kwargs)
with self.subTest("JIT Enabled"):
jitted_outputs = model_jitted(**prepared_inputs_dict).to_tuple()
with self.subTest("JIT Disabled"):
with jax.disable_jit():
outputs = model_jitted(**prepared_inputs_dict).to_tuple()
self.assertEqual(len(outputs), len(jitted_outputs))
for jitted_output, output in zip(jitted_outputs, outputs):
self.assertEqual(jitted_output.shape, output.shape)
# We will verify our results on an image of cute cats
def prepare_img():
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
return image
@require_flax
class FlaxRegNetModelIntegrationTest(unittest.TestCase):
@cached_property
def default_feature_extractor(self):
return AutoFeatureExtractor.from_pretrained("facebook/regnet-y-040") if is_vision_available() else None
@slow
def test_inference_image_classification_head(self):
model = FlaxRegNetForImageClassification.from_pretrained("facebook/regnet-y-040")
feature_extractor = self.default_feature_extractor
image = prepare_img()
inputs = feature_extractor(images=image, return_tensors="np")
outputs = model(**inputs)
# verify the logits
expected_shape = (1, 1000)
self.assertEqual(outputs.logits.shape, expected_shape)
expected_slice = jnp.array([-0.4180, -1.5051, -3.4836])
self.assertTrue(jnp.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment