"docs/source/en/tasks/semantic_segmentation.md" did not exist on "da005253b82395b6097623bcee44b819bfe72b87"
Unverified Commit 73768147 authored by amyeroberts's avatar amyeroberts Committed by GitHub
Browse files

Revert 22152 MaskedImageCompletionOutput changes (#22187)

Revert changes
parent 7b0e2cfd
...@@ -1281,34 +1281,6 @@ class ImageSuperResolutionOutput(ModelOutput): ...@@ -1281,34 +1281,6 @@ class ImageSuperResolutionOutput(ModelOutput):
attentions: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class MaskedImageCompletionOutput(ModelOutput):
"""
Base class for outputs of masked image completion / in-painting models.
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Reconstruction loss.
reconstruction (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Reconstructed / completed images.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each stage) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states
(also called feature maps) of the model at the output of each stage.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, patch_size,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
loss: Optional[torch.FloatTensor] = None
reconstruction: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass @dataclass
class Wav2Vec2BaseModelOutput(ModelOutput): class Wav2Vec2BaseModelOutput(ModelOutput):
""" """
......
...@@ -25,12 +25,7 @@ from torch import nn ...@@ -25,12 +25,7 @@ from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_outputs import ( from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput, MaskedLMOutput
BaseModelOutput,
BaseModelOutputWithPooling,
ImageClassifierOutput,
MaskedImageCompletionOutput,
)
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import ( from ...utils import (
...@@ -648,7 +643,7 @@ class ViTForMaskedImageModeling(ViTPreTrainedModel): ...@@ -648,7 +643,7 @@ class ViTForMaskedImageModeling(ViTPreTrainedModel):
self.post_init() self.post_init()
@add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=MaskedImageCompletionOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
pixel_values: Optional[torch.Tensor] = None, pixel_values: Optional[torch.Tensor] = None,
...@@ -658,7 +653,7 @@ class ViTForMaskedImageModeling(ViTPreTrainedModel): ...@@ -658,7 +653,7 @@ class ViTForMaskedImageModeling(ViTPreTrainedModel):
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: Optional[bool] = None, interpolate_pos_encoding: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
) -> Union[tuple, MaskedImageCompletionOutput]: ) -> Union[tuple, MaskedLMOutput]:
r""" r"""
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`): bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
...@@ -728,9 +723,9 @@ class ViTForMaskedImageModeling(ViTPreTrainedModel): ...@@ -728,9 +723,9 @@ class ViTForMaskedImageModeling(ViTPreTrainedModel):
output = (reconstructed_pixel_values,) + outputs[1:] output = (reconstructed_pixel_values,) + outputs[1:]
return ((masked_im_loss,) + output) if masked_im_loss is not None else output return ((masked_im_loss,) + output) if masked_im_loss is not None else output
return MaskedImageCompletionOutput( return MaskedLMOutput(
loss=masked_im_loss, loss=masked_im_loss,
reconstruction=reconstructed_pixel_values, logits=reconstructed_pixel_values,
hidden_states=outputs.hidden_states, hidden_states=outputs.hidden_states,
attentions=outputs.attentions, attentions=outputs.attentions,
) )
......
...@@ -134,7 +134,7 @@ class ViTModelTester: ...@@ -134,7 +134,7 @@ class ViTModelTester:
model.eval() model.eval()
result = model(pixel_values) result = model(pixel_values)
self.parent.assertEqual( self.parent.assertEqual(
result.reconstruction.shape, (self.batch_size, self.num_channels, self.image_size, self.image_size) result.logits.shape, (self.batch_size, self.num_channels, self.image_size, self.image_size)
) )
# test greyscale images # test greyscale images
...@@ -145,7 +145,7 @@ class ViTModelTester: ...@@ -145,7 +145,7 @@ class ViTModelTester:
pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size]) pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
result = model(pixel_values) result = model(pixel_values)
self.parent.assertEqual(result.reconstruction.shape, (self.batch_size, 1, self.image_size, self.image_size)) self.parent.assertEqual(result.logits.shape, (self.batch_size, 1, self.image_size, self.image_size))
def create_and_check_for_image_classification(self, config, pixel_values, labels): def create_and_check_for_image_classification(self, config, pixel_values, labels):
config.num_labels = self.type_sequence_label_size config.num_labels = self.type_sequence_label_size
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment