Unverified Commit 582d085b authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

Add expected output to the sample code for `ViTMSNForImageClassification` (#19183)

* chore: add expected output to the sample code.

* add: imagenet-1k labels to the model config.

* chore: apply code formatting.

* chore: change the expected output.
parent 368b649a
...@@ -15,11 +15,13 @@ ...@@ -15,11 +15,13 @@
"""Convert ViT MSN checkpoints from the original repository: https://github.com/facebookresearch/msn""" """Convert ViT MSN checkpoints from the original repository: https://github.com/facebookresearch/msn"""
import argparse import argparse
import json
import torch import torch
from PIL import Image from PIL import Image
import requests import requests
from huggingface_hub import hf_hub_download
from transformers import ViTFeatureExtractor, ViTMSNConfig, ViTMSNModel from transformers import ViTFeatureExtractor, ViTMSNConfig, ViTMSNModel
from transformers.image_utils import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from transformers.image_utils import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
...@@ -147,6 +149,13 @@ def convert_vit_msn_checkpoint(checkpoint_url, pytorch_dump_folder_path): ...@@ -147,6 +149,13 @@ def convert_vit_msn_checkpoint(checkpoint_url, pytorch_dump_folder_path):
config = ViTMSNConfig() config = ViTMSNConfig()
config.num_labels = 1000 config.num_labels = 1000
repo_id = "datasets/huggingface/label-files"
filename = "imagenet-1k-id2label.json"
id2label = json.load(open(hf_hub_download(repo_id, filename), "r"))
id2label = {int(k): v for k, v in id2label.items()}
config.id2label = id2label
config.label2id = {v: k for k, v in id2label.items()}
if "s16" in checkpoint_url: if "s16" in checkpoint_url:
config.hidden_size = 384 config.hidden_size = 384
config.intermediate_size = 1536 config.intermediate_size = 1536
......
...@@ -632,6 +632,8 @@ class ViTMSNForImageClassification(ViTMSNPreTrainedModel): ...@@ -632,6 +632,8 @@ class ViTMSNForImageClassification(ViTMSNPreTrainedModel):
>>> from PIL import Image >>> from PIL import Image
>>> import requests >>> import requests
>>> torch.manual_seed(2)
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw) >>> image = Image.open(requests.get(url, stream=True).raw)
...@@ -644,6 +646,7 @@ class ViTMSNForImageClassification(ViTMSNPreTrainedModel): ...@@ -644,6 +646,7 @@ class ViTMSNForImageClassification(ViTMSNPreTrainedModel):
>>> # model predicts one of the 1000 ImageNet classes >>> # model predicts one of the 1000 ImageNet classes
>>> predicted_label = logits.argmax(-1).item() >>> predicted_label = logits.argmax(-1).item()
>>> print(model.config.id2label[predicted_label]) >>> print(model.config.id2label[predicted_label])
Kerry blue terrier
```""" ```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment