Unverified Commit a7eba831 authored by Aritra Roy Gosthipaty's avatar Aritra Roy Gosthipaty Committed by GitHub
Browse files

TF implementation of RegNets (#17554)



* chore: initial commit

Copied the torch implementation of regnets and porting the code to tf step by step. Also introduced an output layer which was needed for regnets.

* chore: porting the rest of the modules to tensorflow

did not change the documentation yet, yet to try the playground on the model

* Fix initilizations (#1)

* fix: code structure in few cases.

* fix: code structure to align tf models.

* fix: layer naming, bn layer still remains.

* chore: change default epsilon and momentum in bn.

* chore: styling nits.

* fix: cross-loading bn params.

* fix: regnet tf model, integration passing.

* add: tests for TF regnet.

* fix: code quality related issues.

* chore: added rest of the files.

* minor additions..

* fix: repo consistency.

* fix: regnet tf tests.

* chore: reorganize dummy_tf_objects for regnet.

* chore: remove checkpoint var.

* chore: remov unnecessary files.

* chore: run make style.

* Update docs/source/en/model_doc/regnet.mdx
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* chore: PR feedback I.

* fix: pt test. thanks to @ydshieh.

* New adaptive pooler (#3)

* feat: new adaptive pooler

Co-authored-by: @Rocketknight1

* chore: remove image_size argument.
Co-authored-by: default avatarmatt <rocketknight1@gmail.com>
Co-authored-by: default avatarmatt <rocketknight1@gmail.com>

* Empty-Commit

* chore: remove image_size comment.

* chore: remove playground_tf.py

* chore: minor changes related to spacing.

* chore: make style.

* Update src/transformers/models/regnet/modeling_tf_regnet.py
Co-authored-by: default avataramyeroberts <aeroberts4444@gmail.com>

* Update src/transformers/models/regnet/modeling_tf_regnet.py
Co-authored-by: default avataramyeroberts <aeroberts4444@gmail.com>

* chore: refactored __init__.

* chore: copied from -> taken from./g

* adaptive pool -> global avg pool, channel check.

* chore: move channel check to stem.

* pr comments - minor refactor and add regnets to doc tests.

* Update src/transformers/models/regnet/modeling_tf_regnet.py
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* minor fix in the xlayer.

* Empty-Commit

* chore: removed from_pt=True.
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarmatt <rocketknight1@gmail.com>
Co-authored-by: default avataramyeroberts <aeroberts4444@gmail.com>
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>
parent e6d27ca5
...@@ -267,7 +267,7 @@ Flax), PyTorch, and/or TensorFlow. ...@@ -267,7 +267,7 @@ Flax), PyTorch, and/or TensorFlow.
| RAG | ✅ | ❌ | ✅ | ✅ | ❌ | | RAG | ✅ | ❌ | ✅ | ✅ | ❌ |
| REALM | ✅ | ✅ | ✅ | ❌ | ❌ | | REALM | ✅ | ✅ | ✅ | ❌ | ❌ |
| Reformer | ✅ | ✅ | ✅ | ❌ | ❌ | | Reformer | ✅ | ✅ | ✅ | ❌ | ❌ |
| RegNet | ❌ | ❌ | ✅ | | ❌ | | RegNet | ❌ | ❌ | ✅ | | ❌ |
| RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ | | RemBERT | ✅ | ✅ | ✅ | ✅ | ❌ |
| ResNet | ❌ | ❌ | ✅ | ❌ | ❌ | | ResNet | ❌ | ❌ | ✅ | ❌ | ❌ |
| RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ | | RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ |
......
...@@ -27,7 +27,8 @@ Tips: ...@@ -27,7 +27,8 @@ Tips:
- One can use [`AutoFeatureExtractor`] to prepare images for the model. - One can use [`AutoFeatureExtractor`] to prepare images for the model.
- The huge 10B model from [Self-supervised Pretraining of Visual Features in the Wild](https://arxiv.org/abs/2103.01988), trained on one billion Instagram images, is available on the [hub](https://huggingface.co/facebook/regnet-y-10b-seer) - The huge 10B model from [Self-supervised Pretraining of Visual Features in the Wild](https://arxiv.org/abs/2103.01988), trained on one billion Instagram images, is available on the [hub](https://huggingface.co/facebook/regnet-y-10b-seer)
This model was contributed by [Francesco](https://huggingface.co/Francesco). This model was contributed by [Francesco](https://huggingface.co/Francesco). The TensorFlow version of the model
was contributed by [sayakpaul](https://huggingface.com/sayakpaul) and [ariG23498](https://huggingface.com/ariG23498).
The original code can be found [here](https://github.com/facebookresearch/pycls). The original code can be found [here](https://github.com/facebookresearch/pycls).
...@@ -46,3 +47,14 @@ The original code can be found [here](https://github.com/facebookresearch/pycls) ...@@ -46,3 +47,14 @@ The original code can be found [here](https://github.com/facebookresearch/pycls)
[[autodoc]] RegNetForImageClassification [[autodoc]] RegNetForImageClassification
- forward - forward
## TFRegNetModel
[[autodoc]] TFRegNetModel
- call
## TFRegNetForImageClassification
[[autodoc]] TFRegNetForImageClassification
- call
\ No newline at end of file
...@@ -2334,6 +2334,14 @@ else: ...@@ -2334,6 +2334,14 @@ else:
"TFRagTokenForGeneration", "TFRagTokenForGeneration",
] ]
) )
_import_structure["models.regnet"].extend(
[
"TF_REGNET_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFRegNetForImageClassification",
"TFRegNetModel",
"TFRegNetPreTrainedModel",
]
)
_import_structure["models.rembert"].extend( _import_structure["models.rembert"].extend(
[ [
"TF_REMBERT_PRETRAINED_MODEL_ARCHIVE_LIST", "TF_REMBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
...@@ -4649,6 +4657,12 @@ if TYPE_CHECKING: ...@@ -4649,6 +4657,12 @@ if TYPE_CHECKING:
from .models.opt import TFOPTForCausalLM, TFOPTModel, TFOPTPreTrainedModel from .models.opt import TFOPTForCausalLM, TFOPTModel, TFOPTPreTrainedModel
from .models.pegasus import TFPegasusForConditionalGeneration, TFPegasusModel, TFPegasusPreTrainedModel from .models.pegasus import TFPegasusForConditionalGeneration, TFPegasusModel, TFPegasusPreTrainedModel
from .models.rag import TFRagModel, TFRagPreTrainedModel, TFRagSequenceForGeneration, TFRagTokenForGeneration from .models.rag import TFRagModel, TFRagPreTrainedModel, TFRagSequenceForGeneration, TFRagTokenForGeneration
from .models.regnet import (
TF_REGNET_PRETRAINED_MODEL_ARCHIVE_LIST,
TFRegNetForImageClassification,
TFRegNetModel,
TFRegNetPreTrainedModel,
)
from .models.rembert import ( from .models.rembert import (
TF_REMBERT_PRETRAINED_MODEL_ARCHIVE_LIST, TF_REMBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFRemBertForCausalLM, TFRemBertForCausalLM,
......
...@@ -46,6 +46,25 @@ class TFBaseModelOutput(ModelOutput): ...@@ -46,6 +46,25 @@ class TFBaseModelOutput(ModelOutput):
attentions: Optional[Tuple[tf.Tensor]] = None attentions: Optional[Tuple[tf.Tensor]] = None
@dataclass
class TFBaseModelOutputWithNoAttention(ModelOutput):
"""
Base class for model's outputs, with potential hidden states.
Args:
last_hidden_state (`tf.Tensor` shape `(batch_size, num_channels, height, width)`):
Sequence of hidden-states at the output of the last layer of the model.
hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, num_channels, height, width)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
"""
last_hidden_state: tf.Tensor = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
@dataclass @dataclass
class TFBaseModelOutputWithPooling(ModelOutput): class TFBaseModelOutputWithPooling(ModelOutput):
""" """
...@@ -80,6 +99,28 @@ class TFBaseModelOutputWithPooling(ModelOutput): ...@@ -80,6 +99,28 @@ class TFBaseModelOutputWithPooling(ModelOutput):
attentions: Optional[Tuple[tf.Tensor]] = None attentions: Optional[Tuple[tf.Tensor]] = None
@dataclass
class TFBaseModelOutputWithPoolingAndNoAttention(ModelOutput):
"""
Base class for model's outputs that also contains a pooling of the last hidden states.
Args:
last_hidden_state (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`):
Sequence of hidden-states at the output of the last layer of the model.
pooler_output (`tf.Tensor` of shape `(batch_size, hidden_size)`):
Last layer hidden-state after a pooling operation on the spatial dimensions.
hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for
the output of each layer) of shape `(batch_size, num_channels, height, width)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
"""
last_hidden_state: tf.Tensor = None
pooler_output: tf.Tensor = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
@dataclass @dataclass
class TFBaseModelOutputWithPoolingAndCrossAttentions(ModelOutput): class TFBaseModelOutputWithPoolingAndCrossAttentions(ModelOutput):
""" """
...@@ -825,3 +866,24 @@ class TFSequenceClassifierOutputWithPast(ModelOutput): ...@@ -825,3 +866,24 @@ class TFSequenceClassifierOutputWithPast(ModelOutput):
past_key_values: Optional[List[tf.Tensor]] = None past_key_values: Optional[List[tf.Tensor]] = None
hidden_states: Optional[Tuple[tf.Tensor]] = None hidden_states: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[tf.Tensor]] = None attentions: Optional[Tuple[tf.Tensor]] = None
@dataclass
class TFImageClassifierOutputWithNoAttention(ModelOutput):
"""
Base class for outputs of image classification models.
Args:
loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Classification (or regression if config.num_labels==1) loss.
logits (`tf.Tensor` of shape `(batch_size, config.num_labels)`):
Classification (or regression if config.num_labels==1) scores (before SoftMax).
hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for
the output of each stage) of shape `(batch_size, num_channels, height, width)`. Hidden-states (also called
feature maps) of the model at the output of each stage.
"""
loss: Optional[tf.Tensor] = None
logits: tf.Tensor = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
...@@ -62,6 +62,7 @@ TF_MODEL_MAPPING_NAMES = OrderedDict( ...@@ -62,6 +62,7 @@ TF_MODEL_MAPPING_NAMES = OrderedDict(
("openai-gpt", "TFOpenAIGPTModel"), ("openai-gpt", "TFOpenAIGPTModel"),
("opt", "TFOPTModel"), ("opt", "TFOPTModel"),
("pegasus", "TFPegasusModel"), ("pegasus", "TFPegasusModel"),
("regnet", "TFRegNetModel"),
("rembert", "TFRemBertModel"), ("rembert", "TFRemBertModel"),
("roberta", "TFRobertaModel"), ("roberta", "TFRobertaModel"),
("roformer", "TFRoFormerModel"), ("roformer", "TFRoFormerModel"),
...@@ -173,6 +174,7 @@ TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( ...@@ -173,6 +174,7 @@ TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
# Model for Image-classsification # Model for Image-classsification
("convnext", "TFConvNextForImageClassification"), ("convnext", "TFConvNextForImageClassification"),
("data2vec-vision", "TFData2VecVisionForImageClassification"), ("data2vec-vision", "TFData2VecVisionForImageClassification"),
("regnet", "TFRegNetForImageClassification"),
("swin", "TFSwinForImageClassification"), ("swin", "TFSwinForImageClassification"),
("vit", "TFViTForImageClassification"), ("vit", "TFViTForImageClassification"),
] ]
......
...@@ -18,8 +18,7 @@ ...@@ -18,8 +18,7 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
# rely on isort to merge the imports # rely on isort to merge the imports
from ...file_utils import _LazyModule, is_torch_available from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tf_available, is_torch_available
from ...utils import OptionalDependencyNotAvailable
_import_structure = {"configuration_regnet": ["REGNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "RegNetConfig"]} _import_structure = {"configuration_regnet": ["REGNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "RegNetConfig"]}
...@@ -37,6 +36,19 @@ else: ...@@ -37,6 +36,19 @@ else:
"RegNetPreTrainedModel", "RegNetPreTrainedModel",
] ]
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_tf_regnet"] = [
"TF_REGNET_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFRegNetForImageClassification",
"TFRegNetModel",
"TFRegNetPreTrainedModel",
]
if TYPE_CHECKING: if TYPE_CHECKING:
from .configuration_regnet import REGNET_PRETRAINED_CONFIG_ARCHIVE_MAP, RegNetConfig from .configuration_regnet import REGNET_PRETRAINED_CONFIG_ARCHIVE_MAP, RegNetConfig
...@@ -54,6 +66,19 @@ if TYPE_CHECKING: ...@@ -54,6 +66,19 @@ if TYPE_CHECKING:
RegNetPreTrainedModel, RegNetPreTrainedModel,
) )
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_tf_regnet import (
TF_REGNET_PRETRAINED_MODEL_ARCHIVE_LIST,
TFRegNetForImageClassification,
TFRegNetModel,
TFRegNetPreTrainedModel,
)
else: else:
import sys import sys
......
This diff is collapsed.
...@@ -1696,6 +1696,30 @@ class TFRagTokenForGeneration(metaclass=DummyObject): ...@@ -1696,6 +1696,30 @@ class TFRagTokenForGeneration(metaclass=DummyObject):
requires_backends(self, ["tf"]) requires_backends(self, ["tf"])
TF_REGNET_PRETRAINED_MODEL_ARCHIVE_LIST = None
class TFRegNetForImageClassification(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFRegNetModel(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFRegNetPreTrainedModel(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
TF_REMBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None TF_REMBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
......
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Testing suite for the TensorFlow RegNet model. """
import inspect
import unittest
from typing import List, Tuple
from transformers import RegNetConfig
from transformers.testing_utils import require_tf, require_vision, slow
from transformers.utils import cached_property, is_tf_available, is_vision_available
from ...test_configuration_common import ConfigTester
from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor
if is_tf_available():
import tensorflow as tf
from transformers import TF_REGNET_PRETRAINED_MODEL_ARCHIVE_LIST, TFRegNetForImageClassification, TFRegNetModel
if is_vision_available():
from PIL import Image
from transformers import AutoFeatureExtractor
class TFRegNetModelTester:
def __init__(
self,
parent,
batch_size=3,
image_size=32,
num_channels=3,
embeddings_size=10,
hidden_sizes=[10, 20, 30, 40],
depths=[1, 1, 2, 1],
is_training=True,
use_labels=True,
hidden_act="relu",
num_labels=3,
scope=None,
):
self.parent = parent
self.batch_size = batch_size
self.image_size = image_size
self.num_channels = num_channels
self.embeddings_size = embeddings_size
self.hidden_sizes = hidden_sizes
self.depths = depths
self.is_training = is_training
self.use_labels = use_labels
self.hidden_act = hidden_act
self.num_labels = num_labels
self.scope = scope
self.num_stages = len(hidden_sizes)
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
labels = None
if self.use_labels:
labels = ids_tensor([self.batch_size], self.num_labels)
config = self.get_config()
return config, pixel_values, labels
def get_config(self):
return RegNetConfig(
num_channels=self.num_channels,
embeddings_size=self.embeddings_size,
hidden_sizes=self.hidden_sizes,
depths=self.depths,
hidden_act=self.hidden_act,
num_labels=self.num_labels,
)
def create_and_check_model(self, config, pixel_values, labels):
model = TFRegNetModel(config=config)
result = model(pixel_values, training=False)
# expected last hidden states: B, C, H // 32, W // 32
self.parent.assertEqual(
result.last_hidden_state.shape,
(self.batch_size, self.hidden_sizes[-1], self.image_size // 32, self.image_size // 32),
)
def create_and_check_for_image_classification(self, config, pixel_values, labels):
config.num_labels = self.num_labels
model = TFRegNetForImageClassification(config)
result = model(pixel_values, labels=labels, training=False)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, pixel_values, labels = config_and_inputs
inputs_dict = {"pixel_values": pixel_values}
return config, inputs_dict
@require_tf
class TFRegNetModelTest(TFModelTesterMixin, unittest.TestCase):
"""
Here we also overwrite some of the tests of test_modeling_common.py, as RegNet does not use input_ids, inputs_embeds,
attention_mask and seq_length.
"""
all_model_classes = (TFRegNetModel, TFRegNetForImageClassification) if is_tf_available() else ()
test_pruning = False
test_onnx = False
test_resize_embeddings = False
test_head_masking = False
has_attentions = False
def setUp(self):
self.model_tester = TFRegNetModelTester(self)
self.config_tester = ConfigTester(self, config_class=RegNetConfig, has_text_modality=False)
def create_and_test_config_common_properties(self):
return
@unittest.skip(reason="RegNet does not use inputs_embeds")
def test_inputs_embeds(self):
pass
@unittest.skipIf(
not is_tf_available() or len(tf.config.list_physical_devices("GPU")) == 0,
reason="TF (<=2.8) does not support backprop for grouped convolutions on CPU.",
)
def test_keras_fit(self):
pass
@unittest.skip(reason="RegNet does not support input and output embeddings")
def test_model_common_attributes(self):
pass
@unittest.skip(reason="Model doesn't have attention layers")
def test_attention_outputs(self):
pass
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
signature = inspect.signature(model.call)
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()]
expected_arg_names = ["pixel_values"]
self.assertListEqual(arg_names[:1], expected_arg_names)
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
def test_hidden_states_output(self):
def check_hidden_states_output(inputs_dict, config, model_class):
model = model_class(config)
outputs = model(**self._prepare_for_class(inputs_dict, model_class), training=False)
hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states
expected_num_stages = self.model_tester.num_stages
self.assertEqual(len(hidden_states), expected_num_stages + 1)
# RegNet's feature maps are of shape (batch_size, num_channels, height, width)
self.assertListEqual(
list(hidden_states[0].shape[-2:]),
[self.model_tester.image_size // 2, self.model_tester.image_size // 2],
)
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
layers_type = ["basic", "bottleneck"]
for model_class in self.all_model_classes:
for layer_type in layers_type:
config.layer_type = layer_type
inputs_dict["output_hidden_states"] = True
check_hidden_states_output(inputs_dict, config, model_class)
# check that output_hidden_states also work using config
del inputs_dict["output_hidden_states"]
config.output_hidden_states = True
check_hidden_states_output(inputs_dict, config, model_class)
# Since RegNet does not have any attention we need to rewrite this test.
def test_model_outputs_equivalence(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}):
tuple_output = model(tuple_inputs, return_dict=False, **additional_kwargs)
dict_output = model(dict_inputs, return_dict=True, **additional_kwargs).to_tuple()
def recursive_check(tuple_object, dict_object):
if isinstance(tuple_object, (List, Tuple)):
for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object):
recursive_check(tuple_iterable_value, dict_iterable_value)
elif tuple_object is None:
return
else:
self.assertTrue(
all(tf.equal(tuple_object, dict_object)),
msg=(
"Tuple and dict output are not equal. Difference:"
f" {tf.math.reduce_max(tf.abs(tuple_object - dict_object))}"
),
)
recursive_check(tuple_output, dict_output)
for model_class in self.all_model_classes:
model = model_class(config)
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
check_equivalence(model, tuple_inputs, dict_inputs)
tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
check_equivalence(model, tuple_inputs, dict_inputs)
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
def test_for_image_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
@slow
def test_model_from_pretrained(self):
for model_name in TF_REGNET_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
model = TFRegNetModel.from_pretrained(model_name)
self.assertIsNotNone(model)
# We will verify our results on an image of cute cats
def prepare_img():
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
return image
@require_tf
@require_vision
class RegNetModelIntegrationTest(unittest.TestCase):
@cached_property
def default_feature_extractor(self):
return (
AutoFeatureExtractor.from_pretrained(TF_REGNET_PRETRAINED_MODEL_ARCHIVE_LIST[0])
if is_vision_available()
else None
)
@slow
def test_inference_image_classification_head(self):
model = TFRegNetForImageClassification.from_pretrained(TF_REGNET_PRETRAINED_MODEL_ARCHIVE_LIST[0])
feature_extractor = self.default_feature_extractor
image = prepare_img()
inputs = feature_extractor(images=image, return_tensors="tf")
# forward pass
outputs = model(**inputs, training=False)
# verify the logits
expected_shape = tf.TensorShape((1, 1000))
self.assertEqual(outputs.logits.shape, expected_shape)
expected_slice = tf.constant([-0.4180, -1.5051, -3.4836])
tf.debugging.assert_near(outputs.logits[0, :3], expected_slice, atol=1e-4)
...@@ -50,6 +50,8 @@ src/transformers/models/pegasus/modeling_pegasus.py ...@@ -50,6 +50,8 @@ src/transformers/models/pegasus/modeling_pegasus.py
src/transformers/models/plbart/modeling_plbart.py src/transformers/models/plbart/modeling_plbart.py
src/transformers/models/poolformer/modeling_poolformer.py src/transformers/models/poolformer/modeling_poolformer.py
src/transformers/models/reformer/modeling_reformer.py src/transformers/models/reformer/modeling_reformer.py
src/transformers/models/regnet/modeling_regnet.py
src/transformers/models/regnet/modeling_tf_regnet.py
src/transformers/models/resnet/modeling_resnet.py src/transformers/models/resnet/modeling_resnet.py
src/transformers/models/roberta/modeling_roberta.py src/transformers/models/roberta/modeling_roberta.py
src/transformers/models/roberta/modeling_tf_roberta.py src/transformers/models/roberta/modeling_tf_roberta.py
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment