Unverified Commit 575aa6ef authored by Seunghwan Hong's avatar Seunghwan Hong Committed by GitHub
Browse files

Fix TFSwinSelfAttention to have relative position index as non-trainable weight (#18226)


Signed-off-by: default avatarSeunghwan Hong <seunghwan@scatterlab.co.kr>
parent 586dcf6b
...@@ -461,21 +461,6 @@ class TFSwinSelfAttention(tf.keras.layers.Layer): ...@@ -461,21 +461,6 @@ class TFSwinSelfAttention(tf.keras.layers.Layer):
window_size if isinstance(window_size, collections.abc.Iterable) else (window_size, window_size) window_size if isinstance(window_size, collections.abc.Iterable) else (window_size, window_size)
) )
# get pair-wise relative position index for each token inside the window
coords_h = tf.range(self.window_size[0])
coords_w = tf.range(self.window_size[1])
coords = tf.stack(tf.meshgrid(coords_h, coords_w, indexing="ij"))
coords_flatten = tf.reshape(coords, (shape_list(coords)[0], -1))
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
relative_coords = tf.transpose(relative_coords, (1, 2, 0))
stack_0, stack_1 = tf.unstack(relative_coords, axis=2)
stack_0 += self.window_size[0] - 1
stack_0 *= 2 * self.window_size[1] - 1
stack_1 += self.window_size[1] - 1
relative_coords = tf.stack([stack_0, stack_1], axis=2)
self.relative_position_index = tf.reduce_sum(relative_coords, axis=-1)
self.query = tf.keras.layers.Dense( self.query = tf.keras.layers.Dense(
self.all_head_size, self.all_head_size,
kernel_initializer=get_initializer(config.initializer_range), kernel_initializer=get_initializer(config.initializer_range),
...@@ -503,6 +488,28 @@ class TFSwinSelfAttention(tf.keras.layers.Layer): ...@@ -503,6 +488,28 @@ class TFSwinSelfAttention(tf.keras.layers.Layer):
initializer="zeros", initializer="zeros",
name="relative_position_bias_table", name="relative_position_bias_table",
) )
self.relative_position_index = self.add_weight(
shape=(self.window_size[0] ** 2, self.window_size[1] ** 2),
trainable=False,
dtype=tf.int32,
name="relative_position_index",
)
# get pair-wise relative position index for each token inside the window
coords_h = tf.range(self.window_size[0])
coords_w = tf.range(self.window_size[1])
coords = tf.stack(tf.meshgrid(coords_h, coords_w, indexing="ij"))
coords_flatten = tf.reshape(coords, (shape_list(coords)[0], -1))
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
relative_coords = tf.transpose(relative_coords, (1, 2, 0))
stack_0, stack_1 = tf.unstack(relative_coords, axis=2)
stack_0 += self.window_size[0] - 1
stack_0 *= 2 * self.window_size[1] - 1
stack_1 += self.window_size[1] - 1
relative_coords = tf.stack([stack_0, stack_1], axis=2)
self.relative_position_index.assign(tf.cast(tf.reduce_sum(relative_coords, axis=-1), tf.int32))
super().build(input_shape) super().build(input_shape)
def transpose_for_scores(self, x: tf.Tensor) -> tf.Tensor: def transpose_for_scores(self, x: tf.Tensor) -> tf.Tensor:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment