Unverified Commit bf174f91 authored by Seunghwan Hong's avatar Seunghwan Hong Committed by GitHub
Browse files

Refactor `TFSwinLayer` to increase serving compatibility (#18352)



* Refactor `TFSwinLayer` to increase serving compatibility
Signed-off-by: default avatarSeunghwan Hong <seunghwan@scatterlab.co.kr>

* Fix missed parameters while refactoring
Signed-off-by: default avatarSeunghwan Hong <seunghwan@scatterlab.co.kr>

* Fix window_reverse to calculate batch size
Signed-off-by: default avatarSeunghwan Hong <harrydrippin@gmail.com>
Co-Authored-By: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent 575aa6ef
......@@ -226,9 +226,9 @@ def window_reverse(windows: tf.Tensor, window_size: int, height: int, width: int
"""
Merges windows to produce higher resolution features.
"""
x = shape_list(windows)[0]
x = tf.shape(windows)[0]
y = tf.cast(height * width / (window_size * window_size), tf.int32)
batch_size = int(x / y)
batch_size = tf.math.floordiv(x, y)
windows = tf.reshape(
windows, (batch_size, height // window_size, width // window_size, window_size, window_size, -1)
)
......@@ -695,16 +695,18 @@ class TFSwinLayer(tf.keras.layers.Layer):
img_mask = tf.expand_dims(img_mask, -1)
img_mask = tf.expand_dims(img_mask, 0)
mask_windows = window_partition(img_mask, self.window_size)
mask_windows = tf.reshape(mask_windows, (-1, self.window_size * self.window_size))
mask_windows = window_partition(img_mask, window_size)
mask_windows = tf.reshape(mask_windows, (-1, window_size * window_size))
attn_mask = tf.expand_dims(mask_windows, 1) - tf.expand_dims(mask_windows, 2)
attn_mask = tf.where(attn_mask != 0, float(-100.0), attn_mask)
attn_mask = tf.where(attn_mask == 0, float(0.0), attn_mask)
return attn_mask
def maybe_pad(self, hidden_states: tf.Tensor, height: int, width: int) -> Tuple[tf.Tensor, tf.Tensor]:
pad_right = (self.window_size - width % self.window_size) % self.window_size
pad_bottom = (self.window_size - height % self.window_size) % self.window_size
def maybe_pad(
self, hidden_states: tf.Tensor, window_size: int, height: int, width: int
) -> Tuple[tf.Tensor, tf.Tensor]:
pad_right = (window_size - width % window_size) % window_size
pad_bottom = (window_size - height % window_size) % window_size
pad_values = [[0, 0], [0, pad_bottom], [0, pad_right], [0, 0]]
hidden_states = tf.pad(hidden_states, pad_values)
pad_values = tf.reshape(pad_values, (-1,))
......@@ -730,7 +732,7 @@ class TFSwinLayer(tf.keras.layers.Layer):
hidden_states = self.layernorm_before(hidden_states, training=training)
hidden_states = tf.reshape(hidden_states, (batch_size, height, width, channels))
# pad hidden_states to multiples of window size
hidden_states, pad_values = self.maybe_pad(hidden_states, height, width)
hidden_states, pad_values = self.maybe_pad(hidden_states, window_size, height, width)
_, height_pad, width_pad, _ = shape_list(hidden_states)
# cyclic shift
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment