Unverified Commit ba71bf4c authored by Aritra Roy Gosthipaty's avatar Aritra Roy Gosthipaty Committed by GitHub
Browse files

fix: renamed variable name (#18850)

The sequence_masked variable is actually the part of the sequence that is kept unmasked for the encoder. This commit renames the variable.
parent 4824741c
......@@ -254,7 +254,7 @@ class TFViTMAEEmbeddings(tf.keras.layers.Layer):
# keep the first subset
ids_keep = ids_shuffle[:, :len_keep]
sequence_masked = tf.gather(
sequence_unmasked = tf.gather(
sequence,
axis=1,
batch_dims=1,
......@@ -271,7 +271,7 @@ class TFViTMAEEmbeddings(tf.keras.layers.Layer):
# unshuffle to get the binary mask
mask = tf.gather(mask, axis=1, batch_dims=1, indices=ids_restore)
return sequence_masked, mask, ids_restore
return sequence_unmasked, mask, ids_restore
def call(self, pixel_values: tf.Tensor, noise: tf.Tensor = None) -> tf.Tensor:
embeddings = self.patch_embeddings(pixel_values)
......
......@@ -251,7 +251,7 @@ class ViTMAEEmbeddings(nn.Module):
# keep the first subset
ids_keep = ids_shuffle[:, :len_keep]
sequence_masked = torch.gather(sequence, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, dim))
sequence_unmasked = torch.gather(sequence, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, dim))
# generate the binary mask: 0 is keep, 1 is remove
mask = torch.ones([batch_size, seq_length], device=sequence.device)
......@@ -259,7 +259,7 @@ class ViTMAEEmbeddings(nn.Module):
# unshuffle to get the binary mask
mask = torch.gather(mask, dim=1, index=ids_restore)
return sequence_masked, mask, ids_restore
return sequence_unmasked, mask, ids_restore
def forward(self, pixel_values, noise=None):
batch_size, num_channels, height, width = pixel_values.shape
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment