"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "b5ca8fcd20da3ed6aa562ca926c4c8e2c56fe6a4"
Unverified Commit 0759f251 authored by NielsRogge's avatar NielsRogge Committed by GitHub
Browse files

Add DINO conversion script (#13265)

* First commit

* Add interpolation of patch embeddings

* Comment out code

* Fix bug

* Fix another bug

* Fix bug

* Fix another bug

* Remove print statements

* Update conversion script

* Use the official vit implementation

* Add support for converting dino_vits8

* Add DINO to docs of ViT

* Remove assertion

* Add interpolation of position encodings

* Fix bug

* Add align_corners

* Add interpolate_pos_encoding option to forward pass of ViTModel

* Improve interpolate_pos_encoding method

* Add docstring
parent 14e52783
...@@ -66,6 +66,23 @@ Tips: ...@@ -66,6 +66,23 @@ Tips:
language modeling). With this approach, the smaller ViT-B/16 model achieves 79.9% accuracy on ImageNet, a significant language modeling). With this approach, the smaller ViT-B/16 model achieves 79.9% accuracy on ImageNet, a significant
improvement of 2% to training from scratch, but still 4% behind supervised pre-training. improvement of 2% to training from scratch, but still 4% behind supervised pre-training.
Following the original Vision Transformer, some follow-up works have been made:
- DeiT (Data-efficient Image Transformers) by Facebook AI. DeiT models are distilled vision transformers. Refer to
:doc:`DeiT's documentation page <deit>`. The authors of DeiT also released more efficiently trained ViT models, which
you can directly plug into :class:`~transformers.ViTModel` or :class:`~transformers.ViTForImageClassification`. There
are 4 variants available (in 3 different sizes): `facebook/deit-tiny-patch16-224`, `facebook/deit-small-patch16-224`,
`facebook/deit-base-patch16-224` and `facebook/deit-base-patch16-384`. Note that one should use
:class:`~transformers.DeiTFeatureExtractor` in order to prepare images for the model.
- BEiT (BERT pre-training of Image Transformers) by Microsoft Research. BEiT models outperform supervised pre-trained
vision transformers using a self-supervised method inspired by BERT (masked image modeling) and based on a VQ-VAE.
Refer to :doc:`BEiT's documentation page <beit>`.
- DINO (a method for self-supervised training of Vision Transformers) by Facebook AI. Vision Transformers trained using
the DINO method show very interesting properties not seen with convolutional models. They are capable of segmenting
objects, without having ever been trained to do so. DINO checkpoints can be found on the `hub
<https://huggingface.co/models?other=dino>`__.
This model was contributed by `nielsr <https://huggingface.co/nielsr>`__. The original code (written in JAX) can be This model was contributed by `nielsr <https://huggingface.co/nielsr>`__. The original code (written in JAX) can be
found `here <https://github.com/google-research/vision_transformer>`__. found `here <https://github.com/google-research/vision_transformer>`__.
......
...@@ -93,7 +93,6 @@ class DeiTEmbeddings(nn.Module): ...@@ -93,7 +93,6 @@ class DeiTEmbeddings(nn.Module):
return embeddings return embeddings
# Copied from transformers.models.vit.modeling_vit.PatchEmbeddings
class PatchEmbeddings(nn.Module): class PatchEmbeddings(nn.Module):
""" """
Image to Patch Embedding. Image to Patch Embedding.
......
# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert ViT checkpoints trained with the DINO method."""
import argparse
import json
from pathlib import Path
import torch
from PIL import Image
import requests
from huggingface_hub import cached_download, hf_hub_url
from transformers import ViTConfig, ViTFeatureExtractor, ViTForImageClassification, ViTModel
from transformers.utils import logging
logging.set_verbosity_info()
logger = logging.get_logger(__name__)
# here we list all keys to be renamed (original name on the left, our name on the right)
def create_rename_keys(config, base_model=False):
rename_keys = []
for i in range(config.num_hidden_layers):
# encoder layers: output projection, 2 feedforward neural networks and 2 layernorms
rename_keys.append((f"blocks.{i}.norm1.weight", f"vit.encoder.layer.{i}.layernorm_before.weight"))
rename_keys.append((f"blocks.{i}.norm1.bias", f"vit.encoder.layer.{i}.layernorm_before.bias"))
rename_keys.append((f"blocks.{i}.attn.proj.weight", f"vit.encoder.layer.{i}.attention.output.dense.weight"))
rename_keys.append((f"blocks.{i}.attn.proj.bias", f"vit.encoder.layer.{i}.attention.output.dense.bias"))
rename_keys.append((f"blocks.{i}.norm2.weight", f"vit.encoder.layer.{i}.layernorm_after.weight"))
rename_keys.append((f"blocks.{i}.norm2.bias", f"vit.encoder.layer.{i}.layernorm_after.bias"))
rename_keys.append((f"blocks.{i}.mlp.fc1.weight", f"vit.encoder.layer.{i}.intermediate.dense.weight"))
rename_keys.append((f"blocks.{i}.mlp.fc1.bias", f"vit.encoder.layer.{i}.intermediate.dense.bias"))
rename_keys.append((f"blocks.{i}.mlp.fc2.weight", f"vit.encoder.layer.{i}.output.dense.weight"))
rename_keys.append((f"blocks.{i}.mlp.fc2.bias", f"vit.encoder.layer.{i}.output.dense.bias"))
# projection layer + position embeddings
rename_keys.extend(
[
("cls_token", "vit.embeddings.cls_token"),
("patch_embed.proj.weight", "vit.embeddings.patch_embeddings.projection.weight"),
("patch_embed.proj.bias", "vit.embeddings.patch_embeddings.projection.bias"),
("pos_embed", "vit.embeddings.position_embeddings"),
]
)
if base_model:
# layernorm + pooler
rename_keys.extend(
[
("norm.weight", "layernorm.weight"),
("norm.bias", "layernorm.bias"),
]
)
# if just the base model, we should remove "vit" from all keys that start with "vit"
rename_keys = [(pair[0], pair[1][4:]) if pair[1].startswith("vit") else pair for pair in rename_keys]
else:
# layernorm + classification head
rename_keys.extend(
[
("norm.weight", "vit.layernorm.weight"),
("norm.bias", "vit.layernorm.bias"),
("head.weight", "classifier.weight"),
("head.bias", "classifier.bias"),
]
)
return rename_keys
# we split up the matrix of each encoder layer into queries, keys and values
def read_in_q_k_v(state_dict, config, base_model=False):
for i in range(config.num_hidden_layers):
if base_model:
prefix = ""
else:
prefix = "vit."
# read in weights + bias of input projection layer (in timm, this is a single matrix + bias)
in_proj_weight = state_dict.pop(f"blocks.{i}.attn.qkv.weight")
in_proj_bias = state_dict.pop(f"blocks.{i}.attn.qkv.bias")
# next, add query, keys and values (in that order) to the state dict
state_dict[f"{prefix}encoder.layer.{i}.attention.attention.query.weight"] = in_proj_weight[
: config.hidden_size, :
]
state_dict[f"{prefix}encoder.layer.{i}.attention.attention.query.bias"] = in_proj_bias[: config.hidden_size]
state_dict[f"{prefix}encoder.layer.{i}.attention.attention.key.weight"] = in_proj_weight[
config.hidden_size : config.hidden_size * 2, :
]
state_dict[f"{prefix}encoder.layer.{i}.attention.attention.key.bias"] = in_proj_bias[
config.hidden_size : config.hidden_size * 2
]
state_dict[f"{prefix}encoder.layer.{i}.attention.attention.value.weight"] = in_proj_weight[
-config.hidden_size :, :
]
state_dict[f"{prefix}encoder.layer.{i}.attention.attention.value.bias"] = in_proj_bias[-config.hidden_size :]
def remove_classification_head_(state_dict):
ignore_keys = ["head.weight", "head.bias"]
for k in ignore_keys:
state_dict.pop(k, None)
def rename_key(dct, old, new):
val = dct.pop(old)
dct[new] = val
# We will verify our results on an image of cute cats
def prepare_img():
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
im = Image.open(requests.get(url, stream=True).raw)
return im
@torch.no_grad()
def convert_vit_checkpoint(model_name, pytorch_dump_folder_path, base_model=True):
"""
Copy/paste/tweak model's weights to our ViT structure.
"""
# define default ViT configuration
config = ViTConfig()
# patch_size
if model_name[-1] == "8":
config.patch_size = 8
# set labels if required
if not base_model:
config.num_labels = 1000
repo_id = "datasets/huggingface/label-files"
filename = "imagenet-1k-id2label.json"
id2label = json.load(open(cached_download(hf_hub_url(repo_id, filename)), "r"))
id2label = {int(k): v for k, v in id2label.items()}
config.id2label = id2label
config.label2id = {v: k for k, v in id2label.items()}
# size of the architecture
if model_name in ["dino_vits8", "dino_vits16"]:
config.hidden_size = 384
config.intermediate_size = 1536
config.num_hidden_layers = 12
config.num_attention_heads = 6
# load original model from torch hub
original_model = torch.hub.load("facebookresearch/dino:main", model_name)
original_model.eval()
# load state_dict of original model, remove and rename some keys
state_dict = original_model.state_dict()
if base_model:
remove_classification_head_(state_dict)
rename_keys = create_rename_keys(config, base_model=base_model)
for src, dest in rename_keys:
rename_key(state_dict, src, dest)
read_in_q_k_v(state_dict, config, base_model)
# load HuggingFace model
if base_model:
model = ViTModel(config, add_pooling_layer=False).eval()
else:
model = ViTForImageClassification(config).eval()
model.load_state_dict(state_dict)
# Check outputs on an image, prepared by ViTFeatureExtractor
feature_extractor = ViTFeatureExtractor()
encoding = feature_extractor(images=prepare_img(), return_tensors="pt")
pixel_values = encoding["pixel_values"]
outputs = model(pixel_values)
if base_model:
final_hidden_state_cls_token = original_model(pixel_values)
assert torch.allclose(final_hidden_state_cls_token, outputs.last_hidden_state[:, 0, :], atol=1e-1)
else:
logits = original_model(pixel_values)
assert logits.shape == outputs.logits.shape
assert torch.allclose(logits, outputs.logits, atol=1e-3)
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
print(f"Saving model {model_name} to {pytorch_dump_folder_path}")
model.save_pretrained(pytorch_dump_folder_path)
print(f"Saving feature extractor to {pytorch_dump_folder_path}")
feature_extractor.save_pretrained(pytorch_dump_folder_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--model_name",
default="dino_vitb16",
type=str,
help="Name of the model trained with DINO you'd like to convert.",
)
parser.add_argument(
"--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
)
parser.add_argument(
"--base_model",
action="store_true",
help="Whether to only convert the base model (no projection head weights).",
)
parser.set_defaults(base_model=True)
args = parser.parse_args()
convert_vit_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.base_model)
...@@ -74,15 +74,55 @@ class ViTEmbeddings(nn.Module): ...@@ -74,15 +74,55 @@ class ViTEmbeddings(nn.Module):
num_patches = self.patch_embeddings.num_patches num_patches = self.patch_embeddings.num_patches
self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size)) self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size))
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.config = config
def interpolate_pos_encoding(self, embeddings, height, width):
"""
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
resolution images.
Source:
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
"""
npatch = embeddings.shape[1] - 1
N = self.position_embeddings.shape[1] - 1
if npatch == N and height == width:
return self.position_embeddings
class_pos_embed = self.position_embeddings[:, 0]
patch_pos_embed = self.position_embeddings[:, 1:]
dim = embeddings.shape[-1]
h0 = height // self.config.patch_size
w0 = width // self.config.patch_size
# we add a small number to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
h0, w0 = h0 + 0.1, w0 + 0.1
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
scale_factor=(h0 / math.sqrt(N), w0 / math.sqrt(N)),
mode="bicubic",
align_corners=False,
)
assert int(h0) == patch_pos_embed.shape[-1] and int(w0) == patch_pos_embed.shape[-2]
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
def forward(self, pixel_values): def forward(self, pixel_values, interpolate_pos_encoding=False):
batch_size = pixel_values.shape[0] batch_size, num_channels, height, width = pixel_values.shape
embeddings = self.patch_embeddings(pixel_values) embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
# add the [CLS] token to the embedded patch tokens
cls_tokens = self.cls_token.expand(batch_size, -1, -1) cls_tokens = self.cls_token.expand(batch_size, -1, -1)
embeddings = torch.cat((cls_tokens, embeddings), dim=1) embeddings = torch.cat((cls_tokens, embeddings), dim=1)
embeddings = embeddings + self.position_embeddings
# add positional encoding to each token
if interpolate_pos_encoding:
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
else:
embeddings = embeddings + self.position_embeddings
embeddings = self.dropout(embeddings) embeddings = self.dropout(embeddings)
return embeddings return embeddings
...@@ -105,13 +145,13 @@ class PatchEmbeddings(nn.Module): ...@@ -105,13 +145,13 @@ class PatchEmbeddings(nn.Module):
self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size) self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, pixel_values): def forward(self, pixel_values, interpolate_pos_encoding=False):
batch_size, num_channels, height, width = pixel_values.shape batch_size, num_channels, height, width = pixel_values.shape
# FIXME look at relaxing size constraints if not interpolate_pos_encoding:
if height != self.image_size[0] or width != self.image_size[1]: if height != self.image_size[0] or width != self.image_size[1]:
raise ValueError( raise ValueError(
f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})." f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
) )
x = self.projection(pixel_values).flatten(2).transpose(1, 2) x = self.projection(pixel_values).flatten(2).transpose(1, 2)
return x return x
...@@ -419,6 +459,8 @@ VIT_INPUTS_DOCSTRING = r""" ...@@ -419,6 +459,8 @@ VIT_INPUTS_DOCSTRING = r"""
output_hidden_states (:obj:`bool`, `optional`): output_hidden_states (:obj:`bool`, `optional`):
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
more detail. more detail.
interpolate_pos_encoding (:obj:`bool`, `optional`):
Whether to interpolate the pre-trained position encodings.
return_dict (:obj:`bool`, `optional`): return_dict (:obj:`bool`, `optional`):
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
""" """
...@@ -460,6 +502,7 @@ class ViTModel(ViTPreTrainedModel): ...@@ -460,6 +502,7 @@ class ViTModel(ViTPreTrainedModel):
head_mask=None, head_mask=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
interpolate_pos_encoding=None,
return_dict=None, return_dict=None,
): ):
r""" r"""
...@@ -497,7 +540,7 @@ class ViTModel(ViTPreTrainedModel): ...@@ -497,7 +540,7 @@ class ViTModel(ViTPreTrainedModel):
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
embedding_output = self.embeddings(pixel_values) embedding_output = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
encoder_outputs = self.encoder( encoder_outputs = self.encoder(
embedding_output, embedding_output,
...@@ -564,6 +607,7 @@ class ViTForImageClassification(ViTPreTrainedModel): ...@@ -564,6 +607,7 @@ class ViTForImageClassification(ViTPreTrainedModel):
labels=None, labels=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
interpolate_pos_encoding=None,
return_dict=None, return_dict=None,
): ):
r""" r"""
...@@ -600,6 +644,7 @@ class ViTForImageClassification(ViTPreTrainedModel): ...@@ -600,6 +644,7 @@ class ViTForImageClassification(ViTPreTrainedModel):
head_mask=head_mask, head_mask=head_mask,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
interpolate_pos_encoding=interpolate_pos_encoding,
return_dict=return_dict, return_dict=return_dict,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment