Unverified Commit 27c79a0f authored by amyeroberts's avatar amyeroberts Committed by GitHub
Browse files

Enable instantiating model with pretrained backbone weights (#28214)



* Enable instantiating model with pretrained backbone weights

* Update tests so backbone checkpoint isn't passed in

* Remove doc updates until changes made in modeling code

* Clarify pretrained import

* Update configs - docs and validation check

* Update src/transformers/utils/backbone_utils.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Clarify exception message

* Update config init in tests

* Add test for when use_timm_backbone=True

* Small test updates

---------
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
parent 008a6a22
......@@ -602,10 +602,6 @@ class _BaseAutoBackboneClass(_BaseAutoModelClass):
config = kwargs.pop("config", TimmBackboneConfig())
use_timm = kwargs.pop("use_timm_backbone", True)
if not use_timm:
raise ValueError("`use_timm_backbone` must be `True` for timm backbones")
if kwargs.get("out_features", None) is not None:
raise ValueError("Cannot specify `out_features` for timm backbones")
......@@ -627,7 +623,8 @@ class _BaseAutoBackboneClass(_BaseAutoModelClass):
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
if kwargs.get("use_timm_backbone", False):
use_timm_backbone = kwargs.pop("use_timm_backbone", False)
if use_timm_backbone:
return cls._load_timm_backbone_from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
......
......@@ -93,11 +93,11 @@ class ConditionalDetrConfig(PretrainedConfig):
position_embedding_type (`str`, *optional*, defaults to `"sine"`):
Type of position embeddings to be used on top of the image features. One of `"sine"` or `"learned"`.
backbone (`str`, *optional*, defaults to `"resnet50"`):
Name of convolutional backbone to use in case `use_timm_backbone` = `True`. Supports any convolutional
backbone from the timm package. For a list of all available models, see [this
page](https://rwightman.github.io/pytorch-image-models/#load-a-pretrained-model).
Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone`
is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
use_pretrained_backbone (`bool`, *optional*, defaults to `True`):
Whether to use pretrained weights for the backbone. Only supported when `use_timm_backbone` = `True`.
Whether to use pretrained weights for the backbone.
dilation (`bool`, *optional*, defaults to `False`):
Whether to replace stride with dilation in the last convolutional block (DC5). Only supported when
`use_timm_backbone` = `True`.
......@@ -180,6 +180,14 @@ class ConditionalDetrConfig(PretrainedConfig):
focal_alpha=0.25,
**kwargs,
):
if not use_timm_backbone and use_pretrained_backbone:
raise ValueError(
"Loading pretrained backbone weights from the transformers library is not supported yet. `use_timm_backbone` must be set to `True` when `use_pretrained_backbone=True`"
)
if backbone_config is not None and backbone is not None:
raise ValueError("You can't specify both `backbone` and `backbone_config`.")
if backbone_config is not None and use_timm_backbone:
raise ValueError("You can't specify both `backbone_config` and `use_timm_backbone`.")
......
......@@ -85,11 +85,11 @@ class DeformableDetrConfig(PretrainedConfig):
position_embedding_type (`str`, *optional*, defaults to `"sine"`):
Type of position embeddings to be used on top of the image features. One of `"sine"` or `"learned"`.
backbone (`str`, *optional*, defaults to `"resnet50"`):
Name of convolutional backbone to use in case `use_timm_backbone` = `True`. Supports any convolutional
backbone from the timm package. For a list of all available models, see [this
page](https://rwightman.github.io/pytorch-image-models/#load-a-pretrained-model).
Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone`
is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
use_pretrained_backbone (`bool`, *optional*, defaults to `True`):
Whether to use pretrained weights for the backbone. Only supported when `use_timm_backbone` = `True`.
Whether to use pretrained weights for the backbone.
dilation (`bool`, *optional*, defaults to `False`):
Whether to replace stride with dilation in the last convolutional block (DC5). Only supported when
`use_timm_backbone` = `True`.
......@@ -196,6 +196,14 @@ class DeformableDetrConfig(PretrainedConfig):
disable_custom_kernels=False,
**kwargs,
):
if not use_timm_backbone and use_pretrained_backbone:
raise ValueError(
"Loading pretrained backbone weights from the transformers library is not supported yet. `use_timm_backbone` must be set to `True` when `use_pretrained_backbone=True`"
)
if backbone_config is not None and backbone is not None:
raise ValueError("You can't specify both `backbone` and `backbone_config`.")
if backbone_config is not None and use_timm_backbone:
raise ValueError("You can't specify both `backbone_config` and `use_timm_backbone`.")
......
......@@ -40,6 +40,12 @@ class DetaConfig(PretrainedConfig):
Args:
backbone_config (`PretrainedConfig` or `dict`, *optional*, defaults to `ResNetConfig()`):
The configuration of the backbone model.
backbone (`str`, *optional*):
Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone`
is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
use_pretrained_backbone (`bool`, *optional*, `False`):
Whether to use pretrained weights for the backbone.
num_queries (`int`, *optional*, defaults to 900):
Number of object queries, i.e. detection slots. This is the maximal number of objects [`DetaModel`] can
detect in a single image. In case `two_stage` is set to `True`, we use `two_stage_num_proposals` instead.
......@@ -138,6 +144,8 @@ class DetaConfig(PretrainedConfig):
def __init__(
self,
backbone_config=None,
backbone=None,
use_pretrained_backbone=False,
num_queries=900,
max_position_embeddings=2048,
encoder_layers=6,
......@@ -177,7 +185,13 @@ class DetaConfig(PretrainedConfig):
focal_alpha=0.25,
**kwargs,
):
if backbone_config is None:
if use_pretrained_backbone:
raise ValueError("Pretrained backbones are not supported yet.")
if backbone_config is not None and backbone is not None:
raise ValueError("You can't specify both `backbone` and `backbone_config`.")
if backbone_config is None and backbone is None:
logger.info("`backbone_config` is `None`. Initializing the config with the default `ResNet` backbone.")
backbone_config = CONFIG_MAPPING["resnet"](out_features=["stage2", "stage3", "stage4"])
else:
......@@ -187,6 +201,8 @@ class DetaConfig(PretrainedConfig):
backbone_config = config_class.from_dict(backbone_config)
self.backbone_config = backbone_config
self.backbone = backbone
self.use_pretrained_backbone = use_pretrained_backbone
self.num_queries = num_queries
self.max_position_embeddings = max_position_embeddings
self.d_model = d_model
......
......@@ -93,11 +93,11 @@ class DetrConfig(PretrainedConfig):
position_embedding_type (`str`, *optional*, defaults to `"sine"`):
Type of position embeddings to be used on top of the image features. One of `"sine"` or `"learned"`.
backbone (`str`, *optional*, defaults to `"resnet50"`):
Name of convolutional backbone to use in case `use_timm_backbone` = `True`. Supports any convolutional
backbone from the timm package. For a list of all available models, see [this
page](https://rwightman.github.io/pytorch-image-models/#load-a-pretrained-model).
use_pretrained_backbone (`bool`, *optional*, defaults to `True`):
Whether to use pretrained weights for the backbone. Only supported when `use_timm_backbone` = `True`.
Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone`
is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
use_pretrained_backbone (`bool`, *optional*, `True`):
Whether to use pretrained weights for the backbone.
dilation (`bool`, *optional*, defaults to `False`):
Whether to replace stride with dilation in the last convolutional block (DC5). Only supported when
`use_timm_backbone` = `True`.
......@@ -177,6 +177,14 @@ class DetrConfig(PretrainedConfig):
eos_coefficient=0.1,
**kwargs,
):
if not use_timm_backbone and use_pretrained_backbone:
raise ValueError(
"Loading pretrained backbone weights from the transformers library is not supported yet. `use_timm_backbone` must be set to `True` when `use_pretrained_backbone=True`"
)
if backbone_config is not None and backbone is not None:
raise ValueError("You can't specify both `backbone` and `backbone_config`.")
if backbone_config is not None and use_timm_backbone:
raise ValueError("You can't specify both `backbone_config` and `use_timm_backbone`.")
......
......@@ -111,6 +111,12 @@ class DPTConfig(PretrainedConfig):
backbone_config (`Union[Dict[str, Any], PretrainedConfig]`, *optional*):
The configuration of the backbone model. Only used in case `is_hybrid` is `True` or in case you want to
leverage the [`AutoBackbone`] API.
backbone (`str`, *optional*):
Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone`
is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
use_pretrained_backbone (`bool`, *optional*, defaults to `False`):
Whether to use pretrained weights for the backbone.
Example:
......@@ -161,6 +167,8 @@ class DPTConfig(PretrainedConfig):
backbone_featmap_shape=[1, 1024, 24, 24],
neck_ignore_stages=[0, 1],
backbone_config=None,
backbone=None,
use_pretrained_backbone=False,
**kwargs,
):
super().__init__(**kwargs)
......@@ -168,9 +176,15 @@ class DPTConfig(PretrainedConfig):
self.hidden_size = hidden_size
self.is_hybrid = is_hybrid
if use_pretrained_backbone:
raise ValueError("Pretrained backbones are not supported yet.")
if backbone_config is not None and backbone is not None:
raise ValueError("You can't specify both `backbone` and `backbone_config`.")
use_autobackbone = False
if self.is_hybrid:
if backbone_config is None:
if backbone_config is None and backbone is None:
logger.info("Initializing the config with a `BiT` backbone.")
backbone_config = {
"global_padding": "same",
......@@ -213,6 +227,8 @@ class DPTConfig(PretrainedConfig):
self.backbone_featmap_shape = None
self.neck_ignore_stages = []
self.backbone = backbone
self.use_pretrained_backbone = use_pretrained_backbone
self.num_hidden_layers = None if use_autobackbone else num_hidden_layers
self.num_attention_heads = None if use_autobackbone else num_attention_heads
self.intermediate_size = None if use_autobackbone else intermediate_size
......
......@@ -47,6 +47,12 @@ class Mask2FormerConfig(PretrainedConfig):
backbone_config (`PretrainedConfig` or `dict`, *optional*, defaults to `SwinConfig()`):
The configuration of the backbone model. If unset, the configuration corresponding to
`swin-base-patch4-window12-384` will be used.
backbone (`str`, *optional*):
Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone`
is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
use_pretrained_backbone (`bool`, *optional*, `False`):
Whether to use pretrained weights for the backbone.
feature_size (`int`, *optional*, defaults to 256):
The features (channels) of the resulting feature maps.
mask_feature_size (`int`, *optional*, defaults to 256):
......@@ -154,9 +160,17 @@ class Mask2FormerConfig(PretrainedConfig):
use_auxiliary_loss: bool = True,
feature_strides: List[int] = [4, 8, 16, 32],
output_auxiliary_logits: bool = None,
backbone=None,
use_pretrained_backbone=False,
**kwargs,
):
if backbone_config is None:
if use_pretrained_backbone:
raise ValueError("Pretrained backbones are not supported yet.")
if backbone_config is not None and backbone is not None:
raise ValueError("You can't specify both `backbone` and `backbone_config`.")
if backbone_config is None and backbone is None:
logger.info("`backbone_config` is `None`. Initializing the config with the default `Swin` backbone.")
backbone_config = CONFIG_MAPPING["swin"](
image_size=224,
......@@ -177,7 +191,7 @@ class Mask2FormerConfig(PretrainedConfig):
backbone_config = config_class.from_dict(backbone_config)
# verify that the backbone is supported
if backbone_config.model_type not in self.backbones_supported:
if backbone_config is not None and backbone_config.model_type not in self.backbones_supported:
logger.warning_once(
f"Backbone {backbone_config.model_type} is not a supported model and may not be compatible with Mask2Former. "
f"Supported model types: {','.join(self.backbones_supported)}"
......@@ -212,6 +226,8 @@ class Mask2FormerConfig(PretrainedConfig):
self.feature_strides = feature_strides
self.output_auxiliary_logits = output_auxiliary_logits
self.num_hidden_layers = decoder_layers
self.backbone = backbone
self.use_pretrained_backbone = use_pretrained_backbone
super().__init__(**kwargs)
......
......@@ -57,6 +57,12 @@ class MaskFormerConfig(PretrainedConfig):
backbone_config (`Dict`, *optional*):
The configuration passed to the backbone, if unset, the configuration corresponding to
`swin-base-patch4-window12-384` will be used.
backbone (`str`, *optional*):
Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone`
is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
use_pretrained_backbone (`bool`, *optional*, `False`):
Whether to use pretrained weights for the backbone.
decoder_config (`Dict`, *optional*):
The configuration passed to the transformer decoder model, if unset the base config for `detr-resnet-50`
will be used.
......@@ -114,9 +120,17 @@ class MaskFormerConfig(PretrainedConfig):
cross_entropy_weight: float = 1.0,
mask_weight: float = 20.0,
output_auxiliary_logits: Optional[bool] = None,
backbone: Optional[str] = None,
use_pretrained_backbone: bool = False,
**kwargs,
):
if backbone_config is None:
if use_pretrained_backbone:
raise ValueError("Pretrained backbones are not supported yet.")
if backbone_config is not None and backbone is not None:
raise ValueError("You can't specify both `backbone` and `backbone_config`.")
if backbone_config is None and backbone is None:
# fall back to https://huggingface.co/microsoft/swin-base-patch4-window12-384-in22k
backbone_config = SwinConfig(
image_size=384,
......@@ -136,7 +150,7 @@ class MaskFormerConfig(PretrainedConfig):
backbone_config = config_class.from_dict(backbone_config)
# verify that the backbone is supported
if backbone_config.model_type not in self.backbones_supported:
if backbone_config is not None and backbone_config.model_type not in self.backbones_supported:
logger.warning_once(
f"Backbone {backbone_config.model_type} is not a supported model and may not be compatible with MaskFormer. "
f"Supported model types: {','.join(self.backbones_supported)}"
......@@ -177,6 +191,8 @@ class MaskFormerConfig(PretrainedConfig):
self.num_attention_heads = self.decoder_config.encoder_attention_heads
self.num_hidden_layers = self.decoder_config.num_hidden_layers
self.backbone = backbone
self.use_pretrained_backbone = use_pretrained_backbone
super().__init__(**kwargs)
@classmethod
......
......@@ -44,6 +44,12 @@ class OneFormerConfig(PretrainedConfig):
Args:
backbone_config (`PretrainedConfig`, *optional*, defaults to `SwinConfig`):
The configuration of the backbone model.
backbone (`str`, *optional*):
Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone`
is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
use_pretrained_backbone (`bool`, *optional*, defaults to `False`):
Whether to use pretrained weights for the backbone.
ignore_value (`int`, *optional*, defaults to 255):
Values to be ignored in GT label while calculating loss.
num_queries (`int`, *optional*, defaults to 150):
......@@ -144,6 +150,8 @@ class OneFormerConfig(PretrainedConfig):
def __init__(
self,
backbone_config: Optional[Dict] = None,
backbone: Optional[str] = None,
use_pretrained_backbone: bool = False,
ignore_value: int = 255,
num_queries: int = 150,
no_object_weight: int = 0.1,
......@@ -186,7 +194,13 @@ class OneFormerConfig(PretrainedConfig):
common_stride: int = 4,
**kwargs,
):
if backbone_config is None:
if use_pretrained_backbone:
raise ValueError("Pretrained backbones are not supported yet.")
if backbone_config is not None and backbone is not None:
raise ValueError("You can't specify both `backbone` and `backbone_config`.")
if backbone_config is None and backbone is None:
logger.info("`backbone_config` is unset. Initializing the config with the default `Swin` backbone.")
backbone_config = CONFIG_MAPPING["swin"](
image_size=224,
......@@ -206,7 +220,8 @@ class OneFormerConfig(PretrainedConfig):
backbone_config = config_class.from_dict(backbone_config)
self.backbone_config = backbone_config
self.backbone = backbone
self.use_pretrained_backbone = use_pretrained_backbone
self.ignore_value = ignore_value
self.num_queries = num_queries
self.no_object_weight = no_object_weight
......
......@@ -92,12 +92,12 @@ class TableTransformerConfig(PretrainedConfig):
Whether auxiliary decoding losses (loss at each decoder layer) are to be used.
position_embedding_type (`str`, *optional*, defaults to `"sine"`):
Type of position embeddings to be used on top of the image features. One of `"sine"` or `"learned"`.
backbone (`str`, *optional*, defaults to `"resnet50"`):
Name of convolutional backbone to use in case `use_timm_backbone` = `True`. Supports any convolutional
backbone from the timm package. For a list of all available models, see [this
page](https://rwightman.github.io/pytorch-image-models/#load-a-pretrained-model).
use_pretrained_backbone (`bool`, *optional*, defaults to `True`):
Whether to use pretrained weights for the backbone. Only supported when `use_timm_backbone` = `True`.
backbone (`str`, *optional*):
Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone`
is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
use_pretrained_backbone (`bool`, *optional*, `True`):
Whether to use pretrained weights for the backbone.
dilation (`bool`, *optional*, defaults to `False`):
Whether to replace stride with dilation in the last convolutional block (DC5). Only supported when
`use_timm_backbone` = `True`.
......@@ -178,6 +178,14 @@ class TableTransformerConfig(PretrainedConfig):
eos_coefficient=0.1,
**kwargs,
):
if not use_timm_backbone and use_pretrained_backbone:
raise ValueError(
"Loading pretrained backbone weights from the transformers library is not supported yet. `use_timm_backbone` must be set to `True` when `use_pretrained_backbone=True`"
)
if backbone_config is not None and backbone is not None:
raise ValueError("You can't specify both `backbone` and `backbone_config`.")
if backbone_config is not None and use_timm_backbone:
raise ValueError("You can't specify both `backbone_config` and `use_timm_backbone`.")
......
......@@ -43,6 +43,12 @@ class TvpConfig(PretrainedConfig):
Args:
backbone_config (`PretrainedConfig` or `dict`, *optional*):
The configuration of the backbone model.
backbone (`str`, *optional*):
Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone`
is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
use_pretrained_backbone (`bool`, *optional*, defaults to `False`):
Whether to use pretrained weights for the backbone.
distance_loss_weight (`float`, *optional*, defaults to 1.0):
The weight of distance loss.
duration_loss_weight (`float`, *optional*, defaults to 0.1):
......@@ -95,6 +101,8 @@ class TvpConfig(PretrainedConfig):
def __init__(
self,
backbone_config=None,
backbone=None,
use_pretrained_backbone=False,
distance_loss_weight=1.0,
duration_loss_weight=0.1,
visual_prompter_type="framepad",
......@@ -118,8 +126,13 @@ class TvpConfig(PretrainedConfig):
**kwargs,
):
super().__init__(**kwargs)
if use_pretrained_backbone:
raise ValueError("Pretrained backbones are not supported yet.")
if backbone_config is None:
if backbone_config is not None and backbone is not None:
raise ValueError("You can't specify both `backbone` and `backbone_config`.")
if backbone_config is None and backbone is None:
logger.info("`backbone_config` is `None`. Initializing the config with the default `ResNet` backbone.")
backbone_config = CONFIG_MAPPING["resnet"](out_features=["stage4"])
elif isinstance(backbone_config, dict):
......@@ -128,6 +141,8 @@ class TvpConfig(PretrainedConfig):
backbone_config = config_class.from_dict(backbone_config)
self.backbone_config = backbone_config
self.backbone = backbone
self.use_pretrained_backbone = use_pretrained_backbone
self.distance_loss_weight = distance_loss_weight
self.duration_loss_weight = duration_loss_weight
self.visual_prompter_type = visual_prompter_type
......
......@@ -36,6 +36,12 @@ class UperNetConfig(PretrainedConfig):
Args:
backbone_config (`PretrainedConfig` or `dict`, *optional*, defaults to `ResNetConfig()`):
The configuration of the backbone model.
backbone (`str`, *optional*):
Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone`
is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
use_pretrained_backbone (`bool`, *optional*, `False`):
Whether to use pretrained weights for the backbone.
hidden_size (`int`, *optional*, defaults to 512):
The number of hidden units in the convolutional layers.
initializer_range (`float`, *optional*, defaults to 0.02):
......@@ -75,6 +81,8 @@ class UperNetConfig(PretrainedConfig):
def __init__(
self,
backbone_config=None,
backbone=None,
use_pretrained_backbone=False,
hidden_size=512,
initializer_range=0.02,
pool_scales=[1, 2, 3, 6],
......@@ -88,8 +96,13 @@ class UperNetConfig(PretrainedConfig):
**kwargs,
):
super().__init__(**kwargs)
if use_pretrained_backbone:
raise ValueError("Pretrained backbones are not supported yet.")
if backbone_config is None:
if backbone_config is not None and backbone is not None:
raise ValueError("You can't specify both `backbone` and `backbone_config`.")
if backbone_config is None and backbone is None:
logger.info("`backbone_config` is `None`. Initializing the config with the default `ResNet` backbone.")
backbone_config = CONFIG_MAPPING["resnet"](out_features=["stage1", "stage2", "stage3", "stage4"])
elif isinstance(backbone_config, dict):
......@@ -98,6 +111,8 @@ class UperNetConfig(PretrainedConfig):
backbone_config = config_class.from_dict(backbone_config)
self.backbone_config = backbone_config
self.backbone = backbone
self.use_pretrained_backbone = use_pretrained_backbone
self.hidden_size = hidden_size
self.initializer_range = initializer_range
self.pool_scales = pool_scales
......
......@@ -42,6 +42,12 @@ class ViTHybridConfig(PretrainedConfig):
Args:
backbone_config (`Union[Dict[str, Any], PretrainedConfig]`, *optional*):
The configuration of the backbone in a dictionary or the config object of the backbone.
backbone (`str`, *optional*):
Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone`
is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
use_pretrained_backbone (`bool`, *optional*, defaults to `False`):
Whether to use pretrained weights for the backbone.
hidden_size (`int`, *optional*, defaults to 768):
Dimensionality of the encoder layers and the pooler layer.
num_hidden_layers (`int`, *optional*, defaults to 12):
......@@ -92,6 +98,8 @@ class ViTHybridConfig(PretrainedConfig):
def __init__(
self,
backbone_config=None,
backbone=None,
use_pretrained_backbone=False,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
......@@ -109,8 +117,13 @@ class ViTHybridConfig(PretrainedConfig):
**kwargs,
):
super().__init__(**kwargs)
if use_pretrained_backbone:
raise ValueError("Pretrained backbones are not supported yet.")
if backbone_config is None:
if backbone_config is not None and backbone is not None:
raise ValueError("You can't specify both `backbone` and `backbone_config`.")
if backbone_config is None and backbone is None:
logger.info("`backbone_config` is `None`. Initializing the config with a `BiT` backbone.")
backbone_config = {
"global_padding": "same",
......@@ -132,6 +145,8 @@ class ViTHybridConfig(PretrainedConfig):
self.backbone_featmap_shape = backbone_featmap_shape
self.backbone_config = backbone_config
self.backbone = backbone
self.use_pretrained_backbone = use_pretrained_backbone
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
......
......@@ -42,6 +42,12 @@ class VitMatteConfig(PretrainedConfig):
Args:
backbone_config (`PretrainedConfig` or `dict`, *optional*, defaults to `VitDetConfig()`):
The configuration of the backbone model.
backbone (`str`, *optional*):
Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone`
is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
use_pretrained_backbone (`bool`, *optional*, defaults to `False`):
Whether to use pretrained weights for the backbone.
hidden_size (`int`, *optional*, defaults to 384):
The number of input channels of the decoder.
batch_norm_eps (`float`, *optional*, defaults to 1e-05):
......@@ -73,6 +79,8 @@ class VitMatteConfig(PretrainedConfig):
def __init__(
self,
backbone_config: PretrainedConfig = None,
backbone=None,
use_pretrained_backbone=False,
hidden_size: int = 384,
batch_norm_eps: float = 1e-5,
initializer_range: float = 0.02,
......@@ -82,7 +90,13 @@ class VitMatteConfig(PretrainedConfig):
):
super().__init__(**kwargs)
if backbone_config is None:
if use_pretrained_backbone:
raise ValueError("Pretrained backbones are not supported yet.")
if backbone_config is not None and backbone is not None:
raise ValueError("You can't specify both `backbone` and `backbone_config`.")
if backbone_config is None and backbone is None:
logger.info("`backbone_config` is `None`. Initializing the config with the default `VitDet` backbone.")
backbone_config = CONFIG_MAPPING["vitdet"](out_features=["stage4"])
elif isinstance(backbone_config, dict):
......@@ -91,6 +105,8 @@ class VitMatteConfig(PretrainedConfig):
backbone_config = config_class.from_dict(backbone_config)
self.backbone_config = backbone_config
self.backbone = backbone
self.use_pretrained_backbone = use_pretrained_backbone
self.batch_norm_eps = batch_norm_eps
self.hidden_size = hidden_size
self.initializer_range = initializer_range
......
......@@ -286,3 +286,56 @@ class BackboneConfigMixin:
output["out_features"] = output.pop("_out_features")
output["out_indices"] = output.pop("_out_indices")
return output
def load_backbone(config):
"""
Loads the backbone model from a config object.
If the config is from the backbone model itself, then we return a backbone model with randomly initialized
weights.
If the config is from the parent model of the backbone model itself, then we load the pretrained backbone weights
if specified.
"""
from transformers import AutoBackbone, AutoConfig
backbone_config = getattr(config, "backbone_config", None)
use_timm_backbone = getattr(config, "use_timm_backbone", None)
use_pretrained_backbone = getattr(config, "use_pretrained_backbone", None)
backbone_checkpoint = getattr(config, "backbone", None)
# If there is a backbone_config and a backbone checkpoint, and use_pretrained_backbone=False then the desired
# behaviour is ill-defined: do you want to load from the checkpoint's config or the backbone_config?
if backbone_config is not None and backbone_checkpoint is not None and use_pretrained_backbone is not None:
raise ValueError("Cannot specify both config.backbone_config and config.backbone")
# If any of thhe following are set, then the config passed in is from a model which contains a backbone.
if (
backbone_config is None
and use_timm_backbone is None
and backbone_checkpoint is None
and backbone_checkpoint is None
):
return AutoBackbone.from_config(config=config)
# config from the parent model that has a backbone
if use_timm_backbone:
if backbone_checkpoint is None:
raise ValueError("config.backbone must be set if use_timm_backbone is True")
# Because of how timm backbones were originally added to models, we need to pass in use_pretrained_backbone
# to determine whether to load the pretrained weights.
backbone = AutoBackbone.from_pretrained(
backbone_checkpoint, use_timm_backbone=use_timm_backbone, use_pretrained_backbone=use_pretrained_backbone
)
elif use_pretrained_backbone:
if backbone_checkpoint is None:
raise ValueError("config.backbone must be set if use_pretrained_backbone is True")
backbone = AutoBackbone.from_pretrained(backbone_checkpoint)
else:
if backbone_config is None and backbone_checkpoint is None:
raise ValueError("Either config.backbone_config or config.backbone must be set")
if backbone_config is None:
backbone_config = AutoConfig.from_pretrained(backbone_checkpoint)
backbone = AutoBackbone.from_config(config=backbone_config)
return backbone
......@@ -134,6 +134,8 @@ class ConditionalDetrModelTester:
num_labels=self.num_labels,
use_timm_backbone=False,
backbone_config=resnet_config,
backbone=None,
use_pretrained_backbone=False,
)
def prepare_config_and_inputs_for_common(self):
......
......@@ -149,7 +149,9 @@ class DeformableDetrModelTester:
encoder_n_points=self.encoder_n_points,
decoder_n_points=self.decoder_n_points,
use_timm_backbone=False,
backbone=None,
backbone_config=resnet_config,
use_pretrained_backbone=False,
)
def prepare_config_and_inputs_for_common(self):
......@@ -518,6 +520,8 @@ class DeformableDetrModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT
# let's pick a random timm backbone
config.backbone = "tf_mobilenetv3_small_075"
config.use_timm_backbone = True
config.backbone_config = None
for model_class in self.all_model_classes:
model = model_class(config)
......
......@@ -157,6 +157,7 @@ class DetaModelTester:
assign_first_stage=assign_first_stage,
assign_second_stage=assign_second_stage,
backbone_config=resnet_config,
backbone=None,
)
def prepare_config_and_inputs_for_common(self, model_class_name="DetaModel"):
......
......@@ -130,6 +130,8 @@ class DetrModelTester:
num_labels=self.num_labels,
use_timm_backbone=False,
backbone_config=resnet_config,
backbone=None,
use_pretrained_backbone=False,
)
def prepare_config_and_inputs_for_common(self):
......@@ -622,7 +624,7 @@ class DetrModelIntegrationTestsTimmBackbone(unittest.TestCase):
torch_device
)
expected_number_of_segments = 5
expected_first_segment = {"id": 1, "label_id": 17, "was_fused": False, "score": 0.994096}
expected_first_segment = {"id": 1, "label_id": 17, "was_fused": False, "score": 0.994097}
number_of_unique_segments = len(torch.unique(results["segmentation"]))
self.assertTrue(
......
......@@ -95,6 +95,7 @@ class DPTModelTester:
def get_config(self):
return DPTConfig(
backbone_config=self.get_backbone_config(),
backbone=None,
neck_hidden_sizes=self.neck_hidden_sizes,
fusion_hidden_size=self.fusion_hidden_size,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment