"git@developer.sourcefind.cn:zhaoyu6/sglang.git" did not exist on "38cd5fb1e06ccfc64ccaa07a0a735093fb91ecad"
Commit 6b72b5cd authored by Zhichao Lu's avatar Zhichao Lu Committed by pkulzc
Browse files

Merged commit includes the following changes:

191649512  by Zhichao Lu:

    Introduce two parameters in ssd.proto - freeze_batchnorm, inplace_batchnorm_update - and set up slim arg_scopes in ssd_meta_arch.py such that applies it to all batchnorm ops in the predict() method.

    This centralizes the control of freezing and doing inplace batchnorm updates.

--
191620303  by Zhichao Lu:

    Modifications to the preprocessor to support multiclass scores

--
191610773  by Zhichao Lu:

    Adding multiclass_scores to InputDataFields and adding padding for multiclass_scores.

--
191595011  by Zhichao Lu:

    Contains implementation of the detection metric for the Open Images Challenge.

--
191449408  by Zhichao Lu:

    Change hyperparams_builder to return a callable so the users can inherit values from outer arg_scopes. This allows us to easily set batch_norm parameters like "is_training" and "inplace_batchnorm_update" for all feature extractors from the base class and propagate it correctly to the nested scopes.

--
191437008  by Zhichao Lu:

    Contains implementation of the Recall@N and MedianRank@N metrics.

--
191385254  by Zhichao Lu:

    Add config rewrite flag to eval.py

--
191382500  by Zhichao Lu:

    Fix bug for config_util.

--

PiperOrigin-RevId: 191649512
parent 143464d2
...@@ -49,12 +49,12 @@ def build(argscope_fn, box_predictor_config, is_training, num_classes): ...@@ -49,12 +49,12 @@ def build(argscope_fn, box_predictor_config, is_training, num_classes):
if box_predictor_oneof == 'convolutional_box_predictor': if box_predictor_oneof == 'convolutional_box_predictor':
conv_box_predictor = box_predictor_config.convolutional_box_predictor conv_box_predictor = box_predictor_config.convolutional_box_predictor
conv_hyperparams = argscope_fn(conv_box_predictor.conv_hyperparams, conv_hyperparams_fn = argscope_fn(conv_box_predictor.conv_hyperparams,
is_training) is_training)
box_predictor_object = box_predictor.ConvolutionalBoxPredictor( box_predictor_object = box_predictor.ConvolutionalBoxPredictor(
is_training=is_training, is_training=is_training,
num_classes=num_classes, num_classes=num_classes,
conv_hyperparams=conv_hyperparams, conv_hyperparams_fn=conv_hyperparams_fn,
min_depth=conv_box_predictor.min_depth, min_depth=conv_box_predictor.min_depth,
max_depth=conv_box_predictor.max_depth, max_depth=conv_box_predictor.max_depth,
num_layers_before_predictor=(conv_box_predictor. num_layers_before_predictor=(conv_box_predictor.
...@@ -73,12 +73,12 @@ def build(argscope_fn, box_predictor_config, is_training, num_classes): ...@@ -73,12 +73,12 @@ def build(argscope_fn, box_predictor_config, is_training, num_classes):
if box_predictor_oneof == 'weight_shared_convolutional_box_predictor': if box_predictor_oneof == 'weight_shared_convolutional_box_predictor':
conv_box_predictor = (box_predictor_config. conv_box_predictor = (box_predictor_config.
weight_shared_convolutional_box_predictor) weight_shared_convolutional_box_predictor)
conv_hyperparams = argscope_fn(conv_box_predictor.conv_hyperparams, conv_hyperparams_fn = argscope_fn(conv_box_predictor.conv_hyperparams,
is_training) is_training)
box_predictor_object = box_predictor.WeightSharedConvolutionalBoxPredictor( box_predictor_object = box_predictor.WeightSharedConvolutionalBoxPredictor(
is_training=is_training, is_training=is_training,
num_classes=num_classes, num_classes=num_classes,
conv_hyperparams=conv_hyperparams, conv_hyperparams_fn=conv_hyperparams_fn,
depth=conv_box_predictor.depth, depth=conv_box_predictor.depth,
num_layers_before_predictor=(conv_box_predictor. num_layers_before_predictor=(conv_box_predictor.
num_layers_before_predictor), num_layers_before_predictor),
...@@ -90,20 +90,20 @@ def build(argscope_fn, box_predictor_config, is_training, num_classes): ...@@ -90,20 +90,20 @@ def build(argscope_fn, box_predictor_config, is_training, num_classes):
if box_predictor_oneof == 'mask_rcnn_box_predictor': if box_predictor_oneof == 'mask_rcnn_box_predictor':
mask_rcnn_box_predictor = box_predictor_config.mask_rcnn_box_predictor mask_rcnn_box_predictor = box_predictor_config.mask_rcnn_box_predictor
fc_hyperparams = argscope_fn(mask_rcnn_box_predictor.fc_hyperparams, fc_hyperparams_fn = argscope_fn(mask_rcnn_box_predictor.fc_hyperparams,
is_training) is_training)
conv_hyperparams = None conv_hyperparams_fn = None
if mask_rcnn_box_predictor.HasField('conv_hyperparams'): if mask_rcnn_box_predictor.HasField('conv_hyperparams'):
conv_hyperparams = argscope_fn(mask_rcnn_box_predictor.conv_hyperparams, conv_hyperparams_fn = argscope_fn(
is_training) mask_rcnn_box_predictor.conv_hyperparams, is_training)
box_predictor_object = box_predictor.MaskRCNNBoxPredictor( box_predictor_object = box_predictor.MaskRCNNBoxPredictor(
is_training=is_training, is_training=is_training,
num_classes=num_classes, num_classes=num_classes,
fc_hyperparams=fc_hyperparams, fc_hyperparams_fn=fc_hyperparams_fn,
use_dropout=mask_rcnn_box_predictor.use_dropout, use_dropout=mask_rcnn_box_predictor.use_dropout,
dropout_keep_prob=mask_rcnn_box_predictor.dropout_keep_probability, dropout_keep_prob=mask_rcnn_box_predictor.dropout_keep_probability,
box_code_size=mask_rcnn_box_predictor.box_code_size, box_code_size=mask_rcnn_box_predictor.box_code_size,
conv_hyperparams=conv_hyperparams, conv_hyperparams_fn=conv_hyperparams_fn,
predict_instance_masks=mask_rcnn_box_predictor.predict_instance_masks, predict_instance_masks=mask_rcnn_box_predictor.predict_instance_masks,
mask_height=mask_rcnn_box_predictor.mask_height, mask_height=mask_rcnn_box_predictor.mask_height,
mask_width=mask_rcnn_box_predictor.mask_width, mask_width=mask_rcnn_box_predictor.mask_width,
...@@ -116,12 +116,12 @@ def build(argscope_fn, box_predictor_config, is_training, num_classes): ...@@ -116,12 +116,12 @@ def build(argscope_fn, box_predictor_config, is_training, num_classes):
if box_predictor_oneof == 'rfcn_box_predictor': if box_predictor_oneof == 'rfcn_box_predictor':
rfcn_box_predictor = box_predictor_config.rfcn_box_predictor rfcn_box_predictor = box_predictor_config.rfcn_box_predictor
conv_hyperparams = argscope_fn(rfcn_box_predictor.conv_hyperparams, conv_hyperparams_fn = argscope_fn(rfcn_box_predictor.conv_hyperparams,
is_training) is_training)
box_predictor_object = box_predictor.RfcnBoxPredictor( box_predictor_object = box_predictor.RfcnBoxPredictor(
is_training=is_training, is_training=is_training,
num_classes=num_classes, num_classes=num_classes,
conv_hyperparams=conv_hyperparams, conv_hyperparams_fn=conv_hyperparams_fn,
crop_size=[rfcn_box_predictor.crop_height, crop_size=[rfcn_box_predictor.crop_height,
rfcn_box_predictor.crop_width], rfcn_box_predictor.crop_width],
num_spatial_bins=[rfcn_box_predictor.num_spatial_bins_height, num_spatial_bins=[rfcn_box_predictor.num_spatial_bins_height,
......
...@@ -54,7 +54,7 @@ class ConvolutionalBoxPredictorBuilderTest(tf.test.TestCase): ...@@ -54,7 +54,7 @@ class ConvolutionalBoxPredictorBuilderTest(tf.test.TestCase):
box_predictor_config=box_predictor_proto, box_predictor_config=box_predictor_proto,
is_training=False, is_training=False,
num_classes=10) num_classes=10)
(conv_hyperparams_actual, is_training) = box_predictor._conv_hyperparams (conv_hyperparams_actual, is_training) = box_predictor._conv_hyperparams_fn
self.assertAlmostEqual((hyperparams_proto.regularizer. self.assertAlmostEqual((hyperparams_proto.regularizer.
l1_regularizer.weight), l1_regularizer.weight),
(conv_hyperparams_actual.regularizer.l1_regularizer. (conv_hyperparams_actual.regularizer.l1_regularizer.
...@@ -183,7 +183,7 @@ class WeightSharedConvolutionalBoxPredictorBuilderTest(tf.test.TestCase): ...@@ -183,7 +183,7 @@ class WeightSharedConvolutionalBoxPredictorBuilderTest(tf.test.TestCase):
box_predictor_config=box_predictor_proto, box_predictor_config=box_predictor_proto,
is_training=False, is_training=False,
num_classes=10) num_classes=10)
(conv_hyperparams_actual, is_training) = box_predictor._conv_hyperparams (conv_hyperparams_actual, is_training) = box_predictor._conv_hyperparams_fn
self.assertAlmostEqual((hyperparams_proto.regularizer. self.assertAlmostEqual((hyperparams_proto.regularizer.
l1_regularizer.weight), l1_regularizer.weight),
(conv_hyperparams_actual.regularizer.l1_regularizer. (conv_hyperparams_actual.regularizer.l1_regularizer.
...@@ -297,7 +297,7 @@ class MaskRCNNBoxPredictorBuilderTest(tf.test.TestCase): ...@@ -297,7 +297,7 @@ class MaskRCNNBoxPredictorBuilderTest(tf.test.TestCase):
is_training=False, is_training=False,
num_classes=10) num_classes=10)
mock_argscope_fn.assert_called_with(hyperparams_proto, False) mock_argscope_fn.assert_called_with(hyperparams_proto, False)
self.assertEqual(box_predictor._fc_hyperparams, 'arg_scope') self.assertEqual(box_predictor._fc_hyperparams_fn, 'arg_scope')
def test_non_default_mask_rcnn_box_predictor(self): def test_non_default_mask_rcnn_box_predictor(self):
fc_hyperparams_text_proto = """ fc_hyperparams_text_proto = """
...@@ -417,7 +417,7 @@ class RfcnBoxPredictorBuilderTest(tf.test.TestCase): ...@@ -417,7 +417,7 @@ class RfcnBoxPredictorBuilderTest(tf.test.TestCase):
box_predictor_config=box_predictor_proto, box_predictor_config=box_predictor_proto,
is_training=False, is_training=False,
num_classes=10) num_classes=10)
(conv_hyperparams_actual, is_training) = box_predictor._conv_hyperparams (conv_hyperparams_actual, is_training) = box_predictor._conv_hyperparams_fn
self.assertAlmostEqual((hyperparams_proto.regularizer. self.assertAlmostEqual((hyperparams_proto.regularizer.
l1_regularizer.weight), l1_regularizer.weight),
(conv_hyperparams_actual.regularizer.l1_regularizer. (conv_hyperparams_actual.regularizer.l1_regularizer.
......
...@@ -72,7 +72,9 @@ def _get_padding_shapes(dataset, max_num_boxes=None, num_classes=None, ...@@ -72,7 +72,9 @@ def _get_padding_shapes(dataset, max_num_boxes=None, num_classes=None,
fields.InputDataFields.num_groundtruth_boxes: [], fields.InputDataFields.num_groundtruth_boxes: [],
fields.InputDataFields.groundtruth_label_types: [max_num_boxes], fields.InputDataFields.groundtruth_label_types: [max_num_boxes],
fields.InputDataFields.groundtruth_label_scores: [max_num_boxes], fields.InputDataFields.groundtruth_label_scores: [max_num_boxes],
fields.InputDataFields.true_image_shape: [3] fields.InputDataFields.true_image_shape: [3],
fields.InputDataFields.multiclass_scores: [
max_num_boxes, num_classes + 1 if num_classes is not None else None],
} }
# Determine whether groundtruth_classes are integers or one-hot encodings, and # Determine whether groundtruth_classes are integers or one-hot encodings, and
# apply batching appropriately. # apply batching appropriately.
......
...@@ -43,7 +43,8 @@ def build(hyperparams_config, is_training): ...@@ -43,7 +43,8 @@ def build(hyperparams_config, is_training):
is_training: Whether the network is in training mode. is_training: Whether the network is in training mode.
Returns: Returns:
arg_scope: tf-slim arg_scope containing hyperparameters for ops. arg_scope_fn: A function to construct tf-slim arg_scope containing
hyperparameters for ops.
Raises: Raises:
ValueError: if hyperparams_config is not of type hyperparams.Hyperparams. ValueError: if hyperparams_config is not of type hyperparams.Hyperparams.
...@@ -64,16 +65,18 @@ def build(hyperparams_config, is_training): ...@@ -64,16 +65,18 @@ def build(hyperparams_config, is_training):
if hyperparams_config.HasField('op') and ( if hyperparams_config.HasField('op') and (
hyperparams_config.op == hyperparams_pb2.Hyperparams.FC): hyperparams_config.op == hyperparams_pb2.Hyperparams.FC):
affected_ops = [slim.fully_connected] affected_ops = [slim.fully_connected]
with slim.arg_scope( def scope_fn():
affected_ops, with slim.arg_scope(
weights_regularizer=_build_regularizer( affected_ops,
hyperparams_config.regularizer), weights_regularizer=_build_regularizer(
weights_initializer=_build_initializer( hyperparams_config.regularizer),
hyperparams_config.initializer), weights_initializer=_build_initializer(
activation_fn=_build_activation_fn(hyperparams_config.activation), hyperparams_config.initializer),
normalizer_fn=batch_norm, activation_fn=_build_activation_fn(hyperparams_config.activation),
normalizer_params=batch_norm_params) as sc: normalizer_fn=batch_norm,
return sc normalizer_params=batch_norm_params) as sc:
return sc
return scope_fn
def _build_activation_fn(activation_fn): def _build_activation_fn(activation_fn):
......
...@@ -45,7 +45,9 @@ class HyperparamsBuilderTest(tf.test.TestCase): ...@@ -45,7 +45,9 @@ class HyperparamsBuilderTest(tf.test.TestCase):
""" """
conv_hyperparams_proto = hyperparams_pb2.Hyperparams() conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto) text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
scope = hyperparams_builder.build(conv_hyperparams_proto, is_training=True) scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
is_training=True)
scope = scope_fn()
self.assertTrue(self._get_scope_key(slim.conv2d) in scope) self.assertTrue(self._get_scope_key(slim.conv2d) in scope)
def test_default_arg_scope_has_separable_conv2d_op(self): def test_default_arg_scope_has_separable_conv2d_op(self):
...@@ -61,7 +63,9 @@ class HyperparamsBuilderTest(tf.test.TestCase): ...@@ -61,7 +63,9 @@ class HyperparamsBuilderTest(tf.test.TestCase):
""" """
conv_hyperparams_proto = hyperparams_pb2.Hyperparams() conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto) text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
scope = hyperparams_builder.build(conv_hyperparams_proto, is_training=True) scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
is_training=True)
scope = scope_fn()
self.assertTrue(self._get_scope_key(slim.separable_conv2d) in scope) self.assertTrue(self._get_scope_key(slim.separable_conv2d) in scope)
def test_default_arg_scope_has_conv2d_transpose_op(self): def test_default_arg_scope_has_conv2d_transpose_op(self):
...@@ -77,7 +81,9 @@ class HyperparamsBuilderTest(tf.test.TestCase): ...@@ -77,7 +81,9 @@ class HyperparamsBuilderTest(tf.test.TestCase):
""" """
conv_hyperparams_proto = hyperparams_pb2.Hyperparams() conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto) text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
scope = hyperparams_builder.build(conv_hyperparams_proto, is_training=True) scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
is_training=True)
scope = scope_fn()
self.assertTrue(self._get_scope_key(slim.conv2d_transpose) in scope) self.assertTrue(self._get_scope_key(slim.conv2d_transpose) in scope)
def test_explicit_fc_op_arg_scope_has_fully_connected_op(self): def test_explicit_fc_op_arg_scope_has_fully_connected_op(self):
...@@ -94,7 +100,9 @@ class HyperparamsBuilderTest(tf.test.TestCase): ...@@ -94,7 +100,9 @@ class HyperparamsBuilderTest(tf.test.TestCase):
""" """
conv_hyperparams_proto = hyperparams_pb2.Hyperparams() conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto) text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
scope = hyperparams_builder.build(conv_hyperparams_proto, is_training=True) scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
is_training=True)
scope = scope_fn()
self.assertTrue(self._get_scope_key(slim.fully_connected) in scope) self.assertTrue(self._get_scope_key(slim.fully_connected) in scope)
def test_separable_conv2d_and_conv2d_and_transpose_have_same_parameters(self): def test_separable_conv2d_and_conv2d_and_transpose_have_same_parameters(self):
...@@ -110,7 +118,9 @@ class HyperparamsBuilderTest(tf.test.TestCase): ...@@ -110,7 +118,9 @@ class HyperparamsBuilderTest(tf.test.TestCase):
""" """
conv_hyperparams_proto = hyperparams_pb2.Hyperparams() conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto) text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
scope = hyperparams_builder.build(conv_hyperparams_proto, is_training=True) scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
is_training=True)
scope = scope_fn()
kwargs_1, kwargs_2, kwargs_3 = scope.values() kwargs_1, kwargs_2, kwargs_3 = scope.values()
self.assertDictEqual(kwargs_1, kwargs_2) self.assertDictEqual(kwargs_1, kwargs_2)
self.assertDictEqual(kwargs_1, kwargs_3) self.assertDictEqual(kwargs_1, kwargs_3)
...@@ -129,7 +139,9 @@ class HyperparamsBuilderTest(tf.test.TestCase): ...@@ -129,7 +139,9 @@ class HyperparamsBuilderTest(tf.test.TestCase):
""" """
conv_hyperparams_proto = hyperparams_pb2.Hyperparams() conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto) text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
scope = hyperparams_builder.build(conv_hyperparams_proto, is_training=True) scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
is_training=True)
scope = scope_fn()
conv_scope_arguments = scope.values()[0] conv_scope_arguments = scope.values()[0]
regularizer = conv_scope_arguments['weights_regularizer'] regularizer = conv_scope_arguments['weights_regularizer']
weights = np.array([1., -1, 4., 2.]) weights = np.array([1., -1, 4., 2.])
...@@ -151,7 +163,9 @@ class HyperparamsBuilderTest(tf.test.TestCase): ...@@ -151,7 +163,9 @@ class HyperparamsBuilderTest(tf.test.TestCase):
""" """
conv_hyperparams_proto = hyperparams_pb2.Hyperparams() conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto) text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
scope = hyperparams_builder.build(conv_hyperparams_proto, is_training=True) scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
is_training=True)
scope = scope_fn()
conv_scope_arguments = scope.values()[0] conv_scope_arguments = scope.values()[0]
regularizer = conv_scope_arguments['weights_regularizer'] regularizer = conv_scope_arguments['weights_regularizer']
...@@ -180,7 +194,9 @@ class HyperparamsBuilderTest(tf.test.TestCase): ...@@ -180,7 +194,9 @@ class HyperparamsBuilderTest(tf.test.TestCase):
""" """
conv_hyperparams_proto = hyperparams_pb2.Hyperparams() conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto) text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
scope = hyperparams_builder.build(conv_hyperparams_proto, is_training=True) scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
is_training=True)
scope = scope_fn()
conv_scope_arguments = scope.values()[0] conv_scope_arguments = scope.values()[0]
self.assertEqual(conv_scope_arguments['normalizer_fn'], slim.batch_norm) self.assertEqual(conv_scope_arguments['normalizer_fn'], slim.batch_norm)
batch_norm_params = conv_scope_arguments['normalizer_params'] batch_norm_params = conv_scope_arguments['normalizer_params']
...@@ -210,7 +226,9 @@ class HyperparamsBuilderTest(tf.test.TestCase): ...@@ -210,7 +226,9 @@ class HyperparamsBuilderTest(tf.test.TestCase):
""" """
conv_hyperparams_proto = hyperparams_pb2.Hyperparams() conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto) text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
scope = hyperparams_builder.build(conv_hyperparams_proto, is_training=False) scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
is_training=False)
scope = scope_fn()
conv_scope_arguments = scope.values()[0] conv_scope_arguments = scope.values()[0]
self.assertEqual(conv_scope_arguments['normalizer_fn'], slim.batch_norm) self.assertEqual(conv_scope_arguments['normalizer_fn'], slim.batch_norm)
batch_norm_params = conv_scope_arguments['normalizer_params'] batch_norm_params = conv_scope_arguments['normalizer_params']
...@@ -240,7 +258,9 @@ class HyperparamsBuilderTest(tf.test.TestCase): ...@@ -240,7 +258,9 @@ class HyperparamsBuilderTest(tf.test.TestCase):
""" """
conv_hyperparams_proto = hyperparams_pb2.Hyperparams() conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto) text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
scope = hyperparams_builder.build(conv_hyperparams_proto, is_training=True) scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
is_training=True)
scope = scope_fn()
conv_scope_arguments = scope.values()[0] conv_scope_arguments = scope.values()[0]
self.assertEqual(conv_scope_arguments['normalizer_fn'], slim.batch_norm) self.assertEqual(conv_scope_arguments['normalizer_fn'], slim.batch_norm)
batch_norm_params = conv_scope_arguments['normalizer_params'] batch_norm_params = conv_scope_arguments['normalizer_params']
...@@ -263,7 +283,9 @@ class HyperparamsBuilderTest(tf.test.TestCase): ...@@ -263,7 +283,9 @@ class HyperparamsBuilderTest(tf.test.TestCase):
""" """
conv_hyperparams_proto = hyperparams_pb2.Hyperparams() conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto) text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
scope = hyperparams_builder.build(conv_hyperparams_proto, is_training=True) scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
is_training=True)
scope = scope_fn()
conv_scope_arguments = scope.values()[0] conv_scope_arguments = scope.values()[0]
self.assertEqual(conv_scope_arguments['normalizer_fn'], None) self.assertEqual(conv_scope_arguments['normalizer_fn'], None)
self.assertEqual(conv_scope_arguments['normalizer_params'], None) self.assertEqual(conv_scope_arguments['normalizer_params'], None)
...@@ -282,7 +304,9 @@ class HyperparamsBuilderTest(tf.test.TestCase): ...@@ -282,7 +304,9 @@ class HyperparamsBuilderTest(tf.test.TestCase):
""" """
conv_hyperparams_proto = hyperparams_pb2.Hyperparams() conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto) text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
scope = hyperparams_builder.build(conv_hyperparams_proto, is_training=True) scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
is_training=True)
scope = scope_fn()
conv_scope_arguments = scope.values()[0] conv_scope_arguments = scope.values()[0]
self.assertEqual(conv_scope_arguments['activation_fn'], None) self.assertEqual(conv_scope_arguments['activation_fn'], None)
...@@ -300,7 +324,9 @@ class HyperparamsBuilderTest(tf.test.TestCase): ...@@ -300,7 +324,9 @@ class HyperparamsBuilderTest(tf.test.TestCase):
""" """
conv_hyperparams_proto = hyperparams_pb2.Hyperparams() conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto) text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
scope = hyperparams_builder.build(conv_hyperparams_proto, is_training=True) scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
is_training=True)
scope = scope_fn()
conv_scope_arguments = scope.values()[0] conv_scope_arguments = scope.values()[0]
self.assertEqual(conv_scope_arguments['activation_fn'], tf.nn.relu) self.assertEqual(conv_scope_arguments['activation_fn'], tf.nn.relu)
...@@ -318,7 +344,9 @@ class HyperparamsBuilderTest(tf.test.TestCase): ...@@ -318,7 +344,9 @@ class HyperparamsBuilderTest(tf.test.TestCase):
""" """
conv_hyperparams_proto = hyperparams_pb2.Hyperparams() conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto) text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
scope = hyperparams_builder.build(conv_hyperparams_proto, is_training=True) scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
is_training=True)
scope = scope_fn()
conv_scope_arguments = scope.values()[0] conv_scope_arguments = scope.values()[0]
self.assertEqual(conv_scope_arguments['activation_fn'], tf.nn.relu6) self.assertEqual(conv_scope_arguments['activation_fn'], tf.nn.relu6)
...@@ -351,7 +379,9 @@ class HyperparamsBuilderTest(tf.test.TestCase): ...@@ -351,7 +379,9 @@ class HyperparamsBuilderTest(tf.test.TestCase):
""" """
conv_hyperparams_proto = hyperparams_pb2.Hyperparams() conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto) text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
scope = hyperparams_builder.build(conv_hyperparams_proto, is_training=True) scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
is_training=True)
scope = scope_fn()
conv_scope_arguments = scope.values()[0] conv_scope_arguments = scope.values()[0]
initializer = conv_scope_arguments['weights_initializer'] initializer = conv_scope_arguments['weights_initializer']
self._assert_variance_in_range(initializer, shape=[100, 40], self._assert_variance_in_range(initializer, shape=[100, 40],
...@@ -373,7 +403,9 @@ class HyperparamsBuilderTest(tf.test.TestCase): ...@@ -373,7 +403,9 @@ class HyperparamsBuilderTest(tf.test.TestCase):
""" """
conv_hyperparams_proto = hyperparams_pb2.Hyperparams() conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto) text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
scope = hyperparams_builder.build(conv_hyperparams_proto, is_training=True) scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
is_training=True)
scope = scope_fn()
conv_scope_arguments = scope.values()[0] conv_scope_arguments = scope.values()[0]
initializer = conv_scope_arguments['weights_initializer'] initializer = conv_scope_arguments['weights_initializer']
self._assert_variance_in_range(initializer, shape=[100, 40], self._assert_variance_in_range(initializer, shape=[100, 40],
...@@ -395,7 +427,9 @@ class HyperparamsBuilderTest(tf.test.TestCase): ...@@ -395,7 +427,9 @@ class HyperparamsBuilderTest(tf.test.TestCase):
""" """
conv_hyperparams_proto = hyperparams_pb2.Hyperparams() conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto) text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
scope = hyperparams_builder.build(conv_hyperparams_proto, is_training=True) scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
is_training=True)
scope = scope_fn()
conv_scope_arguments = scope.values()[0] conv_scope_arguments = scope.values()[0]
initializer = conv_scope_arguments['weights_initializer'] initializer = conv_scope_arguments['weights_initializer']
self._assert_variance_in_range(initializer, shape=[100, 40], self._assert_variance_in_range(initializer, shape=[100, 40],
...@@ -417,7 +451,9 @@ class HyperparamsBuilderTest(tf.test.TestCase): ...@@ -417,7 +451,9 @@ class HyperparamsBuilderTest(tf.test.TestCase):
""" """
conv_hyperparams_proto = hyperparams_pb2.Hyperparams() conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto) text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
scope = hyperparams_builder.build(conv_hyperparams_proto, is_training=True) scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
is_training=True)
scope = scope_fn()
conv_scope_arguments = scope.values()[0] conv_scope_arguments = scope.values()[0]
initializer = conv_scope_arguments['weights_initializer'] initializer = conv_scope_arguments['weights_initializer']
self._assert_variance_in_range(initializer, shape=[100, 40], self._assert_variance_in_range(initializer, shape=[100, 40],
...@@ -438,7 +474,9 @@ class HyperparamsBuilderTest(tf.test.TestCase): ...@@ -438,7 +474,9 @@ class HyperparamsBuilderTest(tf.test.TestCase):
""" """
conv_hyperparams_proto = hyperparams_pb2.Hyperparams() conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto) text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
scope = hyperparams_builder.build(conv_hyperparams_proto, is_training=True) scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
is_training=True)
scope = scope_fn()
conv_scope_arguments = scope.values()[0] conv_scope_arguments = scope.values()[0]
initializer = conv_scope_arguments['weights_initializer'] initializer = conv_scope_arguments['weights_initializer']
self._assert_variance_in_range(initializer, shape=[100, 40], self._assert_variance_in_range(initializer, shape=[100, 40],
...@@ -459,7 +497,9 @@ class HyperparamsBuilderTest(tf.test.TestCase): ...@@ -459,7 +497,9 @@ class HyperparamsBuilderTest(tf.test.TestCase):
""" """
conv_hyperparams_proto = hyperparams_pb2.Hyperparams() conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto) text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
scope = hyperparams_builder.build(conv_hyperparams_proto, is_training=True) scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
is_training=True)
scope = scope_fn()
conv_scope_arguments = scope.values()[0] conv_scope_arguments = scope.values()[0]
initializer = conv_scope_arguments['weights_initializer'] initializer = conv_scope_arguments['weights_initializer']
self._assert_variance_in_range(initializer, shape=[100, 40], self._assert_variance_in_range(initializer, shape=[100, 40],
......
...@@ -98,19 +98,13 @@ def build(model_config, is_training, add_summaries=True): ...@@ -98,19 +98,13 @@ def build(model_config, is_training, add_summaries=True):
def _build_ssd_feature_extractor(feature_extractor_config, is_training, def _build_ssd_feature_extractor(feature_extractor_config, is_training,
reuse_weights=None, reuse_weights=None):
inplace_batchnorm_update=False):
"""Builds a ssd_meta_arch.SSDFeatureExtractor based on config. """Builds a ssd_meta_arch.SSDFeatureExtractor based on config.
Args: Args:
feature_extractor_config: A SSDFeatureExtractor proto config from ssd.proto. feature_extractor_config: A SSDFeatureExtractor proto config from ssd.proto.
is_training: True if this feature extractor is being built for training. is_training: True if this feature extractor is being built for training.
reuse_weights: if the feature extractor should reuse weights. reuse_weights: if the feature extractor should reuse weights.
inplace_batchnorm_update: Whether to update batch_norm inplace during
training. This is required for batch norm to work correctly on TPUs. When
this is false, user must add a control dependency on
tf.GraphKeys.UPDATE_OPS for train/loss op in order to update the batch
norm moving average parameters.
Returns: Returns:
ssd_meta_arch.SSDFeatureExtractor based on config. ssd_meta_arch.SSDFeatureExtractor based on config.
...@@ -122,7 +116,6 @@ def _build_ssd_feature_extractor(feature_extractor_config, is_training, ...@@ -122,7 +116,6 @@ def _build_ssd_feature_extractor(feature_extractor_config, is_training,
depth_multiplier = feature_extractor_config.depth_multiplier depth_multiplier = feature_extractor_config.depth_multiplier
min_depth = feature_extractor_config.min_depth min_depth = feature_extractor_config.min_depth
pad_to_multiple = feature_extractor_config.pad_to_multiple pad_to_multiple = feature_extractor_config.pad_to_multiple
batch_norm_trainable = feature_extractor_config.batch_norm_trainable
use_explicit_padding = feature_extractor_config.use_explicit_padding use_explicit_padding = feature_extractor_config.use_explicit_padding
use_depthwise = feature_extractor_config.use_depthwise use_depthwise = feature_extractor_config.use_depthwise
conv_hyperparams = hyperparams_builder.build( conv_hyperparams = hyperparams_builder.build(
...@@ -132,11 +125,9 @@ def _build_ssd_feature_extractor(feature_extractor_config, is_training, ...@@ -132,11 +125,9 @@ def _build_ssd_feature_extractor(feature_extractor_config, is_training,
raise ValueError('Unknown ssd feature_extractor: {}'.format(feature_type)) raise ValueError('Unknown ssd feature_extractor: {}'.format(feature_type))
feature_extractor_class = SSD_FEATURE_EXTRACTOR_CLASS_MAP[feature_type] feature_extractor_class = SSD_FEATURE_EXTRACTOR_CLASS_MAP[feature_type]
return feature_extractor_class(is_training, depth_multiplier, min_depth, return feature_extractor_class(
pad_to_multiple, conv_hyperparams, is_training, depth_multiplier, min_depth, pad_to_multiple,
batch_norm_trainable, reuse_weights, conv_hyperparams, reuse_weights, use_explicit_padding, use_depthwise)
use_explicit_padding, use_depthwise,
inplace_batchnorm_update)
def _build_ssd_model(ssd_config, is_training, add_summaries): def _build_ssd_model(ssd_config, is_training, add_summaries):
...@@ -160,8 +151,7 @@ def _build_ssd_model(ssd_config, is_training, add_summaries): ...@@ -160,8 +151,7 @@ def _build_ssd_model(ssd_config, is_training, add_summaries):
# Feature extractor # Feature extractor
feature_extractor = _build_ssd_feature_extractor( feature_extractor = _build_ssd_feature_extractor(
feature_extractor_config=ssd_config.feature_extractor, feature_extractor_config=ssd_config.feature_extractor,
is_training=is_training, is_training=is_training)
inplace_batchnorm_update=ssd_config.inplace_batchnorm_update)
box_coder = box_coder_builder.build(ssd_config.box_coder) box_coder = box_coder_builder.build(ssd_config.box_coder)
matcher = matcher_builder.build(ssd_config.matcher) matcher = matcher_builder.build(ssd_config.matcher)
...@@ -203,7 +193,9 @@ def _build_ssd_model(ssd_config, is_training, add_summaries): ...@@ -203,7 +193,9 @@ def _build_ssd_model(ssd_config, is_training, add_summaries):
normalize_loss_by_num_matches, normalize_loss_by_num_matches,
hard_example_miner, hard_example_miner,
add_summaries=add_summaries, add_summaries=add_summaries,
normalize_loc_loss_by_codesize=normalize_loc_loss_by_codesize) normalize_loc_loss_by_codesize=normalize_loc_loss_by_codesize,
freeze_batchnorm=ssd_config.freeze_batchnorm,
inplace_batchnorm_update=ssd_config.inplace_batchnorm_update)
def _build_faster_rcnn_feature_extractor( def _build_faster_rcnn_feature_extractor(
...@@ -276,7 +268,7 @@ def _build_faster_rcnn_model(frcnn_config, is_training, add_summaries): ...@@ -276,7 +268,7 @@ def _build_faster_rcnn_model(frcnn_config, is_training, add_summaries):
frcnn_config.first_stage_anchor_generator) frcnn_config.first_stage_anchor_generator)
first_stage_atrous_rate = frcnn_config.first_stage_atrous_rate first_stage_atrous_rate = frcnn_config.first_stage_atrous_rate
first_stage_box_predictor_arg_scope = hyperparams_builder.build( first_stage_box_predictor_arg_scope_fn = hyperparams_builder.build(
frcnn_config.first_stage_box_predictor_conv_hyperparams, is_training) frcnn_config.first_stage_box_predictor_conv_hyperparams, is_training)
first_stage_box_predictor_kernel_size = ( first_stage_box_predictor_kernel_size = (
frcnn_config.first_stage_box_predictor_kernel_size) frcnn_config.first_stage_box_predictor_kernel_size)
...@@ -329,8 +321,8 @@ def _build_faster_rcnn_model(frcnn_config, is_training, add_summaries): ...@@ -329,8 +321,8 @@ def _build_faster_rcnn_model(frcnn_config, is_training, add_summaries):
'number_of_stages': number_of_stages, 'number_of_stages': number_of_stages,
'first_stage_anchor_generator': first_stage_anchor_generator, 'first_stage_anchor_generator': first_stage_anchor_generator,
'first_stage_atrous_rate': first_stage_atrous_rate, 'first_stage_atrous_rate': first_stage_atrous_rate,
'first_stage_box_predictor_arg_scope': 'first_stage_box_predictor_arg_scope_fn':
first_stage_box_predictor_arg_scope, first_stage_box_predictor_arg_scope_fn,
'first_stage_box_predictor_kernel_size': 'first_stage_box_predictor_kernel_size':
first_stage_box_predictor_kernel_size, first_stage_box_predictor_kernel_size,
'first_stage_box_predictor_depth': first_stage_box_predictor_depth, 'first_stage_box_predictor_depth': first_stage_box_predictor_depth,
......
...@@ -225,7 +225,6 @@ class ModelBuilderTest(tf.test.TestCase): ...@@ -225,7 +225,6 @@ class ModelBuilderTest(tf.test.TestCase):
} }
} }
} }
batch_norm_trainable: true
} }
box_coder { box_coder {
faster_rcnn_box_coder { faster_rcnn_box_coder {
...@@ -298,6 +297,7 @@ class ModelBuilderTest(tf.test.TestCase): ...@@ -298,6 +297,7 @@ class ModelBuilderTest(tf.test.TestCase):
def test_create_ssd_mobilenet_v1_model_from_config(self): def test_create_ssd_mobilenet_v1_model_from_config(self):
model_text_proto = """ model_text_proto = """
ssd { ssd {
freeze_batchnorm: true
inplace_batchnorm_update: true inplace_batchnorm_update: true
feature_extractor { feature_extractor {
type: 'ssd_mobilenet_v1' type: 'ssd_mobilenet_v1'
...@@ -311,7 +311,6 @@ class ModelBuilderTest(tf.test.TestCase): ...@@ -311,7 +311,6 @@ class ModelBuilderTest(tf.test.TestCase):
} }
} }
} }
batch_norm_trainable: true
} }
box_coder { box_coder {
faster_rcnn_box_coder { faster_rcnn_box_coder {
...@@ -368,8 +367,9 @@ class ModelBuilderTest(tf.test.TestCase): ...@@ -368,8 +367,9 @@ class ModelBuilderTest(tf.test.TestCase):
self.assertIsInstance(model, ssd_meta_arch.SSDMetaArch) self.assertIsInstance(model, ssd_meta_arch.SSDMetaArch)
self.assertIsInstance(model._feature_extractor, self.assertIsInstance(model._feature_extractor,
SSDMobileNetV1FeatureExtractor) SSDMobileNetV1FeatureExtractor)
self.assertTrue(model._feature_extractor._batch_norm_trainable)
self.assertTrue(model._normalize_loc_loss_by_codesize) self.assertTrue(model._normalize_loc_loss_by_codesize)
self.assertTrue(model._freeze_batchnorm)
self.assertTrue(model._inplace_batchnorm_update)
def test_create_ssd_mobilenet_v2_model_from_config(self): def test_create_ssd_mobilenet_v2_model_from_config(self):
model_text_proto = """ model_text_proto = """
...@@ -386,7 +386,6 @@ class ModelBuilderTest(tf.test.TestCase): ...@@ -386,7 +386,6 @@ class ModelBuilderTest(tf.test.TestCase):
} }
} }
} }
batch_norm_trainable: true
} }
box_coder { box_coder {
faster_rcnn_box_coder { faster_rcnn_box_coder {
...@@ -443,7 +442,6 @@ class ModelBuilderTest(tf.test.TestCase): ...@@ -443,7 +442,6 @@ class ModelBuilderTest(tf.test.TestCase):
self.assertIsInstance(model, ssd_meta_arch.SSDMetaArch) self.assertIsInstance(model, ssd_meta_arch.SSDMetaArch)
self.assertIsInstance(model._feature_extractor, self.assertIsInstance(model._feature_extractor,
SSDMobileNetV2FeatureExtractor) SSDMobileNetV2FeatureExtractor)
self.assertTrue(model._feature_extractor._batch_norm_trainable)
self.assertTrue(model._normalize_loc_loss_by_codesize) self.assertTrue(model._normalize_loc_loss_by_codesize)
def test_create_embedded_ssd_mobilenet_v1_model_from_config(self): def test_create_embedded_ssd_mobilenet_v1_model_from_config(self):
...@@ -461,7 +459,6 @@ class ModelBuilderTest(tf.test.TestCase): ...@@ -461,7 +459,6 @@ class ModelBuilderTest(tf.test.TestCase):
} }
} }
} }
batch_norm_trainable: true
} }
box_coder { box_coder {
faster_rcnn_box_coder { faster_rcnn_box_coder {
......
...@@ -147,7 +147,7 @@ class RfcnBoxPredictor(BoxPredictor): ...@@ -147,7 +147,7 @@ class RfcnBoxPredictor(BoxPredictor):
def __init__(self, def __init__(self,
is_training, is_training,
num_classes, num_classes,
conv_hyperparams, conv_hyperparams_fn,
num_spatial_bins, num_spatial_bins,
depth, depth,
crop_size, crop_size,
...@@ -160,8 +160,8 @@ class RfcnBoxPredictor(BoxPredictor): ...@@ -160,8 +160,8 @@ class RfcnBoxPredictor(BoxPredictor):
include the background category, so if groundtruth labels take values include the background category, so if groundtruth labels take values
in {0, 1, .., K-1}, num_classes=K (and not K+1, even though the in {0, 1, .., K-1}, num_classes=K (and not K+1, even though the
assigned classification targets can range from {0,... K}). assigned classification targets can range from {0,... K}).
conv_hyperparams: Slim arg_scope with hyperparameters for conolutional conv_hyperparams_fn: A function to construct tf-slim arg_scope with
layers. hyperparameters for convolutional layers.
num_spatial_bins: A list of two integers `[spatial_bins_y, num_spatial_bins: A list of two integers `[spatial_bins_y,
spatial_bins_x]`. spatial_bins_x]`.
depth: Target depth to reduce the input feature maps to. depth: Target depth to reduce the input feature maps to.
...@@ -169,7 +169,7 @@ class RfcnBoxPredictor(BoxPredictor): ...@@ -169,7 +169,7 @@ class RfcnBoxPredictor(BoxPredictor):
box_code_size: Size of encoding for each box. box_code_size: Size of encoding for each box.
""" """
super(RfcnBoxPredictor, self).__init__(is_training, num_classes) super(RfcnBoxPredictor, self).__init__(is_training, num_classes)
self._conv_hyperparams = conv_hyperparams self._conv_hyperparams_fn = conv_hyperparams_fn
self._num_spatial_bins = num_spatial_bins self._num_spatial_bins = num_spatial_bins
self._depth = depth self._depth = depth
self._crop_size = crop_size self._crop_size = crop_size
...@@ -227,7 +227,7 @@ class RfcnBoxPredictor(BoxPredictor): ...@@ -227,7 +227,7 @@ class RfcnBoxPredictor(BoxPredictor):
return tf.reshape(ones_mat * multiplier, [-1]) return tf.reshape(ones_mat * multiplier, [-1])
net = image_feature net = image_feature
with slim.arg_scope(self._conv_hyperparams): with slim.arg_scope(self._conv_hyperparams_fn()):
net = slim.conv2d(net, self._depth, [1, 1], scope='reduce_depth') net = slim.conv2d(net, self._depth, [1, 1], scope='reduce_depth')
# Location predictions. # Location predictions.
location_feature_map_depth = (self._num_spatial_bins[0] * location_feature_map_depth = (self._num_spatial_bins[0] *
...@@ -297,11 +297,11 @@ class MaskRCNNBoxPredictor(BoxPredictor): ...@@ -297,11 +297,11 @@ class MaskRCNNBoxPredictor(BoxPredictor):
def __init__(self, def __init__(self,
is_training, is_training,
num_classes, num_classes,
fc_hyperparams, fc_hyperparams_fn,
use_dropout, use_dropout,
dropout_keep_prob, dropout_keep_prob,
box_code_size, box_code_size,
conv_hyperparams=None, conv_hyperparams_fn=None,
predict_instance_masks=False, predict_instance_masks=False,
mask_height=14, mask_height=14,
mask_width=14, mask_width=14,
...@@ -316,16 +316,16 @@ class MaskRCNNBoxPredictor(BoxPredictor): ...@@ -316,16 +316,16 @@ class MaskRCNNBoxPredictor(BoxPredictor):
include the background category, so if groundtruth labels take values include the background category, so if groundtruth labels take values
in {0, 1, .., K-1}, num_classes=K (and not K+1, even though the in {0, 1, .., K-1}, num_classes=K (and not K+1, even though the
assigned classification targets can range from {0,... K}). assigned classification targets can range from {0,... K}).
fc_hyperparams: Slim arg_scope with hyperparameters for fully fc_hyperparams_fn: A function to generate tf-slim arg_scope with
connected ops. hyperparameters for fully connected ops.
use_dropout: Option to use dropout or not. Note that a single dropout use_dropout: Option to use dropout or not. Note that a single dropout
op is applied here prior to both box and class predictions, which stands op is applied here prior to both box and class predictions, which stands
in contrast to the ConvolutionalBoxPredictor below. in contrast to the ConvolutionalBoxPredictor below.
dropout_keep_prob: Keep probability for dropout. dropout_keep_prob: Keep probability for dropout.
This is only used if use_dropout is True. This is only used if use_dropout is True.
box_code_size: Size of encoding for each box. box_code_size: Size of encoding for each box.
conv_hyperparams: Slim arg_scope with hyperparameters for convolution conv_hyperparams_fn: A function to generate tf-slim arg_scope with
ops. hyperparameters for convolution ops.
predict_instance_masks: Whether to predict object masks inside detection predict_instance_masks: Whether to predict object masks inside detection
boxes. boxes.
mask_height: Desired output mask height. The default value is 14. mask_height: Desired output mask height. The default value is 14.
...@@ -347,11 +347,11 @@ class MaskRCNNBoxPredictor(BoxPredictor): ...@@ -347,11 +347,11 @@ class MaskRCNNBoxPredictor(BoxPredictor):
ValueError: If mask_prediction_num_conv_layers is smaller than two. ValueError: If mask_prediction_num_conv_layers is smaller than two.
""" """
super(MaskRCNNBoxPredictor, self).__init__(is_training, num_classes) super(MaskRCNNBoxPredictor, self).__init__(is_training, num_classes)
self._fc_hyperparams = fc_hyperparams self._fc_hyperparams_fn = fc_hyperparams_fn
self._use_dropout = use_dropout self._use_dropout = use_dropout
self._box_code_size = box_code_size self._box_code_size = box_code_size
self._dropout_keep_prob = dropout_keep_prob self._dropout_keep_prob = dropout_keep_prob
self._conv_hyperparams = conv_hyperparams self._conv_hyperparams_fn = conv_hyperparams_fn
self._predict_instance_masks = predict_instance_masks self._predict_instance_masks = predict_instance_masks
self._mask_height = mask_height self._mask_height = mask_height
self._mask_width = mask_width self._mask_width = mask_width
...@@ -361,7 +361,7 @@ class MaskRCNNBoxPredictor(BoxPredictor): ...@@ -361,7 +361,7 @@ class MaskRCNNBoxPredictor(BoxPredictor):
if self._predict_keypoints: if self._predict_keypoints:
raise ValueError('Keypoint prediction is unimplemented.') raise ValueError('Keypoint prediction is unimplemented.')
if ((self._predict_instance_masks or self._predict_keypoints) and if ((self._predict_instance_masks or self._predict_keypoints) and
self._conv_hyperparams is None): self._conv_hyperparams_fn is None):
raise ValueError('`conv_hyperparams` must be provided when predicting ' raise ValueError('`conv_hyperparams` must be provided when predicting '
'masks.') 'masks.')
if self._mask_prediction_num_conv_layers < 2: if self._mask_prediction_num_conv_layers < 2:
...@@ -399,7 +399,7 @@ class MaskRCNNBoxPredictor(BoxPredictor): ...@@ -399,7 +399,7 @@ class MaskRCNNBoxPredictor(BoxPredictor):
flattened_image_features = slim.dropout(flattened_image_features, flattened_image_features = slim.dropout(flattened_image_features,
keep_prob=self._dropout_keep_prob, keep_prob=self._dropout_keep_prob,
is_training=self._is_training) is_training=self._is_training)
with slim.arg_scope(self._fc_hyperparams): with slim.arg_scope(self._fc_hyperparams_fn()):
box_encodings = slim.fully_connected( box_encodings = slim.fully_connected(
flattened_image_features, flattened_image_features,
self._num_classes * self._box_code_size, self._num_classes * self._box_code_size,
...@@ -463,7 +463,7 @@ class MaskRCNNBoxPredictor(BoxPredictor): ...@@ -463,7 +463,7 @@ class MaskRCNNBoxPredictor(BoxPredictor):
num_feature_channels = image_features.get_shape().as_list()[3] num_feature_channels = image_features.get_shape().as_list()[3]
num_conv_channels = self._get_mask_predictor_conv_depth( num_conv_channels = self._get_mask_predictor_conv_depth(
num_feature_channels, self.num_classes) num_feature_channels, self.num_classes)
with slim.arg_scope(self._conv_hyperparams): with slim.arg_scope(self._conv_hyperparams_fn()):
upsampled_features = tf.image.resize_bilinear( upsampled_features = tf.image.resize_bilinear(
image_features, image_features,
[self._mask_height, self._mask_width], [self._mask_height, self._mask_width],
...@@ -578,7 +578,7 @@ class ConvolutionalBoxPredictor(BoxPredictor): ...@@ -578,7 +578,7 @@ class ConvolutionalBoxPredictor(BoxPredictor):
def __init__(self, def __init__(self,
is_training, is_training,
num_classes, num_classes,
conv_hyperparams, conv_hyperparams_fn,
min_depth, min_depth,
max_depth, max_depth,
num_layers_before_predictor, num_layers_before_predictor,
...@@ -597,8 +597,9 @@ class ConvolutionalBoxPredictor(BoxPredictor): ...@@ -597,8 +597,9 @@ class ConvolutionalBoxPredictor(BoxPredictor):
include the background category, so if groundtruth labels take values include the background category, so if groundtruth labels take values
in {0, 1, .., K-1}, num_classes=K (and not K+1, even though the in {0, 1, .., K-1}, num_classes=K (and not K+1, even though the
assigned classification targets can range from {0,... K}). assigned classification targets can range from {0,... K}).
conv_hyperparams: Slim arg_scope with hyperparameters for convolution ops. conv_hyperparams_fn: A function to generate tf-slim arg_scope with
min_depth: Minumum feature depth prior to predicting box encodings hyperparameters for convolution ops.
min_depth: Minimum feature depth prior to predicting box encodings
and class predictions. and class predictions.
max_depth: Maximum feature depth prior to predicting box encodings max_depth: Maximum feature depth prior to predicting box encodings
and class predictions. If max_depth is set to 0, no additional and class predictions. If max_depth is set to 0, no additional
...@@ -626,7 +627,7 @@ class ConvolutionalBoxPredictor(BoxPredictor): ...@@ -626,7 +627,7 @@ class ConvolutionalBoxPredictor(BoxPredictor):
super(ConvolutionalBoxPredictor, self).__init__(is_training, num_classes) super(ConvolutionalBoxPredictor, self).__init__(is_training, num_classes)
if min_depth > max_depth: if min_depth > max_depth:
raise ValueError('min_depth should be less than or equal to max_depth') raise ValueError('min_depth should be less than or equal to max_depth')
self._conv_hyperparams = conv_hyperparams self._conv_hyperparams_fn = conv_hyperparams_fn
self._min_depth = min_depth self._min_depth = min_depth
self._max_depth = max_depth self._max_depth = max_depth
self._num_layers_before_predictor = num_layers_before_predictor self._num_layers_before_predictor = num_layers_before_predictor
...@@ -679,7 +680,7 @@ class ConvolutionalBoxPredictor(BoxPredictor): ...@@ -679,7 +680,7 @@ class ConvolutionalBoxPredictor(BoxPredictor):
# Add a slot for the background class. # Add a slot for the background class.
num_class_slots = self.num_classes + 1 num_class_slots = self.num_classes + 1
net = image_feature net = image_feature
with slim.arg_scope(self._conv_hyperparams), \ with slim.arg_scope(self._conv_hyperparams_fn()), \
slim.arg_scope([slim.dropout], is_training=self._is_training): slim.arg_scope([slim.dropout], is_training=self._is_training):
# Add additional conv layers before the class predictor. # Add additional conv layers before the class predictor.
features_depth = static_shape.get_depth(image_feature.get_shape()) features_depth = static_shape.get_depth(image_feature.get_shape())
...@@ -767,7 +768,7 @@ class WeightSharedConvolutionalBoxPredictor(BoxPredictor): ...@@ -767,7 +768,7 @@ class WeightSharedConvolutionalBoxPredictor(BoxPredictor):
def __init__(self, def __init__(self,
is_training, is_training,
num_classes, num_classes,
conv_hyperparams, conv_hyperparams_fn,
depth, depth,
num_layers_before_predictor, num_layers_before_predictor,
box_code_size, box_code_size,
...@@ -781,7 +782,8 @@ class WeightSharedConvolutionalBoxPredictor(BoxPredictor): ...@@ -781,7 +782,8 @@ class WeightSharedConvolutionalBoxPredictor(BoxPredictor):
include the background category, so if groundtruth labels take values include the background category, so if groundtruth labels take values
in {0, 1, .., K-1}, num_classes=K (and not K+1, even though the in {0, 1, .., K-1}, num_classes=K (and not K+1, even though the
assigned classification targets can range from {0,... K}). assigned classification targets can range from {0,... K}).
conv_hyperparams: Slim arg_scope with hyperparameters for convolution ops. conv_hyperparams_fn: A function to generate tf-slim arg_scope with
hyperparameters for convolution ops.
depth: depth of conv layers. depth: depth of conv layers.
num_layers_before_predictor: Number of the additional conv layers before num_layers_before_predictor: Number of the additional conv layers before
the predictor. the predictor.
...@@ -792,7 +794,7 @@ class WeightSharedConvolutionalBoxPredictor(BoxPredictor): ...@@ -792,7 +794,7 @@ class WeightSharedConvolutionalBoxPredictor(BoxPredictor):
""" """
super(WeightSharedConvolutionalBoxPredictor, self).__init__(is_training, super(WeightSharedConvolutionalBoxPredictor, self).__init__(is_training,
num_classes) num_classes)
self._conv_hyperparams = conv_hyperparams self._conv_hyperparams_fn = conv_hyperparams_fn
self._depth = depth self._depth = depth
self._num_layers_before_predictor = num_layers_before_predictor self._num_layers_before_predictor = num_layers_before_predictor
self._box_code_size = box_code_size self._box_code_size = box_code_size
...@@ -846,7 +848,7 @@ class WeightSharedConvolutionalBoxPredictor(BoxPredictor): ...@@ -846,7 +848,7 @@ class WeightSharedConvolutionalBoxPredictor(BoxPredictor):
num_class_slots = self.num_classes + 1 num_class_slots = self.num_classes + 1
box_encodings_net = image_feature box_encodings_net = image_feature
class_predictions_net = image_feature class_predictions_net = image_feature
with slim.arg_scope(self._conv_hyperparams): with slim.arg_scope(self._conv_hyperparams_fn()):
for i in range(self._num_layers_before_predictor): for i in range(self._num_layers_before_predictor):
box_encodings_net = slim.conv2d( box_encodings_net = slim.conv2d(
box_encodings_net, box_encodings_net,
......
...@@ -49,7 +49,7 @@ class MaskRCNNBoxPredictorTest(tf.test.TestCase): ...@@ -49,7 +49,7 @@ class MaskRCNNBoxPredictorTest(tf.test.TestCase):
mask_box_predictor = box_predictor.MaskRCNNBoxPredictor( mask_box_predictor = box_predictor.MaskRCNNBoxPredictor(
is_training=False, is_training=False,
num_classes=5, num_classes=5,
fc_hyperparams=self._build_arg_scope_with_hyperparams(), fc_hyperparams_fn=self._build_arg_scope_with_hyperparams(),
use_dropout=False, use_dropout=False,
dropout_keep_prob=0.5, dropout_keep_prob=0.5,
box_code_size=4, box_code_size=4,
...@@ -75,7 +75,7 @@ class MaskRCNNBoxPredictorTest(tf.test.TestCase): ...@@ -75,7 +75,7 @@ class MaskRCNNBoxPredictorTest(tf.test.TestCase):
box_predictor.MaskRCNNBoxPredictor( box_predictor.MaskRCNNBoxPredictor(
is_training=False, is_training=False,
num_classes=5, num_classes=5,
fc_hyperparams=self._build_arg_scope_with_hyperparams(), fc_hyperparams_fn=self._build_arg_scope_with_hyperparams(),
use_dropout=False, use_dropout=False,
dropout_keep_prob=0.5, dropout_keep_prob=0.5,
box_code_size=4, box_code_size=4,
...@@ -86,11 +86,11 @@ class MaskRCNNBoxPredictorTest(tf.test.TestCase): ...@@ -86,11 +86,11 @@ class MaskRCNNBoxPredictorTest(tf.test.TestCase):
mask_box_predictor = box_predictor.MaskRCNNBoxPredictor( mask_box_predictor = box_predictor.MaskRCNNBoxPredictor(
is_training=False, is_training=False,
num_classes=5, num_classes=5,
fc_hyperparams=self._build_arg_scope_with_hyperparams(), fc_hyperparams_fn=self._build_arg_scope_with_hyperparams(),
use_dropout=False, use_dropout=False,
dropout_keep_prob=0.5, dropout_keep_prob=0.5,
box_code_size=4, box_code_size=4,
conv_hyperparams=self._build_arg_scope_with_hyperparams( conv_hyperparams_fn=self._build_arg_scope_with_hyperparams(
op_type=hyperparams_pb2.Hyperparams.CONV), op_type=hyperparams_pb2.Hyperparams.CONV),
predict_instance_masks=True) predict_instance_masks=True)
box_predictions = mask_box_predictor.predict( box_predictions = mask_box_predictor.predict(
...@@ -108,7 +108,7 @@ class MaskRCNNBoxPredictorTest(tf.test.TestCase): ...@@ -108,7 +108,7 @@ class MaskRCNNBoxPredictorTest(tf.test.TestCase):
mask_box_predictor = box_predictor.MaskRCNNBoxPredictor( mask_box_predictor = box_predictor.MaskRCNNBoxPredictor(
is_training=False, is_training=False,
num_classes=5, num_classes=5,
fc_hyperparams=self._build_arg_scope_with_hyperparams(), fc_hyperparams_fn=self._build_arg_scope_with_hyperparams(),
use_dropout=False, use_dropout=False,
dropout_keep_prob=0.5, dropout_keep_prob=0.5,
box_code_size=4) box_code_size=4)
...@@ -125,7 +125,7 @@ class MaskRCNNBoxPredictorTest(tf.test.TestCase): ...@@ -125,7 +125,7 @@ class MaskRCNNBoxPredictorTest(tf.test.TestCase):
box_predictor.MaskRCNNBoxPredictor( box_predictor.MaskRCNNBoxPredictor(
is_training=False, is_training=False,
num_classes=5, num_classes=5,
fc_hyperparams=self._build_arg_scope_with_hyperparams(), fc_hyperparams_fn=self._build_arg_scope_with_hyperparams(),
use_dropout=False, use_dropout=False,
dropout_keep_prob=0.5, dropout_keep_prob=0.5,
box_code_size=4, box_code_size=4,
...@@ -155,7 +155,7 @@ class RfcnBoxPredictorTest(tf.test.TestCase): ...@@ -155,7 +155,7 @@ class RfcnBoxPredictorTest(tf.test.TestCase):
rfcn_box_predictor = box_predictor.RfcnBoxPredictor( rfcn_box_predictor = box_predictor.RfcnBoxPredictor(
is_training=False, is_training=False,
num_classes=2, num_classes=2,
conv_hyperparams=self._build_arg_scope_with_conv_hyperparams(), conv_hyperparams_fn=self._build_arg_scope_with_conv_hyperparams(),
num_spatial_bins=[3, 3], num_spatial_bins=[3, 3],
depth=4, depth=4,
crop_size=[12, 12], crop_size=[12, 12],
...@@ -205,7 +205,7 @@ class ConvolutionalBoxPredictorTest(test_case.TestCase): ...@@ -205,7 +205,7 @@ class ConvolutionalBoxPredictorTest(test_case.TestCase):
conv_box_predictor = box_predictor.ConvolutionalBoxPredictor( conv_box_predictor = box_predictor.ConvolutionalBoxPredictor(
is_training=False, is_training=False,
num_classes=0, num_classes=0,
conv_hyperparams=self._build_arg_scope_with_conv_hyperparams(), conv_hyperparams_fn=self._build_arg_scope_with_conv_hyperparams(),
min_depth=0, min_depth=0,
max_depth=32, max_depth=32,
num_layers_before_predictor=1, num_layers_before_predictor=1,
...@@ -234,7 +234,7 @@ class ConvolutionalBoxPredictorTest(test_case.TestCase): ...@@ -234,7 +234,7 @@ class ConvolutionalBoxPredictorTest(test_case.TestCase):
conv_box_predictor = box_predictor.ConvolutionalBoxPredictor( conv_box_predictor = box_predictor.ConvolutionalBoxPredictor(
is_training=False, is_training=False,
num_classes=0, num_classes=0,
conv_hyperparams=self._build_arg_scope_with_conv_hyperparams(), conv_hyperparams_fn=self._build_arg_scope_with_conv_hyperparams(),
min_depth=0, min_depth=0,
max_depth=32, max_depth=32,
num_layers_before_predictor=1, num_layers_before_predictor=1,
...@@ -265,7 +265,7 @@ class ConvolutionalBoxPredictorTest(test_case.TestCase): ...@@ -265,7 +265,7 @@ class ConvolutionalBoxPredictorTest(test_case.TestCase):
conv_box_predictor = box_predictor.ConvolutionalBoxPredictor( conv_box_predictor = box_predictor.ConvolutionalBoxPredictor(
is_training=False, is_training=False,
num_classes=num_classes_without_background, num_classes=num_classes_without_background,
conv_hyperparams=self._build_arg_scope_with_conv_hyperparams(), conv_hyperparams_fn=self._build_arg_scope_with_conv_hyperparams(),
min_depth=0, min_depth=0,
max_depth=32, max_depth=32,
num_layers_before_predictor=1, num_layers_before_predictor=1,
...@@ -297,7 +297,7 @@ class ConvolutionalBoxPredictorTest(test_case.TestCase): ...@@ -297,7 +297,7 @@ class ConvolutionalBoxPredictorTest(test_case.TestCase):
conv_box_predictor = box_predictor.ConvolutionalBoxPredictor( conv_box_predictor = box_predictor.ConvolutionalBoxPredictor(
is_training=False, is_training=False,
num_classes=0, num_classes=0,
conv_hyperparams=self._build_arg_scope_with_conv_hyperparams(), conv_hyperparams_fn=self._build_arg_scope_with_conv_hyperparams(),
min_depth=0, min_depth=0,
max_depth=32, max_depth=32,
num_layers_before_predictor=1, num_layers_before_predictor=1,
...@@ -344,7 +344,7 @@ class ConvolutionalBoxPredictorTest(test_case.TestCase): ...@@ -344,7 +344,7 @@ class ConvolutionalBoxPredictorTest(test_case.TestCase):
conv_box_predictor = box_predictor.ConvolutionalBoxPredictor( conv_box_predictor = box_predictor.ConvolutionalBoxPredictor(
is_training=False, is_training=False,
num_classes=0, num_classes=0,
conv_hyperparams=self._build_arg_scope_with_conv_hyperparams(), conv_hyperparams_fn=self._build_arg_scope_with_conv_hyperparams(),
min_depth=0, min_depth=0,
max_depth=32, max_depth=32,
num_layers_before_predictor=1, num_layers_before_predictor=1,
...@@ -416,7 +416,7 @@ class WeightSharedConvolutionalBoxPredictorTest(test_case.TestCase): ...@@ -416,7 +416,7 @@ class WeightSharedConvolutionalBoxPredictorTest(test_case.TestCase):
conv_box_predictor = box_predictor.WeightSharedConvolutionalBoxPredictor( conv_box_predictor = box_predictor.WeightSharedConvolutionalBoxPredictor(
is_training=False, is_training=False,
num_classes=0, num_classes=0,
conv_hyperparams=self._build_arg_scope_with_conv_hyperparams(), conv_hyperparams_fn=self._build_arg_scope_with_conv_hyperparams(),
depth=32, depth=32,
num_layers_before_predictor=1, num_layers_before_predictor=1,
box_code_size=4) box_code_size=4)
...@@ -442,7 +442,7 @@ class WeightSharedConvolutionalBoxPredictorTest(test_case.TestCase): ...@@ -442,7 +442,7 @@ class WeightSharedConvolutionalBoxPredictorTest(test_case.TestCase):
conv_box_predictor = box_predictor.WeightSharedConvolutionalBoxPredictor( conv_box_predictor = box_predictor.WeightSharedConvolutionalBoxPredictor(
is_training=False, is_training=False,
num_classes=num_classes_without_background, num_classes=num_classes_without_background,
conv_hyperparams=self._build_arg_scope_with_conv_hyperparams(), conv_hyperparams_fn=self._build_arg_scope_with_conv_hyperparams(),
depth=32, depth=32,
num_layers_before_predictor=1, num_layers_before_predictor=1,
box_code_size=4) box_code_size=4)
...@@ -471,7 +471,7 @@ class WeightSharedConvolutionalBoxPredictorTest(test_case.TestCase): ...@@ -471,7 +471,7 @@ class WeightSharedConvolutionalBoxPredictorTest(test_case.TestCase):
conv_box_predictor = box_predictor.WeightSharedConvolutionalBoxPredictor( conv_box_predictor = box_predictor.WeightSharedConvolutionalBoxPredictor(
is_training=False, is_training=False,
num_classes=num_classes_without_background, num_classes=num_classes_without_background,
conv_hyperparams=self._build_arg_scope_with_conv_hyperparams(), conv_hyperparams_fn=self._build_arg_scope_with_conv_hyperparams(),
depth=32, depth=32,
num_layers_before_predictor=1, num_layers_before_predictor=1,
box_code_size=4) box_code_size=4)
...@@ -500,7 +500,7 @@ class WeightSharedConvolutionalBoxPredictorTest(test_case.TestCase): ...@@ -500,7 +500,7 @@ class WeightSharedConvolutionalBoxPredictorTest(test_case.TestCase):
conv_box_predictor = box_predictor.WeightSharedConvolutionalBoxPredictor( conv_box_predictor = box_predictor.WeightSharedConvolutionalBoxPredictor(
is_training=False, is_training=False,
num_classes=num_classes_without_background, num_classes=num_classes_without_background,
conv_hyperparams=self._build_arg_scope_with_conv_hyperparams(), conv_hyperparams_fn=self._build_arg_scope_with_conv_hyperparams(),
depth=32, depth=32,
num_layers_before_predictor=2, num_layers_before_predictor=2,
box_code_size=4) box_code_size=4)
...@@ -553,7 +553,7 @@ class WeightSharedConvolutionalBoxPredictorTest(test_case.TestCase): ...@@ -553,7 +553,7 @@ class WeightSharedConvolutionalBoxPredictorTest(test_case.TestCase):
conv_box_predictor = box_predictor.WeightSharedConvolutionalBoxPredictor( conv_box_predictor = box_predictor.WeightSharedConvolutionalBoxPredictor(
is_training=False, is_training=False,
num_classes=0, num_classes=0,
conv_hyperparams=self._build_arg_scope_with_conv_hyperparams(), conv_hyperparams_fn=self._build_arg_scope_with_conv_hyperparams(),
depth=32, depth=32,
num_layers_before_predictor=1, num_layers_before_predictor=1,
box_code_size=4) box_code_size=4)
......
...@@ -272,6 +272,7 @@ def normalize_image(image, original_minval, original_maxval, target_minval, ...@@ -272,6 +272,7 @@ def normalize_image(image, original_minval, original_maxval, target_minval,
def retain_boxes_above_threshold(boxes, def retain_boxes_above_threshold(boxes,
labels, labels,
label_scores, label_scores,
multiclass_scores=None,
masks=None, masks=None,
keypoints=None, keypoints=None,
threshold=0.0): threshold=0.0):
...@@ -288,6 +289,9 @@ def retain_boxes_above_threshold(boxes, ...@@ -288,6 +289,9 @@ def retain_boxes_above_threshold(boxes,
classes. classes.
label_scores: float32 tensor of shape [num_instance] representing the label_scores: float32 tensor of shape [num_instance] representing the
score for each box. score for each box.
multiclass_scores: (optional) float32 tensor of shape
[num_instances, num_classes] representing the score for each box for each
class.
masks: (optional) rank 3 float32 tensor with shape masks: (optional) rank 3 float32 tensor with shape
[num_instances, height, width] containing instance masks. The masks are of [num_instances, height, width] containing instance masks. The masks are of
the same height, width as the input `image`. the same height, width as the input `image`.
...@@ -301,8 +305,10 @@ def retain_boxes_above_threshold(boxes, ...@@ -301,8 +305,10 @@ def retain_boxes_above_threshold(boxes,
retianed_labels: [num_retained_instance] retianed_labels: [num_retained_instance]
retained_label_scores: [num_retained_instance] retained_label_scores: [num_retained_instance]
If masks, or keypoints are not None, the function also returns: If multiclass_scores, masks, or keypoints are not None, the function also
returns:
retained_multiclass_scores: [num_retained_instance, num_classes]
retained_masks: [num_retained_instance, height, width] retained_masks: [num_retained_instance, height, width]
retained_keypoints: [num_retained_instance, num_keypoints, 2] retained_keypoints: [num_retained_instance, num_keypoints, 2]
""" """
...@@ -316,6 +322,10 @@ def retain_boxes_above_threshold(boxes, ...@@ -316,6 +322,10 @@ def retain_boxes_above_threshold(boxes,
retained_label_scores = tf.gather(label_scores, indices) retained_label_scores = tf.gather(label_scores, indices)
result = [retained_boxes, retained_labels, retained_label_scores] result = [retained_boxes, retained_labels, retained_label_scores]
if multiclass_scores is not None:
retained_multiclass_scores = tf.gather(multiclass_scores, indices)
result.append(retained_multiclass_scores)
if masks is not None: if masks is not None:
retained_masks = tf.gather(masks, indices) retained_masks = tf.gather(masks, indices)
result.append(retained_masks) result.append(retained_masks)
...@@ -1097,6 +1107,7 @@ def _strict_random_crop_image(image, ...@@ -1097,6 +1107,7 @@ def _strict_random_crop_image(image,
boxes, boxes,
labels, labels,
label_scores=None, label_scores=None,
multiclass_scores=None,
masks=None, masks=None,
keypoints=None, keypoints=None,
min_object_covered=1.0, min_object_covered=1.0,
...@@ -1123,6 +1134,9 @@ def _strict_random_crop_image(image, ...@@ -1123,6 +1134,9 @@ def _strict_random_crop_image(image,
labels: rank 1 int32 tensor containing the object classes. labels: rank 1 int32 tensor containing the object classes.
label_scores: (optional) float32 tensor of shape [num_instances] label_scores: (optional) float32 tensor of shape [num_instances]
representing the score for each box. representing the score for each box.
multiclass_scores: (optional) float32 tensor of shape
[num_instances, num_classes] representing the score for each box for each
class.
masks: (optional) rank 3 float32 tensor with shape masks: (optional) rank 3 float32 tensor with shape
[num_instances, height, width] containing instance masks. The masks [num_instances, height, width] containing instance masks. The masks
are of the same height, width as the input `image`. are of the same height, width as the input `image`.
...@@ -1147,8 +1161,11 @@ def _strict_random_crop_image(image, ...@@ -1147,8 +1161,11 @@ def _strict_random_crop_image(image,
Boxes are in normalized form. Boxes are in normalized form.
labels: new labels. labels: new labels.
If label_scores, masks, or keypoints is not None, the function also returns: If label_scores, multiclass_scores, masks, or keypoints is not None, the
function also returns:
label_scores: rank 1 float32 tensor with shape [num_instances]. label_scores: rank 1 float32 tensor with shape [num_instances].
multiclass_scores: rank 2 float32 tensor with shape
[num_instances, num_classes]
masks: rank 3 float32 tensor with shape [num_instances, height, width] masks: rank 3 float32 tensor with shape [num_instances, height, width]
containing instance masks. containing instance masks.
keypoints: rank 3 float32 tensor with shape keypoints: rank 3 float32 tensor with shape
...@@ -1195,6 +1212,9 @@ def _strict_random_crop_image(image, ...@@ -1195,6 +1212,9 @@ def _strict_random_crop_image(image,
if label_scores is not None: if label_scores is not None:
boxlist.add_field('label_scores', label_scores) boxlist.add_field('label_scores', label_scores)
if multiclass_scores is not None:
boxlist.add_field('multiclass_scores', multiclass_scores)
im_boxlist = box_list.BoxList(im_box_rank2) im_boxlist = box_list.BoxList(im_box_rank2)
# remove boxes that are outside cropped image # remove boxes that are outside cropped image
...@@ -1219,6 +1239,10 @@ def _strict_random_crop_image(image, ...@@ -1219,6 +1239,10 @@ def _strict_random_crop_image(image,
new_label_scores = overlapping_boxlist.get_field('label_scores') new_label_scores = overlapping_boxlist.get_field('label_scores')
result.append(new_label_scores) result.append(new_label_scores)
if multiclass_scores is not None:
new_multiclass_scores = overlapping_boxlist.get_field('multiclass_scores')
result.append(new_multiclass_scores)
if masks is not None: if masks is not None:
masks_of_boxes_inside_window = tf.gather(masks, inside_window_ids) masks_of_boxes_inside_window = tf.gather(masks, inside_window_ids)
masks_of_boxes_completely_inside_window = tf.gather( masks_of_boxes_completely_inside_window = tf.gather(
...@@ -1247,6 +1271,7 @@ def random_crop_image(image, ...@@ -1247,6 +1271,7 @@ def random_crop_image(image,
boxes, boxes,
labels, labels,
label_scores=None, label_scores=None,
multiclass_scores=None,
masks=None, masks=None,
keypoints=None, keypoints=None,
min_object_covered=1.0, min_object_covered=1.0,
...@@ -1282,6 +1307,9 @@ def random_crop_image(image, ...@@ -1282,6 +1307,9 @@ def random_crop_image(image,
labels: rank 1 int32 tensor containing the object classes. labels: rank 1 int32 tensor containing the object classes.
label_scores: (optional) float32 tensor of shape [num_instances]. label_scores: (optional) float32 tensor of shape [num_instances].
representing the score for each box. representing the score for each box.
multiclass_scores: (optional) float32 tensor of shape
[num_instances, num_classes] representing the score for each box for each
class.
masks: (optional) rank 3 float32 tensor with shape masks: (optional) rank 3 float32 tensor with shape
[num_instances, height, width] containing instance masks. The masks [num_instances, height, width] containing instance masks. The masks
are of the same height, width as the input `image`. are of the same height, width as the input `image`.
...@@ -1311,9 +1339,11 @@ def random_crop_image(image, ...@@ -1311,9 +1339,11 @@ def random_crop_image(image,
form. form.
labels: new labels. labels: new labels.
If label_scores, masks, or keypoints are not None, the function also If label_scores, multiclass_scores, masks, or keypoints is not None, the
returns: function also returns:
label_scores: new scores. label_scores: rank 1 float32 tensor with shape [num_instances].
multiclass_scores: rank 2 float32 tensor with shape
[num_instances, num_classes]
masks: rank 3 float32 tensor with shape [num_instances, height, width] masks: rank 3 float32 tensor with shape [num_instances, height, width]
containing instance masks. containing instance masks.
keypoints: rank 3 float32 tensor with shape keypoints: rank 3 float32 tensor with shape
...@@ -1326,6 +1356,7 @@ def random_crop_image(image, ...@@ -1326,6 +1356,7 @@ def random_crop_image(image,
boxes, boxes,
labels, labels,
label_scores=label_scores, label_scores=label_scores,
multiclass_scores=multiclass_scores,
masks=masks, masks=masks,
keypoints=keypoints, keypoints=keypoints,
min_object_covered=min_object_covered, min_object_covered=min_object_covered,
...@@ -1348,6 +1379,8 @@ def random_crop_image(image, ...@@ -1348,6 +1379,8 @@ def random_crop_image(image,
if label_scores is not None: if label_scores is not None:
outputs.append(label_scores) outputs.append(label_scores)
if multiclass_scores is not None:
outputs.append(multiclass_scores)
if masks is not None: if masks is not None:
outputs.append(masks) outputs.append(masks)
if keypoints is not None: if keypoints is not None:
...@@ -1481,6 +1514,7 @@ def random_crop_pad_image(image, ...@@ -1481,6 +1514,7 @@ def random_crop_pad_image(image,
boxes, boxes,
labels, labels,
label_scores=None, label_scores=None,
multiclass_scores=None,
min_object_covered=1.0, min_object_covered=1.0,
aspect_ratio_range=(0.75, 1.33), aspect_ratio_range=(0.75, 1.33),
area_range=(0.1, 1.0), area_range=(0.1, 1.0),
...@@ -1512,6 +1546,9 @@ def random_crop_pad_image(image, ...@@ -1512,6 +1546,9 @@ def random_crop_pad_image(image,
Each row is in the form of [ymin, xmin, ymax, xmax]. Each row is in the form of [ymin, xmin, ymax, xmax].
labels: rank 1 int32 tensor containing the object classes. labels: rank 1 int32 tensor containing the object classes.
label_scores: rank 1 float32 containing the label scores. label_scores: rank 1 float32 containing the label scores.
multiclass_scores: (optional) float32 tensor of shape
[num_instances, num_classes] representing the score for each box for each
class.
min_object_covered: the cropped image must cover at least this fraction of min_object_covered: the cropped image must cover at least this fraction of
at least one of the input bounding boxes. at least one of the input bounding boxes.
aspect_ratio_range: allowed range for aspect ratio of cropped image. aspect_ratio_range: allowed range for aspect ratio of cropped image.
...@@ -1543,6 +1580,9 @@ def random_crop_pad_image(image, ...@@ -1543,6 +1580,9 @@ def random_crop_pad_image(image,
cropped_labels: cropped labels. cropped_labels: cropped labels.
if label_scores is not None also returns: if label_scores is not None also returns:
cropped_label_scores: cropped label scores. cropped_label_scores: cropped label scores.
if multiclass_scores is not None also returns:
cropped_multiclass_scores: cropped_multiclass_scores.
""" """
image_size = tf.shape(image) image_size = tf.shape(image)
image_height = image_size[0] image_height = image_size[0]
...@@ -1552,6 +1592,7 @@ def random_crop_pad_image(image, ...@@ -1552,6 +1592,7 @@ def random_crop_pad_image(image,
boxes=boxes, boxes=boxes,
labels=labels, labels=labels,
label_scores=label_scores, label_scores=label_scores,
multiclass_scores=multiclass_scores,
min_object_covered=min_object_covered, min_object_covered=min_object_covered,
aspect_ratio_range=aspect_ratio_range, aspect_ratio_range=aspect_ratio_range,
area_range=area_range, area_range=area_range,
...@@ -1580,9 +1621,15 @@ def random_crop_pad_image(image, ...@@ -1580,9 +1621,15 @@ def random_crop_pad_image(image,
cropped_padded_output = (padded_image, padded_boxes, cropped_labels) cropped_padded_output = (padded_image, padded_boxes, cropped_labels)
index = 3
if label_scores is not None: if label_scores is not None:
cropped_label_scores = result[3] cropped_label_scores = result[index]
cropped_padded_output += (cropped_label_scores,) cropped_padded_output += (cropped_label_scores,)
index += 1
if multiclass_scores is not None:
cropped_multiclass_scores = result[index]
cropped_padded_output += (cropped_multiclass_scores,)
return cropped_padded_output return cropped_padded_output
...@@ -1591,6 +1638,7 @@ def random_crop_to_aspect_ratio(image, ...@@ -1591,6 +1638,7 @@ def random_crop_to_aspect_ratio(image,
boxes, boxes,
labels, labels,
label_scores=None, label_scores=None,
multiclass_scores=None,
masks=None, masks=None,
keypoints=None, keypoints=None,
aspect_ratio=1.0, aspect_ratio=1.0,
...@@ -1618,6 +1666,9 @@ def random_crop_to_aspect_ratio(image, ...@@ -1618,6 +1666,9 @@ def random_crop_to_aspect_ratio(image,
labels: rank 1 int32 tensor containing the object classes. labels: rank 1 int32 tensor containing the object classes.
label_scores: (optional) float32 tensor of shape [num_instances] label_scores: (optional) float32 tensor of shape [num_instances]
representing the score for each box. representing the score for each box.
multiclass_scores: (optional) float32 tensor of shape
[num_instances, num_classes] representing the score for each box for each
class.
masks: (optional) rank 3 float32 tensor with shape masks: (optional) rank 3 float32 tensor with shape
[num_instances, height, width] containing instance masks. The masks [num_instances, height, width] containing instance masks. The masks
are of the same height, width as the input `image`. are of the same height, width as the input `image`.
...@@ -1639,12 +1690,15 @@ def random_crop_to_aspect_ratio(image, ...@@ -1639,12 +1690,15 @@ def random_crop_to_aspect_ratio(image,
Boxes are in normalized form. Boxes are in normalized form.
labels: new labels. labels: new labels.
If label_scores, masks, or keypoints is not None, the function also returns: If label_scores, masks, keypoints, or multiclass_scores is not None, the
label_scores: new label scores. function also returns:
label_scores: rank 1 float32 tensor with shape [num_instances].
masks: rank 3 float32 tensor with shape [num_instances, height, width] masks: rank 3 float32 tensor with shape [num_instances, height, width]
containing instance masks. containing instance masks.
keypoints: rank 3 float32 tensor with shape keypoints: rank 3 float32 tensor with shape
[num_instances, num_keypoints, 2] [num_instances, num_keypoints, 2]
multiclass_scores: rank 2 float32 tensor with shape
[num_instances, num_classes]
Raises: Raises:
ValueError: If image is not a 3D tensor. ValueError: If image is not a 3D tensor.
...@@ -1698,6 +1752,9 @@ def random_crop_to_aspect_ratio(image, ...@@ -1698,6 +1752,9 @@ def random_crop_to_aspect_ratio(image,
if label_scores is not None: if label_scores is not None:
boxlist.add_field('label_scores', label_scores) boxlist.add_field('label_scores', label_scores)
if multiclass_scores is not None:
boxlist.add_field('multiclass_scores', multiclass_scores)
im_boxlist = box_list.BoxList(tf.expand_dims(im_box, 0)) im_boxlist = box_list.BoxList(tf.expand_dims(im_box, 0))
# remove boxes whose overlap with the image is less than overlap_thresh # remove boxes whose overlap with the image is less than overlap_thresh
...@@ -1719,6 +1776,10 @@ def random_crop_to_aspect_ratio(image, ...@@ -1719,6 +1776,10 @@ def random_crop_to_aspect_ratio(image,
new_label_scores = overlapping_boxlist.get_field('label_scores') new_label_scores = overlapping_boxlist.get_field('label_scores')
result.append(new_label_scores) result.append(new_label_scores)
if multiclass_scores is not None:
new_multiclass_scores = overlapping_boxlist.get_field('multiclass_scores')
result.append(new_multiclass_scores)
if masks is not None: if masks is not None:
masks_inside_window = tf.gather(masks, keep_ids) masks_inside_window = tf.gather(masks, keep_ids)
masks_box_begin = tf.stack([0, offset_height, offset_width]) masks_box_begin = tf.stack([0, offset_height, offset_width])
...@@ -1784,8 +1845,7 @@ def random_pad_to_aspect_ratio(image, ...@@ -1784,8 +1845,7 @@ def random_pad_to_aspect_ratio(image,
Boxes are in normalized form. Boxes are in normalized form.
labels: new labels. labels: new labels.
If label_scores, masks, or keypoints is not None, the function also returns: If masks, or keypoints is not None, the function also returns:
label_scores: new label scores.
masks: rank 3 float32 tensor with shape [num_instances, height, width] masks: rank 3 float32 tensor with shape [num_instances, height, width]
containing instance masks. containing instance masks.
keypoints: rank 3 float32 tensor with shape keypoints: rank 3 float32 tensor with shape
...@@ -2356,6 +2416,7 @@ def ssd_random_crop(image, ...@@ -2356,6 +2416,7 @@ def ssd_random_crop(image,
boxes, boxes,
labels, labels,
label_scores=None, label_scores=None,
multiclass_scores=None,
masks=None, masks=None,
keypoints=None, keypoints=None,
min_object_covered=(0.0, 0.1, 0.3, 0.5, 0.7, 0.9, 1.0), min_object_covered=(0.0, 0.1, 0.3, 0.5, 0.7, 0.9, 1.0),
...@@ -2380,6 +2441,9 @@ def ssd_random_crop(image, ...@@ -2380,6 +2441,9 @@ def ssd_random_crop(image,
Each row is in the form of [ymin, xmin, ymax, xmax]. Each row is in the form of [ymin, xmin, ymax, xmax].
labels: rank 1 int32 tensor containing the object classes. labels: rank 1 int32 tensor containing the object classes.
label_scores: rank 1 float32 tensor containing the scores. label_scores: rank 1 float32 tensor containing the scores.
multiclass_scores: (optional) float32 tensor of shape
[num_instances, num_classes] representing the score for each box for each
class.
masks: (optional) rank 3 float32 tensor with shape masks: (optional) rank 3 float32 tensor with shape
[num_instances, height, width] containing instance masks. The masks [num_instances, height, width] containing instance masks. The masks
are of the same height, width as the input `image`. are of the same height, width as the input `image`.
...@@ -2409,8 +2473,11 @@ def ssd_random_crop(image, ...@@ -2409,8 +2473,11 @@ def ssd_random_crop(image,
Boxes are in normalized form. Boxes are in normalized form.
labels: new labels. labels: new labels.
If label_scores, masks, or keypoints is not None, the function also returns: If label_scores, multiclass_scores, masks, or keypoints is not None, the
label_scores: new label scores. function also returns:
label_scores: rank 1 float32 tensor with shape [num_instances].
multiclass_scores: rank 2 float32 tensor with shape
[num_instances, num_classes]
masks: rank 3 float32 tensor with shape [num_instances, height, width] masks: rank 3 float32 tensor with shape [num_instances, height, width]
containing instance masks. containing instance masks.
keypoints: rank 3 float32 tensor with shape keypoints: rank 3 float32 tensor with shape
...@@ -2428,14 +2495,19 @@ def ssd_random_crop(image, ...@@ -2428,14 +2495,19 @@ def ssd_random_crop(image,
Returns: A tuple containing image, boxes, labels, keypoints (if not None), Returns: A tuple containing image, boxes, labels, keypoints (if not None),
and masks (if not None). and masks (if not None).
""" """
i = 3 i = 3
image, boxes, labels = selected_result[:i] image, boxes, labels = selected_result[:i]
selected_label_scores = None selected_label_scores = None
selected_multiclass_scores = None
selected_masks = None selected_masks = None
selected_keypoints = None selected_keypoints = None
if label_scores is not None: if label_scores is not None:
selected_label_scores = selected_result[i] selected_label_scores = selected_result[i]
i += 1 i += 1
if multiclass_scores is not None:
selected_multiclass_scores = selected_result[i]
i += 1
if masks is not None: if masks is not None:
selected_masks = selected_result[i] selected_masks = selected_result[i]
i += 1 i += 1
...@@ -2447,6 +2519,7 @@ def ssd_random_crop(image, ...@@ -2447,6 +2519,7 @@ def ssd_random_crop(image,
boxes=boxes, boxes=boxes,
labels=labels, labels=labels,
label_scores=selected_label_scores, label_scores=selected_label_scores,
multiclass_scores=selected_multiclass_scores,
masks=selected_masks, masks=selected_masks,
keypoints=selected_keypoints, keypoints=selected_keypoints,
min_object_covered=min_object_covered[index], min_object_covered=min_object_covered[index],
...@@ -2459,8 +2532,8 @@ def ssd_random_crop(image, ...@@ -2459,8 +2532,8 @@ def ssd_random_crop(image,
result = _apply_with_random_selector_tuples( result = _apply_with_random_selector_tuples(
tuple( tuple(
t for t in (image, boxes, labels, label_scores, masks, keypoints) t for t in (image, boxes, labels, label_scores, multiclass_scores,
if t is not None), masks, keypoints) if t is not None),
random_crop_selector, random_crop_selector,
num_cases=len(min_object_covered), num_cases=len(min_object_covered),
preprocess_vars_cache=preprocess_vars_cache, preprocess_vars_cache=preprocess_vars_cache,
...@@ -2472,6 +2545,7 @@ def ssd_random_crop_pad(image, ...@@ -2472,6 +2545,7 @@ def ssd_random_crop_pad(image,
boxes, boxes,
labels, labels,
label_scores=None, label_scores=None,
multiclass_scores=None,
min_object_covered=(0.1, 0.3, 0.5, 0.7, 0.9, 1.0), min_object_covered=(0.1, 0.3, 0.5, 0.7, 0.9, 1.0),
aspect_ratio_range=((0.5, 2.0),) * 6, aspect_ratio_range=((0.5, 2.0),) * 6,
area_range=((0.1, 1.0),) * 6, area_range=((0.1, 1.0),) * 6,
...@@ -2498,6 +2572,9 @@ def ssd_random_crop_pad(image, ...@@ -2498,6 +2572,9 @@ def ssd_random_crop_pad(image,
labels: rank 1 int32 tensor containing the object classes. labels: rank 1 int32 tensor containing the object classes.
label_scores: float32 tensor of shape [num_instances] representing the label_scores: float32 tensor of shape [num_instances] representing the
score for each box. score for each box.
multiclass_scores: (optional) float32 tensor of shape
[num_instances, num_classes] representing the score for each box for each
class.
min_object_covered: the cropped image must cover at least this fraction of min_object_covered: the cropped image must cover at least this fraction of
at least one of the input bounding boxes. at least one of the input bounding boxes.
aspect_ratio_range: allowed range for aspect ratio of cropped image. aspect_ratio_range: allowed range for aspect ratio of cropped image.
...@@ -2531,17 +2608,23 @@ def ssd_random_crop_pad(image, ...@@ -2531,17 +2608,23 @@ def ssd_random_crop_pad(image,
""" """
def random_crop_pad_selector(image_boxes_labels, index): def random_crop_pad_selector(image_boxes_labels, index):
"""Random crop preprocessing helper."""
i = 3 i = 3
image, boxes, labels = image_boxes_labels[:i] image, boxes, labels = image_boxes_labels[:i]
selected_label_scores = None selected_label_scores = None
selected_multiclass_scores = None
if label_scores is not None: if label_scores is not None:
selected_label_scores = image_boxes_labels[i] selected_label_scores = image_boxes_labels[i]
i += 1
if multiclass_scores is not None:
selected_multiclass_scores = image_boxes_labels[i]
return random_crop_pad_image( return random_crop_pad_image(
image, image,
boxes, boxes,
labels, labels,
selected_label_scores, label_scores=selected_label_scores,
multiclass_scores=selected_multiclass_scores,
min_object_covered=min_object_covered[index], min_object_covered=min_object_covered[index],
aspect_ratio_range=aspect_ratio_range[index], aspect_ratio_range=aspect_ratio_range[index],
area_range=area_range[index], area_range=area_range[index],
...@@ -2554,7 +2637,8 @@ def ssd_random_crop_pad(image, ...@@ -2554,7 +2637,8 @@ def ssd_random_crop_pad(image,
preprocess_vars_cache=preprocess_vars_cache) preprocess_vars_cache=preprocess_vars_cache)
return _apply_with_random_selector_tuples( return _apply_with_random_selector_tuples(
tuple(t for t in (image, boxes, labels, label_scores) if t is not None), tuple(t for t in (image, boxes, labels, label_scores, multiclass_scores)
if t is not None),
random_crop_pad_selector, random_crop_pad_selector,
num_cases=len(min_object_covered), num_cases=len(min_object_covered),
preprocess_vars_cache=preprocess_vars_cache, preprocess_vars_cache=preprocess_vars_cache,
...@@ -2566,6 +2650,7 @@ def ssd_random_crop_fixed_aspect_ratio( ...@@ -2566,6 +2650,7 @@ def ssd_random_crop_fixed_aspect_ratio(
boxes, boxes,
labels, labels,
label_scores=None, label_scores=None,
multiclass_scores=None,
masks=None, masks=None,
keypoints=None, keypoints=None,
min_object_covered=(0.0, 0.1, 0.3, 0.5, 0.7, 0.9, 1.0), min_object_covered=(0.0, 0.1, 0.3, 0.5, 0.7, 0.9, 1.0),
...@@ -2593,6 +2678,9 @@ def ssd_random_crop_fixed_aspect_ratio( ...@@ -2593,6 +2678,9 @@ def ssd_random_crop_fixed_aspect_ratio(
labels: rank 1 int32 tensor containing the object classes. labels: rank 1 int32 tensor containing the object classes.
label_scores: (optional) float32 tensor of shape [num_instances] label_scores: (optional) float32 tensor of shape [num_instances]
representing the score for each box. representing the score for each box.
multiclass_scores: (optional) float32 tensor of shape
[num_instances, num_classes] representing the score for each box for each
class.
masks: (optional) rank 3 float32 tensor with shape masks: (optional) rank 3 float32 tensor with shape
[num_instances, height, width] containing instance masks. The masks [num_instances, height, width] containing instance masks. The masks
are of the same height, width as the input `image`. are of the same height, width as the input `image`.
...@@ -2622,8 +2710,11 @@ def ssd_random_crop_fixed_aspect_ratio( ...@@ -2622,8 +2710,11 @@ def ssd_random_crop_fixed_aspect_ratio(
Boxes are in normalized form. Boxes are in normalized form.
labels: new labels. labels: new labels.
If masks or keypoints is not None, the function also returns: If mulitclass_scores, masks, or keypoints is not None, the function also
returns:
multiclass_scores: rank 2 float32 tensor with shape
[num_instances, num_classes]
masks: rank 3 float32 tensor with shape [num_instances, height, width] masks: rank 3 float32 tensor with shape [num_instances, height, width]
containing instance masks. containing instance masks.
keypoints: rank 3 float32 tensor with shape keypoints: rank 3 float32 tensor with shape
...@@ -2632,29 +2723,46 @@ def ssd_random_crop_fixed_aspect_ratio( ...@@ -2632,29 +2723,46 @@ def ssd_random_crop_fixed_aspect_ratio(
aspect_ratio_range = ((aspect_ratio, aspect_ratio),) * len(area_range) aspect_ratio_range = ((aspect_ratio, aspect_ratio),) * len(area_range)
crop_result = ssd_random_crop( crop_result = ssd_random_crop(
image, boxes, labels, label_scores, masks, keypoints, min_object_covered, image,
aspect_ratio_range, area_range, overlap_thresh, random_coef, seed, boxes,
preprocess_vars_cache) labels,
label_scores=label_scores,
multiclass_scores=multiclass_scores,
masks=masks,
keypoints=keypoints,
min_object_covered=min_object_covered,
aspect_ratio_range=aspect_ratio_range,
area_range=area_range,
overlap_thresh=overlap_thresh,
random_coef=random_coef,
seed=seed,
preprocess_vars_cache=preprocess_vars_cache)
i = 3 i = 3
new_image, new_boxes, new_labels = crop_result[:i] new_image, new_boxes, new_labels = crop_result[:i]
new_label_scores = None new_label_scores = None
new_multiclass_scores = None
new_masks = None new_masks = None
new_keypoints = None new_keypoints = None
if label_scores is not None: if label_scores is not None:
new_label_scores = crop_result[i] new_label_scores = crop_result[i]
i += 1 i += 1
if multiclass_scores is not None:
new_multiclass_scores = crop_result[i]
i += 1
if masks is not None: if masks is not None:
new_masks = crop_result[i] new_masks = crop_result[i]
i += 1 i += 1
if keypoints is not None: if keypoints is not None:
new_keypoints = crop_result[i] new_keypoints = crop_result[i]
result = random_crop_to_aspect_ratio( result = random_crop_to_aspect_ratio(
new_image, new_image,
new_boxes, new_boxes,
new_labels, new_labels,
new_label_scores, label_scores=new_label_scores,
new_masks, multiclass_scores=new_multiclass_scores,
new_keypoints, masks=new_masks,
keypoints=new_keypoints,
aspect_ratio=aspect_ratio, aspect_ratio=aspect_ratio,
seed=seed, seed=seed,
preprocess_vars_cache=preprocess_vars_cache) preprocess_vars_cache=preprocess_vars_cache)
...@@ -2667,6 +2775,7 @@ def ssd_random_crop_pad_fixed_aspect_ratio( ...@@ -2667,6 +2775,7 @@ def ssd_random_crop_pad_fixed_aspect_ratio(
boxes, boxes,
labels, labels,
label_scores=None, label_scores=None,
multiclass_scores=None,
masks=None, masks=None,
keypoints=None, keypoints=None,
min_object_covered=(0.0, 0.1, 0.3, 0.5, 0.7, 0.9, 1.0), min_object_covered=(0.0, 0.1, 0.3, 0.5, 0.7, 0.9, 1.0),
...@@ -2698,6 +2807,9 @@ def ssd_random_crop_pad_fixed_aspect_ratio( ...@@ -2698,6 +2807,9 @@ def ssd_random_crop_pad_fixed_aspect_ratio(
labels: rank 1 int32 tensor containing the object classes. labels: rank 1 int32 tensor containing the object classes.
label_scores: (optional) float32 tensor of shape [num_instances] label_scores: (optional) float32 tensor of shape [num_instances]
representing the score for each box. representing the score for each box.
multiclass_scores: (optional) float32 tensor of shape
[num_instances, num_classes] representing the score for each box for each
class.
masks: (optional) rank 3 float32 tensor with shape masks: (optional) rank 3 float32 tensor with shape
[num_instances, height, width] containing instance masks. The masks [num_instances, height, width] containing instance masks. The masks
are of the same height, width as the input `image`. are of the same height, width as the input `image`.
...@@ -2732,35 +2844,53 @@ def ssd_random_crop_pad_fixed_aspect_ratio( ...@@ -2732,35 +2844,53 @@ def ssd_random_crop_pad_fixed_aspect_ratio(
Boxes are in normalized form. Boxes are in normalized form.
labels: new labels. labels: new labels.
If masks or keypoints is not None, the function also returns: If multiclass_scores, masks, or keypoints is not None, the function also
returns:
multiclass_scores: rank 2 with shape [num_instances, num_classes]
masks: rank 3 float32 tensor with shape [num_instances, height, width] masks: rank 3 float32 tensor with shape [num_instances, height, width]
containing instance masks. containing instance masks.
keypoints: rank 3 float32 tensor with shape keypoints: rank 3 float32 tensor with shape
[num_instances, num_keypoints, 2] [num_instances, num_keypoints, 2]
""" """
crop_result = ssd_random_crop( crop_result = ssd_random_crop(
image, boxes, labels, label_scores, masks, keypoints, min_object_covered, image,
aspect_ratio_range, area_range, overlap_thresh, random_coef, seed, boxes,
preprocess_vars_cache) labels,
label_scores=label_scores,
multiclass_scores=multiclass_scores,
masks=masks,
keypoints=keypoints,
min_object_covered=min_object_covered,
aspect_ratio_range=aspect_ratio_range,
area_range=area_range,
overlap_thresh=overlap_thresh,
random_coef=random_coef,
seed=seed,
preprocess_vars_cache=preprocess_vars_cache)
i = 3 i = 3
new_image, new_boxes, new_labels = crop_result[:i] new_image, new_boxes, new_labels = crop_result[:i]
new_label_scores = None new_label_scores = None
new_multiclass_scores = None
new_masks = None new_masks = None
new_keypoints = None new_keypoints = None
if label_scores is not None: if label_scores is not None:
new_label_scores = crop_result[i] new_label_scores = crop_result[i]
i += 1 i += 1
if multiclass_scores is not None:
new_multiclass_scores = crop_result[i]
i += 1
if masks is not None: if masks is not None:
new_masks = crop_result[i] new_masks = crop_result[i]
i += 1 i += 1
if keypoints is not None: if keypoints is not None:
new_keypoints = crop_result[i] new_keypoints = crop_result[i]
result = random_pad_to_aspect_ratio( result = random_pad_to_aspect_ratio(
new_image, new_image,
new_boxes, new_boxes,
new_masks, masks=new_masks,
new_keypoints, keypoints=new_keypoints,
aspect_ratio=aspect_ratio, aspect_ratio=aspect_ratio,
min_padded_size_ratio=min_padded_size_ratio, min_padded_size_ratio=min_padded_size_ratio,
max_padded_size_ratio=max_padded_size_ratio, max_padded_size_ratio=max_padded_size_ratio,
...@@ -2768,15 +2898,20 @@ def ssd_random_crop_pad_fixed_aspect_ratio( ...@@ -2768,15 +2898,20 @@ def ssd_random_crop_pad_fixed_aspect_ratio(
preprocess_vars_cache=preprocess_vars_cache) preprocess_vars_cache=preprocess_vars_cache)
result = list(result) result = list(result)
if new_label_scores is not None: i = 3
result.insert(2, new_label_scores)
result.insert(2, new_labels) result.insert(2, new_labels)
if new_label_scores is not None:
result.insert(i, new_label_scores)
i += 1
if multiclass_scores is not None:
result.insert(i, new_multiclass_scores)
result = tuple(result) result = tuple(result)
return result return result
def get_default_func_arg_map(include_label_scores=False, def get_default_func_arg_map(include_label_scores=False,
include_multiclass_scores=False,
include_instance_masks=False, include_instance_masks=False,
include_keypoints=False): include_keypoints=False):
"""Returns the default mapping from a preprocessor function to its args. """Returns the default mapping from a preprocessor function to its args.
...@@ -2784,6 +2919,8 @@ def get_default_func_arg_map(include_label_scores=False, ...@@ -2784,6 +2919,8 @@ def get_default_func_arg_map(include_label_scores=False,
Args: Args:
include_label_scores: If True, preprocessing functions will modify the include_label_scores: If True, preprocessing functions will modify the
label scores, too. label scores, too.
include_multiclass_scores: If True, preprocessing functions will modify the
multiclass scores, too.
include_instance_masks: If True, preprocessing functions will modify the include_instance_masks: If True, preprocessing functions will modify the
instance masks, too. instance masks, too.
include_keypoints: If True, preprocessing functions will modify the include_keypoints: If True, preprocessing functions will modify the
...@@ -2796,6 +2933,10 @@ def get_default_func_arg_map(include_label_scores=False, ...@@ -2796,6 +2933,10 @@ def get_default_func_arg_map(include_label_scores=False,
if include_label_scores: if include_label_scores:
groundtruth_label_scores = (fields.InputDataFields.groundtruth_label_scores) groundtruth_label_scores = (fields.InputDataFields.groundtruth_label_scores)
multiclass_scores = None
if include_multiclass_scores:
multiclass_scores = (fields.InputDataFields.multiclass_scores)
groundtruth_instance_masks = None groundtruth_instance_masks = None
if include_instance_masks: if include_instance_masks:
groundtruth_instance_masks = ( groundtruth_instance_masks = (
...@@ -2811,21 +2952,25 @@ def get_default_func_arg_map(include_label_scores=False, ...@@ -2811,21 +2952,25 @@ def get_default_func_arg_map(include_label_scores=False,
fields.InputDataFields.image, fields.InputDataFields.image,
fields.InputDataFields.groundtruth_boxes, fields.InputDataFields.groundtruth_boxes,
groundtruth_instance_masks, groundtruth_instance_masks,
groundtruth_keypoints,), groundtruth_keypoints,
),
random_vertical_flip: ( random_vertical_flip: (
fields.InputDataFields.image, fields.InputDataFields.image,
fields.InputDataFields.groundtruth_boxes, fields.InputDataFields.groundtruth_boxes,
groundtruth_instance_masks, groundtruth_instance_masks,
groundtruth_keypoints,), groundtruth_keypoints,
),
random_rotation90: ( random_rotation90: (
fields.InputDataFields.image, fields.InputDataFields.image,
fields.InputDataFields.groundtruth_boxes, fields.InputDataFields.groundtruth_boxes,
groundtruth_instance_masks, groundtruth_instance_masks,
groundtruth_keypoints,), groundtruth_keypoints,
),
random_pixel_value_scale: (fields.InputDataFields.image,), random_pixel_value_scale: (fields.InputDataFields.image,),
random_image_scale: ( random_image_scale: (
fields.InputDataFields.image, fields.InputDataFields.image,
groundtruth_instance_masks,), groundtruth_instance_masks,
),
random_rgb_to_gray: (fields.InputDataFields.image,), random_rgb_to_gray: (fields.InputDataFields.image,),
random_adjust_brightness: (fields.InputDataFields.image,), random_adjust_brightness: (fields.InputDataFields.image,),
random_adjust_contrast: (fields.InputDataFields.image,), random_adjust_contrast: (fields.InputDataFields.image,),
...@@ -2833,53 +2978,61 @@ def get_default_func_arg_map(include_label_scores=False, ...@@ -2833,53 +2978,61 @@ def get_default_func_arg_map(include_label_scores=False,
random_adjust_saturation: (fields.InputDataFields.image,), random_adjust_saturation: (fields.InputDataFields.image,),
random_distort_color: (fields.InputDataFields.image,), random_distort_color: (fields.InputDataFields.image,),
random_jitter_boxes: (fields.InputDataFields.groundtruth_boxes,), random_jitter_boxes: (fields.InputDataFields.groundtruth_boxes,),
random_crop_image: ( random_crop_image: (fields.InputDataFields.image,
fields.InputDataFields.image, fields.InputDataFields.groundtruth_boxes,
fields.InputDataFields.groundtruth_boxes, fields.InputDataFields.groundtruth_classes,
fields.InputDataFields.groundtruth_classes, groundtruth_label_scores, multiclass_scores,
groundtruth_label_scores, groundtruth_instance_masks, groundtruth_keypoints),
groundtruth_instance_masks,
groundtruth_keypoints,),
random_pad_image: (fields.InputDataFields.image, random_pad_image: (fields.InputDataFields.image,
fields.InputDataFields.groundtruth_boxes), fields.InputDataFields.groundtruth_boxes),
random_crop_pad_image: (fields.InputDataFields.image, random_crop_pad_image: (fields.InputDataFields.image,
fields.InputDataFields.groundtruth_boxes, fields.InputDataFields.groundtruth_boxes,
fields.InputDataFields.groundtruth_classes, fields.InputDataFields.groundtruth_classes,
groundtruth_label_scores), groundtruth_label_scores,
multiclass_scores),
random_crop_to_aspect_ratio: ( random_crop_to_aspect_ratio: (
fields.InputDataFields.image, fields.InputDataFields.image,
fields.InputDataFields.groundtruth_boxes, fields.InputDataFields.groundtruth_boxes,
fields.InputDataFields.groundtruth_classes, fields.InputDataFields.groundtruth_classes,
groundtruth_label_scores, groundtruth_label_scores,
multiclass_scores,
groundtruth_instance_masks, groundtruth_instance_masks,
groundtruth_keypoints,), groundtruth_keypoints,
),
random_pad_to_aspect_ratio: ( random_pad_to_aspect_ratio: (
fields.InputDataFields.image, fields.InputDataFields.image,
fields.InputDataFields.groundtruth_boxes, fields.InputDataFields.groundtruth_boxes,
groundtruth_instance_masks, groundtruth_instance_masks,
groundtruth_keypoints,), groundtruth_keypoints,
),
random_black_patches: (fields.InputDataFields.image,), random_black_patches: (fields.InputDataFields.image,),
retain_boxes_above_threshold: ( retain_boxes_above_threshold: (
fields.InputDataFields.groundtruth_boxes, fields.InputDataFields.groundtruth_boxes,
fields.InputDataFields.groundtruth_classes, fields.InputDataFields.groundtruth_classes,
groundtruth_label_scores, groundtruth_label_scores,
multiclass_scores,
groundtruth_instance_masks, groundtruth_instance_masks,
groundtruth_keypoints,), groundtruth_keypoints,
),
image_to_float: (fields.InputDataFields.image,), image_to_float: (fields.InputDataFields.image,),
random_resize_method: (fields.InputDataFields.image,), random_resize_method: (fields.InputDataFields.image,),
resize_to_range: ( resize_to_range: (
fields.InputDataFields.image, fields.InputDataFields.image,
groundtruth_instance_masks,), groundtruth_instance_masks,
),
resize_to_min_dimension: ( resize_to_min_dimension: (
fields.InputDataFields.image, fields.InputDataFields.image,
groundtruth_instance_masks,), groundtruth_instance_masks,
),
scale_boxes_to_pixel_coordinates: ( scale_boxes_to_pixel_coordinates: (
fields.InputDataFields.image, fields.InputDataFields.image,
fields.InputDataFields.groundtruth_boxes, fields.InputDataFields.groundtruth_boxes,
groundtruth_keypoints,), groundtruth_keypoints,
),
resize_image: ( resize_image: (
fields.InputDataFields.image, fields.InputDataFields.image,
groundtruth_instance_masks,), groundtruth_instance_masks,
),
subtract_channel_mean: (fields.InputDataFields.image,), subtract_channel_mean: (fields.InputDataFields.image,),
one_hot_encoding: (fields.InputDataFields.groundtruth_image_classes,), one_hot_encoding: (fields.InputDataFields.groundtruth_image_classes,),
rgb_to_gray: (fields.InputDataFields.image,), rgb_to_gray: (fields.InputDataFields.image,),
...@@ -2888,26 +3041,29 @@ def get_default_func_arg_map(include_label_scores=False, ...@@ -2888,26 +3041,29 @@ def get_default_func_arg_map(include_label_scores=False,
fields.InputDataFields.groundtruth_boxes, fields.InputDataFields.groundtruth_boxes,
fields.InputDataFields.groundtruth_classes, fields.InputDataFields.groundtruth_classes,
groundtruth_label_scores, groundtruth_label_scores,
multiclass_scores,
groundtruth_instance_masks, groundtruth_instance_masks,
groundtruth_keypoints,), groundtruth_keypoints
),
ssd_random_crop_pad: (fields.InputDataFields.image, ssd_random_crop_pad: (fields.InputDataFields.image,
fields.InputDataFields.groundtruth_boxes, fields.InputDataFields.groundtruth_boxes,
fields.InputDataFields.groundtruth_classes, fields.InputDataFields.groundtruth_classes,
groundtruth_label_scores), groundtruth_label_scores,
multiclass_scores),
ssd_random_crop_fixed_aspect_ratio: ( ssd_random_crop_fixed_aspect_ratio: (
fields.InputDataFields.image, fields.InputDataFields.image,
fields.InputDataFields.groundtruth_boxes, fields.InputDataFields.groundtruth_boxes,
fields.InputDataFields.groundtruth_classes, fields.InputDataFields.groundtruth_classes, groundtruth_label_scores,
groundtruth_label_scores, multiclass_scores, groundtruth_instance_masks, groundtruth_keypoints),
groundtruth_instance_masks,
groundtruth_keypoints,),
ssd_random_crop_pad_fixed_aspect_ratio: ( ssd_random_crop_pad_fixed_aspect_ratio: (
fields.InputDataFields.image, fields.InputDataFields.image,
fields.InputDataFields.groundtruth_boxes, fields.InputDataFields.groundtruth_boxes,
fields.InputDataFields.groundtruth_classes, fields.InputDataFields.groundtruth_classes,
groundtruth_label_scores, groundtruth_label_scores,
multiclass_scores,
groundtruth_instance_masks, groundtruth_instance_masks,
groundtruth_keypoints,), groundtruth_keypoints,
),
} }
return prep_func_arg_map return prep_func_arg_map
......
...@@ -119,6 +119,9 @@ class PreprocessorTest(tf.test.TestCase): ...@@ -119,6 +119,9 @@ class PreprocessorTest(tf.test.TestCase):
[[-0.1, 0.25, 0.75, 1], [0.25, 0.5, 0.75, 1.1]], dtype=tf.float32) [[-0.1, 0.25, 0.75, 1], [0.25, 0.5, 0.75, 1.1]], dtype=tf.float32)
return boxes return boxes
def createTestMultiClassScores(self):
return tf.constant([[1.0, 0.0], [0.5, 0.5]], dtype=tf.float32)
def expectedImagesAfterNormalization(self): def expectedImagesAfterNormalization(self):
images_r = tf.constant([[[0, 0, 0, 0], [-1, -1, 0, 0], images_r = tf.constant([[[0, 0, 0, 0], [-1, -1, 0, 0],
[-1, 0, 0, 0], [0.5, 0.5, 0, 0]]], [-1, 0, 0, 0], [0.5, 0.5, 0, 0]]],
...@@ -269,6 +272,9 @@ class PreprocessorTest(tf.test.TestCase): ...@@ -269,6 +272,9 @@ class PreprocessorTest(tf.test.TestCase):
def expectedLabelsAfterThresholding(self): def expectedLabelsAfterThresholding(self):
return tf.constant([1], dtype=tf.float32) return tf.constant([1], dtype=tf.float32)
def expectedMultiClassScoresAfterThresholding(self):
return tf.constant([[1.0, 0.0]], dtype=tf.float32)
def expectedMasksAfterThresholding(self): def expectedMasksAfterThresholding(self):
mask = np.array([ mask = np.array([
[[255.0, 0.0, 0.0], [[255.0, 0.0, 0.0],
...@@ -345,6 +351,28 @@ class PreprocessorTest(tf.test.TestCase): ...@@ -345,6 +351,28 @@ class PreprocessorTest(tf.test.TestCase):
self.assertAllClose( self.assertAllClose(
retained_label_scores_, expected_retained_label_scores_) retained_label_scores_, expected_retained_label_scores_)
def testRetainBoxesAboveThresholdWithMultiClassScores(self):
boxes = self.createTestBoxes()
labels = self.createTestLabels()
label_scores = self.createTestLabelScores()
multiclass_scores = self.createTestMultiClassScores()
(_, _, _,
retained_multiclass_scores) = preprocessor.retain_boxes_above_threshold(
boxes,
labels,
label_scores,
multiclass_scores=multiclass_scores,
threshold=0.6)
with self.test_session() as sess:
(retained_multiclass_scores_,
expected_retained_multiclass_scores_) = sess.run([
retained_multiclass_scores,
self.expectedMultiClassScoresAfterThresholding()
])
self.assertAllClose(retained_multiclass_scores_,
expected_retained_multiclass_scores_)
def testRetainBoxesAboveThresholdWithMasks(self): def testRetainBoxesAboveThresholdWithMasks(self):
boxes = self.createTestBoxes() boxes = self.createTestBoxes()
labels = self.createTestLabels() labels = self.createTestLabels()
...@@ -1264,6 +1292,48 @@ class PreprocessorTest(tf.test.TestCase): ...@@ -1264,6 +1292,48 @@ class PreprocessorTest(tf.test.TestCase):
self.assertAllClose(distorted_boxes_, expected_boxes_) self.assertAllClose(distorted_boxes_, expected_boxes_)
self.assertAllEqual(distorted_labels_, expected_labels_) self.assertAllEqual(distorted_labels_, expected_labels_)
def testRandomCropImageWithMultiClassScores(self):
preprocessing_options = []
preprocessing_options.append((preprocessor.normalize_image, {
'original_minval': 0,
'original_maxval': 255,
'target_minval': 0,
'target_maxval': 1
}))
preprocessing_options.append((preprocessor.random_crop_image, {}))
images = self.createTestImages()
boxes = self.createTestBoxes()
labels = self.createTestLabels()
multiclass_scores = self.createTestMultiClassScores()
tensor_dict = {
fields.InputDataFields.image: images,
fields.InputDataFields.groundtruth_boxes: boxes,
fields.InputDataFields.groundtruth_classes: labels,
fields.InputDataFields.multiclass_scores: multiclass_scores
}
distorted_tensor_dict = preprocessor.preprocess(tensor_dict,
preprocessing_options)
distorted_images = distorted_tensor_dict[fields.InputDataFields.image]
distorted_boxes = distorted_tensor_dict[
fields.InputDataFields.groundtruth_boxes]
distorted_multiclass_scores = distorted_tensor_dict[
fields.InputDataFields.multiclass_scores]
boxes_rank = tf.rank(boxes)
distorted_boxes_rank = tf.rank(distorted_boxes)
images_rank = tf.rank(images)
distorted_images_rank = tf.rank(distorted_images)
with self.test_session() as sess:
(boxes_rank_, distorted_boxes_rank_, images_rank_, distorted_images_rank_,
multiclass_scores_, distorted_multiclass_scores_) = sess.run([
boxes_rank, distorted_boxes_rank, images_rank, distorted_images_rank,
multiclass_scores, distorted_multiclass_scores
])
self.assertAllEqual(boxes_rank_, distorted_boxes_rank_)
self.assertAllEqual(images_rank_, distorted_images_rank_)
self.assertAllEqual(multiclass_scores_, distorted_multiclass_scores_)
def testStrictRandomCropImageWithLabelScores(self): def testStrictRandomCropImageWithLabelScores(self):
image = self.createColorfulTestImage()[0] image = self.createColorfulTestImage()[0]
boxes = self.createTestBoxes() boxes = self.createTestBoxes()
...@@ -2510,6 +2580,49 @@ class PreprocessorTest(tf.test.TestCase): ...@@ -2510,6 +2580,49 @@ class PreprocessorTest(tf.test.TestCase):
self.assertAllEqual(boxes_rank_, distorted_boxes_rank_) self.assertAllEqual(boxes_rank_, distorted_boxes_rank_)
self.assertAllEqual(images_rank_, distorted_images_rank_) self.assertAllEqual(images_rank_, distorted_images_rank_)
def testSSDRandomCropWithMultiClassScores(self):
preprocessing_options = [(preprocessor.normalize_image, {
'original_minval': 0,
'original_maxval': 255,
'target_minval': 0,
'target_maxval': 1
}), (preprocessor.ssd_random_crop, {})]
images = self.createTestImages()
boxes = self.createTestBoxes()
labels = self.createTestLabels()
multiclass_scores = self.createTestMultiClassScores()
tensor_dict = {
fields.InputDataFields.image: images,
fields.InputDataFields.groundtruth_boxes: boxes,
fields.InputDataFields.groundtruth_classes: labels,
fields.InputDataFields.multiclass_scores: multiclass_scores,
}
preprocessor_arg_map = preprocessor.get_default_func_arg_map(
include_multiclass_scores=True)
distorted_tensor_dict = preprocessor.preprocess(
tensor_dict, preprocessing_options, func_arg_map=preprocessor_arg_map)
distorted_images = distorted_tensor_dict[fields.InputDataFields.image]
distorted_boxes = distorted_tensor_dict[
fields.InputDataFields.groundtruth_boxes]
distorted_multiclass_scores = distorted_tensor_dict[
fields.InputDataFields.multiclass_scores]
images_rank = tf.rank(images)
distorted_images_rank = tf.rank(distorted_images)
boxes_rank = tf.rank(boxes)
distorted_boxes_rank = tf.rank(distorted_boxes)
with self.test_session() as sess:
(boxes_rank_, distorted_boxes_rank_, images_rank_, distorted_images_rank_,
multiclass_scores_, distorted_multiclass_scores_) = sess.run([
boxes_rank, distorted_boxes_rank, images_rank, distorted_images_rank,
multiclass_scores, distorted_multiclass_scores
])
self.assertAllEqual(boxes_rank_, distorted_boxes_rank_)
self.assertAllEqual(images_rank_, distorted_images_rank_)
self.assertAllEqual(multiclass_scores_, distorted_multiclass_scores_)
def testSSDRandomCropPad(self): def testSSDRandomCropPad(self):
images = self.createTestImages() images = self.createTestImages()
boxes = self.createTestBoxes() boxes = self.createTestBoxes()
...@@ -2562,28 +2675,31 @@ class PreprocessorTest(tf.test.TestCase): ...@@ -2562,28 +2675,31 @@ class PreprocessorTest(tf.test.TestCase):
def _testSSDRandomCropFixedAspectRatio(self, def _testSSDRandomCropFixedAspectRatio(self,
include_label_scores, include_label_scores,
include_multiclass_scores,
include_instance_masks, include_instance_masks,
include_keypoints): include_keypoints):
images = self.createTestImages() images = self.createTestImages()
boxes = self.createTestBoxes() boxes = self.createTestBoxes()
labels = self.createTestLabels() labels = self.createTestLabels()
preprocessing_options = [ preprocessing_options = [(preprocessor.normalize_image, {
(preprocessor.normalize_image, { 'original_minval': 0,
'original_minval': 0, 'original_maxval': 255,
'original_maxval': 255, 'target_minval': 0,
'target_minval': 0, 'target_maxval': 1
'target_maxval': 1 }), (preprocessor.ssd_random_crop_fixed_aspect_ratio, {})]
}),
(preprocessor.ssd_random_crop_fixed_aspect_ratio, {})]
tensor_dict = { tensor_dict = {
fields.InputDataFields.image: images, fields.InputDataFields.image: images,
fields.InputDataFields.groundtruth_boxes: boxes, fields.InputDataFields.groundtruth_boxes: boxes,
fields.InputDataFields.groundtruth_classes: labels fields.InputDataFields.groundtruth_classes: labels,
} }
if include_label_scores: if include_label_scores:
label_scores = self.createTestLabelScores() label_scores = self.createTestLabelScores()
tensor_dict[fields.InputDataFields.groundtruth_label_scores] = ( tensor_dict[fields.InputDataFields.groundtruth_label_scores] = (
label_scores) label_scores)
if include_multiclass_scores:
multiclass_scores = self.createTestMultiClassScores()
tensor_dict[fields.InputDataFields.multiclass_scores] = (
multiclass_scores)
if include_instance_masks: if include_instance_masks:
masks = self.createTestMasks() masks = self.createTestMasks()
tensor_dict[fields.InputDataFields.groundtruth_instance_masks] = masks tensor_dict[fields.InputDataFields.groundtruth_instance_masks] = masks
...@@ -2593,6 +2709,7 @@ class PreprocessorTest(tf.test.TestCase): ...@@ -2593,6 +2709,7 @@ class PreprocessorTest(tf.test.TestCase):
preprocessor_arg_map = preprocessor.get_default_func_arg_map( preprocessor_arg_map = preprocessor.get_default_func_arg_map(
include_label_scores=include_label_scores, include_label_scores=include_label_scores,
include_multiclass_scores=include_multiclass_scores,
include_instance_masks=include_instance_masks, include_instance_masks=include_instance_masks,
include_keypoints=include_keypoints) include_keypoints=include_keypoints)
distorted_tensor_dict = preprocessor.preprocess( distorted_tensor_dict = preprocessor.preprocess(
...@@ -2615,16 +2732,25 @@ class PreprocessorTest(tf.test.TestCase): ...@@ -2615,16 +2732,25 @@ class PreprocessorTest(tf.test.TestCase):
def testSSDRandomCropFixedAspectRatio(self): def testSSDRandomCropFixedAspectRatio(self):
self._testSSDRandomCropFixedAspectRatio(include_label_scores=False, self._testSSDRandomCropFixedAspectRatio(include_label_scores=False,
include_multiclass_scores=False,
include_instance_masks=False,
include_keypoints=False)
def testSSDRandomCropFixedAspectRatioWithMultiClassScores(self):
self._testSSDRandomCropFixedAspectRatio(include_label_scores=False,
include_multiclass_scores=True,
include_instance_masks=False, include_instance_masks=False,
include_keypoints=False) include_keypoints=False)
def testSSDRandomCropFixedAspectRatioWithMasksAndKeypoints(self): def testSSDRandomCropFixedAspectRatioWithMasksAndKeypoints(self):
self._testSSDRandomCropFixedAspectRatio(include_label_scores=False, self._testSSDRandomCropFixedAspectRatio(include_label_scores=False,
include_multiclass_scores=False,
include_instance_masks=True, include_instance_masks=True,
include_keypoints=True) include_keypoints=True)
def testSSDRandomCropFixedAspectRatioWithLabelScoresMasksAndKeypoints(self): def testSSDRandomCropFixedAspectRatioWithLabelScoresMasksAndKeypoints(self):
self._testSSDRandomCropFixedAspectRatio(include_label_scores=True, self._testSSDRandomCropFixedAspectRatio(include_label_scores=True,
include_multiclass_scores=False,
include_instance_masks=True, include_instance_masks=True,
include_keypoints=True) include_keypoints=True)
......
...@@ -61,6 +61,9 @@ class InputDataFields(object): ...@@ -61,6 +61,9 @@ class InputDataFields(object):
num_groundtruth_boxes: number of groundtruth boxes. num_groundtruth_boxes: number of groundtruth boxes.
true_image_shapes: true shapes of images in the resized images, as resized true_image_shapes: true shapes of images in the resized images, as resized
images can be padded with zeros. images can be padded with zeros.
verified_labels: list of human-verified image-level labels (note, that a
label can be verified both as positive and negative).
multiclass_scores: the label score per class for each box.
""" """
image = 'image' image = 'image'
original_image = 'original_image' original_image = 'original_image'
...@@ -86,6 +89,8 @@ class InputDataFields(object): ...@@ -86,6 +89,8 @@ class InputDataFields(object):
groundtruth_weights = 'groundtruth_weights' groundtruth_weights = 'groundtruth_weights'
num_groundtruth_boxes = 'num_groundtruth_boxes' num_groundtruth_boxes = 'num_groundtruth_boxes'
true_image_shape = 'true_image_shape' true_image_shape = 'true_image_shape'
verified_labels = 'verified_labels'
multiclass_scores = 'multiclass_scores'
class DetectionResultFields(object): class DetectionResultFields(object):
......
confidential;1;confidentialit,confidentiality
dogfood;1;
fishfood;1;
catfood;1;
teamfood;1;
droidfood;1;
//go/;1;
//sites/;1;
a/google.com;1;
corp.google.com;1;
.googleplex.com;1;
sandbox.;1;wallet-web.sandbox.,sandbox.google.com/checkout, sandbox.,paymentssandbox
stupid;1;astupidi
caution:;2;
fixme:;2;
fixme(;2;
internal only;2;
internal_only;2;
backdoor;2;
STOPSHIP;2;
ridiculous;1;
notasecret;1;
@google.com;1;noreply@google.com
$RE:chmod [0-9]?777;3;chmod (0)777
mactruck;2;
seastar;2;
...@@ -229,7 +229,7 @@ class FasterRCNNMetaArch(model.DetectionModel): ...@@ -229,7 +229,7 @@ class FasterRCNNMetaArch(model.DetectionModel):
number_of_stages, number_of_stages,
first_stage_anchor_generator, first_stage_anchor_generator,
first_stage_atrous_rate, first_stage_atrous_rate,
first_stage_box_predictor_arg_scope, first_stage_box_predictor_arg_scope_fn,
first_stage_box_predictor_kernel_size, first_stage_box_predictor_kernel_size,
first_stage_box_predictor_depth, first_stage_box_predictor_depth,
first_stage_minibatch_size, first_stage_minibatch_size,
...@@ -291,8 +291,9 @@ class FasterRCNNMetaArch(model.DetectionModel): ...@@ -291,8 +291,9 @@ class FasterRCNNMetaArch(model.DetectionModel):
denser resolutions. The atrous rate is used to compensate for the denser resolutions. The atrous rate is used to compensate for the
denser feature maps by using an effectively larger receptive field. denser feature maps by using an effectively larger receptive field.
(This should typically be set to 1). (This should typically be set to 1).
first_stage_box_predictor_arg_scope: Slim arg_scope for conv2d, first_stage_box_predictor_arg_scope_fn: A function to construct tf-slim
separable_conv2d and fully_connected ops for the RPN box predictor. arg_scope for conv2d, separable_conv2d and fully_connected ops for the
RPN box predictor.
first_stage_box_predictor_kernel_size: Kernel size to use for the first_stage_box_predictor_kernel_size: Kernel size to use for the
convolution op just prior to RPN box predictions. convolution op just prior to RPN box predictions.
first_stage_box_predictor_depth: Output depth for the convolution op first_stage_box_predictor_depth: Output depth for the convolution op
...@@ -396,8 +397,8 @@ class FasterRCNNMetaArch(model.DetectionModel): ...@@ -396,8 +397,8 @@ class FasterRCNNMetaArch(model.DetectionModel):
# (First stage) Region proposal network parameters # (First stage) Region proposal network parameters
self._first_stage_anchor_generator = first_stage_anchor_generator self._first_stage_anchor_generator = first_stage_anchor_generator
self._first_stage_atrous_rate = first_stage_atrous_rate self._first_stage_atrous_rate = first_stage_atrous_rate
self._first_stage_box_predictor_arg_scope = ( self._first_stage_box_predictor_arg_scope_fn = (
first_stage_box_predictor_arg_scope) first_stage_box_predictor_arg_scope_fn)
self._first_stage_box_predictor_kernel_size = ( self._first_stage_box_predictor_kernel_size = (
first_stage_box_predictor_kernel_size) first_stage_box_predictor_kernel_size)
self._first_stage_box_predictor_depth = first_stage_box_predictor_depth self._first_stage_box_predictor_depth = first_stage_box_predictor_depth
...@@ -406,7 +407,7 @@ class FasterRCNNMetaArch(model.DetectionModel): ...@@ -406,7 +407,7 @@ class FasterRCNNMetaArch(model.DetectionModel):
positive_fraction=first_stage_positive_balance_fraction) positive_fraction=first_stage_positive_balance_fraction)
self._first_stage_box_predictor = box_predictor.ConvolutionalBoxPredictor( self._first_stage_box_predictor = box_predictor.ConvolutionalBoxPredictor(
self._is_training, num_classes=1, self._is_training, num_classes=1,
conv_hyperparams=self._first_stage_box_predictor_arg_scope, conv_hyperparams_fn=self._first_stage_box_predictor_arg_scope_fn,
min_depth=0, max_depth=0, num_layers_before_predictor=0, min_depth=0, max_depth=0, num_layers_before_predictor=0,
use_dropout=False, dropout_keep_prob=1.0, kernel_size=1, use_dropout=False, dropout_keep_prob=1.0, kernel_size=1,
box_code_size=self._box_coder.code_size) box_code_size=self._box_coder.code_size)
...@@ -914,7 +915,7 @@ class FasterRCNNMetaArch(model.DetectionModel): ...@@ -914,7 +915,7 @@ class FasterRCNNMetaArch(model.DetectionModel):
anchors = box_list_ops.concatenate( anchors = box_list_ops.concatenate(
self._first_stage_anchor_generator.generate([(feature_map_shape[1], self._first_stage_anchor_generator.generate([(feature_map_shape[1],
feature_map_shape[2])])) feature_map_shape[2])]))
with slim.arg_scope(self._first_stage_box_predictor_arg_scope): with slim.arg_scope(self._first_stage_box_predictor_arg_scope_fn()):
kernel_size = self._first_stage_box_predictor_kernel_size kernel_size = self._first_stage_box_predictor_kernel_size
rpn_box_predictor_features = slim.conv2d( rpn_box_predictor_features = slim.conv2d(
rpn_features_to_crop, rpn_features_to_crop,
......
...@@ -196,7 +196,7 @@ class FasterRCNNMetaArchTestBase(tf.test.TestCase): ...@@ -196,7 +196,7 @@ class FasterRCNNMetaArchTestBase(tf.test.TestCase):
} }
} }
""" """
first_stage_box_predictor_arg_scope = ( first_stage_box_predictor_arg_scope_fn = (
self._build_arg_scope_with_hyperparams( self._build_arg_scope_with_hyperparams(
first_stage_box_predictor_hyperparams_text_proto, is_training)) first_stage_box_predictor_hyperparams_text_proto, is_training))
...@@ -255,8 +255,8 @@ class FasterRCNNMetaArchTestBase(tf.test.TestCase): ...@@ -255,8 +255,8 @@ class FasterRCNNMetaArchTestBase(tf.test.TestCase):
'number_of_stages': number_of_stages, 'number_of_stages': number_of_stages,
'first_stage_anchor_generator': first_stage_anchor_generator, 'first_stage_anchor_generator': first_stage_anchor_generator,
'first_stage_atrous_rate': first_stage_atrous_rate, 'first_stage_atrous_rate': first_stage_atrous_rate,
'first_stage_box_predictor_arg_scope': 'first_stage_box_predictor_arg_scope_fn':
first_stage_box_predictor_arg_scope, first_stage_box_predictor_arg_scope_fn,
'first_stage_box_predictor_kernel_size': 'first_stage_box_predictor_kernel_size':
first_stage_box_predictor_kernel_size, first_stage_box_predictor_kernel_size,
'first_stage_box_predictor_depth': first_stage_box_predictor_depth, 'first_stage_box_predictor_depth': first_stage_box_predictor_depth,
......
...@@ -56,7 +56,7 @@ class RFCNMetaArch(faster_rcnn_meta_arch.FasterRCNNMetaArch): ...@@ -56,7 +56,7 @@ class RFCNMetaArch(faster_rcnn_meta_arch.FasterRCNNMetaArch):
number_of_stages, number_of_stages,
first_stage_anchor_generator, first_stage_anchor_generator,
first_stage_atrous_rate, first_stage_atrous_rate,
first_stage_box_predictor_arg_scope, first_stage_box_predictor_arg_scope_fn,
first_stage_box_predictor_kernel_size, first_stage_box_predictor_kernel_size,
first_stage_box_predictor_depth, first_stage_box_predictor_depth,
first_stage_minibatch_size, first_stage_minibatch_size,
...@@ -103,8 +103,9 @@ class RFCNMetaArch(faster_rcnn_meta_arch.FasterRCNNMetaArch): ...@@ -103,8 +103,9 @@ class RFCNMetaArch(faster_rcnn_meta_arch.FasterRCNNMetaArch):
denser resolutions. The atrous rate is used to compensate for the denser resolutions. The atrous rate is used to compensate for the
denser feature maps by using an effectively larger receptive field. denser feature maps by using an effectively larger receptive field.
(This should typically be set to 1). (This should typically be set to 1).
first_stage_box_predictor_arg_scope: Slim arg_scope for conv2d, first_stage_box_predictor_arg_scope_fn: A function to generate tf-slim
separable_conv2d and fully_connected ops for the RPN box predictor. arg_scope for conv2d, separable_conv2d and fully_connected ops for the
RPN box predictor.
first_stage_box_predictor_kernel_size: Kernel size to use for the first_stage_box_predictor_kernel_size: Kernel size to use for the
convolution op just prior to RPN box predictions. convolution op just prior to RPN box predictions.
first_stage_box_predictor_depth: Output depth for the convolution op first_stage_box_predictor_depth: Output depth for the convolution op
...@@ -174,7 +175,7 @@ class RFCNMetaArch(faster_rcnn_meta_arch.FasterRCNNMetaArch): ...@@ -174,7 +175,7 @@ class RFCNMetaArch(faster_rcnn_meta_arch.FasterRCNNMetaArch):
number_of_stages, number_of_stages,
first_stage_anchor_generator, first_stage_anchor_generator,
first_stage_atrous_rate, first_stage_atrous_rate,
first_stage_box_predictor_arg_scope, first_stage_box_predictor_arg_scope_fn,
first_stage_box_predictor_kernel_size, first_stage_box_predictor_kernel_size,
first_stage_box_predictor_depth, first_stage_box_predictor_depth,
first_stage_minibatch_size, first_stage_minibatch_size,
......
...@@ -42,12 +42,10 @@ class SSDFeatureExtractor(object): ...@@ -42,12 +42,10 @@ class SSDFeatureExtractor(object):
depth_multiplier, depth_multiplier,
min_depth, min_depth,
pad_to_multiple, pad_to_multiple,
conv_hyperparams, conv_hyperparams_fn,
batch_norm_trainable=True,
reuse_weights=None, reuse_weights=None,
use_explicit_padding=False, use_explicit_padding=False,
use_depthwise=False, use_depthwise=False):
inplace_batchnorm_update=False):
"""Constructor. """Constructor.
Args: Args:
...@@ -56,27 +54,19 @@ class SSDFeatureExtractor(object): ...@@ -56,27 +54,19 @@ class SSDFeatureExtractor(object):
min_depth: minimum feature extractor depth. min_depth: minimum feature extractor depth.
pad_to_multiple: the nearest multiple to zero pad the input height and pad_to_multiple: the nearest multiple to zero pad the input height and
width dimensions to. width dimensions to.
conv_hyperparams: tf slim arg_scope for conv2d and separable_conv2d ops. conv_hyperparams_fn: A function to construct tf slim arg_scope for conv2d
batch_norm_trainable: Whether to update batch norm parameters during and separable_conv2d ops.
training or not. When training with a small batch size
(e.g. 1), it is desirable to disable batch norm update and use
pretrained batch norm params.
reuse_weights: whether to reuse variables. Default is None. reuse_weights: whether to reuse variables. Default is None.
use_explicit_padding: Whether to use explicit padding when extracting use_explicit_padding: Whether to use explicit padding when extracting
features. Default is False. features. Default is False.
use_depthwise: Whether to use depthwise convolutions. Default is False. use_depthwise: Whether to use depthwise convolutions. Default is False.
inplace_batchnorm_update: Whether to update batch norm moving average
values inplace. When this is false train op must add a control
dependency on tf.graphkeys.UPDATE_OPS collection in order to update
batch norm statistics.
""" """
self._is_training = is_training self._is_training = is_training
self._depth_multiplier = depth_multiplier self._depth_multiplier = depth_multiplier
self._min_depth = min_depth self._min_depth = min_depth
self._pad_to_multiple = pad_to_multiple self._pad_to_multiple = pad_to_multiple
self._conv_hyperparams = conv_hyperparams self._conv_hyperparams_fn = conv_hyperparams_fn
self._batch_norm_trainable = batch_norm_trainable
self._inplace_batchnorm_update = inplace_batchnorm_update
self._reuse_weights = reuse_weights self._reuse_weights = reuse_weights
self._use_explicit_padding = use_explicit_padding self._use_explicit_padding = use_explicit_padding
self._use_depthwise = use_depthwise self._use_depthwise = use_depthwise
...@@ -106,28 +96,6 @@ class SSDFeatureExtractor(object): ...@@ -106,28 +96,6 @@ class SSDFeatureExtractor(object):
This function is responsible for extracting feature maps from preprocessed This function is responsible for extracting feature maps from preprocessed
images. images.
Args:
preprocessed_inputs: a [batch, height, width, channels] float tensor
representing a batch of images.
Returns:
feature_maps: a list of tensors where the ith tensor has shape
[batch, height_i, width_i, depth_i]
"""
batchnorm_updates_collections = (None if self._inplace_batchnorm_update
else tf.GraphKeys.UPDATE_OPS)
with slim.arg_scope([slim.batch_norm],
updates_collections=batchnorm_updates_collections):
return self._extract_features(preprocessed_inputs)
@abstractmethod
def _extract_features(self, preprocessed_inputs):
"""Extracts features from preprocessed inputs.
This function is responsible for extracting feature maps from preprocessed
images.
Args: Args:
preprocessed_inputs: a [batch, height, width, channels] float tensor preprocessed_inputs: a [batch, height, width, channels] float tensor
representing a batch of images. representing a batch of images.
...@@ -162,7 +130,9 @@ class SSDMetaArch(model.DetectionModel): ...@@ -162,7 +130,9 @@ class SSDMetaArch(model.DetectionModel):
normalize_loss_by_num_matches, normalize_loss_by_num_matches,
hard_example_miner, hard_example_miner,
add_summaries=True, add_summaries=True,
normalize_loc_loss_by_codesize=False): normalize_loc_loss_by_codesize=False,
freeze_batchnorm=False,
inplace_batchnorm_update=False):
"""SSDMetaArch Constructor. """SSDMetaArch Constructor.
TODO(rathodv,jonathanhuang): group NMS parameters + score converter into TODO(rathodv,jonathanhuang): group NMS parameters + score converter into
...@@ -209,9 +179,19 @@ class SSDMetaArch(model.DetectionModel): ...@@ -209,9 +179,19 @@ class SSDMetaArch(model.DetectionModel):
should be added to tensorflow graph. should be added to tensorflow graph.
normalize_loc_loss_by_codesize: whether to normalize localization loss normalize_loc_loss_by_codesize: whether to normalize localization loss
by code size of the box encoder. by code size of the box encoder.
freeze_batchnorm: Whether to freeze batch norm parameters during
training or not. When training with a small batch size (e.g. 1), it is
desirable to freeze batch norm update and use pretrained batch norm
params.
inplace_batchnorm_update: Whether to update batch norm moving average
values inplace. When this is false train op must add a control
dependency on tf.graphkeys.UPDATE_OPS collection in order to update
batch norm statistics.
""" """
super(SSDMetaArch, self).__init__(num_classes=box_predictor.num_classes) super(SSDMetaArch, self).__init__(num_classes=box_predictor.num_classes)
self._is_training = is_training self._is_training = is_training
self._freeze_batchnorm = freeze_batchnorm
self._inplace_batchnorm_update = inplace_batchnorm_update
# Needed for fine-tuning from classification checkpoints whose # Needed for fine-tuning from classification checkpoints whose
# variables do not have the feature extractor scope. # variables do not have the feature extractor scope.
...@@ -372,32 +352,40 @@ class SSDMetaArch(model.DetectionModel): ...@@ -372,32 +352,40 @@ class SSDMetaArch(model.DetectionModel):
5) anchors: 2-D float tensor of shape [num_anchors, 4] containing 5) anchors: 2-D float tensor of shape [num_anchors, 4] containing
the generated anchors in normalized coordinates. the generated anchors in normalized coordinates.
""" """
with tf.variable_scope(None, self._extract_features_scope, batchnorm_updates_collections = (None if self._inplace_batchnorm_update
[preprocessed_inputs]): else tf.GraphKeys.UPDATE_OPS)
feature_maps = self._feature_extractor.extract_features( with slim.arg_scope([slim.batch_norm],
is_training=(self._is_training and
not self._freeze_batchnorm),
updates_collections=batchnorm_updates_collections):
with tf.variable_scope(None, self._extract_features_scope,
[preprocessed_inputs]):
feature_maps = self._feature_extractor.extract_features(
preprocessed_inputs)
feature_map_spatial_dims = self._get_feature_map_spatial_dims(
feature_maps)
image_shape = shape_utils.combined_static_and_dynamic_shape(
preprocessed_inputs) preprocessed_inputs)
feature_map_spatial_dims = self._get_feature_map_spatial_dims(feature_maps) self._anchors = box_list_ops.concatenate(
image_shape = shape_utils.combined_static_and_dynamic_shape( self._anchor_generator.generate(
preprocessed_inputs) feature_map_spatial_dims,
self._anchors = box_list_ops.concatenate( im_height=image_shape[1],
self._anchor_generator.generate( im_width=image_shape[2]))
feature_map_spatial_dims, prediction_dict = self._box_predictor.predict(
im_height=image_shape[1], feature_maps, self._anchor_generator.num_anchors_per_location())
im_width=image_shape[2])) box_encodings = tf.squeeze(
prediction_dict = self._box_predictor.predict( tf.concat(prediction_dict['box_encodings'], axis=1), axis=2)
feature_maps, self._anchor_generator.num_anchors_per_location()) class_predictions_with_background = tf.concat(
box_encodings = tf.squeeze( prediction_dict['class_predictions_with_background'], axis=1)
tf.concat(prediction_dict['box_encodings'], axis=1), axis=2) predictions_dict = {
class_predictions_with_background = tf.concat( 'preprocessed_inputs': preprocessed_inputs,
prediction_dict['class_predictions_with_background'], axis=1) 'box_encodings': box_encodings,
predictions_dict = { 'class_predictions_with_background':
'preprocessed_inputs': preprocessed_inputs, class_predictions_with_background,
'box_encodings': box_encodings, 'feature_maps': feature_maps,
'class_predictions_with_background': class_predictions_with_background, 'anchors': self._anchors.get()
'feature_maps': feature_maps, }
'anchors': self._anchors.get() return predictions_dict
}
return predictions_dict
def _get_feature_map_spatial_dims(self, feature_maps): def _get_feature_map_spatial_dims(self, feature_maps):
"""Return list of spatial dimensions for each feature map in a list. """Return list of spatial dimensions for each feature map in a list.
......
...@@ -38,8 +38,7 @@ class FakeSSDFeatureExtractor(ssd_meta_arch.SSDFeatureExtractor): ...@@ -38,8 +38,7 @@ class FakeSSDFeatureExtractor(ssd_meta_arch.SSDFeatureExtractor):
depth_multiplier=0, depth_multiplier=0,
min_depth=0, min_depth=0,
pad_to_multiple=1, pad_to_multiple=1,
batch_norm_trainable=True, conv_hyperparams_fn=None)
conv_hyperparams=None)
def preprocess(self, resized_inputs): def preprocess(self, resized_inputs):
return tf.identity(resized_inputs) return tf.identity(resized_inputs)
...@@ -124,7 +123,8 @@ class SsdMetaArchTest(test_case.TestCase): ...@@ -124,7 +123,8 @@ class SsdMetaArchTest(test_case.TestCase):
non_max_suppression_fn, tf.identity, classification_loss, non_max_suppression_fn, tf.identity, classification_loss,
localization_loss, classification_loss_weight, localization_loss_weight, localization_loss, classification_loss_weight, localization_loss_weight,
normalize_loss_by_num_matches, hard_example_miner, add_summaries=False, normalize_loss_by_num_matches, hard_example_miner, add_summaries=False,
normalize_loc_loss_by_codesize=normalize_loc_loss_by_codesize) normalize_loc_loss_by_codesize=normalize_loc_loss_by_codesize,
freeze_batchnorm=False, inplace_batchnorm_update=False)
return model, num_classes, mock_anchor_generator.num_anchors(), code_size return model, num_classes, mock_anchor_generator.num_anchors(), code_size
def test_preprocess_preserves_shapes_with_dynamic_input_image(self): def test_preprocess_preserves_shapes_with_dynamic_input_image(self):
......
...@@ -49,12 +49,10 @@ class EmbeddedSSDMobileNetV1FeatureExtractor( ...@@ -49,12 +49,10 @@ class EmbeddedSSDMobileNetV1FeatureExtractor(
depth_multiplier, depth_multiplier,
min_depth, min_depth,
pad_to_multiple, pad_to_multiple,
conv_hyperparams, conv_hyperparams_fn,
batch_norm_trainable=True,
reuse_weights=None, reuse_weights=None,
use_explicit_padding=False, use_explicit_padding=False,
use_depthwise=False, use_depthwise=False):
inplace_batchnorm_update=False):
"""MobileNetV1 Feature Extractor for Embedded-friendly SSD Models. """MobileNetV1 Feature Extractor for Embedded-friendly SSD Models.
Args: Args:
...@@ -63,20 +61,12 @@ class EmbeddedSSDMobileNetV1FeatureExtractor( ...@@ -63,20 +61,12 @@ class EmbeddedSSDMobileNetV1FeatureExtractor(
min_depth: minimum feature extractor depth. min_depth: minimum feature extractor depth.
pad_to_multiple: the nearest multiple to zero pad the input height and pad_to_multiple: the nearest multiple to zero pad the input height and
width dimensions to. For EmbeddedSSD it must be set to 1. width dimensions to. For EmbeddedSSD it must be set to 1.
conv_hyperparams: tf slim arg_scope for conv2d and separable_conv2d ops. conv_hyperparams_fn: A function to construct tf slim arg_scope for conv2d
batch_norm_trainable: Whether to update batch norm parameters during and separable_conv2d ops.
training or not. When training with a small batch size
(e.g. 1), it is desirable to disable batch norm update and use
pretrained batch norm params.
reuse_weights: Whether to reuse variables. Default is None. reuse_weights: Whether to reuse variables. Default is None.
use_explicit_padding: Whether to use explicit padding when extracting use_explicit_padding: Whether to use explicit padding when extracting
features. Default is False. features. Default is False.
use_depthwise: Whether to use depthwise convolutions. Default is False. use_depthwise: Whether to use depthwise convolutions. Default is False.
inplace_batchnorm_update: Whether to update batch_norm inplace during
training. This is required for batch norm to work correctly on TPUs.
When this is false, user must add a control dependency on
tf.GraphKeys.UPDATE_OPS for train/loss op in order to update the batch
norm moving average parameters.
Raises: Raises:
ValueError: upon invalid `pad_to_multiple` values. ValueError: upon invalid `pad_to_multiple` values.
...@@ -87,10 +77,9 @@ class EmbeddedSSDMobileNetV1FeatureExtractor( ...@@ -87,10 +77,9 @@ class EmbeddedSSDMobileNetV1FeatureExtractor(
super(EmbeddedSSDMobileNetV1FeatureExtractor, self).__init__( super(EmbeddedSSDMobileNetV1FeatureExtractor, self).__init__(
is_training, depth_multiplier, min_depth, pad_to_multiple, is_training, depth_multiplier, min_depth, pad_to_multiple,
conv_hyperparams, batch_norm_trainable, reuse_weights, conv_hyperparams_fn, reuse_weights, use_explicit_padding, use_depthwise)
use_explicit_padding, use_depthwise, inplace_batchnorm_update)
def _extract_features(self, preprocessed_inputs): def extract_features(self, preprocessed_inputs):
"""Extract features from preprocessed inputs. """Extract features from preprocessed inputs.
Args: Args:
...@@ -130,7 +119,7 @@ class EmbeddedSSDMobileNetV1FeatureExtractor( ...@@ -130,7 +119,7 @@ class EmbeddedSSDMobileNetV1FeatureExtractor(
'use_depthwise': self._use_depthwise, 'use_depthwise': self._use_depthwise,
} }
with slim.arg_scope(self._conv_hyperparams): with slim.arg_scope(self._conv_hyperparams_fn()):
with slim.arg_scope([slim.batch_norm], fused=False): with slim.arg_scope([slim.batch_norm], fused=False):
with tf.variable_scope('MobilenetV1', with tf.variable_scope('MobilenetV1',
reuse=self._reuse_weights) as scope: reuse=self._reuse_weights) as scope:
......
...@@ -25,7 +25,7 @@ class EmbeddedSSDMobileNetV1FeatureExtractorTest( ...@@ -25,7 +25,7 @@ class EmbeddedSSDMobileNetV1FeatureExtractorTest(
ssd_feature_extractor_test.SsdFeatureExtractorTestBase): ssd_feature_extractor_test.SsdFeatureExtractorTestBase):
def _create_feature_extractor(self, depth_multiplier, pad_to_multiple, def _create_feature_extractor(self, depth_multiplier, pad_to_multiple,
is_training=True, batch_norm_trainable=True): is_training=True):
"""Constructs a new feature extractor. """Constructs a new feature extractor.
Args: Args:
...@@ -33,18 +33,15 @@ class EmbeddedSSDMobileNetV1FeatureExtractorTest( ...@@ -33,18 +33,15 @@ class EmbeddedSSDMobileNetV1FeatureExtractorTest(
pad_to_multiple: the nearest multiple to zero pad the input height and pad_to_multiple: the nearest multiple to zero pad the input height and
width dimensions to. width dimensions to.
is_training: whether the network is in training mode. is_training: whether the network is in training mode.
batch_norm_trainable: whether to update batch norm parameters during
training.
Returns: Returns:
an ssd_meta_arch.SSDFeatureExtractor object. an ssd_meta_arch.SSDFeatureExtractor object.
""" """
min_depth = 32 min_depth = 32
conv_hyperparams = {}
return (embedded_ssd_mobilenet_v1_feature_extractor. return (embedded_ssd_mobilenet_v1_feature_extractor.
EmbeddedSSDMobileNetV1FeatureExtractor( EmbeddedSSDMobileNetV1FeatureExtractor(
is_training, depth_multiplier, min_depth, pad_to_multiple, is_training, depth_multiplier, min_depth, pad_to_multiple,
conv_hyperparams, batch_norm_trainable)) self.conv_hyperparams_fn))
def test_extract_features_returns_correct_shapes_256(self): def test_extract_features_returns_correct_shapes_256(self):
image_height = 256 image_height = 256
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment