"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "17ea43cf985829634bd86b36b44e5410c6f83e36"
Unverified Commit 0199a484 authored by amyeroberts's avatar amyeroberts Committed by GitHub
Browse files

Backbone kwargs in config (#28784)



* Enable instantiating model with pretrained backbone weights

* Clarify pretrained import

* Use load_backbone instead

* Add backbone_kwargs to config

* Pass kwargs to constructors

* Fix up

* Input verification

* Add tests

* Tidy up

* Update tests/utils/test_backbone_utils.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

---------
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
parent 725f4ad1
......@@ -98,6 +98,9 @@ class ConditionalDetrConfig(PretrainedConfig):
is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
use_pretrained_backbone (`bool`, *optional*, defaults to `True`):
Whether to use pretrained weights for the backbone.
backbone_kwargs (`dict`, *optional*):
Keyword arguments to be passed to AutoBackbone when loading from a checkpoint
e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set.
dilation (`bool`, *optional*, defaults to `False`):
Whether to replace stride with dilation in the last convolutional block (DC5). Only supported when
`use_timm_backbone` = `True`.
......@@ -168,6 +171,7 @@ class ConditionalDetrConfig(PretrainedConfig):
position_embedding_type="sine",
backbone="resnet50",
use_pretrained_backbone=True,
backbone_kwargs=None,
dilation=False,
class_cost=2,
bbox_cost=5,
......@@ -191,6 +195,9 @@ class ConditionalDetrConfig(PretrainedConfig):
if backbone_config is not None and use_timm_backbone:
raise ValueError("You can't specify both `backbone_config` and `use_timm_backbone`.")
if backbone_kwargs is not None and backbone_kwargs and backbone_config is not None:
raise ValueError("You can't specify both `backbone_kwargs` and `backbone_config`.")
if not use_timm_backbone:
if backbone_config is None:
logger.info("`backbone_config` is `None`. Initializing the config with the default `ResNet` backbone.")
......@@ -224,6 +231,7 @@ class ConditionalDetrConfig(PretrainedConfig):
self.position_embedding_type = position_embedding_type
self.backbone = backbone
self.use_pretrained_backbone = use_pretrained_backbone
self.backbone_kwargs = backbone_kwargs
self.dilation = dilation
# Hungarian matcher
self.class_cost = class_cost
......
......@@ -90,6 +90,9 @@ class DeformableDetrConfig(PretrainedConfig):
is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
use_pretrained_backbone (`bool`, *optional*, defaults to `True`):
Whether to use pretrained weights for the backbone.
backbone_kwargs (`dict`, *optional*):
Keyword arguments to be passed to AutoBackbone when loading from a checkpoint
e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set.
dilation (`bool`, *optional*, defaults to `False`):
Whether to replace stride with dilation in the last convolutional block (DC5). Only supported when
`use_timm_backbone` = `True`.
......@@ -177,6 +180,7 @@ class DeformableDetrConfig(PretrainedConfig):
position_embedding_type="sine",
backbone="resnet50",
use_pretrained_backbone=True,
backbone_kwargs=None,
dilation=False,
num_feature_levels=4,
encoder_n_points=4,
......@@ -207,6 +211,9 @@ class DeformableDetrConfig(PretrainedConfig):
if backbone_config is not None and use_timm_backbone:
raise ValueError("You can't specify both `backbone_config` and `use_timm_backbone`.")
if backbone_kwargs is not None and backbone_kwargs and backbone_config is not None:
raise ValueError("You can't specify both `backbone_kwargs` and `backbone_config`.")
if not use_timm_backbone:
if backbone_config is None:
logger.info("`backbone_config` is `None`. Initializing the config with the default `ResNet` backbone.")
......@@ -238,6 +245,7 @@ class DeformableDetrConfig(PretrainedConfig):
self.position_embedding_type = position_embedding_type
self.backbone = backbone
self.use_pretrained_backbone = use_pretrained_backbone
self.backbone_kwargs = backbone_kwargs
self.dilation = dilation
# deformable attributes
self.num_feature_levels = num_feature_levels
......
......@@ -49,6 +49,9 @@ class DetaConfig(PretrainedConfig):
use_timm_backbone (`bool`, *optional*, `False`):
Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers
library.
backbone_kwargs (`dict`, *optional*):
Keyword arguments to be passed to AutoBackbone when loading from a checkpoint
e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set.
num_queries (`int`, *optional*, defaults to 900):
Number of object queries, i.e. detection slots. This is the maximal number of objects [`DetaModel`] can
detect in a single image. In case `two_stage` is set to `True`, we use `two_stage_num_proposals` instead.
......@@ -150,6 +153,7 @@ class DetaConfig(PretrainedConfig):
backbone=None,
use_pretrained_backbone=False,
use_timm_backbone=False,
backbone_kwargs=None,
num_queries=900,
max_position_embeddings=2048,
encoder_layers=6,
......@@ -204,10 +208,14 @@ class DetaConfig(PretrainedConfig):
config_class = CONFIG_MAPPING[backbone_model_type]
backbone_config = config_class.from_dict(backbone_config)
if backbone_kwargs is not None and backbone_kwargs and backbone_config is not None:
raise ValueError("You can't specify both `backbone_kwargs` and `backbone_config`.")
self.backbone_config = backbone_config
self.backbone = backbone
self.use_pretrained_backbone = use_pretrained_backbone
self.use_timm_backbone = use_timm_backbone
self.backbone_kwargs = backbone_kwargs
self.num_queries = num_queries
self.max_position_embeddings = max_position_embeddings
self.d_model = d_model
......
......@@ -98,6 +98,9 @@ class DetrConfig(PretrainedConfig):
is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
use_pretrained_backbone (`bool`, *optional*, `True`):
Whether to use pretrained weights for the backbone.
backbone_kwargs (`dict`, *optional*):
Keyword arguments to be passed to AutoBackbone when loading from a checkpoint
e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set.
dilation (`bool`, *optional*, defaults to `False`):
Whether to replace stride with dilation in the last convolutional block (DC5). Only supported when
`use_timm_backbone` = `True`.
......@@ -166,6 +169,7 @@ class DetrConfig(PretrainedConfig):
position_embedding_type="sine",
backbone="resnet50",
use_pretrained_backbone=True,
backbone_kwargs=None,
dilation=False,
class_cost=1,
bbox_cost=5,
......@@ -188,6 +192,9 @@ class DetrConfig(PretrainedConfig):
if backbone_config is not None and use_timm_backbone:
raise ValueError("You can't specify both `backbone_config` and `use_timm_backbone`.")
if backbone_kwargs is not None and backbone_kwargs and backbone_config is not None:
raise ValueError("You can't specify both `backbone_kwargs` and `backbone_config`.")
if not use_timm_backbone:
if backbone_config is None:
logger.info("`backbone_config` is `None`. Initializing the config with the default `ResNet` backbone.")
......@@ -223,6 +230,7 @@ class DetrConfig(PretrainedConfig):
self.position_embedding_type = position_embedding_type
self.backbone = backbone
self.use_pretrained_backbone = use_pretrained_backbone
self.backbone_kwargs = backbone_kwargs
self.dilation = dilation
# Hungarian matcher
self.class_cost = class_cost
......
......@@ -120,6 +120,9 @@ class DPTConfig(PretrainedConfig):
use_timm_backbone (`bool`, *optional*, defaults to `False`):
Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers
library.
backbone_kwargs (`dict`, *optional*):
Keyword arguments to be passed to AutoBackbone when loading from a checkpoint
e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set.
Example:
......@@ -173,6 +176,7 @@ class DPTConfig(PretrainedConfig):
backbone=None,
use_pretrained_backbone=False,
use_timm_backbone=False,
backbone_kwargs=None,
**kwargs,
):
super().__init__(**kwargs)
......@@ -230,9 +234,13 @@ class DPTConfig(PretrainedConfig):
if use_autobackbone and backbone_config is not None and backbone is not None:
raise ValueError("You can't specify both `backbone` and `backbone_config`.")
if backbone_kwargs is not None and backbone_kwargs and backbone_config is not None:
raise ValueError("You can't specify both `backbone_kwargs` and `backbone_config`.")
self.backbone = backbone
self.use_pretrained_backbone = use_pretrained_backbone
self.use_timm_backbone = use_timm_backbone
self.backbone_kwargs = backbone_kwargs
self.num_hidden_layers = None if use_autobackbone else num_hidden_layers
self.num_attention_heads = None if use_autobackbone else num_attention_heads
self.intermediate_size = None if use_autobackbone else intermediate_size
......
......@@ -56,6 +56,9 @@ class Mask2FormerConfig(PretrainedConfig):
use_timm_backbone (`bool`, *optional*, `False`):
Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers
library.
backbone_kwargs (`dict`, *optional*):
Keyword arguments to be passed to AutoBackbone when loading from a checkpoint
e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set.
feature_size (`int`, *optional*, defaults to 256):
The features (channels) of the resulting feature maps.
mask_feature_size (`int`, *optional*, defaults to 256):
......@@ -163,9 +166,10 @@ class Mask2FormerConfig(PretrainedConfig):
use_auxiliary_loss: bool = True,
feature_strides: List[int] = [4, 8, 16, 32],
output_auxiliary_logits: bool = None,
backbone=None,
use_pretrained_backbone=False,
use_timm_backbone=False,
backbone: Optional[str] = None,
use_pretrained_backbone: bool = False,
use_timm_backbone: bool = False,
backbone_kwargs: Optional[Dict] = None,
**kwargs,
):
if use_pretrained_backbone:
......@@ -189,6 +193,9 @@ class Mask2FormerConfig(PretrainedConfig):
out_features=["stage1", "stage2", "stage3", "stage4"],
)
if backbone_kwargs is not None and backbone_kwargs and backbone_config is not None:
raise ValueError("You can't specify both `backbone_kwargs` and `backbone_config`.")
if isinstance(backbone_config, dict):
backbone_model_type = backbone_config.pop("model_type")
config_class = CONFIG_MAPPING[backbone_model_type]
......@@ -233,6 +240,7 @@ class Mask2FormerConfig(PretrainedConfig):
self.backbone = backbone
self.use_pretrained_backbone = use_pretrained_backbone
self.use_timm_backbone = use_timm_backbone
self.backbone_kwargs = backbone_kwargs
super().__init__(**kwargs)
......
......@@ -66,6 +66,9 @@ class MaskFormerConfig(PretrainedConfig):
use_timm_backbone (`bool`, *optional*, `False`):
Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers
library.
backbone_kwargs (`dict`, *optional*):
Keyword arguments to be passed to AutoBackbone when loading from a checkpoint
e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set.
decoder_config (`Dict`, *optional*):
The configuration passed to the transformer decoder model, if unset the base config for `detr-resnet-50`
will be used.
......@@ -126,6 +129,7 @@ class MaskFormerConfig(PretrainedConfig):
backbone: Optional[str] = None,
use_pretrained_backbone: bool = False,
use_timm_backbone: bool = False,
backbone_kwargs: Optional[Dict] = None,
**kwargs,
):
if use_pretrained_backbone:
......@@ -134,6 +138,9 @@ class MaskFormerConfig(PretrainedConfig):
if backbone_config is not None and backbone is not None:
raise ValueError("You can't specify both `backbone` and `backbone_config`.")
if backbone_kwargs is not None and backbone_kwargs and backbone_config is not None:
raise ValueError("You can't specify both `backbone_kwargs` and `backbone_config`.")
if backbone_config is None and backbone is None:
# fall back to https://huggingface.co/microsoft/swin-base-patch4-window12-384-in22k
backbone_config = SwinConfig(
......@@ -198,6 +205,7 @@ class MaskFormerConfig(PretrainedConfig):
self.backbone = backbone
self.use_pretrained_backbone = use_pretrained_backbone
self.use_timm_backbone = use_timm_backbone
self.backbone_kwargs = backbone_kwargs
super().__init__(**kwargs)
@classmethod
......
......@@ -53,6 +53,9 @@ class OneFormerConfig(PretrainedConfig):
use_timm_backbone (`bool`, *optional*, defaults to `False`):
Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers
library.
backbone_kwargs (`dict`, *optional*):
Keyword arguments to be passed to AutoBackbone when loading from a checkpoint
e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set.
ignore_value (`int`, *optional*, defaults to 255):
Values to be ignored in GT label while calculating loss.
num_queries (`int`, *optional*, defaults to 150):
......@@ -156,6 +159,7 @@ class OneFormerConfig(PretrainedConfig):
backbone: Optional[str] = None,
use_pretrained_backbone: bool = False,
use_timm_backbone: bool = False,
backbone_kwargs: Optional[Dict] = None,
ignore_value: int = 255,
num_queries: int = 150,
no_object_weight: int = 0.1,
......@@ -223,10 +227,14 @@ class OneFormerConfig(PretrainedConfig):
config_class = CONFIG_MAPPING[backbone_model_type]
backbone_config = config_class.from_dict(backbone_config)
if backbone_kwargs is not None and backbone_kwargs and backbone_config is not None:
raise ValueError("You can't specify both `backbone_kwargs` and `backbone_config`.")
self.backbone_config = backbone_config
self.backbone = backbone
self.use_pretrained_backbone = use_pretrained_backbone
self.use_timm_backbone = use_timm_backbone
self.backbone_kwargs = backbone_kwargs
self.ignore_value = ignore_value
self.num_queries = num_queries
self.no_object_weight = no_object_weight
......
......@@ -98,6 +98,9 @@ class TableTransformerConfig(PretrainedConfig):
is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
use_pretrained_backbone (`bool`, *optional*, `True`):
Whether to use pretrained weights for the backbone.
backbone_kwargs (`dict`, *optional*):
Keyword arguments to be passed to AutoBackbone when loading from a checkpoint
e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set.
dilation (`bool`, *optional*, defaults to `False`):
Whether to replace stride with dilation in the last convolutional block (DC5). Only supported when
`use_timm_backbone` = `True`.
......@@ -167,6 +170,7 @@ class TableTransformerConfig(PretrainedConfig):
position_embedding_type="sine",
backbone="resnet50",
use_pretrained_backbone=True,
backbone_kwargs=None,
dilation=False,
class_cost=1,
bbox_cost=5,
......@@ -189,6 +193,9 @@ class TableTransformerConfig(PretrainedConfig):
if backbone_config is not None and use_timm_backbone:
raise ValueError("You can't specify both `backbone_config` and `use_timm_backbone`.")
if backbone_kwargs is not None and backbone_kwargs and backbone_config is not None:
raise ValueError("You can't specify both `backbone_kwargs` and `backbone_config`.")
if not use_timm_backbone:
if backbone_config is None:
logger.info("`backbone_config` is `None`. Initializing the config with the default `ResNet` backbone.")
......@@ -224,6 +231,7 @@ class TableTransformerConfig(PretrainedConfig):
self.position_embedding_type = position_embedding_type
self.backbone = backbone
self.use_pretrained_backbone = use_pretrained_backbone
self.backbone_kwargs = backbone_kwargs
self.dilation = dilation
# Hungarian matcher
self.class_cost = class_cost
......
......@@ -52,6 +52,9 @@ class TvpConfig(PretrainedConfig):
use_timm_backbone (`bool`, *optional*, defaults to `False`):
Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers
library.
backbone_kwargs (`dict`, *optional*):
Keyword arguments to be passed to AutoBackbone when loading from a checkpoint
e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set.
distance_loss_weight (`float`, *optional*, defaults to 1.0):
The weight of distance loss.
duration_loss_weight (`float`, *optional*, defaults to 0.1):
......@@ -107,6 +110,7 @@ class TvpConfig(PretrainedConfig):
backbone=None,
use_pretrained_backbone=False,
use_timm_backbone=False,
backbone_kwargs=None,
distance_loss_weight=1.0,
duration_loss_weight=0.1,
visual_prompter_type="framepad",
......@@ -144,10 +148,14 @@ class TvpConfig(PretrainedConfig):
config_class = CONFIG_MAPPING[backbone_model_type]
backbone_config = config_class.from_dict(backbone_config)
if backbone_kwargs is not None and backbone_kwargs and backbone_config is not None:
raise ValueError("You can't specify both `backbone_kwargs` and `backbone_config`.")
self.backbone_config = backbone_config
self.backbone = backbone
self.use_pretrained_backbone = use_pretrained_backbone
self.use_timm_backbone = use_timm_backbone
self.backbone_kwargs = backbone_kwargs
self.distance_loss_weight = distance_loss_weight
self.duration_loss_weight = duration_loss_weight
self.visual_prompter_type = visual_prompter_type
......
......@@ -45,6 +45,9 @@ class UperNetConfig(PretrainedConfig):
use_timm_backbone (`bool`, *optional*, `False`):
Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers
library.
backbone_kwargs (`dict`, *optional*):
Keyword arguments to be passed to AutoBackbone when loading from a checkpoint
e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set.
hidden_size (`int`, *optional*, defaults to 512):
The number of hidden units in the convolutional layers.
initializer_range (`float`, *optional*, defaults to 0.02):
......@@ -87,6 +90,7 @@ class UperNetConfig(PretrainedConfig):
backbone=None,
use_pretrained_backbone=False,
use_timm_backbone=False,
backbone_kwargs=None,
hidden_size=512,
initializer_range=0.02,
pool_scales=[1, 2, 3, 6],
......@@ -114,10 +118,14 @@ class UperNetConfig(PretrainedConfig):
config_class = CONFIG_MAPPING[backbone_model_type]
backbone_config = config_class.from_dict(backbone_config)
if backbone_kwargs is not None and backbone_kwargs and backbone_config is not None:
raise ValueError("You can't specify both `backbone_kwargs` and `backbone_config`.")
self.backbone_config = backbone_config
self.backbone = backbone
self.use_pretrained_backbone = use_pretrained_backbone
self.use_timm_backbone = use_timm_backbone
self.backbone_kwargs = backbone_kwargs
self.hidden_size = hidden_size
self.initializer_range = initializer_range
self.pool_scales = pool_scales
......
......@@ -51,6 +51,9 @@ class ViTHybridConfig(PretrainedConfig):
use_timm_backbone (`bool`, *optional*, defaults to `False`):
Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers
library.
backbone_kwargs (`dict`, *optional*):
Keyword arguments to be passed to AutoBackbone when loading from a checkpoint
e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set.
hidden_size (`int`, *optional*, defaults to 768):
Dimensionality of the encoder layers and the pooler layer.
num_hidden_layers (`int`, *optional*, defaults to 12):
......@@ -104,6 +107,7 @@ class ViTHybridConfig(PretrainedConfig):
backbone=None,
use_pretrained_backbone=False,
use_timm_backbone=False,
backbone_kwargs=None,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
......@@ -137,6 +141,9 @@ class ViTHybridConfig(PretrainedConfig):
"embedding_dynamic_padding": True,
}
if backbone_kwargs is not None and backbone_kwargs and backbone_config is not None:
raise ValueError("You can't specify both `backbone_kwargs` and `backbone_config`.")
if isinstance(backbone_config, dict):
if "model_type" in backbone_config:
backbone_config_class = CONFIG_MAPPING[backbone_config["model_type"]]
......@@ -152,6 +159,7 @@ class ViTHybridConfig(PretrainedConfig):
self.backbone = backbone
self.use_pretrained_backbone = use_pretrained_backbone
self.use_timm_backbone = use_timm_backbone
self.backbone_kwargs = backbone_kwargs
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
......
......@@ -51,6 +51,9 @@ class VitMatteConfig(PretrainedConfig):
use_timm_backbone (`bool`, *optional*, defaults to `False`):
Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers
library.
backbone_kwargs (`dict`, *optional*):
Keyword arguments to be passed to AutoBackbone when loading from a checkpoint
e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set.
hidden_size (`int`, *optional*, defaults to 384):
The number of input channels of the decoder.
batch_norm_eps (`float`, *optional*, defaults to 1e-05):
......@@ -85,6 +88,7 @@ class VitMatteConfig(PretrainedConfig):
backbone=None,
use_pretrained_backbone=False,
use_timm_backbone=False,
backbone_kwargs=None,
hidden_size: int = 384,
batch_norm_eps: float = 1e-5,
initializer_range: float = 0.02,
......@@ -108,10 +112,14 @@ class VitMatteConfig(PretrainedConfig):
config_class = CONFIG_MAPPING[backbone_model_type]
backbone_config = config_class.from_dict(backbone_config)
if backbone_kwargs is not None and backbone_kwargs and backbone_config is not None:
raise ValueError("You can't specify both `backbone_kwargs` and `backbone_config`.")
self.backbone_config = backbone_config
self.backbone = backbone
self.use_pretrained_backbone = use_pretrained_backbone
self.use_timm_backbone = use_timm_backbone
self.backbone_kwargs = backbone_kwargs
self.batch_norm_eps = batch_norm_eps
self.hidden_size = hidden_size
self.initializer_range = initializer_range
......
......@@ -304,6 +304,12 @@ def load_backbone(config):
use_timm_backbone = getattr(config, "use_timm_backbone", None)
use_pretrained_backbone = getattr(config, "use_pretrained_backbone", None)
backbone_checkpoint = getattr(config, "backbone", None)
backbone_kwargs = getattr(config, "backbone_kwargs", None)
backbone_kwargs = {} if backbone_kwargs is None else backbone_kwargs
if backbone_kwargs and backbone_config is not None:
raise ValueError("You can't specify both `backbone_kwargs` and `backbone_config`.")
# If there is a backbone_config and a backbone checkpoint, and use_pretrained_backbone=False then the desired
# behaviour is ill-defined: do you want to load from the checkpoint's config or the backbone_config?
......@@ -317,7 +323,7 @@ def load_backbone(config):
and backbone_checkpoint is None
and backbone_checkpoint is None
):
return AutoBackbone.from_config(config=config)
return AutoBackbone.from_config(config=config, **backbone_kwargs)
# config from the parent model that has a backbone
if use_timm_backbone:
......@@ -326,16 +332,19 @@ def load_backbone(config):
# Because of how timm backbones were originally added to models, we need to pass in use_pretrained_backbone
# to determine whether to load the pretrained weights.
backbone = AutoBackbone.from_pretrained(
backbone_checkpoint, use_timm_backbone=use_timm_backbone, use_pretrained_backbone=use_pretrained_backbone
backbone_checkpoint,
use_timm_backbone=use_timm_backbone,
use_pretrained_backbone=use_pretrained_backbone,
**backbone_kwargs,
)
elif use_pretrained_backbone:
if backbone_checkpoint is None:
raise ValueError("config.backbone must be set if use_pretrained_backbone is True")
backbone = AutoBackbone.from_pretrained(backbone_checkpoint)
backbone = AutoBackbone.from_pretrained(backbone_checkpoint, **backbone_kwargs)
else:
if backbone_config is None and backbone_checkpoint is None:
raise ValueError("Either config.backbone_config or config.backbone must be set")
if backbone_config is None:
backbone_config = AutoConfig.from_pretrained(backbone_checkpoint)
backbone_config = AutoConfig.from_pretrained(backbone_checkpoint, **backbone_kwargs)
backbone = AutoBackbone.from_config(config=backbone_config)
return backbone
......@@ -16,7 +16,7 @@ import unittest
import pytest
from transformers import DetrConfig, MaskFormerConfig
from transformers import DetrConfig, MaskFormerConfig, ResNetBackbone, ResNetConfig, TimmBackbone
from transformers.testing_utils import require_torch, slow
from transformers.utils.backbone_utils import (
BackboneMixin,
......@@ -137,6 +137,65 @@ class BackboneUtilsTester(unittest.TestCase):
self.assertEqual(backbone.out_features, ["a", "c"])
self.assertEqual(backbone.out_indices, [-3, -1])
@slow
@require_torch
def test_load_backbone_from_config(self):
"""
Test that load_backbone correctly loads a backbone from a backbone config.
"""
config = MaskFormerConfig(backbone_config=ResNetConfig(out_indices=(0, 2)))
backbone = load_backbone(config)
self.assertEqual(backbone.out_features, ["stem", "stage2"])
self.assertEqual(backbone.out_indices, (0, 2))
self.assertIsInstance(backbone, ResNetBackbone)
@slow
@require_torch
def test_load_backbone_from_checkpoint(self):
"""
Test that load_backbone correctly loads a backbone from a checkpoint.
"""
config = MaskFormerConfig(backbone="microsoft/resnet-18", backbone_config=None)
backbone = load_backbone(config)
self.assertEqual(backbone.out_indices, [4])
self.assertEqual(backbone.out_features, ["stage4"])
self.assertIsInstance(backbone, ResNetBackbone)
config = MaskFormerConfig(
backbone="resnet18",
use_timm_backbone=True,
)
backbone = load_backbone(config)
# We can't know ahead of time the exact output features and indices, or the layer names before
# creating the timm model, so it defaults to the last layer (-1,) and has a different layer name
self.assertEqual(backbone.out_indices, (-1,))
self.assertEqual(backbone.out_features, ["layer4"])
self.assertIsInstance(backbone, TimmBackbone)
@slow
@require_torch
def test_load_backbone_backbone_kwargs(self):
"""
Test that load_backbone correctly configures the loaded backbone with the provided kwargs.
"""
config = MaskFormerConfig(backbone="resnet18", use_timm_backbone=True, backbone_kwargs={"out_indices": (0, 1)})
backbone = load_backbone(config)
self.assertEqual(backbone.out_indices, (0, 1))
self.assertIsInstance(backbone, TimmBackbone)
config = MaskFormerConfig(backbone="microsoft/resnet-18", backbone_kwargs={"out_indices": (0, 2)})
backbone = load_backbone(config)
self.assertEqual(backbone.out_indices, (0, 2))
self.assertIsInstance(backbone, ResNetBackbone)
# Check can't be passed with a backone config
with pytest.raises(ValueError):
config = MaskFormerConfig(
backbone="microsoft/resnet-18",
backbone_config=ResNetConfig(out_indices=(0, 2)),
backbone_kwargs={"out_indices": (0, 1)},
)
@slow
@require_torch
def test_load_backbone_in_new_model(self):
......
......@@ -224,6 +224,7 @@ def check_attribute_being_used(config_class, attributes, default_value, source_s
"backbone",
"backbone_config",
"use_timm_backbone",
"backbone_kwargs",
]
attributes_used_in_generation = ["encoder_no_repeat_ngram_size"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment