Commit e8a80796 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower Committed by TF Object Detection Team
Browse files

Support 1D detection for CenterNet as inputs with height==1.

PiperOrigin-RevId: 371037598
parent 68411471
......@@ -985,8 +985,8 @@ class CenterNetCenterHeatmapTargetAssigner(object):
the stride specified during initialization.
"""
out_height = tf.cast(height // self._stride, tf.float32)
out_width = tf.cast(width // self._stride, tf.float32)
out_height = tf.cast(tf.maximum(height // self._stride, 1), tf.float32)
out_width = tf.cast(tf.maximum(width // self._stride, 1), tf.float32)
# Compute the yx-grid to be used to generate the heatmap. Each returned
# tensor has shape of [out_height, out_width]
(y_grid, x_grid) = ta_utils.image_shape_to_grids(out_height, out_width)
......@@ -999,9 +999,10 @@ class CenterNetCenterHeatmapTargetAssigner(object):
gt_weights_list):
boxes = box_list.BoxList(boxes)
# Convert the box coordinates to absolute output image dimension space.
boxes = box_list_ops.to_absolute_coordinates(boxes,
height // self._stride,
width // self._stride)
boxes = box_list_ops.to_absolute_coordinates(
boxes,
tf.maximum(height // self._stride, 1),
tf.maximum(width // self._stride, 1))
# Get the box center coordinates. Each returned tensors have the shape of
# [num_instances]
(y_center, x_center, boxes_height,
......@@ -1062,8 +1063,8 @@ class CenterNetCenterHeatmapTargetAssigner(object):
assert (self._keypoint_weights_for_center is not None and
self._keypoint_class_id is not None and
self._keypoint_indices is not None)
out_height = tf.cast(height // self._stride, tf.float32)
out_width = tf.cast(width // self._stride, tf.float32)
out_height = tf.cast(tf.maximum(height // self._stride, 1), tf.float32)
out_width = tf.cast(tf.maximum(width // self._stride, 1), tf.float32)
# Compute the yx-grid to be used to generate the heatmap. Each returned
# tensor has shape of [out_height, out_width]
(y_grid, x_grid) = ta_utils.image_shape_to_grids(out_height, out_width)
......@@ -1230,9 +1231,10 @@ class CenterNetBoxTargetAssigner(object):
for i, (boxes, weights) in enumerate(zip(gt_boxes_list, gt_weights_list)):
boxes = box_list.BoxList(boxes)
boxes = box_list_ops.to_absolute_coordinates(boxes,
height // self._stride,
width // self._stride)
boxes = box_list_ops.to_absolute_coordinates(
boxes,
tf.maximum(height // self._stride, 1),
tf.maximum(width // self._stride, 1))
# Get the box center coordinates. Each returned tensors have the shape of
# [num_boxes]
(y_center, x_center, boxes_height,
......@@ -1410,8 +1412,8 @@ class CenterNetKeypointTargetAssigner(object):
output_width] where all values within the regions of the blackout boxes
are 0.0 and 1.0 else where.
"""
out_width = tf.cast(width // self._stride, tf.float32)
out_height = tf.cast(height // self._stride, tf.float32)
out_width = tf.cast(tf.maximum(width // self._stride, 1), tf.float32)
out_height = tf.cast(tf.maximum(height // self._stride, 1), tf.float32)
# Compute the yx-grid to be used to generate the heatmap. Each returned
# tensor has shape of [out_height, out_width]
y_grid, x_grid = ta_utils.image_shape_to_grids(out_height, out_width)
......@@ -1464,9 +1466,10 @@ class CenterNetKeypointTargetAssigner(object):
if boxes is not None:
boxes = box_list.BoxList(boxes)
# Convert the box coordinates to absolute output image dimension space.
boxes = box_list_ops.to_absolute_coordinates(boxes,
height // self._stride,
width // self._stride)
boxes = box_list_ops.to_absolute_coordinates(
boxes,
tf.maximum(height // self._stride, 1),
tf.maximum(width // self._stride, 1))
# Get the box height and width. Each returned tensors have the shape
# of [num_instances]
(_, _, boxes_height,
......@@ -1586,8 +1589,8 @@ class CenterNetKeypointTargetAssigner(object):
zip(gt_keypoints_list, gt_classes_list, gt_keypoints_weights_list,
gt_weights_list)):
keypoints_absolute, kp_weights = _preprocess_keypoints_and_weights(
out_height=height // self._stride,
out_width=width // self._stride,
out_height=tf.maximum(height // self._stride, 1),
out_width=tf.maximum(width // self._stride, 1),
keypoints=keypoints,
class_onehot=classes,
class_weights=weights,
......@@ -1604,8 +1607,9 @@ class CenterNetKeypointTargetAssigner(object):
# All keypoint coordinates and their neighbors:
# [num_instance * num_keypoints, num_neighbors]
(y_source_neighbors, x_source_neighbors,
valid_sources) = ta_utils.get_surrounding_grids(height // self._stride,
width // self._stride,
valid_sources) = ta_utils.get_surrounding_grids(
tf.cast(tf.maximum(height // self._stride, 1), tf.float32),
tf.cast(tf.maximum(width // self._stride, 1), tf.float32),
y_source, x_source,
self._peak_radius)
_, num_neighbors = shape_utils.combined_static_and_dynamic_shape(
......@@ -1722,8 +1726,8 @@ class CenterNetKeypointTargetAssigner(object):
gt_keypoints_weights_list, gt_weights_list,
gt_keypoint_depths_list, gt_keypoint_depth_weights_list)):
keypoints_absolute, kp_weights = _preprocess_keypoints_and_weights(
out_height=height // self._stride,
out_width=width // self._stride,
out_height=tf.maximum(height // self._stride, 1),
out_width=tf.maximum(width // self._stride, 1),
keypoints=keypoints,
class_onehot=classes,
class_weights=weights,
......@@ -1740,8 +1744,9 @@ class CenterNetKeypointTargetAssigner(object):
# All keypoint coordinates and their neighbors:
# [num_instance * num_keypoints, num_neighbors]
(y_source_neighbors, x_source_neighbors,
valid_sources) = ta_utils.get_surrounding_grids(height // self._stride,
width // self._stride,
valid_sources) = ta_utils.get_surrounding_grids(
tf.cast(tf.maximum(height // self._stride, 1), tf.float32),
tf.cast(tf.maximum(width // self._stride, 1), tf.float32),
y_source, x_source,
self._peak_radius)
_, num_neighbors = shape_utils.combined_static_and_dynamic_shape(
......@@ -1894,8 +1899,8 @@ class CenterNetKeypointTargetAssigner(object):
zip(gt_keypoints_list, gt_classes_list,
gt_boxes_list, gt_keypoints_weights_list, gt_weights_list)):
keypoints_absolute, kp_weights = _preprocess_keypoints_and_weights(
out_height=height // self._stride,
out_width=width // self._stride,
out_height=tf.maximum(height // self._stride, 1),
out_width=tf.maximum(width // self._stride, 1),
keypoints=keypoints,
class_onehot=classes,
class_weights=weights,
......@@ -1909,9 +1914,10 @@ class CenterNetKeypointTargetAssigner(object):
if boxes is not None:
# Compute joint center from boxes.
boxes = box_list.BoxList(boxes)
boxes = box_list_ops.to_absolute_coordinates(boxes,
height // self._stride,
width // self._stride)
boxes = box_list_ops.to_absolute_coordinates(
boxes,
tf.maximum(height // self._stride, 1),
tf.maximum(width // self._stride, 1))
y_center, x_center, _, _ = boxes.get_center_coordinates_and_sizes()
else:
# TODO(yuhuic): Add the logic to generate object centers from keypoints.
......@@ -1930,7 +1936,8 @@ class CenterNetKeypointTargetAssigner(object):
# [num_instance * num_keypoints, num_neighbors]
(y_source_neighbors, x_source_neighbors,
valid_sources) = ta_utils.get_surrounding_grids(
height // self._stride, width // self._stride,
tf.cast(tf.maximum(height // self._stride, 1), tf.float32),
tf.cast(tf.maximum(width // self._stride, 1), tf.float32),
tf.keras.backend.flatten(y_center_tiled),
tf.keras.backend.flatten(x_center_tiled), self._peak_radius)
......@@ -2023,8 +2030,8 @@ class CenterNetMaskTargetAssigner(object):
_, input_height, input_width = (
shape_utils.combined_static_and_dynamic_shape(gt_masks_list[0]))
output_height = input_height // self._stride
output_width = input_width // self._stride
output_height = tf.maximum(input_height // self._stride, 1)
output_width = tf.maximum(input_width // self._stride, 1)
segmentation_targets_list = []
for gt_masks, gt_classes in zip(gt_masks_list, gt_classes_list):
......@@ -2114,7 +2121,9 @@ class CenterNetDensePoseTargetAssigner(object):
part_ids_one_hot = tf.one_hot(part_ids_flattened, depth=self._num_parts)
# Get DensePose coordinates in the output space.
surface_coords_abs = densepose_ops.to_absolute_coordinates(
surface_coords, height // self._stride, width // self._stride)
surface_coords,
tf.maximum(height // self._stride, 1),
tf.maximum(width // self._stride, 1))
surface_coords_abs = tf.reshape(surface_coords_abs, [-1, 4])
# Each tensor has shape [num_boxes * max_sampled_points].
yabs, xabs, v, u = tf.unstack(surface_coords_abs, axis=-1)
......@@ -2213,9 +2222,10 @@ class CenterNetTrackTargetAssigner(object):
for i, (boxes, weights) in enumerate(zip(gt_boxes_list, gt_weights_list)):
boxes = box_list.BoxList(boxes)
boxes = box_list_ops.to_absolute_coordinates(boxes,
height // self._stride,
width // self._stride)
boxes = box_list_ops.to_absolute_coordinates(
boxes,
tf.maximum(height // self._stride, 1),
tf.maximum(width // self._stride, 1))
# Get the box center coordinates. Each returned tensors have the shape of
# [num_boxes]
(y_center, x_center, _, _) = boxes.get_center_coordinates_and_sizes()
......@@ -2318,8 +2328,8 @@ class CenterNetCornerOffsetTargetAssigner(object):
"""
_, input_height, input_width = (
shape_utils.combined_static_and_dynamic_shape(gt_masks_list[0]))
output_height = input_height // self._stride
output_width = input_width // self._stride
output_height = tf.maximum(input_height // self._stride, 1)
output_width = tf.maximum(input_width // self._stride, 1)
y_grid, x_grid = tf.meshgrid(
tf.range(output_height), tf.range(output_width),
indexing='ij')
......@@ -2332,6 +2342,8 @@ class CenterNetCornerOffsetTargetAssigner(object):
method=ResizeMethod.NEAREST_NEIGHBOR)
gt_masks = filter_mask_overlap(gt_masks, self._overlap_resolution)
output_height = tf.cast(output_height, tf.float32)
output_width = tf.cast(output_width, tf.float32)
ymin, xmin, ymax, xmax = tf.unstack(gt_boxes, axis=1)
ymin, ymax = ymin * output_height, ymax * output_height
xmin, xmax = xmin * output_width, xmax * output_width
......@@ -2427,9 +2439,10 @@ class CenterNetTemporalOffsetTargetAssigner(object):
for i, (boxes, offsets, match_flags, weights) in enumerate(zip(
gt_boxes_list, gt_offsets_list, gt_match_list, gt_weights_list)):
boxes = box_list.BoxList(boxes)
boxes = box_list_ops.to_absolute_coordinates(boxes,
height // self._stride,
width // self._stride)
boxes = box_list_ops.to_absolute_coordinates(
boxes,
tf.maximum(height // self._stride, 1),
tf.maximum(width // self._stride, 1))
# Get the box center coordinates. Each returned tensors have the shape of
# [num_boxes]
(y_center, x_center, _, _) = boxes.get_center_coordinates_and_sizes()
......
......@@ -137,7 +137,8 @@ class CenterNetFeatureExtractor(tf.keras.Model):
def make_prediction_net(num_out_channels, kernel_sizes=(3), num_filters=(256),
bias_fill=None, use_depthwise=False, name=None):
bias_fill=None, use_depthwise=False, name=None,
unit_height_conv=True):
"""Creates a network to predict the given number of output channels.
This function is intended to make the prediction heads for the CenterNet
......@@ -157,6 +158,7 @@ def make_prediction_net(num_out_channels, kernel_sizes=(3), num_filters=(256),
use_depthwise: If true, use SeparableConv2D to construct the Sequential
layers instead of Conv2D.
name: Optional name for the prediction net.
unit_height_conv: If True, Conv2Ds have asymmetric kernels with height=1.
Returns:
net: A keras module which when called on an input tensor of size
......@@ -189,7 +191,7 @@ def make_prediction_net(num_out_channels, kernel_sizes=(3), num_filters=(256),
layers.append(
conv_fn(
num_filter,
kernel_size=kernel_size,
kernel_size=[1, kernel_size] if unit_height_conv else kernel_size,
padding='same',
name='conv2_%d' % idx if tf_version.is_tf1() else None))
layers.append(tf.keras.layers.ReLU())
......@@ -2174,7 +2176,8 @@ class CenterNetMetaArch(model.DetectionModel):
temporal_offset_params=None,
use_depthwise=False,
compute_heatmap_sparse=False,
non_max_suppression_fn=None):
non_max_suppression_fn=None,
unit_height_conv=False):
"""Initializes a CenterNet model.
Args:
......@@ -2218,6 +2221,8 @@ class CenterNetMetaArch(model.DetectionModel):
better with number of channels in the heatmap, but in some cases is
known to cause an OOM error. See b/170989061.
non_max_suppression_fn: Optional Non Max Suppression function to apply.
unit_height_conv: If True, Conv2Ds in prediction heads have asymmetric
kernels with height=1.
"""
assert object_detection_params or keypoint_params_dict
# Shorten the name for convenience and better formatting.
......@@ -2244,11 +2249,15 @@ class CenterNetMetaArch(model.DetectionModel):
self._use_depthwise = use_depthwise
self._compute_heatmap_sparse = compute_heatmap_sparse
# subclasses may not implement the unit_height_conv arg, so only provide it
# as a kwarg if it is True.
kwargs = {'unit_height_conv': unit_height_conv} if unit_height_conv else {}
# Construct the prediction head nets.
self._prediction_head_dict = self._construct_prediction_heads(
num_classes,
self._num_feature_outputs,
class_prediction_bias_init=self._center_params.heatmap_bias_init)
class_prediction_bias_init=self._center_params.heatmap_bias_init,
**kwargs)
# Initialize the target assigners.
self._target_assigner_dict = self._initialize_target_assigners(
stride=self._stride,
......@@ -2269,7 +2278,8 @@ class CenterNetMetaArch(model.DetectionModel):
def _make_prediction_net_list(self, num_feature_outputs, num_out_channels,
kernel_sizes=(3), num_filters=(256),
bias_fill=None, name=None):
bias_fill=None, name=None,
unit_height_conv=False):
prediction_net_list = []
for i in range(num_feature_outputs):
prediction_net_list.append(
......@@ -2279,11 +2289,13 @@ class CenterNetMetaArch(model.DetectionModel):
num_filters=num_filters,
bias_fill=bias_fill,
use_depthwise=self._use_depthwise,
name='{}_{}'.format(name, i) if name else name))
name='{}_{}'.format(name, i) if name else name,
unit_height_conv=unit_height_conv))
return prediction_net_list
def _construct_prediction_heads(self, num_classes, num_feature_outputs,
class_prediction_bias_init):
class_prediction_bias_init,
unit_height_conv=False):
"""Constructs the prediction heads based on the specific parameters.
Args:
......@@ -2295,6 +2307,7 @@ class CenterNetMetaArch(model.DetectionModel):
class_prediction_bias_init: float, the initial value of bias in the
convolutional kernel of the class prediction head. If set to None, the
bias is initialized with zeros.
unit_height_conv: If True, Conv2Ds have asymmetric kernels with height=1.
Returns:
A dictionary of keras modules generated by calling make_prediction_net
......@@ -2308,13 +2321,16 @@ class CenterNetMetaArch(model.DetectionModel):
kernel_sizes=self._center_params.center_head_kernel_sizes,
num_filters=self._center_params.center_head_num_filters,
bias_fill=class_prediction_bias_init,
name='center')
name='center',
unit_height_conv=unit_height_conv)
if self._od_params is not None:
prediction_heads[BOX_SCALE] = self._make_prediction_net_list(
num_feature_outputs, NUM_SIZE_CHANNELS, name='box_scale')
num_feature_outputs, NUM_SIZE_CHANNELS, name='box_scale',
unit_height_conv=unit_height_conv)
prediction_heads[BOX_OFFSET] = self._make_prediction_net_list(
num_feature_outputs, NUM_OFFSET_CHANNELS, name='box_offset')
num_feature_outputs, NUM_OFFSET_CHANNELS, name='box_offset',
unit_height_conv=unit_height_conv)
if self._kp_params_dict is not None:
for task_name, kp_params in self._kp_params_dict.items():
......@@ -2326,14 +2342,16 @@ class CenterNetMetaArch(model.DetectionModel):
kernel_sizes=kp_params.heatmap_head_kernel_sizes,
num_filters=kp_params.heatmap_head_num_filters,
bias_fill=kp_params.heatmap_bias_init,
name='kpt_heatmap')
name='kpt_heatmap',
unit_height_conv=unit_height_conv)
prediction_heads[get_keypoint_name(
task_name, KEYPOINT_REGRESSION)] = self._make_prediction_net_list(
num_feature_outputs,
NUM_OFFSET_CHANNELS * num_keypoints,
kernel_sizes=kp_params.regress_head_kernel_sizes,
num_filters=kp_params.regress_head_num_filters,
name='kpt_regress')
name='kpt_regress',
unit_height_conv=unit_height_conv)
if kp_params.per_keypoint_offset:
prediction_heads[get_keypoint_name(
......@@ -2342,7 +2360,8 @@ class CenterNetMetaArch(model.DetectionModel):
NUM_OFFSET_CHANNELS * num_keypoints,
kernel_sizes=kp_params.offset_head_kernel_sizes,
num_filters=kp_params.offset_head_num_filters,
name='kpt_offset')
name='kpt_offset',
unit_height_conv=unit_height_conv)
else:
prediction_heads[get_keypoint_name(
task_name, KEYPOINT_OFFSET)] = self._make_prediction_net_list(
......@@ -2350,38 +2369,44 @@ class CenterNetMetaArch(model.DetectionModel):
NUM_OFFSET_CHANNELS,
kernel_sizes=kp_params.offset_head_kernel_sizes,
num_filters=kp_params.offset_head_num_filters,
name='kpt_offset')
name='kpt_offset',
unit_height_conv=unit_height_conv)
if kp_params.predict_depth:
num_depth_channel = (
num_keypoints if kp_params.per_keypoint_depth else 1)
prediction_heads[get_keypoint_name(
task_name, KEYPOINT_DEPTH)] = self._make_prediction_net_list(
num_feature_outputs, num_depth_channel, name='kpt_depth')
num_feature_outputs, num_depth_channel, name='kpt_depth',
unit_height_conv=unit_height_conv)
if self._mask_params is not None:
prediction_heads[SEGMENTATION_HEATMAP] = self._make_prediction_net_list(
num_feature_outputs,
num_classes,
bias_fill=self._mask_params.heatmap_bias_init,
name='seg_heatmap')
name='seg_heatmap',
unit_height_conv=unit_height_conv)
if self._densepose_params is not None:
prediction_heads[DENSEPOSE_HEATMAP] = self._make_prediction_net_list(
num_feature_outputs,
self._densepose_params.num_parts,
bias_fill=self._densepose_params.heatmap_bias_init,
name='dense_pose_heatmap')
name='dense_pose_heatmap',
unit_height_conv=unit_height_conv)
prediction_heads[DENSEPOSE_REGRESSION] = self._make_prediction_net_list(
num_feature_outputs,
2 * self._densepose_params.num_parts,
name='dense_pose_regress')
name='dense_pose_regress',
unit_height_conv=unit_height_conv)
if self._track_params is not None:
prediction_heads[TRACK_REID] = self._make_prediction_net_list(
num_feature_outputs,
self._track_params.reid_embed_size,
name='track_reid')
name='track_reid',
unit_height_conv=unit_height_conv)
# Creates a classification network to train object embeddings by learning
# a projection from embedding space to object track ID space.
......@@ -2400,7 +2425,8 @@ class CenterNetMetaArch(model.DetectionModel):
self._track_params.reid_embed_size,)))
if self._temporal_offset_params is not None:
prediction_heads[TEMPORAL_OFFSET] = self._make_prediction_net_list(
num_feature_outputs, NUM_OFFSET_CHANNELS, name='temporal_offset')
num_feature_outputs, NUM_OFFSET_CHANNELS, name='temporal_offset',
unit_height_conv=unit_height_conv)
return prediction_heads
def _initialize_target_assigners(self, stride, min_box_overlap_iou):
......@@ -3357,8 +3383,8 @@ class CenterNetMetaArch(model.DetectionModel):
_, input_height, input_width, _ = _get_shape(
prediction_dict['preprocessed_inputs'], 4)
output_height, output_width = (input_height // self._stride,
input_width // self._stride)
output_height, output_width = (tf.maximum(input_height // self._stride, 1),
tf.maximum(input_width // self._stride, 1))
# TODO(vighneshb) Explore whether using floor here is safe.
output_true_image_shapes = tf.ceil(
......
......@@ -2995,6 +2995,162 @@ class CenterNetFeatureExtractorTest(test_case.TestCase):
self.assertAllClose(output[..., 2], 3 * np.ones((2, 32, 32)))
class Dummy1dFeatureExtractor(cnma.CenterNetFeatureExtractor):
"""Returns a static tensor."""
def __init__(self, tensor, out_stride=1, channel_means=(0., 0., 0.),
channel_stds=(1., 1., 1.), bgr_ordering=False):
"""Intializes the feature extractor.
Args:
tensor: The tensor to return as the processed feature.
out_stride: The out_stride to return if asked.
channel_means: Ignored, but provided for API compatability.
channel_stds: Ignored, but provided for API compatability.
bgr_ordering: Ignored, but provided for API compatability.
"""
super().__init__(
channel_means=channel_means, channel_stds=channel_stds,
bgr_ordering=bgr_ordering)
self._tensor = tensor
self._out_stride = out_stride
def call(self, inputs):
return [self._tensor]
@property
def out_stride(self):
"""The stride in the output image of the network."""
return self._out_stride
@property
def num_feature_outputs(self):
"""Ther number of feature outputs returned by the feature extractor."""
return 1
@property
def supported_sub_model_types(self):
return ['detection']
def get_sub_model(self, sub_model_type):
if sub_model_type == 'detection':
return self._network
else:
ValueError('Sub model type "{}" not supported.'.format(sub_model_type))
@unittest.skipIf(tf_version.is_tf1(), 'Skipping TF2.X only test.')
class CenterNetMetaArch1dTest(test_case.TestCase, parameterized.TestCase):
@parameterized.parameters([1, 2])
def test_outputs_with_correct_shape(self, stride):
# The 1D case reuses code from the 2D cases. These tests only check that
# the output shapes are correct, and relies on other tests for correctness.
batch_size = 2
height = 1
width = 32
channels = 16
unstrided_inputs = np.random.randn(
batch_size, height, width, channels)
fixed_output_features = np.random.randn(
batch_size, height, width // stride, channels)
max_boxes = 10
num_classes = 3
feature_extractor = Dummy1dFeatureExtractor(fixed_output_features, stride)
arch = cnma.CenterNetMetaArch(
is_training=True,
add_summaries=True,
num_classes=num_classes,
feature_extractor=feature_extractor,
image_resizer_fn=None,
object_center_params=cnma.ObjectCenterParams(
classification_loss=losses.PenaltyReducedLogisticFocalLoss(),
object_center_loss_weight=1.0,
max_box_predictions=max_boxes,
),
object_detection_params=cnma.ObjectDetectionParams(
localization_loss=losses.L1LocalizationLoss(),
scale_loss_weight=1.0,
offset_loss_weight=1.0,
),
keypoint_params_dict=None,
mask_params=None,
densepose_params=None,
track_params=None,
temporal_offset_params=None,
use_depthwise=False,
compute_heatmap_sparse=False,
non_max_suppression_fn=None,
unit_height_conv=True)
arch.provide_groundtruth(
groundtruth_boxes_list=[
tf.constant([[0, 0.5, 1.0, 0.75],
[0, 0.1, 1.0, 0.25]], tf.float32),
tf.constant([[0, 0, 1.0, 1.0],
[0, 0, 0.0, 0.0]], tf.float32)
],
groundtruth_classes_list=[
tf.constant([[0, 0, 1],
[0, 1, 0]], tf.float32),
tf.constant([[1, 0, 0],
[0, 0, 0]], tf.float32)
],
groundtruth_weights_list=[
tf.constant([1.0, 1.0]),
tf.constant([1.0, 0.0])]
)
predictions = arch.predict(None, None) # input is hardcoded above.
predictions['preprocessed_inputs'] = tf.constant(unstrided_inputs)
true_shapes = tf.constant([[1, 32, 16], [1, 24, 16]], tf.int32)
postprocess_output = arch.postprocess(predictions, true_shapes)
losses_output = arch.loss(predictions, true_shapes)
self.assertIn('%s/%s' % (cnma.LOSS_KEY_PREFIX, cnma.OBJECT_CENTER),
losses_output)
self.assertEqual((), losses_output['%s/%s' % (
cnma.LOSS_KEY_PREFIX, cnma.OBJECT_CENTER)].shape)
self.assertIn('%s/%s' % (cnma.LOSS_KEY_PREFIX, cnma.BOX_SCALE),
losses_output)
self.assertEqual((), losses_output['%s/%s' % (
cnma.LOSS_KEY_PREFIX, cnma.BOX_SCALE)].shape)
self.assertIn('%s/%s' % (cnma.LOSS_KEY_PREFIX, cnma.BOX_OFFSET),
losses_output)
self.assertEqual((), losses_output['%s/%s' % (
cnma.LOSS_KEY_PREFIX, cnma.BOX_OFFSET)].shape)
self.assertIn('detection_scores', postprocess_output)
self.assertEqual(postprocess_output['detection_scores'].shape,
(batch_size, max_boxes))
self.assertIn('detection_multiclass_scores', postprocess_output)
self.assertEqual(postprocess_output['detection_multiclass_scores'].shape,
(batch_size, max_boxes, num_classes))
self.assertIn('detection_classes', postprocess_output)
self.assertEqual(postprocess_output['detection_classes'].shape,
(batch_size, max_boxes))
self.assertIn('num_detections', postprocess_output)
self.assertEqual(postprocess_output['num_detections'].shape,
(batch_size,))
self.assertIn('detection_boxes', postprocess_output)
self.assertEqual(postprocess_output['detection_boxes'].shape,
(batch_size, max_boxes, 4))
self.assertIn('detection_boxes_strided', postprocess_output)
self.assertEqual(postprocess_output['detection_boxes_strided'].shape,
(batch_size, max_boxes, 4))
self.assertIn(cnma.OBJECT_CENTER, predictions)
self.assertEqual(predictions[cnma.OBJECT_CENTER][0].shape,
(batch_size, height, width // stride, num_classes))
self.assertIn(cnma.BOX_SCALE, predictions)
self.assertEqual(predictions[cnma.BOX_SCALE][0].shape,
(batch_size, height, width // stride, 2))
self.assertIn(cnma.BOX_OFFSET, predictions)
self.assertEqual(predictions[cnma.BOX_OFFSET][0].shape,
(batch_size, height, width // stride, 2))
self.assertIn('preprocessed_inputs', predictions)
if __name__ == '__main__':
tf.enable_v2_behavior()
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment