Commit 47bc1813 authored by syiming's avatar syiming
Browse files

Merge remote-tracking branch 'upstream/master' into add_multilevel_crop_and_resize

parents d8611151 b035a227
......@@ -24,10 +24,6 @@ from six.moves import range
import tensorflow.compat.v1 as tf
from object_detection.core import prefetcher
from object_detection.utils import tf_version
if not tf_version.is_tf1():
raise ValueError('`batcher.py` is only supported in Tensorflow 1.X')
rt_shape_str = '_runtime_shapes'
......
......@@ -19,14 +19,17 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import unittest
import numpy as np
from six.moves import range
import tensorflow.compat.v1 as tf
import tf_slim as slim
from object_detection.core import batcher
from object_detection.utils import tf_version
@unittest.skipIf(tf_version.is_tf2(), 'Skipping TF1.X only test.')
class BatcherTest(tf.test.TestCase):
def test_batch_and_unpad_2d_tensors_of_different_sizes_in_1st_dimension(self):
......
......@@ -17,15 +17,17 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import unittest
import numpy as np
from six.moves import zip
import tensorflow.compat.v1 as tf
from object_detection.core import freezable_batch_norm
from object_detection.utils import tf_version
@unittest.skipIf(tf_version.is_tf1(), 'Skipping TF2.X only test.')
class FreezableBatchNormTest(tf.test.TestCase):
"""Tests for FreezableBatchNorm operations."""
......
......@@ -217,7 +217,7 @@ def to_absolute_coordinates(keypoints, height, width,
return scale(keypoints, height, width)
def flip_horizontal(keypoints, flip_point, flip_permutation, scope=None):
def flip_horizontal(keypoints, flip_point, flip_permutation=None, scope=None):
"""Flips the keypoints horizontally around the flip_point.
This operation flips the x coordinate for each keypoint around the flip_point
......@@ -227,13 +227,14 @@ def flip_horizontal(keypoints, flip_point, flip_permutation, scope=None):
keypoints: a tensor of shape [num_instances, num_keypoints, 2]
flip_point: (float) scalar tensor representing the x coordinate to flip the
keypoints around.
flip_permutation: rank 1 int32 tensor containing the keypoint flip
permutation. This specifies the mapping from original keypoint indices
to the flipped keypoint indices. This is used primarily for keypoints
that are not reflection invariant. E.g. Suppose there are 3 keypoints
representing ['head', 'right_eye', 'left_eye'], then a logical choice for
flip_permutation might be [0, 2, 1] since we want to swap the 'left_eye'
and 'right_eye' after a horizontal flip.
flip_permutation: integer list or rank 1 int32 tensor containing the
keypoint flip permutation. This specifies the mapping from original
keypoint indices to the flipped keypoint indices. This is used primarily
for keypoints that are not reflection invariant. E.g. Suppose there are 3
keypoints representing ['head', 'right_eye', 'left_eye'], then a logical
choice for flip_permutation might be [0, 2, 1] since we want to swap the
'left_eye' and 'right_eye' after a horizontal flip.
Default to None or empty list to keep the original order after flip.
scope: name scope.
Returns:
......@@ -241,7 +242,8 @@ def flip_horizontal(keypoints, flip_point, flip_permutation, scope=None):
"""
with tf.name_scope(scope, 'FlipHorizontal'):
keypoints = tf.transpose(keypoints, [1, 0, 2])
keypoints = tf.gather(keypoints, flip_permutation)
if flip_permutation:
keypoints = tf.gather(keypoints, flip_permutation)
v, u = tf.split(value=keypoints, num_or_size_splits=2, axis=2)
u = flip_point * 2.0 - u
new_keypoints = tf.concat([v, u], 2)
......@@ -249,7 +251,7 @@ def flip_horizontal(keypoints, flip_point, flip_permutation, scope=None):
return new_keypoints
def flip_vertical(keypoints, flip_point, flip_permutation, scope=None):
def flip_vertical(keypoints, flip_point, flip_permutation=None, scope=None):
"""Flips the keypoints vertically around the flip_point.
This operation flips the y coordinate for each keypoint around the flip_point
......@@ -259,13 +261,14 @@ def flip_vertical(keypoints, flip_point, flip_permutation, scope=None):
keypoints: a tensor of shape [num_instances, num_keypoints, 2]
flip_point: (float) scalar tensor representing the y coordinate to flip the
keypoints around.
flip_permutation: rank 1 int32 tensor containing the keypoint flip
permutation. This specifies the mapping from original keypoint indices
to the flipped keypoint indices. This is used primarily for keypoints
that are not reflection invariant. E.g. Suppose there are 3 keypoints
representing ['head', 'right_eye', 'left_eye'], then a logical choice for
flip_permutation might be [0, 2, 1] since we want to swap the 'left_eye'
and 'right_eye' after a horizontal flip.
flip_permutation: integer list or rank 1 int32 tensor containing the
keypoint flip permutation. This specifies the mapping from original
keypoint indices to the flipped keypoint indices. This is used primarily
for keypoints that are not reflection invariant. E.g. Suppose there are 3
keypoints representing ['head', 'right_eye', 'left_eye'], then a logical
choice for flip_permutation might be [0, 2, 1] since we want to swap the
'left_eye' and 'right_eye' after a horizontal flip.
Default to None or empty list to keep the original order after flip.
scope: name scope.
Returns:
......@@ -273,7 +276,8 @@ def flip_vertical(keypoints, flip_point, flip_permutation, scope=None):
"""
with tf.name_scope(scope, 'FlipVertical'):
keypoints = tf.transpose(keypoints, [1, 0, 2])
keypoints = tf.gather(keypoints, flip_permutation)
if flip_permutation:
keypoints = tf.gather(keypoints, flip_permutation)
v, u = tf.split(value=keypoints, num_or_size_splits=2, axis=2)
v = flip_point * 2.0 - v
new_keypoints = tf.concat([v, u], 2)
......@@ -281,18 +285,24 @@ def flip_vertical(keypoints, flip_point, flip_permutation, scope=None):
return new_keypoints
def rot90(keypoints, scope=None):
def rot90(keypoints, rotation_permutation=None, scope=None):
"""Rotates the keypoints counter-clockwise by 90 degrees.
Args:
keypoints: a tensor of shape [num_instances, num_keypoints, 2]
rotation_permutation: integer list or rank 1 int32 tensor containing the
keypoint flip permutation. This specifies the mapping from original
keypoint indices to the rotated keypoint indices. This is used primarily
for keypoints that are not rotation invariant.
Default to None or empty list to keep the original order after rotation.
scope: name scope.
Returns:
new_keypoints: a tensor of shape [num_instances, num_keypoints, 2]
"""
with tf.name_scope(scope, 'Rot90'):
keypoints = tf.transpose(keypoints, [1, 0, 2])
if rotation_permutation:
keypoints = tf.gather(keypoints, rotation_permutation)
v, u = tf.split(value=keypoints[:, :, ::-1], num_or_size_splits=2, axis=2)
v = 1.0 - v
new_keypoints = tf.concat([v, u], 2)
......
......@@ -180,6 +180,21 @@ class KeypointOpsTest(test_case.TestCase):
[[0.1, 0.1], [0.2, 0.2], [0.3, 0.3]],
[[0.4, 0.4], [0.5, 0.5], [0.6, 0.6]]
])
expected_keypoints = tf.constant([
[[0.1, 0.9], [0.2, 0.8], [0.3, 0.7]],
[[0.4, 0.6], [0.5, 0.5], [0.6, 0.4]],
])
output = keypoint_ops.flip_horizontal(keypoints, 0.5)
return output, expected_keypoints
output, expected_keypoints = self.execute(graph_fn, [])
self.assertAllClose(output, expected_keypoints)
def test_flip_horizontal_permutation(self):
def graph_fn():
keypoints = tf.constant([[[0.1, 0.1], [0.2, 0.2], [0.3, 0.3]],
[[0.4, 0.4], [0.5, 0.5], [0.6, 0.6]]])
flip_permutation = [0, 2, 1]
expected_keypoints = tf.constant([
......@@ -197,6 +212,22 @@ class KeypointOpsTest(test_case.TestCase):
[[0.1, 0.1], [0.2, 0.2], [0.3, 0.3]],
[[0.4, 0.4], [0.5, 0.5], [0.6, 0.6]]
])
expected_keypoints = tf.constant([
[[0.9, 0.1], [0.8, 0.2], [0.7, 0.3]],
[[0.6, 0.4], [0.5, 0.5], [0.4, 0.6]],
])
output = keypoint_ops.flip_vertical(keypoints, 0.5)
return output, expected_keypoints
output, expected_keypoints = self.execute(graph_fn, [])
self.assertAllClose(output, expected_keypoints)
def test_flip_vertical_permutation(self):
def graph_fn():
keypoints = tf.constant([[[0.1, 0.1], [0.2, 0.2], [0.3, 0.3]],
[[0.4, 0.4], [0.5, 0.5], [0.6, 0.6]]])
flip_permutation = [0, 2, 1]
expected_keypoints = tf.constant([
......@@ -223,6 +254,23 @@ class KeypointOpsTest(test_case.TestCase):
output, expected_keypoints = self.execute(graph_fn, [])
self.assertAllClose(output, expected_keypoints)
def test_rot90_permutation(self):
def graph_fn():
keypoints = tf.constant([[[0.1, 0.1], [0.2, 0.2], [0.3, 0.3]],
[[0.4, 0.6], [0.5, 0.6], [0.6, 0.7]]])
rot_permutation = [0, 2, 1]
expected_keypoints = tf.constant([
[[0.9, 0.1], [0.7, 0.3], [0.8, 0.2]],
[[0.4, 0.4], [0.3, 0.6], [0.4, 0.5]],
])
output = keypoint_ops.rot90(keypoints,
rotation_permutation=rot_permutation)
return output, expected_keypoints
output, expected_keypoints = self.execute(graph_fn, [])
self.assertAllClose(output, expected_keypoints)
def test_keypoint_weights_from_visibilities(self):
def graph_fn():
keypoint_visibilities = tf.constant([
......
......@@ -681,3 +681,95 @@ class HardExampleMiner(object):
num_positives, num_negatives)
class PenaltyReducedLogisticFocalLoss(Loss):
"""Penalty-reduced pixelwise logistic regression with focal loss.
The loss is defined in Equation (1) of the Objects as Points[1] paper.
Although the loss is defined per-pixel in the output space, this class
assumes that each pixel is an anchor to be compatible with the base class.
[1]: https://arxiv.org/abs/1904.07850
"""
def __init__(self, alpha=2.0, beta=4.0, sigmoid_clip_value=1e-4):
"""Constructor.
Args:
alpha: Focussing parameter of the focal loss. Increasing this will
decrease the loss contribution of the well classified examples.
beta: The local penalty reduction factor. Increasing this will decrease
the contribution of loss due to negative pixels near the keypoint.
sigmoid_clip_value: The sigmoid operation used internally will be clipped
between [sigmoid_clip_value, 1 - sigmoid_clip_value)
"""
self._alpha = alpha
self._beta = beta
self._sigmoid_clip_value = sigmoid_clip_value
super(PenaltyReducedLogisticFocalLoss, self).__init__()
def _compute_loss(self, prediction_tensor, target_tensor, weights):
"""Compute loss function.
In all input tensors, `num_anchors` is the total number of pixels in the
the output space.
Args:
prediction_tensor: A float tensor of shape [batch_size, num_anchors,
num_classes] representing the predicted unscaled logits for each class.
The function will compute sigmoid on this tensor internally.
target_tensor: A float tensor of shape [batch_size, num_anchors,
num_classes] representing a tensor with the 'splatted' keypoints,
possibly using a gaussian kernel. This function assumes that
the target is bounded between [0, 1].
weights: a float tensor of shape, either [batch_size, num_anchors,
num_classes] or [batch_size, num_anchors, 1]. If the shape is
[batch_size, num_anchors, 1], all the classses are equally weighted.
Returns:
loss: a float tensor of shape [batch_size, num_anchors, num_classes]
representing the value of the loss function.
"""
is_present_tensor = tf.math.equal(target_tensor, 1.0)
prediction_tensor = tf.clip_by_value(tf.sigmoid(prediction_tensor),
self._sigmoid_clip_value,
1 - self._sigmoid_clip_value)
positive_loss = (tf.math.pow((1 - prediction_tensor), self._alpha)*
tf.math.log(prediction_tensor))
negative_loss = (tf.math.pow((1 - target_tensor), self._beta)*
tf.math.pow(prediction_tensor, self._alpha)*
tf.math.log(1 - prediction_tensor))
loss = -tf.where(is_present_tensor, positive_loss, negative_loss)
return loss * weights
class L1LocalizationLoss(Loss):
"""L1 loss or absolute difference.
When used in a per-pixel manner, each pixel should be given as an anchor.
"""
def _compute_loss(self, prediction_tensor, target_tensor, weights):
"""Compute loss function.
Args:
prediction_tensor: A float tensor of shape [batch_size, num_anchors]
representing the (encoded) predicted locations of objects.
target_tensor: A float tensor of shape [batch_size, num_anchors]
representing the regression targets
weights: a float tensor of shape [batch_size, num_anchors]
Returns:
loss: a float tensor of shape [batch_size, num_anchors] tensor
representing the value of the loss function.
"""
return tf.losses.absolute_difference(
target_tensor,
prediction_tensor,
weights=weights,
loss_collection=None,
reduction=tf.losses.Reduction.NONE
)
......@@ -391,7 +391,9 @@ class DetectionModel(six.with_metaclass(abc.ABCMeta, _BaseClass)):
pass
@abc.abstractmethod
def restore_map(self, fine_tune_checkpoint_type='detection'):
def restore_map(self,
fine_tune_checkpoint_type='detection',
load_all_detection_checkpoint_vars=False):
"""Returns a map of variables to load from a foreign checkpoint.
Returns a map of variable names to load from a checkpoint to variables in
......@@ -407,6 +409,9 @@ class DetectionModel(six.with_metaclass(abc.ABCMeta, _BaseClass)):
checkpoint (with compatible variable names) or to restore from a
classification checkpoint for initialization prior to training.
Valid values: `detection`, `classification`. Default 'detection'.
load_all_detection_checkpoint_vars: whether to load all variables (when
`fine_tune_checkpoint_type` is `detection`). If False, only variables
within the feature extractor scope are included. Default False.
Returns:
A dict mapping variable names (to load from a checkpoint) to variables in
......@@ -414,6 +419,36 @@ class DetectionModel(six.with_metaclass(abc.ABCMeta, _BaseClass)):
"""
pass
@abc.abstractmethod
def restore_from_objects(self, fine_tune_checkpoint_type='detection'):
"""Returns a map of variables to load from a foreign checkpoint.
Returns a dictionary of Tensorflow 2 Trackable objects (e.g. tf.Module
or Checkpoint). This enables the model to initialize based on weights from
another task. For example, the feature extractor variables from a
classification model can be used to bootstrap training of an object
detector. When loading from an object detection model, the checkpoint model
should have the same parameters as this detection model with exception of
the num_classes parameter.
Note that this function is intended to be used to restore Keras-based
models when running Tensorflow 2, whereas restore_map (above) is intended
to be used to restore Slim-based models when running Tensorflow 1.x.
TODO(jonathanhuang,rathodv): Check tf_version and raise unimplemented
error for both restore_map and restore_from_objects depending on version.
Args:
fine_tune_checkpoint_type: whether to restore from a full detection
checkpoint (with compatible variable names) or to restore from a
classification checkpoint for initialization prior to training.
Valid values: `detection`, `classification`. Default 'detection'.
Returns:
A dict mapping keys to Trackable objects (tf.Module or Checkpoint).
"""
pass
@abc.abstractmethod
def updates(self):
"""Returns a list of update operators for this model.
......
......@@ -57,6 +57,9 @@ class FakeModel(model.DetectionModel):
def restore_map(self):
return {}
def restore_from_objects(self, fine_tune_checkpoint_type):
pass
def regularization_losses(self):
return []
......
......@@ -16,10 +16,6 @@
"""Provides functions to prefetch tensors to feed into models."""
import tensorflow.compat.v1 as tf
from object_detection.utils import tf_version
if not tf_version.is_tf1():
raise ValueError('`prefetcher.py` is only supported in Tensorflow 1.X')
def prefetch(tensor_dict, capacity):
"""Creates a prefetch queue for tensors.
......
......@@ -18,16 +18,16 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import unittest
from six.moves import range
import tensorflow.compat.v1 as tf
# pylint: disable=g-bad-import-order,
from object_detection.core import prefetcher
import tf_slim as slim
# pylint: disable=g-bad-import-order
from object_detection.core import prefetcher
from object_detection.utils import tf_version
@unittest.skipIf(tf_version.is_tf2(), 'Skipping TF1.X only test.')
class PrefetcherTest(tf.test.TestCase):
"""Test class for prefetcher."""
......
......@@ -569,12 +569,11 @@ def random_horizontal_flip(image,
keypoints=None,
keypoint_visibilities=None,
keypoint_flip_permutation=None,
probability=0.5,
seed=None,
preprocess_vars_cache=None):
"""Randomly flips the image and detections horizontally.
The probability of flipping the image is 50%.
Args:
image: rank 3 float32 tensor with shape [height, width, channels].
boxes: (optional) rank 2 float32 tensor with shape [N, 4]
......@@ -592,6 +591,7 @@ def random_horizontal_flip(image,
[num_instances, num_keypoints].
keypoint_flip_permutation: rank 1 int32 tensor containing the keypoint flip
permutation.
probability: the probability of performing this augmentation.
seed: random seed
preprocess_vars_cache: PreprocessorCache object that records previously
performed augmentations. Updated in-place. If this
......@@ -636,7 +636,7 @@ def random_horizontal_flip(image,
generator_func,
preprocessor_cache.PreprocessorCache.HORIZONTAL_FLIP,
preprocess_vars_cache)
do_a_flip_random = tf.greater(do_a_flip_random, 0.5)
do_a_flip_random = tf.less(do_a_flip_random, probability)
# flip image
image = tf.cond(do_a_flip_random, lambda: _flip_image(image), lambda: image)
......@@ -682,6 +682,7 @@ def random_vertical_flip(image,
masks=None,
keypoints=None,
keypoint_flip_permutation=None,
probability=0.5,
seed=None,
preprocess_vars_cache=None):
"""Randomly flips the image and detections vertically.
......@@ -703,6 +704,7 @@ def random_vertical_flip(image,
normalized coordinates.
keypoint_flip_permutation: rank 1 int32 tensor containing the keypoint flip
permutation.
probability: the probability of performing this augmentation.
seed: random seed
preprocess_vars_cache: PreprocessorCache object that records previously
performed augmentations. Updated in-place. If this
......@@ -743,7 +745,7 @@ def random_vertical_flip(image,
do_a_flip_random = _get_or_create_preprocess_rand_vars(
generator_func, preprocessor_cache.PreprocessorCache.VERTICAL_FLIP,
preprocess_vars_cache)
do_a_flip_random = tf.greater(do_a_flip_random, 0.5)
do_a_flip_random = tf.less(do_a_flip_random, probability)
# flip image
image = tf.cond(do_a_flip_random, lambda: _flip_image(image), lambda: image)
......@@ -777,6 +779,8 @@ def random_rotation90(image,
boxes=None,
masks=None,
keypoints=None,
keypoint_rot_permutation=None,
probability=0.5,
seed=None,
preprocess_vars_cache=None):
"""Randomly rotates the image and detections 90 degrees counter-clockwise.
......@@ -799,6 +803,9 @@ def random_rotation90(image,
keypoints: (optional) rank 3 float32 tensor with shape
[num_instances, num_keypoints, 2]. The keypoints are in y-x
normalized coordinates.
keypoint_rot_permutation: rank 1 int32 tensor containing the keypoint flip
permutation.
probability: the probability of performing this augmentation.
seed: random seed
preprocess_vars_cache: PreprocessorCache object that records previously
performed augmentations. Updated in-place. If this
......@@ -833,7 +840,7 @@ def random_rotation90(image,
do_a_rot90_random = _get_or_create_preprocess_rand_vars(
generator_func, preprocessor_cache.PreprocessorCache.ROTATION90,
preprocess_vars_cache)
do_a_rot90_random = tf.greater(do_a_rot90_random, 0.5)
do_a_rot90_random = tf.less(do_a_rot90_random, probability)
# flip image
image = tf.cond(do_a_rot90_random, lambda: _rot90_image(image),
......@@ -856,7 +863,7 @@ def random_rotation90(image,
if keypoints is not None:
keypoints = tf.cond(
do_a_rot90_random,
lambda: keypoint_ops.rot90(keypoints),
lambda: keypoint_ops.rot90(keypoints, keypoint_rot_permutation),
lambda: keypoints)
result.append(keypoints)
......
......@@ -19,6 +19,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import unittest
from absl.testing import parameterized
import numpy as np
import six
......@@ -30,11 +31,12 @@ from object_detection.core import preprocessor
from object_detection.core import preprocessor_cache
from object_detection.core import standard_fields as fields
from object_detection.utils import test_case
from object_detection.utils import tf_version
if six.PY2:
import mock # pylint: disable=g-import-not-at-top
else:
from unittest import mock # pylint: disable=g-import-not-at-top
mock = unittest.mock # pylint: disable=g-import-not-at-top
class PreprocessorTest(test_case.TestCase, parameterized.TestCase):
......@@ -118,7 +120,10 @@ class PreprocessorTest(test_case.TestCase, parameterized.TestCase):
return tf.constant(keypoints, dtype=tf.float32)
def createKeypointFlipPermutation(self):
return np.array([0, 2, 1], dtype=np.int32)
return [0, 2, 1]
def createKeypointRotPermutation(self):
return [0, 2, 1]
def createTestLabels(self):
labels = tf.constant([1, 2], dtype=tf.int32)
......@@ -910,19 +915,22 @@ class PreprocessorTest(test_case.TestCase, parameterized.TestCase):
test_keypoints=True)
def testRunRandomRotation90WithMaskAndKeypoints(self):
preprocess_options = [(preprocessor.random_rotation90, {})]
image_height = 3
image_width = 3
images = tf.random_uniform([1, image_height, image_width, 3])
boxes = self.createTestBoxes()
masks = self.createTestMasks()
keypoints, _ = self.createTestKeypoints()
keypoint_rot_permutation = self.createKeypointRotPermutation()
tensor_dict = {
fields.InputDataFields.image: images,
fields.InputDataFields.groundtruth_boxes: boxes,
fields.InputDataFields.groundtruth_instance_masks: masks,
fields.InputDataFields.groundtruth_keypoints: keypoints
}
preprocess_options = [(preprocessor.random_rotation90, {
'keypoint_rot_permutation': keypoint_rot_permutation
})]
preprocessor_arg_map = preprocessor.get_default_func_arg_map(
include_instance_masks=True, include_keypoints=True)
tensor_dict = preprocessor.preprocess(
......@@ -2819,6 +2827,7 @@ class PreprocessorTest(test_case.TestCase, parameterized.TestCase):
self.assertAllEqual(images_shape, patched_images_shape)
self.assertAllEqual(images, patched_images)
@unittest.skipIf(tf_version.is_tf2(), 'Skipping TF1.X only test.')
def testAutoAugmentImage(self):
def graph_fn():
preprocessing_options = []
......
......@@ -66,6 +66,11 @@ class InputDataFields(object):
groundtruth_keypoint_weights: groundtruth weight factor for keypoints.
groundtruth_label_weights: groundtruth label weights.
groundtruth_weights: groundtruth weight factor for bounding boxes.
groundtruth_dp_num_points: The number of DensePose sampled points for each
instance.
groundtruth_dp_part_ids: Part indices for DensePose points.
groundtruth_dp_surface_coords: Image locations and UV coordinates for
DensePose points.
num_groundtruth_boxes: number of groundtruth boxes.
is_annotated: whether an image has been labeled or not.
true_image_shapes: true shapes of images in the resized images, as resized
......@@ -108,6 +113,9 @@ class InputDataFields(object):
groundtruth_keypoint_weights = 'groundtruth_keypoint_weights'
groundtruth_label_weights = 'groundtruth_label_weights'
groundtruth_weights = 'groundtruth_weights'
groundtruth_dp_num_points = 'groundtruth_dp_num_points'
groundtruth_dp_part_ids = 'groundtruth_dp_part_ids'
groundtruth_dp_surface_coords = 'groundtruth_dp_surface_coords'
num_groundtruth_boxes = 'num_groundtruth_boxes'
is_annotated = 'is_annotated'
true_image_shape = 'true_image_shape'
......
......@@ -50,10 +50,12 @@ from object_detection.core import matcher as mat
from object_detection.core import region_similarity_calculator as sim_calc
from object_detection.core import standard_fields as fields
from object_detection.matchers import argmax_matcher
from object_detection.matchers import bipartite_matcher
from object_detection.utils import shape_utils
from object_detection.utils import target_assigner_utils as ta_utils
from object_detection.utils import tf_version
if tf_version.is_tf1():
from object_detection.matchers import bipartite_matcher # pylint: disable=g-import-not-at-top
ResizeMethod = tf2.image.ResizeMethod
......@@ -398,6 +400,8 @@ def create_target_assigner(reference, stage=None,
ValueError: if combination reference+stage is invalid.
"""
if reference == 'Multibox' and stage == 'proposal':
if tf_version.is_tf2():
raise ValueError('GreedyBipartiteMatcher is not supported in TF 2.X.')
similarity_calc = sim_calc.NegSqDistSimilarity()
matcher = bipartite_matcher.GreedyBipartiteMatcher()
box_coder_instance = mean_stddev_box_coder.MeanStddevBoxCoder()
......@@ -713,3 +717,943 @@ def batch_assign_confidences(target_assigner,
batch_reg_weights, batch_match)
def _smallest_positive_root(a, b, c):
"""Returns the smallest positive root of a quadratic equation."""
discriminant = tf.sqrt(b ** 2 - 4 * a * c)
# TODO(vighneshb) We are currently using the slightly incorrect
# CenterNet implementation. The commented lines implement the fixed version
# in https://github.com/princeton-vl/CornerNet. Change the implementation
# after verifying it has no negative impact.
# root1 = (-b - discriminant) / (2 * a)
# root2 = (-b + discriminant) / (2 * a)
# return tf.where(tf.less(root1, 0), root2, root1)
return (-b + discriminant) / (2.0)
def max_distance_for_overlap(height, width, min_iou):
"""Computes how far apart bbox corners can lie while maintaining the iou.
Given a bounding box size, this function returns a lower bound on how far
apart the corners of another box can lie while still maintaining the given
IoU. The implementation is based on the `gaussian_radius` function in the
Objects as Points github repo: https://github.com/xingyizhou/CenterNet
Args:
height: A 1-D float Tensor representing height of the ground truth boxes.
width: A 1-D float Tensor representing width of the ground truth boxes.
min_iou: A float representing the minimum IoU desired.
Returns:
distance: A 1-D Tensor of distances, of the same length as the input
height and width tensors.
"""
# Given that the detected box is displaced at a distance `d`, the exact
# IoU value will depend on the angle at which each corner is displaced.
# We simplify our computation by assuming that each corner is displaced by
# a distance `d` in both x and y direction. This gives us a lower IoU than
# what is actually realizable and ensures that any box with corners less
# than `d` distance apart will always have an IoU greater than or equal
# to `min_iou`
# The following 3 cases can be worked on geometrically and come down to
# solving a quadratic inequality. In each case, to ensure `min_iou` we use
# the smallest positive root of the equation.
# Case where detected box is offset from ground truth and no box completely
# contains the other.
distance_detection_offset = _smallest_positive_root(
a=1, b=-(height + width),
c=width * height * ((1 - min_iou) / (1 + min_iou))
)
# Case where detection is smaller than ground truth and completely contained
# in it.
distance_detection_in_gt = _smallest_positive_root(
a=4, b=-2 * (height + width),
c=(1 - min_iou) * width * height
)
# Case where ground truth is smaller than detection and completely contained
# in it.
distance_gt_in_detection = _smallest_positive_root(
a=4 * min_iou, b=(2 * min_iou) * (width + height),
c=(min_iou - 1) * width * height
)
return tf.reduce_min([distance_detection_offset,
distance_gt_in_detection,
distance_detection_in_gt], axis=0)
def get_batch_predictions_from_indices(batch_predictions, indices):
"""Gets the values of predictions in a batch at the given indices.
The indices are expected to come from the offset targets generation functions
in this library. The returned value is intended to be used inside a loss
function.
Args:
batch_predictions: A tensor of shape [batch_size, height, width, 2] for
single class offsets and [batch_size, height, width, class, 2] for
multiple classes offsets (e.g. keypoint joint offsets) representing the
(height, width) or (y_offset, x_offset) predictions over a batch.
indices: A tensor of shape [num_instances, 3] for single class offset and
[num_instances, 4] for multiple classes offsets representing the indices
in the batch to be penalized in a loss function
Returns:
values: A tensor of shape [num_instances, 2] holding the predicted values
at the given indices.
"""
return tf.gather_nd(batch_predictions, indices)
def _compute_std_dev_from_box_size(boxes_height, boxes_width, min_overlap):
"""Computes the standard deviation of the Gaussian kernel from box size.
Args:
boxes_height: A 1D tensor with shape [num_instances] representing the height
of each box.
boxes_width: A 1D tensor with shape [num_instances] representing the width
of each box.
min_overlap: The minimum IOU overlap that boxes need to have to not be
penalized.
Returns:
A 1D tensor with shape [num_instances] representing the computed Gaussian
sigma for each of the box.
"""
# We are dividing by 3 so that points closer than the computed
# distance have a >99% CDF.
sigma = max_distance_for_overlap(boxes_height, boxes_width, min_overlap)
sigma = (2 * tf.math.maximum(tf.math.floor(sigma), 0.0) + 1) / 6.0
return sigma
class CenterNetCenterHeatmapTargetAssigner(object):
"""Wrapper to compute the object center heatmap."""
def __init__(self, stride, min_overlap=0.7):
"""Initializes the target assigner.
Args:
stride: int, the stride of the network in output pixels.
min_overlap: The minimum IOU overlap that boxes need to have to not be
penalized.
"""
self._stride = stride
self._min_overlap = min_overlap
def assign_center_targets_from_boxes(self,
height,
width,
gt_boxes_list,
gt_classes_list,
gt_weights_list=None):
"""Computes the object center heatmap target.
Args:
height: int, height of input to the model. This is used to
determine the height of the output.
width: int, width of the input to the model. This is used to
determine the width of the output.
gt_boxes_list: A list of float tensors with shape [num_boxes, 4]
representing the groundtruth detection bounding boxes for each sample in
the batch. The box coordinates are expected in normalized coordinates.
gt_classes_list: A list of float tensors with shape [num_boxes,
num_classes] representing the one-hot encoded class labels for each box
in the gt_boxes_list.
gt_weights_list: A list of float tensors with shape [num_boxes]
representing the weight of each groundtruth detection box.
Returns:
heatmap: A Tensor of size [batch_size, output_height, output_width,
num_classes] representing the per class center heatmap. output_height
and output_width are computed by dividing the input height and width by
the stride specified during initialization.
"""
out_height = tf.cast(height // self._stride, tf.float32)
out_width = tf.cast(width // self._stride, tf.float32)
# Compute the yx-grid to be used to generate the heatmap. Each returned
# tensor has shape of [out_height, out_width]
(y_grid, x_grid) = ta_utils.image_shape_to_grids(out_height, out_width)
heatmaps = []
if gt_weights_list is None:
gt_weights_list = [None] * len(gt_boxes_list)
# TODO(vighneshb) Replace the for loop with a batch version.
for boxes, class_targets, weights in zip(gt_boxes_list, gt_classes_list,
gt_weights_list):
boxes = box_list.BoxList(boxes)
# Convert the box coordinates to absolute output image dimension space.
boxes = box_list_ops.to_absolute_coordinates(boxes,
height // self._stride,
width // self._stride)
# Get the box center coordinates. Each returned tensors have the shape of
# [num_instances]
(y_center, x_center, boxes_height,
boxes_width) = boxes.get_center_coordinates_and_sizes()
# Compute the sigma from box size. The tensor shape: [num_instances].
sigma = _compute_std_dev_from_box_size(boxes_height, boxes_width,
self._min_overlap)
# Apply the Gaussian kernel to the center coordinates. Returned heatmap
# has shape of [out_height, out_width, num_classes]
heatmap = ta_utils.coordinates_to_heatmap(
y_grid=y_grid,
x_grid=x_grid,
y_coordinates=y_center,
x_coordinates=x_center,
sigma=sigma,
channel_onehot=class_targets,
channel_weights=weights)
heatmaps.append(heatmap)
# Return the stacked heatmaps over the batch.
return tf.stack(heatmaps, axis=0)
class CenterNetBoxTargetAssigner(object):
"""Wrapper to compute target tensors for the object detection task.
This class has methods that take as input a batch of ground truth tensors
(in the form of a list) and return the targets required to train the object
detection task.
"""
def __init__(self, stride):
"""Initializes the target assigner.
Args:
stride: int, the stride of the network in output pixels.
"""
self._stride = stride
def assign_size_and_offset_targets(self,
height,
width,
gt_boxes_list,
gt_weights_list=None):
"""Returns the box height/width and center offset targets and their indices.
The returned values are expected to be used with predicted tensors
of size (batch_size, height//self._stride, width//self._stride, 2). The
predicted values at the relevant indices can be retrieved with the
get_batch_predictions_from_indices function.
Args:
height: int, height of input to the model. This is used to determine the
height of the output.
width: int, width of the input to the model. This is used to determine the
width of the output.
gt_boxes_list: A list of float tensors with shape [num_boxes, 4]
representing the groundtruth detection bounding boxes for each sample in
the batch. The coordinates are expected in normalized coordinates.
gt_weights_list: A list of tensors with shape [num_boxes] corresponding to
the weight of each groundtruth detection box.
Returns:
batch_indices: an integer tensor of shape [num_boxes, 3] holding the
indices inside the predicted tensor which should be penalized. The
first column indicates the index along the batch dimension and the
second and third columns indicate the index along the y and x
dimensions respectively.
batch_box_height_width: a float tensor of shape [num_boxes, 2] holding
expected height and width of each box in the output space.
batch_offsets: a float tensor of shape [num_boxes, 2] holding the
expected y and x offset of each box in the output space.
batch_weights: a float tensor of shape [num_boxes] indicating the
weight of each prediction.
"""
if gt_weights_list is None:
gt_weights_list = [None] * len(gt_boxes_list)
batch_indices = []
batch_box_height_width = []
batch_weights = []
batch_offsets = []
for i, (boxes, weights) in enumerate(zip(gt_boxes_list, gt_weights_list)):
boxes = box_list.BoxList(boxes)
boxes = box_list_ops.to_absolute_coordinates(boxes,
height // self._stride,
width // self._stride)
# Get the box center coordinates. Each returned tensors have the shape of
# [num_boxes]
(y_center, x_center, boxes_height,
boxes_width) = boxes.get_center_coordinates_and_sizes()
num_boxes = tf.shape(x_center)
# Compute the offsets and indices of the box centers. Shape:
# offsets: [num_boxes, 2]
# indices: [num_boxes, 2]
(offsets, indices) = ta_utils.compute_floor_offsets_with_indices(
y_source=y_center, x_source=x_center)
# Assign ones if weights are not provided.
if weights is None:
weights = tf.ones(num_boxes, dtype=tf.float32)
# Shape of [num_boxes, 1] integer tensor filled with current batch index.
batch_index = i * tf.ones_like(indices[:, 0:1], dtype=tf.int32)
batch_indices.append(tf.concat([batch_index, indices], axis=1))
batch_box_height_width.append(
tf.stack([boxes_height, boxes_width], axis=1))
batch_weights.append(weights)
batch_offsets.append(offsets)
batch_indices = tf.concat(batch_indices, axis=0)
batch_box_height_width = tf.concat(batch_box_height_width, axis=0)
batch_weights = tf.concat(batch_weights, axis=0)
batch_offsets = tf.concat(batch_offsets, axis=0)
return (batch_indices, batch_box_height_width, batch_offsets, batch_weights)
# TODO(yuhuic): Update this class to handle the instance/keypoint weights.
# Currently those weights are used as "mask" to indicate whether an
# instance/keypoint should be considered or not (expecting only either 0 or 1
# value). In reality, the weights can be any value and this class should handle
# those values properly.
class CenterNetKeypointTargetAssigner(object):
"""Wrapper to compute target tensors for the CenterNet keypoint estimation.
This class has methods that take as input a batch of groundtruth tensors
(in the form of a list) and returns the targets required to train the
CenterNet model for keypoint estimation. Specifically, the class methods
expect the groundtruth in the following formats (consistent with the
standard Object Detection API). Note that usually the groundtruth tensors are
packed with a list which represents the batch dimension:
gt_classes_list: [Required] a list of 2D tf.float32 one-hot
(or k-hot) tensors of shape [num_instances, num_classes] containing the
class targets with the 0th index assumed to map to the first non-background
class.
gt_keypoints_list: [Required] a list of 3D tf.float32 tensors of
shape [num_instances, num_total_keypoints, 2] containing keypoint
coordinates. Note that the "num_total_keypoints" should be the sum of the
num_keypoints over all possible keypoint types, e.g. human pose, face.
For example, if a dataset contains both 17 human pose keypoints and 5 face
keypoints, then num_total_keypoints = 17 + 5 = 22.
If an intance contains only a subet of keypoints (e.g. human pose keypoints
but not face keypoints), the face keypoints will be filled with zeros.
Also note that keypoints are assumed to be provided in normalized
coordinates and missing keypoints should be encoded as NaN.
gt_keypoints_weights_list: [Optional] a list 3D tf.float32 tensors of shape
[num_instances, num_total_keypoints] representing the weights of each
keypoints. If not provided, then all not NaN keypoints will be equally
weighted.
gt_boxes_list: [Optional] a list of 2D tf.float32 tensors of shape
[num_instances, 4] containing coordinates of the groundtruth boxes.
Groundtruth boxes are provided in [y_min, x_min, y_max, x_max] format and
assumed to be normalized and clipped relative to the image window with
y_min <= y_max and x_min <= x_max.
Note that the boxes are only used to compute the center targets but are not
considered as required output of the keypoint task. If the boxes were not
provided, the center targets will be inferred from the keypoints
[not implemented yet].
gt_weights_list: [Optional] A list of 1D tf.float32 tensors of shape
[num_instances] containing weights for groundtruth boxes. Only useful when
gt_boxes_list is also provided.
"""
def __init__(self,
stride,
class_id,
keypoint_indices,
keypoint_std_dev=None,
per_keypoint_offset=False,
peak_radius=0):
"""Initializes a CenterNet keypoints target assigner.
Args:
stride: int, the stride of the network in output pixels.
class_id: int, the ID of the class (0-indexed) that contains the target
keypoints to consider in this task. For example, if the task is human
pose estimation, the class id should correspond to the "human" class.
keypoint_indices: A list of integers representing the indices of the
keypoints to be considered in this task. This is used to retrieve the
subset of the keypoints from gt_keypoints that should be considered in
this task.
keypoint_std_dev: A list of floats represent the standard deviation of the
Gaussian kernel used to generate the keypoint heatmap (in the unit of
output pixels). It is to provide the flexibility of using different
sizes of Gaussian kernel for each keypoint type. If not provided, then
all standard deviation will be the same as the default value (10.0 in
the output pixel space). If provided, the length of keypoint_std_dev
needs to be the same as the length of keypoint_indices, indicating the
standard deviation of each keypoint type.
per_keypoint_offset: boolean, indicating whether to assign offset for
each keypoint channel. If set False, the output offset target will have
the shape [batch_size, out_height, out_width, 2]. If set True, the
output offset target will have the shape [batch_size, out_height,
out_width, 2 * num_keypoints].
peak_radius: int, the radius (in the unit of output pixel) around heatmap
peak to assign the offset targets.
"""
self._stride = stride
self._class_id = class_id
self._keypoint_indices = keypoint_indices
self._per_keypoint_offset = per_keypoint_offset
self._peak_radius = peak_radius
if keypoint_std_dev is None:
self._keypoint_std_dev = ([_DEFAULT_KEYPOINT_OFFSET_STD_DEV] *
len(keypoint_indices))
else:
assert len(keypoint_indices) == len(keypoint_std_dev)
self._keypoint_std_dev = keypoint_std_dev
def _preprocess_keypoints_and_weights(self, out_height, out_width, keypoints,
class_onehot, class_weights,
keypoint_weights):
"""Preprocesses the keypoints and the corresponding keypoint weights.
This function performs several common steps to preprocess the keypoints and
keypoint weights features, including:
1) Select the subset of keypoints based on the keypoint indices, fill the
keypoint NaN values with zeros and convert to absoluate coordinates.
2) Generate the weights of the keypoint using the following information:
a. The class of the instance.
b. The NaN value of the keypoint coordinates.
c. The provided keypoint weights.
Args:
out_height: An integer or an interger tensor indicating the output height
of the model.
out_width: An integer or an interger tensor indicating the output width of
the model.
keypoints: A float tensor of shape [num_instances, num_total_keypoints, 2]
representing the original keypoint grountruth coordinates.
class_onehot: A float tensor of shape [num_instances, num_classes]
containing the class targets with the 0th index assumed to map to the
first non-background class.
class_weights: A float tensor of shape [num_instances] containing weights
for groundtruth instances.
keypoint_weights: A float tensor of shape
[num_instances, num_total_keypoints] representing the weights of each
keypoints.
Returns:
A tuple of two tensors:
keypoint_absolute: A float tensor of shape
[num_instances, num_keypoints, 2] which is the selected and updated
keypoint coordinates.
keypoint_weights: A float tensor of shape [num_instances, num_keypoints]
representing the updated weight of each keypoint.
"""
# Select the targets keypoints by their type ids and generate the mask
# of valid elements.
valid_mask, keypoints = ta_utils.get_valid_keypoint_mask_for_class(
keypoint_coordinates=keypoints,
class_id=self._class_id,
class_onehot=class_onehot,
class_weights=class_weights,
keypoint_indices=self._keypoint_indices)
# Keypoint coordinates in absolute coordinate system.
# The shape of the tensors: [num_instances, num_keypoints, 2].
keypoints_absolute = keypoint_ops.to_absolute_coordinates(
keypoints, out_height, out_width)
# Assign default weights for the keypoints.
if keypoint_weights is None:
keypoint_weights = tf.ones_like(keypoints[:, :, 0])
else:
keypoint_weights = tf.gather(
keypoint_weights, indices=self._keypoint_indices, axis=1)
keypoint_weights = keypoint_weights * valid_mask
return keypoints_absolute, keypoint_weights
def assign_keypoint_heatmap_targets(self,
height,
width,
gt_keypoints_list,
gt_classes_list,
gt_keypoints_weights_list=None,
gt_weights_list=None,
gt_boxes_list=None):
"""Returns the keypoint heatmap targets for the CenterNet model.
Args:
height: int, height of input to the CenterNet model. This is used to
determine the height of the output.
width: int, width of the input to the CenterNet model. This is used to
determine the width of the output.
gt_keypoints_list: A list of float tensors with shape [num_instances,
num_total_keypoints, 2]. See class-level description for more detail.
gt_classes_list: A list of float tensors with shape [num_instances,
num_classes]. See class-level description for more detail.
gt_keypoints_weights_list: A list of tensors with shape [num_instances,
num_total_keypoints] corresponding to the weight of each keypoint.
gt_weights_list: A list of float tensors with shape [num_instances]. See
class-level description for more detail.
gt_boxes_list: A list of float tensors with shape [num_instances, 4]. See
class-level description for more detail. If provided, the keypoint
standard deviations will be scaled based on the box sizes.
Returns:
heatmap: A float tensor of shape [batch_size, output_height, output_width,
num_keypoints] representing the per keypoint type center heatmap.
output_height and output_width are computed by dividing the input height
and width by the stride specified during initialization. Note that the
"num_keypoints" is defined by the length of keypoint_indices, which is
not necessarily equal to "num_total_keypoints".
num_instances_batch: A 2D int tensor of shape
[batch_size, num_keypoints] representing number of instances for each
keypoint type.
valid_mask: A float tensor with shape [batch_size, output_height,
output_width] where all values within the regions of the blackout boxes
are 0.0 and 1.0 else where.
"""
out_width = tf.cast(width // self._stride, tf.float32)
out_height = tf.cast(height // self._stride, tf.float32)
# Compute the yx-grid to be used to generate the heatmap. Each returned
# tensor has shape of [out_height, out_width]
y_grid, x_grid = ta_utils.image_shape_to_grids(out_height, out_width)
if gt_keypoints_weights_list is None:
gt_keypoints_weights_list = [None] * len(gt_keypoints_list)
if gt_weights_list is None:
gt_weights_list = [None] * len(gt_classes_list)
if gt_boxes_list is None:
gt_boxes_list = [None] * len(gt_keypoints_list)
heatmaps = []
num_instances_list = []
valid_mask_list = []
for keypoints, classes, kp_weights, weights, boxes in zip(
gt_keypoints_list, gt_classes_list, gt_keypoints_weights_list,
gt_weights_list, gt_boxes_list):
keypoints_absolute, kp_weights = self._preprocess_keypoints_and_weights(
out_height=out_height,
out_width=out_width,
keypoints=keypoints,
class_onehot=classes,
class_weights=weights,
keypoint_weights=kp_weights)
num_instances, num_keypoints, _ = (
shape_utils.combined_static_and_dynamic_shape(keypoints_absolute))
# A tensor of shape [num_instances, num_keypoints] with
# each element representing the type dimension for each corresponding
# keypoint:
# [[0, 1, ..., k-1],
# [0, 1, ..., k-1],
# :
# [0, 1, ..., k-1]]
keypoint_types = tf.tile(
input=tf.expand_dims(tf.range(num_keypoints), axis=0),
multiples=[num_instances, 1])
# A tensor of shape [num_instances, num_keypoints] with
# each element representing the sigma of the Gaussian kernel for each
# keypoint.
keypoint_std_dev = tf.tile(
input=tf.expand_dims(tf.constant(self._keypoint_std_dev), axis=0),
multiples=[num_instances, 1])
# If boxes is not None, then scale the standard deviation based on the
# size of the object bounding boxes similar to object center heatmap.
if boxes is not None:
boxes = box_list.BoxList(boxes)
# Convert the box coordinates to absolute output image dimension space.
boxes = box_list_ops.to_absolute_coordinates(boxes,
height // self._stride,
width // self._stride)
# Get the box height and width. Each returned tensors have the shape
# of [num_instances]
(_, _, boxes_height,
boxes_width) = boxes.get_center_coordinates_and_sizes()
# Compute the sigma from box size. The tensor shape: [num_instances].
sigma = _compute_std_dev_from_box_size(boxes_height, boxes_width, 0.7)
keypoint_std_dev = keypoint_std_dev * tf.stack(
[sigma] * num_keypoints, axis=1)
# Generate the valid region mask to ignore regions with target class but
# no corresponding keypoints.
# Shape: [num_instances].
blackout = tf.logical_and(classes[:, self._class_id] > 0,
tf.reduce_max(kp_weights, axis=1) < 1e-3)
valid_mask = ta_utils.blackout_pixel_weights_by_box_regions(
out_height, out_width, boxes.get(), blackout)
valid_mask_list.append(valid_mask)
# Apply the Gaussian kernel to the keypoint coordinates. Returned heatmap
# has shape of [out_height, out_width, num_keypoints].
heatmap = ta_utils.coordinates_to_heatmap(
y_grid=y_grid,
x_grid=x_grid,
y_coordinates=tf.keras.backend.flatten(keypoints_absolute[:, :, 0]),
x_coordinates=tf.keras.backend.flatten(keypoints_absolute[:, :, 1]),
sigma=tf.keras.backend.flatten(keypoint_std_dev),
channel_onehot=tf.one_hot(
tf.keras.backend.flatten(keypoint_types), depth=num_keypoints),
channel_weights=tf.keras.backend.flatten(kp_weights))
num_instances_list.append(
tf.cast(tf.reduce_sum(kp_weights, axis=0), dtype=tf.int32))
heatmaps.append(heatmap)
return (tf.stack(heatmaps, axis=0), tf.stack(num_instances_list, axis=0),
tf.stack(valid_mask_list, axis=0))
def _get_keypoint_types(self, num_instances, num_keypoints, num_neighbors):
"""Gets keypoint type index tensor.
The function prepares the tensor of keypoint indices with shape
[num_instances, num_keypoints, num_neighbors]. Each element represents the
keypoint type index for each corresponding keypoint and tiled along the 3rd
axis:
[[0, 1, ..., num_keypoints - 1],
[0, 1, ..., num_keypoints - 1],
:
[0, 1, ..., num_keypoints - 1]]
Args:
num_instances: int, the number of instances, used to define the 1st
dimension.
num_keypoints: int, the number of keypoint types, used to define the 2nd
dimension.
num_neighbors: int, the number of neighborhood pixels to consider for each
keypoint, used to define the 3rd dimension.
Returns:
A integer tensor of shape [num_instances, num_keypoints, num_neighbors].
"""
keypoint_types = tf.range(num_keypoints)[tf.newaxis, :, tf.newaxis]
tiled_keypoint_types = tf.tile(keypoint_types,
multiples=[num_instances, 1, num_neighbors])
return tiled_keypoint_types
def assign_keypoints_offset_targets(self,
height,
width,
gt_keypoints_list,
gt_classes_list,
gt_keypoints_weights_list=None,
gt_weights_list=None):
"""Returns the offsets and indices of the keypoints for location refinement.
The returned values are used to refine the location of each keypoints in the
heatmap. The predicted values at the relevant indices can be retrieved with
the get_batch_predictions_from_indices function.
Args:
height: int, height of input to the CenterNet model. This is used to
determine the height of the output.
width: int, width of the input to the CenterNet model. This is used to
determine the width of the output.
gt_keypoints_list: A list of tensors with shape [num_instances,
num_total_keypoints]. See class-level description for more detail.
gt_classes_list: A list of tensors with shape [num_instances,
num_classes]. See class-level description for more detail.
gt_keypoints_weights_list: A list of tensors with shape [num_instances,
num_total_keypoints] corresponding to the weight of each keypoint.
gt_weights_list: A list of float tensors with shape [num_instances]. See
class-level description for more detail.
Returns:
batch_indices: an integer tensor of shape [num_total_instances, 3] (or
[num_total_instances, 4] if 'per_keypoint_offset' is set True) holding
the indices inside the predicted tensor which should be penalized. The
first column indicates the index along the batch dimension and the
second and third columns indicate the index along the y and x
dimensions respectively. The fourth column corresponds to the channel
dimension (if 'per_keypoint_offset' is set True).
batch_offsets: a float tensor of shape [num_total_instances, 2] holding
the expected y and x offset of each box in the output space.
batch_weights: a float tensor of shape [num_total_instances] indicating
the weight of each prediction.
Note that num_total_instances = batch_size * num_instances *
num_keypoints * num_neighbors
"""
batch_indices = []
batch_offsets = []
batch_weights = []
if gt_keypoints_weights_list is None:
gt_keypoints_weights_list = [None] * len(gt_keypoints_list)
if gt_weights_list is None:
gt_weights_list = [None] * len(gt_classes_list)
for i, (keypoints, classes, kp_weights, weights) in enumerate(
zip(gt_keypoints_list, gt_classes_list, gt_keypoints_weights_list,
gt_weights_list)):
keypoints_absolute, kp_weights = self._preprocess_keypoints_and_weights(
out_height=height // self._stride,
out_width=width // self._stride,
keypoints=keypoints,
class_onehot=classes,
class_weights=weights,
keypoint_weights=kp_weights)
num_instances, num_keypoints, _ = (
shape_utils.combined_static_and_dynamic_shape(keypoints_absolute))
# [num_instances * num_keypoints]
y_source = tf.keras.backend.flatten(keypoints_absolute[:, :, 0])
x_source = tf.keras.backend.flatten(keypoints_absolute[:, :, 1])
# All keypoint coordinates and their neighbors:
# [num_instance * num_keypoints, num_neighbors]
(y_source_neighbors, x_source_neighbors,
valid_sources) = ta_utils.get_surrounding_grids(height // self._stride,
width // self._stride,
y_source, x_source,
self._peak_radius)
_, num_neighbors = shape_utils.combined_static_and_dynamic_shape(
y_source_neighbors)
# Update the valid keypoint weights.
# [num_instance * num_keypoints, num_neighbors]
valid_keypoints = tf.cast(
valid_sources, dtype=tf.float32) * tf.stack(
[tf.keras.backend.flatten(kp_weights)] * num_neighbors, axis=-1)
# Compute the offsets and indices of the box centers. Shape:
# offsets: [num_instances * num_keypoints, num_neighbors, 2]
# indices: [num_instances * num_keypoints, num_neighbors, 2]
offsets, indices = ta_utils.compute_floor_offsets_with_indices(
y_source=y_source_neighbors,
x_source=x_source_neighbors,
y_target=y_source,
x_target=x_source)
# Reshape to:
# offsets: [num_instances * num_keypoints * num_neighbors, 2]
# indices: [num_instances * num_keypoints * num_neighbors, 2]
offsets = tf.reshape(offsets, [-1, 2])
indices = tf.reshape(indices, [-1, 2])
# Prepare the batch indices to be prepended.
batch_index = tf.fill(
[num_instances * num_keypoints * num_neighbors, 1], i)
if self._per_keypoint_offset:
tiled_keypoint_types = self._get_keypoint_types(
num_instances, num_keypoints, num_neighbors)
batch_indices.append(
tf.concat([batch_index, indices,
tf.reshape(tiled_keypoint_types, [-1, 1])], axis=1))
else:
batch_indices.append(tf.concat([batch_index, indices], axis=1))
batch_offsets.append(offsets)
batch_weights.append(tf.keras.backend.flatten(valid_keypoints))
# Concatenate the tensors in the batch in the first dimension:
# shape: [batch_size * num_instances * num_keypoints * num_neighbors, 3] or
# [batch_size * num_instances * num_keypoints * num_neighbors, 4] if
# 'per_keypoint_offset' is set to True.
batch_indices = tf.concat(batch_indices, axis=0)
# shape: [batch_size * num_instances * num_keypoints * num_neighbors]
batch_weights = tf.concat(batch_weights, axis=0)
# shape: [batch_size * num_instances * num_keypoints * num_neighbors, 2]
batch_offsets = tf.concat(batch_offsets, axis=0)
return (batch_indices, batch_offsets, batch_weights)
def assign_joint_regression_targets(self,
height,
width,
gt_keypoints_list,
gt_classes_list,
gt_boxes_list=None,
gt_keypoints_weights_list=None,
gt_weights_list=None):
"""Returns the joint regression from center grid to keypoints.
The joint regression is used as the grouping cue from the estimated
keypoints to instance center. The offsets are the vectors from the floored
object center coordinates to the keypoint coordinates.
Args:
height: int, height of input to the CenterNet model. This is used to
determine the height of the output.
width: int, width of the input to the CenterNet model. This is used to
determine the width of the output.
gt_keypoints_list: A list of float tensors with shape [num_instances,
num_total_keypoints]. See class-level description for more detail.
gt_classes_list: A list of float tensors with shape [num_instances,
num_classes]. See class-level description for more detail.
gt_boxes_list: A list of float tensors with shape [num_instances, 4]. See
class-level description for more detail. If provided, then the center
targets will be computed based on the center of the boxes.
gt_keypoints_weights_list: A list of float tensors with shape
[num_instances, num_total_keypoints] representing to the weight of each
keypoint.
gt_weights_list: A list of float tensors with shape [num_instances]. See
class-level description for more detail.
Returns:
batch_indices: an integer tensor of shape [num_instances, 4] holding the
indices inside the predicted tensor which should be penalized. The
first column indicates the index along the batch dimension and the
second and third columns indicate the index along the y and x
dimensions respectively, the last dimension refers to the keypoint type
dimension.
batch_offsets: a float tensor of shape [num_instances, 2] holding the
expected y and x offset of each box in the output space.
batch_weights: a float tensor of shape [num_instances] indicating the
weight of each prediction.
Note that num_total_instances = batch_size * num_instances * num_keypoints
Raises:
NotImplementedError: currently the object center coordinates need to be
computed from groundtruth bounding boxes. The functionality of
generating the object center coordinates from keypoints is not
implemented yet.
"""
batch_indices = []
batch_offsets = []
batch_weights = []
batch_size = len(gt_keypoints_list)
if gt_keypoints_weights_list is None:
gt_keypoints_weights_list = [None] * batch_size
if gt_boxes_list is None:
gt_boxes_list = [None] * batch_size
if gt_weights_list is None:
gt_weights_list = [None] * len(gt_classes_list)
for i, (keypoints, classes, boxes, kp_weights, weights) in enumerate(
zip(gt_keypoints_list, gt_classes_list,
gt_boxes_list, gt_keypoints_weights_list, gt_weights_list)):
keypoints_absolute, kp_weights = self._preprocess_keypoints_and_weights(
out_height=height // self._stride,
out_width=width // self._stride,
keypoints=keypoints,
class_onehot=classes,
class_weights=weights,
keypoint_weights=kp_weights)
num_instances, num_keypoints, _ = (
shape_utils.combined_static_and_dynamic_shape(keypoints_absolute))
# If boxes are provided, compute the joint center from it.
if boxes is not None:
# Compute joint center from boxes.
boxes = box_list.BoxList(boxes)
boxes = box_list_ops.to_absolute_coordinates(boxes,
height // self._stride,
width // self._stride)
y_center, x_center, _, _ = boxes.get_center_coordinates_and_sizes()
else:
# TODO(yuhuic): Add the logic to generate object centers from keypoints.
raise NotImplementedError((
'The functionality of generating object centers from keypoints is'
' not implemented yet. Please provide groundtruth bounding boxes.'
))
# Tile the yx center coordinates to be the same shape as keypoints.
y_center_tiled = tf.tile(
tf.reshape(y_center, shape=[num_instances, 1]),
multiples=[1, num_keypoints])
x_center_tiled = tf.tile(
tf.reshape(x_center, shape=[num_instances, 1]),
multiples=[1, num_keypoints])
# [num_instance * num_keypoints, num_neighbors]
(y_source_neighbors, x_source_neighbors,
valid_sources) = ta_utils.get_surrounding_grids(
height // self._stride, width // self._stride,
tf.keras.backend.flatten(y_center_tiled),
tf.keras.backend.flatten(x_center_tiled), self._peak_radius)
_, num_neighbors = shape_utils.combined_static_and_dynamic_shape(
y_source_neighbors)
valid_keypoints = tf.cast(
valid_sources, dtype=tf.float32) * tf.stack(
[tf.keras.backend.flatten(kp_weights)] * num_neighbors, axis=-1)
# Compute the offsets and indices of the box centers. Shape:
# offsets: [num_instances * num_keypoints, 2]
# indices: [num_instances * num_keypoints, 2]
(offsets, indices) = ta_utils.compute_floor_offsets_with_indices(
y_source=y_source_neighbors,
x_source=x_source_neighbors,
y_target=tf.keras.backend.flatten(keypoints_absolute[:, :, 0]),
x_target=tf.keras.backend.flatten(keypoints_absolute[:, :, 1]))
# Reshape to:
# offsets: [num_instances * num_keypoints * num_neighbors, 2]
# indices: [num_instances * num_keypoints * num_neighbors, 2]
offsets = tf.reshape(offsets, [-1, 2])
indices = tf.reshape(indices, [-1, 2])
# keypoint type tensor: [num_instances, num_keypoints, num_neighbors].
tiled_keypoint_types = self._get_keypoint_types(
num_instances, num_keypoints, num_neighbors)
batch_index = tf.fill(
[num_instances * num_keypoints * num_neighbors, 1], i)
batch_indices.append(
tf.concat([batch_index, indices,
tf.reshape(tiled_keypoint_types, [-1, 1])], axis=1))
batch_offsets.append(offsets)
batch_weights.append(tf.keras.backend.flatten(valid_keypoints))
# Concatenate the tensors in the batch in the first dimension:
# shape: [batch_size * num_instances * num_keypoints, 4]
batch_indices = tf.concat(batch_indices, axis=0)
# shape: [batch_size * num_instances * num_keypoints]
batch_weights = tf.concat(batch_weights, axis=0)
# shape: [batch_size * num_instances * num_keypoints, 2]
batch_offsets = tf.concat(batch_offsets, axis=0)
return (batch_indices, batch_offsets, batch_weights)
class CenterNetMaskTargetAssigner(object):
"""Wrapper to compute targets for segmentation masks."""
def __init__(self, stride):
self._stride = stride
def assign_segmentation_targets(
self, gt_masks_list, gt_classes_list,
mask_resize_method=ResizeMethod.BILINEAR):
"""Computes the segmentation targets.
This utility produces a semantic segmentation mask for each class, starting
with whole image instance segmentation masks. Effectively, each per-class
segmentation target is the union of all masks from that class.
Args:
gt_masks_list: A list of float tensors with shape [num_boxes,
input_height, input_width] with values in {0, 1} representing instance
masks for each object.
gt_classes_list: A list of float tensors with shape [num_boxes,
num_classes] representing the one-hot encoded class labels for each box
in the gt_boxes_list.
mask_resize_method: A `tf.compat.v2.image.ResizeMethod`. The method to use
when resizing masks from input resolution to output resolution.
Returns:
segmentation_targets: An int32 tensor of size [batch_size, output_height,
output_width, num_classes] representing the class of each location in
the output space.
"""
# TODO(ronnyvotel): Handle groundtruth weights.
_, num_classes = shape_utils.combined_static_and_dynamic_shape(
gt_classes_list[0])
_, input_height, input_width = (
shape_utils.combined_static_and_dynamic_shape(gt_masks_list[0]))
output_height = input_height // self._stride
output_width = input_width // self._stride
segmentation_targets_list = []
for gt_masks, gt_classes in zip(gt_masks_list, gt_classes_list):
# Resize segmentation masks to conform to output dimensions. Use TF2
# image resize because TF1's version is buggy:
# https://yaqs.corp.google.com/eng/q/4970450458378240
gt_masks = tf2.image.resize(
gt_masks[:, :, :, tf.newaxis],
size=(output_height, output_width),
method=mask_resize_method)
gt_classes_reshaped = tf.reshape(gt_classes, [-1, 1, 1, num_classes])
# Shape: [h, w, num_classes].
segmentations_for_image = tf.reduce_max(
gt_masks * gt_classes_reshaped, axis=0)
segmentation_targets_list.append(segmentations_for_image)
segmentation_target = tf.stack(segmentation_targets_list, axis=0)
return segmentation_target
......@@ -24,9 +24,9 @@ from object_detection.core import region_similarity_calculator
from object_detection.core import standard_fields as fields
from object_detection.core import target_assigner as targetassigner
from object_detection.matchers import argmax_matcher
from object_detection.matchers import bipartite_matcher
from object_detection.utils import np_box_ops
from object_detection.utils import test_case
from object_detection.utils import tf_version
class TargetAssignerTest(test_case.TestCase):
......@@ -439,7 +439,7 @@ class TargetAssignerTest(test_case.TestCase):
def test_raises_error_on_incompatible_groundtruth_boxes_and_labels(self):
similarity_calc = region_similarity_calculator.NegSqDistSimilarity()
matcher = bipartite_matcher.GreedyBipartiteMatcher()
matcher = argmax_matcher.ArgMaxMatcher(0.5)
box_coder = mean_stddev_box_coder.MeanStddevBoxCoder()
unmatched_class_label = tf.constant([1, 0, 0, 0, 0, 0, 0], tf.float32)
target_assigner = targetassigner.TargetAssigner(
......@@ -469,7 +469,7 @@ class TargetAssignerTest(test_case.TestCase):
def test_raises_error_on_invalid_groundtruth_labels(self):
similarity_calc = region_similarity_calculator.NegSqDistSimilarity()
matcher = bipartite_matcher.GreedyBipartiteMatcher()
matcher = argmax_matcher.ArgMaxMatcher(0.5)
box_coder = mean_stddev_box_coder.MeanStddevBoxCoder(stddev=1.0)
unmatched_class_label = tf.constant([[0, 0], [0, 0], [0, 0]], tf.float32)
target_assigner = targetassigner.TargetAssigner(
......@@ -1191,7 +1191,7 @@ class BatchTargetAssignConfidencesTest(test_case.TestCase):
])
class CreateTargetAssignerTest(tf.test.TestCase):
class CreateTargetAssignerTest(test_case.TestCase):
def test_create_target_assigner(self):
"""Tests that named constructor gives working target assigners.
......@@ -1202,9 +1202,10 @@ class CreateTargetAssignerTest(tf.test.TestCase):
groundtruth = box_list.BoxList(tf.constant(corners))
priors = box_list.BoxList(tf.constant(corners))
multibox_ta = (targetassigner
.create_target_assigner('Multibox', stage='proposal'))
multibox_ta.assign(priors, groundtruth)
if tf_version.is_tf1():
multibox_ta = (targetassigner
.create_target_assigner('Multibox', stage='proposal'))
multibox_ta.assign(priors, groundtruth)
# No tests on output, as that may vary arbitrarily as new target assigners
# are added. As long as it is constructed correctly and runs without errors,
# tests on the individual assigners cover correctness of the assignments.
......@@ -1229,6 +1230,681 @@ class CreateTargetAssignerTest(tf.test.TestCase):
stage='invalid_stage')
def _array_argmax(array):
return np.unravel_index(np.argmax(array), array.shape)
class CenterNetCenterHeatmapTargetAssignerTest(test_case.TestCase):
def setUp(self):
super(CenterNetCenterHeatmapTargetAssignerTest, self).setUp()
self._box_center = [0.0, 0.0, 1.0, 1.0]
self._box_center_small = [0.25, 0.25, 0.75, 0.75]
self._box_lower_left = [0.5, 0.0, 1.0, 0.5]
self._box_center_offset = [0.1, 0.05, 1.0, 1.0]
self._box_odd_coordinates = [0.1625, 0.2125, 0.5625, 0.9625]
def test_center_location(self):
"""Test that the centers are at the correct location."""
def graph_fn():
box_batch = [tf.constant([self._box_center, self._box_lower_left])]
classes = [
tf.one_hot([0, 1], depth=4),
]
assigner = targetassigner.CenterNetCenterHeatmapTargetAssigner(4)
targets = assigner.assign_center_targets_from_boxes(80, 80, box_batch,
classes)
return targets
targets = self.execute(graph_fn, [])
self.assertEqual((10, 10), _array_argmax(targets[0, :, :, 0]))
self.assertAlmostEqual(1.0, targets[0, 10, 10, 0])
self.assertEqual((15, 5), _array_argmax(targets[0, :, :, 1]))
self.assertAlmostEqual(1.0, targets[0, 15, 5, 1])
def test_center_batch_shape(self):
"""Test that the shape of the target for a batch is correct."""
def graph_fn():
box_batch = [
tf.constant([self._box_center, self._box_lower_left]),
tf.constant([self._box_center]),
tf.constant([self._box_center_small]),
]
classes = [
tf.one_hot([0, 1], depth=4),
tf.one_hot([2], depth=4),
tf.one_hot([3], depth=4),
]
assigner = targetassigner.CenterNetCenterHeatmapTargetAssigner(4)
targets = assigner.assign_center_targets_from_boxes(80, 80, box_batch,
classes)
return targets
targets = self.execute(graph_fn, [])
self.assertEqual((3, 20, 20, 4), targets.shape)
def test_center_overlap_maximum(self):
"""Test that when boxes overlap we, are computing the maximum."""
def graph_fn():
box_batch = [
tf.constant([
self._box_center, self._box_center_offset, self._box_center,
self._box_center_offset
])
]
classes = [
tf.one_hot([0, 0, 1, 2], depth=4),
]
assigner = targetassigner.CenterNetCenterHeatmapTargetAssigner(4)
targets = assigner.assign_center_targets_from_boxes(80, 80, box_batch,
classes)
return targets
targets = self.execute(graph_fn, [])
class0_targets = targets[0, :, :, 0]
class1_targets = targets[0, :, :, 1]
class2_targets = targets[0, :, :, 2]
np.testing.assert_allclose(class0_targets,
np.maximum(class1_targets, class2_targets))
def test_size_blur(self):
"""Test that the heatmap of a larger box is more blurred."""
def graph_fn():
box_batch = [tf.constant([self._box_center, self._box_center_small])]
classes = [
tf.one_hot([0, 1], depth=4),
]
assigner = targetassigner.CenterNetCenterHeatmapTargetAssigner(4)
targets = assigner.assign_center_targets_from_boxes(80, 80, box_batch,
classes)
return targets
targets = self.execute(graph_fn, [])
self.assertGreater(
np.count_nonzero(targets[:, :, :, 0]),
np.count_nonzero(targets[:, :, :, 1]))
def test_weights(self):
"""Test that the weights correctly ignore ground truth."""
def graph1_fn():
box_batch = [
tf.constant([self._box_center, self._box_lower_left]),
tf.constant([self._box_center]),
tf.constant([self._box_center_small]),
]
classes = [
tf.one_hot([0, 1], depth=4),
tf.one_hot([2], depth=4),
tf.one_hot([3], depth=4),
]
assigner = targetassigner.CenterNetCenterHeatmapTargetAssigner(4)
targets = assigner.assign_center_targets_from_boxes(80, 80, box_batch,
classes)
return targets
targets = self.execute(graph1_fn, [])
self.assertAlmostEqual(1.0, targets[0, :, :, 0].max())
self.assertAlmostEqual(1.0, targets[0, :, :, 1].max())
self.assertAlmostEqual(1.0, targets[1, :, :, 2].max())
self.assertAlmostEqual(1.0, targets[2, :, :, 3].max())
self.assertAlmostEqual(0.0, targets[0, :, :, [2, 3]].max())
self.assertAlmostEqual(0.0, targets[1, :, :, [0, 1, 3]].max())
self.assertAlmostEqual(0.0, targets[2, :, :, :3].max())
def graph2_fn():
weights = [
tf.constant([0., 1.]),
tf.constant([1.]),
tf.constant([1.]),
]
box_batch = [
tf.constant([self._box_center, self._box_lower_left]),
tf.constant([self._box_center]),
tf.constant([self._box_center_small]),
]
classes = [
tf.one_hot([0, 1], depth=4),
tf.one_hot([2], depth=4),
tf.one_hot([3], depth=4),
]
assigner = targetassigner.CenterNetCenterHeatmapTargetAssigner(4)
targets = assigner.assign_center_targets_from_boxes(80, 80, box_batch,
classes,
weights)
return targets
targets = self.execute(graph2_fn, [])
self.assertAlmostEqual(1.0, targets[0, :, :, 1].max())
self.assertAlmostEqual(1.0, targets[1, :, :, 2].max())
self.assertAlmostEqual(1.0, targets[2, :, :, 3].max())
self.assertAlmostEqual(0.0, targets[0, :, :, [0, 2, 3]].max())
self.assertAlmostEqual(0.0, targets[1, :, :, [0, 1, 3]].max())
self.assertAlmostEqual(0.0, targets[2, :, :, :3].max())
def test_low_overlap(self):
def graph1_fn():
box_batch = [tf.constant([self._box_center])]
classes = [
tf.one_hot([0], depth=2),
]
assigner = targetassigner.CenterNetCenterHeatmapTargetAssigner(
4, min_overlap=0.1)
targets_low_overlap = assigner.assign_center_targets_from_boxes(
80, 80, box_batch, classes)
return targets_low_overlap
targets_low_overlap = self.execute(graph1_fn, [])
self.assertLess(1, np.count_nonzero(targets_low_overlap))
def graph2_fn():
box_batch = [tf.constant([self._box_center])]
classes = [
tf.one_hot([0], depth=2),
]
assigner = targetassigner.CenterNetCenterHeatmapTargetAssigner(
4, min_overlap=0.6)
targets_medium_overlap = assigner.assign_center_targets_from_boxes(
80, 80, box_batch, classes)
return targets_medium_overlap
targets_medium_overlap = self.execute(graph2_fn, [])
self.assertLess(1, np.count_nonzero(targets_medium_overlap))
def graph3_fn():
box_batch = [tf.constant([self._box_center])]
classes = [
tf.one_hot([0], depth=2),
]
assigner = targetassigner.CenterNetCenterHeatmapTargetAssigner(
4, min_overlap=0.99)
targets_high_overlap = assigner.assign_center_targets_from_boxes(
80, 80, box_batch, classes)
return targets_high_overlap
targets_high_overlap = self.execute(graph3_fn, [])
self.assertTrue(np.all(targets_low_overlap >= targets_medium_overlap))
self.assertTrue(np.all(targets_medium_overlap >= targets_high_overlap))
def test_empty_box_list(self):
"""Test that an empty box list gives an all 0 heatmap."""
def graph_fn():
box_batch = [
tf.zeros((0, 4), dtype=tf.float32),
]
classes = [
tf.zeros((0, 5), dtype=tf.float32),
]
assigner = targetassigner.CenterNetCenterHeatmapTargetAssigner(
4, min_overlap=0.1)
targets = assigner.assign_center_targets_from_boxes(
80, 80, box_batch, classes)
return targets
targets = self.execute(graph_fn, [])
np.testing.assert_allclose(targets, 0.)
class CenterNetBoxTargetAssignerTest(test_case.TestCase):
def setUp(self):
super(CenterNetBoxTargetAssignerTest, self).setUp()
self._box_center = [0.0, 0.0, 1.0, 1.0]
self._box_center_small = [0.25, 0.25, 0.75, 0.75]
self._box_lower_left = [0.5, 0.0, 1.0, 0.5]
self._box_center_offset = [0.1, 0.05, 1.0, 1.0]
self._box_odd_coordinates = [0.1625, 0.2125, 0.5625, 0.9625]
def test_max_distance_for_overlap(self):
"""Test that the distance ensures the IoU with random boxes."""
# TODO(vighneshb) remove this after the `_smallest_positive_root`
# function if fixed.
self.skipTest(('Skipping test because we are using an incorrect version of'
'the `max_distance_for_overlap` function to reproduce'
' results.'))
rng = np.random.RandomState(0)
n_samples = 100
width = rng.uniform(1, 100, size=n_samples)
height = rng.uniform(1, 100, size=n_samples)
min_iou = rng.uniform(0.1, 1.0, size=n_samples)
def graph_fn():
max_dist = targetassigner.max_distance_for_overlap(height, width, min_iou)
return max_dist
max_dist = self.execute(graph_fn, [])
xmin1 = np.zeros(n_samples)
ymin1 = np.zeros(n_samples)
xmax1 = np.zeros(n_samples) + width
ymax1 = np.zeros(n_samples) + height
xmin2 = max_dist * np.cos(rng.uniform(0, 2 * np.pi))
ymin2 = max_dist * np.sin(rng.uniform(0, 2 * np.pi))
xmax2 = width + max_dist * np.cos(rng.uniform(0, 2 * np.pi))
ymax2 = height + max_dist * np.sin(rng.uniform(0, 2 * np.pi))
boxes1 = np.vstack([ymin1, xmin1, ymax1, xmax1]).T
boxes2 = np.vstack([ymin2, xmin2, ymax2, xmax2]).T
iou = np.diag(np_box_ops.iou(boxes1, boxes2))
self.assertTrue(np.all(iou >= min_iou))
def test_max_distance_for_overlap_centernet(self):
"""Test the version of the function used in the CenterNet paper."""
def graph_fn():
distance = targetassigner.max_distance_for_overlap(10, 5, 0.5)
return distance
distance = self.execute(graph_fn, [])
self.assertAlmostEqual(2.807764064, distance)
def test_assign_size_and_offset_targets(self):
"""Test the assign_size_and_offset_targets function."""
def graph_fn():
box_batch = [
tf.constant([self._box_center, self._box_lower_left]),
tf.constant([self._box_center_offset]),
tf.constant([self._box_center_small, self._box_odd_coordinates]),
]
assigner = targetassigner.CenterNetBoxTargetAssigner(4)
indices, hw, yx_offset, weights = assigner.assign_size_and_offset_targets(
80, 80, box_batch)
return indices, hw, yx_offset, weights
indices, hw, yx_offset, weights = self.execute(graph_fn, [])
self.assertEqual(indices.shape, (5, 3))
self.assertEqual(hw.shape, (5, 2))
self.assertEqual(yx_offset.shape, (5, 2))
self.assertEqual(weights.shape, (5,))
np.testing.assert_array_equal(
indices,
[[0, 10, 10], [0, 15, 5], [1, 11, 10], [2, 10, 10], [2, 7, 11]])
np.testing.assert_array_equal(
hw, [[20, 20], [10, 10], [18, 19], [10, 10], [8, 15]])
np.testing.assert_array_equal(
yx_offset, [[0, 0], [0, 0], [0, 0.5], [0, 0], [0.25, 0.75]])
np.testing.assert_array_equal(weights, 1)
def test_assign_size_and_offset_targets_weights(self):
"""Test the assign_size_and_offset_targets function with box weights."""
def graph_fn():
box_batch = [
tf.constant([self._box_center, self._box_lower_left]),
tf.constant([self._box_lower_left, self._box_center_small]),
tf.constant([self._box_center_small, self._box_odd_coordinates]),
]
cn_assigner = targetassigner.CenterNetBoxTargetAssigner(4)
weights_batch = [
tf.constant([0.0, 1.0]),
tf.constant([1.0, 1.0]),
tf.constant([0.0, 0.0])
]
indices, hw, yx_offset, weights = cn_assigner.assign_size_and_offset_targets(
80, 80, box_batch, weights_batch)
return indices, hw, yx_offset, weights
indices, hw, yx_offset, weights = self.execute(graph_fn, [])
self.assertEqual(indices.shape, (6, 3))
self.assertEqual(hw.shape, (6, 2))
self.assertEqual(yx_offset.shape, (6, 2))
self.assertEqual(weights.shape, (6,))
np.testing.assert_array_equal(indices,
[[0, 10, 10], [0, 15, 5], [1, 15, 5],
[1, 10, 10], [2, 10, 10], [2, 7, 11]])
np.testing.assert_array_equal(
hw, [[20, 20], [10, 10], [10, 10], [10, 10], [10, 10], [8, 15]])
np.testing.assert_array_equal(
yx_offset, [[0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [0.25, 0.75]])
np.testing.assert_array_equal(weights, [0, 1, 1, 1, 0, 0])
def test_get_batch_predictions_from_indices(self):
"""Test the get_batch_predictions_from_indices function.
This test verifies that the indices returned by
assign_size_and_offset_targets function work as expected with a predicted
tensor.
"""
def graph_fn():
box_batch = [
tf.constant([self._box_center, self._box_lower_left]),
tf.constant([self._box_center_small, self._box_odd_coordinates]),
]
pred_array = np.ones((2, 40, 20, 2), dtype=np.int32) * -1000
pred_array[0, 20, 10] = [1, 2]
pred_array[0, 30, 5] = [3, 4]
pred_array[1, 20, 10] = [5, 6]
pred_array[1, 14, 11] = [7, 8]
pred_tensor = tf.constant(pred_array)
cn_assigner = targetassigner.CenterNetBoxTargetAssigner(4)
indices, _, _, _ = cn_assigner.assign_size_and_offset_targets(
160, 80, box_batch)
preds = targetassigner.get_batch_predictions_from_indices(
pred_tensor, indices)
return preds
preds = self.execute(graph_fn, [])
np.testing.assert_array_equal(preds, [[1, 2], [3, 4], [5, 6], [7, 8]])
class CenterNetKeypointTargetAssignerTest(test_case.TestCase):
def test_keypoint_heatmap_targets(self):
def graph_fn():
gt_classes_list = [
tf.one_hot([0, 1, 0, 1], depth=4),
]
coordinates = tf.expand_dims(
tf.constant(
np.array([[0.1, 0.2, 0.3, 0.4, 0.5],
[float('nan'), 0.7, float('nan'), 0.9, 1.0],
[0.4, 0.1, 0.4, 0.2, 0.1],
[float('nan'), 0.1, 0.5, 0.7, 0.6]]),
dtype=tf.float32),
axis=2)
gt_keypoints_list = [tf.concat([coordinates, coordinates], axis=2)]
gt_boxes_list = [
tf.constant(
np.array([[0.0, 0.0, 0.3, 0.3],
[0.0, 0.0, 0.5, 0.5],
[0.0, 0.0, 0.5, 0.5],
[0.0, 0.0, 1.0, 1.0]]),
dtype=tf.float32)
]
cn_assigner = targetassigner.CenterNetKeypointTargetAssigner(
stride=4,
class_id=1,
keypoint_indices=[0, 2])
(targets, num_instances_batch,
valid_mask) = cn_assigner.assign_keypoint_heatmap_targets(
120,
80,
gt_keypoints_list,
gt_classes_list,
gt_boxes_list=gt_boxes_list)
return targets, num_instances_batch, valid_mask
targets, num_instances_batch, valid_mask = self.execute(graph_fn, [])
# keypoint (0.5, 0.5) is selected. The peak is expected to appear at the
# center of the image.
self.assertEqual((15, 10), _array_argmax(targets[0, :, :, 1]))
self.assertAlmostEqual(1.0, targets[0, 15, 10, 1])
# No peak for the first class since NaN is selected.
self.assertAlmostEqual(0.0, targets[0, 15, 10, 0])
# Verify the output heatmap shape.
self.assertAllEqual([1, 30, 20, 2], targets.shape)
# Verify the number of instances is correct.
np.testing.assert_array_almost_equal([[0, 1]],
num_instances_batch)
# When calling the function, we specify the class id to be 1 (1th and 3rd)
# instance and the keypoint indices to be [0, 2], meaning that the 1st
# instance is the target class with no valid keypoints in it. As a result,
# the region of the 1st instance boxing box should be blacked out
# (0.0, 0.0, 0.5, 0.5), transfering to (0, 0, 15, 10) in absolute output
# space.
self.assertAlmostEqual(np.sum(valid_mask[:, 0:16, 0:11]), 0.0)
# All other values are 1.0 so the sum is: 30 * 20 - 16 * 11 = 424.
self.assertAlmostEqual(np.sum(valid_mask), 424.0)
def test_assign_keypoints_offset_targets(self):
def graph_fn():
gt_classes_list = [
tf.one_hot([0, 1, 0, 1], depth=4),
]
coordinates = tf.expand_dims(
tf.constant(
np.array([[0.1, 0.2, 0.3, 0.4, 0.5],
[float('nan'), 0.7, float('nan'), 0.9, 0.4],
[0.4, 0.1, 0.4, 0.2, 0.0],
[float('nan'), 0.0, 0.12, 0.7, 0.4]]),
dtype=tf.float32),
axis=2)
gt_keypoints_list = [tf.concat([coordinates, coordinates], axis=2)]
cn_assigner = targetassigner.CenterNetKeypointTargetAssigner(
stride=4,
class_id=1,
keypoint_indices=[0, 2])
(indices, offsets, weights) = cn_assigner.assign_keypoints_offset_targets(
height=120,
width=80,
gt_keypoints_list=gt_keypoints_list,
gt_classes_list=gt_classes_list)
return indices, weights, offsets
indices, weights, offsets = self.execute(graph_fn, [])
# Only the last element has positive weight.
np.testing.assert_array_almost_equal(
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0], weights)
# Validate the last element's indices and offsets.
np.testing.assert_array_equal([0, 3, 2], indices[7, :])
np.testing.assert_array_almost_equal([0.6, 0.4], offsets[7, :])
def test_assign_keypoints_offset_targets_radius(self):
def graph_fn():
gt_classes_list = [
tf.one_hot([0, 1, 0, 1], depth=4),
]
coordinates = tf.expand_dims(
tf.constant(
np.array([[0.1, 0.2, 0.3, 0.4, 0.5],
[float('nan'), 0.7, float('nan'), 0.9, 0.4],
[0.4, 0.1, 0.4, 0.2, 0.0],
[float('nan'), 0.0, 0.12, 0.7, 0.4]]),
dtype=tf.float32),
axis=2)
gt_keypoints_list = [tf.concat([coordinates, coordinates], axis=2)]
cn_assigner = targetassigner.CenterNetKeypointTargetAssigner(
stride=4,
class_id=1,
keypoint_indices=[0, 2],
peak_radius=1,
per_keypoint_offset=True)
(indices, offsets, weights) = cn_assigner.assign_keypoints_offset_targets(
height=120,
width=80,
gt_keypoints_list=gt_keypoints_list,
gt_classes_list=gt_classes_list)
return indices, weights, offsets
indices, weights, offsets = self.execute(graph_fn, [])
# There are total 8 * 5 (neighbors) = 40 targets.
self.assertAllEqual(indices.shape, [40, 4])
self.assertAllEqual(offsets.shape, [40, 2])
self.assertAllEqual(weights.shape, [40])
# Only the last 5 (radius 1 generates 5 valid points) element has positive
# weight.
np.testing.assert_array_almost_equal([
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0
], weights)
# Validate the last element's (with neighbors) indices and offsets.
np.testing.assert_array_equal([0, 2, 2, 1], indices[35, :])
np.testing.assert_array_equal([0, 3, 1, 1], indices[36, :])
np.testing.assert_array_equal([0, 3, 2, 1], indices[37, :])
np.testing.assert_array_equal([0, 3, 3, 1], indices[38, :])
np.testing.assert_array_equal([0, 4, 2, 1], indices[39, :])
np.testing.assert_array_almost_equal([1.6, 0.4], offsets[35, :])
np.testing.assert_array_almost_equal([0.6, 1.4], offsets[36, :])
np.testing.assert_array_almost_equal([0.6, 0.4], offsets[37, :])
np.testing.assert_array_almost_equal([0.6, -0.6], offsets[38, :])
np.testing.assert_array_almost_equal([-0.4, 0.4], offsets[39, :])
def test_assign_joint_regression_targets(self):
def graph_fn():
gt_boxes_list = [
tf.constant(
np.array([[0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 1.0, 1.0]]),
dtype=tf.float32)
]
gt_classes_list = [
tf.one_hot([0, 1, 0, 1], depth=4),
]
coordinates = tf.expand_dims(
tf.constant(
np.array([[0.1, 0.2, 0.3, 0.4, 0.5],
[float('nan'), 0.7, float('nan'), 0.9, 0.4],
[0.4, 0.1, 0.4, 0.2, 0.0],
[float('nan'), 0.0, 0.12, 0.7, 0.4]]),
dtype=tf.float32),
axis=2)
gt_keypoints_list = [tf.concat([coordinates, coordinates], axis=2)]
cn_assigner = targetassigner.CenterNetKeypointTargetAssigner(
stride=4,
class_id=1,
keypoint_indices=[0, 2])
(indices, offsets, weights) = cn_assigner.assign_joint_regression_targets(
height=120,
width=80,
gt_keypoints_list=gt_keypoints_list,
gt_classes_list=gt_classes_list,
gt_boxes_list=gt_boxes_list)
return indices, offsets, weights
indices, offsets, weights = self.execute(graph_fn, [])
np.testing.assert_array_almost_equal(
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0], weights)
np.testing.assert_array_equal([0, 15, 10, 1], indices[7, :])
np.testing.assert_array_almost_equal([-11.4, -7.6], offsets[7, :])
def test_assign_joint_regression_targets_radius(self):
def graph_fn():
gt_boxes_list = [
tf.constant(
np.array([[0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 1.0, 1.0]]),
dtype=tf.float32)
]
gt_classes_list = [
tf.one_hot([0, 1, 0, 1], depth=4),
]
coordinates = tf.expand_dims(
tf.constant(
np.array([[0.1, 0.2, 0.3, 0.4, 0.5],
[float('nan'), 0.7, float('nan'), 0.9, 0.4],
[0.4, 0.1, 0.4, 0.2, 0.0],
[float('nan'), 0.0, 0.12, 0.7, 0.4]]),
dtype=tf.float32),
axis=2)
gt_keypoints_list = [tf.concat([coordinates, coordinates], axis=2)]
cn_assigner = targetassigner.CenterNetKeypointTargetAssigner(
stride=4,
class_id=1,
keypoint_indices=[0, 2],
peak_radius=1)
(indices, offsets, weights) = cn_assigner.assign_joint_regression_targets(
height=120,
width=80,
gt_keypoints_list=gt_keypoints_list,
gt_classes_list=gt_classes_list,
gt_boxes_list=gt_boxes_list)
return indices, offsets, weights
indices, offsets, weights = self.execute(graph_fn, [])
# There are total 8 * 5 (neighbors) = 40 targets.
self.assertAllEqual(indices.shape, [40, 4])
self.assertAllEqual(offsets.shape, [40, 2])
self.assertAllEqual(weights.shape, [40])
# Only the last 5 (radius 1 generates 5 valid points) element has positive
# weight.
np.testing.assert_array_almost_equal([
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0
], weights)
# Test the values of the indices and offsets of the last 5 elements.
np.testing.assert_array_equal([0, 14, 10, 1], indices[35, :])
np.testing.assert_array_equal([0, 15, 9, 1], indices[36, :])
np.testing.assert_array_equal([0, 15, 10, 1], indices[37, :])
np.testing.assert_array_equal([0, 15, 11, 1], indices[38, :])
np.testing.assert_array_equal([0, 16, 10, 1], indices[39, :])
np.testing.assert_array_almost_equal([-10.4, -7.6], offsets[35, :])
np.testing.assert_array_almost_equal([-11.4, -6.6], offsets[36, :])
np.testing.assert_array_almost_equal([-11.4, -7.6], offsets[37, :])
np.testing.assert_array_almost_equal([-11.4, -8.6], offsets[38, :])
np.testing.assert_array_almost_equal([-12.4, -7.6], offsets[39, :])
class CenterNetMaskTargetAssignerTest(test_case.TestCase):
def test_assign_segmentation_targets(self):
def graph_fn():
gt_masks_list = [
# Example 0.
tf.constant([
[
[1., 0., 0., 0.],
[1., 1., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
],
[
[0., 0., 0., 0.],
[0., 0., 0., 1.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
],
[
[1., 1., 0., 0.],
[1., 1., 0., 0.],
[0., 0., 1., 1.],
[0., 0., 1., 1.],
]
], dtype=tf.float32),
# Example 1.
tf.constant([
[
[1., 1., 0., 1.],
[1., 1., 1., 1.],
[0., 0., 1., 1.],
[0., 0., 0., 1.],
],
[
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[1., 1., 0., 0.],
[1., 1., 0., 0.],
],
], dtype=tf.float32),
]
gt_classes_list = [
# Example 0.
tf.constant([[1., 0., 0.],
[0., 1., 0.],
[1., 0., 0.]], dtype=tf.float32),
# Example 1.
tf.constant([[0., 1., 0.],
[0., 1., 0.]], dtype=tf.float32)
]
cn_assigner = targetassigner.CenterNetMaskTargetAssigner(stride=2)
segmentation_target = cn_assigner.assign_segmentation_targets(
gt_masks_list=gt_masks_list,
gt_classes_list=gt_classes_list,
mask_resize_method=targetassigner.ResizeMethod.NEAREST_NEIGHBOR)
return segmentation_target
segmentation_target = self.execute(graph_fn, [])
expected_seg_target = np.array([
# Example 0 [[class 0, class 1], [background, class 0]]
[[[1, 0, 0], [0, 1, 0]],
[[0, 0, 0], [1, 0, 0]]],
# Example 1 [[class 1, class 1], [class 1, class 1]]
[[[0, 1, 0], [0, 1, 0]],
[[0, 1, 0], [0, 1, 0]]],
], dtype=np.float32)
np.testing.assert_array_almost_equal(
expected_seg_target, segmentation_target)
if __name__ == '__main__':
tf.enable_v2_behavior()
......
......@@ -30,6 +30,7 @@ from object_detection.core import data_decoder
from object_detection.core import standard_fields as fields
from object_detection.protos import input_reader_pb2
from object_detection.utils import label_map_util
from object_detection.utils import shape_utils
# pylint: disable=g-import-not-at-top
try:
......@@ -170,7 +171,8 @@ class TfExampleDecoder(data_decoder.DataDecoder):
num_additional_channels=0,
load_multiclass_scores=False,
load_context_features=False,
expand_hierarchy_labels=False):
expand_hierarchy_labels=False,
load_dense_pose=False):
"""Constructor sets keys_to_features and items_to_handlers.
Args:
......@@ -201,6 +203,7 @@ class TfExampleDecoder(data_decoder.DataDecoder):
account the provided hierarchy in the label_map_proto_file. For positive
classes, the labels are extended to ancestor. For negative classes,
the labels are expanded to descendants.
load_dense_pose: Whether to load DensePose annotations.
Raises:
ValueError: If `instance_mask_type` option is not one of
......@@ -371,6 +374,34 @@ class TfExampleDecoder(data_decoder.DataDecoder):
self._decode_png_instance_masks))
else:
raise ValueError('Did not recognize the `instance_mask_type` option.')
if load_dense_pose:
self.keys_to_features['image/object/densepose/num'] = (
tf.VarLenFeature(tf.int64))
self.keys_to_features['image/object/densepose/part_index'] = (
tf.VarLenFeature(tf.int64))
self.keys_to_features['image/object/densepose/x'] = (
tf.VarLenFeature(tf.float32))
self.keys_to_features['image/object/densepose/y'] = (
tf.VarLenFeature(tf.float32))
self.keys_to_features['image/object/densepose/u'] = (
tf.VarLenFeature(tf.float32))
self.keys_to_features['image/object/densepose/v'] = (
tf.VarLenFeature(tf.float32))
self.items_to_handlers[
fields.InputDataFields.groundtruth_dp_num_points] = (
slim_example_decoder.Tensor('image/object/densepose/num'))
self.items_to_handlers[fields.InputDataFields.groundtruth_dp_part_ids] = (
slim_example_decoder.ItemHandlerCallback(
['image/object/densepose/part_index',
'image/object/densepose/num'], self._dense_pose_part_indices))
self.items_to_handlers[
fields.InputDataFields.groundtruth_dp_surface_coords] = (
slim_example_decoder.ItemHandlerCallback(
['image/object/densepose/x', 'image/object/densepose/y',
'image/object/densepose/u', 'image/object/densepose/v',
'image/object/densepose/num'],
self._dense_pose_surface_coordinates))
if label_map_proto_file:
# If the label_map_proto is provided, try to use it in conjunction with
# the class text, and fall back to a materialized ID.
......@@ -547,6 +578,14 @@ class TfExampleDecoder(data_decoder.DataDecoder):
group_of = fields.InputDataFields.groundtruth_group_of
tensor_dict[group_of] = tf.cast(tensor_dict[group_of], dtype=tf.bool)
if fields.InputDataFields.groundtruth_dp_num_points in tensor_dict:
tensor_dict[fields.InputDataFields.groundtruth_dp_num_points] = tf.cast(
tensor_dict[fields.InputDataFields.groundtruth_dp_num_points],
dtype=tf.int32)
tensor_dict[fields.InputDataFields.groundtruth_dp_part_ids] = tf.cast(
tensor_dict[fields.InputDataFields.groundtruth_dp_part_ids],
dtype=tf.int32)
return tensor_dict
def _reshape_keypoints(self, keys_to_tensors):
......@@ -697,6 +736,97 @@ class TfExampleDecoder(data_decoder.DataDecoder):
lambda: tf.map_fn(decode_png_mask, png_masks, dtype=tf.float32),
lambda: tf.zeros(tf.cast(tf.stack([0, height, width]), dtype=tf.int32)))
def _dense_pose_part_indices(self, keys_to_tensors):
"""Creates a tensor that contains part indices for each DensePose point.
Args:
keys_to_tensors: a dictionary from keys to tensors.
Returns:
A 2-D int32 tensor of shape [num_instances, num_points] where each element
contains the DensePose part index (0-23). The value `num_points`
corresponds to the maximum number of sampled points across all instances
in the image. Note that instances with less sampled points will be padded
with zeros in the last dimension.
"""
num_points_per_instances = keys_to_tensors['image/object/densepose/num']
part_index = keys_to_tensors['image/object/densepose/part_index']
if isinstance(num_points_per_instances, tf.SparseTensor):
num_points_per_instances = tf.sparse_tensor_to_dense(
num_points_per_instances)
if isinstance(part_index, tf.SparseTensor):
part_index = tf.sparse_tensor_to_dense(part_index)
part_index = tf.cast(part_index, dtype=tf.int32)
max_points_per_instance = tf.cast(
tf.math.reduce_max(num_points_per_instances), dtype=tf.int32)
num_points_cumulative = tf.concat([
[0], tf.math.cumsum(num_points_per_instances)], axis=0)
def pad_parts_tensor(instance_ind):
points_range_start = num_points_cumulative[instance_ind]
points_range_end = num_points_cumulative[instance_ind + 1]
part_inds = part_index[points_range_start:points_range_end]
return shape_utils.pad_or_clip_nd(part_inds,
output_shape=[max_points_per_instance])
return tf.map_fn(pad_parts_tensor,
tf.range(tf.size(num_points_per_instances)),
dtype=tf.int32)
def _dense_pose_surface_coordinates(self, keys_to_tensors):
"""Creates a tensor that contains surface coords for each DensePose point.
Args:
keys_to_tensors: a dictionary from keys to tensors.
Returns:
A 3-D float32 tensor of shape [num_instances, num_points, 4] where each
point contains (y, x, v, u) data for each sampled DensePose point. The
(y, x) coordinate has normalized image locations for the point, and (v, u)
contains the surface coordinate (also normalized) for the part. The value
`num_points` corresponds to the maximum number of sampled points across
all instances in the image. Note that instances with less sampled points
will be padded with zeros in dim=1.
"""
num_points_per_instances = keys_to_tensors['image/object/densepose/num']
dp_y = keys_to_tensors['image/object/densepose/y']
dp_x = keys_to_tensors['image/object/densepose/x']
dp_v = keys_to_tensors['image/object/densepose/v']
dp_u = keys_to_tensors['image/object/densepose/u']
if isinstance(num_points_per_instances, tf.SparseTensor):
num_points_per_instances = tf.sparse_tensor_to_dense(
num_points_per_instances)
if isinstance(dp_y, tf.SparseTensor):
dp_y = tf.sparse_tensor_to_dense(dp_y)
if isinstance(dp_x, tf.SparseTensor):
dp_x = tf.sparse_tensor_to_dense(dp_x)
if isinstance(dp_v, tf.SparseTensor):
dp_v = tf.sparse_tensor_to_dense(dp_v)
if isinstance(dp_u, tf.SparseTensor):
dp_u = tf.sparse_tensor_to_dense(dp_u)
max_points_per_instance = tf.cast(
tf.math.reduce_max(num_points_per_instances), dtype=tf.int32)
num_points_cumulative = tf.concat([
[0], tf.math.cumsum(num_points_per_instances)], axis=0)
def pad_surface_coordinates_tensor(instance_ind):
"""Pads DensePose surface coordinates for each instance."""
points_range_start = num_points_cumulative[instance_ind]
points_range_end = num_points_cumulative[instance_ind + 1]
y = dp_y[points_range_start:points_range_end]
x = dp_x[points_range_start:points_range_end]
v = dp_v[points_range_start:points_range_end]
u = dp_u[points_range_start:points_range_end]
# Create [num_points_i, 4] tensor, where num_points_i is the number of
# sampled points for instance i.
unpadded_tensor = tf.stack([y, x, v, u], axis=1)
return shape_utils.pad_or_clip_nd(
unpadded_tensor, output_shape=[max_points_per_instance, 4])
return tf.map_fn(pad_surface_coordinates_tensor,
tf.range(tf.size(num_points_per_instances)),
dtype=tf.float32)
def _expand_image_label_hierarchy(self, image_classes, image_confidences):
"""Expand image level labels according to the hierarchy.
......
......@@ -1096,8 +1096,8 @@ class TfExampleDecoderTest(test_case.TestCase):
return example_decoder.decode(tf.convert_to_tensor(example))
tensor_dict = self.execute_cpu(graph_fn, [])
self.assertTrue(
fields.InputDataFields.groundtruth_instance_masks not in tensor_dict)
self.assertNotIn(fields.InputDataFields.groundtruth_instance_masks,
tensor_dict)
def testDecodeImageLabels(self):
image_tensor = np.random.randint(256, size=(4, 5, 3)).astype(np.uint8)
......@@ -1116,8 +1116,7 @@ class TfExampleDecoderTest(test_case.TestCase):
return example_decoder.decode(tf.convert_to_tensor(example))
tensor_dict = self.execute_cpu(graph_fn_1, [])
self.assertTrue(
fields.InputDataFields.groundtruth_image_classes in tensor_dict)
self.assertIn(fields.InputDataFields.groundtruth_image_classes, tensor_dict)
self.assertAllEqual(
tensor_dict[fields.InputDataFields.groundtruth_image_classes],
np.array([1, 2]))
......@@ -1152,8 +1151,7 @@ class TfExampleDecoderTest(test_case.TestCase):
return example_decoder.decode(tf.convert_to_tensor(example))
tensor_dict = self.execute_cpu(graph_fn_2, [])
self.assertTrue(
fields.InputDataFields.groundtruth_image_classes in tensor_dict)
self.assertIn(fields.InputDataFields.groundtruth_image_classes, tensor_dict)
self.assertAllEqual(
tensor_dict[fields.InputDataFields.groundtruth_image_classes],
np.array([1, 3]))
......@@ -1345,6 +1343,93 @@ class TfExampleDecoderTest(test_case.TestCase):
expected_image_confidence,
tensor_dict[fields.InputDataFields.groundtruth_image_confidences])
def testDecodeDensePose(self):
image_tensor = np.random.randint(256, size=(4, 5, 3)).astype(np.uint8)
encoded_jpeg, _ = self._create_encoded_and_decoded_data(
image_tensor, 'jpeg')
bbox_ymins = [0.0, 4.0, 2.0]
bbox_xmins = [1.0, 5.0, 8.0]
bbox_ymaxs = [2.0, 6.0, 1.0]
bbox_xmaxs = [3.0, 7.0, 3.3]
densepose_num = [0, 4, 2]
densepose_part_index = [2, 2, 3, 4, 2, 9]
densepose_x = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
densepose_y = [0.9, 0.8, 0.7, 0.6, 0.5, 0.4]
densepose_u = [0.01, 0.02, 0.03, 0.04, 0.05, 0.06]
densepose_v = [0.99, 0.98, 0.97, 0.96, 0.95, 0.94]
def graph_fn():
example = tf.train.Example(
features=tf.train.Features(
feature={
'image/encoded':
dataset_util.bytes_feature(encoded_jpeg),
'image/format':
dataset_util.bytes_feature(six.b('jpeg')),
'image/object/bbox/ymin':
dataset_util.float_list_feature(bbox_ymins),
'image/object/bbox/xmin':
dataset_util.float_list_feature(bbox_xmins),
'image/object/bbox/ymax':
dataset_util.float_list_feature(bbox_ymaxs),
'image/object/bbox/xmax':
dataset_util.float_list_feature(bbox_xmaxs),
'image/object/densepose/num':
dataset_util.int64_list_feature(densepose_num),
'image/object/densepose/part_index':
dataset_util.int64_list_feature(densepose_part_index),
'image/object/densepose/x':
dataset_util.float_list_feature(densepose_x),
'image/object/densepose/y':
dataset_util.float_list_feature(densepose_y),
'image/object/densepose/u':
dataset_util.float_list_feature(densepose_u),
'image/object/densepose/v':
dataset_util.float_list_feature(densepose_v),
})).SerializeToString()
example_decoder = tf_example_decoder.TfExampleDecoder(
load_dense_pose=True)
output = example_decoder.decode(tf.convert_to_tensor(example))
dp_num_points = output[fields.InputDataFields.groundtruth_dp_num_points]
dp_part_ids = output[fields.InputDataFields.groundtruth_dp_part_ids]
dp_surface_coords = output[
fields.InputDataFields.groundtruth_dp_surface_coords]
return dp_num_points, dp_part_ids, dp_surface_coords
dp_num_points, dp_part_ids, dp_surface_coords = self.execute_cpu(
graph_fn, [])
expected_dp_num_points = [0, 4, 2]
expected_dp_part_ids = [
[0, 0, 0, 0],
[2, 2, 3, 4],
[2, 9, 0, 0]
]
expected_dp_surface_coords = np.array(
[
# Instance 0 (no points).
[[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]],
# Instance 1 (4 points).
[[0.9, 0.1, 0.99, 0.01],
[0.8, 0.2, 0.98, 0.02],
[0.7, 0.3, 0.97, 0.03],
[0.6, 0.4, 0.96, 0.04]],
# Instance 2 (2 points).
[[0.5, 0.5, 0.95, 0.05],
[0.4, 0.6, 0.94, 0.06],
[0., 0., 0., 0.],
[0., 0., 0., 0.]],
], dtype=np.float32)
self.assertAllEqual(dp_num_points, expected_dp_num_points)
self.assertAllEqual(dp_part_ids, expected_dp_part_ids)
self.assertAllClose(dp_surface_coords, expected_dp_surface_coords)
if __name__ == '__main__':
tf.test.main()
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""A Beam job to add contextual memory banks to tf.Examples.
This tool groups images containing bounding boxes and embedded context features
by a key, either `image/location` or `image/seq_id`, and time horizon,
then uses these groups to build up a contextual memory bank from the embedded
context features from each image in the group and adds that context to the
output tf.Examples for each image in the group.
Steps to generate a dataset with context from one with bounding boxes and
embedded context features:
1. Use object/detection/export_inference_graph.py to get a `saved_model` for
inference. The input node must accept a tf.Example proto.
2. Run this tool with `saved_model` from step 1 and a TFRecord of tf.Example
protos containing images, bounding boxes, and embedded context features.
The context features can be added to tf.Examples using
generate_embedding_data.py.
Example Usage:
--------------
python add_context_to_examples.py \
--input_tfrecord path/to/input_tfrecords* \
--output_tfrecord path/to/output_tfrecords \
--sequence_key image/location \
--time_horizon month
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
import datetime
import io
import itertools
import json
import os
from absl import app
from absl import flags
import apache_beam as beam
import numpy as np
import PIL.Image
import six
import tensorflow as tf
from apache_beam import runners
flags.DEFINE_string('input_tfrecord', None, 'TFRecord containing images in '
'tf.Example format for object detection, with bounding'
'boxes and contextual feature embeddings.')
flags.DEFINE_string('output_tfrecord', None,
'TFRecord containing images in tf.Example format, with '
'added contextual memory banks.')
flags.DEFINE_string('sequence_key', None, 'Key to use when grouping sequences: '
'so far supports `image/seq_id` and `image/location`.')
flags.DEFINE_string('time_horizon', None, 'What time horizon to use when '
'splitting the data, if any. Options are: `year`, `month`,'
' `week`, `day `, `hour`, `minute`, `None`.')
flags.DEFINE_integer('subsample_context_features_rate', 0, 'Whether to '
'subsample the context_features, and if so how many to '
'sample. If the rate is set to X, it will sample context '
'from 1 out of every X images. Default is sampling from '
'every image, which is X=0.')
flags.DEFINE_boolean('reduce_image_size', True, 'downsamples images to'
'have longest side max_image_dimension, maintaining aspect'
' ratio')
flags.DEFINE_integer('max_image_dimension', 1024, 'sets max image dimension')
flags.DEFINE_boolean('add_context_features', True, 'adds a memory bank of'
'embeddings to each clip')
flags.DEFINE_boolean('sorted_image_ids', True, 'whether the image source_ids '
'are sortable to deal with date_captured tie-breaks')
flags.DEFINE_string('image_ids_to_keep', 'All', 'path to .json list of image'
'ids to keep, used for ground truth eval creation')
flags.DEFINE_boolean('keep_context_features_image_id_list', False, 'Whether or '
'not to keep a list of the image_ids corresponding to the '
'memory bank')
flags.DEFINE_boolean('keep_only_positives', False, 'Whether or not to '
'keep only positive boxes based on score')
flags.DEFINE_boolean('keep_only_positives_gt', False, 'Whether or not to '
'keep only positive boxes based on gt class')
flags.DEFINE_float('context_features_score_threshold', 0.7, 'What score '
'threshold to use for boxes in context_features')
flags.DEFINE_integer('max_num_elements_in_context_features', 2000, 'Sets max '
'num elements per memory bank')
flags.DEFINE_integer('num_shards', 0, 'Number of output shards.')
flags.DEFINE_string('output_type', 'tf_sequence_example', 'Output type, one of '
'`tf_example`, `tf_sequence_example`')
flags.DEFINE_integer('max_clip_length', None, 'Max length for sequence '
'example outputs.')
FLAGS = flags.FLAGS
DEFAULT_FEATURE_LENGTH = 2057
class ReKeyDataFn(beam.DoFn):
"""Re-keys tfrecords by sequence_key.
This Beam DoFn re-keys the tfrecords by a user-defined sequence_key
"""
def __init__(self, sequence_key, time_horizon,
reduce_image_size, max_image_dimension):
"""Initialization function.
Args:
sequence_key: A feature name to use as a key for grouping sequences.
Must point to a key of type bytes_list
time_horizon: What length of time to use to partition the data when
building the memory banks. Options: `year`, `month`, `week`, `day `,
`hour`, `minute`, None
reduce_image_size: Whether to reduce the sizes of the stored images.
max_image_dimension: maximum dimension of reduced images
"""
self._sequence_key = sequence_key
if time_horizon is None or time_horizon in {'year', 'month', 'week', 'day',
'hour', 'minute'}:
self._time_horizon = time_horizon
else:
raise ValueError('Time horizon not supported.')
self._reduce_image_size = reduce_image_size
self._max_image_dimension = max_image_dimension
self._session = None
self._num_examples_processed = beam.metrics.Metrics.counter(
'data_rekey', 'num_tf_examples_processed')
self._num_images_resized = beam.metrics.Metrics.counter(
'data_rekey', 'num_images_resized')
self._num_images_read = beam.metrics.Metrics.counter(
'data_rekey', 'num_images_read')
self._num_images_found = beam.metrics.Metrics.counter(
'data_rekey', 'num_images_read')
self._num_got_shape = beam.metrics.Metrics.counter(
'data_rekey', 'num_images_got_shape')
self._num_images_found_size = beam.metrics.Metrics.counter(
'data_rekey', 'num_images_found_size')
self._num_examples_cleared = beam.metrics.Metrics.counter(
'data_rekey', 'num_examples_cleared')
self._num_examples_updated = beam.metrics.Metrics.counter(
'data_rekey', 'num_examples_updated')
def process(self, tfrecord_entry):
return self._rekey_examples(tfrecord_entry)
def _largest_size_at_most(self, height, width, largest_side):
"""Computes new shape with the largest side equal to `largest_side`.
Args:
height: an int indicating the current height.
width: an int indicating the current width.
largest_side: A python integer indicating the size of
the largest side after resize.
Returns:
new_height: an int indicating the new height.
new_width: an int indicating the new width.
"""
x_scale = float(largest_side) / float(width)
y_scale = float(largest_side) / float(height)
scale = min(x_scale, y_scale)
new_width = int(width * scale)
new_height = int(height * scale)
return new_height, new_width
def _resize_image(self, input_example):
"""Resizes the image within input_example and updates the height and width.
Args:
input_example: A tf.Example that we want to update to contain a resized
image.
Returns:
input_example: Updated tf.Example.
"""
original_image = copy.deepcopy(
input_example.features.feature['image/encoded'].bytes_list.value[0])
self._num_images_read.inc(1)
height = copy.deepcopy(
input_example.features.feature['image/height'].int64_list.value[0])
width = copy.deepcopy(
input_example.features.feature['image/width'].int64_list.value[0])
self._num_got_shape.inc(1)
new_height, new_width = self._largest_size_at_most(
height, width, self._max_image_dimension)
self._num_images_found_size.inc(1)
encoded_jpg_io = io.BytesIO(original_image)
image = PIL.Image.open(encoded_jpg_io)
resized_image = image.resize((new_width, new_height))
with io.BytesIO() as output:
resized_image.save(output, format='JPEG')
encoded_resized_image = output.getvalue()
self._num_images_resized.inc(1)
del input_example.features.feature['image/encoded'].bytes_list.value[:]
del input_example.features.feature['image/height'].int64_list.value[:]
del input_example.features.feature['image/width'].int64_list.value[:]
self._num_examples_cleared.inc(1)
input_example.features.feature['image/encoded'].bytes_list.value.extend(
[encoded_resized_image])
input_example.features.feature['image/height'].int64_list.value.extend(
[new_height])
input_example.features.feature['image/width'].int64_list.value.extend(
[new_width])
self._num_examples_updated.inc(1)
return input_example
def _rekey_examples(self, tfrecord_entry):
serialized_example = copy.deepcopy(tfrecord_entry)
input_example = tf.train.Example.FromString(serialized_example)
self._num_images_found.inc(1)
if self._reduce_image_size:
input_example = self._resize_image(input_example)
self._num_images_resized.inc(1)
new_key = input_example.features.feature[
self._sequence_key].bytes_list.value[0]
if self._time_horizon:
date_captured = datetime.datetime.strptime(
six.ensure_str(input_example.features.feature[
'image/date_captured'].bytes_list.value[0]), '%Y-%m-%d %H:%M:%S')
year = date_captured.year
month = date_captured.month
day = date_captured.day
week = np.floor(float(day) / float(7))
hour = date_captured.hour
minute = date_captured.minute
if self._time_horizon == 'year':
new_key = new_key + six.ensure_binary('/' + str(year))
elif self._time_horizon == 'month':
new_key = new_key + six.ensure_binary(
'/' + str(year) + '/' + str(month))
elif self._time_horizon == 'week':
new_key = new_key + six.ensure_binary(
'/' + str(year) + '/' + str(month) + '/' + str(week))
elif self._time_horizon == 'day':
new_key = new_key + six.ensure_binary(
'/' + str(year) + '/' + str(month) + '/' + str(day))
elif self._time_horizon == 'hour':
new_key = new_key + six.ensure_binary(
'/' + str(year) + '/' + str(month) + '/' + str(day) + '/' + (
str(hour)))
elif self._time_horizon == 'minute':
new_key = new_key + six.ensure_binary(
'/' + str(year) + '/' + str(month) + '/' + str(day) + '/' + (
str(hour) + '/' + str(minute)))
self._num_examples_processed.inc(1)
return [(new_key, input_example)]
class SortGroupedDataFn(beam.DoFn):
"""Sorts data within a keyed group.
This Beam DoFn sorts the grouped list of image examples by frame_num
"""
def __init__(self, sequence_key, sorted_image_ids,
max_num_elements_in_context_features):
"""Initialization function.
Args:
sequence_key: A feature name to use as a key for grouping sequences.
Must point to a key of type bytes_list
sorted_image_ids: Whether the image ids are sortable to use as sorting
tie-breakers
max_num_elements_in_context_features: The maximum number of elements
allowed in the memory bank
"""
self._session = None
self._num_examples_processed = beam.metrics.Metrics.counter(
'sort_group', 'num_groups_sorted')
self._too_many_elements = beam.metrics.Metrics.counter(
'sort_group', 'too_many_elements')
self._split_elements = beam.metrics.Metrics.counter(
'sort_group', 'split_elements')
self._sequence_key = six.ensure_binary(sequence_key)
self._sorted_image_ids = sorted_image_ids
self._max_num_elements_in_context_features = (
max_num_elements_in_context_features)
def process(self, grouped_entry):
return self._sort_image_examples(grouped_entry)
def _sort_image_examples(self, grouped_entry):
key, example_collection = grouped_entry
example_list = list(example_collection)
def get_frame_num(example):
return example.features.feature['image/seq_frame_num'].int64_list.value[0]
def get_date_captured(example):
return datetime.datetime.strptime(
six.ensure_str(
example.features.feature[
'image/date_captured'].bytes_list.value[0]),
'%Y-%m-%d %H:%M:%S')
def get_image_id(example):
return example.features.feature['image/source_id'].bytes_list.value[0]
if self._sequence_key == six.ensure_binary('image/seq_id'):
sorting_fn = get_frame_num
elif self._sequence_key == six.ensure_binary('image/location'):
if self._sorted_image_ids:
sorting_fn = get_image_id
else:
sorting_fn = get_date_captured
sorted_example_list = sorted(example_list, key=sorting_fn)
self._num_examples_processed.inc(1)
if len(sorted_example_list) > self._max_num_elements_in_context_features:
leftovers = sorted_example_list
output_list = []
count = 0
self._too_many_elements.inc(1)
while len(leftovers) > self._max_num_elements_in_context_features:
self._split_elements.inc(1)
new_key = key + six.ensure_binary('_' + str(count))
new_list = leftovers[:self._max_num_elements_in_context_features]
output_list.append((new_key, new_list))
leftovers = leftovers[:self._max_num_elements_in_context_features]
count += 1
else:
output_list = [(key, sorted_example_list)]
return output_list
def get_sliding_window(example_list, max_clip_length, stride_length):
"""Yields a sliding window over data from example_list.
Sliding window has width max_clip_len (n) and stride stride_len (m).
s -> (s0,s1,...s[n-1]), (s[m],s[m+1],...,s[m+n]), ...
Args:
example_list: A list of examples.
max_clip_length: The maximum length of each clip.
stride_length: The stride between each clip.
Yields:
A list of lists of examples, each with length <= max_clip_length
"""
# check if the list is too short to slide over
if len(example_list) < max_clip_length:
yield example_list
else:
starting_values = [i*stride_length for i in
range(len(example_list)) if
len(example_list) > i*stride_length]
for start in starting_values:
result = tuple(itertools.islice(example_list, start,
min(start + max_clip_length,
len(example_list))))
yield result
class GenerateContextFn(beam.DoFn):
"""Generates context data for camera trap images.
This Beam DoFn builds up contextual memory banks from groups of images and
stores them in the output tf.Example or tf.Sequence_example for each image.
"""
def __init__(self, sequence_key, add_context_features, image_ids_to_keep,
keep_context_features_image_id_list=False,
subsample_context_features_rate=0,
keep_only_positives=False,
context_features_score_threshold=0.7,
keep_only_positives_gt=False,
max_num_elements_in_context_features=5000,
pad_context_features=False,
output_type='tf_example', max_clip_length=None):
"""Initialization function.
Args:
sequence_key: A feature name to use as a key for grouping sequences.
add_context_features: Whether to keep and store the contextual memory
bank.
image_ids_to_keep: A list of image ids to save, to use to build data
subsets for evaluation.
keep_context_features_image_id_list: Whether to save an ordered list of
the ids of the images in the contextual memory bank.
subsample_context_features_rate: What rate to subsample images for the
contextual memory bank.
keep_only_positives: Whether to only keep high scoring
(>context_features_score_threshold) features in the contextual memory
bank.
context_features_score_threshold: What threshold to use for keeping
features.
keep_only_positives_gt: Whether to only keep features from images that
contain objects based on the ground truth (for training).
max_num_elements_in_context_features: the maximum number of elements in
the memory bank
pad_context_features: Whether to pad the context features to a fixed size.
output_type: What type of output, tf_example of tf_sequence_example
max_clip_length: The maximum length of a sequence example, before
splitting into multiple
"""
self._session = None
self._num_examples_processed = beam.metrics.Metrics.counter(
'sequence_data_generation', 'num_seq_examples_processed')
self._num_keys_processed = beam.metrics.Metrics.counter(
'sequence_data_generation', 'num_keys_processed')
self._sequence_key = sequence_key
self._add_context_features = add_context_features
self._pad_context_features = pad_context_features
self._output_type = output_type
self._max_clip_length = max_clip_length
if six.ensure_str(image_ids_to_keep) == 'All':
self._image_ids_to_keep = None
else:
with tf.io.gfile.GFile(image_ids_to_keep) as f:
self._image_ids_to_keep = json.load(f)
self._keep_context_features_image_id_list = (
keep_context_features_image_id_list)
self._subsample_context_features_rate = subsample_context_features_rate
self._keep_only_positives = keep_only_positives
self._keep_only_positives_gt = keep_only_positives_gt
self._context_features_score_threshold = context_features_score_threshold
self._max_num_elements_in_context_features = (
max_num_elements_in_context_features)
self._images_kept = beam.metrics.Metrics.counter(
'sequence_data_generation', 'images_kept')
self._images_loaded = beam.metrics.Metrics.counter(
'sequence_data_generation', 'images_loaded')
def process(self, grouped_entry):
return self._add_context_to_example(copy.deepcopy(grouped_entry))
def _build_context_features(self, example_list):
context_features = []
context_features_image_id_list = []
count = 0
example_embedding = []
for idx, example in enumerate(example_list):
if self._subsample_context_features_rate > 0:
if (idx % self._subsample_context_features_rate) != 0:
example.features.feature[
'context_features_idx'].int64_list.value.append(
self._max_num_elements_in_context_features + 1)
continue
if self._keep_only_positives:
if example.features.feature[
'image/embedding_score'
].float_list.value[0] < self._context_features_score_threshold:
example.features.feature[
'context_features_idx'].int64_list.value.append(
self._max_num_elements_in_context_features + 1)
continue
if self._keep_only_positives_gt:
if len(example.features.feature[
'image/object/bbox/xmin'
].float_list.value) < 1:
example.features.feature[
'context_features_idx'].int64_list.value.append(
self._max_num_elements_in_context_features + 1)
continue
example_embedding = list(example.features.feature[
'image/embedding'].float_list.value)
context_features.extend(example_embedding)
example.features.feature[
'context_features_idx'].int64_list.value.append(count)
count += 1
example_image_id = example.features.feature[
'image/source_id'].bytes_list.value[0]
context_features_image_id_list.append(example_image_id)
if not example_embedding:
example_embedding.append(np.zeros(DEFAULT_FEATURE_LENGTH))
feature_length = DEFAULT_FEATURE_LENGTH
# If the example_list is not empty and image/embedding_length is in the
# featture dict, feature_length will be assigned to that. Otherwise, it will
# be kept as default.
if example_list and (
'image/embedding_length' in example_list[0].features.feature):
feature_length = example_list[0].features.feature[
'image/embedding_length'].int64_list.value[0]
if self._pad_context_features:
while len(context_features_image_id_list) < (
self._max_num_elements_in_context_features):
context_features_image_id_list.append('')
return context_features, feature_length, context_features_image_id_list
def _add_context_to_example(self, grouped_entry):
key, example_collection = grouped_entry
list_of_examples = []
example_list = list(example_collection)
if self._add_context_features:
context_features, feature_length, context_features_image_id_list = (
self._build_context_features(example_list))
if self._image_ids_to_keep is not None:
new_example_list = []
for example in example_list:
im_id = example.features.feature['image/source_id'].bytes_list.value[0]
self._images_loaded.inc(1)
if six.ensure_str(im_id) in self._image_ids_to_keep:
self._images_kept.inc(1)
new_example_list.append(example)
if new_example_list:
example_list = new_example_list
else:
return []
if self._output_type == 'tf_sequence_example':
if self._max_clip_length is not None:
# For now, no overlap
clips = get_sliding_window(
example_list, self._max_clip_length, self._max_clip_length)
else:
clips = [example_list]
for clip_num, clip_list in enumerate(clips):
# initialize sequence example
seq_example = tf.train.SequenceExample()
video_id = six.ensure_str(key)+'_'+ str(clip_num)
seq_example.context.feature['clip/media_id'].bytes_list.value.append(
video_id.encode('utf8'))
seq_example.context.feature['clip/frames'].int64_list.value.append(
len(clip_list))
seq_example.context.feature[
'clip/start/timestamp'].int64_list.value.append(0)
seq_example.context.feature[
'clip/end/timestamp'].int64_list.value.append(len(clip_list))
seq_example.context.feature['image/format'].bytes_list.value.append(
six.ensure_binary('JPG'))
seq_example.context.feature['image/channels'].int64_list.value.append(3)
context_example = clip_list[0]
seq_example.context.feature['image/height'].int64_list.value.append(
context_example.features.feature[
'image/height'].int64_list.value[0])
seq_example.context.feature['image/width'].int64_list.value.append(
context_example.features.feature['image/width'].int64_list.value[0])
seq_example.context.feature[
'image/context_feature_length'].int64_list.value.append(
feature_length)
seq_example.context.feature[
'image/context_features'].float_list.value.extend(
context_features)
if self._keep_context_features_image_id_list:
seq_example.context.feature[
'image/context_features_image_id_list'].bytes_list.value.extend(
context_features_image_id_list)
encoded_image_list = seq_example.feature_lists.feature_list[
'image/encoded']
timestamps_list = seq_example.feature_lists.feature_list[
'image/timestamp']
context_features_idx_list = seq_example.feature_lists.feature_list[
'image/context_features_idx']
date_captured_list = seq_example.feature_lists.feature_list[
'image/date_captured']
unix_time_list = seq_example.feature_lists.feature_list[
'image/unix_time']
location_list = seq_example.feature_lists.feature_list['image/location']
image_ids_list = seq_example.feature_lists.feature_list[
'image/source_id']
gt_xmin_list = seq_example.feature_lists.feature_list[
'region/bbox/xmin']
gt_xmax_list = seq_example.feature_lists.feature_list[
'region/bbox/xmax']
gt_ymin_list = seq_example.feature_lists.feature_list[
'region/bbox/ymin']
gt_ymax_list = seq_example.feature_lists.feature_list[
'region/bbox/ymax']
gt_type_list = seq_example.feature_lists.feature_list[
'region/label/index']
gt_type_string_list = seq_example.feature_lists.feature_list[
'region/label/string']
gt_is_annotated_list = seq_example.feature_lists.feature_list[
'region/is_annotated']
for idx, example in enumerate(clip_list):
encoded_image = encoded_image_list.feature.add()
encoded_image.bytes_list.value.extend(
example.features.feature['image/encoded'].bytes_list.value)
image_id = image_ids_list.feature.add()
image_id.bytes_list.value.append(
example.features.feature['image/source_id'].bytes_list.value[0])
timestamp = timestamps_list.feature.add()
# Timestamp is currently order in the list.
timestamp.int64_list.value.extend([idx])
context_features_idx = context_features_idx_list.feature.add()
context_features_idx.int64_list.value.extend(
example.features.feature['context_features_idx'].int64_list.value)
date_captured = date_captured_list.feature.add()
date_captured.bytes_list.value.extend(
example.features.feature['image/date_captured'].bytes_list.value)
unix_time = unix_time_list.feature.add()
unix_time.float_list.value.extend(
example.features.feature['image/unix_time'].float_list.value)
location = location_list.feature.add()
location.bytes_list.value.extend(
example.features.feature['image/location'].bytes_list.value)
gt_xmin = gt_xmin_list.feature.add()
gt_xmax = gt_xmax_list.feature.add()
gt_ymin = gt_ymin_list.feature.add()
gt_ymax = gt_ymax_list.feature.add()
gt_type = gt_type_list.feature.add()
gt_type_str = gt_type_string_list.feature.add()
gt_is_annotated = gt_is_annotated_list.feature.add()
gt_is_annotated.int64_list.value.append(1)
gt_xmin.float_list.value.extend(
example.features.feature[
'image/object/bbox/xmin'].float_list.value)
gt_xmax.float_list.value.extend(
example.features.feature[
'image/object/bbox/xmax'].float_list.value)
gt_ymin.float_list.value.extend(
example.features.feature[
'image/object/bbox/ymin'].float_list.value)
gt_ymax.float_list.value.extend(
example.features.feature[
'image/object/bbox/ymax'].float_list.value)
gt_type.int64_list.value.extend(
example.features.feature[
'image/object/class/label'].int64_list.value)
gt_type_str.bytes_list.value.extend(
example.features.feature[
'image/object/class/text'].bytes_list.value)
self._num_examples_processed.inc(1)
list_of_examples.append(seq_example)
elif self._output_type == 'tf_example':
for example in example_list:
im_id = example.features.feature['image/source_id'].bytes_list.value[0]
if self._add_context_features:
example.features.feature[
'image/context_features'].float_list.value.extend(
context_features)
example.features.feature[
'image/context_feature_length'].int64_list.value.append(
feature_length)
if self._keep_context_features_image_id_list:
example.features.feature[
'image/context_features_image_id_list'].bytes_list.value.extend(
context_features_image_id_list)
self._num_examples_processed.inc(1)
list_of_examples.append(example)
return list_of_examples
def construct_pipeline(input_tfrecord,
output_tfrecord,
sequence_key,
time_horizon=None,
subsample_context_features_rate=0,
reduce_image_size=True,
max_image_dimension=1024,
add_context_features=True,
sorted_image_ids=True,
image_ids_to_keep='All',
keep_context_features_image_id_list=False,
keep_only_positives=False,
context_features_score_threshold=0.7,
keep_only_positives_gt=False,
max_num_elements_in_context_features=5000,
num_shards=0,
output_type='tf_example',
max_clip_length=None):
"""Returns a beam pipeline to run object detection inference.
Args:
input_tfrecord: An TFRecord of tf.train.Example protos containing images.
output_tfrecord: An TFRecord of tf.train.Example protos that contain images
in the input TFRecord and the detections from the model.
sequence_key: A feature name to use as a key for grouping sequences.
time_horizon: What length of time to use to partition the data when building
the memory banks. Options: `year`, `month`, `week`, `day `, `hour`,
`minute`, None.
subsample_context_features_rate: What rate to subsample images for the
contextual memory bank.
reduce_image_size: Whether to reduce the size of the stored images.
max_image_dimension: The maximum image dimension to use for resizing.
add_context_features: Whether to keep and store the contextual memory bank.
sorted_image_ids: Whether the image ids are sortable, and can be used as
datetime tie-breakers when building memory banks.
image_ids_to_keep: A list of image ids to save, to use to build data subsets
for evaluation.
keep_context_features_image_id_list: Whether to save an ordered list of the
ids of the images in the contextual memory bank.
keep_only_positives: Whether to only keep high scoring
(>context_features_score_threshold) features in the contextual memory
bank.
context_features_score_threshold: What threshold to use for keeping
features.
keep_only_positives_gt: Whether to only keep features from images that
contain objects based on the ground truth (for training).
max_num_elements_in_context_features: the maximum number of elements in the
memory bank
num_shards: The number of output shards.
output_type: What type of output, tf_example of tf_sequence_example
max_clip_length: The maximum length of a sequence example, before
splitting into multiple
"""
def pipeline(root):
if output_type == 'tf_example':
coder = beam.coders.ProtoCoder(tf.train.Example)
elif output_type == 'tf_sequence_example':
coder = beam.coders.ProtoCoder(tf.train.SequenceExample)
else:
raise ValueError('Unsupported output type.')
input_collection = (
root | 'ReadInputTFRecord' >> beam.io.tfrecordio.ReadFromTFRecord(
input_tfrecord,
coder=beam.coders.BytesCoder()))
rekey_collection = input_collection | 'RekeyExamples' >> beam.ParDo(
ReKeyDataFn(sequence_key, time_horizon,
reduce_image_size, max_image_dimension))
grouped_collection = (
rekey_collection | 'GroupBySequenceKey' >> beam.GroupByKey())
grouped_collection = (
grouped_collection | 'ReshuffleGroups' >> beam.Reshuffle())
ordered_collection = (
grouped_collection | 'OrderByFrameNumber' >> beam.ParDo(
SortGroupedDataFn(sequence_key, sorted_image_ids,
max_num_elements_in_context_features)))
ordered_collection = (
ordered_collection | 'ReshuffleSortedGroups' >> beam.Reshuffle())
output_collection = (
ordered_collection | 'AddContextToExamples' >> beam.ParDo(
GenerateContextFn(
sequence_key, add_context_features, image_ids_to_keep,
keep_context_features_image_id_list=(
keep_context_features_image_id_list),
subsample_context_features_rate=subsample_context_features_rate,
keep_only_positives=keep_only_positives,
keep_only_positives_gt=keep_only_positives_gt,
context_features_score_threshold=(
context_features_score_threshold),
max_num_elements_in_context_features=(
max_num_elements_in_context_features),
output_type=output_type,
max_clip_length=max_clip_length)))
output_collection = (
output_collection | 'ReshuffleExamples' >> beam.Reshuffle())
_ = output_collection | 'WritetoDisk' >> beam.io.tfrecordio.WriteToTFRecord(
output_tfrecord,
num_shards=num_shards,
coder=coder)
return pipeline
def main(_):
"""Runs the Beam pipeline that builds context features.
Args:
_: unused
"""
# must create before flags are used
runner = runners.DirectRunner()
dirname = os.path.dirname(FLAGS.output_tfrecord)
tf.io.gfile.makedirs(dirname)
runner.run(
construct_pipeline(FLAGS.input_tfrecord,
FLAGS.output_tfrecord,
FLAGS.sequence_key,
FLAGS.time_horizon,
FLAGS.subsample_context_features_rate,
FLAGS.reduce_image_size,
FLAGS.max_image_dimension,
FLAGS.add_context_features,
FLAGS.sorted_image_ids,
FLAGS.image_ids_to_keep,
FLAGS.keep_context_features_image_id_list,
FLAGS.keep_only_positives,
FLAGS.context_features_score_threshold,
FLAGS.keep_only_positives_gt,
FLAGS.max_num_elements_in_context_features,
FLAGS.num_shards,
FLAGS.output_type,
FLAGS.max_clip_length))
if __name__ == '__main__':
flags.mark_flags_as_required([
'input_tfrecord',
'output_tfrecord'
])
app.run(main)
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for add_context_to_examples."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import contextlib
import datetime
import os
import tempfile
import unittest
import numpy as np
import six
import tensorflow.compat.v1 as tf
from object_detection.dataset_tools.context_rcnn import add_context_to_examples
from object_detection.utils import tf_version
from apache_beam import runners
@contextlib.contextmanager
def InMemoryTFRecord(entries):
temp = tempfile.NamedTemporaryFile(delete=False)
filename = temp.name
try:
with tf.python_io.TFRecordWriter(filename) as writer:
for value in entries:
writer.write(value)
yield filename
finally:
os.unlink(temp.name)
def BytesFeature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def BytesListFeature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))
def Int64Feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def Int64ListFeature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
def FloatListFeature(value):
return tf.train.Feature(float_list=tf.train.FloatList(value=value))
@unittest.skipIf(tf_version.is_tf2(), 'Skipping TF1.X only test.')
class GenerateContextDataTest(tf.test.TestCase):
def _create_first_tf_example(self):
with self.test_session():
encoded_image = tf.image.encode_jpeg(
tf.constant(np.ones((4, 4, 3)).astype(np.uint8))).eval()
example = tf.train.Example(features=tf.train.Features(feature={
'image/encoded': BytesFeature(encoded_image),
'image/source_id': BytesFeature(six.ensure_binary('image_id_1')),
'image/height': Int64Feature(4),
'image/width': Int64Feature(4),
'image/object/class/label': Int64ListFeature([5, 5]),
'image/object/class/text': BytesListFeature([six.ensure_binary('hyena'),
six.ensure_binary('hyena')
]),
'image/object/bbox/xmin': FloatListFeature([0.0, 0.1]),
'image/object/bbox/xmax': FloatListFeature([0.2, 0.3]),
'image/object/bbox/ymin': FloatListFeature([0.4, 0.5]),
'image/object/bbox/ymax': FloatListFeature([0.6, 0.7]),
'image/seq_id': BytesFeature(six.ensure_binary('01')),
'image/seq_num_frames': Int64Feature(2),
'image/seq_frame_num': Int64Feature(0),
'image/date_captured': BytesFeature(
six.ensure_binary(str(datetime.datetime(2020, 1, 1, 1, 0, 0)))),
'image/embedding': FloatListFeature([0.1, 0.2, 0.3]),
'image/embedding_score': FloatListFeature([0.9]),
'image/embedding_length': Int64Feature(3)
}))
return example.SerializeToString()
def _create_second_tf_example(self):
with self.test_session():
encoded_image = tf.image.encode_jpeg(
tf.constant(np.ones((4, 4, 3)).astype(np.uint8))).eval()
example = tf.train.Example(features=tf.train.Features(feature={
'image/encoded': BytesFeature(encoded_image),
'image/source_id': BytesFeature(six.ensure_binary('image_id_2')),
'image/height': Int64Feature(4),
'image/width': Int64Feature(4),
'image/object/class/label': Int64ListFeature([5]),
'image/object/class/text': BytesListFeature([six.ensure_binary('hyena')
]),
'image/object/bbox/xmin': FloatListFeature([0.0]),
'image/object/bbox/xmax': FloatListFeature([0.1]),
'image/object/bbox/ymin': FloatListFeature([0.2]),
'image/object/bbox/ymax': FloatListFeature([0.3]),
'image/seq_id': BytesFeature(six.ensure_binary('01')),
'image/seq_num_frames': Int64Feature(2),
'image/seq_frame_num': Int64Feature(1),
'image/date_captured': BytesFeature(
six.ensure_binary(str(datetime.datetime(2020, 1, 1, 1, 1, 0)))),
'image/embedding': FloatListFeature([0.4, 0.5, 0.6]),
'image/embedding_score': FloatListFeature([0.9]),
'image/embedding_length': Int64Feature(3)
}))
return example.SerializeToString()
def assert_expected_examples(self, tf_example_list):
self.assertAllEqual(
{tf_example.features.feature['image/source_id'].bytes_list.value[0]
for tf_example in tf_example_list},
{six.ensure_binary('image_id_1'), six.ensure_binary('image_id_2')})
self.assertAllClose(
tf_example_list[0].features.feature[
'image/context_features'].float_list.value,
[0.1, 0.2, 0.3, 0.4, 0.5, 0.6])
self.assertAllClose(
tf_example_list[1].features.feature[
'image/context_features'].float_list.value,
[0.1, 0.2, 0.3, 0.4, 0.5, 0.6])
def assert_expected_sequence_example(self, tf_sequence_example_list):
tf_sequence_example = tf_sequence_example_list[0]
num_frames = 2
self.assertAllEqual(
tf_sequence_example.context.feature[
'clip/media_id'].bytes_list.value[0], six.ensure_binary(
'01_0'))
self.assertAllClose(
tf_sequence_example.context.feature[
'image/context_features'].float_list.value,
[0.1, 0.2, 0.3, 0.4, 0.5, 0.6])
seq_feature_dict = tf_sequence_example.feature_lists.feature_list
self.assertLen(
seq_feature_dict['image/encoded'].feature[:],
num_frames)
actual_timestamps = [
feature.int64_list.value[0] for feature
in seq_feature_dict['image/timestamp'].feature]
timestamps = [0, 1]
self.assertAllEqual(timestamps, actual_timestamps)
# First image.
self.assertAllClose(
[0.4, 0.5],
seq_feature_dict['region/bbox/ymin'].feature[0].float_list.value[:])
self.assertAllClose(
[0.0, 0.1],
seq_feature_dict['region/bbox/xmin'].feature[0].float_list.value[:])
self.assertAllClose(
[0.6, 0.7],
seq_feature_dict['region/bbox/ymax'].feature[0].float_list.value[:])
self.assertAllClose(
[0.2, 0.3],
seq_feature_dict['region/bbox/xmax'].feature[0].float_list.value[:])
self.assertAllEqual(
[six.ensure_binary('hyena'), six.ensure_binary('hyena')],
seq_feature_dict['region/label/string'].feature[0].bytes_list.value[:])
# Second example.
self.assertAllClose(
[0.2],
seq_feature_dict['region/bbox/ymin'].feature[1].float_list.value[:])
self.assertAllClose(
[0.0],
seq_feature_dict['region/bbox/xmin'].feature[1].float_list.value[:])
self.assertAllClose(
[0.3],
seq_feature_dict['region/bbox/ymax'].feature[1].float_list.value[:])
self.assertAllClose(
[0.1],
seq_feature_dict['region/bbox/xmax'].feature[1].float_list.value[:])
self.assertAllEqual(
[six.ensure_binary('hyena')],
seq_feature_dict['region/label/string'].feature[1].bytes_list.value[:])
def assert_expected_key(self, key):
self.assertAllEqual(key, b'01')
def assert_sorted(self, example_collection):
example_list = list(example_collection)
counter = 0
for example in example_list:
frame_num = example.features.feature[
'image/seq_frame_num'].int64_list.value[0]
self.assertGreaterEqual(frame_num, counter)
counter = frame_num
def assert_context(self, example_collection):
example_list = list(example_collection)
for example in example_list:
context = example.features.feature[
'image/context_features'].float_list.value
self.assertAllClose([0.1, 0.2, 0.3, 0.4, 0.5, 0.6], context)
def assert_resized(self, example):
width = example.features.feature['image/width'].int64_list.value[0]
self.assertAllEqual(width, 2)
height = example.features.feature['image/height'].int64_list.value[0]
self.assertAllEqual(height, 2)
def assert_size(self, example):
width = example.features.feature['image/width'].int64_list.value[0]
self.assertAllEqual(width, 4)
height = example.features.feature['image/height'].int64_list.value[0]
self.assertAllEqual(height, 4)
def test_sliding_window(self):
example_list = ['a', 'b', 'c', 'd', 'e', 'f', 'g']
max_clip_length = 3
stride_length = 3
out_list = [list(i) for i in add_context_to_examples.get_sliding_window(
example_list, max_clip_length, stride_length)]
self.assertAllEqual(out_list, [['a', 'b', 'c'],
['d', 'e', 'f'],
['g']])
def test_rekey_data_fn(self):
sequence_key = 'image/seq_id'
time_horizon = None
reduce_image_size = False
max_dim = None
rekey_fn = add_context_to_examples.ReKeyDataFn(
sequence_key, time_horizon,
reduce_image_size, max_dim)
output = rekey_fn.process(self._create_first_tf_example())
self.assert_expected_key(output[0][0])
self.assert_size(output[0][1])
def test_rekey_data_fn_w_resize(self):
sequence_key = 'image/seq_id'
time_horizon = None
reduce_image_size = True
max_dim = 2
rekey_fn = add_context_to_examples.ReKeyDataFn(
sequence_key, time_horizon,
reduce_image_size, max_dim)
output = rekey_fn.process(self._create_first_tf_example())
self.assert_expected_key(output[0][0])
self.assert_resized(output[0][1])
def test_sort_fn(self):
sequence_key = 'image/seq_id'
sorted_image_ids = False
max_num_elements_in_context_features = 10
sort_fn = add_context_to_examples.SortGroupedDataFn(
sequence_key, sorted_image_ids, max_num_elements_in_context_features)
output = sort_fn.process(
('dummy_key', [tf.train.Example.FromString(
self._create_second_tf_example()),
tf.train.Example.FromString(
self._create_first_tf_example())]))
self.assert_sorted(output[0][1])
def test_add_context_fn(self):
sequence_key = 'image/seq_id'
add_context_features = True
image_ids_to_keep = 'All'
context_fn = add_context_to_examples.GenerateContextFn(
sequence_key, add_context_features, image_ids_to_keep)
output = context_fn.process(
('dummy_key', [tf.train.Example.FromString(
self._create_first_tf_example()),
tf.train.Example.FromString(
self._create_second_tf_example())]))
self.assertEqual(len(output), 2)
self.assert_context(output)
def test_add_context_fn_output_sequence_example(self):
sequence_key = 'image/seq_id'
add_context_features = True
image_ids_to_keep = 'All'
context_fn = add_context_to_examples.GenerateContextFn(
sequence_key, add_context_features, image_ids_to_keep,
output_type='tf_sequence_example')
output = context_fn.process(
('01',
[tf.train.Example.FromString(self._create_first_tf_example()),
tf.train.Example.FromString(self._create_second_tf_example())]))
self.assertEqual(len(output), 1)
self.assert_expected_sequence_example(output)
def test_add_context_fn_output_sequence_example_cliplen(self):
sequence_key = 'image/seq_id'
add_context_features = True
image_ids_to_keep = 'All'
context_fn = add_context_to_examples.GenerateContextFn(
sequence_key, add_context_features, image_ids_to_keep,
output_type='tf_sequence_example', max_clip_length=1)
output = context_fn.process(
('01',
[tf.train.Example.FromString(self._create_first_tf_example()),
tf.train.Example.FromString(self._create_second_tf_example())]))
self.assertEqual(len(output), 2)
def test_beam_pipeline(self):
with InMemoryTFRecord(
[self._create_first_tf_example(),
self._create_second_tf_example()]) as input_tfrecord:
runner = runners.DirectRunner()
temp_dir = tempfile.mkdtemp(dir=os.environ.get('TEST_TMPDIR'))
output_tfrecord = os.path.join(temp_dir, 'output_tfrecord')
sequence_key = six.ensure_binary('image/seq_id')
max_num_elements = 10
num_shards = 1
pipeline = add_context_to_examples.construct_pipeline(
input_tfrecord,
output_tfrecord,
sequence_key,
max_num_elements_in_context_features=max_num_elements,
num_shards=num_shards)
runner.run(pipeline)
filenames = tf.io.gfile.glob(output_tfrecord + '-?????-of-?????')
actual_output = []
record_iterator = tf.python_io.tf_record_iterator(path=filenames[0])
for record in record_iterator:
actual_output.append(record)
self.assertEqual(len(actual_output), 2)
self.assert_expected_examples([tf.train.Example.FromString(
tf_example) for tf_example in actual_output])
def test_beam_pipeline_sequence_example(self):
with InMemoryTFRecord(
[self._create_first_tf_example(),
self._create_second_tf_example()]) as input_tfrecord:
runner = runners.DirectRunner()
temp_dir = tempfile.mkdtemp(dir=os.environ.get('TEST_TMPDIR'))
output_tfrecord = os.path.join(temp_dir, 'output_tfrecord')
sequence_key = six.ensure_binary('image/seq_id')
max_num_elements = 10
num_shards = 1
pipeline = add_context_to_examples.construct_pipeline(
input_tfrecord,
output_tfrecord,
sequence_key,
max_num_elements_in_context_features=max_num_elements,
num_shards=num_shards,
output_type='tf_sequence_example')
runner.run(pipeline)
filenames = tf.io.gfile.glob(output_tfrecord + '-?????-of-?????')
actual_output = []
record_iterator = tf.python_io.tf_record_iterator(
path=filenames[0])
for record in record_iterator:
actual_output.append(record)
self.assertEqual(len(actual_output), 1)
self.assert_expected_sequence_example(
[tf.train.SequenceExample.FromString(
tf_example) for tf_example in actual_output])
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment