Unverified Commit 462f77cb authored by amyeroberts's avatar amyeroberts Committed by GitHub
Browse files

Allow backbones not in backbones_supported - Maskformer Mask2Former (#24532)

Allow backbones not in backbones_supported
parent 8e5d1619
...@@ -170,11 +170,19 @@ class Mask2FormerConfig(PretrainedConfig): ...@@ -170,11 +170,19 @@ class Mask2FormerConfig(PretrainedConfig):
use_absolute_embeddings=False, use_absolute_embeddings=False,
out_features=["stage1", "stage2", "stage3", "stage4"], out_features=["stage1", "stage2", "stage3", "stage4"],
) )
elif isinstance(backbone_config, dict):
backbone_model_type = backbone_config.get("model_type") if isinstance(backbone_config, dict):
backbone_model_type = backbone_config.pop("model_type")
config_class = CONFIG_MAPPING[backbone_model_type] config_class = CONFIG_MAPPING[backbone_model_type]
backbone_config = config_class.from_dict(backbone_config) backbone_config = config_class.from_dict(backbone_config)
# verify that the backbone is supported
if backbone_config.model_type not in self.backbones_supported:
logger.warning_once(
f"Backbone {backbone_config.model_type} is not a supported model and may not be compatible with Mask2Former. "
f"Supported model types: {','.join(self.backbones_supported)}"
)
self.backbone_config = backbone_config self.backbone_config = backbone_config
self.feature_size = feature_size self.feature_size = feature_size
self.mask_feature_size = mask_feature_size self.mask_feature_size = mask_feature_size
......
...@@ -129,19 +129,18 @@ class MaskFormerConfig(PretrainedConfig): ...@@ -129,19 +129,18 @@ class MaskFormerConfig(PretrainedConfig):
drop_path_rate=0.3, drop_path_rate=0.3,
out_features=["stage1", "stage2", "stage3", "stage4"], out_features=["stage1", "stage2", "stage3", "stage4"],
) )
else:
# verify that the backbone is supported if isinstance(backbone_config, dict):
backbone_model_type = ( backbone_model_type = backbone_config.pop("model_type")
backbone_config.pop("model_type") if isinstance(backbone_config, dict) else backbone_config.model_type config_class = CONFIG_MAPPING[backbone_model_type]
backbone_config = config_class.from_dict(backbone_config)
# verify that the backbone is supported
if backbone_config.model_type not in self.backbones_supported:
logger.warning_once(
f"Backbone {backbone_config.model_type} is not a supported model and may not be compatible with MaskFormer. "
f"Supported model types: {','.join(self.backbones_supported)}"
) )
if backbone_model_type not in self.backbones_supported:
raise ValueError(
f"Backbone {backbone_model_type} not supported, please use one of"
f" {','.join(self.backbones_supported)}"
)
if isinstance(backbone_config, dict):
config_class = CONFIG_MAPPING[backbone_model_type]
backbone_config = config_class.from_dict(backbone_config)
if decoder_config is None: if decoder_config is None:
# fall back to https://huggingface.co/facebook/detr-resnet-50 # fall back to https://huggingface.co/facebook/detr-resnet-50
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment