Unverified Commit 2fa1c808 authored by amyeroberts's avatar amyeroberts Committed by GitHub
Browse files

[`Backbone`] Use `load_backbone` instead of `AutoBackbone.from_config` (#28661)

* Enable instantiating model with pretrained backbone weights

* Remove doc updates until changes made in modeling code

* Use load_backbone instead

* Add use_timm_backbone to the model configs

* Add missing imports and arguments

* Update docstrings

* Make sure test is properly configured

* Include recent DPT updates
parent c24c5245
...@@ -48,6 +48,9 @@ class VitMatteConfig(PretrainedConfig): ...@@ -48,6 +48,9 @@ class VitMatteConfig(PretrainedConfig):
is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights. is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
use_pretrained_backbone (`bool`, *optional*, defaults to `False`): use_pretrained_backbone (`bool`, *optional*, defaults to `False`):
Whether to use pretrained weights for the backbone. Whether to use pretrained weights for the backbone.
use_timm_backbone (`bool`, *optional*, defaults to `False`):
Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers
library.
hidden_size (`int`, *optional*, defaults to 384): hidden_size (`int`, *optional*, defaults to 384):
The number of input channels of the decoder. The number of input channels of the decoder.
batch_norm_eps (`float`, *optional*, defaults to 1e-05): batch_norm_eps (`float`, *optional*, defaults to 1e-05):
...@@ -81,6 +84,7 @@ class VitMatteConfig(PretrainedConfig): ...@@ -81,6 +84,7 @@ class VitMatteConfig(PretrainedConfig):
backbone_config: PretrainedConfig = None, backbone_config: PretrainedConfig = None,
backbone=None, backbone=None,
use_pretrained_backbone=False, use_pretrained_backbone=False,
use_timm_backbone=False,
hidden_size: int = 384, hidden_size: int = 384,
batch_norm_eps: float = 1e-5, batch_norm_eps: float = 1e-5,
initializer_range: float = 0.02, initializer_range: float = 0.02,
...@@ -107,6 +111,7 @@ class VitMatteConfig(PretrainedConfig): ...@@ -107,6 +111,7 @@ class VitMatteConfig(PretrainedConfig):
self.backbone_config = backbone_config self.backbone_config = backbone_config
self.backbone = backbone self.backbone = backbone
self.use_pretrained_backbone = use_pretrained_backbone self.use_pretrained_backbone = use_pretrained_backbone
self.use_timm_backbone = use_timm_backbone
self.batch_norm_eps = batch_norm_eps self.batch_norm_eps = batch_norm_eps
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.initializer_range = initializer_range self.initializer_range = initializer_range
......
...@@ -20,7 +20,6 @@ from typing import Optional, Tuple ...@@ -20,7 +20,6 @@ from typing import Optional, Tuple
import torch import torch
from torch import nn from torch import nn
from ... import AutoBackbone
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...utils import ( from ...utils import (
ModelOutput, ModelOutput,
...@@ -28,6 +27,7 @@ from ...utils import ( ...@@ -28,6 +27,7 @@ from ...utils import (
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
replace_return_docstrings, replace_return_docstrings,
) )
from ...utils.backbone_utils import load_backbone
from .configuration_vitmatte import VitMatteConfig from .configuration_vitmatte import VitMatteConfig
...@@ -259,7 +259,7 @@ class VitMatteForImageMatting(VitMattePreTrainedModel): ...@@ -259,7 +259,7 @@ class VitMatteForImageMatting(VitMattePreTrainedModel):
super().__init__(config) super().__init__(config)
self.config = config self.config = config
self.backbone = AutoBackbone.from_config(config.backbone_config) self.backbone = load_backbone(config)
self.decoder = VitMatteDetailCaptureModule(config) self.decoder = VitMatteDetailCaptureModule(config)
# Initialize weights and apply final processing # Initialize weights and apply final processing
......
...@@ -443,6 +443,7 @@ class ConditionalDetrModelTest(ModelTesterMixin, GenerationTesterMixin, Pipeline ...@@ -443,6 +443,7 @@ class ConditionalDetrModelTest(ModelTesterMixin, GenerationTesterMixin, Pipeline
# let's pick a random timm backbone # let's pick a random timm backbone
config.backbone = "tf_mobilenetv3_small_075" config.backbone = "tf_mobilenetv3_small_075"
config.use_timm_backbone = True
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
model = model_class(config) model = model_class(config)
......
...@@ -219,7 +219,11 @@ def check_attribute_being_used(config_class, attributes, default_value, source_s ...@@ -219,7 +219,11 @@ def check_attribute_being_used(config_class, attributes, default_value, source_s
"out_features", "out_features",
"out_indices", "out_indices",
"sampling_rate", "sampling_rate",
# backbone related arguments passed to load_backbone
"use_pretrained_backbone", "use_pretrained_backbone",
"backbone",
"backbone_config",
"use_timm_backbone",
] ]
attributes_used_in_generation = ["encoder_no_repeat_ngram_size"] attributes_used_in_generation = ["encoder_no_repeat_ngram_size"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment