"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "a4616c6767b02c8c868bab08ccf56d59f861278d"
Unverified Commit bf174f91 authored by Seunghwan Hong's avatar Seunghwan Hong Committed by GitHub
Browse files

Refactor `TFSwinLayer` to increase serving compatibility (#18352)



* Refactor `TFSwinLayer` to increase serving compatibility
Signed-off-by: default avatarSeunghwan Hong <seunghwan@scatterlab.co.kr>

* Fix missed parameters while refactoring
Signed-off-by: default avatarSeunghwan Hong <seunghwan@scatterlab.co.kr>

* Fix window_reverse to calculate batch size
Signed-off-by: default avatarSeunghwan Hong <harrydrippin@gmail.com>
Co-Authored-By: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent 575aa6ef
...@@ -226,9 +226,9 @@ def window_reverse(windows: tf.Tensor, window_size: int, height: int, width: int ...@@ -226,9 +226,9 @@ def window_reverse(windows: tf.Tensor, window_size: int, height: int, width: int
""" """
Merges windows to produce higher resolution features. Merges windows to produce higher resolution features.
""" """
x = shape_list(windows)[0] x = tf.shape(windows)[0]
y = tf.cast(height * width / (window_size * window_size), tf.int32) y = tf.cast(height * width / (window_size * window_size), tf.int32)
batch_size = int(x / y) batch_size = tf.math.floordiv(x, y)
windows = tf.reshape( windows = tf.reshape(
windows, (batch_size, height // window_size, width // window_size, window_size, window_size, -1) windows, (batch_size, height // window_size, width // window_size, window_size, window_size, -1)
) )
...@@ -695,16 +695,18 @@ class TFSwinLayer(tf.keras.layers.Layer): ...@@ -695,16 +695,18 @@ class TFSwinLayer(tf.keras.layers.Layer):
img_mask = tf.expand_dims(img_mask, -1) img_mask = tf.expand_dims(img_mask, -1)
img_mask = tf.expand_dims(img_mask, 0) img_mask = tf.expand_dims(img_mask, 0)
mask_windows = window_partition(img_mask, self.window_size) mask_windows = window_partition(img_mask, window_size)
mask_windows = tf.reshape(mask_windows, (-1, self.window_size * self.window_size)) mask_windows = tf.reshape(mask_windows, (-1, window_size * window_size))
attn_mask = tf.expand_dims(mask_windows, 1) - tf.expand_dims(mask_windows, 2) attn_mask = tf.expand_dims(mask_windows, 1) - tf.expand_dims(mask_windows, 2)
attn_mask = tf.where(attn_mask != 0, float(-100.0), attn_mask) attn_mask = tf.where(attn_mask != 0, float(-100.0), attn_mask)
attn_mask = tf.where(attn_mask == 0, float(0.0), attn_mask) attn_mask = tf.where(attn_mask == 0, float(0.0), attn_mask)
return attn_mask return attn_mask
def maybe_pad(self, hidden_states: tf.Tensor, height: int, width: int) -> Tuple[tf.Tensor, tf.Tensor]: def maybe_pad(
pad_right = (self.window_size - width % self.window_size) % self.window_size self, hidden_states: tf.Tensor, window_size: int, height: int, width: int
pad_bottom = (self.window_size - height % self.window_size) % self.window_size ) -> Tuple[tf.Tensor, tf.Tensor]:
pad_right = (window_size - width % window_size) % window_size
pad_bottom = (window_size - height % window_size) % window_size
pad_values = [[0, 0], [0, pad_bottom], [0, pad_right], [0, 0]] pad_values = [[0, 0], [0, pad_bottom], [0, pad_right], [0, 0]]
hidden_states = tf.pad(hidden_states, pad_values) hidden_states = tf.pad(hidden_states, pad_values)
pad_values = tf.reshape(pad_values, (-1,)) pad_values = tf.reshape(pad_values, (-1,))
...@@ -730,7 +732,7 @@ class TFSwinLayer(tf.keras.layers.Layer): ...@@ -730,7 +732,7 @@ class TFSwinLayer(tf.keras.layers.Layer):
hidden_states = self.layernorm_before(hidden_states, training=training) hidden_states = self.layernorm_before(hidden_states, training=training)
hidden_states = tf.reshape(hidden_states, (batch_size, height, width, channels)) hidden_states = tf.reshape(hidden_states, (batch_size, height, width, channels))
# pad hidden_states to multiples of window size # pad hidden_states to multiples of window size
hidden_states, pad_values = self.maybe_pad(hidden_states, height, width) hidden_states, pad_values = self.maybe_pad(hidden_states, window_size, height, width)
_, height_pad, width_pad, _ = shape_list(hidden_states) _, height_pad, width_pad, _ = shape_list(hidden_states)
# cyclic shift # cyclic shift
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment