Unverified Commit f71895a6 authored by amyeroberts's avatar amyeroberts Committed by GitHub
Browse files

Move logic into pixelshuffle layer (#17899)

* Move all pixelshuffle logic into layer

* Rename layer

* Use correct input to function
parent 0094565f
......@@ -1224,27 +1224,30 @@ class TFSwinModel(TFSwinPreTrainedModel):
return swin_outputs
class PixelShuffle(tf.keras.layers.Layer):
class TFSwinPixelShuffle(tf.keras.layers.Layer):
"""TF layer implementation of torch.nn.PixelShuffle"""
def __init__(
self,
upscale_factor: int,
data_format: str = "NHWC",
trainable: bool = True,
name: str = None,
dtype=None,
dynamic: bool = False,
**kwargs
) -> None:
super().__init__(trainable, name, dtype, dynamic, **kwargs)
if upscale_factor < 2:
raise ValueError("upscale_factor must be an integer value >= 2")
def __init__(self, upscale_factor: int, **kwargs) -> None:
super().__init__(**kwargs)
if not isinstance(upscale_factor, int) or upscale_factor < 2:
raise ValueError(f"upscale_factor must be an integer value >= 2 got {upscale_factor}")
self.upscale_factor = upscale_factor
self.data_format = data_format
def call(self, x: tf.Tensor) -> tf.Tensor:
return tf.nn.depth_to_space(x, block_size=self.upscale_factor, data_format=self.data_format)
hidden_states = x
batch_size, _, _, num_input_channels = shape_list(hidden_states)
block_size_squared = self.upscale_factor**2
output_depth = int(num_input_channels / block_size_squared)
# When the number of output channels >= 2, PyTorch's PixelShuffle and
# TF's depth_to_space differ in their output as the order of channels selected for combining
# is a permutation of the other c.f.
# https://stackoverflow.com/questions/68272502/tf-depth-to-space-not-same-as-torchs-pixelshuffle-when-output-channels-1
permutation = tf.constant(
[[i + j * block_size_squared for i in range(block_size_squared) for j in range(output_depth)]]
)
hidden_states = tf.gather(params=hidden_states, indices=tf.tile(permutation, [batch_size, 1]), batch_dims=-1)
hidden_states = tf.nn.depth_to_space(hidden_states, block_size=self.upscale_factor, data_format="NHWC")
return hidden_states
class TFSwinDecoder(tf.keras.layers.Layer):
......@@ -1253,25 +1256,13 @@ class TFSwinDecoder(tf.keras.layers.Layer):
self.conv2d = tf.keras.layers.Conv2D(
filters=config.encoder_stride**2 * config.num_channels, kernel_size=1, strides=1, name="0"
)
self._block_size = config.encoder_stride
self.pixel_shuffle = PixelShuffle(self._block_size, name="1")
self.pixel_shuffle = TFSwinPixelShuffle(config.encoder_stride, name="1")
def call(self, x: tf.Tensor) -> tf.Tensor:
hidden_states = x
# B,C,H,W -> B,H,W,C
hidden_states = tf.transpose(hidden_states, (0, 2, 3, 1))
hidden_states = self.conv2d(hidden_states)
batch_size, _, _, num_input_channels = shape_list(hidden_states)
block_size_squared = self._block_size**2
output_depth = int(num_input_channels / block_size_squared)
# When the number of output channels >= 2, PyTorch's PixelShuffle and
# TF's depth_to_space differ in their output as the order of channels selected for combining
# is a permutation of the other c.f.
# https://stackoverflow.com/questions/68272502/tf-depth-to-space-not-same-as-torchs-pixelshuffle-when-output-channels-1
permutation = tf.constant(
[[i + j * block_size_squared for i in range(block_size_squared) for j in range(output_depth)]]
)
hidden_states = tf.gather(params=hidden_states, indices=tf.tile(permutation, [batch_size, 1]), batch_dims=-1)
hidden_states = self.pixel_shuffle(hidden_states)
# B,H,W,C -> B,C,H,W
hidden_states = tf.transpose(hidden_states, (0, 3, 1, 2))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment