Unverified Commit 8518d053 authored by pkulzc's avatar pkulzc Committed by GitHub
Browse files

Open source MnasFPN and minor fixes to OD API (#8484)

310447280  by lzc:

    Internal change

310420845  by Zhichao Lu:

    Open source the internal Context RCNN code.

--
310362339  by Zhichao Lu:

    Internal change

310259448  by lzc:

    Update required TF version for OD API.

--
310252159  by Zhichao Lu:

    Port patch_ops_test to TF1/TF2 as TPUs.

--
310247180  by Zhichao Lu:

    Ignore keypoint heatmap loss in the regions/bounding boxes with target keypoint
    class but no valid keypoint annotations.

--
310178294  by Zhichao Lu:

    Opensource MnasFPN
    https://arxiv.org/abs/1912.01106

--
310094222  by lzc:

    Internal changes.

--
310085250  by lzc:

    Internal Change.

--
310016447  by huizhongc:

    Remove unrecognized classes from labeled_classes.

--
310009470  by rathodv:

    Mark batcher.py as TF1 only.

--
310001984  by rathodv:

    Update core/preprocessor.py to be compatible with TF1/TF2..

--
309455035  by Zhichao Lu:

    Makes the freezable_batch_norm_test run w/ v2 behavior.

    The main change is in v2 updates will happen right away when running batchnorm in training mode. So, we need to restore the weights between batchnorm calls to make sure the numerical checks all start from the same place.

--
309425881  by Zhichao Lu:

    Make TF1/TF2 optimizer builder tests explicit.

--
309408646  by Zhichao Lu:

    Make dataset builder tests TF1 and TF2 compatible.

--
309246305  by Zhichao Lu:

    Added the functionality of combining the person keypoints and object detection
    annotations in the binary that converts the COCO raw data to TfRecord.

--
309125076  by Zhichao Lu:

    Convert target_assigner_utils to TF1/TF2.

--
308966359  by huizhongc:

    Support SSD training with partially labeled groundtruth.

--
308937159  by rathodv:

    Update core/target_assigner.py to be compatible with TF1/TF2.

--
308774302  by Zhichao Lu:

    Internal

--
308732860  by rathodv:

    Make core/prefetcher.py  compatible with TF1 only.

--
308726984  by rathodv:

    Update core/multiclass_nms_test.py to be TF1/TF2 compatible.

--
308714718  by rathodv:

    Update core/region_similarity_calculator_test.py to be TF1/TF2 compatible.

--
308707960  by rathodv:

    Update core/minibatch_sampler_test.py to be TF1/TF2 compatible.

--
308700595  by rathodv:

    Update core/losses_test.py to be TF1/TF2 compatible and remove losses_test_v2.py

--
308361472  by rathodv:

    Update core/matcher_test.py to be TF1/TF2 compatible.

--
308335846  by Zhichao Lu:

    Updated the COCO evaluation logics and populated the groundturth area
    information through. This change matches the groundtruth format expected by the
    COCO keypoint evaluation.

--
308256924  by rathodv:

    Update core/keypoints_ops_test.py to be TF1/TF2 compatible.

--
308256826  by rathodv:

    Update class_agnostic_nms_test.py to be TF1/TF2 compatible.

--
308256112  by rathodv:

    Update box_list_ops_test.py to be TF1/TF2 compatible.

--
308159360  by Zhichao Lu:

    Internal change

308145008  by Zhichao Lu:

    Added 'image/class/confidence' field in the TFExample decoder.

--
307651875  by rathodv:

    Refactor core/box_list.py to support TF1/TF2.

--
307651798  by rathodv:

    Modify box_coder.py base class to work with with TF1/TF2

--
307651652  by rathodv:

    Refactor core/balanced_positive_negative_sampler.py to support TF1/TF2.

--
307651571  by rathodv:

    Modify BoxCoders tests to use test_case:execute method to allow testing with TF1.X and TF2.X

--
307651480  by rathodv:

    Modify Matcher tests to use test_case:execute method to allow testing with TF1.X and TF2.X

--
307651409  by rathodv:

    Modify AnchorGenerator tests to use test_case:execute method to allow testing with TF1.X and TF2.X

--
307651314  by rathodv:

    Refactor model_builder to support TF1 or TF2 models based on TensorFlow version.

--
307092053  by Zhichao Lu:

    Use manager to save checkpoint.

--
307071352  by ronnyvotel:

    Fixing keypoint visibilities. Now by default, the visibility is marked True if the keypoint is labeled (regardless of whether it is visible or not).
    Also, if visibilities are not present in the dataset, they will be created based on whether the keypoint coordinates are finite (vis = True) or NaN (vis = False).

--
307069557  by Zhichao Lu:

    Internal change to add few fields related to postprocessing parameters in
    center_net.proto and populate those parameters to the keypoint postprocessing
    functions.

--
307012091  by Zhichao Lu:

    Make Adam Optimizer's epsilon proto configurable.

    Potential issue: tf.compat.v1's AdamOptimizer has a default epsilon on 1e-08 ([doc-link](https://www.tensorflow.org/api_docs/python/tf/compat/v1/train/AdamOptimizer))  whereas tf.keras's AdamOptimizer has default epsilon 1e-07 ([doc-link](https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/Adam))

--
306858598  by Zhichao Lu:

    Internal changes to update the CenterNet model:
    1) Modified eval job loss computation to avoid averaging over batches with zero loss.
    2) Updated CenterNet keypoint heatmap target assigner to apply box size to heatmap Guassian standard deviation.
    3) Updated the CenterNet meta arch keypoint losses computation to apply weights outside of loss function.

--
306731223  by jonathanhuang:

    Internal change.

--
306549183  by rathodv:

    Internal Update.

--
306542930  by rathodv:

    Internal Update

--
306322697  by rathodv:

    Internal.

--
305345036  by Zhichao Lu:

    Adding COCO Camera Traps Json to tf.Example beam code

--
304104869  by lzc:

    Internal changes.

--
304068971  by jonathanhuang:

    Internal change.

--
304050469  by Zhichao Lu:

    Internal change.

--
303880642  by huizhongc:

    Support parsing partially labeled groundtruth.

--
303841743  by Zhichao Lu:

    Deprecate nms_on_host in SSDMetaArch.

--
303803204  by rathodv:

    Internal change.

--
303793895  by jonathanhuang:

    Internal change.

--
303467631  by rathodv:

    Py3 update for detection inference test.

--
303444542  by rathodv:

    Py3 update to metrics module

--
303421960  by rathodv:

    Update json_utils to python3.

--
302787583  by ronnyvotel:

    Coco results generator for submission to the coco test server.

--
302719091  by Zhichao Lu:

    Internal change to add the ResNet50 image feature extractor for CenterNet model.

--
302116230  by Zhichao Lu:

    Added the functions to overlay the heatmaps with images in visualization util
    library.

--
301888316  by Zhichao Lu:

    Fix checkpoint_filepath not defined error.

--
301840312  by ronnyvotel:

    Adding keypoint_scores to visualizations.

--
301683475  by ronnyvotel:

    Introducing the ability to preprocess `keypoint_visibilities`.

    Some data augmentation ops such as random crop can filter instances and keypoints. It's important to also filter keypoint visibilities, so that the groundtruth tensors are always in alignment.

--
301532344  by Zhichao Lu:

    Don't use tf.divide since "Quantization not yet supported for op: DIV"

--
301480348  by ronnyvotel:

    Introducing keypoint evaluation into model lib v2.
    Also, making some fixes to coco keypoint evaluation.

--
301454018  by Zhichao Lu:

    Added the image summary to visualize the train/eval input images and eval's
    prediction/groundtruth side-by-side image.

--
301317527  by Zhichao Lu:

    Updated the random_absolute_pad_image function in the preprocessor library to
    support the keypoints argument.

--
301300324  by Zhichao Lu:

    Apply name change(experimental_run_v2 -> run) for all callers in Tensorflow.

--
301297115  by ronnyvotel:

    Utility function for setting keypoint visibilities based on keypoint coordinates.

--
301248885  by Zhichao Lu:

    Allow MultiworkerMirroredStrategy(MWMS) use by adding checkpoint handling with temporary directories in model_lib_v2. Added missing WeakKeyDictionary cfer_fn_cache field in CollectiveAllReduceStrategyExtended.

--
301224559  by Zhichao Lu:

    ...1) Fixes model_lib to also use keypoints while preparing model groundtruth.
    ...2) Tests model_lib with newly added keypoint metrics config.

--
300836556  by Zhichao Lu:

    Internal changes to add keypoint estimation parameters in CenterNet proto.

--
300795208  by Zhichao Lu:

    Updated the eval_util library to populate the keypoint groundtruth to
    eval_dict.

--
299474766  by Zhichao Lu:

    ...Modifies eval_util to create Keypoint Evaluator objects when configured in eval config.

--
299453920  by Zhichao Lu:

    Add swish activation as a hyperperams option.

--
299240093  by ronnyvotel:

    Keypoint postprocessing for CenterNetMetaArch.

--
299176395  by Zhichao Lu:

    Internal change.

--
299135608  by Zhichao Lu:

    Internal changes to refactor the CenterNet model in preparation for keypoint estimation tasks.

--
298915482  by Zhichao Lu:

    Make dataset_builder aware of input_context for distributed training.

--
298713595  by Zhichao Lu:

    Handling data with negative size boxes.

--
298695964  by Zhichao Lu:

    Expose change_coordinate_frame as a config parameter; fix multiclass_scores optional field.

--
298492150  by Zhichao Lu:

    Rename optimizer_builder_test_v2.py -> optimizer_builder_v2_test.py

--
298476471  by Zhichao Lu:

    Internal changes to support CenterNet keypoint estimation.

--
298365851  by ronnyvotel:

    Fixing a bug where groundtruth_keypoint_weights were being padded with a dynamic dimension.

--
297843700  by Zhichao Lu:

    Internal change.

--
297706988  by lzc:

    Internal change.

--
297705287  by ronnyvotel:

    Creating the "snapping" behavior in CenterNet, where regressed keypoints are refined with updated candidate keypoints from a heatmap.

--
297700447  by Zhichao Lu:

    Improve checkpoint checking logic with TF2 loop.

--
297686094  by Zhichao Lu:

    Convert "import tensorflow as tf" to "import tensorflow.compat.v1".

--
297670468  by lzc:

    Internal change.

--
297241327  by Zhichao Lu:

    Convert "import tensorflow as tf" to "import tensorflow.compat.v1".

--
297205959  by Zhichao Lu:

    Internal changes to support refactored the centernet object detection target assigner into a separate library.

--
297143806  by Zhichao Lu:

    Convert "import tensorflow as tf" to "import tensorflow.compat.v1".

--
297129625  by Zhichao Lu:

    Explicitly replace "import tensorflow" with "tensorflow.compat.v1" for TF2.x migration

--
297117070  by Zhichao Lu:

    Explicitly replace "import tensorflow" with "tensorflow.compat.v1" for TF2.x migration

--
297030190  by Zhichao Lu:

    Add configuration options for visualizing keypoint edges

--
296359649  by Zhichao Lu:

    Support DepthwiseConv2dNative (of separable conv) in weight equalization loss.

--
296290582  by Zhichao Lu:

    Internal change.

--
296093857  by Zhichao Lu:

    Internal changes to add general target assigner utilities.

--
295975116  by Zhichao Lu:

    Fix visualize_boxes_and_labels_on_image_array to show max_boxes_to_draw correctly.

--
295819711  by Zhichao Lu:

    Adds a flag to visualize_boxes_and_labels_on_image_array to skip the drawing of axis aligned bounding boxes.

--
295811929  by Zhichao Lu:

    Keypoint support in random_square_crop_by_scale.

--
295788458  by rathodv:

    Remove unused checkpoint to reduce repo size on github

--
295787184  by Zhichao Lu:

    Enable visualization of edges between keypoints

--
295763508  by Zhichao Lu:

    [Context RCNN] Add an option to enable / disable cropping feature in the post
    process step in the meta archtecture.

--
295605344  by Zhichao Lu:

    internal change.

--
294926050  by ronnyvotel:

    Adding per-keypoint groundtruth weights. These weights are intended to be used as multipliers in a keypoint loss function.

    Groundtruth keypoint weights are constructed as follows:
    - Initialize the weight for each keypoint type based on user-specified weights in the input_reader proto
    - Mask out (i.e. make zero) all keypoint weights that are not visible.

--
294829061  by lzc:

    Internal change.

--
294566503  by Zhichao Lu:

    Changed internal CenterNet Model configuration.

--
294346662  by ronnyvotel:

    Using NaN values in keypoint coordinates that are not visible.

--
294333339  by Zhichao Lu:

    Change experimetna_distribute_dataset -> experimental_distribute_dataset_from_function

--
293928752  by Zhichao Lu:

    Internal change

--
293909384  by Zhichao Lu:

    Add capabilities to train 1024x1024 CenterNet models.

--
293637554  by ronnyvotel:

    Adding keypoint visibilities to TfExampleDecoder.

--
293501558  by lzc:

    Internal change.

--
293252851  by Zhichao Lu:

    Change tf.gfile.GFile to tf.io.gfile.GFile.

--
292730217  by Zhichao Lu:

    Internal change.

--
292456563  by lzc:

    Internal changes.

--
292355612  by Zhichao Lu:

    Use tf.gather and tf.scatter_nd instead of matrix ops.

--
292245265  by rathodv:

    Internal

--
291989323  by richardmunoz:

    Refactor out building a DataDecoder from building a tf.data.Dataset.

--
291950147  by Zhichao Lu:

    Flip bounding boxes in arbitrary shaped tensors.

--
291401052  by huizhongc:

    Fix multiscale grid anchor generator to allow fully convolutional inference. When exporting model with identity_resizer as image_resizer, there is an incorrect box offset on the detection results. We add the anchor offset to address this problem.

--
291298871  by Zhichao Lu:

    Py3 compatibility changes.

--
290957957  by Zhichao Lu:

    Hourglass feature extractor for CenterNet.

--
290564372  by Zhichao Lu:

    Internal change.

--
290155278  by rathodv:

    Remove Dataset Explorer.

--
290155153  by Zhichao Lu:

    Internal change

--
290122054  by Zhichao Lu:

    Unify the format in the faster_rcnn.proto

--
290116084  by Zhichao Lu:

    Deprecate tensorflow.contrib.

--
290100672  by Zhichao Lu:

    Update MobilenetV3 SSD candidates

--
289926392  by Zhichao Lu:

    Internal change

--
289553440  by Zhichao Lu:

    [Object Detection API] Fix the comments about the dimension of the rpn_box_encodings from 4-D to 3-D.

--
288994128  by lzc:

    Internal changes.

--
288942194  by lzc:

    Internal change.

--
288746124  by Zhichao Lu:

    Configurable channel mean/std. dev in CenterNet feature extractors.

--
288552509  by rathodv:

    Internal.

--
288541285  by rathodv:

    Internal update.

--
288396396  by Zhichao Lu:

    Make object detection import contrib explicitly

--
288255791  by rathodv:

    Internal

--
288078600  by Zhichao Lu:

    Fix model_lib_v2 test

--
287952244  by rathodv:

    Internal

--
287921774  by Zhichao Lu:

    internal change

--
287906173  by Zhichao Lu:

    internal change

--
287889407  by jonathanhuang:

    PY3 compatibility

--
287889042  by rathodv:

    Internal

--
287876178  by Zhichao Lu:

    Internal change.

--
287770490  by Zhichao Lu:

    Add CenterNet proto and builder

--
287694213  by Zhichao Lu:

    Support for running multiple steps per tf.function call.

--
287377183  by jonathanhuang:

    PY3 compatibility

--
287371344  by rathodv:

    Support loading keypoint labels and ids.

--
287368213  by rathodv:

    Add protos supporting keypoint evaluation.

--
286673200  by rathodv:

    dataset_tools PY3 migration

--
286635106  by Zhichao Lu:

    Update code for upcoming tf.contrib removal

--
286479439  by Zhichao Lu:

    Internal change

--
286311711  by Zhichao Lu:

    Skeleton of context model within TFODAPI

--
286005546  by Zhichao Lu:

    Fix Faster-RCNN training when using keep_aspect_ratio_resizer with pad_to_max_dimension

--
285906400  by derekjchow:

    Internal change

--
285822795  by Zhichao Lu:

    Add CenterNet meta arch target assigners.

--
285447238  by Zhichao Lu:

    Internal changes.

--
285016927  by Zhichao Lu:

    Make _dummy_computation a tf.function. This fixes breakage caused by
    cl/284256438

--
284827274  by Zhichao Lu:

    Convert to python 3.

--
284645593  by rathodv:

    Internal change

--
284639893  by rathodv:

    Add missing documentation for keypoints in eval_util.py.

--
284323712  by Zhichao Lu:

    Internal changes.

--
284295290  by Zhichao Lu:

    Updating input config proto and dataset builder to include context fields

    Updating standard_fields and tf_example_decoder to include context features

--
284226821  by derekjchow:

    Update exporter.

--
284211030  by Zhichao Lu:

    API changes in CenterNet informed by the experiments with hourlgass network.

--
284190451  by Zhichao Lu:

    Add support for CenterNet losses in protos and builders.

--
284093961  by lzc:

    Internal changes.

--
284028174  by Zhichao Lu:

    Internal change

--
284014719  by derekjchow:

    Do not pad top_down feature maps unnecessarily.

--
284005765  by Zhichao Lu:

    Add new pad_to_multiple_resizer

--
283858233  by Zhichao Lu:

    Make target assigner work when under tf.function.

--
283836611  by Zhichao Lu:

    Make config getters more general.

--
283808990  by Zhichao Lu:

    Internal change

--
283754588  by Zhichao Lu:

    Internal changes.

--
282460301  by Zhichao Lu:

    Add ability to restore v2 style checkpoints.

--
281605842  by lzc:

    Add option to disable loss computation in OD API eval job.

--
280298212  by Zhichao Lu:

    Add backwards compatible change

--
280237857  by Zhichao Lu:

    internal change

--

PiperOrigin-RevId: 310447280
parent ac5fff19
# Lint as: python2, python3
# Copyright 2017 The TensorFlow Authors. All Rights Reserved. # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
...@@ -24,7 +25,13 @@ from object_detection.builders import hyperparams_builder ...@@ -24,7 +25,13 @@ from object_detection.builders import hyperparams_builder
from object_detection.core import freezable_batch_norm from object_detection.core import freezable_batch_norm
from object_detection.protos import hyperparams_pb2 from object_detection.protos import hyperparams_pb2
slim = tf.contrib.slim # pylint: disable=g-import-not-at-top
try:
from tensorflow.contrib import slim
except ImportError:
# TF 2.0 doesn't ship with contrib.
pass
# pylint: enable=g-import-not-at-top
def _get_scope_key(op): def _get_scope_key(op):
...@@ -49,7 +56,7 @@ class HyperparamsBuilderTest(tf.test.TestCase): ...@@ -49,7 +56,7 @@ class HyperparamsBuilderTest(tf.test.TestCase):
scope_fn = hyperparams_builder.build(conv_hyperparams_proto, scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
is_training=True) is_training=True)
scope = scope_fn() scope = scope_fn()
self.assertTrue(_get_scope_key(slim.conv2d) in scope) self.assertIn(_get_scope_key(slim.conv2d), scope)
def test_default_arg_scope_has_separable_conv2d_op(self): def test_default_arg_scope_has_separable_conv2d_op(self):
conv_hyperparams_text_proto = """ conv_hyperparams_text_proto = """
...@@ -67,7 +74,7 @@ class HyperparamsBuilderTest(tf.test.TestCase): ...@@ -67,7 +74,7 @@ class HyperparamsBuilderTest(tf.test.TestCase):
scope_fn = hyperparams_builder.build(conv_hyperparams_proto, scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
is_training=True) is_training=True)
scope = scope_fn() scope = scope_fn()
self.assertTrue(_get_scope_key(slim.separable_conv2d) in scope) self.assertIn(_get_scope_key(slim.separable_conv2d), scope)
def test_default_arg_scope_has_conv2d_transpose_op(self): def test_default_arg_scope_has_conv2d_transpose_op(self):
conv_hyperparams_text_proto = """ conv_hyperparams_text_proto = """
...@@ -85,7 +92,7 @@ class HyperparamsBuilderTest(tf.test.TestCase): ...@@ -85,7 +92,7 @@ class HyperparamsBuilderTest(tf.test.TestCase):
scope_fn = hyperparams_builder.build(conv_hyperparams_proto, scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
is_training=True) is_training=True)
scope = scope_fn() scope = scope_fn()
self.assertTrue(_get_scope_key(slim.conv2d_transpose) in scope) self.assertIn(_get_scope_key(slim.conv2d_transpose), scope)
def test_explicit_fc_op_arg_scope_has_fully_connected_op(self): def test_explicit_fc_op_arg_scope_has_fully_connected_op(self):
conv_hyperparams_text_proto = """ conv_hyperparams_text_proto = """
...@@ -104,7 +111,7 @@ class HyperparamsBuilderTest(tf.test.TestCase): ...@@ -104,7 +111,7 @@ class HyperparamsBuilderTest(tf.test.TestCase):
scope_fn = hyperparams_builder.build(conv_hyperparams_proto, scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
is_training=True) is_training=True)
scope = scope_fn() scope = scope_fn()
self.assertTrue(_get_scope_key(slim.fully_connected) in scope) self.assertIn(_get_scope_key(slim.fully_connected), scope)
def test_separable_conv2d_and_conv2d_and_transpose_have_same_parameters(self): def test_separable_conv2d_and_conv2d_and_transpose_have_same_parameters(self):
conv_hyperparams_text_proto = """ conv_hyperparams_text_proto = """
...@@ -143,7 +150,7 @@ class HyperparamsBuilderTest(tf.test.TestCase): ...@@ -143,7 +150,7 @@ class HyperparamsBuilderTest(tf.test.TestCase):
scope_fn = hyperparams_builder.build(conv_hyperparams_proto, scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
is_training=True) is_training=True)
scope = scope_fn() scope = scope_fn()
conv_scope_arguments = scope.values()[0] conv_scope_arguments = list(scope.values())[0]
regularizer = conv_scope_arguments['weights_regularizer'] regularizer = conv_scope_arguments['weights_regularizer']
weights = np.array([1., -1, 4., 2.]) weights = np.array([1., -1, 4., 2.])
with self.test_session() as sess: with self.test_session() as sess:
...@@ -284,8 +291,8 @@ class HyperparamsBuilderTest(tf.test.TestCase): ...@@ -284,8 +291,8 @@ class HyperparamsBuilderTest(tf.test.TestCase):
self.assertTrue(batch_norm_params['scale']) self.assertTrue(batch_norm_params['scale'])
batch_norm_layer = keras_config.build_batch_norm() batch_norm_layer = keras_config.build_batch_norm()
self.assertTrue(isinstance(batch_norm_layer, self.assertIsInstance(batch_norm_layer,
freezable_batch_norm.FreezableBatchNorm)) freezable_batch_norm.FreezableBatchNorm)
def test_return_non_default_batch_norm_params_keras_override( def test_return_non_default_batch_norm_params_keras_override(
self): self):
...@@ -420,8 +427,8 @@ class HyperparamsBuilderTest(tf.test.TestCase): ...@@ -420,8 +427,8 @@ class HyperparamsBuilderTest(tf.test.TestCase):
# The batch norm builder should build an identity Lambda layer # The batch norm builder should build an identity Lambda layer
identity_layer = keras_config.build_batch_norm() identity_layer = keras_config.build_batch_norm()
self.assertTrue(isinstance(identity_layer, self.assertIsInstance(identity_layer,
tf.keras.layers.Lambda)) tf.keras.layers.Lambda)
def test_use_none_activation(self): def test_use_none_activation(self):
conv_hyperparams_text_proto = """ conv_hyperparams_text_proto = """
...@@ -463,7 +470,7 @@ class HyperparamsBuilderTest(tf.test.TestCase): ...@@ -463,7 +470,7 @@ class HyperparamsBuilderTest(tf.test.TestCase):
self.assertEqual( self.assertEqual(
keras_config.params(include_activation=True)['activation'], None) keras_config.params(include_activation=True)['activation'], None)
activation_layer = keras_config.build_activation_layer() activation_layer = keras_config.build_activation_layer()
self.assertTrue(isinstance(activation_layer, tf.keras.layers.Lambda)) self.assertIsInstance(activation_layer, tf.keras.layers.Lambda)
self.assertEqual(activation_layer.function, tf.identity) self.assertEqual(activation_layer.function, tf.identity)
def test_use_relu_activation(self): def test_use_relu_activation(self):
...@@ -506,7 +513,7 @@ class HyperparamsBuilderTest(tf.test.TestCase): ...@@ -506,7 +513,7 @@ class HyperparamsBuilderTest(tf.test.TestCase):
self.assertEqual( self.assertEqual(
keras_config.params(include_activation=True)['activation'], tf.nn.relu) keras_config.params(include_activation=True)['activation'], tf.nn.relu)
activation_layer = keras_config.build_activation_layer() activation_layer = keras_config.build_activation_layer()
self.assertTrue(isinstance(activation_layer, tf.keras.layers.Lambda)) self.assertIsInstance(activation_layer, tf.keras.layers.Lambda)
self.assertEqual(activation_layer.function, tf.nn.relu) self.assertEqual(activation_layer.function, tf.nn.relu)
def test_use_relu_6_activation(self): def test_use_relu_6_activation(self):
...@@ -549,9 +556,52 @@ class HyperparamsBuilderTest(tf.test.TestCase): ...@@ -549,9 +556,52 @@ class HyperparamsBuilderTest(tf.test.TestCase):
self.assertEqual( self.assertEqual(
keras_config.params(include_activation=True)['activation'], tf.nn.relu6) keras_config.params(include_activation=True)['activation'], tf.nn.relu6)
activation_layer = keras_config.build_activation_layer() activation_layer = keras_config.build_activation_layer()
self.assertTrue(isinstance(activation_layer, tf.keras.layers.Lambda)) self.assertIsInstance(activation_layer, tf.keras.layers.Lambda)
self.assertEqual(activation_layer.function, tf.nn.relu6) self.assertEqual(activation_layer.function, tf.nn.relu6)
def test_use_swish_activation(self):
conv_hyperparams_text_proto = """
regularizer {
l2_regularizer {
}
}
initializer {
truncated_normal_initializer {
}
}
activation: SWISH
"""
conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
is_training=True)
scope = scope_fn()
conv_scope_arguments = scope[_get_scope_key(slim.conv2d)]
self.assertEqual(conv_scope_arguments['activation_fn'], tf.nn.swish)
def test_use_swish_activation_keras(self):
conv_hyperparams_text_proto = """
regularizer {
l2_regularizer {
}
}
initializer {
truncated_normal_initializer {
}
}
activation: SWISH
"""
conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
keras_config = hyperparams_builder.KerasLayerHyperparams(
conv_hyperparams_proto)
self.assertEqual(keras_config.params()['activation'], None)
self.assertEqual(
keras_config.params(include_activation=True)['activation'], tf.nn.swish)
activation_layer = keras_config.build_activation_layer()
self.assertIsInstance(activation_layer, tf.keras.layers.Lambda)
self.assertEqual(activation_layer.function, tf.nn.swish)
def test_override_activation_keras(self): def test_override_activation_keras(self):
conv_hyperparams_text_proto = """ conv_hyperparams_text_proto = """
regularizer { regularizer {
......
...@@ -133,9 +133,22 @@ def build(image_resizer_config): ...@@ -133,9 +133,22 @@ def build(image_resizer_config):
'Invalid image resizer condition option for ' 'Invalid image resizer condition option for '
'ConditionalShapeResizer: \'%s\'.' 'ConditionalShapeResizer: \'%s\'.'
% conditional_shape_resize_config.condition) % conditional_shape_resize_config.condition)
if not conditional_shape_resize_config.convert_to_grayscale: if not conditional_shape_resize_config.convert_to_grayscale:
return image_resizer_fn return image_resizer_fn
elif image_resizer_oneof == 'pad_to_multiple_resizer':
pad_to_multiple_resizer_config = (
image_resizer_config.pad_to_multiple_resizer)
if pad_to_multiple_resizer_config.multiple < 0:
raise ValueError('`multiple` for pad_to_multiple_resizer should be > 0.')
else:
image_resizer_fn = functools.partial(
preprocessor.resize_pad_to_multiple,
multiple=pad_to_multiple_resizer_config.multiple)
if not pad_to_multiple_resizer_config.convert_to_grayscale:
return image_resizer_fn
else: else:
raise ValueError( raise ValueError(
'Invalid image resizer option: \'%s\'.' % image_resizer_oneof) 'Invalid image resizer option: \'%s\'.' % image_resizer_oneof)
...@@ -149,16 +162,16 @@ def build(image_resizer_config): ...@@ -149,16 +162,16 @@ def build(image_resizer_config):
width] containing instance masks. width] containing instance masks.
Returns: Returns:
Note that the position of the resized_image_shape changes based on whether Note that the position of the resized_image_shape changes based on whether
masks are present. masks are present.
resized_image: A 3D tensor of shape [new_height, new_width, 1], resized_image: A 3D tensor of shape [new_height, new_width, 1],
where the image has been resized (with bilinear interpolation) so that where the image has been resized (with bilinear interpolation) so that
min(new_height, new_width) == min_dimension or min(new_height, new_width) == min_dimension or
max(new_height, new_width) == max_dimension. max(new_height, new_width) == max_dimension.
resized_masks: If masks is not None, also outputs masks. A 3D tensor of resized_masks: If masks is not None, also outputs masks. A 3D tensor of
shape [num_instances, new_height, new_width]. shape [num_instances, new_height, new_width].
resized_image_shape: A 1D tensor of shape [3] containing shape of the resized_image_shape: A 1D tensor of shape [3] containing shape of the
resized image. resized image.
""" """
# image_resizer_fn returns [resized_image, resized_image_shape] if # image_resizer_fn returns [resized_image, resized_image_shape] if
# mask==None, otherwise it returns # mask==None, otherwise it returns
......
...@@ -211,6 +211,31 @@ class ImageResizerBuilderTest(tf.test.TestCase): ...@@ -211,6 +211,31 @@ class ImageResizerBuilderTest(tf.test.TestCase):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
image_resizer_builder.build(invalid_image_resizer_text_proto) image_resizer_builder.build(invalid_image_resizer_text_proto)
def test_build_pad_to_multiple_resizer(self):
"""Test building a pad_to_multiple_resizer from proto."""
image_resizer_text_proto = """
pad_to_multiple_resizer {
multiple: 32
}
"""
input_shape = (60, 30, 3)
expected_output_shape = (64, 32, 3)
output_shape = self._shape_of_resized_random_image_given_text_proto(
input_shape, image_resizer_text_proto)
self.assertEqual(output_shape, expected_output_shape)
def test_build_pad_to_multiple_resizer_invalid_multiple(self):
"""Test that building a pad_to_multiple_resizer errors with invalid multiple."""
image_resizer_text_proto = """
pad_to_multiple_resizer {
multiple: -10
}
"""
with self.assertRaises(ValueError):
image_resizer_builder.build(image_resizer_text_proto)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
# Lint as: python2, python3
# Copyright 2017 The TensorFlow Authors. All Rights Reserved. # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
...@@ -23,12 +24,24 @@ Detection configuration framework, they should define their own builder function ...@@ -23,12 +24,24 @@ Detection configuration framework, they should define their own builder function
that wraps the build function. that wraps the build function.
""" """
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf import tensorflow as tf
from object_detection.data_decoders import tf_example_decoder from object_detection.data_decoders import tf_example_decoder
from object_detection.protos import input_reader_pb2 from object_detection.protos import input_reader_pb2
parallel_reader = tf.contrib.slim.parallel_reader # pylint: disable=g-import-not-at-top
try:
from tensorflow.contrib import slim as contrib_slim
except ImportError:
# TF 2.0 doesn't ship with contrib.
pass
# pylint: enable=g-import-not-at-top
parallel_reader = contrib_slim.parallel_reader
def build(input_reader_config): def build(input_reader_config):
...@@ -70,7 +83,8 @@ def build(input_reader_config): ...@@ -70,7 +83,8 @@ def build(input_reader_config):
decoder = tf_example_decoder.TfExampleDecoder( decoder = tf_example_decoder.TfExampleDecoder(
load_instance_masks=input_reader_config.load_instance_masks, load_instance_masks=input_reader_config.load_instance_masks,
instance_mask_type=input_reader_config.mask_type, instance_mask_type=input_reader_config.mask_type,
label_map_proto_file=label_map_proto_file) label_map_proto_file=label_map_proto_file,
load_context_features=input_reader_config.load_context_features)
return decoder.decode(string_tensor) return decoder.decode(string_tensor)
raise ValueError('Unsupported input_reader_config.') raise ValueError('Unsupported input_reader_config.')
...@@ -54,6 +54,48 @@ class InputReaderBuilderTest(tf.test.TestCase): ...@@ -54,6 +54,48 @@ class InputReaderBuilderTest(tf.test.TestCase):
return path return path
def create_tf_record_with_context(self):
path = os.path.join(self.get_temp_dir(), 'tfrecord')
writer = tf.python_io.TFRecordWriter(path)
image_tensor = np.random.randint(255, size=(4, 5, 3)).astype(np.uint8)
flat_mask = (4 * 5) * [1.0]
context_features = (10 * 3) * [1.0]
with self.test_session():
encoded_jpeg = tf.image.encode_jpeg(tf.constant(image_tensor)).eval()
example = tf.train.Example(
features=tf.train.Features(
feature={
'image/encoded':
dataset_util.bytes_feature(encoded_jpeg),
'image/format':
dataset_util.bytes_feature('jpeg'.encode('utf8')),
'image/height':
dataset_util.int64_feature(4),
'image/width':
dataset_util.int64_feature(5),
'image/object/bbox/xmin':
dataset_util.float_list_feature([0.0]),
'image/object/bbox/xmax':
dataset_util.float_list_feature([1.0]),
'image/object/bbox/ymin':
dataset_util.float_list_feature([0.0]),
'image/object/bbox/ymax':
dataset_util.float_list_feature([1.0]),
'image/object/class/label':
dataset_util.int64_list_feature([2]),
'image/object/mask':
dataset_util.float_list_feature(flat_mask),
'image/context_features':
dataset_util.float_list_feature(context_features),
'image/context_feature_length':
dataset_util.int64_list_feature([10]),
}))
writer.write(example.SerializeToString())
writer.close()
return path
def test_build_tf_record_input_reader(self): def test_build_tf_record_input_reader(self):
tf_record_path = self.create_tf_record() tf_record_path = self.create_tf_record()
...@@ -71,18 +113,53 @@ class InputReaderBuilderTest(tf.test.TestCase): ...@@ -71,18 +113,53 @@ class InputReaderBuilderTest(tf.test.TestCase):
with tf.train.MonitoredSession() as sess: with tf.train.MonitoredSession() as sess:
output_dict = sess.run(tensor_dict) output_dict = sess.run(tensor_dict)
self.assertTrue(fields.InputDataFields.groundtruth_instance_masks self.assertNotIn(fields.InputDataFields.groundtruth_instance_masks,
not in output_dict) output_dict)
self.assertEquals( self.assertEqual((4, 5, 3), output_dict[fields.InputDataFields.image].shape)
(4, 5, 3), output_dict[fields.InputDataFields.image].shape) self.assertEqual([2],
self.assertEquals( output_dict[fields.InputDataFields.groundtruth_classes])
[2], output_dict[fields.InputDataFields.groundtruth_classes]) self.assertEqual(
self.assertEquals(
(1, 4), output_dict[fields.InputDataFields.groundtruth_boxes].shape) (1, 4), output_dict[fields.InputDataFields.groundtruth_boxes].shape)
self.assertAllEqual( self.assertAllEqual(
[0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 1.0, 1.0],
output_dict[fields.InputDataFields.groundtruth_boxes][0]) output_dict[fields.InputDataFields.groundtruth_boxes][0])
def test_build_tf_record_input_reader_with_context(self):
tf_record_path = self.create_tf_record_with_context()
input_reader_text_proto = """
shuffle: false
num_readers: 1
tf_record_input_reader {{
input_path: '{0}'
}}
""".format(tf_record_path)
input_reader_proto = input_reader_pb2.InputReader()
text_format.Merge(input_reader_text_proto, input_reader_proto)
input_reader_proto.load_context_features = True
tensor_dict = input_reader_builder.build(input_reader_proto)
with tf.train.MonitoredSession() as sess:
output_dict = sess.run(tensor_dict)
self.assertNotIn(fields.InputDataFields.groundtruth_instance_masks,
output_dict)
self.assertEqual((4, 5, 3), output_dict[fields.InputDataFields.image].shape)
self.assertEqual([2],
output_dict[fields.InputDataFields.groundtruth_classes])
self.assertEqual(
(1, 4), output_dict[fields.InputDataFields.groundtruth_boxes].shape)
self.assertAllEqual(
[0.0, 0.0, 1.0, 1.0],
output_dict[fields.InputDataFields.groundtruth_boxes][0])
self.assertAllEqual(
[0.0, 0.0, 1.0, 1.0],
output_dict[fields.InputDataFields.groundtruth_boxes][0])
self.assertAllEqual(
(3, 10), output_dict[fields.InputDataFields.context_features].shape)
self.assertAllEqual(
(10), output_dict[fields.InputDataFields.context_feature_length])
def test_build_tf_record_input_reader_and_load_instance_masks(self): def test_build_tf_record_input_reader_and_load_instance_masks(self):
tf_record_path = self.create_tf_record() tf_record_path = self.create_tf_record()
...@@ -101,11 +178,10 @@ class InputReaderBuilderTest(tf.test.TestCase): ...@@ -101,11 +178,10 @@ class InputReaderBuilderTest(tf.test.TestCase):
with tf.train.MonitoredSession() as sess: with tf.train.MonitoredSession() as sess:
output_dict = sess.run(tensor_dict) output_dict = sess.run(tensor_dict)
self.assertEquals( self.assertEqual((4, 5, 3), output_dict[fields.InputDataFields.image].shape)
(4, 5, 3), output_dict[fields.InputDataFields.image].shape) self.assertEqual([2],
self.assertEquals( output_dict[fields.InputDataFields.groundtruth_classes])
[2], output_dict[fields.InputDataFields.groundtruth_classes]) self.assertEqual(
self.assertEquals(
(1, 4), output_dict[fields.InputDataFields.groundtruth_boxes].shape) (1, 4), output_dict[fields.InputDataFields.groundtruth_boxes].shape)
self.assertAllEqual( self.assertAllEqual(
[0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 1.0, 1.0],
......
...@@ -201,6 +201,9 @@ def _build_localization_loss(loss_config): ...@@ -201,6 +201,9 @@ def _build_localization_loss(loss_config):
if loss_type == 'weighted_iou': if loss_type == 'weighted_iou':
return losses.WeightedIOULocalizationLoss() return losses.WeightedIOULocalizationLoss()
if loss_type == 'l1_localization_loss':
return losses.L1LocalizationLoss()
raise ValueError('Empty loss config.') raise ValueError('Empty loss config.')
...@@ -249,4 +252,9 @@ def _build_classification_loss(loss_config): ...@@ -249,4 +252,9 @@ def _build_classification_loss(loss_config):
alpha=config.alpha, alpha=config.alpha,
bootstrap_type=('hard' if config.hard_bootstrap else 'soft')) bootstrap_type=('hard' if config.hard_bootstrap else 'soft'))
if loss_type == 'penalty_reduced_logistic_focal_loss':
config = loss_config.penalty_reduced_logistic_focal_loss
return losses.PenaltyReducedLogisticFocalLoss(
alpha=config.alpha, beta=config.beta)
raise ValueError('Empty loss config.') raise ValueError('Empty loss config.')
...@@ -40,8 +40,8 @@ class LocalizationLossBuilderTest(tf.test.TestCase): ...@@ -40,8 +40,8 @@ class LocalizationLossBuilderTest(tf.test.TestCase):
losses_proto = losses_pb2.Loss() losses_proto = losses_pb2.Loss()
text_format.Merge(losses_text_proto, losses_proto) text_format.Merge(losses_text_proto, losses_proto)
_, localization_loss, _, _, _, _, _ = losses_builder.build(losses_proto) _, localization_loss, _, _, _, _, _ = losses_builder.build(losses_proto)
self.assertTrue(isinstance(localization_loss, self.assertIsInstance(localization_loss,
losses.WeightedL2LocalizationLoss)) losses.WeightedL2LocalizationLoss)
def test_build_weighted_smooth_l1_localization_loss_default_delta(self): def test_build_weighted_smooth_l1_localization_loss_default_delta(self):
losses_text_proto = """ losses_text_proto = """
...@@ -57,8 +57,8 @@ class LocalizationLossBuilderTest(tf.test.TestCase): ...@@ -57,8 +57,8 @@ class LocalizationLossBuilderTest(tf.test.TestCase):
losses_proto = losses_pb2.Loss() losses_proto = losses_pb2.Loss()
text_format.Merge(losses_text_proto, losses_proto) text_format.Merge(losses_text_proto, losses_proto)
_, localization_loss, _, _, _, _, _ = losses_builder.build(losses_proto) _, localization_loss, _, _, _, _, _ = losses_builder.build(losses_proto)
self.assertTrue(isinstance(localization_loss, self.assertIsInstance(localization_loss,
losses.WeightedSmoothL1LocalizationLoss)) losses.WeightedSmoothL1LocalizationLoss)
self.assertAlmostEqual(localization_loss._delta, 1.0) self.assertAlmostEqual(localization_loss._delta, 1.0)
def test_build_weighted_smooth_l1_localization_loss_non_default_delta(self): def test_build_weighted_smooth_l1_localization_loss_non_default_delta(self):
...@@ -76,8 +76,8 @@ class LocalizationLossBuilderTest(tf.test.TestCase): ...@@ -76,8 +76,8 @@ class LocalizationLossBuilderTest(tf.test.TestCase):
losses_proto = losses_pb2.Loss() losses_proto = losses_pb2.Loss()
text_format.Merge(losses_text_proto, losses_proto) text_format.Merge(losses_text_proto, losses_proto)
_, localization_loss, _, _, _, _, _ = losses_builder.build(losses_proto) _, localization_loss, _, _, _, _, _ = losses_builder.build(losses_proto)
self.assertTrue(isinstance(localization_loss, self.assertIsInstance(localization_loss,
losses.WeightedSmoothL1LocalizationLoss)) losses.WeightedSmoothL1LocalizationLoss)
self.assertAlmostEqual(localization_loss._delta, 0.1) self.assertAlmostEqual(localization_loss._delta, 0.1)
def test_build_weighted_iou_localization_loss(self): def test_build_weighted_iou_localization_loss(self):
...@@ -94,8 +94,8 @@ class LocalizationLossBuilderTest(tf.test.TestCase): ...@@ -94,8 +94,8 @@ class LocalizationLossBuilderTest(tf.test.TestCase):
losses_proto = losses_pb2.Loss() losses_proto = losses_pb2.Loss()
text_format.Merge(losses_text_proto, losses_proto) text_format.Merge(losses_text_proto, losses_proto)
_, localization_loss, _, _, _, _, _ = losses_builder.build(losses_proto) _, localization_loss, _, _, _, _, _ = losses_builder.build(losses_proto)
self.assertTrue(isinstance(localization_loss, self.assertIsInstance(localization_loss,
losses.WeightedIOULocalizationLoss)) losses.WeightedIOULocalizationLoss)
def test_anchorwise_output(self): def test_anchorwise_output(self):
losses_text_proto = """ losses_text_proto = """
...@@ -111,8 +111,8 @@ class LocalizationLossBuilderTest(tf.test.TestCase): ...@@ -111,8 +111,8 @@ class LocalizationLossBuilderTest(tf.test.TestCase):
losses_proto = losses_pb2.Loss() losses_proto = losses_pb2.Loss()
text_format.Merge(losses_text_proto, losses_proto) text_format.Merge(losses_text_proto, losses_proto)
_, localization_loss, _, _, _, _, _ = losses_builder.build(losses_proto) _, localization_loss, _, _, _, _, _ = losses_builder.build(losses_proto)
self.assertTrue(isinstance(localization_loss, self.assertIsInstance(localization_loss,
losses.WeightedSmoothL1LocalizationLoss)) losses.WeightedSmoothL1LocalizationLoss)
predictions = tf.constant([[[0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 1.0, 1.0]]]) predictions = tf.constant([[[0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 1.0, 1.0]]])
targets = tf.constant([[[0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 1.0, 1.0]]]) targets = tf.constant([[[0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 1.0, 1.0]]])
weights = tf.constant([[1.0, 1.0]]) weights = tf.constant([[1.0, 1.0]])
...@@ -132,6 +132,7 @@ class LocalizationLossBuilderTest(tf.test.TestCase): ...@@ -132,6 +132,7 @@ class LocalizationLossBuilderTest(tf.test.TestCase):
losses_builder._build_localization_loss(losses_proto) losses_builder._build_localization_loss(losses_proto)
class ClassificationLossBuilderTest(tf.test.TestCase): class ClassificationLossBuilderTest(tf.test.TestCase):
def test_build_weighted_sigmoid_classification_loss(self): def test_build_weighted_sigmoid_classification_loss(self):
...@@ -148,8 +149,8 @@ class ClassificationLossBuilderTest(tf.test.TestCase): ...@@ -148,8 +149,8 @@ class ClassificationLossBuilderTest(tf.test.TestCase):
losses_proto = losses_pb2.Loss() losses_proto = losses_pb2.Loss()
text_format.Merge(losses_text_proto, losses_proto) text_format.Merge(losses_text_proto, losses_proto)
classification_loss, _, _, _, _, _, _ = losses_builder.build(losses_proto) classification_loss, _, _, _, _, _, _ = losses_builder.build(losses_proto)
self.assertTrue(isinstance(classification_loss, self.assertIsInstance(classification_loss,
losses.WeightedSigmoidClassificationLoss)) losses.WeightedSigmoidClassificationLoss)
def test_build_weighted_sigmoid_focal_classification_loss(self): def test_build_weighted_sigmoid_focal_classification_loss(self):
losses_text_proto = """ losses_text_proto = """
...@@ -165,8 +166,8 @@ class ClassificationLossBuilderTest(tf.test.TestCase): ...@@ -165,8 +166,8 @@ class ClassificationLossBuilderTest(tf.test.TestCase):
losses_proto = losses_pb2.Loss() losses_proto = losses_pb2.Loss()
text_format.Merge(losses_text_proto, losses_proto) text_format.Merge(losses_text_proto, losses_proto)
classification_loss, _, _, _, _, _, _ = losses_builder.build(losses_proto) classification_loss, _, _, _, _, _, _ = losses_builder.build(losses_proto)
self.assertTrue(isinstance(classification_loss, self.assertIsInstance(classification_loss,
losses.SigmoidFocalClassificationLoss)) losses.SigmoidFocalClassificationLoss)
self.assertAlmostEqual(classification_loss._alpha, None) self.assertAlmostEqual(classification_loss._alpha, None)
self.assertAlmostEqual(classification_loss._gamma, 2.0) self.assertAlmostEqual(classification_loss._gamma, 2.0)
...@@ -186,8 +187,8 @@ class ClassificationLossBuilderTest(tf.test.TestCase): ...@@ -186,8 +187,8 @@ class ClassificationLossBuilderTest(tf.test.TestCase):
losses_proto = losses_pb2.Loss() losses_proto = losses_pb2.Loss()
text_format.Merge(losses_text_proto, losses_proto) text_format.Merge(losses_text_proto, losses_proto)
classification_loss, _, _, _, _, _, _ = losses_builder.build(losses_proto) classification_loss, _, _, _, _, _, _ = losses_builder.build(losses_proto)
self.assertTrue(isinstance(classification_loss, self.assertIsInstance(classification_loss,
losses.SigmoidFocalClassificationLoss)) losses.SigmoidFocalClassificationLoss)
self.assertAlmostEqual(classification_loss._alpha, 0.25) self.assertAlmostEqual(classification_loss._alpha, 0.25)
self.assertAlmostEqual(classification_loss._gamma, 3.0) self.assertAlmostEqual(classification_loss._gamma, 3.0)
...@@ -205,8 +206,8 @@ class ClassificationLossBuilderTest(tf.test.TestCase): ...@@ -205,8 +206,8 @@ class ClassificationLossBuilderTest(tf.test.TestCase):
losses_proto = losses_pb2.Loss() losses_proto = losses_pb2.Loss()
text_format.Merge(losses_text_proto, losses_proto) text_format.Merge(losses_text_proto, losses_proto)
classification_loss, _, _, _, _, _, _ = losses_builder.build(losses_proto) classification_loss, _, _, _, _, _, _ = losses_builder.build(losses_proto)
self.assertTrue(isinstance(classification_loss, self.assertIsInstance(classification_loss,
losses.WeightedSoftmaxClassificationLoss)) losses.WeightedSoftmaxClassificationLoss)
def test_build_weighted_logits_softmax_classification_loss(self): def test_build_weighted_logits_softmax_classification_loss(self):
losses_text_proto = """ losses_text_proto = """
...@@ -222,9 +223,9 @@ class ClassificationLossBuilderTest(tf.test.TestCase): ...@@ -222,9 +223,9 @@ class ClassificationLossBuilderTest(tf.test.TestCase):
losses_proto = losses_pb2.Loss() losses_proto = losses_pb2.Loss()
text_format.Merge(losses_text_proto, losses_proto) text_format.Merge(losses_text_proto, losses_proto)
classification_loss, _, _, _, _, _, _ = losses_builder.build(losses_proto) classification_loss, _, _, _, _, _, _ = losses_builder.build(losses_proto)
self.assertTrue( self.assertIsInstance(
isinstance(classification_loss, classification_loss,
losses.WeightedSoftmaxClassificationAgainstLogitsLoss)) losses.WeightedSoftmaxClassificationAgainstLogitsLoss)
def test_build_weighted_softmax_classification_loss_with_logit_scale(self): def test_build_weighted_softmax_classification_loss_with_logit_scale(self):
losses_text_proto = """ losses_text_proto = """
...@@ -241,8 +242,8 @@ class ClassificationLossBuilderTest(tf.test.TestCase): ...@@ -241,8 +242,8 @@ class ClassificationLossBuilderTest(tf.test.TestCase):
losses_proto = losses_pb2.Loss() losses_proto = losses_pb2.Loss()
text_format.Merge(losses_text_proto, losses_proto) text_format.Merge(losses_text_proto, losses_proto)
classification_loss, _, _, _, _, _, _ = losses_builder.build(losses_proto) classification_loss, _, _, _, _, _, _ = losses_builder.build(losses_proto)
self.assertTrue(isinstance(classification_loss, self.assertIsInstance(classification_loss,
losses.WeightedSoftmaxClassificationLoss)) losses.WeightedSoftmaxClassificationLoss)
def test_build_bootstrapped_sigmoid_classification_loss(self): def test_build_bootstrapped_sigmoid_classification_loss(self):
losses_text_proto = """ losses_text_proto = """
...@@ -259,8 +260,8 @@ class ClassificationLossBuilderTest(tf.test.TestCase): ...@@ -259,8 +260,8 @@ class ClassificationLossBuilderTest(tf.test.TestCase):
losses_proto = losses_pb2.Loss() losses_proto = losses_pb2.Loss()
text_format.Merge(losses_text_proto, losses_proto) text_format.Merge(losses_text_proto, losses_proto)
classification_loss, _, _, _, _, _, _ = losses_builder.build(losses_proto) classification_loss, _, _, _, _, _, _ = losses_builder.build(losses_proto)
self.assertTrue(isinstance(classification_loss, self.assertIsInstance(classification_loss,
losses.BootstrappedSigmoidClassificationLoss)) losses.BootstrappedSigmoidClassificationLoss)
def test_anchorwise_output(self): def test_anchorwise_output(self):
losses_text_proto = """ losses_text_proto = """
...@@ -277,8 +278,8 @@ class ClassificationLossBuilderTest(tf.test.TestCase): ...@@ -277,8 +278,8 @@ class ClassificationLossBuilderTest(tf.test.TestCase):
losses_proto = losses_pb2.Loss() losses_proto = losses_pb2.Loss()
text_format.Merge(losses_text_proto, losses_proto) text_format.Merge(losses_text_proto, losses_proto)
classification_loss, _, _, _, _, _, _ = losses_builder.build(losses_proto) classification_loss, _, _, _, _, _, _ = losses_builder.build(losses_proto)
self.assertTrue(isinstance(classification_loss, self.assertIsInstance(classification_loss,
losses.WeightedSigmoidClassificationLoss)) losses.WeightedSigmoidClassificationLoss)
predictions = tf.constant([[[0.0, 1.0, 0.0], [0.0, 0.5, 0.5]]]) predictions = tf.constant([[[0.0, 1.0, 0.0], [0.0, 0.5, 0.5]]])
targets = tf.constant([[[0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]]) targets = tf.constant([[[0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]])
weights = tf.constant([[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]]) weights = tf.constant([[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]])
...@@ -298,6 +299,7 @@ class ClassificationLossBuilderTest(tf.test.TestCase): ...@@ -298,6 +299,7 @@ class ClassificationLossBuilderTest(tf.test.TestCase):
losses_builder.build(losses_proto) losses_builder.build(losses_proto)
class HardExampleMinerBuilderTest(tf.test.TestCase): class HardExampleMinerBuilderTest(tf.test.TestCase):
def test_do_not_build_hard_example_miner_by_default(self): def test_do_not_build_hard_example_miner_by_default(self):
...@@ -333,7 +335,7 @@ class HardExampleMinerBuilderTest(tf.test.TestCase): ...@@ -333,7 +335,7 @@ class HardExampleMinerBuilderTest(tf.test.TestCase):
losses_proto = losses_pb2.Loss() losses_proto = losses_pb2.Loss()
text_format.Merge(losses_text_proto, losses_proto) text_format.Merge(losses_text_proto, losses_proto)
_, _, _, _, hard_example_miner, _, _ = losses_builder.build(losses_proto) _, _, _, _, hard_example_miner, _, _ = losses_builder.build(losses_proto)
self.assertTrue(isinstance(hard_example_miner, losses.HardExampleMiner)) self.assertIsInstance(hard_example_miner, losses.HardExampleMiner)
self.assertEqual(hard_example_miner._loss_type, 'cls') self.assertEqual(hard_example_miner._loss_type, 'cls')
def test_build_hard_example_miner_for_localization_loss(self): def test_build_hard_example_miner_for_localization_loss(self):
...@@ -353,7 +355,7 @@ class HardExampleMinerBuilderTest(tf.test.TestCase): ...@@ -353,7 +355,7 @@ class HardExampleMinerBuilderTest(tf.test.TestCase):
losses_proto = losses_pb2.Loss() losses_proto = losses_pb2.Loss()
text_format.Merge(losses_text_proto, losses_proto) text_format.Merge(losses_text_proto, losses_proto)
_, _, _, _, hard_example_miner, _, _ = losses_builder.build(losses_proto) _, _, _, _, hard_example_miner, _, _ = losses_builder.build(losses_proto)
self.assertTrue(isinstance(hard_example_miner, losses.HardExampleMiner)) self.assertIsInstance(hard_example_miner, losses.HardExampleMiner)
self.assertEqual(hard_example_miner._loss_type, 'loc') self.assertEqual(hard_example_miner._loss_type, 'loc')
def test_build_hard_example_miner_with_non_default_values(self): def test_build_hard_example_miner_with_non_default_values(self):
...@@ -377,7 +379,7 @@ class HardExampleMinerBuilderTest(tf.test.TestCase): ...@@ -377,7 +379,7 @@ class HardExampleMinerBuilderTest(tf.test.TestCase):
losses_proto = losses_pb2.Loss() losses_proto = losses_pb2.Loss()
text_format.Merge(losses_text_proto, losses_proto) text_format.Merge(losses_text_proto, losses_proto)
_, _, _, _, hard_example_miner, _, _ = losses_builder.build(losses_proto) _, _, _, _, hard_example_miner, _, _ = losses_builder.build(losses_proto)
self.assertTrue(isinstance(hard_example_miner, losses.HardExampleMiner)) self.assertIsInstance(hard_example_miner, losses.HardExampleMiner)
self.assertEqual(hard_example_miner._num_hard_examples, 32) self.assertEqual(hard_example_miner._num_hard_examples, 32)
self.assertAlmostEqual(hard_example_miner._iou_threshold, 0.5) self.assertAlmostEqual(hard_example_miner._iou_threshold, 0.5)
self.assertEqual(hard_example_miner._max_negatives_per_positive, 10) self.assertEqual(hard_example_miner._max_negatives_per_positive, 10)
...@@ -406,11 +408,11 @@ class LossBuilderTest(tf.test.TestCase): ...@@ -406,11 +408,11 @@ class LossBuilderTest(tf.test.TestCase):
(classification_loss, localization_loss, classification_weight, (classification_loss, localization_loss, classification_weight,
localization_weight, hard_example_miner, _, localization_weight, hard_example_miner, _,
_) = losses_builder.build(losses_proto) _) = losses_builder.build(losses_proto)
self.assertTrue(isinstance(hard_example_miner, losses.HardExampleMiner)) self.assertIsInstance(hard_example_miner, losses.HardExampleMiner)
self.assertTrue(isinstance(classification_loss, self.assertIsInstance(classification_loss,
losses.WeightedSoftmaxClassificationLoss)) losses.WeightedSoftmaxClassificationLoss)
self.assertTrue(isinstance(localization_loss, self.assertIsInstance(localization_loss,
losses.WeightedL2LocalizationLoss)) losses.WeightedL2LocalizationLoss)
self.assertAlmostEqual(classification_weight, 0.8) self.assertAlmostEqual(classification_weight, 0.8)
self.assertAlmostEqual(localization_weight, 0.2) self.assertAlmostEqual(localization_weight, 0.2)
...@@ -434,12 +436,10 @@ class LossBuilderTest(tf.test.TestCase): ...@@ -434,12 +436,10 @@ class LossBuilderTest(tf.test.TestCase):
(classification_loss, localization_loss, classification_weight, (classification_loss, localization_loss, classification_weight,
localization_weight, hard_example_miner, _, localization_weight, hard_example_miner, _,
_) = losses_builder.build(losses_proto) _) = losses_builder.build(losses_proto)
self.assertTrue(isinstance(hard_example_miner, losses.HardExampleMiner)) self.assertIsInstance(hard_example_miner, losses.HardExampleMiner)
self.assertTrue( self.assertIsInstance(classification_loss,
isinstance(classification_loss, losses.WeightedSoftmaxClassificationLoss)
losses.WeightedSoftmaxClassificationLoss)) self.assertIsInstance(localization_loss, losses.WeightedL2LocalizationLoss)
self.assertTrue(
isinstance(localization_loss, losses.WeightedL2LocalizationLoss))
self.assertAlmostEqual(classification_weight, 0.8) self.assertAlmostEqual(classification_weight, 0.8)
self.assertAlmostEqual(localization_weight, 0.2) self.assertAlmostEqual(localization_weight, 0.2)
...@@ -464,12 +464,10 @@ class LossBuilderTest(tf.test.TestCase): ...@@ -464,12 +464,10 @@ class LossBuilderTest(tf.test.TestCase):
(classification_loss, localization_loss, classification_weight, (classification_loss, localization_loss, classification_weight,
localization_weight, hard_example_miner, _, localization_weight, hard_example_miner, _,
_) = losses_builder.build(losses_proto) _) = losses_builder.build(losses_proto)
self.assertTrue(isinstance(hard_example_miner, losses.HardExampleMiner)) self.assertIsInstance(hard_example_miner, losses.HardExampleMiner)
self.assertTrue( self.assertIsInstance(classification_loss,
isinstance(classification_loss, losses.WeightedSoftmaxClassificationLoss)
losses.WeightedSoftmaxClassificationLoss)) self.assertIsInstance(localization_loss, losses.WeightedL2LocalizationLoss)
self.assertTrue(
isinstance(localization_loss, losses.WeightedL2LocalizationLoss))
self.assertAlmostEqual(classification_weight, 0.8) self.assertAlmostEqual(classification_weight, 0.8)
self.assertAlmostEqual(localization_weight, 0.2) self.assertAlmostEqual(localization_weight, 0.2)
...@@ -505,8 +503,8 @@ class FasterRcnnClassificationLossBuilderTest(tf.test.TestCase): ...@@ -505,8 +503,8 @@ class FasterRcnnClassificationLossBuilderTest(tf.test.TestCase):
text_format.Merge(losses_text_proto, losses_proto) text_format.Merge(losses_text_proto, losses_proto)
classification_loss = losses_builder.build_faster_rcnn_classification_loss( classification_loss = losses_builder.build_faster_rcnn_classification_loss(
losses_proto) losses_proto)
self.assertTrue(isinstance(classification_loss, self.assertIsInstance(classification_loss,
losses.WeightedSigmoidClassificationLoss)) losses.WeightedSigmoidClassificationLoss)
def test_build_softmax_loss(self): def test_build_softmax_loss(self):
losses_text_proto = """ losses_text_proto = """
...@@ -517,8 +515,8 @@ class FasterRcnnClassificationLossBuilderTest(tf.test.TestCase): ...@@ -517,8 +515,8 @@ class FasterRcnnClassificationLossBuilderTest(tf.test.TestCase):
text_format.Merge(losses_text_proto, losses_proto) text_format.Merge(losses_text_proto, losses_proto)
classification_loss = losses_builder.build_faster_rcnn_classification_loss( classification_loss = losses_builder.build_faster_rcnn_classification_loss(
losses_proto) losses_proto)
self.assertTrue(isinstance(classification_loss, self.assertIsInstance(classification_loss,
losses.WeightedSoftmaxClassificationLoss)) losses.WeightedSoftmaxClassificationLoss)
def test_build_logits_softmax_loss(self): def test_build_logits_softmax_loss(self):
losses_text_proto = """ losses_text_proto = """
...@@ -542,9 +540,8 @@ class FasterRcnnClassificationLossBuilderTest(tf.test.TestCase): ...@@ -542,9 +540,8 @@ class FasterRcnnClassificationLossBuilderTest(tf.test.TestCase):
text_format.Merge(losses_text_proto, losses_proto) text_format.Merge(losses_text_proto, losses_proto)
classification_loss = losses_builder.build_faster_rcnn_classification_loss( classification_loss = losses_builder.build_faster_rcnn_classification_loss(
losses_proto) losses_proto)
self.assertTrue( self.assertIsInstance(classification_loss,
isinstance(classification_loss, losses.SigmoidFocalClassificationLoss)
losses.SigmoidFocalClassificationLoss))
def test_build_softmax_loss_by_default(self): def test_build_softmax_loss_by_default(self):
losses_text_proto = """ losses_text_proto = """
...@@ -553,8 +550,8 @@ class FasterRcnnClassificationLossBuilderTest(tf.test.TestCase): ...@@ -553,8 +550,8 @@ class FasterRcnnClassificationLossBuilderTest(tf.test.TestCase):
text_format.Merge(losses_text_proto, losses_proto) text_format.Merge(losses_text_proto, losses_proto)
classification_loss = losses_builder.build_faster_rcnn_classification_loss( classification_loss = losses_builder.build_faster_rcnn_classification_loss(
losses_proto) losses_proto)
self.assertTrue(isinstance(classification_loss, self.assertIsInstance(classification_loss,
losses.WeightedSoftmaxClassificationLoss)) losses.WeightedSoftmaxClassificationLoss)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -16,7 +16,6 @@ ...@@ -16,7 +16,6 @@
"""A function to build a DetectionModel from configuration.""" """A function to build a DetectionModel from configuration."""
import functools import functools
from object_detection.builders import anchor_generator_builder from object_detection.builders import anchor_generator_builder
from object_detection.builders import box_coder_builder from object_detection.builders import box_coder_builder
from object_detection.builders import box_predictor_builder from object_detection.builders import box_predictor_builder
...@@ -32,96 +31,162 @@ from object_detection.core import target_assigner ...@@ -32,96 +31,162 @@ from object_detection.core import target_assigner
from object_detection.meta_architectures import faster_rcnn_meta_arch from object_detection.meta_architectures import faster_rcnn_meta_arch
from object_detection.meta_architectures import rfcn_meta_arch from object_detection.meta_architectures import rfcn_meta_arch
from object_detection.meta_architectures import ssd_meta_arch from object_detection.meta_architectures import ssd_meta_arch
from object_detection.models import faster_rcnn_inception_resnet_v2_feature_extractor as frcnn_inc_res
from object_detection.models import faster_rcnn_inception_resnet_v2_keras_feature_extractor as frcnn_inc_res_keras
from object_detection.models import faster_rcnn_inception_v2_feature_extractor as frcnn_inc_v2
from object_detection.models import faster_rcnn_nas_feature_extractor as frcnn_nas
from object_detection.models import faster_rcnn_pnas_feature_extractor as frcnn_pnas
from object_detection.models import faster_rcnn_resnet_v1_feature_extractor as frcnn_resnet_v1
from object_detection.models import ssd_resnet_v1_fpn_feature_extractor as ssd_resnet_v1_fpn
from object_detection.models import ssd_resnet_v1_fpn_keras_feature_extractor as ssd_resnet_v1_fpn_keras
from object_detection.models import ssd_resnet_v1_ppn_feature_extractor as ssd_resnet_v1_ppn
from object_detection.models.embedded_ssd_mobilenet_v1_feature_extractor import EmbeddedSSDMobileNetV1FeatureExtractor
from object_detection.models.ssd_inception_v2_feature_extractor import SSDInceptionV2FeatureExtractor
from object_detection.models.ssd_inception_v3_feature_extractor import SSDInceptionV3FeatureExtractor
from object_detection.models.ssd_mobilenet_edgetpu_feature_extractor import SSDMobileNetEdgeTPUFeatureExtractor
from object_detection.models.ssd_mobilenet_v1_feature_extractor import SSDMobileNetV1FeatureExtractor
from object_detection.models.ssd_mobilenet_v1_fpn_feature_extractor import SSDMobileNetV1FpnFeatureExtractor
from object_detection.models.ssd_mobilenet_v1_fpn_keras_feature_extractor import SSDMobileNetV1FpnKerasFeatureExtractor
from object_detection.models.ssd_mobilenet_v1_keras_feature_extractor import SSDMobileNetV1KerasFeatureExtractor
from object_detection.models.ssd_mobilenet_v1_ppn_feature_extractor import SSDMobileNetV1PpnFeatureExtractor
from object_detection.models.ssd_mobilenet_v2_feature_extractor import SSDMobileNetV2FeatureExtractor
from object_detection.models.ssd_mobilenet_v2_fpn_feature_extractor import SSDMobileNetV2FpnFeatureExtractor
from object_detection.models.ssd_mobilenet_v2_fpn_keras_feature_extractor import SSDMobileNetV2FpnKerasFeatureExtractor
from object_detection.models.ssd_mobilenet_v2_keras_feature_extractor import SSDMobileNetV2KerasFeatureExtractor
from object_detection.models.ssd_mobilenet_v3_feature_extractor import SSDMobileNetV3LargeFeatureExtractor
from object_detection.models.ssd_mobilenet_v3_feature_extractor import SSDMobileNetV3SmallFeatureExtractor
from object_detection.models.ssd_pnasnet_feature_extractor import SSDPNASNetFeatureExtractor
from object_detection.predictors import rfcn_box_predictor
from object_detection.predictors import rfcn_keras_box_predictor
from object_detection.predictors.heads import mask_head from object_detection.predictors.heads import mask_head
from object_detection.protos import losses_pb2
from object_detection.protos import model_pb2 from object_detection.protos import model_pb2
from object_detection.utils import label_map_util
from object_detection.utils import ops from object_detection.utils import ops
from object_detection.utils import tf_version
## Feature Extractors for TF
## This section conditionally imports different feature extractors based on the
## Tensorflow version.
##
# pylint: disable=g-import-not-at-top
if tf_version.is_tf2():
from object_detection.models import center_net_hourglass_feature_extractor
from object_detection.models import center_net_resnet_feature_extractor
from object_detection.models import faster_rcnn_inception_resnet_v2_keras_feature_extractor as frcnn_inc_res_keras
from object_detection.models import faster_rcnn_resnet_keras_feature_extractor as frcnn_resnet_keras
from object_detection.models import ssd_resnet_v1_fpn_keras_feature_extractor as ssd_resnet_v1_fpn_keras
from object_detection.models.ssd_mobilenet_v1_fpn_keras_feature_extractor import SSDMobileNetV1FpnKerasFeatureExtractor
from object_detection.models.ssd_mobilenet_v1_keras_feature_extractor import SSDMobileNetV1KerasFeatureExtractor
from object_detection.models.ssd_mobilenet_v2_fpn_keras_feature_extractor import SSDMobileNetV2FpnKerasFeatureExtractor
from object_detection.models.ssd_mobilenet_v2_keras_feature_extractor import SSDMobileNetV2KerasFeatureExtractor
from object_detection.predictors import rfcn_keras_box_predictor
if tf_version.is_tf1():
from object_detection.models import faster_rcnn_inception_resnet_v2_feature_extractor as frcnn_inc_res
from object_detection.models import faster_rcnn_inception_v2_feature_extractor as frcnn_inc_v2
from object_detection.models import faster_rcnn_nas_feature_extractor as frcnn_nas
from object_detection.models import faster_rcnn_pnas_feature_extractor as frcnn_pnas
from object_detection.models import faster_rcnn_resnet_v1_feature_extractor as frcnn_resnet_v1
from object_detection.models import ssd_resnet_v1_fpn_feature_extractor as ssd_resnet_v1_fpn
from object_detection.models import ssd_resnet_v1_ppn_feature_extractor as ssd_resnet_v1_ppn
from object_detection.models.embedded_ssd_mobilenet_v1_feature_extractor import EmbeddedSSDMobileNetV1FeatureExtractor
from object_detection.models.ssd_inception_v2_feature_extractor import SSDInceptionV2FeatureExtractor
from object_detection.models.ssd_mobilenet_v2_fpn_feature_extractor import SSDMobileNetV2FpnFeatureExtractor
from object_detection.models.ssd_mobilenet_v2_mnasfpn_feature_extractor import SSDMobileNetV2MnasFPNFeatureExtractor
from object_detection.models.ssd_inception_v3_feature_extractor import SSDInceptionV3FeatureExtractor
from object_detection.models.ssd_mobilenet_edgetpu_feature_extractor import SSDMobileNetEdgeTPUFeatureExtractor
from object_detection.models.ssd_mobilenet_v1_feature_extractor import SSDMobileNetV1FeatureExtractor
from object_detection.models.ssd_mobilenet_v1_fpn_feature_extractor import SSDMobileNetV1FpnFeatureExtractor
from object_detection.models.ssd_mobilenet_v1_ppn_feature_extractor import SSDMobileNetV1PpnFeatureExtractor
from object_detection.models.ssd_mobilenet_v2_feature_extractor import SSDMobileNetV2FeatureExtractor
from object_detection.models.ssd_mobilenet_v3_feature_extractor import SSDMobileNetV3LargeFeatureExtractor
from object_detection.models.ssd_mobilenet_v3_feature_extractor import SSDMobileNetV3SmallFeatureExtractor
from object_detection.models.ssd_pnasnet_feature_extractor import SSDPNASNetFeatureExtractor
from object_detection.predictors import rfcn_box_predictor
# pylint: enable=g-import-not-at-top
if tf_version.is_tf2():
SSD_KERAS_FEATURE_EXTRACTOR_CLASS_MAP = {
'ssd_mobilenet_v1_keras': SSDMobileNetV1KerasFeatureExtractor,
'ssd_mobilenet_v1_fpn_keras': SSDMobileNetV1FpnKerasFeatureExtractor,
'ssd_mobilenet_v2_keras': SSDMobileNetV2KerasFeatureExtractor,
'ssd_mobilenet_v2_fpn_keras': SSDMobileNetV2FpnKerasFeatureExtractor,
'ssd_resnet50_v1_fpn_keras':
ssd_resnet_v1_fpn_keras.SSDResNet50V1FpnKerasFeatureExtractor,
'ssd_resnet101_v1_fpn_keras':
ssd_resnet_v1_fpn_keras.SSDResNet101V1FpnKerasFeatureExtractor,
'ssd_resnet152_v1_fpn_keras':
ssd_resnet_v1_fpn_keras.SSDResNet152V1FpnKerasFeatureExtractor,
}
# A map of names to SSD feature extractors. FASTER_RCNN_KERAS_FEATURE_EXTRACTOR_CLASS_MAP = {
SSD_FEATURE_EXTRACTOR_CLASS_MAP = { 'faster_rcnn_resnet50_keras':
'ssd_inception_v2': SSDInceptionV2FeatureExtractor, frcnn_resnet_keras.FasterRCNNResnet50KerasFeatureExtractor,
'ssd_inception_v3': SSDInceptionV3FeatureExtractor, 'faster_rcnn_resnet101_keras':
'ssd_mobilenet_v1': SSDMobileNetV1FeatureExtractor, frcnn_resnet_keras.FasterRCNNResnet101KerasFeatureExtractor,
'ssd_mobilenet_v1_fpn': SSDMobileNetV1FpnFeatureExtractor, 'faster_rcnn_resnet152_keras':
'ssd_mobilenet_v1_ppn': SSDMobileNetV1PpnFeatureExtractor, frcnn_resnet_keras.FasterRCNNResnet152KerasFeatureExtractor,
'ssd_mobilenet_v2': SSDMobileNetV2FeatureExtractor, 'faster_rcnn_inception_resnet_v2_keras':
'ssd_mobilenet_v2_fpn': SSDMobileNetV2FpnFeatureExtractor, frcnn_inc_res_keras.FasterRCNNInceptionResnetV2KerasFeatureExtractor,
'ssd_mobilenet_v3_large': SSDMobileNetV3LargeFeatureExtractor, }
'ssd_mobilenet_v3_small': SSDMobileNetV3SmallFeatureExtractor,
'ssd_mobilenet_edgetpu': SSDMobileNetEdgeTPUFeatureExtractor,
'ssd_resnet50_v1_fpn': ssd_resnet_v1_fpn.SSDResnet50V1FpnFeatureExtractor,
'ssd_resnet101_v1_fpn': ssd_resnet_v1_fpn.SSDResnet101V1FpnFeatureExtractor,
'ssd_resnet152_v1_fpn': ssd_resnet_v1_fpn.SSDResnet152V1FpnFeatureExtractor,
'ssd_resnet50_v1_ppn': ssd_resnet_v1_ppn.SSDResnet50V1PpnFeatureExtractor,
'ssd_resnet101_v1_ppn':
ssd_resnet_v1_ppn.SSDResnet101V1PpnFeatureExtractor,
'ssd_resnet152_v1_ppn':
ssd_resnet_v1_ppn.SSDResnet152V1PpnFeatureExtractor,
'embedded_ssd_mobilenet_v1': EmbeddedSSDMobileNetV1FeatureExtractor,
'ssd_pnasnet': SSDPNASNetFeatureExtractor,
}
SSD_KERAS_FEATURE_EXTRACTOR_CLASS_MAP = { CENTER_NET_EXTRACTOR_FUNCTION_MAP = {
'ssd_mobilenet_v1_keras': SSDMobileNetV1KerasFeatureExtractor, 'resnet_v2_101': center_net_resnet_feature_extractor.resnet_v2_101,
'ssd_mobilenet_v1_fpn_keras': SSDMobileNetV1FpnKerasFeatureExtractor, 'resnet_v2_50': center_net_resnet_feature_extractor.resnet_v2_50,
'ssd_mobilenet_v2_keras': SSDMobileNetV2KerasFeatureExtractor, 'hourglass_104': center_net_hourglass_feature_extractor.hourglass_104,
'ssd_mobilenet_v2_fpn_keras': SSDMobileNetV2FpnKerasFeatureExtractor, }
'ssd_resnet50_v1_fpn_keras':
ssd_resnet_v1_fpn_keras.SSDResNet50V1FpnKerasFeatureExtractor,
'ssd_resnet101_v1_fpn_keras':
ssd_resnet_v1_fpn_keras.SSDResNet101V1FpnKerasFeatureExtractor,
'ssd_resnet152_v1_fpn_keras':
ssd_resnet_v1_fpn_keras.SSDResNet152V1FpnKerasFeatureExtractor,
}
# A map of names to Faster R-CNN feature extractors. FEATURE_EXTRACTOR_MAPS = [
FASTER_RCNN_FEATURE_EXTRACTOR_CLASS_MAP = { CENTER_NET_EXTRACTOR_FUNCTION_MAP,
'faster_rcnn_nas': FASTER_RCNN_KERAS_FEATURE_EXTRACTOR_CLASS_MAP,
frcnn_nas.FasterRCNNNASFeatureExtractor, SSD_KERAS_FEATURE_EXTRACTOR_CLASS_MAP
'faster_rcnn_pnas': ]
frcnn_pnas.FasterRCNNPNASFeatureExtractor,
'faster_rcnn_inception_resnet_v2': if tf_version.is_tf1():
frcnn_inc_res.FasterRCNNInceptionResnetV2FeatureExtractor, SSD_FEATURE_EXTRACTOR_CLASS_MAP = {
'faster_rcnn_inception_v2': 'ssd_inception_v2':
frcnn_inc_v2.FasterRCNNInceptionV2FeatureExtractor, SSDInceptionV2FeatureExtractor,
'faster_rcnn_resnet50': 'ssd_inception_v3':
frcnn_resnet_v1.FasterRCNNResnet50FeatureExtractor, SSDInceptionV3FeatureExtractor,
'faster_rcnn_resnet101': 'ssd_mobilenet_v1':
frcnn_resnet_v1.FasterRCNNResnet101FeatureExtractor, SSDMobileNetV1FeatureExtractor,
'faster_rcnn_resnet152': 'ssd_mobilenet_v1_fpn':
frcnn_resnet_v1.FasterRCNNResnet152FeatureExtractor, SSDMobileNetV1FpnFeatureExtractor,
} 'ssd_mobilenet_v1_ppn':
SSDMobileNetV1PpnFeatureExtractor,
'ssd_mobilenet_v2':
SSDMobileNetV2FeatureExtractor,
'ssd_mobilenet_v2_fpn':
SSDMobileNetV2FpnFeatureExtractor,
'ssd_mobilenet_v2_mnasfpn':
SSDMobileNetV2MnasFPNFeatureExtractor,
'ssd_mobilenet_v3_large':
SSDMobileNetV3LargeFeatureExtractor,
'ssd_mobilenet_v3_small':
SSDMobileNetV3SmallFeatureExtractor,
'ssd_mobilenet_edgetpu':
SSDMobileNetEdgeTPUFeatureExtractor,
'ssd_resnet50_v1_fpn':
ssd_resnet_v1_fpn.SSDResnet50V1FpnFeatureExtractor,
'ssd_resnet101_v1_fpn':
ssd_resnet_v1_fpn.SSDResnet101V1FpnFeatureExtractor,
'ssd_resnet152_v1_fpn':
ssd_resnet_v1_fpn.SSDResnet152V1FpnFeatureExtractor,
'ssd_resnet50_v1_ppn':
ssd_resnet_v1_ppn.SSDResnet50V1PpnFeatureExtractor,
'ssd_resnet101_v1_ppn':
ssd_resnet_v1_ppn.SSDResnet101V1PpnFeatureExtractor,
'ssd_resnet152_v1_ppn':
ssd_resnet_v1_ppn.SSDResnet152V1PpnFeatureExtractor,
'embedded_ssd_mobilenet_v1':
EmbeddedSSDMobileNetV1FeatureExtractor,
'ssd_pnasnet':
SSDPNASNetFeatureExtractor,
}
FASTER_RCNN_FEATURE_EXTRACTOR_CLASS_MAP = {
'faster_rcnn_nas':
frcnn_nas.FasterRCNNNASFeatureExtractor,
'faster_rcnn_pnas':
frcnn_pnas.FasterRCNNPNASFeatureExtractor,
'faster_rcnn_inception_resnet_v2':
frcnn_inc_res.FasterRCNNInceptionResnetV2FeatureExtractor,
'faster_rcnn_inception_v2':
frcnn_inc_v2.FasterRCNNInceptionV2FeatureExtractor,
'faster_rcnn_resnet50':
frcnn_resnet_v1.FasterRCNNResnet50FeatureExtractor,
'faster_rcnn_resnet101':
frcnn_resnet_v1.FasterRCNNResnet101FeatureExtractor,
'faster_rcnn_resnet152':
frcnn_resnet_v1.FasterRCNNResnet152FeatureExtractor,
}
FEATURE_EXTRACTOR_MAPS = [
SSD_FEATURE_EXTRACTOR_CLASS_MAP,
FASTER_RCNN_FEATURE_EXTRACTOR_CLASS_MAP
]
FASTER_RCNN_KERAS_FEATURE_EXTRACTOR_CLASS_MAP = {
'faster_rcnn_inception_resnet_v2_keras': def _check_feature_extractor_exists(feature_extractor_type):
frcnn_inc_res_keras.FasterRCNNInceptionResnetV2KerasFeatureExtractor, feature_extractors = set().union(*FEATURE_EXTRACTOR_MAPS)
} if feature_extractor_type not in feature_extractors:
raise ValueError('{} is not supported. See `model_builder.py` for features '
'extractors compatible with different versions of '
'Tensorflow'.format(feature_extractor_type))
def _build_ssd_feature_extractor(feature_extractor_config, def _build_ssd_feature_extractor(feature_extractor_config,
...@@ -146,14 +211,14 @@ def _build_ssd_feature_extractor(feature_extractor_config, ...@@ -146,14 +211,14 @@ def _build_ssd_feature_extractor(feature_extractor_config,
ValueError: On invalid feature extractor type. ValueError: On invalid feature extractor type.
""" """
feature_type = feature_extractor_config.type feature_type = feature_extractor_config.type
is_keras_extractor = feature_type in SSD_KERAS_FEATURE_EXTRACTOR_CLASS_MAP
depth_multiplier = feature_extractor_config.depth_multiplier depth_multiplier = feature_extractor_config.depth_multiplier
min_depth = feature_extractor_config.min_depth min_depth = feature_extractor_config.min_depth
pad_to_multiple = feature_extractor_config.pad_to_multiple pad_to_multiple = feature_extractor_config.pad_to_multiple
use_explicit_padding = feature_extractor_config.use_explicit_padding use_explicit_padding = feature_extractor_config.use_explicit_padding
use_depthwise = feature_extractor_config.use_depthwise use_depthwise = feature_extractor_config.use_depthwise
if is_keras_extractor: is_keras = tf_version.is_tf2()
if is_keras:
conv_hyperparams = hyperparams_builder.KerasLayerHyperparams( conv_hyperparams = hyperparams_builder.KerasLayerHyperparams(
feature_extractor_config.conv_hyperparams) feature_extractor_config.conv_hyperparams)
else: else:
...@@ -162,11 +227,10 @@ def _build_ssd_feature_extractor(feature_extractor_config, ...@@ -162,11 +227,10 @@ def _build_ssd_feature_extractor(feature_extractor_config,
override_base_feature_extractor_hyperparams = ( override_base_feature_extractor_hyperparams = (
feature_extractor_config.override_base_feature_extractor_hyperparams) feature_extractor_config.override_base_feature_extractor_hyperparams)
if (feature_type not in SSD_FEATURE_EXTRACTOR_CLASS_MAP) and ( if not is_keras and feature_type not in SSD_FEATURE_EXTRACTOR_CLASS_MAP:
not is_keras_extractor):
raise ValueError('Unknown ssd feature_extractor: {}'.format(feature_type)) raise ValueError('Unknown ssd feature_extractor: {}'.format(feature_type))
if is_keras_extractor: if is_keras:
feature_extractor_class = SSD_KERAS_FEATURE_EXTRACTOR_CLASS_MAP[ feature_extractor_class = SSD_KERAS_FEATURE_EXTRACTOR_CLASS_MAP[
feature_type] feature_type]
else: else:
...@@ -197,7 +261,7 @@ def _build_ssd_feature_extractor(feature_extractor_config, ...@@ -197,7 +261,7 @@ def _build_ssd_feature_extractor(feature_extractor_config,
if feature_extractor_config.HasField('num_layers'): if feature_extractor_config.HasField('num_layers'):
kwargs.update({'num_layers': feature_extractor_config.num_layers}) kwargs.update({'num_layers': feature_extractor_config.num_layers})
if is_keras_extractor: if is_keras:
kwargs.update({ kwargs.update({
'conv_hyperparams': conv_hyperparams, 'conv_hyperparams': conv_hyperparams,
'inplace_batchnorm_update': False, 'inplace_batchnorm_update': False,
...@@ -209,6 +273,7 @@ def _build_ssd_feature_extractor(feature_extractor_config, ...@@ -209,6 +273,7 @@ def _build_ssd_feature_extractor(feature_extractor_config,
'reuse_weights': reuse_weights, 'reuse_weights': reuse_weights,
}) })
if feature_extractor_config.HasField('fpn'): if feature_extractor_config.HasField('fpn'):
kwargs.update({ kwargs.update({
'fpn_min_level': 'fpn_min_level':
...@@ -239,6 +304,7 @@ def _build_ssd_model(ssd_config, is_training, add_summaries): ...@@ -239,6 +304,7 @@ def _build_ssd_model(ssd_config, is_training, add_summaries):
model_class_map). model_class_map).
""" """
num_classes = ssd_config.num_classes num_classes = ssd_config.num_classes
_check_feature_extractor_exists(ssd_config.feature_extractor.type)
# Feature extractor # Feature extractor
feature_extractor = _build_ssd_feature_extractor( feature_extractor = _build_ssd_feature_extractor(
...@@ -325,7 +391,7 @@ def _build_ssd_model(ssd_config, is_training, add_summaries): ...@@ -325,7 +391,7 @@ def _build_ssd_model(ssd_config, is_training, add_summaries):
def _build_faster_rcnn_feature_extractor( def _build_faster_rcnn_feature_extractor(
feature_extractor_config, is_training, reuse_weights=None, feature_extractor_config, is_training, reuse_weights=True,
inplace_batchnorm_update=False): inplace_batchnorm_update=False):
"""Builds a faster_rcnn_meta_arch.FasterRCNNFeatureExtractor based on config. """Builds a faster_rcnn_meta_arch.FasterRCNNFeatureExtractor based on config.
...@@ -422,9 +488,8 @@ def _build_faster_rcnn_model(frcnn_config, is_training, add_summaries): ...@@ -422,9 +488,8 @@ def _build_faster_rcnn_model(frcnn_config, is_training, add_summaries):
""" """
num_classes = frcnn_config.num_classes num_classes = frcnn_config.num_classes
image_resizer_fn = image_resizer_builder.build(frcnn_config.image_resizer) image_resizer_fn = image_resizer_builder.build(frcnn_config.image_resizer)
_check_feature_extractor_exists(frcnn_config.feature_extractor.type)
is_keras = (frcnn_config.feature_extractor.type in is_keras = tf_version.is_tf2()
FASTER_RCNN_KERAS_FEATURE_EXTRACTOR_CLASS_MAP)
if is_keras: if is_keras:
feature_extractor = _build_faster_rcnn_keras_feature_extractor( feature_extractor = _build_faster_rcnn_keras_feature_extractor(
...@@ -536,54 +601,98 @@ def _build_faster_rcnn_model(frcnn_config, is_training, add_summaries): ...@@ -536,54 +601,98 @@ def _build_faster_rcnn_model(frcnn_config, is_training, add_summaries):
frcnn_config.clip_anchors_to_image) frcnn_config.clip_anchors_to_image)
common_kwargs = { common_kwargs = {
'is_training': is_training, 'is_training':
'num_classes': num_classes, is_training,
'image_resizer_fn': image_resizer_fn, 'num_classes':
'feature_extractor': feature_extractor, num_classes,
'number_of_stages': number_of_stages, 'image_resizer_fn':
'first_stage_anchor_generator': first_stage_anchor_generator, image_resizer_fn,
'first_stage_target_assigner': first_stage_target_assigner, 'feature_extractor':
'first_stage_atrous_rate': first_stage_atrous_rate, feature_extractor,
'number_of_stages':
number_of_stages,
'first_stage_anchor_generator':
first_stage_anchor_generator,
'first_stage_target_assigner':
first_stage_target_assigner,
'first_stage_atrous_rate':
first_stage_atrous_rate,
'first_stage_box_predictor_arg_scope_fn': 'first_stage_box_predictor_arg_scope_fn':
first_stage_box_predictor_arg_scope_fn, first_stage_box_predictor_arg_scope_fn,
'first_stage_box_predictor_kernel_size': 'first_stage_box_predictor_kernel_size':
first_stage_box_predictor_kernel_size, first_stage_box_predictor_kernel_size,
'first_stage_box_predictor_depth': first_stage_box_predictor_depth, 'first_stage_box_predictor_depth':
'first_stage_minibatch_size': first_stage_minibatch_size, first_stage_box_predictor_depth,
'first_stage_sampler': first_stage_sampler, 'first_stage_minibatch_size':
'first_stage_non_max_suppression_fn': first_stage_non_max_suppression_fn, first_stage_minibatch_size,
'first_stage_max_proposals': first_stage_max_proposals, 'first_stage_sampler':
'first_stage_localization_loss_weight': first_stage_loc_loss_weight, first_stage_sampler,
'first_stage_objectness_loss_weight': first_stage_obj_loss_weight, 'first_stage_non_max_suppression_fn':
'second_stage_target_assigner': second_stage_target_assigner, first_stage_non_max_suppression_fn,
'second_stage_batch_size': second_stage_batch_size, 'first_stage_max_proposals':
'second_stage_sampler': second_stage_sampler, first_stage_max_proposals,
'first_stage_localization_loss_weight':
first_stage_loc_loss_weight,
'first_stage_objectness_loss_weight':
first_stage_obj_loss_weight,
'second_stage_target_assigner':
second_stage_target_assigner,
'second_stage_batch_size':
second_stage_batch_size,
'second_stage_sampler':
second_stage_sampler,
'second_stage_non_max_suppression_fn': 'second_stage_non_max_suppression_fn':
second_stage_non_max_suppression_fn, second_stage_non_max_suppression_fn,
'second_stage_score_conversion_fn': second_stage_score_conversion_fn, 'second_stage_score_conversion_fn':
second_stage_score_conversion_fn,
'second_stage_localization_loss_weight': 'second_stage_localization_loss_weight':
second_stage_localization_loss_weight, second_stage_localization_loss_weight,
'second_stage_classification_loss': 'second_stage_classification_loss':
second_stage_classification_loss, second_stage_classification_loss,
'second_stage_classification_loss_weight': 'second_stage_classification_loss_weight':
second_stage_classification_loss_weight, second_stage_classification_loss_weight,
'hard_example_miner': hard_example_miner, 'hard_example_miner':
'add_summaries': add_summaries, hard_example_miner,
'crop_and_resize_fn': crop_and_resize_fn, 'add_summaries':
'clip_anchors_to_image': clip_anchors_to_image, add_summaries,
'use_static_shapes': use_static_shapes, 'crop_and_resize_fn':
'resize_masks': frcnn_config.resize_masks, crop_and_resize_fn,
'return_raw_detections_during_predict': ( 'clip_anchors_to_image':
frcnn_config.return_raw_detections_during_predict) clip_anchors_to_image,
'use_static_shapes':
use_static_shapes,
'resize_masks':
frcnn_config.resize_masks,
'return_raw_detections_during_predict':
frcnn_config.return_raw_detections_during_predict,
'output_final_box_features':
frcnn_config.output_final_box_features
} }
if (isinstance(second_stage_box_predictor, if ((not is_keras and isinstance(second_stage_box_predictor,
rfcn_box_predictor.RfcnBoxPredictor) or rfcn_box_predictor.RfcnBoxPredictor)) or
isinstance(second_stage_box_predictor, (is_keras and
rfcn_keras_box_predictor.RfcnKerasBoxPredictor)): isinstance(second_stage_box_predictor,
rfcn_keras_box_predictor.RfcnKerasBoxPredictor))):
return rfcn_meta_arch.RFCNMetaArch( return rfcn_meta_arch.RFCNMetaArch(
second_stage_rfcn_box_predictor=second_stage_box_predictor, second_stage_rfcn_box_predictor=second_stage_box_predictor,
**common_kwargs) **common_kwargs)
elif frcnn_config.HasField('context_config'):
context_config = frcnn_config.context_config
common_kwargs.update({
'attention_bottleneck_dimension':
context_config.attention_bottleneck_dimension,
'attention_temperature':
context_config.attention_temperature
})
return context_rcnn_meta_arch.ContextRCNNMetaArch(
initial_crop_size=initial_crop_size,
maxpool_kernel_size=maxpool_kernel_size,
maxpool_stride=maxpool_stride,
second_stage_mask_rcnn_box_predictor=second_stage_box_predictor,
second_stage_mask_prediction_loss_weight=(
second_stage_mask_prediction_loss_weight),
**common_kwargs)
else: else:
return faster_rcnn_meta_arch.FasterRCNNMetaArch( return faster_rcnn_meta_arch.FasterRCNNMetaArch(
initial_crop_size=initial_crop_size, initial_crop_size=initial_crop_size,
...@@ -602,10 +711,170 @@ def _build_experimental_model(config, is_training, add_summaries=True): ...@@ -602,10 +711,170 @@ def _build_experimental_model(config, is_training, add_summaries=True):
return EXPERIMENTAL_META_ARCH_BUILDER_MAP[config.name]( return EXPERIMENTAL_META_ARCH_BUILDER_MAP[config.name](
is_training, add_summaries) is_training, add_summaries)
META_ARCHITECURE_BUILDER_MAP = {
# The class ID in the groundtruth/model architecture is usually 0-based while
# the ID in the label map is 1-based. The offset is used to convert between the
# the two.
CLASS_ID_OFFSET = 1
KEYPOINT_STD_DEV_DEFAULT = 1.0
def keypoint_proto_to_params(kp_config, keypoint_map_dict):
"""Converts CenterNet.KeypointEstimation proto to parameter namedtuple."""
label_map_item = keypoint_map_dict[kp_config.keypoint_class_name]
classification_loss, localization_loss, _, _, _, _, _ = (
losses_builder.build(kp_config.loss))
keypoint_indices = [
keypoint.id for keypoint in label_map_item.keypoints
]
keypoint_labels = [
keypoint.label for keypoint in label_map_item.keypoints
]
keypoint_std_dev_dict = {
label: KEYPOINT_STD_DEV_DEFAULT for label in keypoint_labels
}
if kp_config.keypoint_label_to_std:
for label, value in kp_config.keypoint_label_to_std.items():
keypoint_std_dev_dict[label] = value
keypoint_std_dev = [keypoint_std_dev_dict[label] for label in keypoint_labels]
return center_net_meta_arch.KeypointEstimationParams(
task_name=kp_config.task_name,
class_id=label_map_item.id - CLASS_ID_OFFSET,
keypoint_indices=keypoint_indices,
classification_loss=classification_loss,
localization_loss=localization_loss,
keypoint_labels=keypoint_labels,
keypoint_std_dev=keypoint_std_dev,
task_loss_weight=kp_config.task_loss_weight,
keypoint_regression_loss_weight=kp_config.keypoint_regression_loss_weight,
keypoint_heatmap_loss_weight=kp_config.keypoint_heatmap_loss_weight,
keypoint_offset_loss_weight=kp_config.keypoint_offset_loss_weight,
heatmap_bias_init=kp_config.heatmap_bias_init,
keypoint_candidate_score_threshold=(
kp_config.keypoint_candidate_score_threshold),
num_candidates_per_keypoint=kp_config.num_candidates_per_keypoint,
peak_max_pool_kernel_size=kp_config.peak_max_pool_kernel_size,
unmatched_keypoint_score=kp_config.unmatched_keypoint_score,
box_scale=kp_config.box_scale,
candidate_search_scale=kp_config.candidate_search_scale,
candidate_ranking_mode=kp_config.candidate_ranking_mode)
def object_detection_proto_to_params(od_config):
"""Converts CenterNet.ObjectDetection proto to parameter namedtuple."""
loss = losses_pb2.Loss()
# Add dummy classification loss to avoid the loss_builder throwing error.
# TODO(yuhuic): update the loss builder to take the classification loss
# directly.
loss.classification_loss.weighted_sigmoid.CopyFrom(
losses_pb2.WeightedSigmoidClassificationLoss())
loss.localization_loss.CopyFrom(od_config.localization_loss)
_, localization_loss, _, _, _, _, _ = (losses_builder.build(loss))
return center_net_meta_arch.ObjectDetectionParams(
localization_loss=localization_loss,
scale_loss_weight=od_config.scale_loss_weight,
offset_loss_weight=od_config.offset_loss_weight,
task_loss_weight=od_config.task_loss_weight)
def object_center_proto_to_params(oc_config):
"""Converts CenterNet.ObjectCenter proto to parameter namedtuple."""
loss = losses_pb2.Loss()
# Add dummy localization loss to avoid the loss_builder throwing error.
# TODO(yuhuic): update the loss builder to take the localization loss
# directly.
loss.localization_loss.weighted_l2.CopyFrom(
losses_pb2.WeightedL2LocalizationLoss())
loss.classification_loss.CopyFrom(oc_config.classification_loss)
classification_loss, _, _, _, _, _, _ = (losses_builder.build(loss))
return center_net_meta_arch.ObjectCenterParams(
classification_loss=classification_loss,
object_center_loss_weight=oc_config.object_center_loss_weight,
heatmap_bias_init=oc_config.heatmap_bias_init,
min_box_overlap_iou=oc_config.min_box_overlap_iou,
max_box_predictions=oc_config.max_box_predictions)
def _build_center_net_model(center_net_config, is_training, add_summaries):
"""Build a CenterNet detection model.
Args:
center_net_config: A CenterNet proto object with model configuration.
is_training: True if this model is being built for training purposes.
add_summaries: Whether to add tf summaries in the model.
Returns:
CenterNetMetaArch based on the config.
"""
image_resizer_fn = image_resizer_builder.build(
center_net_config.image_resizer)
_check_feature_extractor_exists(center_net_config.feature_extractor.type)
feature_extractor = _build_center_net_feature_extractor(
center_net_config.feature_extractor)
object_center_params = object_center_proto_to_params(
center_net_config.object_center_params)
object_detection_params = None
if center_net_config.HasField('object_detection_task'):
object_detection_params = object_detection_proto_to_params(
center_net_config.object_detection_task)
keypoint_params_dict = None
if center_net_config.keypoint_estimation_task:
label_map_proto = label_map_util.load_labelmap(
center_net_config.keypoint_label_map_path)
keypoint_map_dict = {
item.name: item for item in label_map_proto.item if item.keypoints
}
keypoint_params_dict = {}
keypoint_class_id_set = set()
all_keypoint_indices = []
for task in center_net_config.keypoint_estimation_task:
kp_params = keypoint_proto_to_params(task, keypoint_map_dict)
keypoint_params_dict[task.task_name] = kp_params
all_keypoint_indices.extend(kp_params.keypoint_indices)
if kp_params.class_id in keypoint_class_id_set:
raise ValueError(('Multiple keypoint tasks map to the same class id is '
'not allowed: %d' % kp_params.class_id))
else:
keypoint_class_id_set.add(kp_params.class_id)
if len(all_keypoint_indices) > len(set(all_keypoint_indices)):
raise ValueError('Some keypoint indices are used more than once.')
return center_net_meta_arch.CenterNetMetaArch(
is_training=is_training,
add_summaries=add_summaries,
num_classes=center_net_config.num_classes,
feature_extractor=feature_extractor,
image_resizer_fn=image_resizer_fn,
object_center_params=object_center_params,
object_detection_params=object_detection_params,
keypoint_params_dict=keypoint_params_dict)
def _build_center_net_feature_extractor(
feature_extractor_config):
"""Build a CenterNet feature extractor from the given config."""
if feature_extractor_config.type not in CENTER_NET_EXTRACTOR_FUNCTION_MAP:
raise ValueError('\'{}\' is not a known CenterNet feature extractor type'
.format(feature_extractor_config.type))
return CENTER_NET_EXTRACTOR_FUNCTION_MAP[feature_extractor_config.type](
channel_means=list(feature_extractor_config.channel_means),
channel_stds=list(feature_extractor_config.channel_stds),
bgr_ordering=feature_extractor_config.bgr_ordering
)
META_ARCH_BUILDER_MAP = {
'ssd': _build_ssd_model, 'ssd': _build_ssd_model,
'faster_rcnn': _build_faster_rcnn_model, 'faster_rcnn': _build_faster_rcnn_model,
'experimental_model': _build_experimental_model 'experimental_model': _build_experimental_model,
'center_net': _build_center_net_model
} }
...@@ -628,9 +897,9 @@ def build(model_config, is_training, add_summaries=True): ...@@ -628,9 +897,9 @@ def build(model_config, is_training, add_summaries=True):
meta_architecture = model_config.WhichOneof('model') meta_architecture = model_config.WhichOneof('model')
if meta_architecture not in META_ARCHITECURE_BUILDER_MAP: if meta_architecture not in META_ARCH_BUILDER_MAP:
raise ValueError('Unknown meta architecture: {}'.format(meta_architecture)) raise ValueError('Unknown meta architecture: {}'.format(meta_architecture))
else: else:
build_func = META_ARCHITECURE_BUILDER_MAP[meta_architecture] build_func = META_ARCH_BUILDER_MAP[meta_architecture]
return build_func(getattr(model_config, meta_architecture), is_training, return build_func(getattr(model_config, meta_architecture), is_training,
add_summaries) add_summaries)
# Lint as: python2, python3
# Copyright 2017 The TensorFlow Authors. All Rights Reserved. # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
...@@ -12,25 +13,34 @@ ...@@ -12,25 +13,34 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Tests for object_detection.models.model_builder.""" """Tests for object_detection.models.model_builder."""
from absl.testing import parameterized from absl.testing import parameterized
import tensorflow as tf
from google.protobuf import text_format from google.protobuf import text_format
from object_detection.builders import model_builder from object_detection.builders import model_builder
from object_detection.meta_architectures import faster_rcnn_meta_arch from object_detection.meta_architectures import faster_rcnn_meta_arch
from object_detection.meta_architectures import rfcn_meta_arch from object_detection.meta_architectures import rfcn_meta_arch
from object_detection.meta_architectures import ssd_meta_arch from object_detection.meta_architectures import ssd_meta_arch
from object_detection.models import ssd_resnet_v1_fpn_feature_extractor as ssd_resnet_v1_fpn
from object_detection.protos import hyperparams_pb2 from object_detection.protos import hyperparams_pb2
from object_detection.protos import losses_pb2 from object_detection.protos import losses_pb2
from object_detection.protos import model_pb2 from object_detection.protos import model_pb2
from object_detection.utils import test_case
class ModelBuilderTest(test_case.TestCase, parameterized.TestCase):
def default_ssd_feature_extractor(self):
raise NotImplementedError
class ModelBuilderTest(tf.test.TestCase, parameterized.TestCase): def default_faster_rcnn_feature_extractor(self):
raise NotImplementedError
def ssd_feature_extractors(self):
raise NotImplementedError
def faster_rcnn_feature_extractors(self):
raise NotImplementedError
def create_model(self, model_config, is_training=True): def create_model(self, model_config, is_training=True):
"""Builds a DetectionModel based on the model config. """Builds a DetectionModel based on the model config.
...@@ -50,7 +60,6 @@ class ModelBuilderTest(tf.test.TestCase, parameterized.TestCase): ...@@ -50,7 +60,6 @@ class ModelBuilderTest(tf.test.TestCase, parameterized.TestCase):
model_text_proto = """ model_text_proto = """
ssd { ssd {
feature_extractor { feature_extractor {
type: 'ssd_inception_v2'
conv_hyperparams { conv_hyperparams {
regularizer { regularizer {
l2_regularizer { l2_regularizer {
...@@ -113,6 +122,8 @@ class ModelBuilderTest(tf.test.TestCase, parameterized.TestCase): ...@@ -113,6 +122,8 @@ class ModelBuilderTest(tf.test.TestCase, parameterized.TestCase):
}""" }"""
model_proto = model_pb2.DetectionModel() model_proto = model_pb2.DetectionModel()
text_format.Merge(model_text_proto, model_proto) text_format.Merge(model_text_proto, model_proto)
model_proto.ssd.feature_extractor.type = (self.
default_ssd_feature_extractor())
return model_proto return model_proto
def create_default_faster_rcnn_model_proto(self): def create_default_faster_rcnn_model_proto(self):
...@@ -127,9 +138,6 @@ class ModelBuilderTest(tf.test.TestCase, parameterized.TestCase): ...@@ -127,9 +138,6 @@ class ModelBuilderTest(tf.test.TestCase, parameterized.TestCase):
max_dimension: 1024 max_dimension: 1024
} }
} }
feature_extractor {
type: 'faster_rcnn_resnet101'
}
first_stage_anchor_generator { first_stage_anchor_generator {
grid_anchor_generator { grid_anchor_generator {
scales: [0.25, 0.5, 1.0, 2.0] scales: [0.25, 0.5, 1.0, 2.0]
...@@ -188,17 +196,14 @@ class ModelBuilderTest(tf.test.TestCase, parameterized.TestCase): ...@@ -188,17 +196,14 @@ class ModelBuilderTest(tf.test.TestCase, parameterized.TestCase):
}""" }"""
model_proto = model_pb2.DetectionModel() model_proto = model_pb2.DetectionModel()
text_format.Merge(model_text_proto, model_proto) text_format.Merge(model_text_proto, model_proto)
(model_proto.faster_rcnn.feature_extractor.type
) = self.default_faster_rcnn_feature_extractor()
return model_proto return model_proto
def test_create_ssd_models_from_config(self): def test_create_ssd_models_from_config(self):
model_proto = self.create_default_ssd_model_proto() model_proto = self.create_default_ssd_model_proto()
ssd_feature_extractor_map = {} for extractor_type, extractor_class in self.ssd_feature_extractors().items(
ssd_feature_extractor_map.update( ):
model_builder.SSD_FEATURE_EXTRACTOR_CLASS_MAP)
ssd_feature_extractor_map.update(
model_builder.SSD_KERAS_FEATURE_EXTRACTOR_CLASS_MAP)
for extractor_type, extractor_class in ssd_feature_extractor_map.items():
model_proto.ssd.feature_extractor.type = extractor_type model_proto.ssd.feature_extractor.type = extractor_type
model = model_builder.build(model_proto, is_training=True) model = model_builder.build(model_proto, is_training=True)
self.assertIsInstance(model, ssd_meta_arch.SSDMetaArch) self.assertIsInstance(model, ssd_meta_arch.SSDMetaArch)
...@@ -206,12 +211,9 @@ class ModelBuilderTest(tf.test.TestCase, parameterized.TestCase): ...@@ -206,12 +211,9 @@ class ModelBuilderTest(tf.test.TestCase, parameterized.TestCase):
def test_create_ssd_fpn_model_from_config(self): def test_create_ssd_fpn_model_from_config(self):
model_proto = self.create_default_ssd_model_proto() model_proto = self.create_default_ssd_model_proto()
model_proto.ssd.feature_extractor.type = 'ssd_resnet101_v1_fpn'
model_proto.ssd.feature_extractor.fpn.min_level = 3 model_proto.ssd.feature_extractor.fpn.min_level = 3
model_proto.ssd.feature_extractor.fpn.max_level = 7 model_proto.ssd.feature_extractor.fpn.max_level = 7
model = model_builder.build(model_proto, is_training=True) model = model_builder.build(model_proto, is_training=True)
self.assertIsInstance(model._feature_extractor,
ssd_resnet_v1_fpn.SSDResnet101V1FpnFeatureExtractor)
self.assertEqual(model._feature_extractor._fpn_min_level, 3) self.assertEqual(model._feature_extractor._fpn_min_level, 3)
self.assertEqual(model._feature_extractor._fpn_max_level, 7) self.assertEqual(model._feature_extractor._fpn_max_level, 7)
...@@ -238,8 +240,9 @@ class ModelBuilderTest(tf.test.TestCase, parameterized.TestCase): ...@@ -238,8 +240,9 @@ class ModelBuilderTest(tf.test.TestCase, parameterized.TestCase):
'enable_mask_prediction': False 'enable_mask_prediction': False
}, },
) )
def test_create_faster_rcnn_models_from_config( def test_create_faster_rcnn_models_from_config(self,
self, use_matmul_crop_and_resize, enable_mask_prediction): use_matmul_crop_and_resize,
enable_mask_prediction):
model_proto = self.create_default_faster_rcnn_model_proto() model_proto = self.create_default_faster_rcnn_model_proto()
faster_rcnn_config = model_proto.faster_rcnn faster_rcnn_config = model_proto.faster_rcnn
faster_rcnn_config.use_matmul_crop_and_resize = use_matmul_crop_and_resize faster_rcnn_config.use_matmul_crop_and_resize = use_matmul_crop_and_resize
...@@ -250,7 +253,7 @@ class ModelBuilderTest(tf.test.TestCase, parameterized.TestCase): ...@@ -250,7 +253,7 @@ class ModelBuilderTest(tf.test.TestCase, parameterized.TestCase):
mask_predictor_config.predict_instance_masks = True mask_predictor_config.predict_instance_masks = True
for extractor_type, extractor_class in ( for extractor_type, extractor_class in (
model_builder.FASTER_RCNN_FEATURE_EXTRACTOR_CLASS_MAP.items()): self.faster_rcnn_feature_extractors().items()):
faster_rcnn_config.feature_extractor.type = extractor_type faster_rcnn_config.feature_extractor.type = extractor_type
model = model_builder.build(model_proto, is_training=True) model = model_builder.build(model_proto, is_training=True)
self.assertIsInstance(model, faster_rcnn_meta_arch.FasterRCNNMetaArch) self.assertIsInstance(model, faster_rcnn_meta_arch.FasterRCNNMetaArch)
...@@ -270,52 +273,59 @@ class ModelBuilderTest(tf.test.TestCase, parameterized.TestCase): ...@@ -270,52 +273,59 @@ class ModelBuilderTest(tf.test.TestCase, parameterized.TestCase):
model_proto.faster_rcnn.second_stage_box_predictor.rfcn_box_predictor) model_proto.faster_rcnn.second_stage_box_predictor.rfcn_box_predictor)
rfcn_predictor_config.conv_hyperparams.op = hyperparams_pb2.Hyperparams.CONV rfcn_predictor_config.conv_hyperparams.op = hyperparams_pb2.Hyperparams.CONV
for extractor_type, extractor_class in ( for extractor_type, extractor_class in (
model_builder.FASTER_RCNN_FEATURE_EXTRACTOR_CLASS_MAP.items()): self.faster_rcnn_feature_extractors().items()):
model_proto.faster_rcnn.feature_extractor.type = extractor_type model_proto.faster_rcnn.feature_extractor.type = extractor_type
model = model_builder.build(model_proto, is_training=True) model = model_builder.build(model_proto, is_training=True)
self.assertIsInstance(model, rfcn_meta_arch.RFCNMetaArch) self.assertIsInstance(model, rfcn_meta_arch.RFCNMetaArch)
self.assertIsInstance(model._feature_extractor, extractor_class) self.assertIsInstance(model._feature_extractor, extractor_class)
@parameterized.parameters(True, False)
def test_create_faster_rcnn_from_config_with_crop_feature(
self, output_final_box_features):
model_proto = self.create_default_faster_rcnn_model_proto()
model_proto.faster_rcnn.output_final_box_features = (
output_final_box_features)
_ = model_builder.build(model_proto, is_training=True)
def test_invalid_model_config_proto(self): def test_invalid_model_config_proto(self):
model_proto = '' model_proto = ''
with self.assertRaisesRegexp( with self.assertRaisesRegex(
ValueError, 'model_config not of type model_pb2.DetectionModel.'): ValueError, 'model_config not of type model_pb2.DetectionModel.'):
model_builder.build(model_proto, is_training=True) model_builder.build(model_proto, is_training=True)
def test_unknown_meta_architecture(self): def test_unknown_meta_architecture(self):
model_proto = model_pb2.DetectionModel() model_proto = model_pb2.DetectionModel()
with self.assertRaisesRegexp(ValueError, 'Unknown meta architecture'): with self.assertRaisesRegex(ValueError, 'Unknown meta architecture'):
model_builder.build(model_proto, is_training=True) model_builder.build(model_proto, is_training=True)
def test_unknown_ssd_feature_extractor(self): def test_unknown_ssd_feature_extractor(self):
model_proto = self.create_default_ssd_model_proto() model_proto = self.create_default_ssd_model_proto()
model_proto.ssd.feature_extractor.type = 'unknown_feature_extractor' model_proto.ssd.feature_extractor.type = 'unknown_feature_extractor'
with self.assertRaisesRegexp(ValueError, 'Unknown ssd feature_extractor'): with self.assertRaises(ValueError):
model_builder.build(model_proto, is_training=True) model_builder.build(model_proto, is_training=True)
def test_unknown_faster_rcnn_feature_extractor(self): def test_unknown_faster_rcnn_feature_extractor(self):
model_proto = self.create_default_faster_rcnn_model_proto() model_proto = self.create_default_faster_rcnn_model_proto()
model_proto.faster_rcnn.feature_extractor.type = 'unknown_feature_extractor' model_proto.faster_rcnn.feature_extractor.type = 'unknown_feature_extractor'
with self.assertRaisesRegexp(ValueError, with self.assertRaises(ValueError):
'Unknown Faster R-CNN feature_extractor'):
model_builder.build(model_proto, is_training=True) model_builder.build(model_proto, is_training=True)
def test_invalid_first_stage_nms_iou_threshold(self): def test_invalid_first_stage_nms_iou_threshold(self):
model_proto = self.create_default_faster_rcnn_model_proto() model_proto = self.create_default_faster_rcnn_model_proto()
model_proto.faster_rcnn.first_stage_nms_iou_threshold = 1.1 model_proto.faster_rcnn.first_stage_nms_iou_threshold = 1.1
with self.assertRaisesRegexp(ValueError, with self.assertRaisesRegex(ValueError,
r'iou_threshold not in \[0, 1\.0\]'): r'iou_threshold not in \[0, 1\.0\]'):
model_builder.build(model_proto, is_training=True) model_builder.build(model_proto, is_training=True)
model_proto.faster_rcnn.first_stage_nms_iou_threshold = -0.1 model_proto.faster_rcnn.first_stage_nms_iou_threshold = -0.1
with self.assertRaisesRegexp(ValueError, with self.assertRaisesRegex(ValueError,
r'iou_threshold not in \[0, 1\.0\]'): r'iou_threshold not in \[0, 1\.0\]'):
model_builder.build(model_proto, is_training=True) model_builder.build(model_proto, is_training=True)
def test_invalid_second_stage_batch_size(self): def test_invalid_second_stage_batch_size(self):
model_proto = self.create_default_faster_rcnn_model_proto() model_proto = self.create_default_faster_rcnn_model_proto()
model_proto.faster_rcnn.first_stage_max_proposals = 1 model_proto.faster_rcnn.first_stage_max_proposals = 1
model_proto.faster_rcnn.second_stage_batch_size = 2 model_proto.faster_rcnn.second_stage_batch_size = 2
with self.assertRaisesRegexp( with self.assertRaisesRegex(
ValueError, 'second_stage_batch_size should be no greater ' ValueError, 'second_stage_batch_size should be no greater '
'than first_stage_max_proposals.'): 'than first_stage_max_proposals.'):
model_builder.build(model_proto, is_training=True) model_builder.build(model_proto, is_training=True)
...@@ -323,8 +333,8 @@ class ModelBuilderTest(tf.test.TestCase, parameterized.TestCase): ...@@ -323,8 +333,8 @@ class ModelBuilderTest(tf.test.TestCase, parameterized.TestCase):
def test_invalid_faster_rcnn_batchnorm_update(self): def test_invalid_faster_rcnn_batchnorm_update(self):
model_proto = self.create_default_faster_rcnn_model_proto() model_proto = self.create_default_faster_rcnn_model_proto()
model_proto.faster_rcnn.inplace_batchnorm_update = True model_proto.faster_rcnn.inplace_batchnorm_update = True
with self.assertRaisesRegexp(ValueError, with self.assertRaisesRegex(ValueError,
'inplace batchnorm updates not supported'): 'inplace batchnorm updates not supported'):
model_builder.build(model_proto, is_training=True) model_builder.build(model_proto, is_training=True)
def test_create_experimental_model(self): def test_create_experimental_model(self):
...@@ -340,7 +350,3 @@ class ModelBuilderTest(tf.test.TestCase, parameterized.TestCase): ...@@ -340,7 +350,3 @@ class ModelBuilderTest(tf.test.TestCase, parameterized.TestCase):
text_format.Merge(model_text_proto, model_proto) text_format.Merge(model_text_proto, model_proto)
self.assertEqual(model_builder.build(model_proto, is_training=True), 42) self.assertEqual(model_builder.build(model_proto, is_training=True), 42)
if __name__ == '__main__':
tf.test.main()
# Lint as: python2, python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for model_builder under TensorFlow 1.X."""
from absl.testing import parameterized
import tensorflow as tf
from object_detection.builders import model_builder
from object_detection.builders import model_builder_test
from object_detection.meta_architectures import ssd_meta_arch
from object_detection.protos import losses_pb2
class ModelBuilderTF1Test(model_builder_test.ModelBuilderTest):
def default_ssd_feature_extractor(self):
return 'ssd_resnet50_v1_fpn'
def default_faster_rcnn_feature_extractor(self):
return 'faster_rcnn_resnet101'
def ssd_feature_extractors(self):
return model_builder.SSD_FEATURE_EXTRACTOR_CLASS_MAP
def faster_rcnn_feature_extractors(self):
return model_builder.FASTER_RCNN_FEATURE_EXTRACTOR_CLASS_MAP
if __name__ == '__main__':
tf.test.main()
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
import tensorflow as tf import tensorflow as tf
from tensorflow.contrib import opt as tf_opt
from object_detection.utils import learning_schedules from object_detection.utils import learning_schedules
...@@ -64,14 +65,14 @@ def build_optimizers_tf_v1(optimizer_config, global_step=None): ...@@ -64,14 +65,14 @@ def build_optimizers_tf_v1(optimizer_config, global_step=None):
learning_rate = _create_learning_rate(config.learning_rate, learning_rate = _create_learning_rate(config.learning_rate,
global_step=global_step) global_step=global_step)
summary_vars.append(learning_rate) summary_vars.append(learning_rate)
optimizer = tf.train.AdamOptimizer(learning_rate) optimizer = tf.train.AdamOptimizer(learning_rate, epsilon=config.epsilon)
if optimizer is None: if optimizer is None:
raise ValueError('Optimizer %s not supported.' % optimizer_type) raise ValueError('Optimizer %s not supported.' % optimizer_type)
if optimizer_config.use_moving_average: if optimizer_config.use_moving_average:
optimizer = tf.contrib.opt.MovingAverageOptimizer( optimizer = tf_opt.MovingAverageOptimizer(
optimizer, average_decay=optimizer_config.moving_average_decay) optimizer, average_decay=optimizer_config.moving_average_decay)
return optimizer, summary_vars return optimizer, summary_vars
...@@ -120,7 +121,7 @@ def build_optimizers_tf_v2(optimizer_config, global_step=None): ...@@ -120,7 +121,7 @@ def build_optimizers_tf_v2(optimizer_config, global_step=None):
learning_rate = _create_learning_rate(config.learning_rate, learning_rate = _create_learning_rate(config.learning_rate,
global_step=global_step) global_step=global_step)
summary_vars.append(learning_rate) summary_vars.append(learning_rate)
optimizer = tf.keras.optimizers.Adam(learning_rate) optimizer = tf.keras.optimizers.Adam(learning_rate, epsilon=config.epsilon)
if optimizer is None: if optimizer is None:
raise ValueError('Optimizer %s not supported.' % optimizer_type) raise ValueError('Optimizer %s not supported.' % optimizer_type)
......
# Lint as: python2, python3
# Copyright 2017 The TensorFlow Authors. All Rights Reserved. # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
...@@ -15,6 +16,11 @@ ...@@ -15,6 +16,11 @@
"""Tests for optimizer_builder.""" """Tests for optimizer_builder."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import six
import tensorflow as tf import tensorflow as tf
from google.protobuf import text_format from google.protobuf import text_format
...@@ -22,6 +28,14 @@ from google.protobuf import text_format ...@@ -22,6 +28,14 @@ from google.protobuf import text_format
from object_detection.builders import optimizer_builder from object_detection.builders import optimizer_builder
from object_detection.protos import optimizer_pb2 from object_detection.protos import optimizer_pb2
# pylint: disable=g-import-not-at-top
try:
from tensorflow.contrib import opt as contrib_opt
except ImportError:
# TF 2.0 doesn't ship with contrib.
pass
# pylint: enable=g-import-not-at-top
class LearningRateBuilderTest(tf.test.TestCase): class LearningRateBuilderTest(tf.test.TestCase):
...@@ -35,7 +49,8 @@ class LearningRateBuilderTest(tf.test.TestCase): ...@@ -35,7 +49,8 @@ class LearningRateBuilderTest(tf.test.TestCase):
text_format.Merge(learning_rate_text_proto, learning_rate_proto) text_format.Merge(learning_rate_text_proto, learning_rate_proto)
learning_rate = optimizer_builder._create_learning_rate( learning_rate = optimizer_builder._create_learning_rate(
learning_rate_proto) learning_rate_proto)
self.assertTrue(learning_rate.op.name.endswith('learning_rate')) self.assertTrue(
six.ensure_str(learning_rate.op.name).endswith('learning_rate'))
with self.test_session(): with self.test_session():
learning_rate_out = learning_rate.eval() learning_rate_out = learning_rate.eval()
self.assertAlmostEqual(learning_rate_out, 0.004) self.assertAlmostEqual(learning_rate_out, 0.004)
...@@ -53,8 +68,9 @@ class LearningRateBuilderTest(tf.test.TestCase): ...@@ -53,8 +68,9 @@ class LearningRateBuilderTest(tf.test.TestCase):
text_format.Merge(learning_rate_text_proto, learning_rate_proto) text_format.Merge(learning_rate_text_proto, learning_rate_proto)
learning_rate = optimizer_builder._create_learning_rate( learning_rate = optimizer_builder._create_learning_rate(
learning_rate_proto) learning_rate_proto)
self.assertTrue(learning_rate.op.name.endswith('learning_rate')) self.assertTrue(
self.assertTrue(isinstance(learning_rate, tf.Tensor)) six.ensure_str(learning_rate.op.name).endswith('learning_rate'))
self.assertIsInstance(learning_rate, tf.Tensor)
def testBuildManualStepLearningRate(self): def testBuildManualStepLearningRate(self):
learning_rate_text_proto = """ learning_rate_text_proto = """
...@@ -75,7 +91,7 @@ class LearningRateBuilderTest(tf.test.TestCase): ...@@ -75,7 +91,7 @@ class LearningRateBuilderTest(tf.test.TestCase):
text_format.Merge(learning_rate_text_proto, learning_rate_proto) text_format.Merge(learning_rate_text_proto, learning_rate_proto)
learning_rate = optimizer_builder._create_learning_rate( learning_rate = optimizer_builder._create_learning_rate(
learning_rate_proto) learning_rate_proto)
self.assertTrue(isinstance(learning_rate, tf.Tensor)) self.assertIsInstance(learning_rate, tf.Tensor)
def testBuildCosineDecayLearningRate(self): def testBuildCosineDecayLearningRate(self):
learning_rate_text_proto = """ learning_rate_text_proto = """
...@@ -91,7 +107,7 @@ class LearningRateBuilderTest(tf.test.TestCase): ...@@ -91,7 +107,7 @@ class LearningRateBuilderTest(tf.test.TestCase):
text_format.Merge(learning_rate_text_proto, learning_rate_proto) text_format.Merge(learning_rate_text_proto, learning_rate_proto)
learning_rate = optimizer_builder._create_learning_rate( learning_rate = optimizer_builder._create_learning_rate(
learning_rate_proto) learning_rate_proto)
self.assertTrue(isinstance(learning_rate, tf.Tensor)) self.assertIsInstance(learning_rate, tf.Tensor)
def testRaiseErrorOnEmptyLearningRate(self): def testRaiseErrorOnEmptyLearningRate(self):
learning_rate_text_proto = """ learning_rate_text_proto = """
...@@ -123,7 +139,7 @@ class OptimizerBuilderTest(tf.test.TestCase): ...@@ -123,7 +139,7 @@ class OptimizerBuilderTest(tf.test.TestCase):
optimizer_proto = optimizer_pb2.Optimizer() optimizer_proto = optimizer_pb2.Optimizer()
text_format.Merge(optimizer_text_proto, optimizer_proto) text_format.Merge(optimizer_text_proto, optimizer_proto)
optimizer, _ = optimizer_builder.build(optimizer_proto) optimizer, _ = optimizer_builder.build(optimizer_proto)
self.assertTrue(isinstance(optimizer, tf.train.RMSPropOptimizer)) self.assertIsInstance(optimizer, tf.train.RMSPropOptimizer)
def testBuildMomentumOptimizer(self): def testBuildMomentumOptimizer(self):
optimizer_text_proto = """ optimizer_text_proto = """
...@@ -140,11 +156,12 @@ class OptimizerBuilderTest(tf.test.TestCase): ...@@ -140,11 +156,12 @@ class OptimizerBuilderTest(tf.test.TestCase):
optimizer_proto = optimizer_pb2.Optimizer() optimizer_proto = optimizer_pb2.Optimizer()
text_format.Merge(optimizer_text_proto, optimizer_proto) text_format.Merge(optimizer_text_proto, optimizer_proto)
optimizer, _ = optimizer_builder.build(optimizer_proto) optimizer, _ = optimizer_builder.build(optimizer_proto)
self.assertTrue(isinstance(optimizer, tf.train.MomentumOptimizer)) self.assertIsInstance(optimizer, tf.train.MomentumOptimizer)
def testBuildAdamOptimizer(self): def testBuildAdamOptimizer(self):
optimizer_text_proto = """ optimizer_text_proto = """
adam_optimizer: { adam_optimizer: {
epsilon: 1e-6
learning_rate: { learning_rate: {
constant_learning_rate { constant_learning_rate {
learning_rate: 0.002 learning_rate: 0.002
...@@ -156,7 +173,7 @@ class OptimizerBuilderTest(tf.test.TestCase): ...@@ -156,7 +173,7 @@ class OptimizerBuilderTest(tf.test.TestCase):
optimizer_proto = optimizer_pb2.Optimizer() optimizer_proto = optimizer_pb2.Optimizer()
text_format.Merge(optimizer_text_proto, optimizer_proto) text_format.Merge(optimizer_text_proto, optimizer_proto)
optimizer, _ = optimizer_builder.build(optimizer_proto) optimizer, _ = optimizer_builder.build(optimizer_proto)
self.assertTrue(isinstance(optimizer, tf.train.AdamOptimizer)) self.assertIsInstance(optimizer, tf.train.AdamOptimizer)
def testBuildMovingAverageOptimizer(self): def testBuildMovingAverageOptimizer(self):
optimizer_text_proto = """ optimizer_text_proto = """
...@@ -172,8 +189,7 @@ class OptimizerBuilderTest(tf.test.TestCase): ...@@ -172,8 +189,7 @@ class OptimizerBuilderTest(tf.test.TestCase):
optimizer_proto = optimizer_pb2.Optimizer() optimizer_proto = optimizer_pb2.Optimizer()
text_format.Merge(optimizer_text_proto, optimizer_proto) text_format.Merge(optimizer_text_proto, optimizer_proto)
optimizer, _ = optimizer_builder.build(optimizer_proto) optimizer, _ = optimizer_builder.build(optimizer_proto)
self.assertTrue( self.assertIsInstance(optimizer, contrib_opt.MovingAverageOptimizer)
isinstance(optimizer, tf.contrib.opt.MovingAverageOptimizer))
def testBuildMovingAverageOptimizerWithNonDefaultDecay(self): def testBuildMovingAverageOptimizerWithNonDefaultDecay(self):
optimizer_text_proto = """ optimizer_text_proto = """
...@@ -190,8 +206,7 @@ class OptimizerBuilderTest(tf.test.TestCase): ...@@ -190,8 +206,7 @@ class OptimizerBuilderTest(tf.test.TestCase):
optimizer_proto = optimizer_pb2.Optimizer() optimizer_proto = optimizer_pb2.Optimizer()
text_format.Merge(optimizer_text_proto, optimizer_proto) text_format.Merge(optimizer_text_proto, optimizer_proto)
optimizer, _ = optimizer_builder.build(optimizer_proto) optimizer, _ = optimizer_builder.build(optimizer_proto)
self.assertTrue( self.assertIsInstance(optimizer, contrib_opt.MovingAverageOptimizer)
isinstance(optimizer, tf.contrib.opt.MovingAverageOptimizer))
# TODO(rathodv): Find a way to not depend on the private members. # TODO(rathodv): Find a way to not depend on the private members.
self.assertAlmostEqual(optimizer._ema._decay, 0.2) self.assertAlmostEqual(optimizer._ema._decay, 0.2)
......
...@@ -102,7 +102,7 @@ def _build_non_max_suppressor(nms_config): ...@@ -102,7 +102,7 @@ def _build_non_max_suppressor(nms_config):
soft_nms_sigma=nms_config.soft_nms_sigma, soft_nms_sigma=nms_config.soft_nms_sigma,
use_partitioned_nms=nms_config.use_partitioned_nms, use_partitioned_nms=nms_config.use_partitioned_nms,
use_combined_nms=nms_config.use_combined_nms, use_combined_nms=nms_config.use_combined_nms,
change_coordinate_frame=True) change_coordinate_frame=nms_config.change_coordinate_frame)
return non_max_suppressor_fn return non_max_suppressor_fn
...@@ -110,7 +110,7 @@ def _build_non_max_suppressor(nms_config): ...@@ -110,7 +110,7 @@ def _build_non_max_suppressor(nms_config):
def _score_converter_fn_with_logit_scale(tf_score_converter_fn, logit_scale): def _score_converter_fn_with_logit_scale(tf_score_converter_fn, logit_scale):
"""Create a function to scale logits then apply a Tensorflow function.""" """Create a function to scale logits then apply a Tensorflow function."""
def score_converter_fn(logits): def score_converter_fn(logits):
scaled_logits = tf.divide(logits, logit_scale, name='scale_logits') scaled_logits = tf.multiply(logits, 1.0 / logit_scale, name='scale_logits')
return tf_score_converter_fn(scaled_logits, name='convert_scores') return tf_score_converter_fn(scaled_logits, name='convert_scores')
score_converter_fn.__name__ = '%s_with_logit_scale' % ( score_converter_fn.__name__ = '%s_with_logit_scale' % (
tf_score_converter_fn.__name__) tf_score_converter_fn.__name__)
......
...@@ -150,7 +150,7 @@ def build(preprocessor_step_config): ...@@ -150,7 +150,7 @@ def build(preprocessor_step_config):
return (preprocessor.random_horizontal_flip, return (preprocessor.random_horizontal_flip,
{ {
'keypoint_flip_permutation': tuple( 'keypoint_flip_permutation': tuple(
config.keypoint_flip_permutation), config.keypoint_flip_permutation) or None,
}) })
if step_type == 'random_vertical_flip': if step_type == 'random_vertical_flip':
...@@ -158,7 +158,7 @@ def build(preprocessor_step_config): ...@@ -158,7 +158,7 @@ def build(preprocessor_step_config):
return (preprocessor.random_vertical_flip, return (preprocessor.random_vertical_flip,
{ {
'keypoint_flip_permutation': tuple( 'keypoint_flip_permutation': tuple(
config.keypoint_flip_permutation), config.keypoint_flip_permutation) or None,
}) })
if step_type == 'random_rotation90': if step_type == 'random_rotation90':
...@@ -400,4 +400,13 @@ def build(preprocessor_step_config): ...@@ -400,4 +400,13 @@ def build(preprocessor_step_config):
kwargs['random_coef'] = [op.random_coef for op in config.operations] kwargs['random_coef'] = [op.random_coef for op in config.operations]
return (preprocessor.ssd_random_crop_pad_fixed_aspect_ratio, kwargs) return (preprocessor.ssd_random_crop_pad_fixed_aspect_ratio, kwargs)
if step_type == 'random_square_crop_by_scale':
config = preprocessor_step_config.random_square_crop_by_scale
return preprocessor.random_square_crop_by_scale, {
'scale_min': config.scale_min,
'scale_max': config.scale_max,
'max_border': config.max_border,
'num_scales': config.num_scales
}
raise ValueError('Unknown preprocessing step.') raise ValueError('Unknown preprocessing step.')
...@@ -723,6 +723,25 @@ class PreprocessorBuilderTest(tf.test.TestCase): ...@@ -723,6 +723,25 @@ class PreprocessorBuilderTest(tf.test.TestCase):
self.assertEqual(function, preprocessor.convert_class_logits_to_softmax) self.assertEqual(function, preprocessor.convert_class_logits_to_softmax)
self.assertEqual(args, {'temperature': 2}) self.assertEqual(args, {'temperature': 2})
def test_random_crop_by_scale(self):
preprocessor_text_proto = """
random_square_crop_by_scale {
scale_min: 0.25
scale_max: 2.0
num_scales: 8
}
"""
preprocessor_proto = preprocessor_pb2.PreprocessingStep()
text_format.Merge(preprocessor_text_proto, preprocessor_proto)
function, args = preprocessor_builder.build(preprocessor_proto)
self.assertEqual(function, preprocessor.random_square_crop_by_scale)
self.assertEqual(args, {
'scale_min': 0.25,
'scale_max': 2.0,
'num_scales': 8,
'max_border': 128
})
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
...@@ -33,7 +33,6 @@ when number of examples set to True in indicator is less than batch_size. ...@@ -33,7 +33,6 @@ when number of examples set to True in indicator is less than batch_size.
import tensorflow as tf import tensorflow as tf
from object_detection.core import minibatch_sampler from object_detection.core import minibatch_sampler
from object_detection.utils import ops
class BalancedPositiveNegativeSampler(minibatch_sampler.MinibatchSampler): class BalancedPositiveNegativeSampler(minibatch_sampler.MinibatchSampler):
...@@ -158,19 +157,17 @@ class BalancedPositiveNegativeSampler(minibatch_sampler.MinibatchSampler): ...@@ -158,19 +157,17 @@ class BalancedPositiveNegativeSampler(minibatch_sampler.MinibatchSampler):
# Shuffle indicator and label. Need to store the permutation to restore the # Shuffle indicator and label. Need to store the permutation to restore the
# order post sampling. # order post sampling.
permutation = tf.random_shuffle(tf.range(input_length)) permutation = tf.random_shuffle(tf.range(input_length))
indicator = ops.matmul_gather_on_zeroth_axis( indicator = tf.gather(indicator, permutation, axis=0)
tf.cast(indicator, tf.float32), permutation) labels = tf.gather(labels, permutation, axis=0)
labels = ops.matmul_gather_on_zeroth_axis(
tf.cast(labels, tf.float32), permutation)
# index (starting from 1) when indicator is True, 0 when False # index (starting from 1) when indicator is True, 0 when False
indicator_idx = tf.where( indicator_idx = tf.where(
tf.cast(indicator, tf.bool), tf.range(1, input_length + 1), indicator, tf.range(1, input_length + 1),
tf.zeros(input_length, tf.int32)) tf.zeros(input_length, tf.int32))
# Replace -1 for negative, +1 for positive labels # Replace -1 for negative, +1 for positive labels
signed_label = tf.where( signed_label = tf.where(
tf.cast(labels, tf.bool), tf.ones(input_length, tf.int32), labels, tf.ones(input_length, tf.int32),
tf.scalar_mul(-1, tf.ones(input_length, tf.int32))) tf.scalar_mul(-1, tf.ones(input_length, tf.int32)))
# negative of index for negative label, positive index for positive label, # negative of index for negative label, positive index for positive label,
# 0 when indicator is False. # 0 when indicator is False.
...@@ -198,11 +195,10 @@ class BalancedPositiveNegativeSampler(minibatch_sampler.MinibatchSampler): ...@@ -198,11 +195,10 @@ class BalancedPositiveNegativeSampler(minibatch_sampler.MinibatchSampler):
axis=0), tf.bool) axis=0), tf.bool)
# project back the order based on stored permutations # project back the order based on stored permutations
reprojections = tf.one_hot(permutation, depth=input_length, idx_indicator = tf.scatter_nd(
dtype=tf.float32) tf.expand_dims(permutation, -1), sampled_idx_indicator,
return tf.cast(tf.tensordot( shape=(input_length,))
tf.cast(sampled_idx_indicator, tf.float32), return idx_indicator
reprojections, axes=[0, 0]), tf.bool)
def subsample(self, indicator, batch_size, labels, scope=None): def subsample(self, indicator, batch_size, labels, scope=None):
"""Returns subsampled minibatch. """Returns subsampled minibatch.
......
...@@ -24,24 +24,27 @@ from object_detection.utils import test_case ...@@ -24,24 +24,27 @@ from object_detection.utils import test_case
class BalancedPositiveNegativeSamplerTest(test_case.TestCase): class BalancedPositiveNegativeSamplerTest(test_case.TestCase):
def test_subsample_all_examples_dynamic(self): def test_subsample_all_examples(self):
if self.has_tpu(): return
numpy_labels = np.random.permutation(300) numpy_labels = np.random.permutation(300)
indicator = tf.constant(np.ones(300) == 1) indicator = np.array(np.ones(300) == 1, np.bool)
numpy_labels = (numpy_labels - 200) > 0 numpy_labels = (numpy_labels - 200) > 0
labels = tf.constant(numpy_labels) labels = np.array(numpy_labels, np.bool)
sampler = ( def graph_fn(indicator, labels):
balanced_positive_negative_sampler.BalancedPositiveNegativeSampler()) sampler = (
is_sampled = sampler.subsample(indicator, 64, labels) balanced_positive_negative_sampler.BalancedPositiveNegativeSampler())
with self.test_session() as sess: return sampler.subsample(indicator, 64, labels)
is_sampled = sess.run(is_sampled)
self.assertTrue(sum(is_sampled) == 64) is_sampled = self.execute_cpu(graph_fn, [indicator, labels])
self.assertTrue(sum(np.logical_and(numpy_labels, is_sampled)) == 32) self.assertEqual(sum(is_sampled), 64)
self.assertTrue(sum(np.logical_and( self.assertEqual(sum(np.logical_and(numpy_labels, is_sampled)), 32)
np.logical_not(numpy_labels), is_sampled)) == 32) self.assertEqual(sum(np.logical_and(
np.logical_not(numpy_labels), is_sampled)), 32)
def test_subsample_all_examples_static(self): def test_subsample_all_examples_static(self):
if not self.has_tpu(): return
numpy_labels = np.random.permutation(300) numpy_labels = np.random.permutation(300)
indicator = np.array(np.ones(300) == 1, np.bool) indicator = np.array(np.ones(300) == 1, np.bool)
numpy_labels = (numpy_labels - 200) > 0 numpy_labels = (numpy_labels - 200) > 0
...@@ -54,35 +57,37 @@ class BalancedPositiveNegativeSamplerTest(test_case.TestCase): ...@@ -54,35 +57,37 @@ class BalancedPositiveNegativeSamplerTest(test_case.TestCase):
is_static=True)) is_static=True))
return sampler.subsample(indicator, 64, labels) return sampler.subsample(indicator, 64, labels)
is_sampled = self.execute(graph_fn, [indicator, labels]) is_sampled = self.execute_tpu(graph_fn, [indicator, labels])
self.assertTrue(sum(is_sampled) == 64) self.assertEqual(sum(is_sampled), 64)
self.assertTrue(sum(np.logical_and(numpy_labels, is_sampled)) == 32) self.assertEqual(sum(np.logical_and(numpy_labels, is_sampled)), 32)
self.assertTrue(sum(np.logical_and( self.assertEqual(sum(np.logical_and(
np.logical_not(numpy_labels), is_sampled)) == 32) np.logical_not(numpy_labels), is_sampled)), 32)
def test_subsample_selection_dynamic(self): def test_subsample_selection(self):
if self.has_tpu(): return
# Test random sampling when only some examples can be sampled: # Test random sampling when only some examples can be sampled:
# 100 samples, 20 positives, 10 positives cannot be sampled # 100 samples, 20 positives, 10 positives cannot be sampled.
numpy_labels = np.arange(100) numpy_labels = np.arange(100)
numpy_indicator = numpy_labels < 90 numpy_indicator = numpy_labels < 90
indicator = tf.constant(numpy_indicator) indicator = np.array(numpy_indicator, np.bool)
numpy_labels = (numpy_labels - 80) >= 0 numpy_labels = (numpy_labels - 80) >= 0
labels = tf.constant(numpy_labels) labels = np.array(numpy_labels, np.bool)
sampler = ( def graph_fn(indicator, labels):
balanced_positive_negative_sampler.BalancedPositiveNegativeSampler()) sampler = (
is_sampled = sampler.subsample(indicator, 64, labels) balanced_positive_negative_sampler.BalancedPositiveNegativeSampler())
with self.test_session() as sess: return sampler.subsample(indicator, 64, labels)
is_sampled = sess.run(is_sampled)
self.assertTrue(sum(is_sampled) == 64) is_sampled = self.execute_cpu(graph_fn, [indicator, labels])
self.assertTrue(sum(np.logical_and(numpy_labels, is_sampled)) == 10) self.assertEqual(sum(is_sampled), 64)
self.assertTrue(sum(np.logical_and( self.assertEqual(sum(np.logical_and(numpy_labels, is_sampled)), 10)
np.logical_not(numpy_labels), is_sampled)) == 54) self.assertEqual(sum(np.logical_and(
self.assertAllEqual(is_sampled, np.logical_and(is_sampled, np.logical_not(numpy_labels), is_sampled)), 54)
numpy_indicator)) self.assertAllEqual(is_sampled, np.logical_and(is_sampled, numpy_indicator))
def test_subsample_selection_static(self): def test_subsample_selection_static(self):
if not self.has_tpu(): return
# Test random sampling when only some examples can be sampled: # Test random sampling when only some examples can be sampled:
# 100 samples, 20 positives, 10 positives cannot be sampled. # 100 samples, 20 positives, 10 positives cannot be sampled.
numpy_labels = np.arange(100) numpy_labels = np.arange(100)
...@@ -98,37 +103,41 @@ class BalancedPositiveNegativeSamplerTest(test_case.TestCase): ...@@ -98,37 +103,41 @@ class BalancedPositiveNegativeSamplerTest(test_case.TestCase):
is_static=True)) is_static=True))
return sampler.subsample(indicator, 64, labels) return sampler.subsample(indicator, 64, labels)
is_sampled = self.execute(graph_fn, [indicator, labels]) is_sampled = self.execute_tpu(graph_fn, [indicator, labels])
self.assertTrue(sum(is_sampled) == 64) self.assertEqual(sum(is_sampled), 64)
self.assertTrue(sum(np.logical_and(numpy_labels, is_sampled)) == 10) self.assertEqual(sum(np.logical_and(numpy_labels, is_sampled)), 10)
self.assertTrue(sum(np.logical_and( self.assertEqual(sum(np.logical_and(
np.logical_not(numpy_labels), is_sampled)) == 54) np.logical_not(numpy_labels), is_sampled)), 54)
self.assertAllEqual(is_sampled, np.logical_and(is_sampled, numpy_indicator)) self.assertAllEqual(is_sampled, np.logical_and(is_sampled, numpy_indicator))
def test_subsample_selection_larger_batch_size_dynamic(self): def test_subsample_selection_larger_batch_size(self):
if self.has_tpu(): return
# Test random sampling when total number of examples that can be sampled are # Test random sampling when total number of examples that can be sampled are
# less than batch size: # less than batch size:
# 100 samples, 50 positives, 40 positives cannot be sampled, batch size 64. # 100 samples, 50 positives, 40 positives cannot be sampled, batch size 64.
# It should still return 64 samples, with 4 of them that couldn't have been
# sampled.
numpy_labels = np.arange(100) numpy_labels = np.arange(100)
numpy_indicator = numpy_labels < 60 numpy_indicator = numpy_labels < 60
indicator = tf.constant(numpy_indicator) indicator = np.array(numpy_indicator, np.bool)
numpy_labels = (numpy_labels - 50) >= 0 numpy_labels = (numpy_labels - 50) >= 0
labels = tf.constant(numpy_labels) labels = np.array(numpy_labels, np.bool)
sampler = ( def graph_fn(indicator, labels):
balanced_positive_negative_sampler.BalancedPositiveNegativeSampler()) sampler = (
is_sampled = sampler.subsample(indicator, 64, labels) balanced_positive_negative_sampler.BalancedPositiveNegativeSampler())
with self.test_session() as sess: return sampler.subsample(indicator, 64, labels)
is_sampled = sess.run(is_sampled)
self.assertTrue(sum(is_sampled) == 60) is_sampled = self.execute_cpu(graph_fn, [indicator, labels])
self.assertTrue(sum(np.logical_and(numpy_labels, is_sampled)) == 10) self.assertEqual(sum(is_sampled), 60)
self.assertTrue( self.assertGreaterEqual(sum(np.logical_and(numpy_labels, is_sampled)), 10)
sum(np.logical_and(np.logical_not(numpy_labels), is_sampled)) == 50) self.assertGreaterEqual(
self.assertAllEqual(is_sampled, np.logical_and(is_sampled, sum(np.logical_and(np.logical_not(numpy_labels), is_sampled)), 50)
numpy_indicator)) self.assertEqual(sum(np.logical_and(is_sampled, numpy_indicator)), 60)
def test_subsample_selection_larger_batch_size_static(self): def test_subsample_selection_larger_batch_size_static(self):
if not self.has_tpu(): return
# Test random sampling when total number of examples that can be sampled are # Test random sampling when total number of examples that can be sampled are
# less than batch size: # less than batch size:
# 100 samples, 50 positives, 40 positives cannot be sampled, batch size 64. # 100 samples, 50 positives, 40 positives cannot be sampled, batch size 64.
...@@ -147,34 +156,33 @@ class BalancedPositiveNegativeSamplerTest(test_case.TestCase): ...@@ -147,34 +156,33 @@ class BalancedPositiveNegativeSamplerTest(test_case.TestCase):
is_static=True)) is_static=True))
return sampler.subsample(indicator, 64, labels) return sampler.subsample(indicator, 64, labels)
is_sampled = self.execute(graph_fn, [indicator, labels]) is_sampled = self.execute_tpu(graph_fn, [indicator, labels])
self.assertTrue(sum(is_sampled) == 64) self.assertEqual(sum(is_sampled), 64)
self.assertTrue(sum(np.logical_and(numpy_labels, is_sampled)) >= 10) self.assertGreaterEqual(sum(np.logical_and(numpy_labels, is_sampled)), 10)
self.assertTrue( self.assertGreaterEqual(
sum(np.logical_and(np.logical_not(numpy_labels), is_sampled)) >= 50) sum(np.logical_and(np.logical_not(numpy_labels), is_sampled)), 50)
self.assertTrue(sum(np.logical_and(is_sampled, numpy_indicator)) == 60) self.assertEqual(sum(np.logical_and(is_sampled, numpy_indicator)), 60)
def test_subsample_selection_no_batch_size(self): def test_subsample_selection_no_batch_size(self):
if self.has_tpu(): return
# Test random sampling when only some examples can be sampled: # Test random sampling when only some examples can be sampled:
# 1000 samples, 6 positives (5 can be sampled). # 1000 samples, 6 positives (5 can be sampled).
numpy_labels = np.arange(1000) numpy_labels = np.arange(1000)
numpy_indicator = numpy_labels < 999 numpy_indicator = numpy_labels < 999
indicator = tf.constant(numpy_indicator)
numpy_labels = (numpy_labels - 994) >= 0 numpy_labels = (numpy_labels - 994) >= 0
labels = tf.constant(numpy_labels) def graph_fn(indicator, labels):
sampler = (balanced_positive_negative_sampler.
sampler = (balanced_positive_negative_sampler. BalancedPositiveNegativeSampler(0.01))
BalancedPositiveNegativeSampler(0.01)) is_sampled = sampler.subsample(indicator, None, labels)
is_sampled = sampler.subsample(indicator, None, labels) return is_sampled
with self.test_session() as sess: is_sampled_out = self.execute_cpu(graph_fn, [numpy_indicator, numpy_labels])
is_sampled = sess.run(is_sampled) self.assertEqual(sum(is_sampled_out), 500)
self.assertTrue(sum(is_sampled) == 500) self.assertEqual(sum(np.logical_and(numpy_labels, is_sampled_out)), 5)
self.assertTrue(sum(np.logical_and(numpy_labels, is_sampled)) == 5) self.assertEqual(sum(np.logical_and(
self.assertTrue(sum(np.logical_and( np.logical_not(numpy_labels), is_sampled_out)), 495)
np.logical_not(numpy_labels), is_sampled)) == 495) self.assertAllEqual(is_sampled_out, np.logical_and(is_sampled_out,
self.assertAllEqual(is_sampled, np.logical_and(is_sampled, numpy_indicator))
numpy_indicator))
def test_subsample_selection_no_batch_size_static(self): def test_subsample_selection_no_batch_size_static(self):
labels = tf.constant([[True, False, False]]) labels = tf.constant([[True, False, False]])
......
...@@ -24,6 +24,10 @@ from six.moves import range ...@@ -24,6 +24,10 @@ from six.moves import range
import tensorflow as tf import tensorflow as tf
from object_detection.core import prefetcher from object_detection.core import prefetcher
from object_detection.utils import tf_version
if not tf_version.is_tf1():
raise ValueError('`batcher.py` is only supported in Tensorflow 1.X')
rt_shape_str = '_runtime_shapes' rt_shape_str = '_runtime_shapes'
......
...@@ -22,10 +22,11 @@ from __future__ import print_function ...@@ -22,10 +22,11 @@ from __future__ import print_function
import numpy as np import numpy as np
from six.moves import range from six.moves import range
import tensorflow as tf import tensorflow as tf
from tensorflow.contrib import slim as contrib_slim
from object_detection.core import batcher from object_detection.core import batcher
slim = tf.contrib.slim slim = contrib_slim
class BatcherTest(tf.test.TestCase): class BatcherTest(tf.test.TestCase):
......
...@@ -14,11 +14,11 @@ ...@@ -14,11 +14,11 @@
# ============================================================================== # ==============================================================================
"""Tests for object_detection.core.box_coder.""" """Tests for object_detection.core.box_coder."""
import tensorflow as tf import tensorflow as tf
from object_detection.core import box_coder from object_detection.core import box_coder
from object_detection.core import box_list from object_detection.core import box_list
from object_detection.utils import test_case
class MockBoxCoder(box_coder.BoxCoder): class MockBoxCoder(box_coder.BoxCoder):
...@@ -34,27 +34,28 @@ class MockBoxCoder(box_coder.BoxCoder): ...@@ -34,27 +34,28 @@ class MockBoxCoder(box_coder.BoxCoder):
return box_list.BoxList(rel_codes / 2.0) return box_list.BoxList(rel_codes / 2.0)
class BoxCoderTest(tf.test.TestCase): class BoxCoderTest(test_case.TestCase):
def test_batch_decode(self): def test_batch_decode(self):
mock_anchor_corners = tf.constant(
[[0, 0.1, 0.2, 0.3], [0.2, 0.4, 0.4, 0.6]], tf.float32)
mock_anchors = box_list.BoxList(mock_anchor_corners)
mock_box_coder = MockBoxCoder()
expected_boxes = [[[0.0, 0.1, 0.5, 0.6], [0.5, 0.6, 0.7, 0.8]], expected_boxes = [[[0.0, 0.1, 0.5, 0.6], [0.5, 0.6, 0.7, 0.8]],
[[0.1, 0.2, 0.3, 0.4], [0.7, 0.8, 0.9, 1.0]]] [[0.1, 0.2, 0.3, 0.4], [0.7, 0.8, 0.9, 1.0]]]
encoded_boxes_list = [mock_box_coder.encode( def graph_fn():
box_list.BoxList(tf.constant(boxes)), mock_anchors) mock_anchor_corners = tf.constant(
for boxes in expected_boxes] [[0, 0.1, 0.2, 0.3], [0.2, 0.4, 0.4, 0.6]], tf.float32)
encoded_boxes = tf.stack(encoded_boxes_list) mock_anchors = box_list.BoxList(mock_anchor_corners)
decoded_boxes = box_coder.batch_decode( mock_box_coder = MockBoxCoder()
encoded_boxes, mock_box_coder, mock_anchors)
encoded_boxes_list = [mock_box_coder.encode(
with self.test_session() as sess: box_list.BoxList(tf.constant(boxes)), mock_anchors)
decoded_boxes_result = sess.run(decoded_boxes) for boxes in expected_boxes]
self.assertAllClose(expected_boxes, decoded_boxes_result) encoded_boxes = tf.stack(encoded_boxes_list)
decoded_boxes = box_coder.batch_decode(
encoded_boxes, mock_box_coder, mock_anchors)
return decoded_boxes
decoded_boxes_result = self.execute(graph_fn, [])
self.assertAllClose(expected_boxes, decoded_boxes_result)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment