Unverified Commit 667b823b authored by Francesco Saverio Zuppichini's avatar Francesco Saverio Zuppichini Committed by GitHub
Browse files

Swin support for any input size (#15986)



* padding done

* correctly return one attention per layer

* almost correct, attentions are not flatten one tuple per stage

* tests green

* doc

* conversations

* reshaping hidden_states

* view in the test

* reshape_hidden_states in Encoder and Model

* new outputs with reshaped_hidden_states

* conversations

* doc

* Update docs/source/model_doc/swin.mdx
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Apply suggestions from code review
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* conversations

* fix tests

* minor changes

* resolved conversations

* attentions one per stage

* typo

* typos

* typos

* function signature

* CI

* clean up tests
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>
parent 204c54d4
...@@ -34,6 +34,8 @@ The hierarchical design and the shifted window approach also prove beneficial fo ...@@ -34,6 +34,8 @@ The hierarchical design and the shifted window approach also prove beneficial fo
Tips: Tips:
- One can use the [`AutoFeatureExtractor`] API to prepare images for the model. - One can use the [`AutoFeatureExtractor`] API to prepare images for the model.
- Swin pads the inputs supporting any input height and width (if divisible by `32`).
- Swin can be used as a *backbone*. When `output_hidden_states = True`, it will output both `hidden_states` and `reshaped_hidden_states`. The `reshaped_hidden_states` have a shape of `(batch, num_channels, height, width)` rather than `(batch_size, sequence_length, num_channels)`.
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/swin_transformer_architecture.png" <img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/swin_transformer_architecture.png"
alt="drawing" width="600"/> alt="drawing" width="600"/>
......
...@@ -17,6 +17,8 @@ ...@@ -17,6 +17,8 @@
import collections.abc import collections.abc
import math import math
from dataclasses import dataclass
from typing import Optional, Tuple
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
...@@ -25,12 +27,12 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss ...@@ -25,12 +27,12 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...file_utils import ( from ...file_utils import (
ModelOutput,
add_code_sample_docstrings, add_code_sample_docstrings,
add_start_docstrings, add_start_docstrings,
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
replace_return_docstrings, replace_return_docstrings,
) )
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, MaskedLMOutput, SequenceClassifierOutput
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import logging from ...utils import logging
from .configuration_swin import SwinConfig from .configuration_swin import SwinConfig
...@@ -56,10 +58,150 @@ SWIN_PRETRAINED_MODEL_ARCHIVE_LIST = [ ...@@ -56,10 +58,150 @@ SWIN_PRETRAINED_MODEL_ARCHIVE_LIST = [
# See all Swin models at https://huggingface.co/models?filter=swin # See all Swin models at https://huggingface.co/models?filter=swin
] ]
# to_2tuple, drop_path, SwinPatchEmbeddings, SwinPatchMerging and SwinDropPath are from the timm library. # to_2tuple, drop_path, SwinPatchEmbeddings, SwinPatchMerging and SwinDropPath are from the timm library.
@dataclass
class SwinEncoderOutput(ModelOutput):
"""
Swin encoder's outputs, with potential hidden states and attentions.
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
shape `(batch_size, hidden_size, height, width)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
include the spatial dimensions.
"""
last_hidden_state: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
reshaped_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class SwinModelOutput(ModelOutput):
"""
Swin model's outputs that also contains a pooling of the last hidden states.
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
Average pooling of the last layer hidden-state.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
shape `(batch_size, hidden_size, height, width)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
include the spatial dimensions.
"""
last_hidden_state: torch.FloatTensor = None
pooler_output: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
reshaped_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class SwinMaskedImageModelingOutput(ModelOutput):
"""
Swin masked image model outputs.
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `bool_masked_pos` is provided):
Masked image modeling (MLM) loss.
logits (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Reconstructed pixel values.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
shape `(batch_size, hidden_size, height, width)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
include the spatial dimensions.
"""
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
reshaped_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class SwinImageClassifierOutput(ModelOutput):
"""
Swin outputs for image classification.
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Classification (or regression if config.num_labels==1) loss.
logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
Classification (or regression if config.num_labels==1) scores (before SoftMax).
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
shape `(batch_size, hidden_size, height, width)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
include the spatial dimensions.
"""
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
reshaped_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
# Copied from transformers.models.vit.modeling_vit.to_2tuple # Copied from transformers.models.vit.modeling_vit.to_2tuple
def to_2tuple(x): def to_2tuple(x):
if isinstance(x, collections.abc.Iterable): if isinstance(x, collections.abc.Iterable):
...@@ -130,7 +272,7 @@ class SwinEmbeddings(nn.Module): ...@@ -130,7 +272,7 @@ class SwinEmbeddings(nn.Module):
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, pixel_values, bool_masked_pos=None): def forward(self, pixel_values, bool_masked_pos=None):
embeddings = self.patch_embeddings(pixel_values) embeddings, output_dimensions = self.patch_embeddings(pixel_values)
embeddings = self.norm(embeddings) embeddings = self.norm(embeddings)
batch_size, seq_len, _ = embeddings.size() batch_size, seq_len, _ = embeddings.size()
...@@ -145,7 +287,7 @@ class SwinEmbeddings(nn.Module): ...@@ -145,7 +287,7 @@ class SwinEmbeddings(nn.Module):
embeddings = self.dropout(embeddings) embeddings = self.dropout(embeddings)
return embeddings return embeddings, output_dimensions
class SwinPatchEmbeddings(nn.Module): class SwinPatchEmbeddings(nn.Module):
...@@ -165,9 +307,25 @@ class SwinPatchEmbeddings(nn.Module): ...@@ -165,9 +307,25 @@ class SwinPatchEmbeddings(nn.Module):
self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size) self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
def maybe_pad(self, pixel_values, height, width):
if width % self.patch_size[1] != 0:
pad_values = (0, self.patch_size[1] - width % self.patch_size[1])
pixel_values = nn.functional.pad(pixel_values, pad_values)
if height % self.patch_size[0] != 0:
pad_values = (0, 0, 0, self.patch_size[0] - height % self.patch_size[0])
pixel_values = nn.functional.pad(pixel_values, pad_values)
return pixel_values
def forward(self, pixel_values): def forward(self, pixel_values):
embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2) _, _, height, width = pixel_values.shape
return embeddings # pad the input to be divisible by self.patch_size, if needed
pixel_values = self.maybe_pad(pixel_values, height, width)
embeddings = self.projection(pixel_values)
_, _, height, width = embeddings.shape
output_dimensions = (height, width)
embeddings = embeddings.flatten(2).transpose(1, 2)
return embeddings, output_dimensions
class SwinPatchMerging(nn.Module): class SwinPatchMerging(nn.Module):
...@@ -190,17 +348,30 @@ class SwinPatchMerging(nn.Module): ...@@ -190,17 +348,30 @@ class SwinPatchMerging(nn.Module):
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(4 * dim) self.norm = norm_layer(4 * dim)
def forward(self, input_feature): def maybe_pad(self, input_feature, height, width):
height, width = self.input_resolution should_pad = (height % 2 == 1) or (width % 2 == 1)
if should_pad:
pad_values = (0, 0, 0, width % 2, 0, height % 2)
input_feature = nn.functional.pad(input_feature, pad_values)
return input_feature
def forward(self, input_feature, input_dimensions):
height, width = input_dimensions
# `dim` is height * width # `dim` is height * width
batch_size, dim, num_channels = input_feature.shape batch_size, dim, num_channels = input_feature.shape
input_feature = input_feature.view(batch_size, height, width, num_channels) input_feature = input_feature.view(batch_size, height, width, num_channels)
# pad input to be disible by width and height, if needed
input_feature_0 = input_feature[:, 0::2, 0::2, :] # batch_size height/2 width/2 num_channels input_feature = self.maybe_pad(input_feature, height, width)
input_feature_1 = input_feature[:, 1::2, 0::2, :] # batch_size height/2 width/2 num_channels # [batch_size, height/2, width/2, num_channels]
input_feature_2 = input_feature[:, 0::2, 1::2, :] # batch_size height/2 width/2 num_channels input_feature_0 = input_feature[:, 0::2, 0::2, :]
input_feature_3 = input_feature[:, 1::2, 1::2, :] # batch_size height/2 width/2 num_channels # [batch_size, height/2, width/2, num_channels]
input_feature_1 = input_feature[:, 1::2, 0::2, :]
# [batch_size, height/2, width/2, num_channels]
input_feature_2 = input_feature[:, 0::2, 1::2, :]
# [batch_size, height/2, width/2, num_channels]
input_feature_3 = input_feature[:, 1::2, 1::2, :]
# batch_size height/2 width/2 4*num_channels # batch_size height/2 width/2 4*num_channels
input_feature = torch.cat([input_feature_0, input_feature_1, input_feature_2, input_feature_3], -1) input_feature = torch.cat([input_feature_0, input_feature_1, input_feature_2, input_feature_3], -1)
input_feature = input_feature.view(batch_size, -1, 4 * num_channels) # batch_size height/2*width/2 4*C input_feature = input_feature.view(batch_size, -1, 4 * num_channels) # batch_size height/2*width/2 4*C
...@@ -393,19 +564,14 @@ class SwinOutput(nn.Module): ...@@ -393,19 +564,14 @@ class SwinOutput(nn.Module):
return hidden_states return hidden_states
class SwinBlock(nn.Module): class SwinLayer(nn.Module):
def __init__(self, config, dim, input_resolution, num_heads, shift_size=0): def __init__(self, config, dim, input_resolution, num_heads, shift_size=0):
super().__init__() super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.shift_size = shift_size self.shift_size = shift_size
self.window_size = config.window_size self.window_size = config.window_size
self.input_resolution = input_resolution self.input_resolution = input_resolution
self.set_shift_and_window_size(input_resolution)
if min(self.input_resolution) <= self.window_size:
# if window size is larger than input resolution, we don't partition windows
self.shift_size = 0
self.window_size = min(self.input_resolution)
self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps) self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps)
self.attention = SwinAttention(config, dim, num_heads) self.attention = SwinAttention(config, dim, num_heads)
self.drop_path = SwinDropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity() self.drop_path = SwinDropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity()
...@@ -413,9 +579,15 @@ class SwinBlock(nn.Module): ...@@ -413,9 +579,15 @@ class SwinBlock(nn.Module):
self.intermediate = SwinIntermediate(config, dim) self.intermediate = SwinIntermediate(config, dim)
self.output = SwinOutput(config, dim) self.output = SwinOutput(config, dim)
def set_shift_and_window_size(self, input_resolution):
if min(input_resolution) <= self.window_size:
# if window size is larger than input resolution, we don't partition windows
self.shift_size = 0
self.window_size = min(input_resolution)
def get_attn_mask(self, height, width):
if self.shift_size > 0: if self.shift_size > 0:
# calculate attention mask for SW-MSA # calculate attention mask for SW-MSA
height, width = self.input_resolution
img_mask = torch.zeros((1, height, width, 1)) img_mask = torch.zeros((1, height, width, 1))
height_slices = ( height_slices = (
slice(0, -self.window_size), slice(0, -self.window_size),
...@@ -439,17 +611,27 @@ class SwinBlock(nn.Module): ...@@ -439,17 +611,27 @@ class SwinBlock(nn.Module):
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
else: else:
attn_mask = None attn_mask = None
return attn_mask
self.attn_mask = attn_mask
def maybe_pad(self, hidden_states, height, width):
def forward(self, hidden_states, head_mask=None, output_attentions=False): pad_right = (self.window_size - width % self.window_size) % self.window_size
height, width = self.input_resolution pad_bottom = (self.window_size - height % self.window_size) % self.window_size
batch_size, dim, channels = hidden_states.size() pad_values = (0, 0, 0, pad_right, 0, pad_bottom)
hidden_states = nn.functional.pad(hidden_states, pad_values)
return hidden_states, pad_values
def forward(self, hidden_states, input_dimensions, head_mask=None, output_attentions=False):
self.set_shift_and_window_size(input_dimensions)
height, width = input_dimensions
batch_size, _, channels = hidden_states.size()
shortcut = hidden_states shortcut = hidden_states
hidden_states = self.layernorm_before(hidden_states) hidden_states = self.layernorm_before(hidden_states)
hidden_states = hidden_states.view(batch_size, height, width, channels) hidden_states = hidden_states.view(batch_size, height, width, channels)
# pad hidden_states to multiples of window size
hidden_states, pad_values = self.maybe_pad(hidden_states, height, width)
_, height_pad, width_pad, _ = hidden_states.shape
# cyclic shift # cyclic shift
if self.shift_size > 0: if self.shift_size > 0:
shifted_hidden_states = torch.roll(hidden_states, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) shifted_hidden_states = torch.roll(hidden_states, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
...@@ -459,23 +641,18 @@ class SwinBlock(nn.Module): ...@@ -459,23 +641,18 @@ class SwinBlock(nn.Module):
# partition windows # partition windows
hidden_states_windows = window_partition(shifted_hidden_states, self.window_size) hidden_states_windows = window_partition(shifted_hidden_states, self.window_size)
hidden_states_windows = hidden_states_windows.view(-1, self.window_size * self.window_size, channels) hidden_states_windows = hidden_states_windows.view(-1, self.window_size * self.window_size, channels)
attn_mask = self.get_attn_mask(height_pad, width_pad)
if attn_mask is not None:
attn_mask = attn_mask.to(hidden_states_windows.device)
if self.attn_mask is not None: attention_outputs = self.attention(
self.attn_mask = self.attn_mask.to(hidden_states_windows.device) hidden_states_windows, attn_mask, head_mask, output_attentions=output_attentions
self_attention_outputs = self.attention(
hidden_states_windows,
self.attn_mask,
head_mask,
output_attentions=output_attentions,
) )
attention_output = self_attention_outputs[0] attention_output = attention_outputs[0]
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
attention_windows = attention_output.view(-1, self.window_size, self.window_size, channels) attention_windows = attention_output.view(-1, self.window_size, self.window_size, channels)
shifted_windows = window_reverse(attention_windows, self.window_size, height, width) # B H' W' C shifted_windows = window_reverse(attention_windows, self.window_size, height_pad, width_pad)
# reverse cyclic shift # reverse cyclic shift
if self.shift_size > 0: if self.shift_size > 0:
...@@ -483,6 +660,10 @@ class SwinBlock(nn.Module): ...@@ -483,6 +660,10 @@ class SwinBlock(nn.Module):
else: else:
attention_windows = shifted_windows attention_windows = shifted_windows
was_padded = pad_values[3] > 0 or pad_values[5] > 0
if was_padded:
attention_windows = attention_windows[:, :height, :width, :].contiguous()
attention_windows = attention_windows.view(batch_size, height * width, channels) attention_windows = attention_windows.view(batch_size, height * width, channels)
hidden_states = shortcut + self.drop_path(attention_windows) hidden_states = shortcut + self.drop_path(attention_windows)
...@@ -491,19 +672,18 @@ class SwinBlock(nn.Module): ...@@ -491,19 +672,18 @@ class SwinBlock(nn.Module):
layer_output = self.intermediate(layer_output) layer_output = self.intermediate(layer_output)
layer_output = hidden_states + self.output(layer_output) layer_output = hidden_states + self.output(layer_output)
outputs = (layer_output,) + outputs layer_outputs = (layer_output, attention_outputs[1]) if output_attentions else (layer_output,)
return layer_outputs
return outputs
class SwinLayer(nn.Module): class SwinStage(nn.Module):
def __init__(self, config, dim, input_resolution, depth, num_heads, drop_path, downsample): def __init__(self, config, dim, input_resolution, depth, num_heads, drop_path, downsample):
super().__init__() super().__init__()
self.config = config self.config = config
self.dim = dim self.dim = dim
self.blocks = nn.ModuleList( self.blocks = nn.ModuleList(
[ [
SwinBlock( SwinLayer(
config=config, config=config,
dim=dim, dim=dim,
input_resolution=input_resolution, input_resolution=input_resolution,
...@@ -522,29 +702,28 @@ class SwinLayer(nn.Module): ...@@ -522,29 +702,28 @@ class SwinLayer(nn.Module):
self.pointing = False self.pointing = False
def forward(self, hidden_states, head_mask=None, output_attentions=False, output_hidden_states=False): def forward(self, hidden_states, input_dimensions, head_mask=None, output_attentions=False):
all_hidden_states = () if output_hidden_states else None height, width = input_dimensions
for i, layer_module in enumerate(self.blocks):
for i, block_module in enumerate(self.blocks):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
layer_outputs = block_module( layer_outputs = layer_module(hidden_states, input_dimensions, layer_head_mask, output_attentions)
hidden_states,
layer_head_mask,
output_attentions,
)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if self.downsample is not None: if self.downsample is not None:
layer_outputs_list = list(layer_outputs) height_downsampled, width_downsampled = (height + 1) // 2, (width + 1) // 2
layer_outputs_list[0] = self.downsample(layer_outputs[0]) output_dimensions = (height, width, height_downsampled, width_downsampled)
layer_outputs = tuple(layer_outputs_list) hidden_states = self.downsample(layer_outputs[0], input_dimensions)
else:
output_dimensions = (height, width, height, width)
return layer_outputs stage_outputs = (hidden_states, output_dimensions)
if output_attentions:
stage_outputs += layer_outputs[1:]
return stage_outputs
class SwinEncoder(nn.Module): class SwinEncoder(nn.Module):
...@@ -555,7 +734,7 @@ class SwinEncoder(nn.Module): ...@@ -555,7 +734,7 @@ class SwinEncoder(nn.Module):
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths))] dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths))]
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[ [
SwinLayer( SwinStage(
config=config, config=config,
dim=int(config.embed_dim * 2**i_layer), dim=int(config.embed_dim * 2**i_layer),
input_resolution=(grid_size[0] // (2**i_layer), grid_size[1] // (2**i_layer)), input_resolution=(grid_size[0] // (2**i_layer), grid_size[1] // (2**i_layer)),
...@@ -573,18 +752,26 @@ class SwinEncoder(nn.Module): ...@@ -573,18 +752,26 @@ class SwinEncoder(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states,
input_dimensions,
head_mask=None, head_mask=None,
output_attentions=False, output_attentions=False,
output_hidden_states=False, output_hidden_states=False,
return_dict=True, return_dict=True,
): ):
all_input_dimensions = ()
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_reshaped_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None all_self_attentions = () if output_attentions else None
for i, layer_module in enumerate(self.layers): if output_hidden_states:
if output_hidden_states: batch_size, _, hidden_size = hidden_states.shape
all_hidden_states = all_hidden_states + (hidden_states,) # rearrange b (h w) c -> b c h w
reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size)
reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
all_hidden_states += (hidden_states,)
all_reshaped_hidden_states += (reshaped_hidden_state,)
for i, layer_module in enumerate(self.layers):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
...@@ -596,23 +783,36 @@ class SwinEncoder(nn.Module): ...@@ -596,23 +783,36 @@ class SwinEncoder(nn.Module):
return custom_forward return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint( layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module), hidden_states, layer_head_mask create_custom_forward(layer_module), hidden_states, input_dimensions, layer_head_mask
) )
else: else:
layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) layer_outputs = layer_module(hidden_states, input_dimensions, layer_head_mask, output_attentions)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if output_attentions: output_dimensions = layer_outputs[1]
all_self_attentions = all_self_attentions + (layer_outputs[1],)
if output_hidden_states: input_dimensions = (output_dimensions[-2], output_dimensions[-1])
all_hidden_states = all_hidden_states + (hidden_states,) all_input_dimensions += (input_dimensions,)
if output_hidden_states:
batch_size, _, hidden_size = hidden_states.shape
# rearrange b (h w) c -> b c h w
reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size)
reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
all_hidden_states += (hidden_states,)
all_reshaped_hidden_states += (reshaped_hidden_state,)
if output_attentions:
all_self_attentions += layer_outputs[2:]
if not return_dict: if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
return BaseModelOutput( return SwinEncoderOutput(
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_self_attentions last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
reshaped_hidden_states=all_reshaped_hidden_states,
) )
...@@ -712,7 +912,7 @@ class SwinModel(SwinPreTrainedModel): ...@@ -712,7 +912,7 @@ class SwinModel(SwinPreTrainedModel):
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_FEAT_EXTRACTOR_FOR_DOC, processor_class=_FEAT_EXTRACTOR_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC, checkpoint=_CHECKPOINT_FOR_DOC,
output_type=BaseModelOutputWithPooling, output_type=SwinModelOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
modality="vision", modality="vision",
expected_output=_EXPECTED_OUTPUT_SHAPE, expected_output=_EXPECTED_OUTPUT_SHAPE,
...@@ -742,10 +942,11 @@ class SwinModel(SwinPreTrainedModel): ...@@ -742,10 +942,11 @@ class SwinModel(SwinPreTrainedModel):
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, len(self.config.depths)) head_mask = self.get_head_mask(head_mask, len(self.config.depths))
embedding_output = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos) embedding_output, input_dimensions = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
encoder_outputs = self.encoder( encoder_outputs = self.encoder(
embedding_output, embedding_output,
input_dimensions,
head_mask=head_mask, head_mask=head_mask,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
...@@ -761,13 +962,16 @@ class SwinModel(SwinPreTrainedModel): ...@@ -761,13 +962,16 @@ class SwinModel(SwinPreTrainedModel):
pooled_output = torch.flatten(pooled_output, 1) pooled_output = torch.flatten(pooled_output, 1)
if not return_dict: if not return_dict:
return (sequence_output, pooled_output) + encoder_outputs[1:] output = (sequence_output, pooled_output) + encoder_outputs[1:]
return output
return BaseModelOutputWithPooling( return SwinModelOutput(
last_hidden_state=sequence_output, last_hidden_state=sequence_output,
pooler_output=pooled_output, pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states, hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions, attentions=encoder_outputs.attentions,
reshaped_hidden_states=encoder_outputs.reshaped_hidden_states,
) )
...@@ -791,7 +995,7 @@ class SwinForMaskedImageModeling(SwinPreTrainedModel): ...@@ -791,7 +995,7 @@ class SwinForMaskedImageModeling(SwinPreTrainedModel):
self.post_init() self.post_init()
@add_start_docstrings_to_model_forward(SWIN_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(SWIN_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=SwinMaskedImageModelingOutput, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
pixel_values=None, pixel_values=None,
...@@ -869,11 +1073,12 @@ class SwinForMaskedImageModeling(SwinPreTrainedModel): ...@@ -869,11 +1073,12 @@ class SwinForMaskedImageModeling(SwinPreTrainedModel):
output = (reconstructed_pixel_values,) + outputs[2:] output = (reconstructed_pixel_values,) + outputs[2:]
return ((masked_im_loss,) + output) if masked_im_loss is not None else output return ((masked_im_loss,) + output) if masked_im_loss is not None else output
return MaskedLMOutput( return SwinMaskedImageModelingOutput(
loss=masked_im_loss, loss=masked_im_loss,
logits=reconstructed_pixel_values, logits=reconstructed_pixel_values,
hidden_states=outputs.hidden_states, hidden_states=outputs.hidden_states,
attentions=outputs.attentions, attentions=outputs.attentions,
reshaped_hidden_states=outputs.reshaped_hidden_states,
) )
...@@ -903,7 +1108,7 @@ class SwinForImageClassification(SwinPreTrainedModel): ...@@ -903,7 +1108,7 @@ class SwinForImageClassification(SwinPreTrainedModel):
@add_code_sample_docstrings( @add_code_sample_docstrings(
processor_class=_FEAT_EXTRACTOR_FOR_DOC, processor_class=_FEAT_EXTRACTOR_FOR_DOC,
checkpoint=_IMAGE_CLASS_CHECKPOINT, checkpoint=_IMAGE_CLASS_CHECKPOINT,
output_type=SequenceClassifierOutput, output_type=SwinImageClassifierOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT, expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
) )
...@@ -963,9 +1168,10 @@ class SwinForImageClassification(SwinPreTrainedModel): ...@@ -963,9 +1168,10 @@ class SwinForImageClassification(SwinPreTrainedModel):
output = (logits,) + outputs[2:] output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput( return SwinImageClassifierOutput(
loss=loss, loss=loss,
logits=logits, logits=logits,
hidden_states=outputs.hidden_states, hidden_states=outputs.hidden_states,
attentions=outputs.attentions, attentions=outputs.attentions,
reshaped_hidden_states=outputs.reshaped_hidden_states,
) )
...@@ -230,15 +230,6 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -230,15 +230,6 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True config.return_dict = True
image_size = to_2tuple(self.model_tester.image_size)
patch_size = to_2tuple(self.model_tester.patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
seq_len = num_patches
encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len)
chunk_length = getattr(self.model_tester, "chunk_length", None)
if chunk_length is not None and hasattr(self.model_tester, "num_hashes"):
encoder_seq_length = encoder_seq_length * self.model_tester.num_hashes
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
inputs_dict["output_attentions"] = True inputs_dict["output_attentions"] = True
inputs_dict["output_hidden_states"] = False inputs_dict["output_hidden_states"] = False
...@@ -248,8 +239,9 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -248,8 +239,9 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class)) outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions attentions = outputs.attentions
self.assertEqual(len(attentions), len(self.model_tester.depths)) expected_num_attentions = len(self.model_tester.depths)
self.assertEqual(len(attentions), expected_num_attentions)
# check that output_attentions also work using config # check that output_attentions also work using config
del inputs_dict["output_attentions"] del inputs_dict["output_attentions"]
...@@ -260,19 +252,13 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -260,19 +252,13 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class)) outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions attentions = outputs.attentions
self.assertEqual(len(attentions), len(self.model_tester.depths)) self.assertEqual(len(attentions), expected_num_attentions)
if chunk_length is not None: self.assertListEqual(
self.assertListEqual( list(attentions[0].shape[-3:]),
list(attentions[0].shape[-4:]), [self.model_tester.num_heads[0], window_size_squared, window_size_squared],
[self.model_tester.num_heads[0], window_size_squared, chunk_length, window_size_squared], )
)
else:
self.assertListEqual(
list(attentions[0].shape[-3:]),
[self.model_tester.num_heads[0], window_size_squared, window_size_squared],
)
out_len = len(outputs) out_len = len(outputs)
# Check attention is always last and order is fine # Check attention is always last and order is fine
...@@ -286,25 +272,19 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -286,25 +272,19 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
if hasattr(self.model_tester, "num_hidden_states_types"): if hasattr(self.model_tester, "num_hidden_states_types"):
added_hidden_states = self.model_tester.num_hidden_states_types added_hidden_states = self.model_tester.num_hidden_states_types
elif self.is_encoder_decoder:
added_hidden_states = 2
else: else:
added_hidden_states = 1 # also another +1 for reshaped_hidden_states
added_hidden_states = 2
self.assertEqual(out_len + added_hidden_states, len(outputs)) self.assertEqual(out_len + added_hidden_states, len(outputs))
self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions self_attentions = outputs.attentions
self.assertEqual(len(self_attentions), len(self.model_tester.depths)) self.assertEqual(len(self_attentions), expected_num_attentions)
if chunk_length is not None:
self.assertListEqual( self.assertListEqual(
list(self_attentions[0].shape[-4:]), list(self_attentions[0].shape[-3:]),
[self.model_tester.num_heads[0], window_size_squared, chunk_length, window_size_squared], [self.model_tester.num_heads[0], window_size_squared, window_size_squared],
) )
else:
self.assertListEqual(
list(self_attentions[0].shape[-3:]),
[self.model_tester.num_heads[0], window_size_squared, window_size_squared],
)
def test_hidden_states_output(self): def test_hidden_states_output(self):
def check_hidden_states_output(inputs_dict, config, model_class): def check_hidden_states_output(inputs_dict, config, model_class):
...@@ -315,7 +295,7 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -315,7 +295,7 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
with torch.no_grad(): with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class)) outputs = model(**self._prepare_for_class(inputs_dict, model_class))
hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states hidden_states = outputs.hidden_states
expected_num_layers = getattr( expected_num_layers = getattr(
self.model_tester, "expected_num_hidden_layers", len(self.model_tester.depths) + 1 self.model_tester, "expected_num_hidden_layers", len(self.model_tester.depths) + 1
...@@ -325,6 +305,7 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -325,6 +305,7 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
# Swin has a different seq_length # Swin has a different seq_length
image_size = to_2tuple(self.model_tester.image_size) image_size = to_2tuple(self.model_tester.image_size)
patch_size = to_2tuple(self.model_tester.patch_size) patch_size = to_2tuple(self.model_tester.patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.assertListEqual( self.assertListEqual(
...@@ -332,6 +313,18 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -332,6 +313,18 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
[num_patches, self.model_tester.embed_dim], [num_patches, self.model_tester.embed_dim],
) )
reshaped_hidden_states = outputs.reshaped_hidden_states
self.assertEqual(len(reshaped_hidden_states), expected_num_layers)
batch_size, num_channels, height, width = reshaped_hidden_states[0].shape
reshaped_hidden_states = (
reshaped_hidden_states[0].view(batch_size, num_channels, height * width).permute(0, 2, 1)
)
self.assertListEqual(
list(reshaped_hidden_states.shape[-2:]),
[num_patches, self.model_tester.embed_dim],
)
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
...@@ -395,7 +388,5 @@ class SwinModelIntegrationTest(unittest.TestCase): ...@@ -395,7 +388,5 @@ class SwinModelIntegrationTest(unittest.TestCase):
# verify the logits # verify the logits
expected_shape = torch.Size((1, 1000)) expected_shape = torch.Size((1, 1000))
self.assertEqual(outputs.logits.shape, expected_shape) self.assertEqual(outputs.logits.shape, expected_shape)
expected_slice = torch.tensor([-0.0948, -0.6454, -0.0921]).to(torch_device) expected_slice = torch.tensor([-0.0948, -0.6454, -0.0921]).to(torch_device)
self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4)) self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment