Commit 9b012b52 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Merge pull request #10081 from Atze00:padding_causal_convolution

PiperOrigin-RevId: 382547358
parents 4f655acd c703ef58
...@@ -723,7 +723,7 @@ class CausalConvMixin: ...@@ -723,7 +723,7 @@ class CausalConvMixin:
self._use_buffered_input = variable self._use_buffered_input = variable
def _compute_buffered_causal_padding(self, def _compute_buffered_causal_padding(self,
inputs: Optional[tf.Tensor] = None, inputs: tf.Tensor,
use_buffered_input: bool = False, use_buffered_input: bool = False,
time_axis: int = 1) -> List[List[int]]: time_axis: int = 1) -> List[List[int]]:
"""Calculates padding for 'causal' option for conv layers. """Calculates padding for 'causal' option for conv layers.
...@@ -737,7 +737,7 @@ class CausalConvMixin: ...@@ -737,7 +737,7 @@ class CausalConvMixin:
Returns: Returns:
A list of paddings for `tf.pad`. A list of paddings for `tf.pad`.
""" """
del inputs input_shape = tf.shape(inputs)[1:-1]
if tf.keras.backend.image_data_format() == 'channels_first': if tf.keras.backend.image_data_format() == 'channels_first':
raise ValueError('"channels_first" mode is unsupported.') raise ValueError('"channels_first" mode is unsupported.')
...@@ -747,7 +747,10 @@ class CausalConvMixin: ...@@ -747,7 +747,10 @@ class CausalConvMixin:
(self.kernel_size[i] - 1) * (self.dilation_rate[i] - 1)) (self.kernel_size[i] - 1) * (self.dilation_rate[i] - 1))
for i in range(self.rank) for i in range(self.rank)
] ]
pad_total = [kernel_size_effective[i] - 1 for i in range(self.rank)] pad_total = [kernel_size_effective[0] - 1]
for i in range(1, self.rank):
overlap = (input_shape[i] - 1) % self.strides[i] + 1
pad_total.append(tf.maximum(kernel_size_effective[i] - overlap, 0))
pad_beg = [pad_total[i] // 2 for i in range(self.rank)] pad_beg = [pad_total[i] // 2 for i in range(self.rank)]
pad_end = [pad_total[i] - pad_beg[i] for i in range(self.rank)] pad_end = [pad_total[i] - pad_beg[i] for i in range(self.rank)]
padding = [[pad_beg[i], pad_end[i]] for i in range(self.rank)] padding = [[pad_beg[i], pad_end[i]] for i in range(self.rank)]
...@@ -780,7 +783,8 @@ class CausalConvMixin: ...@@ -780,7 +783,8 @@ class CausalConvMixin:
# across time should be the input shape minus any padding, assuming # across time should be the input shape minus any padding, assuming
# the stride across time is 1. # the stride across time is 1.
if self._use_buffered_input and spatial_output_shape[0] is not None: if self._use_buffered_input and spatial_output_shape[0] is not None:
padding = self._compute_buffered_causal_padding(use_buffered_input=False) padding = self._compute_buffered_causal_padding(
tf.zeros([1] + spatial_output_shape + [1]), use_buffered_input=False)
spatial_output_shape[0] -= sum(padding[1]) spatial_output_shape[0] -= sum(padding[1])
return spatial_output_shape return spatial_output_shape
......
...@@ -279,14 +279,14 @@ class NNLayersTest(parameterized.TestCase, tf.test.TestCase): ...@@ -279,14 +279,14 @@ class NNLayersTest(parameterized.TestCase, tf.test.TestCase):
predicted = conv3d(padded_inputs) predicted = conv3d(padded_inputs)
expected = tf.constant( expected = tf.constant(
[[[[[12., 12., 12.], [[[[[27., 27., 27.],
[18., 18., 18.]], [18., 18., 18.]],
[[18., 18., 18.], [[18., 18., 18.],
[27., 27., 27.]]], [12., 12., 12.]]],
[[[24., 24., 24.], [[[54., 54., 54.],
[36., 36., 36.]], [36., 36., 36.]],
[[36., 36., 36.], [[36., 36., 36.],
[54., 54., 54.]]]]]) [24., 24., 24.]]]]])
self.assertEqual(predicted.shape, expected.shape) self.assertEqual(predicted.shape, expected.shape)
self.assertAllClose(predicted, expected) self.assertAllClose(predicted, expected)
...@@ -316,14 +316,17 @@ class NNLayersTest(parameterized.TestCase, tf.test.TestCase): ...@@ -316,14 +316,17 @@ class NNLayersTest(parameterized.TestCase, tf.test.TestCase):
predicted = conv3d(padded_inputs) predicted = conv3d(padded_inputs)
expected = tf.constant( expected = tf.constant(
[[[[[4.0, 4.0, 4.0], [[[[[9.0, 9.0, 9.0],
[6.0, 6.0, 6.0]], [6.0, 6.0, 6.0]],
[[6.0, 6.0, 6.0], [[6.0, 6.0, 6.0],
[9.0, 9.0, 9.0]]], [4.0, 4.0, 4.0]]],
[[[8.0, 8.0, 8.0], [[[18.0, 18.0, 18.0],
[12., 12., 12.]], [12., 12., 12.]],
[[12., 12., 12.], [[12., 12., 12.],
[18., 18., 18.]]]]]) [8., 8., 8.]]]]])
output_shape = conv3d._spatial_output_shape([4, 4, 4])
self.assertAllClose(output_shape, [2, 2, 2])
self.assertEqual(predicted.shape, expected.shape) self.assertEqual(predicted.shape, expected.shape)
self.assertAllClose(predicted, expected) self.assertAllClose(predicted, expected)
...@@ -334,5 +337,74 @@ class NNLayersTest(parameterized.TestCase, tf.test.TestCase): ...@@ -334,5 +337,74 @@ class NNLayersTest(parameterized.TestCase, tf.test.TestCase):
self.assertEqual(predicted.shape, expected.shape) self.assertEqual(predicted.shape, expected.shape)
self.assertAllClose(predicted, expected) self.assertAllClose(predicted, expected)
def test_conv3d_causal_padding_2d(self):
"""Test to ensure causal padding works like standard padding."""
conv3d = nn_layers.Conv3D(
filters=1,
kernel_size=(1, 3, 3),
strides=(1, 2, 2),
padding='causal',
use_buffered_input=False,
kernel_initializer='ones',
use_bias=False,
)
keras_conv3d = tf.keras.layers.Conv3D(
filters=1,
kernel_size=(1, 3, 3),
strides=(1, 2, 2),
padding='same',
kernel_initializer='ones',
use_bias=False,
)
inputs = tf.ones([1, 1, 4, 4, 1])
predicted = conv3d(inputs)
expected = keras_conv3d(inputs)
self.assertEqual(predicted.shape, expected.shape)
self.assertAllClose(predicted, expected)
self.assertAllClose(predicted,
[[[[[9.],
[6.]],
[[6.],
[4.]]]]])
def test_conv3d_causal_padding_1d(self):
"""Test to ensure causal padding works like standard padding."""
conv3d = nn_layers.Conv3D(
filters=1,
kernel_size=(3, 1, 1),
strides=(2, 1, 1),
padding='causal',
use_buffered_input=False,
kernel_initializer='ones',
use_bias=False,
)
keras_conv1d = tf.keras.layers.Conv1D(
filters=1,
kernel_size=3,
strides=2,
padding='causal',
kernel_initializer='ones',
use_bias=False,
)
inputs = tf.ones([1, 4, 1, 1, 1])
predicted = conv3d(inputs)
expected = keras_conv1d(tf.squeeze(inputs, axis=[2, 3]))
expected = tf.reshape(expected, [1, 2, 1, 1, 1])
self.assertEqual(predicted.shape, expected.shape)
self.assertAllClose(predicted, expected)
self.assertAllClose(predicted,
[[[[[1.]]],
[[[3.]]]]])
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment