Commit f5955367 authored by Vighnesh Birodkar's avatar Vighnesh Birodkar Committed by TF Object Detection Team
Browse files

Refactor random jitter boxes and support expansion only mode.

PiperOrigin-RevId: 370903487
parent 120953f5
...@@ -89,8 +89,6 @@ PREPROCESSING_FUNCTION_MAP = { ...@@ -89,8 +89,6 @@ PREPROCESSING_FUNCTION_MAP = {
preprocessor.random_adjust_saturation, preprocessor.random_adjust_saturation,
'random_distort_color': 'random_distort_color':
preprocessor.random_distort_color, preprocessor.random_distort_color,
'random_jitter_boxes':
preprocessor.random_jitter_boxes,
'random_crop_to_aspect_ratio': 'random_crop_to_aspect_ratio':
preprocessor.random_crop_to_aspect_ratio, preprocessor.random_crop_to_aspect_ratio,
'random_black_patches': 'random_black_patches':
...@@ -125,6 +123,16 @@ RESIZE_METHOD_MAP = { ...@@ -125,6 +123,16 @@ RESIZE_METHOD_MAP = {
} }
def get_random_jitter_kwargs(proto):
return {
'ratio':
proto.ratio,
'jitter_mode':
preprocessor_pb2.RandomJitterBoxes.JitterMode.Name(proto.jitter_mode
).lower()
}
def build(preprocessor_step_config): def build(preprocessor_step_config):
"""Builds preprocessing step based on the configuration. """Builds preprocessing step based on the configuration.
...@@ -427,4 +435,8 @@ def build(preprocessor_step_config): ...@@ -427,4 +435,8 @@ def build(preprocessor_step_config):
'output_size': config.output_size, 'output_size': config.output_size,
} }
if step_type == 'random_jitter_boxes':
config = preprocessor_step_config.random_jitter_boxes
kwargs = get_random_jitter_kwargs(config)
return preprocessor.random_jitter_boxes, kwargs
raise ValueError('Unknown preprocessing step.') raise ValueError('Unknown preprocessing step.')
...@@ -216,13 +216,14 @@ class PreprocessorBuilderTest(tf.test.TestCase): ...@@ -216,13 +216,14 @@ class PreprocessorBuilderTest(tf.test.TestCase):
preprocessor_text_proto = """ preprocessor_text_proto = """
random_jitter_boxes { random_jitter_boxes {
ratio: 0.1 ratio: 0.1
jitter_mode: SHRINK
} }
""" """
preprocessor_proto = preprocessor_pb2.PreprocessingStep() preprocessor_proto = preprocessor_pb2.PreprocessingStep()
text_format.Merge(preprocessor_text_proto, preprocessor_proto) text_format.Merge(preprocessor_text_proto, preprocessor_proto)
function, args = preprocessor_builder.build(preprocessor_proto) function, args = preprocessor_builder.build(preprocessor_proto)
self.assertEqual(function, preprocessor.random_jitter_boxes) self.assertEqual(function, preprocessor.random_jitter_boxes)
self.assert_dictionary_close(args, {'ratio': 0.1}) self.assert_dictionary_close(args, {'ratio': 0.1, 'jitter_mode': 'shrink'})
def test_build_random_crop_image(self): def test_build_random_crop_image(self):
preprocessor_text_proto = """ preprocessor_text_proto = """
......
...@@ -1306,7 +1306,7 @@ def random_distort_color(image, color_ordering=0, preprocess_vars_cache=None): ...@@ -1306,7 +1306,7 @@ def random_distort_color(image, color_ordering=0, preprocess_vars_cache=None):
return image return image
def random_jitter_boxes(boxes, ratio=0.05, seed=None): def random_jitter_boxes(boxes, ratio=0.05, jitter_mode='random', seed=None):
"""Randomly jitter boxes in image. """Randomly jitter boxes in image.
Args: Args:
...@@ -1317,45 +1317,46 @@ def random_jitter_boxes(boxes, ratio=0.05, seed=None): ...@@ -1317,45 +1317,46 @@ def random_jitter_boxes(boxes, ratio=0.05, seed=None):
ratio: The ratio of the box width and height that the corners can jitter. ratio: The ratio of the box width and height that the corners can jitter.
For example if the width is 100 pixels and ratio is 0.05, For example if the width is 100 pixels and ratio is 0.05,
the corners can jitter up to 5 pixels in the x direction. the corners can jitter up to 5 pixels in the x direction.
jitter_mode: One of
shrink - Only shrinks boxes.
expand - Only expands boxes.
default - Randomly and independently perturbs each box boundary.
seed: random seed. seed: random seed.
Returns: Returns:
boxes: boxes which is the same shape as input boxes. boxes: boxes which is the same shape as input boxes.
""" """
def random_jitter_box(box, ratio, seed): with tf.name_scope('RandomJitterBoxes', values=[boxes]):
"""Randomly jitter box. ymin, xmin, ymax, xmax = (boxes[:, i] for i in range(4))
Args: height, width = ymax - ymin, xmax - xmin
box: bounding box [1, 1, 4]. ycenter, xcenter = (ymin + ymax) / 2.0, (xmin + xmax) / 2.0
ratio: max ratio between jittered box and original box,
a number between [0, 0.5].
seed: random seed.
Returns: height = tf.abs(height)
jittered_box: jittered box. width = tf.abs(width)
"""
rand_numbers = tf.random_uniform(
[1, 1, 4], minval=-ratio, maxval=ratio, dtype=tf.float32, seed=seed)
box_width = tf.subtract(box[0, 0, 3], box[0, 0, 1])
box_height = tf.subtract(box[0, 0, 2], box[0, 0, 0])
hw_coefs = tf.stack([box_height, box_width, box_height, box_width])
hw_rand_coefs = tf.multiply(hw_coefs, rand_numbers)
jittered_box = tf.add(box, hw_rand_coefs)
jittered_box = tf.clip_by_value(jittered_box, 0.0, 1.0)
return jittered_box
with tf.name_scope('RandomJitterBoxes', values=[boxes]): if jitter_mode == 'shrink':
# boxes are [N, 4]. Lets first make them [N, 1, 1, 4] min_ratio, max_ratio = -ratio, 0
boxes_shape = tf.shape(boxes) elif jitter_mode == 'expand':
boxes = tf.expand_dims(boxes, 1) min_ratio, max_ratio = 0, ratio
boxes = tf.expand_dims(boxes, 2) else:
min_ratio, max_ratio = -ratio, ratio
num_boxes = tf.shape(boxes)[0]
distortion = 1.0 + tf.random_uniform(
[num_boxes, 4], minval=min_ratio, maxval=max_ratio, dtype=tf.float32,
seed=seed)
distorted_boxes = tf.map_fn( ymin_jitter = height * distortion[:, 0]
lambda x: random_jitter_box(x, ratio, seed), boxes, dtype=tf.float32) xmin_jitter = width * distortion[:, 1]
ymax_jitter = height * distortion[:, 2]
xmax_jitter = width * distortion[:, 3]
distorted_boxes = tf.reshape(distorted_boxes, boxes_shape) ymin, ymax = ycenter - (ymin_jitter / 2.0), ycenter + (ymax_jitter / 2.0)
xmin, xmax = xcenter - (xmin_jitter / 2.0), xcenter + (xmax_jitter / 2.0)
return distorted_boxes boxes = tf.stack([ymin, xmin, ymax, xmax], axis=1)
return tf.clip_by_value(boxes, 0.0, 1.0)
def _strict_random_crop_image(image, def _strict_random_crop_image(image,
......
...@@ -1263,6 +1263,67 @@ class PreprocessorTest(test_case.TestCase, parameterized.TestCase): ...@@ -1263,6 +1263,67 @@ class PreprocessorTest(test_case.TestCase, parameterized.TestCase):
(boxes_shape_, distorted_boxes_shape_) = self.execute_cpu(graph_fn, []) (boxes_shape_, distorted_boxes_shape_) = self.execute_cpu(graph_fn, [])
self.assertAllEqual(boxes_shape_, distorted_boxes_shape_) self.assertAllEqual(boxes_shape_, distorted_boxes_shape_)
def testRandomJitterBoxesZeroRatio(self):
def graph_fn():
preprocessing_options = []
preprocessing_options.append((preprocessor.random_jitter_boxes,
{'ratio': 0.0}))
boxes = self.createTestBoxes()
tensor_dict = {fields.InputDataFields.groundtruth_boxes: boxes}
tensor_dict = preprocessor.preprocess(tensor_dict, preprocessing_options)
distorted_boxes = tensor_dict[fields.InputDataFields.groundtruth_boxes]
return [boxes, distorted_boxes]
(boxes, distorted_boxes) = self.execute_cpu(graph_fn, [])
self.assertAllEqual(boxes, distorted_boxes)
def testRandomJitterBoxesExpand(self):
def graph_fn():
preprocessing_options = []
preprocessing_options.append((preprocessor.random_jitter_boxes,
{'jitter_mode': 'expand'}))
boxes = self.createTestBoxes()
tensor_dict = {fields.InputDataFields.groundtruth_boxes: boxes}
tensor_dict = preprocessor.preprocess(tensor_dict, preprocessing_options)
distorted_boxes = tensor_dict[fields.InputDataFields.groundtruth_boxes]
return [boxes, distorted_boxes]
boxes, distorted_boxes = self.execute_cpu(graph_fn, [])
ymin, xmin, ymax, xmax = boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3]
distorted_ymin, distorted_xmin, distorted_ymax, distorted_xmax = (
distorted_boxes[:, 0], distorted_boxes[:, 1], distorted_boxes[0, 2],
distorted_boxes[:, 3])
self.assertTrue(np.all(distorted_ymin <= ymin))
self.assertTrue(np.all(distorted_xmin <= xmin))
self.assertTrue(np.all(distorted_ymax >= ymax))
self.assertTrue(np.all(distorted_xmax >= xmax))
def testRandomJitterBoxesShrink(self):
def graph_fn():
preprocessing_options = []
preprocessing_options.append((preprocessor.random_jitter_boxes,
{'jitter_mode': 'shrink'}))
boxes = self.createTestBoxes()
tensor_dict = {fields.InputDataFields.groundtruth_boxes: boxes}
tensor_dict = preprocessor.preprocess(tensor_dict, preprocessing_options)
distorted_boxes = tensor_dict[fields.InputDataFields.groundtruth_boxes]
return [boxes, distorted_boxes]
boxes, distorted_boxes = self.execute_cpu(graph_fn, [])
ymin, xmin, ymax, xmax = boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3]
distorted_ymin, distorted_xmin, distorted_ymax, distorted_xmax = (
distorted_boxes[:, 0], distorted_boxes[:, 1], distorted_boxes[0, 2],
distorted_boxes[:, 3])
self.assertTrue(np.all(distorted_ymin >= ymin))
self.assertTrue(np.all(distorted_xmin >= xmin))
self.assertTrue(np.all(distorted_ymax <= ymax))
self.assertTrue(np.all(distorted_xmax <= xmax))
def testRandomCropImage(self): def testRandomCropImage(self):
def graph_fn(): def graph_fn():
......
...@@ -12,11 +12,13 @@ from object_detection.builders import losses_builder ...@@ -12,11 +12,13 @@ from object_detection.builders import losses_builder
from object_detection.core import box_list from object_detection.core import box_list
from object_detection.core import box_list_ops from object_detection.core import box_list_ops
from object_detection.core import losses from object_detection.core import losses
from object_detection.core import preprocessor
from object_detection.core import standard_fields as fields from object_detection.core import standard_fields as fields
from object_detection.meta_architectures import center_net_meta_arch from object_detection.meta_architectures import center_net_meta_arch
from object_detection.models.keras_models import hourglass_network from object_detection.models.keras_models import hourglass_network
from object_detection.models.keras_models import resnet_v1 from object_detection.models.keras_models import resnet_v1
from object_detection.protos import losses_pb2 from object_detection.protos import losses_pb2
from object_detection.protos import preprocessor_pb2
from object_detection.utils import shape_utils from object_detection.utils import shape_utils
from object_detection.utils import spatial_transform_ops from object_detection.utils import spatial_transform_ops
...@@ -32,7 +34,8 @@ class DeepMACParams( ...@@ -32,7 +34,8 @@ class DeepMACParams(
'classification_loss', 'dim', 'task_loss_weight', 'pixel_embedding_dim', 'classification_loss', 'dim', 'task_loss_weight', 'pixel_embedding_dim',
'allowed_masked_classes_ids', 'mask_size', 'mask_num_subsamples', 'allowed_masked_classes_ids', 'mask_size', 'mask_num_subsamples',
'use_xy', 'network_type', 'use_instance_embedding', 'num_init_channels', 'use_xy', 'network_type', 'use_instance_embedding', 'num_init_channels',
'predict_full_resolution_masks', 'postprocess_crop_size' 'predict_full_resolution_masks', 'postprocess_crop_size',
'max_roi_jitter_ratio', 'roi_jitter_mode'
])): ])):
"""Class holding the DeepMAC network configutration.""" """Class holding the DeepMAC network configutration."""
...@@ -42,7 +45,8 @@ class DeepMACParams( ...@@ -42,7 +45,8 @@ class DeepMACParams(
pixel_embedding_dim, allowed_masked_classes_ids, mask_size, pixel_embedding_dim, allowed_masked_classes_ids, mask_size,
mask_num_subsamples, use_xy, network_type, use_instance_embedding, mask_num_subsamples, use_xy, network_type, use_instance_embedding,
num_init_channels, predict_full_resolution_masks, num_init_channels, predict_full_resolution_masks,
postprocess_crop_size): postprocess_crop_size, max_roi_jitter_ratio,
roi_jitter_mode):
return super(DeepMACParams, return super(DeepMACParams,
cls).__new__(cls, classification_loss, dim, cls).__new__(cls, classification_loss, dim,
task_loss_weight, pixel_embedding_dim, task_loss_weight, pixel_embedding_dim,
...@@ -50,7 +54,8 @@ class DeepMACParams( ...@@ -50,7 +54,8 @@ class DeepMACParams(
mask_num_subsamples, use_xy, network_type, mask_num_subsamples, use_xy, network_type,
use_instance_embedding, num_init_channels, use_instance_embedding, num_init_channels,
predict_full_resolution_masks, predict_full_resolution_masks,
postprocess_crop_size) postprocess_crop_size, max_roi_jitter_ratio,
roi_jitter_mode)
def subsample_instances(classes, weights, boxes, masks, num_subsamples): def subsample_instances(classes, weights, boxes, masks, num_subsamples):
...@@ -355,6 +360,9 @@ def deepmac_proto_to_params(deepmac_config): ...@@ -355,6 +360,9 @@ def deepmac_proto_to_params(deepmac_config):
loss.classification_loss.CopyFrom(deepmac_config.classification_loss) loss.classification_loss.CopyFrom(deepmac_config.classification_loss)
classification_loss, _, _, _, _, _, _ = (losses_builder.build(loss)) classification_loss, _, _, _, _, _, _ = (losses_builder.build(loss))
jitter_mode = preprocessor_pb2.RandomJitterBoxes.JitterMode.Name(
deepmac_config.jitter_mode).lower()
return DeepMACParams( return DeepMACParams(
dim=deepmac_config.dim, dim=deepmac_config.dim,
classification_loss=classification_loss, classification_loss=classification_loss,
...@@ -369,7 +377,9 @@ def deepmac_proto_to_params(deepmac_config): ...@@ -369,7 +377,9 @@ def deepmac_proto_to_params(deepmac_config):
num_init_channels=deepmac_config.num_init_channels, num_init_channels=deepmac_config.num_init_channels,
predict_full_resolution_masks= predict_full_resolution_masks=
deepmac_config.predict_full_resolution_masks, deepmac_config.predict_full_resolution_masks,
postprocess_crop_size=deepmac_config.postprocess_crop_size postprocess_crop_size=deepmac_config.postprocess_crop_size,
max_roi_jitter_ratio=deepmac_config.max_roi_jitter_ratio,
roi_jitter_mode=jitter_mode
) )
...@@ -553,6 +563,11 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch): ...@@ -553,6 +563,11 @@ class DeepMACMetaArch(center_net_meta_arch.CenterNetMetaArch):
""" """
num_instances = tf.shape(boxes)[0] num_instances = tf.shape(boxes)[0]
if tf.keras.backend.learning_phase():
boxes = preprocessor.random_jitter_boxes(
boxes, self._deepmac_params.max_roi_jitter_ratio,
jitter_mode=self._deepmac_params.roi_jitter_mode)
mask_input = self._get_mask_head_input( mask_input = self._get_mask_head_input(
boxes, pixel_embedding) boxes, pixel_embedding)
instance_embeddings = self._get_instance_embeddings( instance_embeddings = self._get_instance_embeddings(
......
...@@ -100,7 +100,9 @@ def build_meta_arch(predict_full_resolution_masks=False, use_dice_loss=False): ...@@ -100,7 +100,9 @@ def build_meta_arch(predict_full_resolution_masks=False, use_dice_loss=False):
use_instance_embedding=True, use_instance_embedding=True,
num_init_channels=8, num_init_channels=8,
predict_full_resolution_masks=predict_full_resolution_masks, predict_full_resolution_masks=predict_full_resolution_masks,
postprocess_crop_size=128 postprocess_crop_size=128,
max_roi_jitter_ratio=0.0,
roi_jitter_mode='random'
) )
object_detection_params = center_net_meta_arch.ObjectDetectionParams( object_detection_params = center_net_meta_arch.ObjectDetectionParams(
......
...@@ -5,6 +5,7 @@ package object_detection.protos; ...@@ -5,6 +5,7 @@ package object_detection.protos;
import "object_detection/protos/image_resizer.proto"; import "object_detection/protos/image_resizer.proto";
import "object_detection/protos/losses.proto"; import "object_detection/protos/losses.proto";
import "object_detection/protos/post_processing.proto"; import "object_detection/protos/post_processing.proto";
import "object_detection/protos/preprocessor.proto";
// Configuration for the CenterNet meta architecture from the "Objects as // Configuration for the CenterNet meta architecture from the "Objects as
// Points" paper [1] // Points" paper [1]
...@@ -396,6 +397,15 @@ message CenterNet { ...@@ -396,6 +397,15 @@ message CenterNet {
// of the API, masks are always cropped and resized according to detected // of the API, masks are always cropped and resized according to detected
// boxes in postprocess. // boxes in postprocess.
optional int32 postprocess_crop_size = 13 [default=256]; optional int32 postprocess_crop_size = 13 [default=256];
// The maximum relative amount by which boxes will be jittered before
// RoI crop happens. The x and y coordinates of the box are jittered
// relative to width and height respectively.
optional float max_roi_jitter_ratio = 14 [default=0.0];
// The mode for jitterting box ROIs. See RandomJitterBoxes in
// preprocessor.proto for more details
optional RandomJitterBoxes.JitterMode jitter_mode = 15 [default=DEFAULT];
} }
optional DeepMACMaskEstimation deepmac_mask_estimation = 14; optional DeepMACMaskEstimation deepmac_mask_estimation = 14;
......
...@@ -165,6 +165,18 @@ message RandomDistortColor { ...@@ -165,6 +165,18 @@ message RandomDistortColor {
// ie. If a box is [100, 200] and ratio is 0.02, the corners can move by [1, 4]. // ie. If a box is [100, 200] and ratio is 0.02, the corners can move by [1, 4].
message RandomJitterBoxes { message RandomJitterBoxes {
optional float ratio = 1 [default=0.05]; optional float ratio = 1 [default=0.05];
enum JitterMode {
DEFAULT = 0;
EXPAND = 1;
SHRINK = 2;
}
// The mode of jittering
// EXPAND - Only expands boxes
// SHRINK - Only shrinks boxes
// DEFAULT - Jitters each box boundary independently
optional JitterMode jitter_mode = 2 [default=DEFAULT];
} }
// Randomly crops the image and bounding boxes. // Randomly crops the image and bounding boxes.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment