Unverified Commit d2cec09b authored by João David's avatar João David Committed by GitHub
Browse files

Add TF swiftformer (#23342)



* Duplicate swiftformer

* Convert SwiftFormerPatchEmbedding

* Convert SwiftFormerEmbeddings

* Convert TFSwiftFormerMlp

* Convert TFSwiftFormerConvEncoder

* Convert TFSwiftFormerLocalRepresentation

* convert TFSwiftFormerEncoderBlock

* Convert SwiftFormerStage

* Convert SwiftFormerEncoder

* Add TFSWiftFormerPreTrainedModel

* Convert SwiftFormerForImageClassification

* Add kwargs and start drop path

* Fix syntax

* Change Model class name

* Add TFSwiftFormer to __init__

* Duplicate test_modeling_swiftformer

* First test conversions

* Change require_torch to require_tf

* Add exports to swiftformer __init__

* Add TFSwiftFormerModel wrapper

* Fix __init__ and run black

* Remove docstring from MainLayer, fix padding

* Use keras.layers.Activation on keras.Sequential

* Fix swiftformer exports

* Fix activation layer from config

* Remove post_inits

* Use tf.keras.layers.ZeroPadding2D

* Convert torch normalize

* Change tf test input shape

* Fix softmax and reduce_sum

* Convert expand_dims and repeat

* Add missing reshape and tranpose

* Simplify TFSwiftFormerEncoderBlock.call

* Fix mismatch in patch embeddings

* Fix expected output shape to match channels last

* Fix swiftformer typo

* Disable test_onnx

* Fix TFSwiftFormerForImageClassification call

* Add unpack inputs

* Convert flatten(2).mean(-1)

* Change vision dummy inputs (to be reviewed)

* Change test_forward_signature to use .call

* Fix @unpack_inputs

* Set return_tensors="tf" and rename class

* Rename wrongly named patch_embeddings layer

* Add serving_output and change dummy_input shape

* Make dimensions BCHW and transpose inside embedding layer

* Change SwiftFormerEncoderBlock

* Fix ruff problems

* Add image size to swiftformer config

* Change tranpose to MainLayer and use -1 for reshape

* Remove serving_outputs and dummy_inputs

* Remove test_initialization test from tf model

* Make Sequential component a separate layer

* Fix layers' names

* Tranpose encoder outputs

* Fix tests and check if hidden states is not None

* Fix TFSwiftFormerForImageClassification

* Run make fixup

* Run make fix-copies

* Update modeling_tf_auto

* Update docs

* Fix modeling auto mapping

* Update modelint_tf_swiftformer docs

* Fill image_size doc and type

* Add reduction=None to loss computation

* Update docs

* make style

* Debug: Delete the tip to see if that changes anything

* Re-add tip

* Remove add_code_sample_docstrings

* Remove unused import

* Get the debug to actually tell us the problem it has with the docs

* Try a substitution to match the PyTorch file?

* Add swiftformer to ignore list

* Add build() methods

* Update copyright year
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Remove FIXME comment

* Remove from_pt

* Update copyright year
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Rename one-letter variables

* Remove FIXMEs related to momentum

* Remove old TODO comment

* Remove outstanding FIXME comments

* Get dropout rate from config

* Add specific dropout config for MLP

* Add convencoder dropout to config

* Pass config to SwiftFormerDropPath layer

* Fix drop_path variable name and add Adapted from comment

* Run ruff

* Removed copied from comment

* Run fix copies

* Change drop_path to identity to match pt

* Cleanup build() methods and move to new keras imports

* Update docs/source/en/model_doc/swiftformer.md
Co-authored-by: default avatarMatt <Rocketknight1@users.noreply.github.com>

* Raise error if drop_path_rate > 0.0

* Apply suggestions from code review

Replace (self.dim), with self.dim,
Co-authored-by: default avatarMatt <Rocketknight1@users.noreply.github.com>

* Remove drop_path function

* Add training to TFSwiftFormerEncoder

* Set self.built = True last
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Should have been added to previous commit
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Apply suggestions from code review
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Change default_feature_extractor to default_image_processor
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Import Keras from modeling_tf_utils

* Remove relative import

* Run ruff --fix

* Move import keras to tf_available

* Add copied from comment to test_forward_signature

* Reduce batch size and num_labels

* Extract loss logic to hf_compute_loss

* Run ruff format

---------
Co-authored-by: default avatarMatt <rocketknight1@gmail.com>
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
Co-authored-by: default avatarMatt <Rocketknight1@users.noreply.github.com>
parent 21c912e7
...@@ -275,7 +275,7 @@ Flax), PyTorch, and/or TensorFlow. ...@@ -275,7 +275,7 @@ Flax), PyTorch, and/or TensorFlow.
| [StableLm](model_doc/stablelm) | ✅ | ❌ | ❌ | | [StableLm](model_doc/stablelm) | ✅ | ❌ | ❌ |
| [Starcoder2](model_doc/starcoder2) | ✅ | ❌ | ❌ | | [Starcoder2](model_doc/starcoder2) | ✅ | ❌ | ❌ |
| [SuperPoint](model_doc/superpoint) | ✅ | ❌ | ❌ | | [SuperPoint](model_doc/superpoint) | ✅ | ❌ | ❌ |
| [SwiftFormer](model_doc/swiftformer) | ✅ | | ❌ | | [SwiftFormer](model_doc/swiftformer) | ✅ | | ❌ |
| [Swin Transformer](model_doc/swin) | ✅ | ✅ | ❌ | | [Swin Transformer](model_doc/swin) | ✅ | ✅ | ❌ |
| [Swin Transformer V2](model_doc/swinv2) | ✅ | ❌ | ❌ | | [Swin Transformer V2](model_doc/swinv2) | ✅ | ❌ | ❌ |
| [Swin2SR](model_doc/swin2sr) | ✅ | ❌ | ❌ | | [Swin2SR](model_doc/swin2sr) | ✅ | ❌ | ❌ |
......
...@@ -26,7 +26,7 @@ The abstract from the paper is the following: ...@@ -26,7 +26,7 @@ The abstract from the paper is the following:
*Self-attention has become a defacto choice for capturing global context in various vision applications. However, its quadratic computational complexity with respect to image resolution limits its use in real-time applications, especially for deployment on resource-constrained mobile devices. Although hybrid approaches have been proposed to combine the advantages of convolutions and self-attention for a better speed-accuracy trade-off, the expensive matrix multiplication operations in self-attention remain a bottleneck. In this work, we introduce a novel efficient additive attention mechanism that effectively replaces the quadratic matrix multiplication operations with linear element-wise multiplications. Our design shows that the key-value interaction can be replaced with a linear layer without sacrificing any accuracy. Unlike previous state-of-the-art methods, our efficient formulation of self-attention enables its usage at all stages of the network. Using our proposed efficient additive attention, we build a series of models called "SwiftFormer" which achieves state-of-the-art performance in terms of both accuracy and mobile inference speed. Our small variant achieves 78.5% top-1 ImageNet-1K accuracy with only 0.8 ms latency on iPhone 14, which is more accurate and 2x faster compared to MobileViT-v2.* *Self-attention has become a defacto choice for capturing global context in various vision applications. However, its quadratic computational complexity with respect to image resolution limits its use in real-time applications, especially for deployment on resource-constrained mobile devices. Although hybrid approaches have been proposed to combine the advantages of convolutions and self-attention for a better speed-accuracy trade-off, the expensive matrix multiplication operations in self-attention remain a bottleneck. In this work, we introduce a novel efficient additive attention mechanism that effectively replaces the quadratic matrix multiplication operations with linear element-wise multiplications. Our design shows that the key-value interaction can be replaced with a linear layer without sacrificing any accuracy. Unlike previous state-of-the-art methods, our efficient formulation of self-attention enables its usage at all stages of the network. Using our proposed efficient additive attention, we build a series of models called "SwiftFormer" which achieves state-of-the-art performance in terms of both accuracy and mobile inference speed. Our small variant achieves 78.5% top-1 ImageNet-1K accuracy with only 0.8 ms latency on iPhone 14, which is more accurate and 2x faster compared to MobileViT-v2.*
This model was contributed by [shehan97](https://huggingface.co/shehan97). This model was contributed by [shehan97](https://huggingface.co/shehan97). The TensorFlow version was contributed by [joaocmd](https://huggingface.co/joaocmd).
The original code can be found [here](https://github.com/Amshaker/SwiftFormer). The original code can be found [here](https://github.com/Amshaker/SwiftFormer).
## SwiftFormerConfig ## SwiftFormerConfig
...@@ -42,3 +42,13 @@ The original code can be found [here](https://github.com/Amshaker/SwiftFormer). ...@@ -42,3 +42,13 @@ The original code can be found [here](https://github.com/Amshaker/SwiftFormer).
[[autodoc]] SwiftFormerForImageClassification [[autodoc]] SwiftFormerForImageClassification
- forward - forward
## TFSwiftFormerModel
[[autodoc]] TFSwiftFormerModel
- call
## TFSwiftFormerForImageClassification
[[autodoc]] TFSwiftFormerForImageClassification
- call
...@@ -4517,6 +4517,14 @@ else: ...@@ -4517,6 +4517,14 @@ else:
"TFSpeech2TextPreTrainedModel", "TFSpeech2TextPreTrainedModel",
] ]
) )
_import_structure["models.swiftformer"].extend(
[
"TF_SWIFTFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFSwiftFormerForImageClassification",
"TFSwiftFormerModel",
"TFSwiftFormerPreTrainedModel",
]
)
_import_structure["models.swin"].extend( _import_structure["models.swin"].extend(
[ [
"TF_SWIN_PRETRAINED_MODEL_ARCHIVE_LIST", "TF_SWIN_PRETRAINED_MODEL_ARCHIVE_LIST",
...@@ -8901,6 +8909,12 @@ if TYPE_CHECKING: ...@@ -8901,6 +8909,12 @@ if TYPE_CHECKING:
TFSpeech2TextModel, TFSpeech2TextModel,
TFSpeech2TextPreTrainedModel, TFSpeech2TextPreTrainedModel,
) )
from .models.swiftformer import (
TF_SWIFTFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
TFSwiftFormerForImageClassification,
TFSwiftFormerModel,
TFSwiftFormerPreTrainedModel,
)
from .models.swin import ( from .models.swin import (
TF_SWIN_PRETRAINED_MODEL_ARCHIVE_LIST, TF_SWIN_PRETRAINED_MODEL_ARCHIVE_LIST,
TFSwinForImageClassification, TFSwinForImageClassification,
......
...@@ -81,6 +81,7 @@ TF_MODEL_MAPPING_NAMES = OrderedDict( ...@@ -81,6 +81,7 @@ TF_MODEL_MAPPING_NAMES = OrderedDict(
("sam", "TFSamModel"), ("sam", "TFSamModel"),
("segformer", "TFSegformerModel"), ("segformer", "TFSegformerModel"),
("speech_to_text", "TFSpeech2TextModel"), ("speech_to_text", "TFSpeech2TextModel"),
("swiftformer", "TFSwiftFormerModel"),
("swin", "TFSwinModel"), ("swin", "TFSwinModel"),
("t5", "TFT5Model"), ("t5", "TFT5Model"),
("tapas", "TFTapasModel"), ("tapas", "TFTapasModel"),
...@@ -213,6 +214,7 @@ TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( ...@@ -213,6 +214,7 @@ TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
("regnet", "TFRegNetForImageClassification"), ("regnet", "TFRegNetForImageClassification"),
("resnet", "TFResNetForImageClassification"), ("resnet", "TFResNetForImageClassification"),
("segformer", "TFSegformerForImageClassification"), ("segformer", "TFSegformerForImageClassification"),
("swiftformer", "TFSwiftFormerForImageClassification"),
("swin", "TFSwinForImageClassification"), ("swin", "TFSwinForImageClassification"),
("vit", "TFViTForImageClassification"), ("vit", "TFViTForImageClassification"),
] ]
......
...@@ -16,6 +16,7 @@ from typing import TYPE_CHECKING ...@@ -16,6 +16,7 @@ from typing import TYPE_CHECKING
from ...utils import ( from ...utils import (
OptionalDependencyNotAvailable, OptionalDependencyNotAvailable,
_LazyModule, _LazyModule,
is_tf_available,
is_torch_available, is_torch_available,
) )
...@@ -41,6 +42,19 @@ else: ...@@ -41,6 +42,19 @@ else:
"SwiftFormerPreTrainedModel", "SwiftFormerPreTrainedModel",
] ]
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_tf_swiftformer"] = [
"TF_SWIFTFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFSwiftFormerForImageClassification",
"TFSwiftFormerModel",
"TFSwiftFormerPreTrainedModel",
]
if TYPE_CHECKING: if TYPE_CHECKING:
from .configuration_swiftformer import ( from .configuration_swiftformer import (
SWIFTFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, SWIFTFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP,
...@@ -60,6 +74,18 @@ if TYPE_CHECKING: ...@@ -60,6 +74,18 @@ if TYPE_CHECKING:
SwiftFormerModel, SwiftFormerModel,
SwiftFormerPreTrainedModel, SwiftFormerPreTrainedModel,
) )
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_tf_swiftformer import (
TF_SWIFTFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
TFSwiftFormerForImageClassification,
TFSwiftFormerModel,
TFSwiftFormerPreTrainedModel,
)
else: else:
import sys import sys
......
...@@ -42,6 +42,8 @@ class SwiftFormerConfig(PretrainedConfig): ...@@ -42,6 +42,8 @@ class SwiftFormerConfig(PretrainedConfig):
Args: Args:
image_size (`int`, *optional*, defaults to 224):
The size (resolution) of each image
num_channels (`int`, *optional*, defaults to 3): num_channels (`int`, *optional*, defaults to 3):
The number of input channels The number of input channels
depths (`List[int]`, *optional*, defaults to `[3, 3, 6, 4]`): depths (`List[int]`, *optional*, defaults to `[3, 3, 6, 4]`):
...@@ -62,6 +64,10 @@ class SwiftFormerConfig(PretrainedConfig): ...@@ -62,6 +64,10 @@ class SwiftFormerConfig(PretrainedConfig):
Padding in downsampling layers. Padding in downsampling layers.
drop_path_rate (`float`, *optional*, defaults to 0.0): drop_path_rate (`float`, *optional*, defaults to 0.0):
Rate at which to increase dropout probability in DropPath. Rate at which to increase dropout probability in DropPath.
drop_mlp_rate (`float`, *optional*, defaults to 0.0):
Dropout rate for the MLP component of SwiftFormer.
drop_conv_encoder_rate (`float`, *optional*, defaults to 0.0):
Dropout rate for the ConvEncoder component of SwiftFormer.
use_layer_scale (`bool`, *optional*, defaults to `True`): use_layer_scale (`bool`, *optional*, defaults to `True`):
Whether to scale outputs from token mixers. Whether to scale outputs from token mixers.
layer_scale_init_value (`float`, *optional*, defaults to 1e-05): layer_scale_init_value (`float`, *optional*, defaults to 1e-05):
...@@ -89,6 +95,7 @@ class SwiftFormerConfig(PretrainedConfig): ...@@ -89,6 +95,7 @@ class SwiftFormerConfig(PretrainedConfig):
def __init__( def __init__(
self, self,
image_size=224,
num_channels=3, num_channels=3,
depths=[3, 3, 6, 4], depths=[3, 3, 6, 4],
embed_dims=[48, 56, 112, 220], embed_dims=[48, 56, 112, 220],
...@@ -99,12 +106,15 @@ class SwiftFormerConfig(PretrainedConfig): ...@@ -99,12 +106,15 @@ class SwiftFormerConfig(PretrainedConfig):
down_stride=2, down_stride=2,
down_pad=1, down_pad=1,
drop_path_rate=0.0, drop_path_rate=0.0,
drop_mlp_rate=0.0,
drop_conv_encoder_rate=0.0,
use_layer_scale=True, use_layer_scale=True,
layer_scale_init_value=1e-5, layer_scale_init_value=1e-5,
batch_norm_eps=1e-5, batch_norm_eps=1e-5,
**kwargs, **kwargs,
): ):
super().__init__(**kwargs) super().__init__(**kwargs)
self.image_size = image_size
self.num_channels = num_channels self.num_channels = num_channels
self.depths = depths self.depths = depths
self.embed_dims = embed_dims self.embed_dims = embed_dims
...@@ -115,6 +125,8 @@ class SwiftFormerConfig(PretrainedConfig): ...@@ -115,6 +125,8 @@ class SwiftFormerConfig(PretrainedConfig):
self.down_stride = down_stride self.down_stride = down_stride
self.down_pad = down_pad self.down_pad = down_pad
self.drop_path_rate = drop_path_rate self.drop_path_rate = drop_path_rate
self.drop_mlp_rate = drop_mlp_rate
self.drop_conv_encoder_rate = drop_conv_encoder_rate
self.use_layer_scale = use_layer_scale self.use_layer_scale = use_layer_scale
self.layer_scale_init_value = layer_scale_init_value self.layer_scale_init_value = layer_scale_init_value
self.batch_norm_eps = batch_norm_eps self.batch_norm_eps = batch_norm_eps
......
...@@ -103,13 +103,12 @@ def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = Fals ...@@ -103,13 +103,12 @@ def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = Fals
return output return output
# Copied from transformers.models.beit.modeling_beit.BeitDropPath with Beit->Swiftformer
class SwiftFormerDropPath(nn.Module): class SwiftFormerDropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
def __init__(self, drop_prob: Optional[float] = None) -> None: def __init__(self, config: SwiftFormerConfig) -> None:
super().__init__() super().__init__()
self.drop_prob = drop_prob self.drop_prob = config.drop_path_rate
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return drop_path(hidden_states, self.drop_prob, self.training) return drop_path(hidden_states, self.drop_prob, self.training)
...@@ -169,7 +168,7 @@ class SwiftFormerConvEncoder(nn.Module): ...@@ -169,7 +168,7 @@ class SwiftFormerConvEncoder(nn.Module):
self.point_wise_conv1 = nn.Conv2d(dim, hidden_dim, kernel_size=1) self.point_wise_conv1 = nn.Conv2d(dim, hidden_dim, kernel_size=1)
self.act = nn.GELU() self.act = nn.GELU()
self.point_wise_conv2 = nn.Conv2d(hidden_dim, dim, kernel_size=1) self.point_wise_conv2 = nn.Conv2d(hidden_dim, dim, kernel_size=1)
self.drop_path = nn.Identity() self.drop_path = nn.Dropout(p=config.drop_conv_encoder_rate)
self.layer_scale = nn.Parameter(torch.ones(dim).unsqueeze(-1).unsqueeze(-1), requires_grad=True) self.layer_scale = nn.Parameter(torch.ones(dim).unsqueeze(-1).unsqueeze(-1), requires_grad=True)
def forward(self, x): def forward(self, x):
...@@ -200,7 +199,7 @@ class SwiftFormerMlp(nn.Module): ...@@ -200,7 +199,7 @@ class SwiftFormerMlp(nn.Module):
act_layer = ACT2CLS[config.hidden_act] act_layer = ACT2CLS[config.hidden_act]
self.act = act_layer() self.act = act_layer()
self.fc2 = nn.Conv2d(hidden_features, in_features, 1) self.fc2 = nn.Conv2d(hidden_features, in_features, 1)
self.drop = nn.Dropout(p=0.0) self.drop = nn.Dropout(p=config.drop_mlp_rate)
def forward(self, x): def forward(self, x):
x = self.norm1(x) x = self.norm1(x)
...@@ -302,7 +301,7 @@ class SwiftFormerEncoderBlock(nn.Module): ...@@ -302,7 +301,7 @@ class SwiftFormerEncoderBlock(nn.Module):
self.local_representation = SwiftFormerLocalRepresentation(config, dim=dim) self.local_representation = SwiftFormerLocalRepresentation(config, dim=dim)
self.attn = SwiftFormerEfficientAdditiveAttention(config, dim=dim) self.attn = SwiftFormerEfficientAdditiveAttention(config, dim=dim)
self.linear = SwiftFormerMlp(config, in_features=dim) self.linear = SwiftFormerMlp(config, in_features=dim)
self.drop_path = SwiftFormerDropPath(drop_path) if drop_path > 0.0 else nn.Identity() self.drop_path = SwiftFormerDropPath(config) if drop_path > 0.0 else nn.Identity()
self.use_layer_scale = use_layer_scale self.use_layer_scale = use_layer_scale
if use_layer_scale: if use_layer_scale:
self.layer_scale_1 = nn.Parameter( self.layer_scale_1 = nn.Parameter(
...@@ -315,21 +314,13 @@ class SwiftFormerEncoderBlock(nn.Module): ...@@ -315,21 +314,13 @@ class SwiftFormerEncoderBlock(nn.Module):
def forward(self, x): def forward(self, x):
x = self.local_representation(x) x = self.local_representation(x)
batch_size, channels, height, width = x.shape batch_size, channels, height, width = x.shape
res = self.attn(x.permute(0, 2, 3, 1).reshape(batch_size, height * width, channels))
res = res.reshape(batch_size, height, width, channels).permute(0, 3, 1, 2)
if self.use_layer_scale: if self.use_layer_scale:
x = x + self.drop_path( x = x + self.drop_path(self.layer_scale_1 * res)
self.layer_scale_1
* self.attn(x.permute(0, 2, 3, 1).reshape(batch_size, height * width, channels))
.reshape(batch_size, height, width, channels)
.permute(0, 3, 1, 2)
)
x = x + self.drop_path(self.layer_scale_2 * self.linear(x)) x = x + self.drop_path(self.layer_scale_2 * self.linear(x))
else: else:
x = x + self.drop_path( x = x + self.drop_path(res)
self.attn(x.permute(0, 2, 3, 1).reshape(batch_size, height * width, channels))
.reshape(batch_size, height, width, channels)
.permute(0, 3, 1, 2)
)
x = x + self.drop_path(self.linear(x)) x = x + self.drop_path(self.linear(x))
return x return x
......
This diff is collapsed.
...@@ -2554,6 +2554,30 @@ class TFSpeech2TextPreTrainedModel(metaclass=DummyObject): ...@@ -2554,6 +2554,30 @@ class TFSpeech2TextPreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["tf"]) requires_backends(self, ["tf"])
TF_SWIFTFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None
class TFSwiftFormerForImageClassification(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFSwiftFormerModel(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFSwiftFormerPreTrainedModel(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
TF_SWIN_PRETRAINED_MODEL_ARCHIVE_LIST = None TF_SWIN_PRETRAINED_MODEL_ARCHIVE_LIST = None
......
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Testing suite for the TensorFlow SwiftFormer model. """
import inspect
import unittest
from transformers import SwiftFormerConfig
from transformers.testing_utils import (
require_tf,
require_vision,
slow,
)
from transformers.utils import cached_property, is_tf_available, is_vision_available
from ...test_configuration_common import ConfigTester
from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
if is_tf_available():
import tensorflow as tf
from transformers import TFSwiftFormerForImageClassification, TFSwiftFormerModel
from transformers.modeling_tf_utils import keras
from transformers.models.swiftformer.modeling_tf_swiftformer import TF_SWIFTFORMER_PRETRAINED_MODEL_ARCHIVE_LIST
if is_vision_available():
from PIL import Image
from transformers import ViTImageProcessor
class TFSwiftFormerModelTester:
def __init__(
self,
parent,
batch_size=1,
num_channels=3,
is_training=True,
use_labels=True,
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
image_size=224,
num_labels=2,
layer_depths=[3, 3, 6, 4],
embed_dims=[48, 56, 112, 220],
):
self.parent = parent
self.batch_size = batch_size
self.num_channels = num_channels
self.is_training = is_training
self.use_labels = use_labels
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.num_labels = num_labels
self.image_size = image_size
self.layer_depths = layer_depths
self.embed_dims = embed_dims
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
labels = None
if self.use_labels:
labels = ids_tensor([self.batch_size], self.num_labels)
config = self.get_config()
return config, pixel_values, labels
def get_config(self):
return SwiftFormerConfig(
depths=self.layer_depths,
embed_dims=self.embed_dims,
mlp_ratio=4,
downsamples=[True, True, True, True],
hidden_act="gelu",
num_labels=self.num_labels,
down_patch_size=3,
down_stride=2,
down_pad=1,
drop_rate=0.0,
drop_path_rate=0.0,
use_layer_scale=True,
layer_scale_init_value=1e-5,
)
def create_and_check_model(self, config, pixel_values, labels):
model = TFSwiftFormerModel(config=config)
result = model(pixel_values)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.embed_dims[-1], 7, 7))
def create_and_check_for_image_classification(self, config, pixel_values, labels):
config.num_labels = self.num_labels
model = TFSwiftFormerForImageClassification(config)
result = model(pixel_values, labels=labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
model = TFSwiftFormerForImageClassification(config)
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
result = model(pixel_values)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
def prepare_config_and_inputs_for_common(self):
(config, pixel_values, labels) = self.prepare_config_and_inputs()
inputs_dict = {"pixel_values": pixel_values}
return config, inputs_dict
@require_tf
class TFSwiftFormerModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
"""
Here we also overwrite some of the tests of test_modeling_common.py, as SwiftFormer does not use input_ids, inputs_embeds,
attention_mask and seq_length.
"""
all_model_classes = (TFSwiftFormerModel, TFSwiftFormerForImageClassification) if is_tf_available() else ()
pipeline_model_mapping = (
{"feature-extraction": TFSwiftFormerModel, "image-classification": TFSwiftFormerForImageClassification}
if is_tf_available()
else {}
)
fx_compatible = False
test_pruning = False
test_resize_embeddings = False
test_head_masking = False
has_attentions = False
test_onnx = False
def setUp(self):
self.model_tester = TFSwiftFormerModelTester(self)
self.config_tester = ConfigTester(
self,
config_class=SwiftFormerConfig,
has_text_modality=False,
hidden_size=37,
num_attention_heads=12,
num_hidden_layers=12,
)
def test_config(self):
self.config_tester.run_common_tests()
@unittest.skip(reason="TFSwiftFormer does not use inputs_embeds")
def test_inputs_embeds(self):
pass
def test_model_common_attributes(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
x = model.get_output_embeddings()
self.assertTrue(x is None or isinstance(x, keras.layers.Dense))
# Copied from transformers.tests.models.deit.test_modeling_tf_deit.py
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
signature = inspect.signature(model.call)
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()]
expected_arg_names = ["pixel_values"]
self.assertListEqual(arg_names[:1], expected_arg_names)
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
def test_for_image_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
@slow
def test_model_from_pretrained(self):
for model_name in TF_SWIFTFORMER_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
model = TFSwiftFormerModel.from_pretrained(model_name)
self.assertIsNotNone(model)
@unittest.skip(reason="TFSwiftFormer does not output attentions")
def test_attention_outputs(self):
pass
def test_hidden_states_output(self):
def check_hidden_states_output(inputs_dict, config, model_class):
model = model_class(config)
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
hidden_states = outputs.hidden_states
expected_num_stages = 8
self.assertEqual(len(hidden_states), expected_num_stages)
# SwiftFormer's feature maps are of shape (batch_size, embed_dims, height, width)
# with the width and height being successively divided by 2, after every 2 blocks
for i in range(len(hidden_states)):
self.assertEqual(
hidden_states[i].shape,
tf.TensorShape(
[
self.model_tester.batch_size,
self.model_tester.embed_dims[i // 2],
(self.model_tester.image_size // 4) // 2 ** (i // 2),
(self.model_tester.image_size // 4) // 2 ** (i // 2),
]
),
)
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
inputs_dict["output_hidden_states"] = True
check_hidden_states_output(inputs_dict, config, model_class)
# check that output_hidden_states also work using config
del inputs_dict["output_hidden_states"]
config.output_hidden_states = True
check_hidden_states_output(inputs_dict, config, model_class)
# We will verify our results on an image of cute cats
def prepare_img():
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
return image
@require_tf
@require_vision
class TFSwiftFormerModelIntegrationTest(unittest.TestCase):
@cached_property
def default_image_processor(self):
return ViTImageProcessor.from_pretrained("MBZUAI/swiftformer-xs") if is_vision_available() else None
@slow
def test_inference_image_classification_head(self):
model = TFSwiftFormerForImageClassification.from_pretrained("MBZUAI/swiftformer-xs")
feature_extractor = self.default_feature_extractor
image = prepare_img()
inputs = feature_extractor(images=image, return_tensors="tf")
# forward pass
outputs = model(**inputs)
# verify the logits
expected_shape = tf.TensorShape((1, 1000))
self.assertEqual(outputs.logits.shape, expected_shape)
expected_slice = tf.constant([[-2.1703e00, 2.1107e00, -2.0811e00]])
tf.debugging.assert_near(outputs.logits[0, :3], expected_slice, atol=1e-4)
...@@ -697,6 +697,8 @@ OBJECTS_TO_IGNORE = [ ...@@ -697,6 +697,8 @@ OBJECTS_TO_IGNORE = [
"TFSegformerModel", "TFSegformerModel",
"TFSpeech2TextForConditionalGeneration", "TFSpeech2TextForConditionalGeneration",
"TFSpeech2TextModel", "TFSpeech2TextModel",
"TFSwiftFormerForImageClassification",
"TFSwiftFormerModel",
"TFSwinForImageClassification", "TFSwinForImageClassification",
"TFSwinForMaskedImageModeling", "TFSwinForMaskedImageModeling",
"TFSwinModel", "TFSwinModel",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment