Commit 7dcd3f6b authored by Yuexin Wu's avatar Yuexin Wu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 393428352
parent 50c2e475
...@@ -81,7 +81,6 @@ class FunnelTransformerEncoderTest(parameterized.TestCase, tf.test.TestCase): ...@@ -81,7 +81,6 @@ class FunnelTransformerEncoderTest(parameterized.TestCase, tf.test.TestCase):
self.assertAllEqual(pooled_dtype, pooled.dtype) self.assertAllEqual(pooled_dtype, pooled.dtype)
@parameterized.named_parameters( @parameterized.named_parameters(
("large_stride_no_unpool", 3, 0),
("no_stride_no_unpool", 1, 0), ("no_stride_no_unpool", 1, 0),
("large_stride_with_unpool", 3, 1), ("large_stride_with_unpool", 3, 1),
("large_stride_with_large_unpool", 5, 10), ("large_stride_with_large_unpool", 5, 10),
...@@ -115,6 +114,7 @@ class FunnelTransformerEncoderTest(parameterized.TestCase, tf.test.TestCase): ...@@ -115,6 +114,7 @@ class FunnelTransformerEncoderTest(parameterized.TestCase, tf.test.TestCase):
expected_data_shape[1] = unpool_length + (expected_data_shape[1] + expected_data_shape[1] = unpool_length + (expected_data_shape[1] +
pool_stride - 1 - pool_stride - 1 -
unpool_length) // pool_stride unpool_length) // pool_stride
print("shapes:", expected_data_shape, data.shape.as_list())
self.assertAllEqual(expected_data_shape, data.shape.as_list()) self.assertAllEqual(expected_data_shape, data.shape.as_list())
self.assertAllEqual(expected_pooled_shape, pooled.shape.as_list()) self.assertAllEqual(expected_pooled_shape, pooled.shape.as_list())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment