Unverified Commit d2cec09b authored by João David's avatar João David Committed by GitHub
Browse files

Add TF swiftformer (#23342)



* Duplicate swiftformer

* Convert SwiftFormerPatchEmbedding

* Convert SwiftFormerEmbeddings

* Convert TFSwiftFormerMlp

* Convert TFSwiftFormerConvEncoder

* Convert TFSwiftFormerLocalRepresentation

* convert TFSwiftFormerEncoderBlock

* Convert SwiftFormerStage

* Convert SwiftFormerEncoder

* Add TFSWiftFormerPreTrainedModel

* Convert SwiftFormerForImageClassification

* Add kwargs and start drop path

* Fix syntax

* Change Model class name

* Add TFSwiftFormer to __init__

* Duplicate test_modeling_swiftformer

* First test conversions

* Change require_torch to require_tf

* Add exports to swiftformer __init__

* Add TFSwiftFormerModel wrapper

* Fix __init__ and run black

* Remove docstring from MainLayer, fix padding

* Use keras.layers.Activation on keras.Sequential

* Fix swiftformer exports

* Fix activation layer from config

* Remove post_inits

* Use tf.keras.layers.ZeroPadding2D

* Convert torch normalize

* Change tf test input shape

* Fix softmax and reduce_sum

* Convert expand_dims and repeat

* Add missing reshape and tranpose

* Simplify TFSwiftFormerEncoderBlock.call

* Fix mismatch in patch embeddings

* Fix expected output shape to match channels last

* Fix swiftformer typo

* Disable test_onnx

* Fix TFSwiftFormerForImageClassification call

* Add unpack inputs

* Convert flatten(2).mean(-1)

* Change vision dummy inputs (to be reviewed)

* Change test_forward_signature to use .call

* Fix @unpack_inputs

* Set return_tensors="tf" and rename class

* Rename wrongly named patch_embeddings layer

* Add serving_output and change dummy_input shape

* Make dimensions BCHW and transpose inside embedding layer

* Change SwiftFormerEncoderBlock

* Fix ruff problems

* Add image size to swiftformer config

* Change tranpose to MainLayer and use -1 for reshape

* Remove serving_outputs and dummy_inputs

* Remove test_initialization test from tf model

* Make Sequential component a separate layer

* Fix layers' names

* Tranpose encoder outputs

* Fix tests and check if hidden states is not None

* Fix TFSwiftFormerForImageClassification

* Run make fixup

* Run make fix-copies

* Update modeling_tf_auto

* Update docs

* Fix modeling auto mapping

* Update modelint_tf_swiftformer docs

* Fill image_size doc and type

* Add reduction=None to loss computation

* Update docs

* make style

* Debug: Delete the tip to see if that changes anything

* Re-add tip

* Remove add_code_sample_docstrings

* Remove unused import

* Get the debug to actually tell us the problem it has with the docs

* Try a substitution to match the PyTorch file?

* Add swiftformer to ignore list

* Add build() methods

* Update copyright year
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Remove FIXME comment

* Remove from_pt

* Update copyright year
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Rename one-letter variables

* Remove FIXMEs related to momentum

* Remove old TODO comment

* Remove outstanding FIXME comments

* Get dropout rate from config

* Add specific dropout config for MLP

* Add convencoder dropout to config

* Pass config to SwiftFormerDropPath layer

* Fix drop_path variable name and add Adapted from comment

* Run ruff

* Removed copied from comment

* Run fix copies

* Change drop_path to identity to match pt

* Cleanup build() methods and move to new keras imports

* Update docs/source/en/model_doc/swiftformer.md
Co-authored-by: default avatarMatt <Rocketknight1@users.noreply.github.com>

* Raise error if drop_path_rate > 0.0

* Apply suggestions from code review

Replace (self.dim), with self.dim,
Co-authored-by: default avatarMatt <Rocketknight1@users.noreply.github.com>

* Remove drop_path function

* Add training to TFSwiftFormerEncoder

* Set self.built = True last
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Should have been added to previous commit
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Apply suggestions from code review
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Change default_feature_extractor to default_image_processor
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Import Keras from modeling_tf_utils

* Remove relative import

* Run ruff --fix

* Move import keras to tf_available

* Add copied from comment to test_forward_signature

* Reduce batch size and num_labels

* Extract loss logic to hf_compute_loss

* Run ruff format

---------
Co-authored-by: default avatarMatt <rocketknight1@gmail.com>
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
Co-authored-by: default avatarMatt <Rocketknight1@users.noreply.github.com>
parent 21c912e7
...@@ -275,7 +275,7 @@ Flax), PyTorch, and/or TensorFlow. ...@@ -275,7 +275,7 @@ Flax), PyTorch, and/or TensorFlow.
| [StableLm](model_doc/stablelm) | ✅ | ❌ | ❌ | | [StableLm](model_doc/stablelm) | ✅ | ❌ | ❌ |
| [Starcoder2](model_doc/starcoder2) | ✅ | ❌ | ❌ | | [Starcoder2](model_doc/starcoder2) | ✅ | ❌ | ❌ |
| [SuperPoint](model_doc/superpoint) | ✅ | ❌ | ❌ | | [SuperPoint](model_doc/superpoint) | ✅ | ❌ | ❌ |
| [SwiftFormer](model_doc/swiftformer) | ✅ | | ❌ | | [SwiftFormer](model_doc/swiftformer) | ✅ | | ❌ |
| [Swin Transformer](model_doc/swin) | ✅ | ✅ | ❌ | | [Swin Transformer](model_doc/swin) | ✅ | ✅ | ❌ |
| [Swin Transformer V2](model_doc/swinv2) | ✅ | ❌ | ❌ | | [Swin Transformer V2](model_doc/swinv2) | ✅ | ❌ | ❌ |
| [Swin2SR](model_doc/swin2sr) | ✅ | ❌ | ❌ | | [Swin2SR](model_doc/swin2sr) | ✅ | ❌ | ❌ |
......
...@@ -26,7 +26,7 @@ The abstract from the paper is the following: ...@@ -26,7 +26,7 @@ The abstract from the paper is the following:
*Self-attention has become a defacto choice for capturing global context in various vision applications. However, its quadratic computational complexity with respect to image resolution limits its use in real-time applications, especially for deployment on resource-constrained mobile devices. Although hybrid approaches have been proposed to combine the advantages of convolutions and self-attention for a better speed-accuracy trade-off, the expensive matrix multiplication operations in self-attention remain a bottleneck. In this work, we introduce a novel efficient additive attention mechanism that effectively replaces the quadratic matrix multiplication operations with linear element-wise multiplications. Our design shows that the key-value interaction can be replaced with a linear layer without sacrificing any accuracy. Unlike previous state-of-the-art methods, our efficient formulation of self-attention enables its usage at all stages of the network. Using our proposed efficient additive attention, we build a series of models called "SwiftFormer" which achieves state-of-the-art performance in terms of both accuracy and mobile inference speed. Our small variant achieves 78.5% top-1 ImageNet-1K accuracy with only 0.8 ms latency on iPhone 14, which is more accurate and 2x faster compared to MobileViT-v2.* *Self-attention has become a defacto choice for capturing global context in various vision applications. However, its quadratic computational complexity with respect to image resolution limits its use in real-time applications, especially for deployment on resource-constrained mobile devices. Although hybrid approaches have been proposed to combine the advantages of convolutions and self-attention for a better speed-accuracy trade-off, the expensive matrix multiplication operations in self-attention remain a bottleneck. In this work, we introduce a novel efficient additive attention mechanism that effectively replaces the quadratic matrix multiplication operations with linear element-wise multiplications. Our design shows that the key-value interaction can be replaced with a linear layer without sacrificing any accuracy. Unlike previous state-of-the-art methods, our efficient formulation of self-attention enables its usage at all stages of the network. Using our proposed efficient additive attention, we build a series of models called "SwiftFormer" which achieves state-of-the-art performance in terms of both accuracy and mobile inference speed. Our small variant achieves 78.5% top-1 ImageNet-1K accuracy with only 0.8 ms latency on iPhone 14, which is more accurate and 2x faster compared to MobileViT-v2.*
This model was contributed by [shehan97](https://huggingface.co/shehan97). This model was contributed by [shehan97](https://huggingface.co/shehan97). The TensorFlow version was contributed by [joaocmd](https://huggingface.co/joaocmd).
The original code can be found [here](https://github.com/Amshaker/SwiftFormer). The original code can be found [here](https://github.com/Amshaker/SwiftFormer).
## SwiftFormerConfig ## SwiftFormerConfig
...@@ -42,3 +42,13 @@ The original code can be found [here](https://github.com/Amshaker/SwiftFormer). ...@@ -42,3 +42,13 @@ The original code can be found [here](https://github.com/Amshaker/SwiftFormer).
[[autodoc]] SwiftFormerForImageClassification [[autodoc]] SwiftFormerForImageClassification
- forward - forward
## TFSwiftFormerModel
[[autodoc]] TFSwiftFormerModel
- call
## TFSwiftFormerForImageClassification
[[autodoc]] TFSwiftFormerForImageClassification
- call
...@@ -4517,6 +4517,14 @@ else: ...@@ -4517,6 +4517,14 @@ else:
"TFSpeech2TextPreTrainedModel", "TFSpeech2TextPreTrainedModel",
] ]
) )
_import_structure["models.swiftformer"].extend(
[
"TF_SWIFTFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFSwiftFormerForImageClassification",
"TFSwiftFormerModel",
"TFSwiftFormerPreTrainedModel",
]
)
_import_structure["models.swin"].extend( _import_structure["models.swin"].extend(
[ [
"TF_SWIN_PRETRAINED_MODEL_ARCHIVE_LIST", "TF_SWIN_PRETRAINED_MODEL_ARCHIVE_LIST",
...@@ -8901,6 +8909,12 @@ if TYPE_CHECKING: ...@@ -8901,6 +8909,12 @@ if TYPE_CHECKING:
TFSpeech2TextModel, TFSpeech2TextModel,
TFSpeech2TextPreTrainedModel, TFSpeech2TextPreTrainedModel,
) )
from .models.swiftformer import (
TF_SWIFTFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
TFSwiftFormerForImageClassification,
TFSwiftFormerModel,
TFSwiftFormerPreTrainedModel,
)
from .models.swin import ( from .models.swin import (
TF_SWIN_PRETRAINED_MODEL_ARCHIVE_LIST, TF_SWIN_PRETRAINED_MODEL_ARCHIVE_LIST,
TFSwinForImageClassification, TFSwinForImageClassification,
......
...@@ -81,6 +81,7 @@ TF_MODEL_MAPPING_NAMES = OrderedDict( ...@@ -81,6 +81,7 @@ TF_MODEL_MAPPING_NAMES = OrderedDict(
("sam", "TFSamModel"), ("sam", "TFSamModel"),
("segformer", "TFSegformerModel"), ("segformer", "TFSegformerModel"),
("speech_to_text", "TFSpeech2TextModel"), ("speech_to_text", "TFSpeech2TextModel"),
("swiftformer", "TFSwiftFormerModel"),
("swin", "TFSwinModel"), ("swin", "TFSwinModel"),
("t5", "TFT5Model"), ("t5", "TFT5Model"),
("tapas", "TFTapasModel"), ("tapas", "TFTapasModel"),
...@@ -213,6 +214,7 @@ TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( ...@@ -213,6 +214,7 @@ TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
("regnet", "TFRegNetForImageClassification"), ("regnet", "TFRegNetForImageClassification"),
("resnet", "TFResNetForImageClassification"), ("resnet", "TFResNetForImageClassification"),
("segformer", "TFSegformerForImageClassification"), ("segformer", "TFSegformerForImageClassification"),
("swiftformer", "TFSwiftFormerForImageClassification"),
("swin", "TFSwinForImageClassification"), ("swin", "TFSwinForImageClassification"),
("vit", "TFViTForImageClassification"), ("vit", "TFViTForImageClassification"),
] ]
......
...@@ -16,6 +16,7 @@ from typing import TYPE_CHECKING ...@@ -16,6 +16,7 @@ from typing import TYPE_CHECKING
from ...utils import ( from ...utils import (
OptionalDependencyNotAvailable, OptionalDependencyNotAvailable,
_LazyModule, _LazyModule,
is_tf_available,
is_torch_available, is_torch_available,
) )
...@@ -41,6 +42,19 @@ else: ...@@ -41,6 +42,19 @@ else:
"SwiftFormerPreTrainedModel", "SwiftFormerPreTrainedModel",
] ]
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_tf_swiftformer"] = [
"TF_SWIFTFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFSwiftFormerForImageClassification",
"TFSwiftFormerModel",
"TFSwiftFormerPreTrainedModel",
]
if TYPE_CHECKING: if TYPE_CHECKING:
from .configuration_swiftformer import ( from .configuration_swiftformer import (
SWIFTFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, SWIFTFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP,
...@@ -60,6 +74,18 @@ if TYPE_CHECKING: ...@@ -60,6 +74,18 @@ if TYPE_CHECKING:
SwiftFormerModel, SwiftFormerModel,
SwiftFormerPreTrainedModel, SwiftFormerPreTrainedModel,
) )
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_tf_swiftformer import (
TF_SWIFTFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
TFSwiftFormerForImageClassification,
TFSwiftFormerModel,
TFSwiftFormerPreTrainedModel,
)
else: else:
import sys import sys
......
...@@ -42,6 +42,8 @@ class SwiftFormerConfig(PretrainedConfig): ...@@ -42,6 +42,8 @@ class SwiftFormerConfig(PretrainedConfig):
Args: Args:
image_size (`int`, *optional*, defaults to 224):
The size (resolution) of each image
num_channels (`int`, *optional*, defaults to 3): num_channels (`int`, *optional*, defaults to 3):
The number of input channels The number of input channels
depths (`List[int]`, *optional*, defaults to `[3, 3, 6, 4]`): depths (`List[int]`, *optional*, defaults to `[3, 3, 6, 4]`):
...@@ -62,6 +64,10 @@ class SwiftFormerConfig(PretrainedConfig): ...@@ -62,6 +64,10 @@ class SwiftFormerConfig(PretrainedConfig):
Padding in downsampling layers. Padding in downsampling layers.
drop_path_rate (`float`, *optional*, defaults to 0.0): drop_path_rate (`float`, *optional*, defaults to 0.0):
Rate at which to increase dropout probability in DropPath. Rate at which to increase dropout probability in DropPath.
drop_mlp_rate (`float`, *optional*, defaults to 0.0):
Dropout rate for the MLP component of SwiftFormer.
drop_conv_encoder_rate (`float`, *optional*, defaults to 0.0):
Dropout rate for the ConvEncoder component of SwiftFormer.
use_layer_scale (`bool`, *optional*, defaults to `True`): use_layer_scale (`bool`, *optional*, defaults to `True`):
Whether to scale outputs from token mixers. Whether to scale outputs from token mixers.
layer_scale_init_value (`float`, *optional*, defaults to 1e-05): layer_scale_init_value (`float`, *optional*, defaults to 1e-05):
...@@ -89,6 +95,7 @@ class SwiftFormerConfig(PretrainedConfig): ...@@ -89,6 +95,7 @@ class SwiftFormerConfig(PretrainedConfig):
def __init__( def __init__(
self, self,
image_size=224,
num_channels=3, num_channels=3,
depths=[3, 3, 6, 4], depths=[3, 3, 6, 4],
embed_dims=[48, 56, 112, 220], embed_dims=[48, 56, 112, 220],
...@@ -99,12 +106,15 @@ class SwiftFormerConfig(PretrainedConfig): ...@@ -99,12 +106,15 @@ class SwiftFormerConfig(PretrainedConfig):
down_stride=2, down_stride=2,
down_pad=1, down_pad=1,
drop_path_rate=0.0, drop_path_rate=0.0,
drop_mlp_rate=0.0,
drop_conv_encoder_rate=0.0,
use_layer_scale=True, use_layer_scale=True,
layer_scale_init_value=1e-5, layer_scale_init_value=1e-5,
batch_norm_eps=1e-5, batch_norm_eps=1e-5,
**kwargs, **kwargs,
): ):
super().__init__(**kwargs) super().__init__(**kwargs)
self.image_size = image_size
self.num_channels = num_channels self.num_channels = num_channels
self.depths = depths self.depths = depths
self.embed_dims = embed_dims self.embed_dims = embed_dims
...@@ -115,6 +125,8 @@ class SwiftFormerConfig(PretrainedConfig): ...@@ -115,6 +125,8 @@ class SwiftFormerConfig(PretrainedConfig):
self.down_stride = down_stride self.down_stride = down_stride
self.down_pad = down_pad self.down_pad = down_pad
self.drop_path_rate = drop_path_rate self.drop_path_rate = drop_path_rate
self.drop_mlp_rate = drop_mlp_rate
self.drop_conv_encoder_rate = drop_conv_encoder_rate
self.use_layer_scale = use_layer_scale self.use_layer_scale = use_layer_scale
self.layer_scale_init_value = layer_scale_init_value self.layer_scale_init_value = layer_scale_init_value
self.batch_norm_eps = batch_norm_eps self.batch_norm_eps = batch_norm_eps
......
...@@ -103,13 +103,12 @@ def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = Fals ...@@ -103,13 +103,12 @@ def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = Fals
return output return output
# Copied from transformers.models.beit.modeling_beit.BeitDropPath with Beit->Swiftformer
class SwiftFormerDropPath(nn.Module): class SwiftFormerDropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
def __init__(self, drop_prob: Optional[float] = None) -> None: def __init__(self, config: SwiftFormerConfig) -> None:
super().__init__() super().__init__()
self.drop_prob = drop_prob self.drop_prob = config.drop_path_rate
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return drop_path(hidden_states, self.drop_prob, self.training) return drop_path(hidden_states, self.drop_prob, self.training)
...@@ -169,7 +168,7 @@ class SwiftFormerConvEncoder(nn.Module): ...@@ -169,7 +168,7 @@ class SwiftFormerConvEncoder(nn.Module):
self.point_wise_conv1 = nn.Conv2d(dim, hidden_dim, kernel_size=1) self.point_wise_conv1 = nn.Conv2d(dim, hidden_dim, kernel_size=1)
self.act = nn.GELU() self.act = nn.GELU()
self.point_wise_conv2 = nn.Conv2d(hidden_dim, dim, kernel_size=1) self.point_wise_conv2 = nn.Conv2d(hidden_dim, dim, kernel_size=1)
self.drop_path = nn.Identity() self.drop_path = nn.Dropout(p=config.drop_conv_encoder_rate)
self.layer_scale = nn.Parameter(torch.ones(dim).unsqueeze(-1).unsqueeze(-1), requires_grad=True) self.layer_scale = nn.Parameter(torch.ones(dim).unsqueeze(-1).unsqueeze(-1), requires_grad=True)
def forward(self, x): def forward(self, x):
...@@ -200,7 +199,7 @@ class SwiftFormerMlp(nn.Module): ...@@ -200,7 +199,7 @@ class SwiftFormerMlp(nn.Module):
act_layer = ACT2CLS[config.hidden_act] act_layer = ACT2CLS[config.hidden_act]
self.act = act_layer() self.act = act_layer()
self.fc2 = nn.Conv2d(hidden_features, in_features, 1) self.fc2 = nn.Conv2d(hidden_features, in_features, 1)
self.drop = nn.Dropout(p=0.0) self.drop = nn.Dropout(p=config.drop_mlp_rate)
def forward(self, x): def forward(self, x):
x = self.norm1(x) x = self.norm1(x)
...@@ -302,7 +301,7 @@ class SwiftFormerEncoderBlock(nn.Module): ...@@ -302,7 +301,7 @@ class SwiftFormerEncoderBlock(nn.Module):
self.local_representation = SwiftFormerLocalRepresentation(config, dim=dim) self.local_representation = SwiftFormerLocalRepresentation(config, dim=dim)
self.attn = SwiftFormerEfficientAdditiveAttention(config, dim=dim) self.attn = SwiftFormerEfficientAdditiveAttention(config, dim=dim)
self.linear = SwiftFormerMlp(config, in_features=dim) self.linear = SwiftFormerMlp(config, in_features=dim)
self.drop_path = SwiftFormerDropPath(drop_path) if drop_path > 0.0 else nn.Identity() self.drop_path = SwiftFormerDropPath(config) if drop_path > 0.0 else nn.Identity()
self.use_layer_scale = use_layer_scale self.use_layer_scale = use_layer_scale
if use_layer_scale: if use_layer_scale:
self.layer_scale_1 = nn.Parameter( self.layer_scale_1 = nn.Parameter(
...@@ -315,21 +314,13 @@ class SwiftFormerEncoderBlock(nn.Module): ...@@ -315,21 +314,13 @@ class SwiftFormerEncoderBlock(nn.Module):
def forward(self, x): def forward(self, x):
x = self.local_representation(x) x = self.local_representation(x)
batch_size, channels, height, width = x.shape batch_size, channels, height, width = x.shape
res = self.attn(x.permute(0, 2, 3, 1).reshape(batch_size, height * width, channels))
res = res.reshape(batch_size, height, width, channels).permute(0, 3, 1, 2)
if self.use_layer_scale: if self.use_layer_scale:
x = x + self.drop_path( x = x + self.drop_path(self.layer_scale_1 * res)
self.layer_scale_1
* self.attn(x.permute(0, 2, 3, 1).reshape(batch_size, height * width, channels))
.reshape(batch_size, height, width, channels)
.permute(0, 3, 1, 2)
)
x = x + self.drop_path(self.layer_scale_2 * self.linear(x)) x = x + self.drop_path(self.layer_scale_2 * self.linear(x))
else: else:
x = x + self.drop_path( x = x + self.drop_path(res)
self.attn(x.permute(0, 2, 3, 1).reshape(batch_size, height * width, channels))
.reshape(batch_size, height, width, channels)
.permute(0, 3, 1, 2)
)
x = x + self.drop_path(self.linear(x)) x = x + self.drop_path(self.linear(x))
return x return x
......
# coding=utf-8
# Copyright 2024 MBZUAI and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" TensorFlow SwiftFormer model."""
import collections.abc
from typing import Optional, Tuple, Union
import tensorflow as tf
from ...activations_tf import get_tf_activation
from ...modeling_tf_outputs import (
TFBaseModelOutputWithNoAttention,
TFImageClassifierOutputWithNoAttention,
)
from ...modeling_tf_utils import TFPreTrainedModel, keras, keras_serializable, unpack_inputs
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
)
from .configuration_swiftformer import SwiftFormerConfig
logger = logging.get_logger(__name__)
# General docstring
_CONFIG_FOR_DOC = "SwiftFormerConfig"
# Base docstring
_CHECKPOINT_FOR_DOC = "MBZUAI/swiftformer-xs"
_EXPECTED_OUTPUT_SHAPE = [1, 220, 7, 7]
# Image classification docstring
_IMAGE_CLASS_CHECKPOINT = "MBZUAI/swiftformer-xs"
_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat"
TF_SWIFTFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [
"MBZUAI/swiftformer-xs",
# See all SwiftFormer models at https://huggingface.co/models?filter=swiftformer
]
class TFSwiftFormerPatchEmbeddingSequential(keras.layers.Layer):
"""
The sequential component of the patch embedding layer.
Input: tensor of shape `[batch_size, in_channels, height, width]`
Output: tensor of shape `[batch_size, out_channels, height/4, width/4]`
"""
def __init__(self, config: SwiftFormerConfig, **kwargs):
super().__init__(**kwargs)
self.out_chs = config.embed_dims[0]
self.zero_padding = keras.layers.ZeroPadding2D(padding=(1, 1))
self.conv1 = keras.layers.Conv2D(self.out_chs // 2, kernel_size=3, strides=2, name="0")
self.batch_norm1 = keras.layers.BatchNormalization(epsilon=config.batch_norm_eps, momentum=0.9, name="1")
self.conv2 = keras.layers.Conv2D(self.out_chs, kernel_size=3, strides=2, name="3")
self.batch_norm2 = keras.layers.BatchNormalization(epsilon=config.batch_norm_eps, momentum=0.9, name="4")
self.config = config
def call(self, x: tf.Tensor, training: bool = False) -> tf.Tensor:
x = self.zero_padding(x)
x = self.conv1(x)
x = self.batch_norm1(x, training=training)
x = get_tf_activation("relu")(x)
x = self.zero_padding(x)
x = self.conv2(x)
x = self.batch_norm2(x, training=training)
x = get_tf_activation("relu")(x)
return x
def build(self, input_shape=None):
if self.built:
return
if getattr(self, "conv1", None) is not None:
with tf.name_scope(self.conv1.name):
self.conv1.build(self.config.num_channels)
if getattr(self, "batch_norm1", None) is not None:
with tf.name_scope(self.batch_norm1.name):
self.batch_norm1.build((None, None, None, self.out_chs // 2))
if getattr(self, "conv2", None) is not None:
with tf.name_scope(self.conv2.name):
self.conv2.build((None, None, None, self.out_chs // 2))
if getattr(self, "batch_norm2", None) is not None:
with tf.name_scope(self.batch_norm2.name):
self.batch_norm2.build((None, None, None, self.out_chs))
self.built = True
class TFSwiftFormerPatchEmbedding(keras.layers.Layer):
"""
Patch Embedding Layer constructed of two 2D convolutional layers.
Input: tensor of shape `[batch_size, in_channels, height, width]`
Output: tensor of shape `[batch_size, out_channels, height/4, width/4]`
"""
def __init__(self, config: SwiftFormerConfig, **kwargs):
super().__init__(**kwargs)
self.patch_embedding = TFSwiftFormerPatchEmbeddingSequential(config, name="patch_embedding")
def call(self, x: tf.Tensor, training: bool = False) -> tf.Tensor:
return self.patch_embedding(x, training=training)
def build(self, input_shape=None):
if self.built:
return
if getattr(self, "patch_embedding", None) is not None:
with tf.name_scope(self.patch_embedding.name):
self.patch_embedding.build(None)
self.built = True
class TFSwiftFormerDropPath(keras.layers.Layer):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
def __init__(self, config: SwiftFormerConfig, **kwargs) -> None:
super().__init__(**kwargs)
raise NotImplementedError("Drop path is not implemented in TF port")
def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor:
raise NotImplementedError("Drop path is not implemented in TF port")
class TFSwiftFormerEmbeddings(keras.layers.Layer):
"""
Embeddings layer consisting of a single 2D convolutional and batch normalization layer.
Input: tensor of shape `[batch_size, channels, height, width]`
Output: tensor of shape `[batch_size, channels, height/stride, width/stride]`
"""
def __init__(self, config: SwiftFormerConfig, index: int, **kwargs):
super().__init__(**kwargs)
patch_size = config.down_patch_size
stride = config.down_stride
padding = config.down_pad
embed_dims = config.embed_dims
self.in_chans = embed_dims[index]
self.embed_dim = embed_dims[index + 1]
patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
stride = stride if isinstance(stride, collections.abc.Iterable) else (stride, stride)
padding = padding if isinstance(padding, collections.abc.Iterable) else (padding, padding)
self.pad = keras.layers.ZeroPadding2D(padding=padding)
self.proj = keras.layers.Conv2D(self.embed_dim, kernel_size=patch_size, strides=stride, name="proj")
self.norm = keras.layers.BatchNormalization(epsilon=config.batch_norm_eps, momentum=0.9, name="norm")
def call(self, x: tf.Tensor, training: bool = False) -> tf.Tensor:
x = self.pad(x)
x = self.proj(x)
x = self.norm(x, training=training)
return x
def build(self, input_shape=None):
if self.built:
return
if getattr(self, "proj", None) is not None:
with tf.name_scope(self.proj.name):
self.proj.build(self.in_chans)
if getattr(self, "norm", None) is not None:
with tf.name_scope(self.norm.name):
self.norm.build((None, None, None, self.embed_dim))
self.built = True
class TFSwiftFormerConvEncoder(keras.layers.Layer):
"""
`SwiftFormerConvEncoder` with 3*3 and 1*1 convolutions.
Input: tensor of shape `[batch_size, channels, height, width]`
Output: tensor of shape `[batch_size, channels, height, width]`
"""
def __init__(self, config: SwiftFormerConfig, dim: int, **kwargs):
super().__init__(**kwargs)
hidden_dim = int(config.mlp_ratio * dim)
self.dim = dim
self.pad = keras.layers.ZeroPadding2D(padding=(1, 1))
self.depth_wise_conv = keras.layers.Conv2D(dim, kernel_size=3, groups=dim, name="depth_wise_conv")
self.norm = keras.layers.BatchNormalization(epsilon=config.batch_norm_eps, momentum=0.9, name="norm")
self.point_wise_conv1 = keras.layers.Conv2D(hidden_dim, kernel_size=1, name="point_wise_conv1")
self.act = get_tf_activation("gelu")
self.point_wise_conv2 = keras.layers.Conv2D(dim, kernel_size=1, name="point_wise_conv2")
self.drop_path = keras.layers.Dropout(name="drop_path", rate=config.drop_conv_encoder_rate)
self.hidden_dim = int(config.mlp_ratio * self.dim)
def build(self, input_shape=None):
if self.built:
return
self.layer_scale = self.add_weight(
name="layer_scale",
shape=self.dim,
initializer="ones",
trainable=True,
)
if getattr(self, "depth_wise_conv", None) is not None:
with tf.name_scope(self.depth_wise_conv.name):
self.depth_wise_conv.build(self.dim)
if getattr(self, "norm", None) is not None:
with tf.name_scope(self.norm.name):
self.norm.build((None, None, None, self.dim))
if getattr(self, "point_wise_conv1", None) is not None:
with tf.name_scope(self.point_wise_conv1.name):
self.point_wise_conv1.build(self.dim)
if getattr(self, "point_wise_conv2", None) is not None:
with tf.name_scope(self.point_wise_conv2.name):
self.point_wise_conv2.build(self.hidden_dim)
if getattr(self, "drop_path", None) is not None:
with tf.name_scope(self.drop_path.name):
self.drop_path.build(None)
self.built = True
def call(self, x: tf.Tensor, training: bool = False) -> tf.Tensor:
input = x
x = self.pad(x)
x = self.depth_wise_conv(x)
x = self.norm(x, training=training)
x = self.point_wise_conv1(x)
x = self.act(x)
x = self.point_wise_conv2(x)
x = input + self.drop_path(self.layer_scale * x)
return x
class TFSwiftFormerMlp(keras.layers.Layer):
"""
MLP layer with 1*1 convolutions.
Input: tensor of shape `[batch_size, channels, height, width]`
Output: tensor of shape `[batch_size, channels, height, width]`
"""
def __init__(self, config: SwiftFormerConfig, in_features: int, **kwargs):
super().__init__(**kwargs)
hidden_features = int(in_features * config.mlp_ratio)
self.norm1 = keras.layers.BatchNormalization(epsilon=config.batch_norm_eps, momentum=0.9, name="norm1")
self.fc1 = keras.layers.Conv2D(hidden_features, 1, name="fc1")
act_layer = get_tf_activation(config.hidden_act)
self.act = act_layer
self.fc2 = keras.layers.Conv2D(in_features, 1, name="fc2")
self.drop = keras.layers.Dropout(rate=config.drop_mlp_rate)
self.hidden_features = hidden_features
self.in_features = in_features
def call(self, x: tf.Tensor, training: bool = False) -> tf.Tensor:
x = self.norm1(x, training=training)
x = self.fc1(x)
x = self.act(x)
x = self.drop(x, training=training)
x = self.fc2(x)
x = self.drop(x, training=training)
return x
def build(self, input_shape=None):
if self.built:
return
if getattr(self, "norm1", None) is not None:
with tf.name_scope(self.norm1.name):
self.norm1.build((None, None, None, self.in_features))
if getattr(self, "fc1", None) is not None:
with tf.name_scope(self.fc1.name):
self.fc1.build((None, None, None, self.in_features))
if getattr(self, "fc2", None) is not None:
with tf.name_scope(self.fc2.name):
self.fc2.build((None, None, None, self.hidden_features))
self.built = True
class TFSwiftFormerEfficientAdditiveAttention(keras.layers.Layer):
"""
Efficient Additive Attention module for SwiftFormer.
Input: tensor of shape `[batch_size, channels, height, width]`
Output: tensor of shape `[batch_size, channels, height, width]`
"""
def __init__(self, config: SwiftFormerConfig, dim: int = 512, **kwargs):
super().__init__(**kwargs)
self.dim = dim
self.to_query = keras.layers.Dense(dim, name="to_query")
self.to_key = keras.layers.Dense(dim, name="to_key")
self.scale_factor = dim**-0.5
self.proj = keras.layers.Dense(dim, name="proj")
self.final = keras.layers.Dense(dim, name="final")
def build(self, input_shape=None):
if self.built:
return
self.w_g = self.add_weight(
name="w_g",
shape=(self.dim, 1),
initializer=keras.initializers.RandomNormal(mean=0, stddev=1),
trainable=True,
)
if getattr(self, "to_query", None) is not None:
with tf.name_scope(self.to_query.name):
self.to_query.build(self.dim)
if getattr(self, "to_key", None) is not None:
with tf.name_scope(self.to_key.name):
self.to_key.build(self.dim)
if getattr(self, "proj", None) is not None:
with tf.name_scope(self.proj.name):
self.proj.build(self.dim)
if getattr(self, "final", None) is not None:
with tf.name_scope(self.final.name):
self.final.build(self.dim)
self.built = True
def call(self, x: tf.Tensor) -> tf.Tensor:
query = self.to_query(x)
key = self.to_key(x)
query = tf.math.l2_normalize(query, dim=-1)
key = tf.math.l2_normalize(key, dim=-1)
query_weight = query @ self.w_g
scaled_query_weight = query_weight * self.scale_factor
scaled_query_weight = tf.nn.softmax(scaled_query_weight, axis=-1)
global_queries = tf.math.reduce_sum(scaled_query_weight * query, axis=1)
global_queries = tf.tile(tf.expand_dims(global_queries, 1), (1, key.shape[1], 1))
out = self.proj(global_queries * key) + query
out = self.final(out)
return out
class TFSwiftFormerLocalRepresentation(keras.layers.Layer):
"""
Local Representation module for SwiftFormer that is implemented by 3*3 depth-wise and point-wise convolutions.
Input: tensor of shape `[batch_size, channels, height, width]`
Output: tensor of shape `[batch_size, channels, height, width]`
"""
def __init__(self, config: SwiftFormerConfig, dim: int, **kwargs):
super().__init__(**kwargs)
self.dim = dim
self.pad = keras.layers.ZeroPadding2D(padding=(1, 1))
self.depth_wise_conv = keras.layers.Conv2D(dim, kernel_size=3, groups=dim, name="depth_wise_conv")
self.norm = keras.layers.BatchNormalization(epsilon=config.batch_norm_eps, momentum=0.9, name="norm")
self.point_wise_conv1 = keras.layers.Conv2D(dim, kernel_size=1, name="point_wise_conv1")
self.act = get_tf_activation("gelu")
self.point_wise_conv2 = keras.layers.Conv2D(dim, kernel_size=1, name="point_wise_conv2")
self.drop_path = keras.layers.Identity(name="drop_path")
def build(self, input_shape=None):
if self.built:
return
self.layer_scale = self.add_weight(
name="layer_scale",
shape=(self.dim),
initializer="ones",
trainable=True,
)
if getattr(self, "depth_wise_conv", None) is not None:
with tf.name_scope(self.depth_wise_conv.name):
self.depth_wise_conv.build((None, None, None, self.dim))
if getattr(self, "norm", None) is not None:
with tf.name_scope(self.norm.name):
self.norm.build((None, None, None, self.dim))
if getattr(self, "point_wise_conv1", None) is not None:
with tf.name_scope(self.point_wise_conv1.name):
self.point_wise_conv1.build(self.dim)
if getattr(self, "point_wise_conv2", None) is not None:
with tf.name_scope(self.point_wise_conv2.name):
self.point_wise_conv2.build(self.dim)
if getattr(self, "drop_path", None) is not None:
with tf.name_scope(self.drop_path.name):
self.drop_path.build(None)
self.built = True
def call(self, x: tf.Tensor, training: bool = False) -> tf.Tensor:
input = x
x = self.pad(x)
x = self.depth_wise_conv(x)
x = self.norm(x, training=training)
x = self.point_wise_conv1(x)
x = self.act(x)
x = self.point_wise_conv2(x)
x = input + self.drop_path(self.layer_scale * x, training=training)
return x
class TFSwiftFormerEncoderBlock(keras.layers.Layer):
"""
SwiftFormer Encoder Block for SwiftFormer. It consists of (1) Local representation module, (2)
SwiftFormerEfficientAdditiveAttention, and (3) MLP block.
Input: tensor of shape `[batch_size, channels, height, width]`
Output: tensor of shape `[batch_size, channels,height, width]`
"""
def __init__(self, config: SwiftFormerConfig, dim: int, drop_path: float = 0.0, **kwargs):
super().__init__(**kwargs)
layer_scale_init_value = config.layer_scale_init_value
use_layer_scale = config.use_layer_scale
self.local_representation = TFSwiftFormerLocalRepresentation(config, dim=dim, name="local_representation")
self.attn = TFSwiftFormerEfficientAdditiveAttention(config, dim=dim, name="attn")
self.linear = TFSwiftFormerMlp(config, in_features=dim, name="linear")
self.drop_path = TFSwiftFormerDropPath(config) if drop_path > 0.0 else keras.layers.Identity()
self.use_layer_scale = use_layer_scale
if use_layer_scale:
self.dim = dim
self.layer_scale_init_value = layer_scale_init_value
def build(self, input_shape=None):
if self.built:
return
self.layer_scale_1 = self.add_weight(
name="layer_scale_1",
shape=self.dim,
initializer=keras.initializers.constant(self.layer_scale_init_value),
trainable=True,
)
self.layer_scale_2 = self.add_weight(
name="layer_scale_2",
shape=self.dim,
initializer=keras.initializers.constant(self.layer_scale_init_value),
trainable=True,
)
if getattr(self, "local_representation", None) is not None:
with tf.name_scope(self.local_representation.name):
self.local_representation.build(None)
if getattr(self, "attn", None) is not None:
with tf.name_scope(self.attn.name):
self.attn.build(None)
if getattr(self, "linear", None) is not None:
with tf.name_scope(self.linear.name):
self.linear.build(None)
self.built = True
def call(self, x: tf.Tensor, training: bool = False):
x = self.local_representation(x, training=training)
batch_size, height, width, channels = x.shape
res = tf.reshape(x, [-1, height * width, channels])
res = self.attn(res)
res = tf.reshape(res, [-1, height, width, channels])
if self.use_layer_scale:
x = x + self.drop_path(self.layer_scale_1 * res, training=training)
x = x + self.drop_path(self.layer_scale_2 * self.linear(x), training=training)
else:
x = x + self.drop_path(res, training=training)
x = x + self.drop_path(self.linear(x), training=training)
return x
class TFSwiftFormerStage(keras.layers.Layer):
"""
A Swiftformer stage consisting of a series of `SwiftFormerConvEncoder` blocks and a final
`SwiftFormerEncoderBlock`.
Input: tensor in shape `[batch_size, channels, height, width]`
Output: tensor in shape `[batch_size, channels, height, width]`
"""
def __init__(self, config: SwiftFormerConfig, index: int, **kwargs) -> None:
super().__init__(**kwargs)
layer_depths = config.depths
dim = config.embed_dims[index]
depth = layer_depths[index]
self.blocks = []
for block_idx in range(depth):
block_dpr = config.drop_path_rate * (block_idx + sum(layer_depths[:index])) / (sum(layer_depths) - 1)
if depth - block_idx <= 1:
self.blocks.append(
TFSwiftFormerEncoderBlock(config, dim=dim, drop_path=block_dpr, name=f"blocks_._{block_idx}")
)
else:
self.blocks.append(TFSwiftFormerConvEncoder(config, dim=dim, name=f"blocks_._{block_idx}"))
def call(self, input: tf.Tensor, training: bool = False) -> tf.Tensor:
for i, block in enumerate(self.blocks):
input = block(input, training=training)
return input
def build(self, input_shape=None):
for layer in self.blocks:
with tf.name_scope(layer.name):
layer.build(None)
class TFSwiftFormerEncoder(keras.layers.Layer):
def __init__(self, config: SwiftFormerConfig, **kwargs) -> None:
super().__init__(**kwargs)
self.config = config
embed_dims = config.embed_dims
downsamples = config.downsamples
layer_depths = config.depths
# Transformer model
self.network = []
name_i = 0
for i in range(len(layer_depths)):
stage = TFSwiftFormerStage(config, index=i, name=f"network_._{name_i}")
self.network.append(stage)
name_i += 1
if i >= len(layer_depths) - 1:
break
if downsamples[i] or embed_dims[i] != embed_dims[i + 1]:
# downsampling between two stages
self.network.append(TFSwiftFormerEmbeddings(config, index=i, name=f"network_._{name_i}"))
name_i += 1
self.gradient_checkpointing = False
def call(
self,
hidden_states: tf.Tensor,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: bool = False,
) -> Union[tuple, TFBaseModelOutputWithNoAttention]:
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
all_hidden_states = (hidden_states,) if output_hidden_states else None
for i, block in enumerate(self.network):
hidden_states = block(hidden_states, training=training)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
hidden_states = tf.transpose(hidden_states, perm=[0, 3, 1, 2])
if all_hidden_states:
all_hidden_states = tuple(tf.transpose(s, perm=[0, 3, 1, 2]) for s in all_hidden_states)
if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states] if v is not None)
return TFBaseModelOutputWithNoAttention(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
)
def build(self, input_shape=None):
for layer in self.network:
with tf.name_scope(layer.name):
layer.build(None)
class TFSwiftFormerPreTrainedModel(TFPreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = SwiftFormerConfig
base_model_prefix = "swiftformer"
main_input_name = "pixel_values"
TFSWIFTFORMER_START_DOCSTRING = r"""
This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
behavior.
<Tip>
TF 2.0 models accepts two formats as inputs:
- having all inputs as keyword arguments (like PyTorch models), or
- having all inputs as a list, tuple or dict in the first positional arguments.
This second option is useful when using [`keras.Model.fit`] method which currently requires having all the
tensors in the first argument of the model call function: `model(inputs)`.
If you choose this second option, there are three possibilities you can use to gather all the input Tensors in the
first positional argument :
- a single Tensor with `input_ids` only and nothing else: `model(input_ids)`
- a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
`model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])`
- a dictionary with one or several input Tensors associated to the input names given in the docstring:
`model({"input_ids": input_ids, "token_type_ids": token_type_ids})`
</Tip>
Parameters:
config ([`SwiftFormerConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
TFSWIFTFORMER_INPUTS_DOCSTRING = r"""
Args:
pixel_values (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`ViTImageProcessor.__call__`]
for details.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
training (`bool`, *optional*, defaults to `False`):
Whether or not to run the model in training mode.
"""
@keras_serializable
class TFSwiftFormerMainLayer(keras.layers.Layer):
config_class = SwiftFormerConfig
def __init__(self, config: SwiftFormerConfig, **kwargs):
super().__init__(**kwargs)
self.config = config
self.patch_embed = TFSwiftFormerPatchEmbedding(config, name="patch_embed")
self.encoder = TFSwiftFormerEncoder(config, name="encoder")
@unpack_inputs
def call(
self,
pixel_values: Optional[tf.Tensor] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: bool = False,
) -> Union[Tuple, TFBaseModelOutputWithNoAttention]:
r""" """
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# TF 2.0 image layers can't use NCHW format when running on CPU.
# We transpose to NHWC format and then transpose back after the full forward pass.
# (batch_size, num_channels, height, width) -> (batch_size, height, width, num_channels)
pixel_values = tf.transpose(pixel_values, perm=[0, 2, 3, 1])
if pixel_values is None:
raise ValueError("You have to specify pixel_values")
embedding_output = self.patch_embed(pixel_values, training=training)
encoder_outputs = self.encoder(
embedding_output,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
if not return_dict:
return tuple(v for v in encoder_outputs if v is not None)
return TFBaseModelOutputWithNoAttention(
last_hidden_state=encoder_outputs.last_hidden_state,
hidden_states=encoder_outputs.hidden_states,
)
def build(self, input_shape=None):
if self.built:
return
if getattr(self, "patch_embed", None) is not None:
with tf.name_scope(self.patch_embed.name):
self.patch_embed.build(None)
if getattr(self, "encoder", None) is not None:
with tf.name_scope(self.encoder.name):
self.encoder.build(None)
self.built = True
@add_start_docstrings(
"The bare TFSwiftFormer Model transformer outputting raw hidden-states without any specific head on top.",
TFSWIFTFORMER_START_DOCSTRING,
)
class TFSwiftFormerModel(TFSwiftFormerPreTrainedModel):
def __init__(self, config: SwiftFormerConfig, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.swiftformer = TFSwiftFormerMainLayer(config, name="swiftformer")
@unpack_inputs
@add_start_docstrings_to_model_forward(TFSWIFTFORMER_INPUTS_DOCSTRING)
def call(
self,
pixel_values: Optional[tf.Tensor] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: bool = False,
) -> Union[TFBaseModelOutputWithNoAttention, Tuple[tf.Tensor]]:
outputs = self.swiftformer(
pixel_values=pixel_values,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
return outputs
def build(self, input_shape=None):
if self.built:
return
if getattr(self, "swiftformer", None) is not None:
with tf.name_scope(self.swiftformer.name):
self.swiftformer.build(None)
self.built = True
@add_start_docstrings(
"""
TFSwiftFormer Model transformer with an image classification head on top (e.g. for ImageNet).
""",
TFSWIFTFORMER_START_DOCSTRING,
)
class TFSwiftFormerForImageClassification(TFSwiftFormerPreTrainedModel):
def __init__(self, config: SwiftFormerConfig, **kwargs) -> None:
super().__init__(config, **kwargs)
self.num_labels = config.num_labels
self.swiftformer = TFSwiftFormerMainLayer(config, name="swiftformer")
# Classifier head
self.norm = keras.layers.BatchNormalization(epsilon=config.batch_norm_eps, momentum=0.9, name="norm")
self.head = (
keras.layers.Dense(self.num_labels, name="head")
if self.num_labels > 0
else keras.layers.Identity(name="head")
)
self.dist_head = (
keras.layers.Dense(self.num_labels, name="dist_head")
if self.num_labels > 0
else keras.layers.Identity(name="dist_head")
)
def hf_compute_loss(self, labels, logits):
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == tf.int64 or labels.dtype == tf.int32):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = keras.losses.MSE
if self.num_labels == 1:
loss = loss_fct(labels.squeeze(), logits.squeeze())
else:
loss = loss_fct(labels, logits)
elif self.config.problem_type == "single_label_classification":
loss_fct = keras.losses.SparseCategoricalCrossentropy(
from_logits=True, reduction=keras.losses.Reduction.NONE
)
loss = loss_fct(labels, logits)
elif self.config.problem_type == "multi_label_classification":
loss_fct = keras.losses.SparseCategoricalCrossentropy(
from_logits=True,
reduction=keras.losses.Reduction.NONE,
)
loss = loss_fct(labels, logits)
else:
loss = None
return loss
@unpack_inputs
@add_start_docstrings_to_model_forward(TFSWIFTFORMER_INPUTS_DOCSTRING)
def call(
self,
pixel_values: Optional[tf.Tensor] = None,
labels: Optional[tf.Tensor] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: bool = False,
) -> Union[tuple, TFImageClassifierOutputWithNoAttention]:
r"""
labels (`tf.Tensor` of shape `(batch_size,)`, *optional*):
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# run base model
outputs = self.swiftformer(
pixel_values,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
sequence_output = outputs.last_hidden_state if return_dict else outputs[0]
sequence_output = tf.transpose(sequence_output, perm=[0, 2, 3, 1])
# run classification head
sequence_output = self.norm(sequence_output, training=training)
sequence_output = tf.transpose(sequence_output, perm=[0, 3, 1, 2])
_, num_channels, height, width = sequence_output.shape
sequence_output = tf.reshape(sequence_output, [-1, num_channels, height * width])
sequence_output = tf.reduce_mean(sequence_output, axis=-1)
cls_out = self.head(sequence_output)
distillation_out = self.dist_head(sequence_output)
logits = (cls_out + distillation_out) / 2
# calculate loss
loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits)
if not return_dict:
output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
return TFImageClassifierOutputWithNoAttention(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
)
def build(self, input_shape=None):
if self.built:
return
if getattr(self, "swiftformer", None) is not None:
with tf.name_scope(self.swiftformer.name):
self.swiftformer.build(None)
if getattr(self, "norm", None) is not None:
with tf.name_scope(self.norm.name):
self.norm.build((None, None, None, self.config.embed_dims[-1]))
if getattr(self, "head", None) is not None:
with tf.name_scope(self.head.name):
self.head.build(self.config.embed_dims[-1])
if getattr(self, "dist_head", None) is not None:
with tf.name_scope(self.dist_head.name):
self.dist_head.build(self.config.embed_dims[-1])
self.built = True
...@@ -2554,6 +2554,30 @@ class TFSpeech2TextPreTrainedModel(metaclass=DummyObject): ...@@ -2554,6 +2554,30 @@ class TFSpeech2TextPreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["tf"]) requires_backends(self, ["tf"])
TF_SWIFTFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None
class TFSwiftFormerForImageClassification(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFSwiftFormerModel(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFSwiftFormerPreTrainedModel(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
TF_SWIN_PRETRAINED_MODEL_ARCHIVE_LIST = None TF_SWIN_PRETRAINED_MODEL_ARCHIVE_LIST = None
......
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Testing suite for the TensorFlow SwiftFormer model. """
import inspect
import unittest
from transformers import SwiftFormerConfig
from transformers.testing_utils import (
require_tf,
require_vision,
slow,
)
from transformers.utils import cached_property, is_tf_available, is_vision_available
from ...test_configuration_common import ConfigTester
from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
if is_tf_available():
import tensorflow as tf
from transformers import TFSwiftFormerForImageClassification, TFSwiftFormerModel
from transformers.modeling_tf_utils import keras
from transformers.models.swiftformer.modeling_tf_swiftformer import TF_SWIFTFORMER_PRETRAINED_MODEL_ARCHIVE_LIST
if is_vision_available():
from PIL import Image
from transformers import ViTImageProcessor
class TFSwiftFormerModelTester:
def __init__(
self,
parent,
batch_size=1,
num_channels=3,
is_training=True,
use_labels=True,
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
image_size=224,
num_labels=2,
layer_depths=[3, 3, 6, 4],
embed_dims=[48, 56, 112, 220],
):
self.parent = parent
self.batch_size = batch_size
self.num_channels = num_channels
self.is_training = is_training
self.use_labels = use_labels
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.num_labels = num_labels
self.image_size = image_size
self.layer_depths = layer_depths
self.embed_dims = embed_dims
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
labels = None
if self.use_labels:
labels = ids_tensor([self.batch_size], self.num_labels)
config = self.get_config()
return config, pixel_values, labels
def get_config(self):
return SwiftFormerConfig(
depths=self.layer_depths,
embed_dims=self.embed_dims,
mlp_ratio=4,
downsamples=[True, True, True, True],
hidden_act="gelu",
num_labels=self.num_labels,
down_patch_size=3,
down_stride=2,
down_pad=1,
drop_rate=0.0,
drop_path_rate=0.0,
use_layer_scale=True,
layer_scale_init_value=1e-5,
)
def create_and_check_model(self, config, pixel_values, labels):
model = TFSwiftFormerModel(config=config)
result = model(pixel_values)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.embed_dims[-1], 7, 7))
def create_and_check_for_image_classification(self, config, pixel_values, labels):
config.num_labels = self.num_labels
model = TFSwiftFormerForImageClassification(config)
result = model(pixel_values, labels=labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
model = TFSwiftFormerForImageClassification(config)
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
result = model(pixel_values)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
def prepare_config_and_inputs_for_common(self):
(config, pixel_values, labels) = self.prepare_config_and_inputs()
inputs_dict = {"pixel_values": pixel_values}
return config, inputs_dict
@require_tf
class TFSwiftFormerModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
"""
Here we also overwrite some of the tests of test_modeling_common.py, as SwiftFormer does not use input_ids, inputs_embeds,
attention_mask and seq_length.
"""
all_model_classes = (TFSwiftFormerModel, TFSwiftFormerForImageClassification) if is_tf_available() else ()
pipeline_model_mapping = (
{"feature-extraction": TFSwiftFormerModel, "image-classification": TFSwiftFormerForImageClassification}
if is_tf_available()
else {}
)
fx_compatible = False
test_pruning = False
test_resize_embeddings = False
test_head_masking = False
has_attentions = False
test_onnx = False
def setUp(self):
self.model_tester = TFSwiftFormerModelTester(self)
self.config_tester = ConfigTester(
self,
config_class=SwiftFormerConfig,
has_text_modality=False,
hidden_size=37,
num_attention_heads=12,
num_hidden_layers=12,
)
def test_config(self):
self.config_tester.run_common_tests()
@unittest.skip(reason="TFSwiftFormer does not use inputs_embeds")
def test_inputs_embeds(self):
pass
def test_model_common_attributes(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
x = model.get_output_embeddings()
self.assertTrue(x is None or isinstance(x, keras.layers.Dense))
# Copied from transformers.tests.models.deit.test_modeling_tf_deit.py
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
signature = inspect.signature(model.call)
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()]
expected_arg_names = ["pixel_values"]
self.assertListEqual(arg_names[:1], expected_arg_names)
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
def test_for_image_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
@slow
def test_model_from_pretrained(self):
for model_name in TF_SWIFTFORMER_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
model = TFSwiftFormerModel.from_pretrained(model_name)
self.assertIsNotNone(model)
@unittest.skip(reason="TFSwiftFormer does not output attentions")
def test_attention_outputs(self):
pass
def test_hidden_states_output(self):
def check_hidden_states_output(inputs_dict, config, model_class):
model = model_class(config)
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
hidden_states = outputs.hidden_states
expected_num_stages = 8
self.assertEqual(len(hidden_states), expected_num_stages)
# SwiftFormer's feature maps are of shape (batch_size, embed_dims, height, width)
# with the width and height being successively divided by 2, after every 2 blocks
for i in range(len(hidden_states)):
self.assertEqual(
hidden_states[i].shape,
tf.TensorShape(
[
self.model_tester.batch_size,
self.model_tester.embed_dims[i // 2],
(self.model_tester.image_size // 4) // 2 ** (i // 2),
(self.model_tester.image_size // 4) // 2 ** (i // 2),
]
),
)
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
inputs_dict["output_hidden_states"] = True
check_hidden_states_output(inputs_dict, config, model_class)
# check that output_hidden_states also work using config
del inputs_dict["output_hidden_states"]
config.output_hidden_states = True
check_hidden_states_output(inputs_dict, config, model_class)
# We will verify our results on an image of cute cats
def prepare_img():
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
return image
@require_tf
@require_vision
class TFSwiftFormerModelIntegrationTest(unittest.TestCase):
@cached_property
def default_image_processor(self):
return ViTImageProcessor.from_pretrained("MBZUAI/swiftformer-xs") if is_vision_available() else None
@slow
def test_inference_image_classification_head(self):
model = TFSwiftFormerForImageClassification.from_pretrained("MBZUAI/swiftformer-xs")
feature_extractor = self.default_feature_extractor
image = prepare_img()
inputs = feature_extractor(images=image, return_tensors="tf")
# forward pass
outputs = model(**inputs)
# verify the logits
expected_shape = tf.TensorShape((1, 1000))
self.assertEqual(outputs.logits.shape, expected_shape)
expected_slice = tf.constant([[-2.1703e00, 2.1107e00, -2.0811e00]])
tf.debugging.assert_near(outputs.logits[0, :3], expected_slice, atol=1e-4)
...@@ -697,6 +697,8 @@ OBJECTS_TO_IGNORE = [ ...@@ -697,6 +697,8 @@ OBJECTS_TO_IGNORE = [
"TFSegformerModel", "TFSegformerModel",
"TFSpeech2TextForConditionalGeneration", "TFSpeech2TextForConditionalGeneration",
"TFSpeech2TextModel", "TFSpeech2TextModel",
"TFSwiftFormerForImageClassification",
"TFSwiftFormerModel",
"TFSwinForImageClassification", "TFSwinForImageClassification",
"TFSwinForMaskedImageModeling", "TFSwinForMaskedImageModeling",
"TFSwinModel", "TFSwinModel",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment