Commit 6db23a1f authored by Atze00's avatar Atze00
Browse files

added tests and fixed some wrong behaviour

parent bb3cc770
......@@ -706,7 +706,7 @@ class CausalConvMixin:
self._use_buffered_input = variable
def _compute_buffered_causal_padding(self,
inputs: Optional[tf.Tensor] = None,
inputs: tf.Tensor,
use_buffered_input: bool = False,
time_axis: int = 1) -> List[List[int]]:
"""Calculates padding for 'causal' option for conv layers.
......@@ -720,7 +720,7 @@ class CausalConvMixin:
Returns:
A list of paddings for `tf.pad`.
"""
shape_in = inputs.shape[1:-1]
input_shape = tf.shape(inputs)[1:-1]
del inputs
if tf.keras.backend.image_data_format() == 'channels_first':
......@@ -731,11 +731,10 @@ class CausalConvMixin:
(self.kernel_size[i] - 1) * (self.dilation_rate[i] - 1))
for i in range(self.rank)
]
pad_total = [max(kernel_size_effective[i] - (self.strides[i]), 0)
if (shape_in[i]%self.strides[i]) == 0 else
max(kernel_size_effective[i] -
(shape_in[i]%self.strides[i]), 0)
for i in range(self.rank)]
pad_total = [kernel_size_effective[0] - 1]
for i in range(1, self.rank):
overlap = (input_shape[i] - 1) % self.strides[i] + 1
pad_total.append(tf.maximum(kernel_size_effective[i] - overlap, 0))
pad_beg = [pad_total[i] // 2 for i in range(self.rank)]
pad_end = [pad_total[i] - pad_beg[i] for i in range(self.rank)]
padding = [[pad_beg[i], pad_end[i]] for i in range(self.rank)]
......@@ -768,7 +767,8 @@ class CausalConvMixin:
# across time should be the input shape minus any padding, assuming
# the stride across time is 1.
if self._use_buffered_input and spatial_output_shape[0] is not None:
padding = self._compute_buffered_causal_padding(use_buffered_input=False)
padding = self._compute_buffered_causal_padding(
tf.zeros([1] + spatial_output_shape + [1]), use_buffered_input=False)
spatial_output_shape[0] -= sum(padding[1])
return spatial_output_shape
......
......@@ -320,6 +320,9 @@ class NNLayersTest(parameterized.TestCase, tf.test.TestCase):
[[12., 12., 12.],
[8., 8., 8.]]]]])
output_shape = conv3d._spatial_output_shape([4, 4, 4])
self.assertAllClose(output_shape, [2, 2, 2])
self.assertEqual(predicted.shape, expected.shape)
self.assertAllClose(predicted, expected)
......@@ -329,5 +332,74 @@ class NNLayersTest(parameterized.TestCase, tf.test.TestCase):
self.assertEqual(predicted.shape, expected.shape)
self.assertAllClose(predicted, expected)
def test_conv3d_causal_padding_2d(self):
"""Test to ensure causal padding works like standard padding."""
conv3d = nn_layers.Conv3D(
filters=1,
kernel_size=(1, 3, 3),
strides=(1, 2, 2),
padding='causal',
use_buffered_input=False,
kernel_initializer='ones',
use_bias=False,
)
keras_conv3d = tf.keras.layers.Conv3D(
filters=1,
kernel_size=(1, 3, 3),
strides=(1, 2, 2),
padding='same',
kernel_initializer='ones',
use_bias=False,
)
inputs = tf.ones([1, 1, 4, 4, 1])
predicted = conv3d(inputs)
expected = keras_conv3d(inputs)
self.assertEqual(predicted.shape, expected.shape)
self.assertAllClose(predicted, expected)
self.assertAllClose(predicted,
[[[[[9.],
[6.]],
[[6.],
[4.]]]]])
def test_conv3d_causal_padding_1d(self):
"""Test to ensure causal padding works like standard padding."""
conv3d = nn_layers.Conv3D(
filters=1,
kernel_size=(3, 1, 1),
strides=(2, 1, 1),
padding='causal',
use_buffered_input=False,
kernel_initializer='ones',
use_bias=False,
)
keras_conv1d = tf.keras.layers.Conv1D(
filters=1,
kernel_size=3,
strides=2,
padding='causal',
kernel_initializer='ones',
use_bias=False,
)
inputs = tf.ones([1, 4, 1, 1, 1])
predicted = conv3d(inputs)
expected = keras_conv1d(tf.squeeze(inputs, axis=[2, 3]))
expected = tf.reshape(expected, [1, 2, 1, 1, 1])
self.assertEqual(predicted.shape, expected.shape)
self.assertAllClose(predicted, expected)
self.assertAllClose(predicted,
[[[[[1.]]],
[[[3.]]]]])
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment