Unverified Commit a8b6443e authored by Francesco Saverio Zuppichini's avatar Francesco Saverio Zuppichini Committed by GitHub
Browse files

Refactor Modeling Outputs (#16341)



* first proposal

* replace model outputs in various models

* conflicts

* docstring

* update poolformer

* minor change in docstring

* CI

* removed poolformer specific outputs from doc

* removed convnext specific outputs from doc

* CI

* weird char in segformer

* conversations

* reverted docstring for BaseModelOutputWithPooling

* update outputs

* changed docstring in BaseModelOutput

* updated docstring in modeling outputs

* typos :)

* fixed typo after copy & paste it all around

* CI

* Apply suggestions from code review
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* segformer
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>
parent 857eb87c
......@@ -40,11 +40,6 @@ alt="drawing" width="600"/>
This model was contributed by [nielsr](https://huggingface.co/nielsr). TensorFlow version of the model was contributed by [ariG23498](https://github.com/ariG23498),
[gante](https://github.com/gante), and [sayakpaul](https://github.com/sayakpaul) (equal contribution). The original code can be found [here](https://github.com/facebookresearch/ConvNeXt).
## ConvNeXT specific outputs
[[autodoc]] models.convnext.modeling_convnext.ConvNextModelOutput
## ConvNextConfig
[[autodoc]] ConvNextConfig
......
......@@ -41,12 +41,6 @@ Tips:
This model was contributed by [heytanay](https://huggingface.co/heytanay). The original code can be found [here](https://github.com/sail-sg/poolformer).
## PoolFormer specific outputs
[[autodoc]] models.poolformer.modeling_poolformer.PoolFormerModelOutput
[[autodoc]] models.poolformer.modeling_poolformer.PoolFormerClassifierOutput
## PoolFormerConfig
[[autodoc]] PoolFormerConfig
......
This diff is collapsed.
......@@ -29,9 +29,9 @@ from ...activations import ACT2FN
from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPooling,
ImageClassifierOutput,
MaskedLMOutput,
SemanticSegmenterOutput,
SequenceClassifierOutput,
)
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import (
......@@ -851,7 +851,7 @@ class BeitForImageClassification(BeitPreTrainedModel):
@add_code_sample_docstrings(
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
checkpoint=_IMAGE_CLASS_CHECKPOINT,
output_type=SequenceClassifierOutput,
output_type=ImageClassifierOutput,
config_class=_CONFIG_FOR_DOC,
expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
)
......@@ -863,7 +863,7 @@ class BeitForImageClassification(BeitPreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[tuple, SequenceClassifierOutput]:
) -> Union[tuple, ImageClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
......@@ -909,7 +909,7 @@ class BeitForImageClassification(BeitPreTrainedModel):
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(
return ImageClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
......
......@@ -14,8 +14,6 @@
# limitations under the License.
""" PyTorch ConvNext model."""
from dataclasses import dataclass
from typing import Optional, Tuple
import torch
import torch.utils.checkpoint
......@@ -23,14 +21,13 @@ from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...modeling_utils import PreTrainedModel
from ...utils import (
ModelOutput,
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
from ...modeling_outputs import (
BaseModelOutputWithNoAttention,
BaseModelOutputWithPoolingAndNoAttention,
ImageClassifierOutputWithNoAttention,
)
from ...modeling_utils import PreTrainedModel
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
from .configuration_convnext import ConvNextConfig
......@@ -54,66 +51,6 @@ CONVNEXT_PRETRAINED_MODEL_ARCHIVE_LIST = [
]
@dataclass
class ConvNextEncoderOutput(ModelOutput):
"""
Class for [`ConvNextEncoder`]'s outputs, with potential hidden states (feature maps).
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Last hidden states (final feature map) of the last stage of the model.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
shape `(batch_size, num_channels, height, width)`. Hidden-states (also called feature maps) of the model at
the output of each stage.
"""
last_hidden_state: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class ConvNextModelOutput(ModelOutput):
"""
Class for [`ConvNextModel`]'s outputs, with potential hidden states (feature maps).
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Last hidden states (final feature map) of the last stage of the model.
pooler_output (`torch.FloatTensor` of shape `(batch_size, config.dim[-1])`):
Global average pooling of the last feature map followed by a layernorm.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
shape `(batch_size, num_channels, height, width)`. Hidden-states (also called feature maps) of the model at
the output of each stage.
"""
last_hidden_state: torch.FloatTensor = None
pooler_output: Optional[torch.FloatTensor] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class ConvNextClassifierOutput(ModelOutput):
"""
Class for [`ConvNextForImageClassification`]'s outputs, with potential hidden states (feature maps).
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Classification (or regression if config.num_labels==1) loss.
logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
Classification (or regression if config.num_labels==1) scores (before SoftMax).
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
shape `(batch_size, num_channels, height, width)`. Hidden-states (also called feature maps) of the model at
the output of each stage.
"""
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
# Stochastic depth implementation
# Taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
......@@ -302,7 +239,7 @@ class ConvNextEncoder(nn.Module):
if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states] if v is not None)
return ConvNextEncoderOutput(
return BaseModelOutputWithNoAttention(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
)
......@@ -383,7 +320,7 @@ class ConvNextModel(ConvNextPreTrainedModel):
@add_code_sample_docstrings(
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=ConvNextModelOutput,
output_type=BaseModelOutputWithPoolingAndNoAttention,
config_class=_CONFIG_FOR_DOC,
modality="vision",
expected_output=_EXPECTED_OUTPUT_SHAPE,
......@@ -413,7 +350,7 @@ class ConvNextModel(ConvNextPreTrainedModel):
if not return_dict:
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
return ConvNextModelOutput(
return BaseModelOutputWithPoolingAndNoAttention(
last_hidden_state=last_hidden_state,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
......@@ -446,7 +383,7 @@ class ConvNextForImageClassification(ConvNextPreTrainedModel):
@add_code_sample_docstrings(
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
checkpoint=_IMAGE_CLASS_CHECKPOINT,
output_type=ConvNextClassifierOutput,
output_type=ImageClassifierOutputWithNoAttention,
config_class=_CONFIG_FOR_DOC,
expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
)
......@@ -491,7 +428,7 @@ class ConvNextForImageClassification(ConvNextPreTrainedModel):
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return ConvNextClassifierOutput(
return ImageClassifierOutputWithNoAttention(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
......
......@@ -26,7 +26,7 @@ from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, MaskedLMOutput, SequenceClassifierOutput
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput, MaskedLMOutput
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import (
ModelOutput,
......@@ -693,7 +693,7 @@ class DeiTForImageClassification(DeiTPreTrainedModel):
self.post_init()
@add_start_docstrings_to_model_forward(DEIT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
@replace_return_docstrings(output_type=ImageClassifierOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
pixel_values: Optional[torch.Tensor] = None,
......@@ -702,7 +702,7 @@ class DeiTForImageClassification(DeiTPreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[tuple, SequenceClassifierOutput]:
) -> Union[tuple, ImageClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
......@@ -777,7 +777,7 @@ class DeiTForImageClassification(DeiTPreTrainedModel):
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(
return ImageClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
......
......@@ -16,7 +16,6 @@
import collections.abc
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import torch
......@@ -25,14 +24,9 @@ from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutputWithNoAttention, ImageClassifierOutputWithNoAttention
from ...modeling_utils import PreTrainedModel
from ...utils import (
ModelOutput,
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
)
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
from .configuration_poolformer import PoolFormerConfig
......@@ -63,47 +57,6 @@ def to_2tuple(x):
return (x, x)
@dataclass
class PoolFormerModelOutput(ModelOutput):
"""
Class for PoolFormerModel's outputs, with potential hidden states.
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
plus the initial embedding outputs.
"""
last_hidden_state: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class PoolFormerClassifierOutput(ModelOutput):
"""
Class for PoolformerForImageClassification's outputs.
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Classification (or regression if config.num_labels==1) loss.
logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
Classification (or regression if config.num_labels==1) scores (before SoftMax).
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
shape `(batch_size, num_channels, height, width)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
"""
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, the original name is
......@@ -295,7 +248,7 @@ class PoolFormerEncoder(nn.Module):
# Get patch embeddings from hidden_states
hidden_states = embedding_layer(hidden_states)
# Send the embeddings through the blocks
for i, blk in enumerate(block_layer):
for _, blk in enumerate(block_layer):
layer_outputs = blk(hidden_states)
hidden_states = layer_outputs[0]
......@@ -305,7 +258,7 @@ class PoolFormerEncoder(nn.Module):
if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states] if v is not None)
return PoolFormerModelOutput(last_hidden_state=hidden_states, hidden_states=all_hidden_states)
return BaseModelOutputWithNoAttention(last_hidden_state=hidden_states, hidden_states=all_hidden_states)
class PoolFormerPreTrainedModel(PreTrainedModel):
......@@ -374,7 +327,7 @@ class PoolFormerModel(PoolFormerPreTrainedModel):
@add_code_sample_docstrings(
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=PoolFormerModelOutput,
output_type=BaseModelOutputWithNoAttention,
config_class=_CONFIG_FOR_DOC,
modality="vision",
expected_output=_EXPECTED_OUTPUT_SHAPE,
......@@ -384,7 +337,7 @@ class PoolFormerModel(PoolFormerPreTrainedModel):
pixel_values: Optional[torch.FloatTensor] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, PoolFormerModelOutput]:
) -> Union[Tuple, BaseModelOutputWithNoAttention]:
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
......@@ -403,7 +356,7 @@ class PoolFormerModel(PoolFormerPreTrainedModel):
if not return_dict:
return (sequence_output, None) + encoder_outputs[1:]
return PoolFormerModelOutput(
return BaseModelOutputWithNoAttention(
last_hidden_state=sequence_output,
hidden_states=encoder_outputs.hidden_states,
)
......@@ -445,7 +398,7 @@ class PoolFormerForImageClassification(PoolFormerPreTrainedModel):
@add_code_sample_docstrings(
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
checkpoint=_IMAGE_CLASS_CHECKPOINT,
output_type=PoolFormerClassifierOutput,
output_type=ImageClassifierOutputWithNoAttention,
config_class=_CONFIG_FOR_DOC,
expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
)
......@@ -455,7 +408,7 @@ class PoolFormerForImageClassification(PoolFormerPreTrainedModel):
labels: Optional[torch.LongTensor] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, PoolFormerClassifierOutput]:
) -> Union[Tuple, ImageClassifierOutputWithNoAttention]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
......@@ -501,4 +454,4 @@ class PoolFormerForImageClassification(PoolFormerPreTrainedModel):
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return PoolFormerClassifierOutput(loss=loss, logits=logits, hidden_states=outputs.hidden_states)
return ImageClassifierOutputWithNoAttention(loss=loss, logits=logits, hidden_states=outputs.hidden_states)
......@@ -14,8 +14,7 @@
# limitations under the License.
""" PyTorch ResNet model."""
from dataclasses import dataclass
from typing import Optional, Tuple
from typing import Optional
import torch
import torch.utils.checkpoint
......@@ -23,7 +22,11 @@ from torch import Tensor, nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...modeling_outputs import ImageClassifierOutput, ModelOutput
from ...modeling_outputs import (
BaseModelOutputWithNoAttention,
BaseModelOutputWithPoolingAndNoAttention,
ImageClassifierOutputWithNoAttention,
)
from ...modeling_utils import PreTrainedModel
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
from .configuration_resnet import ResNetConfig
......@@ -49,47 +52,6 @@ RESNET_PRETRAINED_MODEL_ARCHIVE_LIST = [
]
@dataclass
class ResNetEncoderOutput(ModelOutput):
"""
ResNet encoder's output, with potential hidden states.
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Sequence of hidden-states at the output of the last layer of the model.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
shape `(batch_size, num_channels, height, width)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
"""
last_hidden_state: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class ResNetModelOutput(ModelOutput):
"""
ResNet model's output, with potential hidden states.
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Sequence of hidden-states at the output of the last layer of the model.
pooler_output (`torch.FloatTensor` of shape `(batch_size, config.hidden_sizes[-1])`):
The pooled last hidden state.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
shape `(batch_size, num_channels, height, width)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
"""
last_hidden_state: torch.FloatTensor = None
pooler_output: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
class ResNetConvLayer(nn.Sequential):
def __init__(
self, in_channels: int, out_channels: int, kernel_size: int = 3, stride: int = 1, activation: str = "relu"
......@@ -228,7 +190,7 @@ class ResNetEncoder(nn.Module):
def forward(
self, hidden_state: Tensor, output_hidden_states: bool = False, return_dict: bool = True
) -> ResNetEncoderOutput:
) -> BaseModelOutputWithNoAttention:
hidden_states = () if output_hidden_states else None
for stage_module in self.stages:
......@@ -243,7 +205,7 @@ class ResNetEncoder(nn.Module):
if not return_dict:
return tuple(v for v in [hidden_state, hidden_states] if v is not None)
return ResNetEncoderOutput(
return BaseModelOutputWithNoAttention(
last_hidden_state=hidden_state,
hidden_states=hidden_states,
)
......@@ -315,14 +277,14 @@ class ResNetModel(ResNetPreTrainedModel):
@add_code_sample_docstrings(
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=ResNetModelOutput,
output_type=BaseModelOutputWithPoolingAndNoAttention,
config_class=_CONFIG_FOR_DOC,
modality="vision",
expected_output=_EXPECTED_OUTPUT_SHAPE,
)
def forward(
self, pixel_values: Tensor, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None
) -> ResNetModelOutput:
) -> BaseModelOutputWithPoolingAndNoAttention:
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
......@@ -341,7 +303,7 @@ class ResNetModel(ResNetPreTrainedModel):
if not return_dict:
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
return ResNetModelOutput(
return BaseModelOutputWithPoolingAndNoAttention(
last_hidden_state=last_hidden_state,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
......@@ -372,7 +334,7 @@ class ResNetForImageClassification(ResNetPreTrainedModel):
@add_code_sample_docstrings(
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
checkpoint=_IMAGE_CLASS_CHECKPOINT,
output_type=ImageClassifierOutput,
output_type=ImageClassifierOutputWithNoAttention,
config_class=_CONFIG_FOR_DOC,
expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
)
......@@ -382,7 +344,7 @@ class ResNetForImageClassification(ResNetPreTrainedModel):
labels: Tensor = None,
output_hidden_states: bool = None,
return_dict: bool = None,
) -> ImageClassifierOutput:
) -> ImageClassifierOutputWithNoAttention:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
......@@ -423,4 +385,4 @@ class ResNetForImageClassification(ResNetPreTrainedModel):
output = (logits,) + outputs[2:]
return (loss,) + output if loss is not None else output
return ImageClassifierOutput(loss=loss, logits=logits, hidden_states=outputs.hidden_states)
return ImageClassifierOutputWithNoAttention(loss=loss, logits=logits, hidden_states=outputs.hidden_states)
......@@ -24,7 +24,7 @@ from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutput, SemanticSegmenterOutput, SequenceClassifierOutput
from ...modeling_outputs import BaseModelOutput, ImageClassifierOutput, SemanticSegmenterOutput
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import (
add_code_sample_docstrings,
......@@ -57,6 +57,33 @@ SEGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [
]
class SegFormerImageClassifierOutput(ImageClassifierOutput):
"""
Base class for outputs of image classification models.
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Classification (or regression if config.num_labels==1) loss.
logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
Classification (or regression if config.num_labels==1) scores (before SoftMax).
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each stage) of shape `(batch_size, num_channels, height, width)`. Hidden-states (also
called feature maps) of the model at the output of each stage.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, patch_size,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
# Copied from transformers.models.convnext.modeling_convnext.drop_path
def drop_path(x, drop_prob: float = 0.0, training: bool = False, scale_by_keep=True):
"""
......@@ -558,7 +585,7 @@ class SegformerForImageClassification(SegformerPreTrainedModel):
@add_code_sample_docstrings(
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
checkpoint=_IMAGE_CLASS_CHECKPOINT,
output_type=SequenceClassifierOutput,
output_type=SegFormerImageClassifierOutput,
config_class=_CONFIG_FOR_DOC,
expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
)
......@@ -569,7 +596,7 @@ class SegformerForImageClassification(SegformerPreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, SequenceClassifierOutput]:
) -> Union[Tuple, SegFormerImageClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
......@@ -625,7 +652,7 @@ class SegformerForImageClassification(SegformerPreTrainedModel):
output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(
return SegFormerImageClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
......
......@@ -16,8 +16,6 @@
import math
from collections import OrderedDict
from dataclasses import dataclass
from typing import Optional, Tuple
import torch
import torch.utils.checkpoint
......@@ -25,14 +23,13 @@ from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...modeling_utils import PreTrainedModel
from ...utils import (
ModelOutput,
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
from ...modeling_outputs import (
BaseModelOutputWithNoAttention,
BaseModelOutputWithPoolingAndNoAttention,
ImageClassifierOutputWithNoAttention,
)
from ...modeling_utils import PreTrainedModel
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
from .configuration_van import VanConfig
......@@ -56,63 +53,6 @@ VAN_PRETRAINED_MODEL_ARCHIVE_LIST = [
]
@dataclass
class VanEncoderOutput(ModelOutput):
"""
Class for [`VanEncoder`]'s outputs, with potential hidden states (feature maps).
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Last hidden states (final feature map) of the last stage of the model.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of each stage) of shape `(batch_size, num_channels,
height, width)`. Hidden-states (also called feature maps) of the model at the output of each stage.
"""
last_hidden_state: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class VanModelOutput(ModelOutput):
"""
Class for [`VanModel`]'s outputs, with potential hidden states (feature maps).
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Last hidden states (final feature map) of the last stage of the model.
pooler_output (`torch.FloatTensor` of shape `(batch_size, config.hidden_sizes[-1])`):
Global average pooling of the last feature map followed by a layernorm.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of each stage) of shape `(batch_size, num_channels,
height, width)`. Hidden-states (also called feature maps) of the model at the output of each stage.
"""
last_hidden_state: torch.FloatTensor = None
pooler_output: Optional[torch.FloatTensor] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class VanClassifierOutput(ModelOutput):
"""
Class for [`VanForImageClassification`]'s outputs, with potential hidden states (feature maps).
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Classification (or regression if config.num_labels==1) loss.
logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
Classification (or regression if config.num_labels==1) scores (before SoftMax).
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of each stage) of shape `(batch_size, num_channels,
height, width)`. Hidden-states (also called feature maps) of the model at the output of each stage.
"""
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
# Stochastic depth implementation
# Taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
......@@ -388,7 +328,7 @@ class VanEncoder(nn.Module):
if not return_dict:
return tuple(v for v in [hidden_state, all_hidden_states] if v is not None)
return VanEncoderOutput(last_hidden_state=hidden_state, hidden_states=all_hidden_states)
return BaseModelOutputWithNoAttention(last_hidden_state=hidden_state, hidden_states=all_hidden_states)
class VanPreTrainedModel(PreTrainedModel):
......@@ -466,7 +406,7 @@ class VanModel(VanPreTrainedModel):
@add_code_sample_docstrings(
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=VanModelOutput,
output_type=BaseModelOutputWithPoolingAndNoAttention,
config_class=_CONFIG_FOR_DOC,
modality="vision",
expected_output=_EXPECTED_OUTPUT_SHAPE,
......@@ -489,7 +429,7 @@ class VanModel(VanPreTrainedModel):
if not return_dict:
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
return VanModelOutput(
return BaseModelOutputWithPoolingAndNoAttention(
last_hidden_state=last_hidden_state,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
......@@ -519,7 +459,7 @@ class VanForImageClassification(VanPreTrainedModel):
@add_code_sample_docstrings(
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
checkpoint=_IMAGE_CLASS_CHECKPOINT,
output_type=VanClassifierOutput,
output_type=ImageClassifierOutputWithNoAttention,
config_class=_CONFIG_FOR_DOC,
expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
)
......@@ -565,4 +505,4 @@ class VanForImageClassification(VanPreTrainedModel):
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return VanClassifierOutput(loss=loss, logits=logits, hidden_states=outputs.hidden_states)
return ImageClassifierOutputWithNoAttention(loss=loss, logits=logits, hidden_states=outputs.hidden_states)
......@@ -25,7 +25,7 @@ from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, MaskedLMOutput, SequenceClassifierOutput
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput, MaskedLMOutput
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import (
add_code_sample_docstrings,
......@@ -740,7 +740,7 @@ class ViTForImageClassification(ViTPreTrainedModel):
@add_code_sample_docstrings(
processor_class=_FEAT_EXTRACTOR_FOR_DOC,
checkpoint=_IMAGE_CLASS_CHECKPOINT,
output_type=SequenceClassifierOutput,
output_type=ImageClassifierOutput,
config_class=_CONFIG_FOR_DOC,
expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
)
......@@ -753,7 +753,7 @@ class ViTForImageClassification(ViTPreTrainedModel):
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[tuple, SequenceClassifierOutput]:
) -> Union[tuple, ImageClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
......@@ -801,7 +801,7 @@ class ViTForImageClassification(ViTPreTrainedModel):
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(
return ImageClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment