Unverified Commit 481a9578 authored by Abhiroop Tejomay's avatar Abhiroop Tejomay Committed by GitHub
Browse files

Enable dynamic resolution input for Swin Transformer and variants (#30656)



* add interpolation of positional encoding support to swin

* add style changes

* use default image processor and make size a dictionary
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* remove logits testing
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Refactor image size validation logic when interpolation is disabled
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* remove asserts in modeling
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* add dynamic resolution input support to swinv2

* change size to ensure interpolation encoding path is triggered

* set interpolate_pos_encoding default value to False
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* set interpolate_pos_encoding default value to False
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* set interpolate_pos_encoding default value to False
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* set interpolate_pos_encoding default value to False
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* set interpolate_pos_encoding default value to False
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* set interpolate_pos_encoding default value to False
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* set interpolate_pos_encoding default value to False
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* set interpolate_pos_encoding default value to False

* add dynamic resolution input to donut swin

* add dynamic resolution input to maskformer swin

---------
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent b6eb708b
...@@ -166,10 +166,48 @@ class DonutSwinEmbeddings(nn.Module): ...@@ -166,10 +166,48 @@ class DonutSwinEmbeddings(nn.Module):
self.norm = nn.LayerNorm(config.embed_dim) self.norm = nn.LayerNorm(config.embed_dim)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
resolution images.
Source:
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
"""
num_patches = embeddings.shape[1] - 1
num_positions = self.position_embeddings.shape[1] - 1
if num_patches == num_positions and height == width:
return self.position_embeddings
class_pos_embed = self.position_embeddings[:, 0]
patch_pos_embed = self.position_embeddings[:, 1:]
dim = embeddings.shape[-1]
h0 = height // self.config.patch_size
w0 = width // self.config.patch_size
# we add a small number to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
h0, w0 = h0 + 0.1, w0 + 0.1
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
mode="bicubic",
align_corners=False,
)
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
def forward( def forward(
self, pixel_values: Optional[torch.FloatTensor], bool_masked_pos: Optional[torch.BoolTensor] = None self,
pixel_values: Optional[torch.FloatTensor],
bool_masked_pos: Optional[torch.BoolTensor] = None,
interpolate_pos_encoding: bool = False,
) -> Tuple[torch.Tensor]: ) -> Tuple[torch.Tensor]:
embeddings, output_dimensions = self.patch_embeddings(pixel_values) _, num_channels, height, width = pixel_values.shape
embeddings, output_dimensions = self.patch_embeddings(
pixel_values, interpolate_pos_encoding=interpolate_pos_encoding
)
embeddings = self.norm(embeddings) embeddings = self.norm(embeddings)
batch_size, seq_len, _ = embeddings.size() batch_size, seq_len, _ = embeddings.size()
...@@ -180,6 +218,9 @@ class DonutSwinEmbeddings(nn.Module): ...@@ -180,6 +218,9 @@ class DonutSwinEmbeddings(nn.Module):
embeddings = embeddings * (1.0 - mask) + mask_tokens * mask embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
if self.position_embeddings is not None: if self.position_embeddings is not None:
if interpolate_pos_encoding:
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
else:
embeddings = embeddings + self.position_embeddings embeddings = embeddings + self.position_embeddings
embeddings = self.dropout(embeddings) embeddings = self.dropout(embeddings)
...@@ -219,7 +260,9 @@ class DonutSwinPatchEmbeddings(nn.Module): ...@@ -219,7 +260,9 @@ class DonutSwinPatchEmbeddings(nn.Module):
pixel_values = nn.functional.pad(pixel_values, pad_values) pixel_values = nn.functional.pad(pixel_values, pad_values)
return pixel_values return pixel_values
def forward(self, pixel_values: Optional[torch.FloatTensor]) -> Tuple[torch.Tensor, Tuple[int]]: def forward(
self, pixel_values: Optional[torch.FloatTensor], interpolate_pos_encoding: bool = False
) -> Tuple[torch.Tensor, Tuple[int]]:
_, num_channels, height, width = pixel_values.shape _, num_channels, height, width = pixel_values.shape
if num_channels != self.num_channels: if num_channels != self.num_channels:
raise ValueError( raise ValueError(
...@@ -227,6 +270,11 @@ class DonutSwinPatchEmbeddings(nn.Module): ...@@ -227,6 +270,11 @@ class DonutSwinPatchEmbeddings(nn.Module):
) )
# pad the input to be divisible by self.patch_size, if needed # pad the input to be divisible by self.patch_size, if needed
pixel_values = self.maybe_pad(pixel_values, height, width) pixel_values = self.maybe_pad(pixel_values, height, width)
if not interpolate_pos_encoding and (height != self.image_size[0] or width != self.image_size[1]):
raise ValueError(
f"Input image size ({height}*{width}) doesn't match model"
f" ({self.image_size[0]}*{self.image_size[1]})."
)
embeddings = self.projection(pixel_values) embeddings = self.projection(pixel_values)
_, _, height, width = embeddings.shape _, _, height, width = embeddings.shape
output_dimensions = (height, width) output_dimensions = (height, width)
...@@ -849,6 +897,8 @@ SWIN_INPUTS_DOCSTRING = r""" ...@@ -849,6 +897,8 @@ SWIN_INPUTS_DOCSTRING = r"""
output_hidden_states (`bool`, *optional*): output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail. more detail.
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
Whether to interpolate the pre-trained position encodings.
return_dict (`bool`, *optional*): return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
""" """
...@@ -899,6 +949,7 @@ class DonutSwinModel(DonutSwinPreTrainedModel): ...@@ -899,6 +949,7 @@ class DonutSwinModel(DonutSwinPreTrainedModel):
head_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
) -> Union[Tuple, DonutSwinModelOutput]: ) -> Union[Tuple, DonutSwinModelOutput]:
r""" r"""
...@@ -921,7 +972,9 @@ class DonutSwinModel(DonutSwinPreTrainedModel): ...@@ -921,7 +972,9 @@ class DonutSwinModel(DonutSwinPreTrainedModel):
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, len(self.config.depths)) head_mask = self.get_head_mask(head_mask, len(self.config.depths))
embedding_output, input_dimensions = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos) embedding_output, input_dimensions = self.embeddings(
pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
)
encoder_outputs = self.encoder( encoder_outputs = self.encoder(
embedding_output, embedding_output,
......
...@@ -163,11 +163,49 @@ class MaskFormerSwinEmbeddings(nn.Module): ...@@ -163,11 +163,49 @@ class MaskFormerSwinEmbeddings(nn.Module):
self.norm = nn.LayerNorm(config.embed_dim) self.norm = nn.LayerNorm(config.embed_dim)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, pixel_values): def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
embeddings, output_dimensions = self.patch_embeddings(pixel_values) """
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
resolution images.
Source:
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
"""
num_patches = embeddings.shape[1] - 1
num_positions = self.position_embeddings.shape[1] - 1
if num_patches == num_positions and height == width:
return self.position_embeddings
class_pos_embed = self.position_embeddings[:, 0]
patch_pos_embed = self.position_embeddings[:, 1:]
dim = embeddings.shape[-1]
h0 = height // self.config.patch_size
w0 = width // self.config.patch_size
# we add a small number to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
h0, w0 = h0 + 0.1, w0 + 0.1
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
mode="bicubic",
align_corners=False,
)
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
def forward(self, pixel_values, interpolate_pos_encoding):
_, num_channels, height, width = pixel_values.shape
embeddings, output_dimensions = self.patch_embeddings(
pixel_values, interpolate_pos_encoding=interpolate_pos_encoding
)
embeddings = self.norm(embeddings) embeddings = self.norm(embeddings)
if self.position_embeddings is not None: if self.position_embeddings is not None:
if interpolate_pos_encoding:
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
else:
embeddings = embeddings + self.position_embeddings embeddings = embeddings + self.position_embeddings
embeddings = self.dropout(embeddings) embeddings = self.dropout(embeddings)
...@@ -207,7 +245,9 @@ class MaskFormerSwinPatchEmbeddings(nn.Module): ...@@ -207,7 +245,9 @@ class MaskFormerSwinPatchEmbeddings(nn.Module):
pixel_values = nn.functional.pad(pixel_values, pad_values) pixel_values = nn.functional.pad(pixel_values, pad_values)
return pixel_values return pixel_values
def forward(self, pixel_values: Optional[torch.FloatTensor]) -> Tuple[torch.Tensor, Tuple[int]]: def forward(
self, pixel_values: Optional[torch.FloatTensor], interpolate_pos_encoding: bool = False
) -> Tuple[torch.Tensor, Tuple[int]]:
_, num_channels, height, width = pixel_values.shape _, num_channels, height, width = pixel_values.shape
if num_channels != self.num_channels: if num_channels != self.num_channels:
raise ValueError( raise ValueError(
...@@ -215,6 +255,11 @@ class MaskFormerSwinPatchEmbeddings(nn.Module): ...@@ -215,6 +255,11 @@ class MaskFormerSwinPatchEmbeddings(nn.Module):
) )
# pad the input to be divisible by self.patch_size, if needed # pad the input to be divisible by self.patch_size, if needed
pixel_values = self.maybe_pad(pixel_values, height, width) pixel_values = self.maybe_pad(pixel_values, height, width)
if not interpolate_pos_encoding and (height != self.image_size[0] or width != self.image_size[1]):
raise ValueError(
f"Input image size ({height}*{width}) doesn't match model"
f" ({self.image_size[0]}*{self.image_size[1]})."
)
embeddings = self.projection(pixel_values) embeddings = self.projection(pixel_values)
_, _, height, width = embeddings.shape _, _, height, width = embeddings.shape
output_dimensions = (height, width) output_dimensions = (height, width)
...@@ -780,6 +825,7 @@ class MaskFormerSwinModel(MaskFormerSwinPreTrainedModel): ...@@ -780,6 +825,7 @@ class MaskFormerSwinModel(MaskFormerSwinPreTrainedModel):
head_mask=None, head_mask=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
interpolate_pos_encoding=False,
return_dict=None, return_dict=None,
): ):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
...@@ -798,7 +844,9 @@ class MaskFormerSwinModel(MaskFormerSwinPreTrainedModel): ...@@ -798,7 +844,9 @@ class MaskFormerSwinModel(MaskFormerSwinPreTrainedModel):
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, len(self.config.depths)) head_mask = self.get_head_mask(head_mask, len(self.config.depths))
embedding_output, input_dimensions = self.embeddings(pixel_values) embedding_output, input_dimensions = self.embeddings(
pixel_values, interpolate_pos_encoding=interpolate_pos_encoding
)
encoder_outputs = self.encoder( encoder_outputs = self.encoder(
embedding_output, embedding_output,
......
...@@ -252,10 +252,48 @@ class SwinEmbeddings(nn.Module): ...@@ -252,10 +252,48 @@ class SwinEmbeddings(nn.Module):
self.norm = nn.LayerNorm(config.embed_dim) self.norm = nn.LayerNorm(config.embed_dim)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
resolution images.
Source:
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
"""
num_patches = embeddings.shape[1] - 1
num_positions = self.position_embeddings.shape[1] - 1
if num_patches == num_positions and height == width:
return self.position_embeddings
class_pos_embed = self.position_embeddings[:, 0]
patch_pos_embed = self.position_embeddings[:, 1:]
dim = embeddings.shape[-1]
h0 = height // self.config.patch_size
w0 = width // self.config.patch_size
# we add a small number to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
h0, w0 = h0 + 0.1, w0 + 0.1
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
mode="bicubic",
align_corners=False,
)
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
def forward( def forward(
self, pixel_values: Optional[torch.FloatTensor], bool_masked_pos: Optional[torch.BoolTensor] = None self,
pixel_values: Optional[torch.FloatTensor],
bool_masked_pos: Optional[torch.BoolTensor] = None,
interpolate_pos_encoding: bool = False,
) -> Tuple[torch.Tensor]: ) -> Tuple[torch.Tensor]:
embeddings, output_dimensions = self.patch_embeddings(pixel_values) _, num_channels, height, width = pixel_values.shape
embeddings, output_dimensions = self.patch_embeddings(
pixel_values, interpolate_pos_encoding=interpolate_pos_encoding
)
embeddings = self.norm(embeddings) embeddings = self.norm(embeddings)
batch_size, seq_len, _ = embeddings.size() batch_size, seq_len, _ = embeddings.size()
...@@ -266,6 +304,9 @@ class SwinEmbeddings(nn.Module): ...@@ -266,6 +304,9 @@ class SwinEmbeddings(nn.Module):
embeddings = embeddings * (1.0 - mask) + mask_tokens * mask embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
if self.position_embeddings is not None: if self.position_embeddings is not None:
if interpolate_pos_encoding:
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
else:
embeddings = embeddings + self.position_embeddings embeddings = embeddings + self.position_embeddings
embeddings = self.dropout(embeddings) embeddings = self.dropout(embeddings)
...@@ -304,7 +345,9 @@ class SwinPatchEmbeddings(nn.Module): ...@@ -304,7 +345,9 @@ class SwinPatchEmbeddings(nn.Module):
pixel_values = nn.functional.pad(pixel_values, pad_values) pixel_values = nn.functional.pad(pixel_values, pad_values)
return pixel_values return pixel_values
def forward(self, pixel_values: Optional[torch.FloatTensor]) -> Tuple[torch.Tensor, Tuple[int]]: def forward(
self, pixel_values: Optional[torch.FloatTensor], interpolate_pos_encoding: bool = False
) -> Tuple[torch.Tensor, Tuple[int]]:
_, num_channels, height, width = pixel_values.shape _, num_channels, height, width = pixel_values.shape
if num_channels != self.num_channels: if num_channels != self.num_channels:
raise ValueError( raise ValueError(
...@@ -312,6 +355,11 @@ class SwinPatchEmbeddings(nn.Module): ...@@ -312,6 +355,11 @@ class SwinPatchEmbeddings(nn.Module):
) )
# pad the input to be divisible by self.patch_size, if needed # pad the input to be divisible by self.patch_size, if needed
pixel_values = self.maybe_pad(pixel_values, height, width) pixel_values = self.maybe_pad(pixel_values, height, width)
if not interpolate_pos_encoding and (height != self.image_size[0] or width != self.image_size[1]):
raise ValueError(
f"Input image size ({height}*{width}) doesn't match model"
f" ({self.image_size[0]}*{self.image_size[1]})."
)
embeddings = self.projection(pixel_values) embeddings = self.projection(pixel_values)
_, _, height, width = embeddings.shape _, _, height, width = embeddings.shape
output_dimensions = (height, width) output_dimensions = (height, width)
...@@ -924,6 +972,8 @@ SWIN_INPUTS_DOCSTRING = r""" ...@@ -924,6 +972,8 @@ SWIN_INPUTS_DOCSTRING = r"""
output_hidden_states (`bool`, *optional*): output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail. more detail.
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
Whether to interpolate the pre-trained position encodings.
return_dict (`bool`, *optional*): return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
""" """
...@@ -981,6 +1031,7 @@ class SwinModel(SwinPreTrainedModel): ...@@ -981,6 +1031,7 @@ class SwinModel(SwinPreTrainedModel):
head_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
) -> Union[Tuple, SwinModelOutput]: ) -> Union[Tuple, SwinModelOutput]:
r""" r"""
...@@ -1003,7 +1054,9 @@ class SwinModel(SwinPreTrainedModel): ...@@ -1003,7 +1054,9 @@ class SwinModel(SwinPreTrainedModel):
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, len(self.config.depths)) head_mask = self.get_head_mask(head_mask, len(self.config.depths))
embedding_output, input_dimensions = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos) embedding_output, input_dimensions = self.embeddings(
pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
)
encoder_outputs = self.encoder( encoder_outputs = self.encoder(
embedding_output, embedding_output,
...@@ -1074,6 +1127,7 @@ class SwinForMaskedImageModeling(SwinPreTrainedModel): ...@@ -1074,6 +1127,7 @@ class SwinForMaskedImageModeling(SwinPreTrainedModel):
head_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
) -> Union[Tuple, SwinMaskedImageModelingOutput]: ) -> Union[Tuple, SwinMaskedImageModelingOutput]:
r""" r"""
...@@ -1113,6 +1167,7 @@ class SwinForMaskedImageModeling(SwinPreTrainedModel): ...@@ -1113,6 +1167,7 @@ class SwinForMaskedImageModeling(SwinPreTrainedModel):
head_mask=head_mask, head_mask=head_mask,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
interpolate_pos_encoding=interpolate_pos_encoding,
return_dict=return_dict, return_dict=return_dict,
) )
...@@ -1156,6 +1211,14 @@ class SwinForMaskedImageModeling(SwinPreTrainedModel): ...@@ -1156,6 +1211,14 @@ class SwinForMaskedImageModeling(SwinPreTrainedModel):
""" """
Swin Model transformer with an image classification head on top (a linear layer on top of the final hidden state of Swin Model transformer with an image classification head on top (a linear layer on top of the final hidden state of
the [CLS] token) e.g. for ImageNet. the [CLS] token) e.g. for ImageNet.
<Tip>
Note that it's possible to fine-tune Swin on higher resolution images than the ones it has been trained on, by
setting `interpolate_pos_encoding` to `True` in the forward of the model. This will interpolate the pre-trained
position embeddings to the higher resolution.
</Tip>
""", """,
SWIN_START_DOCSTRING, SWIN_START_DOCSTRING,
) )
...@@ -1188,6 +1251,7 @@ class SwinForImageClassification(SwinPreTrainedModel): ...@@ -1188,6 +1251,7 @@ class SwinForImageClassification(SwinPreTrainedModel):
labels: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
) -> Union[Tuple, SwinImageClassifierOutput]: ) -> Union[Tuple, SwinImageClassifierOutput]:
r""" r"""
...@@ -1203,6 +1267,7 @@ class SwinForImageClassification(SwinPreTrainedModel): ...@@ -1203,6 +1267,7 @@ class SwinForImageClassification(SwinPreTrainedModel):
head_mask=head_mask, head_mask=head_mask,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
interpolate_pos_encoding=interpolate_pos_encoding,
return_dict=return_dict, return_dict=return_dict,
) )
......
...@@ -295,10 +295,48 @@ class Swinv2Embeddings(nn.Module): ...@@ -295,10 +295,48 @@ class Swinv2Embeddings(nn.Module):
self.norm = nn.LayerNorm(config.embed_dim) self.norm = nn.LayerNorm(config.embed_dim)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
resolution images.
Source:
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
"""
num_patches = embeddings.shape[1] - 1
num_positions = self.position_embeddings.shape[1] - 1
if num_patches == num_positions and height == width:
return self.position_embeddings
class_pos_embed = self.position_embeddings[:, 0]
patch_pos_embed = self.position_embeddings[:, 1:]
dim = embeddings.shape[-1]
h0 = height // self.config.patch_size
w0 = width // self.config.patch_size
# we add a small number to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
h0, w0 = h0 + 0.1, w0 + 0.1
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
mode="bicubic",
align_corners=False,
)
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
def forward( def forward(
self, pixel_values: Optional[torch.FloatTensor], bool_masked_pos: Optional[torch.BoolTensor] = None self,
pixel_values: Optional[torch.FloatTensor],
bool_masked_pos: Optional[torch.BoolTensor] = None,
interpolate_pos_encoding: bool = False,
) -> Tuple[torch.Tensor]: ) -> Tuple[torch.Tensor]:
embeddings, output_dimensions = self.patch_embeddings(pixel_values) _, num_channels, height, width = pixel_values.shape
embeddings, output_dimensions = self.patch_embeddings(
pixel_values, interpolate_pos_encoding=interpolate_pos_encoding
)
embeddings = self.norm(embeddings) embeddings = self.norm(embeddings)
batch_size, seq_len, _ = embeddings.size() batch_size, seq_len, _ = embeddings.size()
...@@ -309,6 +347,9 @@ class Swinv2Embeddings(nn.Module): ...@@ -309,6 +347,9 @@ class Swinv2Embeddings(nn.Module):
embeddings = embeddings * (1.0 - mask) + mask_tokens * mask embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
if self.position_embeddings is not None: if self.position_embeddings is not None:
if interpolate_pos_encoding:
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
else:
embeddings = embeddings + self.position_embeddings embeddings = embeddings + self.position_embeddings
embeddings = self.dropout(embeddings) embeddings = self.dropout(embeddings)
...@@ -348,7 +389,9 @@ class Swinv2PatchEmbeddings(nn.Module): ...@@ -348,7 +389,9 @@ class Swinv2PatchEmbeddings(nn.Module):
pixel_values = nn.functional.pad(pixel_values, pad_values) pixel_values = nn.functional.pad(pixel_values, pad_values)
return pixel_values return pixel_values
def forward(self, pixel_values: Optional[torch.FloatTensor]) -> Tuple[torch.Tensor, Tuple[int]]: def forward(
self, pixel_values: Optional[torch.FloatTensor], interpolate_pos_encoding: bool = False
) -> Tuple[torch.Tensor, Tuple[int]]:
_, num_channels, height, width = pixel_values.shape _, num_channels, height, width = pixel_values.shape
if num_channels != self.num_channels: if num_channels != self.num_channels:
raise ValueError( raise ValueError(
...@@ -356,6 +399,11 @@ class Swinv2PatchEmbeddings(nn.Module): ...@@ -356,6 +399,11 @@ class Swinv2PatchEmbeddings(nn.Module):
) )
# pad the input to be divisible by self.patch_size, if needed # pad the input to be divisible by self.patch_size, if needed
pixel_values = self.maybe_pad(pixel_values, height, width) pixel_values = self.maybe_pad(pixel_values, height, width)
if not interpolate_pos_encoding and (height != self.image_size[0] or width != self.image_size[1]):
raise ValueError(
f"Input image size ({height}*{width}) doesn't match model"
f" ({self.image_size[0]}*{self.image_size[1]})."
)
embeddings = self.projection(pixel_values) embeddings = self.projection(pixel_values)
_, _, height, width = embeddings.shape _, _, height, width = embeddings.shape
output_dimensions = (height, width) output_dimensions = (height, width)
...@@ -979,6 +1027,8 @@ SWINV2_INPUTS_DOCSTRING = r""" ...@@ -979,6 +1027,8 @@ SWINV2_INPUTS_DOCSTRING = r"""
output_hidden_states (`bool`, *optional*): output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail. more detail.
interpolate_pos_encoding (`bool`, *optional*, default `False`):
Whether to interpolate the pre-trained position encodings.
return_dict (`bool`, *optional*): return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
""" """
...@@ -1031,6 +1081,7 @@ class Swinv2Model(Swinv2PreTrainedModel): ...@@ -1031,6 +1081,7 @@ class Swinv2Model(Swinv2PreTrainedModel):
head_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
) -> Union[Tuple, Swinv2ModelOutput]: ) -> Union[Tuple, Swinv2ModelOutput]:
r""" r"""
...@@ -1053,7 +1104,9 @@ class Swinv2Model(Swinv2PreTrainedModel): ...@@ -1053,7 +1104,9 @@ class Swinv2Model(Swinv2PreTrainedModel):
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, len(self.config.depths)) head_mask = self.get_head_mask(head_mask, len(self.config.depths))
embedding_output, input_dimensions = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos) embedding_output, input_dimensions = self.embeddings(
pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
)
encoder_outputs = self.encoder( encoder_outputs = self.encoder(
embedding_output, embedding_output,
...@@ -1126,6 +1179,7 @@ class Swinv2ForMaskedImageModeling(Swinv2PreTrainedModel): ...@@ -1126,6 +1179,7 @@ class Swinv2ForMaskedImageModeling(Swinv2PreTrainedModel):
head_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
) -> Union[Tuple, Swinv2MaskedImageModelingOutput]: ) -> Union[Tuple, Swinv2MaskedImageModelingOutput]:
r""" r"""
...@@ -1165,6 +1219,7 @@ class Swinv2ForMaskedImageModeling(Swinv2PreTrainedModel): ...@@ -1165,6 +1219,7 @@ class Swinv2ForMaskedImageModeling(Swinv2PreTrainedModel):
head_mask=head_mask, head_mask=head_mask,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
interpolate_pos_encoding=interpolate_pos_encoding,
return_dict=return_dict, return_dict=return_dict,
) )
...@@ -1208,6 +1263,14 @@ class Swinv2ForMaskedImageModeling(Swinv2PreTrainedModel): ...@@ -1208,6 +1263,14 @@ class Swinv2ForMaskedImageModeling(Swinv2PreTrainedModel):
""" """
Swinv2 Model transformer with an image classification head on top (a linear layer on top of the final hidden state Swinv2 Model transformer with an image classification head on top (a linear layer on top of the final hidden state
of the [CLS] token) e.g. for ImageNet. of the [CLS] token) e.g. for ImageNet.
<Tip>
Note that it's possible to fine-tune SwinV2 on higher resolution images than the ones it has been trained on, by
setting `interpolate_pos_encoding` to `True` in the forward of the model. This will interpolate the pre-trained
position embeddings to the higher resolution.
</Tip>
""", """,
SWINV2_START_DOCSTRING, SWINV2_START_DOCSTRING,
) )
...@@ -1241,6 +1304,7 @@ class Swinv2ForImageClassification(Swinv2PreTrainedModel): ...@@ -1241,6 +1304,7 @@ class Swinv2ForImageClassification(Swinv2PreTrainedModel):
labels: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
) -> Union[Tuple, Swinv2ImageClassifierOutput]: ) -> Union[Tuple, Swinv2ImageClassifierOutput]:
r""" r"""
...@@ -1256,6 +1320,7 @@ class Swinv2ForImageClassification(Swinv2PreTrainedModel): ...@@ -1256,6 +1320,7 @@ class Swinv2ForImageClassification(Swinv2PreTrainedModel):
head_mask=head_mask, head_mask=head_mask,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
interpolate_pos_encoding=interpolate_pos_encoding,
return_dict=return_dict, return_dict=return_dict,
) )
......
...@@ -493,6 +493,26 @@ class SwinModelIntegrationTest(unittest.TestCase): ...@@ -493,6 +493,26 @@ class SwinModelIntegrationTest(unittest.TestCase):
expected_slice = torch.tensor([-0.0948, -0.6454, -0.0921]).to(torch_device) expected_slice = torch.tensor([-0.0948, -0.6454, -0.0921]).to(torch_device)
self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4)) self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))
@slow
def test_inference_interpolate_pos_encoding(self):
# Swin models have an `interpolate_pos_encoding` argument in their forward method,
# allowing to interpolate the pre-trained position embeddings in order to use
# the model on higher resolutions.
model = SwinModel.from_pretrained("microsoft/swin-tiny-patch4-window7-224").to(torch_device)
image_processor = self.default_image_processor
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
inputs = image_processor(images=image, size={"height": 481, "width": 481}, return_tensors="pt")
pixel_values = inputs.pixel_values.to(torch_device)
# forward pass
with torch.no_grad():
outputs = model(pixel_values, interpolate_pos_encoding=True)
# verify the logits
expected_shape = torch.Size((1, 256, 768))
self.assertEqual(outputs.last_hidden_state.shape, expected_shape)
@require_torch @require_torch
class SwinBackboneTest(unittest.TestCase, BackboneTesterMixin): class SwinBackboneTest(unittest.TestCase, BackboneTesterMixin):
......
...@@ -485,6 +485,26 @@ class Swinv2ModelIntegrationTest(unittest.TestCase): ...@@ -485,6 +485,26 @@ class Swinv2ModelIntegrationTest(unittest.TestCase):
expected_slice = torch.tensor([-0.3947, -0.4306, 0.0026]).to(torch_device) expected_slice = torch.tensor([-0.3947, -0.4306, 0.0026]).to(torch_device)
self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4)) self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))
@slow
def test_inference_interpolate_pos_encoding(self):
# Swinv2 models have an `interpolate_pos_encoding` argument in their forward method,
# allowing to interpolate the pre-trained position embeddings in order to use
# the model on higher resolutions.
model = Swinv2Model.from_pretrained("microsoft/swinv2-tiny-patch4-window8-256").to(torch_device)
image_processor = self.default_image_processor
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
inputs = image_processor(images=image, size={"height": 481, "width": 481}, return_tensors="pt")
pixel_values = inputs.pixel_values.to(torch_device)
# forward pass
with torch.no_grad():
outputs = model(pixel_values, interpolate_pos_encoding=True)
# verify the logits
expected_shape = torch.Size((1, 256, 768))
self.assertEqual(outputs.last_hidden_state.shape, expected_shape)
@require_torch @require_torch
class Swinv2BackboneTest(unittest.TestCase, BackboneTesterMixin): class Swinv2BackboneTest(unittest.TestCase, BackboneTesterMixin):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment