Commit 1efe98bb authored by Zhichao Lu's avatar Zhichao Lu Committed by lzc5123016
Browse files

Merged commit includes the following changes:

185215255  by Zhichao Lu:

    Stop populating image/object/class/text field when generating COCO tf record.

--
185213306  by Zhichao Lu:

    Use the params batch size and not the one from train_config in input_fn

--
185209081  by Zhichao Lu:

    Handle the case when there are no ground-truth masks for an image.

--
185195531  by Zhichao Lu:

    Remove unstack and stack operations on features from third_party/object_detection/model.py.

--
185195017  by Zhichao Lu:

    Matrix multiplication based gather op implementation.

--
185187744  by Zhichao Lu:

    Fix eval_util minor issue.

--
185098733  by Zhichao Lu:

    Internal change

185076656  by Zhichao Lu:

    Increment the amount of boxes for coco17.

--
185074199  by Zhichao Lu:

    Add config for SSD Resnet50 v1 with FPN.

--
185060199  by Zhichao Lu:

    Fix a bug in clear_detections.
    This method set detection_keys to an empty dictionary instead of an empty set. I've refactored so that this ...
parent fbc5ba06
...@@ -4,10 +4,14 @@ package( ...@@ -4,10 +4,14 @@ package(
default_visibility = ["//visibility:public"], default_visibility = ["//visibility:public"],
) )
load("//learning/brain/contrib/learn/tpu:tpu.bzl", "cloud_tpu_py_binaries")
licenses(["notice"]) licenses(["notice"])
# Apache 2.0 # Apache 2.0
exports_files(["LICENSE"])
py_library( py_library(
name = "inputs", name = "inputs",
srcs = [ srcs = [
...@@ -15,11 +19,14 @@ py_library( ...@@ -15,11 +19,14 @@ py_library(
], ],
deps = [ deps = [
"//tensorflow", "//tensorflow",
"//tensorflow/models/research/object_detection:trainer",
"//tensorflow/models/research/object_detection/builders:dataset_builder", "//tensorflow/models/research/object_detection/builders:dataset_builder",
"//tensorflow/models/research/object_detection/builders:image_resizer_builder",
"//tensorflow/models/research/object_detection/builders:model_builder",
"//tensorflow/models/research/object_detection/builders:preprocessor_builder", "//tensorflow/models/research/object_detection/builders:preprocessor_builder",
"//tensorflow/models/research/object_detection/protos:input_reader_py_pb2", "//tensorflow/models/research/object_detection/protos:input_reader_py_pb2",
"//tensorflow/models/research/object_detection/protos:model_py_pb2",
"//tensorflow/models/research/object_detection/protos:train_py_pb2", "//tensorflow/models/research/object_detection/protos:train_py_pb2",
"//tensorflow/models/research/object_detection/utils:config_util",
"//tensorflow/models/research/object_detection/utils:dataset_util", "//tensorflow/models/research/object_detection/utils:dataset_util",
"//tensorflow/models/research/object_detection/utils:ops", "//tensorflow/models/research/object_detection/utils:ops",
], ],
...@@ -44,6 +51,109 @@ py_test( ...@@ -44,6 +51,109 @@ py_test(
], ],
) )
py_binary(
name = "model",
srcs = [
"model.py",
],
deps = [
":inputs",
":model_hparams",
"//tensorflow",
"//tensorflow/models/research/object_detection:eval_util",
"//tensorflow/models/research/object_detection/builders:model_builder",
"//tensorflow/models/research/object_detection/builders:optimizer_builder",
"//tensorflow/models/research/object_detection/metrics:coco_evaluation",
"//tensorflow/models/research/object_detection/utils:config_util",
"//tensorflow/models/research/object_detection/utils:label_map_util",
"//tensorflow/models/research/object_detection/utils:ops",
"//tensorflow/models/research/object_detection/utils:shape_utils",
"//tensorflow/models/research/object_detection/utils:variables_helper",
"//tensorflow/models/research/object_detection/utils:visualization_utils",
],
)
py_library(
name = "model_hparams",
srcs = [
"model_hparams.py",
],
deps = [
"//tensorflow",
],
)
py_test(
name = "model_test",
timeout = "long",
srcs = [
"model_test.py",
],
data = [
"//tensorflow/models/research/object_detection/data:pet_label_map.pbtxt",
"//tensorflow/models/research/object_detection/samples/configs:faster_rcnn_resnet50_pets.config",
"//tensorflow/models/research/object_detection/samples/configs:ssd_inception_v2_pets.config",
"//tensorflow/models/research/object_detection/test_data:pets_examples.record",
],
deps = [
":inputs",
":model",
":model_hparams",
":model_test_util",
"//mock",
"//tensorflow",
"//tensorflow/models/research/object_detection/core:standard_fields",
"//tensorflow/models/research/object_detection/data_decoders:tf_example_decoder",
"//tensorflow/models/research/object_detection/utils:config_util",
"//tensorflow/models/research/object_detection/utils:ops",
],
)
MODEL_TPU_DEPS = [
":inputs",
":model",
":model_hparams",
"//tensorflow",
"//tensorflow/models/research/object_detection:eval_util",
"//tensorflow/models/research/object_detection/builders:model_builder",
"//tensorflow/models/research/object_detection/builders:optimizer_builder",
"//tensorflow/models/research/object_detection/metrics:coco_evaluation",
"//tensorflow/models/research/object_detection/utils:config_util",
"//tensorflow/models/research/object_detection/utils:label_map_util",
"//tensorflow/models/research/object_detection/utils:ops",
"//tensorflow/models/research/object_detection/utils:variables_helper",
"//tensorflow/models/research/object_detection/utils:visualization_utils",
]
cloud_tpu_py_binaries(
name = "model_tpu",
srcs = [
"model_tpu.py",
],
main = "model_tpu.py",
deps = MODEL_TPU_DEPS,
)
py_library(
name = "model_tpu_lib",
srcs = [
"model_tpu.py",
],
deps = MODEL_TPU_DEPS,
)
py_library(
name = "model_test_util",
srcs = [
"model_test_util.py",
],
deps = [
":model",
":model_hparams",
"//tensorflow",
],
)
py_binary( py_binary(
name = "train", name = "train",
srcs = [ srcs = [
...@@ -113,6 +223,7 @@ py_library( ...@@ -113,6 +223,7 @@ py_library(
"//tensorflow/models/research/object_detection:eval_util", "//tensorflow/models/research/object_detection:eval_util",
"//tensorflow/models/research/object_detection/core:prefetcher", "//tensorflow/models/research/object_detection/core:prefetcher",
"//tensorflow/models/research/object_detection/core:standard_fields", "//tensorflow/models/research/object_detection/core:standard_fields",
"//tensorflow/models/research/object_detection/metrics:coco_evaluation",
"//tensorflow/models/research/object_detection/protos:eval_py_pb2", "//tensorflow/models/research/object_detection/protos:eval_py_pb2",
"//tensorflow/models/research/object_detection/utils:object_detection_evaluation", "//tensorflow/models/research/object_detection/utils:object_detection_evaluation",
], ],
......
...@@ -69,6 +69,8 @@ Extras: ...@@ -69,6 +69,8 @@ Extras:
Supported object detection evaluation protocols</a><br> Supported object detection evaluation protocols</a><br>
* <a href='g3doc/oid_inference_and_evaluation.md'> * <a href='g3doc/oid_inference_and_evaluation.md'>
Inference and evaluation on the Open Images dataset</a><br> Inference and evaluation on the Open Images dataset</a><br>
* <a href='g3doc/instance_segmentation.md'>
Run an instance segmentation model
## Getting Help ## Getting Help
...@@ -77,7 +79,7 @@ API, create a new question on [StackOverflow](https://stackoverflow.com/) with ...@@ -77,7 +79,7 @@ API, create a new question on [StackOverflow](https://stackoverflow.com/) with
the tags "tensorflow" and "object-detection". the tags "tensorflow" and "object-detection".
Please report bugs (actually broken code, not usage questions) to the Please report bugs (actually broken code, not usage questions) to the
tensorflow/models Github tensorflow/models GitHub
[issue tracker](https://github.com/tensorflow/models/issues), prefixing the [issue tracker](https://github.com/tensorflow/models/issues), prefixing the
issue name with "object_detection". issue name with "object_detection".
...@@ -85,6 +87,15 @@ issue name with "object_detection". ...@@ -85,6 +87,15 @@ issue name with "object_detection".
## Release information ## Release information
### February 9, 2018
We now support instance segmentation!! In this API update we support a number of instance segmentation models similar to those discussed in the [Mask R-CNN paper](https://arxiv.org/abs/1703.06870). For further details refer to
[our slides](http://presentations.cocodataset.org/Places17-GMRI.pdf) from the 2017 Coco + Places Workshop.
Refer to the section on [Running an Instance Segmentation Model](g3doc/instance_segmentation.md) for instructions on how to configure a model
that predicts masks in addition to object bounding boxes.
<b>Thanks to contributors</b>: Alireza Fathi, Zhichao Lu, Vivek Rathod, Ronny Votel, Jonathan Huang
### November 17, 2017 ### November 17, 2017
As a part of the Open Images V3 release we have released: As a part of the Open Images V3 release we have released:
......
...@@ -16,8 +16,8 @@ ...@@ -16,8 +16,8 @@
Generates grid anchors on the fly corresponding to multiple CNN layers as Generates grid anchors on the fly corresponding to multiple CNN layers as
described in: described in:
"Focal Loss for Dense Object Detection" "Focal Loss for Dense Object Detection" (https://arxiv.org/abs/1708.02002)
T.-Y. Lin, P. Goyal, R. Girshick, K. He, P. Dollar (https://arxiv.org/abs/1708.02002) T.-Y. Lin, P. Goyal, R. Girshick, K. He, P. Dollar
""" """
from object_detection.anchor_generators import grid_anchor_generator from object_detection.anchor_generators import grid_anchor_generator
...@@ -77,11 +77,15 @@ class MultiscaleGridAnchorGenerator(object): ...@@ -77,11 +77,15 @@ class MultiscaleGridAnchorGenerator(object):
a list of integers, one for each expected feature map to be passed to a list of integers, one for each expected feature map to be passed to
the Generate function. the Generate function.
""" """
return self._aspect_ratios * self._scales_per_octave return len(self._anchor_grid_info) * [
len(self._aspect_ratios) * self._scales_per_octave]
def generate(self, feature_map_shape_list, im_height, im_width): def generate(self, feature_map_shape_list, im_height, im_width):
"""Generates a collection of bounding boxes to be used as anchors. """Generates a collection of bounding boxes to be used as anchors.
Currently we require the input image shape to be statically defined. That
is, im_height and im_width should be integers rather than tensors.
Args: Args:
feature_map_shape_list: list of pairs of convnet layer resolutions in the feature_map_shape_list: list of pairs of convnet layer resolutions in the
format [(height_0, width_0), (height_1, width_1), ...]. For example, format [(height_0, width_0), (height_1, width_1), ...]. For example,
...@@ -92,8 +96,12 @@ class MultiscaleGridAnchorGenerator(object): ...@@ -92,8 +96,12 @@ class MultiscaleGridAnchorGenerator(object):
Returns: Returns:
boxes: a BoxList holding a collection of N anchor boxes boxes: a BoxList holding a collection of N anchor boxes
Raises:
ValueError: if im_height and im_width are not integers.
""" """
if not isinstance(im_height, int) or not isinstance(im_width, int):
raise ValueError('MultiscaleGridAnchorGenerator currently requires '
'input image shape to be statically defined.')
anchor_grid_list = [] anchor_grid_list = []
for feat_shape, grid_info in zip(feature_map_shape_list, for feat_shape, grid_info in zip(feature_map_shape_list,
self._anchor_grid_info): self._anchor_grid_info):
......
...@@ -46,6 +46,30 @@ class MultiscaleGridAnchorGeneratorTest(test_case.TestCase): ...@@ -46,6 +46,30 @@ class MultiscaleGridAnchorGeneratorTest(test_case.TestCase):
anchor_corners_out = anchor_corners.eval() anchor_corners_out = anchor_corners.eval()
self.assertAllClose(anchor_corners_out, exp_anchor_corners) self.assertAllClose(anchor_corners_out, exp_anchor_corners)
def test_num_anchors_per_location(self):
min_level = 5
max_level = 6
anchor_scale = 4.0
aspect_ratios = [1.0, 2.0]
scales_per_octave = 3
anchor_generator = mg.MultiscaleGridAnchorGenerator(
min_level, max_level, anchor_scale, aspect_ratios, scales_per_octave)
self.assertEqual(anchor_generator.num_anchors_per_location(), [6, 6])
def test_construct_single_anchor_fails_with_tensor_image_size(self):
min_level = 5
max_level = 5
anchor_scale = 4.0
aspect_ratios = [1.0]
scales_per_octave = 1
im_height = tf.constant(64)
im_width = tf.constant(64)
feature_map_shape_list = [(2, 2)]
anchor_generator = mg.MultiscaleGridAnchorGenerator(
min_level, max_level, anchor_scale, aspect_ratios, scales_per_octave)
with self.assertRaises(ValueError):
anchor_generator.generate(feature_map_shape_list, im_height, im_width)
def test_construct_single_anchor_with_odd_input_dimension(self): def test_construct_single_anchor_with_odd_input_dimension(self):
def graph_fn(): def graph_fn():
......
...@@ -32,6 +32,7 @@ py_library( ...@@ -32,6 +32,7 @@ py_library(
"//tensorflow/models/research/object_detection/models:ssd_inception_v2_feature_extractor", "//tensorflow/models/research/object_detection/models:ssd_inception_v2_feature_extractor",
"//tensorflow/models/research/object_detection/models:ssd_inception_v3_feature_extractor", "//tensorflow/models/research/object_detection/models:ssd_inception_v3_feature_extractor",
"//tensorflow/models/research/object_detection/models:ssd_mobilenet_v1_feature_extractor", "//tensorflow/models/research/object_detection/models:ssd_mobilenet_v1_feature_extractor",
"//tensorflow/models/research/object_detection/models:ssd_resnet_v1_fpn_feature_extractor",
"//tensorflow/models/research/object_detection/protos:model_py_pb2", "//tensorflow/models/research/object_detection/protos:model_py_pb2",
], ],
) )
...@@ -44,6 +45,7 @@ py_test( ...@@ -44,6 +45,7 @@ py_test(
"//tensorflow", "//tensorflow",
"//tensorflow/models/research/object_detection/meta_architectures:faster_rcnn_meta_arch", "//tensorflow/models/research/object_detection/meta_architectures:faster_rcnn_meta_arch",
"//tensorflow/models/research/object_detection/meta_architectures:ssd_meta_arch", "//tensorflow/models/research/object_detection/meta_architectures:ssd_meta_arch",
"//tensorflow/models/research/object_detection/models:embedded_ssd_mobilenet_v1_feature_extractor",
"//tensorflow/models/research/object_detection/models:faster_rcnn_inception_resnet_v2_feature_extractor", "//tensorflow/models/research/object_detection/models:faster_rcnn_inception_resnet_v2_feature_extractor",
"//tensorflow/models/research/object_detection/models:faster_rcnn_inception_v2_feature_extractor", "//tensorflow/models/research/object_detection/models:faster_rcnn_inception_v2_feature_extractor",
"//tensorflow/models/research/object_detection/models:faster_rcnn_nas_feature_extractor", "//tensorflow/models/research/object_detection/models:faster_rcnn_nas_feature_extractor",
...@@ -51,6 +53,7 @@ py_test( ...@@ -51,6 +53,7 @@ py_test(
"//tensorflow/models/research/object_detection/models:ssd_inception_v2_feature_extractor", "//tensorflow/models/research/object_detection/models:ssd_inception_v2_feature_extractor",
"//tensorflow/models/research/object_detection/models:ssd_inception_v3_feature_extractor", "//tensorflow/models/research/object_detection/models:ssd_inception_v3_feature_extractor",
"//tensorflow/models/research/object_detection/models:ssd_mobilenet_v1_feature_extractor", "//tensorflow/models/research/object_detection/models:ssd_mobilenet_v1_feature_extractor",
"//tensorflow/models/research/object_detection/models:ssd_resnet_v1_fpn_feature_extractor",
"//tensorflow/models/research/object_detection/protos:model_py_pb2", "//tensorflow/models/research/object_detection/protos:model_py_pb2",
], ],
) )
......
...@@ -83,10 +83,10 @@ def build(anchor_generator_config): ...@@ -83,10 +83,10 @@ def build(anchor_generator_config):
'anchor_generator_oneof') == 'multiscale_anchor_generator': 'anchor_generator_oneof') == 'multiscale_anchor_generator':
cfg = anchor_generator_config.multiscale_anchor_generator cfg = anchor_generator_config.multiscale_anchor_generator
return multiscale_grid_anchor_generator.MultiscaleGridAnchorGenerator( return multiscale_grid_anchor_generator.MultiscaleGridAnchorGenerator(
cfg.min_lvl, cfg.min_level,
cfg.max_lvl, cfg.max_level,
cfg.anchor_scale, cfg.anchor_scale,
cfg.aspect_ratios, [float(aspect_ratio) for aspect_ratio in cfg.aspect_ratios],
cfg.scales_per_octave cfg.scales_per_octave
) )
else: else:
......
...@@ -22,6 +22,7 @@ import tensorflow as tf ...@@ -22,6 +22,7 @@ import tensorflow as tf
from google.protobuf import text_format from google.protobuf import text_format
from object_detection.anchor_generators import grid_anchor_generator from object_detection.anchor_generators import grid_anchor_generator
from object_detection.anchor_generators import multiple_grid_anchor_generator from object_detection.anchor_generators import multiple_grid_anchor_generator
from object_detection.anchor_generators import multiscale_grid_anchor_generator
from object_detection.builders import anchor_generator_builder from object_detection.builders import anchor_generator_builder
from object_detection.protos import anchor_generator_pb2 from object_detection.protos import anchor_generator_pb2
...@@ -252,6 +253,31 @@ class AnchorGeneratorBuilderTest(tf.test.TestCase): ...@@ -252,6 +253,31 @@ class AnchorGeneratorBuilderTest(tf.test.TestCase):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
anchor_generator_builder.build(anchor_generator_proto) anchor_generator_builder.build(anchor_generator_proto)
def test_build_multiscale_anchor_generator_custom_aspect_ratios(self):
anchor_generator_text_proto = """
multiscale_anchor_generator {
aspect_ratios: [1.0]
}
"""
anchor_generator_proto = anchor_generator_pb2.AnchorGenerator()
text_format.Merge(anchor_generator_text_proto, anchor_generator_proto)
anchor_generator_object = anchor_generator_builder.build(
anchor_generator_proto)
self.assertTrue(isinstance(anchor_generator_object,
multiscale_grid_anchor_generator.
MultiscaleGridAnchorGenerator))
print anchor_generator_object._anchor_grid_info
for level, anchor_grid_info in zip(
range(3, 8), anchor_generator_object._anchor_grid_info):
self.assertEqual(set(anchor_grid_info.keys()), set(['level', 'info']))
self.assertTrue(level, anchor_grid_info['level'])
self.assertEqual(len(anchor_grid_info['info']), 4)
self.assertAllClose(anchor_grid_info['info'][0], [2**0, 2**0.5])
self.assertTrue(anchor_grid_info['info'][1], 1.0)
self.assertAllClose(anchor_grid_info['info'][2],
[4.0 * 2**level, 4.0 * 2**level])
self.assertAllClose(anchor_grid_info['info'][3], [2**level, 2**level])
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
...@@ -64,7 +64,9 @@ def build(argscope_fn, box_predictor_config, is_training, num_classes): ...@@ -64,7 +64,9 @@ def build(argscope_fn, box_predictor_config, is_training, num_classes):
kernel_size=conv_box_predictor.kernel_size, kernel_size=conv_box_predictor.kernel_size,
box_code_size=conv_box_predictor.box_code_size, box_code_size=conv_box_predictor.box_code_size,
apply_sigmoid_to_scores=conv_box_predictor.apply_sigmoid_to_scores, apply_sigmoid_to_scores=conv_box_predictor.apply_sigmoid_to_scores,
class_prediction_bias_init=conv_box_predictor.class_prediction_bias_init class_prediction_bias_init=(conv_box_predictor.
class_prediction_bias_init),
use_depthwise=conv_box_predictor.use_depthwise
) )
return box_predictor_object return box_predictor_object
......
...@@ -83,6 +83,7 @@ class ConvolutionalBoxPredictorBuilderTest(tf.test.TestCase): ...@@ -83,6 +83,7 @@ class ConvolutionalBoxPredictorBuilderTest(tf.test.TestCase):
box_code_size: 3 box_code_size: 3
apply_sigmoid_to_scores: true apply_sigmoid_to_scores: true
class_prediction_bias_init: 4.0 class_prediction_bias_init: 4.0
use_depthwise: true
} }
""" """
conv_hyperparams_text_proto = """ conv_hyperparams_text_proto = """
...@@ -118,6 +119,7 @@ class ConvolutionalBoxPredictorBuilderTest(tf.test.TestCase): ...@@ -118,6 +119,7 @@ class ConvolutionalBoxPredictorBuilderTest(tf.test.TestCase):
self.assertAlmostEqual(box_predictor._class_prediction_bias_init, 4.0) self.assertAlmostEqual(box_predictor._class_prediction_bias_init, 4.0)
self.assertEqual(box_predictor.num_classes, 10) self.assertEqual(box_predictor.num_classes, 10)
self.assertFalse(box_predictor._is_training) self.assertFalse(box_predictor._is_training)
self.assertTrue(box_predictor._use_depthwise)
def test_construct_default_conv_box_predictor(self): def test_construct_default_conv_box_predictor(self):
box_predictor_text_proto = """ box_predictor_text_proto = """
...@@ -148,6 +150,7 @@ class ConvolutionalBoxPredictorBuilderTest(tf.test.TestCase): ...@@ -148,6 +150,7 @@ class ConvolutionalBoxPredictorBuilderTest(tf.test.TestCase):
self.assertFalse(box_predictor._apply_sigmoid_to_scores) self.assertFalse(box_predictor._apply_sigmoid_to_scores)
self.assertEqual(box_predictor.num_classes, 90) self.assertEqual(box_predictor.num_classes, 90)
self.assertTrue(box_predictor._is_training) self.assertTrue(box_predictor._is_training)
self.assertFalse(box_predictor._use_depthwise)
class WeightSharedConvolutionalBoxPredictorBuilderTest(tf.test.TestCase): class WeightSharedConvolutionalBoxPredictorBuilderTest(tf.test.TestCase):
......
...@@ -24,18 +24,93 @@ that wraps the build function. ...@@ -24,18 +24,93 @@ that wraps the build function.
import tensorflow as tf import tensorflow as tf
from object_detection.core import standard_fields as fields
from object_detection.data_decoders import tf_example_decoder from object_detection.data_decoders import tf_example_decoder
from object_detection.protos import input_reader_pb2 from object_detection.protos import input_reader_pb2
from object_detection.utils import dataset_util from object_detection.utils import dataset_util
def build(input_reader_config, num_workers=1, worker_index=0): def _get_padding_shapes(dataset, max_num_boxes, num_classes,
"""Builds a tf.data.Dataset based on the InputReader config. spatial_image_shape):
"""Returns shapes to pad dataset tensors to before batching.
Args:
dataset: tf.data.Dataset object.
max_num_boxes: Max number of groundtruth boxes needed to computes shapes for
padding.
num_classes: Number of classes in the dataset needed to compute shapes for
padding.
spatial_image_shape: A list of two integers of the form [height, width]
containing expected spatial shape of the imaage.
Returns:
A dictionary keyed by fields.InputDataFields containing padding shapes for
tensors in the dataset.
"""
height, width = spatial_image_shape
padding_shapes = {
fields.InputDataFields.image: [height, width, 3],
fields.InputDataFields.source_id: [],
fields.InputDataFields.filename: [],
fields.InputDataFields.key: [],
fields.InputDataFields.groundtruth_difficult: [max_num_boxes],
fields.InputDataFields.groundtruth_boxes: [max_num_boxes, 4],
fields.InputDataFields.groundtruth_classes: [
max_num_boxes, num_classes
],
fields.InputDataFields.groundtruth_instance_masks: [max_num_boxes, height,
width],
fields.InputDataFields.groundtruth_is_crowd: [max_num_boxes],
fields.InputDataFields.groundtruth_group_of: [max_num_boxes],
fields.InputDataFields.groundtruth_area: [max_num_boxes],
fields.InputDataFields.groundtruth_weights: [max_num_boxes],
fields.InputDataFields.num_groundtruth_boxes: [],
fields.InputDataFields.groundtruth_label_types: [max_num_boxes],
fields.InputDataFields.groundtruth_label_scores: [max_num_boxes],
fields.InputDataFields.true_image_shape: [3]
}
if fields.InputDataFields.groundtruth_keypoints in dataset.output_shapes:
tensor_shape = dataset.output_shapes[fields.InputDataFields.
groundtruth_keypoints]
padding_shape = [max_num_boxes, tensor_shape[1].value,
tensor_shape[2].value]
padding_shapes[fields.InputDataFields.groundtruth_keypoints] = padding_shape
if (fields.InputDataFields.groundtruth_keypoint_visibilities
in dataset.output_shapes):
tensor_shape = dataset.output_shapes[fields.InputDataFields.
groundtruth_keypoint_visibilities]
padding_shape = [max_num_boxes, tensor_shape[1].value]
padding_shapes[fields.InputDataFields.
groundtruth_keypoint_visibilities] = padding_shape
return {tensor_key: padding_shapes[tensor_key]
for tensor_key, _ in dataset.output_shapes.items()}
def build(input_reader_config, transform_input_data_fn=None, num_workers=1,
worker_index=0, batch_size=1, max_num_boxes=None, num_classes=None,
spatial_image_shape=None):
"""Builds a tf.data.Dataset.
Builds a tf.data.Dataset by applying the `transform_input_data_fn` on all
records. Optionally, if `batch_size` > 1 and `max_num_boxes`, `num_classes`
and `spatial_image_shape` are not None, returns a padded batched
tf.data.Dataset.
Args: Args:
input_reader_config: A input_reader_pb2.InputReader object. input_reader_config: A input_reader_pb2.InputReader object.
num_workers: Number of workers / shards. transform_input_data_fn: Function to apply to all records, or None if
worker_index: Id for the current worker. no extra decoding is required.
num_workers: Number of workers (tpu shard).
worker_index: Id for the current worker (tpu shard).
batch_size: Batch size. If not None, returns a padded batch dataset.
max_num_boxes: Max number of groundtruth boxes needed to computes shapes for
padding. This is only used if batch_size is greater than 1.
num_classes: Number of classes in the dataset needed to compute shapes for
padding. This is only used if batch_size is greater than 1.
spatial_image_shape: a list of two integers of the form [height, width]
containing expected spatial shape of the image after applying
transform_input_data_fn. This is needed to compute shapes for padding and
only used if batch_size is greater than 1.
Returns: Returns:
A tf.data.Dataset based on the input_reader_config. A tf.data.Dataset based on the input_reader_config.
...@@ -43,6 +118,8 @@ def build(input_reader_config, num_workers=1, worker_index=0): ...@@ -43,6 +118,8 @@ def build(input_reader_config, num_workers=1, worker_index=0):
Raises: Raises:
ValueError: On invalid input reader proto. ValueError: On invalid input reader proto.
ValueError: If no input paths are specified. ValueError: If no input paths are specified.
ValueError: If batch_size > 1 and any of (max_num_boxes, num_classes,
spatial_image_shape) is None.
""" """
if not isinstance(input_reader_config, input_reader_pb2.InputReader): if not isinstance(input_reader_config, input_reader_pb2.InputReader):
raise ValueError('input_reader_config not of type ' raise ValueError('input_reader_config not of type '
...@@ -62,8 +139,29 @@ def build(input_reader_config, num_workers=1, worker_index=0): ...@@ -62,8 +139,29 @@ def build(input_reader_config, num_workers=1, worker_index=0):
instance_mask_type=input_reader_config.mask_type, instance_mask_type=input_reader_config.mask_type,
label_map_proto_file=label_map_proto_file) label_map_proto_file=label_map_proto_file)
return dataset_util.read_dataset( def process_fn(value):
tf.data.TFRecordDataset, decoder.decode, config.input_path[:], processed = decoder.decode(value)
if transform_input_data_fn is not None:
return transform_input_data_fn(processed)
return processed
dataset = dataset_util.read_dataset(
tf.data.TFRecordDataset, process_fn, config.input_path[:],
input_reader_config, num_workers, worker_index) input_reader_config, num_workers, worker_index)
if batch_size > 1:
if num_classes is None:
raise ValueError('`num_classes` must be set when batch_size > 1.')
if max_num_boxes is None:
raise ValueError('`max_num_boxes` must be set when batch_size > 1.')
if spatial_image_shape is None:
raise ValueError('`spatial_image_shape` must be set when batch_size > '
'1 .')
padding_shapes = _get_padding_shapes(dataset, max_num_boxes, num_classes,
spatial_image_shape)
dataset = dataset.apply(
tf.contrib.data.padded_batch_and_drop_remainder(batch_size,
padding_shapes))
return dataset
raise ValueError('Unsupported input_reader_config.') raise ValueError('Unsupported input_reader_config.')
...@@ -130,18 +130,92 @@ class DatasetBuilderTest(tf.test.TestCase): ...@@ -130,18 +130,92 @@ class DatasetBuilderTest(tf.test.TestCase):
with sv.prepare_or_wait_for_session() as sess: with sv.prepare_or_wait_for_session() as sess:
sv.start_queue_runners(sess) sv.start_queue_runners(sess)
output_dict = sess.run(tensor_dict) output_dict = sess.run(tensor_dict)
self.assertAllEqual(
(1, 4, 5),
output_dict[fields.InputDataFields.groundtruth_instance_masks].shape)
self.assertEquals((4, 5, 3), def test_build_tf_record_input_reader_with_batch_size_two(self):
output_dict[fields.InputDataFields.image].shape) tf_record_path = self.create_tf_record()
self.assertEquals([2],
output_dict[fields.InputDataFields.groundtruth_classes]) input_reader_text_proto = """
self.assertEquals( shuffle: false
(1, 4), output_dict[fields.InputDataFields.groundtruth_boxes].shape) num_readers: 1
tf_record_input_reader {{
input_path: '{0}'
}}
""".format(tf_record_path)
input_reader_proto = input_reader_pb2.InputReader()
text_format.Merge(input_reader_text_proto, input_reader_proto)
def one_hot_class_encoding_fn(tensor_dict):
tensor_dict[fields.InputDataFields.groundtruth_classes] = tf.one_hot(
tensor_dict[fields.InputDataFields.groundtruth_classes] - 1, depth=3)
return tensor_dict
tensor_dict = dataset_util.make_initializable_iterator(
dataset_builder.build(
input_reader_proto,
transform_input_data_fn=one_hot_class_encoding_fn,
batch_size=2,
max_num_boxes=2,
num_classes=3,
spatial_image_shape=[4, 5])).get_next()
sv = tf.train.Supervisor(logdir=self.get_temp_dir())
with sv.prepare_or_wait_for_session() as sess:
sv.start_queue_runners(sess)
output_dict = sess.run(tensor_dict)
self.assertAllEqual([2, 4, 5, 3],
output_dict[fields.InputDataFields.image].shape)
self.assertAllEqual([2, 2, 3],
output_dict[fields.InputDataFields.groundtruth_classes].
shape)
self.assertAllEqual([2, 2, 4],
output_dict[fields.InputDataFields.groundtruth_boxes].
shape)
self.assertAllEqual( self.assertAllEqual(
[0.0, 0.0, 1.0, 1.0], [[[0.0, 0.0, 1.0, 1.0],
output_dict[fields.InputDataFields.groundtruth_boxes][0]) [0.0, 0.0, 0.0, 0.0]],
[[0.0, 0.0, 1.0, 1.0],
[0.0, 0.0, 0.0, 0.0]]],
output_dict[fields.InputDataFields.groundtruth_boxes])
def test_build_tf_record_input_reader_with_batch_size_two_and_masks(self):
tf_record_path = self.create_tf_record()
input_reader_text_proto = """
shuffle: false
num_readers: 1
load_instance_masks: true
tf_record_input_reader {{
input_path: '{0}'
}}
""".format(tf_record_path)
input_reader_proto = input_reader_pb2.InputReader()
text_format.Merge(input_reader_text_proto, input_reader_proto)
def one_hot_class_encoding_fn(tensor_dict):
tensor_dict[fields.InputDataFields.groundtruth_classes] = tf.one_hot(
tensor_dict[fields.InputDataFields.groundtruth_classes] - 1, depth=3)
return tensor_dict
tensor_dict = dataset_util.make_initializable_iterator(
dataset_builder.build(
input_reader_proto,
transform_input_data_fn=one_hot_class_encoding_fn,
batch_size=2,
max_num_boxes=2,
num_classes=3,
spatial_image_shape=[4, 5])).get_next()
sv = tf.train.Supervisor(logdir=self.get_temp_dir())
with sv.prepare_or_wait_for_session() as sess:
sv.start_queue_runners(sess)
output_dict = sess.run(tensor_dict)
self.assertAllEqual( self.assertAllEqual(
(1, 4, 5), [2, 2, 4, 5],
output_dict[fields.InputDataFields.groundtruth_instance_masks].shape) output_dict[fields.InputDataFields.groundtruth_instance_masks].shape)
def test_raises_error_with_no_input_paths(self): def test_raises_error_with_no_input_paths(self):
......
...@@ -79,19 +79,32 @@ def build(image_resizer_config): ...@@ -79,19 +79,32 @@ def build(image_resizer_config):
keep_aspect_ratio_config.max_dimension): keep_aspect_ratio_config.max_dimension):
raise ValueError('min_dimension > max_dimension') raise ValueError('min_dimension > max_dimension')
method = _tf_resize_method(keep_aspect_ratio_config.resize_method) method = _tf_resize_method(keep_aspect_ratio_config.resize_method)
return functools.partial( image_resizer_fn = functools.partial(
preprocessor.resize_to_range, preprocessor.resize_to_range,
min_dimension=keep_aspect_ratio_config.min_dimension, min_dimension=keep_aspect_ratio_config.min_dimension,
max_dimension=keep_aspect_ratio_config.max_dimension, max_dimension=keep_aspect_ratio_config.max_dimension,
method=method, method=method,
pad_to_max_dimension=keep_aspect_ratio_config.pad_to_max_dimension) pad_to_max_dimension=keep_aspect_ratio_config.pad_to_max_dimension)
if image_resizer_config.WhichOneof( if not keep_aspect_ratio_config.convert_to_grayscale:
return image_resizer_fn
elif image_resizer_config.WhichOneof(
'image_resizer_oneof') == 'fixed_shape_resizer': 'image_resizer_oneof') == 'fixed_shape_resizer':
fixed_shape_resizer_config = image_resizer_config.fixed_shape_resizer fixed_shape_resizer_config = image_resizer_config.fixed_shape_resizer
method = _tf_resize_method(fixed_shape_resizer_config.resize_method) method = _tf_resize_method(fixed_shape_resizer_config.resize_method)
return functools.partial( image_resizer_fn = functools.partial(
preprocessor.resize_image, preprocessor.resize_image,
new_height=fixed_shape_resizer_config.height, new_height=fixed_shape_resizer_config.height,
new_width=fixed_shape_resizer_config.width, new_width=fixed_shape_resizer_config.width,
method=method) method=method)
raise ValueError('Invalid image resizer option.') if not fixed_shape_resizer_config.convert_to_grayscale:
return image_resizer_fn
else:
raise ValueError('Invalid image resizer option.')
def grayscale_image_resizer(image):
[resized_image, resized_image_shape] = image_resizer_fn(image)
grayscale_image = preprocessor.rgb_to_gray(resized_image)
grayscale_image_shape = tf.concat([resized_image_shape[:-1], [1]], 0)
return [grayscale_image, grayscale_image_shape]
return functools.partial(grayscale_image_resizer)
...@@ -45,7 +45,9 @@ def build(matcher_config): ...@@ -45,7 +45,9 @@ def build(matcher_config):
matched_threshold=matched_threshold, matched_threshold=matched_threshold,
unmatched_threshold=unmatched_threshold, unmatched_threshold=unmatched_threshold,
negatives_lower_than_unmatched=matcher.negatives_lower_than_unmatched, negatives_lower_than_unmatched=matcher.negatives_lower_than_unmatched,
force_match_for_each_row=matcher.force_match_for_each_row) force_match_for_each_row=matcher.force_match_for_each_row,
use_matmul_gather=matcher.use_matmul_gather)
if matcher_config.WhichOneof('matcher_oneof') == 'bipartite_matcher': if matcher_config.WhichOneof('matcher_oneof') == 'bipartite_matcher':
return bipartite_matcher.GreedyBipartiteMatcher() matcher = matcher_config.bipartite_matcher
return bipartite_matcher.GreedyBipartiteMatcher(matcher.use_matmul_gather)
raise ValueError('Empty matcher.') raise ValueError('Empty matcher.')
...@@ -62,6 +62,7 @@ class MatcherBuilderTest(tf.test.TestCase): ...@@ -62,6 +62,7 @@ class MatcherBuilderTest(tf.test.TestCase):
unmatched_threshold: 0.3 unmatched_threshold: 0.3
negatives_lower_than_unmatched: false negatives_lower_than_unmatched: false
force_match_for_each_row: true force_match_for_each_row: true
use_matmul_gather: true
} }
""" """
matcher_proto = matcher_pb2.Matcher() matcher_proto = matcher_pb2.Matcher()
...@@ -72,6 +73,7 @@ class MatcherBuilderTest(tf.test.TestCase): ...@@ -72,6 +73,7 @@ class MatcherBuilderTest(tf.test.TestCase):
self.assertAlmostEqual(matcher_object._unmatched_threshold, 0.3) self.assertAlmostEqual(matcher_object._unmatched_threshold, 0.3)
self.assertFalse(matcher_object._negatives_lower_than_unmatched) self.assertFalse(matcher_object._negatives_lower_than_unmatched)
self.assertTrue(matcher_object._force_match_for_each_row) self.assertTrue(matcher_object._force_match_for_each_row)
self.assertTrue(matcher_object._use_matmul_gather)
def test_build_bipartite_matcher(self): def test_build_bipartite_matcher(self):
matcher_text_proto = """ matcher_text_proto = """
......
...@@ -31,6 +31,7 @@ from object_detection.models import faster_rcnn_inception_resnet_v2_feature_extr ...@@ -31,6 +31,7 @@ from object_detection.models import faster_rcnn_inception_resnet_v2_feature_extr
from object_detection.models import faster_rcnn_inception_v2_feature_extractor as frcnn_inc_v2 from object_detection.models import faster_rcnn_inception_v2_feature_extractor as frcnn_inc_v2
from object_detection.models import faster_rcnn_nas_feature_extractor as frcnn_nas from object_detection.models import faster_rcnn_nas_feature_extractor as frcnn_nas
from object_detection.models import faster_rcnn_resnet_v1_feature_extractor as frcnn_resnet_v1 from object_detection.models import faster_rcnn_resnet_v1_feature_extractor as frcnn_resnet_v1
from object_detection.models import ssd_resnet_v1_fpn_feature_extractor as ssd_resnet_v1_fpn
from object_detection.models.embedded_ssd_mobilenet_v1_feature_extractor import EmbeddedSSDMobileNetV1FeatureExtractor from object_detection.models.embedded_ssd_mobilenet_v1_feature_extractor import EmbeddedSSDMobileNetV1FeatureExtractor
from object_detection.models.ssd_inception_v2_feature_extractor import SSDInceptionV2FeatureExtractor from object_detection.models.ssd_inception_v2_feature_extractor import SSDInceptionV2FeatureExtractor
from object_detection.models.ssd_inception_v3_feature_extractor import SSDInceptionV3FeatureExtractor from object_detection.models.ssd_inception_v3_feature_extractor import SSDInceptionV3FeatureExtractor
...@@ -42,6 +43,9 @@ SSD_FEATURE_EXTRACTOR_CLASS_MAP = { ...@@ -42,6 +43,9 @@ SSD_FEATURE_EXTRACTOR_CLASS_MAP = {
'ssd_inception_v2': SSDInceptionV2FeatureExtractor, 'ssd_inception_v2': SSDInceptionV2FeatureExtractor,
'ssd_inception_v3': SSDInceptionV3FeatureExtractor, 'ssd_inception_v3': SSDInceptionV3FeatureExtractor,
'ssd_mobilenet_v1': SSDMobileNetV1FeatureExtractor, 'ssd_mobilenet_v1': SSDMobileNetV1FeatureExtractor,
'ssd_resnet50_v1_fpn': ssd_resnet_v1_fpn.SSDResnet50V1FpnFeatureExtractor,
'ssd_resnet101_v1_fpn': ssd_resnet_v1_fpn.SSDResnet101V1FpnFeatureExtractor,
'ssd_resnet152_v1_fpn': ssd_resnet_v1_fpn.SSDResnet152V1FpnFeatureExtractor,
'embedded_ssd_mobilenet_v1': EmbeddedSSDMobileNetV1FeatureExtractor, 'embedded_ssd_mobilenet_v1': EmbeddedSSDMobileNetV1FeatureExtractor,
} }
...@@ -62,13 +66,14 @@ FASTER_RCNN_FEATURE_EXTRACTOR_CLASS_MAP = { ...@@ -62,13 +66,14 @@ FASTER_RCNN_FEATURE_EXTRACTOR_CLASS_MAP = {
} }
def build(model_config, is_training): def build(model_config, is_training, add_summaries=True):
"""Builds a DetectionModel based on the model config. """Builds a DetectionModel based on the model config.
Args: Args:
model_config: A model.proto object containing the config for the desired model_config: A model.proto object containing the config for the desired
DetectionModel. DetectionModel.
is_training: True if this model is being built for training purposes. is_training: True if this model is being built for training purposes.
add_summaries: Whether to add tensorflow summaries in the model graph.
Returns: Returns:
DetectionModel based on the config. DetectionModel based on the config.
...@@ -80,9 +85,10 @@ def build(model_config, is_training): ...@@ -80,9 +85,10 @@ def build(model_config, is_training):
raise ValueError('model_config not of type model_pb2.DetectionModel.') raise ValueError('model_config not of type model_pb2.DetectionModel.')
meta_architecture = model_config.WhichOneof('model') meta_architecture = model_config.WhichOneof('model')
if meta_architecture == 'ssd': if meta_architecture == 'ssd':
return _build_ssd_model(model_config.ssd, is_training) return _build_ssd_model(model_config.ssd, is_training, add_summaries)
if meta_architecture == 'faster_rcnn': if meta_architecture == 'faster_rcnn':
return _build_faster_rcnn_model(model_config.faster_rcnn, is_training) return _build_faster_rcnn_model(model_config.faster_rcnn, is_training,
add_summaries)
raise ValueError('Unknown meta architecture: {}'.format(meta_architecture)) raise ValueError('Unknown meta architecture: {}'.format(meta_architecture))
...@@ -107,6 +113,7 @@ def _build_ssd_feature_extractor(feature_extractor_config, is_training, ...@@ -107,6 +113,7 @@ def _build_ssd_feature_extractor(feature_extractor_config, is_training,
pad_to_multiple = feature_extractor_config.pad_to_multiple pad_to_multiple = feature_extractor_config.pad_to_multiple
batch_norm_trainable = feature_extractor_config.batch_norm_trainable batch_norm_trainable = feature_extractor_config.batch_norm_trainable
use_explicit_padding = feature_extractor_config.use_explicit_padding use_explicit_padding = feature_extractor_config.use_explicit_padding
use_depthwise = feature_extractor_config.use_depthwise
conv_hyperparams = hyperparams_builder.build( conv_hyperparams = hyperparams_builder.build(
feature_extractor_config.conv_hyperparams, is_training) feature_extractor_config.conv_hyperparams, is_training)
...@@ -117,16 +124,17 @@ def _build_ssd_feature_extractor(feature_extractor_config, is_training, ...@@ -117,16 +124,17 @@ def _build_ssd_feature_extractor(feature_extractor_config, is_training,
return feature_extractor_class(is_training, depth_multiplier, min_depth, return feature_extractor_class(is_training, depth_multiplier, min_depth,
pad_to_multiple, conv_hyperparams, pad_to_multiple, conv_hyperparams,
batch_norm_trainable, reuse_weights, batch_norm_trainable, reuse_weights,
use_explicit_padding) use_explicit_padding, use_depthwise)
def _build_ssd_model(ssd_config, is_training): def _build_ssd_model(ssd_config, is_training, add_summaries):
"""Builds an SSD detection model based on the model config. """Builds an SSD detection model based on the model config.
Args: Args:
ssd_config: A ssd.proto object containing the config for the desired ssd_config: A ssd.proto object containing the config for the desired
SSDMetaArch. SSDMetaArch.
is_training: True if this model is being built for training purposes. is_training: True if this model is being built for training purposes.
add_summaries: Whether to add tf summaries in the model.
Returns: Returns:
SSDMetaArch based on the config. SSDMetaArch based on the config.
...@@ -173,7 +181,8 @@ def _build_ssd_model(ssd_config, is_training): ...@@ -173,7 +181,8 @@ def _build_ssd_model(ssd_config, is_training):
classification_weight, classification_weight,
localization_weight, localization_weight,
normalize_loss_by_num_matches, normalize_loss_by_num_matches,
hard_example_miner) hard_example_miner,
add_summaries=add_summaries)
def _build_faster_rcnn_feature_extractor( def _build_faster_rcnn_feature_extractor(
...@@ -207,7 +216,7 @@ def _build_faster_rcnn_feature_extractor( ...@@ -207,7 +216,7 @@ def _build_faster_rcnn_feature_extractor(
batch_norm_trainable, reuse_weights) batch_norm_trainable, reuse_weights)
def _build_faster_rcnn_model(frcnn_config, is_training): def _build_faster_rcnn_model(frcnn_config, is_training, add_summaries):
"""Builds a Faster R-CNN or R-FCN detection model based on the model config. """Builds a Faster R-CNN or R-FCN detection model based on the model config.
Builds R-FCN model if the second_stage_box_predictor in the config is of type Builds R-FCN model if the second_stage_box_predictor in the config is of type
...@@ -215,8 +224,9 @@ def _build_faster_rcnn_model(frcnn_config, is_training): ...@@ -215,8 +224,9 @@ def _build_faster_rcnn_model(frcnn_config, is_training):
Args: Args:
frcnn_config: A faster_rcnn.proto object containing the config for the frcnn_config: A faster_rcnn.proto object containing the config for the
desired FasterRCNNMetaArch or RFCNMetaArch. desired FasterRCNNMetaArch or RFCNMetaArch.
is_training: True if this model is being built for training purposes. is_training: True if this model is being built for training purposes.
add_summaries: Whether to add tf summaries in the model.
Returns: Returns:
FasterRCNNMetaArch based on the config. FasterRCNNMetaArch based on the config.
...@@ -312,7 +322,8 @@ def _build_faster_rcnn_model(frcnn_config, is_training): ...@@ -312,7 +322,8 @@ def _build_faster_rcnn_model(frcnn_config, is_training):
second_stage_classification_loss, second_stage_classification_loss,
'second_stage_classification_loss_weight': 'second_stage_classification_loss_weight':
second_stage_classification_loss_weight, second_stage_classification_loss_weight,
'hard_example_miner': hard_example_miner} 'hard_example_miner': hard_example_miner,
'add_summaries': add_summaries}
if isinstance(second_stage_box_predictor, box_predictor.RfcnBoxPredictor): if isinstance(second_stage_box_predictor, box_predictor.RfcnBoxPredictor):
return rfcn_meta_arch.RFCNMetaArch( return rfcn_meta_arch.RFCNMetaArch(
......
...@@ -26,12 +26,14 @@ from object_detection.models import faster_rcnn_inception_resnet_v2_feature_extr ...@@ -26,12 +26,14 @@ from object_detection.models import faster_rcnn_inception_resnet_v2_feature_extr
from object_detection.models import faster_rcnn_inception_v2_feature_extractor as frcnn_inc_v2 from object_detection.models import faster_rcnn_inception_v2_feature_extractor as frcnn_inc_v2
from object_detection.models import faster_rcnn_nas_feature_extractor as frcnn_nas from object_detection.models import faster_rcnn_nas_feature_extractor as frcnn_nas
from object_detection.models import faster_rcnn_resnet_v1_feature_extractor as frcnn_resnet_v1 from object_detection.models import faster_rcnn_resnet_v1_feature_extractor as frcnn_resnet_v1
from object_detection.models import ssd_resnet_v1_fpn_feature_extractor as ssd_resnet_v1_fpn
from object_detection.models.embedded_ssd_mobilenet_v1_feature_extractor import EmbeddedSSDMobileNetV1FeatureExtractor
from object_detection.models.ssd_inception_v2_feature_extractor import SSDInceptionV2FeatureExtractor from object_detection.models.ssd_inception_v2_feature_extractor import SSDInceptionV2FeatureExtractor
from object_detection.models.ssd_inception_v3_feature_extractor import SSDInceptionV3FeatureExtractor from object_detection.models.ssd_inception_v3_feature_extractor import SSDInceptionV3FeatureExtractor
from object_detection.models.ssd_mobilenet_v1_feature_extractor import SSDMobileNetV1FeatureExtractor from object_detection.models.ssd_mobilenet_v1_feature_extractor import SSDMobileNetV1FeatureExtractor
from object_detection.protos import model_pb2 from object_detection.protos import model_pb2
FEATURE_EXTRACTOR_MAPS = { FRCNN_RESNET_FEAT_MAPS = {
'faster_rcnn_resnet50': 'faster_rcnn_resnet50':
frcnn_resnet_v1.FasterRCNNResnet50FeatureExtractor, frcnn_resnet_v1.FasterRCNNResnet50FeatureExtractor,
'faster_rcnn_resnet101': 'faster_rcnn_resnet101':
...@@ -40,6 +42,15 @@ FEATURE_EXTRACTOR_MAPS = { ...@@ -40,6 +42,15 @@ FEATURE_EXTRACTOR_MAPS = {
frcnn_resnet_v1.FasterRCNNResnet152FeatureExtractor frcnn_resnet_v1.FasterRCNNResnet152FeatureExtractor
} }
SSD_RESNET_V1_FPN_FEAT_MAPS = {
'ssd_resnet50_v1_fpn':
ssd_resnet_v1_fpn.SSDResnet50V1FpnFeatureExtractor,
'ssd_resnet101_v1_fpn':
ssd_resnet_v1_fpn.SSDResnet101V1FpnFeatureExtractor,
'ssd_resnet152_v1_fpn':
ssd_resnet_v1_fpn.SSDResnet152V1FpnFeatureExtractor
}
class ModelBuilderTest(tf.test.TestCase): class ModelBuilderTest(tf.test.TestCase):
...@@ -197,6 +208,87 @@ class ModelBuilderTest(tf.test.TestCase): ...@@ -197,6 +208,87 @@ class ModelBuilderTest(tf.test.TestCase):
self.assertIsInstance(model._feature_extractor, self.assertIsInstance(model._feature_extractor,
SSDInceptionV3FeatureExtractor) SSDInceptionV3FeatureExtractor)
def test_create_ssd_resnet_v1_fpn_model_from_config(self):
model_text_proto = """
ssd {
feature_extractor {
type: 'ssd_resnet50_v1_fpn'
conv_hyperparams {
regularizer {
l2_regularizer {
}
}
initializer {
truncated_normal_initializer {
}
}
}
batch_norm_trainable: true
}
box_coder {
faster_rcnn_box_coder {
}
}
matcher {
argmax_matcher {
}
}
similarity_calculator {
iou_similarity {
}
}
anchor_generator {
multiscale_anchor_generator {
aspect_ratios: [1.0, 2.0, 0.5]
scales_per_octave: 2
}
}
image_resizer {
fixed_shape_resizer {
height: 320
width: 320
}
}
box_predictor {
weight_shared_convolutional_box_predictor {
depth: 32
conv_hyperparams {
regularizer {
l2_regularizer {
}
}
initializer {
truncated_normal_initializer {
}
}
}
num_layers_before_predictor: 1
}
}
loss {
classification_loss {
weighted_sigmoid_focal {
alpha: 0.25
gamma: 2.0
}
}
localization_loss {
weighted_smooth_l1 {
}
}
classification_weight: 1.0
localization_weight: 1.0
}
}"""
model_proto = model_pb2.DetectionModel()
text_format.Merge(model_text_proto, model_proto)
for extractor_type, extractor_class in SSD_RESNET_V1_FPN_FEAT_MAPS.items():
model_proto.ssd.feature_extractor.type = extractor_type
model = model_builder.build(model_proto, is_training=True)
self.assertIsInstance(model, ssd_meta_arch.SSDMetaArch)
self.assertIsInstance(model._feature_extractor, extractor_class)
def test_create_ssd_mobilenet_v1_model_from_config(self): def test_create_ssd_mobilenet_v1_model_from_config(self):
model_text_proto = """ model_text_proto = """
ssd { ssd {
...@@ -270,6 +362,78 @@ class ModelBuilderTest(tf.test.TestCase): ...@@ -270,6 +362,78 @@ class ModelBuilderTest(tf.test.TestCase):
SSDMobileNetV1FeatureExtractor) SSDMobileNetV1FeatureExtractor)
self.assertTrue(model._feature_extractor._batch_norm_trainable) self.assertTrue(model._feature_extractor._batch_norm_trainable)
def test_create_embedded_ssd_mobilenet_v1_model_from_config(self):
model_text_proto = """
ssd {
feature_extractor {
type: 'embedded_ssd_mobilenet_v1'
conv_hyperparams {
regularizer {
l2_regularizer {
}
}
initializer {
truncated_normal_initializer {
}
}
}
batch_norm_trainable: true
}
box_coder {
faster_rcnn_box_coder {
}
}
matcher {
argmax_matcher {
}
}
similarity_calculator {
iou_similarity {
}
}
anchor_generator {
ssd_anchor_generator {
aspect_ratios: 1.0
}
}
image_resizer {
fixed_shape_resizer {
height: 256
width: 256
}
}
box_predictor {
convolutional_box_predictor {
conv_hyperparams {
regularizer {
l2_regularizer {
}
}
initializer {
truncated_normal_initializer {
}
}
}
}
}
loss {
classification_loss {
weighted_softmax {
}
}
localization_loss {
weighted_smooth_l1 {
}
}
}
}"""
model_proto = model_pb2.DetectionModel()
text_format.Merge(model_text_proto, model_proto)
model = self.create_model(model_proto)
self.assertIsInstance(model, ssd_meta_arch.SSDMetaArch)
self.assertIsInstance(model._feature_extractor,
EmbeddedSSDMobileNetV1FeatureExtractor)
def test_create_faster_rcnn_resnet_v1_models_from_config(self): def test_create_faster_rcnn_resnet_v1_models_from_config(self):
model_text_proto = """ model_text_proto = """
faster_rcnn { faster_rcnn {
...@@ -331,7 +495,7 @@ class ModelBuilderTest(tf.test.TestCase): ...@@ -331,7 +495,7 @@ class ModelBuilderTest(tf.test.TestCase):
}""" }"""
model_proto = model_pb2.DetectionModel() model_proto = model_pb2.DetectionModel()
text_format.Merge(model_text_proto, model_proto) text_format.Merge(model_text_proto, model_proto)
for extractor_type, extractor_class in FEATURE_EXTRACTOR_MAPS.items(): for extractor_type, extractor_class in FRCNN_RESNET_FEAT_MAPS.items():
model_proto.faster_rcnn.feature_extractor.type = extractor_type model_proto.faster_rcnn.feature_extractor.type = extractor_type
model = model_builder.build(model_proto, is_training=True) model = model_builder.build(model_proto, is_training=True)
self.assertIsInstance(model, faster_rcnn_meta_arch.FasterRCNNMetaArch) self.assertIsInstance(model, faster_rcnn_meta_arch.FasterRCNNMetaArch)
...@@ -730,7 +894,7 @@ class ModelBuilderTest(tf.test.TestCase): ...@@ -730,7 +894,7 @@ class ModelBuilderTest(tf.test.TestCase):
}""" }"""
model_proto = model_pb2.DetectionModel() model_proto = model_pb2.DetectionModel()
text_format.Merge(model_text_proto, model_proto) text_format.Merge(model_text_proto, model_proto)
for extractor_type, extractor_class in FEATURE_EXTRACTOR_MAPS.items(): for extractor_type, extractor_class in FRCNN_RESNET_FEAT_MAPS.items():
model_proto.faster_rcnn.feature_extractor.type = extractor_type model_proto.faster_rcnn.feature_extractor.type = extractor_type
model = model_builder.build(model_proto, is_training=True) model = model_builder.build(model_proto, is_training=True)
self.assertIsInstance(model, rfcn_meta_arch.RFCNMetaArch) self.assertIsInstance(model, rfcn_meta_arch.RFCNMetaArch)
......
...@@ -19,15 +19,14 @@ import tensorflow as tf ...@@ -19,15 +19,14 @@ import tensorflow as tf
from object_detection.utils import learning_schedules from object_detection.utils import learning_schedules
def build(optimizer_config, global_summaries): def build(optimizer_config):
"""Create optimizer based on config. """Create optimizer based on config.
Args: Args:
optimizer_config: A Optimizer proto message. optimizer_config: A Optimizer proto message.
global_summaries: A set to attach learning rate summary to.
Returns: Returns:
An optimizer. An optimizer and a list of variables for summary.
Raises: Raises:
ValueError: when using an unsupported input data type. ValueError: when using an unsupported input data type.
...@@ -35,24 +34,30 @@ def build(optimizer_config, global_summaries): ...@@ -35,24 +34,30 @@ def build(optimizer_config, global_summaries):
optimizer_type = optimizer_config.WhichOneof('optimizer') optimizer_type = optimizer_config.WhichOneof('optimizer')
optimizer = None optimizer = None
summary_vars = []
if optimizer_type == 'rms_prop_optimizer': if optimizer_type == 'rms_prop_optimizer':
config = optimizer_config.rms_prop_optimizer config = optimizer_config.rms_prop_optimizer
learning_rate = _create_learning_rate(config.learning_rate)
summary_vars.append(learning_rate)
optimizer = tf.train.RMSPropOptimizer( optimizer = tf.train.RMSPropOptimizer(
_create_learning_rate(config.learning_rate, global_summaries), learning_rate,
decay=config.decay, decay=config.decay,
momentum=config.momentum_optimizer_value, momentum=config.momentum_optimizer_value,
epsilon=config.epsilon) epsilon=config.epsilon)
if optimizer_type == 'momentum_optimizer': if optimizer_type == 'momentum_optimizer':
config = optimizer_config.momentum_optimizer config = optimizer_config.momentum_optimizer
learning_rate = _create_learning_rate(config.learning_rate)
summary_vars.append(learning_rate)
optimizer = tf.train.MomentumOptimizer( optimizer = tf.train.MomentumOptimizer(
_create_learning_rate(config.learning_rate, global_summaries), learning_rate,
momentum=config.momentum_optimizer_value) momentum=config.momentum_optimizer_value)
if optimizer_type == 'adam_optimizer': if optimizer_type == 'adam_optimizer':
config = optimizer_config.adam_optimizer config = optimizer_config.adam_optimizer
optimizer = tf.train.AdamOptimizer( learning_rate = _create_learning_rate(config.learning_rate)
_create_learning_rate(config.learning_rate, global_summaries)) summary_vars.append(learning_rate)
optimizer = tf.train.AdamOptimizer(learning_rate)
if optimizer is None: if optimizer is None:
raise ValueError('Optimizer %s not supported.' % optimizer_type) raise ValueError('Optimizer %s not supported.' % optimizer_type)
...@@ -61,15 +66,14 @@ def build(optimizer_config, global_summaries): ...@@ -61,15 +66,14 @@ def build(optimizer_config, global_summaries):
optimizer = tf.contrib.opt.MovingAverageOptimizer( optimizer = tf.contrib.opt.MovingAverageOptimizer(
optimizer, average_decay=optimizer_config.moving_average_decay) optimizer, average_decay=optimizer_config.moving_average_decay)
return optimizer return optimizer, summary_vars
def _create_learning_rate(learning_rate_config, global_summaries): def _create_learning_rate(learning_rate_config):
"""Create optimizer learning rate based on config. """Create optimizer learning rate based on config.
Args: Args:
learning_rate_config: A LearningRate proto message. learning_rate_config: A LearningRate proto message.
global_summaries: A set to attach learning rate summary to.
Returns: Returns:
A learning rate. A learning rate.
...@@ -81,7 +85,7 @@ def _create_learning_rate(learning_rate_config, global_summaries): ...@@ -81,7 +85,7 @@ def _create_learning_rate(learning_rate_config, global_summaries):
learning_rate_type = learning_rate_config.WhichOneof('learning_rate') learning_rate_type = learning_rate_config.WhichOneof('learning_rate')
if learning_rate_type == 'constant_learning_rate': if learning_rate_type == 'constant_learning_rate':
config = learning_rate_config.constant_learning_rate config = learning_rate_config.constant_learning_rate
learning_rate = config.learning_rate learning_rate = tf.constant(config.learning_rate, dtype=tf.float32)
if learning_rate_type == 'exponential_decay_learning_rate': if learning_rate_type == 'exponential_decay_learning_rate':
config = learning_rate_config.exponential_decay_learning_rate config = learning_rate_config.exponential_decay_learning_rate
...@@ -115,5 +119,4 @@ def _create_learning_rate(learning_rate_config, global_summaries): ...@@ -115,5 +119,4 @@ def _create_learning_rate(learning_rate_config, global_summaries):
if learning_rate is None: if learning_rate is None:
raise ValueError('Learning_rate %s not supported.' % learning_rate_type) raise ValueError('Learning_rate %s not supported.' % learning_rate_type)
global_summaries.add(tf.summary.scalar('Learning_Rate', learning_rate))
return learning_rate return learning_rate
...@@ -31,12 +31,13 @@ class LearningRateBuilderTest(tf.test.TestCase): ...@@ -31,12 +31,13 @@ class LearningRateBuilderTest(tf.test.TestCase):
learning_rate: 0.004 learning_rate: 0.004
} }
""" """
global_summaries = set([])
learning_rate_proto = optimizer_pb2.LearningRate() learning_rate_proto = optimizer_pb2.LearningRate()
text_format.Merge(learning_rate_text_proto, learning_rate_proto) text_format.Merge(learning_rate_text_proto, learning_rate_proto)
learning_rate = optimizer_builder._create_learning_rate( learning_rate = optimizer_builder._create_learning_rate(
learning_rate_proto, global_summaries) learning_rate_proto)
self.assertAlmostEqual(learning_rate, 0.004) with self.test_session():
learning_rate_out = learning_rate.eval()
self.assertAlmostEqual(learning_rate_out, 0.004)
def testBuildExponentialDecayLearningRate(self): def testBuildExponentialDecayLearningRate(self):
learning_rate_text_proto = """ learning_rate_text_proto = """
...@@ -47,11 +48,10 @@ class LearningRateBuilderTest(tf.test.TestCase): ...@@ -47,11 +48,10 @@ class LearningRateBuilderTest(tf.test.TestCase):
staircase: false staircase: false
} }
""" """
global_summaries = set([])
learning_rate_proto = optimizer_pb2.LearningRate() learning_rate_proto = optimizer_pb2.LearningRate()
text_format.Merge(learning_rate_text_proto, learning_rate_proto) text_format.Merge(learning_rate_text_proto, learning_rate_proto)
learning_rate = optimizer_builder._create_learning_rate( learning_rate = optimizer_builder._create_learning_rate(
learning_rate_proto, global_summaries) learning_rate_proto)
self.assertTrue(isinstance(learning_rate, tf.Tensor)) self.assertTrue(isinstance(learning_rate, tf.Tensor))
def testBuildManualStepLearningRate(self): def testBuildManualStepLearningRate(self):
...@@ -67,11 +67,10 @@ class LearningRateBuilderTest(tf.test.TestCase): ...@@ -67,11 +67,10 @@ class LearningRateBuilderTest(tf.test.TestCase):
} }
} }
""" """
global_summaries = set([])
learning_rate_proto = optimizer_pb2.LearningRate() learning_rate_proto = optimizer_pb2.LearningRate()
text_format.Merge(learning_rate_text_proto, learning_rate_proto) text_format.Merge(learning_rate_text_proto, learning_rate_proto)
learning_rate = optimizer_builder._create_learning_rate( learning_rate = optimizer_builder._create_learning_rate(
learning_rate_proto, global_summaries) learning_rate_proto)
self.assertTrue(isinstance(learning_rate, tf.Tensor)) self.assertTrue(isinstance(learning_rate, tf.Tensor))
def testBuildCosineDecayLearningRate(self): def testBuildCosineDecayLearningRate(self):
...@@ -83,22 +82,19 @@ class LearningRateBuilderTest(tf.test.TestCase): ...@@ -83,22 +82,19 @@ class LearningRateBuilderTest(tf.test.TestCase):
warmup_steps: 1000 warmup_steps: 1000
} }
""" """
global_summaries = set([])
learning_rate_proto = optimizer_pb2.LearningRate() learning_rate_proto = optimizer_pb2.LearningRate()
text_format.Merge(learning_rate_text_proto, learning_rate_proto) text_format.Merge(learning_rate_text_proto, learning_rate_proto)
learning_rate = optimizer_builder._create_learning_rate( learning_rate = optimizer_builder._create_learning_rate(
learning_rate_proto, global_summaries) learning_rate_proto)
self.assertTrue(isinstance(learning_rate, tf.Tensor)) self.assertTrue(isinstance(learning_rate, tf.Tensor))
def testRaiseErrorOnEmptyLearningRate(self): def testRaiseErrorOnEmptyLearningRate(self):
learning_rate_text_proto = """ learning_rate_text_proto = """
""" """
global_summaries = set([])
learning_rate_proto = optimizer_pb2.LearningRate() learning_rate_proto = optimizer_pb2.LearningRate()
text_format.Merge(learning_rate_text_proto, learning_rate_proto) text_format.Merge(learning_rate_text_proto, learning_rate_proto)
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
optimizer_builder._create_learning_rate( optimizer_builder._create_learning_rate(learning_rate_proto)
learning_rate_proto, global_summaries)
class OptimizerBuilderTest(tf.test.TestCase): class OptimizerBuilderTest(tf.test.TestCase):
...@@ -119,10 +115,9 @@ class OptimizerBuilderTest(tf.test.TestCase): ...@@ -119,10 +115,9 @@ class OptimizerBuilderTest(tf.test.TestCase):
} }
use_moving_average: false use_moving_average: false
""" """
global_summaries = set([])
optimizer_proto = optimizer_pb2.Optimizer() optimizer_proto = optimizer_pb2.Optimizer()
text_format.Merge(optimizer_text_proto, optimizer_proto) text_format.Merge(optimizer_text_proto, optimizer_proto)
optimizer = optimizer_builder.build(optimizer_proto, global_summaries) optimizer, _ = optimizer_builder.build(optimizer_proto)
self.assertTrue(isinstance(optimizer, tf.train.RMSPropOptimizer)) self.assertTrue(isinstance(optimizer, tf.train.RMSPropOptimizer))
def testBuildMomentumOptimizer(self): def testBuildMomentumOptimizer(self):
...@@ -137,10 +132,9 @@ class OptimizerBuilderTest(tf.test.TestCase): ...@@ -137,10 +132,9 @@ class OptimizerBuilderTest(tf.test.TestCase):
} }
use_moving_average: false use_moving_average: false
""" """
global_summaries = set([])
optimizer_proto = optimizer_pb2.Optimizer() optimizer_proto = optimizer_pb2.Optimizer()
text_format.Merge(optimizer_text_proto, optimizer_proto) text_format.Merge(optimizer_text_proto, optimizer_proto)
optimizer = optimizer_builder.build(optimizer_proto, global_summaries) optimizer, _ = optimizer_builder.build(optimizer_proto)
self.assertTrue(isinstance(optimizer, tf.train.MomentumOptimizer)) self.assertTrue(isinstance(optimizer, tf.train.MomentumOptimizer))
def testBuildAdamOptimizer(self): def testBuildAdamOptimizer(self):
...@@ -154,10 +148,9 @@ class OptimizerBuilderTest(tf.test.TestCase): ...@@ -154,10 +148,9 @@ class OptimizerBuilderTest(tf.test.TestCase):
} }
use_moving_average: false use_moving_average: false
""" """
global_summaries = set([])
optimizer_proto = optimizer_pb2.Optimizer() optimizer_proto = optimizer_pb2.Optimizer()
text_format.Merge(optimizer_text_proto, optimizer_proto) text_format.Merge(optimizer_text_proto, optimizer_proto)
optimizer = optimizer_builder.build(optimizer_proto, global_summaries) optimizer, _ = optimizer_builder.build(optimizer_proto)
self.assertTrue(isinstance(optimizer, tf.train.AdamOptimizer)) self.assertTrue(isinstance(optimizer, tf.train.AdamOptimizer))
def testBuildMovingAverageOptimizer(self): def testBuildMovingAverageOptimizer(self):
...@@ -171,10 +164,9 @@ class OptimizerBuilderTest(tf.test.TestCase): ...@@ -171,10 +164,9 @@ class OptimizerBuilderTest(tf.test.TestCase):
} }
use_moving_average: True use_moving_average: True
""" """
global_summaries = set([])
optimizer_proto = optimizer_pb2.Optimizer() optimizer_proto = optimizer_pb2.Optimizer()
text_format.Merge(optimizer_text_proto, optimizer_proto) text_format.Merge(optimizer_text_proto, optimizer_proto)
optimizer = optimizer_builder.build(optimizer_proto, global_summaries) optimizer, _ = optimizer_builder.build(optimizer_proto)
self.assertTrue( self.assertTrue(
isinstance(optimizer, tf.contrib.opt.MovingAverageOptimizer)) isinstance(optimizer, tf.contrib.opt.MovingAverageOptimizer))
...@@ -190,10 +182,9 @@ class OptimizerBuilderTest(tf.test.TestCase): ...@@ -190,10 +182,9 @@ class OptimizerBuilderTest(tf.test.TestCase):
use_moving_average: True use_moving_average: True
moving_average_decay: 0.2 moving_average_decay: 0.2
""" """
global_summaries = set([])
optimizer_proto = optimizer_pb2.Optimizer() optimizer_proto = optimizer_pb2.Optimizer()
text_format.Merge(optimizer_text_proto, optimizer_proto) text_format.Merge(optimizer_text_proto, optimizer_proto)
optimizer = optimizer_builder.build(optimizer_proto, global_summaries) optimizer, _ = optimizer_builder.build(optimizer_proto)
self.assertTrue( self.assertTrue(
isinstance(optimizer, tf.contrib.opt.MovingAverageOptimizer)) isinstance(optimizer, tf.contrib.opt.MovingAverageOptimizer))
# TODO: Find a way to not depend on the private members. # TODO: Find a way to not depend on the private members.
...@@ -202,11 +193,10 @@ class OptimizerBuilderTest(tf.test.TestCase): ...@@ -202,11 +193,10 @@ class OptimizerBuilderTest(tf.test.TestCase):
def testBuildEmptyOptimizer(self): def testBuildEmptyOptimizer(self):
optimizer_text_proto = """ optimizer_text_proto = """
""" """
global_summaries = set([])
optimizer_proto = optimizer_pb2.Optimizer() optimizer_proto = optimizer_pb2.Optimizer()
text_format.Merge(optimizer_text_proto, optimizer_proto) text_format.Merge(optimizer_text_proto, optimizer_proto)
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
optimizer_builder.build(optimizer_proto, global_summaries) optimizer_builder.build(optimizer_proto)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -83,6 +83,7 @@ PREPROCESSING_FUNCTION_MAP = { ...@@ -83,6 +83,7 @@ PREPROCESSING_FUNCTION_MAP = {
'random_jitter_boxes': preprocessor.random_jitter_boxes, 'random_jitter_boxes': preprocessor.random_jitter_boxes,
'random_crop_to_aspect_ratio': preprocessor.random_crop_to_aspect_ratio, 'random_crop_to_aspect_ratio': preprocessor.random_crop_to_aspect_ratio,
'random_black_patches': preprocessor.random_black_patches, 'random_black_patches': preprocessor.random_black_patches,
'rgb_to_gray': preprocessor.rgb_to_gray,
'scale_boxes_to_pixel_coordinates': ( 'scale_boxes_to_pixel_coordinates': (
preprocessor.scale_boxes_to_pixel_coordinates), preprocessor.scale_boxes_to_pixel_coordinates),
'subtract_channel_mean': preprocessor.subtract_channel_mean, 'subtract_channel_mean': preprocessor.subtract_channel_mean,
......
...@@ -379,6 +379,16 @@ class PreprocessorBuilderTest(tf.test.TestCase): ...@@ -379,6 +379,16 @@ class PreprocessorBuilderTest(tf.test.TestCase):
'new_width': 100, 'new_width': 100,
'method': tf.image.ResizeMethod.BICUBIC}) 'method': tf.image.ResizeMethod.BICUBIC})
def test_build_rgb_to_gray(self):
preprocessor_text_proto = """
rgb_to_gray {}
"""
preprocessor_proto = preprocessor_pb2.PreprocessingStep()
text_format.Merge(preprocessor_text_proto, preprocessor_proto)
function, args = preprocessor_builder.build(preprocessor_proto)
self.assertEqual(function, preprocessor.rgb_to_gray)
self.assertEqual(args, {})
def test_build_subtract_channel_mean(self): def test_build_subtract_channel_mean(self):
preprocessor_text_proto = """ preprocessor_text_proto = """
subtract_channel_mean { subtract_channel_mean {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment