"vscode:/vscode.git/clone" did not exist on "56c3f07a13be9e5bb1408479c0d9ffc75f03da26"
Unverified Commit f8afb2b2 authored by Andi Powers Holmes's avatar Andi Powers Holmes Committed by GitHub
Browse files

Add TensorFlow implementation of ConvNeXTv2 (#25558)

* Add type annotations to TFConvNextDropPath

* Use tf.debugging.assert_equal for TFConvNextEmbeddings shape check

* Add TensorFlow implementation of ConvNeXTV2

* check_docstrings: add TFConvNextV2Model to exclusions

TFConvNextV2Model and TFConvNextV2ForImageClassification have docstrings
which are equivalent to their PyTorch cousins, but a parsing issue prevents them
from passing the test.

Adding exclusions for these two classes as discussed in #25558.
parent 391d14e8
......@@ -97,7 +97,7 @@ Flax), PyTorch, and/or TensorFlow.
| [Conditional DETR](model_doc/conditional_detr) | ✅ | ❌ | ❌ |
| [ConvBERT](model_doc/convbert) | ✅ | ✅ | ❌ |
| [ConvNeXT](model_doc/convnext) | ✅ | ✅ | ❌ |
| [ConvNeXTV2](model_doc/convnextv2) | ✅ | | ❌ |
| [ConvNeXTV2](model_doc/convnextv2) | ✅ | | ❌ |
| [CPM](model_doc/cpm) | ✅ | ✅ | ✅ |
| [CPM-Ant](model_doc/cpmant) | ✅ | ❌ | ❌ |
| [CTRL](model_doc/ctrl) | ✅ | ✅ | ❌ |
......
......@@ -59,3 +59,14 @@ If you're interested in submitting a resource to be included here, please feel f
[[autodoc]] ConvNextV2ForImageClassification
- forward
## TFConvNextV2Model
[[autodoc]] TFConvNextV2Model
- call
## TFConvNextV2ForImageClassification
[[autodoc]] TFConvNextV2ForImageClassification
- call
......@@ -3415,6 +3415,13 @@ else:
"TFConvNextPreTrainedModel",
]
)
_import_structure["models.convnextv2"].extend(
[
"TFConvNextV2ForImageClassification",
"TFConvNextV2Model",
"TFConvNextV2PreTrainedModel",
]
)
_import_structure["models.ctrl"].extend(
[
"TF_CTRL_PRETRAINED_MODEL_ARCHIVE_LIST",
......@@ -7127,6 +7134,11 @@ if TYPE_CHECKING:
TFConvBertPreTrainedModel,
)
from .models.convnext import TFConvNextForImageClassification, TFConvNextModel, TFConvNextPreTrainedModel
from .models.convnextv2 import (
TFConvNextV2ForImageClassification,
TFConvNextV2Model,
TFConvNextV2PreTrainedModel,
)
from .models.ctrl import (
TF_CTRL_PRETRAINED_MODEL_ARCHIVE_LIST,
TFCTRLForSequenceClassification,
......
......@@ -39,6 +39,7 @@ TF_MODEL_MAPPING_NAMES = OrderedDict(
("clip", "TFCLIPModel"),
("convbert", "TFConvBertModel"),
("convnext", "TFConvNextModel"),
("convnextv2", "TFConvNextV2Model"),
("ctrl", "TFCTRLModel"),
("cvt", "TFCvtModel"),
("data2vec-vision", "TFData2VecVisionModel"),
......@@ -200,6 +201,7 @@ TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
# Model for Image-classsification
("convnext", "TFConvNextForImageClassification"),
("convnextv2", "TFConvNextV2ForImageClassification"),
("cvt", "TFCvtForImageClassification"),
("data2vec-vision", "TFData2VecVisionForImageClassification"),
("deit", ("TFDeiTForImageClassification", "TFDeiTForImageClassificationWithTeacher")),
......
......@@ -17,7 +17,7 @@
from __future__ import annotations
from typing import Optional, Tuple, Union
from typing import List, Optional, Tuple, Union
import numpy as np
import tensorflow as tf
......@@ -50,11 +50,11 @@ class TFConvNextDropPath(tf.keras.layers.Layer):
(1) github.com:rwightman/pytorch-image-models
"""
def __init__(self, drop_path, **kwargs):
def __init__(self, drop_path: float, **kwargs):
super().__init__(**kwargs)
self.drop_path = drop_path
def call(self, x, training=None):
def call(self, x: tf.Tensor, training=None):
if training:
keep_prob = 1 - self.drop_path
shape = (tf.shape(x)[0],) + (1,) * (len(tf.shape(x)) - 1)
......@@ -69,7 +69,7 @@ class TFConvNextEmbeddings(tf.keras.layers.Layer):
found in src/transformers/models/swin/modeling_swin.py.
"""
def __init__(self, config, **kwargs):
def __init__(self, config: ConvNextConfig, **kwargs):
super().__init__(**kwargs)
self.patch_embeddings = tf.keras.layers.Conv2D(
filters=config.hidden_sizes[0],
......@@ -77,7 +77,7 @@ class TFConvNextEmbeddings(tf.keras.layers.Layer):
strides=config.patch_size,
name="patch_embeddings",
kernel_initializer=get_initializer(config.initializer_range),
bias_initializer="zeros",
bias_initializer=tf.keras.initializers.Zeros(),
)
self.layernorm = tf.keras.layers.LayerNormalization(epsilon=1e-6, name="layernorm")
self.num_channels = config.num_channels
......@@ -86,15 +86,15 @@ class TFConvNextEmbeddings(tf.keras.layers.Layer):
if isinstance(pixel_values, dict):
pixel_values = pixel_values["pixel_values"]
num_channels = shape_list(pixel_values)[1]
if tf.executing_eagerly() and num_channels != self.num_channels:
raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
tf.debugging.assert_equal(
shape_list(pixel_values)[1],
self.num_channels,
message="Make sure that the channel dimension of the pixel values match with the one set in the configuration.",
)
# When running on CPU, `tf.keras.layers.Conv2D` doesn't support `NCHW` format.
# So change the input format from `NCHW` to `NHWC`.
# shape = (batch_size, in_height, in_width, in_channels=num_channels)
# shape = (batch_size, in_height, in_width, in_channels)
pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1))
embeddings = self.patch_embeddings(pixel_values)
......@@ -188,15 +188,28 @@ class TFConvNextStage(tf.keras.layers.Layer):
"""ConvNext stage, consisting of an optional downsampling layer + multiple residual blocks.
Args:
config ([`ConvNextConfig`]): Model configuration class.
in_channels (`int`): Number of input channels.
out_channels (`int`): Number of output channels.
depth (`int`): Number of residual blocks.
drop_path_rates(`List[float]`): Stochastic depth rates for each layer.
config (`ConvNextV2Config`):
Model configuration class.
in_channels (`int`):
Number of input channels.
out_channels (`int`):
Number of output channels.
depth (`int`):
Number of residual blocks.
drop_path_rates(`List[float]`):
Stochastic depth rates for each layer.
"""
def __init__(
self, config, in_channels, out_channels, kernel_size=2, stride=2, depth=2, drop_path_rates=None, **kwargs
self,
config: ConvNextConfig,
in_channels: int,
out_channels: int,
kernel_size: int = 2,
stride: int = 2,
depth: int = 2,
drop_path_rates: Optional[List[float]] = None,
**kwargs,
):
super().__init__(**kwargs)
if in_channels != out_channels or stride > 1:
......@@ -215,7 +228,7 @@ class TFConvNextStage(tf.keras.layers.Layer):
kernel_size=kernel_size,
strides=stride,
kernel_initializer=get_initializer(config.initializer_range),
bias_initializer="zeros",
bias_initializer=tf.keras.initializers.Zeros(),
name="downsampling_layer.1",
),
]
......
......@@ -22,6 +22,7 @@ from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_torch_available,
is_tf_available,
)
......@@ -46,6 +47,17 @@ else:
"ConvNextV2Backbone",
]
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_tf_convnextv2"] = [
"TFConvNextV2ForImageClassification",
"TFConvNextV2Model",
"TFConvNextV2PreTrainedModel",
]
if TYPE_CHECKING:
from .configuration_convnextv2 import (
......@@ -67,6 +79,18 @@ if TYPE_CHECKING:
ConvNextV2PreTrainedModel,
)
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_tf_convnextv2 import (
TFConvNextV2ForImageClassification,
TFConvNextV2Model,
TFConvNextV2PreTrainedModel,
)
else:
import sys
......
This diff is collapsed.
......@@ -337,11 +337,11 @@ class TFEfficientFormerDropPath(tf.keras.layers.Layer):
(1) github.com:rwightman/pytorch-image-models
"""
def __init__(self, drop_path, **kwargs):
def __init__(self, drop_path: float, **kwargs):
super().__init__(**kwargs)
self.drop_path = drop_path
def call(self, x, training=None):
def call(self, x: tf.Tensor, training=None):
if training:
keep_prob = 1 - self.drop_path
shape = (tf.shape(x)[0],) + (1,) * (len(tf.shape(x)) - 1)
......
......@@ -62,11 +62,11 @@ class TFSegformerDropPath(tf.keras.layers.Layer):
(1) github.com:rwightman/pytorch-image-models
"""
def __init__(self, drop_path, **kwargs):
def __init__(self, drop_path: float, **kwargs):
super().__init__(**kwargs)
self.drop_path = drop_path
def call(self, x, training=None):
def call(self, x: tf.Tensor, training=None):
if training:
keep_prob = 1 - self.drop_path
shape = (tf.shape(x)[0],) + (1,) * (len(tf.shape(x)) - 1)
......
......@@ -836,6 +836,27 @@ class TFConvNextPreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["tf"])
class TFConvNextV2ForImageClassification(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFConvNextV2Model(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFConvNextV2PreTrainedModel(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
TF_CTRL_PRETRAINED_MODEL_ARCHIVE_LIST = None
......
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Testing suite for the TensorFlow ConvNext model. """
from __future__ import annotations
import inspect
import unittest
from typing import List, Tuple
import numpy as np
from transformers import ConvNextV2Config
from transformers.testing_utils import require_tf, require_vision, slow
from transformers.utils import cached_property, is_tf_available, is_vision_available
from ...test_configuration_common import ConfigTester
from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
if is_tf_available():
import tensorflow as tf
from transformers import TFConvNextV2ForImageClassification, TFConvNextV2Model
if is_vision_available():
from PIL import Image
from transformers import ConvNextImageProcessor
class TFConvNextV2ModelTester:
def __init__(
self,
parent,
batch_size=13,
image_size=32,
num_channels=3,
num_stages=4,
hidden_sizes=[10, 20, 30, 40],
depths=[2, 2, 3, 2],
is_training=True,
use_labels=True,
intermediate_size=37,
hidden_act="gelu",
type_sequence_label_size=10,
initializer_range=0.02,
num_labels=3,
scope=None,
):
self.parent = parent
self.batch_size = batch_size
self.image_size = image_size
self.num_channels = num_channels
self.num_stages = num_stages
self.hidden_sizes = hidden_sizes
self.depths = depths
self.is_training = is_training
self.use_labels = use_labels
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range
self.scope = scope
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
labels = None
if self.use_labels:
labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
config = self.get_config()
return config, pixel_values, labels
def get_config(self):
return ConvNextV2Config(
num_channels=self.num_channels,
hidden_sizes=self.hidden_sizes,
depths=self.depths,
num_stages=self.num_stages,
hidden_act=self.hidden_act,
is_decoder=False,
initializer_range=self.initializer_range,
)
def create_and_check_model(self, config, pixel_values, labels):
model = TFConvNextV2Model(config=config)
result = model(pixel_values, training=False)
# expected last hidden states: batch_size, channels, height // 32, width // 32
self.parent.assertEqual(
result.last_hidden_state.shape,
(self.batch_size, self.hidden_sizes[-1], self.image_size // 32, self.image_size // 32),
)
def create_and_check_for_image_classification(self, config, pixel_values, labels):
config.num_labels = self.type_sequence_label_size
model = TFConvNextV2ForImageClassification(config)
result = model(pixel_values, labels=labels, training=False)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, pixel_values, labels = config_and_inputs
inputs_dict = {"pixel_values": pixel_values}
return config, inputs_dict
@require_tf
class TFConvNextV2ModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
"""
Here we also overwrite some of the tests of test_modeling_common.py, as ConvNext does not use input_ids, inputs_embeds,
attention_mask and seq_length.
"""
all_model_classes = (TFConvNextV2Model, TFConvNextV2ForImageClassification) if is_tf_available() else ()
pipeline_model_mapping = (
{"feature-extraction": TFConvNextV2Model, "image-classification": TFConvNextV2ForImageClassification}
if is_tf_available()
else {}
)
test_pruning = False
test_onnx = False
test_resize_embeddings = False
test_head_masking = False
has_attentions = False
def setUp(self):
self.model_tester = TFConvNextV2ModelTester(self)
self.config_tester = ConfigTester(
self,
config_class=ConvNextV2Config,
has_text_modality=False,
hidden_size=37,
)
@unittest.skip(reason="ConvNext does not use inputs_embeds")
def test_inputs_embeds(self):
pass
@unittest.skipIf(
not is_tf_available() or len(tf.config.list_physical_devices("GPU")) == 0,
reason="TF does not support backprop for grouped convolutions on CPU.",
)
@slow
def test_keras_fit(self):
super().test_keras_fit()
@unittest.skip(reason="ConvNext does not support input and output embeddings")
def test_model_common_attributes(self):
pass
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
signature = inspect.signature(model.call)
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()]
expected_arg_names = ["pixel_values"]
self.assertListEqual(arg_names[:1], expected_arg_names)
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
@unittest.skipIf(
not is_tf_available() or len(tf.config.list_physical_devices("GPU")) == 0,
reason="TF does not support backprop for grouped convolutions on CPU.",
)
def test_dataset_conversion(self):
super().test_dataset_conversion()
def test_hidden_states_output(self):
def check_hidden_states_output(inputs_dict, config, model_class):
model = model_class(config)
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states
expected_num_stages = self.model_tester.num_stages
self.assertEqual(len(hidden_states), expected_num_stages + 1)
# ConvNext's feature maps are of shape (batch_size, num_channels, height, width)
self.assertListEqual(
list(hidden_states[0].shape[-2:]),
[self.model_tester.image_size // 4, self.model_tester.image_size // 4],
)
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
inputs_dict["output_hidden_states"] = True
check_hidden_states_output(inputs_dict, config, model_class)
# check that output_hidden_states also work using config
del inputs_dict["output_hidden_states"]
config.output_hidden_states = True
check_hidden_states_output(inputs_dict, config, model_class)
# Since ConvNext does not have any attention we need to rewrite this test.
def test_model_outputs_equivalence(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}):
tuple_output = model(tuple_inputs, return_dict=False, **additional_kwargs)
dict_output = model(dict_inputs, return_dict=True, **additional_kwargs).to_tuple()
def recursive_check(tuple_object, dict_object):
if isinstance(tuple_object, (List, Tuple)):
for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object):
recursive_check(tuple_iterable_value, dict_iterable_value)
elif tuple_object is None:
return
else:
self.assertTrue(
all(tf.equal(tuple_object, dict_object)),
msg=(
"Tuple and dict output are not equal. Difference:"
f" {tf.math.reduce_max(tf.abs(tuple_object - dict_object))}"
),
)
recursive_check(tuple_output, dict_output)
for model_class in self.all_model_classes:
model = model_class(config)
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
check_equivalence(model, tuple_inputs, dict_inputs)
tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
check_equivalence(model, tuple_inputs, dict_inputs)
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
def test_for_image_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
@slow
def test_model_from_pretrained(self):
model = TFConvNextV2Model.from_pretrained("facebook/convnextv2-tiny-1k-224")
self.assertIsNotNone(model)
# We will verify our results on an image of cute cats
def prepare_img():
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
return image
@require_tf
@require_vision
class TFConvNextV2ModelIntegrationTest(unittest.TestCase):
@cached_property
def default_image_processor(self):
return (
ConvNextImageProcessor.from_pretrained("facebook/convnextv2-tiny-1k-224")
if is_vision_available()
else None
)
@slow
def test_inference_image_classification_head(self):
model = TFConvNextV2ForImageClassification.from_pretrained("facebook/convnextv2-tiny-1k-224")
image_processor = self.default_image_processor
image = prepare_img()
inputs = image_processor(images=image, return_tensors="tf")
# forward pass
outputs = model(**inputs)
# verify the logits
expected_shape = tf.TensorShape((1, 1000))
self.assertEqual(outputs.logits.shape, expected_shape)
expected_slice = np.array([0.9996, 0.1966, -0.4386])
self.assertTrue(np.allclose(outputs.logits[0, :3].numpy(), expected_slice, atol=1e-4))
......@@ -533,6 +533,8 @@ OBJECTS_TO_IGNORE = [
"TFConvBertModel",
"TFConvNextForImageClassification",
"TFConvNextModel",
"TFConvNextV2Model", # Parsing issue. Equivalent to PT ConvNextV2Model, see PR #25558
"TFConvNextV2ForImageClassification",
"TFCvtForImageClassification",
"TFCvtModel",
"TFDPRReader",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment