Commit 78d5f8f8 authored by Zhichao Lu's avatar Zhichao Lu Committed by lzc5123016
Browse files

Merged commit includes the following changes:

187187978  by Zhichao Lu:

    Only updating hyperparameters if they have non-null values.

--
187097690  by Zhichao Lu:

    Rewrite some conditions a bit more clearly.

--
187085190  by Zhichao Lu:

    More informative error message.

--
186935376  by Zhichao Lu:

    Added option to evaluator.evaluate to use custom evaluator objects.

--
186808249  by Zhichao Lu:

    Fix documentation re: number of stages.

--
186775014  by Zhichao Lu:

    Change anchor generator interface to return a list of BoxLists containing anchors for different feature map layers.

--
186729028  by Zhichao Lu:

    Minor fixes to object detection.

--
186723716  by Zhichao Lu:

    Fix tf_example_decoder.py initailization issue.

--
186668505  by Zhichao Lu:

    Remove unused import.

--
186475361  by Zhichao Lu:

    Update the box predictor interface to return list of predictions - one from each feature map - instead of stacking them into one large tensor.

--
186410844  by Zhich...
parent 629adffa
# Tensorflow Object Detection API: main runnables.
package(
default_visibility = ["//visibility:public"],
)
load("//learning/brain/contrib/learn/tpu:tpu.bzl", "cloud_tpu_py_binaries")
licenses(["notice"])
# Apache 2.0
exports_files(["LICENSE"])
py_library(
name = "inputs",
srcs = [
"inputs.py",
],
deps = [
"//tensorflow",
"//tensorflow/models/research/object_detection/builders:dataset_builder",
"//tensorflow/models/research/object_detection/builders:image_resizer_builder",
"//tensorflow/models/research/object_detection/builders:model_builder",
"//tensorflow/models/research/object_detection/builders:preprocessor_builder",
"//tensorflow/models/research/object_detection/protos:input_reader_py_pb2",
"//tensorflow/models/research/object_detection/protos:model_py_pb2",
"//tensorflow/models/research/object_detection/protos:train_py_pb2",
"//tensorflow/models/research/object_detection/utils:config_util",
"//tensorflow/models/research/object_detection/utils:dataset_util",
"//tensorflow/models/research/object_detection/utils:ops",
],
)
py_test(
name = "inputs_test",
srcs = [
"inputs_test.py",
],
data = [
"//tensorflow/models/research/object_detection/data:pet_label_map.pbtxt",
"//tensorflow/models/research/object_detection/samples/configs:faster_rcnn_resnet50_pets.config",
"//tensorflow/models/research/object_detection/samples/configs:ssd_inception_v2_pets.config",
"//tensorflow/models/research/object_detection/test_data:pets_examples.record",
],
deps = [
":inputs",
"//tensorflow",
"//tensorflow/models/research/object_detection/core:standard_fields",
"//tensorflow/models/research/object_detection/utils:config_util",
],
)
py_binary(
name = "model",
srcs = [
"model.py",
],
deps = [
":inputs",
":model_hparams",
"//tensorflow",
"//tensorflow/models/research/object_detection:eval_util",
"//tensorflow/models/research/object_detection/builders:model_builder",
"//tensorflow/models/research/object_detection/builders:optimizer_builder",
"//tensorflow/models/research/object_detection/metrics:coco_evaluation",
"//tensorflow/models/research/object_detection/utils:config_util",
"//tensorflow/models/research/object_detection/utils:label_map_util",
"//tensorflow/models/research/object_detection/utils:ops",
"//tensorflow/models/research/object_detection/utils:shape_utils",
"//tensorflow/models/research/object_detection/utils:variables_helper",
"//tensorflow/models/research/object_detection/utils:visualization_utils",
],
)
py_library(
name = "model_hparams",
srcs = [
"model_hparams.py",
],
deps = [
"//tensorflow",
],
)
py_test(
name = "model_test",
timeout = "long",
srcs = [
"model_test.py",
],
data = [
"//tensorflow/models/research/object_detection/data:pet_label_map.pbtxt",
"//tensorflow/models/research/object_detection/samples/configs:faster_rcnn_resnet50_pets.config",
"//tensorflow/models/research/object_detection/samples/configs:ssd_inception_v2_pets.config",
"//tensorflow/models/research/object_detection/test_data:pets_examples.record",
],
deps = [
":inputs",
":model",
":model_hparams",
":model_test_util",
"//mock",
"//tensorflow",
"//tensorflow/models/research/object_detection/core:standard_fields",
"//tensorflow/models/research/object_detection/data_decoders:tf_example_decoder",
"//tensorflow/models/research/object_detection/utils:config_util",
"//tensorflow/models/research/object_detection/utils:ops",
],
)
MODEL_TPU_DEPS = [
":inputs",
":model",
":model_hparams",
"//tensorflow",
"//tensorflow/models/research/object_detection:eval_util",
"//tensorflow/models/research/object_detection/builders:model_builder",
"//tensorflow/models/research/object_detection/builders:optimizer_builder",
"//tensorflow/models/research/object_detection/metrics:coco_evaluation",
"//tensorflow/models/research/object_detection/utils:config_util",
"//tensorflow/models/research/object_detection/utils:label_map_util",
"//tensorflow/models/research/object_detection/utils:ops",
"//tensorflow/models/research/object_detection/utils:variables_helper",
"//tensorflow/models/research/object_detection/utils:visualization_utils",
]
cloud_tpu_py_binaries(
name = "model_tpu",
srcs = [
"model_tpu.py",
],
main = "model_tpu.py",
deps = MODEL_TPU_DEPS,
)
py_library(
name = "model_tpu_lib",
srcs = [
"model_tpu.py",
],
deps = MODEL_TPU_DEPS,
)
py_library(
name = "model_test_util",
srcs = [
"model_test_util.py",
],
deps = [
":model",
":model_hparams",
"//tensorflow",
],
)
py_binary(
name = "train",
srcs = [
"train.py",
],
deps = [
":trainer",
"//tensorflow",
"//tensorflow/models/research/object_detection/builders:dataset_builder",
"//tensorflow/models/research/object_detection/builders:model_builder",
"//tensorflow/models/research/object_detection/utils:config_util",
"//tensorflow/models/research/object_detection/utils:dataset_util",
],
)
py_library(
name = "trainer",
srcs = ["trainer.py"],
deps = [
"//tensorflow",
"//tensorflow/models/research/object_detection/builders:optimizer_builder",
"//tensorflow/models/research/object_detection/builders:preprocessor_builder",
"//tensorflow/models/research/object_detection/core:batcher",
"//tensorflow/models/research/object_detection/core:preprocessor",
"//tensorflow/models/research/object_detection/core:standard_fields",
"//tensorflow/models/research/object_detection/utils:ops",
"//tensorflow/models/research/object_detection/utils:variables_helper",
"//third_party/tensorflow_models/slim:model_deploy",
],
)
py_test(
name = "trainer_test",
srcs = ["trainer_test.py"],
deps = [
":trainer",
"//tensorflow",
"//tensorflow/models/research/object_detection/core:losses",
"//tensorflow/models/research/object_detection/core:model",
"//tensorflow/models/research/object_detection/core:standard_fields",
"//tensorflow/models/research/object_detection/protos:train_py_pb2",
],
)
py_library(
name = "eval_util",
srcs = [
"eval_util.py",
],
deps = [
"//tensorflow",
"//tensorflow/models/research/object_detection/core:box_list",
"//tensorflow/models/research/object_detection/core:box_list_ops",
"//tensorflow/models/research/object_detection/core:keypoint_ops",
"//tensorflow/models/research/object_detection/core:standard_fields",
"//tensorflow/models/research/object_detection/utils:label_map_util",
"//tensorflow/models/research/object_detection/utils:ops",
"//tensorflow/models/research/object_detection/utils:visualization_utils",
],
)
py_library(
name = "evaluator",
srcs = ["evaluator.py"],
deps = [
"//tensorflow",
"//tensorflow/models/research/object_detection:eval_util",
"//tensorflow/models/research/object_detection/core:prefetcher",
"//tensorflow/models/research/object_detection/core:standard_fields",
"//tensorflow/models/research/object_detection/metrics:coco_evaluation",
"//tensorflow/models/research/object_detection/protos:eval_py_pb2",
"//tensorflow/models/research/object_detection/utils:object_detection_evaluation",
],
)
py_binary(
name = "eval",
srcs = [
"eval.py",
],
deps = [
":evaluator",
"//tensorflow",
"//tensorflow/models/research/object_detection/builders:dataset_builder",
"//tensorflow/models/research/object_detection/builders:model_builder",
"//tensorflow/models/research/object_detection/utils:config_util",
"//tensorflow/models/research/object_detection/utils:dataset_util",
"//tensorflow/models/research/object_detection/utils:label_map_util",
],
)
py_library(
name = "exporter",
srcs = [
"exporter.py",
],
deps = [
"//tensorflow",
"//tensorflow/python/tools:freeze_graph_lib",
"//tensorflow/models/research/object_detection/builders:model_builder",
"//tensorflow/models/research/object_detection/core:standard_fields",
"//tensorflow/models/research/object_detection/data_decoders:tf_example_decoder",
],
)
py_test(
name = "exporter_test",
srcs = [
"exporter_test.py",
],
deps = [
":exporter",
"//tensorflow",
"//tensorflow/models/research/object_detection/builders:model_builder",
"//tensorflow/models/research/object_detection/core:model",
"//tensorflow/models/research/object_detection/protos:pipeline_py_pb2",
],
)
py_binary(
name = "export_inference_graph",
srcs = [
"export_inference_graph.py",
],
deps = [
":exporter",
"//tensorflow",
"//tensorflow/models/research/object_detection/protos:pipeline_py_pb2",
],
)
# Tensorflow Object Detection API: Anchor Generator implementations.
package(
default_visibility = ["//visibility:public"],
)
licenses(["notice"])
# Apache 2.0
py_library(
name = "grid_anchor_generator",
srcs = [
"grid_anchor_generator.py",
],
deps = [
"//tensorflow",
"//tensorflow/models/research/object_detection/core:anchor_generator",
"//tensorflow/models/research/object_detection/core:box_list",
"//tensorflow/models/research/object_detection/utils:ops",
],
)
py_test(
name = "grid_anchor_generator_test",
srcs = [
"grid_anchor_generator_test.py",
],
deps = [
":grid_anchor_generator",
"//tensorflow",
"//tensorflow/models/research/object_detection/utils:test_case",
],
)
py_library(
name = "multiple_grid_anchor_generator",
srcs = [
"multiple_grid_anchor_generator.py",
],
deps = [
":grid_anchor_generator",
"//tensorflow",
"//tensorflow/models/research/object_detection/core:anchor_generator",
"//tensorflow/models/research/object_detection/core:box_list_ops",
],
)
py_test(
name = "multiple_grid_anchor_generator_test",
srcs = [
"multiple_grid_anchor_generator_test.py",
],
deps = [
":multiple_grid_anchor_generator",
"//numpy",
"//tensorflow/models/research/object_detection/utils:test_case",
],
)
py_library(
name = "multiscale_grid_anchor_generator",
srcs = [
"multiscale_grid_anchor_generator.py",
],
deps = [
":grid_anchor_generator",
"//tensorflow",
"//tensorflow/models/research/object_detection/core:anchor_generator",
"//tensorflow/models/research/object_detection/core:box_list_ops",
],
)
py_test(
name = "multiscale_grid_anchor_generator_test",
srcs = [
"multiscale_grid_anchor_generator_test.py",
],
deps = [
":multiscale_grid_anchor_generator",
"//numpy",
"//tensorflow/models/research/object_detection/utils:test_case",
],
)
...@@ -93,11 +93,9 @@ class GridAnchorGenerator(anchor_generator.AnchorGenerator): ...@@ -93,11 +93,9 @@ class GridAnchorGenerator(anchor_generator.AnchorGenerator):
allowed. allowed.
Returns: Returns:
boxes: a BoxList holding a collection of N anchor boxes. Additionally boxes_list: a list of BoxLists each holding anchor boxes corresponding to
this BoxList also holds a `feature_map_index` field which is set to 0 the input feature map shapes.
for each anchor; this field exists for interchangeability reasons with
the MultipleGridAnchorGenerator (see the docstring for the corresponding
`_generate` function in multiple_grid_anchor_generator.py)
Raises: Raises:
ValueError: if feature_map_shape_list, box_specs_list do not have the same ValueError: if feature_map_shape_list, box_specs_list do not have the same
length. length.
...@@ -128,7 +126,7 @@ class GridAnchorGenerator(anchor_generator.AnchorGenerator): ...@@ -128,7 +126,7 @@ class GridAnchorGenerator(anchor_generator.AnchorGenerator):
num_anchors = anchors.num_boxes() num_anchors = anchors.num_boxes()
anchor_indices = tf.zeros([num_anchors]) anchor_indices = tf.zeros([num_anchors])
anchors.add_field('feature_map_index', anchor_indices) anchors.add_field('feature_map_index', anchor_indices)
return anchors return [anchors]
def tile_anchors(grid_height, def tile_anchors(grid_height,
......
...@@ -31,8 +31,8 @@ class GridAnchorGeneratorTest(test_case.TestCase): ...@@ -31,8 +31,8 @@ class GridAnchorGeneratorTest(test_case.TestCase):
anchor_offset = [7, -3] anchor_offset = [7, -3]
anchor_generator = grid_anchor_generator.GridAnchorGenerator( anchor_generator = grid_anchor_generator.GridAnchorGenerator(
scales, aspect_ratios, anchor_offset=anchor_offset) scales, aspect_ratios, anchor_offset=anchor_offset)
anchors = anchor_generator.generate(feature_map_shape_list=[(1, 1)]) anchors_list = anchor_generator.generate(feature_map_shape_list=[(1, 1)])
anchor_corners = anchors.get() anchor_corners = anchors_list[0].get()
return (anchor_corners,) return (anchor_corners,)
exp_anchor_corners = [[-121, -35, 135, 29], [-249, -67, 263, 61], exp_anchor_corners = [[-121, -35, 135, 29], [-249, -67, 263, 61],
[-505, -131, 519, 125], [-57, -67, 71, 61], [-505, -131, 519, 125], [-57, -67, 71, 61],
...@@ -57,8 +57,8 @@ class GridAnchorGeneratorTest(test_case.TestCase): ...@@ -57,8 +57,8 @@ class GridAnchorGeneratorTest(test_case.TestCase):
anchor_stride=anchor_stride, anchor_stride=anchor_stride,
anchor_offset=anchor_offset) anchor_offset=anchor_offset)
anchors = anchor_generator.generate(feature_map_shape_list=[(2, 2)]) anchors_list = anchor_generator.generate(feature_map_shape_list=[(2, 2)])
anchor_corners = anchors.get() anchor_corners = anchors_list[0].get()
return (anchor_corners,) return (anchor_corners,)
exp_anchor_corners = [[-2.5, -2.5, 2.5, 2.5], [-5., -5., 5., 5.], exp_anchor_corners = [[-2.5, -2.5, 2.5, 2.5], [-5., -5., 5., 5.],
[-10., -10., 10., 10.], [-2.5, 16.5, 2.5, 21.5], [-10., -10., 10., 10.], [-2.5, 16.5, 2.5, 21.5],
...@@ -83,9 +83,9 @@ class GridAnchorGeneratorTest(test_case.TestCase): ...@@ -83,9 +83,9 @@ class GridAnchorGeneratorTest(test_case.TestCase):
anchor_stride=anchor_stride, anchor_stride=anchor_stride,
anchor_offset=anchor_offset) anchor_offset=anchor_offset)
anchors = anchor_generator.generate( anchors_list = anchor_generator.generate(
feature_map_shape_list=[(feature_map_height, feature_map_width)]) feature_map_shape_list=[(feature_map_height, feature_map_width)])
anchor_corners = anchors.get() anchor_corners = anchors_list[0].get()
return (anchor_corners,) return (anchor_corners,)
exp_anchor_corners = [[-2.5, -2.5, 2.5, 2.5], [-5., -5., 5., 5.], exp_anchor_corners = [[-2.5, -2.5, 2.5, 2.5], [-5., -5., 5., 5.],
......
...@@ -165,10 +165,9 @@ class MultipleGridAnchorGenerator(anchor_generator.AnchorGenerator): ...@@ -165,10 +165,9 @@ class MultipleGridAnchorGenerator(anchor_generator.AnchorGenerator):
grid. grid.
Returns: Returns:
boxes: a BoxList holding a collection of N anchor boxes. Additionally boxes_list: a list of BoxLists each holding anchor boxes corresponding to
this BoxList also holds a `feature_map_index` field which, for each the input feature map shapes.
anchor, stores the index of the corresponding feature map which was used
to generate it.
Raises: Raises:
ValueError: if feature_map_shape_list, box_specs_list do not have the same ValueError: if feature_map_shape_list, box_specs_list do not have the same
length. length.
...@@ -211,7 +210,6 @@ class MultipleGridAnchorGenerator(anchor_generator.AnchorGenerator): ...@@ -211,7 +210,6 @@ class MultipleGridAnchorGenerator(anchor_generator.AnchorGenerator):
raise ValueError('%s must be a list of pairs.' % arg_name) raise ValueError('%s must be a list of pairs.' % arg_name)
anchor_grid_list = [] anchor_grid_list = []
anchor_indices_list = []
min_im_shape = tf.minimum(im_height, im_width) min_im_shape = tf.minimum(im_height, im_width)
scale_height = min_im_shape / im_height scale_height = min_im_shape / im_height
scale_width = min_im_shape / im_width scale_width = min_im_shape / im_width
...@@ -219,10 +217,11 @@ class MultipleGridAnchorGenerator(anchor_generator.AnchorGenerator): ...@@ -219,10 +217,11 @@ class MultipleGridAnchorGenerator(anchor_generator.AnchorGenerator):
scale_height * self._base_anchor_size[0], scale_height * self._base_anchor_size[0],
scale_width * self._base_anchor_size[1] scale_width * self._base_anchor_size[1]
] ]
for feature_map_index, ( for feature_map_index, (grid_size, scales, aspect_ratios, stride,
grid_size, scales, aspect_ratios, stride, offset) in enumerate( offset) in enumerate(
zip(feature_map_shape_list, self._scales, self._aspect_ratios, zip(feature_map_shape_list, self._scales,
anchor_strides, anchor_offsets)): self._aspect_ratios, anchor_strides,
anchor_offsets)):
tiled_anchors = grid_anchor_generator.tile_anchors( tiled_anchors = grid_anchor_generator.tile_anchors(
grid_height=grid_size[0], grid_height=grid_size[0],
grid_width=grid_size[1], grid_width=grid_size[1],
...@@ -231,30 +230,17 @@ class MultipleGridAnchorGenerator(anchor_generator.AnchorGenerator): ...@@ -231,30 +230,17 @@ class MultipleGridAnchorGenerator(anchor_generator.AnchorGenerator):
base_anchor_size=base_anchor_size, base_anchor_size=base_anchor_size,
anchor_stride=stride, anchor_stride=stride,
anchor_offset=offset) anchor_offset=offset)
anchor_grid_list.append(tiled_anchors) if self._clip_window is not None:
tiled_anchors = box_list_ops.clip_to_window(
tiled_anchors, self._clip_window, filter_nonoverlapping=False)
num_anchors_in_layer = tiled_anchors.num_boxes_static() num_anchors_in_layer = tiled_anchors.num_boxes_static()
if num_anchors_in_layer is None: if num_anchors_in_layer is None:
num_anchors_in_layer = tiled_anchors.num_boxes() num_anchors_in_layer = tiled_anchors.num_boxes()
anchor_indices_list.append( anchor_indices = feature_map_index * tf.ones([num_anchors_in_layer])
feature_map_index * tf.ones([num_anchors_in_layer])) tiled_anchors.add_field('feature_map_index', anchor_indices)
concatenated_anchors = box_list_ops.concatenate(anchor_grid_list) anchor_grid_list.append(tiled_anchors)
anchor_indices = tf.concat(anchor_indices_list, 0)
num_anchors = concatenated_anchors.num_boxes_static() return anchor_grid_list
if num_anchors is None:
num_anchors = concatenated_anchors.num_boxes()
if self._clip_window is not None:
concatenated_anchors = box_list_ops.clip_to_window(
concatenated_anchors, self._clip_window, filter_nonoverlapping=False)
# TODO: make reshape an option for the clip_to_window op
concatenated_anchors.set(
tf.reshape(concatenated_anchors.get(), [num_anchors, 4]))
stddevs_tensor = 0.01 * tf.ones(
[num_anchors, 4], dtype=tf.float32, name='stddevs')
concatenated_anchors.add_field('stddev', stddevs_tensor)
concatenated_anchors.add_field('feature_map_index', anchor_indices)
return concatenated_anchors
def create_ssd_anchors(num_layers=6, def create_ssd_anchors(num_layers=6,
...@@ -285,7 +271,7 @@ def create_ssd_anchors(num_layers=6, ...@@ -285,7 +271,7 @@ def create_ssd_anchors(num_layers=6,
grid sizes passed in at generation time) grid sizes passed in at generation time)
min_scale: scale of anchors corresponding to finest resolution (float) min_scale: scale of anchors corresponding to finest resolution (float)
max_scale: scale of anchors corresponding to coarsest resolution (float) max_scale: scale of anchors corresponding to coarsest resolution (float)
scales: As list of anchor scales to use. When not None and not emtpy, scales: As list of anchor scales to use. When not None and not empty,
min_scale and max_scale are not used. min_scale and max_scale are not used.
aspect_ratios: list or tuple of (float) aspect ratios to place on each aspect_ratios: list or tuple of (float) aspect ratios to place on each
grid point. grid point.
......
...@@ -37,8 +37,8 @@ class MultipleGridAnchorGeneratorTest(test_case.TestCase): ...@@ -37,8 +37,8 @@ class MultipleGridAnchorGeneratorTest(test_case.TestCase):
base_anchor_size=tf.constant([256, 256], dtype=tf.float32), base_anchor_size=tf.constant([256, 256], dtype=tf.float32),
anchor_strides=[(16, 16)], anchor_strides=[(16, 16)],
anchor_offsets=[(7, -3)]) anchor_offsets=[(7, -3)])
anchors = anchor_generator.generate(feature_map_shape_list=[(1, 1)]) anchors_list = anchor_generator.generate(feature_map_shape_list=[(1, 1)])
return anchors.get() return anchors_list[0].get()
exp_anchor_corners = [[-121, -35, 135, 29], [-249, -67, 263, 61], exp_anchor_corners = [[-121, -35, 135, 29], [-249, -67, 263, 61],
[-505, -131, 519, 125], [-57, -67, 71, 61], [-505, -131, 519, 125], [-57, -67, 71, 61],
[-121, -131, 135, 125], [-249, -259, 263, 253], [-121, -131, 135, 125], [-249, -259, 263, 253],
...@@ -57,8 +57,8 @@ class MultipleGridAnchorGeneratorTest(test_case.TestCase): ...@@ -57,8 +57,8 @@ class MultipleGridAnchorGeneratorTest(test_case.TestCase):
base_anchor_size=tf.constant([10, 10], dtype=tf.float32), base_anchor_size=tf.constant([10, 10], dtype=tf.float32),
anchor_strides=[(19, 19)], anchor_strides=[(19, 19)],
anchor_offsets=[(0, 0)]) anchor_offsets=[(0, 0)])
anchors = anchor_generator.generate(feature_map_shape_list=[(2, 2)]) anchors_list = anchor_generator.generate(feature_map_shape_list=[(2, 2)])
return anchors.get() return anchors_list[0].get()
exp_anchor_corners = [[-2.5, -2.5, 2.5, 2.5], [-5., -5., 5., 5.], exp_anchor_corners = [[-2.5, -2.5, 2.5, 2.5], [-5., -5., 5., 5.],
[-10., -10., 10., 10.], [-2.5, 16.5, 2.5, 21.5], [-10., -10., 10., 10.], [-2.5, 16.5, 2.5, 21.5],
[-5., 14., 5, 24], [-10., 9., 10, 29], [-5., 14., 5, 24], [-10., 9., 10, 29],
...@@ -76,9 +76,9 @@ class MultipleGridAnchorGeneratorTest(test_case.TestCase): ...@@ -76,9 +76,9 @@ class MultipleGridAnchorGeneratorTest(test_case.TestCase):
anchor_generator = ag.MultipleGridAnchorGenerator( anchor_generator = ag.MultipleGridAnchorGenerator(
box_specs_list, base_anchor_size=tf.constant([1, 1], box_specs_list, base_anchor_size=tf.constant([1, 1],
dtype=tf.float32)) dtype=tf.float32))
anchors = anchor_generator.generate(feature_map_shape_list=[(tf.constant( anchors_list = anchor_generator.generate(feature_map_shape_list=[(
1, dtype=tf.int32), tf.constant(2, dtype=tf.int32))]) tf.constant(1, dtype=tf.int32), tf.constant(2, dtype=tf.int32))])
return anchors.get() return anchors_list[0].get()
exp_anchor_corners = [[0., -0.25, 1., 0.75], [0., 0.25, 1., 1.25]] exp_anchor_corners = [[0., -0.25, 1., 0.75], [0., 0.25, 1., 1.25]]
anchor_corners_out = self.execute(graph_fn, []) anchor_corners_out = self.execute(graph_fn, [])
...@@ -91,9 +91,9 @@ class MultipleGridAnchorGeneratorTest(test_case.TestCase): ...@@ -91,9 +91,9 @@ class MultipleGridAnchorGeneratorTest(test_case.TestCase):
anchor_generator = ag.MultipleGridAnchorGenerator( anchor_generator = ag.MultipleGridAnchorGenerator(
box_specs_list, base_anchor_size=tf.constant([1, 1], box_specs_list, base_anchor_size=tf.constant([1, 1],
dtype=tf.float32)) dtype=tf.float32))
anchors = anchor_generator.generate(feature_map_shape_list=[(height, anchors_list = anchor_generator.generate(feature_map_shape_list=[(height,
width)]) width)])
return anchors.get() return anchors_list[0].get()
exp_anchor_corners = [[0., -0.25, 1., 0.75], [0., 0.25, 1., 1.25]] exp_anchor_corners = [[0., -0.25, 1., 0.75], [0., 0.25, 1., 1.25]]
...@@ -109,12 +109,12 @@ class MultipleGridAnchorGeneratorTest(test_case.TestCase): ...@@ -109,12 +109,12 @@ class MultipleGridAnchorGeneratorTest(test_case.TestCase):
anchor_generator = ag.MultipleGridAnchorGenerator( anchor_generator = ag.MultipleGridAnchorGenerator(
box_specs_list, base_anchor_size=tf.constant([1, 1], box_specs_list, base_anchor_size=tf.constant([1, 1],
dtype=tf.float32)) dtype=tf.float32))
anchors = anchor_generator.generate( anchors_list = anchor_generator.generate(
feature_map_shape_list=[(tf.constant(1, dtype=tf.int32), tf.constant( feature_map_shape_list=[(tf.constant(1, dtype=tf.int32), tf.constant(
2, dtype=tf.int32))], 2, dtype=tf.int32))],
im_height=320, im_height=320,
im_width=640) im_width=640)
return anchors.get() return anchors_list[0].get()
exp_anchor_corners = [[0., 0., 1., 0.5], [0., 0.5, 1., 1.]] exp_anchor_corners = [[0., 0., 1., 0.5], [0., 0.5, 1., 1.]]
anchor_corners_out = self.execute(graph_fn, []) anchor_corners_out = self.execute(graph_fn, [])
...@@ -131,9 +131,9 @@ class MultipleGridAnchorGeneratorTest(test_case.TestCase): ...@@ -131,9 +131,9 @@ class MultipleGridAnchorGeneratorTest(test_case.TestCase):
base_anchor_size=tf.constant([1.0, 1.0], dtype=tf.float32), base_anchor_size=tf.constant([1.0, 1.0], dtype=tf.float32),
anchor_strides=[(.25, .25), (.5, .5)], anchor_strides=[(.25, .25), (.5, .5)],
anchor_offsets=[(.125, .125), (.25, .25)]) anchor_offsets=[(.125, .125), (.25, .25)])
anchors = anchor_generator.generate(feature_map_shape_list=[(4, 4), anchors_list = anchor_generator.generate(feature_map_shape_list=[(4, 4), (
(2, 2)]) 2, 2)])
return anchors.get() return [anchors.get() for anchors in anchors_list]
# height and width of box with .5 aspect ratio # height and width of box with .5 aspect ratio
h = np.sqrt(2) h = np.sqrt(2)
w = 1.0/np.sqrt(2) w = 1.0/np.sqrt(2)
...@@ -150,7 +150,7 @@ class MultipleGridAnchorGeneratorTest(test_case.TestCase): ...@@ -150,7 +150,7 @@ class MultipleGridAnchorGeneratorTest(test_case.TestCase):
[.125-1.0, .125-1.0, .125+1.0, .125+1.0], [.125-1.0, .125-1.0, .125+1.0, .125+1.0],
[.125-.5*h, .125-.5*w, .125+.5*h, .125+.5*w],] [.125-.5*h, .125-.5*w, .125+.5*h, .125+.5*w],]
anchor_corners_out = self.execute(graph_fn, []) anchor_corners_out = np.concatenate(self.execute(graph_fn, []), axis=0)
self.assertEquals(anchor_corners_out.shape, (56, 4)) self.assertEquals(anchor_corners_out.shape, (56, 4))
big_grid_corners = anchor_corners_out[0:3, :] big_grid_corners = anchor_corners_out[0:3, :]
small_grid_corners = anchor_corners_out[48:, :] small_grid_corners = anchor_corners_out[48:, :]
...@@ -168,9 +168,9 @@ class MultipleGridAnchorGeneratorTest(test_case.TestCase): ...@@ -168,9 +168,9 @@ class MultipleGridAnchorGeneratorTest(test_case.TestCase):
box_specs_list, box_specs_list,
base_anchor_size=tf.constant([1.0, 1.0], dtype=tf.float32), base_anchor_size=tf.constant([1.0, 1.0], dtype=tf.float32),
clip_window=clip_window) clip_window=clip_window)
anchors = anchor_generator.generate(feature_map_shape_list=[(4, 4), anchors_list = anchor_generator.generate(feature_map_shape_list=[(4, 4), (
(2, 2)]) 2, 2)])
return anchors.get() return [anchors.get() for anchors in anchors_list]
# height and width of box with .5 aspect ratio # height and width of box with .5 aspect ratio
h = np.sqrt(2) h = np.sqrt(2)
w = 1.0/np.sqrt(2) w = 1.0/np.sqrt(2)
...@@ -183,7 +183,7 @@ class MultipleGridAnchorGeneratorTest(test_case.TestCase): ...@@ -183,7 +183,7 @@ class MultipleGridAnchorGeneratorTest(test_case.TestCase):
[.25, .25, 1, 1], [.25, .25, 1, 1],
[.75-.5*h, .75-.5*w, 1, 1]] [.75-.5*h, .75-.5*w, 1, 1]]
anchor_corners_out = self.execute(graph_fn, []) anchor_corners_out = np.concatenate(self.execute(graph_fn, []), axis=0)
small_grid_corners = anchor_corners_out[48:, :] small_grid_corners = anchor_corners_out[48:, :]
self.assertAllClose(small_grid_corners, exp_small_grid_corners) self.assertAllClose(small_grid_corners, exp_small_grid_corners)
...@@ -264,10 +264,10 @@ class CreateSSDAnchorsTest(test_case.TestCase): ...@@ -264,10 +264,10 @@ class CreateSSDAnchorsTest(test_case.TestCase):
feature_map_shape_list = [(38, 38), (19, 19), (10, 10), feature_map_shape_list = [(38, 38), (19, 19), (10, 10),
(5, 5), (3, 3), (1, 1)] (5, 5), (3, 3), (1, 1)]
anchors = anchor_generator.generate( anchors_list = anchor_generator.generate(
feature_map_shape_list=feature_map_shape_list) feature_map_shape_list=feature_map_shape_list)
return anchors.get() return [anchors.get() for anchors in anchors_list]
anchor_corners_out = self.execute(graph_fn1, []) anchor_corners_out = np.concatenate(self.execute(graph_fn1, []), axis=0)
self.assertEquals(anchor_corners_out.shape, (7308, 4)) self.assertEquals(anchor_corners_out.shape, (7308, 4))
def graph_fn2(): def graph_fn2():
...@@ -278,10 +278,10 @@ class CreateSSDAnchorsTest(test_case.TestCase): ...@@ -278,10 +278,10 @@ class CreateSSDAnchorsTest(test_case.TestCase):
feature_map_shape_list = [(38, 38), (19, 19), (10, 10), feature_map_shape_list = [(38, 38), (19, 19), (10, 10),
(5, 5), (3, 3), (1, 1)] (5, 5), (3, 3), (1, 1)]
anchors = anchor_generator.generate( anchors_list = anchor_generator.generate(
feature_map_shape_list=feature_map_shape_list) feature_map_shape_list=feature_map_shape_list)
return anchors.get() return [anchors.get() for anchors in anchors_list]
anchor_corners_out = self.execute(graph_fn2, []) anchor_corners_out = np.concatenate(self.execute(graph_fn2, []), axis=0)
self.assertEquals(anchor_corners_out.shape, (11640, 4)) self.assertEquals(anchor_corners_out.shape, (11640, 4))
......
...@@ -21,14 +21,15 @@ T.-Y. Lin, P. Goyal, R. Girshick, K. He, P. Dollar ...@@ -21,14 +21,15 @@ T.-Y. Lin, P. Goyal, R. Girshick, K. He, P. Dollar
""" """
from object_detection.anchor_generators import grid_anchor_generator from object_detection.anchor_generators import grid_anchor_generator
from object_detection.core import anchor_generator
from object_detection.core import box_list_ops from object_detection.core import box_list_ops
class MultiscaleGridAnchorGenerator(object): class MultiscaleGridAnchorGenerator(anchor_generator.AnchorGenerator):
"""Generate a grid of anchors for multiple CNN layers of different scale.""" """Generate a grid of anchors for multiple CNN layers of different scale."""
def __init__(self, min_level, max_level, anchor_scale, aspect_ratios, def __init__(self, min_level, max_level, anchor_scale, aspect_ratios,
scales_per_octave): scales_per_octave, normalize_coordinates=True):
"""Constructs a MultiscaleGridAnchorGenerator. """Constructs a MultiscaleGridAnchorGenerator.
To construct anchors, at multiple scale resolutions, one must provide a To construct anchors, at multiple scale resolutions, one must provide a
...@@ -48,10 +49,13 @@ class MultiscaleGridAnchorGenerator(object): ...@@ -48,10 +49,13 @@ class MultiscaleGridAnchorGenerator(object):
aspect_ratios: list or tuple of (float) aspect ratios to place on each aspect_ratios: list or tuple of (float) aspect ratios to place on each
grid point. grid point.
scales_per_octave: integer number of intermediate scales per scale octave. scales_per_octave: integer number of intermediate scales per scale octave.
normalize_coordinates: whether to produce anchors in normalized
coordinates. (defaults to True).
""" """
self._anchor_grid_info = [] self._anchor_grid_info = []
self._aspect_ratios = aspect_ratios self._aspect_ratios = aspect_ratios
self._scales_per_octave = scales_per_octave self._scales_per_octave = scales_per_octave
self._normalize_coordinates = normalize_coordinates
for level in range(min_level, max_level + 1): for level in range(min_level, max_level + 1):
anchor_stride = [2**level, 2**level] anchor_stride = [2**level, 2**level]
...@@ -80,7 +84,7 @@ class MultiscaleGridAnchorGenerator(object): ...@@ -80,7 +84,7 @@ class MultiscaleGridAnchorGenerator(object):
return len(self._anchor_grid_info) * [ return len(self._anchor_grid_info) * [
len(self._aspect_ratios) * self._scales_per_octave] len(self._aspect_ratios) * self._scales_per_octave]
def generate(self, feature_map_shape_list, im_height, im_width): def _generate(self, feature_map_shape_list, im_height, im_width):
"""Generates a collection of bounding boxes to be used as anchors. """Generates a collection of bounding boxes to be used as anchors.
Currently we require the input image shape to be statically defined. That Currently we require the input image shape to be statically defined. That
...@@ -95,7 +99,8 @@ class MultiscaleGridAnchorGenerator(object): ...@@ -95,7 +99,8 @@ class MultiscaleGridAnchorGenerator(object):
im_width: the width of the image to generate the grid for. im_width: the width of the image to generate the grid for.
Returns: Returns:
boxes: a BoxList holding a collection of N anchor boxes boxes_list: a list of BoxLists each holding anchor boxes corresponding to
the input feature map shapes.
Raises: Raises:
ValueError: if im_height and im_width are not integers. ValueError: if im_height and im_width are not integers.
""" """
...@@ -105,7 +110,7 @@ class MultiscaleGridAnchorGenerator(object): ...@@ -105,7 +110,7 @@ class MultiscaleGridAnchorGenerator(object):
anchor_grid_list = [] anchor_grid_list = []
for feat_shape, grid_info in zip(feature_map_shape_list, for feat_shape, grid_info in zip(feature_map_shape_list,
self._anchor_grid_info): self._anchor_grid_info):
# TODO check the feature_map_shape_list is consistent with # TODO(rathodv) check the feature_map_shape_list is consistent with
# self._anchor_grid_info # self._anchor_grid_info
level = grid_info['level'] level = grid_info['level']
stride = 2**level stride = 2**level
...@@ -123,9 +128,11 @@ class MultiscaleGridAnchorGenerator(object): ...@@ -123,9 +128,11 @@ class MultiscaleGridAnchorGenerator(object):
base_anchor_size=base_anchor_size, base_anchor_size=base_anchor_size,
anchor_stride=anchor_stride, anchor_stride=anchor_stride,
anchor_offset=anchor_offset) anchor_offset=anchor_offset)
anchor_grid_list.append( (anchor_grid,) = ag.generate(feature_map_shape_list=[(feat_h, feat_w)])
ag.generate(feature_map_shape_list=[(feat_h, feat_w)]))
concatenated_anchors = box_list_ops.concatenate(anchor_grid_list) if self._normalize_coordinates:
anchor_grid = box_list_ops.to_normalized_coordinates(
anchor_grid, im_height, im_width, check_range=False)
anchor_grid_list.append(anchor_grid)
return concatenated_anchors return anchor_grid_list
...@@ -37,10 +37,35 @@ class MultiscaleGridAnchorGeneratorTest(test_case.TestCase): ...@@ -37,10 +37,35 @@ class MultiscaleGridAnchorGeneratorTest(test_case.TestCase):
[-16, -48, 112, 80], [-16, -48, 112, 80],
[-16, -16, 112, 112]] [-16, -16, 112, 112]]
anchor_generator = mg.MultiscaleGridAnchorGenerator( anchor_generator = mg.MultiscaleGridAnchorGenerator(
min_level, max_level, anchor_scale, aspect_ratios, scales_per_octave) min_level, max_level, anchor_scale, aspect_ratios, scales_per_octave,
anchors = anchor_generator.generate(feature_map_shape_list, normalize_coordinates=False)
im_height, im_width) anchors_list = anchor_generator.generate(
anchor_corners = anchors.get() feature_map_shape_list, im_height=im_height, im_width=im_width)
anchor_corners = anchors_list[0].get()
with self.test_session():
anchor_corners_out = anchor_corners.eval()
self.assertAllClose(anchor_corners_out, exp_anchor_corners)
def test_construct_single_anchor_in_normalized_coordinates(self):
min_level = 5
max_level = 5
anchor_scale = 4.0
aspect_ratios = [1.0]
scales_per_octave = 1
im_height = 64
im_width = 128
feature_map_shape_list = [(2, 2)]
exp_anchor_corners = [[-48./64, -48./128, 80./64, 80./128],
[-48./64, -16./128, 80./64, 112./128],
[-16./64, -48./128, 112./64, 80./128],
[-16./64, -16./128, 112./64, 112./128]]
anchor_generator = mg.MultiscaleGridAnchorGenerator(
min_level, max_level, anchor_scale, aspect_ratios, scales_per_octave,
normalize_coordinates=True)
anchors_list = anchor_generator.generate(
feature_map_shape_list, im_height=im_height, im_width=im_width)
anchor_corners = anchors_list[0].get()
with self.test_session(): with self.test_session():
anchor_corners_out = anchor_corners.eval() anchor_corners_out = anchor_corners.eval()
...@@ -53,7 +78,8 @@ class MultiscaleGridAnchorGeneratorTest(test_case.TestCase): ...@@ -53,7 +78,8 @@ class MultiscaleGridAnchorGeneratorTest(test_case.TestCase):
aspect_ratios = [1.0, 2.0] aspect_ratios = [1.0, 2.0]
scales_per_octave = 3 scales_per_octave = 3
anchor_generator = mg.MultiscaleGridAnchorGenerator( anchor_generator = mg.MultiscaleGridAnchorGenerator(
min_level, max_level, anchor_scale, aspect_ratios, scales_per_octave) min_level, max_level, anchor_scale, aspect_ratios, scales_per_octave,
normalize_coordinates=False)
self.assertEqual(anchor_generator.num_anchors_per_location(), [6, 6]) self.assertEqual(anchor_generator.num_anchors_per_location(), [6, 6])
def test_construct_single_anchor_fails_with_tensor_image_size(self): def test_construct_single_anchor_fails_with_tensor_image_size(self):
...@@ -66,9 +92,11 @@ class MultiscaleGridAnchorGeneratorTest(test_case.TestCase): ...@@ -66,9 +92,11 @@ class MultiscaleGridAnchorGeneratorTest(test_case.TestCase):
im_width = tf.constant(64) im_width = tf.constant(64)
feature_map_shape_list = [(2, 2)] feature_map_shape_list = [(2, 2)]
anchor_generator = mg.MultiscaleGridAnchorGenerator( anchor_generator = mg.MultiscaleGridAnchorGenerator(
min_level, max_level, anchor_scale, aspect_ratios, scales_per_octave) min_level, max_level, anchor_scale, aspect_ratios, scales_per_octave,
normalize_coordinates=False)
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
anchor_generator.generate(feature_map_shape_list, im_height, im_width) anchor_generator.generate(
feature_map_shape_list, im_height=im_height, im_width=im_width)
def test_construct_single_anchor_with_odd_input_dimension(self): def test_construct_single_anchor_with_odd_input_dimension(self):
...@@ -82,10 +110,11 @@ class MultiscaleGridAnchorGeneratorTest(test_case.TestCase): ...@@ -82,10 +110,11 @@ class MultiscaleGridAnchorGeneratorTest(test_case.TestCase):
im_width = 65 im_width = 65
feature_map_shape_list = [(3, 3)] feature_map_shape_list = [(3, 3)]
anchor_generator = mg.MultiscaleGridAnchorGenerator( anchor_generator = mg.MultiscaleGridAnchorGenerator(
min_level, max_level, anchor_scale, aspect_ratios, scales_per_octave) min_level, max_level, anchor_scale, aspect_ratios, scales_per_octave,
anchors = anchor_generator.generate(feature_map_shape_list, im_height, normalize_coordinates=False)
im_width) anchors_list = anchor_generator.generate(
anchor_corners = anchors.get() feature_map_shape_list, im_height=im_height, im_width=im_width)
anchor_corners = anchors_list[0].get()
return (anchor_corners,) return (anchor_corners,)
anchor_corners_out = self.execute(graph_fn, []) anchor_corners_out = self.execute(graph_fn, [])
exp_anchor_corners = [[-64, -64, 64, 64], exp_anchor_corners = [[-64, -64, 64, 64],
...@@ -111,13 +140,15 @@ class MultiscaleGridAnchorGeneratorTest(test_case.TestCase): ...@@ -111,13 +140,15 @@ class MultiscaleGridAnchorGeneratorTest(test_case.TestCase):
im_width = 64 im_width = 64
feature_map_shape_list = [(2, 2), (1, 1)] feature_map_shape_list = [(2, 2), (1, 1)]
anchor_generator = mg.MultiscaleGridAnchorGenerator( anchor_generator = mg.MultiscaleGridAnchorGenerator(
min_level, max_level, anchor_scale, aspect_ratios, scales_per_octave) min_level, max_level, anchor_scale, aspect_ratios, scales_per_octave,
anchors = anchor_generator.generate(feature_map_shape_list, im_height, normalize_coordinates=False)
im_width) anchors_list = anchor_generator.generate(feature_map_shape_list,
anchor_corners = anchors.get() im_height=im_height,
return (anchor_corners,) im_width=im_width)
anchor_corners = [anchors.get() for anchors in anchors_list]
return anchor_corners
anchor_corners_out = self.execute(graph_fn, []) anchor_corners_out = np.concatenate(self.execute(graph_fn, []), axis=0)
exp_anchor_corners = [[-48, -48, 80, 80], exp_anchor_corners = [[-48, -48, 80, 80],
[-48, -16, 80, 112], [-48, -16, 80, 112],
[-16, -48, 112, 80], [-16, -48, 112, 80],
...@@ -135,19 +166,22 @@ class MultiscaleGridAnchorGeneratorTest(test_case.TestCase): ...@@ -135,19 +166,22 @@ class MultiscaleGridAnchorGeneratorTest(test_case.TestCase):
scales_per_octave = 2 scales_per_octave = 2
im_height = 64 im_height = 64
im_width = 64 im_width = 64
feature_map_shape_list = [(1, 1), (1, 1)] feature_map_shape_list = [(1, 1)]
anchor_generator = mg.MultiscaleGridAnchorGenerator( anchor_generator = mg.MultiscaleGridAnchorGenerator(
min_level, max_level, anchor_scale, aspect_ratios, scales_per_octave) min_level, max_level, anchor_scale, aspect_ratios, scales_per_octave,
anchors = anchor_generator.generate(feature_map_shape_list, im_height, normalize_coordinates=False)
im_width) anchors_list = anchor_generator.generate(feature_map_shape_list,
anchor_corners = anchors.get() im_height=im_height,
return (anchor_corners,) im_width=im_width)
anchor_corners = [anchors.get() for anchors in anchors_list]
return anchor_corners
# There are 4 set of anchors in this configuration. The order is: # There are 4 set of anchors in this configuration. The order is:
# [[2**0.0 intermediate scale + 1.0 aspect], # [[2**0.0 intermediate scale + 1.0 aspect],
# [2**0.5 intermediate scale + 1.0 aspect]] # [2**0.5 intermediate scale + 1.0 aspect]]
exp_anchor_corners = [[-96., -96., 160., 160.], exp_anchor_corners = [[-96., -96., 160., 160.],
[-149.0193, -149.0193, 213.0193, 213.0193]] [-149.0193, -149.0193, 213.0193, 213.0193]]
anchor_corners_out = self.execute(graph_fn, []) anchor_corners_out = self.execute(graph_fn, [])
self.assertAllClose(anchor_corners_out, exp_anchor_corners) self.assertAllClose(anchor_corners_out, exp_anchor_corners)
...@@ -160,18 +194,21 @@ class MultiscaleGridAnchorGeneratorTest(test_case.TestCase): ...@@ -160,18 +194,21 @@ class MultiscaleGridAnchorGeneratorTest(test_case.TestCase):
scales_per_octave = 2 scales_per_octave = 2
im_height = 64 im_height = 64
im_width = 64 im_width = 64
feature_map_shape_list = [(1, 1), (1, 1), (1, 1), (1, 1)] feature_map_shape_list = [(1, 1)]
anchor_generator = mg.MultiscaleGridAnchorGenerator( anchor_generator = mg.MultiscaleGridAnchorGenerator(
min_level, max_level, anchor_scale, aspect_ratios, scales_per_octave) min_level, max_level, anchor_scale, aspect_ratios, scales_per_octave,
anchors = anchor_generator.generate(feature_map_shape_list, im_height, normalize_coordinates=False)
im_width) anchors_list = anchor_generator.generate(feature_map_shape_list,
anchor_corners = anchors.get() im_height=im_height,
im_width=im_width)
anchor_corners = [anchors.get() for anchors in anchors_list]
return anchor_corners return anchor_corners
# There are 4 set of anchors in this configuration. The order is: # There are 4 set of anchors in this configuration. The order is:
# [[2**0.0 intermediate scale + 1.0 aspect], # [[2**0.0 intermediate scale + 1.0 aspect],
# [2**0.5 intermediate scale + 1.0 aspect], # [2**0.5 intermediate scale + 1.0 aspect],
# [2**0.0 intermediate scale + 2.0 aspect], # [2**0.0 intermediate scale + 2.0 aspect],
# [2**0.5 intermediate scale + 2.0 aspect]] # [2**0.5 intermediate scale + 2.0 aspect]]
exp_anchor_corners = [[-96., -96., 160., 160.], exp_anchor_corners = [[-96., -96., 160., 160.],
[-149.0193, -149.0193, 213.0193, 213.0193], [-149.0193, -149.0193, 213.0193, 213.0193],
[-58.50967, -149.0193, 122.50967, 213.0193], [-58.50967, -149.0193, 122.50967, 213.0193],
...@@ -193,18 +230,22 @@ class MultiscaleGridAnchorGeneratorTest(test_case.TestCase): ...@@ -193,18 +230,22 @@ class MultiscaleGridAnchorGeneratorTest(test_case.TestCase):
feature_map_shape_list = [(feature_map1_height, feature_map1_width), feature_map_shape_list = [(feature_map1_height, feature_map1_width),
(feature_map2_height, feature_map2_width)] (feature_map2_height, feature_map2_width)]
anchor_generator = mg.MultiscaleGridAnchorGenerator( anchor_generator = mg.MultiscaleGridAnchorGenerator(
min_level, max_level, anchor_scale, aspect_ratios, scales_per_octave) min_level, max_level, anchor_scale, aspect_ratios, scales_per_octave,
anchors = anchor_generator.generate(feature_map_shape_list, im_height, normalize_coordinates=False)
im_width) anchors_list = anchor_generator.generate(feature_map_shape_list,
anchor_corners = anchors.get() im_height=im_height,
return (anchor_corners,) im_width=im_width)
anchor_corners = [anchors.get() for anchors in anchors_list]
return anchor_corners
anchor_corners_out = self.execute_cpu(graph_fn, [ anchor_corners_out = np.concatenate(
np.array(2, dtype=np.int32), self.execute_cpu(graph_fn, [
np.array(2, dtype=np.int32), np.array(2, dtype=np.int32),
np.array(1, dtype=np.int32), np.array(2, dtype=np.int32),
np.array(1, dtype=np.int32) np.array(1, dtype=np.int32),
]) np.array(1, dtype=np.int32)
]),
axis=0)
exp_anchor_corners = [[-48, -48, 80, 80], exp_anchor_corners = [[-48, -48, 80, 80],
[-48, -16, 80, 112], [-48, -16, 80, 112],
[-16, -48, 112, 80], [-16, -48, 112, 80],
......
# Tensorflow Object Detection API: Box Coder implementations.
package(
default_visibility = ["//visibility:public"],
)
licenses(["notice"])
# Apache 2.0
py_library(
name = "faster_rcnn_box_coder",
srcs = [
"faster_rcnn_box_coder.py",
],
deps = [
"//tensorflow/models/research/object_detection/core:box_coder",
"//tensorflow/models/research/object_detection/core:box_list",
],
)
py_test(
name = "faster_rcnn_box_coder_test",
srcs = [
"faster_rcnn_box_coder_test.py",
],
deps = [
":faster_rcnn_box_coder",
"//tensorflow",
"//tensorflow/models/research/object_detection/core:box_list",
],
)
py_library(
name = "keypoint_box_coder",
srcs = [
"keypoint_box_coder.py",
],
deps = [
"//tensorflow/models/research/object_detection/core:box_coder",
"//tensorflow/models/research/object_detection/core:box_list",
"//tensorflow/models/research/object_detection/core:standard_fields",
],
)
py_test(
name = "keypoint_box_coder_test",
srcs = [
"keypoint_box_coder_test.py",
],
deps = [
":keypoint_box_coder",
"//tensorflow",
"//tensorflow/models/research/object_detection/core:box_list",
"//tensorflow/models/research/object_detection/core:standard_fields",
],
)
py_library(
name = "mean_stddev_box_coder",
srcs = [
"mean_stddev_box_coder.py",
],
deps = [
"//tensorflow/models/research/object_detection/core:box_coder",
"//tensorflow/models/research/object_detection/core:box_list",
],
)
py_test(
name = "mean_stddev_box_coder_test",
srcs = [
"mean_stddev_box_coder_test.py",
],
deps = [
":mean_stddev_box_coder",
"//tensorflow",
"//tensorflow/models/research/object_detection/core:box_list",
],
)
py_library(
name = "square_box_coder",
srcs = [
"square_box_coder.py",
],
deps = [
"//tensorflow/models/research/object_detection/core:box_coder",
"//tensorflow/models/research/object_detection/core:box_list",
],
)
py_test(
name = "square_box_coder_test",
srcs = [
"square_box_coder_test.py",
],
deps = [
":square_box_coder",
"//tensorflow",
"//tensorflow/models/research/object_detection/core:box_list",
],
)
# Tensorflow Object Detection API: component builders.
package(
default_visibility = ["//visibility:public"],
)
licenses(["notice"])
# Apache 2.0
py_library(
name = "model_builder",
srcs = ["model_builder.py"],
deps = [
":anchor_generator_builder",
":box_coder_builder",
":box_predictor_builder",
":hyperparams_builder",
":image_resizer_builder",
":losses_builder",
":matcher_builder",
":post_processing_builder",
":region_similarity_calculator_builder",
"//tensorflow/models/research/object_detection/core:box_predictor",
"//tensorflow/models/research/object_detection/meta_architectures:faster_rcnn_meta_arch",
"//tensorflow/models/research/object_detection/meta_architectures:rfcn_meta_arch",
"//tensorflow/models/research/object_detection/meta_architectures:ssd_meta_arch",
"//tensorflow/models/research/object_detection/models:embedded_ssd_mobilenet_v1_feature_extractor",
"//tensorflow/models/research/object_detection/models:faster_rcnn_inception_resnet_v2_feature_extractor",
"//tensorflow/models/research/object_detection/models:faster_rcnn_inception_v2_feature_extractor",
"//tensorflow/models/research/object_detection/models:faster_rcnn_nas_feature_extractor",
"//tensorflow/models/research/object_detection/models:faster_rcnn_resnet_v1_feature_extractor",
"//tensorflow/models/research/object_detection/models:ssd_inception_v2_feature_extractor",
"//tensorflow/models/research/object_detection/models:ssd_inception_v3_feature_extractor",
"//tensorflow/models/research/object_detection/models:ssd_mobilenet_v1_feature_extractor",
"//tensorflow/models/research/object_detection/models:ssd_resnet_v1_fpn_feature_extractor",
"//tensorflow/models/research/object_detection/protos:model_py_pb2",
],
)
py_test(
name = "model_builder_test",
srcs = ["model_builder_test.py"],
deps = [
":model_builder",
"//tensorflow",
"//tensorflow/models/research/object_detection/meta_architectures:faster_rcnn_meta_arch",
"//tensorflow/models/research/object_detection/meta_architectures:ssd_meta_arch",
"//tensorflow/models/research/object_detection/models:embedded_ssd_mobilenet_v1_feature_extractor",
"//tensorflow/models/research/object_detection/models:faster_rcnn_inception_resnet_v2_feature_extractor",
"//tensorflow/models/research/object_detection/models:faster_rcnn_inception_v2_feature_extractor",
"//tensorflow/models/research/object_detection/models:faster_rcnn_nas_feature_extractor",
"//tensorflow/models/research/object_detection/models:faster_rcnn_resnet_v1_feature_extractor",
"//tensorflow/models/research/object_detection/models:ssd_inception_v2_feature_extractor",
"//tensorflow/models/research/object_detection/models:ssd_inception_v3_feature_extractor",
"//tensorflow/models/research/object_detection/models:ssd_mobilenet_v1_feature_extractor",
"//tensorflow/models/research/object_detection/models:ssd_resnet_v1_fpn_feature_extractor",
"//tensorflow/models/research/object_detection/protos:model_py_pb2",
],
)
py_library(
name = "matcher_builder",
srcs = ["matcher_builder.py"],
deps = [
"//tensorflow/models/research/object_detection/matchers:argmax_matcher",
"//tensorflow/models/research/object_detection/matchers:bipartite_matcher",
"//tensorflow/models/research/object_detection/protos:matcher_py_pb2",
],
)
py_test(
name = "matcher_builder_test",
srcs = ["matcher_builder_test.py"],
deps = [
":matcher_builder",
"//tensorflow/models/research/object_detection/matchers:argmax_matcher",
"//tensorflow/models/research/object_detection/matchers:bipartite_matcher",
"//tensorflow/models/research/object_detection/protos:matcher_py_pb2",
],
)
py_library(
name = "box_coder_builder",
srcs = ["box_coder_builder.py"],
deps = [
"//tensorflow/models/research/object_detection/box_coders:faster_rcnn_box_coder",
"//tensorflow/models/research/object_detection/box_coders:keypoint_box_coder",
"//tensorflow/models/research/object_detection/box_coders:mean_stddev_box_coder",
"//tensorflow/models/research/object_detection/box_coders:square_box_coder",
"//tensorflow/models/research/object_detection/protos:box_coder_py_pb2",
],
)
py_test(
name = "box_coder_builder_test",
srcs = ["box_coder_builder_test.py"],
deps = [
":box_coder_builder",
"//tensorflow",
"//tensorflow/models/research/object_detection/box_coders:faster_rcnn_box_coder",
"//tensorflow/models/research/object_detection/box_coders:keypoint_box_coder",
"//tensorflow/models/research/object_detection/box_coders:mean_stddev_box_coder",
"//tensorflow/models/research/object_detection/box_coders:square_box_coder",
"//tensorflow/models/research/object_detection/protos:box_coder_py_pb2",
],
)
py_library(
name = "anchor_generator_builder",
srcs = ["anchor_generator_builder.py"],
deps = [
"//tensorflow/models/research/object_detection/anchor_generators:grid_anchor_generator",
"//tensorflow/models/research/object_detection/anchor_generators:multiple_grid_anchor_generator",
"//tensorflow/models/research/object_detection/anchor_generators:multiscale_grid_anchor_generator",
"//tensorflow/models/research/object_detection/protos:anchor_generator_py_pb2",
],
)
py_test(
name = "anchor_generator_builder_test",
srcs = ["anchor_generator_builder_test.py"],
deps = [
":anchor_generator_builder",
"//tensorflow",
"//tensorflow/models/research/object_detection/anchor_generators:grid_anchor_generator",
"//tensorflow/models/research/object_detection/anchor_generators:multiple_grid_anchor_generator",
"//tensorflow/models/research/object_detection/anchor_generators:multiscale_grid_anchor_generator",
"//tensorflow/models/research/object_detection/protos:anchor_generator_py_pb2",
],
)
py_library(
name = "dataset_builder",
srcs = ["dataset_builder.py"],
deps = [
"//tensorflow",
"//tensorflow/models/research/object_detection/data_decoders:tf_example_decoder",
"//tensorflow/models/research/object_detection/protos:input_reader_py_pb2",
"//tensorflow/models/research/object_detection/utils:dataset_util",
],
)
py_test(
name = "dataset_builder_test",
srcs = [
"dataset_builder_test.py",
],
deps = [
":dataset_builder",
"//tensorflow",
"//tensorflow/models/research/object_detection/core:standard_fields",
"//tensorflow/models/research/object_detection/protos:input_reader_py_pb2",
"//tensorflow/models/research/object_detection/utils:dataset_util",
],
)
py_library(
name = "input_reader_builder",
srcs = ["input_reader_builder.py"],
deps = [
"//tensorflow",
"//tensorflow/models/research/object_detection/data_decoders:tf_example_decoder",
"//tensorflow/models/research/object_detection/protos:input_reader_py_pb2",
],
)
py_test(
name = "input_reader_builder_test",
srcs = [
"input_reader_builder_test.py",
],
deps = [
":input_reader_builder",
"//tensorflow",
"//tensorflow/models/research/object_detection/core:standard_fields",
"//tensorflow/models/research/object_detection/protos:input_reader_py_pb2",
],
)
py_library(
name = "losses_builder",
srcs = ["losses_builder.py"],
deps = [
"//tensorflow/models/research/object_detection/core:losses",
"//tensorflow/models/research/object_detection/protos:losses_py_pb2",
],
)
py_test(
name = "losses_builder_test",
srcs = ["losses_builder_test.py"],
deps = [
":losses_builder",
"//tensorflow/models/research/object_detection/core:losses",
"//tensorflow/models/research/object_detection/protos:losses_py_pb2",
],
)
py_library(
name = "optimizer_builder",
srcs = ["optimizer_builder.py"],
deps = [
"//tensorflow",
"//tensorflow/models/research/object_detection/utils:learning_schedules",
],
)
py_test(
name = "optimizer_builder_test",
srcs = ["optimizer_builder_test.py"],
deps = [
":optimizer_builder",
"//tensorflow",
"//tensorflow/models/research/object_detection/protos:optimizer_py_pb2",
],
)
py_library(
name = "post_processing_builder",
srcs = ["post_processing_builder.py"],
deps = [
"//tensorflow",
"//tensorflow/models/research/object_detection/core:post_processing",
"//tensorflow/models/research/object_detection/protos:post_processing_py_pb2",
],
)
py_test(
name = "post_processing_builder_test",
srcs = ["post_processing_builder_test.py"],
deps = [
":post_processing_builder",
"//tensorflow",
"//tensorflow/models/research/object_detection/protos:post_processing_py_pb2",
],
)
py_library(
name = "hyperparams_builder",
srcs = ["hyperparams_builder.py"],
deps = [
"//tensorflow/models/research/object_detection/protos:hyperparams_py_pb2",
],
)
py_test(
name = "hyperparams_builder_test",
srcs = ["hyperparams_builder_test.py"],
deps = [
":hyperparams_builder",
"//tensorflow",
"//tensorflow/models/research/object_detection/protos:hyperparams_py_pb2",
],
)
py_library(
name = "box_predictor_builder",
srcs = ["box_predictor_builder.py"],
deps = [
":hyperparams_builder",
"//tensorflow/models/research/object_detection/core:box_predictor",
"//tensorflow/models/research/object_detection/protos:box_predictor_py_pb2",
],
)
py_test(
name = "box_predictor_builder_test",
srcs = ["box_predictor_builder_test.py"],
deps = [
":box_predictor_builder",
":hyperparams_builder",
"//tensorflow",
"//tensorflow/models/research/object_detection/protos:box_predictor_py_pb2",
"//tensorflow/models/research/object_detection/protos:hyperparams_py_pb2",
],
)
py_library(
name = "region_similarity_calculator_builder",
srcs = ["region_similarity_calculator_builder.py"],
deps = [
"//tensorflow/models/research/object_detection/core:region_similarity_calculator",
"//tensorflow/models/research/object_detection/protos:region_similarity_calculator_py_pb2",
],
)
py_test(
name = "region_similarity_calculator_builder_test",
srcs = ["region_similarity_calculator_builder_test.py"],
deps = [
":region_similarity_calculator_builder",
"//tensorflow",
],
)
py_library(
name = "preprocessor_builder",
srcs = ["preprocessor_builder.py"],
deps = [
"//tensorflow",
"//tensorflow/models/research/object_detection/core:preprocessor",
"//tensorflow/models/research/object_detection/protos:preprocessor_py_pb2",
],
)
py_test(
name = "preprocessor_builder_test",
srcs = [
"preprocessor_builder_test.py",
],
deps = [
":preprocessor_builder",
"//tensorflow",
"//tensorflow/models/research/object_detection/core:preprocessor",
"//tensorflow/models/research/object_detection/protos:preprocessor_py_pb2",
],
)
py_library(
name = "image_resizer_builder",
srcs = ["image_resizer_builder.py"],
deps = [
"//tensorflow",
"//tensorflow/models/research/object_detection/core:preprocessor",
"//tensorflow/models/research/object_detection/protos:image_resizer_py_pb2",
],
)
py_test(
name = "image_resizer_builder_test",
srcs = ["image_resizer_builder_test.py"],
deps = [
":image_resizer_builder",
"//tensorflow",
"//tensorflow/models/research/object_detection/protos:image_resizer_py_pb2",
],
)
...@@ -87,7 +87,8 @@ def build(anchor_generator_config): ...@@ -87,7 +87,8 @@ def build(anchor_generator_config):
cfg.max_level, cfg.max_level,
cfg.anchor_scale, cfg.anchor_scale,
[float(aspect_ratio) for aspect_ratio in cfg.aspect_ratios], [float(aspect_ratio) for aspect_ratio in cfg.aspect_ratios],
cfg.scales_per_octave cfg.scales_per_octave,
cfg.normalize_coordinates
) )
else: else:
raise ValueError('Empty anchor generator.') raise ValueError('Empty anchor generator.')
...@@ -276,6 +276,24 @@ class AnchorGeneratorBuilderTest(tf.test.TestCase): ...@@ -276,6 +276,24 @@ class AnchorGeneratorBuilderTest(tf.test.TestCase):
self.assertAllClose(anchor_grid_info['info'][2], self.assertAllClose(anchor_grid_info['info'][2],
[4.0 * 2**level, 4.0 * 2**level]) [4.0 * 2**level, 4.0 * 2**level])
self.assertAllClose(anchor_grid_info['info'][3], [2**level, 2**level]) self.assertAllClose(anchor_grid_info['info'][3], [2**level, 2**level])
self.assertTrue(anchor_generator_object._normalize_coordinates)
def test_build_multiscale_anchor_generator_with_anchors_in_pixel_coordinates(
self):
anchor_generator_text_proto = """
multiscale_anchor_generator {
aspect_ratios: [1.0]
normalize_coordinates: false
}
"""
anchor_generator_proto = anchor_generator_pb2.AnchorGenerator()
text_format.Merge(anchor_generator_text_proto, anchor_generator_proto)
anchor_generator_object = anchor_generator_builder.build(
anchor_generator_proto)
self.assertTrue(isinstance(anchor_generator_object,
multiscale_grid_anchor_generator.
MultiscaleGridAnchorGenerator))
self.assertFalse(anchor_generator_object._normalize_coordinates)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -134,6 +134,10 @@ def _build_initializer(initializer): ...@@ -134,6 +134,10 @@ def _build_initializer(initializer):
return tf.truncated_normal_initializer( return tf.truncated_normal_initializer(
mean=initializer.truncated_normal_initializer.mean, mean=initializer.truncated_normal_initializer.mean,
stddev=initializer.truncated_normal_initializer.stddev) stddev=initializer.truncated_normal_initializer.stddev)
if initializer_oneof == 'random_normal_initializer':
return tf.random_normal_initializer(
mean=initializer.random_normal_initializer.mean,
stddev=initializer.random_normal_initializer.stddev)
if initializer_oneof == 'variance_scaling_initializer': if initializer_oneof == 'variance_scaling_initializer':
enum_descriptor = (hyperparams_pb2.VarianceScalingInitializer. enum_descriptor = (hyperparams_pb2.VarianceScalingInitializer.
DESCRIPTOR.enum_types_by_name['Mode']) DESCRIPTOR.enum_types_by_name['Mode'])
......
...@@ -28,7 +28,7 @@ slim = tf.contrib.slim ...@@ -28,7 +28,7 @@ slim = tf.contrib.slim
class HyperparamsBuilderTest(tf.test.TestCase): class HyperparamsBuilderTest(tf.test.TestCase):
# TODO: Make this a public api in slim arg_scope.py. # TODO(rathodv): Make this a public api in slim arg_scope.py.
def _get_scope_key(self, op): def _get_scope_key(self, op):
return getattr(op, '_key_op', str(op)) return getattr(op, '_key_op', str(op))
...@@ -444,6 +444,27 @@ class HyperparamsBuilderTest(tf.test.TestCase): ...@@ -444,6 +444,27 @@ class HyperparamsBuilderTest(tf.test.TestCase):
self._assert_variance_in_range(initializer, shape=[100, 40], self._assert_variance_in_range(initializer, shape=[100, 40],
variance=0.49, tol=1e-1) variance=0.49, tol=1e-1)
def test_variance_in_range_with_random_normal_initializer(self):
conv_hyperparams_text_proto = """
regularizer {
l2_regularizer {
}
}
initializer {
random_normal_initializer {
mean: 0.0
stddev: 0.8
}
}
"""
conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
scope = hyperparams_builder.build(conv_hyperparams_proto, is_training=True)
conv_scope_arguments = scope.values()[0]
initializer = conv_scope_arguments['weights_initializer']
self._assert_variance_in_range(initializer, shape=[100, 40],
variance=0.64, tol=1e-1)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
...@@ -72,8 +72,8 @@ def build(image_resizer_config): ...@@ -72,8 +72,8 @@ def build(image_resizer_config):
raise ValueError('image_resizer_config not of type ' raise ValueError('image_resizer_config not of type '
'image_resizer_pb2.ImageResizer.') 'image_resizer_pb2.ImageResizer.')
if image_resizer_config.WhichOneof( image_resizer_oneof = image_resizer_config.WhichOneof('image_resizer_oneof')
'image_resizer_oneof') == 'keep_aspect_ratio_resizer': if image_resizer_oneof == 'keep_aspect_ratio_resizer':
keep_aspect_ratio_config = image_resizer_config.keep_aspect_ratio_resizer keep_aspect_ratio_config = image_resizer_config.keep_aspect_ratio_resizer
if not (keep_aspect_ratio_config.min_dimension <= if not (keep_aspect_ratio_config.min_dimension <=
keep_aspect_ratio_config.max_dimension): keep_aspect_ratio_config.max_dimension):
...@@ -87,8 +87,7 @@ def build(image_resizer_config): ...@@ -87,8 +87,7 @@ def build(image_resizer_config):
pad_to_max_dimension=keep_aspect_ratio_config.pad_to_max_dimension) pad_to_max_dimension=keep_aspect_ratio_config.pad_to_max_dimension)
if not keep_aspect_ratio_config.convert_to_grayscale: if not keep_aspect_ratio_config.convert_to_grayscale:
return image_resizer_fn return image_resizer_fn
elif image_resizer_config.WhichOneof( elif image_resizer_oneof == 'fixed_shape_resizer':
'image_resizer_oneof') == 'fixed_shape_resizer':
fixed_shape_resizer_config = image_resizer_config.fixed_shape_resizer fixed_shape_resizer_config = image_resizer_config.fixed_shape_resizer
method = _tf_resize_method(fixed_shape_resizer_config.resize_method) method = _tf_resize_method(fixed_shape_resizer_config.resize_method)
image_resizer_fn = functools.partial( image_resizer_fn = functools.partial(
...@@ -99,7 +98,8 @@ def build(image_resizer_config): ...@@ -99,7 +98,8 @@ def build(image_resizer_config):
if not fixed_shape_resizer_config.convert_to_grayscale: if not fixed_shape_resizer_config.convert_to_grayscale:
return image_resizer_fn return image_resizer_fn
else: else:
raise ValueError('Invalid image resizer option.') raise ValueError(
'Invalid image resizer option: \'%s\'.' % image_resizer_oneof)
def grayscale_image_resizer(image): def grayscale_image_resizer(image):
[resized_image, resized_image_shape] = image_resizer_fn(image) [resized_image, resized_image_shape] = image_resizer_fn(image)
......
...@@ -150,7 +150,8 @@ def _build_localization_loss(loss_config): ...@@ -150,7 +150,8 @@ def _build_localization_loss(loss_config):
return losses.WeightedL2LocalizationLoss() return losses.WeightedL2LocalizationLoss()
if loss_type == 'weighted_smooth_l1': if loss_type == 'weighted_smooth_l1':
return losses.WeightedSmoothL1LocalizationLoss() return losses.WeightedSmoothL1LocalizationLoss(
loss_config.weighted_smooth_l1.delta)
if loss_type == 'weighted_iou': if loss_type == 'weighted_iou':
return losses.WeightedIOULocalizationLoss() return losses.WeightedIOULocalizationLoss()
......
...@@ -42,7 +42,7 @@ class LocalizationLossBuilderTest(tf.test.TestCase): ...@@ -42,7 +42,7 @@ class LocalizationLossBuilderTest(tf.test.TestCase):
self.assertTrue(isinstance(localization_loss, self.assertTrue(isinstance(localization_loss,
losses.WeightedL2LocalizationLoss)) losses.WeightedL2LocalizationLoss))
def test_build_weighted_smooth_l1_localization_loss(self): def test_build_weighted_smooth_l1_localization_loss_default_delta(self):
losses_text_proto = """ losses_text_proto = """
localization_loss { localization_loss {
weighted_smooth_l1 { weighted_smooth_l1 {
...@@ -58,6 +58,26 @@ class LocalizationLossBuilderTest(tf.test.TestCase): ...@@ -58,6 +58,26 @@ class LocalizationLossBuilderTest(tf.test.TestCase):
_, localization_loss, _, _, _ = losses_builder.build(losses_proto) _, localization_loss, _, _, _ = losses_builder.build(losses_proto)
self.assertTrue(isinstance(localization_loss, self.assertTrue(isinstance(localization_loss,
losses.WeightedSmoothL1LocalizationLoss)) losses.WeightedSmoothL1LocalizationLoss))
self.assertAlmostEqual(localization_loss._delta, 1.0)
def test_build_weighted_smooth_l1_localization_loss_non_default_delta(self):
losses_text_proto = """
localization_loss {
weighted_smooth_l1 {
delta: 0.1
}
}
classification_loss {
weighted_softmax {
}
}
"""
losses_proto = losses_pb2.Loss()
text_format.Merge(losses_text_proto, losses_proto)
_, localization_loss, _, _, _ = losses_builder.build(losses_proto)
self.assertTrue(isinstance(localization_loss,
losses.WeightedSmoothL1LocalizationLoss))
self.assertAlmostEqual(localization_loss._delta, 0.1)
def test_build_weighted_iou_localization_loss(self): def test_build_weighted_iou_localization_loss(self):
losses_text_proto = """ losses_text_proto = """
......
...@@ -153,6 +153,7 @@ def _build_ssd_model(ssd_config, is_training, add_summaries): ...@@ -153,6 +153,7 @@ def _build_ssd_model(ssd_config, is_training, add_summaries):
region_similarity_calculator = sim_calc.build( region_similarity_calculator = sim_calc.build(
ssd_config.similarity_calculator) ssd_config.similarity_calculator)
encode_background_as_zeros = ssd_config.encode_background_as_zeros encode_background_as_zeros = ssd_config.encode_background_as_zeros
negative_class_weight = ssd_config.negative_class_weight
ssd_box_predictor = box_predictor_builder.build(hyperparams_builder.build, ssd_box_predictor = box_predictor_builder.build(hyperparams_builder.build,
ssd_config.box_predictor, ssd_config.box_predictor,
is_training, num_classes) is_training, num_classes)
...@@ -165,6 +166,7 @@ def _build_ssd_model(ssd_config, is_training, add_summaries): ...@@ -165,6 +166,7 @@ def _build_ssd_model(ssd_config, is_training, add_summaries):
localization_weight, localization_weight,
hard_example_miner) = losses_builder.build(ssd_config.loss) hard_example_miner) = losses_builder.build(ssd_config.loss)
normalize_loss_by_num_matches = ssd_config.normalize_loss_by_num_matches normalize_loss_by_num_matches = ssd_config.normalize_loss_by_num_matches
normalize_loc_loss_by_codesize = ssd_config.normalize_loc_loss_by_codesize
return ssd_meta_arch.SSDMetaArch( return ssd_meta_arch.SSDMetaArch(
is_training, is_training,
...@@ -175,6 +177,7 @@ def _build_ssd_model(ssd_config, is_training, add_summaries): ...@@ -175,6 +177,7 @@ def _build_ssd_model(ssd_config, is_training, add_summaries):
matcher, matcher,
region_similarity_calculator, region_similarity_calculator,
encode_background_as_zeros, encode_background_as_zeros,
negative_class_weight,
image_resizer_fn, image_resizer_fn,
non_max_suppression_fn, non_max_suppression_fn,
score_conversion_fn, score_conversion_fn,
...@@ -184,7 +187,8 @@ def _build_ssd_model(ssd_config, is_training, add_summaries): ...@@ -184,7 +187,8 @@ def _build_ssd_model(ssd_config, is_training, add_summaries):
localization_weight, localization_weight,
normalize_loss_by_num_matches, normalize_loss_by_num_matches,
hard_example_miner, hard_example_miner,
add_summaries=add_summaries) add_summaries=add_summaries,
normalize_loc_loss_by_codesize=normalize_loc_loss_by_codesize)
def _build_faster_rcnn_feature_extractor( def _build_faster_rcnn_feature_extractor(
......
...@@ -259,13 +259,15 @@ class ModelBuilderTest(tf.test.TestCase): ...@@ -259,13 +259,15 @@ class ModelBuilderTest(tf.test.TestCase):
} }
} }
initializer { initializer {
truncated_normal_initializer { random_normal_initializer {
} }
} }
} }
num_layers_before_predictor: 1 num_layers_before_predictor: 1
} }
} }
normalize_loss_by_num_matches: true
normalize_loc_loss_by_codesize: true
loss { loss {
classification_loss { classification_loss {
weighted_sigmoid_focal { weighted_sigmoid_focal {
...@@ -275,6 +277,7 @@ class ModelBuilderTest(tf.test.TestCase): ...@@ -275,6 +277,7 @@ class ModelBuilderTest(tf.test.TestCase):
} }
localization_loss { localization_loss {
weighted_smooth_l1 { weighted_smooth_l1 {
delta: 0.1
} }
} }
classification_weight: 1.0 classification_weight: 1.0
...@@ -344,6 +347,7 @@ class ModelBuilderTest(tf.test.TestCase): ...@@ -344,6 +347,7 @@ class ModelBuilderTest(tf.test.TestCase):
} }
} }
} }
normalize_loc_loss_by_codesize: true
loss { loss {
classification_loss { classification_loss {
weighted_softmax { weighted_softmax {
...@@ -362,6 +366,7 @@ class ModelBuilderTest(tf.test.TestCase): ...@@ -362,6 +366,7 @@ class ModelBuilderTest(tf.test.TestCase):
self.assertIsInstance(model._feature_extractor, self.assertIsInstance(model._feature_extractor,
SSDMobileNetV1FeatureExtractor) SSDMobileNetV1FeatureExtractor)
self.assertTrue(model._feature_extractor._batch_norm_trainable) self.assertTrue(model._feature_extractor._batch_norm_trainable)
self.assertTrue(model._normalize_loc_loss_by_codesize)
def test_create_embedded_ssd_mobilenet_v1_model_from_config(self): def test_create_embedded_ssd_mobilenet_v1_model_from_config(self):
model_text_proto = """ model_text_proto = """
......
...@@ -187,7 +187,7 @@ class OptimizerBuilderTest(tf.test.TestCase): ...@@ -187,7 +187,7 @@ class OptimizerBuilderTest(tf.test.TestCase):
optimizer, _ = optimizer_builder.build(optimizer_proto) optimizer, _ = optimizer_builder.build(optimizer_proto)
self.assertTrue( self.assertTrue(
isinstance(optimizer, tf.contrib.opt.MovingAverageOptimizer)) isinstance(optimizer, tf.contrib.opt.MovingAverageOptimizer))
# TODO: Find a way to not depend on the private members. # TODO(rathodv): Find a way to not depend on the private members.
self.assertAlmostEqual(optimizer._ema._decay, 0.2) self.assertAlmostEqual(optimizer._ema._decay, 0.2)
def testBuildEmptyOptimizer(self): def testBuildEmptyOptimizer(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment