"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "43d17c18360ac9c3d3491389328e2fe55fe8f9ce"
Unverified Commit a717e031 authored by amyeroberts's avatar amyeroberts Committed by GitHub
Browse files

Add TimmBackbone model (#22619)



* Add test_backbone for convnext

* Add TimmBackbone model

* Add check for backbone type

* Tidying up - config checks

* Update convnextv2

* Tidy up

* Fix indices & clearer comment

* Exceptions for config checks

* Correclty update config for tests

* Safer imports

* Safer safer imports

* Fix where decorators go

* Update import logic and backbone tests

* More import fixes

* Fixup

* Only import all_models if torch available

* Fix kwarg updates in from_pretrained & main rebase

* Tidy up

* Add tests for AutoBackbone

* Tidy up

* Fix import error

* Fix up

* Install nattan in doc_test_job

* Revert back to setting self._out_xxx directly

* Bug fix - out_indices mapping from out_features

* Fix tests

* Dont accept output_loading_info for Timm models

* Set out_xxx and don't remap

* Use smaller checkpoint for test

* Don't remap timm indices - check out_indices based on stage names

* Skip test as it's n/a

* Apply suggestions from code review
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Cleaner imports / spelling is hard

---------
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent b8935980
...@@ -6806,6 +6806,13 @@ class TimesformerPreTrainedModel(metaclass=DummyObject): ...@@ -6806,6 +6806,13 @@ class TimesformerPreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class TimmBackbone(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
TRAJECTORY_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None TRAJECTORY_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None
......
...@@ -45,6 +45,7 @@ if is_torch_available(): ...@@ -45,6 +45,7 @@ if is_torch_available():
from test_module.custom_modeling import CustomModel from test_module.custom_modeling import CustomModel
from transformers import ( from transformers import (
AutoBackbone,
AutoConfig, AutoConfig,
AutoModel, AutoModel,
AutoModelForCausalLM, AutoModelForCausalLM,
...@@ -66,11 +67,13 @@ if is_torch_available(): ...@@ -66,11 +67,13 @@ if is_torch_available():
FunnelModel, FunnelModel,
GPT2Config, GPT2Config,
GPT2LMHeadModel, GPT2LMHeadModel,
ResNetBackbone,
RobertaForMaskedLM, RobertaForMaskedLM,
T5Config, T5Config,
T5ForConditionalGeneration, T5ForConditionalGeneration,
TapasConfig, TapasConfig,
TapasForQuestionAnswering, TapasForQuestionAnswering,
TimmBackbone,
) )
from transformers.models.auto.modeling_auto import ( from transformers.models.auto.modeling_auto import (
MODEL_FOR_CAUSAL_LM_MAPPING, MODEL_FOR_CAUSAL_LM_MAPPING,
...@@ -224,6 +227,42 @@ class AutoModelTest(unittest.TestCase): ...@@ -224,6 +227,42 @@ class AutoModelTest(unittest.TestCase):
self.assertIsNotNone(model) self.assertIsNotNone(model)
self.assertIsInstance(model, BertForTokenClassification) self.assertIsInstance(model, BertForTokenClassification)
@slow
def test_auto_backbone_timm_model_from_pretrained(self):
# Configs can't be loaded for timm models
model = AutoBackbone.from_pretrained("resnet18", use_timm_backbone=True)
with pytest.raises(ValueError):
# We can't pass output_loading_info=True as we're loading from timm
AutoBackbone.from_pretrained("resnet18", use_timm_backbone=True, output_loading_info=True)
self.assertIsNotNone(model)
self.assertIsInstance(model, TimmBackbone)
# Check kwargs are correctly passed to the backbone
model = AutoBackbone.from_pretrained("resnet18", use_timm_backbone=True, out_indices=(-1, -2))
self.assertEqual(model.out_indices, (-1, -2))
# Check out_features cannot be passed to Timm backbones
with self.assertRaises(ValueError):
_ = AutoBackbone.from_pretrained("resnet18", use_timm_backbone=True, out_features=["stage1"])
@slow
def test_auto_backbone_from_pretrained(self):
model = AutoBackbone.from_pretrained("microsoft/resnet-18")
model, loading_info = AutoBackbone.from_pretrained("microsoft/resnet-18", output_loading_info=True)
self.assertIsNotNone(model)
self.assertIsInstance(model, ResNetBackbone)
# Check kwargs are correctly passed to the backbone
model = AutoBackbone.from_pretrained("microsoft/resnet-18", out_indices=[-1, -2])
self.assertEqual(model.out_indices, [-1, -2])
self.assertEqual(model.out_features, ["stage4", "stage3"])
model = AutoBackbone.from_pretrained("microsoft/resnet-18", out_features=["stage2", "stage4"])
self.assertEqual(model.out_indices, [2, 4])
self.assertEqual(model.out_features, ["stage2", "stage4"])
def test_from_pretrained_identifier(self): def test_from_pretrained_identifier(self):
model = AutoModelWithLMHead.from_pretrained(SMALL_MODEL_IDENTIFIER) model = AutoModelWithLMHead.from_pretrained(SMALL_MODEL_IDENTIFIER)
self.assertIsInstance(model, BertForMaskedLM) self.assertIsInstance(model, BertForMaskedLM)
......
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import inspect
import unittest
from transformers import AutoBackbone
from transformers.configuration_utils import PretrainedConfig
from transformers.testing_utils import require_timm, require_torch, torch_device
from transformers.utils.import_utils import is_torch_available
from ...test_backbone_common import BackboneTesterMixin
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor
if is_torch_available():
import torch
from transformers import TimmBackbone, TimmBackboneConfig
class TimmBackboneModelTester:
def __init__(
self,
parent,
out_indices=None,
out_features=None,
stage_names=None,
backbone="resnet50",
batch_size=3,
image_size=32,
num_channels=3,
is_training=True,
use_pretrained_backbone=True,
):
self.parent = parent
self.out_indices = out_indices if out_indices is not None else [4]
self.stage_names = stage_names
self.out_features = out_features
self.backbone = backbone
self.batch_size = batch_size
self.image_size = image_size
self.num_channels = num_channels
self.use_pretrained_backbone = use_pretrained_backbone
self.is_training = is_training
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
config = self.get_config()
return config, pixel_values
def get_config(self):
return TimmBackboneConfig(
image_size=self.image_size,
num_channels=self.num_channels,
out_features=self.out_features,
out_indices=self.out_indices,
stage_names=self.stage_names,
use_pretrained_backbone=self.use_pretrained_backbone,
backbone=self.backbone,
)
def create_and_check_model(self, config, pixel_values):
model = TimmBackbone(config=config)
model.to(torch_device)
model.eval()
with torch.no_grad():
result = model(pixel_values)
self.parent.assertEqual(
result.feature_map[-1].shape,
(self.batch_size, model.channels[-1], 14, 14),
)
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, pixel_values = config_and_inputs
inputs_dict = {"pixel_values": pixel_values}
return config, inputs_dict
@require_torch
@require_timm
class TimmBackboneModelTest(ModelTesterMixin, BackboneTesterMixin, unittest.TestCase):
all_model_classes = (TimmBackbone,) if is_torch_available() else ()
test_resize_embeddings = False
test_head_masking = False
test_pruning = False
has_attentions = False
def setUp(self):
self.model_tester = TimmBackboneModelTester(self)
self.config_tester = ConfigTester(self, config_class=PretrainedConfig, has_text_modality=False)
def test_config(self):
self.config_tester.create_and_test_config_to_json_string()
self.config_tester.create_and_test_config_to_json_file()
self.config_tester.create_and_test_config_from_and_save_pretrained()
self.config_tester.create_and_test_config_with_num_labels()
self.config_tester.check_config_can_be_init_without_params()
self.config_tester.check_config_arguments_init()
def test_timm_transformer_backbone_equivalence(self):
timm_checkpoint = "resnet18"
transformers_checkpoint = "microsoft/resnet-18"
timm_model = AutoBackbone.from_pretrained(timm_checkpoint, use_timm_backbone=True)
transformers_model = AutoBackbone.from_pretrained(transformers_checkpoint)
self.assertEqual(len(timm_model.out_features), len(transformers_model.out_features))
self.assertEqual(len(timm_model.stage_names), len(transformers_model.stage_names))
self.assertEqual(timm_model.channels, transformers_model.channels)
# Out indices are set to the last layer by default. For timm models, we don't know
# the number of layers in advance, so we set it to (-1,), whereas for transformers
# models, we set it to [len(stage_names) - 1] (kept for backward compatibility).
self.assertEqual(timm_model.out_indices, (-1,))
self.assertEqual(transformers_model.out_indices, [len(timm_model.stage_names) - 1])
timm_model = AutoBackbone.from_pretrained(timm_checkpoint, use_timm_backbone=True, out_indices=[1, 2, 3])
transformers_model = AutoBackbone.from_pretrained(transformers_checkpoint, out_indices=[1, 2, 3])
self.assertEqual(timm_model.out_indices, transformers_model.out_indices)
self.assertEqual(len(timm_model.out_features), len(transformers_model.out_features))
self.assertEqual(timm_model.channels, transformers_model.channels)
@unittest.skip("TimmBackbone doesn't support feed forward chunking")
def test_feed_forward_chunking(self):
pass
@unittest.skip("TimmBackbone doesn't have num_hidden_layers attribute")
def test_hidden_states_output(self):
pass
@unittest.skip("TimmBackbone initialization is managed on the timm side")
def test_initialization(self):
pass
@unittest.skip("TimmBackbone models doesn't have inputs_embeds")
def test_inputs_embeds(self):
pass
@unittest.skip("TimmBackbone models doesn't have inputs_embeds")
def test_model_common_attributes(self):
pass
@unittest.skip("TimmBackbone model cannot be created without specifying a backbone checkpoint")
def test_from_pretrained_no_checkpoint(self):
pass
@unittest.skip("Only checkpoints on timm can be loaded into TimmBackbone")
def test_save_load(self):
pass
@unittest.skip("model weights aren't tied in TimmBackbone.")
def test_tie_model_weights(self):
pass
@unittest.skip("model weights aren't tied in TimmBackbone.")
def test_tied_model_weights_key_ignore(self):
pass
@unittest.skip("TimmBackbone doesn't have hidden size info in its configuration.")
def test_channels(self):
pass
@unittest.skip("TimmBackbone doesn't support output_attentions.")
def test_torchscript_output_attentions(self):
pass
@unittest.skip("Safetensors is not supported by timm.")
def test_can_use_safetensors(self):
pass
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
signature = inspect.signature(model.forward)
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()]
expected_arg_names = ["pixel_values"]
self.assertListEqual(arg_names[:1], expected_arg_names)
def test_retain_grad_hidden_states_attentions(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.output_hidden_states = True
config.output_attentions = self.has_attentions
# no need to test all models as different heads yield the same functionality
model_class = self.all_model_classes[0]
model = model_class(config)
model.to(torch_device)
inputs = self._prepare_for_class(inputs_dict, model_class)
outputs = model(**inputs)
output = outputs[0][-1]
# Encoder-/Decoder-only models
hidden_states = outputs.hidden_states[0]
hidden_states.retain_grad()
if self.has_attentions:
attentions = outputs.attentions[0]
attentions.retain_grad()
output.flatten()[0].backward(retain_graph=True)
self.assertIsNotNone(hidden_states.grad)
if self.has_attentions:
self.assertIsNotNone(attentions.grad)
# TimmBackbone config doesn't have out_features attribute
def test_create_from_modified_config(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
result = model(**inputs_dict)
self.assertEqual(len(result.feature_maps), len(config.out_indices))
self.assertEqual(len(model.channels), len(config.out_indices))
# Check output of last stage is taken if out_features=None, out_indices=None
modified_config = copy.deepcopy(config)
modified_config.out_indices = None
model = model_class(modified_config)
model.to(torch_device)
model.eval()
result = model(**inputs_dict)
self.assertEqual(len(result.feature_maps), 1)
self.assertEqual(len(model.channels), 1)
# Check backbone can be initialized with fresh weights
modified_config = copy.deepcopy(config)
modified_config.use_pretrained_backbone = False
model = model_class(modified_config)
model.to(torch_device)
model.eval()
result = model(**inputs_dict)
...@@ -17,6 +17,7 @@ import copy ...@@ -17,6 +17,7 @@ import copy
import inspect import inspect
from transformers.testing_utils import require_torch, torch_device from transformers.testing_utils import require_torch, torch_device
from transformers.utils.backbone_utils import BackboneType
@require_torch @require_torch
...@@ -104,6 +105,8 @@ class BackboneTesterMixin: ...@@ -104,6 +105,8 @@ class BackboneTesterMixin:
self.assertEqual(len(result.feature_maps), len(config.out_features)) self.assertEqual(len(result.feature_maps), len(config.out_features))
self.assertEqual(len(model.channels), len(config.out_features)) self.assertEqual(len(model.channels), len(config.out_features))
self.assertEqual(len(result.feature_maps), len(config.out_indices))
self.assertEqual(len(model.channels), len(config.out_indices))
# Check output of last stage is taken if out_features=None, out_indices=None # Check output of last stage is taken if out_features=None, out_indices=None
modified_config = copy.deepcopy(config) modified_config = copy.deepcopy(config)
...@@ -140,6 +143,7 @@ class BackboneTesterMixin: ...@@ -140,6 +143,7 @@ class BackboneTesterMixin:
for backbone_class in self.all_model_classes: for backbone_class in self.all_model_classes:
backbone = backbone_class(config) backbone = backbone_class(config)
self.assertTrue(hasattr(backbone, "backbone_type"))
self.assertTrue(hasattr(backbone, "stage_names")) self.assertTrue(hasattr(backbone, "stage_names"))
self.assertTrue(hasattr(backbone, "num_features")) self.assertTrue(hasattr(backbone, "num_features"))
self.assertTrue(hasattr(backbone, "out_indices")) self.assertTrue(hasattr(backbone, "out_indices"))
...@@ -147,6 +151,7 @@ class BackboneTesterMixin: ...@@ -147,6 +151,7 @@ class BackboneTesterMixin:
self.assertTrue(hasattr(backbone, "out_feature_channels")) self.assertTrue(hasattr(backbone, "out_feature_channels"))
self.assertTrue(hasattr(backbone, "channels")) self.assertTrue(hasattr(backbone, "channels"))
self.assertIsInstance(backbone.backbone_type, BackboneType)
# Verify num_features has been initialized in the backbone init # Verify num_features has been initialized in the backbone init
self.assertIsNotNone(backbone.num_features) self.assertIsNotNone(backbone.num_features)
self.assertTrue(len(backbone.channels) == len(backbone.out_indices)) self.assertTrue(len(backbone.channels) == len(backbone.out_indices))
......
...@@ -77,6 +77,7 @@ SPECIAL_CASES_TO_ALLOW = { ...@@ -77,6 +77,7 @@ SPECIAL_CASES_TO_ALLOW = {
"AutoformerConfig": ["num_static_real_features", "num_time_features"], "AutoformerConfig": ["num_static_real_features", "num_time_features"],
} }
# TODO (ydshieh): Check the failing cases, try to fix them or move some cases to the above block once we are sure # TODO (ydshieh): Check the failing cases, try to fix them or move some cases to the above block once we are sure
SPECIAL_CASES_TO_ALLOW.update( SPECIAL_CASES_TO_ALLOW.update(
{ {
...@@ -172,6 +173,8 @@ def check_attribute_being_used(config_class, attributes, default_value, source_s ...@@ -172,6 +173,8 @@ def check_attribute_being_used(config_class, attributes, default_value, source_s
"mask_index", "mask_index",
"image_size", "image_size",
"use_cache", "use_cache",
"out_features",
"out_indices",
] ]
attributes_used_in_generation = ["encoder_no_repeat_ngram_size"] attributes_used_in_generation = ["encoder_no_repeat_ngram_size"]
......
...@@ -39,6 +39,7 @@ CONFIG_CLASSES_TO_IGNORE_FOR_DOCSTRING_CHECKPOINT_CHECK = { ...@@ -39,6 +39,7 @@ CONFIG_CLASSES_TO_IGNORE_FOR_DOCSTRING_CHECKPOINT_CHECK = {
"EncoderDecoderConfig", "EncoderDecoderConfig",
"RagConfig", "RagConfig",
"SpeechEncoderDecoderConfig", "SpeechEncoderDecoderConfig",
"TimmBackboneConfig",
"VisionEncoderDecoderConfig", "VisionEncoderDecoderConfig",
"VisionTextDualEncoderConfig", "VisionTextDualEncoderConfig",
"LlamaConfig", "LlamaConfig",
......
...@@ -517,6 +517,7 @@ MODELS_NOT_IN_README = [ ...@@ -517,6 +517,7 @@ MODELS_NOT_IN_README = [
"Speech Encoder decoder", "Speech Encoder decoder",
"Speech2Text", "Speech2Text",
"Speech2Text2", "Speech2Text2",
"TimmBackbone",
"Vision Encoder decoder", "Vision Encoder decoder",
"VisionTextDualEncoder", "VisionTextDualEncoder",
] ]
......
...@@ -408,6 +408,7 @@ def get_model_modules(): ...@@ -408,6 +408,7 @@ def get_model_modules():
"modeling_speech_encoder_decoder", "modeling_speech_encoder_decoder",
"modeling_flax_speech_encoder_decoder", "modeling_flax_speech_encoder_decoder",
"modeling_flax_vision_encoder_decoder", "modeling_flax_vision_encoder_decoder",
"modeling_timm_backbone",
"modeling_transfo_xl_utilities", "modeling_transfo_xl_utilities",
"modeling_tf_auto", "modeling_tf_auto",
"modeling_tf_encoder_decoder", "modeling_tf_encoder_decoder",
...@@ -846,6 +847,8 @@ SHOULD_HAVE_THEIR_OWN_PAGE = [ ...@@ -846,6 +847,8 @@ SHOULD_HAVE_THEIR_OWN_PAGE = [
"NatBackbone", "NatBackbone",
"ResNetBackbone", "ResNetBackbone",
"SwinBackbone", "SwinBackbone",
"TimmBackbone",
"TimmBackboneConfig",
] ]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment