Unverified Commit 2fa1c808 authored by amyeroberts's avatar amyeroberts Committed by GitHub
Browse files

[`Backbone`] Use `load_backbone` instead of `AutoBackbone.from_config` (#28661)

* Enable instantiating model with pretrained backbone weights

* Remove doc updates until changes made in modeling code

* Use load_backbone instead

* Add use_timm_backbone to the model configs

* Add missing imports and arguments

* Update docstrings

* Make sure test is properly configured

* Include recent DPT updates
parent c24c5245
......@@ -37,7 +37,7 @@ from ...utils import (
replace_return_docstrings,
requires_backends,
)
from ..auto import AutoBackbone
from ...utils.backbone_utils import load_backbone
from .configuration_conditional_detr import ConditionalDetrConfig
......@@ -363,7 +363,7 @@ class ConditionalDetrConvEncoder(nn.Module):
**kwargs,
)
else:
backbone = AutoBackbone.from_config(config.backbone_config)
backbone = load_backbone(config)
# replace batch norm by frozen batch norm
with torch.no_grad():
......
......@@ -44,7 +44,7 @@ from ...modeling_outputs import BaseModelOutput
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import meshgrid
from ...utils import is_ninja_available, logging
from ..auto import AutoBackbone
from ...utils.backbone_utils import load_backbone
from .configuration_deformable_detr import DeformableDetrConfig
from .load_custom import load_cuda_kernels
......@@ -409,7 +409,7 @@ class DeformableDetrConvEncoder(nn.Module):
**kwargs,
)
else:
backbone = AutoBackbone.from_config(config.backbone_config)
backbone = load_backbone(config)
# replace batch norm by frozen batch norm
with torch.no_grad():
......
......@@ -46,6 +46,9 @@ class DetaConfig(PretrainedConfig):
is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
use_pretrained_backbone (`bool`, *optional*, `False`):
Whether to use pretrained weights for the backbone.
use_timm_backbone (`bool`, *optional*, `False`):
Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers
library.
num_queries (`int`, *optional*, defaults to 900):
Number of object queries, i.e. detection slots. This is the maximal number of objects [`DetaModel`] can
detect in a single image. In case `two_stage` is set to `True`, we use `two_stage_num_proposals` instead.
......@@ -146,6 +149,7 @@ class DetaConfig(PretrainedConfig):
backbone_config=None,
backbone=None,
use_pretrained_backbone=False,
use_timm_backbone=False,
num_queries=900,
max_position_embeddings=2048,
encoder_layers=6,
......@@ -203,6 +207,7 @@ class DetaConfig(PretrainedConfig):
self.backbone_config = backbone_config
self.backbone = backbone
self.use_pretrained_backbone = use_pretrained_backbone
self.use_timm_backbone = use_timm_backbone
self.num_queries = num_queries
self.max_position_embeddings = max_position_embeddings
self.d_model = d_model
......
......@@ -39,7 +39,7 @@ from ...modeling_outputs import BaseModelOutput
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import meshgrid
from ...utils import is_accelerate_available, is_torchvision_available, logging, requires_backends
from ..auto import AutoBackbone
from ...utils.backbone_utils import load_backbone
from .configuration_deta import DetaConfig
......@@ -338,7 +338,7 @@ class DetaBackboneWithPositionalEncodings(nn.Module):
def __init__(self, config):
super().__init__()
backbone = AutoBackbone.from_config(config.backbone_config)
backbone = load_backbone(config)
with torch.no_grad():
replace_batch_norm(backbone)
self.model = backbone
......
......@@ -37,7 +37,7 @@ from ...utils import (
replace_return_docstrings,
requires_backends,
)
from ..auto import AutoBackbone
from ...utils.backbone_utils import load_backbone
from .configuration_detr import DetrConfig
......@@ -356,7 +356,7 @@ class DetrConvEncoder(nn.Module):
**kwargs,
)
else:
backbone = AutoBackbone.from_config(config.backbone_config)
backbone = load_backbone(config)
# replace batch norm by frozen batch norm
with torch.no_grad():
......
......@@ -117,6 +117,9 @@ class DPTConfig(PretrainedConfig):
is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
use_pretrained_backbone (`bool`, *optional*, defaults to `False`):
Whether to use pretrained weights for the backbone.
use_timm_backbone (`bool`, *optional*, defaults to `False`):
Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers
library.
Example:
......@@ -169,6 +172,7 @@ class DPTConfig(PretrainedConfig):
backbone_config=None,
backbone=None,
use_pretrained_backbone=False,
use_timm_backbone=False,
**kwargs,
):
super().__init__(**kwargs)
......@@ -179,9 +183,6 @@ class DPTConfig(PretrainedConfig):
if use_pretrained_backbone:
raise ValueError("Pretrained backbones are not supported yet.")
if backbone_config is not None and backbone is not None:
raise ValueError("You can't specify both `backbone` and `backbone_config`.")
use_autobackbone = False
if self.is_hybrid:
if backbone_config is None and backbone is None:
......@@ -193,17 +194,17 @@ class DPTConfig(PretrainedConfig):
"out_features": ["stage1", "stage2", "stage3"],
"embedding_dynamic_padding": True,
}
self.backbone_config = BitConfig(**backbone_config)
backbone_config = BitConfig(**backbone_config)
elif isinstance(backbone_config, dict):
logger.info("Initializing the config with a `BiT` backbone.")
self.backbone_config = BitConfig(**backbone_config)
backbone_config = BitConfig(**backbone_config)
elif isinstance(backbone_config, PretrainedConfig):
self.backbone_config = backbone_config
backbone_config = backbone_config
else:
raise ValueError(
f"backbone_config must be a dictionary or a `PretrainedConfig`, got {backbone_config.__class__}."
)
self.backbone_config = backbone_config
self.backbone_featmap_shape = backbone_featmap_shape
self.neck_ignore_stages = neck_ignore_stages
......@@ -221,14 +222,17 @@ class DPTConfig(PretrainedConfig):
self.backbone_config = backbone_config
self.backbone_featmap_shape = None
self.neck_ignore_stages = []
else:
self.backbone_config = backbone_config
self.backbone_featmap_shape = None
self.neck_ignore_stages = []
if use_autobackbone and backbone_config is not None and backbone is not None:
raise ValueError("You can't specify both `backbone` and `backbone_config`.")
self.backbone = backbone
self.use_pretrained_backbone = use_pretrained_backbone
self.use_timm_backbone = use_timm_backbone
self.num_hidden_layers = None if use_autobackbone else num_hidden_layers
self.num_attention_heads = None if use_autobackbone else num_attention_heads
self.intermediate_size = None if use_autobackbone else intermediate_size
......
......@@ -41,7 +41,7 @@ from ...modeling_outputs import BaseModelOutput, DepthEstimatorOutput, SemanticS
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import ModelOutput, logging
from ..auto import AutoBackbone
from ...utils.backbone_utils import load_backbone
from .configuration_dpt import DPTConfig
......@@ -131,12 +131,10 @@ class DPTViTHybridEmbeddings(nn.Module):
patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.backbone = AutoBackbone.from_config(config.backbone_config)
self.backbone = load_backbone(config)
feature_dim = self.backbone.channels[-1]
if len(config.backbone_config.out_features) != 3:
raise ValueError(
f"Expected backbone to have 3 output features, got {len(config.backbone_config.out_features)}"
)
if len(self.backbone.channels) != 3:
raise ValueError(f"Expected backbone to have 3 output features, got {len(self.backbone.channels)}")
self.residual_feature_map_index = [0, 1] # Always take the output of the first and second backbone stage
if feature_size is None:
......@@ -1082,7 +1080,7 @@ class DPTForDepthEstimation(DPTPreTrainedModel):
self.backbone = None
if config.backbone_config is not None and config.is_hybrid is False:
self.backbone = AutoBackbone.from_config(config.backbone_config)
self.backbone = load_backbone(config)
else:
self.dpt = DPTModel(config, add_pooling_layer=False)
......
......@@ -53,6 +53,9 @@ class Mask2FormerConfig(PretrainedConfig):
is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
use_pretrained_backbone (`bool`, *optional*, `False`):
Whether to use pretrained weights for the backbone.
use_timm_backbone (`bool`, *optional*, `False`):
Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers
library.
feature_size (`int`, *optional*, defaults to 256):
The features (channels) of the resulting feature maps.
mask_feature_size (`int`, *optional*, defaults to 256):
......@@ -162,6 +165,7 @@ class Mask2FormerConfig(PretrainedConfig):
output_auxiliary_logits: bool = None,
backbone=None,
use_pretrained_backbone=False,
use_timm_backbone=False,
**kwargs,
):
if use_pretrained_backbone:
......@@ -228,6 +232,7 @@ class Mask2FormerConfig(PretrainedConfig):
self.num_hidden_layers = decoder_layers
self.backbone = backbone
self.use_pretrained_backbone = use_pretrained_backbone
self.use_timm_backbone = use_timm_backbone
super().__init__(**kwargs)
......
......@@ -23,7 +23,6 @@ import numpy as np
import torch
from torch import Tensor, nn
from ... import AutoBackbone
from ...activations import ACT2FN
from ...file_utils import (
ModelOutput,
......@@ -36,6 +35,7 @@ from ...file_utils import (
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions
from ...modeling_utils import PreTrainedModel
from ...utils import logging
from ...utils.backbone_utils import load_backbone
from .configuration_mask2former import Mask2FormerConfig
......@@ -1376,7 +1376,7 @@ class Mask2FormerPixelLevelModule(nn.Module):
"""
super().__init__()
self.encoder = AutoBackbone.from_config(config.backbone_config)
self.encoder = load_backbone(config)
self.decoder = Mask2FormerPixelDecoder(config, feature_channels=self.encoder.channels)
def forward(self, pixel_values: Tensor, output_hidden_states: bool = False) -> Mask2FormerPixelLevelModuleOutput:
......
......@@ -63,6 +63,9 @@ class MaskFormerConfig(PretrainedConfig):
is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
use_pretrained_backbone (`bool`, *optional*, `False`):
Whether to use pretrained weights for the backbone.
use_timm_backbone (`bool`, *optional*, `False`):
Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers
library.
decoder_config (`Dict`, *optional*):
The configuration passed to the transformer decoder model, if unset the base config for `detr-resnet-50`
will be used.
......@@ -122,6 +125,7 @@ class MaskFormerConfig(PretrainedConfig):
output_auxiliary_logits: Optional[bool] = None,
backbone: Optional[str] = None,
use_pretrained_backbone: bool = False,
use_timm_backbone: bool = False,
**kwargs,
):
if use_pretrained_backbone:
......@@ -193,6 +197,7 @@ class MaskFormerConfig(PretrainedConfig):
self.num_hidden_layers = self.decoder_config.num_hidden_layers
self.backbone = backbone
self.use_pretrained_backbone = use_pretrained_backbone
self.use_timm_backbone = use_timm_backbone
super().__init__(**kwargs)
@classmethod
......
......@@ -23,7 +23,6 @@ import numpy as np
import torch
from torch import Tensor, nn
from ... import AutoBackbone
from ...activations import ACT2FN
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
from ...modeling_outputs import BaseModelOutputWithCrossAttentions
......@@ -37,6 +36,7 @@ from ...utils import (
replace_return_docstrings,
requires_backends,
)
from ...utils.backbone_utils import load_backbone
from ..detr import DetrConfig
from .configuration_maskformer import MaskFormerConfig
from .configuration_maskformer_swin import MaskFormerSwinConfig
......@@ -1428,14 +1428,13 @@ class MaskFormerPixelLevelModule(nn.Module):
The configuration used to instantiate this model.
"""
super().__init__()
# TODD: add method to load pretrained weights of backbone
backbone_config = config.backbone_config
if backbone_config.model_type == "swin":
if hasattr(config, "backbone_config") and config.backbone_config.model_type == "swin":
# for backwards compatibility
backbone_config = config.backbone_config
backbone_config = MaskFormerSwinConfig.from_dict(backbone_config.to_dict())
backbone_config.out_features = ["stage1", "stage2", "stage3", "stage4"]
self.encoder = AutoBackbone.from_config(backbone_config)
config.backbone_config = backbone_config
self.encoder = load_backbone(config)
feature_channels = self.encoder.channels
self.decoder = MaskFormerPixelDecoder(
......
......@@ -50,6 +50,9 @@ class OneFormerConfig(PretrainedConfig):
is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
use_pretrained_backbone (`bool`, *optional*, defaults to `False`):
Whether to use pretrained weights for the backbone.
use_timm_backbone (`bool`, *optional*, defaults to `False`):
Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers
library.
ignore_value (`int`, *optional*, defaults to 255):
Values to be ignored in GT label while calculating loss.
num_queries (`int`, *optional*, defaults to 150):
......@@ -152,6 +155,7 @@ class OneFormerConfig(PretrainedConfig):
backbone_config: Optional[Dict] = None,
backbone: Optional[str] = None,
use_pretrained_backbone: bool = False,
use_timm_backbone: bool = False,
ignore_value: int = 255,
num_queries: int = 150,
no_object_weight: int = 0.1,
......@@ -222,6 +226,7 @@ class OneFormerConfig(PretrainedConfig):
self.backbone_config = backbone_config
self.backbone = backbone
self.use_pretrained_backbone = use_pretrained_backbone
self.use_timm_backbone = use_timm_backbone
self.ignore_value = ignore_value
self.num_queries = num_queries
self.no_object_weight = no_object_weight
......
......@@ -24,7 +24,6 @@ import torch
from torch import Tensor, nn
from torch.cuda.amp import autocast
from ... import AutoBackbone
from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutput
from ...modeling_utils import PreTrainedModel
......@@ -37,6 +36,7 @@ from ...utils import (
replace_return_docstrings,
requires_backends,
)
from ...utils.backbone_utils import load_backbone
from .configuration_oneformer import OneFormerConfig
......@@ -1478,8 +1478,7 @@ class OneFormerPixelLevelModule(nn.Module):
The configuration used to instantiate this model.
"""
super().__init__()
backbone_config = config.backbone_config
self.encoder = AutoBackbone.from_config(backbone_config)
self.encoder = load_backbone(config)
self.decoder = OneFormerPixelDecoder(config, feature_channels=self.encoder.channels)
def forward(self, pixel_values: Tensor, output_hidden_states: bool = False) -> OneFormerPixelLevelModuleOutput:
......
......@@ -37,7 +37,7 @@ from ...utils import (
replace_return_docstrings,
requires_backends,
)
from ..auto import AutoBackbone
from ...utils.backbone_utils import load_backbone
from .configuration_table_transformer import TableTransformerConfig
......@@ -290,7 +290,7 @@ class TableTransformerConvEncoder(nn.Module):
**kwargs,
)
else:
backbone = AutoBackbone.from_config(config.backbone_config)
backbone = load_backbone(config)
# replace batch norm by frozen batch norm
with torch.no_grad():
......
......@@ -49,6 +49,9 @@ class TvpConfig(PretrainedConfig):
is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
use_pretrained_backbone (`bool`, *optional*, defaults to `False`):
Whether to use pretrained weights for the backbone.
use_timm_backbone (`bool`, *optional*, defaults to `False`):
Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers
library.
distance_loss_weight (`float`, *optional*, defaults to 1.0):
The weight of distance loss.
duration_loss_weight (`float`, *optional*, defaults to 0.1):
......@@ -103,6 +106,7 @@ class TvpConfig(PretrainedConfig):
backbone_config=None,
backbone=None,
use_pretrained_backbone=False,
use_timm_backbone=False,
distance_loss_weight=1.0,
duration_loss_weight=0.1,
visual_prompter_type="framepad",
......@@ -143,6 +147,7 @@ class TvpConfig(PretrainedConfig):
self.backbone_config = backbone_config
self.backbone = backbone
self.use_pretrained_backbone = use_pretrained_backbone
self.use_timm_backbone = use_timm_backbone
self.distance_loss_weight = distance_loss_weight
self.duration_loss_weight = duration_loss_weight
self.visual_prompter_type = visual_prompter_type
......
......@@ -28,7 +28,7 @@ from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, Mod
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import prune_linear_layer
from ...utils import logging
from ..auto import AutoBackbone
from ...utils.backbone_utils import load_backbone
from .configuration_tvp import TvpConfig
......@@ -148,7 +148,7 @@ class TvpLoss(nn.Module):
class TvpVisionModel(nn.Module):
def __init__(self, config):
super().__init__()
self.backbone = AutoBackbone.from_config(config.backbone_config)
self.backbone = load_backbone(config)
self.grid_encoder_conv = nn.Conv2d(
config.backbone_config.hidden_sizes[-1],
config.hidden_size,
......
......@@ -42,6 +42,9 @@ class UperNetConfig(PretrainedConfig):
is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
use_pretrained_backbone (`bool`, *optional*, `False`):
Whether to use pretrained weights for the backbone.
use_timm_backbone (`bool`, *optional*, `False`):
Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers
library.
hidden_size (`int`, *optional*, defaults to 512):
The number of hidden units in the convolutional layers.
initializer_range (`float`, *optional*, defaults to 0.02):
......@@ -83,6 +86,7 @@ class UperNetConfig(PretrainedConfig):
backbone_config=None,
backbone=None,
use_pretrained_backbone=False,
use_timm_backbone=False,
hidden_size=512,
initializer_range=0.02,
pool_scales=[1, 2, 3, 6],
......@@ -113,6 +117,7 @@ class UperNetConfig(PretrainedConfig):
self.backbone_config = backbone_config
self.backbone = backbone
self.use_pretrained_backbone = use_pretrained_backbone
self.use_timm_backbone = use_timm_backbone
self.hidden_size = hidden_size
self.initializer_range = initializer_range
self.pool_scales = pool_scales
......
......@@ -20,10 +20,10 @@ import torch
from torch import nn
from torch.nn import CrossEntropyLoss
from ... import AutoBackbone
from ...modeling_outputs import SemanticSegmenterOutput
from ...modeling_utils import PreTrainedModel
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
from ...utils.backbone_utils import load_backbone
from .configuration_upernet import UperNetConfig
......@@ -348,7 +348,7 @@ class UperNetForSemanticSegmentation(UperNetPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.backbone = AutoBackbone.from_config(config.backbone_config)
self.backbone = load_backbone(config)
# Semantic segmentation head(s)
self.decode_head = UperNetHead(config, in_channels=self.backbone.channels)
......
......@@ -48,6 +48,9 @@ class ViTHybridConfig(PretrainedConfig):
is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
use_pretrained_backbone (`bool`, *optional*, defaults to `False`):
Whether to use pretrained weights for the backbone.
use_timm_backbone (`bool`, *optional*, defaults to `False`):
Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers
library.
hidden_size (`int`, *optional*, defaults to 768):
Dimensionality of the encoder layers and the pooler layer.
num_hidden_layers (`int`, *optional*, defaults to 12):
......@@ -100,6 +103,7 @@ class ViTHybridConfig(PretrainedConfig):
backbone_config=None,
backbone=None,
use_pretrained_backbone=False,
use_timm_backbone=False,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
......@@ -147,6 +151,7 @@ class ViTHybridConfig(PretrainedConfig):
self.backbone_config = backbone_config
self.backbone = backbone
self.use_pretrained_backbone = use_pretrained_backbone
self.use_timm_backbone = use_timm_backbone
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
......
......@@ -29,7 +29,7 @@ from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, Ima
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
from ..auto import AutoBackbone
from ...utils.backbone_utils import load_backbone
from .configuration_vit_hybrid import ViTHybridConfig
......@@ -150,7 +150,7 @@ class ViTHybridPatchEmbeddings(nn.Module):
image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
self.backbone = AutoBackbone.from_config(config.backbone_config)
self.backbone = load_backbone(config)
if self.backbone.config.model_type != "bit":
raise ValueError(f"Backbone model type {self.backbone.model_type} is not supported.")
feature_dim = self.backbone.channels[-1]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment