Unverified Commit 77ea5130 authored by amyeroberts's avatar amyeroberts Committed by GitHub
Browse files

Add TF ResNet model (#17427)



* Rought TF conversion outline

* Tidy up

* Fix padding differences between layers

* Add back embedder - whoops

* Match test file to main

* Match upstream test file

* Correctly pass and assign image_size parameter
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* Add in MainLayer

* Correctly name layer

* Tidy up AdaptivePooler

* Small tidy-up

More accurate type hints and remove whitespaces

* Change AdaptiveAvgPool

Use the AdaptiveAvgPool implementation by @Rocketknight1, which correctly pools if the output shape does not evenly divide by input shape c.f. https://github.com/huggingface/transformers/pull/17554/files/9e26607e22aa8d069c86b50196656012ff0ce62a#r900109509

Co-authored-by: default avatarFrom: matt <rocketknight1@gmail.com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* Use updated AdaptiveAvgPool
Co-authored-by: default avatarmatt <rocketknight1@gmail.com>

* Make AdaptiveAvgPool compatible with CPU

* Remove image_size from configuration

* Fixup

* Tensorflow -> TensorFlow

* Fix pt references in tests

* Apply suggestions from code review - grammar and wording
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Add TFResNet to doc tests

* PR comments - GlobalAveragePooling and clearer comments

* Remove unused import

* Add in keepdims argument

* Add num_channels check

* grammar fix: by -> of
Co-authored-by: default avatarmatt <rocketknight1@gmail.com>
Co-authored-by: default avatarMatt <Rocketknight1@users.noreply.github.com>

* Remove transposes - keep NHWC throughout forward pass

* Fixup look sharp

* Add missing layer names

* Final tidy up - remove from_pt now weights on hub
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatarmatt <rocketknight1@gmail.com>
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>
Co-authored-by: default avatarMatt <Rocketknight1@users.noreply.github.com>
parent 7b18702c
...@@ -273,7 +273,7 @@ Flax), PyTorch, and/or TensorFlow. ...@@ -273,7 +273,7 @@ Flax), PyTorch, and/or TensorFlow.
| Reformer | ✅ | ✅ | ✅ | ❌ | ❌ | | Reformer | ✅ | ✅ | ✅ | ❌ | ❌ |
| RegNet | ❌ | ❌ | ✅ | ✅ | ❌ | | RegNet | ❌ | ❌ | ✅ | ✅ | ❌ |
| RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ | | RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ |
| ResNet | ❌ | ❌ | ✅ | | ❌ | | ResNet | ❌ | ❌ | ✅ | | ❌ |
| RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ | | RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ |
| RoBERTa | ✅ | ✅ | ✅ | ✅ | ✅ | | RoBERTa | ✅ | ✅ | ✅ | ✅ | ✅ |
| RoFormer | ✅ | ✅ | ✅ | ✅ | ✅ | | RoFormer | ✅ | ✅ | ✅ | ✅ | ✅ |
......
...@@ -31,7 +31,7 @@ The figure below illustrates the architecture of ResNet. Taken from the [origina ...@@ -31,7 +31,7 @@ The figure below illustrates the architecture of ResNet. Taken from the [origina
<img width="600" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/resnet_architecture.png"/> <img width="600" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/resnet_architecture.png"/>
This model was contributed by [Francesco](https://huggingface.co/Francesco). The original code can be found [here](https://github.com/KaimingHe/deep-residual-networks). This model was contributed by [Francesco](https://huggingface.co/Francesco). The TensorFlow version of this model was added by [amyeroberts](https://huggingface.co/amyeroberts). The original code can be found [here](https://github.com/KaimingHe/deep-residual-networks).
## ResNetConfig ## ResNetConfig
...@@ -47,4 +47,16 @@ This model was contributed by [Francesco](https://huggingface.co/Francesco). The ...@@ -47,4 +47,16 @@ This model was contributed by [Francesco](https://huggingface.co/Francesco). The
## ResNetForImageClassification ## ResNetForImageClassification
[[autodoc]] ResNetForImageClassification [[autodoc]] ResNetForImageClassification
- forward - forward
\ No newline at end of file
## TFResNetModel
[[autodoc]] TFResNetModel
- call
## TFResNetForImageClassification
[[autodoc]] TFResNetForImageClassification
- call
...@@ -2380,6 +2380,14 @@ else: ...@@ -2380,6 +2380,14 @@ else:
"TFRemBertPreTrainedModel", "TFRemBertPreTrainedModel",
] ]
) )
_import_structure["models.resnet"].extend(
[
"TF_RESNET_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFResNetForImageClassification",
"TFResNetModel",
"TFResNetPreTrainedModel",
]
)
_import_structure["models.roberta"].extend( _import_structure["models.roberta"].extend(
[ [
"TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST", "TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST",
...@@ -4721,6 +4729,12 @@ if TYPE_CHECKING: ...@@ -4721,6 +4729,12 @@ if TYPE_CHECKING:
TFRemBertModel, TFRemBertModel,
TFRemBertPreTrainedModel, TFRemBertPreTrainedModel,
) )
from .models.resnet import (
TF_RESNET_PRETRAINED_MODEL_ARCHIVE_LIST,
TFResNetForImageClassification,
TFResNetModel,
TFResNetPreTrainedModel,
)
from .models.roberta import ( from .models.roberta import (
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST, TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
TFRobertaForCausalLM, TFRobertaForCausalLM,
......
...@@ -62,7 +62,7 @@ class TFBaseModelOutputWithNoAttention(ModelOutput): ...@@ -62,7 +62,7 @@ class TFBaseModelOutputWithNoAttention(ModelOutput):
""" """
last_hidden_state: tf.Tensor = None last_hidden_state: tf.Tensor = None
hidden_states: Optional[Tuple[tf.Tensor]] = None hidden_states: Optional[Tuple[tf.Tensor, ...]] = None
@dataclass @dataclass
...@@ -118,7 +118,7 @@ class TFBaseModelOutputWithPoolingAndNoAttention(ModelOutput): ...@@ -118,7 +118,7 @@ class TFBaseModelOutputWithPoolingAndNoAttention(ModelOutput):
last_hidden_state: tf.Tensor = None last_hidden_state: tf.Tensor = None
pooler_output: tf.Tensor = None pooler_output: tf.Tensor = None
hidden_states: Optional[Tuple[tf.Tensor]] = None hidden_states: Optional[Tuple[tf.Tensor, ...]] = None
@dataclass @dataclass
...@@ -886,4 +886,4 @@ class TFImageClassifierOutputWithNoAttention(ModelOutput): ...@@ -886,4 +886,4 @@ class TFImageClassifierOutputWithNoAttention(ModelOutput):
loss: Optional[tf.Tensor] = None loss: Optional[tf.Tensor] = None
logits: tf.Tensor = None logits: tf.Tensor = None
hidden_states: Optional[Tuple[tf.Tensor]] = None hidden_states: Optional[Tuple[tf.Tensor, ...]] = None
...@@ -64,6 +64,7 @@ TF_MODEL_MAPPING_NAMES = OrderedDict( ...@@ -64,6 +64,7 @@ TF_MODEL_MAPPING_NAMES = OrderedDict(
("pegasus", "TFPegasusModel"), ("pegasus", "TFPegasusModel"),
("regnet", "TFRegNetModel"), ("regnet", "TFRegNetModel"),
("rembert", "TFRemBertModel"), ("rembert", "TFRemBertModel"),
("resnet", "TFResNetModel"),
("roberta", "TFRobertaModel"), ("roberta", "TFRobertaModel"),
("roformer", "TFRoFormerModel"), ("roformer", "TFRoFormerModel"),
("speech_to_text", "TFSpeech2TextModel"), ("speech_to_text", "TFSpeech2TextModel"),
...@@ -175,6 +176,7 @@ TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( ...@@ -175,6 +176,7 @@ TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
("convnext", "TFConvNextForImageClassification"), ("convnext", "TFConvNextForImageClassification"),
("data2vec-vision", "TFData2VecVisionForImageClassification"), ("data2vec-vision", "TFData2VecVisionForImageClassification"),
("regnet", "TFRegNetForImageClassification"), ("regnet", "TFRegNetForImageClassification"),
("resnet", "TFResNetForImageClassification"),
("swin", "TFSwinForImageClassification"), ("swin", "TFSwinForImageClassification"),
("vit", "TFViTForImageClassification"), ("vit", "TFViTForImageClassification"),
] ]
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
# rely on isort to merge the imports # rely on isort to merge the imports
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tf_available, is_torch_available
_import_structure = { _import_structure = {
...@@ -38,6 +38,19 @@ else: ...@@ -38,6 +38,19 @@ else:
"ResNetPreTrainedModel", "ResNetPreTrainedModel",
] ]
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_tf_resnet"] = [
"TF_RESNET_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFResNetForImageClassification",
"TFResNetModel",
"TFResNetPreTrainedModel",
]
if TYPE_CHECKING: if TYPE_CHECKING:
from .configuration_resnet import RESNET_PRETRAINED_CONFIG_ARCHIVE_MAP, ResNetConfig, ResNetOnnxConfig from .configuration_resnet import RESNET_PRETRAINED_CONFIG_ARCHIVE_MAP, ResNetConfig, ResNetOnnxConfig
...@@ -55,6 +68,19 @@ if TYPE_CHECKING: ...@@ -55,6 +68,19 @@ if TYPE_CHECKING:
ResNetPreTrainedModel, ResNetPreTrainedModel,
) )
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_tf_resnet import (
TF_RESNET_PRETRAINED_MODEL_ARCHIVE_LIST,
TFResNetForImageClassification,
TFResNetModel,
TFResNetPreTrainedModel,
)
else: else:
import sys import sys
......
# coding=utf-8
# Copyright 2022 Microsoft Research, Inc. and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" TensorFlow ResNet model."""
from typing import Dict, Optional, Tuple, Union
import tensorflow as tf
from ...activations_tf import ACT2FN
from ...modeling_tf_outputs import (
TFBaseModelOutputWithNoAttention,
TFBaseModelOutputWithPoolingAndNoAttention,
TFImageClassifierOutputWithNoAttention,
)
from ...modeling_tf_utils import TFPreTrainedModel, TFSequenceClassificationLoss, keras_serializable, unpack_inputs
from ...tf_utils import shape_list
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
from .configuration_resnet import ResNetConfig
logger = logging.get_logger(__name__)
# General docstring
_CONFIG_FOR_DOC = "ResNetConfig"
_FEAT_EXTRACTOR_FOR_DOC = "AutoFeatureExtractor"
# Base docstring
_CHECKPOINT_FOR_DOC = "microsoft/resnet-50"
_EXPECTED_OUTPUT_SHAPE = [1, 2048, 7, 7]
# Image classification docstring
_IMAGE_CLASS_CHECKPOINT = "microsoft/resnet-50"
_IMAGE_CLASS_EXPECTED_OUTPUT = "tiger cat"
TF_RESNET_PRETRAINED_MODEL_ARCHIVE_LIST = [
"microsoft/resnet-50",
# See all resnet models at https://huggingface.co/models?filter=resnet
]
class TFResNetConvLayer(tf.keras.layers.Layer):
def __init__(
self, out_channels: int, kernel_size: int = 3, stride: int = 1, activation: str = "relu", **kwargs
) -> None:
super().__init__(**kwargs)
self.pad_value = kernel_size // 2
self.conv = tf.keras.layers.Conv2D(
out_channels, kernel_size=kernel_size, strides=stride, padding="valid", use_bias=False, name="convolution"
)
# Use same default momentum and epsilon as PyTorch equivalent
self.normalization = tf.keras.layers.BatchNormalization(epsilon=1e-5, momentum=0.1, name="normalization")
self.activation = ACT2FN[activation] if activation is not None else tf.keras.layers.Activation("linear")
def convolution(self, hidden_state: tf.Tensor) -> tf.Tensor:
# Pad to match that done in the PyTorch Conv2D model
height_pad = width_pad = (self.pad_value, self.pad_value)
hidden_state = tf.pad(hidden_state, [(0, 0), height_pad, width_pad, (0, 0)])
hidden_state = self.conv(hidden_state)
return hidden_state
def call(self, hidden_state: tf.Tensor, training: bool = False) -> tf.Tensor:
hidden_state = self.convolution(hidden_state)
hidden_state = self.normalization(hidden_state, training=training)
hidden_state = self.activation(hidden_state)
return hidden_state
class TFResNetEmbeddings(tf.keras.layers.Layer):
"""
ResNet Embeddings (stem) composed of a single aggressive convolution.
"""
def __init__(self, config: ResNetConfig, **kwargs) -> None:
super().__init__(**kwargs)
self.embedder = TFResNetConvLayer(
config.embedding_size,
kernel_size=7,
stride=2,
activation=config.hidden_act,
name="embedder",
)
self.pooler = tf.keras.layers.MaxPool2D(pool_size=3, strides=2, padding="valid", name="pooler")
self.num_channels = config.num_channels
def call(self, pixel_values: tf.Tensor, training: bool = False) -> tf.Tensor:
_, _, _, num_channels = shape_list(pixel_values)
if tf.executing_eagerly() and num_channels != self.num_channels:
raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
)
hidden_state = pixel_values
hidden_state = self.embedder(hidden_state)
hidden_state = tf.pad(hidden_state, [[0, 0], [1, 1], [1, 1], [0, 0]])
hidden_state = self.pooler(hidden_state)
return hidden_state
class TFResNetShortCut(tf.keras.layers.Layer):
"""
ResNet shortcut, used to project the residual features to the correct size. If needed, it is also used to
downsample the input using `stride=2`.
"""
def __init__(self, out_channels: int, stride: int = 2, **kwargs) -> None:
super().__init__(**kwargs)
self.convolution = tf.keras.layers.Conv2D(
out_channels, kernel_size=1, strides=stride, use_bias=False, name="convolution"
)
# Use same default momentum and epsilon as PyTorch equivalent
self.normalization = tf.keras.layers.BatchNormalization(epsilon=1e-5, momentum=0.1, name="normalization")
def call(self, x: tf.Tensor, training: bool = False) -> tf.Tensor:
hidden_state = x
hidden_state = self.convolution(hidden_state)
hidden_state = self.normalization(hidden_state, training=training)
return hidden_state
class TFResNetBasicLayer(tf.keras.layers.Layer):
"""
A classic ResNet's residual layer composed by two `3x3` convolutions.
"""
def __init__(
self, in_channels: int, out_channels: int, stride: int = 1, activation: str = "relu", **kwargs
) -> None:
super().__init__(**kwargs)
should_apply_shortcut = in_channels != out_channels or stride != 1
self.conv1 = TFResNetConvLayer(out_channels, stride=stride, name="layer.0")
self.conv2 = TFResNetConvLayer(out_channels, activation=None, name="layer.1")
self.shortcut = (
TFResNetShortCut(out_channels, stride=stride, name="shortcut")
if should_apply_shortcut
else tf.keras.layers.Activation("linear", name="shortcut")
)
self.activation = ACT2FN[activation]
def call(self, hidden_state: tf.Tensor, training: bool = False) -> tf.Tensor:
residual = hidden_state
hidden_state = self.conv1(hidden_state, training=training)
hidden_state = self.conv2(hidden_state, training=training)
residual = self.shortcut(residual, training=training)
hidden_state += residual
hidden_state = self.activation(hidden_state)
return hidden_state
class TFResNetBottleNeckLayer(tf.keras.layers.Layer):
"""
A classic ResNet's bottleneck layer composed by three `3x3` convolutions.
The first `1x1` convolution reduces the input by a factor of `reduction` in order to make the second `3x3`
convolution faster. The last `1x1` convolution remaps the reduced features to `out_channels`.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
stride: int = 1,
activation: str = "relu",
reduction: int = 4,
**kwargs
) -> None:
super().__init__(**kwargs)
should_apply_shortcut = in_channels != out_channels or stride != 1
reduces_channels = out_channels // reduction
self.conv0 = TFResNetConvLayer(reduces_channels, kernel_size=1, name="layer.0")
self.conv1 = TFResNetConvLayer(reduces_channels, stride=stride, name="layer.1")
self.conv2 = TFResNetConvLayer(out_channels, kernel_size=1, activation=None, name="layer.2")
self.shortcut = (
TFResNetShortCut(out_channels, stride=stride, name="shortcut")
if should_apply_shortcut
else tf.keras.layers.Activation("linear", name="shortcut")
)
self.activation = ACT2FN[activation]
def call(self, hidden_state: tf.Tensor, training: bool = False) -> tf.Tensor:
residual = hidden_state
hidden_state = self.conv0(hidden_state, training=training)
hidden_state = self.conv1(hidden_state, training=training)
hidden_state = self.conv2(hidden_state, training=training)
residual = self.shortcut(residual, training=training)
hidden_state += residual
hidden_state = self.activation(hidden_state)
return hidden_state
class TFResNetStage(tf.keras.layers.Layer):
"""
A ResNet stage composed of stacked layers.
"""
def __init__(
self, config: ResNetConfig, in_channels: int, out_channels: int, stride: int = 2, depth: int = 2, **kwargs
) -> None:
super().__init__(**kwargs)
layer = TFResNetBottleNeckLayer if config.layer_type == "bottleneck" else TFResNetBasicLayer
layers = [layer(in_channels, out_channels, stride=stride, activation=config.hidden_act, name="layers.0")]
layers += [
layer(out_channels, out_channels, activation=config.hidden_act, name=f"layers.{i + 1}")
for i in range(depth - 1)
]
self.stage_layers = layers
def call(self, hidden_state: tf.Tensor, training: bool = False) -> tf.Tensor:
for layer in self.stage_layers:
hidden_state = layer(hidden_state, training=training)
return hidden_state
class TFResNetEncoder(tf.keras.layers.Layer):
def __init__(self, config: ResNetConfig, **kwargs) -> None:
super().__init__(**kwargs)
# based on `downsample_in_first_stage` the first layer of the first stage may or may not downsample the input
self.stages = [
TFResNetStage(
config,
config.embedding_size,
config.hidden_sizes[0],
stride=2 if config.downsample_in_first_stage else 1,
depth=config.depths[0],
name="stages.0",
)
]
for i, (in_channels, out_channels, depth) in enumerate(
zip(config.hidden_sizes, config.hidden_sizes[1:], config.depths[1:])
):
self.stages.append(TFResNetStage(config, in_channels, out_channels, depth=depth, name=f"stages.{i + 1}"))
def call(
self,
hidden_state: tf.Tensor,
output_hidden_states: bool = False,
return_dict: bool = True,
training: bool = False,
) -> TFBaseModelOutputWithNoAttention:
hidden_states = () if output_hidden_states else None
for stage_module in self.stages:
if output_hidden_states:
hidden_states = hidden_states + (hidden_state,)
hidden_state = stage_module(hidden_state, training=training)
if output_hidden_states:
hidden_states = hidden_states + (hidden_state,)
if not return_dict:
return tuple(v for v in [hidden_state, hidden_states] if v is not None)
return TFBaseModelOutputWithNoAttention(
last_hidden_state=hidden_state,
hidden_states=hidden_states,
)
class TFResNetPreTrainedModel(TFPreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = ResNetConfig
base_model_prefix = "resnet"
main_input_name = "pixel_values"
@property
def dummy_inputs(self) -> Dict[str, tf.Tensor]:
"""
Dummy inputs to build the network. Returns:
`Dict[str, tf.Tensor]`: The dummy inputs.
"""
VISION_DUMMY_INPUTS = tf.random.uniform(shape=(3, self.config.num_channels, 224, 224), dtype=tf.float32)
return {"pixel_values": tf.constant(VISION_DUMMY_INPUTS)}
RESNET_START_DOCSTRING = r"""
This model is a TensorFlow
[tf.keras.layers.Layer](https://www.tensorflow.org/api_docs/python/tf/keras/layers/Layer) sub-class. Use it as a
regular TensorFlow Module and refer to the TensorFlow documentation for all matter related to general usage and
behavior.
Parameters:
config ([`ResNetConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights.
"""
RESNET_INPUTS_DOCSTRING = r"""
Args:
pixel_values (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values. Pixel values can be obtained using [`AutoFeatureExtractor`]. See
[`AutoFeatureExtractor.__call__`] for details.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
@keras_serializable
class TFResNetMainLayer(tf.keras.layers.Layer):
config_class = ResNetConfig
def __init__(self, config: ResNetConfig, **kwargs) -> None:
super().__init__(**kwargs)
self.config = config
self.embedder = TFResNetEmbeddings(config, name="embedder")
self.encoder = TFResNetEncoder(config, name="encoder")
self.pooler = tf.keras.layers.GlobalAveragePooling2D(keepdims=True)
@unpack_inputs
def call(
self,
pixel_values: tf.Tensor,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: bool = False,
) -> Union[Tuple[tf.Tensor], TFBaseModelOutputWithPoolingAndNoAttention]:
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# TF 2.0 image layers can't use NCHW format when running on CPU.
# We transpose to NHWC format and then transpose back after the full forward pass.
# (batch_size, num_channels, height, width) -> (batch_size, height, width, num_channels)
pixel_values = tf.transpose(pixel_values, perm=[0, 2, 3, 1])
embedding_output = self.embedder(pixel_values, training=training)
encoder_outputs = self.encoder(
embedding_output, output_hidden_states=output_hidden_states, return_dict=return_dict, training=training
)
last_hidden_state = encoder_outputs[0]
pooled_output = self.pooler(last_hidden_state)
# Transpose all the outputs to the NCHW format
# (batch_size, height, width, num_channels) -> (batch_size, num_channels, height, width)
last_hidden_state = tf.transpose(last_hidden_state, (0, 3, 1, 2))
pooled_output = tf.transpose(pooled_output, (0, 3, 1, 2))
hidden_states = ()
for hidden_state in encoder_outputs[1:]:
hidden_states = hidden_states + tuple(tf.transpose(h, (0, 3, 1, 2)) for h in hidden_state)
if not return_dict:
return (last_hidden_state, pooled_output) + hidden_states
hidden_states = hidden_states if output_hidden_states else None
return TFBaseModelOutputWithPoolingAndNoAttention(
last_hidden_state=last_hidden_state,
pooler_output=pooled_output,
hidden_states=hidden_states,
)
@add_start_docstrings(
"The bare ResNet model outputting raw features without any specific head on top.",
RESNET_START_DOCSTRING,
)
class TFResNetModel(TFResNetPreTrainedModel):
def __init__(self, config: ResNetConfig, **kwargs) -> None:
super().__init__(config, **kwargs)
self.resnet = TFResNetMainLayer(config=config, name="resnet")
@add_start_docstrings_to_model_forward(RESNET_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TFBaseModelOutputWithPoolingAndNoAttention,
config_class=_CONFIG_FOR_DOC,
modality="vision",
expected_output=_EXPECTED_OUTPUT_SHAPE,
)
@unpack_inputs
def call(
self,
pixel_values: tf.Tensor,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: bool = False,
) -> Union[Tuple[tf.Tensor], TFBaseModelOutputWithPoolingAndNoAttention]:
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
resnet_outputs = self.resnet(
pixel_values=pixel_values,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
return resnet_outputs
@add_start_docstrings(
"""
ResNet Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for
ImageNet.
""",
RESNET_START_DOCSTRING,
)
class TFResNetForImageClassification(TFResNetPreTrainedModel, TFSequenceClassificationLoss):
def __init__(self, config: ResNetConfig, **kwargs) -> None:
super().__init__(config, **kwargs)
self.num_labels = config.num_labels
self.resnet = TFResNetMainLayer(config, name="resnet")
# classification head
self.classifier_layer = (
tf.keras.layers.Dense(config.num_labels, name="classifier.1")
if config.num_labels > 0
else tf.keras.layers.Activation("linear", name="classifier.1")
)
def classifier(self, x: tf.Tensor) -> tf.Tensor:
x = tf.keras.layers.Flatten()(x)
logits = self.classifier_layer(x)
return logits
@add_start_docstrings_to_model_forward(RESNET_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
checkpoint=_IMAGE_CLASS_CHECKPOINT,
output_type=TFImageClassifierOutputWithNoAttention,
config_class=_CONFIG_FOR_DOC,
expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
)
@unpack_inputs
def call(
self,
pixel_values: tf.Tensor = None,
labels: tf.Tensor = None,
output_hidden_states: bool = None,
return_dict: bool = None,
training: bool = False,
) -> Union[Tuple[tf.Tensor], TFImageClassifierOutputWithNoAttention]:
r"""
labels (`tf.Tensor` of shape `(batch_size,)`, *optional*):
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.resnet(
pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict, training=training
)
pooled_output = outputs.pooler_output if return_dict else outputs[1]
logits = self.classifier(pooled_output)
loss = None if labels is None else self.hf_compute_loss(labels, logits)
if not return_dict:
output = (logits,) + outputs[2:]
return (loss,) + output if loss is not None else output
return TFImageClassifierOutputWithNoAttention(loss=loss, logits=logits, hidden_states=outputs.hidden_states)
...@@ -1786,6 +1786,30 @@ class TFRemBertPreTrainedModel(metaclass=DummyObject): ...@@ -1786,6 +1786,30 @@ class TFRemBertPreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["tf"]) requires_backends(self, ["tf"])
TF_RESNET_PRETRAINED_MODEL_ARCHIVE_LIST = None
class TFResNetForImageClassification(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFResNetModel(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFResNetPreTrainedModel(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST = None TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST = None
......
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Testing suite for the Tensorflow ResNet model. """
import inspect
import unittest
import numpy as np
from transformers import ResNetConfig
from transformers.testing_utils import require_tf, require_vision, slow
from transformers.utils import cached_property, is_tf_available, is_vision_available
from ...test_configuration_common import ConfigTester
from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor
if is_tf_available():
import tensorflow as tf
from transformers import TFResNetForImageClassification, TFResNetModel
from transformers.models.resnet.modeling_tf_resnet import TF_RESNET_PRETRAINED_MODEL_ARCHIVE_LIST
if is_vision_available():
from PIL import Image
from transformers import AutoFeatureExtractor
class ResNetModelTester:
def __init__(
self,
parent,
batch_size=3,
image_size=32,
num_channels=3,
embeddings_size=10,
hidden_sizes=[10, 20, 30, 40],
depths=[1, 1, 2, 1],
is_training=True,
use_labels=True,
hidden_act="relu",
num_labels=3,
scope=None,
):
self.parent = parent
self.batch_size = batch_size
self.image_size = image_size
self.num_channels = num_channels
self.embeddings_size = embeddings_size
self.hidden_sizes = hidden_sizes
self.depths = depths
self.is_training = is_training
self.use_labels = use_labels
self.hidden_act = hidden_act
self.num_labels = num_labels
self.scope = scope
self.num_stages = len(hidden_sizes)
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
labels = None
if self.use_labels:
labels = ids_tensor([self.batch_size], self.num_labels)
config = self.get_config()
return config, pixel_values, labels
def get_config(self):
return ResNetConfig(
num_channels=self.num_channels,
embeddings_size=self.embeddings_size,
hidden_sizes=self.hidden_sizes,
depths=self.depths,
hidden_act=self.hidden_act,
num_labels=self.num_labels,
image_size=self.image_size,
)
def create_and_check_model(self, config, pixel_values, labels):
model = TFResNetModel(config=config)
result = model(pixel_values)
# expected last hidden states: B, C, H // 32, W // 32
self.parent.assertEqual(
result.last_hidden_state.shape,
(self.batch_size, self.hidden_sizes[-1], self.image_size // 32, self.image_size // 32),
)
def create_and_check_for_image_classification(self, config, pixel_values, labels):
config.num_labels = self.num_labels
model = TFResNetForImageClassification(config)
result = model(pixel_values, labels=labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, pixel_values, labels = config_and_inputs
inputs_dict = {"pixel_values": pixel_values}
return config, inputs_dict
@require_tf
class ResNetModelTest(TFModelTesterMixin, unittest.TestCase):
"""
Here we also overwrite some of the tests of test_modeling_common.py, as ResNet does not use input_ids, inputs_embeds,
attention_mask and seq_length.
"""
all_model_classes = (TFResNetModel, TFResNetForImageClassification) if is_tf_available() else ()
test_pruning = False
test_resize_embeddings = False
test_head_masking = False
test_onnx = False
has_attentions = False
def setUp(self):
self.model_tester = ResNetModelTester(self)
self.config_tester = ConfigTester(self, config_class=ResNetConfig, has_text_modality=False)
def test_config(self):
self.create_and_test_config_common_properties()
self.config_tester.create_and_test_config_to_json_string()
self.config_tester.create_and_test_config_to_json_file()
self.config_tester.create_and_test_config_from_and_save_pretrained()
self.config_tester.create_and_test_config_with_num_labels()
self.config_tester.check_config_can_be_init_without_params()
self.config_tester.check_config_arguments_init()
def create_and_test_config_common_properties(self):
return
@unittest.skip(reason="ResNet does not use inputs_embeds")
def test_inputs_embeds(self):
pass
@unittest.skip(reason="ResNet does not output attentions")
def test_attention_outputs(self):
pass
@unittest.skip(reason="ResNet does not support input and output embeddings")
def test_model_common_attributes(self):
pass
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
signature = inspect.signature(model.call)
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()]
expected_arg_names = ["pixel_values"]
self.assertListEqual(arg_names[:1], expected_arg_names)
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
def test_hidden_states_output(self):
def check_hidden_states_output(inputs_dict, config, model_class):
model = model_class(config)
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states
expected_num_stages = self.model_tester.num_stages
self.assertEqual(len(hidden_states), expected_num_stages + 1)
# ResNet's feature maps are of shape (batch_size, num_channels, height, width)
self.assertListEqual(
list(hidden_states[0].shape[-2:]),
[self.model_tester.image_size // 4, self.model_tester.image_size // 4],
)
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
layers_type = ["basic", "bottleneck"]
for model_class in self.all_model_classes:
for layer_type in layers_type:
config.layer_type = layer_type
inputs_dict["output_hidden_states"] = True
check_hidden_states_output(inputs_dict, config, model_class)
# check that output_hidden_states also work using config
del inputs_dict["output_hidden_states"]
config.output_hidden_states = True
check_hidden_states_output(inputs_dict, config, model_class)
def test_for_image_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
@slow
def test_model_from_pretrained(self):
for model_name in TF_RESNET_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
model = TFResNetModel.from_pretrained(model_name)
self.assertIsNotNone(model)
# We will verify our results on an image of cute cats
def prepare_img():
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
return image
@require_tf
@require_vision
class ResNetModelIntegrationTest(unittest.TestCase):
@cached_property
def default_feature_extractor(self):
return (
AutoFeatureExtractor.from_pretrained(TF_RESNET_PRETRAINED_MODEL_ARCHIVE_LIST[0])
if is_vision_available()
else None
)
@slow
def test_inference_image_classification_head(self):
model = TFResNetForImageClassification.from_pretrained(TF_RESNET_PRETRAINED_MODEL_ARCHIVE_LIST[0])
feature_extractor = self.default_feature_extractor
image = prepare_img()
inputs = feature_extractor(images=image, return_tensors="tf")
# forward pass
outputs = model(**inputs)
# verify the logits
expected_shape = tf.TensorShape((1, 1000))
self.assertEqual(outputs.logits.shape, expected_shape)
expected_slice = tf.constant([-11.1069, -9.7877, -8.3777])
self.assertTrue(np.allclose(outputs.logits[0, :3].numpy(), expected_slice, atol=1e-4))
...@@ -54,6 +54,7 @@ src/transformers/models/reformer/modeling_reformer.py ...@@ -54,6 +54,7 @@ src/transformers/models/reformer/modeling_reformer.py
src/transformers/models/regnet/modeling_regnet.py src/transformers/models/regnet/modeling_regnet.py
src/transformers/models/regnet/modeling_tf_regnet.py src/transformers/models/regnet/modeling_tf_regnet.py
src/transformers/models/resnet/modeling_resnet.py src/transformers/models/resnet/modeling_resnet.py
src/transformers/models/resnet/modeling_tf_resnet.py
src/transformers/models/roberta/modeling_roberta.py src/transformers/models/roberta/modeling_roberta.py
src/transformers/models/roberta/modeling_tf_roberta.py src/transformers/models/roberta/modeling_tf_roberta.py
src/transformers/models/segformer/modeling_segformer.py src/transformers/models/segformer/modeling_segformer.py
...@@ -76,5 +77,5 @@ src/transformers/models/wav2vec2/modeling_wav2vec2.py ...@@ -76,5 +77,5 @@ src/transformers/models/wav2vec2/modeling_wav2vec2.py
src/transformers/models/wav2vec2/tokenization_wav2vec2.py src/transformers/models/wav2vec2/tokenization_wav2vec2.py
src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py
src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py
src/transformers/models/wavlm/modeling_wavlm.py src/transformers/models/wavlm/modeling_wavlm.py
src/transformers/models/yolos/modeling_yolos.py src/transformers/models/yolos/modeling_yolos.py
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment