Unverified Commit 2fa1c808 authored by amyeroberts's avatar amyeroberts Committed by GitHub
Browse files

[`Backbone`] Use `load_backbone` instead of `AutoBackbone.from_config` (#28661)

* Enable instantiating model with pretrained backbone weights

* Remove doc updates until changes made in modeling code

* Use load_backbone instead

* Add use_timm_backbone to the model configs

* Add missing imports and arguments

* Update docstrings

* Make sure test is properly configured

* Include recent DPT updates
parent c24c5245
...@@ -37,7 +37,7 @@ from ...utils import ( ...@@ -37,7 +37,7 @@ from ...utils import (
replace_return_docstrings, replace_return_docstrings,
requires_backends, requires_backends,
) )
from ..auto import AutoBackbone from ...utils.backbone_utils import load_backbone
from .configuration_conditional_detr import ConditionalDetrConfig from .configuration_conditional_detr import ConditionalDetrConfig
...@@ -363,7 +363,7 @@ class ConditionalDetrConvEncoder(nn.Module): ...@@ -363,7 +363,7 @@ class ConditionalDetrConvEncoder(nn.Module):
**kwargs, **kwargs,
) )
else: else:
backbone = AutoBackbone.from_config(config.backbone_config) backbone = load_backbone(config)
# replace batch norm by frozen batch norm # replace batch norm by frozen batch norm
with torch.no_grad(): with torch.no_grad():
......
...@@ -44,7 +44,7 @@ from ...modeling_outputs import BaseModelOutput ...@@ -44,7 +44,7 @@ from ...modeling_outputs import BaseModelOutput
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import meshgrid from ...pytorch_utils import meshgrid
from ...utils import is_ninja_available, logging from ...utils import is_ninja_available, logging
from ..auto import AutoBackbone from ...utils.backbone_utils import load_backbone
from .configuration_deformable_detr import DeformableDetrConfig from .configuration_deformable_detr import DeformableDetrConfig
from .load_custom import load_cuda_kernels from .load_custom import load_cuda_kernels
...@@ -409,7 +409,7 @@ class DeformableDetrConvEncoder(nn.Module): ...@@ -409,7 +409,7 @@ class DeformableDetrConvEncoder(nn.Module):
**kwargs, **kwargs,
) )
else: else:
backbone = AutoBackbone.from_config(config.backbone_config) backbone = load_backbone(config)
# replace batch norm by frozen batch norm # replace batch norm by frozen batch norm
with torch.no_grad(): with torch.no_grad():
......
...@@ -46,6 +46,9 @@ class DetaConfig(PretrainedConfig): ...@@ -46,6 +46,9 @@ class DetaConfig(PretrainedConfig):
is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights. is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
use_pretrained_backbone (`bool`, *optional*, `False`): use_pretrained_backbone (`bool`, *optional*, `False`):
Whether to use pretrained weights for the backbone. Whether to use pretrained weights for the backbone.
use_timm_backbone (`bool`, *optional*, `False`):
Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers
library.
num_queries (`int`, *optional*, defaults to 900): num_queries (`int`, *optional*, defaults to 900):
Number of object queries, i.e. detection slots. This is the maximal number of objects [`DetaModel`] can Number of object queries, i.e. detection slots. This is the maximal number of objects [`DetaModel`] can
detect in a single image. In case `two_stage` is set to `True`, we use `two_stage_num_proposals` instead. detect in a single image. In case `two_stage` is set to `True`, we use `two_stage_num_proposals` instead.
...@@ -146,6 +149,7 @@ class DetaConfig(PretrainedConfig): ...@@ -146,6 +149,7 @@ class DetaConfig(PretrainedConfig):
backbone_config=None, backbone_config=None,
backbone=None, backbone=None,
use_pretrained_backbone=False, use_pretrained_backbone=False,
use_timm_backbone=False,
num_queries=900, num_queries=900,
max_position_embeddings=2048, max_position_embeddings=2048,
encoder_layers=6, encoder_layers=6,
...@@ -203,6 +207,7 @@ class DetaConfig(PretrainedConfig): ...@@ -203,6 +207,7 @@ class DetaConfig(PretrainedConfig):
self.backbone_config = backbone_config self.backbone_config = backbone_config
self.backbone = backbone self.backbone = backbone
self.use_pretrained_backbone = use_pretrained_backbone self.use_pretrained_backbone = use_pretrained_backbone
self.use_timm_backbone = use_timm_backbone
self.num_queries = num_queries self.num_queries = num_queries
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
self.d_model = d_model self.d_model = d_model
......
...@@ -39,7 +39,7 @@ from ...modeling_outputs import BaseModelOutput ...@@ -39,7 +39,7 @@ from ...modeling_outputs import BaseModelOutput
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import meshgrid from ...pytorch_utils import meshgrid
from ...utils import is_accelerate_available, is_torchvision_available, logging, requires_backends from ...utils import is_accelerate_available, is_torchvision_available, logging, requires_backends
from ..auto import AutoBackbone from ...utils.backbone_utils import load_backbone
from .configuration_deta import DetaConfig from .configuration_deta import DetaConfig
...@@ -338,7 +338,7 @@ class DetaBackboneWithPositionalEncodings(nn.Module): ...@@ -338,7 +338,7 @@ class DetaBackboneWithPositionalEncodings(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
backbone = AutoBackbone.from_config(config.backbone_config) backbone = load_backbone(config)
with torch.no_grad(): with torch.no_grad():
replace_batch_norm(backbone) replace_batch_norm(backbone)
self.model = backbone self.model = backbone
......
...@@ -37,7 +37,7 @@ from ...utils import ( ...@@ -37,7 +37,7 @@ from ...utils import (
replace_return_docstrings, replace_return_docstrings,
requires_backends, requires_backends,
) )
from ..auto import AutoBackbone from ...utils.backbone_utils import load_backbone
from .configuration_detr import DetrConfig from .configuration_detr import DetrConfig
...@@ -356,7 +356,7 @@ class DetrConvEncoder(nn.Module): ...@@ -356,7 +356,7 @@ class DetrConvEncoder(nn.Module):
**kwargs, **kwargs,
) )
else: else:
backbone = AutoBackbone.from_config(config.backbone_config) backbone = load_backbone(config)
# replace batch norm by frozen batch norm # replace batch norm by frozen batch norm
with torch.no_grad(): with torch.no_grad():
......
...@@ -117,6 +117,9 @@ class DPTConfig(PretrainedConfig): ...@@ -117,6 +117,9 @@ class DPTConfig(PretrainedConfig):
is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights. is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
use_pretrained_backbone (`bool`, *optional*, defaults to `False`): use_pretrained_backbone (`bool`, *optional*, defaults to `False`):
Whether to use pretrained weights for the backbone. Whether to use pretrained weights for the backbone.
use_timm_backbone (`bool`, *optional*, defaults to `False`):
Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers
library.
Example: Example:
...@@ -169,6 +172,7 @@ class DPTConfig(PretrainedConfig): ...@@ -169,6 +172,7 @@ class DPTConfig(PretrainedConfig):
backbone_config=None, backbone_config=None,
backbone=None, backbone=None,
use_pretrained_backbone=False, use_pretrained_backbone=False,
use_timm_backbone=False,
**kwargs, **kwargs,
): ):
super().__init__(**kwargs) super().__init__(**kwargs)
...@@ -179,9 +183,6 @@ class DPTConfig(PretrainedConfig): ...@@ -179,9 +183,6 @@ class DPTConfig(PretrainedConfig):
if use_pretrained_backbone: if use_pretrained_backbone:
raise ValueError("Pretrained backbones are not supported yet.") raise ValueError("Pretrained backbones are not supported yet.")
if backbone_config is not None and backbone is not None:
raise ValueError("You can't specify both `backbone` and `backbone_config`.")
use_autobackbone = False use_autobackbone = False
if self.is_hybrid: if self.is_hybrid:
if backbone_config is None and backbone is None: if backbone_config is None and backbone is None:
...@@ -193,17 +194,17 @@ class DPTConfig(PretrainedConfig): ...@@ -193,17 +194,17 @@ class DPTConfig(PretrainedConfig):
"out_features": ["stage1", "stage2", "stage3"], "out_features": ["stage1", "stage2", "stage3"],
"embedding_dynamic_padding": True, "embedding_dynamic_padding": True,
} }
self.backbone_config = BitConfig(**backbone_config) backbone_config = BitConfig(**backbone_config)
elif isinstance(backbone_config, dict): elif isinstance(backbone_config, dict):
logger.info("Initializing the config with a `BiT` backbone.") logger.info("Initializing the config with a `BiT` backbone.")
self.backbone_config = BitConfig(**backbone_config) backbone_config = BitConfig(**backbone_config)
elif isinstance(backbone_config, PretrainedConfig): elif isinstance(backbone_config, PretrainedConfig):
self.backbone_config = backbone_config backbone_config = backbone_config
else: else:
raise ValueError( raise ValueError(
f"backbone_config must be a dictionary or a `PretrainedConfig`, got {backbone_config.__class__}." f"backbone_config must be a dictionary or a `PretrainedConfig`, got {backbone_config.__class__}."
) )
self.backbone_config = backbone_config
self.backbone_featmap_shape = backbone_featmap_shape self.backbone_featmap_shape = backbone_featmap_shape
self.neck_ignore_stages = neck_ignore_stages self.neck_ignore_stages = neck_ignore_stages
...@@ -221,14 +222,17 @@ class DPTConfig(PretrainedConfig): ...@@ -221,14 +222,17 @@ class DPTConfig(PretrainedConfig):
self.backbone_config = backbone_config self.backbone_config = backbone_config
self.backbone_featmap_shape = None self.backbone_featmap_shape = None
self.neck_ignore_stages = [] self.neck_ignore_stages = []
else: else:
self.backbone_config = backbone_config self.backbone_config = backbone_config
self.backbone_featmap_shape = None self.backbone_featmap_shape = None
self.neck_ignore_stages = [] self.neck_ignore_stages = []
if use_autobackbone and backbone_config is not None and backbone is not None:
raise ValueError("You can't specify both `backbone` and `backbone_config`.")
self.backbone = backbone self.backbone = backbone
self.use_pretrained_backbone = use_pretrained_backbone self.use_pretrained_backbone = use_pretrained_backbone
self.use_timm_backbone = use_timm_backbone
self.num_hidden_layers = None if use_autobackbone else num_hidden_layers self.num_hidden_layers = None if use_autobackbone else num_hidden_layers
self.num_attention_heads = None if use_autobackbone else num_attention_heads self.num_attention_heads = None if use_autobackbone else num_attention_heads
self.intermediate_size = None if use_autobackbone else intermediate_size self.intermediate_size = None if use_autobackbone else intermediate_size
......
...@@ -41,7 +41,7 @@ from ...modeling_outputs import BaseModelOutput, DepthEstimatorOutput, SemanticS ...@@ -41,7 +41,7 @@ from ...modeling_outputs import BaseModelOutput, DepthEstimatorOutput, SemanticS
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import ModelOutput, logging from ...utils import ModelOutput, logging
from ..auto import AutoBackbone from ...utils.backbone_utils import load_backbone
from .configuration_dpt import DPTConfig from .configuration_dpt import DPTConfig
...@@ -131,12 +131,10 @@ class DPTViTHybridEmbeddings(nn.Module): ...@@ -131,12 +131,10 @@ class DPTViTHybridEmbeddings(nn.Module):
patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.backbone = AutoBackbone.from_config(config.backbone_config) self.backbone = load_backbone(config)
feature_dim = self.backbone.channels[-1] feature_dim = self.backbone.channels[-1]
if len(config.backbone_config.out_features) != 3: if len(self.backbone.channels) != 3:
raise ValueError( raise ValueError(f"Expected backbone to have 3 output features, got {len(self.backbone.channels)}")
f"Expected backbone to have 3 output features, got {len(config.backbone_config.out_features)}"
)
self.residual_feature_map_index = [0, 1] # Always take the output of the first and second backbone stage self.residual_feature_map_index = [0, 1] # Always take the output of the first and second backbone stage
if feature_size is None: if feature_size is None:
...@@ -1082,7 +1080,7 @@ class DPTForDepthEstimation(DPTPreTrainedModel): ...@@ -1082,7 +1080,7 @@ class DPTForDepthEstimation(DPTPreTrainedModel):
self.backbone = None self.backbone = None
if config.backbone_config is not None and config.is_hybrid is False: if config.backbone_config is not None and config.is_hybrid is False:
self.backbone = AutoBackbone.from_config(config.backbone_config) self.backbone = load_backbone(config)
else: else:
self.dpt = DPTModel(config, add_pooling_layer=False) self.dpt = DPTModel(config, add_pooling_layer=False)
......
...@@ -53,6 +53,9 @@ class Mask2FormerConfig(PretrainedConfig): ...@@ -53,6 +53,9 @@ class Mask2FormerConfig(PretrainedConfig):
is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights. is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
use_pretrained_backbone (`bool`, *optional*, `False`): use_pretrained_backbone (`bool`, *optional*, `False`):
Whether to use pretrained weights for the backbone. Whether to use pretrained weights for the backbone.
use_timm_backbone (`bool`, *optional*, `False`):
Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers
library.
feature_size (`int`, *optional*, defaults to 256): feature_size (`int`, *optional*, defaults to 256):
The features (channels) of the resulting feature maps. The features (channels) of the resulting feature maps.
mask_feature_size (`int`, *optional*, defaults to 256): mask_feature_size (`int`, *optional*, defaults to 256):
...@@ -162,6 +165,7 @@ class Mask2FormerConfig(PretrainedConfig): ...@@ -162,6 +165,7 @@ class Mask2FormerConfig(PretrainedConfig):
output_auxiliary_logits: bool = None, output_auxiliary_logits: bool = None,
backbone=None, backbone=None,
use_pretrained_backbone=False, use_pretrained_backbone=False,
use_timm_backbone=False,
**kwargs, **kwargs,
): ):
if use_pretrained_backbone: if use_pretrained_backbone:
...@@ -228,6 +232,7 @@ class Mask2FormerConfig(PretrainedConfig): ...@@ -228,6 +232,7 @@ class Mask2FormerConfig(PretrainedConfig):
self.num_hidden_layers = decoder_layers self.num_hidden_layers = decoder_layers
self.backbone = backbone self.backbone = backbone
self.use_pretrained_backbone = use_pretrained_backbone self.use_pretrained_backbone = use_pretrained_backbone
self.use_timm_backbone = use_timm_backbone
super().__init__(**kwargs) super().__init__(**kwargs)
......
...@@ -23,7 +23,6 @@ import numpy as np ...@@ -23,7 +23,6 @@ import numpy as np
import torch import torch
from torch import Tensor, nn from torch import Tensor, nn
from ... import AutoBackbone
from ...activations import ACT2FN from ...activations import ACT2FN
from ...file_utils import ( from ...file_utils import (
ModelOutput, ModelOutput,
...@@ -36,6 +35,7 @@ from ...file_utils import ( ...@@ -36,6 +35,7 @@ from ...file_utils import (
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...utils import logging from ...utils import logging
from ...utils.backbone_utils import load_backbone
from .configuration_mask2former import Mask2FormerConfig from .configuration_mask2former import Mask2FormerConfig
...@@ -1376,7 +1376,7 @@ class Mask2FormerPixelLevelModule(nn.Module): ...@@ -1376,7 +1376,7 @@ class Mask2FormerPixelLevelModule(nn.Module):
""" """
super().__init__() super().__init__()
self.encoder = AutoBackbone.from_config(config.backbone_config) self.encoder = load_backbone(config)
self.decoder = Mask2FormerPixelDecoder(config, feature_channels=self.encoder.channels) self.decoder = Mask2FormerPixelDecoder(config, feature_channels=self.encoder.channels)
def forward(self, pixel_values: Tensor, output_hidden_states: bool = False) -> Mask2FormerPixelLevelModuleOutput: def forward(self, pixel_values: Tensor, output_hidden_states: bool = False) -> Mask2FormerPixelLevelModuleOutput:
......
...@@ -63,6 +63,9 @@ class MaskFormerConfig(PretrainedConfig): ...@@ -63,6 +63,9 @@ class MaskFormerConfig(PretrainedConfig):
is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights. is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
use_pretrained_backbone (`bool`, *optional*, `False`): use_pretrained_backbone (`bool`, *optional*, `False`):
Whether to use pretrained weights for the backbone. Whether to use pretrained weights for the backbone.
use_timm_backbone (`bool`, *optional*, `False`):
Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers
library.
decoder_config (`Dict`, *optional*): decoder_config (`Dict`, *optional*):
The configuration passed to the transformer decoder model, if unset the base config for `detr-resnet-50` The configuration passed to the transformer decoder model, if unset the base config for `detr-resnet-50`
will be used. will be used.
...@@ -122,6 +125,7 @@ class MaskFormerConfig(PretrainedConfig): ...@@ -122,6 +125,7 @@ class MaskFormerConfig(PretrainedConfig):
output_auxiliary_logits: Optional[bool] = None, output_auxiliary_logits: Optional[bool] = None,
backbone: Optional[str] = None, backbone: Optional[str] = None,
use_pretrained_backbone: bool = False, use_pretrained_backbone: bool = False,
use_timm_backbone: bool = False,
**kwargs, **kwargs,
): ):
if use_pretrained_backbone: if use_pretrained_backbone:
...@@ -193,6 +197,7 @@ class MaskFormerConfig(PretrainedConfig): ...@@ -193,6 +197,7 @@ class MaskFormerConfig(PretrainedConfig):
self.num_hidden_layers = self.decoder_config.num_hidden_layers self.num_hidden_layers = self.decoder_config.num_hidden_layers
self.backbone = backbone self.backbone = backbone
self.use_pretrained_backbone = use_pretrained_backbone self.use_pretrained_backbone = use_pretrained_backbone
self.use_timm_backbone = use_timm_backbone
super().__init__(**kwargs) super().__init__(**kwargs)
@classmethod @classmethod
......
...@@ -23,7 +23,6 @@ import numpy as np ...@@ -23,7 +23,6 @@ import numpy as np
import torch import torch
from torch import Tensor, nn from torch import Tensor, nn
from ... import AutoBackbone
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
from ...modeling_outputs import BaseModelOutputWithCrossAttentions from ...modeling_outputs import BaseModelOutputWithCrossAttentions
...@@ -37,6 +36,7 @@ from ...utils import ( ...@@ -37,6 +36,7 @@ from ...utils import (
replace_return_docstrings, replace_return_docstrings,
requires_backends, requires_backends,
) )
from ...utils.backbone_utils import load_backbone
from ..detr import DetrConfig from ..detr import DetrConfig
from .configuration_maskformer import MaskFormerConfig from .configuration_maskformer import MaskFormerConfig
from .configuration_maskformer_swin import MaskFormerSwinConfig from .configuration_maskformer_swin import MaskFormerSwinConfig
...@@ -1428,14 +1428,13 @@ class MaskFormerPixelLevelModule(nn.Module): ...@@ -1428,14 +1428,13 @@ class MaskFormerPixelLevelModule(nn.Module):
The configuration used to instantiate this model. The configuration used to instantiate this model.
""" """
super().__init__() super().__init__()
if hasattr(config, "backbone_config") and config.backbone_config.model_type == "swin":
# TODD: add method to load pretrained weights of backbone
backbone_config = config.backbone_config
if backbone_config.model_type == "swin":
# for backwards compatibility # for backwards compatibility
backbone_config = config.backbone_config
backbone_config = MaskFormerSwinConfig.from_dict(backbone_config.to_dict()) backbone_config = MaskFormerSwinConfig.from_dict(backbone_config.to_dict())
backbone_config.out_features = ["stage1", "stage2", "stage3", "stage4"] backbone_config.out_features = ["stage1", "stage2", "stage3", "stage4"]
self.encoder = AutoBackbone.from_config(backbone_config) config.backbone_config = backbone_config
self.encoder = load_backbone(config)
feature_channels = self.encoder.channels feature_channels = self.encoder.channels
self.decoder = MaskFormerPixelDecoder( self.decoder = MaskFormerPixelDecoder(
......
...@@ -50,6 +50,9 @@ class OneFormerConfig(PretrainedConfig): ...@@ -50,6 +50,9 @@ class OneFormerConfig(PretrainedConfig):
is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights. is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
use_pretrained_backbone (`bool`, *optional*, defaults to `False`): use_pretrained_backbone (`bool`, *optional*, defaults to `False`):
Whether to use pretrained weights for the backbone. Whether to use pretrained weights for the backbone.
use_timm_backbone (`bool`, *optional*, defaults to `False`):
Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers
library.
ignore_value (`int`, *optional*, defaults to 255): ignore_value (`int`, *optional*, defaults to 255):
Values to be ignored in GT label while calculating loss. Values to be ignored in GT label while calculating loss.
num_queries (`int`, *optional*, defaults to 150): num_queries (`int`, *optional*, defaults to 150):
...@@ -152,6 +155,7 @@ class OneFormerConfig(PretrainedConfig): ...@@ -152,6 +155,7 @@ class OneFormerConfig(PretrainedConfig):
backbone_config: Optional[Dict] = None, backbone_config: Optional[Dict] = None,
backbone: Optional[str] = None, backbone: Optional[str] = None,
use_pretrained_backbone: bool = False, use_pretrained_backbone: bool = False,
use_timm_backbone: bool = False,
ignore_value: int = 255, ignore_value: int = 255,
num_queries: int = 150, num_queries: int = 150,
no_object_weight: int = 0.1, no_object_weight: int = 0.1,
...@@ -222,6 +226,7 @@ class OneFormerConfig(PretrainedConfig): ...@@ -222,6 +226,7 @@ class OneFormerConfig(PretrainedConfig):
self.backbone_config = backbone_config self.backbone_config = backbone_config
self.backbone = backbone self.backbone = backbone
self.use_pretrained_backbone = use_pretrained_backbone self.use_pretrained_backbone = use_pretrained_backbone
self.use_timm_backbone = use_timm_backbone
self.ignore_value = ignore_value self.ignore_value = ignore_value
self.num_queries = num_queries self.num_queries = num_queries
self.no_object_weight = no_object_weight self.no_object_weight = no_object_weight
......
...@@ -24,7 +24,6 @@ import torch ...@@ -24,7 +24,6 @@ import torch
from torch import Tensor, nn from torch import Tensor, nn
from torch.cuda.amp import autocast from torch.cuda.amp import autocast
from ... import AutoBackbone
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutput from ...modeling_outputs import BaseModelOutput
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
...@@ -37,6 +36,7 @@ from ...utils import ( ...@@ -37,6 +36,7 @@ from ...utils import (
replace_return_docstrings, replace_return_docstrings,
requires_backends, requires_backends,
) )
from ...utils.backbone_utils import load_backbone
from .configuration_oneformer import OneFormerConfig from .configuration_oneformer import OneFormerConfig
...@@ -1478,8 +1478,7 @@ class OneFormerPixelLevelModule(nn.Module): ...@@ -1478,8 +1478,7 @@ class OneFormerPixelLevelModule(nn.Module):
The configuration used to instantiate this model. The configuration used to instantiate this model.
""" """
super().__init__() super().__init__()
backbone_config = config.backbone_config self.encoder = load_backbone(config)
self.encoder = AutoBackbone.from_config(backbone_config)
self.decoder = OneFormerPixelDecoder(config, feature_channels=self.encoder.channels) self.decoder = OneFormerPixelDecoder(config, feature_channels=self.encoder.channels)
def forward(self, pixel_values: Tensor, output_hidden_states: bool = False) -> OneFormerPixelLevelModuleOutput: def forward(self, pixel_values: Tensor, output_hidden_states: bool = False) -> OneFormerPixelLevelModuleOutput:
......
...@@ -37,7 +37,7 @@ from ...utils import ( ...@@ -37,7 +37,7 @@ from ...utils import (
replace_return_docstrings, replace_return_docstrings,
requires_backends, requires_backends,
) )
from ..auto import AutoBackbone from ...utils.backbone_utils import load_backbone
from .configuration_table_transformer import TableTransformerConfig from .configuration_table_transformer import TableTransformerConfig
...@@ -290,7 +290,7 @@ class TableTransformerConvEncoder(nn.Module): ...@@ -290,7 +290,7 @@ class TableTransformerConvEncoder(nn.Module):
**kwargs, **kwargs,
) )
else: else:
backbone = AutoBackbone.from_config(config.backbone_config) backbone = load_backbone(config)
# replace batch norm by frozen batch norm # replace batch norm by frozen batch norm
with torch.no_grad(): with torch.no_grad():
......
...@@ -49,6 +49,9 @@ class TvpConfig(PretrainedConfig): ...@@ -49,6 +49,9 @@ class TvpConfig(PretrainedConfig):
is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights. is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
use_pretrained_backbone (`bool`, *optional*, defaults to `False`): use_pretrained_backbone (`bool`, *optional*, defaults to `False`):
Whether to use pretrained weights for the backbone. Whether to use pretrained weights for the backbone.
use_timm_backbone (`bool`, *optional*, defaults to `False`):
Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers
library.
distance_loss_weight (`float`, *optional*, defaults to 1.0): distance_loss_weight (`float`, *optional*, defaults to 1.0):
The weight of distance loss. The weight of distance loss.
duration_loss_weight (`float`, *optional*, defaults to 0.1): duration_loss_weight (`float`, *optional*, defaults to 0.1):
...@@ -103,6 +106,7 @@ class TvpConfig(PretrainedConfig): ...@@ -103,6 +106,7 @@ class TvpConfig(PretrainedConfig):
backbone_config=None, backbone_config=None,
backbone=None, backbone=None,
use_pretrained_backbone=False, use_pretrained_backbone=False,
use_timm_backbone=False,
distance_loss_weight=1.0, distance_loss_weight=1.0,
duration_loss_weight=0.1, duration_loss_weight=0.1,
visual_prompter_type="framepad", visual_prompter_type="framepad",
...@@ -143,6 +147,7 @@ class TvpConfig(PretrainedConfig): ...@@ -143,6 +147,7 @@ class TvpConfig(PretrainedConfig):
self.backbone_config = backbone_config self.backbone_config = backbone_config
self.backbone = backbone self.backbone = backbone
self.use_pretrained_backbone = use_pretrained_backbone self.use_pretrained_backbone = use_pretrained_backbone
self.use_timm_backbone = use_timm_backbone
self.distance_loss_weight = distance_loss_weight self.distance_loss_weight = distance_loss_weight
self.duration_loss_weight = duration_loss_weight self.duration_loss_weight = duration_loss_weight
self.visual_prompter_type = visual_prompter_type self.visual_prompter_type = visual_prompter_type
......
...@@ -28,7 +28,7 @@ from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, Mod ...@@ -28,7 +28,7 @@ from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, Mod
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import prune_linear_layer from ...pytorch_utils import prune_linear_layer
from ...utils import logging from ...utils import logging
from ..auto import AutoBackbone from ...utils.backbone_utils import load_backbone
from .configuration_tvp import TvpConfig from .configuration_tvp import TvpConfig
...@@ -148,7 +148,7 @@ class TvpLoss(nn.Module): ...@@ -148,7 +148,7 @@ class TvpLoss(nn.Module):
class TvpVisionModel(nn.Module): class TvpVisionModel(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.backbone = AutoBackbone.from_config(config.backbone_config) self.backbone = load_backbone(config)
self.grid_encoder_conv = nn.Conv2d( self.grid_encoder_conv = nn.Conv2d(
config.backbone_config.hidden_sizes[-1], config.backbone_config.hidden_sizes[-1],
config.hidden_size, config.hidden_size,
......
...@@ -42,6 +42,9 @@ class UperNetConfig(PretrainedConfig): ...@@ -42,6 +42,9 @@ class UperNetConfig(PretrainedConfig):
is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights. is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
use_pretrained_backbone (`bool`, *optional*, `False`): use_pretrained_backbone (`bool`, *optional*, `False`):
Whether to use pretrained weights for the backbone. Whether to use pretrained weights for the backbone.
use_timm_backbone (`bool`, *optional*, `False`):
Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers
library.
hidden_size (`int`, *optional*, defaults to 512): hidden_size (`int`, *optional*, defaults to 512):
The number of hidden units in the convolutional layers. The number of hidden units in the convolutional layers.
initializer_range (`float`, *optional*, defaults to 0.02): initializer_range (`float`, *optional*, defaults to 0.02):
...@@ -83,6 +86,7 @@ class UperNetConfig(PretrainedConfig): ...@@ -83,6 +86,7 @@ class UperNetConfig(PretrainedConfig):
backbone_config=None, backbone_config=None,
backbone=None, backbone=None,
use_pretrained_backbone=False, use_pretrained_backbone=False,
use_timm_backbone=False,
hidden_size=512, hidden_size=512,
initializer_range=0.02, initializer_range=0.02,
pool_scales=[1, 2, 3, 6], pool_scales=[1, 2, 3, 6],
...@@ -113,6 +117,7 @@ class UperNetConfig(PretrainedConfig): ...@@ -113,6 +117,7 @@ class UperNetConfig(PretrainedConfig):
self.backbone_config = backbone_config self.backbone_config = backbone_config
self.backbone = backbone self.backbone = backbone
self.use_pretrained_backbone = use_pretrained_backbone self.use_pretrained_backbone = use_pretrained_backbone
self.use_timm_backbone = use_timm_backbone
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.initializer_range = initializer_range self.initializer_range = initializer_range
self.pool_scales = pool_scales self.pool_scales = pool_scales
......
...@@ -20,10 +20,10 @@ import torch ...@@ -20,10 +20,10 @@ import torch
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
from ... import AutoBackbone
from ...modeling_outputs import SemanticSegmenterOutput from ...modeling_outputs import SemanticSegmenterOutput
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
from ...utils.backbone_utils import load_backbone
from .configuration_upernet import UperNetConfig from .configuration_upernet import UperNetConfig
...@@ -348,7 +348,7 @@ class UperNetForSemanticSegmentation(UperNetPreTrainedModel): ...@@ -348,7 +348,7 @@ class UperNetForSemanticSegmentation(UperNetPreTrainedModel):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.backbone = AutoBackbone.from_config(config.backbone_config) self.backbone = load_backbone(config)
# Semantic segmentation head(s) # Semantic segmentation head(s)
self.decode_head = UperNetHead(config, in_channels=self.backbone.channels) self.decode_head = UperNetHead(config, in_channels=self.backbone.channels)
......
...@@ -48,6 +48,9 @@ class ViTHybridConfig(PretrainedConfig): ...@@ -48,6 +48,9 @@ class ViTHybridConfig(PretrainedConfig):
is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights. is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
use_pretrained_backbone (`bool`, *optional*, defaults to `False`): use_pretrained_backbone (`bool`, *optional*, defaults to `False`):
Whether to use pretrained weights for the backbone. Whether to use pretrained weights for the backbone.
use_timm_backbone (`bool`, *optional*, defaults to `False`):
Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers
library.
hidden_size (`int`, *optional*, defaults to 768): hidden_size (`int`, *optional*, defaults to 768):
Dimensionality of the encoder layers and the pooler layer. Dimensionality of the encoder layers and the pooler layer.
num_hidden_layers (`int`, *optional*, defaults to 12): num_hidden_layers (`int`, *optional*, defaults to 12):
...@@ -100,6 +103,7 @@ class ViTHybridConfig(PretrainedConfig): ...@@ -100,6 +103,7 @@ class ViTHybridConfig(PretrainedConfig):
backbone_config=None, backbone_config=None,
backbone=None, backbone=None,
use_pretrained_backbone=False, use_pretrained_backbone=False,
use_timm_backbone=False,
hidden_size=768, hidden_size=768,
num_hidden_layers=12, num_hidden_layers=12,
num_attention_heads=12, num_attention_heads=12,
...@@ -147,6 +151,7 @@ class ViTHybridConfig(PretrainedConfig): ...@@ -147,6 +151,7 @@ class ViTHybridConfig(PretrainedConfig):
self.backbone_config = backbone_config self.backbone_config = backbone_config
self.backbone = backbone self.backbone = backbone
self.use_pretrained_backbone = use_pretrained_backbone self.use_pretrained_backbone = use_pretrained_backbone
self.use_timm_backbone = use_timm_backbone
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads self.num_attention_heads = num_attention_heads
......
...@@ -29,7 +29,7 @@ from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, Ima ...@@ -29,7 +29,7 @@ from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, Ima
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
from ..auto import AutoBackbone from ...utils.backbone_utils import load_backbone
from .configuration_vit_hybrid import ViTHybridConfig from .configuration_vit_hybrid import ViTHybridConfig
...@@ -150,7 +150,7 @@ class ViTHybridPatchEmbeddings(nn.Module): ...@@ -150,7 +150,7 @@ class ViTHybridPatchEmbeddings(nn.Module):
image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
self.backbone = AutoBackbone.from_config(config.backbone_config) self.backbone = load_backbone(config)
if self.backbone.config.model_type != "bit": if self.backbone.config.model_type != "bit":
raise ValueError(f"Backbone model type {self.backbone.model_type} is not supported.") raise ValueError(f"Backbone model type {self.backbone.model_type} is not supported.")
feature_dim = self.backbone.channels[-1] feature_dim = self.backbone.channels[-1]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment