"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "29baa8fabe15393ec4451beceee6d025881ec992"
Unverified Commit b681e12d authored by NielsRogge's avatar NielsRogge Committed by GitHub
Browse files

[ViTMAE] Fix docstrings and variable names (#17710)



* Fix docstrings and variable names

* Rename x to something better

* Improve messages

* Fix docstrings and add test for greyscale images
Co-authored-by: default avatarNiels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
parent 3fab17fc
...@@ -84,7 +84,7 @@ class TFViTMAEDecoderOutput(ModelOutput): ...@@ -84,7 +84,7 @@ class TFViTMAEDecoderOutput(ModelOutput):
Class for TFViTMAEDecoder's outputs, with potential hidden states and attentions. Class for TFViTMAEDecoder's outputs, with potential hidden states and attentions.
Args: Args:
logits (`tf.Tensor` of shape `(batch_size, patch_size ** 2 * num_channels)`): logits (`tf.Tensor` of shape `(batch_size, sequence_length, patch_size ** 2 * num_channels)`):
Pixel reconstruction logits. Pixel reconstruction logits.
hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
...@@ -109,7 +109,7 @@ class TFViTMAEForPreTrainingOutput(ModelOutput): ...@@ -109,7 +109,7 @@ class TFViTMAEForPreTrainingOutput(ModelOutput):
Args: Args:
loss (`tf.Tensor` of shape `(1,)`): loss (`tf.Tensor` of shape `(1,)`):
Pixel reconstruction loss. Pixel reconstruction loss.
logits (`tf.Tensor` of shape `(batch_size, patch_size ** 2 * num_channels)`): logits (`tf.Tensor` of shape `(batch_size, sequence_length, patch_size ** 2 * num_channels)`):
Pixel reconstruction logits. Pixel reconstruction logits.
mask (`tf.Tensor` of shape `(batch_size, sequence_length)`): mask (`tf.Tensor` of shape `(batch_size, sequence_length)`):
Tensor indicating which patches are masked (1) and which are not (0). Tensor indicating which patches are masked (1) and which are not (0).
...@@ -969,50 +969,110 @@ class TFViTMAEForPreTraining(TFViTMAEPreTrainedModel): ...@@ -969,50 +969,110 @@ class TFViTMAEForPreTraining(TFViTMAEPreTrainedModel):
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
raise NotImplementedError raise NotImplementedError
def patchify(self, imgs): def patchify(self, pixel_values):
""" """
imgs: (batch_size, height, width, 3) x: (batch_size, num_patches, patch_size**2 *3) Args:
pixel_values (`tf.Tensor` of shape `(batch_size, height, width, num_channels)` or `(batch_size, num_channels, height, width)`):
Pixel values.
Returns:
`tf.Tensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`:
Patchified pixel values.
""" """
imgs = tf.cond( patch_size, num_channels = self.config.patch_size, self.config.num_channels
tf.math.equal(shape_list(imgs)[1], 3), lambda: tf.transpose(imgs, perm=(0, 2, 3, 1)), lambda: imgs # make sure channels are last
pixel_values = tf.cond(
tf.math.equal(shape_list(pixel_values)[1], num_channels),
lambda: tf.transpose(pixel_values, perm=(0, 2, 3, 1)),
lambda: pixel_values,
) )
p = self.vit.embeddings.patch_embeddings.patch_size[0] # sanity checks
tf.debugging.assert_equal(shape_list(imgs)[1], shape_list(imgs)[2]) tf.debugging.assert_equal(
tf.debugging.assert_equal(shape_list(imgs)[1] % p, 0) shape_list(pixel_values)[1],
shape_list(pixel_values)[2],
message="Make sure the pixel values have a squared size",
)
tf.debugging.assert_equal(
shape_list(pixel_values)[1] % patch_size,
0,
message="Make sure the pixel values have a size that is divisible by the patch size",
)
tf.debugging.assert_equal(
shape_list(pixel_values)[3],
num_channels,
message=(
"Make sure the number of channels of the pixel values is equal to the one set in the configuration"
),
)
h = w = shape_list(imgs)[2] // p # patchify
x = tf.reshape(imgs, (shape_list(imgs)[0], h, p, w, p, 3)) batch_size = shape_list(pixel_values)[0]
x = tf.einsum("nhpwqc->nhwpqc", x) num_patches_one_direction = shape_list(pixel_values)[2] // patch_size
x = tf.reshape(x, (shape_list(imgs)[0], h * w, p**2 * 3)) patchified_pixel_values = tf.reshape(
return x pixel_values,
(batch_size, num_patches_one_direction, patch_size, num_patches_one_direction, patch_size, num_channels),
)
patchified_pixel_values = tf.einsum("nhpwqc->nhwpqc", patchified_pixel_values)
patchified_pixel_values = tf.reshape(
patchified_pixel_values,
(batch_size, num_patches_one_direction * num_patches_one_direction, patch_size**2 * num_channels),
)
return patchified_pixel_values
def unpatchify(self, x): def unpatchify(self, patchified_pixel_values):
""" """
x: (batch_size, num_patches, patch_size**2 *3) imgs: (batch_size, height, width, 3) Args:
patchified_pixel_values (`tf.Tensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`:
Patchified pixel values.
Returns:
`tf.Tensor` of shape `(batch_size, height, width, num_channels)`:
Pixel values.
""" """
p = self.vit.embeddings.patch_embeddings.patch_size[0] patch_size, num_channels = self.config.patch_size, self.config.num_channels
h = w = int(shape_list(x)[1] ** 0.5) num_patches_one_direction = int(shape_list(patchified_pixel_values)[1] ** 0.5)
tf.debugging.assert_equal(h * w, shape_list(x)[1]) # sanity check
tf.debugging.assert_equal(
num_patches_one_direction * num_patches_one_direction,
shape_list(patchified_pixel_values)[1],
message="Make sure that the number of patches can be squared",
)
x = tf.reshape(x, (shape_list(x)[0], h, w, p, p, 3)) # unpatchify
x = tf.einsum("nhwpqc->nhpwqc", x) batch_size = shape_list(patchified_pixel_values)[0]
imgs = tf.reshape(x, (shape_list(x)[0], h * p, h * p, 3)) patchified_pixel_values = tf.reshape(
return imgs patchified_pixel_values,
(batch_size, num_patches_one_direction, num_patches_one_direction, patch_size, patch_size, num_channels),
)
patchified_pixel_values = tf.einsum("nhwpqc->nhpwqc", patchified_pixel_values)
pixel_values = tf.reshape(
patchified_pixel_values,
(batch_size, num_patches_one_direction * patch_size, num_patches_one_direction * patch_size, num_channels),
)
return pixel_values
def forward_loss(self, imgs, pred, mask): def forward_loss(self, pixel_values, pred, mask):
""" """
imgs: [batch_size, height, width, 3] pred: [batch_size, num_patches, patch_size**2*3] mask: [N, L], 0 is keep, Args:
1 is remove, pixel_values (`tf.Tensor` of shape `(batch_size, height, width, num_channels)`):
Pixel values.
pred (`tf.Tensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`:
Predicted pixel values.
mask (`tf.Tensor` of shape `(batch_size, sequence_length)`):
Tensor indicating which patches are masked (1) and which are not (0).
Returns:
`tf.Tensor`: Pixel reconstruction loss.
""" """
target = self.patchify(imgs) target = self.patchify(pixel_values)
if self.config.norm_pix_loss: if self.config.norm_pix_loss:
mean = tf.reduce_mean(target, axis=-1, keepdims=True) mean = tf.reduce_mean(target, axis=-1, keepdims=True)
var = tf.math.reduce_variance(target, axis=-1, keepdims=True) var = tf.math.reduce_variance(target, axis=-1, keepdims=True)
target = (target - mean) / (var + 1.0e-6) ** 0.5 target = (target - mean) / (var + 1.0e-6) ** 0.5
loss = (pred - target) ** 2 loss = (pred - target) ** 2
loss = tf.reduce_mean(loss, axis=-1) # [N, L], mean loss per patch loss = tf.reduce_mean(loss, axis=-1) # [batch_size, num_patches], mean loss per patch
loss = tf.reduce_sum(loss * mask) / tf.reduce_sum(mask) # mean loss on removed patches loss = tf.reduce_sum(loss * mask) / tf.reduce_sum(mask) # mean loss on removed patches
return loss return loss
......
...@@ -86,7 +86,7 @@ class ViTMAEDecoderOutput(ModelOutput): ...@@ -86,7 +86,7 @@ class ViTMAEDecoderOutput(ModelOutput):
Class for ViTMAEDecoder's outputs, with potential hidden states and attentions. Class for ViTMAEDecoder's outputs, with potential hidden states and attentions.
Args: Args:
logits (`torch.FloatTensor` of shape `(batch_size, patch_size ** 2 * num_channels)`): logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, patch_size ** 2 * num_channels)`):
Pixel reconstruction logits. Pixel reconstruction logits.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
...@@ -111,7 +111,7 @@ class ViTMAEForPreTrainingOutput(ModelOutput): ...@@ -111,7 +111,7 @@ class ViTMAEForPreTrainingOutput(ModelOutput):
Args: Args:
loss (`torch.FloatTensor` of shape `(1,)`): loss (`torch.FloatTensor` of shape `(1,)`):
Pixel reconstruction loss. Pixel reconstruction loss.
logits (`torch.FloatTensor` of shape `(batch_size, patch_size ** 2 * num_channels)`): logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, patch_size ** 2 * num_channels)`):
Pixel reconstruction logits. Pixel reconstruction logits.
mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
Tensor indicating which patches are masked (1) and which are not (0). Tensor indicating which patches are masked (1) and which are not (0).
...@@ -868,37 +868,86 @@ class ViTMAEForPreTraining(ViTMAEPreTrainedModel): ...@@ -868,37 +868,86 @@ class ViTMAEForPreTraining(ViTMAEPreTrainedModel):
for layer, heads in heads_to_prune.items(): for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads) self.encoder.layer[layer].attention.prune_heads(heads)
def patchify(self, imgs): def patchify(self, pixel_values):
""" """
imgs: (N, 3, H, W) x: (N, L, patch_size**2 *3) Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values.
Returns:
`torch.FloatTensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`:
Patchified pixel values.
""" """
p = self.vit.embeddings.patch_embeddings.patch_size[0] patch_size, num_channels = self.config.patch_size, self.config.num_channels
assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 # sanity checks
if (pixel_values.shape[2] != pixel_values.shape[3]) or (pixel_values.shape[2] % patch_size != 0):
raise ValueError("Make sure the pixel values have a squared size that is divisible by the patch size")
if pixel_values.shape[1] != num_channels:
raise ValueError(
"Make sure the number of channels of the pixel values is equal to the one set in the configuration"
)
h = w = imgs.shape[2] // p # patchify
x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p)) batch_size = pixel_values.shape[0]
x = torch.einsum("nchpwq->nhwpqc", x) num_patches_one_direction = pixel_values.shape[2] // patch_size
x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3)) patchified_pixel_values = pixel_values.reshape(
return x batch_size, num_channels, num_patches_one_direction, patch_size, num_patches_one_direction, patch_size
)
patchified_pixel_values = torch.einsum("nchpwq->nhwpqc", patchified_pixel_values)
patchified_pixel_values = patchified_pixel_values.reshape(
batch_size, num_patches_one_direction * num_patches_one_direction, patch_size**2 * num_channels
)
return patchified_pixel_values
def unpatchify(self, x): def unpatchify(self, patchified_pixel_values):
"""
x: (N, L, patch_size**2 *3) imgs: (N, 3, H, W)
""" """
p = self.vit.embeddings.patch_embeddings.patch_size[0] Args:
h = w = int(x.shape[1] ** 0.5) patchified_pixel_values (`torch.FloatTensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`:
assert h * w == x.shape[1] Patchified pixel values.
x = x.reshape(shape=(x.shape[0], h, w, p, p, 3)) Returns:
x = torch.einsum("nhwpqc->nchpwq", x) `torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`:
imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p)) Pixel values.
return imgs """
patch_size, num_channels = self.config.patch_size, self.config.num_channels
num_patches_one_direction = int(patchified_pixel_values.shape[1] ** 0.5)
# sanity check
if num_patches_one_direction**2 != patchified_pixel_values.shape[1]:
raise ValueError("Make sure that the number of patches can be squared")
# unpatchify
batch_size = patchified_pixel_values.shape[0]
patchified_pixel_values = patchified_pixel_values.reshape(
batch_size,
num_patches_one_direction,
num_patches_one_direction,
patch_size,
patch_size,
num_channels,
)
patchified_pixel_values = torch.einsum("nhwpqc->nchpwq", patchified_pixel_values)
pixel_values = patchified_pixel_values.reshape(
batch_size,
num_channels,
num_patches_one_direction * patch_size,
num_patches_one_direction * patch_size,
)
return pixel_values
def forward_loss(self, imgs, pred, mask): def forward_loss(self, pixel_values, pred, mask):
""" """
imgs: [N, 3, H, W] pred: [N, L, p*p*3] mask: [N, L], 0 is keep, 1 is remove, Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values.
pred (`torch.FloatTensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`:
Predicted pixel values.
mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
Tensor indicating which patches are masked (1) and which are not (0).
Returns:
`torch.FloatTensor`: Pixel reconstruction loss.
""" """
target = self.patchify(imgs) target = self.patchify(pixel_values)
if self.config.norm_pix_loss: if self.config.norm_pix_loss:
mean = target.mean(dim=-1, keepdim=True) mean = target.mean(dim=-1, keepdim=True)
var = target.var(dim=-1, keepdim=True) var = target.var(dim=-1, keepdim=True)
...@@ -958,8 +1007,8 @@ class ViTMAEForPreTraining(ViTMAEPreTrainedModel): ...@@ -958,8 +1007,8 @@ class ViTMAEForPreTraining(ViTMAEPreTrainedModel):
ids_restore = outputs.ids_restore ids_restore = outputs.ids_restore
mask = outputs.mask mask = outputs.mask
decoder_outputs = self.decoder(latent, ids_restore) # [N, L, p*p*3] decoder_outputs = self.decoder(latent, ids_restore)
logits = decoder_outputs.logits logits = decoder_outputs.logits # shape (batch_size, num_patches, patch_size*patch_size*num_channels)
loss = self.forward_loss(pixel_values, logits, mask) loss = self.forward_loss(pixel_values, logits, mask)
......
...@@ -140,6 +140,15 @@ class TFViTMAEModelTester: ...@@ -140,6 +140,15 @@ class TFViTMAEModelTester:
expected_num_channels = self.patch_size**2 * self.num_channels expected_num_channels = self.patch_size**2 * self.num_channels
self.parent.assertEqual(result.logits.shape, (self.batch_size, expected_seq_len, expected_num_channels)) self.parent.assertEqual(result.logits.shape, (self.batch_size, expected_seq_len, expected_num_channels))
# test greyscale images
config.num_channels = 1
model = TFViTMAEForPreTraining(config)
pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
result = model(pixel_values, training=False)
expected_num_channels = self.patch_size**2
self.parent.assertEqual(result.logits.shape, (self.batch_size, expected_seq_len, expected_num_channels))
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs() config_and_inputs = self.prepare_config_and_inputs()
(config, pixel_values, labels) = config_and_inputs (config, pixel_values, labels) = config_and_inputs
......
...@@ -137,6 +137,16 @@ class ViTMAEModelTester: ...@@ -137,6 +137,16 @@ class ViTMAEModelTester:
expected_num_channels = self.patch_size**2 * self.num_channels expected_num_channels = self.patch_size**2 * self.num_channels
self.parent.assertEqual(result.logits.shape, (self.batch_size, expected_seq_len, expected_num_channels)) self.parent.assertEqual(result.logits.shape, (self.batch_size, expected_seq_len, expected_num_channels))
# test greyscale images
config.num_channels = 1
model = ViTMAEForPreTraining(config)
model.to(torch_device)
model.eval()
pixel_values = floats_tensor([self.batch_size, 1, self.image_size, self.image_size])
result = model(pixel_values)
expected_num_channels = self.patch_size**2
self.parent.assertEqual(result.logits.shape, (self.batch_size, expected_seq_len, expected_num_channels))
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs() config_and_inputs = self.prepare_config_and_inputs()
config, pixel_values, labels = config_and_inputs config, pixel_values, labels = config_and_inputs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment