Commit 6a98d040 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower Committed by TF Object Detection Team
Browse files

Avoids generating 6D tensors in `nearest_neighbor_upsampling`.

It creates unnecessary complexity for backend compilers to deal with.

PiperOrigin-RevId: 443509141
parent 2aee5734
......@@ -101,15 +101,21 @@ def rewrite_nn_resize_op(is_quantized=False):
'FakeQuantWithMinMaxVars' if is_quantized else '*')
stack_1_pattern = graph_matcher.OpTypePattern(
'Pack', inputs=[input_pattern, input_pattern], ordered_inputs=False)
reshape_1_pattern = graph_matcher.OpTypePattern(
'Reshape', inputs=[stack_1_pattern, 'Const'], ordered_inputs=False)
stack_2_pattern = graph_matcher.OpTypePattern(
'Pack', inputs=[stack_1_pattern, stack_1_pattern], ordered_inputs=False)
reshape_pattern = graph_matcher.OpTypePattern(
'Pack',
inputs=[reshape_1_pattern, reshape_1_pattern],
ordered_inputs=False)
reshape_2_pattern = graph_matcher.OpTypePattern(
'Reshape', inputs=[stack_2_pattern, 'Const'], ordered_inputs=False)
consumer_pattern1 = graph_matcher.OpTypePattern(
'Add|AddV2|Max|Mul', inputs=[reshape_pattern, '*'],
'Add|AddV2|Max|Mul',
inputs=[reshape_2_pattern, '*'],
ordered_inputs=False)
consumer_pattern2 = graph_matcher.OpTypePattern(
'StridedSlice', inputs=[reshape_pattern, '*', '*', '*'],
'StridedSlice',
inputs=[reshape_2_pattern, '*', '*', '*'],
ordered_inputs=False)
def replace_matches(consumer_pattern):
......@@ -119,16 +125,17 @@ def rewrite_nn_resize_op(is_quantized=False):
for match in matcher.match_graph(tf.get_default_graph()):
match_counter += 1
projection_op = match.get_op(input_pattern)
reshape_op = match.get_op(reshape_pattern)
reshape_2_op = match.get_op(reshape_2_pattern)
consumer_op = match.get_op(consumer_pattern)
nn_resize = tf.image.resize_nearest_neighbor(
projection_op.outputs[0],
reshape_op.outputs[0].shape.dims[1:3],
reshape_2_op.outputs[0].shape.dims[1:3],
align_corners=False,
name=os.path.split(reshape_op.name)[0] + '/resize_nearest_neighbor')
name=os.path.split(reshape_2_op.name)[0] +
'/resize_nearest_neighbor')
for index, op_input in enumerate(consumer_op.inputs):
if op_input == reshape_op.outputs[0]:
if op_input == reshape_2_op.outputs[0]:
consumer_op._update_input(index, nn_resize) # pylint: disable=protected-access
break
......
......@@ -1168,16 +1168,20 @@ class ExportInferenceGraphTest(tf.test.TestCase):
g = tf.Graph()
with g.as_default():
with tf.name_scope('nearest_upsampling'):
x = array_ops.placeholder(dtypes.float32, shape=(8, 10, 10, 8))
x_stack = tf.stack([tf.stack([x] * 2, axis=3)] * 2, axis=2)
x_reshape = tf.reshape(x_stack, [8, 20, 20, 8])
x_1 = array_ops.placeholder(dtypes.float32, shape=(8, 10, 10, 8))
x_1_stack_1 = tf.stack([x_1] * 2, axis=3)
x_1_reshape_1 = tf.reshape(x_1_stack_1, [8, 10, 20, 8])
x_1_stack_2 = tf.stack([x_1_reshape_1] * 2, axis=2)
x_1_reshape_2 = tf.reshape(x_1_stack_2, [8, 20, 20, 8])
with tf.name_scope('nearest_upsampling'):
x_2 = array_ops.placeholder(dtypes.float32, shape=(8, 10, 10, 8))
x_stack_2 = tf.stack([tf.stack([x_2] * 2, axis=3)] * 2, axis=2)
x_reshape_2 = tf.reshape(x_stack_2, [8, 20, 20, 8])
x_2_stack_1 = tf.stack([x_2] * 2, axis=3)
x_2_reshape_1 = tf.reshape(x_2_stack_1, [8, 10, 20, 8])
x_2_stack_2 = tf.stack([x_2_reshape_1] * 2, axis=2)
x_2_reshape_2 = tf.reshape(x_2_stack_2, [8, 20, 20, 8])
t = x_reshape + x_reshape_2
t = x_1_reshape_2 + x_2_reshape_2
exporter.rewrite_nn_resize_op()
......
......@@ -998,6 +998,10 @@ def nearest_neighbor_upsampling(input_tensor, scale=None, height_scale=None,
(batch_size, height, width,
channels) = shape_utils.combined_static_and_dynamic_shape(input_tensor)
output_tensor = tf.stack([input_tensor] * w_scale, axis=3, name='w_stack')
# Adds a reshape op to avoid generating high-dimensional tensors that some
# compilers cannot deal with.
output_tensor = tf.reshape(output_tensor,
[batch_size, height, width * w_scale, channels])
output_tensor = tf.stack([output_tensor] * h_scale, axis=2, name='h_stack')
return tf.reshape(output_tensor,
[batch_size, height * h_scale, width * w_scale, channels])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment