Commit 7271dc97 authored by Vighnesh Birodkar's avatar Vighnesh Birodkar Committed by TF Object Detection Team
Browse files

Implement sparse version of coordinates_to_heatmap for memory savings.

PiperOrigin-RevId: 338221628
parent 6e63dfee
...@@ -1036,7 +1036,8 @@ def _build_center_net_model(center_net_config, is_training, add_summaries): ...@@ -1036,7 +1036,8 @@ def _build_center_net_model(center_net_config, is_training, add_summaries):
densepose_params=densepose_params, densepose_params=densepose_params,
track_params=track_params, track_params=track_params,
temporal_offset_params=temporal_offset_params, temporal_offset_params=temporal_offset_params,
use_depthwise=center_net_config.use_depthwise) use_depthwise=center_net_config.use_depthwise,
compute_heatmap_sparse=center_net_config.compute_heatmap_sparse)
def _build_center_net_feature_extractor( def _build_center_net_feature_extractor(
......
...@@ -840,17 +840,22 @@ def _compute_std_dev_from_box_size(boxes_height, boxes_width, min_overlap): ...@@ -840,17 +840,22 @@ def _compute_std_dev_from_box_size(boxes_height, boxes_width, min_overlap):
class CenterNetCenterHeatmapTargetAssigner(object): class CenterNetCenterHeatmapTargetAssigner(object):
"""Wrapper to compute the object center heatmap.""" """Wrapper to compute the object center heatmap."""
def __init__(self, stride, min_overlap=0.7): def __init__(self, stride, min_overlap=0.7, compute_heatmap_sparse=False):
"""Initializes the target assigner. """Initializes the target assigner.
Args: Args:
stride: int, the stride of the network in output pixels. stride: int, the stride of the network in output pixels.
min_overlap: The minimum IOU overlap that boxes need to have to not be min_overlap: The minimum IOU overlap that boxes need to have to not be
penalized. penalized.
compute_heatmap_sparse: bool, indicating whether or not to use the sparse
version of the Op that computes the heatmap. The sparse version scales
better with number of classes, but in some cases is known to cause
OOM error. See (b/170989061).
""" """
self._stride = stride self._stride = stride
self._min_overlap = min_overlap self._min_overlap = min_overlap
self._compute_heatmap_sparse = compute_heatmap_sparse
def assign_center_targets_from_boxes(self, def assign_center_targets_from_boxes(self,
height, height,
...@@ -915,7 +920,8 @@ class CenterNetCenterHeatmapTargetAssigner(object): ...@@ -915,7 +920,8 @@ class CenterNetCenterHeatmapTargetAssigner(object):
x_coordinates=x_center, x_coordinates=x_center,
sigma=sigma, sigma=sigma,
channel_onehot=class_targets, channel_onehot=class_targets,
channel_weights=weights) channel_weights=weights,
sparse=self._compute_heatmap_sparse)
heatmaps.append(heatmap) heatmaps.append(heatmap)
# Return the stacked heatmaps over the batch. # Return the stacked heatmaps over the batch.
...@@ -1073,7 +1079,8 @@ class CenterNetKeypointTargetAssigner(object): ...@@ -1073,7 +1079,8 @@ class CenterNetKeypointTargetAssigner(object):
keypoint_indices, keypoint_indices,
keypoint_std_dev=None, keypoint_std_dev=None,
per_keypoint_offset=False, per_keypoint_offset=False,
peak_radius=0): peak_radius=0,
compute_heatmap_sparse=False):
"""Initializes a CenterNet keypoints target assigner. """Initializes a CenterNet keypoints target assigner.
Args: Args:
...@@ -1100,6 +1107,10 @@ class CenterNetKeypointTargetAssigner(object): ...@@ -1100,6 +1107,10 @@ class CenterNetKeypointTargetAssigner(object):
out_width, 2 * num_keypoints]. out_width, 2 * num_keypoints].
peak_radius: int, the radius (in the unit of output pixel) around heatmap peak_radius: int, the radius (in the unit of output pixel) around heatmap
peak to assign the offset targets. peak to assign the offset targets.
compute_heatmap_sparse: bool, indicating whether or not to use the sparse
version of the Op that computes the heatmap. The sparse version scales
better with number of keypoint types, but in some cases is known to
cause an OOM error. See (b/170989061).
""" """
self._stride = stride self._stride = stride
...@@ -1107,6 +1118,7 @@ class CenterNetKeypointTargetAssigner(object): ...@@ -1107,6 +1118,7 @@ class CenterNetKeypointTargetAssigner(object):
self._keypoint_indices = keypoint_indices self._keypoint_indices = keypoint_indices
self._per_keypoint_offset = per_keypoint_offset self._per_keypoint_offset = per_keypoint_offset
self._peak_radius = peak_radius self._peak_radius = peak_radius
self._compute_heatmap_sparse = compute_heatmap_sparse
if keypoint_std_dev is None: if keypoint_std_dev is None:
self._keypoint_std_dev = ([_DEFAULT_KEYPOINT_OFFSET_STD_DEV] * self._keypoint_std_dev = ([_DEFAULT_KEYPOINT_OFFSET_STD_DEV] *
len(keypoint_indices)) len(keypoint_indices))
......
...@@ -1787,7 +1787,8 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -1787,7 +1787,8 @@ class CenterNetMetaArch(model.DetectionModel):
densepose_params=None, densepose_params=None,
track_params=None, track_params=None,
temporal_offset_params=None, temporal_offset_params=None,
use_depthwise=False): use_depthwise=False,
compute_heatmap_sparse=False):
"""Initializes a CenterNet model. """Initializes a CenterNet model.
Args: Args:
...@@ -1826,6 +1827,10 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -1826,6 +1827,10 @@ class CenterNetMetaArch(model.DetectionModel):
holds the hyper-parameters for offset prediction based tracking. holds the hyper-parameters for offset prediction based tracking.
use_depthwise: If true, all task heads will be constructed using use_depthwise: If true, all task heads will be constructed using
separable_conv. Otherwise, standard convoltuions will be used. separable_conv. Otherwise, standard convoltuions will be used.
compute_heatmap_sparse: bool, whether or not to use the sparse version of
the Op that computes the center heatmaps. The sparse version scales
better with number of channels in the heatmap, but in some cases is
known to cause an OOM error. See b/170989061.
""" """
assert object_detection_params or keypoint_params_dict assert object_detection_params or keypoint_params_dict
# Shorten the name for convenience and better formatting. # Shorten the name for convenience and better formatting.
...@@ -1850,6 +1855,7 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -1850,6 +1855,7 @@ class CenterNetMetaArch(model.DetectionModel):
self._temporal_offset_params = temporal_offset_params self._temporal_offset_params = temporal_offset_params
self._use_depthwise = use_depthwise self._use_depthwise = use_depthwise
self._compute_heatmap_sparse = compute_heatmap_sparse
# Construct the prediction head nets. # Construct the prediction head nets.
self._prediction_head_dict = self._construct_prediction_heads( self._prediction_head_dict = self._construct_prediction_heads(
...@@ -2003,7 +2009,7 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -2003,7 +2009,7 @@ class CenterNetMetaArch(model.DetectionModel):
target_assigners = {} target_assigners = {}
target_assigners[OBJECT_CENTER] = ( target_assigners[OBJECT_CENTER] = (
cn_assigner.CenterNetCenterHeatmapTargetAssigner( cn_assigner.CenterNetCenterHeatmapTargetAssigner(
stride, min_box_overlap_iou)) stride, min_box_overlap_iou, self._compute_heatmap_sparse))
if self._od_params is not None: if self._od_params is not None:
target_assigners[DETECTION_TASK] = ( target_assigners[DETECTION_TASK] = (
cn_assigner.CenterNetBoxTargetAssigner(stride)) cn_assigner.CenterNetBoxTargetAssigner(stride))
...@@ -2016,7 +2022,8 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -2016,7 +2022,8 @@ class CenterNetMetaArch(model.DetectionModel):
keypoint_indices=kp_params.keypoint_indices, keypoint_indices=kp_params.keypoint_indices,
keypoint_std_dev=kp_params.keypoint_std_dev, keypoint_std_dev=kp_params.keypoint_std_dev,
peak_radius=kp_params.offset_peak_radius, peak_radius=kp_params.offset_peak_radius,
per_keypoint_offset=kp_params.per_keypoint_offset)) per_keypoint_offset=kp_params.per_keypoint_offset,
compute_heatmap_sparse=self._compute_heatmap_sparse))
if self._mask_params is not None: if self._mask_params is not None:
target_assigners[SEGMENTATION_TASK] = ( target_assigners[SEGMENTATION_TASK] = (
cn_assigner.CenterNetMaskTargetAssigner(stride)) cn_assigner.CenterNetMaskTargetAssigner(stride))
......
...@@ -9,6 +9,7 @@ import "object_detection/protos/losses.proto"; ...@@ -9,6 +9,7 @@ import "object_detection/protos/losses.proto";
// Points" paper [1] // Points" paper [1]
// [1]: https://arxiv.org/abs/1904.07850 // [1]: https://arxiv.org/abs/1904.07850
// Next Id = 16
message CenterNet { message CenterNet {
// Number of classes to predict. // Number of classes to predict.
optional int32 num_classes = 1; optional int32 num_classes = 1;
...@@ -22,6 +23,12 @@ message CenterNet { ...@@ -22,6 +23,12 @@ message CenterNet {
// If set, all task heads will be constructed with separable convolutions. // If set, all task heads will be constructed with separable convolutions.
optional bool use_depthwise = 13 [default = false]; optional bool use_depthwise = 13 [default = false];
// Indicates whether or not to use the sparse version of the Op that computes
// the center heatmaps. The sparse version scales better with number of
// channels in the heatmap, but in some cases is known to cause an OOM error.
// TODO(b/170989061) When bug is fixed, make this the default behavior.
optional bool compute_heatmap_sparse = 15 [default = false];
// Parameters which are related to object detection task. // Parameters which are related to object detection task.
message ObjectDetection { message ObjectDetection {
// The original fields are moved to ObjectCenterParams or deleted. // The original fields are moved to ObjectCenterParams or deleted.
......
...@@ -41,13 +41,88 @@ def image_shape_to_grids(height, width): ...@@ -41,13 +41,88 @@ def image_shape_to_grids(height, width):
return (y_grid, x_grid) return (y_grid, x_grid)
def _coordinates_to_heatmap_dense(y_grid, x_grid, y_coordinates, x_coordinates,
sigma, channel_onehot, channel_weights=None):
"""Dense version of coordinates to heatmap that uses an outer product."""
num_instances, num_channels = (
shape_utils.combined_static_and_dynamic_shape(channel_onehot))
x_grid = tf.expand_dims(x_grid, 2)
y_grid = tf.expand_dims(y_grid, 2)
# The raw center coordinates in the output space.
x_diff = x_grid - tf.math.floor(x_coordinates)
y_diff = y_grid - tf.math.floor(y_coordinates)
squared_distance = x_diff**2 + y_diff**2
gaussian_map = tf.exp(-squared_distance / (2 * sigma * sigma))
reshaped_gaussian_map = tf.expand_dims(gaussian_map, axis=-1)
reshaped_channel_onehot = tf.reshape(channel_onehot,
(1, 1, num_instances, num_channels))
gaussian_per_box_per_class_map = (
reshaped_gaussian_map * reshaped_channel_onehot)
if channel_weights is not None:
reshaped_weights = tf.reshape(channel_weights, (1, 1, num_instances, 1))
gaussian_per_box_per_class_map *= reshaped_weights
# Take maximum along the "instance" dimension so that all per-instance
# heatmaps of the same class are merged together.
heatmap = tf.reduce_max(gaussian_per_box_per_class_map, axis=2)
# Maximum of an empty tensor is -inf, the following is to avoid that.
heatmap = tf.maximum(heatmap, 0)
return tf.stop_gradient(heatmap)
def _coordinates_to_heatmap_sparse(y_grid, x_grid, y_coordinates, x_coordinates,
sigma, channel_onehot, channel_weights=None):
"""Sparse version of coordinates to heatmap using tf.scatter."""
if not hasattr(tf, 'tensor_scatter_nd_max'):
raise RuntimeError(
('Please upgrade tensowflow to use `tensor_scatter_nd_max` or set '
'compute_heatmap_sparse=False'))
_, num_channels = (
shape_utils.combined_static_and_dynamic_shape(channel_onehot))
height, width = shape_utils.combined_static_and_dynamic_shape(y_grid)
x_grid = tf.expand_dims(x_grid, 2)
y_grid = tf.expand_dims(y_grid, 2)
# The raw center coordinates in the output space.
x_diff = x_grid - tf.math.floor(x_coordinates)
y_diff = y_grid - tf.math.floor(y_coordinates)
squared_distance = x_diff**2 + y_diff**2
gaussian_map = tf.exp(-squared_distance / (2 * sigma * sigma))
if channel_weights is not None:
gaussian_map = gaussian_map * channel_weights[tf.newaxis, tf.newaxis, :]
channel_indices = tf.argmax(channel_onehot, axis=1)
channel_indices = channel_indices[:, tf.newaxis]
heatmap_init = tf.zeros((num_channels, height, width))
gaussian_map = tf.transpose(gaussian_map, (2, 0, 1))
heatmap = tf.tensor_scatter_nd_max(
heatmap_init, channel_indices, gaussian_map)
# Maximum of an empty tensor is -inf, the following is to avoid that.
heatmap = tf.maximum(heatmap, 0)
return tf.stop_gradient(tf.transpose(heatmap, (1, 2, 0)))
def coordinates_to_heatmap(y_grid, def coordinates_to_heatmap(y_grid,
x_grid, x_grid,
y_coordinates, y_coordinates,
x_coordinates, x_coordinates,
sigma, sigma,
channel_onehot, channel_onehot,
channel_weights=None): channel_weights=None,
sparse=False):
"""Returns the heatmap targets from a set of point coordinates. """Returns the heatmap targets from a set of point coordinates.
This function maps a set of point coordinates to the output heatmap image This function maps a set of point coordinates to the output heatmap image
...@@ -71,41 +146,23 @@ def coordinates_to_heatmap(y_grid, ...@@ -71,41 +146,23 @@ def coordinates_to_heatmap(y_grid,
representing the one-hot encoded channel labels for each point. representing the one-hot encoded channel labels for each point.
channel_weights: A 1D tensor with shape [num_instances] corresponding to the channel_weights: A 1D tensor with shape [num_instances] corresponding to the
weight of each instance. weight of each instance.
sparse: bool, indicating whether or not to use the sparse implementation
of the function. The sparse version scales better with number of channels,
but in some cases is known to cause OOM error. See (b/170989061).
Returns: Returns:
heatmap: A tensor of size [height, width, num_channels] representing the heatmap: A tensor of size [height, width, num_channels] representing the
heatmap. Output (height, width) match the dimensions of the input grids. heatmap. Output (height, width) match the dimensions of the input grids.
""" """
num_instances, num_channels = (
shape_utils.combined_static_and_dynamic_shape(channel_onehot))
x_grid = tf.expand_dims(x_grid, 2)
y_grid = tf.expand_dims(y_grid, 2)
# The raw center coordinates in the output space.
x_diff = x_grid - tf.math.floor(x_coordinates)
y_diff = y_grid - tf.math.floor(y_coordinates)
squared_distance = x_diff**2 + y_diff**2
gaussian_map = tf.exp(-squared_distance / (2 * sigma * sigma)) if sparse:
return _coordinates_to_heatmap_sparse(
reshaped_gaussian_map = tf.expand_dims(gaussian_map, axis=-1) y_grid, x_grid, y_coordinates, x_coordinates, sigma, channel_onehot,
reshaped_channel_onehot = tf.reshape(channel_onehot, channel_weights)
(1, 1, num_instances, num_channels)) else:
gaussian_per_box_per_class_map = ( return _coordinates_to_heatmap_dense(
reshaped_gaussian_map * reshaped_channel_onehot) y_grid, x_grid, y_coordinates, x_coordinates, sigma, channel_onehot,
channel_weights)
if channel_weights is not None:
reshaped_weights = tf.reshape(channel_weights, (1, 1, num_instances, 1))
gaussian_per_box_per_class_map *= reshaped_weights
# Take maximum along the "instance" dimension so that all per-instance
# heatmaps of the same class are merged together.
heatmap = tf.reduce_max(gaussian_per_box_per_class_map, axis=2)
# Maximum of an empty tensor is -inf, the following is to avoid that.
heatmap = tf.maximum(heatmap, 0)
return heatmap
def compute_floor_offsets_with_indices(y_source, def compute_floor_offsets_with_indices(y_source,
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# ============================================================================== # ==============================================================================
"""Tests for utils.target_assigner_utils.""" """Tests for utils.target_assigner_utils."""
from absl.testing import parameterized
import numpy as np import numpy as np
import tensorflow.compat.v1 as tf import tensorflow.compat.v1 as tf
...@@ -21,7 +22,7 @@ from object_detection.utils import target_assigner_utils as ta_utils ...@@ -21,7 +22,7 @@ from object_detection.utils import target_assigner_utils as ta_utils
from object_detection.utils import test_case from object_detection.utils import test_case
class TargetUtilTest(test_case.TestCase): class TargetUtilTest(parameterized.TestCase, test_case.TestCase):
def test_image_shape_to_grids(self): def test_image_shape_to_grids(self):
def graph_fn(): def graph_fn():
...@@ -36,7 +37,11 @@ class TargetUtilTest(test_case.TestCase): ...@@ -36,7 +37,11 @@ class TargetUtilTest(test_case.TestCase):
np.testing.assert_array_equal(y_grid, expected_y_grid) np.testing.assert_array_equal(y_grid, expected_y_grid)
np.testing.assert_array_equal(x_grid, expected_x_grid) np.testing.assert_array_equal(x_grid, expected_x_grid)
def test_coordinates_to_heatmap(self): @parameterized.parameters((False,), (True,))
def test_coordinates_to_heatmap(self, sparse):
if not hasattr(tf, 'tensor_scatter_nd_max'):
self.skipTest('Cannot test function due to old TF version.')
def graph_fn(): def graph_fn():
(y_grid, x_grid) = ta_utils.image_shape_to_grids(height=3, width=5) (y_grid, x_grid) = ta_utils.image_shape_to_grids(height=3, width=5)
y_coordinates = tf.constant([1.5, 0.5], dtype=tf.float32) y_coordinates = tf.constant([1.5, 0.5], dtype=tf.float32)
...@@ -46,7 +51,8 @@ class TargetUtilTest(test_case.TestCase): ...@@ -46,7 +51,8 @@ class TargetUtilTest(test_case.TestCase):
channel_weights = tf.constant([1, 1], dtype=tf.float32) channel_weights = tf.constant([1, 1], dtype=tf.float32)
heatmap = ta_utils.coordinates_to_heatmap(y_grid, x_grid, y_coordinates, heatmap = ta_utils.coordinates_to_heatmap(y_grid, x_grid, y_coordinates,
x_coordinates, sigma, x_coordinates, sigma,
channel_onehot, channel_weights) channel_onehot,
channel_weights, sparse=sparse)
return heatmap return heatmap
heatmap = self.execute(graph_fn, []) heatmap = self.execute(graph_fn, [])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment