Unverified Commit 0558914d authored by Alara Dirik's avatar Alara Dirik Committed by GitHub
Browse files

Add MaskedImageModelingOutput (#22212)

* Add MaskedImageModelingOutput
parent 0dcb46e7
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import warnings
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple from typing import Optional, Tuple
...@@ -1622,3 +1623,40 @@ class SampleTSPredictionOutput(ModelOutput): ...@@ -1622,3 +1623,40 @@ class SampleTSPredictionOutput(ModelOutput):
""" """
sequences: torch.FloatTensor = None sequences: torch.FloatTensor = None
@dataclass
class MaskedImageModelingOutput(ModelOutput):
"""
Base class for outputs of masked image completion / in-painting models.
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `bool_masked_pos` is provided):
Reconstruction loss.
reconstruction (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Reconstructed / completed images.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or
when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each stage) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states
(also called feature maps) of the model at the output of each stage.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when
`config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, patch_size,
sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
the self-attention heads.
"""
loss: Optional[torch.FloatTensor] = None
reconstruction: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
@property
def logits(self):
warnings.warn(
"logits attribute is deprecated and will be removed in version 5 of Transformers."
" Please use the reconstruction attribute to retrieve the final output instead.",
FutureWarning,
)
return self.reconstruction
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import warnings
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
...@@ -55,8 +56,8 @@ class TFBaseModelOutputWithNoAttention(ModelOutput): ...@@ -55,8 +56,8 @@ class TFBaseModelOutputWithNoAttention(ModelOutput):
last_hidden_state (`tf.Tensor` shape `(batch_size, num_channels, height, width)`): last_hidden_state (`tf.Tensor` shape `(batch_size, num_channels, height, width)`):
Sequence of hidden-states at the output of the last layer of the model. Sequence of hidden-states at the output of the last layer of the model.
hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for
one for the output of each layer) of shape `(batch_size, num_channels, height, width)`. the output of each layer) of shape `(batch_size, num_channels, height, width)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
""" """
...@@ -949,3 +950,40 @@ class TFImageClassifierOutputWithNoAttention(ModelOutput): ...@@ -949,3 +950,40 @@ class TFImageClassifierOutputWithNoAttention(ModelOutput):
loss: Optional[tf.Tensor] = None loss: Optional[tf.Tensor] = None
logits: tf.Tensor = None logits: tf.Tensor = None
hidden_states: Optional[Tuple[tf.Tensor, ...]] = None hidden_states: Optional[Tuple[tf.Tensor, ...]] = None
@dataclass
class TFMaskedImageModelingOutput(ModelOutput):
"""
Base class for outputs of masked image completion / in-painting models.
Args:
loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `bool_masked_pos` is provided):
Reconstruction loss.
reconstruction (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`):
Reconstructed / completed images.
hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when
`config.output_hidden_states=True`):
Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for
the output of each stage) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states (also called
feature maps) of the model at the output of each stage.
attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when
`config.output_attentions=True`):
Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, patch_size, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
loss: Optional[tf.Tensor] = None
reconstruction: tf.Tensor = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[tf.Tensor]] = None
@property
def logits(self):
warnings.warn(
"logits attribute is deprecated and will be removed in version 5 of Transformers."
" Please use the reconstruction attribute to retrieve the final output instead.",
FutureWarning,
)
return self.reconstruction
...@@ -26,7 +26,12 @@ from torch import nn ...@@ -26,7 +26,12 @@ from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput, MaskedLMOutput from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPooling,
ImageClassifierOutput,
MaskedImageModelingOutput,
)
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import ( from ...utils import (
...@@ -592,7 +597,7 @@ class DeiTForMaskedImageModeling(DeiTPreTrainedModel): ...@@ -592,7 +597,7 @@ class DeiTForMaskedImageModeling(DeiTPreTrainedModel):
self.post_init() self.post_init()
@add_start_docstrings_to_model_forward(DEIT_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(DEIT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=MaskedImageModelingOutput, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
pixel_values: Optional[torch.Tensor] = None, pixel_values: Optional[torch.Tensor] = None,
...@@ -601,7 +606,7 @@ class DeiTForMaskedImageModeling(DeiTPreTrainedModel): ...@@ -601,7 +606,7 @@ class DeiTForMaskedImageModeling(DeiTPreTrainedModel):
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
) -> Union[tuple, MaskedLMOutput]: ) -> Union[tuple, MaskedImageModelingOutput]:
r""" r"""
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`): bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
...@@ -627,7 +632,7 @@ class DeiTForMaskedImageModeling(DeiTPreTrainedModel): ...@@ -627,7 +632,7 @@ class DeiTForMaskedImageModeling(DeiTPreTrainedModel):
>>> bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool() >>> bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool()
>>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos) >>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos)
>>> loss, reconstructed_pixel_values = outputs.loss, outputs.logits >>> loss, reconstructed_pixel_values = outputs.loss, outputs.reconstruction
>>> list(reconstructed_pixel_values.shape) >>> list(reconstructed_pixel_values.shape)
[1, 3, 224, 224] [1, 3, 224, 224]
```""" ```"""
...@@ -670,9 +675,9 @@ class DeiTForMaskedImageModeling(DeiTPreTrainedModel): ...@@ -670,9 +675,9 @@ class DeiTForMaskedImageModeling(DeiTPreTrainedModel):
output = (reconstructed_pixel_values,) + outputs[1:] output = (reconstructed_pixel_values,) + outputs[1:]
return ((masked_im_loss,) + output) if masked_im_loss is not None else output return ((masked_im_loss,) + output) if masked_im_loss is not None else output
return MaskedLMOutput( return MaskedImageModelingOutput(
loss=masked_im_loss, loss=masked_im_loss,
logits=reconstructed_pixel_values, reconstruction=reconstructed_pixel_values,
hidden_states=outputs.hidden_states, hidden_states=outputs.hidden_states,
attentions=outputs.attentions, attentions=outputs.attentions,
) )
......
...@@ -27,7 +27,7 @@ from ...modeling_tf_outputs import ( ...@@ -27,7 +27,7 @@ from ...modeling_tf_outputs import (
TFBaseModelOutput, TFBaseModelOutput,
TFBaseModelOutputWithPooling, TFBaseModelOutputWithPooling,
TFImageClassifierOutput, TFImageClassifierOutput,
TFMaskedLMOutput, TFMaskedImageModelingOutput,
) )
from ...modeling_tf_utils import ( from ...modeling_tf_utils import (
TFPreTrainedModel, TFPreTrainedModel,
...@@ -769,7 +769,7 @@ class TFDeiTForMaskedImageModeling(TFDeiTPreTrainedModel): ...@@ -769,7 +769,7 @@ class TFDeiTForMaskedImageModeling(TFDeiTPreTrainedModel):
@unpack_inputs @unpack_inputs
@add_start_docstrings_to_model_forward(DEIT_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(DEIT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFMaskedLMOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=TFMaskedImageModelingOutput, config_class=_CONFIG_FOR_DOC)
def call( def call(
self, self,
pixel_values: Optional[tf.Tensor] = None, pixel_values: Optional[tf.Tensor] = None,
...@@ -779,7 +779,7 @@ class TFDeiTForMaskedImageModeling(TFDeiTPreTrainedModel): ...@@ -779,7 +779,7 @@ class TFDeiTForMaskedImageModeling(TFDeiTPreTrainedModel):
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
training: bool = False, training: bool = False,
) -> Union[tuple, TFMaskedLMOutput]: ) -> Union[tuple, TFMaskedImageModelingOutput]:
r""" r"""
bool_masked_pos (`tf.Tensor` of type bool and shape `(batch_size, num_patches)`): bool_masked_pos (`tf.Tensor` of type bool and shape `(batch_size, num_patches)`):
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
...@@ -805,7 +805,7 @@ class TFDeiTForMaskedImageModeling(TFDeiTPreTrainedModel): ...@@ -805,7 +805,7 @@ class TFDeiTForMaskedImageModeling(TFDeiTPreTrainedModel):
>>> bool_masked_pos = tf.cast(tf.random.uniform((1, num_patches), minval=0, maxval=2, dtype=tf.int32), tf.bool) >>> bool_masked_pos = tf.cast(tf.random.uniform((1, num_patches), minval=0, maxval=2, dtype=tf.int32), tf.bool)
>>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos) >>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos)
>>> loss, reconstructed_pixel_values = outputs.loss, outputs.logits >>> loss, reconstructed_pixel_values = outputs.loss, outputs.reconstruction
>>> list(reconstructed_pixel_values.shape) >>> list(reconstructed_pixel_values.shape)
[1, 3, 224, 224] [1, 3, 224, 224]
```""" ```"""
...@@ -860,18 +860,20 @@ class TFDeiTForMaskedImageModeling(TFDeiTPreTrainedModel): ...@@ -860,18 +860,20 @@ class TFDeiTForMaskedImageModeling(TFDeiTPreTrainedModel):
output = (reconstructed_pixel_values,) + outputs[1:] output = (reconstructed_pixel_values,) + outputs[1:]
return ((masked_im_loss,) + output) if masked_im_loss is not None else output return ((masked_im_loss,) + output) if masked_im_loss is not None else output
return TFMaskedLMOutput( return TFMaskedImageModelingOutput(
loss=masked_im_loss, loss=masked_im_loss,
logits=reconstructed_pixel_values, reconstruction=reconstructed_pixel_values,
hidden_states=outputs.hidden_states, hidden_states=outputs.hidden_states,
attentions=outputs.attentions, attentions=outputs.attentions,
) )
def serving_output(self, output: TFMaskedLMOutput) -> TFMaskedLMOutput: def serving_output(self, output: TFMaskedImageModelingOutput) -> TFMaskedImageModelingOutput:
hidden_states = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None hidden_states = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attentions = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None attentions = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
return TFMaskedLMOutput(logits=output.logits, hidden_states=hidden_states, attentions=attentions) return TFMaskedImageModelingOutput(
reconstruction=output.reconstruction, hidden_states=hidden_states, attentions=attentions
)
@add_start_docstrings( @add_start_docstrings(
......
...@@ -25,7 +25,12 @@ from torch import nn ...@@ -25,7 +25,12 @@ from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput, MaskedLMOutput from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPooling,
ImageClassifierOutput,
MaskedImageModelingOutput,
)
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import ( from ...utils import (
...@@ -647,7 +652,7 @@ class ViTForMaskedImageModeling(ViTPreTrainedModel): ...@@ -647,7 +652,7 @@ class ViTForMaskedImageModeling(ViTPreTrainedModel):
self.post_init() self.post_init()
@add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=MaskedImageModelingOutput, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
pixel_values: Optional[torch.Tensor] = None, pixel_values: Optional[torch.Tensor] = None,
...@@ -657,7 +662,7 @@ class ViTForMaskedImageModeling(ViTPreTrainedModel): ...@@ -657,7 +662,7 @@ class ViTForMaskedImageModeling(ViTPreTrainedModel):
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: Optional[bool] = None, interpolate_pos_encoding: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
) -> Union[tuple, MaskedLMOutput]: ) -> Union[tuple, MaskedImageModelingOutput]:
r""" r"""
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`): bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
...@@ -683,7 +688,7 @@ class ViTForMaskedImageModeling(ViTPreTrainedModel): ...@@ -683,7 +688,7 @@ class ViTForMaskedImageModeling(ViTPreTrainedModel):
>>> bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool() >>> bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool()
>>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos) >>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos)
>>> loss, reconstructed_pixel_values = outputs.loss, outputs.logits >>> loss, reconstructed_pixel_values = outputs.loss, outputs.reconstruction
>>> list(reconstructed_pixel_values.shape) >>> list(reconstructed_pixel_values.shape)
[1, 3, 224, 224] [1, 3, 224, 224]
```""" ```"""
...@@ -727,9 +732,9 @@ class ViTForMaskedImageModeling(ViTPreTrainedModel): ...@@ -727,9 +732,9 @@ class ViTForMaskedImageModeling(ViTPreTrainedModel):
output = (reconstructed_pixel_values,) + outputs[1:] output = (reconstructed_pixel_values,) + outputs[1:]
return ((masked_im_loss,) + output) if masked_im_loss is not None else output return ((masked_im_loss,) + output) if masked_im_loss is not None else output
return MaskedLMOutput( return MaskedImageModelingOutput(
loss=masked_im_loss, loss=masked_im_loss,
logits=reconstructed_pixel_values, reconstruction=reconstructed_pixel_values,
hidden_states=outputs.hidden_states, hidden_states=outputs.hidden_states,
attentions=outputs.attentions, attentions=outputs.attentions,
) )
......
...@@ -145,7 +145,7 @@ class DeiTModelTester: ...@@ -145,7 +145,7 @@ class DeiTModelTester:
model.eval() model.eval()
result = model(pixel_values) result = model(pixel_values)
self.parent.assertEqual( self.parent.assertEqual(
result.logits.shape, (self.batch_size, self.num_channels, self.image_size, self.image_size) result.reconstruction.shape, (self.batch_size, self.num_channels, self.image_size, self.image_size)
) )
# test greyscale images # test greyscale images
...@@ -156,7 +156,7 @@ class DeiTModelTester: ...@@ -156,7 +156,7 @@ class DeiTModelTester:
pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size]) pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
result = model(pixel_values) result = model(pixel_values)
self.parent.assertEqual(result.logits.shape, (self.batch_size, 1, self.image_size, self.image_size)) self.parent.assertEqual(result.reconstruction.shape, (self.batch_size, 1, self.image_size, self.image_size))
def create_and_check_for_image_classification(self, config, pixel_values, labels): def create_and_check_for_image_classification(self, config, pixel_values, labels):
config.num_labels = self.type_sequence_label_size config.num_labels = self.type_sequence_label_size
......
...@@ -130,7 +130,7 @@ class TFDeiTModelTester: ...@@ -130,7 +130,7 @@ class TFDeiTModelTester:
model = TFDeiTForMaskedImageModeling(config=config) model = TFDeiTForMaskedImageModeling(config=config)
result = model(pixel_values) result = model(pixel_values)
self.parent.assertEqual( self.parent.assertEqual(
result.logits.shape, (self.batch_size, self.num_channels, self.image_size, self.image_size) result.reconstruction.shape, (self.batch_size, self.num_channels, self.image_size, self.image_size)
) )
# test greyscale images # test greyscale images
...@@ -139,7 +139,7 @@ class TFDeiTModelTester: ...@@ -139,7 +139,7 @@ class TFDeiTModelTester:
pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size]) pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
result = model(pixel_values) result = model(pixel_values)
self.parent.assertEqual(result.logits.shape, (self.batch_size, 1, self.image_size, self.image_size)) self.parent.assertEqual(result.reconstruction.shape, (self.batch_size, 1, self.image_size, self.image_size))
def create_and_check_for_image_classification(self, config, pixel_values, labels): def create_and_check_for_image_classification(self, config, pixel_values, labels):
config.num_labels = self.type_sequence_label_size config.num_labels = self.type_sequence_label_size
......
...@@ -134,7 +134,7 @@ class ViTModelTester: ...@@ -134,7 +134,7 @@ class ViTModelTester:
model.eval() model.eval()
result = model(pixel_values) result = model(pixel_values)
self.parent.assertEqual( self.parent.assertEqual(
result.logits.shape, (self.batch_size, self.num_channels, self.image_size, self.image_size) result.reconstruction.shape, (self.batch_size, self.num_channels, self.image_size, self.image_size)
) )
# test greyscale images # test greyscale images
...@@ -145,7 +145,7 @@ class ViTModelTester: ...@@ -145,7 +145,7 @@ class ViTModelTester:
pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size]) pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
result = model(pixel_values) result = model(pixel_values)
self.parent.assertEqual(result.logits.shape, (self.batch_size, 1, self.image_size, self.image_size)) self.parent.assertEqual(result.reconstruction.shape, (self.batch_size, 1, self.image_size, self.image_size))
def create_and_check_for_image_classification(self, config, pixel_values, labels): def create_and_check_for_image_classification(self, config, pixel_values, labels):
config.num_labels = self.type_sequence_label_size config.num_labels = self.type_sequence_label_size
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment