Unverified Commit 9cfd2d93 authored by Jonathan Huang's avatar Jonathan Huang Committed by GitHub
Browse files

Merge pull request #2620 from tombstone/update_builder

update builders.
parents 20d74292 0a49aee8
...@@ -72,6 +72,7 @@ py_library( ...@@ -72,6 +72,7 @@ py_library(
srcs = ["box_coder_builder.py"], srcs = ["box_coder_builder.py"],
deps = [ deps = [
"//tensorflow_models/object_detection/box_coders:faster_rcnn_box_coder", "//tensorflow_models/object_detection/box_coders:faster_rcnn_box_coder",
"//tensorflow_models/object_detection/box_coders:keypoint_box_coder",
"//tensorflow_models/object_detection/box_coders:mean_stddev_box_coder", "//tensorflow_models/object_detection/box_coders:mean_stddev_box_coder",
"//tensorflow_models/object_detection/box_coders:square_box_coder", "//tensorflow_models/object_detection/box_coders:square_box_coder",
"//tensorflow_models/object_detection/protos:box_coder_py_pb2", "//tensorflow_models/object_detection/protos:box_coder_py_pb2",
...@@ -85,6 +86,7 @@ py_test( ...@@ -85,6 +86,7 @@ py_test(
":box_coder_builder", ":box_coder_builder",
"//tensorflow", "//tensorflow",
"//tensorflow_models/object_detection/box_coders:faster_rcnn_box_coder", "//tensorflow_models/object_detection/box_coders:faster_rcnn_box_coder",
"//tensorflow_models/object_detection/box_coders:keypoint_box_coder",
"//tensorflow_models/object_detection/box_coders:mean_stddev_box_coder", "//tensorflow_models/object_detection/box_coders:mean_stddev_box_coder",
"//tensorflow_models/object_detection/box_coders:square_box_coder", "//tensorflow_models/object_detection/box_coders:square_box_coder",
"//tensorflow_models/object_detection/protos:box_coder_py_pb2", "//tensorflow_models/object_detection/protos:box_coder_py_pb2",
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
"""A function to build an object detection box coder from configuration.""" """A function to build an object detection box coder from configuration."""
from object_detection.box_coders import faster_rcnn_box_coder from object_detection.box_coders import faster_rcnn_box_coder
from object_detection.box_coders import keypoint_box_coder
from object_detection.box_coders import mean_stddev_box_coder from object_detection.box_coders import mean_stddev_box_coder
from object_detection.box_coders import square_box_coder from object_detection.box_coders import square_box_coder
from object_detection.protos import box_coder_pb2 from object_detection.protos import box_coder_pb2
...@@ -43,6 +44,15 @@ def build(box_coder_config): ...@@ -43,6 +44,15 @@ def build(box_coder_config):
box_coder_config.faster_rcnn_box_coder.height_scale, box_coder_config.faster_rcnn_box_coder.height_scale,
box_coder_config.faster_rcnn_box_coder.width_scale box_coder_config.faster_rcnn_box_coder.width_scale
]) ])
if box_coder_config.WhichOneof('box_coder_oneof') == 'keypoint_box_coder':
return keypoint_box_coder.KeypointBoxCoder(
box_coder_config.keypoint_box_coder.num_keypoints,
scale_factors=[
box_coder_config.keypoint_box_coder.y_scale,
box_coder_config.keypoint_box_coder.x_scale,
box_coder_config.keypoint_box_coder.height_scale,
box_coder_config.keypoint_box_coder.width_scale
])
if (box_coder_config.WhichOneof('box_coder_oneof') == if (box_coder_config.WhichOneof('box_coder_oneof') ==
'mean_stddev_box_coder'): 'mean_stddev_box_coder'):
return mean_stddev_box_coder.MeanStddevBoxCoder() return mean_stddev_box_coder.MeanStddevBoxCoder()
......
...@@ -19,6 +19,7 @@ import tensorflow as tf ...@@ -19,6 +19,7 @@ import tensorflow as tf
from google.protobuf import text_format from google.protobuf import text_format
from object_detection.box_coders import faster_rcnn_box_coder from object_detection.box_coders import faster_rcnn_box_coder
from object_detection.box_coders import keypoint_box_coder
from object_detection.box_coders import mean_stddev_box_coder from object_detection.box_coders import mean_stddev_box_coder
from object_detection.box_coders import square_box_coder from object_detection.box_coders import square_box_coder
from object_detection.builders import box_coder_builder from object_detection.builders import box_coder_builder
...@@ -35,8 +36,8 @@ class BoxCoderBuilderTest(tf.test.TestCase): ...@@ -35,8 +36,8 @@ class BoxCoderBuilderTest(tf.test.TestCase):
box_coder_proto = box_coder_pb2.BoxCoder() box_coder_proto = box_coder_pb2.BoxCoder()
text_format.Merge(box_coder_text_proto, box_coder_proto) text_format.Merge(box_coder_text_proto, box_coder_proto)
box_coder_object = box_coder_builder.build(box_coder_proto) box_coder_object = box_coder_builder.build(box_coder_proto)
self.assertTrue(isinstance(box_coder_object, self.assertIsInstance(box_coder_object,
faster_rcnn_box_coder.FasterRcnnBoxCoder)) faster_rcnn_box_coder.FasterRcnnBoxCoder)
self.assertEqual(box_coder_object._scale_factors, [10.0, 10.0, 5.0, 5.0]) self.assertEqual(box_coder_object._scale_factors, [10.0, 10.0, 5.0, 5.0])
def test_build_faster_rcnn_box_coder_with_non_default_parameters(self): def test_build_faster_rcnn_box_coder_with_non_default_parameters(self):
...@@ -51,8 +52,36 @@ class BoxCoderBuilderTest(tf.test.TestCase): ...@@ -51,8 +52,36 @@ class BoxCoderBuilderTest(tf.test.TestCase):
box_coder_proto = box_coder_pb2.BoxCoder() box_coder_proto = box_coder_pb2.BoxCoder()
text_format.Merge(box_coder_text_proto, box_coder_proto) text_format.Merge(box_coder_text_proto, box_coder_proto)
box_coder_object = box_coder_builder.build(box_coder_proto) box_coder_object = box_coder_builder.build(box_coder_proto)
self.assertTrue(isinstance(box_coder_object, self.assertIsInstance(box_coder_object,
faster_rcnn_box_coder.FasterRcnnBoxCoder)) faster_rcnn_box_coder.FasterRcnnBoxCoder)
self.assertEqual(box_coder_object._scale_factors, [6.0, 3.0, 7.0, 8.0])
def test_build_keypoint_box_coder_with_defaults(self):
box_coder_text_proto = """
keypoint_box_coder {
}
"""
box_coder_proto = box_coder_pb2.BoxCoder()
text_format.Merge(box_coder_text_proto, box_coder_proto)
box_coder_object = box_coder_builder.build(box_coder_proto)
self.assertIsInstance(box_coder_object, keypoint_box_coder.KeypointBoxCoder)
self.assertEqual(box_coder_object._scale_factors, [10.0, 10.0, 5.0, 5.0])
def test_build_keypoint_box_coder_with_non_default_parameters(self):
box_coder_text_proto = """
keypoint_box_coder {
num_keypoints: 6
y_scale: 6.0
x_scale: 3.0
height_scale: 7.0
width_scale: 8.0
}
"""
box_coder_proto = box_coder_pb2.BoxCoder()
text_format.Merge(box_coder_text_proto, box_coder_proto)
box_coder_object = box_coder_builder.build(box_coder_proto)
self.assertIsInstance(box_coder_object, keypoint_box_coder.KeypointBoxCoder)
self.assertEqual(box_coder_object._num_keypoints, 6)
self.assertEqual(box_coder_object._scale_factors, [6.0, 3.0, 7.0, 8.0]) self.assertEqual(box_coder_object._scale_factors, [6.0, 3.0, 7.0, 8.0])
def test_build_mean_stddev_box_coder(self): def test_build_mean_stddev_box_coder(self):
......
...@@ -63,7 +63,9 @@ def build(argscope_fn, box_predictor_config, is_training, num_classes): ...@@ -63,7 +63,9 @@ def build(argscope_fn, box_predictor_config, is_training, num_classes):
dropout_keep_prob=conv_box_predictor.dropout_keep_probability, dropout_keep_prob=conv_box_predictor.dropout_keep_probability,
kernel_size=conv_box_predictor.kernel_size, kernel_size=conv_box_predictor.kernel_size,
box_code_size=conv_box_predictor.box_code_size, box_code_size=conv_box_predictor.box_code_size,
apply_sigmoid_to_scores=conv_box_predictor.apply_sigmoid_to_scores) apply_sigmoid_to_scores=conv_box_predictor.apply_sigmoid_to_scores,
class_prediction_bias_init=conv_box_predictor.class_prediction_bias_init
)
return box_predictor_object return box_predictor_object
if box_predictor_oneof == 'mask_rcnn_box_predictor': if box_predictor_oneof == 'mask_rcnn_box_predictor':
......
...@@ -82,6 +82,7 @@ class ConvolutionalBoxPredictorBuilderTest(tf.test.TestCase): ...@@ -82,6 +82,7 @@ class ConvolutionalBoxPredictorBuilderTest(tf.test.TestCase):
kernel_size: 3 kernel_size: 3
box_code_size: 3 box_code_size: 3
apply_sigmoid_to_scores: true apply_sigmoid_to_scores: true
class_prediction_bias_init: 4.0
} }
""" """
conv_hyperparams_text_proto = """ conv_hyperparams_text_proto = """
...@@ -114,6 +115,7 @@ class ConvolutionalBoxPredictorBuilderTest(tf.test.TestCase): ...@@ -114,6 +115,7 @@ class ConvolutionalBoxPredictorBuilderTest(tf.test.TestCase):
self.assertFalse(box_predictor._use_dropout) self.assertFalse(box_predictor._use_dropout)
self.assertAlmostEqual(box_predictor._dropout_keep_prob, 0.4) self.assertAlmostEqual(box_predictor._dropout_keep_prob, 0.4)
self.assertTrue(box_predictor._apply_sigmoid_to_scores) self.assertTrue(box_predictor._apply_sigmoid_to_scores)
self.assertAlmostEqual(box_predictor._class_prediction_bias_init, 4.0)
self.assertEqual(box_predictor.num_classes, 10) self.assertEqual(box_predictor.num_classes, 10)
self.assertFalse(box_predictor._is_training) self.assertFalse(box_predictor._is_training)
......
...@@ -163,7 +163,6 @@ def _build_batch_norm_params(batch_norm, is_training): ...@@ -163,7 +163,6 @@ def _build_batch_norm_params(batch_norm, is_training):
'center': batch_norm.center, 'center': batch_norm.center,
'scale': batch_norm.scale, 'scale': batch_norm.scale,
'epsilon': batch_norm.epsilon, 'epsilon': batch_norm.epsilon,
'fused': True,
'is_training': is_training and batch_norm.train, 'is_training': is_training and batch_norm.train,
} }
return batch_norm_params return batch_norm_params
...@@ -20,7 +20,6 @@ import tensorflow as tf ...@@ -20,7 +20,6 @@ import tensorflow as tf
from google.protobuf import text_format from google.protobuf import text_format
# TODO: Rewrite third_party imports.
from object_detection.builders import hyperparams_builder from object_detection.builders import hyperparams_builder
from object_detection.protos import hyperparams_pb2 from object_detection.protos import hyperparams_pb2
......
...@@ -12,14 +12,43 @@ ...@@ -12,14 +12,43 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Builder function for image resizing operations.""" """Builder function for image resizing operations."""
import functools import functools
import tensorflow as tf
from object_detection.core import preprocessor from object_detection.core import preprocessor
from object_detection.protos import image_resizer_pb2 from object_detection.protos import image_resizer_pb2
def _tf_resize_method(resize_method):
"""Maps image resize method from enumeration type to TensorFlow.
Args:
resize_method: The resize_method attribute of keep_aspect_ratio_resizer or
fixed_shape_resizer.
Returns:
method: The corresponding TensorFlow ResizeMethod.
Raises:
ValueError: if `resize_method` is of unknown type.
"""
dict_method = {
image_resizer_pb2.BILINEAR:
tf.image.ResizeMethod.BILINEAR,
image_resizer_pb2.NEAREST_NEIGHBOR:
tf.image.ResizeMethod.NEAREST_NEIGHBOR,
image_resizer_pb2.BICUBIC:
tf.image.ResizeMethod.BICUBIC,
image_resizer_pb2.AREA:
tf.image.ResizeMethod.AREA
}
if resize_method in dict_method:
return dict_method[resize_method]
else:
raise ValueError('Unknown resize_method')
def build(image_resizer_config): def build(image_resizer_config):
"""Builds callable for image resizing operations. """Builds callable for image resizing operations.
...@@ -46,17 +75,22 @@ def build(image_resizer_config): ...@@ -46,17 +75,22 @@ def build(image_resizer_config):
if image_resizer_config.WhichOneof( if image_resizer_config.WhichOneof(
'image_resizer_oneof') == 'keep_aspect_ratio_resizer': 'image_resizer_oneof') == 'keep_aspect_ratio_resizer':
keep_aspect_ratio_config = image_resizer_config.keep_aspect_ratio_resizer keep_aspect_ratio_config = image_resizer_config.keep_aspect_ratio_resizer
if not (keep_aspect_ratio_config.min_dimension if not (keep_aspect_ratio_config.min_dimension <=
<= keep_aspect_ratio_config.max_dimension): keep_aspect_ratio_config.max_dimension):
raise ValueError('min_dimension > max_dimension') raise ValueError('min_dimension > max_dimension')
method = _tf_resize_method(keep_aspect_ratio_config.resize_method)
return functools.partial( return functools.partial(
preprocessor.resize_to_range, preprocessor.resize_to_range,
min_dimension=keep_aspect_ratio_config.min_dimension, min_dimension=keep_aspect_ratio_config.min_dimension,
max_dimension=keep_aspect_ratio_config.max_dimension) max_dimension=keep_aspect_ratio_config.max_dimension,
method=method)
if image_resizer_config.WhichOneof( if image_resizer_config.WhichOneof(
'image_resizer_oneof') == 'fixed_shape_resizer': 'image_resizer_oneof') == 'fixed_shape_resizer':
fixed_shape_resizer_config = image_resizer_config.fixed_shape_resizer fixed_shape_resizer_config = image_resizer_config.fixed_shape_resizer
return functools.partial(preprocessor.resize_image, method = _tf_resize_method(fixed_shape_resizer_config.resize_method)
new_height=fixed_shape_resizer_config.height, return functools.partial(
new_width=fixed_shape_resizer_config.width) preprocessor.resize_image,
new_height=fixed_shape_resizer_config.height,
new_width=fixed_shape_resizer_config.width,
method=method)
raise ValueError('Invalid image resizer option.') raise ValueError('Invalid image resizer option.')
...@@ -12,8 +12,8 @@ ...@@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Tests for object_detection.builders.image_resizer_builder.""" """Tests for object_detection.builders.image_resizer_builder."""
import numpy as np
import tensorflow as tf import tensorflow as tf
from google.protobuf import text_format from google.protobuf import text_format
from object_detection.builders import image_resizer_builder from object_detection.builders import image_resizer_builder
...@@ -22,13 +22,13 @@ from object_detection.protos import image_resizer_pb2 ...@@ -22,13 +22,13 @@ from object_detection.protos import image_resizer_pb2
class ImageResizerBuilderTest(tf.test.TestCase): class ImageResizerBuilderTest(tf.test.TestCase):
def _shape_of_resized_random_image_given_text_proto( def _shape_of_resized_random_image_given_text_proto(self, input_shape,
self, input_shape, text_proto): text_proto):
image_resizer_config = image_resizer_pb2.ImageResizer() image_resizer_config = image_resizer_pb2.ImageResizer()
text_format.Merge(text_proto, image_resizer_config) text_format.Merge(text_proto, image_resizer_config)
image_resizer_fn = image_resizer_builder.build(image_resizer_config) image_resizer_fn = image_resizer_builder.build(image_resizer_config)
images = tf.to_float(tf.random_uniform( images = tf.to_float(
input_shape, minval=0, maxval=255, dtype=tf.int32)) tf.random_uniform(input_shape, minval=0, maxval=255, dtype=tf.int32))
resized_images = image_resizer_fn(images) resized_images = image_resizer_fn(images)
with self.test_session() as sess: with self.test_session() as sess:
return sess.run(resized_images).shape return sess.run(resized_images).shape
...@@ -64,7 +64,33 @@ class ImageResizerBuilderTest(tf.test.TestCase): ...@@ -64,7 +64,33 @@ class ImageResizerBuilderTest(tf.test.TestCase):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
image_resizer_builder.build(invalid_input) image_resizer_builder.build(invalid_input)
def _resized_image_given_text_proto(self, image, text_proto):
image_resizer_config = image_resizer_pb2.ImageResizer()
text_format.Merge(text_proto, image_resizer_config)
image_resizer_fn = image_resizer_builder.build(image_resizer_config)
image_placeholder = tf.placeholder(tf.uint8, [1, None, None, 3])
resized_image = image_resizer_fn(image_placeholder)
with self.test_session() as sess:
return sess.run(resized_image, feed_dict={image_placeholder: image})
def test_fixed_shape_resizer_nearest_neighbor_method(self):
image_resizer_text_proto = """
fixed_shape_resizer {
height: 1
width: 1
resize_method: NEAREST_NEIGHBOR
}
"""
image = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
image = np.expand_dims(image, axis=2)
image = np.tile(image, (1, 1, 3))
image = np.expand_dims(image, axis=0)
resized_image = self._resized_image_given_text_proto(
image, image_resizer_text_proto)
vals = np.unique(resized_image).tolist()
self.assertEqual(len(vals), 1)
self.assertEqual(vals[0], 1)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
...@@ -42,6 +42,7 @@ def build(input_reader_config): ...@@ -42,6 +42,7 @@ def build(input_reader_config):
Raises: Raises:
ValueError: On invalid input reader proto. ValueError: On invalid input reader proto.
ValueError: If no input paths are specified.
""" """
if not isinstance(input_reader_config, input_reader_pb2.InputReader): if not isinstance(input_reader_config, input_reader_pb2.InputReader):
raise ValueError('input_reader_config not of type ' raise ValueError('input_reader_config not of type '
...@@ -49,8 +50,11 @@ def build(input_reader_config): ...@@ -49,8 +50,11 @@ def build(input_reader_config):
if input_reader_config.WhichOneof('input_reader') == 'tf_record_input_reader': if input_reader_config.WhichOneof('input_reader') == 'tf_record_input_reader':
config = input_reader_config.tf_record_input_reader config = input_reader_config.tf_record_input_reader
if not config.input_path:
raise ValueError('At least one input path must be specified in '
'`input_reader_config`.')
_, string_tensor = parallel_reader.parallel_read( _, string_tensor = parallel_reader.parallel_read(
config.input_path, config.input_path[:], # Convert `RepeatedScalarContainer` to list.
reader_class=tf.TFRecordReader, reader_class=tf.TFRecordReader,
num_epochs=(input_reader_config.num_epochs num_epochs=(input_reader_config.num_epochs
if input_reader_config.num_epochs else None), if input_reader_config.num_epochs else None),
...@@ -60,6 +64,12 @@ def build(input_reader_config): ...@@ -60,6 +64,12 @@ def build(input_reader_config):
capacity=input_reader_config.queue_capacity, capacity=input_reader_config.queue_capacity,
min_after_dequeue=input_reader_config.min_after_dequeue) min_after_dequeue=input_reader_config.min_after_dequeue)
return tf_example_decoder.TfExampleDecoder().decode(string_tensor) label_map_proto_file = None
if input_reader_config.HasField('label_map_path'):
label_map_proto_file = input_reader_config.label_map_path
decoder = tf_example_decoder.TfExampleDecoder(
load_instance_masks=input_reader_config.load_instance_masks,
label_map_proto_file=label_map_proto_file)
return decoder.decode(string_tensor)
raise ValueError('Unsupported input_reader_config.') raise ValueError('Unsupported input_reader_config.')
...@@ -35,6 +35,7 @@ class InputReaderBuilderTest(tf.test.TestCase): ...@@ -35,6 +35,7 @@ class InputReaderBuilderTest(tf.test.TestCase):
writer = tf.python_io.TFRecordWriter(path) writer = tf.python_io.TFRecordWriter(path)
image_tensor = np.random.randint(255, size=(4, 5, 3)).astype(np.uint8) image_tensor = np.random.randint(255, size=(4, 5, 3)).astype(np.uint8)
flat_mask = (4 * 5) * [1.0]
with self.test_session(): with self.test_session():
encoded_jpeg = tf.image.encode_jpeg(tf.constant(image_tensor)).eval() encoded_jpeg = tf.image.encode_jpeg(tf.constant(image_tensor)).eval()
example = example_pb2.Example(features=feature_pb2.Features(feature={ example = example_pb2.Example(features=feature_pb2.Features(feature={
...@@ -42,6 +43,10 @@ class InputReaderBuilderTest(tf.test.TestCase): ...@@ -42,6 +43,10 @@ class InputReaderBuilderTest(tf.test.TestCase):
bytes_list=feature_pb2.BytesList(value=[encoded_jpeg])), bytes_list=feature_pb2.BytesList(value=[encoded_jpeg])),
'image/format': feature_pb2.Feature( 'image/format': feature_pb2.Feature(
bytes_list=feature_pb2.BytesList(value=['jpeg'.encode('utf-8')])), bytes_list=feature_pb2.BytesList(value=['jpeg'.encode('utf-8')])),
'image/height': feature_pb2.Feature(
int64_list=feature_pb2.Int64List(value=[4])),
'image/width': feature_pb2.Feature(
int64_list=feature_pb2.Int64List(value=[5])),
'image/object/bbox/xmin': feature_pb2.Feature( 'image/object/bbox/xmin': feature_pb2.Feature(
float_list=feature_pb2.FloatList(value=[0.0])), float_list=feature_pb2.FloatList(value=[0.0])),
'image/object/bbox/xmax': feature_pb2.Feature( 'image/object/bbox/xmax': feature_pb2.Feature(
...@@ -52,6 +57,8 @@ class InputReaderBuilderTest(tf.test.TestCase): ...@@ -52,6 +57,8 @@ class InputReaderBuilderTest(tf.test.TestCase):
float_list=feature_pb2.FloatList(value=[1.0])), float_list=feature_pb2.FloatList(value=[1.0])),
'image/object/class/label': feature_pb2.Feature( 'image/object/class/label': feature_pb2.Feature(
int64_list=feature_pb2.Int64List(value=[2])), int64_list=feature_pb2.Int64List(value=[2])),
'image/object/mask': feature_pb2.Feature(
float_list=feature_pb2.FloatList(value=flat_mask)),
})) }))
writer.write(example.SerializeToString()) writer.write(example.SerializeToString())
writer.close() writer.close()
...@@ -77,6 +84,8 @@ class InputReaderBuilderTest(tf.test.TestCase): ...@@ -77,6 +84,8 @@ class InputReaderBuilderTest(tf.test.TestCase):
sv.start_queue_runners(sess) sv.start_queue_runners(sess)
output_dict = sess.run(tensor_dict) output_dict = sess.run(tensor_dict)
self.assertTrue(fields.InputDataFields.groundtruth_instance_masks
not in output_dict)
self.assertEquals( self.assertEquals(
(4, 5, 3), output_dict[fields.InputDataFields.image].shape) (4, 5, 3), output_dict[fields.InputDataFields.image].shape)
self.assertEquals( self.assertEquals(
...@@ -87,6 +96,49 @@ class InputReaderBuilderTest(tf.test.TestCase): ...@@ -87,6 +96,49 @@ class InputReaderBuilderTest(tf.test.TestCase):
[0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 1.0, 1.0],
output_dict[fields.InputDataFields.groundtruth_boxes][0]) output_dict[fields.InputDataFields.groundtruth_boxes][0])
def test_build_tf_record_input_reader_and_load_instance_masks(self):
tf_record_path = self.create_tf_record()
input_reader_text_proto = """
shuffle: false
num_readers: 1
load_instance_masks: true
tf_record_input_reader {{
input_path: '{0}'
}}
""".format(tf_record_path)
input_reader_proto = input_reader_pb2.InputReader()
text_format.Merge(input_reader_text_proto, input_reader_proto)
tensor_dict = input_reader_builder.build(input_reader_proto)
sv = tf.train.Supervisor(logdir=self.get_temp_dir())
with sv.prepare_or_wait_for_session() as sess:
sv.start_queue_runners(sess)
output_dict = sess.run(tensor_dict)
self.assertEquals(
(4, 5, 3), output_dict[fields.InputDataFields.image].shape)
self.assertEquals(
[2], output_dict[fields.InputDataFields.groundtruth_classes])
self.assertEquals(
(1, 4), output_dict[fields.InputDataFields.groundtruth_boxes].shape)
self.assertAllEqual(
[0.0, 0.0, 1.0, 1.0],
output_dict[fields.InputDataFields.groundtruth_boxes][0])
self.assertAllEqual(
(1, 4, 5),
output_dict[fields.InputDataFields.groundtruth_instance_masks].shape)
def test_raises_error_with_no_input_paths(self):
input_reader_text_proto = """
shuffle: false
num_readers: 1
load_instance_masks: true
"""
input_reader_proto = input_reader_pb2.InputReader()
text_format.Merge(input_reader_text_proto, input_reader_proto)
with self.assertRaises(ValueError):
input_reader_builder.build(input_reader_proto)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
...@@ -34,6 +34,9 @@ def build(loss_config): ...@@ -34,6 +34,9 @@ def build(loss_config):
classification_weight: Classification loss weight. classification_weight: Classification loss weight.
localization_weight: Localization loss weight. localization_weight: Localization loss weight.
hard_example_miner: Hard example miner object. hard_example_miner: Hard example miner object.
Raises:
ValueError: If hard_example_miner is used with sigmoid_focal_loss.
""" """
classification_loss = _build_classification_loss( classification_loss = _build_classification_loss(
loss_config.classification_loss) loss_config.classification_loss)
...@@ -43,6 +46,10 @@ def build(loss_config): ...@@ -43,6 +46,10 @@ def build(loss_config):
localization_weight = loss_config.localization_weight localization_weight = loss_config.localization_weight
hard_example_miner = None hard_example_miner = None
if loss_config.HasField('hard_example_miner'): if loss_config.HasField('hard_example_miner'):
if (loss_config.classification_loss.WhichOneof('classification_loss') ==
'weighted_sigmoid_focal'):
raise ValueError('HardExampleMiner should not be used with sigmoid focal '
'loss')
hard_example_miner = build_hard_example_miner( hard_example_miner = build_hard_example_miner(
loss_config.hard_example_miner, loss_config.hard_example_miner,
classification_weight, classification_weight,
...@@ -91,6 +98,38 @@ def build_hard_example_miner(config, ...@@ -91,6 +98,38 @@ def build_hard_example_miner(config,
return hard_example_miner return hard_example_miner
def build_faster_rcnn_classification_loss(loss_config):
"""Builds a classification loss for Faster RCNN based on the loss config.
Args:
loss_config: A losses_pb2.ClassificationLoss object.
Returns:
Loss based on the config.
Raises:
ValueError: On invalid loss_config.
"""
if not isinstance(loss_config, losses_pb2.ClassificationLoss):
raise ValueError('loss_config not of type losses_pb2.ClassificationLoss.')
loss_type = loss_config.WhichOneof('classification_loss')
if loss_type == 'weighted_sigmoid':
config = loss_config.weighted_sigmoid
return losses.WeightedSigmoidClassificationLoss(
anchorwise_output=config.anchorwise_output)
if loss_type == 'weighted_softmax':
config = loss_config.weighted_softmax
return losses.WeightedSoftmaxClassificationLoss(
anchorwise_output=config.anchorwise_output)
# By default, Faster RCNN second stage classifier uses Softmax loss
# with anchor-wise outputs.
return losses.WeightedSoftmaxClassificationLoss(
anchorwise_output=True)
def _build_localization_loss(loss_config): def _build_localization_loss(loss_config):
"""Builds a localization loss based on the loss config. """Builds a localization loss based on the loss config.
...@@ -146,10 +185,21 @@ def _build_classification_loss(loss_config): ...@@ -146,10 +185,21 @@ def _build_classification_loss(loss_config):
return losses.WeightedSigmoidClassificationLoss( return losses.WeightedSigmoidClassificationLoss(
anchorwise_output=config.anchorwise_output) anchorwise_output=config.anchorwise_output)
if loss_type == 'weighted_sigmoid_focal':
config = loss_config.weighted_sigmoid_focal
alpha = None
if config.HasField('alpha'):
alpha = config.alpha
return losses.SigmoidFocalClassificationLoss(
anchorwise_output=config.anchorwise_output,
gamma=config.gamma,
alpha=alpha)
if loss_type == 'weighted_softmax': if loss_type == 'weighted_softmax':
config = loss_config.weighted_softmax config = loss_config.weighted_softmax
return losses.WeightedSoftmaxClassificationLoss( return losses.WeightedSoftmaxClassificationLoss(
anchorwise_output=config.anchorwise_output) anchorwise_output=config.anchorwise_output,
logit_scale=config.logit_scale)
if loss_type == 'bootstrapped_sigmoid': if loss_type == 'bootstrapped_sigmoid':
config = loss_config.bootstrapped_sigmoid config = loss_config.bootstrapped_sigmoid
......
...@@ -131,6 +131,46 @@ class ClassificationLossBuilderTest(tf.test.TestCase): ...@@ -131,6 +131,46 @@ class ClassificationLossBuilderTest(tf.test.TestCase):
self.assertTrue(isinstance(classification_loss, self.assertTrue(isinstance(classification_loss,
losses.WeightedSigmoidClassificationLoss)) losses.WeightedSigmoidClassificationLoss))
def test_build_weighted_sigmoid_focal_classification_loss(self):
losses_text_proto = """
classification_loss {
weighted_sigmoid_focal {
}
}
localization_loss {
weighted_l2 {
}
}
"""
losses_proto = losses_pb2.Loss()
text_format.Merge(losses_text_proto, losses_proto)
classification_loss, _, _, _, _ = losses_builder.build(losses_proto)
self.assertTrue(isinstance(classification_loss,
losses.SigmoidFocalClassificationLoss))
self.assertAlmostEqual(classification_loss._alpha, None)
self.assertAlmostEqual(classification_loss._gamma, 2.0)
def test_build_weighted_sigmoid_focal_loss_non_default(self):
losses_text_proto = """
classification_loss {
weighted_sigmoid_focal {
alpha: 0.25
gamma: 3.0
}
}
localization_loss {
weighted_l2 {
}
}
"""
losses_proto = losses_pb2.Loss()
text_format.Merge(losses_text_proto, losses_proto)
classification_loss, _, _, _, _ = losses_builder.build(losses_proto)
self.assertTrue(isinstance(classification_loss,
losses.SigmoidFocalClassificationLoss))
self.assertAlmostEqual(classification_loss._alpha, 0.25)
self.assertAlmostEqual(classification_loss._gamma, 3.0)
def test_build_weighted_softmax_classification_loss(self): def test_build_weighted_softmax_classification_loss(self):
losses_text_proto = """ losses_text_proto = """
classification_loss { classification_loss {
...@@ -148,6 +188,24 @@ class ClassificationLossBuilderTest(tf.test.TestCase): ...@@ -148,6 +188,24 @@ class ClassificationLossBuilderTest(tf.test.TestCase):
self.assertTrue(isinstance(classification_loss, self.assertTrue(isinstance(classification_loss,
losses.WeightedSoftmaxClassificationLoss)) losses.WeightedSoftmaxClassificationLoss))
def test_build_weighted_softmax_classification_loss_with_logit_scale(self):
losses_text_proto = """
classification_loss {
weighted_softmax {
logit_scale: 2.0
}
}
localization_loss {
weighted_l2 {
}
}
"""
losses_proto = losses_pb2.Loss()
text_format.Merge(losses_text_proto, losses_proto)
classification_loss, _, _, _, _ = losses_builder.build(losses_proto)
self.assertTrue(isinstance(classification_loss,
losses.WeightedSoftmaxClassificationLoss))
def test_build_bootstrapped_sigmoid_classification_loss(self): def test_build_bootstrapped_sigmoid_classification_loss(self):
losses_text_proto = """ losses_text_proto = """
classification_loss { classification_loss {
...@@ -318,6 +376,63 @@ class LossBuilderTest(tf.test.TestCase): ...@@ -318,6 +376,63 @@ class LossBuilderTest(tf.test.TestCase):
self.assertAlmostEqual(classification_weight, 0.8) self.assertAlmostEqual(classification_weight, 0.8)
self.assertAlmostEqual(localization_weight, 0.2) self.assertAlmostEqual(localization_weight, 0.2)
def test_raise_error_when_both_focal_loss_and_hard_example_miner(self):
losses_text_proto = """
localization_loss {
weighted_l2 {
}
}
classification_loss {
weighted_sigmoid_focal {
}
}
hard_example_miner {
}
classification_weight: 0.8
localization_weight: 0.2
"""
losses_proto = losses_pb2.Loss()
text_format.Merge(losses_text_proto, losses_proto)
with self.assertRaises(ValueError):
losses_builder.build(losses_proto)
class FasterRcnnClassificationLossBuilderTest(tf.test.TestCase):
def test_build_sigmoid_loss(self):
losses_text_proto = """
weighted_sigmoid {
}
"""
losses_proto = losses_pb2.ClassificationLoss()
text_format.Merge(losses_text_proto, losses_proto)
classification_loss = losses_builder.build_faster_rcnn_classification_loss(
losses_proto)
self.assertTrue(isinstance(classification_loss,
losses.WeightedSigmoidClassificationLoss))
def test_build_softmax_loss(self):
losses_text_proto = """
weighted_softmax {
}
"""
losses_proto = losses_pb2.ClassificationLoss()
text_format.Merge(losses_text_proto, losses_proto)
classification_loss = losses_builder.build_faster_rcnn_classification_loss(
losses_proto)
self.assertTrue(isinstance(classification_loss,
losses.WeightedSoftmaxClassificationLoss))
def test_build_softmax_loss_by_default(self):
losses_text_proto = """
"""
losses_proto = losses_pb2.ClassificationLoss()
text_format.Merge(losses_text_proto, losses_proto)
classification_loss = losses_builder.build_faster_rcnn_classification_loss(
losses_proto)
self.assertTrue(isinstance(classification_loss,
losses.WeightedSoftmaxClassificationLoss))
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
...@@ -18,8 +18,6 @@ ...@@ -18,8 +18,6 @@
import tensorflow as tf import tensorflow as tf
from object_detection.utils import learning_schedules from object_detection.utils import learning_schedules
slim = tf.contrib.slim
def build(optimizer_config, global_summaries): def build(optimizer_config, global_summaries):
"""Create optimizer based on config. """Create optimizer based on config.
...@@ -89,7 +87,7 @@ def _create_learning_rate(learning_rate_config, global_summaries): ...@@ -89,7 +87,7 @@ def _create_learning_rate(learning_rate_config, global_summaries):
config = learning_rate_config.exponential_decay_learning_rate config = learning_rate_config.exponential_decay_learning_rate
learning_rate = tf.train.exponential_decay( learning_rate = tf.train.exponential_decay(
config.initial_learning_rate, config.initial_learning_rate,
slim.get_or_create_global_step(), tf.train.get_or_create_global_step(),
config.decay_steps, config.decay_steps,
config.decay_factor, config.decay_factor,
staircase=config.staircase) staircase=config.staircase)
...@@ -102,11 +100,20 @@ def _create_learning_rate(learning_rate_config, global_summaries): ...@@ -102,11 +100,20 @@ def _create_learning_rate(learning_rate_config, global_summaries):
learning_rate_sequence = [config.initial_learning_rate] learning_rate_sequence = [config.initial_learning_rate]
learning_rate_sequence += [x.learning_rate for x in config.schedule] learning_rate_sequence += [x.learning_rate for x in config.schedule]
learning_rate = learning_schedules.manual_stepping( learning_rate = learning_schedules.manual_stepping(
slim.get_or_create_global_step(), learning_rate_step_boundaries, tf.train.get_or_create_global_step(), learning_rate_step_boundaries,
learning_rate_sequence) learning_rate_sequence)
if learning_rate_type == 'cosine_decay_learning_rate':
config = learning_rate_config.cosine_decay_learning_rate
learning_rate = learning_schedules.cosine_decay_with_warmup(
tf.train.get_or_create_global_step(),
config.learning_rate_base,
config.total_steps,
config.warmup_learning_rate,
config.warmup_steps)
if learning_rate is None: if learning_rate is None:
raise ValueError('Learning_rate %s not supported.' % learning_rate_type) raise ValueError('Learning_rate %s not supported.' % learning_rate_type)
global_summaries.add(tf.summary.scalar('Learning Rate', learning_rate)) global_summaries.add(tf.summary.scalar('Learning_Rate', learning_rate))
return learning_rate return learning_rate
...@@ -74,6 +74,22 @@ class LearningRateBuilderTest(tf.test.TestCase): ...@@ -74,6 +74,22 @@ class LearningRateBuilderTest(tf.test.TestCase):
learning_rate_proto, global_summaries) learning_rate_proto, global_summaries)
self.assertTrue(isinstance(learning_rate, tf.Tensor)) self.assertTrue(isinstance(learning_rate, tf.Tensor))
def testBuildCosineDecayLearningRate(self):
learning_rate_text_proto = """
cosine_decay_learning_rate {
learning_rate_base: 0.002
total_steps: 20000
warmup_learning_rate: 0.0001
warmup_steps: 1000
}
"""
global_summaries = set([])
learning_rate_proto = optimizer_pb2.LearningRate()
text_format.Merge(learning_rate_text_proto, learning_rate_proto)
learning_rate = optimizer_builder._create_learning_rate(
learning_rate_proto, global_summaries)
self.assertTrue(isinstance(learning_rate, tf.Tensor))
def testRaiseErrorOnEmptyLearningRate(self): def testRaiseErrorOnEmptyLearningRate(self):
learning_rate_text_proto = """ learning_rate_text_proto = """
""" """
...@@ -180,7 +196,7 @@ class OptimizerBuilderTest(tf.test.TestCase): ...@@ -180,7 +196,7 @@ class OptimizerBuilderTest(tf.test.TestCase):
optimizer = optimizer_builder.build(optimizer_proto, global_summaries) optimizer = optimizer_builder.build(optimizer_proto, global_summaries)
self.assertTrue( self.assertTrue(
isinstance(optimizer, tf.contrib.opt.MovingAverageOptimizer)) isinstance(optimizer, tf.contrib.opt.MovingAverageOptimizer))
# TODO: Find a way to not depend on the private members. # TODO(rathodv): Find a way to not depend on the private members.
self.assertAlmostEqual(optimizer._ema._decay, 0.2) self.assertAlmostEqual(optimizer._ema._decay, 0.2)
def testBuildEmptyOptimizer(self): def testBuildEmptyOptimizer(self):
......
...@@ -72,7 +72,6 @@ def _get_dict_from_proto(config): ...@@ -72,7 +72,6 @@ def _get_dict_from_proto(config):
# with _get_dict_from_proto. # with _get_dict_from_proto.
PREPROCESSING_FUNCTION_MAP = { PREPROCESSING_FUNCTION_MAP = {
'normalize_image': preprocessor.normalize_image, 'normalize_image': preprocessor.normalize_image,
'random_horizontal_flip': preprocessor.random_horizontal_flip,
'random_pixel_value_scale': preprocessor.random_pixel_value_scale, 'random_pixel_value_scale': preprocessor.random_pixel_value_scale,
'random_image_scale': preprocessor.random_image_scale, 'random_image_scale': preprocessor.random_image_scale,
'random_rgb_to_gray': preprocessor.random_rgb_to_gray, 'random_rgb_to_gray': preprocessor.random_rgb_to_gray,
...@@ -123,6 +122,25 @@ def build(preprocessor_step_config): ...@@ -123,6 +122,25 @@ def build(preprocessor_step_config):
function_args = _get_dict_from_proto(step_config) function_args = _get_dict_from_proto(step_config)
return (preprocessing_function, function_args) return (preprocessing_function, function_args)
if step_type == 'random_horizontal_flip':
config = preprocessor_step_config.random_horizontal_flip
return (preprocessor.random_horizontal_flip,
{
'keypoint_flip_permutation': tuple(
config.keypoint_flip_permutation),
})
if step_type == 'random_vertical_flip':
config = preprocessor_step_config.random_vertical_flip
return (preprocessor.random_vertical_flip,
{
'keypoint_flip_permutation': tuple(
config.keypoint_flip_permutation),
})
if step_type == 'random_rotation90':
return (preprocessor.random_rotation90, {})
if step_type == 'random_crop_image': if step_type == 'random_crop_image':
config = preprocessor_step_config.random_crop_image config = preprocessor_step_config.random_crop_image
return (preprocessor.random_crop_image, return (preprocessor.random_crop_image,
...@@ -274,4 +292,32 @@ def build(preprocessor_step_config): ...@@ -274,4 +292,32 @@ def build(preprocessor_step_config):
}) })
return (preprocessor.ssd_random_crop_fixed_aspect_ratio, {}) return (preprocessor.ssd_random_crop_fixed_aspect_ratio, {})
if step_type == 'ssd_random_crop_pad_fixed_aspect_ratio':
config = preprocessor_step_config.ssd_random_crop_pad_fixed_aspect_ratio
if config.operations:
min_object_covered = [op.min_object_covered for op in config.operations]
aspect_ratio_range = [(op.min_aspect_ratio, op.max_aspect_ratio)
for op in config.operations]
area_range = [(op.min_area, op.max_area) for op in config.operations]
overlap_thresh = [op.overlap_thresh for op in config.operations]
random_coef = [op.random_coef for op in config.operations]
min_padded_size_ratio = [
(op.min_padded_size_ratio[0], op.min_padded_size_ratio[1])
for op in config.operations]
max_padded_size_ratio = [
(op.max_padded_size_ratio[0], op.max_padded_size_ratio[1])
for op in config.operations]
return (preprocessor.ssd_random_crop_pad_fixed_aspect_ratio,
{
'min_object_covered': min_object_covered,
'aspect_ratio': config.aspect_ratio,
'aspect_ratio_range': aspect_ratio_range,
'area_range': area_range,
'overlap_thresh': overlap_thresh,
'random_coef': random_coef,
'min_padded_size_ratio': min_padded_size_ratio,
'max_padded_size_ratio': max_padded_size_ratio,
})
return (preprocessor.ssd_random_crop_pad_fixed_aspect_ratio, {})
raise ValueError('Unknown preprocessing step.') raise ValueError('Unknown preprocessing step.')
...@@ -59,12 +59,45 @@ class PreprocessorBuilderTest(tf.test.TestCase): ...@@ -59,12 +59,45 @@ class PreprocessorBuilderTest(tf.test.TestCase):
def test_build_random_horizontal_flip(self): def test_build_random_horizontal_flip(self):
preprocessor_text_proto = """ preprocessor_text_proto = """
random_horizontal_flip { random_horizontal_flip {
keypoint_flip_permutation: 1
keypoint_flip_permutation: 0
keypoint_flip_permutation: 2
keypoint_flip_permutation: 3
keypoint_flip_permutation: 5
keypoint_flip_permutation: 4
} }
""" """
preprocessor_proto = preprocessor_pb2.PreprocessingStep() preprocessor_proto = preprocessor_pb2.PreprocessingStep()
text_format.Merge(preprocessor_text_proto, preprocessor_proto) text_format.Merge(preprocessor_text_proto, preprocessor_proto)
function, args = preprocessor_builder.build(preprocessor_proto) function, args = preprocessor_builder.build(preprocessor_proto)
self.assertEqual(function, preprocessor.random_horizontal_flip) self.assertEqual(function, preprocessor.random_horizontal_flip)
self.assertEqual(args, {'keypoint_flip_permutation': (1, 0, 2, 3, 5, 4)})
def test_build_random_vertical_flip(self):
preprocessor_text_proto = """
random_vertical_flip {
keypoint_flip_permutation: 1
keypoint_flip_permutation: 0
keypoint_flip_permutation: 2
keypoint_flip_permutation: 3
keypoint_flip_permutation: 5
keypoint_flip_permutation: 4
}
"""
preprocessor_proto = preprocessor_pb2.PreprocessingStep()
text_format.Merge(preprocessor_text_proto, preprocessor_proto)
function, args = preprocessor_builder.build(preprocessor_proto)
self.assertEqual(function, preprocessor.random_vertical_flip)
self.assertEqual(args, {'keypoint_flip_permutation': (1, 0, 2, 3, 5, 4)})
def test_build_random_rotation90(self):
preprocessor_text_proto = """
random_rotation90 {}
"""
preprocessor_proto = preprocessor_pb2.PreprocessingStep()
text_format.Merge(preprocessor_text_proto, preprocessor_proto)
function, args = preprocessor_builder.build(preprocessor_proto)
self.assertEqual(function, preprocessor.random_rotation90)
self.assertEqual(args, {}) self.assertEqual(args, {})
def test_build_random_pixel_value_scale(self): def test_build_random_pixel_value_scale(self):
...@@ -382,7 +415,7 @@ class PreprocessorBuilderTest(tf.test.TestCase): ...@@ -382,7 +415,7 @@ class PreprocessorBuilderTest(tf.test.TestCase):
max_area: 1.0 max_area: 1.0
overlap_thresh: 0.0 overlap_thresh: 0.0
random_coef: 0.375 random_coef: 0.375
min_padded_size_ratio: [0.0, 0.0] min_padded_size_ratio: [1.0, 1.0]
max_padded_size_ratio: [2.0, 2.0] max_padded_size_ratio: [2.0, 2.0]
pad_color_r: 0.5 pad_color_r: 0.5
pad_color_g: 0.5 pad_color_g: 0.5
...@@ -396,7 +429,7 @@ class PreprocessorBuilderTest(tf.test.TestCase): ...@@ -396,7 +429,7 @@ class PreprocessorBuilderTest(tf.test.TestCase):
max_area: 1.0 max_area: 1.0
overlap_thresh: 0.25 overlap_thresh: 0.25
random_coef: 0.375 random_coef: 0.375
min_padded_size_ratio: [0.0, 0.0] min_padded_size_ratio: [1.0, 1.0]
max_padded_size_ratio: [2.0, 2.0] max_padded_size_ratio: [2.0, 2.0]
pad_color_r: 0.5 pad_color_r: 0.5
pad_color_g: 0.5 pad_color_g: 0.5
...@@ -413,7 +446,7 @@ class PreprocessorBuilderTest(tf.test.TestCase): ...@@ -413,7 +446,7 @@ class PreprocessorBuilderTest(tf.test.TestCase):
'area_range': [(0.5, 1.0), (0.5, 1.0)], 'area_range': [(0.5, 1.0), (0.5, 1.0)],
'overlap_thresh': [0.0, 0.25], 'overlap_thresh': [0.0, 0.25],
'random_coef': [0.375, 0.375], 'random_coef': [0.375, 0.375],
'min_padded_size_ratio': [(0.0, 0.0), (0.0, 0.0)], 'min_padded_size_ratio': [(1.0, 1.0), (1.0, 1.0)],
'max_padded_size_ratio': [(2.0, 2.0), (2.0, 2.0)], 'max_padded_size_ratio': [(2.0, 2.0), (2.0, 2.0)],
'pad_color': [(0.5, 0.5, 0.5), (0.5, 0.5, 0.5)]}) 'pad_color': [(0.5, 0.5, 0.5), (0.5, 0.5, 0.5)]})
...@@ -447,6 +480,48 @@ class PreprocessorBuilderTest(tf.test.TestCase): ...@@ -447,6 +480,48 @@ class PreprocessorBuilderTest(tf.test.TestCase):
'overlap_thresh': [0.0, 0.25], 'overlap_thresh': [0.0, 0.25],
'random_coef': [0.375, 0.375]}) 'random_coef': [0.375, 0.375]})
def test_build_ssd_random_crop_pad_fixed_aspect_ratio(self):
preprocessor_text_proto = """
ssd_random_crop_pad_fixed_aspect_ratio {
operations {
min_object_covered: 0.0
min_aspect_ratio: 0.875
max_aspect_ratio: 1.125
min_area: 0.5
max_area: 1.0
overlap_thresh: 0.0
random_coef: 0.375
min_padded_size_ratio: [1.0, 1.0]
max_padded_size_ratio: [2.0, 2.0]
}
operations {
min_object_covered: 0.25
min_aspect_ratio: 0.75
max_aspect_ratio: 1.5
min_area: 0.5
max_area: 1.0
overlap_thresh: 0.25
random_coef: 0.375
min_padded_size_ratio: [1.0, 1.0]
max_padded_size_ratio: [2.0, 2.0]
}
aspect_ratio: 0.875
}
"""
preprocessor_proto = preprocessor_pb2.PreprocessingStep()
text_format.Merge(preprocessor_text_proto, preprocessor_proto)
function, args = preprocessor_builder.build(preprocessor_proto)
self.assertEqual(function,
preprocessor.ssd_random_crop_pad_fixed_aspect_ratio)
self.assertEqual(args, {'min_object_covered': [0.0, 0.25],
'aspect_ratio': 0.875,
'aspect_ratio_range': [(0.875, 1.125), (0.75, 1.5)],
'area_range': [(0.5, 1.0), (0.5, 1.0)],
'overlap_thresh': [0.0, 0.25],
'random_coef': [0.375, 0.375],
'min_padded_size_ratio': [(1.0, 1.0), (1.0, 1.0)],
'max_padded_size_ratio': [(2.0, 2.0), (2.0, 2.0)]})
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment