Commit 47bc1813 authored by syiming's avatar syiming
Browse files

Merge remote-tracking branch 'upstream/master' into add_multilevel_crop_and_resize

parents d8611151 b035a227
......@@ -13,22 +13,21 @@
# limitations under the License.
# ==============================================================================
"""Tests for graph_rewriter_builder."""
import unittest
import mock
import tensorflow.compat.v1 as tf
import tf_slim as slim
from object_detection.builders import graph_rewriter_builder
from object_detection.protos import graph_rewriter_pb2
from object_detection.utils import tf_version
# pylint: disable=g-import-not-at-top
try:
from tensorflow.contrib import quantize as contrib_quantize
except ImportError:
# TF 2.0 doesn't ship with contrib.
pass
# pylint: enable=g-import-not-at-top
if tf_version.is_tf1():
from tensorflow.contrib import quantize as contrib_quantize # pylint: disable=g-import-not-at-top
@unittest.skipIf(tf_version.is_tf2(), 'Skipping TF1.X only test.')
class QuantizationBuilderTest(tf.test.TestCase):
def testQuantizationBuilderSetsUpCorrectTrainArguments(self):
......
......@@ -64,6 +64,7 @@ class KerasLayerHyperparams(object):
self._batch_norm_params = _build_keras_batch_norm_params(
hyperparams_config.batch_norm)
self._force_use_bias = hyperparams_config.force_use_bias
self._activation_fn = _build_activation_fn(hyperparams_config.activation)
# TODO(kaftan): Unclear if these kwargs apply to separable & depthwise conv
# (Those might use depthwise_* instead of kernel_*)
......@@ -80,6 +81,13 @@ class KerasLayerHyperparams(object):
def use_batch_norm(self):
return self._batch_norm_params is not None
def force_use_bias(self):
return self._force_use_bias
def use_bias(self):
return (self._force_use_bias or not
(self.use_batch_norm() and self.batch_norm_params()['center']))
def batch_norm_params(self, **overrides):
"""Returns a dict containing batchnorm layer construction hyperparameters.
......@@ -168,10 +176,7 @@ class KerasLayerHyperparams(object):
new_params['activation'] = None
if include_activation:
new_params['activation'] = self._activation_fn
if self.use_batch_norm() and self.batch_norm_params()['center']:
new_params['use_bias'] = False
else:
new_params['use_bias'] = True
new_params['use_bias'] = self.use_bias()
new_params.update(**overrides)
return new_params
......@@ -210,6 +215,10 @@ def build(hyperparams_config, is_training):
raise ValueError('hyperparams_config not of type '
'hyperparams_pb.Hyperparams.')
if hyperparams_config.force_use_bias:
raise ValueError('Hyperparams force_use_bias only supported by '
'KerasLayerHyperparams.')
normalizer_fn = None
batch_norm_params = None
if hyperparams_config.HasField('batch_norm'):
......
......@@ -16,6 +16,7 @@
"""Tests object_detection.core.hyperparams_builder."""
import unittest
import numpy as np
import tensorflow.compat.v1 as tf
import tf_slim as slim
......@@ -24,12 +25,14 @@ from google.protobuf import text_format
from object_detection.builders import hyperparams_builder
from object_detection.core import freezable_batch_norm
from object_detection.protos import hyperparams_pb2
from object_detection.utils import tf_version
def _get_scope_key(op):
return getattr(op, '_key_op', str(op))
@unittest.skipIf(tf_version.is_tf2(), 'Skipping TF1.X only tests.')
class HyperparamsBuilderTest(tf.test.TestCase):
def test_default_arg_scope_has_conv2d_op(self):
......@@ -149,29 +152,6 @@ class HyperparamsBuilderTest(tf.test.TestCase):
result = sess.run(regularizer(tf.constant(weights)))
self.assertAllClose(np.abs(weights).sum() * 0.5, result)
def test_return_l1_regularized_weights_keras(self):
conv_hyperparams_text_proto = """
regularizer {
l1_regularizer {
weight: 0.5
}
}
initializer {
truncated_normal_initializer {
}
}
"""
conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
keras_config = hyperparams_builder.KerasLayerHyperparams(
conv_hyperparams_proto)
regularizer = keras_config.params()['kernel_regularizer']
weights = np.array([1., -1, 4., 2.])
with self.test_session() as sess:
result = sess.run(regularizer(tf.constant(weights)))
self.assertAllClose(np.abs(weights).sum() * 0.5, result)
def test_return_l2_regularizer_weights(self):
conv_hyperparams_text_proto = """
regularizer {
......@@ -197,30 +177,39 @@ class HyperparamsBuilderTest(tf.test.TestCase):
result = sess.run(regularizer(tf.constant(weights)))
self.assertAllClose(np.power(weights, 2).sum() / 2.0 * 0.42, result)
def test_return_l2_regularizer_weights_keras(self):
def test_return_non_default_batch_norm_params_with_train_during_train(self):
conv_hyperparams_text_proto = """
regularizer {
l2_regularizer {
weight: 0.42
}
}
initializer {
truncated_normal_initializer {
}
}
batch_norm {
decay: 0.7
center: false
scale: true
epsilon: 0.03
train: true
}
"""
conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
keras_config = hyperparams_builder.KerasLayerHyperparams(
conv_hyperparams_proto)
regularizer = keras_config.params()['kernel_regularizer']
weights = np.array([1., -1, 4., 2.])
with self.test_session() as sess:
result = sess.run(regularizer(tf.constant(weights)))
self.assertAllClose(np.power(weights, 2).sum() / 2.0 * 0.42, result)
scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
is_training=True)
scope = scope_fn()
conv_scope_arguments = scope[_get_scope_key(slim.conv2d)]
self.assertEqual(conv_scope_arguments['normalizer_fn'], slim.batch_norm)
batch_norm_params = scope[_get_scope_key(slim.batch_norm)]
self.assertAlmostEqual(batch_norm_params['decay'], 0.7)
self.assertAlmostEqual(batch_norm_params['epsilon'], 0.03)
self.assertFalse(batch_norm_params['center'])
self.assertTrue(batch_norm_params['scale'])
self.assertTrue(batch_norm_params['is_training'])
def test_return_non_default_batch_norm_params_with_train_during_train(self):
def test_return_batch_norm_params_with_notrain_during_eval(self):
conv_hyperparams_text_proto = """
regularizer {
l2_regularizer {
......@@ -241,7 +230,7 @@ class HyperparamsBuilderTest(tf.test.TestCase):
conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
is_training=True)
is_training=False)
scope = scope_fn()
conv_scope_arguments = scope[_get_scope_key(slim.conv2d)]
self.assertEqual(conv_scope_arguments['normalizer_fn'], slim.batch_norm)
......@@ -250,10 +239,9 @@ class HyperparamsBuilderTest(tf.test.TestCase):
self.assertAlmostEqual(batch_norm_params['epsilon'], 0.03)
self.assertFalse(batch_norm_params['center'])
self.assertTrue(batch_norm_params['scale'])
self.assertTrue(batch_norm_params['is_training'])
self.assertFalse(batch_norm_params['is_training'])
def test_return_non_default_batch_norm_params_keras(
self):
def test_return_batch_norm_params_with_notrain_when_train_is_false(self):
conv_hyperparams_text_proto = """
regularizer {
l2_regularizer {
......@@ -268,26 +256,43 @@ class HyperparamsBuilderTest(tf.test.TestCase):
center: false
scale: true
epsilon: 0.03
train: false
}
"""
conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
keras_config = hyperparams_builder.KerasLayerHyperparams(
conv_hyperparams_proto)
self.assertTrue(keras_config.use_batch_norm())
batch_norm_params = keras_config.batch_norm_params()
self.assertAlmostEqual(batch_norm_params['momentum'], 0.7)
scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
is_training=True)
scope = scope_fn()
conv_scope_arguments = scope[_get_scope_key(slim.conv2d)]
self.assertEqual(conv_scope_arguments['normalizer_fn'], slim.batch_norm)
batch_norm_params = scope[_get_scope_key(slim.batch_norm)]
self.assertAlmostEqual(batch_norm_params['decay'], 0.7)
self.assertAlmostEqual(batch_norm_params['epsilon'], 0.03)
self.assertFalse(batch_norm_params['center'])
self.assertTrue(batch_norm_params['scale'])
self.assertFalse(batch_norm_params['is_training'])
batch_norm_layer = keras_config.build_batch_norm()
self.assertIsInstance(batch_norm_layer,
freezable_batch_norm.FreezableBatchNorm)
def test_do_not_use_batch_norm_if_default(self):
conv_hyperparams_text_proto = """
regularizer {
l2_regularizer {
}
}
initializer {
truncated_normal_initializer {
}
}
"""
conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
is_training=True)
scope = scope_fn()
conv_scope_arguments = scope[_get_scope_key(slim.conv2d)]
self.assertEqual(conv_scope_arguments['normalizer_fn'], None)
def test_return_non_default_batch_norm_params_keras_override(
self):
def test_use_none_activation(self):
conv_hyperparams_text_proto = """
regularizer {
l2_regularizer {
......@@ -297,26 +302,57 @@ class HyperparamsBuilderTest(tf.test.TestCase):
truncated_normal_initializer {
}
}
batch_norm {
decay: 0.7
center: false
scale: true
epsilon: 0.03
activation: NONE
"""
conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
is_training=True)
scope = scope_fn()
conv_scope_arguments = scope[_get_scope_key(slim.conv2d)]
self.assertEqual(conv_scope_arguments['activation_fn'], None)
def test_use_relu_activation(self):
conv_hyperparams_text_proto = """
regularizer {
l2_regularizer {
}
}
initializer {
truncated_normal_initializer {
}
}
activation: RELU
"""
conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
keras_config = hyperparams_builder.KerasLayerHyperparams(
conv_hyperparams_proto)
scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
is_training=True)
scope = scope_fn()
conv_scope_arguments = scope[_get_scope_key(slim.conv2d)]
self.assertEqual(conv_scope_arguments['activation_fn'], tf.nn.relu)
self.assertTrue(keras_config.use_batch_norm())
batch_norm_params = keras_config.batch_norm_params(momentum=0.4)
self.assertAlmostEqual(batch_norm_params['momentum'], 0.4)
self.assertAlmostEqual(batch_norm_params['epsilon'], 0.03)
self.assertFalse(batch_norm_params['center'])
self.assertTrue(batch_norm_params['scale'])
def test_use_relu_6_activation(self):
conv_hyperparams_text_proto = """
regularizer {
l2_regularizer {
}
}
initializer {
truncated_normal_initializer {
}
}
activation: RELU_6
"""
conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
is_training=True)
scope = scope_fn()
conv_scope_arguments = scope[_get_scope_key(slim.conv2d)]
self.assertEqual(conv_scope_arguments['activation_fn'], tf.nn.relu6)
def test_return_batch_norm_params_with_notrain_during_eval(self):
def test_use_swish_activation(self):
conv_hyperparams_text_proto = """
regularizer {
l2_regularizer {
......@@ -326,44 +362,89 @@ class HyperparamsBuilderTest(tf.test.TestCase):
truncated_normal_initializer {
}
}
batch_norm {
decay: 0.7
center: false
scale: true
epsilon: 0.03
train: true
activation: SWISH
"""
conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
is_training=True)
scope = scope_fn()
conv_scope_arguments = scope[_get_scope_key(slim.conv2d)]
self.assertEqual(conv_scope_arguments['activation_fn'], tf.nn.swish)
def _assert_variance_in_range(self, initializer, shape, variance,
tol=1e-2):
with tf.Graph().as_default() as g:
with self.test_session(graph=g) as sess:
var = tf.get_variable(
name='test',
shape=shape,
dtype=tf.float32,
initializer=initializer)
sess.run(tf.global_variables_initializer())
values = sess.run(var)
self.assertAllClose(np.var(values), variance, tol, tol)
def test_variance_in_range_with_variance_scaling_initializer_fan_in(self):
conv_hyperparams_text_proto = """
regularizer {
l2_regularizer {
}
}
initializer {
variance_scaling_initializer {
factor: 2.0
mode: FAN_IN
uniform: false
}
}
"""
conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
is_training=False)
is_training=True)
scope = scope_fn()
conv_scope_arguments = scope[_get_scope_key(slim.conv2d)]
self.assertEqual(conv_scope_arguments['normalizer_fn'], slim.batch_norm)
batch_norm_params = scope[_get_scope_key(slim.batch_norm)]
self.assertAlmostEqual(batch_norm_params['decay'], 0.7)
self.assertAlmostEqual(batch_norm_params['epsilon'], 0.03)
self.assertFalse(batch_norm_params['center'])
self.assertTrue(batch_norm_params['scale'])
self.assertFalse(batch_norm_params['is_training'])
initializer = conv_scope_arguments['weights_initializer']
self._assert_variance_in_range(initializer, shape=[100, 40],
variance=2. / 100.)
def test_return_batch_norm_params_with_notrain_when_train_is_false(self):
def test_variance_in_range_with_variance_scaling_initializer_fan_out(self):
conv_hyperparams_text_proto = """
regularizer {
l2_regularizer {
}
}
initializer {
truncated_normal_initializer {
variance_scaling_initializer {
factor: 2.0
mode: FAN_OUT
uniform: false
}
}
batch_norm {
decay: 0.7
center: false
scale: true
epsilon: 0.03
train: false
"""
conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
is_training=True)
scope = scope_fn()
conv_scope_arguments = scope[_get_scope_key(slim.conv2d)]
initializer = conv_scope_arguments['weights_initializer']
self._assert_variance_in_range(initializer, shape=[100, 40],
variance=2. / 40.)
def test_variance_in_range_with_variance_scaling_initializer_fan_avg(self):
conv_hyperparams_text_proto = """
regularizer {
l2_regularizer {
}
}
initializer {
variance_scaling_initializer {
factor: 2.0
mode: FAN_AVG
uniform: false
}
}
"""
conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
......@@ -372,15 +453,35 @@ class HyperparamsBuilderTest(tf.test.TestCase):
is_training=True)
scope = scope_fn()
conv_scope_arguments = scope[_get_scope_key(slim.conv2d)]
self.assertEqual(conv_scope_arguments['normalizer_fn'], slim.batch_norm)
batch_norm_params = scope[_get_scope_key(slim.batch_norm)]
self.assertAlmostEqual(batch_norm_params['decay'], 0.7)
self.assertAlmostEqual(batch_norm_params['epsilon'], 0.03)
self.assertFalse(batch_norm_params['center'])
self.assertTrue(batch_norm_params['scale'])
self.assertFalse(batch_norm_params['is_training'])
initializer = conv_scope_arguments['weights_initializer']
self._assert_variance_in_range(initializer, shape=[100, 40],
variance=4. / (100. + 40.))
def test_do_not_use_batch_norm_if_default(self):
def test_variance_in_range_with_variance_scaling_initializer_uniform(self):
conv_hyperparams_text_proto = """
regularizer {
l2_regularizer {
}
}
initializer {
variance_scaling_initializer {
factor: 2.0
mode: FAN_IN
uniform: true
}
}
"""
conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
is_training=True)
scope = scope_fn()
conv_scope_arguments = scope[_get_scope_key(slim.conv2d)]
initializer = conv_scope_arguments['weights_initializer']
self._assert_variance_in_range(initializer, shape=[100, 40],
variance=2. / 100.)
def test_variance_in_range_with_truncated_normal_initializer(self):
conv_hyperparams_text_proto = """
regularizer {
l2_regularizer {
......@@ -388,6 +489,8 @@ class HyperparamsBuilderTest(tf.test.TestCase):
}
initializer {
truncated_normal_initializer {
mean: 0.0
stddev: 0.8
}
}
"""
......@@ -397,7 +500,149 @@ class HyperparamsBuilderTest(tf.test.TestCase):
is_training=True)
scope = scope_fn()
conv_scope_arguments = scope[_get_scope_key(slim.conv2d)]
self.assertEqual(conv_scope_arguments['normalizer_fn'], None)
initializer = conv_scope_arguments['weights_initializer']
self._assert_variance_in_range(initializer, shape=[100, 40],
variance=0.49, tol=1e-1)
def test_variance_in_range_with_random_normal_initializer(self):
conv_hyperparams_text_proto = """
regularizer {
l2_regularizer {
}
}
initializer {
random_normal_initializer {
mean: 0.0
stddev: 0.8
}
}
"""
conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
is_training=True)
scope = scope_fn()
conv_scope_arguments = scope[_get_scope_key(slim.conv2d)]
initializer = conv_scope_arguments['weights_initializer']
self._assert_variance_in_range(initializer, shape=[100, 40],
variance=0.64, tol=1e-1)
@unittest.skipIf(tf_version.is_tf1(), 'Skipping TF2.X only tests.')
class KerasHyperparamsBuilderTest(tf.test.TestCase):
def _assert_variance_in_range(self, initializer, shape, variance,
tol=1e-2):
var = tf.Variable(initializer(shape=shape, dtype=tf.float32))
self.assertAllClose(np.var(var.numpy()), variance, tol, tol)
def test_return_l1_regularized_weights_keras(self):
conv_hyperparams_text_proto = """
regularizer {
l1_regularizer {
weight: 0.5
}
}
initializer {
truncated_normal_initializer {
}
}
"""
conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
keras_config = hyperparams_builder.KerasLayerHyperparams(
conv_hyperparams_proto)
regularizer = keras_config.params()['kernel_regularizer']
weights = np.array([1., -1, 4., 2.])
result = regularizer(tf.constant(weights)).numpy()
self.assertAllClose(np.abs(weights).sum() * 0.5, result)
def test_return_l2_regularizer_weights_keras(self):
conv_hyperparams_text_proto = """
regularizer {
l2_regularizer {
weight: 0.42
}
}
initializer {
truncated_normal_initializer {
}
}
"""
conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
keras_config = hyperparams_builder.KerasLayerHyperparams(
conv_hyperparams_proto)
regularizer = keras_config.params()['kernel_regularizer']
weights = np.array([1., -1, 4., 2.])
result = regularizer(tf.constant(weights)).numpy()
self.assertAllClose(np.power(weights, 2).sum() / 2.0 * 0.42, result)
def test_return_non_default_batch_norm_params_keras(
self):
conv_hyperparams_text_proto = """
regularizer {
l2_regularizer {
}
}
initializer {
truncated_normal_initializer {
}
}
batch_norm {
decay: 0.7
center: false
scale: true
epsilon: 0.03
}
"""
conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
keras_config = hyperparams_builder.KerasLayerHyperparams(
conv_hyperparams_proto)
self.assertTrue(keras_config.use_batch_norm())
batch_norm_params = keras_config.batch_norm_params()
self.assertAlmostEqual(batch_norm_params['momentum'], 0.7)
self.assertAlmostEqual(batch_norm_params['epsilon'], 0.03)
self.assertFalse(batch_norm_params['center'])
self.assertTrue(batch_norm_params['scale'])
batch_norm_layer = keras_config.build_batch_norm()
self.assertIsInstance(batch_norm_layer,
freezable_batch_norm.FreezableBatchNorm)
def test_return_non_default_batch_norm_params_keras_override(
self):
conv_hyperparams_text_proto = """
regularizer {
l2_regularizer {
}
}
initializer {
truncated_normal_initializer {
}
}
batch_norm {
decay: 0.7
center: false
scale: true
epsilon: 0.03
}
"""
conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
keras_config = hyperparams_builder.KerasLayerHyperparams(
conv_hyperparams_proto)
self.assertTrue(keras_config.use_batch_norm())
batch_norm_params = keras_config.batch_norm_params(momentum=0.4)
self.assertAlmostEqual(batch_norm_params['momentum'], 0.4)
self.assertAlmostEqual(batch_norm_params['epsilon'], 0.03)
self.assertFalse(batch_norm_params['center'])
self.assertTrue(batch_norm_params['scale'])
def test_do_not_use_batch_norm_if_default_keras(self):
conv_hyperparams_text_proto = """
......@@ -422,7 +667,7 @@ class HyperparamsBuilderTest(tf.test.TestCase):
self.assertIsInstance(identity_layer,
tf.keras.layers.Lambda)
def test_use_none_activation(self):
def test_do_not_use_bias_if_batch_norm_center_keras(self):
conv_hyperparams_text_proto = """
regularizer {
l2_regularizer {
......@@ -432,17 +677,27 @@ class HyperparamsBuilderTest(tf.test.TestCase):
truncated_normal_initializer {
}
}
activation: NONE
batch_norm {
decay: 0.7
center: true
scale: true
epsilon: 0.03
train: true
}
"""
conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
is_training=True)
scope = scope_fn()
conv_scope_arguments = scope[_get_scope_key(slim.conv2d)]
self.assertEqual(conv_scope_arguments['activation_fn'], None)
keras_config = hyperparams_builder.KerasLayerHyperparams(
conv_hyperparams_proto)
def test_use_none_activation_keras(self):
self.assertTrue(keras_config.use_batch_norm())
batch_norm_params = keras_config.batch_norm_params()
self.assertTrue(batch_norm_params['center'])
self.assertTrue(batch_norm_params['scale'])
hyperparams = keras_config.params()
self.assertFalse(hyperparams['use_bias'])
def test_force_use_bias_if_batch_norm_center_keras(self):
conv_hyperparams_text_proto = """
regularizer {
l2_regularizer {
......@@ -452,20 +707,28 @@ class HyperparamsBuilderTest(tf.test.TestCase):
truncated_normal_initializer {
}
}
activation: NONE
batch_norm {
decay: 0.7
center: true
scale: true
epsilon: 0.03
train: true
}
force_use_bias: true
"""
conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
keras_config = hyperparams_builder.KerasLayerHyperparams(
conv_hyperparams_proto)
self.assertEqual(keras_config.params()['activation'], None)
self.assertEqual(
keras_config.params(include_activation=True)['activation'], None)
activation_layer = keras_config.build_activation_layer()
self.assertIsInstance(activation_layer, tf.keras.layers.Lambda)
self.assertEqual(activation_layer.function, tf.identity)
def test_use_relu_activation(self):
self.assertTrue(keras_config.use_batch_norm())
batch_norm_params = keras_config.batch_norm_params()
self.assertTrue(batch_norm_params['center'])
self.assertTrue(batch_norm_params['scale'])
hyperparams = keras_config.params()
self.assertTrue(hyperparams['use_bias'])
def test_use_none_activation_keras(self):
conv_hyperparams_text_proto = """
regularizer {
l2_regularizer {
......@@ -475,15 +738,18 @@ class HyperparamsBuilderTest(tf.test.TestCase):
truncated_normal_initializer {
}
}
activation: RELU
activation: NONE
"""
conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
is_training=True)
scope = scope_fn()
conv_scope_arguments = scope[_get_scope_key(slim.conv2d)]
self.assertEqual(conv_scope_arguments['activation_fn'], tf.nn.relu)
keras_config = hyperparams_builder.KerasLayerHyperparams(
conv_hyperparams_proto)
self.assertIsNone(keras_config.params()['activation'])
self.assertIsNone(
keras_config.params(include_activation=True)['activation'])
activation_layer = keras_config.build_activation_layer()
self.assertIsInstance(activation_layer, tf.keras.layers.Lambda)
self.assertEqual(activation_layer.function, tf.identity)
def test_use_relu_activation_keras(self):
conv_hyperparams_text_proto = """
......@@ -501,33 +767,13 @@ class HyperparamsBuilderTest(tf.test.TestCase):
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
keras_config = hyperparams_builder.KerasLayerHyperparams(
conv_hyperparams_proto)
self.assertEqual(keras_config.params()['activation'], None)
self.assertIsNone(keras_config.params()['activation'])
self.assertEqual(
keras_config.params(include_activation=True)['activation'], tf.nn.relu)
activation_layer = keras_config.build_activation_layer()
self.assertIsInstance(activation_layer, tf.keras.layers.Lambda)
self.assertEqual(activation_layer.function, tf.nn.relu)
def test_use_relu_6_activation(self):
conv_hyperparams_text_proto = """
regularizer {
l2_regularizer {
}
}
initializer {
truncated_normal_initializer {
}
}
activation: RELU_6
"""
conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
is_training=True)
scope = scope_fn()
conv_scope_arguments = scope[_get_scope_key(slim.conv2d)]
self.assertEqual(conv_scope_arguments['activation_fn'], tf.nn.relu6)
def test_use_relu_6_activation_keras(self):
conv_hyperparams_text_proto = """
regularizer {
......@@ -544,33 +790,13 @@ class HyperparamsBuilderTest(tf.test.TestCase):
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
keras_config = hyperparams_builder.KerasLayerHyperparams(
conv_hyperparams_proto)
self.assertEqual(keras_config.params()['activation'], None)
self.assertIsNone(keras_config.params()['activation'])
self.assertEqual(
keras_config.params(include_activation=True)['activation'], tf.nn.relu6)
activation_layer = keras_config.build_activation_layer()
self.assertIsInstance(activation_layer, tf.keras.layers.Lambda)
self.assertEqual(activation_layer.function, tf.nn.relu6)
def test_use_swish_activation(self):
conv_hyperparams_text_proto = """
regularizer {
l2_regularizer {
}
}
initializer {
truncated_normal_initializer {
}
}
activation: SWISH
"""
conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
is_training=True)
scope = scope_fn()
conv_scope_arguments = scope[_get_scope_key(slim.conv2d)]
self.assertEqual(conv_scope_arguments['activation_fn'], tf.nn.swish)
def test_use_swish_activation_keras(self):
conv_hyperparams_text_proto = """
regularizer {
......@@ -587,7 +813,7 @@ class HyperparamsBuilderTest(tf.test.TestCase):
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
keras_config = hyperparams_builder.KerasLayerHyperparams(
conv_hyperparams_proto)
self.assertEqual(keras_config.params()['activation'], None)
self.assertIsNone(keras_config.params()['activation'])
self.assertEqual(
keras_config.params(include_activation=True)['activation'], tf.nn.swish)
activation_layer = keras_config.build_activation_layer()
......@@ -613,43 +839,6 @@ class HyperparamsBuilderTest(tf.test.TestCase):
new_params = keras_config.params(activation=tf.nn.relu)
self.assertEqual(new_params['activation'], tf.nn.relu)
def _assert_variance_in_range(self, initializer, shape, variance,
tol=1e-2):
with tf.Graph().as_default() as g:
with self.test_session(graph=g) as sess:
var = tf.get_variable(
name='test',
shape=shape,
dtype=tf.float32,
initializer=initializer)
sess.run(tf.global_variables_initializer())
values = sess.run(var)
self.assertAllClose(np.var(values), variance, tol, tol)
def test_variance_in_range_with_variance_scaling_initializer_fan_in(self):
conv_hyperparams_text_proto = """
regularizer {
l2_regularizer {
}
}
initializer {
variance_scaling_initializer {
factor: 2.0
mode: FAN_IN
uniform: false
}
}
"""
conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
is_training=True)
scope = scope_fn()
conv_scope_arguments = scope[_get_scope_key(slim.conv2d)]
initializer = conv_scope_arguments['weights_initializer']
self._assert_variance_in_range(initializer, shape=[100, 40],
variance=2. / 100.)
def test_variance_in_range_with_variance_scaling_initializer_fan_in_keras(
self):
conv_hyperparams_text_proto = """
......@@ -673,30 +862,6 @@ class HyperparamsBuilderTest(tf.test.TestCase):
self._assert_variance_in_range(initializer, shape=[100, 40],
variance=2. / 100.)
def test_variance_in_range_with_variance_scaling_initializer_fan_out(self):
conv_hyperparams_text_proto = """
regularizer {
l2_regularizer {
}
}
initializer {
variance_scaling_initializer {
factor: 2.0
mode: FAN_OUT
uniform: false
}
}
"""
conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
is_training=True)
scope = scope_fn()
conv_scope_arguments = scope[_get_scope_key(slim.conv2d)]
initializer = conv_scope_arguments['weights_initializer']
self._assert_variance_in_range(initializer, shape=[100, 40],
variance=2. / 40.)
def test_variance_in_range_with_variance_scaling_initializer_fan_out_keras(
self):
conv_hyperparams_text_proto = """
......@@ -720,30 +885,6 @@ class HyperparamsBuilderTest(tf.test.TestCase):
self._assert_variance_in_range(initializer, shape=[100, 40],
variance=2. / 40.)
def test_variance_in_range_with_variance_scaling_initializer_fan_avg(self):
conv_hyperparams_text_proto = """
regularizer {
l2_regularizer {
}
}
initializer {
variance_scaling_initializer {
factor: 2.0
mode: FAN_AVG
uniform: false
}
}
"""
conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
is_training=True)
scope = scope_fn()
conv_scope_arguments = scope[_get_scope_key(slim.conv2d)]
initializer = conv_scope_arguments['weights_initializer']
self._assert_variance_in_range(initializer, shape=[100, 40],
variance=4. / (100. + 40.))
def test_variance_in_range_with_variance_scaling_initializer_fan_avg_keras(
self):
conv_hyperparams_text_proto = """
......@@ -767,30 +908,6 @@ class HyperparamsBuilderTest(tf.test.TestCase):
self._assert_variance_in_range(initializer, shape=[100, 40],
variance=4. / (100. + 40.))
def test_variance_in_range_with_variance_scaling_initializer_uniform(self):
conv_hyperparams_text_proto = """
regularizer {
l2_regularizer {
}
}
initializer {
variance_scaling_initializer {
factor: 2.0
mode: FAN_IN
uniform: true
}
}
"""
conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
is_training=True)
scope = scope_fn()
conv_scope_arguments = scope[_get_scope_key(slim.conv2d)]
initializer = conv_scope_arguments['weights_initializer']
self._assert_variance_in_range(initializer, shape=[100, 40],
variance=2. / 100.)
def test_variance_in_range_with_variance_scaling_initializer_uniform_keras(
self):
conv_hyperparams_text_proto = """
......@@ -814,29 +931,6 @@ class HyperparamsBuilderTest(tf.test.TestCase):
self._assert_variance_in_range(initializer, shape=[100, 40],
variance=2. / 100.)
def test_variance_in_range_with_truncated_normal_initializer(self):
conv_hyperparams_text_proto = """
regularizer {
l2_regularizer {
}
}
initializer {
truncated_normal_initializer {
mean: 0.0
stddev: 0.8
}
}
"""
conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
is_training=True)
scope = scope_fn()
conv_scope_arguments = scope[_get_scope_key(slim.conv2d)]
initializer = conv_scope_arguments['weights_initializer']
self._assert_variance_in_range(initializer, shape=[100, 40],
variance=0.49, tol=1e-1)
def test_variance_in_range_with_truncated_normal_initializer_keras(self):
conv_hyperparams_text_proto = """
regularizer {
......@@ -858,29 +952,6 @@ class HyperparamsBuilderTest(tf.test.TestCase):
self._assert_variance_in_range(initializer, shape=[100, 40],
variance=0.49, tol=1e-1)
def test_variance_in_range_with_random_normal_initializer(self):
conv_hyperparams_text_proto = """
regularizer {
l2_regularizer {
}
}
initializer {
random_normal_initializer {
mean: 0.0
stddev: 0.8
}
}
"""
conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
is_training=True)
scope = scope_fn()
conv_scope_arguments = scope[_get_scope_key(slim.conv2d)]
initializer = conv_scope_arguments['weights_initializer']
self._assert_variance_in_range(initializer, shape=[100, 40],
variance=0.64, tol=1e-1)
def test_variance_in_range_with_random_normal_initializer_keras(self):
conv_hyperparams_text_proto = """
regularizer {
......@@ -902,6 +973,5 @@ class HyperparamsBuilderTest(tf.test.TestCase):
self._assert_variance_in_range(initializer, shape=[100, 40],
variance=0.64, tol=1e-1)
if __name__ == '__main__':
tf.test.main()
......@@ -18,21 +18,23 @@ import tensorflow.compat.v1 as tf
from google.protobuf import text_format
from object_detection.builders import image_resizer_builder
from object_detection.protos import image_resizer_pb2
from object_detection.utils import test_case
class ImageResizerBuilderTest(tf.test.TestCase):
class ImageResizerBuilderTest(test_case.TestCase):
def _shape_of_resized_random_image_given_text_proto(self, input_shape,
text_proto):
image_resizer_config = image_resizer_pb2.ImageResizer()
text_format.Merge(text_proto, image_resizer_config)
image_resizer_fn = image_resizer_builder.build(image_resizer_config)
images = tf.cast(
tf.random_uniform(input_shape, minval=0, maxval=255, dtype=tf.int32),
dtype=tf.float32)
resized_images, _ = image_resizer_fn(images)
with self.test_session() as sess:
return sess.run(resized_images).shape
def graph_fn():
images = tf.cast(
tf.random_uniform(input_shape, minval=0, maxval=255, dtype=tf.int32),
dtype=tf.float32)
resized_images, _ = image_resizer_fn(images)
return resized_images
return self.execute_cpu(graph_fn, []).shape
def test_build_keep_aspect_ratio_resizer_returns_expected_shape(self):
image_resizer_text_proto = """
......@@ -125,10 +127,10 @@ class ImageResizerBuilderTest(tf.test.TestCase):
image_resizer_config = image_resizer_pb2.ImageResizer()
text_format.Merge(text_proto, image_resizer_config)
image_resizer_fn = image_resizer_builder.build(image_resizer_config)
image_placeholder = tf.placeholder(tf.uint8, [1, None, None, 3])
resized_image, _ = image_resizer_fn(image_placeholder)
with self.test_session() as sess:
return sess.run(resized_image, feed_dict={image_placeholder: image})
def graph_fn(image):
resized_image, _ = image_resizer_fn(image)
return resized_image
return self.execute_cpu(graph_fn, [image])
def test_fixed_shape_resizer_nearest_neighbor_method(self):
image_resizer_text_proto = """
......
......@@ -29,19 +29,12 @@ from __future__ import division
from __future__ import print_function
import tensorflow.compat.v1 as tf
import tf_slim as slim
from object_detection.data_decoders import tf_example_decoder
from object_detection.data_decoders import tf_sequence_example_decoder
from object_detection.protos import input_reader_pb2
# pylint: disable=g-import-not-at-top
try:
import tf_slim as slim
except ImportError:
# TF 2.0 doesn't ship with contrib.
pass
# pylint: enable=g-import-not-at-top
parallel_reader = slim.parallel_reader
......@@ -82,14 +75,14 @@ def build(input_reader_config):
if input_reader_config.HasField('label_map_path'):
label_map_proto_file = input_reader_config.label_map_path
input_type = input_reader_config.input_type
if input_type == input_reader_pb2.InputType.TF_EXAMPLE:
if input_type == input_reader_pb2.InputType.Value('TF_EXAMPLE'):
decoder = tf_example_decoder.TfExampleDecoder(
load_instance_masks=input_reader_config.load_instance_masks,
instance_mask_type=input_reader_config.mask_type,
label_map_proto_file=label_map_proto_file,
load_context_features=input_reader_config.load_context_features)
return decoder.decode(string_tensor)
elif input_type == input_reader_pb2.InputType.TF_SEQUENCE_EXAMPLE:
elif input_type == input_reader_pb2.InputType.Value('TF_SEQUENCE_EXAMPLE'):
decoder = tf_sequence_example_decoder.TfSequenceExampleDecoder(
label_map_proto_file=label_map_proto_file,
load_context_features=input_reader_config.load_context_features)
......
......@@ -16,6 +16,7 @@
"""Tests for input_reader_builder."""
import os
import unittest
import numpy as np
import tensorflow.compat.v1 as tf
......@@ -26,6 +27,7 @@ from object_detection.core import standard_fields as fields
from object_detection.dataset_tools import seq_example_util
from object_detection.protos import input_reader_pb2
from object_detection.utils import dataset_util
from object_detection.utils import tf_version
def _get_labelmap_path():
......@@ -35,6 +37,7 @@ def _get_labelmap_path():
'pet_label_map.pbtxt')
@unittest.skipIf(tf_version.is_tf2(), 'Skipping TF1.X only test.')
class InputReaderBuilderTest(tf.test.TestCase):
def create_tf_record(self):
......
......@@ -16,8 +16,11 @@
"""A function to build an object detection matcher from configuration."""
from object_detection.matchers import argmax_matcher
from object_detection.matchers import bipartite_matcher
from object_detection.protos import matcher_pb2
from object_detection.utils import tf_version
if tf_version.is_tf1():
from object_detection.matchers import bipartite_matcher # pylint: disable=g-import-not-at-top
def build(matcher_config):
......@@ -48,6 +51,8 @@ def build(matcher_config):
force_match_for_each_row=matcher.force_match_for_each_row,
use_matmul_gather=matcher.use_matmul_gather)
if matcher_config.WhichOneof('matcher_oneof') == 'bipartite_matcher':
if tf_version.is_tf2():
raise ValueError('bipartite_matcher is not supported in TF 2.X')
matcher = matcher_config.bipartite_matcher
return bipartite_matcher.GreedyBipartiteMatcher(matcher.use_matmul_gather)
raise ValueError('Empty matcher.')
......@@ -20,11 +20,15 @@ import tensorflow.compat.v1 as tf
from google.protobuf import text_format
from object_detection.builders import matcher_builder
from object_detection.matchers import argmax_matcher
from object_detection.matchers import bipartite_matcher
from object_detection.protos import matcher_pb2
from object_detection.utils import test_case
from object_detection.utils import tf_version
if tf_version.is_tf1():
from object_detection.matchers import bipartite_matcher # pylint: disable=g-import-not-at-top
class MatcherBuilderTest(tf.test.TestCase):
class MatcherBuilderTest(test_case.TestCase):
def test_build_arg_max_matcher_with_defaults(self):
matcher_text_proto = """
......@@ -34,7 +38,7 @@ class MatcherBuilderTest(tf.test.TestCase):
matcher_proto = matcher_pb2.Matcher()
text_format.Merge(matcher_text_proto, matcher_proto)
matcher_object = matcher_builder.build(matcher_proto)
self.assertTrue(isinstance(matcher_object, argmax_matcher.ArgMaxMatcher))
self.assertIsInstance(matcher_object, argmax_matcher.ArgMaxMatcher)
self.assertAlmostEqual(matcher_object._matched_threshold, 0.5)
self.assertAlmostEqual(matcher_object._unmatched_threshold, 0.5)
self.assertTrue(matcher_object._negatives_lower_than_unmatched)
......@@ -49,7 +53,7 @@ class MatcherBuilderTest(tf.test.TestCase):
matcher_proto = matcher_pb2.Matcher()
text_format.Merge(matcher_text_proto, matcher_proto)
matcher_object = matcher_builder.build(matcher_proto)
self.assertTrue(isinstance(matcher_object, argmax_matcher.ArgMaxMatcher))
self.assertIsInstance(matcher_object, argmax_matcher.ArgMaxMatcher)
self.assertEqual(matcher_object._matched_threshold, None)
self.assertEqual(matcher_object._unmatched_threshold, None)
self.assertTrue(matcher_object._negatives_lower_than_unmatched)
......@@ -68,7 +72,7 @@ class MatcherBuilderTest(tf.test.TestCase):
matcher_proto = matcher_pb2.Matcher()
text_format.Merge(matcher_text_proto, matcher_proto)
matcher_object = matcher_builder.build(matcher_proto)
self.assertTrue(isinstance(matcher_object, argmax_matcher.ArgMaxMatcher))
self.assertIsInstance(matcher_object, argmax_matcher.ArgMaxMatcher)
self.assertAlmostEqual(matcher_object._matched_threshold, 0.7)
self.assertAlmostEqual(matcher_object._unmatched_threshold, 0.3)
self.assertFalse(matcher_object._negatives_lower_than_unmatched)
......@@ -76,6 +80,8 @@ class MatcherBuilderTest(tf.test.TestCase):
self.assertTrue(matcher_object._use_matmul_gather)
def test_build_bipartite_matcher(self):
if tf_version.is_tf2():
self.skipTest('BipartiteMatcher unsupported in TF 2.X. Skipping.')
matcher_text_proto = """
bipartite_matcher {
}
......@@ -83,8 +89,8 @@ class MatcherBuilderTest(tf.test.TestCase):
matcher_proto = matcher_pb2.Matcher()
text_format.Merge(matcher_text_proto, matcher_proto)
matcher_object = matcher_builder.build(matcher_proto)
self.assertTrue(
isinstance(matcher_object, bipartite_matcher.GreedyBipartiteMatcher))
self.assertIsInstance(matcher_object,
bipartite_matcher.GreedyBipartiteMatcher)
def test_raise_error_on_empty_matcher(self):
matcher_text_proto = """
......
......@@ -28,6 +28,8 @@ from object_detection.builders import region_similarity_calculator_builder as si
from object_detection.core import balanced_positive_negative_sampler as sampler
from object_detection.core import post_processing
from object_detection.core import target_assigner
from object_detection.meta_architectures import center_net_meta_arch
from object_detection.meta_architectures import context_rcnn_meta_arch
from object_detection.meta_architectures import faster_rcnn_meta_arch
from object_detection.meta_architectures import rfcn_meta_arch
from object_detection.meta_architectures import ssd_meta_arch
......@@ -47,6 +49,7 @@ from object_detection.utils import tf_version
if tf_version.is_tf2():
from object_detection.models import center_net_hourglass_feature_extractor
from object_detection.models import center_net_resnet_feature_extractor
from object_detection.models import center_net_resnet_v1_fpn_feature_extractor
from object_detection.models import faster_rcnn_inception_resnet_v2_keras_feature_extractor as frcnn_inc_res_keras
from object_detection.models import faster_rcnn_resnet_keras_feature_extractor as frcnn_resnet_keras
from object_detection.models import ssd_resnet_v1_fpn_keras_feature_extractor as ssd_resnet_v1_fpn_keras
......@@ -79,6 +82,7 @@ if tf_version.is_tf1():
from object_detection.models.ssd_mobiledet_feature_extractor import SSDMobileDetCPUFeatureExtractor
from object_detection.models.ssd_mobiledet_feature_extractor import SSDMobileDetDSPFeatureExtractor
from object_detection.models.ssd_mobiledet_feature_extractor import SSDMobileDetEdgeTPUFeatureExtractor
from object_detection.models.ssd_mobiledet_feature_extractor import SSDMobileDetGPUFeatureExtractor
from object_detection.models.ssd_pnasnet_feature_extractor import SSDPNASNetFeatureExtractor
from object_detection.predictors import rfcn_box_predictor
# pylint: enable=g-import-not-at-top
......@@ -109,8 +113,12 @@ if tf_version.is_tf2():
}
CENTER_NET_EXTRACTOR_FUNCTION_MAP = {
'resnet_v2_101': center_net_resnet_feature_extractor.resnet_v2_101,
'resnet_v2_50': center_net_resnet_feature_extractor.resnet_v2_50,
'resnet_v2_101': center_net_resnet_feature_extractor.resnet_v2_101,
'resnet_v1_50_fpn':
center_net_resnet_v1_fpn_feature_extractor.resnet_v1_50_fpn,
'resnet_v1_101_fpn':
center_net_resnet_v1_fpn_feature_extractor.resnet_v1_101_fpn,
'hourglass_104': center_net_hourglass_feature_extractor.hourglass_104,
}
......@@ -160,9 +168,14 @@ if tf_version.is_tf1():
EmbeddedSSDMobileNetV1FeatureExtractor,
'ssd_pnasnet':
SSDPNASNetFeatureExtractor,
'ssd_mobiledet_cpu': SSDMobileDetCPUFeatureExtractor,
'ssd_mobiledet_dsp': SSDMobileDetDSPFeatureExtractor,
'ssd_mobiledet_edgetpu': SSDMobileDetEdgeTPUFeatureExtractor,
'ssd_mobiledet_cpu':
SSDMobileDetCPUFeatureExtractor,
'ssd_mobiledet_dsp':
SSDMobileDetDSPFeatureExtractor,
'ssd_mobiledet_edgetpu':
SSDMobileDetEdgeTPUFeatureExtractor,
'ssd_mobiledet_gpu':
SSDMobileDetGPUFeatureExtractor,
}
FASTER_RCNN_FEATURE_EXTRACTOR_CLASS_MAP = {
......@@ -767,7 +780,9 @@ def keypoint_proto_to_params(kp_config, keypoint_map_dict):
unmatched_keypoint_score=kp_config.unmatched_keypoint_score,
box_scale=kp_config.box_scale,
candidate_search_scale=kp_config.candidate_search_scale,
candidate_ranking_mode=kp_config.candidate_ranking_mode)
candidate_ranking_mode=kp_config.candidate_ranking_mode,
offset_peak_radius=kp_config.offset_peak_radius,
per_keypoint_offset=kp_config.per_keypoint_offset)
def object_detection_proto_to_params(od_config):
......
......@@ -14,16 +14,19 @@
# limitations under the License.
# ==============================================================================
"""Tests for model_builder under TensorFlow 1.X."""
import unittest
from absl.testing import parameterized
import tensorflow.compat.v1 as tf
from object_detection.builders import model_builder
from object_detection.builders import model_builder_test
from object_detection.meta_architectures import context_rcnn_meta_arch
from object_detection.meta_architectures import ssd_meta_arch
from object_detection.protos import losses_pb2
from object_detection.utils import tf_version
@unittest.skipIf(tf_version.is_tf2(), 'Skipping TF1.X only test.')
class ModelBuilderTF1Test(model_builder_test.ModelBuilderTest):
def default_ssd_feature_extractor(self):
......@@ -39,6 +42,14 @@ class ModelBuilderTF1Test(model_builder_test.ModelBuilderTest):
return model_builder.FASTER_RCNN_FEATURE_EXTRACTOR_CLASS_MAP
@parameterized.parameters(True, False)
def test_create_context_rcnn_from_config_with_params(self, is_training):
model_proto = self.create_default_faster_rcnn_model_proto()
model_proto.faster_rcnn.context_config.attention_bottleneck_dimension = 10
model_proto.faster_rcnn.context_config.attention_temperature = 0.5
model = model_builder.build(model_proto, is_training=is_training)
self.assertIsInstance(model, context_rcnn_meta_arch.ContextRCNNMetaArch)
if __name__ == '__main__':
tf.test.main()
# Lint as: python2, python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for model_builder under TensorFlow 2.X."""
import os
import unittest
import tensorflow.compat.v1 as tf
from google.protobuf import text_format
from object_detection.builders import model_builder
from object_detection.builders import model_builder_test
from object_detection.core import losses
from object_detection.models import center_net_resnet_feature_extractor
from object_detection.protos import center_net_pb2
from object_detection.protos import model_pb2
from object_detection.utils import tf_version
@unittest.skipIf(tf_version.is_tf1(), 'Skipping TF2.X only test.')
class ModelBuilderTF2Test(model_builder_test.ModelBuilderTest):
def default_ssd_feature_extractor(self):
return 'ssd_resnet50_v1_fpn_keras'
def default_faster_rcnn_feature_extractor(self):
return 'faster_rcnn_resnet101_keras'
def ssd_feature_extractors(self):
return model_builder.SSD_KERAS_FEATURE_EXTRACTOR_CLASS_MAP
def faster_rcnn_feature_extractors(self):
return model_builder.FASTER_RCNN_KERAS_FEATURE_EXTRACTOR_CLASS_MAP
def get_fake_label_map_file_path(self):
keypoint_spec_text = """
item {
name: "/m/01g317"
id: 1
display_name: "person"
keypoints {
id: 0
label: 'nose'
}
keypoints {
id: 1
label: 'left_shoulder'
}
keypoints {
id: 2
label: 'right_shoulder'
}
keypoints {
id: 3
label: 'hip'
}
}
"""
keypoint_label_map_path = os.path.join(
self.get_temp_dir(), 'keypoint_label_map')
with tf.gfile.Open(keypoint_label_map_path, 'wb') as f:
f.write(keypoint_spec_text)
return keypoint_label_map_path
def get_fake_keypoint_proto(self):
task_proto_txt = """
task_name: "human_pose"
task_loss_weight: 0.9
keypoint_regression_loss_weight: 1.0
keypoint_heatmap_loss_weight: 0.1
keypoint_offset_loss_weight: 0.5
heatmap_bias_init: 2.14
keypoint_class_name: "/m/01g317"
loss {
classification_loss {
penalty_reduced_logistic_focal_loss {
alpha: 3.0
beta: 4.0
}
}
localization_loss {
l1_localization_loss {
}
}
}
keypoint_label_to_std {
key: "nose"
value: 0.3
}
keypoint_label_to_std {
key: "hip"
value: 0.0
}
keypoint_candidate_score_threshold: 0.3
num_candidates_per_keypoint: 12
peak_max_pool_kernel_size: 5
unmatched_keypoint_score: 0.05
box_scale: 1.7
candidate_search_scale: 0.2
candidate_ranking_mode: "score_distance_ratio"
offset_peak_radius: 3
per_keypoint_offset: true
"""
config = text_format.Merge(task_proto_txt,
center_net_pb2.CenterNet.KeypointEstimation())
return config
def get_fake_object_center_proto(self):
proto_txt = """
object_center_loss_weight: 0.5
heatmap_bias_init: 3.14
min_box_overlap_iou: 0.2
max_box_predictions: 15
classification_loss {
penalty_reduced_logistic_focal_loss {
alpha: 3.0
beta: 4.0
}
}
"""
return text_format.Merge(proto_txt,
center_net_pb2.CenterNet.ObjectCenterParams())
def get_fake_object_detection_proto(self):
proto_txt = """
task_loss_weight: 0.5
offset_loss_weight: 0.1
scale_loss_weight: 0.2
localization_loss {
l1_localization_loss {
}
}
"""
return text_format.Merge(proto_txt,
center_net_pb2.CenterNet.ObjectDetection())
def get_fake_mask_proto(self):
proto_txt = """
task_loss_weight: 0.7
classification_loss {
weighted_softmax {}
}
mask_height: 8
mask_width: 8
score_threshold: 0.7
heatmap_bias_init: -2.0
"""
return text_format.Merge(proto_txt,
center_net_pb2.CenterNet.MaskEstimation())
def test_create_center_net_model(self):
"""Test building a CenterNet model from proto txt."""
proto_txt = """
center_net {
num_classes: 10
feature_extractor {
type: "resnet_v2_101"
channel_stds: [4, 5, 6]
bgr_ordering: true
}
image_resizer {
keep_aspect_ratio_resizer {
min_dimension: 512
max_dimension: 512
pad_to_max_dimension: true
}
}
}
"""
# Set up the configuration proto.
config = text_format.Merge(proto_txt, model_pb2.DetectionModel())
config.center_net.object_center_params.CopyFrom(
self.get_fake_object_center_proto())
config.center_net.object_detection_task.CopyFrom(
self.get_fake_object_detection_proto())
config.center_net.keypoint_estimation_task.append(
self.get_fake_keypoint_proto())
config.center_net.keypoint_label_map_path = (
self.get_fake_label_map_file_path())
config.center_net.mask_estimation_task.CopyFrom(
self.get_fake_mask_proto())
# Build the model from the configuration.
model = model_builder.build(config, is_training=True)
# Check object center related parameters.
self.assertEqual(model._num_classes, 10)
self.assertIsInstance(model._center_params.classification_loss,
losses.PenaltyReducedLogisticFocalLoss)
self.assertEqual(model._center_params.classification_loss._alpha, 3.0)
self.assertEqual(model._center_params.classification_loss._beta, 4.0)
self.assertAlmostEqual(model._center_params.min_box_overlap_iou, 0.2)
self.assertAlmostEqual(
model._center_params.heatmap_bias_init, 3.14, places=4)
self.assertEqual(model._center_params.max_box_predictions, 15)
# Check object detection related parameters.
self.assertAlmostEqual(model._od_params.offset_loss_weight, 0.1)
self.assertAlmostEqual(model._od_params.scale_loss_weight, 0.2)
self.assertAlmostEqual(model._od_params.task_loss_weight, 0.5)
self.assertIsInstance(model._od_params.localization_loss,
losses.L1LocalizationLoss)
# Check keypoint estimation related parameters.
kp_params = model._kp_params_dict['human_pose']
self.assertAlmostEqual(kp_params.task_loss_weight, 0.9)
self.assertAlmostEqual(kp_params.keypoint_regression_loss_weight, 1.0)
self.assertAlmostEqual(kp_params.keypoint_offset_loss_weight, 0.5)
self.assertAlmostEqual(kp_params.heatmap_bias_init, 2.14, places=4)
self.assertEqual(kp_params.classification_loss._alpha, 3.0)
self.assertEqual(kp_params.keypoint_indices, [0, 1, 2, 3])
self.assertEqual(kp_params.keypoint_labels,
['nose', 'left_shoulder', 'right_shoulder', 'hip'])
self.assertAllClose(kp_params.keypoint_std_dev, [0.3, 1.0, 1.0, 0.0])
self.assertEqual(kp_params.classification_loss._beta, 4.0)
self.assertIsInstance(kp_params.localization_loss,
losses.L1LocalizationLoss)
self.assertAlmostEqual(kp_params.keypoint_candidate_score_threshold, 0.3)
self.assertEqual(kp_params.num_candidates_per_keypoint, 12)
self.assertEqual(kp_params.peak_max_pool_kernel_size, 5)
self.assertAlmostEqual(kp_params.unmatched_keypoint_score, 0.05)
self.assertAlmostEqual(kp_params.box_scale, 1.7)
self.assertAlmostEqual(kp_params.candidate_search_scale, 0.2)
self.assertEqual(kp_params.candidate_ranking_mode, 'score_distance_ratio')
self.assertEqual(kp_params.offset_peak_radius, 3)
self.assertEqual(kp_params.per_keypoint_offset, True)
# Check mask related parameters.
self.assertAlmostEqual(model._mask_params.task_loss_weight, 0.7)
self.assertIsInstance(model._mask_params.classification_loss,
losses.WeightedSoftmaxClassificationLoss)
self.assertEqual(model._mask_params.mask_height, 8)
self.assertEqual(model._mask_params.mask_width, 8)
self.assertAlmostEqual(model._mask_params.score_threshold, 0.7)
self.assertAlmostEqual(
model._mask_params.heatmap_bias_init, -2.0, places=4)
# Check feature extractor parameters.
self.assertIsInstance(
model._feature_extractor,
center_net_resnet_feature_extractor.CenterNetResnetFeatureExtractor)
self.assertAllClose(model._feature_extractor._channel_means, [0, 0, 0])
self.assertAllClose(model._feature_extractor._channel_stds, [4, 5, 6])
self.assertTrue(model._feature_extractor._bgr_ordering)
if __name__ == '__main__':
tf.test.main()
......@@ -17,10 +17,13 @@
import tensorflow.compat.v1 as tf
from tensorflow.contrib import opt as tf_opt
from object_detection.utils import learning_schedules
try:
from tensorflow.contrib import opt as tf_opt # pylint: disable=g-import-not-at-top
except: # pylint: disable=bare-except
pass
def build_optimizers_tf_v1(optimizer_config, global_step=None):
"""Create a TF v1 compatible optimizer based on config.
......
......@@ -20,6 +20,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import unittest
import six
import tensorflow.compat.v1 as tf
......@@ -27,16 +28,15 @@ from google.protobuf import text_format
from object_detection.builders import optimizer_builder
from object_detection.protos import optimizer_pb2
from object_detection.utils import tf_version
# pylint: disable=g-import-not-at-top
try:
if tf_version.is_tf1():
from tensorflow.contrib import opt as contrib_opt
except ImportError:
# TF 2.0 doesn't ship with contrib.
pass
# pylint: enable=g-import-not-at-top
@unittest.skipIf(tf_version.is_tf2(), 'Skipping TF1.X only test.')
class LearningRateBuilderTest(tf.test.TestCase):
def testBuildConstantLearningRate(self):
......@@ -118,6 +118,7 @@ class LearningRateBuilderTest(tf.test.TestCase):
optimizer_builder._create_learning_rate(learning_rate_proto)
@unittest.skipIf(tf_version.is_tf2(), 'Skipping TF1.X only test.')
class OptimizerBuilderTest(tf.test.TestCase):
def testBuildRMSPropOptimizer(self):
......
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for optimizer_builder."""
import unittest
import tensorflow.compat.v1 as tf
from google.protobuf import text_format
from object_detection.builders import optimizer_builder
from object_detection.protos import optimizer_pb2
from object_detection.utils import tf_version
@unittest.skipIf(tf_version.is_tf1(), 'Skipping TF2.X only test.')
class OptimizerBuilderV2Test(tf.test.TestCase):
"""Test building optimizers in V2 mode."""
def testBuildRMSPropOptimizer(self):
optimizer_text_proto = """
rms_prop_optimizer: {
learning_rate: {
exponential_decay_learning_rate {
initial_learning_rate: 0.004
decay_steps: 800720
decay_factor: 0.95
}
}
momentum_optimizer_value: 0.9
decay: 0.9
epsilon: 1.0
}
use_moving_average: false
"""
optimizer_proto = optimizer_pb2.Optimizer()
text_format.Merge(optimizer_text_proto, optimizer_proto)
optimizer, _ = optimizer_builder.build(optimizer_proto)
self.assertIsInstance(optimizer, tf.keras.optimizers.RMSprop)
def testBuildMomentumOptimizer(self):
optimizer_text_proto = """
momentum_optimizer: {
learning_rate: {
constant_learning_rate {
learning_rate: 0.001
}
}
momentum_optimizer_value: 0.99
}
use_moving_average: false
"""
optimizer_proto = optimizer_pb2.Optimizer()
text_format.Merge(optimizer_text_proto, optimizer_proto)
optimizer, _ = optimizer_builder.build(optimizer_proto)
self.assertIsInstance(optimizer, tf.keras.optimizers.SGD)
def testBuildAdamOptimizer(self):
optimizer_text_proto = """
adam_optimizer: {
learning_rate: {
constant_learning_rate {
learning_rate: 0.002
}
}
}
use_moving_average: false
"""
optimizer_proto = optimizer_pb2.Optimizer()
text_format.Merge(optimizer_text_proto, optimizer_proto)
optimizer, _ = optimizer_builder.build(optimizer_proto)
self.assertIsInstance(optimizer, tf.keras.optimizers.Adam)
def testMovingAverageOptimizerUnsupported(self):
optimizer_text_proto = """
adam_optimizer: {
learning_rate: {
constant_learning_rate {
learning_rate: 0.002
}
}
}
use_moving_average: True
"""
optimizer_proto = optimizer_pb2.Optimizer()
text_format.Merge(optimizer_text_proto, optimizer_proto)
with self.assertRaises(ValueError):
optimizer_builder.build(optimizer_proto)
if __name__ == '__main__':
tf.enable_v2_behavior()
tf.test.main()
......@@ -19,9 +19,10 @@ import tensorflow.compat.v1 as tf
from google.protobuf import text_format
from object_detection.builders import post_processing_builder
from object_detection.protos import post_processing_pb2
from object_detection.utils import test_case
class PostProcessingBuilderTest(tf.test.TestCase):
class PostProcessingBuilderTest(test_case.TestCase):
def test_build_non_max_suppressor_with_correct_parameters(self):
post_processing_text_proto = """
......@@ -77,13 +78,12 @@ class PostProcessingBuilderTest(tf.test.TestCase):
_, score_converter = post_processing_builder.build(
post_processing_config)
self.assertEqual(score_converter.__name__, 'identity_with_logit_scale')
inputs = tf.constant([1, 1], tf.float32)
outputs = score_converter(inputs)
with self.test_session() as sess:
converted_scores = sess.run(outputs)
expected_converted_scores = sess.run(inputs)
self.assertAllClose(converted_scores, expected_converted_scores)
def graph_fn():
inputs = tf.constant([1, 1], tf.float32)
outputs = score_converter(inputs)
return outputs
converted_scores = self.execute_cpu(graph_fn, [])
self.assertAllClose(converted_scores, [1, 1])
def test_build_identity_score_converter_with_logit_scale(self):
post_processing_text_proto = """
......@@ -95,12 +95,12 @@ class PostProcessingBuilderTest(tf.test.TestCase):
_, score_converter = post_processing_builder.build(post_processing_config)
self.assertEqual(score_converter.__name__, 'identity_with_logit_scale')
inputs = tf.constant([1, 1], tf.float32)
outputs = score_converter(inputs)
with self.test_session() as sess:
converted_scores = sess.run(outputs)
expected_converted_scores = sess.run(tf.constant([.5, .5], tf.float32))
self.assertAllClose(converted_scores, expected_converted_scores)
def graph_fn():
inputs = tf.constant([1, 1], tf.float32)
outputs = score_converter(inputs)
return outputs
converted_scores = self.execute_cpu(graph_fn, [])
self.assertAllClose(converted_scores, [.5, .5])
def test_build_sigmoid_score_converter(self):
post_processing_text_proto = """
......@@ -153,12 +153,12 @@ class PostProcessingBuilderTest(tf.test.TestCase):
self.assertEqual(calibrated_score_conversion_fn.__name__,
'calibrate_with_function_approximation')
input_scores = tf.constant([1, 1], tf.float32)
outputs = calibrated_score_conversion_fn(input_scores)
with self.test_session() as sess:
calibrated_scores = sess.run(outputs)
expected_calibrated_scores = sess.run(tf.constant([0.5, 0.5], tf.float32))
self.assertAllClose(calibrated_scores, expected_calibrated_scores)
def graph_fn():
input_scores = tf.constant([1, 1], tf.float32)
outputs = calibrated_score_conversion_fn(input_scores)
return outputs
calibrated_scores = self.execute_cpu(graph_fn, [])
self.assertAllClose(calibrated_scores, [0.5, 0.5])
def test_build_temperature_scaling_calibrator(self):
post_processing_text_proto = """
......@@ -174,12 +174,12 @@ class PostProcessingBuilderTest(tf.test.TestCase):
self.assertEqual(calibrated_score_conversion_fn.__name__,
'calibrate_with_temperature_scaling_calibration')
input_scores = tf.constant([1, 1], tf.float32)
outputs = calibrated_score_conversion_fn(input_scores)
with self.test_session() as sess:
calibrated_scores = sess.run(outputs)
expected_calibrated_scores = sess.run(tf.constant([0.5, 0.5], tf.float32))
self.assertAllClose(calibrated_scores, expected_calibrated_scores)
def graph_fn():
input_scores = tf.constant([1, 1], tf.float32)
outputs = calibrated_score_conversion_fn(input_scores)
return outputs
calibrated_scores = self.execute_cpu(graph_fn, [])
self.assertAllClose(calibrated_scores, [0.5, 0.5])
if __name__ == '__main__':
tf.test.main()
......@@ -151,6 +151,7 @@ def build(preprocessor_step_config):
{
'keypoint_flip_permutation': tuple(
config.keypoint_flip_permutation) or None,
'probability': config.probability or None,
})
if step_type == 'random_vertical_flip':
......@@ -159,10 +160,17 @@ def build(preprocessor_step_config):
{
'keypoint_flip_permutation': tuple(
config.keypoint_flip_permutation) or None,
'probability': config.probability or None,
})
if step_type == 'random_rotation90':
return (preprocessor.random_rotation90, {})
config = preprocessor_step_config.random_rotation90
return (preprocessor.random_rotation90,
{
'keypoint_rot_permutation': tuple(
config.keypoint_rot_permutation) or None,
'probability': config.probability or None,
})
if step_type == 'random_crop_image':
config = preprocessor_step_config.random_crop_image
......
......@@ -65,13 +65,15 @@ class PreprocessorBuilderTest(tf.test.TestCase):
keypoint_flip_permutation: 3
keypoint_flip_permutation: 5
keypoint_flip_permutation: 4
probability: 0.5
}
"""
preprocessor_proto = preprocessor_pb2.PreprocessingStep()
text_format.Merge(preprocessor_text_proto, preprocessor_proto)
function, args = preprocessor_builder.build(preprocessor_proto)
self.assertEqual(function, preprocessor.random_horizontal_flip)
self.assertEqual(args, {'keypoint_flip_permutation': (1, 0, 2, 3, 5, 4)})
self.assertEqual(args, {'keypoint_flip_permutation': (1, 0, 2, 3, 5, 4),
'probability': 0.5})
def test_build_random_vertical_flip(self):
preprocessor_text_proto = """
......@@ -82,23 +84,32 @@ class PreprocessorBuilderTest(tf.test.TestCase):
keypoint_flip_permutation: 3
keypoint_flip_permutation: 5
keypoint_flip_permutation: 4
probability: 0.5
}
"""
preprocessor_proto = preprocessor_pb2.PreprocessingStep()
text_format.Merge(preprocessor_text_proto, preprocessor_proto)
function, args = preprocessor_builder.build(preprocessor_proto)
self.assertEqual(function, preprocessor.random_vertical_flip)
self.assertEqual(args, {'keypoint_flip_permutation': (1, 0, 2, 3, 5, 4)})
self.assertEqual(args, {'keypoint_flip_permutation': (1, 0, 2, 3, 5, 4),
'probability': 0.5})
def test_build_random_rotation90(self):
preprocessor_text_proto = """
random_rotation90 {}
random_rotation90 {
keypoint_rot_permutation: 3
keypoint_rot_permutation: 0
keypoint_rot_permutation: 1
keypoint_rot_permutation: 2
probability: 0.5
}
"""
preprocessor_proto = preprocessor_pb2.PreprocessingStep()
text_format.Merge(preprocessor_text_proto, preprocessor_proto)
function, args = preprocessor_builder.build(preprocessor_proto)
self.assertEqual(function, preprocessor.random_rotation90)
self.assertEqual(args, {})
self.assertEqual(args, {'keypoint_rot_permutation': (3, 0, 1, 2),
'probability': 0.5})
def test_build_random_pixel_value_scale(self):
preprocessor_text_proto = """
......
This source diff could not be displayed because it is too large. You can view the blob instead.
......@@ -10,11 +10,11 @@
"# Object Detection API Demo\n",
"\n",
"\u003ctable align=\"left\"\u003e\u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://colab.sandbox.google.com/github/tensorflow/models/blob/master/research/object_detection/object_detection_tutorial.ipynb\"\u003e\n",
" \u003ca target=\"_blank\" href=\"https://colab.sandbox.google.com/github/tensorflow/models/blob/master/research/object_detection/colab_tutorials/colab_tutorials/object_detection_tutorial.ipynb\"\u003e\n",
" \u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\n",
" \u003c/a\u003e\n",
"\u003c/td\u003e\u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/models/blob/master/research/object_detection/object_detection_tutorial.ipynb\"\u003e\n",
" \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/models/blob/master/research/object_detection/colab_tutorials/colab_tutorials/object_detection_tutorial.ipynb\"\u003e\n",
" \u003cimg width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n",
"\u003c/td\u003e\u003c/table\u003e"
]
......
......@@ -27,21 +27,20 @@ from object_detection.utils import test_case
class BatchMulticlassNonMaxSuppressionTest(test_case.TestCase,
parameterized.TestCase):
@parameterized.named_parameters(('', False), ('_use_static_shapes', True))
def test_batch_multiclass_nms_with_batch_size_1(self, use_static_shapes):
boxes = tf.constant([[[[0, 0, 1, 1], [0, 0, 4, 5]],
[[0, 0.1, 1, 1.1], [0, 0.1, 2, 1.1]],
[[0, -0.1, 1, 0.9], [0, -0.1, 1, 0.9]],
[[0, 10, 1, 11], [0, 10, 1, 11]],
[[0, 10.1, 1, 11.1], [0, 10.1, 1, 11.1]],
[[0, 100, 1, 101], [0, 100, 1, 101]],
[[0, 1000, 1, 1002], [0, 999, 2, 1004]],
[[0, 1000, 1, 1002.1], [0, 999, 2, 1002.7]]]],
tf.float32)
scores = tf.constant([[[.9, 0.01], [.75, 0.05],
[.6, 0.01], [.95, 0],
[.5, 0.01], [.3, 0.01],
[.01, .85], [.01, .5]]])
def test_batch_multiclass_nms_with_batch_size_1(self):
boxes = np.array([[[[0, 0, 1, 1], [0, 0, 4, 5]],
[[0, 0.1, 1, 1.1], [0, 0.1, 2, 1.1]],
[[0, -0.1, 1, 0.9], [0, -0.1, 1, 0.9]],
[[0, 10, 1, 11], [0, 10, 1, 11]],
[[0, 10.1, 1, 11.1], [0, 10.1, 1, 11.1]],
[[0, 100, 1, 101], [0, 100, 1, 101]],
[[0, 1000, 1, 1002], [0, 999, 2, 1004]],
[[0, 1000, 1, 1002.1], [0, 999, 2, 1002.7]]]],
np.float32)
scores = np.array([[[.9, 0.01], [.75, 0.05],
[.6, 0.01], [.95, 0],
[.5, 0.01], [.3, 0.01],
[.01, .85], [.01, .5]]], np.float32)
score_thresh = 0.1
iou_thresh = .5
max_output_size = 4
......@@ -52,56 +51,51 @@ class BatchMulticlassNonMaxSuppressionTest(test_case.TestCase,
[0, 100, 1, 101]]]
exp_nms_scores = [[.95, .9, .85, .3]]
exp_nms_classes = [[0, 0, 1, 0]]
(nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_masks,
nmsed_additional_fields,
num_detections) = post_processing.batch_multiclass_non_max_suppression(
boxes,
scores,
score_thresh,
iou_thresh,
max_size_per_class=max_output_size,
max_total_size=max_output_size,
use_static_shapes=use_static_shapes)
self.assertIsNone(nmsed_masks)
self.assertIsNone(nmsed_additional_fields)
with self.test_session() as sess:
(nmsed_boxes, nmsed_scores, nmsed_classes,
num_detections) = sess.run([nmsed_boxes, nmsed_scores, nmsed_classes,
num_detections])
self.assertAllClose(nmsed_boxes, exp_nms_corners)
self.assertAllClose(nmsed_scores, exp_nms_scores)
self.assertAllClose(nmsed_classes, exp_nms_classes)
self.assertEqual(num_detections, [4])
def graph_fn(boxes, scores):
(nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_masks,
nmsed_additional_fields, num_detections
) = post_processing.batch_multiclass_non_max_suppression(
boxes, scores, score_thresh, iou_thresh,
max_size_per_class=max_output_size,
max_total_size=max_output_size)
self.assertIsNone(nmsed_masks)
self.assertIsNone(nmsed_additional_fields)
return (nmsed_boxes, nmsed_scores, nmsed_classes, num_detections)
(nmsed_boxes, nmsed_scores, nmsed_classes,
num_detections) = self.execute_cpu(graph_fn, [boxes, scores])
self.assertAllClose(nmsed_boxes, exp_nms_corners)
self.assertAllClose(nmsed_scores, exp_nms_scores)
self.assertAllClose(nmsed_classes, exp_nms_classes)
self.assertEqual(num_detections, [4])
def test_batch_iou_with_negative_data(self):
boxes = tf.constant([[[0, -0.01, 0.1, 1.1], [0, 0.2, 0.2, 5.0],
[0, -0.01, 0.1, 1.], [-1, -1, -1, -1]]], tf.float32)
iou = post_processing.batch_iou(boxes, boxes)
def graph_fn():
boxes = tf.constant([[[0, -0.01, 0.1, 1.1], [0, 0.2, 0.2, 5.0],
[0, -0.01, 0.1, 1.], [-1, -1, -1, -1]]], tf.float32)
iou = post_processing.batch_iou(boxes, boxes)
return iou
iou = self.execute_cpu(graph_fn, [])
expected_iou = [[[0.99999994, 0.0917431, 0.9099099, -1.],
[0.0917431, 1., 0.08154944, -1.],
[0.9099099, 0.08154944, 1., -1.], [-1., -1., -1., -1.]]]
with self.test_session() as sess:
iou = sess.run(iou)
self.assertAllClose(iou, expected_iou)
self.assertAllClose(iou, expected_iou)
@parameterized.parameters(False, True)
def test_batch_multiclass_nms_with_batch_size_2(self, use_dynamic_map_fn):
boxes = tf.constant([[[[0, 0, 1, 1], [0, 0, 4, 5]],
[[0, 0.1, 1, 1.1], [0, 0.1, 2, 1.1]],
[[0, -0.1, 1, 0.9], [0, -0.1, 1, 0.9]],
[[0, 10, 1, 11], [0, 10, 1, 11]]],
[[[0, 10.1, 1, 11.1], [0, 10.1, 1, 11.1]],
[[0, 100, 1, 101], [0, 100, 1, 101]],
[[0, 1000, 1, 1002], [0, 999, 2, 1004]],
[[0, 1000, 1, 1002.1], [0, 999, 2, 1002.7]]]],
tf.float32)
scores = tf.constant([[[.9, 0.01], [.75, 0.05],
[.6, 0.01], [.95, 0]],
[[.5, 0.01], [.3, 0.01],
[.01, .85], [.01, .5]]])
boxes = np.array([[[[0, 0, 1, 1], [0, 0, 4, 5]],
[[0, 0.1, 1, 1.1], [0, 0.1, 2, 1.1]],
[[0, -0.1, 1, 0.9], [0, -0.1, 1, 0.9]],
[[0, 10, 1, 11], [0, 10, 1, 11]]],
[[[0, 10.1, 1, 11.1], [0, 10.1, 1, 11.1]],
[[0, 100, 1, 101], [0, 100, 1, 101]],
[[0, 1000, 1, 1002], [0, 999, 2, 1004]],
[[0, 1000, 1, 1002.1], [0, 999, 2, 1002.7]]]],
np.float32)
scores = np.array([[[.9, 0.01], [.75, 0.05],
[.6, 0.01], [.95, 0]],
[[.5, 0.01], [.3, 0.01],
[.01, .85], [.01, .5]]], np.float32)
score_thresh = 0.1
iou_thresh = .5
max_output_size = 4
......@@ -118,49 +112,48 @@ class BatchMulticlassNonMaxSuppressionTest(test_case.TestCase,
[.85, .5, .3, 0]])
exp_nms_classes = np.array([[0, 0, 0, 0],
[1, 0, 0, 0]])
(nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_masks,
nmsed_additional_fields, num_detections
) = post_processing.batch_multiclass_non_max_suppression(
boxes, scores, score_thresh, iou_thresh,
max_size_per_class=max_output_size, max_total_size=max_output_size,
use_dynamic_map_fn=use_dynamic_map_fn)
self.assertIsNone(nmsed_masks)
self.assertIsNone(nmsed_additional_fields)
# Check static shapes
self.assertAllEqual(nmsed_boxes.shape.as_list(),
exp_nms_corners.shape)
self.assertAllEqual(nmsed_scores.shape.as_list(),
exp_nms_scores.shape)
self.assertAllEqual(nmsed_classes.shape.as_list(),
exp_nms_classes.shape)
self.assertEqual(num_detections.shape.as_list(), [2])
with self.test_session() as sess:
(nmsed_boxes, nmsed_scores, nmsed_classes,
num_detections) = sess.run([nmsed_boxes, nmsed_scores, nmsed_classes,
num_detections])
self.assertAllClose(nmsed_boxes, exp_nms_corners)
self.assertAllClose(nmsed_scores, exp_nms_scores)
self.assertAllClose(nmsed_classes, exp_nms_classes)
self.assertAllClose(num_detections, [2, 3])
def graph_fn(boxes, scores):
(nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_masks,
nmsed_additional_fields, num_detections
) = post_processing.batch_multiclass_non_max_suppression(
boxes, scores, score_thresh, iou_thresh,
max_size_per_class=max_output_size,
max_total_size=max_output_size,
use_dynamic_map_fn=use_dynamic_map_fn)
self.assertIsNone(nmsed_masks)
self.assertIsNone(nmsed_additional_fields)
# Check static shapes
self.assertAllEqual(nmsed_boxes.shape.as_list(),
exp_nms_corners.shape)
self.assertAllEqual(nmsed_scores.shape.as_list(),
exp_nms_scores.shape)
self.assertAllEqual(nmsed_classes.shape.as_list(),
exp_nms_classes.shape)
self.assertEqual(num_detections.shape.as_list(), [2])
return (nmsed_boxes, nmsed_scores, nmsed_classes, num_detections)
(nmsed_boxes, nmsed_scores, nmsed_classes,
num_detections) = self.execute_cpu(graph_fn, [boxes, scores])
self.assertAllClose(nmsed_boxes, exp_nms_corners)
self.assertAllClose(nmsed_scores, exp_nms_scores)
self.assertAllClose(nmsed_classes, exp_nms_classes)
self.assertAllClose(num_detections, [2, 3])
def test_batch_multiclass_nms_with_per_batch_clip_window(self):
boxes = tf.constant([[[[0, 0, 1, 1], [0, 0, 4, 5]],
[[0, 0.1, 1, 1.1], [0, 0.1, 2, 1.1]],
[[0, -0.1, 1, 0.9], [0, -0.1, 1, 0.9]],
[[0, 10, 1, 11], [0, 10, 1, 11]]],
[[[0, 10.1, 1, 11.1], [0, 10.1, 1, 11.1]],
[[0, 100, 1, 101], [0, 100, 1, 101]],
[[0, 1000, 1, 1002], [0, 999, 2, 1004]],
[[0, 1000, 1, 1002.1], [0, 999, 2, 1002.7]]]],
tf.float32)
scores = tf.constant([[[.9, 0.01], [.75, 0.05],
[.6, 0.01], [.95, 0]],
[[.5, 0.01], [.3, 0.01],
[.01, .85], [.01, .5]]])
clip_window = tf.constant([0., 0., 200., 200.])
boxes = np.array([[[[0, 0, 1, 1], [0, 0, 4, 5]],
[[0, 0.1, 1, 1.1], [0, 0.1, 2, 1.1]],
[[0, -0.1, 1, 0.9], [0, -0.1, 1, 0.9]],
[[0, 10, 1, 11], [0, 10, 1, 11]]],
[[[0, 10.1, 1, 11.1], [0, 10.1, 1, 11.1]],
[[0, 100, 1, 101], [0, 100, 1, 101]],
[[0, 1000, 1, 1002], [0, 999, 2, 1004]],
[[0, 1000, 1, 1002.1], [0, 999, 2, 1002.7]]]],
np.float32)
scores = np.array([[[.9, 0.01], [.75, 0.05],
[.6, 0.01], [.95, 0]],
[[.5, 0.01], [.3, 0.01],
[.01, .85], [.01, .5]]], np.float32)
clip_window = np.array([0., 0., 200., 200.], np.float32)
score_thresh = 0.1
iou_thresh = .5
max_output_size = 4
......@@ -177,50 +170,48 @@ class BatchMulticlassNonMaxSuppressionTest(test_case.TestCase,
[.5, .3, 0, 0]])
exp_nms_classes = np.array([[0, 0, 0, 0],
[0, 0, 0, 0]])
(nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_masks,
nmsed_additional_fields, num_detections
) = post_processing.batch_multiclass_non_max_suppression(
boxes, scores, score_thresh, iou_thresh,
max_size_per_class=max_output_size, max_total_size=max_output_size,
clip_window=clip_window)
self.assertIsNone(nmsed_masks)
self.assertIsNone(nmsed_additional_fields)
# Check static shapes
self.assertAllEqual(nmsed_boxes.shape.as_list(),
exp_nms_corners.shape)
self.assertAllEqual(nmsed_scores.shape.as_list(),
exp_nms_scores.shape)
self.assertAllEqual(nmsed_classes.shape.as_list(),
exp_nms_classes.shape)
self.assertEqual(num_detections.shape.as_list(), [2])
with self.test_session() as sess:
(nmsed_boxes, nmsed_scores, nmsed_classes,
num_detections) = sess.run([nmsed_boxes, nmsed_scores, nmsed_classes,
num_detections])
self.assertAllClose(nmsed_boxes, exp_nms_corners)
self.assertAllClose(nmsed_scores, exp_nms_scores)
self.assertAllClose(nmsed_classes, exp_nms_classes)
self.assertAllClose(num_detections, [2, 2])
def graph_fn(boxes, scores, clip_window):
(nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_masks,
nmsed_additional_fields, num_detections
) = post_processing.batch_multiclass_non_max_suppression(
boxes, scores, score_thresh, iou_thresh,
max_size_per_class=max_output_size, max_total_size=max_output_size,
clip_window=clip_window)
self.assertIsNone(nmsed_masks)
self.assertIsNone(nmsed_additional_fields)
# Check static shapes
self.assertAllEqual(nmsed_boxes.shape.as_list(),
exp_nms_corners.shape)
self.assertAllEqual(nmsed_scores.shape.as_list(),
exp_nms_scores.shape)
self.assertAllEqual(nmsed_classes.shape.as_list(),
exp_nms_classes.shape)
self.assertEqual(num_detections.shape.as_list(), [2])
return nmsed_boxes, nmsed_scores, nmsed_classes, num_detections
(nmsed_boxes, nmsed_scores, nmsed_classes,
num_detections) = self.execute_cpu(graph_fn, [boxes, scores, clip_window])
self.assertAllClose(nmsed_boxes, exp_nms_corners)
self.assertAllClose(nmsed_scores, exp_nms_scores)
self.assertAllClose(nmsed_classes, exp_nms_classes)
self.assertAllClose(num_detections, [2, 2])
def test_batch_multiclass_nms_with_per_image_clip_window(self):
boxes = tf.constant([[[[0, 0, 1, 1], [0, 0, 4, 5]],
[[0, 0.1, 1, 1.1], [0, 0.1, 2, 1.1]],
[[0, -0.1, 1, 0.9], [0, -0.1, 1, 0.9]],
[[0, 10, 1, 11], [0, 10, 1, 11]]],
[[[0, 10.1, 1, 11.1], [0, 10.1, 1, 11.1]],
[[0, 100, 1, 101], [0, 100, 1, 101]],
[[0, 1000, 1, 1002], [0, 999, 2, 1004]],
[[0, 1000, 1, 1002.1], [0, 999, 2, 1002.7]]]],
tf.float32)
scores = tf.constant([[[.9, 0.01], [.75, 0.05],
[.6, 0.01], [.95, 0]],
[[.5, 0.01], [.3, 0.01],
[.01, .85], [.01, .5]]])
clip_window = tf.constant([[0., 0., 5., 5.],
[0., 0., 200., 200.]])
boxes = np.array([[[[0, 0, 1, 1], [0, 0, 4, 5]],
[[0, 0.1, 1, 1.1], [0, 0.1, 2, 1.1]],
[[0, -0.1, 1, 0.9], [0, -0.1, 1, 0.9]],
[[0, 10, 1, 11], [0, 10, 1, 11]]],
[[[0, 10.1, 1, 11.1], [0, 10.1, 1, 11.1]],
[[0, 100, 1, 101], [0, 100, 1, 101]],
[[0, 1000, 1, 1002], [0, 999, 2, 1004]],
[[0, 1000, 1, 1002.1], [0, 999, 2, 1002.7]]]],
np.float32)
scores = np.array([[[.9, 0.01], [.75, 0.05],
[.6, 0.01], [.95, 0]],
[[.5, 0.01], [.3, 0.01],
[.01, .85], [.01, .5]]], np.float32)
clip_window = np.array([[0., 0., 5., 5.],
[0., 0., 200., 200.]], np.float32)
score_thresh = 0.1
iou_thresh = .5
max_output_size = 4
......@@ -238,56 +229,55 @@ class BatchMulticlassNonMaxSuppressionTest(test_case.TestCase,
exp_nms_classes = np.array([[0, 0, 0, 0],
[0, 0, 0, 0]])
(nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_masks,
nmsed_additional_fields, num_detections
) = post_processing.batch_multiclass_non_max_suppression(
boxes, scores, score_thresh, iou_thresh,
max_size_per_class=max_output_size, max_total_size=max_output_size,
clip_window=clip_window)
self.assertIsNone(nmsed_masks)
self.assertIsNone(nmsed_additional_fields)
# Check static shapes
self.assertAllEqual(nmsed_boxes.shape.as_list(),
exp_nms_corners.shape)
self.assertAllEqual(nmsed_scores.shape.as_list(),
exp_nms_scores.shape)
self.assertAllEqual(nmsed_classes.shape.as_list(),
exp_nms_classes.shape)
self.assertEqual(num_detections.shape.as_list(), [2])
with self.test_session() as sess:
(nmsed_boxes, nmsed_scores, nmsed_classes,
num_detections) = sess.run([nmsed_boxes, nmsed_scores, nmsed_classes,
num_detections])
self.assertAllClose(nmsed_boxes, exp_nms_corners)
self.assertAllClose(nmsed_scores, exp_nms_scores)
self.assertAllClose(nmsed_classes, exp_nms_classes)
self.assertAllClose(num_detections, [1, 2])
def graph_fn(boxes, scores, clip_window):
(nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_masks,
nmsed_additional_fields, num_detections
) = post_processing.batch_multiclass_non_max_suppression(
boxes, scores, score_thresh, iou_thresh,
max_size_per_class=max_output_size, max_total_size=max_output_size,
clip_window=clip_window)
self.assertIsNone(nmsed_masks)
self.assertIsNone(nmsed_additional_fields)
# Check static shapes
self.assertAllEqual(nmsed_boxes.shape.as_list(),
exp_nms_corners.shape)
self.assertAllEqual(nmsed_scores.shape.as_list(),
exp_nms_scores.shape)
self.assertAllEqual(nmsed_classes.shape.as_list(),
exp_nms_classes.shape)
self.assertEqual(num_detections.shape.as_list(), [2])
return nmsed_boxes, nmsed_scores, nmsed_classes, num_detections
(nmsed_boxes, nmsed_scores, nmsed_classes,
num_detections) = self.execute_cpu(graph_fn, [boxes, scores, clip_window])
self.assertAllClose(nmsed_boxes, exp_nms_corners)
self.assertAllClose(nmsed_scores, exp_nms_scores)
self.assertAllClose(nmsed_classes, exp_nms_classes)
self.assertAllClose(num_detections, [1, 2])
def test_batch_multiclass_nms_with_masks(self):
boxes = tf.constant([[[[0, 0, 1, 1], [0, 0, 4, 5]],
[[0, 0.1, 1, 1.1], [0, 0.1, 2, 1.1]],
[[0, -0.1, 1, 0.9], [0, -0.1, 1, 0.9]],
[[0, 10, 1, 11], [0, 10, 1, 11]]],
[[[0, 10.1, 1, 11.1], [0, 10.1, 1, 11.1]],
[[0, 100, 1, 101], [0, 100, 1, 101]],
[[0, 1000, 1, 1002], [0, 999, 2, 1004]],
[[0, 1000, 1, 1002.1], [0, 999, 2, 1002.7]]]],
tf.float32)
scores = tf.constant([[[.9, 0.01], [.75, 0.05],
[.6, 0.01], [.95, 0]],
[[.5, 0.01], [.3, 0.01],
[.01, .85], [.01, .5]]])
masks = tf.constant([[[[[0, 1], [2, 3]], [[1, 2], [3, 4]]],
[[[2, 3], [4, 5]], [[3, 4], [5, 6]]],
[[[4, 5], [6, 7]], [[5, 6], [7, 8]]],
[[[6, 7], [8, 9]], [[7, 8], [9, 10]]]],
[[[[8, 9], [10, 11]], [[9, 10], [11, 12]]],
[[[10, 11], [12, 13]], [[11, 12], [13, 14]]],
[[[12, 13], [14, 15]], [[13, 14], [15, 16]]],
[[[14, 15], [16, 17]], [[15, 16], [17, 18]]]]],
tf.float32)
boxes = np.array([[[[0, 0, 1, 1], [0, 0, 4, 5]],
[[0, 0.1, 1, 1.1], [0, 0.1, 2, 1.1]],
[[0, -0.1, 1, 0.9], [0, -0.1, 1, 0.9]],
[[0, 10, 1, 11], [0, 10, 1, 11]]],
[[[0, 10.1, 1, 11.1], [0, 10.1, 1, 11.1]],
[[0, 100, 1, 101], [0, 100, 1, 101]],
[[0, 1000, 1, 1002], [0, 999, 2, 1004]],
[[0, 1000, 1, 1002.1], [0, 999, 2, 1002.7]]]],
np.float32)
scores = np.array([[[.9, 0.01], [.75, 0.05],
[.6, 0.01], [.95, 0]],
[[.5, 0.01], [.3, 0.01],
[.01, .85], [.01, .5]]], np.float32)
masks = np.array([[[[[0, 1], [2, 3]], [[1, 2], [3, 4]]],
[[[2, 3], [4, 5]], [[3, 4], [5, 6]]],
[[[4, 5], [6, 7]], [[5, 6], [7, 8]]],
[[[6, 7], [8, 9]], [[7, 8], [9, 10]]]],
[[[[8, 9], [10, 11]], [[9, 10], [11, 12]]],
[[[10, 11], [12, 13]], [[11, 12], [13, 14]]],
[[[12, 13], [14, 15]], [[13, 14], [15, 16]]],
[[[14, 15], [16, 17]], [[15, 16], [17, 18]]]]],
np.float32)
score_thresh = 0.1
iou_thresh = .5
max_output_size = 4
......@@ -313,61 +303,58 @@ class BatchMulticlassNonMaxSuppressionTest(test_case.TestCase,
[[10, 11], [12, 13]],
[[0, 0], [0, 0]]]])
(nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_masks,
nmsed_additional_fields, num_detections
) = post_processing.batch_multiclass_non_max_suppression(
boxes, scores, score_thresh, iou_thresh,
max_size_per_class=max_output_size, max_total_size=max_output_size,
masks=masks)
self.assertIsNone(nmsed_additional_fields)
# Check static shapes
self.assertAllEqual(nmsed_boxes.shape.as_list(), exp_nms_corners.shape)
self.assertAllEqual(nmsed_scores.shape.as_list(), exp_nms_scores.shape)
self.assertAllEqual(nmsed_classes.shape.as_list(), exp_nms_classes.shape)
self.assertAllEqual(nmsed_masks.shape.as_list(), exp_nms_masks.shape)
self.assertEqual(num_detections.shape.as_list(), [2])
with self.test_session() as sess:
def graph_fn(boxes, scores, masks):
(nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_masks,
num_detections) = sess.run([nmsed_boxes, nmsed_scores, nmsed_classes,
nmsed_masks, num_detections])
nmsed_additional_fields, num_detections
) = post_processing.batch_multiclass_non_max_suppression(
boxes, scores, score_thresh, iou_thresh,
max_size_per_class=max_output_size, max_total_size=max_output_size,
masks=masks)
self.assertIsNone(nmsed_additional_fields)
# Check static shapes
self.assertAllEqual(nmsed_boxes.shape.as_list(), exp_nms_corners.shape)
self.assertAllEqual(nmsed_scores.shape.as_list(), exp_nms_scores.shape)
self.assertAllEqual(nmsed_classes.shape.as_list(), exp_nms_classes.shape)
self.assertAllEqual(nmsed_masks.shape.as_list(), exp_nms_masks.shape)
self.assertEqual(num_detections.shape.as_list(), [2])
return (nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_masks,
num_detections)
self.assertAllClose(nmsed_boxes, exp_nms_corners)
self.assertAllClose(nmsed_scores, exp_nms_scores)
self.assertAllClose(nmsed_classes, exp_nms_classes)
self.assertAllClose(num_detections, [2, 3])
self.assertAllClose(nmsed_masks, exp_nms_masks)
(nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_masks,
num_detections) = self.execute_cpu(graph_fn, [boxes, scores, masks])
self.assertAllClose(nmsed_boxes, exp_nms_corners)
self.assertAllClose(nmsed_scores, exp_nms_scores)
self.assertAllClose(nmsed_classes, exp_nms_classes)
self.assertAllClose(num_detections, [2, 3])
self.assertAllClose(nmsed_masks, exp_nms_masks)
def test_batch_multiclass_nms_with_additional_fields(self):
boxes = tf.constant([[[[0, 0, 1, 1], [0, 0, 4, 5]],
[[0, 0.1, 1, 1.1], [0, 0.1, 2, 1.1]],
[[0, -0.1, 1, 0.9], [0, -0.1, 1, 0.9]],
[[0, 10, 1, 11], [0, 10, 1, 11]]],
[[[0, 10.1, 1, 11.1], [0, 10.1, 1, 11.1]],
[[0, 100, 1, 101], [0, 100, 1, 101]],
[[0, 1000, 1, 1002], [0, 999, 2, 1004]],
[[0, 1000, 1, 1002.1], [0, 999, 2, 1002.7]]]],
tf.float32)
scores = tf.constant([[[.9, 0.01], [.75, 0.05],
[.6, 0.01], [.95, 0]],
[[.5, 0.01], [.3, 0.01],
[.01, .85], [.01, .5]]])
additional_fields = {
'keypoints': tf.constant(
[[[[6, 7], [8, 9]],
[[0, 1], [2, 3]],
[[0, 0], [0, 0]],
[[0, 0], [0, 0]]],
[[[13, 14], [15, 16]],
[[8, 9], [10, 11]],
[[10, 11], [12, 13]],
[[0, 0], [0, 0]]]],
tf.float32)
}
additional_fields['size'] = tf.constant(
boxes = np.array([[[[0, 0, 1, 1], [0, 0, 4, 5]],
[[0, 0.1, 1, 1.1], [0, 0.1, 2, 1.1]],
[[0, -0.1, 1, 0.9], [0, -0.1, 1, 0.9]],
[[0, 10, 1, 11], [0, 10, 1, 11]]],
[[[0, 10.1, 1, 11.1], [0, 10.1, 1, 11.1]],
[[0, 100, 1, 101], [0, 100, 1, 101]],
[[0, 1000, 1, 1002], [0, 999, 2, 1004]],
[[0, 1000, 1, 1002.1], [0, 999, 2, 1002.7]]]],
np.float32)
scores = np.array([[[.9, 0.01], [.75, 0.05],
[.6, 0.01], [.95, 0]],
[[.5, 0.01], [.3, 0.01],
[.01, .85], [.01, .5]]], np.float32)
keypoints = np.array(
[[[[6, 7], [8, 9]],
[[0, 1], [2, 3]],
[[0, 0], [0, 0]],
[[0, 0], [0, 0]]],
[[[13, 14], [15, 16]],
[[8, 9], [10, 11]],
[[10, 11], [12, 13]],
[[0, 0], [0, 0]]]],
np.float32)
size = np.array(
[[[[6], [8]], [[0], [2]], [[0], [0]], [[0], [0]]],
[[[13], [15]], [[8], [10]], [[10], [12]], [[0], [0]]]], tf.float32)
[[[13], [15]], [[8], [10]], [[10], [12]], [[0], [0]]]], np.float32)
score_thresh = 0.1
iou_thresh = .5
max_output_size = 4
......@@ -399,43 +386,43 @@ class BatchMulticlassNonMaxSuppressionTest(test_case.TestCase,
[[[10], [12]], [[13], [15]],
[[8], [10]], [[0], [0]]]])
(nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_masks,
nmsed_additional_fields, num_detections
) = post_processing.batch_multiclass_non_max_suppression(
boxes, scores, score_thresh, iou_thresh,
max_size_per_class=max_output_size, max_total_size=max_output_size,
additional_fields=additional_fields)
self.assertIsNone(nmsed_masks)
# Check static shapes
self.assertAllEqual(nmsed_boxes.shape.as_list(), exp_nms_corners.shape)
self.assertAllEqual(nmsed_scores.shape.as_list(), exp_nms_scores.shape)
self.assertAllEqual(nmsed_classes.shape.as_list(), exp_nms_classes.shape)
self.assertEqual(len(nmsed_additional_fields),
len(exp_nms_additional_fields))
for key in exp_nms_additional_fields:
self.assertAllEqual(nmsed_additional_fields[key].shape.as_list(),
exp_nms_additional_fields[key].shape)
self.assertEqual(num_detections.shape.as_list(), [2])
with self.test_session() as sess:
(nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_additional_fields,
num_detections) = sess.run([nmsed_boxes, nmsed_scores, nmsed_classes,
nmsed_additional_fields, num_detections])
self.assertAllClose(nmsed_boxes, exp_nms_corners)
self.assertAllClose(nmsed_scores, exp_nms_scores)
self.assertAllClose(nmsed_classes, exp_nms_classes)
def graph_fn(boxes, scores, keypoints, size):
additional_fields = {'keypoints': keypoints, 'size': size}
(nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_masks,
nmsed_additional_fields, num_detections
) = post_processing.batch_multiclass_non_max_suppression(
boxes, scores, score_thresh, iou_thresh,
max_size_per_class=max_output_size, max_total_size=max_output_size,
additional_fields=additional_fields)
self.assertIsNone(nmsed_masks)
# Check static shapes
self.assertAllEqual(nmsed_boxes.shape.as_list(), exp_nms_corners.shape)
self.assertAllEqual(nmsed_scores.shape.as_list(), exp_nms_scores.shape)
self.assertAllEqual(nmsed_classes.shape.as_list(), exp_nms_classes.shape)
self.assertEqual(len(nmsed_additional_fields),
len(exp_nms_additional_fields))
for key in exp_nms_additional_fields:
self.assertAllClose(nmsed_additional_fields[key],
exp_nms_additional_fields[key])
self.assertAllClose(num_detections, [2, 3])
def test_batch_multiclass_nms_with_dynamic_batch_size(self):
boxes_placeholder = tf.placeholder(tf.float32, shape=(None, None, 2, 4))
scores_placeholder = tf.placeholder(tf.float32, shape=(None, None, 2))
masks_placeholder = tf.placeholder(tf.float32, shape=(None, None, 2, 2, 2))
self.assertAllEqual(nmsed_additional_fields[key].shape.as_list(),
exp_nms_additional_fields[key].shape)
self.assertEqual(num_detections.shape.as_list(), [2])
return (nmsed_boxes, nmsed_scores, nmsed_classes,
nmsed_additional_fields['keypoints'],
nmsed_additional_fields['size'],
num_detections)
(nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_keypoints, nmsed_size,
num_detections) = self.execute_cpu(graph_fn, [boxes, scores, keypoints,
size])
self.assertAllClose(nmsed_boxes, exp_nms_corners)
self.assertAllClose(nmsed_scores, exp_nms_scores)
self.assertAllClose(nmsed_classes, exp_nms_classes)
self.assertAllClose(nmsed_keypoints,
exp_nms_additional_fields['keypoints'])
self.assertAllClose(nmsed_size,
exp_nms_additional_fields['size'])
self.assertAllClose(num_detections, [2, 3])
def test_batch_multiclass_nms_with_masks_and_num_valid_boxes(self):
boxes = np.array([[[[0, 0, 1, 1], [0, 0, 4, 5]],
[[0, 0.1, 1, 1.1], [0, 0.1, 2, 1.1]],
[[0, -0.1, 1, 0.9], [0, -0.1, 1, 0.9]],
......@@ -443,11 +430,12 @@ class BatchMulticlassNonMaxSuppressionTest(test_case.TestCase,
[[[0, 10.1, 1, 11.1], [0, 10.1, 1, 11.1]],
[[0, 100, 1, 101], [0, 100, 1, 101]],
[[0, 1000, 1, 1002], [0, 999, 2, 1004]],
[[0, 1000, 1, 1002.1], [0, 999, 2, 1002.7]]]])
[[0, 1000, 1, 1002.1], [0, 999, 2, 1002.7]]]],
np.float32)
scores = np.array([[[.9, 0.01], [.75, 0.05],
[.6, 0.01], [.95, 0]],
[[.5, 0.01], [.3, 0.01],
[.01, .85], [.01, .5]]])
[.01, .85], [.01, .5]]], np.float32)
masks = np.array([[[[[0, 1], [2, 3]], [[1, 2], [3, 4]]],
[[[2, 3], [4, 5]], [[3, 4], [5, 6]]],
[[[4, 5], [6, 7]], [[5, 6], [7, 8]]],
......@@ -455,84 +443,9 @@ class BatchMulticlassNonMaxSuppressionTest(test_case.TestCase,
[[[[8, 9], [10, 11]], [[9, 10], [11, 12]]],
[[[10, 11], [12, 13]], [[11, 12], [13, 14]]],
[[[12, 13], [14, 15]], [[13, 14], [15, 16]]],
[[[14, 15], [16, 17]], [[15, 16], [17, 18]]]]])
score_thresh = 0.1
iou_thresh = .5
max_output_size = 4
exp_nms_corners = np.array([[[0, 10, 1, 11],
[0, 0, 1, 1],
[0, 0, 0, 0],
[0, 0, 0, 0]],
[[0, 999, 2, 1004],
[0, 10.1, 1, 11.1],
[0, 100, 1, 101],
[0, 0, 0, 0]]])
exp_nms_scores = np.array([[.95, .9, 0, 0],
[.85, .5, .3, 0]])
exp_nms_classes = np.array([[0, 0, 0, 0],
[1, 0, 0, 0]])
exp_nms_masks = np.array([[[[6, 7], [8, 9]],
[[0, 1], [2, 3]],
[[0, 0], [0, 0]],
[[0, 0], [0, 0]]],
[[[13, 14], [15, 16]],
[[8, 9], [10, 11]],
[[10, 11], [12, 13]],
[[0, 0], [0, 0]]]])
(nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_masks,
nmsed_additional_fields, num_detections
) = post_processing.batch_multiclass_non_max_suppression(
boxes_placeholder, scores_placeholder, score_thresh, iou_thresh,
max_size_per_class=max_output_size, max_total_size=max_output_size,
masks=masks_placeholder)
self.assertIsNone(nmsed_additional_fields)
# Check static shapes
self.assertAllEqual(nmsed_boxes.shape.as_list(), [None, 4, 4])
self.assertAllEqual(nmsed_scores.shape.as_list(), [None, 4])
self.assertAllEqual(nmsed_classes.shape.as_list(), [None, 4])
self.assertAllEqual(nmsed_masks.shape.as_list(), [None, 4, 2, 2])
self.assertEqual(num_detections.shape.as_list(), [None])
with self.test_session() as sess:
(nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_masks,
num_detections) = sess.run([nmsed_boxes, nmsed_scores, nmsed_classes,
nmsed_masks, num_detections],
feed_dict={boxes_placeholder: boxes,
scores_placeholder: scores,
masks_placeholder: masks})
self.assertAllClose(nmsed_boxes, exp_nms_corners)
self.assertAllClose(nmsed_scores, exp_nms_scores)
self.assertAllClose(nmsed_classes, exp_nms_classes)
self.assertAllClose(num_detections, [2, 3])
self.assertAllClose(nmsed_masks, exp_nms_masks)
def test_batch_multiclass_nms_with_masks_and_num_valid_boxes(self):
boxes = tf.constant([[[[0, 0, 1, 1], [0, 0, 4, 5]],
[[0, 0.1, 1, 1.1], [0, 0.1, 2, 1.1]],
[[0, -0.1, 1, 0.9], [0, -0.1, 1, 0.9]],
[[0, 10, 1, 11], [0, 10, 1, 11]]],
[[[0, 10.1, 1, 11.1], [0, 10.1, 1, 11.1]],
[[0, 100, 1, 101], [0, 100, 1, 101]],
[[0, 1000, 1, 1002], [0, 999, 2, 1004]],
[[0, 1000, 1, 1002.1], [0, 999, 2, 1002.7]]]],
tf.float32)
scores = tf.constant([[[.9, 0.01], [.75, 0.05],
[.6, 0.01], [.95, 0]],
[[.5, 0.01], [.3, 0.01],
[.01, .85], [.01, .5]]])
masks = tf.constant([[[[[0, 1], [2, 3]], [[1, 2], [3, 4]]],
[[[2, 3], [4, 5]], [[3, 4], [5, 6]]],
[[[4, 5], [6, 7]], [[5, 6], [7, 8]]],
[[[6, 7], [8, 9]], [[7, 8], [9, 10]]]],
[[[[8, 9], [10, 11]], [[9, 10], [11, 12]]],
[[[10, 11], [12, 13]], [[11, 12], [13, 14]]],
[[[12, 13], [14, 15]], [[13, 14], [15, 16]]],
[[[14, 15], [16, 17]], [[15, 16], [17, 18]]]]],
tf.float32)
num_valid_boxes = tf.constant([1, 1], tf.int32)
[[[14, 15], [16, 17]], [[15, 16], [17, 18]]]]],
np.float32)
num_valid_boxes = np.array([1, 1], np.int32)
score_thresh = 0.1
iou_thresh = .5
max_output_size = 4
......@@ -558,58 +471,56 @@ class BatchMulticlassNonMaxSuppressionTest(test_case.TestCase,
[[0, 0], [0, 0]],
[[0, 0], [0, 0]]]]
(nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_masks,
nmsed_additional_fields, num_detections
) = post_processing.batch_multiclass_non_max_suppression(
boxes, scores, score_thresh, iou_thresh,
max_size_per_class=max_output_size, max_total_size=max_output_size,
num_valid_boxes=num_valid_boxes, masks=masks)
self.assertIsNone(nmsed_additional_fields)
with self.test_session() as sess:
def graph_fn(boxes, scores, masks, num_valid_boxes):
(nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_masks,
num_detections) = sess.run([nmsed_boxes, nmsed_scores, nmsed_classes,
nmsed_masks, num_detections])
self.assertAllClose(nmsed_boxes, exp_nms_corners)
self.assertAllClose(nmsed_scores, exp_nms_scores)
self.assertAllClose(nmsed_classes, exp_nms_classes)
self.assertAllClose(num_detections, [1, 1])
self.assertAllClose(nmsed_masks, exp_nms_masks)
nmsed_additional_fields, num_detections
) = post_processing.batch_multiclass_non_max_suppression(
boxes, scores, score_thresh, iou_thresh,
max_size_per_class=max_output_size, max_total_size=max_output_size,
masks=masks, num_valid_boxes=num_valid_boxes)
self.assertIsNone(nmsed_additional_fields)
return (nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_masks,
num_detections)
(nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_masks,
num_detections) = self.execute_cpu(graph_fn, [boxes, scores, masks,
num_valid_boxes])
self.assertAllClose(nmsed_boxes, exp_nms_corners)
self.assertAllClose(nmsed_scores, exp_nms_scores)
self.assertAllClose(nmsed_classes, exp_nms_classes)
self.assertAllClose(num_detections, [1, 1])
self.assertAllClose(nmsed_masks, exp_nms_masks)
def test_batch_multiclass_nms_with_additional_fields_and_num_valid_boxes(
self):
boxes = tf.constant([[[[0, 0, 1, 1], [0, 0, 4, 5]],
[[0, 0.1, 1, 1.1], [0, 0.1, 2, 1.1]],
[[0, -0.1, 1, 0.9], [0, -0.1, 1, 0.9]],
[[0, 10, 1, 11], [0, 10, 1, 11]]],
[[[0, 10.1, 1, 11.1], [0, 10.1, 1, 11.1]],
[[0, 100, 1, 101], [0, 100, 1, 101]],
[[0, 1000, 1, 1002], [0, 999, 2, 1004]],
[[0, 1000, 1, 1002.1], [0, 999, 2, 1002.7]]]],
tf.float32)
scores = tf.constant([[[.9, 0.01], [.75, 0.05],
[.6, 0.01], [.95, 0]],
[[.5, 0.01], [.3, 0.01],
[.01, .85], [.01, .5]]])
additional_fields = {
'keypoints': tf.constant(
[[[[6, 7], [8, 9]],
[[0, 1], [2, 3]],
[[0, 0], [0, 0]],
[[0, 0], [0, 0]]],
[[[13, 14], [15, 16]],
[[8, 9], [10, 11]],
[[10, 11], [12, 13]],
[[0, 0], [0, 0]]]],
tf.float32)
}
additional_fields['size'] = tf.constant(
boxes = np.array([[[[0, 0, 1, 1], [0, 0, 4, 5]],
[[0, 0.1, 1, 1.1], [0, 0.1, 2, 1.1]],
[[0, -0.1, 1, 0.9], [0, -0.1, 1, 0.9]],
[[0, 10, 1, 11], [0, 10, 1, 11]]],
[[[0, 10.1, 1, 11.1], [0, 10.1, 1, 11.1]],
[[0, 100, 1, 101], [0, 100, 1, 101]],
[[0, 1000, 1, 1002], [0, 999, 2, 1004]],
[[0, 1000, 1, 1002.1], [0, 999, 2, 1002.7]]]],
np.float32)
scores = np.array([[[.9, 0.01], [.75, 0.05],
[.6, 0.01], [.95, 0]],
[[.5, 0.01], [.3, 0.01],
[.01, .85], [.01, .5]]], np.float32)
keypoints = np.array(
[[[[6, 7], [8, 9]],
[[0, 1], [2, 3]],
[[0, 0], [0, 0]],
[[0, 0], [0, 0]]],
[[[13, 14], [15, 16]],
[[8, 9], [10, 11]],
[[10, 11], [12, 13]],
[[0, 0], [0, 0]]]],
np.float32)
size = np.array(
[[[[7], [9]], [[1], [3]], [[0], [0]], [[0], [0]]],
[[[14], [16]], [[9], [11]], [[11], [13]], [[0], [0]]]], tf.float32)
[[[14], [16]], [[9], [11]], [[11], [13]], [[0], [0]]]], np.float32)
num_valid_boxes = tf.constant([1, 1], tf.int32)
num_valid_boxes = np.array([1, 1], np.int32)
score_thresh = 0.1
iou_thresh = .5
max_output_size = 4
......@@ -641,45 +552,48 @@ class BatchMulticlassNonMaxSuppressionTest(test_case.TestCase,
[[0], [0]], [[0], [0]]],
[[[14], [16]], [[0], [0]],
[[0], [0]], [[0], [0]]]])
(nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_masks,
nmsed_additional_fields, num_detections
) = post_processing.batch_multiclass_non_max_suppression(
boxes, scores, score_thresh, iou_thresh,
max_size_per_class=max_output_size, max_total_size=max_output_size,
num_valid_boxes=num_valid_boxes,
additional_fields=additional_fields)
self.assertIsNone(nmsed_masks)
with self.test_session() as sess:
(nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_additional_fields,
num_detections) = sess.run([nmsed_boxes, nmsed_scores, nmsed_classes,
nmsed_additional_fields, num_detections])
self.assertAllClose(nmsed_boxes, exp_nms_corners)
self.assertAllClose(nmsed_scores, exp_nms_scores)
self.assertAllClose(nmsed_classes, exp_nms_classes)
for key in exp_nms_additional_fields:
self.assertAllClose(nmsed_additional_fields[key],
exp_nms_additional_fields[key])
self.assertAllClose(num_detections, [1, 1])
def graph_fn(boxes, scores, keypoints, size, num_valid_boxes):
additional_fields = {'keypoints': keypoints, 'size': size}
(nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_masks,
nmsed_additional_fields, num_detections
) = post_processing.batch_multiclass_non_max_suppression(
boxes, scores, score_thresh, iou_thresh,
max_size_per_class=max_output_size, max_total_size=max_output_size,
num_valid_boxes=num_valid_boxes,
additional_fields=additional_fields)
self.assertIsNone(nmsed_masks)
return (nmsed_boxes, nmsed_scores, nmsed_classes,
nmsed_additional_fields['keypoints'],
nmsed_additional_fields['size'], num_detections)
(nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_keypoints, nmsed_size,
num_detections) = self.execute_cpu(graph_fn, [boxes, scores, keypoints,
size, num_valid_boxes])
self.assertAllClose(nmsed_boxes, exp_nms_corners)
self.assertAllClose(nmsed_scores, exp_nms_scores)
self.assertAllClose(nmsed_classes, exp_nms_classes)
self.assertAllClose(nmsed_keypoints,
exp_nms_additional_fields['keypoints'])
self.assertAllClose(nmsed_size,
exp_nms_additional_fields['size'])
self.assertAllClose(num_detections, [1, 1])
def test_combined_nms_with_batch_size_2(self):
"""Test use_combined_nms."""
boxes = tf.constant([[[[0, 0, 0.1, 0.1], [0, 0, 0.1, 0.1]],
[[0, 0.01, 1, 0.11], [0, 0.6, 0.1, 0.7]],
[[0, -0.01, 0.1, 0.09], [0, -0.1, 0.1, 0.09]],
[[0, 0.11, 0.1, 0.2], [0, 0.11, 0.1, 0.2]]],
[[[0, 0, 0.2, 0.2], [0, 0, 0.2, 0.2]],
[[0, 0.02, 0.2, 0.22], [0, 0.02, 0.2, 0.22]],
[[0, -0.02, 0.2, 0.19], [0, -0.02, 0.2, 0.19]],
[[0, 0.21, 0.2, 0.3], [0, 0.21, 0.2, 0.3]]]],
tf.float32)
scores = tf.constant([[[.1, 0.9], [.75, 0.8],
[.6, 0.3], [0.95, 0.1]],
[[.1, 0.9], [.75, 0.8],
[.6, .3], [.95, .1]]])
boxes = np.array([[[[0, 0, 0.1, 0.1], [0, 0, 0.1, 0.1]],
[[0, 0.01, 1, 0.11], [0, 0.6, 0.1, 0.7]],
[[0, -0.01, 0.1, 0.09], [0, -0.1, 0.1, 0.09]],
[[0, 0.11, 0.1, 0.2], [0, 0.11, 0.1, 0.2]]],
[[[0, 0, 0.2, 0.2], [0, 0, 0.2, 0.2]],
[[0, 0.02, 0.2, 0.22], [0, 0.02, 0.2, 0.22]],
[[0, -0.02, 0.2, 0.19], [0, -0.02, 0.2, 0.19]],
[[0, 0.21, 0.2, 0.3], [0, 0.21, 0.2, 0.3]]]],
np.float32)
scores = np.array([[[.1, 0.9], [.75, 0.8],
[.6, 0.3], [0.95, 0.1]],
[[.1, 0.9], [.75, 0.8],
[.6, .3], [.95, .1]]], np.float32)
score_thresh = 0.1
iou_thresh = .5
max_output_size = 3
......@@ -695,27 +609,78 @@ class BatchMulticlassNonMaxSuppressionTest(test_case.TestCase,
exp_nms_classes = np.array([[0, 1, 1],
[0, 1, 0]])
(nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_masks,
nmsed_additional_fields, num_detections
) = post_processing.batch_multiclass_non_max_suppression(
boxes, scores, score_thresh, iou_thresh,
max_size_per_class=max_output_size, max_total_size=max_output_size,
use_static_shapes=True,
use_combined_nms=True)
self.assertIsNone(nmsed_masks)
self.assertIsNone(nmsed_additional_fields)
with self.test_session() as sess:
(nmsed_boxes, nmsed_scores, nmsed_classes,
num_detections) = sess.run([nmsed_boxes, nmsed_scores, nmsed_classes,
num_detections])
self.assertAllClose(nmsed_boxes, exp_nms_corners)
self.assertAllClose(nmsed_scores, exp_nms_scores)
self.assertAllClose(nmsed_classes, exp_nms_classes)
self.assertListEqual(num_detections.tolist(), [3, 3])
# TODO(bhattad): Remove conditional after CMLE moves to TF 1.9
def graph_fn(boxes, scores):
(nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_masks,
nmsed_additional_fields, num_detections
) = post_processing.batch_multiclass_non_max_suppression(
boxes, scores, score_thresh, iou_thresh,
max_size_per_class=max_output_size, max_total_size=max_output_size,
use_static_shapes=True,
use_combined_nms=True)
self.assertIsNone(nmsed_masks)
self.assertIsNone(nmsed_additional_fields)
return (nmsed_boxes, nmsed_scores, nmsed_classes, num_detections)
(nmsed_boxes, nmsed_scores, nmsed_classes,
num_detections) = self.execute_cpu(graph_fn, [boxes, scores])
self.assertAllClose(nmsed_boxes, exp_nms_corners)
self.assertAllClose(nmsed_scores, exp_nms_scores)
self.assertAllClose(nmsed_classes, exp_nms_classes)
self.assertListEqual(num_detections.tolist(), [3, 3])
def test_batch_multiclass_nms_with_use_static_shapes(self):
boxes = np.array([[[[0, 0, 1, 1], [0, 0, 4, 5]],
[[0, 0.1, 1, 1.1], [0, 0.1, 2, 1.1]],
[[0, -0.1, 1, 0.9], [0, -0.1, 1, 0.9]],
[[0, 10, 1, 11], [0, 10, 1, 11]]],
[[[0, 10.1, 1, 11.1], [0, 10.1, 1, 11.1]],
[[0, 100, 1, 101], [0, 100, 1, 101]],
[[0, 1000, 1, 1002], [0, 999, 2, 1004]],
[[0, 1000, 1, 1002.1], [0, 999, 2, 1002.7]]]],
np.float32)
scores = np.array([[[.9, 0.01], [.75, 0.05],
[.6, 0.01], [.95, 0]],
[[.5, 0.01], [.3, 0.01],
[.01, .85], [.01, .5]]],
np.float32)
clip_window = np.array([[0., 0., 5., 5.],
[0., 0., 200., 200.]],
np.float32)
score_thresh = 0.1
iou_thresh = .5
max_output_size = 4
exp_nms_corners = np.array([[[0, 0, 1, 1],
[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0]],
[[0, 10.1, 1, 11.1],
[0, 100, 1, 101],
[0, 0, 0, 0],
[0, 0, 0, 0]]])
exp_nms_scores = np.array([[.9, 0., 0., 0.],
[.5, .3, 0, 0]])
exp_nms_classes = np.array([[0, 0, 0, 0],
[0, 0, 0, 0]])
def graph_fn(boxes, scores, clip_window):
(nmsed_boxes, nmsed_scores, nmsed_classes, _, _, num_detections
) = post_processing.batch_multiclass_non_max_suppression(
boxes, scores, score_thresh, iou_thresh,
max_size_per_class=max_output_size, clip_window=clip_window,
use_static_shapes=True)
return nmsed_boxes, nmsed_scores, nmsed_classes, num_detections
(nmsed_boxes, nmsed_scores, nmsed_classes,
num_detections) = self.execute(graph_fn, [boxes, scores, clip_window])
for i in range(len(num_detections)):
self.assertAllClose(nmsed_boxes[i, 0:num_detections[i]],
exp_nms_corners[i, 0:num_detections[i]])
self.assertAllClose(nmsed_scores[i, 0:num_detections[i]],
exp_nms_scores[i, 0:num_detections[i]])
self.assertAllClose(nmsed_classes[i, 0:num_detections[i]],
exp_nms_classes[i, 0:num_detections[i]])
self.assertAllClose(num_detections, [1, 2])
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment