Unverified Commit aafa7ce7 authored by amyeroberts's avatar amyeroberts Committed by GitHub
Browse files

[`DETR`] Remove timm hardcoded logic in modeling files (#29038)



* Enable instantiating model with pretrained backbone weights

* Clarify pretrained import

* Use load_backbone instead

* Add backbone_kwargs to config

* Fix up

* Add tests

* Tidy up

* Enable instantiating model with pretrained backbone weights

* Update tests so backbone checkpoint isn't passed in

* Clarify pretrained import

* Update configs - docs and validation check

* Update src/transformers/utils/backbone_utils.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Clarify exception message

* Update config init in tests

* Add test for when use_timm_backbone=True

* Use load_backbone instead

* Add use_timm_backbone to the model configs

* Add backbone_kwargs to config

* Pass kwargs to constructors

* Draft

* Fix tests

* Add back timm - weight naming

* More tidying up

* Whoops

* Tidy up

* Handle when kwargs are none

* Update tests

* Revert test changes

* Deformable detr test - don't use default

* Don't mutate; correct model attributes

* Add some clarifying comments

* nit - grammar is hard

---------
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
parent 77ff304d
......@@ -192,10 +192,16 @@ class ConditionalDetrConfig(PretrainedConfig):
if backbone_config is not None and use_timm_backbone:
raise ValueError("You can't specify both `backbone_config` and `use_timm_backbone`.")
if backbone_kwargs is not None and backbone_kwargs and backbone_config is not None:
raise ValueError("You can't specify both `backbone_kwargs` and `backbone_config`.")
if not use_timm_backbone:
# We default to values which were previously hard-coded in the model. This enables configurability of the config
# while keeping the default behavior the same.
if use_timm_backbone and backbone_kwargs is None:
backbone_kwargs = {}
if dilation:
backbone_kwargs["output_stride"] = 16
backbone_kwargs["out_indices"] = [1, 2, 3, 4]
backbone_kwargs["in_chans"] = num_channels
# Backwards compatibility
elif not use_timm_backbone and backbone in (None, "resnet50"):
if backbone_config is None:
logger.info("`backbone_config` is `None`. Initializing the config with the default `ResNet` backbone.")
backbone_config = CONFIG_MAPPING["resnet"](out_features=["stage4"])
......
......@@ -338,12 +338,12 @@ def replace_batch_norm(model):
replace_batch_norm(module)
# Copied from transformers.models.detr.modeling_detr.DetrConvEncoder
# Copied from transformers.models.detr.modeling_detr.DetrConvEncoder with Detr->ConditionalDetr
class ConditionalDetrConvEncoder(nn.Module):
"""
Convolutional backbone, using either the AutoBackbone API or one from the timm library.
nn.BatchNorm2d layers are replaced by DetrFrozenBatchNorm2d as defined above.
nn.BatchNorm2d layers are replaced by ConditionalDetrFrozenBatchNorm2d as defined above.
"""
......@@ -352,17 +352,23 @@ class ConditionalDetrConvEncoder(nn.Module):
self.config = config
# For backwards compatibility we have to use the timm library directly instead of the AutoBackbone API
if config.use_timm_backbone:
# We default to values which were previously hard-coded. This enables configurability from the config
# using backbone arguments, while keeping the default behavior the same.
requires_backends(self, ["timm"])
kwargs = {}
kwargs = getattr(config, "backbone_kwargs", {})
kwargs = {} if kwargs is None else kwargs.copy()
out_indices = kwargs.pop("out_indices", (1, 2, 3, 4))
num_channels = kwargs.pop("in_chans", config.num_channels)
if config.dilation:
kwargs["output_stride"] = 16
kwargs["output_stride"] = kwargs.get("output_stride", 16)
backbone = create_model(
config.backbone,
pretrained=config.use_pretrained_backbone,
features_only=True,
out_indices=(1, 2, 3, 4),
in_chans=config.num_channels,
out_indices=out_indices,
in_chans=num_channels,
**kwargs,
)
else:
......
......@@ -212,7 +212,16 @@ class DeformableDetrConfig(PretrainedConfig):
if backbone_kwargs is not None and backbone_kwargs and backbone_config is not None:
raise ValueError("You can't specify both `backbone_kwargs` and `backbone_config`.")
if not use_timm_backbone:
# We default to values which were previously hard-coded in the model. This enables configurability of the config
# while keeping the default behavior the same.
if use_timm_backbone and backbone_kwargs is None:
backbone_kwargs = {}
if dilation:
backbone_kwargs["output_stride"] = 16
backbone_kwargs["out_indices"] = [2, 3, 4] if num_feature_levels > 1 else [4]
backbone_kwargs["in_chans"] = num_channels
# Backwards compatibility
elif not use_timm_backbone and backbone in (None, "resnet50"):
if backbone_config is None:
logger.info("`backbone_config` is `None`. Initializing the config with the default `ResNet` backbone.")
backbone_config = CONFIG_MAPPING["resnet"](out_features=["stage4"])
......@@ -220,6 +229,7 @@ class DeformableDetrConfig(PretrainedConfig):
backbone_model_type = backbone_config.get("model_type")
config_class = CONFIG_MAPPING[backbone_model_type]
backbone_config = config_class.from_dict(backbone_config)
self.use_timm_backbone = use_timm_backbone
self.backbone_config = backbone_config
self.num_channels = num_channels
......
......@@ -88,11 +88,31 @@ def load_cuda_kernels():
if is_vision_available():
from transformers.image_transforms import center_to_corners_format
if is_accelerate_available():
from accelerate import PartialState
from accelerate.utils import reduce
if is_timm_available():
from timm import create_model
if is_scipy_available():
from scipy.optimize import linear_sum_assignment
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "DeformableDetrConfig"
_CHECKPOINT_FOR_DOC = "sensetime/deformable-detr"
DEFORMABLE_DETR_PRETRAINED_MODEL_ARCHIVE_LIST = [
"sensetime/deformable-detr",
# See all Deformable DETR models at https://huggingface.co/models?filter=deformable-detr
]
class MultiScaleDeformableAttentionFunction(Function):
@staticmethod
def forward(
......@@ -141,21 +161,6 @@ class MultiScaleDeformableAttentionFunction(Function):
return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None
if is_scipy_available():
from scipy.optimize import linear_sum_assignment
if is_timm_available():
from timm import create_model
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "DeformableDetrConfig"
_CHECKPOINT_FOR_DOC = "sensetime/deformable-detr"
from ..deprecated._archive_maps import DEFORMABLE_DETR_PRETRAINED_MODEL_ARCHIVE_LIST # noqa: F401, E402
@dataclass
class DeformableDetrDecoderOutput(ModelOutput):
"""
......@@ -420,17 +425,23 @@ class DeformableDetrConvEncoder(nn.Module):
self.config = config
# For backwards compatibility we have to use the timm library directly instead of the AutoBackbone API
if config.use_timm_backbone:
# We default to values which were previously hard-coded. This enables configurability from the config
# using backbone arguments, while keeping the default behavior the same.
requires_backends(self, ["timm"])
kwargs = {}
kwargs = getattr(config, "backbone_kwargs", {})
kwargs = {} if kwargs is None else kwargs.copy()
out_indices = kwargs.pop("out_indices", (2, 3, 4) if config.num_feature_levels > 1 else (4,))
num_channels = kwargs.pop("in_chans", config.num_channels)
if config.dilation:
kwargs["output_stride"] = 16
kwargs["output_stride"] = kwargs.get("output_stride", 16)
backbone = create_model(
config.backbone,
pretrained=config.use_pretrained_backbone,
features_only=True,
out_indices=(2, 3, 4) if config.num_feature_levels > 1 else (4,),
in_chans=config.num_channels,
out_indices=out_indices,
in_chans=num_channels,
**kwargs,
)
else:
......
......@@ -193,7 +193,16 @@ class DetrConfig(PretrainedConfig):
if backbone_kwargs is not None and backbone_kwargs and backbone_config is not None:
raise ValueError("You can't specify both `backbone_kwargs` and `backbone_config`.")
if not use_timm_backbone:
# We default to values which were previously hard-coded in the model. This enables configurability of the config
# while keeping the default behavior the same.
if use_timm_backbone and backbone_kwargs is None:
backbone_kwargs = {}
if dilation:
backbone_kwargs["output_stride"] = 16
backbone_kwargs["out_indices"] = [1, 2, 3, 4]
backbone_kwargs["in_chans"] = num_channels
# Backwards compatibility
elif not use_timm_backbone and backbone in (None, "resnet50"):
if backbone_config is None:
logger.info("`backbone_config` is `None`. Initializing the config with the default `ResNet` backbone.")
backbone_config = CONFIG_MAPPING["resnet"](out_features=["stage4"])
......@@ -201,8 +210,9 @@ class DetrConfig(PretrainedConfig):
backbone_model_type = backbone_config.get("model_type")
config_class = CONFIG_MAPPING[backbone_model_type]
backbone_config = config_class.from_dict(backbone_config)
backbone = None
# set timm attributes to None
dilation, backbone, use_pretrained_backbone = None, None, None
dilation = None
self.use_timm_backbone = use_timm_backbone
self.backbone_config = backbone_config
......
......@@ -49,9 +49,11 @@ if is_accelerate_available():
if is_scipy_available():
from scipy.optimize import linear_sum_assignment
if is_timm_available():
from timm import create_model
if is_vision_available():
from transformers.image_transforms import center_to_corners_format
......@@ -345,17 +347,23 @@ class DetrConvEncoder(nn.Module):
self.config = config
# For backwards compatibility we have to use the timm library directly instead of the AutoBackbone API
if config.use_timm_backbone:
# We default to values which were previously hard-coded. This enables configurability from the config
# using backbone arguments, while keeping the default behavior the same.
requires_backends(self, ["timm"])
kwargs = {}
kwargs = getattr(config, "backbone_kwargs", {})
kwargs = {} if kwargs is None else kwargs.copy()
out_indices = kwargs.pop("out_indices", (1, 2, 3, 4))
num_channels = kwargs.pop("in_chans", config.num_channels)
if config.dilation:
kwargs["output_stride"] = 16
kwargs["output_stride"] = kwargs.get("output_stride", 16)
backbone = create_model(
config.backbone,
pretrained=config.use_pretrained_backbone,
features_only=True,
out_indices=(1, 2, 3, 4),
in_chans=config.num_channels,
out_indices=out_indices,
in_chans=num_channels,
**kwargs,
)
else:
......
......@@ -1075,10 +1075,10 @@ class DPTForDepthEstimation(DPTPreTrainedModel):
super().__init__(config)
self.backbone = None
if config.backbone_config is not None and config.is_hybrid is False:
self.backbone = load_backbone(config)
else:
if config.is_hybrid or config.backbone_config is None:
self.dpt = DPTModel(config, add_pooling_layer=False)
else:
self.backbone = load_backbone(config)
# Neck
self.neck = DPTNeck(config)
......
......@@ -193,7 +193,16 @@ class TableTransformerConfig(PretrainedConfig):
if backbone_kwargs is not None and backbone_kwargs and backbone_config is not None:
raise ValueError("You can't specify both `backbone_kwargs` and `backbone_config`.")
if not use_timm_backbone:
# We default to values which were previously hard-coded in the model. This enables configurability of the config
# while keeping the default behavior the same.
if use_timm_backbone and backbone_kwargs is None:
backbone_kwargs = {}
if dilation:
backbone_kwargs["output_stride"] = 16
backbone_kwargs["out_indices"] = [1, 2, 3, 4]
backbone_kwargs["in_chans"] = num_channels
# Backwards compatibility
elif not use_timm_backbone and backbone in (None, "resnet50"):
if backbone_config is None:
logger.info("`backbone_config` is `None`. Initializing the config with the default `ResNet` backbone.")
backbone_config = CONFIG_MAPPING["resnet"](out_features=["stage4"])
......@@ -201,8 +210,9 @@ class TableTransformerConfig(PretrainedConfig):
backbone_model_type = backbone_config.get("model_type")
config_class = CONFIG_MAPPING[backbone_model_type]
backbone_config = config_class.from_dict(backbone_config)
backbone = None
# set timm attributes to None
dilation, backbone, use_pretrained_backbone = None, None, None
dilation = None
self.use_timm_backbone = use_timm_backbone
self.backbone_config = backbone_config
......
......@@ -279,17 +279,23 @@ class TableTransformerConvEncoder(nn.Module):
self.config = config
# For backwards compatibility we have to use the timm library directly instead of the AutoBackbone API
if config.use_timm_backbone:
# We default to values which were previously hard-coded. This enables configurability from the config
# using backbone arguments, while keeping the default behavior the same.
requires_backends(self, ["timm"])
kwargs = {}
kwargs = getattr(config, "backbone_kwargs", {})
kwargs = {} if kwargs is None else kwargs.copy()
out_indices = kwargs.pop("out_indices", (1, 2, 3, 4))
num_channels = kwargs.pop("in_chans", config.num_channels)
if config.dilation:
kwargs["output_stride"] = 16
kwargs["output_stride"] = kwargs.get("output_stride", 16)
backbone = create_model(
config.backbone,
pretrained=config.use_pretrained_backbone,
features_only=True,
out_indices=(1, 2, 3, 4),
in_chans=config.num_channels,
out_indices=out_indices,
in_chans=num_channels,
**kwargs,
)
else:
......
......@@ -63,12 +63,13 @@ class TimmBackbone(PreTrainedModel, BackboneMixin):
# We just take the final layer by default. This matches the default for the transformers models.
out_indices = config.out_indices if getattr(config, "out_indices", None) is not None else (-1,)
in_chans = kwargs.pop("in_chans", config.num_channels)
self._backbone = timm.create_model(
config.backbone,
pretrained=pretrained,
# This is currently not possible for transformer architectures.
features_only=config.features_only,
in_chans=config.num_channels,
in_chans=in_chans,
out_indices=out_indices,
**kwargs,
)
......@@ -79,7 +80,9 @@ class TimmBackbone(PreTrainedModel, BackboneMixin):
# These are used to control the output of the model when called. If output_hidden_states is True, then
# return_layers is modified to include all layers.
self._return_layers = self._backbone.return_layers
self._return_layers = {
layer["module"]: str(layer["index"]) for layer in self._backbone.feature_info.get_dicts()
}
self._all_layers = {layer["module"]: str(i) for i, layer in enumerate(self._backbone.feature_info.info)}
super()._init_backbone(config)
......
......@@ -444,7 +444,9 @@ class ConditionalDetrModelTest(ModelTesterMixin, GenerationTesterMixin, Pipeline
# let's pick a random timm backbone
config.backbone = "tf_mobilenetv3_small_075"
config.backbone_config = None
config.use_timm_backbone = True
config.backbone_kwargs = {"out_indices": [2, 3, 4]}
for model_class in self.all_model_classes:
model = model_class(config)
......@@ -460,6 +462,14 @@ class ConditionalDetrModelTest(ModelTesterMixin, GenerationTesterMixin, Pipeline
self.model_tester.num_labels,
)
self.assertEqual(outputs.logits.shape, expected_shape)
# Confirm out_indices was propogated to backbone
self.assertEqual(len(model.model.backbone.conv_encoder.intermediate_channel_sizes), 3)
elif model_class.__name__ == "ConditionalDetrForSegmentation":
# Confirm out_indices was propogated to backbone
self.assertEqual(len(model.conditional_detr.model.backbone.conv_encoder.intermediate_channel_sizes), 3)
else:
# Confirm out_indices was propogated to backbone
self.assertEqual(len(model.backbone.conv_encoder.intermediate_channel_sizes), 3)
self.assertTrue(outputs)
......
......@@ -521,8 +521,9 @@ class DeformableDetrModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT
# let's pick a random timm backbone
config.backbone = "tf_mobilenetv3_small_075"
config.use_timm_backbone = True
config.backbone_config = None
config.use_timm_backbone = True
config.backbone_kwargs = {"out_indices": [1, 2, 3, 4]}
for model_class in self.all_model_classes:
model = model_class(config)
......@@ -538,6 +539,14 @@ class DeformableDetrModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT
self.model_tester.num_labels,
)
self.assertEqual(outputs.logits.shape, expected_shape)
# Confirm out_indices was propogated to backbone
self.assertEqual(len(model.model.backbone.conv_encoder.intermediate_channel_sizes), 4)
elif model_class.__name__ == "ConditionalDetrForSegmentation":
# Confirm out_indices was propogated to backbone
self.assertEqual(len(model.deformable_detr.model.backbone.conv_encoder.intermediate_channel_sizes), 4)
else:
# Confirm out_indices was propogated to backbone
self.assertEqual(len(model.backbone.conv_encoder.intermediate_channel_sizes), 4)
self.assertTrue(outputs)
......
......@@ -444,6 +444,9 @@ class DetrModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
# let's pick a random timm backbone
config.backbone = "tf_mobilenetv3_small_075"
config.backbone_config = None
config.use_timm_backbone = True
config.backbone_kwargs = {"out_indices": [2, 3, 4]}
for model_class in self.all_model_classes:
model = model_class(config)
......@@ -459,6 +462,14 @@ class DetrModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
self.model_tester.num_labels + 1,
)
self.assertEqual(outputs.logits.shape, expected_shape)
# Confirm out_indices was propogated to backbone
self.assertEqual(len(model.model.backbone.conv_encoder.intermediate_channel_sizes), 3)
elif model_class.__name__ == "DetrForSegmentation":
# Confirm out_indices was propogated to backbone
self.assertEqual(len(model.detr.model.backbone.conv_encoder.intermediate_channel_sizes), 3)
else:
# Confirm out_indices was propogated to backbone
self.assertEqual(len(model.backbone.conv_encoder.intermediate_channel_sizes), 3)
self.assertTrue(outputs)
......
......@@ -456,6 +456,9 @@ class TableTransformerModelTest(ModelTesterMixin, GenerationTesterMixin, Pipelin
# let's pick a random timm backbone
config.backbone = "tf_mobilenetv3_small_075"
config.backbone_config = None
config.use_timm_backbone = True
config.backbone_kwargs = {"out_indices": [2, 3, 4]}
for model_class in self.all_model_classes:
model = model_class(config)
......@@ -471,6 +474,11 @@ class TableTransformerModelTest(ModelTesterMixin, GenerationTesterMixin, Pipelin
self.model_tester.num_labels + 1,
)
self.assertEqual(outputs.logits.shape, expected_shape)
# Confirm out_indices was propogated to backbone
self.assertEqual(len(model.model.backbone.conv_encoder.intermediate_channel_sizes), 3)
else:
# Confirm out_indices was propogated to backbone
self.assertEqual(len(model.backbone.conv_encoder.intermediate_channel_sizes), 3)
self.assertTrue(outputs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment