Unverified Commit a7eba831 authored by Aritra Roy Gosthipaty's avatar Aritra Roy Gosthipaty Committed by GitHub
Browse files

TF implementation of RegNets (#17554)



* chore: initial commit

Copied the torch implementation of regnets and porting the code to tf step by step. Also introduced an output layer which was needed for regnets.

* chore: porting the rest of the modules to tensorflow

did not change the documentation yet, yet to try the playground on the model

* Fix initilizations (#1)

* fix: code structure in few cases.

* fix: code structure to align tf models.

* fix: layer naming, bn layer still remains.

* chore: change default epsilon and momentum in bn.

* chore: styling nits.

* fix: cross-loading bn params.

* fix: regnet tf model, integration passing.

* add: tests for TF regnet.

* fix: code quality related issues.

* chore: added rest of the files.

* minor additions..

* fix: repo consistency.

* fix: regnet tf tests.

* chore: reorganize dummy_tf_objects for regnet.

* chore: remove checkpoint var.

* chore: remov unnecessary files.

* chore: run make style.

* Update docs/source/en/model_doc/regnet.mdx
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* chore: PR feedback I.

* fix: pt test. thanks to @ydshieh.

* New adaptive pooler (#3)

* feat: new adaptive pooler

Co-authored-by: @Rocketknight1

* chore: remove image_size argument.
Co-authored-by: default avatarmatt <rocketknight1@gmail.com>
Co-authored-by: default avatarmatt <rocketknight1@gmail.com>

* Empty-Commit

* chore: remove image_size comment.

* chore: remove playground_tf.py

* chore: minor changes related to spacing.

* chore: make style.

* Update src/transformers/models/regnet/modeling_tf_regnet.py
Co-authored-by: default avataramyeroberts <aeroberts4444@gmail.com>

* Update src/transformers/models/regnet/modeling_tf_regnet.py
Co-authored-by: default avataramyeroberts <aeroberts4444@gmail.com>

* chore: refactored __init__.

* chore: copied from -> taken from./g

* adaptive pool -> global avg pool, channel check.

* chore: move channel check to stem.

* pr comments - minor refactor and add regnets to doc tests.

* Update src/transformers/models/regnet/modeling_tf_regnet.py
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* minor fix in the xlayer.

* Empty-Commit

* chore: removed from_pt=True.
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarmatt <rocketknight1@gmail.com>
Co-authored-by: default avataramyeroberts <aeroberts4444@gmail.com>
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>
parent e6d27ca5
......@@ -267,7 +267,7 @@ Flax), PyTorch, and/or TensorFlow.
| RAG | ✅ | ❌ | ✅ | ✅ | ❌ |
| REALM | ✅ | ✅ | ✅ | ❌ | ❌ |
| Reformer | ✅ | ✅ | ✅ | ❌ | ❌ |
| RegNet | ❌ | ❌ | ✅ | | ❌ |
| RegNet | ❌ | ❌ | ✅ | | ❌ |
| RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ |
| ResNet | ❌ | ❌ | ✅ | ❌ | ❌ |
| RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ |
......
......@@ -27,7 +27,8 @@ Tips:
- One can use [`AutoFeatureExtractor`] to prepare images for the model.
- The huge 10B model from [Self-supervised Pretraining of Visual Features in the Wild](https://arxiv.org/abs/2103.01988), trained on one billion Instagram images, is available on the [hub](https://huggingface.co/facebook/regnet-y-10b-seer)
This model was contributed by [Francesco](https://huggingface.co/Francesco).
This model was contributed by [Francesco](https://huggingface.co/Francesco). The TensorFlow version of the model
was contributed by [sayakpaul](https://huggingface.com/sayakpaul) and [ariG23498](https://huggingface.com/ariG23498).
The original code can be found [here](https://github.com/facebookresearch/pycls).
......@@ -46,3 +47,14 @@ The original code can be found [here](https://github.com/facebookresearch/pycls)
[[autodoc]] RegNetForImageClassification
- forward
## TFRegNetModel
[[autodoc]] TFRegNetModel
- call
## TFRegNetForImageClassification
[[autodoc]] TFRegNetForImageClassification
- call
\ No newline at end of file
......@@ -2334,6 +2334,14 @@ else:
"TFRagTokenForGeneration",
]
)
_import_structure["models.regnet"].extend(
[
"TF_REGNET_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFRegNetForImageClassification",
"TFRegNetModel",
"TFRegNetPreTrainedModel",
]
)
_import_structure["models.rembert"].extend(
[
"TF_REMBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
......@@ -4649,6 +4657,12 @@ if TYPE_CHECKING:
from .models.opt import TFOPTForCausalLM, TFOPTModel, TFOPTPreTrainedModel
from .models.pegasus import TFPegasusForConditionalGeneration, TFPegasusModel, TFPegasusPreTrainedModel
from .models.rag import TFRagModel, TFRagPreTrainedModel, TFRagSequenceForGeneration, TFRagTokenForGeneration
from .models.regnet import (
TF_REGNET_PRETRAINED_MODEL_ARCHIVE_LIST,
TFRegNetForImageClassification,
TFRegNetModel,
TFRegNetPreTrainedModel,
)
from .models.rembert import (
TF_REMBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFRemBertForCausalLM,
......
......@@ -46,6 +46,25 @@ class TFBaseModelOutput(ModelOutput):
attentions: Optional[Tuple[tf.Tensor]] = None
@dataclass
class TFBaseModelOutputWithNoAttention(ModelOutput):
"""
Base class for model's outputs, with potential hidden states.
Args:
last_hidden_state (`tf.Tensor` shape `(batch_size, num_channels, height, width)`):
Sequence of hidden-states at the output of the last layer of the model.
hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, num_channels, height, width)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
"""
last_hidden_state: tf.Tensor = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
@dataclass
class TFBaseModelOutputWithPooling(ModelOutput):
"""
......@@ -80,6 +99,28 @@ class TFBaseModelOutputWithPooling(ModelOutput):
attentions: Optional[Tuple[tf.Tensor]] = None
@dataclass
class TFBaseModelOutputWithPoolingAndNoAttention(ModelOutput):
"""
Base class for model's outputs that also contains a pooling of the last hidden states.
Args:
last_hidden_state (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`):
Sequence of hidden-states at the output of the last layer of the model.
pooler_output (`tf.Tensor` of shape `(batch_size, hidden_size)`):
Last layer hidden-state after a pooling operation on the spatial dimensions.
hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for
the output of each layer) of shape `(batch_size, num_channels, height, width)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
"""
last_hidden_state: tf.Tensor = None
pooler_output: tf.Tensor = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
@dataclass
class TFBaseModelOutputWithPoolingAndCrossAttentions(ModelOutput):
"""
......@@ -825,3 +866,24 @@ class TFSequenceClassifierOutputWithPast(ModelOutput):
past_key_values: Optional[List[tf.Tensor]] = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[tf.Tensor]] = None
@dataclass
class TFImageClassifierOutputWithNoAttention(ModelOutput):
"""
Base class for outputs of image classification models.
Args:
loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Classification (or regression if config.num_labels==1) loss.
logits (`tf.Tensor` of shape `(batch_size, config.num_labels)`):
Classification (or regression if config.num_labels==1) scores (before SoftMax).
hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for
the output of each stage) of shape `(batch_size, num_channels, height, width)`. Hidden-states (also called
feature maps) of the model at the output of each stage.
"""
loss: Optional[tf.Tensor] = None
logits: tf.Tensor = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
......@@ -62,6 +62,7 @@ TF_MODEL_MAPPING_NAMES = OrderedDict(
("openai-gpt", "TFOpenAIGPTModel"),
("opt", "TFOPTModel"),
("pegasus", "TFPegasusModel"),
("regnet", "TFRegNetModel"),
("rembert", "TFRemBertModel"),
("roberta", "TFRobertaModel"),
("roformer", "TFRoFormerModel"),
......@@ -173,6 +174,7 @@ TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
# Model for Image-classsification
("convnext", "TFConvNextForImageClassification"),
("data2vec-vision", "TFData2VecVisionForImageClassification"),
("regnet", "TFRegNetForImageClassification"),
("swin", "TFSwinForImageClassification"),
("vit", "TFViTForImageClassification"),
]
......
......@@ -18,8 +18,7 @@
from typing import TYPE_CHECKING
# rely on isort to merge the imports
from ...file_utils import _LazyModule, is_torch_available
from ...utils import OptionalDependencyNotAvailable
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tf_available, is_torch_available
_import_structure = {"configuration_regnet": ["REGNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "RegNetConfig"]}
......@@ -37,6 +36,19 @@ else:
"RegNetPreTrainedModel",
]
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_tf_regnet"] = [
"TF_REGNET_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFRegNetForImageClassification",
"TFRegNetModel",
"TFRegNetPreTrainedModel",
]
if TYPE_CHECKING:
from .configuration_regnet import REGNET_PRETRAINED_CONFIG_ARCHIVE_MAP, RegNetConfig
......@@ -54,6 +66,19 @@ if TYPE_CHECKING:
RegNetPreTrainedModel,
)
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_tf_regnet import (
TF_REGNET_PRETRAINED_MODEL_ARCHIVE_LIST,
TFRegNetForImageClassification,
TFRegNetModel,
TFRegNetPreTrainedModel,
)
else:
import sys
......
# coding=utf-8
# Copyright 2022 Meta Platforms, Inc. and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" TensorFlow RegNet model."""
from typing import Dict, Optional, Tuple, Union
import tensorflow as tf
from ...activations_tf import ACT2FN
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
from ...modeling_tf_outputs import (
TFBaseModelOutputWithNoAttention,
TFBaseModelOutputWithPoolingAndNoAttention,
TFSequenceClassifierOutput,
)
from ...modeling_tf_utils import TFPreTrainedModel, TFSequenceClassificationLoss, keras_serializable, unpack_inputs
from ...tf_utils import shape_list
from ...utils import logging
from .configuration_regnet import RegNetConfig
logger = logging.get_logger(__name__)
# General docstring
_CONFIG_FOR_DOC = "RegNetConfig"
_FEAT_EXTRACTOR_FOR_DOC = "AutoFeatureExtractor"
# Base docstring
_CHECKPOINT_FOR_DOC = "facebook/regnet-y-040"
_EXPECTED_OUTPUT_SHAPE = [1, 1088, 7, 7]
# Image classification docstring
_IMAGE_CLASS_CHECKPOINT = "facebook/regnet-y-040"
_IMAGE_CLASS_EXPECTED_OUTPUT = "'tabby, tabby cat'"
TF_REGNET_PRETRAINED_MODEL_ARCHIVE_LIST = [
"facebook/regnet-y-040",
# See all regnet models at https://huggingface.co/models?filter=regnet
]
class TFRegNetConvLayer(tf.keras.layers.Layer):
def __init__(
self,
out_channels: int,
kernel_size: int = 3,
stride: int = 1,
groups: int = 1,
activation: Optional[str] = "relu",
**kwargs,
):
super().__init__(**kwargs)
# The padding and conv has been verified in
# https://colab.research.google.com/gist/sayakpaul/854bc10eeaf21c9ee2119e0b9f3841a7/scratchpad.ipynb
self.padding = tf.keras.layers.ZeroPadding2D(padding=kernel_size // 2)
self.convolution = tf.keras.layers.Conv2D(
filters=out_channels,
kernel_size=kernel_size,
strides=stride,
padding="VALID",
groups=groups,
use_bias=False,
name="convolution",
)
self.normalization = tf.keras.layers.BatchNormalization(epsilon=1e-5, momentum=0.1, name="normalization")
self.activation = ACT2FN[activation] if activation is not None else tf.identity
def call(self, hidden_state):
hidden_state = self.convolution(self.padding(hidden_state))
hidden_state = self.normalization(hidden_state)
hidden_state = self.activation(hidden_state)
return hidden_state
class TFRegNetEmbeddings(tf.keras.layers.Layer):
"""
RegNet Embeddings (stem) composed of a single aggressive convolution.
"""
def __init__(self, config: RegNetConfig, **kwargs):
super().__init__(**kwargs)
self.num_channels = config.num_channels
self.embedder = TFRegNetConvLayer(
out_channels=config.embedding_size,
kernel_size=3,
stride=2,
activation=config.hidden_act,
name="embedder",
)
def call(self, pixel_values):
num_channels = shape_list(pixel_values)[1]
if tf.executing_eagerly() and num_channels != self.num_channels:
raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
)
# When running on CPU, `tf.keras.layers.Conv2D` doesn't support `NCHW` format.
# So change the input format from `NCHW` to `NHWC`.
# shape = (batch_size, in_height, in_width, in_channels=num_channels)
pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1))
hidden_state = self.embedder(pixel_values)
return hidden_state
class TFRegNetShortCut(tf.keras.layers.Layer):
"""
RegNet shortcut, used to project the residual features to the correct size. If needed, it is also used to
downsample the input using `stride=2`.
"""
def __init__(self, out_channels: int, stride: int = 2, **kwargs):
super().__init__(**kwargs)
self.convolution = tf.keras.layers.Conv2D(
filters=out_channels, kernel_size=1, strides=stride, use_bias=False, name="convolution"
)
self.normalization = tf.keras.layers.BatchNormalization(epsilon=1e-5, momentum=0.1, name="normalization")
def call(self, inputs: tf.Tensor, training: bool = False) -> tf.Tensor:
return self.normalization(self.convolution(inputs), training=training)
class TFRegNetSELayer(tf.keras.layers.Layer):
"""
Squeeze and Excitation layer (SE) proposed in [Squeeze-and-Excitation Networks](https://arxiv.org/abs/1709.01507).
"""
def __init__(self, in_channels: int, reduced_channels: int, **kwargs):
super().__init__(**kwargs)
self.pooler = tf.keras.layers.GlobalAveragePooling2D(keepdims=True, name="pooler")
self.attention = [
tf.keras.layers.Conv2D(filters=reduced_channels, kernel_size=1, activation="relu", name="attention.0"),
tf.keras.layers.Conv2D(filters=in_channels, kernel_size=1, activation="sigmoid", name="attention.2"),
]
def call(self, hidden_state):
# [batch_size, h, w, num_channels] -> [batch_size, 1, 1, num_channels]
pooled = self.pooler(hidden_state)
for layer_module in self.attention:
pooled = layer_module(pooled)
hidden_state = hidden_state * pooled
return hidden_state
class TFRegNetXLayer(tf.keras.layers.Layer):
"""
RegNet's layer composed by three `3x3` convolutions, same as a ResNet bottleneck layer with reduction = 1.
"""
def __init__(self, config: RegNetConfig, in_channels: int, out_channels: int, stride: int = 1, **kwargs):
super().__init__(**kwargs)
should_apply_shortcut = in_channels != out_channels or stride != 1
groups = max(1, out_channels // config.groups_width)
self.shortcut = (
TFRegNetShortCut(out_channels, stride=stride, name="shortcut")
if should_apply_shortcut
else tf.keras.layers.Activation("linear", name="shortcut")
)
# `self.layers` instead of `self.layer` because that is a reserved argument.
self.layers = [
TFRegNetConvLayer(out_channels, kernel_size=1, activation=config.hidden_act, name="layer.0"),
TFRegNetConvLayer(
out_channels, stride=stride, groups=groups, activation=config.hidden_act, name="layer.1"
),
TFRegNetConvLayer(out_channels, kernel_size=1, activation=None, name="layer.2"),
]
self.activation = ACT2FN[config.hidden_act]
def call(self, hidden_state):
residual = hidden_state
for layer_module in self.layers:
hidden_state = layer_module(hidden_state)
residual = self.shortcut(residual)
hidden_state += residual
hidden_state = self.activation(hidden_state)
return hidden_state
class TFRegNetYLayer(tf.keras.layers.Layer):
"""
RegNet's Y layer: an X layer with Squeeze and Excitation.
"""
def __init__(self, config: RegNetConfig, in_channels: int, out_channels: int, stride: int = 1, **kwargs):
super().__init__(**kwargs)
should_apply_shortcut = in_channels != out_channels or stride != 1
groups = max(1, out_channels // config.groups_width)
self.shortcut = (
TFRegNetShortCut(out_channels, stride=stride, name="shortcut")
if should_apply_shortcut
else tf.keras.layers.Activation("linear", name="shortcut")
)
self.layers = [
TFRegNetConvLayer(out_channels, kernel_size=1, activation=config.hidden_act, name="layer.0"),
TFRegNetConvLayer(
out_channels, stride=stride, groups=groups, activation=config.hidden_act, name="layer.1"
),
TFRegNetSELayer(out_channels, reduced_channels=int(round(in_channels / 4)), name="layer.2"),
TFRegNetConvLayer(out_channels, kernel_size=1, activation=None, name="layer.3"),
]
self.activation = ACT2FN[config.hidden_act]
def call(self, hidden_state):
residual = hidden_state
for layer_module in self.layers:
hidden_state = layer_module(hidden_state)
residual = self.shortcut(residual)
hidden_state += residual
hidden_state = self.activation(hidden_state)
return hidden_state
class TFRegNetStage(tf.keras.layers.Layer):
"""
A RegNet stage composed by stacked layers.
"""
def __init__(
self, config: RegNetConfig, in_channels: int, out_channels: int, stride: int = 2, depth: int = 2, **kwargs
):
super().__init__(**kwargs)
layer = TFRegNetXLayer if config.layer_type == "x" else TFRegNetYLayer
self.layers = [
# downsampling is done in the first layer with stride of 2
layer(config, in_channels, out_channels, stride=stride, name="layers.0"),
*[layer(config, out_channels, out_channels, name=f"layers.{i+1}") for i in range(depth - 1)],
]
def call(self, hidden_state):
for layer_module in self.layers:
hidden_state = layer_module(hidden_state)
return hidden_state
class TFRegNetEncoder(tf.keras.layers.Layer):
def __init__(self, config: RegNetConfig, **kwargs):
super().__init__(**kwargs)
self.stages = list()
# based on `downsample_in_first_stage`, the first layer of the first stage may or may not downsample the input
self.stages.append(
TFRegNetStage(
config,
config.embedding_size,
config.hidden_sizes[0],
stride=2 if config.downsample_in_first_stage else 1,
depth=config.depths[0],
name="stages.0",
)
)
in_out_channels = zip(config.hidden_sizes, config.hidden_sizes[1:])
for i, ((in_channels, out_channels), depth) in enumerate(zip(in_out_channels, config.depths[1:])):
self.stages.append(TFRegNetStage(config, in_channels, out_channels, depth=depth, name=f"stages.{i+1}"))
def call(
self, hidden_state: tf.Tensor, output_hidden_states: bool = False, return_dict: bool = True
) -> TFBaseModelOutputWithNoAttention:
hidden_states = () if output_hidden_states else None
for stage_module in self.stages:
if output_hidden_states:
hidden_states = hidden_states + (hidden_state,)
hidden_state = stage_module(hidden_state)
if output_hidden_states:
hidden_states = hidden_states + (hidden_state,)
if not return_dict:
return tuple(v for v in [hidden_state, hidden_states] if v is not None)
return TFBaseModelOutputWithNoAttention(last_hidden_state=hidden_state, hidden_states=hidden_states)
@keras_serializable
class TFRegNetMainLayer(tf.keras.layers.Layer):
config_class = RegNetConfig
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.config = config
self.embedder = TFRegNetEmbeddings(config, name="embedder")
self.encoder = TFRegNetEncoder(config, name="encoder")
self.pooler = tf.keras.layers.GlobalAveragePooling2D(keepdims=True, name="pooler")
@unpack_inputs
def call(
self,
pixel_values: tf.Tensor,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: bool = False,
) -> TFBaseModelOutputWithPoolingAndNoAttention:
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
embedding_output = self.embedder(pixel_values, training=training)
encoder_outputs = self.encoder(
embedding_output, output_hidden_states=output_hidden_states, return_dict=return_dict, training=training
)
last_hidden_state = encoder_outputs[0]
pooled_output = self.pooler(last_hidden_state)
# Change to NCHW output format have uniformity in the modules
pooled_output = tf.transpose(pooled_output, perm=(0, 3, 1, 2))
last_hidden_state = tf.transpose(last_hidden_state, perm=(0, 3, 1, 2))
# Change the other hidden state outputs to NCHW as well
if output_hidden_states:
hidden_states = tuple([tf.transpose(h, perm=(0, 3, 1, 2)) for h in encoder_outputs[1]])
if not return_dict:
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
return TFBaseModelOutputWithPoolingAndNoAttention(
last_hidden_state=last_hidden_state,
pooler_output=pooled_output,
hidden_states=hidden_states if output_hidden_states else encoder_outputs.hidden_states,
)
class TFRegNetPreTrainedModel(TFPreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = RegNetConfig
base_model_prefix = "regnet"
main_input_name = "pixel_values"
@property
def dummy_inputs(self) -> Dict[str, tf.Tensor]:
"""
Dummy inputs to build the network.
Returns:
`Dict[str, tf.Tensor]`: The dummy inputs.
"""
VISION_DUMMY_INPUTS = tf.random.uniform(shape=(3, self.config.num_channels, 224, 224), dtype=tf.float32)
return {"pixel_values": tf.constant(VISION_DUMMY_INPUTS)}
@tf.function(
input_signature=[
{
"pixel_values": tf.TensorSpec((None, None, None, None), tf.float32, name="pixel_values"),
}
]
)
def serving(self, inputs):
"""
Method used for serving the model.
Args:
inputs (`Dict[str, tf.Tensor]`):
The input of the saved model as a dictionary of tensors.
"""
return self.call(inputs)
REGNET_START_DOCSTRING = r"""
Parameters:
This model is a Tensorflow
[tf.keras.layers.Layer](https://www.tensorflow.org/api_docs/python/tf/keras/layers/Layer) sub-class. Use it as a
regular Tensorflow Module and refer to the Tensorflow documentation for all matter related to general usage and
behavior.
config ([`RegNetConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights.
"""
REGNET_INPUTS_DOCSTRING = r"""
Args:
pixel_values (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values. Pixel values can be obtained using [`AutoFeatureExtractor`]. See
[`AutoFeatureExtractor.__call__`] for details.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
@add_start_docstrings(
"The bare RegNet model outputting raw features without any specific head on top.",
REGNET_START_DOCSTRING,
)
class TFRegNetModel(TFRegNetPreTrainedModel):
def __init__(self, config: RegNetConfig, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.regnet = TFRegNetMainLayer(config, name="regnet")
@unpack_inputs
@add_start_docstrings_to_model_forward(REGNET_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TFBaseModelOutputWithPoolingAndNoAttention,
config_class=_CONFIG_FOR_DOC,
modality="vision",
expected_output=_EXPECTED_OUTPUT_SHAPE,
)
def call(
self,
pixel_values: tf.Tensor,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training=False,
) -> Union[TFBaseModelOutputWithPoolingAndNoAttention, Tuple[tf.Tensor]]:
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.regnet(
pixel_values=pixel_values,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
if not return_dict:
return (outputs[0],) + outputs[1:]
return TFBaseModelOutputWithPoolingAndNoAttention(
last_hidden_state=outputs.last_hidden_state,
pooler_output=outputs.pooler_output,
hidden_states=outputs.hidden_states,
)
@add_start_docstrings(
"""
RegNet Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for
ImageNet.
""",
REGNET_START_DOCSTRING,
)
class TFRegNetForImageClassification(TFRegNetPreTrainedModel, TFSequenceClassificationLoss):
def __init__(self, config: RegNetConfig, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.num_labels = config.num_labels
self.regnet = TFRegNetMainLayer(config, name="regnet")
# classification head
self.classifier = [
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(config.num_labels, name="classifier.1") if config.num_labels > 0 else tf.identity,
]
@unpack_inputs
@add_start_docstrings_to_model_forward(REGNET_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
checkpoint=_IMAGE_CLASS_CHECKPOINT,
output_type=TFSequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC,
expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
)
def call(
self,
pixel_values: tf.Tensor = None,
labels: tf.Tensor = None,
output_hidden_states: bool = None,
return_dict: bool = None,
training=False,
) -> Union[TFSequenceClassifierOutput, Tuple[tf.Tensor]]:
r"""
labels (`tf.Tensor` of shape `(batch_size,)`, *optional*):
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.regnet(
pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict, training=training
)
pooled_output = outputs.pooler_output if return_dict else outputs[1]
flattened_output = self.classifier[0](pooled_output)
logits = self.classifier[1](flattened_output)
loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits)
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return TFSequenceClassifierOutput(loss=loss, logits=logits, hidden_states=outputs.hidden_states)
......@@ -1696,6 +1696,30 @@ class TFRagTokenForGeneration(metaclass=DummyObject):
requires_backends(self, ["tf"])
TF_REGNET_PRETRAINED_MODEL_ARCHIVE_LIST = None
class TFRegNetForImageClassification(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFRegNetModel(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFRegNetPreTrainedModel(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
TF_REMBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
......
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Testing suite for the TensorFlow RegNet model. """
import inspect
import unittest
from typing import List, Tuple
from transformers import RegNetConfig
from transformers.testing_utils import require_tf, require_vision, slow
from transformers.utils import cached_property, is_tf_available, is_vision_available
from ...test_configuration_common import ConfigTester
from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor
if is_tf_available():
import tensorflow as tf
from transformers import TF_REGNET_PRETRAINED_MODEL_ARCHIVE_LIST, TFRegNetForImageClassification, TFRegNetModel
if is_vision_available():
from PIL import Image
from transformers import AutoFeatureExtractor
class TFRegNetModelTester:
def __init__(
self,
parent,
batch_size=3,
image_size=32,
num_channels=3,
embeddings_size=10,
hidden_sizes=[10, 20, 30, 40],
depths=[1, 1, 2, 1],
is_training=True,
use_labels=True,
hidden_act="relu",
num_labels=3,
scope=None,
):
self.parent = parent
self.batch_size = batch_size
self.image_size = image_size
self.num_channels = num_channels
self.embeddings_size = embeddings_size
self.hidden_sizes = hidden_sizes
self.depths = depths
self.is_training = is_training
self.use_labels = use_labels
self.hidden_act = hidden_act
self.num_labels = num_labels
self.scope = scope
self.num_stages = len(hidden_sizes)
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
labels = None
if self.use_labels:
labels = ids_tensor([self.batch_size], self.num_labels)
config = self.get_config()
return config, pixel_values, labels
def get_config(self):
return RegNetConfig(
num_channels=self.num_channels,
embeddings_size=self.embeddings_size,
hidden_sizes=self.hidden_sizes,
depths=self.depths,
hidden_act=self.hidden_act,
num_labels=self.num_labels,
)
def create_and_check_model(self, config, pixel_values, labels):
model = TFRegNetModel(config=config)
result = model(pixel_values, training=False)
# expected last hidden states: B, C, H // 32, W // 32
self.parent.assertEqual(
result.last_hidden_state.shape,
(self.batch_size, self.hidden_sizes[-1], self.image_size // 32, self.image_size // 32),
)
def create_and_check_for_image_classification(self, config, pixel_values, labels):
config.num_labels = self.num_labels
model = TFRegNetForImageClassification(config)
result = model(pixel_values, labels=labels, training=False)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, pixel_values, labels = config_and_inputs
inputs_dict = {"pixel_values": pixel_values}
return config, inputs_dict
@require_tf
class TFRegNetModelTest(TFModelTesterMixin, unittest.TestCase):
"""
Here we also overwrite some of the tests of test_modeling_common.py, as RegNet does not use input_ids, inputs_embeds,
attention_mask and seq_length.
"""
all_model_classes = (TFRegNetModel, TFRegNetForImageClassification) if is_tf_available() else ()
test_pruning = False
test_onnx = False
test_resize_embeddings = False
test_head_masking = False
has_attentions = False
def setUp(self):
self.model_tester = TFRegNetModelTester(self)
self.config_tester = ConfigTester(self, config_class=RegNetConfig, has_text_modality=False)
def create_and_test_config_common_properties(self):
return
@unittest.skip(reason="RegNet does not use inputs_embeds")
def test_inputs_embeds(self):
pass
@unittest.skipIf(
not is_tf_available() or len(tf.config.list_physical_devices("GPU")) == 0,
reason="TF (<=2.8) does not support backprop for grouped convolutions on CPU.",
)
def test_keras_fit(self):
pass
@unittest.skip(reason="RegNet does not support input and output embeddings")
def test_model_common_attributes(self):
pass
@unittest.skip(reason="Model doesn't have attention layers")
def test_attention_outputs(self):
pass
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
signature = inspect.signature(model.call)
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()]
expected_arg_names = ["pixel_values"]
self.assertListEqual(arg_names[:1], expected_arg_names)
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
def test_hidden_states_output(self):
def check_hidden_states_output(inputs_dict, config, model_class):
model = model_class(config)
outputs = model(**self._prepare_for_class(inputs_dict, model_class), training=False)
hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states
expected_num_stages = self.model_tester.num_stages
self.assertEqual(len(hidden_states), expected_num_stages + 1)
# RegNet's feature maps are of shape (batch_size, num_channels, height, width)
self.assertListEqual(
list(hidden_states[0].shape[-2:]),
[self.model_tester.image_size // 2, self.model_tester.image_size // 2],
)
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
layers_type = ["basic", "bottleneck"]
for model_class in self.all_model_classes:
for layer_type in layers_type:
config.layer_type = layer_type
inputs_dict["output_hidden_states"] = True
check_hidden_states_output(inputs_dict, config, model_class)
# check that output_hidden_states also work using config
del inputs_dict["output_hidden_states"]
config.output_hidden_states = True
check_hidden_states_output(inputs_dict, config, model_class)
# Since RegNet does not have any attention we need to rewrite this test.
def test_model_outputs_equivalence(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}):
tuple_output = model(tuple_inputs, return_dict=False, **additional_kwargs)
dict_output = model(dict_inputs, return_dict=True, **additional_kwargs).to_tuple()
def recursive_check(tuple_object, dict_object):
if isinstance(tuple_object, (List, Tuple)):
for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object):
recursive_check(tuple_iterable_value, dict_iterable_value)
elif tuple_object is None:
return
else:
self.assertTrue(
all(tf.equal(tuple_object, dict_object)),
msg=(
"Tuple and dict output are not equal. Difference:"
f" {tf.math.reduce_max(tf.abs(tuple_object - dict_object))}"
),
)
recursive_check(tuple_output, dict_output)
for model_class in self.all_model_classes:
model = model_class(config)
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
check_equivalence(model, tuple_inputs, dict_inputs)
tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
check_equivalence(model, tuple_inputs, dict_inputs)
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
def test_for_image_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
@slow
def test_model_from_pretrained(self):
for model_name in TF_REGNET_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
model = TFRegNetModel.from_pretrained(model_name)
self.assertIsNotNone(model)
# We will verify our results on an image of cute cats
def prepare_img():
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
return image
@require_tf
@require_vision
class RegNetModelIntegrationTest(unittest.TestCase):
@cached_property
def default_feature_extractor(self):
return (
AutoFeatureExtractor.from_pretrained(TF_REGNET_PRETRAINED_MODEL_ARCHIVE_LIST[0])
if is_vision_available()
else None
)
@slow
def test_inference_image_classification_head(self):
model = TFRegNetForImageClassification.from_pretrained(TF_REGNET_PRETRAINED_MODEL_ARCHIVE_LIST[0])
feature_extractor = self.default_feature_extractor
image = prepare_img()
inputs = feature_extractor(images=image, return_tensors="tf")
# forward pass
outputs = model(**inputs, training=False)
# verify the logits
expected_shape = tf.TensorShape((1, 1000))
self.assertEqual(outputs.logits.shape, expected_shape)
expected_slice = tf.constant([-0.4180, -1.5051, -3.4836])
tf.debugging.assert_near(outputs.logits[0, :3], expected_slice, atol=1e-4)
......@@ -50,6 +50,8 @@ src/transformers/models/pegasus/modeling_pegasus.py
src/transformers/models/plbart/modeling_plbart.py
src/transformers/models/poolformer/modeling_poolformer.py
src/transformers/models/reformer/modeling_reformer.py
src/transformers/models/regnet/modeling_regnet.py
src/transformers/models/regnet/modeling_tf_regnet.py
src/transformers/models/resnet/modeling_resnet.py
src/transformers/models/roberta/modeling_roberta.py
src/transformers/models/roberta/modeling_tf_roberta.py
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment