Unverified Commit 88f50a1e authored by Denisa Roberts's avatar Denisa Roberts Committed by GitHub
Browse files

Add TensorFlow implementation of EfficientFormer (#22620)

* Add tf code for efficientformer

* Fix return dict bug - return last hidden state after last stage

* Fix corresponding return dict bug

* Override test tol

* Change default values of training to False

* Set training to default False X3

* Rm axis from ln

* Set init in dense projection

* Rm debug stuff

* Make style; all tests pass.

* Modify year to 2023

* Fix attention biases codes

* Update the shape list logic

* Add a batch norm eps config

* Remove extract comments in test files

* Add conditional attn and hidden states return for serving output

* Change channel dim checking logic

* Add exception for withteacher model in training mode

* Revert layer count for now

* Add layer count for conditional layer naming

* Transpose for conv happens only in main layer

* Make tests smaller

* Make style

* Update doc

* Rm from_pt

* Change to actual expect image class label

* Remove stray print in tests

* Update image processor test

* Remove the old serving output logic

* Make style

* Make style

* Complete test
parent 9fea71b4
......@@ -313,7 +313,7 @@ Flax), PyTorch, and/or TensorFlow.
| DonutSwin | ❌ | ❌ | ✅ | ❌ | ❌ |
| DPR | ✅ | ✅ | ✅ | ✅ | ❌ |
| DPT | ❌ | ❌ | ✅ | ❌ | ❌ |
| EfficientFormer | ❌ | ❌ | ✅ | | ❌ |
| EfficientFormer | ❌ | ❌ | ✅ | | ❌ |
| EfficientNet | ❌ | ❌ | ✅ | ❌ | ❌ |
| ELECTRA | ✅ | ✅ | ✅ | ✅ | ✅ |
| Encoder decoder | ❌ | ❌ | ✅ | ✅ | ✅ |
......
......@@ -37,7 +37,7 @@ EfficientFormer-L7, obtains 83.3% accuracy with only 7.0 ms latency. Our work pr
reach extremely low latency on mobile devices while maintaining high performance.*
This model was contributed by [novice03](https://huggingface.co/novice03) and [Bearnardd](https://huggingface.co/Bearnardd).
The original code can be found [here](https://github.com/snap-research/EfficientFormer).
The original code can be found [here](https://github.com/snap-research/EfficientFormer). The TensorFlow version of this model was added by [D-Roberts](https://huggingface.co/D-Roberts).
## Documentation resources
......@@ -66,3 +66,18 @@ The original code can be found [here](https://github.com/snap-research/Efficient
[[autodoc]] EfficientFormerForImageClassificationWithTeacher
- forward
## TFEfficientFormerModel
[[autodoc]] TFEfficientFormerModel
- call
## TFEfficientFormerForImageClassification
[[autodoc]] TFEfficientFormerForImageClassification
- call
## TFEfficientFormerForImageClassificationWithTeacher
[[autodoc]] TFEfficientFormerForImageClassificationWithTeacher
- call
......@@ -3142,6 +3142,15 @@ else:
"TFDPRReader",
]
)
_import_structure["models.efficientformer"].extend(
[
"TF_EFFICIENTFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFEfficientFormerForImageClassification",
"TFEfficientFormerForImageClassificationWithTeacher",
"TFEfficientFormerModel",
"TFEfficientFormerPreTrainedModel",
]
)
_import_structure["models.electra"].extend(
[
"TF_ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST",
......@@ -6471,6 +6480,13 @@ if TYPE_CHECKING:
TFDPRQuestionEncoder,
TFDPRReader,
)
from .models.efficientformer import (
TF_EFFICIENTFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
TFEfficientFormerForImageClassification,
TFEfficientFormerForImageClassificationWithTeacher,
TFEfficientFormerModel,
TFEfficientFormerPreTrainedModel,
)
from .models.electra import (
TF_ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST,
TFElectraForMaskedLM,
......
......@@ -47,6 +47,7 @@ TF_MODEL_MAPPING_NAMES = OrderedDict(
("deit", "TFDeiTModel"),
("distilbert", "TFDistilBertModel"),
("dpr", "TFDPRQuestionEncoder"),
("efficientformer", "TFEfficientFormerModel"),
("electra", "TFElectraModel"),
("esm", "TFEsmModel"),
("flaubert", "TFFlaubertModel"),
......@@ -202,6 +203,10 @@ TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
("cvt", "TFCvtForImageClassification"),
("data2vec-vision", "TFData2VecVisionForImageClassification"),
("deit", ("TFDeiTForImageClassification", "TFDeiTForImageClassificationWithTeacher")),
(
"efficientformer",
("TFEfficientFormerForImageClassification", "TFEfficientFormerForImageClassificationWithTeacher"),
),
("mobilevit", "TFMobileViTForImageClassification"),
("regnet", "TFRegNetForImageClassification"),
("resnet", "TFResNetForImageClassification"),
......
......@@ -13,7 +13,13 @@
# limitations under the License.
from typing import TYPE_CHECKING
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_tf_available,
is_torch_available,
is_vision_available,
)
_import_structure = {
......@@ -45,6 +51,20 @@ else:
"EfficientFormerPreTrainedModel",
]
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_tf_efficientformer"] = [
"TF_EFFICIENTFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFEfficientFormerForImageClassification",
"TFEfficientFormerForImageClassificationWithTeacher",
"TFEfficientFormerModel",
"TFEfficientFormerPreTrainedModel",
]
if TYPE_CHECKING:
from .configuration_efficientformer import EFFICIENTFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, EfficientFormerConfig
......@@ -69,6 +89,19 @@ if TYPE_CHECKING:
EfficientFormerModel,
EfficientFormerPreTrainedModel,
)
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_tf_efficientformer import (
TF_EFFICIENTFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
TFEfficientFormerForImageClassification,
TFEfficientFormerForImageClassificationWithTeacher,
TFEfficientFormerModel,
TFEfficientFormerPreTrainedModel,
)
else:
import sys
......
......@@ -52,7 +52,7 @@ class EfficientFormerConfig(PretrainedConfig):
The size of the key in meta3D block.
attention_ratio (`int`, *optional*, defaults to 4):
Ratio of the dimension of the query and value to the dimension of the key in MSHA block
resolution (`int`, *optional*, defaults to 5)
resolution (`int`, *optional*, defaults to 7)
Size of each patch
num_hidden_layers (`int`, *optional*, defaults to 5):
Number of hidden layers in the Transformer encoder.
......@@ -91,6 +91,8 @@ class EfficientFormerConfig(PretrainedConfig):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
The epsilon used by the layer normalization layers.
image_size (`int`, *optional*, defaults to `224`):
The size (resolution) of each image.
Example:
......@@ -136,6 +138,8 @@ class EfficientFormerConfig(PretrainedConfig):
hidden_act: str = "gelu",
initializer_range: float = 0.02,
layer_norm_eps: float = 1e-12,
image_size: int = 224,
batch_norm_eps: float = 1e-05,
**kwargs,
) -> None:
super().__init__(**kwargs)
......@@ -165,3 +169,5 @@ class EfficientFormerConfig(PretrainedConfig):
self.distillation = distillation
self.use_layer_scale = use_layer_scale
self.layer_scale_init_value = layer_scale_init_value
self.image_size = image_size
self.batch_norm_eps = batch_norm_eps
......@@ -43,7 +43,7 @@ _CONFIG_FOR_DOC = "EfficientFormerConfig"
# Base docstring
_CHECKPOINT_FOR_DOC = "snap-research/efficientformer-l1-300"
_EXPECTED_OUTPUT_SHAPE = [1, 197, 768]
_EXPECTED_OUTPUT_SHAPE = [1, 49, 448]
# Image classification docstring
_IMAGE_CLASS_CHECKPOINT = "snap-research/efficientformer-l1-300"
......@@ -73,7 +73,7 @@ class EfficientFormerPatchEmbeddings(nn.Module):
stride=config.downsample_stride,
padding=config.downsample_pad,
)
self.norm = nn.BatchNorm2d(embed_dim) if apply_norm else nn.Identity()
self.norm = nn.BatchNorm2d(embed_dim, eps=config.batch_norm_eps) if apply_norm else nn.Identity()
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, height, width = pixel_values.shape
......@@ -157,10 +157,10 @@ class EfficientFormerConvStem(nn.Module):
super().__init__()
self.convolution1 = nn.Conv2d(config.num_channels, out_channels // 2, kernel_size=3, stride=2, padding=1)
self.batchnorm_before = nn.BatchNorm2d(out_channels // 2)
self.batchnorm_before = nn.BatchNorm2d(out_channels // 2, eps=config.batch_norm_eps)
self.convolution2 = nn.Conv2d(out_channels // 2, out_channels, kernel_size=3, stride=2, padding=1)
self.batchnorm_after = nn.BatchNorm2d(out_channels)
self.batchnorm_after = nn.BatchNorm2d(out_channels, eps=config.batch_norm_eps)
self.activation = nn.ReLU()
......@@ -224,24 +224,24 @@ class EfficientFormerConvMlp(nn.Module):
hidden_features = hidden_features or in_features
self.convolution1 = nn.Conv2d(in_features, hidden_features, 1)
self.actvation = ACT2FN[config.hidden_act]
self.activation = ACT2FN[config.hidden_act]
self.convolution2 = nn.Conv2d(hidden_features, out_features, 1)
self.dropout = nn.Dropout(drop)
self.batchnorm_before = nn.BatchNorm2d(hidden_features)
self.batchnorm_after = nn.BatchNorm2d(out_features)
self.batchnorm_before = nn.BatchNorm2d(hidden_features, eps=config.batch_norm_eps)
self.batchnorm_after = nn.BatchNorm2d(out_features, eps=config.batch_norm_eps)
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
hidden_state = self.convolution1(hidden_state)
hidden_state = self.batchnorm_before(hidden_state)
hidden_state = self.actvation(hidden_state)
hidden_state = self.activation(hidden_state)
hidden_state = self.dropout(hidden_state)
hidden_state = self.convolution2(hidden_state)
hidden_state = self.batchnorm_after(hidden_state)
hidden_state = self.dropout(hidden_state)
return hidden_state
......@@ -266,7 +266,7 @@ def drop_path(input, drop_prob: float = 0.0, training: bool = False):
return output
# Copied from transformers.models.beit.modeling_beit.BeitDropPath with Beit->Bit
# Copied from transformers.models.beit.modeling_beit.BeitDropPath with Beit->EfficientFormer
class EfficientFormerDropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
......@@ -301,8 +301,10 @@ class EfficientFormerMeta3D(nn.Module):
attention_ratio=config.attention_ratio,
resolution=config.resolution,
)
self.layernorm1 = nn.LayerNorm(dim)
self.layernorm2 = nn.LayerNorm(dim)
self.layernorm1 = nn.LayerNorm(dim, eps=config.layer_norm_eps)
self.layernorm2 = nn.LayerNorm(dim, eps=config.layer_norm_eps)
mlp_hidden_dim = int(dim * config.mlp_expansion_ratio)
self.mlp = EfficientFormerDenseMlp(config, in_features=dim, hidden_features=mlp_hidden_dim)
......@@ -346,15 +348,20 @@ class EfficientFormerMeta3DLayers(nn.Module):
def forward(self, hidden_states: torch.Tensor, output_attentions: bool = False) -> Tuple[torch.Tensor]:
all_attention_outputs = () if output_attentions else None
for layer_module in self.blocks:
if isinstance(hidden_states, tuple):
hidden_states = hidden_states[0]
hidden_states = layer_module(hidden_states, output_attentions)
if output_attentions:
all_attention_outputs = all_attention_outputs + (hidden_states[1],)
if output_attentions:
outputs = (hidden_states[0],) + all_attention_outputs
return outputs
return hidden_states
......@@ -379,6 +386,7 @@ class EfficientFormerMeta4D(nn.Module):
if self.use_layer_scale:
layer_output = hidden_states + self.drop_path(self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) * outputs)
layer_output = layer_output + self.drop_path(
self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) * self.mlp(layer_output)
)
......@@ -398,6 +406,7 @@ class EfficientFormerMeta4DLayers(nn.Module):
drop_paths = [
config.drop_path_rate * (block_idx + sum(config.depths[:stage_idx])) for block_idx in range(num_layers)
]
self.blocks = nn.ModuleList(
[
EfficientFormerMeta4D(config, config.hidden_sizes[stage_idx], drop_path=drop_path)
......@@ -446,6 +455,7 @@ class EfficientFormerEncoder(nn.Module):
for i in range(num_intermediate_stages)
]
intermediate_stages = []
for i in range(num_intermediate_stages):
intermediate_stages.append(EfficientFormerIntermediateStage(config, i))
if downsamples[i]:
......@@ -475,6 +485,7 @@ class EfficientFormerEncoder(nn.Module):
all_hidden_states = all_hidden_states + (hidden_states,)
layer_output = self.last_stage(hidden_states, output_attentions=output_attentions)
if output_attentions:
all_self_attentions = all_self_attentions + layer_output[1:]
......@@ -482,7 +493,7 @@ class EfficientFormerEncoder(nn.Module):
all_hidden_states = all_hidden_states + (layer_output[0],)
if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
return tuple(v for v in [layer_output[0], all_hidden_states, all_self_attentions] if v is not None)
return BaseModelOutput(
last_hidden_state=layer_output[0],
......
......@@ -1099,6 +1099,37 @@ class TFDPRReader(metaclass=DummyObject):
requires_backends(self, ["tf"])
TF_EFFICIENTFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None
class TFEfficientFormerForImageClassification(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFEfficientFormerForImageClassificationWithTeacher(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFEfficientFormerModel(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFEfficientFormerPreTrainedModel(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
TF_ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST = None
......
......@@ -18,6 +18,7 @@
import inspect
import unittest
import warnings
from typing import List
from transformers import EfficientFormerConfig
from transformers.models.auto import get_values
......@@ -55,15 +56,16 @@ class EfficientFormerModelTester:
self,
parent,
batch_size: int = 13,
image_size: int = 224,
image_size: int = 64,
patch_size: int = 2,
embed_dim: int = 48, # last embed dim of stem
embed_dim: int = 3,
num_channels: int = 3,
is_training: bool = True,
use_labels: bool = True,
hidden_size: int = 448,
num_hidden_layers: int = 7, # For the l1
num_attention_heads: int = 8,
hidden_size: int = 128,
hidden_sizes=[16, 32, 64, 128],
num_hidden_layers: int = 7,
num_attention_heads: int = 4,
intermediate_size: int = 37,
hidden_act: str = "gelu",
hidden_dropout_prob: float = 0.1,
......@@ -71,7 +73,11 @@ class EfficientFormerModelTester:
type_sequence_label_size: int = 10,
initializer_range: float = 0.02,
encoder_stride: int = 2,
num_attention_outputs: int = 1, # For l1
num_attention_outputs: int = 1,
dim: int = 128,
depths: List[int] = [2, 2, 2, 2],
resolution: int = 2,
mlp_expansion_ratio: int = 2,
):
self.parent = parent
self.batch_size = batch_size
......@@ -93,6 +99,11 @@ class EfficientFormerModelTester:
self.num_attention_outputs = num_attention_outputs
self.embed_dim = embed_dim
self.seq_length = embed_dim + 1
self.resolution = resolution
self.depths = depths
self.hidden_sizes = hidden_sizes
self.dim = dim
self.mlp_expansion_ratio = mlp_expansion_ratio
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
......@@ -119,6 +130,11 @@ class EfficientFormerModelTester:
is_decoder=False,
initializer_range=self.initializer_range,
encoder_stride=self.encoder_stride,
resolution=self.resolution,
depths=self.depths,
hidden_sizes=self.hidden_sizes,
dim=self.dim,
mlp_expansion_ratio=self.mlp_expansion_ratio,
)
def create_and_check_model(self, config, pixel_values, labels):
......@@ -379,6 +395,7 @@ class EfficientFormerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.T
encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len)
encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
chunk_length = getattr(self.model_tester, "chunk_length", None)
if chunk_length is not None and hasattr(self.model_tester, "num_hashes"):
encoder_seq_length = encoder_seq_length * self.model_tester.num_hashes
......
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Testing suite for the TensorFlow EfficientFormer model. """
import inspect
import unittest
from typing import List
import numpy as np
from transformers import EfficientFormerConfig
from transformers.testing_utils import require_tf, require_vision, slow
from transformers.utils import cached_property, is_tf_available, is_vision_available
from ...test_configuration_common import ConfigTester
from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
if is_tf_available():
import tensorflow as tf
from transformers import (
TFEfficientFormerForImageClassification,
TFEfficientFormerForImageClassificationWithTeacher,
TFEfficientFormerModel,
)
from transformers.models.efficientformer.modeling_tf_efficientformer import (
TF_EFFICIENTFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
)
if is_vision_available():
from PIL import Image
from transformers import EfficientFormerImageProcessor
class TFEfficientFormerModelTester:
def __init__(
self,
parent,
batch_size: int = 13,
image_size: int = 64,
patch_size: int = 2,
embed_dim: int = 3,
num_channels: int = 3,
is_training: bool = True,
use_labels: bool = True,
hidden_size: int = 128,
hidden_sizes=[16, 32, 64, 128],
num_hidden_layers: int = 7,
num_attention_heads: int = 4,
intermediate_size: int = 37,
hidden_act: str = "gelu",
hidden_dropout_prob: float = 0.1,
attention_probs_dropout_prob: float = 0.1,
type_sequence_label_size: int = 10,
initializer_range: float = 0.02,
encoder_stride: int = 2,
num_attention_outputs: int = 1,
dim: int = 128,
depths: List[int] = [2, 2, 2, 2],
resolution: int = 2,
mlp_expansion_ratio: int = 2,
):
self.parent = parent
self.batch_size = batch_size
self.image_size = image_size
self.patch_size = patch_size
self.num_channels = num_channels
self.is_training = is_training
self.use_labels = use_labels
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range
self.encoder_stride = encoder_stride
self.num_attention_outputs = num_attention_outputs
self.embed_dim = embed_dim
self.seq_length = embed_dim + 1
self.resolution = resolution
self.depths = depths
self.hidden_sizes = hidden_sizes
self.dim = dim
self.mlp_expansion_ratio = mlp_expansion_ratio
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
labels = None
if self.use_labels:
labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
config = self.get_config()
return config, pixel_values, labels
def get_config(self):
return EfficientFormerConfig(
image_size=self.image_size,
patch_size=self.patch_size,
num_channels=self.num_channels,
hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
intermediate_size=self.intermediate_size,
hidden_act=self.hidden_act,
hidden_dropout_prob=self.hidden_dropout_prob,
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
is_decoder=False,
initializer_range=self.initializer_range,
encoder_stride=self.encoder_stride,
resolution=self.resolution,
depths=self.depths,
hidden_sizes=self.hidden_sizes,
dim=self.dim,
mlp_expansion_ratio=self.mlp_expansion_ratio,
)
def create_and_check_model(self, config, pixel_values, labels):
model = TFEfficientFormerModel(config=config)
result = model(pixel_values, training=False)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
def create_and_check_for_image_classification(self, config, pixel_values, labels):
config.num_labels = self.type_sequence_label_size
model = TFEfficientFormerForImageClassification(config)
result = model(pixel_values, labels=labels, training=False)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
# test greyscale images
config.num_channels = 1
model = TFEfficientFormerForImageClassification(config)
pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
result = model(pixel_values, labels=labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, pixel_values, labels = config_and_inputs
inputs_dict = {"pixel_values": pixel_values}
return config, inputs_dict
@require_tf
class TFEfficientFormerModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
"""
Here we also overwrite some of the tests of test_modeling_tf_common.py, as EfficientFormer does not use input_ids,
inputs_embeds, attention_mask and seq_length.
"""
all_model_classes = (
(
TFEfficientFormerModel,
TFEfficientFormerForImageClassificationWithTeacher,
TFEfficientFormerForImageClassification,
)
if is_tf_available()
else ()
)
pipeline_model_mapping = (
{
"feature-extraction": TFEfficientFormerModel,
"image-classification": (
TFEfficientFormerForImageClassification,
TFEfficientFormerForImageClassificationWithTeacher,
),
}
if is_tf_available()
else {}
)
fx_compatible = False
test_pruning = False
test_resize_embeddings = False
test_head_masking = False
test_onnx = False
def setUp(self):
self.model_tester = TFEfficientFormerModelTester(self)
self.config_tester = ConfigTester(
self, config_class=EfficientFormerConfig, has_text_modality=False, hidden_size=37
)
def test_config(self):
self.config_tester.run_common_tests()
@unittest.skip(reason="EfficientFormer does not use inputs_embeds")
def test_inputs_embeds(self):
pass
@unittest.skip(reason="EfficientFormer does not support input and output embeddings")
def test_model_common_attributes(self):
pass
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
signature = inspect.signature(model.call)
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()]
expected_arg_names = ["pixel_values"]
self.assertListEqual(arg_names[:1], expected_arg_names)
def test_hidden_states_output(self):
def check_hidden_states_output(inputs_dict, config, model_class):
model = model_class(config)
outputs = model(**self._prepare_for_class(inputs_dict, model_class), training=False)
hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states
expected_num_layers = getattr(
self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
)
self.assertEqual(len(hidden_states), expected_num_layers)
if hasattr(self.model_tester, "encoder_seq_length"):
seq_length = self.model_tester.encoder_seq_length
if hasattr(self.model_tester, "chunk_length") and self.model_tester.chunk_length > 1:
seq_length = seq_length * self.model_tester.chunk_length
else:
seq_length = self.model_tester.seq_length
self.assertListEqual(
list(hidden_states[-1].shape[-2:]),
[seq_length, self.model_tester.hidden_size],
)
if config.is_encoder_decoder:
hidden_states = outputs.decoder_hidden_states
self.asseretIsInstance(hidden_states, (list, tuple))
self.assertEqual(len(hidden_states), expected_num_layers)
seq_len = getattr(self.model_tester, "seq_length", None)
decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", seq_len)
self.assertListEqual(
list(hidden_states[-1].shape[-2:]),
[decoder_seq_length, self.model_tester.hidden_size],
)
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
inputs_dict["output_hidden_states"] = True
check_hidden_states_output(inputs_dict, config, model_class)
# check that output_hidden_states also work using config
del inputs_dict["output_hidden_states"]
config.output_hidden_states = True
check_hidden_states_output(inputs_dict, config, model_class)
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
if return_labels:
if model_class.__name__ == "TFEfficientFormerForImageClassificationWithTeacher":
del inputs_dict["labels"]
return inputs_dict
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
@unittest.skip(reason="EfficientFormer does not implement masked image modeling yet")
def test_for_masked_image_modeling(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_masked_image_modeling(*config_and_inputs)
def test_for_image_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
@slow
def test_model_from_pretrained(self):
for model_name in TF_EFFICIENTFORMER_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
model = TFEfficientFormerModel.from_pretrained(model_name)
self.assertIsNotNone(model)
def test_attention_outputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True
seq_len = getattr(self.model_tester, "seq_length", None)
encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len)
encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
chunk_length = getattr(self.model_tester, "chunk_length", None)
if chunk_length is not None and hasattr(self.model_tester, "num_hashes"):
encoder_seq_length = encoder_seq_length * self.model_tester.num_hashes
for model_class in self.all_model_classes:
inputs_dict["output_attentions"] = True
inputs_dict["output_hidden_states"] = False
config.return_dict = True
model = model_class(config)
outputs = model(**self._prepare_for_class(inputs_dict, model_class), training=False)
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
self.assertEqual(len(attentions), self.model_tester.num_attention_outputs)
# check that output_attentions also work using config
del inputs_dict["output_attentions"]
config.output_attentions = True
model = model_class(config)
outputs = model(**self._prepare_for_class(inputs_dict, model_class), training=False)
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
self.assertEqual(len(attentions), self.model_tester.num_attention_outputs)
if chunk_length is not None:
self.assertListEqual(
list(attentions[0].shape[-4:]),
[self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length],
)
else:
self.assertListEqual(
list(attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
)
# We will verify our results on an image of cute cats
def prepare_img():
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
return image
@require_tf
@require_vision
class EfficientFormerModelIntegrationTest(unittest.TestCase):
@cached_property
def default_image_processor(self):
return (
EfficientFormerImageProcessor.from_pretrained("snap-research/efficientformer-l1-300")
if is_vision_available()
else None
)
@slow
def test_inference_image_classification_head(self):
model = TFEfficientFormerForImageClassification.from_pretrained("snap-research/efficientformer-l1-300")
image_processor = self.default_image_processor
image = prepare_img()
inputs = image_processor(images=image, return_tensors="tf")
# forward pass
outputs = model(**inputs, training=False)
# verify the logits
expected_shape = tf.TensorShape((1, 1000))
self.assertEqual(outputs.logits.shape, expected_shape)
expected_slice = tf.constant([-0.0555, 0.4825, -0.0852])
self.assertTrue(np.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))
@slow
def test_inference_image_classification_head_with_teacher(self):
model = TFEfficientFormerForImageClassificationWithTeacher.from_pretrained(
"snap-research/efficientformer-l1-300"
)
image_processor = self.default_image_processor
image = prepare_img()
inputs = image_processor(images=image, return_tensors="tf")
# forward pass
outputs = model(**inputs, training=False)
# verify the logits
expected_shape = tf.TensorShape((1, 1000))
self.assertEqual(outputs.logits.shape, expected_shape)
expected_slice = tf.constant([-0.1312, 0.4353, -1.0499])
self.assertTrue(np.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))
......@@ -82,6 +82,7 @@ src/transformers/models/dpt/modeling_dpt.py
src/transformers/models/electra/configuration_electra.py
src/transformers/models/electra/modeling_electra.py
src/transformers/models/electra/modeling_tf_electra.py
src/transformers/models/efficientformer/modeling_tf_efficientformer.py
src/transformers/models/ernie/configuration_ernie.py
src/transformers/models/ernie_m/configuration_ernie_m.py
src/transformers/models/ernie_m/modeling_ernie_m.py
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment