Commit 44e7092c authored by stephenwu's avatar stephenwu
Browse files

Merge branch 'master' of https://github.com/tensorflow/models into AXg

parents 431a9ca3 59434199
...@@ -275,4 +275,11 @@ class VideoClassificationTask(base_task.Task): ...@@ -275,4 +275,11 @@ class VideoClassificationTask(base_task.Task):
outputs = tf.math.sigmoid(outputs) outputs = tf.math.sigmoid(outputs)
else: else:
outputs = tf.math.softmax(outputs) outputs = tf.math.softmax(outputs)
num_test_clips = self.task_config.validation_data.num_test_clips
num_test_crops = self.task_config.validation_data.num_test_crops
num_test_views = num_test_clips * num_test_crops
if num_test_views > 1:
# Averaging output probabilities across multiples views.
outputs = tf.reshape(outputs, [-1, num_test_views, outputs.shape[-1]])
outputs = tf.reduce_mean(outputs, axis=1)
return outputs return outputs
...@@ -63,6 +63,8 @@ def main(_): ...@@ -63,6 +63,8 @@ def main(_):
params=params, params=params,
model_dir=model_dir) model_dir=model_dir)
train_utils.save_gin_config(FLAGS.mode, model_dir)
if __name__ == '__main__': if __name__ == '__main__':
tfm_flags.define_flags() tfm_flags.define_flags()
app.run(main) app.run(main)
...@@ -142,7 +142,10 @@ class SpatialPyramidPooling(tf.keras.layers.Layer): ...@@ -142,7 +142,10 @@ class SpatialPyramidPooling(tf.keras.layers.Layer):
epsilon=self.batchnorm_epsilon), epsilon=self.batchnorm_epsilon),
tf.keras.layers.Activation(self.activation), tf.keras.layers.Activation(self.activation),
tf.keras.layers.experimental.preprocessing.Resizing( tf.keras.layers.experimental.preprocessing.Resizing(
height, width, interpolation=self.interpolation) height,
width,
interpolation=self.interpolation,
dtype=tf.float32)
])) ]))
self.aspp_layers.append(pool_sequential) self.aspp_layers.append(pool_sequential)
...@@ -165,7 +168,7 @@ class SpatialPyramidPooling(tf.keras.layers.Layer): ...@@ -165,7 +168,7 @@ class SpatialPyramidPooling(tf.keras.layers.Layer):
training = tf.keras.backend.learning_phase() training = tf.keras.backend.learning_phase()
result = [] result = []
for layer in self.aspp_layers: for layer in self.aspp_layers:
result.append(layer(inputs, training=training)) result.append(tf.cast(layer(inputs, training=training), inputs.dtype))
result = tf.concat(result, axis=-1) result = tf.concat(result, axis=-1)
result = self.projection(result, training=training) result = self.projection(result, training=training)
return result return result
......
...@@ -27,6 +27,7 @@ from __future__ import division ...@@ -27,6 +27,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import functools import functools
import math
import tensorflow.compat.v1 as tf import tensorflow.compat.v1 as tf
from object_detection.builders import decoder_builder from object_detection.builders import decoder_builder
...@@ -52,6 +53,7 @@ def make_initializable_iterator(dataset): ...@@ -52,6 +53,7 @@ def make_initializable_iterator(dataset):
def _read_dataset_internal(file_read_func, def _read_dataset_internal(file_read_func,
input_files, input_files,
num_readers,
config, config,
filename_shard_fn=None): filename_shard_fn=None):
"""Reads a dataset, and handles repetition and shuffling. """Reads a dataset, and handles repetition and shuffling.
...@@ -60,6 +62,7 @@ def _read_dataset_internal(file_read_func, ...@@ -60,6 +62,7 @@ def _read_dataset_internal(file_read_func,
file_read_func: Function to use in tf_data.parallel_interleave, to read file_read_func: Function to use in tf_data.parallel_interleave, to read
every individual file into a tf.data.Dataset. every individual file into a tf.data.Dataset.
input_files: A list of file paths to read. input_files: A list of file paths to read.
num_readers: Number of readers to use.
config: A input_reader_builder.InputReader object. config: A input_reader_builder.InputReader object.
filename_shard_fn: optional, A function used to shard filenames across filename_shard_fn: optional, A function used to shard filenames across
replicas. This function takes as input a TF dataset of filenames and is replicas. This function takes as input a TF dataset of filenames and is
...@@ -79,7 +82,6 @@ def _read_dataset_internal(file_read_func, ...@@ -79,7 +82,6 @@ def _read_dataset_internal(file_read_func,
if not filenames: if not filenames:
raise RuntimeError('Did not find any input files matching the glob pattern ' raise RuntimeError('Did not find any input files matching the glob pattern '
'{}'.format(input_files)) '{}'.format(input_files))
num_readers = config.num_readers
if num_readers > len(filenames): if num_readers > len(filenames):
num_readers = len(filenames) num_readers = len(filenames)
tf.logging.warning('num_readers has been reduced to %d to match input file ' tf.logging.warning('num_readers has been reduced to %d to match input file '
...@@ -137,17 +139,30 @@ def read_dataset(file_read_func, input_files, config, filename_shard_fn=None): ...@@ -137,17 +139,30 @@ def read_dataset(file_read_func, input_files, config, filename_shard_fn=None):
tf.logging.info('Sampling from datasets %s with weights %s' % tf.logging.info('Sampling from datasets %s with weights %s' %
(input_files, config.sample_from_datasets_weights)) (input_files, config.sample_from_datasets_weights))
records_datasets = [] records_datasets = []
for input_file in input_files: dataset_weights = []
for i, input_file in enumerate(input_files):
weight = config.sample_from_datasets_weights[i]
num_readers = math.ceil(config.num_readers *
weight /
sum(config.sample_from_datasets_weights))
tf.logging.info(
'Num readers for dataset [%s]: %d', input_file, num_readers)
if num_readers == 0:
tf.logging.info('Skipping dataset due to zero weights: %s', input_file)
continue
tf.logging.info(
'Num readers for dataset [%s]: %d', input_file, num_readers)
records_dataset = _read_dataset_internal(file_read_func, [input_file], records_dataset = _read_dataset_internal(file_read_func, [input_file],
config, filename_shard_fn) num_readers, config,
filename_shard_fn)
dataset_weights.append(weight)
records_datasets.append(records_dataset) records_datasets.append(records_dataset)
dataset_weights = list(config.sample_from_datasets_weights)
return tf.data.experimental.sample_from_datasets(records_datasets, return tf.data.experimental.sample_from_datasets(records_datasets,
dataset_weights) dataset_weights)
else: else:
tf.logging.info('Reading unweighted datasets: %s' % input_files) tf.logging.info('Reading unweighted datasets: %s' % input_files)
return _read_dataset_internal(file_read_func, input_files, config, return _read_dataset_internal(file_read_func, input_files,
filename_shard_fn) config.num_readers, config, filename_shard_fn)
def shard_function_for_context(input_context): def shard_function_for_context(input_context):
......
...@@ -60,7 +60,9 @@ def build(input_reader_config): ...@@ -60,7 +60,9 @@ def build(input_reader_config):
num_keypoints=input_reader_config.num_keypoints, num_keypoints=input_reader_config.num_keypoints,
expand_hierarchy_labels=input_reader_config.expand_labels_hierarchy, expand_hierarchy_labels=input_reader_config.expand_labels_hierarchy,
load_dense_pose=input_reader_config.load_dense_pose, load_dense_pose=input_reader_config.load_dense_pose,
load_track_id=input_reader_config.load_track_id) load_track_id=input_reader_config.load_track_id,
load_keypoint_depth_features=input_reader_config
.load_keypoint_depth_features)
return decoder return decoder
elif input_type == input_reader_pb2.InputType.Value('TF_SEQUENCE_EXAMPLE'): elif input_type == input_reader_pb2.InputType.Value('TF_SEQUENCE_EXAMPLE'):
decoder = tf_sequence_example_decoder.TfSequenceExampleDecoder( decoder = tf_sequence_example_decoder.TfSequenceExampleDecoder(
......
...@@ -65,6 +65,8 @@ class DecoderBuilderTest(test_case.TestCase): ...@@ -65,6 +65,8 @@ class DecoderBuilderTest(test_case.TestCase):
'image/object/bbox/ymax': dataset_util.float_list_feature([1.0]), 'image/object/bbox/ymax': dataset_util.float_list_feature([1.0]),
'image/object/class/label': dataset_util.int64_list_feature([2]), 'image/object/class/label': dataset_util.int64_list_feature([2]),
'image/object/mask': dataset_util.float_list_feature(flat_mask), 'image/object/mask': dataset_util.float_list_feature(flat_mask),
'image/object/keypoint/x': dataset_util.float_list_feature([1.0, 1.0]),
'image/object/keypoint/y': dataset_util.float_list_feature([1.0, 1.0])
} }
if has_additional_channels: if has_additional_channels:
additional_channels_key = 'image/additional_channels/encoded' additional_channels_key = 'image/additional_channels/encoded'
...@@ -188,6 +190,28 @@ class DecoderBuilderTest(test_case.TestCase): ...@@ -188,6 +190,28 @@ class DecoderBuilderTest(test_case.TestCase):
masks = self.execute_cpu(graph_fn, []) masks = self.execute_cpu(graph_fn, [])
self.assertAllEqual((1, 4, 5), masks.shape) self.assertAllEqual((1, 4, 5), masks.shape)
def test_build_tf_record_input_reader_and_load_keypoint_depth(self):
input_reader_text_proto = """
load_keypoint_depth_features: true
num_keypoints: 2
tf_record_input_reader {}
"""
input_reader_proto = input_reader_pb2.InputReader()
text_format.Parse(input_reader_text_proto, input_reader_proto)
decoder = decoder_builder.build(input_reader_proto)
serialized_example = self._make_serialized_tf_example()
def graph_fn():
tensor_dict = decoder.decode(serialized_example)
return (tensor_dict[fields.InputDataFields.groundtruth_keypoint_depths],
tensor_dict[
fields.InputDataFields.groundtruth_keypoint_depth_weights])
(kpts_depths, kpts_depth_weights) = self.execute_cpu(graph_fn, [])
self.assertAllEqual((1, 2), kpts_depths.shape)
self.assertAllEqual((1, 2), kpts_depth_weights.shape)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
...@@ -20,7 +20,11 @@ import tf_slim as slim ...@@ -20,7 +20,11 @@ import tf_slim as slim
from object_detection.core import freezable_batch_norm from object_detection.core import freezable_batch_norm
from object_detection.protos import hyperparams_pb2 from object_detection.protos import hyperparams_pb2
from object_detection.utils import context_manager from object_detection.utils import context_manager
from object_detection.utils import tf_version
# pylint: disable=g-import-not-at-top
if tf_version.is_tf2():
from object_detection.core import freezable_sync_batch_norm
# pylint: enable=g-import-not-at-top # pylint: enable=g-import-not-at-top
...@@ -60,9 +64,14 @@ class KerasLayerHyperparams(object): ...@@ -60,9 +64,14 @@ class KerasLayerHyperparams(object):
'hyperparams_pb.Hyperparams.') 'hyperparams_pb.Hyperparams.')
self._batch_norm_params = None self._batch_norm_params = None
self._use_sync_batch_norm = False
if hyperparams_config.HasField('batch_norm'): if hyperparams_config.HasField('batch_norm'):
self._batch_norm_params = _build_keras_batch_norm_params( self._batch_norm_params = _build_keras_batch_norm_params(
hyperparams_config.batch_norm) hyperparams_config.batch_norm)
elif hyperparams_config.HasField('sync_batch_norm'):
self._use_sync_batch_norm = True
self._batch_norm_params = _build_keras_batch_norm_params(
hyperparams_config.sync_batch_norm)
self._force_use_bias = hyperparams_config.force_use_bias self._force_use_bias = hyperparams_config.force_use_bias
self._activation_fn = _build_activation_fn(hyperparams_config.activation) self._activation_fn = _build_activation_fn(hyperparams_config.activation)
...@@ -133,10 +142,12 @@ class KerasLayerHyperparams(object): ...@@ -133,10 +142,12 @@ class KerasLayerHyperparams(object):
is False) is False)
""" """
if self.use_batch_norm(): if self.use_batch_norm():
if self._use_sync_batch_norm:
return freezable_sync_batch_norm.FreezableSyncBatchNorm(
training=training, **self.batch_norm_params(**overrides))
else:
return freezable_batch_norm.FreezableBatchNorm( return freezable_batch_norm.FreezableBatchNorm(
training=training, training=training, **self.batch_norm_params(**overrides))
**self.batch_norm_params(**overrides)
)
else: else:
return tf.keras.layers.Lambda(tf.identity) return tf.keras.layers.Lambda(tf.identity)
...@@ -219,6 +230,10 @@ def build(hyperparams_config, is_training): ...@@ -219,6 +230,10 @@ def build(hyperparams_config, is_training):
raise ValueError('Hyperparams force_use_bias only supported by ' raise ValueError('Hyperparams force_use_bias only supported by '
'KerasLayerHyperparams.') 'KerasLayerHyperparams.')
if hyperparams_config.HasField('sync_batch_norm'):
raise ValueError('Hyperparams sync_batch_norm only supported by '
'KerasLayerHyperparams.')
normalizer_fn = None normalizer_fn = None
batch_norm_params = None batch_norm_params = None
if hyperparams_config.HasField('batch_norm'): if hyperparams_config.HasField('batch_norm'):
......
...@@ -1039,7 +1039,10 @@ def _build_center_net_model(center_net_config, is_training, add_summaries): ...@@ -1039,7 +1039,10 @@ def _build_center_net_model(center_net_config, is_training, add_summaries):
if center_net_config.HasField('temporal_offset_task'): if center_net_config.HasField('temporal_offset_task'):
temporal_offset_params = temporal_offset_proto_to_params( temporal_offset_params = temporal_offset_proto_to_params(
center_net_config.temporal_offset_task) center_net_config.temporal_offset_task)
non_max_suppression_fn = None
if center_net_config.HasField('post_processing'):
non_max_suppression_fn, _ = post_processing_builder.build(
center_net_config.post_processing)
return center_net_meta_arch.CenterNetMetaArch( return center_net_meta_arch.CenterNetMetaArch(
is_training=is_training, is_training=is_training,
add_summaries=add_summaries, add_summaries=add_summaries,
...@@ -1054,7 +1057,8 @@ def _build_center_net_model(center_net_config, is_training, add_summaries): ...@@ -1054,7 +1057,8 @@ def _build_center_net_model(center_net_config, is_training, add_summaries):
track_params=track_params, track_params=track_params,
temporal_offset_params=temporal_offset_params, temporal_offset_params=temporal_offset_params,
use_depthwise=center_net_config.use_depthwise, use_depthwise=center_net_config.use_depthwise,
compute_heatmap_sparse=center_net_config.compute_heatmap_sparse) compute_heatmap_sparse=center_net_config.compute_heatmap_sparse,
non_max_suppression_fn=non_max_suppression_fn)
def _build_center_net_feature_extractor( def _build_center_net_feature_extractor(
......
...@@ -17,25 +17,40 @@ ...@@ -17,25 +17,40 @@
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import unittest import unittest
from absl.testing import parameterized
import numpy as np import numpy as np
from six.moves import zip from six.moves import zip
import tensorflow.compat.v1 as tf import tensorflow as tf
from object_detection.core import freezable_batch_norm from object_detection.core import freezable_batch_norm
from object_detection.utils import tf_version from object_detection.utils import tf_version
# pylint: disable=g-import-not-at-top
if tf_version.is_tf2():
from object_detection.core import freezable_sync_batch_norm
# pylint: enable=g-import-not-at-top
@unittest.skipIf(tf_version.is_tf1(), 'Skipping TF2.X only test.') @unittest.skipIf(tf_version.is_tf1(), 'Skipping TF2.X only test.')
class FreezableBatchNormTest(tf.test.TestCase): class FreezableBatchNormTest(tf.test.TestCase, parameterized.TestCase):
"""Tests for FreezableBatchNorm operations.""" """Tests for FreezableBatchNorm operations."""
def _build_model(self, training=None): def _build_model(self, use_sync_batch_norm, training=None):
model = tf.keras.models.Sequential() model = tf.keras.models.Sequential()
norm = None
if use_sync_batch_norm:
norm = freezable_sync_batch_norm.FreezableSyncBatchNorm(training=training,
input_shape=(10,),
momentum=0.8)
else:
norm = freezable_batch_norm.FreezableBatchNorm(training=training, norm = freezable_batch_norm.FreezableBatchNorm(training=training,
input_shape=(10,), input_shape=(10,),
momentum=0.8) momentum=0.8)
model.add(norm) model.add(norm)
return model, norm return model, norm
...@@ -43,8 +58,9 @@ class FreezableBatchNormTest(tf.test.TestCase): ...@@ -43,8 +58,9 @@ class FreezableBatchNormTest(tf.test.TestCase):
for source, target in zip(source_weights, target_weights): for source, target in zip(source_weights, target_weights):
target.assign(source) target.assign(source)
def _train_freezable_batch_norm(self, training_mean, training_var): def _train_freezable_batch_norm(self, training_mean, training_var,
model, _ = self._build_model() use_sync_batch_norm):
model, _ = self._build_model(use_sync_batch_norm=use_sync_batch_norm)
model.compile(loss='mse', optimizer='sgd') model.compile(loss='mse', optimizer='sgd')
# centered on training_mean, variance training_var # centered on training_mean, variance training_var
...@@ -72,7 +88,8 @@ class FreezableBatchNormTest(tf.test.TestCase): ...@@ -72,7 +88,8 @@ class FreezableBatchNormTest(tf.test.TestCase):
np.testing.assert_allclose(out.numpy().mean(), 0.0, atol=1.5e-1) np.testing.assert_allclose(out.numpy().mean(), 0.0, atol=1.5e-1)
np.testing.assert_allclose(out.numpy().std(), 1.0, atol=1.5e-1) np.testing.assert_allclose(out.numpy().std(), 1.0, atol=1.5e-1)
def test_batchnorm_freezing_training_none(self): @parameterized.parameters(True, False)
def test_batchnorm_freezing_training_none(self, use_sync_batch_norm):
training_mean = 5.0 training_mean = 5.0
training_var = 10.0 training_var = 10.0
...@@ -81,12 +98,13 @@ class FreezableBatchNormTest(tf.test.TestCase): ...@@ -81,12 +98,13 @@ class FreezableBatchNormTest(tf.test.TestCase):
# Initially train the batch norm, and save the weights # Initially train the batch norm, and save the weights
trained_weights = self._train_freezable_batch_norm(training_mean, trained_weights = self._train_freezable_batch_norm(training_mean,
training_var) training_var,
use_sync_batch_norm)
# Load the batch norm weights, freezing training to True. # Load the batch norm weights, freezing training to True.
# Apply the batch norm layer to testing data and ensure it is normalized # Apply the batch norm layer to testing data and ensure it is normalized
# according to the batch statistics. # according to the batch statistics.
model, norm = self._build_model(training=True) model, norm = self._build_model(use_sync_batch_norm, training=True)
self._copy_weights(trained_weights, model.weights) self._copy_weights(trained_weights, model.weights)
# centered on testing_mean, variance testing_var # centered on testing_mean, variance testing_var
...@@ -136,7 +154,8 @@ class FreezableBatchNormTest(tf.test.TestCase): ...@@ -136,7 +154,8 @@ class FreezableBatchNormTest(tf.test.TestCase):
testing_mean, testing_var, training_arg, testing_mean, testing_var, training_arg,
training_mean, training_var) training_mean, training_var)
def test_batchnorm_freezing_training_false(self): @parameterized.parameters(True, False)
def test_batchnorm_freezing_training_false(self, use_sync_batch_norm):
training_mean = 5.0 training_mean = 5.0
training_var = 10.0 training_var = 10.0
...@@ -145,12 +164,13 @@ class FreezableBatchNormTest(tf.test.TestCase): ...@@ -145,12 +164,13 @@ class FreezableBatchNormTest(tf.test.TestCase):
# Initially train the batch norm, and save the weights # Initially train the batch norm, and save the weights
trained_weights = self._train_freezable_batch_norm(training_mean, trained_weights = self._train_freezable_batch_norm(training_mean,
training_var) training_var,
use_sync_batch_norm)
# Load the batch norm back up, freezing training to False. # Load the batch norm back up, freezing training to False.
# Apply the batch norm layer to testing data and ensure it is normalized # Apply the batch norm layer to testing data and ensure it is normalized
# according to the training data's statistics. # according to the training data's statistics.
model, norm = self._build_model(training=False) model, norm = self._build_model(use_sync_batch_norm, training=False)
self._copy_weights(trained_weights, model.weights) self._copy_weights(trained_weights, model.weights)
# centered on testing_mean, variance testing_var # centered on testing_mean, variance testing_var
......
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A freezable batch norm layer that uses Keras sync batch normalization."""
import tensorflow as tf
class FreezableSyncBatchNorm(tf.keras.layers.experimental.SyncBatchNormalization
):
"""Sync Batch normalization layer (Ioffe and Szegedy, 2014).
This is a `freezable` batch norm layer that supports setting the `training`
parameter in the __init__ method rather than having to set it either via
the Keras learning phase or via the `call` method parameter. This layer will
forward all other parameters to the Keras `SyncBatchNormalization` layer
This is class is necessary because Object Detection model training sometimes
requires batch normalization layers to be `frozen` and used as if it was
evaluation time, despite still training (and potentially using dropout layers)
Like the default Keras SyncBatchNormalization layer, this will normalize the
activations of the previous layer at each batch,
i.e. applies a transformation that maintains the mean activation
close to 0 and the activation standard deviation close to 1.
Input shape:
Arbitrary. Use the keyword argument `input_shape`
(tuple of integers, does not include the samples axis)
when using this layer as the first layer in a model.
Output shape:
Same shape as input.
References:
- [Batch Normalization: Accelerating Deep Network Training by Reducing
Internal Covariate Shift](https://arxiv.org/abs/1502.03167)
"""
def __init__(self, training=None, **kwargs):
"""Constructor.
Args:
training: If False, the layer will normalize using the moving average and
std. dev, without updating the learned avg and std. dev.
If None or True, the layer will follow the keras SyncBatchNormalization
layer strategy of checking the Keras learning phase at `call` time to
decide what to do.
**kwargs: The keyword arguments to forward to the keras
SyncBatchNormalization layer constructor.
"""
super(FreezableSyncBatchNorm, self).__init__(**kwargs)
self._training = training
def call(self, inputs, training=None):
# Override the call arg only if the batchnorm is frozen. (Ignore None)
if self._training is False: # pylint: disable=g-bool-id-comparison
training = self._training
return super(FreezableSyncBatchNorm, self).call(inputs, training=training)
...@@ -315,7 +315,9 @@ class DetectionModel(six.with_metaclass(abc.ABCMeta, _BaseClass)): ...@@ -315,7 +315,9 @@ class DetectionModel(six.with_metaclass(abc.ABCMeta, _BaseClass)):
is_annotated_list=None, is_annotated_list=None,
groundtruth_labeled_classes=None, groundtruth_labeled_classes=None,
groundtruth_verified_neg_classes=None, groundtruth_verified_neg_classes=None,
groundtruth_not_exhaustive_classes=None): groundtruth_not_exhaustive_classes=None,
groundtruth_keypoint_depths_list=None,
groundtruth_keypoint_depth_weights_list=None):
"""Provide groundtruth tensors. """Provide groundtruth tensors.
Args: Args:
...@@ -379,6 +381,11 @@ class DetectionModel(six.with_metaclass(abc.ABCMeta, _BaseClass)): ...@@ -379,6 +381,11 @@ class DetectionModel(six.with_metaclass(abc.ABCMeta, _BaseClass)):
groundtruth_not_exhaustive_classes: A list of 1-D tf.float32 tensors of groundtruth_not_exhaustive_classes: A list of 1-D tf.float32 tensors of
shape [num_classes], containing a K-hot representation of classes shape [num_classes], containing a K-hot representation of classes
which don't have all of their instances marked exhaustively. which don't have all of their instances marked exhaustively.
groundtruth_keypoint_depths_list: a list of 2-D tf.float32 tensors
of shape [num_boxes, num_keypoints] containing keypoint relative depths.
groundtruth_keypoint_depth_weights_list: a list of 2-D tf.float32 tensors
of shape [num_boxes, num_keypoints] containing the weights of the
relative depths.
""" """
self._groundtruth_lists[fields.BoxListFields.boxes] = groundtruth_boxes_list self._groundtruth_lists[fields.BoxListFields.boxes] = groundtruth_boxes_list
self._groundtruth_lists[ self._groundtruth_lists[
...@@ -399,6 +406,14 @@ class DetectionModel(six.with_metaclass(abc.ABCMeta, _BaseClass)): ...@@ -399,6 +406,14 @@ class DetectionModel(six.with_metaclass(abc.ABCMeta, _BaseClass)):
self._groundtruth_lists[ self._groundtruth_lists[
fields.BoxListFields.keypoint_visibilities] = ( fields.BoxListFields.keypoint_visibilities] = (
groundtruth_keypoint_visibilities_list) groundtruth_keypoint_visibilities_list)
if groundtruth_keypoint_depths_list:
self._groundtruth_lists[
fields.BoxListFields.keypoint_depths] = (
groundtruth_keypoint_depths_list)
if groundtruth_keypoint_depth_weights_list:
self._groundtruth_lists[
fields.BoxListFields.keypoint_depth_weights] = (
groundtruth_keypoint_depth_weights_list)
if groundtruth_dp_num_points_list: if groundtruth_dp_num_points_list:
self._groundtruth_lists[ self._groundtruth_lists[
fields.BoxListFields.densepose_num_points] = ( fields.BoxListFields.densepose_num_points] = (
......
...@@ -26,6 +26,7 @@ import tensorflow.compat.v1 as tf ...@@ -26,6 +26,7 @@ import tensorflow.compat.v1 as tf
from object_detection.core import box_list from object_detection.core import box_list
from object_detection.core import box_list_ops from object_detection.core import box_list_ops
from object_detection.core import keypoint_ops
from object_detection.core import standard_fields as fields from object_detection.core import standard_fields as fields
from object_detection.utils import shape_utils from object_detection.utils import shape_utils
...@@ -379,6 +380,11 @@ def _clip_window_prune_boxes(sorted_boxes, clip_window, pad_to_max_output_size, ...@@ -379,6 +380,11 @@ def _clip_window_prune_boxes(sorted_boxes, clip_window, pad_to_max_output_size,
if change_coordinate_frame: if change_coordinate_frame:
sorted_boxes = box_list_ops.change_coordinate_frame(sorted_boxes, sorted_boxes = box_list_ops.change_coordinate_frame(sorted_boxes,
clip_window) clip_window)
if sorted_boxes.has_field(fields.BoxListFields.keypoints):
sorted_keypoints = sorted_boxes.get_field(fields.BoxListFields.keypoints)
sorted_keypoints = keypoint_ops.change_coordinate_frame(sorted_keypoints,
clip_window)
sorted_boxes.set_field(fields.BoxListFields.keypoints, sorted_keypoints)
return sorted_boxes, num_valid_nms_boxes_cumulative return sorted_boxes, num_valid_nms_boxes_cumulative
......
...@@ -571,6 +571,8 @@ def random_horizontal_flip(image, ...@@ -571,6 +571,8 @@ def random_horizontal_flip(image,
keypoint_visibilities=None, keypoint_visibilities=None,
densepose_part_ids=None, densepose_part_ids=None,
densepose_surface_coords=None, densepose_surface_coords=None,
keypoint_depths=None,
keypoint_depth_weights=None,
keypoint_flip_permutation=None, keypoint_flip_permutation=None,
probability=0.5, probability=0.5,
seed=None, seed=None,
...@@ -602,6 +604,12 @@ def random_horizontal_flip(image, ...@@ -602,6 +604,12 @@ def random_horizontal_flip(image,
(y, x) are the normalized image coordinates for a (y, x) are the normalized image coordinates for a
sampled point, and (v, u) is the surface sampled point, and (v, u) is the surface
coordinate for the part. coordinate for the part.
keypoint_depths: (optional) rank 2 float32 tensor with shape [num_instances,
num_keypoints] representing the relative depth of the
keypoints.
keypoint_depth_weights: (optional) rank 2 float32 tensor with shape
[num_instances, num_keypoints] representing the
weights of the relative depth of the keypoints.
keypoint_flip_permutation: rank 1 int32 tensor containing the keypoint flip keypoint_flip_permutation: rank 1 int32 tensor containing the keypoint flip
permutation. permutation.
probability: the probability of performing this augmentation. probability: the probability of performing this augmentation.
...@@ -631,6 +639,10 @@ def random_horizontal_flip(image, ...@@ -631,6 +639,10 @@ def random_horizontal_flip(image,
[num_instances, num_points]. [num_instances, num_points].
densepose_surface_coords: rank 3 float32 tensor with shape densepose_surface_coords: rank 3 float32 tensor with shape
[num_instances, num_points, 4]. [num_instances, num_points, 4].
keypoint_depths: rank 2 float32 tensor with shape [num_instances,
num_keypoints]
keypoint_depth_weights: rank 2 float32 tensor with shape [num_instances,
num_keypoints].
Raises: Raises:
ValueError: if keypoints are provided but keypoint_flip_permutation is not. ValueError: if keypoints are provided but keypoint_flip_permutation is not.
...@@ -708,6 +720,21 @@ def random_horizontal_flip(image, ...@@ -708,6 +720,21 @@ def random_horizontal_flip(image,
lambda: (densepose_part_ids, densepose_surface_coords)) lambda: (densepose_part_ids, densepose_surface_coords))
result.extend(densepose_tensors) result.extend(densepose_tensors)
# flip keypoint depths and weights.
if (keypoint_depths is not None and
keypoint_flip_permutation is not None):
kpt_flip_perm = keypoint_flip_permutation
keypoint_depths = tf.cond(
do_a_flip_random,
lambda: tf.gather(keypoint_depths, kpt_flip_perm, axis=1),
lambda: keypoint_depths)
keypoint_depth_weights = tf.cond(
do_a_flip_random,
lambda: tf.gather(keypoint_depth_weights, kpt_flip_perm, axis=1),
lambda: keypoint_depth_weights)
result.append(keypoint_depths)
result.append(keypoint_depth_weights)
return tuple(result) return tuple(result)
...@@ -4293,7 +4320,8 @@ def get_default_func_arg_map(include_label_weights=True, ...@@ -4293,7 +4320,8 @@ def get_default_func_arg_map(include_label_weights=True,
include_instance_masks=False, include_instance_masks=False,
include_keypoints=False, include_keypoints=False,
include_keypoint_visibilities=False, include_keypoint_visibilities=False,
include_dense_pose=False): include_dense_pose=False,
include_keypoint_depths=False):
"""Returns the default mapping from a preprocessor function to its args. """Returns the default mapping from a preprocessor function to its args.
Args: Args:
...@@ -4311,6 +4339,8 @@ def get_default_func_arg_map(include_label_weights=True, ...@@ -4311,6 +4339,8 @@ def get_default_func_arg_map(include_label_weights=True,
the keypoint visibilities, too. the keypoint visibilities, too.
include_dense_pose: If True, preprocessing functions will modify the include_dense_pose: If True, preprocessing functions will modify the
DensePose labels, too. DensePose labels, too.
include_keypoint_depths: If True, preprocessing functions will modify the
keypoint depth labels, too.
Returns: Returns:
A map from preprocessing functions to the arguments they receive. A map from preprocessing functions to the arguments they receive.
...@@ -4353,6 +4383,13 @@ def get_default_func_arg_map(include_label_weights=True, ...@@ -4353,6 +4383,13 @@ def get_default_func_arg_map(include_label_weights=True,
fields.InputDataFields.groundtruth_dp_part_ids) fields.InputDataFields.groundtruth_dp_part_ids)
groundtruth_dp_surface_coords = ( groundtruth_dp_surface_coords = (
fields.InputDataFields.groundtruth_dp_surface_coords) fields.InputDataFields.groundtruth_dp_surface_coords)
groundtruth_keypoint_depths = None
groundtruth_keypoint_depth_weights = None
if include_keypoint_depths:
groundtruth_keypoint_depths = (
fields.InputDataFields.groundtruth_keypoint_depths)
groundtruth_keypoint_depth_weights = (
fields.InputDataFields.groundtruth_keypoint_depth_weights)
prep_func_arg_map = { prep_func_arg_map = {
normalize_image: (fields.InputDataFields.image,), normalize_image: (fields.InputDataFields.image,),
...@@ -4364,6 +4401,8 @@ def get_default_func_arg_map(include_label_weights=True, ...@@ -4364,6 +4401,8 @@ def get_default_func_arg_map(include_label_weights=True,
groundtruth_keypoint_visibilities, groundtruth_keypoint_visibilities,
groundtruth_dp_part_ids, groundtruth_dp_part_ids,
groundtruth_dp_surface_coords, groundtruth_dp_surface_coords,
groundtruth_keypoint_depths,
groundtruth_keypoint_depth_weights,
), ),
random_vertical_flip: ( random_vertical_flip: (
fields.InputDataFields.image, fields.InputDataFields.image,
......
...@@ -105,6 +105,17 @@ class PreprocessorTest(test_case.TestCase, parameterized.TestCase): ...@@ -105,6 +105,17 @@ class PreprocessorTest(test_case.TestCase, parameterized.TestCase):
]) ])
return keypoints, keypoint_visibilities return keypoints, keypoint_visibilities
def createTestKeypointDepths(self):
keypoint_depths = tf.constant([
[1.0, 0.9, 0.8],
[0.7, 0.6, 0.5]
], dtype=tf.float32)
keypoint_depth_weights = tf.constant([
[0.5, 0.6, 0.7],
[0.8, 0.9, 1.0]
], dtype=tf.float32)
return keypoint_depths, keypoint_depth_weights
def createTestKeypointsInsideCrop(self): def createTestKeypointsInsideCrop(self):
keypoints = np.array([ keypoints = np.array([
[[0.4, 0.4], [0.5, 0.5], [0.6, 0.6]], [[0.4, 0.4], [0.5, 0.5], [0.6, 0.6]],
...@@ -713,6 +724,59 @@ class PreprocessorTest(test_case.TestCase, parameterized.TestCase): ...@@ -713,6 +724,59 @@ class PreprocessorTest(test_case.TestCase, parameterized.TestCase):
test_keypoints=True) test_keypoints=True)
def testRunRandomHorizontalFlipWithKeypointDepth(self):
def graph_fn():
preprocess_options = [(preprocessor.random_horizontal_flip, {})]
image_height = 3
image_width = 3
images = tf.random_uniform([1, image_height, image_width, 3])
boxes = self.createTestBoxes()
masks = self.createTestMasks()
keypoints, keypoint_visibilities = self.createTestKeypoints()
keypoint_depths, keypoint_depth_weights = self.createTestKeypointDepths()
keypoint_flip_permutation = self.createKeypointFlipPermutation()
tensor_dict = {
fields.InputDataFields.image:
images,
fields.InputDataFields.groundtruth_boxes:
boxes,
fields.InputDataFields.groundtruth_instance_masks:
masks,
fields.InputDataFields.groundtruth_keypoints:
keypoints,
fields.InputDataFields.groundtruth_keypoint_visibilities:
keypoint_visibilities,
fields.InputDataFields.groundtruth_keypoint_depths:
keypoint_depths,
fields.InputDataFields.groundtruth_keypoint_depth_weights:
keypoint_depth_weights,
}
preprocess_options = [(preprocessor.random_horizontal_flip, {
'keypoint_flip_permutation': keypoint_flip_permutation,
'probability': 1.0
})]
preprocessor_arg_map = preprocessor.get_default_func_arg_map(
include_instance_masks=True,
include_keypoints=True,
include_keypoint_visibilities=True,
include_dense_pose=False,
include_keypoint_depths=True)
tensor_dict = preprocessor.preprocess(
tensor_dict, preprocess_options, func_arg_map=preprocessor_arg_map)
keypoint_depths = tensor_dict[
fields.InputDataFields.groundtruth_keypoint_depths]
keypoint_depth_weights = tensor_dict[
fields.InputDataFields.groundtruth_keypoint_depth_weights]
output_tensors = [keypoint_depths, keypoint_depth_weights]
return output_tensors
output_tensors = self.execute_cpu(graph_fn, [])
expected_keypoint_depths = [[1.0, 0.8, 0.9], [0.7, 0.5, 0.6]]
expected_keypoint_depth_weights = [[0.5, 0.7, 0.6], [0.8, 1.0, 0.9]]
self.assertAllClose(expected_keypoint_depths, output_tensors[0])
self.assertAllClose(expected_keypoint_depth_weights, output_tensors[1])
def testRandomVerticalFlip(self): def testRandomVerticalFlip(self):
def graph_fn(): def graph_fn():
......
...@@ -67,6 +67,9 @@ class InputDataFields(object): ...@@ -67,6 +67,9 @@ class InputDataFields(object):
groundtruth_instance_boundaries: ground truth instance boundaries. groundtruth_instance_boundaries: ground truth instance boundaries.
groundtruth_instance_classes: instance mask-level class labels. groundtruth_instance_classes: instance mask-level class labels.
groundtruth_keypoints: ground truth keypoints. groundtruth_keypoints: ground truth keypoints.
groundtruth_keypoint_depths: Relative depth of the keypoints.
groundtruth_keypoint_depth_weights: Weights of the relative depth of the
keypoints.
groundtruth_keypoint_visibilities: ground truth keypoint visibilities. groundtruth_keypoint_visibilities: ground truth keypoint visibilities.
groundtruth_keypoint_weights: groundtruth weight factor for keypoints. groundtruth_keypoint_weights: groundtruth weight factor for keypoints.
groundtruth_label_weights: groundtruth label weights. groundtruth_label_weights: groundtruth label weights.
...@@ -122,6 +125,8 @@ class InputDataFields(object): ...@@ -122,6 +125,8 @@ class InputDataFields(object):
groundtruth_instance_boundaries = 'groundtruth_instance_boundaries' groundtruth_instance_boundaries = 'groundtruth_instance_boundaries'
groundtruth_instance_classes = 'groundtruth_instance_classes' groundtruth_instance_classes = 'groundtruth_instance_classes'
groundtruth_keypoints = 'groundtruth_keypoints' groundtruth_keypoints = 'groundtruth_keypoints'
groundtruth_keypoint_depths = 'groundtruth_keypoint_depths'
groundtruth_keypoint_depth_weights = 'groundtruth_keypoint_depth_weights'
groundtruth_keypoint_visibilities = 'groundtruth_keypoint_visibilities' groundtruth_keypoint_visibilities = 'groundtruth_keypoint_visibilities'
groundtruth_keypoint_weights = 'groundtruth_keypoint_weights' groundtruth_keypoint_weights = 'groundtruth_keypoint_weights'
groundtruth_label_weights = 'groundtruth_label_weights' groundtruth_label_weights = 'groundtruth_label_weights'
...@@ -162,6 +167,7 @@ class DetectionResultFields(object): ...@@ -162,6 +167,7 @@ class DetectionResultFields(object):
detection_boundaries: contains an object boundary for each detection box. detection_boundaries: contains an object boundary for each detection box.
detection_keypoints: contains detection keypoints for each detection box. detection_keypoints: contains detection keypoints for each detection box.
detection_keypoint_scores: contains detection keypoint scores. detection_keypoint_scores: contains detection keypoint scores.
detection_keypoint_depths: contains detection keypoint depths.
num_detections: number of detections in the batch. num_detections: number of detections in the batch.
raw_detection_boxes: contains decoded detection boxes without Non-Max raw_detection_boxes: contains decoded detection boxes without Non-Max
suppression. suppression.
...@@ -183,6 +189,7 @@ class DetectionResultFields(object): ...@@ -183,6 +189,7 @@ class DetectionResultFields(object):
detection_boundaries = 'detection_boundaries' detection_boundaries = 'detection_boundaries'
detection_keypoints = 'detection_keypoints' detection_keypoints = 'detection_keypoints'
detection_keypoint_scores = 'detection_keypoint_scores' detection_keypoint_scores = 'detection_keypoint_scores'
detection_keypoint_depths = 'detection_keypoint_depths'
detection_embeddings = 'detection_embeddings' detection_embeddings = 'detection_embeddings'
detection_offsets = 'detection_temporal_offsets' detection_offsets = 'detection_temporal_offsets'
num_detections = 'num_detections' num_detections = 'num_detections'
...@@ -205,6 +212,8 @@ class BoxListFields(object): ...@@ -205,6 +212,8 @@ class BoxListFields(object):
keypoints: keypoints per bounding box. keypoints: keypoints per bounding box.
keypoint_visibilities: keypoint visibilities per bounding box. keypoint_visibilities: keypoint visibilities per bounding box.
keypoint_heatmaps: keypoint heatmaps per bounding box. keypoint_heatmaps: keypoint heatmaps per bounding box.
keypoint_depths: keypoint depths per bounding box.
keypoint_depth_weights: keypoint depth weights per bounding box.
densepose_num_points: number of DensePose points per bounding box. densepose_num_points: number of DensePose points per bounding box.
densepose_part_ids: DensePose part ids per bounding box. densepose_part_ids: DensePose part ids per bounding box.
densepose_surface_coords: DensePose surface coordinates per bounding box. densepose_surface_coords: DensePose surface coordinates per bounding box.
...@@ -223,6 +232,8 @@ class BoxListFields(object): ...@@ -223,6 +232,8 @@ class BoxListFields(object):
keypoints = 'keypoints' keypoints = 'keypoints'
keypoint_visibilities = 'keypoint_visibilities' keypoint_visibilities = 'keypoint_visibilities'
keypoint_heatmaps = 'keypoint_heatmaps' keypoint_heatmaps = 'keypoint_heatmaps'
keypoint_depths = 'keypoint_depths'
keypoint_depth_weights = 'keypoint_depth_weights'
densepose_num_points = 'densepose_num_points' densepose_num_points = 'densepose_num_points'
densepose_part_ids = 'densepose_part_ids' densepose_part_ids = 'densepose_part_ids'
densepose_surface_coords = 'densepose_surface_coords' densepose_surface_coords = 'densepose_surface_coords'
......
...@@ -139,7 +139,8 @@ class TfExampleDecoder(data_decoder.DataDecoder): ...@@ -139,7 +139,8 @@ class TfExampleDecoder(data_decoder.DataDecoder):
load_context_features=False, load_context_features=False,
expand_hierarchy_labels=False, expand_hierarchy_labels=False,
load_dense_pose=False, load_dense_pose=False,
load_track_id=False): load_track_id=False,
load_keypoint_depth_features=False):
"""Constructor sets keys_to_features and items_to_handlers. """Constructor sets keys_to_features and items_to_handlers.
Args: Args:
...@@ -172,6 +173,10 @@ class TfExampleDecoder(data_decoder.DataDecoder): ...@@ -172,6 +173,10 @@ class TfExampleDecoder(data_decoder.DataDecoder):
the labels are expanded to descendants. the labels are expanded to descendants.
load_dense_pose: Whether to load DensePose annotations. load_dense_pose: Whether to load DensePose annotations.
load_track_id: Whether to load tracking annotations. load_track_id: Whether to load tracking annotations.
load_keypoint_depth_features: Whether to load the keypoint depth features
including keypoint relative depths and weights. If this field is set to
True but no keypoint depth features are in the input tf.Example, then
default values will be populated.
Raises: Raises:
ValueError: If `instance_mask_type` option is not one of ValueError: If `instance_mask_type` option is not one of
...@@ -180,6 +185,7 @@ class TfExampleDecoder(data_decoder.DataDecoder): ...@@ -180,6 +185,7 @@ class TfExampleDecoder(data_decoder.DataDecoder):
ValueError: If `expand_labels_hierarchy` is True, but the ValueError: If `expand_labels_hierarchy` is True, but the
`label_map_proto_file` is not provided. `label_map_proto_file` is not provided.
""" """
# TODO(rathodv): delete unused `use_display_name` argument once we change # TODO(rathodv): delete unused `use_display_name` argument once we change
# other decoders to handle label maps similarly. # other decoders to handle label maps similarly.
del use_display_name del use_display_name
...@@ -331,6 +337,23 @@ class TfExampleDecoder(data_decoder.DataDecoder): ...@@ -331,6 +337,23 @@ class TfExampleDecoder(data_decoder.DataDecoder):
slim_example_decoder.ItemHandlerCallback( slim_example_decoder.ItemHandlerCallback(
['image/object/keypoint/x', 'image/object/keypoint/visibility'], ['image/object/keypoint/x', 'image/object/keypoint/visibility'],
self._reshape_keypoint_visibilities)) self._reshape_keypoint_visibilities))
if load_keypoint_depth_features:
self.keys_to_features['image/object/keypoint/z'] = (
tf.VarLenFeature(tf.float32))
self.keys_to_features['image/object/keypoint/z/weights'] = (
tf.VarLenFeature(tf.float32))
self.items_to_handlers[
fields.InputDataFields.groundtruth_keypoint_depths] = (
slim_example_decoder.ItemHandlerCallback(
['image/object/keypoint/x', 'image/object/keypoint/z'],
self._reshape_keypoint_depths))
self.items_to_handlers[
fields.InputDataFields.groundtruth_keypoint_depth_weights] = (
slim_example_decoder.ItemHandlerCallback(
['image/object/keypoint/x',
'image/object/keypoint/z/weights'],
self._reshape_keypoint_depth_weights))
if load_instance_masks: if load_instance_masks:
if instance_mask_type in (input_reader_pb2.DEFAULT, if instance_mask_type in (input_reader_pb2.DEFAULT,
input_reader_pb2.NUMERICAL_MASKS): input_reader_pb2.NUMERICAL_MASKS):
...@@ -601,6 +624,73 @@ class TfExampleDecoder(data_decoder.DataDecoder): ...@@ -601,6 +624,73 @@ class TfExampleDecoder(data_decoder.DataDecoder):
keypoints = tf.reshape(keypoints, [-1, self._num_keypoints, 2]) keypoints = tf.reshape(keypoints, [-1, self._num_keypoints, 2])
return keypoints return keypoints
def _reshape_keypoint_depths(self, keys_to_tensors):
"""Reshape keypoint depths.
The keypoint depths are reshaped to [num_instances, num_keypoints]. The
keypoint depth tensor is expected to have the same shape as the keypoint x
(or y) tensors. If not (usually because the example does not have the depth
groundtruth), then default depth values (zero) are provided.
Args:
keys_to_tensors: a dictionary from keys to tensors. Expected keys are:
'image/object/keypoint/x'
'image/object/keypoint/z'
Returns:
A 2-D float tensor of shape [num_instances, num_keypoints] with values
representing the keypoint depths.
"""
x = keys_to_tensors['image/object/keypoint/x']
z = keys_to_tensors['image/object/keypoint/z']
if isinstance(z, tf.SparseTensor):
z = tf.sparse_tensor_to_dense(z)
if isinstance(x, tf.SparseTensor):
x = tf.sparse_tensor_to_dense(x)
default_z = tf.zeros_like(x)
# Use keypoint depth groundtruth if provided, otherwise use the default
# depth value.
z = tf.cond(tf.equal(tf.size(x), tf.size(z)),
true_fn=lambda: z,
false_fn=lambda: default_z)
z = tf.reshape(z, [-1, self._num_keypoints])
return z
def _reshape_keypoint_depth_weights(self, keys_to_tensors):
"""Reshape keypoint depth weights.
The keypoint depth weights are reshaped to [num_instances, num_keypoints].
The keypoint depth weights tensor is expected to have the same shape as the
keypoint x (or y) tensors. If not (usually because the example does not have
the depth weights groundtruth), then default weight values (zero) are
provided.
Args:
keys_to_tensors: a dictionary from keys to tensors. Expected keys are:
'image/object/keypoint/x'
'image/object/keypoint/z/weights'
Returns:
A 2-D float tensor of shape [num_instances, num_keypoints] with values
representing the keypoint depth weights.
"""
x = keys_to_tensors['image/object/keypoint/x']
z = keys_to_tensors['image/object/keypoint/z/weights']
if isinstance(z, tf.SparseTensor):
z = tf.sparse_tensor_to_dense(z)
if isinstance(x, tf.SparseTensor):
x = tf.sparse_tensor_to_dense(x)
default_z = tf.zeros_like(x)
# Use keypoint depth weights if provided, otherwise use the default
# values.
z = tf.cond(tf.equal(tf.size(x), tf.size(z)),
true_fn=lambda: z,
false_fn=lambda: default_z)
z = tf.reshape(z, [-1, self._num_keypoints])
return z
def _reshape_keypoint_visibilities(self, keys_to_tensors): def _reshape_keypoint_visibilities(self, keys_to_tensors):
"""Reshape keypoint visibilities. """Reshape keypoint visibilities.
......
...@@ -275,6 +275,124 @@ class TfExampleDecoderTest(test_case.TestCase): ...@@ -275,6 +275,124 @@ class TfExampleDecoderTest(test_case.TestCase):
self.assertAllEqual(expected_boxes, self.assertAllEqual(expected_boxes,
tensor_dict[fields.InputDataFields.groundtruth_boxes]) tensor_dict[fields.InputDataFields.groundtruth_boxes])
def testDecodeKeypointDepth(self):
image_tensor = np.random.randint(256, size=(4, 5, 3)).astype(np.uint8)
encoded_jpeg, _ = self._create_encoded_and_decoded_data(
image_tensor, 'jpeg')
bbox_ymins = [0.0, 4.0]
bbox_xmins = [1.0, 5.0]
bbox_ymaxs = [2.0, 6.0]
bbox_xmaxs = [3.0, 7.0]
keypoint_ys = [0.0, 1.0, 2.0, 3.0, 4.0, 5.0]
keypoint_xs = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
keypoint_visibility = [1, 2, 0, 1, 0, 2]
keypoint_depths = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
keypoint_depth_weights = [1.0, 0.9, 0.8, 0.7, 0.6, 0.5]
def graph_fn():
example = tf.train.Example(
features=tf.train.Features(
feature={
'image/encoded':
dataset_util.bytes_feature(encoded_jpeg),
'image/format':
dataset_util.bytes_feature(six.b('jpeg')),
'image/object/bbox/ymin':
dataset_util.float_list_feature(bbox_ymins),
'image/object/bbox/xmin':
dataset_util.float_list_feature(bbox_xmins),
'image/object/bbox/ymax':
dataset_util.float_list_feature(bbox_ymaxs),
'image/object/bbox/xmax':
dataset_util.float_list_feature(bbox_xmaxs),
'image/object/keypoint/y':
dataset_util.float_list_feature(keypoint_ys),
'image/object/keypoint/x':
dataset_util.float_list_feature(keypoint_xs),
'image/object/keypoint/z':
dataset_util.float_list_feature(keypoint_depths),
'image/object/keypoint/z/weights':
dataset_util.float_list_feature(keypoint_depth_weights),
'image/object/keypoint/visibility':
dataset_util.int64_list_feature(keypoint_visibility),
})).SerializeToString()
example_decoder = tf_example_decoder.TfExampleDecoder(
num_keypoints=3, load_keypoint_depth_features=True)
output = example_decoder.decode(tf.convert_to_tensor(example))
self.assertAllEqual(
(output[fields.InputDataFields.groundtruth_keypoint_depths].get_shape(
).as_list()), [2, 3])
self.assertAllEqual(
(output[fields.InputDataFields.groundtruth_keypoint_depth_weights]
.get_shape().as_list()), [2, 3])
return output
tensor_dict = self.execute_cpu(graph_fn, [])
expected_keypoint_depths = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
self.assertAllClose(
expected_keypoint_depths,
tensor_dict[fields.InputDataFields.groundtruth_keypoint_depths])
expected_keypoint_depth_weights = [[1.0, 0.9, 0.8], [0.7, 0.6, 0.5]]
self.assertAllClose(
expected_keypoint_depth_weights,
tensor_dict[fields.InputDataFields.groundtruth_keypoint_depth_weights])
def testDecodeKeypointDepthNoDepth(self):
image_tensor = np.random.randint(256, size=(4, 5, 3)).astype(np.uint8)
encoded_jpeg, _ = self._create_encoded_and_decoded_data(
image_tensor, 'jpeg')
bbox_ymins = [0.0, 4.0]
bbox_xmins = [1.0, 5.0]
bbox_ymaxs = [2.0, 6.0]
bbox_xmaxs = [3.0, 7.0]
keypoint_ys = [0.0, 1.0, 2.0, 3.0, 4.0, 5.0]
keypoint_xs = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
keypoint_visibility = [1, 2, 0, 1, 0, 2]
def graph_fn():
example = tf.train.Example(
features=tf.train.Features(
feature={
'image/encoded':
dataset_util.bytes_feature(encoded_jpeg),
'image/format':
dataset_util.bytes_feature(six.b('jpeg')),
'image/object/bbox/ymin':
dataset_util.float_list_feature(bbox_ymins),
'image/object/bbox/xmin':
dataset_util.float_list_feature(bbox_xmins),
'image/object/bbox/ymax':
dataset_util.float_list_feature(bbox_ymaxs),
'image/object/bbox/xmax':
dataset_util.float_list_feature(bbox_xmaxs),
'image/object/keypoint/y':
dataset_util.float_list_feature(keypoint_ys),
'image/object/keypoint/x':
dataset_util.float_list_feature(keypoint_xs),
'image/object/keypoint/visibility':
dataset_util.int64_list_feature(keypoint_visibility),
})).SerializeToString()
example_decoder = tf_example_decoder.TfExampleDecoder(
num_keypoints=3, load_keypoint_depth_features=True)
output = example_decoder.decode(tf.convert_to_tensor(example))
return output
tensor_dict = self.execute_cpu(graph_fn, [])
expected_keypoints_depth_default = [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]
self.assertAllClose(
expected_keypoints_depth_default,
tensor_dict[fields.InputDataFields.groundtruth_keypoint_depths])
self.assertAllClose(
expected_keypoints_depth_default,
tensor_dict[fields.InputDataFields.groundtruth_keypoint_depth_weights])
def testDecodeKeypoint(self): def testDecodeKeypoint(self):
image_tensor = np.random.randint(256, size=(4, 5, 3)).astype(np.uint8) image_tensor = np.random.randint(256, size=(4, 5, 3)).astype(np.uint8)
encoded_jpeg, _ = self._create_encoded_and_decoded_data( encoded_jpeg, _ = self._create_encoded_and_decoded_data(
......
...@@ -56,8 +56,7 @@ class FakeModel(model.DetectionModel): ...@@ -56,8 +56,7 @@ class FakeModel(model.DetectionModel):
value=conv_weight_scalar)) value=conv_weight_scalar))
def preprocess(self, inputs): def preprocess(self, inputs):
true_image_shapes = [] # Doesn't matter for the fake model. return tf.identity(inputs), exporter_lib_v2.get_true_shapes(inputs)
return tf.identity(inputs), true_image_shapes
def predict(self, preprocessed_inputs, true_image_shapes): def predict(self, preprocessed_inputs, true_image_shapes):
return {'image': self._conv(preprocessed_inputs)} return {'image': self._conv(preprocessed_inputs)}
......
...@@ -54,8 +54,7 @@ class FakeModel(model.DetectionModel): ...@@ -54,8 +54,7 @@ class FakeModel(model.DetectionModel):
value=conv_weight_scalar)) value=conv_weight_scalar))
def preprocess(self, inputs): def preprocess(self, inputs):
true_image_shapes = [] # Doesn't matter for the fake model. return tf.identity(inputs), exporter_lib_v2.get_true_shapes(inputs)
return tf.identity(inputs), true_image_shapes
def predict(self, preprocessed_inputs, true_image_shapes): def predict(self, preprocessed_inputs, true_image_shapes):
return {'image': self._conv(preprocessed_inputs)} return {'image': self._conv(preprocessed_inputs)}
......
...@@ -51,8 +51,7 @@ class FakeModel(model.DetectionModel): ...@@ -51,8 +51,7 @@ class FakeModel(model.DetectionModel):
value=conv_weight_scalar)) value=conv_weight_scalar))
def preprocess(self, inputs): def preprocess(self, inputs):
true_image_shapes = [] # Doesn't matter for the fake model. return tf.identity(inputs), exporter_lib_v2.get_true_shapes(inputs)
return tf.identity(inputs), true_image_shapes
def predict(self, preprocessed_inputs, true_image_shapes, **side_inputs): def predict(self, preprocessed_inputs, true_image_shapes, **side_inputs):
return_dict = {'image': self._conv(preprocessed_inputs)} return_dict = {'image': self._conv(preprocessed_inputs)}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment