Unverified Commit e5103a76 authored by BHUVAN M's avatar BHUVAN M Committed by GitHub
Browse files

added interpolation for vitmae model in pytorch as well as tf. (#30732)



* added interpolation for vitmae model in pytorch as well as tf.

* Update modeling_vit_mae.py

irreugalr import fixed

* small changes and proper formatting

* changes suggested in review.

* modified decoder interpolate_func

* arguments and docstring fix

* Apply suggestions from code review

doc fixes
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

---------
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent a3cdff41
......@@ -240,6 +240,38 @@ class TFViTMAEEmbeddings(keras.layers.Layer):
with tf.name_scope(self.patch_embeddings.name):
self.patch_embeddings.build(None)
def interpolate_pos_encoding(self, embeddings, height, width) -> tf.Tensor:
"""
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
resolution images.
Source:
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
"""
batch_size, seq_len, dim = shape_list(embeddings)
num_patches = seq_len - 1
_, num_positions, _ = shape_list(self.position_embeddings)
num_positions -= 1
if num_patches == num_positions and height == width:
return self.position_embeddings
class_pos_embed = self.position_embeddings[:, :1]
patch_pos_embed = self.position_embeddings[:, 1:]
h0 = height // self.config.patch_size
w0 = width // self.config.patch_size
patch_pos_embed = tf.image.resize(
images=tf.reshape(
patch_pos_embed, shape=(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
),
size=(h0, w0),
method="bicubic",
)
patch_pos_embed = tf.reshape(tensor=patch_pos_embed, shape=(1, -1, dim))
return tf.concat(values=(class_pos_embed, patch_pos_embed), axis=1)
def random_masking(self, sequence: tf.Tensor, noise: tf.Tensor | None = None):
"""
Perform per-sample random masking by per-sample shuffling. Per-sample shuffling is done by argsort random
......@@ -281,17 +313,23 @@ class TFViTMAEEmbeddings(keras.layers.Layer):
return sequence_unmasked, mask, ids_restore
def call(self, pixel_values: tf.Tensor, noise: tf.Tensor = None) -> tf.Tensor:
embeddings = self.patch_embeddings(pixel_values)
def call(
self, pixel_values: tf.Tensor, noise: tf.Tensor = None, interpolate_pos_encoding: bool = False
) -> tf.Tensor:
batch_size, num_channels, height, width = shape_list(pixel_values)
embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
if interpolate_pos_encoding:
position_embeddings = self.interpolate_pos_encoding(embeddings, height, width)
else:
position_embeddings = self.position_embeddings
# add position embeddings w/o cls token
embeddings = embeddings + self.position_embeddings[:, 1:, :]
embeddings = embeddings + position_embeddings[:, 1:, :]
# masking: length -> length * config.mask_ratio
embeddings, mask, ids_restore = self.random_masking(embeddings, noise)
# append cls token
cls_token = self.cls_token + self.position_embeddings[:, :1, :]
cls_token = self.cls_token + position_embeddings[:, :1, :]
cls_tokens = tf.tile(cls_token, (shape_list(embeddings)[0], 1, 1))
embeddings = tf.concat([cls_tokens, embeddings], axis=1)
......@@ -329,7 +367,9 @@ class TFViTMAEPatchEmbeddings(keras.layers.Layer):
name="projection",
)
def call(self, pixel_values: tf.Tensor, training: bool = False) -> tf.Tensor:
def call(
self, pixel_values: tf.Tensor, training: bool = False, interpolate_pos_encoding: bool = False
) -> tf.Tensor:
batch_size, num_channels, height, width = shape_list(pixel_values)
if tf.executing_eagerly():
if num_channels != self.num_channels:
......@@ -337,7 +377,7 @@ class TFViTMAEPatchEmbeddings(keras.layers.Layer):
"Make sure that the channel dimension of the pixel values match with the one set in the"
" configuration."
)
if height != self.image_size[0] or width != self.image_size[1]:
if not interpolate_pos_encoding and (height != self.image_size[0] or width != self.image_size[1]):
raise ValueError(
f"Input image size ({height}*{width}) doesn't match model"
f" ({self.image_size[0]}*{self.image_size[1]})."
......@@ -741,9 +781,13 @@ class TFViTMAEMainLayer(keras.layers.Layer):
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: bool = False,
interpolate_pos_encoding: bool = False,
) -> Union[TFViTMAEModelOutput, Tuple[tf.Tensor]]:
embedding_output, mask, ids_restore = self.embeddings(
pixel_values=pixel_values, training=training, noise=noise
pixel_values=pixel_values,
training=training,
noise=noise,
interpolate_pos_encoding=interpolate_pos_encoding,
)
# Prepare head mask if needed
......@@ -874,6 +918,9 @@ VIT_MAE_INPUTS_DOCSTRING = r"""
training (`bool`, *optional*, defaults to `False``):
Whether or not to use the model in training mode (some modules like dropout modules have different
behaviors between training and evaluation).
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
Whether to interpolate the position encodings at the encoder and decoder.
"""
......@@ -902,6 +949,7 @@ class TFViTMAEModel(TFViTMAEPreTrainedModel):
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: bool = False,
interpolate_pos_encoding: bool = False,
) -> Union[TFViTMAEModelOutput, Tuple[tf.Tensor]]:
r"""
Returns:
......@@ -931,6 +979,7 @@ class TFViTMAEModel(TFViTMAEPreTrainedModel):
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
interpolate_pos_encoding=interpolate_pos_encoding,
)
return outputs
......@@ -1004,6 +1053,39 @@ class TFViTMAEDecoder(keras.layers.Layer):
with tf.name_scope(layer.name):
layer.build(None)
def interpolate_pos_encoding(self, embeddings) -> tf.Tensor:
"""
This method is a modified version of the interpolation function for ViT-mae model at the deocder, that
allows to interpolate the pre-trained decoder position encodings, to be able to use the model on higher
resolution images.
Source:
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
"""
# [batch_size, num_patches + 1, hidden_size]
_, num_positions, dim = shape_list(self.decoder_pos_embed)
# -1 removes the class dimension since we later append it without interpolation
seq_len = shape_list(embeddings)[1] - 1
num_positions = num_positions - 1
# Separation of class token and patch tokens
class_pos_embed = self.decoder_pos_embed[:, :1, :]
patch_pos_embed = self.decoder_pos_embed[:, 1:, :]
# interpolate the position embeddings
patch_pos_embed = tf.image.resize(
images=tf.reshape(patch_pos_embed, shape=(1, 1, -1, dim)),
size=(1, seq_len),
method="bicubic",
)
# [1, seq_len, hidden_size]
patch_pos_embed = tf.reshape(tensor=patch_pos_embed, shape=(1, -1, dim))
# Adding the class token back
return tf.concat(values=(class_pos_embed, patch_pos_embed), axis=1)
def call(
self,
hidden_states,
......@@ -1011,10 +1093,10 @@ class TFViTMAEDecoder(keras.layers.Layer):
output_attentions=False,
output_hidden_states=False,
return_dict=True,
interpolate_pos_encoding=False,
):
# embed tokens
x = self.decoder_embed(hidden_states)
# append mask tokens to sequence
mask_tokens = tf.tile(
self.mask_token,
......@@ -1023,10 +1105,12 @@ class TFViTMAEDecoder(keras.layers.Layer):
x_ = tf.concat([x[:, 1:, :], mask_tokens], axis=1) # no cls token
x_ = tf.gather(x_, axis=1, batch_dims=1, indices=ids_restore) # unshuffle
x = tf.concat([x[:, :1, :], x_], axis=1) # append cls token
if interpolate_pos_encoding:
decoder_pos_embed = self.interpolate_pos_encoding(x)
else:
decoder_pos_embed = self.decoder_pos_embed
# add pos embed
hidden_states = x + self.decoder_pos_embed
hidden_states = x + decoder_pos_embed
# apply Transformer layers (blocks)
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
......@@ -1083,11 +1167,13 @@ class TFViTMAEForPreTraining(TFViTMAEPreTrainedModel):
def _prune_heads(self, heads_to_prune):
raise NotImplementedError
def patchify(self, pixel_values):
def patchify(self, pixel_values, interpolate_pos_encoding: bool = False):
"""
Args:
pixel_values (`tf.Tensor` of shape `(batch_size, height, width, num_channels)` or `(batch_size, num_channels, height, width)`):
Pixel values.
interpolate_pos_encoding (`bool`, default `False`):
interpolation flag passed during the forward pass.
Returns:
`tf.Tensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`:
......@@ -1099,6 +1185,7 @@ class TFViTMAEForPreTraining(TFViTMAEPreTrainedModel):
pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1))
# sanity checks
if not interpolate_pos_encoding:
tf.debugging.assert_equal(
shape_list(pixel_values)[1],
shape_list(pixel_values)[2],
......@@ -1119,51 +1206,61 @@ class TFViTMAEForPreTraining(TFViTMAEPreTrainedModel):
# patchify
batch_size = shape_list(pixel_values)[0]
num_patches_one_direction = shape_list(pixel_values)[2] // patch_size
num_patches_h = shape_list(pixel_values)[1] // patch_size
num_patches_w = shape_list(pixel_values)[2] // patch_size
patchified_pixel_values = tf.reshape(
pixel_values,
(batch_size, num_patches_one_direction, patch_size, num_patches_one_direction, patch_size, num_channels),
(batch_size, num_patches_h, patch_size, num_patches_w, patch_size, num_channels),
)
patchified_pixel_values = tf.einsum("nhpwqc->nhwpqc", patchified_pixel_values)
patchified_pixel_values = tf.reshape(
patchified_pixel_values,
(batch_size, num_patches_one_direction * num_patches_one_direction, patch_size**2 * num_channels),
(batch_size, num_patches_h * num_patches_w, patch_size**2 * num_channels),
)
return patchified_pixel_values
def unpatchify(self, patchified_pixel_values):
def unpatchify(self, patchified_pixel_values, original_image_size: Optional[Tuple[int, int]] = None):
"""
Args:
patchified_pixel_values (`tf.Tensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`:
Patchified pixel values.
original_image_size (`Tuple[int, int]`, *optional*):
Original image size.
Returns:
`tf.Tensor` of shape `(batch_size, height, width, num_channels)`:
Pixel values.
"""
patch_size, num_channels = self.config.patch_size, self.config.num_channels
num_patches_one_direction = int(shape_list(patchified_pixel_values)[1] ** 0.5)
original_image_size = (
original_image_size
if original_image_size is not None
else (self.config.image_size, self.config.image_size)
)
original_height, original_width = original_image_size
num_patches_h = original_height // patch_size
num_patches_w = original_width // patch_size
# sanity check
tf.debugging.assert_equal(
num_patches_one_direction * num_patches_one_direction,
num_patches_h * num_patches_w,
shape_list(patchified_pixel_values)[1],
message="Make sure that the number of patches can be squared",
message=f"The number of patches in the patchified pixel values is {shape_list(patchified_pixel_values)[1]} does not match the patches of original image {num_patches_w}*{num_patches_h}",
)
# unpatchify
batch_size = shape_list(patchified_pixel_values)[0]
patchified_pixel_values = tf.reshape(
patchified_pixel_values,
(batch_size, num_patches_one_direction, num_patches_one_direction, patch_size, patch_size, num_channels),
(batch_size, num_patches_h, num_patches_w, patch_size, patch_size, num_channels),
)
patchified_pixel_values = tf.einsum("nhwpqc->nhpwqc", patchified_pixel_values)
pixel_values = tf.reshape(
patchified_pixel_values,
(batch_size, num_patches_one_direction * patch_size, num_patches_one_direction * patch_size, num_channels),
(batch_size, num_patches_h * patch_size, num_patches_w * patch_size, num_channels),
)
return pixel_values
def forward_loss(self, pixel_values, pred, mask):
def forward_loss(self, pixel_values, pred, mask, interpolate_pos_encoding: bool = False):
"""
Args:
pixel_values (`tf.Tensor` of shape `(batch_size, height, width, num_channels)`):
......@@ -1172,11 +1269,13 @@ class TFViTMAEForPreTraining(TFViTMAEPreTrainedModel):
Predicted pixel values.
mask (`tf.Tensor` of shape `(batch_size, sequence_length)`):
Tensor indicating which patches are masked (1) and which are not (0).
interpolate_pos_encoding (`bool`, *optional*, default `False`):
interpolation flag passed during the forward pass.
Returns:
`tf.Tensor`: Pixel reconstruction loss.
"""
target = self.patchify(pixel_values)
target = self.patchify(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
if self.config.norm_pix_loss:
mean = tf.reduce_mean(target, axis=-1, keepdims=True)
var = tf.math.reduce_variance(target, axis=-1, keepdims=True)
......@@ -1201,6 +1300,7 @@ class TFViTMAEForPreTraining(TFViTMAEPreTrainedModel):
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: bool = False,
interpolate_pos_encoding: bool = False,
) -> Union[TFViTMAEForPreTrainingOutput, Tuple[tf.Tensor]]:
r"""
Returns:
......@@ -1234,16 +1334,18 @@ class TFViTMAEForPreTraining(TFViTMAEPreTrainedModel):
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
interpolate_pos_encoding=interpolate_pos_encoding,
)
latent = outputs.last_hidden_state
ids_restore = outputs.ids_restore
mask = outputs.mask
decoder_outputs = self.decoder(latent, ids_restore) # [batch_size, num_patches, patch_size**2*3]
# [batch_size, num_patches, patch_size**2*3]
decoder_outputs = self.decoder(latent, ids_restore, interpolate_pos_encoding=interpolate_pos_encoding)
logits = decoder_outputs.logits
loss = self.forward_loss(pixel_values, logits, mask)
loss = self.forward_loss(pixel_values, logits, mask, interpolate_pos_encoding=interpolate_pos_encoding)
if not return_dict:
output = (logits, mask, ids_restore) + outputs[2:]
......
......@@ -223,6 +223,41 @@ class ViTMAEEmbeddings(nn.Module):
# timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
torch.nn.init.normal_(self.cls_token, std=self.config.initializer_range)
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
resolution images.
Source:
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
"""
num_patches = embeddings.shape[1] - 1
num_positions = self.position_embeddings.shape[1] - 1
if num_patches == num_positions and height == width:
return self.position_embeddings
class_pos_embed = self.position_embeddings[:, 0, :]
patch_pos_embed = self.position_embeddings[:, 1:, :]
dim = embeddings.shape[-1]
h0 = height // self.config.patch_size
w0 = width // self.config.patch_size
# we add a small number to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
h0, w0 = h0 + 0.1, w0 + 0.1
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
mode="bicubic",
align_corners=False,
)
if int(h0) != patch_pos_embed.shape[-2] or int(w0) != patch_pos_embed.shape[-1]:
raise ValueError("Width or height does not match with the interpolated position embeddings")
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
def random_masking(self, sequence, noise=None):
"""
Perform per-sample random masking by per-sample shuffling. Per-sample shuffling is done by argsort random
......@@ -255,18 +290,22 @@ class ViTMAEEmbeddings(nn.Module):
return sequence_unmasked, mask, ids_restore
def forward(self, pixel_values, noise=None):
def forward(self, pixel_values, noise=None, interpolate_pos_encoding: bool = False):
batch_size, num_channels, height, width = pixel_values.shape
embeddings = self.patch_embeddings(pixel_values)
embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
if interpolate_pos_encoding:
position_embeddings = self.interpolate_pos_encoding(embeddings, height, width)
else:
position_embeddings = self.position_embeddings
# add position embeddings w/o cls token
embeddings = embeddings + self.position_embeddings[:, 1:, :]
embeddings = embeddings + position_embeddings[:, 1:, :]
# masking: length -> length * config.mask_ratio
embeddings, mask, ids_restore = self.random_masking(embeddings, noise)
# append cls token
cls_token = self.cls_token + self.position_embeddings[:, :1, :]
cls_token = self.cls_token + position_embeddings[:, :1, :]
cls_tokens = cls_token.expand(embeddings.shape[0], -1, -1)
embeddings = torch.cat((cls_tokens, embeddings), dim=1)
......@@ -294,13 +333,14 @@ class ViTMAEPatchEmbeddings(nn.Module):
self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
def forward(self, pixel_values):
def forward(self, pixel_values, interpolate_pos_encoding: bool = False):
batch_size, num_channels, height, width = pixel_values.shape
if num_channels != self.num_channels:
raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
)
if height != self.image_size[0] or width != self.image_size[1]:
if not interpolate_pos_encoding and (height != self.image_size[0] or width != self.image_size[1]):
raise ValueError(
f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
)
......@@ -657,6 +697,9 @@ VIT_MAE_INPUTS_DOCSTRING = r"""
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
interpolate_pos_encoding (`bool`, *optional*, default `False`):
Whether to interpolate the pre-trained position encodings. This is mainly used to use the model on higher
resolution images.
"""
......@@ -698,6 +741,7 @@ class ViTMAEModel(ViTMAEPreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
) -> Union[Tuple, ViTMAEModelOutput]:
r"""
Returns:
......@@ -735,7 +779,9 @@ class ViTMAEModel(ViTMAEPreTrainedModel):
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
embedding_output, mask, ids_restore = self.embeddings(pixel_values, noise=noise)
embedding_output, mask, ids_restore = self.embeddings(
pixel_values, noise=noise, interpolate_pos_encoding=interpolate_pos_encoding
)
encoder_outputs = self.encoder(
embedding_output,
......@@ -785,6 +831,47 @@ class ViTMAEDecoder(nn.Module):
self.config = config
self.initialize_weights(num_patches)
def interpolate_pos_encoding(self, embeddings: torch.Tensor) -> torch.Tensor:
"""
This method is a modified version of the interpolation function for ViT-mae model at the deocder, that
allows to interpolate the pre-trained decoder position encodings, to be able to use the model on higher
resolution images.
Source:
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
"""
# -1 removes the class dimension since we later append it without interpolation
embeddings_positions = embeddings.shape[1] - 1
num_positions = self.decoder_pos_embed.shape[1] - 1
# Separation of class token and patch tokens
class_pos_embed = self.decoder_pos_embed[:, 0, :]
patch_pos_embed = self.decoder_pos_embed[:, 1:, :]
# To retain the final 3d tensor with the required dimensions
dim = self.decoder_pos_embed.shape[-1]
# Increasing a dimension to enable bicubic interpolation
patch_pos_embed = patch_pos_embed.reshape(1, 1, -1, dim)
# permute to bring the dimension to be interpolated, to the last
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
# Interpolating the decoder position embeddings shape wrt embeddings shape i.e (x).
# 1 keeps the other dimension constant
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
scale_factor=(1, embeddings_positions / num_positions),
mode="bicubic",
align_corners=False,
)
# Converting back to the original shape
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
# Adding the class token back
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
def initialize_weights(self, num_patches):
# initialize (and freeze) position embeddings by sin-cos embedding
decoder_pos_embed = get_2d_sincos_pos_embed(
......@@ -802,6 +889,7 @@ class ViTMAEDecoder(nn.Module):
output_attentions=False,
output_hidden_states=False,
return_dict=True,
interpolate_pos_encoding: bool = False,
):
# embed tokens
x = self.decoder_embed(hidden_states)
......@@ -812,9 +900,12 @@ class ViTMAEDecoder(nn.Module):
# unshuffle
x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]).to(x_.device))
x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token
# add pos embed
hidden_states = x + self.decoder_pos_embed
if interpolate_pos_encoding:
decoder_pos_embed = self.interpolate_pos_encoding(x)
else:
decoder_pos_embed = self.decoder_pos_embed
hidden_states = x + decoder_pos_embed
# apply Transformer layers (blocks)
all_hidden_states = () if output_hidden_states else None
......@@ -893,11 +984,13 @@ class ViTMAEForPreTraining(ViTMAEPreTrainedModel):
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)
def patchify(self, pixel_values):
def patchify(self, pixel_values, interpolate_pos_encoding: bool = False):
"""
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values.
interpolate_pos_encoding (`bool`, *optional*, default `False`):
interpolation flag passed during the forward pass.
Returns:
`torch.FloatTensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`:
......@@ -905,7 +998,9 @@ class ViTMAEForPreTraining(ViTMAEPreTrainedModel):
"""
patch_size, num_channels = self.config.patch_size, self.config.num_channels
# sanity checks
if (pixel_values.shape[2] != pixel_values.shape[3]) or (pixel_values.shape[2] % patch_size != 0):
if not interpolate_pos_encoding and (
pixel_values.shape[2] != pixel_values.shape[3] or pixel_values.shape[2] % patch_size != 0
):
raise ValueError("Make sure the pixel values have a squared size that is divisible by the patch size")
if pixel_values.shape[1] != num_channels:
raise ValueError(
......@@ -914,38 +1009,50 @@ class ViTMAEForPreTraining(ViTMAEPreTrainedModel):
# patchify
batch_size = pixel_values.shape[0]
num_patches_one_direction = pixel_values.shape[2] // patch_size
num_patches_h = pixel_values.shape[2] // patch_size
num_patches_w = pixel_values.shape[3] // patch_size
patchified_pixel_values = pixel_values.reshape(
batch_size, num_channels, num_patches_one_direction, patch_size, num_patches_one_direction, patch_size
batch_size, num_channels, num_patches_h, patch_size, num_patches_w, patch_size
)
patchified_pixel_values = torch.einsum("nchpwq->nhwpqc", patchified_pixel_values)
patchified_pixel_values = patchified_pixel_values.reshape(
batch_size, num_patches_one_direction * num_patches_one_direction, patch_size**2 * num_channels
batch_size, num_patches_h * num_patches_w, patch_size**2 * num_channels
)
return patchified_pixel_values
def unpatchify(self, patchified_pixel_values):
def unpatchify(self, patchified_pixel_values, original_image_size: Optional[Tuple[int, int]] = None):
"""
Args:
patchified_pixel_values (`torch.FloatTensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`:
Patchified pixel values.
original_image_size (`Tuple[int, int]`, *optional*):
Original image size.
Returns:
`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`:
Pixel values.
"""
patch_size, num_channels = self.config.patch_size, self.config.num_channels
num_patches_one_direction = int(patchified_pixel_values.shape[1] ** 0.5)
original_image_size = (
original_image_size
if original_image_size is not None
else (self.config.image_size, self.config.image_size)
)
original_height, original_width = original_image_size
num_patches_h = original_height // patch_size
num_patches_w = original_width // patch_size
# sanity check
if num_patches_one_direction**2 != patchified_pixel_values.shape[1]:
raise ValueError("Make sure that the number of patches can be squared")
if num_patches_h * num_patches_w != patchified_pixel_values.shape[1]:
raise ValueError(
f"The number of patches in the patchified pixel values {patchified_pixel_values.shape[1]}, does not match the number of patches on original image {num_patches_h}*{num_patches_w}"
)
# unpatchify
batch_size = patchified_pixel_values.shape[0]
patchified_pixel_values = patchified_pixel_values.reshape(
batch_size,
num_patches_one_direction,
num_patches_one_direction,
num_patches_h,
num_patches_w,
patch_size,
patch_size,
num_channels,
......@@ -954,12 +1061,12 @@ class ViTMAEForPreTraining(ViTMAEPreTrainedModel):
pixel_values = patchified_pixel_values.reshape(
batch_size,
num_channels,
num_patches_one_direction * patch_size,
num_patches_one_direction * patch_size,
num_patches_h * patch_size,
num_patches_w * patch_size,
)
return pixel_values
def forward_loss(self, pixel_values, pred, mask):
def forward_loss(self, pixel_values, pred, mask, interpolate_pos_encoding: bool = False):
"""
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
......@@ -968,11 +1075,13 @@ class ViTMAEForPreTraining(ViTMAEPreTrainedModel):
Predicted pixel values.
mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
Tensor indicating which patches are masked (1) and which are not (0).
interpolate_pos_encoding (`bool`, *optional*, default `False`):
interpolation flag passed during the forward pass.
Returns:
`torch.FloatTensor`: Pixel reconstruction loss.
"""
target = self.patchify(pixel_values)
target = self.patchify(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
if self.config.norm_pix_loss:
mean = target.mean(dim=-1, keepdim=True)
var = target.var(dim=-1, keepdim=True)
......@@ -980,7 +1089,6 @@ class ViTMAEForPreTraining(ViTMAEPreTrainedModel):
loss = (pred - target) ** 2
loss = loss.mean(dim=-1) # [N, L], mean loss per patch
loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches
return loss
......@@ -994,6 +1102,7 @@ class ViTMAEForPreTraining(ViTMAEPreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
) -> Union[Tuple, ViTMAEForPreTrainingOutput]:
r"""
Returns:
......@@ -1026,16 +1135,17 @@ class ViTMAEForPreTraining(ViTMAEPreTrainedModel):
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
interpolate_pos_encoding=interpolate_pos_encoding,
)
latent = outputs.last_hidden_state
ids_restore = outputs.ids_restore
mask = outputs.mask
decoder_outputs = self.decoder(latent, ids_restore)
decoder_outputs = self.decoder(latent, ids_restore, interpolate_pos_encoding=interpolate_pos_encoding)
logits = decoder_outputs.logits # shape (batch_size, num_patches, patch_size*patch_size*num_channels)
loss = self.forward_loss(pixel_values, logits, mask)
loss = self.forward_loss(pixel_values, logits, mask, interpolate_pos_encoding=interpolate_pos_encoding)
if not return_dict:
output = (logits, mask, ids_restore) + outputs[2:]
......
......@@ -426,7 +426,7 @@ def prepare_img():
class TFViTMAEModelIntegrationTest(unittest.TestCase):
@cached_property
def default_image_processor(self):
return ViTImageProcessor.from_pretrained("facebook/vit-mae-base") if is_vision_available() else None
return ViTImageProcessor.from_pretrained("facebook/vit-mae-base")
@slow
def test_inference_for_pretraining(self):
......@@ -457,3 +457,32 @@ class TFViTMAEModelIntegrationTest(unittest.TestCase):
)
tf.debugging.assert_near(outputs.logits[0, :3, :3], expected_slice, atol=1e-4)
@slow
def test_inference_interpolate_pos_encoding(self):
# ViTMAE models have an `interpolate_pos_encoding` argument in their forward method,
# allowing to interpolate the pre-trained position embeddings in order to use
# the model on higher resolutions. The DINO model by Facebook AI leverages this
# to visualize self-attention on higher resolution images.
# make random mask reproducible across the PT and TF model
np.random.seed(2)
model = TFViTMAEForPreTraining.from_pretrained("facebook/vit-mae-base")
image_processor = self.default_image_processor
image = prepare_img()
inputs = image_processor(images=image, do_resize=False, return_tensors="tf")
# prepare a noise vector that will be also used for testing the TF model
# (this way we can ensure that the PT and TF models operate on the same inputs)
vit_mae_config = ViTMAEConfig()
num_patches = (image.height // vit_mae_config.patch_size) * (image.width // vit_mae_config.patch_size)
noise = np.random.uniform(size=(1, num_patches))
# forward pass
outputs = model(**inputs, noise=noise, interpolate_pos_encoding=True)
# verify the logits
expected_shape = tf.convert_to_tensor([1, 1200, 768])
self.assertEqual(outputs.logits.shape, expected_shape)
......@@ -296,7 +296,7 @@ def prepare_img():
class ViTMAEModelIntegrationTest(unittest.TestCase):
@cached_property
def default_image_processor(self):
return ViTImageProcessor.from_pretrained("facebook/vit-mae-base") if is_vision_available() else None
return ViTImageProcessor.from_pretrained("facebook/vit-mae-base")
@slow
def test_inference_for_pretraining(self):
......@@ -328,3 +328,35 @@ class ViTMAEModelIntegrationTest(unittest.TestCase):
)
self.assertTrue(torch.allclose(outputs.logits[0, :3, :3], expected_slice.to(torch_device), atol=1e-4))
@slow
def test_inference_interpolate_pos_encoding(self):
# ViTMAE models have an `interpolate_pos_encoding` argument in their forward method,
# allowing to interpolate the pre-trained position embeddings in order to use
# the model on higher resolutions. The DINO model by Facebook AI leverages this
# to visualize self-attention on higher resolution images.
# make random mask reproducible across the PT and TF model
np.random.seed(2)
model = ViTMAEForPreTraining.from_pretrained("facebook/vit-mae-base").to(torch_device)
image_processor = self.default_image_processor
image = prepare_img()
inputs = image_processor(images=image, return_tensors="pt", do_resize=False).to(torch_device)
# prepare a noise vector that will be also used for testing the TF model
# (this way we can ensure that the PT and TF models operate on the same inputs)
vit_mae_config = ViTMAEConfig()
num_patches = (image.height // vit_mae_config.patch_size) * (image.width // vit_mae_config.patch_size)
noise = np.random.uniform(size=(1, num_patches))
# forward pass
with torch.no_grad():
outputs = model(
**inputs, noise=torch.from_numpy(noise).to(device=torch_device), interpolate_pos_encoding=True
)
# verify the logits
expected_shape = torch.Size((1, 1200, 768))
self.assertEqual(outputs.logits.shape, expected_shape)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment