"git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "db5ef3004c539688f15dbd3fb3ee9d8c0b48fe05"
Commit 2398e6a5 authored by Zhichao Lu's avatar Zhichao Lu Committed by pkulzc
Browse files

Tfmini requires positive values for the dimensions. Therefore, we need to test...

Tfmini requires positive values for the dimensions. Therefore, we need to test for no paddings in both width and height direction.

PiperOrigin-RevId: 190520529
parent 4a91d110
...@@ -180,26 +180,24 @@ def pad_to_multiple(tensor, multiple): ...@@ -180,26 +180,24 @@ def pad_to_multiple(tensor, multiple):
padded_tensor_width = int( padded_tensor_width = int(
math.ceil(float(tensor_width) / multiple) * multiple) math.ceil(float(tensor_width) / multiple) * multiple)
if (padded_tensor_height == tensor_height and
padded_tensor_width == tensor_width):
return tensor
if tensor_depth is None: if tensor_depth is None:
tensor_depth = tf.shape(tensor)[3] tensor_depth = tf.shape(tensor)[3]
# Use tf.concat instead of tf.pad to preserve static shape # Use tf.concat instead of tf.pad to preserve static shape
height_pad = tf.zeros([ if padded_tensor_height != tensor_height:
batch_size, padded_tensor_height - tensor_height, tensor_width, height_pad = tf.zeros([
tensor_depth batch_size, padded_tensor_height - tensor_height, tensor_width,
]) tensor_depth
padded_tensor = tf.concat([tensor, height_pad], 1) ])
width_pad = tf.zeros([ tensor = tf.concat([tensor, height_pad], 1)
batch_size, padded_tensor_height, padded_tensor_width - tensor_width, if padded_tensor_width != tensor_width:
tensor_depth width_pad = tf.zeros([
]) batch_size, padded_tensor_height, padded_tensor_width - tensor_width,
padded_tensor = tf.concat([padded_tensor, width_pad], 2) tensor_depth
])
return padded_tensor tensor = tf.concat([tensor, width_pad], 2)
return tensor
def padded_one_hot_encoding(indices, depth, left_pad): def padded_one_hot_encoding(indices, depth, left_pad):
......
...@@ -136,6 +136,13 @@ class OpsTestPadToMultiple(tf.test.TestCase): ...@@ -136,6 +136,13 @@ class OpsTestPadToMultiple(tf.test.TestCase):
padded_tensor_out = sess.run(padded_tensor) padded_tensor_out = sess.run(padded_tensor)
self.assertEqual((1, 2, 2, 1), padded_tensor_out.shape) self.assertEqual((1, 2, 2, 1), padded_tensor_out.shape)
def test_non_square_padding(self):
tensor = tf.constant([[[[0.], [0.]]]])
padded_tensor = ops.pad_to_multiple(tensor, 2)
with self.test_session() as sess:
padded_tensor_out = sess.run(padded_tensor)
self.assertEqual((1, 2, 2, 1), padded_tensor_out.shape)
def test_padding(self): def test_padding(self):
tensor = tf.constant([[[[0.], [0.]], [[0.], [0.]]]]) tensor = tf.constant([[[[0.], [0.]], [[0.], [0.]]]])
padded_tensor = ops.pad_to_multiple(tensor, 4) padded_tensor = ops.pad_to_multiple(tensor, 4)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment