Unverified Commit b610c47f authored by NielsRogge's avatar NielsRogge Committed by GitHub
Browse files

[MaskFormer] Add support for ResNet backbone (#20483)



* Add SwinBackbone

* Add hidden_states_before_downsampling support

* Fix Swin tests

* Improve conversion script

* Add id2label mappings

* Add vistas mapping

* Update comments

* Fix backbone

* Improve tests

* Extend conversion script

* Add Swin conversion script

* Fix style

* Revert config attribute

* Remove SwinBackbone from main init

* Remove unused attribute

* Use encoder for ResNet backbone

* Improve conversion script and add integration test

* Apply suggestion
Co-authored-by: default avatarNiels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
parent 6c1a0b39
...@@ -18,7 +18,7 @@ from typing import Dict, Optional ...@@ -18,7 +18,7 @@ from typing import Dict, Optional
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...utils import logging from ...utils import logging
from ..auto.configuration_auto import AutoConfig from ..auto import CONFIG_MAPPING
from ..detr import DetrConfig from ..detr import DetrConfig
from ..swin import SwinConfig from ..swin import SwinConfig
...@@ -97,7 +97,7 @@ class MaskFormerConfig(PretrainedConfig): ...@@ -97,7 +97,7 @@ class MaskFormerConfig(PretrainedConfig):
""" """
model_type = "maskformer" model_type = "maskformer"
attribute_map = {"hidden_size": "mask_feature_size"} attribute_map = {"hidden_size": "mask_feature_size"}
backbones_supported = ["swin"] backbones_supported = ["resnet", "swin"]
decoders_supported = ["detr"] decoders_supported = ["detr"]
def __init__( def __init__(
...@@ -127,27 +127,38 @@ class MaskFormerConfig(PretrainedConfig): ...@@ -127,27 +127,38 @@ class MaskFormerConfig(PretrainedConfig):
num_heads=[4, 8, 16, 32], num_heads=[4, 8, 16, 32],
window_size=12, window_size=12,
drop_path_rate=0.3, drop_path_rate=0.3,
out_features=["stage1", "stage2", "stage3", "stage4"],
) )
else: else:
backbone_model_type = backbone_config.pop("model_type") # verify that the backbone is supported
backbone_model_type = (
backbone_config.pop("model_type") if isinstance(backbone_config, dict) else backbone_config.model_type
)
if backbone_model_type not in self.backbones_supported: if backbone_model_type not in self.backbones_supported:
raise ValueError( raise ValueError(
f"Backbone {backbone_model_type} not supported, please use one of" f"Backbone {backbone_model_type} not supported, please use one of"
f" {','.join(self.backbones_supported)}" f" {','.join(self.backbones_supported)}"
) )
backbone_config = AutoConfig.for_model(backbone_model_type, **backbone_config) if isinstance(backbone_config, dict):
config_class = CONFIG_MAPPING[backbone_model_type]
backbone_config = config_class.from_dict(backbone_config)
if decoder_config is None: if decoder_config is None:
# fall back to https://huggingface.co/facebook/detr-resnet-50 # fall back to https://huggingface.co/facebook/detr-resnet-50
decoder_config = DetrConfig() decoder_config = DetrConfig()
else: else:
decoder_type = decoder_config.pop("model_type") # verify that the decoder is supported
decoder_type = (
decoder_config.pop("model_type") if isinstance(decoder_config, dict) else decoder_config.model_type
)
if decoder_type not in self.decoders_supported: if decoder_type not in self.decoders_supported:
raise ValueError( raise ValueError(
f"Transformer Decoder {decoder_type} not supported, please use one of" f"Transformer Decoder {decoder_type} not supported, please use one of"
f" {','.join(self.decoders_supported)}" f" {','.join(self.decoders_supported)}"
) )
decoder_config = AutoConfig.for_model(decoder_type, **decoder_config) if isinstance(decoder_config, dict):
config_class = CONFIG_MAPPING[decoder_type]
decoder_config = config_class.from_dict(decoder_config)
self.backbone_config = backbone_config self.backbone_config = backbone_config
self.decoder_config = decoder_config self.decoder_config = decoder_config
...@@ -186,8 +197,8 @@ class MaskFormerConfig(PretrainedConfig): ...@@ -186,8 +197,8 @@ class MaskFormerConfig(PretrainedConfig):
[`MaskFormerConfig`]: An instance of a configuration object [`MaskFormerConfig`]: An instance of a configuration object
""" """
return cls( return cls(
backbone_config=backbone_config.to_dict(), backbone_config=backbone_config,
decoder_config=decoder_config.to_dict(), decoder_config=decoder_config,
**kwargs, **kwargs,
) )
......
...@@ -69,7 +69,7 @@ class MaskFormerSwinConfig(PretrainedConfig): ...@@ -69,7 +69,7 @@ class MaskFormerSwinConfig(PretrainedConfig):
layer_norm_eps (`float`, *optional*, defaults to 1e-12): layer_norm_eps (`float`, *optional*, defaults to 1e-12):
The epsilon used by the layer normalization layers. The epsilon used by the layer normalization layers.
out_features (`List[str]`, *optional*): out_features (`List[str]`, *optional*):
If used as a backbone, list of feature names to output, e.g. `["stem", "stage1"]`. If used as a backbone, list of feature names to output, e.g. `["stage1", "stage2"]`.
Example: Example:
......
...@@ -275,7 +275,6 @@ class RegNetEncoder(nn.Module): ...@@ -275,7 +275,6 @@ class RegNetEncoder(nn.Module):
return BaseModelOutputWithNoAttention(last_hidden_state=hidden_state, hidden_states=hidden_states) return BaseModelOutputWithNoAttention(last_hidden_state=hidden_state, hidden_states=hidden_states)
# Copied from transformers.models.resnet.modeling_resnet.ResNetPreTrainedModel with ResNet->RegNet,resnet->regnet
class RegNetPreTrainedModel(PreTrainedModel): class RegNetPreTrainedModel(PreTrainedModel):
""" """
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
...@@ -287,6 +286,7 @@ class RegNetPreTrainedModel(PreTrainedModel): ...@@ -287,6 +286,7 @@ class RegNetPreTrainedModel(PreTrainedModel):
main_input_name = "pixel_values" main_input_name = "pixel_values"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
# Copied from transformers.models.resnet.modeling_resnet.ResNetPreTrainedModel._init_weights
def _init_weights(self, module): def _init_weights(self, module):
if isinstance(module, nn.Conv2d): if isinstance(module, nn.Conv2d):
nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu") nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
......
...@@ -267,7 +267,7 @@ class ResNetPreTrainedModel(PreTrainedModel): ...@@ -267,7 +267,7 @@ class ResNetPreTrainedModel(PreTrainedModel):
nn.init.constant_(module.bias, 0) nn.init.constant_(module.bias, 0)
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, ResNetModel): if isinstance(module, (ResNetModel, ResNetBackbone)):
module.gradient_checkpointing = value module.gradient_checkpointing = value
...@@ -436,7 +436,8 @@ class ResNetBackbone(ResNetPreTrainedModel): ...@@ -436,7 +436,8 @@ class ResNetBackbone(ResNetPreTrainedModel):
super().__init__(config) super().__init__(config)
self.stage_names = config.stage_names self.stage_names = config.stage_names
self.resnet = ResNetModel(config) self.embedder = ResNetEmbeddings(config)
self.encoder = ResNetEncoder(config)
self.out_features = config.out_features self.out_features = config.out_features
...@@ -490,7 +491,9 @@ class ResNetBackbone(ResNetPreTrainedModel): ...@@ -490,7 +491,9 @@ class ResNetBackbone(ResNetPreTrainedModel):
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
) )
outputs = self.resnet(pixel_values, output_hidden_states=True, return_dict=True) embedding_output = self.embedder(pixel_values)
outputs = self.encoder(embedding_output, output_hidden_states=True, return_dict=True)
hidden_states = outputs.hidden_states hidden_states = outputs.hidden_states
......
...@@ -84,7 +84,7 @@ class SwinConfig(PretrainedConfig): ...@@ -84,7 +84,7 @@ class SwinConfig(PretrainedConfig):
encoder_stride (`int`, `optional`, defaults to 32): encoder_stride (`int`, `optional`, defaults to 32):
Factor to increase the spatial resolution by in the decoder head for masked image modeling. Factor to increase the spatial resolution by in the decoder head for masked image modeling.
Example: Example:
```python ```python
>>> from transformers import SwinConfig, SwinModel >>> from transformers import SwinConfig, SwinModel
......
...@@ -320,16 +320,16 @@ def prepare_img(): ...@@ -320,16 +320,16 @@ def prepare_img():
@require_vision @require_vision
@slow @slow
class MaskFormerModelIntegrationTest(unittest.TestCase): class MaskFormerModelIntegrationTest(unittest.TestCase):
@cached_property
def model_checkpoints(self):
return "facebook/maskformer-swin-small-coco"
@cached_property @cached_property
def default_feature_extractor(self): def default_feature_extractor(self):
return MaskFormerFeatureExtractor.from_pretrained(self.model_checkpoints) if is_vision_available() else None return (
MaskFormerFeatureExtractor.from_pretrained("facebook/maskformer-swin-small-coco")
if is_vision_available()
else None
)
def test_inference_no_head(self): def test_inference_no_head(self):
model = MaskFormerModel.from_pretrained(self.model_checkpoints).to(torch_device) model = MaskFormerModel.from_pretrained("facebook/maskformer-swin-small-coco").to(torch_device)
feature_extractor = self.default_feature_extractor feature_extractor = self.default_feature_extractor
image = prepare_img() image = prepare_img()
inputs = feature_extractor(image, return_tensors="pt").to(torch_device) inputs = feature_extractor(image, return_tensors="pt").to(torch_device)
...@@ -370,7 +370,11 @@ class MaskFormerModelIntegrationTest(unittest.TestCase): ...@@ -370,7 +370,11 @@ class MaskFormerModelIntegrationTest(unittest.TestCase):
) )
def test_inference_instance_segmentation_head(self): def test_inference_instance_segmentation_head(self):
model = MaskFormerForInstanceSegmentation.from_pretrained(self.model_checkpoints).to(torch_device).eval() model = (
MaskFormerForInstanceSegmentation.from_pretrained("facebook/maskformer-swin-small-coco")
.to(torch_device)
.eval()
)
feature_extractor = self.default_feature_extractor feature_extractor = self.default_feature_extractor
image = prepare_img() image = prepare_img()
inputs = feature_extractor(image, return_tensors="pt").to(torch_device) inputs = feature_extractor(image, return_tensors="pt").to(torch_device)
...@@ -385,7 +389,8 @@ class MaskFormerModelIntegrationTest(unittest.TestCase): ...@@ -385,7 +389,8 @@ class MaskFormerModelIntegrationTest(unittest.TestCase):
# masks_queries_logits # masks_queries_logits
masks_queries_logits = outputs.masks_queries_logits masks_queries_logits = outputs.masks_queries_logits
self.assertEqual( self.assertEqual(
masks_queries_logits.shape, (1, model.config.num_queries, inputs_shape[-2] // 4, inputs_shape[-1] // 4) masks_queries_logits.shape,
(1, model.config.decoder_config.num_queries, inputs_shape[-2] // 4, inputs_shape[-1] // 4),
) )
expected_slice = [ expected_slice = [
[-1.3737124, -1.7724937, -1.9364233], [-1.3737124, -1.7724937, -1.9364233],
...@@ -396,7 +401,9 @@ class MaskFormerModelIntegrationTest(unittest.TestCase): ...@@ -396,7 +401,9 @@ class MaskFormerModelIntegrationTest(unittest.TestCase):
self.assertTrue(torch.allclose(masks_queries_logits[0, 0, :3, :3], expected_slice, atol=TOLERANCE)) self.assertTrue(torch.allclose(masks_queries_logits[0, 0, :3, :3], expected_slice, atol=TOLERANCE))
# class_queries_logits # class_queries_logits
class_queries_logits = outputs.class_queries_logits class_queries_logits = outputs.class_queries_logits
self.assertEqual(class_queries_logits.shape, (1, model.config.num_queries, model.config.num_labels + 1)) self.assertEqual(
class_queries_logits.shape, (1, model.config.decoder_config.num_queries, model.config.num_labels + 1)
)
expected_slice = torch.tensor( expected_slice = torch.tensor(
[ [
[1.6512e00, -5.2572e00, -3.3519e00], [1.6512e00, -5.2572e00, -3.3519e00],
...@@ -406,8 +413,48 @@ class MaskFormerModelIntegrationTest(unittest.TestCase): ...@@ -406,8 +413,48 @@ class MaskFormerModelIntegrationTest(unittest.TestCase):
).to(torch_device) ).to(torch_device)
self.assertTrue(torch.allclose(outputs.class_queries_logits[0, :3, :3], expected_slice, atol=TOLERANCE)) self.assertTrue(torch.allclose(outputs.class_queries_logits[0, :3, :3], expected_slice, atol=TOLERANCE))
def test_inference_instance_segmentation_head_resnet_backbone(self):
model = (
MaskFormerForInstanceSegmentation.from_pretrained("facebook/maskformer-resnet101-coco-stuff")
.to(torch_device)
.eval()
)
feature_extractor = self.default_feature_extractor
image = prepare_img()
inputs = feature_extractor(image, return_tensors="pt").to(torch_device)
inputs_shape = inputs["pixel_values"].shape
# check size is divisible by 32
self.assertTrue((inputs_shape[-1] % 32) == 0 and (inputs_shape[-2] % 32) == 0)
# check size
self.assertEqual(inputs_shape, (1, 3, 800, 1088))
with torch.no_grad():
outputs = model(**inputs)
# masks_queries_logits
masks_queries_logits = outputs.masks_queries_logits
self.assertEqual(
masks_queries_logits.shape,
(1, model.config.decoder_config.num_queries, inputs_shape[-2] // 4, inputs_shape[-1] // 4),
)
expected_slice = [[-0.9046, -2.6366, -4.6062], [-3.4179, -5.7890, -8.8057], [-4.9179, -7.6560, -10.7711]]
expected_slice = torch.tensor(expected_slice).to(torch_device)
self.assertTrue(torch.allclose(masks_queries_logits[0, 0, :3, :3], expected_slice, atol=TOLERANCE))
# class_queries_logits
class_queries_logits = outputs.class_queries_logits
self.assertEqual(
class_queries_logits.shape, (1, model.config.decoder_config.num_queries, model.config.num_labels + 1)
)
expected_slice = torch.tensor(
[[4.7188, -3.2585, -2.8857], [6.6871, -2.9181, -1.2487], [7.2449, -2.2764, -2.1874]]
).to(torch_device)
self.assertTrue(torch.allclose(outputs.class_queries_logits[0, :3, :3], expected_slice, atol=TOLERANCE))
def test_with_segmentation_maps_and_loss(self): def test_with_segmentation_maps_and_loss(self):
model = MaskFormerForInstanceSegmentation.from_pretrained(self.model_checkpoints).to(torch_device).eval() model = (
MaskFormerForInstanceSegmentation.from_pretrained("facebook/maskformer-swin-small-coco")
.to(torch_device)
.eval()
)
feature_extractor = self.default_feature_extractor feature_extractor = self.default_feature_extractor
inputs = feature_extractor( inputs = feature_extractor(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment