Unverified Commit aa6cfe9c authored by NielsRogge's avatar NielsRogge Committed by GitHub
Browse files

Rename to SemanticSegmenterOutput (#15849)


Co-authored-by: default avatarNiels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
parent 70a9bc69
...@@ -815,7 +815,7 @@ class Seq2SeqQuestionAnsweringModelOutput(ModelOutput): ...@@ -815,7 +815,7 @@ class Seq2SeqQuestionAnsweringModelOutput(ModelOutput):
@dataclass @dataclass
class SemanticSegmentationModelOutput(ModelOutput): class SemanticSegmenterOutput(ModelOutput):
""" """
Base class for outputs of semantic segmentation models. Base class for outputs of semantic segmentation models.
......
...@@ -30,7 +30,7 @@ from ...modeling_outputs import ( ...@@ -30,7 +30,7 @@ from ...modeling_outputs import (
BaseModelOutput, BaseModelOutput,
BaseModelOutputWithPooling, BaseModelOutputWithPooling,
MaskedLMOutput, MaskedLMOutput,
SemanticSegmentationModelOutput, SemanticSegmenterOutput,
SequenceClassifierOutput, SequenceClassifierOutput,
) )
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
...@@ -1188,7 +1188,7 @@ class BeitForSemanticSegmentation(BeitPreTrainedModel): ...@@ -1188,7 +1188,7 @@ class BeitForSemanticSegmentation(BeitPreTrainedModel):
return loss return loss
@add_start_docstrings_to_model_forward(BEIT_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(BEIT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=SemanticSegmentationModelOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=SemanticSegmenterOutput, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
pixel_values: Optional[torch.Tensor] = None, pixel_values: Optional[torch.Tensor] = None,
...@@ -1197,7 +1197,7 @@ class BeitForSemanticSegmentation(BeitPreTrainedModel): ...@@ -1197,7 +1197,7 @@ class BeitForSemanticSegmentation(BeitPreTrainedModel):
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
) -> Union[tuple, SemanticSegmentationModelOutput]: ) -> Union[tuple, SemanticSegmenterOutput]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ..., Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ...,
...@@ -1272,7 +1272,7 @@ class BeitForSemanticSegmentation(BeitPreTrainedModel): ...@@ -1272,7 +1272,7 @@ class BeitForSemanticSegmentation(BeitPreTrainedModel):
output = (logits,) + outputs[3:] output = (logits,) + outputs[3:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
return SemanticSegmentationModelOutput( return SemanticSegmenterOutput(
loss=loss, loss=loss,
logits=logits, logits=logits,
hidden_states=outputs.hidden_states if output_hidden_states else None, hidden_states=outputs.hidden_states if output_hidden_states else None,
......
...@@ -24,7 +24,7 @@ from torch import nn ...@@ -24,7 +24,7 @@ from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutput, SemanticSegmentationModelOutput, SequenceClassifierOutput from ...modeling_outputs import BaseModelOutput, SemanticSegmenterOutput, SequenceClassifierOutput
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import ( from ...utils import (
add_code_sample_docstrings, add_code_sample_docstrings,
...@@ -720,7 +720,7 @@ class SegformerForSemanticSegmentation(SegformerPreTrainedModel): ...@@ -720,7 +720,7 @@ class SegformerForSemanticSegmentation(SegformerPreTrainedModel):
self.post_init() self.post_init()
@add_start_docstrings_to_model_forward(SEGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(SEGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=SemanticSegmentationModelOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=SemanticSegmenterOutput, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
pixel_values: torch.FloatTensor, pixel_values: torch.FloatTensor,
...@@ -728,7 +728,7 @@ class SegformerForSemanticSegmentation(SegformerPreTrainedModel): ...@@ -728,7 +728,7 @@ class SegformerForSemanticSegmentation(SegformerPreTrainedModel):
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
) -> Union[Tuple, SemanticSegmentationModelOutput]: ) -> Union[Tuple, SemanticSegmenterOutput]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ..., Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ...,
...@@ -788,7 +788,7 @@ class SegformerForSemanticSegmentation(SegformerPreTrainedModel): ...@@ -788,7 +788,7 @@ class SegformerForSemanticSegmentation(SegformerPreTrainedModel):
output = (logits,) + outputs[2:] output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
return SemanticSegmentationModelOutput( return SemanticSegmenterOutput(
loss=loss, loss=loss,
logits=logits, logits=logits,
hidden_states=outputs.hidden_states if output_hidden_states else None, hidden_states=outputs.hidden_states if output_hidden_states else None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment