Unverified Commit d4b3e56d authored by NielsRogge's avatar NielsRogge Committed by GitHub
Browse files

[Hotfix] Fix Swin model outputs (#15414)

* Fix Swin model outputs

* Rename pooler
parent 38dfb40a
...@@ -21,11 +21,11 @@ import math ...@@ -21,11 +21,11 @@ import math
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
from ...modeling_outputs import BaseModelOutput, SequenceClassifierOutput from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, SequenceClassifierOutput
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import logging from ...utils import logging
from .configuration_swin import SwinConfig from .configuration_swin import SwinConfig
...@@ -143,8 +143,8 @@ class SwinPatchEmbeddings(nn.Module): ...@@ -143,8 +143,8 @@ class SwinPatchEmbeddings(nn.Module):
self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size) self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, pixel_values): def forward(self, pixel_values):
pixel_values = self.projection(pixel_values).flatten(2).transpose(1, 2) embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
return pixel_values return embeddings
class SwinPatchMerging(nn.Module): class SwinPatchMerging(nn.Module):
...@@ -659,7 +659,7 @@ SWIN_INPUTS_DOCSTRING = r""" ...@@ -659,7 +659,7 @@ SWIN_INPUTS_DOCSTRING = r"""
SWIN_START_DOCSTRING, SWIN_START_DOCSTRING,
) )
class SwinModel(SwinPreTrainedModel): class SwinModel(SwinPreTrainedModel):
def __init__(self, config): def __init__(self, config, add_pooling_layer=True):
super().__init__(config) super().__init__(config)
self.config = config self.config = config
self.num_layers = len(config.depths) self.num_layers = len(config.depths)
...@@ -669,7 +669,7 @@ class SwinModel(SwinPreTrainedModel): ...@@ -669,7 +669,7 @@ class SwinModel(SwinPreTrainedModel):
self.encoder = SwinEncoder(config, self.embeddings.patch_grid) self.encoder = SwinEncoder(config, self.embeddings.patch_grid)
self.layernorm = nn.LayerNorm(self.num_features, eps=config.layer_norm_eps) self.layernorm = nn.LayerNorm(self.num_features, eps=config.layer_norm_eps)
self.pool = nn.AdaptiveAvgPool1d(1) self.pooler = nn.AdaptiveAvgPool1d(1) if add_pooling_layer else None
# Initialize weights and apply final processing # Initialize weights and apply final processing
self.post_init() self.post_init()
...@@ -686,7 +686,7 @@ class SwinModel(SwinPreTrainedModel): ...@@ -686,7 +686,7 @@ class SwinModel(SwinPreTrainedModel):
self.encoder.layer[layer].attention.prune_heads(heads) self.encoder.layer[layer].attention.prune_heads(heads)
@add_start_docstrings_to_model_forward(SWIN_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(SWIN_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
pixel_values=None, pixel_values=None,
...@@ -744,14 +744,18 @@ class SwinModel(SwinPreTrainedModel): ...@@ -744,14 +744,18 @@ class SwinModel(SwinPreTrainedModel):
sequence_output = encoder_outputs[0] sequence_output = encoder_outputs[0]
sequence_output = self.layernorm(sequence_output) sequence_output = self.layernorm(sequence_output)
sequence_output = self.pool(sequence_output.transpose(1, 2))
sequence_output = torch.flatten(sequence_output, 1) pooled_output = None
if self.pooler is not None:
pooled_output = self.pooler(sequence_output.transpose(1, 2))
pooled_output = torch.flatten(pooled_output, 1)
if not return_dict: if not return_dict:
return (sequence_output,) + encoder_outputs[1:] return (sequence_output, pooled_output) + encoder_outputs[1:]
return BaseModelOutput( return BaseModelOutputWithPooling(
last_hidden_state=sequence_output, last_hidden_state=sequence_output,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states, hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions, attentions=encoder_outputs.attentions,
) )
...@@ -829,22 +833,35 @@ class SwinForImageClassification(SwinPreTrainedModel): ...@@ -829,22 +833,35 @@ class SwinForImageClassification(SwinPreTrainedModel):
return_dict=return_dict, return_dict=return_dict,
) )
sequence_output = outputs[0] pooled_output = outputs[1]
logits = self.classifier(sequence_output) logits = self.classifier(pooled_output)
loss = None loss = None
if labels is not None: if labels is not None:
if self.num_labels == 1: if self.config.problem_type is None:
# We are doing regression if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss() loss_fct = MSELoss()
loss = loss_fct(logits.view(-1), labels.view(-1)) if self.num_labels == 1:
else: loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict: if not return_dict:
output = (logits,) + outputs[1:] output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput( return SequenceClassifierOutput(
......
...@@ -137,9 +137,11 @@ class SwinModelTester: ...@@ -137,9 +137,11 @@ class SwinModelTester:
model.eval() model.eval()
result = model(pixel_values) result = model(pixel_values)
num_features = int(config.embed_dim * 2 ** (len(config.depths) - 1)) # since the model we're testing only consists of a single layer, expected_seq_len = number of patches
expected_seq_len = (config.image_size // config.patch_size) ** 2
expected_dim = int(config.embed_dim * 2 ** (len(config.depths) - 1))
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_features)) self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, expected_seq_len, expected_dim))
def create_and_check_for_image_classification(self, config, pixel_values, labels): def create_and_check_for_image_classification(self, config, pixel_values, labels):
config.num_labels = self.type_sequence_label_size config.num_labels = self.type_sequence_label_size
...@@ -392,6 +394,6 @@ class SwinModelIntegrationTest(unittest.TestCase): ...@@ -392,6 +394,6 @@ class SwinModelIntegrationTest(unittest.TestCase):
expected_shape = torch.Size((1, 1000)) expected_shape = torch.Size((1, 1000))
self.assertEqual(outputs.logits.shape, expected_shape) self.assertEqual(outputs.logits.shape, expected_shape)
expected_slice = torch.tensor([-0.2952, -0.4777, 0.2025]).to(torch_device) expected_slice = torch.tensor([-0.0948, -0.6454, -0.0921]).to(torch_device)
self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4)) self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment