Unverified Commit 44f6d511 authored by Srihari Humbarwadi's avatar Srihari Humbarwadi Committed by GitHub
Browse files

Merge branch 'tensorflow:master' into panoptic-deeplab

parents 686a287d 8bc5a1a5
...@@ -101,15 +101,21 @@ def rewrite_nn_resize_op(is_quantized=False): ...@@ -101,15 +101,21 @@ def rewrite_nn_resize_op(is_quantized=False):
'FakeQuantWithMinMaxVars' if is_quantized else '*') 'FakeQuantWithMinMaxVars' if is_quantized else '*')
stack_1_pattern = graph_matcher.OpTypePattern( stack_1_pattern = graph_matcher.OpTypePattern(
'Pack', inputs=[input_pattern, input_pattern], ordered_inputs=False) 'Pack', inputs=[input_pattern, input_pattern], ordered_inputs=False)
reshape_1_pattern = graph_matcher.OpTypePattern(
'Reshape', inputs=[stack_1_pattern, 'Const'], ordered_inputs=False)
stack_2_pattern = graph_matcher.OpTypePattern( stack_2_pattern = graph_matcher.OpTypePattern(
'Pack', inputs=[stack_1_pattern, stack_1_pattern], ordered_inputs=False) 'Pack',
reshape_pattern = graph_matcher.OpTypePattern( inputs=[reshape_1_pattern, reshape_1_pattern],
ordered_inputs=False)
reshape_2_pattern = graph_matcher.OpTypePattern(
'Reshape', inputs=[stack_2_pattern, 'Const'], ordered_inputs=False) 'Reshape', inputs=[stack_2_pattern, 'Const'], ordered_inputs=False)
consumer_pattern1 = graph_matcher.OpTypePattern( consumer_pattern1 = graph_matcher.OpTypePattern(
'Add|AddV2|Max|Mul', inputs=[reshape_pattern, '*'], 'Add|AddV2|Max|Mul',
inputs=[reshape_2_pattern, '*'],
ordered_inputs=False) ordered_inputs=False)
consumer_pattern2 = graph_matcher.OpTypePattern( consumer_pattern2 = graph_matcher.OpTypePattern(
'StridedSlice', inputs=[reshape_pattern, '*', '*', '*'], 'StridedSlice',
inputs=[reshape_2_pattern, '*', '*', '*'],
ordered_inputs=False) ordered_inputs=False)
def replace_matches(consumer_pattern): def replace_matches(consumer_pattern):
...@@ -119,16 +125,17 @@ def rewrite_nn_resize_op(is_quantized=False): ...@@ -119,16 +125,17 @@ def rewrite_nn_resize_op(is_quantized=False):
for match in matcher.match_graph(tf.get_default_graph()): for match in matcher.match_graph(tf.get_default_graph()):
match_counter += 1 match_counter += 1
projection_op = match.get_op(input_pattern) projection_op = match.get_op(input_pattern)
reshape_op = match.get_op(reshape_pattern) reshape_2_op = match.get_op(reshape_2_pattern)
consumer_op = match.get_op(consumer_pattern) consumer_op = match.get_op(consumer_pattern)
nn_resize = tf.image.resize_nearest_neighbor( nn_resize = tf.image.resize_nearest_neighbor(
projection_op.outputs[0], projection_op.outputs[0],
reshape_op.outputs[0].shape.dims[1:3], reshape_2_op.outputs[0].shape.dims[1:3],
align_corners=False, align_corners=False,
name=os.path.split(reshape_op.name)[0] + '/resize_nearest_neighbor') name=os.path.split(reshape_2_op.name)[0] +
'/resize_nearest_neighbor')
for index, op_input in enumerate(consumer_op.inputs): for index, op_input in enumerate(consumer_op.inputs):
if op_input == reshape_op.outputs[0]: if op_input == reshape_2_op.outputs[0]:
consumer_op._update_input(index, nn_resize) # pylint: disable=protected-access consumer_op._update_input(index, nn_resize) # pylint: disable=protected-access
break break
......
...@@ -1168,16 +1168,20 @@ class ExportInferenceGraphTest(tf.test.TestCase): ...@@ -1168,16 +1168,20 @@ class ExportInferenceGraphTest(tf.test.TestCase):
g = tf.Graph() g = tf.Graph()
with g.as_default(): with g.as_default():
with tf.name_scope('nearest_upsampling'): with tf.name_scope('nearest_upsampling'):
x = array_ops.placeholder(dtypes.float32, shape=(8, 10, 10, 8)) x_1 = array_ops.placeholder(dtypes.float32, shape=(8, 10, 10, 8))
x_stack = tf.stack([tf.stack([x] * 2, axis=3)] * 2, axis=2) x_1_stack_1 = tf.stack([x_1] * 2, axis=3)
x_reshape = tf.reshape(x_stack, [8, 20, 20, 8]) x_1_reshape_1 = tf.reshape(x_1_stack_1, [8, 10, 20, 8])
x_1_stack_2 = tf.stack([x_1_reshape_1] * 2, axis=2)
x_1_reshape_2 = tf.reshape(x_1_stack_2, [8, 20, 20, 8])
with tf.name_scope('nearest_upsampling'): with tf.name_scope('nearest_upsampling'):
x_2 = array_ops.placeholder(dtypes.float32, shape=(8, 10, 10, 8)) x_2 = array_ops.placeholder(dtypes.float32, shape=(8, 10, 10, 8))
x_stack_2 = tf.stack([tf.stack([x_2] * 2, axis=3)] * 2, axis=2) x_2_stack_1 = tf.stack([x_2] * 2, axis=3)
x_reshape_2 = tf.reshape(x_stack_2, [8, 20, 20, 8]) x_2_reshape_1 = tf.reshape(x_2_stack_1, [8, 10, 20, 8])
x_2_stack_2 = tf.stack([x_2_reshape_1] * 2, axis=2)
x_2_reshape_2 = tf.reshape(x_2_stack_2, [8, 20, 20, 8])
t = x_reshape + x_reshape_2 t = x_1_reshape_2 + x_2_reshape_2
exporter.rewrite_nn_resize_op() exporter.rewrite_nn_resize_op()
......
...@@ -54,18 +54,20 @@ INPUT_BUILDER_UTIL_MAP = { ...@@ -54,18 +54,20 @@ INPUT_BUILDER_UTIL_MAP = {
} }
def _multiclass_scores_or_one_hot_labels(multiclass_scores, def _multiclass_scores_or_one_hot_labels(multiclass_scores, groundtruth_boxes,
groundtruth_boxes,
groundtruth_classes, num_classes): groundtruth_classes, num_classes):
"""Returns one-hot encoding of classes when multiclass_scores is empty.""" """Returns one-hot encoding of classes when multiclass_scores is empty."""
# Replace groundtruth_classes tensor with multiclass_scores tensor when its # Replace groundtruth_classes tensor with multiclass_scores tensor when its
# non-empty. If multiclass_scores is empty fall back on groundtruth_classes # non-empty. If multiclass_scores is empty fall back on groundtruth_classes
# tensor. # tensor.
def true_fn(): def true_fn():
return tf.reshape(multiclass_scores, return tf.reshape(multiclass_scores,
[tf.shape(groundtruth_boxes)[0], num_classes]) [tf.shape(groundtruth_boxes)[0], num_classes])
def false_fn(): def false_fn():
return tf.one_hot(groundtruth_classes, num_classes) return tf.one_hot(groundtruth_classes, num_classes)
return tf.cond(tf.size(multiclass_scores) > 0, true_fn, false_fn) return tf.cond(tf.size(multiclass_scores) > 0, true_fn, false_fn)
...@@ -132,8 +134,7 @@ def assert_or_prune_invalid_boxes(boxes): ...@@ -132,8 +134,7 @@ def assert_or_prune_invalid_boxes(boxes):
This is not supported on TPUs. This is not supported on TPUs.
""" """
ymin, xmin, ymax, xmax = tf.split( ymin, xmin, ymax, xmax = tf.split(boxes, num_or_size_splits=4, axis=1)
boxes, num_or_size_splits=4, axis=1)
height_check = tf.Assert(tf.reduce_all(ymax >= ymin), [ymin, ymax]) height_check = tf.Assert(tf.reduce_all(ymax >= ymin), [ymin, ymax])
width_check = tf.Assert(tf.reduce_all(xmax >= xmin), [xmin, xmax]) width_check = tf.Assert(tf.reduce_all(xmax >= xmin), [xmin, xmax])
...@@ -157,7 +158,8 @@ def transform_input_data(tensor_dict, ...@@ -157,7 +158,8 @@ def transform_input_data(tensor_dict,
use_multiclass_scores=False, use_multiclass_scores=False,
use_bfloat16=False, use_bfloat16=False,
retain_original_image_additional_channels=False, retain_original_image_additional_channels=False,
keypoint_type_weight=None): keypoint_type_weight=None,
image_classes_field_map_empty_to_ones=True):
"""A single function that is responsible for all input data transformations. """A single function that is responsible for all input data transformations.
Data transformation functions are applied in the following order. Data transformation functions are applied in the following order.
...@@ -206,6 +208,9 @@ def transform_input_data(tensor_dict, ...@@ -206,6 +208,9 @@ def transform_input_data(tensor_dict,
keypoint_type_weight: A list (of length num_keypoints) containing keypoint_type_weight: A list (of length num_keypoints) containing
groundtruth loss weights to use for each keypoint. If None, will use a groundtruth loss weights to use for each keypoint. If None, will use a
weight of 1. weight of 1.
image_classes_field_map_empty_to_ones: A boolean flag indicating if empty
image classes field indicates that all classes have been labeled on this
image [true] or none [false].
Returns: Returns:
A dictionary keyed by fields.InputDataFields containing the tensors obtained A dictionary keyed by fields.InputDataFields containing the tensors obtained
...@@ -229,9 +234,9 @@ def transform_input_data(tensor_dict, ...@@ -229,9 +234,9 @@ def transform_input_data(tensor_dict,
raise KeyError('groundtruth_labeled_classes and groundtruth_image_classes' raise KeyError('groundtruth_labeled_classes and groundtruth_image_classes'
'are provided by the decoder, but only one should be set.') 'are provided by the decoder, but only one should be set.')
for field, map_empty_to_ones in [ for field, map_empty_to_ones in [(labeled_classes_field, True),
(labeled_classes_field, True), (image_classes_field,
(image_classes_field, True), image_classes_field_map_empty_to_ones),
(verified_neg_classes_field, False), (verified_neg_classes_field, False),
(not_exhaustive_field, False)]: (not_exhaustive_field, False)]:
if field in out_tensor_dict: if field in out_tensor_dict:
...@@ -1044,7 +1049,9 @@ def eval_input(eval_config, eval_input_config, model_config, ...@@ -1044,7 +1049,9 @@ def eval_input(eval_config, eval_input_config, model_config,
retain_original_image=eval_config.retain_original_images, retain_original_image=eval_config.retain_original_images,
retain_original_image_additional_channels= retain_original_image_additional_channels=
eval_config.retain_original_image_additional_channels, eval_config.retain_original_image_additional_channels,
keypoint_type_weight=keypoint_type_weight) keypoint_type_weight=keypoint_type_weight,
image_classes_field_map_empty_to_ones=eval_config
.image_classes_field_map_empty_to_ones)
tensor_dict = pad_input_data_to_static_shapes( tensor_dict = pad_input_data_to_static_shapes(
tensor_dict=transform_data_fn(tensor_dict), tensor_dict=transform_data_fn(tensor_dict),
max_num_boxes=eval_input_config.max_number_of_boxes, max_num_boxes=eval_input_config.max_number_of_boxes,
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment