Unverified Commit ed4e22b8 authored by pkulzc's avatar pkulzc Committed by GitHub
Browse files

Merge pull request #3973 from pkulzc/master

Object detection internal changes
parents cac90a0e 13b89b93
...@@ -49,12 +49,12 @@ def build(argscope_fn, box_predictor_config, is_training, num_classes): ...@@ -49,12 +49,12 @@ def build(argscope_fn, box_predictor_config, is_training, num_classes):
if box_predictor_oneof == 'convolutional_box_predictor': if box_predictor_oneof == 'convolutional_box_predictor':
conv_box_predictor = box_predictor_config.convolutional_box_predictor conv_box_predictor = box_predictor_config.convolutional_box_predictor
conv_hyperparams = argscope_fn(conv_box_predictor.conv_hyperparams, conv_hyperparams_fn = argscope_fn(conv_box_predictor.conv_hyperparams,
is_training) is_training)
box_predictor_object = box_predictor.ConvolutionalBoxPredictor( box_predictor_object = box_predictor.ConvolutionalBoxPredictor(
is_training=is_training, is_training=is_training,
num_classes=num_classes, num_classes=num_classes,
conv_hyperparams=conv_hyperparams, conv_hyperparams_fn=conv_hyperparams_fn,
min_depth=conv_box_predictor.min_depth, min_depth=conv_box_predictor.min_depth,
max_depth=conv_box_predictor.max_depth, max_depth=conv_box_predictor.max_depth,
num_layers_before_predictor=(conv_box_predictor. num_layers_before_predictor=(conv_box_predictor.
...@@ -73,12 +73,12 @@ def build(argscope_fn, box_predictor_config, is_training, num_classes): ...@@ -73,12 +73,12 @@ def build(argscope_fn, box_predictor_config, is_training, num_classes):
if box_predictor_oneof == 'weight_shared_convolutional_box_predictor': if box_predictor_oneof == 'weight_shared_convolutional_box_predictor':
conv_box_predictor = (box_predictor_config. conv_box_predictor = (box_predictor_config.
weight_shared_convolutional_box_predictor) weight_shared_convolutional_box_predictor)
conv_hyperparams = argscope_fn(conv_box_predictor.conv_hyperparams, conv_hyperparams_fn = argscope_fn(conv_box_predictor.conv_hyperparams,
is_training) is_training)
box_predictor_object = box_predictor.WeightSharedConvolutionalBoxPredictor( box_predictor_object = box_predictor.WeightSharedConvolutionalBoxPredictor(
is_training=is_training, is_training=is_training,
num_classes=num_classes, num_classes=num_classes,
conv_hyperparams=conv_hyperparams, conv_hyperparams_fn=conv_hyperparams_fn,
depth=conv_box_predictor.depth, depth=conv_box_predictor.depth,
num_layers_before_predictor=(conv_box_predictor. num_layers_before_predictor=(conv_box_predictor.
num_layers_before_predictor), num_layers_before_predictor),
...@@ -90,20 +90,20 @@ def build(argscope_fn, box_predictor_config, is_training, num_classes): ...@@ -90,20 +90,20 @@ def build(argscope_fn, box_predictor_config, is_training, num_classes):
if box_predictor_oneof == 'mask_rcnn_box_predictor': if box_predictor_oneof == 'mask_rcnn_box_predictor':
mask_rcnn_box_predictor = box_predictor_config.mask_rcnn_box_predictor mask_rcnn_box_predictor = box_predictor_config.mask_rcnn_box_predictor
fc_hyperparams = argscope_fn(mask_rcnn_box_predictor.fc_hyperparams, fc_hyperparams_fn = argscope_fn(mask_rcnn_box_predictor.fc_hyperparams,
is_training) is_training)
conv_hyperparams = None conv_hyperparams_fn = None
if mask_rcnn_box_predictor.HasField('conv_hyperparams'): if mask_rcnn_box_predictor.HasField('conv_hyperparams'):
conv_hyperparams = argscope_fn(mask_rcnn_box_predictor.conv_hyperparams, conv_hyperparams_fn = argscope_fn(
is_training) mask_rcnn_box_predictor.conv_hyperparams, is_training)
box_predictor_object = box_predictor.MaskRCNNBoxPredictor( box_predictor_object = box_predictor.MaskRCNNBoxPredictor(
is_training=is_training, is_training=is_training,
num_classes=num_classes, num_classes=num_classes,
fc_hyperparams=fc_hyperparams, fc_hyperparams_fn=fc_hyperparams_fn,
use_dropout=mask_rcnn_box_predictor.use_dropout, use_dropout=mask_rcnn_box_predictor.use_dropout,
dropout_keep_prob=mask_rcnn_box_predictor.dropout_keep_probability, dropout_keep_prob=mask_rcnn_box_predictor.dropout_keep_probability,
box_code_size=mask_rcnn_box_predictor.box_code_size, box_code_size=mask_rcnn_box_predictor.box_code_size,
conv_hyperparams=conv_hyperparams, conv_hyperparams_fn=conv_hyperparams_fn,
predict_instance_masks=mask_rcnn_box_predictor.predict_instance_masks, predict_instance_masks=mask_rcnn_box_predictor.predict_instance_masks,
mask_height=mask_rcnn_box_predictor.mask_height, mask_height=mask_rcnn_box_predictor.mask_height,
mask_width=mask_rcnn_box_predictor.mask_width, mask_width=mask_rcnn_box_predictor.mask_width,
...@@ -111,17 +111,19 @@ def build(argscope_fn, box_predictor_config, is_training, num_classes): ...@@ -111,17 +111,19 @@ def build(argscope_fn, box_predictor_config, is_training, num_classes):
mask_rcnn_box_predictor.mask_prediction_num_conv_layers), mask_rcnn_box_predictor.mask_prediction_num_conv_layers),
mask_prediction_conv_depth=( mask_prediction_conv_depth=(
mask_rcnn_box_predictor.mask_prediction_conv_depth), mask_rcnn_box_predictor.mask_prediction_conv_depth),
masks_are_class_agnostic=(
mask_rcnn_box_predictor.masks_are_class_agnostic),
predict_keypoints=mask_rcnn_box_predictor.predict_keypoints) predict_keypoints=mask_rcnn_box_predictor.predict_keypoints)
return box_predictor_object return box_predictor_object
if box_predictor_oneof == 'rfcn_box_predictor': if box_predictor_oneof == 'rfcn_box_predictor':
rfcn_box_predictor = box_predictor_config.rfcn_box_predictor rfcn_box_predictor = box_predictor_config.rfcn_box_predictor
conv_hyperparams = argscope_fn(rfcn_box_predictor.conv_hyperparams, conv_hyperparams_fn = argscope_fn(rfcn_box_predictor.conv_hyperparams,
is_training) is_training)
box_predictor_object = box_predictor.RfcnBoxPredictor( box_predictor_object = box_predictor.RfcnBoxPredictor(
is_training=is_training, is_training=is_training,
num_classes=num_classes, num_classes=num_classes,
conv_hyperparams=conv_hyperparams, conv_hyperparams_fn=conv_hyperparams_fn,
crop_size=[rfcn_box_predictor.crop_height, crop_size=[rfcn_box_predictor.crop_height,
rfcn_box_predictor.crop_width], rfcn_box_predictor.crop_width],
num_spatial_bins=[rfcn_box_predictor.num_spatial_bins_height, num_spatial_bins=[rfcn_box_predictor.num_spatial_bins_height,
......
...@@ -54,7 +54,7 @@ class ConvolutionalBoxPredictorBuilderTest(tf.test.TestCase): ...@@ -54,7 +54,7 @@ class ConvolutionalBoxPredictorBuilderTest(tf.test.TestCase):
box_predictor_config=box_predictor_proto, box_predictor_config=box_predictor_proto,
is_training=False, is_training=False,
num_classes=10) num_classes=10)
(conv_hyperparams_actual, is_training) = box_predictor._conv_hyperparams (conv_hyperparams_actual, is_training) = box_predictor._conv_hyperparams_fn
self.assertAlmostEqual((hyperparams_proto.regularizer. self.assertAlmostEqual((hyperparams_proto.regularizer.
l1_regularizer.weight), l1_regularizer.weight),
(conv_hyperparams_actual.regularizer.l1_regularizer. (conv_hyperparams_actual.regularizer.l1_regularizer.
...@@ -183,7 +183,7 @@ class WeightSharedConvolutionalBoxPredictorBuilderTest(tf.test.TestCase): ...@@ -183,7 +183,7 @@ class WeightSharedConvolutionalBoxPredictorBuilderTest(tf.test.TestCase):
box_predictor_config=box_predictor_proto, box_predictor_config=box_predictor_proto,
is_training=False, is_training=False,
num_classes=10) num_classes=10)
(conv_hyperparams_actual, is_training) = box_predictor._conv_hyperparams (conv_hyperparams_actual, is_training) = box_predictor._conv_hyperparams_fn
self.assertAlmostEqual((hyperparams_proto.regularizer. self.assertAlmostEqual((hyperparams_proto.regularizer.
l1_regularizer.weight), l1_regularizer.weight),
(conv_hyperparams_actual.regularizer.l1_regularizer. (conv_hyperparams_actual.regularizer.l1_regularizer.
...@@ -297,7 +297,7 @@ class MaskRCNNBoxPredictorBuilderTest(tf.test.TestCase): ...@@ -297,7 +297,7 @@ class MaskRCNNBoxPredictorBuilderTest(tf.test.TestCase):
is_training=False, is_training=False,
num_classes=10) num_classes=10)
mock_argscope_fn.assert_called_with(hyperparams_proto, False) mock_argscope_fn.assert_called_with(hyperparams_proto, False)
self.assertEqual(box_predictor._fc_hyperparams, 'arg_scope') self.assertEqual(box_predictor._fc_hyperparams_fn, 'arg_scope')
def test_non_default_mask_rcnn_box_predictor(self): def test_non_default_mask_rcnn_box_predictor(self):
fc_hyperparams_text_proto = """ fc_hyperparams_text_proto = """
...@@ -417,7 +417,7 @@ class RfcnBoxPredictorBuilderTest(tf.test.TestCase): ...@@ -417,7 +417,7 @@ class RfcnBoxPredictorBuilderTest(tf.test.TestCase):
box_predictor_config=box_predictor_proto, box_predictor_config=box_predictor_proto,
is_training=False, is_training=False,
num_classes=10) num_classes=10)
(conv_hyperparams_actual, is_training) = box_predictor._conv_hyperparams (conv_hyperparams_actual, is_training) = box_predictor._conv_hyperparams_fn
self.assertAlmostEqual((hyperparams_proto.regularizer. self.assertAlmostEqual((hyperparams_proto.regularizer.
l1_regularizer.weight), l1_regularizer.weight),
(conv_hyperparams_actual.regularizer.l1_regularizer. (conv_hyperparams_actual.regularizer.l1_regularizer.
......
...@@ -72,7 +72,9 @@ def _get_padding_shapes(dataset, max_num_boxes=None, num_classes=None, ...@@ -72,7 +72,9 @@ def _get_padding_shapes(dataset, max_num_boxes=None, num_classes=None,
fields.InputDataFields.num_groundtruth_boxes: [], fields.InputDataFields.num_groundtruth_boxes: [],
fields.InputDataFields.groundtruth_label_types: [max_num_boxes], fields.InputDataFields.groundtruth_label_types: [max_num_boxes],
fields.InputDataFields.groundtruth_label_scores: [max_num_boxes], fields.InputDataFields.groundtruth_label_scores: [max_num_boxes],
fields.InputDataFields.true_image_shape: [3] fields.InputDataFields.true_image_shape: [3],
fields.InputDataFields.multiclass_scores: [
max_num_boxes, num_classes + 1 if num_classes is not None else None],
} }
# Determine whether groundtruth_classes are integers or one-hot encodings, and # Determine whether groundtruth_classes are integers or one-hot encodings, and
# apply batching appropriately. # apply batching appropriately.
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
import tensorflow as tf import tensorflow as tf
from object_detection.protos import hyperparams_pb2 from object_detection.protos import hyperparams_pb2
from object_detection.utils import context_manager
slim = tf.contrib.slim slim = tf.contrib.slim
...@@ -43,7 +44,8 @@ def build(hyperparams_config, is_training): ...@@ -43,7 +44,8 @@ def build(hyperparams_config, is_training):
is_training: Whether the network is in training mode. is_training: Whether the network is in training mode.
Returns: Returns:
arg_scope: tf-slim arg_scope containing hyperparameters for ops. arg_scope_fn: A function to construct tf-slim arg_scope containing
hyperparameters for ops.
Raises: Raises:
ValueError: if hyperparams_config is not of type hyperparams.Hyperparams. ValueError: if hyperparams_config is not of type hyperparams.Hyperparams.
...@@ -64,6 +66,10 @@ def build(hyperparams_config, is_training): ...@@ -64,6 +66,10 @@ def build(hyperparams_config, is_training):
if hyperparams_config.HasField('op') and ( if hyperparams_config.HasField('op') and (
hyperparams_config.op == hyperparams_pb2.Hyperparams.FC): hyperparams_config.op == hyperparams_pb2.Hyperparams.FC):
affected_ops = [slim.fully_connected] affected_ops = [slim.fully_connected]
def scope_fn():
with (slim.arg_scope([slim.batch_norm], **batch_norm_params)
if batch_norm_params is not None else
context_manager.IdentityContextManager()):
with slim.arg_scope( with slim.arg_scope(
affected_ops, affected_ops,
weights_regularizer=_build_regularizer( weights_regularizer=_build_regularizer(
...@@ -71,10 +77,11 @@ def build(hyperparams_config, is_training): ...@@ -71,10 +77,11 @@ def build(hyperparams_config, is_training):
weights_initializer=_build_initializer( weights_initializer=_build_initializer(
hyperparams_config.initializer), hyperparams_config.initializer),
activation_fn=_build_activation_fn(hyperparams_config.activation), activation_fn=_build_activation_fn(hyperparams_config.activation),
normalizer_fn=batch_norm, normalizer_fn=batch_norm) as sc:
normalizer_params=batch_norm_params) as sc:
return sc return sc
return scope_fn
def _build_activation_fn(activation_fn): def _build_activation_fn(activation_fn):
"""Builds a callable activation from config. """Builds a callable activation from config.
...@@ -167,6 +174,9 @@ def _build_batch_norm_params(batch_norm, is_training): ...@@ -167,6 +174,9 @@ def _build_batch_norm_params(batch_norm, is_training):
'center': batch_norm.center, 'center': batch_norm.center,
'scale': batch_norm.scale, 'scale': batch_norm.scale,
'epsilon': batch_norm.epsilon, 'epsilon': batch_norm.epsilon,
# Remove is_training parameter from here and deprecate it in the proto
# once we refactor Faster RCNN models to set is_training through an outer
# arg_scope in the meta architecture.
'is_training': is_training and batch_norm.train, 'is_training': is_training and batch_norm.train,
} }
return batch_norm_params return batch_norm_params
...@@ -26,12 +26,12 @@ from object_detection.protos import hyperparams_pb2 ...@@ -26,12 +26,12 @@ from object_detection.protos import hyperparams_pb2
slim = tf.contrib.slim slim = tf.contrib.slim
class HyperparamsBuilderTest(tf.test.TestCase): def _get_scope_key(op):
# TODO(rathodv): Make this a public api in slim arg_scope.py.
def _get_scope_key(self, op):
return getattr(op, '_key_op', str(op)) return getattr(op, '_key_op', str(op))
class HyperparamsBuilderTest(tf.test.TestCase):
def test_default_arg_scope_has_conv2d_op(self): def test_default_arg_scope_has_conv2d_op(self):
conv_hyperparams_text_proto = """ conv_hyperparams_text_proto = """
regularizer { regularizer {
...@@ -45,8 +45,10 @@ class HyperparamsBuilderTest(tf.test.TestCase): ...@@ -45,8 +45,10 @@ class HyperparamsBuilderTest(tf.test.TestCase):
""" """
conv_hyperparams_proto = hyperparams_pb2.Hyperparams() conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto) text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
scope = hyperparams_builder.build(conv_hyperparams_proto, is_training=True) scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
self.assertTrue(self._get_scope_key(slim.conv2d) in scope) is_training=True)
scope = scope_fn()
self.assertTrue(_get_scope_key(slim.conv2d) in scope)
def test_default_arg_scope_has_separable_conv2d_op(self): def test_default_arg_scope_has_separable_conv2d_op(self):
conv_hyperparams_text_proto = """ conv_hyperparams_text_proto = """
...@@ -61,8 +63,10 @@ class HyperparamsBuilderTest(tf.test.TestCase): ...@@ -61,8 +63,10 @@ class HyperparamsBuilderTest(tf.test.TestCase):
""" """
conv_hyperparams_proto = hyperparams_pb2.Hyperparams() conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto) text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
scope = hyperparams_builder.build(conv_hyperparams_proto, is_training=True) scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
self.assertTrue(self._get_scope_key(slim.separable_conv2d) in scope) is_training=True)
scope = scope_fn()
self.assertTrue(_get_scope_key(slim.separable_conv2d) in scope)
def test_default_arg_scope_has_conv2d_transpose_op(self): def test_default_arg_scope_has_conv2d_transpose_op(self):
conv_hyperparams_text_proto = """ conv_hyperparams_text_proto = """
...@@ -77,8 +81,10 @@ class HyperparamsBuilderTest(tf.test.TestCase): ...@@ -77,8 +81,10 @@ class HyperparamsBuilderTest(tf.test.TestCase):
""" """
conv_hyperparams_proto = hyperparams_pb2.Hyperparams() conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto) text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
scope = hyperparams_builder.build(conv_hyperparams_proto, is_training=True) scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
self.assertTrue(self._get_scope_key(slim.conv2d_transpose) in scope) is_training=True)
scope = scope_fn()
self.assertTrue(_get_scope_key(slim.conv2d_transpose) in scope)
def test_explicit_fc_op_arg_scope_has_fully_connected_op(self): def test_explicit_fc_op_arg_scope_has_fully_connected_op(self):
conv_hyperparams_text_proto = """ conv_hyperparams_text_proto = """
...@@ -94,8 +100,10 @@ class HyperparamsBuilderTest(tf.test.TestCase): ...@@ -94,8 +100,10 @@ class HyperparamsBuilderTest(tf.test.TestCase):
""" """
conv_hyperparams_proto = hyperparams_pb2.Hyperparams() conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto) text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
scope = hyperparams_builder.build(conv_hyperparams_proto, is_training=True) scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
self.assertTrue(self._get_scope_key(slim.fully_connected) in scope) is_training=True)
scope = scope_fn()
self.assertTrue(_get_scope_key(slim.fully_connected) in scope)
def test_separable_conv2d_and_conv2d_and_transpose_have_same_parameters(self): def test_separable_conv2d_and_conv2d_and_transpose_have_same_parameters(self):
conv_hyperparams_text_proto = """ conv_hyperparams_text_proto = """
...@@ -110,7 +118,9 @@ class HyperparamsBuilderTest(tf.test.TestCase): ...@@ -110,7 +118,9 @@ class HyperparamsBuilderTest(tf.test.TestCase):
""" """
conv_hyperparams_proto = hyperparams_pb2.Hyperparams() conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto) text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
scope = hyperparams_builder.build(conv_hyperparams_proto, is_training=True) scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
is_training=True)
scope = scope_fn()
kwargs_1, kwargs_2, kwargs_3 = scope.values() kwargs_1, kwargs_2, kwargs_3 = scope.values()
self.assertDictEqual(kwargs_1, kwargs_2) self.assertDictEqual(kwargs_1, kwargs_2)
self.assertDictEqual(kwargs_1, kwargs_3) self.assertDictEqual(kwargs_1, kwargs_3)
...@@ -129,7 +139,9 @@ class HyperparamsBuilderTest(tf.test.TestCase): ...@@ -129,7 +139,9 @@ class HyperparamsBuilderTest(tf.test.TestCase):
""" """
conv_hyperparams_proto = hyperparams_pb2.Hyperparams() conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto) text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
scope = hyperparams_builder.build(conv_hyperparams_proto, is_training=True) scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
is_training=True)
scope = scope_fn()
conv_scope_arguments = scope.values()[0] conv_scope_arguments = scope.values()[0]
regularizer = conv_scope_arguments['weights_regularizer'] regularizer = conv_scope_arguments['weights_regularizer']
weights = np.array([1., -1, 4., 2.]) weights = np.array([1., -1, 4., 2.])
...@@ -151,8 +163,10 @@ class HyperparamsBuilderTest(tf.test.TestCase): ...@@ -151,8 +163,10 @@ class HyperparamsBuilderTest(tf.test.TestCase):
""" """
conv_hyperparams_proto = hyperparams_pb2.Hyperparams() conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto) text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
scope = hyperparams_builder.build(conv_hyperparams_proto, is_training=True) scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
conv_scope_arguments = scope.values()[0] is_training=True)
scope = scope_fn()
conv_scope_arguments = scope[_get_scope_key(slim.conv2d)]
regularizer = conv_scope_arguments['weights_regularizer'] regularizer = conv_scope_arguments['weights_regularizer']
weights = np.array([1., -1, 4., 2.]) weights = np.array([1., -1, 4., 2.])
...@@ -180,10 +194,12 @@ class HyperparamsBuilderTest(tf.test.TestCase): ...@@ -180,10 +194,12 @@ class HyperparamsBuilderTest(tf.test.TestCase):
""" """
conv_hyperparams_proto = hyperparams_pb2.Hyperparams() conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto) text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
scope = hyperparams_builder.build(conv_hyperparams_proto, is_training=True) scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
conv_scope_arguments = scope.values()[0] is_training=True)
scope = scope_fn()
conv_scope_arguments = scope[_get_scope_key(slim.conv2d)]
self.assertEqual(conv_scope_arguments['normalizer_fn'], slim.batch_norm) self.assertEqual(conv_scope_arguments['normalizer_fn'], slim.batch_norm)
batch_norm_params = conv_scope_arguments['normalizer_params'] batch_norm_params = scope[_get_scope_key(slim.batch_norm)]
self.assertAlmostEqual(batch_norm_params['decay'], 0.7) self.assertAlmostEqual(batch_norm_params['decay'], 0.7)
self.assertAlmostEqual(batch_norm_params['epsilon'], 0.03) self.assertAlmostEqual(batch_norm_params['epsilon'], 0.03)
self.assertFalse(batch_norm_params['center']) self.assertFalse(batch_norm_params['center'])
...@@ -210,10 +226,12 @@ class HyperparamsBuilderTest(tf.test.TestCase): ...@@ -210,10 +226,12 @@ class HyperparamsBuilderTest(tf.test.TestCase):
""" """
conv_hyperparams_proto = hyperparams_pb2.Hyperparams() conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto) text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
scope = hyperparams_builder.build(conv_hyperparams_proto, is_training=False) scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
conv_scope_arguments = scope.values()[0] is_training=False)
scope = scope_fn()
conv_scope_arguments = scope[_get_scope_key(slim.conv2d)]
self.assertEqual(conv_scope_arguments['normalizer_fn'], slim.batch_norm) self.assertEqual(conv_scope_arguments['normalizer_fn'], slim.batch_norm)
batch_norm_params = conv_scope_arguments['normalizer_params'] batch_norm_params = scope[_get_scope_key(slim.batch_norm)]
self.assertAlmostEqual(batch_norm_params['decay'], 0.7) self.assertAlmostEqual(batch_norm_params['decay'], 0.7)
self.assertAlmostEqual(batch_norm_params['epsilon'], 0.03) self.assertAlmostEqual(batch_norm_params['epsilon'], 0.03)
self.assertFalse(batch_norm_params['center']) self.assertFalse(batch_norm_params['center'])
...@@ -240,10 +258,12 @@ class HyperparamsBuilderTest(tf.test.TestCase): ...@@ -240,10 +258,12 @@ class HyperparamsBuilderTest(tf.test.TestCase):
""" """
conv_hyperparams_proto = hyperparams_pb2.Hyperparams() conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto) text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
scope = hyperparams_builder.build(conv_hyperparams_proto, is_training=True) scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
conv_scope_arguments = scope.values()[0] is_training=True)
scope = scope_fn()
conv_scope_arguments = scope[_get_scope_key(slim.conv2d)]
self.assertEqual(conv_scope_arguments['normalizer_fn'], slim.batch_norm) self.assertEqual(conv_scope_arguments['normalizer_fn'], slim.batch_norm)
batch_norm_params = conv_scope_arguments['normalizer_params'] batch_norm_params = scope[_get_scope_key(slim.batch_norm)]
self.assertAlmostEqual(batch_norm_params['decay'], 0.7) self.assertAlmostEqual(batch_norm_params['decay'], 0.7)
self.assertAlmostEqual(batch_norm_params['epsilon'], 0.03) self.assertAlmostEqual(batch_norm_params['epsilon'], 0.03)
self.assertFalse(batch_norm_params['center']) self.assertFalse(batch_norm_params['center'])
...@@ -263,10 +283,11 @@ class HyperparamsBuilderTest(tf.test.TestCase): ...@@ -263,10 +283,11 @@ class HyperparamsBuilderTest(tf.test.TestCase):
""" """
conv_hyperparams_proto = hyperparams_pb2.Hyperparams() conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto) text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
scope = hyperparams_builder.build(conv_hyperparams_proto, is_training=True) scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
conv_scope_arguments = scope.values()[0] is_training=True)
scope = scope_fn()
conv_scope_arguments = scope[_get_scope_key(slim.conv2d)]
self.assertEqual(conv_scope_arguments['normalizer_fn'], None) self.assertEqual(conv_scope_arguments['normalizer_fn'], None)
self.assertEqual(conv_scope_arguments['normalizer_params'], None)
def test_use_none_activation(self): def test_use_none_activation(self):
conv_hyperparams_text_proto = """ conv_hyperparams_text_proto = """
...@@ -282,8 +303,10 @@ class HyperparamsBuilderTest(tf.test.TestCase): ...@@ -282,8 +303,10 @@ class HyperparamsBuilderTest(tf.test.TestCase):
""" """
conv_hyperparams_proto = hyperparams_pb2.Hyperparams() conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto) text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
scope = hyperparams_builder.build(conv_hyperparams_proto, is_training=True) scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
conv_scope_arguments = scope.values()[0] is_training=True)
scope = scope_fn()
conv_scope_arguments = scope[_get_scope_key(slim.conv2d)]
self.assertEqual(conv_scope_arguments['activation_fn'], None) self.assertEqual(conv_scope_arguments['activation_fn'], None)
def test_use_relu_activation(self): def test_use_relu_activation(self):
...@@ -300,8 +323,10 @@ class HyperparamsBuilderTest(tf.test.TestCase): ...@@ -300,8 +323,10 @@ class HyperparamsBuilderTest(tf.test.TestCase):
""" """
conv_hyperparams_proto = hyperparams_pb2.Hyperparams() conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto) text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
scope = hyperparams_builder.build(conv_hyperparams_proto, is_training=True) scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
conv_scope_arguments = scope.values()[0] is_training=True)
scope = scope_fn()
conv_scope_arguments = scope[_get_scope_key(slim.conv2d)]
self.assertEqual(conv_scope_arguments['activation_fn'], tf.nn.relu) self.assertEqual(conv_scope_arguments['activation_fn'], tf.nn.relu)
def test_use_relu_6_activation(self): def test_use_relu_6_activation(self):
...@@ -318,8 +343,10 @@ class HyperparamsBuilderTest(tf.test.TestCase): ...@@ -318,8 +343,10 @@ class HyperparamsBuilderTest(tf.test.TestCase):
""" """
conv_hyperparams_proto = hyperparams_pb2.Hyperparams() conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto) text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
scope = hyperparams_builder.build(conv_hyperparams_proto, is_training=True) scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
conv_scope_arguments = scope.values()[0] is_training=True)
scope = scope_fn()
conv_scope_arguments = scope[_get_scope_key(slim.conv2d)]
self.assertEqual(conv_scope_arguments['activation_fn'], tf.nn.relu6) self.assertEqual(conv_scope_arguments['activation_fn'], tf.nn.relu6)
def _assert_variance_in_range(self, initializer, shape, variance, def _assert_variance_in_range(self, initializer, shape, variance,
...@@ -351,8 +378,10 @@ class HyperparamsBuilderTest(tf.test.TestCase): ...@@ -351,8 +378,10 @@ class HyperparamsBuilderTest(tf.test.TestCase):
""" """
conv_hyperparams_proto = hyperparams_pb2.Hyperparams() conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto) text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
scope = hyperparams_builder.build(conv_hyperparams_proto, is_training=True) scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
conv_scope_arguments = scope.values()[0] is_training=True)
scope = scope_fn()
conv_scope_arguments = scope[_get_scope_key(slim.conv2d)]
initializer = conv_scope_arguments['weights_initializer'] initializer = conv_scope_arguments['weights_initializer']
self._assert_variance_in_range(initializer, shape=[100, 40], self._assert_variance_in_range(initializer, shape=[100, 40],
variance=2. / 100.) variance=2. / 100.)
...@@ -373,8 +402,10 @@ class HyperparamsBuilderTest(tf.test.TestCase): ...@@ -373,8 +402,10 @@ class HyperparamsBuilderTest(tf.test.TestCase):
""" """
conv_hyperparams_proto = hyperparams_pb2.Hyperparams() conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto) text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
scope = hyperparams_builder.build(conv_hyperparams_proto, is_training=True) scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
conv_scope_arguments = scope.values()[0] is_training=True)
scope = scope_fn()
conv_scope_arguments = scope[_get_scope_key(slim.conv2d)]
initializer = conv_scope_arguments['weights_initializer'] initializer = conv_scope_arguments['weights_initializer']
self._assert_variance_in_range(initializer, shape=[100, 40], self._assert_variance_in_range(initializer, shape=[100, 40],
variance=2. / 40.) variance=2. / 40.)
...@@ -395,8 +426,10 @@ class HyperparamsBuilderTest(tf.test.TestCase): ...@@ -395,8 +426,10 @@ class HyperparamsBuilderTest(tf.test.TestCase):
""" """
conv_hyperparams_proto = hyperparams_pb2.Hyperparams() conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto) text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
scope = hyperparams_builder.build(conv_hyperparams_proto, is_training=True) scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
conv_scope_arguments = scope.values()[0] is_training=True)
scope = scope_fn()
conv_scope_arguments = scope[_get_scope_key(slim.conv2d)]
initializer = conv_scope_arguments['weights_initializer'] initializer = conv_scope_arguments['weights_initializer']
self._assert_variance_in_range(initializer, shape=[100, 40], self._assert_variance_in_range(initializer, shape=[100, 40],
variance=4. / (100. + 40.)) variance=4. / (100. + 40.))
...@@ -417,8 +450,10 @@ class HyperparamsBuilderTest(tf.test.TestCase): ...@@ -417,8 +450,10 @@ class HyperparamsBuilderTest(tf.test.TestCase):
""" """
conv_hyperparams_proto = hyperparams_pb2.Hyperparams() conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto) text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
scope = hyperparams_builder.build(conv_hyperparams_proto, is_training=True) scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
conv_scope_arguments = scope.values()[0] is_training=True)
scope = scope_fn()
conv_scope_arguments = scope[_get_scope_key(slim.conv2d)]
initializer = conv_scope_arguments['weights_initializer'] initializer = conv_scope_arguments['weights_initializer']
self._assert_variance_in_range(initializer, shape=[100, 40], self._assert_variance_in_range(initializer, shape=[100, 40],
variance=2. / 100.) variance=2. / 100.)
...@@ -438,8 +473,10 @@ class HyperparamsBuilderTest(tf.test.TestCase): ...@@ -438,8 +473,10 @@ class HyperparamsBuilderTest(tf.test.TestCase):
""" """
conv_hyperparams_proto = hyperparams_pb2.Hyperparams() conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto) text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
scope = hyperparams_builder.build(conv_hyperparams_proto, is_training=True) scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
conv_scope_arguments = scope.values()[0] is_training=True)
scope = scope_fn()
conv_scope_arguments = scope[_get_scope_key(slim.conv2d)]
initializer = conv_scope_arguments['weights_initializer'] initializer = conv_scope_arguments['weights_initializer']
self._assert_variance_in_range(initializer, shape=[100, 40], self._assert_variance_in_range(initializer, shape=[100, 40],
variance=0.49, tol=1e-1) variance=0.49, tol=1e-1)
...@@ -459,8 +496,10 @@ class HyperparamsBuilderTest(tf.test.TestCase): ...@@ -459,8 +496,10 @@ class HyperparamsBuilderTest(tf.test.TestCase):
""" """
conv_hyperparams_proto = hyperparams_pb2.Hyperparams() conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto) text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
scope = hyperparams_builder.build(conv_hyperparams_proto, is_training=True) scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
conv_scope_arguments = scope.values()[0] is_training=True)
scope = scope_fn()
conv_scope_arguments = scope[_get_scope_key(slim.conv2d)]
initializer = conv_scope_arguments['weights_initializer'] initializer = conv_scope_arguments['weights_initializer']
self._assert_variance_in_range(initializer, shape=[100, 40], self._assert_variance_in_range(initializer, shape=[100, 40],
variance=0.64, tol=1e-1) variance=0.64, tol=1e-1)
......
...@@ -71,7 +71,8 @@ FASTER_RCNN_FEATURE_EXTRACTOR_CLASS_MAP = { ...@@ -71,7 +71,8 @@ FASTER_RCNN_FEATURE_EXTRACTOR_CLASS_MAP = {
} }
def build(model_config, is_training, add_summaries=True): def build(model_config, is_training, add_summaries=True,
add_background_class=True):
"""Builds a DetectionModel based on the model config. """Builds a DetectionModel based on the model config.
Args: Args:
...@@ -79,7 +80,10 @@ def build(model_config, is_training, add_summaries=True): ...@@ -79,7 +80,10 @@ def build(model_config, is_training, add_summaries=True):
DetectionModel. DetectionModel.
is_training: True if this model is being built for training purposes. is_training: True if this model is being built for training purposes.
add_summaries: Whether to add tensorflow summaries in the model graph. add_summaries: Whether to add tensorflow summaries in the model graph.
add_background_class: Whether to add an implicit background class to one-hot
encodings of groundtruth labels. Set to false if using groundtruth labels
with an explicit background class or using multiclass scores instead of
truth in the case of distillation. Ignored in the case of faster_rcnn.
Returns: Returns:
DetectionModel based on the config. DetectionModel based on the config.
...@@ -90,7 +94,8 @@ def build(model_config, is_training, add_summaries=True): ...@@ -90,7 +94,8 @@ def build(model_config, is_training, add_summaries=True):
raise ValueError('model_config not of type model_pb2.DetectionModel.') raise ValueError('model_config not of type model_pb2.DetectionModel.')
meta_architecture = model_config.WhichOneof('model') meta_architecture = model_config.WhichOneof('model')
if meta_architecture == 'ssd': if meta_architecture == 'ssd':
return _build_ssd_model(model_config.ssd, is_training, add_summaries) return _build_ssd_model(model_config.ssd, is_training, add_summaries,
add_background_class)
if meta_architecture == 'faster_rcnn': if meta_architecture == 'faster_rcnn':
return _build_faster_rcnn_model(model_config.faster_rcnn, is_training, return _build_faster_rcnn_model(model_config.faster_rcnn, is_training,
add_summaries) add_summaries)
...@@ -98,19 +103,13 @@ def build(model_config, is_training, add_summaries=True): ...@@ -98,19 +103,13 @@ def build(model_config, is_training, add_summaries=True):
def _build_ssd_feature_extractor(feature_extractor_config, is_training, def _build_ssd_feature_extractor(feature_extractor_config, is_training,
reuse_weights=None, reuse_weights=None):
inplace_batchnorm_update=False):
"""Builds a ssd_meta_arch.SSDFeatureExtractor based on config. """Builds a ssd_meta_arch.SSDFeatureExtractor based on config.
Args: Args:
feature_extractor_config: A SSDFeatureExtractor proto config from ssd.proto. feature_extractor_config: A SSDFeatureExtractor proto config from ssd.proto.
is_training: True if this feature extractor is being built for training. is_training: True if this feature extractor is being built for training.
reuse_weights: if the feature extractor should reuse weights. reuse_weights: if the feature extractor should reuse weights.
inplace_batchnorm_update: Whether to update batch_norm inplace during
training. This is required for batch norm to work correctly on TPUs. When
this is false, user must add a control dependency on
tf.GraphKeys.UPDATE_OPS for train/loss op in order to update the batch
norm moving average parameters.
Returns: Returns:
ssd_meta_arch.SSDFeatureExtractor based on config. ssd_meta_arch.SSDFeatureExtractor based on config.
...@@ -122,24 +121,25 @@ def _build_ssd_feature_extractor(feature_extractor_config, is_training, ...@@ -122,24 +121,25 @@ def _build_ssd_feature_extractor(feature_extractor_config, is_training,
depth_multiplier = feature_extractor_config.depth_multiplier depth_multiplier = feature_extractor_config.depth_multiplier
min_depth = feature_extractor_config.min_depth min_depth = feature_extractor_config.min_depth
pad_to_multiple = feature_extractor_config.pad_to_multiple pad_to_multiple = feature_extractor_config.pad_to_multiple
batch_norm_trainable = feature_extractor_config.batch_norm_trainable
use_explicit_padding = feature_extractor_config.use_explicit_padding use_explicit_padding = feature_extractor_config.use_explicit_padding
use_depthwise = feature_extractor_config.use_depthwise use_depthwise = feature_extractor_config.use_depthwise
conv_hyperparams = hyperparams_builder.build( conv_hyperparams = hyperparams_builder.build(
feature_extractor_config.conv_hyperparams, is_training) feature_extractor_config.conv_hyperparams, is_training)
override_base_feature_extractor_hyperparams = (
feature_extractor_config.override_base_feature_extractor_hyperparams)
if feature_type not in SSD_FEATURE_EXTRACTOR_CLASS_MAP: if feature_type not in SSD_FEATURE_EXTRACTOR_CLASS_MAP:
raise ValueError('Unknown ssd feature_extractor: {}'.format(feature_type)) raise ValueError('Unknown ssd feature_extractor: {}'.format(feature_type))
feature_extractor_class = SSD_FEATURE_EXTRACTOR_CLASS_MAP[feature_type] feature_extractor_class = SSD_FEATURE_EXTRACTOR_CLASS_MAP[feature_type]
return feature_extractor_class(is_training, depth_multiplier, min_depth, return feature_extractor_class(
pad_to_multiple, conv_hyperparams, is_training, depth_multiplier, min_depth, pad_to_multiple,
batch_norm_trainable, reuse_weights, conv_hyperparams, reuse_weights, use_explicit_padding, use_depthwise,
use_explicit_padding, use_depthwise, override_base_feature_extractor_hyperparams)
inplace_batchnorm_update)
def _build_ssd_model(ssd_config, is_training, add_summaries): def _build_ssd_model(ssd_config, is_training, add_summaries,
add_background_class=True):
"""Builds an SSD detection model based on the model config. """Builds an SSD detection model based on the model config.
Args: Args:
...@@ -147,7 +147,10 @@ def _build_ssd_model(ssd_config, is_training, add_summaries): ...@@ -147,7 +147,10 @@ def _build_ssd_model(ssd_config, is_training, add_summaries):
SSDMetaArch. SSDMetaArch.
is_training: True if this model is being built for training purposes. is_training: True if this model is being built for training purposes.
add_summaries: Whether to add tf summaries in the model. add_summaries: Whether to add tf summaries in the model.
add_background_class: Whether to add an implicit background class to one-hot
encodings of groundtruth labels. Set to false if using groundtruth labels
with an explicit background class or using multiclass scores instead of
truth in the case of distillation.
Returns: Returns:
SSDMetaArch based on the config. SSDMetaArch based on the config.
...@@ -160,8 +163,7 @@ def _build_ssd_model(ssd_config, is_training, add_summaries): ...@@ -160,8 +163,7 @@ def _build_ssd_model(ssd_config, is_training, add_summaries):
# Feature extractor # Feature extractor
feature_extractor = _build_ssd_feature_extractor( feature_extractor = _build_ssd_feature_extractor(
feature_extractor_config=ssd_config.feature_extractor, feature_extractor_config=ssd_config.feature_extractor,
is_training=is_training, is_training=is_training)
inplace_batchnorm_update=ssd_config.inplace_batchnorm_update)
box_coder = box_coder_builder.build(ssd_config.box_coder) box_coder = box_coder_builder.build(ssd_config.box_coder)
matcher = matcher_builder.build(ssd_config.matcher) matcher = matcher_builder.build(ssd_config.matcher)
...@@ -203,7 +205,10 @@ def _build_ssd_model(ssd_config, is_training, add_summaries): ...@@ -203,7 +205,10 @@ def _build_ssd_model(ssd_config, is_training, add_summaries):
normalize_loss_by_num_matches, normalize_loss_by_num_matches,
hard_example_miner, hard_example_miner,
add_summaries=add_summaries, add_summaries=add_summaries,
normalize_loc_loss_by_codesize=normalize_loc_loss_by_codesize) normalize_loc_loss_by_codesize=normalize_loc_loss_by_codesize,
freeze_batchnorm=ssd_config.freeze_batchnorm,
inplace_batchnorm_update=ssd_config.inplace_batchnorm_update,
add_background_class=add_background_class)
def _build_faster_rcnn_feature_extractor( def _build_faster_rcnn_feature_extractor(
...@@ -276,7 +281,7 @@ def _build_faster_rcnn_model(frcnn_config, is_training, add_summaries): ...@@ -276,7 +281,7 @@ def _build_faster_rcnn_model(frcnn_config, is_training, add_summaries):
frcnn_config.first_stage_anchor_generator) frcnn_config.first_stage_anchor_generator)
first_stage_atrous_rate = frcnn_config.first_stage_atrous_rate first_stage_atrous_rate = frcnn_config.first_stage_atrous_rate
first_stage_box_predictor_arg_scope = hyperparams_builder.build( first_stage_box_predictor_arg_scope_fn = hyperparams_builder.build(
frcnn_config.first_stage_box_predictor_conv_hyperparams, is_training) frcnn_config.first_stage_box_predictor_conv_hyperparams, is_training)
first_stage_box_predictor_kernel_size = ( first_stage_box_predictor_kernel_size = (
frcnn_config.first_stage_box_predictor_kernel_size) frcnn_config.first_stage_box_predictor_kernel_size)
...@@ -329,8 +334,8 @@ def _build_faster_rcnn_model(frcnn_config, is_training, add_summaries): ...@@ -329,8 +334,8 @@ def _build_faster_rcnn_model(frcnn_config, is_training, add_summaries):
'number_of_stages': number_of_stages, 'number_of_stages': number_of_stages,
'first_stage_anchor_generator': first_stage_anchor_generator, 'first_stage_anchor_generator': first_stage_anchor_generator,
'first_stage_atrous_rate': first_stage_atrous_rate, 'first_stage_atrous_rate': first_stage_atrous_rate,
'first_stage_box_predictor_arg_scope': 'first_stage_box_predictor_arg_scope_fn':
first_stage_box_predictor_arg_scope, first_stage_box_predictor_arg_scope_fn,
'first_stage_box_predictor_kernel_size': 'first_stage_box_predictor_kernel_size':
first_stage_box_predictor_kernel_size, first_stage_box_predictor_kernel_size,
'first_stage_box_predictor_depth': first_stage_box_predictor_depth, 'first_stage_box_predictor_depth': first_stage_box_predictor_depth,
......
...@@ -83,6 +83,7 @@ class ModelBuilderTest(tf.test.TestCase): ...@@ -83,6 +83,7 @@ class ModelBuilderTest(tf.test.TestCase):
} }
} }
} }
override_base_feature_extractor_hyperparams: true
} }
box_coder { box_coder {
faster_rcnn_box_coder { faster_rcnn_box_coder {
...@@ -154,6 +155,7 @@ class ModelBuilderTest(tf.test.TestCase): ...@@ -154,6 +155,7 @@ class ModelBuilderTest(tf.test.TestCase):
} }
} }
} }
override_base_feature_extractor_hyperparams: true
} }
box_coder { box_coder {
faster_rcnn_box_coder { faster_rcnn_box_coder {
...@@ -225,7 +227,6 @@ class ModelBuilderTest(tf.test.TestCase): ...@@ -225,7 +227,6 @@ class ModelBuilderTest(tf.test.TestCase):
} }
} }
} }
batch_norm_trainable: true
} }
box_coder { box_coder {
faster_rcnn_box_coder { faster_rcnn_box_coder {
...@@ -298,6 +299,7 @@ class ModelBuilderTest(tf.test.TestCase): ...@@ -298,6 +299,7 @@ class ModelBuilderTest(tf.test.TestCase):
def test_create_ssd_mobilenet_v1_model_from_config(self): def test_create_ssd_mobilenet_v1_model_from_config(self):
model_text_proto = """ model_text_proto = """
ssd { ssd {
freeze_batchnorm: true
inplace_batchnorm_update: true inplace_batchnorm_update: true
feature_extractor { feature_extractor {
type: 'ssd_mobilenet_v1' type: 'ssd_mobilenet_v1'
...@@ -311,7 +313,6 @@ class ModelBuilderTest(tf.test.TestCase): ...@@ -311,7 +313,6 @@ class ModelBuilderTest(tf.test.TestCase):
} }
} }
} }
batch_norm_trainable: true
} }
box_coder { box_coder {
faster_rcnn_box_coder { faster_rcnn_box_coder {
...@@ -368,8 +369,9 @@ class ModelBuilderTest(tf.test.TestCase): ...@@ -368,8 +369,9 @@ class ModelBuilderTest(tf.test.TestCase):
self.assertIsInstance(model, ssd_meta_arch.SSDMetaArch) self.assertIsInstance(model, ssd_meta_arch.SSDMetaArch)
self.assertIsInstance(model._feature_extractor, self.assertIsInstance(model._feature_extractor,
SSDMobileNetV1FeatureExtractor) SSDMobileNetV1FeatureExtractor)
self.assertTrue(model._feature_extractor._batch_norm_trainable)
self.assertTrue(model._normalize_loc_loss_by_codesize) self.assertTrue(model._normalize_loc_loss_by_codesize)
self.assertTrue(model._freeze_batchnorm)
self.assertTrue(model._inplace_batchnorm_update)
def test_create_ssd_mobilenet_v2_model_from_config(self): def test_create_ssd_mobilenet_v2_model_from_config(self):
model_text_proto = """ model_text_proto = """
...@@ -386,7 +388,6 @@ class ModelBuilderTest(tf.test.TestCase): ...@@ -386,7 +388,6 @@ class ModelBuilderTest(tf.test.TestCase):
} }
} }
} }
batch_norm_trainable: true
} }
box_coder { box_coder {
faster_rcnn_box_coder { faster_rcnn_box_coder {
...@@ -443,7 +444,6 @@ class ModelBuilderTest(tf.test.TestCase): ...@@ -443,7 +444,6 @@ class ModelBuilderTest(tf.test.TestCase):
self.assertIsInstance(model, ssd_meta_arch.SSDMetaArch) self.assertIsInstance(model, ssd_meta_arch.SSDMetaArch)
self.assertIsInstance(model._feature_extractor, self.assertIsInstance(model._feature_extractor,
SSDMobileNetV2FeatureExtractor) SSDMobileNetV2FeatureExtractor)
self.assertTrue(model._feature_extractor._batch_norm_trainable)
self.assertTrue(model._normalize_loc_loss_by_codesize) self.assertTrue(model._normalize_loc_loss_by_codesize)
def test_create_embedded_ssd_mobilenet_v1_model_from_config(self): def test_create_embedded_ssd_mobilenet_v1_model_from_config(self):
...@@ -461,7 +461,6 @@ class ModelBuilderTest(tf.test.TestCase): ...@@ -461,7 +461,6 @@ class ModelBuilderTest(tf.test.TestCase):
} }
} }
} }
batch_norm_trainable: true
} }
box_coder { box_coder {
faster_rcnn_box_coder { faster_rcnn_box_coder {
......
...@@ -147,7 +147,7 @@ class RfcnBoxPredictor(BoxPredictor): ...@@ -147,7 +147,7 @@ class RfcnBoxPredictor(BoxPredictor):
def __init__(self, def __init__(self,
is_training, is_training,
num_classes, num_classes,
conv_hyperparams, conv_hyperparams_fn,
num_spatial_bins, num_spatial_bins,
depth, depth,
crop_size, crop_size,
...@@ -160,8 +160,8 @@ class RfcnBoxPredictor(BoxPredictor): ...@@ -160,8 +160,8 @@ class RfcnBoxPredictor(BoxPredictor):
include the background category, so if groundtruth labels take values include the background category, so if groundtruth labels take values
in {0, 1, .., K-1}, num_classes=K (and not K+1, even though the in {0, 1, .., K-1}, num_classes=K (and not K+1, even though the
assigned classification targets can range from {0,... K}). assigned classification targets can range from {0,... K}).
conv_hyperparams: Slim arg_scope with hyperparameters for conolutional conv_hyperparams_fn: A function to construct tf-slim arg_scope with
layers. hyperparameters for convolutional layers.
num_spatial_bins: A list of two integers `[spatial_bins_y, num_spatial_bins: A list of two integers `[spatial_bins_y,
spatial_bins_x]`. spatial_bins_x]`.
depth: Target depth to reduce the input feature maps to. depth: Target depth to reduce the input feature maps to.
...@@ -169,7 +169,7 @@ class RfcnBoxPredictor(BoxPredictor): ...@@ -169,7 +169,7 @@ class RfcnBoxPredictor(BoxPredictor):
box_code_size: Size of encoding for each box. box_code_size: Size of encoding for each box.
""" """
super(RfcnBoxPredictor, self).__init__(is_training, num_classes) super(RfcnBoxPredictor, self).__init__(is_training, num_classes)
self._conv_hyperparams = conv_hyperparams self._conv_hyperparams_fn = conv_hyperparams_fn
self._num_spatial_bins = num_spatial_bins self._num_spatial_bins = num_spatial_bins
self._depth = depth self._depth = depth
self._crop_size = crop_size self._crop_size = crop_size
...@@ -227,7 +227,7 @@ class RfcnBoxPredictor(BoxPredictor): ...@@ -227,7 +227,7 @@ class RfcnBoxPredictor(BoxPredictor):
return tf.reshape(ones_mat * multiplier, [-1]) return tf.reshape(ones_mat * multiplier, [-1])
net = image_feature net = image_feature
with slim.arg_scope(self._conv_hyperparams): with slim.arg_scope(self._conv_hyperparams_fn()):
net = slim.conv2d(net, self._depth, [1, 1], scope='reduce_depth') net = slim.conv2d(net, self._depth, [1, 1], scope='reduce_depth')
# Location predictions. # Location predictions.
location_feature_map_depth = (self._num_spatial_bins[0] * location_feature_map_depth = (self._num_spatial_bins[0] *
...@@ -297,16 +297,17 @@ class MaskRCNNBoxPredictor(BoxPredictor): ...@@ -297,16 +297,17 @@ class MaskRCNNBoxPredictor(BoxPredictor):
def __init__(self, def __init__(self,
is_training, is_training,
num_classes, num_classes,
fc_hyperparams, fc_hyperparams_fn,
use_dropout, use_dropout,
dropout_keep_prob, dropout_keep_prob,
box_code_size, box_code_size,
conv_hyperparams=None, conv_hyperparams_fn=None,
predict_instance_masks=False, predict_instance_masks=False,
mask_height=14, mask_height=14,
mask_width=14, mask_width=14,
mask_prediction_num_conv_layers=2, mask_prediction_num_conv_layers=2,
mask_prediction_conv_depth=256, mask_prediction_conv_depth=256,
masks_are_class_agnostic=False,
predict_keypoints=False): predict_keypoints=False):
"""Constructor. """Constructor.
...@@ -316,16 +317,16 @@ class MaskRCNNBoxPredictor(BoxPredictor): ...@@ -316,16 +317,16 @@ class MaskRCNNBoxPredictor(BoxPredictor):
include the background category, so if groundtruth labels take values include the background category, so if groundtruth labels take values
in {0, 1, .., K-1}, num_classes=K (and not K+1, even though the in {0, 1, .., K-1}, num_classes=K (and not K+1, even though the
assigned classification targets can range from {0,... K}). assigned classification targets can range from {0,... K}).
fc_hyperparams: Slim arg_scope with hyperparameters for fully fc_hyperparams_fn: A function to generate tf-slim arg_scope with
connected ops. hyperparameters for fully connected ops.
use_dropout: Option to use dropout or not. Note that a single dropout use_dropout: Option to use dropout or not. Note that a single dropout
op is applied here prior to both box and class predictions, which stands op is applied here prior to both box and class predictions, which stands
in contrast to the ConvolutionalBoxPredictor below. in contrast to the ConvolutionalBoxPredictor below.
dropout_keep_prob: Keep probability for dropout. dropout_keep_prob: Keep probability for dropout.
This is only used if use_dropout is True. This is only used if use_dropout is True.
box_code_size: Size of encoding for each box. box_code_size: Size of encoding for each box.
conv_hyperparams: Slim arg_scope with hyperparameters for convolution conv_hyperparams_fn: A function to generate tf-slim arg_scope with
ops. hyperparameters for convolution ops.
predict_instance_masks: Whether to predict object masks inside detection predict_instance_masks: Whether to predict object masks inside detection
boxes. boxes.
mask_height: Desired output mask height. The default value is 14. mask_height: Desired output mask height. The default value is 14.
...@@ -337,6 +338,8 @@ class MaskRCNNBoxPredictor(BoxPredictor): ...@@ -337,6 +338,8 @@ class MaskRCNNBoxPredictor(BoxPredictor):
to 0, the depth of the convolution layers will be automatically chosen to 0, the depth of the convolution layers will be automatically chosen
based on the number of object classes and the number of channels in the based on the number of object classes and the number of channels in the
image features. image features.
masks_are_class_agnostic: Boolean determining if the mask-head is
class-agnostic or not.
predict_keypoints: Whether to predict keypoints insde detection boxes. predict_keypoints: Whether to predict keypoints insde detection boxes.
...@@ -347,21 +350,22 @@ class MaskRCNNBoxPredictor(BoxPredictor): ...@@ -347,21 +350,22 @@ class MaskRCNNBoxPredictor(BoxPredictor):
ValueError: If mask_prediction_num_conv_layers is smaller than two. ValueError: If mask_prediction_num_conv_layers is smaller than two.
""" """
super(MaskRCNNBoxPredictor, self).__init__(is_training, num_classes) super(MaskRCNNBoxPredictor, self).__init__(is_training, num_classes)
self._fc_hyperparams = fc_hyperparams self._fc_hyperparams_fn = fc_hyperparams_fn
self._use_dropout = use_dropout self._use_dropout = use_dropout
self._box_code_size = box_code_size self._box_code_size = box_code_size
self._dropout_keep_prob = dropout_keep_prob self._dropout_keep_prob = dropout_keep_prob
self._conv_hyperparams = conv_hyperparams self._conv_hyperparams_fn = conv_hyperparams_fn
self._predict_instance_masks = predict_instance_masks self._predict_instance_masks = predict_instance_masks
self._mask_height = mask_height self._mask_height = mask_height
self._mask_width = mask_width self._mask_width = mask_width
self._mask_prediction_num_conv_layers = mask_prediction_num_conv_layers self._mask_prediction_num_conv_layers = mask_prediction_num_conv_layers
self._mask_prediction_conv_depth = mask_prediction_conv_depth self._mask_prediction_conv_depth = mask_prediction_conv_depth
self._masks_are_class_agnostic = masks_are_class_agnostic
self._predict_keypoints = predict_keypoints self._predict_keypoints = predict_keypoints
if self._predict_keypoints: if self._predict_keypoints:
raise ValueError('Keypoint prediction is unimplemented.') raise ValueError('Keypoint prediction is unimplemented.')
if ((self._predict_instance_masks or self._predict_keypoints) and if ((self._predict_instance_masks or self._predict_keypoints) and
self._conv_hyperparams is None): self._conv_hyperparams_fn is None):
raise ValueError('`conv_hyperparams` must be provided when predicting ' raise ValueError('`conv_hyperparams` must be provided when predicting '
'masks.') 'masks.')
if self._mask_prediction_num_conv_layers < 2: if self._mask_prediction_num_conv_layers < 2:
...@@ -399,7 +403,7 @@ class MaskRCNNBoxPredictor(BoxPredictor): ...@@ -399,7 +403,7 @@ class MaskRCNNBoxPredictor(BoxPredictor):
flattened_image_features = slim.dropout(flattened_image_features, flattened_image_features = slim.dropout(flattened_image_features,
keep_prob=self._dropout_keep_prob, keep_prob=self._dropout_keep_prob,
is_training=self._is_training) is_training=self._is_training)
with slim.arg_scope(self._fc_hyperparams): with slim.arg_scope(self._fc_hyperparams_fn()):
box_encodings = slim.fully_connected( box_encodings = slim.fully_connected(
flattened_image_features, flattened_image_features,
self._num_classes * self._box_code_size, self._num_classes * self._box_code_size,
...@@ -463,7 +467,7 @@ class MaskRCNNBoxPredictor(BoxPredictor): ...@@ -463,7 +467,7 @@ class MaskRCNNBoxPredictor(BoxPredictor):
num_feature_channels = image_features.get_shape().as_list()[3] num_feature_channels = image_features.get_shape().as_list()[3]
num_conv_channels = self._get_mask_predictor_conv_depth( num_conv_channels = self._get_mask_predictor_conv_depth(
num_feature_channels, self.num_classes) num_feature_channels, self.num_classes)
with slim.arg_scope(self._conv_hyperparams): with slim.arg_scope(self._conv_hyperparams_fn()):
upsampled_features = tf.image.resize_bilinear( upsampled_features = tf.image.resize_bilinear(
image_features, image_features,
[self._mask_height, self._mask_width], [self._mask_height, self._mask_width],
...@@ -473,8 +477,9 @@ class MaskRCNNBoxPredictor(BoxPredictor): ...@@ -473,8 +477,9 @@ class MaskRCNNBoxPredictor(BoxPredictor):
upsampled_features, upsampled_features,
num_outputs=num_conv_channels, num_outputs=num_conv_channels,
kernel_size=[3, 3]) kernel_size=[3, 3])
num_masks = 1 if self._masks_are_class_agnostic else self.num_classes
mask_predictions = slim.conv2d(upsampled_features, mask_predictions = slim.conv2d(upsampled_features,
num_outputs=self.num_classes, num_outputs=num_masks,
activation_fn=None, activation_fn=None,
kernel_size=[3, 3]) kernel_size=[3, 3])
return tf.expand_dims( return tf.expand_dims(
...@@ -578,7 +583,7 @@ class ConvolutionalBoxPredictor(BoxPredictor): ...@@ -578,7 +583,7 @@ class ConvolutionalBoxPredictor(BoxPredictor):
def __init__(self, def __init__(self,
is_training, is_training,
num_classes, num_classes,
conv_hyperparams, conv_hyperparams_fn,
min_depth, min_depth,
max_depth, max_depth,
num_layers_before_predictor, num_layers_before_predictor,
...@@ -597,8 +602,9 @@ class ConvolutionalBoxPredictor(BoxPredictor): ...@@ -597,8 +602,9 @@ class ConvolutionalBoxPredictor(BoxPredictor):
include the background category, so if groundtruth labels take values include the background category, so if groundtruth labels take values
in {0, 1, .., K-1}, num_classes=K (and not K+1, even though the in {0, 1, .., K-1}, num_classes=K (and not K+1, even though the
assigned classification targets can range from {0,... K}). assigned classification targets can range from {0,... K}).
conv_hyperparams: Slim arg_scope with hyperparameters for convolution ops. conv_hyperparams_fn: A function to generate tf-slim arg_scope with
min_depth: Minumum feature depth prior to predicting box encodings hyperparameters for convolution ops.
min_depth: Minimum feature depth prior to predicting box encodings
and class predictions. and class predictions.
max_depth: Maximum feature depth prior to predicting box encodings max_depth: Maximum feature depth prior to predicting box encodings
and class predictions. If max_depth is set to 0, no additional and class predictions. If max_depth is set to 0, no additional
...@@ -626,7 +632,7 @@ class ConvolutionalBoxPredictor(BoxPredictor): ...@@ -626,7 +632,7 @@ class ConvolutionalBoxPredictor(BoxPredictor):
super(ConvolutionalBoxPredictor, self).__init__(is_training, num_classes) super(ConvolutionalBoxPredictor, self).__init__(is_training, num_classes)
if min_depth > max_depth: if min_depth > max_depth:
raise ValueError('min_depth should be less than or equal to max_depth') raise ValueError('min_depth should be less than or equal to max_depth')
self._conv_hyperparams = conv_hyperparams self._conv_hyperparams_fn = conv_hyperparams_fn
self._min_depth = min_depth self._min_depth = min_depth
self._max_depth = max_depth self._max_depth = max_depth
self._num_layers_before_predictor = num_layers_before_predictor self._num_layers_before_predictor = num_layers_before_predictor
...@@ -679,7 +685,7 @@ class ConvolutionalBoxPredictor(BoxPredictor): ...@@ -679,7 +685,7 @@ class ConvolutionalBoxPredictor(BoxPredictor):
# Add a slot for the background class. # Add a slot for the background class.
num_class_slots = self.num_classes + 1 num_class_slots = self.num_classes + 1
net = image_feature net = image_feature
with slim.arg_scope(self._conv_hyperparams), \ with slim.arg_scope(self._conv_hyperparams_fn()), \
slim.arg_scope([slim.dropout], is_training=self._is_training): slim.arg_scope([slim.dropout], is_training=self._is_training):
# Add additional conv layers before the class predictor. # Add additional conv layers before the class predictor.
features_depth = static_shape.get_depth(image_feature.get_shape()) features_depth = static_shape.get_depth(image_feature.get_shape())
...@@ -767,7 +773,7 @@ class WeightSharedConvolutionalBoxPredictor(BoxPredictor): ...@@ -767,7 +773,7 @@ class WeightSharedConvolutionalBoxPredictor(BoxPredictor):
def __init__(self, def __init__(self,
is_training, is_training,
num_classes, num_classes,
conv_hyperparams, conv_hyperparams_fn,
depth, depth,
num_layers_before_predictor, num_layers_before_predictor,
box_code_size, box_code_size,
...@@ -781,7 +787,8 @@ class WeightSharedConvolutionalBoxPredictor(BoxPredictor): ...@@ -781,7 +787,8 @@ class WeightSharedConvolutionalBoxPredictor(BoxPredictor):
include the background category, so if groundtruth labels take values include the background category, so if groundtruth labels take values
in {0, 1, .., K-1}, num_classes=K (and not K+1, even though the in {0, 1, .., K-1}, num_classes=K (and not K+1, even though the
assigned classification targets can range from {0,... K}). assigned classification targets can range from {0,... K}).
conv_hyperparams: Slim arg_scope with hyperparameters for convolution ops. conv_hyperparams_fn: A function to generate tf-slim arg_scope with
hyperparameters for convolution ops.
depth: depth of conv layers. depth: depth of conv layers.
num_layers_before_predictor: Number of the additional conv layers before num_layers_before_predictor: Number of the additional conv layers before
the predictor. the predictor.
...@@ -792,7 +799,7 @@ class WeightSharedConvolutionalBoxPredictor(BoxPredictor): ...@@ -792,7 +799,7 @@ class WeightSharedConvolutionalBoxPredictor(BoxPredictor):
""" """
super(WeightSharedConvolutionalBoxPredictor, self).__init__(is_training, super(WeightSharedConvolutionalBoxPredictor, self).__init__(is_training,
num_classes) num_classes)
self._conv_hyperparams = conv_hyperparams self._conv_hyperparams_fn = conv_hyperparams_fn
self._depth = depth self._depth = depth
self._num_layers_before_predictor = num_layers_before_predictor self._num_layers_before_predictor = num_layers_before_predictor
self._box_code_size = box_code_size self._box_code_size = box_code_size
...@@ -846,7 +853,7 @@ class WeightSharedConvolutionalBoxPredictor(BoxPredictor): ...@@ -846,7 +853,7 @@ class WeightSharedConvolutionalBoxPredictor(BoxPredictor):
num_class_slots = self.num_classes + 1 num_class_slots = self.num_classes + 1
box_encodings_net = image_feature box_encodings_net = image_feature
class_predictions_net = image_feature class_predictions_net = image_feature
with slim.arg_scope(self._conv_hyperparams): with slim.arg_scope(self._conv_hyperparams_fn()):
for i in range(self._num_layers_before_predictor): for i in range(self._num_layers_before_predictor):
box_encodings_net = slim.conv2d( box_encodings_net = slim.conv2d(
box_encodings_net, box_encodings_net,
......
...@@ -49,7 +49,7 @@ class MaskRCNNBoxPredictorTest(tf.test.TestCase): ...@@ -49,7 +49,7 @@ class MaskRCNNBoxPredictorTest(tf.test.TestCase):
mask_box_predictor = box_predictor.MaskRCNNBoxPredictor( mask_box_predictor = box_predictor.MaskRCNNBoxPredictor(
is_training=False, is_training=False,
num_classes=5, num_classes=5,
fc_hyperparams=self._build_arg_scope_with_hyperparams(), fc_hyperparams_fn=self._build_arg_scope_with_hyperparams(),
use_dropout=False, use_dropout=False,
dropout_keep_prob=0.5, dropout_keep_prob=0.5,
box_code_size=4, box_code_size=4,
...@@ -75,7 +75,7 @@ class MaskRCNNBoxPredictorTest(tf.test.TestCase): ...@@ -75,7 +75,7 @@ class MaskRCNNBoxPredictorTest(tf.test.TestCase):
box_predictor.MaskRCNNBoxPredictor( box_predictor.MaskRCNNBoxPredictor(
is_training=False, is_training=False,
num_classes=5, num_classes=5,
fc_hyperparams=self._build_arg_scope_with_hyperparams(), fc_hyperparams_fn=self._build_arg_scope_with_hyperparams(),
use_dropout=False, use_dropout=False,
dropout_keep_prob=0.5, dropout_keep_prob=0.5,
box_code_size=4, box_code_size=4,
...@@ -86,11 +86,11 @@ class MaskRCNNBoxPredictorTest(tf.test.TestCase): ...@@ -86,11 +86,11 @@ class MaskRCNNBoxPredictorTest(tf.test.TestCase):
mask_box_predictor = box_predictor.MaskRCNNBoxPredictor( mask_box_predictor = box_predictor.MaskRCNNBoxPredictor(
is_training=False, is_training=False,
num_classes=5, num_classes=5,
fc_hyperparams=self._build_arg_scope_with_hyperparams(), fc_hyperparams_fn=self._build_arg_scope_with_hyperparams(),
use_dropout=False, use_dropout=False,
dropout_keep_prob=0.5, dropout_keep_prob=0.5,
box_code_size=4, box_code_size=4,
conv_hyperparams=self._build_arg_scope_with_hyperparams( conv_hyperparams_fn=self._build_arg_scope_with_hyperparams(
op_type=hyperparams_pb2.Hyperparams.CONV), op_type=hyperparams_pb2.Hyperparams.CONV),
predict_instance_masks=True) predict_instance_masks=True)
box_predictions = mask_box_predictor.predict( box_predictions = mask_box_predictor.predict(
...@@ -108,7 +108,7 @@ class MaskRCNNBoxPredictorTest(tf.test.TestCase): ...@@ -108,7 +108,7 @@ class MaskRCNNBoxPredictorTest(tf.test.TestCase):
mask_box_predictor = box_predictor.MaskRCNNBoxPredictor( mask_box_predictor = box_predictor.MaskRCNNBoxPredictor(
is_training=False, is_training=False,
num_classes=5, num_classes=5,
fc_hyperparams=self._build_arg_scope_with_hyperparams(), fc_hyperparams_fn=self._build_arg_scope_with_hyperparams(),
use_dropout=False, use_dropout=False,
dropout_keep_prob=0.5, dropout_keep_prob=0.5,
box_code_size=4) box_code_size=4)
...@@ -125,7 +125,7 @@ class MaskRCNNBoxPredictorTest(tf.test.TestCase): ...@@ -125,7 +125,7 @@ class MaskRCNNBoxPredictorTest(tf.test.TestCase):
box_predictor.MaskRCNNBoxPredictor( box_predictor.MaskRCNNBoxPredictor(
is_training=False, is_training=False,
num_classes=5, num_classes=5,
fc_hyperparams=self._build_arg_scope_with_hyperparams(), fc_hyperparams_fn=self._build_arg_scope_with_hyperparams(),
use_dropout=False, use_dropout=False,
dropout_keep_prob=0.5, dropout_keep_prob=0.5,
box_code_size=4, box_code_size=4,
...@@ -155,7 +155,7 @@ class RfcnBoxPredictorTest(tf.test.TestCase): ...@@ -155,7 +155,7 @@ class RfcnBoxPredictorTest(tf.test.TestCase):
rfcn_box_predictor = box_predictor.RfcnBoxPredictor( rfcn_box_predictor = box_predictor.RfcnBoxPredictor(
is_training=False, is_training=False,
num_classes=2, num_classes=2,
conv_hyperparams=self._build_arg_scope_with_conv_hyperparams(), conv_hyperparams_fn=self._build_arg_scope_with_conv_hyperparams(),
num_spatial_bins=[3, 3], num_spatial_bins=[3, 3],
depth=4, depth=4,
crop_size=[12, 12], crop_size=[12, 12],
...@@ -205,7 +205,7 @@ class ConvolutionalBoxPredictorTest(test_case.TestCase): ...@@ -205,7 +205,7 @@ class ConvolutionalBoxPredictorTest(test_case.TestCase):
conv_box_predictor = box_predictor.ConvolutionalBoxPredictor( conv_box_predictor = box_predictor.ConvolutionalBoxPredictor(
is_training=False, is_training=False,
num_classes=0, num_classes=0,
conv_hyperparams=self._build_arg_scope_with_conv_hyperparams(), conv_hyperparams_fn=self._build_arg_scope_with_conv_hyperparams(),
min_depth=0, min_depth=0,
max_depth=32, max_depth=32,
num_layers_before_predictor=1, num_layers_before_predictor=1,
...@@ -234,7 +234,7 @@ class ConvolutionalBoxPredictorTest(test_case.TestCase): ...@@ -234,7 +234,7 @@ class ConvolutionalBoxPredictorTest(test_case.TestCase):
conv_box_predictor = box_predictor.ConvolutionalBoxPredictor( conv_box_predictor = box_predictor.ConvolutionalBoxPredictor(
is_training=False, is_training=False,
num_classes=0, num_classes=0,
conv_hyperparams=self._build_arg_scope_with_conv_hyperparams(), conv_hyperparams_fn=self._build_arg_scope_with_conv_hyperparams(),
min_depth=0, min_depth=0,
max_depth=32, max_depth=32,
num_layers_before_predictor=1, num_layers_before_predictor=1,
...@@ -265,7 +265,7 @@ class ConvolutionalBoxPredictorTest(test_case.TestCase): ...@@ -265,7 +265,7 @@ class ConvolutionalBoxPredictorTest(test_case.TestCase):
conv_box_predictor = box_predictor.ConvolutionalBoxPredictor( conv_box_predictor = box_predictor.ConvolutionalBoxPredictor(
is_training=False, is_training=False,
num_classes=num_classes_without_background, num_classes=num_classes_without_background,
conv_hyperparams=self._build_arg_scope_with_conv_hyperparams(), conv_hyperparams_fn=self._build_arg_scope_with_conv_hyperparams(),
min_depth=0, min_depth=0,
max_depth=32, max_depth=32,
num_layers_before_predictor=1, num_layers_before_predictor=1,
...@@ -297,7 +297,7 @@ class ConvolutionalBoxPredictorTest(test_case.TestCase): ...@@ -297,7 +297,7 @@ class ConvolutionalBoxPredictorTest(test_case.TestCase):
conv_box_predictor = box_predictor.ConvolutionalBoxPredictor( conv_box_predictor = box_predictor.ConvolutionalBoxPredictor(
is_training=False, is_training=False,
num_classes=0, num_classes=0,
conv_hyperparams=self._build_arg_scope_with_conv_hyperparams(), conv_hyperparams_fn=self._build_arg_scope_with_conv_hyperparams(),
min_depth=0, min_depth=0,
max_depth=32, max_depth=32,
num_layers_before_predictor=1, num_layers_before_predictor=1,
...@@ -344,7 +344,7 @@ class ConvolutionalBoxPredictorTest(test_case.TestCase): ...@@ -344,7 +344,7 @@ class ConvolutionalBoxPredictorTest(test_case.TestCase):
conv_box_predictor = box_predictor.ConvolutionalBoxPredictor( conv_box_predictor = box_predictor.ConvolutionalBoxPredictor(
is_training=False, is_training=False,
num_classes=0, num_classes=0,
conv_hyperparams=self._build_arg_scope_with_conv_hyperparams(), conv_hyperparams_fn=self._build_arg_scope_with_conv_hyperparams(),
min_depth=0, min_depth=0,
max_depth=32, max_depth=32,
num_layers_before_predictor=1, num_layers_before_predictor=1,
...@@ -416,7 +416,7 @@ class WeightSharedConvolutionalBoxPredictorTest(test_case.TestCase): ...@@ -416,7 +416,7 @@ class WeightSharedConvolutionalBoxPredictorTest(test_case.TestCase):
conv_box_predictor = box_predictor.WeightSharedConvolutionalBoxPredictor( conv_box_predictor = box_predictor.WeightSharedConvolutionalBoxPredictor(
is_training=False, is_training=False,
num_classes=0, num_classes=0,
conv_hyperparams=self._build_arg_scope_with_conv_hyperparams(), conv_hyperparams_fn=self._build_arg_scope_with_conv_hyperparams(),
depth=32, depth=32,
num_layers_before_predictor=1, num_layers_before_predictor=1,
box_code_size=4) box_code_size=4)
...@@ -442,7 +442,7 @@ class WeightSharedConvolutionalBoxPredictorTest(test_case.TestCase): ...@@ -442,7 +442,7 @@ class WeightSharedConvolutionalBoxPredictorTest(test_case.TestCase):
conv_box_predictor = box_predictor.WeightSharedConvolutionalBoxPredictor( conv_box_predictor = box_predictor.WeightSharedConvolutionalBoxPredictor(
is_training=False, is_training=False,
num_classes=num_classes_without_background, num_classes=num_classes_without_background,
conv_hyperparams=self._build_arg_scope_with_conv_hyperparams(), conv_hyperparams_fn=self._build_arg_scope_with_conv_hyperparams(),
depth=32, depth=32,
num_layers_before_predictor=1, num_layers_before_predictor=1,
box_code_size=4) box_code_size=4)
...@@ -471,7 +471,7 @@ class WeightSharedConvolutionalBoxPredictorTest(test_case.TestCase): ...@@ -471,7 +471,7 @@ class WeightSharedConvolutionalBoxPredictorTest(test_case.TestCase):
conv_box_predictor = box_predictor.WeightSharedConvolutionalBoxPredictor( conv_box_predictor = box_predictor.WeightSharedConvolutionalBoxPredictor(
is_training=False, is_training=False,
num_classes=num_classes_without_background, num_classes=num_classes_without_background,
conv_hyperparams=self._build_arg_scope_with_conv_hyperparams(), conv_hyperparams_fn=self._build_arg_scope_with_conv_hyperparams(),
depth=32, depth=32,
num_layers_before_predictor=1, num_layers_before_predictor=1,
box_code_size=4) box_code_size=4)
...@@ -500,7 +500,7 @@ class WeightSharedConvolutionalBoxPredictorTest(test_case.TestCase): ...@@ -500,7 +500,7 @@ class WeightSharedConvolutionalBoxPredictorTest(test_case.TestCase):
conv_box_predictor = box_predictor.WeightSharedConvolutionalBoxPredictor( conv_box_predictor = box_predictor.WeightSharedConvolutionalBoxPredictor(
is_training=False, is_training=False,
num_classes=num_classes_without_background, num_classes=num_classes_without_background,
conv_hyperparams=self._build_arg_scope_with_conv_hyperparams(), conv_hyperparams_fn=self._build_arg_scope_with_conv_hyperparams(),
depth=32, depth=32,
num_layers_before_predictor=2, num_layers_before_predictor=2,
box_code_size=4) box_code_size=4)
...@@ -553,7 +553,7 @@ class WeightSharedConvolutionalBoxPredictorTest(test_case.TestCase): ...@@ -553,7 +553,7 @@ class WeightSharedConvolutionalBoxPredictorTest(test_case.TestCase):
conv_box_predictor = box_predictor.WeightSharedConvolutionalBoxPredictor( conv_box_predictor = box_predictor.WeightSharedConvolutionalBoxPredictor(
is_training=False, is_training=False,
num_classes=0, num_classes=0,
conv_hyperparams=self._build_arg_scope_with_conv_hyperparams(), conv_hyperparams_fn=self._build_arg_scope_with_conv_hyperparams(),
depth=32, depth=32,
num_layers_before_predictor=1, num_layers_before_predictor=1,
box_code_size=4) box_code_size=4)
......
...@@ -69,7 +69,7 @@ class DetectionModel(object): ...@@ -69,7 +69,7 @@ class DetectionModel(object):
Args: Args:
num_classes: number of classes. Note that num_classes *does not* include num_classes: number of classes. Note that num_classes *does not* include
background categories that might be implicitly be predicted in various background categories that might be implicitly predicted in various
implementations. implementations.
""" """
self._num_classes = num_classes self._num_classes = num_classes
......
...@@ -119,6 +119,9 @@ class PreprocessorTest(tf.test.TestCase): ...@@ -119,6 +119,9 @@ class PreprocessorTest(tf.test.TestCase):
[[-0.1, 0.25, 0.75, 1], [0.25, 0.5, 0.75, 1.1]], dtype=tf.float32) [[-0.1, 0.25, 0.75, 1], [0.25, 0.5, 0.75, 1.1]], dtype=tf.float32)
return boxes return boxes
def createTestMultiClassScores(self):
return tf.constant([[1.0, 0.0], [0.5, 0.5]], dtype=tf.float32)
def expectedImagesAfterNormalization(self): def expectedImagesAfterNormalization(self):
images_r = tf.constant([[[0, 0, 0, 0], [-1, -1, 0, 0], images_r = tf.constant([[[0, 0, 0, 0], [-1, -1, 0, 0],
[-1, 0, 0, 0], [0.5, 0.5, 0, 0]]], [-1, 0, 0, 0], [0.5, 0.5, 0, 0]]],
...@@ -269,6 +272,9 @@ class PreprocessorTest(tf.test.TestCase): ...@@ -269,6 +272,9 @@ class PreprocessorTest(tf.test.TestCase):
def expectedLabelsAfterThresholding(self): def expectedLabelsAfterThresholding(self):
return tf.constant([1], dtype=tf.float32) return tf.constant([1], dtype=tf.float32)
def expectedMultiClassScoresAfterThresholding(self):
return tf.constant([[1.0, 0.0]], dtype=tf.float32)
def expectedMasksAfterThresholding(self): def expectedMasksAfterThresholding(self):
mask = np.array([ mask = np.array([
[[255.0, 0.0, 0.0], [[255.0, 0.0, 0.0],
...@@ -345,6 +351,28 @@ class PreprocessorTest(tf.test.TestCase): ...@@ -345,6 +351,28 @@ class PreprocessorTest(tf.test.TestCase):
self.assertAllClose( self.assertAllClose(
retained_label_scores_, expected_retained_label_scores_) retained_label_scores_, expected_retained_label_scores_)
def testRetainBoxesAboveThresholdWithMultiClassScores(self):
boxes = self.createTestBoxes()
labels = self.createTestLabels()
label_scores = self.createTestLabelScores()
multiclass_scores = self.createTestMultiClassScores()
(_, _, _,
retained_multiclass_scores) = preprocessor.retain_boxes_above_threshold(
boxes,
labels,
label_scores,
multiclass_scores=multiclass_scores,
threshold=0.6)
with self.test_session() as sess:
(retained_multiclass_scores_,
expected_retained_multiclass_scores_) = sess.run([
retained_multiclass_scores,
self.expectedMultiClassScoresAfterThresholding()
])
self.assertAllClose(retained_multiclass_scores_,
expected_retained_multiclass_scores_)
def testRetainBoxesAboveThresholdWithMasks(self): def testRetainBoxesAboveThresholdWithMasks(self):
boxes = self.createTestBoxes() boxes = self.createTestBoxes()
labels = self.createTestLabels() labels = self.createTestLabels()
...@@ -1264,6 +1292,56 @@ class PreprocessorTest(tf.test.TestCase): ...@@ -1264,6 +1292,56 @@ class PreprocessorTest(tf.test.TestCase):
self.assertAllClose(distorted_boxes_, expected_boxes_) self.assertAllClose(distorted_boxes_, expected_boxes_)
self.assertAllEqual(distorted_labels_, expected_labels_) self.assertAllEqual(distorted_labels_, expected_labels_)
def testRandomCropImageWithMultiClassScores(self):
preprocessing_options = []
preprocessing_options.append((preprocessor.normalize_image, {
'original_minval': 0,
'original_maxval': 255,
'target_minval': 0,
'target_maxval': 1
}))
preprocessing_options.append((preprocessor.random_crop_image, {}))
images = self.createTestImages()
boxes = self.createTestBoxes()
labels = self.createTestLabels()
multiclass_scores = self.createTestMultiClassScores()
tensor_dict = {
fields.InputDataFields.image: images,
fields.InputDataFields.groundtruth_boxes: boxes,
fields.InputDataFields.groundtruth_classes: labels,
fields.InputDataFields.multiclass_scores: multiclass_scores
}
distorted_tensor_dict = preprocessor.preprocess(tensor_dict,
preprocessing_options)
distorted_images = distorted_tensor_dict[fields.InputDataFields.image]
distorted_boxes = distorted_tensor_dict[
fields.InputDataFields.groundtruth_boxes]
distorted_multiclass_scores = distorted_tensor_dict[
fields.InputDataFields.multiclass_scores]
boxes_rank = tf.rank(boxes)
distorted_boxes_rank = tf.rank(distorted_boxes)
images_rank = tf.rank(images)
distorted_images_rank = tf.rank(distorted_images)
multiclass_scores_rank = tf.rank(multiclass_scores)
distorted_multiclass_scores_rank = tf.rank(distorted_multiclass_scores)
with self.test_session() as sess:
(boxes_rank_, distorted_boxes_, distorted_boxes_rank_, images_rank_,
distorted_images_rank_, multiclass_scores_rank_,
distorted_multiclass_scores_rank_,
distorted_multiclass_scores_) = sess.run([
boxes_rank, distorted_boxes, distorted_boxes_rank, images_rank,
distorted_images_rank, multiclass_scores_rank,
distorted_multiclass_scores_rank, distorted_multiclass_scores
])
self.assertAllEqual(boxes_rank_, distorted_boxes_rank_)
self.assertAllEqual(images_rank_, distorted_images_rank_)
self.assertAllEqual(multiclass_scores_rank_,
distorted_multiclass_scores_rank_)
self.assertAllEqual(distorted_boxes_.shape[0],
distorted_multiclass_scores_.shape[0])
def testStrictRandomCropImageWithLabelScores(self): def testStrictRandomCropImageWithLabelScores(self):
image = self.createColorfulTestImage()[0] image = self.createColorfulTestImage()[0]
boxes = self.createTestBoxes() boxes = self.createTestBoxes()
...@@ -2510,6 +2588,49 @@ class PreprocessorTest(tf.test.TestCase): ...@@ -2510,6 +2588,49 @@ class PreprocessorTest(tf.test.TestCase):
self.assertAllEqual(boxes_rank_, distorted_boxes_rank_) self.assertAllEqual(boxes_rank_, distorted_boxes_rank_)
self.assertAllEqual(images_rank_, distorted_images_rank_) self.assertAllEqual(images_rank_, distorted_images_rank_)
def testSSDRandomCropWithMultiClassScores(self):
preprocessing_options = [(preprocessor.normalize_image, {
'original_minval': 0,
'original_maxval': 255,
'target_minval': 0,
'target_maxval': 1
}), (preprocessor.ssd_random_crop, {})]
images = self.createTestImages()
boxes = self.createTestBoxes()
labels = self.createTestLabels()
multiclass_scores = self.createTestMultiClassScores()
tensor_dict = {
fields.InputDataFields.image: images,
fields.InputDataFields.groundtruth_boxes: boxes,
fields.InputDataFields.groundtruth_classes: labels,
fields.InputDataFields.multiclass_scores: multiclass_scores,
}
preprocessor_arg_map = preprocessor.get_default_func_arg_map(
include_multiclass_scores=True)
distorted_tensor_dict = preprocessor.preprocess(
tensor_dict, preprocessing_options, func_arg_map=preprocessor_arg_map)
distorted_images = distorted_tensor_dict[fields.InputDataFields.image]
distorted_boxes = distorted_tensor_dict[
fields.InputDataFields.groundtruth_boxes]
distorted_multiclass_scores = distorted_tensor_dict[
fields.InputDataFields.multiclass_scores]
images_rank = tf.rank(images)
distorted_images_rank = tf.rank(distorted_images)
boxes_rank = tf.rank(boxes)
distorted_boxes_rank = tf.rank(distorted_boxes)
with self.test_session() as sess:
(boxes_rank_, distorted_boxes_rank_, images_rank_, distorted_images_rank_,
multiclass_scores_, distorted_multiclass_scores_) = sess.run([
boxes_rank, distorted_boxes_rank, images_rank, distorted_images_rank,
multiclass_scores, distorted_multiclass_scores
])
self.assertAllEqual(boxes_rank_, distorted_boxes_rank_)
self.assertAllEqual(images_rank_, distorted_images_rank_)
self.assertAllEqual(multiclass_scores_, distorted_multiclass_scores_)
def testSSDRandomCropPad(self): def testSSDRandomCropPad(self):
images = self.createTestImages() images = self.createTestImages()
boxes = self.createTestBoxes() boxes = self.createTestBoxes()
...@@ -2562,28 +2683,31 @@ class PreprocessorTest(tf.test.TestCase): ...@@ -2562,28 +2683,31 @@ class PreprocessorTest(tf.test.TestCase):
def _testSSDRandomCropFixedAspectRatio(self, def _testSSDRandomCropFixedAspectRatio(self,
include_label_scores, include_label_scores,
include_multiclass_scores,
include_instance_masks, include_instance_masks,
include_keypoints): include_keypoints):
images = self.createTestImages() images = self.createTestImages()
boxes = self.createTestBoxes() boxes = self.createTestBoxes()
labels = self.createTestLabels() labels = self.createTestLabels()
preprocessing_options = [ preprocessing_options = [(preprocessor.normalize_image, {
(preprocessor.normalize_image, {
'original_minval': 0, 'original_minval': 0,
'original_maxval': 255, 'original_maxval': 255,
'target_minval': 0, 'target_minval': 0,
'target_maxval': 1 'target_maxval': 1
}), }), (preprocessor.ssd_random_crop_fixed_aspect_ratio, {})]
(preprocessor.ssd_random_crop_fixed_aspect_ratio, {})]
tensor_dict = { tensor_dict = {
fields.InputDataFields.image: images, fields.InputDataFields.image: images,
fields.InputDataFields.groundtruth_boxes: boxes, fields.InputDataFields.groundtruth_boxes: boxes,
fields.InputDataFields.groundtruth_classes: labels fields.InputDataFields.groundtruth_classes: labels,
} }
if include_label_scores: if include_label_scores:
label_scores = self.createTestLabelScores() label_scores = self.createTestLabelScores()
tensor_dict[fields.InputDataFields.groundtruth_label_scores] = ( tensor_dict[fields.InputDataFields.groundtruth_label_scores] = (
label_scores) label_scores)
if include_multiclass_scores:
multiclass_scores = self.createTestMultiClassScores()
tensor_dict[fields.InputDataFields.multiclass_scores] = (
multiclass_scores)
if include_instance_masks: if include_instance_masks:
masks = self.createTestMasks() masks = self.createTestMasks()
tensor_dict[fields.InputDataFields.groundtruth_instance_masks] = masks tensor_dict[fields.InputDataFields.groundtruth_instance_masks] = masks
...@@ -2593,6 +2717,7 @@ class PreprocessorTest(tf.test.TestCase): ...@@ -2593,6 +2717,7 @@ class PreprocessorTest(tf.test.TestCase):
preprocessor_arg_map = preprocessor.get_default_func_arg_map( preprocessor_arg_map = preprocessor.get_default_func_arg_map(
include_label_scores=include_label_scores, include_label_scores=include_label_scores,
include_multiclass_scores=include_multiclass_scores,
include_instance_masks=include_instance_masks, include_instance_masks=include_instance_masks,
include_keypoints=include_keypoints) include_keypoints=include_keypoints)
distorted_tensor_dict = preprocessor.preprocess( distorted_tensor_dict = preprocessor.preprocess(
...@@ -2615,16 +2740,25 @@ class PreprocessorTest(tf.test.TestCase): ...@@ -2615,16 +2740,25 @@ class PreprocessorTest(tf.test.TestCase):
def testSSDRandomCropFixedAspectRatio(self): def testSSDRandomCropFixedAspectRatio(self):
self._testSSDRandomCropFixedAspectRatio(include_label_scores=False, self._testSSDRandomCropFixedAspectRatio(include_label_scores=False,
include_multiclass_scores=False,
include_instance_masks=False,
include_keypoints=False)
def testSSDRandomCropFixedAspectRatioWithMultiClassScores(self):
self._testSSDRandomCropFixedAspectRatio(include_label_scores=False,
include_multiclass_scores=True,
include_instance_masks=False, include_instance_masks=False,
include_keypoints=False) include_keypoints=False)
def testSSDRandomCropFixedAspectRatioWithMasksAndKeypoints(self): def testSSDRandomCropFixedAspectRatioWithMasksAndKeypoints(self):
self._testSSDRandomCropFixedAspectRatio(include_label_scores=False, self._testSSDRandomCropFixedAspectRatio(include_label_scores=False,
include_multiclass_scores=False,
include_instance_masks=True, include_instance_masks=True,
include_keypoints=True) include_keypoints=True)
def testSSDRandomCropFixedAspectRatioWithLabelScoresMasksAndKeypoints(self): def testSSDRandomCropFixedAspectRatioWithLabelScoresMasksAndKeypoints(self):
self._testSSDRandomCropFixedAspectRatio(include_label_scores=True, self._testSSDRandomCropFixedAspectRatio(include_label_scores=True,
include_multiclass_scores=False,
include_instance_masks=True, include_instance_masks=True,
include_keypoints=True) include_keypoints=True)
......
...@@ -61,6 +61,9 @@ class InputDataFields(object): ...@@ -61,6 +61,9 @@ class InputDataFields(object):
num_groundtruth_boxes: number of groundtruth boxes. num_groundtruth_boxes: number of groundtruth boxes.
true_image_shapes: true shapes of images in the resized images, as resized true_image_shapes: true shapes of images in the resized images, as resized
images can be padded with zeros. images can be padded with zeros.
verified_labels: list of human-verified image-level labels (note, that a
label can be verified both as positive and negative).
multiclass_scores: the label score per class for each box.
""" """
image = 'image' image = 'image'
original_image = 'original_image' original_image = 'original_image'
...@@ -86,6 +89,8 @@ class InputDataFields(object): ...@@ -86,6 +89,8 @@ class InputDataFields(object):
groundtruth_weights = 'groundtruth_weights' groundtruth_weights = 'groundtruth_weights'
num_groundtruth_boxes = 'num_groundtruth_boxes' num_groundtruth_boxes = 'num_groundtruth_boxes'
true_image_shape = 'true_image_shape' true_image_shape = 'true_image_shape'
verified_labels = 'verified_labels'
multiclass_scores = 'multiclass_scores'
class DetectionResultFields(object): class DetectionResultFields(object):
......
...@@ -104,8 +104,7 @@ def dict_to_tf_example(data, ...@@ -104,8 +104,7 @@ def dict_to_tf_example(data,
truncated = [] truncated = []
poses = [] poses = []
difficult_obj = [] difficult_obj = []
if 'object' in data:
if data.has_key('object'):
for obj in data['object']: for obj in data['object']:
difficult = bool(int(obj['difficult'])) difficult = bool(int(obj['difficult']))
if ignore_difficult_instances and difficult: if ignore_difficult_instances and difficult:
......
...@@ -136,6 +136,7 @@ def dict_to_tf_example(data, ...@@ -136,6 +136,7 @@ def dict_to_tf_example(data,
poses = [] poses = []
difficult_obj = [] difficult_obj = []
masks = [] masks = []
if 'object' in data:
for obj in data['object']: for obj in data['object']:
difficult = bool(int(obj['difficult'])) difficult = bool(int(obj['difficult']))
if ignore_difficult_instances and difficult: if ignore_difficult_instances and difficult:
......
...@@ -229,7 +229,7 @@ class FasterRCNNMetaArch(model.DetectionModel): ...@@ -229,7 +229,7 @@ class FasterRCNNMetaArch(model.DetectionModel):
number_of_stages, number_of_stages,
first_stage_anchor_generator, first_stage_anchor_generator,
first_stage_atrous_rate, first_stage_atrous_rate,
first_stage_box_predictor_arg_scope, first_stage_box_predictor_arg_scope_fn,
first_stage_box_predictor_kernel_size, first_stage_box_predictor_kernel_size,
first_stage_box_predictor_depth, first_stage_box_predictor_depth,
first_stage_minibatch_size, first_stage_minibatch_size,
...@@ -291,8 +291,9 @@ class FasterRCNNMetaArch(model.DetectionModel): ...@@ -291,8 +291,9 @@ class FasterRCNNMetaArch(model.DetectionModel):
denser resolutions. The atrous rate is used to compensate for the denser resolutions. The atrous rate is used to compensate for the
denser feature maps by using an effectively larger receptive field. denser feature maps by using an effectively larger receptive field.
(This should typically be set to 1). (This should typically be set to 1).
first_stage_box_predictor_arg_scope: Slim arg_scope for conv2d, first_stage_box_predictor_arg_scope_fn: A function to construct tf-slim
separable_conv2d and fully_connected ops for the RPN box predictor. arg_scope for conv2d, separable_conv2d and fully_connected ops for the
RPN box predictor.
first_stage_box_predictor_kernel_size: Kernel size to use for the first_stage_box_predictor_kernel_size: Kernel size to use for the
convolution op just prior to RPN box predictions. convolution op just prior to RPN box predictions.
first_stage_box_predictor_depth: Output depth for the convolution op first_stage_box_predictor_depth: Output depth for the convolution op
...@@ -396,8 +397,8 @@ class FasterRCNNMetaArch(model.DetectionModel): ...@@ -396,8 +397,8 @@ class FasterRCNNMetaArch(model.DetectionModel):
# (First stage) Region proposal network parameters # (First stage) Region proposal network parameters
self._first_stage_anchor_generator = first_stage_anchor_generator self._first_stage_anchor_generator = first_stage_anchor_generator
self._first_stage_atrous_rate = first_stage_atrous_rate self._first_stage_atrous_rate = first_stage_atrous_rate
self._first_stage_box_predictor_arg_scope = ( self._first_stage_box_predictor_arg_scope_fn = (
first_stage_box_predictor_arg_scope) first_stage_box_predictor_arg_scope_fn)
self._first_stage_box_predictor_kernel_size = ( self._first_stage_box_predictor_kernel_size = (
first_stage_box_predictor_kernel_size) first_stage_box_predictor_kernel_size)
self._first_stage_box_predictor_depth = first_stage_box_predictor_depth self._first_stage_box_predictor_depth = first_stage_box_predictor_depth
...@@ -406,7 +407,7 @@ class FasterRCNNMetaArch(model.DetectionModel): ...@@ -406,7 +407,7 @@ class FasterRCNNMetaArch(model.DetectionModel):
positive_fraction=first_stage_positive_balance_fraction) positive_fraction=first_stage_positive_balance_fraction)
self._first_stage_box_predictor = box_predictor.ConvolutionalBoxPredictor( self._first_stage_box_predictor = box_predictor.ConvolutionalBoxPredictor(
self._is_training, num_classes=1, self._is_training, num_classes=1,
conv_hyperparams=self._first_stage_box_predictor_arg_scope, conv_hyperparams_fn=self._first_stage_box_predictor_arg_scope_fn,
min_depth=0, max_depth=0, num_layers_before_predictor=0, min_depth=0, max_depth=0, num_layers_before_predictor=0,
use_dropout=False, dropout_keep_prob=1.0, kernel_size=1, use_dropout=False, dropout_keep_prob=1.0, kernel_size=1,
box_code_size=self._box_coder.code_size) box_code_size=self._box_coder.code_size)
...@@ -450,8 +451,6 @@ class FasterRCNNMetaArch(model.DetectionModel): ...@@ -450,8 +451,6 @@ class FasterRCNNMetaArch(model.DetectionModel):
if self._number_of_stages <= 0 or self._number_of_stages > 3: if self._number_of_stages <= 0 or self._number_of_stages > 3:
raise ValueError('Number of stages should be a value in {1, 2, 3}.') raise ValueError('Number of stages should be a value in {1, 2, 3}.')
if self._is_training and self._number_of_stages == 3:
self._number_of_stages = 2
@property @property
def first_stage_feature_extractor_scope(self): def first_stage_feature_extractor_scope(self):
...@@ -738,9 +737,6 @@ class FasterRCNNMetaArch(model.DetectionModel): ...@@ -738,9 +737,6 @@ class FasterRCNNMetaArch(model.DetectionModel):
of the image. of the image.
6) box_classifier_features: a 4-D float32 tensor representing the 6) box_classifier_features: a 4-D float32 tensor representing the
features for each proposal. features for each proposal.
7) mask_predictions: (optional) a 4-D tensor with shape
[total_num_padded_proposals, num_classes, mask_height, mask_width]
containing instance mask predictions.
""" """
image_shape_2d = self._image_batch_shape_2d(image_shape) image_shape_2d = self._image_batch_shape_2d(image_shape)
proposal_boxes_normalized, _, num_proposals = self._postprocess_rpn( proposal_boxes_normalized, _, num_proposals = self._postprocess_rpn(
...@@ -756,20 +752,18 @@ class FasterRCNNMetaArch(model.DetectionModel): ...@@ -756,20 +752,18 @@ class FasterRCNNMetaArch(model.DetectionModel):
flattened_proposal_feature_maps, flattened_proposal_feature_maps,
scope=self.second_stage_feature_extractor_scope)) scope=self.second_stage_feature_extractor_scope))
predict_auxiliary_outputs = False
if self._number_of_stages == 2:
predict_auxiliary_outputs = True
box_predictions = self._mask_rcnn_box_predictor.predict( box_predictions = self._mask_rcnn_box_predictor.predict(
[box_classifier_features], [box_classifier_features],
num_predictions_per_location=[1], num_predictions_per_location=[1],
scope=self.second_stage_box_predictor_scope, scope=self.second_stage_box_predictor_scope,
predict_boxes_and_classes=True, predict_boxes_and_classes=True)
predict_auxiliary_outputs=predict_auxiliary_outputs)
refined_box_encodings = tf.squeeze( refined_box_encodings = tf.squeeze(
box_predictions[box_predictor.BOX_ENCODINGS], axis=1) box_predictions[box_predictor.BOX_ENCODINGS],
class_predictions_with_background = tf.squeeze(box_predictions[ axis=1, name='all_refined_box_encodings')
box_predictor.CLASS_PREDICTIONS_WITH_BACKGROUND], axis=1) class_predictions_with_background = tf.squeeze(
box_predictions[box_predictor.CLASS_PREDICTIONS_WITH_BACKGROUND],
axis=1, name='all_class_predictions_with_background')
absolute_proposal_boxes = ops.normalized_to_image_coordinates( absolute_proposal_boxes = ops.normalized_to_image_coordinates(
proposal_boxes_normalized, image_shape, self._parallel_iterations) proposal_boxes_normalized, image_shape, self._parallel_iterations)
...@@ -783,16 +777,17 @@ class FasterRCNNMetaArch(model.DetectionModel): ...@@ -783,16 +777,17 @@ class FasterRCNNMetaArch(model.DetectionModel):
'box_classifier_features': box_classifier_features, 'box_classifier_features': box_classifier_features,
'proposal_boxes_normalized': proposal_boxes_normalized, 'proposal_boxes_normalized': proposal_boxes_normalized,
} }
if box_predictor.MASK_PREDICTIONS in box_predictions:
mask_predictions = tf.squeeze(box_predictions[
box_predictor.MASK_PREDICTIONS], axis=1)
prediction_dict['mask_predictions'] = mask_predictions
return prediction_dict return prediction_dict
def _predict_third_stage(self, prediction_dict, image_shapes): def _predict_third_stage(self, prediction_dict, image_shapes):
"""Predicts non-box, non-class outputs using refined detections. """Predicts non-box, non-class outputs using refined detections.
For training, masks as predicted directly on the box_classifier_features,
which are region-features from the initial anchor boxes.
For inference, this happens after calling the post-processing stage, such
that masks are only calculated for the top scored boxes.
Args: Args:
prediction_dict: a dictionary holding "raw" prediction tensors: prediction_dict: a dictionary holding "raw" prediction tensors:
1) refined_box_encodings: a 3-D tensor with shape 1) refined_box_encodings: a 3-D tensor with shape
...@@ -813,16 +808,30 @@ class FasterRCNNMetaArch(model.DetectionModel): ...@@ -813,16 +808,30 @@ class FasterRCNNMetaArch(model.DetectionModel):
4) proposal_boxes: A float32 tensor of shape 4) proposal_boxes: A float32 tensor of shape
[batch_size, self.max_num_proposals, 4] representing [batch_size, self.max_num_proposals, 4] representing
decoded proposal bounding boxes in absolute coordinates. decoded proposal bounding boxes in absolute coordinates.
5) box_classifier_features: a 4-D float32 tensor representing the
features for each proposal.
image_shapes: A 2-D int32 tensors of shape [batch_size, 3] containing image_shapes: A 2-D int32 tensors of shape [batch_size, 3] containing
shapes of images in the batch. shapes of images in the batch.
Returns: Returns:
prediction_dict: a dictionary that in addition to the input predictions prediction_dict: a dictionary that in addition to the input predictions
does hold the following predictions as well: does hold the following predictions as well:
1) mask_predictions: (optional) a 4-D tensor with shape 1) mask_predictions: a 4-D tensor with shape
[batch_size, max_detection, mask_height, mask_width] containing [batch_size, max_detection, mask_height, mask_width] containing
instance mask predictions. instance mask predictions.
""" """
if self._is_training:
curr_box_classifier_features = prediction_dict['box_classifier_features']
detection_classes = prediction_dict['class_predictions_with_background']
box_predictions = self._mask_rcnn_box_predictor.predict(
[curr_box_classifier_features],
num_predictions_per_location=[1],
scope=self.second_stage_box_predictor_scope,
predict_boxes_and_classes=False,
predict_auxiliary_outputs=True)
prediction_dict['mask_predictions'] = tf.squeeze(box_predictions[
box_predictor.MASK_PREDICTIONS], axis=1)
else:
detections_dict = self._postprocess_box_classifier( detections_dict = self._postprocess_box_classifier(
prediction_dict['refined_box_encodings'], prediction_dict['refined_box_encodings'],
prediction_dict['class_predictions_with_background'], prediction_dict['class_predictions_with_background'],
...@@ -840,26 +849,32 @@ class FasterRCNNMetaArch(model.DetectionModel): ...@@ -840,26 +849,32 @@ class FasterRCNNMetaArch(model.DetectionModel):
flattened_detected_feature_maps = ( flattened_detected_feature_maps = (
self._compute_second_stage_input_feature_maps( self._compute_second_stage_input_feature_maps(
rpn_features_to_crop, detection_boxes)) rpn_features_to_crop, detection_boxes))
detected_box_classifier_features = ( curr_box_classifier_features = (
self._feature_extractor.extract_box_classifier_features( self._feature_extractor.extract_box_classifier_features(
flattened_detected_feature_maps, flattened_detected_feature_maps,
scope=self.second_stage_feature_extractor_scope)) scope=self.second_stage_feature_extractor_scope))
box_predictions = self._mask_rcnn_box_predictor.predict( box_predictions = self._mask_rcnn_box_predictor.predict(
[detected_box_classifier_features], [curr_box_classifier_features],
num_predictions_per_location=[1], num_predictions_per_location=[1],
scope=self.second_stage_box_predictor_scope, scope=self.second_stage_box_predictor_scope,
predict_boxes_and_classes=False, predict_boxes_and_classes=False,
predict_auxiliary_outputs=True) predict_auxiliary_outputs=True)
if box_predictor.MASK_PREDICTIONS in box_predictions:
detection_masks = tf.squeeze(box_predictions[ detection_masks = tf.squeeze(box_predictions[
box_predictor.MASK_PREDICTIONS], axis=1) box_predictor.MASK_PREDICTIONS], axis=1)
detection_masks = self._gather_instance_masks(detection_masks,
detection_classes) _, num_classes, mask_height, mask_width = (
mask_height = tf.shape(detection_masks)[1] detection_masks.get_shape().as_list())
mask_width = tf.shape(detection_masks)[2] _, max_detection = detection_classes.get_shape().as_list()
if num_classes > 1:
detection_masks = self._gather_instance_masks(
detection_masks, detection_classes)
prediction_dict[fields.DetectionResultFields.detection_masks] = ( prediction_dict[fields.DetectionResultFields.detection_masks] = (
tf.reshape(detection_masks, tf.reshape(detection_masks,
[batch_size, max_detection, mask_height, mask_width])) [batch_size, max_detection, mask_height, mask_width]))
return prediction_dict return prediction_dict
def _gather_instance_masks(self, instance_masks, classes): def _gather_instance_masks(self, instance_masks, classes):
...@@ -873,16 +888,12 @@ class FasterRCNNMetaArch(model.DetectionModel): ...@@ -873,16 +888,12 @@ class FasterRCNNMetaArch(model.DetectionModel):
Returns: Returns:
masks: a 3-D float32 tensor with shape [K, mask_height, mask_width]. masks: a 3-D float32 tensor with shape [K, mask_height, mask_width].
""" """
_, num_classes, height, width = instance_masks.get_shape().as_list()
k = tf.shape(instance_masks)[0] k = tf.shape(instance_masks)[0]
num_mask_classes = tf.shape(instance_masks)[1] instance_masks = tf.reshape(instance_masks, [-1, height, width])
instance_mask_height = tf.shape(instance_masks)[2] classes = tf.to_int32(tf.reshape(classes, [-1]))
instance_mask_width = tf.shape(instance_masks)[3] gather_idx = tf.range(k) * num_classes + classes
classes = tf.reshape(classes, [-1]) return tf.gather(instance_masks, gather_idx)
instance_masks = tf.reshape(instance_masks, [
-1, instance_mask_height, instance_mask_width
])
return tf.gather(instance_masks,
tf.range(k) * num_mask_classes + tf.to_int32(classes))
def _extract_rpn_feature_maps(self, preprocessed_inputs): def _extract_rpn_feature_maps(self, preprocessed_inputs):
"""Extracts RPN features. """Extracts RPN features.
...@@ -914,7 +925,7 @@ class FasterRCNNMetaArch(model.DetectionModel): ...@@ -914,7 +925,7 @@ class FasterRCNNMetaArch(model.DetectionModel):
anchors = box_list_ops.concatenate( anchors = box_list_ops.concatenate(
self._first_stage_anchor_generator.generate([(feature_map_shape[1], self._first_stage_anchor_generator.generate([(feature_map_shape[1],
feature_map_shape[2])])) feature_map_shape[2])]))
with slim.arg_scope(self._first_stage_box_predictor_arg_scope): with slim.arg_scope(self._first_stage_box_predictor_arg_scope_fn()):
kernel_size = self._first_stage_box_predictor_kernel_size kernel_size = self._first_stage_box_predictor_kernel_size
rpn_box_predictor_features = slim.conv2d( rpn_box_predictor_features = slim.conv2d(
rpn_features_to_crop, rpn_features_to_crop,
...@@ -1814,11 +1825,18 @@ class FasterRCNNMetaArch(model.DetectionModel): ...@@ -1814,11 +1825,18 @@ class FasterRCNNMetaArch(model.DetectionModel):
# Pad the prediction_masks with to add zeros for background class to be # Pad the prediction_masks with to add zeros for background class to be
# consistent with class predictions. # consistent with class predictions.
if prediction_masks.get_shape().as_list()[1] == 1:
# Class agnostic masks or masks for one-class prediction. Logic for
# both cases is the same since background predictions are ignored
# through the batch_mask_target_weights.
prediction_masks_masked_by_class_targets = prediction_masks
else:
prediction_masks_with_background = tf.pad( prediction_masks_with_background = tf.pad(
prediction_masks, [[0, 0], [1, 0], [0, 0], [0, 0]]) prediction_masks, [[0, 0], [1, 0], [0, 0], [0, 0]])
prediction_masks_masked_by_class_targets = tf.boolean_mask( prediction_masks_masked_by_class_targets = tf.boolean_mask(
prediction_masks_with_background, prediction_masks_with_background,
tf.greater(one_hot_flat_cls_targets_with_background, 0)) tf.greater(one_hot_flat_cls_targets_with_background, 0))
mask_height = prediction_masks.shape[2].value mask_height = prediction_masks.shape[2].value
mask_width = prediction_masks.shape[3].value mask_width = prediction_masks.shape[3].value
reshaped_prediction_masks = tf.reshape( reshaped_prediction_masks = tf.reshape(
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
"""Tests for object_detection.meta_architectures.faster_rcnn_meta_arch.""" """Tests for object_detection.meta_architectures.faster_rcnn_meta_arch."""
from absl.testing import parameterized
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
...@@ -22,7 +23,8 @@ from object_detection.meta_architectures import faster_rcnn_meta_arch_test_lib ...@@ -22,7 +23,8 @@ from object_detection.meta_architectures import faster_rcnn_meta_arch_test_lib
class FasterRCNNMetaArchTest( class FasterRCNNMetaArchTest(
faster_rcnn_meta_arch_test_lib.FasterRCNNMetaArchTestBase): faster_rcnn_meta_arch_test_lib.FasterRCNNMetaArchTestBase,
parameterized.TestCase):
def test_postprocess_second_stage_only_inference_mode_with_masks(self): def test_postprocess_second_stage_only_inference_mode_with_masks(self):
model = self._build_model( model = self._build_model(
...@@ -83,8 +85,12 @@ class FasterRCNNMetaArchTest( ...@@ -83,8 +85,12 @@ class FasterRCNNMetaArchTest(
self.assertTrue(np.amax(detections_out['detection_masks'] <= 1.0)) self.assertTrue(np.amax(detections_out['detection_masks'] <= 1.0))
self.assertTrue(np.amin(detections_out['detection_masks'] >= 0.0)) self.assertTrue(np.amin(detections_out['detection_masks'] >= 0.0))
@parameterized.parameters(
{'masks_are_class_agnostic': False},
{'masks_are_class_agnostic': True},
)
def test_predict_correct_shapes_in_inference_mode_three_stages_with_masks( def test_predict_correct_shapes_in_inference_mode_three_stages_with_masks(
self): self, masks_are_class_agnostic):
batch_size = 2 batch_size = 2
image_size = 10 image_size = 10
max_num_proposals = 8 max_num_proposals = 8
...@@ -126,7 +132,8 @@ class FasterRCNNMetaArchTest( ...@@ -126,7 +132,8 @@ class FasterRCNNMetaArchTest(
is_training=False, is_training=False,
number_of_stages=3, number_of_stages=3,
second_stage_batch_size=2, second_stage_batch_size=2,
predict_masks=True) predict_masks=True,
masks_are_class_agnostic=masks_are_class_agnostic)
preprocessed_inputs = tf.placeholder(tf.float32, shape=input_shape) preprocessed_inputs = tf.placeholder(tf.float32, shape=input_shape)
_, true_image_shapes = model.preprocess(preprocessed_inputs) _, true_image_shapes = model.preprocess(preprocessed_inputs)
result_tensor_dict = model.predict(preprocessed_inputs, result_tensor_dict = model.predict(preprocessed_inputs,
...@@ -153,16 +160,20 @@ class FasterRCNNMetaArchTest( ...@@ -153,16 +160,20 @@ class FasterRCNNMetaArchTest(
self.assertAllEqual(tensor_dict_out['detection_scores'].shape, [2, 5]) self.assertAllEqual(tensor_dict_out['detection_scores'].shape, [2, 5])
self.assertAllEqual(tensor_dict_out['num_detections'].shape, [2]) self.assertAllEqual(tensor_dict_out['num_detections'].shape, [2])
@parameterized.parameters(
{'masks_are_class_agnostic': False},
{'masks_are_class_agnostic': True},
)
def test_predict_gives_correct_shapes_in_train_mode_both_stages_with_masks( def test_predict_gives_correct_shapes_in_train_mode_both_stages_with_masks(
self): self, masks_are_class_agnostic):
test_graph = tf.Graph() test_graph = tf.Graph()
with test_graph.as_default(): with test_graph.as_default():
model = self._build_model( model = self._build_model(
is_training=True, is_training=True,
number_of_stages=2, number_of_stages=3,
second_stage_batch_size=7, second_stage_batch_size=7,
predict_masks=True) predict_masks=True,
masks_are_class_agnostic=masks_are_class_agnostic)
batch_size = 2 batch_size = 2
image_size = 10 image_size = 10
max_num_proposals = 7 max_num_proposals = 7
...@@ -184,6 +195,7 @@ class FasterRCNNMetaArchTest( ...@@ -184,6 +195,7 @@ class FasterRCNNMetaArchTest(
groundtruth_classes_list) groundtruth_classes_list)
result_tensor_dict = model.predict(preprocessed_inputs, true_image_shapes) result_tensor_dict = model.predict(preprocessed_inputs, true_image_shapes)
mask_shape_1 = 1 if masks_are_class_agnostic else model._num_classes
expected_shapes = { expected_shapes = {
'rpn_box_predictor_features': (2, image_size, image_size, 512), 'rpn_box_predictor_features': (2, image_size, image_size, 512),
'rpn_features_to_crop': (2, image_size, image_size, 3), 'rpn_features_to_crop': (2, image_size, image_size, 3),
...@@ -197,7 +209,7 @@ class FasterRCNNMetaArchTest( ...@@ -197,7 +209,7 @@ class FasterRCNNMetaArchTest(
self._get_box_classifier_features_shape( self._get_box_classifier_features_shape(
image_size, batch_size, max_num_proposals, initial_crop_size, image_size, batch_size, max_num_proposals, initial_crop_size,
maxpool_stride, 3), maxpool_stride, 3),
'mask_predictions': (2 * max_num_proposals, 2, 14, 14) 'mask_predictions': (2 * max_num_proposals, mask_shape_1, 14, 14)
} }
init_op = tf.global_variables_initializer() init_op = tf.global_variables_initializer()
......
...@@ -90,10 +90,13 @@ class FasterRCNNMetaArchTestBase(tf.test.TestCase): ...@@ -90,10 +90,13 @@ class FasterRCNNMetaArchTestBase(tf.test.TestCase):
""" """
return box_predictor_text_proto return box_predictor_text_proto
def _add_mask_to_second_stage_box_predictor_text_proto(self): def _add_mask_to_second_stage_box_predictor_text_proto(
self, masks_are_class_agnostic=False):
agnostic = 'true' if masks_are_class_agnostic else 'false'
box_predictor_text_proto = """ box_predictor_text_proto = """
mask_rcnn_box_predictor { mask_rcnn_box_predictor {
predict_instance_masks: true predict_instance_masks: true
masks_are_class_agnostic: """ + agnostic + """
mask_height: 14 mask_height: 14
mask_width: 14 mask_width: 14
conv_hyperparams { conv_hyperparams {
...@@ -114,13 +117,14 @@ class FasterRCNNMetaArchTestBase(tf.test.TestCase): ...@@ -114,13 +117,14 @@ class FasterRCNNMetaArchTestBase(tf.test.TestCase):
return box_predictor_text_proto return box_predictor_text_proto
def _get_second_stage_box_predictor(self, num_classes, is_training, def _get_second_stage_box_predictor(self, num_classes, is_training,
predict_masks): predict_masks, masks_are_class_agnostic):
box_predictor_proto = box_predictor_pb2.BoxPredictor() box_predictor_proto = box_predictor_pb2.BoxPredictor()
text_format.Merge(self._get_second_stage_box_predictor_text_proto(), text_format.Merge(self._get_second_stage_box_predictor_text_proto(),
box_predictor_proto) box_predictor_proto)
if predict_masks: if predict_masks:
text_format.Merge( text_format.Merge(
self._add_mask_to_second_stage_box_predictor_text_proto(), self._add_mask_to_second_stage_box_predictor_text_proto(
masks_are_class_agnostic),
box_predictor_proto) box_predictor_proto)
return box_predictor_builder.build( return box_predictor_builder.build(
...@@ -146,7 +150,8 @@ class FasterRCNNMetaArchTestBase(tf.test.TestCase): ...@@ -146,7 +150,8 @@ class FasterRCNNMetaArchTestBase(tf.test.TestCase):
hard_mining=False, hard_mining=False,
softmax_second_stage_classification_loss=True, softmax_second_stage_classification_loss=True,
predict_masks=False, predict_masks=False,
pad_to_max_dimension=None): pad_to_max_dimension=None,
masks_are_class_agnostic=False):
def image_resizer_fn(image, masks=None): def image_resizer_fn(image, masks=None):
"""Fake image resizer function.""" """Fake image resizer function."""
...@@ -196,7 +201,7 @@ class FasterRCNNMetaArchTestBase(tf.test.TestCase): ...@@ -196,7 +201,7 @@ class FasterRCNNMetaArchTestBase(tf.test.TestCase):
} }
} }
""" """
first_stage_box_predictor_arg_scope = ( first_stage_box_predictor_arg_scope_fn = (
self._build_arg_scope_with_hyperparams( self._build_arg_scope_with_hyperparams(
first_stage_box_predictor_hyperparams_text_proto, is_training)) first_stage_box_predictor_hyperparams_text_proto, is_training))
...@@ -255,8 +260,8 @@ class FasterRCNNMetaArchTestBase(tf.test.TestCase): ...@@ -255,8 +260,8 @@ class FasterRCNNMetaArchTestBase(tf.test.TestCase):
'number_of_stages': number_of_stages, 'number_of_stages': number_of_stages,
'first_stage_anchor_generator': first_stage_anchor_generator, 'first_stage_anchor_generator': first_stage_anchor_generator,
'first_stage_atrous_rate': first_stage_atrous_rate, 'first_stage_atrous_rate': first_stage_atrous_rate,
'first_stage_box_predictor_arg_scope': 'first_stage_box_predictor_arg_scope_fn':
first_stage_box_predictor_arg_scope, first_stage_box_predictor_arg_scope_fn,
'first_stage_box_predictor_kernel_size': 'first_stage_box_predictor_kernel_size':
first_stage_box_predictor_kernel_size, first_stage_box_predictor_kernel_size,
'first_stage_box_predictor_depth': first_stage_box_predictor_depth, 'first_stage_box_predictor_depth': first_stage_box_predictor_depth,
...@@ -287,7 +292,8 @@ class FasterRCNNMetaArchTestBase(tf.test.TestCase): ...@@ -287,7 +292,8 @@ class FasterRCNNMetaArchTestBase(tf.test.TestCase):
self._get_second_stage_box_predictor( self._get_second_stage_box_predictor(
num_classes=num_classes, num_classes=num_classes,
is_training=is_training, is_training=is_training,
predict_masks=predict_masks), **common_kwargs) predict_masks=predict_masks,
masks_are_class_agnostic=masks_are_class_agnostic), **common_kwargs)
def test_predict_gives_correct_shapes_in_inference_mode_first_stage_only( def test_predict_gives_correct_shapes_in_inference_mode_first_stage_only(
self): self):
......
...@@ -56,7 +56,7 @@ class RFCNMetaArch(faster_rcnn_meta_arch.FasterRCNNMetaArch): ...@@ -56,7 +56,7 @@ class RFCNMetaArch(faster_rcnn_meta_arch.FasterRCNNMetaArch):
number_of_stages, number_of_stages,
first_stage_anchor_generator, first_stage_anchor_generator,
first_stage_atrous_rate, first_stage_atrous_rate,
first_stage_box_predictor_arg_scope, first_stage_box_predictor_arg_scope_fn,
first_stage_box_predictor_kernel_size, first_stage_box_predictor_kernel_size,
first_stage_box_predictor_depth, first_stage_box_predictor_depth,
first_stage_minibatch_size, first_stage_minibatch_size,
...@@ -103,8 +103,9 @@ class RFCNMetaArch(faster_rcnn_meta_arch.FasterRCNNMetaArch): ...@@ -103,8 +103,9 @@ class RFCNMetaArch(faster_rcnn_meta_arch.FasterRCNNMetaArch):
denser resolutions. The atrous rate is used to compensate for the denser resolutions. The atrous rate is used to compensate for the
denser feature maps by using an effectively larger receptive field. denser feature maps by using an effectively larger receptive field.
(This should typically be set to 1). (This should typically be set to 1).
first_stage_box_predictor_arg_scope: Slim arg_scope for conv2d, first_stage_box_predictor_arg_scope_fn: A function to generate tf-slim
separable_conv2d and fully_connected ops for the RPN box predictor. arg_scope for conv2d, separable_conv2d and fully_connected ops for the
RPN box predictor.
first_stage_box_predictor_kernel_size: Kernel size to use for the first_stage_box_predictor_kernel_size: Kernel size to use for the
convolution op just prior to RPN box predictions. convolution op just prior to RPN box predictions.
first_stage_box_predictor_depth: Output depth for the convolution op first_stage_box_predictor_depth: Output depth for the convolution op
...@@ -174,7 +175,7 @@ class RFCNMetaArch(faster_rcnn_meta_arch.FasterRCNNMetaArch): ...@@ -174,7 +175,7 @@ class RFCNMetaArch(faster_rcnn_meta_arch.FasterRCNNMetaArch):
number_of_stages, number_of_stages,
first_stage_anchor_generator, first_stage_anchor_generator,
first_stage_atrous_rate, first_stage_atrous_rate,
first_stage_box_predictor_arg_scope, first_stage_box_predictor_arg_scope_fn,
first_stage_box_predictor_kernel_size, first_stage_box_predictor_kernel_size,
first_stage_box_predictor_depth, first_stage_box_predictor_depth,
first_stage_minibatch_size, first_stage_minibatch_size,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment