Unverified Commit 91ff7efe authored by NielsRogge's avatar NielsRogge Committed by GitHub
Browse files

[DETR and friends] Use AutoBackbone as alternative to timm (#20833)



* First draft

* More improvements

* Add conversion script

* More improvements

* Add docs

* Address review

* Rename class to ConvEncoder

* Address review

* Apply suggestion

* Apply suggestions from code review
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update all DETR friends

* Add corresponding test

* Improve test

* Fix bug

* Add more tests

* Set out_features to last stage by default
Co-authored-by: default avatarNiels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
Co-authored-by: default avatarNiels Rogge <nielsrogge@Nielss-MBP.localdomain>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent c8d719ff
...@@ -22,6 +22,7 @@ from packaging import version ...@@ -22,6 +22,7 @@ from packaging import version
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfig from ...onnx import OnnxConfig
from ...utils import logging from ...utils import logging
from ..auto import CONFIG_MAPPING
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -44,6 +45,12 @@ class ConditionalDetrConfig(PretrainedConfig): ...@@ -44,6 +45,12 @@ class ConditionalDetrConfig(PretrainedConfig):
documentation from [`PretrainedConfig`] for more information. documentation from [`PretrainedConfig`] for more information.
Args: Args:
use_timm_backbone (`bool`, *optional*, defaults to `True`):
Whether or not to use the `timm` library for the backbone. If set to `False`, will use the [`AutoBackbone`]
API.
backbone_config (`PretrainedConfig` or `dict`, *optional*):
The configuration of the backbone model. Only used in case `use_timm_backbone` is set to `False` in which
case it will default to `ResNetConfig()`.
num_channels (`int`, *optional*, defaults to 3): num_channels (`int`, *optional*, defaults to 3):
The number of input channels. The number of input channels.
num_queries (`int`, *optional*, defaults to 100): num_queries (`int`, *optional*, defaults to 100):
...@@ -87,13 +94,14 @@ class ConditionalDetrConfig(PretrainedConfig): ...@@ -87,13 +94,14 @@ class ConditionalDetrConfig(PretrainedConfig):
position_embedding_type (`str`, *optional*, defaults to `"sine"`): position_embedding_type (`str`, *optional*, defaults to `"sine"`):
Type of position embeddings to be used on top of the image features. One of `"sine"` or `"learned"`. Type of position embeddings to be used on top of the image features. One of `"sine"` or `"learned"`.
backbone (`str`, *optional*, defaults to `"resnet50"`): backbone (`str`, *optional*, defaults to `"resnet50"`):
Name of convolutional backbone to use. Supports any convolutional backbone from the timm package. For a Name of convolutional backbone to use in case `use_timm_backbone` = `True`. Supports any convolutional
list of all available models, see [this backbone from the timm package. For a list of all available models, see [this
page](https://rwightman.github.io/pytorch-image-models/#load-a-pretrained-model). page](https://rwightman.github.io/pytorch-image-models/#load-a-pretrained-model).
use_pretrained_backbone (`bool`, *optional*, defaults to `True`): use_pretrained_backbone (`bool`, *optional*, defaults to `True`):
Whether to use pretrained weights for the backbone. Whether to use pretrained weights for the backbone. Only supported when `use_timm_backbone` = `True`.
dilation (`bool`, *optional*, defaults to `False`): dilation (`bool`, *optional*, defaults to `False`):
Whether to replace stride with dilation in the last convolutional block (DC5). Whether to replace stride with dilation in the last convolutional block (DC5). Only supported when
`use_timm_backbone` = `True`.
class_cost (`float`, *optional*, defaults to 1): class_cost (`float`, *optional*, defaults to 1):
Relative weight of the classification error in the Hungarian matching cost. Relative weight of the classification error in the Hungarian matching cost.
bbox_cost (`float`, *optional*, defaults to 5): bbox_cost (`float`, *optional*, defaults to 5):
...@@ -136,6 +144,8 @@ class ConditionalDetrConfig(PretrainedConfig): ...@@ -136,6 +144,8 @@ class ConditionalDetrConfig(PretrainedConfig):
def __init__( def __init__(
self, self,
use_timm_backbone=True,
backbone_config=None,
num_channels=3, num_channels=3,
num_queries=300, num_queries=300,
encoder_layers=6, encoder_layers=6,
...@@ -172,6 +182,20 @@ class ConditionalDetrConfig(PretrainedConfig): ...@@ -172,6 +182,20 @@ class ConditionalDetrConfig(PretrainedConfig):
focal_alpha=0.25, focal_alpha=0.25,
**kwargs **kwargs
): ):
if backbone_config is not None and use_timm_backbone:
raise ValueError("You can't specify both `backbone_config` and `use_timm_backbone`.")
if not use_timm_backbone:
if backbone_config is None:
logger.info("`backbone_config` is `None`. Initializing the config with the default `ResNet` backbone.")
backbone_config = CONFIG_MAPPING["resnet"](out_features=["stage4"])
elif isinstance(backbone_config, dict):
backbone_model_type = backbone_config.get("model_type")
config_class = CONFIG_MAPPING[backbone_model_type]
backbone_config = config_class.from_dict(backbone_config)
self.use_timm_backbone = use_timm_backbone
self.backbone_config = backbone_config
self.num_channels = num_channels self.num_channels = num_channels
self.num_queries = num_queries self.num_queries = num_queries
self.d_model = d_model self.d_model = d_model
......
...@@ -38,6 +38,7 @@ from ...utils import ( ...@@ -38,6 +38,7 @@ from ...utils import (
replace_return_docstrings, replace_return_docstrings,
requires_backends, requires_backends,
) )
from ..auto import AutoBackbone
from .configuration_conditional_detr import ConditionalDetrConfig from .configuration_conditional_detr import ConditionalDetrConfig
...@@ -326,46 +327,57 @@ def replace_batch_norm(m, name=""): ...@@ -326,46 +327,57 @@ def replace_batch_norm(m, name=""):
replace_batch_norm(ch, n) replace_batch_norm(ch, n)
# Copied from transformers.models.detr.modeling_detr.DetrTimmConvEncoder # Copied from transformers.models.detr.modeling_detr.DetrConvEncoder
class ConditionalDetrTimmConvEncoder(nn.Module): class ConditionalDetrConvEncoder(nn.Module):
""" """
Convolutional encoder (backbone) from the timm library. Convolutional backbone, using either the AutoBackbone API or one from the timm library.
nn.BatchNorm2d layers are replaced by DetrFrozenBatchNorm2d as defined above. nn.BatchNorm2d layers are replaced by DetrFrozenBatchNorm2d as defined above.
""" """
def __init__(self, name: str, dilation: bool, use_pretrained_backbone: bool, num_channels: int = 3): def __init__(self, config):
super().__init__() super().__init__()
kwargs = {} self.config = config
if dilation:
kwargs["output_stride"] = 16
if config.use_timm_backbone:
requires_backends(self, ["timm"]) requires_backends(self, ["timm"])
kwargs = {}
if config.dilation:
kwargs["output_stride"] = 16
backbone = create_model( backbone = create_model(
name, config.backbone,
pretrained=use_pretrained_backbone, pretrained=config.use_pretrained_backbone,
features_only=True, features_only=True,
out_indices=(1, 2, 3, 4), out_indices=(1, 2, 3, 4),
in_chans=num_channels, in_chans=config.num_channels,
**kwargs, **kwargs,
) )
else:
backbone = AutoBackbone.from_config(config.backbone_config)
# replace batch norm by frozen batch norm # replace batch norm by frozen batch norm
with torch.no_grad(): with torch.no_grad():
replace_batch_norm(backbone) replace_batch_norm(backbone)
self.model = backbone self.model = backbone
self.intermediate_channel_sizes = self.model.feature_info.channels() self.intermediate_channel_sizes = (
self.model.feature_info.channels() if config.use_timm_backbone else self.model.channels
)
if "resnet" in name: backbone_model_type = config.backbone if config.use_timm_backbone else config.backbone_config.model_type
if "resnet" in backbone_model_type:
for name, parameter in self.model.named_parameters(): for name, parameter in self.model.named_parameters():
if config.use_timm_backbone:
if "layer2" not in name and "layer3" not in name and "layer4" not in name: if "layer2" not in name and "layer3" not in name and "layer4" not in name:
parameter.requires_grad_(False) parameter.requires_grad_(False)
else:
if "stage.1" not in name and "stage.2" not in name and "stage.3" not in name:
parameter.requires_grad_(False)
def forward(self, pixel_values: torch.Tensor, pixel_mask: torch.Tensor): def forward(self, pixel_values: torch.Tensor, pixel_mask: torch.Tensor):
# send pixel_values through the model to get list of feature maps # send pixel_values through the model to get list of feature maps
features = self.model(pixel_values) features = self.model(pixel_values) if self.config.use_timm_backbone else self.model(pixel_values).feature_maps
out = [] out = []
for feature_map in features: for feature_map in features:
...@@ -1468,9 +1480,7 @@ class ConditionalDetrModel(ConditionalDetrPreTrainedModel): ...@@ -1468,9 +1480,7 @@ class ConditionalDetrModel(ConditionalDetrPreTrainedModel):
super().__init__(config) super().__init__(config)
# Create backbone + positional encoding # Create backbone + positional encoding
backbone = ConditionalDetrTimmConvEncoder( backbone = ConditionalDetrConvEncoder(config)
config.backbone, config.dilation, config.use_pretrained_backbone, config.num_channels
)
position_embeddings = build_position_encoding(config) position_embeddings = build_position_encoding(config)
self.backbone = ConditionalDetrConvModel(backbone, position_embeddings) self.backbone = ConditionalDetrConvModel(backbone, position_embeddings)
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...utils import logging from ...utils import logging
from ..auto import CONFIG_MAPPING
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -37,6 +38,14 @@ class DeformableDetrConfig(PretrainedConfig): ...@@ -37,6 +38,14 @@ class DeformableDetrConfig(PretrainedConfig):
documentation from [`PretrainedConfig`] for more information. documentation from [`PretrainedConfig`] for more information.
Args: Args:
use_timm_backbone (`bool`, *optional*, defaults to `True`):
Whether or not to use the `timm` library for the backbone. If set to `False`, will use the [`AutoBackbone`]
API.
backbone_config (`PretrainedConfig` or `dict`, *optional*):
The configuration of the backbone model. Only used in case `use_timm_backbone` is set to `False` in which
case it will default to `ResNetConfig()`.
num_channels (`int`, *optional*, defaults to 3):
The number of input channels.
num_queries (`int`, *optional*, defaults to 300): num_queries (`int`, *optional*, defaults to 300):
Number of object queries, i.e. detection slots. This is the maximal number of objects Number of object queries, i.e. detection slots. This is the maximal number of objects
[`DeformableDetrModel`] can detect in a single image. In case `two_stage` is set to `True`, we use [`DeformableDetrModel`] can detect in a single image. In case `two_stage` is set to `True`, we use
...@@ -79,11 +88,14 @@ class DeformableDetrConfig(PretrainedConfig): ...@@ -79,11 +88,14 @@ class DeformableDetrConfig(PretrainedConfig):
position_embedding_type (`str`, *optional*, defaults to `"sine"`): position_embedding_type (`str`, *optional*, defaults to `"sine"`):
Type of position embeddings to be used on top of the image features. One of `"sine"` or `"learned"`. Type of position embeddings to be used on top of the image features. One of `"sine"` or `"learned"`.
backbone (`str`, *optional*, defaults to `"resnet50"`): backbone (`str`, *optional*, defaults to `"resnet50"`):
Name of convolutional backbone to use. Supports any convolutional backbone from the timm package. For a Name of convolutional backbone to use in case `use_timm_backbone` = `True`. Supports any convolutional
list of all available models, see [this backbone from the timm package. For a list of all available models, see [this
page](https://rwightman.github.io/pytorch-image-models/#load-a-pretrained-model). page](https://rwightman.github.io/pytorch-image-models/#load-a-pretrained-model).
use_pretrained_backbone (`bool`, *optional*, defaults to `True`):
Whether to use pretrained weights for the backbone. Only supported when `use_timm_backbone` = `True`.
dilation (`bool`, *optional*, defaults to `False`): dilation (`bool`, *optional*, defaults to `False`):
Whether to replace stride with dilation in the last convolutional block (DC5). Whether to replace stride with dilation in the last convolutional block (DC5). Only supported when
`use_timm_backbone` = `True`.
class_cost (`float`, *optional*, defaults to 1): class_cost (`float`, *optional*, defaults to 1):
Relative weight of the classification error in the Hungarian matching cost. Relative weight of the classification error in the Hungarian matching cost.
bbox_cost (`float`, *optional*, defaults to 5): bbox_cost (`float`, *optional*, defaults to 5):
...@@ -139,6 +151,9 @@ class DeformableDetrConfig(PretrainedConfig): ...@@ -139,6 +151,9 @@ class DeformableDetrConfig(PretrainedConfig):
def __init__( def __init__(
self, self,
use_timm_backbone=True,
backbone_config=None,
num_channels=3,
num_queries=300, num_queries=300,
max_position_embeddings=1024, max_position_embeddings=1024,
encoder_layers=6, encoder_layers=6,
...@@ -161,6 +176,7 @@ class DeformableDetrConfig(PretrainedConfig): ...@@ -161,6 +176,7 @@ class DeformableDetrConfig(PretrainedConfig):
auxiliary_loss=False, auxiliary_loss=False,
position_embedding_type="sine", position_embedding_type="sine",
backbone="resnet50", backbone="resnet50",
use_pretrained_backbone=True,
dilation=False, dilation=False,
num_feature_levels=4, num_feature_levels=4,
encoder_n_points=4, encoder_n_points=4,
...@@ -179,6 +195,20 @@ class DeformableDetrConfig(PretrainedConfig): ...@@ -179,6 +195,20 @@ class DeformableDetrConfig(PretrainedConfig):
focal_alpha=0.25, focal_alpha=0.25,
**kwargs **kwargs
): ):
if backbone_config is not None and use_timm_backbone:
raise ValueError("You can't specify both `backbone_config` and `use_timm_backbone`.")
if not use_timm_backbone:
if backbone_config is None:
logger.info("`backbone_config` is `None`. Initializing the config with the default `ResNet` backbone.")
backbone_config = CONFIG_MAPPING["resnet"](out_features=["stage4"])
elif isinstance(backbone_config, dict):
backbone_model_type = backbone_config.get("model_type")
config_class = CONFIG_MAPPING[backbone_model_type]
backbone_config = config_class.from_dict(backbone_config)
self.use_timm_backbone = use_timm_backbone
self.backbone_config = backbone_config
self.num_channels = num_channels
self.num_queries = num_queries self.num_queries = num_queries
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
self.d_model = d_model self.d_model = d_model
...@@ -199,6 +229,7 @@ class DeformableDetrConfig(PretrainedConfig): ...@@ -199,6 +229,7 @@ class DeformableDetrConfig(PretrainedConfig):
self.auxiliary_loss = auxiliary_loss self.auxiliary_loss = auxiliary_loss
self.position_embedding_type = position_embedding_type self.position_embedding_type = position_embedding_type
self.backbone = backbone self.backbone = backbone
self.use_pretrained_backbone = use_pretrained_backbone
self.dilation = dilation self.dilation = dilation
# deformable attributes # deformable attributes
self.num_feature_levels = num_feature_levels self.num_feature_levels = num_feature_levels
......
...@@ -43,6 +43,7 @@ from ...modeling_outputs import BaseModelOutput ...@@ -43,6 +43,7 @@ from ...modeling_outputs import BaseModelOutput
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import meshgrid from ...pytorch_utils import meshgrid
from ...utils import is_ninja_available, logging from ...utils import is_ninja_available, logging
from ..auto import AutoBackbone
from .configuration_deformable_detr import DeformableDetrConfig from .configuration_deformable_detr import DeformableDetrConfig
from .load_custom import load_cuda_kernels from .load_custom import load_cuda_kernels
...@@ -371,45 +372,57 @@ def replace_batch_norm(m, name=""): ...@@ -371,45 +372,57 @@ def replace_batch_norm(m, name=""):
replace_batch_norm(ch, n) replace_batch_norm(ch, n)
class DeformableDetrTimmConvEncoder(nn.Module): class DeformableDetrConvEncoder(nn.Module):
""" """
Convolutional encoder (backbone) from the timm library. Convolutional backbone, using either the AutoBackbone API or one from the timm library.
nn.BatchNorm2d layers are replaced by DeformableDetrFrozenBatchNorm2d as defined above. nn.BatchNorm2d layers are replaced by DeformableDetrFrozenBatchNorm2d as defined above.
""" """
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.config = config
if config.use_timm_backbone:
requires_backends(self, ["timm"])
kwargs = {} kwargs = {}
if config.dilation: if config.dilation:
kwargs["output_stride"] = 16 kwargs["output_stride"] = 16
requires_backends(self, ["timm"])
out_indices = (2, 3, 4) if config.num_feature_levels > 1 else (4,)
backbone = create_model( backbone = create_model(
config.backbone, pretrained=True, features_only=True, out_indices=out_indices, **kwargs config.backbone,
pretrained=config.use_pretrained_backbone,
features_only=True,
out_indices=(2, 3, 4) if config.num_feature_levels > 1 else (4,),
in_chans=config.num_channels,
**kwargs,
) )
else:
backbone = AutoBackbone.from_config(config.backbone_config)
# replace batch norm by frozen batch norm # replace batch norm by frozen batch norm
with torch.no_grad(): with torch.no_grad():
replace_batch_norm(backbone) replace_batch_norm(backbone)
self.model = backbone self.model = backbone
self.intermediate_channel_sizes = self.model.feature_info.channels() self.intermediate_channel_sizes = (
self.strides = self.model.feature_info.reduction() self.model.feature_info.channels() if config.use_timm_backbone else self.model.channels
)
if "resnet" in config.backbone: backbone_model_type = config.backbone if config.use_timm_backbone else config.backbone_config.model_type
if "resnet" in backbone_model_type:
for name, parameter in self.model.named_parameters(): for name, parameter in self.model.named_parameters():
if config.use_timm_backbone:
if "layer2" not in name and "layer3" not in name and "layer4" not in name: if "layer2" not in name and "layer3" not in name and "layer4" not in name:
parameter.requires_grad_(False) parameter.requires_grad_(False)
else:
if "stage.1" not in name and "stage.2" not in name and "stage.3" not in name:
parameter.requires_grad_(False)
# Copied from transformers.models.detr.modeling_detr.DetrConvEncoder.forward with Detr->DeformableDetr
def forward(self, pixel_values: torch.Tensor, pixel_mask: torch.Tensor): def forward(self, pixel_values: torch.Tensor, pixel_mask: torch.Tensor):
"""
Outputs feature maps of latter stages C_3 through C_5 in ResNet if `config.num_feature_levels > 1`, otherwise
outputs feature maps of C_5.
"""
# send pixel_values through the model to get list of feature maps # send pixel_values through the model to get list of feature maps
features = self.model(pixel_values) features = self.model(pixel_values) if self.config.use_timm_backbone else self.model(pixel_values).feature_maps
out = [] out = []
for feature_map in features: for feature_map in features:
...@@ -1438,13 +1451,13 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel): ...@@ -1438,13 +1451,13 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
super().__init__(config) super().__init__(config)
# Create backbone + positional encoding # Create backbone + positional encoding
backbone = DeformableDetrTimmConvEncoder(config) backbone = DeformableDetrConvEncoder(config)
position_embeddings = build_position_encoding(config) position_embeddings = build_position_encoding(config)
self.backbone = DeformableDetrConvModel(backbone, position_embeddings) self.backbone = DeformableDetrConvModel(backbone, position_embeddings)
# Create input projection layers # Create input projection layers
if config.num_feature_levels > 1: if config.num_feature_levels > 1:
num_backbone_outs = len(backbone.strides) num_backbone_outs = len(backbone.intermediate_channel_sizes)
input_proj_list = [] input_proj_list = []
for _ in range(num_backbone_outs): for _ in range(num_backbone_outs):
in_channels = backbone.intermediate_channel_sizes[_] in_channels = backbone.intermediate_channel_sizes[_]
......
...@@ -22,6 +22,7 @@ from packaging import version ...@@ -22,6 +22,7 @@ from packaging import version
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfig from ...onnx import OnnxConfig
from ...utils import logging from ...utils import logging
from ..auto import CONFIG_MAPPING
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -43,6 +44,12 @@ class DetrConfig(PretrainedConfig): ...@@ -43,6 +44,12 @@ class DetrConfig(PretrainedConfig):
documentation from [`PretrainedConfig`] for more information. documentation from [`PretrainedConfig`] for more information.
Args: Args:
use_timm_backbone (`bool`, *optional*, defaults to `True`):
Whether or not to use the `timm` library for the backbone. If set to `False`, will use the [`AutoBackbone`]
API.
backbone_config (`PretrainedConfig` or `dict`, *optional*):
The configuration of the backbone model. Only used in case `use_timm_backbone` is set to `False` in which
case it will default to `ResNetConfig()`.
num_channels (`int`, *optional*, defaults to 3): num_channels (`int`, *optional*, defaults to 3):
The number of input channels. The number of input channels.
num_queries (`int`, *optional*, defaults to 100): num_queries (`int`, *optional*, defaults to 100):
...@@ -86,13 +93,14 @@ class DetrConfig(PretrainedConfig): ...@@ -86,13 +93,14 @@ class DetrConfig(PretrainedConfig):
position_embedding_type (`str`, *optional*, defaults to `"sine"`): position_embedding_type (`str`, *optional*, defaults to `"sine"`):
Type of position embeddings to be used on top of the image features. One of `"sine"` or `"learned"`. Type of position embeddings to be used on top of the image features. One of `"sine"` or `"learned"`.
backbone (`str`, *optional*, defaults to `"resnet50"`): backbone (`str`, *optional*, defaults to `"resnet50"`):
Name of convolutional backbone to use. Supports any convolutional backbone from the timm package. For a Name of convolutional backbone to use in case `use_timm_backbone` = `True`. Supports any convolutional
list of all available models, see [this backbone from the timm package. For a list of all available models, see [this
page](https://rwightman.github.io/pytorch-image-models/#load-a-pretrained-model). page](https://rwightman.github.io/pytorch-image-models/#load-a-pretrained-model).
use_pretrained_backbone (`bool`, *optional*, defaults to `True`): use_pretrained_backbone (`bool`, *optional*, defaults to `True`):
Whether to use pretrained weights for the backbone. Whether to use pretrained weights for the backbone. Only supported when `use_timm_backbone` = `True`.
dilation (`bool`, *optional*, defaults to `False`): dilation (`bool`, *optional*, defaults to `False`):
Whether to replace stride with dilation in the last convolutional block (DC5). Whether to replace stride with dilation in the last convolutional block (DC5). Only supported when
`use_timm_backbone` = `True`.
class_cost (`float`, *optional*, defaults to 1): class_cost (`float`, *optional*, defaults to 1):
Relative weight of the classification error in the Hungarian matching cost. Relative weight of the classification error in the Hungarian matching cost.
bbox_cost (`float`, *optional*, defaults to 5): bbox_cost (`float`, *optional*, defaults to 5):
...@@ -133,6 +141,8 @@ class DetrConfig(PretrainedConfig): ...@@ -133,6 +141,8 @@ class DetrConfig(PretrainedConfig):
def __init__( def __init__(
self, self,
use_timm_backbone=True,
backbone_config=None,
num_channels=3, num_channels=3,
num_queries=100, num_queries=100,
encoder_layers=6, encoder_layers=6,
...@@ -168,6 +178,20 @@ class DetrConfig(PretrainedConfig): ...@@ -168,6 +178,20 @@ class DetrConfig(PretrainedConfig):
eos_coefficient=0.1, eos_coefficient=0.1,
**kwargs **kwargs
): ):
if backbone_config is not None and use_timm_backbone:
raise ValueError("You can't specify both `backbone_config` and `use_timm_backbone`.")
if not use_timm_backbone:
if backbone_config is None:
logger.info("`backbone_config` is `None`. Initializing the config with the default `ResNet` backbone.")
backbone_config = CONFIG_MAPPING["resnet"](out_features=["stage4"])
elif isinstance(backbone_config, dict):
backbone_model_type = backbone_config.get("model_type")
config_class = CONFIG_MAPPING[backbone_model_type]
backbone_config = config_class.from_dict(backbone_config)
self.use_timm_backbone = use_timm_backbone
self.backbone_config = backbone_config
self.num_channels = num_channels self.num_channels = num_channels
self.num_queries = num_queries self.num_queries = num_queries
self.d_model = d_model self.d_model = d_model
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Convert DETR checkpoints.""" """Convert DETR checkpoints with timm backbone."""
import argparse import argparse
......
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert DETR checkpoints with native (Transformers) backbone."""
import argparse
import json
from pathlib import Path
import torch
from PIL import Image
import requests
from huggingface_hub import hf_hub_download
from transformers import DetrConfig, DetrForObjectDetection, DetrForSegmentation, DetrImageProcessor, ResNetConfig
from transformers.utils import logging
logging.set_verbosity_info()
logger = logging.get_logger(__name__)
def get_detr_config(model_name):
config = DetrConfig(use_timm_backbone=False)
# set backbone attributes
if "resnet50" in model_name:
pass
elif "resnet101" in model_name:
config.backbone_config = ResNetConfig.from_pretrained("microsoft/resnet-101")
else:
raise ValueError("Model name should include either resnet50 or resnet101")
# set label attributes
is_panoptic = "panoptic" in model_name
if is_panoptic:
config.num_labels = 250
else:
config.num_labels = 91
repo_id = "huggingface/label-files"
filename = "coco-detection-id2label.json"
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
id2label = {int(k): v for k, v in id2label.items()}
config.id2label = id2label
config.label2id = {v: k for k, v in id2label.items()}
return config, is_panoptic
def create_rename_keys(config):
# here we list all keys to be renamed (original name on the left, our name on the right)
rename_keys = []
# stem
# fmt: off
rename_keys.append(("backbone.0.body.conv1.weight", "backbone.conv_encoder.model.embedder.embedder.convolution.weight"))
rename_keys.append(("backbone.0.body.bn1.weight", "backbone.conv_encoder.model.embedder.embedder.normalization.weight"))
rename_keys.append(("backbone.0.body.bn1.bias", "backbone.conv_encoder.model.embedder.embedder.normalization.bias"))
rename_keys.append(("backbone.0.body.bn1.running_mean", "backbone.conv_encoder.model.embedder.embedder.normalization.running_mean"))
rename_keys.append(("backbone.0.body.bn1.running_var", "backbone.conv_encoder.model.embedder.embedder.normalization.running_var"))
# stages
for stage_idx in range(len(config.backbone_config.depths)):
for layer_idx in range(config.backbone_config.depths[stage_idx]):
# shortcut
if layer_idx == 0:
rename_keys.append(
(
f"backbone.0.body.layer{stage_idx + 1}.{layer_idx}.downsample.0.weight",
f"backbone.conv_encoder.model.encoder.stages.{stage_idx}.layers.{layer_idx}.shortcut.convolution.weight",
)
)
rename_keys.append(
(
f"backbone.0.body.layer{stage_idx + 1}.{layer_idx}.downsample.1.weight",
f"backbone.conv_encoder.model.encoder.stages.{stage_idx}.layers.{layer_idx}.shortcut.normalization.weight",
)
)
rename_keys.append(
(
f"backbone.0.body.layer{stage_idx + 1}.{layer_idx}.downsample.1.bias",
f"backbone.conv_encoder.model.encoder.stages.{stage_idx}.layers.{layer_idx}.shortcut.normalization.bias",
)
)
rename_keys.append(
(
f"backbone.0.body.layer{stage_idx + 1}.{layer_idx}.downsample.1.running_mean",
f"backbone.conv_encoder.model.encoder.stages.{stage_idx}.layers.{layer_idx}.shortcut.normalization.running_mean",
)
)
rename_keys.append(
(
f"backbone.0.body.layer{stage_idx + 1}.{layer_idx}.downsample.1.running_var",
f"backbone.conv_encoder.model.encoder.stages.{stage_idx}.layers.{layer_idx}.shortcut.normalization.running_var",
)
)
# 3 convs
for i in range(3):
rename_keys.append(
(
f"backbone.0.body.layer{stage_idx + 1}.{layer_idx}.conv{i+1}.weight",
f"backbone.conv_encoder.model.encoder.stages.{stage_idx}.layers.{layer_idx}.layer.{i}.convolution.weight",
)
)
rename_keys.append(
(
f"backbone.0.body.layer{stage_idx + 1}.{layer_idx}.bn{i+1}.weight",
f"backbone.conv_encoder.model.encoder.stages.{stage_idx}.layers.{layer_idx}.layer.{i}.normalization.weight",
)
)
rename_keys.append(
(
f"backbone.0.body.layer{stage_idx + 1}.{layer_idx}.bn{i+1}.bias",
f"backbone.conv_encoder.model.encoder.stages.{stage_idx}.layers.{layer_idx}.layer.{i}.normalization.bias",
)
)
rename_keys.append(
(
f"backbone.0.body.layer{stage_idx + 1}.{layer_idx}.bn{i+1}.running_mean",
f"backbone.conv_encoder.model.encoder.stages.{stage_idx}.layers.{layer_idx}.layer.{i}.normalization.running_mean",
)
)
rename_keys.append(
(
f"backbone.0.body.layer{stage_idx + 1}.{layer_idx}.bn{i+1}.running_var",
f"backbone.conv_encoder.model.encoder.stages.{stage_idx}.layers.{layer_idx}.layer.{i}.normalization.running_var",
)
)
# fmt: on
for i in range(config.encoder_layers):
# encoder layers: output projection, 2 feedforward neural networks and 2 layernorms
rename_keys.append(
(
f"transformer.encoder.layers.{i}.self_attn.out_proj.weight",
f"encoder.layers.{i}.self_attn.out_proj.weight",
)
)
rename_keys.append(
(f"transformer.encoder.layers.{i}.self_attn.out_proj.bias", f"encoder.layers.{i}.self_attn.out_proj.bias")
)
rename_keys.append((f"transformer.encoder.layers.{i}.linear1.weight", f"encoder.layers.{i}.fc1.weight"))
rename_keys.append((f"transformer.encoder.layers.{i}.linear1.bias", f"encoder.layers.{i}.fc1.bias"))
rename_keys.append((f"transformer.encoder.layers.{i}.linear2.weight", f"encoder.layers.{i}.fc2.weight"))
rename_keys.append((f"transformer.encoder.layers.{i}.linear2.bias", f"encoder.layers.{i}.fc2.bias"))
rename_keys.append(
(f"transformer.encoder.layers.{i}.norm1.weight", f"encoder.layers.{i}.self_attn_layer_norm.weight")
)
rename_keys.append(
(f"transformer.encoder.layers.{i}.norm1.bias", f"encoder.layers.{i}.self_attn_layer_norm.bias")
)
rename_keys.append(
(f"transformer.encoder.layers.{i}.norm2.weight", f"encoder.layers.{i}.final_layer_norm.weight")
)
rename_keys.append((f"transformer.encoder.layers.{i}.norm2.bias", f"encoder.layers.{i}.final_layer_norm.bias"))
# decoder layers: 2 times output projection, 2 feedforward neural networks and 3 layernorms
rename_keys.append(
(
f"transformer.decoder.layers.{i}.self_attn.out_proj.weight",
f"decoder.layers.{i}.self_attn.out_proj.weight",
)
)
rename_keys.append(
(f"transformer.decoder.layers.{i}.self_attn.out_proj.bias", f"decoder.layers.{i}.self_attn.out_proj.bias")
)
rename_keys.append(
(
f"transformer.decoder.layers.{i}.multihead_attn.out_proj.weight",
f"decoder.layers.{i}.encoder_attn.out_proj.weight",
)
)
rename_keys.append(
(
f"transformer.decoder.layers.{i}.multihead_attn.out_proj.bias",
f"decoder.layers.{i}.encoder_attn.out_proj.bias",
)
)
rename_keys.append((f"transformer.decoder.layers.{i}.linear1.weight", f"decoder.layers.{i}.fc1.weight"))
rename_keys.append((f"transformer.decoder.layers.{i}.linear1.bias", f"decoder.layers.{i}.fc1.bias"))
rename_keys.append((f"transformer.decoder.layers.{i}.linear2.weight", f"decoder.layers.{i}.fc2.weight"))
rename_keys.append((f"transformer.decoder.layers.{i}.linear2.bias", f"decoder.layers.{i}.fc2.bias"))
rename_keys.append(
(f"transformer.decoder.layers.{i}.norm1.weight", f"decoder.layers.{i}.self_attn_layer_norm.weight")
)
rename_keys.append(
(f"transformer.decoder.layers.{i}.norm1.bias", f"decoder.layers.{i}.self_attn_layer_norm.bias")
)
rename_keys.append(
(f"transformer.decoder.layers.{i}.norm2.weight", f"decoder.layers.{i}.encoder_attn_layer_norm.weight")
)
rename_keys.append(
(f"transformer.decoder.layers.{i}.norm2.bias", f"decoder.layers.{i}.encoder_attn_layer_norm.bias")
)
rename_keys.append(
(f"transformer.decoder.layers.{i}.norm3.weight", f"decoder.layers.{i}.final_layer_norm.weight")
)
rename_keys.append((f"transformer.decoder.layers.{i}.norm3.bias", f"decoder.layers.{i}.final_layer_norm.bias"))
# convolutional projection + query embeddings + layernorm of decoder + class and bounding box heads
rename_keys.extend(
[
("input_proj.weight", "input_projection.weight"),
("input_proj.bias", "input_projection.bias"),
("query_embed.weight", "query_position_embeddings.weight"),
("transformer.decoder.norm.weight", "decoder.layernorm.weight"),
("transformer.decoder.norm.bias", "decoder.layernorm.bias"),
("class_embed.weight", "class_labels_classifier.weight"),
("class_embed.bias", "class_labels_classifier.bias"),
("bbox_embed.layers.0.weight", "bbox_predictor.layers.0.weight"),
("bbox_embed.layers.0.bias", "bbox_predictor.layers.0.bias"),
("bbox_embed.layers.1.weight", "bbox_predictor.layers.1.weight"),
("bbox_embed.layers.1.bias", "bbox_predictor.layers.1.bias"),
("bbox_embed.layers.2.weight", "bbox_predictor.layers.2.weight"),
("bbox_embed.layers.2.bias", "bbox_predictor.layers.2.bias"),
]
)
return rename_keys
def rename_key(state_dict, old, new):
val = state_dict.pop(old)
state_dict[new] = val
def read_in_q_k_v(state_dict, is_panoptic=False):
prefix = ""
if is_panoptic:
prefix = "detr."
# first: transformer encoder
for i in range(6):
# read in weights + bias of input projection layer (in PyTorch's MultiHeadAttention, this is a single matrix + bias)
in_proj_weight = state_dict.pop(f"{prefix}transformer.encoder.layers.{i}.self_attn.in_proj_weight")
in_proj_bias = state_dict.pop(f"{prefix}transformer.encoder.layers.{i}.self_attn.in_proj_bias")
# next, add query, keys and values (in that order) to the state dict
state_dict[f"encoder.layers.{i}.self_attn.q_proj.weight"] = in_proj_weight[:256, :]
state_dict[f"encoder.layers.{i}.self_attn.q_proj.bias"] = in_proj_bias[:256]
state_dict[f"encoder.layers.{i}.self_attn.k_proj.weight"] = in_proj_weight[256:512, :]
state_dict[f"encoder.layers.{i}.self_attn.k_proj.bias"] = in_proj_bias[256:512]
state_dict[f"encoder.layers.{i}.self_attn.v_proj.weight"] = in_proj_weight[-256:, :]
state_dict[f"encoder.layers.{i}.self_attn.v_proj.bias"] = in_proj_bias[-256:]
# next: transformer decoder (which is a bit more complex because it also includes cross-attention)
for i in range(6):
# read in weights + bias of input projection layer of self-attention
in_proj_weight = state_dict.pop(f"{prefix}transformer.decoder.layers.{i}.self_attn.in_proj_weight")
in_proj_bias = state_dict.pop(f"{prefix}transformer.decoder.layers.{i}.self_attn.in_proj_bias")
# next, add query, keys and values (in that order) to the state dict
state_dict[f"decoder.layers.{i}.self_attn.q_proj.weight"] = in_proj_weight[:256, :]
state_dict[f"decoder.layers.{i}.self_attn.q_proj.bias"] = in_proj_bias[:256]
state_dict[f"decoder.layers.{i}.self_attn.k_proj.weight"] = in_proj_weight[256:512, :]
state_dict[f"decoder.layers.{i}.self_attn.k_proj.bias"] = in_proj_bias[256:512]
state_dict[f"decoder.layers.{i}.self_attn.v_proj.weight"] = in_proj_weight[-256:, :]
state_dict[f"decoder.layers.{i}.self_attn.v_proj.bias"] = in_proj_bias[-256:]
# read in weights + bias of input projection layer of cross-attention
in_proj_weight_cross_attn = state_dict.pop(
f"{prefix}transformer.decoder.layers.{i}.multihead_attn.in_proj_weight"
)
in_proj_bias_cross_attn = state_dict.pop(f"{prefix}transformer.decoder.layers.{i}.multihead_attn.in_proj_bias")
# next, add query, keys and values (in that order) of cross-attention to the state dict
state_dict[f"decoder.layers.{i}.encoder_attn.q_proj.weight"] = in_proj_weight_cross_attn[:256, :]
state_dict[f"decoder.layers.{i}.encoder_attn.q_proj.bias"] = in_proj_bias_cross_attn[:256]
state_dict[f"decoder.layers.{i}.encoder_attn.k_proj.weight"] = in_proj_weight_cross_attn[256:512, :]
state_dict[f"decoder.layers.{i}.encoder_attn.k_proj.bias"] = in_proj_bias_cross_attn[256:512]
state_dict[f"decoder.layers.{i}.encoder_attn.v_proj.weight"] = in_proj_weight_cross_attn[-256:, :]
state_dict[f"decoder.layers.{i}.encoder_attn.v_proj.bias"] = in_proj_bias_cross_attn[-256:]
# We will verify our results on an image of cute cats
def prepare_img():
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
im = Image.open(requests.get(url, stream=True).raw)
return im
@torch.no_grad()
def convert_detr_checkpoint(model_name, pytorch_dump_folder_path):
"""
Copy/paste/tweak model's weights to our DETR structure.
"""
# load default config
config, is_panoptic = get_detr_config(model_name)
# load original model from torch hub
logger.info(f"Converting model {model_name}...")
detr = torch.hub.load("facebookresearch/detr", model_name, pretrained=True).eval()
state_dict = detr.state_dict()
# rename keys
for src, dest in create_rename_keys(config):
if is_panoptic:
src = "detr." + src
rename_key(state_dict, src, dest)
# query, key and value matrices need special treatment
read_in_q_k_v(state_dict, is_panoptic=is_panoptic)
# important: we need to prepend a prefix to each of the base model keys as the head models use different attributes for them
prefix = "detr.model." if is_panoptic else "model."
for key in state_dict.copy().keys():
if is_panoptic:
if (
key.startswith("detr")
and not key.startswith("class_labels_classifier")
and not key.startswith("bbox_predictor")
):
val = state_dict.pop(key)
state_dict["detr.model" + key[4:]] = val
elif "class_labels_classifier" in key or "bbox_predictor" in key:
val = state_dict.pop(key)
state_dict["detr." + key] = val
elif key.startswith("bbox_attention") or key.startswith("mask_head"):
continue
else:
val = state_dict.pop(key)
state_dict[prefix + key] = val
else:
if not key.startswith("class_labels_classifier") and not key.startswith("bbox_predictor"):
val = state_dict.pop(key)
state_dict[prefix + key] = val
# finally, create HuggingFace model and load state dict
model = DetrForSegmentation(config) if is_panoptic else DetrForObjectDetection(config)
model.load_state_dict(state_dict)
model.eval()
# verify our conversion on an image
format = "coco_panoptic" if is_panoptic else "coco_detection"
processor = DetrImageProcessor(format=format)
encoding = processor(images=prepare_img(), return_tensors="pt")
pixel_values = encoding["pixel_values"]
original_outputs = detr(pixel_values)
outputs = model(pixel_values)
print("Logits:", outputs.logits[0, :3, :3])
print("Original logits:", original_outputs["pred_logits"][0, :3, :3])
assert torch.allclose(outputs.logits, original_outputs["pred_logits"], atol=1e-3)
assert torch.allclose(outputs.pred_boxes, original_outputs["pred_boxes"], atol=1e-3)
if is_panoptic:
assert torch.allclose(outputs.pred_masks, original_outputs["pred_masks"], atol=1e-4)
print("Looks ok!")
if pytorch_dump_folder_path is not None:
# Save model and image processor
logger.info(f"Saving PyTorch model and image processor to {pytorch_dump_folder_path}...")
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
model.save_pretrained(pytorch_dump_folder_path)
processor.save_pretrained(pytorch_dump_folder_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_name", default="detr_resnet50", type=str, help="Name of the DETR model you'd like to convert."
)
parser.add_argument(
"--pytorch_dump_folder_path", default=None, type=str, help="Path to the folder to output PyTorch model."
)
args = parser.parse_args()
convert_detr_checkpoint(args.model_name, args.pytorch_dump_folder_path)
...@@ -38,6 +38,7 @@ from ...utils import ( ...@@ -38,6 +38,7 @@ from ...utils import (
replace_return_docstrings, replace_return_docstrings,
requires_backends, requires_backends,
) )
from ..auto import AutoBackbone
from .configuration_detr import DetrConfig from .configuration_detr import DetrConfig
...@@ -320,45 +321,56 @@ def replace_batch_norm(m, name=""): ...@@ -320,45 +321,56 @@ def replace_batch_norm(m, name=""):
replace_batch_norm(ch, n) replace_batch_norm(ch, n)
class DetrTimmConvEncoder(nn.Module): class DetrConvEncoder(nn.Module):
""" """
Convolutional encoder (backbone) from the timm library. Convolutional backbone, using either the AutoBackbone API or one from the timm library.
nn.BatchNorm2d layers are replaced by DetrFrozenBatchNorm2d as defined above. nn.BatchNorm2d layers are replaced by DetrFrozenBatchNorm2d as defined above.
""" """
def __init__(self, name: str, dilation: bool, use_pretrained_backbone: bool, num_channels: int = 3): def __init__(self, config):
super().__init__() super().__init__()
kwargs = {} self.config = config
if dilation:
kwargs["output_stride"] = 16
if config.use_timm_backbone:
requires_backends(self, ["timm"]) requires_backends(self, ["timm"])
kwargs = {}
if config.dilation:
kwargs["output_stride"] = 16
backbone = create_model( backbone = create_model(
name, config.backbone,
pretrained=use_pretrained_backbone, pretrained=config.use_pretrained_backbone,
features_only=True, features_only=True,
out_indices=(1, 2, 3, 4), out_indices=(1, 2, 3, 4),
in_chans=num_channels, in_chans=config.num_channels,
**kwargs, **kwargs,
) )
else:
backbone = AutoBackbone.from_config(config.backbone_config)
# replace batch norm by frozen batch norm # replace batch norm by frozen batch norm
with torch.no_grad(): with torch.no_grad():
replace_batch_norm(backbone) replace_batch_norm(backbone)
self.model = backbone self.model = backbone
self.intermediate_channel_sizes = self.model.feature_info.channels() self.intermediate_channel_sizes = (
self.model.feature_info.channels() if config.use_timm_backbone else self.model.channels
)
if "resnet" in name: backbone_model_type = config.backbone if config.use_timm_backbone else config.backbone_config.model_type
if "resnet" in backbone_model_type:
for name, parameter in self.model.named_parameters(): for name, parameter in self.model.named_parameters():
if config.use_timm_backbone:
if "layer2" not in name and "layer3" not in name and "layer4" not in name: if "layer2" not in name and "layer3" not in name and "layer4" not in name:
parameter.requires_grad_(False) parameter.requires_grad_(False)
else:
if "stage.1" not in name and "stage.2" not in name and "stage.3" not in name:
parameter.requires_grad_(False)
def forward(self, pixel_values: torch.Tensor, pixel_mask: torch.Tensor): def forward(self, pixel_values: torch.Tensor, pixel_mask: torch.Tensor):
# send pixel_values through the model to get list of feature maps # send pixel_values through the model to get list of feature maps
features = self.model(pixel_values) features = self.model(pixel_values) if self.config.use_timm_backbone else self.model(pixel_values).feature_maps
out = [] out = []
for feature_map in features: for feature_map in features:
...@@ -1191,9 +1203,7 @@ class DetrModel(DetrPreTrainedModel): ...@@ -1191,9 +1203,7 @@ class DetrModel(DetrPreTrainedModel):
super().__init__(config) super().__init__(config)
# Create backbone + positional encoding # Create backbone + positional encoding
backbone = DetrTimmConvEncoder( backbone = DetrConvEncoder(config)
config.backbone, config.dilation, config.use_pretrained_backbone, config.num_channels
)
position_embeddings = build_position_encoding(config) position_embeddings = build_position_encoding(config)
self.backbone = DetrConvModel(backbone, position_embeddings) self.backbone = DetrConvModel(backbone, position_embeddings)
......
...@@ -22,6 +22,7 @@ from packaging import version ...@@ -22,6 +22,7 @@ from packaging import version
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfig from ...onnx import OnnxConfig
from ...utils import logging from ...utils import logging
from ..auto import CONFIG_MAPPING
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -44,6 +45,12 @@ class TableTransformerConfig(PretrainedConfig): ...@@ -44,6 +45,12 @@ class TableTransformerConfig(PretrainedConfig):
documentation from [`PretrainedConfig`] for more information. documentation from [`PretrainedConfig`] for more information.
Args: Args:
use_timm_backbone (`bool`, *optional*, defaults to `True`):
Whether or not to use the `timm` library for the backbone. If set to `False`, will use the [`AutoBackbone`]
API.
backbone_config (`PretrainedConfig` or `dict`, *optional*):
The configuration of the backbone model. Only used in case `use_timm_backbone` is set to `False` in which
case it will default to `ResNetConfig()`.
num_channels (`int`, *optional*, defaults to 3): num_channels (`int`, *optional*, defaults to 3):
The number of input channels. The number of input channels.
num_queries (`int`, *optional*, defaults to 100): num_queries (`int`, *optional*, defaults to 100):
...@@ -87,13 +94,14 @@ class TableTransformerConfig(PretrainedConfig): ...@@ -87,13 +94,14 @@ class TableTransformerConfig(PretrainedConfig):
position_embedding_type (`str`, *optional*, defaults to `"sine"`): position_embedding_type (`str`, *optional*, defaults to `"sine"`):
Type of position embeddings to be used on top of the image features. One of `"sine"` or `"learned"`. Type of position embeddings to be used on top of the image features. One of `"sine"` or `"learned"`.
backbone (`str`, *optional*, defaults to `"resnet50"`): backbone (`str`, *optional*, defaults to `"resnet50"`):
Name of convolutional backbone to use. Supports any convolutional backbone from the timm package. For a Name of convolutional backbone to use in case `use_timm_backbone` = `True`. Supports any convolutional
list of all available models, see [this backbone from the timm package. For a list of all available models, see [this
page](https://rwightman.github.io/pytorch-image-models/#load-a-pretrained-model). page](https://rwightman.github.io/pytorch-image-models/#load-a-pretrained-model).
use_pretrained_backbone (`bool`, *optional*, defaults to `True`): use_pretrained_backbone (`bool`, *optional*, defaults to `True`):
Whether to use pretrained weights for the backbone. Whether to use pretrained weights for the backbone. Only supported when `use_timm_backbone` = `True`.
dilation (`bool`, *optional*, defaults to `False`): dilation (`bool`, *optional*, defaults to `False`):
Whether to replace stride with dilation in the last convolutional block (DC5). Whether to replace stride with dilation in the last convolutional block (DC5). Only supported when
`use_timm_backbone` = `True`.
class_cost (`float`, *optional*, defaults to 1): class_cost (`float`, *optional*, defaults to 1):
Relative weight of the classification error in the Hungarian matching cost. Relative weight of the classification error in the Hungarian matching cost.
bbox_cost (`float`, *optional*, defaults to 5): bbox_cost (`float`, *optional*, defaults to 5):
...@@ -135,6 +143,8 @@ class TableTransformerConfig(PretrainedConfig): ...@@ -135,6 +143,8 @@ class TableTransformerConfig(PretrainedConfig):
# Copied from transformers.models.detr.configuration_detr.DetrConfig.__init__ # Copied from transformers.models.detr.configuration_detr.DetrConfig.__init__
def __init__( def __init__(
self, self,
use_timm_backbone=True,
backbone_config=None,
num_channels=3, num_channels=3,
num_queries=100, num_queries=100,
encoder_layers=6, encoder_layers=6,
...@@ -170,6 +180,20 @@ class TableTransformerConfig(PretrainedConfig): ...@@ -170,6 +180,20 @@ class TableTransformerConfig(PretrainedConfig):
eos_coefficient=0.1, eos_coefficient=0.1,
**kwargs **kwargs
): ):
if backbone_config is not None and use_timm_backbone:
raise ValueError("You can't specify both `backbone_config` and `use_timm_backbone`.")
if not use_timm_backbone:
if backbone_config is None:
logger.info("`backbone_config` is `None`. Initializing the config with the default `ResNet` backbone.")
backbone_config = CONFIG_MAPPING["resnet"](out_features=["stage4"])
elif isinstance(backbone_config, dict):
backbone_model_type = backbone_config.get("model_type")
config_class = CONFIG_MAPPING[backbone_model_type]
backbone_config = config_class.from_dict(backbone_config)
self.use_timm_backbone = use_timm_backbone
self.backbone_config = backbone_config
self.num_channels = num_channels self.num_channels = num_channels
self.num_queries = num_queries self.num_queries = num_queries
self.d_model = d_model self.d_model = d_model
......
...@@ -38,6 +38,7 @@ from ...utils import ( ...@@ -38,6 +38,7 @@ from ...utils import (
replace_return_docstrings, replace_return_docstrings,
requires_backends, requires_backends,
) )
from ..auto import AutoBackbone
from .configuration_table_transformer import TableTransformerConfig from .configuration_table_transformer import TableTransformerConfig
...@@ -255,46 +256,57 @@ def replace_batch_norm(m, name=""): ...@@ -255,46 +256,57 @@ def replace_batch_norm(m, name=""):
replace_batch_norm(ch, n) replace_batch_norm(ch, n)
# Copied from transformers.models.detr.modeling_detr.DetrTimmConvEncoder with Detr->TableTransformer # Copied from transformers.models.detr.modeling_detr.DetrConvEncoder with Detr->TableTransformer
class TableTransformerTimmConvEncoder(nn.Module): class TableTransformerConvEncoder(nn.Module):
""" """
Convolutional encoder (backbone) from the timm library. Convolutional backbone, using either the AutoBackbone API or one from the timm library.
nn.BatchNorm2d layers are replaced by TableTransformerFrozenBatchNorm2d as defined above. nn.BatchNorm2d layers are replaced by TableTransformerFrozenBatchNorm2d as defined above.
""" """
def __init__(self, name: str, dilation: bool, use_pretrained_backbone: bool, num_channels: int = 3): def __init__(self, config):
super().__init__() super().__init__()
kwargs = {} self.config = config
if dilation:
kwargs["output_stride"] = 16
if config.use_timm_backbone:
requires_backends(self, ["timm"]) requires_backends(self, ["timm"])
kwargs = {}
if config.dilation:
kwargs["output_stride"] = 16
backbone = create_model( backbone = create_model(
name, config.backbone,
pretrained=use_pretrained_backbone, pretrained=config.use_pretrained_backbone,
features_only=True, features_only=True,
out_indices=(1, 2, 3, 4), out_indices=(1, 2, 3, 4),
in_chans=num_channels, in_chans=config.num_channels,
**kwargs, **kwargs,
) )
else:
backbone = AutoBackbone.from_config(config.backbone_config)
# replace batch norm by frozen batch norm # replace batch norm by frozen batch norm
with torch.no_grad(): with torch.no_grad():
replace_batch_norm(backbone) replace_batch_norm(backbone)
self.model = backbone self.model = backbone
self.intermediate_channel_sizes = self.model.feature_info.channels() self.intermediate_channel_sizes = (
self.model.feature_info.channels() if config.use_timm_backbone else self.model.channels
)
if "resnet" in name: backbone_model_type = config.backbone if config.use_timm_backbone else config.backbone_config.model_type
if "resnet" in backbone_model_type:
for name, parameter in self.model.named_parameters(): for name, parameter in self.model.named_parameters():
if config.use_timm_backbone:
if "layer2" not in name and "layer3" not in name and "layer4" not in name: if "layer2" not in name and "layer3" not in name and "layer4" not in name:
parameter.requires_grad_(False) parameter.requires_grad_(False)
else:
if "stage.1" not in name and "stage.2" not in name and "stage.3" not in name:
parameter.requires_grad_(False)
def forward(self, pixel_values: torch.Tensor, pixel_mask: torch.Tensor): def forward(self, pixel_values: torch.Tensor, pixel_mask: torch.Tensor):
# send pixel_values through the model to get list of feature maps # send pixel_values through the model to get list of feature maps
features = self.model(pixel_values) features = self.model(pixel_values) if self.config.use_timm_backbone else self.model(pixel_values).feature_maps
out = [] out = []
for feature_map in features: for feature_map in features:
...@@ -1136,9 +1148,7 @@ class TableTransformerModel(TableTransformerPreTrainedModel): ...@@ -1136,9 +1148,7 @@ class TableTransformerModel(TableTransformerPreTrainedModel):
super().__init__(config) super().__init__(config)
# Create backbone + positional encoding # Create backbone + positional encoding
backbone = TableTransformerTimmConvEncoder( backbone = TableTransformerConvEncoder(config)
config.backbone, config.dilation, config.use_pretrained_backbone, config.num_channels
)
position_embeddings = build_position_encoding(config) position_embeddings = build_position_encoding(config)
self.backbone = TableTransformerConvModel(backbone, position_embeddings) self.backbone = TableTransformerConvModel(backbone, position_embeddings)
......
...@@ -31,7 +31,12 @@ from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_ ...@@ -31,7 +31,12 @@ from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_
if is_timm_available(): if is_timm_available():
import torch import torch
from transformers import ConditionalDetrForObjectDetection, ConditionalDetrForSegmentation, ConditionalDetrModel from transformers import (
ConditionalDetrForObjectDetection,
ConditionalDetrForSegmentation,
ConditionalDetrModel,
ResNetConfig,
)
if is_vision_available(): if is_vision_available():
...@@ -153,6 +158,25 @@ class ConditionalDetrModelTester: ...@@ -153,6 +158,25 @@ class ConditionalDetrModelTester:
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_queries, self.num_labels)) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_queries, self.num_labels))
self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.num_queries, 4)) self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.num_queries, 4))
def create_and_check_no_timm_backbone(self, config, pixel_values, pixel_mask, labels):
config.use_timm_backbone = False
config.backbone_config = ResNetConfig()
model = ConditionalDetrForObjectDetection(config=config)
model.to(torch_device)
model.eval()
result = model(pixel_values=pixel_values, pixel_mask=pixel_mask)
result = model(pixel_values)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_queries, self.num_labels))
self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.num_queries, 4))
result = model(pixel_values=pixel_values, pixel_mask=pixel_mask, labels=labels)
self.parent.assertEqual(result.loss.shape, ())
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_queries, self.num_labels))
self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.num_queries, 4))
@require_timm @require_timm
class ConditionalDetrModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): class ConditionalDetrModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
...@@ -213,6 +237,10 @@ class ConditionalDetrModelTest(ModelTesterMixin, GenerationTesterMixin, unittest ...@@ -213,6 +237,10 @@ class ConditionalDetrModelTest(ModelTesterMixin, GenerationTesterMixin, unittest
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_conditional_detr_object_detection_head_model(*config_and_inputs) self.model_tester.create_and_check_conditional_detr_object_detection_head_model(*config_and_inputs)
def test_conditional_detr_no_timm_backbone(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_no_timm_backbone(*config_and_inputs)
@unittest.skip(reason="Conditional DETR does not use inputs_embeds") @unittest.skip(reason="Conditional DETR does not use inputs_embeds")
def test_inputs_embeds(self): def test_inputs_embeds(self):
pass pass
......
...@@ -32,7 +32,7 @@ from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_ ...@@ -32,7 +32,7 @@ from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_
if is_timm_available(): if is_timm_available():
import torch import torch
from transformers import DeformableDetrForObjectDetection, DeformableDetrModel from transformers import DeformableDetrForObjectDetection, DeformableDetrModel, ResNetConfig
if is_vision_available(): if is_vision_available():
...@@ -164,6 +164,25 @@ class DeformableDetrModelTester: ...@@ -164,6 +164,25 @@ class DeformableDetrModelTester:
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_queries, self.num_labels)) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_queries, self.num_labels))
self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.num_queries, 4)) self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.num_queries, 4))
def create_and_check_no_timm_backbone(self, config, pixel_values, pixel_mask, labels):
config.use_timm_backbone = False
config.backbone_config = ResNetConfig()
model = DeformableDetrForObjectDetection(config=config)
model.to(torch_device)
model.eval()
result = model(pixel_values=pixel_values, pixel_mask=pixel_mask)
result = model(pixel_values)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_queries, self.num_labels))
self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.num_queries, 4))
result = model(pixel_values=pixel_values, pixel_mask=pixel_mask, labels=labels)
self.parent.assertEqual(result.loss.shape, ())
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_queries, self.num_labels))
self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.num_queries, 4))
@require_timm @require_timm
class DeformableDetrModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): class DeformableDetrModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
...@@ -221,6 +240,10 @@ class DeformableDetrModelTest(ModelTesterMixin, GenerationTesterMixin, unittest. ...@@ -221,6 +240,10 @@ class DeformableDetrModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_deformable_detr_object_detection_head_model(*config_and_inputs) self.model_tester.create_and_check_deformable_detr_object_detection_head_model(*config_and_inputs)
def test_deformable_detr_no_timm_backbone(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_no_timm_backbone(*config_and_inputs)
@unittest.skip(reason="Deformable DETR does not use inputs_embeds") @unittest.skip(reason="Deformable DETR does not use inputs_embeds")
def test_inputs_embeds(self): def test_inputs_embeds(self):
pass pass
......
...@@ -31,7 +31,7 @@ from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_ ...@@ -31,7 +31,7 @@ from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_
if is_timm_available(): if is_timm_available():
import torch import torch
from transformers import DetrForObjectDetection, DetrForSegmentation, DetrModel from transformers import DetrForObjectDetection, DetrForSegmentation, DetrModel, ResNetConfig
if is_vision_available(): if is_vision_available():
...@@ -153,6 +153,25 @@ class DetrModelTester: ...@@ -153,6 +153,25 @@ class DetrModelTester:
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_queries, self.num_labels + 1)) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_queries, self.num_labels + 1))
self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.num_queries, 4)) self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.num_queries, 4))
def create_and_check_no_timm_backbone(self, config, pixel_values, pixel_mask, labels):
config.use_timm_backbone = False
config.backbone_config = ResNetConfig()
model = DetrForObjectDetection(config=config)
model.to(torch_device)
model.eval()
result = model(pixel_values=pixel_values, pixel_mask=pixel_mask)
result = model(pixel_values)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_queries, self.num_labels + 1))
self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.num_queries, 4))
result = model(pixel_values=pixel_values, pixel_mask=pixel_mask, labels=labels)
self.parent.assertEqual(result.loss.shape, ())
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_queries, self.num_labels + 1))
self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.num_queries, 4))
@require_timm @require_timm
class DetrModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): class DetrModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
...@@ -213,6 +232,10 @@ class DetrModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): ...@@ -213,6 +232,10 @@ class DetrModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_detr_object_detection_head_model(*config_and_inputs) self.model_tester.create_and_check_detr_object_detection_head_model(*config_and_inputs)
def test_detr_no_timm_backbone(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_no_timm_backbone(*config_and_inputs)
@unittest.skip(reason="DETR does not use inputs_embeds") @unittest.skip(reason="DETR does not use inputs_embeds")
def test_inputs_embeds(self): def test_inputs_embeds(self):
pass pass
......
...@@ -31,7 +31,7 @@ from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_ ...@@ -31,7 +31,7 @@ from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_
if is_timm_available(): if is_timm_available():
import torch import torch
from transformers import TableTransformerForObjectDetection, TableTransformerModel from transformers import ResNetConfig, TableTransformerForObjectDetection, TableTransformerModel
if is_vision_available(): if is_vision_available():
...@@ -153,6 +153,25 @@ class TableTransformerModelTester: ...@@ -153,6 +153,25 @@ class TableTransformerModelTester:
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_queries, self.num_labels + 1)) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_queries, self.num_labels + 1))
self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.num_queries, 4)) self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.num_queries, 4))
def create_and_check_table_transformer_no_timm_backbone(self, config, pixel_values, pixel_mask, labels):
config.use_timm_backbone = False
config.backbone_config = ResNetConfig()
model = TableTransformerForObjectDetection(config=config)
model.to(torch_device)
model.eval()
result = model(pixel_values=pixel_values, pixel_mask=pixel_mask)
result = model(pixel_values)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_queries, self.num_labels + 1))
self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.num_queries, 4))
result = model(pixel_values=pixel_values, pixel_mask=pixel_mask, labels=labels)
self.parent.assertEqual(result.loss.shape, ())
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_queries, self.num_labels + 1))
self.parent.assertEqual(result.pred_boxes.shape, (self.batch_size, self.num_queries, 4))
@require_timm @require_timm
class TableTransformerModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): class TableTransformerModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
...@@ -212,6 +231,10 @@ class TableTransformerModelTest(ModelTesterMixin, GenerationTesterMixin, unittes ...@@ -212,6 +231,10 @@ class TableTransformerModelTest(ModelTesterMixin, GenerationTesterMixin, unittes
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_table_transformer_object_detection_head_model(*config_and_inputs) self.model_tester.create_and_check_table_transformer_object_detection_head_model(*config_and_inputs)
def test_table_transformer_no_timm_backbone(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_table_transformer_no_timm_backbone(*config_and_inputs)
@unittest.skip(reason="Table Transformer does not use inputs_embeds") @unittest.skip(reason="Table Transformer does not use inputs_embeds")
def test_inputs_embeds(self): def test_inputs_embeds(self):
pass pass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment