Commit bb3cc770 authored by Atze00's avatar Atze00
Browse files

fixed same padding in causal convolutions

parent 5ad16f95
...@@ -720,6 +720,7 @@ class CausalConvMixin: ...@@ -720,6 +720,7 @@ class CausalConvMixin:
Returns: Returns:
A list of paddings for `tf.pad`. A list of paddings for `tf.pad`.
""" """
shape_in = inputs.shape[1:-1]
del inputs del inputs
if tf.keras.backend.image_data_format() == 'channels_first': if tf.keras.backend.image_data_format() == 'channels_first':
...@@ -730,7 +731,11 @@ class CausalConvMixin: ...@@ -730,7 +731,11 @@ class CausalConvMixin:
(self.kernel_size[i] - 1) * (self.dilation_rate[i] - 1)) (self.kernel_size[i] - 1) * (self.dilation_rate[i] - 1))
for i in range(self.rank) for i in range(self.rank)
] ]
pad_total = [kernel_size_effective[i] - 1 for i in range(self.rank)] pad_total = [max(kernel_size_effective[i] - (self.strides[i]), 0)
if (shape_in[i]%self.strides[i]) == 0 else
max(kernel_size_effective[i] -
(shape_in[i]%self.strides[i]), 0)
for i in range(self.rank)]
pad_beg = [pad_total[i] // 2 for i in range(self.rank)] pad_beg = [pad_total[i] // 2 for i in range(self.rank)]
pad_end = [pad_total[i] - pad_beg[i] for i in range(self.rank)] pad_end = [pad_total[i] - pad_beg[i] for i in range(self.rank)]
padding = [[pad_beg[i], pad_end[i]] for i in range(self.rank)] padding = [[pad_beg[i], pad_end[i]] for i in range(self.rank)]
......
...@@ -274,14 +274,14 @@ class NNLayersTest(parameterized.TestCase, tf.test.TestCase): ...@@ -274,14 +274,14 @@ class NNLayersTest(parameterized.TestCase, tf.test.TestCase):
predicted = conv3d(padded_inputs) predicted = conv3d(padded_inputs)
expected = tf.constant( expected = tf.constant(
[[[[[12., 12., 12.], [[[[[27., 27., 27.],
[18., 18., 18.]], [18., 18., 18.]],
[[18., 18., 18.], [[18., 18., 18.],
[27., 27., 27.]]], [12., 12., 12.]]],
[[[24., 24., 24.], [[[54., 54., 54.],
[36., 36., 36.]], [36., 36., 36.]],
[[36., 36., 36.], [[36., 36., 36.],
[54., 54., 54.]]]]]) [24., 24., 24.]]]]])
self.assertEqual(predicted.shape, expected.shape) self.assertEqual(predicted.shape, expected.shape)
self.assertAllClose(predicted, expected) self.assertAllClose(predicted, expected)
...@@ -311,14 +311,14 @@ class NNLayersTest(parameterized.TestCase, tf.test.TestCase): ...@@ -311,14 +311,14 @@ class NNLayersTest(parameterized.TestCase, tf.test.TestCase):
predicted = conv3d(padded_inputs) predicted = conv3d(padded_inputs)
expected = tf.constant( expected = tf.constant(
[[[[[4.0, 4.0, 4.0], [[[[[9.0, 9.0, 9.0],
[6.0, 6.0, 6.0]], [6.0, 6.0, 6.0]],
[[6.0, 6.0, 6.0], [[6.0, 6.0, 6.0],
[9.0, 9.0, 9.0]]], [4.0, 4.0, 4.0]]],
[[[8.0, 8.0, 8.0], [[[18.0, 18.0, 18.0],
[12., 12., 12.]], [12., 12., 12.]],
[[12., 12., 12.], [[12., 12., 12.],
[18., 18., 18.]]]]]) [8., 8., 8.]]]]])
self.assertEqual(predicted.shape, expected.shape) self.assertEqual(predicted.shape, expected.shape)
self.assertAllClose(predicted, expected) self.assertAllClose(predicted, expected)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment