"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "1e00ef681d213938cfafd678b9ec11c786405bbf"
Unverified Commit ba71bf4c authored by Aritra Roy Gosthipaty's avatar Aritra Roy Gosthipaty Committed by GitHub
Browse files

fix: renamed variable name (#18850)

The sequence_masked variable is actually the part of the sequence that is kept unmasked for the encoder. This commit renames the variable.
parent 4824741c
...@@ -254,7 +254,7 @@ class TFViTMAEEmbeddings(tf.keras.layers.Layer): ...@@ -254,7 +254,7 @@ class TFViTMAEEmbeddings(tf.keras.layers.Layer):
# keep the first subset # keep the first subset
ids_keep = ids_shuffle[:, :len_keep] ids_keep = ids_shuffle[:, :len_keep]
sequence_masked = tf.gather( sequence_unmasked = tf.gather(
sequence, sequence,
axis=1, axis=1,
batch_dims=1, batch_dims=1,
...@@ -271,7 +271,7 @@ class TFViTMAEEmbeddings(tf.keras.layers.Layer): ...@@ -271,7 +271,7 @@ class TFViTMAEEmbeddings(tf.keras.layers.Layer):
# unshuffle to get the binary mask # unshuffle to get the binary mask
mask = tf.gather(mask, axis=1, batch_dims=1, indices=ids_restore) mask = tf.gather(mask, axis=1, batch_dims=1, indices=ids_restore)
return sequence_masked, mask, ids_restore return sequence_unmasked, mask, ids_restore
def call(self, pixel_values: tf.Tensor, noise: tf.Tensor = None) -> tf.Tensor: def call(self, pixel_values: tf.Tensor, noise: tf.Tensor = None) -> tf.Tensor:
embeddings = self.patch_embeddings(pixel_values) embeddings = self.patch_embeddings(pixel_values)
......
...@@ -251,7 +251,7 @@ class ViTMAEEmbeddings(nn.Module): ...@@ -251,7 +251,7 @@ class ViTMAEEmbeddings(nn.Module):
# keep the first subset # keep the first subset
ids_keep = ids_shuffle[:, :len_keep] ids_keep = ids_shuffle[:, :len_keep]
sequence_masked = torch.gather(sequence, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, dim)) sequence_unmasked = torch.gather(sequence, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, dim))
# generate the binary mask: 0 is keep, 1 is remove # generate the binary mask: 0 is keep, 1 is remove
mask = torch.ones([batch_size, seq_length], device=sequence.device) mask = torch.ones([batch_size, seq_length], device=sequence.device)
...@@ -259,7 +259,7 @@ class ViTMAEEmbeddings(nn.Module): ...@@ -259,7 +259,7 @@ class ViTMAEEmbeddings(nn.Module):
# unshuffle to get the binary mask # unshuffle to get the binary mask
mask = torch.gather(mask, dim=1, index=ids_restore) mask = torch.gather(mask, dim=1, index=ids_restore)
return sequence_masked, mask, ids_restore return sequence_unmasked, mask, ids_restore
def forward(self, pixel_values, noise=None): def forward(self, pixel_values, noise=None):
batch_size, num_channels, height, width = pixel_values.shape batch_size, num_channels, height, width = pixel_values.shape
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment