Commit a4944a57 authored by derekjchow's avatar derekjchow Committed by Sergio Guadarrama
Browse files

Add Tensorflow Object Detection API. (#1561)

For details see our paper:
"Speed/accuracy trade-offs for modern convolutional object detectors."
Huang J, Rathod V, Sun C, Zhu M, Korattikara A, Fathi A, Fischer I,
Wojna Z, Song Y, Guadarrama S, Murphy K, CVPR 2017
https://arxiv.org/abs/1611.10012
parent 60c3ed2e
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functions to build DetectionModel training optimizers."""
import tensorflow as tf
from object_detection.utils import learning_schedules
slim = tf.contrib.slim
def build(optimizer_config, global_summaries):
"""Create optimizer based on config.
Args:
optimizer_config: A Optimizer proto message.
global_summaries: A set to attach learning rate summary to.
Returns:
An optimizer.
Raises:
ValueError: when using an unsupported input data type.
"""
optimizer_type = optimizer_config.WhichOneof('optimizer')
optimizer = None
if optimizer_type == 'rms_prop_optimizer':
config = optimizer_config.rms_prop_optimizer
optimizer = tf.train.RMSPropOptimizer(
_create_learning_rate(config.learning_rate, global_summaries),
decay=config.decay,
momentum=config.momentum_optimizer_value,
epsilon=config.epsilon)
if optimizer_type == 'momentum_optimizer':
config = optimizer_config.momentum_optimizer
optimizer = tf.train.MomentumOptimizer(
_create_learning_rate(config.learning_rate, global_summaries),
momentum=config.momentum_optimizer_value)
if optimizer_type == 'adam_optimizer':
config = optimizer_config.adam_optimizer
optimizer = tf.train.AdamOptimizer(
_create_learning_rate(config.learning_rate, global_summaries))
if optimizer is None:
raise ValueError('Optimizer %s not supported.' % optimizer_type)
if optimizer_config.use_moving_average:
optimizer = tf.contrib.opt.MovingAverageOptimizer(
optimizer, average_decay=optimizer_config.moving_average_decay)
return optimizer
def _create_learning_rate(learning_rate_config, global_summaries):
"""Create optimizer learning rate based on config.
Args:
learning_rate_config: A LearningRate proto message.
global_summaries: A set to attach learning rate summary to.
Returns:
A learning rate.
Raises:
ValueError: when using an unsupported input data type.
"""
learning_rate = None
learning_rate_type = learning_rate_config.WhichOneof('learning_rate')
if learning_rate_type == 'constant_learning_rate':
config = learning_rate_config.constant_learning_rate
learning_rate = config.learning_rate
if learning_rate_type == 'exponential_decay_learning_rate':
config = learning_rate_config.exponential_decay_learning_rate
learning_rate = tf.train.exponential_decay(
config.initial_learning_rate,
slim.get_or_create_global_step(),
config.decay_steps,
config.decay_factor,
staircase=config.staircase)
if learning_rate_type == 'manual_step_learning_rate':
config = learning_rate_config.manual_step_learning_rate
if not config.schedule:
raise ValueError('Empty learning rate schedule.')
learning_rate_step_boundaries = [x.step for x in config.schedule]
learning_rate_sequence = [config.initial_learning_rate]
learning_rate_sequence += [x.learning_rate for x in config.schedule]
learning_rate = learning_schedules.manual_stepping(
slim.get_or_create_global_step(), learning_rate_step_boundaries,
learning_rate_sequence)
if learning_rate is None:
raise ValueError('Learning_rate %s not supported.' % learning_rate_type)
global_summaries.add(tf.summary.scalar('Learning Rate', learning_rate))
return learning_rate
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for optimizer_builder."""
import tensorflow as tf
from google.protobuf import text_format
from object_detection.builders import optimizer_builder
from object_detection.protos import optimizer_pb2
class LearningRateBuilderTest(tf.test.TestCase):
def testBuildConstantLearningRate(self):
learning_rate_text_proto = """
constant_learning_rate {
learning_rate: 0.004
}
"""
global_summaries = set([])
learning_rate_proto = optimizer_pb2.LearningRate()
text_format.Merge(learning_rate_text_proto, learning_rate_proto)
learning_rate = optimizer_builder._create_learning_rate(
learning_rate_proto, global_summaries)
self.assertAlmostEqual(learning_rate, 0.004)
def testBuildExponentialDecayLearningRate(self):
learning_rate_text_proto = """
exponential_decay_learning_rate {
initial_learning_rate: 0.004
decay_steps: 99999
decay_factor: 0.85
staircase: false
}
"""
global_summaries = set([])
learning_rate_proto = optimizer_pb2.LearningRate()
text_format.Merge(learning_rate_text_proto, learning_rate_proto)
learning_rate = optimizer_builder._create_learning_rate(
learning_rate_proto, global_summaries)
self.assertTrue(isinstance(learning_rate, tf.Tensor))
def testBuildManualStepLearningRate(self):
learning_rate_text_proto = """
manual_step_learning_rate {
schedule {
step: 0
learning_rate: 0.006
}
schedule {
step: 90000
learning_rate: 0.00006
}
}
"""
global_summaries = set([])
learning_rate_proto = optimizer_pb2.LearningRate()
text_format.Merge(learning_rate_text_proto, learning_rate_proto)
learning_rate = optimizer_builder._create_learning_rate(
learning_rate_proto, global_summaries)
self.assertTrue(isinstance(learning_rate, tf.Tensor))
def testRaiseErrorOnEmptyLearningRate(self):
learning_rate_text_proto = """
"""
global_summaries = set([])
learning_rate_proto = optimizer_pb2.LearningRate()
text_format.Merge(learning_rate_text_proto, learning_rate_proto)
with self.assertRaises(ValueError):
optimizer_builder._create_learning_rate(
learning_rate_proto, global_summaries)
class OptimizerBuilderTest(tf.test.TestCase):
def testBuildRMSPropOptimizer(self):
optimizer_text_proto = """
rms_prop_optimizer: {
learning_rate: {
exponential_decay_learning_rate {
initial_learning_rate: 0.004
decay_steps: 800720
decay_factor: 0.95
}
}
momentum_optimizer_value: 0.9
decay: 0.9
epsilon: 1.0
}
use_moving_average: false
"""
global_summaries = set([])
optimizer_proto = optimizer_pb2.Optimizer()
text_format.Merge(optimizer_text_proto, optimizer_proto)
optimizer = optimizer_builder.build(optimizer_proto, global_summaries)
self.assertTrue(isinstance(optimizer, tf.train.RMSPropOptimizer))
def testBuildMomentumOptimizer(self):
optimizer_text_proto = """
momentum_optimizer: {
learning_rate: {
constant_learning_rate {
learning_rate: 0.001
}
}
momentum_optimizer_value: 0.99
}
use_moving_average: false
"""
global_summaries = set([])
optimizer_proto = optimizer_pb2.Optimizer()
text_format.Merge(optimizer_text_proto, optimizer_proto)
optimizer = optimizer_builder.build(optimizer_proto, global_summaries)
self.assertTrue(isinstance(optimizer, tf.train.MomentumOptimizer))
def testBuildAdamOptimizer(self):
optimizer_text_proto = """
adam_optimizer: {
learning_rate: {
constant_learning_rate {
learning_rate: 0.002
}
}
}
use_moving_average: false
"""
global_summaries = set([])
optimizer_proto = optimizer_pb2.Optimizer()
text_format.Merge(optimizer_text_proto, optimizer_proto)
optimizer = optimizer_builder.build(optimizer_proto, global_summaries)
self.assertTrue(isinstance(optimizer, tf.train.AdamOptimizer))
def testBuildMovingAverageOptimizer(self):
optimizer_text_proto = """
adam_optimizer: {
learning_rate: {
constant_learning_rate {
learning_rate: 0.002
}
}
}
use_moving_average: True
"""
global_summaries = set([])
optimizer_proto = optimizer_pb2.Optimizer()
text_format.Merge(optimizer_text_proto, optimizer_proto)
optimizer = optimizer_builder.build(optimizer_proto, global_summaries)
self.assertTrue(
isinstance(optimizer, tf.contrib.opt.MovingAverageOptimizer))
def testBuildMovingAverageOptimizerWithNonDefaultDecay(self):
optimizer_text_proto = """
adam_optimizer: {
learning_rate: {
constant_learning_rate {
learning_rate: 0.002
}
}
}
use_moving_average: True
moving_average_decay: 0.2
"""
global_summaries = set([])
optimizer_proto = optimizer_pb2.Optimizer()
text_format.Merge(optimizer_text_proto, optimizer_proto)
optimizer = optimizer_builder.build(optimizer_proto, global_summaries)
self.assertTrue(
isinstance(optimizer, tf.contrib.opt.MovingAverageOptimizer))
# TODO: Find a way to not depend on the private members.
self.assertAlmostEqual(optimizer._ema._decay, 0.2)
def testBuildEmptyOptimizer(self):
optimizer_text_proto = """
"""
global_summaries = set([])
optimizer_proto = optimizer_pb2.Optimizer()
text_format.Merge(optimizer_text_proto, optimizer_proto)
with self.assertRaises(ValueError):
optimizer_builder.build(optimizer_proto, global_summaries)
if __name__ == '__main__':
tf.test.main()
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Builder function for post processing operations."""
import functools
import tensorflow as tf
from object_detection.core import post_processing
from object_detection.protos import post_processing_pb2
def build(post_processing_config):
"""Builds callables for post-processing operations.
Builds callables for non-max suppression and score conversion based on the
configuration.
Non-max suppression callable takes `boxes`, `scores`, and optionally
`clip_window`, `parallel_iterations` and `scope` as inputs. It returns
`nms_boxes`, `nms_scores`, `nms_nms_classes` and `num_detections`. See
post_processing.batch_multiclass_non_max_suppression for the type and shape
of these tensors.
Score converter callable should be called with `input` tensor. The callable
returns the output from one of 3 tf operations based on the configuration -
tf.identity, tf.sigmoid or tf.nn.softmax. See tensorflow documentation for
argument and return value descriptions.
Args:
post_processing_config: post_processing.proto object containing the
parameters for the post-processing operations.
Returns:
non_max_suppressor_fn: Callable for non-max suppression.
score_converter_fn: Callable for score conversion.
Raises:
ValueError: if the post_processing_config is of incorrect type.
"""
if not isinstance(post_processing_config, post_processing_pb2.PostProcessing):
raise ValueError('post_processing_config not of type '
'post_processing_pb2.Postprocessing.')
non_max_suppressor_fn = _build_non_max_suppressor(
post_processing_config.batch_non_max_suppression)
score_converter_fn = _build_score_converter(
post_processing_config.score_converter)
return non_max_suppressor_fn, score_converter_fn
def _build_non_max_suppressor(nms_config):
"""Builds non-max suppresson based on the nms config.
Args:
nms_config: post_processing_pb2.PostProcessing.BatchNonMaxSuppression proto.
Returns:
non_max_suppressor_fn: Callable non-max suppressor.
Raises:
ValueError: On incorrect iou_threshold or on incompatible values of
max_total_detections and max_detections_per_class.
"""
if nms_config.iou_threshold < 0 or nms_config.iou_threshold > 1.0:
raise ValueError('iou_threshold not in [0, 1.0].')
if nms_config.max_detections_per_class > nms_config.max_total_detections:
raise ValueError('max_detections_per_class should be no greater than '
'max_total_detections.')
non_max_suppressor_fn = functools.partial(
post_processing.batch_multiclass_non_max_suppression,
score_thresh=nms_config.score_threshold,
iou_thresh=nms_config.iou_threshold,
max_size_per_class=nms_config.max_detections_per_class,
max_total_size=nms_config.max_total_detections)
return non_max_suppressor_fn
def _build_score_converter(score_converter_config):
"""Builds score converter based on the config.
Builds one of [tf.identity, tf.sigmoid, tf.softmax] score converters based on
the config.
Args:
score_converter_config: post_processing_pb2.PostProcessing.score_converter.
Returns:
Callable score converter op.
Raises:
ValueError: On unknown score converter.
"""
if score_converter_config == post_processing_pb2.PostProcessing.IDENTITY:
return tf.identity
if score_converter_config == post_processing_pb2.PostProcessing.SIGMOID:
return tf.sigmoid
if score_converter_config == post_processing_pb2.PostProcessing.SOFTMAX:
return tf.nn.softmax
raise ValueError('Unknown score converter.')
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for post_processing_builder."""
import tensorflow as tf
from google.protobuf import text_format
from object_detection.builders import post_processing_builder
from object_detection.protos import post_processing_pb2
class PostProcessingBuilderTest(tf.test.TestCase):
def test_build_non_max_suppressor_with_correct_parameters(self):
post_processing_text_proto = """
batch_non_max_suppression {
score_threshold: 0.7
iou_threshold: 0.6
max_detections_per_class: 100
max_total_detections: 300
}
"""
post_processing_config = post_processing_pb2.PostProcessing()
text_format.Merge(post_processing_text_proto, post_processing_config)
non_max_suppressor, _ = post_processing_builder.build(
post_processing_config)
self.assertEqual(non_max_suppressor.keywords['max_size_per_class'], 100)
self.assertEqual(non_max_suppressor.keywords['max_total_size'], 300)
self.assertAlmostEqual(non_max_suppressor.keywords['score_thresh'], 0.7)
self.assertAlmostEqual(non_max_suppressor.keywords['iou_thresh'], 0.6)
def test_build_identity_score_converter(self):
post_processing_text_proto = """
score_converter: IDENTITY
"""
post_processing_config = post_processing_pb2.PostProcessing()
text_format.Merge(post_processing_text_proto, post_processing_config)
_, score_converter = post_processing_builder.build(post_processing_config)
self.assertEqual(score_converter, tf.identity)
def test_build_sigmoid_score_converter(self):
post_processing_text_proto = """
score_converter: SIGMOID
"""
post_processing_config = post_processing_pb2.PostProcessing()
text_format.Merge(post_processing_text_proto, post_processing_config)
_, score_converter = post_processing_builder.build(post_processing_config)
self.assertEqual(score_converter, tf.sigmoid)
def test_build_softmax_score_converter(self):
post_processing_text_proto = """
score_converter: SOFTMAX
"""
post_processing_config = post_processing_pb2.PostProcessing()
text_format.Merge(post_processing_text_proto, post_processing_config)
_, score_converter = post_processing_builder.build(post_processing_config)
self.assertEqual(score_converter, tf.nn.softmax)
if __name__ == '__main__':
tf.test.main()
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Builder for preprocessing steps."""
import tensorflow as tf
from object_detection.core import preprocessor
from object_detection.protos import preprocessor_pb2
def _get_step_config_from_proto(preprocessor_step_config, step_name):
"""Returns the value of a field named step_name from proto.
Args:
preprocessor_step_config: A preprocessor_pb2.PreprocessingStep object.
step_name: Name of the field to get value from.
Returns:
result_dict: a sub proto message from preprocessor_step_config which will be
later converted to a dictionary.
Raises:
ValueError: If field does not exist in proto.
"""
for field, value in preprocessor_step_config.ListFields():
if field.name == step_name:
return value
raise ValueError('Could not get field %s from proto!', step_name)
def _get_dict_from_proto(config):
"""Helper function to put all proto fields into a dictionary.
For many preprocessing steps, there's an trivial 1-1 mapping from proto fields
to function arguments. This function automatically populates a dictionary with
the arguments from the proto.
Protos that CANNOT be trivially populated include:
* nested messages.
* steps that check if an optional field is set (ie. where None != 0).
* protos that don't map 1-1 to arguments (ie. list should be reshaped).
* fields requiring additional validation (ie. repeated field has n elements).
Args:
config: A protobuf object that does not violate the conditions above.
Returns:
result_dict: |config| converted into a python dictionary.
"""
result_dict = {}
for field, value in config.ListFields():
result_dict[field.name] = value
return result_dict
# A map from a PreprocessingStep proto config field name to the preprocessing
# function that should be used. The PreprocessingStep proto should be parsable
# with _get_dict_from_proto.
PREPROCESSING_FUNCTION_MAP = {
'normalize_image': preprocessor.normalize_image,
'random_horizontal_flip': preprocessor.random_horizontal_flip,
'random_pixel_value_scale': preprocessor.random_pixel_value_scale,
'random_image_scale': preprocessor.random_image_scale,
'random_rgb_to_gray': preprocessor.random_rgb_to_gray,
'random_adjust_brightness': preprocessor.random_adjust_brightness,
'random_adjust_contrast': preprocessor.random_adjust_contrast,
'random_adjust_hue': preprocessor.random_adjust_hue,
'random_adjust_saturation': preprocessor.random_adjust_saturation,
'random_distort_color': preprocessor.random_distort_color,
'random_jitter_boxes': preprocessor.random_jitter_boxes,
'random_crop_to_aspect_ratio': preprocessor.random_crop_to_aspect_ratio,
'random_black_patches': preprocessor.random_black_patches,
'scale_boxes_to_pixel_coordinates': (
preprocessor.scale_boxes_to_pixel_coordinates),
'subtract_channel_mean': preprocessor.subtract_channel_mean,
}
# A map to convert from preprocessor_pb2.ResizeImage.Method enum to
# tf.image.ResizeMethod.
RESIZE_METHOD_MAP = {
preprocessor_pb2.ResizeImage.AREA: tf.image.ResizeMethod.AREA,
preprocessor_pb2.ResizeImage.BICUBIC: tf.image.ResizeMethod.BICUBIC,
preprocessor_pb2.ResizeImage.BILINEAR: tf.image.ResizeMethod.BILINEAR,
preprocessor_pb2.ResizeImage.NEAREST_NEIGHBOR: (
tf.image.ResizeMethod.NEAREST_NEIGHBOR),
}
def build(preprocessor_step_config):
"""Builds preprocessing step based on the configuration.
Args:
preprocessor_step_config: PreprocessingStep configuration proto.
Returns:
function, argmap: A callable function and an argument map to call function
with.
Raises:
ValueError: On invalid configuration.
"""
step_type = preprocessor_step_config.WhichOneof('preprocessing_step')
if step_type in PREPROCESSING_FUNCTION_MAP:
preprocessing_function = PREPROCESSING_FUNCTION_MAP[step_type]
step_config = _get_step_config_from_proto(preprocessor_step_config,
step_type)
function_args = _get_dict_from_proto(step_config)
return (preprocessing_function, function_args)
if step_type == 'random_crop_image':
config = preprocessor_step_config.random_crop_image
return (preprocessor.random_crop_image,
{
'min_object_covered': config.min_object_covered,
'aspect_ratio_range': (config.min_aspect_ratio,
config.max_aspect_ratio),
'area_range': (config.min_area, config.max_area),
'overlap_thresh': config.overlap_thresh,
'random_coef': config.random_coef,
})
if step_type == 'random_pad_image':
config = preprocessor_step_config.random_pad_image
min_image_size = None
if (config.HasField('min_image_height') !=
config.HasField('min_image_width')):
raise ValueError('min_image_height and min_image_width should be either '
'both set or both unset.')
if config.HasField('min_image_height'):
min_image_size = (config.min_image_height, config.min_image_width)
max_image_size = None
if (config.HasField('max_image_height') !=
config.HasField('max_image_width')):
raise ValueError('max_image_height and max_image_width should be either '
'both set or both unset.')
if config.HasField('max_image_height'):
max_image_size = (config.max_image_height, config.max_image_width)
pad_color = config.pad_color
if pad_color and len(pad_color) != 3:
raise ValueError('pad_color should have 3 elements (RGB) if set!')
if not pad_color:
pad_color = None
return (preprocessor.random_pad_image,
{
'min_image_size': min_image_size,
'max_image_size': max_image_size,
'pad_color': pad_color,
})
if step_type == 'random_crop_pad_image':
config = preprocessor_step_config.random_crop_pad_image
min_padded_size_ratio = config.min_padded_size_ratio
if min_padded_size_ratio and len(min_padded_size_ratio) != 2:
raise ValueError('min_padded_size_ratio should have 3 elements if set!')
max_padded_size_ratio = config.max_padded_size_ratio
if max_padded_size_ratio and len(max_padded_size_ratio) != 2:
raise ValueError('max_padded_size_ratio should have 3 elements if set!')
pad_color = config.pad_color
if pad_color and len(pad_color) != 3:
raise ValueError('pad_color should have 3 elements if set!')
return (preprocessor.random_crop_pad_image,
{
'min_object_covered': config.min_object_covered,
'aspect_ratio_range': (config.min_aspect_ratio,
config.max_aspect_ratio),
'area_range': (config.min_area, config.max_area),
'overlap_thresh': config.overlap_thresh,
'random_coef': config.random_coef,
'min_padded_size_ratio': (min_padded_size_ratio if
min_padded_size_ratio else None),
'max_padded_size_ratio': (max_padded_size_ratio if
max_padded_size_ratio else None),
'pad_color': (pad_color if pad_color else None),
})
if step_type == 'random_resize_method':
config = preprocessor_step_config.random_resize_method
return (preprocessor.random_resize_method,
{
'target_size': [config.target_height, config.target_width],
})
if step_type == 'resize_image':
config = preprocessor_step_config.resize_image
method = RESIZE_METHOD_MAP[config.method]
return (preprocessor.resize_image,
{
'new_height': config.new_height,
'new_width': config.new_width,
'method': method
})
if step_type == 'ssd_random_crop':
config = preprocessor_step_config.ssd_random_crop
if config.operations:
min_object_covered = [op.min_object_covered for op in config.operations]
aspect_ratio_range = [(op.min_aspect_ratio, op.max_aspect_ratio)
for op in config.operations]
area_range = [(op.min_area, op.max_area) for op in config.operations]
overlap_thresh = [op.overlap_thresh for op in config.operations]
random_coef = [op.random_coef for op in config.operations]
return (preprocessor.ssd_random_crop,
{
'min_object_covered': min_object_covered,
'aspect_ratio_range': aspect_ratio_range,
'area_range': area_range,
'overlap_thresh': overlap_thresh,
'random_coef': random_coef,
})
return (preprocessor.ssd_random_crop, {})
if step_type == 'ssd_random_crop_pad':
config = preprocessor_step_config.ssd_random_crop_pad
if config.operations:
min_object_covered = [op.min_object_covered for op in config.operations]
aspect_ratio_range = [(op.min_aspect_ratio, op.max_aspect_ratio)
for op in config.operations]
area_range = [(op.min_area, op.max_area) for op in config.operations]
overlap_thresh = [op.overlap_thresh for op in config.operations]
random_coef = [op.random_coef for op in config.operations]
min_padded_size_ratio = [
(op.min_padded_size_ratio[0], op.min_padded_size_ratio[1])
for op in config.operations]
max_padded_size_ratio = [
(op.max_padded_size_ratio[0], op.max_padded_size_ratio[1])
for op in config.operations]
pad_color = [(op.pad_color_r, op.pad_color_g, op.pad_color_b)
for op in config.operations]
return (preprocessor.ssd_random_crop_pad,
{
'min_object_covered': min_object_covered,
'aspect_ratio_range': aspect_ratio_range,
'area_range': area_range,
'overlap_thresh': overlap_thresh,
'random_coef': random_coef,
'min_padded_size_ratio': min_padded_size_ratio,
'max_padded_size_ratio': max_padded_size_ratio,
'pad_color': pad_color,
})
return (preprocessor.ssd_random_crop_pad, {})
if step_type == 'ssd_random_crop_fixed_aspect_ratio':
config = preprocessor_step_config.ssd_random_crop_fixed_aspect_ratio
if config.operations:
min_object_covered = [op.min_object_covered for op in config.operations]
area_range = [(op.min_area, op.max_area) for op in config.operations]
overlap_thresh = [op.overlap_thresh for op in config.operations]
random_coef = [op.random_coef for op in config.operations]
return (preprocessor.ssd_random_crop_fixed_aspect_ratio,
{
'min_object_covered': min_object_covered,
'aspect_ratio': config.aspect_ratio,
'area_range': area_range,
'overlap_thresh': overlap_thresh,
'random_coef': random_coef,
})
return (preprocessor.ssd_random_crop_fixed_aspect_ratio, {})
raise ValueError('Unknown preprocessing step.')
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for preprocessor_builder."""
import tensorflow as tf
from google.protobuf import text_format
from object_detection.builders import preprocessor_builder
from object_detection.core import preprocessor
from object_detection.protos import preprocessor_pb2
class PreprocessorBuilderTest(tf.test.TestCase):
def assert_dictionary_close(self, dict1, dict2):
"""Helper to check if two dicts with floatst or integers are close."""
self.assertEqual(sorted(dict1.keys()), sorted(dict2.keys()))
for key in dict1:
value = dict1[key]
if isinstance(value, float):
self.assertAlmostEqual(value, dict2[key])
else:
self.assertEqual(value, dict2[key])
def test_build_normalize_image(self):
preprocessor_text_proto = """
normalize_image {
original_minval: 0.0
original_maxval: 255.0
target_minval: -1.0
target_maxval: 1.0
}
"""
preprocessor_proto = preprocessor_pb2.PreprocessingStep()
text_format.Merge(preprocessor_text_proto, preprocessor_proto)
function, args = preprocessor_builder.build(preprocessor_proto)
self.assertEqual(function, preprocessor.normalize_image)
self.assertEqual(args, {
'original_minval': 0.0,
'original_maxval': 255.0,
'target_minval': -1.0,
'target_maxval': 1.0,
})
def test_build_random_horizontal_flip(self):
preprocessor_text_proto = """
random_horizontal_flip {
}
"""
preprocessor_proto = preprocessor_pb2.PreprocessingStep()
text_format.Merge(preprocessor_text_proto, preprocessor_proto)
function, args = preprocessor_builder.build(preprocessor_proto)
self.assertEqual(function, preprocessor.random_horizontal_flip)
self.assertEqual(args, {})
def test_build_random_pixel_value_scale(self):
preprocessor_text_proto = """
random_pixel_value_scale {
minval: 0.8
maxval: 1.2
}
"""
preprocessor_proto = preprocessor_pb2.PreprocessingStep()
text_format.Merge(preprocessor_text_proto, preprocessor_proto)
function, args = preprocessor_builder.build(preprocessor_proto)
self.assertEqual(function, preprocessor.random_pixel_value_scale)
self.assert_dictionary_close(args, {'minval': 0.8, 'maxval': 1.2})
def test_build_random_image_scale(self):
preprocessor_text_proto = """
random_image_scale {
min_scale_ratio: 0.8
max_scale_ratio: 2.2
}
"""
preprocessor_proto = preprocessor_pb2.PreprocessingStep()
text_format.Merge(preprocessor_text_proto, preprocessor_proto)
function, args = preprocessor_builder.build(preprocessor_proto)
self.assertEqual(function, preprocessor.random_image_scale)
self.assert_dictionary_close(args, {'min_scale_ratio': 0.8,
'max_scale_ratio': 2.2})
def test_build_random_rgb_to_gray(self):
preprocessor_text_proto = """
random_rgb_to_gray {
probability: 0.8
}
"""
preprocessor_proto = preprocessor_pb2.PreprocessingStep()
text_format.Merge(preprocessor_text_proto, preprocessor_proto)
function, args = preprocessor_builder.build(preprocessor_proto)
self.assertEqual(function, preprocessor.random_rgb_to_gray)
self.assert_dictionary_close(args, {'probability': 0.8})
def test_build_random_adjust_brightness(self):
preprocessor_text_proto = """
random_adjust_brightness {
max_delta: 0.2
}
"""
preprocessor_proto = preprocessor_pb2.PreprocessingStep()
text_format.Merge(preprocessor_text_proto, preprocessor_proto)
function, args = preprocessor_builder.build(preprocessor_proto)
self.assertEqual(function, preprocessor.random_adjust_brightness)
self.assert_dictionary_close(args, {'max_delta': 0.2})
def test_build_random_adjust_contrast(self):
preprocessor_text_proto = """
random_adjust_contrast {
min_delta: 0.7
max_delta: 1.1
}
"""
preprocessor_proto = preprocessor_pb2.PreprocessingStep()
text_format.Merge(preprocessor_text_proto, preprocessor_proto)
function, args = preprocessor_builder.build(preprocessor_proto)
self.assertEqual(function, preprocessor.random_adjust_contrast)
self.assert_dictionary_close(args, {'min_delta': 0.7, 'max_delta': 1.1})
def test_build_random_adjust_hue(self):
preprocessor_text_proto = """
random_adjust_hue {
max_delta: 0.01
}
"""
preprocessor_proto = preprocessor_pb2.PreprocessingStep()
text_format.Merge(preprocessor_text_proto, preprocessor_proto)
function, args = preprocessor_builder.build(preprocessor_proto)
self.assertEqual(function, preprocessor.random_adjust_hue)
self.assert_dictionary_close(args, {'max_delta': 0.01})
def test_build_random_adjust_saturation(self):
preprocessor_text_proto = """
random_adjust_saturation {
min_delta: 0.75
max_delta: 1.15
}
"""
preprocessor_proto = preprocessor_pb2.PreprocessingStep()
text_format.Merge(preprocessor_text_proto, preprocessor_proto)
function, args = preprocessor_builder.build(preprocessor_proto)
self.assertEqual(function, preprocessor.random_adjust_saturation)
self.assert_dictionary_close(args, {'min_delta': 0.75, 'max_delta': 1.15})
def test_build_random_distort_color(self):
preprocessor_text_proto = """
random_distort_color {
color_ordering: 1
}
"""
preprocessor_proto = preprocessor_pb2.PreprocessingStep()
text_format.Merge(preprocessor_text_proto, preprocessor_proto)
function, args = preprocessor_builder.build(preprocessor_proto)
self.assertEqual(function, preprocessor.random_distort_color)
self.assertEqual(args, {'color_ordering': 1})
def test_build_random_jitter_boxes(self):
preprocessor_text_proto = """
random_jitter_boxes {
ratio: 0.1
}
"""
preprocessor_proto = preprocessor_pb2.PreprocessingStep()
text_format.Merge(preprocessor_text_proto, preprocessor_proto)
function, args = preprocessor_builder.build(preprocessor_proto)
self.assertEqual(function, preprocessor.random_jitter_boxes)
self.assert_dictionary_close(args, {'ratio': 0.1})
def test_build_random_crop_image(self):
preprocessor_text_proto = """
random_crop_image {
min_object_covered: 0.75
min_aspect_ratio: 0.75
max_aspect_ratio: 1.5
min_area: 0.25
max_area: 0.875
overlap_thresh: 0.5
random_coef: 0.125
}
"""
preprocessor_proto = preprocessor_pb2.PreprocessingStep()
text_format.Merge(preprocessor_text_proto, preprocessor_proto)
function, args = preprocessor_builder.build(preprocessor_proto)
self.assertEqual(function, preprocessor.random_crop_image)
self.assertEqual(args, {
'min_object_covered': 0.75,
'aspect_ratio_range': (0.75, 1.5),
'area_range': (0.25, 0.875),
'overlap_thresh': 0.5,
'random_coef': 0.125,
})
def test_build_random_pad_image(self):
preprocessor_text_proto = """
random_pad_image {
}
"""
preprocessor_proto = preprocessor_pb2.PreprocessingStep()
text_format.Merge(preprocessor_text_proto, preprocessor_proto)
function, args = preprocessor_builder.build(preprocessor_proto)
self.assertEqual(function, preprocessor.random_pad_image)
self.assertEqual(args, {
'min_image_size': None,
'max_image_size': None,
'pad_color': None,
})
def test_build_random_crop_pad_image(self):
preprocessor_text_proto = """
random_crop_pad_image {
min_object_covered: 0.75
min_aspect_ratio: 0.75
max_aspect_ratio: 1.5
min_area: 0.25
max_area: 0.875
overlap_thresh: 0.5
random_coef: 0.125
}
"""
preprocessor_proto = preprocessor_pb2.PreprocessingStep()
text_format.Merge(preprocessor_text_proto, preprocessor_proto)
function, args = preprocessor_builder.build(preprocessor_proto)
self.assertEqual(function, preprocessor.random_crop_pad_image)
self.assertEqual(args, {
'min_object_covered': 0.75,
'aspect_ratio_range': (0.75, 1.5),
'area_range': (0.25, 0.875),
'overlap_thresh': 0.5,
'random_coef': 0.125,
'min_padded_size_ratio': None,
'max_padded_size_ratio': None,
'pad_color': None,
})
def test_build_random_crop_to_aspect_ratio(self):
preprocessor_text_proto = """
random_crop_to_aspect_ratio {
aspect_ratio: 0.85
overlap_thresh: 0.35
}
"""
preprocessor_proto = preprocessor_pb2.PreprocessingStep()
text_format.Merge(preprocessor_text_proto, preprocessor_proto)
function, args = preprocessor_builder.build(preprocessor_proto)
self.assertEqual(function, preprocessor.random_crop_to_aspect_ratio)
self.assert_dictionary_close(args, {'aspect_ratio': 0.85,
'overlap_thresh': 0.35})
def test_build_random_black_patches(self):
preprocessor_text_proto = """
random_black_patches {
max_black_patches: 20
probability: 0.95
size_to_image_ratio: 0.12
}
"""
preprocessor_proto = preprocessor_pb2.PreprocessingStep()
text_format.Merge(preprocessor_text_proto, preprocessor_proto)
function, args = preprocessor_builder.build(preprocessor_proto)
self.assertEqual(function, preprocessor.random_black_patches)
self.assert_dictionary_close(args, {'max_black_patches': 20,
'probability': 0.95,
'size_to_image_ratio': 0.12})
def test_build_random_resize_method(self):
preprocessor_text_proto = """
random_resize_method {
target_height: 75
target_width: 100
}
"""
preprocessor_proto = preprocessor_pb2.PreprocessingStep()
text_format.Merge(preprocessor_text_proto, preprocessor_proto)
function, args = preprocessor_builder.build(preprocessor_proto)
self.assertEqual(function, preprocessor.random_resize_method)
self.assert_dictionary_close(args, {'target_size': [75, 100]})
def test_build_scale_boxes_to_pixel_coordinates(self):
preprocessor_text_proto = """
scale_boxes_to_pixel_coordinates {}
"""
preprocessor_proto = preprocessor_pb2.PreprocessingStep()
text_format.Merge(preprocessor_text_proto, preprocessor_proto)
function, args = preprocessor_builder.build(preprocessor_proto)
self.assertEqual(function, preprocessor.scale_boxes_to_pixel_coordinates)
self.assertEqual(args, {})
def test_build_resize_image(self):
preprocessor_text_proto = """
resize_image {
new_height: 75
new_width: 100
method: BICUBIC
}
"""
preprocessor_proto = preprocessor_pb2.PreprocessingStep()
text_format.Merge(preprocessor_text_proto, preprocessor_proto)
function, args = preprocessor_builder.build(preprocessor_proto)
self.assertEqual(function, preprocessor.resize_image)
self.assertEqual(args, {'new_height': 75,
'new_width': 100,
'method': tf.image.ResizeMethod.BICUBIC})
def test_build_subtract_channel_mean(self):
preprocessor_text_proto = """
subtract_channel_mean {
means: [1.0, 2.0, 3.0]
}
"""
preprocessor_proto = preprocessor_pb2.PreprocessingStep()
text_format.Merge(preprocessor_text_proto, preprocessor_proto)
function, args = preprocessor_builder.build(preprocessor_proto)
self.assertEqual(function, preprocessor.subtract_channel_mean)
self.assertEqual(args, {'means': [1.0, 2.0, 3.0]})
def test_build_ssd_random_crop(self):
preprocessor_text_proto = """
ssd_random_crop {
operations {
min_object_covered: 0.0
min_aspect_ratio: 0.875
max_aspect_ratio: 1.125
min_area: 0.5
max_area: 1.0
overlap_thresh: 0.0
random_coef: 0.375
}
operations {
min_object_covered: 0.25
min_aspect_ratio: 0.75
max_aspect_ratio: 1.5
min_area: 0.5
max_area: 1.0
overlap_thresh: 0.25
random_coef: 0.375
}
}
"""
preprocessor_proto = preprocessor_pb2.PreprocessingStep()
text_format.Merge(preprocessor_text_proto, preprocessor_proto)
function, args = preprocessor_builder.build(preprocessor_proto)
self.assertEqual(function, preprocessor.ssd_random_crop)
self.assertEqual(args, {'min_object_covered': [0.0, 0.25],
'aspect_ratio_range': [(0.875, 1.125), (0.75, 1.5)],
'area_range': [(0.5, 1.0), (0.5, 1.0)],
'overlap_thresh': [0.0, 0.25],
'random_coef': [0.375, 0.375]})
def test_build_ssd_random_crop_empty_operations(self):
preprocessor_text_proto = """
ssd_random_crop {
}
"""
preprocessor_proto = preprocessor_pb2.PreprocessingStep()
text_format.Merge(preprocessor_text_proto, preprocessor_proto)
function, args = preprocessor_builder.build(preprocessor_proto)
self.assertEqual(function, preprocessor.ssd_random_crop)
self.assertEqual(args, {})
def test_build_ssd_random_crop_pad(self):
preprocessor_text_proto = """
ssd_random_crop_pad {
operations {
min_object_covered: 0.0
min_aspect_ratio: 0.875
max_aspect_ratio: 1.125
min_area: 0.5
max_area: 1.0
overlap_thresh: 0.0
random_coef: 0.375
min_padded_size_ratio: [0.0, 0.0]
max_padded_size_ratio: [2.0, 2.0]
pad_color_r: 0.5
pad_color_g: 0.5
pad_color_b: 0.5
}
operations {
min_object_covered: 0.25
min_aspect_ratio: 0.75
max_aspect_ratio: 1.5
min_area: 0.5
max_area: 1.0
overlap_thresh: 0.25
random_coef: 0.375
min_padded_size_ratio: [0.0, 0.0]
max_padded_size_ratio: [2.0, 2.0]
pad_color_r: 0.5
pad_color_g: 0.5
pad_color_b: 0.5
}
}
"""
preprocessor_proto = preprocessor_pb2.PreprocessingStep()
text_format.Merge(preprocessor_text_proto, preprocessor_proto)
function, args = preprocessor_builder.build(preprocessor_proto)
self.assertEqual(function, preprocessor.ssd_random_crop_pad)
self.assertEqual(args, {'min_object_covered': [0.0, 0.25],
'aspect_ratio_range': [(0.875, 1.125), (0.75, 1.5)],
'area_range': [(0.5, 1.0), (0.5, 1.0)],
'overlap_thresh': [0.0, 0.25],
'random_coef': [0.375, 0.375],
'min_padded_size_ratio': [(0.0, 0.0), (0.0, 0.0)],
'max_padded_size_ratio': [(2.0, 2.0), (2.0, 2.0)],
'pad_color': [(0.5, 0.5, 0.5), (0.5, 0.5, 0.5)]})
def test_build_ssd_random_crop_fixed_aspect_ratio(self):
preprocessor_text_proto = """
ssd_random_crop_fixed_aspect_ratio {
operations {
min_object_covered: 0.0
min_area: 0.5
max_area: 1.0
overlap_thresh: 0.0
random_coef: 0.375
}
operations {
min_object_covered: 0.25
min_area: 0.5
max_area: 1.0
overlap_thresh: 0.25
random_coef: 0.375
}
aspect_ratio: 0.875
}
"""
preprocessor_proto = preprocessor_pb2.PreprocessingStep()
text_format.Merge(preprocessor_text_proto, preprocessor_proto)
function, args = preprocessor_builder.build(preprocessor_proto)
self.assertEqual(function, preprocessor.ssd_random_crop_fixed_aspect_ratio)
self.assertEqual(args, {'min_object_covered': [0.0, 0.25],
'aspect_ratio': 0.875,
'area_range': [(0.5, 1.0), (0.5, 1.0)],
'overlap_thresh': [0.0, 0.25],
'random_coef': [0.375, 0.375]})
if __name__ == '__main__':
tf.test.main()
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Builder for region similarity calculators."""
from object_detection.core import region_similarity_calculator
from object_detection.protos import region_similarity_calculator_pb2
def build(region_similarity_calculator_config):
"""Builds region similarity calculator based on the configuration.
Builds one of [IouSimilarity, IoaSimilarity, NegSqDistSimilarity] objects. See
core/region_similarity_calculator.proto for details.
Args:
region_similarity_calculator_config: RegionSimilarityCalculator
configuration proto.
Returns:
region_similarity_calculator: RegionSimilarityCalculator object.
Raises:
ValueError: On unknown region similarity calculator.
"""
if not isinstance(
region_similarity_calculator_config,
region_similarity_calculator_pb2.RegionSimilarityCalculator):
raise ValueError(
'region_similarity_calculator_config not of type '
'region_similarity_calculator_pb2.RegionsSimilarityCalculator')
similarity_calculator = region_similarity_calculator_config.WhichOneof(
'region_similarity')
if similarity_calculator == 'iou_similarity':
return region_similarity_calculator.IouSimilarity()
if similarity_calculator == 'ioa_similarity':
return region_similarity_calculator.IoaSimilarity()
if similarity_calculator == 'neg_sq_dist_similarity':
return region_similarity_calculator.NegSqDistSimilarity()
raise ValueError('Unknown region similarity calculator.')
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for region_similarity_calculator_builder."""
import tensorflow as tf
from google.protobuf import text_format
from object_detection.builders import region_similarity_calculator_builder
from object_detection.core import region_similarity_calculator
from object_detection.protos import region_similarity_calculator_pb2 as sim_calc_pb2
class RegionSimilarityCalculatorBuilderTest(tf.test.TestCase):
def testBuildIoaSimilarityCalculator(self):
similarity_calc_text_proto = """
ioa_similarity {
}
"""
similarity_calc_proto = sim_calc_pb2.RegionSimilarityCalculator()
text_format.Merge(similarity_calc_text_proto, similarity_calc_proto)
similarity_calc = region_similarity_calculator_builder.build(
similarity_calc_proto)
self.assertTrue(isinstance(similarity_calc,
region_similarity_calculator.IoaSimilarity))
def testBuildIouSimilarityCalculator(self):
similarity_calc_text_proto = """
iou_similarity {
}
"""
similarity_calc_proto = sim_calc_pb2.RegionSimilarityCalculator()
text_format.Merge(similarity_calc_text_proto, similarity_calc_proto)
similarity_calc = region_similarity_calculator_builder.build(
similarity_calc_proto)
self.assertTrue(isinstance(similarity_calc,
region_similarity_calculator.IouSimilarity))
def testBuildNegSqDistSimilarityCalculator(self):
similarity_calc_text_proto = """
neg_sq_dist_similarity {
}
"""
similarity_calc_proto = sim_calc_pb2.RegionSimilarityCalculator()
text_format.Merge(similarity_calc_text_proto, similarity_calc_proto)
similarity_calc = region_similarity_calculator_builder.build(
similarity_calc_proto)
self.assertTrue(isinstance(similarity_calc,
region_similarity_calculator.
NegSqDistSimilarity))
if __name__ == '__main__':
tf.test.main()
# Tensorflow Object Detection API: Core.
package(
default_visibility = ["//visibility:public"],
)
licenses(["notice"])
# Apache 2.0
py_library(
name = "batcher",
srcs = ["batcher.py"],
deps = [
":prefetcher",
":preprocessor",
":standard_fields",
"//tensorflow",
],
)
py_test(
name = "batcher_test",
srcs = ["batcher_test.py"],
deps = [
":batcher",
"//tensorflow",
],
)
py_library(
name = "box_list",
srcs = [
"box_list.py",
],
deps = [
"//tensorflow",
],
)
py_test(
name = "box_list_test",
srcs = ["box_list_test.py"],
deps = [
":box_list",
],
)
py_library(
name = "box_list_ops",
srcs = [
"box_list_ops.py",
],
deps = [
":box_list",
"//tensorflow",
"//tensorflow_models/object_detection/utils:shape_utils",
],
)
py_test(
name = "box_list_ops_test",
srcs = ["box_list_ops_test.py"],
deps = [
":box_list",
":box_list_ops",
],
)
py_library(
name = "box_coder",
srcs = [
"box_coder.py",
],
deps = [
"//tensorflow",
],
)
py_test(
name = "box_coder_test",
srcs = [
"box_coder_test.py",
],
deps = [
":box_coder",
":box_list",
"//tensorflow",
],
)
py_library(
name = "keypoint_ops",
srcs = [
"keypoint_ops.py",
],
deps = [
"//tensorflow",
],
)
py_test(
name = "keypoint_ops_test",
srcs = ["keypoint_ops_test.py"],
deps = [
":keypoint_ops",
],
)
py_library(
name = "losses",
srcs = ["losses.py"],
deps = [
":box_list",
":box_list_ops",
"//tensorflow",
"//tensorflow_models/object_detection/utils:ops",
],
)
py_library(
name = "matcher",
srcs = [
"matcher.py",
],
deps = [
],
)
py_library(
name = "model",
srcs = ["model.py"],
deps = [
":standard_fields",
],
)
py_test(
name = "matcher_test",
srcs = [
"matcher_test.py",
],
deps = [
":matcher",
"//tensorflow",
],
)
py_library(
name = "prefetcher",
srcs = ["prefetcher.py"],
deps = ["//tensorflow"],
)
py_library(
name = "preprocessor",
srcs = [
"preprocessor.py",
],
deps = [
":box_list",
":box_list_ops",
":keypoint_ops",
":standard_fields",
"//tensorflow",
],
)
py_test(
name = "preprocessor_test",
srcs = [
"preprocessor_test.py",
],
deps = [
":preprocessor",
"//tensorflow",
],
)
py_test(
name = "losses_test",
srcs = ["losses_test.py"],
deps = [
":box_list",
":losses",
":matcher",
"//tensorflow",
],
)
py_test(
name = "prefetcher_test",
srcs = ["prefetcher_test.py"],
deps = [
":prefetcher",
"//tensorflow",
],
)
py_library(
name = "standard_fields",
srcs = [
"standard_fields.py",
],
)
py_library(
name = "post_processing",
srcs = ["post_processing.py"],
deps = [
":box_list",
":box_list_ops",
":standard_fields",
"//tensorflow",
],
)
py_test(
name = "post_processing_test",
srcs = ["post_processing_test.py"],
deps = [
":box_list",
":box_list_ops",
":post_processing",
"//tensorflow",
],
)
py_library(
name = "target_assigner",
srcs = [
"target_assigner.py",
],
deps = [
":box_list",
":box_list_ops",
":matcher",
":region_similarity_calculator",
"//tensorflow",
"//tensorflow_models/object_detection/box_coders:faster_rcnn_box_coder",
"//tensorflow_models/object_detection/box_coders:mean_stddev_box_coder",
"//tensorflow_models/object_detection/core:box_coder",
"//tensorflow_models/object_detection/matchers:argmax_matcher",
"//tensorflow_models/object_detection/matchers:bipartite_matcher",
],
)
py_test(
name = "target_assigner_test",
size = "large",
timeout = "long",
srcs = ["target_assigner_test.py"],
deps = [
":box_list",
":region_similarity_calculator",
":target_assigner",
"//tensorflow",
"//tensorflow_models/object_detection/box_coders:mean_stddev_box_coder",
"//tensorflow_models/object_detection/matchers:bipartite_matcher",
],
)
py_library(
name = "data_decoder",
srcs = ["data_decoder.py"],
)
py_library(
name = "box_predictor",
srcs = ["box_predictor.py"],
deps = [
"//tensorflow",
"//tensorflow_models/object_detection/utils:ops",
"//tensorflow_models/object_detection/utils:static_shape",
],
)
py_test(
name = "box_predictor_test",
srcs = ["box_predictor_test.py"],
deps = [
":box_predictor",
"//tensorflow",
"//tensorflow_models/object_detection/builders:hyperparams_builder",
"//tensorflow_models/object_detection/protos:hyperparams_py_pb2",
],
)
py_library(
name = "region_similarity_calculator",
srcs = [
"region_similarity_calculator.py",
],
deps = [
"//tensorflow",
"//tensorflow_models/object_detection/core:box_list_ops",
],
)
py_test(
name = "region_similarity_calculator_test",
srcs = [
"region_similarity_calculator_test.py",
],
deps = [
":region_similarity_calculator",
"//tensorflow_models/object_detection/core:box_list",
],
)
py_library(
name = "anchor_generator",
srcs = [
"anchor_generator.py",
],
deps = [
"//tensorflow",
],
)
py_library(
name = "minibatch_sampler",
srcs = [
"minibatch_sampler.py",
],
deps = [
"//tensorflow",
"//tensorflow_models/object_detection/utils:ops",
],
)
py_test(
name = "minibatch_sampler_test",
srcs = [
"minibatch_sampler_test.py",
],
deps = [
":minibatch_sampler",
"//tensorflow",
],
)
py_library(
name = "balanced_positive_negative_sampler",
srcs = [
"balanced_positive_negative_sampler.py",
],
deps = [
":minibatch_sampler",
"//tensorflow",
],
)
py_test(
name = "balanced_positive_negative_sampler_test",
srcs = [
"balanced_positive_negative_sampler_test.py",
],
deps = [
":balanced_positive_negative_sampler",
"//tensorflow",
],
)
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Base anchor generator.
The job of the anchor generator is to create (or load) a collection
of bounding boxes to be used as anchors.
Generated anchors are assumed to match some convolutional grid or list of grid
shapes. For example, we might want to generate anchors matching an 8x8
feature map and a 4x4 feature map. If we place 3 anchors per grid location
on the first feature map and 6 anchors per grid location on the second feature
map, then 3*8*8 + 6*4*4 = 288 anchors are generated in total.
To support fully convolutional settings, feature map shapes are passed
dynamically at generation time. The number of anchors to place at each location
is static --- implementations of AnchorGenerator must always be able return
the number of anchors that it uses per location for each feature map.
"""
from abc import ABCMeta
from abc import abstractmethod
import tensorflow as tf
class AnchorGenerator(object):
"""Abstract base class for anchor generators."""
__metaclass__ = ABCMeta
@abstractmethod
def name_scope(self):
"""Name scope.
Must be defined by implementations.
Returns:
a string representing the name scope of the anchor generation operation.
"""
pass
@property
def check_num_anchors(self):
"""Whether to dynamically check the number of anchors generated.
Can be overridden by implementations that would like to disable this
behavior.
Returns:
a boolean controlling whether the Generate function should dynamically
check the number of anchors generated against the mathematically
expected number of anchors.
"""
return True
@abstractmethod
def num_anchors_per_location(self):
"""Returns the number of anchors per spatial location.
Returns:
a list of integers, one for each expected feature map to be passed to
the `generate` function.
"""
pass
def generate(self, feature_map_shape_list, **params):
"""Generates a collection of bounding boxes to be used as anchors.
TODO: remove **params from argument list and make stride and offsets (for
multiple_grid_anchor_generator) constructor arguments.
Args:
feature_map_shape_list: list of (height, width) pairs in the format
[(height_0, width_0), (height_1, width_1), ...] that the generated
anchors must align with. Pairs can be provided as 1-dimensional
integer tensors of length 2 or simply as tuples of integers.
**params: parameters for anchor generation op
Returns:
boxes: a BoxList holding a collection of N anchor boxes
Raises:
ValueError: if the number of feature map shapes does not match the length
of NumAnchorsPerLocation.
"""
if self.check_num_anchors and (
len(feature_map_shape_list) != len(self.num_anchors_per_location())):
raise ValueError('Number of feature maps is expected to equal the length '
'of `num_anchors_per_location`.')
with tf.name_scope(self.name_scope()):
anchors = self._generate(feature_map_shape_list, **params)
if self.check_num_anchors:
with tf.control_dependencies([
self._assert_correct_number_of_anchors(
anchors, feature_map_shape_list)]):
anchors.set(tf.identity(anchors.get()))
return anchors
@abstractmethod
def _generate(self, feature_map_shape_list, **params):
"""To be overridden by implementations.
Args:
feature_map_shape_list: list of (height, width) pairs in the format
[(height_0, width_0), (height_1, width_1), ...] that the generated
anchors must align with.
**params: parameters for anchor generation op
Returns:
boxes: a BoxList holding a collection of N anchor boxes
"""
pass
def _assert_correct_number_of_anchors(self, anchors, feature_map_shape_list):
"""Assert that correct number of anchors was generated.
Args:
anchors: box_list.BoxList object holding anchors generated
feature_map_shape_list: list of (height, width) pairs in the format
[(height_0, width_0), (height_1, width_1), ...] that the generated
anchors must align with.
Returns:
Op that raises InvalidArgumentError if the number of anchors does not
match the number of expected anchors.
"""
expected_num_anchors = 0
for num_anchors_per_location, feature_map_shape in zip(
self.num_anchors_per_location(), feature_map_shape_list):
expected_num_anchors += (num_anchors_per_location
* feature_map_shape[0]
* feature_map_shape[1])
return tf.assert_equal(expected_num_anchors, anchors.num_boxes())
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Class to subsample minibatches by balancing positives and negatives.
Subsamples minibatches based on a pre-specified positive fraction in range
[0,1]. The class presumes there are many more negatives than positive examples:
if the desired batch_size cannot be achieved with the pre-specified positive
fraction, it fills the rest with negative examples. If this is not sufficient
for obtaining the desired batch_size, it returns fewer examples.
The main function to call is Subsample(self, indicator, labels). For convenience
one can also call SubsampleWeights(self, weights, labels) which is defined in
the minibatch_sampler base class.
"""
import tensorflow as tf
from object_detection.core import minibatch_sampler
class BalancedPositiveNegativeSampler(minibatch_sampler.MinibatchSampler):
"""Subsamples minibatches to a desired balance of positives and negatives."""
def __init__(self, positive_fraction=0.5):
"""Constructs a minibatch sampler.
Args:
positive_fraction: desired fraction of positive examples (scalar in [0,1])
Raises:
ValueError: if positive_fraction < 0, or positive_fraction > 1
"""
if positive_fraction < 0 or positive_fraction > 1:
raise ValueError('positive_fraction should be in range [0,1]. '
'Received: %s.' % positive_fraction)
self._positive_fraction = positive_fraction
def subsample(self, indicator, batch_size, labels):
"""Returns subsampled minibatch.
Args:
indicator: boolean tensor of shape [N] whose True entries can be sampled.
batch_size: desired batch size.
labels: boolean tensor of shape [N] denoting positive(=True) and negative
(=False) examples.
Returns:
is_sampled: boolean tensor of shape [N], True for entries which are
sampled.
Raises:
ValueError: if labels and indicator are not 1D boolean tensors.
"""
if len(indicator.get_shape().as_list()) != 1:
raise ValueError('indicator must be 1 dimensional, got a tensor of '
'shape %s' % indicator.get_shape())
if len(labels.get_shape().as_list()) != 1:
raise ValueError('labels must be 1 dimensional, got a tensor of '
'shape %s' % labels.get_shape())
if labels.dtype != tf.bool:
raise ValueError('labels should be of type bool. Received: %s' %
labels.dtype)
if indicator.dtype != tf.bool:
raise ValueError('indicator should be of type bool. Received: %s' %
indicator.dtype)
# Only sample from indicated samples
negative_idx = tf.logical_not(labels)
positive_idx = tf.logical_and(labels, indicator)
negative_idx = tf.logical_and(negative_idx, indicator)
# Sample positive and negative samples separately
max_num_pos = int(self._positive_fraction * batch_size)
sampled_pos_idx = self.subsample_indicator(positive_idx, max_num_pos)
max_num_neg = batch_size - tf.reduce_sum(tf.cast(sampled_pos_idx, tf.int32))
sampled_neg_idx = self.subsample_indicator(negative_idx, max_num_neg)
sampled_idx = tf.logical_or(sampled_pos_idx, sampled_neg_idx)
return sampled_idx
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for object_detection.core.balanced_positive_negative_sampler."""
import numpy as np
import tensorflow as tf
from object_detection.core import balanced_positive_negative_sampler
class BalancedPositiveNegativeSamplerTest(tf.test.TestCase):
def test_subsample_all_examples(self):
numpy_labels = np.random.permutation(300)
indicator = tf.constant(np.ones(300) == 1)
numpy_labels = (numpy_labels - 200) > 0
labels = tf.constant(numpy_labels)
sampler = (balanced_positive_negative_sampler.
BalancedPositiveNegativeSampler())
is_sampled = sampler.subsample(indicator, 64, labels)
with self.test_session() as sess:
is_sampled = sess.run(is_sampled)
self.assertTrue(sum(is_sampled) == 64)
self.assertTrue(sum(np.logical_and(numpy_labels, is_sampled)) == 32)
self.assertTrue(sum(np.logical_and(
np.logical_not(numpy_labels), is_sampled)) == 32)
def test_subsample_selection(self):
# Test random sampling when only some examples can be sampled:
# 100 samples, 20 positives, 10 positives cannot be sampled
numpy_labels = np.arange(100)
numpy_indicator = numpy_labels < 90
indicator = tf.constant(numpy_indicator)
numpy_labels = (numpy_labels - 80) >= 0
labels = tf.constant(numpy_labels)
sampler = (balanced_positive_negative_sampler.
BalancedPositiveNegativeSampler())
is_sampled = sampler.subsample(indicator, 64, labels)
with self.test_session() as sess:
is_sampled = sess.run(is_sampled)
self.assertTrue(sum(is_sampled) == 64)
self.assertTrue(sum(np.logical_and(numpy_labels, is_sampled)) == 10)
self.assertTrue(sum(np.logical_and(
np.logical_not(numpy_labels), is_sampled)) == 54)
self.assertAllEqual(is_sampled, np.logical_and(is_sampled,
numpy_indicator))
def test_raises_error_with_incorrect_label_shape(self):
labels = tf.constant([[True, False, False]])
indicator = tf.constant([True, False, True])
sampler = (balanced_positive_negative_sampler.
BalancedPositiveNegativeSampler())
with self.assertRaises(ValueError):
sampler.subsample(indicator, 64, labels)
def test_raises_error_with_incorrect_indicator_shape(self):
labels = tf.constant([True, False, False])
indicator = tf.constant([[True, False, True]])
sampler = (balanced_positive_negative_sampler.
BalancedPositiveNegativeSampler())
with self.assertRaises(ValueError):
sampler.subsample(indicator, 64, labels)
if __name__ == '__main__':
tf.test.main()
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Provides functions to batch a dictionary of input tensors."""
import collections
import tensorflow as tf
from object_detection.core import prefetcher
class BatchQueue(object):
"""BatchQueue class.
This class creates a batch queue to asynchronously enqueue tensors_dict.
It also adds a FIFO prefetcher so that the batches are readily available
for the consumers. Dequeue ops for a BatchQueue object can be created via
the Dequeue method which evaluates to a batch of tensor_dict.
Example input pipeline with batching:
------------------------------------
key, string_tensor = slim.parallel_reader.parallel_read(...)
tensor_dict = decoder.decode(string_tensor)
tensor_dict = preprocessor.preprocess(tensor_dict, ...)
batch_queue = batcher.BatchQueue(tensor_dict,
batch_size=32,
batch_queue_capacity=2000,
num_batch_queue_threads=8,
prefetch_queue_capacity=20)
tensor_dict = batch_queue.dequeue()
outputs = Model(tensor_dict)
...
-----------------------------------
Notes:
-----
This class batches tensors of unequal sizes by zero padding and unpadding
them after generating a batch. This can be computationally expensive when
batching tensors (such as images) that are of vastly different sizes. So it is
recommended that the shapes of such tensors be fully defined in tensor_dict
while other lightweight tensors such as bounding box corners and class labels
can be of varying sizes. Use either crop or resize operations to fully define
the shape of an image in tensor_dict.
It is also recommended to perform any preprocessing operations on tensors
before passing to BatchQueue and subsequently calling the Dequeue method.
Another caveat is that this class does not read the last batch if it is not
full. The current implementation makes it hard to support that use case. So,
for evaluation, when it is critical to run all the examples through your
network use the input pipeline example mentioned in core/prefetcher.py.
"""
def __init__(self, tensor_dict, batch_size, batch_queue_capacity,
num_batch_queue_threads, prefetch_queue_capacity):
"""Constructs a batch queue holding tensor_dict.
Args:
tensor_dict: dictionary of tensors to batch.
batch_size: batch size.
batch_queue_capacity: max capacity of the queue from which the tensors are
batched.
num_batch_queue_threads: number of threads to use for batching.
prefetch_queue_capacity: max capacity of the queue used to prefetch
assembled batches.
"""
# Remember static shapes to set shapes of batched tensors.
static_shapes = collections.OrderedDict(
{key: tensor.get_shape() for key, tensor in tensor_dict.iteritems()})
# Remember runtime shapes to unpad tensors after batching.
runtime_shapes = collections.OrderedDict(
{(key, 'runtime_shapes'): tf.shape(tensor)
for key, tensor in tensor_dict.iteritems()})
all_tensors = tensor_dict
all_tensors.update(runtime_shapes)
batched_tensors = tf.train.batch(
all_tensors,
capacity=batch_queue_capacity,
batch_size=batch_size,
dynamic_pad=True,
num_threads=num_batch_queue_threads)
self._queue = prefetcher.prefetch(batched_tensors,
prefetch_queue_capacity)
self._static_shapes = static_shapes
self._batch_size = batch_size
def dequeue(self):
"""Dequeues a batch of tensor_dict from the BatchQueue.
TODO: use allow_smaller_final_batch to allow running over the whole eval set
Returns:
A list of tensor_dicts of the requested batch_size.
"""
batched_tensors = self._queue.dequeue()
# Separate input tensors from tensors containing their runtime shapes.
tensors = {}
shapes = {}
for key, batched_tensor in batched_tensors.iteritems():
unbatched_tensor_list = tf.unstack(batched_tensor)
for i, unbatched_tensor in enumerate(unbatched_tensor_list):
if isinstance(key, tuple) and key[1] == 'runtime_shapes':
shapes[(key[0], i)] = unbatched_tensor
else:
tensors[(key, i)] = unbatched_tensor
# Undo that padding using shapes and create a list of size `batch_size` that
# contains tensor dictionaries.
tensor_dict_list = []
batch_size = self._batch_size
for batch_id in range(batch_size):
tensor_dict = {}
for key in self._static_shapes:
tensor_dict[key] = tf.slice(tensors[(key, batch_id)],
tf.zeros_like(shapes[(key, batch_id)]),
shapes[(key, batch_id)])
tensor_dict[key].set_shape(self._static_shapes[key])
tensor_dict_list.append(tensor_dict)
return tensor_dict_list
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for object_detection.core.batcher."""
import numpy as np
import tensorflow as tf
from object_detection.core import batcher
slim = tf.contrib.slim
class BatcherTest(tf.test.TestCase):
def test_batch_and_unpad_2d_tensors_of_different_sizes_in_1st_dimension(self):
with self.test_session() as sess:
batch_size = 3
num_batches = 2
examples = tf.Variable(tf.constant(2, dtype=tf.int32))
counter = examples.count_up_to(num_batches * batch_size + 2)
boxes = tf.tile(
tf.reshape(tf.range(4), [1, 4]), tf.stack([counter, tf.constant(1)]))
batch_queue = batcher.BatchQueue(
tensor_dict={'boxes': boxes},
batch_size=batch_size,
batch_queue_capacity=100,
num_batch_queue_threads=1,
prefetch_queue_capacity=100)
batch = batch_queue.dequeue()
for tensor_dict in batch:
for tensor in tensor_dict.values():
self.assertAllEqual([None, 4], tensor.get_shape().as_list())
tf.initialize_all_variables().run()
with slim.queues.QueueRunners(sess):
i = 2
for _ in range(num_batches):
batch_np = sess.run(batch)
for tensor_dict in batch_np:
for tensor in tensor_dict.values():
self.assertAllEqual(tensor, np.tile(np.arange(4), (i, 1)))
i += 1
with self.assertRaises(tf.errors.OutOfRangeError):
sess.run(batch)
def test_batch_and_unpad_2d_tensors_of_different_sizes_in_all_dimensions(
self):
with self.test_session() as sess:
batch_size = 3
num_batches = 2
examples = tf.Variable(tf.constant(2, dtype=tf.int32))
counter = examples.count_up_to(num_batches * batch_size + 2)
image = tf.reshape(
tf.range(counter * counter), tf.stack([counter, counter]))
batch_queue = batcher.BatchQueue(
tensor_dict={'image': image},
batch_size=batch_size,
batch_queue_capacity=100,
num_batch_queue_threads=1,
prefetch_queue_capacity=100)
batch = batch_queue.dequeue()
for tensor_dict in batch:
for tensor in tensor_dict.values():
self.assertAllEqual([None, None], tensor.get_shape().as_list())
tf.initialize_all_variables().run()
with slim.queues.QueueRunners(sess):
i = 2
for _ in range(num_batches):
batch_np = sess.run(batch)
for tensor_dict in batch_np:
for tensor in tensor_dict.values():
self.assertAllEqual(tensor, np.arange(i * i).reshape((i, i)))
i += 1
with self.assertRaises(tf.errors.OutOfRangeError):
sess.run(batch)
def test_batch_and_unpad_2d_tensors_of_same_size_in_all_dimensions(self):
with self.test_session() as sess:
batch_size = 3
num_batches = 2
examples = tf.Variable(tf.constant(1, dtype=tf.int32))
counter = examples.count_up_to(num_batches * batch_size + 1)
image = tf.reshape(tf.range(1, 13), [4, 3]) * counter
batch_queue = batcher.BatchQueue(
tensor_dict={'image': image},
batch_size=batch_size,
batch_queue_capacity=100,
num_batch_queue_threads=1,
prefetch_queue_capacity=100)
batch = batch_queue.dequeue()
for tensor_dict in batch:
for tensor in tensor_dict.values():
self.assertAllEqual([4, 3], tensor.get_shape().as_list())
tf.initialize_all_variables().run()
with slim.queues.QueueRunners(sess):
i = 1
for _ in range(num_batches):
batch_np = sess.run(batch)
for tensor_dict in batch_np:
for tensor in tensor_dict.values():
self.assertAllEqual(tensor, np.arange(1, 13).reshape((4, 3)) * i)
i += 1
with self.assertRaises(tf.errors.OutOfRangeError):
sess.run(batch)
def test_batcher_when_batch_size_is_one(self):
with self.test_session() as sess:
batch_size = 1
num_batches = 2
examples = tf.Variable(tf.constant(2, dtype=tf.int32))
counter = examples.count_up_to(num_batches * batch_size + 2)
image = tf.reshape(
tf.range(counter * counter), tf.stack([counter, counter]))
batch_queue = batcher.BatchQueue(
tensor_dict={'image': image},
batch_size=batch_size,
batch_queue_capacity=100,
num_batch_queue_threads=1,
prefetch_queue_capacity=100)
batch = batch_queue.dequeue()
for tensor_dict in batch:
for tensor in tensor_dict.values():
self.assertAllEqual([None, None], tensor.get_shape().as_list())
tf.initialize_all_variables().run()
with slim.queues.QueueRunners(sess):
i = 2
for _ in range(num_batches):
batch_np = sess.run(batch)
for tensor_dict in batch_np:
for tensor in tensor_dict.values():
self.assertAllEqual(tensor, np.arange(i * i).reshape((i, i)))
i += 1
with self.assertRaises(tf.errors.OutOfRangeError):
sess.run(batch)
if __name__ == '__main__':
tf.test.main()
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Base box coder.
Box coders convert between coordinate frames, namely image-centric
(with (0,0) on the top left of image) and anchor-centric (with (0,0) being
defined by a specific anchor).
Users of a BoxCoder can call two methods:
encode: which encodes a box with respect to a given anchor
(or rather, a tensor of boxes wrt a corresponding tensor of anchors) and
decode: which inverts this encoding with a decode operation.
In both cases, the arguments are assumed to be in 1-1 correspondence already;
it is not the job of a BoxCoder to perform matching.
"""
from abc import ABCMeta
from abc import abstractmethod
from abc import abstractproperty
import tensorflow as tf
# Box coder types.
FASTER_RCNN = 'faster_rcnn'
KEYPOINT = 'keypoint'
MEAN_STDDEV = 'mean_stddev'
SQUARE = 'square'
class BoxCoder(object):
"""Abstract base class for box coder."""
__metaclass__ = ABCMeta
@abstractproperty
def code_size(self):
"""Return the size of each code.
This number is a constant and should agree with the output of the `encode`
op (e.g. if rel_codes is the output of self.encode(...), then it should have
shape [N, code_size()]). This abstractproperty should be overridden by
implementations.
Returns:
an integer constant
"""
pass
def encode(self, boxes, anchors):
"""Encode a box list relative to an anchor collection.
Args:
boxes: BoxList holding N boxes to be encoded
anchors: BoxList of N anchors
Returns:
a tensor representing N relative-encoded boxes
"""
with tf.name_scope('Encode'):
return self._encode(boxes, anchors)
def decode(self, rel_codes, anchors):
"""Decode boxes that are encoded relative to an anchor collection.
Args:
rel_codes: a tensor representing N relative-encoded boxes
anchors: BoxList of anchors
Returns:
boxlist: BoxList holding N boxes encoded in the ordinary way (i.e.,
with corners y_min, x_min, y_max, x_max)
"""
with tf.name_scope('Decode'):
return self._decode(rel_codes, anchors)
@abstractmethod
def _encode(self, boxes, anchors):
"""Method to be overriden by implementations.
Args:
boxes: BoxList holding N boxes to be encoded
anchors: BoxList of N anchors
Returns:
a tensor representing N relative-encoded boxes
"""
pass
@abstractmethod
def _decode(self, rel_codes, anchors):
"""Method to be overriden by implementations.
Args:
rel_codes: a tensor representing N relative-encoded boxes
anchors: BoxList of anchors
Returns:
boxlist: BoxList holding N boxes encoded in the ordinary way (i.e.,
with corners y_min, x_min, y_max, x_max)
"""
pass
def batch_decode(encoded_boxes, box_coder, anchors):
"""Decode a batch of encoded boxes.
This op takes a batch of encoded bounding boxes and transforms
them to a batch of bounding boxes specified by their corners in
the order of [y_min, x_min, y_max, x_max].
Args:
encoded_boxes: a float32 tensor of shape [batch_size, num_anchors,
code_size] representing the location of the objects.
box_coder: a BoxCoder object.
anchors: a BoxList of anchors used to encode `encoded_boxes`.
Returns:
decoded_boxes: a float32 tensor of shape [batch_size, num_anchors,
coder_size] representing the corners of the objects in the order
of [y_min, x_min, y_max, x_max].
Raises:
ValueError: if batch sizes of the inputs are inconsistent, or if
the number of anchors inferred from encoded_boxes and anchors are
inconsistent.
"""
encoded_boxes.get_shape().assert_has_rank(3)
if encoded_boxes.get_shape()[1].value != anchors.num_boxes_static():
raise ValueError('The number of anchors inferred from encoded_boxes'
' and anchors are inconsistent: shape[1] of encoded_boxes'
' %s should be equal to the number of anchors: %s.' %
(encoded_boxes.get_shape()[1].value,
anchors.num_boxes_static()))
decoded_boxes = tf.stack([
box_coder.decode(boxes, anchors).get()
for boxes in tf.unstack(encoded_boxes)
])
return decoded_boxes
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for object_detection.core.box_coder."""
import tensorflow as tf
from object_detection.core import box_coder
from object_detection.core import box_list
class MockBoxCoder(box_coder.BoxCoder):
"""Test BoxCoder that encodes/decodes using the multiply-by-two function."""
def code_size(self):
return 4
def _encode(self, boxes, anchors):
return 2.0 * boxes.get()
def _decode(self, rel_codes, anchors):
return box_list.BoxList(rel_codes / 2.0)
class BoxCoderTest(tf.test.TestCase):
def test_batch_decode(self):
mock_anchor_corners = tf.constant(
[[0, 0.1, 0.2, 0.3], [0.2, 0.4, 0.4, 0.6]], tf.float32)
mock_anchors = box_list.BoxList(mock_anchor_corners)
mock_box_coder = MockBoxCoder()
expected_boxes = [[[0.0, 0.1, 0.5, 0.6], [0.5, 0.6, 0.7, 0.8]],
[[0.1, 0.2, 0.3, 0.4], [0.7, 0.8, 0.9, 1.0]]]
encoded_boxes_list = [mock_box_coder.encode(
box_list.BoxList(tf.constant(boxes)), mock_anchors)
for boxes in expected_boxes]
encoded_boxes = tf.stack(encoded_boxes_list)
decoded_boxes = box_coder.batch_decode(
encoded_boxes, mock_box_coder, mock_anchors)
with self.test_session() as sess:
decoded_boxes_result = sess.run(decoded_boxes)
self.assertAllClose(expected_boxes, decoded_boxes_result)
if __name__ == '__main__':
tf.test.main()
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Bounding Box List definition.
BoxList represents a list of bounding boxes as tensorflow
tensors, where each bounding box is represented as a row of 4 numbers,
[y_min, x_min, y_max, x_max]. It is assumed that all bounding boxes
within a given list correspond to a single image. See also
box_list_ops.py for common box related operations (such as area, iou, etc).
Optionally, users can add additional related fields (such as weights).
We assume the following things to be true about fields:
* they correspond to boxes in the box_list along the 0th dimension
* they have inferrable rank at graph construction time
* all dimensions except for possibly the 0th can be inferred
(i.e., not None) at graph construction time.
Some other notes:
* Following tensorflow conventions, we use height, width ordering,
and correspondingly, y,x (or ymin, xmin, ymax, xmax) ordering
* Tensors are always provided as (flat) [N, 4] tensors.
"""
import tensorflow as tf
class BoxList(object):
"""Box collection."""
def __init__(self, boxes):
"""Constructs box collection.
Args:
boxes: a tensor of shape [N, 4] representing box corners
Raises:
ValueError: if invalid dimensions for bbox data or if bbox data is not in
float32 format.
"""
if len(boxes.get_shape()) != 2 or boxes.get_shape()[-1] != 4:
raise ValueError('Invalid dimensions for box data.')
if boxes.dtype != tf.float32:
raise ValueError('Invalid tensor type: should be tf.float32')
self.data = {'boxes': boxes}
def num_boxes(self):
"""Returns number of boxes held in collection.
Returns:
a tensor representing the number of boxes held in the collection.
"""
return tf.shape(self.data['boxes'])[0]
def num_boxes_static(self):
"""Returns number of boxes held in collection.
This number is inferred at graph construction time rather than run-time.
Returns:
Number of boxes held in collection (integer) or None if this is not
inferrable at graph construction time.
"""
return self.data['boxes'].get_shape()[0].value
def get_all_fields(self):
"""Returns all fields."""
return self.data.keys()
def get_extra_fields(self):
"""Returns all non-box fields (i.e., everything not named 'boxes')."""
return [k for k in self.data.keys() if k != 'boxes']
def add_field(self, field, field_data):
"""Add field to box list.
This method can be used to add related box data such as
weights/labels, etc.
Args:
field: a string key to access the data via `get`
field_data: a tensor containing the data to store in the BoxList
"""
self.data[field] = field_data
def has_field(self, field):
return field in self.data
def get(self):
"""Convenience function for accessing box coordinates.
Returns:
a tensor with shape [N, 4] representing box coordinates.
"""
return self.get_field('boxes')
def set(self, boxes):
"""Convenience function for setting box coordinates.
Args:
boxes: a tensor of shape [N, 4] representing box corners
Raises:
ValueError: if invalid dimensions for bbox data
"""
if len(boxes.get_shape()) != 2 or boxes.get_shape()[-1] != 4:
raise ValueError('Invalid dimensions for box data.')
self.data['boxes'] = boxes
def get_field(self, field):
"""Accesses a box collection and associated fields.
This function returns specified field with object; if no field is specified,
it returns the box coordinates.
Args:
field: this optional string parameter can be used to specify
a related field to be accessed.
Returns:
a tensor representing the box collection or an associated field.
Raises:
ValueError: if invalid field
"""
if not self.has_field(field):
raise ValueError('field ' + str(field) + ' does not exist')
return self.data[field]
def set_field(self, field, value):
"""Sets the value of a field.
Updates the field of a box_list with a given value.
Args:
field: (string) name of the field to set value.
value: the value to assign to the field.
Raises:
ValueError: if the box_list does not have specified field.
"""
if not self.has_field(field):
raise ValueError('field %s does not exist' % field)
self.data[field] = value
def get_center_coordinates_and_sizes(self, scope=None):
"""Computes the center coordinates, height and width of the boxes.
Args:
scope: name scope of the function.
Returns:
a list of 4 1-D tensors [ycenter, xcenter, height, width].
"""
with tf.name_scope(scope, 'get_center_coordinates_and_sizes'):
box_corners = self.get()
ymin, xmin, ymax, xmax = tf.unstack(tf.transpose(box_corners))
width = xmax - xmin
height = ymax - ymin
ycenter = ymin + height / 2.
xcenter = xmin + width / 2.
return [ycenter, xcenter, height, width]
def transpose_coordinates(self, scope=None):
"""Transpose the coordinate representation in a boxlist.
Args:
scope: name scope of the function.
"""
with tf.name_scope(scope, 'transpose_coordinates'):
y_min, x_min, y_max, x_max = tf.split(
value=self.get(), num_or_size_splits=4, axis=1)
self.set(tf.concat([x_min, y_min, x_max, y_max], 1))
def as_tensor_dict(self, fields=None):
"""Retrieves specified fields as a dictionary of tensors.
Args:
fields: (optional) list of fields to return in the dictionary.
If None (default), all fields are returned.
Returns:
tensor_dict: A dictionary of tensors specified by fields.
Raises:
ValueError: if specified field is not contained in boxlist.
"""
tensor_dict = {}
if fields is None:
fields = self.get_all_fields()
for field in fields:
if not self.has_field(field):
raise ValueError('boxlist must contain all specified fields')
tensor_dict[field] = self.get_field(field)
return tensor_dict
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Bounding Box List operations.
Example box operations that are supported:
* areas: compute bounding box areas
* iou: pairwise intersection-over-union scores
* sq_dist: pairwise distances between bounding boxes
Whenever box_list_ops functions output a BoxList, the fields of the incoming
BoxList are retained unless documented otherwise.
"""
import tensorflow as tf
from object_detection.core import box_list
from object_detection.utils import shape_utils
class SortOrder(object):
"""Enum class for sort order.
Attributes:
ascend: ascend order.
descend: descend order.
"""
ascend = 1
descend = 2
def area(boxlist, scope=None):
"""Computes area of boxes.
Args:
boxlist: BoxList holding N boxes
scope: name scope.
Returns:
a tensor with shape [N] representing box areas.
"""
with tf.name_scope(scope, 'Area'):
y_min, x_min, y_max, x_max = tf.split(
value=boxlist.get(), num_or_size_splits=4, axis=1)
return tf.squeeze((y_max - y_min) * (x_max - x_min), [1])
def height_width(boxlist, scope=None):
"""Computes height and width of boxes in boxlist.
Args:
boxlist: BoxList holding N boxes
scope: name scope.
Returns:
Height: A tensor with shape [N] representing box heights.
Width: A tensor with shape [N] representing box widths.
"""
with tf.name_scope(scope, 'HeightWidth'):
y_min, x_min, y_max, x_max = tf.split(
value=boxlist.get(), num_or_size_splits=4, axis=1)
return tf.squeeze(y_max - y_min, [1]), tf.squeeze(x_max - x_min, [1])
def scale(boxlist, y_scale, x_scale, scope=None):
"""scale box coordinates in x and y dimensions.
Args:
boxlist: BoxList holding N boxes
y_scale: (float) scalar tensor
x_scale: (float) scalar tensor
scope: name scope.
Returns:
boxlist: BoxList holding N boxes
"""
with tf.name_scope(scope, 'Scale'):
y_scale = tf.cast(y_scale, tf.float32)
x_scale = tf.cast(x_scale, tf.float32)
y_min, x_min, y_max, x_max = tf.split(
value=boxlist.get(), num_or_size_splits=4, axis=1)
y_min = y_scale * y_min
y_max = y_scale * y_max
x_min = x_scale * x_min
x_max = x_scale * x_max
scaled_boxlist = box_list.BoxList(
tf.concat([y_min, x_min, y_max, x_max], 1))
return _copy_extra_fields(scaled_boxlist, boxlist)
def clip_to_window(boxlist, window, filter_nonoverlapping=True, scope=None):
"""Clip bounding boxes to a window.
This op clips any input bounding boxes (represented by bounding box
corners) to a window, optionally filtering out boxes that do not
overlap at all with the window.
Args:
boxlist: BoxList holding M_in boxes
window: a tensor of shape [4] representing the [y_min, x_min, y_max, x_max]
window to which the op should clip boxes.
filter_nonoverlapping: whether to filter out boxes that do not overlap at
all with the window.
scope: name scope.
Returns:
a BoxList holding M_out boxes where M_out <= M_in
"""
with tf.name_scope(scope, 'ClipToWindow'):
y_min, x_min, y_max, x_max = tf.split(
value=boxlist.get(), num_or_size_splits=4, axis=1)
win_y_min, win_x_min, win_y_max, win_x_max = tf.unstack(window)
y_min_clipped = tf.maximum(tf.minimum(y_min, win_y_max), win_y_min)
y_max_clipped = tf.maximum(tf.minimum(y_max, win_y_max), win_y_min)
x_min_clipped = tf.maximum(tf.minimum(x_min, win_x_max), win_x_min)
x_max_clipped = tf.maximum(tf.minimum(x_max, win_x_max), win_x_min)
clipped = box_list.BoxList(
tf.concat([y_min_clipped, x_min_clipped, y_max_clipped, x_max_clipped],
1))
clipped = _copy_extra_fields(clipped, boxlist)
if filter_nonoverlapping:
areas = area(clipped)
nonzero_area_indices = tf.cast(
tf.reshape(tf.where(tf.greater(areas, 0.0)), [-1]), tf.int32)
clipped = gather(clipped, nonzero_area_indices)
return clipped
def prune_outside_window(boxlist, window, scope=None):
"""Prunes bounding boxes that fall outside a given window.
This function prunes bounding boxes that even partially fall outside the given
window. See also clip_to_window which only prunes bounding boxes that fall
completely outside the window, and clips any bounding boxes that partially
overflow.
Args:
boxlist: a BoxList holding M_in boxes.
window: a float tensor of shape [4] representing [ymin, xmin, ymax, xmax]
of the window
scope: name scope.
Returns:
pruned_corners: a tensor with shape [M_out, 4] where M_out <= M_in
valid_indices: a tensor with shape [M_out] indexing the valid bounding boxes
in the input tensor.
"""
with tf.name_scope(scope, 'PruneOutsideWindow'):
y_min, x_min, y_max, x_max = tf.split(
value=boxlist.get(), num_or_size_splits=4, axis=1)
win_y_min, win_x_min, win_y_max, win_x_max = tf.unstack(window)
coordinate_violations = tf.concat([
tf.less(y_min, win_y_min), tf.less(x_min, win_x_min),
tf.greater(y_max, win_y_max), tf.greater(x_max, win_x_max)
], 1)
valid_indices = tf.reshape(
tf.where(tf.logical_not(tf.reduce_any(coordinate_violations, 1))), [-1])
return gather(boxlist, valid_indices), valid_indices
def prune_completely_outside_window(boxlist, window, scope=None):
"""Prunes bounding boxes that fall completely outside of the given window.
The function clip_to_window prunes bounding boxes that fall
completely outside the window, but also clips any bounding boxes that
partially overflow. This function does not clip partially overflowing boxes.
Args:
boxlist: a BoxList holding M_in boxes.
window: a float tensor of shape [4] representing [ymin, xmin, ymax, xmax]
of the window
scope: name scope.
Returns:
pruned_corners: a tensor with shape [M_out, 4] where M_out <= M_in
valid_indices: a tensor with shape [M_out] indexing the valid bounding boxes
in the input tensor.
"""
with tf.name_scope(scope, 'PruneCompleteleyOutsideWindow'):
y_min, x_min, y_max, x_max = tf.split(
value=boxlist.get(), num_or_size_splits=4, axis=1)
win_y_min, win_x_min, win_y_max, win_x_max = tf.unstack(window)
coordinate_violations = tf.concat([
tf.greater_equal(y_min, win_y_max), tf.greater_equal(x_min, win_x_max),
tf.less_equal(y_max, win_y_min), tf.less_equal(x_max, win_x_min)
], 1)
valid_indices = tf.reshape(
tf.where(tf.logical_not(tf.reduce_any(coordinate_violations, 1))), [-1])
return gather(boxlist, valid_indices), valid_indices
def intersection(boxlist1, boxlist2, scope=None):
"""Compute pairwise intersection areas between boxes.
Args:
boxlist1: BoxList holding N boxes
boxlist2: BoxList holding M boxes
scope: name scope.
Returns:
a tensor with shape [N, M] representing pairwise intersections
"""
with tf.name_scope(scope, 'Intersection'):
y_min1, x_min1, y_max1, x_max1 = tf.split(
value=boxlist1.get(), num_or_size_splits=4, axis=1)
y_min2, x_min2, y_max2, x_max2 = tf.split(
value=boxlist2.get(), num_or_size_splits=4, axis=1)
all_pairs_min_ymax = tf.minimum(y_max1, tf.transpose(y_max2))
all_pairs_max_ymin = tf.maximum(y_min1, tf.transpose(y_min2))
intersect_heights = tf.maximum(0.0, all_pairs_min_ymax - all_pairs_max_ymin)
all_pairs_min_xmax = tf.minimum(x_max1, tf.transpose(x_max2))
all_pairs_max_xmin = tf.maximum(x_min1, tf.transpose(x_min2))
intersect_widths = tf.maximum(0.0, all_pairs_min_xmax - all_pairs_max_xmin)
return intersect_heights * intersect_widths
def matched_intersection(boxlist1, boxlist2, scope=None):
"""Compute intersection areas between corresponding boxes in two boxlists.
Args:
boxlist1: BoxList holding N boxes
boxlist2: BoxList holding N boxes
scope: name scope.
Returns:
a tensor with shape [N] representing pairwise intersections
"""
with tf.name_scope(scope, 'MatchedIntersection'):
y_min1, x_min1, y_max1, x_max1 = tf.split(
value=boxlist1.get(), num_or_size_splits=4, axis=1)
y_min2, x_min2, y_max2, x_max2 = tf.split(
value=boxlist2.get(), num_or_size_splits=4, axis=1)
min_ymax = tf.minimum(y_max1, y_max2)
max_ymin = tf.maximum(y_min1, y_min2)
intersect_heights = tf.maximum(0.0, min_ymax - max_ymin)
min_xmax = tf.minimum(x_max1, x_max2)
max_xmin = tf.maximum(x_min1, x_min2)
intersect_widths = tf.maximum(0.0, min_xmax - max_xmin)
return tf.reshape(intersect_heights * intersect_widths, [-1])
def iou(boxlist1, boxlist2, scope=None):
"""Computes pairwise intersection-over-union between box collections.
Args:
boxlist1: BoxList holding N boxes
boxlist2: BoxList holding M boxes
scope: name scope.
Returns:
a tensor with shape [N, M] representing pairwise iou scores.
"""
with tf.name_scope(scope, 'IOU'):
intersections = intersection(boxlist1, boxlist2)
areas1 = area(boxlist1)
areas2 = area(boxlist2)
unions = (
tf.expand_dims(areas1, 1) + tf.expand_dims(areas2, 0) - intersections)
return tf.where(
tf.equal(intersections, 0.0),
tf.zeros_like(intersections), tf.truediv(intersections, unions))
def matched_iou(boxlist1, boxlist2, scope=None):
"""Compute intersection-over-union between corresponding boxes in boxlists.
Args:
boxlist1: BoxList holding N boxes
boxlist2: BoxList holding N boxes
scope: name scope.
Returns:
a tensor with shape [N] representing pairwise iou scores.
"""
with tf.name_scope(scope, 'MatchedIOU'):
intersections = matched_intersection(boxlist1, boxlist2)
areas1 = area(boxlist1)
areas2 = area(boxlist2)
unions = areas1 + areas2 - intersections
return tf.where(
tf.equal(intersections, 0.0),
tf.zeros_like(intersections), tf.truediv(intersections, unions))
def ioa(boxlist1, boxlist2, scope=None):
"""Computes pairwise intersection-over-area between box collections.
intersection-over-area (IOA) between two boxes box1 and box2 is defined as
their intersection area over box2's area. Note that ioa is not symmetric,
that is, ioa(box1, box2) != ioa(box2, box1).
Args:
boxlist1: BoxList holding N boxes
boxlist2: BoxList holding M boxes
scope: name scope.
Returns:
a tensor with shape [N, M] representing pairwise ioa scores.
"""
with tf.name_scope(scope, 'IOA'):
intersections = intersection(boxlist1, boxlist2)
areas = tf.expand_dims(area(boxlist2), 0)
return tf.truediv(intersections, areas)
def prune_non_overlapping_boxes(
boxlist1, boxlist2, min_overlap=0.0, scope=None):
"""Prunes the boxes in boxlist1 that overlap less than thresh with boxlist2.
For each box in boxlist1, we want its IOA to be more than minoverlap with
at least one of the boxes in boxlist2. If it does not, we remove it.
Args:
boxlist1: BoxList holding N boxes.
boxlist2: BoxList holding M boxes.
min_overlap: Minimum required overlap between boxes, to count them as
overlapping.
scope: name scope.
Returns:
new_boxlist1: A pruned boxlist with size [N', 4].
keep_inds: A tensor with shape [N'] indexing kept bounding boxes in the
first input BoxList `boxlist1`.
"""
with tf.name_scope(scope, 'PruneNonOverlappingBoxes'):
ioa_ = ioa(boxlist2, boxlist1) # [M, N] tensor
ioa_ = tf.reduce_max(ioa_, reduction_indices=[0]) # [N] tensor
keep_bool = tf.greater_equal(ioa_, tf.constant(min_overlap))
keep_inds = tf.squeeze(tf.where(keep_bool), squeeze_dims=[1])
new_boxlist1 = gather(boxlist1, keep_inds)
return new_boxlist1, keep_inds
def prune_small_boxes(boxlist, min_side, scope=None):
"""Prunes small boxes in the boxlist which have a side smaller than min_side.
Args:
boxlist: BoxList holding N boxes.
min_side: Minimum width AND height of box to survive pruning.
scope: name scope.
Returns:
A pruned boxlist.
"""
with tf.name_scope(scope, 'PruneSmallBoxes'):
height, width = height_width(boxlist)
is_valid = tf.logical_and(tf.greater_equal(width, min_side),
tf.greater_equal(height, min_side))
return gather(boxlist, tf.reshape(tf.where(is_valid), [-1]))
def change_coordinate_frame(boxlist, window, scope=None):
"""Change coordinate frame of the boxlist to be relative to window's frame.
Given a window of the form [ymin, xmin, ymax, xmax],
changes bounding box coordinates from boxlist to be relative to this window
(e.g., the min corner maps to (0,0) and the max corner maps to (1,1)).
An example use case is data augmentation: where we are given groundtruth
boxes (boxlist) and would like to randomly crop the image to some
window (window). In this case we need to change the coordinate frame of
each groundtruth box to be relative to this new window.
Args:
boxlist: A BoxList object holding N boxes.
window: A rank 1 tensor [4].
scope: name scope.
Returns:
Returns a BoxList object with N boxes.
"""
with tf.name_scope(scope, 'ChangeCoordinateFrame'):
win_height = window[2] - window[0]
win_width = window[3] - window[1]
boxlist_new = scale(box_list.BoxList(
boxlist.get() - [window[0], window[1], window[0], window[1]]),
1.0 / win_height, 1.0 / win_width)
boxlist_new = _copy_extra_fields(boxlist_new, boxlist)
return boxlist_new
def sq_dist(boxlist1, boxlist2, scope=None):
"""Computes the pairwise squared distances between box corners.
This op treats each box as if it were a point in a 4d Euclidean space and
computes pairwise squared distances.
Mathematically, we are given two matrices of box coordinates X and Y,
where X(i,:) is the i'th row of X, containing the 4 numbers defining the
corners of the i'th box in boxlist1. Similarly Y(j,:) corresponds to
boxlist2. We compute
Z(i,j) = ||X(i,:) - Y(j,:)||^2
= ||X(i,:)||^2 + ||Y(j,:)||^2 - 2 X(i,:)' * Y(j,:),
Args:
boxlist1: BoxList holding N boxes
boxlist2: BoxList holding M boxes
scope: name scope.
Returns:
a tensor with shape [N, M] representing pairwise distances
"""
with tf.name_scope(scope, 'SqDist'):
sqnorm1 = tf.reduce_sum(tf.square(boxlist1.get()), 1, keep_dims=True)
sqnorm2 = tf.reduce_sum(tf.square(boxlist2.get()), 1, keep_dims=True)
innerprod = tf.matmul(boxlist1.get(), boxlist2.get(),
transpose_a=False, transpose_b=True)
return sqnorm1 + tf.transpose(sqnorm2) - 2.0 * innerprod
def boolean_mask(boxlist, indicator, fields=None, scope=None):
"""Select boxes from BoxList according to indicator and return new BoxList.
`boolean_mask` returns the subset of boxes that are marked as "True" by the
indicator tensor. By default, `boolean_mask` returns boxes corresponding to
the input index list, as well as all additional fields stored in the boxlist
(indexing into the first dimension). However one can optionally only draw
from a subset of fields.
Args:
boxlist: BoxList holding N boxes
indicator: a rank-1 boolean tensor
fields: (optional) list of fields to also gather from. If None (default),
all fields are gathered from. Pass an empty fields list to only gather
the box coordinates.
scope: name scope.
Returns:
subboxlist: a BoxList corresponding to the subset of the input BoxList
specified by indicator
Raises:
ValueError: if `indicator` is not a rank-1 boolean tensor.
"""
with tf.name_scope(scope, 'BooleanMask'):
if indicator.shape.ndims != 1:
raise ValueError('indicator should have rank 1')
if indicator.dtype != tf.bool:
raise ValueError('indicator should be a boolean tensor')
subboxlist = box_list.BoxList(tf.boolean_mask(boxlist.get(), indicator))
if fields is None:
fields = boxlist.get_extra_fields()
for field in fields:
if not boxlist.has_field(field):
raise ValueError('boxlist must contain all specified fields')
subfieldlist = tf.boolean_mask(boxlist.get_field(field), indicator)
subboxlist.add_field(field, subfieldlist)
return subboxlist
def gather(boxlist, indices, fields=None, scope=None):
"""Gather boxes from BoxList according to indices and return new BoxList.
By default, `gather` returns boxes corresponding to the input index list, as
well as all additional fields stored in the boxlist (indexing into the
first dimension). However one can optionally only gather from a
subset of fields.
Args:
boxlist: BoxList holding N boxes
indices: a rank-1 tensor of type int32 / int64
fields: (optional) list of fields to also gather from. If None (default),
all fields are gathered from. Pass an empty fields list to only gather
the box coordinates.
scope: name scope.
Returns:
subboxlist: a BoxList corresponding to the subset of the input BoxList
specified by indices
Raises:
ValueError: if specified field is not contained in boxlist or if the
indices are not of type int32
"""
with tf.name_scope(scope, 'Gather'):
if len(indices.shape.as_list()) != 1:
raise ValueError('indices should have rank 1')
if indices.dtype != tf.int32 and indices.dtype != tf.int64:
raise ValueError('indices should be an int32 / int64 tensor')
subboxlist = box_list.BoxList(tf.gather(boxlist.get(), indices))
if fields is None:
fields = boxlist.get_extra_fields()
for field in fields:
if not boxlist.has_field(field):
raise ValueError('boxlist must contain all specified fields')
subfieldlist = tf.gather(boxlist.get_field(field), indices)
subboxlist.add_field(field, subfieldlist)
return subboxlist
def concatenate(boxlists, fields=None, scope=None):
"""Concatenate list of BoxLists.
This op concatenates a list of input BoxLists into a larger BoxList. It also
handles concatenation of BoxList fields as long as the field tensor shapes
are equal except for the first dimension.
Args:
boxlists: list of BoxList objects
fields: optional list of fields to also concatenate. By default, all
fields from the first BoxList in the list are included in the
concatenation.
scope: name scope.
Returns:
a BoxList with number of boxes equal to
sum([boxlist.num_boxes() for boxlist in BoxList])
Raises:
ValueError: if boxlists is invalid (i.e., is not a list, is empty, or
contains non BoxList objects), or if requested fields are not contained in
all boxlists
"""
with tf.name_scope(scope, 'Concatenate'):
if not isinstance(boxlists, list):
raise ValueError('boxlists should be a list')
if not boxlists:
raise ValueError('boxlists should have nonzero length')
for boxlist in boxlists:
if not isinstance(boxlist, box_list.BoxList):
raise ValueError('all elements of boxlists should be BoxList objects')
concatenated = box_list.BoxList(
tf.concat([boxlist.get() for boxlist in boxlists], 0))
if fields is None:
fields = boxlists[0].get_extra_fields()
for field in fields:
first_field_shape = boxlists[0].get_field(field).get_shape().as_list()
first_field_shape[0] = -1
if None in first_field_shape:
raise ValueError('field %s must have fully defined shape except for the'
' 0th dimension.' % field)
for boxlist in boxlists:
if not boxlist.has_field(field):
raise ValueError('boxlist must contain all requested fields')
field_shape = boxlist.get_field(field).get_shape().as_list()
field_shape[0] = -1
if field_shape != first_field_shape:
raise ValueError('field %s must have same shape for all boxlists '
'except for the 0th dimension.' % field)
concatenated_field = tf.concat(
[boxlist.get_field(field) for boxlist in boxlists], 0)
concatenated.add_field(field, concatenated_field)
return concatenated
def sort_by_field(boxlist, field, order=SortOrder.descend, scope=None):
"""Sort boxes and associated fields according to a scalar field.
A common use case is reordering the boxes according to descending scores.
Args:
boxlist: BoxList holding N boxes.
field: A BoxList field for sorting and reordering the BoxList.
order: (Optional) descend or ascend. Default is descend.
scope: name scope.
Returns:
sorted_boxlist: A sorted BoxList with the field in the specified order.
Raises:
ValueError: if specified field does not exist
ValueError: if the order is not either descend or ascend
"""
with tf.name_scope(scope, 'SortByField'):
if order != SortOrder.descend and order != SortOrder.ascend:
raise ValueError('Invalid sort order')
field_to_sort = boxlist.get_field(field)
if len(field_to_sort.shape.as_list()) != 1:
raise ValueError('Field should have rank 1')
num_boxes = boxlist.num_boxes()
num_entries = tf.size(field_to_sort)
length_assert = tf.Assert(
tf.equal(num_boxes, num_entries),
['Incorrect field size: actual vs expected.', num_entries, num_boxes])
with tf.control_dependencies([length_assert]):
# TODO: Remove with tf.device when top_k operation runs correctly on GPU.
with tf.device('/cpu:0'):
_, sorted_indices = tf.nn.top_k(field_to_sort, num_boxes, sorted=True)
if order == SortOrder.ascend:
sorted_indices = tf.reverse_v2(sorted_indices, [0])
return gather(boxlist, sorted_indices)
def visualize_boxes_in_image(image, boxlist, normalized=False, scope=None):
"""Overlay bounding box list on image.
Currently this visualization plots a 1 pixel thick red bounding box on top
of the image. Note that tf.image.draw_bounding_boxes essentially is
1 indexed.
Args:
image: an image tensor with shape [height, width, 3]
boxlist: a BoxList
normalized: (boolean) specify whether corners are to be interpreted
as absolute coordinates in image space or normalized with respect to the
image size.
scope: name scope.
Returns:
image_and_boxes: an image tensor with shape [height, width, 3]
"""
with tf.name_scope(scope, 'VisualizeBoxesInImage'):
if not normalized:
height, width, _ = tf.unstack(tf.shape(image))
boxlist = scale(boxlist,
1.0 / tf.cast(height, tf.float32),
1.0 / tf.cast(width, tf.float32))
corners = tf.expand_dims(boxlist.get(), 0)
image = tf.expand_dims(image, 0)
return tf.squeeze(tf.image.draw_bounding_boxes(image, corners), [0])
def filter_field_value_equals(boxlist, field, value, scope=None):
"""Filter to keep only boxes with field entries equal to the given value.
Args:
boxlist: BoxList holding N boxes.
field: field name for filtering.
value: scalar value.
scope: name scope.
Returns:
a BoxList holding M boxes where M <= N
Raises:
ValueError: if boxlist not a BoxList object or if it does not have
the specified field.
"""
with tf.name_scope(scope, 'FilterFieldValueEquals'):
if not isinstance(boxlist, box_list.BoxList):
raise ValueError('boxlist must be a BoxList')
if not boxlist.has_field(field):
raise ValueError('boxlist must contain the specified field')
filter_field = boxlist.get_field(field)
gather_index = tf.reshape(tf.where(tf.equal(filter_field, value)), [-1])
return gather(boxlist, gather_index)
def filter_greater_than(boxlist, thresh, scope=None):
"""Filter to keep only boxes with score exceeding a given threshold.
This op keeps the collection of boxes whose corresponding scores are
greater than the input threshold.
TODO: Change function name to FilterScoresGreaterThan
Args:
boxlist: BoxList holding N boxes. Must contain a 'scores' field
representing detection scores.
thresh: scalar threshold
scope: name scope.
Returns:
a BoxList holding M boxes where M <= N
Raises:
ValueError: if boxlist not a BoxList object or if it does not
have a scores field
"""
with tf.name_scope(scope, 'FilterGreaterThan'):
if not isinstance(boxlist, box_list.BoxList):
raise ValueError('boxlist must be a BoxList')
if not boxlist.has_field('scores'):
raise ValueError('input boxlist must have \'scores\' field')
scores = boxlist.get_field('scores')
if len(scores.shape.as_list()) > 2:
raise ValueError('Scores should have rank 1 or 2')
if len(scores.shape.as_list()) == 2 and scores.shape.as_list()[1] != 1:
raise ValueError('Scores should have rank 1 or have shape '
'consistent with [None, 1]')
high_score_indices = tf.cast(tf.reshape(
tf.where(tf.greater(scores, thresh)),
[-1]), tf.int32)
return gather(boxlist, high_score_indices)
def non_max_suppression(boxlist, thresh, max_output_size, scope=None):
"""Non maximum suppression.
This op greedily selects a subset of detection bounding boxes, pruning
away boxes that have high IOU (intersection over union) overlap (> thresh)
with already selected boxes. Note that this only works for a single class ---
to apply NMS to multi-class predictions, use MultiClassNonMaxSuppression.
Args:
boxlist: BoxList holding N boxes. Must contain a 'scores' field
representing detection scores.
thresh: scalar threshold
max_output_size: maximum number of retained boxes
scope: name scope.
Returns:
a BoxList holding M boxes where M <= max_output_size
Raises:
ValueError: if thresh is not in [0, 1]
"""
with tf.name_scope(scope, 'NonMaxSuppression'):
if not 0 <= thresh <= 1.0:
raise ValueError('thresh must be between 0 and 1')
if not isinstance(boxlist, box_list.BoxList):
raise ValueError('boxlist must be a BoxList')
if not boxlist.has_field('scores'):
raise ValueError('input boxlist must have \'scores\' field')
selected_indices = tf.image.non_max_suppression(
boxlist.get(), boxlist.get_field('scores'),
max_output_size, iou_threshold=thresh)
return gather(boxlist, selected_indices)
def _copy_extra_fields(boxlist_to_copy_to, boxlist_to_copy_from):
"""Copies the extra fields of boxlist_to_copy_from to boxlist_to_copy_to.
Args:
boxlist_to_copy_to: BoxList to which extra fields are copied.
boxlist_to_copy_from: BoxList from which fields are copied.
Returns:
boxlist_to_copy_to with extra fields.
"""
for field in boxlist_to_copy_from.get_extra_fields():
boxlist_to_copy_to.add_field(field, boxlist_to_copy_from.get_field(field))
return boxlist_to_copy_to
def to_normalized_coordinates(boxlist, height, width,
check_range=True, scope=None):
"""Converts absolute box coordinates to normalized coordinates in [0, 1].
Usually one uses the dynamic shape of the image or conv-layer tensor:
boxlist = box_list_ops.to_normalized_coordinates(boxlist,
tf.shape(images)[1],
tf.shape(images)[2]),
This function raises an assertion failed error at graph execution time when
the maximum coordinate is smaller than 1.01 (which means that coordinates are
already normalized). The value 1.01 is to deal with small rounding errors.
Args:
boxlist: BoxList with coordinates in terms of pixel-locations.
height: Maximum value for height of absolute box coordinates.
width: Maximum value for width of absolute box coordinates.
check_range: If True, checks if the coordinates are normalized or not.
scope: name scope.
Returns:
boxlist with normalized coordinates in [0, 1].
"""
with tf.name_scope(scope, 'ToNormalizedCoordinates'):
height = tf.cast(height, tf.float32)
width = tf.cast(width, tf.float32)
if check_range:
max_val = tf.reduce_max(boxlist.get())
max_assert = tf.Assert(tf.greater(max_val, 1.01),
['max value is lower than 1.01: ', max_val])
with tf.control_dependencies([max_assert]):
width = tf.identity(width)
return scale(boxlist, 1 / height, 1 / width)
def to_absolute_coordinates(boxlist, height, width,
check_range=True, scope=None):
"""Converts normalized box coordinates to absolute pixel coordinates.
This function raises an assertion failed error when the maximum box coordinate
value is larger than 1.01 (in which case coordinates are already absolute).
Args:
boxlist: BoxList with coordinates in range [0, 1].
height: Maximum value for height of absolute box coordinates.
width: Maximum value for width of absolute box coordinates.
check_range: If True, checks if the coordinates are normalized or not.
scope: name scope.
Returns:
boxlist with absolute coordinates in terms of the image size.
"""
with tf.name_scope(scope, 'ToAbsoluteCoordinates'):
height = tf.cast(height, tf.float32)
width = tf.cast(width, tf.float32)
# Ensure range of input boxes is correct.
if check_range:
box_maximum = tf.reduce_max(boxlist.get())
max_assert = tf.Assert(tf.greater_equal(1.01, box_maximum),
['maximum box coordinate value is larger '
'than 1.01: ', box_maximum])
with tf.control_dependencies([max_assert]):
width = tf.identity(width)
return scale(boxlist, height, width)
def refine_boxes_multi_class(pool_boxes,
num_classes,
nms_iou_thresh,
nms_max_detections,
voting_iou_thresh=0.5):
"""Refines a pool of boxes using non max suppression and box voting.
Box refinement is done independently for each class.
Args:
pool_boxes: (BoxList) A collection of boxes to be refined. pool_boxes must
have a rank 1 'scores' field and a rank 1 'classes' field.
num_classes: (int scalar) Number of classes.
nms_iou_thresh: (float scalar) iou threshold for non max suppression (NMS).
nms_max_detections: (int scalar) maximum output size for NMS.
voting_iou_thresh: (float scalar) iou threshold for box voting.
Returns:
BoxList of refined boxes.
Raises:
ValueError: if
a) nms_iou_thresh or voting_iou_thresh is not in [0, 1].
b) pool_boxes is not a BoxList.
c) pool_boxes does not have a scores and classes field.
"""
if not 0.0 <= nms_iou_thresh <= 1.0:
raise ValueError('nms_iou_thresh must be between 0 and 1')
if not 0.0 <= voting_iou_thresh <= 1.0:
raise ValueError('voting_iou_thresh must be between 0 and 1')
if not isinstance(pool_boxes, box_list.BoxList):
raise ValueError('pool_boxes must be a BoxList')
if not pool_boxes.has_field('scores'):
raise ValueError('pool_boxes must have a \'scores\' field')
if not pool_boxes.has_field('classes'):
raise ValueError('pool_boxes must have a \'classes\' field')
refined_boxes = []
for i in range(num_classes):
boxes_class = filter_field_value_equals(pool_boxes, 'classes', i)
refined_boxes_class = refine_boxes(boxes_class, nms_iou_thresh,
nms_max_detections, voting_iou_thresh)
refined_boxes.append(refined_boxes_class)
return sort_by_field(concatenate(refined_boxes), 'scores')
def refine_boxes(pool_boxes,
nms_iou_thresh,
nms_max_detections,
voting_iou_thresh=0.5):
"""Refines a pool of boxes using non max suppression and box voting.
Args:
pool_boxes: (BoxList) A collection of boxes to be refined. pool_boxes must
have a rank 1 'scores' field.
nms_iou_thresh: (float scalar) iou threshold for non max suppression (NMS).
nms_max_detections: (int scalar) maximum output size for NMS.
voting_iou_thresh: (float scalar) iou threshold for box voting.
Returns:
BoxList of refined boxes.
Raises:
ValueError: if
a) nms_iou_thresh or voting_iou_thresh is not in [0, 1].
b) pool_boxes is not a BoxList.
c) pool_boxes does not have a scores field.
"""
if not 0.0 <= nms_iou_thresh <= 1.0:
raise ValueError('nms_iou_thresh must be between 0 and 1')
if not 0.0 <= voting_iou_thresh <= 1.0:
raise ValueError('voting_iou_thresh must be between 0 and 1')
if not isinstance(pool_boxes, box_list.BoxList):
raise ValueError('pool_boxes must be a BoxList')
if not pool_boxes.has_field('scores'):
raise ValueError('pool_boxes must have a \'scores\' field')
nms_boxes = non_max_suppression(
pool_boxes, nms_iou_thresh, nms_max_detections)
return box_voting(nms_boxes, pool_boxes, voting_iou_thresh)
def box_voting(selected_boxes, pool_boxes, iou_thresh=0.5):
"""Performs box voting as described in S. Gidaris and N. Komodakis, ICCV 2015.
Performs box voting as described in 'Object detection via a multi-region &
semantic segmentation-aware CNN model', Gidaris and Komodakis, ICCV 2015. For
each box 'B' in selected_boxes, we find the set 'S' of boxes in pool_boxes
with iou overlap >= iou_thresh. The location of B is set to the weighted
average location of boxes in S (scores are used for weighting). And the score
of B is set to the average score of boxes in S.
Args:
selected_boxes: BoxList containing a subset of boxes in pool_boxes. These
boxes are usually selected from pool_boxes using non max suppression.
pool_boxes: BoxList containing a set of (possibly redundant) boxes.
iou_thresh: (float scalar) iou threshold for matching boxes in
selected_boxes and pool_boxes.
Returns:
BoxList containing averaged locations and scores for each box in
selected_boxes.
Raises:
ValueError: if
a) selected_boxes or pool_boxes is not a BoxList.
b) if iou_thresh is not in [0, 1].
c) pool_boxes does not have a scores field.
"""
if not 0.0 <= iou_thresh <= 1.0:
raise ValueError('iou_thresh must be between 0 and 1')
if not isinstance(selected_boxes, box_list.BoxList):
raise ValueError('selected_boxes must be a BoxList')
if not isinstance(pool_boxes, box_list.BoxList):
raise ValueError('pool_boxes must be a BoxList')
if not pool_boxes.has_field('scores'):
raise ValueError('pool_boxes must have a \'scores\' field')
iou_ = iou(selected_boxes, pool_boxes)
match_indicator = tf.to_float(tf.greater(iou_, iou_thresh))
num_matches = tf.reduce_sum(match_indicator, 1)
# TODO: Handle the case where some boxes in selected_boxes do not match to any
# boxes in pool_boxes. For such boxes without any matches, we should return
# the original boxes without voting.
match_assert = tf.Assert(
tf.reduce_all(tf.greater(num_matches, 0)),
['Each box in selected_boxes must match with at least one box '
'in pool_boxes.'])
scores = tf.expand_dims(pool_boxes.get_field('scores'), 1)
scores_assert = tf.Assert(
tf.reduce_all(tf.greater_equal(scores, 0)),
['Scores must be non negative.'])
with tf.control_dependencies([scores_assert, match_assert]):
sum_scores = tf.matmul(match_indicator, scores)
averaged_scores = tf.reshape(sum_scores, [-1]) / num_matches
box_locations = tf.matmul(match_indicator,
pool_boxes.get() * scores) / sum_scores
averaged_boxes = box_list.BoxList(box_locations)
_copy_extra_fields(averaged_boxes, selected_boxes)
averaged_boxes.add_field('scores', averaged_scores)
return averaged_boxes
def pad_or_clip_box_list(boxlist, num_boxes, scope=None):
"""Pads or clips all fields of a BoxList.
Args:
boxlist: A BoxList with arbitrary of number of boxes.
num_boxes: First num_boxes in boxlist are kept.
The fields are zero-padded if num_boxes is bigger than the
actual number of boxes.
scope: name scope.
Returns:
BoxList with all fields padded or clipped.
"""
with tf.name_scope(scope, 'PadOrClipBoxList'):
subboxlist = box_list.BoxList(shape_utils.pad_or_clip_tensor(
boxlist.get(), num_boxes))
for field in boxlist.get_extra_fields():
subfield = shape_utils.pad_or_clip_tensor(
boxlist.get_field(field), num_boxes)
subboxlist.add_field(field, subfield)
return subboxlist
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for object_detection.core.box_list_ops."""
import numpy as np
import tensorflow as tf
from tensorflow.python.framework import errors
from object_detection.core import box_list
from object_detection.core import box_list_ops
class BoxListOpsTest(tf.test.TestCase):
"""Tests for common bounding box operations."""
def test_area(self):
corners = tf.constant([[0.0, 0.0, 10.0, 20.0], [1.0, 2.0, 3.0, 4.0]])
exp_output = [200.0, 4.0]
boxes = box_list.BoxList(corners)
areas = box_list_ops.area(boxes)
with self.test_session() as sess:
areas_output = sess.run(areas)
self.assertAllClose(areas_output, exp_output)
def test_height_width(self):
corners = tf.constant([[0.0, 0.0, 10.0, 20.0], [1.0, 2.0, 3.0, 4.0]])
exp_output_heights = [10., 2.]
exp_output_widths = [20., 2.]
boxes = box_list.BoxList(corners)
heights, widths = box_list_ops.height_width(boxes)
with self.test_session() as sess:
output_heights, output_widths = sess.run([heights, widths])
self.assertAllClose(output_heights, exp_output_heights)
self.assertAllClose(output_widths, exp_output_widths)
def test_scale(self):
corners = tf.constant([[0, 0, 100, 200], [50, 120, 100, 140]],
dtype=tf.float32)
boxes = box_list.BoxList(corners)
boxes.add_field('extra_data', tf.constant([[1], [2]]))
y_scale = tf.constant(1.0/100)
x_scale = tf.constant(1.0/200)
scaled_boxes = box_list_ops.scale(boxes, y_scale, x_scale)
exp_output = [[0, 0, 1, 1], [0.5, 0.6, 1.0, 0.7]]
with self.test_session() as sess:
scaled_corners_out = sess.run(scaled_boxes.get())
self.assertAllClose(scaled_corners_out, exp_output)
extra_data_out = sess.run(scaled_boxes.get_field('extra_data'))
self.assertAllEqual(extra_data_out, [[1], [2]])
def test_clip_to_window_filter_boxes_which_fall_outside_the_window(
self):
window = tf.constant([0, 0, 9, 14], tf.float32)
corners = tf.constant([[5.0, 5.0, 6.0, 6.0],
[-1.0, -2.0, 4.0, 5.0],
[2.0, 3.0, 5.0, 9.0],
[0.0, 0.0, 9.0, 14.0],
[-100.0, -100.0, 300.0, 600.0],
[-10.0, -10.0, -9.0, -9.0]])
boxes = box_list.BoxList(corners)
boxes.add_field('extra_data', tf.constant([[1], [2], [3], [4], [5], [6]]))
exp_output = [[5.0, 5.0, 6.0, 6.0], [0.0, 0.0, 4.0, 5.0],
[2.0, 3.0, 5.0, 9.0], [0.0, 0.0, 9.0, 14.0],
[0.0, 0.0, 9.0, 14.0]]
pruned = box_list_ops.clip_to_window(
boxes, window, filter_nonoverlapping=True)
with self.test_session() as sess:
pruned_output = sess.run(pruned.get())
self.assertAllClose(pruned_output, exp_output)
extra_data_out = sess.run(pruned.get_field('extra_data'))
self.assertAllEqual(extra_data_out, [[1], [2], [3], [4], [5]])
def test_clip_to_window_without_filtering_boxes_which_fall_outside_the_window(
self):
window = tf.constant([0, 0, 9, 14], tf.float32)
corners = tf.constant([[5.0, 5.0, 6.0, 6.0],
[-1.0, -2.0, 4.0, 5.0],
[2.0, 3.0, 5.0, 9.0],
[0.0, 0.0, 9.0, 14.0],
[-100.0, -100.0, 300.0, 600.0],
[-10.0, -10.0, -9.0, -9.0]])
boxes = box_list.BoxList(corners)
boxes.add_field('extra_data', tf.constant([[1], [2], [3], [4], [5], [6]]))
exp_output = [[5.0, 5.0, 6.0, 6.0], [0.0, 0.0, 4.0, 5.0],
[2.0, 3.0, 5.0, 9.0], [0.0, 0.0, 9.0, 14.0],
[0.0, 0.0, 9.0, 14.0], [0.0, 0.0, 0.0, 0.0]]
pruned = box_list_ops.clip_to_window(
boxes, window, filter_nonoverlapping=False)
with self.test_session() as sess:
pruned_output = sess.run(pruned.get())
self.assertAllClose(pruned_output, exp_output)
extra_data_out = sess.run(pruned.get_field('extra_data'))
self.assertAllEqual(extra_data_out, [[1], [2], [3], [4], [5], [6]])
def test_prune_outside_window_filters_boxes_which_fall_outside_the_window(
self):
window = tf.constant([0, 0, 9, 14], tf.float32)
corners = tf.constant([[5.0, 5.0, 6.0, 6.0],
[-1.0, -2.0, 4.0, 5.0],
[2.0, 3.0, 5.0, 9.0],
[0.0, 0.0, 9.0, 14.0],
[-10.0, -10.0, -9.0, -9.0],
[-100.0, -100.0, 300.0, 600.0]])
boxes = box_list.BoxList(corners)
boxes.add_field('extra_data', tf.constant([[1], [2], [3], [4], [5], [6]]))
exp_output = [[5.0, 5.0, 6.0, 6.0],
[2.0, 3.0, 5.0, 9.0],
[0.0, 0.0, 9.0, 14.0]]
pruned, keep_indices = box_list_ops.prune_outside_window(boxes, window)
with self.test_session() as sess:
pruned_output = sess.run(pruned.get())
self.assertAllClose(pruned_output, exp_output)
keep_indices_out = sess.run(keep_indices)
self.assertAllEqual(keep_indices_out, [0, 2, 3])
extra_data_out = sess.run(pruned.get_field('extra_data'))
self.assertAllEqual(extra_data_out, [[1], [3], [4]])
def test_prune_completely_outside_window(self):
window = tf.constant([0, 0, 9, 14], tf.float32)
corners = tf.constant([[5.0, 5.0, 6.0, 6.0],
[-1.0, -2.0, 4.0, 5.0],
[2.0, 3.0, 5.0, 9.0],
[0.0, 0.0, 9.0, 14.0],
[-10.0, -10.0, -9.0, -9.0],
[-100.0, -100.0, 300.0, 600.0]])
boxes = box_list.BoxList(corners)
boxes.add_field('extra_data', tf.constant([[1], [2], [3], [4], [5], [6]]))
exp_output = [[5.0, 5.0, 6.0, 6.0],
[-1.0, -2.0, 4.0, 5.0],
[2.0, 3.0, 5.0, 9.0],
[0.0, 0.0, 9.0, 14.0],
[-100.0, -100.0, 300.0, 600.0]]
pruned, keep_indices = box_list_ops.prune_completely_outside_window(boxes,
window)
with self.test_session() as sess:
pruned_output = sess.run(pruned.get())
self.assertAllClose(pruned_output, exp_output)
keep_indices_out = sess.run(keep_indices)
self.assertAllEqual(keep_indices_out, [0, 1, 2, 3, 5])
extra_data_out = sess.run(pruned.get_field('extra_data'))
self.assertAllEqual(extra_data_out, [[1], [2], [3], [4], [6]])
def test_intersection(self):
corners1 = tf.constant([[4.0, 3.0, 7.0, 5.0], [5.0, 6.0, 10.0, 7.0]])
corners2 = tf.constant([[3.0, 4.0, 6.0, 8.0], [14.0, 14.0, 15.0, 15.0],
[0.0, 0.0, 20.0, 20.0]])
exp_output = [[2.0, 0.0, 6.0], [1.0, 0.0, 5.0]]
boxes1 = box_list.BoxList(corners1)
boxes2 = box_list.BoxList(corners2)
intersect = box_list_ops.intersection(boxes1, boxes2)
with self.test_session() as sess:
intersect_output = sess.run(intersect)
self.assertAllClose(intersect_output, exp_output)
def test_matched_intersection(self):
corners1 = tf.constant([[4.0, 3.0, 7.0, 5.0], [5.0, 6.0, 10.0, 7.0]])
corners2 = tf.constant([[3.0, 4.0, 6.0, 8.0], [14.0, 14.0, 15.0, 15.0]])
exp_output = [2.0, 0.0]
boxes1 = box_list.BoxList(corners1)
boxes2 = box_list.BoxList(corners2)
intersect = box_list_ops.matched_intersection(boxes1, boxes2)
with self.test_session() as sess:
intersect_output = sess.run(intersect)
self.assertAllClose(intersect_output, exp_output)
def test_iou(self):
corners1 = tf.constant([[4.0, 3.0, 7.0, 5.0], [5.0, 6.0, 10.0, 7.0]])
corners2 = tf.constant([[3.0, 4.0, 6.0, 8.0], [14.0, 14.0, 15.0, 15.0],
[0.0, 0.0, 20.0, 20.0]])
exp_output = [[2.0 / 16.0, 0, 6.0 / 400.0], [1.0 / 16.0, 0.0, 5.0 / 400.0]]
boxes1 = box_list.BoxList(corners1)
boxes2 = box_list.BoxList(corners2)
iou = box_list_ops.iou(boxes1, boxes2)
with self.test_session() as sess:
iou_output = sess.run(iou)
self.assertAllClose(iou_output, exp_output)
def test_matched_iou(self):
corners1 = tf.constant([[4.0, 3.0, 7.0, 5.0], [5.0, 6.0, 10.0, 7.0]])
corners2 = tf.constant([[3.0, 4.0, 6.0, 8.0], [14.0, 14.0, 15.0, 15.0]])
exp_output = [2.0 / 16.0, 0]
boxes1 = box_list.BoxList(corners1)
boxes2 = box_list.BoxList(corners2)
iou = box_list_ops.matched_iou(boxes1, boxes2)
with self.test_session() as sess:
iou_output = sess.run(iou)
self.assertAllClose(iou_output, exp_output)
def test_iouworks_on_empty_inputs(self):
corners1 = tf.constant([[4.0, 3.0, 7.0, 5.0], [5.0, 6.0, 10.0, 7.0]])
corners2 = tf.constant([[3.0, 4.0, 6.0, 8.0], [14.0, 14.0, 15.0, 15.0],
[0.0, 0.0, 20.0, 20.0]])
boxes1 = box_list.BoxList(corners1)
boxes2 = box_list.BoxList(corners2)
boxes_empty = box_list.BoxList(tf.zeros((0, 4)))
iou_empty_1 = box_list_ops.iou(boxes1, boxes_empty)
iou_empty_2 = box_list_ops.iou(boxes_empty, boxes2)
iou_empty_3 = box_list_ops.iou(boxes_empty, boxes_empty)
with self.test_session() as sess:
iou_output_1, iou_output_2, iou_output_3 = sess.run(
[iou_empty_1, iou_empty_2, iou_empty_3])
self.assertAllEqual(iou_output_1.shape, (2, 0))
self.assertAllEqual(iou_output_2.shape, (0, 3))
self.assertAllEqual(iou_output_3.shape, (0, 0))
def test_ioa(self):
corners1 = tf.constant([[4.0, 3.0, 7.0, 5.0], [5.0, 6.0, 10.0, 7.0]])
corners2 = tf.constant([[3.0, 4.0, 6.0, 8.0], [14.0, 14.0, 15.0, 15.0],
[0.0, 0.0, 20.0, 20.0]])
exp_output_1 = [[2.0 / 12.0, 0, 6.0 / 400.0],
[1.0 / 12.0, 0.0, 5.0 / 400.0]]
exp_output_2 = [[2.0 / 6.0, 1.0 / 5.0],
[0, 0],
[6.0 / 6.0, 5.0 / 5.0]]
boxes1 = box_list.BoxList(corners1)
boxes2 = box_list.BoxList(corners2)
ioa_1 = box_list_ops.ioa(boxes1, boxes2)
ioa_2 = box_list_ops.ioa(boxes2, boxes1)
with self.test_session() as sess:
ioa_output_1, ioa_output_2 = sess.run([ioa_1, ioa_2])
self.assertAllClose(ioa_output_1, exp_output_1)
self.assertAllClose(ioa_output_2, exp_output_2)
def test_prune_non_overlapping_boxes(self):
corners1 = tf.constant([[4.0, 3.0, 7.0, 5.0], [5.0, 6.0, 10.0, 7.0]])
corners2 = tf.constant([[3.0, 4.0, 6.0, 8.0], [14.0, 14.0, 15.0, 15.0],
[0.0, 0.0, 20.0, 20.0]])
boxes1 = box_list.BoxList(corners1)
boxes2 = box_list.BoxList(corners2)
minoverlap = 0.5
exp_output_1 = boxes1
exp_output_2 = box_list.BoxList(tf.constant(0.0, shape=[0, 4]))
output_1, keep_indices_1 = box_list_ops.prune_non_overlapping_boxes(
boxes1, boxes2, min_overlap=minoverlap)
output_2, keep_indices_2 = box_list_ops.prune_non_overlapping_boxes(
boxes2, boxes1, min_overlap=minoverlap)
with self.test_session() as sess:
(output_1_, keep_indices_1_, output_2_, keep_indices_2_, exp_output_1_,
exp_output_2_) = sess.run(
[output_1.get(), keep_indices_1,
output_2.get(), keep_indices_2,
exp_output_1.get(), exp_output_2.get()])
self.assertAllClose(output_1_, exp_output_1_)
self.assertAllClose(output_2_, exp_output_2_)
self.assertAllEqual(keep_indices_1_, [0, 1])
self.assertAllEqual(keep_indices_2_, [])
def test_prune_small_boxes(self):
boxes = tf.constant([[4.0, 3.0, 7.0, 5.0],
[5.0, 6.0, 10.0, 7.0],
[3.0, 4.0, 6.0, 8.0],
[14.0, 14.0, 15.0, 15.0],
[0.0, 0.0, 20.0, 20.0]])
exp_boxes = [[3.0, 4.0, 6.0, 8.0],
[0.0, 0.0, 20.0, 20.0]]
boxes = box_list.BoxList(boxes)
pruned_boxes = box_list_ops.prune_small_boxes(boxes, 3)
with self.test_session() as sess:
pruned_boxes = sess.run(pruned_boxes.get())
self.assertAllEqual(pruned_boxes, exp_boxes)
def test_prune_small_boxes_prunes_boxes_with_negative_side(self):
boxes = tf.constant([[4.0, 3.0, 7.0, 5.0],
[5.0, 6.0, 10.0, 7.0],
[3.0, 4.0, 6.0, 8.0],
[14.0, 14.0, 15.0, 15.0],
[0.0, 0.0, 20.0, 20.0],
[2.0, 3.0, 1.5, 7.0], # negative height
[2.0, 3.0, 5.0, 1.7]]) # negative width
exp_boxes = [[3.0, 4.0, 6.0, 8.0],
[0.0, 0.0, 20.0, 20.0]]
boxes = box_list.BoxList(boxes)
pruned_boxes = box_list_ops.prune_small_boxes(boxes, 3)
with self.test_session() as sess:
pruned_boxes = sess.run(pruned_boxes.get())
self.assertAllEqual(pruned_boxes, exp_boxes)
def test_change_coordinate_frame(self):
corners = tf.constant([[0.25, 0.5, 0.75, 0.75], [0.5, 0.0, 1.0, 1.0]])
window = tf.constant([0.25, 0.25, 0.75, 0.75])
boxes = box_list.BoxList(corners)
expected_corners = tf.constant([[0, 0.5, 1.0, 1.0], [0.5, -0.5, 1.5, 1.5]])
expected_boxes = box_list.BoxList(expected_corners)
output = box_list_ops.change_coordinate_frame(boxes, window)
with self.test_session() as sess:
output_, expected_boxes_ = sess.run([output.get(), expected_boxes.get()])
self.assertAllClose(output_, expected_boxes_)
def test_ioaworks_on_empty_inputs(self):
corners1 = tf.constant([[4.0, 3.0, 7.0, 5.0], [5.0, 6.0, 10.0, 7.0]])
corners2 = tf.constant([[3.0, 4.0, 6.0, 8.0], [14.0, 14.0, 15.0, 15.0],
[0.0, 0.0, 20.0, 20.0]])
boxes1 = box_list.BoxList(corners1)
boxes2 = box_list.BoxList(corners2)
boxes_empty = box_list.BoxList(tf.zeros((0, 4)))
ioa_empty_1 = box_list_ops.ioa(boxes1, boxes_empty)
ioa_empty_2 = box_list_ops.ioa(boxes_empty, boxes2)
ioa_empty_3 = box_list_ops.ioa(boxes_empty, boxes_empty)
with self.test_session() as sess:
ioa_output_1, ioa_output_2, ioa_output_3 = sess.run(
[ioa_empty_1, ioa_empty_2, ioa_empty_3])
self.assertAllEqual(ioa_output_1.shape, (2, 0))
self.assertAllEqual(ioa_output_2.shape, (0, 3))
self.assertAllEqual(ioa_output_3.shape, (0, 0))
def test_pairwise_distances(self):
corners1 = tf.constant([[0.0, 0.0, 0.0, 0.0],
[1.0, 1.0, 0.0, 2.0]])
corners2 = tf.constant([[3.0, 4.0, 1.0, 0.0],
[-4.0, 0.0, 0.0, 3.0],
[0.0, 0.0, 0.0, 0.0]])
exp_output = [[26, 25, 0], [18, 27, 6]]
boxes1 = box_list.BoxList(corners1)
boxes2 = box_list.BoxList(corners2)
dist_matrix = box_list_ops.sq_dist(boxes1, boxes2)
with self.test_session() as sess:
dist_output = sess.run(dist_matrix)
self.assertAllClose(dist_output, exp_output)
def test_boolean_mask(self):
corners = tf.constant(
[4 * [0.0], 4 * [1.0], 4 * [2.0], 4 * [3.0], 4 * [4.0]])
indicator = tf.constant([True, False, True, False, True], tf.bool)
expected_subset = [4 * [0.0], 4 * [2.0], 4 * [4.0]]
boxes = box_list.BoxList(corners)
subset = box_list_ops.boolean_mask(boxes, indicator)
with self.test_session() as sess:
subset_output = sess.run(subset.get())
self.assertAllClose(subset_output, expected_subset)
def test_boolean_mask_with_field(self):
corners = tf.constant(
[4 * [0.0], 4 * [1.0], 4 * [2.0], 4 * [3.0], 4 * [4.0]])
indicator = tf.constant([True, False, True, False, True], tf.bool)
weights = tf.constant([[.1], [.3], [.5], [.7], [.9]], tf.float32)
expected_subset = [4 * [0.0], 4 * [2.0], 4 * [4.0]]
expected_weights = [[.1], [.5], [.9]]
boxes = box_list.BoxList(corners)
boxes.add_field('weights', weights)
subset = box_list_ops.boolean_mask(boxes, indicator, ['weights'])
with self.test_session() as sess:
subset_output, weights_output = sess.run(
[subset.get(), subset.get_field('weights')])
self.assertAllClose(subset_output, expected_subset)
self.assertAllClose(weights_output, expected_weights)
def test_gather(self):
corners = tf.constant(
[4 * [0.0], 4 * [1.0], 4 * [2.0], 4 * [3.0], 4 * [4.0]])
indices = tf.constant([0, 2, 4], tf.int32)
expected_subset = [4 * [0.0], 4 * [2.0], 4 * [4.0]]
boxes = box_list.BoxList(corners)
subset = box_list_ops.gather(boxes, indices)
with self.test_session() as sess:
subset_output = sess.run(subset.get())
self.assertAllClose(subset_output, expected_subset)
def test_gather_with_field(self):
corners = tf.constant([4*[0.0], 4*[1.0], 4*[2.0], 4*[3.0], 4*[4.0]])
indices = tf.constant([0, 2, 4], tf.int32)
weights = tf.constant([[.1], [.3], [.5], [.7], [.9]], tf.float32)
expected_subset = [4 * [0.0], 4 * [2.0], 4 * [4.0]]
expected_weights = [[.1], [.5], [.9]]
boxes = box_list.BoxList(corners)
boxes.add_field('weights', weights)
subset = box_list_ops.gather(boxes, indices, ['weights'])
with self.test_session() as sess:
subset_output, weights_output = sess.run(
[subset.get(), subset.get_field('weights')])
self.assertAllClose(subset_output, expected_subset)
self.assertAllClose(weights_output, expected_weights)
def test_gather_with_invalid_field(self):
corners = tf.constant([4 * [0.0], 4 * [1.0]])
indices = tf.constant([0, 1], tf.int32)
weights = tf.constant([[.1], [.3]], tf.float32)
boxes = box_list.BoxList(corners)
boxes.add_field('weights', weights)
with self.assertRaises(ValueError):
box_list_ops.gather(boxes, indices, ['foo', 'bar'])
def test_gather_with_invalid_inputs(self):
corners = tf.constant(
[4 * [0.0], 4 * [1.0], 4 * [2.0], 4 * [3.0], 4 * [4.0]])
indices_float32 = tf.constant([0, 2, 4], tf.float32)
boxes = box_list.BoxList(corners)
with self.assertRaises(ValueError):
_ = box_list_ops.gather(boxes, indices_float32)
indices_2d = tf.constant([[0, 2, 4]], tf.int32)
boxes = box_list.BoxList(corners)
with self.assertRaises(ValueError):
_ = box_list_ops.gather(boxes, indices_2d)
def test_gather_with_dynamic_indexing(self):
corners = tf.constant([4 * [0.0], 4 * [1.0], 4 * [2.0], 4 * [3.0], 4 * [4.0]
])
weights = tf.constant([.5, .3, .7, .1, .9], tf.float32)
indices = tf.reshape(tf.where(tf.greater(weights, 0.4)), [-1])
expected_subset = [4 * [0.0], 4 * [2.0], 4 * [4.0]]
expected_weights = [.5, .7, .9]
boxes = box_list.BoxList(corners)
boxes.add_field('weights', weights)
subset = box_list_ops.gather(boxes, indices, ['weights'])
with self.test_session() as sess:
subset_output, weights_output = sess.run([subset.get(), subset.get_field(
'weights')])
self.assertAllClose(subset_output, expected_subset)
self.assertAllClose(weights_output, expected_weights)
def test_sort_by_field_ascending_order(self):
exp_corners = [[0, 0, 1, 1], [0, 0.1, 1, 1.1], [0, -0.1, 1, 0.9],
[0, 10, 1, 11], [0, 10.1, 1, 11.1], [0, 100, 1, 101]]
exp_scores = [.95, .9, .75, .6, .5, .3]
exp_weights = [.2, .45, .6, .75, .8, .92]
shuffle = [2, 4, 0, 5, 1, 3]
corners = tf.constant([exp_corners[i] for i in shuffle], tf.float32)
boxes = box_list.BoxList(corners)
boxes.add_field('scores', tf.constant(
[exp_scores[i] for i in shuffle], tf.float32))
boxes.add_field('weights', tf.constant(
[exp_weights[i] for i in shuffle], tf.float32))
sort_by_weight = box_list_ops.sort_by_field(
boxes,
'weights',
order=box_list_ops.SortOrder.ascend)
with self.test_session() as sess:
corners_out, scores_out, weights_out = sess.run([
sort_by_weight.get(),
sort_by_weight.get_field('scores'),
sort_by_weight.get_field('weights')])
self.assertAllClose(corners_out, exp_corners)
self.assertAllClose(scores_out, exp_scores)
self.assertAllClose(weights_out, exp_weights)
def test_sort_by_field_descending_order(self):
exp_corners = [[0, 0, 1, 1], [0, 0.1, 1, 1.1], [0, -0.1, 1, 0.9],
[0, 10, 1, 11], [0, 10.1, 1, 11.1], [0, 100, 1, 101]]
exp_scores = [.95, .9, .75, .6, .5, .3]
exp_weights = [.2, .45, .6, .75, .8, .92]
shuffle = [2, 4, 0, 5, 1, 3]
corners = tf.constant([exp_corners[i] for i in shuffle], tf.float32)
boxes = box_list.BoxList(corners)
boxes.add_field('scores', tf.constant(
[exp_scores[i] for i in shuffle], tf.float32))
boxes.add_field('weights', tf.constant(
[exp_weights[i] for i in shuffle], tf.float32))
sort_by_score = box_list_ops.sort_by_field(boxes, 'scores')
with self.test_session() as sess:
corners_out, scores_out, weights_out = sess.run([sort_by_score.get(
), sort_by_score.get_field('scores'), sort_by_score.get_field('weights')])
self.assertAllClose(corners_out, exp_corners)
self.assertAllClose(scores_out, exp_scores)
self.assertAllClose(weights_out, exp_weights)
def test_sort_by_field_invalid_inputs(self):
corners = tf.constant([4 * [0.0], 4 * [0.5], 4 * [1.0], 4 * [2.0], 4 *
[3.0], 4 * [4.0]])
misc = tf.constant([[.95, .9], [.5, .3]], tf.float32)
weights = tf.constant([.1, .2], tf.float32)
boxes = box_list.BoxList(corners)
boxes.add_field('misc', misc)
boxes.add_field('weights', weights)
with self.test_session() as sess:
with self.assertRaises(ValueError):
box_list_ops.sort_by_field(boxes, 'area')
with self.assertRaises(ValueError):
box_list_ops.sort_by_field(boxes, 'misc')
with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
'Incorrect field size'):
sess.run(box_list_ops.sort_by_field(boxes, 'weights').get())
def test_visualize_boxes_in_image(self):
image = tf.zeros((6, 4, 3))
corners = tf.constant([[0, 0, 5, 3],
[0, 0, 3, 2]], tf.float32)
boxes = box_list.BoxList(corners)
image_and_boxes = box_list_ops.visualize_boxes_in_image(image, boxes)
image_and_boxes_bw = tf.to_float(
tf.greater(tf.reduce_sum(image_and_boxes, 2), 0.0))
exp_result = [[1, 1, 1, 0],
[1, 1, 1, 0],
[1, 1, 1, 0],
[1, 0, 1, 0],
[1, 1, 1, 0],
[0, 0, 0, 0]]
with self.test_session() as sess:
output = sess.run(image_and_boxes_bw)
self.assertAllEqual(output.astype(int), exp_result)
def test_filter_field_value_equals(self):
corners = tf.constant([[0, 0, 1, 1],
[0, 0.1, 1, 1.1],
[0, -0.1, 1, 0.9],
[0, 10, 1, 11],
[0, 10.1, 1, 11.1],
[0, 100, 1, 101]], tf.float32)
boxes = box_list.BoxList(corners)
boxes.add_field('classes', tf.constant([1, 2, 1, 2, 2, 1]))
exp_output1 = [[0, 0, 1, 1], [0, -0.1, 1, 0.9], [0, 100, 1, 101]]
exp_output2 = [[0, 0.1, 1, 1.1], [0, 10, 1, 11], [0, 10.1, 1, 11.1]]
filtered_boxes1 = box_list_ops.filter_field_value_equals(
boxes, 'classes', 1)
filtered_boxes2 = box_list_ops.filter_field_value_equals(
boxes, 'classes', 2)
with self.test_session() as sess:
filtered_output1, filtered_output2 = sess.run([filtered_boxes1.get(),
filtered_boxes2.get()])
self.assertAllClose(filtered_output1, exp_output1)
self.assertAllClose(filtered_output2, exp_output2)
def test_filter_greater_than(self):
corners = tf.constant([[0, 0, 1, 1],
[0, 0.1, 1, 1.1],
[0, -0.1, 1, 0.9],
[0, 10, 1, 11],
[0, 10.1, 1, 11.1],
[0, 100, 1, 101]], tf.float32)
boxes = box_list.BoxList(corners)
boxes.add_field('scores', tf.constant([.1, .75, .9, .5, .5, .8]))
thresh = .6
exp_output = [[0, 0.1, 1, 1.1], [0, -0.1, 1, 0.9], [0, 100, 1, 101]]
filtered_boxes = box_list_ops.filter_greater_than(boxes, thresh)
with self.test_session() as sess:
filtered_output = sess.run(filtered_boxes.get())
self.assertAllClose(filtered_output, exp_output)
def test_clip_box_list(self):
boxlist = box_list.BoxList(
tf.constant([[0.1, 0.1, 0.4, 0.4], [0.1, 0.1, 0.5, 0.5],
[0.6, 0.6, 0.8, 0.8], [0.2, 0.2, 0.3, 0.3]], tf.float32))
boxlist.add_field('classes', tf.constant([0, 0, 1, 1]))
boxlist.add_field('scores', tf.constant([0.75, 0.65, 0.3, 0.2]))
num_boxes = 2
clipped_boxlist = box_list_ops.pad_or_clip_box_list(boxlist, num_boxes)
expected_boxes = [[0.1, 0.1, 0.4, 0.4], [0.1, 0.1, 0.5, 0.5]]
expected_classes = [0, 0]
expected_scores = [0.75, 0.65]
with self.test_session() as sess:
boxes_out, classes_out, scores_out = sess.run(
[clipped_boxlist.get(), clipped_boxlist.get_field('classes'),
clipped_boxlist.get_field('scores')])
self.assertAllClose(expected_boxes, boxes_out)
self.assertAllEqual(expected_classes, classes_out)
self.assertAllClose(expected_scores, scores_out)
def test_pad_box_list(self):
boxlist = box_list.BoxList(
tf.constant([[0.1, 0.1, 0.4, 0.4], [0.1, 0.1, 0.5, 0.5]], tf.float32))
boxlist.add_field('classes', tf.constant([0, 1]))
boxlist.add_field('scores', tf.constant([0.75, 0.2]))
num_boxes = 4
padded_boxlist = box_list_ops.pad_or_clip_box_list(boxlist, num_boxes)
expected_boxes = [[0.1, 0.1, 0.4, 0.4], [0.1, 0.1, 0.5, 0.5],
[0, 0, 0, 0], [0, 0, 0, 0]]
expected_classes = [0, 1, 0, 0]
expected_scores = [0.75, 0.2, 0, 0]
with self.test_session() as sess:
boxes_out, classes_out, scores_out = sess.run(
[padded_boxlist.get(), padded_boxlist.get_field('classes'),
padded_boxlist.get_field('scores')])
self.assertAllClose(expected_boxes, boxes_out)
self.assertAllEqual(expected_classes, classes_out)
self.assertAllClose(expected_scores, scores_out)
class ConcatenateTest(tf.test.TestCase):
def test_invalid_input_box_list_list(self):
with self.assertRaises(ValueError):
box_list_ops.concatenate(None)
with self.assertRaises(ValueError):
box_list_ops.concatenate([])
with self.assertRaises(ValueError):
corners = tf.constant([[0, 0, 0, 0]], tf.float32)
boxlist = box_list.BoxList(corners)
box_list_ops.concatenate([boxlist, 2])
def test_concatenate_with_missing_fields(self):
corners1 = tf.constant([[0, 0, 0, 0], [1, 2, 3, 4]], tf.float32)
scores1 = tf.constant([1.0, 2.1])
corners2 = tf.constant([[0, 3, 1, 6], [2, 4, 3, 8]], tf.float32)
boxlist1 = box_list.BoxList(corners1)
boxlist1.add_field('scores', scores1)
boxlist2 = box_list.BoxList(corners2)
with self.assertRaises(ValueError):
box_list_ops.concatenate([boxlist1, boxlist2])
def test_concatenate_with_incompatible_field_shapes(self):
corners1 = tf.constant([[0, 0, 0, 0], [1, 2, 3, 4]], tf.float32)
scores1 = tf.constant([1.0, 2.1])
corners2 = tf.constant([[0, 3, 1, 6], [2, 4, 3, 8]], tf.float32)
scores2 = tf.constant([[1.0, 1.0], [2.1, 3.2]])
boxlist1 = box_list.BoxList(corners1)
boxlist1.add_field('scores', scores1)
boxlist2 = box_list.BoxList(corners2)
boxlist2.add_field('scores', scores2)
with self.assertRaises(ValueError):
box_list_ops.concatenate([boxlist1, boxlist2])
def test_concatenate_is_correct(self):
corners1 = tf.constant([[0, 0, 0, 0], [1, 2, 3, 4]], tf.float32)
scores1 = tf.constant([1.0, 2.1])
corners2 = tf.constant([[0, 3, 1, 6], [2, 4, 3, 8], [1, 0, 5, 10]],
tf.float32)
scores2 = tf.constant([1.0, 2.1, 5.6])
exp_corners = [[0, 0, 0, 0],
[1, 2, 3, 4],
[0, 3, 1, 6],
[2, 4, 3, 8],
[1, 0, 5, 10]]
exp_scores = [1.0, 2.1, 1.0, 2.1, 5.6]
boxlist1 = box_list.BoxList(corners1)
boxlist1.add_field('scores', scores1)
boxlist2 = box_list.BoxList(corners2)
boxlist2.add_field('scores', scores2)
result = box_list_ops.concatenate([boxlist1, boxlist2])
with self.test_session() as sess:
corners_output, scores_output = sess.run(
[result.get(), result.get_field('scores')])
self.assertAllClose(corners_output, exp_corners)
self.assertAllClose(scores_output, exp_scores)
class NonMaxSuppressionTest(tf.test.TestCase):
def test_with_invalid_scores_field(self):
corners = tf.constant([[0, 0, 1, 1],
[0, 0.1, 1, 1.1],
[0, -0.1, 1, 0.9],
[0, 10, 1, 11],
[0, 10.1, 1, 11.1],
[0, 100, 1, 101]], tf.float32)
boxes = box_list.BoxList(corners)
boxes.add_field('scores', tf.constant([.9, .75, .6, .95, .5]))
iou_thresh = .5
max_output_size = 3
nms = box_list_ops.non_max_suppression(
boxes, iou_thresh, max_output_size)
with self.test_session() as sess:
with self.assertRaisesWithPredicateMatch(
errors.InvalidArgumentError, 'scores has incompatible shape'):
sess.run(nms.get())
def test_select_from_three_clusters(self):
corners = tf.constant([[0, 0, 1, 1],
[0, 0.1, 1, 1.1],
[0, -0.1, 1, 0.9],
[0, 10, 1, 11],
[0, 10.1, 1, 11.1],
[0, 100, 1, 101]], tf.float32)
boxes = box_list.BoxList(corners)
boxes.add_field('scores', tf.constant([.9, .75, .6, .95, .5, .3]))
iou_thresh = .5
max_output_size = 3
exp_nms = [[0, 10, 1, 11],
[0, 0, 1, 1],
[0, 100, 1, 101]]
nms = box_list_ops.non_max_suppression(
boxes, iou_thresh, max_output_size)
with self.test_session() as sess:
nms_output = sess.run(nms.get())
self.assertAllClose(nms_output, exp_nms)
def test_select_at_most_two_boxes_from_three_clusters(self):
corners = tf.constant([[0, 0, 1, 1],
[0, 0.1, 1, 1.1],
[0, -0.1, 1, 0.9],
[0, 10, 1, 11],
[0, 10.1, 1, 11.1],
[0, 100, 1, 101]], tf.float32)
boxes = box_list.BoxList(corners)
boxes.add_field('scores', tf.constant([.9, .75, .6, .95, .5, .3]))
iou_thresh = .5
max_output_size = 2
exp_nms = [[0, 10, 1, 11],
[0, 0, 1, 1]]
nms = box_list_ops.non_max_suppression(
boxes, iou_thresh, max_output_size)
with self.test_session() as sess:
nms_output = sess.run(nms.get())
self.assertAllClose(nms_output, exp_nms)
def test_select_at_most_thirty_boxes_from_three_clusters(self):
corners = tf.constant([[0, 0, 1, 1],
[0, 0.1, 1, 1.1],
[0, -0.1, 1, 0.9],
[0, 10, 1, 11],
[0, 10.1, 1, 11.1],
[0, 100, 1, 101]], tf.float32)
boxes = box_list.BoxList(corners)
boxes.add_field('scores', tf.constant([.9, .75, .6, .95, .5, .3]))
iou_thresh = .5
max_output_size = 30
exp_nms = [[0, 10, 1, 11],
[0, 0, 1, 1],
[0, 100, 1, 101]]
nms = box_list_ops.non_max_suppression(
boxes, iou_thresh, max_output_size)
with self.test_session() as sess:
nms_output = sess.run(nms.get())
self.assertAllClose(nms_output, exp_nms)
def test_select_single_box(self):
corners = tf.constant([[0, 0, 1, 1]], tf.float32)
boxes = box_list.BoxList(corners)
boxes.add_field('scores', tf.constant([.9]))
iou_thresh = .5
max_output_size = 3
exp_nms = [[0, 0, 1, 1]]
nms = box_list_ops.non_max_suppression(
boxes, iou_thresh, max_output_size)
with self.test_session() as sess:
nms_output = sess.run(nms.get())
self.assertAllClose(nms_output, exp_nms)
def test_select_from_ten_identical_boxes(self):
corners = tf.constant(10 * [[0, 0, 1, 1]], tf.float32)
boxes = box_list.BoxList(corners)
boxes.add_field('scores', tf.constant(10 * [.9]))
iou_thresh = .5
max_output_size = 3
exp_nms = [[0, 0, 1, 1]]
nms = box_list_ops.non_max_suppression(
boxes, iou_thresh, max_output_size)
with self.test_session() as sess:
nms_output = sess.run(nms.get())
self.assertAllClose(nms_output, exp_nms)
def test_copy_extra_fields(self):
corners = tf.constant([[0, 0, 1, 1],
[0, 0.1, 1, 1.1]], tf.float32)
boxes = box_list.BoxList(corners)
tensor1 = np.array([[1], [4]])
tensor2 = np.array([[1, 1], [2, 2]])
boxes.add_field('tensor1', tf.constant(tensor1))
boxes.add_field('tensor2', tf.constant(tensor2))
new_boxes = box_list.BoxList(tf.constant([[0, 0, 10, 10],
[1, 3, 5, 5]], tf.float32))
new_boxes = box_list_ops._copy_extra_fields(new_boxes, boxes)
with self.test_session() as sess:
self.assertAllClose(tensor1, sess.run(new_boxes.get_field('tensor1')))
self.assertAllClose(tensor2, sess.run(new_boxes.get_field('tensor2')))
class CoordinatesConversionTest(tf.test.TestCase):
def test_to_normalized_coordinates(self):
coordinates = tf.constant([[0, 0, 100, 100],
[25, 25, 75, 75]], tf.float32)
img = tf.ones((128, 100, 100, 3))
boxlist = box_list.BoxList(coordinates)
normalized_boxlist = box_list_ops.to_normalized_coordinates(
boxlist, tf.shape(img)[1], tf.shape(img)[2])
expected_boxes = [[0, 0, 1, 1],
[0.25, 0.25, 0.75, 0.75]]
with self.test_session() as sess:
normalized_boxes = sess.run(normalized_boxlist.get())
self.assertAllClose(normalized_boxes, expected_boxes)
def test_to_normalized_coordinates_already_normalized(self):
coordinates = tf.constant([[0, 0, 1, 1],
[0.25, 0.25, 0.75, 0.75]], tf.float32)
img = tf.ones((128, 100, 100, 3))
boxlist = box_list.BoxList(coordinates)
normalized_boxlist = box_list_ops.to_normalized_coordinates(
boxlist, tf.shape(img)[1], tf.shape(img)[2])
with self.test_session() as sess:
with self.assertRaisesOpError('assertion failed'):
sess.run(normalized_boxlist.get())
def test_to_absolute_coordinates(self):
coordinates = tf.constant([[0, 0, 1, 1],
[0.25, 0.25, 0.75, 0.75]], tf.float32)
img = tf.ones((128, 100, 100, 3))
boxlist = box_list.BoxList(coordinates)
absolute_boxlist = box_list_ops.to_absolute_coordinates(boxlist,
tf.shape(img)[1],
tf.shape(img)[2])
expected_boxes = [[0, 0, 100, 100],
[25, 25, 75, 75]]
with self.test_session() as sess:
absolute_boxes = sess.run(absolute_boxlist.get())
self.assertAllClose(absolute_boxes, expected_boxes)
def test_to_absolute_coordinates_already_abolute(self):
coordinates = tf.constant([[0, 0, 100, 100],
[25, 25, 75, 75]], tf.float32)
img = tf.ones((128, 100, 100, 3))
boxlist = box_list.BoxList(coordinates)
absolute_boxlist = box_list_ops.to_absolute_coordinates(boxlist,
tf.shape(img)[1],
tf.shape(img)[2])
with self.test_session() as sess:
with self.assertRaisesOpError('assertion failed'):
sess.run(absolute_boxlist.get())
def test_convert_to_normalized_and_back(self):
coordinates = np.random.uniform(size=(100, 4))
coordinates = np.round(np.sort(coordinates) * 200)
coordinates[:, 2:4] += 1
coordinates[99, :] = [0, 0, 201, 201]
img = tf.ones((128, 202, 202, 3))
boxlist = box_list.BoxList(tf.constant(coordinates, tf.float32))
boxlist = box_list_ops.to_normalized_coordinates(boxlist,
tf.shape(img)[1],
tf.shape(img)[2])
boxlist = box_list_ops.to_absolute_coordinates(boxlist,
tf.shape(img)[1],
tf.shape(img)[2])
with self.test_session() as sess:
out = sess.run(boxlist.get())
self.assertAllClose(out, coordinates)
def test_convert_to_absolute_and_back(self):
coordinates = np.random.uniform(size=(100, 4))
coordinates = np.sort(coordinates)
coordinates[99, :] = [0, 0, 1, 1]
img = tf.ones((128, 202, 202, 3))
boxlist = box_list.BoxList(tf.constant(coordinates, tf.float32))
boxlist = box_list_ops.to_absolute_coordinates(boxlist,
tf.shape(img)[1],
tf.shape(img)[2])
boxlist = box_list_ops.to_normalized_coordinates(boxlist,
tf.shape(img)[1],
tf.shape(img)[2])
with self.test_session() as sess:
out = sess.run(boxlist.get())
self.assertAllClose(out, coordinates)
class BoxRefinementTest(tf.test.TestCase):
def test_box_voting(self):
candidates = box_list.BoxList(
tf.constant([[0.1, 0.1, 0.4, 0.4], [0.6, 0.6, 0.8, 0.8]], tf.float32))
candidates.add_field('ExtraField', tf.constant([1, 2]))
pool = box_list.BoxList(
tf.constant([[0.1, 0.1, 0.4, 0.4], [0.1, 0.1, 0.5, 0.5],
[0.6, 0.6, 0.8, 0.8]], tf.float32))
pool.add_field('scores', tf.constant([0.75, 0.25, 0.3]))
averaged_boxes = box_list_ops.box_voting(candidates, pool)
expected_boxes = [[0.1, 0.1, 0.425, 0.425], [0.6, 0.6, 0.8, 0.8]]
expected_scores = [0.5, 0.3]
with self.test_session() as sess:
boxes_out, scores_out, extra_field_out = sess.run(
[averaged_boxes.get(), averaged_boxes.get_field('scores'),
averaged_boxes.get_field('ExtraField')])
self.assertAllClose(expected_boxes, boxes_out)
self.assertAllClose(expected_scores, scores_out)
self.assertAllEqual(extra_field_out, [1, 2])
def test_box_voting_fails_with_negative_scores(self):
candidates = box_list.BoxList(
tf.constant([[0.1, 0.1, 0.4, 0.4]], tf.float32))
pool = box_list.BoxList(tf.constant([[0.1, 0.1, 0.4, 0.4]], tf.float32))
pool.add_field('scores', tf.constant([-0.2]))
averaged_boxes = box_list_ops.box_voting(candidates, pool)
with self.test_session() as sess:
with self.assertRaisesOpError('Scores must be non negative'):
sess.run([averaged_boxes.get()])
def test_box_voting_fails_when_unmatched(self):
candidates = box_list.BoxList(
tf.constant([[0.1, 0.1, 0.4, 0.4]], tf.float32))
pool = box_list.BoxList(tf.constant([[0.6, 0.6, 0.8, 0.8]], tf.float32))
pool.add_field('scores', tf.constant([0.2]))
averaged_boxes = box_list_ops.box_voting(candidates, pool)
with self.test_session() as sess:
with self.assertRaisesOpError('Each box in selected_boxes must match '
'with at least one box in pool_boxes.'):
sess.run([averaged_boxes.get()])
def test_refine_boxes(self):
pool = box_list.BoxList(
tf.constant([[0.1, 0.1, 0.4, 0.4], [0.1, 0.1, 0.5, 0.5],
[0.6, 0.6, 0.8, 0.8]], tf.float32))
pool.add_field('ExtraField', tf.constant([1, 2, 3]))
pool.add_field('scores', tf.constant([0.75, 0.25, 0.3]))
refined_boxes = box_list_ops.refine_boxes(pool, 0.5, 10)
expected_boxes = [[0.1, 0.1, 0.425, 0.425], [0.6, 0.6, 0.8, 0.8]]
expected_scores = [0.5, 0.3]
with self.test_session() as sess:
boxes_out, scores_out, extra_field_out = sess.run(
[refined_boxes.get(), refined_boxes.get_field('scores'),
refined_boxes.get_field('ExtraField')])
self.assertAllClose(expected_boxes, boxes_out)
self.assertAllClose(expected_scores, scores_out)
self.assertAllEqual(extra_field_out, [1, 3])
def test_refine_boxes_multi_class(self):
pool = box_list.BoxList(
tf.constant([[0.1, 0.1, 0.4, 0.4], [0.1, 0.1, 0.5, 0.5],
[0.6, 0.6, 0.8, 0.8], [0.2, 0.2, 0.3, 0.3]], tf.float32))
pool.add_field('classes', tf.constant([0, 0, 1, 1]))
pool.add_field('scores', tf.constant([0.75, 0.25, 0.3, 0.2]))
refined_boxes = box_list_ops.refine_boxes_multi_class(pool, 3, 0.5, 10)
expected_boxes = [[0.1, 0.1, 0.425, 0.425], [0.6, 0.6, 0.8, 0.8],
[0.2, 0.2, 0.3, 0.3]]
expected_scores = [0.5, 0.3, 0.2]
with self.test_session() as sess:
boxes_out, scores_out, extra_field_out = sess.run(
[refined_boxes.get(), refined_boxes.get_field('scores'),
refined_boxes.get_field('classes')])
self.assertAllClose(expected_boxes, boxes_out)
self.assertAllClose(expected_scores, scores_out)
self.assertAllEqual(extra_field_out, [0, 1, 1])
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment