Commit 6db23a1f authored by Atze00's avatar Atze00
Browse files

added tests and fixed some wrong behaviour

parent bb3cc770
...@@ -706,7 +706,7 @@ class CausalConvMixin: ...@@ -706,7 +706,7 @@ class CausalConvMixin:
self._use_buffered_input = variable self._use_buffered_input = variable
def _compute_buffered_causal_padding(self, def _compute_buffered_causal_padding(self,
inputs: Optional[tf.Tensor] = None, inputs: tf.Tensor,
use_buffered_input: bool = False, use_buffered_input: bool = False,
time_axis: int = 1) -> List[List[int]]: time_axis: int = 1) -> List[List[int]]:
"""Calculates padding for 'causal' option for conv layers. """Calculates padding for 'causal' option for conv layers.
...@@ -720,7 +720,7 @@ class CausalConvMixin: ...@@ -720,7 +720,7 @@ class CausalConvMixin:
Returns: Returns:
A list of paddings for `tf.pad`. A list of paddings for `tf.pad`.
""" """
shape_in = inputs.shape[1:-1] input_shape = tf.shape(inputs)[1:-1]
del inputs del inputs
if tf.keras.backend.image_data_format() == 'channels_first': if tf.keras.backend.image_data_format() == 'channels_first':
...@@ -731,11 +731,10 @@ class CausalConvMixin: ...@@ -731,11 +731,10 @@ class CausalConvMixin:
(self.kernel_size[i] - 1) * (self.dilation_rate[i] - 1)) (self.kernel_size[i] - 1) * (self.dilation_rate[i] - 1))
for i in range(self.rank) for i in range(self.rank)
] ]
pad_total = [max(kernel_size_effective[i] - (self.strides[i]), 0) pad_total = [kernel_size_effective[0] - 1]
if (shape_in[i]%self.strides[i]) == 0 else for i in range(1, self.rank):
max(kernel_size_effective[i] - overlap = (input_shape[i] - 1) % self.strides[i] + 1
(shape_in[i]%self.strides[i]), 0) pad_total.append(tf.maximum(kernel_size_effective[i] - overlap, 0))
for i in range(self.rank)]
pad_beg = [pad_total[i] // 2 for i in range(self.rank)] pad_beg = [pad_total[i] // 2 for i in range(self.rank)]
pad_end = [pad_total[i] - pad_beg[i] for i in range(self.rank)] pad_end = [pad_total[i] - pad_beg[i] for i in range(self.rank)]
padding = [[pad_beg[i], pad_end[i]] for i in range(self.rank)] padding = [[pad_beg[i], pad_end[i]] for i in range(self.rank)]
...@@ -768,7 +767,8 @@ class CausalConvMixin: ...@@ -768,7 +767,8 @@ class CausalConvMixin:
# across time should be the input shape minus any padding, assuming # across time should be the input shape minus any padding, assuming
# the stride across time is 1. # the stride across time is 1.
if self._use_buffered_input and spatial_output_shape[0] is not None: if self._use_buffered_input and spatial_output_shape[0] is not None:
padding = self._compute_buffered_causal_padding(use_buffered_input=False) padding = self._compute_buffered_causal_padding(
tf.zeros([1] + spatial_output_shape + [1]), use_buffered_input=False)
spatial_output_shape[0] -= sum(padding[1]) spatial_output_shape[0] -= sum(padding[1])
return spatial_output_shape return spatial_output_shape
......
...@@ -320,6 +320,9 @@ class NNLayersTest(parameterized.TestCase, tf.test.TestCase): ...@@ -320,6 +320,9 @@ class NNLayersTest(parameterized.TestCase, tf.test.TestCase):
[[12., 12., 12.], [[12., 12., 12.],
[8., 8., 8.]]]]]) [8., 8., 8.]]]]])
output_shape = conv3d._spatial_output_shape([4, 4, 4])
self.assertAllClose(output_shape, [2, 2, 2])
self.assertEqual(predicted.shape, expected.shape) self.assertEqual(predicted.shape, expected.shape)
self.assertAllClose(predicted, expected) self.assertAllClose(predicted, expected)
...@@ -329,5 +332,74 @@ class NNLayersTest(parameterized.TestCase, tf.test.TestCase): ...@@ -329,5 +332,74 @@ class NNLayersTest(parameterized.TestCase, tf.test.TestCase):
self.assertEqual(predicted.shape, expected.shape) self.assertEqual(predicted.shape, expected.shape)
self.assertAllClose(predicted, expected) self.assertAllClose(predicted, expected)
def test_conv3d_causal_padding_2d(self):
"""Test to ensure causal padding works like standard padding."""
conv3d = nn_layers.Conv3D(
filters=1,
kernel_size=(1, 3, 3),
strides=(1, 2, 2),
padding='causal',
use_buffered_input=False,
kernel_initializer='ones',
use_bias=False,
)
keras_conv3d = tf.keras.layers.Conv3D(
filters=1,
kernel_size=(1, 3, 3),
strides=(1, 2, 2),
padding='same',
kernel_initializer='ones',
use_bias=False,
)
inputs = tf.ones([1, 1, 4, 4, 1])
predicted = conv3d(inputs)
expected = keras_conv3d(inputs)
self.assertEqual(predicted.shape, expected.shape)
self.assertAllClose(predicted, expected)
self.assertAllClose(predicted,
[[[[[9.],
[6.]],
[[6.],
[4.]]]]])
def test_conv3d_causal_padding_1d(self):
"""Test to ensure causal padding works like standard padding."""
conv3d = nn_layers.Conv3D(
filters=1,
kernel_size=(3, 1, 1),
strides=(2, 1, 1),
padding='causal',
use_buffered_input=False,
kernel_initializer='ones',
use_bias=False,
)
keras_conv1d = tf.keras.layers.Conv1D(
filters=1,
kernel_size=3,
strides=2,
padding='causal',
kernel_initializer='ones',
use_bias=False,
)
inputs = tf.ones([1, 4, 1, 1, 1])
predicted = conv3d(inputs)
expected = keras_conv1d(tf.squeeze(inputs, axis=[2, 3]))
expected = tf.reshape(expected, [1, 2, 1, 1, 1])
self.assertEqual(predicted.shape, expected.shape)
self.assertAllClose(predicted, expected)
self.assertAllClose(predicted,
[[[[[1.]]],
[[[3.]]]]])
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment