# Copyright 2023 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Utils for MAE.""" import math import tensorflow as tf from official.modeling import tf_utils # TODO(frederickliu): Move this to vision ops and add tests. def position_embedding_sine(attention_mask, num_pos_features=256, temperature=10000., normalize=True, scale=2 * math.pi): """Sine-based positional embeddings for 2D images. Args: attention_mask: a `bool` Tensor specifying the size of the input image to the Transformer and which elements are padded, of size [batch_size, height, width] num_pos_features: a `int` specifying the number of positional features, should be equal to the hidden size of the Transformer network temperature: a `float` specifying the temperature of the positional embedding. Any type that is converted to a `float` can also be accepted. normalize: a `bool` determining whether the positional embeddings should be normalized between [0, scale] before application of the sine and cos functions. scale: a `float` if normalize is True specifying the scale embeddings before application of the embedding function. Returns: embeddings: a `float` tensor of the same shape as input_tensor specifying the positional embeddings based on sine features. """ if num_pos_features % 2 != 0: raise ValueError( "Number of embedding features (num_pos_features) must be even when " "column and row embeddings are concatenated.") num_pos_features = num_pos_features // 2 # Produce row and column embeddings based on total size of the image # [batch_size, height, width] attention_mask = tf.cast(attention_mask, tf.float32) row_embedding = tf.cumsum(attention_mask, 1) col_embedding = tf.cumsum(attention_mask, 2) if normalize: eps = 1e-6 row_embedding = row_embedding / (row_embedding[:, -1:, :] + eps) * scale col_embedding = col_embedding / (col_embedding[:, :, -1:] + eps) * scale dim_t = tf.range(num_pos_features, dtype=row_embedding.dtype) dim_t = tf.pow(temperature, 2 * (dim_t // 2) / num_pos_features) # Creates positional embeddings for each row and column position # [batch_size, height, width, num_pos_features] pos_row = tf.expand_dims(row_embedding, -1) / dim_t pos_col = tf.expand_dims(col_embedding, -1) / dim_t pos_row = tf.stack( [tf.sin(pos_row[:, :, :, 0::2]), tf.cos(pos_row[:, :, :, 1::2])], axis=4) pos_col = tf.stack( [tf.sin(pos_col[:, :, :, 0::2]), tf.cos(pos_col[:, :, :, 1::2])], axis=4) # final_shape = pos_row.shape.as_list()[:3] + [-1] final_shape = tf_utils.get_shape_list(pos_row)[:3] + [-1] pos_row = tf.reshape(pos_row, final_shape) pos_col = tf.reshape(pos_col, final_shape) output = tf.concat([pos_row, pos_col], -1) embeddings = tf.cast(output, tf.float32) return embeddings