Unverified Commit 84eaa6ac authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

Add TFConvNextModel (#15750)



* feat: initial implementation of convnext in tensorflow.

* fix: sample code for the classification model.

* chore: added checked for  from the classification model.

* chore: set bias initializer in the classification head.

* chore: updated license terms.

* chore: removed ununsed imports

* feat: enabled  argument during using drop_path.

* chore: replaced tf.identity with layers.Activation(linear).

* chore: edited default checkpoint.

* fix: minor bugs in the initializations.

* partial-fix: tf model errors for loading pretrained pt weights.

* partial-fix: call method updated

* partial-fix: cross loading of weights (4x3 variables to be matched)

* chore: removed unneeded comment.

* removed playground.py

* rebasing

* rebasing and removing playground.py.

* fix: renaming TFConvNextStage conv and layer norm layers

* chore: added initializers and other minor additions.

* chore: added initializers and other minor additions.

* add: tests for convnext.

* fix: integration tester class.

* fix: issues mentioned in pr feedback (round 1).

* fix: how output_hidden_states arg is propoagated inside the network.

* feat: handling of  arg for pure cnn models.

* chore: added a note on equal contribution in model docs.

* rebasing

* rebasing and removing playground.py.

* feat: encapsulation for the convnext trunk.

* Fix variable naming; Test-related corrections; Run make fixup

* chore: added Joao as a contributor to convnext.

* rebasing

* rebasing and removing playground.py.

* rebasing

* rebasing and removing playground.py.

* chore: corrected copyright year and added comment on NHWC.

* chore: fixed the black version and ran formatting.

* chore: ran make style.

* chore: removed from_pt argument from test, ran make style.

* rebasing

* rebasing and removing playground.py.

* rebasing

* rebasing and removing playground.py.

* fix: tests in the convnext subclass, ran make style.

* rebasing

* rebasing and removing playground.py.

* rebasing

* rebasing and removing playground.py.

* chore: moved convnext test to the correct location

* fix: locations for the test file of convnext.

* fix: convnext tests.

* chore: applied  sgugger's suggestion for dealing w/ output_attentions.

* chore: added comments.

* chore: applied updated quality enviornment style.

* chore: applied formatting with quality enviornment.

* chore: revert to the previous tests/test_modeling_common.py.

* chore: revert to the original test_modeling_common.py

* chore: revert to previous states for test_modeling_tf_common.py and modeling_tf_utils.py

* fix: tests for convnext.

* chore: removed output_attentions argument from convnext config.

* chore: revert to the earlier tf utils.

* fix: output shapes of the hidden states

* chore: removed unnecessary comment

* chore: reverting to the right test_modeling_tf_common.py.

* Styling nits
Co-authored-by: default avatarariG23498 <aritra.born2fly@gmail.com>
Co-authored-by: default avatarJoao Gante <joao@huggingface.co>
Co-authored-by: default avatarSylvain Gugger <Sylvain.gugger@gmail.com>
parent 0b5bf6ab
...@@ -179,7 +179,7 @@ Flax), PyTorch, and/or TensorFlow. ...@@ -179,7 +179,7 @@ Flax), PyTorch, and/or TensorFlow.
| Canine | ✅ | ❌ | ✅ | ❌ | ❌ | | Canine | ✅ | ❌ | ✅ | ❌ | ❌ |
| CLIP | ✅ | ✅ | ✅ | ✅ | ✅ | | CLIP | ✅ | ✅ | ✅ | ✅ | ✅ |
| ConvBERT | ✅ | ✅ | ✅ | ✅ | ❌ | | ConvBERT | ✅ | ✅ | ✅ | ✅ | ❌ |
| ConvNext | ❌ | ❌ | ✅ | | ❌ | | ConvNext | ❌ | ❌ | ✅ | | ❌ |
| CTRL | ✅ | ❌ | ✅ | ✅ | ❌ | | CTRL | ✅ | ❌ | ✅ | ✅ | ❌ |
| DeBERTa | ✅ | ✅ | ✅ | ✅ | ❌ | | DeBERTa | ✅ | ✅ | ✅ | ✅ | ❌ |
| DeBERTa-v2 | ✅ | ❌ | ✅ | ✅ | ❌ | | DeBERTa-v2 | ✅ | ❌ | ✅ | ✅ | ❌ |
......
...@@ -37,7 +37,8 @@ alt="drawing" width="600"/> ...@@ -37,7 +37,8 @@ alt="drawing" width="600"/>
<small> ConvNeXT architecture. Taken from the <a href="https://arxiv.org/abs/2201.03545">original paper</a>.</small> <small> ConvNeXT architecture. Taken from the <a href="https://arxiv.org/abs/2201.03545">original paper</a>.</small>
This model was contributed by [nielsr](https://huggingface.co/nielsr). The original code can be found [here](https://github.com/facebookresearch/ConvNeXt). This model was contributed by [nielsr](https://huggingface.co/nielsr). TensorFlow version of the model was contributed by [ariG23498](https://github.com/ariG23498),
[gante](https://github.com/gante), and [sayakpaul](https://github.com/sayakpaul) (equal contribution). The original code can be found [here](https://github.com/facebookresearch/ConvNeXt).
## ConvNeXT specific outputs ## ConvNeXT specific outputs
...@@ -64,3 +65,15 @@ This model was contributed by [nielsr](https://huggingface.co/nielsr). The origi ...@@ -64,3 +65,15 @@ This model was contributed by [nielsr](https://huggingface.co/nielsr). The origi
[[autodoc]] ConvNextForImageClassification [[autodoc]] ConvNextForImageClassification
- forward - forward
## TFConvNextModel
[[autodoc]] TFConvNextModel
- call
## TFConvNextForImageClassification
[[autodoc]] TFConvNextForImageClassification
- call
\ No newline at end of file
...@@ -1743,6 +1743,13 @@ if is_tf_available(): ...@@ -1743,6 +1743,13 @@ if is_tf_available():
"TFConvBertPreTrainedModel", "TFConvBertPreTrainedModel",
] ]
) )
_import_structure["models.convnext"].extend(
[
"TFConvNextForImageClassification",
"TFConvNextModel",
"TFConvNextPreTrainedModel",
]
)
_import_structure["models.ctrl"].extend( _import_structure["models.ctrl"].extend(
[ [
"TF_CTRL_PRETRAINED_MODEL_ARCHIVE_LIST", "TF_CTRL_PRETRAINED_MODEL_ARCHIVE_LIST",
...@@ -3751,6 +3758,7 @@ if TYPE_CHECKING: ...@@ -3751,6 +3758,7 @@ if TYPE_CHECKING:
TFConvBertModel, TFConvBertModel,
TFConvBertPreTrainedModel, TFConvBertPreTrainedModel,
) )
from .models.convnext import TFConvNextForImageClassification, TFConvNextModel, TFConvNextPreTrainedModel
from .models.ctrl import ( from .models.ctrl import (
TF_CTRL_PRETRAINED_MODEL_ARCHIVE_LIST, TF_CTRL_PRETRAINED_MODEL_ARCHIVE_LIST,
TFCTRLForSequenceClassification, TFCTRLForSequenceClassification,
......
...@@ -311,9 +311,10 @@ def booleans_processing(config, **kwargs): ...@@ -311,9 +311,10 @@ def booleans_processing(config, **kwargs):
final_booleans = {} final_booleans = {}
if tf.executing_eagerly(): if tf.executing_eagerly():
final_booleans["output_attentions"] = ( # Pure conv models (such as ConvNext) do not have `output_attentions`
kwargs["output_attentions"] if kwargs["output_attentions"] is not None else config.output_attentions final_booleans["output_attentions"] = kwargs.get("output_attentions", None)
) if final_booleans["output_attentions"] is None:
final_booleans["output_attentions"] = config.output_attentions
final_booleans["output_hidden_states"] = ( final_booleans["output_hidden_states"] = (
kwargs["output_hidden_states"] kwargs["output_hidden_states"]
if kwargs["output_hidden_states"] is not None if kwargs["output_hidden_states"] is not None
......
...@@ -36,6 +36,7 @@ TF_MODEL_MAPPING_NAMES = OrderedDict( ...@@ -36,6 +36,7 @@ TF_MODEL_MAPPING_NAMES = OrderedDict(
("rembert", "TFRemBertModel"), ("rembert", "TFRemBertModel"),
("roformer", "TFRoFormerModel"), ("roformer", "TFRoFormerModel"),
("convbert", "TFConvBertModel"), ("convbert", "TFConvBertModel"),
("convnext", "TFConvNextModel"),
("led", "TFLEDModel"), ("led", "TFLEDModel"),
("lxmert", "TFLxmertModel"), ("lxmert", "TFLxmertModel"),
("mt5", "TFMT5Model"), ("mt5", "TFMT5Model"),
...@@ -155,6 +156,7 @@ TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( ...@@ -155,6 +156,7 @@ TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[ [
# Model for Image-classsification # Model for Image-classsification
("vit", "TFViTForImageClassification"), ("vit", "TFViTForImageClassification"),
("convnext", "TFConvNextForImageClassification"),
] ]
) )
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
# rely on isort to merge the imports # rely on isort to merge the imports
from ...file_utils import _LazyModule, is_torch_available, is_vision_available from ...file_utils import _LazyModule, is_tf_available, is_torch_available, is_vision_available
_import_structure = { _import_structure = {
...@@ -36,6 +36,12 @@ if is_torch_available(): ...@@ -36,6 +36,12 @@ if is_torch_available():
"ConvNextPreTrainedModel", "ConvNextPreTrainedModel",
] ]
if is_tf_available():
_import_structure["modeling_tf_convnext"] = [
"TFConvNextForImageClassification",
"TFConvNextModel",
"TFConvNextPreTrainedModel",
]
if TYPE_CHECKING: if TYPE_CHECKING:
from .configuration_convnext import CONVNEXT_PRETRAINED_CONFIG_ARCHIVE_MAP, ConvNextConfig from .configuration_convnext import CONVNEXT_PRETRAINED_CONFIG_ARCHIVE_MAP, ConvNextConfig
...@@ -51,6 +57,9 @@ if TYPE_CHECKING: ...@@ -51,6 +57,9 @@ if TYPE_CHECKING:
ConvNextPreTrainedModel, ConvNextPreTrainedModel,
) )
if is_tf_available():
from .modeling_convnext import TFConvNextForImageClassification, TFConvNextModel, TFConvNextPreTrainedModel
else: else:
import sys import sys
......
...@@ -85,6 +85,7 @@ class ConvNextConfig(PretrainedConfig): ...@@ -85,6 +85,7 @@ class ConvNextConfig(PretrainedConfig):
is_encoder_decoder=False, is_encoder_decoder=False,
layer_scale_init_value=1e-6, layer_scale_init_value=1e-6,
drop_path_rate=0.0, drop_path_rate=0.0,
image_size=224,
**kwargs **kwargs
): ):
super().__init__(**kwargs) super().__init__(**kwargs)
...@@ -99,3 +100,4 @@ class ConvNextConfig(PretrainedConfig): ...@@ -99,3 +100,4 @@ class ConvNextConfig(PretrainedConfig):
self.layer_norm_eps = layer_norm_eps self.layer_norm_eps = layer_norm_eps
self.layer_scale_init_value = layer_scale_init_value self.layer_scale_init_value = layer_scale_init_value
self.drop_path_rate = drop_path_rate self.drop_path_rate = drop_path_rate
self.image_size = image_size
This diff is collapsed.
...@@ -641,6 +641,27 @@ class TFConvBertPreTrainedModel(metaclass=DummyObject): ...@@ -641,6 +641,27 @@ class TFConvBertPreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["tf"]) requires_backends(self, ["tf"])
class TFConvNextForImageClassification(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFConvNextModel(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFConvNextPreTrainedModel(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
TF_CTRL_PRETRAINED_MODEL_ARCHIVE_LIST = None TF_CTRL_PRETRAINED_MODEL_ARCHIVE_LIST = None
......
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Testing suite for the TensorFlow ConvNext model. """
import inspect
import unittest
from typing import List, Tuple
from transformers import ConvNextConfig
from transformers.file_utils import cached_property, is_tf_available, is_vision_available
from transformers.testing_utils import require_tf, require_vision, slow
from ..test_configuration_common import ConfigTester
from ..test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor
if is_tf_available():
import tensorflow as tf
from transformers import TFConvNextForImageClassification, TFConvNextModel
if is_vision_available():
from PIL import Image
from transformers import ConvNextFeatureExtractor
class TFConvNextModelTester:
def __init__(
self,
parent,
batch_size=13,
image_size=32,
num_channels=3,
num_stages=4,
hidden_sizes=[10, 20, 30, 40],
depths=[2, 2, 3, 2],
is_training=True,
use_labels=True,
intermediate_size=37,
hidden_act="gelu",
type_sequence_label_size=10,
initializer_range=0.02,
num_labels=3,
scope=None,
):
self.parent = parent
self.batch_size = batch_size
self.image_size = image_size
self.num_channels = num_channels
self.num_stages = num_stages
self.hidden_sizes = hidden_sizes
self.depths = depths
self.is_training = is_training
self.use_labels = use_labels
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range
self.scope = scope
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
labels = None
if self.use_labels:
labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
config = self.get_config()
return config, pixel_values, labels
def get_config(self):
return ConvNextConfig(
num_channels=self.num_channels,
hidden_sizes=self.hidden_sizes,
depths=self.depths,
num_stages=self.num_stages,
hidden_act=self.hidden_act,
is_decoder=False,
initializer_range=self.initializer_range,
)
def create_and_check_model(self, config, pixel_values, labels):
model = TFConvNextModel(config=config)
result = model(pixel_values, training=False)
# expected last hidden states: B, C, H // 32, W // 32
self.parent.assertEqual(
result.last_hidden_state.shape,
(self.batch_size, self.hidden_sizes[-1], self.image_size // 32, self.image_size // 32),
)
def create_and_check_for_image_classification(self, config, pixel_values, labels):
config.num_labels = self.type_sequence_label_size
model = TFConvNextForImageClassification(config)
result = model(pixel_values, labels=labels, training=False)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, pixel_values, labels = config_and_inputs
inputs_dict = {"pixel_values": pixel_values}
return config, inputs_dict
@require_tf
class TFConvNextModelTest(TFModelTesterMixin, unittest.TestCase):
"""
Here we also overwrite some of the tests of test_modeling_common.py, as ConvNext does not use input_ids, inputs_embeds,
attention_mask and seq_length.
"""
all_model_classes = (TFConvNextModel, TFConvNextForImageClassification) if is_tf_available() else ()
test_pruning = False
test_onnx = False
test_resize_embeddings = False
test_head_masking = False
def setUp(self):
self.model_tester = TFConvNextModelTester(self)
self.config_tester = ConfigTester(
self,
config_class=ConvNextConfig,
has_text_modality=False,
hidden_size=37,
)
@unittest.skip(reason="ConvNext does not use inputs_embeds")
def test_inputs_embeds(self):
pass
@unittest.skip(reason="ConvNext does not support input and output embeddings")
def test_model_common_attributes(self):
pass
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
signature = inspect.signature(model.call)
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()]
expected_arg_names = ["pixel_values"]
self.assertListEqual(arg_names[:1], expected_arg_names)
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
@unittest.skip(reason="Model doesn't have attention layers")
def test_attention_outputs(self):
pass
def test_hidden_states_output(self):
def check_hidden_states_output(inputs_dict, config, model_class):
model = model_class(config)
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states
expected_num_stages = self.model_tester.num_stages
self.assertEqual(len(hidden_states), expected_num_stages + 1)
# ConvNext's feature maps are of shape (batch_size, num_channels, height, width)
self.assertListEqual(
list(hidden_states[0].shape[-2:]),
[self.model_tester.image_size // 4, self.model_tester.image_size // 4],
)
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
inputs_dict["output_hidden_states"] = True
check_hidden_states_output(inputs_dict, config, model_class)
# check that output_hidden_states also work using config
del inputs_dict["output_hidden_states"]
config.output_hidden_states = True
check_hidden_states_output(inputs_dict, config, model_class)
# Since ConvNext does not have any attention we need to rewrite this test.
def test_model_outputs_equivalence(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}):
tuple_output = model(tuple_inputs, return_dict=False, **additional_kwargs)
dict_output = model(dict_inputs, return_dict=True, **additional_kwargs).to_tuple()
def recursive_check(tuple_object, dict_object):
if isinstance(tuple_object, (List, Tuple)):
for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object):
recursive_check(tuple_iterable_value, dict_iterable_value)
elif tuple_object is None:
return
else:
self.assertTrue(
all(tf.equal(tuple_object, dict_object)),
msg=f"Tuple and dict output are not equal. Difference: {tf.math.reduce_max(tf.abs(tuple_object - dict_object))}",
)
recursive_check(tuple_output, dict_output)
for model_class in self.all_model_classes:
model = model_class(config)
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
check_equivalence(model, tuple_inputs, dict_inputs)
tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
check_equivalence(model, tuple_inputs, dict_inputs)
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
def test_for_image_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
@slow
def test_model_from_pretrained(self):
model = TFConvNextModel.from_pretrained("facebook/convnext-tiny-224")
self.assertIsNotNone(model)
# We will verify our results on an image of cute cats
def prepare_img():
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
return image
@require_tf
@require_vision
class TFConvNextModelIntegrationTest(unittest.TestCase):
@cached_property
def default_feature_extractor(self):
return (
ConvNextFeatureExtractor.from_pretrained("facebook/convnext-tiny-224") if is_vision_available() else None
)
@slow
def test_inference_image_classification_head(self):
model = TFConvNextForImageClassification.from_pretrained("facebook/convnext-tiny-224")
feature_extractor = self.default_feature_extractor
image = prepare_img()
inputs = feature_extractor(images=image, return_tensors="tf")
# forward pass
outputs = model(**inputs)
# verify the logits
expected_shape = tf.TensorShape((1, 1000))
self.assertEqual(outputs.logits.shape, expected_shape)
expected_slice = tf.constant([-0.0260, -0.4739, 0.1911])
tf.debugging.assert_near(outputs.logits[0, :3], expected_slice, atol=1e-4)
...@@ -474,8 +474,8 @@ class TFModelTesterMixin: ...@@ -474,8 +474,8 @@ class TFModelTesterMixin:
), ),
"input_ids": tf.keras.Input(batch_shape=(2, max_input), name="input_ids", dtype="int32"), "input_ids": tf.keras.Input(batch_shape=(2, max_input), name="input_ids", dtype="int32"),
} }
# TODO: A better way to handle vision models # `pixel_values` implies that the input is an image
elif model_class.__name__ in ["TFViTModel", "TFViTForImageClassification", "TFCLIPVisionModel"]: elif model_class.main_input_name == "pixel_values":
inputs = tf.keras.Input( inputs = tf.keras.Input(
batch_shape=( batch_shape=(
3, 3,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment