Unverified Commit a717e031 authored by amyeroberts's avatar amyeroberts Committed by GitHub
Browse files

Add TimmBackbone model (#22619)



* Add test_backbone for convnext

* Add TimmBackbone model

* Add check for backbone type

* Tidying up - config checks

* Update convnextv2

* Tidy up

* Fix indices & clearer comment

* Exceptions for config checks

* Correclty update config for tests

* Safer imports

* Safer safer imports

* Fix where decorators go

* Update import logic and backbone tests

* More import fixes

* Fixup

* Only import all_models if torch available

* Fix kwarg updates in from_pretrained & main rebase

* Tidy up

* Add tests for AutoBackbone

* Tidy up

* Fix import error

* Fix up

* Install nattan in doc_test_job

* Revert back to setting self._out_xxx directly

* Bug fix - out_indices mapping from out_features

* Fix tests

* Dont accept output_loading_info for Timm models

* Set out_xxx and don't remap

* Use smaller checkpoint for test

* Don't remap timm indices - check out_indices based on stage names

* Skip test as it's n/a

* Apply suggestions from code review
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Cleaner imports / spelling is hard

---------
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent b8935980
......@@ -456,6 +456,7 @@ doc_test_job = CircleCIJob(
"pip install -e .[dev]",
"pip install git+https://github.com/huggingface/accelerate",
"pip install --upgrade pytest pytest-sugar",
"pip install natten",
"find -name __pycache__ -delete",
"find . -name \*.pyc -delete",
# Add an empty file to keep the test step running correctly even no file is selected to be tested.
......
......@@ -424,6 +424,7 @@ Flax), PyTorch, and/or TensorFlow.
| TAPAS | ✅ | ❌ | ✅ | ✅ | ❌ |
| Time Series Transformer | ❌ | ❌ | ✅ | ❌ | ❌ |
| TimeSformer | ❌ | ❌ | ✅ | ❌ | ❌ |
| TimmBackbone | ❌ | ❌ | ❌ | ❌ | ❌ |
| Trajectory Transformer | ❌ | ❌ | ✅ | ❌ | ❌ |
| Transformer-XL | ✅ | ❌ | ✅ | ✅ | ❌ |
| TrOCR | ❌ | ❌ | ✅ | ❌ | ❌ |
......
......@@ -483,6 +483,7 @@ _import_structure = {
"TimeSeriesTransformerConfig",
],
"models.timesformer": ["TIMESFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "TimesformerConfig"],
"models.timm_backbone": ["TimmBackboneConfig"],
"models.trajectory_transformer": [
"TRAJECTORY_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP",
"TrajectoryTransformerConfig",
......@@ -2578,6 +2579,7 @@ else:
"TimesformerPreTrainedModel",
]
)
_import_structure["models.timm_backbone"].extend(["TimmBackbone"])
_import_structure["models.trajectory_transformer"].extend(
[
"TRAJECTORY_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
......@@ -4288,6 +4290,7 @@ if TYPE_CHECKING:
TimeSeriesTransformerConfig,
)
from .models.timesformer import TIMESFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, TimesformerConfig
from .models.timm_backbone import TimmBackboneConfig
from .models.trajectory_transformer import (
TRAJECTORY_TRANSFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP,
TrajectoryTransformerConfig,
......@@ -6024,6 +6027,7 @@ if TYPE_CHECKING:
TimesformerModel,
TimesformerPreTrainedModel,
)
from .models.timm_backbone import TimmBackbone
from .models.trajectory_transformer import (
TRAJECTORY_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
TrajectoryTransformerModel,
......
......@@ -186,6 +186,7 @@ from . import (
tapex,
time_series_transformer,
timesformer,
timm_backbone,
trajectory_transformer,
transfo_xl,
trocr,
......
......@@ -19,7 +19,7 @@ from collections import OrderedDict
from ...configuration_utils import PretrainedConfig
from ...dynamic_module_utils import get_class_from_dynamic_module
from ...utils import copy_func, logging
from ...utils import copy_func, logging, requires_backends
from .configuration_auto import AutoConfig, model_type_to_module_name, replace_list_option_in_docstrings
......@@ -515,6 +515,48 @@ class _BaseAutoModelClass:
cls._model_mapping.register(config_class, model_class)
class _BaseAutoBackboneClass(_BaseAutoModelClass):
# Base class for auto backbone models.
_model_mapping = None
@classmethod
def _load_timm_backbone_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
requires_backends(cls, ["vision", "timm"])
from ...models.timm_backbone import TimmBackboneConfig
config = kwargs.pop("config", TimmBackboneConfig())
use_timm = kwargs.pop("use_timm_backbone", True)
if not use_timm:
raise ValueError("`use_timm_backbone` must be `True` for timm backbones")
if kwargs.get("out_features", None) is not None:
raise ValueError("Cannot specify `out_features` for timm backbones")
if kwargs.get("output_loading_info", False):
raise ValueError("Cannot specify `output_loading_info=True` when loading from timm")
num_channels = kwargs.pop("num_channels", config.num_channels)
features_only = kwargs.pop("features_only", config.features_only)
use_pretrained_backbone = kwargs.pop("use_pretrained_backbone", config.use_pretrained_backbone)
out_indices = kwargs.pop("out_indices", config.out_indices)
config = TimmBackboneConfig(
backbone=pretrained_model_name_or_path,
num_channels=num_channels,
features_only=features_only,
use_pretrained_backbone=use_pretrained_backbone,
out_indices=out_indices,
)
return super().from_config(config, **kwargs)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
if kwargs.get("use_timm_backbone", False):
return cls._load_timm_backbone_from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
def insert_head_doc(docstring, head_doc=""):
if len(head_doc) > 0:
return docstring.replace(
......
......@@ -186,6 +186,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
("tapas", "TapasConfig"),
("time_series_transformer", "TimeSeriesTransformerConfig"),
("timesformer", "TimesformerConfig"),
("timm_backbone", "TimmBackboneConfig"),
("trajectory_transformer", "TrajectoryTransformerConfig"),
("transfo-xl", "TransfoXLConfig"),
("trocr", "TrOCRConfig"),
......@@ -579,6 +580,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
("tapex", "TAPEX"),
("time_series_transformer", "Time Series Transformer"),
("timesformer", "TimeSformer"),
("timm_backbone", "TimmBackbone"),
("trajectory_transformer", "Trajectory Transformer"),
("transfo-xl", "Transformer-XL"),
("trocr", "TrOCR"),
......
......@@ -18,7 +18,7 @@ import warnings
from collections import OrderedDict
from ...utils import logging
from .auto_factory import _BaseAutoModelClass, _LazyAutoMapping, auto_class_update
from .auto_factory import _BaseAutoBackboneClass, _BaseAutoModelClass, _LazyAutoMapping, auto_class_update
from .configuration_auto import CONFIG_MAPPING_NAMES
......@@ -179,6 +179,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
("tapas", "TapasModel"),
("time_series_transformer", "TimeSeriesTransformerModel"),
("timesformer", "TimesformerModel"),
("timm_backbone", "TimmBackbone"),
("trajectory_transformer", "TrajectoryTransformerModel"),
("transfo-xl", "TransfoXLModel"),
("tvlt", "TvltModel"),
......@@ -999,6 +1000,7 @@ MODEL_FOR_BACKBONE_MAPPING_NAMES = OrderedDict(
("nat", "NatBackbone"),
("resnet", "ResNetBackbone"),
("swin", "SwinBackbone"),
("timm_backbone", "TimmBackbone"),
]
)
......@@ -1330,7 +1332,7 @@ class AutoModelForAudioXVector(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_AUDIO_XVECTOR_MAPPING
class AutoBackbone(_BaseAutoModelClass):
class AutoBackbone(_BaseAutoBackboneClass):
_model_mapping = MODEL_FOR_BACKBONE_MAPPING
......
......@@ -39,7 +39,7 @@ from ...utils import (
logging,
replace_return_docstrings,
)
from ...utils.backbone_utils import BackboneMixin, get_aligned_output_features_output_indices
from ...utils.backbone_utils import BackboneMixin
from .configuration_bit import BitConfig
......@@ -845,14 +845,10 @@ class BitForImageClassification(BitPreTrainedModel):
class BitBackbone(BitPreTrainedModel, BackboneMixin):
def __init__(self, config):
super().__init__(config)
super()._init_backbone(config)
self.stage_names = config.stage_names
self.bit = BitModel(config)
self.num_features = [config.embedding_size] + config.hidden_sizes
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
config.out_features, config.out_indices, self.stage_names
)
# initialize weights and apply final processing
self.post_init()
......
......@@ -37,7 +37,7 @@ from ...utils import (
logging,
replace_return_docstrings,
)
from ...utils.backbone_utils import BackboneMixin, get_aligned_output_features_output_indices
from ...utils.backbone_utils import BackboneMixin
from .configuration_convnext import ConvNextConfig
......@@ -481,15 +481,11 @@ class ConvNextForImageClassification(ConvNextPreTrainedModel):
class ConvNextBackbone(ConvNextPreTrainedModel, BackboneMixin):
def __init__(self, config):
super().__init__(config)
super()._init_backbone(config)
self.stage_names = config.stage_names
self.embeddings = ConvNextEmbeddings(config)
self.encoder = ConvNextEncoder(config)
self.num_features = [config.hidden_sizes[0]] + config.hidden_sizes
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
config.out_features, config.out_indices, self.stage_names
)
# Add layer norms to hidden states of out_features
hidden_states_norms = {}
......
......@@ -37,7 +37,7 @@ from ...utils import (
logging,
replace_return_docstrings,
)
from ...utils.backbone_utils import BackboneMixin, get_aligned_output_features_output_indices
from ...utils.backbone_utils import BackboneMixin
from .configuration_convnextv2 import ConvNextV2Config
......@@ -504,15 +504,11 @@ class ConvNextV2ForImageClassification(ConvNextV2PreTrainedModel):
class ConvNextV2Backbone(ConvNextV2PreTrainedModel, BackboneMixin):
def __init__(self, config):
super().__init__(config)
super()._init_backbone(config)
self.stage_names = config.stage_names
self.embeddings = ConvNextV2Embeddings(config)
self.encoder = ConvNextV2Encoder(config)
self.num_features = [config.hidden_sizes[0]] + config.hidden_sizes
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
config.out_features, config.out_indices, self.stage_names
)
# Add layer norms to hidden states of out_features
hidden_states_norms = {}
......
......@@ -39,7 +39,7 @@ from ...utils import (
replace_return_docstrings,
requires_backends,
)
from ...utils.backbone_utils import BackboneMixin, get_aligned_output_features_output_indices
from ...utils.backbone_utils import BackboneMixin
from .configuration_dinat import DinatConfig
......@@ -883,17 +883,12 @@ class DinatForImageClassification(DinatPreTrainedModel):
class DinatBackbone(DinatPreTrainedModel, BackboneMixin):
def __init__(self, config):
super().__init__(config)
super()._init_backbone(config)
requires_backends(self, ["natten"])
self.stage_names = config.stage_names
self.embeddings = DinatEmbeddings(config)
self.encoder = DinatEncoder(config)
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
config.out_features, config.out_indices, self.stage_names
)
self.num_features = [config.embed_dim] + [int(config.embed_dim * 2**i) for i in range(len(config.depths))]
# Add layer norms to hidden states of out_features
......
......@@ -36,7 +36,7 @@ from ...utils import (
logging,
replace_return_docstrings,
)
from ...utils.backbone_utils import BackboneMixin, get_aligned_output_features_output_indices
from ...utils.backbone_utils import BackboneMixin
from .configuration_focalnet import FocalNetConfig
......@@ -981,16 +981,12 @@ class FocalNetForImageClassification(FocalNetPreTrainedModel):
FOCALNET_START_DOCSTRING,
)
class FocalNetBackbone(FocalNetPreTrainedModel, BackboneMixin):
def __init__(self, config):
def __init__(self, config: FocalNetConfig):
super().__init__(config)
self.stage_names = config.stage_names
self.focalnet = FocalNetModel(config)
super()._init_backbone(config)
self.num_features = [config.embed_dim] + config.hidden_sizes
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
config.out_features, config.out_indices, self.stage_names
)
self.focalnet = FocalNetModel(config)
# initialize weights and apply final processing
self.post_init()
......
......@@ -29,7 +29,7 @@ from ...file_utils import ModelOutput
from ...modeling_outputs import BackboneOutput
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer
from ...utils.backbone_utils import BackboneMixin, get_aligned_output_features_output_indices
from ...utils.backbone_utils import BackboneMixin
from .configuration_maskformer_swin import MaskFormerSwinConfig
......@@ -852,17 +852,11 @@ class MaskFormerSwinBackbone(MaskFormerSwinPreTrainedModel, BackboneMixin):
def __init__(self, config: MaskFormerSwinConfig):
super().__init__(config)
super()._init_backbone(config)
self.stage_names = config.stage_names
self.model = MaskFormerSwinModel(config)
self._out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]]
if "stem" in self.out_features:
raise ValueError("This backbone does not support 'stem' in the `out_features`.")
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
config.out_features, config.out_indices, self.stage_names
)
self.num_features = [config.embed_dim] + [int(config.embed_dim * 2**i) for i in range(len(config.depths))]
self.hidden_states_norms = nn.ModuleList(
[nn.LayerNorm(num_channels) for num_channels in self.num_features[1:]]
......
......@@ -39,7 +39,7 @@ from ...utils import (
replace_return_docstrings,
requires_backends,
)
from ...utils.backbone_utils import BackboneMixin, get_aligned_output_features_output_indices
from ...utils.backbone_utils import BackboneMixin
from .configuration_nat import NatConfig
......@@ -861,17 +861,12 @@ class NatForImageClassification(NatPreTrainedModel):
class NatBackbone(NatPreTrainedModel, BackboneMixin):
def __init__(self, config):
super().__init__(config)
super()._init_backbone(config)
requires_backends(self, ["natten"])
self.stage_names = config.stage_names
self.embeddings = NatEmbeddings(config)
self.encoder = NatEncoder(config)
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
config.out_features, config.out_indices, self.stage_names
)
self.num_features = [config.embed_dim] + [int(config.embed_dim * 2**i) for i in range(len(config.depths))]
# Add layer norms to hidden states of out_features
......
......@@ -36,7 +36,7 @@ from ...utils import (
logging,
replace_return_docstrings,
)
from ...utils.backbone_utils import BackboneMixin, get_aligned_output_features_output_indices
from ...utils.backbone_utils import BackboneMixin
from .configuration_resnet import ResNetConfig
......@@ -432,16 +432,12 @@ class ResNetForImageClassification(ResNetPreTrainedModel):
class ResNetBackbone(ResNetPreTrainedModel, BackboneMixin):
def __init__(self, config):
super().__init__(config)
super()._init_backbone(config)
self.stage_names = config.stage_names
self.num_features = [config.embedding_size] + config.hidden_sizes
self.embedder = ResNetEmbeddings(config)
self.encoder = ResNetEncoder(config)
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
config.out_features, config.out_indices, self.stage_names
)
self.num_features = [config.embedding_size] + config.hidden_sizes
# initialize weights and apply final processing
self.post_init()
......
......@@ -38,7 +38,7 @@ from ...utils import (
logging,
replace_return_docstrings,
)
from ...utils.backbone_utils import BackboneMixin, get_aligned_output_features_output_indices
from ...utils.backbone_utils import BackboneMixin
from .configuration_swin import SwinConfig
......@@ -1259,17 +1259,12 @@ class SwinForImageClassification(SwinPreTrainedModel):
class SwinBackbone(SwinPreTrainedModel, BackboneMixin):
def __init__(self, config: SwinConfig):
super().__init__(config)
super()._init_backbone(config)
self.stage_names = config.stage_names
self.num_features = [config.embed_dim] + [int(config.embed_dim * 2**i) for i in range(len(config.depths))]
self.embeddings = SwinEmbeddings(config)
self.encoder = SwinEncoder(config, self.embeddings.patch_grid)
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
config.out_features, config.out_indices, self.stage_names
)
self.num_features = [config.embed_dim] + [int(config.embed_dim * 2**i) for i in range(len(config.depths))]
# Add layer norms to hidden states of out_features
hidden_states_norms = {}
for stage, num_channels in zip(self._out_features, self.channels):
......
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
_import_structure = {"configuration_timm_backbone": ["TimmBackboneConfig"]}
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_timm_backbone"] = ["TimmBackbone"]
if TYPE_CHECKING:
from .configuration_timm_backbone import TimmBackboneConfig
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_timm_backbone import TimmBackbone
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Configuration for Backbone models"""
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__)
class TimmBackboneConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration for a timm backbone [`TimmBackbone`].
It is used to instantiate a timm backbone model according to the specified arguments, defining the model.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
backbone (`str`, *optional*):
The timm checkpoint to load.
num_channels (`int`, *optional*, defaults to 3):
The number of input channels.
features_only (`bool`, *optional*, defaults to `True`):
Whether to output only the features or also the logits.
use_pretrained_backbone (`bool`, *optional*, defaults to `True`):
Whether to use a pretrained backbone.
out_indices (`List[int]`, *optional*):
If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how
many stages the model has). Will default to the last stage if unset.
Example:
```python
>>> from transformers import TimmBackboneConfig, TimmBackbone
>>> # Initializing a timm backbone
>>> configuration = TimmBackboneConfig("resnet50")
>>> # Initializing a model from the configuration
>>> model = TimmBackbone(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```
"""
model_type = "timm_backbone"
def __init__(
self,
backbone=None,
num_channels=3,
features_only=True,
use_pretrained_backbone=True,
out_indices=None,
**kwargs,
):
super().__init__(**kwargs)
self.backbone = backbone
self.num_channels = num_channels
self.features_only = features_only
self.use_pretrained_backbone = use_pretrained_backbone
self.use_timm_backbone = True
self.out_indices = out_indices if out_indices is not None else (-1,)
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple, Union
from ...modeling_outputs import BackboneOutput
from ...modeling_utils import PreTrainedModel
from ...utils import is_timm_available, is_torch_available, requires_backends
from ...utils.backbone_utils import BackboneMixin
from .configuration_timm_backbone import TimmBackboneConfig
if is_timm_available():
import timm
if is_torch_available():
from torch import Tensor
class TimmBackbone(PreTrainedModel, BackboneMixin):
"""
Wrapper class for timm models to be used as backbones. This enables using the timm models interchangeably with the
other models in the library keeping the same API.
"""
main_input_name = "pixel_values"
supports_gradient_checkpointing = False
config_class = TimmBackboneConfig
def __init__(self, config, **kwargs):
requires_backends(self, "timm")
super().__init__(config)
self.config = config
if config.backbone is None:
raise ValueError("backbone is not set in the config. Please set it to a timm model name.")
if config.backbone not in timm.list_models():
raise ValueError(f"backbone {config.backbone} is not supported by timm.")
if hasattr(config, "out_features") and config.out_features is not None:
raise ValueError("out_features is not supported by TimmBackbone. Please use out_indices instead.")
pretrained = getattr(config, "use_pretrained_backbone", None)
if pretrained is None:
raise ValueError("use_pretrained_backbone is not set in the config. Please set it to True or False.")
# We just take the final layer by default. This matches the default for the transformers models.
out_indices = config.out_indices if getattr(config, "out_indices", None) is not None else (-1,)
self._backbone = timm.create_model(
config.backbone,
pretrained=pretrained,
# This is currently not possible for transformer architectures.
features_only=config.features_only,
in_chans=config.num_channels,
out_indices=out_indices,
**kwargs,
)
# These are used to control the output of the model when called. If output_hidden_states is True, then
# return_layers is modified to include all layers.
self._return_layers = self._backbone.return_layers
self._all_layers = {layer["module"]: str(i) for i, layer in enumerate(self._backbone.feature_info.info)}
super()._init_backbone(config)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
requires_backends(cls, ["vision", "timm"])
from ...models.timm_backbone import TimmBackboneConfig
config = kwargs.pop("config", TimmBackboneConfig())
use_timm = kwargs.pop("use_timm_backbone", True)
if not use_timm:
raise ValueError("use_timm_backbone must be True for timm backbones")
num_channels = kwargs.pop("num_channels", config.num_channels)
features_only = kwargs.pop("features_only", config.features_only)
use_pretrained_backbone = kwargs.pop("use_pretrained_backbone", config.use_pretrained_backbone)
out_indices = kwargs.pop("out_indices", config.out_indices)
config = TimmBackboneConfig(
backbone=pretrained_model_name_or_path,
num_channels=num_channels,
features_only=features_only,
use_pretrained_backbone=use_pretrained_backbone,
out_indices=out_indices,
)
return super()._from_config(config, **kwargs)
def _init_weights(self, module):
"""
Empty init weights function to ensure compatibility of the class in the library.
"""
pass
def forward(
self, pixel_values, output_attentions=None, output_hidden_states=None, return_dict=None, **kwargs
) -> Union[BackboneOutput, Tuple[Tensor, ...]]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
if output_attentions:
raise ValueError("Cannot output attentions for timm backbones at the moment")
if output_hidden_states:
# We modify the return layers to include all the stages of the backbone
self._backbone.return_layers = self._all_layers
hidden_states = self._backbone(pixel_values, **kwargs)
self._backbone.return_layers = self._return_layers
feature_maps = tuple(hidden_states[i] for i in self.out_indices)
else:
feature_maps = self._backbone(pixel_values, **kwargs)
hidden_states = None
feature_maps = tuple(feature_maps)
hidden_states = tuple(hidden_states) if hidden_states is not None else None
if not return_dict:
output = (feature_maps,)
if output_hidden_states:
output = output + (hidden_states,)
return output
return BackboneOutput(feature_maps=feature_maps, hidden_states=hidden_states, attentions=None)
......@@ -15,10 +15,16 @@
""" Collection of utils to be used by backbones and their components."""
import enum
import inspect
from typing import Iterable, List, Optional, Tuple, Union
class BackboneType(enum.Enum):
TIMM = "timm"
TRANSFORMERS = "transformers"
def verify_out_features_out_indices(
out_features: Optional[Iterable[str]], out_indices: Optional[Iterable[int]], stage_names: Optional[Iterable[str]]
):
......@@ -72,7 +78,7 @@ def _align_output_features_output_indices(
out_indices = [len(stage_names) - 1]
out_features = [stage_names[-1]]
elif out_indices is None and out_features is not None:
out_indices = [stage_names.index(layer) for layer in stage_names if layer in out_features]
out_indices = [stage_names.index(layer) for layer in out_features]
elif out_features is None and out_indices is not None:
out_features = [stage_names[idx] for idx in out_indices]
return out_features, out_indices
......@@ -110,29 +116,57 @@ def get_aligned_output_features_output_indices(
class BackboneMixin:
@property
def out_feature_channels(self):
# the current backbones will output the number of channels for each stage
# even if that stage is not in the out_features list.
return {stage: self.num_features[i] for i, stage in enumerate(self.stage_names)}
backbone_type: Optional[BackboneType] = None
@property
def channels(self):
return [self.out_feature_channels[name] for name in self.out_features]
def _init_timm_backbone(self, config) -> None:
"""
Initialize the backbone model from timm The backbone must already be loaded to self._backbone
"""
if getattr(self, "_backbone", None) is None:
raise ValueError("self._backbone must be set before calling _init_timm_backbone")
def forward_with_filtered_kwargs(self, *args, **kwargs):
signature = dict(inspect.signature(self.forward).parameters)
filtered_kwargs = {k: v for k, v in kwargs.items() if k in signature}
return self(*args, **filtered_kwargs)
# These will diagree with the defaults for the transformers models e.g. for resnet50
# the transformer model has out_features = ['stem', 'stage1', 'stage2', 'stage3', 'stage4']
# the timm model has out_features = ['act', 'layer1', 'layer2', 'layer3', 'layer4']
self.stage_names = [stage["module"] for stage in self._backbone.feature_info.info]
self.num_features = [stage["num_chs"] for stage in self._backbone.feature_info.info]
out_indices = self._backbone.feature_info.out_indices
out_features = self._backbone.feature_info.module_name()
def forward(
self,
pixel_values,
output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
raise NotImplementedError("This method should be implemented by the derived class.")
# We verify the out indices and out features are valid
verify_out_features_out_indices(
out_features=out_features, out_indices=out_indices, stage_names=self.stage_names
)
self._out_features, self._out_indices = out_features, out_indices
def _init_transformers_backbone(self, config) -> None:
stage_names = getattr(config, "stage_names")
out_features = getattr(config, "out_features", None)
out_indices = getattr(config, "out_indices", None)
self.stage_names = stage_names
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
out_features=out_features, out_indices=out_indices, stage_names=stage_names
)
# Number of channels for each stage. This is set in the transformer backbone model init
self.num_features = None
def _init_backbone(self, config) -> None:
"""
Method to initialize the backbone. This method is called by the constructor of the base class after the
pretrained model weights have been loaded.
"""
self.config = config
self.use_timm_backbone = getattr(config, "use_timm_backbone", False)
self.backbone_type = BackboneType.TIMM if self.use_timm_backbone else BackboneType.TRANSFORMERS
if self.backbone_type == BackboneType.TIMM:
self._init_timm_backbone(config)
elif self.backbone_type == BackboneType.TRANSFORMERS:
self._init_transformers_backbone(config)
else:
raise ValueError(f"backbone_type {self.backbone_type} not supported.")
@property
def out_features(self):
......@@ -160,6 +194,40 @@ class BackboneMixin:
out_features=None, out_indices=out_indices, stage_names=self.stage_names
)
@property
def out_feature_channels(self):
# the current backbones will output the number of channels for each stage
# even if that stage is not in the out_features list.
return {stage: self.num_features[i] for i, stage in enumerate(self.stage_names)}
@property
def channels(self):
return [self.out_feature_channels[name] for name in self.out_features]
def forward_with_filtered_kwargs(self, *args, **kwargs):
signature = dict(inspect.signature(self.forward).parameters)
filtered_kwargs = {k: v for k, v in kwargs.items() if k in signature}
return self(*args, **filtered_kwargs)
def forward(
self,
pixel_values,
output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
raise NotImplementedError("This method should be implemented by the derived class.")
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default `to_dict()` from `PretrainedConfig` to
include the `out_features` and `out_indices` attributes.
"""
output = super().to_dict()
output["out_features"] = output.pop("_out_features")
output["out_indices"] = output.pop("_out_indices")
return output
class BackboneConfigMixin:
"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment