Unverified Commit a0cbbba3 authored by Shubhamai's avatar Shubhamai Committed by GitHub
Browse files

Resnet flax (#21472)



* [WIP] flax resnet

* added pretrained flax models, results reproducible

* Added pretrained flax models, results reproducible

* working on tests

* no real code change, just some comments

* [flax] adding support for batch norm layers

* fixing bugs related to pt+flax integration

* removing loss from modeling flax output class

* fixing classifier tests

* fixing comments, model output

* cleaning comments

* review changes

* review changes

* Apply suggestions from code review
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* renaming Flax to PyTorch

---------
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent 88dae78f
...@@ -285,7 +285,7 @@ Flax), PyTorch, und/oder TensorFlow haben. ...@@ -285,7 +285,7 @@ Flax), PyTorch, und/oder TensorFlow haben.
| Reformer | ✅ | ✅ | ✅ | ❌ | ❌ | | Reformer | ✅ | ✅ | ✅ | ❌ | ❌ |
| RegNet | ❌ | ❌ | ✅ | ✅ | ❌ | | RegNet | ❌ | ❌ | ✅ | ✅ | ❌ |
| RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ | | RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ |
| ResNet | ❌ | ❌ | ✅ | ✅ | | | ResNet | ❌ | ❌ | ✅ | ✅ | |
| RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ | | RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ |
| RoBERTa | ✅ | ✅ | ✅ | ✅ | ✅ | | RoBERTa | ✅ | ✅ | ✅ | ✅ | ✅ |
| RoFormer | ✅ | ✅ | ✅ | ✅ | ✅ | | RoFormer | ✅ | ✅ | ✅ | ✅ | ✅ |
......
...@@ -377,7 +377,7 @@ Flax), PyTorch, and/or TensorFlow. ...@@ -377,7 +377,7 @@ Flax), PyTorch, and/or TensorFlow.
| Reformer | ✅ | ✅ | ✅ | ❌ | ❌ | | Reformer | ✅ | ✅ | ✅ | ❌ | ❌ |
| RegNet | ❌ | ❌ | ✅ | ✅ | ❌ | | RegNet | ❌ | ❌ | ✅ | ✅ | ❌ |
| RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ | | RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ |
| ResNet | ❌ | ❌ | ✅ | ✅ | | | ResNet | ❌ | ❌ | ✅ | ✅ | |
| RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ | | RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ |
| RoBERTa | ✅ | ✅ | ✅ | ✅ | ✅ | | RoBERTa | ✅ | ✅ | ✅ | ✅ | ✅ |
| RoBERTa-PreLayerNorm | ❌ | ❌ | ✅ | ✅ | ✅ | | RoBERTa-PreLayerNorm | ❌ | ❌ | ✅ | ✅ | ✅ |
......
...@@ -71,3 +71,13 @@ If you're interested in submitting a resource to be included here, please feel f ...@@ -71,3 +71,13 @@ If you're interested in submitting a resource to be included here, please feel f
[[autodoc]] TFResNetForImageClassification [[autodoc]] TFResNetForImageClassification
- call - call
## FlaxResNetModel
[[autodoc]] FlaxResNetModel
- __call__
## FlaxResNetForImageClassification
[[autodoc]] FlaxResNetForImageClassification
- __call__
...@@ -237,7 +237,7 @@ Flax), PyTorch y/o TensorFlow. ...@@ -237,7 +237,7 @@ Flax), PyTorch y/o TensorFlow.
| Reformer | ✅ | ✅ | ✅ | ❌ | ❌ | | Reformer | ✅ | ✅ | ✅ | ❌ | ❌ |
| RegNet | ❌ | ❌ | ✅ | ❌ | ❌ | | RegNet | ❌ | ❌ | ✅ | ❌ | ❌ |
| RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ | | RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ |
| ResNet | ❌ | ❌ | ✅ | ❌ | | | ResNet | ❌ | ❌ | ✅ | ❌ | |
| RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ | | RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ |
| RoBERTa | ✅ | ✅ | ✅ | ✅ | ✅ | | RoBERTa | ✅ | ✅ | ✅ | ✅ | ✅ |
| RoFormer | ✅ | ✅ | ✅ | ✅ | ✅ | | RoFormer | ✅ | ✅ | ✅ | ✅ | ✅ |
......
...@@ -254,7 +254,7 @@ tokenizer (chiamato "slow"). Un tokenizer "fast" supportato dalla libreria 🤗 ...@@ -254,7 +254,7 @@ tokenizer (chiamato "slow"). Un tokenizer "fast" supportato dalla libreria 🤗
| Reformer | ✅ | ✅ | ✅ | ❌ | ❌ | | Reformer | ✅ | ✅ | ✅ | ❌ | ❌ |
| RegNet | ❌ | ❌ | ✅ | ❌ | ❌ | | RegNet | ❌ | ❌ | ✅ | ❌ | ❌ |
| RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ | | RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ |
| ResNet | ❌ | ❌ | ✅ | | | | ResNet | ❌ | ❌ | ✅ | | |
| RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ | | RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ |
| RoBERTa | ✅ | ✅ | ✅ | ✅ | ✅ | | RoBERTa | ✅ | ✅ | ✅ | ✅ | ✅ |
| RoFormer | ✅ | ✅ | ✅ | ✅ | ✅ | | RoFormer | ✅ | ✅ | ✅ | ✅ | ✅ |
......
...@@ -339,7 +339,7 @@ specific language governing permissions and limitations under the License. ...@@ -339,7 +339,7 @@ specific language governing permissions and limitations under the License.
| Reformer | ✅ | ✅ | ✅ | ❌ | ❌ | | Reformer | ✅ | ✅ | ✅ | ❌ | ❌ |
| RegNet | ❌ | ❌ | ✅ | ✅ | ❌ | | RegNet | ❌ | ❌ | ✅ | ✅ | ❌ |
| RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ | | RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ |
| ResNet | ❌ | ❌ | ✅ | ✅ | | | ResNet | ❌ | ❌ | ✅ | ✅ | |
| RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ | | RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ |
| RoBERTa | ✅ | ✅ | ✅ | ✅ | ✅ | | RoBERTa | ✅ | ✅ | ✅ | ✅ | ✅ |
| RoBERTa-PreLayerNorm | ❌ | ❌ | ✅ | ✅ | ✅ | | RoBERTa-PreLayerNorm | ❌ | ❌ | ✅ | ✅ | ✅ |
......
...@@ -308,7 +308,7 @@ specific language governing permissions and limitations under the License. ...@@ -308,7 +308,7 @@ specific language governing permissions and limitations under the License.
| Reformer | ✅ | ✅ | ✅ | ❌ | ❌ | | Reformer | ✅ | ✅ | ✅ | ❌ | ❌ |
| RegNet | ❌ | ❌ | ✅ | ✅ | ❌ | | RegNet | ❌ | ❌ | ✅ | ✅ | ❌ |
| RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ | | RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ |
| ResNet | ❌ | ❌ | ✅ | ✅ | | | ResNet | ❌ | ❌ | ✅ | ✅ | |
| RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ | | RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ |
| RoBERTa | ✅ | ✅ | ✅ | ✅ | ✅ | | RoBERTa | ✅ | ✅ | ✅ | ✅ | ✅ |
| RoCBert | ✅ | ❌ | ✅ | ❌ | ❌ | | RoCBert | ✅ | ❌ | ✅ | ❌ | ❌ |
......
...@@ -252,7 +252,7 @@ disso, são diferenciados pelo suporte em diferentes frameworks: JAX (por meio d ...@@ -252,7 +252,7 @@ disso, são diferenciados pelo suporte em diferentes frameworks: JAX (por meio d
| Reformer | ✅ | ✅ | ✅ | ❌ | ❌ | | Reformer | ✅ | ✅ | ✅ | ❌ | ❌ |
| RegNet | ❌ | ❌ | ✅ | ❌ | ❌ | | RegNet | ❌ | ❌ | ✅ | ❌ | ❌ |
| RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ | | RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ |
| ResNet | ❌ | ❌ | ✅ | ❌ | | | ResNet | ❌ | ❌ | ✅ | ❌ | |
| RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ | | RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ |
| RoBERTa | ✅ | ✅ | ✅ | ✅ | ✅ | | RoBERTa | ✅ | ✅ | ✅ | ✅ | ✅ |
| RoFormer | ✅ | ✅ | ✅ | ✅ | ✅ | | RoFormer | ✅ | ✅ | ✅ | ✅ | ✅ |
......
...@@ -3637,6 +3637,9 @@ else: ...@@ -3637,6 +3637,9 @@ else:
"FlaxPegasusPreTrainedModel", "FlaxPegasusPreTrainedModel",
] ]
) )
_import_structure["models.resnet"].extend(
["FlaxResNetForImageClassification", "FlaxResNetModel", "FlaxResNetPreTrainedModel"]
)
_import_structure["models.roberta"].extend( _import_structure["models.roberta"].extend(
[ [
"FlaxRobertaForCausalLM", "FlaxRobertaForCausalLM",
...@@ -6692,6 +6695,7 @@ if TYPE_CHECKING: ...@@ -6692,6 +6695,7 @@ if TYPE_CHECKING:
from .models.mt5 import FlaxMT5EncoderModel, FlaxMT5ForConditionalGeneration, FlaxMT5Model from .models.mt5 import FlaxMT5EncoderModel, FlaxMT5ForConditionalGeneration, FlaxMT5Model
from .models.opt import FlaxOPTForCausalLM, FlaxOPTModel, FlaxOPTPreTrainedModel from .models.opt import FlaxOPTForCausalLM, FlaxOPTModel, FlaxOPTPreTrainedModel
from .models.pegasus import FlaxPegasusForConditionalGeneration, FlaxPegasusModel, FlaxPegasusPreTrainedModel from .models.pegasus import FlaxPegasusForConditionalGeneration, FlaxPegasusModel, FlaxPegasusPreTrainedModel
from .models.resnet import FlaxResNetForImageClassification, FlaxResNetModel, FlaxResNetPreTrainedModel
from .models.roberta import ( from .models.roberta import (
FlaxRobertaForCausalLM, FlaxRobertaForCausalLM,
FlaxRobertaForMaskedLM, FlaxRobertaForMaskedLM,
......
...@@ -45,6 +45,64 @@ class FlaxBaseModelOutput(ModelOutput): ...@@ -45,6 +45,64 @@ class FlaxBaseModelOutput(ModelOutput):
attentions: Optional[Tuple[jnp.ndarray]] = None attentions: Optional[Tuple[jnp.ndarray]] = None
@flax.struct.dataclass
class FlaxBaseModelOutputWithNoAttention(ModelOutput):
"""
Base class for model's outputs, with potential hidden states.
Args:
last_hidden_state (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`):
Sequence of hidden-states at the output of the last layer of the model.
hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `jnp.ndarray` (one for the output of the embeddings, if the model has an embedding layer, + one
for the output of each layer) of shape `(batch_size, num_channels, height, width)`. Hidden-states of the
model at the output of each layer plus the optional initial embedding outputs.
"""
last_hidden_state: jnp.ndarray = None
hidden_states: Optional[Tuple[jnp.ndarray]] = None
@flax.struct.dataclass
class FlaxBaseModelOutputWithPoolingAndNoAttention(ModelOutput):
"""
Base class for model's outputs that also contains a pooling of the last hidden states.
Args:
last_hidden_state (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`):
Sequence of hidden-states at the output of the last layer of the model.
pooler_output (`jnp.ndarray` of shape `(batch_size, hidden_size)`):
Last layer hidden-state after a pooling operation on the spatial dimensions.
hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `jnp.ndarray` (one for the output of the embeddings, if the model has an embedding layer, + one
for the output of each layer) of shape `(batch_size, num_channels, height, width)`. Hidden-states of the
model at the output of each layer plus the optional initial embedding outputs.
"""
last_hidden_state: jnp.ndarray = None
pooler_output: jnp.ndarray = None
hidden_states: Optional[Tuple[jnp.ndarray]] = None
@flax.struct.dataclass
class FlaxImageClassifierOutputWithNoAttention(ModelOutput):
"""
Base class for outputs of image classification models.
Args:
logits (`jnp.ndarray` of shape `(batch_size, config.num_labels)`):
Classification (or regression if config.num_labels==1) scores (before SoftMax).
hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when
`config.output_hidden_states=True`):
Tuple of `jnp.ndarray` (one for the output of the embeddings, if the model has an embedding layer, + one
for the output of each stage) of shape `(batch_size, num_channels, height, width)`. Hidden-states (also
called feature maps) of the model at the output of each stage.
"""
logits: jnp.ndarray = None
hidden_states: Optional[Tuple[jnp.ndarray]] = None
@flax.struct.dataclass @flax.struct.dataclass
class FlaxBaseModelOutputWithPast(ModelOutput): class FlaxBaseModelOutputWithPast(ModelOutput):
""" """
......
...@@ -48,6 +48,7 @@ FLAX_MODEL_MAPPING_NAMES = OrderedDict( ...@@ -48,6 +48,7 @@ FLAX_MODEL_MAPPING_NAMES = OrderedDict(
("mt5", "FlaxMT5Model"), ("mt5", "FlaxMT5Model"),
("opt", "FlaxOPTModel"), ("opt", "FlaxOPTModel"),
("pegasus", "FlaxPegasusModel"), ("pegasus", "FlaxPegasusModel"),
("resnet", "FlaxResNetModel"),
("roberta", "FlaxRobertaModel"), ("roberta", "FlaxRobertaModel"),
("roberta-prelayernorm", "FlaxRobertaPreLayerNormModel"), ("roberta-prelayernorm", "FlaxRobertaPreLayerNormModel"),
("roformer", "FlaxRoFormerModel"), ("roformer", "FlaxRoFormerModel"),
...@@ -119,6 +120,7 @@ FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( ...@@ -119,6 +120,7 @@ FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[ [
# Model for Image-classsification # Model for Image-classsification
("beit", "FlaxBeitForImageClassification"), ("beit", "FlaxBeitForImageClassification"),
("resnet", "FlaxResNetForImageClassification"),
("vit", "FlaxViTForImageClassification"), ("vit", "FlaxViTForImageClassification"),
] ]
) )
......
...@@ -13,7 +13,13 @@ ...@@ -13,7 +13,13 @@
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tf_available, is_torch_available from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_flax_available,
is_tf_available,
is_torch_available,
)
_import_structure = { _import_structure = {
...@@ -47,6 +53,17 @@ else: ...@@ -47,6 +53,17 @@ else:
"TFResNetPreTrainedModel", "TFResNetPreTrainedModel",
] ]
try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_flax_resnet"] = [
"FlaxResNetForImageClassification",
"FlaxResNetModel",
"FlaxResNetPreTrainedModel",
]
if TYPE_CHECKING: if TYPE_CHECKING:
from .configuration_resnet import RESNET_PRETRAINED_CONFIG_ARCHIVE_MAP, ResNetConfig, ResNetOnnxConfig from .configuration_resnet import RESNET_PRETRAINED_CONFIG_ARCHIVE_MAP, ResNetConfig, ResNetOnnxConfig
...@@ -78,6 +95,14 @@ if TYPE_CHECKING: ...@@ -78,6 +95,14 @@ if TYPE_CHECKING:
TFResNetPreTrainedModel, TFResNetPreTrainedModel,
) )
try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_flax_resnet import FlaxResNetForImageClassification, FlaxResNetModel, FlaxResNetPreTrainedModel
else: else:
import sys import sys
......
# coding=utf-8
# Copyright 2023 HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from typing import Optional, Tuple
import flax.linen as nn
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
from flax.traverse_util import flatten_dict, unflatten_dict
from ...modeling_flax_outputs import (
FlaxBaseModelOutputWithNoAttention,
FlaxBaseModelOutputWithPoolingAndNoAttention,
FlaxImageClassifierOutputWithNoAttention,
)
from ...modeling_flax_utils import (
ACT2FN,
FlaxPreTrainedModel,
append_replace_return_docstrings,
overwrite_call_docstring,
)
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward
from .configuration_resnet import ResNetConfig
RESNET_START_DOCSTRING = r"""
This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading, saving and converting weights from PyTorch models)
This model is also a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module)
subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to
general usage and behavior.
Finally, this model supports inherent JAX features such as:
- [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
- [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
- [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
- [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
Parameters:
config ([`ResNetConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
`jax.numpy.bfloat16` (on TPUs).
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
specified all the computation will be performed with the given `dtype`.
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
parameters.**
If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
[`~FlaxPreTrainedModel.to_bf16`].
"""
RESNET_INPUTS_DOCSTRING = r"""
Args:
pixel_values (`jax.numpy.float32` of shape `(batch_size, num_channels, height, width)`):
Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
[`AutoImageProcessor.__call__`] for details.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
class Identity(nn.Module):
"""Identity function."""
@nn.compact
def __call__(self, x):
return x
class FlaxResNetConvLayer(nn.Module):
out_channels: int
kernel_size: int = 3
stride: int = 1
activation: Optional[str] = "relu"
dtype: jnp.dtype = jnp.float32
def setup(self):
self.convolution = nn.Conv(
self.out_channels,
kernel_size=(self.kernel_size, self.kernel_size),
strides=self.stride,
padding=self.kernel_size // 2,
dtype=self.dtype,
use_bias=False,
kernel_init=nn.initializers.variance_scaling(2.0, mode="fan_out", distribution="normal", dtype=self.dtype),
)
self.normalization = nn.BatchNorm(momentum=0.9, epsilon=1e-05, dtype=self.dtype)
self.activation_func = ACT2FN[self.activation] if self.activation is not None else Identity()
def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray:
hidden_state = self.convolution(x)
hidden_state = self.normalization(hidden_state, use_running_average=deterministic)
hidden_state = self.activation_func(hidden_state)
return hidden_state
class FlaxResNetEmbeddings(nn.Module):
"""
ResNet Embeddings (stem) composed of a single aggressive convolution.
"""
config: ResNetConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.embedder = FlaxResNetConvLayer(
self.config.embedding_size,
kernel_size=7,
stride=2,
activation=self.config.hidden_act,
dtype=self.dtype,
)
self.max_pool = partial(nn.max_pool, window_shape=(3, 3), strides=(2, 2), padding=((1, 1), (1, 1)))
def __call__(self, pixel_values: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray:
num_channels = pixel_values.shape[-1]
if num_channels != self.config.num_channels:
raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
)
embedding = self.embedder(pixel_values, deterministic=deterministic)
embedding = self.max_pool(embedding)
return embedding
class FlaxResNetShortCut(nn.Module):
"""
ResNet shortcut, used to project the residual features to the correct size. If needed, it is also used to
downsample the input using `stride=2`.
"""
out_channels: int
stride: int = 2
dtype: jnp.dtype = jnp.float32
def setup(self):
self.convolution = nn.Conv(
self.out_channels,
kernel_size=(1, 1),
strides=self.stride,
use_bias=False,
kernel_init=nn.initializers.variance_scaling(2.0, mode="fan_out", distribution="truncated_normal"),
dtype=self.dtype,
)
self.normalization = nn.BatchNorm(momentum=0.9, epsilon=1e-05, dtype=self.dtype)
def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray:
hidden_state = self.convolution(x)
hidden_state = self.normalization(hidden_state, use_running_average=deterministic)
return hidden_state
class FlaxResNetBasicLayerCollection(nn.Module):
out_channels: int
stride: int = 1
dtype: jnp.dtype = jnp.float32
def setup(self):
self.layer = [
FlaxResNetConvLayer(self.out_channels, stride=self.stride, dtype=self.dtype),
FlaxResNetConvLayer(self.out_channels, activation=None, dtype=self.dtype),
]
def __call__(self, hidden_state: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray:
for layer in self.layer:
hidden_state = layer(hidden_state, deterministic=deterministic)
return hidden_state
class FlaxResNetBasicLayer(nn.Module):
"""
A classic ResNet's residual layer composed by two `3x3` convolutions.
"""
in_channels: int
out_channels: int
stride: int = 1
activation: Optional[str] = "relu"
dtype: jnp.dtype = jnp.float32
def setup(self):
should_apply_shortcut = self.in_channels != self.out_channels or self.stride != 1
self.shortcut = (
FlaxResNetShortCut(self.out_channels, stride=self.stride, dtype=self.dtype)
if should_apply_shortcut
else None
)
self.layer = FlaxResNetBasicLayerCollection(
out_channels=self.out_channels,
stride=self.stride,
activation=self.activation,
dtype=self.dtype,
)
self.activation_func = ACT2FN[self.activation]
def __call__(self, hidden_state, deterministic: bool = True):
residual = hidden_state
hidden_state = self.layer(hidden_state, deterministic=deterministic)
if self.shortcut is not None:
residual = self.shortcut(residual, deterministic=deterministic)
hidden_state += residual
hidden_state = self.activation_func(hidden_state)
return hidden_state
class FlaxResNetBottleNeckLayerCollection(nn.Module):
out_channels: int
stride: int = 1
activation: Optional[str] = "relu"
reduction: int = 4
dtype: jnp.dtype = jnp.float32
def setup(self):
reduces_channels = self.out_channels // self.reduction
self.layer = [
FlaxResNetConvLayer(reduces_channels, kernel_size=1, dtype=self.dtype, name="0"),
FlaxResNetConvLayer(reduces_channels, stride=self.stride, dtype=self.dtype, name="1"),
FlaxResNetConvLayer(self.out_channels, kernel_size=1, activation=None, dtype=self.dtype, name="2"),
]
def __call__(self, hidden_state: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray:
for layer in self.layer:
hidden_state = layer(hidden_state, deterministic=deterministic)
return hidden_state
class FlaxResNetBottleNeckLayer(nn.Module):
"""
A classic ResNet's bottleneck layer composed by three `3x3` convolutions. The first `1x1` convolution reduces the
input by a factor of `reduction` in order to make the second `3x3` convolution faster. The last `1x1` convolution
remaps the reduced features to `out_channels`.
"""
in_channels: int
out_channels: int
stride: int = 1
activation: Optional[str] = "relu"
reduction: int = 4
dtype: jnp.dtype = jnp.float32
def setup(self):
should_apply_shortcut = self.in_channels != self.out_channels or self.stride != 1
self.shortcut = (
FlaxResNetShortCut(self.out_channels, stride=self.stride, dtype=self.dtype)
if should_apply_shortcut
else None
)
self.layer = FlaxResNetBottleNeckLayerCollection(
self.out_channels,
stride=self.stride,
activation=self.activation,
reduction=self.reduction,
dtype=self.dtype,
)
self.activation_func = ACT2FN[self.activation]
def __call__(self, hidden_state: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray:
residual = hidden_state
if self.shortcut is not None:
residual = self.shortcut(residual, deterministic=deterministic)
hidden_state = self.layer(hidden_state, deterministic)
hidden_state += residual
hidden_state = self.activation_func(hidden_state)
return hidden_state
class FlaxResNetStageLayersCollection(nn.Module):
"""
A ResNet stage composed by stacked layers.
"""
config: ResNetConfig
in_channels: int
out_channels: int
stride: int = 2
depth: int = 2
dtype: jnp.dtype = jnp.float32
def setup(self):
layer = FlaxResNetBottleNeckLayer if self.config.layer_type == "bottleneck" else FlaxResNetBasicLayer
layers = [
# downsampling is done in the first layer with stride of 2
layer(
self.in_channels,
self.out_channels,
stride=self.stride,
activation=self.config.hidden_act,
dtype=self.dtype,
name="0",
),
]
for i in range(self.depth - 1):
layers.append(
layer(
self.out_channels,
self.out_channels,
activation=self.config.hidden_act,
dtype=self.dtype,
name=str(i + 1),
)
)
self.layers = layers
def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray:
hidden_state = x
for layer in self.layers:
hidden_state = layer(hidden_state, deterministic=deterministic)
return hidden_state
class FlaxResNetStage(nn.Module):
"""
A ResNet stage composed by stacked layers.
"""
config: ResNetConfig
in_channels: int
out_channels: int
stride: int = 2
depth: int = 2
dtype: jnp.dtype = jnp.float32
def setup(self):
self.layers = FlaxResNetStageLayersCollection(
self.config,
in_channels=self.in_channels,
out_channels=self.out_channels,
stride=self.stride,
depth=self.depth,
dtype=self.dtype,
)
def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray:
return self.layers(x, deterministic=deterministic)
class FlaxResNetStageCollection(nn.Module):
config: ResNetConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
in_out_channels = zip(self.config.hidden_sizes, self.config.hidden_sizes[1:])
stages = [
FlaxResNetStage(
self.config,
self.config.embedding_size,
self.config.hidden_sizes[0],
stride=2 if self.config.downsample_in_first_stage else 1,
depth=self.config.depths[0],
dtype=self.dtype,
name="0",
)
]
for i, ((in_channels, out_channels), depth) in enumerate(zip(in_out_channels, self.config.depths[1:])):
stages.append(
FlaxResNetStage(self.config, in_channels, out_channels, depth=depth, dtype=self.dtype, name=str(i + 1))
)
self.stages = stages
def __call__(
self,
hidden_state: jnp.ndarray,
output_hidden_states: bool = False,
deterministic: bool = True,
) -> FlaxBaseModelOutputWithNoAttention:
hidden_states = () if output_hidden_states else None
for stage_module in self.stages:
if output_hidden_states:
hidden_states = hidden_states + (hidden_state.transpose(0, 3, 1, 2),)
hidden_state = stage_module(hidden_state, deterministic=deterministic)
return hidden_state, hidden_states
class FlaxResNetEncoder(nn.Module):
config: ResNetConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.stages = FlaxResNetStageCollection(self.config, dtype=self.dtype)
def __call__(
self,
hidden_state: jnp.ndarray,
output_hidden_states: bool = False,
return_dict: bool = True,
deterministic: bool = True,
) -> FlaxBaseModelOutputWithNoAttention:
hidden_state, hidden_states = self.stages(
hidden_state, output_hidden_states=output_hidden_states, deterministic=deterministic
)
if output_hidden_states:
hidden_states = hidden_states + (hidden_state.transpose(0, 3, 1, 2),)
if not return_dict:
return tuple(v for v in [hidden_state, hidden_states] if v is not None)
return FlaxBaseModelOutputWithNoAttention(
last_hidden_state=hidden_state,
hidden_states=hidden_states,
)
class FlaxResNetPreTrainedModel(FlaxPreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = ResNetConfig
base_model_prefix = "resnet"
main_input_name = "pixel_values"
module_class: nn.Module = None
def __init__(
self,
config: ResNetConfig,
input_shape=(1, 224, 224, 3),
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
**kwargs,
):
module = self.module_class(config=config, dtype=dtype, **kwargs)
if input_shape is None:
input_shape = (1, config.image_size, config.image_size, config.num_channels)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensors
pixel_values = jnp.zeros(input_shape, dtype=self.dtype)
rngs = {"params": rng}
random_params = self.module.init(rngs, pixel_values, return_dict=False)
if params is not None:
random_params = flatten_dict(unfreeze(random_params))
params = flatten_dict(unfreeze(params))
for missing_key in self._missing_keys:
params[missing_key] = random_params[missing_key]
self._missing_keys = set()
return freeze(unflatten_dict(params))
else:
return random_params
@add_start_docstrings_to_model_forward(RESNET_INPUTS_DOCSTRING)
def __call__(
self,
pixel_values,
params: dict = None,
train: bool = False,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.return_dict
pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1))
# Handle any PRNG if needed
rngs = {}
return self.module.apply(
{
"params": params["params"] if params is not None else self.params["params"],
"batch_stats": params["batch_stats"] if params is not None else self.params["batch_stats"],
},
jnp.array(pixel_values, dtype=jnp.float32),
not train,
output_hidden_states,
return_dict,
rngs=rngs,
mutable=["batch_stats"] if train else False, # Returing tuple with batch_stats only when train is True
)
class FlaxResNetModule(nn.Module):
config: ResNetConfig
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
def setup(self):
self.embedder = FlaxResNetEmbeddings(self.config, dtype=self.dtype)
self.encoder = FlaxResNetEncoder(self.config, dtype=self.dtype)
# Adaptive average pooling used in resnet
self.pooler = partial(
nn.avg_pool,
padding=((0, 0), (0, 0)),
)
def __call__(
self,
pixel_values,
deterministic: bool = True,
output_hidden_states: bool = False,
return_dict: bool = True,
) -> FlaxBaseModelOutputWithPoolingAndNoAttention:
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
embedding_output = self.embedder(pixel_values, deterministic=deterministic)
encoder_outputs = self.encoder(
embedding_output,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
deterministic=deterministic,
)
last_hidden_state = encoder_outputs[0]
pooled_output = self.pooler(
last_hidden_state,
window_shape=(last_hidden_state.shape[1], last_hidden_state.shape[2]),
strides=(last_hidden_state.shape[1], last_hidden_state.shape[2]),
).transpose(0, 3, 1, 2)
last_hidden_state = last_hidden_state.transpose(0, 3, 1, 2)
if not return_dict:
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
return FlaxBaseModelOutputWithPoolingAndNoAttention(
last_hidden_state=last_hidden_state,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
)
@add_start_docstrings(
"The bare ResNet model outputting raw features without any specific head on top.",
RESNET_START_DOCSTRING,
)
class FlaxResNetModel(FlaxResNetPreTrainedModel):
module_class = FlaxResNetModule
FLAX_VISION_MODEL_DOCSTRING = """
Returns:
Examples:
```python
>>> from transformers import AutoImageProcessor, FlaxResNetModel
>>> from PIL import Image
>>> import requests
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> image_processor = AutoImageProcessor.from_pretrained("microsoft/resnet-50")
>>> model = FlaxResNetModel.from_pretrained("microsoft/resnet-50")
>>> inputs = image_processor(images=image, return_tensors="np")
>>> outputs = model(**inputs)
>>> last_hidden_states = outputs.last_hidden_state
```
"""
overwrite_call_docstring(FlaxResNetModel, FLAX_VISION_MODEL_DOCSTRING)
append_replace_return_docstrings(
FlaxResNetModel, output_type=FlaxBaseModelOutputWithPoolingAndNoAttention, config_class=ResNetConfig
)
class FlaxResNetClassifierCollection(nn.Module):
config: ResNetConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype, name="1")
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
return self.classifier(x)
class FlaxResNetForImageClassificationModule(nn.Module):
config: ResNetConfig
dtype: jnp.dtype = jnp.float32
def setup(self):
self.resnet = FlaxResNetModule(config=self.config, dtype=self.dtype)
if self.config.num_labels > 0:
self.classifier = FlaxResNetClassifierCollection(self.config, dtype=self.dtype)
else:
self.classifier = Identity()
def __call__(
self,
pixel_values=None,
deterministic: bool = True,
output_hidden_states=None,
return_dict=None,
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.resnet(
pixel_values,
deterministic=deterministic,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
pooled_output = outputs.pooler_output if return_dict else outputs[1]
logits = self.classifier(pooled_output[:, :, 0, 0])
if not return_dict:
output = (logits,) + outputs[2:]
return output
return FlaxImageClassifierOutputWithNoAttention(logits=logits, hidden_states=outputs.hidden_states)
@add_start_docstrings(
"""
ResNet Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for
ImageNet.
""",
RESNET_START_DOCSTRING,
)
class FlaxResNetForImageClassification(FlaxResNetPreTrainedModel):
module_class = FlaxResNetForImageClassificationModule
FLAX_VISION_CLASSIF_DOCSTRING = """
Returns:
Example:
```python
>>> from transformers import AutoImageProcessor, FlaxResNetForImageClassification
>>> from PIL import Image
>>> import jax
>>> import requests
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> image_processor = AutoImageProcessor.from_pretrained("microsoft/resnet-50")
>>> model = FlaxResNetForImageClassification.from_pretrained("microsoft/resnet-50")
>>> inputs = image_processor(images=image, return_tensors="np")
>>> outputs = model(**inputs)
>>> logits = outputs.logits
>>> # model predicts one of the 1000 ImageNet classes
>>> predicted_class_idx = jax.numpy.argmax(logits, axis=-1)
>>> print("Predicted class:", model.config.id2label[predicted_class_idx.item()])
```
"""
overwrite_call_docstring(FlaxResNetForImageClassification, FLAX_VISION_CLASSIF_DOCSTRING)
append_replace_return_docstrings(
FlaxResNetForImageClassification, output_type=FlaxImageClassifierOutputWithNoAttention, config_class=ResNetConfig
)
...@@ -881,6 +881,27 @@ class FlaxPegasusPreTrainedModel(metaclass=DummyObject): ...@@ -881,6 +881,27 @@ class FlaxPegasusPreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["flax"]) requires_backends(self, ["flax"])
class FlaxResNetForImageClassification(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxResNetModel(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxResNetPreTrainedModel(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxRobertaForCausalLM(metaclass=DummyObject): class FlaxRobertaForCausalLM(metaclass=DummyObject):
_backends = ["flax"] _backends = ["flax"]
......
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import unittest
from transformers import ResNetConfig, is_flax_available
from transformers.testing_utils import require_flax, slow
from transformers.utils import cached_property, is_vision_available
from ...test_configuration_common import ConfigTester
from ...test_modeling_flax_common import FlaxModelTesterMixin, floats_tensor
if is_flax_available():
import jax
import jax.numpy as jnp
from transformers.models.resnet.modeling_flax_resnet import FlaxResNetForImageClassification, FlaxResNetModel
if is_vision_available():
from PIL import Image
from transformers import AutoFeatureExtractor
class FlaxResNetModelTester(unittest.TestCase):
def __init__(
self,
parent,
batch_size=3,
image_size=32,
num_channels=3,
embeddings_size=10,
hidden_sizes=[10, 20, 30, 40],
depths=[1, 1, 2, 1],
is_training=True,
use_labels=True,
hidden_act="relu",
num_labels=3,
scope=None,
):
self.parent = parent
self.batch_size = batch_size
self.image_size = image_size
self.num_channels = num_channels
self.embeddings_size = embeddings_size
self.hidden_sizes = hidden_sizes
self.depths = depths
self.is_training = is_training
self.use_labels = use_labels
self.hidden_act = hidden_act
self.num_labels = num_labels
self.scope = scope
self.num_stages = len(hidden_sizes)
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
config = self.get_config()
return config, pixel_values
def get_config(self):
return ResNetConfig(
num_channels=self.num_channels,
embeddings_size=self.embeddings_size,
hidden_sizes=self.hidden_sizes,
depths=self.depths,
hidden_act=self.hidden_act,
num_labels=self.num_labels,
image_size=self.image_size,
)
def create_and_check_model(self, config, pixel_values):
model = FlaxResNetModel(config=config)
result = model(pixel_values)
# Output shape (b, c, h, w)
self.parent.assertEqual(
result.last_hidden_state.shape,
(self.batch_size, self.hidden_sizes[-1], self.image_size // 32, self.image_size // 32),
)
def create_and_check_for_image_classification(self, config, pixel_values):
config.num_labels = self.num_labels
model = FlaxResNetForImageClassification(config=config)
result = model(pixel_values)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, pixel_values = config_and_inputs
inputs_dict = {"pixel_values": pixel_values}
return config, inputs_dict
@require_flax
class FlaxResNetModelTest(FlaxModelTesterMixin, unittest.TestCase):
all_model_classes = (FlaxResNetModel, FlaxResNetForImageClassification) if is_flax_available() else ()
is_encoder_decoder = False
test_head_masking = False
has_attentions = False
def setUp(self) -> None:
self.model_tester = FlaxResNetModelTester(self)
self.config_tester = ConfigTester(self, config_class=ResNetConfig, has_text_modality=False)
def test_config(self):
self.create_and_test_config_common_properties()
self.config_tester.create_and_test_config_to_json_string()
self.config_tester.create_and_test_config_to_json_file()
self.config_tester.create_and_test_config_from_and_save_pretrained()
self.config_tester.create_and_test_config_with_num_labels()
self.config_tester.check_config_can_be_init_without_params()
self.config_tester.check_config_arguments_init()
def create_and_test_config_common_properties(self):
return
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
def test_for_image_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
@unittest.skip(reason="ResNet does not use inputs_embeds")
def test_inputs_embeds(self):
pass
@unittest.skip(reason="ResNet does not support input and output embeddings")
def test_model_common_attributes(self):
pass
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
signature = inspect.signature(model.__call__)
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()]
expected_arg_names = ["pixel_values"]
self.assertListEqual(arg_names[:1], expected_arg_names)
def test_hidden_states_output(self):
def check_hidden_states_output(inputs_dict, config, model_class):
model = model_class(config)
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states
expected_num_stages = self.model_tester.num_stages
self.assertEqual(len(hidden_states), expected_num_stages + 1)
@unittest.skip(reason="ResNet does not use feedforward chunking")
def test_feed_forward_chunking(self):
pass
def test_jit_compilation(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
with self.subTest(model_class.__name__):
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
model = model_class(config)
@jax.jit
def model_jitted(pixel_values, **kwargs):
return model(pixel_values=pixel_values, **kwargs)
with self.subTest("JIT Enabled"):
jitted_outputs = model_jitted(**prepared_inputs_dict).to_tuple()
with self.subTest("JIT Disabled"):
with jax.disable_jit():
outputs = model_jitted(**prepared_inputs_dict).to_tuple()
self.assertEqual(len(outputs), len(jitted_outputs))
for jitted_output, output in zip(jitted_outputs, outputs):
self.assertEqual(jitted_output.shape, output.shape)
# We will verify our results on an image of cute cats
def prepare_img():
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
return image
@require_flax
class FlaxResNetModelIntegrationTest(unittest.TestCase):
@cached_property
def default_feature_extractor(self):
return AutoFeatureExtractor.from_pretrained("microsoft/resnet-50") if is_vision_available() else None
@slow
def test_inference_image_classification_head(self):
model = FlaxResNetForImageClassification.from_pretrained("microsoft/resnet-50")
feature_extractor = self.default_feature_extractor
image = prepare_img()
inputs = feature_extractor(images=image, return_tensors="np")
outputs = model(**inputs)
# verify the logits
expected_shape = (1, 1000)
self.assertEqual(outputs.logits.shape, expected_shape)
expected_slice = jnp.array([-11.1069, -9.7877, -8.3777])
self.assertTrue(jnp.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment