Unverified Commit 2629c8f3 authored by NielsRogge's avatar NielsRogge Committed by GitHub
Browse files

[DINOv2] Convert more checkpoints (#26177)

* Convert checkpoints

* Update doc test

* Address comment
parent 897a826d
...@@ -508,7 +508,7 @@ ...@@ -508,7 +508,7 @@
- local: model_doc/dinat - local: model_doc/dinat
title: DiNAT title: DiNAT
- local: model_doc/dinov2 - local: model_doc/dinov2
title: DINO V2 title: DINOV2
- local: model_doc/dit - local: model_doc/dit
title: DiT title: DiT
- local: model_doc/dpt - local: model_doc/dpt
......
...@@ -19,14 +19,17 @@ URL: https://github.com/facebookresearch/dinov2/tree/main ...@@ -19,14 +19,17 @@ URL: https://github.com/facebookresearch/dinov2/tree/main
import argparse import argparse
import json
from pathlib import Path from pathlib import Path
import requests import requests
import torch import torch
import torch.nn as nn
from huggingface_hub import hf_hub_download
from PIL import Image from PIL import Image
from torchvision import transforms from torchvision import transforms
from transformers import BitImageProcessor, Dinov2Config, Dinov2Model from transformers import BitImageProcessor, Dinov2Config, Dinov2ForImageClassification, Dinov2Model
from transformers.image_utils import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, PILImageResampling from transformers.image_utils import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, PILImageResampling
from transformers.utils import logging from transformers.utils import logging
...@@ -35,7 +38,7 @@ logging.set_verbosity_info() ...@@ -35,7 +38,7 @@ logging.set_verbosity_info()
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
def get_dinov2_config(model_name): def get_dinov2_config(model_name, image_classifier=False):
config = Dinov2Config(image_size=518, patch_size=14) config = Dinov2Config(image_size=518, patch_size=14)
# size of the architecture # size of the architecture
...@@ -56,6 +59,13 @@ def get_dinov2_config(model_name): ...@@ -56,6 +59,13 @@ def get_dinov2_config(model_name):
else: else:
raise ValueError("Model not supported") raise ValueError("Model not supported")
if image_classifier:
repo_id = "huggingface/label-files"
filename = "imagenet-1k-id2label.json"
config.num_labels = 1000
config.id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
config.id2label = {int(k): v for k, v in config.id2label.items()}
return config return config
...@@ -140,10 +150,11 @@ def convert_dinov2_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub= ...@@ -140,10 +150,11 @@ def convert_dinov2_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub=
""" """
# define default Dinov2 configuration # define default Dinov2 configuration
config = get_dinov2_config(model_name) image_classifier = "1layer" in model_name
config = get_dinov2_config(model_name, image_classifier=image_classifier)
# load original model from torch hub # load original model from torch hub
original_model = torch.hub.load("facebookresearch/dinov2", model_name) original_model = torch.hub.load("facebookresearch/dinov2", model_name.replace("_1layer", ""))
original_model.eval() original_model.eval()
# load state_dict of original model, remove and rename some keys # load state_dict of original model, remove and rename some keys
...@@ -162,7 +173,21 @@ def convert_dinov2_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub= ...@@ -162,7 +173,21 @@ def convert_dinov2_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub=
state_dict[key] = val state_dict[key] = val
# load HuggingFace model # load HuggingFace model
model = Dinov2Model(config, add_pooling_layer=False).eval() if image_classifier:
model = Dinov2ForImageClassification(config).eval()
model.dinov2.load_state_dict(state_dict)
model_name_to_classifier_dict_url = {
"dinov2_vits14_1layer": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_linear_head.pth",
"dinov2_vitb14_1layer": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_linear_head.pth",
"dinov2_vitl14_1layer": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_linear_head.pth",
"dinov2_vitg14_1layer": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_linear_head.pth",
}
url = model_name_to_classifier_dict_url[model_name]
classifier_state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu")
model.classifier.weight = nn.Parameter(classifier_state_dict["weight"])
model.classifier.bias = nn.Parameter(classifier_state_dict["bias"])
else:
model = Dinov2Model(config).eval()
model.load_state_dict(state_dict) model.load_state_dict(state_dict)
# load image # load image
...@@ -195,10 +220,15 @@ def convert_dinov2_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub= ...@@ -195,10 +220,15 @@ def convert_dinov2_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub=
assert torch.allclose(original_pixel_values, pixel_values) assert torch.allclose(original_pixel_values, pixel_values)
with torch.no_grad(): with torch.no_grad():
outputs = model(pixel_values) outputs = model(pixel_values, output_hidden_states=True)
original_outputs = original_model(pixel_values) original_outputs = original_model(pixel_values)
# assert values # assert values
if image_classifier:
print("Predicted class:")
class_idx = outputs.logits.argmax(-1).item()
print(model.config.id2label[class_idx])
else:
assert outputs.last_hidden_state[:, 0].shape == original_outputs.shape assert outputs.last_hidden_state[:, 0].shape == original_outputs.shape
assert torch.allclose(outputs.last_hidden_state[:, 0], original_outputs, atol=1e-3) assert torch.allclose(outputs.last_hidden_state[:, 0], original_outputs, atol=1e-3)
print("Looks ok!") print("Looks ok!")
...@@ -216,6 +246,10 @@ def convert_dinov2_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub= ...@@ -216,6 +246,10 @@ def convert_dinov2_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub=
"dinov2_vitb14": "dinov2-base", "dinov2_vitb14": "dinov2-base",
"dinov2_vitl14": "dinov2-large", "dinov2_vitl14": "dinov2-large",
"dinov2_vitg14": "dinov2-giant", "dinov2_vitg14": "dinov2-giant",
"dinov2_vits14_1layer": "dinov2-small-imagenet1k-1-layer",
"dinov2_vitb14_1layer": "dinov2-base-imagenet1k-1-layer",
"dinov2_vitl14_1layer": "dinov2-large-imagenet1k-1-layer",
"dinov2_vitg14_1layer": "dinov2-giant-imagenet1k-1-layer",
} }
name = model_name_to_hf_name[model_name] name = model_name_to_hf_name[model_name]
...@@ -230,7 +264,16 @@ if __name__ == "__main__": ...@@ -230,7 +264,16 @@ if __name__ == "__main__":
"--model_name", "--model_name",
default="dinov2_vitb14", default="dinov2_vitb14",
type=str, type=str,
choices=["dinov2_vits14", "dinov2_vitb14", "dinov2_vitl14", "dinov2_vitg14"], choices=[
"dinov2_vits14",
"dinov2_vitb14",
"dinov2_vitl14",
"dinov2_vitg14",
"dinov2_vits14_1layer",
"dinov2_vitb14_1layer",
"dinov2_vitl14_1layer",
"dinov2_vitg14_1layer",
],
help="Name of the model you'd like to convert.", help="Name of the model you'd like to convert.",
) )
parser.add_argument( parser.add_argument(
......
...@@ -54,7 +54,8 @@ _CHECKPOINT_FOR_DOC = "facebook/dinov2-base" ...@@ -54,7 +54,8 @@ _CHECKPOINT_FOR_DOC = "facebook/dinov2-base"
_EXPECTED_OUTPUT_SHAPE = [1, 257, 768] _EXPECTED_OUTPUT_SHAPE = [1, 257, 768]
# Image classification docstring # Image classification docstring
_IMAGE_CLASS_CHECKPOINT = "facebook/dinov2-base" _IMAGE_CLASS_CHECKPOINT = "facebook/dinov2-small-imagenet1k-1-layer"
_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat"
DINOV2_PRETRAINED_MODEL_ARCHIVE_LIST = [ DINOV2_PRETRAINED_MODEL_ARCHIVE_LIST = [
...@@ -693,6 +694,7 @@ class Dinov2ForImageClassification(Dinov2PreTrainedModel): ...@@ -693,6 +694,7 @@ class Dinov2ForImageClassification(Dinov2PreTrainedModel):
checkpoint=_IMAGE_CLASS_CHECKPOINT, checkpoint=_IMAGE_CLASS_CHECKPOINT,
output_type=ImageClassifierOutput, output_type=ImageClassifierOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
) )
def forward( def forward(
self, self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment