Unverified Commit 2fa1c808 authored by amyeroberts's avatar amyeroberts Committed by GitHub
Browse files

[`Backbone`] Use `load_backbone` instead of `AutoBackbone.from_config` (#28661)

* Enable instantiating model with pretrained backbone weights

* Remove doc updates until changes made in modeling code

* Use load_backbone instead

* Add use_timm_backbone to the model configs

* Add missing imports and arguments

* Update docstrings

* Make sure test is properly configured

* Include recent DPT updates
parent c24c5245
......@@ -48,6 +48,9 @@ class VitMatteConfig(PretrainedConfig):
is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
use_pretrained_backbone (`bool`, *optional*, defaults to `False`):
Whether to use pretrained weights for the backbone.
use_timm_backbone (`bool`, *optional*, defaults to `False`):
Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers
library.
hidden_size (`int`, *optional*, defaults to 384):
The number of input channels of the decoder.
batch_norm_eps (`float`, *optional*, defaults to 1e-05):
......@@ -81,6 +84,7 @@ class VitMatteConfig(PretrainedConfig):
backbone_config: PretrainedConfig = None,
backbone=None,
use_pretrained_backbone=False,
use_timm_backbone=False,
hidden_size: int = 384,
batch_norm_eps: float = 1e-5,
initializer_range: float = 0.02,
......@@ -107,6 +111,7 @@ class VitMatteConfig(PretrainedConfig):
self.backbone_config = backbone_config
self.backbone = backbone
self.use_pretrained_backbone = use_pretrained_backbone
self.use_timm_backbone = use_timm_backbone
self.batch_norm_eps = batch_norm_eps
self.hidden_size = hidden_size
self.initializer_range = initializer_range
......
......@@ -20,7 +20,6 @@ from typing import Optional, Tuple
import torch
from torch import nn
from ... import AutoBackbone
from ...modeling_utils import PreTrainedModel
from ...utils import (
ModelOutput,
......@@ -28,6 +27,7 @@ from ...utils import (
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from ...utils.backbone_utils import load_backbone
from .configuration_vitmatte import VitMatteConfig
......@@ -259,7 +259,7 @@ class VitMatteForImageMatting(VitMattePreTrainedModel):
super().__init__(config)
self.config = config
self.backbone = AutoBackbone.from_config(config.backbone_config)
self.backbone = load_backbone(config)
self.decoder = VitMatteDetailCaptureModule(config)
# Initialize weights and apply final processing
......
......@@ -443,6 +443,7 @@ class ConditionalDetrModelTest(ModelTesterMixin, GenerationTesterMixin, Pipeline
# let's pick a random timm backbone
config.backbone = "tf_mobilenetv3_small_075"
config.use_timm_backbone = True
for model_class in self.all_model_classes:
model = model_class(config)
......
......@@ -219,7 +219,11 @@ def check_attribute_being_used(config_class, attributes, default_value, source_s
"out_features",
"out_indices",
"sampling_rate",
# backbone related arguments passed to load_backbone
"use_pretrained_backbone",
"backbone",
"backbone_config",
"use_timm_backbone",
]
attributes_used_in_generation = ["encoder_no_repeat_ngram_size"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment