Unverified Commit 99de3a84 authored by amyeroberts's avatar amyeroberts Committed by GitHub
Browse files

Move out common backbone config param validation (#31144)

* Move out common validation

* Add missing backbone config arguments
parent 485d913d
...@@ -22,6 +22,7 @@ from packaging import version ...@@ -22,6 +22,7 @@ from packaging import version
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfig from ...onnx import OnnxConfig
from ...utils import logging from ...utils import logging
from ...utils.backbone_utils import verify_backbone_config_arguments
from ..auto import CONFIG_MAPPING from ..auto import CONFIG_MAPPING
...@@ -179,17 +180,6 @@ class ConditionalDetrConfig(PretrainedConfig): ...@@ -179,17 +180,6 @@ class ConditionalDetrConfig(PretrainedConfig):
focal_alpha=0.25, focal_alpha=0.25,
**kwargs, **kwargs,
): ):
if not use_timm_backbone and use_pretrained_backbone:
raise ValueError(
"Loading pretrained backbone weights from the transformers library is not supported yet. `use_timm_backbone` must be set to `True` when `use_pretrained_backbone=True`"
)
if backbone_config is not None and backbone is not None:
raise ValueError("You can't specify both `backbone` and `backbone_config`.")
if backbone_config is not None and use_timm_backbone:
raise ValueError("You can't specify both `backbone_config` and `use_timm_backbone`.")
# We default to values which were previously hard-coded in the model. This enables configurability of the config # We default to values which were previously hard-coded in the model. This enables configurability of the config
# while keeping the default behavior the same. # while keeping the default behavior the same.
if use_timm_backbone and backbone_kwargs is None: if use_timm_backbone and backbone_kwargs is None:
...@@ -208,6 +198,14 @@ class ConditionalDetrConfig(PretrainedConfig): ...@@ -208,6 +198,14 @@ class ConditionalDetrConfig(PretrainedConfig):
config_class = CONFIG_MAPPING[backbone_model_type] config_class = CONFIG_MAPPING[backbone_model_type]
backbone_config = config_class.from_dict(backbone_config) backbone_config = config_class.from_dict(backbone_config)
verify_backbone_config_arguments(
use_timm_backbone=use_timm_backbone,
use_pretrained_backbone=use_pretrained_backbone,
backbone=backbone,
backbone_config=backbone_config,
backbone_kwargs=backbone_kwargs,
)
self.use_timm_backbone = use_timm_backbone self.use_timm_backbone = use_timm_backbone
self.backbone_config = backbone_config self.backbone_config = backbone_config
self.num_channels = num_channels self.num_channels = num_channels
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...utils import logging from ...utils import logging
from ...utils.backbone_utils import verify_backbone_config_arguments
from ..auto import CONFIG_MAPPING from ..auto import CONFIG_MAPPING
...@@ -195,20 +196,6 @@ class DeformableDetrConfig(PretrainedConfig): ...@@ -195,20 +196,6 @@ class DeformableDetrConfig(PretrainedConfig):
disable_custom_kernels=False, disable_custom_kernels=False,
**kwargs, **kwargs,
): ):
if not use_timm_backbone and use_pretrained_backbone:
raise ValueError(
"Loading pretrained backbone weights from the transformers library is not supported yet. `use_timm_backbone` must be set to `True` when `use_pretrained_backbone=True`"
)
if backbone_config is not None and backbone is not None:
raise ValueError("You can't specify both `backbone` and `backbone_config`.")
if backbone_config is not None and use_timm_backbone:
raise ValueError("You can't specify both `backbone_config` and `use_timm_backbone`.")
if backbone_kwargs is not None and backbone_kwargs and backbone_config is not None:
raise ValueError("You can't specify both `backbone_kwargs` and `backbone_config`.")
# We default to values which were previously hard-coded in the model. This enables configurability of the config # We default to values which were previously hard-coded in the model. This enables configurability of the config
# while keeping the default behavior the same. # while keeping the default behavior the same.
if use_timm_backbone and backbone_kwargs is None: if use_timm_backbone and backbone_kwargs is None:
...@@ -227,6 +214,14 @@ class DeformableDetrConfig(PretrainedConfig): ...@@ -227,6 +214,14 @@ class DeformableDetrConfig(PretrainedConfig):
config_class = CONFIG_MAPPING[backbone_model_type] config_class = CONFIG_MAPPING[backbone_model_type]
backbone_config = config_class.from_dict(backbone_config) backbone_config = config_class.from_dict(backbone_config)
verify_backbone_config_arguments(
use_timm_backbone=use_timm_backbone,
use_pretrained_backbone=use_pretrained_backbone,
backbone=backbone,
backbone_config=backbone_config,
backbone_kwargs=backbone_kwargs,
)
self.use_timm_backbone = use_timm_backbone self.use_timm_backbone = use_timm_backbone
self.backbone_config = backbone_config self.backbone_config = backbone_config
self.num_channels = num_channels self.num_channels = num_channels
......
...@@ -18,6 +18,7 @@ import copy ...@@ -18,6 +18,7 @@ import copy
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...utils import logging from ...utils import logging
from ...utils.backbone_utils import verify_backbone_config_arguments
from ..auto.configuration_auto import CONFIG_MAPPING from ..auto.configuration_auto import CONFIG_MAPPING
...@@ -44,6 +45,12 @@ class DepthAnythingConfig(PretrainedConfig): ...@@ -44,6 +45,12 @@ class DepthAnythingConfig(PretrainedConfig):
is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights. is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
use_pretrained_backbone (`bool`, *optional*, defaults to `False`): use_pretrained_backbone (`bool`, *optional*, defaults to `False`):
Whether to use pretrained weights for the backbone. Whether to use pretrained weights for the backbone.
use_timm_backbone (`bool`, *optional*, defaults to `False`):
Whether or not to use the `timm` library for the backbone. If set to `False`, will use the [`AutoBackbone`]
API.
backbone_kwargs (`dict`, *optional*):
Keyword arguments to be passed to AutoBackbone when loading from a checkpoint
e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set.
patch_size (`int`, *optional*, defaults to 14): patch_size (`int`, *optional*, defaults to 14):
The size of the patches to extract from the backbone features. The size of the patches to extract from the backbone features.
initializer_range (`float`, *optional*, defaults to 0.02): initializer_range (`float`, *optional*, defaults to 0.02):
...@@ -83,6 +90,8 @@ class DepthAnythingConfig(PretrainedConfig): ...@@ -83,6 +90,8 @@ class DepthAnythingConfig(PretrainedConfig):
backbone_config=None, backbone_config=None,
backbone=None, backbone=None,
use_pretrained_backbone=False, use_pretrained_backbone=False,
use_timm_backbone=False,
backbone_kwargs=None,
patch_size=14, patch_size=14,
initializer_range=0.02, initializer_range=0.02,
reassemble_hidden_size=384, reassemble_hidden_size=384,
...@@ -94,13 +103,6 @@ class DepthAnythingConfig(PretrainedConfig): ...@@ -94,13 +103,6 @@ class DepthAnythingConfig(PretrainedConfig):
**kwargs, **kwargs,
): ):
super().__init__(**kwargs) super().__init__(**kwargs)
if use_pretrained_backbone:
raise ValueError("Pretrained backbones are not supported yet.")
if backbone_config is not None and backbone is not None:
raise ValueError("You can't specify both `backbone` and `backbone_config`.")
if backbone_config is None and backbone is None: if backbone_config is None and backbone is None:
logger.info("`backbone_config` is `None`. Initializing the config with the default `Dinov2` backbone.") logger.info("`backbone_config` is `None`. Initializing the config with the default `Dinov2` backbone.")
backbone_config = CONFIG_MAPPING["dinov2"]( backbone_config = CONFIG_MAPPING["dinov2"](
...@@ -116,6 +118,14 @@ class DepthAnythingConfig(PretrainedConfig): ...@@ -116,6 +118,14 @@ class DepthAnythingConfig(PretrainedConfig):
config_class = CONFIG_MAPPING[backbone_model_type] config_class = CONFIG_MAPPING[backbone_model_type]
backbone_config = config_class.from_dict(backbone_config) backbone_config = config_class.from_dict(backbone_config)
verify_backbone_config_arguments(
use_timm_backbone=use_timm_backbone,
use_pretrained_backbone=use_pretrained_backbone,
backbone=backbone,
backbone_config=backbone_config,
backbone_kwargs=backbone_kwargs,
)
self.backbone_config = backbone_config self.backbone_config = backbone_config
self.backbone = backbone self.backbone = backbone
self.use_pretrained_backbone = use_pretrained_backbone self.use_pretrained_backbone = use_pretrained_backbone
......
...@@ -22,6 +22,7 @@ from packaging import version ...@@ -22,6 +22,7 @@ from packaging import version
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfig from ...onnx import OnnxConfig
from ...utils import logging from ...utils import logging
from ...utils.backbone_utils import verify_backbone_config_arguments
from ..auto import CONFIG_MAPPING from ..auto import CONFIG_MAPPING
...@@ -176,20 +177,6 @@ class DetrConfig(PretrainedConfig): ...@@ -176,20 +177,6 @@ class DetrConfig(PretrainedConfig):
eos_coefficient=0.1, eos_coefficient=0.1,
**kwargs, **kwargs,
): ):
if not use_timm_backbone and use_pretrained_backbone:
raise ValueError(
"Loading pretrained backbone weights from the transformers library is not supported yet. `use_timm_backbone` must be set to `True` when `use_pretrained_backbone=True`"
)
if backbone_config is not None and backbone is not None:
raise ValueError("You can't specify both `backbone` and `backbone_config`.")
if backbone_config is not None and use_timm_backbone:
raise ValueError("You can't specify both `backbone_config` and `use_timm_backbone`.")
if backbone_kwargs is not None and backbone_kwargs and backbone_config is not None:
raise ValueError("You can't specify both `backbone_kwargs` and `backbone_config`.")
# We default to values which were previously hard-coded in the model. This enables configurability of the config # We default to values which were previously hard-coded in the model. This enables configurability of the config
# while keeping the default behavior the same. # while keeping the default behavior the same.
if use_timm_backbone and backbone_kwargs is None: if use_timm_backbone and backbone_kwargs is None:
...@@ -211,6 +198,14 @@ class DetrConfig(PretrainedConfig): ...@@ -211,6 +198,14 @@ class DetrConfig(PretrainedConfig):
# set timm attributes to None # set timm attributes to None
dilation = None dilation = None
verify_backbone_config_arguments(
use_timm_backbone=use_timm_backbone,
use_pretrained_backbone=use_pretrained_backbone,
backbone=backbone,
backbone_config=backbone_config,
backbone_kwargs=backbone_kwargs,
)
self.use_timm_backbone = use_timm_backbone self.use_timm_backbone = use_timm_backbone
self.backbone_config = backbone_config self.backbone_config = backbone_config
self.num_channels = num_channels self.num_channels = num_channels
......
...@@ -18,6 +18,7 @@ import copy ...@@ -18,6 +18,7 @@ import copy
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...utils import logging from ...utils import logging
from ...utils.backbone_utils import verify_backbone_config_arguments
from ..auto.configuration_auto import CONFIG_MAPPING from ..auto.configuration_auto import CONFIG_MAPPING
from ..bit import BitConfig from ..bit import BitConfig
...@@ -179,9 +180,6 @@ class DPTConfig(PretrainedConfig): ...@@ -179,9 +180,6 @@ class DPTConfig(PretrainedConfig):
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.is_hybrid = is_hybrid self.is_hybrid = is_hybrid
if use_pretrained_backbone:
raise ValueError("Pretrained backbones are not supported yet.")
use_autobackbone = False use_autobackbone = False
if self.is_hybrid: if self.is_hybrid:
if backbone_config is None and backbone is None: if backbone_config is None and backbone is None:
...@@ -226,11 +224,13 @@ class DPTConfig(PretrainedConfig): ...@@ -226,11 +224,13 @@ class DPTConfig(PretrainedConfig):
self.backbone_featmap_shape = None self.backbone_featmap_shape = None
self.neck_ignore_stages = [] self.neck_ignore_stages = []
if use_autobackbone and backbone_config is not None and backbone is not None: verify_backbone_config_arguments(
raise ValueError("You can't specify both `backbone` and `backbone_config`.") use_timm_backbone=use_timm_backbone,
use_pretrained_backbone=use_pretrained_backbone,
if backbone_kwargs is not None and backbone_kwargs and backbone_config is not None: backbone=backbone,
raise ValueError("You can't specify both `backbone_kwargs` and `backbone_config`.") backbone_config=backbone_config,
backbone_kwargs=backbone_kwargs,
)
self.backbone = backbone self.backbone = backbone
self.use_pretrained_backbone = use_pretrained_backbone self.use_pretrained_backbone = use_pretrained_backbone
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...utils import logging from ...utils import logging
from ...utils.backbone_utils import verify_backbone_config_arguments
from ..auto import CONFIG_MAPPING from ..auto import CONFIG_MAPPING
...@@ -198,14 +199,6 @@ class GroundingDinoConfig(PretrainedConfig): ...@@ -198,14 +199,6 @@ class GroundingDinoConfig(PretrainedConfig):
layer_norm_eps=1e-5, layer_norm_eps=1e-5,
**kwargs, **kwargs,
): ):
if not use_timm_backbone and use_pretrained_backbone:
raise ValueError(
"Loading pretrained backbone weights from the transformers library is not supported yet. `use_timm_backbone` must be set to `True` when `use_pretrained_backbone=True`"
)
if backbone_config is not None and backbone is not None:
raise ValueError("You can't specify both `backbone` and `backbone_config`.")
if backbone_config is None and backbone is None: if backbone_config is None and backbone is None:
logger.info("`backbone_config` is `None`. Initializing the config with the default `Swin` backbone.") logger.info("`backbone_config` is `None`. Initializing the config with the default `Swin` backbone.")
backbone_config = CONFIG_MAPPING["swin"]( backbone_config = CONFIG_MAPPING["swin"](
...@@ -221,8 +214,13 @@ class GroundingDinoConfig(PretrainedConfig): ...@@ -221,8 +214,13 @@ class GroundingDinoConfig(PretrainedConfig):
config_class = CONFIG_MAPPING[backbone_model_type] config_class = CONFIG_MAPPING[backbone_model_type]
backbone_config = config_class.from_dict(backbone_config) backbone_config = config_class.from_dict(backbone_config)
if backbone_kwargs is not None and backbone_kwargs and backbone_config is not None: verify_backbone_config_arguments(
raise ValueError("You can't specify both `backbone_kwargs` and `backbone_config`.") use_timm_backbone=use_timm_backbone,
use_pretrained_backbone=use_pretrained_backbone,
backbone=backbone,
backbone_config=backbone_config,
backbone_kwargs=backbone_kwargs,
)
if text_config is None: if text_config is None:
text_config = {} text_config = {}
......
...@@ -18,6 +18,7 @@ from typing import Dict, List, Optional ...@@ -18,6 +18,7 @@ from typing import Dict, List, Optional
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...utils import logging from ...utils import logging
from ...utils.backbone_utils import verify_backbone_config_arguments
from ..auto import CONFIG_MAPPING from ..auto import CONFIG_MAPPING
...@@ -166,12 +167,6 @@ class Mask2FormerConfig(PretrainedConfig): ...@@ -166,12 +167,6 @@ class Mask2FormerConfig(PretrainedConfig):
backbone_kwargs: Optional[Dict] = None, backbone_kwargs: Optional[Dict] = None,
**kwargs, **kwargs,
): ):
if use_pretrained_backbone:
raise ValueError("Pretrained backbones are not supported yet.")
if backbone_config is not None and backbone is not None:
raise ValueError("You can't specify both `backbone` and `backbone_config`.")
if backbone_config is None and backbone is None: if backbone_config is None and backbone is None:
logger.info("`backbone_config` is `None`. Initializing the config with the default `Swin` backbone.") logger.info("`backbone_config` is `None`. Initializing the config with the default `Swin` backbone.")
backbone_config = CONFIG_MAPPING["swin"]( backbone_config = CONFIG_MAPPING["swin"](
...@@ -186,15 +181,18 @@ class Mask2FormerConfig(PretrainedConfig): ...@@ -186,15 +181,18 @@ class Mask2FormerConfig(PretrainedConfig):
use_absolute_embeddings=False, use_absolute_embeddings=False,
out_features=["stage1", "stage2", "stage3", "stage4"], out_features=["stage1", "stage2", "stage3", "stage4"],
) )
elif isinstance(backbone_config, dict):
if backbone_kwargs is not None and backbone_kwargs and backbone_config is not None:
raise ValueError("You can't specify both `backbone_kwargs` and `backbone_config`.")
if isinstance(backbone_config, dict):
backbone_model_type = backbone_config.pop("model_type") backbone_model_type = backbone_config.pop("model_type")
config_class = CONFIG_MAPPING[backbone_model_type] config_class = CONFIG_MAPPING[backbone_model_type]
backbone_config = config_class.from_dict(backbone_config) backbone_config = config_class.from_dict(backbone_config)
verify_backbone_config_arguments(
use_timm_backbone=use_timm_backbone,
use_pretrained_backbone=use_pretrained_backbone,
backbone=backbone,
backbone_config=backbone_config,
backbone_kwargs=backbone_kwargs,
)
# verify that the backbone is supported # verify that the backbone is supported
if backbone_config is not None and backbone_config.model_type not in self.backbones_supported: if backbone_config is not None and backbone_config.model_type not in self.backbones_supported:
logger.warning_once( logger.warning_once(
......
...@@ -18,6 +18,7 @@ from typing import Dict, Optional ...@@ -18,6 +18,7 @@ from typing import Dict, Optional
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...utils import logging from ...utils import logging
from ...utils.backbone_utils import verify_backbone_config_arguments
from ..auto import CONFIG_MAPPING from ..auto import CONFIG_MAPPING
from ..detr import DetrConfig from ..detr import DetrConfig
from ..swin import SwinConfig from ..swin import SwinConfig
...@@ -126,15 +127,6 @@ class MaskFormerConfig(PretrainedConfig): ...@@ -126,15 +127,6 @@ class MaskFormerConfig(PretrainedConfig):
backbone_kwargs: Optional[Dict] = None, backbone_kwargs: Optional[Dict] = None,
**kwargs, **kwargs,
): ):
if use_pretrained_backbone:
raise ValueError("Pretrained backbones are not supported yet.")
if backbone_config is not None and backbone is not None:
raise ValueError("You can't specify both `backbone` and `backbone_config`.")
if backbone_kwargs is not None and backbone_kwargs and backbone_config is not None:
raise ValueError("You can't specify both `backbone_kwargs` and `backbone_config`.")
if backbone_config is None and backbone is None: if backbone_config is None and backbone is None:
# fall back to https://huggingface.co/microsoft/swin-base-patch4-window12-384-in22k # fall back to https://huggingface.co/microsoft/swin-base-patch4-window12-384-in22k
backbone_config = SwinConfig( backbone_config = SwinConfig(
...@@ -148,12 +140,18 @@ class MaskFormerConfig(PretrainedConfig): ...@@ -148,12 +140,18 @@ class MaskFormerConfig(PretrainedConfig):
drop_path_rate=0.3, drop_path_rate=0.3,
out_features=["stage1", "stage2", "stage3", "stage4"], out_features=["stage1", "stage2", "stage3", "stage4"],
) )
elif isinstance(backbone_config, dict):
if isinstance(backbone_config, dict):
backbone_model_type = backbone_config.pop("model_type") backbone_model_type = backbone_config.pop("model_type")
config_class = CONFIG_MAPPING[backbone_model_type] config_class = CONFIG_MAPPING[backbone_model_type]
backbone_config = config_class.from_dict(backbone_config) backbone_config = config_class.from_dict(backbone_config)
verify_backbone_config_arguments(
use_timm_backbone=use_timm_backbone,
use_pretrained_backbone=use_pretrained_backbone,
backbone=backbone,
backbone_config=backbone_config,
backbone_kwargs=backbone_kwargs,
)
# verify that the backbone is supported # verify that the backbone is supported
if backbone_config is not None and backbone_config.model_type not in self.backbones_supported: if backbone_config is not None and backbone_config.model_type not in self.backbones_supported:
logger.warning_once( logger.warning_once(
......
...@@ -18,6 +18,7 @@ from typing import Dict, Optional ...@@ -18,6 +18,7 @@ from typing import Dict, Optional
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...utils import logging from ...utils import logging
from ...utils.backbone_utils import verify_backbone_config_arguments
from ..auto import CONFIG_MAPPING from ..auto import CONFIG_MAPPING
...@@ -196,12 +197,6 @@ class OneFormerConfig(PretrainedConfig): ...@@ -196,12 +197,6 @@ class OneFormerConfig(PretrainedConfig):
common_stride: int = 4, common_stride: int = 4,
**kwargs, **kwargs,
): ):
if use_pretrained_backbone:
raise ValueError("Pretrained backbones are not supported yet.")
if backbone_config is not None and backbone is not None:
raise ValueError("You can't specify both `backbone` and `backbone_config`.")
if backbone_config is None and backbone is None: if backbone_config is None and backbone is None:
logger.info("`backbone_config` is unset. Initializing the config with the default `Swin` backbone.") logger.info("`backbone_config` is unset. Initializing the config with the default `Swin` backbone.")
backbone_config = CONFIG_MAPPING["swin"]( backbone_config = CONFIG_MAPPING["swin"](
...@@ -221,8 +216,13 @@ class OneFormerConfig(PretrainedConfig): ...@@ -221,8 +216,13 @@ class OneFormerConfig(PretrainedConfig):
config_class = CONFIG_MAPPING[backbone_model_type] config_class = CONFIG_MAPPING[backbone_model_type]
backbone_config = config_class.from_dict(backbone_config) backbone_config = config_class.from_dict(backbone_config)
if backbone_kwargs is not None and backbone_kwargs and backbone_config is not None: verify_backbone_config_arguments(
raise ValueError("You can't specify both `backbone_kwargs` and `backbone_config`.") use_timm_backbone=use_timm_backbone,
use_pretrained_backbone=use_pretrained_backbone,
backbone=backbone,
backbone_config=backbone_config,
backbone_kwargs=backbone_kwargs,
)
self.backbone_config = backbone_config self.backbone_config = backbone_config
self.backbone = backbone self.backbone = backbone
......
...@@ -22,6 +22,7 @@ from packaging import version ...@@ -22,6 +22,7 @@ from packaging import version
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfig from ...onnx import OnnxConfig
from ...utils import logging from ...utils import logging
from ...utils.backbone_utils import verify_backbone_config_arguments
from ..auto import CONFIG_MAPPING from ..auto import CONFIG_MAPPING
...@@ -177,20 +178,6 @@ class TableTransformerConfig(PretrainedConfig): ...@@ -177,20 +178,6 @@ class TableTransformerConfig(PretrainedConfig):
eos_coefficient=0.1, eos_coefficient=0.1,
**kwargs, **kwargs,
): ):
if not use_timm_backbone and use_pretrained_backbone:
raise ValueError(
"Loading pretrained backbone weights from the transformers library is not supported yet. `use_timm_backbone` must be set to `True` when `use_pretrained_backbone=True`"
)
if backbone_config is not None and backbone is not None:
raise ValueError("You can't specify both `backbone` and `backbone_config`.")
if backbone_config is not None and use_timm_backbone:
raise ValueError("You can't specify both `backbone_config` and `use_timm_backbone`.")
if backbone_kwargs is not None and backbone_kwargs and backbone_config is not None:
raise ValueError("You can't specify both `backbone_kwargs` and `backbone_config`.")
# We default to values which were previously hard-coded in the model. This enables configurability of the config # We default to values which were previously hard-coded in the model. This enables configurability of the config
# while keeping the default behavior the same. # while keeping the default behavior the same.
if use_timm_backbone and backbone_kwargs is None: if use_timm_backbone and backbone_kwargs is None:
...@@ -212,6 +199,14 @@ class TableTransformerConfig(PretrainedConfig): ...@@ -212,6 +199,14 @@ class TableTransformerConfig(PretrainedConfig):
# set timm attributes to None # set timm attributes to None
dilation = None dilation = None
verify_backbone_config_arguments(
use_timm_backbone=use_timm_backbone,
use_pretrained_backbone=use_pretrained_backbone,
backbone=backbone,
backbone_config=backbone_config,
backbone_kwargs=backbone_kwargs,
)
self.use_timm_backbone = use_timm_backbone self.use_timm_backbone = use_timm_backbone
self.backbone_config = backbone_config self.backbone_config = backbone_config
self.num_channels = num_channels self.num_channels = num_channels
......
...@@ -18,6 +18,7 @@ import copy ...@@ -18,6 +18,7 @@ import copy
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...utils import logging from ...utils import logging
from ...utils.backbone_utils import verify_backbone_config_arguments
from ..auto import CONFIG_MAPPING from ..auto import CONFIG_MAPPING
...@@ -129,12 +130,6 @@ class TvpConfig(PretrainedConfig): ...@@ -129,12 +130,6 @@ class TvpConfig(PretrainedConfig):
**kwargs, **kwargs,
): ):
super().__init__(**kwargs) super().__init__(**kwargs)
if use_pretrained_backbone:
raise ValueError("Pretrained backbones are not supported yet.")
if backbone_config is not None and backbone is not None:
raise ValueError("You can't specify both `backbone` and `backbone_config`.")
if backbone_config is None and backbone is None: if backbone_config is None and backbone is None:
logger.info("`backbone_config` is `None`. Initializing the config with the default `ResNet` backbone.") logger.info("`backbone_config` is `None`. Initializing the config with the default `ResNet` backbone.")
backbone_config = CONFIG_MAPPING["resnet"](out_features=["stage4"]) backbone_config = CONFIG_MAPPING["resnet"](out_features=["stage4"])
...@@ -143,8 +138,13 @@ class TvpConfig(PretrainedConfig): ...@@ -143,8 +138,13 @@ class TvpConfig(PretrainedConfig):
config_class = CONFIG_MAPPING[backbone_model_type] config_class = CONFIG_MAPPING[backbone_model_type]
backbone_config = config_class.from_dict(backbone_config) backbone_config = config_class.from_dict(backbone_config)
if backbone_kwargs is not None and backbone_kwargs and backbone_config is not None: verify_backbone_config_arguments(
raise ValueError("You can't specify both `backbone_kwargs` and `backbone_config`.") use_timm_backbone=use_timm_backbone,
use_pretrained_backbone=use_pretrained_backbone,
backbone=backbone,
backbone_config=backbone_config,
backbone_kwargs=backbone_kwargs,
)
self.backbone_config = backbone_config self.backbone_config = backbone_config
self.backbone = backbone self.backbone = backbone
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...utils import logging from ...utils import logging
from ...utils.backbone_utils import verify_backbone_config_arguments
from ..auto.configuration_auto import CONFIG_MAPPING from ..auto.configuration_auto import CONFIG_MAPPING
...@@ -103,12 +104,6 @@ class UperNetConfig(PretrainedConfig): ...@@ -103,12 +104,6 @@ class UperNetConfig(PretrainedConfig):
**kwargs, **kwargs,
): ):
super().__init__(**kwargs) super().__init__(**kwargs)
if use_pretrained_backbone:
raise ValueError("Pretrained backbones are not supported yet.")
if backbone_config is not None and backbone is not None:
raise ValueError("You can't specify both `backbone` and `backbone_config`.")
if backbone_config is None and backbone is None: if backbone_config is None and backbone is None:
logger.info("`backbone_config` is `None`. Initializing the config with the default `ResNet` backbone.") logger.info("`backbone_config` is `None`. Initializing the config with the default `ResNet` backbone.")
backbone_config = CONFIG_MAPPING["resnet"](out_features=["stage1", "stage2", "stage3", "stage4"]) backbone_config = CONFIG_MAPPING["resnet"](out_features=["stage1", "stage2", "stage3", "stage4"])
...@@ -117,8 +112,13 @@ class UperNetConfig(PretrainedConfig): ...@@ -117,8 +112,13 @@ class UperNetConfig(PretrainedConfig):
config_class = CONFIG_MAPPING[backbone_model_type] config_class = CONFIG_MAPPING[backbone_model_type]
backbone_config = config_class.from_dict(backbone_config) backbone_config = config_class.from_dict(backbone_config)
if backbone_kwargs is not None and backbone_kwargs and backbone_config is not None: verify_backbone_config_arguments(
raise ValueError("You can't specify both `backbone_kwargs` and `backbone_config`.") use_timm_backbone=use_timm_backbone,
use_pretrained_backbone=use_pretrained_backbone,
backbone=backbone,
backbone_config=backbone_config,
backbone_kwargs=backbone_kwargs,
)
self.backbone_config = backbone_config self.backbone_config = backbone_config
self.backbone = backbone self.backbone = backbone
......
...@@ -19,6 +19,7 @@ from typing import List ...@@ -19,6 +19,7 @@ from typing import List
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...utils import logging from ...utils import logging
from ...utils.backbone_utils import verify_backbone_config_arguments
from ..auto.configuration_auto import CONFIG_MAPPING from ..auto.configuration_auto import CONFIG_MAPPING
...@@ -94,12 +95,6 @@ class VitMatteConfig(PretrainedConfig): ...@@ -94,12 +95,6 @@ class VitMatteConfig(PretrainedConfig):
): ):
super().__init__(**kwargs) super().__init__(**kwargs)
if use_pretrained_backbone:
raise ValueError("Pretrained backbones are not supported yet.")
if backbone_config is not None and backbone is not None:
raise ValueError("You can't specify both `backbone` and `backbone_config`.")
if backbone_config is None and backbone is None: if backbone_config is None and backbone is None:
logger.info("`backbone_config` is `None`. Initializing the config with the default `VitDet` backbone.") logger.info("`backbone_config` is `None`. Initializing the config with the default `VitDet` backbone.")
backbone_config = CONFIG_MAPPING["vitdet"](out_features=["stage4"]) backbone_config = CONFIG_MAPPING["vitdet"](out_features=["stage4"])
...@@ -108,8 +103,13 @@ class VitMatteConfig(PretrainedConfig): ...@@ -108,8 +103,13 @@ class VitMatteConfig(PretrainedConfig):
config_class = CONFIG_MAPPING[backbone_model_type] config_class = CONFIG_MAPPING[backbone_model_type]
backbone_config = config_class.from_dict(backbone_config) backbone_config = config_class.from_dict(backbone_config)
if backbone_kwargs is not None and backbone_kwargs and backbone_config is not None: verify_backbone_config_arguments(
raise ValueError("You can't specify both `backbone_kwargs` and `backbone_config`.") use_timm_backbone=use_timm_backbone,
use_pretrained_backbone=use_pretrained_backbone,
backbone=backbone,
backbone_config=backbone_config,
backbone_kwargs=backbone_kwargs,
)
self.backbone_config = backbone_config self.backbone_config = backbone_config
self.backbone = backbone self.backbone = backbone
......
...@@ -17,7 +17,11 @@ ...@@ -17,7 +17,11 @@
import enum import enum
import inspect import inspect
from typing import Iterable, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Union
if TYPE_CHECKING:
from .configuration_utils import PretrainedConfig
class BackboneType(enum.Enum): class BackboneType(enum.Enum):
...@@ -352,3 +356,28 @@ def load_backbone(config): ...@@ -352,3 +356,28 @@ def load_backbone(config):
backbone_config = AutoConfig.from_pretrained(backbone_checkpoint, **backbone_kwargs) backbone_config = AutoConfig.from_pretrained(backbone_checkpoint, **backbone_kwargs)
backbone = AutoBackbone.from_config(config=backbone_config) backbone = AutoBackbone.from_config(config=backbone_config)
return backbone return backbone
def verify_backbone_config_arguments(
use_timm_backbone: bool,
use_pretrained_backbone: bool,
backbone: Optional[str],
backbone_config: Optional[Union[dict, "PretrainedConfig"]],
backbone_kwargs: Optional[dict],
):
"""
Verify that the config arguments to be passed to load_backbone are valid
"""
if not use_timm_backbone and use_pretrained_backbone:
raise ValueError(
"Loading pretrained backbone weights from the transformers library is not supported yet. `use_timm_backbone` must be set to `True` when `use_pretrained_backbone=True`"
)
if backbone_config is not None and backbone is not None:
raise ValueError("You can't specify both `backbone` and `backbone_config`.")
if backbone_config is not None and use_timm_backbone:
raise ValueError("You can't specify both `backbone_config` and `use_timm_backbone`.")
if backbone_kwargs is not None and backbone_kwargs and backbone_config is not None:
raise ValueError("You can't specify both `backbone_kwargs` and `backbone_config`.")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment