Unverified Commit ac6aa10f authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Standardize semantic segmentation models outputs (#15469)



* Standardize instance segmentation models outputs

* Rename output

* Update src/transformers/modeling_outputs.py
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Add legacy argument to the config and model forward

* Update src/transformers/models/beit/modeling_beit.py
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>

* Copy fix in Segformer
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>
parent 31be2f45
...@@ -150,6 +150,10 @@ Likewise, if your `NewModel` is a subclass of [`PreTrainedModel`], make sure its ...@@ -150,6 +150,10 @@ Likewise, if your `NewModel` is a subclass of [`PreTrainedModel`], make sure its
[[autodoc]] AutoModelForImageSegmentation [[autodoc]] AutoModelForImageSegmentation
## AutoModelForSemanticSegmentation
[[autodoc]] AutoModelForSemanticSegmentation
## TFAutoModel ## TFAutoModel
[[autodoc]] TFAutoModel [[autodoc]] TFAutoModel
......
...@@ -688,6 +688,7 @@ if is_torch_available(): ...@@ -688,6 +688,7 @@ if is_torch_available():
"AutoModelForObjectDetection", "AutoModelForObjectDetection",
"AutoModelForPreTraining", "AutoModelForPreTraining",
"AutoModelForQuestionAnswering", "AutoModelForQuestionAnswering",
"AutoModelForSemanticSegmentation",
"AutoModelForSeq2SeqLM", "AutoModelForSeq2SeqLM",
"AutoModelForSequenceClassification", "AutoModelForSequenceClassification",
"AutoModelForSpeechSeq2Seq", "AutoModelForSpeechSeq2Seq",
...@@ -2797,6 +2798,7 @@ if TYPE_CHECKING: ...@@ -2797,6 +2798,7 @@ if TYPE_CHECKING:
AutoModelForObjectDetection, AutoModelForObjectDetection,
AutoModelForPreTraining, AutoModelForPreTraining,
AutoModelForQuestionAnswering, AutoModelForQuestionAnswering,
AutoModelForSemanticSegmentation,
AutoModelForSeq2SeqLM, AutoModelForSeq2SeqLM,
AutoModelForSequenceClassification, AutoModelForSequenceClassification,
AutoModelForSpeechSeq2Seq, AutoModelForSpeechSeq2Seq,
......
...@@ -812,3 +812,32 @@ class Seq2SeqQuestionAnsweringModelOutput(ModelOutput): ...@@ -812,3 +812,32 @@ class Seq2SeqQuestionAnsweringModelOutput(ModelOutput):
encoder_last_hidden_state: Optional[torch.FloatTensor] = None encoder_last_hidden_state: Optional[torch.FloatTensor] = None
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class SemanticSegmentationModelOutput(ModelOutput):
"""
Base class for outputs of semantic segmentation models.
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Classification (or regression if config.num_labels==1) loss.
logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels, height, width)`):
Classification scores for each pixel.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
shape `(batch_size, patch_size, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, patch_size,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
...@@ -43,6 +43,7 @@ if is_torch_available(): ...@@ -43,6 +43,7 @@ if is_torch_available():
"MODEL_FOR_OBJECT_DETECTION_MAPPING", "MODEL_FOR_OBJECT_DETECTION_MAPPING",
"MODEL_FOR_PRETRAINING_MAPPING", "MODEL_FOR_PRETRAINING_MAPPING",
"MODEL_FOR_QUESTION_ANSWERING_MAPPING", "MODEL_FOR_QUESTION_ANSWERING_MAPPING",
"MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING",
"MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING", "MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING",
"MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING", "MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING",
"MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING", "MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING",
...@@ -65,6 +66,7 @@ if is_torch_available(): ...@@ -65,6 +66,7 @@ if is_torch_available():
"AutoModelForObjectDetection", "AutoModelForObjectDetection",
"AutoModelForPreTraining", "AutoModelForPreTraining",
"AutoModelForQuestionAnswering", "AutoModelForQuestionAnswering",
"AutoModelForSemanticSegmentation",
"AutoModelForSeq2SeqLM", "AutoModelForSeq2SeqLM",
"AutoModelForSequenceClassification", "AutoModelForSequenceClassification",
"AutoModelForSpeechSeq2Seq", "AutoModelForSpeechSeq2Seq",
...@@ -155,6 +157,7 @@ if TYPE_CHECKING: ...@@ -155,6 +157,7 @@ if TYPE_CHECKING:
MODEL_FOR_OBJECT_DETECTION_MAPPING, MODEL_FOR_OBJECT_DETECTION_MAPPING,
MODEL_FOR_PRETRAINING_MAPPING, MODEL_FOR_PRETRAINING_MAPPING,
MODEL_FOR_QUESTION_ANSWERING_MAPPING, MODEL_FOR_QUESTION_ANSWERING_MAPPING,
MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING,
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
...@@ -177,6 +180,7 @@ if TYPE_CHECKING: ...@@ -177,6 +180,7 @@ if TYPE_CHECKING:
AutoModelForObjectDetection, AutoModelForObjectDetection,
AutoModelForPreTraining, AutoModelForPreTraining,
AutoModelForQuestionAnswering, AutoModelForQuestionAnswering,
AutoModelForSemanticSegmentation,
AutoModelForSeq2SeqLM, AutoModelForSeq2SeqLM,
AutoModelForSequenceClassification, AutoModelForSequenceClassification,
AutoModelForSpeechSeq2Seq, AutoModelForSpeechSeq2Seq,
......
...@@ -278,11 +278,20 @@ MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( ...@@ -278,11 +278,20 @@ MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES = OrderedDict( MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES = OrderedDict(
[ [
# Do not add new models here, this class will be deprecated in the future.
# Model for Image Segmentation mapping # Model for Image Segmentation mapping
("detr", "DetrForSegmentation"), ("detr", "DetrForSegmentation"),
] ]
) )
MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES = OrderedDict(
[
# Model for Semantic Segmentation mapping
("beit", "BeitForSemanticSegmentation"),
("segformer", "SegformerForSemanticSegmentation"),
]
)
MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict( MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict(
[ [
("vision-encoder-decoder", "VisionEncoderDecoderModel"), ("vision-encoder-decoder", "VisionEncoderDecoderModel"),
...@@ -603,6 +612,9 @@ MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping( ...@@ -603,6 +612,9 @@ MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING = _LazyAutoMapping( MODEL_FOR_IMAGE_SEGMENTATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES CONFIG_MAPPING_NAMES, MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES
) )
MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES
)
MODEL_FOR_VISION_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES) MODEL_FOR_VISION_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES)
MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MASKED_LM_MAPPING_NAMES) MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MASKED_LM_MAPPING_NAMES)
MODEL_FOR_OBJECT_DETECTION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES) MODEL_FOR_OBJECT_DETECTION_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES)
...@@ -745,6 +757,15 @@ class AutoModelForImageSegmentation(_BaseAutoModelClass): ...@@ -745,6 +757,15 @@ class AutoModelForImageSegmentation(_BaseAutoModelClass):
AutoModelForImageSegmentation = auto_class_update(AutoModelForImageSegmentation, head_doc="image segmentation") AutoModelForImageSegmentation = auto_class_update(AutoModelForImageSegmentation, head_doc="image segmentation")
class AutoModelForSemanticSegmentation(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING
AutoModelForSemanticSegmentation = auto_class_update(
AutoModelForSemanticSegmentation, head_doc="semantic segmentation"
)
class AutoModelForObjectDetection(_BaseAutoModelClass): class AutoModelForObjectDetection(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_OBJECT_DETECTION_MAPPING _model_mapping = MODEL_FOR_OBJECT_DETECTION_MAPPING
......
...@@ -93,6 +93,10 @@ class BeitConfig(PretrainedConfig): ...@@ -93,6 +93,10 @@ class BeitConfig(PretrainedConfig):
Whether to concatenate the output of the auxiliary head with the input before the classification layer. Whether to concatenate the output of the auxiliary head with the input before the classification layer.
semantic_loss_ignore_index (`int`, *optional*, defaults to 255): semantic_loss_ignore_index (`int`, *optional*, defaults to 255):
The index that is ignored by the loss function of the semantic segmentation model. The index that is ignored by the loss function of the semantic segmentation model.
legacy_output (`bool`, *optional*, defaults to `False`):
Whether to return the legacy outputs or not (with logits of shape `height / 4 , width / 4`)
This argument is only present for backward compatibility reasons and will be removed in v5 of Transformers.
Example: Example:
...@@ -141,6 +145,7 @@ class BeitConfig(PretrainedConfig): ...@@ -141,6 +145,7 @@ class BeitConfig(PretrainedConfig):
auxiliary_num_convs=1, auxiliary_num_convs=1,
auxiliary_concat_input=False, auxiliary_concat_input=False,
semantic_loss_ignore_index=255, semantic_loss_ignore_index=255,
legacy_output=False,
**kwargs **kwargs
): ):
super().__init__(**kwargs) super().__init__(**kwargs)
...@@ -176,3 +181,4 @@ class BeitConfig(PretrainedConfig): ...@@ -176,3 +181,4 @@ class BeitConfig(PretrainedConfig):
self.auxiliary_num_convs = auxiliary_num_convs self.auxiliary_num_convs = auxiliary_num_convs
self.auxiliary_concat_input = auxiliary_concat_input self.auxiliary_concat_input = auxiliary_concat_input
self.semantic_loss_ignore_index = semantic_loss_ignore_index self.semantic_loss_ignore_index = semantic_loss_ignore_index
self.legacy_output = legacy_output
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
import collections.abc import collections.abc
import math import math
import warnings
from dataclasses import dataclass from dataclasses import dataclass
import torch import torch
...@@ -31,7 +32,13 @@ from ...file_utils import ( ...@@ -31,7 +32,13 @@ from ...file_utils import (
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
replace_return_docstrings, replace_return_docstrings,
) )
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, MaskedLMOutput, SequenceClassifierOutput from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPooling,
MaskedLMOutput,
SemanticSegmentationModelOutput,
SequenceClassifierOutput,
)
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import logging from ...utils import logging
from .configuration_beit import BeitConfig from .configuration_beit import BeitConfig
...@@ -1114,11 +1121,8 @@ class BeitForSemanticSegmentation(BeitPreTrainedModel): ...@@ -1114,11 +1121,8 @@ class BeitForSemanticSegmentation(BeitPreTrainedModel):
# Initialize weights and apply final processing # Initialize weights and apply final processing
self.post_init() self.post_init()
def compute_loss(self, logits, auxiliary_logits, labels): def compute_loss(self, upsampled_logits, auxiliary_logits, labels):
# upsample logits to the images' original size # upsample logits to the images' original size
upsampled_logits = nn.functional.interpolate(
logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
)
if auxiliary_logits is not None: if auxiliary_logits is not None:
upsampled_auxiliary_logits = nn.functional.interpolate( upsampled_auxiliary_logits = nn.functional.interpolate(
auxiliary_logits, size=labels.shape[-2:], mode="bilinear", align_corners=False auxiliary_logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
...@@ -1132,7 +1136,7 @@ class BeitForSemanticSegmentation(BeitPreTrainedModel): ...@@ -1132,7 +1136,7 @@ class BeitForSemanticSegmentation(BeitPreTrainedModel):
return loss return loss
@add_start_docstrings_to_model_forward(BEIT_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(BEIT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=SemanticSegmentationModelOutput, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
pixel_values=None, pixel_values=None,
...@@ -1141,11 +1145,17 @@ class BeitForSemanticSegmentation(BeitPreTrainedModel): ...@@ -1141,11 +1145,17 @@ class BeitForSemanticSegmentation(BeitPreTrainedModel):
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
return_dict=None, return_dict=None,
legacy_output=None,
): ):
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ..., Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels > 1`, a classification loss is computed (Cross-Entropy). config.num_labels - 1]`. If `config.num_labels > 1`, a classification loss is computed (Cross-Entropy).
legacy_output (`bool`, *optional*):
Whether to return the legacy outputs or not (with logits of shape `height / 4 , width / 4`). Will default
to `self.config.legacy_output`.
This argument is only present for backward compatibility reasons and will be removed in v5 of Transformers.
Returns: Returns:
...@@ -1164,13 +1174,21 @@ class BeitForSemanticSegmentation(BeitPreTrainedModel): ...@@ -1164,13 +1174,21 @@ class BeitForSemanticSegmentation(BeitPreTrainedModel):
>>> inputs = feature_extractor(images=image, return_tensors="pt") >>> inputs = feature_extractor(images=image, return_tensors="pt")
>>> outputs = model(**inputs) >>> outputs = model(**inputs)
>>> # logits are of shape (batch_size, num_labels, height/4, width/4) >>> # logits are of shape (batch_size, num_labels, height, width)
>>> logits = outputs.logits >>> logits = outputs.logits
```""" ```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
output_hidden_states = ( output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
) )
legacy_output = legacy_output if legacy_output is not None else self.config.legacy_output
if not legacy_output:
warnings.warn(
"The output of this model has changed in v4.17.0 and the logits now have the same size as the inputs. "
"You can activate the previous behavior by passing `legacy_output=True` to this call or the "
"configuration of this model (only until v5, then that argument will be removed).",
FutureWarning,
)
outputs = self.beit( outputs = self.beit(
pixel_values, pixel_values,
...@@ -1197,6 +1215,11 @@ class BeitForSemanticSegmentation(BeitPreTrainedModel): ...@@ -1197,6 +1215,11 @@ class BeitForSemanticSegmentation(BeitPreTrainedModel):
features[i] = ops[i](features[i]) features[i] = ops[i](features[i])
logits = self.decode_head(features) logits = self.decode_head(features)
upsampled_logits = nn.functional.interpolate(
logits, size=pixel_values.shape[-2:], mode="bilinear", align_corners=False
)
auxiliary_logits = None auxiliary_logits = None
if self.auxiliary_head is not None: if self.auxiliary_head is not None:
auxiliary_logits = self.auxiliary_head(features) auxiliary_logits = self.auxiliary_head(features)
...@@ -1206,18 +1229,26 @@ class BeitForSemanticSegmentation(BeitPreTrainedModel): ...@@ -1206,18 +1229,26 @@ class BeitForSemanticSegmentation(BeitPreTrainedModel):
if self.config.num_labels == 1: if self.config.num_labels == 1:
raise ValueError("The number of labels should be greater than one") raise ValueError("The number of labels should be greater than one")
else: else:
loss = self.compute_loss(logits, auxiliary_logits, labels) loss = self.compute_loss(upsampled_logits, auxiliary_logits, labels)
if not return_dict: if not return_dict:
if output_hidden_states: if output_hidden_states:
output = (logits,) + outputs[2:] output = (logits if legacy_output else upsampled_logits,) + outputs[2:]
else: else:
output = (logits,) + outputs[3:] output = (logits if legacy_output else upsampled_logits,) + outputs[3:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput( if legacy_output:
loss=loss, return SequenceClassifierOutput(
logits=logits, loss=loss,
hidden_states=outputs.hidden_states if output_hidden_states else None, logits=logits,
attentions=outputs.attentions, hidden_states=outputs.hidden_states if output_hidden_states else None,
) attentions=outputs.attentions,
)
else:
return SemanticSegmentationModelOutput(
loss=loss,
logits=upsampled_logits,
hidden_states=outputs.hidden_states if output_hidden_states else None,
attentions=outputs.attentions,
)
...@@ -83,6 +83,10 @@ class SegformerConfig(PretrainedConfig): ...@@ -83,6 +83,10 @@ class SegformerConfig(PretrainedConfig):
required for the semantic segmentation model. required for the semantic segmentation model.
semantic_loss_ignore_index (`int`, *optional*, defaults to 255): semantic_loss_ignore_index (`int`, *optional*, defaults to 255):
The index that is ignored by the loss function of the semantic segmentation model. The index that is ignored by the loss function of the semantic segmentation model.
legacy_output (`bool`, *optional*, defaults to `False`):
Whether to return the legacy outputs or not (with logits of shape `height / 4 , width / 4`)
This argument is only present for backward compatibility reasons and will be removed in v5 of Transformers.
Example: Example:
...@@ -124,6 +128,7 @@ class SegformerConfig(PretrainedConfig): ...@@ -124,6 +128,7 @@ class SegformerConfig(PretrainedConfig):
is_encoder_decoder=False, is_encoder_decoder=False,
reshape_last_stage=True, reshape_last_stage=True,
semantic_loss_ignore_index=255, semantic_loss_ignore_index=255,
legacy_output=False,
**kwargs **kwargs
): ):
super().__init__(**kwargs) super().__init__(**kwargs)
...@@ -149,3 +154,4 @@ class SegformerConfig(PretrainedConfig): ...@@ -149,3 +154,4 @@ class SegformerConfig(PretrainedConfig):
self.decoder_hidden_size = decoder_hidden_size self.decoder_hidden_size = decoder_hidden_size
self.reshape_last_stage = reshape_last_stage self.reshape_last_stage = reshape_last_stage
self.semantic_loss_ignore_index = semantic_loss_ignore_index self.semantic_loss_ignore_index = semantic_loss_ignore_index
self.legacy_output = legacy_output
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
import collections import collections
import math import math
import warnings
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
...@@ -30,7 +31,7 @@ from ...file_utils import ( ...@@ -30,7 +31,7 @@ from ...file_utils import (
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
replace_return_docstrings, replace_return_docstrings,
) )
from ...modeling_outputs import BaseModelOutput, SequenceClassifierOutput from ...modeling_outputs import BaseModelOutput, SemanticSegmentationModelOutput, SequenceClassifierOutput
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import logging from ...utils import logging
from .configuration_segformer import SegformerConfig from .configuration_segformer import SegformerConfig
...@@ -688,7 +689,7 @@ class SegformerForSemanticSegmentation(SegformerPreTrainedModel): ...@@ -688,7 +689,7 @@ class SegformerForSemanticSegmentation(SegformerPreTrainedModel):
self.post_init() self.post_init()
@add_start_docstrings_to_model_forward(SEGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(SEGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=SemanticSegmentationModelOutput, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
pixel_values, pixel_values,
...@@ -696,11 +697,17 @@ class SegformerForSemanticSegmentation(SegformerPreTrainedModel): ...@@ -696,11 +697,17 @@ class SegformerForSemanticSegmentation(SegformerPreTrainedModel):
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
return_dict=None, return_dict=None,
legacy_output=None,
): ):
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ..., Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels > 1`, a classification loss is computed (Cross-Entropy). config.num_labels - 1]`. If `config.num_labels > 1`, a classification loss is computed (Cross-Entropy).
legacy_output (`bool`, *optional*):
Whether to return the legacy outputs or not (with logits of shape `height / 4 , width / 4`). Will default
to `self.config.legacy_output`.
This argument is only present for backward compatibility reasons and will be removed in v5 of Transformers.
Returns: Returns:
...@@ -719,12 +726,20 @@ class SegformerForSemanticSegmentation(SegformerPreTrainedModel): ...@@ -719,12 +726,20 @@ class SegformerForSemanticSegmentation(SegformerPreTrainedModel):
>>> inputs = feature_extractor(images=image, return_tensors="pt") >>> inputs = feature_extractor(images=image, return_tensors="pt")
>>> outputs = model(**inputs) >>> outputs = model(**inputs)
>>> logits = outputs.logits # shape (batch_size, num_labels, height/4, width/4) >>> logits = outputs.logits # shape (batch_size, num_labels, height, width)
```""" ```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
output_hidden_states = ( output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
) )
legacy_output = legacy_output if legacy_output is not None else self.config.legacy_output
if not legacy_output:
warnings.warn(
"The output of this model has changed in v4.17.0 and the logits now have the same size as the inputs. "
"You can activate the previous behavior by passing `legacy_output=True` to this call or the "
"configuration of this model (only until v5, then that argument will be removed).",
FutureWarning,
)
outputs = self.segformer( outputs = self.segformer(
pixel_values, pixel_values,
...@@ -737,28 +752,37 @@ class SegformerForSemanticSegmentation(SegformerPreTrainedModel): ...@@ -737,28 +752,37 @@ class SegformerForSemanticSegmentation(SegformerPreTrainedModel):
logits = self.decode_head(encoder_hidden_states) logits = self.decode_head(encoder_hidden_states)
upsampled_logits = nn.functional.interpolate(
logits, size=pixel_values.shape[-2:], mode="bilinear", align_corners=False
)
loss = None loss = None
if labels is not None: if labels is not None:
if self.config.num_labels == 1: if self.config.num_labels == 1:
raise ValueError("The number of labels should be greater than one") raise ValueError("The number of labels should be greater than one")
else: else:
# upsample logits to the images' original size # upsample logits to the images' original size
upsampled_logits = nn.functional.interpolate(
logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
)
loss_fct = CrossEntropyLoss(ignore_index=self.config.semantic_loss_ignore_index) loss_fct = CrossEntropyLoss(ignore_index=self.config.semantic_loss_ignore_index)
loss = loss_fct(upsampled_logits, labels) loss = loss_fct(upsampled_logits, labels)
if not return_dict: if not return_dict:
if output_hidden_states: if output_hidden_states:
output = (logits,) + outputs[1:] output = (logits if legacy_output else upsampled_logits,) + outputs[1:]
else: else:
output = (logits,) + outputs[2:] output = (logits if legacy_output else upsampled_logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput( if legacy_output:
loss=loss, return SequenceClassifierOutput(
logits=logits, loss=loss,
hidden_states=outputs.hidden_states if output_hidden_states else None, logits=logits,
attentions=outputs.attentions, hidden_states=outputs.hidden_states if output_hidden_states else None,
) attentions=outputs.attentions,
)
else:
return SemanticSegmentationModelOutput(
loss=loss,
logits=upsampled_logits,
hidden_states=outputs.hidden_states if output_hidden_states else None,
attentions=outputs.attentions,
)
...@@ -474,6 +474,13 @@ class AutoModelForQuestionAnswering(metaclass=DummyObject): ...@@ -474,6 +474,13 @@ class AutoModelForQuestionAnswering(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class AutoModelForSemanticSegmentation(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class AutoModelForSeq2SeqLM(metaclass=DummyObject): class AutoModelForSeq2SeqLM(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
......
...@@ -92,17 +92,20 @@ class BeitModelTester: ...@@ -92,17 +92,20 @@ class BeitModelTester:
self.initializer_range = initializer_range self.initializer_range = initializer_range
self.scope = scope self.scope = scope
self.out_indices = out_indices self.out_indices = out_indices
self.num_labels = num_labels
def prepare_config_and_inputs(self): def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
labels = None labels = None
pixel_labels = None
if self.use_labels: if self.use_labels:
labels = ids_tensor([self.batch_size], self.type_sequence_label_size) labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
pixel_labels = ids_tensor([self.batch_size, self.image_size, self.image_size], self.num_labels)
config = self.get_config() config = self.get_config()
return config, pixel_values, labels return config, pixel_values, labels, pixel_labels
def get_config(self): def get_config(self):
return BeitConfig( return BeitConfig(
...@@ -122,7 +125,7 @@ class BeitModelTester: ...@@ -122,7 +125,7 @@ class BeitModelTester:
out_indices=self.out_indices, out_indices=self.out_indices,
) )
def create_and_check_model(self, config, pixel_values, labels): def create_and_check_model(self, config, pixel_values, labels, pixel_labels):
model = BeitModel(config=config) model = BeitModel(config=config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
...@@ -133,7 +136,7 @@ class BeitModelTester: ...@@ -133,7 +136,7 @@ class BeitModelTester:
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 1, self.hidden_size)) self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 1, self.hidden_size))
def create_and_check_for_masked_lm(self, config, pixel_values, labels): def create_and_check_for_masked_lm(self, config, pixel_values, labels, pixel_labels):
model = BeitForMaskedImageModeling(config=config) model = BeitForMaskedImageModeling(config=config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
...@@ -144,7 +147,7 @@ class BeitModelTester: ...@@ -144,7 +147,7 @@ class BeitModelTester:
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.parent.assertEqual(result.logits.shape, (self.batch_size, num_patches, self.vocab_size)) self.parent.assertEqual(result.logits.shape, (self.batch_size, num_patches, self.vocab_size))
def create_and_check_for_image_classification(self, config, pixel_values, labels): def create_and_check_for_image_classification(self, config, pixel_values, labels, pixel_labels):
config.num_labels = self.type_sequence_label_size config.num_labels = self.type_sequence_label_size
model = BeitForImageClassification(config) model = BeitForImageClassification(config)
model.to(torch_device) model.to(torch_device)
...@@ -152,13 +155,23 @@ class BeitModelTester: ...@@ -152,13 +155,23 @@ class BeitModelTester:
result = model(pixel_values, labels=labels) result = model(pixel_values, labels=labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size)) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
def create_and_check_for_image_segmentation(self, config, pixel_values, labels, pixel_labels):
config.num_labels = self.num_labels
model = BeitForSemanticSegmentation(config)
model.to(torch_device)
model.eval()
result = model(pixel_values)
self.parent.assertEqual(
result.logits.shape, (self.batch_size, self.num_labels, self.image_size, self.image_size)
)
result = model(pixel_values, labels=pixel_labels)
self.parent.assertEqual(
result.logits.shape, (self.batch_size, self.num_labels, self.image_size, self.image_size)
)
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs() config_and_inputs = self.prepare_config_and_inputs()
( config, pixel_values, labels, pixel_labels = config_and_inputs
config,
pixel_values,
labels,
) = config_and_inputs
inputs_dict = {"pixel_values": pixel_values} inputs_dict = {"pixel_values": pixel_values}
return config, inputs_dict return config, inputs_dict
...@@ -217,6 +230,10 @@ class BeitModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -217,6 +230,10 @@ class BeitModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs) self.model_tester.create_and_check_model(*config_and_inputs)
def test_for_image_segmentation(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_image_segmentation(*config_and_inputs)
def test_training(self): def test_training(self):
if not self.model_tester.is_training: if not self.model_tester.is_training:
return return
...@@ -516,14 +533,14 @@ class BeitModelIntegrationTest(unittest.TestCase): ...@@ -516,14 +533,14 @@ class BeitModelIntegrationTest(unittest.TestCase):
logits = outputs.logits logits = outputs.logits
# verify the logits # verify the logits
expected_shape = torch.Size((1, 150, 160, 160)) expected_shape = torch.Size((1, 150, 640, 640))
self.assertEqual(logits.shape, expected_shape) self.assertEqual(logits.shape, expected_shape)
expected_slice = torch.tensor( expected_slice = torch.tensor(
[ [
[[-4.9225, -2.3954, -3.0522], [-2.8822, -1.0046, -1.7561], [-2.9549, -1.3228, -2.1347]], [[-4.9225, -4.9225, -4.6066], [-4.9225, -4.9225, -4.6066], [-4.6675, -4.6675, -4.3617]],
[[-5.8168, -3.4129, -4.0778], [-3.8651, -2.2214, -3.0277], [-3.8356, -2.4643, -3.3535]], [[-5.8168, -5.8168, -5.5163], [-5.8168, -5.8168, -5.5163], [-5.5728, -5.5728, -5.2842]],
[[-0.0078, 3.9952, 4.0754], [2.9856, 4.6944, 5.0035], [3.2413, 4.7813, 4.9969]], [[-0.0078, -0.0078, 0.4926], [-0.0078, -0.0078, 0.4926], [0.3664, 0.3664, 0.8309]],
] ]
).to(torch_device) ).to(torch_device)
......
...@@ -133,11 +133,11 @@ class SegformerModelTester: ...@@ -133,11 +133,11 @@ class SegformerModelTester:
model.eval() model.eval()
result = model(pixel_values) result = model(pixel_values)
self.parent.assertEqual( self.parent.assertEqual(
result.logits.shape, (self.batch_size, self.num_labels, self.image_size // 4, self.image_size // 4) result.logits.shape, (self.batch_size, self.num_labels, self.image_size, self.image_size)
) )
result = model(pixel_values, labels=labels) result = model(pixel_values, labels=labels)
self.parent.assertEqual( self.parent.assertEqual(
result.logits.shape, (self.batch_size, self.num_labels, self.image_size // 4, self.image_size // 4) result.logits.shape, (self.batch_size, self.num_labels, self.image_size, self.image_size)
) )
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
...@@ -245,6 +245,7 @@ class SegformerModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -245,6 +245,7 @@ class SegformerModelTest(ModelTesterMixin, unittest.TestCase):
list(attentions[-1].shape[-3:]), list(attentions[-1].shape[-3:]),
[self.model_tester.num_attention_heads[-1], expected_seq_len, expected_reduced_seq_len], [self.model_tester.num_attention_heads[-1], expected_seq_len, expected_reduced_seq_len],
) )
out_len = len(outputs)
# Check attention is always last and order is fine # Check attention is always last and order is fine
inputs_dict["output_attentions"] = True inputs_dict["output_attentions"] = True
...@@ -255,7 +256,7 @@ class SegformerModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -255,7 +256,7 @@ class SegformerModelTest(ModelTesterMixin, unittest.TestCase):
with torch.no_grad(): with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class)) outputs = model(**self._prepare_for_class(inputs_dict, model_class))
self.assertEqual(3, len(outputs)) self.assertEqual(out_len + 1, len(outputs))
self_attentions = outputs.attentions self_attentions = outputs.attentions
...@@ -357,16 +358,17 @@ class SegformerModelIntegrationTest(unittest.TestCase): ...@@ -357,16 +358,17 @@ class SegformerModelIntegrationTest(unittest.TestCase):
encoded_inputs = feature_extractor(images=image, return_tensors="pt") encoded_inputs = feature_extractor(images=image, return_tensors="pt")
pixel_values = encoded_inputs.pixel_values.to(torch_device) pixel_values = encoded_inputs.pixel_values.to(torch_device)
outputs = model(pixel_values) with torch.no_grad():
outputs = model(pixel_values)
expected_shape = torch.Size((1, model.config.num_labels, 128, 128)) expected_shape = torch.Size((1, model.config.num_labels, 512, 512))
self.assertEqual(outputs.logits.shape, expected_shape) self.assertEqual(outputs.logits.shape, expected_shape)
expected_slice = torch.tensor( expected_slice = torch.tensor(
[ [
[[-4.6310, -5.5232, -6.2356], [-5.1921, -6.1444, -6.5996], [-5.4424, -6.2790, -6.7574]], [[-4.6309, -4.6309, -4.7425], [-4.6309, -4.6309, -4.7425], [-4.7011, -4.7011, -4.8136]],
[[-12.1391, -13.3122, -13.9554], [-12.8732, -13.9352, -14.3563], [-12.9438, -13.8226, -14.2513]], [[-12.1391, -12.1391, -12.2858], [-12.1391, -12.1391, -12.2858], [-12.2309, -12.2309, -12.3758]],
[[-12.5134, -13.4686, -14.4915], [-12.8669, -14.4343, -14.7758], [-13.2523, -14.5819, -15.0694]], [[-12.5134, -12.5134, -12.6328], [-12.5134, -12.5134, -12.6328], [-12.5576, -12.5576, -12.6865]],
] ]
).to(torch_device) ).to(torch_device)
self.assertTrue(torch.allclose(outputs.logits[0, :3, :3, :3], expected_slice, atol=1e-4)) self.assertTrue(torch.allclose(outputs.logits[0, :3, :3, :3], expected_slice, atol=1e-4))
...@@ -385,16 +387,17 @@ class SegformerModelIntegrationTest(unittest.TestCase): ...@@ -385,16 +387,17 @@ class SegformerModelIntegrationTest(unittest.TestCase):
encoded_inputs = feature_extractor(images=image, return_tensors="pt") encoded_inputs = feature_extractor(images=image, return_tensors="pt")
pixel_values = encoded_inputs.pixel_values.to(torch_device) pixel_values = encoded_inputs.pixel_values.to(torch_device)
outputs = model(pixel_values) with torch.no_grad():
outputs = model(pixel_values)
expected_shape = torch.Size((1, model.config.num_labels, 128, 128)) expected_shape = torch.Size((1, model.config.num_labels, 512, 512))
self.assertEqual(outputs.logits.shape, expected_shape) self.assertEqual(outputs.logits.shape, expected_shape)
expected_slice = torch.tensor( expected_slice = torch.tensor(
[ [
[[-13.5748, -13.9111, -12.6500], [-14.3500, -15.3683, -14.2328], [-14.7532, -16.0424, -15.6087]], [[-13.5729, -13.5729, -13.6149], [-13.5729, -13.5729, -13.6149], [-13.6697, -13.6697, -13.7224]],
[[-17.1651, -15.8725, -12.9653], [-17.2580, -17.3718, -14.8223], [-16.6058, -16.8783, -16.7452]], [[-17.1638, -17.1638, -17.0022], [-17.1638, -17.1638, -17.0022], [-17.1754, -17.1754, -17.0358]],
[[-3.6456, -3.0209, -1.4203], [-3.0797, -3.1959, -2.0000], [-1.8757, -1.9217, -1.6997]], [[-3.6452, -3.6452, -3.5670], [-3.6452, -3.6452, -3.5670], [-3.5744, -3.5744, -3.5079]],
] ]
).to(torch_device) ).to(torch_device)
self.assertTrue(torch.allclose(outputs.logits[0, :3, :3, :3], expected_slice, atol=1e-1)) self.assertTrue(torch.allclose(outputs.logits[0, :3, :3, :3], expected_slice, atol=1e-1))
...@@ -118,8 +118,6 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [ ...@@ -118,8 +118,6 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
"PerceiverForMultimodalAutoencoding", "PerceiverForMultimodalAutoencoding",
"PerceiverForOpticalFlow", "PerceiverForOpticalFlow",
"SegformerDecodeHead", "SegformerDecodeHead",
"SegformerForSemanticSegmentation",
"BeitForSemanticSegmentation",
"FlaxBeitForMaskedImageModeling", "FlaxBeitForMaskedImageModeling",
"BeitForMaskedImageModeling", "BeitForMaskedImageModeling",
"CLIPTextModel", "CLIPTextModel",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment