Commit 9f971b89 authored by Yu-hui Chen's avatar Yu-hui Chen Committed by TF Object Detection Team
Browse files

Added few keypoint postprocessing functions to handle the one-class bounding...

Added few keypoint postprocessing functions to handle the one-class bounding box/keypoint use case such that the model can be converted to tf.lite format and run on CPU/GPU.

PiperOrigin-RevId: 338078454
parent 27289f90
...@@ -229,7 +229,9 @@ def top_k_feature_map_locations(feature_map, max_pool_kernel_size=3, k=100, ...@@ -229,7 +229,9 @@ def top_k_feature_map_locations(feature_map, max_pool_kernel_size=3, k=100,
"""Returns the top k scores and their locations in a feature map. """Returns the top k scores and their locations in a feature map.
Given a feature map, the top k values (based on activation) are returned. If Given a feature map, the top k values (based on activation) are returned. If
`per_channel` is True, the top k values **per channel** are returned. `per_channel` is True, the top k values **per channel** are returned. Note
that when k equals to 1, ths function uses reduce_max and argmax instead of
top_k to make the logics more efficient.
The `max_pool_kernel_size` argument allows for selecting local peaks in a The `max_pool_kernel_size` argument allows for selecting local peaks in a
region. This filtering is done per channel, so nothing prevents two values at region. This filtering is done per channel, so nothing prevents two values at
...@@ -279,12 +281,21 @@ def top_k_feature_map_locations(feature_map, max_pool_kernel_size=3, k=100, ...@@ -279,12 +281,21 @@ def top_k_feature_map_locations(feature_map, max_pool_kernel_size=3, k=100,
batch_size, _, width, num_channels = _get_shape(feature_map, 4) batch_size, _, width, num_channels = _get_shape(feature_map, 4)
if per_channel: if per_channel:
# Perform top k over batch and channels. if k == 1:
feature_map_peaks_transposed = tf.transpose(feature_map_peaks, feature_map_flattened = tf.reshape(
perm=[0, 3, 1, 2]) feature_map_peaks, [batch_size, -1, num_channels])
feature_map_peaks_transposed = tf.reshape( scores = tf.math.reduce_max(feature_map_flattened, axis=1)
feature_map_peaks_transposed, [batch_size, num_channels, -1]) peak_flat_indices = tf.math.argmax(
scores, peak_flat_indices = tf.math.top_k(feature_map_peaks_transposed, k=k) feature_map_flattened, axis=1, output_type=tf.dtypes.int32)
peak_flat_indices = tf.expand_dims(peak_flat_indices, axis=-1)
else:
# Perform top k over batch and channels.
feature_map_peaks_transposed = tf.transpose(feature_map_peaks,
perm=[0, 3, 1, 2])
feature_map_peaks_transposed = tf.reshape(
feature_map_peaks_transposed, [batch_size, num_channels, -1])
scores, peak_flat_indices = tf.math.top_k(
feature_map_peaks_transposed, k=k)
# Convert the indices such that they represent the location in the full # Convert the indices such that they represent the location in the full
# (flattened) feature map of size [batch, height * width * channels]. # (flattened) feature map of size [batch, height * width * channels].
channel_idx = tf.range(num_channels)[tf.newaxis, :, tf.newaxis] channel_idx = tf.range(num_channels)[tf.newaxis, :, tf.newaxis]
...@@ -292,8 +303,14 @@ def top_k_feature_map_locations(feature_map, max_pool_kernel_size=3, k=100, ...@@ -292,8 +303,14 @@ def top_k_feature_map_locations(feature_map, max_pool_kernel_size=3, k=100,
scores = tf.reshape(scores, [batch_size, -1]) scores = tf.reshape(scores, [batch_size, -1])
peak_flat_indices = tf.reshape(peak_flat_indices, [batch_size, -1]) peak_flat_indices = tf.reshape(peak_flat_indices, [batch_size, -1])
else: else:
feature_map_peaks_flat = tf.reshape(feature_map_peaks, [batch_size, -1]) if k == 1:
scores, peak_flat_indices = tf.math.top_k(feature_map_peaks_flat, k=k) feature_map_peaks_flat = tf.reshape(feature_map_peaks, [batch_size, -1])
scores = tf.math.reduce_max(feature_map_peaks_flat, axis=1, keepdims=True)
peak_flat_indices = tf.expand_dims(tf.math.argmax(
feature_map_peaks_flat, axis=1, output_type=tf.dtypes.int32), axis=-1)
else:
feature_map_peaks_flat = tf.reshape(feature_map_peaks, [batch_size, -1])
scores, peak_flat_indices = tf.math.top_k(feature_map_peaks_flat, k=k)
# Get x, y and channel indices corresponding to the top indices in the flat # Get x, y and channel indices corresponding to the top indices in the flat
# array. # array.
...@@ -1816,6 +1833,7 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -1816,6 +1833,7 @@ class CenterNetMetaArch(model.DetectionModel):
# The Objects as Points paper attaches loss functions to multiple # The Objects as Points paper attaches loss functions to multiple
# (`num_feature_outputs`) feature maps in the the backbone. E.g. # (`num_feature_outputs`) feature maps in the the backbone. E.g.
# for the hourglass backbone, `num_feature_outputs` is 2. # for the hourglass backbone, `num_feature_outputs` is 2.
self._num_classes = num_classes
self._feature_extractor = feature_extractor self._feature_extractor = feature_extractor
self._num_feature_outputs = feature_extractor.num_feature_outputs self._num_feature_outputs = feature_extractor.num_feature_outputs
self._stride = self._feature_extractor.out_stride self._stride = self._feature_extractor.out_stride
...@@ -2911,13 +2929,31 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -2911,13 +2929,31 @@ class CenterNetMetaArch(model.DetectionModel):
} }
if self._kp_params_dict: if self._kp_params_dict:
keypoints, keypoint_scores = self._postprocess_keypoints( # If the model is trained to predict only one class of object and its
prediction_dict, classes, y_indices, x_indices, # keypoint, we fall back to a simpler postprocessing function which uses
boxes_strided, num_detections) # the ops that are supported by tf.lite on GPU.
keypoints, keypoint_scores = ( if len(self._kp_params_dict) == 1 and self._num_classes == 1:
convert_strided_predictions_to_normalized_keypoints( # keypoints, keypoint_scores = self._postprocess_keypoints_simple(
keypoints, keypoint_scores, self._stride, true_image_shapes, # prediction_dict, classes, y_indices, x_indices,
clip_out_of_frame_keypoints=True)) # boxes_strided, num_detections)
keypoints, keypoint_scores = self._postprocess_keypoints_simple(
prediction_dict, classes, y_indices, x_indices,
boxes_strided, num_detections)
# The map_fn used to clip out of frame keypoints creates issues when
# converting to tf.lite model so we disable it and let the users to
# handle those out of frame keypoints.
keypoints, keypoint_scores = (
convert_strided_predictions_to_normalized_keypoints(
keypoints, keypoint_scores, self._stride, true_image_shapes,
clip_out_of_frame_keypoints=False))
else:
keypoints, keypoint_scores = self._postprocess_keypoints(
prediction_dict, classes, y_indices, x_indices,
boxes_strided, num_detections)
keypoints, keypoint_scores = (
convert_strided_predictions_to_normalized_keypoints(
keypoints, keypoint_scores, self._stride, true_image_shapes,
clip_out_of_frame_keypoints=True))
postprocess_dict.update({ postprocess_dict.update({
fields.DetectionResultFields.detection_keypoints: keypoints, fields.DetectionResultFields.detection_keypoints: keypoints,
fields.DetectionResultFields.detection_keypoint_scores: fields.DetectionResultFields.detection_keypoint_scores:
...@@ -3107,6 +3143,72 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -3107,6 +3143,72 @@ class CenterNetMetaArch(model.DetectionModel):
return keypoints, keypoint_scores return keypoints, keypoint_scores
def _postprocess_keypoints_simple(self, prediction_dict, classes, y_indices,
x_indices, boxes, num_detections):
"""Performs postprocessing on keypoint predictions (one class only).
This function handles the special case of keypoint task that the model
predicts only one class of the bounding box/keypoint (e.g. person). By the
assumption, the function uses only tf.lite supported ops and should run
faster.
Args:
prediction_dict: a dictionary holding predicted tensors, returned from the
predict() method. This dictionary should contain keypoint prediction
feature maps for each keypoint task.
classes: A [batch_size, max_detections] int tensor with class indices for
all detected objects.
y_indices: A [batch_size, max_detections] int tensor with y indices for
all object centers.
x_indices: A [batch_size, max_detections] int tensor with x indices for
all object centers.
boxes: A [batch_size, max_detections, 4] float32 tensor with bounding
boxes in (un-normalized) output space.
num_detections: A [batch_size] int tensor with the number of valid
detections for each image.
Returns:
A tuple of
keypoints: a [batch_size, max_detection, num_total_keypoints, 2] float32
tensor with keypoints in the output (strided) coordinate frame.
keypoint_scores: a [batch_size, max_detections, num_total_keypoints]
float32 tensor with keypoint scores.
"""
# This function only works when there is only one keypoint task and the
# number of classes equal to one. For more general use cases, please use
# _postprocess_keypoints instead.
assert len(self._kp_params_dict) == 1 and self._num_classes == 1
task_name, kp_params = next(iter(self._kp_params_dict.items()))
keypoint_heatmap = prediction_dict[
get_keypoint_name(task_name, KEYPOINT_HEATMAP)][-1]
keypoint_offsets = prediction_dict[
get_keypoint_name(task_name, KEYPOINT_OFFSET)][-1]
keypoint_regression = prediction_dict[
get_keypoint_name(task_name, KEYPOINT_REGRESSION)][-1]
batch_size, _, _ = _get_shape(boxes, 3)
kpt_coords_for_example_list = []
kpt_scores_for_example_list = []
for ex_ind in range(batch_size):
# Postprocess keypoints and scores for class and single image. Shapes
# are [1, max_detections, num_keypoints, 2] and
# [1, max_detections, num_keypoints], respectively.
kpt_coords_for_class, kpt_scores_for_class = (
self._postprocess_keypoints_for_class_and_image_simple(
keypoint_heatmap, keypoint_offsets, keypoint_regression,
classes, y_indices, x_indices, boxes, ex_ind, kp_params))
kpt_coords_for_example_list.append(kpt_coords_for_class)
kpt_scores_for_example_list.append(kpt_scores_for_class)
# Concatenate all keypoints and scores from all examples in the batch.
# Shapes are [batch_size, max_detections, num_keypoints, 2] and
# [batch_size, max_detections, num_keypoints], respectively.
keypoints = tf.concat(kpt_coords_for_example_list, axis=0)
keypoint_scores = tf.concat(kpt_scores_for_example_list, axis=0)
return keypoints, keypoint_scores
def _get_instance_indices(self, classes, num_detections, batch_index, def _get_instance_indices(self, classes, num_detections, batch_index,
class_id): class_id):
"""Gets the instance indices that match the target class ID. """Gets the instance indices that match the target class ID.
...@@ -3234,6 +3336,86 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -3234,6 +3336,86 @@ class CenterNetMetaArch(model.DetectionModel):
return refined_keypoints, refined_scores return refined_keypoints, refined_scores
def _postprocess_keypoints_for_class_and_image_simple(
self, keypoint_heatmap, keypoint_offsets, keypoint_regression, classes,
y_indices, x_indices, boxes, batch_index, kp_params):
"""Postprocess keypoints for a single image and class.
This function is similar to "_postprocess_keypoints_for_class_and_image"
except that it assumes there is only one class of bounding box/keypoint to
be handled. The function is tf.lite compatible.
Args:
keypoint_heatmap: A [batch_size, height, width, num_keypoints] float32
tensor with keypoint heatmaps.
keypoint_offsets: A [batch_size, height, width, 2] float32 tensor with
local offsets to keypoint centers.
keypoint_regression: A [batch_size, height, width, 2 * num_keypoints]
float32 tensor with regressed offsets to all keypoints.
classes: A [batch_size, max_detections] int tensor with class indices for
all detected objects.
y_indices: A [batch_size, max_detections] int tensor with y indices for
all object centers.
x_indices: A [batch_size, max_detections] int tensor with x indices for
all object centers.
boxes: A [batch_size, max_detections, 4] float32 tensor with detected
boxes in the output (strided) frame.
batch_index: An integer specifying the index for an example in the batch.
kp_params: A `KeypointEstimationParams` object with parameters for a
single keypoint class.
Returns:
A tuple of
refined_keypoints: A [1, num_instances, num_keypoints, 2] float32 tensor
with refined keypoints for a single class in a single image, expressed
in the output (strided) coordinate frame. Note that `num_instances` is a
dynamic dimension, and corresponds to the number of valid detections
for the specific class.
refined_scores: A [1, num_instances, num_keypoints] float32 tensor with
keypoint scores.
"""
num_keypoints = len(kp_params.keypoint_indices)
keypoint_heatmap = tf.nn.sigmoid(
keypoint_heatmap[batch_index:batch_index+1, ...])
keypoint_offsets = keypoint_offsets[batch_index:batch_index+1, ...]
keypoint_regression = keypoint_regression[batch_index:batch_index+1, ...]
y_indices = y_indices[batch_index:batch_index+1, ...]
x_indices = x_indices[batch_index:batch_index+1, ...]
boxes_slice = boxes[batch_index:batch_index+1, ...]
# Gather the regressed keypoints. Final tensor has shape
# [1, num_instances, num_keypoints, 2].
regressed_keypoints_for_objects = regressed_keypoints_at_object_centers(
keypoint_regression, y_indices, x_indices)
regressed_keypoints_for_objects = tf.reshape(
regressed_keypoints_for_objects, [1, -1, num_keypoints, 2])
# Get the candidate keypoints and scores.
# The shape of keypoint_candidates and keypoint_scores is:
# [1, num_candidates_per_keypoint, num_keypoints, 2] and
# [1, num_candidates_per_keypoint, num_keypoints], respectively.
keypoint_candidates, keypoint_scores, num_keypoint_candidates = (
prediction_tensors_to_keypoint_candidates(
keypoint_heatmap, keypoint_offsets,
keypoint_score_threshold=(
kp_params.keypoint_candidate_score_threshold),
max_pool_kernel_size=kp_params.peak_max_pool_kernel_size,
max_candidates=kp_params.num_candidates_per_keypoint))
# Get the refined keypoints and scores, of shape
# [1, num_instances, num_keypoints, 2] and
# [1, num_instances, num_keypoints], respectively.
refined_keypoints, refined_scores = refine_keypoints(
regressed_keypoints_for_objects, keypoint_candidates, keypoint_scores,
num_keypoint_candidates, bboxes=boxes_slice,
unmatched_keypoint_score=kp_params.unmatched_keypoint_score,
box_scale=kp_params.box_scale,
candidate_search_scale=kp_params.candidate_search_scale,
candidate_ranking_mode=kp_params.candidate_ranking_mode)
return refined_keypoints, refined_scores
def regularization_losses(self): def regularization_losses(self):
return [] return []
......
...@@ -484,6 +484,70 @@ class CenterNetMetaArchHelpersTest(test_case.TestCase, parameterized.TestCase): ...@@ -484,6 +484,70 @@ class CenterNetMetaArchHelpersTest(test_case.TestCase, parameterized.TestCase):
np.testing.assert_array_equal([1, 0, 0, 2], x_inds[1]) np.testing.assert_array_equal([1, 0, 0, 2], x_inds[1])
np.testing.assert_array_equal([0, 0, 1, 1], channel_inds[1]) np.testing.assert_array_equal([0, 0, 1, 1], channel_inds[1])
def test_top_k_feature_map_locations_k1(self):
feature_map_np = np.zeros((2, 3, 3, 2), dtype=np.float32)
feature_map_np[0, 2, 0, 0] = 1.0 # Selected.
feature_map_np[0, 2, 1, 0] = 0.9
feature_map_np[0, 0, 1, 0] = 0.7
feature_map_np[0, 2, 2, 1] = 0.5
feature_map_np[0, 0, 0, 1] = 0.3
feature_map_np[1, 2, 1, 0] = 0.7
feature_map_np[1, 1, 0, 0] = 0.4
feature_map_np[1, 1, 2, 0] = 0.3
feature_map_np[1, 1, 0, 1] = 0.8 # Selected.
feature_map_np[1, 1, 2, 1] = 0.3
def graph_fn():
feature_map = tf.constant(feature_map_np)
scores, y_inds, x_inds, channel_inds = (
cnma.top_k_feature_map_locations(
feature_map, max_pool_kernel_size=3, k=1, per_channel=False))
return scores, y_inds, x_inds, channel_inds
scores, y_inds, x_inds, channel_inds = self.execute(graph_fn, [])
np.testing.assert_allclose([1.0], scores[0])
np.testing.assert_array_equal([2], y_inds[0])
np.testing.assert_array_equal([0], x_inds[0])
np.testing.assert_array_equal([0], channel_inds[0])
np.testing.assert_allclose([0.8], scores[1])
np.testing.assert_array_equal([1], y_inds[1])
np.testing.assert_array_equal([0], x_inds[1])
np.testing.assert_array_equal([1], channel_inds[1])
def test_top_k_feature_map_locations_k1_per_channel(self):
feature_map_np = np.zeros((2, 3, 3, 2), dtype=np.float32)
feature_map_np[0, 2, 0, 0] = 1.0 # Selected.
feature_map_np[0, 2, 1, 0] = 0.9
feature_map_np[0, 0, 1, 0] = 0.7
feature_map_np[0, 2, 2, 1] = 0.5 # Selected.
feature_map_np[0, 0, 0, 1] = 0.3
feature_map_np[1, 2, 1, 0] = 0.7 # Selected.
feature_map_np[1, 1, 0, 0] = 0.4
feature_map_np[1, 1, 2, 0] = 0.3
feature_map_np[1, 1, 0, 1] = 0.8 # Selected.
feature_map_np[1, 1, 2, 1] = 0.3
def graph_fn():
feature_map = tf.constant(feature_map_np)
scores, y_inds, x_inds, channel_inds = (
cnma.top_k_feature_map_locations(
feature_map, max_pool_kernel_size=3, k=1, per_channel=True))
return scores, y_inds, x_inds, channel_inds
scores, y_inds, x_inds, channel_inds = self.execute(graph_fn, [])
np.testing.assert_allclose([1.0, 0.5], scores[0])
np.testing.assert_array_equal([2, 2], y_inds[0])
np.testing.assert_array_equal([0, 2], x_inds[0])
np.testing.assert_array_equal([0, 1], channel_inds[0])
np.testing.assert_allclose([0.7, 0.8], scores[1])
np.testing.assert_array_equal([2, 1], y_inds[1])
np.testing.assert_array_equal([1, 0], x_inds[1])
np.testing.assert_array_equal([0, 1], channel_inds[1])
def test_box_prediction(self): def test_box_prediction(self):
class_pred = np.zeros((3, 128, 128, 5), dtype=np.float32) class_pred = np.zeros((3, 128, 128, 5), dtype=np.float32)
...@@ -1213,7 +1277,7 @@ def get_fake_temporal_offset_params(): ...@@ -1213,7 +1277,7 @@ def get_fake_temporal_offset_params():
task_loss_weight=1.0) task_loss_weight=1.0)
def build_center_net_meta_arch(build_resnet=False): def build_center_net_meta_arch(build_resnet=False, num_classes=_NUM_CLASSES):
"""Builds the CenterNet meta architecture.""" """Builds the CenterNet meta architecture."""
if build_resnet: if build_resnet:
feature_extractor = ( feature_extractor = (
...@@ -1231,19 +1295,31 @@ def build_center_net_meta_arch(build_resnet=False): ...@@ -1231,19 +1295,31 @@ def build_center_net_meta_arch(build_resnet=False):
min_dimension=128, min_dimension=128,
max_dimension=128, max_dimension=128,
pad_to_max_dimesnion=True) pad_to_max_dimesnion=True)
return cnma.CenterNetMetaArch(
is_training=True, if num_classes == 1:
add_summaries=False, return cnma.CenterNetMetaArch(
num_classes=_NUM_CLASSES, is_training=True,
feature_extractor=feature_extractor, add_summaries=False,
image_resizer_fn=image_resizer_fn, num_classes=num_classes,
object_center_params=get_fake_center_params(), feature_extractor=feature_extractor,
object_detection_params=get_fake_od_params(), image_resizer_fn=image_resizer_fn,
keypoint_params_dict={_TASK_NAME: get_fake_kp_params()}, object_center_params=get_fake_center_params(),
mask_params=get_fake_mask_params(), object_detection_params=get_fake_od_params(),
densepose_params=get_fake_densepose_params(), keypoint_params_dict={_TASK_NAME: get_fake_kp_params()})
track_params=get_fake_track_params(), else:
temporal_offset_params=get_fake_temporal_offset_params()) return cnma.CenterNetMetaArch(
is_training=True,
add_summaries=False,
num_classes=num_classes,
feature_extractor=feature_extractor,
image_resizer_fn=image_resizer_fn,
object_center_params=get_fake_center_params(),
object_detection_params=get_fake_od_params(),
keypoint_params_dict={_TASK_NAME: get_fake_kp_params()},
mask_params=get_fake_mask_params(),
densepose_params=get_fake_densepose_params(),
track_params=get_fake_track_params(),
temporal_offset_params=get_fake_temporal_offset_params())
def _logit(p): def _logit(p):
...@@ -1650,6 +1726,72 @@ class CenterNetMetaArchTest(test_case.TestCase, parameterized.TestCase): ...@@ -1650,6 +1726,72 @@ class CenterNetMetaArchTest(test_case.TestCase, parameterized.TestCase):
detections['detection_surface_coords'][0, 0, :, :], detections['detection_surface_coords'][0, 0, :, :],
np.zeros_like(detections['detection_surface_coords'][0, 0, :, :])) np.zeros_like(detections['detection_surface_coords'][0, 0, :, :]))
def test_postprocess_simple(self):
"""Test the postprocess function."""
model = build_center_net_meta_arch(num_classes=1)
max_detection = model._center_params.max_box_predictions
num_keypoints = len(model._kp_params_dict[_TASK_NAME].keypoint_indices)
class_center = np.zeros((1, 32, 32, 1), dtype=np.float32)
height_width = np.zeros((1, 32, 32, 2), dtype=np.float32)
offset = np.zeros((1, 32, 32, 2), dtype=np.float32)
keypoint_heatmaps = np.zeros((1, 32, 32, num_keypoints), dtype=np.float32)
keypoint_offsets = np.zeros((1, 32, 32, 2), dtype=np.float32)
keypoint_regression = np.random.randn(1, 32, 32, num_keypoints * 2)
class_probs = np.zeros(1)
class_probs[0] = _logit(0.75)
class_center[0, 16, 16] = class_probs
height_width[0, 16, 16] = [5, 10]
offset[0, 16, 16] = [.25, .5]
keypoint_regression[0, 16, 16] = [
-1., -1.,
-1., 1.,
1., -1.,
1., 1.]
keypoint_heatmaps[0, 14, 14, 0] = _logit(0.9)
keypoint_heatmaps[0, 14, 18, 1] = _logit(0.9)
keypoint_heatmaps[0, 18, 14, 2] = _logit(0.9)
keypoint_heatmaps[0, 18, 18, 3] = _logit(0.05) # Note the low score.
class_center = tf.constant(class_center)
height_width = tf.constant(height_width)
offset = tf.constant(offset)
keypoint_heatmaps = tf.constant(keypoint_heatmaps, dtype=tf.float32)
keypoint_offsets = tf.constant(keypoint_offsets, dtype=tf.float32)
keypoint_regression = tf.constant(keypoint_regression, dtype=tf.float32)
prediction_dict = {
cnma.OBJECT_CENTER: [class_center],
cnma.BOX_SCALE: [height_width],
cnma.BOX_OFFSET: [offset],
cnma.get_keypoint_name(_TASK_NAME, cnma.KEYPOINT_HEATMAP):
[keypoint_heatmaps],
cnma.get_keypoint_name(_TASK_NAME, cnma.KEYPOINT_OFFSET):
[keypoint_offsets],
cnma.get_keypoint_name(_TASK_NAME, cnma.KEYPOINT_REGRESSION):
[keypoint_regression],
}
def graph_fn():
detections = model.postprocess(prediction_dict,
tf.constant([[128, 128, 3]]))
return detections
detections = self.execute_cpu(graph_fn, [])
self.assertAllClose(detections['detection_boxes'][0, 0],
np.array([55, 46, 75, 86]) / 128.0)
self.assertAllClose(detections['detection_scores'][0],
[.75, .5, .5, .5, .5])
self.assertEqual(detections['detection_classes'][0, 0], 0)
self.assertEqual(detections['num_detections'], [5])
self.assertAllEqual([1, max_detection, num_keypoints, 2],
detections['detection_keypoints'].shape)
self.assertAllEqual([1, max_detection, num_keypoints],
detections['detection_keypoint_scores'].shape)
def test_get_instance_indices(self): def test_get_instance_indices(self):
classes = tf.constant([[0, 1, 2, 0], [2, 1, 2, 2]], dtype=tf.int32) classes = tf.constant([[0, 1, 2, 0], [2, 1, 2, 2]], dtype=tf.int32)
num_detections = tf.constant([1, 3], dtype=tf.int32) num_detections = tf.constant([1, 3], dtype=tf.int32)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment