Unverified Commit 787620e2 authored by NielsRogge's avatar NielsRogge Committed by GitHub
Browse files

[Swin] Add Swin SimMIM checkpoints (#20034)



* Fix Swin

* Remove file

* Update code snippet

* Add copied from to maskformer

* Fix docstring

* Add whole name to replace
Co-authored-by: default avatarNiels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
parent 3936411b
......@@ -334,7 +334,7 @@ class DonutSwinDropPath(nn.Module):
# Copied from transformers.models.swin.modeling_swin.SwinSelfAttention with Swin->DonutSwin
class DonutSwinSelfAttention(nn.Module):
def __init__(self, config, dim, num_heads):
def __init__(self, config, dim, num_heads, window_size):
super().__init__()
if dim % num_heads != 0:
raise ValueError(
......@@ -344,7 +344,6 @@ class DonutSwinSelfAttention(nn.Module):
self.num_attention_heads = num_heads
self.attention_head_size = int(dim / num_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
window_size = config.window_size
self.window_size = (
window_size if isinstance(window_size, collections.abc.Iterable) else (window_size, window_size)
)
......@@ -450,9 +449,9 @@ class DonutSwinSelfOutput(nn.Module):
# Copied from transformers.models.swin.modeling_swin.SwinAttention with Swin->DonutSwin
class DonutSwinAttention(nn.Module):
def __init__(self, config, dim, num_heads):
def __init__(self, config, dim, num_heads, window_size):
super().__init__()
self.self = DonutSwinSelfAttention(config, dim, num_heads)
self.self = DonutSwinSelfAttention(config, dim, num_heads, window_size)
self.output = DonutSwinSelfOutput(config, dim)
self.pruned_heads = set()
......@@ -526,7 +525,7 @@ class DonutSwinLayer(nn.Module):
self.input_resolution = input_resolution
self.set_shift_and_window_size(input_resolution)
self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps)
self.attention = DonutSwinAttention(config, dim, num_heads)
self.attention = DonutSwinAttention(config, dim, num_heads, window_size=self.window_size)
self.drop_path = DonutSwinDropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity()
self.layernorm_after = nn.LayerNorm(dim, eps=config.layer_norm_eps)
self.intermediate = DonutSwinIntermediate(config, dim)
......
......@@ -665,7 +665,7 @@ class MaskFormerSwinDropPath(nn.Module):
# Copied from transformers.models.swin.modeling_swin.SwinSelfAttention with Swin->MaskFormerSwin
class MaskFormerSwinSelfAttention(nn.Module):
def __init__(self, config, dim, num_heads):
def __init__(self, config, dim, num_heads, window_size):
super().__init__()
if dim % num_heads != 0:
raise ValueError(
......@@ -675,7 +675,6 @@ class MaskFormerSwinSelfAttention(nn.Module):
self.num_attention_heads = num_heads
self.attention_head_size = int(dim / num_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
window_size = config.window_size
self.window_size = (
window_size if isinstance(window_size, collections.abc.Iterable) else (window_size, window_size)
)
......@@ -781,9 +780,9 @@ class MaskFormerSwinSelfOutput(nn.Module):
# Copied from transformers.models.swin.modeling_swin.SwinAttention with Swin->MaskFormerSwin
class MaskFormerSwinAttention(nn.Module):
def __init__(self, config, dim, num_heads):
def __init__(self, config, dim, num_heads, window_size):
super().__init__()
self.self = MaskFormerSwinSelfAttention(config, dim, num_heads)
self.self = MaskFormerSwinSelfAttention(config, dim, num_heads, window_size)
self.output = MaskFormerSwinSelfOutput(config, dim)
self.pruned_heads = set()
......@@ -847,7 +846,7 @@ class MaskFormerSwinOutput(nn.Module):
return hidden_states
class MaskFormerSwinBlock(nn.Module):
class MaskFormerSwinLayer(nn.Module):
def __init__(self, config, dim, input_resolution, num_heads, shift_size=0):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
......@@ -855,7 +854,7 @@ class MaskFormerSwinBlock(nn.Module):
self.window_size = config.window_size
self.input_resolution = input_resolution
self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps)
self.attention = MaskFormerSwinAttention(config, dim, num_heads)
self.attention = MaskFormerSwinAttention(config, dim, num_heads, self.window_size)
self.drop_path = (
MaskFormerSwinDropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity()
)
......@@ -960,14 +959,15 @@ class MaskFormerSwinBlock(nn.Module):
return outputs
class MaskFormerSwinLayer(nn.Module):
class MaskFormerSwinStage(nn.Module):
# Copied from transformers.models.swin.modeling_swin.SwinStage.__init__ with Swin->MaskFormerSwin
def __init__(self, config, dim, input_resolution, depth, num_heads, drop_path, downsample):
super().__init__()
self.config = config
self.dim = dim
self.blocks = nn.ModuleList(
[
MaskFormerSwinBlock(
MaskFormerSwinLayer(
config=config,
dim=dim,
input_resolution=input_resolution,
......@@ -1016,6 +1016,7 @@ class MaskFormerSwinLayer(nn.Module):
class MaskFormerSwinEncoder(nn.Module):
# Copied from transformers.models.swin.modeling_swin.SwinEncoder.__init__ with Swin->MaskFormerSwin
def __init__(self, config, grid_size):
super().__init__()
self.num_layers = len(config.depths)
......@@ -1023,7 +1024,7 @@ class MaskFormerSwinEncoder(nn.Module):
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths))]
self.layers = nn.ModuleList(
[
MaskFormerSwinLayer(
MaskFormerSwinStage(
config=config,
dim=int(config.embed_dim * 2**i_layer),
input_resolution=(grid_size[0] // (2**i_layer), grid_size[1] // (2**i_layer)),
......
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert Swin SimMIM checkpoints from the original repository.
URL: https://github.com/microsoft/Swin-Transformer/blob/main/MODELHUB.md#simmim-pretrained-swin-v1-models"""
import argparse
import torch
from PIL import Image
import requests
from transformers import SwinConfig, SwinForMaskedImageModeling, ViTFeatureExtractor
def get_swin_config(model_name):
config = SwinConfig(image_size=192)
if "base" in model_name:
window_size = 6
embed_dim = 128
depths = (2, 2, 18, 2)
num_heads = (4, 8, 16, 32)
elif "large" in model_name:
window_size = 12
embed_dim = 192
depths = (2, 2, 18, 2)
num_heads = (6, 12, 24, 48)
else:
raise ValueError("Model not supported, only supports base and large variants")
config.window_size = window_size
config.embed_dim = embed_dim
config.depths = depths
config.num_heads = num_heads
return config
def rename_key(name):
if "encoder.mask_token" in name:
name = name.replace("encoder.mask_token", "embeddings.mask_token")
if "encoder.patch_embed.proj" in name:
name = name.replace("encoder.patch_embed.proj", "embeddings.patch_embeddings.projection")
if "encoder.patch_embed.norm" in name:
name = name.replace("encoder.patch_embed.norm", "embeddings.norm")
if "attn.proj" in name:
name = name.replace("attn.proj", "attention.output.dense")
if "attn" in name:
name = name.replace("attn", "attention.self")
if "norm1" in name:
name = name.replace("norm1", "layernorm_before")
if "norm2" in name:
name = name.replace("norm2", "layernorm_after")
if "mlp.fc1" in name:
name = name.replace("mlp.fc1", "intermediate.dense")
if "mlp.fc2" in name:
name = name.replace("mlp.fc2", "output.dense")
if name == "encoder.norm.weight":
name = "layernorm.weight"
if name == "encoder.norm.bias":
name = "layernorm.bias"
if "decoder" in name:
pass
else:
name = "swin." + name
return name
def convert_state_dict(orig_state_dict, model):
for key in orig_state_dict.copy().keys():
val = orig_state_dict.pop(key)
if "attn_mask" in key:
pass
elif "qkv" in key:
key_split = key.split(".")
layer_num = int(key_split[2])
block_num = int(key_split[4])
dim = model.swin.encoder.layers[layer_num].blocks[block_num].attention.self.all_head_size
if "weight" in key:
orig_state_dict[
f"swin.encoder.layers.{layer_num}.blocks.{block_num}.attention.self.query.weight"
] = val[:dim, :]
orig_state_dict[f"swin.encoder.layers.{layer_num}.blocks.{block_num}.attention.self.key.weight"] = val[
dim : dim * 2, :
]
orig_state_dict[
f"swin.encoder.layers.{layer_num}.blocks.{block_num}.attention.self.value.weight"
] = val[-dim:, :]
else:
orig_state_dict[f"swin.encoder.layers.{layer_num}.blocks.{block_num}.attention.self.query.bias"] = val[
:dim
]
orig_state_dict[f"swin.encoder.layers.{layer_num}.blocks.{block_num}.attention.self.key.bias"] = val[
dim : dim * 2
]
orig_state_dict[f"swin.encoder.layers.{layer_num}.blocks.{block_num}.attention.self.value.bias"] = val[
-dim:
]
else:
orig_state_dict[rename_key(key)] = val
return orig_state_dict
def convert_swin_checkpoint(model_name, checkpoint_path, pytorch_dump_folder_path, push_to_hub):
state_dict = torch.load(checkpoint_path, map_location="cpu")["model"]
config = get_swin_config(model_name)
model = SwinForMaskedImageModeling(config)
model.eval()
new_state_dict = convert_state_dict(state_dict, model)
model.load_state_dict(new_state_dict)
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
feature_extractor = ViTFeatureExtractor(size={"height": 192, "width": 192})
image = Image.open(requests.get(url, stream=True).raw)
inputs = feature_extractor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs).logits
print(outputs.keys())
print("Looks ok!")
if pytorch_dump_folder_path is not None:
print(f"Saving model {model_name} to {pytorch_dump_folder_path}")
model.save_pretrained(pytorch_dump_folder_path)
print(f"Saving feature extractor to {pytorch_dump_folder_path}")
feature_extractor.save_pretrained(pytorch_dump_folder_path)
if push_to_hub:
print(f"Pushing model and feature extractor for {model_name} to hub")
model.push_to_hub(f"microsoft/{model_name}")
feature_extractor.push_to_hub(f"microsoft/{model_name}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--model_name",
default="swin-base-simmim-window6-192",
type=str,
choices=["swin-base-simmim-window6-192", "swin-large-simmim-window12-192"],
help="Name of the Swin SimMIM model you'd like to convert.",
)
parser.add_argument(
"--checkpoint_path",
default="/Users/nielsrogge/Documents/SwinSimMIM/simmim_pretrain__swin_base__img192_window6__100ep.pth",
type=str,
help="Path to the original PyTorch checkpoint (.pth file).",
)
parser.add_argument(
"--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
)
parser.add_argument(
"--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub."
)
args = parser.parse_args()
convert_swin_checkpoint(args.model_name, args.checkpoint_path, args.pytorch_dump_folder_path, args.push_to_hub)
......@@ -405,7 +405,7 @@ class SwinDropPath(nn.Module):
class SwinSelfAttention(nn.Module):
def __init__(self, config, dim, num_heads):
def __init__(self, config, dim, num_heads, window_size):
super().__init__()
if dim % num_heads != 0:
raise ValueError(
......@@ -415,7 +415,6 @@ class SwinSelfAttention(nn.Module):
self.num_attention_heads = num_heads
self.attention_head_size = int(dim / num_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
window_size = config.window_size
self.window_size = (
window_size if isinstance(window_size, collections.abc.Iterable) else (window_size, window_size)
)
......@@ -519,9 +518,9 @@ class SwinSelfOutput(nn.Module):
class SwinAttention(nn.Module):
def __init__(self, config, dim, num_heads):
def __init__(self, config, dim, num_heads, window_size):
super().__init__()
self.self = SwinSelfAttention(config, dim, num_heads)
self.self = SwinSelfAttention(config, dim, num_heads, window_size)
self.output = SwinSelfOutput(config, dim)
self.pruned_heads = set()
......@@ -592,7 +591,7 @@ class SwinLayer(nn.Module):
self.input_resolution = input_resolution
self.set_shift_and_window_size(input_resolution)
self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps)
self.attention = SwinAttention(config, dim, num_heads)
self.attention = SwinAttention(config, dim, num_heads, window_size=self.window_size)
self.drop_path = SwinDropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity()
self.layernorm_after = nn.LayerNorm(dim, eps=config.layer_norm_eps)
self.intermediate = SwinIntermediate(config, dim)
......@@ -1062,8 +1061,8 @@ class SwinForMaskedImageModeling(SwinPreTrainedModel):
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/swin-tiny-patch4-window7-224")
>>> model = SwinForMaskedImageModeling.from_pretrained("microsoft/swin-tiny-patch4-window7-224")
>>> feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/swin-base-simmim-window6-192")
>>> model = SwinForMaskedImageModeling.from_pretrained("microsoft/swin-base-simmim-window6-192")
>>> num_patches = (model.config.image_size // model.config.patch_size) ** 2
>>> pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
......
......@@ -1099,7 +1099,7 @@ class Swinv2Model(Swinv2PreTrainedModel):
""",
SWINV2_START_DOCSTRING,
)
# Copied from transformers.models.swin.modeling_swin.SwinForMaskedImageModeling with SWIN->SWINV2,Swin->Swinv2,swin->swinv2,224->256,window7->window8
# Copied from transformers.models.swin.modeling_swin.SwinForMaskedImageModeling with swin->swinv2, base-simmim-window6-192->tiny-patch4-window8-256, SWIN->SWINV2,Swin->Swinv2, 224->256
class Swinv2ForMaskedImageModeling(Swinv2PreTrainedModel):
def __init__(self, config):
super().__init__(config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment