"tests/vscode:/vscode.git/clone" did not exist on "b469ebc5cfd34846ab02e9038e60bc73b4c74a3a"
Unverified Commit 57882177 authored by NielsRogge's avatar NielsRogge Committed by GitHub
Browse files

Add SimMIM (#15586)



* Add first draft

* Make model importable

* Make SwinForMaskedImageModeling importable

* Fix imports

* Add missing inits

* Add support for Swin

* Fix bug

* Fix bug

* Fix another bug

* Fix Swin MIM implementation

* Fix default encoder stride

* Fix Swin

* Add print statements for debugging

* Add image_size data argument

* Fix Swin

* Fix image_size

* Add print statements for debugging

* Fix print statement

* Remove print statements

* Improve reshaping of bool_masked_pos

* Add support for DeiT, fix tests

* Improve docstrings

* Apply new black version

* Improve script

* Fix bug

* Improve README

* Apply suggestions from code review

* Remove DS_Store and add to gitignore

* Apply suggestions from code review + fix BEiT Flax

* Revert BEiT changes

* Improve README

* Fix code quality

* Improve README
Co-authored-by: default avatarNiels Rogge <nielsrogge@Nielss-MBP.localdomain>
Co-authored-by: default avatarNiels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
parent 426b9623
......@@ -24,8 +24,13 @@ from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, SequenceClassifierOutput
from ...file_utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, MaskedLMOutput, SequenceClassifierOutput
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import logging
from .configuration_vit import ViTConfig
......@@ -67,14 +72,15 @@ def to_2tuple(x):
class ViTEmbeddings(nn.Module):
"""
Construct the CLS token, position and patch embeddings.
Construct the CLS token, position and patch embeddings. Optionally, also the mask token.
"""
def __init__(self, config):
def __init__(self, config, use_mask_token=False):
super().__init__()
self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None
self.patch_embeddings = PatchEmbeddings(
image_size=config.image_size,
patch_size=config.patch_size,
......@@ -117,10 +123,17 @@ class ViTEmbeddings(nn.Module):
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
def forward(self, pixel_values, interpolate_pos_encoding=False):
def forward(self, pixel_values, bool_masked_pos=None, interpolate_pos_encoding=False):
batch_size, num_channels, height, width = pixel_values.shape
embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
batch_size, seq_len, _ = embeddings.size()
if bool_masked_pos is not None:
mask_tokens = self.mask_token.expand(batch_size, seq_len, -1)
# replace the masked visual tokens by mask_tokens
mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
# add the [CLS] token to the embedded patch tokens
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
embeddings = torch.cat((cls_tokens, embeddings), dim=1)
......@@ -422,10 +435,6 @@ class ViTPreTrainedModel(PreTrainedModel):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
......@@ -476,11 +485,11 @@ VIT_INPUTS_DOCSTRING = r"""
VIT_START_DOCSTRING,
)
class ViTModel(ViTPreTrainedModel):
def __init__(self, config, add_pooling_layer=True):
def __init__(self, config, add_pooling_layer=True, use_mask_token=False):
super().__init__(config)
self.config = config
self.embeddings = ViTEmbeddings(config)
self.embeddings = ViTEmbeddings(config, use_mask_token=use_mask_token)
self.encoder = ViTEncoder(config)
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
......@@ -512,6 +521,7 @@ class ViTModel(ViTPreTrainedModel):
def forward(
self,
pixel_values=None,
bool_masked_pos=None,
head_mask=None,
output_attentions=None,
output_hidden_states=None,
......@@ -534,7 +544,9 @@ class ViTModel(ViTPreTrainedModel):
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
embedding_output = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
embedding_output = self.embeddings(
pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
)
encoder_outputs = self.encoder(
embedding_output,
......@@ -573,6 +585,107 @@ class ViTPooler(nn.Module):
return pooled_output
@add_start_docstrings(
"ViT Model with a decoder on top for masked image modeling, as proposed in `SimMIM <https://arxiv.org/abs/2111.09886>`__.",
VIT_START_DOCSTRING,
)
class ViTForMaskedImageModeling(ViTPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.vit = ViTModel(config, add_pooling_layer=False, use_mask_token=True)
self.decoder = nn.Sequential(
nn.Conv2d(in_channels=config.hidden_size, out_channels=config.encoder_stride**2 * 3, kernel_size=1),
nn.PixelShuffle(config.encoder_stride),
)
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
pixel_values=None,
bool_masked_pos=None,
head_mask=None,
output_attentions=None,
output_hidden_states=None,
interpolate_pos_encoding=None,
return_dict=None,
):
r"""
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
Returns:
Examples:
```python
>>> from transformers import ViTFeatureExtractor, ViTForMaskedImageModeling
>>> from PIL import Image
>>> import requests
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
>>> model = ViTForMaskedImageModeling.from_pretrained("google/vit-base-patch16-224-in21k")
>>> inputs = feature_extractor(images=image, return_tensors="pt")
>>> outputs = model(**inputs)
>>> last_hidden_states = outputs.last_hidden_state
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.vit(
pixel_values,
bool_masked_pos=bool_masked_pos,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
interpolate_pos_encoding=interpolate_pos_encoding,
return_dict=return_dict,
)
sequence_output = outputs[0]
# Reshape to (batch_size, num_channels, height, width)
sequence_output = sequence_output[:, 1:]
batch_size, sequence_length, num_channels = sequence_output.shape
height = width = int(sequence_length**0.5)
sequence_output = sequence_output.permute(0, 2, 1).reshape(batch_size, num_channels, height, width)
# Reconstruct pixel values
reconstructed_pixel_values = self.decoder(sequence_output)
masked_im_loss = None
if bool_masked_pos is not None:
size = self.config.image_size // self.config.patch_size
bool_masked_pos = bool_masked_pos.reshape(-1, size, size)
mask = (
bool_masked_pos.repeat_interleave(self.config.patch_size, 1)
.repeat_interleave(self.config.patch_size, 2)
.unsqueeze(1)
.contiguous()
)
reconstruction_loss = nn.functional.l1_loss(pixel_values, reconstructed_pixel_values, reduction="none")
masked_im_loss = (reconstruction_loss * mask).sum() / (mask.sum() + 1e-5) / self.config.num_channels
if not return_dict:
output = (reconstructed_pixel_values,) + outputs[2:]
return ((masked_im_loss,) + output) if masked_im_loss is not None else output
return MaskedLMOutput(
loss=masked_im_loss,
logits=reconstructed_pixel_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@add_start_docstrings(
"""
ViT Model transformer with an image classification head on top (a linear layer on top of the final hidden state of
......
......@@ -362,6 +362,9 @@ MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = None
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING = None
MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING = None
MODEL_FOR_MASKED_LM_MAPPING = None
......@@ -460,6 +463,13 @@ class AutoModelForImageSegmentation(metaclass=DummyObject):
requires_backends(self, ["torch"])
class AutoModelForMaskedImageModeling(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class AutoModelForMaskedLM(metaclass=DummyObject):
_backends = ["torch"]
......@@ -1305,6 +1315,13 @@ class DeiTForImageClassificationWithTeacher(metaclass=DummyObject):
requires_backends(self, ["torch"])
class DeiTForMaskedImageModeling(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class DeiTModel(metaclass=DummyObject):
_backends = ["torch"]
......@@ -3449,6 +3466,13 @@ class SwinForImageClassification(metaclass=DummyObject):
requires_backends(self, ["torch"])
class SwinForMaskedImageModeling(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class SwinModel(metaclass=DummyObject):
_backends = ["torch"]
......@@ -3782,6 +3806,13 @@ class ViTForImageClassification(metaclass=DummyObject):
requires_backends(self, ["torch"])
class ViTForMaskedImageModeling(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class ViTModel(metaclass=DummyObject):
_backends = ["torch"]
......
......@@ -72,6 +72,7 @@ if is_torch_available():
MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING,
MODEL_FOR_CAUSAL_LM_MAPPING,
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING,
MODEL_FOR_MASKED_LM_MAPPING,
MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
......@@ -165,6 +166,11 @@ class ModelTesterMixin:
inputs_dict["labels"] = torch.zeros(
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
)
elif model_class in get_values(MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING):
num_patches = self.model_tester.image_size // self.model_tester.patch_size
inputs_dict["bool_masked_pos"] = torch.zeros(
(self.model_tester.batch_size, num_patches**2), dtype=torch.long, device=torch_device
)
return inputs_dict
def test_save_load(self):
......
......@@ -35,6 +35,7 @@ if is_torch_available():
MODEL_MAPPING,
DeiTForImageClassification,
DeiTForImageClassificationWithTeacher,
DeiTForMaskedImageModeling,
DeiTModel,
)
from transformers.models.deit.modeling_deit import DEIT_PRETRAINED_MODEL_ARCHIVE_LIST, to_2tuple
......@@ -67,6 +68,7 @@ class DeiTModelTester:
initializer_range=0.02,
num_labels=3,
scope=None,
encoder_stride=2,
):
self.parent = parent
self.batch_size = batch_size
......@@ -85,6 +87,7 @@ class DeiTModelTester:
self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range
self.scope = scope
self.encoder_stride = encoder_stride
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
......@@ -111,6 +114,7 @@ class DeiTModelTester:
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
is_decoder=False,
initializer_range=self.initializer_range,
encoder_stride=self.encoder_stride,
)
def create_and_check_model(self, config, pixel_values, labels):
......@@ -155,6 +159,7 @@ class DeiTModelTest(ModelTesterMixin, unittest.TestCase):
DeiTModel,
DeiTForImageClassification,
DeiTForImageClassificationWithTeacher,
DeiTForMaskedImageModeling,
)
if is_torch_available()
else ()
......
......@@ -31,7 +31,7 @@ if is_torch_available():
import torch
from torch import nn
from transformers import SwinForImageClassification, SwinModel
from transformers import SwinForImageClassification, SwinForMaskedImageModeling, SwinModel
from transformers.models.swin.modeling_swin import SWIN_PRETRAINED_MODEL_ARCHIVE_LIST, to_2tuple
if is_vision_available():
......@@ -74,6 +74,7 @@ class SwinModelTester:
scope=None,
use_labels=True,
type_sequence_label_size=10,
encoder_stride=2,
):
self.parent = parent
self.batch_size = batch_size
......@@ -98,6 +99,7 @@ class SwinModelTester:
self.scope = scope
self.use_labels = use_labels
self.type_sequence_label_size = type_sequence_label_size
self.encoder_stride = encoder_stride
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
......@@ -129,6 +131,7 @@ class SwinModelTester:
path_norm=self.patch_norm,
layer_norm_eps=self.layer_norm_eps,
initializer_range=self.initializer_range,
encoder_stride=self.encoder_stride,
)
def create_and_check_model(self, config, pixel_values, labels):
......@@ -169,6 +172,7 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
(
SwinModel,
SwinForImageClassification,
SwinForMaskedImageModeling,
)
if is_torch_available()
else ()
......
......@@ -30,7 +30,7 @@ if is_torch_available():
import torch
from torch import nn
from transformers import ViTForImageClassification, ViTModel
from transformers import ViTForImageClassification, ViTForMaskedImageModeling, ViTModel
from transformers.models.vit.modeling_vit import VIT_PRETRAINED_MODEL_ARCHIVE_LIST, to_2tuple
......@@ -61,6 +61,7 @@ class ViTModelTester:
initializer_range=0.02,
num_labels=3,
scope=None,
encoder_stride=2,
):
self.parent = parent
self.batch_size = batch_size
......@@ -79,6 +80,7 @@ class ViTModelTester:
self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range
self.scope = scope
self.encoder_stride = encoder_stride
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
......@@ -105,6 +107,7 @@ class ViTModelTester:
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
is_decoder=False,
initializer_range=self.initializer_range,
encoder_stride=self.encoder_stride,
)
def create_and_check_model(self, config, pixel_values, labels):
......@@ -148,6 +151,7 @@ class ViTModelTest(ModelTesterMixin, unittest.TestCase):
(
ViTModel,
ViTForImageClassification,
ViTForMaskedImageModeling,
)
if is_torch_available()
else ()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment