Unverified Commit 30677dc7 authored by NielsRogge's avatar NielsRogge Committed by GitHub
Browse files

Add Vision Transformer and ViTFeatureExtractor (#10950)



* Squash all commits into one

* Update ViTFeatureExtractor to use image_utils instead of torchvision

* Remove torchvision and add Pillow

* Small docs improvement

* Address most comments by @sgugger

* Fix tests

* Clean up conversion script

* Pooler first draft

* Fix quality

* Improve conversion script

* Make style and quality

* Make fix-copies

* Minor docs improvements

* Should use fix-copies instead of manual handling

* Revert "Should use fix-copies instead of manual handling"

This reverts commit fd4e591bce4496d41406425c82606a8fdaf8a50b.

* Place ViT in alphabetical order
Co-authored-by: default avatarLysandre <lysandre.debut@reseau.eseo.fr>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent af673222
...@@ -80,8 +80,8 @@ jobs: ...@@ -80,8 +80,8 @@ jobs:
- v0.4-{{ checksum "setup.py" }} - v0.4-{{ checksum "setup.py" }}
- run: sudo apt-get -y update && sudo apt-get install -y libsndfile1-dev - run: sudo apt-get -y update && sudo apt-get install -y libsndfile1-dev
- run: pip install --upgrade pip - run: pip install --upgrade pip
- run: pip install .[sklearn,tf-cpu,torch,testing,sentencepiece,speech] - run: pip install .[sklearn,tf-cpu,torch,testing,sentencepiece,speech,vision]
- run: pip install tapas torch-scatter -f https://pytorch-geometric.com/whl/torch-1.8.0+cpu.html - run: pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.8.0+cpu.html
- save_cache: - save_cache:
key: v0.4-{{ checksum "setup.py" }} key: v0.4-{{ checksum "setup.py" }}
paths: paths:
...@@ -110,8 +110,8 @@ jobs: ...@@ -110,8 +110,8 @@ jobs:
- v0.4-{{ checksum "setup.py" }} - v0.4-{{ checksum "setup.py" }}
- run: sudo apt-get -y update && sudo apt-get install -y libsndfile1-dev - run: sudo apt-get -y update && sudo apt-get install -y libsndfile1-dev
- run: pip install --upgrade pip - run: pip install --upgrade pip
- run: pip install .[sklearn,flax,torch,testing,sentencepiece,speech] - run: pip install .[sklearn,flax,torch,testing,sentencepiece,speech,vision]
- run: pip install tapas torch-scatter -f https://pytorch-geometric.com/whl/torch-1.8.0+cpu.html - run: pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.8.0+cpu.html
- save_cache: - save_cache:
key: v0.4-{{ checksum "setup.py" }} key: v0.4-{{ checksum "setup.py" }}
paths: paths:
...@@ -139,8 +139,8 @@ jobs: ...@@ -139,8 +139,8 @@ jobs:
- v0.4-{{ checksum "setup.py" }} - v0.4-{{ checksum "setup.py" }}
- run: sudo apt-get -y update && sudo apt-get install -y libsndfile1-dev - run: sudo apt-get -y update && sudo apt-get install -y libsndfile1-dev
- run: pip install --upgrade pip - run: pip install --upgrade pip
- run: pip install .[sklearn,torch,testing,sentencepiece,speech] - run: pip install .[sklearn,torch,testing,sentencepiece,speech,vision]
- run: pip install tapas torch-scatter -f https://pytorch-geometric.com/whl/torch-1.8.0+cpu.html - run: pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.8.0+cpu.html
- save_cache: - save_cache:
key: v0.4-torch-{{ checksum "setup.py" }} key: v0.4-torch-{{ checksum "setup.py" }}
paths: paths:
...@@ -223,8 +223,8 @@ jobs: ...@@ -223,8 +223,8 @@ jobs:
- v0.4-{{ checksum "setup.py" }} - v0.4-{{ checksum "setup.py" }}
- run: sudo apt-get -y update && sudo apt-get install -y libsndfile1-dev - run: sudo apt-get -y update && sudo apt-get install -y libsndfile1-dev
- run: pip install --upgrade pip - run: pip install --upgrade pip
- run: pip install .[sklearn,torch,testing,sentencepiece,speech] - run: pip install .[sklearn,torch,testing,sentencepiece,speech,vision]
- run: pip install tapas torch-scatter -f https://pytorch-geometric.com/whl/torch-1.8.0+cpu.html - run: pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.8.0+cpu.html
- save_cache: - save_cache:
key: v0.4-torch-{{ checksum "setup.py" }} key: v0.4-torch-{{ checksum "setup.py" }}
paths: paths:
......
...@@ -234,6 +234,7 @@ Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, and Wen-tau Yih. ...@@ -234,6 +234,7 @@ Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, and Wen-tau Yih.
1. **[T5](https://huggingface.co/transformers/model_doc/t5.html)** (from Google AI) released with the paper [Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer](https://arxiv.org/abs/1910.10683) by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu. 1. **[T5](https://huggingface.co/transformers/model_doc/t5.html)** (from Google AI) released with the paper [Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer](https://arxiv.org/abs/1910.10683) by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu.
1. **[TAPAS](https://huggingface.co/transformers/model_doc/tapas.html)** (from Google AI) released with the paper [TAPAS: Weakly Supervised Table Parsing via Pre-training](https://arxiv.org/abs/2004.02349) by Jonathan Herzig, Paweł Krzysztof Nowak, Thomas Müller, Francesco Piccinno and Julian Martin Eisenschlos. 1. **[TAPAS](https://huggingface.co/transformers/model_doc/tapas.html)** (from Google AI) released with the paper [TAPAS: Weakly Supervised Table Parsing via Pre-training](https://arxiv.org/abs/2004.02349) by Jonathan Herzig, Paweł Krzysztof Nowak, Thomas Müller, Francesco Piccinno and Julian Martin Eisenschlos.
1. **[Transformer-XL](https://huggingface.co/transformers/model_doc/transformerxl.html)** (from Google/CMU) released with the paper [Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context](https://arxiv.org/abs/1901.02860) by Zihang Dai*, Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov. 1. **[Transformer-XL](https://huggingface.co/transformers/model_doc/transformerxl.html)** (from Google/CMU) released with the paper [Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context](https://arxiv.org/abs/1901.02860) by Zihang Dai*, Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov.
1. **[Vision Transformer (ViT)](https://huggingface.co/transformers/model_doc/vit.html)** (from Google AI) released with the paper [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929) by Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby.
1. **[Wav2Vec2](https://huggingface.co/transformers/model_doc/wav2vec2.html)** (from Facebook AI) released with the paper [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael Auli. 1. **[Wav2Vec2](https://huggingface.co/transformers/model_doc/wav2vec2.html)** (from Facebook AI) released with the paper [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael Auli.
1. **[XLM](https://huggingface.co/transformers/model_doc/xlm.html)** (from Facebook) released together with the paper [Cross-lingual Language Model Pretraining](https://arxiv.org/abs/1901.07291) by Guillaume Lample and Alexis Conneau. 1. **[XLM](https://huggingface.co/transformers/model_doc/xlm.html)** (from Facebook) released together with the paper [Cross-lingual Language Model Pretraining](https://arxiv.org/abs/1901.07291) by Guillaume Lample and Alexis Conneau.
1. **[XLM-ProphetNet](https://huggingface.co/transformers/model_doc/xlmprophetnet.html)** (from Microsoft Research) released with the paper [ProphetNet: Predicting Future N-gram for Sequence-to-Sequence Pre-training](https://arxiv.org/abs/2001.04063) by Yu Yan, Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou. 1. **[XLM-ProphetNet](https://huggingface.co/transformers/model_doc/xlmprophetnet.html)** (from Microsoft Research) released with the paper [ProphetNet: Predicting Future N-gram for Sequence-to-Sequence Pre-training](https://arxiv.org/abs/2001.04063) by Yu Yan, Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou.
......
...@@ -210,22 +210,26 @@ and conversion utilities for the following models: ...@@ -210,22 +210,26 @@ and conversion utilities for the following models:
43. :doc:`Transformer-XL <model_doc/transformerxl>` (from Google/CMU) released with the paper `Transformer-XL: 43. :doc:`Transformer-XL <model_doc/transformerxl>` (from Google/CMU) released with the paper `Transformer-XL:
Attentive Language Models Beyond a Fixed-Length Context <https://arxiv.org/abs/1901.02860>`__ by Zihang Dai*, Attentive Language Models Beyond a Fixed-Length Context <https://arxiv.org/abs/1901.02860>`__ by Zihang Dai*,
Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov. Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov.
44. :doc:`Wav2Vec2 <model_doc/wav2vec2>` (from Facebook AI) released with the paper `wav2vec 2.0: A Framework for 44. :doc:`Vision Transformer (ViT) <model_doc/vit>` (from Google AI) released with the paper `An Image is Worth 16x16
Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>`__ by Alexey Dosovitskiy,
Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias
Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby.
45. :doc:`Wav2Vec2 <model_doc/wav2vec2>` (from Facebook AI) released with the paper `wav2vec 2.0: A Framework for
Self-Supervised Learning of Speech Representations <https://arxiv.org/abs/2006.11477>`__ by Alexei Baevski, Henry Self-Supervised Learning of Speech Representations <https://arxiv.org/abs/2006.11477>`__ by Alexei Baevski, Henry
Zhou, Abdelrahman Mohamed, Michael Auli. Zhou, Abdelrahman Mohamed, Michael Auli.
45. :doc:`XLM <model_doc/xlm>` (from Facebook) released together with the paper `Cross-lingual Language Model 46. :doc:`XLM <model_doc/xlm>` (from Facebook) released together with the paper `Cross-lingual Language Model
Pretraining <https://arxiv.org/abs/1901.07291>`__ by Guillaume Lample and Alexis Conneau. Pretraining <https://arxiv.org/abs/1901.07291>`__ by Guillaume Lample and Alexis Conneau.
46. :doc:`XLM-ProphetNet <model_doc/xlmprophetnet>` (from Microsoft Research) released with the paper `ProphetNet: 47. :doc:`XLM-ProphetNet <model_doc/xlmprophetnet>` (from Microsoft Research) released with the paper `ProphetNet:
Predicting Future N-gram for Sequence-to-Sequence Pre-training <https://arxiv.org/abs/2001.04063>`__ by Yu Yan, Predicting Future N-gram for Sequence-to-Sequence Pre-training <https://arxiv.org/abs/2001.04063>`__ by Yu Yan,
Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou. Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou.
47. :doc:`XLM-RoBERTa <model_doc/xlmroberta>` (from Facebook AI), released together with the paper `Unsupervised 48. :doc:`XLM-RoBERTa <model_doc/xlmroberta>` (from Facebook AI), released together with the paper `Unsupervised
Cross-lingual Representation Learning at Scale <https://arxiv.org/abs/1911.02116>`__ by Alexis Conneau*, Kartikay Cross-lingual Representation Learning at Scale <https://arxiv.org/abs/1911.02116>`__ by Alexis Conneau*, Kartikay
Khandelwal*, Naman Goyal, Vishrav Chaudhary, Guillaume Wenzek, Francisco Guzmán, Edouard Grave, Myle Ott, Luke Khandelwal*, Naman Goyal, Vishrav Chaudhary, Guillaume Wenzek, Francisco Guzmán, Edouard Grave, Myle Ott, Luke
Zettlemoyer and Veselin Stoyanov. Zettlemoyer and Veselin Stoyanov.
48. :doc:`XLNet <model_doc/xlnet>` (from Google/CMU) released with the paper `​XLNet: Generalized Autoregressive 49. :doc:`XLNet <model_doc/xlnet>` (from Google/CMU) released with the paper `​XLNet: Generalized Autoregressive
Pretraining for Language Understanding <https://arxiv.org/abs/1906.08237>`__ by Zhilin Yang*, Zihang Dai*, Yiming Pretraining for Language Understanding <https://arxiv.org/abs/1906.08237>`__ by Zhilin Yang*, Zihang Dai*, Yiming
Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le. Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le.
49. :doc:`XLSR-Wav2Vec2 <model_doc/xlsr_wav2vec2>` (from Facebook AI) released with the paper `Unsupervised 50. :doc:`XLSR-Wav2Vec2 <model_doc/xlsr_wav2vec2>` (from Facebook AI) released with the paper `Unsupervised
Cross-Lingual Representation Learning For Speech Recognition <https://arxiv.org/abs/2006.13979>`__ by Alexis Cross-Lingual Representation Learning For Speech Recognition <https://arxiv.org/abs/2006.13979>`__ by Alexis
Conneau, Alexei Baevski, Ronan Collobert, Abdelrahman Mohamed, Michael Auli. Conneau, Alexei Baevski, Ronan Collobert, Abdelrahman Mohamed, Michael Auli.
...@@ -328,6 +332,8 @@ TensorFlow and/or Flax. ...@@ -328,6 +332,8 @@ TensorFlow and/or Flax.
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| Transformer-XL | ✅ | ❌ | ✅ | ✅ | ❌ | | Transformer-XL | ✅ | ❌ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| ViT | ❌ | ❌ | ✅ | ❌ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| Wav2Vec2 | ✅ | ❌ | ✅ | ❌ | ❌ | | Wav2Vec2 | ✅ | ❌ | ✅ | ❌ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| XLM | ✅ | ❌ | ✅ | ✅ | ❌ | | XLM | ✅ | ❌ | ✅ | ✅ | ❌ |
...@@ -460,6 +466,7 @@ TensorFlow and/or Flax. ...@@ -460,6 +466,7 @@ TensorFlow and/or Flax.
model_doc/t5 model_doc/t5
model_doc/tapas model_doc/tapas
model_doc/transformerxl model_doc/transformerxl
model_doc/vit
model_doc/wav2vec2 model_doc/wav2vec2
model_doc/xlm model_doc/xlm
model_doc/xlmprophetnet model_doc/xlmprophetnet
......
..
Copyright 2020 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
Vision Transformer (ViT)
-----------------------------------------------------------------------------------------------------------------------
.. note::
This is a recently introduced model so the API hasn't been tested extensively. There may be some bugs or slight
breaking changes to fix it in the future. If you see something strange, file a `Github Issue
<https://github.com/huggingface/transformers/issues/new?assignees=&labels=&template=bug-report.md&title>`__.
Overview
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
The Vision Transformer (ViT) model was proposed in `An Image is Worth 16x16 Words: Transformers for Image Recognition
at Scale <https://arxiv.org/abs/2010.11929>`__ by Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk
Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob
Uszkoreit, Neil Houlsby. It's the first paper that successfully trains a Transformer encoder on ImageNet, attaining
very good results compared to familiar convolutional architectures.
The abstract from the paper is the following:
*While the Transformer architecture has become the de-facto standard for natural language processing tasks, its
applications to computer vision remain limited. In vision, attention is either applied in conjunction with
convolutional networks, or used to replace certain components of convolutional networks while keeping their overall
structure in place. We show that this reliance on CNNs is not necessary and a pure transformer applied directly to
sequences of image patches can perform very well on image classification tasks. When pre-trained on large amounts of
data and transferred to multiple mid-sized or small image recognition benchmarks (ImageNet, CIFAR-100, VTAB, etc.),
Vision Transformer (ViT) attains excellent results compared to state-of-the-art convolutional networks while requiring
substantially fewer computational resources to train.*
Tips:
- To feed images to the Transformer encoder, each image is split into a sequence of fixed-size non-overlapping patches,
which are then linearly embedded. A [CLS] token is added to serve as representation of an entire image, which can be
used for classification. The authors also add absolute position embeddings, and feed the resulting sequence of
vectors to a standard Transformer encoder.
- The Vision Transformer was pre-trained using a resolution of 224x224. During fine-tuning, it is often beneficial to
use a higher resolution than pre-training `(Touvron et al., 2019) <https://arxiv.org/abs/1906.06423>`__, `(Kolesnikov
et al., 2020) <https://arxiv.org/abs/1912.11370>`__. The authors report the best results with a resolution of 384x384
during fine-tuning.
- As the Vision Transformer expects each image to be of the same size (resolution), one can use
:class:`~transformers.ViTFeatureExtractor` to resize (or rescale) and normalize images for the model.
- Both the patch resolution and image resolution used during pre-training or fine-tuning are reflected in the name of
each checkpoint. For example, :obj:`google/vit-base-patch16-224` refers to a base-sized architecture with patch
resolution of 16x16 and fine-tuning resolution of 224x224. All checkpoints can be found on the `hub
<https://huggingface.co/models?search=vit>`__.
- The available checkpoints are either (1) pre-trained on `ImageNet-21k <http://www.image-net.org/>`__ (a collection of
14 million images and 21k classes) only, or (2) also fine-tuned on `ImageNet
<http://www.image-net.org/challenges/LSVRC/2012/>`__ (also referred to as ILSVRC 2012, a collection of 1.3 million
images and 1,000 classes).
- The best results are obtained with supervised pre-training, which is not the case in NLP. The authors also performed
an experiment with a self-supervised pre-training objective, namely masked patched prediction (inspired by masked
language modeling). With this approach, the smaller ViT-B/16 model achieves 79.9% accuracy on ImageNet, a significant
improvement of 2% to training from scratch, but still 4% behind supervised pre-training.
The original code (written in JAX) can be found `here <https://github.com/google-research/vision_transformer>`__.
Note that we converted the weights from Ross Wightman's `timm library
<https://github.com/rwightman/pytorch-image-models>`__, who already converted the weights from JAX to PyTorch. Credits
go to him!
ViTConfig
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.ViTConfig
:members:
ViTFeatureExtractor
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.ViTFeatureExtractor
:members: __call__
ViTModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.ViTModel
:members: forward
ViTForImageClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.ViTForImageClassification
:members: forward
...@@ -107,6 +107,7 @@ _deps = [ ...@@ -107,6 +107,7 @@ _deps = [
"onnxruntime>=1.4.0", "onnxruntime>=1.4.0",
"packaging", "packaging",
"parameterized", "parameterized",
"Pillow",
"protobuf", "protobuf",
"psutil", "psutil",
"pydantic", "pydantic",
...@@ -230,6 +231,7 @@ extras["sagemaker"] = deps_list("sagemaker") ...@@ -230,6 +231,7 @@ extras["sagemaker"] = deps_list("sagemaker")
extras["serving"] = deps_list("pydantic", "uvicorn", "fastapi", "starlette") extras["serving"] = deps_list("pydantic", "uvicorn", "fastapi", "starlette")
extras["speech"] = deps_list("soundfile", "torchaudio") extras["speech"] = deps_list("soundfile", "torchaudio")
extras["vision"] = deps_list("Pillow")
extras["sentencepiece"] = deps_list("sentencepiece", "protobuf") extras["sentencepiece"] = deps_list("sentencepiece", "protobuf")
extras["testing"] = ( extras["testing"] = (
...@@ -242,7 +244,7 @@ extras["testing"] = ( ...@@ -242,7 +244,7 @@ extras["testing"] = (
extras["docs"] = deps_list("recommonmark", "sphinx", "sphinx-markdown-tables", "sphinx-rtd-theme", "sphinx-copybutton") extras["docs"] = deps_list("recommonmark", "sphinx", "sphinx-markdown-tables", "sphinx-rtd-theme", "sphinx-copybutton")
extras["quality"] = deps_list("black", "isort", "flake8") extras["quality"] = deps_list("black", "isort", "flake8")
extras["all"] = extras["tf"] + extras["torch"] + extras["flax"] + extras["sentencepiece"] + extras["tokenizers"] extras["all"] = extras["tf"] + extras["torch"] + extras["flax"] + extras["sentencepiece"] + extras["tokenizers"] + extras["speech"] + extras["vision"]
extras["dev"] = ( extras["dev"] = (
extras["all"] extras["all"]
......
...@@ -213,6 +213,7 @@ _import_structure = { ...@@ -213,6 +213,7 @@ _import_structure = {
"TransfoXLCorpus", "TransfoXLCorpus",
"TransfoXLTokenizer", "TransfoXLTokenizer",
], ],
"models.vit": ["VIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ViTConfig"],
"models.wav2vec2": [ "models.wav2vec2": [
"WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP", "WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP",
"Wav2Vec2Config", "Wav2Vec2Config",
...@@ -299,7 +300,7 @@ else: ...@@ -299,7 +300,7 @@ else:
name for name in dir(dummy_sentencepiece_objects) if not name.startswith("_") name for name in dir(dummy_sentencepiece_objects) if not name.startswith("_")
] ]
# tokenziers-backed objects # tokenizers-backed objects
if is_tokenizers_available(): if is_tokenizers_available():
# Fast tokenizers # Fast tokenizers
_import_structure["models.convbert"].append("ConvBertTokenizerFast") _import_structure["models.convbert"].append("ConvBertTokenizerFast")
...@@ -348,6 +349,7 @@ else: ...@@ -348,6 +349,7 @@ else:
# Vision-specific objects # Vision-specific objects
if is_vision_available(): if is_vision_available():
_import_structure["image_utils"] = ["ImageFeatureExtractionMixin"] _import_structure["image_utils"] = ["ImageFeatureExtractionMixin"]
_import_structure["models.vit"].append("ViTFeatureExtractor")
else: else:
from .utils import dummy_vision_objects from .utils import dummy_vision_objects
...@@ -426,6 +428,7 @@ if is_torch_available(): ...@@ -426,6 +428,7 @@ if is_torch_available():
_import_structure["models.auto"].extend( _import_structure["models.auto"].extend(
[ [
"MODEL_FOR_CAUSAL_LM_MAPPING", "MODEL_FOR_CAUSAL_LM_MAPPING",
"MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING",
"MODEL_FOR_MASKED_LM_MAPPING", "MODEL_FOR_MASKED_LM_MAPPING",
"MODEL_FOR_MULTIPLE_CHOICE_MAPPING", "MODEL_FOR_MULTIPLE_CHOICE_MAPPING",
"MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING", "MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING",
...@@ -867,6 +870,14 @@ if is_torch_available(): ...@@ -867,6 +870,14 @@ if is_torch_available():
"load_tf_weights_in_transfo_xl", "load_tf_weights_in_transfo_xl",
] ]
) )
_import_structure["models.vit"].extend(
[
"VIT_PRETRAINED_MODEL_ARCHIVE_LIST",
"ViTForImageClassification",
"ViTModel",
"ViTPreTrainedModel",
]
)
_import_structure["models.wav2vec2"].extend( _import_structure["models.wav2vec2"].extend(
[ [
"WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST", "WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST",
...@@ -1311,7 +1322,6 @@ else: ...@@ -1311,7 +1322,6 @@ else:
name for name in dir(dummy_flax_objects) if not name.startswith("_") name for name in dir(dummy_flax_objects) if not name.startswith("_")
] ]
# Direct imports for type-checking # Direct imports for type-checking
if TYPE_CHECKING: if TYPE_CHECKING:
# Configuration # Configuration
...@@ -1479,6 +1489,7 @@ if TYPE_CHECKING: ...@@ -1479,6 +1489,7 @@ if TYPE_CHECKING:
TransfoXLCorpus, TransfoXLCorpus,
TransfoXLTokenizer, TransfoXLTokenizer,
) )
from .models.vit import VIT_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTConfig
from .models.wav2vec2 import ( from .models.wav2vec2 import (
WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP, WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP,
Wav2Vec2Config, Wav2Vec2Config,
...@@ -1601,6 +1612,7 @@ if TYPE_CHECKING: ...@@ -1601,6 +1612,7 @@ if TYPE_CHECKING:
if is_vision_available(): if is_vision_available():
from .image_utils import ImageFeatureExtractionMixin from .image_utils import ImageFeatureExtractionMixin
from .models.vit import ViTFeatureExtractor
else: else:
from .utils.dummy_vision_objects import * from .utils.dummy_vision_objects import *
...@@ -1666,6 +1678,7 @@ if TYPE_CHECKING: ...@@ -1666,6 +1678,7 @@ if TYPE_CHECKING:
) )
from .models.auto import ( from .models.auto import (
MODEL_FOR_CAUSAL_LM_MAPPING, MODEL_FOR_CAUSAL_LM_MAPPING,
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
MODEL_FOR_MASKED_LM_MAPPING, MODEL_FOR_MASKED_LM_MAPPING,
MODEL_FOR_MULTIPLE_CHOICE_MAPPING, MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING, MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
...@@ -2025,6 +2038,12 @@ if TYPE_CHECKING: ...@@ -2025,6 +2038,12 @@ if TYPE_CHECKING:
TransfoXLPreTrainedModel, TransfoXLPreTrainedModel,
load_tf_weights_in_transfo_xl, load_tf_weights_in_transfo_xl,
) )
from .models.vit import (
VIT_PRETRAINED_MODEL_ARCHIVE_LIST,
ViTForImageClassification,
ViTModel,
ViTPreTrainedModel,
)
from .models.wav2vec2 import ( from .models.wav2vec2 import (
WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST, WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST,
Wav2Vec2ForCTC, Wav2Vec2ForCTC,
...@@ -2400,6 +2419,7 @@ if TYPE_CHECKING: ...@@ -2400,6 +2419,7 @@ if TYPE_CHECKING:
# Import the same objects as dummies to get them in the namespace. # Import the same objects as dummies to get them in the namespace.
# They will raise an import error if the user tries to instantiate / use them. # They will raise an import error if the user tries to instantiate / use them.
from .utils.dummy_flax_objects import * from .utils.dummy_flax_objects import *
else: else:
import importlib import importlib
import os import os
......
...@@ -24,6 +24,7 @@ deps = { ...@@ -24,6 +24,7 @@ deps = {
"onnxruntime": "onnxruntime>=1.4.0", "onnxruntime": "onnxruntime>=1.4.0",
"packaging": "packaging", "packaging": "packaging",
"parameterized": "parameterized", "parameterized": "parameterized",
"Pillow": "Pillow",
"protobuf": "protobuf", "protobuf": "protobuf",
"psutil": "psutil", "psutil": "psutil",
"pydantic": "pydantic", "pydantic": "pydantic",
......
...@@ -175,10 +175,11 @@ try: ...@@ -175,10 +175,11 @@ try:
except importlib_metadata.PackageNotFoundError: except importlib_metadata.PackageNotFoundError:
_soundfile_available = False _soundfile_available = False
_torchaudio_available = importlib.util.find_spec("torchaudio")
_torchaudio_available = importlib.util.find_spec("torchaudio") is not None
try: try:
_torchaudio_version = importlib_metadata.version("torchaudio") _torchaudio_version = importlib_metadata.version("torchaudio")
logger.debug(f"Successfully imported soundfile version {_torchaudio_version}") logger.debug(f"Successfully imported torchaudio version {_torchaudio_version}")
except importlib_metadata.PackageNotFoundError: except importlib_metadata.PackageNotFoundError:
_torchaudio_available = False _torchaudio_available = False
......
...@@ -120,9 +120,9 @@ class ImageFeatureExtractionMixin: ...@@ -120,9 +120,9 @@ class ImageFeatureExtractionMixin:
if isinstance(image, np.ndarray): if isinstance(image, np.ndarray):
if not isinstance(mean, np.ndarray): if not isinstance(mean, np.ndarray):
mean = np.array(mean) mean = np.array(mean).astype(image.dtype)
if not isinstance(std, np.ndarray): if not isinstance(std, np.ndarray):
std = np.array(std) std = np.array(std).astype(image.dtype)
elif is_torch_tensor(image): elif is_torch_tensor(image):
import torch import torch
......
...@@ -67,6 +67,7 @@ from . import ( ...@@ -67,6 +67,7 @@ from . import (
t5, t5,
tapas, tapas,
transfo_xl, transfo_xl,
vit,
wav2vec2, wav2vec2,
xlm, xlm,
xlm_roberta, xlm_roberta,
......
...@@ -29,6 +29,7 @@ _import_structure = { ...@@ -29,6 +29,7 @@ _import_structure = {
if is_torch_available(): if is_torch_available():
_import_structure["modeling_auto"] = [ _import_structure["modeling_auto"] = [
"MODEL_FOR_CAUSAL_LM_MAPPING", "MODEL_FOR_CAUSAL_LM_MAPPING",
"MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING",
"MODEL_FOR_MASKED_LM_MAPPING", "MODEL_FOR_MASKED_LM_MAPPING",
"MODEL_FOR_MULTIPLE_CHOICE_MAPPING", "MODEL_FOR_MULTIPLE_CHOICE_MAPPING",
"MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING", "MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING",
...@@ -42,6 +43,7 @@ if is_torch_available(): ...@@ -42,6 +43,7 @@ if is_torch_available():
"MODEL_WITH_LM_HEAD_MAPPING", "MODEL_WITH_LM_HEAD_MAPPING",
"AutoModel", "AutoModel",
"AutoModelForCausalLM", "AutoModelForCausalLM",
"AutoModelForImageClassification",
"AutoModelForMaskedLM", "AutoModelForMaskedLM",
"AutoModelForMultipleChoice", "AutoModelForMultipleChoice",
"AutoModelForNextSentencePrediction", "AutoModelForNextSentencePrediction",
...@@ -90,6 +92,7 @@ if TYPE_CHECKING: ...@@ -90,6 +92,7 @@ if TYPE_CHECKING:
if is_torch_available(): if is_torch_available():
from .modeling_auto import ( from .modeling_auto import (
MODEL_FOR_CAUSAL_LM_MAPPING, MODEL_FOR_CAUSAL_LM_MAPPING,
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
MODEL_FOR_MASKED_LM_MAPPING, MODEL_FOR_MASKED_LM_MAPPING,
MODEL_FOR_MULTIPLE_CHOICE_MAPPING, MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING, MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
...@@ -103,6 +106,7 @@ if TYPE_CHECKING: ...@@ -103,6 +106,7 @@ if TYPE_CHECKING:
MODEL_WITH_LM_HEAD_MAPPING, MODEL_WITH_LM_HEAD_MAPPING,
AutoModel, AutoModel,
AutoModelForCausalLM, AutoModelForCausalLM,
AutoModelForImageClassification,
AutoModelForMaskedLM, AutoModelForMaskedLM,
AutoModelForMultipleChoice, AutoModelForMultipleChoice,
AutoModelForNextSentencePrediction, AutoModelForNextSentencePrediction,
......
...@@ -68,6 +68,7 @@ from ..squeezebert.configuration_squeezebert import SQUEEZEBERT_PRETRAINED_CONFI ...@@ -68,6 +68,7 @@ from ..squeezebert.configuration_squeezebert import SQUEEZEBERT_PRETRAINED_CONFI
from ..t5.configuration_t5 import T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5Config from ..t5.configuration_t5 import T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5Config
from ..tapas.configuration_tapas import TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP, TapasConfig from ..tapas.configuration_tapas import TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP, TapasConfig
from ..transfo_xl.configuration_transfo_xl import TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP, TransfoXLConfig from ..transfo_xl.configuration_transfo_xl import TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP, TransfoXLConfig
from ..vit.configuration_vit import VIT_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTConfig
from ..wav2vec2.configuration_wav2vec2 import WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP, Wav2Vec2Config from ..wav2vec2.configuration_wav2vec2 import WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP, Wav2Vec2Config
from ..xlm.configuration_xlm import XLM_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMConfig from ..xlm.configuration_xlm import XLM_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMConfig
from ..xlm_prophetnet.configuration_xlm_prophetnet import ( from ..xlm_prophetnet.configuration_xlm_prophetnet import (
...@@ -85,6 +86,7 @@ ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict( ...@@ -85,6 +86,7 @@ ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict(
GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP,
BIG_BIRD_PRETRAINED_CONFIG_ARCHIVE_MAP, BIG_BIRD_PRETRAINED_CONFIG_ARCHIVE_MAP,
SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP, SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP,
VIT_PRETRAINED_CONFIG_ARCHIVE_MAP,
WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP, WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP,
M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP, M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP,
CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
...@@ -134,6 +136,7 @@ CONFIG_MAPPING = OrderedDict( ...@@ -134,6 +136,7 @@ CONFIG_MAPPING = OrderedDict(
("gpt_neo", GPTNeoConfig), ("gpt_neo", GPTNeoConfig),
("big_bird", BigBirdConfig), ("big_bird", BigBirdConfig),
("speech_to_text", Speech2TextConfig), ("speech_to_text", Speech2TextConfig),
("vit", ViTConfig),
("wav2vec2", Wav2Vec2Config), ("wav2vec2", Wav2Vec2Config),
("m2m_100", M2M100Config), ("m2m_100", M2M100Config),
("convbert", ConvBertConfig), ("convbert", ConvBertConfig),
...@@ -189,6 +192,7 @@ MODEL_NAMES_MAPPING = OrderedDict( ...@@ -189,6 +192,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
("gpt_neo", "GPT Neo"), ("gpt_neo", "GPT Neo"),
("big_bird", "BigBird"), ("big_bird", "BigBird"),
("speech_to_text", "Speech2Text"), ("speech_to_text", "Speech2Text"),
("vit", "ViT"),
("wav2vec2", "Wav2Vec2"), ("wav2vec2", "Wav2Vec2"),
("m2m_100", "M2M100"), ("m2m_100", "M2M100"),
("convbert", "ConvBERT"), ("convbert", "ConvBERT"),
......
...@@ -237,6 +237,7 @@ from ..tapas.modeling_tapas import ( ...@@ -237,6 +237,7 @@ from ..tapas.modeling_tapas import (
TapasModel, TapasModel,
) )
from ..transfo_xl.modeling_transfo_xl import TransfoXLForSequenceClassification, TransfoXLLMHeadModel, TransfoXLModel from ..transfo_xl.modeling_transfo_xl import TransfoXLForSequenceClassification, TransfoXLLMHeadModel, TransfoXLModel
from ..vit.modeling_vit import ViTForImageClassification, ViTModel
from ..wav2vec2.modeling_wav2vec2 import Wav2Vec2ForMaskedLM, Wav2Vec2Model from ..wav2vec2.modeling_wav2vec2 import Wav2Vec2ForMaskedLM, Wav2Vec2Model
from ..xlm.modeling_xlm import ( from ..xlm.modeling_xlm import (
XLMForMultipleChoice, XLMForMultipleChoice,
...@@ -313,6 +314,7 @@ from .configuration_auto import ( ...@@ -313,6 +314,7 @@ from .configuration_auto import (
T5Config, T5Config,
TapasConfig, TapasConfig,
TransfoXLConfig, TransfoXLConfig,
ViTConfig,
Wav2Vec2Config, Wav2Vec2Config,
XLMConfig, XLMConfig,
XLMProphetNetConfig, XLMProphetNetConfig,
...@@ -331,6 +333,7 @@ MODEL_MAPPING = OrderedDict( ...@@ -331,6 +333,7 @@ MODEL_MAPPING = OrderedDict(
(GPTNeoConfig, GPTNeoModel), (GPTNeoConfig, GPTNeoModel),
(BigBirdConfig, BigBirdModel), (BigBirdConfig, BigBirdModel),
(Speech2TextConfig, Speech2TextModel), (Speech2TextConfig, Speech2TextModel),
(ViTConfig, ViTModel),
(Wav2Vec2Config, Wav2Vec2Model), (Wav2Vec2Config, Wav2Vec2Model),
(M2M100Config, M2M100Model), (M2M100Config, M2M100Model),
(ConvBertConfig, ConvBertModel), (ConvBertConfig, ConvBertModel),
...@@ -490,6 +493,13 @@ MODEL_FOR_CAUSAL_LM_MAPPING = OrderedDict( ...@@ -490,6 +493,13 @@ MODEL_FOR_CAUSAL_LM_MAPPING = OrderedDict(
] ]
) )
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = OrderedDict(
[
# Model for Image Classification mapping
(ViTConfig, ViTForImageClassification),
]
)
MODEL_FOR_MASKED_LM_MAPPING = OrderedDict( MODEL_FOR_MASKED_LM_MAPPING = OrderedDict(
[ [
# Model for Masked LM mapping # Model for Masked LM mapping
...@@ -1864,3 +1874,100 @@ class AutoModelForNextSentencePrediction: ...@@ -1864,3 +1874,100 @@ class AutoModelForNextSentencePrediction:
f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n" f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
f"Model type should be one of {', '.join(c.__name__ for c in MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING.keys())}." f"Model type should be one of {', '.join(c.__name__ for c in MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING.keys())}."
) )
class AutoModelForImageClassification:
r"""
This is a generic model class that will be instantiated as one of the model classes of the library---with an image
classification head---when created with the :meth:`~transformers.AutoModelForImageClassification.from_pretrained`
class method or the :meth:`~transformers.AutoModelForImageClassification.from_config` class method.
This class cannot be instantiated directly using ``__init__()`` (throws an error).
"""
def __init__(self):
raise EnvironmentError(
"AutoModelForImageClassification is designed to be instantiated "
"using the `AutoModelForImageClassification.from_pretrained(pretrained_model_name_or_path)` or "
"`AutoModelForImageClassification.from_config(config)` methods."
)
@classmethod
@replace_list_option_in_docstrings(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING, use_model_types=False)
def from_config(cls, config):
r"""
Instantiates one of the model classes of the library---with an image classification head---from a
configuration.
Note:
Loading a model from its configuration file does **not** load the model weights. It only affects the
model's configuration. Use :meth:`~transformers.AutoModelForImageClassification.from_pretrained` to load
the model weights.
Args:
config (:class:`~transformers.PretrainedConfig`):
The model class to instantiate is selected based on the configuration class:
List options
Examples::
>>> from transformers import AutoConfig, AutoModelForImageClassification
>>> # Download configuration from huggingface.co and cache.
>>> config = AutoConfig.from_pretrained('google/vit_base_patch16_224')
>>> model = AutoModelForImageClassification.from_config(config)
"""
if type(config) in MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING.keys():
return MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING[type(config)](config)
raise ValueError(
"Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
"Model type should be one of {}.".format(
config.__class__,
cls.__name__,
", ".join(c.__name__ for c in MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING.keys()),
)
)
@classmethod
@replace_list_option_in_docstrings(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING)
@add_start_docstrings(
"Instantiate one of the model classes of the library---with an image classification head---from a "
"pretrained model.",
AUTO_MODEL_PRETRAINED_DOCSTRING,
)
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
r"""
Examples::
>>> from transformers import AutoConfig, AutoModelForImageClassification
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModelForImageClassification.from_pretrained('google/vit_base_patch16_224')
>>> # Update configuration during loading
>>> model = AutoModelForImageClassification.from_pretrained('google/vit_base_patch16_224', output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_json_file('./tf_model/vit_tf_model_config.json')
>>> model = AutoModelForImageClassification.from_pretrained('./tf_model/vit_tf_checkpoint.ckpt.index', from_tf=True, config=config)
"""
config = kwargs.pop("config", None)
if not isinstance(config, PretrainedConfig):
config, kwargs = AutoConfig.from_pretrained(
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
)
if type(config) in MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING.keys():
return MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING[type(config)].from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
raise ValueError(
"Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
"Model type should be one of {}.".format(
config.__class__,
cls.__name__,
", ".join(c.__name__ for c in MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING.keys()),
)
)
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...file_utils import _BaseLazyModule, is_torch_available, is_vision_available
_import_structure = {
"configuration_vit": ["VIT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ViTConfig"],
}
if is_vision_available():
_import_structure["feature_extraction_vit"] = ["ViTFeatureExtractor"]
if is_torch_available():
_import_structure["modeling_vit"] = [
"VIT_PRETRAINED_MODEL_ARCHIVE_LIST",
"ViTForImageClassification",
"ViTModel",
"ViTPreTrainedModel",
]
if TYPE_CHECKING:
from .configuration_vit import VIT_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTConfig
if is_vision_available():
from .feature_extraction_vit import ViTFeatureExtractor
if is_torch_available():
from .modeling_vit import (
VIT_PRETRAINED_MODEL_ARCHIVE_LIST,
ViTForImageClassification,
ViTModel,
ViTPreTrainedModel,
)
else:
import importlib
import os
import sys
class _LazyModule(_BaseLazyModule):
"""
Module class that surfaces all objects but only performs associated imports when the objects are requested.
"""
__file__ = globals()["__file__"]
__path__ = [os.path.dirname(__file__)]
def _get_module(self, module_name: str):
return importlib.import_module("." + module_name, self.__name__)
sys.modules[__name__] = _LazyModule(__name__, _import_structure)
# coding=utf-8
# Copyright 2021 Google AI and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" ViT model configuration """
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__)
VIT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"nielsr/vit-base-patch16-224": "https://huggingface.co/vit-base-patch16-224/resolve/main/config.json",
# See all ViT models at https://huggingface.co/models?filter=vit
}
class ViTConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a :class:`~transformers.ViTModel`. It is used to
instantiate an ViT model according to the specified arguments, defining the model architecture. Instantiating a
configuration with the defaults will yield a similar configuration to that of the ViT `google/vit-base-patch16-224
<https://huggingface.co/google/vit-base-patch16-224>`__ architecture.
Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model
outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information.
Args:
hidden_size (:obj:`int`, `optional`, defaults to 768):
Dimensionality of the encoder layers and the pooler layer.
num_hidden_layers (:obj:`int`, `optional`, defaults to 12):
Number of hidden layers in the Transformer encoder.
num_attention_heads (:obj:`int`, `optional`, defaults to 12):
Number of attention heads for each attention layer in the Transformer encoder.
intermediate_size (:obj:`int`, `optional`, defaults to 3072):
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
hidden_act (:obj:`str` or :obj:`function`, `optional`, defaults to :obj:`"gelu"`):
The non-linear activation function (function or string) in the encoder and pooler. If string,
:obj:`"gelu"`, :obj:`"relu"`, :obj:`"selu"` and :obj:`"gelu_new"` are supported.
hidden_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):
The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
attention_probs_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):
The dropout ratio for the attention probabilities.
initializer_range (:obj:`float`, `optional`, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12):
The epsilon used by the layer normalization layers.
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
image_size (:obj:`int`, `optional`, defaults to :obj:`224`):
The size (resolution) of each image.
patch_size (:obj:`int`, `optional`, defaults to :obj:`16`):
The size (resolution) of each patch.
num_channels (:obj:`int`, `optional`, defaults to :obj:`3`):
The number of input channels.
Example::
>>> from transformers import ViTModel, ViTConfig
>>> # Initializing a ViT vit-base-patch16-224 style configuration
>>> configuration = ViTConfig()
>>> # Initializing a model from the vit-base-patch16-224 style configuration
>>> model = ViTModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
"""
model_type = "vit"
def __init__(
self,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
hidden_act="gelu",
hidden_dropout_prob=0.0,
attention_probs_dropout_prob=0.0,
initializer_range=0.02,
layer_norm_eps=1e-12,
is_encoder_decoder=False,
image_size=224,
patch_size=16,
num_channels=3,
**kwargs
):
super().__init__(**kwargs)
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.image_size = image_size
self.patch_size = patch_size
self.num_channels = num_channels
# coding=utf-8
# Copyright 2020 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert ViT checkpoints from the timm library."""
import argparse
from pathlib import Path
import torch
from PIL import Image
import requests
import timm
from transformers import ViTConfig, ViTFeatureExtractor, ViTForImageClassification, ViTModel
from transformers.utils import logging
from transformers.utils.imagenet_classes import id2label
logging.set_verbosity_info()
logger = logging.get_logger(__name__)
# here we list all keys to be renamed (original name on the left, our name on the right)
def create_rename_keys(config, base_model=False):
rename_keys = []
for i in range(config.num_hidden_layers):
# encoder layers: output projection, 2 feedforward neural networks and 2 layernorms
rename_keys.append((f"blocks.{i}.norm1.weight", f"vit.encoder.layer.{i}.layernorm_before.weight"))
rename_keys.append((f"blocks.{i}.norm1.bias", f"vit.encoder.layer.{i}.layernorm_before.bias"))
rename_keys.append((f"blocks.{i}.attn.proj.weight", f"vit.encoder.layer.{i}.attention.output.dense.weight"))
rename_keys.append((f"blocks.{i}.attn.proj.bias", f"vit.encoder.layer.{i}.attention.output.dense.bias"))
rename_keys.append((f"blocks.{i}.norm2.weight", f"vit.encoder.layer.{i}.layernorm_after.weight"))
rename_keys.append((f"blocks.{i}.norm2.bias", f"vit.encoder.layer.{i}.layernorm_after.bias"))
rename_keys.append((f"blocks.{i}.mlp.fc1.weight", f"vit.encoder.layer.{i}.intermediate.dense.weight"))
rename_keys.append((f"blocks.{i}.mlp.fc1.bias", f"vit.encoder.layer.{i}.intermediate.dense.bias"))
rename_keys.append((f"blocks.{i}.mlp.fc2.weight", f"vit.encoder.layer.{i}.output.dense.weight"))
rename_keys.append((f"blocks.{i}.mlp.fc2.bias", f"vit.encoder.layer.{i}.output.dense.bias"))
# projection layer + position embeddings
rename_keys.extend(
[
("cls_token", "vit.embeddings.cls_token"),
("patch_embed.proj.weight", "vit.embeddings.patch_embeddings.projection.weight"),
("patch_embed.proj.bias", "vit.embeddings.patch_embeddings.projection.bias"),
("pos_embed", "vit.embeddings.position_embeddings"),
]
)
if base_model:
# layernorm + pooler
rename_keys.extend(
[
("norm.weight", "layernorm.weight"),
("norm.bias", "layernorm.bias"),
("pre_logits.fc.weight", "pooler.dense.weight"),
("pre_logits.fc.bias", "pooler.dense.bias"),
]
)
# if just the base model, we should remove "vit" from all keys that start with "vit"
rename_keys = [(pair[0], pair[1][4:]) if pair[1].startswith("vit") else pair for pair in rename_keys]
else:
# layernorm + classification head
rename_keys.extend(
[
("norm.weight", "vit.layernorm.weight"),
("norm.bias", "vit.layernorm.bias"),
("head.weight", "classifier.weight"),
("head.bias", "classifier.bias"),
]
)
return rename_keys
# we split up the matrix of each encoder layer into queries, keys and values
def read_in_q_k_v(state_dict, config, base_model=False):
for i in range(config.num_hidden_layers):
if base_model:
prefix = ""
else:
prefix = "vit."
# read in weights + bias of input projection layer (in timm, this is a single matrix + bias)
in_proj_weight = state_dict.pop(f"blocks.{i}.attn.qkv.weight")
in_proj_bias = state_dict.pop(f"blocks.{i}.attn.qkv.bias")
# next, add query, keys and values (in that order) to the state dict
state_dict[f"{prefix}encoder.layer.{i}.attention.attention.query.weight"] = in_proj_weight[
: config.hidden_size, :
]
state_dict[f"{prefix}encoder.layer.{i}.attention.attention.query.bias"] = in_proj_bias[: config.hidden_size]
state_dict[f"{prefix}encoder.layer.{i}.attention.attention.key.weight"] = in_proj_weight[
config.hidden_size : config.hidden_size * 2, :
]
state_dict[f"{prefix}encoder.layer.{i}.attention.attention.key.bias"] = in_proj_bias[
config.hidden_size : config.hidden_size * 2
]
state_dict[f"{prefix}encoder.layer.{i}.attention.attention.value.weight"] = in_proj_weight[
-config.hidden_size :, :
]
state_dict[f"{prefix}encoder.layer.{i}.attention.attention.value.bias"] = in_proj_bias[-config.hidden_size :]
def remove_classification_head_(state_dict):
ignore_keys = ["head.weight", "head.bias"]
for k in ignore_keys:
state_dict.pop(k, None)
def rename_key(dct, old, new):
val = dct.pop(old)
dct[new] = val
# We will verify our results on an image of cute cats
def prepare_img():
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
im = Image.open(requests.get(url, stream=True).raw)
return im
@torch.no_grad()
def convert_vit_checkpoint(vit_name, pytorch_dump_folder_path):
"""
Copy/paste/tweak model's weights to our ViT structure.
"""
# define default ViT configuration
config = ViTConfig()
base_model = False
# dataset (ImageNet-21k only or also fine-tuned on ImageNet 2012), patch_size and image_size
if vit_name[-5:] == "in21k":
base_model = True
config.patch_size = int(vit_name[-12:-10])
config.image_size = int(vit_name[-9:-6])
else:
config.num_labels = 1000
config.id2label = id2label
config.label2id = {v: k for k, v in id2label.items()}
config.patch_size = int(vit_name[-6:-4])
config.image_size = int(vit_name[-3:])
# size of the architecture
if vit_name[4:].startswith("small"):
config.hidden_size = 768
config.intermediate_size = 2304
config.num_hidden_layers = 8
config.num_attention_heads = 8
if vit_name[4:].startswith("base"):
pass
elif vit_name[4:].startswith("large"):
config.hidden_size = 1024
config.intermediate_size = 4096
config.num_hidden_layers = 24
config.num_attention_heads = 16
elif vit_name[4:].startswith("huge"):
config.hidden_size = 1280
config.intermediate_size = 5120
config.num_hidden_layers = 32
config.num_attention_heads = 16
# load original model from timm
timm_model = timm.create_model(vit_name, pretrained=True)
timm_model.eval()
# load state_dict of original model, remove and rename some keys
state_dict = timm_model.state_dict()
if base_model:
remove_classification_head_(state_dict)
rename_keys = create_rename_keys(config, base_model)
for src, dest in rename_keys:
rename_key(state_dict, src, dest)
read_in_q_k_v(state_dict, config, base_model)
# load HuggingFace model
if vit_name[-5:] == "in21k":
model = ViTModel(config).eval()
else:
model = ViTForImageClassification(config).eval()
model.load_state_dict(state_dict)
# Check outputs on an image, prepared by ViTFeatureExtractor
feature_extractor = ViTFeatureExtractor(size=config.image_size)
encoding = feature_extractor(images=prepare_img(), return_tensors="pt")
pixel_values = encoding["pixel_values"]
outputs = model(pixel_values)
if base_model:
timm_pooled_output = timm_model.forward_features(pixel_values)
assert timm_pooled_output.shape == outputs.pooler_output.shape
assert torch.allclose(timm_pooled_output, outputs.pooler_output, atol=1e-3)
else:
timm_logits = timm_model(pixel_values)
assert timm_logits.shape == outputs.logits.shape
assert torch.allclose(timm_logits, outputs.logits, atol=1e-3)
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
print(f"Saving model {vit_name} to {pytorch_dump_folder_path}")
model.save_pretrained(pytorch_dump_folder_path)
print(f"Saving feature extractor to {pytorch_dump_folder_path}")
feature_extractor.save_pretrained(pytorch_dump_folder_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--vit_name",
default="vit_base_patch16_224",
type=str,
help="Name of the ViT timm model you'd like to convert.",
)
parser.add_argument(
"--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
)
args = parser.parse_args()
convert_vit_checkpoint(args.vit_name, args.pytorch_dump_folder_path)
# coding=utf-8
# Copyright Google AI and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Feature extractor class for ViT."""
from typing import List, Optional, Union
import numpy as np
from PIL import Image
from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin
from ...file_utils import TensorType
from ...image_utils import ImageFeatureExtractionMixin, is_torch_tensor
from ...utils import logging
logger = logging.get_logger(__name__)
class ViTFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
r"""
Constructs a ViT feature extractor.
This feature extractor inherits from :class:`~transformers.FeatureExtractionMixin` which contains most of the main
methods. Users should refer to this superclass for more information regarding those methods.
Args:
image_mean (:obj:`int`, defaults to :obj:`[0.5, 0.5, 0.5]`):
The sequence of means for each channel, to be used when normalizing images.
image_std (:obj:`int`, defaults to :obj:`[0.5, 0.5, 0.5]`):
The sequence of standard deviations for each channel, to be used when normalizing images.
do_normalize (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not to normalize the input with mean and standard deviation.
do_resize (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether to resize the input to a certain :obj:`size`.
size (:obj:`int`, `optional`, defaults to 224):
Resize the input to the given size. Only has an effect if :obj:`do_resize` is set to :obj:`True`.
"""
model_input_names = ["pixel_values"]
def __init__(self, image_mean=None, image_std=None, do_normalize=True, do_resize=True, size=224, **kwargs):
super().__init__(**kwargs)
self.image_mean = [0.5, 0.5, 0.5]
self.image_std = [0.5, 0.5, 0.5]
self.do_normalize = do_normalize
self.do_resize = do_resize
self.size = size
def __call__(
self,
images: Union[
Image.Image, np.ndarray, "torch.Tensor", List[Image.Image], List[np.ndarray], List["torch.Tensor"] # noqa
],
return_tensors: Optional[Union[str, TensorType]] = None,
**kwargs
) -> BatchFeature:
"""
Main method to prepare for the model one or several image(s).
.. warning::
NumPy arrays and PyTorch tensors are converted to PIL images when resizing, so the most efficient is to pass
PIL images.
Args:
images (:obj:`PIL.Image.Image`, :obj:`np.ndarray`, :obj:`torch.Tensor`, :obj:`List[PIL.Image.Image]`, :obj:`List[np.ndarray]`, :obj:`List[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
number of channels, H and W are image height and width.
return_tensors (:obj:`str` or :class:`~transformers.file_utils.TensorType`, `optional`):
If set, will return tensors instead of list of python integers. Acceptable values are:
* :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects.
* :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects.
* :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects.s
* :obj:`'jax'`: Return JAX :obj:`jnp.ndarray` objects.
Returns:
:class:`~transformers.BatchFeature`: A :class:`~transformers.BatchFeature` with the following fields:
- **pixel_values** -- Pixel values to be fed to a model.
"""
# Input type checking for clearer error
valid_images = False
# Check that images has a valid type
if isinstance(images, (Image.Image, np.ndarray)) or is_torch_tensor(images):
valid_images = True
elif isinstance(images, (list, tuple)):
if len(images) == 0 or isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]):
valid_images = True
if not valid_images:
raise ValueError(
"Images must of type `PIL.Image.Image`, `np.ndarray` or `torch.Tensor` (single example),"
"`List[PIL.Image.Image]`, `List[np.ndarray]` or `List[torch.Tensor]` (batch of examples)."
)
is_batched = bool(
isinstance(images, (list, tuple))
and (isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]))
)
if not is_batched:
images = [images]
# transformations (resizing + normalization)
if self.do_resize and self.size is not None:
images = [self.resize(image=image, size=self.size) for image in images]
if self.do_normalize:
images = [self.normalize(image=image, mean=self.image_mean, std=self.image_std) for image in images]
# return as BatchFeature
data = {"pixel_values": images}
encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
return encoded_inputs
This diff is collapsed.
...@@ -302,6 +302,9 @@ def load_tf_weights_in_albert(*args, **kwargs): ...@@ -302,6 +302,9 @@ def load_tf_weights_in_albert(*args, **kwargs):
MODEL_FOR_CAUSAL_LM_MAPPING = None MODEL_FOR_CAUSAL_LM_MAPPING = None
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = None
MODEL_FOR_MASKED_LM_MAPPING = None MODEL_FOR_MASKED_LM_MAPPING = None
...@@ -2512,6 +2515,32 @@ def load_tf_weights_in_transfo_xl(*args, **kwargs): ...@@ -2512,6 +2515,32 @@ def load_tf_weights_in_transfo_xl(*args, **kwargs):
requires_pytorch(load_tf_weights_in_transfo_xl) requires_pytorch(load_tf_weights_in_transfo_xl)
VIT_PRETRAINED_MODEL_ARCHIVE_LIST = None
class ViTForImageClassification:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
class ViTModel:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_pytorch(self)
class ViTPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_pytorch(self)
WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST = None WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST = None
......
...@@ -5,3 +5,8 @@ from ..file_utils import requires_vision ...@@ -5,3 +5,8 @@ from ..file_utils import requires_vision
class ImageFeatureExtractionMixin: class ImageFeatureExtractionMixin:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_vision(self) requires_vision(self)
class ViTFeatureExtractor:
def __init__(self, *args, **kwargs):
requires_vision(self)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment