Commit 2398e6a5 authored by Zhichao Lu's avatar Zhichao Lu Committed by pkulzc
Browse files

Tfmini requires positive values for the dimensions. Therefore, we need to test...

Tfmini requires positive values for the dimensions. Therefore, we need to test for no paddings in both width and height direction.

PiperOrigin-RevId: 190520529
parent 4a91d110
...@@ -180,26 +180,24 @@ def pad_to_multiple(tensor, multiple): ...@@ -180,26 +180,24 @@ def pad_to_multiple(tensor, multiple):
padded_tensor_width = int( padded_tensor_width = int(
math.ceil(float(tensor_width) / multiple) * multiple) math.ceil(float(tensor_width) / multiple) * multiple)
if (padded_tensor_height == tensor_height and
padded_tensor_width == tensor_width):
return tensor
if tensor_depth is None: if tensor_depth is None:
tensor_depth = tf.shape(tensor)[3] tensor_depth = tf.shape(tensor)[3]
# Use tf.concat instead of tf.pad to preserve static shape # Use tf.concat instead of tf.pad to preserve static shape
if padded_tensor_height != tensor_height:
height_pad = tf.zeros([ height_pad = tf.zeros([
batch_size, padded_tensor_height - tensor_height, tensor_width, batch_size, padded_tensor_height - tensor_height, tensor_width,
tensor_depth tensor_depth
]) ])
padded_tensor = tf.concat([tensor, height_pad], 1) tensor = tf.concat([tensor, height_pad], 1)
if padded_tensor_width != tensor_width:
width_pad = tf.zeros([ width_pad = tf.zeros([
batch_size, padded_tensor_height, padded_tensor_width - tensor_width, batch_size, padded_tensor_height, padded_tensor_width - tensor_width,
tensor_depth tensor_depth
]) ])
padded_tensor = tf.concat([padded_tensor, width_pad], 2) tensor = tf.concat([tensor, width_pad], 2)
return padded_tensor return tensor
def padded_one_hot_encoding(indices, depth, left_pad): def padded_one_hot_encoding(indices, depth, left_pad):
......
...@@ -136,6 +136,13 @@ class OpsTestPadToMultiple(tf.test.TestCase): ...@@ -136,6 +136,13 @@ class OpsTestPadToMultiple(tf.test.TestCase):
padded_tensor_out = sess.run(padded_tensor) padded_tensor_out = sess.run(padded_tensor)
self.assertEqual((1, 2, 2, 1), padded_tensor_out.shape) self.assertEqual((1, 2, 2, 1), padded_tensor_out.shape)
def test_non_square_padding(self):
tensor = tf.constant([[[[0.], [0.]]]])
padded_tensor = ops.pad_to_multiple(tensor, 2)
with self.test_session() as sess:
padded_tensor_out = sess.run(padded_tensor)
self.assertEqual((1, 2, 2, 1), padded_tensor_out.shape)
def test_padding(self): def test_padding(self):
tensor = tf.constant([[[[0.], [0.]], [[0.], [0.]]]]) tensor = tf.constant([[[[0.], [0.]], [[0.], [0.]]]])
padded_tensor = ops.pad_to_multiple(tensor, 4) padded_tensor = ops.pad_to_multiple(tensor, 4)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment