Unverified Commit 8518d053 authored by pkulzc's avatar pkulzc Committed by GitHub
Browse files

Open source MnasFPN and minor fixes to OD API (#8484)

310447280  by lzc:

    Internal change

310420845  by Zhichao Lu:

    Open source the internal Context RCNN code.

--
310362339  by Zhichao Lu:

    Internal change

310259448  by lzc:

    Update required TF version for OD API.

--
310252159  by Zhichao Lu:

    Port patch_ops_test to TF1/TF2 as TPUs.

--
310247180  by Zhichao Lu:

    Ignore keypoint heatmap loss in the regions/bounding boxes with target keypoint
    class but no valid keypoint annotations.

--
310178294  by Zhichao Lu:

    Opensource MnasFPN
    https://arxiv.org/abs/1912.01106

--
310094222  by lzc:

    Internal changes.

--
310085250  by lzc:

    Internal Change.

--
310016447  by huizhongc:

    Remove unrecognized classes from labeled_classes.

--
310009470  by rathodv:

    Mark batcher.py as TF1 only.

--
310001984  by rathodv:

    Update core/preprocessor.py to be compatible with TF1/TF2..

--
309455035  by Zhi...
parent ac5fff19
# Lint as: python2, python3
# Copyright 2017 The TensorFlow Authors. All Rights Reserved. # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
...@@ -24,7 +25,13 @@ from object_detection.builders import hyperparams_builder ...@@ -24,7 +25,13 @@ from object_detection.builders import hyperparams_builder
from object_detection.core import freezable_batch_norm from object_detection.core import freezable_batch_norm
from object_detection.protos import hyperparams_pb2 from object_detection.protos import hyperparams_pb2
slim = tf.contrib.slim # pylint: disable=g-import-not-at-top
try:
from tensorflow.contrib import slim
except ImportError:
# TF 2.0 doesn't ship with contrib.
pass
# pylint: enable=g-import-not-at-top
def _get_scope_key(op): def _get_scope_key(op):
...@@ -49,7 +56,7 @@ class HyperparamsBuilderTest(tf.test.TestCase): ...@@ -49,7 +56,7 @@ class HyperparamsBuilderTest(tf.test.TestCase):
scope_fn = hyperparams_builder.build(conv_hyperparams_proto, scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
is_training=True) is_training=True)
scope = scope_fn() scope = scope_fn()
self.assertTrue(_get_scope_key(slim.conv2d) in scope) self.assertIn(_get_scope_key(slim.conv2d), scope)
def test_default_arg_scope_has_separable_conv2d_op(self): def test_default_arg_scope_has_separable_conv2d_op(self):
conv_hyperparams_text_proto = """ conv_hyperparams_text_proto = """
...@@ -67,7 +74,7 @@ class HyperparamsBuilderTest(tf.test.TestCase): ...@@ -67,7 +74,7 @@ class HyperparamsBuilderTest(tf.test.TestCase):
scope_fn = hyperparams_builder.build(conv_hyperparams_proto, scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
is_training=True) is_training=True)
scope = scope_fn() scope = scope_fn()
self.assertTrue(_get_scope_key(slim.separable_conv2d) in scope) self.assertIn(_get_scope_key(slim.separable_conv2d), scope)
def test_default_arg_scope_has_conv2d_transpose_op(self): def test_default_arg_scope_has_conv2d_transpose_op(self):
conv_hyperparams_text_proto = """ conv_hyperparams_text_proto = """
...@@ -85,7 +92,7 @@ class HyperparamsBuilderTest(tf.test.TestCase): ...@@ -85,7 +92,7 @@ class HyperparamsBuilderTest(tf.test.TestCase):
scope_fn = hyperparams_builder.build(conv_hyperparams_proto, scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
is_training=True) is_training=True)
scope = scope_fn() scope = scope_fn()
self.assertTrue(_get_scope_key(slim.conv2d_transpose) in scope) self.assertIn(_get_scope_key(slim.conv2d_transpose), scope)
def test_explicit_fc_op_arg_scope_has_fully_connected_op(self): def test_explicit_fc_op_arg_scope_has_fully_connected_op(self):
conv_hyperparams_text_proto = """ conv_hyperparams_text_proto = """
...@@ -104,7 +111,7 @@ class HyperparamsBuilderTest(tf.test.TestCase): ...@@ -104,7 +111,7 @@ class HyperparamsBuilderTest(tf.test.TestCase):
scope_fn = hyperparams_builder.build(conv_hyperparams_proto, scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
is_training=True) is_training=True)
scope = scope_fn() scope = scope_fn()
self.assertTrue(_get_scope_key(slim.fully_connected) in scope) self.assertIn(_get_scope_key(slim.fully_connected), scope)
def test_separable_conv2d_and_conv2d_and_transpose_have_same_parameters(self): def test_separable_conv2d_and_conv2d_and_transpose_have_same_parameters(self):
conv_hyperparams_text_proto = """ conv_hyperparams_text_proto = """
...@@ -143,7 +150,7 @@ class HyperparamsBuilderTest(tf.test.TestCase): ...@@ -143,7 +150,7 @@ class HyperparamsBuilderTest(tf.test.TestCase):
scope_fn = hyperparams_builder.build(conv_hyperparams_proto, scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
is_training=True) is_training=True)
scope = scope_fn() scope = scope_fn()
conv_scope_arguments = scope.values()[0] conv_scope_arguments = list(scope.values())[0]
regularizer = conv_scope_arguments['weights_regularizer'] regularizer = conv_scope_arguments['weights_regularizer']
weights = np.array([1., -1, 4., 2.]) weights = np.array([1., -1, 4., 2.])
with self.test_session() as sess: with self.test_session() as sess:
...@@ -284,8 +291,8 @@ class HyperparamsBuilderTest(tf.test.TestCase): ...@@ -284,8 +291,8 @@ class HyperparamsBuilderTest(tf.test.TestCase):
self.assertTrue(batch_norm_params['scale']) self.assertTrue(batch_norm_params['scale'])
batch_norm_layer = keras_config.build_batch_norm() batch_norm_layer = keras_config.build_batch_norm()
self.assertTrue(isinstance(batch_norm_layer, self.assertIsInstance(batch_norm_layer,
freezable_batch_norm.FreezableBatchNorm)) freezable_batch_norm.FreezableBatchNorm)
def test_return_non_default_batch_norm_params_keras_override( def test_return_non_default_batch_norm_params_keras_override(
self): self):
...@@ -420,8 +427,8 @@ class HyperparamsBuilderTest(tf.test.TestCase): ...@@ -420,8 +427,8 @@ class HyperparamsBuilderTest(tf.test.TestCase):
# The batch norm builder should build an identity Lambda layer # The batch norm builder should build an identity Lambda layer
identity_layer = keras_config.build_batch_norm() identity_layer = keras_config.build_batch_norm()
self.assertTrue(isinstance(identity_layer, self.assertIsInstance(identity_layer,
tf.keras.layers.Lambda)) tf.keras.layers.Lambda)
def test_use_none_activation(self): def test_use_none_activation(self):
conv_hyperparams_text_proto = """ conv_hyperparams_text_proto = """
...@@ -463,7 +470,7 @@ class HyperparamsBuilderTest(tf.test.TestCase): ...@@ -463,7 +470,7 @@ class HyperparamsBuilderTest(tf.test.TestCase):
self.assertEqual( self.assertEqual(
keras_config.params(include_activation=True)['activation'], None) keras_config.params(include_activation=True)['activation'], None)
activation_layer = keras_config.build_activation_layer() activation_layer = keras_config.build_activation_layer()
self.assertTrue(isinstance(activation_layer, tf.keras.layers.Lambda)) self.assertIsInstance(activation_layer, tf.keras.layers.Lambda)
self.assertEqual(activation_layer.function, tf.identity) self.assertEqual(activation_layer.function, tf.identity)
def test_use_relu_activation(self): def test_use_relu_activation(self):
...@@ -506,7 +513,7 @@ class HyperparamsBuilderTest(tf.test.TestCase): ...@@ -506,7 +513,7 @@ class HyperparamsBuilderTest(tf.test.TestCase):
self.assertEqual( self.assertEqual(
keras_config.params(include_activation=True)['activation'], tf.nn.relu) keras_config.params(include_activation=True)['activation'], tf.nn.relu)
activation_layer = keras_config.build_activation_layer() activation_layer = keras_config.build_activation_layer()
self.assertTrue(isinstance(activation_layer, tf.keras.layers.Lambda)) self.assertIsInstance(activation_layer, tf.keras.layers.Lambda)
self.assertEqual(activation_layer.function, tf.nn.relu) self.assertEqual(activation_layer.function, tf.nn.relu)
def test_use_relu_6_activation(self): def test_use_relu_6_activation(self):
...@@ -549,9 +556,52 @@ class HyperparamsBuilderTest(tf.test.TestCase): ...@@ -549,9 +556,52 @@ class HyperparamsBuilderTest(tf.test.TestCase):
self.assertEqual( self.assertEqual(
keras_config.params(include_activation=True)['activation'], tf.nn.relu6) keras_config.params(include_activation=True)['activation'], tf.nn.relu6)
activation_layer = keras_config.build_activation_layer() activation_layer = keras_config.build_activation_layer()
self.assertTrue(isinstance(activation_layer, tf.keras.layers.Lambda)) self.assertIsInstance(activation_layer, tf.keras.layers.Lambda)
self.assertEqual(activation_layer.function, tf.nn.relu6) self.assertEqual(activation_layer.function, tf.nn.relu6)
def test_use_swish_activation(self):
conv_hyperparams_text_proto = """
regularizer {
l2_regularizer {
}
}
initializer {
truncated_normal_initializer {
}
}
activation: SWISH
"""
conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
is_training=True)
scope = scope_fn()
conv_scope_arguments = scope[_get_scope_key(slim.conv2d)]
self.assertEqual(conv_scope_arguments['activation_fn'], tf.nn.swish)
def test_use_swish_activation_keras(self):
conv_hyperparams_text_proto = """
regularizer {
l2_regularizer {
}
}
initializer {
truncated_normal_initializer {
}
}
activation: SWISH
"""
conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
keras_config = hyperparams_builder.KerasLayerHyperparams(
conv_hyperparams_proto)
self.assertEqual(keras_config.params()['activation'], None)
self.assertEqual(
keras_config.params(include_activation=True)['activation'], tf.nn.swish)
activation_layer = keras_config.build_activation_layer()
self.assertIsInstance(activation_layer, tf.keras.layers.Lambda)
self.assertEqual(activation_layer.function, tf.nn.swish)
def test_override_activation_keras(self): def test_override_activation_keras(self):
conv_hyperparams_text_proto = """ conv_hyperparams_text_proto = """
regularizer { regularizer {
......
...@@ -133,9 +133,22 @@ def build(image_resizer_config): ...@@ -133,9 +133,22 @@ def build(image_resizer_config):
'Invalid image resizer condition option for ' 'Invalid image resizer condition option for '
'ConditionalShapeResizer: \'%s\'.' 'ConditionalShapeResizer: \'%s\'.'
% conditional_shape_resize_config.condition) % conditional_shape_resize_config.condition)
if not conditional_shape_resize_config.convert_to_grayscale: if not conditional_shape_resize_config.convert_to_grayscale:
return image_resizer_fn return image_resizer_fn
elif image_resizer_oneof == 'pad_to_multiple_resizer':
pad_to_multiple_resizer_config = (
image_resizer_config.pad_to_multiple_resizer)
if pad_to_multiple_resizer_config.multiple < 0:
raise ValueError('`multiple` for pad_to_multiple_resizer should be > 0.')
else:
image_resizer_fn = functools.partial(
preprocessor.resize_pad_to_multiple,
multiple=pad_to_multiple_resizer_config.multiple)
if not pad_to_multiple_resizer_config.convert_to_grayscale:
return image_resizer_fn
else: else:
raise ValueError( raise ValueError(
'Invalid image resizer option: \'%s\'.' % image_resizer_oneof) 'Invalid image resizer option: \'%s\'.' % image_resizer_oneof)
......
...@@ -211,6 +211,31 @@ class ImageResizerBuilderTest(tf.test.TestCase): ...@@ -211,6 +211,31 @@ class ImageResizerBuilderTest(tf.test.TestCase):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
image_resizer_builder.build(invalid_image_resizer_text_proto) image_resizer_builder.build(invalid_image_resizer_text_proto)
def test_build_pad_to_multiple_resizer(self):
"""Test building a pad_to_multiple_resizer from proto."""
image_resizer_text_proto = """
pad_to_multiple_resizer {
multiple: 32
}
"""
input_shape = (60, 30, 3)
expected_output_shape = (64, 32, 3)
output_shape = self._shape_of_resized_random_image_given_text_proto(
input_shape, image_resizer_text_proto)
self.assertEqual(output_shape, expected_output_shape)
def test_build_pad_to_multiple_resizer_invalid_multiple(self):
"""Test that building a pad_to_multiple_resizer errors with invalid multiple."""
image_resizer_text_proto = """
pad_to_multiple_resizer {
multiple: -10
}
"""
with self.assertRaises(ValueError):
image_resizer_builder.build(image_resizer_text_proto)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
# Lint as: python2, python3
# Copyright 2017 The TensorFlow Authors. All Rights Reserved. # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
...@@ -23,12 +24,24 @@ Detection configuration framework, they should define their own builder function ...@@ -23,12 +24,24 @@ Detection configuration framework, they should define their own builder function
that wraps the build function. that wraps the build function.
""" """
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf import tensorflow as tf
from object_detection.data_decoders import tf_example_decoder from object_detection.data_decoders import tf_example_decoder
from object_detection.protos import input_reader_pb2 from object_detection.protos import input_reader_pb2
parallel_reader = tf.contrib.slim.parallel_reader # pylint: disable=g-import-not-at-top
try:
from tensorflow.contrib import slim as contrib_slim
except ImportError:
# TF 2.0 doesn't ship with contrib.
pass
# pylint: enable=g-import-not-at-top
parallel_reader = contrib_slim.parallel_reader
def build(input_reader_config): def build(input_reader_config):
...@@ -70,7 +83,8 @@ def build(input_reader_config): ...@@ -70,7 +83,8 @@ def build(input_reader_config):
decoder = tf_example_decoder.TfExampleDecoder( decoder = tf_example_decoder.TfExampleDecoder(
load_instance_masks=input_reader_config.load_instance_masks, load_instance_masks=input_reader_config.load_instance_masks,
instance_mask_type=input_reader_config.mask_type, instance_mask_type=input_reader_config.mask_type,
label_map_proto_file=label_map_proto_file) label_map_proto_file=label_map_proto_file,
load_context_features=input_reader_config.load_context_features)
return decoder.decode(string_tensor) return decoder.decode(string_tensor)
raise ValueError('Unsupported input_reader_config.') raise ValueError('Unsupported input_reader_config.')
...@@ -54,6 +54,48 @@ class InputReaderBuilderTest(tf.test.TestCase): ...@@ -54,6 +54,48 @@ class InputReaderBuilderTest(tf.test.TestCase):
return path return path
def create_tf_record_with_context(self):
path = os.path.join(self.get_temp_dir(), 'tfrecord')
writer = tf.python_io.TFRecordWriter(path)
image_tensor = np.random.randint(255, size=(4, 5, 3)).astype(np.uint8)
flat_mask = (4 * 5) * [1.0]
context_features = (10 * 3) * [1.0]
with self.test_session():
encoded_jpeg = tf.image.encode_jpeg(tf.constant(image_tensor)).eval()
example = tf.train.Example(
features=tf.train.Features(
feature={
'image/encoded':
dataset_util.bytes_feature(encoded_jpeg),
'image/format':
dataset_util.bytes_feature('jpeg'.encode('utf8')),
'image/height':
dataset_util.int64_feature(4),
'image/width':
dataset_util.int64_feature(5),
'image/object/bbox/xmin':
dataset_util.float_list_feature([0.0]),
'image/object/bbox/xmax':
dataset_util.float_list_feature([1.0]),
'image/object/bbox/ymin':
dataset_util.float_list_feature([0.0]),
'image/object/bbox/ymax':
dataset_util.float_list_feature([1.0]),
'image/object/class/label':
dataset_util.int64_list_feature([2]),
'image/object/mask':
dataset_util.float_list_feature(flat_mask),
'image/context_features':
dataset_util.float_list_feature(context_features),
'image/context_feature_length':
dataset_util.int64_list_feature([10]),
}))
writer.write(example.SerializeToString())
writer.close()
return path
def test_build_tf_record_input_reader(self): def test_build_tf_record_input_reader(self):
tf_record_path = self.create_tf_record() tf_record_path = self.create_tf_record()
...@@ -71,17 +113,52 @@ class InputReaderBuilderTest(tf.test.TestCase): ...@@ -71,17 +113,52 @@ class InputReaderBuilderTest(tf.test.TestCase):
with tf.train.MonitoredSession() as sess: with tf.train.MonitoredSession() as sess:
output_dict = sess.run(tensor_dict) output_dict = sess.run(tensor_dict)
self.assertTrue(fields.InputDataFields.groundtruth_instance_masks self.assertNotIn(fields.InputDataFields.groundtruth_instance_masks,
not in output_dict) output_dict)
self.assertEquals( self.assertEqual((4, 5, 3), output_dict[fields.InputDataFields.image].shape)
(4, 5, 3), output_dict[fields.InputDataFields.image].shape) self.assertEqual([2],
self.assertEquals( output_dict[fields.InputDataFields.groundtruth_classes])
[2], output_dict[fields.InputDataFields.groundtruth_classes]) self.assertEqual(
self.assertEquals( (1, 4), output_dict[fields.InputDataFields.groundtruth_boxes].shape)
self.assertAllEqual(
[0.0, 0.0, 1.0, 1.0],
output_dict[fields.InputDataFields.groundtruth_boxes][0])
def test_build_tf_record_input_reader_with_context(self):
tf_record_path = self.create_tf_record_with_context()
input_reader_text_proto = """
shuffle: false
num_readers: 1
tf_record_input_reader {{
input_path: '{0}'
}}
""".format(tf_record_path)
input_reader_proto = input_reader_pb2.InputReader()
text_format.Merge(input_reader_text_proto, input_reader_proto)
input_reader_proto.load_context_features = True
tensor_dict = input_reader_builder.build(input_reader_proto)
with tf.train.MonitoredSession() as sess:
output_dict = sess.run(tensor_dict)
self.assertNotIn(fields.InputDataFields.groundtruth_instance_masks,
output_dict)
self.assertEqual((4, 5, 3), output_dict[fields.InputDataFields.image].shape)
self.assertEqual([2],
output_dict[fields.InputDataFields.groundtruth_classes])
self.assertEqual(
(1, 4), output_dict[fields.InputDataFields.groundtruth_boxes].shape) (1, 4), output_dict[fields.InputDataFields.groundtruth_boxes].shape)
self.assertAllEqual( self.assertAllEqual(
[0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 1.0, 1.0],
output_dict[fields.InputDataFields.groundtruth_boxes][0]) output_dict[fields.InputDataFields.groundtruth_boxes][0])
self.assertAllEqual(
[0.0, 0.0, 1.0, 1.0],
output_dict[fields.InputDataFields.groundtruth_boxes][0])
self.assertAllEqual(
(3, 10), output_dict[fields.InputDataFields.context_features].shape)
self.assertAllEqual(
(10), output_dict[fields.InputDataFields.context_feature_length])
def test_build_tf_record_input_reader_and_load_instance_masks(self): def test_build_tf_record_input_reader_and_load_instance_masks(self):
tf_record_path = self.create_tf_record() tf_record_path = self.create_tf_record()
...@@ -101,11 +178,10 @@ class InputReaderBuilderTest(tf.test.TestCase): ...@@ -101,11 +178,10 @@ class InputReaderBuilderTest(tf.test.TestCase):
with tf.train.MonitoredSession() as sess: with tf.train.MonitoredSession() as sess:
output_dict = sess.run(tensor_dict) output_dict = sess.run(tensor_dict)
self.assertEquals( self.assertEqual((4, 5, 3), output_dict[fields.InputDataFields.image].shape)
(4, 5, 3), output_dict[fields.InputDataFields.image].shape) self.assertEqual([2],
self.assertEquals( output_dict[fields.InputDataFields.groundtruth_classes])
[2], output_dict[fields.InputDataFields.groundtruth_classes]) self.assertEqual(
self.assertEquals(
(1, 4), output_dict[fields.InputDataFields.groundtruth_boxes].shape) (1, 4), output_dict[fields.InputDataFields.groundtruth_boxes].shape)
self.assertAllEqual( self.assertAllEqual(
[0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 1.0, 1.0],
......
...@@ -201,6 +201,9 @@ def _build_localization_loss(loss_config): ...@@ -201,6 +201,9 @@ def _build_localization_loss(loss_config):
if loss_type == 'weighted_iou': if loss_type == 'weighted_iou':
return losses.WeightedIOULocalizationLoss() return losses.WeightedIOULocalizationLoss()
if loss_type == 'l1_localization_loss':
return losses.L1LocalizationLoss()
raise ValueError('Empty loss config.') raise ValueError('Empty loss config.')
...@@ -249,4 +252,9 @@ def _build_classification_loss(loss_config): ...@@ -249,4 +252,9 @@ def _build_classification_loss(loss_config):
alpha=config.alpha, alpha=config.alpha,
bootstrap_type=('hard' if config.hard_bootstrap else 'soft')) bootstrap_type=('hard' if config.hard_bootstrap else 'soft'))
if loss_type == 'penalty_reduced_logistic_focal_loss':
config = loss_config.penalty_reduced_logistic_focal_loss
return losses.PenaltyReducedLogisticFocalLoss(
alpha=config.alpha, beta=config.beta)
raise ValueError('Empty loss config.') raise ValueError('Empty loss config.')
...@@ -40,8 +40,8 @@ class LocalizationLossBuilderTest(tf.test.TestCase): ...@@ -40,8 +40,8 @@ class LocalizationLossBuilderTest(tf.test.TestCase):
losses_proto = losses_pb2.Loss() losses_proto = losses_pb2.Loss()
text_format.Merge(losses_text_proto, losses_proto) text_format.Merge(losses_text_proto, losses_proto)
_, localization_loss, _, _, _, _, _ = losses_builder.build(losses_proto) _, localization_loss, _, _, _, _, _ = losses_builder.build(losses_proto)
self.assertTrue(isinstance(localization_loss, self.assertIsInstance(localization_loss,
losses.WeightedL2LocalizationLoss)) losses.WeightedL2LocalizationLoss)
def test_build_weighted_smooth_l1_localization_loss_default_delta(self): def test_build_weighted_smooth_l1_localization_loss_default_delta(self):
losses_text_proto = """ losses_text_proto = """
...@@ -57,8 +57,8 @@ class LocalizationLossBuilderTest(tf.test.TestCase): ...@@ -57,8 +57,8 @@ class LocalizationLossBuilderTest(tf.test.TestCase):
losses_proto = losses_pb2.Loss() losses_proto = losses_pb2.Loss()
text_format.Merge(losses_text_proto, losses_proto) text_format.Merge(losses_text_proto, losses_proto)
_, localization_loss, _, _, _, _, _ = losses_builder.build(losses_proto) _, localization_loss, _, _, _, _, _ = losses_builder.build(losses_proto)
self.assertTrue(isinstance(localization_loss, self.assertIsInstance(localization_loss,
losses.WeightedSmoothL1LocalizationLoss)) losses.WeightedSmoothL1LocalizationLoss)
self.assertAlmostEqual(localization_loss._delta, 1.0) self.assertAlmostEqual(localization_loss._delta, 1.0)
def test_build_weighted_smooth_l1_localization_loss_non_default_delta(self): def test_build_weighted_smooth_l1_localization_loss_non_default_delta(self):
...@@ -76,8 +76,8 @@ class LocalizationLossBuilderTest(tf.test.TestCase): ...@@ -76,8 +76,8 @@ class LocalizationLossBuilderTest(tf.test.TestCase):
losses_proto = losses_pb2.Loss() losses_proto = losses_pb2.Loss()
text_format.Merge(losses_text_proto, losses_proto) text_format.Merge(losses_text_proto, losses_proto)
_, localization_loss, _, _, _, _, _ = losses_builder.build(losses_proto) _, localization_loss, _, _, _, _, _ = losses_builder.build(losses_proto)
self.assertTrue(isinstance(localization_loss, self.assertIsInstance(localization_loss,
losses.WeightedSmoothL1LocalizationLoss)) losses.WeightedSmoothL1LocalizationLoss)
self.assertAlmostEqual(localization_loss._delta, 0.1) self.assertAlmostEqual(localization_loss._delta, 0.1)
def test_build_weighted_iou_localization_loss(self): def test_build_weighted_iou_localization_loss(self):
...@@ -94,8 +94,8 @@ class LocalizationLossBuilderTest(tf.test.TestCase): ...@@ -94,8 +94,8 @@ class LocalizationLossBuilderTest(tf.test.TestCase):
losses_proto = losses_pb2.Loss() losses_proto = losses_pb2.Loss()
text_format.Merge(losses_text_proto, losses_proto) text_format.Merge(losses_text_proto, losses_proto)
_, localization_loss, _, _, _, _, _ = losses_builder.build(losses_proto) _, localization_loss, _, _, _, _, _ = losses_builder.build(losses_proto)
self.assertTrue(isinstance(localization_loss, self.assertIsInstance(localization_loss,
losses.WeightedIOULocalizationLoss)) losses.WeightedIOULocalizationLoss)
def test_anchorwise_output(self): def test_anchorwise_output(self):
losses_text_proto = """ losses_text_proto = """
...@@ -111,8 +111,8 @@ class LocalizationLossBuilderTest(tf.test.TestCase): ...@@ -111,8 +111,8 @@ class LocalizationLossBuilderTest(tf.test.TestCase):
losses_proto = losses_pb2.Loss() losses_proto = losses_pb2.Loss()
text_format.Merge(losses_text_proto, losses_proto) text_format.Merge(losses_text_proto, losses_proto)
_, localization_loss, _, _, _, _, _ = losses_builder.build(losses_proto) _, localization_loss, _, _, _, _, _ = losses_builder.build(losses_proto)
self.assertTrue(isinstance(localization_loss, self.assertIsInstance(localization_loss,
losses.WeightedSmoothL1LocalizationLoss)) losses.WeightedSmoothL1LocalizationLoss)
predictions = tf.constant([[[0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 1.0, 1.0]]]) predictions = tf.constant([[[0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 1.0, 1.0]]])
targets = tf.constant([[[0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 1.0, 1.0]]]) targets = tf.constant([[[0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 1.0, 1.0]]])
weights = tf.constant([[1.0, 1.0]]) weights = tf.constant([[1.0, 1.0]])
...@@ -132,6 +132,7 @@ class LocalizationLossBuilderTest(tf.test.TestCase): ...@@ -132,6 +132,7 @@ class LocalizationLossBuilderTest(tf.test.TestCase):
losses_builder._build_localization_loss(losses_proto) losses_builder._build_localization_loss(losses_proto)
class ClassificationLossBuilderTest(tf.test.TestCase): class ClassificationLossBuilderTest(tf.test.TestCase):
def test_build_weighted_sigmoid_classification_loss(self): def test_build_weighted_sigmoid_classification_loss(self):
...@@ -148,8 +149,8 @@ class ClassificationLossBuilderTest(tf.test.TestCase): ...@@ -148,8 +149,8 @@ class ClassificationLossBuilderTest(tf.test.TestCase):
losses_proto = losses_pb2.Loss() losses_proto = losses_pb2.Loss()
text_format.Merge(losses_text_proto, losses_proto) text_format.Merge(losses_text_proto, losses_proto)
classification_loss, _, _, _, _, _, _ = losses_builder.build(losses_proto) classification_loss, _, _, _, _, _, _ = losses_builder.build(losses_proto)
self.assertTrue(isinstance(classification_loss, self.assertIsInstance(classification_loss,
losses.WeightedSigmoidClassificationLoss)) losses.WeightedSigmoidClassificationLoss)
def test_build_weighted_sigmoid_focal_classification_loss(self): def test_build_weighted_sigmoid_focal_classification_loss(self):
losses_text_proto = """ losses_text_proto = """
...@@ -165,8 +166,8 @@ class ClassificationLossBuilderTest(tf.test.TestCase): ...@@ -165,8 +166,8 @@ class ClassificationLossBuilderTest(tf.test.TestCase):
losses_proto = losses_pb2.Loss() losses_proto = losses_pb2.Loss()
text_format.Merge(losses_text_proto, losses_proto) text_format.Merge(losses_text_proto, losses_proto)
classification_loss, _, _, _, _, _, _ = losses_builder.build(losses_proto) classification_loss, _, _, _, _, _, _ = losses_builder.build(losses_proto)
self.assertTrue(isinstance(classification_loss, self.assertIsInstance(classification_loss,
losses.SigmoidFocalClassificationLoss)) losses.SigmoidFocalClassificationLoss)
self.assertAlmostEqual(classification_loss._alpha, None) self.assertAlmostEqual(classification_loss._alpha, None)
self.assertAlmostEqual(classification_loss._gamma, 2.0) self.assertAlmostEqual(classification_loss._gamma, 2.0)
...@@ -186,8 +187,8 @@ class ClassificationLossBuilderTest(tf.test.TestCase): ...@@ -186,8 +187,8 @@ class ClassificationLossBuilderTest(tf.test.TestCase):
losses_proto = losses_pb2.Loss() losses_proto = losses_pb2.Loss()
text_format.Merge(losses_text_proto, losses_proto) text_format.Merge(losses_text_proto, losses_proto)
classification_loss, _, _, _, _, _, _ = losses_builder.build(losses_proto) classification_loss, _, _, _, _, _, _ = losses_builder.build(losses_proto)
self.assertTrue(isinstance(classification_loss, self.assertIsInstance(classification_loss,
losses.SigmoidFocalClassificationLoss)) losses.SigmoidFocalClassificationLoss)
self.assertAlmostEqual(classification_loss._alpha, 0.25) self.assertAlmostEqual(classification_loss._alpha, 0.25)
self.assertAlmostEqual(classification_loss._gamma, 3.0) self.assertAlmostEqual(classification_loss._gamma, 3.0)
...@@ -205,8 +206,8 @@ class ClassificationLossBuilderTest(tf.test.TestCase): ...@@ -205,8 +206,8 @@ class ClassificationLossBuilderTest(tf.test.TestCase):
losses_proto = losses_pb2.Loss() losses_proto = losses_pb2.Loss()
text_format.Merge(losses_text_proto, losses_proto) text_format.Merge(losses_text_proto, losses_proto)
classification_loss, _, _, _, _, _, _ = losses_builder.build(losses_proto) classification_loss, _, _, _, _, _, _ = losses_builder.build(losses_proto)
self.assertTrue(isinstance(classification_loss, self.assertIsInstance(classification_loss,
losses.WeightedSoftmaxClassificationLoss)) losses.WeightedSoftmaxClassificationLoss)
def test_build_weighted_logits_softmax_classification_loss(self): def test_build_weighted_logits_softmax_classification_loss(self):
losses_text_proto = """ losses_text_proto = """
...@@ -222,9 +223,9 @@ class ClassificationLossBuilderTest(tf.test.TestCase): ...@@ -222,9 +223,9 @@ class ClassificationLossBuilderTest(tf.test.TestCase):
losses_proto = losses_pb2.Loss() losses_proto = losses_pb2.Loss()
text_format.Merge(losses_text_proto, losses_proto) text_format.Merge(losses_text_proto, losses_proto)
classification_loss, _, _, _, _, _, _ = losses_builder.build(losses_proto) classification_loss, _, _, _, _, _, _ = losses_builder.build(losses_proto)
self.assertTrue( self.assertIsInstance(
isinstance(classification_loss, classification_loss,
losses.WeightedSoftmaxClassificationAgainstLogitsLoss)) losses.WeightedSoftmaxClassificationAgainstLogitsLoss)
def test_build_weighted_softmax_classification_loss_with_logit_scale(self): def test_build_weighted_softmax_classification_loss_with_logit_scale(self):
losses_text_proto = """ losses_text_proto = """
...@@ -241,8 +242,8 @@ class ClassificationLossBuilderTest(tf.test.TestCase): ...@@ -241,8 +242,8 @@ class ClassificationLossBuilderTest(tf.test.TestCase):
losses_proto = losses_pb2.Loss() losses_proto = losses_pb2.Loss()
text_format.Merge(losses_text_proto, losses_proto) text_format.Merge(losses_text_proto, losses_proto)
classification_loss, _, _, _, _, _, _ = losses_builder.build(losses_proto) classification_loss, _, _, _, _, _, _ = losses_builder.build(losses_proto)
self.assertTrue(isinstance(classification_loss, self.assertIsInstance(classification_loss,
losses.WeightedSoftmaxClassificationLoss)) losses.WeightedSoftmaxClassificationLoss)
def test_build_bootstrapped_sigmoid_classification_loss(self): def test_build_bootstrapped_sigmoid_classification_loss(self):
losses_text_proto = """ losses_text_proto = """
...@@ -259,8 +260,8 @@ class ClassificationLossBuilderTest(tf.test.TestCase): ...@@ -259,8 +260,8 @@ class ClassificationLossBuilderTest(tf.test.TestCase):
losses_proto = losses_pb2.Loss() losses_proto = losses_pb2.Loss()
text_format.Merge(losses_text_proto, losses_proto) text_format.Merge(losses_text_proto, losses_proto)
classification_loss, _, _, _, _, _, _ = losses_builder.build(losses_proto) classification_loss, _, _, _, _, _, _ = losses_builder.build(losses_proto)
self.assertTrue(isinstance(classification_loss, self.assertIsInstance(classification_loss,
losses.BootstrappedSigmoidClassificationLoss)) losses.BootstrappedSigmoidClassificationLoss)
def test_anchorwise_output(self): def test_anchorwise_output(self):
losses_text_proto = """ losses_text_proto = """
...@@ -277,8 +278,8 @@ class ClassificationLossBuilderTest(tf.test.TestCase): ...@@ -277,8 +278,8 @@ class ClassificationLossBuilderTest(tf.test.TestCase):
losses_proto = losses_pb2.Loss() losses_proto = losses_pb2.Loss()
text_format.Merge(losses_text_proto, losses_proto) text_format.Merge(losses_text_proto, losses_proto)
classification_loss, _, _, _, _, _, _ = losses_builder.build(losses_proto) classification_loss, _, _, _, _, _, _ = losses_builder.build(losses_proto)
self.assertTrue(isinstance(classification_loss, self.assertIsInstance(classification_loss,
losses.WeightedSigmoidClassificationLoss)) losses.WeightedSigmoidClassificationLoss)
predictions = tf.constant([[[0.0, 1.0, 0.0], [0.0, 0.5, 0.5]]]) predictions = tf.constant([[[0.0, 1.0, 0.0], [0.0, 0.5, 0.5]]])
targets = tf.constant([[[0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]]) targets = tf.constant([[[0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]])
weights = tf.constant([[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]]) weights = tf.constant([[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]])
...@@ -298,6 +299,7 @@ class ClassificationLossBuilderTest(tf.test.TestCase): ...@@ -298,6 +299,7 @@ class ClassificationLossBuilderTest(tf.test.TestCase):
losses_builder.build(losses_proto) losses_builder.build(losses_proto)
class HardExampleMinerBuilderTest(tf.test.TestCase): class HardExampleMinerBuilderTest(tf.test.TestCase):
def test_do_not_build_hard_example_miner_by_default(self): def test_do_not_build_hard_example_miner_by_default(self):
...@@ -333,7 +335,7 @@ class HardExampleMinerBuilderTest(tf.test.TestCase): ...@@ -333,7 +335,7 @@ class HardExampleMinerBuilderTest(tf.test.TestCase):
losses_proto = losses_pb2.Loss() losses_proto = losses_pb2.Loss()
text_format.Merge(losses_text_proto, losses_proto) text_format.Merge(losses_text_proto, losses_proto)
_, _, _, _, hard_example_miner, _, _ = losses_builder.build(losses_proto) _, _, _, _, hard_example_miner, _, _ = losses_builder.build(losses_proto)
self.assertTrue(isinstance(hard_example_miner, losses.HardExampleMiner)) self.assertIsInstance(hard_example_miner, losses.HardExampleMiner)
self.assertEqual(hard_example_miner._loss_type, 'cls') self.assertEqual(hard_example_miner._loss_type, 'cls')
def test_build_hard_example_miner_for_localization_loss(self): def test_build_hard_example_miner_for_localization_loss(self):
...@@ -353,7 +355,7 @@ class HardExampleMinerBuilderTest(tf.test.TestCase): ...@@ -353,7 +355,7 @@ class HardExampleMinerBuilderTest(tf.test.TestCase):
losses_proto = losses_pb2.Loss() losses_proto = losses_pb2.Loss()
text_format.Merge(losses_text_proto, losses_proto) text_format.Merge(losses_text_proto, losses_proto)
_, _, _, _, hard_example_miner, _, _ = losses_builder.build(losses_proto) _, _, _, _, hard_example_miner, _, _ = losses_builder.build(losses_proto)
self.assertTrue(isinstance(hard_example_miner, losses.HardExampleMiner)) self.assertIsInstance(hard_example_miner, losses.HardExampleMiner)
self.assertEqual(hard_example_miner._loss_type, 'loc') self.assertEqual(hard_example_miner._loss_type, 'loc')
def test_build_hard_example_miner_with_non_default_values(self): def test_build_hard_example_miner_with_non_default_values(self):
...@@ -377,7 +379,7 @@ class HardExampleMinerBuilderTest(tf.test.TestCase): ...@@ -377,7 +379,7 @@ class HardExampleMinerBuilderTest(tf.test.TestCase):
losses_proto = losses_pb2.Loss() losses_proto = losses_pb2.Loss()
text_format.Merge(losses_text_proto, losses_proto) text_format.Merge(losses_text_proto, losses_proto)
_, _, _, _, hard_example_miner, _, _ = losses_builder.build(losses_proto) _, _, _, _, hard_example_miner, _, _ = losses_builder.build(losses_proto)
self.assertTrue(isinstance(hard_example_miner, losses.HardExampleMiner)) self.assertIsInstance(hard_example_miner, losses.HardExampleMiner)
self.assertEqual(hard_example_miner._num_hard_examples, 32) self.assertEqual(hard_example_miner._num_hard_examples, 32)
self.assertAlmostEqual(hard_example_miner._iou_threshold, 0.5) self.assertAlmostEqual(hard_example_miner._iou_threshold, 0.5)
self.assertEqual(hard_example_miner._max_negatives_per_positive, 10) self.assertEqual(hard_example_miner._max_negatives_per_positive, 10)
...@@ -406,11 +408,11 @@ class LossBuilderTest(tf.test.TestCase): ...@@ -406,11 +408,11 @@ class LossBuilderTest(tf.test.TestCase):
(classification_loss, localization_loss, classification_weight, (classification_loss, localization_loss, classification_weight,
localization_weight, hard_example_miner, _, localization_weight, hard_example_miner, _,
_) = losses_builder.build(losses_proto) _) = losses_builder.build(losses_proto)
self.assertTrue(isinstance(hard_example_miner, losses.HardExampleMiner)) self.assertIsInstance(hard_example_miner, losses.HardExampleMiner)
self.assertTrue(isinstance(classification_loss, self.assertIsInstance(classification_loss,
losses.WeightedSoftmaxClassificationLoss)) losses.WeightedSoftmaxClassificationLoss)
self.assertTrue(isinstance(localization_loss, self.assertIsInstance(localization_loss,
losses.WeightedL2LocalizationLoss)) losses.WeightedL2LocalizationLoss)
self.assertAlmostEqual(classification_weight, 0.8) self.assertAlmostEqual(classification_weight, 0.8)
self.assertAlmostEqual(localization_weight, 0.2) self.assertAlmostEqual(localization_weight, 0.2)
...@@ -434,12 +436,10 @@ class LossBuilderTest(tf.test.TestCase): ...@@ -434,12 +436,10 @@ class LossBuilderTest(tf.test.TestCase):
(classification_loss, localization_loss, classification_weight, (classification_loss, localization_loss, classification_weight,
localization_weight, hard_example_miner, _, localization_weight, hard_example_miner, _,
_) = losses_builder.build(losses_proto) _) = losses_builder.build(losses_proto)
self.assertTrue(isinstance(hard_example_miner, losses.HardExampleMiner)) self.assertIsInstance(hard_example_miner, losses.HardExampleMiner)
self.assertTrue( self.assertIsInstance(classification_loss,
isinstance(classification_loss, losses.WeightedSoftmaxClassificationLoss)
losses.WeightedSoftmaxClassificationLoss)) self.assertIsInstance(localization_loss, losses.WeightedL2LocalizationLoss)
self.assertTrue(
isinstance(localization_loss, losses.WeightedL2LocalizationLoss))
self.assertAlmostEqual(classification_weight, 0.8) self.assertAlmostEqual(classification_weight, 0.8)
self.assertAlmostEqual(localization_weight, 0.2) self.assertAlmostEqual(localization_weight, 0.2)
...@@ -464,12 +464,10 @@ class LossBuilderTest(tf.test.TestCase): ...@@ -464,12 +464,10 @@ class LossBuilderTest(tf.test.TestCase):
(classification_loss, localization_loss, classification_weight, (classification_loss, localization_loss, classification_weight,
localization_weight, hard_example_miner, _, localization_weight, hard_example_miner, _,
_) = losses_builder.build(losses_proto) _) = losses_builder.build(losses_proto)
self.assertTrue(isinstance(hard_example_miner, losses.HardExampleMiner)) self.assertIsInstance(hard_example_miner, losses.HardExampleMiner)
self.assertTrue( self.assertIsInstance(classification_loss,
isinstance(classification_loss, losses.WeightedSoftmaxClassificationLoss)
losses.WeightedSoftmaxClassificationLoss)) self.assertIsInstance(localization_loss, losses.WeightedL2LocalizationLoss)
self.assertTrue(
isinstance(localization_loss, losses.WeightedL2LocalizationLoss))
self.assertAlmostEqual(classification_weight, 0.8) self.assertAlmostEqual(classification_weight, 0.8)
self.assertAlmostEqual(localization_weight, 0.2) self.assertAlmostEqual(localization_weight, 0.2)
...@@ -505,8 +503,8 @@ class FasterRcnnClassificationLossBuilderTest(tf.test.TestCase): ...@@ -505,8 +503,8 @@ class FasterRcnnClassificationLossBuilderTest(tf.test.TestCase):
text_format.Merge(losses_text_proto, losses_proto) text_format.Merge(losses_text_proto, losses_proto)
classification_loss = losses_builder.build_faster_rcnn_classification_loss( classification_loss = losses_builder.build_faster_rcnn_classification_loss(
losses_proto) losses_proto)
self.assertTrue(isinstance(classification_loss, self.assertIsInstance(classification_loss,
losses.WeightedSigmoidClassificationLoss)) losses.WeightedSigmoidClassificationLoss)
def test_build_softmax_loss(self): def test_build_softmax_loss(self):
losses_text_proto = """ losses_text_proto = """
...@@ -517,8 +515,8 @@ class FasterRcnnClassificationLossBuilderTest(tf.test.TestCase): ...@@ -517,8 +515,8 @@ class FasterRcnnClassificationLossBuilderTest(tf.test.TestCase):
text_format.Merge(losses_text_proto, losses_proto) text_format.Merge(losses_text_proto, losses_proto)
classification_loss = losses_builder.build_faster_rcnn_classification_loss( classification_loss = losses_builder.build_faster_rcnn_classification_loss(
losses_proto) losses_proto)
self.assertTrue(isinstance(classification_loss, self.assertIsInstance(classification_loss,
losses.WeightedSoftmaxClassificationLoss)) losses.WeightedSoftmaxClassificationLoss)
def test_build_logits_softmax_loss(self): def test_build_logits_softmax_loss(self):
losses_text_proto = """ losses_text_proto = """
...@@ -542,9 +540,8 @@ class FasterRcnnClassificationLossBuilderTest(tf.test.TestCase): ...@@ -542,9 +540,8 @@ class FasterRcnnClassificationLossBuilderTest(tf.test.TestCase):
text_format.Merge(losses_text_proto, losses_proto) text_format.Merge(losses_text_proto, losses_proto)
classification_loss = losses_builder.build_faster_rcnn_classification_loss( classification_loss = losses_builder.build_faster_rcnn_classification_loss(
losses_proto) losses_proto)
self.assertTrue( self.assertIsInstance(classification_loss,
isinstance(classification_loss, losses.SigmoidFocalClassificationLoss)
losses.SigmoidFocalClassificationLoss))
def test_build_softmax_loss_by_default(self): def test_build_softmax_loss_by_default(self):
losses_text_proto = """ losses_text_proto = """
...@@ -553,8 +550,8 @@ class FasterRcnnClassificationLossBuilderTest(tf.test.TestCase): ...@@ -553,8 +550,8 @@ class FasterRcnnClassificationLossBuilderTest(tf.test.TestCase):
text_format.Merge(losses_text_proto, losses_proto) text_format.Merge(losses_text_proto, losses_proto)
classification_loss = losses_builder.build_faster_rcnn_classification_loss( classification_loss = losses_builder.build_faster_rcnn_classification_loss(
losses_proto) losses_proto)
self.assertTrue(isinstance(classification_loss, self.assertIsInstance(classification_loss,
losses.WeightedSoftmaxClassificationLoss)) losses.WeightedSoftmaxClassificationLoss)
if __name__ == '__main__': if __name__ == '__main__':
......
# Lint as: python2, python3
# Copyright 2017 The TensorFlow Authors. All Rights Reserved. # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
...@@ -12,25 +13,34 @@ ...@@ -12,25 +13,34 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Tests for object_detection.models.model_builder.""" """Tests for object_detection.models.model_builder."""
from absl.testing import parameterized from absl.testing import parameterized
import tensorflow as tf
from google.protobuf import text_format from google.protobuf import text_format
from object_detection.builders import model_builder from object_detection.builders import model_builder
from object_detection.meta_architectures import faster_rcnn_meta_arch from object_detection.meta_architectures import faster_rcnn_meta_arch
from object_detection.meta_architectures import rfcn_meta_arch from object_detection.meta_architectures import rfcn_meta_arch
from object_detection.meta_architectures import ssd_meta_arch from object_detection.meta_architectures import ssd_meta_arch
from object_detection.models import ssd_resnet_v1_fpn_feature_extractor as ssd_resnet_v1_fpn
from object_detection.protos import hyperparams_pb2 from object_detection.protos import hyperparams_pb2
from object_detection.protos import losses_pb2 from object_detection.protos import losses_pb2
from object_detection.protos import model_pb2 from object_detection.protos import model_pb2
from object_detection.utils import test_case
class ModelBuilderTest(test_case.TestCase, parameterized.TestCase):
def default_ssd_feature_extractor(self):
raise NotImplementedError
class ModelBuilderTest(tf.test.TestCase, parameterized.TestCase): def default_faster_rcnn_feature_extractor(self):
raise NotImplementedError
def ssd_feature_extractors(self):
raise NotImplementedError
def faster_rcnn_feature_extractors(self):
raise NotImplementedError
def create_model(self, model_config, is_training=True): def create_model(self, model_config, is_training=True):
"""Builds a DetectionModel based on the model config. """Builds a DetectionModel based on the model config.
...@@ -50,7 +60,6 @@ class ModelBuilderTest(tf.test.TestCase, parameterized.TestCase): ...@@ -50,7 +60,6 @@ class ModelBuilderTest(tf.test.TestCase, parameterized.TestCase):
model_text_proto = """ model_text_proto = """
ssd { ssd {
feature_extractor { feature_extractor {
type: 'ssd_inception_v2'
conv_hyperparams { conv_hyperparams {
regularizer { regularizer {
l2_regularizer { l2_regularizer {
...@@ -113,6 +122,8 @@ class ModelBuilderTest(tf.test.TestCase, parameterized.TestCase): ...@@ -113,6 +122,8 @@ class ModelBuilderTest(tf.test.TestCase, parameterized.TestCase):
}""" }"""
model_proto = model_pb2.DetectionModel() model_proto = model_pb2.DetectionModel()
text_format.Merge(model_text_proto, model_proto) text_format.Merge(model_text_proto, model_proto)
model_proto.ssd.feature_extractor.type = (self.
default_ssd_feature_extractor())
return model_proto return model_proto
def create_default_faster_rcnn_model_proto(self): def create_default_faster_rcnn_model_proto(self):
...@@ -127,9 +138,6 @@ class ModelBuilderTest(tf.test.TestCase, parameterized.TestCase): ...@@ -127,9 +138,6 @@ class ModelBuilderTest(tf.test.TestCase, parameterized.TestCase):
max_dimension: 1024 max_dimension: 1024
} }
} }
feature_extractor {
type: 'faster_rcnn_resnet101'
}
first_stage_anchor_generator { first_stage_anchor_generator {
grid_anchor_generator { grid_anchor_generator {
scales: [0.25, 0.5, 1.0, 2.0] scales: [0.25, 0.5, 1.0, 2.0]
...@@ -188,17 +196,14 @@ class ModelBuilderTest(tf.test.TestCase, parameterized.TestCase): ...@@ -188,17 +196,14 @@ class ModelBuilderTest(tf.test.TestCase, parameterized.TestCase):
}""" }"""
model_proto = model_pb2.DetectionModel() model_proto = model_pb2.DetectionModel()
text_format.Merge(model_text_proto, model_proto) text_format.Merge(model_text_proto, model_proto)
(model_proto.faster_rcnn.feature_extractor.type
) = self.default_faster_rcnn_feature_extractor()
return model_proto return model_proto
def test_create_ssd_models_from_config(self): def test_create_ssd_models_from_config(self):
model_proto = self.create_default_ssd_model_proto() model_proto = self.create_default_ssd_model_proto()
ssd_feature_extractor_map = {} for extractor_type, extractor_class in self.ssd_feature_extractors().items(
ssd_feature_extractor_map.update( ):
model_builder.SSD_FEATURE_EXTRACTOR_CLASS_MAP)
ssd_feature_extractor_map.update(
model_builder.SSD_KERAS_FEATURE_EXTRACTOR_CLASS_MAP)
for extractor_type, extractor_class in ssd_feature_extractor_map.items():
model_proto.ssd.feature_extractor.type = extractor_type model_proto.ssd.feature_extractor.type = extractor_type
model = model_builder.build(model_proto, is_training=True) model = model_builder.build(model_proto, is_training=True)
self.assertIsInstance(model, ssd_meta_arch.SSDMetaArch) self.assertIsInstance(model, ssd_meta_arch.SSDMetaArch)
...@@ -206,12 +211,9 @@ class ModelBuilderTest(tf.test.TestCase, parameterized.TestCase): ...@@ -206,12 +211,9 @@ class ModelBuilderTest(tf.test.TestCase, parameterized.TestCase):
def test_create_ssd_fpn_model_from_config(self): def test_create_ssd_fpn_model_from_config(self):
model_proto = self.create_default_ssd_model_proto() model_proto = self.create_default_ssd_model_proto()
model_proto.ssd.feature_extractor.type = 'ssd_resnet101_v1_fpn'
model_proto.ssd.feature_extractor.fpn.min_level = 3 model_proto.ssd.feature_extractor.fpn.min_level = 3
model_proto.ssd.feature_extractor.fpn.max_level = 7 model_proto.ssd.feature_extractor.fpn.max_level = 7
model = model_builder.build(model_proto, is_training=True) model = model_builder.build(model_proto, is_training=True)
self.assertIsInstance(model._feature_extractor,
ssd_resnet_v1_fpn.SSDResnet101V1FpnFeatureExtractor)
self.assertEqual(model._feature_extractor._fpn_min_level, 3) self.assertEqual(model._feature_extractor._fpn_min_level, 3)
self.assertEqual(model._feature_extractor._fpn_max_level, 7) self.assertEqual(model._feature_extractor._fpn_max_level, 7)
...@@ -238,8 +240,9 @@ class ModelBuilderTest(tf.test.TestCase, parameterized.TestCase): ...@@ -238,8 +240,9 @@ class ModelBuilderTest(tf.test.TestCase, parameterized.TestCase):
'enable_mask_prediction': False 'enable_mask_prediction': False
}, },
) )
def test_create_faster_rcnn_models_from_config( def test_create_faster_rcnn_models_from_config(self,
self, use_matmul_crop_and_resize, enable_mask_prediction): use_matmul_crop_and_resize,
enable_mask_prediction):
model_proto = self.create_default_faster_rcnn_model_proto() model_proto = self.create_default_faster_rcnn_model_proto()
faster_rcnn_config = model_proto.faster_rcnn faster_rcnn_config = model_proto.faster_rcnn
faster_rcnn_config.use_matmul_crop_and_resize = use_matmul_crop_and_resize faster_rcnn_config.use_matmul_crop_and_resize = use_matmul_crop_and_resize
...@@ -250,7 +253,7 @@ class ModelBuilderTest(tf.test.TestCase, parameterized.TestCase): ...@@ -250,7 +253,7 @@ class ModelBuilderTest(tf.test.TestCase, parameterized.TestCase):
mask_predictor_config.predict_instance_masks = True mask_predictor_config.predict_instance_masks = True
for extractor_type, extractor_class in ( for extractor_type, extractor_class in (
model_builder.FASTER_RCNN_FEATURE_EXTRACTOR_CLASS_MAP.items()): self.faster_rcnn_feature_extractors().items()):
faster_rcnn_config.feature_extractor.type = extractor_type faster_rcnn_config.feature_extractor.type = extractor_type
model = model_builder.build(model_proto, is_training=True) model = model_builder.build(model_proto, is_training=True)
self.assertIsInstance(model, faster_rcnn_meta_arch.FasterRCNNMetaArch) self.assertIsInstance(model, faster_rcnn_meta_arch.FasterRCNNMetaArch)
...@@ -270,44 +273,51 @@ class ModelBuilderTest(tf.test.TestCase, parameterized.TestCase): ...@@ -270,44 +273,51 @@ class ModelBuilderTest(tf.test.TestCase, parameterized.TestCase):
model_proto.faster_rcnn.second_stage_box_predictor.rfcn_box_predictor) model_proto.faster_rcnn.second_stage_box_predictor.rfcn_box_predictor)
rfcn_predictor_config.conv_hyperparams.op = hyperparams_pb2.Hyperparams.CONV rfcn_predictor_config.conv_hyperparams.op = hyperparams_pb2.Hyperparams.CONV
for extractor_type, extractor_class in ( for extractor_type, extractor_class in (
model_builder.FASTER_RCNN_FEATURE_EXTRACTOR_CLASS_MAP.items()): self.faster_rcnn_feature_extractors().items()):
model_proto.faster_rcnn.feature_extractor.type = extractor_type model_proto.faster_rcnn.feature_extractor.type = extractor_type
model = model_builder.build(model_proto, is_training=True) model = model_builder.build(model_proto, is_training=True)
self.assertIsInstance(model, rfcn_meta_arch.RFCNMetaArch) self.assertIsInstance(model, rfcn_meta_arch.RFCNMetaArch)
self.assertIsInstance(model._feature_extractor, extractor_class) self.assertIsInstance(model._feature_extractor, extractor_class)
@parameterized.parameters(True, False)
def test_create_faster_rcnn_from_config_with_crop_feature(
self, output_final_box_features):
model_proto = self.create_default_faster_rcnn_model_proto()
model_proto.faster_rcnn.output_final_box_features = (
output_final_box_features)
_ = model_builder.build(model_proto, is_training=True)
def test_invalid_model_config_proto(self): def test_invalid_model_config_proto(self):
model_proto = '' model_proto = ''
with self.assertRaisesRegexp( with self.assertRaisesRegex(
ValueError, 'model_config not of type model_pb2.DetectionModel.'): ValueError, 'model_config not of type model_pb2.DetectionModel.'):
model_builder.build(model_proto, is_training=True) model_builder.build(model_proto, is_training=True)
def test_unknown_meta_architecture(self): def test_unknown_meta_architecture(self):
model_proto = model_pb2.DetectionModel() model_proto = model_pb2.DetectionModel()
with self.assertRaisesRegexp(ValueError, 'Unknown meta architecture'): with self.assertRaisesRegex(ValueError, 'Unknown meta architecture'):
model_builder.build(model_proto, is_training=True) model_builder.build(model_proto, is_training=True)
def test_unknown_ssd_feature_extractor(self): def test_unknown_ssd_feature_extractor(self):
model_proto = self.create_default_ssd_model_proto() model_proto = self.create_default_ssd_model_proto()
model_proto.ssd.feature_extractor.type = 'unknown_feature_extractor' model_proto.ssd.feature_extractor.type = 'unknown_feature_extractor'
with self.assertRaisesRegexp(ValueError, 'Unknown ssd feature_extractor'): with self.assertRaises(ValueError):
model_builder.build(model_proto, is_training=True) model_builder.build(model_proto, is_training=True)
def test_unknown_faster_rcnn_feature_extractor(self): def test_unknown_faster_rcnn_feature_extractor(self):
model_proto = self.create_default_faster_rcnn_model_proto() model_proto = self.create_default_faster_rcnn_model_proto()
model_proto.faster_rcnn.feature_extractor.type = 'unknown_feature_extractor' model_proto.faster_rcnn.feature_extractor.type = 'unknown_feature_extractor'
with self.assertRaisesRegexp(ValueError, with self.assertRaises(ValueError):
'Unknown Faster R-CNN feature_extractor'):
model_builder.build(model_proto, is_training=True) model_builder.build(model_proto, is_training=True)
def test_invalid_first_stage_nms_iou_threshold(self): def test_invalid_first_stage_nms_iou_threshold(self):
model_proto = self.create_default_faster_rcnn_model_proto() model_proto = self.create_default_faster_rcnn_model_proto()
model_proto.faster_rcnn.first_stage_nms_iou_threshold = 1.1 model_proto.faster_rcnn.first_stage_nms_iou_threshold = 1.1
with self.assertRaisesRegexp(ValueError, with self.assertRaisesRegex(ValueError,
r'iou_threshold not in \[0, 1\.0\]'): r'iou_threshold not in \[0, 1\.0\]'):
model_builder.build(model_proto, is_training=True) model_builder.build(model_proto, is_training=True)
model_proto.faster_rcnn.first_stage_nms_iou_threshold = -0.1 model_proto.faster_rcnn.first_stage_nms_iou_threshold = -0.1
with self.assertRaisesRegexp(ValueError, with self.assertRaisesRegex(ValueError,
r'iou_threshold not in \[0, 1\.0\]'): r'iou_threshold not in \[0, 1\.0\]'):
model_builder.build(model_proto, is_training=True) model_builder.build(model_proto, is_training=True)
...@@ -315,7 +325,7 @@ class ModelBuilderTest(tf.test.TestCase, parameterized.TestCase): ...@@ -315,7 +325,7 @@ class ModelBuilderTest(tf.test.TestCase, parameterized.TestCase):
model_proto = self.create_default_faster_rcnn_model_proto() model_proto = self.create_default_faster_rcnn_model_proto()
model_proto.faster_rcnn.first_stage_max_proposals = 1 model_proto.faster_rcnn.first_stage_max_proposals = 1
model_proto.faster_rcnn.second_stage_batch_size = 2 model_proto.faster_rcnn.second_stage_batch_size = 2
with self.assertRaisesRegexp( with self.assertRaisesRegex(
ValueError, 'second_stage_batch_size should be no greater ' ValueError, 'second_stage_batch_size should be no greater '
'than first_stage_max_proposals.'): 'than first_stage_max_proposals.'):
model_builder.build(model_proto, is_training=True) model_builder.build(model_proto, is_training=True)
...@@ -323,7 +333,7 @@ class ModelBuilderTest(tf.test.TestCase, parameterized.TestCase): ...@@ -323,7 +333,7 @@ class ModelBuilderTest(tf.test.TestCase, parameterized.TestCase):
def test_invalid_faster_rcnn_batchnorm_update(self): def test_invalid_faster_rcnn_batchnorm_update(self):
model_proto = self.create_default_faster_rcnn_model_proto() model_proto = self.create_default_faster_rcnn_model_proto()
model_proto.faster_rcnn.inplace_batchnorm_update = True model_proto.faster_rcnn.inplace_batchnorm_update = True
with self.assertRaisesRegexp(ValueError, with self.assertRaisesRegex(ValueError,
'inplace batchnorm updates not supported'): 'inplace batchnorm updates not supported'):
model_builder.build(model_proto, is_training=True) model_builder.build(model_proto, is_training=True)
...@@ -340,7 +350,3 @@ class ModelBuilderTest(tf.test.TestCase, parameterized.TestCase): ...@@ -340,7 +350,3 @@ class ModelBuilderTest(tf.test.TestCase, parameterized.TestCase):
text_format.Merge(model_text_proto, model_proto) text_format.Merge(model_text_proto, model_proto)
self.assertEqual(model_builder.build(model_proto, is_training=True), 42) self.assertEqual(model_builder.build(model_proto, is_training=True), 42)
if __name__ == '__main__':
tf.test.main()
# Lint as: python2, python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for model_builder under TensorFlow 1.X."""
from absl.testing import parameterized
import tensorflow as tf
from object_detection.builders import model_builder
from object_detection.builders import model_builder_test
from object_detection.meta_architectures import ssd_meta_arch
from object_detection.protos import losses_pb2
class ModelBuilderTF1Test(model_builder_test.ModelBuilderTest):
def default_ssd_feature_extractor(self):
return 'ssd_resnet50_v1_fpn'
def default_faster_rcnn_feature_extractor(self):
return 'faster_rcnn_resnet101'
def ssd_feature_extractors(self):
return model_builder.SSD_FEATURE_EXTRACTOR_CLASS_MAP
def faster_rcnn_feature_extractors(self):
return model_builder.FASTER_RCNN_FEATURE_EXTRACTOR_CLASS_MAP
if __name__ == '__main__':
tf.test.main()
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
import tensorflow as tf import tensorflow as tf
from tensorflow.contrib import opt as tf_opt
from object_detection.utils import learning_schedules from object_detection.utils import learning_schedules
...@@ -64,14 +65,14 @@ def build_optimizers_tf_v1(optimizer_config, global_step=None): ...@@ -64,14 +65,14 @@ def build_optimizers_tf_v1(optimizer_config, global_step=None):
learning_rate = _create_learning_rate(config.learning_rate, learning_rate = _create_learning_rate(config.learning_rate,
global_step=global_step) global_step=global_step)
summary_vars.append(learning_rate) summary_vars.append(learning_rate)
optimizer = tf.train.AdamOptimizer(learning_rate) optimizer = tf.train.AdamOptimizer(learning_rate, epsilon=config.epsilon)
if optimizer is None: if optimizer is None:
raise ValueError('Optimizer %s not supported.' % optimizer_type) raise ValueError('Optimizer %s not supported.' % optimizer_type)
if optimizer_config.use_moving_average: if optimizer_config.use_moving_average:
optimizer = tf.contrib.opt.MovingAverageOptimizer( optimizer = tf_opt.MovingAverageOptimizer(
optimizer, average_decay=optimizer_config.moving_average_decay) optimizer, average_decay=optimizer_config.moving_average_decay)
return optimizer, summary_vars return optimizer, summary_vars
...@@ -120,7 +121,7 @@ def build_optimizers_tf_v2(optimizer_config, global_step=None): ...@@ -120,7 +121,7 @@ def build_optimizers_tf_v2(optimizer_config, global_step=None):
learning_rate = _create_learning_rate(config.learning_rate, learning_rate = _create_learning_rate(config.learning_rate,
global_step=global_step) global_step=global_step)
summary_vars.append(learning_rate) summary_vars.append(learning_rate)
optimizer = tf.keras.optimizers.Adam(learning_rate) optimizer = tf.keras.optimizers.Adam(learning_rate, epsilon=config.epsilon)
if optimizer is None: if optimizer is None:
raise ValueError('Optimizer %s not supported.' % optimizer_type) raise ValueError('Optimizer %s not supported.' % optimizer_type)
......
# Lint as: python2, python3
# Copyright 2017 The TensorFlow Authors. All Rights Reserved. # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
...@@ -15,6 +16,11 @@ ...@@ -15,6 +16,11 @@
"""Tests for optimizer_builder.""" """Tests for optimizer_builder."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import six
import tensorflow as tf import tensorflow as tf
from google.protobuf import text_format from google.protobuf import text_format
...@@ -22,6 +28,14 @@ from google.protobuf import text_format ...@@ -22,6 +28,14 @@ from google.protobuf import text_format
from object_detection.builders import optimizer_builder from object_detection.builders import optimizer_builder
from object_detection.protos import optimizer_pb2 from object_detection.protos import optimizer_pb2
# pylint: disable=g-import-not-at-top
try:
from tensorflow.contrib import opt as contrib_opt
except ImportError:
# TF 2.0 doesn't ship with contrib.
pass
# pylint: enable=g-import-not-at-top
class LearningRateBuilderTest(tf.test.TestCase): class LearningRateBuilderTest(tf.test.TestCase):
...@@ -35,7 +49,8 @@ class LearningRateBuilderTest(tf.test.TestCase): ...@@ -35,7 +49,8 @@ class LearningRateBuilderTest(tf.test.TestCase):
text_format.Merge(learning_rate_text_proto, learning_rate_proto) text_format.Merge(learning_rate_text_proto, learning_rate_proto)
learning_rate = optimizer_builder._create_learning_rate( learning_rate = optimizer_builder._create_learning_rate(
learning_rate_proto) learning_rate_proto)
self.assertTrue(learning_rate.op.name.endswith('learning_rate')) self.assertTrue(
six.ensure_str(learning_rate.op.name).endswith('learning_rate'))
with self.test_session(): with self.test_session():
learning_rate_out = learning_rate.eval() learning_rate_out = learning_rate.eval()
self.assertAlmostEqual(learning_rate_out, 0.004) self.assertAlmostEqual(learning_rate_out, 0.004)
...@@ -53,8 +68,9 @@ class LearningRateBuilderTest(tf.test.TestCase): ...@@ -53,8 +68,9 @@ class LearningRateBuilderTest(tf.test.TestCase):
text_format.Merge(learning_rate_text_proto, learning_rate_proto) text_format.Merge(learning_rate_text_proto, learning_rate_proto)
learning_rate = optimizer_builder._create_learning_rate( learning_rate = optimizer_builder._create_learning_rate(
learning_rate_proto) learning_rate_proto)
self.assertTrue(learning_rate.op.name.endswith('learning_rate')) self.assertTrue(
self.assertTrue(isinstance(learning_rate, tf.Tensor)) six.ensure_str(learning_rate.op.name).endswith('learning_rate'))
self.assertIsInstance(learning_rate, tf.Tensor)
def testBuildManualStepLearningRate(self): def testBuildManualStepLearningRate(self):
learning_rate_text_proto = """ learning_rate_text_proto = """
...@@ -75,7 +91,7 @@ class LearningRateBuilderTest(tf.test.TestCase): ...@@ -75,7 +91,7 @@ class LearningRateBuilderTest(tf.test.TestCase):
text_format.Merge(learning_rate_text_proto, learning_rate_proto) text_format.Merge(learning_rate_text_proto, learning_rate_proto)
learning_rate = optimizer_builder._create_learning_rate( learning_rate = optimizer_builder._create_learning_rate(
learning_rate_proto) learning_rate_proto)
self.assertTrue(isinstance(learning_rate, tf.Tensor)) self.assertIsInstance(learning_rate, tf.Tensor)
def testBuildCosineDecayLearningRate(self): def testBuildCosineDecayLearningRate(self):
learning_rate_text_proto = """ learning_rate_text_proto = """
...@@ -91,7 +107,7 @@ class LearningRateBuilderTest(tf.test.TestCase): ...@@ -91,7 +107,7 @@ class LearningRateBuilderTest(tf.test.TestCase):
text_format.Merge(learning_rate_text_proto, learning_rate_proto) text_format.Merge(learning_rate_text_proto, learning_rate_proto)
learning_rate = optimizer_builder._create_learning_rate( learning_rate = optimizer_builder._create_learning_rate(
learning_rate_proto) learning_rate_proto)
self.assertTrue(isinstance(learning_rate, tf.Tensor)) self.assertIsInstance(learning_rate, tf.Tensor)
def testRaiseErrorOnEmptyLearningRate(self): def testRaiseErrorOnEmptyLearningRate(self):
learning_rate_text_proto = """ learning_rate_text_proto = """
...@@ -123,7 +139,7 @@ class OptimizerBuilderTest(tf.test.TestCase): ...@@ -123,7 +139,7 @@ class OptimizerBuilderTest(tf.test.TestCase):
optimizer_proto = optimizer_pb2.Optimizer() optimizer_proto = optimizer_pb2.Optimizer()
text_format.Merge(optimizer_text_proto, optimizer_proto) text_format.Merge(optimizer_text_proto, optimizer_proto)
optimizer, _ = optimizer_builder.build(optimizer_proto) optimizer, _ = optimizer_builder.build(optimizer_proto)
self.assertTrue(isinstance(optimizer, tf.train.RMSPropOptimizer)) self.assertIsInstance(optimizer, tf.train.RMSPropOptimizer)
def testBuildMomentumOptimizer(self): def testBuildMomentumOptimizer(self):
optimizer_text_proto = """ optimizer_text_proto = """
...@@ -140,11 +156,12 @@ class OptimizerBuilderTest(tf.test.TestCase): ...@@ -140,11 +156,12 @@ class OptimizerBuilderTest(tf.test.TestCase):
optimizer_proto = optimizer_pb2.Optimizer() optimizer_proto = optimizer_pb2.Optimizer()
text_format.Merge(optimizer_text_proto, optimizer_proto) text_format.Merge(optimizer_text_proto, optimizer_proto)
optimizer, _ = optimizer_builder.build(optimizer_proto) optimizer, _ = optimizer_builder.build(optimizer_proto)
self.assertTrue(isinstance(optimizer, tf.train.MomentumOptimizer)) self.assertIsInstance(optimizer, tf.train.MomentumOptimizer)
def testBuildAdamOptimizer(self): def testBuildAdamOptimizer(self):
optimizer_text_proto = """ optimizer_text_proto = """
adam_optimizer: { adam_optimizer: {
epsilon: 1e-6
learning_rate: { learning_rate: {
constant_learning_rate { constant_learning_rate {
learning_rate: 0.002 learning_rate: 0.002
...@@ -156,7 +173,7 @@ class OptimizerBuilderTest(tf.test.TestCase): ...@@ -156,7 +173,7 @@ class OptimizerBuilderTest(tf.test.TestCase):
optimizer_proto = optimizer_pb2.Optimizer() optimizer_proto = optimizer_pb2.Optimizer()
text_format.Merge(optimizer_text_proto, optimizer_proto) text_format.Merge(optimizer_text_proto, optimizer_proto)
optimizer, _ = optimizer_builder.build(optimizer_proto) optimizer, _ = optimizer_builder.build(optimizer_proto)
self.assertTrue(isinstance(optimizer, tf.train.AdamOptimizer)) self.assertIsInstance(optimizer, tf.train.AdamOptimizer)
def testBuildMovingAverageOptimizer(self): def testBuildMovingAverageOptimizer(self):
optimizer_text_proto = """ optimizer_text_proto = """
...@@ -172,8 +189,7 @@ class OptimizerBuilderTest(tf.test.TestCase): ...@@ -172,8 +189,7 @@ class OptimizerBuilderTest(tf.test.TestCase):
optimizer_proto = optimizer_pb2.Optimizer() optimizer_proto = optimizer_pb2.Optimizer()
text_format.Merge(optimizer_text_proto, optimizer_proto) text_format.Merge(optimizer_text_proto, optimizer_proto)
optimizer, _ = optimizer_builder.build(optimizer_proto) optimizer, _ = optimizer_builder.build(optimizer_proto)
self.assertTrue( self.assertIsInstance(optimizer, contrib_opt.MovingAverageOptimizer)
isinstance(optimizer, tf.contrib.opt.MovingAverageOptimizer))
def testBuildMovingAverageOptimizerWithNonDefaultDecay(self): def testBuildMovingAverageOptimizerWithNonDefaultDecay(self):
optimizer_text_proto = """ optimizer_text_proto = """
...@@ -190,8 +206,7 @@ class OptimizerBuilderTest(tf.test.TestCase): ...@@ -190,8 +206,7 @@ class OptimizerBuilderTest(tf.test.TestCase):
optimizer_proto = optimizer_pb2.Optimizer() optimizer_proto = optimizer_pb2.Optimizer()
text_format.Merge(optimizer_text_proto, optimizer_proto) text_format.Merge(optimizer_text_proto, optimizer_proto)
optimizer, _ = optimizer_builder.build(optimizer_proto) optimizer, _ = optimizer_builder.build(optimizer_proto)
self.assertTrue( self.assertIsInstance(optimizer, contrib_opt.MovingAverageOptimizer)
isinstance(optimizer, tf.contrib.opt.MovingAverageOptimizer))
# TODO(rathodv): Find a way to not depend on the private members. # TODO(rathodv): Find a way to not depend on the private members.
self.assertAlmostEqual(optimizer._ema._decay, 0.2) self.assertAlmostEqual(optimizer._ema._decay, 0.2)
......
...@@ -102,7 +102,7 @@ def _build_non_max_suppressor(nms_config): ...@@ -102,7 +102,7 @@ def _build_non_max_suppressor(nms_config):
soft_nms_sigma=nms_config.soft_nms_sigma, soft_nms_sigma=nms_config.soft_nms_sigma,
use_partitioned_nms=nms_config.use_partitioned_nms, use_partitioned_nms=nms_config.use_partitioned_nms,
use_combined_nms=nms_config.use_combined_nms, use_combined_nms=nms_config.use_combined_nms,
change_coordinate_frame=True) change_coordinate_frame=nms_config.change_coordinate_frame)
return non_max_suppressor_fn return non_max_suppressor_fn
...@@ -110,7 +110,7 @@ def _build_non_max_suppressor(nms_config): ...@@ -110,7 +110,7 @@ def _build_non_max_suppressor(nms_config):
def _score_converter_fn_with_logit_scale(tf_score_converter_fn, logit_scale): def _score_converter_fn_with_logit_scale(tf_score_converter_fn, logit_scale):
"""Create a function to scale logits then apply a Tensorflow function.""" """Create a function to scale logits then apply a Tensorflow function."""
def score_converter_fn(logits): def score_converter_fn(logits):
scaled_logits = tf.divide(logits, logit_scale, name='scale_logits') scaled_logits = tf.multiply(logits, 1.0 / logit_scale, name='scale_logits')
return tf_score_converter_fn(scaled_logits, name='convert_scores') return tf_score_converter_fn(scaled_logits, name='convert_scores')
score_converter_fn.__name__ = '%s_with_logit_scale' % ( score_converter_fn.__name__ = '%s_with_logit_scale' % (
tf_score_converter_fn.__name__) tf_score_converter_fn.__name__)
......
...@@ -150,7 +150,7 @@ def build(preprocessor_step_config): ...@@ -150,7 +150,7 @@ def build(preprocessor_step_config):
return (preprocessor.random_horizontal_flip, return (preprocessor.random_horizontal_flip,
{ {
'keypoint_flip_permutation': tuple( 'keypoint_flip_permutation': tuple(
config.keypoint_flip_permutation), config.keypoint_flip_permutation) or None,
}) })
if step_type == 'random_vertical_flip': if step_type == 'random_vertical_flip':
...@@ -158,7 +158,7 @@ def build(preprocessor_step_config): ...@@ -158,7 +158,7 @@ def build(preprocessor_step_config):
return (preprocessor.random_vertical_flip, return (preprocessor.random_vertical_flip,
{ {
'keypoint_flip_permutation': tuple( 'keypoint_flip_permutation': tuple(
config.keypoint_flip_permutation), config.keypoint_flip_permutation) or None,
}) })
if step_type == 'random_rotation90': if step_type == 'random_rotation90':
...@@ -400,4 +400,13 @@ def build(preprocessor_step_config): ...@@ -400,4 +400,13 @@ def build(preprocessor_step_config):
kwargs['random_coef'] = [op.random_coef for op in config.operations] kwargs['random_coef'] = [op.random_coef for op in config.operations]
return (preprocessor.ssd_random_crop_pad_fixed_aspect_ratio, kwargs) return (preprocessor.ssd_random_crop_pad_fixed_aspect_ratio, kwargs)
if step_type == 'random_square_crop_by_scale':
config = preprocessor_step_config.random_square_crop_by_scale
return preprocessor.random_square_crop_by_scale, {
'scale_min': config.scale_min,
'scale_max': config.scale_max,
'max_border': config.max_border,
'num_scales': config.num_scales
}
raise ValueError('Unknown preprocessing step.') raise ValueError('Unknown preprocessing step.')
...@@ -723,6 +723,25 @@ class PreprocessorBuilderTest(tf.test.TestCase): ...@@ -723,6 +723,25 @@ class PreprocessorBuilderTest(tf.test.TestCase):
self.assertEqual(function, preprocessor.convert_class_logits_to_softmax) self.assertEqual(function, preprocessor.convert_class_logits_to_softmax)
self.assertEqual(args, {'temperature': 2}) self.assertEqual(args, {'temperature': 2})
def test_random_crop_by_scale(self):
preprocessor_text_proto = """
random_square_crop_by_scale {
scale_min: 0.25
scale_max: 2.0
num_scales: 8
}
"""
preprocessor_proto = preprocessor_pb2.PreprocessingStep()
text_format.Merge(preprocessor_text_proto, preprocessor_proto)
function, args = preprocessor_builder.build(preprocessor_proto)
self.assertEqual(function, preprocessor.random_square_crop_by_scale)
self.assertEqual(args, {
'scale_min': 0.25,
'scale_max': 2.0,
'num_scales': 8,
'max_border': 128
})
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
...@@ -33,7 +33,6 @@ when number of examples set to True in indicator is less than batch_size. ...@@ -33,7 +33,6 @@ when number of examples set to True in indicator is less than batch_size.
import tensorflow as tf import tensorflow as tf
from object_detection.core import minibatch_sampler from object_detection.core import minibatch_sampler
from object_detection.utils import ops
class BalancedPositiveNegativeSampler(minibatch_sampler.MinibatchSampler): class BalancedPositiveNegativeSampler(minibatch_sampler.MinibatchSampler):
...@@ -158,19 +157,17 @@ class BalancedPositiveNegativeSampler(minibatch_sampler.MinibatchSampler): ...@@ -158,19 +157,17 @@ class BalancedPositiveNegativeSampler(minibatch_sampler.MinibatchSampler):
# Shuffle indicator and label. Need to store the permutation to restore the # Shuffle indicator and label. Need to store the permutation to restore the
# order post sampling. # order post sampling.
permutation = tf.random_shuffle(tf.range(input_length)) permutation = tf.random_shuffle(tf.range(input_length))
indicator = ops.matmul_gather_on_zeroth_axis( indicator = tf.gather(indicator, permutation, axis=0)
tf.cast(indicator, tf.float32), permutation) labels = tf.gather(labels, permutation, axis=0)
labels = ops.matmul_gather_on_zeroth_axis(
tf.cast(labels, tf.float32), permutation)
# index (starting from 1) when indicator is True, 0 when False # index (starting from 1) when indicator is True, 0 when False
indicator_idx = tf.where( indicator_idx = tf.where(
tf.cast(indicator, tf.bool), tf.range(1, input_length + 1), indicator, tf.range(1, input_length + 1),
tf.zeros(input_length, tf.int32)) tf.zeros(input_length, tf.int32))
# Replace -1 for negative, +1 for positive labels # Replace -1 for negative, +1 for positive labels
signed_label = tf.where( signed_label = tf.where(
tf.cast(labels, tf.bool), tf.ones(input_length, tf.int32), labels, tf.ones(input_length, tf.int32),
tf.scalar_mul(-1, tf.ones(input_length, tf.int32))) tf.scalar_mul(-1, tf.ones(input_length, tf.int32)))
# negative of index for negative label, positive index for positive label, # negative of index for negative label, positive index for positive label,
# 0 when indicator is False. # 0 when indicator is False.
...@@ -198,11 +195,10 @@ class BalancedPositiveNegativeSampler(minibatch_sampler.MinibatchSampler): ...@@ -198,11 +195,10 @@ class BalancedPositiveNegativeSampler(minibatch_sampler.MinibatchSampler):
axis=0), tf.bool) axis=0), tf.bool)
# project back the order based on stored permutations # project back the order based on stored permutations
reprojections = tf.one_hot(permutation, depth=input_length, idx_indicator = tf.scatter_nd(
dtype=tf.float32) tf.expand_dims(permutation, -1), sampled_idx_indicator,
return tf.cast(tf.tensordot( shape=(input_length,))
tf.cast(sampled_idx_indicator, tf.float32), return idx_indicator
reprojections, axes=[0, 0]), tf.bool)
def subsample(self, indicator, batch_size, labels, scope=None): def subsample(self, indicator, batch_size, labels, scope=None):
"""Returns subsampled minibatch. """Returns subsampled minibatch.
......
...@@ -24,24 +24,27 @@ from object_detection.utils import test_case ...@@ -24,24 +24,27 @@ from object_detection.utils import test_case
class BalancedPositiveNegativeSamplerTest(test_case.TestCase): class BalancedPositiveNegativeSamplerTest(test_case.TestCase):
def test_subsample_all_examples_dynamic(self): def test_subsample_all_examples(self):
if self.has_tpu(): return
numpy_labels = np.random.permutation(300) numpy_labels = np.random.permutation(300)
indicator = tf.constant(np.ones(300) == 1) indicator = np.array(np.ones(300) == 1, np.bool)
numpy_labels = (numpy_labels - 200) > 0 numpy_labels = (numpy_labels - 200) > 0
labels = tf.constant(numpy_labels) labels = np.array(numpy_labels, np.bool)
def graph_fn(indicator, labels):
sampler = ( sampler = (
balanced_positive_negative_sampler.BalancedPositiveNegativeSampler()) balanced_positive_negative_sampler.BalancedPositiveNegativeSampler())
is_sampled = sampler.subsample(indicator, 64, labels) return sampler.subsample(indicator, 64, labels)
with self.test_session() as sess:
is_sampled = sess.run(is_sampled) is_sampled = self.execute_cpu(graph_fn, [indicator, labels])
self.assertTrue(sum(is_sampled) == 64) self.assertEqual(sum(is_sampled), 64)
self.assertTrue(sum(np.logical_and(numpy_labels, is_sampled)) == 32) self.assertEqual(sum(np.logical_and(numpy_labels, is_sampled)), 32)
self.assertTrue(sum(np.logical_and( self.assertEqual(sum(np.logical_and(
np.logical_not(numpy_labels), is_sampled)) == 32) np.logical_not(numpy_labels), is_sampled)), 32)
def test_subsample_all_examples_static(self): def test_subsample_all_examples_static(self):
if not self.has_tpu(): return
numpy_labels = np.random.permutation(300) numpy_labels = np.random.permutation(300)
indicator = np.array(np.ones(300) == 1, np.bool) indicator = np.array(np.ones(300) == 1, np.bool)
numpy_labels = (numpy_labels - 200) > 0 numpy_labels = (numpy_labels - 200) > 0
...@@ -54,35 +57,37 @@ class BalancedPositiveNegativeSamplerTest(test_case.TestCase): ...@@ -54,35 +57,37 @@ class BalancedPositiveNegativeSamplerTest(test_case.TestCase):
is_static=True)) is_static=True))
return sampler.subsample(indicator, 64, labels) return sampler.subsample(indicator, 64, labels)
is_sampled = self.execute(graph_fn, [indicator, labels]) is_sampled = self.execute_tpu(graph_fn, [indicator, labels])
self.assertTrue(sum(is_sampled) == 64) self.assertEqual(sum(is_sampled), 64)
self.assertTrue(sum(np.logical_and(numpy_labels, is_sampled)) == 32) self.assertEqual(sum(np.logical_and(numpy_labels, is_sampled)), 32)
self.assertTrue(sum(np.logical_and( self.assertEqual(sum(np.logical_and(
np.logical_not(numpy_labels), is_sampled)) == 32) np.logical_not(numpy_labels), is_sampled)), 32)
def test_subsample_selection_dynamic(self): def test_subsample_selection(self):
if self.has_tpu(): return
# Test random sampling when only some examples can be sampled: # Test random sampling when only some examples can be sampled:
# 100 samples, 20 positives, 10 positives cannot be sampled # 100 samples, 20 positives, 10 positives cannot be sampled.
numpy_labels = np.arange(100) numpy_labels = np.arange(100)
numpy_indicator = numpy_labels < 90 numpy_indicator = numpy_labels < 90
indicator = tf.constant(numpy_indicator) indicator = np.array(numpy_indicator, np.bool)
numpy_labels = (numpy_labels - 80) >= 0 numpy_labels = (numpy_labels - 80) >= 0
labels = tf.constant(numpy_labels) labels = np.array(numpy_labels, np.bool)
def graph_fn(indicator, labels):
sampler = ( sampler = (
balanced_positive_negative_sampler.BalancedPositiveNegativeSampler()) balanced_positive_negative_sampler.BalancedPositiveNegativeSampler())
is_sampled = sampler.subsample(indicator, 64, labels) return sampler.subsample(indicator, 64, labels)
with self.test_session() as sess:
is_sampled = sess.run(is_sampled) is_sampled = self.execute_cpu(graph_fn, [indicator, labels])
self.assertTrue(sum(is_sampled) == 64) self.assertEqual(sum(is_sampled), 64)
self.assertTrue(sum(np.logical_and(numpy_labels, is_sampled)) == 10) self.assertEqual(sum(np.logical_and(numpy_labels, is_sampled)), 10)
self.assertTrue(sum(np.logical_and( self.assertEqual(sum(np.logical_and(
np.logical_not(numpy_labels), is_sampled)) == 54) np.logical_not(numpy_labels), is_sampled)), 54)
self.assertAllEqual(is_sampled, np.logical_and(is_sampled, self.assertAllEqual(is_sampled, np.logical_and(is_sampled, numpy_indicator))
numpy_indicator))
def test_subsample_selection_static(self): def test_subsample_selection_static(self):
if not self.has_tpu(): return
# Test random sampling when only some examples can be sampled: # Test random sampling when only some examples can be sampled:
# 100 samples, 20 positives, 10 positives cannot be sampled. # 100 samples, 20 positives, 10 positives cannot be sampled.
numpy_labels = np.arange(100) numpy_labels = np.arange(100)
...@@ -98,37 +103,41 @@ class BalancedPositiveNegativeSamplerTest(test_case.TestCase): ...@@ -98,37 +103,41 @@ class BalancedPositiveNegativeSamplerTest(test_case.TestCase):
is_static=True)) is_static=True))
return sampler.subsample(indicator, 64, labels) return sampler.subsample(indicator, 64, labels)
is_sampled = self.execute(graph_fn, [indicator, labels]) is_sampled = self.execute_tpu(graph_fn, [indicator, labels])
self.assertTrue(sum(is_sampled) == 64) self.assertEqual(sum(is_sampled), 64)
self.assertTrue(sum(np.logical_and(numpy_labels, is_sampled)) == 10) self.assertEqual(sum(np.logical_and(numpy_labels, is_sampled)), 10)
self.assertTrue(sum(np.logical_and( self.assertEqual(sum(np.logical_and(
np.logical_not(numpy_labels), is_sampled)) == 54) np.logical_not(numpy_labels), is_sampled)), 54)
self.assertAllEqual(is_sampled, np.logical_and(is_sampled, numpy_indicator)) self.assertAllEqual(is_sampled, np.logical_and(is_sampled, numpy_indicator))
def test_subsample_selection_larger_batch_size_dynamic(self): def test_subsample_selection_larger_batch_size(self):
if self.has_tpu(): return
# Test random sampling when total number of examples that can be sampled are # Test random sampling when total number of examples that can be sampled are
# less than batch size: # less than batch size:
# 100 samples, 50 positives, 40 positives cannot be sampled, batch size 64. # 100 samples, 50 positives, 40 positives cannot be sampled, batch size 64.
# It should still return 64 samples, with 4 of them that couldn't have been
# sampled.
numpy_labels = np.arange(100) numpy_labels = np.arange(100)
numpy_indicator = numpy_labels < 60 numpy_indicator = numpy_labels < 60
indicator = tf.constant(numpy_indicator) indicator = np.array(numpy_indicator, np.bool)
numpy_labels = (numpy_labels - 50) >= 0 numpy_labels = (numpy_labels - 50) >= 0
labels = tf.constant(numpy_labels) labels = np.array(numpy_labels, np.bool)
def graph_fn(indicator, labels):
sampler = ( sampler = (
balanced_positive_negative_sampler.BalancedPositiveNegativeSampler()) balanced_positive_negative_sampler.BalancedPositiveNegativeSampler())
is_sampled = sampler.subsample(indicator, 64, labels) return sampler.subsample(indicator, 64, labels)
with self.test_session() as sess:
is_sampled = sess.run(is_sampled) is_sampled = self.execute_cpu(graph_fn, [indicator, labels])
self.assertTrue(sum(is_sampled) == 60) self.assertEqual(sum(is_sampled), 60)
self.assertTrue(sum(np.logical_and(numpy_labels, is_sampled)) == 10) self.assertGreaterEqual(sum(np.logical_and(numpy_labels, is_sampled)), 10)
self.assertTrue( self.assertGreaterEqual(
sum(np.logical_and(np.logical_not(numpy_labels), is_sampled)) == 50) sum(np.logical_and(np.logical_not(numpy_labels), is_sampled)), 50)
self.assertAllEqual(is_sampled, np.logical_and(is_sampled, self.assertEqual(sum(np.logical_and(is_sampled, numpy_indicator)), 60)
numpy_indicator))
def test_subsample_selection_larger_batch_size_static(self): def test_subsample_selection_larger_batch_size_static(self):
if not self.has_tpu(): return
# Test random sampling when total number of examples that can be sampled are # Test random sampling when total number of examples that can be sampled are
# less than batch size: # less than batch size:
# 100 samples, 50 positives, 40 positives cannot be sampled, batch size 64. # 100 samples, 50 positives, 40 positives cannot be sampled, batch size 64.
...@@ -147,33 +156,32 @@ class BalancedPositiveNegativeSamplerTest(test_case.TestCase): ...@@ -147,33 +156,32 @@ class BalancedPositiveNegativeSamplerTest(test_case.TestCase):
is_static=True)) is_static=True))
return sampler.subsample(indicator, 64, labels) return sampler.subsample(indicator, 64, labels)
is_sampled = self.execute(graph_fn, [indicator, labels]) is_sampled = self.execute_tpu(graph_fn, [indicator, labels])
self.assertTrue(sum(is_sampled) == 64) self.assertEqual(sum(is_sampled), 64)
self.assertTrue(sum(np.logical_and(numpy_labels, is_sampled)) >= 10) self.assertGreaterEqual(sum(np.logical_and(numpy_labels, is_sampled)), 10)
self.assertTrue( self.assertGreaterEqual(
sum(np.logical_and(np.logical_not(numpy_labels), is_sampled)) >= 50) sum(np.logical_and(np.logical_not(numpy_labels), is_sampled)), 50)
self.assertTrue(sum(np.logical_and(is_sampled, numpy_indicator)) == 60) self.assertEqual(sum(np.logical_and(is_sampled, numpy_indicator)), 60)
def test_subsample_selection_no_batch_size(self): def test_subsample_selection_no_batch_size(self):
if self.has_tpu(): return
# Test random sampling when only some examples can be sampled: # Test random sampling when only some examples can be sampled:
# 1000 samples, 6 positives (5 can be sampled). # 1000 samples, 6 positives (5 can be sampled).
numpy_labels = np.arange(1000) numpy_labels = np.arange(1000)
numpy_indicator = numpy_labels < 999 numpy_indicator = numpy_labels < 999
indicator = tf.constant(numpy_indicator)
numpy_labels = (numpy_labels - 994) >= 0 numpy_labels = (numpy_labels - 994) >= 0
labels = tf.constant(numpy_labels) def graph_fn(indicator, labels):
sampler = (balanced_positive_negative_sampler. sampler = (balanced_positive_negative_sampler.
BalancedPositiveNegativeSampler(0.01)) BalancedPositiveNegativeSampler(0.01))
is_sampled = sampler.subsample(indicator, None, labels) is_sampled = sampler.subsample(indicator, None, labels)
with self.test_session() as sess: return is_sampled
is_sampled = sess.run(is_sampled) is_sampled_out = self.execute_cpu(graph_fn, [numpy_indicator, numpy_labels])
self.assertTrue(sum(is_sampled) == 500) self.assertEqual(sum(is_sampled_out), 500)
self.assertTrue(sum(np.logical_and(numpy_labels, is_sampled)) == 5) self.assertEqual(sum(np.logical_and(numpy_labels, is_sampled_out)), 5)
self.assertTrue(sum(np.logical_and( self.assertEqual(sum(np.logical_and(
np.logical_not(numpy_labels), is_sampled)) == 495) np.logical_not(numpy_labels), is_sampled_out)), 495)
self.assertAllEqual(is_sampled, np.logical_and(is_sampled, self.assertAllEqual(is_sampled_out, np.logical_and(is_sampled_out,
numpy_indicator)) numpy_indicator))
def test_subsample_selection_no_batch_size_static(self): def test_subsample_selection_no_batch_size_static(self):
......
...@@ -24,6 +24,10 @@ from six.moves import range ...@@ -24,6 +24,10 @@ from six.moves import range
import tensorflow as tf import tensorflow as tf
from object_detection.core import prefetcher from object_detection.core import prefetcher
from object_detection.utils import tf_version
if not tf_version.is_tf1():
raise ValueError('`batcher.py` is only supported in Tensorflow 1.X')
rt_shape_str = '_runtime_shapes' rt_shape_str = '_runtime_shapes'
......
...@@ -22,10 +22,11 @@ from __future__ import print_function ...@@ -22,10 +22,11 @@ from __future__ import print_function
import numpy as np import numpy as np
from six.moves import range from six.moves import range
import tensorflow as tf import tensorflow as tf
from tensorflow.contrib import slim as contrib_slim
from object_detection.core import batcher from object_detection.core import batcher
slim = tf.contrib.slim slim = contrib_slim
class BatcherTest(tf.test.TestCase): class BatcherTest(tf.test.TestCase):
......
...@@ -14,11 +14,11 @@ ...@@ -14,11 +14,11 @@
# ============================================================================== # ==============================================================================
"""Tests for object_detection.core.box_coder.""" """Tests for object_detection.core.box_coder."""
import tensorflow as tf import tensorflow as tf
from object_detection.core import box_coder from object_detection.core import box_coder
from object_detection.core import box_list from object_detection.core import box_list
from object_detection.utils import test_case
class MockBoxCoder(box_coder.BoxCoder): class MockBoxCoder(box_coder.BoxCoder):
...@@ -34,26 +34,27 @@ class MockBoxCoder(box_coder.BoxCoder): ...@@ -34,26 +34,27 @@ class MockBoxCoder(box_coder.BoxCoder):
return box_list.BoxList(rel_codes / 2.0) return box_list.BoxList(rel_codes / 2.0)
class BoxCoderTest(tf.test.TestCase): class BoxCoderTest(test_case.TestCase):
def test_batch_decode(self): def test_batch_decode(self):
expected_boxes = [[[0.0, 0.1, 0.5, 0.6], [0.5, 0.6, 0.7, 0.8]],
[[0.1, 0.2, 0.3, 0.4], [0.7, 0.8, 0.9, 1.0]]]
def graph_fn():
mock_anchor_corners = tf.constant( mock_anchor_corners = tf.constant(
[[0, 0.1, 0.2, 0.3], [0.2, 0.4, 0.4, 0.6]], tf.float32) [[0, 0.1, 0.2, 0.3], [0.2, 0.4, 0.4, 0.6]], tf.float32)
mock_anchors = box_list.BoxList(mock_anchor_corners) mock_anchors = box_list.BoxList(mock_anchor_corners)
mock_box_coder = MockBoxCoder() mock_box_coder = MockBoxCoder()
expected_boxes = [[[0.0, 0.1, 0.5, 0.6], [0.5, 0.6, 0.7, 0.8]],
[[0.1, 0.2, 0.3, 0.4], [0.7, 0.8, 0.9, 1.0]]]
encoded_boxes_list = [mock_box_coder.encode( encoded_boxes_list = [mock_box_coder.encode(
box_list.BoxList(tf.constant(boxes)), mock_anchors) box_list.BoxList(tf.constant(boxes)), mock_anchors)
for boxes in expected_boxes] for boxes in expected_boxes]
encoded_boxes = tf.stack(encoded_boxes_list) encoded_boxes = tf.stack(encoded_boxes_list)
decoded_boxes = box_coder.batch_decode( decoded_boxes = box_coder.batch_decode(
encoded_boxes, mock_box_coder, mock_anchors) encoded_boxes, mock_box_coder, mock_anchors)
return decoded_boxes
with self.test_session() as sess: decoded_boxes_result = self.execute(graph_fn, [])
decoded_boxes_result = sess.run(decoded_boxes)
self.assertAllClose(expected_boxes, decoded_boxes_result) self.assertAllClose(expected_boxes, decoded_boxes_result)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment