Unverified Commit 27c79a0f authored by amyeroberts's avatar amyeroberts Committed by GitHub
Browse files

Enable instantiating model with pretrained backbone weights (#28214)



* Enable instantiating model with pretrained backbone weights

* Update tests so backbone checkpoint isn't passed in

* Remove doc updates until changes made in modeling code

* Clarify pretrained import

* Update configs - docs and validation check

* Update src/transformers/utils/backbone_utils.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Clarify exception message

* Update config init in tests

* Add test for when use_timm_backbone=True

* Small test updates

---------
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
parent 008a6a22
...@@ -602,10 +602,6 @@ class _BaseAutoBackboneClass(_BaseAutoModelClass): ...@@ -602,10 +602,6 @@ class _BaseAutoBackboneClass(_BaseAutoModelClass):
config = kwargs.pop("config", TimmBackboneConfig()) config = kwargs.pop("config", TimmBackboneConfig())
use_timm = kwargs.pop("use_timm_backbone", True)
if not use_timm:
raise ValueError("`use_timm_backbone` must be `True` for timm backbones")
if kwargs.get("out_features", None) is not None: if kwargs.get("out_features", None) is not None:
raise ValueError("Cannot specify `out_features` for timm backbones") raise ValueError("Cannot specify `out_features` for timm backbones")
...@@ -627,7 +623,8 @@ class _BaseAutoBackboneClass(_BaseAutoModelClass): ...@@ -627,7 +623,8 @@ class _BaseAutoBackboneClass(_BaseAutoModelClass):
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
if kwargs.get("use_timm_backbone", False): use_timm_backbone = kwargs.pop("use_timm_backbone", False)
if use_timm_backbone:
return cls._load_timm_backbone_from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return cls._load_timm_backbone_from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
......
...@@ -93,11 +93,11 @@ class ConditionalDetrConfig(PretrainedConfig): ...@@ -93,11 +93,11 @@ class ConditionalDetrConfig(PretrainedConfig):
position_embedding_type (`str`, *optional*, defaults to `"sine"`): position_embedding_type (`str`, *optional*, defaults to `"sine"`):
Type of position embeddings to be used on top of the image features. One of `"sine"` or `"learned"`. Type of position embeddings to be used on top of the image features. One of `"sine"` or `"learned"`.
backbone (`str`, *optional*, defaults to `"resnet50"`): backbone (`str`, *optional*, defaults to `"resnet50"`):
Name of convolutional backbone to use in case `use_timm_backbone` = `True`. Supports any convolutional Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
backbone from the timm package. For a list of all available models, see [this will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone`
page](https://rwightman.github.io/pytorch-image-models/#load-a-pretrained-model). is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
use_pretrained_backbone (`bool`, *optional*, defaults to `True`): use_pretrained_backbone (`bool`, *optional*, defaults to `True`):
Whether to use pretrained weights for the backbone. Only supported when `use_timm_backbone` = `True`. Whether to use pretrained weights for the backbone.
dilation (`bool`, *optional*, defaults to `False`): dilation (`bool`, *optional*, defaults to `False`):
Whether to replace stride with dilation in the last convolutional block (DC5). Only supported when Whether to replace stride with dilation in the last convolutional block (DC5). Only supported when
`use_timm_backbone` = `True`. `use_timm_backbone` = `True`.
...@@ -180,6 +180,14 @@ class ConditionalDetrConfig(PretrainedConfig): ...@@ -180,6 +180,14 @@ class ConditionalDetrConfig(PretrainedConfig):
focal_alpha=0.25, focal_alpha=0.25,
**kwargs, **kwargs,
): ):
if not use_timm_backbone and use_pretrained_backbone:
raise ValueError(
"Loading pretrained backbone weights from the transformers library is not supported yet. `use_timm_backbone` must be set to `True` when `use_pretrained_backbone=True`"
)
if backbone_config is not None and backbone is not None:
raise ValueError("You can't specify both `backbone` and `backbone_config`.")
if backbone_config is not None and use_timm_backbone: if backbone_config is not None and use_timm_backbone:
raise ValueError("You can't specify both `backbone_config` and `use_timm_backbone`.") raise ValueError("You can't specify both `backbone_config` and `use_timm_backbone`.")
......
...@@ -85,11 +85,11 @@ class DeformableDetrConfig(PretrainedConfig): ...@@ -85,11 +85,11 @@ class DeformableDetrConfig(PretrainedConfig):
position_embedding_type (`str`, *optional*, defaults to `"sine"`): position_embedding_type (`str`, *optional*, defaults to `"sine"`):
Type of position embeddings to be used on top of the image features. One of `"sine"` or `"learned"`. Type of position embeddings to be used on top of the image features. One of `"sine"` or `"learned"`.
backbone (`str`, *optional*, defaults to `"resnet50"`): backbone (`str`, *optional*, defaults to `"resnet50"`):
Name of convolutional backbone to use in case `use_timm_backbone` = `True`. Supports any convolutional Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
backbone from the timm package. For a list of all available models, see [this will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone`
page](https://rwightman.github.io/pytorch-image-models/#load-a-pretrained-model). is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
use_pretrained_backbone (`bool`, *optional*, defaults to `True`): use_pretrained_backbone (`bool`, *optional*, defaults to `True`):
Whether to use pretrained weights for the backbone. Only supported when `use_timm_backbone` = `True`. Whether to use pretrained weights for the backbone.
dilation (`bool`, *optional*, defaults to `False`): dilation (`bool`, *optional*, defaults to `False`):
Whether to replace stride with dilation in the last convolutional block (DC5). Only supported when Whether to replace stride with dilation in the last convolutional block (DC5). Only supported when
`use_timm_backbone` = `True`. `use_timm_backbone` = `True`.
...@@ -196,6 +196,14 @@ class DeformableDetrConfig(PretrainedConfig): ...@@ -196,6 +196,14 @@ class DeformableDetrConfig(PretrainedConfig):
disable_custom_kernels=False, disable_custom_kernels=False,
**kwargs, **kwargs,
): ):
if not use_timm_backbone and use_pretrained_backbone:
raise ValueError(
"Loading pretrained backbone weights from the transformers library is not supported yet. `use_timm_backbone` must be set to `True` when `use_pretrained_backbone=True`"
)
if backbone_config is not None and backbone is not None:
raise ValueError("You can't specify both `backbone` and `backbone_config`.")
if backbone_config is not None and use_timm_backbone: if backbone_config is not None and use_timm_backbone:
raise ValueError("You can't specify both `backbone_config` and `use_timm_backbone`.") raise ValueError("You can't specify both `backbone_config` and `use_timm_backbone`.")
......
...@@ -40,6 +40,12 @@ class DetaConfig(PretrainedConfig): ...@@ -40,6 +40,12 @@ class DetaConfig(PretrainedConfig):
Args: Args:
backbone_config (`PretrainedConfig` or `dict`, *optional*, defaults to `ResNetConfig()`): backbone_config (`PretrainedConfig` or `dict`, *optional*, defaults to `ResNetConfig()`):
The configuration of the backbone model. The configuration of the backbone model.
backbone (`str`, *optional*):
Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone`
is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
use_pretrained_backbone (`bool`, *optional*, `False`):
Whether to use pretrained weights for the backbone.
num_queries (`int`, *optional*, defaults to 900): num_queries (`int`, *optional*, defaults to 900):
Number of object queries, i.e. detection slots. This is the maximal number of objects [`DetaModel`] can Number of object queries, i.e. detection slots. This is the maximal number of objects [`DetaModel`] can
detect in a single image. In case `two_stage` is set to `True`, we use `two_stage_num_proposals` instead. detect in a single image. In case `two_stage` is set to `True`, we use `two_stage_num_proposals` instead.
...@@ -138,6 +144,8 @@ class DetaConfig(PretrainedConfig): ...@@ -138,6 +144,8 @@ class DetaConfig(PretrainedConfig):
def __init__( def __init__(
self, self,
backbone_config=None, backbone_config=None,
backbone=None,
use_pretrained_backbone=False,
num_queries=900, num_queries=900,
max_position_embeddings=2048, max_position_embeddings=2048,
encoder_layers=6, encoder_layers=6,
...@@ -177,7 +185,13 @@ class DetaConfig(PretrainedConfig): ...@@ -177,7 +185,13 @@ class DetaConfig(PretrainedConfig):
focal_alpha=0.25, focal_alpha=0.25,
**kwargs, **kwargs,
): ):
if backbone_config is None: if use_pretrained_backbone:
raise ValueError("Pretrained backbones are not supported yet.")
if backbone_config is not None and backbone is not None:
raise ValueError("You can't specify both `backbone` and `backbone_config`.")
if backbone_config is None and backbone is None:
logger.info("`backbone_config` is `None`. Initializing the config with the default `ResNet` backbone.") logger.info("`backbone_config` is `None`. Initializing the config with the default `ResNet` backbone.")
backbone_config = CONFIG_MAPPING["resnet"](out_features=["stage2", "stage3", "stage4"]) backbone_config = CONFIG_MAPPING["resnet"](out_features=["stage2", "stage3", "stage4"])
else: else:
...@@ -187,6 +201,8 @@ class DetaConfig(PretrainedConfig): ...@@ -187,6 +201,8 @@ class DetaConfig(PretrainedConfig):
backbone_config = config_class.from_dict(backbone_config) backbone_config = config_class.from_dict(backbone_config)
self.backbone_config = backbone_config self.backbone_config = backbone_config
self.backbone = backbone
self.use_pretrained_backbone = use_pretrained_backbone
self.num_queries = num_queries self.num_queries = num_queries
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
self.d_model = d_model self.d_model = d_model
......
...@@ -93,11 +93,11 @@ class DetrConfig(PretrainedConfig): ...@@ -93,11 +93,11 @@ class DetrConfig(PretrainedConfig):
position_embedding_type (`str`, *optional*, defaults to `"sine"`): position_embedding_type (`str`, *optional*, defaults to `"sine"`):
Type of position embeddings to be used on top of the image features. One of `"sine"` or `"learned"`. Type of position embeddings to be used on top of the image features. One of `"sine"` or `"learned"`.
backbone (`str`, *optional*, defaults to `"resnet50"`): backbone (`str`, *optional*, defaults to `"resnet50"`):
Name of convolutional backbone to use in case `use_timm_backbone` = `True`. Supports any convolutional Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
backbone from the timm package. For a list of all available models, see [this will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone`
page](https://rwightman.github.io/pytorch-image-models/#load-a-pretrained-model). is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
use_pretrained_backbone (`bool`, *optional*, defaults to `True`): use_pretrained_backbone (`bool`, *optional*, `True`):
Whether to use pretrained weights for the backbone. Only supported when `use_timm_backbone` = `True`. Whether to use pretrained weights for the backbone.
dilation (`bool`, *optional*, defaults to `False`): dilation (`bool`, *optional*, defaults to `False`):
Whether to replace stride with dilation in the last convolutional block (DC5). Only supported when Whether to replace stride with dilation in the last convolutional block (DC5). Only supported when
`use_timm_backbone` = `True`. `use_timm_backbone` = `True`.
...@@ -177,6 +177,14 @@ class DetrConfig(PretrainedConfig): ...@@ -177,6 +177,14 @@ class DetrConfig(PretrainedConfig):
eos_coefficient=0.1, eos_coefficient=0.1,
**kwargs, **kwargs,
): ):
if not use_timm_backbone and use_pretrained_backbone:
raise ValueError(
"Loading pretrained backbone weights from the transformers library is not supported yet. `use_timm_backbone` must be set to `True` when `use_pretrained_backbone=True`"
)
if backbone_config is not None and backbone is not None:
raise ValueError("You can't specify both `backbone` and `backbone_config`.")
if backbone_config is not None and use_timm_backbone: if backbone_config is not None and use_timm_backbone:
raise ValueError("You can't specify both `backbone_config` and `use_timm_backbone`.") raise ValueError("You can't specify both `backbone_config` and `use_timm_backbone`.")
......
...@@ -111,6 +111,12 @@ class DPTConfig(PretrainedConfig): ...@@ -111,6 +111,12 @@ class DPTConfig(PretrainedConfig):
backbone_config (`Union[Dict[str, Any], PretrainedConfig]`, *optional*): backbone_config (`Union[Dict[str, Any], PretrainedConfig]`, *optional*):
The configuration of the backbone model. Only used in case `is_hybrid` is `True` or in case you want to The configuration of the backbone model. Only used in case `is_hybrid` is `True` or in case you want to
leverage the [`AutoBackbone`] API. leverage the [`AutoBackbone`] API.
backbone (`str`, *optional*):
Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone`
is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
use_pretrained_backbone (`bool`, *optional*, defaults to `False`):
Whether to use pretrained weights for the backbone.
Example: Example:
...@@ -161,6 +167,8 @@ class DPTConfig(PretrainedConfig): ...@@ -161,6 +167,8 @@ class DPTConfig(PretrainedConfig):
backbone_featmap_shape=[1, 1024, 24, 24], backbone_featmap_shape=[1, 1024, 24, 24],
neck_ignore_stages=[0, 1], neck_ignore_stages=[0, 1],
backbone_config=None, backbone_config=None,
backbone=None,
use_pretrained_backbone=False,
**kwargs, **kwargs,
): ):
super().__init__(**kwargs) super().__init__(**kwargs)
...@@ -168,9 +176,15 @@ class DPTConfig(PretrainedConfig): ...@@ -168,9 +176,15 @@ class DPTConfig(PretrainedConfig):
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.is_hybrid = is_hybrid self.is_hybrid = is_hybrid
if use_pretrained_backbone:
raise ValueError("Pretrained backbones are not supported yet.")
if backbone_config is not None and backbone is not None:
raise ValueError("You can't specify both `backbone` and `backbone_config`.")
use_autobackbone = False use_autobackbone = False
if self.is_hybrid: if self.is_hybrid:
if backbone_config is None: if backbone_config is None and backbone is None:
logger.info("Initializing the config with a `BiT` backbone.") logger.info("Initializing the config with a `BiT` backbone.")
backbone_config = { backbone_config = {
"global_padding": "same", "global_padding": "same",
...@@ -213,6 +227,8 @@ class DPTConfig(PretrainedConfig): ...@@ -213,6 +227,8 @@ class DPTConfig(PretrainedConfig):
self.backbone_featmap_shape = None self.backbone_featmap_shape = None
self.neck_ignore_stages = [] self.neck_ignore_stages = []
self.backbone = backbone
self.use_pretrained_backbone = use_pretrained_backbone
self.num_hidden_layers = None if use_autobackbone else num_hidden_layers self.num_hidden_layers = None if use_autobackbone else num_hidden_layers
self.num_attention_heads = None if use_autobackbone else num_attention_heads self.num_attention_heads = None if use_autobackbone else num_attention_heads
self.intermediate_size = None if use_autobackbone else intermediate_size self.intermediate_size = None if use_autobackbone else intermediate_size
......
...@@ -47,6 +47,12 @@ class Mask2FormerConfig(PretrainedConfig): ...@@ -47,6 +47,12 @@ class Mask2FormerConfig(PretrainedConfig):
backbone_config (`PretrainedConfig` or `dict`, *optional*, defaults to `SwinConfig()`): backbone_config (`PretrainedConfig` or `dict`, *optional*, defaults to `SwinConfig()`):
The configuration of the backbone model. If unset, the configuration corresponding to The configuration of the backbone model. If unset, the configuration corresponding to
`swin-base-patch4-window12-384` will be used. `swin-base-patch4-window12-384` will be used.
backbone (`str`, *optional*):
Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone`
is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
use_pretrained_backbone (`bool`, *optional*, `False`):
Whether to use pretrained weights for the backbone.
feature_size (`int`, *optional*, defaults to 256): feature_size (`int`, *optional*, defaults to 256):
The features (channels) of the resulting feature maps. The features (channels) of the resulting feature maps.
mask_feature_size (`int`, *optional*, defaults to 256): mask_feature_size (`int`, *optional*, defaults to 256):
...@@ -154,9 +160,17 @@ class Mask2FormerConfig(PretrainedConfig): ...@@ -154,9 +160,17 @@ class Mask2FormerConfig(PretrainedConfig):
use_auxiliary_loss: bool = True, use_auxiliary_loss: bool = True,
feature_strides: List[int] = [4, 8, 16, 32], feature_strides: List[int] = [4, 8, 16, 32],
output_auxiliary_logits: bool = None, output_auxiliary_logits: bool = None,
backbone=None,
use_pretrained_backbone=False,
**kwargs, **kwargs,
): ):
if backbone_config is None: if use_pretrained_backbone:
raise ValueError("Pretrained backbones are not supported yet.")
if backbone_config is not None and backbone is not None:
raise ValueError("You can't specify both `backbone` and `backbone_config`.")
if backbone_config is None and backbone is None:
logger.info("`backbone_config` is `None`. Initializing the config with the default `Swin` backbone.") logger.info("`backbone_config` is `None`. Initializing the config with the default `Swin` backbone.")
backbone_config = CONFIG_MAPPING["swin"]( backbone_config = CONFIG_MAPPING["swin"](
image_size=224, image_size=224,
...@@ -177,7 +191,7 @@ class Mask2FormerConfig(PretrainedConfig): ...@@ -177,7 +191,7 @@ class Mask2FormerConfig(PretrainedConfig):
backbone_config = config_class.from_dict(backbone_config) backbone_config = config_class.from_dict(backbone_config)
# verify that the backbone is supported # verify that the backbone is supported
if backbone_config.model_type not in self.backbones_supported: if backbone_config is not None and backbone_config.model_type not in self.backbones_supported:
logger.warning_once( logger.warning_once(
f"Backbone {backbone_config.model_type} is not a supported model and may not be compatible with Mask2Former. " f"Backbone {backbone_config.model_type} is not a supported model and may not be compatible with Mask2Former. "
f"Supported model types: {','.join(self.backbones_supported)}" f"Supported model types: {','.join(self.backbones_supported)}"
...@@ -212,6 +226,8 @@ class Mask2FormerConfig(PretrainedConfig): ...@@ -212,6 +226,8 @@ class Mask2FormerConfig(PretrainedConfig):
self.feature_strides = feature_strides self.feature_strides = feature_strides
self.output_auxiliary_logits = output_auxiliary_logits self.output_auxiliary_logits = output_auxiliary_logits
self.num_hidden_layers = decoder_layers self.num_hidden_layers = decoder_layers
self.backbone = backbone
self.use_pretrained_backbone = use_pretrained_backbone
super().__init__(**kwargs) super().__init__(**kwargs)
......
...@@ -57,6 +57,12 @@ class MaskFormerConfig(PretrainedConfig): ...@@ -57,6 +57,12 @@ class MaskFormerConfig(PretrainedConfig):
backbone_config (`Dict`, *optional*): backbone_config (`Dict`, *optional*):
The configuration passed to the backbone, if unset, the configuration corresponding to The configuration passed to the backbone, if unset, the configuration corresponding to
`swin-base-patch4-window12-384` will be used. `swin-base-patch4-window12-384` will be used.
backbone (`str`, *optional*):
Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone`
is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
use_pretrained_backbone (`bool`, *optional*, `False`):
Whether to use pretrained weights for the backbone.
decoder_config (`Dict`, *optional*): decoder_config (`Dict`, *optional*):
The configuration passed to the transformer decoder model, if unset the base config for `detr-resnet-50` The configuration passed to the transformer decoder model, if unset the base config for `detr-resnet-50`
will be used. will be used.
...@@ -114,9 +120,17 @@ class MaskFormerConfig(PretrainedConfig): ...@@ -114,9 +120,17 @@ class MaskFormerConfig(PretrainedConfig):
cross_entropy_weight: float = 1.0, cross_entropy_weight: float = 1.0,
mask_weight: float = 20.0, mask_weight: float = 20.0,
output_auxiliary_logits: Optional[bool] = None, output_auxiliary_logits: Optional[bool] = None,
backbone: Optional[str] = None,
use_pretrained_backbone: bool = False,
**kwargs, **kwargs,
): ):
if backbone_config is None: if use_pretrained_backbone:
raise ValueError("Pretrained backbones are not supported yet.")
if backbone_config is not None and backbone is not None:
raise ValueError("You can't specify both `backbone` and `backbone_config`.")
if backbone_config is None and backbone is None:
# fall back to https://huggingface.co/microsoft/swin-base-patch4-window12-384-in22k # fall back to https://huggingface.co/microsoft/swin-base-patch4-window12-384-in22k
backbone_config = SwinConfig( backbone_config = SwinConfig(
image_size=384, image_size=384,
...@@ -136,7 +150,7 @@ class MaskFormerConfig(PretrainedConfig): ...@@ -136,7 +150,7 @@ class MaskFormerConfig(PretrainedConfig):
backbone_config = config_class.from_dict(backbone_config) backbone_config = config_class.from_dict(backbone_config)
# verify that the backbone is supported # verify that the backbone is supported
if backbone_config.model_type not in self.backbones_supported: if backbone_config is not None and backbone_config.model_type not in self.backbones_supported:
logger.warning_once( logger.warning_once(
f"Backbone {backbone_config.model_type} is not a supported model and may not be compatible with MaskFormer. " f"Backbone {backbone_config.model_type} is not a supported model and may not be compatible with MaskFormer. "
f"Supported model types: {','.join(self.backbones_supported)}" f"Supported model types: {','.join(self.backbones_supported)}"
...@@ -177,6 +191,8 @@ class MaskFormerConfig(PretrainedConfig): ...@@ -177,6 +191,8 @@ class MaskFormerConfig(PretrainedConfig):
self.num_attention_heads = self.decoder_config.encoder_attention_heads self.num_attention_heads = self.decoder_config.encoder_attention_heads
self.num_hidden_layers = self.decoder_config.num_hidden_layers self.num_hidden_layers = self.decoder_config.num_hidden_layers
self.backbone = backbone
self.use_pretrained_backbone = use_pretrained_backbone
super().__init__(**kwargs) super().__init__(**kwargs)
@classmethod @classmethod
......
...@@ -44,6 +44,12 @@ class OneFormerConfig(PretrainedConfig): ...@@ -44,6 +44,12 @@ class OneFormerConfig(PretrainedConfig):
Args: Args:
backbone_config (`PretrainedConfig`, *optional*, defaults to `SwinConfig`): backbone_config (`PretrainedConfig`, *optional*, defaults to `SwinConfig`):
The configuration of the backbone model. The configuration of the backbone model.
backbone (`str`, *optional*):
Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone`
is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
use_pretrained_backbone (`bool`, *optional*, defaults to `False`):
Whether to use pretrained weights for the backbone.
ignore_value (`int`, *optional*, defaults to 255): ignore_value (`int`, *optional*, defaults to 255):
Values to be ignored in GT label while calculating loss. Values to be ignored in GT label while calculating loss.
num_queries (`int`, *optional*, defaults to 150): num_queries (`int`, *optional*, defaults to 150):
...@@ -144,6 +150,8 @@ class OneFormerConfig(PretrainedConfig): ...@@ -144,6 +150,8 @@ class OneFormerConfig(PretrainedConfig):
def __init__( def __init__(
self, self,
backbone_config: Optional[Dict] = None, backbone_config: Optional[Dict] = None,
backbone: Optional[str] = None,
use_pretrained_backbone: bool = False,
ignore_value: int = 255, ignore_value: int = 255,
num_queries: int = 150, num_queries: int = 150,
no_object_weight: int = 0.1, no_object_weight: int = 0.1,
...@@ -186,7 +194,13 @@ class OneFormerConfig(PretrainedConfig): ...@@ -186,7 +194,13 @@ class OneFormerConfig(PretrainedConfig):
common_stride: int = 4, common_stride: int = 4,
**kwargs, **kwargs,
): ):
if backbone_config is None: if use_pretrained_backbone:
raise ValueError("Pretrained backbones are not supported yet.")
if backbone_config is not None and backbone is not None:
raise ValueError("You can't specify both `backbone` and `backbone_config`.")
if backbone_config is None and backbone is None:
logger.info("`backbone_config` is unset. Initializing the config with the default `Swin` backbone.") logger.info("`backbone_config` is unset. Initializing the config with the default `Swin` backbone.")
backbone_config = CONFIG_MAPPING["swin"]( backbone_config = CONFIG_MAPPING["swin"](
image_size=224, image_size=224,
...@@ -206,7 +220,8 @@ class OneFormerConfig(PretrainedConfig): ...@@ -206,7 +220,8 @@ class OneFormerConfig(PretrainedConfig):
backbone_config = config_class.from_dict(backbone_config) backbone_config = config_class.from_dict(backbone_config)
self.backbone_config = backbone_config self.backbone_config = backbone_config
self.backbone = backbone
self.use_pretrained_backbone = use_pretrained_backbone
self.ignore_value = ignore_value self.ignore_value = ignore_value
self.num_queries = num_queries self.num_queries = num_queries
self.no_object_weight = no_object_weight self.no_object_weight = no_object_weight
......
...@@ -92,12 +92,12 @@ class TableTransformerConfig(PretrainedConfig): ...@@ -92,12 +92,12 @@ class TableTransformerConfig(PretrainedConfig):
Whether auxiliary decoding losses (loss at each decoder layer) are to be used. Whether auxiliary decoding losses (loss at each decoder layer) are to be used.
position_embedding_type (`str`, *optional*, defaults to `"sine"`): position_embedding_type (`str`, *optional*, defaults to `"sine"`):
Type of position embeddings to be used on top of the image features. One of `"sine"` or `"learned"`. Type of position embeddings to be used on top of the image features. One of `"sine"` or `"learned"`.
backbone (`str`, *optional*, defaults to `"resnet50"`): backbone (`str`, *optional*):
Name of convolutional backbone to use in case `use_timm_backbone` = `True`. Supports any convolutional Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
backbone from the timm package. For a list of all available models, see [this will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone`
page](https://rwightman.github.io/pytorch-image-models/#load-a-pretrained-model). is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
use_pretrained_backbone (`bool`, *optional*, defaults to `True`): use_pretrained_backbone (`bool`, *optional*, `True`):
Whether to use pretrained weights for the backbone. Only supported when `use_timm_backbone` = `True`. Whether to use pretrained weights for the backbone.
dilation (`bool`, *optional*, defaults to `False`): dilation (`bool`, *optional*, defaults to `False`):
Whether to replace stride with dilation in the last convolutional block (DC5). Only supported when Whether to replace stride with dilation in the last convolutional block (DC5). Only supported when
`use_timm_backbone` = `True`. `use_timm_backbone` = `True`.
...@@ -178,6 +178,14 @@ class TableTransformerConfig(PretrainedConfig): ...@@ -178,6 +178,14 @@ class TableTransformerConfig(PretrainedConfig):
eos_coefficient=0.1, eos_coefficient=0.1,
**kwargs, **kwargs,
): ):
if not use_timm_backbone and use_pretrained_backbone:
raise ValueError(
"Loading pretrained backbone weights from the transformers library is not supported yet. `use_timm_backbone` must be set to `True` when `use_pretrained_backbone=True`"
)
if backbone_config is not None and backbone is not None:
raise ValueError("You can't specify both `backbone` and `backbone_config`.")
if backbone_config is not None and use_timm_backbone: if backbone_config is not None and use_timm_backbone:
raise ValueError("You can't specify both `backbone_config` and `use_timm_backbone`.") raise ValueError("You can't specify both `backbone_config` and `use_timm_backbone`.")
......
...@@ -43,6 +43,12 @@ class TvpConfig(PretrainedConfig): ...@@ -43,6 +43,12 @@ class TvpConfig(PretrainedConfig):
Args: Args:
backbone_config (`PretrainedConfig` or `dict`, *optional*): backbone_config (`PretrainedConfig` or `dict`, *optional*):
The configuration of the backbone model. The configuration of the backbone model.
backbone (`str`, *optional*):
Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone`
is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
use_pretrained_backbone (`bool`, *optional*, defaults to `False`):
Whether to use pretrained weights for the backbone.
distance_loss_weight (`float`, *optional*, defaults to 1.0): distance_loss_weight (`float`, *optional*, defaults to 1.0):
The weight of distance loss. The weight of distance loss.
duration_loss_weight (`float`, *optional*, defaults to 0.1): duration_loss_weight (`float`, *optional*, defaults to 0.1):
...@@ -95,6 +101,8 @@ class TvpConfig(PretrainedConfig): ...@@ -95,6 +101,8 @@ class TvpConfig(PretrainedConfig):
def __init__( def __init__(
self, self,
backbone_config=None, backbone_config=None,
backbone=None,
use_pretrained_backbone=False,
distance_loss_weight=1.0, distance_loss_weight=1.0,
duration_loss_weight=0.1, duration_loss_weight=0.1,
visual_prompter_type="framepad", visual_prompter_type="framepad",
...@@ -118,8 +126,13 @@ class TvpConfig(PretrainedConfig): ...@@ -118,8 +126,13 @@ class TvpConfig(PretrainedConfig):
**kwargs, **kwargs,
): ):
super().__init__(**kwargs) super().__init__(**kwargs)
if use_pretrained_backbone:
raise ValueError("Pretrained backbones are not supported yet.")
if backbone_config is None: if backbone_config is not None and backbone is not None:
raise ValueError("You can't specify both `backbone` and `backbone_config`.")
if backbone_config is None and backbone is None:
logger.info("`backbone_config` is `None`. Initializing the config with the default `ResNet` backbone.") logger.info("`backbone_config` is `None`. Initializing the config with the default `ResNet` backbone.")
backbone_config = CONFIG_MAPPING["resnet"](out_features=["stage4"]) backbone_config = CONFIG_MAPPING["resnet"](out_features=["stage4"])
elif isinstance(backbone_config, dict): elif isinstance(backbone_config, dict):
...@@ -128,6 +141,8 @@ class TvpConfig(PretrainedConfig): ...@@ -128,6 +141,8 @@ class TvpConfig(PretrainedConfig):
backbone_config = config_class.from_dict(backbone_config) backbone_config = config_class.from_dict(backbone_config)
self.backbone_config = backbone_config self.backbone_config = backbone_config
self.backbone = backbone
self.use_pretrained_backbone = use_pretrained_backbone
self.distance_loss_weight = distance_loss_weight self.distance_loss_weight = distance_loss_weight
self.duration_loss_weight = duration_loss_weight self.duration_loss_weight = duration_loss_weight
self.visual_prompter_type = visual_prompter_type self.visual_prompter_type = visual_prompter_type
......
...@@ -36,6 +36,12 @@ class UperNetConfig(PretrainedConfig): ...@@ -36,6 +36,12 @@ class UperNetConfig(PretrainedConfig):
Args: Args:
backbone_config (`PretrainedConfig` or `dict`, *optional*, defaults to `ResNetConfig()`): backbone_config (`PretrainedConfig` or `dict`, *optional*, defaults to `ResNetConfig()`):
The configuration of the backbone model. The configuration of the backbone model.
backbone (`str`, *optional*):
Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone`
is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
use_pretrained_backbone (`bool`, *optional*, `False`):
Whether to use pretrained weights for the backbone.
hidden_size (`int`, *optional*, defaults to 512): hidden_size (`int`, *optional*, defaults to 512):
The number of hidden units in the convolutional layers. The number of hidden units in the convolutional layers.
initializer_range (`float`, *optional*, defaults to 0.02): initializer_range (`float`, *optional*, defaults to 0.02):
...@@ -75,6 +81,8 @@ class UperNetConfig(PretrainedConfig): ...@@ -75,6 +81,8 @@ class UperNetConfig(PretrainedConfig):
def __init__( def __init__(
self, self,
backbone_config=None, backbone_config=None,
backbone=None,
use_pretrained_backbone=False,
hidden_size=512, hidden_size=512,
initializer_range=0.02, initializer_range=0.02,
pool_scales=[1, 2, 3, 6], pool_scales=[1, 2, 3, 6],
...@@ -88,8 +96,13 @@ class UperNetConfig(PretrainedConfig): ...@@ -88,8 +96,13 @@ class UperNetConfig(PretrainedConfig):
**kwargs, **kwargs,
): ):
super().__init__(**kwargs) super().__init__(**kwargs)
if use_pretrained_backbone:
raise ValueError("Pretrained backbones are not supported yet.")
if backbone_config is None: if backbone_config is not None and backbone is not None:
raise ValueError("You can't specify both `backbone` and `backbone_config`.")
if backbone_config is None and backbone is None:
logger.info("`backbone_config` is `None`. Initializing the config with the default `ResNet` backbone.") logger.info("`backbone_config` is `None`. Initializing the config with the default `ResNet` backbone.")
backbone_config = CONFIG_MAPPING["resnet"](out_features=["stage1", "stage2", "stage3", "stage4"]) backbone_config = CONFIG_MAPPING["resnet"](out_features=["stage1", "stage2", "stage3", "stage4"])
elif isinstance(backbone_config, dict): elif isinstance(backbone_config, dict):
...@@ -98,6 +111,8 @@ class UperNetConfig(PretrainedConfig): ...@@ -98,6 +111,8 @@ class UperNetConfig(PretrainedConfig):
backbone_config = config_class.from_dict(backbone_config) backbone_config = config_class.from_dict(backbone_config)
self.backbone_config = backbone_config self.backbone_config = backbone_config
self.backbone = backbone
self.use_pretrained_backbone = use_pretrained_backbone
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.initializer_range = initializer_range self.initializer_range = initializer_range
self.pool_scales = pool_scales self.pool_scales = pool_scales
......
...@@ -42,6 +42,12 @@ class ViTHybridConfig(PretrainedConfig): ...@@ -42,6 +42,12 @@ class ViTHybridConfig(PretrainedConfig):
Args: Args:
backbone_config (`Union[Dict[str, Any], PretrainedConfig]`, *optional*): backbone_config (`Union[Dict[str, Any], PretrainedConfig]`, *optional*):
The configuration of the backbone in a dictionary or the config object of the backbone. The configuration of the backbone in a dictionary or the config object of the backbone.
backbone (`str`, *optional*):
Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone`
is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
use_pretrained_backbone (`bool`, *optional*, defaults to `False`):
Whether to use pretrained weights for the backbone.
hidden_size (`int`, *optional*, defaults to 768): hidden_size (`int`, *optional*, defaults to 768):
Dimensionality of the encoder layers and the pooler layer. Dimensionality of the encoder layers and the pooler layer.
num_hidden_layers (`int`, *optional*, defaults to 12): num_hidden_layers (`int`, *optional*, defaults to 12):
...@@ -92,6 +98,8 @@ class ViTHybridConfig(PretrainedConfig): ...@@ -92,6 +98,8 @@ class ViTHybridConfig(PretrainedConfig):
def __init__( def __init__(
self, self,
backbone_config=None, backbone_config=None,
backbone=None,
use_pretrained_backbone=False,
hidden_size=768, hidden_size=768,
num_hidden_layers=12, num_hidden_layers=12,
num_attention_heads=12, num_attention_heads=12,
...@@ -109,8 +117,13 @@ class ViTHybridConfig(PretrainedConfig): ...@@ -109,8 +117,13 @@ class ViTHybridConfig(PretrainedConfig):
**kwargs, **kwargs,
): ):
super().__init__(**kwargs) super().__init__(**kwargs)
if use_pretrained_backbone:
raise ValueError("Pretrained backbones are not supported yet.")
if backbone_config is None: if backbone_config is not None and backbone is not None:
raise ValueError("You can't specify both `backbone` and `backbone_config`.")
if backbone_config is None and backbone is None:
logger.info("`backbone_config` is `None`. Initializing the config with a `BiT` backbone.") logger.info("`backbone_config` is `None`. Initializing the config with a `BiT` backbone.")
backbone_config = { backbone_config = {
"global_padding": "same", "global_padding": "same",
...@@ -132,6 +145,8 @@ class ViTHybridConfig(PretrainedConfig): ...@@ -132,6 +145,8 @@ class ViTHybridConfig(PretrainedConfig):
self.backbone_featmap_shape = backbone_featmap_shape self.backbone_featmap_shape = backbone_featmap_shape
self.backbone_config = backbone_config self.backbone_config = backbone_config
self.backbone = backbone
self.use_pretrained_backbone = use_pretrained_backbone
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads self.num_attention_heads = num_attention_heads
......
...@@ -42,6 +42,12 @@ class VitMatteConfig(PretrainedConfig): ...@@ -42,6 +42,12 @@ class VitMatteConfig(PretrainedConfig):
Args: Args:
backbone_config (`PretrainedConfig` or `dict`, *optional*, defaults to `VitDetConfig()`): backbone_config (`PretrainedConfig` or `dict`, *optional*, defaults to `VitDetConfig()`):
The configuration of the backbone model. The configuration of the backbone model.
backbone (`str`, *optional*):
Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this
will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone`
is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights.
use_pretrained_backbone (`bool`, *optional*, defaults to `False`):
Whether to use pretrained weights for the backbone.
hidden_size (`int`, *optional*, defaults to 384): hidden_size (`int`, *optional*, defaults to 384):
The number of input channels of the decoder. The number of input channels of the decoder.
batch_norm_eps (`float`, *optional*, defaults to 1e-05): batch_norm_eps (`float`, *optional*, defaults to 1e-05):
...@@ -73,6 +79,8 @@ class VitMatteConfig(PretrainedConfig): ...@@ -73,6 +79,8 @@ class VitMatteConfig(PretrainedConfig):
def __init__( def __init__(
self, self,
backbone_config: PretrainedConfig = None, backbone_config: PretrainedConfig = None,
backbone=None,
use_pretrained_backbone=False,
hidden_size: int = 384, hidden_size: int = 384,
batch_norm_eps: float = 1e-5, batch_norm_eps: float = 1e-5,
initializer_range: float = 0.02, initializer_range: float = 0.02,
...@@ -82,7 +90,13 @@ class VitMatteConfig(PretrainedConfig): ...@@ -82,7 +90,13 @@ class VitMatteConfig(PretrainedConfig):
): ):
super().__init__(**kwargs) super().__init__(**kwargs)
if backbone_config is None: if use_pretrained_backbone:
raise ValueError("Pretrained backbones are not supported yet.")
if backbone_config is not None and backbone is not None:
raise ValueError("You can't specify both `backbone` and `backbone_config`.")
if backbone_config is None and backbone is None:
logger.info("`backbone_config` is `None`. Initializing the config with the default `VitDet` backbone.") logger.info("`backbone_config` is `None`. Initializing the config with the default `VitDet` backbone.")
backbone_config = CONFIG_MAPPING["vitdet"](out_features=["stage4"]) backbone_config = CONFIG_MAPPING["vitdet"](out_features=["stage4"])
elif isinstance(backbone_config, dict): elif isinstance(backbone_config, dict):
...@@ -91,6 +105,8 @@ class VitMatteConfig(PretrainedConfig): ...@@ -91,6 +105,8 @@ class VitMatteConfig(PretrainedConfig):
backbone_config = config_class.from_dict(backbone_config) backbone_config = config_class.from_dict(backbone_config)
self.backbone_config = backbone_config self.backbone_config = backbone_config
self.backbone = backbone
self.use_pretrained_backbone = use_pretrained_backbone
self.batch_norm_eps = batch_norm_eps self.batch_norm_eps = batch_norm_eps
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.initializer_range = initializer_range self.initializer_range = initializer_range
......
...@@ -286,3 +286,56 @@ class BackboneConfigMixin: ...@@ -286,3 +286,56 @@ class BackboneConfigMixin:
output["out_features"] = output.pop("_out_features") output["out_features"] = output.pop("_out_features")
output["out_indices"] = output.pop("_out_indices") output["out_indices"] = output.pop("_out_indices")
return output return output
def load_backbone(config):
"""
Loads the backbone model from a config object.
If the config is from the backbone model itself, then we return a backbone model with randomly initialized
weights.
If the config is from the parent model of the backbone model itself, then we load the pretrained backbone weights
if specified.
"""
from transformers import AutoBackbone, AutoConfig
backbone_config = getattr(config, "backbone_config", None)
use_timm_backbone = getattr(config, "use_timm_backbone", None)
use_pretrained_backbone = getattr(config, "use_pretrained_backbone", None)
backbone_checkpoint = getattr(config, "backbone", None)
# If there is a backbone_config and a backbone checkpoint, and use_pretrained_backbone=False then the desired
# behaviour is ill-defined: do you want to load from the checkpoint's config or the backbone_config?
if backbone_config is not None and backbone_checkpoint is not None and use_pretrained_backbone is not None:
raise ValueError("Cannot specify both config.backbone_config and config.backbone")
# If any of thhe following are set, then the config passed in is from a model which contains a backbone.
if (
backbone_config is None
and use_timm_backbone is None
and backbone_checkpoint is None
and backbone_checkpoint is None
):
return AutoBackbone.from_config(config=config)
# config from the parent model that has a backbone
if use_timm_backbone:
if backbone_checkpoint is None:
raise ValueError("config.backbone must be set if use_timm_backbone is True")
# Because of how timm backbones were originally added to models, we need to pass in use_pretrained_backbone
# to determine whether to load the pretrained weights.
backbone = AutoBackbone.from_pretrained(
backbone_checkpoint, use_timm_backbone=use_timm_backbone, use_pretrained_backbone=use_pretrained_backbone
)
elif use_pretrained_backbone:
if backbone_checkpoint is None:
raise ValueError("config.backbone must be set if use_pretrained_backbone is True")
backbone = AutoBackbone.from_pretrained(backbone_checkpoint)
else:
if backbone_config is None and backbone_checkpoint is None:
raise ValueError("Either config.backbone_config or config.backbone must be set")
if backbone_config is None:
backbone_config = AutoConfig.from_pretrained(backbone_checkpoint)
backbone = AutoBackbone.from_config(config=backbone_config)
return backbone
...@@ -134,6 +134,8 @@ class ConditionalDetrModelTester: ...@@ -134,6 +134,8 @@ class ConditionalDetrModelTester:
num_labels=self.num_labels, num_labels=self.num_labels,
use_timm_backbone=False, use_timm_backbone=False,
backbone_config=resnet_config, backbone_config=resnet_config,
backbone=None,
use_pretrained_backbone=False,
) )
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
......
...@@ -149,7 +149,9 @@ class DeformableDetrModelTester: ...@@ -149,7 +149,9 @@ class DeformableDetrModelTester:
encoder_n_points=self.encoder_n_points, encoder_n_points=self.encoder_n_points,
decoder_n_points=self.decoder_n_points, decoder_n_points=self.decoder_n_points,
use_timm_backbone=False, use_timm_backbone=False,
backbone=None,
backbone_config=resnet_config, backbone_config=resnet_config,
use_pretrained_backbone=False,
) )
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
...@@ -518,6 +520,8 @@ class DeformableDetrModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT ...@@ -518,6 +520,8 @@ class DeformableDetrModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT
# let's pick a random timm backbone # let's pick a random timm backbone
config.backbone = "tf_mobilenetv3_small_075" config.backbone = "tf_mobilenetv3_small_075"
config.use_timm_backbone = True
config.backbone_config = None
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
model = model_class(config) model = model_class(config)
......
...@@ -157,6 +157,7 @@ class DetaModelTester: ...@@ -157,6 +157,7 @@ class DetaModelTester:
assign_first_stage=assign_first_stage, assign_first_stage=assign_first_stage,
assign_second_stage=assign_second_stage, assign_second_stage=assign_second_stage,
backbone_config=resnet_config, backbone_config=resnet_config,
backbone=None,
) )
def prepare_config_and_inputs_for_common(self, model_class_name="DetaModel"): def prepare_config_and_inputs_for_common(self, model_class_name="DetaModel"):
......
...@@ -130,6 +130,8 @@ class DetrModelTester: ...@@ -130,6 +130,8 @@ class DetrModelTester:
num_labels=self.num_labels, num_labels=self.num_labels,
use_timm_backbone=False, use_timm_backbone=False,
backbone_config=resnet_config, backbone_config=resnet_config,
backbone=None,
use_pretrained_backbone=False,
) )
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
...@@ -622,7 +624,7 @@ class DetrModelIntegrationTestsTimmBackbone(unittest.TestCase): ...@@ -622,7 +624,7 @@ class DetrModelIntegrationTestsTimmBackbone(unittest.TestCase):
torch_device torch_device
) )
expected_number_of_segments = 5 expected_number_of_segments = 5
expected_first_segment = {"id": 1, "label_id": 17, "was_fused": False, "score": 0.994096} expected_first_segment = {"id": 1, "label_id": 17, "was_fused": False, "score": 0.994097}
number_of_unique_segments = len(torch.unique(results["segmentation"])) number_of_unique_segments = len(torch.unique(results["segmentation"]))
self.assertTrue( self.assertTrue(
......
...@@ -95,6 +95,7 @@ class DPTModelTester: ...@@ -95,6 +95,7 @@ class DPTModelTester:
def get_config(self): def get_config(self):
return DPTConfig( return DPTConfig(
backbone_config=self.get_backbone_config(), backbone_config=self.get_backbone_config(),
backbone=None,
neck_hidden_sizes=self.neck_hidden_sizes, neck_hidden_sizes=self.neck_hidden_sizes,
fusion_hidden_size=self.fusion_hidden_size, fusion_hidden_size=self.fusion_hidden_size,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment