Unverified Commit 2629c8f3 authored by NielsRogge's avatar NielsRogge Committed by GitHub
Browse files

[DINOv2] Convert more checkpoints (#26177)

* Convert checkpoints

* Update doc test

* Address comment
parent 897a826d
......@@ -508,7 +508,7 @@
- local: model_doc/dinat
title: DiNAT
- local: model_doc/dinov2
title: DINO V2
title: DINOV2
- local: model_doc/dit
title: DiT
- local: model_doc/dpt
......
......@@ -19,14 +19,17 @@ URL: https://github.com/facebookresearch/dinov2/tree/main
import argparse
import json
from pathlib import Path
import requests
import torch
import torch.nn as nn
from huggingface_hub import hf_hub_download
from PIL import Image
from torchvision import transforms
from transformers import BitImageProcessor, Dinov2Config, Dinov2Model
from transformers import BitImageProcessor, Dinov2Config, Dinov2ForImageClassification, Dinov2Model
from transformers.image_utils import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, PILImageResampling
from transformers.utils import logging
......@@ -35,7 +38,7 @@ logging.set_verbosity_info()
logger = logging.get_logger(__name__)
def get_dinov2_config(model_name):
def get_dinov2_config(model_name, image_classifier=False):
config = Dinov2Config(image_size=518, patch_size=14)
# size of the architecture
......@@ -56,6 +59,13 @@ def get_dinov2_config(model_name):
else:
raise ValueError("Model not supported")
if image_classifier:
repo_id = "huggingface/label-files"
filename = "imagenet-1k-id2label.json"
config.num_labels = 1000
config.id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
config.id2label = {int(k): v for k, v in config.id2label.items()}
return config
......@@ -140,10 +150,11 @@ def convert_dinov2_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub=
"""
# define default Dinov2 configuration
config = get_dinov2_config(model_name)
image_classifier = "1layer" in model_name
config = get_dinov2_config(model_name, image_classifier=image_classifier)
# load original model from torch hub
original_model = torch.hub.load("facebookresearch/dinov2", model_name)
original_model = torch.hub.load("facebookresearch/dinov2", model_name.replace("_1layer", ""))
original_model.eval()
# load state_dict of original model, remove and rename some keys
......@@ -162,7 +173,21 @@ def convert_dinov2_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub=
state_dict[key] = val
# load HuggingFace model
model = Dinov2Model(config, add_pooling_layer=False).eval()
if image_classifier:
model = Dinov2ForImageClassification(config).eval()
model.dinov2.load_state_dict(state_dict)
model_name_to_classifier_dict_url = {
"dinov2_vits14_1layer": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_linear_head.pth",
"dinov2_vitb14_1layer": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_linear_head.pth",
"dinov2_vitl14_1layer": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_linear_head.pth",
"dinov2_vitg14_1layer": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_linear_head.pth",
}
url = model_name_to_classifier_dict_url[model_name]
classifier_state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu")
model.classifier.weight = nn.Parameter(classifier_state_dict["weight"])
model.classifier.bias = nn.Parameter(classifier_state_dict["bias"])
else:
model = Dinov2Model(config).eval()
model.load_state_dict(state_dict)
# load image
......@@ -195,10 +220,15 @@ def convert_dinov2_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub=
assert torch.allclose(original_pixel_values, pixel_values)
with torch.no_grad():
outputs = model(pixel_values)
outputs = model(pixel_values, output_hidden_states=True)
original_outputs = original_model(pixel_values)
# assert values
if image_classifier:
print("Predicted class:")
class_idx = outputs.logits.argmax(-1).item()
print(model.config.id2label[class_idx])
else:
assert outputs.last_hidden_state[:, 0].shape == original_outputs.shape
assert torch.allclose(outputs.last_hidden_state[:, 0], original_outputs, atol=1e-3)
print("Looks ok!")
......@@ -216,6 +246,10 @@ def convert_dinov2_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub=
"dinov2_vitb14": "dinov2-base",
"dinov2_vitl14": "dinov2-large",
"dinov2_vitg14": "dinov2-giant",
"dinov2_vits14_1layer": "dinov2-small-imagenet1k-1-layer",
"dinov2_vitb14_1layer": "dinov2-base-imagenet1k-1-layer",
"dinov2_vitl14_1layer": "dinov2-large-imagenet1k-1-layer",
"dinov2_vitg14_1layer": "dinov2-giant-imagenet1k-1-layer",
}
name = model_name_to_hf_name[model_name]
......@@ -230,7 +264,16 @@ if __name__ == "__main__":
"--model_name",
default="dinov2_vitb14",
type=str,
choices=["dinov2_vits14", "dinov2_vitb14", "dinov2_vitl14", "dinov2_vitg14"],
choices=[
"dinov2_vits14",
"dinov2_vitb14",
"dinov2_vitl14",
"dinov2_vitg14",
"dinov2_vits14_1layer",
"dinov2_vitb14_1layer",
"dinov2_vitl14_1layer",
"dinov2_vitg14_1layer",
],
help="Name of the model you'd like to convert.",
)
parser.add_argument(
......
......@@ -54,7 +54,8 @@ _CHECKPOINT_FOR_DOC = "facebook/dinov2-base"
_EXPECTED_OUTPUT_SHAPE = [1, 257, 768]
# Image classification docstring
_IMAGE_CLASS_CHECKPOINT = "facebook/dinov2-base"
_IMAGE_CLASS_CHECKPOINT = "facebook/dinov2-small-imagenet1k-1-layer"
_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat"
DINOV2_PRETRAINED_MODEL_ARCHIVE_LIST = [
......@@ -693,6 +694,7 @@ class Dinov2ForImageClassification(Dinov2PreTrainedModel):
checkpoint=_IMAGE_CLASS_CHECKPOINT,
output_type=ImageClassifierOutput,
config_class=_CONFIG_FOR_DOC,
expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
)
def forward(
self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment