Unverified Commit e5103a76 authored by BHUVAN M's avatar BHUVAN M Committed by GitHub
Browse files

added interpolation for vitmae model in pytorch as well as tf. (#30732)



* added interpolation for vitmae model in pytorch as well as tf.

* Update modeling_vit_mae.py

irreugalr import fixed

* small changes and proper formatting

* changes suggested in review.

* modified decoder interpolate_func

* arguments and docstring fix

* Apply suggestions from code review

doc fixes
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

---------
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent a3cdff41
...@@ -240,6 +240,38 @@ class TFViTMAEEmbeddings(keras.layers.Layer): ...@@ -240,6 +240,38 @@ class TFViTMAEEmbeddings(keras.layers.Layer):
with tf.name_scope(self.patch_embeddings.name): with tf.name_scope(self.patch_embeddings.name):
self.patch_embeddings.build(None) self.patch_embeddings.build(None)
def interpolate_pos_encoding(self, embeddings, height, width) -> tf.Tensor:
"""
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
resolution images.
Source:
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
"""
batch_size, seq_len, dim = shape_list(embeddings)
num_patches = seq_len - 1
_, num_positions, _ = shape_list(self.position_embeddings)
num_positions -= 1
if num_patches == num_positions and height == width:
return self.position_embeddings
class_pos_embed = self.position_embeddings[:, :1]
patch_pos_embed = self.position_embeddings[:, 1:]
h0 = height // self.config.patch_size
w0 = width // self.config.patch_size
patch_pos_embed = tf.image.resize(
images=tf.reshape(
patch_pos_embed, shape=(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
),
size=(h0, w0),
method="bicubic",
)
patch_pos_embed = tf.reshape(tensor=patch_pos_embed, shape=(1, -1, dim))
return tf.concat(values=(class_pos_embed, patch_pos_embed), axis=1)
def random_masking(self, sequence: tf.Tensor, noise: tf.Tensor | None = None): def random_masking(self, sequence: tf.Tensor, noise: tf.Tensor | None = None):
""" """
Perform per-sample random masking by per-sample shuffling. Per-sample shuffling is done by argsort random Perform per-sample random masking by per-sample shuffling. Per-sample shuffling is done by argsort random
...@@ -281,17 +313,23 @@ class TFViTMAEEmbeddings(keras.layers.Layer): ...@@ -281,17 +313,23 @@ class TFViTMAEEmbeddings(keras.layers.Layer):
return sequence_unmasked, mask, ids_restore return sequence_unmasked, mask, ids_restore
def call(self, pixel_values: tf.Tensor, noise: tf.Tensor = None) -> tf.Tensor: def call(
embeddings = self.patch_embeddings(pixel_values) self, pixel_values: tf.Tensor, noise: tf.Tensor = None, interpolate_pos_encoding: bool = False
) -> tf.Tensor:
batch_size, num_channels, height, width = shape_list(pixel_values)
embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
if interpolate_pos_encoding:
position_embeddings = self.interpolate_pos_encoding(embeddings, height, width)
else:
position_embeddings = self.position_embeddings
# add position embeddings w/o cls token # add position embeddings w/o cls token
embeddings = embeddings + self.position_embeddings[:, 1:, :] embeddings = embeddings + position_embeddings[:, 1:, :]
# masking: length -> length * config.mask_ratio # masking: length -> length * config.mask_ratio
embeddings, mask, ids_restore = self.random_masking(embeddings, noise) embeddings, mask, ids_restore = self.random_masking(embeddings, noise)
# append cls token # append cls token
cls_token = self.cls_token + self.position_embeddings[:, :1, :] cls_token = self.cls_token + position_embeddings[:, :1, :]
cls_tokens = tf.tile(cls_token, (shape_list(embeddings)[0], 1, 1)) cls_tokens = tf.tile(cls_token, (shape_list(embeddings)[0], 1, 1))
embeddings = tf.concat([cls_tokens, embeddings], axis=1) embeddings = tf.concat([cls_tokens, embeddings], axis=1)
...@@ -329,7 +367,9 @@ class TFViTMAEPatchEmbeddings(keras.layers.Layer): ...@@ -329,7 +367,9 @@ class TFViTMAEPatchEmbeddings(keras.layers.Layer):
name="projection", name="projection",
) )
def call(self, pixel_values: tf.Tensor, training: bool = False) -> tf.Tensor: def call(
self, pixel_values: tf.Tensor, training: bool = False, interpolate_pos_encoding: bool = False
) -> tf.Tensor:
batch_size, num_channels, height, width = shape_list(pixel_values) batch_size, num_channels, height, width = shape_list(pixel_values)
if tf.executing_eagerly(): if tf.executing_eagerly():
if num_channels != self.num_channels: if num_channels != self.num_channels:
...@@ -337,7 +377,7 @@ class TFViTMAEPatchEmbeddings(keras.layers.Layer): ...@@ -337,7 +377,7 @@ class TFViTMAEPatchEmbeddings(keras.layers.Layer):
"Make sure that the channel dimension of the pixel values match with the one set in the" "Make sure that the channel dimension of the pixel values match with the one set in the"
" configuration." " configuration."
) )
if height != self.image_size[0] or width != self.image_size[1]: if not interpolate_pos_encoding and (height != self.image_size[0] or width != self.image_size[1]):
raise ValueError( raise ValueError(
f"Input image size ({height}*{width}) doesn't match model" f"Input image size ({height}*{width}) doesn't match model"
f" ({self.image_size[0]}*{self.image_size[1]})." f" ({self.image_size[0]}*{self.image_size[1]})."
...@@ -741,9 +781,13 @@ class TFViTMAEMainLayer(keras.layers.Layer): ...@@ -741,9 +781,13 @@ class TFViTMAEMainLayer(keras.layers.Layer):
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
training: bool = False, training: bool = False,
interpolate_pos_encoding: bool = False,
) -> Union[TFViTMAEModelOutput, Tuple[tf.Tensor]]: ) -> Union[TFViTMAEModelOutput, Tuple[tf.Tensor]]:
embedding_output, mask, ids_restore = self.embeddings( embedding_output, mask, ids_restore = self.embeddings(
pixel_values=pixel_values, training=training, noise=noise pixel_values=pixel_values,
training=training,
noise=noise,
interpolate_pos_encoding=interpolate_pos_encoding,
) )
# Prepare head mask if needed # Prepare head mask if needed
...@@ -874,6 +918,9 @@ VIT_MAE_INPUTS_DOCSTRING = r""" ...@@ -874,6 +918,9 @@ VIT_MAE_INPUTS_DOCSTRING = r"""
training (`bool`, *optional*, defaults to `False``): training (`bool`, *optional*, defaults to `False``):
Whether or not to use the model in training mode (some modules like dropout modules have different Whether or not to use the model in training mode (some modules like dropout modules have different
behaviors between training and evaluation). behaviors between training and evaluation).
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
Whether to interpolate the position encodings at the encoder and decoder.
""" """
...@@ -902,6 +949,7 @@ class TFViTMAEModel(TFViTMAEPreTrainedModel): ...@@ -902,6 +949,7 @@ class TFViTMAEModel(TFViTMAEPreTrainedModel):
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
training: bool = False, training: bool = False,
interpolate_pos_encoding: bool = False,
) -> Union[TFViTMAEModelOutput, Tuple[tf.Tensor]]: ) -> Union[TFViTMAEModelOutput, Tuple[tf.Tensor]]:
r""" r"""
Returns: Returns:
...@@ -931,6 +979,7 @@ class TFViTMAEModel(TFViTMAEPreTrainedModel): ...@@ -931,6 +979,7 @@ class TFViTMAEModel(TFViTMAEPreTrainedModel):
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
training=training, training=training,
interpolate_pos_encoding=interpolate_pos_encoding,
) )
return outputs return outputs
...@@ -1004,6 +1053,39 @@ class TFViTMAEDecoder(keras.layers.Layer): ...@@ -1004,6 +1053,39 @@ class TFViTMAEDecoder(keras.layers.Layer):
with tf.name_scope(layer.name): with tf.name_scope(layer.name):
layer.build(None) layer.build(None)
def interpolate_pos_encoding(self, embeddings) -> tf.Tensor:
"""
This method is a modified version of the interpolation function for ViT-mae model at the deocder, that
allows to interpolate the pre-trained decoder position encodings, to be able to use the model on higher
resolution images.
Source:
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
"""
# [batch_size, num_patches + 1, hidden_size]
_, num_positions, dim = shape_list(self.decoder_pos_embed)
# -1 removes the class dimension since we later append it without interpolation
seq_len = shape_list(embeddings)[1] - 1
num_positions = num_positions - 1
# Separation of class token and patch tokens
class_pos_embed = self.decoder_pos_embed[:, :1, :]
patch_pos_embed = self.decoder_pos_embed[:, 1:, :]
# interpolate the position embeddings
patch_pos_embed = tf.image.resize(
images=tf.reshape(patch_pos_embed, shape=(1, 1, -1, dim)),
size=(1, seq_len),
method="bicubic",
)
# [1, seq_len, hidden_size]
patch_pos_embed = tf.reshape(tensor=patch_pos_embed, shape=(1, -1, dim))
# Adding the class token back
return tf.concat(values=(class_pos_embed, patch_pos_embed), axis=1)
def call( def call(
self, self,
hidden_states, hidden_states,
...@@ -1011,10 +1093,10 @@ class TFViTMAEDecoder(keras.layers.Layer): ...@@ -1011,10 +1093,10 @@ class TFViTMAEDecoder(keras.layers.Layer):
output_attentions=False, output_attentions=False,
output_hidden_states=False, output_hidden_states=False,
return_dict=True, return_dict=True,
interpolate_pos_encoding=False,
): ):
# embed tokens # embed tokens
x = self.decoder_embed(hidden_states) x = self.decoder_embed(hidden_states)
# append mask tokens to sequence # append mask tokens to sequence
mask_tokens = tf.tile( mask_tokens = tf.tile(
self.mask_token, self.mask_token,
...@@ -1023,10 +1105,12 @@ class TFViTMAEDecoder(keras.layers.Layer): ...@@ -1023,10 +1105,12 @@ class TFViTMAEDecoder(keras.layers.Layer):
x_ = tf.concat([x[:, 1:, :], mask_tokens], axis=1) # no cls token x_ = tf.concat([x[:, 1:, :], mask_tokens], axis=1) # no cls token
x_ = tf.gather(x_, axis=1, batch_dims=1, indices=ids_restore) # unshuffle x_ = tf.gather(x_, axis=1, batch_dims=1, indices=ids_restore) # unshuffle
x = tf.concat([x[:, :1, :], x_], axis=1) # append cls token x = tf.concat([x[:, :1, :], x_], axis=1) # append cls token
if interpolate_pos_encoding:
decoder_pos_embed = self.interpolate_pos_encoding(x)
else:
decoder_pos_embed = self.decoder_pos_embed
# add pos embed # add pos embed
hidden_states = x + self.decoder_pos_embed hidden_states = x + decoder_pos_embed
# apply Transformer layers (blocks) # apply Transformer layers (blocks)
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None all_self_attentions = () if output_attentions else None
...@@ -1083,11 +1167,13 @@ class TFViTMAEForPreTraining(TFViTMAEPreTrainedModel): ...@@ -1083,11 +1167,13 @@ class TFViTMAEForPreTraining(TFViTMAEPreTrainedModel):
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
raise NotImplementedError raise NotImplementedError
def patchify(self, pixel_values): def patchify(self, pixel_values, interpolate_pos_encoding: bool = False):
""" """
Args: Args:
pixel_values (`tf.Tensor` of shape `(batch_size, height, width, num_channels)` or `(batch_size, num_channels, height, width)`): pixel_values (`tf.Tensor` of shape `(batch_size, height, width, num_channels)` or `(batch_size, num_channels, height, width)`):
Pixel values. Pixel values.
interpolate_pos_encoding (`bool`, default `False`):
interpolation flag passed during the forward pass.
Returns: Returns:
`tf.Tensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`: `tf.Tensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`:
...@@ -1099,6 +1185,7 @@ class TFViTMAEForPreTraining(TFViTMAEPreTrainedModel): ...@@ -1099,6 +1185,7 @@ class TFViTMAEForPreTraining(TFViTMAEPreTrainedModel):
pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1)) pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1))
# sanity checks # sanity checks
if not interpolate_pos_encoding:
tf.debugging.assert_equal( tf.debugging.assert_equal(
shape_list(pixel_values)[1], shape_list(pixel_values)[1],
shape_list(pixel_values)[2], shape_list(pixel_values)[2],
...@@ -1119,51 +1206,61 @@ class TFViTMAEForPreTraining(TFViTMAEPreTrainedModel): ...@@ -1119,51 +1206,61 @@ class TFViTMAEForPreTraining(TFViTMAEPreTrainedModel):
# patchify # patchify
batch_size = shape_list(pixel_values)[0] batch_size = shape_list(pixel_values)[0]
num_patches_one_direction = shape_list(pixel_values)[2] // patch_size num_patches_h = shape_list(pixel_values)[1] // patch_size
num_patches_w = shape_list(pixel_values)[2] // patch_size
patchified_pixel_values = tf.reshape( patchified_pixel_values = tf.reshape(
pixel_values, pixel_values,
(batch_size, num_patches_one_direction, patch_size, num_patches_one_direction, patch_size, num_channels), (batch_size, num_patches_h, patch_size, num_patches_w, patch_size, num_channels),
) )
patchified_pixel_values = tf.einsum("nhpwqc->nhwpqc", patchified_pixel_values) patchified_pixel_values = tf.einsum("nhpwqc->nhwpqc", patchified_pixel_values)
patchified_pixel_values = tf.reshape( patchified_pixel_values = tf.reshape(
patchified_pixel_values, patchified_pixel_values,
(batch_size, num_patches_one_direction * num_patches_one_direction, patch_size**2 * num_channels), (batch_size, num_patches_h * num_patches_w, patch_size**2 * num_channels),
) )
return patchified_pixel_values return patchified_pixel_values
def unpatchify(self, patchified_pixel_values): def unpatchify(self, patchified_pixel_values, original_image_size: Optional[Tuple[int, int]] = None):
""" """
Args: Args:
patchified_pixel_values (`tf.Tensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`: patchified_pixel_values (`tf.Tensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`:
Patchified pixel values. Patchified pixel values.
original_image_size (`Tuple[int, int]`, *optional*):
Original image size.
Returns: Returns:
`tf.Tensor` of shape `(batch_size, height, width, num_channels)`: `tf.Tensor` of shape `(batch_size, height, width, num_channels)`:
Pixel values. Pixel values.
""" """
patch_size, num_channels = self.config.patch_size, self.config.num_channels patch_size, num_channels = self.config.patch_size, self.config.num_channels
num_patches_one_direction = int(shape_list(patchified_pixel_values)[1] ** 0.5) original_image_size = (
original_image_size
if original_image_size is not None
else (self.config.image_size, self.config.image_size)
)
original_height, original_width = original_image_size
num_patches_h = original_height // patch_size
num_patches_w = original_width // patch_size
# sanity check # sanity check
tf.debugging.assert_equal( tf.debugging.assert_equal(
num_patches_one_direction * num_patches_one_direction, num_patches_h * num_patches_w,
shape_list(patchified_pixel_values)[1], shape_list(patchified_pixel_values)[1],
message="Make sure that the number of patches can be squared", message=f"The number of patches in the patchified pixel values is {shape_list(patchified_pixel_values)[1]} does not match the patches of original image {num_patches_w}*{num_patches_h}",
) )
# unpatchify # unpatchify
batch_size = shape_list(patchified_pixel_values)[0] batch_size = shape_list(patchified_pixel_values)[0]
patchified_pixel_values = tf.reshape( patchified_pixel_values = tf.reshape(
patchified_pixel_values, patchified_pixel_values,
(batch_size, num_patches_one_direction, num_patches_one_direction, patch_size, patch_size, num_channels), (batch_size, num_patches_h, num_patches_w, patch_size, patch_size, num_channels),
) )
patchified_pixel_values = tf.einsum("nhwpqc->nhpwqc", patchified_pixel_values) patchified_pixel_values = tf.einsum("nhwpqc->nhpwqc", patchified_pixel_values)
pixel_values = tf.reshape( pixel_values = tf.reshape(
patchified_pixel_values, patchified_pixel_values,
(batch_size, num_patches_one_direction * patch_size, num_patches_one_direction * patch_size, num_channels), (batch_size, num_patches_h * patch_size, num_patches_w * patch_size, num_channels),
) )
return pixel_values return pixel_values
def forward_loss(self, pixel_values, pred, mask): def forward_loss(self, pixel_values, pred, mask, interpolate_pos_encoding: bool = False):
""" """
Args: Args:
pixel_values (`tf.Tensor` of shape `(batch_size, height, width, num_channels)`): pixel_values (`tf.Tensor` of shape `(batch_size, height, width, num_channels)`):
...@@ -1172,11 +1269,13 @@ class TFViTMAEForPreTraining(TFViTMAEPreTrainedModel): ...@@ -1172,11 +1269,13 @@ class TFViTMAEForPreTraining(TFViTMAEPreTrainedModel):
Predicted pixel values. Predicted pixel values.
mask (`tf.Tensor` of shape `(batch_size, sequence_length)`): mask (`tf.Tensor` of shape `(batch_size, sequence_length)`):
Tensor indicating which patches are masked (1) and which are not (0). Tensor indicating which patches are masked (1) and which are not (0).
interpolate_pos_encoding (`bool`, *optional*, default `False`):
interpolation flag passed during the forward pass.
Returns: Returns:
`tf.Tensor`: Pixel reconstruction loss. `tf.Tensor`: Pixel reconstruction loss.
""" """
target = self.patchify(pixel_values) target = self.patchify(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
if self.config.norm_pix_loss: if self.config.norm_pix_loss:
mean = tf.reduce_mean(target, axis=-1, keepdims=True) mean = tf.reduce_mean(target, axis=-1, keepdims=True)
var = tf.math.reduce_variance(target, axis=-1, keepdims=True) var = tf.math.reduce_variance(target, axis=-1, keepdims=True)
...@@ -1201,6 +1300,7 @@ class TFViTMAEForPreTraining(TFViTMAEPreTrainedModel): ...@@ -1201,6 +1300,7 @@ class TFViTMAEForPreTraining(TFViTMAEPreTrainedModel):
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
training: bool = False, training: bool = False,
interpolate_pos_encoding: bool = False,
) -> Union[TFViTMAEForPreTrainingOutput, Tuple[tf.Tensor]]: ) -> Union[TFViTMAEForPreTrainingOutput, Tuple[tf.Tensor]]:
r""" r"""
Returns: Returns:
...@@ -1234,16 +1334,18 @@ class TFViTMAEForPreTraining(TFViTMAEPreTrainedModel): ...@@ -1234,16 +1334,18 @@ class TFViTMAEForPreTraining(TFViTMAEPreTrainedModel):
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
training=training, training=training,
interpolate_pos_encoding=interpolate_pos_encoding,
) )
latent = outputs.last_hidden_state latent = outputs.last_hidden_state
ids_restore = outputs.ids_restore ids_restore = outputs.ids_restore
mask = outputs.mask mask = outputs.mask
decoder_outputs = self.decoder(latent, ids_restore) # [batch_size, num_patches, patch_size**2*3] # [batch_size, num_patches, patch_size**2*3]
decoder_outputs = self.decoder(latent, ids_restore, interpolate_pos_encoding=interpolate_pos_encoding)
logits = decoder_outputs.logits logits = decoder_outputs.logits
loss = self.forward_loss(pixel_values, logits, mask) loss = self.forward_loss(pixel_values, logits, mask, interpolate_pos_encoding=interpolate_pos_encoding)
if not return_dict: if not return_dict:
output = (logits, mask, ids_restore) + outputs[2:] output = (logits, mask, ids_restore) + outputs[2:]
......
...@@ -223,6 +223,41 @@ class ViTMAEEmbeddings(nn.Module): ...@@ -223,6 +223,41 @@ class ViTMAEEmbeddings(nn.Module):
# timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.) # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
torch.nn.init.normal_(self.cls_token, std=self.config.initializer_range) torch.nn.init.normal_(self.cls_token, std=self.config.initializer_range)
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
resolution images.
Source:
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
"""
num_patches = embeddings.shape[1] - 1
num_positions = self.position_embeddings.shape[1] - 1
if num_patches == num_positions and height == width:
return self.position_embeddings
class_pos_embed = self.position_embeddings[:, 0, :]
patch_pos_embed = self.position_embeddings[:, 1:, :]
dim = embeddings.shape[-1]
h0 = height // self.config.patch_size
w0 = width // self.config.patch_size
# we add a small number to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
h0, w0 = h0 + 0.1, w0 + 0.1
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
mode="bicubic",
align_corners=False,
)
if int(h0) != patch_pos_embed.shape[-2] or int(w0) != patch_pos_embed.shape[-1]:
raise ValueError("Width or height does not match with the interpolated position embeddings")
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
def random_masking(self, sequence, noise=None): def random_masking(self, sequence, noise=None):
""" """
Perform per-sample random masking by per-sample shuffling. Per-sample shuffling is done by argsort random Perform per-sample random masking by per-sample shuffling. Per-sample shuffling is done by argsort random
...@@ -255,18 +290,22 @@ class ViTMAEEmbeddings(nn.Module): ...@@ -255,18 +290,22 @@ class ViTMAEEmbeddings(nn.Module):
return sequence_unmasked, mask, ids_restore return sequence_unmasked, mask, ids_restore
def forward(self, pixel_values, noise=None): def forward(self, pixel_values, noise=None, interpolate_pos_encoding: bool = False):
batch_size, num_channels, height, width = pixel_values.shape batch_size, num_channels, height, width = pixel_values.shape
embeddings = self.patch_embeddings(pixel_values) embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
if interpolate_pos_encoding:
position_embeddings = self.interpolate_pos_encoding(embeddings, height, width)
else:
position_embeddings = self.position_embeddings
# add position embeddings w/o cls token # add position embeddings w/o cls token
embeddings = embeddings + self.position_embeddings[:, 1:, :] embeddings = embeddings + position_embeddings[:, 1:, :]
# masking: length -> length * config.mask_ratio # masking: length -> length * config.mask_ratio
embeddings, mask, ids_restore = self.random_masking(embeddings, noise) embeddings, mask, ids_restore = self.random_masking(embeddings, noise)
# append cls token # append cls token
cls_token = self.cls_token + self.position_embeddings[:, :1, :] cls_token = self.cls_token + position_embeddings[:, :1, :]
cls_tokens = cls_token.expand(embeddings.shape[0], -1, -1) cls_tokens = cls_token.expand(embeddings.shape[0], -1, -1)
embeddings = torch.cat((cls_tokens, embeddings), dim=1) embeddings = torch.cat((cls_tokens, embeddings), dim=1)
...@@ -294,13 +333,14 @@ class ViTMAEPatchEmbeddings(nn.Module): ...@@ -294,13 +333,14 @@ class ViTMAEPatchEmbeddings(nn.Module):
self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size) self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
def forward(self, pixel_values): def forward(self, pixel_values, interpolate_pos_encoding: bool = False):
batch_size, num_channels, height, width = pixel_values.shape batch_size, num_channels, height, width = pixel_values.shape
if num_channels != self.num_channels: if num_channels != self.num_channels:
raise ValueError( raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration." "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
) )
if height != self.image_size[0] or width != self.image_size[1]:
if not interpolate_pos_encoding and (height != self.image_size[0] or width != self.image_size[1]):
raise ValueError( raise ValueError(
f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})." f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
) )
...@@ -657,6 +697,9 @@ VIT_MAE_INPUTS_DOCSTRING = r""" ...@@ -657,6 +697,9 @@ VIT_MAE_INPUTS_DOCSTRING = r"""
more detail. more detail.
return_dict (`bool`, *optional*): return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
interpolate_pos_encoding (`bool`, *optional*, default `False`):
Whether to interpolate the pre-trained position encodings. This is mainly used to use the model on higher
resolution images.
""" """
...@@ -698,6 +741,7 @@ class ViTMAEModel(ViTMAEPreTrainedModel): ...@@ -698,6 +741,7 @@ class ViTMAEModel(ViTMAEPreTrainedModel):
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
) -> Union[Tuple, ViTMAEModelOutput]: ) -> Union[Tuple, ViTMAEModelOutput]:
r""" r"""
Returns: Returns:
...@@ -735,7 +779,9 @@ class ViTMAEModel(ViTMAEPreTrainedModel): ...@@ -735,7 +779,9 @@ class ViTMAEModel(ViTMAEPreTrainedModel):
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
embedding_output, mask, ids_restore = self.embeddings(pixel_values, noise=noise) embedding_output, mask, ids_restore = self.embeddings(
pixel_values, noise=noise, interpolate_pos_encoding=interpolate_pos_encoding
)
encoder_outputs = self.encoder( encoder_outputs = self.encoder(
embedding_output, embedding_output,
...@@ -785,6 +831,47 @@ class ViTMAEDecoder(nn.Module): ...@@ -785,6 +831,47 @@ class ViTMAEDecoder(nn.Module):
self.config = config self.config = config
self.initialize_weights(num_patches) self.initialize_weights(num_patches)
def interpolate_pos_encoding(self, embeddings: torch.Tensor) -> torch.Tensor:
"""
This method is a modified version of the interpolation function for ViT-mae model at the deocder, that
allows to interpolate the pre-trained decoder position encodings, to be able to use the model on higher
resolution images.
Source:
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
"""
# -1 removes the class dimension since we later append it without interpolation
embeddings_positions = embeddings.shape[1] - 1
num_positions = self.decoder_pos_embed.shape[1] - 1
# Separation of class token and patch tokens
class_pos_embed = self.decoder_pos_embed[:, 0, :]
patch_pos_embed = self.decoder_pos_embed[:, 1:, :]
# To retain the final 3d tensor with the required dimensions
dim = self.decoder_pos_embed.shape[-1]
# Increasing a dimension to enable bicubic interpolation
patch_pos_embed = patch_pos_embed.reshape(1, 1, -1, dim)
# permute to bring the dimension to be interpolated, to the last
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
# Interpolating the decoder position embeddings shape wrt embeddings shape i.e (x).
# 1 keeps the other dimension constant
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
scale_factor=(1, embeddings_positions / num_positions),
mode="bicubic",
align_corners=False,
)
# Converting back to the original shape
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
# Adding the class token back
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
def initialize_weights(self, num_patches): def initialize_weights(self, num_patches):
# initialize (and freeze) position embeddings by sin-cos embedding # initialize (and freeze) position embeddings by sin-cos embedding
decoder_pos_embed = get_2d_sincos_pos_embed( decoder_pos_embed = get_2d_sincos_pos_embed(
...@@ -802,6 +889,7 @@ class ViTMAEDecoder(nn.Module): ...@@ -802,6 +889,7 @@ class ViTMAEDecoder(nn.Module):
output_attentions=False, output_attentions=False,
output_hidden_states=False, output_hidden_states=False,
return_dict=True, return_dict=True,
interpolate_pos_encoding: bool = False,
): ):
# embed tokens # embed tokens
x = self.decoder_embed(hidden_states) x = self.decoder_embed(hidden_states)
...@@ -812,9 +900,12 @@ class ViTMAEDecoder(nn.Module): ...@@ -812,9 +900,12 @@ class ViTMAEDecoder(nn.Module):
# unshuffle # unshuffle
x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]).to(x_.device)) x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]).to(x_.device))
x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token
# add pos embed # add pos embed
hidden_states = x + self.decoder_pos_embed if interpolate_pos_encoding:
decoder_pos_embed = self.interpolate_pos_encoding(x)
else:
decoder_pos_embed = self.decoder_pos_embed
hidden_states = x + decoder_pos_embed
# apply Transformer layers (blocks) # apply Transformer layers (blocks)
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
...@@ -893,11 +984,13 @@ class ViTMAEForPreTraining(ViTMAEPreTrainedModel): ...@@ -893,11 +984,13 @@ class ViTMAEForPreTraining(ViTMAEPreTrainedModel):
for layer, heads in heads_to_prune.items(): for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads) self.encoder.layer[layer].attention.prune_heads(heads)
def patchify(self, pixel_values): def patchify(self, pixel_values, interpolate_pos_encoding: bool = False):
""" """
Args: Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values. Pixel values.
interpolate_pos_encoding (`bool`, *optional*, default `False`):
interpolation flag passed during the forward pass.
Returns: Returns:
`torch.FloatTensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`: `torch.FloatTensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`:
...@@ -905,7 +998,9 @@ class ViTMAEForPreTraining(ViTMAEPreTrainedModel): ...@@ -905,7 +998,9 @@ class ViTMAEForPreTraining(ViTMAEPreTrainedModel):
""" """
patch_size, num_channels = self.config.patch_size, self.config.num_channels patch_size, num_channels = self.config.patch_size, self.config.num_channels
# sanity checks # sanity checks
if (pixel_values.shape[2] != pixel_values.shape[3]) or (pixel_values.shape[2] % patch_size != 0): if not interpolate_pos_encoding and (
pixel_values.shape[2] != pixel_values.shape[3] or pixel_values.shape[2] % patch_size != 0
):
raise ValueError("Make sure the pixel values have a squared size that is divisible by the patch size") raise ValueError("Make sure the pixel values have a squared size that is divisible by the patch size")
if pixel_values.shape[1] != num_channels: if pixel_values.shape[1] != num_channels:
raise ValueError( raise ValueError(
...@@ -914,38 +1009,50 @@ class ViTMAEForPreTraining(ViTMAEPreTrainedModel): ...@@ -914,38 +1009,50 @@ class ViTMAEForPreTraining(ViTMAEPreTrainedModel):
# patchify # patchify
batch_size = pixel_values.shape[0] batch_size = pixel_values.shape[0]
num_patches_one_direction = pixel_values.shape[2] // patch_size num_patches_h = pixel_values.shape[2] // patch_size
num_patches_w = pixel_values.shape[3] // patch_size
patchified_pixel_values = pixel_values.reshape( patchified_pixel_values = pixel_values.reshape(
batch_size, num_channels, num_patches_one_direction, patch_size, num_patches_one_direction, patch_size batch_size, num_channels, num_patches_h, patch_size, num_patches_w, patch_size
) )
patchified_pixel_values = torch.einsum("nchpwq->nhwpqc", patchified_pixel_values) patchified_pixel_values = torch.einsum("nchpwq->nhwpqc", patchified_pixel_values)
patchified_pixel_values = patchified_pixel_values.reshape( patchified_pixel_values = patchified_pixel_values.reshape(
batch_size, num_patches_one_direction * num_patches_one_direction, patch_size**2 * num_channels batch_size, num_patches_h * num_patches_w, patch_size**2 * num_channels
) )
return patchified_pixel_values return patchified_pixel_values
def unpatchify(self, patchified_pixel_values): def unpatchify(self, patchified_pixel_values, original_image_size: Optional[Tuple[int, int]] = None):
""" """
Args: Args:
patchified_pixel_values (`torch.FloatTensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`: patchified_pixel_values (`torch.FloatTensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`:
Patchified pixel values. Patchified pixel values.
original_image_size (`Tuple[int, int]`, *optional*):
Original image size.
Returns: Returns:
`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`: `torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`:
Pixel values. Pixel values.
""" """
patch_size, num_channels = self.config.patch_size, self.config.num_channels patch_size, num_channels = self.config.patch_size, self.config.num_channels
num_patches_one_direction = int(patchified_pixel_values.shape[1] ** 0.5) original_image_size = (
original_image_size
if original_image_size is not None
else (self.config.image_size, self.config.image_size)
)
original_height, original_width = original_image_size
num_patches_h = original_height // patch_size
num_patches_w = original_width // patch_size
# sanity check # sanity check
if num_patches_one_direction**2 != patchified_pixel_values.shape[1]: if num_patches_h * num_patches_w != patchified_pixel_values.shape[1]:
raise ValueError("Make sure that the number of patches can be squared") raise ValueError(
f"The number of patches in the patchified pixel values {patchified_pixel_values.shape[1]}, does not match the number of patches on original image {num_patches_h}*{num_patches_w}"
)
# unpatchify # unpatchify
batch_size = patchified_pixel_values.shape[0] batch_size = patchified_pixel_values.shape[0]
patchified_pixel_values = patchified_pixel_values.reshape( patchified_pixel_values = patchified_pixel_values.reshape(
batch_size, batch_size,
num_patches_one_direction, num_patches_h,
num_patches_one_direction, num_patches_w,
patch_size, patch_size,
patch_size, patch_size,
num_channels, num_channels,
...@@ -954,12 +1061,12 @@ class ViTMAEForPreTraining(ViTMAEPreTrainedModel): ...@@ -954,12 +1061,12 @@ class ViTMAEForPreTraining(ViTMAEPreTrainedModel):
pixel_values = patchified_pixel_values.reshape( pixel_values = patchified_pixel_values.reshape(
batch_size, batch_size,
num_channels, num_channels,
num_patches_one_direction * patch_size, num_patches_h * patch_size,
num_patches_one_direction * patch_size, num_patches_w * patch_size,
) )
return pixel_values return pixel_values
def forward_loss(self, pixel_values, pred, mask): def forward_loss(self, pixel_values, pred, mask, interpolate_pos_encoding: bool = False):
""" """
Args: Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
...@@ -968,11 +1075,13 @@ class ViTMAEForPreTraining(ViTMAEPreTrainedModel): ...@@ -968,11 +1075,13 @@ class ViTMAEForPreTraining(ViTMAEPreTrainedModel):
Predicted pixel values. Predicted pixel values.
mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
Tensor indicating which patches are masked (1) and which are not (0). Tensor indicating which patches are masked (1) and which are not (0).
interpolate_pos_encoding (`bool`, *optional*, default `False`):
interpolation flag passed during the forward pass.
Returns: Returns:
`torch.FloatTensor`: Pixel reconstruction loss. `torch.FloatTensor`: Pixel reconstruction loss.
""" """
target = self.patchify(pixel_values) target = self.patchify(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
if self.config.norm_pix_loss: if self.config.norm_pix_loss:
mean = target.mean(dim=-1, keepdim=True) mean = target.mean(dim=-1, keepdim=True)
var = target.var(dim=-1, keepdim=True) var = target.var(dim=-1, keepdim=True)
...@@ -980,7 +1089,6 @@ class ViTMAEForPreTraining(ViTMAEPreTrainedModel): ...@@ -980,7 +1089,6 @@ class ViTMAEForPreTraining(ViTMAEPreTrainedModel):
loss = (pred - target) ** 2 loss = (pred - target) ** 2
loss = loss.mean(dim=-1) # [N, L], mean loss per patch loss = loss.mean(dim=-1) # [N, L], mean loss per patch
loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches
return loss return loss
...@@ -994,6 +1102,7 @@ class ViTMAEForPreTraining(ViTMAEPreTrainedModel): ...@@ -994,6 +1102,7 @@ class ViTMAEForPreTraining(ViTMAEPreTrainedModel):
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
) -> Union[Tuple, ViTMAEForPreTrainingOutput]: ) -> Union[Tuple, ViTMAEForPreTrainingOutput]:
r""" r"""
Returns: Returns:
...@@ -1026,16 +1135,17 @@ class ViTMAEForPreTraining(ViTMAEPreTrainedModel): ...@@ -1026,16 +1135,17 @@ class ViTMAEForPreTraining(ViTMAEPreTrainedModel):
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
interpolate_pos_encoding=interpolate_pos_encoding,
) )
latent = outputs.last_hidden_state latent = outputs.last_hidden_state
ids_restore = outputs.ids_restore ids_restore = outputs.ids_restore
mask = outputs.mask mask = outputs.mask
decoder_outputs = self.decoder(latent, ids_restore) decoder_outputs = self.decoder(latent, ids_restore, interpolate_pos_encoding=interpolate_pos_encoding)
logits = decoder_outputs.logits # shape (batch_size, num_patches, patch_size*patch_size*num_channels) logits = decoder_outputs.logits # shape (batch_size, num_patches, patch_size*patch_size*num_channels)
loss = self.forward_loss(pixel_values, logits, mask) loss = self.forward_loss(pixel_values, logits, mask, interpolate_pos_encoding=interpolate_pos_encoding)
if not return_dict: if not return_dict:
output = (logits, mask, ids_restore) + outputs[2:] output = (logits, mask, ids_restore) + outputs[2:]
......
...@@ -426,7 +426,7 @@ def prepare_img(): ...@@ -426,7 +426,7 @@ def prepare_img():
class TFViTMAEModelIntegrationTest(unittest.TestCase): class TFViTMAEModelIntegrationTest(unittest.TestCase):
@cached_property @cached_property
def default_image_processor(self): def default_image_processor(self):
return ViTImageProcessor.from_pretrained("facebook/vit-mae-base") if is_vision_available() else None return ViTImageProcessor.from_pretrained("facebook/vit-mae-base")
@slow @slow
def test_inference_for_pretraining(self): def test_inference_for_pretraining(self):
...@@ -457,3 +457,32 @@ class TFViTMAEModelIntegrationTest(unittest.TestCase): ...@@ -457,3 +457,32 @@ class TFViTMAEModelIntegrationTest(unittest.TestCase):
) )
tf.debugging.assert_near(outputs.logits[0, :3, :3], expected_slice, atol=1e-4) tf.debugging.assert_near(outputs.logits[0, :3, :3], expected_slice, atol=1e-4)
@slow
def test_inference_interpolate_pos_encoding(self):
# ViTMAE models have an `interpolate_pos_encoding` argument in their forward method,
# allowing to interpolate the pre-trained position embeddings in order to use
# the model on higher resolutions. The DINO model by Facebook AI leverages this
# to visualize self-attention on higher resolution images.
# make random mask reproducible across the PT and TF model
np.random.seed(2)
model = TFViTMAEForPreTraining.from_pretrained("facebook/vit-mae-base")
image_processor = self.default_image_processor
image = prepare_img()
inputs = image_processor(images=image, do_resize=False, return_tensors="tf")
# prepare a noise vector that will be also used for testing the TF model
# (this way we can ensure that the PT and TF models operate on the same inputs)
vit_mae_config = ViTMAEConfig()
num_patches = (image.height // vit_mae_config.patch_size) * (image.width // vit_mae_config.patch_size)
noise = np.random.uniform(size=(1, num_patches))
# forward pass
outputs = model(**inputs, noise=noise, interpolate_pos_encoding=True)
# verify the logits
expected_shape = tf.convert_to_tensor([1, 1200, 768])
self.assertEqual(outputs.logits.shape, expected_shape)
...@@ -296,7 +296,7 @@ def prepare_img(): ...@@ -296,7 +296,7 @@ def prepare_img():
class ViTMAEModelIntegrationTest(unittest.TestCase): class ViTMAEModelIntegrationTest(unittest.TestCase):
@cached_property @cached_property
def default_image_processor(self): def default_image_processor(self):
return ViTImageProcessor.from_pretrained("facebook/vit-mae-base") if is_vision_available() else None return ViTImageProcessor.from_pretrained("facebook/vit-mae-base")
@slow @slow
def test_inference_for_pretraining(self): def test_inference_for_pretraining(self):
...@@ -328,3 +328,35 @@ class ViTMAEModelIntegrationTest(unittest.TestCase): ...@@ -328,3 +328,35 @@ class ViTMAEModelIntegrationTest(unittest.TestCase):
) )
self.assertTrue(torch.allclose(outputs.logits[0, :3, :3], expected_slice.to(torch_device), atol=1e-4)) self.assertTrue(torch.allclose(outputs.logits[0, :3, :3], expected_slice.to(torch_device), atol=1e-4))
@slow
def test_inference_interpolate_pos_encoding(self):
# ViTMAE models have an `interpolate_pos_encoding` argument in their forward method,
# allowing to interpolate the pre-trained position embeddings in order to use
# the model on higher resolutions. The DINO model by Facebook AI leverages this
# to visualize self-attention on higher resolution images.
# make random mask reproducible across the PT and TF model
np.random.seed(2)
model = ViTMAEForPreTraining.from_pretrained("facebook/vit-mae-base").to(torch_device)
image_processor = self.default_image_processor
image = prepare_img()
inputs = image_processor(images=image, return_tensors="pt", do_resize=False).to(torch_device)
# prepare a noise vector that will be also used for testing the TF model
# (this way we can ensure that the PT and TF models operate on the same inputs)
vit_mae_config = ViTMAEConfig()
num_patches = (image.height // vit_mae_config.patch_size) * (image.width // vit_mae_config.patch_size)
noise = np.random.uniform(size=(1, num_patches))
# forward pass
with torch.no_grad():
outputs = model(
**inputs, noise=torch.from_numpy(noise).to(device=torch_device), interpolate_pos_encoding=True
)
# verify the logits
expected_shape = torch.Size((1, 1200, 768))
self.assertEqual(outputs.logits.shape, expected_shape)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment