Unverified Commit 1dc96a76 authored by NielsRogge's avatar NielsRogge Committed by GitHub
Browse files

Add SegFormer (#14019)



* First draft

* Make style & quality

* Improve conversion script

* Add print statement to see actual slice

* Make absolute tolerance smaller

* Fix image classification models

* Add post_process_semantic method

* Disable padding

* Improve conversion script

* Rename to ForSemanticSegmentation, add integration test, remove post_process methods

* Improve docs

* Fix code quality

* Fix feature extractor tests

* Fix tests for image classification model

* Delete file

* Add is_torch_available to feature extractor

* Improve documentation of feature extractor methods

* Apply suggestions from @sgugger's code review
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Apply some more suggestions of code review

* Rebase with master

* Fix rebase issues

* Make sure model only outputs hidden states when the user wants to

* Apply suggestions from code review

* Add pad method

* Support padding of 2d images

* Add print statement

* Add print statement

* Move padding method to SegformerFeatureExtractor

* Fix issue

* Add casting of segmentation maps

* Add test for padding

* Add small note about padding
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 123cce6f
......@@ -271,6 +271,7 @@ Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, and Wen-tau Yih.
1. **[RemBERT](https://huggingface.co/transformers/model_doc/rembert.html)** (from Google Research) released with the paper [Rethinking embedding coupling in pre-trained language models](https://arxiv.org/pdf/2010.12821.pdf) by Hyung Won Chung, Thibault Févry, Henry Tsai, M. Johnson, Sebastian Ruder.
1. **[RoBERTa](https://huggingface.co/transformers/model_doc/roberta.html)** (from Facebook), released together with the paper a [Robustly Optimized BERT Pretraining Approach](https://arxiv.org/abs/1907.11692) by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov.
1. **[RoFormer](https://huggingface.co/transformers/model_doc/roformer.html)** (from ZhuiyiTechnology), released together with the paper a [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/pdf/2104.09864v1.pdf) by Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu.
1. **[SegFormer](https://huggingface.co/transformers/master/model_doc/segformer.html)** (from NVIDIA) released with the paper [SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers](https://arxiv.org/abs/2105.15203) by Enze Xie, Wenhai Wang, Zhiding Yu, Anima Anandkumar, Jose M. Alvarez, Ping Luo.
1. **[SEW](https://huggingface.co/transformers/model_doc/sew.html)** (from ASAPP) released with the paper [Performance-Efficiency Trade-offs in Unsupervised Pre-training for Speech Recognition](https://arxiv.org/abs/2109.06870) by Felix Wu, Kwangyoun Kim, Jing Pan, Kyu Han, Kilian Q. Weinberger, Yoav Artzi.
1. **[SEW-D](https://huggingface.co/transformers/model_doc/sew_d.html)** (from ASAPP) released with the paper [Performance-Efficiency Trade-offs in Unsupervised Pre-training for Speech Recognition](https://arxiv.org/abs/2109.06870) by Felix Wu, Kwangyoun Kim, Jing Pan, Kyu Han, Kilian Q. Weinberger, Yoav Artzi.
1. **[SpeechToTextTransformer](https://huggingface.co/transformers/model_doc/speech_to_text.html)** (from Facebook), released together with the paper [fairseq S2T: Fast Speech-to-Text Modeling with fairseq](https://arxiv.org/abs/2010.05171) by Changhan Wang, Yun Tang, Xutai Ma, Anne Wu, Dmytro Okhonko, Juan Pino.
......
......@@ -269,6 +269,7 @@ Flax, PyTorch, TensorFlow 설치 페이지에서 이들을 conda로 설치하는
1. **[RemBERT](https://huggingface.co/transformers/model_doc/rembert.html)** (from Google Research) released with the paper [Rethinking embedding coupling in pre-trained language models](https://arxiv.org/pdf/2010.12821.pdf) by Hyung Won Chung, Thibault Févry, Henry Tsai, M. Johnson, Sebastian Ruder.
1. **[RoBERTa](https://huggingface.co/transformers/model_doc/roberta.html)** (from Facebook), released together with the paper a [Robustly Optimized BERT Pretraining Approach](https://arxiv.org/abs/1907.11692) by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov.
1. **[RoFormer](https://huggingface.co/transformers/model_doc/roformer.html)** (from ZhuiyiTechnology), released together with the paper a [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/pdf/2104.09864v1.pdf) by Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu.
1. **[SegFormer](https://huggingface.co/transformers/model_doc/segformer.html)** (from NVIDIA) released with the paper [SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers](https://arxiv.org/abs/2105.15203) by Enze Xie, Wenhai Wang, Zhiding Yu, Anima Anandkumar, Jose M. Alvarez, Ping Luo.
1. **[SEW](https://huggingface.co/transformers/model_doc/sew.html)** (from ASAPP) released with the paper [Performance-Efficiency Trade-offs in Unsupervised Pre-training for Speech Recognition](https://arxiv.org/abs/2109.06870) by Felix Wu, Kwangyoun Kim, Jing Pan, Kyu Han, Kilian Q. Weinberger, Yoav Artzi.
1. **[SEW-D](https://huggingface.co/transformers/model_doc/sew_d.html)** (from ASAPP) released with the paper [Performance-Efficiency Trade-offs in Unsupervised Pre-training for Speech Recognition](https://arxiv.org/abs/2109.06870) by Felix Wu, Kwangyoun Kim, Jing Pan, Kyu Han, Kilian Q. Weinberger, Yoav Artzi.
1. **[SpeechToTextTransformer](https://huggingface.co/transformers/model_doc/speech_to_text.html)** (from Facebook), released together with the paper [fairseq S2T: Fast Speech-to-Text Modeling with fairseq](https://arxiv.org/abs/2010.05171) by Changhan Wang, Yun Tang, Xutai Ma, Anne Wu, Dmytro Okhonko, Juan Pino.
......
......@@ -293,6 +293,7 @@ conda install -c huggingface transformers
1. **[RemBERT](https://huggingface.co/transformers/model_doc/rembert.html)** (来自 Google Research) 伴随论文 [Rethinking embedding coupling in pre-trained language models](https://arxiv.org/pdf/2010.12821.pdf) 由 Hyung Won Chung, Thibault Févry, Henry Tsai, M. Johnson, Sebastian Ruder 发布。
1. **[RoBERTa](https://huggingface.co/transformers/model_doc/roberta.html)** (来自 Facebook), 伴随论文 [Robustly Optimized BERT Pretraining Approach](https://arxiv.org/abs/1907.11692) 由 Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov 发布。
1. **[RoFormer](https://huggingface.co/transformers/model_doc/roformer.html)** (来自 ZhuiyiTechnology), 伴随论文 [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/pdf/2104.09864v1.pdf) 由 Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu 发布。
1. **[SegFormer](https://huggingface.co/transformers/master/model_doc/segformer.html)** (来自 NVIDIA) 伴随论文 [SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers](https://arxiv.org/abs/2105.15203) 由 Enze Xie, Wenhai Wang, Zhiding Yu, Anima Anandkumar, Jose M. Alvarez, Ping Luo 发布。
1. **[SEW](https://huggingface.co/transformers/model_doc/sew.html)** (来自 ASAPP) 伴随论文 [Performance-Efficiency Trade-offs in Unsupervised Pre-training for Speech Recognition](https://arxiv.org/abs/2109.06870) 由 Felix Wu, Kwangyoun Kim, Jing Pan, Kyu Han, Kilian Q. Weinberger, Yoav Artzi 发布。
1. **[SEW-D](https://huggingface.co/transformers/model_doc/sew_d.html)** (来自 ASAPP) 伴随论文 [Performance-Efficiency Trade-offs in Unsupervised Pre-training for Speech Recognition](https://arxiv.org/abs/2109.06870) 由 Felix Wu, Kwangyoun Kim, Jing Pan, Kyu Han, Kilian Q. Weinberger, Yoav Artzi 发布。
1. **[SpeechEncoderDecoder](https://huggingface.co/transformers/model_doc/speechencoderdecoder.html)**
......
......@@ -305,6 +305,7 @@ conda install -c huggingface transformers
1. **[RemBERT](https://huggingface.co/transformers/model_doc/rembert.html)** (from Google Research) released with the paper [Rethinking embedding coupling in pre-trained language models](https://arxiv.org/pdf/2010.12821.pdf) by Hyung Won Chung, Thibault Févry, Henry Tsai, M. Johnson, Sebastian Ruder.
1. **[RoBERTa](https://huggingface.co/transformers/model_doc/roberta.html)** (from Facebook), released together with the paper a [Robustly Optimized BERT Pretraining Approach](https://arxiv.org/abs/1907.11692) by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov.
1. **[RoFormer](https://huggingface.co/transformers/model_doc/roformer.html)** (from ZhuiyiTechnology), released together with the paper a [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/pdf/2104.09864v1.pdf) by Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu.
1. **[SegFormer](https://huggingface.co/transformers/master/model_doc/segformer.html)** (from NVIDIA) released with the paper [SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers](https://arxiv.org/abs/2105.15203) by Enze Xie, Wenhai Wang, Zhiding Yu, Anima Anandkumar, Jose M. Alvarez, Ping Luo.
1. **[SEW](https://huggingface.co/transformers/model_doc/sew.html)** (from ASAPP) released with the paper [Performance-Efficiency Trade-offs in Unsupervised Pre-training for Speech Recognition](https://arxiv.org/abs/2109.06870) by Felix Wu, Kwangyoun Kim, Jing Pan, Kyu Han, Kilian Q. Weinberger, Yoav Artzi.
1. **[SEW-D](https://huggingface.co/transformers/model_doc/sew_d.html)** (from ASAPP) released with the paper [Performance-Efficiency Trade-offs in Unsupervised Pre-training for Speech Recognition](https://arxiv.org/abs/2109.06870) by Felix Wu, Kwangyoun Kim, Jing Pan, Kyu Han, Kilian Q. Weinberger, Yoav Artzi.
1. **[SpeechEncoderDecoder](https://huggingface.co/transformers/model_doc/speechencoderdecoder.html)**
......
......@@ -278,73 +278,77 @@ Supported models
60. :doc:`RoFormer <model_doc/roformer>` (from ZhuiyiTechnology), released together with the paper a `RoFormer:
Enhanced Transformer with Rotary Position Embedding <https://arxiv.org/pdf/2104.09864v1.pdf>`__ by Jianlin Su and
Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu.
61. :doc:`SEW <model_doc/sew>` (from ASAPP) released with the paper `Performance-Efficiency Trade-offs in Unsupervised
61. `SegFormer <https://huggingface.co/transformers/master/model_doc/segformer.html>`__ (from NVIDIA) released with the
paper `SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers
<https://arxiv.org/abs/2105.15203>`__ by Enze Xie, Wenhai Wang, Zhiding Yu, Anima Anandkumar, Jose M. Alvarez, Ping
Luo.
62. :doc:`SEW <model_doc/sew>` (from ASAPP) released with the paper `Performance-Efficiency Trade-offs in Unsupervised
Pre-training for Speech Recognition <https://arxiv.org/abs/2109.06870>`__ by Felix Wu, Kwangyoun Kim, Jing Pan, Kyu
Han, Kilian Q. Weinberger, Yoav Artzi.
62. :doc:`SEW-D <model_doc/sew_d>` (from ASAPP) released with the paper `Performance-Efficiency Trade-offs in
63. :doc:`SEW-D <model_doc/sew_d>` (from ASAPP) released with the paper `Performance-Efficiency Trade-offs in
Unsupervised Pre-training for Speech Recognition <https://arxiv.org/abs/2109.06870>`__ by Felix Wu, Kwangyoun Kim,
Jing Pan, Kyu Han, Kilian Q. Weinberger, Yoav Artzi.
63. :doc:`SpeechToTextTransformer <model_doc/speech_to_text>` (from Facebook), released together with the paper
64. :doc:`SpeechToTextTransformer <model_doc/speech_to_text>` (from Facebook), released together with the paper
`fairseq S2T: Fast Speech-to-Text Modeling with fairseq <https://arxiv.org/abs/2010.05171>`__ by Changhan Wang, Yun
Tang, Xutai Ma, Anne Wu, Dmytro Okhonko, Juan Pino.
64. :doc:`SpeechToTextTransformer2 <model_doc/speech_to_text_2>` (from Facebook), released together with the paper
65. :doc:`SpeechToTextTransformer2 <model_doc/speech_to_text_2>` (from Facebook), released together with the paper
`Large-Scale Self- and Semi-Supervised Learning for Speech Translation <https://arxiv.org/abs/2104.06678>`__ by
Changhan Wang, Anne Wu, Juan Pino, Alexei Baevski, Michael Auli, Alexis Conneau.
65. :doc:`Splinter <model_doc/splinter>` (from Tel Aviv University), released together with the paper `Few-Shot
66. :doc:`Splinter <model_doc/splinter>` (from Tel Aviv University), released together with the paper `Few-Shot
Question Answering by Pretraining Span Selection <https://arxiv.org/abs/2101.00438>`__ by Ori Ram, Yuval Kirstain,
Jonathan Berant, Amir Globerson, Omer Levy.
66. :doc:`SqueezeBert <model_doc/squeezebert>` (from Berkeley) released with the paper `SqueezeBERT: What can computer
67. :doc:`SqueezeBert <model_doc/squeezebert>` (from Berkeley) released with the paper `SqueezeBERT: What can computer
vision teach NLP about efficient neural networks? <https://arxiv.org/abs/2006.11316>`__ by Forrest N. Iandola,
Albert E. Shaw, Ravi Krishna, and Kurt W. Keutzer.
67. :doc:`T5 <model_doc/t5>` (from Google AI) released with the paper `Exploring the Limits of Transfer Learning with a
68. :doc:`T5 <model_doc/t5>` (from Google AI) released with the paper `Exploring the Limits of Transfer Learning with a
Unified Text-to-Text Transformer <https://arxiv.org/abs/1910.10683>`__ by Colin Raffel and Noam Shazeer and Adam
Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu.
68. :doc:`T5v1.1 <model_doc/t5v1.1>` (from Google AI) released in the repository
69. :doc:`T5v1.1 <model_doc/t5v1.1>` (from Google AI) released in the repository
`google-research/text-to-text-transfer-transformer
<https://github.com/google-research/text-to-text-transfer-transformer/blob/main/released_checkpoints.md#t511>`__ by
Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi
Zhou and Wei Li and Peter J. Liu.
69. :doc:`TAPAS <model_doc/tapas>` (from Google AI) released with the paper `TAPAS: Weakly Supervised Table Parsing via
70. :doc:`TAPAS <model_doc/tapas>` (from Google AI) released with the paper `TAPAS: Weakly Supervised Table Parsing via
Pre-training <https://arxiv.org/abs/2004.02349>`__ by Jonathan Herzig, Paweł Krzysztof Nowak, Thomas Müller,
Francesco Piccinno and Julian Martin Eisenschlos.
70. :doc:`Transformer-XL <model_doc/transformerxl>` (from Google/CMU) released with the paper `Transformer-XL:
71. :doc:`Transformer-XL <model_doc/transformerxl>` (from Google/CMU) released with the paper `Transformer-XL:
Attentive Language Models Beyond a Fixed-Length Context <https://arxiv.org/abs/1901.02860>`__ by Zihang Dai*,
Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov.
71. `TrOCR <https://huggingface.co/transformers/master/model_doc/trocr.html>`__ (from Microsoft), released together
72. `TrOCR <https://huggingface.co/transformers/master/model_doc/trocr.html>`__ (from Microsoft), released together
with the paper `TrOCR: Transformer-based Optical Character Recognition with Pre-trained Models
<https://arxiv.org/abs/2109.10282>`__ by Minghao Li, Tengchao Lv, Lei Cui, Yijuan Lu, Dinei Florencio, Cha Zhang,
Zhoujun Li, Furu Wei.
72. `UniSpeech <https://huggingface.co/transformers/master/model_doc/unispeech.html>`__ (from Microsoft Research)
73. `UniSpeech <https://huggingface.co/transformers/master/model_doc/unispeech.html>`__ (from Microsoft Research)
released with the paper `UniSpeech: Unified Speech Representation Learning with Labeled and Unlabeled Data
<https://arxiv.org/abs/2101.07597>`__ by Chengyi Wang, Yu Wu, Yao Qian, Kenichi Kumatani, Shujie Liu, Furu Wei,
Michael Zeng, Xuedong Huang.
73. `UniSpeechSat <https://huggingface.co/transformers/master/model_doc/unispeech_sat.html>`__ (from Microsoft
74. `UniSpeechSat <https://huggingface.co/transformers/master/model_doc/unispeech_sat.html>`__ (from Microsoft
Research) released with the paper `UNISPEECH-SAT: UNIVERSAL SPEECH REPRESENTATION LEARNING WITH SPEAKER AWARE
PRE-TRAINING <https://arxiv.org/abs/2110.05752>`__ by Sanyuan Chen, Yu Wu, Chengyi Wang, Zhengyang Chen, Zhuo Chen,
Shujie Liu, Jian Wu, Yao Qian, Furu Wei, Jinyu Li, Xiangzhan Yu.
74. :doc:`Vision Transformer (ViT) <model_doc/vit>` (from Google AI) released with the paper `An Image is Worth 16x16
75. :doc:`Vision Transformer (ViT) <model_doc/vit>` (from Google AI) released with the paper `An Image is Worth 16x16
Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>`__ by Alexey Dosovitskiy,
Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias
Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby.
75. :doc:`VisualBERT <model_doc/visual_bert>` (from UCLA NLP) released with the paper `VisualBERT: A Simple and
76. :doc:`VisualBERT <model_doc/visual_bert>` (from UCLA NLP) released with the paper `VisualBERT: A Simple and
Performant Baseline for Vision and Language <https://arxiv.org/pdf/1908.03557>`__ by Liunian Harold Li, Mark
Yatskar, Da Yin, Cho-Jui Hsieh, Kai-Wei Chang.
76. :doc:`Wav2Vec2 <model_doc/wav2vec2>` (from Facebook AI) released with the paper `wav2vec 2.0: A Framework for
77. :doc:`Wav2Vec2 <model_doc/wav2vec2>` (from Facebook AI) released with the paper `wav2vec 2.0: A Framework for
Self-Supervised Learning of Speech Representations <https://arxiv.org/abs/2006.11477>`__ by Alexei Baevski, Henry
Zhou, Abdelrahman Mohamed, Michael Auli.
77. :doc:`XLM <model_doc/xlm>` (from Facebook) released together with the paper `Cross-lingual Language Model
78. :doc:`XLM <model_doc/xlm>` (from Facebook) released together with the paper `Cross-lingual Language Model
Pretraining <https://arxiv.org/abs/1901.07291>`__ by Guillaume Lample and Alexis Conneau.
78. :doc:`XLM-ProphetNet <model_doc/xlmprophetnet>` (from Microsoft Research) released with the paper `ProphetNet:
79. :doc:`XLM-ProphetNet <model_doc/xlmprophetnet>` (from Microsoft Research) released with the paper `ProphetNet:
Predicting Future N-gram for Sequence-to-Sequence Pre-training <https://arxiv.org/abs/2001.04063>`__ by Yu Yan,
Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou.
79. :doc:`XLM-RoBERTa <model_doc/xlmroberta>` (from Facebook AI), released together with the paper `Unsupervised
80. :doc:`XLM-RoBERTa <model_doc/xlmroberta>` (from Facebook AI), released together with the paper `Unsupervised
Cross-lingual Representation Learning at Scale <https://arxiv.org/abs/1911.02116>`__ by Alexis Conneau*, Kartikay
Khandelwal*, Naman Goyal, Vishrav Chaudhary, Guillaume Wenzek, Francisco Guzmán, Edouard Grave, Myle Ott, Luke
Zettlemoyer and Veselin Stoyanov.
80. :doc:`XLNet <model_doc/xlnet>` (from Google/CMU) released with the paper `​XLNet: Generalized Autoregressive
81. :doc:`XLNet <model_doc/xlnet>` (from Google/CMU) released with the paper `​XLNet: Generalized Autoregressive
Pretraining for Language Understanding <https://arxiv.org/abs/1906.08237>`__ by Zhilin Yang*, Zihang Dai*, Yiming
Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le.
81. :doc:`XLSR-Wav2Vec2 <model_doc/xlsr_wav2vec2>` (from Facebook AI) released with the paper `Unsupervised
82. :doc:`XLSR-Wav2Vec2 <model_doc/xlsr_wav2vec2>` (from Facebook AI) released with the paper `Unsupervised
Cross-Lingual Representation Learning For Speech Recognition <https://arxiv.org/abs/2006.13979>`__ by Alexis
Conneau, Alexei Baevski, Ronan Collobert, Abdelrahman Mohamed, Michael Auli.
......@@ -368,7 +372,7 @@ Flax), PyTorch, and/or TensorFlow.
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| BART | ✅ | ✅ | ✅ | ✅ | ✅ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| BeiT | ❌ | ❌ | ✅ | ❌ | ✅ |
| BEiT | ❌ | ❌ | ✅ | ❌ | ✅ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| BERT | ✅ | ✅ | ✅ | ✅ | ✅ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
......@@ -470,6 +474,8 @@ Flax), PyTorch, and/or TensorFlow.
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| RoFormer | ✅ | ✅ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| SegFormer | ❌ | ❌ | ✅ | ❌ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| SEW | ❌ | ❌ | ✅ | ❌ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| SEW-D | ❌ | ❌ | ✅ | ❌ | ❌ |
......@@ -654,6 +660,7 @@ Flax), PyTorch, and/or TensorFlow.
model_doc/retribert
model_doc/roberta
model_doc/roformer
model_doc/segformer
model_doc/sew
model_doc/sew_d
model_doc/speechencoderdecoder
......
..
Copyright 2021 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
SegFormer
-----------------------------------------------------------------------------------------------------------------------
Overview
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
The SegFormer model was proposed in `SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers
<https://arxiv.org/abs/2105.15203>`__ by Enze Xie, Wenhai Wang, Zhiding Yu, Anima Anandkumar, Jose M. Alvarez, Ping
Luo. The model consists of a hierarchical Transformer encoder and a lightweight all-MLP decode head to achieve great
results on image segmentation benchmarks such as ADE20K and Cityscapes.
The abstract from the paper is the following:
*We present SegFormer, a simple, efficient yet powerful semantic segmentation framework which unifies Transformers with
lightweight multilayer perception (MLP) decoders. SegFormer has two appealing features: 1) SegFormer comprises a novel
hierarchically structured Transformer encoder which outputs multiscale features. It does not need positional encoding,
thereby avoiding the interpolation of positional codes which leads to decreased performance when the testing resolution
differs from training. 2) SegFormer avoids complex decoders. The proposed MLP decoder aggregates information from
different layers, and thus combining both local attention and global attention to render powerful representations. We
show that this simple and lightweight design is the key to efficient segmentation on Transformers. We scale our
approach up to obtain a series of models from SegFormer-B0 to SegFormer-B5, reaching significantly better performance
and efficiency than previous counterparts. For example, SegFormer-B4 achieves 50.3% mIoU on ADE20K with 64M parameters,
being 5x smaller and 2.2% better than the previous best method. Our best model, SegFormer-B5, achieves 84.0% mIoU on
Cityscapes validation set and shows excellent zero-shot robustness on Cityscapes-C.*
This model was contributed by `nielsr <https://huggingface.co/nielsr>`__. The original code can be found `here
<https://github.com/NVlabs/SegFormer>`__.
SegformerConfig
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.SegformerConfig
:members:
SegformerFeatureExtractor
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.SegformerFeatureExtractor
:members: __call__
SegformerModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.SegformerModel
:members: forward
SegformerDecodeHead
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.SegformerDecodeHead
:members: forward
SegformerForImageClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.SegformerForImageClassification
:members: forward
SegformerForSemanticSegmentation
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.SegformerForSemanticSegmentation
:members: forward
......@@ -252,6 +252,7 @@ _import_structure = {
"models.retribert": ["RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "RetriBertConfig", "RetriBertTokenizer"],
"models.roberta": ["ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP", "RobertaConfig", "RobertaTokenizer"],
"models.roformer": ["ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "RoFormerConfig", "RoFormerTokenizer"],
"models.segformer": ["SEGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "SegformerConfig"],
"models.sew": ["SEW_PRETRAINED_CONFIG_ARCHIVE_MAP", "SEWConfig"],
"models.sew_d": ["SEW_D_PRETRAINED_CONFIG_ARCHIVE_MAP", "SEWDConfig"],
"models.speech_encoder_decoder": ["SpeechEncoderDecoderConfig"],
......@@ -475,6 +476,7 @@ if is_vision_available():
_import_structure["models.detr"].append("DetrFeatureExtractor")
_import_structure["models.layoutlmv2"].append("LayoutLMv2FeatureExtractor")
_import_structure["models.layoutlmv2"].append("LayoutLMv2Processor")
_import_structure["models.segformer"].append("SegformerFeatureExtractor")
_import_structure["models.vit"].append("ViTFeatureExtractor")
else:
from .utils import dummy_vision_objects
......@@ -1147,6 +1149,17 @@ if is_torch_available():
"load_tf_weights_in_roformer",
]
)
_import_structure["models.segformer"].extend(
[
"SEGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
"SegformerDecodeHead",
"SegformerForImageClassification",
"SegformerForSemanticSegmentation",
"SegformerLayer",
"SegformerModel",
"SegformerPreTrainedModel",
]
)
_import_structure["models.sew"].extend(
[
"SEW_PRETRAINED_MODEL_ARCHIVE_LIST",
......@@ -2145,6 +2158,7 @@ if TYPE_CHECKING:
from .models.retribert import RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RetriBertConfig, RetriBertTokenizer
from .models.roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig, RobertaTokenizer
from .models.roformer import ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, RoFormerConfig, RoFormerTokenizer
from .models.segformer import SEGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, SegformerConfig
from .models.sew import SEW_PRETRAINED_CONFIG_ARCHIVE_MAP, SEWConfig
from .models.sew_d import SEW_D_PRETRAINED_CONFIG_ARCHIVE_MAP, SEWDConfig
from .models.speech_encoder_decoder import SpeechEncoderDecoderConfig
......@@ -2330,6 +2344,7 @@ if TYPE_CHECKING:
from .models.deit import DeiTFeatureExtractor
from .models.detr import DetrFeatureExtractor
from .models.layoutlmv2 import LayoutLMv2FeatureExtractor, LayoutLMv2Processor
from .models.segformer import SegformerFeatureExtractor
from .models.vit import ViTFeatureExtractor
else:
from .utils.dummy_vision_objects import *
......@@ -2890,6 +2905,15 @@ if TYPE_CHECKING:
RoFormerPreTrainedModel,
load_tf_weights_in_roformer,
)
from .models.segformer import (
SEGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
SegformerDecodeHead,
SegformerForImageClassification,
SegformerForSemanticSegmentation,
SegformerLayer,
SegformerModel,
SegformerPreTrainedModel,
)
from .models.sew import (
SEW_PRETRAINED_MODEL_ARCHIVE_LIST,
SEWForCTC,
......
......@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Union
import numpy as np
import PIL.Image
......@@ -24,6 +26,10 @@ IMAGENET_DEFAULT_STD = [0.229, 0.224, 0.225]
IMAGENET_STANDARD_MEAN = [0.5, 0.5, 0.5]
IMAGENET_STANDARD_STD = [0.5, 0.5, 0.5]
ImageInput = Union[
PIL.Image.Image, np.ndarray, "torch.Tensor", List[PIL.Image.Image], List[np.ndarray], List["torch.Tensor"] # noqa
]
def is_torch_tensor(obj):
return _is_torch(obj) if is_torch_available() else False
......@@ -101,7 +107,7 @@ class ImageFeatureExtractionMixin:
if rescale:
image = image.astype(np.float32) / 255.0
if channel_first:
if channel_first and image.ndim == 3:
image = image.transpose(2, 0, 1)
return image
......@@ -156,8 +162,10 @@ class ImageFeatureExtractionMixin:
"""
self._ensure_format_supported(image)
if not isinstance(size, tuple):
if isinstance(size, int):
size = (size, size)
elif isinstance(size, list):
size = tuple(size)
if not isinstance(image, PIL.Image.Image):
image = self.to_pil_image(image)
......
......@@ -81,6 +81,7 @@ from . import (
retribert,
roberta,
roformer,
segformer,
sew,
sew_d,
speech_to_text,
......
......@@ -29,6 +29,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
("vision-encoder-decoder", "VisionEncoderDecoderConfig"),
("trocr", "TrOCRConfig"),
("fnet", "FNetConfig"),
("segformer", "SegformerConfig"),
("gptj", "GPTJConfig"),
("layoutlmv2", "LayoutLMv2Config"),
("beit", "BeitConfig"),
......@@ -108,6 +109,7 @@ CONFIG_ARCHIVE_MAP_MAPPING_NAMES = OrderedDict(
# Add archive maps here
("fnet", "FNET_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("pegasus", "PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("segformer", "SEGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("gptj", "GPTJ_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("layoutlmv2", "LAYOUTLMV2_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("beit", "BEIT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
......@@ -179,8 +181,9 @@ MODEL_NAMES_MAPPING = OrderedDict(
("vision-encoder-decoder", "Vision Encoder decoder"),
("trocr", "TrOCR"),
("fnet", "FNet"),
("segformer", "SegFormer"),
("gptj", "GPT-J"),
("beit", "BeiT"),
("beit", "BEiT"),
("rembert", "RemBERT"),
("layoutlmv2", "LayoutLMv2"),
("visual_bert", "VisualBert"),
......
......@@ -29,6 +29,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
[
# Base model mapping
("fnet", "FNetModel"),
("segformer", "SegformerModel"),
("gptj", "GPTJModel"),
("layoutlmv2", "LayoutLMv2Model"),
("beit", "BeitModel"),
......@@ -232,6 +233,7 @@ MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
("vit", "ViTForImageClassification"),
("deit", ("DeiTForImageClassification", "DeiTForImageClassificationWithTeacher")),
("beit", "BeitForImageClassification"),
("segformer", "SegformerForImageClassification"),
]
)
......
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...file_utils import _LazyModule, is_torch_available, is_vision_available
_import_structure = {
"configuration_segformer": ["SEGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "SegformerConfig"],
}
if is_vision_available():
_import_structure["feature_extraction_segformer"] = ["SegformerFeatureExtractor"]
if is_torch_available():
_import_structure["modeling_segformer"] = [
"SEGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
"SegformerDecodeHead",
"SegformerForImageClassification",
"SegformerForSemanticSegmentation",
"SegformerLayer",
"SegformerModel",
"SegformerPreTrainedModel",
]
if TYPE_CHECKING:
from .configuration_segformer import SEGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, SegformerConfig
if is_vision_available():
from .feature_extraction_segformer import SegformerFeatureExtractor
if is_torch_available():
from .modeling_segformer import (
SEGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
SegformerDecodeHead,
SegformerForImageClassification,
SegformerForSemanticSegmentation,
SegformerLayer,
SegformerModel,
SegformerPreTrainedModel,
)
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
# coding=utf-8
# Copyright 2021 NVIDIA and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" SegFormer model configuration """
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__)
SEGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"nvidia/segformer-b0-finetuned-ade-512-512": "https://huggingface.co/nvidia/segformer-b0-finetuned-ade-512-512/resolve/main/config.json",
# See all SegFormer models at https://huggingface.co/models?filter=segformer
}
class SegformerConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a :class:`~transformers.SegformerModel`. It is used
to instantiate an SegFormer model according to the specified arguments, defining the model architecture.
Instantiating a configuration with the defaults will yield a similar configuration to that of the SegFormer
`nvidia/segformer-b0-finetuned-ade-512-512 <https://huggingface.co/nvidia/segformer-b0-finetuned-ade-512-512>`__
architecture.
Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model
outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information.
Args:
image_size (:obj:`int`, `optional`, defaults to 512):
The size (resolution) of each image.
num_channels (:obj:`int`, `optional`, defaults to 3):
The number of input channels.
num_encoder_blocks (:obj:`int`, `optional`, defaults to 4):
The number of encoder blocks (i.e. stages in the Mix Transformer encoder).
depths (:obj:`List[int]`, `optional`, defaults to [2, 2, 2, 2]):
The number of layers in each encoder block.
sr_ratios (:obj:`List[int]`, `optional`, defaults to [8, 4, 2, 1]):
Sequence reduction ratios in each encoder block.
hidden_sizes (:obj:`List[int]`, `optional`, defaults to [32, 64, 160, 256]):
Dimension of each of the encoder blocks.
downsampling_rates (:obj:`List[int]`, `optional`, defaults to [1, 4, 8, 16]):
Downsample rate of the image resolution compared to the original image size before each encoder block.
patch_sizes (:obj:`List[int]`, `optional`, defaults to [7, 3, 3, 3]):
Patch size before each encoder block.
strides (:obj:`List[int]`, `optional`, defaults to [4, 2, 2, 2]):
Stride before each encoder block.
num_attention_heads (:obj:`List[int]`, `optional`, defaults to [1, 2, 4, 8]):
Number of attention heads for each attention layer in each block of the Transformer encoder.
mlp_ratios (:obj:`List[int]`, `optional`, defaults to [4, 4, 4, 4]):
Ratio of the size of the hidden layer compared to the size of the input layer of the Mix FFNs in the
encoder blocks.
hidden_act (:obj:`str` or :obj:`function`, `optional`, defaults to :obj:`"gelu"`):
The non-linear activation function (function or string) in the encoder and pooler. If string,
:obj:`"gelu"`, :obj:`"relu"`, :obj:`"selu"` and :obj:`"gelu_new"` are supported.
hidden_dropout_prob (:obj:`float`, `optional`, defaults to 0.0):
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
attention_probs_dropout_prob (:obj:`float`, `optional`, defaults to 0.0):
The dropout ratio for the attention probabilities.
classifier_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):
The dropout probability before the classification head.
initializer_range (:obj:`float`, `optional`, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
drop_path_rate (:obj:`float`, `optional`, defaults to 0.1):
The dropout probability for stochastic depth, used in the blocks of the Transformer encoder.
layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-6):
The epsilon used by the layer normalization layers.
decoder_hidden_size (:obj:`int`, `optional`, defaults to 256):
The dimension of the all-MLP decode head.
reshape_last_stage (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether to reshape the features of the last stage back to :obj:`(batch_size, num_channels, height, width)`.
Only required for the semantic segmentation model.
Example::
>>> from transformers import SegformerModel, SegformerConfig
>>> # Initializing a SegFormer nvidia/segformer-b0-finetuned-ade-512-512 style configuration
>>> configuration = SegformerConfig()
>>> # Initializing a model from the nvidia/segformer-b0-finetuned-ade-512-512 style configuration
>>> model = SegformerModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
"""
model_type = "segformer"
def __init__(
self,
image_size=224,
num_channels=3,
num_encoder_blocks=4,
depths=[2, 2, 2, 2],
sr_ratios=[8, 4, 2, 1],
hidden_sizes=[32, 64, 160, 256],
downsampling_rates=[1, 4, 8, 16],
patch_sizes=[7, 3, 3, 3],
strides=[4, 2, 2, 2],
num_attention_heads=[1, 2, 5, 8],
mlp_ratios=[4, 4, 4, 4],
hidden_act="gelu",
hidden_dropout_prob=0.0,
attention_probs_dropout_prob=0.0,
classifier_dropout_prob=0.1,
initializer_range=0.02,
drop_path_rate=0.1,
layer_norm_eps=1e-6,
decoder_hidden_size=256,
is_encoder_decoder=False,
reshape_last_stage=True,
**kwargs
):
super().__init__(**kwargs)
self.image_size = image_size
self.num_channels = num_channels
self.num_encoder_blocks = num_encoder_blocks
self.depths = depths
self.sr_ratios = sr_ratios
self.hidden_sizes = hidden_sizes
self.downsampling_rates = downsampling_rates
self.patch_sizes = patch_sizes
self.strides = strides
self.mlp_ratios = mlp_ratios
self.num_attention_heads = num_attention_heads
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.classifier_dropout_prob = classifier_dropout_prob
self.initializer_range = initializer_range
self.drop_path_rate = drop_path_rate
self.layer_norm_eps = layer_norm_eps
self.decoder_hidden_size = decoder_hidden_size
self.reshape_last_stage = reshape_last_stage
# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert SegFormer checkpoints."""
import argparse
import json
from collections import OrderedDict
from pathlib import Path
import torch
from PIL import Image
import requests
from huggingface_hub import cached_download, hf_hub_url
from transformers import (
SegformerConfig,
SegformerFeatureExtractor,
SegformerForImageClassification,
SegformerForSemanticSegmentation,
)
from transformers.utils import logging
logging.set_verbosity_info()
logger = logging.get_logger(__name__)
def rename_keys(state_dict, encoder_only=False):
new_state_dict = OrderedDict()
for key, value in state_dict.items():
if encoder_only and not key.startswith("head"):
key = "segformer.encoder." + key
if key.startswith("backbone"):
key = key.replace("backbone", "segformer.encoder")
if "patch_embed" in key:
# replace for example patch_embed1 by patch_embeddings.0
idx = key[key.find("patch_embed") + len("patch_embed")]
key = key.replace(f"patch_embed{idx}", f"patch_embeddings.{int(idx)-1}")
if "norm" in key:
key = key.replace("norm", "layer_norm")
if "segformer.encoder.layer_norm" in key:
# replace for example layer_norm1 by layer_norm.0
idx = key[key.find("segformer.encoder.layer_norm") + len("segformer.encoder.layer_norm")]
key = key.replace(f"layer_norm{idx}", f"layer_norm.{int(idx)-1}")
if "layer_norm1" in key:
key = key.replace("layer_norm1", "layer_norm_1")
if "layer_norm2" in key:
key = key.replace("layer_norm2", "layer_norm_2")
if "block" in key:
# replace for example block1 by block.0
idx = key[key.find("block") + len("block")]
key = key.replace(f"block{idx}", f"block.{int(idx)-1}")
if "attn.q" in key:
key = key.replace("attn.q", "attention.self.query")
if "attn.proj" in key:
key = key.replace("attn.proj", "attention.output.dense")
if "attn" in key:
key = key.replace("attn", "attention.self")
if "fc1" in key:
key = key.replace("fc1", "dense1")
if "fc2" in key:
key = key.replace("fc2", "dense2")
if "linear_pred" in key:
key = key.replace("linear_pred", "classifier")
if "linear_fuse" in key:
key = key.replace("linear_fuse.conv", "linear_fuse")
key = key.replace("linear_fuse.bn", "batch_norm")
if "linear_c" in key:
# replace for example linear_c4 by linear_c.3
idx = key[key.find("linear_c") + len("linear_c")]
key = key.replace(f"linear_c{idx}", f"linear_c.{int(idx)-1}")
if key.startswith("head"):
key = key.replace("head", "classifier")
new_state_dict[key] = value
return new_state_dict
def read_in_k_v(state_dict, config):
# for each of the encoder blocks:
for i in range(config.num_encoder_blocks):
for j in range(config.depths[i]):
# read in weights + bias of keys and values (which is a single matrix in the original implementation)
kv_weight = state_dict.pop(f"segformer.encoder.block.{i}.{j}.attention.self.kv.weight")
kv_bias = state_dict.pop(f"segformer.encoder.block.{i}.{j}.attention.self.kv.bias")
# next, add keys and values (in that order) to the state dict
state_dict[f"segformer.encoder.block.{i}.{j}.attention.self.key.weight"] = kv_weight[
: config.hidden_sizes[i], :
]
state_dict[f"segformer.encoder.block.{i}.{j}.attention.self.key.bias"] = kv_bias[: config.hidden_sizes[i]]
state_dict[f"segformer.encoder.block.{i}.{j}.attention.self.value.weight"] = kv_weight[
config.hidden_sizes[i] :, :
]
state_dict[f"segformer.encoder.block.{i}.{j}.attention.self.value.bias"] = kv_bias[
config.hidden_sizes[i] :
]
# We will verify our results on a COCO image
def prepare_img():
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
return image
@torch.no_grad()
def convert_segformer_checkpoint(model_name, checkpoint_path, pytorch_dump_folder_path):
"""
Copy/paste/tweak model's weights to our SegFormer structure.
"""
# load default SegFormer configuration
config = SegformerConfig()
encoder_only = False
# set attributes based on model_name
repo_id = "datasets/huggingface/label-files"
if "segformer" in model_name:
size = model_name[len("segformer.") : len("segformer.") + 2]
if "ade" in model_name:
config.num_labels = 150
filename = "ade20k-id2label.json"
expected_shape = (1, 150, 128, 128)
elif "city" in model_name:
config.num_labels = 19
filename = "cityscapes-id2label.json"
expected_shape = (1, 19, 128, 128)
else:
raise ValueError(f"Model {model_name} not supported")
elif "mit" in model_name:
encoder_only = True
size = model_name[4:6]
config.num_labels = 1000
filename = "imagenet-1k-id2label.json"
expected_shape = (1, 1000)
else:
raise ValueError(f"Model {model_name} not supported")
# set config attributes
id2label = json.load(open(cached_download(hf_hub_url(repo_id, filename)), "r"))
id2label = {int(k): v for k, v in id2label.items()}
config.id2label = id2label
config.label2id = {v: k for k, v in id2label.items()}
if size == "b0":
pass
elif size == "b1":
config.hidden_sizes = [64, 128, 320, 512]
config.decoder_hidden_size = 256
elif size == "b2":
config.hidden_sizes = [64, 128, 320, 512]
config.decoder_hidden_size = 768
config.depths = [3, 4, 6, 3]
elif size == "b3":
config.hidden_sizes = [64, 128, 320, 512]
config.decoder_hidden_size = 768
config.depths = [3, 4, 18, 3]
elif size == "b4":
config.hidden_sizes = [64, 128, 320, 512]
config.decoder_hidden_size = 768
config.depths = [3, 8, 27, 3]
elif size == "b5":
config.hidden_sizes = [64, 128, 320, 512]
config.decoder_hidden_size = 768
config.depths = [3, 6, 40, 3]
else:
raise ValueError(f"Size {size} not supported")
# load feature extractor (only resize + normalize)
feature_extractor = SegformerFeatureExtractor(
image_scale=(512, 512), keep_ratio=False, align=False, do_random_crop=False
)
# prepare image
image = prepare_img()
pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
logger.info(f"Converting model {model_name}...")
# load original state dict
if encoder_only:
state_dict = torch.load(checkpoint_path, map_location=torch.device("cpu"))
else:
state_dict = torch.load(checkpoint_path, map_location=torch.device("cpu"))["state_dict"]
# rename keys
state_dict = rename_keys(state_dict, encoder_only=encoder_only)
if not encoder_only:
del state_dict["decode_head.conv_seg.weight"]
del state_dict["decode_head.conv_seg.bias"]
# key and value matrices need special treatment
read_in_k_v(state_dict, config)
# create HuggingFace model and load state dict
if encoder_only:
config.reshape_last_stage = False
model = SegformerForImageClassification(config)
else:
model = SegformerForSemanticSegmentation(config)
model.load_state_dict(state_dict)
model.eval()
# forward pass
outputs = model(pixel_values)
logits = outputs.logits
# set expected_slice based on model name
# ADE20k checkpoints
if model_name == "segformer.b0.512x512.ade.160k":
expected_slice = torch.tensor(
[
[[-4.6310, -5.5232, -6.2356], [-5.1921, -6.1444, -6.5996], [-5.4424, -6.2790, -6.7574]],
[[-12.1391, -13.3122, -13.9554], [-12.8732, -13.9352, -14.3563], [-12.9438, -13.8226, -14.2513]],
[[-12.5134, -13.4686, -14.4915], [-12.8669, -14.4343, -14.7758], [-13.2523, -14.5819, -15.0694]],
]
)
elif model_name == "segformer.b1.512x512.ade.160k":
expected_slice = torch.tensor(
[
[[-7.5820, -8.7231, -8.3215], [-8.0600, -10.3529, -10.0304], [-7.5208, -9.4103, -9.6239]],
[[-12.6918, -13.8994, -13.7137], [-13.3196, -15.7523, -15.4789], [-12.9343, -14.8757, -14.9689]],
[[-11.1911, -11.9421, -11.3243], [-11.3342, -13.6839, -13.3581], [-10.3909, -12.1832, -12.4858]],
]
)
elif model_name == "segformer.b2.512x512.ade.160k":
expected_slice = torch.tensor(
[
[[-11.8173, -14.3850, -16.3128], [-14.5648, -16.5804, -18.6568], [-14.7223, -15.7387, -18.4218]],
[[-15.7290, -17.9171, -19.4423], [-18.3105, -19.9448, -21.4661], [-17.9296, -18.6497, -20.7910]],
[[-15.0783, -17.0336, -18.2789], [-16.8771, -18.6870, -20.1612], [-16.2454, -17.1426, -19.5055]],
]
)
elif model_name == "segformer.b3.512x512.ade.160k":
expected_slice = torch.tensor(
[
[[-9.0878, -10.2081, -10.1891], [-9.3144, -10.7941, -10.9843], [-9.2294, -10.3855, -10.5704]],
[[-12.2316, -13.9068, -13.6102], [-12.9161, -14.3702, -14.3235], [-12.5233, -13.7174, -13.7932]],
[[-14.6275, -15.2490, -14.9727], [-14.3400, -15.9687, -16.2827], [-14.1484, -15.4033, -15.8937]],
]
)
elif model_name == "segformer.b4.512x512.ade.160k":
expected_slice = torch.tensor(
[
[[-12.3144, -13.2447, -14.0802], [-13.3614, -14.5816, -15.6117], [-13.3340, -14.4433, -16.2219]],
[[-19.2781, -20.4128, -20.7506], [-20.6153, -21.6566, -22.0998], [-19.9800, -21.0430, -22.1494]],
[[-18.8739, -19.7804, -21.1834], [-20.1233, -21.6765, -23.2944], [-20.0315, -21.2641, -23.6944]],
]
)
elif model_name == "segformer.b5.640x640.ade.160k":
expected_slice = torch.tensor(
[
[[-9.5524, -12.0835, -11.7348], [-10.5229, -13.6446, -14.5662], [-9.5842, -12.8851, -13.9414]],
[[-15.3432, -17.5323, -17.0818], [-16.3330, -18.9255, -19.2101], [-15.1340, -17.7848, -18.3971]],
[[-12.6072, -14.9486, -14.6631], [-13.7629, -17.0907, -17.7745], [-12.7899, -16.1695, -17.1671]],
]
)
# Cityscapes checkpoints
elif model_name == "segformer.b0.1024x1024.city.160k":
expected_slice = torch.tensor(
[
[[-11.9295, -13.4057, -14.8106], [-13.3431, -14.8179, -15.3781], [-14.2836, -15.5942, -16.1588]],
[[-11.4906, -12.8067, -13.6564], [-13.1189, -14.0500, -14.1543], [-13.8748, -14.5136, -14.8789]],
[[0.5374, 0.1067, -0.4742], [0.1141, -0.2255, -0.7099], [-0.3000, -0.5924, -1.3105]],
]
)
elif model_name == "segformer.b0.512x1024.city.160k":
expected_slice = torch.tensor(
[
[[-7.8217, -9.8767, -10.1717], [-9.4438, -10.9058, -11.4047], [-9.7939, -12.3495, -12.1079]],
[[-7.1514, -9.5336, -10.0860], [-9.7776, -11.6822, -11.8439], [-10.1411, -12.7655, -12.8972]],
[[0.3021, 0.0805, -0.2310], [-0.0328, -0.1605, -0.2714], [-0.1408, -0.5477, -0.6976]],
]
)
elif model_name == "segformer.b0.640x1280.city.160k":
expected_slice = torch.tensor(
[
[
[-1.1372e01, -1.2787e01, -1.3477e01],
[-1.2536e01, -1.4194e01, -1.4409e01],
[-1.3217e01, -1.4888e01, -1.5327e01],
],
[
[-1.4791e01, -1.7122e01, -1.8277e01],
[-1.7163e01, -1.9192e01, -1.9533e01],
[-1.7897e01, -1.9991e01, -2.0315e01],
],
[
[7.6723e-01, 4.1921e-01, -7.7878e-02],
[4.7772e-01, 9.5557e-03, -2.8082e-01],
[3.6032e-01, -2.4826e-01, -5.1168e-01],
],
]
)
elif model_name == "segformer.b0.768x768.city.160k":
expected_slice = torch.tensor(
[
[[-9.4959, -11.3087, -11.7479], [-11.0025, -12.6540, -12.3319], [-11.4064, -13.0487, -12.9905]],
[[-9.8905, -11.3084, -12.0854], [-11.1726, -12.7698, -12.9583], [-11.5985, -13.3278, -14.1774]],
[[0.2213, 0.0192, -0.2466], [-0.1731, -0.4213, -0.4874], [-0.3126, -0.6541, -1.1389]],
]
)
elif model_name == "segformer.b1.1024x1024.city.160k":
expected_slice = torch.tensor(
[
[[-13.5748, -13.9111, -12.6500], [-14.3500, -15.3683, -14.2328], [-14.7532, -16.0424, -15.6087]],
[[-17.1651, -15.8725, -12.9653], [-17.2580, -17.3718, -14.8223], [-16.6058, -16.8783, -16.7452]],
[[-3.6456, -3.0209, -1.4203], [-3.0797, -3.1959, -2.0000], [-1.8757, -1.9217, -1.6997]],
]
)
elif model_name == "segformer.b2.1024x1024.city.160k":
expected_slice = torch.tensor(
[
[[-16.0976, -16.4856, -17.3962], [-16.6234, -19.0342, -19.7685], [-16.0900, -18.0661, -19.1180]],
[[-18.4750, -18.8488, -19.5074], [-19.4030, -22.1570, -22.5977], [-19.1191, -20.8486, -22.3783]],
[[-4.5178, -5.5037, -6.5109], [-5.0884, -7.2174, -8.0334], [-4.4156, -5.8117, -7.2970]],
]
)
elif model_name == "segformer.b3.1024x1024.city.160k":
expected_slice = torch.tensor(
[
[[-14.2081, -14.4732, -14.1977], [-14.5867, -16.4423, -16.6356], [-13.4441, -14.9685, -16.8696]],
[[-14.4576, -14.7073, -15.0451], [-15.0816, -17.6237, -17.9873], [-14.4213, -16.0199, -18.5992]],
[[-4.7349, -4.9588, -5.0966], [-4.3210, -6.9325, -7.2591], [-3.4312, -4.7484, -7.1917]],
]
)
elif model_name == "segformer.b4.1024x1024.city.160k":
expected_slice = torch.tensor(
[
[[-11.7737, -11.9526, -11.3273], [-13.6692, -14.4574, -13.8878], [-13.8937, -14.6924, -15.9345]],
[[-14.6706, -14.5330, -14.1306], [-16.1502, -16.8180, -16.4269], [-16.8338, -17.8939, -20.1746]],
[[1.0491, 0.8289, 1.0310], [1.1044, 0.5219, 0.8055], [1.0899, 0.6926, 0.5590]],
]
)
elif model_name == "segformer.b5.1024x1024.city.160k":
expected_slice = torch.tensor(
[
[[-12.5641, -13.4777, -13.0684], [-13.9587, -15.8983, -16.6557], [-13.3109, -15.7350, -16.3141]],
[[-14.7074, -15.4352, -14.5944], [-16.6353, -18.1663, -18.6120], [-15.1702, -18.0329, -18.1547]],
[[-1.7990, -2.0951, -1.7784], [-2.6397, -3.8245, -3.9686], [-1.5264, -2.8126, -2.9316]],
]
)
else:
predicted_class_idx = logits.argmax(-1).item()
print("Predicted class:", model.config.id2label[predicted_class_idx])
# verify logits
if not encoder_only:
assert logits.shape == expected_shape
assert torch.allclose(logits[0, :3, :3, :3], expected_slice, atol=1e-2)
# finally, save model and feature extractor
logger.info(f"Saving PyTorch model and feature extractor to {pytorch_dump_folder_path}...")
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
model.save_pretrained(pytorch_dump_folder_path)
feature_extractor.save_pretrained(pytorch_dump_folder_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_name",
default="segformer.b0.512x512.ade.160k",
type=str,
help="Name of the model you'd like to convert.",
)
parser.add_argument(
"--checkpoint_path", default=None, type=str, help="Path to the original PyTorch checkpoint (.pth file)."
)
parser.add_argument(
"--pytorch_dump_folder_path", default=None, type=str, help="Path to the folder to output PyTorch model."
)
args = parser.parse_args()
convert_segformer_checkpoint(args.model_name, args.checkpoint_path, args.pytorch_dump_folder_path)
This diff is collapsed.
......@@ -3281,6 +3281,47 @@ def load_tf_weights_in_roformer(*args, **kwargs):
requires_backends(load_tf_weights_in_roformer, ["torch"])
SEGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None
class SegformerDecodeHead:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class SegformerForImageClassification:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class SegformerForSemanticSegmentation:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class SegformerLayer:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class SegformerModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class SegformerPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
SEW_PRETRAINED_MODEL_ARCHIVE_LIST = None
......
......@@ -50,6 +50,11 @@ class LayoutLMv2Processor:
requires_backends(cls, ["vision"])
class SegformerFeatureExtractor:
def __init__(self, *args, **kwargs):
requires_backends(self, ["vision"])
class ViTFeatureExtractor:
def __init__(self, *args, **kwargs):
requires_backends(self, ["vision"])
# coding=utf-8
# Copyright 2021 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from transformers.file_utils import is_torch_available, is_vision_available
from transformers.testing_utils import require_torch, require_vision
from .test_feature_extraction_common import FeatureExtractionSavingTestMixin, prepare_image_inputs
if is_torch_available():
import torch
if is_vision_available():
from PIL import Image
from transformers import SegformerFeatureExtractor
class SegformerFeatureExtractionTester(unittest.TestCase):
def __init__(
self,
parent,
batch_size=7,
num_channels=3,
min_resolution=30,
max_resolution=400,
do_resize=True,
keep_ratio=True,
image_scale=[100, 20],
align=True,
size_divisor=10,
do_random_crop=True,
crop_size=[20, 20],
do_normalize=True,
image_mean=[0.5, 0.5, 0.5],
image_std=[0.5, 0.5, 0.5],
do_pad=True,
):
self.parent = parent
self.batch_size = batch_size
self.num_channels = num_channels
self.min_resolution = min_resolution
self.max_resolution = max_resolution
self.do_resize = do_resize
self.keep_ratio = keep_ratio
self.image_scale = image_scale
self.align = align
self.size_divisor = size_divisor
self.do_random_crop = do_random_crop
self.crop_size = crop_size
self.do_normalize = do_normalize
self.image_mean = image_mean
self.image_std = image_std
self.do_pad = do_pad
def prepare_feat_extract_dict(self):
return {
"do_resize": self.do_resize,
"keep_ratio": self.keep_ratio,
"image_scale": self.image_scale,
"align": self.align,
"size_divisor": self.size_divisor,
"do_random_crop": self.do_random_crop,
"crop_size": self.crop_size,
"do_normalize": self.do_normalize,
"image_mean": self.image_mean,
"image_std": self.image_std,
"do_pad": self.do_pad,
}
@require_torch
@require_vision
class SegformerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestCase):
feature_extraction_class = SegformerFeatureExtractor if is_vision_available() else None
def setUp(self):
self.feature_extract_tester = SegformerFeatureExtractionTester(self)
@property
def feat_extract_dict(self):
return self.feature_extract_tester.prepare_feat_extract_dict()
def test_feat_extract_properties(self):
feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
self.assertTrue(hasattr(feature_extractor, "do_resize"))
self.assertTrue(hasattr(feature_extractor, "keep_ratio"))
self.assertTrue(hasattr(feature_extractor, "image_scale"))
self.assertTrue(hasattr(feature_extractor, "align"))
self.assertTrue(hasattr(feature_extractor, "size_divisor"))
self.assertTrue(hasattr(feature_extractor, "do_random_crop"))
self.assertTrue(hasattr(feature_extractor, "crop_size"))
self.assertTrue(hasattr(feature_extractor, "do_normalize"))
self.assertTrue(hasattr(feature_extractor, "image_mean"))
self.assertTrue(hasattr(feature_extractor, "image_std"))
self.assertTrue(hasattr(feature_extractor, "do_pad"))
def test_batch_feature(self):
pass
def test_call_pil(self):
# Initialize feature_extractor
feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
# create random PIL images
image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False)
for image in image_inputs:
self.assertIsInstance(image, Image.Image)
# Test not batched input
encoded_images = feature_extractor(image_inputs[0], return_tensors="pt").pixel_values
self.assertEqual(
encoded_images.shape,
(
1,
self.feature_extract_tester.num_channels,
*self.feature_extract_tester.crop_size,
),
)
# Test batched
encoded_images = feature_extractor(image_inputs, return_tensors="pt").pixel_values
self.assertEqual(
encoded_images.shape,
(
self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels,
*self.feature_extract_tester.crop_size[::-1],
),
)
def test_call_numpy(self):
# Initialize feature_extractor
feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
# create random numpy tensors
image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False, numpify=True)
for image in image_inputs:
self.assertIsInstance(image, np.ndarray)
# Test not batched input
encoded_images = feature_extractor(image_inputs[0], return_tensors="pt").pixel_values
self.assertEqual(
encoded_images.shape,
(
1,
self.feature_extract_tester.num_channels,
*self.feature_extract_tester.crop_size[::-1],
),
)
# Test batched
encoded_images = feature_extractor(image_inputs, return_tensors="pt").pixel_values
self.assertEqual(
encoded_images.shape,
(
self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels,
*self.feature_extract_tester.crop_size[::-1],
),
)
def test_call_pytorch(self):
# Initialize feature_extractor
feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
# create random PyTorch tensors
image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False, torchify=True)
for image in image_inputs:
self.assertIsInstance(image, torch.Tensor)
# Test not batched input
encoded_images = feature_extractor(image_inputs[0], return_tensors="pt").pixel_values
self.assertEqual(
encoded_images.shape,
(
1,
self.feature_extract_tester.num_channels,
*self.feature_extract_tester.crop_size[::-1],
),
)
# Test batched
encoded_images = feature_extractor(image_inputs, return_tensors="pt").pixel_values
self.assertEqual(
encoded_images.shape,
(
self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels,
*self.feature_extract_tester.crop_size[::-1],
),
)
def test_resize(self):
# Initialize feature_extractor: version 1 (no align, keep_ratio=True)
feature_extractor = SegformerFeatureExtractor(
image_scale=(1333, 800), align=False, do_random_crop=False, do_pad=False
)
# Create random PyTorch tensor
image = torch.randn((3, 288, 512))
# Verify shape
encoded_images = feature_extractor(image, return_tensors="pt").pixel_values
expected_shape = (1, 3, 750, 1333)
self.assertEqual(encoded_images.shape, expected_shape)
# Initialize feature_extractor: version 2 (keep_ratio=False)
feature_extractor = SegformerFeatureExtractor(
image_scale=(1280, 800), align=False, keep_ratio=False, do_random_crop=False, do_pad=False
)
# Verify shape
encoded_images = feature_extractor(image, return_tensors="pt").pixel_values
expected_shape = (1, 3, 800, 1280)
self.assertEqual(encoded_images.shape, expected_shape)
def test_aligned_resize(self):
# Initialize feature_extractor: version 1
feature_extractor = SegformerFeatureExtractor(do_random_crop=False, do_pad=False)
# Create random PyTorch tensor
image = torch.randn((3, 256, 304))
# Verify shape
encoded_images = feature_extractor(image, return_tensors="pt").pixel_values
expected_shape = (1, 3, 512, 608)
self.assertEqual(encoded_images.shape, expected_shape)
# Initialize feature_extractor: version 2
feature_extractor = SegformerFeatureExtractor(image_scale=(1024, 2048), do_random_crop=False, do_pad=False)
# create random PyTorch tensor
image = torch.randn((3, 1024, 2048))
# Verify shape
encoded_images = feature_extractor(image, return_tensors="pt").pixel_values
expected_shape = (1, 3, 1024, 2048)
self.assertEqual(encoded_images.shape, expected_shape)
def test_random_crop(self):
from datasets import load_dataset
ds = load_dataset("hf-internal-testing/fixtures_ade20k", split="test")
image = Image.open(ds[0]["file"])
segmentation_map = Image.open(ds[1]["file"])
w, h = image.size
# Initialize feature_extractor
feature_extractor = SegformerFeatureExtractor(crop_size=[w - 20, h - 20], do_pad=False)
# Encode image + segmentation map
encoded_images = feature_extractor(images=image, segmentation_maps=segmentation_map, return_tensors="pt")
# Verify shape of pixel_values
self.assertEqual(encoded_images.pixel_values.shape[-2:], (h - 20, w - 20))
# Verify shape of labels
self.assertEqual(encoded_images.labels.shape[-2:], (h - 20, w - 20))
def test_pad(self):
# Initialize feature_extractor (note that padding should only be applied when random cropping)
feature_extractor = SegformerFeatureExtractor(
align=False, do_random_crop=True, crop_size=self.feature_extract_tester.crop_size, do_pad=True
)
# create random PyTorch tensors
image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False, torchify=True)
for image in image_inputs:
self.assertIsInstance(image, torch.Tensor)
# Test not batched input
encoded_images = feature_extractor(image_inputs[0], return_tensors="pt").pixel_values
self.assertEqual(
encoded_images.shape,
(
1,
self.feature_extract_tester.num_channels,
*self.feature_extract_tester.crop_size[::-1],
),
)
# Test batched
encoded_images = feature_extractor(image_inputs, return_tensors="pt").pixel_values
self.assertEqual(
encoded_images.shape,
(
self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels,
*self.feature_extract_tester.crop_size[::-1],
),
)
# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Testing suite for the PyTorch SegFormer model. """
import inspect
import unittest
from transformers import is_torch_available, is_vision_available
from transformers.models.auto import get_values
from transformers.testing_utils import require_torch, slow, torch_device
from .test_configuration_common import ConfigTester
from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
if is_torch_available():
import torch
from transformers import (
MODEL_MAPPING,
SegformerConfig,
SegformerForImageClassification,
SegformerForSemanticSegmentation,
SegformerModel,
)
from transformers.models.segformer.modeling_segformer import SEGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST
if is_vision_available():
from PIL import Image
from transformers import SegformerFeatureExtractor
class SegformerConfigTester(ConfigTester):
def create_and_test_config_common_properties(self):
config = self.config_class(**self.inputs_dict)
self.parent.assertTrue(hasattr(config, "hidden_sizes"))
self.parent.assertTrue(hasattr(config, "num_attention_heads"))
self.parent.assertTrue(hasattr(config, "num_encoder_blocks"))
class SegformerModelTester:
def __init__(
self,
parent,
batch_size=13,
image_size=64,
num_channels=3,
num_encoder_blocks=4,
depths=[2, 2, 2, 2],
sr_ratios=[8, 4, 2, 1],
hidden_sizes=[16, 32, 64, 128],
downsampling_rates=[1, 4, 8, 16],
num_attention_heads=[1, 2, 4, 8],
is_training=True,
use_labels=True,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
initializer_range=0.02,
num_labels=3,
scope=None,
):
self.parent = parent
self.batch_size = batch_size
self.image_size = image_size
self.num_channels = num_channels
self.num_encoder_blocks = num_encoder_blocks
self.sr_ratios = sr_ratios
self.depths = depths
self.hidden_sizes = hidden_sizes
self.downsampling_rates = downsampling_rates
self.num_attention_heads = num_attention_heads
self.is_training = is_training
self.use_labels = use_labels
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.initializer_range = initializer_range
self.num_labels = num_labels
self.scope = scope
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
labels = None
if self.use_labels:
labels = ids_tensor([self.batch_size, self.image_size, self.image_size], self.num_labels)
config = SegformerConfig(
image_size=self.image_size,
num_channels=self.num_channels,
num_encoder_blocks=self.num_encoder_blocks,
depths=self.depths,
hidden_sizes=self.hidden_sizes,
num_attention_heads=self.num_attention_heads,
hidden_act=self.hidden_act,
hidden_dropout_prob=self.hidden_dropout_prob,
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
initializer_range=self.initializer_range,
)
return config, pixel_values, labels
def create_and_check_model(self, config, pixel_values, labels):
model = SegformerModel(config=config)
model.to(torch_device)
model.eval()
result = model(pixel_values)
expected_height = expected_width = self.image_size // (self.downsampling_rates[-1] * 2)
self.parent.assertEqual(
result.last_hidden_state.shape, (self.batch_size, self.hidden_sizes[-1], expected_height, expected_width)
)
def create_and_check_for_image_segmentation(self, config, pixel_values, labels):
config.num_labels = self.num_labels
model = SegformerForSemanticSegmentation(config)
model.to(torch_device)
model.eval()
result = model(pixel_values)
self.parent.assertEqual(
result.logits.shape, (self.batch_size, self.num_labels, self.image_size // 4, self.image_size // 4)
)
result = model(pixel_values, labels=labels)
self.parent.assertEqual(
result.logits.shape, (self.batch_size, self.num_labels, self.image_size // 4, self.image_size // 4)
)
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, pixel_values, labels = config_and_inputs
inputs_dict = {"pixel_values": pixel_values}
return config, inputs_dict
@require_torch
class SegformerModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (
(
SegformerModel,
SegformerForSemanticSegmentation,
SegformerForImageClassification,
)
if is_torch_available()
else ()
)
test_head_masking = False
test_pruning = False
test_resize_embeddings = False
test_torchscript = False
def setUp(self):
self.model_tester = SegformerModelTester(self)
self.config_tester = SegformerConfigTester(self, config_class=SegformerConfig)
def test_config(self):
self.config_tester.run_common_tests()
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
def test_for_image_segmentation(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_image_segmentation(*config_and_inputs)
@unittest.skip("SegFormer does not use inputs_embeds")
def test_inputs_embeds(self):
pass
@unittest.skip("SegFormer does not have get_input_embeddings method and get_output_embeddings methods")
def test_model_common_attributes(self):
pass
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
signature = inspect.signature(model.forward)
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()]
expected_arg_names = ["pixel_values"]
self.assertListEqual(arg_names[:1], expected_arg_names)
def test_attention_outputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True
for model_class in self.all_model_classes:
inputs_dict["output_attentions"] = True
inputs_dict["output_hidden_states"] = False
config.return_dict = True
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs.attentions
expected_num_attentions = sum(self.model_tester.depths)
self.assertEqual(len(attentions), expected_num_attentions)
# check that output_attentions also work using config
del inputs_dict["output_attentions"]
config.output_attentions = True
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs.attentions
self.assertEqual(len(attentions), expected_num_attentions)
# verify the first attentions (first block, first layer)
expected_seq_len = (self.model_tester.image_size // 4) ** 2
expected_reduced_seq_len = (self.model_tester.image_size // (4 * self.model_tester.sr_ratios[0])) ** 2
self.assertListEqual(
list(attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads[0], expected_seq_len, expected_reduced_seq_len],
)
# verify the last attentions (last block, last layer)
expected_seq_len = (self.model_tester.image_size // 32) ** 2
expected_reduced_seq_len = (self.model_tester.image_size // (32 * self.model_tester.sr_ratios[-1])) ** 2
self.assertListEqual(
list(attentions[-1].shape[-3:]),
[self.model_tester.num_attention_heads[-1], expected_seq_len, expected_reduced_seq_len],
)
# Check attention is always last and order is fine
inputs_dict["output_attentions"] = True
inputs_dict["output_hidden_states"] = True
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
self.assertEqual(3, len(outputs))
self_attentions = outputs.attentions
self.assertEqual(len(self_attentions), expected_num_attentions)
# verify the first attentions (first block, first layer)
expected_seq_len = (self.model_tester.image_size // 4) ** 2
expected_reduced_seq_len = (self.model_tester.image_size // (4 * self.model_tester.sr_ratios[0])) ** 2
self.assertListEqual(
list(self_attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads[0], expected_seq_len, expected_reduced_seq_len],
)
def test_hidden_states_output(self):
def check_hidden_states_output(inputs_dict, config, model_class):
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
hidden_states = outputs.hidden_states
expected_num_layers = self.model_tester.num_encoder_blocks
self.assertEqual(len(hidden_states), expected_num_layers)
# verify the first hidden states (first block)
self.assertListEqual(
list(hidden_states[0].shape[-3:]),
[
self.model_tester.hidden_sizes[0],
self.model_tester.image_size // 4,
self.model_tester.image_size // 4,
],
)
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
inputs_dict["output_hidden_states"] = True
check_hidden_states_output(inputs_dict, config, model_class)
# check that output_hidden_states also work using config
del inputs_dict["output_hidden_states"]
config.output_hidden_states = True
check_hidden_states_output(inputs_dict, config, model_class)
def test_training(self):
if not self.model_tester.is_training:
return
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True
for model_class in self.all_model_classes:
if model_class in get_values(MODEL_MAPPING):
continue
# TODO: remove the following 3 lines once we have a MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING
# this can then be incorporated into _prepare_for_class in test_modeling_common.py
if model_class.__name__ == "SegformerForSemanticSegmentation":
batch_size, num_channels, height, width = inputs_dict["pixel_values"].shape
inputs_dict["labels"] = torch.zeros([self.model_tester.batch_size, height, width]).long()
model = model_class(config)
model.to(torch_device)
model.train()
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
loss = model(**inputs).loss
loss.backward()
@slow
def test_model_from_pretrained(self):
for model_name in SEGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
model = SegformerModel.from_pretrained(model_name)
self.assertIsNotNone(model)
# We will verify our results on an image of cute cats
def prepare_img():
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
return image
@require_torch
class SegformerModelIntegrationTest(unittest.TestCase):
@slow
def test_inference_image_segmentation_ade(self):
# only resize + normalize
feature_extractor = SegformerFeatureExtractor(
image_scale=(512, 512), keep_ratio=False, align=False, do_random_crop=False
)
model = SegformerForSemanticSegmentation.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512").to(
torch_device
)
image = prepare_img()
encoded_inputs = feature_extractor(images=image, return_tensors="pt")
pixel_values = encoded_inputs.pixel_values.to(torch_device)
outputs = model(pixel_values)
expected_shape = torch.Size((1, model.config.num_labels, 128, 128))
self.assertEqual(outputs.logits.shape, expected_shape)
expected_slice = torch.tensor(
[
[[-4.6310, -5.5232, -6.2356], [-5.1921, -6.1444, -6.5996], [-5.4424, -6.2790, -6.7574]],
[[-12.1391, -13.3122, -13.9554], [-12.8732, -13.9352, -14.3563], [-12.9438, -13.8226, -14.2513]],
[[-12.5134, -13.4686, -14.4915], [-12.8669, -14.4343, -14.7758], [-13.2523, -14.5819, -15.0694]],
]
).to(torch_device)
self.assertTrue(torch.allclose(outputs.logits[0, :3, :3, :3], expected_slice, atol=1e-4))
@slow
def test_inference_image_segmentation_city(self):
# only resize + normalize
feature_extractor = SegformerFeatureExtractor(
image_scale=(512, 512), keep_ratio=False, align=False, do_random_crop=False
)
model = SegformerForSemanticSegmentation.from_pretrained(
"nvidia/segformer-b1-finetuned-cityscapes-1024-1024"
).to(torch_device)
image = prepare_img()
encoded_inputs = feature_extractor(images=image, return_tensors="pt")
pixel_values = encoded_inputs.pixel_values.to(torch_device)
outputs = model(pixel_values)
expected_shape = torch.Size((1, model.config.num_labels, 128, 128))
self.assertEqual(outputs.logits.shape, expected_shape)
expected_slice = torch.tensor(
[
[[-13.5748, -13.9111, -12.6500], [-14.3500, -15.3683, -14.2328], [-14.7532, -16.0424, -15.6087]],
[[-17.1651, -15.8725, -12.9653], [-17.2580, -17.3718, -14.8223], [-16.6058, -16.8783, -16.7452]],
[[-3.6456, -3.0209, -1.4203], [-3.0797, -3.1959, -2.0000], [-1.8757, -1.9217, -1.6997]],
]
).to(torch_device)
self.assertTrue(torch.allclose(outputs.logits[0, :3, :3, :3], expected_slice, atol=1e-1))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment