Unverified Commit 3db2e404 authored by Alara Dirik's avatar Alara Dirik Committed by GitHub
Browse files

Update Swin MIM output class (#22893)

Updates Swin MIM output class to match other masked image modeling outputs
parent 1e1cb6f8
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
import collections.abc import collections.abc
import math import math
import warnings
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
...@@ -139,7 +140,7 @@ class SwinMaskedImageModelingOutput(ModelOutput): ...@@ -139,7 +140,7 @@ class SwinMaskedImageModelingOutput(ModelOutput):
Args: Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `bool_masked_pos` is provided): loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `bool_masked_pos` is provided):
Masked image modeling (MLM) loss. Masked image modeling (MLM) loss.
logits (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): reconstruction (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Reconstructed pixel values. Reconstructed pixel values.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
...@@ -161,11 +162,20 @@ class SwinMaskedImageModelingOutput(ModelOutput): ...@@ -161,11 +162,20 @@ class SwinMaskedImageModelingOutput(ModelOutput):
""" """
loss: Optional[torch.FloatTensor] = None loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None reconstruction: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None
reshaped_hidden_states: Optional[Tuple[torch.FloatTensor]] = None reshaped_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
@property
def logits(self):
warnings.warn(
"logits attribute is deprecated and will be removed in version 5 of Transformers."
" Please use the reconstruction attribute to retrieve the final output instead.",
FutureWarning,
)
return self.reconstruction
@dataclass @dataclass
class SwinImageClassifierOutput(ModelOutput): class SwinImageClassifierOutput(ModelOutput):
...@@ -1094,7 +1104,7 @@ class SwinForMaskedImageModeling(SwinPreTrainedModel): ...@@ -1094,7 +1104,7 @@ class SwinForMaskedImageModeling(SwinPreTrainedModel):
>>> bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool() >>> bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool()
>>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos) >>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos)
>>> loss, reconstructed_pixel_values = outputs.loss, outputs.logits >>> loss, reconstructed_pixel_values = outputs.loss, outputs.reconstruction
>>> list(reconstructed_pixel_values.shape) >>> list(reconstructed_pixel_values.shape)
[1, 3, 192, 192] [1, 3, 192, 192]
```""" ```"""
...@@ -1138,7 +1148,7 @@ class SwinForMaskedImageModeling(SwinPreTrainedModel): ...@@ -1138,7 +1148,7 @@ class SwinForMaskedImageModeling(SwinPreTrainedModel):
return SwinMaskedImageModelingOutput( return SwinMaskedImageModelingOutput(
loss=masked_im_loss, loss=masked_im_loss,
logits=reconstructed_pixel_values, reconstruction=reconstructed_pixel_values,
hidden_states=outputs.hidden_states, hidden_states=outputs.hidden_states,
attentions=outputs.attentions, attentions=outputs.attentions,
reshaped_hidden_states=outputs.reshaped_hidden_states, reshaped_hidden_states=outputs.reshaped_hidden_states,
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
import collections.abc import collections.abc
import math import math
import warnings
from dataclasses import dataclass from dataclasses import dataclass
from functools import partial from functools import partial
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
...@@ -143,7 +144,7 @@ class TFSwinMaskedImageModelingOutput(ModelOutput): ...@@ -143,7 +144,7 @@ class TFSwinMaskedImageModelingOutput(ModelOutput):
Args: Args:
loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `bool_masked_pos` is provided): loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `bool_masked_pos` is provided):
Masked image modeling (MLM) loss. Masked image modeling (MLM) loss.
logits (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`): reconstruction (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`):
Reconstructed pixel values. Reconstructed pixel values.
hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each stage) of shape Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each stage) of shape
...@@ -165,11 +166,20 @@ class TFSwinMaskedImageModelingOutput(ModelOutput): ...@@ -165,11 +166,20 @@ class TFSwinMaskedImageModelingOutput(ModelOutput):
""" """
loss: Optional[tf.Tensor] = None loss: Optional[tf.Tensor] = None
logits: tf.Tensor = None reconstruction: tf.Tensor = None
hidden_states: Optional[Tuple[tf.Tensor]] = None hidden_states: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[tf.Tensor]] = None attentions: Optional[Tuple[tf.Tensor]] = None
reshaped_hidden_states: Optional[Tuple[tf.Tensor]] = None reshaped_hidden_states: Optional[Tuple[tf.Tensor]] = None
@property
def logits(self):
warnings.warn(
"logits attribute is deprecated and will be removed in version 5 of Transformers."
" Please use the reconstruction attribute to retrieve the final output instead.",
FutureWarning,
)
return self.reconstruction
@dataclass @dataclass
class TFSwinImageClassifierOutput(ModelOutput): class TFSwinImageClassifierOutput(ModelOutput):
...@@ -1340,7 +1350,7 @@ class TFSwinForMaskedImageModeling(TFSwinPreTrainedModel): ...@@ -1340,7 +1350,7 @@ class TFSwinForMaskedImageModeling(TFSwinPreTrainedModel):
>>> bool_masked_pos = tf.random.uniform((1, num_patches)) >= 0.5 >>> bool_masked_pos = tf.random.uniform((1, num_patches)) >= 0.5
>>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos) >>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos)
>>> loss, reconstructed_pixel_values = outputs.loss, outputs.logits >>> loss, reconstructed_pixel_values = outputs.loss, outputs.reconstruction
>>> list(reconstructed_pixel_values.shape) >>> list(reconstructed_pixel_values.shape)
[1, 3, 224, 224] [1, 3, 224, 224]
```""" ```"""
...@@ -1392,7 +1402,7 @@ class TFSwinForMaskedImageModeling(TFSwinPreTrainedModel): ...@@ -1392,7 +1402,7 @@ class TFSwinForMaskedImageModeling(TFSwinPreTrainedModel):
return TFSwinMaskedImageModelingOutput( return TFSwinMaskedImageModelingOutput(
loss=masked_im_loss, loss=masked_im_loss,
logits=reconstructed_pixel_values, reconstruction=reconstructed_pixel_values,
hidden_states=outputs.hidden_states, hidden_states=outputs.hidden_states,
attentions=outputs.attentions, attentions=outputs.attentions,
reshaped_hidden_states=outputs.reshaped_hidden_states, reshaped_hidden_states=outputs.reshaped_hidden_states,
...@@ -1401,7 +1411,7 @@ class TFSwinForMaskedImageModeling(TFSwinPreTrainedModel): ...@@ -1401,7 +1411,7 @@ class TFSwinForMaskedImageModeling(TFSwinPreTrainedModel):
def serving_output(self, output: TFSwinMaskedImageModelingOutput) -> TFSwinMaskedImageModelingOutput: def serving_output(self, output: TFSwinMaskedImageModelingOutput) -> TFSwinMaskedImageModelingOutput:
# hidden_states and attentions not converted to Tensor with tf.convert_to_tensor as they are all of different dimensions # hidden_states and attentions not converted to Tensor with tf.convert_to_tensor as they are all of different dimensions
return TFSwinMaskedImageModelingOutput( return TFSwinMaskedImageModelingOutput(
logits=output.logits, reconstruction=output.reconstruction,
hidden_states=output.hidden_states, hidden_states=output.hidden_states,
attentions=output.attentions, attentions=output.attentions,
reshaped_hidden_states=output.reshaped_hidden_states, reshaped_hidden_states=output.reshaped_hidden_states,
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
import collections.abc import collections.abc
import math import math
import warnings
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
...@@ -142,7 +143,7 @@ class Swinv2MaskedImageModelingOutput(ModelOutput): ...@@ -142,7 +143,7 @@ class Swinv2MaskedImageModelingOutput(ModelOutput):
Args: Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `bool_masked_pos` is provided): loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `bool_masked_pos` is provided):
Masked image modeling (MLM) loss. Masked image modeling (MLM) loss.
logits (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): reconstruction (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Reconstructed pixel values. Reconstructed pixel values.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
...@@ -164,11 +165,20 @@ class Swinv2MaskedImageModelingOutput(ModelOutput): ...@@ -164,11 +165,20 @@ class Swinv2MaskedImageModelingOutput(ModelOutput):
""" """
loss: Optional[torch.FloatTensor] = None loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None reconstruction: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None
reshaped_hidden_states: Optional[Tuple[torch.FloatTensor]] = None reshaped_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
@property
def logits(self):
warnings.warn(
"logits attribute is deprecated and will be removed in version 5 of Transformers."
" Please use the reconstruction attribute to retrieve the final output instead.",
FutureWarning,
)
return self.reconstruction
@dataclass @dataclass
# Copied from transformers.models.swin.modeling_swin.SwinImageClassifierOutput with Swin->Swinv2 # Copied from transformers.models.swin.modeling_swin.SwinImageClassifierOutput with Swin->Swinv2
...@@ -1175,7 +1185,7 @@ class Swinv2ForMaskedImageModeling(Swinv2PreTrainedModel): ...@@ -1175,7 +1185,7 @@ class Swinv2ForMaskedImageModeling(Swinv2PreTrainedModel):
>>> bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool() >>> bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool()
>>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos) >>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos)
>>> loss, reconstructed_pixel_values = outputs.loss, outputs.logits >>> loss, reconstructed_pixel_values = outputs.loss, outputs.reconstruction
>>> list(reconstructed_pixel_values.shape) >>> list(reconstructed_pixel_values.shape)
[1, 3, 256, 256] [1, 3, 256, 256]
```""" ```"""
...@@ -1219,7 +1229,7 @@ class Swinv2ForMaskedImageModeling(Swinv2PreTrainedModel): ...@@ -1219,7 +1229,7 @@ class Swinv2ForMaskedImageModeling(Swinv2PreTrainedModel):
return Swinv2MaskedImageModelingOutput( return Swinv2MaskedImageModelingOutput(
loss=masked_im_loss, loss=masked_im_loss,
logits=reconstructed_pixel_values, reconstruction=reconstructed_pixel_values,
hidden_states=outputs.hidden_states, hidden_states=outputs.hidden_states,
attentions=outputs.attentions, attentions=outputs.attentions,
reshaped_hidden_states=outputs.reshaped_hidden_states, reshaped_hidden_states=outputs.reshaped_hidden_states,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment