Unverified Commit 44f6d511 authored by Srihari Humbarwadi's avatar Srihari Humbarwadi Committed by GitHub
Browse files

Merge branch 'tensorflow:master' into panoptic-deeplab

parents 686a287d 8bc5a1a5
......@@ -101,15 +101,21 @@ def rewrite_nn_resize_op(is_quantized=False):
'FakeQuantWithMinMaxVars' if is_quantized else '*')
stack_1_pattern = graph_matcher.OpTypePattern(
'Pack', inputs=[input_pattern, input_pattern], ordered_inputs=False)
reshape_1_pattern = graph_matcher.OpTypePattern(
'Reshape', inputs=[stack_1_pattern, 'Const'], ordered_inputs=False)
stack_2_pattern = graph_matcher.OpTypePattern(
'Pack', inputs=[stack_1_pattern, stack_1_pattern], ordered_inputs=False)
reshape_pattern = graph_matcher.OpTypePattern(
'Pack',
inputs=[reshape_1_pattern, reshape_1_pattern],
ordered_inputs=False)
reshape_2_pattern = graph_matcher.OpTypePattern(
'Reshape', inputs=[stack_2_pattern, 'Const'], ordered_inputs=False)
consumer_pattern1 = graph_matcher.OpTypePattern(
'Add|AddV2|Max|Mul', inputs=[reshape_pattern, '*'],
'Add|AddV2|Max|Mul',
inputs=[reshape_2_pattern, '*'],
ordered_inputs=False)
consumer_pattern2 = graph_matcher.OpTypePattern(
'StridedSlice', inputs=[reshape_pattern, '*', '*', '*'],
'StridedSlice',
inputs=[reshape_2_pattern, '*', '*', '*'],
ordered_inputs=False)
def replace_matches(consumer_pattern):
......@@ -119,16 +125,17 @@ def rewrite_nn_resize_op(is_quantized=False):
for match in matcher.match_graph(tf.get_default_graph()):
match_counter += 1
projection_op = match.get_op(input_pattern)
reshape_op = match.get_op(reshape_pattern)
reshape_2_op = match.get_op(reshape_2_pattern)
consumer_op = match.get_op(consumer_pattern)
nn_resize = tf.image.resize_nearest_neighbor(
projection_op.outputs[0],
reshape_op.outputs[0].shape.dims[1:3],
reshape_2_op.outputs[0].shape.dims[1:3],
align_corners=False,
name=os.path.split(reshape_op.name)[0] + '/resize_nearest_neighbor')
name=os.path.split(reshape_2_op.name)[0] +
'/resize_nearest_neighbor')
for index, op_input in enumerate(consumer_op.inputs):
if op_input == reshape_op.outputs[0]:
if op_input == reshape_2_op.outputs[0]:
consumer_op._update_input(index, nn_resize) # pylint: disable=protected-access
break
......
......@@ -1168,16 +1168,20 @@ class ExportInferenceGraphTest(tf.test.TestCase):
g = tf.Graph()
with g.as_default():
with tf.name_scope('nearest_upsampling'):
x = array_ops.placeholder(dtypes.float32, shape=(8, 10, 10, 8))
x_stack = tf.stack([tf.stack([x] * 2, axis=3)] * 2, axis=2)
x_reshape = tf.reshape(x_stack, [8, 20, 20, 8])
x_1 = array_ops.placeholder(dtypes.float32, shape=(8, 10, 10, 8))
x_1_stack_1 = tf.stack([x_1] * 2, axis=3)
x_1_reshape_1 = tf.reshape(x_1_stack_1, [8, 10, 20, 8])
x_1_stack_2 = tf.stack([x_1_reshape_1] * 2, axis=2)
x_1_reshape_2 = tf.reshape(x_1_stack_2, [8, 20, 20, 8])
with tf.name_scope('nearest_upsampling'):
x_2 = array_ops.placeholder(dtypes.float32, shape=(8, 10, 10, 8))
x_stack_2 = tf.stack([tf.stack([x_2] * 2, axis=3)] * 2, axis=2)
x_reshape_2 = tf.reshape(x_stack_2, [8, 20, 20, 8])
x_2_stack_1 = tf.stack([x_2] * 2, axis=3)
x_2_reshape_1 = tf.reshape(x_2_stack_1, [8, 10, 20, 8])
x_2_stack_2 = tf.stack([x_2_reshape_1] * 2, axis=2)
x_2_reshape_2 = tf.reshape(x_2_stack_2, [8, 20, 20, 8])
t = x_reshape + x_reshape_2
t = x_1_reshape_2 + x_2_reshape_2
exporter.rewrite_nn_resize_op()
......
......@@ -54,18 +54,20 @@ INPUT_BUILDER_UTIL_MAP = {
}
def _multiclass_scores_or_one_hot_labels(multiclass_scores,
groundtruth_boxes,
def _multiclass_scores_or_one_hot_labels(multiclass_scores, groundtruth_boxes,
groundtruth_classes, num_classes):
"""Returns one-hot encoding of classes when multiclass_scores is empty."""
# Replace groundtruth_classes tensor with multiclass_scores tensor when its
# non-empty. If multiclass_scores is empty fall back on groundtruth_classes
# tensor.
def true_fn():
return tf.reshape(multiclass_scores,
[tf.shape(groundtruth_boxes)[0], num_classes])
def false_fn():
return tf.one_hot(groundtruth_classes, num_classes)
return tf.cond(tf.size(multiclass_scores) > 0, true_fn, false_fn)
......@@ -132,8 +134,7 @@ def assert_or_prune_invalid_boxes(boxes):
This is not supported on TPUs.
"""
ymin, xmin, ymax, xmax = tf.split(
boxes, num_or_size_splits=4, axis=1)
ymin, xmin, ymax, xmax = tf.split(boxes, num_or_size_splits=4, axis=1)
height_check = tf.Assert(tf.reduce_all(ymax >= ymin), [ymin, ymax])
width_check = tf.Assert(tf.reduce_all(xmax >= xmin), [xmin, xmax])
......@@ -157,7 +158,8 @@ def transform_input_data(tensor_dict,
use_multiclass_scores=False,
use_bfloat16=False,
retain_original_image_additional_channels=False,
keypoint_type_weight=None):
keypoint_type_weight=None,
image_classes_field_map_empty_to_ones=True):
"""A single function that is responsible for all input data transformations.
Data transformation functions are applied in the following order.
......@@ -206,6 +208,9 @@ def transform_input_data(tensor_dict,
keypoint_type_weight: A list (of length num_keypoints) containing
groundtruth loss weights to use for each keypoint. If None, will use a
weight of 1.
image_classes_field_map_empty_to_ones: A boolean flag indicating if empty
image classes field indicates that all classes have been labeled on this
image [true] or none [false].
Returns:
A dictionary keyed by fields.InputDataFields containing the tensors obtained
......@@ -229,9 +234,9 @@ def transform_input_data(tensor_dict,
raise KeyError('groundtruth_labeled_classes and groundtruth_image_classes'
'are provided by the decoder, but only one should be set.')
for field, map_empty_to_ones in [
(labeled_classes_field, True),
(image_classes_field, True),
for field, map_empty_to_ones in [(labeled_classes_field, True),
(image_classes_field,
image_classes_field_map_empty_to_ones),
(verified_neg_classes_field, False),
(not_exhaustive_field, False)]:
if field in out_tensor_dict:
......@@ -1044,7 +1049,9 @@ def eval_input(eval_config, eval_input_config, model_config,
retain_original_image=eval_config.retain_original_images,
retain_original_image_additional_channels=
eval_config.retain_original_image_additional_channels,
keypoint_type_weight=keypoint_type_weight)
keypoint_type_weight=keypoint_type_weight,
image_classes_field_map_empty_to_ones=eval_config
.image_classes_field_map_empty_to_ones)
tensor_dict = pad_input_data_to_static_shapes(
tensor_dict=transform_data_fn(tensor_dict),
max_num_boxes=eval_input_config.max_number_of_boxes,
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment