Unverified Commit 44f6d511 authored by Srihari Humbarwadi's avatar Srihari Humbarwadi Committed by GitHub
Browse files

Merge branch 'tensorflow:master' into panoptic-deeplab

parents 686a287d 8bc5a1a5
...@@ -101,15 +101,21 @@ def rewrite_nn_resize_op(is_quantized=False): ...@@ -101,15 +101,21 @@ def rewrite_nn_resize_op(is_quantized=False):
'FakeQuantWithMinMaxVars' if is_quantized else '*') 'FakeQuantWithMinMaxVars' if is_quantized else '*')
stack_1_pattern = graph_matcher.OpTypePattern( stack_1_pattern = graph_matcher.OpTypePattern(
'Pack', inputs=[input_pattern, input_pattern], ordered_inputs=False) 'Pack', inputs=[input_pattern, input_pattern], ordered_inputs=False)
reshape_1_pattern = graph_matcher.OpTypePattern(
'Reshape', inputs=[stack_1_pattern, 'Const'], ordered_inputs=False)
stack_2_pattern = graph_matcher.OpTypePattern( stack_2_pattern = graph_matcher.OpTypePattern(
'Pack', inputs=[stack_1_pattern, stack_1_pattern], ordered_inputs=False) 'Pack',
reshape_pattern = graph_matcher.OpTypePattern( inputs=[reshape_1_pattern, reshape_1_pattern],
ordered_inputs=False)
reshape_2_pattern = graph_matcher.OpTypePattern(
'Reshape', inputs=[stack_2_pattern, 'Const'], ordered_inputs=False) 'Reshape', inputs=[stack_2_pattern, 'Const'], ordered_inputs=False)
consumer_pattern1 = graph_matcher.OpTypePattern( consumer_pattern1 = graph_matcher.OpTypePattern(
'Add|AddV2|Max|Mul', inputs=[reshape_pattern, '*'], 'Add|AddV2|Max|Mul',
inputs=[reshape_2_pattern, '*'],
ordered_inputs=False) ordered_inputs=False)
consumer_pattern2 = graph_matcher.OpTypePattern( consumer_pattern2 = graph_matcher.OpTypePattern(
'StridedSlice', inputs=[reshape_pattern, '*', '*', '*'], 'StridedSlice',
inputs=[reshape_2_pattern, '*', '*', '*'],
ordered_inputs=False) ordered_inputs=False)
def replace_matches(consumer_pattern): def replace_matches(consumer_pattern):
...@@ -119,16 +125,17 @@ def rewrite_nn_resize_op(is_quantized=False): ...@@ -119,16 +125,17 @@ def rewrite_nn_resize_op(is_quantized=False):
for match in matcher.match_graph(tf.get_default_graph()): for match in matcher.match_graph(tf.get_default_graph()):
match_counter += 1 match_counter += 1
projection_op = match.get_op(input_pattern) projection_op = match.get_op(input_pattern)
reshape_op = match.get_op(reshape_pattern) reshape_2_op = match.get_op(reshape_2_pattern)
consumer_op = match.get_op(consumer_pattern) consumer_op = match.get_op(consumer_pattern)
nn_resize = tf.image.resize_nearest_neighbor( nn_resize = tf.image.resize_nearest_neighbor(
projection_op.outputs[0], projection_op.outputs[0],
reshape_op.outputs[0].shape.dims[1:3], reshape_2_op.outputs[0].shape.dims[1:3],
align_corners=False, align_corners=False,
name=os.path.split(reshape_op.name)[0] + '/resize_nearest_neighbor') name=os.path.split(reshape_2_op.name)[0] +
'/resize_nearest_neighbor')
for index, op_input in enumerate(consumer_op.inputs): for index, op_input in enumerate(consumer_op.inputs):
if op_input == reshape_op.outputs[0]: if op_input == reshape_2_op.outputs[0]:
consumer_op._update_input(index, nn_resize) # pylint: disable=protected-access consumer_op._update_input(index, nn_resize) # pylint: disable=protected-access
break break
......
...@@ -1168,16 +1168,20 @@ class ExportInferenceGraphTest(tf.test.TestCase): ...@@ -1168,16 +1168,20 @@ class ExportInferenceGraphTest(tf.test.TestCase):
g = tf.Graph() g = tf.Graph()
with g.as_default(): with g.as_default():
with tf.name_scope('nearest_upsampling'): with tf.name_scope('nearest_upsampling'):
x = array_ops.placeholder(dtypes.float32, shape=(8, 10, 10, 8)) x_1 = array_ops.placeholder(dtypes.float32, shape=(8, 10, 10, 8))
x_stack = tf.stack([tf.stack([x] * 2, axis=3)] * 2, axis=2) x_1_stack_1 = tf.stack([x_1] * 2, axis=3)
x_reshape = tf.reshape(x_stack, [8, 20, 20, 8]) x_1_reshape_1 = tf.reshape(x_1_stack_1, [8, 10, 20, 8])
x_1_stack_2 = tf.stack([x_1_reshape_1] * 2, axis=2)
x_1_reshape_2 = tf.reshape(x_1_stack_2, [8, 20, 20, 8])
with tf.name_scope('nearest_upsampling'): with tf.name_scope('nearest_upsampling'):
x_2 = array_ops.placeholder(dtypes.float32, shape=(8, 10, 10, 8)) x_2 = array_ops.placeholder(dtypes.float32, shape=(8, 10, 10, 8))
x_stack_2 = tf.stack([tf.stack([x_2] * 2, axis=3)] * 2, axis=2) x_2_stack_1 = tf.stack([x_2] * 2, axis=3)
x_reshape_2 = tf.reshape(x_stack_2, [8, 20, 20, 8]) x_2_reshape_1 = tf.reshape(x_2_stack_1, [8, 10, 20, 8])
x_2_stack_2 = tf.stack([x_2_reshape_1] * 2, axis=2)
x_2_reshape_2 = tf.reshape(x_2_stack_2, [8, 20, 20, 8])
t = x_reshape + x_reshape_2 t = x_1_reshape_2 + x_2_reshape_2
exporter.rewrite_nn_resize_op() exporter.rewrite_nn_resize_op()
......
This diff is collapsed.
This diff is collapsed.
...@@ -403,7 +403,7 @@ message CenterNet { ...@@ -403,7 +403,7 @@ message CenterNet {
// Mask prediction support using DeepMAC. See https://arxiv.org/abs/2104.00613 // Mask prediction support using DeepMAC. See https://arxiv.org/abs/2104.00613
// Next ID 25 // Next ID 33
message DeepMACMaskEstimation { message DeepMACMaskEstimation {
// The loss used for penalizing mask predictions. // The loss used for penalizing mask predictions.
optional ClassificationLoss classification_loss = 1; optional ClassificationLoss classification_loss = 1;
...@@ -505,6 +505,21 @@ message CenterNet { ...@@ -505,6 +505,21 @@ message CenterNet {
optional bool use_only_last_stage = 24 [default = false]; optional bool use_only_last_stage = 24 [default = false];
optional float augmented_self_supervision_max_translation = 25 [default=0.0];
optional float augmented_self_supervision_flip_probability = 26 [default=0.0];
optional float augmented_self_supervision_loss_weight = 27 [default=0.0];
optional int32 augmented_self_supervision_warmup_start = 28 [default=0];
optional int32 augmented_self_supervision_warmup_steps = 29 [default=0];
optional AugmentedSelfSupervisionLoss augmented_self_supervision_loss = 30 [default=LOSS_DICE];
optional float augmented_self_supervision_scale_min = 31 [default=1.0];
optional float augmented_self_supervision_scale_max = 32 [default=1.0];
} }
optional DeepMACMaskEstimation deepmac_mask_estimation = 14; optional DeepMACMaskEstimation deepmac_mask_estimation = 14;
...@@ -527,6 +542,13 @@ enum LossNormalize { ...@@ -527,6 +542,13 @@ enum LossNormalize {
NORMALIZE_BALANCED = 3; NORMALIZE_BALANCED = 3;
} }
enum AugmentedSelfSupervisionLoss {
LOSS_UNSET = 0;
LOSS_DICE = 1;
LOSS_MSE = 2;
LOSS_KL_DIV = 3;
}
message CenterNetFeatureExtractor { message CenterNetFeatureExtractor {
optional string type = 1; optional string type = 1;
......
...@@ -3,7 +3,7 @@ syntax = "proto2"; ...@@ -3,7 +3,7 @@ syntax = "proto2";
package object_detection.protos; package object_detection.protos;
// Message for configuring DetectionModel evaluation jobs (eval.py). // Message for configuring DetectionModel evaluation jobs (eval.py).
// Next id - 36 // Next id - 37
message EvalConfig { message EvalConfig {
optional uint32 batch_size = 25 [default = 1]; optional uint32 batch_size = 25 [default = 1];
// Number of visualization images to generate. // Number of visualization images to generate.
...@@ -118,6 +118,11 @@ message EvalConfig { ...@@ -118,6 +118,11 @@ message EvalConfig {
// will be ignored. This is useful for evaluating on test data that are not // will be ignored. This is useful for evaluating on test data that are not
// exhaustively labeled. // exhaustively labeled.
optional bool skip_predictions_for_unlabeled_class = 33 [default = false]; optional bool skip_predictions_for_unlabeled_class = 33 [default = false];
// If image_classes_field for a given image is empty and this field set to
// true, it is interpreted as if the annotations on this image were
// exhaustive.
optional bool image_classes_field_map_empty_to_ones = 36 [default = true];
} }
// A message to configure parameterized evaluation metric. // A message to configure parameterized evaluation metric.
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment