Unverified Commit 954e18ab authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

TensorFlow MobileViT (#18555)



* initial implementation.

* add: working model till image classification.

* add: initial implementation that passes intg tests.
Co-authored-by: default avatarAmy <aeroberts4444@gmail.com>

* chore: formatting.

* add: tests (still breaking because of config mismatch).
Coo-authored-by: default avatarYih <2521628+ydshieh@users.noreply.github.com>

* add: corrected tests and remaning changes.

* fix code style and repo consistency.

* address PR comments.

* address Amy's comments.

* chore: remove from_pt argument.

* chore: add full-stop.

* fix: TFLite model conversion in the doc.

* Update src/transformers/models/mobilevit/modeling_tf_mobilevit.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/models/mobilevit/modeling_tf_mobilevit.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/models/mobilevit/modeling_tf_mobilevit.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/models/mobilevit/modeling_tf_mobilevit.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/models/mobilevit/modeling_tf_mobilevit.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* apply formatting.

* chore: remove comments from the example block.

* remove identation in the example.
Co-authored-by: default avatarAmy <aeroberts4444@gmail.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent fe58929a
......@@ -259,7 +259,7 @@ Flax), PyTorch, and/or TensorFlow.
| mBART | ✅ | ✅ | ✅ | ✅ | ✅ |
| Megatron-BERT | ❌ | ❌ | ✅ | ❌ | ❌ |
| MobileBERT | ✅ | ✅ | ✅ | ✅ | ❌ |
| MobileViT | ❌ | ❌ | ✅ | | ❌ |
| MobileViT | ❌ | ❌ | ✅ | | ❌ |
| MPNet | ✅ | ✅ | ✅ | ✅ | ❌ |
| MT5 | ✅ | ✅ | ✅ | ✅ | ✅ |
| MVP | ✅ | ✅ | ✅ | ❌ | ❌ |
......
......@@ -22,12 +22,40 @@ The abstract from the paper is the following:
Tips:
- MobileViT is more like a CNN than a Transformer model. It does not work on sequence data but on batches of images. Unlike ViT, there are no embeddings. The backbone model outputs a feature map.
- MobileViT is more like a CNN than a Transformer model. It does not work on sequence data but on batches of images. Unlike ViT, there are no embeddings. The backbone model outputs a feature map. You can follow [this tutorial](https://keras.io/examples/vision/mobilevit) for a lightweight introduction.
- One can use [`MobileViTFeatureExtractor`] to prepare images for the model. Note that if you do your own preprocessing, the pretrained checkpoints expect images to be in BGR pixel order (not RGB).
- The available image classification checkpoints are pre-trained on [ImageNet-1k](https://huggingface.co/datasets/imagenet-1k) (also referred to as ILSVRC 2012, a collection of 1.3 million images and 1,000 classes).
- The segmentation model uses a [DeepLabV3](https://arxiv.org/abs/1706.05587) head. The available semantic segmentation checkpoints are pre-trained on [PASCAL VOC](http://host.robots.ox.ac.uk/pascal/VOC/).
- As the name suggests MobileViT was desgined to be performant and efficient on mobile phones. The TensorFlow versions of the MobileViT models are fully compatible with [TensorFlow Lite](https://www.tensorflow.org/lite).
This model was contributed by [matthijs](https://huggingface.co/Matthijs). The original code and weights can be found [here](https://github.com/apple/ml-cvnets).
You can use the following code to convert a MobileViT checkpoint (be it image classification or semantic segmentation) to generate a
TensorFlow Lite model:
```py
from transformers import TFMobileViTForImageClassification
import tensorflow as tf
model_ckpt = "apple/mobilevit-xx-small"
model = TFMobileViTForImageClassification.from_pretrained(model_ckpt)
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_ops = [
tf.lite.OpsSet.TFLITE_BUILTINS,
tf.lite.OpsSet.SELECT_TF_OPS,
]
tflite_model = converter.convert()
tflite_filename = model_ckpt.split("/")[-1] + ".tflite"
with open(tflite_filename, "wb") as f:
f.write(tflite_model)
```
The resulting model will be just **about an MB** making it a good fit for mobile applications where resources and network
bandwidth can be constrained.
This model was contributed by [matthijs](https://huggingface.co/Matthijs). The TensorFlow version of the model was contributed by [sayakpaul](https://huggingface.co/sayakpaul). The original code and weights can be found [here](https://github.com/apple/ml-cvnets).
## MobileViTConfig
......@@ -53,3 +81,18 @@ This model was contributed by [matthijs](https://huggingface.co/Matthijs). The o
[[autodoc]] MobileViTForSemanticSegmentation
- forward
## TFMobileViTModel
[[autodoc]] TFMobileViTModel
- call
## TFMobileViTForImageClassification
[[autodoc]] TFMobileViTForImageClassification
- call
## TFMobileViTForSemanticSegmentation
[[autodoc]] TFMobileViTForSemanticSegmentation
- call
......@@ -2398,6 +2398,15 @@ else:
"TFMobileBertPreTrainedModel",
]
)
_import_structure["models.mobilevit"].extend(
[
"TF_MOBILEVIT_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFMobileViTPreTrainedModel",
"TFMobileViTModel",
"TFMobileViTForImageClassification",
"TFMobileViTForSemanticSegmentation",
]
)
_import_structure["models.mpnet"].extend(
[
"TF_MPNET_PRETRAINED_MODEL_ARCHIVE_LIST",
......@@ -4847,6 +4856,7 @@ if TYPE_CHECKING:
from .models.mbart import TFMBartForConditionalGeneration, TFMBartModel, TFMBartPreTrainedModel
from .models.mobilebert import (
TF_MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
TF_MOBILEVIT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFMobileBertForMaskedLM,
TFMobileBertForMultipleChoice,
TFMobileBertForNextSentencePrediction,
......@@ -4857,6 +4867,10 @@ if TYPE_CHECKING:
TFMobileBertMainLayer,
TFMobileBertModel,
TFMobileBertPreTrainedModel,
TFMobileViTForImageClassification,
TFMobileViTForSemanticSegmentation,
TFMobileViTModel,
TFMobileViTPreTrainedModel,
)
from .models.mpnet import (
TF_MPNET_PRETRAINED_MODEL_ARCHIVE_LIST,
......
......@@ -685,6 +685,37 @@ class TFSemanticSegmenterOutput(ModelOutput):
attentions: Optional[Tuple[tf.Tensor]] = None
@dataclass
class TFSemanticSegmenterOutputWithNoAttention(ModelOutput):
"""
Base class for outputs of semantic segmentation models that do not output attention scores.
Args:
loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Classification (or regression if config.num_labels==1) loss.
logits (`tf.Tensor` of shape `(batch_size, config.num_labels, logits_height, logits_width)`):
Classification scores for each pixel.
<Tip warning={true}>
The logits returned do not necessarily have the same size as the `pixel_values` passed as inputs. This is
to avoid doing two interpolations and lose some quality when a user needs to resize the logits to the
original image size as post-processing. You should always check your logits shape and resize as needed.
</Tip>
hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for
the output of each layer) of shape `(batch_size, patch_size, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
"""
loss: Optional[tf.Tensor] = None
logits: tf.Tensor = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
@dataclass
class TFImageClassifierOutput(ModelOutput):
"""
......
......@@ -59,6 +59,7 @@ TF_MODEL_MAPPING_NAMES = OrderedDict(
("marian", "TFMarianModel"),
("mbart", "TFMBartModel"),
("mobilebert", "TFMobileBertModel"),
("mobilevit", "TFMobileViTModel"),
("mpnet", "TFMPNetModel"),
("mt5", "TFMT5Model"),
("openai-gpt", "TFOpenAIGPTModel"),
......@@ -182,6 +183,7 @@ TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
("convnext", "TFConvNextForImageClassification"),
("data2vec-vision", "TFData2VecVisionForImageClassification"),
("deit", ("TFDeiTForImageClassification", "TFDeiTForImageClassificationWithTeacher")),
("mobilevit", "TFMobileViTForImageClassification"),
("regnet", "TFRegNetForImageClassification"),
("resnet", "TFResNetForImageClassification"),
("segformer", "TFSegformerForImageClassification"),
......@@ -194,6 +196,7 @@ TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES = OrderedDict(
[
# Model for Semantic Segmentation mapping
("data2vec-vision", "TFData2VecVisionForSemanticSegmentation"),
("mobilevit", "TFMobileViTForSemanticSegmentation"),
("segformer", "TFSegformerForSemanticSegmentation"),
]
)
......
......@@ -17,7 +17,13 @@
# limitations under the License.
from typing import TYPE_CHECKING
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_tf_available,
is_torch_available,
is_vision_available,
)
_import_structure = {
......@@ -46,6 +52,19 @@ else:
"MobileViTPreTrainedModel",
]
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_tf_mobilevit"] = [
"TF_MOBILEVIT_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFMobileViTForImageClassification",
"TFMobileViTForSemanticSegmentation",
"TFMobileViTModel",
"TFMobileViTPreTrainedModel",
]
if TYPE_CHECKING:
from .configuration_mobilevit import MOBILEVIT_PRETRAINED_CONFIG_ARCHIVE_MAP, MobileViTConfig, MobileViTOnnxConfig
......@@ -72,6 +91,28 @@ if TYPE_CHECKING:
MobileViTPreTrainedModel,
)
try:
if not is_vision_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .feature_extraction_mobilevit import MobileViTFeatureExtractor
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_tf_mobilevit import (
TF_MOBILEVIT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFMobileViTForImageClassification,
TFMobileViTForSemanticSegmentation,
TFMobileViTModel,
TFMobileViTPreTrainedModel,
)
else:
import sys
......
# coding=utf-8
# Copyright 2022 Apple Inc. and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Original license: https://github.com/apple/ml-cvnets/blob/main/LICENSE
""" TensorFlow 2.0 MobileViT model."""
from typing import Dict, Optional, Tuple, Union
import tensorflow as tf
from ...activations_tf import get_tf_activation
from ...file_utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from ...modeling_tf_outputs import (
TFBaseModelOutput,
TFBaseModelOutputWithPooling,
TFImageClassifierOutputWithNoAttention,
TFSemanticSegmenterOutputWithNoAttention,
)
from ...modeling_tf_utils import TFPreTrainedModel, TFSequenceClassificationLoss, keras_serializable, unpack_inputs
from ...tf_utils import shape_list, stable_softmax
from ...utils import logging
from .configuration_mobilevit import MobileViTConfig
logger = logging.get_logger(__name__)
# General docstring
_CONFIG_FOR_DOC = "MobileViTConfig"
_FEAT_EXTRACTOR_FOR_DOC = "MobileViTFeatureExtractor"
# Base docstring
_CHECKPOINT_FOR_DOC = "apple/mobilevit-small"
_EXPECTED_OUTPUT_SHAPE = [1, 640, 8, 8]
# Image classification docstring
_IMAGE_CLASS_CHECKPOINT = "apple/mobilevit-small"
_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat"
TF_MOBILEVIT_PRETRAINED_MODEL_ARCHIVE_LIST = [
"apple/mobilevit-small",
"apple/mobilevit-x-small",
"apple/mobilevit-xx-small",
"apple/deeplabv3-mobilevit-small",
"apple/deeplabv3-mobilevit-x-small",
"apple/deeplabv3-mobilevit-xx-small",
# See all MobileViT models at https://huggingface.co/models?filter=mobilevit
]
def make_divisible(value: int, divisor: int = 8, min_value: Optional[int] = None) -> int:
"""
Ensure that all layers have a channel count that is divisible by `divisor`. This function is taken from the
original TensorFlow repo. It can be seen here:
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
"""
if min_value is None:
min_value = divisor
new_value = max(min_value, int(value + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than 10%.
if new_value < 0.9 * value:
new_value += divisor
return int(new_value)
class TFMobileViTConvLayer(tf.keras.layers.Layer):
def __init__(
self,
config: MobileViTConfig,
out_channels: int,
kernel_size: int,
stride: int = 1,
groups: int = 1,
bias: bool = False,
dilation: int = 1,
use_normalization: bool = True,
use_activation: Union[bool, str] = True,
**kwargs
) -> None:
super().__init__(**kwargs)
logger.warning(
f"\n{self.__class__.__name__} has backpropagation operations that are NOT supported on CPU. If you wish "
"to train/fine-tine this model, you need a GPU or a TPU"
)
padding = int((kernel_size - 1) / 2) * dilation
self.padding = tf.keras.layers.ZeroPadding2D(padding)
if out_channels % groups != 0:
raise ValueError(f"Output channels ({out_channels}) are not divisible by {groups} groups.")
self.convolution = tf.keras.layers.Conv2D(
filters=out_channels,
kernel_size=kernel_size,
strides=stride,
padding="VALID",
dilation_rate=dilation,
groups=groups,
use_bias=bias,
name="convolution",
)
if use_normalization:
self.normalization = tf.keras.layers.BatchNormalization(epsilon=1e-5, momentum=0.1, name="normalization")
else:
self.normalization = None
if use_activation:
if isinstance(use_activation, str):
self.activation = get_tf_activation(use_activation)
elif isinstance(config.hidden_act, str):
self.activation = get_tf_activation(config.hidden_act)
else:
self.activation = config.hidden_act
else:
self.activation = None
def call(self, features: tf.Tensor, training: bool = False) -> tf.Tensor:
padded_features = self.padding(features)
features = self.convolution(padded_features)
if self.normalization is not None:
features = self.normalization(features, training=training)
if self.activation is not None:
features = self.activation(features)
return features
class TFMobileViTInvertedResidual(tf.keras.layers.Layer):
"""
Inverted residual block (MobileNetv2): https://arxiv.org/abs/1801.04381
"""
def __init__(
self, config: MobileViTConfig, in_channels: int, out_channels: int, stride: int, dilation: int = 1, **kwargs
) -> None:
super().__init__(**kwargs)
expanded_channels = make_divisible(int(round(in_channels * config.expand_ratio)), 8)
if stride not in [1, 2]:
raise ValueError(f"Invalid stride {stride}.")
self.use_residual = (stride == 1) and (in_channels == out_channels)
self.expand_1x1 = TFMobileViTConvLayer(
config, out_channels=expanded_channels, kernel_size=1, name="expand_1x1"
)
self.conv_3x3 = TFMobileViTConvLayer(
config,
out_channels=expanded_channels,
kernel_size=3,
stride=stride,
groups=expanded_channels,
dilation=dilation,
name="conv_3x3",
)
self.reduce_1x1 = TFMobileViTConvLayer(
config,
out_channels=out_channels,
kernel_size=1,
use_activation=False,
name="reduce_1x1",
)
def call(self, features: tf.Tensor, training: bool = False) -> tf.Tensor:
residual = features
features = self.expand_1x1(features, training=training)
features = self.conv_3x3(features, training=training)
features = self.reduce_1x1(features, training=training)
return residual + features if self.use_residual else features
class TFMobileViTMobileNetLayer(tf.keras.layers.Layer):
def __init__(
self,
config: MobileViTConfig,
in_channels: int,
out_channels: int,
stride: int = 1,
num_stages: int = 1,
**kwargs
) -> None:
super().__init__(**kwargs)
self.layers = []
for i in range(num_stages):
layer = TFMobileViTInvertedResidual(
config,
in_channels=in_channels,
out_channels=out_channels,
stride=stride if i == 0 else 1,
name=f"layer.{i}",
)
self.layers.append(layer)
in_channels = out_channels
def call(self, features: tf.Tensor, training: bool = False) -> tf.Tensor:
for layer_module in self.layers:
features = layer_module(features, training=training)
return features
class TFMobileViTSelfAttention(tf.keras.layers.Layer):
def __init__(self, config: MobileViTConfig, hidden_size: int, **kwargs) -> None:
super().__init__(**kwargs)
if hidden_size % config.num_attention_heads != 0:
raise ValueError(
f"The hidden size {hidden_size,} is not a multiple of the number of attention "
f"heads {config.num_attention_heads}."
)
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
scale = tf.cast(self.attention_head_size, dtype=tf.float32)
self.scale = tf.math.sqrt(scale)
self.query = tf.keras.layers.Dense(self.all_head_size, use_bias=config.qkv_bias, name="query")
self.key = tf.keras.layers.Dense(self.all_head_size, use_bias=config.qkv_bias, name="key")
self.value = tf.keras.layers.Dense(self.all_head_size, use_bias=config.qkv_bias, name="value")
self.dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob)
def transpose_for_scores(self, x: tf.Tensor) -> tf.Tensor:
batch_size = tf.shape(x)[0]
x = tf.reshape(x, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size))
return tf.transpose(x, perm=[0, 2, 1, 3])
def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor:
batch_size = tf.shape(hidden_states)[0]
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(self.query(hidden_states))
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
attention_scores = attention_scores / self.scale
# Normalize the attention scores to probabilities.
attention_probs = stable_softmax(attention_scores, axis=-1)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs, training=training)
context_layer = tf.matmul(attention_probs, value_layer)
context_layer = tf.transpose(context_layer, perm=[0, 2, 1, 3])
context_layer = tf.reshape(context_layer, shape=(batch_size, -1, self.all_head_size))
return context_layer
class TFMobileViTSelfOutput(tf.keras.layers.Layer):
def __init__(self, config: MobileViTConfig, hidden_size: int, **kwargs) -> None:
super().__init__(**kwargs)
self.dense = tf.keras.layers.Dense(hidden_size, name="dense")
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states, training=training)
return hidden_states
class TFMobileViTAttention(tf.keras.layers.Layer):
def __init__(self, config: MobileViTConfig, hidden_size: int, **kwargs) -> None:
super().__init__(**kwargs)
self.attention = TFMobileViTSelfAttention(config, hidden_size, name="attention")
self.dense_output = TFMobileViTSelfOutput(config, hidden_size, name="output")
def prune_heads(self, heads):
raise NotImplementedError
def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor:
self_outputs = self.attention(hidden_states, training=training)
attention_output = self.dense_output(self_outputs, training=training)
return attention_output
class TFMobileViTIntermediate(tf.keras.layers.Layer):
def __init__(self, config: MobileViTConfig, hidden_size: int, intermediate_size: int, **kwargs) -> None:
super().__init__(**kwargs)
self.dense = tf.keras.layers.Dense(intermediate_size, name="dense")
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = get_tf_activation(config.hidden_act)
else:
self.intermediate_act_fn = config.hidden_act
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
class TFMobileViTOutput(tf.keras.layers.Layer):
def __init__(self, config: MobileViTConfig, hidden_size: int, intermediate_size: int, **kwargs) -> None:
super().__init__(**kwargs)
self.dense = tf.keras.layers.Dense(hidden_size, name="dense")
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states, training=training)
hidden_states = hidden_states + input_tensor
return hidden_states
class TFMobileViTTransformerLayer(tf.keras.layers.Layer):
def __init__(self, config: MobileViTConfig, hidden_size: int, intermediate_size: int, **kwargs) -> None:
super().__init__(**kwargs)
self.attention = TFMobileViTAttention(config, hidden_size, name="attention")
self.intermediate = TFMobileViTIntermediate(config, hidden_size, intermediate_size, name="intermediate")
self.mobilevit_output = TFMobileViTOutput(config, hidden_size, intermediate_size, name="output")
self.layernorm_before = tf.keras.layers.LayerNormalization(
epsilon=config.layer_norm_eps, name="layernorm_before"
)
self.layernorm_after = tf.keras.layers.LayerNormalization(
epsilon=config.layer_norm_eps, name="layernorm_after"
)
def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor:
attention_output = self.attention(self.layernorm_before(hidden_states), training=training)
hidden_states = attention_output + hidden_states
layer_output = self.layernorm_after(hidden_states)
layer_output = self.intermediate(layer_output)
layer_output = self.mobilevit_output(layer_output, hidden_states, training=training)
return layer_output
class TFMobileViTTransformer(tf.keras.layers.Layer):
def __init__(self, config: MobileViTConfig, hidden_size: int, num_stages: int, **kwargs) -> None:
super().__init__(**kwargs)
self.layers = []
for i in range(num_stages):
transformer_layer = TFMobileViTTransformerLayer(
config,
hidden_size=hidden_size,
intermediate_size=int(hidden_size * config.mlp_ratio),
name=f"layer.{i}",
)
self.layers.append(transformer_layer)
def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor:
for layer_module in self.layers:
hidden_states = layer_module(hidden_states, training=training)
return hidden_states
class TFMobileViTLayer(tf.keras.layers.Layer):
"""
MobileViT block: https://arxiv.org/abs/2110.02178
"""
def __init__(
self,
config: MobileViTConfig,
in_channels: int,
out_channels: int,
stride: int,
hidden_size: int,
num_stages: int,
dilation: int = 1,
**kwargs
) -> None:
super().__init__(**kwargs)
self.patch_width = config.patch_size
self.patch_height = config.patch_size
if stride == 2:
self.downsampling_layer = TFMobileViTInvertedResidual(
config,
in_channels=in_channels,
out_channels=out_channels,
stride=stride if dilation == 1 else 1,
dilation=dilation // 2 if dilation > 1 else 1,
name="downsampling_layer",
)
in_channels = out_channels
else:
self.downsampling_layer = None
self.conv_kxk = TFMobileViTConvLayer(
config, out_channels=in_channels, kernel_size=config.conv_kernel_size, name="conv_kxk"
)
self.conv_1x1 = TFMobileViTConvLayer(
config,
out_channels=hidden_size,
kernel_size=1,
use_normalization=False,
use_activation=False,
name="conv_1x1",
)
self.transformer = TFMobileViTTransformer(
config, hidden_size=hidden_size, num_stages=num_stages, name="transformer"
)
self.layernorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm")
self.conv_projection = TFMobileViTConvLayer(
config, out_channels=in_channels, kernel_size=1, name="conv_projection"
)
self.fusion = TFMobileViTConvLayer(
config, out_channels=in_channels, kernel_size=config.conv_kernel_size, name="fusion"
)
def unfolding(self, features: tf.Tensor) -> Tuple[tf.Tensor, Dict]:
patch_width, patch_height = self.patch_width, self.patch_height
patch_area = tf.cast(patch_width * patch_height, "int32")
batch_size = tf.shape(features)[0]
orig_height = tf.shape(features)[1]
orig_width = tf.shape(features)[2]
channels = tf.shape(features)[3]
new_height = tf.cast(tf.math.ceil(orig_height / patch_height) * patch_height, "int32")
new_width = tf.cast(tf.math.ceil(orig_width / patch_width) * patch_width, "int32")
interpolate = new_width != orig_width or new_height != orig_height
if interpolate:
# Note: Padding can be done, but then it needs to be handled in attention function.
features = tf.image.resize(features, size=(new_height, new_width), method="bilinear")
# number of patches along width and height
num_patch_width = new_width // patch_width
num_patch_height = new_height // patch_height
num_patches = num_patch_height * num_patch_width
# convert from shape (batch_size, orig_height, orig_width, channels)
# to the shape (batch_size * patch_area, num_patches, channels)
features = tf.transpose(features, [0, 3, 1, 2])
patches = tf.reshape(
features, (batch_size * channels * num_patch_height, patch_height, num_patch_width, patch_width)
)
patches = tf.transpose(patches, [0, 2, 1, 3])
patches = tf.reshape(patches, (batch_size, channels, num_patches, patch_area))
patches = tf.transpose(patches, [0, 3, 2, 1])
patches = tf.reshape(patches, (batch_size * patch_area, num_patches, channels))
info_dict = {
"orig_size": (orig_height, orig_width),
"batch_size": batch_size,
"channels": channels,
"interpolate": interpolate,
"num_patches": num_patches,
"num_patches_width": num_patch_width,
"num_patches_height": num_patch_height,
}
return patches, info_dict
def folding(self, patches: tf.Tensor, info_dict: Dict) -> tf.Tensor:
patch_width, patch_height = self.patch_width, self.patch_height
patch_area = int(patch_width * patch_height)
batch_size = info_dict["batch_size"]
channels = info_dict["channels"]
num_patches = info_dict["num_patches"]
num_patch_height = info_dict["num_patches_height"]
num_patch_width = info_dict["num_patches_width"]
# convert from shape (batch_size * patch_area, num_patches, channels)
# back to shape (batch_size, channels, orig_height, orig_width)
features = tf.reshape(patches, (batch_size, patch_area, num_patches, -1))
features = tf.transpose(features, perm=(0, 3, 2, 1))
features = tf.reshape(
features, (batch_size * channels * num_patch_height, num_patch_width, patch_height, patch_width)
)
features = tf.transpose(features, perm=(0, 2, 1, 3))
features = tf.reshape(
features, (batch_size, channels, num_patch_height * patch_height, num_patch_width * patch_width)
)
features = tf.transpose(features, perm=(0, 2, 3, 1))
if info_dict["interpolate"]:
features = tf.image.resize(features, size=info_dict["orig_size"], method="bilinear")
return features
def call(self, features: tf.Tensor, training: bool = False) -> tf.Tensor:
# reduce spatial dimensions if needed
if self.downsampling_layer:
features = self.downsampling_layer(features, training=training)
residual = features
# local representation
features = self.conv_kxk(features, training=training)
features = self.conv_1x1(features, training=training)
# convert feature map to patches
patches, info_dict = self.unfolding(features)
# learn global representations
patches = self.transformer(patches, training=training)
patches = self.layernorm(patches)
# convert patches back to feature maps
features = self.folding(patches, info_dict)
features = self.conv_projection(features, training=training)
features = self.fusion(tf.concat([residual, features], axis=-1), training=training)
return features
class TFMobileViTEncoder(tf.keras.layers.Layer):
def __init__(self, config: MobileViTConfig, **kwargs) -> None:
super().__init__(**kwargs)
self.config = config
self.layers = []
# segmentation architectures like DeepLab and PSPNet modify the strides
# of the classification backbones
dilate_layer_4 = dilate_layer_5 = False
if config.output_stride == 8:
dilate_layer_4 = True
dilate_layer_5 = True
elif config.output_stride == 16:
dilate_layer_5 = True
dilation = 1
layer_1 = TFMobileViTMobileNetLayer(
config,
in_channels=config.neck_hidden_sizes[0],
out_channels=config.neck_hidden_sizes[1],
stride=1,
num_stages=1,
name="layer.0",
)
self.layers.append(layer_1)
layer_2 = TFMobileViTMobileNetLayer(
config,
in_channels=config.neck_hidden_sizes[1],
out_channels=config.neck_hidden_sizes[2],
stride=2,
num_stages=3,
name="layer.1",
)
self.layers.append(layer_2)
layer_3 = TFMobileViTLayer(
config,
in_channels=config.neck_hidden_sizes[2],
out_channels=config.neck_hidden_sizes[3],
stride=2,
hidden_size=config.hidden_sizes[0],
num_stages=2,
name="layer.2",
)
self.layers.append(layer_3)
if dilate_layer_4:
dilation *= 2
layer_4 = TFMobileViTLayer(
config,
in_channels=config.neck_hidden_sizes[3],
out_channels=config.neck_hidden_sizes[4],
stride=2,
hidden_size=config.hidden_sizes[1],
num_stages=4,
dilation=dilation,
name="layer.3",
)
self.layers.append(layer_4)
if dilate_layer_5:
dilation *= 2
layer_5 = TFMobileViTLayer(
config,
in_channels=config.neck_hidden_sizes[4],
out_channels=config.neck_hidden_sizes[5],
stride=2,
hidden_size=config.hidden_sizes[2],
num_stages=3,
dilation=dilation,
name="layer.4",
)
self.layers.append(layer_5)
def call(
self,
hidden_states: tf.Tensor,
output_hidden_states: bool = False,
return_dict: bool = True,
training: bool = False,
) -> Union[tuple, TFBaseModelOutput]:
all_hidden_states = () if output_hidden_states else None
for i, layer_module in enumerate(self.layers):
hidden_states = layer_module(hidden_states, training=training)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states] if v is not None)
return TFBaseModelOutput(last_hidden_state=hidden_states, hidden_states=all_hidden_states)
@keras_serializable
class TFMobileViTMainLayer(tf.keras.layers.Layer):
config_class = MobileViTConfig
def __init__(self, config: MobileViTConfig, expand_output: bool = True, **kwargs):
super().__init__(**kwargs)
self.config = config
self.expand_output = expand_output
self.conv_stem = TFMobileViTConvLayer(
config,
out_channels=config.neck_hidden_sizes[0],
kernel_size=3,
stride=2,
name="conv_stem",
)
self.encoder = TFMobileViTEncoder(config, name="encoder")
if self.expand_output:
self.conv_1x1_exp = TFMobileViTConvLayer(
config, out_channels=config.neck_hidden_sizes[6], kernel_size=1, name="conv_1x1_exp"
)
self.pooler = tf.keras.layers.GlobalAveragePooling2D(data_format="channels_first", name="pooler")
def _prune_heads(self, heads_to_prune):
"""
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
class PreTrainedModel
"""
raise NotImplementedError
@unpack_inputs
def call(
self,
pixel_values: Optional[tf.Tensor] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: bool = False,
) -> Union[Tuple[tf.Tensor], TFBaseModelOutputWithPooling]:
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# When running on CPU, `tf.keras.layers.Conv2D` doesn't support `NCHW` format.
# So change the input format from `NCHW` to `NHWC`.
# shape = (batch_size, in_height, in_width, in_channels=num_channels)
pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1))
embedding_output = self.conv_stem(pixel_values, training=training)
encoder_outputs = self.encoder(
embedding_output, output_hidden_states=output_hidden_states, return_dict=return_dict, training=training
)
if self.expand_output:
last_hidden_state = self.conv_1x1_exp(encoder_outputs[0])
# Change to NCHW output format to have uniformity in the modules
last_hidden_state = tf.transpose(last_hidden_state, perm=[0, 3, 1, 2])
# global average pooling: (batch_size, channels, height, width) -> (batch_size, channels)
pooled_output = self.pooler(last_hidden_state)
else:
last_hidden_state = encoder_outputs[0]
# Change to NCHW output format to have uniformity in the modules
last_hidden_state = tf.transpose(last_hidden_state, perm=[0, 3, 1, 2])
pooled_output = None
if not return_dict:
output = (last_hidden_state, pooled_output) if pooled_output is not None else (last_hidden_state,)
# Change to NCHW output format to have uniformity in the modules
if not self.expand_output:
remaining_encoder_outputs = encoder_outputs[1:]
remaining_encoder_outputs = tuple(
[tf.transpose(h, perm=(0, 3, 1, 2)) for h in remaining_encoder_outputs[0]]
)
remaining_encoder_outputs = (remaining_encoder_outputs,)
return output + remaining_encoder_outputs
else:
return output + encoder_outputs[1:]
# Change the other hidden state outputs to NCHW as well
if output_hidden_states:
hidden_states = tuple([tf.transpose(h, perm=(0, 3, 1, 2)) for h in encoder_outputs[1]])
return TFBaseModelOutputWithPooling(
last_hidden_state=last_hidden_state,
pooler_output=pooled_output,
hidden_states=hidden_states if output_hidden_states else encoder_outputs.hidden_states,
)
class TFMobileViTPreTrainedModel(TFPreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = MobileViTConfig
base_model_prefix = "mobilevit"
main_input_name = "pixel_values"
@property
def dummy_inputs(self) -> Dict[str, tf.Tensor]:
"""
Dummy inputs to build the network.
Returns:
`Dict[str, tf.Tensor]`: The dummy inputs.
"""
VISION_DUMMY_INPUTS = tf.random.uniform(
shape=(3, self.config.num_channels, self.config.image_size, self.config.image_size),
dtype=tf.float32,
)
return {"pixel_values": tf.constant(VISION_DUMMY_INPUTS)}
@tf.function(
input_signature=[
{
"pixel_values": tf.TensorSpec((None, None, None, None), tf.float32, name="pixel_values"),
}
]
)
def serving(self, inputs):
"""
Method used for serving the model.
Args:
inputs (`Dict[str, tf.Tensor]`):
The input of the saved model as a dictionary of tensors.
"""
output = self.call(inputs)
return self.serving_output(output)
MOBILEVIT_START_DOCSTRING = r"""
This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a [tf.keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
behavior.
<Tip>
TF 2.0 models accepts two formats as inputs:
- having all inputs as keyword arguments (like PyTorch models), or
- having all inputs as a list, tuple or dict in the first positional arguments.
This second option is useful when using [`tf.keras.Model.fit`] method which currently requires having all the
tensors in the first argument of the model call function: `model(inputs)`.
</Tip>
Parameters:
config ([`MobileViTConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights.
"""
MOBILEVIT_INPUTS_DOCSTRING = r"""
Args:
pixel_values (`np.ndarray`, `tf.Tensor`, `List[tf.Tensor]`, `Dict[str, tf.Tensor]` or `Dict[str, np.ndarray]` and each example must have the shape `(batch_size, num_channels, height, width)`):
Pixel values. Pixel values can be obtained using [`MobileViTFeatureExtractor`]. See
[`MobileViTFeatureExtractor.__call__`] for details.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail. This argument can be used only in eager mode, in graph mode the value in the config will be
used instead.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in
eager mode, in graph mode the value will always be set to True.
"""
@add_start_docstrings(
"The bare MobileViT model outputting raw hidden-states without any specific head on top.",
MOBILEVIT_START_DOCSTRING,
)
class TFMobileViTModel(TFMobileViTPreTrainedModel):
def __init__(self, config: MobileViTConfig, expand_output: bool = True, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.config = config
self.expand_output = expand_output
self.mobilevit = TFMobileViTMainLayer(config, expand_output=expand_output, name="mobilevit")
@unpack_inputs
@add_start_docstrings_to_model_forward(MOBILEVIT_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TFBaseModelOutputWithPooling,
config_class=_CONFIG_FOR_DOC,
modality="vision",
expected_output=_EXPECTED_OUTPUT_SHAPE,
)
def call(
self,
pixel_values: Optional[tf.Tensor] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: bool = False,
) -> Union[Tuple[tf.Tensor], TFBaseModelOutputWithPooling]:
output = self.mobilevit(pixel_values, output_hidden_states, return_dict, training=training)
return output
def serving_output(self, output: TFBaseModelOutputWithPooling) -> TFBaseModelOutputWithPooling:
# hidden_states not converted to Tensor with tf.convert_to_tensor as they are all of different dimensions
return TFBaseModelOutputWithPooling(
last_hidden_state=output.last_hidden_state,
pooler_output=output.pooler_output,
hidden_states=output.hidden_states,
)
@add_start_docstrings(
"""
MobileViT model with an image classification head on top (a linear layer on top of the pooled features), e.g. for
ImageNet.
""",
MOBILEVIT_START_DOCSTRING,
)
class TFMobileViTForImageClassification(TFMobileViTPreTrainedModel, TFSequenceClassificationLoss):
def __init__(self, config: MobileViTConfig, *inputs, **kwargs) -> None:
super().__init__(config, *inputs, **kwargs)
self.num_labels = config.num_labels
self.mobilevit = TFMobileViTMainLayer(config, name="mobilevit")
# Classifier head
self.dropout = tf.keras.layers.Dropout(config.classifier_dropout_prob)
self.classifier = (
tf.keras.layers.Dense(config.num_labels, name="classifier") if config.num_labels > 0 else tf.identity
)
@unpack_inputs
@add_start_docstrings_to_model_forward(MOBILEVIT_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
checkpoint=_IMAGE_CLASS_CHECKPOINT,
output_type=TFImageClassifierOutputWithNoAttention,
config_class=_CONFIG_FOR_DOC,
expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
)
def call(
self,
pixel_values: Optional[tf.Tensor] = None,
output_hidden_states: Optional[bool] = None,
labels: Optional[tf.Tensor] = None,
return_dict: Optional[bool] = None,
training: Optional[bool] = False,
) -> Union[tuple, TFImageClassifierOutputWithNoAttention]:
r"""
labels (`tf.Tensor` of shape `(batch_size,)`, *optional*):
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss). If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.mobilevit(
pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict, training=training
)
pooled_output = outputs.pooler_output if return_dict else outputs[1]
logits = self.classifier(self.dropout(pooled_output, training=training))
loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits)
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return TFImageClassifierOutputWithNoAttention(loss=loss, logits=logits, hidden_states=outputs.hidden_states)
def serving_output(self, output: TFImageClassifierOutputWithNoAttention) -> TFImageClassifierOutputWithNoAttention:
# hidden_states and attention not converted to Tensor with tf.convert_to_tensor as they are all of different dimensions
return TFImageClassifierOutputWithNoAttention(logits=output.logits, hidden_states=output.hidden_states)
class TFMobileViTASPPPooling(tf.keras.layers.Layer):
def __init__(self, config: MobileViTConfig, out_channels: int, **kwargs) -> None:
super().__init__(**kwargs)
self.global_pool = tf.keras.layers.GlobalAveragePooling2D(keepdims=True, name="global_pool")
self.conv_1x1 = TFMobileViTConvLayer(
config,
out_channels=out_channels,
kernel_size=1,
stride=1,
use_normalization=True,
use_activation="relu",
name="conv_1x1",
)
def call(self, features: tf.Tensor, training: bool = False) -> tf.Tensor:
spatial_size = shape_list(features)[1:-1]
features = self.global_pool(features)
features = self.conv_1x1(features, training=training)
features = tf.image.resize(features, size=spatial_size, method="bilinear")
return features
class TFMobileViTASPP(tf.keras.layers.Layer):
"""
ASPP module defined in DeepLab papers: https://arxiv.org/abs/1606.00915, https://arxiv.org/abs/1706.05587
"""
def __init__(self, config: MobileViTConfig, **kwargs) -> None:
super().__init__(**kwargs)
out_channels = config.aspp_out_channels
if len(config.atrous_rates) != 3:
raise ValueError("Expected 3 values for atrous_rates")
self.convs = []
in_projection = TFMobileViTConvLayer(
config,
out_channels=out_channels,
kernel_size=1,
use_activation="relu",
name="convs.0",
)
self.convs.append(in_projection)
self.convs.extend(
[
TFMobileViTConvLayer(
config,
out_channels=out_channels,
kernel_size=3,
dilation=rate,
use_activation="relu",
name=f"convs.{i + 1}",
)
for i, rate in enumerate(config.atrous_rates)
]
)
pool_layer = TFMobileViTASPPPooling(config, out_channels, name=f"convs.{len(config.atrous_rates) + 1}")
self.convs.append(pool_layer)
self.project = TFMobileViTConvLayer(
config,
out_channels=out_channels,
kernel_size=1,
use_activation="relu",
name="project",
)
self.dropout = tf.keras.layers.Dropout(config.aspp_dropout_prob)
def call(self, features: tf.Tensor, training: bool = False) -> tf.Tensor:
# since the hidden states were transposed to have `(batch_size, channels, height, width)`
# layout we transpose them back to have `(batch_size, height, width, channels)` layout.
features = tf.transpose(features, perm=[0, 2, 3, 1])
pyramid = []
for conv in self.convs:
pyramid.append(conv(features, training=training))
pyramid = tf.concat(pyramid, axis=-1)
pooled_features = self.project(pyramid, training=training)
pooled_features = self.dropout(pooled_features, training=training)
return pooled_features
class TFMobileViTDeepLabV3(tf.keras.layers.Layer):
"""
DeepLabv3 architecture: https://arxiv.org/abs/1706.05587
"""
def __init__(self, config: MobileViTConfig, **kwargs) -> None:
super().__init__(**kwargs)
self.aspp = TFMobileViTASPP(config, name="aspp")
self.dropout = tf.keras.layers.Dropout(config.classifier_dropout_prob)
self.classifier = TFMobileViTConvLayer(
config,
out_channels=config.num_labels,
kernel_size=1,
use_normalization=False,
use_activation=False,
bias=True,
name="classifier",
)
def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor:
features = self.aspp(hidden_states[-1], training=training)
features = self.dropout(features, training=training)
features = self.classifier(features, training=training)
return features
@add_start_docstrings(
"""
MobileViT model with a semantic segmentation head on top, e.g. for Pascal VOC.
""",
MOBILEVIT_START_DOCSTRING,
)
class TFMobileViTForSemanticSegmentation(TFMobileViTPreTrainedModel):
def __init__(self, config: MobileViTConfig, **kwargs) -> None:
super().__init__(config, **kwargs)
self.num_labels = config.num_labels
self.mobilevit = TFMobileViTMainLayer(config, expand_output=False, name="mobilevit")
self.segmentation_head = TFMobileViTDeepLabV3(config, name="segmentation_head")
def hf_compute_loss(self, logits, labels):
# upsample logits to the images' original size
# `labels` is of shape (batch_size, height, width)
label_interp_shape = shape_list(labels)[1:]
upsampled_logits = tf.image.resize(logits, size=label_interp_shape, method="bilinear")
# compute weighted loss
loss_fct = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction="none")
def masked_loss(real, pred):
unmasked_loss = loss_fct(real, pred)
mask = tf.cast(real != self.config.semantic_loss_ignore_index, dtype=unmasked_loss.dtype)
masked_loss = unmasked_loss * mask
# Reduction strategy in the similar spirit with
# https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_tf_utils.py#L210
reduced_masked_loss = tf.reduce_sum(masked_loss) / tf.reduce_sum(mask)
return tf.reshape(reduced_masked_loss, (1,))
return masked_loss(labels, upsampled_logits)
@unpack_inputs
@add_start_docstrings_to_model_forward(MOBILEVIT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFSemanticSegmenterOutputWithNoAttention, config_class=_CONFIG_FOR_DOC)
def call(
self,
pixel_values: Optional[tf.Tensor] = None,
labels: Optional[tf.Tensor] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: bool = False,
) -> Union[tuple, TFSemanticSegmenterOutputWithNoAttention]:
r"""
labels (`tf.Tensor` of shape `(batch_size, height, width)`, *optional*):
Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels > 1`, a classification loss is computed (Cross-Entropy).
Returns:
Examples:
```python
>>> from transformers import MobileViTFeatureExtractor, TFMobileViTForSemanticSegmentation
>>> from PIL import Image
>>> import requests
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> feature_extractor = MobileViTFeatureExtractor.from_pretrained("apple/deeplabv3-mobilevit-small")
>>> model = TFMobileViTForSemanticSegmentation.from_pretrained("apple/deeplabv3-mobilevit-small")
>>> inputs = feature_extractor(images=image, return_tensors="tf")
>>> outputs = model(**inputs)
>>> # logits are of shape (batch_size, num_labels, height, width)
>>> logits = outputs.logits
```"""
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.mobilevit(
pixel_values,
output_hidden_states=True, # we need the intermediate hidden states
return_dict=return_dict,
training=training,
)
encoder_hidden_states = outputs.hidden_states if return_dict else outputs[1]
logits = self.segmentation_head(encoder_hidden_states, training=training)
loss = None
if labels is not None:
if not self.config.num_labels > 1:
raise ValueError("The number of labels should be greater than one")
else:
loss = self.hf_compute_loss(logits=logits, labels=labels)
# make logits of shape (batch_size, num_labels, height, width) to
# keep them consistent across APIs
logits = tf.transpose(logits, perm=[0, 3, 1, 2])
if not return_dict:
if output_hidden_states:
output = (logits,) + outputs[1:]
else:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return TFSemanticSegmenterOutputWithNoAttention(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states if output_hidden_states else None,
)
def serving_output(
self, output: TFSemanticSegmenterOutputWithNoAttention
) -> TFSemanticSegmenterOutputWithNoAttention:
return TFSemanticSegmenterOutputWithNoAttention(logits=output.logits, hidden_states=output.hidden_states)
......@@ -1524,6 +1524,9 @@ class TFMBartPreTrainedModel(metaclass=DummyObject):
TF_MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
TF_MOBILEVIT_PRETRAINED_MODEL_ARCHIVE_LIST = None
class TFMobileBertForMaskedLM(metaclass=DummyObject):
_backends = ["tf"]
......@@ -1594,6 +1597,34 @@ class TFMobileBertPreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["tf"])
class TFMobileViTForImageClassification(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFMobileViTForSemanticSegmentation(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFMobileViTModel(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFMobileViTPreTrainedModel(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
TF_MPNET_PRETRAINED_MODEL_ARCHIVE_LIST = None
......
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Testing suite for the TensorFlow MobileViT model. """
import inspect
import unittest
from transformers import MobileViTConfig
from transformers.file_utils import is_tf_available, is_vision_available
from transformers.testing_utils import require_tf, slow
from ...test_configuration_common import ConfigTester
from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor
if is_tf_available():
import numpy as np
import tensorflow as tf
from transformers import TFMobileViTForImageClassification, TFMobileViTForSemanticSegmentation, TFMobileViTModel
from transformers.models.mobilevit.modeling_tf_mobilevit import TF_MOBILEVIT_PRETRAINED_MODEL_ARCHIVE_LIST
if is_vision_available():
from PIL import Image
from transformers import MobileViTFeatureExtractor
class TFMobileViTConfigTester(ConfigTester):
def create_and_test_config_common_properties(self):
config = self.config_class(**self.inputs_dict)
self.parent.assertTrue(hasattr(config, "hidden_sizes"))
self.parent.assertTrue(hasattr(config, "neck_hidden_sizes"))
self.parent.assertTrue(hasattr(config, "num_attention_heads"))
class TFMobileViTModelTester:
def __init__(
self,
parent,
batch_size=13,
image_size=32,
patch_size=2,
num_channels=3,
last_hidden_size=640,
num_attention_heads=4,
hidden_act="silu",
conv_kernel_size=3,
output_stride=32,
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
classifier_dropout_prob=0.1,
initializer_range=0.02,
is_training=True,
use_labels=True,
num_labels=10,
scope=None,
):
self.parent = parent
self.batch_size = batch_size
self.image_size = image_size
self.patch_size = patch_size
self.num_channels = num_channels
self.last_hidden_size = last_hidden_size
self.num_attention_heads = num_attention_heads
self.hidden_act = hidden_act
self.conv_kernel_size = conv_kernel_size
self.output_stride = output_stride
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.classifier_dropout_prob = classifier_dropout_prob
self.use_labels = use_labels
self.is_training = is_training
self.num_labels = num_labels
self.initializer_range = initializer_range
self.scope = scope
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
labels = None
pixel_labels = None
if self.use_labels:
labels = ids_tensor([self.batch_size], self.num_labels)
pixel_labels = ids_tensor([self.batch_size, self.image_size, self.image_size], self.num_labels)
config = self.get_config()
return config, pixel_values, labels, pixel_labels
def get_config(self):
return MobileViTConfig(
image_size=self.image_size,
patch_size=self.patch_size,
num_channels=self.num_channels,
num_attention_heads=self.num_attention_heads,
hidden_act=self.hidden_act,
conv_kernel_size=self.conv_kernel_size,
output_stride=self.output_stride,
hidden_dropout_prob=self.hidden_dropout_prob,
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
classifier_dropout_prob=self.classifier_dropout_prob,
initializer_range=self.initializer_range,
)
def create_and_check_model(self, config, pixel_values, labels, pixel_labels):
model = TFMobileViTModel(config=config)
result = model(pixel_values, training=False)
expected_height = expected_width = self.image_size // self.output_stride
self.parent.assertEqual(
result.last_hidden_state.shape, (self.batch_size, self.last_hidden_size, expected_height, expected_width)
)
def create_and_check_for_image_classification(self, config, pixel_values, labels, pixel_labels):
config.num_labels = self.num_labels
model = TFMobileViTForImageClassification(config)
result = model(pixel_values, labels=labels, training=False)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
def create_and_check_for_semantic_segmentation(self, config, pixel_values, labels, pixel_labels):
config.num_labels = self.num_labels
model = TFMobileViTForSemanticSegmentation(config)
expected_height = expected_width = self.image_size // self.output_stride
result = model(pixel_values, training=False)
self.parent.assertEqual(
result.logits.shape, (self.batch_size, self.num_labels, expected_height, expected_width)
)
result = model(pixel_values, labels=pixel_labels, training=False)
self.parent.assertEqual(
result.logits.shape, (self.batch_size, self.num_labels, expected_height, expected_width)
)
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, pixel_values, labels, pixel_labels = config_and_inputs
inputs_dict = {"pixel_values": pixel_values}
return config, inputs_dict
@require_tf
class MobileViTModelTest(TFModelTesterMixin, unittest.TestCase):
"""
Here we also overwrite some of the tests of test_modeling_common.py, as MobileViT does not use input_ids, inputs_embeds,
attention_mask and seq_length.
"""
all_model_classes = (
(TFMobileViTModel, TFMobileViTForImageClassification, TFMobileViTForSemanticSegmentation)
if is_tf_available()
else ()
)
test_pruning = False
test_resize_embeddings = False
test_head_masking = False
has_attentions = False
test_onnx = False
def setUp(self):
self.model_tester = TFMobileViTModelTester(self)
self.config_tester = TFMobileViTConfigTester(self, config_class=MobileViTConfig, has_text_modality=False)
def test_config(self):
self.config_tester.run_common_tests()
@unittest.skip(reason="MobileViT does not use inputs_embeds")
def test_inputs_embeds(self):
pass
@unittest.skip(reason="MobileViT does not support input and output embeddings")
def test_model_common_attributes(self):
pass
@unittest.skip(reason="MobileViT does not output attentions")
def test_attention_outputs(self):
pass
@unittest.skip("Test was written for TF 1.x and isn't really relevant here")
def test_compile_tf_model(self):
pass
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
signature = inspect.signature(model.call)
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()]
expected_arg_names = ["pixel_values"]
self.assertListEqual(arg_names[:1], expected_arg_names)
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
def test_hidden_states_output(self):
def check_hidden_states_output(inputs_dict, config, model_class):
model = model_class(config)
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
hidden_states = outputs.hidden_states
expected_num_stages = 5
self.assertEqual(len(hidden_states), expected_num_stages)
# MobileViT's feature maps are of shape (batch_size, num_channels, height, width)
# with the width and height being successively divided by 2.
divisor = 2
for i in range(len(hidden_states)):
self.assertListEqual(
list(hidden_states[i].shape[-2:]),
[self.model_tester.image_size // divisor, self.model_tester.image_size // divisor],
)
divisor *= 2
self.assertEqual(self.model_tester.output_stride, divisor // 2)
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
inputs_dict["output_hidden_states"] = True
check_hidden_states_output(inputs_dict, config, model_class)
# check that output_hidden_states also work using config
del inputs_dict["output_hidden_states"]
config.output_hidden_states = True
check_hidden_states_output(inputs_dict, config, model_class)
def test_for_image_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
def test_for_semantic_segmentation(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_semantic_segmentation(*config_and_inputs)
@unittest.skipIf(
not is_tf_available() or len(tf.config.list_physical_devices("GPU")) == 0,
reason="TF does not support backprop for grouped convolutions on CPU.",
)
def test_dataset_conversion(self):
super().test_dataset_conversion()
def check_keras_fit_results(self, val_loss1, val_loss2, atol=2e-1, rtol=2e-1):
self.assertTrue(np.allclose(val_loss1, val_loss2, atol=atol, rtol=rtol))
@unittest.skipIf(
not is_tf_available() or len(tf.config.list_physical_devices("GPU")) == 0,
reason="TF does not support backprop for grouped convolutions on CPU.",
)
def test_keras_fit(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
# Since `TFMobileViTModel` cannot operate with the default `fit()` method.
if model_class.__name__ != "TFMobileViTModel":
model = model_class(config)
if getattr(model, "hf_compute_loss", None):
super().test_keras_fit()
# The default test_loss_computation() uses -100 as a proxy ignore_index
# to test masked losses. Overridding to avoid -100 since semantic segmentation
# models use `semantic_loss_ignore_index` from the config.
def test_loss_computation(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
# set an ignore index to correctly test the masked loss used in
# `TFMobileViTForSemanticSegmentation`.
if model_class.__name__ != "TFMobileViTForSemanticSegmentation":
config.semantic_loss_ignore_index = 5
model = model_class(config)
if getattr(model, "hf_compute_loss", None):
# The number of elements in the loss should be the same as the number of elements in the label
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
added_label = prepared_for_class[
sorted(list(prepared_for_class.keys() - inputs_dict.keys()), reverse=True)[0]
]
expected_loss_size = added_label.shape.as_list()[:1]
# Test that model correctly compute the loss with kwargs
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
possible_input_names = {"input_ids", "pixel_values", "input_features"}
input_name = possible_input_names.intersection(set(prepared_for_class)).pop()
model_input = prepared_for_class.pop(input_name)
loss = model(model_input, **prepared_for_class)[0]
self.assertTrue(loss.shape.as_list() == expected_loss_size or loss.shape.as_list() == [1])
# Test that model correctly compute the loss when we mask some positions
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
possible_input_names = {"input_ids", "pixel_values", "input_features"}
input_name = possible_input_names.intersection(set(prepared_for_class)).pop()
model_input = prepared_for_class.pop(input_name)
if "labels" in prepared_for_class:
labels = prepared_for_class["labels"].numpy()
if len(labels.shape) > 1 and labels.shape[1] != 1:
# labels[0] = -100
prepared_for_class["labels"] = tf.convert_to_tensor(labels)
loss = model(model_input, **prepared_for_class)[0]
self.assertTrue(loss.shape.as_list() == expected_loss_size or loss.shape.as_list() == [1])
self.assertTrue(not np.any(np.isnan(loss.numpy())))
# Test that model correctly compute the loss with a dict
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
loss = model(prepared_for_class)[0]
self.assertTrue(loss.shape.as_list() == expected_loss_size or loss.shape.as_list() == [1])
# Test that model correctly compute the loss with a tuple
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
# Get keys that were added with the _prepare_for_class function
label_keys = prepared_for_class.keys() - inputs_dict.keys()
signature = inspect.signature(model.call).parameters
signature_names = list(signature.keys())
# Create a dictionary holding the location of the tensors in the tuple
tuple_index_mapping = {0: input_name}
for label_key in label_keys:
label_key_index = signature_names.index(label_key)
tuple_index_mapping[label_key_index] = label_key
sorted_tuple_index_mapping = sorted(tuple_index_mapping.items())
# Initialize a list with their default values, update the values and convert to a tuple
list_input = []
for name in signature_names:
if name != "kwargs":
list_input.append(signature[name].default)
for index, value in sorted_tuple_index_mapping:
list_input[index] = prepared_for_class[value]
tuple_input = tuple(list_input)
# Send to model
loss = model(tuple_input[:-1])[0]
self.assertTrue(loss.shape.as_list() == expected_loss_size or loss.shape.as_list() == [1])
@slow
def test_model_from_pretrained(self):
for model_name in TF_MOBILEVIT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
model = TFMobileViTModel.from_pretrained(model_name)
self.assertIsNotNone(model)
# We will verify our results on an image of cute cats
def prepare_img():
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
return image
@require_tf
class TFMobileViTModelIntegrationTest(unittest.TestCase):
@slow
def test_inference_image_classification_head(self):
model = TFMobileViTForImageClassification.from_pretrained("apple/mobilevit-xx-small")
feature_extractor = MobileViTFeatureExtractor.from_pretrained("apple/mobilevit-xx-small")
image = prepare_img()
inputs = feature_extractor(images=image, return_tensors="tf")
# forward pass
outputs = model(**inputs, training=False)
# verify the logits
expected_shape = tf.TensorShape((1, 1000))
self.assertEqual(outputs.logits.shape, expected_shape)
expected_slice = tf.constant([-1.9364, -1.2327, -0.4653])
tf.debugging.assert_near(outputs.logits[0, :3], expected_slice, atol=1e-4, rtol=1e-04)
@slow
def test_inference_semantic_segmentation(self):
# `from_pt` will be removed
model = TFMobileViTForSemanticSegmentation.from_pretrained("apple/deeplabv3-mobilevit-xx-small")
feature_extractor = MobileViTFeatureExtractor.from_pretrained("apple/deeplabv3-mobilevit-xx-small")
image = prepare_img()
inputs = feature_extractor(images=image, return_tensors="tf")
# forward pass
outputs = model(inputs.pixel_values, training=False)
logits = outputs.logits
# verify the logits
expected_shape = tf.TensorShape((1, 21, 32, 32))
self.assertEqual(logits.shape, expected_shape)
expected_slice = tf.constant(
[
[[6.9713, 6.9786, 7.2422], [7.2893, 7.2825, 7.4446], [7.6580, 7.8797, 7.9420]],
[[-10.6869, -10.3250, -10.3471], [-10.4228, -9.9868, -9.7132], [-11.0405, -11.0221, -10.7318]],
[[-3.3089, -2.8539, -2.6740], [-3.2706, -2.5621, -2.5108], [-3.2534, -2.6615, -2.6651]],
]
)
tf.debugging.assert_near(logits[0, :3, :3, :3], expected_slice, rtol=1e-4, atol=1e-4)
......@@ -51,6 +51,7 @@ src/transformers/models/mbart/modeling_mbart.py
src/transformers/models/mobilebert/modeling_mobilebert.py
src/transformers/models/mobilebert/modeling_tf_mobilebert.py
src/transformers/models/mobilevit/modeling_mobilevit.py
src/transformers/models/mobilevit/modeling_tf_mobilevit.py
src/transformers/models/opt/modeling_opt.py
src/transformers/models/opt/modeling_tf_opt.py
src/transformers/models/owlvit/modeling_owlvit.py
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment