Unverified Commit 561b9a8c authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[SegFormer] TensorFlow port (#17910)



* add: segformer utils and img. classification.

* add: segmentation layer.

* feat: working implementation of segformer.

* chore: remove unused variable.

* add test, remaining modifications.

* remove: unnecessary files.

* add: rest of the files.
Co-authored-by: default avatarmatt <rocketknight1@gmail.com>

* chore: remove ModuleList comment.

* chore: apply make style.

* chore: apply make fixup-copies.

* add  to check_repo.py

* add decode head to IGNORE_NON_TESTED

* chore: run make style.

* chore: PR comments.

* chore: minor changes to model doc.

* tests: reduction across samples.

* add a note on the space.

* sort importats.

* fix: reduction in loss computation.

* chore: align loss function with that of NER.

* chore: correct utils/documentation_tests.txt
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>

* chore: simplify the interpolation of logits in loss computation.

* chore: return transposed logits when return_dict=False.

* chore: add link to the tf fine-tuning repo.

* address pr comments.

* address niels's comments.

* remove from_pt=True since tf weights are in.

* remove comment from pt model.

* address niels's comments.
Co-authored-by: default avatarmatt <rocketknight1@gmail.com>
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>
parent 2c5747ed
...@@ -278,7 +278,7 @@ Flax), PyTorch, and/or TensorFlow. ...@@ -278,7 +278,7 @@ Flax), PyTorch, and/or TensorFlow.
| RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ | | RetriBERT | ✅ | ✅ | ✅ | ❌ | ❌ |
| RoBERTa | ✅ | ✅ | ✅ | ✅ | ✅ | | RoBERTa | ✅ | ✅ | ✅ | ✅ | ✅ |
| RoFormer | ✅ | ✅ | ✅ | ✅ | ✅ | | RoFormer | ✅ | ✅ | ✅ | ✅ | ✅ |
| SegFormer | ❌ | ❌ | ✅ | | ❌ | | SegFormer | ❌ | ❌ | ✅ | | ❌ |
| SEW | ❌ | ❌ | ✅ | ❌ | ❌ | | SEW | ❌ | ❌ | ✅ | ❌ | ❌ |
| SEW-D | ❌ | ❌ | ✅ | ❌ | ❌ | | SEW-D | ❌ | ❌ | ✅ | ❌ | ❌ |
| Speech Encoder decoder | ❌ | ❌ | ✅ | ❌ | ✅ | | Speech Encoder decoder | ❌ | ❌ | ✅ | ❌ | ✅ |
......
...@@ -36,13 +36,14 @@ The figure below illustrates the architecture of SegFormer. Taken from the [orig ...@@ -36,13 +36,14 @@ The figure below illustrates the architecture of SegFormer. Taken from the [orig
<img width="600" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/segformer_architecture.png"/> <img width="600" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/segformer_architecture.png"/>
This model was contributed by [nielsr](https://huggingface.co/nielsr). The original code can be found [here](https://github.com/NVlabs/SegFormer). This model was contributed by [nielsr](https://huggingface.co/nielsr). The TensorFlow version
of the model was contributed by [sayakpaul](https://huggingface.co/sayakpaul). The original code can be found [here](https://github.com/NVlabs/SegFormer).
Tips: Tips:
- SegFormer consists of a hierarchical Transformer encoder, and a lightweight all-MLP decode head. - SegFormer consists of a hierarchical Transformer encoder, and a lightweight all-MLP decoder head.
[`SegformerModel`] is the hierarchical Transformer encoder (which in the paper is also referred to [`SegformerModel`] is the hierarchical Transformer encoder (which in the paper is also referred to
as Mix Transformer or MiT). [`SegformerForSemanticSegmentation`] adds the all-MLP decode head on as Mix Transformer or MiT). [`SegformerForSemanticSegmentation`] adds the all-MLP decoder head on
top to perform semantic segmentation of images. In addition, there's top to perform semantic segmentation of images. In addition, there's
[`SegformerForImageClassification`] which can be used to - you guessed it - classify images. The [`SegformerForImageClassification`] which can be used to - you guessed it - classify images. The
authors of SegFormer first pre-trained the Transformer encoder on ImageNet-1k to classify images. Next, they throw authors of SegFormer first pre-trained the Transformer encoder on ImageNet-1k to classify images. Next, they throw
...@@ -51,6 +52,9 @@ Tips: ...@@ -51,6 +52,9 @@ Tips:
found on the [hub](https://huggingface.co/models?other=segformer). found on the [hub](https://huggingface.co/models?other=segformer).
- The quickest way to get started with SegFormer is by checking the [example notebooks](https://github.com/NielsRogge/Transformers-Tutorials/tree/master/SegFormer) (which showcase both inference and - The quickest way to get started with SegFormer is by checking the [example notebooks](https://github.com/NielsRogge/Transformers-Tutorials/tree/master/SegFormer) (which showcase both inference and
fine-tuning on custom data). One can also check out the [blog post](https://huggingface.co/blog/fine-tune-segformer) introducing SegFormer and illustrating how it can be fine-tuned on custom data. fine-tuning on custom data). One can also check out the [blog post](https://huggingface.co/blog/fine-tune-segformer) introducing SegFormer and illustrating how it can be fine-tuned on custom data.
- TensorFlow users should refer to [this repository](https://github.com/deep-diver/segformer-tf-transformers) that shows off-the-shelf inference and fine-tuning.
- One can also check out [this interactive demo on Hugging Face Spaces](https://huggingface.co/spaces/chansung/segformer-tf-transformers)
to try out a SegFormer model on custom images.
- SegFormer works on any input size, as it pads the input to be divisible by `config.patch_sizes`. - SegFormer works on any input size, as it pads the input to be divisible by `config.patch_sizes`.
- One can use [`SegformerFeatureExtractor`] to prepare images and corresponding segmentation maps - One can use [`SegformerFeatureExtractor`] to prepare images and corresponding segmentation maps
for the model. Note that this feature extractor is fairly basic and does not include all data augmentations used in for the model. Note that this feature extractor is fairly basic and does not include all data augmentations used in
...@@ -65,7 +69,8 @@ Tips: ...@@ -65,7 +69,8 @@ Tips:
used by [`SegformerForSemanticSegmentation`]). However, other datasets use the 0 index as used by [`SegformerForSemanticSegmentation`]). However, other datasets use the 0 index as
background class and include this class as part of all labels. In that case, `reduce_labels` should be set to background class and include this class as part of all labels. In that case, `reduce_labels` should be set to
`False`, as loss should also be computed for the background class. `False`, as loss should also be computed for the background class.
- As most models, SegFormer comes in different sizes, the details of which can be found in the table below. - As most models, SegFormer comes in different sizes, the details of which can be found in the table below
(taken from Table 7 of the [original paper](https://arxiv.org/abs/2105.15203)).
| **Model variant** | **Depths** | **Hidden sizes** | **Decoder hidden size** | **Params (M)** | **ImageNet-1k Top 1** | | **Model variant** | **Depths** | **Hidden sizes** | **Decoder hidden size** | **Params (M)** | **ImageNet-1k Top 1** |
| :---------------: | ------------- | ------------------- | :---------------------: | :------------: | :-------------------: | | :---------------: | ------------- | ------------------- | :---------------------: | :------------: | :-------------------: |
...@@ -76,6 +81,10 @@ Tips: ...@@ -76,6 +81,10 @@ Tips:
| MiT-b4 | [3, 8, 27, 3] | [64, 128, 320, 512] | 768 | 62.6 | 83.6 | | MiT-b4 | [3, 8, 27, 3] | [64, 128, 320, 512] | 768 | 62.6 | 83.6 |
| MiT-b5 | [3, 6, 40, 3] | [64, 128, 320, 512] | 768 | 82.0 | 83.8 | | MiT-b5 | [3, 6, 40, 3] | [64, 128, 320, 512] | 768 | 82.0 | 83.8 |
Note that MiT in the above table refers to the Mix Transformer encoder backbone introduced in SegFormer. For
SegFormer's results on the segmentation datasets like ADE20k, refer to the [paper](https://arxiv.org/abs/2105.15203).
## SegformerConfig ## SegformerConfig
[[autodoc]] SegformerConfig [[autodoc]] SegformerConfig
...@@ -104,3 +113,23 @@ Tips: ...@@ -104,3 +113,23 @@ Tips:
[[autodoc]] SegformerForSemanticSegmentation [[autodoc]] SegformerForSemanticSegmentation
- forward - forward
## TFSegformerDecodeHead
[[autodoc]] TFSegformerDecodeHead
- call
## TFSegformerModel
[[autodoc]] TFSegformerModel
- call
## TFSegformerForImageClassification
[[autodoc]] TFSegformerForImageClassification
- call
## TFSegformerForSemanticSegmentation
[[autodoc]] TFSegformerForSemanticSegmentation
- call
...@@ -2430,6 +2430,16 @@ else: ...@@ -2430,6 +2430,16 @@ else:
"TFRoFormerPreTrainedModel", "TFRoFormerPreTrainedModel",
] ]
) )
_import_structure["models.segformer"].extend(
[
"TF_SEGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFSegformerDecodeHead",
"TFSegformerForImageClassification",
"TFSegformerForSemanticSegmentation",
"TFSegformerModel",
"TFSegformerPreTrainedModel",
]
)
_import_structure["models.speech_to_text"].extend( _import_structure["models.speech_to_text"].extend(
[ [
"TF_SPEECH_TO_TEXT_PRETRAINED_MODEL_ARCHIVE_LIST", "TF_SPEECH_TO_TEXT_PRETRAINED_MODEL_ARCHIVE_LIST",
...@@ -4789,6 +4799,14 @@ if TYPE_CHECKING: ...@@ -4789,6 +4799,14 @@ if TYPE_CHECKING:
TFRoFormerModel, TFRoFormerModel,
TFRoFormerPreTrainedModel, TFRoFormerPreTrainedModel,
) )
from .models.segformer import (
TF_SEGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
TFSegformerDecodeHead,
TFSegformerForImageClassification,
TFSegformerForSemanticSegmentation,
TFSegformerModel,
TFSegformerPreTrainedModel,
)
from .models.speech_to_text import ( from .models.speech_to_text import (
TF_SPEECH_TO_TEXT_PRETRAINED_MODEL_ARCHIVE_LIST, TF_SPEECH_TO_TEXT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFSpeech2TextForConditionalGeneration, TFSpeech2TextForConditionalGeneration,
......
...@@ -68,6 +68,7 @@ TF_MODEL_MAPPING_NAMES = OrderedDict( ...@@ -68,6 +68,7 @@ TF_MODEL_MAPPING_NAMES = OrderedDict(
("resnet", "TFResNetModel"), ("resnet", "TFResNetModel"),
("roberta", "TFRobertaModel"), ("roberta", "TFRobertaModel"),
("roformer", "TFRoFormerModel"), ("roformer", "TFRoFormerModel"),
("segformer", "TFSegformerModel"),
("speech_to_text", "TFSpeech2TextModel"), ("speech_to_text", "TFSpeech2TextModel"),
("swin", "TFSwinModel"), ("swin", "TFSwinModel"),
("t5", "TFT5Model"), ("t5", "TFT5Model"),
...@@ -180,6 +181,7 @@ TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( ...@@ -180,6 +181,7 @@ TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
("deit", ("TFDeiTForImageClassification", "TFDeiTForImageClassificationWithTeacher")), ("deit", ("TFDeiTForImageClassification", "TFDeiTForImageClassificationWithTeacher")),
("regnet", "TFRegNetForImageClassification"), ("regnet", "TFRegNetForImageClassification"),
("resnet", "TFResNetForImageClassification"), ("resnet", "TFResNetForImageClassification"),
("segformer", "TFSegformerForImageClassification"),
("swin", "TFSwinForImageClassification"), ("swin", "TFSwinForImageClassification"),
("vit", "TFViTForImageClassification"), ("vit", "TFViTForImageClassification"),
] ]
...@@ -189,6 +191,7 @@ TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES = OrderedDict( ...@@ -189,6 +191,7 @@ TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES = OrderedDict(
[ [
# Model for Semantic Segmentation mapping # Model for Semantic Segmentation mapping
("data2vec-vision", "TFData2VecVisionForSemanticSegmentation"), ("data2vec-vision", "TFData2VecVisionForSemanticSegmentation"),
("segformer", "TFSegformerForSemanticSegmentation"),
] ]
) )
......
...@@ -17,7 +17,13 @@ ...@@ -17,7 +17,13 @@
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_tf_available,
is_torch_available,
is_vision_available,
)
_import_structure = {"configuration_segformer": ["SEGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "SegformerConfig"]} _import_structure = {"configuration_segformer": ["SEGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "SegformerConfig"]}
...@@ -46,6 +52,21 @@ else: ...@@ -46,6 +52,21 @@ else:
"SegformerPreTrainedModel", "SegformerPreTrainedModel",
] ]
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_tf_segformer"] = [
"TF_SEGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFSegformerDecodeHead",
"TFSegformerForImageClassification",
"TFSegformerForSemanticSegmentation",
"TFSegformerModel",
"TFSegformerPreTrainedModel",
]
if TYPE_CHECKING: if TYPE_CHECKING:
from .configuration_segformer import SEGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, SegformerConfig from .configuration_segformer import SEGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, SegformerConfig
...@@ -73,7 +94,20 @@ if TYPE_CHECKING: ...@@ -73,7 +94,20 @@ if TYPE_CHECKING:
SegformerModel, SegformerModel,
SegformerPreTrainedModel, SegformerPreTrainedModel,
) )
try:
if not is_tf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_tf_segformer import (
TF_SEGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
TFSegformerDecodeHead,
TFSegformerForImageClassification,
TFSegformerForSemanticSegmentation,
TFSegformerModel,
TFSegformerPreTrainedModel,
)
else: else:
import sys import sys
......
...@@ -785,6 +785,8 @@ class SegformerForSemanticSegmentation(SegformerPreTrainedModel): ...@@ -785,6 +785,8 @@ class SegformerForSemanticSegmentation(SegformerPreTrainedModel):
>>> inputs = feature_extractor(images=image, return_tensors="pt") >>> inputs = feature_extractor(images=image, return_tensors="pt")
>>> outputs = model(**inputs) >>> outputs = model(**inputs)
>>> logits = outputs.logits # shape (batch_size, num_labels, height, width) >>> logits = outputs.logits # shape (batch_size, num_labels, height, width)
>>> logits.shape
(1, 150, 128, 128)
```""" ```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
output_hidden_states = ( output_hidden_states = (
...@@ -804,7 +806,7 @@ class SegformerForSemanticSegmentation(SegformerPreTrainedModel): ...@@ -804,7 +806,7 @@ class SegformerForSemanticSegmentation(SegformerPreTrainedModel):
loss = None loss = None
if labels is not None: if labels is not None:
if self.config.num_labels == 1: if not self.config.num_labels > 1:
raise ValueError("The number of labels should be greater than one") raise ValueError("The number of labels should be greater than one")
else: else:
# upsample logits to the images' original size # upsample logits to the images' original size
......
This diff is collapsed.
...@@ -1980,6 +1980,44 @@ class TFRoFormerPreTrainedModel(metaclass=DummyObject): ...@@ -1980,6 +1980,44 @@ class TFRoFormerPreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["tf"]) requires_backends(self, ["tf"])
TF_SEGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None
class TFSegformerDecodeHead(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFSegformerForImageClassification(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFSegformerForSemanticSegmentation(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFSegformerModel(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFSegformerPreTrainedModel(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
TF_SPEECH_TO_TEXT_PRETRAINED_MODEL_ARCHIVE_LIST = None TF_SPEECH_TO_TEXT_PRETRAINED_MODEL_ARCHIVE_LIST = None
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
import inspect import inspect
import unittest import unittest
from transformers import is_torch_available, is_vision_available from transformers import SegformerConfig, is_torch_available, is_vision_available
from transformers.models.auto import get_values from transformers.models.auto import get_values
from transformers.testing_utils import require_torch, slow, torch_device from transformers.testing_utils import require_torch, slow, torch_device
...@@ -31,7 +31,6 @@ if is_torch_available(): ...@@ -31,7 +31,6 @@ if is_torch_available():
from transformers import ( from transformers import (
MODEL_MAPPING, MODEL_MAPPING,
SegformerConfig,
SegformerForImageClassification, SegformerForImageClassification,
SegformerForSemanticSegmentation, SegformerForSemanticSegmentation,
SegformerModel, SegformerModel,
......
This diff is collapsed.
...@@ -98,6 +98,7 @@ IGNORE_NON_TESTED = PRIVATE_MODELS.copy() + [ ...@@ -98,6 +98,7 @@ IGNORE_NON_TESTED = PRIVATE_MODELS.copy() + [
"FlaxBartForCausalLM", # Building part of bigger (tested) model. "FlaxBartForCausalLM", # Building part of bigger (tested) model.
"FlaxBertForCausalLM", # Building part of bigger (tested) model. Tested implicitly through FlaxRobertaForCausalLM. "FlaxBertForCausalLM", # Building part of bigger (tested) model. Tested implicitly through FlaxRobertaForCausalLM.
"OPTDecoderWrapper", "OPTDecoderWrapper",
"TFSegformerDecodeHead", # Not a regular model.
] ]
# Update this list with test files that don't have a tester with a `all_model_classes` variable and which don't # Update this list with test files that don't have a tester with a `all_model_classes` variable and which don't
...@@ -137,6 +138,7 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [ ...@@ -137,6 +138,7 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
"PerceiverForMultimodalAutoencoding", "PerceiverForMultimodalAutoencoding",
"PerceiverForOpticalFlow", "PerceiverForOpticalFlow",
"SegformerDecodeHead", "SegformerDecodeHead",
"TFSegformerDecodeHead",
"FlaxBeitForMaskedImageModeling", "FlaxBeitForMaskedImageModeling",
"PLBartEncoder", "PLBartEncoder",
"PLBartDecoder", "PLBartDecoder",
......
...@@ -64,6 +64,7 @@ src/transformers/models/sew_d/modeling_sew_d.py ...@@ -64,6 +64,7 @@ src/transformers/models/sew_d/modeling_sew_d.py
src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py
src/transformers/models/speech_to_text/modeling_speech_to_text.py src/transformers/models/speech_to_text/modeling_speech_to_text.py
src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py
src/transformers/models/segformer/modeling_tf_segformer.py
src/transformers/models/swin/modeling_swin.py src/transformers/models/swin/modeling_swin.py
src/transformers/models/trocr/modeling_trocr.py src/transformers/models/trocr/modeling_trocr.py
src/transformers/models/unispeech/modeling_unispeech.py src/transformers/models/unispeech/modeling_unispeech.py
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment