"projects/web/src/vscode:/vscode.git/clone" did not exist on "ece7f8d5a476d6fdcf3aa948f1786e69e8c96aed"
Unverified Commit 44f6d511 authored by Srihari Humbarwadi's avatar Srihari Humbarwadi Committed by GitHub
Browse files

Merge branch 'tensorflow:master' into panoptic-deeplab

parents 686a287d 8bc5a1a5
......@@ -101,15 +101,21 @@ def rewrite_nn_resize_op(is_quantized=False):
'FakeQuantWithMinMaxVars' if is_quantized else '*')
stack_1_pattern = graph_matcher.OpTypePattern(
'Pack', inputs=[input_pattern, input_pattern], ordered_inputs=False)
reshape_1_pattern = graph_matcher.OpTypePattern(
'Reshape', inputs=[stack_1_pattern, 'Const'], ordered_inputs=False)
stack_2_pattern = graph_matcher.OpTypePattern(
'Pack', inputs=[stack_1_pattern, stack_1_pattern], ordered_inputs=False)
reshape_pattern = graph_matcher.OpTypePattern(
'Pack',
inputs=[reshape_1_pattern, reshape_1_pattern],
ordered_inputs=False)
reshape_2_pattern = graph_matcher.OpTypePattern(
'Reshape', inputs=[stack_2_pattern, 'Const'], ordered_inputs=False)
consumer_pattern1 = graph_matcher.OpTypePattern(
'Add|AddV2|Max|Mul', inputs=[reshape_pattern, '*'],
'Add|AddV2|Max|Mul',
inputs=[reshape_2_pattern, '*'],
ordered_inputs=False)
consumer_pattern2 = graph_matcher.OpTypePattern(
'StridedSlice', inputs=[reshape_pattern, '*', '*', '*'],
'StridedSlice',
inputs=[reshape_2_pattern, '*', '*', '*'],
ordered_inputs=False)
def replace_matches(consumer_pattern):
......@@ -119,16 +125,17 @@ def rewrite_nn_resize_op(is_quantized=False):
for match in matcher.match_graph(tf.get_default_graph()):
match_counter += 1
projection_op = match.get_op(input_pattern)
reshape_op = match.get_op(reshape_pattern)
reshape_2_op = match.get_op(reshape_2_pattern)
consumer_op = match.get_op(consumer_pattern)
nn_resize = tf.image.resize_nearest_neighbor(
projection_op.outputs[0],
reshape_op.outputs[0].shape.dims[1:3],
reshape_2_op.outputs[0].shape.dims[1:3],
align_corners=False,
name=os.path.split(reshape_op.name)[0] + '/resize_nearest_neighbor')
name=os.path.split(reshape_2_op.name)[0] +
'/resize_nearest_neighbor')
for index, op_input in enumerate(consumer_op.inputs):
if op_input == reshape_op.outputs[0]:
if op_input == reshape_2_op.outputs[0]:
consumer_op._update_input(index, nn_resize) # pylint: disable=protected-access
break
......
......@@ -1168,16 +1168,20 @@ class ExportInferenceGraphTest(tf.test.TestCase):
g = tf.Graph()
with g.as_default():
with tf.name_scope('nearest_upsampling'):
x = array_ops.placeholder(dtypes.float32, shape=(8, 10, 10, 8))
x_stack = tf.stack([tf.stack([x] * 2, axis=3)] * 2, axis=2)
x_reshape = tf.reshape(x_stack, [8, 20, 20, 8])
x_1 = array_ops.placeholder(dtypes.float32, shape=(8, 10, 10, 8))
x_1_stack_1 = tf.stack([x_1] * 2, axis=3)
x_1_reshape_1 = tf.reshape(x_1_stack_1, [8, 10, 20, 8])
x_1_stack_2 = tf.stack([x_1_reshape_1] * 2, axis=2)
x_1_reshape_2 = tf.reshape(x_1_stack_2, [8, 20, 20, 8])
with tf.name_scope('nearest_upsampling'):
x_2 = array_ops.placeholder(dtypes.float32, shape=(8, 10, 10, 8))
x_stack_2 = tf.stack([tf.stack([x_2] * 2, axis=3)] * 2, axis=2)
x_reshape_2 = tf.reshape(x_stack_2, [8, 20, 20, 8])
x_2_stack_1 = tf.stack([x_2] * 2, axis=3)
x_2_reshape_1 = tf.reshape(x_2_stack_1, [8, 10, 20, 8])
x_2_stack_2 = tf.stack([x_2_reshape_1] * 2, axis=2)
x_2_reshape_2 = tf.reshape(x_2_stack_2, [8, 20, 20, 8])
t = x_reshape + x_reshape_2
t = x_1_reshape_2 + x_2_reshape_2
exporter.rewrite_nn_resize_op()
......
This diff is collapsed.
This diff is collapsed.
......@@ -403,7 +403,7 @@ message CenterNet {
// Mask prediction support using DeepMAC. See https://arxiv.org/abs/2104.00613
// Next ID 25
// Next ID 33
message DeepMACMaskEstimation {
// The loss used for penalizing mask predictions.
optional ClassificationLoss classification_loss = 1;
......@@ -505,6 +505,21 @@ message CenterNet {
optional bool use_only_last_stage = 24 [default = false];
optional float augmented_self_supervision_max_translation = 25 [default=0.0];
optional float augmented_self_supervision_flip_probability = 26 [default=0.0];
optional float augmented_self_supervision_loss_weight = 27 [default=0.0];
optional int32 augmented_self_supervision_warmup_start = 28 [default=0];
optional int32 augmented_self_supervision_warmup_steps = 29 [default=0];
optional AugmentedSelfSupervisionLoss augmented_self_supervision_loss = 30 [default=LOSS_DICE];
optional float augmented_self_supervision_scale_min = 31 [default=1.0];
optional float augmented_self_supervision_scale_max = 32 [default=1.0];
}
optional DeepMACMaskEstimation deepmac_mask_estimation = 14;
......@@ -527,6 +542,13 @@ enum LossNormalize {
NORMALIZE_BALANCED = 3;
}
enum AugmentedSelfSupervisionLoss {
LOSS_UNSET = 0;
LOSS_DICE = 1;
LOSS_MSE = 2;
LOSS_KL_DIV = 3;
}
message CenterNetFeatureExtractor {
optional string type = 1;
......
......@@ -3,7 +3,7 @@ syntax = "proto2";
package object_detection.protos;
// Message for configuring DetectionModel evaluation jobs (eval.py).
// Next id - 36
// Next id - 37
message EvalConfig {
optional uint32 batch_size = 25 [default = 1];
// Number of visualization images to generate.
......@@ -118,6 +118,11 @@ message EvalConfig {
// will be ignored. This is useful for evaluating on test data that are not
// exhaustively labeled.
optional bool skip_predictions_for_unlabeled_class = 33 [default = false];
// If image_classes_field for a given image is empty and this field set to
// true, it is interpreted as if the annotations on this image were
// exhaustive.
optional bool image_classes_field_map_empty_to_ones = 36 [default = true];
}
// A message to configure parameterized evaluation metric.
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment