"conda/git@developer.sourcefind.cn:OpenDAS/torch-sparce.git" did not exist on "f79c8e984e7544ee6f457951980e4e6656caeb94"
Commit 54d2baa4 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower Committed by TF Object Detection Team
Browse files

Add a temporal offset prediction head to improve tracking accuracy, similar to...

Add a temporal offset prediction head to improve tracking accuracy, similar to "Tracking Objects as Points"(https://arxiv.org/abs/2004.01177).
Notice that the original paper requires a change to the input pipeline, which is not included in this commit.

PiperOrigin-RevId: 330796987
parent fcff6f65
...@@ -936,6 +936,21 @@ def tracking_proto_to_params(tracking_config): ...@@ -936,6 +936,21 @@ def tracking_proto_to_params(tracking_config):
task_loss_weight=tracking_config.task_loss_weight) task_loss_weight=tracking_config.task_loss_weight)
def temporal_offset_proto_to_params(temporal_offset_config):
"""Converts CenterNet.TemporalOffsetEstimation proto to param-tuple."""
loss = losses_pb2.Loss()
# Add dummy classification loss to avoid the loss_builder throwing error.
# TODO(yuhuic): update the loss builder to take the classification loss
# directly.
loss.classification_loss.weighted_sigmoid.CopyFrom(
losses_pb2.WeightedSigmoidClassificationLoss())
loss.localization_loss.CopyFrom(temporal_offset_config.localization_loss)
_, localization_loss, _, _, _, _, _ = losses_builder.build(loss)
return center_net_meta_arch.TemporalOffsetParams(
localization_loss=localization_loss,
task_loss_weight=temporal_offset_config.task_loss_weight)
def _build_center_net_model(center_net_config, is_training, add_summaries): def _build_center_net_model(center_net_config, is_training, add_summaries):
"""Build a CenterNet detection model. """Build a CenterNet detection model.
...@@ -998,6 +1013,11 @@ def _build_center_net_model(center_net_config, is_training, add_summaries): ...@@ -998,6 +1013,11 @@ def _build_center_net_model(center_net_config, is_training, add_summaries):
track_params = tracking_proto_to_params( track_params = tracking_proto_to_params(
center_net_config.track_estimation_task) center_net_config.track_estimation_task)
temporal_offset_params = None
if center_net_config.HasField('temporal_offset_task'):
temporal_offset_params = temporal_offset_proto_to_params(
center_net_config.temporal_offset_task)
return center_net_meta_arch.CenterNetMetaArch( return center_net_meta_arch.CenterNetMetaArch(
is_training=is_training, is_training=is_training,
add_summaries=add_summaries, add_summaries=add_summaries,
...@@ -1009,7 +1029,8 @@ def _build_center_net_model(center_net_config, is_training, add_summaries): ...@@ -1009,7 +1029,8 @@ def _build_center_net_model(center_net_config, is_training, add_summaries):
keypoint_params_dict=keypoint_params_dict, keypoint_params_dict=keypoint_params_dict,
mask_params=mask_params, mask_params=mask_params,
densepose_params=densepose_params, densepose_params=densepose_params,
track_params=track_params) track_params=track_params,
temporal_offset_params=temporal_offset_params)
def _build_center_net_feature_extractor( def _build_center_net_feature_extractor(
......
...@@ -102,7 +102,8 @@ class DetectionModel(six.with_metaclass(abc.ABCMeta, _BaseClass)): ...@@ -102,7 +102,8 @@ class DetectionModel(six.with_metaclass(abc.ABCMeta, _BaseClass)):
Args: Args:
field: a string key, options are field: a string key, options are
fields.BoxListFields.{boxes,classes,masks,keypoints, fields.BoxListFields.{boxes,classes,masks,keypoints,
keypoint_visibilities, densepose_*, track_ids} keypoint_visibilities, densepose_*, track_ids,
temporal_offsets, track_match_flags}
fields.InputDataFields.is_annotated. fields.InputDataFields.is_annotated.
Returns: Returns:
...@@ -304,6 +305,8 @@ class DetectionModel(six.with_metaclass(abc.ABCMeta, _BaseClass)): ...@@ -304,6 +305,8 @@ class DetectionModel(six.with_metaclass(abc.ABCMeta, _BaseClass)):
groundtruth_dp_part_ids_list=None, groundtruth_dp_part_ids_list=None,
groundtruth_dp_surface_coords_list=None, groundtruth_dp_surface_coords_list=None,
groundtruth_track_ids_list=None, groundtruth_track_ids_list=None,
groundtruth_temporal_offsets_list=None,
groundtruth_track_match_flags_list=None,
groundtruth_weights_list=None, groundtruth_weights_list=None,
groundtruth_confidences_list=None, groundtruth_confidences_list=None,
groundtruth_is_crowd_list=None, groundtruth_is_crowd_list=None,
...@@ -345,6 +348,12 @@ class DetectionModel(six.with_metaclass(abc.ABCMeta, _BaseClass)): ...@@ -345,6 +348,12 @@ class DetectionModel(six.with_metaclass(abc.ABCMeta, _BaseClass)):
padding. padding.
groundtruth_track_ids_list: a list of 1-D tf.int32 tensors of shape groundtruth_track_ids_list: a list of 1-D tf.int32 tensors of shape
[num_boxes] containing the track IDs of groundtruth objects. [num_boxes] containing the track IDs of groundtruth objects.
groundtruth_temporal_offsets_list: a list of 2-D tf.float32 tensors
of shape [num_boxes, 2] containing the spatial offsets of objects'
centers compared with the previous frame.
groundtruth_track_match_flags_list: a list of 1-D tf.float32 tensors
of shape [num_boxes] containing 0-1 flags that indicate if an object
has existed in the previous frame.
groundtruth_weights_list: A list of 1-D tf.float32 tensors of shape groundtruth_weights_list: A list of 1-D tf.float32 tensors of shape
[num_boxes] containing weights for groundtruth boxes. [num_boxes] containing weights for groundtruth boxes.
groundtruth_confidences_list: A list of 2-D tf.float32 tensors of shape groundtruth_confidences_list: A list of 2-D tf.float32 tensors of shape
...@@ -397,6 +406,14 @@ class DetectionModel(six.with_metaclass(abc.ABCMeta, _BaseClass)): ...@@ -397,6 +406,14 @@ class DetectionModel(six.with_metaclass(abc.ABCMeta, _BaseClass)):
if groundtruth_track_ids_list: if groundtruth_track_ids_list:
self._groundtruth_lists[ self._groundtruth_lists[
fields.BoxListFields.track_ids] = groundtruth_track_ids_list fields.BoxListFields.track_ids] = groundtruth_track_ids_list
if groundtruth_temporal_offsets_list:
self._groundtruth_lists[
fields.BoxListFields.temporal_offsets] = (
groundtruth_temporal_offsets_list)
if groundtruth_track_match_flags_list:
self._groundtruth_lists[
fields.BoxListFields.track_match_flags] = (
groundtruth_track_match_flags_list)
if groundtruth_is_crowd_list: if groundtruth_is_crowd_list:
self._groundtruth_lists[ self._groundtruth_lists[
fields.BoxListFields.is_crowd] = groundtruth_is_crowd_list fields.BoxListFields.is_crowd] = groundtruth_is_crowd_list
......
...@@ -47,6 +47,10 @@ class InputDataFields(object): ...@@ -47,6 +47,10 @@ class InputDataFields(object):
groundtruth_boxes: coordinates of the ground truth boxes in the image. groundtruth_boxes: coordinates of the ground truth boxes in the image.
groundtruth_classes: box-level class labels. groundtruth_classes: box-level class labels.
groundtruth_track_ids: box-level track ID labels. groundtruth_track_ids: box-level track ID labels.
groundtruth_temporal_offset: box-level temporal offsets, i.e.,
movement of the box center in adjacent frames.
groundtruth_track_match_flags: box-level flags indicating if objects
exist in the previous frame.
groundtruth_confidences: box-level class confidences. The shape should be groundtruth_confidences: box-level class confidences. The shape should be
the same as the shape of groundtruth_classes. the same as the shape of groundtruth_classes.
groundtruth_label_types: box-level label types (e.g. explicit negative). groundtruth_label_types: box-level label types (e.g. explicit negative).
...@@ -99,6 +103,8 @@ class InputDataFields(object): ...@@ -99,6 +103,8 @@ class InputDataFields(object):
groundtruth_boxes = 'groundtruth_boxes' groundtruth_boxes = 'groundtruth_boxes'
groundtruth_classes = 'groundtruth_classes' groundtruth_classes = 'groundtruth_classes'
groundtruth_track_ids = 'groundtruth_track_ids' groundtruth_track_ids = 'groundtruth_track_ids'
groundtruth_temporal_offset = 'groundtruth_temporal_offset'
groundtruth_track_match_flags = 'groundtruth_track_match_flags'
groundtruth_confidences = 'groundtruth_confidences' groundtruth_confidences = 'groundtruth_confidences'
groundtruth_label_types = 'groundtruth_label_types' groundtruth_label_types = 'groundtruth_label_types'
groundtruth_is_crowd = 'groundtruth_is_crowd' groundtruth_is_crowd = 'groundtruth_is_crowd'
...@@ -170,6 +176,7 @@ class DetectionResultFields(object): ...@@ -170,6 +176,7 @@ class DetectionResultFields(object):
detection_keypoints = 'detection_keypoints' detection_keypoints = 'detection_keypoints'
detection_keypoint_scores = 'detection_keypoint_scores' detection_keypoint_scores = 'detection_keypoint_scores'
detection_embeddings = 'detection_embeddings' detection_embeddings = 'detection_embeddings'
detection_offsets = 'detection_temporal_offsets'
num_detections = 'num_detections' num_detections = 'num_detections'
raw_detection_boxes = 'raw_detection_boxes' raw_detection_boxes = 'raw_detection_boxes'
raw_detection_scores = 'raw_detection_scores' raw_detection_scores = 'raw_detection_scores'
...@@ -194,6 +201,8 @@ class BoxListFields(object): ...@@ -194,6 +201,8 @@ class BoxListFields(object):
densepose_part_ids: DensePose part ids per bounding box. densepose_part_ids: DensePose part ids per bounding box.
densepose_surface_coords: DensePose surface coordinates per bounding box. densepose_surface_coords: DensePose surface coordinates per bounding box.
is_crowd: is_crowd annotation per bounding box. is_crowd: is_crowd annotation per bounding box.
temporal_offsets: temporal center offsets per bounding box.
track_match_flags: match flags per bounding box.
""" """
boxes = 'boxes' boxes = 'boxes'
classes = 'classes' classes = 'classes'
...@@ -212,6 +221,8 @@ class BoxListFields(object): ...@@ -212,6 +221,8 @@ class BoxListFields(object):
is_crowd = 'is_crowd' is_crowd = 'is_crowd'
group_of = 'group_of' group_of = 'group_of'
track_ids = 'track_ids' track_ids = 'track_ids'
temporal_offsets = 'temporal_offsets'
track_match_flags = 'track_match_flags'
class PredictionFields(object): class PredictionFields(object):
......
...@@ -1980,3 +1980,105 @@ class CenterNetCornerOffsetTargetAssigner(object): ...@@ -1980,3 +1980,105 @@ class CenterNetCornerOffsetTargetAssigner(object):
return (tf.stack(corner_targets, axis=0), return (tf.stack(corner_targets, axis=0),
tf.stack(foreground_targets, axis=0)) tf.stack(foreground_targets, axis=0))
class CenterNetTemporalOffsetTargetAssigner(object):
"""Wrapper to compute target tensors for the temporal offset task.
This class has methods that take as input a batch of ground truth tensors
(in the form of a list) and returns the targets required to train the
temporal offset task.
"""
def __init__(self, stride):
"""Initializes the target assigner.
Args:
stride: int, the stride of the network in output pixels.
"""
self._stride = stride
def assign_temporal_offset_targets(self,
height,
width,
gt_boxes_list,
gt_offsets_list,
gt_match_list,
gt_weights_list=None):
"""Returns the temporal offset targets and their indices.
For each ground truth box, this function assigns it the corresponding
temporal offset to train the model.
Args:
height: int, height of input to the model. This is used to determine the
height of the output.
width: int, width of the input to the model. This is used to determine the
width of the output.
gt_boxes_list: A list of float tensors with shape [num_boxes, 4]
representing the groundtruth detection bounding boxes for each sample in
the batch. The coordinates are expected in normalized coordinates.
gt_offsets_list: A list of 2-D tf.float32 tensors of shape [num_boxes, 2]
containing the spatial offsets of objects' centers compared with the
previous frame.
gt_match_list: A list of 1-D tf.float32 tensors of shape [num_boxes]
containing flags that indicate if an object has existed in the
previous frame.
gt_weights_list: A list of tensors with shape [num_boxes] corresponding to
the weight of each groundtruth detection box.
Returns:
batch_indices: an integer tensor of shape [num_boxes, 3] holding the
indices inside the predicted tensor which should be penalized. The
first column indicates the index along the batch dimension and the
second and third columns indicate the index along the y and x
dimensions respectively.
batch_temporal_offsets: a float tensor of shape [num_boxes, 2] of the
expected y and x temporal offset of each object center in the
output space.
batch_weights: a float tensor of shape [num_boxes] indicating the
weight of each prediction.
"""
if gt_weights_list is None:
gt_weights_list = [None] * len(gt_boxes_list)
batch_indices = []
batch_weights = []
batch_temporal_offsets = []
for i, (boxes, offsets, match_flags, weights) in enumerate(zip(
gt_boxes_list, gt_offsets_list, gt_match_list, gt_weights_list)):
boxes = box_list.BoxList(boxes)
boxes = box_list_ops.to_absolute_coordinates(boxes,
height // self._stride,
width // self._stride)
# Get the box center coordinates. Each returned tensors have the shape of
# [num_boxes]
(y_center, x_center, _, _) = boxes.get_center_coordinates_and_sizes()
num_boxes = tf.shape(x_center)
# Compute the offsets and indices of the box centers. Shape:
# offsets: [num_boxes, 2]
# indices: [num_boxes, 2]
(_, indices) = ta_utils.compute_floor_offsets_with_indices(
y_source=y_center, x_source=x_center)
# Assign ones if weights are not provided.
# if an object is not matched, its weight becomes zero.
if weights is None:
weights = tf.ones(num_boxes, dtype=tf.float32)
weights *= match_flags
# Shape of [num_boxes, 1] integer tensor filled with current batch index.
batch_index = i * tf.ones_like(indices[:, 0:1], dtype=tf.int32)
batch_indices.append(tf.concat([batch_index, indices], axis=1))
batch_weights.append(weights)
batch_temporal_offsets.append(offsets)
batch_indices = tf.concat(batch_indices, axis=0)
batch_weights = tf.concat(batch_weights, axis=0)
batch_temporal_offsets = tf.concat(batch_temporal_offsets, axis=0)
return (batch_indices, batch_temporal_offsets, batch_weights)
...@@ -2290,6 +2290,126 @@ class CornerOffsetTargetAssignerTest(test_case.TestCase): ...@@ -2290,6 +2290,126 @@ class CornerOffsetTargetAssignerTest(test_case.TestCase):
self.assertAllClose(foreground, np.zeros((1, 5, 5))) self.assertAllClose(foreground, np.zeros((1, 5, 5)))
class CenterNetTemporalOffsetTargetAssigner(test_case.TestCase):
def setUp(self):
super(CenterNetTemporalOffsetTargetAssigner, self).setUp()
self._box_center = [0.0, 0.0, 1.0, 1.0]
self._box_center_small = [0.25, 0.25, 0.75, 0.75]
self._box_lower_left = [0.5, 0.0, 1.0, 0.5]
self._box_center_offset = [0.1, 0.05, 1.0, 1.0]
self._box_odd_coordinates = [0.1625, 0.2125, 0.5625, 0.9625]
self._offset_center = [0.5, 0.4]
self._offset_center_small = [0.1, 0.1]
self._offset_lower_left = [-0.1, 0.1]
self._offset_center_offset = [0.4, 0.3]
self._offset_odd_coord = [0.125, -0.125]
def test_assign_empty_groundtruths(self):
"""Tests the assign_offset_targets function with empty inputs."""
def graph_fn():
box_batch = [
tf.zeros((0, 4), dtype=tf.float32),
]
offset_batch = [
tf.zeros((0, 2), dtype=tf.float32),
]
match_flag_batch = [
tf.zeros((0), dtype=tf.float32),
]
assigner = targetassigner.CenterNetTemporalOffsetTargetAssigner(4)
indices, temporal_offset, weights = assigner.assign_temporal_offset_targets(
80, 80, box_batch, offset_batch, match_flag_batch)
return indices, temporal_offset, weights
indices, temporal_offset, weights = self.execute(graph_fn, [])
self.assertEqual(indices.shape, (0, 3))
self.assertEqual(temporal_offset.shape, (0, 2))
self.assertEqual(weights.shape, (0,))
def test_assign_offset_targets(self):
"""Tests the assign_offset_targets function."""
def graph_fn():
box_batch = [
tf.constant([self._box_center, self._box_lower_left]),
tf.constant([self._box_center_offset]),
tf.constant([self._box_center_small, self._box_odd_coordinates]),
]
offset_batch = [
tf.constant([self._offset_center, self._offset_lower_left]),
tf.constant([self._offset_center_offset]),
tf.constant([self._offset_center_small, self._offset_odd_coord]),
]
match_flag_batch = [
tf.constant([1.0, 1.0]),
tf.constant([1.0]),
tf.constant([1.0, 1.0]),
]
assigner = targetassigner.CenterNetTemporalOffsetTargetAssigner(4)
indices, temporal_offset, weights = assigner.assign_temporal_offset_targets(
80, 80, box_batch, offset_batch, match_flag_batch)
return indices, temporal_offset, weights
indices, temporal_offset, weights = self.execute(graph_fn, [])
self.assertEqual(indices.shape, (5, 3))
self.assertEqual(temporal_offset.shape, (5, 2))
self.assertEqual(weights.shape, (5,))
np.testing.assert_array_equal(
indices,
[[0, 10, 10], [0, 15, 5], [1, 11, 10], [2, 10, 10], [2, 7, 11]])
np.testing.assert_array_almost_equal(
temporal_offset,
[[0.5, 0.4], [-0.1, 0.1], [0.4, 0.3], [0.1, 0.1], [0.125, -0.125]])
np.testing.assert_array_equal(weights, 1)
def test_assign_offset_targets_with_match_flags(self):
"""Tests the assign_offset_targets function with match flags."""
def graph_fn():
box_batch = [
tf.constant([self._box_center, self._box_lower_left]),
tf.constant([self._box_center_offset]),
tf.constant([self._box_center_small, self._box_odd_coordinates]),
]
offset_batch = [
tf.constant([self._offset_center, self._offset_lower_left]),
tf.constant([self._offset_center_offset]),
tf.constant([self._offset_center_small, self._offset_odd_coord]),
]
match_flag_batch = [
tf.constant([0.0, 1.0]),
tf.constant([1.0]),
tf.constant([1.0, 1.0]),
]
cn_assigner = targetassigner.CenterNetTemporalOffsetTargetAssigner(4)
weights_batch = [
tf.constant([1.0, 0.0]),
tf.constant([1.0]),
tf.constant([1.0, 1.0])
]
indices, temporal_offset, weights = cn_assigner.assign_temporal_offset_targets(
80, 80, box_batch, offset_batch, match_flag_batch, weights_batch)
return indices, temporal_offset, weights
indices, temporal_offset, weights = self.execute(graph_fn, [])
self.assertEqual(indices.shape, (5, 3))
self.assertEqual(temporal_offset.shape, (5, 2))
self.assertEqual(weights.shape, (5,))
np.testing.assert_array_equal(
indices,
[[0, 10, 10], [0, 15, 5], [1, 11, 10], [2, 10, 10], [2, 7, 11]])
np.testing.assert_array_almost_equal(
temporal_offset,
[[0.5, 0.4], [-0.1, 0.1], [0.4, 0.3], [0.1, 0.1], [0.125, -0.125]])
np.testing.assert_array_equal(weights, [0, 0, 1, 1, 1])
if __name__ == '__main__': if __name__ == '__main__':
tf.enable_v2_behavior() tf.enable_v2_behavior()
tf.test.main() tf.test.main()
...@@ -329,6 +329,39 @@ def prediction_tensors_to_boxes(detection_scores, y_indices, x_indices, ...@@ -329,6 +329,39 @@ def prediction_tensors_to_boxes(detection_scores, y_indices, x_indices,
return boxes, detection_classes, detection_scores, num_detections return boxes, detection_classes, detection_scores, num_detections
def prediction_tensors_to_temporal_offsets(
y_indices, x_indices, offset_predictions):
"""Converts CenterNet temporal offset map predictions to batched format.
This function is similiar to the box offset conversion function, as both
temporal offsets and box offsets are size-2 vectors.
Args:
y_indices: A [batch, num_boxes] int32 tensor with y indices corresponding to
object center locations (expressed in output coordinate frame).
x_indices: A [batch, num_boxes] int32 tensor with x indices corresponding to
object center locations (expressed in output coordinate frame).
offset_predictions: A float tensor of shape [batch_size, height, width, 2]
representing the y and x offsets of a box's center across adjacent frames.
Returns:
offsets: A tensor of shape [batch_size, num_boxes, 2] holding the
the object temporal offsets of (y, x) dimensions.
"""
_, _, width, _ = _get_shape(offset_predictions, 4)
peak_spatial_indices = flattened_indices_from_row_col_indices(
y_indices, x_indices, width)
y_indices = _to_float32(y_indices)
x_indices = _to_float32(x_indices)
offsets_flat = _flatten_spatial_dimensions(offset_predictions)
offsets = tf.gather(offsets_flat, peak_spatial_indices, batch_dims=1)
return offsets
def prediction_tensors_to_keypoint_candidates( def prediction_tensors_to_keypoint_candidates(
keypoint_heatmap_predictions, keypoint_heatmap_predictions,
keypoint_heatmap_offsets, keypoint_heatmap_offsets,
...@@ -1534,6 +1567,32 @@ class TrackParams( ...@@ -1534,6 +1567,32 @@ class TrackParams(
num_fc_layers, classification_loss, num_fc_layers, classification_loss,
task_loss_weight) task_loss_weight)
class TemporalOffsetParams(
collections.namedtuple('TemporalOffsetParams', [
'localization_loss', 'task_loss_weight'
])):
"""Namedtuple to store temporal offset related parameters."""
__slots__ = ()
def __new__(cls,
localization_loss,
task_loss_weight=1.0):
"""Constructor with default values for TrackParams.
Args:
localization_loss: an object_detection.core.losses.Loss object to
compute the loss for the temporal offset in CenterNet.
task_loss_weight: float, the loss weight for the temporal offset
task.
Returns:
An initialized TemporalOffsetParams namedtuple.
"""
return super(TemporalOffsetParams,
cls).__new__(cls, localization_loss, task_loss_weight)
# The following constants are used to generate the keys of the # The following constants are used to generate the keys of the
# (prediction, loss, target assigner,...) dictionaries used in CenterNetMetaArch # (prediction, loss, target assigner,...) dictionaries used in CenterNetMetaArch
# class. # class.
...@@ -1552,6 +1611,8 @@ DENSEPOSE_REGRESSION = 'densepose/regression' ...@@ -1552,6 +1611,8 @@ DENSEPOSE_REGRESSION = 'densepose/regression'
LOSS_KEY_PREFIX = 'Loss' LOSS_KEY_PREFIX = 'Loss'
TRACK_TASK = 'track_task' TRACK_TASK = 'track_task'
TRACK_REID = 'track/reid' TRACK_REID = 'track/reid'
TEMPORALOFFSET_TASK = 'temporal_offset_task'
TEMPORAL_OFFSET = 'track/offset'
def get_keypoint_name(task_name, head_name): def get_keypoint_name(task_name, head_name):
...@@ -1596,7 +1657,8 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -1596,7 +1657,8 @@ class CenterNetMetaArch(model.DetectionModel):
keypoint_params_dict=None, keypoint_params_dict=None,
mask_params=None, mask_params=None,
densepose_params=None, densepose_params=None,
track_params=None): track_params=None,
temporal_offset_params=None):
"""Initializes a CenterNet model. """Initializes a CenterNet model.
Args: Args:
...@@ -1631,6 +1693,8 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -1631,6 +1693,8 @@ class CenterNetMetaArch(model.DetectionModel):
track_params: A TrackParams namedtuple. This object track_params: A TrackParams namedtuple. This object
holds the hyper-parameters for tracking. Please see the class holds the hyper-parameters for tracking. Please see the class
definition for more details. definition for more details.
temporal_offset_params: A TemporalOffsetParams namedtuple. This object
holds the hyper-parameters for offset prediction based tracking.
""" """
assert object_detection_params or keypoint_params_dict assert object_detection_params or keypoint_params_dict
# Shorten the name for convenience and better formatting. # Shorten the name for convenience and better formatting.
...@@ -1651,6 +1715,7 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -1651,6 +1715,7 @@ class CenterNetMetaArch(model.DetectionModel):
'be supplied.') 'be supplied.')
self._densepose_params = densepose_params self._densepose_params = densepose_params
self._track_params = track_params self._track_params = track_params
self._temporal_offset_params = temporal_offset_params
# Construct the prediction head nets. # Construct the prediction head nets.
self._prediction_head_dict = self._construct_prediction_heads( self._prediction_head_dict = self._construct_prediction_heads(
...@@ -1764,6 +1829,11 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -1764,6 +1829,11 @@ class CenterNetMetaArch(model.DetectionModel):
tf.keras.layers.Dense(self._track_params.num_track_ids, tf.keras.layers.Dense(self._track_params.num_track_ids,
input_shape=( input_shape=(
self._track_params.reid_embed_size,))) self._track_params.reid_embed_size,)))
if self._temporal_offset_params is not None:
prediction_heads[TEMPORAL_OFFSET] = [
make_prediction_net(NUM_OFFSET_CHANNELS)
for _ in range(num_feature_outputs)
]
return prediction_heads return prediction_heads
def _initialize_target_assigners(self, stride, min_box_overlap_iou): def _initialize_target_assigners(self, stride, min_box_overlap_iou):
...@@ -1806,6 +1876,9 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -1806,6 +1876,9 @@ class CenterNetMetaArch(model.DetectionModel):
target_assigners[TRACK_TASK] = ( target_assigners[TRACK_TASK] = (
cn_assigner.CenterNetTrackTargetAssigner( cn_assigner.CenterNetTrackTargetAssigner(
stride, self._track_params.num_track_ids)) stride, self._track_params.num_track_ids))
if self._temporal_offset_params is not None:
target_assigners[TEMPORALOFFSET_TASK] = (
cn_assigner.CenterNetTemporalOffsetTargetAssigner(stride))
return target_assigners return target_assigners
...@@ -2394,6 +2467,54 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -2394,6 +2467,54 @@ class CenterNetMetaArch(model.DetectionModel):
return loss_per_instance return loss_per_instance
def _compute_temporal_offset_loss(self, input_height,
input_width, prediction_dict):
"""Computes the temporal offset loss for tracking.
Args:
input_height: An integer scalar tensor representing input image height.
input_width: An integer scalar tensor representing input image width.
prediction_dict: The dictionary returned from the predict() method.
Returns:
A dictionary with track/temporal_offset losses.
"""
gt_boxes_list = self.groundtruth_lists(fields.BoxListFields.boxes)
gt_offsets_list = self.groundtruth_lists(
fields.BoxListFields.temporal_offsets)
gt_match_list = self.groundtruth_lists(
fields.BoxListFields.track_match_flags)
gt_weights_list = self.groundtruth_lists(fields.BoxListFields.weights)
num_boxes = tf.cast(
get_num_instances_from_weights(gt_weights_list), tf.float32)
offset_predictions = prediction_dict[TEMPORAL_OFFSET]
num_predictions = float(len(offset_predictions))
assigner = self._target_assigner_dict[TEMPORALOFFSET_TASK]
(batch_indices, batch_offset_targets,
batch_weights) = assigner.assign_temporal_offset_targets(
height=input_height,
width=input_width,
gt_boxes_list=gt_boxes_list,
gt_offsets_list=gt_offsets_list,
gt_match_list=gt_match_list,
gt_weights_list=gt_weights_list)
batch_weights = tf.expand_dims(batch_weights, -1)
offset_loss_fn = self._temporal_offset_params.localization_loss
loss_dict = {}
offset_loss = 0
for offset_pred in offset_predictions:
offset_pred = cn_assigner.get_batch_predictions_from_indices(
offset_pred, batch_indices)
offset_loss += offset_loss_fn(offset_pred[:, None],
batch_offset_targets[:, None],
weights=batch_weights)
offset_loss = tf.reduce_sum(offset_loss) / (num_predictions * num_boxes)
loss_dict[TEMPORAL_OFFSET] = offset_loss
return loss_dict
def preprocess(self, inputs): def preprocess(self, inputs):
outputs = shape_utils.resize_images_and_return_shapes( outputs = shape_utils.resize_images_and_return_shapes(
inputs, self._image_resizer_fn) inputs, self._image_resizer_fn)
...@@ -2490,6 +2611,7 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -2490,6 +2611,7 @@ class CenterNetMetaArch(model.DetectionModel):
'Loss/densepose/heatmap', (optional) 'Loss/densepose/heatmap', (optional)
'Loss/densepose/regression', (optional) 'Loss/densepose/regression', (optional)
'Loss/track/reid'] (optional) 'Loss/track/reid'] (optional)
'Loss/track/offset'] (optional)
scalar tensors corresponding to the losses for different tasks. Note the scalar tensors corresponding to the losses for different tasks. Note the
$TASK_NAME is provided by the KeypointEstimation namedtuple used to $TASK_NAME is provided by the KeypointEstimation namedtuple used to
differentiate between different keypoint tasks. differentiate between different keypoint tasks.
...@@ -2567,6 +2689,16 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -2567,6 +2689,16 @@ class CenterNetMetaArch(model.DetectionModel):
track_losses[key] * self._track_params.task_loss_weight) track_losses[key] * self._track_params.task_loss_weight)
losses.update(track_losses) losses.update(track_losses)
if self._temporal_offset_params is not None:
offset_losses = self._compute_temporal_offset_loss(
input_height=input_height,
input_width=input_width,
prediction_dict=prediction_dict)
for key in offset_losses:
offset_losses[key] = (
offset_losses[key] * self._temporal_offset_params.task_loss_weight)
losses.update(offset_losses)
# Prepend the LOSS_KEY_PREFIX to the keys in the dictionary such that the # Prepend the LOSS_KEY_PREFIX to the keys in the dictionary such that the
# losses will be grouped together in Tensorboard. # losses will be grouped together in Tensorboard.
return dict([('%s/%s' % (LOSS_KEY_PREFIX, key), val) return dict([('%s/%s' % (LOSS_KEY_PREFIX, key), val)
...@@ -2683,6 +2815,12 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -2683,6 +2815,12 @@ class CenterNetMetaArch(model.DetectionModel):
fields.DetectionResultFields.detection_embeddings: embeddings fields.DetectionResultFields.detection_embeddings: embeddings
}) })
if self._temporal_offset_params:
offsets = prediction_tensors_to_temporal_offsets(
y_indices, x_indices,
prediction_dict[TEMPORAL_OFFSET][-1])
postprocess_dict[fields.DetectionResultFields.detection_offsets] = offsets
return postprocess_dict return postprocess_dict
def _postprocess_embeddings(self, prediction_dict, y_indices, x_indices): def _postprocess_embeddings(self, prediction_dict, y_indices, x_indices):
......
...@@ -547,6 +547,53 @@ class CenterNetMetaArchHelpersTest(test_case.TestCase, parameterized.TestCase): ...@@ -547,6 +547,53 @@ class CenterNetMetaArchHelpersTest(test_case.TestCase, parameterized.TestCase):
np.testing.assert_allclose(scores[1][:1], [.9]) np.testing.assert_allclose(scores[1][:1], [.9])
np.testing.assert_allclose(scores[2], [1., .8]) np.testing.assert_allclose(scores[2], [1., .8])
def test_offset_prediction(self):
class_pred = np.zeros((3, 128, 128, 5), dtype=np.float32)
offset_pred = np.zeros((3, 128, 128, 2), dtype=np.float32)
# Sample 1, 2 boxes
class_pred[0, 10, 20] = [0.3, .7, 0.0, 0.0, 0.0]
offset_pred[0, 10, 20] = [1, 2]
class_pred[0, 50, 60] = [0.55, 0.0, 0.0, 0.0, 0.45]
offset_pred[0, 50, 60] = [0, 0]
# Sample 2, 2 boxes (at same location)
class_pred[1, 100, 100] = [0.0, 0.1, 0.9, 0.0, 0.0]
offset_pred[1, 100, 100] = [1, 3]
# Sample 3, 3 boxes
class_pred[2, 60, 90] = [0.0, 0.0, 0.0, 0.2, 0.8]
offset_pred[2, 60, 90] = [0, 0]
class_pred[2, 65, 95] = [0.0, 0.7, 0.3, 0.0, 0.0]
offset_pred[2, 65, 95] = [1, 2]
class_pred[2, 75, 85] = [1.0, 0.0, 0.0, 0.0, 0.0]
offset_pred[2, 75, 85] = [5, 2]
def graph_fn():
class_pred_tensor = tf.constant(class_pred)
offset_pred_tensor = tf.constant(offset_pred)
_, y_indices, x_indices, _ = (
cnma.top_k_feature_map_locations(
class_pred_tensor, max_pool_kernel_size=3, k=2))
offsets = cnma.prediction_tensors_to_temporal_offsets(
y_indices, x_indices, offset_pred_tensor)
return offsets
offsets = self.execute(graph_fn, [])
np.testing.assert_allclose(
[[1, 2], [0, 0]], offsets[0])
np.testing.assert_allclose(
[[1, 3], [1, 3]], offsets[1])
np.testing.assert_allclose(
[[5, 2], [0, 0]], offsets[2])
def test_keypoint_candidate_prediction(self): def test_keypoint_candidate_prediction(self):
keypoint_heatmap_np = np.zeros((2, 3, 3, 2), dtype=np.float32) keypoint_heatmap_np = np.zeros((2, 3, 3, 2), dtype=np.float32)
keypoint_heatmap_np[0, 0, 0, 0] = 1.0 keypoint_heatmap_np[0, 0, 0, 0] = 1.0
...@@ -1156,6 +1203,13 @@ def get_fake_track_params(): ...@@ -1156,6 +1203,13 @@ def get_fake_track_params():
task_loss_weight=1.0) task_loss_weight=1.0)
def get_fake_temporal_offset_params():
"""Returns the fake temporal offset parameter namedtuple."""
return cnma.TemporalOffsetParams(
localization_loss=losses.WeightedSmoothL1LocalizationLoss(),
task_loss_weight=1.0)
def build_center_net_meta_arch(build_resnet=False): def build_center_net_meta_arch(build_resnet=False):
"""Builds the CenterNet meta architecture.""" """Builds the CenterNet meta architecture."""
if build_resnet: if build_resnet:
...@@ -1185,7 +1239,8 @@ def build_center_net_meta_arch(build_resnet=False): ...@@ -1185,7 +1239,8 @@ def build_center_net_meta_arch(build_resnet=False):
keypoint_params_dict={_TASK_NAME: get_fake_kp_params()}, keypoint_params_dict={_TASK_NAME: get_fake_kp_params()},
mask_params=get_fake_mask_params(), mask_params=get_fake_mask_params(),
densepose_params=get_fake_densepose_params(), densepose_params=get_fake_densepose_params(),
track_params=get_fake_track_params()) track_params=get_fake_track_params(),
temporal_offset_params=get_fake_temporal_offset_params())
def _logit(p): def _logit(p):
...@@ -1284,6 +1339,11 @@ class CenterNetMetaArchTest(test_case.TestCase, parameterized.TestCase): ...@@ -1284,6 +1339,11 @@ class CenterNetMetaArchTest(test_case.TestCase, parameterized.TestCase):
fake_feature_map) fake_feature_map)
self.assertEqual((4, 128, 128, _REID_EMBED_SIZE), output.shape) self.assertEqual((4, 128, 128, _REID_EMBED_SIZE), output.shape)
# "temporal offset" head:
output = model._prediction_head_dict[cnma.TEMPORAL_OFFSET][-1](
fake_feature_map)
self.assertEqual((4, 128, 128, 2), output.shape)
def test_initialize_target_assigners(self): def test_initialize_target_assigners(self):
model = build_center_net_meta_arch() model = build_center_net_meta_arch()
assigner_dict = model._initialize_target_assigners( assigner_dict = model._initialize_target_assigners(
...@@ -1315,6 +1375,10 @@ class CenterNetMetaArchTest(test_case.TestCase, parameterized.TestCase): ...@@ -1315,6 +1375,10 @@ class CenterNetMetaArchTest(test_case.TestCase, parameterized.TestCase):
self.assertIsInstance(assigner_dict[cnma.TRACK_TASK], self.assertIsInstance(assigner_dict[cnma.TRACK_TASK],
cn_assigner.CenterNetTrackTargetAssigner) cn_assigner.CenterNetTrackTargetAssigner)
# Temporal Offset target assigner:
self.assertIsInstance(assigner_dict[cnma.TEMPORALOFFSET_TASK],
cn_assigner.CenterNetTemporalOffsetTargetAssigner)
def test_predict(self): def test_predict(self):
"""Test the predict function.""" """Test the predict function."""
...@@ -1341,6 +1405,8 @@ class CenterNetMetaArchTest(test_case.TestCase, parameterized.TestCase): ...@@ -1341,6 +1405,8 @@ class CenterNetMetaArchTest(test_case.TestCase, parameterized.TestCase):
(2, 32, 32, 2 * _DENSEPOSE_NUM_PARTS)) (2, 32, 32, 2 * _DENSEPOSE_NUM_PARTS))
self.assertEqual(prediction_dict[cnma.TRACK_REID][0].shape, self.assertEqual(prediction_dict[cnma.TRACK_REID][0].shape,
(2, 32, 32, _REID_EMBED_SIZE)) (2, 32, 32, _REID_EMBED_SIZE))
self.assertEqual(prediction_dict[cnma.TEMPORAL_OFFSET][0].shape,
(2, 32, 32, 2))
def test_loss(self): def test_loss(self):
"""Test the loss function.""" """Test the loss function."""
...@@ -1361,7 +1427,11 @@ class CenterNetMetaArchTest(test_case.TestCase, parameterized.TestCase): ...@@ -1361,7 +1427,11 @@ class CenterNetMetaArchTest(test_case.TestCase, parameterized.TestCase):
groundtruth_dp_surface_coords_list=groundtruth_dict[ groundtruth_dp_surface_coords_list=groundtruth_dict[
fields.BoxListFields.densepose_surface_coords], fields.BoxListFields.densepose_surface_coords],
groundtruth_track_ids_list=groundtruth_dict[ groundtruth_track_ids_list=groundtruth_dict[
fields.BoxListFields.track_ids]) fields.BoxListFields.track_ids],
groundtruth_track_match_flags_list=groundtruth_dict[
fields.BoxListFields.track_match_flags],
groundtruth_temporal_offsets_list=groundtruth_dict[
fields.BoxListFields.temporal_offsets])
kernel_initializer = tf.constant_initializer( kernel_initializer = tf.constant_initializer(
[[1, 1, 0], [-1000000, -1000000, 1000000]]) [[1, 1, 0], [-1000000, -1000000, 1000000]])
...@@ -1413,6 +1483,9 @@ class CenterNetMetaArchTest(test_case.TestCase, parameterized.TestCase): ...@@ -1413,6 +1483,9 @@ class CenterNetMetaArchTest(test_case.TestCase, parameterized.TestCase):
self.assertGreater( self.assertGreater(
0.01, loss_dict['%s/%s' % (cnma.LOSS_KEY_PREFIX, 0.01, loss_dict['%s/%s' % (cnma.LOSS_KEY_PREFIX,
cnma.TRACK_REID)]) cnma.TRACK_REID)])
self.assertGreater(
0.01, loss_dict['%s/%s' % (cnma.LOSS_KEY_PREFIX,
cnma.TEMPORAL_OFFSET)])
@parameterized.parameters( @parameterized.parameters(
{'target_class_id': 1}, {'target_class_id': 1},
...@@ -1463,6 +1536,9 @@ class CenterNetMetaArchTest(test_case.TestCase, parameterized.TestCase): ...@@ -1463,6 +1536,9 @@ class CenterNetMetaArchTest(test_case.TestCase, parameterized.TestCase):
dtype=np.float32) dtype=np.float32)
track_reid_embedding[0, 16, 16, :] = np.ones(embedding_size) track_reid_embedding[0, 16, 16, :] = np.ones(embedding_size)
temporal_offsets = np.zeros((1, 32, 32, 2), dtype=np.float32)
temporal_offsets[..., 1] = 1
class_center = tf.constant(class_center) class_center = tf.constant(class_center)
height_width = tf.constant(height_width) height_width = tf.constant(height_width)
offset = tf.constant(offset) offset = tf.constant(offset)
...@@ -1473,6 +1549,7 @@ class CenterNetMetaArchTest(test_case.TestCase, parameterized.TestCase): ...@@ -1473,6 +1549,7 @@ class CenterNetMetaArchTest(test_case.TestCase, parameterized.TestCase):
dp_part_heatmap = tf.constant(dp_part_heatmap, dtype=tf.float32) dp_part_heatmap = tf.constant(dp_part_heatmap, dtype=tf.float32)
dp_surf_coords = tf.constant(dp_surf_coords, dtype=tf.float32) dp_surf_coords = tf.constant(dp_surf_coords, dtype=tf.float32)
track_reid_embedding = tf.constant(track_reid_embedding, dtype=tf.float32) track_reid_embedding = tf.constant(track_reid_embedding, dtype=tf.float32)
temporal_offsets = tf.constant(temporal_offsets, dtype=tf.float32)
prediction_dict = { prediction_dict = {
cnma.OBJECT_CENTER: [class_center], cnma.OBJECT_CENTER: [class_center],
...@@ -1487,7 +1564,8 @@ class CenterNetMetaArchTest(test_case.TestCase, parameterized.TestCase): ...@@ -1487,7 +1564,8 @@ class CenterNetMetaArchTest(test_case.TestCase, parameterized.TestCase):
cnma.SEGMENTATION_HEATMAP: [segmentation_heatmap], cnma.SEGMENTATION_HEATMAP: [segmentation_heatmap],
cnma.DENSEPOSE_HEATMAP: [dp_part_heatmap], cnma.DENSEPOSE_HEATMAP: [dp_part_heatmap],
cnma.DENSEPOSE_REGRESSION: [dp_surf_coords], cnma.DENSEPOSE_REGRESSION: [dp_surf_coords],
cnma.TRACK_REID: [track_reid_embedding] cnma.TRACK_REID: [track_reid_embedding],
cnma.TEMPORAL_OFFSET: [temporal_offsets],
} }
def graph_fn(): def graph_fn():
...@@ -1519,6 +1597,8 @@ class CenterNetMetaArchTest(test_case.TestCase, parameterized.TestCase): ...@@ -1519,6 +1597,8 @@ class CenterNetMetaArchTest(test_case.TestCase, parameterized.TestCase):
detections['detection_masks'].shape) detections['detection_masks'].shape)
self.assertAllEqual([1, max_detection, embedding_size], self.assertAllEqual([1, max_detection, embedding_size],
detections['detection_embeddings'].shape) detections['detection_embeddings'].shape)
self.assertAllEqual([1, max_detection, 2],
detections['detection_temporal_offsets'].shape)
# Masks should be empty for everything but the first detection. # Masks should be empty for everything but the first detection.
self.assertAllEqual( self.assertAllEqual(
...@@ -1632,6 +1712,10 @@ def get_fake_prediction_dict(input_height, input_width, stride): ...@@ -1632,6 +1712,10 @@ def get_fake_prediction_dict(input_height, input_width, stride):
_REID_EMBED_SIZE), dtype=np.float32) _REID_EMBED_SIZE), dtype=np.float32)
track_reid_embedding[0, 2, 4, :] = np.arange(_REID_EMBED_SIZE) track_reid_embedding[0, 2, 4, :] = np.arange(_REID_EMBED_SIZE)
temporal_offsets = np.zeros((2, output_height, output_width, 2),
dtype=np.float32)
temporal_offsets[0, 2, 4, :] = 5
prediction_dict = { prediction_dict = {
'preprocessed_inputs': 'preprocessed_inputs':
tf.zeros((2, input_height, input_width, 3)), tf.zeros((2, input_height, input_width, 3)),
...@@ -1674,7 +1758,11 @@ def get_fake_prediction_dict(input_height, input_width, stride): ...@@ -1674,7 +1758,11 @@ def get_fake_prediction_dict(input_height, input_width, stride):
cnma.TRACK_REID: [ cnma.TRACK_REID: [
tf.constant(track_reid_embedding), tf.constant(track_reid_embedding),
tf.constant(track_reid_embedding), tf.constant(track_reid_embedding),
] ],
cnma.TEMPORAL_OFFSET: [
tf.constant(temporal_offsets),
tf.constant(temporal_offsets),
],
} }
return prediction_dict return prediction_dict
...@@ -1736,6 +1824,14 @@ def get_fake_groundtruth_dict(input_height, input_width, stride): ...@@ -1736,6 +1824,14 @@ def get_fake_groundtruth_dict(input_height, input_width, stride):
tf.constant([2], dtype=tf.int32), tf.constant([2], dtype=tf.int32),
tf.constant([1], dtype=tf.int32), tf.constant([1], dtype=tf.int32),
] ]
temporal_offsets = [
tf.constant([[5.0, 5.0]], dtype=tf.float32),
tf.constant([[2.0, 3.0]], dtype=tf.float32),
]
track_match_flags = [
tf.constant([1.0], dtype=tf.float32),
tf.constant([1.0], dtype=tf.float32),
]
groundtruth_dict = { groundtruth_dict = {
fields.BoxListFields.boxes: boxes, fields.BoxListFields.boxes: boxes,
fields.BoxListFields.weights: weights, fields.BoxListFields.weights: weights,
...@@ -1747,6 +1843,8 @@ def get_fake_groundtruth_dict(input_height, input_width, stride): ...@@ -1747,6 +1843,8 @@ def get_fake_groundtruth_dict(input_height, input_width, stride):
fields.BoxListFields.densepose_surface_coords: fields.BoxListFields.densepose_surface_coords:
densepose_surface_coords, densepose_surface_coords,
fields.BoxListFields.track_ids: track_ids, fields.BoxListFields.track_ids: track_ids,
fields.BoxListFields.temporal_offsets: temporal_offsets,
fields.BoxListFields.track_match_flags: track_match_flags,
fields.InputDataFields.groundtruth_labeled_classes: labeled_classes, fields.InputDataFields.groundtruth_labeled_classes: labeled_classes,
} }
return groundtruth_dict return groundtruth_dict
......
...@@ -245,6 +245,21 @@ message CenterNet { ...@@ -245,6 +245,21 @@ message CenterNet {
} }
optional TrackEstimation track_estimation_task = 10; optional TrackEstimation track_estimation_task = 10;
// Temporal offset prediction head similar to CenterTrack.
// Currently our implementation adopts LSTM, different from original paper.
// See go/lstd-centernet for more details.
// Tracking Objects as Points [3]
// [3]: https://arxiv.org/abs/2004.01177
message TemporalOffsetEstimation {
// Weight of the task loss. The total loss of the model will be the
// summation of task losses weighted by the weights.
optional float task_loss_weight = 1 [default = 1.0];
// Localization loss configuration for offset loss.
optional LocalizationLoss localization_loss = 2;
}
optional TemporalOffsetEstimation temporal_offset_task = 12;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment