Unverified Commit d4b3e56d authored by NielsRogge's avatar NielsRogge Committed by GitHub
Browse files

[Hotfix] Fix Swin model outputs (#15414)

* Fix Swin model outputs

* Rename pooler
parent 38dfb40a
......@@ -21,11 +21,11 @@ import math
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
from ...modeling_outputs import BaseModelOutput, SequenceClassifierOutput
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, SequenceClassifierOutput
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import logging
from .configuration_swin import SwinConfig
......@@ -143,8 +143,8 @@ class SwinPatchEmbeddings(nn.Module):
self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, pixel_values):
pixel_values = self.projection(pixel_values).flatten(2).transpose(1, 2)
return pixel_values
embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
return embeddings
class SwinPatchMerging(nn.Module):
......@@ -659,7 +659,7 @@ SWIN_INPUTS_DOCSTRING = r"""
SWIN_START_DOCSTRING,
)
class SwinModel(SwinPreTrainedModel):
def __init__(self, config):
def __init__(self, config, add_pooling_layer=True):
super().__init__(config)
self.config = config
self.num_layers = len(config.depths)
......@@ -669,7 +669,7 @@ class SwinModel(SwinPreTrainedModel):
self.encoder = SwinEncoder(config, self.embeddings.patch_grid)
self.layernorm = nn.LayerNorm(self.num_features, eps=config.layer_norm_eps)
self.pool = nn.AdaptiveAvgPool1d(1)
self.pooler = nn.AdaptiveAvgPool1d(1) if add_pooling_layer else None
# Initialize weights and apply final processing
self.post_init()
......@@ -686,7 +686,7 @@ class SwinModel(SwinPreTrainedModel):
self.encoder.layer[layer].attention.prune_heads(heads)
@add_start_docstrings_to_model_forward(SWIN_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC)
def forward(
self,
pixel_values=None,
......@@ -744,14 +744,18 @@ class SwinModel(SwinPreTrainedModel):
sequence_output = encoder_outputs[0]
sequence_output = self.layernorm(sequence_output)
sequence_output = self.pool(sequence_output.transpose(1, 2))
sequence_output = torch.flatten(sequence_output, 1)
pooled_output = None
if self.pooler is not None:
pooled_output = self.pooler(sequence_output.transpose(1, 2))
pooled_output = torch.flatten(pooled_output, 1)
if not return_dict:
return (sequence_output,) + encoder_outputs[1:]
return (sequence_output, pooled_output) + encoder_outputs[1:]
return BaseModelOutput(
return BaseModelOutputWithPooling(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
......@@ -829,22 +833,35 @@ class SwinForImageClassification(SwinPreTrainedModel):
return_dict=return_dict,
)
sequence_output = outputs[0]
pooled_output = outputs[1]
logits = self.classifier(sequence_output)
logits = self.classifier(pooled_output)
loss = None
if labels is not None:
if self.config.problem_type is None:
if self.num_labels == 1:
# We are doing regression
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
loss = loss_fct(logits.view(-1), labels.view(-1))
if self.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict:
output = (logits,) + outputs[1:]
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(
......
......@@ -137,9 +137,11 @@ class SwinModelTester:
model.eval()
result = model(pixel_values)
num_features = int(config.embed_dim * 2 ** (len(config.depths) - 1))
# since the model we're testing only consists of a single layer, expected_seq_len = number of patches
expected_seq_len = (config.image_size // config.patch_size) ** 2
expected_dim = int(config.embed_dim * 2 ** (len(config.depths) - 1))
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_features))
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, expected_seq_len, expected_dim))
def create_and_check_for_image_classification(self, config, pixel_values, labels):
config.num_labels = self.type_sequence_label_size
......@@ -392,6 +394,6 @@ class SwinModelIntegrationTest(unittest.TestCase):
expected_shape = torch.Size((1, 1000))
self.assertEqual(outputs.logits.shape, expected_shape)
expected_slice = torch.tensor([-0.2952, -0.4777, 0.2025]).to(torch_device)
expected_slice = torch.tensor([-0.0948, -0.6454, -0.0921]).to(torch_device)
self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment