"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "7579a52b55611ba7651b6d05cba6f45539a6089d"
Unverified Commit b610c47f authored by NielsRogge's avatar NielsRogge Committed by GitHub
Browse files

[MaskFormer] Add support for ResNet backbone (#20483)



* Add SwinBackbone

* Add hidden_states_before_downsampling support

* Fix Swin tests

* Improve conversion script

* Add id2label mappings

* Add vistas mapping

* Update comments

* Fix backbone

* Improve tests

* Extend conversion script

* Add Swin conversion script

* Fix style

* Revert config attribute

* Remove SwinBackbone from main init

* Remove unused attribute

* Use encoder for ResNet backbone

* Improve conversion script and add integration test

* Apply suggestion
Co-authored-by: default avatarNiels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
parent 6c1a0b39
...@@ -18,7 +18,7 @@ from typing import Dict, Optional ...@@ -18,7 +18,7 @@ from typing import Dict, Optional
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...utils import logging from ...utils import logging
from ..auto.configuration_auto import AutoConfig from ..auto import CONFIG_MAPPING
from ..detr import DetrConfig from ..detr import DetrConfig
from ..swin import SwinConfig from ..swin import SwinConfig
...@@ -97,7 +97,7 @@ class MaskFormerConfig(PretrainedConfig): ...@@ -97,7 +97,7 @@ class MaskFormerConfig(PretrainedConfig):
""" """
model_type = "maskformer" model_type = "maskformer"
attribute_map = {"hidden_size": "mask_feature_size"} attribute_map = {"hidden_size": "mask_feature_size"}
backbones_supported = ["swin"] backbones_supported = ["resnet", "swin"]
decoders_supported = ["detr"] decoders_supported = ["detr"]
def __init__( def __init__(
...@@ -127,27 +127,38 @@ class MaskFormerConfig(PretrainedConfig): ...@@ -127,27 +127,38 @@ class MaskFormerConfig(PretrainedConfig):
num_heads=[4, 8, 16, 32], num_heads=[4, 8, 16, 32],
window_size=12, window_size=12,
drop_path_rate=0.3, drop_path_rate=0.3,
out_features=["stage1", "stage2", "stage3", "stage4"],
) )
else: else:
backbone_model_type = backbone_config.pop("model_type") # verify that the backbone is supported
backbone_model_type = (
backbone_config.pop("model_type") if isinstance(backbone_config, dict) else backbone_config.model_type
)
if backbone_model_type not in self.backbones_supported: if backbone_model_type not in self.backbones_supported:
raise ValueError( raise ValueError(
f"Backbone {backbone_model_type} not supported, please use one of" f"Backbone {backbone_model_type} not supported, please use one of"
f" {','.join(self.backbones_supported)}" f" {','.join(self.backbones_supported)}"
) )
backbone_config = AutoConfig.for_model(backbone_model_type, **backbone_config) if isinstance(backbone_config, dict):
config_class = CONFIG_MAPPING[backbone_model_type]
backbone_config = config_class.from_dict(backbone_config)
if decoder_config is None: if decoder_config is None:
# fall back to https://huggingface.co/facebook/detr-resnet-50 # fall back to https://huggingface.co/facebook/detr-resnet-50
decoder_config = DetrConfig() decoder_config = DetrConfig()
else: else:
decoder_type = decoder_config.pop("model_type") # verify that the decoder is supported
decoder_type = (
decoder_config.pop("model_type") if isinstance(decoder_config, dict) else decoder_config.model_type
)
if decoder_type not in self.decoders_supported: if decoder_type not in self.decoders_supported:
raise ValueError( raise ValueError(
f"Transformer Decoder {decoder_type} not supported, please use one of" f"Transformer Decoder {decoder_type} not supported, please use one of"
f" {','.join(self.decoders_supported)}" f" {','.join(self.decoders_supported)}"
) )
decoder_config = AutoConfig.for_model(decoder_type, **decoder_config) if isinstance(decoder_config, dict):
config_class = CONFIG_MAPPING[decoder_type]
decoder_config = config_class.from_dict(decoder_config)
self.backbone_config = backbone_config self.backbone_config = backbone_config
self.decoder_config = decoder_config self.decoder_config = decoder_config
...@@ -186,8 +197,8 @@ class MaskFormerConfig(PretrainedConfig): ...@@ -186,8 +197,8 @@ class MaskFormerConfig(PretrainedConfig):
[`MaskFormerConfig`]: An instance of a configuration object [`MaskFormerConfig`]: An instance of a configuration object
""" """
return cls( return cls(
backbone_config=backbone_config.to_dict(), backbone_config=backbone_config,
decoder_config=decoder_config.to_dict(), decoder_config=decoder_config,
**kwargs, **kwargs,
) )
......
...@@ -69,7 +69,7 @@ class MaskFormerSwinConfig(PretrainedConfig): ...@@ -69,7 +69,7 @@ class MaskFormerSwinConfig(PretrainedConfig):
layer_norm_eps (`float`, *optional*, defaults to 1e-12): layer_norm_eps (`float`, *optional*, defaults to 1e-12):
The epsilon used by the layer normalization layers. The epsilon used by the layer normalization layers.
out_features (`List[str]`, *optional*): out_features (`List[str]`, *optional*):
If used as a backbone, list of feature names to output, e.g. `["stem", "stage1"]`. If used as a backbone, list of feature names to output, e.g. `["stage1", "stage2"]`.
Example: Example:
......
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert MaskFormer checkpoints with ResNet backbone from the original repository. URL:
https://github.com/facebookresearch/MaskFormer"""
import argparse
import json
import pickle
from pathlib import Path
import torch
from PIL import Image
import requests
from huggingface_hub import hf_hub_download
from transformers import MaskFormerConfig, MaskFormerFeatureExtractor, MaskFormerForInstanceSegmentation, ResNetConfig
from transformers.utils import logging
logging.set_verbosity_info()
logger = logging.get_logger(__name__)
def get_maskformer_config(model_name: str):
if "resnet101c" in model_name:
# TODO add support for ResNet-C backbone, which uses a "deeplab" stem
raise NotImplementedError("To do")
elif "resnet101" in model_name:
backbone_config = ResNetConfig.from_pretrained(
"microsoft/resnet-101", out_features=["stage1", "stage2", "stage3", "stage4"]
)
else:
backbone_config = ResNetConfig.from_pretrained(
"microsoft/resnet-50", out_features=["stage1", "stage2", "stage3", "stage4"]
)
config = MaskFormerConfig(backbone_config=backbone_config)
repo_id = "huggingface/label-files"
if "ade20k-full" in model_name:
config.num_labels = 847
filename = "maskformer-ade20k-full-id2label.json"
elif "ade" in model_name:
config.num_labels = 150
filename = "ade20k-id2label.json"
elif "coco-stuff" in model_name:
config.num_labels = 171
filename = "maskformer-coco-stuff-id2label.json"
elif "coco" in model_name:
# TODO
config.num_labels = 133
filename = "coco-panoptic-id2label.json"
elif "cityscapes" in model_name:
config.num_labels = 19
filename = "cityscapes-id2label.json"
elif "vistas" in model_name:
config.num_labels = 65
filename = "mapillary-vistas-id2label.json"
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
id2label = {int(k): v for k, v in id2label.items()}
config.id2label = id2label
config.label2id = {v: k for k, v in id2label.items()}
return config
def create_rename_keys(config):
rename_keys = []
# stem
# fmt: off
rename_keys.append(("backbone.stem.conv1.weight", "model.pixel_level_module.encoder.embedder.embedder.convolution.weight"))
rename_keys.append(("backbone.stem.conv1.norm.weight", "model.pixel_level_module.encoder.embedder.embedder.normalization.weight"))
rename_keys.append(("backbone.stem.conv1.norm.bias", "model.pixel_level_module.encoder.embedder.embedder.normalization.bias"))
rename_keys.append(("backbone.stem.conv1.norm.running_mean", "model.pixel_level_module.encoder.embedder.embedder.normalization.running_mean"))
rename_keys.append(("backbone.stem.conv1.norm.running_var", "model.pixel_level_module.encoder.embedder.embedder.normalization.running_var"))
# fmt: on
# stages
for stage_idx in range(len(config.backbone_config.depths)):
for layer_idx in range(config.backbone_config.depths[stage_idx]):
# shortcut
if layer_idx == 0:
rename_keys.append(
(
f"backbone.res{stage_idx + 2}.{layer_idx}.shortcut.weight",
f"model.pixel_level_module.encoder.encoder.stages.{stage_idx}.layers.{layer_idx}.shortcut.convolution.weight",
)
)
rename_keys.append(
(
f"backbone.res{stage_idx + 2}.{layer_idx}.shortcut.norm.weight",
f"model.pixel_level_module.encoder.encoder.stages.{stage_idx}.layers.{layer_idx}.shortcut.normalization.weight",
)
)
rename_keys.append(
(
f"backbone.res{stage_idx + 2}.{layer_idx}.shortcut.norm.bias",
f"model.pixel_level_module.encoder.encoder.stages.{stage_idx}.layers.{layer_idx}.shortcut.normalization.bias",
)
)
rename_keys.append(
(
f"backbone.res{stage_idx + 2}.{layer_idx}.shortcut.norm.running_mean",
f"model.pixel_level_module.encoder.encoder.stages.{stage_idx}.layers.{layer_idx}.shortcut.normalization.running_mean",
)
)
rename_keys.append(
(
f"backbone.res{stage_idx + 2}.{layer_idx}.shortcut.norm.running_var",
f"model.pixel_level_module.encoder.encoder.stages.{stage_idx}.layers.{layer_idx}.shortcut.normalization.running_var",
)
)
# 3 convs
for i in range(3):
rename_keys.append(
(
f"backbone.res{stage_idx + 2}.{layer_idx}.conv{i+1}.weight",
f"model.pixel_level_module.encoder.encoder.stages.{stage_idx}.layers.{layer_idx}.layer.{i}.convolution.weight",
)
)
rename_keys.append(
(
f"backbone.res{stage_idx + 2}.{layer_idx}.conv{i+1}.norm.weight",
f"model.pixel_level_module.encoder.encoder.stages.{stage_idx}.layers.{layer_idx}.layer.{i}.normalization.weight",
)
)
rename_keys.append(
(
f"backbone.res{stage_idx + 2}.{layer_idx}.conv{i+1}.norm.bias",
f"model.pixel_level_module.encoder.encoder.stages.{stage_idx}.layers.{layer_idx}.layer.{i}.normalization.bias",
)
)
rename_keys.append(
(
f"backbone.res{stage_idx + 2}.{layer_idx}.conv{i+1}.norm.running_mean",
f"model.pixel_level_module.encoder.encoder.stages.{stage_idx}.layers.{layer_idx}.layer.{i}.normalization.running_mean",
)
)
rename_keys.append(
(
f"backbone.res{stage_idx + 2}.{layer_idx}.conv{i+1}.norm.running_var",
f"model.pixel_level_module.encoder.encoder.stages.{stage_idx}.layers.{layer_idx}.layer.{i}.normalization.running_var",
)
)
# FPN
# fmt: off
rename_keys.append(("sem_seg_head.layer_4.weight", "model.pixel_level_module.decoder.fpn.stem.0.weight"))
rename_keys.append(("sem_seg_head.layer_4.norm.weight", "model.pixel_level_module.decoder.fpn.stem.1.weight"))
rename_keys.append(("sem_seg_head.layer_4.norm.bias", "model.pixel_level_module.decoder.fpn.stem.1.bias"))
for source_index, target_index in zip(range(3, 0, -1), range(0, 3)):
rename_keys.append((f"sem_seg_head.adapter_{source_index}.weight", f"model.pixel_level_module.decoder.fpn.layers.{target_index}.proj.0.weight"))
rename_keys.append((f"sem_seg_head.adapter_{source_index}.norm.weight", f"model.pixel_level_module.decoder.fpn.layers.{target_index}.proj.1.weight"))
rename_keys.append((f"sem_seg_head.adapter_{source_index}.norm.bias", f"model.pixel_level_module.decoder.fpn.layers.{target_index}.proj.1.bias"))
rename_keys.append((f"sem_seg_head.layer_{source_index}.weight", f"model.pixel_level_module.decoder.fpn.layers.{target_index}.block.0.weight"))
rename_keys.append((f"sem_seg_head.layer_{source_index}.norm.weight", f"model.pixel_level_module.decoder.fpn.layers.{target_index}.block.1.weight"))
rename_keys.append((f"sem_seg_head.layer_{source_index}.norm.bias", f"model.pixel_level_module.decoder.fpn.layers.{target_index}.block.1.bias"))
rename_keys.append(("sem_seg_head.mask_features.weight", "model.pixel_level_module.decoder.mask_projection.weight"))
rename_keys.append(("sem_seg_head.mask_features.bias", "model.pixel_level_module.decoder.mask_projection.bias"))
# fmt: on
# Transformer decoder
# fmt: off
for idx in range(config.decoder_config.decoder_layers):
# self-attention out projection
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.self_attn.out_proj.weight", f"model.transformer_module.decoder.layers.{idx}.self_attn.out_proj.weight"))
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.self_attn.out_proj.bias", f"model.transformer_module.decoder.layers.{idx}.self_attn.out_proj.bias"))
# cross-attention out projection
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.multihead_attn.out_proj.weight", f"model.transformer_module.decoder.layers.{idx}.encoder_attn.out_proj.weight"))
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.multihead_attn.out_proj.bias", f"model.transformer_module.decoder.layers.{idx}.encoder_attn.out_proj.bias"))
# MLP 1
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.linear1.weight", f"model.transformer_module.decoder.layers.{idx}.fc1.weight"))
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.linear1.bias", f"model.transformer_module.decoder.layers.{idx}.fc1.bias"))
# MLP 2
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.linear2.weight", f"model.transformer_module.decoder.layers.{idx}.fc2.weight"))
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.linear2.bias", f"model.transformer_module.decoder.layers.{idx}.fc2.bias"))
# layernorm 1 (self-attention layernorm)
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.norm1.weight", f"model.transformer_module.decoder.layers.{idx}.self_attn_layer_norm.weight"))
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.norm1.bias", f"model.transformer_module.decoder.layers.{idx}.self_attn_layer_norm.bias"))
# layernorm 2 (cross-attention layernorm)
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.norm2.weight", f"model.transformer_module.decoder.layers.{idx}.encoder_attn_layer_norm.weight"))
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.norm2.bias", f"model.transformer_module.decoder.layers.{idx}.encoder_attn_layer_norm.bias"))
# layernorm 3 (final layernorm)
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.norm3.weight", f"model.transformer_module.decoder.layers.{idx}.final_layer_norm.weight"))
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.norm3.bias", f"model.transformer_module.decoder.layers.{idx}.final_layer_norm.bias"))
rename_keys.append(("sem_seg_head.predictor.transformer.decoder.norm.weight", "model.transformer_module.decoder.layernorm.weight"))
rename_keys.append(("sem_seg_head.predictor.transformer.decoder.norm.bias", "model.transformer_module.decoder.layernorm.bias"))
# fmt: on
# heads on top
# fmt: off
rename_keys.append(("sem_seg_head.predictor.query_embed.weight", "model.transformer_module.queries_embedder.weight"))
rename_keys.append(("sem_seg_head.predictor.input_proj.weight", "model.transformer_module.input_projection.weight"))
rename_keys.append(("sem_seg_head.predictor.input_proj.bias", "model.transformer_module.input_projection.bias"))
rename_keys.append(("sem_seg_head.predictor.class_embed.weight", "class_predictor.weight"))
rename_keys.append(("sem_seg_head.predictor.class_embed.bias", "class_predictor.bias"))
for i in range(3):
rename_keys.append((f"sem_seg_head.predictor.mask_embed.layers.{i}.weight", f"mask_embedder.{i}.0.weight"))
rename_keys.append((f"sem_seg_head.predictor.mask_embed.layers.{i}.bias", f"mask_embedder.{i}.0.bias"))
# fmt: on
return rename_keys
def rename_key(dct, old, new):
val = dct.pop(old)
dct[new] = val
# we split up the matrix of each encoder layer into queries, keys and values
def read_in_decoder_q_k_v(state_dict, config):
# fmt: off
hidden_size = config.decoder_config.hidden_size
for idx in range(config.decoder_config.decoder_layers):
# read in weights + bias of self-attention input projection layer (in the original implementation, this is a single matrix + bias)
in_proj_weight = state_dict.pop(f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.self_attn.in_proj_weight")
in_proj_bias = state_dict.pop(f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.self_attn.in_proj_bias")
# next, add query, keys and values (in that order) to the state dict
state_dict[f"model.transformer_module.decoder.layers.{idx}.self_attn.q_proj.weight"] = in_proj_weight[: hidden_size, :]
state_dict[f"model.transformer_module.decoder.layers.{idx}.self_attn.q_proj.bias"] = in_proj_bias[:config.hidden_size]
state_dict[f"model.transformer_module.decoder.layers.{idx}.self_attn.k_proj.weight"] = in_proj_weight[hidden_size : hidden_size * 2, :]
state_dict[f"model.transformer_module.decoder.layers.{idx}.self_attn.k_proj.bias"] = in_proj_bias[hidden_size : hidden_size * 2]
state_dict[f"model.transformer_module.decoder.layers.{idx}.self_attn.v_proj.weight"] = in_proj_weight[-hidden_size :, :]
state_dict[f"model.transformer_module.decoder.layers.{idx}.self_attn.v_proj.bias"] = in_proj_bias[-hidden_size :]
# read in weights + bias of cross-attention input projection layer (in the original implementation, this is a single matrix + bias)
in_proj_weight = state_dict.pop(f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.multihead_attn.in_proj_weight")
in_proj_bias = state_dict.pop(f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.multihead_attn.in_proj_bias")
# next, add query, keys and values (in that order) to the state dict
state_dict[f"model.transformer_module.decoder.layers.{idx}.encoder_attn.q_proj.weight"] = in_proj_weight[: hidden_size, :]
state_dict[f"model.transformer_module.decoder.layers.{idx}.encoder_attn.q_proj.bias"] = in_proj_bias[:config.hidden_size]
state_dict[f"model.transformer_module.decoder.layers.{idx}.encoder_attn.k_proj.weight"] = in_proj_weight[hidden_size : hidden_size * 2, :]
state_dict[f"model.transformer_module.decoder.layers.{idx}.encoder_attn.k_proj.bias"] = in_proj_bias[hidden_size : hidden_size * 2]
state_dict[f"model.transformer_module.decoder.layers.{idx}.encoder_attn.v_proj.weight"] = in_proj_weight[-hidden_size :, :]
state_dict[f"model.transformer_module.decoder.layers.{idx}.encoder_attn.v_proj.bias"] = in_proj_bias[-hidden_size :]
# fmt: on
# We will verify our results on an image of cute cats
def prepare_img() -> torch.Tensor:
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
im = Image.open(requests.get(url, stream=True).raw)
return im
@torch.no_grad()
def convert_maskformer_checkpoint(
model_name: str, checkpoint_path: str, pytorch_dump_folder_path: str, push_to_hub: bool = False
):
"""
Copy/paste/tweak model's weights to our MaskFormer structure.
"""
config = get_maskformer_config(model_name)
# load original state_dict
with open(checkpoint_path, "rb") as f:
data = pickle.load(f)
state_dict = data["model"]
# rename keys
rename_keys = create_rename_keys(config)
for src, dest in rename_keys:
rename_key(state_dict, src, dest)
read_in_decoder_q_k_v(state_dict, config)
# update to torch tensors
for key, value in state_dict.items():
state_dict[key] = torch.from_numpy(value)
# load 🤗 model
model = MaskFormerForInstanceSegmentation(config)
model.eval()
model.load_state_dict(state_dict)
# verify results
image = prepare_img()
if "vistas" in model_name:
ignore_index = 65
elif "cityscapes" in model_name:
ignore_index = 65535
else:
ignore_index = 255
reduce_labels = True if "ade" in model_name else False
feature_extractor = MaskFormerFeatureExtractor(ignore_index=ignore_index, reduce_labels=reduce_labels)
inputs = feature_extractor(image, return_tensors="pt")
outputs = model(**inputs)
if model_name == "maskformer-resnet50-ade":
expected_logits = torch.tensor(
[[6.7710, -0.1452, -3.5687], [1.9165, -1.0010, -1.8614], [3.6209, -0.2950, -1.3813]]
)
elif model_name == "maskformer-resnet101-ade":
expected_logits = torch.tensor(
[[4.0381, -1.1483, -1.9688], [2.7083, -1.9147, -2.2555], [3.4367, -1.3711, -2.1609]]
)
elif model_name == "maskformer-resnet50-coco-stuff":
expected_logits = torch.tensor(
[[3.2309, -3.0481, -2.8695], [5.4986, -5.4242, -2.4211], [6.2100, -5.2279, -2.7786]]
)
elif model_name == "maskformer-resnet101-coco-stuff":
expected_logits = torch.tensor(
[[4.7188, -3.2585, -2.8857], [6.6871, -2.9181, -1.2487], [7.2449, -2.2764, -2.1874]]
)
elif model_name == "maskformer-resnet101-cityscapes":
expected_logits = torch.tensor(
[[-1.8861, -1.5465, 0.6749], [-2.3677, -1.6707, -0.0867], [-2.2314, -1.9530, -0.9132]]
)
elif model_name == "maskformer-resnet50-vistas":
expected_logits = torch.tensor(
[[-6.3917, -1.5216, -1.1392], [-5.5335, -4.5318, -1.8339], [-4.3576, -4.0301, 0.2162]]
)
elif model_name == "maskformer-resnet50-ade20k-full":
expected_logits = torch.tensor(
[[3.6146, -1.9367, -3.2534], [4.0099, 0.2027, -2.7576], [3.3913, -2.3644, -3.9519]]
)
elif model_name == "maskformer-resnet101-ade20k-full":
expected_logits = torch.tensor(
[[3.2211, -1.6550, -2.7605], [2.8559, -2.4512, -2.9574], [2.6331, -2.6775, -2.1844]]
)
assert torch.allclose(outputs.class_queries_logits[0, :3, :3], expected_logits, atol=1e-4)
print("Looks ok!")
if pytorch_dump_folder_path is not None:
print(f"Saving model and feature extractor of {model_name} to {pytorch_dump_folder_path}")
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
model.save_pretrained(pytorch_dump_folder_path)
feature_extractor.save_pretrained(pytorch_dump_folder_path)
if push_to_hub:
print(f"Pushing model and feature extractor of {model_name} to the hub...")
model.push_to_hub(f"facebook/{model_name}")
feature_extractor.push_to_hub(f"facebook/{model_name}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--model_name",
default="maskformer-resnet50-ade",
type=str,
required=True,
choices=[
"maskformer-resnet50-ade",
"maskformer-resnet101-ade",
"maskformer-resnet50-coco-stuff",
"maskformer-resnet101-coco-stuff",
"maskformer-resnet101-cityscapes",
"maskformer-resnet50-vistas",
"maskformer-resnet50-ade20k-full",
"maskformer-resnet101-ade20k-full",
],
help=("Name of the MaskFormer model you'd like to convert",),
)
parser.add_argument(
"--checkpoint_path",
type=str,
required=True,
help=("Path to the original pickle file (.pkl) of the original checkpoint.",),
)
parser.add_argument(
"--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
)
parser.add_argument(
"--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub."
)
args = parser.parse_args()
convert_maskformer_checkpoint(
args.model_name, args.checkpoint_path, args.pytorch_dump_folder_path, args.push_to_hub
)
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert MaskFormer checkpoints with Swin backbone from the original repository. URL:
https://github.com/facebookresearch/MaskFormer"""
import argparse
import json
import pickle
from pathlib import Path
import torch
from PIL import Image
import requests
from huggingface_hub import hf_hub_download
from transformers import MaskFormerConfig, MaskFormerFeatureExtractor, MaskFormerForInstanceSegmentation, SwinConfig
from transformers.utils import logging
logging.set_verbosity_info()
logger = logging.get_logger(__name__)
def get_maskformer_config(model_name: str):
backbone_config = SwinConfig.from_pretrained(
"microsoft/swin-tiny-patch4-window7-224", out_features=["stage1", "stage2", "stage3", "stage4"]
)
config = MaskFormerConfig(backbone_config=backbone_config)
repo_id = "huggingface/label-files"
if "ade20k-full" in model_name:
# this should be ok
config.num_labels = 847
filename = "maskformer-ade20k-full-id2label.json"
elif "ade" in model_name:
# this should be ok
config.num_labels = 150
filename = "ade20k-id2label.json"
elif "coco-stuff" in model_name:
# this should be ok
config.num_labels = 171
filename = "maskformer-coco-stuff-id2label.json"
elif "coco" in model_name:
# TODO
config.num_labels = 133
filename = "coco-panoptic-id2label.json"
elif "cityscapes" in model_name:
# this should be ok
config.num_labels = 19
filename = "cityscapes-id2label.json"
elif "vistas" in model_name:
# this should be ok
config.num_labels = 65
filename = "mapillary-vistas-id2label.json"
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
id2label = {int(k): v for k, v in id2label.items()}
return config
def create_rename_keys(config):
rename_keys = []
# stem
# fmt: off
rename_keys.append(("backbone.patch_embed.proj.weight", "model.pixel_level_module.encoder.model.embeddings.patch_embeddings.projection.weight"))
rename_keys.append(("backbone.patch_embed.proj.bias", "model.pixel_level_module.encoder.model.embeddings.patch_embeddings.projection.bias"))
rename_keys.append(("backbone.patch_embed.norm.weight", "model.pixel_level_module.encoder.model.embeddings.norm.weight"))
rename_keys.append(("backbone.patch_embed.norm.bias", "model.pixel_level_module.encoder.model.embeddings.norm.bias"))
# stages
for i in range(len(config.backbone_config.depths)):
for j in range(config.backbone_config.depths[i]):
rename_keys.append((f"backbone.layers.{i}.blocks.{j}.norm1.weight", f"model.pixel_level_module.encoder.model.encoder.layers.{i}.blocks.{j}.layernorm_before.weight"))
rename_keys.append((f"backbone.layers.{i}.blocks.{j}.norm1.bias", f"model.pixel_level_module.encoder.model.encoder.layers.{i}.blocks.{j}.layernorm_before.bias"))
rename_keys.append((f"backbone.layers.{i}.blocks.{j}.attn.relative_position_bias_table", f"model.pixel_level_module.encoder.model.encoder.layers.{i}.blocks.{j}.attention.self.relative_position_bias_table"))
rename_keys.append((f"backbone.layers.{i}.blocks.{j}.attn.relative_position_index", f"model.pixel_level_module.encoder.model.encoder.layers.{i}.blocks.{j}.attention.self.relative_position_index"))
rename_keys.append((f"backbone.layers.{i}.blocks.{j}.attn.proj.weight", f"model.pixel_level_module.encoder.model.encoder.layers.{i}.blocks.{j}.attention.output.dense.weight"))
rename_keys.append((f"backbone.layers.{i}.blocks.{j}.attn.proj.bias", f"model.pixel_level_module.encoder.model.encoder.layers.{i}.blocks.{j}.attention.output.dense.bias"))
rename_keys.append((f"backbone.layers.{i}.blocks.{j}.norm2.weight", f"model.pixel_level_module.encoder.model.encoder.layers.{i}.blocks.{j}.layernorm_after.weight"))
rename_keys.append((f"backbone.layers.{i}.blocks.{j}.norm2.bias", f"model.pixel_level_module.encoder.model.encoder.layers.{i}.blocks.{j}.layernorm_after.bias"))
rename_keys.append((f"backbone.layers.{i}.blocks.{j}.mlp.fc1.weight", f"model.pixel_level_module.encoder.model.encoder.layers.{i}.blocks.{j}.intermediate.dense.weight"))
rename_keys.append((f"backbone.layers.{i}.blocks.{j}.mlp.fc1.bias", f"model.pixel_level_module.encoder.model.encoder.layers.{i}.blocks.{j}.intermediate.dense.bias"))
rename_keys.append((f"backbone.layers.{i}.blocks.{j}.mlp.fc2.weight", f"model.pixel_level_module.encoder.model.encoder.layers.{i}.blocks.{j}.output.dense.weight"))
rename_keys.append((f"backbone.layers.{i}.blocks.{j}.mlp.fc2.bias", f"model.pixel_level_module.encoder.model.encoder.layers.{i}.blocks.{j}.output.dense.bias"))
if i < 3:
rename_keys.append((f"backbone.layers.{i}.downsample.reduction.weight", f"model.pixel_level_module.encoder.model.encoder.layers.{i}.downsample.reduction.weight"))
rename_keys.append((f"backbone.layers.{i}.downsample.norm.weight", f"model.pixel_level_module.encoder.model.encoder.layers.{i}.downsample.norm.weight"))
rename_keys.append((f"backbone.layers.{i}.downsample.norm.bias", f"model.pixel_level_module.encoder.model.encoder.layers.{i}.downsample.norm.bias"))
rename_keys.append((f"backbone.norm{i}.weight", f"model.pixel_level_module.encoder.hidden_states_norms.{i}.weight"))
rename_keys.append((f"backbone.norm{i}.bias", f"model.pixel_level_module.encoder.hidden_states_norms.{i}.bias"))
# FPN
rename_keys.append(("sem_seg_head.layer_4.weight", "model.pixel_level_module.decoder.fpn.stem.0.weight"))
rename_keys.append(("sem_seg_head.layer_4.norm.weight", "model.pixel_level_module.decoder.fpn.stem.1.weight"))
rename_keys.append(("sem_seg_head.layer_4.norm.bias", "model.pixel_level_module.decoder.fpn.stem.1.bias"))
for source_index, target_index in zip(range(3, 0, -1), range(0, 3)):
rename_keys.append((f"sem_seg_head.adapter_{source_index}.weight", f"model.pixel_level_module.decoder.fpn.layers.{target_index}.proj.0.weight"))
rename_keys.append((f"sem_seg_head.adapter_{source_index}.norm.weight", f"model.pixel_level_module.decoder.fpn.layers.{target_index}.proj.1.weight"))
rename_keys.append((f"sem_seg_head.adapter_{source_index}.norm.bias", f"model.pixel_level_module.decoder.fpn.layers.{target_index}.proj.1.bias"))
rename_keys.append((f"sem_seg_head.layer_{source_index}.weight", f"model.pixel_level_module.decoder.fpn.layers.{target_index}.block.0.weight"))
rename_keys.append((f"sem_seg_head.layer_{source_index}.norm.weight", f"model.pixel_level_module.decoder.fpn.layers.{target_index}.block.1.weight"))
rename_keys.append((f"sem_seg_head.layer_{source_index}.norm.bias", f"model.pixel_level_module.decoder.fpn.layers.{target_index}.block.1.bias"))
rename_keys.append(("sem_seg_head.mask_features.weight", "model.pixel_level_module.decoder.mask_projection.weight"))
rename_keys.append(("sem_seg_head.mask_features.bias", "model.pixel_level_module.decoder.mask_projection.bias"))
# Transformer decoder
for idx in range(config.decoder_config.decoder_layers):
# self-attention out projection
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.self_attn.out_proj.weight", f"model.transformer_module.decoder.layers.{idx}.self_attn.out_proj.weight"))
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.self_attn.out_proj.bias", f"model.transformer_module.decoder.layers.{idx}.self_attn.out_proj.bias"))
# cross-attention out projection
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.multihead_attn.out_proj.weight", f"model.transformer_module.decoder.layers.{idx}.encoder_attn.out_proj.weight"))
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.multihead_attn.out_proj.bias", f"model.transformer_module.decoder.layers.{idx}.encoder_attn.out_proj.bias"))
# MLP 1
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.linear1.weight", f"model.transformer_module.decoder.layers.{idx}.fc1.weight"))
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.linear1.bias", f"model.transformer_module.decoder.layers.{idx}.fc1.bias"))
# MLP 2
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.linear2.weight", f"model.transformer_module.decoder.layers.{idx}.fc2.weight"))
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.linear2.bias", f"model.transformer_module.decoder.layers.{idx}.fc2.bias"))
# layernorm 1 (self-attention layernorm)
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.norm1.weight", f"model.transformer_module.decoder.layers.{idx}.self_attn_layer_norm.weight"))
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.norm1.bias", f"model.transformer_module.decoder.layers.{idx}.self_attn_layer_norm.bias"))
# layernorm 2 (cross-attention layernorm)
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.norm2.weight", f"model.transformer_module.decoder.layers.{idx}.encoder_attn_layer_norm.weight"))
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.norm2.bias", f"model.transformer_module.decoder.layers.{idx}.encoder_attn_layer_norm.bias"))
# layernorm 3 (final layernorm)
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.norm3.weight", f"model.transformer_module.decoder.layers.{idx}.final_layer_norm.weight"))
rename_keys.append((f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.norm3.bias", f"model.transformer_module.decoder.layers.{idx}.final_layer_norm.bias"))
rename_keys.append(("sem_seg_head.predictor.transformer.decoder.norm.weight", "model.transformer_module.decoder.layernorm.weight"))
rename_keys.append(("sem_seg_head.predictor.transformer.decoder.norm.bias", "model.transformer_module.decoder.layernorm.bias"))
# heads on top
rename_keys.append(("sem_seg_head.predictor.query_embed.weight", "model.transformer_module.queries_embedder.weight"))
rename_keys.append(("sem_seg_head.predictor.input_proj.weight", "model.transformer_module.input_projection.weight"))
rename_keys.append(("sem_seg_head.predictor.input_proj.bias", "model.transformer_module.input_projection.bias"))
rename_keys.append(("sem_seg_head.predictor.class_embed.weight", "class_predictor.weight"))
rename_keys.append(("sem_seg_head.predictor.class_embed.bias", "class_predictor.bias"))
for i in range(3):
rename_keys.append((f"sem_seg_head.predictor.mask_embed.layers.{i}.weight", f"mask_embedder.{i}.0.weight"))
rename_keys.append((f"sem_seg_head.predictor.mask_embed.layers.{i}.bias", f"mask_embedder.{i}.0.bias"))
# fmt: on
return rename_keys
def rename_key(dct, old, new):
val = dct.pop(old)
dct[new] = val
# we split up the matrix of each encoder layer into queries, keys and values
def read_in_swin_q_k_v(state_dict, backbone_config):
num_features = [int(backbone_config.embed_dim * 2**i) for i in range(len(backbone_config.depths))]
for i in range(len(backbone_config.depths)):
dim = num_features[i]
for j in range(backbone_config.depths[i]):
# fmt: off
# read in weights + bias of input projection layer (in original implementation, this is a single matrix + bias)
in_proj_weight = state_dict.pop(f"backbone.layers.{i}.blocks.{j}.attn.qkv.weight")
in_proj_bias = state_dict.pop(f"backbone.layers.{i}.blocks.{j}.attn.qkv.bias")
# next, add query, keys and values (in that order) to the state dict
state_dict[f"model.pixel_level_module.encoder.model.encoder.layers.{i}.blocks.{j}.attention.self.query.weight"] = in_proj_weight[:dim, :]
state_dict[f"model.pixel_level_module.encoder.model.encoder.layers.{i}.blocks.{j}.attention.self.query.bias"] = in_proj_bias[: dim]
state_dict[f"model.pixel_level_module.encoder.model.encoder.layers.{i}.blocks.{j}.attention.self.key.weight"] = in_proj_weight[
dim : dim * 2, :
]
state_dict[f"model.pixel_level_module.encoder.model.encoder.layers.{i}.blocks.{j}.attention.self.key.bias"] = in_proj_bias[
dim : dim * 2
]
state_dict[f"model.pixel_level_module.encoder.model.encoder.layers.{i}.blocks.{j}.attention.self.value.weight"] = in_proj_weight[
-dim :, :
]
state_dict[f"model.pixel_level_module.encoder.model.encoder.layers.{i}.blocks.{j}.attention.self.value.bias"] = in_proj_bias[-dim :]
# fmt: on
# we split up the matrix of each encoder layer into queries, keys and values
def read_in_decoder_q_k_v(state_dict, config):
# fmt: off
hidden_size = config.decoder_config.hidden_size
for idx in range(config.decoder_config.decoder_layers):
# read in weights + bias of self-attention input projection layer (in the original implementation, this is a single matrix + bias)
in_proj_weight = state_dict.pop(f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.self_attn.in_proj_weight")
in_proj_bias = state_dict.pop(f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.self_attn.in_proj_bias")
# next, add query, keys and values (in that order) to the state dict
state_dict[f"model.transformer_module.decoder.layers.{idx}.self_attn.q_proj.weight"] = in_proj_weight[: hidden_size, :]
state_dict[f"model.transformer_module.decoder.layers.{idx}.self_attn.q_proj.bias"] = in_proj_bias[:config.hidden_size]
state_dict[f"model.transformer_module.decoder.layers.{idx}.self_attn.k_proj.weight"] = in_proj_weight[hidden_size : hidden_size * 2, :]
state_dict[f"model.transformer_module.decoder.layers.{idx}.self_attn.k_proj.bias"] = in_proj_bias[hidden_size : hidden_size * 2]
state_dict[f"model.transformer_module.decoder.layers.{idx}.self_attn.v_proj.weight"] = in_proj_weight[-hidden_size :, :]
state_dict[f"model.transformer_module.decoder.layers.{idx}.self_attn.v_proj.bias"] = in_proj_bias[-hidden_size :]
# read in weights + bias of cross-attention input projection layer (in the original implementation, this is a single matrix + bias)
in_proj_weight = state_dict.pop(f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.multihead_attn.in_proj_weight")
in_proj_bias = state_dict.pop(f"sem_seg_head.predictor.transformer.decoder.layers.{idx}.multihead_attn.in_proj_bias")
# next, add query, keys and values (in that order) to the state dict
state_dict[f"model.transformer_module.decoder.layers.{idx}.encoder_attn.q_proj.weight"] = in_proj_weight[: hidden_size, :]
state_dict[f"model.transformer_module.decoder.layers.{idx}.encoder_attn.q_proj.bias"] = in_proj_bias[:config.hidden_size]
state_dict[f"model.transformer_module.decoder.layers.{idx}.encoder_attn.k_proj.weight"] = in_proj_weight[hidden_size : hidden_size * 2, :]
state_dict[f"model.transformer_module.decoder.layers.{idx}.encoder_attn.k_proj.bias"] = in_proj_bias[hidden_size : hidden_size * 2]
state_dict[f"model.transformer_module.decoder.layers.{idx}.encoder_attn.v_proj.weight"] = in_proj_weight[-hidden_size :, :]
state_dict[f"model.transformer_module.decoder.layers.{idx}.encoder_attn.v_proj.bias"] = in_proj_bias[-hidden_size :]
# fmt: on
# We will verify our results on an image of cute cats
def prepare_img() -> torch.Tensor:
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
im = Image.open(requests.get(url, stream=True).raw)
return im
@torch.no_grad()
def convert_maskformer_checkpoint(
model_name: str, checkpoint_path: str, pytorch_dump_folder_path: str, push_to_hub: bool = False
):
"""
Copy/paste/tweak model's weights to our MaskFormer structure.
"""
config = get_maskformer_config(model_name)
# load original state_dict
with open(checkpoint_path, "rb") as f:
data = pickle.load(f)
state_dict = data["model"]
# for name, param in state_dict.items():
# print(name, param.shape)
# rename keys
rename_keys = create_rename_keys(config)
for src, dest in rename_keys:
rename_key(state_dict, src, dest)
read_in_swin_q_k_v(state_dict, config.backbone_config)
read_in_decoder_q_k_v(state_dict, config)
# update to torch tensors
for key, value in state_dict.items():
state_dict[key] = torch.from_numpy(value)
# load 🤗 model
model = MaskFormerForInstanceSegmentation(config)
model.eval()
for name, param in model.named_parameters():
print(name, param.shape)
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
assert missing_keys == [
"model.pixel_level_module.encoder.model.layernorm.weight",
"model.pixel_level_module.encoder.model.layernorm.bias",
]
assert len(unexpected_keys) == 0, f"Unexpected keys: {unexpected_keys}"
# verify results
image = prepare_img()
if "vistas" in model_name:
ignore_index = 65
elif "cityscapes" in model_name:
ignore_index = 65535
else:
ignore_index = 255
reduce_labels = True if "ade" in model_name else False
feature_extractor = MaskFormerFeatureExtractor(ignore_index=ignore_index, reduce_labels=reduce_labels)
inputs = feature_extractor(image, return_tensors="pt")
outputs = model(**inputs)
print("Logits:", outputs.class_queries_logits[0, :3, :3])
if model_name == "maskformer-swin-tiny-ade":
expected_logits = torch.tensor(
[[3.6353, -4.4770, -2.6065], [0.5081, -4.2394, -3.5343], [2.1909, -5.0353, -1.9323]]
)
assert torch.allclose(outputs.class_queries_logits[0, :3, :3], expected_logits, atol=1e-4)
print("Looks ok!")
if pytorch_dump_folder_path is not None:
print(f"Saving model and feature extractor to {pytorch_dump_folder_path}")
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
model.save_pretrained(pytorch_dump_folder_path)
feature_extractor.save_pretrained(pytorch_dump_folder_path)
if push_to_hub:
print("Pushing model and feature extractor to the hub...")
model.push_to_hub(f"nielsr/{model_name}")
feature_extractor.push_to_hub(f"nielsr/{model_name}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--model_name",
default="maskformer-swin-tiny-ade",
type=str,
help=("Name of the MaskFormer model you'd like to convert",),
)
parser.add_argument(
"--checkpoint_path",
default="/Users/nielsrogge/Documents/MaskFormer_checkpoints/MaskFormer-Swin-tiny-ADE20k/model.pkl",
type=str,
help="Path to the original state dict (.pth file).",
)
parser.add_argument(
"--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
)
parser.add_argument(
"--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub."
)
args = parser.parse_args()
convert_maskformer_checkpoint(
args.model_name, args.checkpoint_path, args.pytorch_dump_folder_path, args.push_to_hub
)
...@@ -275,7 +275,6 @@ class RegNetEncoder(nn.Module): ...@@ -275,7 +275,6 @@ class RegNetEncoder(nn.Module):
return BaseModelOutputWithNoAttention(last_hidden_state=hidden_state, hidden_states=hidden_states) return BaseModelOutputWithNoAttention(last_hidden_state=hidden_state, hidden_states=hidden_states)
# Copied from transformers.models.resnet.modeling_resnet.ResNetPreTrainedModel with ResNet->RegNet,resnet->regnet
class RegNetPreTrainedModel(PreTrainedModel): class RegNetPreTrainedModel(PreTrainedModel):
""" """
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
...@@ -287,6 +286,7 @@ class RegNetPreTrainedModel(PreTrainedModel): ...@@ -287,6 +286,7 @@ class RegNetPreTrainedModel(PreTrainedModel):
main_input_name = "pixel_values" main_input_name = "pixel_values"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
# Copied from transformers.models.resnet.modeling_resnet.ResNetPreTrainedModel._init_weights
def _init_weights(self, module): def _init_weights(self, module):
if isinstance(module, nn.Conv2d): if isinstance(module, nn.Conv2d):
nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu") nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
......
...@@ -267,7 +267,7 @@ class ResNetPreTrainedModel(PreTrainedModel): ...@@ -267,7 +267,7 @@ class ResNetPreTrainedModel(PreTrainedModel):
nn.init.constant_(module.bias, 0) nn.init.constant_(module.bias, 0)
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, ResNetModel): if isinstance(module, (ResNetModel, ResNetBackbone)):
module.gradient_checkpointing = value module.gradient_checkpointing = value
...@@ -436,7 +436,8 @@ class ResNetBackbone(ResNetPreTrainedModel): ...@@ -436,7 +436,8 @@ class ResNetBackbone(ResNetPreTrainedModel):
super().__init__(config) super().__init__(config)
self.stage_names = config.stage_names self.stage_names = config.stage_names
self.resnet = ResNetModel(config) self.embedder = ResNetEmbeddings(config)
self.encoder = ResNetEncoder(config)
self.out_features = config.out_features self.out_features = config.out_features
...@@ -490,7 +491,9 @@ class ResNetBackbone(ResNetPreTrainedModel): ...@@ -490,7 +491,9 @@ class ResNetBackbone(ResNetPreTrainedModel):
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
) )
outputs = self.resnet(pixel_values, output_hidden_states=True, return_dict=True) embedding_output = self.embedder(pixel_values)
outputs = self.encoder(embedding_output, output_hidden_states=True, return_dict=True)
hidden_states = outputs.hidden_states hidden_states = outputs.hidden_states
......
...@@ -84,7 +84,7 @@ class SwinConfig(PretrainedConfig): ...@@ -84,7 +84,7 @@ class SwinConfig(PretrainedConfig):
encoder_stride (`int`, `optional`, defaults to 32): encoder_stride (`int`, `optional`, defaults to 32):
Factor to increase the spatial resolution by in the decoder head for masked image modeling. Factor to increase the spatial resolution by in the decoder head for masked image modeling.
Example: Example:
```python ```python
>>> from transformers import SwinConfig, SwinModel >>> from transformers import SwinConfig, SwinModel
......
...@@ -320,16 +320,16 @@ def prepare_img(): ...@@ -320,16 +320,16 @@ def prepare_img():
@require_vision @require_vision
@slow @slow
class MaskFormerModelIntegrationTest(unittest.TestCase): class MaskFormerModelIntegrationTest(unittest.TestCase):
@cached_property
def model_checkpoints(self):
return "facebook/maskformer-swin-small-coco"
@cached_property @cached_property
def default_feature_extractor(self): def default_feature_extractor(self):
return MaskFormerFeatureExtractor.from_pretrained(self.model_checkpoints) if is_vision_available() else None return (
MaskFormerFeatureExtractor.from_pretrained("facebook/maskformer-swin-small-coco")
if is_vision_available()
else None
)
def test_inference_no_head(self): def test_inference_no_head(self):
model = MaskFormerModel.from_pretrained(self.model_checkpoints).to(torch_device) model = MaskFormerModel.from_pretrained("facebook/maskformer-swin-small-coco").to(torch_device)
feature_extractor = self.default_feature_extractor feature_extractor = self.default_feature_extractor
image = prepare_img() image = prepare_img()
inputs = feature_extractor(image, return_tensors="pt").to(torch_device) inputs = feature_extractor(image, return_tensors="pt").to(torch_device)
...@@ -370,7 +370,11 @@ class MaskFormerModelIntegrationTest(unittest.TestCase): ...@@ -370,7 +370,11 @@ class MaskFormerModelIntegrationTest(unittest.TestCase):
) )
def test_inference_instance_segmentation_head(self): def test_inference_instance_segmentation_head(self):
model = MaskFormerForInstanceSegmentation.from_pretrained(self.model_checkpoints).to(torch_device).eval() model = (
MaskFormerForInstanceSegmentation.from_pretrained("facebook/maskformer-swin-small-coco")
.to(torch_device)
.eval()
)
feature_extractor = self.default_feature_extractor feature_extractor = self.default_feature_extractor
image = prepare_img() image = prepare_img()
inputs = feature_extractor(image, return_tensors="pt").to(torch_device) inputs = feature_extractor(image, return_tensors="pt").to(torch_device)
...@@ -385,7 +389,8 @@ class MaskFormerModelIntegrationTest(unittest.TestCase): ...@@ -385,7 +389,8 @@ class MaskFormerModelIntegrationTest(unittest.TestCase):
# masks_queries_logits # masks_queries_logits
masks_queries_logits = outputs.masks_queries_logits masks_queries_logits = outputs.masks_queries_logits
self.assertEqual( self.assertEqual(
masks_queries_logits.shape, (1, model.config.num_queries, inputs_shape[-2] // 4, inputs_shape[-1] // 4) masks_queries_logits.shape,
(1, model.config.decoder_config.num_queries, inputs_shape[-2] // 4, inputs_shape[-1] // 4),
) )
expected_slice = [ expected_slice = [
[-1.3737124, -1.7724937, -1.9364233], [-1.3737124, -1.7724937, -1.9364233],
...@@ -396,7 +401,9 @@ class MaskFormerModelIntegrationTest(unittest.TestCase): ...@@ -396,7 +401,9 @@ class MaskFormerModelIntegrationTest(unittest.TestCase):
self.assertTrue(torch.allclose(masks_queries_logits[0, 0, :3, :3], expected_slice, atol=TOLERANCE)) self.assertTrue(torch.allclose(masks_queries_logits[0, 0, :3, :3], expected_slice, atol=TOLERANCE))
# class_queries_logits # class_queries_logits
class_queries_logits = outputs.class_queries_logits class_queries_logits = outputs.class_queries_logits
self.assertEqual(class_queries_logits.shape, (1, model.config.num_queries, model.config.num_labels + 1)) self.assertEqual(
class_queries_logits.shape, (1, model.config.decoder_config.num_queries, model.config.num_labels + 1)
)
expected_slice = torch.tensor( expected_slice = torch.tensor(
[ [
[1.6512e00, -5.2572e00, -3.3519e00], [1.6512e00, -5.2572e00, -3.3519e00],
...@@ -406,8 +413,48 @@ class MaskFormerModelIntegrationTest(unittest.TestCase): ...@@ -406,8 +413,48 @@ class MaskFormerModelIntegrationTest(unittest.TestCase):
).to(torch_device) ).to(torch_device)
self.assertTrue(torch.allclose(outputs.class_queries_logits[0, :3, :3], expected_slice, atol=TOLERANCE)) self.assertTrue(torch.allclose(outputs.class_queries_logits[0, :3, :3], expected_slice, atol=TOLERANCE))
def test_inference_instance_segmentation_head_resnet_backbone(self):
model = (
MaskFormerForInstanceSegmentation.from_pretrained("facebook/maskformer-resnet101-coco-stuff")
.to(torch_device)
.eval()
)
feature_extractor = self.default_feature_extractor
image = prepare_img()
inputs = feature_extractor(image, return_tensors="pt").to(torch_device)
inputs_shape = inputs["pixel_values"].shape
# check size is divisible by 32
self.assertTrue((inputs_shape[-1] % 32) == 0 and (inputs_shape[-2] % 32) == 0)
# check size
self.assertEqual(inputs_shape, (1, 3, 800, 1088))
with torch.no_grad():
outputs = model(**inputs)
# masks_queries_logits
masks_queries_logits = outputs.masks_queries_logits
self.assertEqual(
masks_queries_logits.shape,
(1, model.config.decoder_config.num_queries, inputs_shape[-2] // 4, inputs_shape[-1] // 4),
)
expected_slice = [[-0.9046, -2.6366, -4.6062], [-3.4179, -5.7890, -8.8057], [-4.9179, -7.6560, -10.7711]]
expected_slice = torch.tensor(expected_slice).to(torch_device)
self.assertTrue(torch.allclose(masks_queries_logits[0, 0, :3, :3], expected_slice, atol=TOLERANCE))
# class_queries_logits
class_queries_logits = outputs.class_queries_logits
self.assertEqual(
class_queries_logits.shape, (1, model.config.decoder_config.num_queries, model.config.num_labels + 1)
)
expected_slice = torch.tensor(
[[4.7188, -3.2585, -2.8857], [6.6871, -2.9181, -1.2487], [7.2449, -2.2764, -2.1874]]
).to(torch_device)
self.assertTrue(torch.allclose(outputs.class_queries_logits[0, :3, :3], expected_slice, atol=TOLERANCE))
def test_with_segmentation_maps_and_loss(self): def test_with_segmentation_maps_and_loss(self):
model = MaskFormerForInstanceSegmentation.from_pretrained(self.model_checkpoints).to(torch_device).eval() model = (
MaskFormerForInstanceSegmentation.from_pretrained("facebook/maskformer-swin-small-coco")
.to(torch_device)
.eval()
)
feature_extractor = self.default_feature_extractor feature_extractor = self.default_feature_extractor
inputs = feature_extractor( inputs = feature_extractor(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment