Commit 02828808 authored by Yu-hui Chen's avatar Yu-hui Chen Committed by TF Object Detection Team
Browse files

Updated the single instance postprocessing logics such that it uses the score...

Updated the single instance postprocessing logics such that it uses the score to distance ratio scoring method. The new logics also avoid using expensive ops, e.g. reduce_max/maximum.

PiperOrigin-RevId: 364607976
parent 9d1a6927
...@@ -545,13 +545,135 @@ def prediction_tensors_to_keypoint_candidates(keypoint_heatmap_predictions, ...@@ -545,13 +545,135 @@ def prediction_tensors_to_keypoint_candidates(keypoint_heatmap_predictions,
return keypoint_candidates, keypoint_scores, num_candidates, depth_candidates return keypoint_candidates, keypoint_scores, num_candidates, depth_candidates
def prediction_to_single_instance_keypoints(object_heatmap, def argmax_feature_map_locations(feature_map):
"""Returns the peak locations in the feature map."""
batch_size, _, width, num_channels = _get_shape(feature_map, 4)
feature_map_flattened = tf.reshape(
feature_map, [batch_size, -1, num_channels])
peak_flat_indices = tf.math.argmax(
feature_map_flattened, axis=1, output_type=tf.dtypes.int32)
# Convert the indices such that they represent the location in the full
# (flattened) feature map of size [batch, height * width * channels].
channel_idx = tf.range(num_channels)[tf.newaxis, :]
peak_flat_indices = num_channels * peak_flat_indices + channel_idx
# Get x, y and channel indices corresponding to the top indices in the flat
# array.
y_indices, x_indices, channel_indices = (
row_col_channel_indices_from_flattened_indices(
peak_flat_indices, width, num_channels))
return y_indices, x_indices, channel_indices
def prediction_tensors_to_single_instance_kpts(
keypoint_heatmap_predictions,
keypoint_heatmap_offsets,
keypoint_score_heatmap=None):
"""Convert keypoint heatmap predictions and offsets to keypoint candidates.
Args:
keypoint_heatmap_predictions: A float tensor of shape [batch_size, height,
width, num_keypoints] representing the per-keypoint heatmaps which is
used for finding the best keypoint candidate locations.
keypoint_heatmap_offsets: A float tensor of shape [batch_size, height,
width, 2] (or [batch_size, height, width, 2 * num_keypoints] if
'per_keypoint_offset' is set True) representing the per-keypoint offsets.
keypoint_score_heatmap: (optional) A float tensor of shape [batch_size,
height, width, num_keypoints] representing the heatmap which is used for
reporting the confidence scores. If not provided, then the values in the
keypoint_heatmap_predictions will be used.
Returns:
keypoint_candidates: A tensor of shape
[batch_size, max_candidates, num_keypoints, 2] holding the
location of keypoint candidates in [y, x] format (expressed in absolute
coordinates in the output coordinate frame).
keypoint_scores: A float tensor of shape
[batch_size, max_candidates, num_keypoints] with the scores for each
keypoint candidate. The scores come directly from the heatmap predictions.
num_keypoint_candidates: An integer tensor of shape
[batch_size, num_keypoints] with the number of candidates for each
keypoint type, as it's possible to filter some candidates due to the score
threshold.
"""
batch_size, height, width, num_keypoints = _get_shape(
keypoint_heatmap_predictions, 4)
# Get x, y and channel indices corresponding to the top indices in the
# keypoint heatmap predictions.
y_indices, x_indices, channel_indices = argmax_feature_map_locations(
keypoint_heatmap_predictions)
# TF Lite does not support tf.gather with batch_dims > 0, so we need to use
# tf_gather_nd instead and here we prepare the indices for that.
_, num_keypoints = _get_shape(y_indices, 2)
combined_indices = tf.stack([
_multi_range(batch_size, value_repetitions=num_keypoints),
tf.reshape(y_indices, [-1]),
tf.reshape(x_indices, [-1]),
tf.reshape(channel_indices, [-1])
], axis=1)
# Reshape the offsets predictions to shape:
# [batch_size, height, width, num_keypoints, 2]
keypoint_heatmap_offsets = tf.reshape(
keypoint_heatmap_offsets, [batch_size, height, width, num_keypoints, -1])
# shape: [num_keypoints, 2]
selected_offsets_flat = tf.gather_nd(keypoint_heatmap_offsets,
combined_indices)
y_offsets, x_offsets = tf.unstack(selected_offsets_flat, axis=1)
keypoint_candidates = tf.stack([
tf.cast(y_indices, dtype=tf.float32) + tf.expand_dims(y_offsets, axis=0),
tf.cast(x_indices, dtype=tf.float32) + tf.expand_dims(x_offsets, axis=0)
], axis=2)
keypoint_candidates = tf.expand_dims(keypoint_candidates, axis=0)
if keypoint_score_heatmap is None:
keypoint_scores = tf.gather_nd(
keypoint_heatmap_predictions, combined_indices)
else:
keypoint_scores = tf.gather_nd(keypoint_score_heatmap, combined_indices)
keypoint_scores = tf.expand_dims(
tf.expand_dims(keypoint_scores, axis=0), axis=0)
return keypoint_candidates, keypoint_scores
def _score_to_distance_map(y_grid, x_grid, heatmap, points_y, points_x,
score_distance_offset):
"""Rescores heatmap using the distance information.
Rescore the heatmap scores using the formula:
score / (d + score_distance_offset), where the d is the distance from each
pixel location to the target point location.
Args:
y_grid: A float tensor with shape [height, width] representing the
y-coordinate of each pixel grid.
x_grid: A float tensor with shape [height, width] representing the
x-coordinate of each pixel grid.
heatmap: A float tensor with shape [1, height, width, channel]
representing the heatmap to be rescored.
points_y: A float tensor with shape [channel] representing the y
coordinates of the target points for each channel.
points_x: A float tensor with shape [channel] representing the x
coordinates of the target points for each channel.
score_distance_offset: A constant used in the above formula.
Returns:
A float tensor with shape [1, height, width, channel] representing the
rescored heatmap.
"""
y_diff = y_grid[:, :, tf.newaxis] - points_y
x_diff = x_grid[:, :, tf.newaxis] - points_x
distance = tf.math.sqrt(y_diff**2 + x_diff**2)
return tf.math.divide(heatmap, distance + score_distance_offset)
def prediction_to_single_instance_keypoints(
object_heatmap,
keypoint_heatmap, keypoint_heatmap,
keypoint_offset, keypoint_offset,
keypoint_regression, keypoint_regression,
stride,
object_center_std_dev,
keypoint_std_dev,
kp_params, kp_params,
keypoint_depths=None): keypoint_depths=None):
"""Postprocess function to predict single instance keypoints. """Postprocess function to predict single instance keypoints.
...@@ -560,8 +682,8 @@ def prediction_to_single_instance_keypoints(object_heatmap, ...@@ -560,8 +682,8 @@ def prediction_to_single_instance_keypoints(object_heatmap,
there is only one instance in the image. If there are multiple instances in there is only one instance in the image. If there are multiple instances in
the image, the model prefers to predict the one that is closest to the image the image, the model prefers to predict the one that is closest to the image
center. Here is a high-level description of what this function does: center. Here is a high-level description of what this function does:
1) Object heatmap re-weighted by image center Gaussian is used to determine 1) Object heatmap re-weighted by the distance between each pixel to the
the instance center. image center is used to determine the instance center.
2) Regressed keypoint locations are retrieved from the instance center. The 2) Regressed keypoint locations are retrieved from the instance center. The
Gaussian kernel is applied to the regressed keypoint locations to Gaussian kernel is applied to the regressed keypoint locations to
re-weight the keypoint heatmap. This is to select the keypoints that are re-weight the keypoint heatmap. This is to select the keypoints that are
...@@ -579,16 +701,6 @@ def prediction_to_single_instance_keypoints(object_heatmap, ...@@ -579,16 +701,6 @@ def prediction_to_single_instance_keypoints(object_heatmap,
representing the per-keypoint offsets. representing the per-keypoint offsets.
keypoint_regression: A float tensor of shape [1, height, width, 2 * keypoint_regression: A float tensor of shape [1, height, width, 2 *
num_keypoints] representing the joint regression prediction. num_keypoints] representing the joint regression prediction.
stride: The stride in the output space.
object_center_std_dev: The standard deviation of the Gaussian mask which is
applied to the object_heatmap. The goal is to upweight the instance that
is closer to the image center. Expressed in units of input image pixels.
keypoint_std_dev: The standard deviation of the Gaussian masks which are
applied to the keypoint_heatmap based on the regressed joint location. It
is used to upweight the keypoint joints that belongs to the targeted
instance. If keypoint_std_dev contains 1 element, all keypoint joints will
share the same value. Otherwise, it must contain num_keypoints elements,
representing the standard deviation corresponding to each joint.
kp_params: A `KeypointEstimationParams` object with parameters for a single kp_params: A `KeypointEstimationParams` object with parameters for a single
keypoint class. keypoint class.
keypoint_depths: (optional) A float tensor of shape [batch_size, height, keypoint_depths: (optional) A float tensor of shape [batch_size, height,
...@@ -602,33 +714,29 @@ def prediction_to_single_instance_keypoints(object_heatmap, ...@@ -602,33 +714,29 @@ def prediction_to_single_instance_keypoints(object_heatmap,
map space. map space.
keypoint_scores: A float tensor with shape [1, 1, num_keypoints] keypoint_scores: A float tensor with shape [1, 1, num_keypoints]
representing the keypoint prediction scores. representing the keypoint prediction scores.
keypoint_depths: A float tensor with shape [1, 1, num_keypoints]
representing the estimated keypoint depths. Return None if the input
keypoint_depths is None.
Raises: Raises:
ValueError: if the input keypoint_std_dev doesn't have valid number of ValueError: if the input keypoint_std_dev doesn't have valid number of
elements (1 or num_keypoints). elements (1 or num_keypoints).
""" """
# TODO(yuhuic): add the keypoint depth prediction logics in the browser
# postprocessing back.
del keypoint_depths
num_keypoints = len(kp_params.keypoint_std_dev) num_keypoints = len(kp_params.keypoint_std_dev)
batch_size, height, width, _ = _get_shape(keypoint_heatmap, 4) batch_size, height, width, _ = _get_shape(keypoint_heatmap, 4)
# Apply the Gaussian mask to the image center. # Create the image center location.
image_center_y = tf.convert_to_tensor([0.5 * height], dtype=tf.float32) image_center_y = tf.convert_to_tensor([0.5 * height], dtype=tf.float32)
image_center_x = tf.convert_to_tensor([0.5 * width], dtype=tf.float32) image_center_x = tf.convert_to_tensor([0.5 * width], dtype=tf.float32)
(y_grid, x_grid) = ta_utils.image_shape_to_grids(height, width) (y_grid, x_grid) = ta_utils.image_shape_to_grids(height, width)
# Mask shape: [1, height, width, 1] # Rescore the object heatmap by the distnace to the image center.
object_mask = tf.expand_dims( object_heatmap = _score_to_distance_map(
ta_utils.coordinates_to_heatmap(y_grid, x_grid, image_center_y, y_grid, x_grid, object_heatmap, image_center_y,
image_center_x, image_center_x, kp_params.score_distance_offset)
object_center_std_dev / stride,
tf.one_hot(tf.range(1), depth=1)), axis=0)
object_heatmap = tf.math.multiply(object_heatmap, object_mask)
# Pick the highest score and location of the weighted object heatmap. # Pick the highest score and location of the weighted object heatmap.
_, y_indices, x_indices, _ = ( y_indices, x_indices, _ = argmax_feature_map_locations(object_heatmap)
top_k_feature_map_locations(
object_heatmap, max_pool_kernel_size=1, k=1, per_channel=True))
_, num_indices = _get_shape(y_indices, 2) _, num_indices = _get_shape(y_indices, 2)
combined_indices = tf.stack([ combined_indices = tf.stack([
_multi_range(batch_size, value_repetitions=num_indices), _multi_range(batch_size, value_repetitions=num_indices),
...@@ -644,36 +752,24 @@ def prediction_to_single_instance_keypoints(object_heatmap, ...@@ -644,36 +752,24 @@ def prediction_to_single_instance_keypoints(object_heatmap,
y_regressed = tf.cast(y_indices, dtype=tf.float32) + y_reg y_regressed = tf.cast(y_indices, dtype=tf.float32) + y_reg
x_regressed = tf.cast(x_indices, dtype=tf.float32) + x_reg x_regressed = tf.cast(x_indices, dtype=tf.float32) + x_reg
# Prepare and apply the keypoint heatmap masks. if kp_params.candidate_ranking_mode == 'score_distance_ratio':
keypoint_std_dev = [x / stride for x in keypoint_std_dev] reweighted_keypoint_heatmap = _score_to_distance_map(
if len(keypoint_std_dev) == 1: y_grid, x_grid, keypoint_heatmap, y_regressed, x_regressed,
std_dev = tf.convert_to_tensor( kp_params.score_distance_offset)
keypoint_std_dev * num_keypoints, dtype=tf.float32)
elif len(keypoint_std_dev) == num_keypoints:
std_dev = tf.convert_to_tensor(
keypoint_std_dev, dtype=tf.float32)
else: else:
raise ValueError('keypoint_std_dev needs to have length either ' raise ValueError('Unsupported candidate_ranking_mode: %s' %
'equal to 1 or num_keypoints.') kp_params.candidate_ranking_mode)
channel_onehot = tf.one_hot(tf.range(num_keypoints), depth=num_keypoints)
keypoint_mask = tf.expand_dims(
ta_utils.coordinates_to_heatmap(y_grid, x_grid, y_regressed, x_regressed,
std_dev, channel_onehot), axis=0)
keypoint_predictions = tf.math.multiply(keypoint_heatmap, keypoint_mask)
# Get the keypoint locations/scores: # Get the keypoint locations/scores:
# keypoint_candidates: [1, 1, num_keypoints, 2] # keypoint_candidates: [1, 1, num_keypoints, 2]
# keypoint_scores: [1, 1, num_keypoints] # keypoint_scores: [1, 1, num_keypoints]
# depth_candidates: [1, 1, num_keypoints] # depth_candidates: [1, 1, num_keypoints]
(keypoint_candidates, keypoint_scores, _, (keypoint_candidates, keypoint_scores
depth_candidates) = prediction_tensors_to_keypoint_candidates( ) = prediction_tensors_to_single_instance_kpts(
keypoint_predictions, reweighted_keypoint_heatmap,
keypoint_offset, keypoint_offset,
keypoint_score_threshold=kp_params.keypoint_candidate_score_threshold, keypoint_score_heatmap=keypoint_heatmap)
max_pool_kernel_size=kp_params.peak_max_pool_kernel_size, return keypoint_candidates, keypoint_scores, None
max_candidates=1,
keypoint_depths=keypoint_depths)
return keypoint_candidates, keypoint_scores, depth_candidates
def regressed_keypoints_at_object_centers(regressed_keypoint_predictions, def regressed_keypoints_at_object_centers(regressed_keypoint_predictions,
...@@ -3455,10 +3551,10 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -3455,10 +3551,10 @@ class CenterNetMetaArch(model.DetectionModel):
postprocess_dict.update(nmsed_additional_fields) postprocess_dict.update(nmsed_additional_fields)
return postprocess_dict return postprocess_dict
def postprocess_single_instance_keypoints(self, prediction_dict, def postprocess_single_instance_keypoints(
true_image_shapes, self,
object_center_std_dev, prediction_dict,
keypoint_std_dev): true_image_shapes):
"""Postprocess for predicting single instance keypoints. """Postprocess for predicting single instance keypoints.
This postprocess function is a special case of predicting the keypoint of This postprocess function is a special case of predicting the keypoint of
...@@ -3483,17 +3579,6 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -3483,17 +3579,6 @@ class CenterNetMetaArch(model.DetectionModel):
true_image_shapes: int32 tensor of shape [batch, 3] where each row is of true_image_shapes: int32 tensor of shape [batch, 3] where each row is of
the form [height, width, channels] indicating the shapes of true images the form [height, width, channels] indicating the shapes of true images
in the resized images, as resized images can be padded with zeros. in the resized images, as resized images can be padded with zeros.
object_center_std_dev: The standard deviation of the Gaussian mask which
is applied to the object_heatmap. The goal is to upweight the instance
that is closer to the image center. Expressed in units of input image
pixels.
keypoint_std_dev: The standard deviation of the Gaussian masks which are
applied to the keypoint_heatmap based on the regressed joint location.
It is used to upweight the keypoint joints that belongs to the targeted
instance. If keypoint_std_dev contains one value, then we assume the
same value is applied to all keypoint joints. If keypoint_std_dev is a
list, it must contain num_keypoints elements, representing the standard
deviation corresponding to each joints.
Returns: Returns:
detections: a dictionary containing the following fields detections: a dictionary containing the following fields
...@@ -3524,9 +3609,6 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -3524,9 +3609,6 @@ class CenterNetMetaArch(model.DetectionModel):
keypoint_heatmap=keypoint_heatmap, keypoint_heatmap=keypoint_heatmap,
keypoint_offset=keypoint_offset, keypoint_offset=keypoint_offset,
keypoint_regression=keypoint_regression, keypoint_regression=keypoint_regression,
stride=self._stride,
object_center_std_dev=object_center_std_dev,
keypoint_std_dev=keypoint_std_dev,
kp_params=kp_params, kp_params=kp_params,
keypoint_depths=keypoint_depths)) keypoint_depths=keypoint_depths))
......
...@@ -750,18 +750,19 @@ class CenterNetMetaArchHelpersTest(test_case.TestCase, parameterized.TestCase): ...@@ -750,18 +750,19 @@ class CenterNetMetaArchHelpersTest(test_case.TestCase, parameterized.TestCase):
keypoint_heatmap_np[0, 7, 7, 3] = 0.9 keypoint_heatmap_np[0, 7, 7, 3] = 0.9
keypoint_heatmap_np[0, 4, 4, 3] = 1.0 keypoint_heatmap_np[0, 4, 4, 3] = 1.0
keypoint_offset_np = np.zeros((1, image_size[0], image_size[1], 2), keypoint_offset_np = np.zeros((1, image_size[0], image_size[1], 8),
dtype=np.float32) dtype=np.float32)
keypoint_offset_np[0, 1, 1] = [0.5, 0.5] keypoint_offset_np[0, 1, 1] = [0.5, 0.5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
keypoint_offset_np[0, 1, 7] = [0.5, -0.5] keypoint_offset_np[0, 1, 7] = [0.0, 0.0, 0.5, -0.5, 0.0, 0.0, 0.0, 0.0]
keypoint_offset_np[0, 7, 1] = [-0.5, 0.5] keypoint_offset_np[0, 7, 1] = [0.0, 0.0, 0.0, 0.0, -0.5, 0.5, 0.0, 0.0]
keypoint_offset_np[0, 7, 7] = [-0.5, -0.5] keypoint_offset_np[0, 7, 7] = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.5, -0.5]
keypoint_regression_np = np.zeros((1, image_size[0], image_size[1], 8), keypoint_regression_np = np.zeros((1, image_size[0], image_size[1], 8),
dtype=np.float32) dtype=np.float32)
keypoint_regression_np[0, 4, 4] = [-3, -3, -3, 3, 3, -3, 3, 3] keypoint_regression_np[0, 4, 4] = [-3, -3, -3, 3, 3, -3, 3, 3]
kp_params = get_fake_kp_params(num_candidates_per_keypoint=1) kp_params = get_fake_kp_params(
candidate_ranking_mode='score_distance_ratio')
def graph_fn(): def graph_fn():
object_heatmap = tf.constant(object_heatmap_np, dtype=tf.float32) object_heatmap = tf.constant(object_heatmap_np, dtype=tf.float32)
...@@ -776,9 +777,6 @@ class CenterNetMetaArchHelpersTest(test_case.TestCase, parameterized.TestCase): ...@@ -776,9 +777,6 @@ class CenterNetMetaArchHelpersTest(test_case.TestCase, parameterized.TestCase):
keypoint_heatmap, keypoint_heatmap,
keypoint_offset, keypoint_offset,
keypoint_regression, keypoint_regression,
stride=4,
object_center_std_dev=image_size[0] / 2,
keypoint_std_dev=[image_size[0] / 10],
kp_params=kp_params)) kp_params=kp_params))
return keypoint_cands, keypoint_scores return keypoint_cands, keypoint_scores
...@@ -1499,7 +1497,8 @@ def get_fake_kp_params(num_candidates_per_keypoint=100, ...@@ -1499,7 +1497,8 @@ def get_fake_kp_params(num_candidates_per_keypoint=100,
per_keypoint_offset=False, per_keypoint_offset=False,
predict_depth=False, predict_depth=False,
per_keypoint_depth=False, per_keypoint_depth=False,
peak_radius=0): peak_radius=0,
candidate_ranking_mode='min_distance'):
"""Returns the fake keypoint estimation parameter namedtuple.""" """Returns the fake keypoint estimation parameter namedtuple."""
return cnma.KeypointEstimationParams( return cnma.KeypointEstimationParams(
task_name=_TASK_NAME, task_name=_TASK_NAME,
...@@ -1514,7 +1513,8 @@ def get_fake_kp_params(num_candidates_per_keypoint=100, ...@@ -1514,7 +1513,8 @@ def get_fake_kp_params(num_candidates_per_keypoint=100,
per_keypoint_offset=per_keypoint_offset, per_keypoint_offset=per_keypoint_offset,
predict_depth=predict_depth, predict_depth=predict_depth,
per_keypoint_depth=per_keypoint_depth, per_keypoint_depth=per_keypoint_depth,
offset_peak_radius=peak_radius) offset_peak_radius=peak_radius,
candidate_ranking_mode=candidate_ranking_mode)
def get_fake_mask_params(): def get_fake_mask_params():
...@@ -1566,7 +1566,8 @@ def build_center_net_meta_arch(build_resnet=False, ...@@ -1566,7 +1566,8 @@ def build_center_net_meta_arch(build_resnet=False,
predict_depth=False, predict_depth=False,
per_keypoint_depth=False, per_keypoint_depth=False,
peak_radius=0, peak_radius=0,
keypoint_only=False): keypoint_only=False,
candidate_ranking_mode='min_distance'):
"""Builds the CenterNet meta architecture.""" """Builds the CenterNet meta architecture."""
if build_resnet: if build_resnet:
feature_extractor = ( feature_extractor = (
...@@ -1612,7 +1613,8 @@ def build_center_net_meta_arch(build_resnet=False, ...@@ -1612,7 +1613,8 @@ def build_center_net_meta_arch(build_resnet=False,
_TASK_NAME: _TASK_NAME:
get_fake_kp_params(num_candidates_per_keypoint, get_fake_kp_params(num_candidates_per_keypoint,
per_keypoint_offset, predict_depth, per_keypoint_offset, predict_depth,
per_keypoint_depth, peak_radius) per_keypoint_depth, peak_radius,
candidate_ranking_mode)
}, },
non_max_suppression_fn=non_max_suppression_fn) non_max_suppression_fn=non_max_suppression_fn)
elif detection_only: elif detection_only:
...@@ -1639,7 +1641,8 @@ def build_center_net_meta_arch(build_resnet=False, ...@@ -1639,7 +1641,8 @@ def build_center_net_meta_arch(build_resnet=False,
_TASK_NAME: _TASK_NAME:
get_fake_kp_params(num_candidates_per_keypoint, get_fake_kp_params(num_candidates_per_keypoint,
per_keypoint_offset, predict_depth, per_keypoint_offset, predict_depth,
per_keypoint_depth, peak_radius) per_keypoint_depth, peak_radius,
candidate_ranking_mode)
}, },
non_max_suppression_fn=non_max_suppression_fn) non_max_suppression_fn=non_max_suppression_fn)
else: else:
...@@ -1651,7 +1654,8 @@ def build_center_net_meta_arch(build_resnet=False, ...@@ -1651,7 +1654,8 @@ def build_center_net_meta_arch(build_resnet=False,
image_resizer_fn=image_resizer_fn, image_resizer_fn=image_resizer_fn,
object_center_params=get_fake_center_params(), object_center_params=get_fake_center_params(),
object_detection_params=get_fake_od_params(), object_detection_params=get_fake_od_params(),
keypoint_params_dict={_TASK_NAME: get_fake_kp_params()}, keypoint_params_dict={_TASK_NAME: get_fake_kp_params(
candidate_ranking_mode=candidate_ranking_mode)},
mask_params=get_fake_mask_params(), mask_params=get_fake_mask_params(),
densepose_params=get_fake_densepose_params(), densepose_params=get_fake_densepose_params(),
track_params=get_fake_track_params(), track_params=get_fake_track_params(),
...@@ -2236,12 +2240,14 @@ class CenterNetMetaArchTest(test_case.TestCase, parameterized.TestCase): ...@@ -2236,12 +2240,14 @@ class CenterNetMetaArchTest(test_case.TestCase, parameterized.TestCase):
def test_postprocess_single_instance(self): def test_postprocess_single_instance(self):
"""Test the postprocess single instance function.""" """Test the postprocess single instance function."""
model = build_center_net_meta_arch(num_classes=1) model = build_center_net_meta_arch(
num_classes=1, candidate_ranking_mode='score_distance_ratio')
num_keypoints = len(model._kp_params_dict[_TASK_NAME].keypoint_indices) num_keypoints = len(model._kp_params_dict[_TASK_NAME].keypoint_indices)
class_center = np.zeros((1, 32, 32, 1), dtype=np.float32) class_center = np.zeros((1, 32, 32, 1), dtype=np.float32)
keypoint_heatmaps = np.zeros((1, 32, 32, num_keypoints), dtype=np.float32) keypoint_heatmaps = np.zeros((1, 32, 32, num_keypoints), dtype=np.float32)
keypoint_offsets = np.zeros((1, 32, 32, 2), dtype=np.float32) keypoint_offsets = np.zeros(
(1, 32, 32, num_keypoints * 2), dtype=np.float32)
keypoint_regression = np.random.randn(1, 32, 32, num_keypoints * 2) keypoint_regression = np.random.randn(1, 32, 32, num_keypoints * 2)
class_probs = np.zeros(1) class_probs = np.zeros(1)
...@@ -2275,9 +2281,7 @@ class CenterNetMetaArchTest(test_case.TestCase, parameterized.TestCase): ...@@ -2275,9 +2281,7 @@ class CenterNetMetaArchTest(test_case.TestCase, parameterized.TestCase):
def graph_fn(): def graph_fn():
detections = model.postprocess_single_instance_keypoints( detections = model.postprocess_single_instance_keypoints(
prediction_dict, prediction_dict,
tf.constant([[128, 128, 3]]), tf.constant([[128, 128, 3]]))
object_center_std_dev=32,
keypoint_std_dev=[32])
return detections return detections
detections = self.execute_cpu(graph_fn, []) detections = self.execute_cpu(graph_fn, [])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment