Unverified Commit 67acb07e authored by NielsRogge's avatar NielsRogge Committed by GitHub
Browse files

Add Swin backbone (#20769)



* Add Swin backbone

* Remove line

* Add code example
Co-authored-by: default avatarNiels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
parent 94f8e21c
...@@ -2078,6 +2078,7 @@ else: ...@@ -2078,6 +2078,7 @@ else:
_import_structure["models.swin"].extend( _import_structure["models.swin"].extend(
[ [
"SWIN_PRETRAINED_MODEL_ARCHIVE_LIST", "SWIN_PRETRAINED_MODEL_ARCHIVE_LIST",
"SwinBackbone",
"SwinForImageClassification", "SwinForImageClassification",
"SwinForMaskedImageModeling", "SwinForMaskedImageModeling",
"SwinModel", "SwinModel",
...@@ -5041,6 +5042,7 @@ if TYPE_CHECKING: ...@@ -5041,6 +5042,7 @@ if TYPE_CHECKING:
) )
from .models.swin import ( from .models.swin import (
SWIN_PRETRAINED_MODEL_ARCHIVE_LIST, SWIN_PRETRAINED_MODEL_ARCHIVE_LIST,
SwinBackbone,
SwinForImageClassification, SwinForImageClassification,
SwinForMaskedImageModeling, SwinForMaskedImageModeling,
SwinModel, SwinModel,
......
...@@ -869,6 +869,7 @@ MODEL_FOR_BACKBONE_MAPPING_NAMES = OrderedDict( ...@@ -869,6 +869,7 @@ MODEL_FOR_BACKBONE_MAPPING_NAMES = OrderedDict(
("maskformer-swin", "MaskFormerSwinBackbone"), ("maskformer-swin", "MaskFormerSwinBackbone"),
("nat", "NatBackbone"), ("nat", "NatBackbone"),
("resnet", "ResNetBackbone"), ("resnet", "ResNetBackbone"),
("swin", "SwinBackbone"),
] ]
) )
......
...@@ -523,7 +523,6 @@ class DonutSwinLayer(nn.Module): ...@@ -523,7 +523,6 @@ class DonutSwinLayer(nn.Module):
self.shift_size = shift_size self.shift_size = shift_size
self.window_size = config.window_size self.window_size = config.window_size
self.input_resolution = input_resolution self.input_resolution = input_resolution
self.set_shift_and_window_size(input_resolution)
self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps) self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps)
self.attention = DonutSwinAttention(config, dim, num_heads, window_size=self.window_size) self.attention = DonutSwinAttention(config, dim, num_heads, window_size=self.window_size)
self.drop_path = DonutSwinDropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity() self.drop_path = DonutSwinDropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity()
...@@ -585,7 +584,9 @@ class DonutSwinLayer(nn.Module): ...@@ -585,7 +584,9 @@ class DonutSwinLayer(nn.Module):
shortcut = hidden_states shortcut = hidden_states
hidden_states = self.layernorm_before(hidden_states) hidden_states = self.layernorm_before(hidden_states)
hidden_states = hidden_states.view(batch_size, height, width, channels) hidden_states = hidden_states.view(batch_size, height, width, channels)
# pad hidden_states to multiples of window size # pad hidden_states to multiples of window size
hidden_states, pad_values = self.maybe_pad(hidden_states, height, width) hidden_states, pad_values = self.maybe_pad(hidden_states, height, width)
...@@ -677,14 +678,15 @@ class DonutSwinStage(nn.Module): ...@@ -677,14 +678,15 @@ class DonutSwinStage(nn.Module):
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
hidden_states_before_downsampling = hidden_states
if self.downsample is not None: if self.downsample is not None:
height_downsampled, width_downsampled = (height + 1) // 2, (width + 1) // 2 height_downsampled, width_downsampled = (height + 1) // 2, (width + 1) // 2
output_dimensions = (height, width, height_downsampled, width_downsampled) output_dimensions = (height, width, height_downsampled, width_downsampled)
hidden_states = self.downsample(layer_outputs[0], input_dimensions) hidden_states = self.downsample(hidden_states_before_downsampling, input_dimensions)
else: else:
output_dimensions = (height, width, height, width) output_dimensions = (height, width, height, width)
stage_outputs = (hidden_states, output_dimensions) stage_outputs = (hidden_states, hidden_states_before_downsampling, output_dimensions)
if output_attentions: if output_attentions:
stage_outputs += layer_outputs[1:] stage_outputs += layer_outputs[1:]
...@@ -722,9 +724,9 @@ class DonutSwinEncoder(nn.Module): ...@@ -722,9 +724,9 @@ class DonutSwinEncoder(nn.Module):
head_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
output_hidden_states: Optional[bool] = False, output_hidden_states: Optional[bool] = False,
output_hidden_states_before_downsampling: Optional[bool] = False,
return_dict: Optional[bool] = True, return_dict: Optional[bool] = True,
) -> Union[Tuple, DonutSwinEncoderOutput]: ) -> Union[Tuple, DonutSwinEncoderOutput]:
all_input_dimensions = ()
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_reshaped_hidden_states = () if output_hidden_states else None all_reshaped_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None all_self_attentions = () if output_attentions else None
...@@ -755,12 +757,22 @@ class DonutSwinEncoder(nn.Module): ...@@ -755,12 +757,22 @@ class DonutSwinEncoder(nn.Module):
layer_outputs = layer_module(hidden_states, input_dimensions, layer_head_mask, output_attentions) layer_outputs = layer_module(hidden_states, input_dimensions, layer_head_mask, output_attentions)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
output_dimensions = layer_outputs[1] hidden_states_before_downsampling = layer_outputs[1]
output_dimensions = layer_outputs[2]
input_dimensions = (output_dimensions[-2], output_dimensions[-1]) input_dimensions = (output_dimensions[-2], output_dimensions[-1])
all_input_dimensions += (input_dimensions,)
if output_hidden_states: if output_hidden_states and output_hidden_states_before_downsampling:
batch_size, _, hidden_size = hidden_states_before_downsampling.shape
# rearrange b (h w) c -> b c h w
# here we use the original (not downsampled) height and width
reshaped_hidden_state = hidden_states_before_downsampling.view(
batch_size, *(output_dimensions[0], output_dimensions[1]), hidden_size
)
reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
all_hidden_states += (hidden_states_before_downsampling,)
all_reshaped_hidden_states += (reshaped_hidden_state,)
elif output_hidden_states and not output_hidden_states_before_downsampling:
batch_size, _, hidden_size = hidden_states.shape batch_size, _, hidden_size = hidden_states.shape
# rearrange b (h w) c -> b c h w # rearrange b (h w) c -> b c h w
reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size) reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size)
...@@ -769,7 +781,7 @@ class DonutSwinEncoder(nn.Module): ...@@ -769,7 +781,7 @@ class DonutSwinEncoder(nn.Module):
all_reshaped_hidden_states += (reshaped_hidden_state,) all_reshaped_hidden_states += (reshaped_hidden_state,)
if output_attentions: if output_attentions:
all_self_attentions += layer_outputs[2:] all_self_attentions += layer_outputs[3:]
if not return_dict: if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
......
...@@ -36,6 +36,7 @@ else: ...@@ -36,6 +36,7 @@ else:
"SwinForMaskedImageModeling", "SwinForMaskedImageModeling",
"SwinModel", "SwinModel",
"SwinPreTrainedModel", "SwinPreTrainedModel",
"SwinBackbone",
] ]
try: try:
...@@ -63,6 +64,7 @@ if TYPE_CHECKING: ...@@ -63,6 +64,7 @@ if TYPE_CHECKING:
else: else:
from .modeling_swin import ( from .modeling_swin import (
SWIN_PRETRAINED_MODEL_ARCHIVE_LIST, SWIN_PRETRAINED_MODEL_ARCHIVE_LIST,
SwinBackbone,
SwinForImageClassification, SwinForImageClassification,
SwinForMaskedImageModeling, SwinForMaskedImageModeling,
SwinModel, SwinModel,
......
...@@ -83,6 +83,9 @@ class SwinConfig(PretrainedConfig): ...@@ -83,6 +83,9 @@ class SwinConfig(PretrainedConfig):
The epsilon used by the layer normalization layers. The epsilon used by the layer normalization layers.
encoder_stride (`int`, `optional`, defaults to 32): encoder_stride (`int`, `optional`, defaults to 32):
Factor to increase the spatial resolution by in the decoder head for masked image modeling. Factor to increase the spatial resolution by in the decoder head for masked image modeling.
out_features (`List[str]`, *optional*):
If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
(depending on how many stages the model has). Will default to the last stage if unset.
Example: Example:
...@@ -125,6 +128,7 @@ class SwinConfig(PretrainedConfig): ...@@ -125,6 +128,7 @@ class SwinConfig(PretrainedConfig):
initializer_range=0.02, initializer_range=0.02,
layer_norm_eps=1e-5, layer_norm_eps=1e-5,
encoder_stride=32, encoder_stride=32,
out_features=None,
**kwargs **kwargs
): ):
super().__init__(**kwargs) super().__init__(**kwargs)
...@@ -151,6 +155,16 @@ class SwinConfig(PretrainedConfig): ...@@ -151,6 +155,16 @@ class SwinConfig(PretrainedConfig):
# we set the hidden_size attribute in order to make Swin work with VisionEncoderDecoderModel # we set the hidden_size attribute in order to make Swin work with VisionEncoderDecoderModel
# this indicates the channel dimension after the last stage of the model # this indicates the channel dimension after the last stage of the model
self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1)) self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1))
self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(depths) + 1)]
if out_features is not None:
if not isinstance(out_features, list):
raise ValueError("out_features should be a list")
for feature in out_features:
if feature not in self.stage_names:
raise ValueError(
f"Feature {feature} is not a valid feature name. Valid names are {self.stage_names}"
)
self.out_features = out_features
class SwinOnnxConfig(OnnxConfig): class SwinOnnxConfig(OnnxConfig):
......
...@@ -26,7 +26,8 @@ from torch import nn ...@@ -26,7 +26,8 @@ from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_utils import PreTrainedModel from ...modeling_outputs import BackboneOutput
from ...modeling_utils import BackboneMixin, PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer
from ...utils import ( from ...utils import (
ModelOutput, ModelOutput,
...@@ -589,7 +590,6 @@ class SwinLayer(nn.Module): ...@@ -589,7 +590,6 @@ class SwinLayer(nn.Module):
self.shift_size = shift_size self.shift_size = shift_size
self.window_size = config.window_size self.window_size = config.window_size
self.input_resolution = input_resolution self.input_resolution = input_resolution
self.set_shift_and_window_size(input_resolution)
self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps) self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps)
self.attention = SwinAttention(config, dim, num_heads, window_size=self.window_size) self.attention = SwinAttention(config, dim, num_heads, window_size=self.window_size)
self.drop_path = SwinDropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity() self.drop_path = SwinDropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity()
...@@ -651,7 +651,9 @@ class SwinLayer(nn.Module): ...@@ -651,7 +651,9 @@ class SwinLayer(nn.Module):
shortcut = hidden_states shortcut = hidden_states
hidden_states = self.layernorm_before(hidden_states) hidden_states = self.layernorm_before(hidden_states)
hidden_states = hidden_states.view(batch_size, height, width, channels) hidden_states = hidden_states.view(batch_size, height, width, channels)
# pad hidden_states to multiples of window size # pad hidden_states to multiples of window size
hidden_states, pad_values = self.maybe_pad(hidden_states, height, width) hidden_states, pad_values = self.maybe_pad(hidden_states, height, width)
...@@ -742,14 +744,15 @@ class SwinStage(nn.Module): ...@@ -742,14 +744,15 @@ class SwinStage(nn.Module):
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
hidden_states_before_downsampling = hidden_states
if self.downsample is not None: if self.downsample is not None:
height_downsampled, width_downsampled = (height + 1) // 2, (width + 1) // 2 height_downsampled, width_downsampled = (height + 1) // 2, (width + 1) // 2
output_dimensions = (height, width, height_downsampled, width_downsampled) output_dimensions = (height, width, height_downsampled, width_downsampled)
hidden_states = self.downsample(layer_outputs[0], input_dimensions) hidden_states = self.downsample(hidden_states_before_downsampling, input_dimensions)
else: else:
output_dimensions = (height, width, height, width) output_dimensions = (height, width, height, width)
stage_outputs = (hidden_states, output_dimensions) stage_outputs = (hidden_states, hidden_states_before_downsampling, output_dimensions)
if output_attentions: if output_attentions:
stage_outputs += layer_outputs[1:] stage_outputs += layer_outputs[1:]
...@@ -786,9 +789,9 @@ class SwinEncoder(nn.Module): ...@@ -786,9 +789,9 @@ class SwinEncoder(nn.Module):
head_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
output_hidden_states: Optional[bool] = False, output_hidden_states: Optional[bool] = False,
output_hidden_states_before_downsampling: Optional[bool] = False,
return_dict: Optional[bool] = True, return_dict: Optional[bool] = True,
) -> Union[Tuple, SwinEncoderOutput]: ) -> Union[Tuple, SwinEncoderOutput]:
all_input_dimensions = ()
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_reshaped_hidden_states = () if output_hidden_states else None all_reshaped_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None all_self_attentions = () if output_attentions else None
...@@ -819,12 +822,22 @@ class SwinEncoder(nn.Module): ...@@ -819,12 +822,22 @@ class SwinEncoder(nn.Module):
layer_outputs = layer_module(hidden_states, input_dimensions, layer_head_mask, output_attentions) layer_outputs = layer_module(hidden_states, input_dimensions, layer_head_mask, output_attentions)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
output_dimensions = layer_outputs[1] hidden_states_before_downsampling = layer_outputs[1]
output_dimensions = layer_outputs[2]
input_dimensions = (output_dimensions[-2], output_dimensions[-1]) input_dimensions = (output_dimensions[-2], output_dimensions[-1])
all_input_dimensions += (input_dimensions,)
if output_hidden_states: if output_hidden_states and output_hidden_states_before_downsampling:
batch_size, _, hidden_size = hidden_states_before_downsampling.shape
# rearrange b (h w) c -> b c h w
# here we use the original (not downsampled) height and width
reshaped_hidden_state = hidden_states_before_downsampling.view(
batch_size, *(output_dimensions[0], output_dimensions[1]), hidden_size
)
reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
all_hidden_states += (hidden_states_before_downsampling,)
all_reshaped_hidden_states += (reshaped_hidden_state,)
elif output_hidden_states and not output_hidden_states_before_downsampling:
batch_size, _, hidden_size = hidden_states.shape batch_size, _, hidden_size = hidden_states.shape
# rearrange b (h w) c -> b c h w # rearrange b (h w) c -> b c h w
reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size) reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size)
...@@ -833,7 +846,7 @@ class SwinEncoder(nn.Module): ...@@ -833,7 +846,7 @@ class SwinEncoder(nn.Module):
all_reshaped_hidden_states += (reshaped_hidden_state,) all_reshaped_hidden_states += (reshaped_hidden_state,)
if output_attentions: if output_attentions:
all_self_attentions += layer_outputs[2:] all_self_attentions += layer_outputs[3:]
if not return_dict: if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
...@@ -1214,3 +1227,118 @@ class SwinForImageClassification(SwinPreTrainedModel): ...@@ -1214,3 +1227,118 @@ class SwinForImageClassification(SwinPreTrainedModel):
attentions=outputs.attentions, attentions=outputs.attentions,
reshaped_hidden_states=outputs.reshaped_hidden_states, reshaped_hidden_states=outputs.reshaped_hidden_states,
) )
@add_start_docstrings(
"""
Swin backbone, to be used with frameworks like DETR and MaskFormer.
""",
SWIN_START_DOCSTRING,
)
class SwinBackbone(SwinPreTrainedModel, BackboneMixin):
def __init__(self, config: SwinConfig):
super().__init__(config)
self.stage_names = config.stage_names
self.embeddings = SwinEmbeddings(config)
self.encoder = SwinEncoder(config, self.embeddings.patch_grid)
self.out_features = config.out_features if config.out_features is not None else [self.stage_names[-1]]
num_features = [int(config.embed_dim * 2**i) for i in range(len(config.depths))]
self.out_feature_channels = {}
self.out_feature_channels["stem"] = config.embed_dim
for i, stage in enumerate(self.stage_names[1:]):
self.out_feature_channels[stage] = num_features[i]
# Add layer norms to hidden states of out_features
hidden_states_norms = dict()
for stage, num_channels in zip(self.out_features, self.channels):
hidden_states_norms[stage] = nn.LayerNorm(num_channels)
self.hidden_states_norms = nn.ModuleDict(hidden_states_norms)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.embeddings.patch_embeddings
@property
def channels(self):
return [self.out_feature_channels[name] for name in self.out_features]
def forward(
self,
pixel_values: torch.Tensor,
output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> BackboneOutput:
"""
Returns:
Examples:
```python
>>> from transformers import AutoImageProcessor, AutoBackbone
>>> import torch
>>> from PIL import Image
>>> import requests
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> processor = AutoImageProcessor.from_pretrained("shi-labs/nat-mini-in1k-224")
>>> model = AutoBackbone.from_pretrained(
... "microsoft/swin-tiny-patch4-window7-224", out_features=["stage1", "stage2", "stage3", "stage4"]
... )
>>> inputs = processor(image, return_tensors="pt")
>>> outputs = model(**inputs)
>>> feature_maps = outputs.feature_maps
>>> list(feature_maps[-1].shape)
[1, 768, 7, 7]
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
embedding_output, input_dimensions = self.embeddings(pixel_values)
outputs = self.encoder(
embedding_output,
input_dimensions,
head_mask=None,
output_attentions=output_attentions,
output_hidden_states=True,
output_hidden_states_before_downsampling=True,
return_dict=True,
)
hidden_states = outputs.reshaped_hidden_states
feature_maps = ()
for stage, hidden_state in zip(self.stage_names, hidden_states):
if stage in self.out_features:
batch_size, num_channels, height, width = hidden_state.shape
hidden_state = hidden_state.permute(0, 2, 3, 1).contiguous()
hidden_state = hidden_state.view(batch_size, height * width, num_channels)
hidden_state = self.hidden_states_norms[stage](hidden_state)
hidden_state = hidden_state.view(batch_size, height, width, num_channels)
hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous()
feature_maps += (hidden_state,)
if not return_dict:
output = (feature_maps,)
if output_hidden_states:
output += (outputs.hidden_states,)
return output
return BackboneOutput(
feature_maps=feature_maps,
hidden_states=outputs.hidden_states if output_hidden_states else None,
attentions=outputs.attentions,
)
...@@ -817,14 +817,15 @@ class Swinv2Stage(nn.Module): ...@@ -817,14 +817,15 @@ class Swinv2Stage(nn.Module):
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
hidden_states_before_downsampling = hidden_states
if self.downsample is not None: if self.downsample is not None:
height_downsampled, width_downsampled = (height + 1) // 2, (width + 1) // 2 height_downsampled, width_downsampled = (height + 1) // 2, (width + 1) // 2
output_dimensions = (height, width, height_downsampled, width_downsampled) output_dimensions = (height, width, height_downsampled, width_downsampled)
hidden_states = self.downsample(layer_outputs[0], input_dimensions) hidden_states = self.downsample(hidden_states_before_downsampling, input_dimensions)
else: else:
output_dimensions = (height, width, height, width) output_dimensions = (height, width, height, width)
stage_outputs = (hidden_states, output_dimensions) stage_outputs = (hidden_states, hidden_states_before_downsampling, output_dimensions)
if output_attentions: if output_attentions:
stage_outputs += layer_outputs[1:] stage_outputs += layer_outputs[1:]
...@@ -865,9 +866,9 @@ class Swinv2Encoder(nn.Module): ...@@ -865,9 +866,9 @@ class Swinv2Encoder(nn.Module):
head_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
output_hidden_states: Optional[bool] = False, output_hidden_states: Optional[bool] = False,
output_hidden_states_before_downsampling: Optional[bool] = False,
return_dict: Optional[bool] = True, return_dict: Optional[bool] = True,
) -> Union[Tuple, Swinv2EncoderOutput]: ) -> Union[Tuple, Swinv2EncoderOutput]:
all_input_dimensions = ()
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_reshaped_hidden_states = () if output_hidden_states else None all_reshaped_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None all_self_attentions = () if output_attentions else None
...@@ -898,12 +899,22 @@ class Swinv2Encoder(nn.Module): ...@@ -898,12 +899,22 @@ class Swinv2Encoder(nn.Module):
layer_outputs = layer_module(hidden_states, input_dimensions, layer_head_mask, output_attentions) layer_outputs = layer_module(hidden_states, input_dimensions, layer_head_mask, output_attentions)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
output_dimensions = layer_outputs[1] hidden_states_before_downsampling = layer_outputs[1]
output_dimensions = layer_outputs[2]
input_dimensions = (output_dimensions[-2], output_dimensions[-1]) input_dimensions = (output_dimensions[-2], output_dimensions[-1])
all_input_dimensions += (input_dimensions,)
if output_hidden_states: if output_hidden_states and output_hidden_states_before_downsampling:
batch_size, _, hidden_size = hidden_states_before_downsampling.shape
# rearrange b (h w) c -> b c h w
# here we use the original (not downsampled) height and width
reshaped_hidden_state = hidden_states_before_downsampling.view(
batch_size, *(output_dimensions[0], output_dimensions[1]), hidden_size
)
reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
all_hidden_states += (hidden_states_before_downsampling,)
all_reshaped_hidden_states += (reshaped_hidden_state,)
elif output_hidden_states and not output_hidden_states_before_downsampling:
batch_size, _, hidden_size = hidden_states.shape batch_size, _, hidden_size = hidden_states.shape
# rearrange b (h w) c -> b c h w # rearrange b (h w) c -> b c h w
reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size) reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size)
...@@ -912,7 +923,7 @@ class Swinv2Encoder(nn.Module): ...@@ -912,7 +923,7 @@ class Swinv2Encoder(nn.Module):
all_reshaped_hidden_states += (reshaped_hidden_state,) all_reshaped_hidden_states += (reshaped_hidden_state,)
if output_attentions: if output_attentions:
all_self_attentions += layer_outputs[2:] all_self_attentions += layer_outputs[3:]
if not return_dict: if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
......
...@@ -5243,6 +5243,13 @@ class SqueezeBertPreTrainedModel(metaclass=DummyObject): ...@@ -5243,6 +5243,13 @@ class SqueezeBertPreTrainedModel(metaclass=DummyObject):
SWIN_PRETRAINED_MODEL_ARCHIVE_LIST = None SWIN_PRETRAINED_MODEL_ARCHIVE_LIST = None
class SwinBackbone(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class SwinForImageClassification(metaclass=DummyObject): class SwinForImageClassification(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
......
...@@ -30,7 +30,7 @@ if is_torch_available(): ...@@ -30,7 +30,7 @@ if is_torch_available():
import torch import torch
from torch import nn from torch import nn
from transformers import SwinForImageClassification, SwinForMaskedImageModeling, SwinModel from transformers import SwinBackbone, SwinForImageClassification, SwinForMaskedImageModeling, SwinModel
from transformers.models.swin.modeling_swin import SWIN_PRETRAINED_MODEL_ARCHIVE_LIST from transformers.models.swin.modeling_swin import SWIN_PRETRAINED_MODEL_ARCHIVE_LIST
if is_vision_available(): if is_vision_available():
...@@ -66,6 +66,7 @@ class SwinModelTester: ...@@ -66,6 +66,7 @@ class SwinModelTester:
use_labels=True, use_labels=True,
type_sequence_label_size=10, type_sequence_label_size=10,
encoder_stride=8, encoder_stride=8,
out_features=["stage1", "stage2"],
): ):
self.parent = parent self.parent = parent
self.batch_size = batch_size self.batch_size = batch_size
...@@ -91,6 +92,7 @@ class SwinModelTester: ...@@ -91,6 +92,7 @@ class SwinModelTester:
self.use_labels = use_labels self.use_labels = use_labels
self.type_sequence_label_size = type_sequence_label_size self.type_sequence_label_size = type_sequence_label_size
self.encoder_stride = encoder_stride self.encoder_stride = encoder_stride
self.out_features = out_features
def prepare_config_and_inputs(self): def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
...@@ -123,6 +125,7 @@ class SwinModelTester: ...@@ -123,6 +125,7 @@ class SwinModelTester:
layer_norm_eps=self.layer_norm_eps, layer_norm_eps=self.layer_norm_eps,
initializer_range=self.initializer_range, initializer_range=self.initializer_range,
encoder_stride=self.encoder_stride, encoder_stride=self.encoder_stride,
out_features=self.out_features,
) )
def create_and_check_model(self, config, pixel_values, labels): def create_and_check_model(self, config, pixel_values, labels):
...@@ -136,6 +139,33 @@ class SwinModelTester: ...@@ -136,6 +139,33 @@ class SwinModelTester:
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, expected_seq_len, expected_dim)) self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, expected_seq_len, expected_dim))
def create_and_check_backbone(self, config, pixel_values, labels):
model = SwinBackbone(config=config)
model.to(torch_device)
model.eval()
result = model(pixel_values)
# verify hidden states
self.parent.assertEqual(len(result.feature_maps), len(config.out_features))
self.parent.assertListEqual(list(result.feature_maps[0].shape), [self.batch_size, model.channels[0], 16, 16])
# verify channels
self.parent.assertEqual(len(model.channels), len(config.out_features))
# verify backbone works with out_features=None
config.out_features = None
model = SwinBackbone(config=config)
model.to(torch_device)
model.eval()
result = model(pixel_values)
# verify feature maps
self.parent.assertEqual(len(result.feature_maps), 1)
self.parent.assertListEqual(list(result.feature_maps[0].shape), [self.batch_size, model.channels[-1], 4, 4])
# verify channels
self.parent.assertEqual(len(model.channels), 1)
def create_and_check_for_masked_image_modeling(self, config, pixel_values, labels): def create_and_check_for_masked_image_modeling(self, config, pixel_values, labels):
model = SwinForMaskedImageModeling(config=config) model = SwinForMaskedImageModeling(config=config)
model.to(torch_device) model.to(torch_device)
...@@ -190,6 +220,7 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -190,6 +220,7 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = ( all_model_classes = (
( (
SwinModel, SwinModel,
SwinBackbone,
SwinForImageClassification, SwinForImageClassification,
SwinForMaskedImageModeling, SwinForMaskedImageModeling,
) )
...@@ -222,6 +253,10 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -222,6 +253,10 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs) self.model_tester.create_and_check_model(*config_and_inputs)
def test_backbone(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_backbone(*config_and_inputs)
def test_for_masked_image_modeling(self): def test_for_masked_image_modeling(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_masked_image_modeling(*config_and_inputs) self.model_tester.create_and_check_for_masked_image_modeling(*config_and_inputs)
...@@ -230,8 +265,12 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -230,8 +265,12 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_image_classification(*config_and_inputs) self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
@unittest.skip(reason="Swin does not use inputs_embeds")
def test_inputs_embeds(self): def test_inputs_embeds(self):
# Swin does not use inputs_embeds pass
@unittest.skip(reason="Swin Transformer does not use feedforward chunking")
def test_feed_forward_chunking(self):
pass pass
def test_model_common_attributes(self): def test_model_common_attributes(self):
...@@ -299,11 +338,8 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -299,11 +338,8 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
with torch.no_grad(): with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class)) outputs = model(**self._prepare_for_class(inputs_dict, model_class))
if hasattr(self.model_tester, "num_hidden_states_types"): # also another +1 for reshaped_hidden_states
added_hidden_states = self.model_tester.num_hidden_states_types added_hidden_states = 1 if model_class.__name__ == "SwinBackbone" else 2
else:
# also another +1 for reshaped_hidden_states
added_hidden_states = 2
self.assertEqual(out_len + added_hidden_states, len(outputs)) self.assertEqual(out_len + added_hidden_states, len(outputs))
self_attentions = outputs.attentions self_attentions = outputs.attentions
...@@ -344,17 +380,18 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -344,17 +380,18 @@ class SwinModelTest(ModelTesterMixin, unittest.TestCase):
[num_patches, self.model_tester.embed_dim], [num_patches, self.model_tester.embed_dim],
) )
reshaped_hidden_states = outputs.reshaped_hidden_states if not model_class.__name__ == "SwinBackbone":
self.assertEqual(len(reshaped_hidden_states), expected_num_layers) reshaped_hidden_states = outputs.reshaped_hidden_states
self.assertEqual(len(reshaped_hidden_states), expected_num_layers)
batch_size, num_channels, height, width = reshaped_hidden_states[0].shape batch_size, num_channels, height, width = reshaped_hidden_states[0].shape
reshaped_hidden_states = ( reshaped_hidden_states = (
reshaped_hidden_states[0].view(batch_size, num_channels, height * width).permute(0, 2, 1) reshaped_hidden_states[0].view(batch_size, num_channels, height * width).permute(0, 2, 1)
) )
self.assertListEqual( self.assertListEqual(
list(reshaped_hidden_states.shape[-2:]), list(reshaped_hidden_states.shape[-2:]),
[num_patches, self.model_tester.embed_dim], [num_patches, self.model_tester.embed_dim],
) )
def test_hidden_states_output(self): def test_hidden_states_output(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
......
...@@ -681,6 +681,7 @@ SHOULD_HAVE_THEIR_OWN_PAGE = [ ...@@ -681,6 +681,7 @@ SHOULD_HAVE_THEIR_OWN_PAGE = [
"NatBackbone", "NatBackbone",
"MaskFormerSwinConfig", "MaskFormerSwinConfig",
"MaskFormerSwinModel", "MaskFormerSwinModel",
"SwinBackbone",
] ]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment