Commit fe31beae authored by Zhichao Lu's avatar Zhichao Lu Committed by lzc5123016
Browse files

Ensure that `ops.matmul_gather_on_zeroth_axis` works with dynamic shapes.

PiperOrigin-RevId: 185218732
parent 1efe98bb
......@@ -813,9 +813,10 @@ def matmul_gather_on_zeroth_axis(params, indices, scope=None):
from indices given by indices, with shape indices.shape + params.shape[1:].
"""
with tf.name_scope(scope, 'MatMulGather'):
index_range = params.shape[0]
params2d = tf.reshape(params, [index_range, -1])
indicator_matrix = tf.one_hot(indices, index_range)
params_shape = shape_utils.combined_static_and_dynamic_shape(params)
indices_shape = shape_utils.combined_static_and_dynamic_shape(indices)
params2d = tf.reshape(params, [params_shape[0], -1])
indicator_matrix = tf.one_hot(indices, params_shape[0])
gathered_result_flattened = tf.matmul(indicator_matrix, params2d)
return tf.reshape(gathered_result_flattened,
indices.shape.concatenate(params.shape[1:]))
tf.stack(indices_shape + params_shape[1:]))
......@@ -840,7 +840,7 @@ class OpsTestPositionSensitiveCropRegions(tf.test.TestCase):
# All channels are equal so position-sensitive crop and resize should
# work as the usual crop and resize for just one channel.
crop = tf.image.crop_and_resize(image, boxes, box_ind, crop_size)
crop_and_pool = tf.reduce_mean(crop, [1, 2], keep_dims=True)
crop_and_pool = tf.reduce_mean(crop, [1, 2], keepdims=True)
ps_crop_and_pool = ops.position_sensitive_crop_regions(
tiled_image,
......@@ -866,7 +866,7 @@ class OpsTestPositionSensitiveCropRegions(tf.test.TestCase):
# When a single bin is used, position-sensitive crop and pool should be
# the same as non-position sensitive crop and pool.
crop = tf.image.crop_and_resize(image, boxes, box_ind, crop_size)
crop_and_pool = tf.reduce_mean(crop, [1, 2], keep_dims=True)
crop_and_pool = tf.reduce_mean(crop, [1, 2], keepdims=True)
ps_crop_and_pool = ops.position_sensitive_crop_regions(
image, boxes, box_ind, crop_size, num_spatial_bins, global_pool=True)
......@@ -1054,7 +1054,7 @@ class OpsTestPositionSensitiveCropRegions(tf.test.TestCase):
ps_crop = ops.position_sensitive_crop_regions(
image, boxes, box_ind, crop_size, num_spatial_bins, global_pool=False)
ps_crop_and_pool = tf.reduce_mean(
ps_crop, reduction_indices=(1, 2), keep_dims=True)
ps_crop, reduction_indices=(1, 2), keepdims=True)
with self.test_session() as sess:
output = sess.run(ps_crop_and_pool)
......@@ -1225,6 +1225,21 @@ class MatmulGatherOnZerothAxis(test_case.TestCase):
gather_output = self.execute(graph_fn, [params, indices])
self.assertAllClose(gather_output, expected_output)
def test_gather_with_dynamic_shape_input(self):
params_placeholder = tf.placeholder(tf.float32, shape=[None, 4])
indices_placeholder = tf.placeholder(tf.int32, shape=[None])
gather_result = ops.matmul_gather_on_zeroth_axis(
params_placeholder, indices_placeholder)
params = np.array([[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12],
[0, 1, 0, 0]], dtype=np.float32)
indices = np.array([0, 0, 0, 0, 0, 0])
expected_output = np.array(6*[[1, 2, 3, 4]])
with self.test_session() as sess:
gather_output = sess.run(gather_result, feed_dict={
params_placeholder: params, indices_placeholder: indices})
self.assertAllClose(gather_output, expected_output)
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment