Unverified Commit 420a7253 authored by pkulzc's avatar pkulzc Committed by GitHub
Browse files

Refactor tests for Object Detection API. (#8688)

Internal changes

--

PiperOrigin-RevId: 316837667
parent d0ef3913
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
"""Tests for box_predictor_builder.""" """Tests for box_predictor_builder."""
import unittest
import mock import mock
import tensorflow.compat.v1 as tf import tensorflow.compat.v1 as tf
...@@ -25,8 +26,10 @@ from object_detection.builders import hyperparams_builder ...@@ -25,8 +26,10 @@ from object_detection.builders import hyperparams_builder
from object_detection.predictors import mask_rcnn_box_predictor from object_detection.predictors import mask_rcnn_box_predictor
from object_detection.protos import box_predictor_pb2 from object_detection.protos import box_predictor_pb2
from object_detection.protos import hyperparams_pb2 from object_detection.protos import hyperparams_pb2
from object_detection.utils import tf_version
@unittest.skipIf(tf_version.is_tf2(), 'Skipping TF1.X only Tests.')
class ConvolutionalBoxPredictorBuilderTest(tf.test.TestCase): class ConvolutionalBoxPredictorBuilderTest(tf.test.TestCase):
def test_box_predictor_calls_conv_argscope_fn(self): def test_box_predictor_calls_conv_argscope_fn(self):
...@@ -161,6 +164,7 @@ class ConvolutionalBoxPredictorBuilderTest(tf.test.TestCase): ...@@ -161,6 +164,7 @@ class ConvolutionalBoxPredictorBuilderTest(tf.test.TestCase):
self.assertFalse(class_head._use_depthwise) self.assertFalse(class_head._use_depthwise)
@unittest.skipIf(tf_version.is_tf2(), 'Skipping TF1.X only Tests.')
class WeightSharedConvolutionalBoxPredictorBuilderTest(tf.test.TestCase): class WeightSharedConvolutionalBoxPredictorBuilderTest(tf.test.TestCase):
def test_box_predictor_calls_conv_argscope_fn(self): def test_box_predictor_calls_conv_argscope_fn(self):
...@@ -357,6 +361,7 @@ class WeightSharedConvolutionalBoxPredictorBuilderTest(tf.test.TestCase): ...@@ -357,6 +361,7 @@ class WeightSharedConvolutionalBoxPredictorBuilderTest(tf.test.TestCase):
@unittest.skipIf(tf_version.is_tf2(), 'Skipping TF1.X only Tests.')
class MaskRCNNBoxPredictorBuilderTest(tf.test.TestCase): class MaskRCNNBoxPredictorBuilderTest(tf.test.TestCase):
def test_box_predictor_builder_calls_fc_argscope_fn(self): def test_box_predictor_builder_calls_fc_argscope_fn(self):
...@@ -537,6 +542,7 @@ class MaskRCNNBoxPredictorBuilderTest(tf.test.TestCase): ...@@ -537,6 +542,7 @@ class MaskRCNNBoxPredictorBuilderTest(tf.test.TestCase):
._convolve_then_upsample) ._convolve_then_upsample)
@unittest.skipIf(tf_version.is_tf2(), 'Skipping TF1.X only Tests.')
class RfcnBoxPredictorBuilderTest(tf.test.TestCase): class RfcnBoxPredictorBuilderTest(tf.test.TestCase):
def test_box_predictor_calls_fc_argscope_fn(self): def test_box_predictor_calls_fc_argscope_fn(self):
......
...@@ -25,31 +25,34 @@ from six.moves import zip ...@@ -25,31 +25,34 @@ from six.moves import zip
import tensorflow.compat.v1 as tf import tensorflow.compat.v1 as tf
from object_detection.builders import calibration_builder from object_detection.builders import calibration_builder
from object_detection.protos import calibration_pb2 from object_detection.protos import calibration_pb2
from object_detection.utils import test_case
class CalibrationBuilderTest(tf.test.TestCase): class CalibrationBuilderTest(test_case.TestCase):
def test_tf_linear_interp1d_map(self): def test_tf_linear_interp1d_map(self):
"""Tests TF linear interpolation mapping to a single number.""" """Tests TF linear interpolation mapping to a single number."""
with self.test_session() as sess: def graph_fn():
tf_x = tf.constant([0., 0.5, 1.]) tf_x = tf.constant([0., 0.5, 1.])
tf_y = tf.constant([0.5, 0.5, 0.5]) tf_y = tf.constant([0.5, 0.5, 0.5])
new_x = tf.constant([0., 0.25, 0.5, 0.75, 1.]) new_x = tf.constant([0., 0.25, 0.5, 0.75, 1.])
tf_map_outputs = calibration_builder._tf_linear_interp1d( tf_map_outputs = calibration_builder._tf_linear_interp1d(
new_x, tf_x, tf_y) new_x, tf_x, tf_y)
tf_map_outputs_np = sess.run([tf_map_outputs]) return tf_map_outputs
self.assertAllClose(tf_map_outputs_np, [[0.5, 0.5, 0.5, 0.5, 0.5]]) tf_map_outputs_np = self.execute(graph_fn, [])
self.assertAllClose(tf_map_outputs_np, [0.5, 0.5, 0.5, 0.5, 0.5])
def test_tf_linear_interp1d_interpolate(self): def test_tf_linear_interp1d_interpolate(self):
"""Tests TF 1d linear interpolation not mapping to a single number.""" """Tests TF 1d linear interpolation not mapping to a single number."""
with self.test_session() as sess: def graph_fn():
tf_x = tf.constant([0., 0.5, 1.]) tf_x = tf.constant([0., 0.5, 1.])
tf_y = tf.constant([0.6, 0.7, 1.0]) tf_y = tf.constant([0.6, 0.7, 1.0])
new_x = tf.constant([0., 0.25, 0.5, 0.75, 1.]) new_x = tf.constant([0., 0.25, 0.5, 0.75, 1.])
tf_interpolate_outputs = calibration_builder._tf_linear_interp1d( tf_interpolate_outputs = calibration_builder._tf_linear_interp1d(
new_x, tf_x, tf_y) new_x, tf_x, tf_y)
tf_interpolate_outputs_np = sess.run([tf_interpolate_outputs]) return tf_interpolate_outputs
self.assertAllClose(tf_interpolate_outputs_np, [[0.6, 0.65, 0.7, 0.85, 1.]]) tf_interpolate_outputs_np = self.execute(graph_fn, [])
self.assertAllClose(tf_interpolate_outputs_np, [0.6, 0.65, 0.7, 0.85, 1.])
@staticmethod @staticmethod
def _get_scipy_interp1d(new_x, x, y): def _get_scipy_interp1d(new_x, x, y):
...@@ -59,12 +62,13 @@ class CalibrationBuilderTest(tf.test.TestCase): ...@@ -59,12 +62,13 @@ class CalibrationBuilderTest(tf.test.TestCase):
def _get_tf_interp1d(self, new_x, x, y): def _get_tf_interp1d(self, new_x, x, y):
"""Helper performing 1d linear interpolation using Tensorflow.""" """Helper performing 1d linear interpolation using Tensorflow."""
with self.test_session() as sess: def graph_fn():
tf_interp_outputs = calibration_builder._tf_linear_interp1d( tf_interp_outputs = calibration_builder._tf_linear_interp1d(
tf.convert_to_tensor(new_x, dtype=tf.float32), tf.convert_to_tensor(new_x, dtype=tf.float32),
tf.convert_to_tensor(x, dtype=tf.float32), tf.convert_to_tensor(x, dtype=tf.float32),
tf.convert_to_tensor(y, dtype=tf.float32)) tf.convert_to_tensor(y, dtype=tf.float32))
np_tf_interp_outputs = sess.run(tf_interp_outputs) return tf_interp_outputs
np_tf_interp_outputs = self.execute(graph_fn, [])
return np_tf_interp_outputs return np_tf_interp_outputs
def test_tf_linear_interp1d_against_scipy_map(self): def test_tf_linear_interp1d_against_scipy_map(self):
...@@ -128,8 +132,7 @@ class CalibrationBuilderTest(tf.test.TestCase): ...@@ -128,8 +132,7 @@ class CalibrationBuilderTest(tf.test.TestCase):
self._add_function_approximation_to_calibration_proto( self._add_function_approximation_to_calibration_proto(
calibration_config, class_agnostic_x, class_agnostic_y, class_id=None) calibration_config, class_agnostic_x, class_agnostic_y, class_id=None)
od_graph = tf.Graph() def graph_fn():
with self.test_session(graph=od_graph) as sess:
calibration_fn = calibration_builder.build(calibration_config) calibration_fn = calibration_builder.build(calibration_config)
# batch_size = 2, num_classes = 2, num_anchors = 2. # batch_size = 2, num_classes = 2, num_anchors = 2.
class_predictions_with_background = tf.constant( class_predictions_with_background = tf.constant(
...@@ -140,7 +143,8 @@ class CalibrationBuilderTest(tf.test.TestCase): ...@@ -140,7 +143,8 @@ class CalibrationBuilderTest(tf.test.TestCase):
# Everything should map to 0.5 if classes are ignored. # Everything should map to 0.5 if classes are ignored.
calibrated_scores = calibration_fn(class_predictions_with_background) calibrated_scores = calibration_fn(class_predictions_with_background)
calibrated_scores_np = sess.run(calibrated_scores) return calibrated_scores
calibrated_scores_np = self.execute(graph_fn, [])
self.assertAllClose(calibrated_scores_np, [[[0.05, 0.1, 0.15], self.assertAllClose(calibrated_scores_np, [[[0.05, 0.1, 0.15],
[0.2, 0.25, 0.0]], [0.2, 0.25, 0.0]],
[[0.35, 0.45, 0.55], [[0.35, 0.45, 0.55],
...@@ -161,8 +165,7 @@ class CalibrationBuilderTest(tf.test.TestCase): ...@@ -161,8 +165,7 @@ class CalibrationBuilderTest(tf.test.TestCase):
self._add_function_approximation_to_calibration_proto( self._add_function_approximation_to_calibration_proto(
calibration_config, class_1_x, class_1_y, class_id=1) calibration_config, class_1_x, class_1_y, class_id=1)
od_graph = tf.Graph() def graph_fn():
with self.test_session(graph=od_graph) as sess:
calibration_fn = calibration_builder.build(calibration_config) calibration_fn = calibration_builder.build(calibration_config)
# batch_size = 2, num_classes = 2, num_anchors = 2. # batch_size = 2, num_classes = 2, num_anchors = 2.
class_predictions_with_background = tf.constant( class_predictions_with_background = tf.constant(
...@@ -170,7 +173,8 @@ class CalibrationBuilderTest(tf.test.TestCase): ...@@ -170,7 +173,8 @@ class CalibrationBuilderTest(tf.test.TestCase):
[[0.6, 0.4], [0.08, 0.92]]], [[0.6, 0.4], [0.08, 0.92]]],
dtype=tf.float32) dtype=tf.float32)
calibrated_scores = calibration_fn(class_predictions_with_background) calibrated_scores = calibration_fn(class_predictions_with_background)
calibrated_scores_np = sess.run(calibrated_scores) return calibrated_scores
calibrated_scores_np = self.execute(graph_fn, [])
self.assertAllClose(calibrated_scores_np, [[[0.5, 0.6], [0.5, 0.3]], self.assertAllClose(calibrated_scores_np, [[[0.5, 0.6], [0.5, 0.3]],
[[0.5, 0.7], [0.5, 0.96]]]) [[0.5, 0.7], [0.5, 0.96]]])
...@@ -179,8 +183,7 @@ class CalibrationBuilderTest(tf.test.TestCase): ...@@ -179,8 +183,7 @@ class CalibrationBuilderTest(tf.test.TestCase):
calibration_config = calibration_pb2.CalibrationConfig() calibration_config = calibration_pb2.CalibrationConfig()
calibration_config.temperature_scaling_calibration.scaler = 2.0 calibration_config.temperature_scaling_calibration.scaler = 2.0
od_graph = tf.Graph() def graph_fn():
with self.test_session(graph=od_graph) as sess:
calibration_fn = calibration_builder.build(calibration_config) calibration_fn = calibration_builder.build(calibration_config)
# batch_size = 2, num_classes = 2, num_anchors = 2. # batch_size = 2, num_classes = 2, num_anchors = 2.
class_predictions_with_background = tf.constant( class_predictions_with_background = tf.constant(
...@@ -188,7 +191,8 @@ class CalibrationBuilderTest(tf.test.TestCase): ...@@ -188,7 +191,8 @@ class CalibrationBuilderTest(tf.test.TestCase):
[[0.6, 0.7, 0.8], [0.9, 1.0, 1.0]]], [[0.6, 0.7, 0.8], [0.9, 1.0, 1.0]]],
dtype=tf.float32) dtype=tf.float32)
calibrated_scores = calibration_fn(class_predictions_with_background) calibrated_scores = calibration_fn(class_predictions_with_background)
calibrated_scores_np = sess.run(calibrated_scores) return calibrated_scores
calibrated_scores_np = self.execute(graph_fn, [])
self.assertAllClose(calibrated_scores_np, self.assertAllClose(calibrated_scores_np,
[[[0.05, 0.1, 0.15], [0.2, 0.25, 0.0]], [[[0.05, 0.1, 0.15], [0.2, 0.25, 0.0]],
[[0.3, 0.35, 0.4], [0.45, 0.5, 0.5]]]) [[0.3, 0.35, 0.4], [0.45, 0.5, 0.5]]])
...@@ -212,8 +216,7 @@ class CalibrationBuilderTest(tf.test.TestCase): ...@@ -212,8 +216,7 @@ class CalibrationBuilderTest(tf.test.TestCase):
calibration_config = calibration_pb2.CalibrationConfig() calibration_config = calibration_pb2.CalibrationConfig()
self._add_function_approximation_to_calibration_proto( self._add_function_approximation_to_calibration_proto(
calibration_config, class_0_x, class_0_y, class_id=0) calibration_config, class_0_x, class_0_y, class_id=0)
od_graph = tf.Graph() def graph_fn():
with self.test_session(graph=od_graph) as sess:
calibration_fn = calibration_builder.build(calibration_config) calibration_fn = calibration_builder.build(calibration_config)
# batch_size = 2, num_classes = 2, num_anchors = 2. # batch_size = 2, num_classes = 2, num_anchors = 2.
class_predictions_with_background = tf.constant( class_predictions_with_background = tf.constant(
...@@ -221,7 +224,8 @@ class CalibrationBuilderTest(tf.test.TestCase): ...@@ -221,7 +224,8 @@ class CalibrationBuilderTest(tf.test.TestCase):
[[0.6, 0.4], [0.08, 0.92]]], [[0.6, 0.4], [0.08, 0.92]]],
dtype=tf.float32) dtype=tf.float32)
calibrated_scores = calibration_fn(class_predictions_with_background) calibrated_scores = calibration_fn(class_predictions_with_background)
calibrated_scores_np = sess.run(calibrated_scores) return calibrated_scores
calibrated_scores_np = self.execute(graph_fn, [])
self.assertAllClose(calibrated_scores_np, [[[0.5, 0.2], [0.5, 0.1]], self.assertAllClose(calibrated_scores_np, [[[0.5, 0.2], [0.5, 0.1]],
[[0.5, 0.4], [0.5, 0.92]]]) [[0.5, 0.4], [0.5, 0.92]]])
......
...@@ -29,7 +29,6 @@ from __future__ import print_function ...@@ -29,7 +29,6 @@ from __future__ import print_function
import functools import functools
import tensorflow.compat.v1 as tf import tensorflow.compat.v1 as tf
from tensorflow.contrib import data as tf_data
from object_detection.builders import decoder_builder from object_detection.builders import decoder_builder
from object_detection.protos import input_reader_pb2 from object_detection.protos import input_reader_pb2
...@@ -94,7 +93,7 @@ def read_dataset(file_read_func, input_files, config, ...@@ -94,7 +93,7 @@ def read_dataset(file_read_func, input_files, config,
filename_dataset = filename_dataset.repeat(config.num_epochs or None) filename_dataset = filename_dataset.repeat(config.num_epochs or None)
records_dataset = filename_dataset.apply( records_dataset = filename_dataset.apply(
tf_data.parallel_interleave( tf.data.experimental.parallel_interleave(
file_read_func, file_read_func,
cycle_length=num_readers, cycle_length=num_readers,
block_length=config.read_block_length, block_length=config.read_block_length,
...@@ -153,6 +152,30 @@ def build(input_reader_config, batch_size=None, transform_input_data_fn=None, ...@@ -153,6 +152,30 @@ def build(input_reader_config, batch_size=None, transform_input_data_fn=None,
if not config.input_path: if not config.input_path:
raise ValueError('At least one input path must be specified in ' raise ValueError('At least one input path must be specified in '
'`input_reader_config`.') '`input_reader_config`.')
def dataset_map_fn(dataset, fn_to_map, batch_size=None,
input_reader_config=None):
"""Handles whether or not to use the legacy map function.
Args:
dataset: A tf.Dataset.
fn_to_map: The function to be mapped for that dataset.
batch_size: Batch size. If batch size is None, no batching is performed.
input_reader_config: A input_reader_pb2.InputReader object.
Returns:
A tf.data.Dataset mapped with fn_to_map.
"""
if hasattr(dataset, 'map_with_legacy_function'):
if batch_size:
num_parallel_calls = batch_size * (
input_reader_config.num_parallel_batches)
else:
num_parallel_calls = input_reader_config.num_parallel_map_calls
dataset = dataset.map_with_legacy_function(
fn_to_map, num_parallel_calls=num_parallel_calls)
else:
dataset = dataset.map(fn_to_map, tf.data.experimental.AUTOTUNE)
return dataset
shard_fn = shard_function_for_context(input_context) shard_fn = shard_function_for_context(input_context)
if input_context is not None: if input_context is not None:
batch_size = input_context.get_per_replica_batch_size(batch_size) batch_size = input_context.get_per_replica_batch_size(batch_size)
...@@ -163,15 +186,16 @@ def build(input_reader_config, batch_size=None, transform_input_data_fn=None, ...@@ -163,15 +186,16 @@ def build(input_reader_config, batch_size=None, transform_input_data_fn=None,
dataset = dataset.shard(input_reader_config.sample_1_of_n_examples, 0) dataset = dataset.shard(input_reader_config.sample_1_of_n_examples, 0)
# TODO(rathodv): make batch size a required argument once the old binaries # TODO(rathodv): make batch size a required argument once the old binaries
# are deleted. # are deleted.
dataset = dataset.map(decoder.decode, tf.data.experimental.AUTOTUNE) dataset = dataset_map_fn(dataset, decoder.decode, batch_size,
input_reader_config)
if reduce_to_frame_fn: if reduce_to_frame_fn:
dataset = reduce_to_frame_fn(dataset) dataset = reduce_to_frame_fn(dataset, dataset_map_fn, batch_size,
input_reader_config)
if transform_input_data_fn is not None: if transform_input_data_fn is not None:
dataset = dataset.map(transform_input_data_fn, dataset = dataset_map_fn(dataset, transform_input_data_fn,
tf.data.experimental.AUTOTUNE) batch_size, input_reader_config)
if batch_size: if batch_size:
dataset = dataset.apply( dataset = dataset.batch(batch_size, drop_remainder=True)
tf_data.batch_and_drop_remainder(batch_size))
dataset = dataset.prefetch(input_reader_config.num_prefetch_batches) dataset = dataset.prefetch(input_reader_config.num_prefetch_batches)
return dataset return dataset
......
...@@ -197,13 +197,13 @@ class DatasetBuilderTest(test_case.TestCase): ...@@ -197,13 +197,13 @@ class DatasetBuilderTest(test_case.TestCase):
output_dict[fields.InputDataFields.groundtruth_boxes][0][0]) output_dict[fields.InputDataFields.groundtruth_boxes][0][0])
def get_mock_reduce_to_frame_fn(self): def get_mock_reduce_to_frame_fn(self):
def mock_reduce_to_frame_fn(dataset): def mock_reduce_to_frame_fn(dataset, dataset_map_fn, batch_size, config):
def get_frame(tensor_dict): def get_frame(tensor_dict):
out_tensor_dict = {} out_tensor_dict = {}
out_tensor_dict[fields.InputDataFields.source_id] = ( out_tensor_dict[fields.InputDataFields.source_id] = (
tensor_dict[fields.InputDataFields.source_id][0]) tensor_dict[fields.InputDataFields.source_id][0])
return out_tensor_dict return out_tensor_dict
return dataset.map(get_frame, tf.data.experimental.AUTOTUNE) return dataset_map_fn(dataset, get_frame, batch_size, config)
return mock_reduce_to_frame_fn return mock_reduce_to_frame_fn
def test_build_tf_record_input_reader_sequence_example_train(self): def test_build_tf_record_input_reader_sequence_example_train(self):
...@@ -537,8 +537,15 @@ class ReadDatasetTest(test_case.TestCase): ...@@ -537,8 +537,15 @@ class ReadDatasetTest(test_case.TestCase):
def graph_fn(): def graph_fn():
keys = [1, 0, -1] keys = [1, 0, -1]
dataset = tf.data.Dataset.from_tensor_slices([[1, 2, -1, 5]]) dataset = tf.data.Dataset.from_tensor_slices([[1, 2, -1, 5]])
table = contrib_lookup.HashTable( try:
initializer=contrib_lookup.KeyValueTensorInitializer( # Dynamically try to load the tf v2 lookup, falling back to contrib
lookup = tf.compat.v2.lookup
hash_table_class = tf.compat.v2.lookup.StaticHashTable
except AttributeError:
lookup = contrib_lookup
hash_table_class = contrib_lookup.HashTable
table = hash_table_class(
initializer=lookup.KeyValueTensorInitializer(
keys=keys, values=list(reversed(keys))), keys=keys, values=list(reversed(keys))),
default_value=100) default_value=100)
dataset = dataset.map(table.lookup) dataset = dataset.map(table.lookup)
...@@ -559,7 +566,7 @@ class ReadDatasetTest(test_case.TestCase): ...@@ -559,7 +566,7 @@ class ReadDatasetTest(test_case.TestCase):
data = self.execute(graph_fn, []) data = self.execute(graph_fn, [])
# Note that the execute function extracts single outputs if the return # Note that the execute function extracts single outputs if the return
# value is of size 1. # value is of size 1.
self.assertAllEqual( self.assertCountEqual(
data, [ data, [
1, 10, 2, 20, 3, 30, 4, 40, 5, 50, 1, 10, 2, 20, 3, 30, 4, 40, 5, 1, 10, 2, 20, 3, 30, 4, 40, 5, 50, 1, 10, 2, 20, 3, 30, 4, 40, 5,
50 50
...@@ -577,7 +584,7 @@ class ReadDatasetTest(test_case.TestCase): ...@@ -577,7 +584,7 @@ class ReadDatasetTest(test_case.TestCase):
data = self.execute(graph_fn, []) data = self.execute(graph_fn, [])
# Note that the execute function extracts single outputs if the return # Note that the execute function extracts single outputs if the return
# value is of size 1. # value is of size 1.
self.assertAllEqual( self.assertCountEqual(
data, [ data, [
1, 10, 2, 20, 3, 30, 4, 40, 5, 50, 1, 10, 2, 20, 3, 30, 4, 40, 5, 1, 10, 2, 20, 3, 30, 4, 40, 5, 50, 1, 10, 2, 20, 3, 30, 4, 40, 5,
50 50
...@@ -607,12 +614,14 @@ class ReadDatasetTest(test_case.TestCase): ...@@ -607,12 +614,14 @@ class ReadDatasetTest(test_case.TestCase):
def graph_fn(): def graph_fn():
return self._get_dataset_next( return self._get_dataset_next(
[self._shuffle_path_template % '*'], config, batch_size=10) [self._shuffle_path_template % '*'], config, batch_size=10)
expected_non_shuffle_output = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1] expected_non_shuffle_output1 = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]
expected_non_shuffle_output2 = [1, 1, 1, 1, 1, 0, 0, 0, 0, 0]
# Note that the execute function extracts single outputs if the return # Note that the execute function extracts single outputs if the return
# value is of size 1. # value is of size 1.
data = self.execute(graph_fn, []) data = self.execute(graph_fn, [])
self.assertAllEqual(data, expected_non_shuffle_output) self.assertTrue(all(data == expected_non_shuffle_output1) or
all(data == expected_non_shuffle_output2))
def test_read_dataset_single_epoch(self): def test_read_dataset_single_epoch(self):
config = input_reader_pb2.InputReader() config = input_reader_pb2.InputReader()
......
...@@ -48,7 +48,7 @@ def build(input_reader_config): ...@@ -48,7 +48,7 @@ def build(input_reader_config):
if input_reader_config.HasField('label_map_path'): if input_reader_config.HasField('label_map_path'):
label_map_proto_file = input_reader_config.label_map_path label_map_proto_file = input_reader_config.label_map_path
input_type = input_reader_config.input_type input_type = input_reader_config.input_type
if input_type == input_reader_pb2.InputType.TF_EXAMPLE: if input_type == input_reader_pb2.InputType.Value('TF_EXAMPLE'):
decoder = tf_example_decoder.TfExampleDecoder( decoder = tf_example_decoder.TfExampleDecoder(
load_instance_masks=input_reader_config.load_instance_masks, load_instance_masks=input_reader_config.load_instance_masks,
load_multiclass_scores=input_reader_config.load_multiclass_scores, load_multiclass_scores=input_reader_config.load_multiclass_scores,
...@@ -60,7 +60,7 @@ def build(input_reader_config): ...@@ -60,7 +60,7 @@ def build(input_reader_config):
num_keypoints=input_reader_config.num_keypoints, num_keypoints=input_reader_config.num_keypoints,
expand_hierarchy_labels=input_reader_config.expand_labels_hierarchy) expand_hierarchy_labels=input_reader_config.expand_labels_hierarchy)
return decoder return decoder
elif input_type == input_reader_pb2.InputType.TF_SEQUENCE_EXAMPLE: elif input_type == input_reader_pb2.InputType.Value('TF_SEQUENCE_EXAMPLE'):
decoder = tf_sequence_example_decoder.TfSequenceExampleDecoder( decoder = tf_sequence_example_decoder.TfSequenceExampleDecoder(
label_map_proto_file=label_map_proto_file, label_map_proto_file=label_map_proto_file,
load_context_features=input_reader_config.load_context_features) load_context_features=input_reader_config.load_context_features)
......
...@@ -29,6 +29,7 @@ from object_detection.core import standard_fields as fields ...@@ -29,6 +29,7 @@ from object_detection.core import standard_fields as fields
from object_detection.dataset_tools import seq_example_util from object_detection.dataset_tools import seq_example_util
from object_detection.protos import input_reader_pb2 from object_detection.protos import input_reader_pb2
from object_detection.utils import dataset_util from object_detection.utils import dataset_util
from object_detection.utils import test_case
def _get_labelmap_path(): def _get_labelmap_path():
...@@ -38,17 +39,20 @@ def _get_labelmap_path(): ...@@ -38,17 +39,20 @@ def _get_labelmap_path():
'pet_label_map.pbtxt') 'pet_label_map.pbtxt')
class DecoderBuilderTest(tf.test.TestCase): class DecoderBuilderTest(test_case.TestCase):
def _make_serialized_tf_example(self, has_additional_channels=False): def _make_serialized_tf_example(self, has_additional_channels=False):
image_tensor = np.random.randint(255, size=(4, 5, 3)).astype(np.uint8) image_tensor_np = np.random.randint(255, size=(4, 5, 3)).astype(np.uint8)
additional_channels_tensor = np.random.randint( additional_channels_tensor_np = np.random.randint(
255, size=(4, 5, 1)).astype(np.uint8) 255, size=(4, 5, 1)).astype(np.uint8)
flat_mask = (4 * 5) * [1.0] flat_mask = (4 * 5) * [1.0]
with self.test_session(): def graph_fn(image_tensor):
encoded_jpeg = tf.image.encode_jpeg(tf.constant(image_tensor)).eval() encoded_jpeg = tf.image.encode_jpeg(image_tensor)
encoded_additional_channels_jpeg = tf.image.encode_jpeg( return encoded_jpeg
tf.constant(additional_channels_tensor)).eval() encoded_jpeg = self.execute_cpu(graph_fn, [image_tensor_np])
encoded_additional_channels_jpeg = self.execute_cpu(
graph_fn, [additional_channels_tensor_np])
features = { features = {
'image/source_id': dataset_util.bytes_feature('0'.encode()), 'image/source_id': dataset_util.bytes_feature('0'.encode()),
'image/encoded': dataset_util.bytes_feature(encoded_jpeg), 'image/encoded': dataset_util.bytes_feature(encoded_jpeg),
...@@ -71,22 +75,21 @@ class DecoderBuilderTest(tf.test.TestCase): ...@@ -71,22 +75,21 @@ class DecoderBuilderTest(tf.test.TestCase):
def _make_random_serialized_jpeg_images(self, num_frames, image_height, def _make_random_serialized_jpeg_images(self, num_frames, image_height,
image_width): image_width):
def graph_fn():
images = tf.cast(tf.random.uniform( images = tf.cast(tf.random.uniform(
[num_frames, image_height, image_width, 3], [num_frames, image_height, image_width, 3],
maxval=256, maxval=256,
dtype=tf.int32), dtype=tf.uint8) dtype=tf.int32), dtype=tf.uint8)
images_list = tf.unstack(images, axis=0) images_list = tf.unstack(images, axis=0)
encoded_images_list = [tf.io.encode_jpeg(image) for image in images_list] encoded_images = [tf.io.encode_jpeg(image) for image in images_list]
with tf.Session() as sess:
encoded_images = sess.run(encoded_images_list)
return encoded_images return encoded_images
return self.execute_cpu(graph_fn, [])
def _make_serialized_tf_sequence_example(self): def _make_serialized_tf_sequence_example(self):
num_frames = 4 num_frames = 4
image_height = 20 image_height = 20
image_width = 30 image_width = 30
image_source_ids = [str(i) for i in range(num_frames)] image_source_ids = [str(i) for i in range(num_frames)]
with self.test_session():
encoded_images = self._make_random_serialized_jpeg_images( encoded_images = self._make_random_serialized_jpeg_images(
num_frames, image_height, image_width) num_frames, image_height, image_width)
sequence_example_serialized = seq_example_util.make_sequence_example( sequence_example_serialized = seq_example_util.make_sequence_example(
...@@ -119,21 +122,19 @@ class DecoderBuilderTest(tf.test.TestCase): ...@@ -119,21 +122,19 @@ class DecoderBuilderTest(tf.test.TestCase):
text_format.Parse(input_reader_text_proto, input_reader_proto) text_format.Parse(input_reader_text_proto, input_reader_proto)
decoder = decoder_builder.build(input_reader_proto) decoder = decoder_builder.build(input_reader_proto)
tensor_dict = decoder.decode(self._make_serialized_tf_example()) serialized_seq_example = self._make_serialized_tf_example()
def graph_fn():
with tf.train.MonitoredSession() as sess: tensor_dict = decoder.decode(serialized_seq_example)
output_dict = sess.run(tensor_dict) return (tensor_dict[fields.InputDataFields.image],
tensor_dict[fields.InputDataFields.groundtruth_classes],
self.assertNotIn( tensor_dict[fields.InputDataFields.groundtruth_boxes])
fields.InputDataFields.groundtruth_instance_masks, output_dict)
self.assertEqual((4, 5, 3), output_dict[fields.InputDataFields.image].shape) (image, groundtruth_classes,
self.assertAllEqual([2], groundtruth_boxes) = self.execute_cpu(graph_fn, [])
output_dict[fields.InputDataFields.groundtruth_classes]) self.assertEqual((4, 5, 3), image.shape)
self.assertEqual( self.assertAllEqual([2], groundtruth_classes)
(1, 4), output_dict[fields.InputDataFields.groundtruth_boxes].shape) self.assertEqual((1, 4), groundtruth_boxes.shape)
self.assertAllEqual( self.assertAllEqual([0.0, 0.0, 1.0, 1.0], groundtruth_boxes[0])
[0.0, 0.0, 1.0, 1.0],
output_dict[fields.InputDataFields.groundtruth_boxes][0])
def test_build_tf_record_input_reader_sequence_example(self): def test_build_tf_record_input_reader_sequence_example(self):
label_map_path = _get_labelmap_path() label_map_path = _get_labelmap_path()
...@@ -145,12 +146,16 @@ class DecoderBuilderTest(tf.test.TestCase): ...@@ -145,12 +146,16 @@ class DecoderBuilderTest(tf.test.TestCase):
input_reader_proto.label_map_path = label_map_path input_reader_proto.label_map_path = label_map_path
text_format.Parse(input_reader_text_proto, input_reader_proto) text_format.Parse(input_reader_text_proto, input_reader_proto)
serialized_seq_example = self._make_serialized_tf_sequence_example()
def graph_fn():
decoder = decoder_builder.build(input_reader_proto) decoder = decoder_builder.build(input_reader_proto)
tensor_dict = decoder.decode(self._make_serialized_tf_sequence_example()) tensor_dict = decoder.decode(serialized_seq_example)
return (tensor_dict[fields.InputDataFields.image],
with tf.train.MonitoredSession() as sess: tensor_dict[fields.InputDataFields.groundtruth_classes],
output_dict = sess.run(tensor_dict) tensor_dict[fields.InputDataFields.groundtruth_boxes],
tensor_dict[fields.InputDataFields.num_groundtruth_boxes])
(actual_image, actual_groundtruth_classes, actual_groundtruth_boxes,
actual_num_groundtruth_boxes) = self.execute_cpu(graph_fn, [])
expected_groundtruth_classes = [[-1, -1], [1, -1], [1, 2], [-1, -1]] expected_groundtruth_classes = [[-1, -1], [1, -1], [1, 2], [-1, -1]]
expected_groundtruth_boxes = [[[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]], expected_groundtruth_boxes = [[[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]],
[[0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 0.0, 0.0]], [[0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 0.0, 0.0]],
...@@ -158,19 +163,14 @@ class DecoderBuilderTest(tf.test.TestCase): ...@@ -158,19 +163,14 @@ class DecoderBuilderTest(tf.test.TestCase):
[[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]] [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]]
expected_num_groundtruth_boxes = [0, 1, 2, 0] expected_num_groundtruth_boxes = [0, 1, 2, 0]
self.assertNotIn(
fields.InputDataFields.groundtruth_instance_masks, output_dict)
# Sequence example images are encoded. # Sequence example images are encoded.
self.assertEqual((4,), output_dict[fields.InputDataFields.image].shape) self.assertEqual((4,), actual_image.shape)
self.assertAllEqual(expected_groundtruth_classes, self.assertAllEqual(expected_groundtruth_classes,
output_dict[fields.InputDataFields.groundtruth_classes]) actual_groundtruth_classes)
self.assertEqual(
(4, 2, 4), output_dict[fields.InputDataFields.groundtruth_boxes].shape)
self.assertAllClose(expected_groundtruth_boxes, self.assertAllClose(expected_groundtruth_boxes,
output_dict[fields.InputDataFields.groundtruth_boxes]) actual_groundtruth_boxes)
self.assertAllClose( self.assertAllClose(
expected_num_groundtruth_boxes, expected_num_groundtruth_boxes, actual_num_groundtruth_boxes)
output_dict[fields.InputDataFields.num_groundtruth_boxes])
def test_build_tf_record_input_reader_and_load_instance_masks(self): def test_build_tf_record_input_reader_and_load_instance_masks(self):
input_reader_text_proto = """ input_reader_text_proto = """
...@@ -181,14 +181,12 @@ class DecoderBuilderTest(tf.test.TestCase): ...@@ -181,14 +181,12 @@ class DecoderBuilderTest(tf.test.TestCase):
text_format.Parse(input_reader_text_proto, input_reader_proto) text_format.Parse(input_reader_text_proto, input_reader_proto)
decoder = decoder_builder.build(input_reader_proto) decoder = decoder_builder.build(input_reader_proto)
tensor_dict = decoder.decode(self._make_serialized_tf_example()) serialized_seq_example = self._make_serialized_tf_example()
def graph_fn():
with tf.train.MonitoredSession() as sess: tensor_dict = decoder.decode(serialized_seq_example)
output_dict = sess.run(tensor_dict) return tensor_dict[fields.InputDataFields.groundtruth_instance_masks]
masks = self.execute_cpu(graph_fn, [])
self.assertAllEqual( self.assertAllEqual((1, 4, 5), masks.shape)
(1, 4, 5),
output_dict[fields.InputDataFields.groundtruth_instance_masks].shape)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -13,22 +13,21 @@ ...@@ -13,22 +13,21 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Tests for graph_rewriter_builder.""" """Tests for graph_rewriter_builder."""
import unittest
import mock import mock
import tensorflow.compat.v1 as tf import tensorflow.compat.v1 as tf
import tf_slim as slim import tf_slim as slim
from object_detection.builders import graph_rewriter_builder from object_detection.builders import graph_rewriter_builder
from object_detection.protos import graph_rewriter_pb2 from object_detection.protos import graph_rewriter_pb2
from object_detection.utils import tf_version
# pylint: disable=g-import-not-at-top
try:
from tensorflow.contrib import quantize as contrib_quantize
except ImportError:
# TF 2.0 doesn't ship with contrib.
pass
# pylint: enable=g-import-not-at-top
if tf_version.is_tf1():
from tensorflow.contrib import quantize as contrib_quantize # pylint: disable=g-import-not-at-top
@unittest.skipIf(tf_version.is_tf2(), 'Skipping TF1.X only test.')
class QuantizationBuilderTest(tf.test.TestCase): class QuantizationBuilderTest(tf.test.TestCase):
def testQuantizationBuilderSetsUpCorrectTrainArguments(self): def testQuantizationBuilderSetsUpCorrectTrainArguments(self):
......
...@@ -18,21 +18,23 @@ import tensorflow.compat.v1 as tf ...@@ -18,21 +18,23 @@ import tensorflow.compat.v1 as tf
from google.protobuf import text_format from google.protobuf import text_format
from object_detection.builders import image_resizer_builder from object_detection.builders import image_resizer_builder
from object_detection.protos import image_resizer_pb2 from object_detection.protos import image_resizer_pb2
from object_detection.utils import test_case
class ImageResizerBuilderTest(tf.test.TestCase): class ImageResizerBuilderTest(test_case.TestCase):
def _shape_of_resized_random_image_given_text_proto(self, input_shape, def _shape_of_resized_random_image_given_text_proto(self, input_shape,
text_proto): text_proto):
image_resizer_config = image_resizer_pb2.ImageResizer() image_resizer_config = image_resizer_pb2.ImageResizer()
text_format.Merge(text_proto, image_resizer_config) text_format.Merge(text_proto, image_resizer_config)
image_resizer_fn = image_resizer_builder.build(image_resizer_config) image_resizer_fn = image_resizer_builder.build(image_resizer_config)
def graph_fn():
images = tf.cast( images = tf.cast(
tf.random_uniform(input_shape, minval=0, maxval=255, dtype=tf.int32), tf.random_uniform(input_shape, minval=0, maxval=255, dtype=tf.int32),
dtype=tf.float32) dtype=tf.float32)
resized_images, _ = image_resizer_fn(images) resized_images, _ = image_resizer_fn(images)
with self.test_session() as sess: return resized_images
return sess.run(resized_images).shape return self.execute_cpu(graph_fn, []).shape
def test_build_keep_aspect_ratio_resizer_returns_expected_shape(self): def test_build_keep_aspect_ratio_resizer_returns_expected_shape(self):
image_resizer_text_proto = """ image_resizer_text_proto = """
...@@ -125,10 +127,10 @@ class ImageResizerBuilderTest(tf.test.TestCase): ...@@ -125,10 +127,10 @@ class ImageResizerBuilderTest(tf.test.TestCase):
image_resizer_config = image_resizer_pb2.ImageResizer() image_resizer_config = image_resizer_pb2.ImageResizer()
text_format.Merge(text_proto, image_resizer_config) text_format.Merge(text_proto, image_resizer_config)
image_resizer_fn = image_resizer_builder.build(image_resizer_config) image_resizer_fn = image_resizer_builder.build(image_resizer_config)
image_placeholder = tf.placeholder(tf.uint8, [1, None, None, 3]) def graph_fn(image):
resized_image, _ = image_resizer_fn(image_placeholder) resized_image, _ = image_resizer_fn(image)
with self.test_session() as sess: return resized_image
return sess.run(resized_image, feed_dict={image_placeholder: image}) return self.execute_cpu(graph_fn, [image])
def test_fixed_shape_resizer_nearest_neighbor_method(self): def test_fixed_shape_resizer_nearest_neighbor_method(self):
image_resizer_text_proto = """ image_resizer_text_proto = """
......
...@@ -29,19 +29,12 @@ from __future__ import division ...@@ -29,19 +29,12 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import tensorflow.compat.v1 as tf import tensorflow.compat.v1 as tf
import tf_slim as slim
from object_detection.data_decoders import tf_example_decoder from object_detection.data_decoders import tf_example_decoder
from object_detection.data_decoders import tf_sequence_example_decoder from object_detection.data_decoders import tf_sequence_example_decoder
from object_detection.protos import input_reader_pb2 from object_detection.protos import input_reader_pb2
# pylint: disable=g-import-not-at-top
try:
import tf_slim as slim
except ImportError:
# TF 2.0 doesn't ship with contrib.
pass
# pylint: enable=g-import-not-at-top
parallel_reader = slim.parallel_reader parallel_reader = slim.parallel_reader
...@@ -82,14 +75,14 @@ def build(input_reader_config): ...@@ -82,14 +75,14 @@ def build(input_reader_config):
if input_reader_config.HasField('label_map_path'): if input_reader_config.HasField('label_map_path'):
label_map_proto_file = input_reader_config.label_map_path label_map_proto_file = input_reader_config.label_map_path
input_type = input_reader_config.input_type input_type = input_reader_config.input_type
if input_type == input_reader_pb2.InputType.TF_EXAMPLE: if input_type == input_reader_pb2.InputType.Value('TF_EXAMPLE'):
decoder = tf_example_decoder.TfExampleDecoder( decoder = tf_example_decoder.TfExampleDecoder(
load_instance_masks=input_reader_config.load_instance_masks, load_instance_masks=input_reader_config.load_instance_masks,
instance_mask_type=input_reader_config.mask_type, instance_mask_type=input_reader_config.mask_type,
label_map_proto_file=label_map_proto_file, label_map_proto_file=label_map_proto_file,
load_context_features=input_reader_config.load_context_features) load_context_features=input_reader_config.load_context_features)
return decoder.decode(string_tensor) return decoder.decode(string_tensor)
elif input_type == input_reader_pb2.InputType.TF_SEQUENCE_EXAMPLE: elif input_type == input_reader_pb2.InputType.Value('TF_SEQUENCE_EXAMPLE'):
decoder = tf_sequence_example_decoder.TfSequenceExampleDecoder( decoder = tf_sequence_example_decoder.TfSequenceExampleDecoder(
label_map_proto_file=label_map_proto_file, label_map_proto_file=label_map_proto_file,
load_context_features=input_reader_config.load_context_features) load_context_features=input_reader_config.load_context_features)
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
"""Tests for input_reader_builder.""" """Tests for input_reader_builder."""
import os import os
import unittest
import numpy as np import numpy as np
import tensorflow.compat.v1 as tf import tensorflow.compat.v1 as tf
...@@ -26,6 +27,7 @@ from object_detection.core import standard_fields as fields ...@@ -26,6 +27,7 @@ from object_detection.core import standard_fields as fields
from object_detection.dataset_tools import seq_example_util from object_detection.dataset_tools import seq_example_util
from object_detection.protos import input_reader_pb2 from object_detection.protos import input_reader_pb2
from object_detection.utils import dataset_util from object_detection.utils import dataset_util
from object_detection.utils import tf_version
def _get_labelmap_path(): def _get_labelmap_path():
...@@ -35,6 +37,7 @@ def _get_labelmap_path(): ...@@ -35,6 +37,7 @@ def _get_labelmap_path():
'pet_label_map.pbtxt') 'pet_label_map.pbtxt')
@unittest.skipIf(tf_version.is_tf2(), 'Skipping TF1.X only test.')
class InputReaderBuilderTest(tf.test.TestCase): class InputReaderBuilderTest(tf.test.TestCase):
def create_tf_record(self): def create_tf_record(self):
......
...@@ -16,8 +16,11 @@ ...@@ -16,8 +16,11 @@
"""A function to build an object detection matcher from configuration.""" """A function to build an object detection matcher from configuration."""
from object_detection.matchers import argmax_matcher from object_detection.matchers import argmax_matcher
from object_detection.matchers import bipartite_matcher
from object_detection.protos import matcher_pb2 from object_detection.protos import matcher_pb2
from object_detection.utils import tf_version
if tf_version.is_tf1():
from object_detection.matchers import bipartite_matcher # pylint: disable=g-import-not-at-top
def build(matcher_config): def build(matcher_config):
...@@ -48,6 +51,8 @@ def build(matcher_config): ...@@ -48,6 +51,8 @@ def build(matcher_config):
force_match_for_each_row=matcher.force_match_for_each_row, force_match_for_each_row=matcher.force_match_for_each_row,
use_matmul_gather=matcher.use_matmul_gather) use_matmul_gather=matcher.use_matmul_gather)
if matcher_config.WhichOneof('matcher_oneof') == 'bipartite_matcher': if matcher_config.WhichOneof('matcher_oneof') == 'bipartite_matcher':
if tf_version.is_tf2():
raise ValueError('bipartite_matcher is not supported in TF 2.X')
matcher = matcher_config.bipartite_matcher matcher = matcher_config.bipartite_matcher
return bipartite_matcher.GreedyBipartiteMatcher(matcher.use_matmul_gather) return bipartite_matcher.GreedyBipartiteMatcher(matcher.use_matmul_gather)
raise ValueError('Empty matcher.') raise ValueError('Empty matcher.')
...@@ -20,11 +20,15 @@ import tensorflow.compat.v1 as tf ...@@ -20,11 +20,15 @@ import tensorflow.compat.v1 as tf
from google.protobuf import text_format from google.protobuf import text_format
from object_detection.builders import matcher_builder from object_detection.builders import matcher_builder
from object_detection.matchers import argmax_matcher from object_detection.matchers import argmax_matcher
from object_detection.matchers import bipartite_matcher
from object_detection.protos import matcher_pb2 from object_detection.protos import matcher_pb2
from object_detection.utils import test_case
from object_detection.utils import tf_version
if tf_version.is_tf1():
from object_detection.matchers import bipartite_matcher # pylint: disable=g-import-not-at-top
class MatcherBuilderTest(tf.test.TestCase):
class MatcherBuilderTest(test_case.TestCase):
def test_build_arg_max_matcher_with_defaults(self): def test_build_arg_max_matcher_with_defaults(self):
matcher_text_proto = """ matcher_text_proto = """
...@@ -34,7 +38,7 @@ class MatcherBuilderTest(tf.test.TestCase): ...@@ -34,7 +38,7 @@ class MatcherBuilderTest(tf.test.TestCase):
matcher_proto = matcher_pb2.Matcher() matcher_proto = matcher_pb2.Matcher()
text_format.Merge(matcher_text_proto, matcher_proto) text_format.Merge(matcher_text_proto, matcher_proto)
matcher_object = matcher_builder.build(matcher_proto) matcher_object = matcher_builder.build(matcher_proto)
self.assertTrue(isinstance(matcher_object, argmax_matcher.ArgMaxMatcher)) self.assertIsInstance(matcher_object, argmax_matcher.ArgMaxMatcher)
self.assertAlmostEqual(matcher_object._matched_threshold, 0.5) self.assertAlmostEqual(matcher_object._matched_threshold, 0.5)
self.assertAlmostEqual(matcher_object._unmatched_threshold, 0.5) self.assertAlmostEqual(matcher_object._unmatched_threshold, 0.5)
self.assertTrue(matcher_object._negatives_lower_than_unmatched) self.assertTrue(matcher_object._negatives_lower_than_unmatched)
...@@ -49,7 +53,7 @@ class MatcherBuilderTest(tf.test.TestCase): ...@@ -49,7 +53,7 @@ class MatcherBuilderTest(tf.test.TestCase):
matcher_proto = matcher_pb2.Matcher() matcher_proto = matcher_pb2.Matcher()
text_format.Merge(matcher_text_proto, matcher_proto) text_format.Merge(matcher_text_proto, matcher_proto)
matcher_object = matcher_builder.build(matcher_proto) matcher_object = matcher_builder.build(matcher_proto)
self.assertTrue(isinstance(matcher_object, argmax_matcher.ArgMaxMatcher)) self.assertIsInstance(matcher_object, argmax_matcher.ArgMaxMatcher)
self.assertEqual(matcher_object._matched_threshold, None) self.assertEqual(matcher_object._matched_threshold, None)
self.assertEqual(matcher_object._unmatched_threshold, None) self.assertEqual(matcher_object._unmatched_threshold, None)
self.assertTrue(matcher_object._negatives_lower_than_unmatched) self.assertTrue(matcher_object._negatives_lower_than_unmatched)
...@@ -68,7 +72,7 @@ class MatcherBuilderTest(tf.test.TestCase): ...@@ -68,7 +72,7 @@ class MatcherBuilderTest(tf.test.TestCase):
matcher_proto = matcher_pb2.Matcher() matcher_proto = matcher_pb2.Matcher()
text_format.Merge(matcher_text_proto, matcher_proto) text_format.Merge(matcher_text_proto, matcher_proto)
matcher_object = matcher_builder.build(matcher_proto) matcher_object = matcher_builder.build(matcher_proto)
self.assertTrue(isinstance(matcher_object, argmax_matcher.ArgMaxMatcher)) self.assertIsInstance(matcher_object, argmax_matcher.ArgMaxMatcher)
self.assertAlmostEqual(matcher_object._matched_threshold, 0.7) self.assertAlmostEqual(matcher_object._matched_threshold, 0.7)
self.assertAlmostEqual(matcher_object._unmatched_threshold, 0.3) self.assertAlmostEqual(matcher_object._unmatched_threshold, 0.3)
self.assertFalse(matcher_object._negatives_lower_than_unmatched) self.assertFalse(matcher_object._negatives_lower_than_unmatched)
...@@ -76,6 +80,8 @@ class MatcherBuilderTest(tf.test.TestCase): ...@@ -76,6 +80,8 @@ class MatcherBuilderTest(tf.test.TestCase):
self.assertTrue(matcher_object._use_matmul_gather) self.assertTrue(matcher_object._use_matmul_gather)
def test_build_bipartite_matcher(self): def test_build_bipartite_matcher(self):
if tf_version.is_tf2():
self.skipTest('BipartiteMatcher unsupported in TF 2.X. Skipping.')
matcher_text_proto = """ matcher_text_proto = """
bipartite_matcher { bipartite_matcher {
} }
...@@ -83,8 +89,8 @@ class MatcherBuilderTest(tf.test.TestCase): ...@@ -83,8 +89,8 @@ class MatcherBuilderTest(tf.test.TestCase):
matcher_proto = matcher_pb2.Matcher() matcher_proto = matcher_pb2.Matcher()
text_format.Merge(matcher_text_proto, matcher_proto) text_format.Merge(matcher_text_proto, matcher_proto)
matcher_object = matcher_builder.build(matcher_proto) matcher_object = matcher_builder.build(matcher_proto)
self.assertTrue( self.assertIsInstance(matcher_object,
isinstance(matcher_object, bipartite_matcher.GreedyBipartiteMatcher)) bipartite_matcher.GreedyBipartiteMatcher)
def test_raise_error_on_empty_matcher(self): def test_raise_error_on_empty_matcher(self):
matcher_text_proto = """ matcher_text_proto = """
......
...@@ -28,6 +28,8 @@ from object_detection.builders import region_similarity_calculator_builder as si ...@@ -28,6 +28,8 @@ from object_detection.builders import region_similarity_calculator_builder as si
from object_detection.core import balanced_positive_negative_sampler as sampler from object_detection.core import balanced_positive_negative_sampler as sampler
from object_detection.core import post_processing from object_detection.core import post_processing
from object_detection.core import target_assigner from object_detection.core import target_assigner
from object_detection.meta_architectures import center_net_meta_arch
from object_detection.meta_architectures import context_rcnn_meta_arch
from object_detection.meta_architectures import faster_rcnn_meta_arch from object_detection.meta_architectures import faster_rcnn_meta_arch
from object_detection.meta_architectures import rfcn_meta_arch from object_detection.meta_architectures import rfcn_meta_arch
from object_detection.meta_architectures import ssd_meta_arch from object_detection.meta_architectures import ssd_meta_arch
...@@ -46,6 +48,7 @@ from object_detection.utils import tf_version ...@@ -46,6 +48,7 @@ from object_detection.utils import tf_version
if tf_version.is_tf2(): if tf_version.is_tf2():
from object_detection.models import center_net_hourglass_feature_extractor from object_detection.models import center_net_hourglass_feature_extractor
from object_detection.models import center_net_resnet_feature_extractor from object_detection.models import center_net_resnet_feature_extractor
from object_detection.models import center_net_resnet_v1_fpn_feature_extractor
from object_detection.models import faster_rcnn_inception_resnet_v2_keras_feature_extractor as frcnn_inc_res_keras from object_detection.models import faster_rcnn_inception_resnet_v2_keras_feature_extractor as frcnn_inc_res_keras
from object_detection.models import faster_rcnn_resnet_keras_feature_extractor as frcnn_resnet_keras from object_detection.models import faster_rcnn_resnet_keras_feature_extractor as frcnn_resnet_keras
from object_detection.models import ssd_resnet_v1_fpn_keras_feature_extractor as ssd_resnet_v1_fpn_keras from object_detection.models import ssd_resnet_v1_fpn_keras_feature_extractor as ssd_resnet_v1_fpn_keras
...@@ -78,6 +81,7 @@ if tf_version.is_tf1(): ...@@ -78,6 +81,7 @@ if tf_version.is_tf1():
from object_detection.models.ssd_mobiledet_feature_extractor import SSDMobileDetCPUFeatureExtractor from object_detection.models.ssd_mobiledet_feature_extractor import SSDMobileDetCPUFeatureExtractor
from object_detection.models.ssd_mobiledet_feature_extractor import SSDMobileDetDSPFeatureExtractor from object_detection.models.ssd_mobiledet_feature_extractor import SSDMobileDetDSPFeatureExtractor
from object_detection.models.ssd_mobiledet_feature_extractor import SSDMobileDetEdgeTPUFeatureExtractor from object_detection.models.ssd_mobiledet_feature_extractor import SSDMobileDetEdgeTPUFeatureExtractor
from object_detection.models.ssd_mobiledet_feature_extractor import SSDMobileDetGPUFeatureExtractor
from object_detection.models.ssd_pnasnet_feature_extractor import SSDPNASNetFeatureExtractor from object_detection.models.ssd_pnasnet_feature_extractor import SSDPNASNetFeatureExtractor
from object_detection.predictors import rfcn_box_predictor from object_detection.predictors import rfcn_box_predictor
# pylint: enable=g-import-not-at-top # pylint: enable=g-import-not-at-top
...@@ -108,8 +112,12 @@ if tf_version.is_tf2(): ...@@ -108,8 +112,12 @@ if tf_version.is_tf2():
} }
CENTER_NET_EXTRACTOR_FUNCTION_MAP = { CENTER_NET_EXTRACTOR_FUNCTION_MAP = {
'resnet_v2_101': center_net_resnet_feature_extractor.resnet_v2_101,
'resnet_v2_50': center_net_resnet_feature_extractor.resnet_v2_50, 'resnet_v2_50': center_net_resnet_feature_extractor.resnet_v2_50,
'resnet_v2_101': center_net_resnet_feature_extractor.resnet_v2_101,
'resnet_v1_50_fpn':
center_net_resnet_v1_fpn_feature_extractor.resnet_v1_50_fpn,
'resnet_v1_101_fpn':
center_net_resnet_v1_fpn_feature_extractor.resnet_v1_101_fpn,
'hourglass_104': center_net_hourglass_feature_extractor.hourglass_104, 'hourglass_104': center_net_hourglass_feature_extractor.hourglass_104,
} }
...@@ -159,9 +167,14 @@ if tf_version.is_tf1(): ...@@ -159,9 +167,14 @@ if tf_version.is_tf1():
EmbeddedSSDMobileNetV1FeatureExtractor, EmbeddedSSDMobileNetV1FeatureExtractor,
'ssd_pnasnet': 'ssd_pnasnet':
SSDPNASNetFeatureExtractor, SSDPNASNetFeatureExtractor,
'ssd_mobiledet_cpu': SSDMobileDetCPUFeatureExtractor, 'ssd_mobiledet_cpu':
'ssd_mobiledet_dsp': SSDMobileDetDSPFeatureExtractor, SSDMobileDetCPUFeatureExtractor,
'ssd_mobiledet_edgetpu': SSDMobileDetEdgeTPUFeatureExtractor, 'ssd_mobiledet_dsp':
SSDMobileDetDSPFeatureExtractor,
'ssd_mobiledet_edgetpu':
SSDMobileDetEdgeTPUFeatureExtractor,
'ssd_mobiledet_gpu':
SSDMobileDetGPUFeatureExtractor,
} }
FASTER_RCNN_FEATURE_EXTRACTOR_CLASS_MAP = { FASTER_RCNN_FEATURE_EXTRACTOR_CLASS_MAP = {
...@@ -765,7 +778,9 @@ def keypoint_proto_to_params(kp_config, keypoint_map_dict): ...@@ -765,7 +778,9 @@ def keypoint_proto_to_params(kp_config, keypoint_map_dict):
unmatched_keypoint_score=kp_config.unmatched_keypoint_score, unmatched_keypoint_score=kp_config.unmatched_keypoint_score,
box_scale=kp_config.box_scale, box_scale=kp_config.box_scale,
candidate_search_scale=kp_config.candidate_search_scale, candidate_search_scale=kp_config.candidate_search_scale,
candidate_ranking_mode=kp_config.candidate_ranking_mode) candidate_ranking_mode=kp_config.candidate_ranking_mode,
offset_peak_radius=kp_config.offset_peak_radius,
per_keypoint_offset=kp_config.per_keypoint_offset)
def object_detection_proto_to_params(od_config): def object_detection_proto_to_params(od_config):
......
...@@ -14,16 +14,19 @@ ...@@ -14,16 +14,19 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Tests for model_builder under TensorFlow 1.X.""" """Tests for model_builder under TensorFlow 1.X."""
import unittest
from absl.testing import parameterized from absl.testing import parameterized
import tensorflow.compat.v1 as tf import tensorflow.compat.v1 as tf
from object_detection.builders import model_builder from object_detection.builders import model_builder
from object_detection.builders import model_builder_test from object_detection.builders import model_builder_test
from object_detection.meta_architectures import context_rcnn_meta_arch
from object_detection.meta_architectures import ssd_meta_arch from object_detection.meta_architectures import ssd_meta_arch
from object_detection.protos import losses_pb2 from object_detection.protos import losses_pb2
from object_detection.utils import tf_version
@unittest.skipIf(tf_version.is_tf2(), 'Skipping TF1.X only test.')
class ModelBuilderTF1Test(model_builder_test.ModelBuilderTest): class ModelBuilderTF1Test(model_builder_test.ModelBuilderTest):
def default_ssd_feature_extractor(self): def default_ssd_feature_extractor(self):
...@@ -39,6 +42,14 @@ class ModelBuilderTF1Test(model_builder_test.ModelBuilderTest): ...@@ -39,6 +42,14 @@ class ModelBuilderTF1Test(model_builder_test.ModelBuilderTest):
return model_builder.FASTER_RCNN_FEATURE_EXTRACTOR_CLASS_MAP return model_builder.FASTER_RCNN_FEATURE_EXTRACTOR_CLASS_MAP
@parameterized.parameters(True, False)
def test_create_context_rcnn_from_config_with_params(self, is_training):
model_proto = self.create_default_faster_rcnn_model_proto()
model_proto.faster_rcnn.context_config.attention_bottleneck_dimension = 10
model_proto.faster_rcnn.context_config.attention_temperature = 0.5
model = model_builder.build(model_proto, is_training=is_training)
self.assertIsInstance(model, context_rcnn_meta_arch.ContextRCNNMetaArch)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
# Lint as: python2, python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for model_builder under TensorFlow 2.X."""
import os
import unittest
import tensorflow.compat.v1 as tf
from google.protobuf import text_format
from object_detection.builders import model_builder
from object_detection.builders import model_builder_test
from object_detection.core import losses
from object_detection.models import center_net_resnet_feature_extractor
from object_detection.protos import center_net_pb2
from object_detection.protos import model_pb2
from object_detection.utils import tf_version
@unittest.skipIf(tf_version.is_tf1(), 'Skipping TF2.X only test.')
class ModelBuilderTF2Test(model_builder_test.ModelBuilderTest):
def default_ssd_feature_extractor(self):
return 'ssd_resnet50_v1_fpn_keras'
def default_faster_rcnn_feature_extractor(self):
return 'faster_rcnn_resnet101_keras'
def ssd_feature_extractors(self):
return model_builder.SSD_KERAS_FEATURE_EXTRACTOR_CLASS_MAP
def faster_rcnn_feature_extractors(self):
return model_builder.FASTER_RCNN_KERAS_FEATURE_EXTRACTOR_CLASS_MAP
def get_fake_label_map_file_path(self):
keypoint_spec_text = """
item {
name: "/m/01g317"
id: 1
display_name: "person"
keypoints {
id: 0
label: 'nose'
}
keypoints {
id: 1
label: 'left_shoulder'
}
keypoints {
id: 2
label: 'right_shoulder'
}
keypoints {
id: 3
label: 'hip'
}
}
"""
keypoint_label_map_path = os.path.join(
self.get_temp_dir(), 'keypoint_label_map')
with tf.gfile.Open(keypoint_label_map_path, 'wb') as f:
f.write(keypoint_spec_text)
return keypoint_label_map_path
def get_fake_keypoint_proto(self):
task_proto_txt = """
task_name: "human_pose"
task_loss_weight: 0.9
keypoint_regression_loss_weight: 1.0
keypoint_heatmap_loss_weight: 0.1
keypoint_offset_loss_weight: 0.5
heatmap_bias_init: 2.14
keypoint_class_name: "/m/01g317"
loss {
classification_loss {
penalty_reduced_logistic_focal_loss {
alpha: 3.0
beta: 4.0
}
}
localization_loss {
l1_localization_loss {
}
}
}
keypoint_label_to_std {
key: "nose"
value: 0.3
}
keypoint_label_to_std {
key: "hip"
value: 0.0
}
keypoint_candidate_score_threshold: 0.3
num_candidates_per_keypoint: 12
peak_max_pool_kernel_size: 5
unmatched_keypoint_score: 0.05
box_scale: 1.7
candidate_search_scale: 0.2
candidate_ranking_mode: "score_distance_ratio"
offset_peak_radius: 3
per_keypoint_offset: true
"""
config = text_format.Merge(task_proto_txt,
center_net_pb2.CenterNet.KeypointEstimation())
return config
def get_fake_object_center_proto(self):
proto_txt = """
object_center_loss_weight: 0.5
heatmap_bias_init: 3.14
min_box_overlap_iou: 0.2
max_box_predictions: 15
classification_loss {
penalty_reduced_logistic_focal_loss {
alpha: 3.0
beta: 4.0
}
}
"""
return text_format.Merge(proto_txt,
center_net_pb2.CenterNet.ObjectCenterParams())
def get_fake_object_detection_proto(self):
proto_txt = """
task_loss_weight: 0.5
offset_loss_weight: 0.1
scale_loss_weight: 0.2
localization_loss {
l1_localization_loss {
}
}
"""
return text_format.Merge(proto_txt,
center_net_pb2.CenterNet.ObjectDetection())
def get_fake_mask_proto(self):
proto_txt = """
task_loss_weight: 0.7
classification_loss {
weighted_softmax {}
}
mask_height: 8
mask_width: 8
score_threshold: 0.7
heatmap_bias_init: -2.0
"""
return text_format.Merge(proto_txt,
center_net_pb2.CenterNet.MaskEstimation())
def test_create_center_net_model(self):
"""Test building a CenterNet model from proto txt."""
proto_txt = """
center_net {
num_classes: 10
feature_extractor {
type: "resnet_v2_101"
channel_stds: [4, 5, 6]
bgr_ordering: true
}
image_resizer {
keep_aspect_ratio_resizer {
min_dimension: 512
max_dimension: 512
pad_to_max_dimension: true
}
}
}
"""
# Set up the configuration proto.
config = text_format.Merge(proto_txt, model_pb2.DetectionModel())
config.center_net.object_center_params.CopyFrom(
self.get_fake_object_center_proto())
config.center_net.object_detection_task.CopyFrom(
self.get_fake_object_detection_proto())
config.center_net.keypoint_estimation_task.append(
self.get_fake_keypoint_proto())
config.center_net.keypoint_label_map_path = (
self.get_fake_label_map_file_path())
config.center_net.mask_estimation_task.CopyFrom(
self.get_fake_mask_proto())
# Build the model from the configuration.
model = model_builder.build(config, is_training=True)
# Check object center related parameters.
self.assertEqual(model._num_classes, 10)
self.assertIsInstance(model._center_params.classification_loss,
losses.PenaltyReducedLogisticFocalLoss)
self.assertEqual(model._center_params.classification_loss._alpha, 3.0)
self.assertEqual(model._center_params.classification_loss._beta, 4.0)
self.assertAlmostEqual(model._center_params.min_box_overlap_iou, 0.2)
self.assertAlmostEqual(
model._center_params.heatmap_bias_init, 3.14, places=4)
self.assertEqual(model._center_params.max_box_predictions, 15)
# Check object detection related parameters.
self.assertAlmostEqual(model._od_params.offset_loss_weight, 0.1)
self.assertAlmostEqual(model._od_params.scale_loss_weight, 0.2)
self.assertAlmostEqual(model._od_params.task_loss_weight, 0.5)
self.assertIsInstance(model._od_params.localization_loss,
losses.L1LocalizationLoss)
# Check keypoint estimation related parameters.
kp_params = model._kp_params_dict['human_pose']
self.assertAlmostEqual(kp_params.task_loss_weight, 0.9)
self.assertAlmostEqual(kp_params.keypoint_regression_loss_weight, 1.0)
self.assertAlmostEqual(kp_params.keypoint_offset_loss_weight, 0.5)
self.assertAlmostEqual(kp_params.heatmap_bias_init, 2.14, places=4)
self.assertEqual(kp_params.classification_loss._alpha, 3.0)
self.assertEqual(kp_params.keypoint_indices, [0, 1, 2, 3])
self.assertEqual(kp_params.keypoint_labels,
['nose', 'left_shoulder', 'right_shoulder', 'hip'])
self.assertAllClose(kp_params.keypoint_std_dev, [0.3, 1.0, 1.0, 0.0])
self.assertEqual(kp_params.classification_loss._beta, 4.0)
self.assertIsInstance(kp_params.localization_loss,
losses.L1LocalizationLoss)
self.assertAlmostEqual(kp_params.keypoint_candidate_score_threshold, 0.3)
self.assertEqual(kp_params.num_candidates_per_keypoint, 12)
self.assertEqual(kp_params.peak_max_pool_kernel_size, 5)
self.assertAlmostEqual(kp_params.unmatched_keypoint_score, 0.05)
self.assertAlmostEqual(kp_params.box_scale, 1.7)
self.assertAlmostEqual(kp_params.candidate_search_scale, 0.2)
self.assertEqual(kp_params.candidate_ranking_mode, 'score_distance_ratio')
self.assertEqual(kp_params.offset_peak_radius, 3)
self.assertEqual(kp_params.per_keypoint_offset, True)
# Check mask related parameters.
self.assertAlmostEqual(model._mask_params.task_loss_weight, 0.7)
self.assertIsInstance(model._mask_params.classification_loss,
losses.WeightedSoftmaxClassificationLoss)
self.assertEqual(model._mask_params.mask_height, 8)
self.assertEqual(model._mask_params.mask_width, 8)
self.assertAlmostEqual(model._mask_params.score_threshold, 0.7)
self.assertAlmostEqual(
model._mask_params.heatmap_bias_init, -2.0, places=4)
# Check feature extractor parameters.
self.assertIsInstance(
model._feature_extractor,
center_net_resnet_feature_extractor.CenterNetResnetFeatureExtractor)
self.assertAllClose(model._feature_extractor._channel_means, [0, 0, 0])
self.assertAllClose(model._feature_extractor._channel_stds, [4, 5, 6])
self.assertTrue(model._feature_extractor._bgr_ordering)
if __name__ == '__main__':
tf.test.main()
...@@ -17,10 +17,13 @@ ...@@ -17,10 +17,13 @@
import tensorflow.compat.v1 as tf import tensorflow.compat.v1 as tf
from tensorflow.contrib import opt as tf_opt
from object_detection.utils import learning_schedules from object_detection.utils import learning_schedules
try:
from tensorflow.contrib import opt as tf_opt # pylint: disable=g-import-not-at-top
except: # pylint: disable=bare-except
pass
def build_optimizers_tf_v1(optimizer_config, global_step=None): def build_optimizers_tf_v1(optimizer_config, global_step=None):
"""Create a TF v1 compatible optimizer based on config. """Create a TF v1 compatible optimizer based on config.
......
...@@ -20,6 +20,7 @@ from __future__ import absolute_import ...@@ -20,6 +20,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import unittest
import six import six
import tensorflow.compat.v1 as tf import tensorflow.compat.v1 as tf
...@@ -27,16 +28,15 @@ from google.protobuf import text_format ...@@ -27,16 +28,15 @@ from google.protobuf import text_format
from object_detection.builders import optimizer_builder from object_detection.builders import optimizer_builder
from object_detection.protos import optimizer_pb2 from object_detection.protos import optimizer_pb2
from object_detection.utils import tf_version
# pylint: disable=g-import-not-at-top # pylint: disable=g-import-not-at-top
try: if tf_version.is_tf1():
from tensorflow.contrib import opt as contrib_opt from tensorflow.contrib import opt as contrib_opt
except ImportError:
# TF 2.0 doesn't ship with contrib.
pass
# pylint: enable=g-import-not-at-top # pylint: enable=g-import-not-at-top
@unittest.skipIf(tf_version.is_tf2(), 'Skipping TF1.X only test.')
class LearningRateBuilderTest(tf.test.TestCase): class LearningRateBuilderTest(tf.test.TestCase):
def testBuildConstantLearningRate(self): def testBuildConstantLearningRate(self):
...@@ -118,6 +118,7 @@ class LearningRateBuilderTest(tf.test.TestCase): ...@@ -118,6 +118,7 @@ class LearningRateBuilderTest(tf.test.TestCase):
optimizer_builder._create_learning_rate(learning_rate_proto) optimizer_builder._create_learning_rate(learning_rate_proto)
@unittest.skipIf(tf_version.is_tf2(), 'Skipping TF1.X only test.')
class OptimizerBuilderTest(tf.test.TestCase): class OptimizerBuilderTest(tf.test.TestCase):
def testBuildRMSPropOptimizer(self): def testBuildRMSPropOptimizer(self):
......
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for optimizer_builder."""
import unittest
import tensorflow.compat.v1 as tf
from google.protobuf import text_format
from object_detection.builders import optimizer_builder
from object_detection.protos import optimizer_pb2
from object_detection.utils import tf_version
@unittest.skipIf(tf_version.is_tf1(), 'Skipping TF2.X only test.')
class OptimizerBuilderV2Test(tf.test.TestCase):
"""Test building optimizers in V2 mode."""
def testBuildRMSPropOptimizer(self):
optimizer_text_proto = """
rms_prop_optimizer: {
learning_rate: {
exponential_decay_learning_rate {
initial_learning_rate: 0.004
decay_steps: 800720
decay_factor: 0.95
}
}
momentum_optimizer_value: 0.9
decay: 0.9
epsilon: 1.0
}
use_moving_average: false
"""
optimizer_proto = optimizer_pb2.Optimizer()
text_format.Merge(optimizer_text_proto, optimizer_proto)
optimizer, _ = optimizer_builder.build(optimizer_proto)
self.assertIsInstance(optimizer, tf.keras.optimizers.RMSprop)
def testBuildMomentumOptimizer(self):
optimizer_text_proto = """
momentum_optimizer: {
learning_rate: {
constant_learning_rate {
learning_rate: 0.001
}
}
momentum_optimizer_value: 0.99
}
use_moving_average: false
"""
optimizer_proto = optimizer_pb2.Optimizer()
text_format.Merge(optimizer_text_proto, optimizer_proto)
optimizer, _ = optimizer_builder.build(optimizer_proto)
self.assertIsInstance(optimizer, tf.keras.optimizers.SGD)
def testBuildAdamOptimizer(self):
optimizer_text_proto = """
adam_optimizer: {
learning_rate: {
constant_learning_rate {
learning_rate: 0.002
}
}
}
use_moving_average: false
"""
optimizer_proto = optimizer_pb2.Optimizer()
text_format.Merge(optimizer_text_proto, optimizer_proto)
optimizer, _ = optimizer_builder.build(optimizer_proto)
self.assertIsInstance(optimizer, tf.keras.optimizers.Adam)
def testMovingAverageOptimizerUnsupported(self):
optimizer_text_proto = """
adam_optimizer: {
learning_rate: {
constant_learning_rate {
learning_rate: 0.002
}
}
}
use_moving_average: True
"""
optimizer_proto = optimizer_pb2.Optimizer()
text_format.Merge(optimizer_text_proto, optimizer_proto)
with self.assertRaises(ValueError):
optimizer_builder.build(optimizer_proto)
if __name__ == '__main__':
tf.enable_v2_behavior()
tf.test.main()
...@@ -19,9 +19,10 @@ import tensorflow.compat.v1 as tf ...@@ -19,9 +19,10 @@ import tensorflow.compat.v1 as tf
from google.protobuf import text_format from google.protobuf import text_format
from object_detection.builders import post_processing_builder from object_detection.builders import post_processing_builder
from object_detection.protos import post_processing_pb2 from object_detection.protos import post_processing_pb2
from object_detection.utils import test_case
class PostProcessingBuilderTest(tf.test.TestCase): class PostProcessingBuilderTest(test_case.TestCase):
def test_build_non_max_suppressor_with_correct_parameters(self): def test_build_non_max_suppressor_with_correct_parameters(self):
post_processing_text_proto = """ post_processing_text_proto = """
...@@ -77,13 +78,12 @@ class PostProcessingBuilderTest(tf.test.TestCase): ...@@ -77,13 +78,12 @@ class PostProcessingBuilderTest(tf.test.TestCase):
_, score_converter = post_processing_builder.build( _, score_converter = post_processing_builder.build(
post_processing_config) post_processing_config)
self.assertEqual(score_converter.__name__, 'identity_with_logit_scale') self.assertEqual(score_converter.__name__, 'identity_with_logit_scale')
def graph_fn():
inputs = tf.constant([1, 1], tf.float32) inputs = tf.constant([1, 1], tf.float32)
outputs = score_converter(inputs) outputs = score_converter(inputs)
with self.test_session() as sess: return outputs
converted_scores = sess.run(outputs) converted_scores = self.execute_cpu(graph_fn, [])
expected_converted_scores = sess.run(inputs) self.assertAllClose(converted_scores, [1, 1])
self.assertAllClose(converted_scores, expected_converted_scores)
def test_build_identity_score_converter_with_logit_scale(self): def test_build_identity_score_converter_with_logit_scale(self):
post_processing_text_proto = """ post_processing_text_proto = """
...@@ -95,12 +95,12 @@ class PostProcessingBuilderTest(tf.test.TestCase): ...@@ -95,12 +95,12 @@ class PostProcessingBuilderTest(tf.test.TestCase):
_, score_converter = post_processing_builder.build(post_processing_config) _, score_converter = post_processing_builder.build(post_processing_config)
self.assertEqual(score_converter.__name__, 'identity_with_logit_scale') self.assertEqual(score_converter.__name__, 'identity_with_logit_scale')
def graph_fn():
inputs = tf.constant([1, 1], tf.float32) inputs = tf.constant([1, 1], tf.float32)
outputs = score_converter(inputs) outputs = score_converter(inputs)
with self.test_session() as sess: return outputs
converted_scores = sess.run(outputs) converted_scores = self.execute_cpu(graph_fn, [])
expected_converted_scores = sess.run(tf.constant([.5, .5], tf.float32)) self.assertAllClose(converted_scores, [.5, .5])
self.assertAllClose(converted_scores, expected_converted_scores)
def test_build_sigmoid_score_converter(self): def test_build_sigmoid_score_converter(self):
post_processing_text_proto = """ post_processing_text_proto = """
...@@ -153,12 +153,12 @@ class PostProcessingBuilderTest(tf.test.TestCase): ...@@ -153,12 +153,12 @@ class PostProcessingBuilderTest(tf.test.TestCase):
self.assertEqual(calibrated_score_conversion_fn.__name__, self.assertEqual(calibrated_score_conversion_fn.__name__,
'calibrate_with_function_approximation') 'calibrate_with_function_approximation')
def graph_fn():
input_scores = tf.constant([1, 1], tf.float32) input_scores = tf.constant([1, 1], tf.float32)
outputs = calibrated_score_conversion_fn(input_scores) outputs = calibrated_score_conversion_fn(input_scores)
with self.test_session() as sess: return outputs
calibrated_scores = sess.run(outputs) calibrated_scores = self.execute_cpu(graph_fn, [])
expected_calibrated_scores = sess.run(tf.constant([0.5, 0.5], tf.float32)) self.assertAllClose(calibrated_scores, [0.5, 0.5])
self.assertAllClose(calibrated_scores, expected_calibrated_scores)
def test_build_temperature_scaling_calibrator(self): def test_build_temperature_scaling_calibrator(self):
post_processing_text_proto = """ post_processing_text_proto = """
...@@ -174,12 +174,12 @@ class PostProcessingBuilderTest(tf.test.TestCase): ...@@ -174,12 +174,12 @@ class PostProcessingBuilderTest(tf.test.TestCase):
self.assertEqual(calibrated_score_conversion_fn.__name__, self.assertEqual(calibrated_score_conversion_fn.__name__,
'calibrate_with_temperature_scaling_calibration') 'calibrate_with_temperature_scaling_calibration')
def graph_fn():
input_scores = tf.constant([1, 1], tf.float32) input_scores = tf.constant([1, 1], tf.float32)
outputs = calibrated_score_conversion_fn(input_scores) outputs = calibrated_score_conversion_fn(input_scores)
with self.test_session() as sess: return outputs
calibrated_scores = sess.run(outputs) calibrated_scores = self.execute_cpu(graph_fn, [])
expected_calibrated_scores = sess.run(tf.constant([0.5, 0.5], tf.float32)) self.assertAllClose(calibrated_scores, [0.5, 0.5])
self.assertAllClose(calibrated_scores, expected_calibrated_scores)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment