Unverified Commit 3db2e404 authored by Alara Dirik's avatar Alara Dirik Committed by GitHub
Browse files

Update Swin MIM output class (#22893)

Updates Swin MIM output class to match other masked image modeling outputs
parent 1e1cb6f8
......@@ -17,6 +17,7 @@
import collections.abc
import math
import warnings
from dataclasses import dataclass
from typing import Optional, Tuple, Union
......@@ -139,7 +140,7 @@ class SwinMaskedImageModelingOutput(ModelOutput):
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `bool_masked_pos` is provided):
Masked image modeling (MLM) loss.
logits (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
reconstruction (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Reconstructed pixel values.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
......@@ -161,11 +162,20 @@ class SwinMaskedImageModelingOutput(ModelOutput):
"""
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
reconstruction: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
reshaped_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
@property
def logits(self):
warnings.warn(
"logits attribute is deprecated and will be removed in version 5 of Transformers."
" Please use the reconstruction attribute to retrieve the final output instead.",
FutureWarning,
)
return self.reconstruction
@dataclass
class SwinImageClassifierOutput(ModelOutput):
......@@ -1094,7 +1104,7 @@ class SwinForMaskedImageModeling(SwinPreTrainedModel):
>>> bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool()
>>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos)
>>> loss, reconstructed_pixel_values = outputs.loss, outputs.logits
>>> loss, reconstructed_pixel_values = outputs.loss, outputs.reconstruction
>>> list(reconstructed_pixel_values.shape)
[1, 3, 192, 192]
```"""
......@@ -1138,7 +1148,7 @@ class SwinForMaskedImageModeling(SwinPreTrainedModel):
return SwinMaskedImageModelingOutput(
loss=masked_im_loss,
logits=reconstructed_pixel_values,
reconstruction=reconstructed_pixel_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
reshaped_hidden_states=outputs.reshaped_hidden_states,
......
......@@ -17,6 +17,7 @@
import collections.abc
import math
import warnings
from dataclasses import dataclass
from functools import partial
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
......@@ -143,7 +144,7 @@ class TFSwinMaskedImageModelingOutput(ModelOutput):
Args:
loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `bool_masked_pos` is provided):
Masked image modeling (MLM) loss.
logits (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`):
reconstruction (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`):
Reconstructed pixel values.
hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each stage) of shape
......@@ -165,11 +166,20 @@ class TFSwinMaskedImageModelingOutput(ModelOutput):
"""
loss: Optional[tf.Tensor] = None
logits: tf.Tensor = None
reconstruction: tf.Tensor = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[tf.Tensor]] = None
reshaped_hidden_states: Optional[Tuple[tf.Tensor]] = None
@property
def logits(self):
warnings.warn(
"logits attribute is deprecated and will be removed in version 5 of Transformers."
" Please use the reconstruction attribute to retrieve the final output instead.",
FutureWarning,
)
return self.reconstruction
@dataclass
class TFSwinImageClassifierOutput(ModelOutput):
......@@ -1340,7 +1350,7 @@ class TFSwinForMaskedImageModeling(TFSwinPreTrainedModel):
>>> bool_masked_pos = tf.random.uniform((1, num_patches)) >= 0.5
>>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos)
>>> loss, reconstructed_pixel_values = outputs.loss, outputs.logits
>>> loss, reconstructed_pixel_values = outputs.loss, outputs.reconstruction
>>> list(reconstructed_pixel_values.shape)
[1, 3, 224, 224]
```"""
......@@ -1392,7 +1402,7 @@ class TFSwinForMaskedImageModeling(TFSwinPreTrainedModel):
return TFSwinMaskedImageModelingOutput(
loss=masked_im_loss,
logits=reconstructed_pixel_values,
reconstruction=reconstructed_pixel_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
reshaped_hidden_states=outputs.reshaped_hidden_states,
......@@ -1401,7 +1411,7 @@ class TFSwinForMaskedImageModeling(TFSwinPreTrainedModel):
def serving_output(self, output: TFSwinMaskedImageModelingOutput) -> TFSwinMaskedImageModelingOutput:
# hidden_states and attentions not converted to Tensor with tf.convert_to_tensor as they are all of different dimensions
return TFSwinMaskedImageModelingOutput(
logits=output.logits,
reconstruction=output.reconstruction,
hidden_states=output.hidden_states,
attentions=output.attentions,
reshaped_hidden_states=output.reshaped_hidden_states,
......
......@@ -17,6 +17,7 @@
import collections.abc
import math
import warnings
from dataclasses import dataclass
from typing import Optional, Tuple, Union
......@@ -142,7 +143,7 @@ class Swinv2MaskedImageModelingOutput(ModelOutput):
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `bool_masked_pos` is provided):
Masked image modeling (MLM) loss.
logits (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
reconstruction (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Reconstructed pixel values.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
......@@ -164,11 +165,20 @@ class Swinv2MaskedImageModelingOutput(ModelOutput):
"""
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
reconstruction: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
reshaped_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
@property
def logits(self):
warnings.warn(
"logits attribute is deprecated and will be removed in version 5 of Transformers."
" Please use the reconstruction attribute to retrieve the final output instead.",
FutureWarning,
)
return self.reconstruction
@dataclass
# Copied from transformers.models.swin.modeling_swin.SwinImageClassifierOutput with Swin->Swinv2
......@@ -1175,7 +1185,7 @@ class Swinv2ForMaskedImageModeling(Swinv2PreTrainedModel):
>>> bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool()
>>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos)
>>> loss, reconstructed_pixel_values = outputs.loss, outputs.logits
>>> loss, reconstructed_pixel_values = outputs.loss, outputs.reconstruction
>>> list(reconstructed_pixel_values.shape)
[1, 3, 256, 256]
```"""
......@@ -1219,7 +1229,7 @@ class Swinv2ForMaskedImageModeling(Swinv2PreTrainedModel):
return Swinv2MaskedImageModelingOutput(
loss=masked_im_loss,
logits=reconstructed_pixel_values,
reconstruction=reconstructed_pixel_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
reshaped_hidden_states=outputs.reshaped_hidden_states,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment