Unverified Commit d83d22f5 authored by Francesco Saverio Zuppichini's avatar Francesco Saverio Zuppichini Committed by GitHub
Browse files

Maskformer (#15682)



* maskformer

* conflicts

* conflicts

* minor fixes

* feature extractor test fix

refactor MaskFormerLoss following conversation

MaskFormer related types should not trigger a module time import error

missed one

removed all the types that are not used

update config mapping

minor updates in the doc

resolved conversation that doesn't need a discussion

minor changes

resolved conversations

fixed DetrDecoder

* minor changes

minor changes

fixed mdx file

test feature_extractor return types

functional losses -> classes

removed the return type test for the feature extractor

minor changes + style + quality

* conflicts?

* rebase master

* readme

* added missing files

* deleded poolformers test that where in the wrong palce

* CI

* minor changes

* Apply suggestions from code review
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* resolved conversations

* minor changes

* conversations

[Unispeech] Fix slow tests (#15818)

* remove soundfile old way of loading audio

* Adapt slow test

[Barthez Tokenizer] Fix saving (#15815)

[TFXLNet] Correct tf xlnet generate (#15822)

* [TFXLNet] Correct tf xlnet

* adapt test comment

Fix the push run (#15807)

Fix semantic segmentation pipeline test (#15826)

Fix dummy_inputs() to dummy_inputs in symbolic_trace doc (#15776)

Add model specific output classes to PoolFormer model docs (#15746)

* Added model specific output classes to poolformer docs

* Fixed Segformer typo in Poolformer docs

Adding the option to return_timestamps on pure CTC ASR models. (#15792)

* Adding the option to return_timestamps on pure CTC ASR models.

* Remove `math.prod` which was introduced in Python 3.8

* int are not floats.

* Reworking the PR to support "char" vs "word" output.

* Fixup!

* Update src/transformers/pipelines/automatic_speech_recognition.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Update src/transformers/pipelines/automatic_speech_recognition.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Update src/transformers/pipelines/automatic_speech_recognition.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Update src/transformers/pipelines/automatic_speech_recognition.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Update src/transformers/pipelines/automatic_speech_recognition.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Update src/transformers/pipelines/automatic_speech_recognition.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Update src/transformers/pipelines/automatic_speech_recognition.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Update src/transformers/pipelines/automatic_speech_recognition.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Update src/transformers/pipelines/automatic_speech_recognition.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Quality.
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

HFTracer.trace should use/return self.graph to be compatible with torch.fx.Tracer (#15824)

Fix tf.concatenate + test past_key_values for TF models (#15774)

* fix wrong method name tf.concatenate

* add tests related to causal LM / decoder

* make style and quality

* clean-up

* Fix TFBertModel's extended_attention_mask when past_key_values is provided

* Fix tests

* fix copies

* More tf.int8 -> tf.int32 in TF test template

* clean-up

* Update TF test template

* revert the previous commit + update the TF test template

* Fix TF template extended_attention_mask when past_key_values is provided

* Fix some styles manually

* clean-up

* Fix ValueError: too many values to unpack in the test

* Fix more: too many values to unpack in the test

* Add a comment for extended_attention_mask when there is past_key_values

* Fix TFElectra extended_attention_mask when past_key_values is provided

* Add tests to other TF models

* Fix for TF Electra test: add prepare_config_and_inputs_for_decoder

* Fix not passing training arg to lm_head in TFRobertaForCausalLM

* Fix tests (with past) for TF Roberta

* add testing for pask_key_values for TFElectra model
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>

[examples/summarization and translation] fix readme (#15833)

Add ONNX Runtime quantization for text classification notebook (#15817)

Re-enable doctests for the quicktour (#15828)

* Re-enable doctests for the quicktour

* Re-enable doctests for task_summary (#15830)

* Remove &

Framework split model report (#15825)

Add TFConvNextModel (#15750)

* feat: initial implementation of convnext in tensorflow.

* fix: sample code for the classification model.

* chore: added checked for  from the classification model.

* chore: set bias initializer in the classification head.

* chore: updated license terms.

* chore: removed ununsed imports

* feat: enabled  argument during using drop_path.

* chore: replaced tf.identity with layers.Activation(linear).

* chore: edited default checkpoint.

* fix: minor bugs in the initializations.

* partial-fix: tf model errors for loading pretrained pt weights.

* partial-fix: call method updated

* partial-fix: cross loading of weights (4x3 variables to be matched)

* chore: removed unneeded comment.

* removed playground.py

* rebasing

* rebasing and removing playground.py.

* fix: renaming TFConvNextStage conv and layer norm layers

* chore: added initializers and other minor additions.

* chore: added initializers and other minor additions.

* add: tests for convnext.

* fix: integration tester class.

* fix: issues mentioned in pr feedback (round 1).

* fix: how output_hidden_states arg is propoagated inside the network.

* feat: handling of  arg for pure cnn models.

* chore: added a note on equal contribution in model docs.

* rebasing

* rebasing and removing playground.py.

* feat: encapsulation for the convnext trunk.

* Fix variable naming; Test-related corrections; Run make fixup

* chore: added Joao as a contributor to convnext.

* rebasing

* rebasing and removing playground.py.

* rebasing

* rebasing and removing playground.py.

* chore: corrected copyright year and added comment on NHWC.

* chore: fixed the black version and ran formatting.

* chore: ran make style.

* chore: removed from_pt argument from test, ran make style.

* rebasing

* rebasing and removing playground.py.

* rebasing

* rebasing and removing playground.py.

* fix: tests in the convnext subclass, ran make style.

* rebasing

* rebasing and removing playground.py.

* rebasing

* rebasing and removing playground.py.

* chore: moved convnext test to the correct location

* fix: locations for the test file of convnext.

* fix: convnext tests.

* chore: applied  sgugger's suggestion for dealing w/ output_attentions.

* chore: added comments.

* chore: applied updated quality enviornment style.

* chore: applied formatting with quality enviornment.

* chore: revert to the previous tests/test_modeling_common.py.

* chore: revert to the original test_modeling_common.py

* chore: revert to previous states for test_modeling_tf_common.py and modeling_tf_utils.py

* fix: tests for convnext.

* chore: removed output_attentions argument from convnext config.

* chore: revert to the earlier tf utils.

* fix: output shapes of the hidden states

* chore: removed unnecessary comment

* chore: reverting to the right test_modeling_tf_common.py.

* Styling nits
Co-authored-by: default avatarariG23498 <aritra.born2fly@gmail.com>
Co-authored-by: default avatarJoao Gante <joao@huggingface.co>
Co-authored-by: default avatarSylvain Gugger <Sylvain.gugger@gmail.com>

* minor changes

* doc fix in feature extractor

* doc

* typose

* removed detr logic from config

* removed detr logic from config

* removed num_labels

* small fix in the config

* auxilary -> auxiliary

* make style

* some test is failing

* fix a weird char in config prevending doc-builder

* retry to fix the doc-builder issue

* make style

* new try to fix the doc builder

* CI

* change weights to facebook
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>
Co-authored-by: default avatarariG23498 <aritra.born2fly@gmail.com>
Co-authored-by: default avatarJoao Gante <joao@huggingface.co>
Co-authored-by: default avatarSylvain Gugger <Sylvain.gugger@gmail.com>
parent e535c389
......@@ -281,6 +281,7 @@ Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, and Wen-tau Yih.
1. **[LXMERT](https://huggingface.co/docs/transformers/model_doc/lxmert)** (from UNC Chapel Hill) released with the paper [LXMERT: Learning Cross-Modality Encoder Representations from Transformers for Open-Domain Question Answering](https://arxiv.org/abs/1908.07490) by Hao Tan and Mohit Bansal.
1. **[M2M100](https://huggingface.co/docs/transformers/model_doc/m2m_100)** (from Facebook) released with the paper [Beyond English-Centric Multilingual Machine Translation](https://arxiv.org/abs/2010.11125) by Angela Fan, Shruti Bhosale, Holger Schwenk, Zhiyi Ma, Ahmed El-Kishky, Siddharth Goyal, Mandeep Baines, Onur Celebi, Guillaume Wenzek, Vishrav Chaudhary, Naman Goyal, Tom Birch, Vitaliy Liptchinsky, Sergey Edunov, Edouard Grave, Michael Auli, Armand Joulin.
1. **[MarianMT](https://huggingface.co/docs/transformers/model_doc/marian)** Machine translation models trained using [OPUS](http://opus.nlpl.eu/) data by Jörg Tiedemann. The [Marian Framework](https://marian-nmt.github.io/) is being developed by the Microsoft Translator Team.
1. **[MaskFormer](https://huggingface.co/docs/transformers/master/model_doc/maskformer)** (from Meta and UIUC) released with the paper [Per-Pixel Classification is Not All You Need for Semantic Segmentation](https://arxiv.org/abs/2107.06278) by Bowen Cheng, Alexander G. Schwing, Alexander Kirillov.
1. **[MBart](https://huggingface.co/docs/transformers/model_doc/mbart)** (from Facebook) released with the paper [Multilingual Denoising Pre-training for Neural Machine Translation](https://arxiv.org/abs/2001.08210) by Yinhan Liu, Jiatao Gu, Naman Goyal, Xian Li, Sergey Edunov, Marjan Ghazvininejad, Mike Lewis, Luke Zettlemoyer.
1. **[MBart-50](https://huggingface.co/docs/transformers/model_doc/mbart)** (from Facebook) released with the paper [Multilingual Translation with Extensible Multilingual Pretraining and Finetuning](https://arxiv.org/abs/2008.00401) by Yuqing Tang, Chau Tran, Xian Li, Peng-Jen Chen, Naman Goyal, Vishrav Chaudhary, Jiatao Gu, Angela Fan.
1. **[Megatron-BERT](https://huggingface.co/docs/transformers/model_doc/megatron-bert)** (from NVIDIA) released with the paper [Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism](https://arxiv.org/abs/1909.08053) by Mohammad Shoeybi, Mostofa Patwary, Raul Puri, Patrick LeGresley, Jared Casper and Bryan Catanzaro.
......
......@@ -259,6 +259,7 @@ Flax, PyTorch, TensorFlow 설치 페이지에서 이들을 conda로 설치하는
1. **[LXMERT](https://huggingface.co/docs/transformers/model_doc/lxmert)** (from UNC Chapel Hill) released with the paper [LXMERT: Learning Cross-Modality Encoder Representations from Transformers for Open-Domain Question Answering](https://arxiv.org/abs/1908.07490) by Hao Tan and Mohit Bansal.
1. **[M2M100](https://huggingface.co/docs/transformers/model_doc/m2m_100)** (from Facebook) released with the paper [Beyond English-Centric Multilingual Machine Translation](https://arxiv.org/abs/2010.11125) by Angela Fan, Shruti Bhosale, Holger Schwenk, Zhiyi Ma, Ahmed El-Kishky, Siddharth Goyal, Mandeep Baines, Onur Celebi, Guillaume Wenzek, Vishrav Chaudhary, Naman Goyal, Tom Birch, Vitaliy Liptchinsky, Sergey Edunov, Edouard Grave, Michael Auli, Armand Joulin.
1. **[MarianMT](https://huggingface.co/docs/transformers/model_doc/marian)** Machine translation models trained using [OPUS](http://opus.nlpl.eu/) data by Jörg Tiedemann. The [Marian Framework](https://marian-nmt.github.io/) is being developed by the Microsoft Translator Team.
1. **[MaskFormer](https://huggingface.co/docs/transformers/master/model_doc/maskformer)** (from Meta and UIUC) released with the paper [Per-Pixel Classification is Not All You Need for Semantic Segmentation](https://arxiv.org/abs/2107.06278) by Bowen Cheng, Alexander G. Schwing, Alexander Kirillov.
1. **[MBart](https://huggingface.co/docs/transformers/model_doc/mbart)** (from Facebook) released with the paper [Multilingual Denoising Pre-training for Neural Machine Translation](https://arxiv.org/abs/2001.08210) by Yinhan Liu, Jiatao Gu, Naman Goyal, Xian Li, Sergey Edunov, Marjan Ghazvininejad, Mike Lewis, Luke Zettlemoyer.
1. **[MBart-50](https://huggingface.co/docs/transformers/model_doc/mbart)** (from Facebook) released with the paper [Multilingual Translation with Extensible Multilingual Pretraining and Finetuning](https://arxiv.org/abs/2008.00401) by Yuqing Tang, Chau Tran, Xian Li, Peng-Jen Chen, Naman Goyal, Vishrav Chaudhary, Jiatao Gu, Angela Fan.
1. **[Megatron-BERT](https://huggingface.co/docs/transformers/model_doc/megatron-bert)** (from NVIDIA) released with the paper [Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism](https://arxiv.org/abs/1909.08053) by Mohammad Shoeybi, Mostofa Patwary, Raul Puri, Patrick LeGresley, Jared Casper and Bryan Catanzaro.
......
......@@ -283,6 +283,7 @@ conda install -c huggingface transformers
1. **[LXMERT](https://huggingface.co/docs/transformers/model_doc/lxmert)** (来自 UNC Chapel Hill) 伴随论文 [LXMERT: Learning Cross-Modality Encoder Representations from Transformers for Open-Domain Question Answering](https://arxiv.org/abs/1908.07490) 由 Hao Tan and Mohit Bansal 发布。
1. **[M2M100](https://huggingface.co/docs/transformers/model_doc/m2m_100)** (来自 Facebook) 伴随论文 [Beyond English-Centric Multilingual Machine Translation](https://arxiv.org/abs/2010.11125) 由 Angela Fan, Shruti Bhosale, Holger Schwenk, Zhiyi Ma, Ahmed El-Kishky, Siddharth Goyal, Mandeep Baines, Onur Celebi, Guillaume Wenzek, Vishrav Chaudhary, Naman Goyal, Tom Birch, Vitaliy Liptchinsky, Sergey Edunov, Edouard Grave, Michael Auli, Armand Joulin 发布。
1. **[MarianMT](https://huggingface.co/docs/transformers/model_doc/marian)**[OPUS](http://opus.nlpl.eu/) 数据训练的机器翻译模型由 Jörg Tiedemann 发布。[Marian Framework](https://marian-nmt.github.io/) 由微软翻译团队开发。
1. **[MaskFormer](https://huggingface.co/docs/transformers/master/model_doc/maskformer)** (from Meta and UIUC) released with the paper [Per-Pixel Classification is Not All You Need for Semantic Segmentation](https://arxiv.org/abs/2107.06278) by Bowen Cheng, Alexander G. Schwing, Alexander Kirillov
1. **[MBart](https://huggingface.co/docs/transformers/model_doc/mbart)** (来自 Facebook) 伴随论文 [Multilingual Denoising Pre-training for Neural Machine Translation](https://arxiv.org/abs/2001.08210) 由 Yinhan Liu, Jiatao Gu, Naman Goyal, Xian Li, Sergey Edunov, Marjan Ghazvininejad, Mike Lewis, Luke Zettlemoyer 发布。
1. **[MBart-50](https://huggingface.co/docs/transformers/model_doc/mbart)** (来自 Facebook) 伴随论文 [Multilingual Translation with Extensible Multilingual Pretraining and Finetuning](https://arxiv.org/abs/2008.00401) 由 Yuqing Tang, Chau Tran, Xian Li, Peng-Jen Chen, Naman Goyal, Vishrav Chaudhary, Jiatao Gu, Angela Fan 发布。
1. **[Megatron-BERT](https://huggingface.co/docs/transformers/model_doc/megatron-bert)** (来自 NVIDIA) 伴随论文 [Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism](https://arxiv.org/abs/1909.08053) 由 Mohammad Shoeybi, Mostofa Patwary, Raul Puri, Patrick LeGresley, Jared Casper and Bryan Catanzaro 发布。
......
......@@ -295,6 +295,7 @@ conda install -c huggingface transformers
1. **[LXMERT](https://huggingface.co/docs/transformers/model_doc/lxmert)** (from UNC Chapel Hill) released with the paper [LXMERT: Learning Cross-Modality Encoder Representations from Transformers for Open-Domain Question Answering](https://arxiv.org/abs/1908.07490) by Hao Tan and Mohit Bansal.
1. **[M2M100](https://huggingface.co/docs/transformers/model_doc/m2m_100)** (from Facebook) released with the paper [Beyond English-Centric Multilingual Machine Translation](https://arxiv.org/abs/2010.11125) by Angela Fan, Shruti Bhosale, Holger Schwenk, Zhiyi Ma, Ahmed El-Kishky, Siddharth Goyal, Mandeep Baines, Onur Celebi, Guillaume Wenzek, Vishrav Chaudhary, Naman Goyal, Tom Birch, Vitaliy Liptchinsky, Sergey Edunov, Edouard Grave, Michael Auli, Armand Joulin.
1. **[MarianMT](https://huggingface.co/docs/transformers/model_doc/marian)** Machine translation models trained using [OPUS](http://opus.nlpl.eu/) data by Jörg Tiedemann. The [Marian Framework](https://marian-nmt.github.io/) is being developed by the Microsoft Translator Team.
1. **[MaskFormer](https://huggingface.co/docs/transformers/master/model_doc/maskformer)** (from Meta and UIUC) released with the paper [Per-Pixel Classification is Not All You Need for Semantic Segmentation](https://arxiv.org/abs/2107.06278) by Bowen Cheng, Alexander G. Schwing, Alexander Kirillov
1. **[MBart](https://huggingface.co/docs/transformers/model_doc/mbart)** (from Facebook) released with the paper [Multilingual Denoising Pre-training for Neural Machine Translation](https://arxiv.org/abs/2001.08210) by Yinhan Liu, Jiatao Gu, Naman Goyal, Xian Li, Sergey Edunov, Marjan Ghazvininejad, Mike Lewis, Luke Zettlemoyer.
1. **[MBart-50](https://huggingface.co/docs/transformers/model_doc/mbart)** (from Facebook) released with the paper [Multilingual Translation with Extensible Multilingual Pretraining and Finetuning](https://arxiv.org/abs/2008.00401) by Yuqing Tang, Chau Tran, Xian Li, Peng-Jen Chen, Naman Goyal, Vishrav Chaudhary, Jiatao Gu, Angela Fan.
1. **[Megatron-BERT](https://huggingface.co/docs/transformers/model_doc/megatron-bert)** (from NVIDIA) released with the paper [Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism](https://arxiv.org/abs/1909.08053) by Mohammad Shoeybi, Mostofa Patwary, Raul Puri, Patrick LeGresley, Jared Casper and Bryan Catanzaro.
......
......@@ -230,6 +230,8 @@
title: LXMERT
- local: model_doc/marian
title: MarianMT
- local: model_doc/maskformer
title: MaskFormer
- local: model_doc/m2m_100
title: M2M100
- local: model_doc/mbart
......
......@@ -105,6 +105,7 @@ conversion utilities for the following models.
1. **[LXMERT](model_doc/lxmert)** (from UNC Chapel Hill) released with the paper [LXMERT: Learning Cross-Modality Encoder Representations from Transformers for Open-Domain Question Answering](https://arxiv.org/abs/1908.07490) by Hao Tan and Mohit Bansal.
1. **[M2M100](model_doc/m2m_100)** (from Facebook) released with the paper [Beyond English-Centric Multilingual Machine Translation](https://arxiv.org/abs/2010.11125) by Angela Fan, Shruti Bhosale, Holger Schwenk, Zhiyi Ma, Ahmed El-Kishky, Siddharth Goyal, Mandeep Baines, Onur Celebi, Guillaume Wenzek, Vishrav Chaudhary, Naman Goyal, Tom Birch, Vitaliy Liptchinsky, Sergey Edunov, Edouard Grave, Michael Auli, Armand Joulin.
1. **[MarianMT](model_doc/marian)** Machine translation models trained using [OPUS](http://opus.nlpl.eu/) data by Jörg Tiedemann. The [Marian Framework](https://marian-nmt.github.io/) is being developed by the Microsoft Translator Team.
1. **[MaskFormer](model_doc/maskformer)** (from Meta and UIUC) released with the paper [Per-Pixel Classification is Not All You Need for Semantic Segmentation](https://arxiv.org/abs/2107.06278) by Bowen Cheng, Alexander G. Schwing, Alexander Kirillov.
1. **[MBart](model_doc/mbart)** (from Facebook) released with the paper [Multilingual Denoising Pre-training for Neural Machine Translation](https://arxiv.org/abs/2001.08210) by Yinhan Liu, Jiatao Gu, Naman Goyal, Xian Li, Sergey Edunov, Marjan Ghazvininejad, Mike Lewis, Luke Zettlemoyer.
1. **[MBart-50](model_doc/mbart)** (from Facebook) released with the paper [Multilingual Translation with Extensible Multilingual Pretraining and Finetuning](https://arxiv.org/abs/2008.00401) by Yuqing Tang, Chau Tran, Xian Li, Peng-Jen Chen, Naman Goyal, Vishrav Chaudhary, Jiatao Gu, Angela Fan.
1. **[Megatron-BERT](model_doc/megatron-bert)** (from NVIDIA) released with the paper [Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism](https://arxiv.org/abs/1909.08053) by Mohammad Shoeybi, Mostofa Patwary, Raul Puri, Patrick LeGresley, Jared Casper and Bryan Catanzaro.
......@@ -209,6 +210,7 @@ Flax), PyTorch, and/or TensorFlow.
| LXMERT | ✅ | ✅ | ✅ | ✅ | ❌ |
| M2M100 | ✅ | ❌ | ✅ | ❌ | ❌ |
| Marian | ✅ | ❌ | ✅ | ✅ | ✅ |
| MaskFormer | ❌ | ❌ | ✅ | ❌ | ❌ |
| mBART | ✅ | ✅ | ✅ | ✅ | ✅ |
| MegatronBert | ❌ | ❌ | ✅ | ❌ | ❌ |
| MobileBERT | ✅ | ✅ | ✅ | ✅ | ❌ |
......
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# MaskFormer
<Tip>
This is a recently introduced model so the API hasn't been tested extensively. There may be some bugs or slight
breaking changes to fix it in the future. If you see something strange, file a [Github Issue](https://github.com/huggingface/transformers/issues/new?assignees=&labels=&template=bug-report.md&title).
</Tip>
## Overview
The MaskFormer model was proposed in [Per-Pixel Classification is Not All You Need for Semantic Segmentation](https://arxiv.org/abs/2107.06278) by Bowen Cheng, Alexander G. Schwing, Alexander Kirillov. MaskFormer addresses semantic segmentation with a mask classification paradigm instead of performing classic pixel-level classification.
The abstract from the paper is the following:
*Modern approaches typically formulate semantic segmentation as a per-pixel classification task, while instance-level segmentation is handled with an alternative mask classification. Our key insight: mask classification is sufficiently general to solve both semantic- and instance-level segmentation tasks in a unified manner using the exact same model, loss, and training procedure. Following this observation, we propose MaskFormer, a simple mask classification model which predicts a set of binary masks, each associated with a single global class label prediction. Overall, the proposed mask classification-based method simplifies the landscape of effective approaches to semantic and panoptic segmentation tasks and shows excellent empirical results. In particular, we observe that MaskFormer outperforms per-pixel classification baselines when the number of classes is large. Our mask classification-based method outperforms both current state-of-the-art semantic (55.6 mIoU on ADE20K) and panoptic segmentation (52.7 PQ on COCO) models.*
Tips:
- MaskFormer's Transformer decoder is identical to the decoder of [DETR](detr). During training, the authors of DETR did find it helpful to use auxiliary losses in the decoder, especially to help the model output the correct number of objects of each class. If you set the parameter `use_auxilary_loss` of [`MaskFormerConfig`] to `True`, then prediction feedforward neural networks and Hungarian losses are added after each decoder layer (with the FFNs sharing parameters).
- If you want to train the model in a distributed environment across multiple nodes, then one should update the
`get_num_masks` function inside in the `MaskFormerLoss` class of `modeling_maskformer.py`. When training on multiple nodes, this should be
set to the average number of target masks across all nodes, as can be seen in the original implementation [here](https://github.com/facebookresearch/MaskFormer/blob/da3e60d85fdeedcb31476b5edd7d328826ce56cc/mask_former/modeling/criterion.py#L169).
- One can use [`MaskFormerFeatureExtractor`] to prepare images for the model and optional targets for the model.
- To get the final segmentation, depending on the task, you can call [`~MaskFormerFeatureExtractor.post_process_semantic_segmentation`] or [`~MaskFormerFeatureExtractor.post_process_panoptic_segmentation`]. Both tasks can be solved using [`MaskFormerForInstanceSegmentation`] output, the latter needs an additional `is_thing_map` to know which instances must be merged together..
The figure below illustrates the architecture of MaskFormer. Taken from the [original paper](https://arxiv.org/abs/2107.06278).
<img width="600" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/maskformer_architecture.png"/>
This model was contributed by [francesco](https://huggingface.co/francesco). The original code can be found [here](https://github.com/facebookresearch/MaskFormer).
## MaskFormer specific outputs
[[autodoc]] models.maskformer.modeling_maskformer.MaskFormerModelOutput
[[autodoc]] models.maskformer.modeling_maskformer.MaskFormerForInstanceSegmentationOutput
## MaskFormerConfig
[[autodoc]] MaskFormerConfig
## MaskFormerFeatureExtractor
[[autodoc]] MaskFormerFeatureExtractor
- __call__
- encode_inputs
- post_process_segmentation
- post_process_semantic_segmentation
- post_process_panoptic_segmentation
## MaskFormerModel
[[autodoc]] MaskFormerModel
- forward
## MaskFormerForInstanceSegmentation
[[autodoc]] MaskFormerForInstanceSegmentation
- forward
......@@ -247,6 +247,7 @@ _import_structure = {
"models.lxmert": ["LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "LxmertConfig", "LxmertTokenizer"],
"models.m2m_100": ["M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP", "M2M100Config"],
"models.marian": ["MarianConfig"],
"models.maskformer": ["MASKFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "MaskFormerConfig"],
"models.mbart": ["MBartConfig"],
"models.mbart50": [],
"models.megatron_bert": ["MEGATRON_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "MegatronBertConfig"],
......@@ -527,6 +528,7 @@ if is_vision_available():
_import_structure["models.layoutlmv2"].append("LayoutLMv2FeatureExtractor")
_import_structure["models.layoutlmv2"].append("LayoutLMv2Processor")
_import_structure["models.layoutxlm"].append("LayoutXLMProcessor")
_import_structure["models.maskformer"].append("MaskFormerFeatureExtractor")
_import_structure["models.perceiver"].append("PerceiverFeatureExtractor")
_import_structure["models.poolformer"].append("PoolFormerFeatureExtractor")
_import_structure["models.segformer"].append("SegformerFeatureExtractor")
......@@ -1147,6 +1149,14 @@ if is_torch_available():
]
)
_import_structure["models.marian"].extend(["MarianForCausalLM", "MarianModel", "MarianMTModel"])
_import_structure["models.maskformer"].extend(
[
"MASKFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
"MaskFormerForInstanceSegmentation",
"MaskFormerModel",
"MaskFormerPreTrainedModel",
]
)
_import_structure["models.mbart"].extend(
[
"MBartForCausalLM",
......@@ -2532,6 +2542,7 @@ if TYPE_CHECKING:
from .models.lxmert import LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP, LxmertConfig, LxmertTokenizer
from .models.m2m_100 import M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP, M2M100Config
from .models.marian import MarianConfig
from .models.maskformer import MASKFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, MaskFormerConfig
from .models.mbart import MBartConfig
from .models.megatron_bert import MEGATRON_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, MegatronBertConfig
from .models.mmbt import MMBTConfig
......@@ -2763,6 +2774,7 @@ if TYPE_CHECKING:
from .models.imagegpt import ImageGPTFeatureExtractor
from .models.layoutlmv2 import LayoutLMv2FeatureExtractor, LayoutLMv2Processor
from .models.layoutxlm import LayoutXLMProcessor
from .models.maskformer import MaskFormerFeatureExtractor
from .models.perceiver import PerceiverFeatureExtractor
from .models.poolformer import PoolFormerFeatureExtractor
from .models.segformer import SegformerFeatureExtractor
......@@ -3273,6 +3285,12 @@ if TYPE_CHECKING:
M2M100PreTrainedModel,
)
from .models.marian import MarianForCausalLM, MarianModel, MarianMTModel
from .models.maskformer import (
MASKFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
MaskFormerForInstanceSegmentation,
MaskFormerModel,
MaskFormerPreTrainedModel,
)
from .models.mbart import (
MBartForCausalLM,
MBartForConditionalGeneration,
......
......@@ -70,6 +70,7 @@ from . import (
lxmert,
m2m_100,
marian,
maskformer,
mbart,
mbart50,
megatron_bert,
......
......@@ -30,6 +30,7 @@ logger = logging.get_logger(__name__)
CONFIG_MAPPING_NAMES = OrderedDict(
[
# Add configs here
("maskformer", "MaskFormerConfig"),
("poolformer", "PoolFormerConfig"),
("convnext", "ConvNextConfig"),
("yoso", "YosoConfig"),
......@@ -129,6 +130,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
CONFIG_ARCHIVE_MAP_MAPPING_NAMES = OrderedDict(
[
# Add archive maps here
("maskformer", "MASKFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("poolformer", "POOLFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("convnext", "CONVNEXT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("yoso", "YOSO_PRETRAINED_CONFIG_ARCHIVE_MAP"),
......@@ -215,6 +217,7 @@ CONFIG_ARCHIVE_MAP_MAPPING_NAMES = OrderedDict(
MODEL_NAMES_MAPPING = OrderedDict(
[
# Add full (and cased) model names here
("maskformer", "MaskFormer"),
("poolformer", "PoolFormer"),
("convnext", "ConvNext"),
("yoso", "YOSO"),
......
......@@ -28,6 +28,7 @@ logger = logging.get_logger(__name__)
MODEL_MAPPING_NAMES = OrderedDict(
[
# Base model mapping
("maskformer", "MaskFormerModel"),
("poolformer", "PoolFormerModel"),
("convnext", "ConvNextModel"),
("yoso", "YosoModel"),
......
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...file_utils import _LazyModule, is_torch_available, is_vision_available
_import_structure = {
"configuration_maskformer": ["MASKFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "MaskFormerConfig"],
}
if is_vision_available():
_import_structure["feature_extraction_maskformer"] = ["MaskFormerFeatureExtractor"]
if is_torch_available():
_import_structure["modeling_maskformer"] = [
"MASKFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
"MaskFormerForInstanceSegmentation",
"MaskFormerModel",
"MaskFormerPreTrainedModel",
]
if TYPE_CHECKING:
from .configuration_maskformer import MASKFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, MaskFormerConfig
if is_vision_available():
from .feature_extraction_maskformer import MaskFormerFeatureExtractor
if is_torch_available():
from .modeling_maskformer import (
MASKFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
MaskFormerForInstanceSegmentation,
MaskFormerModel,
MaskFormerPreTrainedModel,
)
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
# coding=utf-8
# Copyright 2022 Meta Platforms, Inc.and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" MaskFormer model configuration"""
import copy
from typing import Dict, Optional
from ...configuration_utils import PretrainedConfig
from ...utils import logging
from ..auto.configuration_auto import AutoConfig
from ..detr import DetrConfig
from ..swin import SwinConfig
MASKFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"facebook/maskformer-swin-base-ade": "https://huggingface.co/facebook/maskformer-swin-base-ade/blob/main/config.json"
# See all MaskFormer models at https://huggingface.co/models?filter=maskformer
}
logger = logging.get_logger(__name__)
class MaskFormerConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`MaskFormerModel`]. It is used to instantiate a
MaskFormer model according to the specified arguments, defining the model architecture. Instantiating a
configuration with the defaults will yield a similar configuration to that of the
"facebook/maskformer-swin-base-ade" architecture trained on
[ADE20k-150](https://huggingface.co/datasets/scene_parse_150).
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Currently, MaskFormer only supports the [Swin Transformer](swin) as backbone.
Args:
mask_feature_size (`int`, *optional*, defaults to 256):
The masks' features size, this value will also be used to specify the Feature Pyramid Network features'
size.
no_object_weight (`float`, *optional*, defaults to 0.1):
Weight to apply to the null (no object) class.
use_auxiliary_loss(`bool`, *optional*, defaults to `False`):
If `True` [`MaskFormerForInstanceSegmentationOutput`] will contain the auxiliary losses computed using the
logits from each decoder's stage.
backbone_config (`Dict`, *optional*):
The configuration passed to the backbone, if unset, the configuration corresponding to
`swin-base-patch4-window12-384` will be used.
decoder_config (`Dict`, *optional*):
The configuration passed to the transformer decoder model, if unset the base config for `detr-resnet-50`
will be used.
init_std (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
init_xavier_std (`float`, *optional*, defaults to 1):
The scaling factor used for the Xavier initialization gain in the HM Attention map module.
dice_weight (`float`, *optional*, defaults to 1.0):
The weight for the dice loss.
cross_entropy_weight (`float`, *optional*, defaults to 1.0):
The weight for the cross entropy loss.
mask_weight (`float`, *optional*, defaults to 20.0):
The weight for the mask loss.
Raises:
`ValueError`:
Raised if the backbone model type selected is not in `["swin"]` or the decoder model type selected is not
in `["detr"]`
Examples:
```python
>>> from transformers import MaskFormerConfig, MaskFormerModel
>>> # Initializing a MaskFormer facebook/maskformer-swin-base-ade configuration
>>> configuration = MaskFormerConfig()
>>> # Initializing a model from the facebook/maskformer-swin-base-ade style configuration
>>> model = MaskFormerModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```
"""
model_type = "maskformer"
attribute_map = {"hidden_size": "mask_feature_size"}
backbones_supported = ["swin"]
decoders_supported = ["detr"]
def __init__(
self,
fpn_feature_size: int = 256,
mask_feature_size: int = 256,
no_object_weight: float = 0.1,
use_auxiliary_loss: bool = False,
backbone_config: Optional[Dict] = None,
decoder_config: Optional[Dict] = None,
init_std: float = 0.02,
init_xavier_std: float = 1.0,
dice_weight: float = 1.0,
cross_entropy_weight: float = 1.0,
mask_weight: float = 20.0,
**kwargs,
):
if backbone_config is None:
# fall back to https://huggingface.co/microsoft/swin-base-patch4-window12-384-in22k
backbone_config = SwinConfig(
image_size=384,
in_channels=3,
patch_size=4,
embed_dim=128,
depths=[2, 2, 18, 2],
num_heads=[4, 8, 16, 32],
window_size=12,
drop_path_rate=0.3,
)
else:
backbone_model_type = backbone_config.pop("model_type")
if backbone_model_type not in self.backbones_supported:
raise ValueError(
f"Backbone {backbone_model_type} not supported, please use one of {','.join(self.backbones_supported)}"
)
backbone_config = AutoConfig.for_model(backbone_model_type, **backbone_config)
if decoder_config is None:
# fall back to https://huggingface.co/facebook/detr-resnet-50
decoder_config = DetrConfig()
else:
decoder_type = decoder_config.pop("model_type")
if decoder_type not in self.decoders_supported:
raise ValueError(
f"Transformer Decoder {decoder_type} not supported, please use one of {','.join(self.decoders_supported)}"
)
decoder_config = AutoConfig.for_model(decoder_type, **decoder_config)
self.backbone_config = backbone_config
self.decoder_config = decoder_config
# main feature dimension for the model
self.fpn_feature_size = fpn_feature_size
self.mask_feature_size = mask_feature_size
# initializer
self.init_std = init_std
self.init_xavier_std = init_xavier_std
# Hungarian matcher && loss
self.cross_entropy_weight = cross_entropy_weight
self.dice_weight = dice_weight
self.mask_weight = mask_weight
self.use_auxiliary_loss = use_auxiliary_loss
self.no_object_weight = no_object_weight
self.num_attention_heads = self.decoder_config.encoder_attention_heads
self.num_hidden_layers = self.decoder_config.num_hidden_layers
super().__init__(**kwargs)
@classmethod
def from_backbone_and_decoder_configs(
cls, backbone_config: PretrainedConfig, decoder_config: PretrainedConfig, **kwargs
):
"""Instantiate a [`MaskFormerConfig`] (or a derived class) from a pre-trained backbone model configuration and DETR model
configuration.
Args:
backbone_config ([`PretrainedConfig`]):
The backbone configuration.
decoder_config ([`PretrainedConfig`]):
The transformer decoder configuration to use.
Returns:
[`MaskFormerConfig`]: An instance of a configuration object
"""
return cls(
backbone_config=backbone_config.to_dict(),
decoder_config=decoder_config.to_dict(),
**kwargs,
)
def to_dict(self) -> Dict[str, any]:
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["backbone_config"] = self.backbone_config.to_dict()
output["decoder_config"] = self.decoder_config.to_dict()
output["model_type"] = self.__class__.model_type
return output
# coding=utf-8
# Copyright 2022 Meta Platforms, Inc. and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
from argparse import ArgumentParser
from dataclasses import dataclass
from pathlib import Path
from pprint import pformat
from typing import Any, Dict, Iterator, List, Set, Tuple
import torch
import torchvision.transforms as T
from PIL import Image
from torch import Tensor, nn
import requests
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.config import get_cfg
from detectron2.data import MetadataCatalog
from detectron2.projects.deeplab import add_deeplab_config
from transformers.models.maskformer.feature_extraction_maskformer import MaskFormerFeatureExtractor
from transformers.models.maskformer.modeling_maskformer import (
MaskFormerConfig,
MaskFormerForInstanceSegmentation,
MaskFormerForInstanceSegmentationOutput,
MaskFormerModel,
MaskFormerModelOutput,
)
from transformers.utils import logging
StateDict = Dict[str, Tensor]
logging.set_verbosity_info()
logger = logging.get_logger()
torch.manual_seed(0)
class TrackedStateDict:
def __init__(self, to_track: Dict):
"""This class "tracks" a python dictionary by keeping track of which item is accessed.
Args:
to_track (Dict): The dictionary we wish to track
"""
self.to_track = to_track
self._seen: Set[str] = set()
def __getitem__(self, key: str) -> Any:
return self.to_track[key]
def __setitem__(self, key: str, item: Any):
self._seen.add(key)
self.to_track[key] = item
def diff(self) -> List[str]:
"""This method returns a set difference between the keys in the tracked state dict and the one we have access so far.
This is an effective method to check if we have update all the keys
Returns:
List[str]: List of keys not yet updated
"""
return set(list(self.to_track.keys())) - self._seen
def copy(self) -> Dict:
# proxy the call to the internal dictionary
return self.to_track.copy()
# We will verify our results on an image of cute cats
def prepare_img():
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
img_data = requests.get(url, stream=True).raw
im = Image.open(img_data)
return im
@dataclass
class Args:
"""Fake command line arguments needed by maskformer/detectron implementation"""
config_file: str
def setup_cfg(args: Args):
# load config from file and command-line arguments
cfg = get_cfg()
add_deeplab_config(cfg)
add_mask_former_config(cfg)
cfg.merge_from_file(args.config_file)
cfg.freeze()
return cfg
class OriginalMaskFormerConfigToOursConverter:
def __call__(self, original_config: object) -> MaskFormerConfig:
model = original_config.MODEL
mask_former = model.MASK_FORMER
swin = model.SWIN
dataset_catalog = MetadataCatalog.get(original_config.DATASETS.TEST[0])
id2label = {idx: label for idx, label in enumerate(dataset_catalog.stuff_classes)}
label2id = {label: idx for idx, label in id2label.items()}
config: MaskFormerConfig = MaskFormerConfig(
fpn_feature_size=model.SEM_SEG_HEAD.CONVS_DIM,
mask_feature_size=model.SEM_SEG_HEAD.MASK_DIM,
num_labels=model.SEM_SEG_HEAD.NUM_CLASSES,
no_object_weight=mask_former.NO_OBJECT_WEIGHT,
num_queries=mask_former.NUM_OBJECT_QUERIES,
backbone_config=dict(
pretrain_img_size=swin.PRETRAIN_IMG_SIZE,
image_size=swin.PRETRAIN_IMG_SIZE,
in_channels=3,
patch_size=swin.PATCH_SIZE,
embed_dim=swin.EMBED_DIM,
depths=swin.DEPTHS,
num_heads=swin.NUM_HEADS,
window_size=swin.WINDOW_SIZE,
drop_path_rate=swin.DROP_PATH_RATE,
model_type="swin",
),
dice_weight=mask_former.DICE_WEIGHT,
ce_weight=1.0,
mask_weight=mask_former.MASK_WEIGHT,
decoder_config=dict(
model_type="detr",
max_position_embeddings=1024,
encoder_layers=6,
encoder_ffn_dim=2048,
encoder_attention_heads=8,
decoder_layers=mask_former.DEC_LAYERS,
decoder_ffn_dim=mask_former.DIM_FEEDFORWARD,
decoder_attention_heads=mask_former.NHEADS,
encoder_layerdrop=0.0,
decoder_layerdrop=0.0,
d_model=mask_former.HIDDEN_DIM,
dropout=mask_former.DROPOUT,
attention_dropout=0.0,
activation_dropout=0.0,
init_std=0.02,
init_xavier_std=1.0,
scale_embedding=False,
auxiliary_loss=False,
dilation=False,
# default pretrained config values
),
id2label=id2label,
label2id=label2id,
)
return config
class OriginalMaskFormerConfigToFeatureExtractorConverter:
def __call__(self, original_config: object) -> MaskFormerFeatureExtractor:
model = original_config.MODEL
model_input = original_config.INPUT
return MaskFormerFeatureExtractor(
image_mean=(torch.tensor(model.PIXEL_MEAN) / 255).tolist(),
image_std=(torch.tensor(model.PIXEL_STD) / 255).tolist(),
size=model_input.MIN_SIZE_TEST,
max_size=model_input.MAX_SIZE_TEST,
size_divisibility=32, # 32 is required by swin
)
class OriginalMaskFormerCheckpointToOursConverter:
def __init__(self, original_model: nn.Module, config: MaskFormerConfig):
self.original_model = original_model
self.config = config
def pop_all(self, renamed_keys: List[Tuple[str, str]], dst_state_dict: StateDict, src_state_dict: StateDict):
for (src_key, dst_key) in renamed_keys:
dst_state_dict[dst_key] = src_state_dict.pop(src_key)
def replace_backbone(self, dst_state_dict: StateDict, src_state_dict: StateDict, config: MaskFormerConfig):
dst_prefix: str = "pixel_level_module.encoder"
src_prefix: str = "backbone"
renamed_keys = [
(
f"{src_prefix}.patch_embed.proj.weight",
f"{dst_prefix}.model.embeddings.patch_embeddings.projection.weight",
),
(f"{src_prefix}.patch_embed.proj.bias", f"{dst_prefix}.model.embeddings.patch_embeddings.projection.bias"),
(f"{src_prefix}.patch_embed.norm.weight", f"{dst_prefix}.model.embeddings.norm.weight"),
(f"{src_prefix}.patch_embed.norm.bias", f"{dst_prefix}.model.embeddings.norm.bias"),
]
num_layers = len(config.backbone_config.depths)
for layer_idx in range(num_layers):
for block_idx in range(config.backbone_config.depths[layer_idx]):
renamed_keys.extend(
[ # src, dst
(
f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.norm1.weight",
f"{dst_prefix}.model.encoder.layers.{layer_idx}.blocks.{block_idx}.layernorm_before.weight",
),
(
f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.norm1.bias",
f"{dst_prefix}.model.encoder.layers.{layer_idx}.blocks.{block_idx}.layernorm_before.bias",
),
(
f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.attn.relative_position_bias_table",
f"{dst_prefix}.model.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.self.relative_position_bias_table",
),
]
)
# now we need to handle the attentions
# read in weights + bias of input projection layer of cross-attention
src_att_weight = src_state_dict[f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.attn.qkv.weight"]
src_att_bias = src_state_dict[f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.attn.qkv.bias"]
size = src_att_weight.shape[0]
offset = size // 3
dst_state_dict[
f"{dst_prefix}.model.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.self.query.weight"
] = src_att_weight[:offset, :]
dst_state_dict[
f"{dst_prefix}.model.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.self.query.bias"
] = src_att_bias[:offset]
dst_state_dict[
f"{dst_prefix}.model.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.self.key.weight"
] = src_att_weight[offset : offset * 2, :]
dst_state_dict[
f"{dst_prefix}.model.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.self.key.bias"
] = src_att_bias[offset : offset * 2]
dst_state_dict[
f"{dst_prefix}.model.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.self.value.weight"
] = src_att_weight[-offset:, :]
dst_state_dict[
f"{dst_prefix}.model.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.self.value.bias"
] = src_att_bias[-offset:]
# let's pop them
src_state_dict.pop(f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.attn.qkv.weight")
src_state_dict.pop(f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.attn.qkv.bias")
# proj
renamed_keys.extend(
[
(
f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.attn.proj.weight",
f"{dst_prefix}.model.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.output.dense.weight",
),
(
f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.attn.proj.bias",
f"{dst_prefix}.model.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.output.dense.bias",
),
]
)
# second norm
renamed_keys.extend(
[
(
f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.norm2.weight",
f"{dst_prefix}.model.encoder.layers.{layer_idx}.blocks.{block_idx}.layernorm_after.weight",
),
(
f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.norm2.bias",
f"{dst_prefix}.model.encoder.layers.{layer_idx}.blocks.{block_idx}.layernorm_after.bias",
),
]
)
# mlp
renamed_keys.extend(
[
(
f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.mlp.fc1.weight",
f"{dst_prefix}.model.encoder.layers.{layer_idx}.blocks.{block_idx}.intermediate.dense.weight",
),
(
f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.mlp.fc1.bias",
f"{dst_prefix}.model.encoder.layers.{layer_idx}.blocks.{block_idx}.intermediate.dense.bias",
),
(
f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.mlp.fc2.weight",
f"{dst_prefix}.model.encoder.layers.{layer_idx}.blocks.{block_idx}.output.dense.weight",
),
(
f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.mlp.fc2.bias",
f"{dst_prefix}.model.encoder.layers.{layer_idx}.blocks.{block_idx}.output.dense.bias",
),
]
)
renamed_keys.extend(
[
(
f"{src_prefix}.layers.{layer_idx}.blocks.{block_idx}.attn.relative_position_index",
f"{dst_prefix}.model.encoder.layers.{layer_idx}.blocks.{block_idx}.attention.self.relative_position_index",
)
]
)
if layer_idx < num_layers - 1:
# patch merging
renamed_keys.extend(
[
(
f"{src_prefix}.layers.{layer_idx}.downsample.reduction.weight",
f"{dst_prefix}.model.encoder.layers.{layer_idx}.downsample.reduction.weight",
),
(
f"{src_prefix}.layers.{layer_idx}.downsample.norm.weight",
f"{dst_prefix}.model.encoder.layers.{layer_idx}.downsample.norm.weight",
),
(
f"{src_prefix}.layers.{layer_idx}.downsample.norm.bias",
f"{dst_prefix}.model.encoder.layers.{layer_idx}.downsample.norm.bias",
),
]
)
# hidden states norms
renamed_keys.extend(
[
(
f"{src_prefix}.norm{layer_idx}.weight",
f"{dst_prefix}.hidden_states_norms.{layer_idx}.weight",
),
(
f"{src_prefix}.norm{layer_idx}.bias",
f"{dst_prefix}.hidden_states_norms.{layer_idx}.bias",
),
]
)
self.pop_all(renamed_keys, dst_state_dict, src_state_dict)
def replace_pixel_module(self, dst_state_dict: StateDict, src_state_dict: StateDict):
dst_prefix: str = "pixel_level_module.decoder"
src_prefix: str = "sem_seg_head.pixel_decoder"
self.replace_backbone(dst_state_dict, src_state_dict, self.config)
def rename_keys_for_conv(detectron_conv: str, mine_conv: str):
return [
(f"{detectron_conv}.weight", f"{mine_conv}.0.weight"),
# 2 cuz the have act in the middle -> rename it
(f"{detectron_conv}.norm.weight", f"{mine_conv}.1.weight"),
(f"{detectron_conv}.norm.bias", f"{mine_conv}.1.bias"),
]
renamed_keys = [
(f"{src_prefix}.mask_features.weight", f"{dst_prefix}.mask_projection.weight"),
(f"{src_prefix}.mask_features.bias", f"{dst_prefix}.mask_projection.bias"),
# the layers in the original one are in reverse order, stem is the last one!
]
renamed_keys.extend(rename_keys_for_conv(f"{src_prefix}.layer_4", f"{dst_prefix}.fpn.stem"))
# add all the fpn layers (here we need some config parameters to know the size in advance)
for src_i, dst_i in zip(range(3, 0, -1), range(0, 3)):
renamed_keys.extend(
rename_keys_for_conv(f"{src_prefix}.adapter_{src_i}", f"{dst_prefix}.fpn.layers.{dst_i}.proj")
)
renamed_keys.extend(
rename_keys_for_conv(f"{src_prefix}.layer_{src_i}", f"{dst_prefix}.fpn.layers.{dst_i}.block")
)
self.pop_all(renamed_keys, dst_state_dict, src_state_dict)
def rename_keys_in_detr_decoder(self, dst_state_dict: StateDict, src_state_dict: StateDict):
dst_prefix: str = "transformer_module.decoder"
src_prefix: str = "sem_seg_head.predictor.transformer.decoder"
# not sure why we are not popping direcetly here!
# here we list all keys to be renamed (original name on the left, our name on the right)
rename_keys = []
for i in range(self.config.decoder_config.decoder_layers):
# decoder layers: 2 times output projection, 2 feedforward neural networks and 3 layernorms
rename_keys.append(
(
f"{src_prefix}.layers.{i}.self_attn.out_proj.weight",
f"{dst_prefix}.layers.{i}.self_attn.out_proj.weight",
)
)
rename_keys.append(
(
f"{src_prefix}.layers.{i}.self_attn.out_proj.bias",
f"{dst_prefix}.layers.{i}.self_attn.out_proj.bias",
)
)
rename_keys.append(
(
f"{src_prefix}.layers.{i}.multihead_attn.out_proj.weight",
f"{dst_prefix}.layers.{i}.encoder_attn.out_proj.weight",
)
)
rename_keys.append(
(
f"{src_prefix}.layers.{i}.multihead_attn.out_proj.bias",
f"{dst_prefix}.layers.{i}.encoder_attn.out_proj.bias",
)
)
rename_keys.append((f"{src_prefix}.layers.{i}.linear1.weight", f"{dst_prefix}.layers.{i}.fc1.weight"))
rename_keys.append((f"{src_prefix}.layers.{i}.linear1.bias", f"{dst_prefix}.layers.{i}.fc1.bias"))
rename_keys.append((f"{src_prefix}.layers.{i}.linear2.weight", f"{dst_prefix}.layers.{i}.fc2.weight"))
rename_keys.append((f"{src_prefix}.layers.{i}.linear2.bias", f"{dst_prefix}.layers.{i}.fc2.bias"))
rename_keys.append(
(f"{src_prefix}.layers.{i}.norm1.weight", f"{dst_prefix}.layers.{i}.self_attn_layer_norm.weight")
)
rename_keys.append(
(f"{src_prefix}.layers.{i}.norm1.bias", f"{dst_prefix}.layers.{i}.self_attn_layer_norm.bias")
)
rename_keys.append(
(f"{src_prefix}.layers.{i}.norm2.weight", f"{dst_prefix}.layers.{i}.encoder_attn_layer_norm.weight")
)
rename_keys.append(
(f"{src_prefix}.layers.{i}.norm2.bias", f"{dst_prefix}.layers.{i}.encoder_attn_layer_norm.bias")
)
rename_keys.append(
(f"{src_prefix}.layers.{i}.norm3.weight", f"{dst_prefix}.layers.{i}.final_layer_norm.weight")
)
rename_keys.append(
(f"{src_prefix}.layers.{i}.norm3.bias", f"{dst_prefix}.layers.{i}.final_layer_norm.bias")
)
return rename_keys
def replace_q_k_v_in_detr_decoder(self, dst_state_dict: StateDict, src_state_dict: StateDict):
dst_prefix: str = "transformer_module.decoder"
src_prefix: str = "sem_seg_head.predictor.transformer.decoder"
for i in range(self.config.decoder_config.decoder_layers):
# read in weights + bias of input projection layer of self-attention
in_proj_weight = src_state_dict.pop(f"{src_prefix}.layers.{i}.self_attn.in_proj_weight")
in_proj_bias = src_state_dict.pop(f"{src_prefix}.layers.{i}.self_attn.in_proj_bias")
# next, add query, keys and values (in that order) to the state dict
dst_state_dict[f"{dst_prefix}.layers.{i}.self_attn.q_proj.weight"] = in_proj_weight[:256, :]
dst_state_dict[f"{dst_prefix}.layers.{i}.self_attn.q_proj.bias"] = in_proj_bias[:256]
dst_state_dict[f"{dst_prefix}.layers.{i}.self_attn.k_proj.weight"] = in_proj_weight[256:512, :]
dst_state_dict[f"{dst_prefix}.layers.{i}.self_attn.k_proj.bias"] = in_proj_bias[256:512]
dst_state_dict[f"{dst_prefix}.layers.{i}.self_attn.v_proj.weight"] = in_proj_weight[-256:, :]
dst_state_dict[f"{dst_prefix}.layers.{i}.self_attn.v_proj.bias"] = in_proj_bias[-256:]
# read in weights + bias of input projection layer of cross-attention
in_proj_weight_cross_attn = src_state_dict.pop(f"{src_prefix}.layers.{i}.multihead_attn.in_proj_weight")
in_proj_bias_cross_attn = src_state_dict.pop(f"{src_prefix}.layers.{i}.multihead_attn.in_proj_bias")
# next, add query, keys and values (in that order) of cross-attention to the state dict
dst_state_dict[f"{dst_prefix}.layers.{i}.encoder_attn.q_proj.weight"] = in_proj_weight_cross_attn[:256, :]
dst_state_dict[f"{dst_prefix}.layers.{i}.encoder_attn.q_proj.bias"] = in_proj_bias_cross_attn[:256]
dst_state_dict[f"{dst_prefix}.layers.{i}.encoder_attn.k_proj.weight"] = in_proj_weight_cross_attn[
256:512, :
]
dst_state_dict[f"{dst_prefix}.layers.{i}.encoder_attn.k_proj.bias"] = in_proj_bias_cross_attn[256:512]
dst_state_dict[f"{dst_prefix}.layers.{i}.encoder_attn.v_proj.weight"] = in_proj_weight_cross_attn[-256:, :]
dst_state_dict[f"{dst_prefix}.layers.{i}.encoder_attn.v_proj.bias"] = in_proj_bias_cross_attn[-256:]
def replace_detr_decoder(self, dst_state_dict: StateDict, src_state_dict: StateDict):
dst_prefix: str = "transformer_module.decoder"
src_prefix: str = "sem_seg_head.predictor.transformer.decoder"
renamed_keys = self.rename_keys_in_detr_decoder(dst_state_dict, src_state_dict)
# add more
renamed_keys.extend(
[
(f"{src_prefix}.norm.weight", f"{dst_prefix}.layernorm.weight"),
(f"{src_prefix}.norm.bias", f"{dst_prefix}.layernorm.bias"),
]
)
self.pop_all(renamed_keys, dst_state_dict, src_state_dict)
self.replace_q_k_v_in_detr_decoder(dst_state_dict, src_state_dict)
def replace_transformer_module(self, dst_state_dict: StateDict, src_state_dict: StateDict):
dst_prefix: str = "transformer_module"
src_prefix: str = "sem_seg_head.predictor"
self.replace_detr_decoder(dst_state_dict, src_state_dict)
renamed_keys = [
(f"{src_prefix}.query_embed.weight", f"{dst_prefix}.queries_embedder.weight"),
(f"{src_prefix}.input_proj.weight", f"{dst_prefix}.input_projection.weight"),
(f"{src_prefix}.input_proj.bias", f"{dst_prefix}.input_projection.bias"),
]
self.pop_all(renamed_keys, dst_state_dict, src_state_dict)
def replace_instance_segmentation_module(self, dst_state_dict: StateDict, src_state_dict: StateDict):
# NOTE in our case we don't have a prefix, thus we removed the "." from the keys later on!
dst_prefix: str = ""
src_prefix: str = "sem_seg_head.predictor"
renamed_keys = [
(f"{src_prefix}.class_embed.weight", f"{dst_prefix}class_predictor.weight"),
(f"{src_prefix}.class_embed.bias", f"{dst_prefix}class_predictor.bias"),
]
mlp_len = 3
for i in range(mlp_len):
renamed_keys.extend(
[
(f"{src_prefix}.mask_embed.layers.{i}.weight", f"{dst_prefix}mask_embedder.{i}.0.weight"),
(f"{src_prefix}.mask_embed.layers.{i}.bias", f"{dst_prefix}mask_embedder.{i}.0.bias"),
]
)
logger.info(f"Replacing keys {pformat(renamed_keys)}")
self.pop_all(renamed_keys, dst_state_dict, src_state_dict)
def convert(self, mask_former: MaskFormerModel) -> MaskFormerModel:
dst_state_dict = TrackedStateDict(mask_former.state_dict())
src_state_dict = self.original_model.state_dict()
self.replace_pixel_module(dst_state_dict, src_state_dict)
self.replace_transformer_module(dst_state_dict, src_state_dict)
logger.info(f"Missed keys are {pformat(dst_state_dict.diff())}")
logger.info(f"Not copied keys are {pformat(src_state_dict.keys())}")
logger.info("🙌 Done")
mask_former.load_state_dict(dst_state_dict)
return mask_former
def convert_instance_segmentation(
self, mask_former: MaskFormerForInstanceSegmentation
) -> MaskFormerForInstanceSegmentation:
dst_state_dict = TrackedStateDict(mask_former.state_dict())
src_state_dict = self.original_model.state_dict()
self.replace_instance_segmentation_module(dst_state_dict, src_state_dict)
mask_former.load_state_dict(dst_state_dict)
return mask_former
@staticmethod
def using_dirs(checkpoints_dir: Path, config_dir: Path) -> Iterator[Tuple[object, Path, Path]]:
checkpoints: List[Path] = checkpoints_dir.glob("**/*.pkl")
for checkpoint in checkpoints:
logger.info(f"💪 Converting {checkpoint.stem}")
# find associated config file
config: Path = config_dir / checkpoint.parents[0].stem / "swin" / f"{checkpoint.stem}.yaml"
yield config, checkpoint
def test(original_model, our_model: MaskFormerForInstanceSegmentation):
with torch.no_grad():
original_model = original_model.eval()
our_model = our_model.eval()
im = prepare_img()
tr = T.Compose(
[
T.Resize((384, 384)),
T.ToTensor(),
T.Normalize(
mean=torch.tensor([123.675, 116.280, 103.530]) / 255.0,
std=torch.tensor([58.395, 57.120, 57.375]) / 255.0,
),
],
)
x = tr(im).unsqueeze(0)
original_model_backbone_features = original_model.backbone(x.clone())
our_model_output: MaskFormerModelOutput = our_model.model(x.clone(), output_hidden_states=True)
for original_model_feature, our_model_feature in zip(
original_model_backbone_features.values(), our_model_output.encoder_hidden_states
):
assert torch.allclose(
original_model_feature, our_model_feature, atol=1e-3
), "The backbone features are not the same."
original_model_pixel_out = original_model.sem_seg_head.pixel_decoder.forward_features(
original_model_backbone_features
)
assert torch.allclose(
original_model_pixel_out[0], our_model_output.pixel_decoder_last_hidden_state, atol=1e-4
), "The pixel decoder feature are not the same"
# let's test the full model
original_model_out = original_model([{"image": x.squeeze(0)}])
original_segmentation = original_model_out[0]["sem_seg"]
our_model_out: MaskFormerForInstanceSegmentationOutput = our_model(x)
feature_extractor = MaskFormerFeatureExtractor()
our_segmentation = feature_extractor.post_process_segmentation(our_model_out, target_size=(384, 384))
assert torch.allclose(
original_segmentation, our_segmentation, atol=1e-3
), "The segmentation image is not the same."
logger.info("✅ Test passed!")
def get_name(checkpoint_file: Path):
model_name_raw: str = checkpoint_file.stem
# model_name_raw is something like maskformer_panoptic_swin_base_IN21k_384_bs64_554k
parent_name: str = checkpoint_file.parents[0].stem
backbone = "swin"
dataset = ""
if "coco" in parent_name:
dataset = "coco"
elif "ade" in parent_name:
dataset = "ade"
else:
raise ValueError(f"{parent_name} must be wrong since we didn't find 'coco' or 'ade' in it ")
backbone_types = ["tiny", "small", "base", "large"]
backbone_type = list(filter(lambda x: x in model_name_raw, backbone_types))[0]
model_name = f"maskformer-{backbone}-{backbone_type}-{dataset}"
return model_name
if __name__ == "__main__":
parser = ArgumentParser(
description="Command line to convert the original maskformers (with swin backbone) to our implementations."
)
parser.add_argument(
"--checkpoints_dir",
type=Path,
help="A directory containing the model's checkpoints. The directory has to have the following structure: <DIR_NAME>/<DATASET_NAME>/<CONFIG_NAME>.pkl",
)
parser.add_argument(
"--configs_dir",
type=Path,
help="A directory containing the model's configs, see detectron2 doc. The directory has to have the following structure: <DIR_NAME>/<DATASET_NAME>/<CONFIG_NAME>.yaml",
)
parser.add_argument(
"--pytorch_dump_folder_path",
required=True,
type=Path,
help="Path to the folder to output PyTorch models.",
)
parser.add_argument(
"--maskformer_dir",
required=True,
type=Path,
help="A path to MaskFormer's original implementation directory. You can download from here: https://github.com/facebookresearch/MaskFormer",
)
args = parser.parse_args()
checkpoints_dir: Path = args.checkpoints_dir
config_dir: Path = args.configs_dir
save_directory: Path = args.pytorch_dump_folder_path
maskformer_dir: Path = args.maskformer_dir
# append the path to the parents to maskformer dir
sys.path.append(str(maskformer_dir.parent))
# and import what's needed
from MaskFormer.mask_former import add_mask_former_config
from MaskFormer.mask_former.mask_former_model import MaskFormer as OriginalMaskFormer
if not save_directory.exists():
save_directory.mkdir(parents=True)
for config_file, checkpoint_file in OriginalMaskFormerCheckpointToOursConverter.using_dirs(
checkpoints_dir, config_dir
):
feature_extractor = OriginalMaskFormerConfigToFeatureExtractorConverter()(
setup_cfg(Args(config_file=config_file))
)
original_config = setup_cfg(Args(config_file=config_file))
mask_former_kwargs = OriginalMaskFormer.from_config(original_config)
original_model = OriginalMaskFormer(**mask_former_kwargs).eval()
DetectionCheckpointer(original_model).load(str(checkpoint_file))
config: MaskFormerConfig = OriginalMaskFormerConfigToOursConverter()(original_config)
mask_former = MaskFormerModel(config=config).eval()
converter = OriginalMaskFormerCheckpointToOursConverter(original_model, config)
maskformer = converter.convert(mask_former)
mask_former_for_instance_segmentation = MaskFormerForInstanceSegmentation(config=config).eval()
mask_former_for_instance_segmentation.model = mask_former
mask_former_for_instance_segmentation = converter.convert_instance_segmentation(
mask_former_for_instance_segmentation
)
test(original_model, mask_former_for_instance_segmentation)
model_name = get_name(checkpoint_file)
logger.info(f"🪄 Saving {model_name}")
feature_extractor.save_pretrained(save_directory / model_name)
mask_former_for_instance_segmentation.save_pretrained(save_directory / model_name)
feature_extractor.push_to_hub(
repo_path_or_name=save_directory / model_name,
commit_message="Add model",
use_temp_dir=True,
)
mask_former_for_instance_segmentation.push_to_hub(
repo_path_or_name=save_directory / model_name,
commit_message="Add model",
use_temp_dir=True,
)
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Feature extractor class for MaskFormer."""
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import numpy as np
from PIL import Image
from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin
from ...file_utils import TensorType, is_torch_available
from ...image_utils import ImageFeatureExtractionMixin, ImageInput, is_torch_tensor
from ...utils import logging
if is_torch_available():
import torch
from torch import Tensor, nn
from torch.nn.functional import interpolate
if TYPE_CHECKING:
from transformers.models.maskformer.modeling_maskformer import MaskFormerForInstanceSegmentationOutput
logger = logging.get_logger(__name__)
class MaskFormerFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
r"""
Constructs a MaskFormer feature extractor. The feature extractor can be used to prepare image(s) and optional
targets for the model.
This feature extractor inherits from [`FeatureExtractionMixin`] which contains most of the main methods. Users
should refer to this superclass for more information regarding those methods.
Args:
do_resize (`bool`, *optional*, defaults to `True`):
Whether to resize the input to a certain `size`.
size (`int`, *optional*, defaults to 800):
Resize the input to the given size. Only has an effect if `do_resize` is set to `True`. If size is a
sequence like `(width, height)`, output size will be matched to this. If size is an int, smaller edge of
the image will be matched to this number. i.e, if `height > width`, then image will be rescaled to `(size *
height / width, size)`.
max_size (`int`, *optional*, defaults to 1333):
The largest size an image dimension can have (otherwise it's capped). Only has an effect if `do_resize` is
set to `True`.
size_divisibility (`int`, *optional*, defaults to 32):
Some backbones need images divisible by a certain number. If not passed, it defaults to the value used in
Swin Transformer.
do_normalize (`bool`, *optional*, defaults to `True`):
Whether or not to normalize the input with mean and standard deviation.
image_mean (`int`, *optional*, defaults to `[0.485, 0.456, 0.406]`):
The sequence of means for each channel, to be used when normalizing images. Defaults to the ImageNet mean.
image_std (`int`, *optional*, defaults to `[0.229, 0.224, 0.225]`):
The sequence of standard deviations for each channel, to be used when normalizing images. Defaults to the
ImageNet std.
ignore_index (`int`, *optional*, default to 255):
Value of the index (label) to ignore.
"""
model_input_names = ["pixel_values", "pixel_mask"]
def __init__(
self,
do_resize=True,
size=800,
max_size=1333,
size_divisibility=32,
do_normalize=True,
image_mean=None,
image_std=None,
ignore_index=255,
**kwargs
):
super().__init__(**kwargs)
self.do_resize = do_resize
self.size = size
self.max_size = max_size
self.size_divisibility = size_divisibility
self.ignore_index = ignore_index
self.do_normalize = do_normalize
self.image_mean = image_mean if image_mean is not None else [0.485, 0.456, 0.406] # ImageNet mean
self.image_std = image_std if image_std is not None else [0.229, 0.224, 0.225] # ImageNet std
def _resize(self, image, size, target=None, max_size=None):
"""
Resize the image to the given size. Size can be min_size (scalar) or (width, height) tuple. If size is an int,
smaller edge of the image will be matched to this number.
If given, also resize the target accordingly.
"""
if not isinstance(image, Image.Image):
image = self.to_pil_image(image)
def get_size_with_aspect_ratio(image_size, size, max_size=None):
width, height = image_size
if max_size is not None:
min_original_size = float(min((width, height)))
max_original_size = float(max((width, height)))
if max_original_size / min_original_size * size > max_size:
size = int(round(max_size * min_original_size / max_original_size))
if (width <= height and width == size) or (height <= width and height == size):
return (height, width)
if width < height:
output_width = size
output_height = int(size * height / width)
else:
output_height = size
output_width = int(size * width / height)
return (output_height, output_width)
def get_size(image_size, size, max_size=None):
if isinstance(size, (list, tuple)):
return size
else:
# size returned must be (width, height) since we use PIL to resize images
# so we revert the tuple
return get_size_with_aspect_ratio(image_size, size, max_size)[::-1]
width, height = get_size(image.size, size, max_size)
if self.size_divisibility > 0:
height = int(np.ceil(height / self.size_divisibility)) * self.size_divisibility
width = int(np.ceil(width / self.size_divisibility)) * self.size_divisibility
size = (width, height)
rescaled_image = self.resize(image, size=size)
has_target = target is not None
if has_target:
target = target.copy()
# store original_size
target["original_size"] = image.size
if "masks" in target:
masks = torch.from_numpy(target["masks"])[:, None].float()
# use PyTorch as current workaround
# TODO replace by self.resize
interpolated_masks = (
nn.functional.interpolate(masks, size=(height, width), mode="nearest")[:, 0] > 0.5
).float()
target["masks"] = interpolated_masks.numpy()
return rescaled_image, target
def __call__(
self,
images: ImageInput,
annotations: Union[List[Dict], List[List[Dict]]] = None,
pad_and_return_pixel_mask: Optional[bool] = True,
return_tensors: Optional[Union[str, TensorType]] = None,
**kwargs,
) -> BatchFeature:
"""
Main method to prepare for the model one or several image(s) and optional annotations. Images are by default
padded up to the largest image in a batch, and a pixel mask is created that indicates which pixels are
real/which are padding.
<Tip warning={true}>
NumPy arrays and PyTorch tensors are converted to PIL images when resizing, so the most efficient is to pass
PIL images.
</Tip>
Args:
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
number of channels, H and W are image height and width.
annotations (`Dict`, `List[Dict]`, *optional*):
The corresponding annotations as dictionary of numpy arrays with the following keys:
- **masks** (`np.ndarray`) The target mask of shape `(num_classes, height, width)`.
- **labels** (`np.ndarray`) The target labels of shape `(num_classes)`.
pad_and_return_pixel_mask (`bool`, *optional*, defaults to `True`):
Whether or not to pad images up to the largest image in a batch and create a pixel mask.
If left to the default, will return a pixel mask that is:
- 1 for pixels that are real (i.e. **not masked**),
- 0 for pixels that are padding (i.e. **masked**).
return_tensors (`str` or [`~file_utils.TensorType`], *optional*):
If set, will return tensors instead of NumPy arrays. If set to `'pt'`, return PyTorch `torch.Tensor`
objects.
Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
- **pixel_values** -- Pixel values to be fed to a model.
- **pixel_mask** -- Pixel mask to be fed to a model (when `pad_and_return_pixel_mask=True` or if
*"pixel_mask"* is in `self.model_input_names`).
- **mask_labels** -- Optional mask labels of shape `(batch_size, num_classes, height, width) to be fed to a
model (when `annotations` are provided).
- **class_labels** -- Optional class labels of shape `(batch_size, num_classes) to be fed to a model (when
`annotations` are provided).
"""
# Input type checking for clearer error
valid_images = False
valid_annotations = False
# Check that images has a valid type
if isinstance(images, (Image.Image, np.ndarray)) or is_torch_tensor(images):
valid_images = True
elif isinstance(images, (list, tuple)):
if len(images) == 0 or isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]):
valid_images = True
if not valid_images:
raise ValueError(
"Images must of type `PIL.Image.Image`, `np.ndarray` or `torch.Tensor` (single example), "
"`List[PIL.Image.Image]`, `List[np.ndarray]` or `List[torch.Tensor]` (batch of examples)."
)
is_batched = bool(
isinstance(images, (list, tuple))
and (isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]))
)
if not is_batched:
images = [images]
if annotations is not None:
annotations = [annotations]
# Check that annotations has a valid type
if annotations is not None:
valid_annotations = type(annotations) is list and "masks" in annotations[0] and "labels" in annotations[0]
if not valid_annotations:
raise ValueError(
"Annotations must of type `Dict` (single image) or `List[Dict]` (batch of images)."
"The annotations must be numpy arrays in the following format:"
"{ 'masks' : the target mask, with shape [C,H,W], 'labels' : the target labels, with shape [C]}"
)
# transformations (resizing + normalization)
if self.do_resize and self.size is not None:
if annotations is not None:
for idx, (image, target) in enumerate(zip(images, annotations)):
image, target = self._resize(image=image, target=target, size=self.size, max_size=self.max_size)
images[idx] = image
annotations[idx] = target
else:
for idx, image in enumerate(images):
images[idx] = self._resize(image=image, target=None, size=self.size, max_size=self.max_size)[0]
if self.do_normalize:
images = [self.normalize(image=image, mean=self.image_mean, std=self.image_std) for image in images]
# NOTE I will be always forced to pad them them since they have to be stacked in the batch dim
encoded_inputs = self.encode_inputs(
images, annotations, pad_and_return_pixel_mask, return_tensors=return_tensors
)
# Convert to TensorType
tensor_type = return_tensors
if not isinstance(tensor_type, TensorType):
tensor_type = TensorType(tensor_type)
if not tensor_type == TensorType.PYTORCH:
raise ValueError("Only PyTorch is supported for the moment.")
else:
if not is_torch_available():
raise ImportError("Unable to convert output to PyTorch tensors format, PyTorch is not installed.")
return encoded_inputs
def _max_by_axis(self, the_list: List[List[int]]) -> List[int]:
maxes = the_list[0]
for sublist in the_list[1:]:
for index, item in enumerate(sublist):
maxes[index] = max(maxes[index], item)
return maxes
def encode_inputs(
self,
pixel_values_list: List["torch.Tensor"],
annotations: Optional[List[Dict]] = None,
pad_and_return_pixel_mask: Optional[bool] = True,
return_tensors: Optional[Union[str, TensorType]] = None,
):
"""
Pad images up to the largest image in a batch and create a corresponding `pixel_mask`.
Args:
pixel_values_list (`List[torch.Tensor]`):
List of images (pixel values) to be padded. Each image should be a tensor of shape `(channels, height,
width)`.
annotations (`Dict`, `List[Dict]`, *optional*):
The corresponding annotations as dictionary of numpy arrays with the following keys:
- **masks** (`np.ndarray`) The target mask of shape `(num_classes, height, width)`.
- **labels** (`np.ndarray`) The target labels of shape `(num_classes)`.
pad_and_return_pixel_mask (`bool`, *optional*, defaults to `True`):
Whether or not to pad images up to the largest image in a batch and create a pixel mask.
If left to the default, will return a pixel mask that is:
- 1 for pixels that are real (i.e. **not masked**),
- 0 for pixels that are padding (i.e. **masked**).
return_tensors (`str` or [`~file_utils.TensorType`], *optional*):
If set, will return tensors instead of NumPy arrays. If set to `'pt'`, return PyTorch `torch.Tensor`
objects.
Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
- **pixel_values** -- Pixel values to be fed to a model.
- **pixel_mask** -- Pixel mask to be fed to a model (when `pad_and_return_pixel_mask=True` or if
*"pixel_mask"* is in `self.model_input_names`).
- **mask_labels** -- Optional mask labels of shape `(batch_size, num_classes, height, width) to be fed to a
model (when `annotations` are provided).
- **class_labels** -- Optional class labels of shape `(batch_size, num_classes) to be fed to a model (when
`annotations` are provided).
"""
max_size = self._max_by_axis([list(image.shape) for image in pixel_values_list])
channels, height, width = max_size
pixel_values = []
pixel_mask = []
mask_labels = []
class_labels = []
for idx, image in enumerate(pixel_values_list):
# create padded image
if pad_and_return_pixel_mask:
padded_image = np.zeros((channels, height, width), dtype=np.float32)
padded_image[: image.shape[0], : image.shape[1], : image.shape[2]] = np.copy(image)
image = padded_image
pixel_values.append(image)
# if we have a target, pad it
if annotations:
annotation = annotations[idx]
masks = annotation["masks"]
if pad_and_return_pixel_mask:
padded_masks = np.zeros((masks.shape[0], height, width), dtype=masks.dtype)
padded_masks[:, : masks.shape[1], : masks.shape[2]] = np.copy(masks)
masks = padded_masks
mask_labels.append(masks)
class_labels.append(annotation["labels"])
if pad_and_return_pixel_mask:
# create pixel mask
mask = np.zeros((height, width), dtype=np.int64)
mask[: image.shape[1], : image.shape[2]] = True
pixel_mask.append(mask)
# return as BatchFeature
data = {"pixel_values": pixel_values, "pixel_mask": pixel_mask}
if annotations:
data["mask_labels"] = mask_labels
data["class_labels"] = class_labels
encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
return encoded_inputs
def post_process_segmentation(
self, outputs: "MaskFormerForInstanceSegmentationOutput", target_size: Tuple[int, int] = None
) -> "torch.Tensor":
"""
Converts the output of [`MaskFormerForInstanceSegmentationOutput`] into image segmentation predictions. Only
supports PyTorch.
Args:
outputs ([`MaskFormerForInstanceSegmentationOutput`]):
The outputs from [`MaskFormerForInstanceSegmentation`].
target_size (`Tuple[int, int]`, *optional*):
If set, the `masks_queries_logits` will be resized to `target_size`.
Returns:
`torch.Tensor`:
A tensor of shape (`batch_size, num_labels, height, width`).
"""
# class_queries_logits has shape [BATCH, QUERIES, CLASSES + 1]
class_queries_logits = outputs.class_queries_logits
# masks_queries_logits has shape [BATCH, QUERIES, HEIGHT, WIDTH]
masks_queries_logits = outputs.masks_queries_logits
if target_size is not None:
masks_queries_logits = interpolate(
masks_queries_logits,
size=target_size,
mode="bilinear",
align_corners=False,
)
# remove the null class `[..., :-1]`
masks_classes = class_queries_logits.softmax(dim=-1)[..., :-1]
# mask probs has shape [BATCH, QUERIES, HEIGHT, WIDTH]
masks_probs = masks_queries_logits.sigmoid()
# now we want to sum over the queries,
# $ out_{c,h,w} = \sum_q p_{q,c} * m_{q,h,w} $
# where $ softmax(p) \in R^{q, c} $ is the mask classes
# and $ sigmoid(m) \in R^{q, h, w}$ is the mask probabilities
# b(atch)q(uery)c(lasses), b(atch)q(uery)h(eight)w(idth)
segmentation = torch.einsum("bqc, bqhw -> bchw", masks_classes, masks_probs)
return segmentation
def remove_low_and_no_objects(self, masks, scores, labels, object_mask_threshold, num_labels):
"""
Binarize the given masks using `object_mask_threshold`, it returns the associated values of `masks`, `scores`
and `labels`.
Args:
masks (`torch.Tensor`):
A tensor of shape `(num_queries, height, width)`.
scores (`torch.Tensor`):
A tensor of shape `(num_queries)`.
labels (`torch.Tensor`):
A tensor of shape `(num_queries)`.
object_mask_threshold (`float`):
A number between 0 and 1 used to binarize the masks.
Raises:
`ValueError`: Raised when the first dimension doesn't match in all input tensors.
Returns:
`Tuple[`torch.Tensor`, `torch.Tensor`, `torch.Tensor`]`: The `masks`, `scores` and `labels` without the
region < `object_mask_threshold`.
"""
if not (masks.shape[0] == scores.shape[0] == labels.shape[0]):
raise ValueError("mask, scores and labels must have the same shape!")
to_keep = labels.ne(num_labels) & (scores > object_mask_threshold)
return masks[to_keep], scores[to_keep], labels[to_keep]
def post_process_semantic_segmentation(
self, outputs: "MaskFormerForInstanceSegmentationOutput", target_size: Tuple[int, int] = None
) -> "torch.Tensor":
"""
Converts the output of [`MaskFormerForInstanceSegmentationOutput`] into semantic segmentation predictions. Only
supports PyTorch.
Args:
outputs ([`MaskFormerForInstanceSegmentationOutput`]):
The outputs from [`MaskFormerForInstanceSegmentation`].
Returns:
`torch.Tensor`: A tensor of shape `batch_size, height, width`.
"""
segmentation = self.post_process_segmentation(outputs, target_size)
semantic_segmentation = segmentation.argmax(dim=1)
return semantic_segmentation
def post_process_panoptic_segmentation(
self,
outputs: "MaskFormerForInstanceSegmentationOutput",
object_mask_threshold: float = 0.8,
overlap_mask_area_threshold: float = 0.8,
is_thing_map: Optional[Dict[int, bool]] = None,
) -> List[Dict]:
"""
Converts the output of [`MaskFormerForInstanceSegmentationOutput`] into image panoptic segmentation
predictions. Only supports PyTorch.
Args:
outputs ([`MaskFormerForInstanceSegmentationOutput`]):
The outputs from [`MaskFormerForInstanceSegmentation`].
object_mask_threshold (`float`, *optional*, defaults to 0.8):
The object mask threshold.
overlap_mask_area_threshold (`float`, *optional*, defaults to 0.8):
The overlap mask area threshold to use.
is_thing_map (`Dict[int, bool]`, *optional*):
Dictionary mapping class indices to either `True` or `False`, depending on whether or not they are a
thing. If not set, defaults to the `is_thing_map` of COCO panoptic.
Returns:
`List[Dict]`: A list of dictionaries, one per image, each dictionary containing two keys:
- **segmentation** -- a tensor of shape `(height, width)` where each pixel represents a `segment_id`.
- **segments** -- a dictionary with the following keys
- **id** -- an integer representing the `segment_id`.
- **category_id** -- an integer representing the segment's label.
- **is_thing** -- a boolean, `True` if `category_id` was in `is_thing_map`, `False` otherwise.
"""
if is_thing_map is None:
logger.warning("`is_thing_map` unset. Default to COCO.")
# default to is_thing_map of COCO panoptic
is_thing_map = {i: i <= 90 for i in range(201)}
# class_queries_logits has shape [BATCH, QUERIES, CLASSES + 1]
class_queries_logits = outputs.class_queries_logits
# keep track of the number of labels, subtract -1 for null class
num_labels = class_queries_logits.shape[-1] - 1
# masks_queries_logits has shape [BATCH, QUERIES, HEIGHT, WIDTH]
masks_queries_logits = outputs.masks_queries_logits
# since all images are padded, they all have the same spatial dimensions
_, _, height, width = masks_queries_logits.shape
# for each query, the best scores and their indeces
pred_scores, pred_labels = nn.functional.softmax(class_queries_logits, dim=-1).max(-1)
# pred_scores and pred_labels shape = [BATH,NUM_QUERIES]
mask_probs = masks_queries_logits.sigmoid()
# mask probs has shape [BATCH, QUERIES, HEIGHT, WIDTH]
# now, we need to iterate over the batch size to correctly process the segmentation we got from the queries using our thresholds. Even if the original predicted masks have the same shape across the batch, they won't after thresholding so batch-wise operations are impossible
results: List[Dict[str, Tensor]] = []
for (mask_probs, pred_scores, pred_labels) in zip(mask_probs, pred_scores, pred_labels):
mask_probs, pred_scores, pred_labels = self.remove_low_and_no_objects(
mask_probs, pred_scores, pred_labels, object_mask_threshold, num_labels
)
we_detect_something = mask_probs.shape[0] > 0
segmentation = torch.zeros((height, width), dtype=torch.int32, device=mask_probs.device)
segments: List[Dict] = []
if we_detect_something:
current_segment_id = 0
# weight each mask by its score
mask_probs *= pred_scores.view(-1, 1, 1)
# find out for each pixel what is the most likely class to be there
mask_labels = mask_probs.argmax(0)
# mask_labels shape = [H,W] where each pixel has a class label
stuff_memory_list: Dict[str, int] = {}
# this is a map between stuff and segments id, the used it to keep track of the instances of one class
for k in range(pred_labels.shape[0]):
pred_class = pred_labels[k].item()
# check if pred_class is not a "thing", so it can be merged with other instance. For example, class "sky" cannot have more then one instance
is_stuff = not is_thing_map[pred_class]
# get the mask associated with the k class
mask_k = mask_labels == k
# create the area, since bool we just need to sum :)
mask_k_area = mask_k.sum()
# this is the area of all the stuff in query k
# TODO not 100%, why are the taking the k query here????
original_area = (mask_probs[k] >= 0.5).sum()
mask_does_exist = mask_k_area > 0 and original_area > 0
if mask_does_exist:
# find out how much of the all area mask_k is using
area_ratio = mask_k_area / original_area
mask_k_is_overlapping_enough = area_ratio.item() > overlap_mask_area_threshold
if mask_k_is_overlapping_enough:
# merge stuff regions
if pred_class in stuff_memory_list:
current_segment_id = stuff_memory_list[pred_class]
else:
current_segment_id += 1
# then we update out mask with the current segment
segmentation[mask_k] = current_segment_id
segments.append(
{
"id": current_segment_id,
"category_id": pred_class,
"is_thing": not is_stuff,
}
)
if is_stuff:
stuff_memory_list[pred_class] = current_segment_id
results.append({"segmentation": segmentation, "segments": segments})
return results
# coding=utf-8
# Copyright 2022 Meta Platforms, Inc.s and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch MaskFormer model."""
import collections.abc
import math
import random
from dataclasses import dataclass
from numbers import Number
from typing import Dict, List, Optional, Tuple
import numpy as np
import torch
from torch import Tensor, nn
from transformers.utils import logging
from ...activations import ACT2FN
from ...file_utils import (
ModelOutput,
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_scipy_available,
replace_return_docstrings,
requires_backends,
)
from ...modeling_outputs import BaseModelOutputWithCrossAttentions
from ...modeling_utils import ModuleUtilsMixin, PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
from ..detr import DetrConfig
from ..swin import SwinConfig
from .configuration_maskformer import MaskFormerConfig
if is_scipy_available():
from scipy.optimize import linear_sum_assignment
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "MaskFormerConfig"
_CHECKPOINT_FOR_DOC = "facebook/maskformer-swin-base-ade"
_FEAT_EXTRACTOR_FOR_DOC = "MaskFormerFeatureExtractor"
MASKFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [
"facebook/maskformer-swin-base-ade",
# See all MaskFormer models at https://huggingface.co/models?filter=maskformer
]
@dataclass
class MaskFormerSwinModelOutputWithPooling(ModelOutput):
"""
Class for MaskFormerSwinModel's outputs that also contains the spatial dimensions of the hidden states.
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
Last layer hidden-state after a mean pooling operation.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
hidden_states_spatial_dimensions (`tuple(tuple(int, int))`, *optional*):
A tuple containing the spatial dimension of each `hidden_state` needed to reshape the `hidden_states` to
`batch, channels, height, width`. Due to padding, their spatial size cannot be inferred before the
`forward` method.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
last_hidden_state: torch.FloatTensor = None
pooler_output: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
hidden_states_spatial_dimensions: Tuple[Tuple[int, int]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class MaskFormerSwinBaseModelOutput(ModelOutput):
"""
Class for SwinEncoder's outputs.
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
hidden_states_spatial_dimensions (`tuple(tuple(int, int))`, *optional*):
A tuple containing the spatial dimension of each `hidden_state` needed to reshape the `hidden_states` to
`batch, channels, height, width`. Due to padding, their spatial size cannot inferred before the `forward`
method.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
last_hidden_state: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
hidden_states_spatial_dimensions: Tuple[Tuple[int, int]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
# Copied from transformers.models.detr.modeling_detr.DetrDecoderOutput
class DetrDecoderOutput(BaseModelOutputWithCrossAttentions):
"""
Base class for outputs of the DETR decoder. This class adds one attribute to BaseModelOutputWithCrossAttentions,
namely an optional stack of intermediate decoder activations, i.e. the output of each decoder layer, each of them
gone through a layernorm. This is useful when training the model with auxiliary decoding losses.
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
plus the initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
the self-attention heads.
cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax,
used to compute the weighted average in the cross-attention heads.
intermediate_hidden_states (`torch.FloatTensor` of shape `(config.decoder_layers, batch_size, num_queries, hidden_size)`, *optional*, returned when `config.auxiliary_loss=True`):
Intermediate decoder activations, i.e. the output of each decoder layer, each of them gone through a
layernorm.
"""
intermediate_hidden_states: Optional[torch.FloatTensor] = None
@dataclass
class MaskFormerPixelLevelModuleOutput(ModelOutput):
"""
MaskFormer's pixel level module output. It returns both the last and (optionally) the hidden states from the
`encoder` and `decoder`. By default, the `encoder` is a MaskFormerSwin Transformer and the `decoder` is a Feature
Pyramid Network (FPN).
The `encoder_last_hidden_state` are referred on the paper as **images features**, while `decoder_last_hidden_state`
as **pixel embeddings**
Args:
encoder_last_hidden_state (`torch.FloatTensor` of shape`(batch_size, num_channels, height, width)`):
Last hidden states (final feature map) of the last stage of the encoder.
encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
shape `(batch_size, num_channels, height, width)`. Hidden-states (also called feature maps) of the model at
the output of each stage.
decoder_last_hidden_state (`torch.FloatTensor` of shape`(batch_size, num_channels, height, width)`):
Last hidden states (final feature map) of the last stage of the decoder.
decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
shape `(batch_size, num_channels, height, width)`. Hidden-states (also called feature maps) of the model at
the output of each stage.
"""
encoder_last_hidden_state: Optional[torch.FloatTensor] = None
decoder_last_hidden_state: Optional[torch.FloatTensor] = None
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
class MaskFormerPixelDecoderOutput(ModelOutput):
"""
MaskFormer's pixel decoder module output, practically a Feature Pyramid Network. It returns the last hidden state
and (optionally) the hidden states.
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Last hidden states (final feature map) of the last stage of the model.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
shape `(batch_size, num_channels, height, width)`. Hidden-states of the model at the output of each layer
plus the initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`. Attentions weights from Detr's decoder after the attention softmax, used to compute the
weighted average in the self-attention heads.
"""
last_hidden_state: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class MaskFormerModelOutput(ModelOutput):
"""
Class for outputs of [`MaskFormerModel`]. This class returns all the needed hidden states to compute the logits.
Args:
encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Last hidden states (final feature map) of the last stage of the encoder model (backbone).
pixel_decoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Last hidden states (final feature map) of the last stage of the pixel decoder model (FPN).
transformer_decoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Last hidden states (final feature map) of the last stage of the transformer decoder model.
encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
shape `(batch_size, num_channels, height, width)`. Hidden-states (also called feature maps) of the encoder
model at the output of each stage.
pixel_decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
shape `(batch_size, num_channels, height, width)`. Hidden-states (also called feature maps) of the pixel
decoder model at the output of each stage.
transformer_decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
shape `(batch_size, sequence_length, hidden_size)`. Hidden-states (also called feature maps) of the
transformer decoder at the output of each stage.
hidden_states `tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` containing `encoder_hidden_states`, `pixel_decoder_hidden_states` and
`decoder_hidden_states`
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`. Attentions weights from Detr's decoder after the attention softmax, used to compute the
weighted average in the self-attention heads.
"""
encoder_last_hidden_state: Optional[torch.FloatTensor] = None
pixel_decoder_last_hidden_state: Optional[torch.FloatTensor] = None
transformer_decoder_last_hidden_state: Optional[torch.FloatTensor] = None
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
pixel_decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
transformer_decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class MaskFormerForInstanceSegmentationOutput(ModelOutput):
"""
Class for outputs of [`MaskFormerForInstanceSegmentation`].
This output can be directly passed to [`~MaskFormerFeatureExtractor.post_process_segmentation`] or
[`~MaskFormerFeatureExtractor.post_process_panoptic_segmentation`] depending on the task. Please, see
[`~MaskFormerFeatureExtractor] for details regarding usage.
Args:
loss (`torch.Tensor`, *optional*):
The computed loss, returned when labels are present.
class_queries_logits (`torch.FloatTensor`):
A tensor of shape `(batch_size, num_queries, height, width)` representing the proposed masks for each
query.
masks_queries_logits (`torch.FloatTensor`):
A tensor of shape `(batch_size, num_queries, num_classes + 1)` representing the proposed classes for each
query. Note the `+ 1` is needed because we incorporate the null class.
encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Last hidden states (final feature map) of the last stage of the encoder model (backbone).
pixel_decoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Last hidden states (final feature map) of the last stage of the pixel decoder model (FPN).
transformer_decoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Last hidden states (final feature map) of the last stage of the transformer decoder model.
encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
shape `(batch_size, num_channels, height, width)`. Hidden-states (also called feature maps) of the encoder
model at the output of each stage.
pixel_decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
shape `(batch_size, num_channels, height, width)`. Hidden-states (also called feature maps) of the pixel
decoder model at the output of each stage.
transformer_decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the transformer decoder at the output
of each stage.
hidden_states `tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` containing `encoder_hidden_states`, `pixel_decoder_hidden_states` and
`decoder_hidden_states`.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`. Attentions weights from Detr's decoder after the attention softmax, used to compute the
weighted average in the self-attention heads.
"""
loss: Optional[torch.FloatTensor] = None
class_queries_logits: torch.FloatTensor = None
masks_queries_logits: torch.FloatTensor = None
auxiliary_logits: torch.FloatTensor = None
encoder_last_hidden_state: Optional[torch.FloatTensor] = None
pixel_decoder_last_hidden_state: Optional[torch.FloatTensor] = None
transformer_decoder_last_hidden_state: Optional[torch.FloatTensor] = None
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
pixel_decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
transformer_decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
def upsample_like(pixel_values: Tensor, like: Tensor, mode: str = "bilinear") -> Tensor:
"""
An utility function that upsamples `pixel_values` to match the dimension of `like`.
Args:
pixel_values (`torch.Tensor`):
The tensor we wish to upsample.
like (`torch.Tensor`):
The tensor we wish to use as size target.
mode (str, *optional*, defaults to `"bilinear"`):
The interpolation mode.
Returns:
`torch.Tensor`: The upsampled tensor
"""
_, _, height, width = like.shape
upsampled = nn.functional.interpolate(pixel_values, size=(height, width), mode=mode, align_corners=False)
return upsampled
# refactored from original implementation
def dice_loss(inputs: Tensor, labels: Tensor, num_masks: int) -> Tensor:
r"""
Compute the DICE loss, similar to generalized IOU for masks as follows:
$$ \mathcal{L}_{\text{dice}(x, y) = 1 - \frac{2 * x \cap y }{x \cup y + 1}} $$
In practice, since `labels` is a binary mask, (only 0s and 1s), dice can be computed as follow
$$ \mathcal{L}_{\text{dice}(x, y) = 1 - \frac{2 * x * y }{x + y + 1}} $$
Args:
inputs (`torch.Tensor`):
A tensor representing a mask.
labels (`torch.Tensor`):
A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs
(0 for the negative class and 1 for the positive class).
num_masks (`int`):
The number of masks present in the current batch, used for normalization.
Returns:
`torch.Tensor`: The computed loss.
"""
probs = inputs.sigmoid().flatten(1)
numerator = 2 * (probs * labels).sum(-1)
denominator = probs.sum(-1) + labels.sum(-1)
loss = 1 - (numerator + 1) / (denominator + 1)
loss = loss.sum() / num_masks
return loss
# refactored from original implementation
def sigmoid_focal_loss(
inputs: Tensor, labels: Tensor, num_masks: int, alpha: float = 0.25, gamma: float = 2
) -> Tensor:
r"""
Focal loss proposed in [Focal Loss for Dense Object Detection](https://arxiv.org/abs/1708.02002) originally used in
RetinaNet. The loss is computed as follows:
$$ \mathcal{L}_{\text{focal loss} = -(1 - p_t)^{\gamma}\log{(p_t)} $$
where \\(CE(p_t) = -\log{(p_t)}}\\), CE is the standard Cross Entropy Loss
Please refer to equation (1,2,3) of the paper for a better understanding.
Args:
inputs (`torch.Tensor`):
A float tensor of arbitrary shape.
labels (`torch.Tensor`):
A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs
(0 for the negative class and 1 for the positive class).
num_masks (`int`):
The number of masks present in the current batch, used for normalization.
alpha (float, *optional*, defaults to 0.25):
Weighting factor in range (0,1) to balance positive vs negative examples.
gamma (float, *optional*, defaults to 2.0):
Exponent of the modulating factor \\(1 - p_t\\) to balance easy vs hard examples.
Returns:
`torch.Tensor`: The computed loss.
"""
criterion = nn.BCEWithLogitsLoss(reduction="none")
probs = inputs.sigmoid()
cross_entropy_loss = criterion(inputs, labels)
p_t = probs * labels + (1 - probs) * (1 - labels)
loss = cross_entropy_loss * ((1 - p_t) ** gamma)
if alpha >= 0:
alpha_t = alpha * labels + (1 - alpha) * (1 - labels)
loss = alpha_t * loss
loss = loss.mean(1).sum() / num_masks
return loss
# refactored from original implementation
def pair_wise_dice_loss(inputs: Tensor, labels: Tensor) -> Tensor:
"""
A pair wise version of the dice loss, see `dice_loss` for usage.
Args:
inputs (`torch.Tensor`):
A tensor representing a mask
labels (`torch.Tensor`):
A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs
(0 for the negative class and 1 for the positive class).
Returns:
`torch.Tensor`: The computed loss between each pairs.
"""
inputs = inputs.sigmoid().flatten(1)
numerator = 2 * torch.einsum("nc,mc->nm", inputs, labels)
# using broadcasting to get a [NUM_QUERIES, NUM_CLASSES] matrix
denominator = inputs.sum(-1)[:, None] + labels.sum(-1)[None, :]
loss = 1 - (numerator + 1) / (denominator + 1)
return loss
# refactored from original implementation
def pair_wise_sigmoid_focal_loss(inputs: Tensor, labels: Tensor, alpha: float = 0.25, gamma: float = 2.0) -> Tensor:
r"""
A pair wise version of the focal loss, see `sigmoid_focal_loss` for usage.
Args:
inputs (`torch.Tensor`):
A tensor representing a mask.
labels (`torch.Tensor`):
A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs
(0 for the negative class and 1 for the positive class).
alpha (float, *optional*, defaults to 0.25):
Weighting factor in range (0,1) to balance positive vs negative examples.
gamma (float, *optional*, defaults to 2.0):
Exponent of the modulating factor \\(1 - p_t\\) to balance easy vs hard examples.
Returns:
`torch.Tensor`: The computed loss between each pairs.
"""
if alpha < 0:
raise ValueError("alpha must be positive")
height_and_width = inputs.shape[1]
criterion = nn.BCEWithLogitsLoss(reduction="none")
prob = inputs.sigmoid()
cross_entropy_loss_pos = criterion(inputs, torch.ones_like(inputs))
focal_pos = ((1 - prob) ** gamma) * cross_entropy_loss_pos
focal_pos *= alpha
cross_entropy_loss_neg = criterion(inputs, torch.zeros_like(inputs))
focal_neg = (prob**gamma) * cross_entropy_loss_neg
focal_neg *= 1 - alpha
loss = torch.einsum("nc,mc->nm", focal_pos, labels) + torch.einsum("nc,mc->nm", focal_neg, (1 - labels))
return loss / height_and_width
# Copied from transformers.models.vit.modeling_vit.to_2tuple
def to_2tuple(x):
if isinstance(x, collections.abc.Iterable):
return x
return (x, x)
# Copied from transformers.models.swin.modeling_swin.window_partition
def window_partition(input_feature, window_size):
"""
Partitions the given input into windows.
"""
batch_size, height, width, num_channels = input_feature.shape
input_feature = input_feature.view(
batch_size, height // window_size, window_size, width // window_size, window_size, num_channels
)
windows = input_feature.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, num_channels)
return windows
# Copied from transformers.models.swin.modeling_swin.window_reverse
def window_reverse(windows, window_size, height, width):
"""
Merges windows to produce higher resolution features.
"""
batch_size = int(windows.shape[0] / (height * width / window_size / window_size))
windows = windows.view(batch_size, height // window_size, width // window_size, window_size, window_size, -1)
windows = windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(batch_size, height, width, -1)
return windows
# Copied from transformers.models.swin.modeling_swin.drop_path
def drop_path(input, drop_prob=0.0, training=False, scale_by_keep=True):
"""
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
if drop_prob == 0.0 or not training:
return input
keep_prob = 1 - drop_prob
shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = input.new_empty(shape).bernoulli_(keep_prob)
if keep_prob > 0.0 and scale_by_keep:
random_tensor.div_(keep_prob)
return input * random_tensor
class MaskFormerSwinEmbeddings(nn.Module):
"""
Construct the patch and position embeddings.
"""
def __init__(self, config):
super().__init__()
self.patch_embeddings = MaskFormerSwinPatchEmbeddings(
image_size=config.image_size,
patch_size=config.patch_size,
num_channels=config.num_channels,
embed_dim=config.embed_dim,
)
num_patches = self.patch_embeddings.num_patches
self.patch_grid = self.patch_embeddings.grid_size
if config.use_absolute_embeddings:
self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.embed_dim))
else:
self.position_embeddings = None
self.norm = nn.LayerNorm(config.embed_dim)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, pixel_values):
embeddings, output_dimensions = self.patch_embeddings(pixel_values)
embeddings = self.norm(embeddings)
if self.position_embeddings is not None:
embeddings = embeddings + self.position_embeddings
embeddings = self.dropout(embeddings)
return embeddings, output_dimensions
class MaskFormerSwinPatchEmbeddings(nn.Module):
"""
Image to Patch Embedding, including padding.
"""
def __init__(self, image_size=224, patch_size=16, num_channels=3, embed_dim=768):
super().__init__()
image_size = to_2tuple(image_size)
patch_size = to_2tuple(patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.image_size = image_size
self.patch_size = patch_size
self.num_patches = num_patches
self.grid_size = (image_size[0] // patch_size[0], image_size[1] // patch_size[1])
self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
def maybe_pad(self, pixel_values, height, width):
if width % self.patch_size[1] != 0:
pad_values = (0, self.patch_size[1] - width % self.patch_size[1])
pixel_values = nn.functional.pad(pixel_values, pad_values)
if height % self.patch_size[0] != 0:
pad_values = (0, 0, 0, self.patch_size[0] - height % self.patch_size[0])
pixel_values = nn.functional.pad(pixel_values, pad_values)
return pixel_values
def forward(self, pixel_values):
_, _, height, width = pixel_values.shape
# pad the input to be divisible by self.patch_size, if needed
pixel_values = self.maybe_pad(pixel_values, height, width)
embeddings = self.projection(pixel_values)
_, _, height, width = embeddings.shape
output_dimensions = (height, width)
embeddings_flat = embeddings.flatten(2).transpose(1, 2)
return embeddings_flat, output_dimensions
class MaskFormerSwinPatchMerging(nn.Module):
"""
Patch Merging Layer for maskformer model.
Args:
input_resolution (`Tuple[int]`):
Resolution of input feature.
dim (`int`):
Number of input channels.
norm_layer (`nn.Module`, *optional*, defaults to `nn.LayerNorm`):
Normalization layer class.
"""
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
super().__init__()
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(4 * dim)
def maybe_pad(self, input_feature, width, height):
should_pad = (height % 2 == 1) or (width % 2 == 1)
if should_pad:
pad_values = (0, 0, 0, width % 2, 0, height % 2)
input_feature = nn.functional.pad(input_feature, pad_values)
return input_feature
def forward(self, input_feature, input_dimensions):
height, width = input_dimensions
# `dim` is height * width
batch_size, dim, num_channels = input_feature.shape
input_feature = input_feature.view(batch_size, height, width, num_channels)
# pad input to be disible by width and height, if needed
input_feature = self.maybe_pad(input_feature, height, width)
# [batch_size, height/2, width/2, num_channels]
input_feature_0 = input_feature[:, 0::2, 0::2, :]
# [batch_size, height/2, width/2, num_channels]
input_feature_1 = input_feature[:, 1::2, 0::2, :]
# [batch_size, height/2, width/2, num_channels]
input_feature_2 = input_feature[:, 0::2, 1::2, :]
# [batch_size, height/2, width/2, num_channels]
input_feature_3 = input_feature[:, 1::2, 1::2, :]
# batch_size height/2 width/2 4*num_channels
input_feature = torch.cat([input_feature_0, input_feature_1, input_feature_2, input_feature_3], -1)
input_feature = input_feature.view(batch_size, -1, 4 * num_channels) # batch_size height/2*width/2 4*C
input_feature = self.norm(input_feature)
input_feature = self.reduction(input_feature)
return input_feature
# Copied from transformers.models.swin.modeling_swin.SwinDropPath with Swin->MaskFormerSwin
class MaskFormerSwinDropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
def __init__(self, drop_prob=None, scale_by_keep=True):
super(MaskFormerSwinDropPath, self).__init__()
self.drop_prob = drop_prob
self.scale_by_keep = scale_by_keep
def forward(self, input):
return drop_path(input, self.drop_prob, self.training, self.scale_by_keep)
# Copied from transformers.models.swin.modeling_swin.SwinSelfAttention with Swin->MaskFormerSwin
class MaskFormerSwinSelfAttention(nn.Module):
def __init__(self, config, dim, num_heads):
super().__init__()
if dim % num_heads != 0:
raise ValueError(
f"The hidden size ({dim}) is not a multiple of the number of attention " f"heads ({num_heads})"
)
self.num_attention_heads = num_heads
self.attention_head_size = int(dim / num_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.window_size = to_2tuple(config.window_size)
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), num_heads)
)
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w]))
coords_flatten = torch.flatten(coords, 1)
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
relative_coords[:, :, 0] += self.window_size[0] - 1
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1)
self.register_buffer("relative_position_index", relative_position_index)
self.query = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
self.key = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
self.value = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
output_attentions=False,
):
batch_size, dim, num_channels = hidden_states.shape
mixed_query_layer = self.query(hidden_states)
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)]
relative_position_bias = relative_position_bias.view(
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1
)
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
attention_scores = attention_scores + relative_position_bias.unsqueeze(0)
if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in MaskFormerSwinModel forward() function)
mask_shape = attention_mask.shape[0]
attention_scores = attention_scores.view(
batch_size // mask_shape, mask_shape, self.num_attention_heads, dim, dim
)
attention_scores = attention_scores + attention_mask.unsqueeze(1).unsqueeze(0)
attention_scores = attention_scores.view(-1, self.num_attention_heads, dim, dim)
# Normalize the attention scores to probabilities.
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)
# Mask heads if we want to
if head_mask is not None:
attention_probs = attention_probs * head_mask
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
return outputs
# Copied from transformers.models.swin.modeling_swin.SwinSelfOutput with Swin->MaskFormerSwin
class MaskFormerSwinSelfOutput(nn.Module):
def __init__(self, config, dim):
super().__init__()
self.dense = nn.Linear(dim, dim)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
return hidden_states
# Copied from transformers.models.swin.modeling_swin.SwinAttention with Swin->MaskFormerSwin
class MaskFormerSwinAttention(nn.Module):
def __init__(self, config, dim, num_heads):
super().__init__()
self.self = MaskFormerSwinSelfAttention(config, dim, num_heads)
self.output = MaskFormerSwinSelfOutput(config, dim)
self.pruned_heads = set()
def prune_heads(self, heads):
if len(heads) == 0:
return
heads, index = find_pruneable_heads_and_indices(
heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
)
# Prune linear layers
self.self.query = prune_linear_layer(self.self.query, index)
self.self.key = prune_linear_layer(self.self.key, index)
self.self.value = prune_linear_layer(self.self.value, index)
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
# Update hyper params and store pruned heads
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads)
def forward(self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False):
self_outputs = self.self(hidden_states, attention_mask, head_mask, output_attentions)
attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
return outputs
# Copied from transformers.models.swin.modeling_swin.SwinIntermediate with Swin->MaskFormerSwin
class MaskFormerSwinIntermediate(nn.Module):
def __init__(self, config, dim):
super().__init__()
self.dense = nn.Linear(dim, int(config.mlp_ratio * dim))
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = ACT2FN[config.hidden_act]
else:
self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
# Copied from transformers.models.swin.modeling_swin.SwinOutput with Swin->MaskFormerSwin
class MaskFormerSwinOutput(nn.Module):
def __init__(self, config, dim):
super().__init__()
self.dense = nn.Linear(int(config.mlp_ratio * dim), dim)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
return hidden_states
class MaskFormerSwinBlock(nn.Module):
def __init__(self, config, dim, input_resolution, num_heads, shift_size=0):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.shift_size = shift_size
self.window_size = config.window_size
self.input_resolution = input_resolution
self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps)
self.attention = MaskFormerSwinAttention(config, dim, num_heads)
self.drop_path = (
MaskFormerSwinDropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity()
)
self.layernorm_after = nn.LayerNorm(dim, eps=config.layer_norm_eps)
self.intermediate = MaskFormerSwinIntermediate(config, dim)
self.output = MaskFormerSwinOutput(config, dim)
def get_attn_mask(self, input_resolution):
if self.shift_size > 0:
# calculate attention mask for SW-MSA
height, width = input_resolution
img_mask = torch.zeros((1, height, width, 1))
height_slices = (
slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None),
)
width_slices = (
slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None),
)
count = 0
for height_slice in height_slices:
for width_slice in width_slices:
img_mask[:, height_slice, width_slice, :] = count
count += 1
mask_windows = window_partition(img_mask, self.window_size)
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
else:
attn_mask = None
return attn_mask
def maybe_pad(self, hidden_states, height, width):
pad_left = pad_top = 0
pad_rigth = (self.window_size - width % self.window_size) % self.window_size
pad_bottom = (self.window_size - height % self.window_size) % self.window_size
pad_values = (0, 0, pad_left, pad_rigth, pad_top, pad_bottom)
hidden_states = nn.functional.pad(hidden_states, pad_values)
return hidden_states, pad_values
def forward(self, hidden_states, input_dimensions, head_mask=None, output_attentions=False):
height, width = input_dimensions
batch_size, dim, channels = hidden_states.size()
shortcut = hidden_states
hidden_states = self.layernorm_before(hidden_states)
hidden_states = hidden_states.view(batch_size, height, width, channels)
# pad hidden_states to multiples of window size
hidden_states, pad_values = self.maybe_pad(hidden_states, height, width)
_, height_pad, width_pad, _ = hidden_states.shape
# cyclic shift
if self.shift_size > 0:
shifted_hidden_states = torch.roll(hidden_states, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
else:
shifted_hidden_states = hidden_states
# partition windows
hidden_states_windows = window_partition(shifted_hidden_states, self.window_size)
hidden_states_windows = hidden_states_windows.view(-1, self.window_size * self.window_size, channels)
attn_mask = self.get_attn_mask((height_pad, width_pad))
if attn_mask is not None:
attn_mask = attn_mask.to(hidden_states_windows.device)
self_attention_outputs = self.attention(
hidden_states_windows, attn_mask, head_mask, output_attentions=output_attentions
)
attention_output = self_attention_outputs[0]
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
attention_windows = attention_output.view(-1, self.window_size, self.window_size, channels)
shifted_windows = window_reverse(attention_windows, self.window_size, height_pad, width_pad) # B H' W' C
# reverse cyclic shift
if self.shift_size > 0:
attention_windows = torch.roll(shifted_windows, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
else:
attention_windows = shifted_windows
was_padded = pad_values[3] > 0 or pad_values[5] > 0
if was_padded:
attention_windows = attention_windows[:, :height, :width, :].contiguous()
attention_windows = attention_windows.view(batch_size, height * width, channels)
hidden_states = shortcut + self.drop_path(attention_windows)
layer_output = self.layernorm_after(hidden_states)
layer_output = self.intermediate(layer_output)
layer_output = hidden_states + self.output(layer_output)
outputs = (layer_output,) + outputs
return outputs
class MaskFormerSwinLayer(nn.Module):
def __init__(self, config, dim, input_resolution, depth, num_heads, drop_path, downsample):
super().__init__()
self.config = config
self.dim = dim
self.blocks = nn.ModuleList(
[
MaskFormerSwinBlock(
config=config,
dim=dim,
input_resolution=input_resolution,
num_heads=num_heads,
shift_size=0 if (i % 2 == 0) else config.window_size // 2,
)
for i in range(depth)
]
)
# patch merging layer
if downsample is not None:
self.downsample = downsample(input_resolution, dim=dim, norm_layer=nn.LayerNorm)
else:
self.downsample = None
self.pointing = False
def forward(
self, hidden_states, input_dimensions, head_mask=None, output_attentions=False, output_hidden_states=False
):
all_hidden_states = () if output_hidden_states else None
height, width = input_dimensions
for i, block_module in enumerate(self.blocks):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_head_mask = head_mask[i] if head_mask is not None else None
block_hidden_states = block_module(hidden_states, input_dimensions, layer_head_mask, output_attentions)
hidden_states = block_hidden_states[0]
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.downsample is not None:
height_downsampled, width_downsampled = (height + 1) // 2, (width + 1) // 2
output_dimensions = (height, width, height_downsampled, width_downsampled)
hidden_states = self.downsample(hidden_states, input_dimensions)
else:
output_dimensions = (height, width, height, width)
return hidden_states, output_dimensions, all_hidden_states
class MaskFormerSwinEncoder(nn.Module):
def __init__(self, config, grid_size):
super().__init__()
self.num_layers = len(config.depths)
self.config = config
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths))]
self.layers = nn.ModuleList(
[
MaskFormerSwinLayer(
config=config,
dim=int(config.embed_dim * 2**i_layer),
input_resolution=(grid_size[0] // (2**i_layer), grid_size[1] // (2**i_layer)),
depth=config.depths[i_layer],
num_heads=config.num_heads[i_layer],
drop_path=dpr[sum(config.depths[:i_layer]) : sum(config.depths[: i_layer + 1])],
downsample=MaskFormerSwinPatchMerging if (i_layer < self.num_layers - 1) else None,
)
for i_layer in range(self.num_layers)
]
)
self.gradient_checkpointing = False
def forward(
self,
hidden_states,
input_dimensions,
head_mask=None,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
):
all_hidden_states = () if output_hidden_states else None
all_input_dimensions = ()
all_self_attentions = () if output_attentions else None
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
for i, layer_module in enumerate(self.layers):
layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_hidden_states, output_dimensions, layer_all_hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module), hidden_states, layer_head_mask
)
else:
layer_hidden_states, output_dimensions, layer_all_hidden_states = layer_module(
hidden_states,
input_dimensions,
layer_head_mask,
output_attentions,
output_hidden_states,
)
input_dimensions = (output_dimensions[-2], output_dimensions[-1])
all_input_dimensions += (input_dimensions,)
if output_hidden_states:
all_hidden_states += (layer_all_hidden_states,)
hidden_states = layer_hidden_states
if output_attentions:
all_self_attentions = all_self_attentions + (layer_all_hidden_states[1],)
if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
return MaskFormerSwinBaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
hidden_states_spatial_dimensions=all_input_dimensions,
attentions=all_self_attentions,
)
class MaskFormerSwinModel(nn.Module, ModuleUtilsMixin):
def __init__(self, config, add_pooling_layer=True):
super().__init__()
self.config = config
self.num_layers = len(config.depths)
self.num_features = int(config.embed_dim * 2 ** (self.num_layers - 1))
self.embeddings = MaskFormerSwinEmbeddings(config)
self.encoder = MaskFormerSwinEncoder(config, self.embeddings.patch_grid)
self.layernorm = nn.LayerNorm(self.num_features, eps=config.layer_norm_eps)
self.pooler = nn.AdaptiveAvgPool1d(1) if add_pooling_layer else None
def get_input_embeddings(self):
return self.embeddings.patch_embeddings
def _prune_heads(self, heads_to_prune):
"""
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
class PreTrainedModel
"""
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)
def forward(
self,
pixel_values=None,
head_mask=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if pixel_values is None:
raise ValueError("You have to specify pixel_values")
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, len(self.config.depths))
embedding_output, input_dimensions = self.embeddings(pixel_values)
encoder_outputs = self.encoder(
embedding_output,
input_dimensions,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = encoder_outputs.last_hidden_state
sequence_output = self.layernorm(sequence_output)
pooled_output = None
if self.pooler is not None:
pooled_output = self.pooler(sequence_output.transpose(1, 2))
pooled_output = torch.flatten(pooled_output, 1)
if not return_dict:
return (sequence_output, pooled_output) + encoder_outputs[1:]
hidden_states_spatial_dimensions = (input_dimensions,) + encoder_outputs.hidden_states_spatial_dimensions
return MaskFormerSwinModelOutputWithPooling(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
hidden_states_spatial_dimensions=hidden_states_spatial_dimensions,
attentions=encoder_outputs.attentions,
)
# Copied from transformers.models.detr.modeling_detr.DetrAttention
class DetrAttention(nn.Module):
"""
Multi-headed attention from 'Attention Is All You Need' paper.
Here, we add position embeddings to the queries and keys (as explained in the DETR paper).
"""
def __init__(
self,
embed_dim: int,
num_heads: int,
dropout: float = 0.0,
is_decoder: bool = False,
bias: bool = True,
):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
assert (
self.head_dim * num_heads == self.embed_dim
), f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {num_heads})."
self.scaling = self.head_dim**-0.5
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
def with_pos_embed(self, tensor: torch.Tensor, position_embeddings: Optional[Tensor]):
return tensor if position_embeddings is None else tensor + position_embeddings
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_embeddings: Optional[torch.Tensor] = None,
key_value_states: Optional[torch.Tensor] = None,
key_value_position_embeddings: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
bsz, tgt_len, embed_dim = hidden_states.size()
# add position embeddings to the hidden states before projecting to queries and keys
if position_embeddings is not None:
hidden_states_original = hidden_states
hidden_states = self.with_pos_embed(hidden_states, position_embeddings)
# add key-value position embeddings to the key value states
if key_value_position_embeddings is not None:
key_value_states_original = key_value_states
key_value_states = self.with_pos_embed(key_value_states, key_value_position_embeddings)
# get query proj
query_states = self.q_proj(hidden_states) * self.scaling
# get key, value proj
if is_cross_attention:
# cross_attentions
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
value_states = self._shape(self.v_proj(key_value_states_original), -1, bsz)
else:
# self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states_original), -1, bsz)
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
key_states = key_states.view(*proj_shape)
value_states = value_states.view(*proj_shape)
src_len = key_states.size(1)
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
raise ValueError(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
if output_attentions:
# this operation is a bit awkward, but it's required to
# make sure that attn_weights keeps its gradient.
# In order to do so, attn_weights have to reshaped
# twice and have to be reused in the following
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
else:
attn_weights_reshaped = None
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = torch.bmm(attn_probs, value_states)
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
)
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output, attn_weights_reshaped
# Copied from transformers.models.detr.modeling_detr.DetrDecoderLayer
class DetrDecoderLayer(nn.Module):
def __init__(self, config: DetrConfig):
super().__init__()
self.embed_dim = config.d_model
self.self_attn = DetrAttention(
embed_dim=self.embed_dim,
num_heads=config.decoder_attention_heads,
dropout=config.attention_dropout,
is_decoder=True,
)
self.dropout = config.dropout
self.activation_fn = ACT2FN[config.activation_function]
self.activation_dropout = config.activation_dropout
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.encoder_attn = DetrAttention(
self.embed_dim,
config.decoder_attention_heads,
dropout=config.attention_dropout,
is_decoder=True,
)
self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_embeddings: Optional[torch.Tensor] = None,
query_position_embeddings: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
):
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
attention_mask (`torch.FloatTensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
position_embeddings (`torch.FloatTensor`, *optional*):
position embeddings that are added to the queries and keys
in the cross-attention layer.
query_position_embeddings (`torch.FloatTensor`, *optional*):
position embeddings that are added to the queries and keys
in the self-attention layer.
encoder_hidden_states (`torch.FloatTensor`):
cross attention input to the layer of shape `(seq_len, batch, embed_dim)`
encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
"""
residual = hidden_states
# Self Attention
hidden_states, self_attn_weights = self.self_attn(
hidden_states=hidden_states,
position_embeddings=query_position_embeddings,
attention_mask=attention_mask,
output_attentions=output_attentions,
)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
# Cross-Attention Block
cross_attn_weights = None
if encoder_hidden_states is not None:
residual = hidden_states
hidden_states, cross_attn_weights = self.encoder_attn(
hidden_states=hidden_states,
position_embeddings=query_position_embeddings,
key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
key_value_position_embeddings=position_embeddings,
output_attentions=output_attentions,
)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
hidden_states = self.encoder_attn_layer_norm(hidden_states)
# Fully Connected
residual = hidden_states
hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
hidden_states = self.fc2(hidden_states)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
hidden_states = self.final_layer_norm(hidden_states)
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights, cross_attn_weights)
return outputs
# Copied from transformers.models.detr.modeling_detr._expand_mask
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)
class DetrDecoder(nn.Module):
"""
Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`DetrDecoderLayer`].
The decoder updates the query embeddings through multiple self-attention and cross-attention layers.
Some small tweaks for DETR:
- position_embeddings and query_position_embeddings are added to the forward pass.
- if self.config.auxiliary_loss is set to True, also returns a stack of activations from all decoding layers.
Args:
config: DetrConfig
"""
def __init__(self, config: DetrConfig):
super().__init__()
self.config = config
self.dropout = config.dropout
self.layerdrop = config.decoder_layerdrop
self.layers = nn.ModuleList([DetrDecoderLayer(config) for _ in range(config.decoder_layers)])
# in DETR, the decoder uses layernorm after the last decoder layer output
self.layernorm = nn.LayerNorm(config.d_model)
self.gradient_checkpointing = False
def forward(
self,
inputs_embeds=None,
attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
position_embeddings=None,
query_position_embeddings=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
r"""
Args:
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
The query embeddings that are passed into the decoder.
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on certain queries. Mask values selected in `[0, 1]`:
- 1 for queries that are **not masked**,
- 0 for queries that are **masked**.
[What are attention masks?](../glossary#attention-mask)
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
of the decoder.
encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
Mask to avoid performing cross-attention on padding pixel_values of the encoder. Mask values selected
in `[0, 1]`:
- 1 for pixels that are real (i.e. **not masked**),
- 0 for pixels that are padding (i.e. **masked**).
position_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Position embeddings that are added to the queries and keys in each cross-attention layer.
query_position_embeddings (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`):
, *optional*): Position embeddings that are added to the queries and keys in each self-attention layer.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if inputs_embeds is not None:
hidden_states = inputs_embeds
input_shape = inputs_embeds.size()[:-1]
combined_attention_mask = None
if attention_mask is not None and combined_attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = combined_attention_mask + _expand_mask(
attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
)
# expand encoder attention mask
if encoder_hidden_states is not None and encoder_attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
# optional intermediate hidden states
intermediate = () if self.config.auxiliary_loss else None
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if output_hidden_states:
all_hidden_states += (hidden_states,)
dropout_probability = random.uniform(0, 1)
if self.training and (dropout_probability < self.layerdrop):
continue
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states,
combined_attention_mask,
encoder_hidden_states,
encoder_attention_mask,
None,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=combined_attention_mask,
position_embeddings=position_embeddings,
query_position_embeddings=query_position_embeddings,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
if self.config.auxiliary_loss:
hidden_states = self.layernorm(hidden_states)
intermediate += (hidden_states,)
if output_attentions:
all_self_attns += (layer_outputs[1],)
if encoder_hidden_states is not None:
all_cross_attentions += (layer_outputs[2],)
# finally, apply layernorm
hidden_states = self.layernorm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
# stack intermediate decoder activations
if self.config.auxiliary_loss:
intermediate = torch.stack(intermediate)
if not return_dict:
return tuple(
v
for v in [hidden_states, all_hidden_states, all_self_attns, all_cross_attentions, intermediate]
if v is not None
)
return DetrDecoderOutput(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_self_attns,
cross_attentions=all_cross_attentions,
intermediate_hidden_states=intermediate,
)
# refactored from original implementation
class MaskFormerHungarianMatcher(nn.Module):
"""This class computes an assignment between the labels and the predictions of the network.
For efficiency reasons, the labels don't include the no_object. Because of this, in general, there are more
predictions than labels. In this case, we do a 1-to-1 matching of the best predictions, while the others are
un-matched (and thus treated as non-objects).
"""
def __init__(self, cost_class: float = 1.0, cost_mask: float = 1.0, cost_dice: float = 1.0):
"""Creates the matcher
Params:
cost_class (float, *optional*, defaults to 1.0):
This is the relative weight of the classification error in the matching cost.
cost_mask (float, *optional*, defaults to 1.0):
This is the relative weight of the focal loss of the binary mask in the matching cost.
cost_dice (float, *optional*, defaults to 1.0):
This is the relative weight of the dice loss of the binary mask in the matching cost
"""
super().__init__()
if cost_class == 0 and cost_mask == 0 and cost_dice == 0:
raise ValueError("All costs cant be 0")
self.cost_class = cost_class
self.cost_mask = cost_mask
self.cost_dice = cost_dice
@torch.no_grad()
def forward(self, masks_queries_logits, class_queries_logits, mask_labels, class_labels) -> List[Tuple[Tensor]]:
"""Performs the matching
Params:
masks_queries_logits (`torch.Tensor`):
A tensor` of dim `batch_size, num_queries, num_classes` with the
classification logits.
class_queries_logits (`torch.Tensor`):
A tensor` of dim `batch_size, num_queries, height, width` with the
predicted masks.
class_labels (`torch.Tensor`):
A tensor` of dim `num_target_boxes` (where num_target_boxes is the number
of ground-truth objects in the target) containing the class labels.
mask_labels (`torch.Tensor`):
A tensor` of dim `num_target_boxes, height, width` containing the target
masks.
Returns:
`List[Tuple[Tensor]]`: A list of size batch_size, containing tuples of (index_i, index_j) where:
- index_i is the indices of the selected predictions (in order)
- index_j is the indices of the corresponding selected labels (in order)
For each batch element, it holds:
len(index_i) = len(index_j) = min(num_queries, num_target_boxes).
"""
indices: List[Tuple[np.array]] = []
preds_masks = masks_queries_logits
preds_probs = class_queries_logits.softmax(dim=-1)
# downsample all masks in one go -> save memory
mask_labels = nn.functional.interpolate(mask_labels, size=preds_masks.shape[-2:], mode="nearest")
# iterate through batch size
for pred_probs, pred_mask, target_mask, labels in zip(preds_probs, preds_masks, mask_labels, class_labels):
# Compute the classification cost. Contrary to the loss, we don't use the NLL,
# but approximate it in 1 - proba[target class].
# The 1 is a constant that doesn't change the matching, it can be ommitted.
cost_class = -pred_probs[:, labels]
# flatten spatial dimension "q h w -> q (h w)"
num_queries, height, width = pred_mask.shape
pred_mask_flat = pred_mask.view(num_queries, height * width) # [num_queries, H*W]
# same for target_mask "c h w -> c (h w)"
num_channels, height, width = target_mask.shape
target_mask_flat = target_mask.view(num_channels, height * width) # [num_total_labels, H*W]
# compute the focal loss between each mask pairs -> shape [NUM_QUERIES, CLASSES]
cost_mask = pair_wise_sigmoid_focal_loss(pred_mask_flat, target_mask_flat)
# Compute the dice loss betwen each mask pairs -> shape [NUM_QUERIES, CLASSES]
cost_dice = pair_wise_dice_loss(pred_mask_flat, target_mask_flat)
# final cost matrix
cost_matrix = self.cost_mask * cost_mask + self.cost_class * cost_class + self.cost_dice * cost_dice
# do the assigmented using the hungarian algorithm in scipy
assigned_indices: Tuple[np.array] = linear_sum_assignment(cost_matrix.cpu())
indices.append(assigned_indices)
# It could be stacked in one tensor
matched_indices = [
(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices
]
return matched_indices
def __repr__(self):
head = "Matcher " + self.__class__.__name__
body = [
f"cost_class: {self.cost_class}",
f"cost_mask: {self.cost_mask}",
f"cost_dice: {self.cost_dice}",
]
_repr_indent = 4
lines = [head] + [" " * _repr_indent + line for line in body]
return "\n".join(lines)
# copied and adapted from original implementation
class MaskFormerLoss(nn.Module):
def __init__(
self,
num_classes: int,
matcher: MaskFormerHungarianMatcher,
weight_dict: Dict[str, float],
eos_coef: float,
):
"""
The MaskFormer Loss. The loss is computed very similar to DETR. The process happens in two steps: 1) we compute
hungarian assignment between ground truth masks and the outputs of the model 2) we supervise each pair of
matched ground-truth / prediction (supervise class and mask)
Args:
num_classes (`int`):
The number of classes.
matcher (`MaskFormerHungarianMatcher`):
A torch module that computes the assigments between the predictions and labels.
weight_dict (`Dict[str, float]`):
A dictionary of weights to be applied to the different losses.
eos_coef (`float`):
Weight to apply to the null class.
"""
super().__init__()
requires_backends(self, ["scipy"])
self.num_classes = num_classes
self.matcher = matcher
self.weight_dict = weight_dict
self.eos_coef = eos_coef
empty_weight = torch.ones(self.num_classes + 1)
empty_weight[-1] = self.eos_coef
self.register_buffer("empty_weight", empty_weight)
def loss_labels(
self, class_queries_logits: Tensor, class_labels: Tensor, indices: Tuple[np.array]
) -> Dict[str, Tensor]:
"""Compute the losses related to the labels using cross entropy.
Args:
class_queries_logits (`torch.Tensor`):
A tensor of shape `batch_size, num_queries, num_classes`
class_labels (`Dict[str, Tensor]`):
A tensor of shape `batch_size, num_classes`
indices (`Tuple[np.array])`:
The indices computed by the Hungarian matcher.
Returns:
`Dict[str, Tensor]`: A dict of `torch.Tensor` containing the following key:
- **loss_cross_entropy** -- The loss computed using cross entropy on the predicted and ground truth labels.
"""
pred_logits = class_queries_logits
batch_size, num_queries, _ = pred_logits.shape
criterion = nn.CrossEntropyLoss(weight=self.empty_weight)
idx = self._get_predictions_permutation_indices(indices)
# shape = [BATCH, N_QUERIES]
target_classes_o = torch.cat([target[j] for target, (_, j) in zip(class_labels, indices)])
# shape = [BATCH, N_QUERIES]
target_classes = torch.full(
(batch_size, num_queries), fill_value=self.num_classes, dtype=torch.int64, device=pred_logits.device
)
target_classes[idx] = target_classes_o
# target_classes is a [BATCH, CLASSES, N_QUERIES], we need to permute pred_logits "b q c -> b c q"
pred_logits_permuted = pred_logits.permute(0, 2, 1)
loss_ce = criterion(pred_logits_permuted, target_classes)
losses = {"loss_cross_entropy": loss_ce}
return losses
def loss_masks(
self, masks_queries_logits: Tensor, mask_labels: Tensor, indices: Tuple[np.array], num_masks: int
) -> Dict[str, Tensor]:
"""Compute the losses related to the masks using focal and dice loss.
Args:
masks_queries_logits (`torch.Tensor`):
A tensor of shape `batch_size, num_queries, height, width`
mask_labels (`torch.Tensor`):
A tensor of shape `batch_size, num_queries, height, width`
indices (`Tuple[np.array])`:
The indices computed by the Hungarian matcher.
num_masks (`int)`:
The number of masks, used for normalization.
Returns:
`Dict[str, Tensor]`: A dict of `torch.Tensor` containing two keys:
- **loss_mask** -- The loss computed using sigmoid focal loss on the predicted and ground truth masks.
- **loss_dice** -- The loss computed using dice loss on the predicted on the predicted and ground truth
masks.
"""
src_idx = self._get_predictions_permutation_indices(indices)
tgt_idx = self._get_targets_permutation_indices(indices)
pred_masks = masks_queries_logits # shape [BATCH, NUM_QUERIES, H, W]
pred_masks = pred_masks[src_idx] # shape [BATCH * NUM_QUERIES, H, W]
target_masks = mask_labels # shape [BATCH, NUM_QUERIES, H, W]
target_masks = target_masks[tgt_idx] # shape [BATCH * NUM_QUERIES, H, W]
# upsample predictions to the target size, we have to add one dim to use interpolate
pred_masks = nn.functional.interpolate(
pred_masks[:, None], size=target_masks.shape[-2:], mode="bilinear", align_corners=False
)
pred_masks = pred_masks[:, 0].flatten(1)
target_masks = target_masks.flatten(1)
target_masks = target_masks.view(pred_masks.shape)
losses = {
"loss_mask": sigmoid_focal_loss(pred_masks, target_masks, num_masks),
"loss_dice": dice_loss(pred_masks, target_masks, num_masks),
}
return losses
def _get_predictions_permutation_indices(self, indices):
# permute predictions following indices
batch_indices = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
predictions_indices = torch.cat([src for (src, _) in indices])
return batch_indices, predictions_indices
def _get_targets_permutation_indices(self, indices):
# permute labels following indices
batch_indices = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
target_indices = torch.cat([tgt for (_, tgt) in indices])
return batch_indices, target_indices
def get_loss(self, loss, outputs, labels, indices, num_masks):
loss_map = {"labels": self.loss_labels, "masks": self.loss_masks}
if loss not in loss_map:
raise KeyError(f"{loss} not in loss_map")
return loss_map[loss](outputs, labels, indices, num_masks)
def forward(
self,
masks_queries_logits: torch.Tensor,
class_queries_logits: torch.Tensor,
mask_labels: torch.Tensor,
class_labels: torch.Tensor,
auxiliary_predictions: Optional[Dict[str, torch.Tensor]] = None,
) -> Dict[str, Tensor]:
"""
This performs the loss computation.
Args:
masks_queries_logits (`torch.Tensor`):
A tensor of shape `batch_size, num_queries, height, width`
class_queries_logits (`torch.Tensor`):
A tensor of shape `batch_size, num_queries, num_classes`
mask_labels (`torch.Tensor`):
A tensor of shape `batch_size, num_classes, height, width`
class_labels (`torch.Tensor`):
A tensor of shape `batch_size, num_classes`
auxiliary_predictions (`Dict[str, torch.Tensor]`, *optional*):
if `use_auxiliary_loss` was set to `true` in [`MaskFormerConfig`], then it contains the logits from the
inner layers of the Detr's Decoder.
Returns:
`Dict[str, Tensor]`: A dict of `torch.Tensor` containing two keys:
- **loss_cross_entropy** -- The loss computed using cross entropy on the predicted and ground truth labels.
- **loss_mask** -- The loss computed using sigmoid focal loss on the predicted and ground truth masks.
- **loss_dice** -- The loss computed using dice loss on the predicted on the predicted and ground truth
masks.
if `use_auxiliary_loss` was set to `true` in [`MaskFormerConfig`], the dictionary contains addional losses
for each auxiliary predictions.
"""
# Retrieve the matching between the outputs of the last layer and the labels
indices = self.matcher(masks_queries_logits, class_queries_logits, mask_labels, class_labels)
# Compute the average number of target masks accross all nodes, for normalization purposes
num_masks: Number = self.get_num_masks(class_labels, device=class_labels.device)
# Compute all the requested losses
losses: Dict[str, Tensor] = {
**self.loss_masks(masks_queries_logits, mask_labels, indices, num_masks),
**self.loss_labels(class_queries_logits, class_labels, indices),
}
# In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
if auxiliary_predictions is not None:
for idx, aux_outputs in enumerate(auxiliary_predictions):
masks_queries_logits = aux_outputs["masks_queries_logits"]
class_queries_logits = aux_outputs["class_queries_logits"]
loss_dict = self.forward(masks_queries_logits, class_queries_logits, mask_labels, class_labels)
loss_dict = {f"{key}_{idx}": value for key, value in loss_dict.items()}
losses.update(loss_dict)
return losses
def get_num_masks(self, class_labels: torch.Tensor, device: torch.device) -> torch.Tensor:
# Compute the average number of target masks accross all nodes, for normalization purposes
num_masks = class_labels.shape[0]
num_masks_pt = torch.as_tensor([num_masks], dtype=torch.float, device=device)
return num_masks_pt
class MaskFormerSwinTransformerBackbone(nn.Module):
"""
This class uses [`MaskFormerSwinModel`] to reshape its `hidden_states` from (`batch_size, sequence_length,
hidden_size)` to (`batch_size, num_channels, height, width)`).
Args:
config (`SwinConfig`):
The configuration used by [`MaskFormerSwinModel`].
"""
def __init__(self, config: SwinConfig):
super().__init__()
self.model = MaskFormerSwinModel(config)
self.hidden_states_norms = nn.ModuleList([nn.LayerNorm(out_shape) for out_shape in self.outputs_shapes])
def forward(self, *args, **kwargs) -> List[Tensor]:
output = self.model(*args, **kwargs, output_hidden_states=True)
hidden_states_permuted: List[Tensor] = []
# we need to reshape the hidden state to their original spatial dimensions
# skipping the embeddings
hidden_states: Tuple[Tuple[Tensor]] = output.hidden_states[1:]
# spatial dimensions contains all the heights and widths of each stage, including after the embeddings
spatial_dimensions: Tuple[Tuple[int, int]] = output.hidden_states_spatial_dimensions
for i, (hidden_state, (height, width)) in enumerate(zip(hidden_states, spatial_dimensions)):
norm = self.hidden_states_norms[i]
# the last element corespond to the layer's last block output but before patch merging
hidden_state_unpolled = hidden_state[-1]
hidden_state_norm = norm(hidden_state_unpolled)
# our pixel decoder (FPN) expect 3D tensors (features)
batch_size, _, hidden_size = hidden_state_norm.shape
# reshape our tensor "b (h w) d -> b d h w"
hidden_state_permuted = (
hidden_state_norm.permute(0, 2, 1).view((batch_size, hidden_size, height, width)).contiguous()
)
hidden_states_permuted.append(hidden_state_permuted)
return hidden_states_permuted
@property
def input_resolutions(self) -> List[int]:
return [layer.input_resolution for layer in self.model.encoder.layers]
@property
def outputs_shapes(self) -> List[int]:
return [layer.dim for layer in self.model.encoder.layers]
class MaskFormerFPNConvLayer(nn.Sequential):
def __init__(self, in_features: int, out_features: int, kernel_size: int = 3, padding: int = 1):
"""
A basic module that executes conv - norm - in sequence used in MaskFormer.
Args:
in_features (`int`):
The number of input features (channels).
out_features (`int`):
The number of outputs features (channels).
"""
super().__init__(
nn.Conv2d(in_features, out_features, kernel_size=kernel_size, padding=padding, bias=False),
nn.GroupNorm(32, out_features),
nn.ReLU(inplace=True),
)
class MaskFormerFPNLayer(nn.Module):
def __init__(self, in_features: int, lateral_features: int):
"""
A Feature Pyramid Network Layer (FPN) layer. It creates a feature map by aggregating features from the previous
and backbone layer. Due to the spatial mismatch, the tensor coming from the previous layer is upsampled.
Args:
in_features (`int`):
The number of input features (channels).
lateral_features (`int`):
The number of lateral features (channels).
"""
super().__init__()
self.proj = nn.Sequential(
nn.Conv2d(lateral_features, in_features, kernel_size=1, padding=0, bias=False),
nn.GroupNorm(32, in_features),
)
self.block = MaskFormerFPNConvLayer(in_features, in_features)
def forward(self, down: Tensor, left: Tensor) -> Tensor:
left = self.proj(left)
down = nn.functional.interpolate(down, size=left.shape[-2:], mode="nearest")
down += left
down = self.block(down)
return down
class MaskFormerFPNModel(nn.Module):
def __init__(self, in_features: int, lateral_widths: List[int], feature_size: int = 256):
"""
Feature Pyramid Network, given an input tensor and a set of feature map of different feature/spatial size, it
creates a list of feature maps with the same feature size.
Args:
in_features (`int`):
The number of input features (channels).
lateral_widths (`List[int]`):
A list with the features (channels) size of each lateral connection.
feature_size (int, *optional*, defaults to 256):
The features (channels) of the resulting feature maps.
"""
super().__init__()
self.stem = MaskFormerFPNConvLayer(in_features, feature_size)
self.layers = nn.Sequential(
*[MaskFormerFPNLayer(feature_size, lateral_width) for lateral_width in lateral_widths[::-1]]
)
def forward(self, features: List[Tensor]) -> List[Tensor]:
fpn_features = []
last_feature = features[-1]
other_features = features[:-1]
output = self.stem(last_feature)
for layer, left in zip(self.layers, other_features[::-1]):
output = layer(output, left)
fpn_features.append(output)
return fpn_features
class MaskFormerPixelDecoder(nn.Module):
def __init__(self, *args, feature_size: int = 256, mask_feature_size: int = 256, **kwargs):
"""
Pixel Decoder Module proposed in [Per-Pixel Classification is Not All You Need for Semantic
Segmentation](https://arxiv.org/abs/2107.06278). It first runs the backbone's feature into a Feature Pyramid
Network creating a list of feature maps. Then, it projects the last one to the correct `mask_size`.
Args:
feature_size (`int`, *optional*, defaults to 256):
The feature size (channel dimension) of the FPN feature maps.
mask_feature_size (`int`, *optional*, defaults to 256):
The features (channels) of the target masks size \\C_{\epsilon}\\ in the paper.
"""
super().__init__()
self.fpn = MaskFormerFPNModel(*args, feature_size=feature_size, **kwargs)
self.mask_projection = nn.Conv2d(feature_size, mask_feature_size, kernel_size=3, padding=1)
def forward(self, features: List[Tensor], output_hidden_states: bool = False) -> MaskFormerPixelDecoderOutput:
fpn_features: List[Tensor] = self.fpn(features)
# we use the last feature map
last_feature_projected = self.mask_projection(fpn_features[-1])
return MaskFormerPixelDecoderOutput(
last_hidden_state=last_feature_projected, hidden_states=tuple(fpn_features) if output_hidden_states else ()
)
# copied and adapted from original implementation, also practically equal to DetrSinePositionEmbedding
class MaskFormerSinePositionEmbedding(nn.Module):
"""
This is a more standard version of the position embedding, very similar to the one used by the Attention is all you
need paper, generalized to work on images.
"""
def __init__(
self, num_pos_feats: int = 64, temperature: int = 10000, normalize: bool = False, scale: Optional[float] = None
):
super().__init__()
if scale is not None and normalize is False:
raise ValueError("normalize should be True if scale is passed")
self.num_pos_feats = num_pos_feats
self.temperature = temperature
self.normalize = normalize
self.scale = 2 * torch.pi if scale is None else scale
def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor:
if mask is None:
mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool)
not_mask = ~mask
y_embed = not_mask.cumsum(1, dtype=torch.float32)
x_embed = not_mask.cumsum(2, dtype=torch.float32)
if self.normalize:
eps = 1e-6
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_pos_feats)
pos_x = x_embed[:, :, :, None] / dim_t
pos_y = y_embed[:, :, :, None] / dim_t
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
return pos
class MaskformerMLPPredictionHead(nn.Sequential):
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int = 3):
"""
A classic Multi Layer Perceptron (MLP).
Args:
input_dim (`int`):
The input dimensions.
hidden_dim (`int`):
The hidden dimensions.
output_dim (`int`):
The output dimensions.
num_layers (int, *optional*, defaults to 3):
The number of layers.
"""
in_dims = [input_dim] + [hidden_dim] * (num_layers - 1)
out_dims = [hidden_dim] * (num_layers - 1) + [output_dim]
layers = []
for i, (in_dim, out_dim) in enumerate(zip(in_dims, out_dims)):
layer = nn.Sequential(
nn.Linear(in_dim, out_dim), nn.ReLU(inplace=True) if i < num_layers - 1 else nn.Identity()
)
layers.append(layer)
super().__init__(*layers)
class MaskFormerPixelLevelModule(nn.Module):
def __init__(self, config: MaskFormerConfig):
"""
Pixel Level Module proposed in [Per-Pixel Classification is Not All You Need for Semantic
Segmentation](https://arxiv.org/abs/2107.06278). It runs the input image through a backbone and a pixel
decoder, generating an image feature map and pixel embeddings.
Args:
config ([`MaskFormerConfig`]):
The configuration used to instantiate this model.
"""
super().__init__()
self.encoder = MaskFormerSwinTransformerBackbone(config.backbone_config)
self.decoder = MaskFormerPixelDecoder(
in_features=self.encoder.outputs_shapes[-1],
feature_size=config.fpn_feature_size,
mask_feature_size=config.mask_feature_size,
lateral_widths=self.encoder.outputs_shapes[:-1],
)
def forward(self, pixel_values: Tensor, output_hidden_states: bool = False) -> MaskFormerPixelLevelModuleOutput:
features: List[Tensor] = self.encoder(pixel_values)
decoder_output: MaskFormerPixelDecoderOutput = self.decoder(features, output_hidden_states)
return MaskFormerPixelLevelModuleOutput(
# the last feature is actually the output from the last layer
encoder_last_hidden_state=features[-1],
decoder_last_hidden_state=decoder_output.last_hidden_state,
encoder_hidden_states=tuple(features) if output_hidden_states else (),
decoder_hidden_states=decoder_output.hidden_states if output_hidden_states else (),
)
class MaskFormerTransformerModule(nn.Module):
"""
The MaskFormer's transformer module.
"""
def __init__(self, in_features: int, config: MaskFormerConfig):
super().__init__()
hidden_size = config.decoder_config.hidden_size
should_project = in_features != hidden_size
self.position_embedder = MaskFormerSinePositionEmbedding(num_pos_feats=hidden_size // 2, normalize=True)
self.queries_embedder = nn.Embedding(config.decoder_config.num_queries, hidden_size)
self.input_projection = nn.Conv2d(in_features, hidden_size, kernel_size=1) if should_project else None
self.decoder = DetrDecoder(config=config.decoder_config)
def forward(
self, image_features: Tensor, output_hidden_states: bool = False, output_attentions: bool = False
) -> DetrDecoderOutput:
if self.input_projection is not None:
image_features = self.input_projection(image_features)
position_embeddings = self.position_embedder(image_features)
# repeat the queries "q c -> b q c"
batch_size = image_features.shape[0]
queries_embeddings = self.queries_embedder.weight.unsqueeze(0).repeat(batch_size, 1, 1)
inputs_embeds = torch.zeros_like(queries_embeddings, requires_grad=True)
batch_size, num_channels, height, width = image_features.shape
# rearrange both image_features and position_embeddings "b c h w -> b (h w) c"
image_features = image_features.view(batch_size, num_channels, height * width).permute(0, 2, 1)
position_embeddings = position_embeddings.view(batch_size, num_channels, height * width).permute(0, 2, 1)
decoder_output: DetrDecoderOutput = self.decoder(
inputs_embeds=inputs_embeds,
attention_mask=None,
encoder_hidden_states=image_features,
encoder_attention_mask=None,
position_embeddings=position_embeddings,
query_position_embeddings=queries_embeddings,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=None,
)
return decoder_output
MASKFORMER_START_DOCSTRING = r"""
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
behavior.
Parameters:
config ([`MaskFormerConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
MASKFORMER_INPUTS_DOCSTRING = r"""
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values. Pixel values can be obtained using [`AutoFeatureExtractor`]. See
[`AutoFeatureExtractor.__call__`] for details.
pixel_mask (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
Mask to avoid performing attention on padding pixel values. Mask values selected in `[0, 1]`:
- 1 for pixels that are real (i.e. **not masked**),
- 0 for pixels that are padding (i.e. **masked**).
[What are attention masks?](../glossary#attention-mask)
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of Detr's decoder attention layers.
return_dict (`bool`, *optional*):
Whether or not to return a [`~MaskFormerModelOutput`] instead of a plain tuple.
"""
class MaskFormerPreTrainedModel(PreTrainedModel):
config_class = MaskFormerConfig
base_model_prefix = "model"
main_input_name = "pixel_values"
def _init_weights(self, module: nn.Module):
xavier_std = self.config.init_xavier_std
std = self.config.init_std
if isinstance(module, MaskFormerTransformerModule):
if module.input_projection is not None:
nn.init.xavier_uniform_(module.input_projection.weight, gain=xavier_std)
nn.init.constant_(module.input_projection.bias, 0)
# FPN
elif isinstance(module, MaskFormerFPNModel):
nn.init.xavier_uniform_(module.stem[0].weight, gain=xavier_std)
elif isinstance(module, MaskFormerFPNLayer):
nn.init.xavier_uniform_(module.proj[0].weight, gain=xavier_std)
elif isinstance(module, MaskFormerFPNConvLayer):
nn.init.xavier_uniform_(module[0].weight, gain=xavier_std)
# The MLP head
elif isinstance(module, MaskformerMLPPredictionHead):
# I was not able to find the correct initializer in the original implementation
# we'll use xavier
for layer in module:
nn.init.xavier_uniform_(layer[0].weight, gain=xavier_std)
nn.init.constant_(layer[0].bias, 0)
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
# copied from DETR
if isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, MaskFormerSwinEncoder):
module.gradient_checkpointing = value
if isinstance(module, DetrDecoder):
module.gradient_checkpointing = value
@add_start_docstrings(
"The bare MaskFormer Model outputting raw hidden-states without any specific head on top.",
MASKFORMER_START_DOCSTRING,
)
class MaskFormerModel(MaskFormerPreTrainedModel):
def __init__(self, config: MaskFormerConfig):
super().__init__(config)
self.pixel_level_module = MaskFormerPixelLevelModule(config)
self.transformer_module = MaskFormerTransformerModule(
in_features=self.pixel_level_module.encoder.outputs_shapes[-1], config=config
)
self.post_init()
@add_start_docstrings_to_model_forward(MASKFORMER_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=MaskFormerModelOutput,
config_class=_CONFIG_FOR_DOC,
modality="vision",
)
def forward(
self,
pixel_values: Tensor,
pixel_mask: Optional[Tensor] = None,
output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> MaskFormerModelOutput:
if pixel_values is None:
raise ValueError("You have to specify pixel_values")
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
batch_size, _, height, width = pixel_values.shape
if pixel_mask is None:
pixel_mask = torch.ones((batch_size, height, width), device=pixel_values.device)
pixel_level_module_output: MaskFormerPixelLevelModuleOutput = self.pixel_level_module(
pixel_values, output_hidden_states
)
image_features = pixel_level_module_output.encoder_last_hidden_state
pixel_embeddings = pixel_level_module_output.decoder_last_hidden_state
transformer_module_output: DetrDecoderOutput = self.transformer_module(
image_features, output_hidden_states, output_attentions
)
queries = transformer_module_output.last_hidden_state
encoder_hidden_states = pixel_level_module_output.encoder_hidden_states if output_hidden_states else ()
pixel_decoder_hidden_states = pixel_level_module_output.decoder_hidden_states if output_hidden_states else ()
transformer_decoder_hidden_states = transformer_module_output.hidden_states if output_hidden_states else ()
output = MaskFormerModelOutput(
encoder_last_hidden_state=image_features,
pixel_decoder_last_hidden_state=pixel_embeddings,
transformer_decoder_last_hidden_state=queries,
encoder_hidden_states=encoder_hidden_states,
pixel_decoder_hidden_states=pixel_decoder_hidden_states,
transformer_decoder_hidden_states=transformer_decoder_hidden_states,
hidden_states=encoder_hidden_states + pixel_decoder_hidden_states + transformer_decoder_hidden_states,
attentions=transformer_module_output.attentions,
)
if not return_dict:
output = tuple(v for v in output.values())
return output
class MaskFormerForInstanceSegmentation(MaskFormerPreTrainedModel):
def __init__(self, config: MaskFormerConfig):
super().__init__(config)
self.model = MaskFormerModel(config)
hidden_size = config.decoder_config.hidden_size
# + 1 because we add the "null" class
self.class_predictor = nn.Linear(hidden_size, config.num_labels + 1)
self.mask_embedder = MaskformerMLPPredictionHead(hidden_size, hidden_size, config.mask_feature_size)
self.matcher = MaskFormerHungarianMatcher(
cost_class=1.0, cost_dice=config.dice_weight, cost_mask=config.mask_weight
)
self.weight_dict: Dict[str, float] = {
"loss_cross_entropy": config.cross_entropy_weight,
"loss_mask": config.mask_weight,
"loss_dice": config.dice_weight,
}
self.criterion = MaskFormerLoss(
config.num_labels,
matcher=self.matcher,
weight_dict=self.weight_dict,
eos_coef=config.no_object_weight,
)
self.post_init()
def get_loss_dict(
self,
masks_queries_logits: Tensor,
class_queries_logits: Tensor,
mask_labels: Tensor,
class_labels: Tensor,
auxiliary_logits: Dict[str, Tensor],
) -> Dict[str, Tensor]:
loss_dict: Dict[str, Tensor] = self.criterion(
masks_queries_logits, class_queries_logits, mask_labels, class_labels, auxiliary_logits
)
# weight each loss by `self.weight_dict[<LOSS_NAME>]`
weighted_loss_dict: Dict[str, Tensor] = {
k: v * self.weight_dict[k] for k, v in loss_dict.items() if k in self.weight_dict
}
return weighted_loss_dict
def get_loss(self, loss_dict: Dict[str, Tensor]) -> Tensor:
return sum(loss_dict.values())
def get_logits(self, outputs: MaskFormerModelOutput) -> Tuple[Tensor, Tensor, Dict[str, Tensor]]:
pixel_embeddings = outputs.pixel_decoder_last_hidden_state
# get the auxiliary predictions (one for each decoder's layer)
auxiliary_logits: List[str, Tensor] = []
# This code is a little bit cumbersome, an improvement can be to return a list of predictions. If we have auxiliary loss then we are going to return more than one element in the list
if self.config.use_auxiliary_loss:
stacked_transformer_decoder_outputs = torch.stack(outputs.transformer_decoder_hidden_states)
classes = self.class_predictor(stacked_transformer_decoder_outputs)
class_queries_logits = classes[-1]
# get the masks
mask_embeddings = self.mask_embedder(stacked_transformer_decoder_outputs)
# sum up over the channels for each embedding
binaries_masks = torch.einsum("lbqc, bchw -> lbqhw", mask_embeddings, pixel_embeddings)
masks_queries_logits = binaries_masks[-1]
# go til [:-1] because the last one is always used
for aux_binary_masks, aux_classes in zip(binaries_masks[:-1], classes[:-1]):
auxiliary_logits.append(
{"masks_queries_logits": aux_binary_masks, "class_queries_logits": aux_classes}
)
else:
transformer_decoder_hidden_states = outputs.transformer_decoder_last_hidden_state
classes = self.class_predictor(transformer_decoder_hidden_states)
class_queries_logits = classes
# get the masks
mask_embeddings = self.mask_embedder(transformer_decoder_hidden_states)
# sum up over the channels
masks_queries_logits = torch.einsum("bqc, bchw -> bqhw", mask_embeddings, pixel_embeddings)
return class_queries_logits, masks_queries_logits, auxiliary_logits
@add_start_docstrings_to_model_forward(MASKFORMER_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=MaskFormerForInstanceSegmentationOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
pixel_values: Tensor,
mask_labels: Optional[Tensor] = None,
class_labels: Optional[Tensor] = None,
pixel_mask: Optional[Tensor] = None,
output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> MaskFormerForInstanceSegmentationOutput:
r"""
mask_labels (`torch.FloatTensor`, *optional*):
The target mask of shape `(num_classes, height, width)`.
class_labels (`torch.LongTensor`, *optional*):
The target labels of shape `(num_classes)`.
Returns:
Examples:
```python
>>> from transformers import MaskFormerFeatureExtractor, MaskFormerForInstanceSegmentation
>>> from PIL import Image
>>> import requests
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> feature_extractor = MaskFormerFeatureExtractor.from_pretrained("facebook/maskformer-swin-base-ade")
>>> inputs = feature_extractor(images=image, return_tensors="pt")
>>> model = MaskFormerForInstanceSegmentation.from_pretrained("facebook/maskformer-swin-base-ade")
>>> outputs = model(**inputs)
>>> # model predicts class_queries_logits of shape `(batch_size, num_queries)`
>>> # and masks_queries_logits of shape `(batch_size, num_queries, height, width)`
>>> class_queries_logits = outputs.class_queries_logits
>>> masks_queries_logits = outputs.masks_queries_logits
>>> # you can pass them to feature_extractor for postprocessing
>>> output = feature_extractor.post_process_segmentation(outputs)
>>> output = feature_extractor.post_process_semantic_segmentation(outputs)
>>> output = feature_extractor.post_process_panoptic_segmentation(outputs)
```
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs: MaskFormerModelOutput = self.model(
pixel_values,
pixel_mask,
output_hidden_states=output_hidden_states,
return_dict=True,
output_attentions=output_attentions,
)
loss, loss_dict, auxiliary_logits = None, None, None
class_queries_logits, masks_queries_logits, auxiliary_logits = self.get_logits(outputs)
if mask_labels is not None and class_labels is not None:
loss_dict: Dict[str, Tensor] = self.get_loss_dict(
masks_queries_logits, class_queries_logits, mask_labels, class_labels, auxiliary_logits
)
loss = self.get_loss(loss_dict)
output = MaskFormerForInstanceSegmentationOutput(
loss=loss,
**outputs,
class_queries_logits=class_queries_logits,
masks_queries_logits=masks_queries_logits,
auxiliary_logits=auxiliary_logits,
)
if not return_dict:
output = tuple(v for v in output.values())
if loss is not None:
output = ((loss)) + output
return output
......@@ -2401,6 +2401,30 @@ class MarianMTModel(metaclass=DummyObject):
requires_backends(self, ["torch"])
MASKFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None
class MaskFormerForInstanceSegmentation(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class MaskFormerModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class MaskFormerPreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class MBartForCausalLM(metaclass=DummyObject):
_backends = ["torch"]
......
......@@ -80,6 +80,13 @@ class LayoutXLMProcessor(metaclass=DummyObject):
requires_backends(self, ["vision"])
class MaskFormerFeatureExtractor(metaclass=DummyObject):
_backends = ["vision"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["vision"])
class PerceiverFeatureExtractor(metaclass=DummyObject):
_backends = ["vision"]
......
# coding=utf-8
# Copyright 2022 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from transformers.file_utils import is_torch_available, is_vision_available
from transformers.testing_utils import require_torch, require_vision
from ..test_feature_extraction_common import FeatureExtractionSavingTestMixin, prepare_image_inputs
if is_torch_available():
import torch
if is_vision_available():
from transformers import MaskFormerFeatureExtractor
if is_vision_available():
from PIL import Image
class MaskFormerFeatureExtractionTester(unittest.TestCase):
def __init__(
self,
parent,
batch_size=7,
num_channels=3,
min_resolution=30,
max_resolution=400,
do_resize=True,
size=32,
max_size=1333, # by setting max_size > max_resolution we're effectively not testing this :p
do_normalize=True,
image_mean=[0.5, 0.5, 0.5],
image_std=[0.5, 0.5, 0.5],
):
self.parent = parent
self.batch_size = batch_size
self.num_channels = num_channels
self.min_resolution = min_resolution
self.max_resolution = max_resolution
self.do_resize = do_resize
self.size = size
self.max_size = max_size
self.do_normalize = do_normalize
self.image_mean = image_mean
self.image_std = image_std
self.size_divisibility = 0
def prepare_feat_extract_dict(self):
return {
"do_resize": self.do_resize,
"size": self.size,
"max_size": self.max_size,
"do_normalize": self.do_normalize,
"image_mean": self.image_mean,
"image_std": self.image_std,
"size_divisibility": self.size_divisibility,
}
def get_expected_values(self, image_inputs, batched=False):
"""
This function computes the expected height and width when providing images to MaskFormerFeatureExtractor,
assuming do_resize is set to True with a scalar size.
"""
if not batched:
image = image_inputs[0]
if isinstance(image, Image.Image):
w, h = image.size
else:
h, w = image.shape[1], image.shape[2]
if w < h:
expected_height = int(self.size * h / w)
expected_width = self.size
elif w > h:
expected_height = self.size
expected_width = int(self.size * w / h)
else:
expected_height = self.size
expected_width = self.size
else:
expected_values = []
for image in image_inputs:
expected_height, expected_width = self.get_expected_values([image])
expected_values.append((expected_height, expected_width))
expected_height = max(expected_values, key=lambda item: item[0])[0]
expected_width = max(expected_values, key=lambda item: item[1])[1]
return expected_height, expected_width
@require_torch
@require_vision
class MaskFormerFeatureExtractionTest(FeatureExtractionSavingTestMixin, unittest.TestCase):
feature_extraction_class = MaskFormerFeatureExtractor if (is_vision_available() and is_torch_available()) else None
def setUp(self):
self.feature_extract_tester = MaskFormerFeatureExtractionTester(self)
@property
def feat_extract_dict(self):
return self.feature_extract_tester.prepare_feat_extract_dict()
def test_feat_extract_properties(self):
feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
self.assertTrue(hasattr(feature_extractor, "image_mean"))
self.assertTrue(hasattr(feature_extractor, "image_std"))
self.assertTrue(hasattr(feature_extractor, "do_normalize"))
self.assertTrue(hasattr(feature_extractor, "do_resize"))
self.assertTrue(hasattr(feature_extractor, "size"))
self.assertTrue(hasattr(feature_extractor, "max_size"))
def test_batch_feature(self):
pass
def test_call_pil(self):
# Initialize feature_extractor
feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
# create random PIL images
image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False)
for image in image_inputs:
self.assertIsInstance(image, Image.Image)
# Test not batched input
encoded_images = feature_extractor(image_inputs[0], return_tensors="pt").pixel_values
expected_height, expected_width = self.feature_extract_tester.get_expected_values(image_inputs)
self.assertEqual(
encoded_images.shape,
(1, self.feature_extract_tester.num_channels, expected_height, expected_width),
)
# Test batched
expected_height, expected_width = self.feature_extract_tester.get_expected_values(image_inputs, batched=True)
encoded_images = feature_extractor(image_inputs, return_tensors="pt").pixel_values
self.assertEqual(
encoded_images.shape,
(
self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels,
expected_height,
expected_width,
),
)
def test_call_numpy(self):
# Initialize feature_extractor
feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
# create random numpy tensors
image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False, numpify=True)
for image in image_inputs:
self.assertIsInstance(image, np.ndarray)
# Test not batched input
encoded_images = feature_extractor(image_inputs[0], return_tensors="pt").pixel_values
expected_height, expected_width = self.feature_extract_tester.get_expected_values(image_inputs)
self.assertEqual(
encoded_images.shape,
(1, self.feature_extract_tester.num_channels, expected_height, expected_width),
)
# Test batched
encoded_images = feature_extractor(image_inputs, return_tensors="pt").pixel_values
expected_height, expected_width = self.feature_extract_tester.get_expected_values(image_inputs, batched=True)
self.assertEqual(
encoded_images.shape,
(
self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels,
expected_height,
expected_width,
),
)
def test_call_pytorch(self):
# Initialize feature_extractor
feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
# create random PyTorch tensors
image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False, torchify=True)
for image in image_inputs:
self.assertIsInstance(image, torch.Tensor)
# Test not batched input
encoded_images = feature_extractor(image_inputs[0], return_tensors="pt").pixel_values
expected_height, expected_width = self.feature_extract_tester.get_expected_values(image_inputs)
self.assertEqual(
encoded_images.shape,
(1, self.feature_extract_tester.num_channels, expected_height, expected_width),
)
# Test batched
encoded_images = feature_extractor(image_inputs, return_tensors="pt").pixel_values
expected_height, expected_width = self.feature_extract_tester.get_expected_values(image_inputs, batched=True)
self.assertEqual(
encoded_images.shape,
(
self.feature_extract_tester.batch_size,
self.feature_extract_tester.num_channels,
expected_height,
expected_width,
),
)
def test_equivalence_pad_and_create_pixel_mask(self):
# Initialize feature_extractors
feature_extractor_1 = self.feature_extraction_class(**self.feat_extract_dict)
feature_extractor_2 = self.feature_extraction_class(do_resize=False, do_normalize=False)
# create random PyTorch tensors
image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False, torchify=True)
for image in image_inputs:
self.assertIsInstance(image, torch.Tensor)
# Test whether the method "pad_and_return_pixel_mask" and calling the feature extractor return the same tensors
encoded_images_with_method = feature_extractor_1.encode_inputs(image_inputs, return_tensors="pt")
encoded_images = feature_extractor_2(image_inputs, return_tensors="pt")
self.assertTrue(
torch.allclose(encoded_images_with_method["pixel_values"], encoded_images["pixel_values"], atol=1e-4)
)
self.assertTrue(
torch.allclose(encoded_images_with_method["pixel_mask"], encoded_images["pixel_mask"], atol=1e-4)
)
def comm_get_feature_extractor_inputs(self, with_annotations=False):
feature_extractor = self.feature_extraction_class(**self.feat_extract_dict)
# prepare image and target
num_classes = 8
batch_size = self.feature_extract_tester.batch_size
annotations = None
if with_annotations:
annotations = [
{
"masks": np.random.rand(num_classes, 384, 384).astype(np.float32),
"labels": (np.random.rand(num_classes) > 0.5).astype(np.int64),
}
for _ in range(batch_size)
]
image_inputs = prepare_image_inputs(self.feature_extract_tester, equal_resolution=False)
inputs = feature_extractor(image_inputs, annotations, return_tensors="pt", pad_and_return_pixel_mask=True)
return inputs
def test_with_size_divisibility(self):
size_divisibilities = [8, 16, 32]
weird_input_sizes = [(407, 802), (582, 1094)]
for size_divisibility in size_divisibilities:
feat_extract_dict = {**self.feat_extract_dict, **{"size_divisibility": size_divisibility}}
feature_extractor = self.feature_extraction_class(**feat_extract_dict)
for weird_input_size in weird_input_sizes:
inputs = feature_extractor([np.ones((3, *weird_input_size))], return_tensors="pt")
pixel_values = inputs["pixel_values"]
# check if divisible
self.assertTrue((pixel_values.shape[-1] % size_divisibility) == 0)
self.assertTrue((pixel_values.shape[-2] % size_divisibility) == 0)
def test_call_with_numpy_annotations(self):
num_classes = 8
batch_size = self.feature_extract_tester.batch_size
inputs = self.comm_get_feature_extractor_inputs(with_annotations=True)
# check the batch_size
for el in inputs.values():
self.assertEqual(el.shape[0], batch_size)
pixel_values = inputs["pixel_values"]
mask_labels = inputs["mask_labels"]
class_labels = inputs["class_labels"]
self.assertEqual(pixel_values.shape[-2], mask_labels.shape[-2])
self.assertEqual(pixel_values.shape[-1], mask_labels.shape[-1])
self.assertEqual(mask_labels.shape[1], class_labels.shape[1])
self.assertEqual(mask_labels.shape[1], num_classes)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment