Unverified Commit b610c47f authored by NielsRogge's avatar NielsRogge Committed by GitHub
Browse files

[MaskFormer] Add support for ResNet backbone (#20483)



* Add SwinBackbone

* Add hidden_states_before_downsampling support

* Fix Swin tests

* Improve conversion script

* Add id2label mappings

* Add vistas mapping

* Update comments

* Fix backbone

* Improve tests

* Extend conversion script

* Add Swin conversion script

* Fix style

* Revert config attribute

* Remove SwinBackbone from main init

* Remove unused attribute

* Use encoder for ResNet backbone

* Improve conversion script and add integration test

* Apply suggestion
Co-authored-by: default avatarNiels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
parent 6c1a0b39
......@@ -18,7 +18,7 @@ from typing import Dict, Optional
from ...configuration_utils import PretrainedConfig
from ...utils import logging
from ..auto.configuration_auto import AutoConfig
from ..auto import CONFIG_MAPPING
from ..detr import DetrConfig
from ..swin import SwinConfig
......@@ -97,7 +97,7 @@ class MaskFormerConfig(PretrainedConfig):
"""
model_type = "maskformer"
attribute_map = {"hidden_size": "mask_feature_size"}
backbones_supported = ["swin"]
backbones_supported = ["resnet", "swin"]
decoders_supported = ["detr"]
def __init__(
......@@ -127,27 +127,38 @@ class MaskFormerConfig(PretrainedConfig):
num_heads=[4, 8, 16, 32],
window_size=12,
drop_path_rate=0.3,
out_features=["stage1", "stage2", "stage3", "stage4"],
)
else:
backbone_model_type = backbone_config.pop("model_type")
# verify that the backbone is supported
backbone_model_type = (
backbone_config.pop("model_type") if isinstance(backbone_config, dict) else backbone_config.model_type
)
if backbone_model_type not in self.backbones_supported:
raise ValueError(
f"Backbone {backbone_model_type} not supported, please use one of"
f" {','.join(self.backbones_supported)}"
)
backbone_config = AutoConfig.for_model(backbone_model_type, **backbone_config)
if isinstance(backbone_config, dict):
config_class = CONFIG_MAPPING[backbone_model_type]
backbone_config = config_class.from_dict(backbone_config)
if decoder_config is None:
# fall back to https://huggingface.co/facebook/detr-resnet-50
decoder_config = DetrConfig()
else:
decoder_type = decoder_config.pop("model_type")
# verify that the decoder is supported
decoder_type = (
decoder_config.pop("model_type") if isinstance(decoder_config, dict) else decoder_config.model_type
)
if decoder_type not in self.decoders_supported:
raise ValueError(
f"Transformer Decoder {decoder_type} not supported, please use one of"
f" {','.join(self.decoders_supported)}"
)
decoder_config = AutoConfig.for_model(decoder_type, **decoder_config)
if isinstance(decoder_config, dict):
config_class = CONFIG_MAPPING[decoder_type]
decoder_config = config_class.from_dict(decoder_config)
self.backbone_config = backbone_config
self.decoder_config = decoder_config
......@@ -186,8 +197,8 @@ class MaskFormerConfig(PretrainedConfig):
[`MaskFormerConfig`]: An instance of a configuration object
"""
return cls(
backbone_config=backbone_config.to_dict(),
decoder_config=decoder_config.to_dict(),
backbone_config=backbone_config,
decoder_config=decoder_config,
**kwargs,
)
......
......@@ -69,7 +69,7 @@ class MaskFormerSwinConfig(PretrainedConfig):
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
The epsilon used by the layer normalization layers.
out_features (`List[str]`, *optional*):
If used as a backbone, list of feature names to output, e.g. `["stem", "stage1"]`.
If used as a backbone, list of feature names to output, e.g. `["stage1", "stage2"]`.
Example:
......
......@@ -275,7 +275,6 @@ class RegNetEncoder(nn.Module):
return BaseModelOutputWithNoAttention(last_hidden_state=hidden_state, hidden_states=hidden_states)
# Copied from transformers.models.resnet.modeling_resnet.ResNetPreTrainedModel with ResNet->RegNet,resnet->regnet
class RegNetPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
......@@ -287,6 +286,7 @@ class RegNetPreTrainedModel(PreTrainedModel):
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
# Copied from transformers.models.resnet.modeling_resnet.ResNetPreTrainedModel._init_weights
def _init_weights(self, module):
if isinstance(module, nn.Conv2d):
nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
......
......@@ -267,7 +267,7 @@ class ResNetPreTrainedModel(PreTrainedModel):
nn.init.constant_(module.bias, 0)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, ResNetModel):
if isinstance(module, (ResNetModel, ResNetBackbone)):
module.gradient_checkpointing = value
......@@ -436,7 +436,8 @@ class ResNetBackbone(ResNetPreTrainedModel):
super().__init__(config)
self.stage_names = config.stage_names
self.resnet = ResNetModel(config)
self.embedder = ResNetEmbeddings(config)
self.encoder = ResNetEncoder(config)
self.out_features = config.out_features
......@@ -490,7 +491,9 @@ class ResNetBackbone(ResNetPreTrainedModel):
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
outputs = self.resnet(pixel_values, output_hidden_states=True, return_dict=True)
embedding_output = self.embedder(pixel_values)
outputs = self.encoder(embedding_output, output_hidden_states=True, return_dict=True)
hidden_states = outputs.hidden_states
......
......@@ -84,7 +84,7 @@ class SwinConfig(PretrainedConfig):
encoder_stride (`int`, `optional`, defaults to 32):
Factor to increase the spatial resolution by in the decoder head for masked image modeling.
Example:
Example:
```python
>>> from transformers import SwinConfig, SwinModel
......
......@@ -320,16 +320,16 @@ def prepare_img():
@require_vision
@slow
class MaskFormerModelIntegrationTest(unittest.TestCase):
@cached_property
def model_checkpoints(self):
return "facebook/maskformer-swin-small-coco"
@cached_property
def default_feature_extractor(self):
return MaskFormerFeatureExtractor.from_pretrained(self.model_checkpoints) if is_vision_available() else None
return (
MaskFormerFeatureExtractor.from_pretrained("facebook/maskformer-swin-small-coco")
if is_vision_available()
else None
)
def test_inference_no_head(self):
model = MaskFormerModel.from_pretrained(self.model_checkpoints).to(torch_device)
model = MaskFormerModel.from_pretrained("facebook/maskformer-swin-small-coco").to(torch_device)
feature_extractor = self.default_feature_extractor
image = prepare_img()
inputs = feature_extractor(image, return_tensors="pt").to(torch_device)
......@@ -370,7 +370,11 @@ class MaskFormerModelIntegrationTest(unittest.TestCase):
)
def test_inference_instance_segmentation_head(self):
model = MaskFormerForInstanceSegmentation.from_pretrained(self.model_checkpoints).to(torch_device).eval()
model = (
MaskFormerForInstanceSegmentation.from_pretrained("facebook/maskformer-swin-small-coco")
.to(torch_device)
.eval()
)
feature_extractor = self.default_feature_extractor
image = prepare_img()
inputs = feature_extractor(image, return_tensors="pt").to(torch_device)
......@@ -385,7 +389,8 @@ class MaskFormerModelIntegrationTest(unittest.TestCase):
# masks_queries_logits
masks_queries_logits = outputs.masks_queries_logits
self.assertEqual(
masks_queries_logits.shape, (1, model.config.num_queries, inputs_shape[-2] // 4, inputs_shape[-1] // 4)
masks_queries_logits.shape,
(1, model.config.decoder_config.num_queries, inputs_shape[-2] // 4, inputs_shape[-1] // 4),
)
expected_slice = [
[-1.3737124, -1.7724937, -1.9364233],
......@@ -396,7 +401,9 @@ class MaskFormerModelIntegrationTest(unittest.TestCase):
self.assertTrue(torch.allclose(masks_queries_logits[0, 0, :3, :3], expected_slice, atol=TOLERANCE))
# class_queries_logits
class_queries_logits = outputs.class_queries_logits
self.assertEqual(class_queries_logits.shape, (1, model.config.num_queries, model.config.num_labels + 1))
self.assertEqual(
class_queries_logits.shape, (1, model.config.decoder_config.num_queries, model.config.num_labels + 1)
)
expected_slice = torch.tensor(
[
[1.6512e00, -5.2572e00, -3.3519e00],
......@@ -406,8 +413,48 @@ class MaskFormerModelIntegrationTest(unittest.TestCase):
).to(torch_device)
self.assertTrue(torch.allclose(outputs.class_queries_logits[0, :3, :3], expected_slice, atol=TOLERANCE))
def test_inference_instance_segmentation_head_resnet_backbone(self):
model = (
MaskFormerForInstanceSegmentation.from_pretrained("facebook/maskformer-resnet101-coco-stuff")
.to(torch_device)
.eval()
)
feature_extractor = self.default_feature_extractor
image = prepare_img()
inputs = feature_extractor(image, return_tensors="pt").to(torch_device)
inputs_shape = inputs["pixel_values"].shape
# check size is divisible by 32
self.assertTrue((inputs_shape[-1] % 32) == 0 and (inputs_shape[-2] % 32) == 0)
# check size
self.assertEqual(inputs_shape, (1, 3, 800, 1088))
with torch.no_grad():
outputs = model(**inputs)
# masks_queries_logits
masks_queries_logits = outputs.masks_queries_logits
self.assertEqual(
masks_queries_logits.shape,
(1, model.config.decoder_config.num_queries, inputs_shape[-2] // 4, inputs_shape[-1] // 4),
)
expected_slice = [[-0.9046, -2.6366, -4.6062], [-3.4179, -5.7890, -8.8057], [-4.9179, -7.6560, -10.7711]]
expected_slice = torch.tensor(expected_slice).to(torch_device)
self.assertTrue(torch.allclose(masks_queries_logits[0, 0, :3, :3], expected_slice, atol=TOLERANCE))
# class_queries_logits
class_queries_logits = outputs.class_queries_logits
self.assertEqual(
class_queries_logits.shape, (1, model.config.decoder_config.num_queries, model.config.num_labels + 1)
)
expected_slice = torch.tensor(
[[4.7188, -3.2585, -2.8857], [6.6871, -2.9181, -1.2487], [7.2449, -2.2764, -2.1874]]
).to(torch_device)
self.assertTrue(torch.allclose(outputs.class_queries_logits[0, :3, :3], expected_slice, atol=TOLERANCE))
def test_with_segmentation_maps_and_loss(self):
model = MaskFormerForInstanceSegmentation.from_pretrained(self.model_checkpoints).to(torch_device).eval()
model = (
MaskFormerForInstanceSegmentation.from_pretrained("facebook/maskformer-swin-small-coco")
.to(torch_device)
.eval()
)
feature_extractor = self.default_feature_extractor
inputs = feature_extractor(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment