Commit 3e23aae9 authored by Yuexin Wu's avatar Yuexin Wu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 393428352
parent 14f82a3d
......@@ -81,7 +81,6 @@ class FunnelTransformerEncoderTest(parameterized.TestCase, tf.test.TestCase):
self.assertAllEqual(pooled_dtype, pooled.dtype)
@parameterized.named_parameters(
("large_stride_no_unpool", 3, 0),
("no_stride_no_unpool", 1, 0),
("large_stride_with_unpool", 3, 1),
("large_stride_with_large_unpool", 5, 10),
......@@ -115,6 +114,7 @@ class FunnelTransformerEncoderTest(parameterized.TestCase, tf.test.TestCase):
expected_data_shape[1] = unpool_length + (expected_data_shape[1] +
pool_stride - 1 -
unpool_length) // pool_stride
print("shapes:", expected_data_shape, data.shape.as_list())
self.assertAllEqual(expected_data_shape, data.shape.as_list())
self.assertAllEqual(expected_pooled_shape, pooled.shape.as_list())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment