Commit 2e9bb539 authored by stephenwu's avatar stephenwu
Browse files

Merge branch 'master' of https://github.com/tensorflow/models into RTESuperGLUE

parents 7bae5317 8fba84f8
# Lint as: python3 # Copyright 2021 The TensorFlow Authors All Rights Reserved.
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -13,41 +12,41 @@ ...@@ -13,41 +12,41 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Tests for the ResNet backbone.""" """Tests for pooling layers."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf import tensorflow as tf
from delf.python.training.model import resnet50 from delf.python.pooling_layers import pooling
class Resnet50Test(tf.test.TestCase):
def test_gem_pooling_works(self): class PoolingsTest(tf.test.TestCase):
# Input feature map: Batch size = 2, height = 1, width = 2, depth = 2.
feature_map = tf.constant([[[[.0, 2.0], [1.0, -1.0]]],
[[[1.0, 100.0], [1.0, .0]]]],
dtype=tf.float32)
power = 2.0
threshold = .0
def testMac(self):
x = tf.constant([[[[0., 1.], [2., 3.]], [[4., 5.], [6., 7.]]]])
# Run tested function. # Run tested function.
pooled_feature_map = resnet50.gem_pooling(feature_map=feature_map, result = pooling.mac(x)
axis=[1, 2], # Define expected result.
power=power, exp_output = [[6., 7.]]
threshold=threshold) # Compare actual and expected.
self.assertAllClose(exp_output, result)
def testSpoc(self):
x = tf.constant([[[[0., 1.], [2., 3.]], [[4., 5.], [6., 7.]]]])
# Run tested function.
result = pooling.spoc(x)
# Define expected result. # Define expected result.
expected_pooled_feature_map = np.array([[0.707107, 1.414214], exp_output = [[3., 4.]]
[1.0, 70.710678]], # Compare actual and expected.
dtype=float) self.assertAllClose(exp_output, result)
def testGem(self):
x = tf.constant([[[[0., 1.], [2., 3.]], [[4., 5.], [6., 7.]]]])
# Run tested function.
result = pooling.gem(x, power=3., eps=1e-6)
# Define expected result.
exp_output = [[4.1601677, 4.9866314]]
# Compare actual and expected. # Compare actual and expected.
self.assertAllClose(pooled_feature_map, expected_pooled_feature_map) self.assertAllClose(exp_output, result)
if __name__ == '__main__': if __name__ == '__main__':
......
# Copyright 2021 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
\ No newline at end of file
# Copyright 2021 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Ranking loss definitions."""
import tensorflow as tf
class ContrastiveLoss(tf.keras.losses.Loss):
"""Contrastive Loss layer.
Contrastive Loss layer allows to compute contrastive loss for a batch of
images. Implementation based on: https://arxiv.org/abs/1604.02426.
"""
def __init__(self, margin=0.7, reduction=tf.keras.losses.Reduction.NONE):
"""Initialization of Contrastive Loss layer.
Args:
margin: Float contrastive loss margin.
reduction: Type of loss reduction.
"""
super(ContrastiveLoss, self).__init__(reduction)
self.margin = margin
# Parameter for numerical stability.
self.eps = 1e-6
def __call__(self, queries, positives, negatives):
"""Invokes the Contrastive Loss instance.
Args:
queries: [batch_size, dim] Anchor input tensor.
positives: [batch_size, dim] Positive sample input tensor.
negatives: [batch_size, num_neg, dim] Negative sample input tensor.
Returns:
loss: Scalar tensor.
"""
return contrastive_loss(
queries, positives, negatives, margin=self.margin, eps=self.eps)
class TripletLoss(tf.keras.losses.Loss):
"""Triplet Loss layer.
Triplet Loss layer computes triplet loss for a batch of images. Triplet
loss tries to keep all queries closer to positives than to any negatives.
Margin is used to specify when a triplet has become too "easy" and we no
longer want to adjust the weights from it. Differently from the Contrastive
Loss, Triplet Loss uses squared distances when computing the loss.
Implementation based on: https://arxiv.org/abs/1511.07247.
"""
def __init__(self, margin=0.1, reduction=tf.keras.losses.Reduction.NONE):
"""Initialization of Triplet Loss layer.
Args:
margin: Triplet loss margin.
reduction: Type of loss reduction.
"""
super(TripletLoss, self).__init__(reduction)
self.margin = margin
def __call__(self, queries, positives, negatives):
"""Invokes the Triplet Loss instance.
Args:
queries: [batch_size, dim] Anchor input tensor.
positives: [batch_size, dim] Positive sample input tensor.
negatives: [batch_size, num_neg, dim] Negative sample input tensor.
Returns:
loss: Scalar tensor.
"""
return triplet_loss(queries, positives, negatives, margin=self.margin)
def contrastive_loss(queries, positives, negatives, margin=0.7, eps=1e-6):
"""Calculates Contrastive Loss.
We expect the `queries`, `positives` and `negatives` to be normalized with
unit length for training stability. The contrastive loss directly
optimizes this distance by encouraging all positive distances to
approach 0, while keeping negative distances above a certain threshold.
Args:
queries: [batch_size, dim] Anchor input tensor.
positives: [batch_size, dim] Positive sample input tensor.
negatives: [batch_size, num_neg, dim] Negative sample input tensor.
margin: Float contrastive loss loss margin.
eps: Float parameter for numerical stability.
Returns:
loss: Scalar tensor.
"""
dim = tf.shape(queries)[1]
# Number of `queries`.
batch_size = tf.shape(queries)[0]
# Number of `positives`.
np = tf.shape(positives)[0]
# Number of `negatives`.
num_neg = tf.shape(negatives)[1]
# Preparing negatives.
stacked_negatives = tf.reshape(negatives, [num_neg * batch_size, dim])
# Preparing queries for further loss calculation.
stacked_queries = tf.repeat(queries, num_neg + 1, axis=0)
positives_and_negatives = tf.concat([positives, stacked_negatives], axis=0)
# Calculate an Euclidean norm for each pair of points. For any positive
# pair of data points this distance should be small, and for
# negative pair it should be large.
distances = tf.norm(stacked_queries - positives_and_negatives + eps, axis=1)
positives_part = 0.5 * tf.pow(distances[:np], 2.0)
negatives_part = 0.5 * tf.pow(
tf.math.maximum(margin - distances[np:], 0), 2.0)
# Final contrastive loss calculation.
loss = tf.reduce_sum(tf.concat([positives_part, negatives_part], 0))
return loss
def triplet_loss(queries, positives, negatives, margin=0.1):
"""Calculates Triplet Loss.
Triplet loss tries to keep all queries closer to positives than to any
negatives. Differently from the Contrastive Loss, Triplet Loss uses squared
distances when computing the loss.
Args:
queries: [batch_size, dim] Anchor input tensor.
positives: [batch_size, dim] Positive sample input tensor.
negatives: [batch_size, num_neg, dim] Negative sample input tensor.
margin: Float triplet loss loss margin.
Returns:
loss: Scalar tensor.
"""
dim = tf.shape(queries)[1]
# Number of `queries`.
batch_size = tf.shape(queries)[0]
# Number of `negatives`.
num_neg = tf.shape(negatives)[1]
# Preparing negatives.
stacked_negatives = tf.reshape(negatives, [num_neg * batch_size, dim])
# Preparing queries for further loss calculation.
stacked_queries = tf.repeat(queries, num_neg, axis=0)
# Preparing positives for further loss calculation.
stacked_positives = tf.repeat(positives, num_neg, axis=0)
# Computes *squared* distances.
distance_positives = tf.reduce_sum(
tf.square(stacked_queries - stacked_positives), axis=1)
distance_negatives = tf.reduce_sum(
tf.square(stacked_queries - stacked_negatives), axis=1)
# Final triplet loss calculation.
loss = tf.reduce_sum(
tf.maximum(distance_positives - distance_negatives + margin, 0.0))
return loss
# Copyright 2021 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for Ranking losses."""
import tensorflow as tf
from delf.python.training.losses import ranking_losses
class RankingLossesTest(tf.test.TestCase):
def testContrastiveLoss(self):
# Testing the correct numeric value.
queries = tf.math.l2_normalize(tf.constant([[1.0, 2.0, -2.0]]))
positives = tf.math.l2_normalize(tf.constant([[-1.0, 2.0, 0.0]]))
negatives = tf.math.l2_normalize(tf.constant([[[-5.0, 0.0, 3.0]]]))
result = ranking_losses.contrastive_loss(queries, positives, negatives,
margin=0.7, eps=1e-6)
exp_output = 0.55278635
self.assertAllClose(exp_output, result)
def testTripletLossZeroLoss(self):
# Testing the correct numeric value in case if query-positive distance is
# smaller than the query-negative distance.
queries = tf.math.l2_normalize(tf.constant([[1.0, 2.0, -2.0]]))
positives = tf.math.l2_normalize(tf.constant([[-1.0, 2.0, 0.0]]))
negatives = tf.math.l2_normalize(tf.constant([[[-5.0, 0.0, 3.0]]]))
result = ranking_losses.triplet_loss(queries, positives, negatives,
margin=0.1)
exp_output = 0.0
self.assertAllClose(exp_output, result)
def testTripletLossNonZeroLoss(self):
# Testing the correct numeric value in case if query-positive distance is
# bigger than the query-negative distance.
queries = tf.math.l2_normalize(tf.constant([[1.0, 2.0, -2.0]]))
positives = tf.math.l2_normalize(tf.constant([[-5.0, 0.0, 3.0]]))
negatives = tf.math.l2_normalize(tf.constant([[[-1.0, 2.0, 0.0]]]))
result = ranking_losses.triplet_loss(queries, positives, negatives,
margin=0.1)
exp_output = 2.2520838
self.assertAllClose(exp_output, result)
if __name__ == '__main__':
tf.test.main()
...@@ -29,6 +29,7 @@ from absl import logging ...@@ -29,6 +29,7 @@ from absl import logging
import h5py import h5py
import tensorflow as tf import tensorflow as tf
from delf.python.pooling_layers import pooling as pooling_layers
layers = tf.keras.layers layers = tf.keras.layers
...@@ -295,14 +296,14 @@ class ResNet50(tf.keras.Model): ...@@ -295,14 +296,14 @@ class ResNet50(tf.keras.Model):
elif pooling == 'gem': elif pooling == 'gem':
logging.info('Adding GeMPooling layer with power %f', gem_power) logging.info('Adding GeMPooling layer with power %f', gem_power)
self.global_pooling = functools.partial( self.global_pooling = functools.partial(
gem_pooling, axis=reduction_indices, power=gem_power) pooling_layers.gem, axis=reduction_indices, power=gem_power)
else: else:
self.global_pooling = None self.global_pooling = None
if embedding_layer: if embedding_layer:
logging.info('Adding embedding layer with dimension %d', logging.info('Adding embedding layer with dimension %d',
embedding_layer_dim) embedding_layer_dim)
self.embedding_layer = layers.Dense(embedding_layer_dim, self.embedding_layer = layers.Dense(
name='embedding_layer') embedding_layer_dim, name='embedding_layer')
else: else:
self.embedding_layer = None self.embedding_layer = None
...@@ -404,6 +405,7 @@ class ResNet50(tf.keras.Model): ...@@ -404,6 +405,7 @@ class ResNet50(tf.keras.Model):
Args: Args:
filepath: String, path to the .h5 file filepath: String, path to the .h5 file
Raises: Raises:
ValueError: if the file referenced by `filepath` does not exist. ValueError: if the file referenced by `filepath` does not exist.
""" """
...@@ -455,28 +457,4 @@ class ResNet50(tf.keras.Model): ...@@ -455,28 +457,4 @@ class ResNet50(tf.keras.Model):
weights = inlayer.get_weights() weights = inlayer.get_weights()
logging.info(weights) logging.info(weights)
else: else:
logging.info('Layer %s does not have inner layers.', logging.info('Layer %s does not have inner layers.', layer.name)
layer.name)
def gem_pooling(feature_map, axis, power, threshold=1e-6):
"""Performs GeM (Generalized Mean) pooling.
See https://arxiv.org/abs/1711.02512 for a reference.
Args:
feature_map: Tensor of shape [batch, height, width, channels] for
the "channels_last" format or [batch, channels, height, width] for the
"channels_first" format.
axis: Dimensions to reduce.
power: Float, GeM power parameter.
threshold: Optional float, threshold to use for activations.
Returns:
pooled_feature_map: Tensor of shape [batch, channels].
"""
return tf.pow(
tf.reduce_mean(tf.pow(tf.maximum(feature_map, threshold), power),
axis=axis,
keepdims=False),
1.0 / power)
...@@ -37,10 +37,10 @@ message ClientOptions { ...@@ -37,10 +37,10 @@ message ClientOptions {
// The threshold on intersection-over-union used by non-maxima suppression. // The threshold on intersection-over-union used by non-maxima suppression.
optional float iou_threshold = 5 [default = 0.3]; optional float iou_threshold = 5 [default = 0.3];
// Optional whitelist of class names. If non-empty, detections whose class // Optional allowlist of class names. If non-empty, detections whose class
// name is not in this set will be filtered out. Duplicate or unknown class // name is not in this set will be filtered out. Duplicate or unknown class
// names are ignored. // names are ignored.
repeated string class_name_whitelist = 6; repeated string class_name_allowlist = 6;
// SSD in single class agnostic model. // SSD in single class agnostic model.
optional bool agnostic_mode = 7 [default = false]; optional bool agnostic_mode = 7 [default = false];
......
...@@ -63,7 +63,7 @@ void NonMaxSuppressionMultiClassFast( ...@@ -63,7 +63,7 @@ void NonMaxSuppressionMultiClassFast(
// Similar to NonMaxSuppressionMultiClassFast, but restricts the results to // Similar to NonMaxSuppressionMultiClassFast, but restricts the results to
// the provided list of class indices. This effectively filters out any class // the provided list of class indices. This effectively filters out any class
// whose index is not in this whitelist. // whose index is not in this allowlist.
void NonMaxSuppressionMultiClassRestrict( void NonMaxSuppressionMultiClassRestrict(
std::vector<int> restricted_class_indices, std::vector<int> restricted_class_indices,
const protos::BoxCornerEncoding& boxes, const std::vector<float>& scores, const protos::BoxCornerEncoding& boxes, const std::vector<float>& scores,
......
...@@ -868,7 +868,13 @@ def keypoint_proto_to_params(kp_config, keypoint_map_dict): ...@@ -868,7 +868,13 @@ def keypoint_proto_to_params(kp_config, keypoint_map_dict):
candidate_search_scale=kp_config.candidate_search_scale, candidate_search_scale=kp_config.candidate_search_scale,
candidate_ranking_mode=kp_config.candidate_ranking_mode, candidate_ranking_mode=kp_config.candidate_ranking_mode,
offset_peak_radius=kp_config.offset_peak_radius, offset_peak_radius=kp_config.offset_peak_radius,
per_keypoint_offset=kp_config.per_keypoint_offset) per_keypoint_offset=kp_config.per_keypoint_offset,
predict_depth=kp_config.predict_depth,
per_keypoint_depth=kp_config.per_keypoint_depth,
keypoint_depth_loss_weight=kp_config.keypoint_depth_loss_weight,
score_distance_offset=kp_config.score_distance_offset,
clip_out_of_frame_keypoints=kp_config.clip_out_of_frame_keypoints,
rescore_instances=kp_config.rescore_instances)
def object_detection_proto_to_params(od_config): def object_detection_proto_to_params(od_config):
......
...@@ -116,6 +116,9 @@ class ModelBuilderTF2Test(model_builder_test.ModelBuilderTest): ...@@ -116,6 +116,9 @@ class ModelBuilderTF2Test(model_builder_test.ModelBuilderTest):
candidate_ranking_mode: "score_distance_ratio" candidate_ranking_mode: "score_distance_ratio"
offset_peak_radius: 3 offset_peak_radius: 3
per_keypoint_offset: true per_keypoint_offset: true
predict_depth: true
per_keypoint_depth: true
keypoint_depth_loss_weight: 0.3
""" """
config = text_format.Merge(task_proto_txt, config = text_format.Merge(task_proto_txt,
center_net_pb2.CenterNet.KeypointEstimation()) center_net_pb2.CenterNet.KeypointEstimation())
...@@ -264,6 +267,9 @@ class ModelBuilderTF2Test(model_builder_test.ModelBuilderTest): ...@@ -264,6 +267,9 @@ class ModelBuilderTF2Test(model_builder_test.ModelBuilderTest):
self.assertEqual(kp_params.candidate_ranking_mode, 'score_distance_ratio') self.assertEqual(kp_params.candidate_ranking_mode, 'score_distance_ratio')
self.assertEqual(kp_params.offset_peak_radius, 3) self.assertEqual(kp_params.offset_peak_radius, 3)
self.assertEqual(kp_params.per_keypoint_offset, True) self.assertEqual(kp_params.per_keypoint_offset, True)
self.assertEqual(kp_params.predict_depth, True)
self.assertEqual(kp_params.per_keypoint_depth, True)
self.assertAlmostEqual(kp_params.keypoint_depth_loss_weight, 0.3)
# Check mask related parameters. # Check mask related parameters.
self.assertAlmostEqual(model._mask_params.task_loss_weight, 0.7) self.assertAlmostEqual(model._mask_params.task_loss_weight, 0.7)
......
...@@ -18,6 +18,12 @@ ...@@ -18,6 +18,12 @@
import tensorflow.compat.v1 as tf import tensorflow.compat.v1 as tf
from object_detection.utils import learning_schedules from object_detection.utils import learning_schedules
from object_detection.utils import tf_version
# pylint: disable=g-import-not-at-top
if tf_version.is_tf2():
from official.modeling.optimization import ema_optimizer
# pylint: enable=g-import-not-at-top
try: try:
from tensorflow.contrib import opt as tf_opt # pylint: disable=g-import-not-at-top from tensorflow.contrib import opt as tf_opt # pylint: disable=g-import-not-at-top
...@@ -130,7 +136,9 @@ def build_optimizers_tf_v2(optimizer_config, global_step=None): ...@@ -130,7 +136,9 @@ def build_optimizers_tf_v2(optimizer_config, global_step=None):
raise ValueError('Optimizer %s not supported.' % optimizer_type) raise ValueError('Optimizer %s not supported.' % optimizer_type)
if optimizer_config.use_moving_average: if optimizer_config.use_moving_average:
raise ValueError('Moving average not supported in eager mode.') optimizer = ema_optimizer.ExponentialMovingAverage(
optimizer=optimizer,
average_decay=optimizer_config.moving_average_decay)
return optimizer, summary_vars return optimizer, summary_vars
......
...@@ -82,7 +82,7 @@ class OptimizerBuilderV2Test(tf.test.TestCase): ...@@ -82,7 +82,7 @@ class OptimizerBuilderV2Test(tf.test.TestCase):
optimizer, _ = optimizer_builder.build(optimizer_proto) optimizer, _ = optimizer_builder.build(optimizer_proto)
self.assertIsInstance(optimizer, tf.keras.optimizers.Adam) self.assertIsInstance(optimizer, tf.keras.optimizers.Adam)
def testMovingAverageOptimizerUnsupported(self): def testBuildMovingAverageOptimizer(self):
optimizer_text_proto = """ optimizer_text_proto = """
adam_optimizer: { adam_optimizer: {
learning_rate: { learning_rate: {
...@@ -95,8 +95,8 @@ class OptimizerBuilderV2Test(tf.test.TestCase): ...@@ -95,8 +95,8 @@ class OptimizerBuilderV2Test(tf.test.TestCase):
""" """
optimizer_proto = optimizer_pb2.Optimizer() optimizer_proto = optimizer_pb2.Optimizer()
text_format.Merge(optimizer_text_proto, optimizer_proto) text_format.Merge(optimizer_text_proto, optimizer_proto)
with self.assertRaises(ValueError): optimizer, _ = optimizer_builder.build(optimizer_proto)
optimizer_builder.build(optimizer_proto) self.assertIsInstance(optimizer, tf.keras.optimizers.Optimizer)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -570,15 +570,17 @@ def _resize_detection_masks(arg_tuple): ...@@ -570,15 +570,17 @@ def _resize_detection_masks(arg_tuple):
Returns: Returns:
""" """
detection_boxes, detection_masks, image_shape, pad_shape = arg_tuple detection_boxes, detection_masks, image_shape, pad_shape = arg_tuple
detection_masks_reframed = ops.reframe_box_masks_to_image_masks( detection_masks_reframed = ops.reframe_box_masks_to_image_masks(
detection_masks, detection_boxes, image_shape[0], image_shape[1]) detection_masks, detection_boxes, image_shape[0], image_shape[1])
paddings = tf.concat(
[tf.zeros([3, 1], dtype=tf.int32), pad_instance_dim = tf.zeros([3, 1], dtype=tf.int32)
tf.expand_dims( pad_hw_dim = tf.concat([tf.zeros([1], dtype=tf.int32),
tf.concat([tf.zeros([1], dtype=tf.int32), pad_shape - image_shape], axis=0)
pad_shape-image_shape], axis=0), pad_hw_dim = tf.expand_dims(pad_hw_dim, 1)
1)], axis=1) paddings = tf.concat([pad_instance_dim, pad_hw_dim], axis=1)
detection_masks_reframed = tf.pad(detection_masks_reframed, paddings) detection_masks_reframed = tf.pad(detection_masks_reframed, paddings)
# If the masks are currently float, binarize them. Otherwise keep them as # If the masks are currently float, binarize them. Otherwise keep them as
......
...@@ -207,7 +207,8 @@ class CenterNetModule(tf.Module): ...@@ -207,7 +207,8 @@ class CenterNetModule(tf.Module):
both object detection and keypoint estimation task. both object detection and keypoint estimation task.
""" """
def __init__(self, pipeline_config, max_detections, include_keypoints): def __init__(self, pipeline_config, max_detections, include_keypoints,
label_map_path=''):
"""Initialization. """Initialization.
Args: Args:
...@@ -215,10 +216,15 @@ class CenterNetModule(tf.Module): ...@@ -215,10 +216,15 @@ class CenterNetModule(tf.Module):
max_detections: Max detections desired from the TFLite model. max_detections: Max detections desired from the TFLite model.
include_keypoints: If set true, the output dictionary will include the include_keypoints: If set true, the output dictionary will include the
keypoint coordinates and keypoint confidence scores. keypoint coordinates and keypoint confidence scores.
label_map_path: Path to the label map which is used by CenterNet keypoint
estimation task. If provided, the label_map_path in the configuration
will be replaced by this one.
""" """
self._max_detections = max_detections self._max_detections = max_detections
self._include_keypoints = include_keypoints self._include_keypoints = include_keypoints
self._process_config(pipeline_config) self._process_config(pipeline_config)
if include_keypoints and label_map_path:
pipeline_config.model.center_net.keypoint_label_map_path = label_map_path
self._pipeline_config = pipeline_config self._pipeline_config = pipeline_config
self._model = model_builder.build( self._model = model_builder.build(
self._pipeline_config.model, is_training=False) self._pipeline_config.model, is_training=False)
...@@ -303,7 +309,7 @@ class CenterNetModule(tf.Module): ...@@ -303,7 +309,7 @@ class CenterNetModule(tf.Module):
def export_tflite_model(pipeline_config, trained_checkpoint_dir, def export_tflite_model(pipeline_config, trained_checkpoint_dir,
output_directory, max_detections, use_regular_nms, output_directory, max_detections, use_regular_nms,
include_keypoints=False): include_keypoints=False, label_map_path=''):
"""Exports inference SavedModel for TFLite conversion. """Exports inference SavedModel for TFLite conversion.
NOTE: Only supports SSD meta-architectures for now, and the output model will NOTE: Only supports SSD meta-architectures for now, and the output model will
...@@ -322,6 +328,9 @@ def export_tflite_model(pipeline_config, trained_checkpoint_dir, ...@@ -322,6 +328,9 @@ def export_tflite_model(pipeline_config, trained_checkpoint_dir,
Note that this argument is only used by the SSD model. Note that this argument is only used by the SSD model.
include_keypoints: Decides whether to also output the keypoint predictions. include_keypoints: Decides whether to also output the keypoint predictions.
Note that this argument is only used by the CenterNet model. Note that this argument is only used by the CenterNet model.
label_map_path: Path to the label map which is used by CenterNet keypoint
estimation task. If provided, the label_map_path in the configuration will
be replaced by this one.
Raises: Raises:
ValueError: if pipeline is invalid. ValueError: if pipeline is invalid.
...@@ -339,7 +348,8 @@ def export_tflite_model(pipeline_config, trained_checkpoint_dir, ...@@ -339,7 +348,8 @@ def export_tflite_model(pipeline_config, trained_checkpoint_dir,
max_detections, use_regular_nms) max_detections, use_regular_nms)
elif pipeline_config.model.WhichOneof('model') == 'center_net': elif pipeline_config.model.WhichOneof('model') == 'center_net':
detection_module = CenterNetModule( detection_module = CenterNetModule(
pipeline_config, max_detections, include_keypoints) pipeline_config, max_detections, include_keypoints,
label_map_path=label_map_path)
ckpt = tf.train.Checkpoint(model=detection_module.get_model()) ckpt = tf.train.Checkpoint(model=detection_module.get_model())
else: else:
raise ValueError('Only ssd or center_net models are supported in tflite. ' raise ValueError('Only ssd or center_net models are supported in tflite. '
......
...@@ -53,7 +53,7 @@ certain fields in the provided pipeline_config_path. These are useful for ...@@ -53,7 +53,7 @@ certain fields in the provided pipeline_config_path. These are useful for
making small changes to the inference graph that differ from the training or making small changes to the inference graph that differ from the training or
eval config. eval config.
Example Usage (in which we change the NMS iou_threshold to be 0.5 and Example Usage 1 (in which we change the NMS iou_threshold to be 0.5 and
NMS score_threshold to be 0.0): NMS score_threshold to be 0.0):
python object_detection/export_tflite_model_tf2.py \ python object_detection/export_tflite_model_tf2.py \
--pipeline_config_path path/to/ssd_model/pipeline.config \ --pipeline_config_path path/to/ssd_model/pipeline.config \
...@@ -71,6 +71,27 @@ python object_detection/export_tflite_model_tf2.py \ ...@@ -71,6 +71,27 @@ python object_detection/export_tflite_model_tf2.py \
} \ } \
} \ } \
" "
Example Usage 2 (export CenterNet model for keypoint estimation task with fixed
shape resizer and customized input resolution):
python object_detection/export_tflite_model_tf2.py \
--pipeline_config_path path/to/ssd_model/pipeline.config \
--trained_checkpoint_dir path/to/ssd_model/checkpoint \
--output_directory path/to/exported_model_directory \
--keypoint_label_map_path path/to/label_map.txt \
--max_detections 10 \
--centernet_include_keypoints true \
--config_override " \
model{ \
center_net { \
image_resizer { \
fixed_shape_resizer { \
height: 320 \
width: 320 \
} \
} \
} \
}" \
""" """
from absl import app from absl import app
from absl import flags from absl import flags
...@@ -107,6 +128,13 @@ flags.DEFINE_bool( ...@@ -107,6 +128,13 @@ flags.DEFINE_bool(
'Whether to export the predicted keypoint tensors. Only CenterNet model' 'Whether to export the predicted keypoint tensors. Only CenterNet model'
' supports this flag.' ' supports this flag.'
) )
flags.DEFINE_string(
'keypoint_label_map_path', None,
'Path of the label map used by CenterNet keypoint estimation task. If'
' provided, the label map path in the pipeline config will be replaced by'
' this one. Note that it is only used when exporting CenterNet model for'
' keypoint estimation task.'
)
def main(argv): def main(argv):
...@@ -119,12 +147,14 @@ def main(argv): ...@@ -119,12 +147,14 @@ def main(argv):
with tf.io.gfile.GFile(FLAGS.pipeline_config_path, 'r') as f: with tf.io.gfile.GFile(FLAGS.pipeline_config_path, 'r') as f:
text_format.Parse(f.read(), pipeline_config) text_format.Parse(f.read(), pipeline_config)
text_format.Parse(FLAGS.config_override, pipeline_config) override_config = pipeline_pb2.TrainEvalPipelineConfig()
text_format.Parse(FLAGS.config_override, override_config)
pipeline_config.MergeFrom(override_config)
export_tflite_graph_lib_tf2.export_tflite_model( export_tflite_graph_lib_tf2.export_tflite_model(
pipeline_config, FLAGS.trained_checkpoint_dir, FLAGS.output_directory, pipeline_config, FLAGS.trained_checkpoint_dir, FLAGS.output_directory,
FLAGS.max_detections, FLAGS.ssd_use_regular_nms, FLAGS.max_detections, FLAGS.ssd_use_regular_nms,
FLAGS.centernet_include_keypoints) FLAGS.centernet_include_keypoints, FLAGS.keypoint_label_map_path)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -34,6 +34,8 @@ Model name ...@@ -34,6 +34,8 @@ Model name
[CenterNet Resnet101 V1 FPN 512x512](http://download.tensorflow.org/models/object_detection/tf2/20200711/centernet_resnet101_v1_fpn_512x512_coco17_tpu-8.tar.gz) | 34 | 34.2 | Boxes [CenterNet Resnet101 V1 FPN 512x512](http://download.tensorflow.org/models/object_detection/tf2/20200711/centernet_resnet101_v1_fpn_512x512_coco17_tpu-8.tar.gz) | 34 | 34.2 | Boxes
[CenterNet Resnet50 V2 512x512](http://download.tensorflow.org/models/object_detection/tf2/20200711/centernet_resnet50_v2_512x512_coco17_tpu-8.tar.gz) | 27 | 29.5 | Boxes [CenterNet Resnet50 V2 512x512](http://download.tensorflow.org/models/object_detection/tf2/20200711/centernet_resnet50_v2_512x512_coco17_tpu-8.tar.gz) | 27 | 29.5 | Boxes
[CenterNet Resnet50 V2 Keypoints 512x512](http://download.tensorflow.org/models/object_detection/tf2/20200711/centernet_resnet50_v2_512x512_kpts_coco17_tpu-8.tar.gz) | 30 | 27.6/48.2 | Boxes/Keypoints [CenterNet Resnet50 V2 Keypoints 512x512](http://download.tensorflow.org/models/object_detection/tf2/20200711/centernet_resnet50_v2_512x512_kpts_coco17_tpu-8.tar.gz) | 30 | 27.6/48.2 | Boxes/Keypoints
[CenterNet MobileNetV2 FPN 512x512](http://download.tensorflow.org/models/object_detection/tf2/20210210/centernet_mobilenetv2fpn_512x512_coco17_od.tar.gz) | 6 | 23.4 | Boxes
[CenterNet MobileNetV2 FPN Keypoints 512x512](http://download.tensorflow.org/models/object_detection/tf2/20210210/centernet_mobilenetv2fpn_512x512_coco17_kpts.tar.gz) | 6 | 41.7 | Keypoints
[EfficientDet D0 512x512](http://download.tensorflow.org/models/object_detection/tf2/20200711/efficientdet_d0_coco17_tpu-32.tar.gz) | 39 | 33.6 | Boxes [EfficientDet D0 512x512](http://download.tensorflow.org/models/object_detection/tf2/20200711/efficientdet_d0_coco17_tpu-32.tar.gz) | 39 | 33.6 | Boxes
[EfficientDet D1 640x640](http://download.tensorflow.org/models/object_detection/tf2/20200711/efficientdet_d1_coco17_tpu-32.tar.gz) | 54 | 38.4 | Boxes [EfficientDet D1 640x640](http://download.tensorflow.org/models/object_detection/tf2/20200711/efficientdet_d1_coco17_tpu-32.tar.gz) | 54 | 38.4 | Boxes
[EfficientDet D2 768x768](http://download.tensorflow.org/models/object_detection/tf2/20200711/efficientdet_d2_coco17_tpu-32.tar.gz) | 67 | 41.8 | Boxes [EfficientDet D2 768x768](http://download.tensorflow.org/models/object_detection/tf2/20200711/efficientdet_d2_coco17_tpu-32.tar.gz) | 67 | 41.8 | Boxes
......
...@@ -423,12 +423,12 @@ def prediction_tensors_to_temporal_offsets( ...@@ -423,12 +423,12 @@ def prediction_tensors_to_temporal_offsets(
return offsets return offsets
def prediction_tensors_to_keypoint_candidates( def prediction_tensors_to_keypoint_candidates(keypoint_heatmap_predictions,
keypoint_heatmap_predictions, keypoint_heatmap_offsets,
keypoint_heatmap_offsets, keypoint_score_threshold=0.1,
keypoint_score_threshold=0.1, max_pool_kernel_size=1,
max_pool_kernel_size=1, max_candidates=20,
max_candidates=20): keypoint_depths=None):
"""Convert keypoint heatmap predictions and offsets to keypoint candidates. """Convert keypoint heatmap predictions and offsets to keypoint candidates.
Args: Args:
...@@ -437,14 +437,17 @@ def prediction_tensors_to_keypoint_candidates( ...@@ -437,14 +437,17 @@ def prediction_tensors_to_keypoint_candidates(
keypoint_heatmap_offsets: A float tensor of shape [batch_size, height, keypoint_heatmap_offsets: A float tensor of shape [batch_size, height,
width, 2] (or [batch_size, height, width, 2 * num_keypoints] if width, 2] (or [batch_size, height, width, 2 * num_keypoints] if
'per_keypoint_offset' is set True) representing the per-keypoint offsets. 'per_keypoint_offset' is set True) representing the per-keypoint offsets.
keypoint_score_threshold: float, the threshold for considering a keypoint keypoint_score_threshold: float, the threshold for considering a keypoint a
a candidate. candidate.
max_pool_kernel_size: integer, the max pool kernel size to use to pull off max_pool_kernel_size: integer, the max pool kernel size to use to pull off
peak score locations in a neighborhood. For example, to make sure no two peak score locations in a neighborhood. For example, to make sure no two
neighboring values for the same keypoint are returned, set neighboring values for the same keypoint are returned, set
max_pool_kernel_size=3. If None or 1, will not apply any local filtering. max_pool_kernel_size=3. If None or 1, will not apply any local filtering.
max_candidates: integer, maximum number of keypoint candidates per max_candidates: integer, maximum number of keypoint candidates per keypoint
keypoint type. type.
keypoint_depths: (optional) A float tensor of shape [batch_size, height,
width, 1] (or [batch_size, height, width, num_keypoints] if
'per_keypoint_depth' is set True) representing the per-keypoint depths.
Returns: Returns:
keypoint_candidates: A tensor of shape keypoint_candidates: A tensor of shape
...@@ -458,6 +461,9 @@ def prediction_tensors_to_keypoint_candidates( ...@@ -458,6 +461,9 @@ def prediction_tensors_to_keypoint_candidates(
[batch_size, num_keypoints] with the number of candidates for each [batch_size, num_keypoints] with the number of candidates for each
keypoint type, as it's possible to filter some candidates due to the score keypoint type, as it's possible to filter some candidates due to the score
threshold. threshold.
depth_candidates: A tensor of shape [batch_size, max_candidates,
num_keypoints] representing the estimated depth of each keypoint
candidate. Return None if the input keypoint_depths is None.
""" """
batch_size, _, _, num_keypoints = _get_shape(keypoint_heatmap_predictions, 4) batch_size, _, _, num_keypoints = _get_shape(keypoint_heatmap_predictions, 4)
# Get x, y and channel indices corresponding to the top indices in the # Get x, y and channel indices corresponding to the top indices in the
...@@ -499,13 +505,13 @@ def prediction_tensors_to_keypoint_candidates( ...@@ -499,13 +505,13 @@ def prediction_tensors_to_keypoint_candidates(
# TF Lite does not support tf.gather with batch_dims > 0, so we need to use # TF Lite does not support tf.gather with batch_dims > 0, so we need to use
# tf_gather_nd instead and here we prepare the indices for that. In this # tf_gather_nd instead and here we prepare the indices for that. In this
# case, channel_indices indicates which keypoint to use the offset from. # case, channel_indices indicates which keypoint to use the offset from.
combined_indices = tf.stack([ channel_combined_indices = tf.stack([
_multi_range(batch_size, value_repetitions=num_indices), _multi_range(batch_size, value_repetitions=num_indices),
_multi_range(num_indices, range_repetitions=batch_size), _multi_range(num_indices, range_repetitions=batch_size),
tf.reshape(channel_indices, [-1]) tf.reshape(channel_indices, [-1])
], axis=1) ], axis=1)
offsets = tf.gather_nd(reshaped_offsets, combined_indices) offsets = tf.gather_nd(reshaped_offsets, channel_combined_indices)
offsets = tf.reshape(offsets, [batch_size, num_indices, -1]) offsets = tf.reshape(offsets, [batch_size, num_indices, -1])
else: else:
offsets = selected_offsets offsets = selected_offsets
...@@ -524,14 +530,38 @@ def prediction_tensors_to_keypoint_candidates( ...@@ -524,14 +530,38 @@ def prediction_tensors_to_keypoint_candidates(
num_candidates = tf.reduce_sum( num_candidates = tf.reduce_sum(
tf.to_int32(keypoint_scores >= keypoint_score_threshold), axis=1) tf.to_int32(keypoint_scores >= keypoint_score_threshold), axis=1)
return keypoint_candidates, keypoint_scores, num_candidates depth_candidates = None
if keypoint_depths is not None:
selected_depth_flat = tf.gather_nd(keypoint_depths, combined_indices)
selected_depth = tf.reshape(selected_depth_flat,
[batch_size, num_indices, -1])
_, _, num_depth_channels = _get_shape(selected_depth, 3)
if num_depth_channels > 1:
combined_indices = tf.stack([
_multi_range(batch_size, value_repetitions=num_indices),
_multi_range(num_indices, range_repetitions=batch_size),
tf.reshape(channel_indices, [-1])
], axis=1)
depth = tf.gather_nd(selected_depth, combined_indices)
depth = tf.reshape(depth, [batch_size, num_indices, -1])
else:
depth = selected_depth
depth_candidates = tf.reshape(depth,
[batch_size, num_keypoints, max_candidates])
depth_candidates = tf.transpose(depth_candidates, [0, 2, 1])
return keypoint_candidates, keypoint_scores, num_candidates, depth_candidates
def prediction_to_single_instance_keypoints(object_heatmap, keypoint_heatmap,
def prediction_to_single_instance_keypoints(object_heatmap,
keypoint_heatmap,
keypoint_offset, keypoint_offset,
keypoint_regression, stride, keypoint_regression,
stride,
object_center_std_dev, object_center_std_dev,
keypoint_std_dev, kp_params): keypoint_std_dev,
kp_params,
keypoint_depths=None):
"""Postprocess function to predict single instance keypoints. """Postprocess function to predict single instance keypoints.
This is a simplified postprocessing function based on the assumption that This is a simplified postprocessing function based on the assumption that
...@@ -569,6 +599,9 @@ def prediction_to_single_instance_keypoints(object_heatmap, keypoint_heatmap, ...@@ -569,6 +599,9 @@ def prediction_to_single_instance_keypoints(object_heatmap, keypoint_heatmap,
representing the standard deviation corresponding to each joint. representing the standard deviation corresponding to each joint.
kp_params: A `KeypointEstimationParams` object with parameters for a single kp_params: A `KeypointEstimationParams` object with parameters for a single
keypoint class. keypoint class.
keypoint_depths: (optional) A float tensor of shape [batch_size, height,
width, 1] (or [batch_size, height, width, num_keypoints] if
'per_keypoint_depth' is set True) representing the per-keypoint depths.
Returns: Returns:
A tuple of two tensors: A tuple of two tensors:
...@@ -577,6 +610,9 @@ def prediction_to_single_instance_keypoints(object_heatmap, keypoint_heatmap, ...@@ -577,6 +610,9 @@ def prediction_to_single_instance_keypoints(object_heatmap, keypoint_heatmap,
map space. map space.
keypoint_scores: A float tensor with shape [1, 1, num_keypoints] keypoint_scores: A float tensor with shape [1, 1, num_keypoints]
representing the keypoint prediction scores. representing the keypoint prediction scores.
keypoint_depths: A float tensor with shape [1, 1, num_keypoints]
representing the estimated keypoint depths. Return None if the input
keypoint_depths is None.
Raises: Raises:
ValueError: if the input keypoint_std_dev doesn't have valid number of ValueError: if the input keypoint_std_dev doesn't have valid number of
...@@ -636,14 +672,16 @@ def prediction_to_single_instance_keypoints(object_heatmap, keypoint_heatmap, ...@@ -636,14 +672,16 @@ def prediction_to_single_instance_keypoints(object_heatmap, keypoint_heatmap,
# Get the keypoint locations/scores: # Get the keypoint locations/scores:
# keypoint_candidates: [1, 1, num_keypoints, 2] # keypoint_candidates: [1, 1, num_keypoints, 2]
# keypoint_scores: [1, 1, num_keypoints] # keypoint_scores: [1, 1, num_keypoints]
(keypoint_candidates, keypoint_scores, # depth_candidates: [1, 1, num_keypoints]
_) = prediction_tensors_to_keypoint_candidates( (keypoint_candidates, keypoint_scores, _,
depth_candidates) = prediction_tensors_to_keypoint_candidates(
keypoint_predictions, keypoint_predictions,
keypoint_offset, keypoint_offset,
keypoint_score_threshold=kp_params.keypoint_candidate_score_threshold, keypoint_score_threshold=kp_params.keypoint_candidate_score_threshold,
max_pool_kernel_size=kp_params.peak_max_pool_kernel_size, max_pool_kernel_size=kp_params.peak_max_pool_kernel_size,
max_candidates=1) max_candidates=1,
return keypoint_candidates, keypoint_scores keypoint_depths=keypoint_depths)
return keypoint_candidates, keypoint_scores, depth_candidates
def regressed_keypoints_at_object_centers(regressed_keypoint_predictions, def regressed_keypoints_at_object_centers(regressed_keypoint_predictions,
...@@ -697,11 +735,17 @@ def regressed_keypoints_at_object_centers(regressed_keypoint_predictions, ...@@ -697,11 +735,17 @@ def regressed_keypoints_at_object_centers(regressed_keypoint_predictions,
[batch_size, num_instances, -1]) [batch_size, num_instances, -1])
def refine_keypoints(regressed_keypoints, keypoint_candidates, keypoint_scores, def refine_keypoints(regressed_keypoints,
num_keypoint_candidates, bboxes=None, keypoint_candidates,
unmatched_keypoint_score=0.1, box_scale=1.2, keypoint_scores,
num_keypoint_candidates,
bboxes=None,
unmatched_keypoint_score=0.1,
box_scale=1.2,
candidate_search_scale=0.3, candidate_search_scale=0.3,
candidate_ranking_mode='min_distance'): candidate_ranking_mode='min_distance',
score_distance_offset=1e-6,
keypoint_depth_candidates=None):
"""Refines regressed keypoints by snapping to the nearest candidate keypoints. """Refines regressed keypoints by snapping to the nearest candidate keypoints.
The initial regressed keypoints represent a full set of keypoints regressed The initial regressed keypoints represent a full set of keypoints regressed
...@@ -757,6 +801,14 @@ def refine_keypoints(regressed_keypoints, keypoint_candidates, keypoint_scores, ...@@ -757,6 +801,14 @@ def refine_keypoints(regressed_keypoints, keypoint_candidates, keypoint_scores,
candidate_ranking_mode: A string as one of ['min_distance', candidate_ranking_mode: A string as one of ['min_distance',
'score_distance_ratio'] indicating how to select the candidate. If invalid 'score_distance_ratio'] indicating how to select the candidate. If invalid
value is provided, an ValueError will be raised. value is provided, an ValueError will be raised.
score_distance_offset: The distance offset to apply in the denominator when
candidate_ranking_mode is 'score_distance_ratio'. The metric to maximize
in this scenario is score / (distance + score_distance_offset). Larger
values of score_distance_offset make the keypoint score gain more relative
importance.
keypoint_depth_candidates: (optional) A float tensor of shape
[batch_size, max_candidates, num_keypoints] indicating the depths for
keypoint candidates.
Returns: Returns:
A tuple with: A tuple with:
...@@ -827,7 +879,7 @@ def refine_keypoints(regressed_keypoints, keypoint_candidates, keypoint_scores, ...@@ -827,7 +879,7 @@ def refine_keypoints(regressed_keypoints, keypoint_candidates, keypoint_scores,
tiled_keypoint_scores = tf.tile( tiled_keypoint_scores = tf.tile(
tf.expand_dims(keypoint_scores, axis=1), tf.expand_dims(keypoint_scores, axis=1),
multiples=[1, num_instances, 1, 1]) multiples=[1, num_instances, 1, 1])
ranking_scores = tiled_keypoint_scores / (distances + 1e-6) ranking_scores = tiled_keypoint_scores / (distances + score_distance_offset)
nearby_candidate_inds = tf.math.argmax(ranking_scores, axis=2) nearby_candidate_inds = tf.math.argmax(ranking_scores, axis=2)
else: else:
raise ValueError('Not recognized candidate_ranking_mode: %s' % raise ValueError('Not recognized candidate_ranking_mode: %s' %
...@@ -836,9 +888,11 @@ def refine_keypoints(regressed_keypoints, keypoint_candidates, keypoint_scores, ...@@ -836,9 +888,11 @@ def refine_keypoints(regressed_keypoints, keypoint_candidates, keypoint_scores,
# Gather the coordinates and scores corresponding to the closest candidates. # Gather the coordinates and scores corresponding to the closest candidates.
# Shape of tensors are [batch_size, num_instances, num_keypoints, 2] and # Shape of tensors are [batch_size, num_instances, num_keypoints, 2] and
# [batch_size, num_instances, num_keypoints], respectively. # [batch_size, num_instances, num_keypoints], respectively.
nearby_candidate_coords, nearby_candidate_scores = ( (nearby_candidate_coords, nearby_candidate_scores,
_gather_candidates_at_indices(keypoint_candidates, keypoint_scores, nearby_candidate_depths) = (
nearby_candidate_inds)) _gather_candidates_at_indices(keypoint_candidates, keypoint_scores,
nearby_candidate_inds,
keypoint_depth_candidates))
if bboxes is None: if bboxes is None:
# Create bboxes from regressed keypoints. # Create bboxes from regressed keypoints.
...@@ -895,7 +949,12 @@ def refine_keypoints(regressed_keypoints, keypoint_candidates, keypoint_scores, ...@@ -895,7 +949,12 @@ def refine_keypoints(regressed_keypoints, keypoint_candidates, keypoint_scores,
unmatched_keypoint_score * tf.ones_like(nearby_candidate_scores), unmatched_keypoint_score * tf.ones_like(nearby_candidate_scores),
nearby_candidate_scores) nearby_candidate_scores)
return refined_keypoints, refined_scores refined_depths = None
if nearby_candidate_depths is not None:
refined_depths = tf.where(mask, tf.zeros_like(nearby_candidate_depths),
nearby_candidate_depths)
return refined_keypoints, refined_scores, refined_depths
def _pad_to_full_keypoint_dim(keypoint_coords, keypoint_scores, keypoint_inds, def _pad_to_full_keypoint_dim(keypoint_coords, keypoint_scores, keypoint_inds,
...@@ -976,8 +1035,10 @@ def _pad_to_full_instance_dim(keypoint_coords, keypoint_scores, instance_inds, ...@@ -976,8 +1035,10 @@ def _pad_to_full_instance_dim(keypoint_coords, keypoint_scores, instance_inds,
return keypoint_coords_padded, keypoint_scores_padded return keypoint_coords_padded, keypoint_scores_padded
def _gather_candidates_at_indices(keypoint_candidates, keypoint_scores, def _gather_candidates_at_indices(keypoint_candidates,
indices): keypoint_scores,
indices,
keypoint_depth_candidates=None):
"""Gathers keypoint candidate coordinates and scores at indices. """Gathers keypoint candidate coordinates and scores at indices.
Args: Args:
...@@ -987,13 +1048,18 @@ def _gather_candidates_at_indices(keypoint_candidates, keypoint_scores, ...@@ -987,13 +1048,18 @@ def _gather_candidates_at_indices(keypoint_candidates, keypoint_scores,
num_keypoints] with keypoint scores. num_keypoints] with keypoint scores.
indices: an integer tensor of shape [batch_size, num_indices, num_keypoints] indices: an integer tensor of shape [batch_size, num_indices, num_keypoints]
with indices. with indices.
keypoint_depth_candidates: (optional) a float tensor of shape [batch_size,
max_candidates, num_keypoints] with keypoint depths.
Returns: Returns:
A tuple with A tuple with
gathered_keypoint_candidates: a float tensor of shape [batch_size, gathered_keypoint_candidates: a float tensor of shape [batch_size,
num_indices, num_keypoints, 2] with gathered coordinates. num_indices, num_keypoints, 2] with gathered coordinates.
gathered_keypoint_scores: a float tensor of shape [batch_size, gathered_keypoint_scores: a float tensor of shape [batch_size,
num_indices, num_keypoints, 2]. num_indices, num_keypoints].
gathered_keypoint_depths: a float tensor of shape [batch_size,
num_indices, num_keypoints]. Return None if the input
keypoint_depth_candidates is None.
""" """
batch_size, num_indices, num_keypoints = _get_shape(indices, 3) batch_size, num_indices, num_keypoints = _get_shape(indices, 3)
...@@ -1035,7 +1101,19 @@ def _gather_candidates_at_indices(keypoint_candidates, keypoint_scores, ...@@ -1035,7 +1101,19 @@ def _gather_candidates_at_indices(keypoint_candidates, keypoint_scores,
gathered_keypoint_scores = tf.transpose(nearby_candidate_scores_transposed, gathered_keypoint_scores = tf.transpose(nearby_candidate_scores_transposed,
[0, 2, 1]) [0, 2, 1])
return gathered_keypoint_candidates, gathered_keypoint_scores gathered_keypoint_depths = None
if keypoint_depth_candidates is not None:
keypoint_depths_transposed = tf.transpose(keypoint_depth_candidates,
[0, 2, 1])
nearby_candidate_depths_transposed = tf.gather_nd(
keypoint_depths_transposed, combined_indices)
nearby_candidate_depths_transposed = tf.reshape(
nearby_candidate_depths_transposed,
[batch_size, num_keypoints, num_indices])
gathered_keypoint_depths = tf.transpose(nearby_candidate_depths_transposed,
[0, 2, 1])
return (gathered_keypoint_candidates, gathered_keypoint_scores,
gathered_keypoint_depths)
def flattened_indices_from_row_col_indices(row_indices, col_indices, num_cols): def flattened_indices_from_row_col_indices(row_indices, col_indices, num_cols):
...@@ -1517,7 +1595,10 @@ class KeypointEstimationParams( ...@@ -1517,7 +1595,10 @@ class KeypointEstimationParams(
'heatmap_bias_init', 'num_candidates_per_keypoint', 'task_loss_weight', 'heatmap_bias_init', 'num_candidates_per_keypoint', 'task_loss_weight',
'peak_max_pool_kernel_size', 'unmatched_keypoint_score', 'box_scale', 'peak_max_pool_kernel_size', 'unmatched_keypoint_score', 'box_scale',
'candidate_search_scale', 'candidate_ranking_mode', 'candidate_search_scale', 'candidate_ranking_mode',
'offset_peak_radius', 'per_keypoint_offset' 'offset_peak_radius', 'per_keypoint_offset', 'predict_depth',
'per_keypoint_depth', 'keypoint_depth_loss_weight',
'score_distance_offset', 'clip_out_of_frame_keypoints',
'rescore_instances'
])): ])):
"""Namedtuple to host object detection related parameters. """Namedtuple to host object detection related parameters.
...@@ -1550,7 +1631,13 @@ class KeypointEstimationParams( ...@@ -1550,7 +1631,13 @@ class KeypointEstimationParams(
candidate_search_scale=0.3, candidate_search_scale=0.3,
candidate_ranking_mode='min_distance', candidate_ranking_mode='min_distance',
offset_peak_radius=0, offset_peak_radius=0,
per_keypoint_offset=False): per_keypoint_offset=False,
predict_depth=False,
per_keypoint_depth=False,
keypoint_depth_loss_weight=1.0,
score_distance_offset=1e-6,
clip_out_of_frame_keypoints=False,
rescore_instances=False):
"""Constructor with default values for KeypointEstimationParams. """Constructor with default values for KeypointEstimationParams.
Args: Args:
...@@ -1614,6 +1701,22 @@ class KeypointEstimationParams( ...@@ -1614,6 +1701,22 @@ class KeypointEstimationParams(
original paper). If set True, the output offset target has the shape original paper). If set True, the output offset target has the shape
[batch_size, out_height, out_width, 2 * num_keypoints] (recommended when [batch_size, out_height, out_width, 2 * num_keypoints] (recommended when
the offset_peak_radius is not zero). the offset_peak_radius is not zero).
predict_depth: A bool indicates whether to predict the depth of each
keypoints.
per_keypoint_depth: A bool indicates whether the model predicts the depth
of each keypoints in independent channels. Similar to
per_keypoint_offset but for the keypoint depth.
keypoint_depth_loss_weight: The weight of the keypoint depth loss.
score_distance_offset: The distance offset to apply in the denominator
when candidate_ranking_mode is 'score_distance_ratio'. The metric to
maximize in this scenario is score / (distance + score_distance_offset).
Larger values of score_distance_offset make the keypoint score gain more
relative importance.
clip_out_of_frame_keypoints: Whether keypoints outside the image frame
should be clipped back to the image boundary. If True, the keypoints
that are clipped have scores set to 0.0.
rescore_instances: Whether to rescore instances based on a combination of
detection score and keypoint scores.
Returns: Returns:
An initialized KeypointEstimationParams namedtuple. An initialized KeypointEstimationParams namedtuple.
...@@ -1626,7 +1729,9 @@ class KeypointEstimationParams( ...@@ -1626,7 +1729,9 @@ class KeypointEstimationParams(
heatmap_bias_init, num_candidates_per_keypoint, task_loss_weight, heatmap_bias_init, num_candidates_per_keypoint, task_loss_weight,
peak_max_pool_kernel_size, unmatched_keypoint_score, box_scale, peak_max_pool_kernel_size, unmatched_keypoint_score, box_scale,
candidate_search_scale, candidate_ranking_mode, offset_peak_radius, candidate_search_scale, candidate_ranking_mode, offset_peak_radius,
per_keypoint_offset) per_keypoint_offset, predict_depth, per_keypoint_depth,
keypoint_depth_loss_weight, score_distance_offset,
clip_out_of_frame_keypoints, rescore_instances)
class ObjectCenterParams( class ObjectCenterParams(
...@@ -1839,6 +1944,7 @@ BOX_OFFSET = 'box/offset' ...@@ -1839,6 +1944,7 @@ BOX_OFFSET = 'box/offset'
KEYPOINT_REGRESSION = 'keypoint/regression' KEYPOINT_REGRESSION = 'keypoint/regression'
KEYPOINT_HEATMAP = 'keypoint/heatmap' KEYPOINT_HEATMAP = 'keypoint/heatmap'
KEYPOINT_OFFSET = 'keypoint/offset' KEYPOINT_OFFSET = 'keypoint/offset'
KEYPOINT_DEPTH = 'keypoint/depth'
SEGMENTATION_TASK = 'segmentation_task' SEGMENTATION_TASK = 'segmentation_task'
SEGMENTATION_HEATMAP = 'segmentation/heatmap' SEGMENTATION_HEATMAP = 'segmentation/heatmap'
DENSEPOSE_TASK = 'densepose_task' DENSEPOSE_TASK = 'densepose_task'
...@@ -2055,6 +2161,15 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -2055,6 +2161,15 @@ class CenterNetMetaArch(model.DetectionModel):
use_depthwise=self._use_depthwise) use_depthwise=self._use_depthwise)
for _ in range(num_feature_outputs) for _ in range(num_feature_outputs)
] ]
if kp_params.predict_depth:
num_depth_channel = (
num_keypoints if kp_params.per_keypoint_depth else 1)
prediction_heads[get_keypoint_name(task_name, KEYPOINT_DEPTH)] = [
make_prediction_net(
num_depth_channel, use_depthwise=self._use_depthwise)
for _ in range(num_feature_outputs)
]
# pylint: disable=g-complex-comprehension # pylint: disable=g-complex-comprehension
if self._mask_params is not None: if self._mask_params is not None:
prediction_heads[SEGMENTATION_HEATMAP] = [ prediction_heads[SEGMENTATION_HEATMAP] = [
...@@ -2305,6 +2420,7 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -2305,6 +2420,7 @@ class CenterNetMetaArch(model.DetectionModel):
heatmap_key = get_keypoint_name(task_name, KEYPOINT_HEATMAP) heatmap_key = get_keypoint_name(task_name, KEYPOINT_HEATMAP)
offset_key = get_keypoint_name(task_name, KEYPOINT_OFFSET) offset_key = get_keypoint_name(task_name, KEYPOINT_OFFSET)
regression_key = get_keypoint_name(task_name, KEYPOINT_REGRESSION) regression_key = get_keypoint_name(task_name, KEYPOINT_REGRESSION)
depth_key = get_keypoint_name(task_name, KEYPOINT_DEPTH)
heatmap_loss = self._compute_kp_heatmap_loss( heatmap_loss = self._compute_kp_heatmap_loss(
input_height=input_height, input_height=input_height,
input_width=input_width, input_width=input_width,
...@@ -2332,6 +2448,14 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -2332,6 +2448,14 @@ class CenterNetMetaArch(model.DetectionModel):
kp_params.keypoint_offset_loss_weight * offset_loss) kp_params.keypoint_offset_loss_weight * offset_loss)
loss_dict[regression_key] = ( loss_dict[regression_key] = (
kp_params.keypoint_regression_loss_weight * reg_loss) kp_params.keypoint_regression_loss_weight * reg_loss)
if kp_params.predict_depth:
depth_loss = self._compute_kp_depth_loss(
input_height=input_height,
input_width=input_width,
task_name=task_name,
depth_predictions=prediction_dict[depth_key],
localization_loss_fn=kp_params.localization_loss)
loss_dict[depth_key] = kp_params.keypoint_depth_loss_weight * depth_loss
return loss_dict return loss_dict
def _compute_kp_heatmap_loss(self, input_height, input_width, task_name, def _compute_kp_heatmap_loss(self, input_height, input_width, task_name,
...@@ -2501,6 +2625,68 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -2501,6 +2625,68 @@ class CenterNetMetaArch(model.DetectionModel):
tf.maximum(tf.reduce_sum(batch_weights), 1.0)) tf.maximum(tf.reduce_sum(batch_weights), 1.0))
return loss return loss
def _compute_kp_depth_loss(self, input_height, input_width, task_name,
depth_predictions, localization_loss_fn):
"""Computes the loss of the keypoint depth estimation.
Args:
input_height: An integer scalar tensor representing input image height.
input_width: An integer scalar tensor representing input image width.
task_name: A string representing the name of the keypoint task.
depth_predictions: A list of float tensors of shape [batch_size,
out_height, out_width, 1 (or num_keypoints)] representing the prediction
heads of the model for keypoint depth.
localization_loss_fn: An object_detection.core.losses.Loss object to
compute the loss for the keypoint offset predictions in CenterNet.
Returns:
loss: A float scalar tensor representing the keypoint depth loss
normalized by number of total keypoints.
"""
kp_params = self._kp_params_dict[task_name]
gt_keypoints_list = self.groundtruth_lists(fields.BoxListFields.keypoints)
gt_classes_list = self.groundtruth_lists(fields.BoxListFields.classes)
gt_weights_list = self.groundtruth_lists(fields.BoxListFields.weights)
gt_keypoint_depths_list = self.groundtruth_lists(
fields.BoxListFields.keypoint_depths)
gt_keypoint_depth_weights_list = self.groundtruth_lists(
fields.BoxListFields.keypoint_depth_weights)
assigner = self._target_assigner_dict[task_name]
(batch_indices, batch_depths,
batch_weights) = assigner.assign_keypoints_depth_targets(
height=input_height,
width=input_width,
gt_keypoints_list=gt_keypoints_list,
gt_weights_list=gt_weights_list,
gt_classes_list=gt_classes_list,
gt_keypoint_depths_list=gt_keypoint_depths_list,
gt_keypoint_depth_weights_list=gt_keypoint_depth_weights_list)
if kp_params.per_keypoint_offset and not kp_params.per_keypoint_depth:
batch_indices = batch_indices[:, 0:3]
# Keypoint offset loss.
loss = 0.0
for prediction in depth_predictions:
selected_depths = cn_assigner.get_batch_predictions_from_indices(
prediction, batch_indices)
if kp_params.per_keypoint_offset and kp_params.per_keypoint_depth:
selected_depths = tf.expand_dims(selected_depths, axis=-1)
# The dimensions passed are not as per the doc string but the loss
# still computes the correct value.
unweighted_loss = localization_loss_fn(
selected_depths,
batch_depths,
weights=tf.expand_dims(tf.ones_like(batch_weights), -1))
# Apply the weights after the loss function to have full control over it.
loss += batch_weights * tf.squeeze(unweighted_loss, axis=1)
loss = tf.reduce_sum(loss) / (
float(len(depth_predictions)) *
tf.maximum(tf.reduce_sum(batch_weights), 1.0))
return loss
def _compute_segmentation_losses(self, prediction_dict, per_pixel_weights): def _compute_segmentation_losses(self, prediction_dict, per_pixel_weights):
"""Computes all the losses associated with segmentation. """Computes all the losses associated with segmentation.
...@@ -2785,6 +2971,71 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -2785,6 +2971,71 @@ class CenterNetMetaArch(model.DetectionModel):
loss_dict[TEMPORAL_OFFSET] = offset_loss loss_dict[TEMPORAL_OFFSET] = offset_loss
return loss_dict return loss_dict
def _should_clip_keypoints(self):
"""Returns a boolean indicating whether keypoint clipping should occur.
If there is only one keypoint task, clipping is controlled by the field
`clip_out_of_frame_keypoints`. If there are multiple keypoint tasks,
clipping logic is defined based on unanimous agreement of keypoint
parameters. If there is any ambiguity, clip_out_of_frame_keypoints is set
to False (default).
"""
kp_params_iterator = iter(self._kp_params_dict.values())
if len(self._kp_params_dict) == 1:
kp_params = next(kp_params_iterator)
return kp_params.clip_out_of_frame_keypoints
# Multi-task setting.
kp_params = next(kp_params_iterator)
should_clip = kp_params.clip_out_of_frame_keypoints
for kp_params in kp_params_iterator:
if kp_params.clip_out_of_frame_keypoints != should_clip:
return False
return should_clip
def _rescore_instances(self, classes, scores, keypoint_scores):
"""Rescores instances based on detection and keypoint scores.
Args:
classes: A [batch, max_detections] int32 tensor with detection classes.
scores: A [batch, max_detections] float32 tensor with detection scores.
keypoint_scores: A [batch, max_detections, total_num_keypoints] float32
tensor with keypoint scores.
Returns:
A [batch, max_detections] float32 tensor with possibly altered detection
scores.
"""
batch, max_detections, total_num_keypoints = (
shape_utils.combined_static_and_dynamic_shape(keypoint_scores))
classes_tiled = tf.tile(classes[:, :, tf.newaxis],
multiples=[1, 1, total_num_keypoints])
# TODO(yuhuic): Investigate whether this function will reate subgraphs in
# tflite that will cause the model to run slower at inference.
for kp_params in self._kp_params_dict.values():
if not kp_params.rescore_instances:
continue
class_id = kp_params.class_id
keypoint_indices = kp_params.keypoint_indices
num_keypoints = len(keypoint_indices)
kpt_mask = tf.reduce_sum(
tf.one_hot(keypoint_indices, depth=total_num_keypoints), axis=0)
kpt_mask_tiled = tf.tile(kpt_mask[tf.newaxis, tf.newaxis, :],
multiples=[batch, max_detections, 1])
class_and_keypoint_mask = tf.math.logical_and(
classes_tiled == class_id,
kpt_mask_tiled == 1.0)
class_and_keypoint_mask_float = tf.cast(class_and_keypoint_mask,
dtype=tf.float32)
scores_for_class = (1./num_keypoints) * (
tf.reduce_sum(class_and_keypoint_mask_float *
scores[:, :, tf.newaxis] *
keypoint_scores, axis=-1))
scores = tf.where(classes == class_id,
scores_for_class,
scores)
return scores
def preprocess(self, inputs): def preprocess(self, inputs):
outputs = shape_utils.resize_images_and_return_shapes( outputs = shape_utils.resize_images_and_return_shapes(
inputs, self._image_resizer_fn) inputs, self._image_resizer_fn)
...@@ -3050,26 +3301,40 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -3050,26 +3301,40 @@ class CenterNetMetaArch(model.DetectionModel):
# If the model is trained to predict only one class of object and its # If the model is trained to predict only one class of object and its
# keypoint, we fall back to a simpler postprocessing function which uses # keypoint, we fall back to a simpler postprocessing function which uses
# the ops that are supported by tf.lite on GPU. # the ops that are supported by tf.lite on GPU.
clip_keypoints = self._should_clip_keypoints()
if len(self._kp_params_dict) == 1 and self._num_classes == 1: if len(self._kp_params_dict) == 1 and self._num_classes == 1:
keypoints, keypoint_scores = self._postprocess_keypoints_single_class( (keypoints, keypoint_scores,
prediction_dict, classes, y_indices, x_indices, keypoint_depths) = self._postprocess_keypoints_single_class(
boxes_strided, num_detections) prediction_dict, classes, y_indices, x_indices, boxes_strided,
# The map_fn used to clip out of frame keypoints creates issues when num_detections)
# converting to tf.lite model so we disable it and let the users to
# handle those out of frame keypoints.
keypoints, keypoint_scores = ( keypoints, keypoint_scores = (
convert_strided_predictions_to_normalized_keypoints( convert_strided_predictions_to_normalized_keypoints(
keypoints, keypoint_scores, self._stride, true_image_shapes, keypoints, keypoint_scores, self._stride, true_image_shapes,
clip_out_of_frame_keypoints=False)) clip_out_of_frame_keypoints=clip_keypoints))
if keypoint_depths is not None:
postprocess_dict.update({
fields.DetectionResultFields.detection_keypoint_depths:
keypoint_depths
})
else: else:
# Multi-class keypoint estimation task does not support depth
# estimation.
assert all([
not kp_dict.predict_depth
for kp_dict in self._kp_params_dict.values()
])
keypoints, keypoint_scores = self._postprocess_keypoints_multi_class( keypoints, keypoint_scores = self._postprocess_keypoints_multi_class(
prediction_dict, classes, y_indices, x_indices, prediction_dict, classes, y_indices, x_indices,
boxes_strided, num_detections) boxes_strided, num_detections)
keypoints, keypoint_scores = ( keypoints, keypoint_scores = (
convert_strided_predictions_to_normalized_keypoints( convert_strided_predictions_to_normalized_keypoints(
keypoints, keypoint_scores, self._stride, true_image_shapes, keypoints, keypoint_scores, self._stride, true_image_shapes,
clip_out_of_frame_keypoints=True)) clip_out_of_frame_keypoints=clip_keypoints))
# Update instance scores based on keypoints.
scores = self._rescore_instances(classes, scores, keypoint_scores)
postprocess_dict.update({ postprocess_dict.update({
fields.DetectionResultFields.detection_scores: scores,
fields.DetectionResultFields.detection_keypoints: keypoints, fields.DetectionResultFields.detection_keypoints: keypoints,
fields.DetectionResultFields.detection_keypoint_scores: fields.DetectionResultFields.detection_keypoint_scores:
keypoint_scores keypoint_scores
...@@ -3200,7 +3465,11 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -3200,7 +3465,11 @@ class CenterNetMetaArch(model.DetectionModel):
task_name, KEYPOINT_REGRESSION)][-1] task_name, KEYPOINT_REGRESSION)][-1]
object_heatmap = tf.nn.sigmoid(prediction_dict[OBJECT_CENTER][-1]) object_heatmap = tf.nn.sigmoid(prediction_dict[OBJECT_CENTER][-1])
keypoints, keypoint_scores = ( keypoint_depths = None
if kp_params.predict_depth:
keypoint_depths = prediction_dict[get_keypoint_name(
task_name, KEYPOINT_DEPTH)][-1]
keypoints, keypoint_scores, keypoint_depths = (
prediction_to_single_instance_keypoints( prediction_to_single_instance_keypoints(
object_heatmap=object_heatmap, object_heatmap=object_heatmap,
keypoint_heatmap=keypoint_heatmap, keypoint_heatmap=keypoint_heatmap,
...@@ -3209,7 +3478,8 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -3209,7 +3478,8 @@ class CenterNetMetaArch(model.DetectionModel):
stride=self._stride, stride=self._stride,
object_center_std_dev=object_center_std_dev, object_center_std_dev=object_center_std_dev,
keypoint_std_dev=keypoint_std_dev, keypoint_std_dev=keypoint_std_dev,
kp_params=kp_params)) kp_params=kp_params,
keypoint_depths=keypoint_depths))
keypoints, keypoint_scores = ( keypoints, keypoint_scores = (
convert_strided_predictions_to_normalized_keypoints( convert_strided_predictions_to_normalized_keypoints(
...@@ -3222,6 +3492,12 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -3222,6 +3492,12 @@ class CenterNetMetaArch(model.DetectionModel):
fields.DetectionResultFields.detection_keypoints: keypoints, fields.DetectionResultFields.detection_keypoints: keypoints,
fields.DetectionResultFields.detection_keypoint_scores: keypoint_scores fields.DetectionResultFields.detection_keypoint_scores: keypoint_scores
} }
if kp_params.predict_depth:
postprocess_dict.update({
fields.DetectionResultFields.detection_keypoint_depths:
keypoint_depths
})
return postprocess_dict return postprocess_dict
def _postprocess_embeddings(self, prediction_dict, y_indices, x_indices): def _postprocess_embeddings(self, prediction_dict, y_indices, x_indices):
...@@ -3316,7 +3592,7 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -3316,7 +3592,7 @@ class CenterNetMetaArch(model.DetectionModel):
# [1, num_instances_i, num_keypoints_i], respectively. Note that # [1, num_instances_i, num_keypoints_i], respectively. Note that
# num_instances_i and num_keypoints_i refers to the number of # num_instances_i and num_keypoints_i refers to the number of
# instances and keypoints for class i, respectively. # instances and keypoints for class i, respectively.
kpt_coords_for_class, kpt_scores_for_class = ( (kpt_coords_for_class, kpt_scores_for_class, _) = (
self._postprocess_keypoints_for_class_and_image( self._postprocess_keypoints_for_class_and_image(
keypoint_heatmap, keypoint_offsets, keypoint_regression, keypoint_heatmap, keypoint_offsets, keypoint_regression,
classes, y_indices_for_kpt_class, x_indices_for_kpt_class, classes, y_indices_for_kpt_class, x_indices_for_kpt_class,
...@@ -3426,21 +3702,35 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -3426,21 +3702,35 @@ class CenterNetMetaArch(model.DetectionModel):
get_keypoint_name(task_name, KEYPOINT_OFFSET)][-1] get_keypoint_name(task_name, KEYPOINT_OFFSET)][-1]
keypoint_regression = prediction_dict[ keypoint_regression = prediction_dict[
get_keypoint_name(task_name, KEYPOINT_REGRESSION)][-1] get_keypoint_name(task_name, KEYPOINT_REGRESSION)][-1]
keypoint_depth_predictions = None
if kp_params.predict_depth:
keypoint_depth_predictions = prediction_dict[get_keypoint_name(
task_name, KEYPOINT_DEPTH)][-1]
batch_size, _, _ = _get_shape(boxes, 3) batch_size, _, _ = _get_shape(boxes, 3)
kpt_coords_for_example_list = [] kpt_coords_for_example_list = []
kpt_scores_for_example_list = [] kpt_scores_for_example_list = []
kpt_depths_for_example_list = []
for ex_ind in range(batch_size): for ex_ind in range(batch_size):
# Postprocess keypoints and scores for class and single image. Shapes # Postprocess keypoints and scores for class and single image. Shapes
# are [1, max_detections, num_keypoints, 2] and # are [1, max_detections, num_keypoints, 2] and
# [1, max_detections, num_keypoints], respectively. # [1, max_detections, num_keypoints], respectively.
kpt_coords_for_class, kpt_scores_for_class = ( (kpt_coords_for_class, kpt_scores_for_class, kpt_depths_for_class) = (
self._postprocess_keypoints_for_class_and_image( self._postprocess_keypoints_for_class_and_image(
keypoint_heatmap, keypoint_offsets, keypoint_regression, classes, keypoint_heatmap,
y_indices, x_indices, boxes, ex_ind, kp_params)) keypoint_offsets,
keypoint_regression,
classes,
y_indices,
x_indices,
boxes,
ex_ind,
kp_params,
keypoint_depth_predictions=keypoint_depth_predictions))
kpt_coords_for_example_list.append(kpt_coords_for_class) kpt_coords_for_example_list.append(kpt_coords_for_class)
kpt_scores_for_example_list.append(kpt_scores_for_class) kpt_scores_for_example_list.append(kpt_scores_for_class)
kpt_depths_for_example_list.append(kpt_depths_for_class)
# Concatenate all keypoints and scores from all examples in the batch. # Concatenate all keypoints and scores from all examples in the batch.
# Shapes are [batch_size, max_detections, num_keypoints, 2] and # Shapes are [batch_size, max_detections, num_keypoints, 2] and
...@@ -3448,7 +3738,11 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -3448,7 +3738,11 @@ class CenterNetMetaArch(model.DetectionModel):
keypoints = tf.concat(kpt_coords_for_example_list, axis=0) keypoints = tf.concat(kpt_coords_for_example_list, axis=0)
keypoint_scores = tf.concat(kpt_scores_for_example_list, axis=0) keypoint_scores = tf.concat(kpt_scores_for_example_list, axis=0)
return keypoints, keypoint_scores keypoint_depths = None
if kp_params.predict_depth:
keypoint_depths = tf.concat(kpt_depths_for_example_list, axis=0)
return keypoints, keypoint_scores, keypoint_depths
def _get_instance_indices(self, classes, num_detections, batch_index, def _get_instance_indices(self, classes, num_detections, batch_index,
class_id): class_id):
...@@ -3482,8 +3776,17 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -3482,8 +3776,17 @@ class CenterNetMetaArch(model.DetectionModel):
return tf.cast(instance_inds, tf.int32) return tf.cast(instance_inds, tf.int32)
def _postprocess_keypoints_for_class_and_image( def _postprocess_keypoints_for_class_and_image(
self, keypoint_heatmap, keypoint_offsets, keypoint_regression, classes, self,
y_indices, x_indices, boxes, batch_index, kp_params): keypoint_heatmap,
keypoint_offsets,
keypoint_regression,
classes,
y_indices,
x_indices,
boxes,
batch_index,
kp_params,
keypoint_depth_predictions=None):
"""Postprocess keypoints for a single image and class. """Postprocess keypoints for a single image and class.
Args: Args:
...@@ -3504,6 +3807,8 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -3504,6 +3807,8 @@ class CenterNetMetaArch(model.DetectionModel):
batch_index: An integer specifying the index for an example in the batch. batch_index: An integer specifying the index for an example in the batch.
kp_params: A `KeypointEstimationParams` object with parameters for a kp_params: A `KeypointEstimationParams` object with parameters for a
single keypoint class. single keypoint class.
keypoint_depth_predictions: (optional) A [batch_size, height, width, 1]
float32 tensor representing the keypoint depth prediction.
Returns: Returns:
A tuple of A tuple of
...@@ -3514,6 +3819,9 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -3514,6 +3819,9 @@ class CenterNetMetaArch(model.DetectionModel):
for the specific class. for the specific class.
refined_scores: A [1, num_instances, num_keypoints] float32 tensor with refined_scores: A [1, num_instances, num_keypoints] float32 tensor with
keypoint scores. keypoint scores.
refined_depths: A [1, num_instances, num_keypoints] float32 tensor with
keypoint depths. Return None if the input keypoint_depth_predictions is
None.
""" """
num_keypoints = len(kp_params.keypoint_indices) num_keypoints = len(kp_params.keypoint_indices)
...@@ -3521,6 +3829,10 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -3521,6 +3829,10 @@ class CenterNetMetaArch(model.DetectionModel):
keypoint_heatmap[batch_index:batch_index+1, ...]) keypoint_heatmap[batch_index:batch_index+1, ...])
keypoint_offsets = keypoint_offsets[batch_index:batch_index+1, ...] keypoint_offsets = keypoint_offsets[batch_index:batch_index+1, ...]
keypoint_regression = keypoint_regression[batch_index:batch_index+1, ...] keypoint_regression = keypoint_regression[batch_index:batch_index+1, ...]
keypoint_depths = None
if keypoint_depth_predictions is not None:
keypoint_depths = keypoint_depth_predictions[batch_index:batch_index + 1,
...]
y_indices = y_indices[batch_index:batch_index+1, ...] y_indices = y_indices[batch_index:batch_index+1, ...]
x_indices = x_indices[batch_index:batch_index+1, ...] x_indices = x_indices[batch_index:batch_index+1, ...]
boxes_slice = boxes[batch_index:batch_index+1, ...] boxes_slice = boxes[batch_index:batch_index+1, ...]
...@@ -3536,26 +3848,34 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -3536,26 +3848,34 @@ class CenterNetMetaArch(model.DetectionModel):
# The shape of keypoint_candidates and keypoint_scores is: # The shape of keypoint_candidates and keypoint_scores is:
# [1, num_candidates_per_keypoint, num_keypoints, 2] and # [1, num_candidates_per_keypoint, num_keypoints, 2] and
# [1, num_candidates_per_keypoint, num_keypoints], respectively. # [1, num_candidates_per_keypoint, num_keypoints], respectively.
keypoint_candidates, keypoint_scores, num_keypoint_candidates = ( (keypoint_candidates, keypoint_scores, num_keypoint_candidates,
prediction_tensors_to_keypoint_candidates( keypoint_depth_candidates) = (
keypoint_heatmap, keypoint_offsets, prediction_tensors_to_keypoint_candidates(
keypoint_score_threshold=( keypoint_heatmap,
kp_params.keypoint_candidate_score_threshold), keypoint_offsets,
max_pool_kernel_size=kp_params.peak_max_pool_kernel_size, keypoint_score_threshold=(
max_candidates=kp_params.num_candidates_per_keypoint)) kp_params.keypoint_candidate_score_threshold),
max_pool_kernel_size=kp_params.peak_max_pool_kernel_size,
max_candidates=kp_params.num_candidates_per_keypoint,
keypoint_depths=keypoint_depths))
# Get the refined keypoints and scores, of shape # Get the refined keypoints and scores, of shape
# [1, num_instances, num_keypoints, 2] and # [1, num_instances, num_keypoints, 2] and
# [1, num_instances, num_keypoints], respectively. # [1, num_instances, num_keypoints], respectively.
refined_keypoints, refined_scores = refine_keypoints( (refined_keypoints, refined_scores, refined_depths) = refine_keypoints(
regressed_keypoints_for_objects, keypoint_candidates, keypoint_scores, regressed_keypoints_for_objects,
num_keypoint_candidates, bboxes=boxes_slice, keypoint_candidates,
keypoint_scores,
num_keypoint_candidates,
bboxes=boxes_slice,
unmatched_keypoint_score=kp_params.unmatched_keypoint_score, unmatched_keypoint_score=kp_params.unmatched_keypoint_score,
box_scale=kp_params.box_scale, box_scale=kp_params.box_scale,
candidate_search_scale=kp_params.candidate_search_scale, candidate_search_scale=kp_params.candidate_search_scale,
candidate_ranking_mode=kp_params.candidate_ranking_mode) candidate_ranking_mode=kp_params.candidate_ranking_mode,
score_distance_offset=kp_params.score_distance_offset,
keypoint_depth_candidates=keypoint_depth_candidates)
return refined_keypoints, refined_scores return refined_keypoints, refined_scores, refined_depths
def regularization_losses(self): def regularization_losses(self):
return [] return []
......
...@@ -695,7 +695,7 @@ class CenterNetMetaArchHelpersTest(test_case.TestCase, parameterized.TestCase): ...@@ -695,7 +695,7 @@ class CenterNetMetaArchHelpersTest(test_case.TestCase, parameterized.TestCase):
keypoint_heatmap_offsets = tf.constant( keypoint_heatmap_offsets = tf.constant(
keypoint_heatmap_offsets_np, dtype=tf.float32) keypoint_heatmap_offsets_np, dtype=tf.float32)
keypoint_cands, keypoint_scores, num_keypoint_candidates = ( (keypoint_cands, keypoint_scores, num_keypoint_candidates, _) = (
cnma.prediction_tensors_to_keypoint_candidates( cnma.prediction_tensors_to_keypoint_candidates(
keypoint_heatmap, keypoint_heatmap,
keypoint_heatmap_offsets, keypoint_heatmap_offsets,
...@@ -780,7 +780,7 @@ class CenterNetMetaArchHelpersTest(test_case.TestCase, parameterized.TestCase): ...@@ -780,7 +780,7 @@ class CenterNetMetaArchHelpersTest(test_case.TestCase, parameterized.TestCase):
keypoint_regression = tf.constant( keypoint_regression = tf.constant(
keypoint_regression_np, dtype=tf.float32) keypoint_regression_np, dtype=tf.float32)
(keypoint_cands, keypoint_scores) = ( (keypoint_cands, keypoint_scores, _) = (
cnma.prediction_to_single_instance_keypoints( cnma.prediction_to_single_instance_keypoints(
object_heatmap, object_heatmap,
keypoint_heatmap, keypoint_heatmap,
...@@ -839,7 +839,7 @@ class CenterNetMetaArchHelpersTest(test_case.TestCase, parameterized.TestCase): ...@@ -839,7 +839,7 @@ class CenterNetMetaArchHelpersTest(test_case.TestCase, parameterized.TestCase):
keypoint_heatmap_offsets = tf.constant( keypoint_heatmap_offsets = tf.constant(
keypoint_heatmap_offsets_np, dtype=tf.float32) keypoint_heatmap_offsets_np, dtype=tf.float32)
keypoint_cands, keypoint_scores, num_keypoint_candidates = ( (keypoint_cands, keypoint_scores, num_keypoint_candidates, _) = (
cnma.prediction_tensors_to_keypoint_candidates( cnma.prediction_tensors_to_keypoint_candidates(
keypoint_heatmap, keypoint_heatmap,
keypoint_heatmap_offsets, keypoint_heatmap_offsets,
...@@ -880,6 +880,89 @@ class CenterNetMetaArchHelpersTest(test_case.TestCase, parameterized.TestCase): ...@@ -880,6 +880,89 @@ class CenterNetMetaArchHelpersTest(test_case.TestCase, parameterized.TestCase):
np.testing.assert_array_equal(expected_num_keypoint_candidates, np.testing.assert_array_equal(expected_num_keypoint_candidates,
num_keypoint_candidates) num_keypoint_candidates)
@parameterized.parameters({'per_keypoint_depth': True},
{'per_keypoint_depth': False})
def test_keypoint_candidate_prediction_depth(self, per_keypoint_depth):
keypoint_heatmap_np = np.zeros((2, 3, 3, 2), dtype=np.float32)
keypoint_heatmap_np[0, 0, 0, 0] = 1.0
keypoint_heatmap_np[0, 2, 1, 0] = 0.7
keypoint_heatmap_np[0, 1, 1, 0] = 0.6
keypoint_heatmap_np[0, 0, 2, 1] = 0.7
keypoint_heatmap_np[0, 1, 1, 1] = 0.3 # Filtered by low score.
keypoint_heatmap_np[0, 2, 2, 1] = 0.2
keypoint_heatmap_np[1, 1, 0, 0] = 0.6
keypoint_heatmap_np[1, 2, 1, 0] = 0.5
keypoint_heatmap_np[1, 0, 0, 0] = 0.4
keypoint_heatmap_np[1, 0, 0, 1] = 1.0
keypoint_heatmap_np[1, 0, 1, 1] = 0.9
keypoint_heatmap_np[1, 2, 0, 1] = 0.8
if per_keypoint_depth:
keypoint_depths_np = np.zeros((2, 3, 3, 2), dtype=np.float32)
keypoint_depths_np[0, 0, 0, 0] = -1.5
keypoint_depths_np[0, 2, 1, 0] = -1.0
keypoint_depths_np[0, 0, 2, 1] = 1.5
else:
keypoint_depths_np = np.zeros((2, 3, 3, 1), dtype=np.float32)
keypoint_depths_np[0, 0, 0, 0] = -1.5
keypoint_depths_np[0, 2, 1, 0] = -1.0
keypoint_depths_np[0, 0, 2, 0] = 1.5
keypoint_heatmap_offsets_np = np.zeros((2, 3, 3, 2), dtype=np.float32)
keypoint_heatmap_offsets_np[0, 0, 0] = [0.5, 0.25]
keypoint_heatmap_offsets_np[0, 2, 1] = [-0.25, 0.5]
keypoint_heatmap_offsets_np[0, 1, 1] = [0.0, 0.0]
keypoint_heatmap_offsets_np[0, 0, 2] = [1.0, 0.0]
keypoint_heatmap_offsets_np[0, 2, 2] = [1.0, 1.0]
keypoint_heatmap_offsets_np[1, 1, 0] = [0.25, 0.5]
keypoint_heatmap_offsets_np[1, 2, 1] = [0.5, 0.0]
keypoint_heatmap_offsets_np[1, 0, 0] = [0.0, -0.5]
keypoint_heatmap_offsets_np[1, 0, 1] = [0.5, -0.5]
keypoint_heatmap_offsets_np[1, 2, 0] = [-1.0, -0.5]
def graph_fn():
keypoint_heatmap = tf.constant(keypoint_heatmap_np, dtype=tf.float32)
keypoint_heatmap_offsets = tf.constant(
keypoint_heatmap_offsets_np, dtype=tf.float32)
keypoint_depths = tf.constant(keypoint_depths_np, dtype=tf.float32)
(keypoint_cands, keypoint_scores, num_keypoint_candidates,
keypoint_depths) = (
cnma.prediction_tensors_to_keypoint_candidates(
keypoint_heatmap,
keypoint_heatmap_offsets,
keypoint_score_threshold=0.5,
max_pool_kernel_size=1,
max_candidates=2,
keypoint_depths=keypoint_depths))
return (keypoint_cands, keypoint_scores, num_keypoint_candidates,
keypoint_depths)
(_, keypoint_scores, _, keypoint_depths) = self.execute(graph_fn, [])
expected_keypoint_scores = [
[ # Example 0.
[1.0, 0.7], # Keypoint 1.
[0.7, 0.3], # Keypoint 2.
],
[ # Example 1.
[0.6, 1.0], # Keypoint 1.
[0.5, 0.9], # Keypoint 2.
],
]
expected_keypoint_depths = [
[
[-1.5, 1.5],
[-1.0, 0.0],
],
[
[0., 0.],
[0., 0.],
],
]
np.testing.assert_allclose(expected_keypoint_scores, keypoint_scores)
np.testing.assert_allclose(expected_keypoint_depths, keypoint_depths)
def test_regressed_keypoints_at_object_centers(self): def test_regressed_keypoints_at_object_centers(self):
batch_size = 2 batch_size = 2
num_keypoints = 5 num_keypoints = 5
...@@ -985,11 +1068,15 @@ class CenterNetMetaArchHelpersTest(test_case.TestCase, parameterized.TestCase): ...@@ -985,11 +1068,15 @@ class CenterNetMetaArchHelpersTest(test_case.TestCase, parameterized.TestCase):
keypoint_scores = tf.constant(keypoint_scores_np, dtype=tf.float32) keypoint_scores = tf.constant(keypoint_scores_np, dtype=tf.float32)
num_keypoint_candidates = tf.constant(num_keypoints_candidates_np, num_keypoint_candidates = tf.constant(num_keypoints_candidates_np,
dtype=tf.int32) dtype=tf.int32)
refined_keypoints, refined_scores = cnma.refine_keypoints( (refined_keypoints, refined_scores, _) = cnma.refine_keypoints(
regressed_keypoints, keypoint_candidates, keypoint_scores, regressed_keypoints,
num_keypoint_candidates, bboxes=None, keypoint_candidates,
keypoint_scores,
num_keypoint_candidates,
bboxes=None,
unmatched_keypoint_score=unmatched_keypoint_score, unmatched_keypoint_score=unmatched_keypoint_score,
box_scale=1.2, candidate_search_scale=0.3, box_scale=1.2,
candidate_search_scale=0.3,
candidate_ranking_mode=candidate_ranking_mode) candidate_ranking_mode=candidate_ranking_mode)
return refined_keypoints, refined_scores return refined_keypoints, refined_scores
...@@ -1057,7 +1144,8 @@ class CenterNetMetaArchHelpersTest(test_case.TestCase, parameterized.TestCase): ...@@ -1057,7 +1144,8 @@ class CenterNetMetaArchHelpersTest(test_case.TestCase, parameterized.TestCase):
np.testing.assert_allclose(expected_refined_keypoints, refined_keypoints) np.testing.assert_allclose(expected_refined_keypoints, refined_keypoints)
np.testing.assert_allclose(expected_refined_scores, refined_scores) np.testing.assert_allclose(expected_refined_scores, refined_scores)
def test_refine_keypoints_with_bboxes(self): @parameterized.parameters({'predict_depth': True}, {'predict_depth': False})
def test_refine_keypoints_with_bboxes(self, predict_depth):
regressed_keypoints_np = np.array( regressed_keypoints_np = np.array(
[ [
# Example 0. # Example 0.
...@@ -1096,7 +1184,22 @@ class CenterNetMetaArchHelpersTest(test_case.TestCase, parameterized.TestCase): ...@@ -1096,7 +1184,22 @@ class CenterNetMetaArchHelpersTest(test_case.TestCase, parameterized.TestCase):
[0.7, 0.4, 0.0], # Candidate 0. [0.7, 0.4, 0.0], # Candidate 0.
[0.6, 0.1, 0.0], # Candidate 1. [0.6, 0.1, 0.0], # Candidate 1.
] ]
], dtype=np.float32) ],
dtype=np.float32)
keypoint_depths_np = np.array(
[
# Example 0.
[
[-0.8, -0.9, -1.0], # Candidate 0.
[-0.6, -0.1, -0.9], # Candidate 1.
],
# Example 1.
[
[-0.7, -0.4, -0.0], # Candidate 0.
[-0.6, -0.1, -0.0], # Candidate 1.
]
],
dtype=np.float32)
num_keypoints_candidates_np = np.array( num_keypoints_candidates_np = np.array(
[ [
# Example 0. # Example 0.
...@@ -1125,17 +1228,28 @@ class CenterNetMetaArchHelpersTest(test_case.TestCase, parameterized.TestCase): ...@@ -1125,17 +1228,28 @@ class CenterNetMetaArchHelpersTest(test_case.TestCase, parameterized.TestCase):
keypoint_candidates = tf.constant( keypoint_candidates = tf.constant(
keypoint_candidates_np, dtype=tf.float32) keypoint_candidates_np, dtype=tf.float32)
keypoint_scores = tf.constant(keypoint_scores_np, dtype=tf.float32) keypoint_scores = tf.constant(keypoint_scores_np, dtype=tf.float32)
if predict_depth:
keypoint_depths = tf.constant(keypoint_depths_np, dtype=tf.float32)
else:
keypoint_depths = None
num_keypoint_candidates = tf.constant(num_keypoints_candidates_np, num_keypoint_candidates = tf.constant(num_keypoints_candidates_np,
dtype=tf.int32) dtype=tf.int32)
bboxes = tf.constant(bboxes_np, dtype=tf.float32) bboxes = tf.constant(bboxes_np, dtype=tf.float32)
refined_keypoints, refined_scores = cnma.refine_keypoints( (refined_keypoints, refined_scores,
regressed_keypoints, keypoint_candidates, keypoint_scores, refined_depths) = cnma.refine_keypoints(
num_keypoint_candidates, bboxes=bboxes, regressed_keypoints,
unmatched_keypoint_score=unmatched_keypoint_score, keypoint_candidates,
box_scale=1.0, candidate_search_scale=0.3) keypoint_scores,
return refined_keypoints, refined_scores num_keypoint_candidates,
bboxes=bboxes,
refined_keypoints, refined_scores = self.execute(graph_fn, []) unmatched_keypoint_score=unmatched_keypoint_score,
box_scale=1.0,
candidate_search_scale=0.3,
keypoint_depth_candidates=keypoint_depths)
if predict_depth:
return refined_keypoints, refined_scores, refined_depths
else:
return refined_keypoints, refined_scores
expected_refined_keypoints = np.array( expected_refined_keypoints = np.array(
[ [
...@@ -1166,8 +1280,17 @@ class CenterNetMetaArchHelpersTest(test_case.TestCase, parameterized.TestCase): ...@@ -1166,8 +1280,17 @@ class CenterNetMetaArchHelpersTest(test_case.TestCase, parameterized.TestCase):
], ],
], dtype=np.float32) ], dtype=np.float32)
np.testing.assert_allclose(expected_refined_keypoints, refined_keypoints) if predict_depth:
np.testing.assert_allclose(expected_refined_scores, refined_scores) refined_keypoints, refined_scores, refined_depths = self.execute(
graph_fn, [])
expected_refined_depths = np.array([[[-0.8, 0.0, 0.0], [0.0, 0.0, -1.0]],
[[-0.7, -0.1, 0.0], [-0.7, -0.4,
0.0]]])
np.testing.assert_allclose(expected_refined_depths, refined_depths)
else:
refined_keypoints, refined_scores = self.execute(graph_fn, [])
np.testing.assert_allclose(expected_refined_keypoints, refined_keypoints)
np.testing.assert_allclose(expected_refined_scores, refined_scores)
def test_pad_to_full_keypoint_dim(self): def test_pad_to_full_keypoint_dim(self):
batch_size = 4 batch_size = 4
...@@ -1296,7 +1419,11 @@ def get_fake_od_params(): ...@@ -1296,7 +1419,11 @@ def get_fake_od_params():
scale_loss_weight=0.1) scale_loss_weight=0.1)
def get_fake_kp_params(num_candidates_per_keypoint=100): def get_fake_kp_params(num_candidates_per_keypoint=100,
per_keypoint_offset=False,
predict_depth=False,
per_keypoint_depth=False,
peak_radius=0):
"""Returns the fake keypoint estimation parameter namedtuple.""" """Returns the fake keypoint estimation parameter namedtuple."""
return cnma.KeypointEstimationParams( return cnma.KeypointEstimationParams(
task_name=_TASK_NAME, task_name=_TASK_NAME,
...@@ -1306,7 +1433,11 @@ def get_fake_kp_params(num_candidates_per_keypoint=100): ...@@ -1306,7 +1433,11 @@ def get_fake_kp_params(num_candidates_per_keypoint=100):
classification_loss=losses.WeightedSigmoidClassificationLoss(), classification_loss=losses.WeightedSigmoidClassificationLoss(),
localization_loss=losses.L1LocalizationLoss(), localization_loss=losses.L1LocalizationLoss(),
keypoint_candidate_score_threshold=0.1, keypoint_candidate_score_threshold=0.1,
num_candidates_per_keypoint=num_candidates_per_keypoint) num_candidates_per_keypoint=num_candidates_per_keypoint,
per_keypoint_offset=per_keypoint_offset,
predict_depth=predict_depth,
per_keypoint_depth=per_keypoint_depth,
offset_peak_radius=peak_radius)
def get_fake_mask_params(): def get_fake_mask_params():
...@@ -1353,7 +1484,11 @@ def build_center_net_meta_arch(build_resnet=False, ...@@ -1353,7 +1484,11 @@ def build_center_net_meta_arch(build_resnet=False,
num_classes=_NUM_CLASSES, num_classes=_NUM_CLASSES,
max_box_predictions=5, max_box_predictions=5,
apply_non_max_suppression=False, apply_non_max_suppression=False,
detection_only=False): detection_only=False,
per_keypoint_offset=False,
predict_depth=False,
per_keypoint_depth=False,
peak_radius=0):
"""Builds the CenterNet meta architecture.""" """Builds the CenterNet meta architecture."""
if build_resnet: if build_resnet:
feature_extractor = ( feature_extractor = (
...@@ -1407,7 +1542,10 @@ def build_center_net_meta_arch(build_resnet=False, ...@@ -1407,7 +1542,10 @@ def build_center_net_meta_arch(build_resnet=False,
object_center_params=get_fake_center_params(max_box_predictions), object_center_params=get_fake_center_params(max_box_predictions),
object_detection_params=get_fake_od_params(), object_detection_params=get_fake_od_params(),
keypoint_params_dict={ keypoint_params_dict={
_TASK_NAME: get_fake_kp_params(num_candidates_per_keypoint) _TASK_NAME:
get_fake_kp_params(num_candidates_per_keypoint,
per_keypoint_offset, predict_depth,
per_keypoint_depth, peak_radius)
}, },
non_max_suppression_fn=non_max_suppression_fn) non_max_suppression_fn=non_max_suppression_fn)
else: else:
...@@ -1992,6 +2130,84 @@ class CenterNetMetaArchTest(test_case.TestCase, parameterized.TestCase): ...@@ -1992,6 +2130,84 @@ class CenterNetMetaArchTest(test_case.TestCase, parameterized.TestCase):
self.assertAllEqual([1, 1, num_keypoints], self.assertAllEqual([1, 1, num_keypoints],
detections['detection_keypoint_scores'].shape) detections['detection_keypoint_scores'].shape)
@parameterized.parameters(
{'per_keypoint_depth': False},
{'per_keypoint_depth': True},
)
def test_postprocess_single_class_depth(self, per_keypoint_depth):
"""Test the postprocess function."""
model = build_center_net_meta_arch(
num_classes=1,
per_keypoint_offset=per_keypoint_depth,
predict_depth=True,
per_keypoint_depth=per_keypoint_depth)
num_keypoints = len(model._kp_params_dict[_TASK_NAME].keypoint_indices)
class_center = np.zeros((1, 32, 32, 1), dtype=np.float32)
height_width = np.zeros((1, 32, 32, 2), dtype=np.float32)
offset = np.zeros((1, 32, 32, 2), dtype=np.float32)
keypoint_heatmaps = np.zeros((1, 32, 32, num_keypoints), dtype=np.float32)
keypoint_offsets = np.zeros((1, 32, 32, 2), dtype=np.float32)
keypoint_regression = np.random.randn(1, 32, 32, num_keypoints * 2)
class_probs = np.zeros(1)
class_probs[0] = _logit(0.75)
class_center[0, 16, 16] = class_probs
height_width[0, 16, 16] = [5, 10]
offset[0, 16, 16] = [.25, .5]
keypoint_regression[0, 16, 16] = [-1., -1., -1., 1., 1., -1., 1., 1.]
keypoint_heatmaps[0, 14, 14, 0] = _logit(0.9)
keypoint_heatmaps[0, 14, 18, 1] = _logit(0.9)
keypoint_heatmaps[0, 18, 14, 2] = _logit(0.9)
keypoint_heatmaps[0, 18, 18, 3] = _logit(0.05) # Note the low score.
if per_keypoint_depth:
keypoint_depth = np.zeros((1, 32, 32, num_keypoints), dtype=np.float32)
keypoint_depth[0, 14, 14, 0] = -1.0
keypoint_depth[0, 14, 18, 1] = -1.1
keypoint_depth[0, 18, 14, 2] = -1.2
keypoint_depth[0, 18, 18, 3] = -1.3
else:
keypoint_depth = np.zeros((1, 32, 32, 1), dtype=np.float32)
keypoint_depth[0, 14, 14, 0] = -1.0
keypoint_depth[0, 14, 18, 0] = -1.1
keypoint_depth[0, 18, 14, 0] = -1.2
keypoint_depth[0, 18, 18, 0] = -1.3
class_center = tf.constant(class_center)
height_width = tf.constant(height_width)
offset = tf.constant(offset)
keypoint_heatmaps = tf.constant(keypoint_heatmaps, dtype=tf.float32)
keypoint_offsets = tf.constant(keypoint_offsets, dtype=tf.float32)
keypoint_regression = tf.constant(keypoint_regression, dtype=tf.float32)
keypoint_depth = tf.constant(keypoint_depth, dtype=tf.float32)
prediction_dict = {
cnma.OBJECT_CENTER: [class_center],
cnma.BOX_SCALE: [height_width],
cnma.BOX_OFFSET: [offset],
cnma.get_keypoint_name(_TASK_NAME,
cnma.KEYPOINT_HEATMAP): [keypoint_heatmaps],
cnma.get_keypoint_name(_TASK_NAME,
cnma.KEYPOINT_OFFSET): [keypoint_offsets],
cnma.get_keypoint_name(_TASK_NAME,
cnma.KEYPOINT_REGRESSION): [keypoint_regression],
cnma.get_keypoint_name(_TASK_NAME,
cnma.KEYPOINT_DEPTH): [keypoint_depth]
}
def graph_fn():
detections = model.postprocess(prediction_dict,
tf.constant([[128, 128, 3]]))
return detections
detections = self.execute_cpu(graph_fn, [])
self.assertAllClose(detections['detection_keypoint_depths'][0, 0],
np.array([-1.0, -1.1, -1.2, 0.0]))
self.assertAllClose(detections['detection_keypoint_scores'][0, 0],
np.array([0.9, 0.9, 0.9, 0.1]))
def test_get_instance_indices(self): def test_get_instance_indices(self):
classes = tf.constant([[0, 1, 2, 0], [2, 1, 2, 2]], dtype=tf.int32) classes = tf.constant([[0, 1, 2, 0], [2, 1, 2, 2]], dtype=tf.int32)
num_detections = tf.constant([1, 3], dtype=tf.int32) num_detections = tf.constant([1, 3], dtype=tf.int32)
...@@ -2002,8 +2218,72 @@ class CenterNetMetaArchTest(test_case.TestCase, parameterized.TestCase): ...@@ -2002,8 +2218,72 @@ class CenterNetMetaArchTest(test_case.TestCase, parameterized.TestCase):
classes, num_detections, batch_index, class_id) classes, num_detections, batch_index, class_id)
self.assertAllEqual(valid_indices.numpy(), [0, 2]) self.assertAllEqual(valid_indices.numpy(), [0, 2])
def test_rescore_instances(self):
feature_extractor = DummyFeatureExtractor(
channel_means=(1.0, 2.0, 3.0),
channel_stds=(10., 20., 30.),
bgr_ordering=False,
num_feature_outputs=2,
stride=4)
image_resizer_fn = functools.partial(
preprocessor.resize_to_range,
min_dimension=128,
max_dimension=128,
pad_to_max_dimesnion=True)
kp_params_1 = cnma.KeypointEstimationParams(
task_name='kpt_task_1',
class_id=0,
keypoint_indices=[0, 1, 2],
keypoint_std_dev=[0.00001] * 3,
classification_loss=losses.WeightedSigmoidClassificationLoss(),
localization_loss=losses.L1LocalizationLoss(),
keypoint_candidate_score_threshold=0.1,
rescore_instances=True) # Note rescoring for class_id = 0.
kp_params_2 = cnma.KeypointEstimationParams(
task_name='kpt_task_2',
class_id=1,
keypoint_indices=[3, 4],
keypoint_std_dev=[0.00001] * 2,
classification_loss=losses.WeightedSigmoidClassificationLoss(),
localization_loss=losses.L1LocalizationLoss(),
keypoint_candidate_score_threshold=0.1,
rescore_instances=False)
model = cnma.CenterNetMetaArch(
is_training=True,
add_summaries=False,
num_classes=2,
feature_extractor=feature_extractor,
image_resizer_fn=image_resizer_fn,
object_center_params=get_fake_center_params(),
object_detection_params=get_fake_od_params(),
keypoint_params_dict={
'kpt_task_1': kp_params_1,
'kpt_task_2': kp_params_2,
})
def get_fake_prediction_dict(input_height, input_width, stride): def graph_fn():
classes = tf.constant([[1, 0]], dtype=tf.int32)
scores = tf.constant([[0.5, 0.75]], dtype=tf.float32)
keypoint_scores = tf.constant(
[
[[0.1, 0.2, 0.3, 0.4, 0.5],
[0.1, 0.2, 0.3, 0.4, 0.5]],
])
new_scores = model._rescore_instances(classes, scores, keypoint_scores)
return new_scores
new_scores = self.execute_cpu(graph_fn, [])
expected_scores = np.array(
[[0.5, 0.75 * (0.1 + 0.2 + 0.3)/3]]
)
self.assertAllClose(expected_scores, new_scores)
def get_fake_prediction_dict(input_height,
input_width,
stride,
per_keypoint_depth=False):
"""Prepares the fake prediction dictionary.""" """Prepares the fake prediction dictionary."""
output_height = input_height // stride output_height = input_height // stride
output_width = input_width // stride output_width = input_width // stride
...@@ -2038,6 +2318,11 @@ def get_fake_prediction_dict(input_height, input_width, stride): ...@@ -2038,6 +2318,11 @@ def get_fake_prediction_dict(input_height, input_width, stride):
dtype=np.float32) dtype=np.float32)
keypoint_offset[0, 2, 4] = 0.2, 0.4 keypoint_offset[0, 2, 4] = 0.2, 0.4
keypoint_depth = np.zeros((2, output_height, output_width,
_NUM_KEYPOINTS if per_keypoint_depth else 1),
dtype=np.float32)
keypoint_depth[0, 2, 4] = 3.0
keypoint_regression = np.zeros( keypoint_regression = np.zeros(
(2, output_height, output_width, 2 * _NUM_KEYPOINTS), dtype=np.float32) (2, output_height, output_width, 2 * _NUM_KEYPOINTS), dtype=np.float32)
keypoint_regression[0, 2, 4] = 0.0, 0.0, 0.2, 0.4, 0.0, 0.0, 0.2, 0.4 keypoint_regression[0, 2, 4] = 0.0, 0.0, 0.2, 0.4, 0.0, 0.0, 0.2, 0.4
...@@ -2073,14 +2358,10 @@ def get_fake_prediction_dict(input_height, input_width, stride): ...@@ -2073,14 +2358,10 @@ def get_fake_prediction_dict(input_height, input_width, stride):
tf.constant(object_center), tf.constant(object_center),
tf.constant(object_center) tf.constant(object_center)
], ],
cnma.BOX_SCALE: [ cnma.BOX_SCALE: [tf.constant(object_scale),
tf.constant(object_scale), tf.constant(object_scale)],
tf.constant(object_scale) cnma.BOX_OFFSET: [tf.constant(object_offset),
], tf.constant(object_offset)],
cnma.BOX_OFFSET: [
tf.constant(object_offset),
tf.constant(object_offset)
],
cnma.get_keypoint_name(_TASK_NAME, cnma.KEYPOINT_HEATMAP): [ cnma.get_keypoint_name(_TASK_NAME, cnma.KEYPOINT_HEATMAP): [
tf.constant(keypoint_heatmap), tf.constant(keypoint_heatmap),
tf.constant(keypoint_heatmap) tf.constant(keypoint_heatmap)
...@@ -2093,6 +2374,10 @@ def get_fake_prediction_dict(input_height, input_width, stride): ...@@ -2093,6 +2374,10 @@ def get_fake_prediction_dict(input_height, input_width, stride):
tf.constant(keypoint_regression), tf.constant(keypoint_regression),
tf.constant(keypoint_regression) tf.constant(keypoint_regression)
], ],
cnma.get_keypoint_name(_TASK_NAME, cnma.KEYPOINT_DEPTH): [
tf.constant(keypoint_depth),
tf.constant(keypoint_depth)
],
cnma.SEGMENTATION_HEATMAP: [ cnma.SEGMENTATION_HEATMAP: [
tf.constant(mask_heatmap), tf.constant(mask_heatmap),
tf.constant(mask_heatmap) tf.constant(mask_heatmap)
...@@ -2117,7 +2402,10 @@ def get_fake_prediction_dict(input_height, input_width, stride): ...@@ -2117,7 +2402,10 @@ def get_fake_prediction_dict(input_height, input_width, stride):
return prediction_dict return prediction_dict
def get_fake_groundtruth_dict(input_height, input_width, stride): def get_fake_groundtruth_dict(input_height,
input_width,
stride,
has_depth=False):
"""Prepares the fake groundtruth dictionary.""" """Prepares the fake groundtruth dictionary."""
# A small box with center at (0.55, 0.55). # A small box with center at (0.55, 0.55).
boxes = [ boxes = [
...@@ -2146,6 +2434,26 @@ def get_fake_groundtruth_dict(input_height, input_width, stride): ...@@ -2146,6 +2434,26 @@ def get_fake_groundtruth_dict(input_height, input_width, stride):
axis=2), axis=2),
multiples=[1, 1, 2]), multiples=[1, 1, 2]),
] ]
if has_depth:
keypoint_depths = [
tf.constant([[float('nan'), 3.0,
float('nan'), 3.0, 0.55, 0.0]]),
tf.constant([[float('nan'), 0.55,
float('nan'), 0.55, 0.55, 0.0]])
]
keypoint_depth_weights = [
tf.constant([[1.0, 1.0, 1.0, 1.0, 0.0, 0.0]]),
tf.constant([[1.0, 1.0, 1.0, 1.0, 0.0, 0.0]])
]
else:
keypoint_depths = [
tf.constant([[0.0, 0.0, 0.0, 0.0, 0.0, 0.0]]),
tf.constant([[0.0, 0.0, 0.0, 0.0, 0.0, 0.0]])
]
keypoint_depth_weights = [
tf.constant([[0.0, 0.0, 0.0, 0.0, 0.0, 0.0]]),
tf.constant([[0.0, 0.0, 0.0, 0.0, 0.0, 0.0]])
]
labeled_classes = [ labeled_classes = [
tf.one_hot([1], depth=_NUM_CLASSES) + tf.one_hot([2], depth=_NUM_CLASSES), tf.one_hot([1], depth=_NUM_CLASSES) + tf.one_hot([2], depth=_NUM_CLASSES),
tf.one_hot([0], depth=_NUM_CLASSES) + tf.one_hot([1], depth=_NUM_CLASSES), tf.one_hot([0], depth=_NUM_CLASSES) + tf.one_hot([1], depth=_NUM_CLASSES),
...@@ -2187,11 +2495,12 @@ def get_fake_groundtruth_dict(input_height, input_width, stride): ...@@ -2187,11 +2495,12 @@ def get_fake_groundtruth_dict(input_height, input_width, stride):
fields.BoxListFields.weights: weights, fields.BoxListFields.weights: weights,
fields.BoxListFields.classes: classes, fields.BoxListFields.classes: classes,
fields.BoxListFields.keypoints: keypoints, fields.BoxListFields.keypoints: keypoints,
fields.BoxListFields.keypoint_depths: keypoint_depths,
fields.BoxListFields.keypoint_depth_weights: keypoint_depth_weights,
fields.BoxListFields.masks: masks, fields.BoxListFields.masks: masks,
fields.BoxListFields.densepose_num_points: densepose_num_points, fields.BoxListFields.densepose_num_points: densepose_num_points,
fields.BoxListFields.densepose_part_ids: densepose_part_ids, fields.BoxListFields.densepose_part_ids: densepose_part_ids,
fields.BoxListFields.densepose_surface_coords: fields.BoxListFields.densepose_surface_coords: densepose_surface_coords,
densepose_surface_coords,
fields.BoxListFields.track_ids: track_ids, fields.BoxListFields.track_ids: track_ids,
fields.BoxListFields.temporal_offsets: temporal_offsets, fields.BoxListFields.temporal_offsets: temporal_offsets,
fields.BoxListFields.track_match_flags: track_match_flags, fields.BoxListFields.track_match_flags: track_match_flags,
...@@ -2201,7 +2510,7 @@ def get_fake_groundtruth_dict(input_height, input_width, stride): ...@@ -2201,7 +2510,7 @@ def get_fake_groundtruth_dict(input_height, input_width, stride):
@unittest.skipIf(tf_version.is_tf1(), 'Skipping TF2.X only test.') @unittest.skipIf(tf_version.is_tf1(), 'Skipping TF2.X only test.')
class CenterNetMetaComputeLossTest(test_case.TestCase): class CenterNetMetaComputeLossTest(test_case.TestCase, parameterized.TestCase):
"""Test for CenterNet loss compuation related functions.""" """Test for CenterNet loss compuation related functions."""
def setUp(self): def setUp(self):
...@@ -2328,6 +2637,45 @@ class CenterNetMetaComputeLossTest(test_case.TestCase): ...@@ -2328,6 +2637,45 @@ class CenterNetMetaComputeLossTest(test_case.TestCase):
# The prediction and groundtruth are curated to produce very low loss. # The prediction and groundtruth are curated to produce very low loss.
self.assertGreater(0.01, loss) self.assertGreater(0.01, loss)
@parameterized.parameters(
{'per_keypoint_depth': False},
{'per_keypoint_depth': True},
)
def test_compute_kp_depth_loss(self, per_keypoint_depth):
prediction_dict = get_fake_prediction_dict(
self.input_height,
self.input_width,
self.stride,
per_keypoint_depth=per_keypoint_depth)
model = build_center_net_meta_arch(
num_classes=1,
per_keypoint_offset=per_keypoint_depth,
predict_depth=True,
per_keypoint_depth=per_keypoint_depth,
peak_radius=1 if per_keypoint_depth else 0)
model._groundtruth_lists = get_fake_groundtruth_dict(
self.input_height, self.input_width, self.stride, has_depth=True)
def graph_fn():
loss = model._compute_kp_depth_loss(
input_height=self.input_height,
input_width=self.input_width,
task_name=_TASK_NAME,
depth_predictions=prediction_dict[cnma.get_keypoint_name(
_TASK_NAME, cnma.KEYPOINT_DEPTH)],
localization_loss_fn=self.localization_loss_fn)
return loss
loss = self.execute(graph_fn, [])
if per_keypoint_depth:
# The loss is computed on a disk with radius 1 but only the center pixel
# has the accurate prediction. The final loss is (4 * |3-0|) / 5 = 2.4
self.assertAlmostEqual(2.4, loss, delta=1e-4)
else:
# The prediction and groundtruth are curated to produce very low loss.
self.assertGreater(0.01, loss)
def test_compute_track_embedding_loss(self): def test_compute_track_embedding_loss(self):
default_fc = self.model.track_reid_classification_net default_fc = self.model.track_reid_classification_net
# Initialize the kernel to extreme values so that the classification score # Initialize the kernel to extreme values so that the classification score
......
...@@ -514,6 +514,13 @@ def train_loop( ...@@ -514,6 +514,13 @@ def train_loop(
with strategy.scope(): with strategy.scope():
detection_model = MODEL_BUILD_UTIL_MAP['detection_model_fn_base']( detection_model = MODEL_BUILD_UTIL_MAP['detection_model_fn_base'](
model_config=model_config, is_training=True) model_config=model_config, is_training=True)
# We run the detection_model on dummy inputs in order to ensure that the
# model and all its variables have been properly constructed. Specifically,
# this is currently necessary prior to (potentially) creating shadow copies
# of the model variables for the EMA optimizer.
dummy_image, dummy_shapes = detection_model.preprocess(
tf.zeros([1, 512, 512, 3], dtype=tf.float32))
dummy_prediction_dict = detection_model.predict(dummy_image, dummy_shapes)
def train_dataset_fn(input_context): def train_dataset_fn(input_context):
"""Callable to create train input.""" """Callable to create train input."""
...@@ -536,6 +543,8 @@ def train_loop( ...@@ -536,6 +543,8 @@ def train_loop(
aggregation=tf.compat.v2.VariableAggregation.ONLY_FIRST_REPLICA) aggregation=tf.compat.v2.VariableAggregation.ONLY_FIRST_REPLICA)
optimizer, (learning_rate,) = optimizer_builder.build( optimizer, (learning_rate,) = optimizer_builder.build(
train_config.optimizer, global_step=global_step) train_config.optimizer, global_step=global_step)
if train_config.optimizer.use_moving_average:
optimizer.shadow_copy(detection_model)
if callable(learning_rate): if callable(learning_rate):
learning_rate_fn = learning_rate learning_rate_fn = learning_rate
...@@ -1057,6 +1066,13 @@ def eval_continuously( ...@@ -1057,6 +1066,13 @@ def eval_continuously(
with strategy.scope(): with strategy.scope():
detection_model = MODEL_BUILD_UTIL_MAP['detection_model_fn_base']( detection_model = MODEL_BUILD_UTIL_MAP['detection_model_fn_base'](
model_config=model_config, is_training=True) model_config=model_config, is_training=True)
# We run the detection_model on dummy inputs in order to ensure that the
# model and all its variables have been properly constructed. Specifically,
# this is currently necessary prior to (potentially) creating shadow copies
# of the model variables for the EMA optimizer.
dummy_image, dummy_shapes = detection_model.preprocess(
tf.zeros([1, 512, 512, 3], dtype=tf.float32))
dummy_prediction_dict = detection_model.predict(dummy_image, dummy_shapes)
eval_input = strategy.experimental_distribute_dataset( eval_input = strategy.experimental_distribute_dataset(
inputs.eval_input( inputs.eval_input(
...@@ -1068,13 +1084,22 @@ def eval_continuously( ...@@ -1068,13 +1084,22 @@ def eval_continuously(
global_step = tf.compat.v2.Variable( global_step = tf.compat.v2.Variable(
0, trainable=False, dtype=tf.compat.v2.dtypes.int64) 0, trainable=False, dtype=tf.compat.v2.dtypes.int64)
optimizer, _ = optimizer_builder.build(
configs['train_config'].optimizer, global_step=global_step)
for latest_checkpoint in tf.train.checkpoints_iterator( for latest_checkpoint in tf.train.checkpoints_iterator(
checkpoint_dir, timeout=timeout, min_interval_secs=wait_interval): checkpoint_dir, timeout=timeout, min_interval_secs=wait_interval):
ckpt = tf.compat.v2.train.Checkpoint( ckpt = tf.compat.v2.train.Checkpoint(
step=global_step, model=detection_model) step=global_step, model=detection_model, optimizer=optimizer)
if eval_config.use_moving_averages:
optimizer.shadow_copy(detection_model)
ckpt.restore(latest_checkpoint).expect_partial() ckpt.restore(latest_checkpoint).expect_partial()
if eval_config.use_moving_averages:
optimizer.swap_weights()
summary_writer = tf.compat.v2.summary.create_file_writer( summary_writer = tf.compat.v2.summary.create_file_writer(
os.path.join(model_dir, 'eval', eval_input_config.name)) os.path.join(model_dir, 'eval', eval_input_config.name))
with summary_writer.as_default(): with summary_writer.as_default():
......
...@@ -153,6 +153,12 @@ message CenterNet { ...@@ -153,6 +153,12 @@ message CenterNet {
// the keypoint candidate. // the keypoint candidate.
optional string candidate_ranking_mode = 16 [default = "min_distance"]; optional string candidate_ranking_mode = 16 [default = "min_distance"];
// The score distance ratio offset, only used if candidate_ranking_mode is
// 'score_distance_ratio'. The offset is used in the maximization of score
// distance ratio, defined as:
// keypoint_score / (distance + score_distance_offset)
optional float score_distance_offset = 22 [default = 1.0];
// The radius (in the unit of output pixel) around heatmap peak to assign // The radius (in the unit of output pixel) around heatmap peak to assign
// the offset targets. If set 0, then the offset target will only be // the offset targets. If set 0, then the offset target will only be
// assigned to the heatmap peak (same behavior as the original paper). // assigned to the heatmap peak (same behavior as the original paper).
...@@ -165,6 +171,34 @@ message CenterNet { ...@@ -165,6 +171,34 @@ message CenterNet {
// out_height, out_width, 2 * num_keypoints] (recommended when the // out_height, out_width, 2 * num_keypoints] (recommended when the
// offset_peak_radius is not zero). // offset_peak_radius is not zero).
optional bool per_keypoint_offset = 18 [default = false]; optional bool per_keypoint_offset = 18 [default = false];
// Indicates whether to predict the depth of each keypoints. Note that this
// is only supported in the single class keypoint task.
optional bool predict_depth = 19 [default = false];
// Indicates whether to predict depths for each keypoint channel
// separately. If set False, the output depth target has the shape
// [batch_size, out_height, out_width, 1]. If set True, the output depth
// target has the shape [batch_size, out_height, out_width,
// num_keypoints]. Recommend to set this value and "per_keypoint_offset" to
// both be True at the same time.
optional bool per_keypoint_depth = 20 [default = false];
// The weight of the keypoint depth loss.
optional float keypoint_depth_loss_weight = 21 [default = 1.0];
// Whether keypoints outside the image frame should be clipped back to the
// image boundary. If true, the keypoints that are clipped have scores set
// to 0.0.
optional bool clip_out_of_frame_keypoints = 23 [default = false];
// Whether instances should be rescored based on keypoint confidences. If
// False, will use the detection score (from the object center heatmap). If
// True, will compute new scores with:
// new_score = o * (1/k) sum {s_i}
// where o is the object score, s_i is the score for keypoint i, and k is
// the number of keypoints for that class.
optional bool rescore_instances = 24 [default = false];
} }
repeated KeypointEstimation keypoint_estimation_task = 7; repeated KeypointEstimation keypoint_estimation_task = 7;
...@@ -278,7 +312,6 @@ message CenterNet { ...@@ -278,7 +312,6 @@ message CenterNet {
// from CenterNet. Use this optional parameter to apply traditional non max // from CenterNet. Use this optional parameter to apply traditional non max
// suppression and score thresholding. // suppression and score thresholding.
optional PostProcessing post_processing = 24; optional PostProcessing post_processing = 24;
} }
message CenterNetFeatureExtractor { message CenterNetFeatureExtractor {
......
...@@ -19,9 +19,9 @@ from __future__ import division ...@@ -19,9 +19,9 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import os import os
import tensorflow.compat.v1 as tf
from google.protobuf import text_format from google.protobuf import text_format
import tensorflow.compat.v1 as tf
from tensorflow.python.lib.io import file_io from tensorflow.python.lib.io import file_io
...@@ -623,6 +623,20 @@ def _maybe_update_config_with_key_value(configs, key, value): ...@@ -623,6 +623,20 @@ def _maybe_update_config_with_key_value(configs, key, value):
_update_num_classes(configs["model"], value) _update_num_classes(configs["model"], value)
elif field_name == "sample_from_datasets_weights": elif field_name == "sample_from_datasets_weights":
_update_sample_from_datasets_weights(configs["train_input_config"], value) _update_sample_from_datasets_weights(configs["train_input_config"], value)
elif field_name == "peak_max_pool_kernel_size":
_update_peak_max_pool_kernel_size(configs["model"], value)
elif field_name == "candidate_search_scale":
_update_candidate_search_scale(configs["model"], value)
elif field_name == "candidate_ranking_mode":
_update_candidate_ranking_mode(configs["model"], value)
elif field_name == "score_distance_offset":
_update_score_distance_offset(configs["model"], value)
elif field_name == "box_scale":
_update_box_scale(configs["model"], value)
elif field_name == "keypoint_candidate_score_threshold":
_update_keypoint_candidate_score_threshold(configs["model"], value)
elif field_name == "rescore_instances":
_update_rescore_instances(configs["model"], value)
else: else:
return False return False
return True return True
...@@ -1089,3 +1103,99 @@ def _update_sample_from_datasets_weights(input_reader_config, weights): ...@@ -1089,3 +1103,99 @@ def _update_sample_from_datasets_weights(input_reader_config, weights):
del input_reader_config.sample_from_datasets_weights[:] del input_reader_config.sample_from_datasets_weights[:]
input_reader_config.sample_from_datasets_weights.extend(weights) input_reader_config.sample_from_datasets_weights.extend(weights)
def _update_peak_max_pool_kernel_size(model_config, kernel_size):
"""Updates the max pool kernel size (NMS) for keypoints in CenterNet."""
meta_architecture = model_config.WhichOneof("model")
if meta_architecture == "center_net":
if len(model_config.center_net.keypoint_estimation_task) == 1:
kpt_estimation_task = model_config.center_net.keypoint_estimation_task[0]
kpt_estimation_task.peak_max_pool_kernel_size = kernel_size
else:
tf.logging.warning("Ignoring config override key for "
"peak_max_pool_kernel_size since there are multiple "
"keypoint estimation tasks")
def _update_candidate_search_scale(model_config, search_scale):
"""Updates the keypoint candidate search scale in CenterNet."""
meta_architecture = model_config.WhichOneof("model")
if meta_architecture == "center_net":
if len(model_config.center_net.keypoint_estimation_task) == 1:
kpt_estimation_task = model_config.center_net.keypoint_estimation_task[0]
kpt_estimation_task.candidate_search_scale = search_scale
else:
tf.logging.warning("Ignoring config override key for "
"candidate_search_scale since there are multiple "
"keypoint estimation tasks")
def _update_candidate_ranking_mode(model_config, mode):
"""Updates how keypoints are snapped to candidates in CenterNet."""
if mode not in ("min_distance", "score_distance_ratio"):
raise ValueError("Attempting to set the keypoint candidate ranking mode "
"to {}, but the only options are 'min_distance' and "
"'score_distance_ratio'.".format(mode))
meta_architecture = model_config.WhichOneof("model")
if meta_architecture == "center_net":
if len(model_config.center_net.keypoint_estimation_task) == 1:
kpt_estimation_task = model_config.center_net.keypoint_estimation_task[0]
kpt_estimation_task.candidate_ranking_mode = mode
else:
tf.logging.warning("Ignoring config override key for "
"candidate_ranking_mode since there are multiple "
"keypoint estimation tasks")
def _update_score_distance_offset(model_config, offset):
"""Updates the keypoint candidate selection metric. See CenterNet proto."""
meta_architecture = model_config.WhichOneof("model")
if meta_architecture == "center_net":
if len(model_config.center_net.keypoint_estimation_task) == 1:
kpt_estimation_task = model_config.center_net.keypoint_estimation_task[0]
kpt_estimation_task.score_distance_offset = offset
else:
tf.logging.warning("Ignoring config override key for "
"score_distance_offset since there are multiple "
"keypoint estimation tasks")
def _update_box_scale(model_config, box_scale):
"""Updates the keypoint candidate search region. See CenterNet proto."""
meta_architecture = model_config.WhichOneof("model")
if meta_architecture == "center_net":
if len(model_config.center_net.keypoint_estimation_task) == 1:
kpt_estimation_task = model_config.center_net.keypoint_estimation_task[0]
kpt_estimation_task.box_scale = box_scale
else:
tf.logging.warning("Ignoring config override key for box_scale since "
"there are multiple keypoint estimation tasks")
def _update_keypoint_candidate_score_threshold(model_config, threshold):
"""Updates the keypoint candidate score threshold. See CenterNet proto."""
meta_architecture = model_config.WhichOneof("model")
if meta_architecture == "center_net":
if len(model_config.center_net.keypoint_estimation_task) == 1:
kpt_estimation_task = model_config.center_net.keypoint_estimation_task[0]
kpt_estimation_task.keypoint_candidate_score_threshold = threshold
else:
tf.logging.warning("Ignoring config override key for "
"keypoint_candidate_score_threshold since there are "
"multiple keypoint estimation tasks")
def _update_rescore_instances(model_config, should_rescore):
"""Updates whether boxes should be rescored based on keypoint confidences."""
if isinstance(should_rescore, str):
should_rescore = True if should_rescore == "True" else False
meta_architecture = model_config.WhichOneof("model")
if meta_architecture == "center_net":
if len(model_config.center_net.keypoint_estimation_task) == 1:
kpt_estimation_task = model_config.center_net.keypoint_estimation_task[0]
kpt_estimation_task.rescore_instances = should_rescore
else:
tf.logging.warning("Ignoring config override key for "
"rescore_instances since there are multiple keypoint "
"estimation tasks")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment