"docs/git@developer.sourcefind.cn:Wenxuan/LightX2V.git" did not exist on "1d166e211ceb3221cde9698f107e58596e6aeab8"
Commit e8a80796 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower Committed by TF Object Detection Team
Browse files

Support 1D detection for CenterNet as inputs with height==1.

PiperOrigin-RevId: 371037598
parent 68411471
......@@ -985,8 +985,8 @@ class CenterNetCenterHeatmapTargetAssigner(object):
the stride specified during initialization.
"""
out_height = tf.cast(height // self._stride, tf.float32)
out_width = tf.cast(width // self._stride, tf.float32)
out_height = tf.cast(tf.maximum(height // self._stride, 1), tf.float32)
out_width = tf.cast(tf.maximum(width // self._stride, 1), tf.float32)
# Compute the yx-grid to be used to generate the heatmap. Each returned
# tensor has shape of [out_height, out_width]
(y_grid, x_grid) = ta_utils.image_shape_to_grids(out_height, out_width)
......@@ -999,9 +999,10 @@ class CenterNetCenterHeatmapTargetAssigner(object):
gt_weights_list):
boxes = box_list.BoxList(boxes)
# Convert the box coordinates to absolute output image dimension space.
boxes = box_list_ops.to_absolute_coordinates(boxes,
height // self._stride,
width // self._stride)
boxes = box_list_ops.to_absolute_coordinates(
boxes,
tf.maximum(height // self._stride, 1),
tf.maximum(width // self._stride, 1))
# Get the box center coordinates. Each returned tensors have the shape of
# [num_instances]
(y_center, x_center, boxes_height,
......@@ -1062,8 +1063,8 @@ class CenterNetCenterHeatmapTargetAssigner(object):
assert (self._keypoint_weights_for_center is not None and
self._keypoint_class_id is not None and
self._keypoint_indices is not None)
out_height = tf.cast(height // self._stride, tf.float32)
out_width = tf.cast(width // self._stride, tf.float32)
out_height = tf.cast(tf.maximum(height // self._stride, 1), tf.float32)
out_width = tf.cast(tf.maximum(width // self._stride, 1), tf.float32)
# Compute the yx-grid to be used to generate the heatmap. Each returned
# tensor has shape of [out_height, out_width]
(y_grid, x_grid) = ta_utils.image_shape_to_grids(out_height, out_width)
......@@ -1230,9 +1231,10 @@ class CenterNetBoxTargetAssigner(object):
for i, (boxes, weights) in enumerate(zip(gt_boxes_list, gt_weights_list)):
boxes = box_list.BoxList(boxes)
boxes = box_list_ops.to_absolute_coordinates(boxes,
height // self._stride,
width // self._stride)
boxes = box_list_ops.to_absolute_coordinates(
boxes,
tf.maximum(height // self._stride, 1),
tf.maximum(width // self._stride, 1))
# Get the box center coordinates. Each returned tensors have the shape of
# [num_boxes]
(y_center, x_center, boxes_height,
......@@ -1410,8 +1412,8 @@ class CenterNetKeypointTargetAssigner(object):
output_width] where all values within the regions of the blackout boxes
are 0.0 and 1.0 else where.
"""
out_width = tf.cast(width // self._stride, tf.float32)
out_height = tf.cast(height // self._stride, tf.float32)
out_width = tf.cast(tf.maximum(width // self._stride, 1), tf.float32)
out_height = tf.cast(tf.maximum(height // self._stride, 1), tf.float32)
# Compute the yx-grid to be used to generate the heatmap. Each returned
# tensor has shape of [out_height, out_width]
y_grid, x_grid = ta_utils.image_shape_to_grids(out_height, out_width)
......@@ -1464,9 +1466,10 @@ class CenterNetKeypointTargetAssigner(object):
if boxes is not None:
boxes = box_list.BoxList(boxes)
# Convert the box coordinates to absolute output image dimension space.
boxes = box_list_ops.to_absolute_coordinates(boxes,
height // self._stride,
width // self._stride)
boxes = box_list_ops.to_absolute_coordinates(
boxes,
tf.maximum(height // self._stride, 1),
tf.maximum(width // self._stride, 1))
# Get the box height and width. Each returned tensors have the shape
# of [num_instances]
(_, _, boxes_height,
......@@ -1586,8 +1589,8 @@ class CenterNetKeypointTargetAssigner(object):
zip(gt_keypoints_list, gt_classes_list, gt_keypoints_weights_list,
gt_weights_list)):
keypoints_absolute, kp_weights = _preprocess_keypoints_and_weights(
out_height=height // self._stride,
out_width=width // self._stride,
out_height=tf.maximum(height // self._stride, 1),
out_width=tf.maximum(width // self._stride, 1),
keypoints=keypoints,
class_onehot=classes,
class_weights=weights,
......@@ -1604,10 +1607,11 @@ class CenterNetKeypointTargetAssigner(object):
# All keypoint coordinates and their neighbors:
# [num_instance * num_keypoints, num_neighbors]
(y_source_neighbors, x_source_neighbors,
valid_sources) = ta_utils.get_surrounding_grids(height // self._stride,
width // self._stride,
y_source, x_source,
self._peak_radius)
valid_sources) = ta_utils.get_surrounding_grids(
tf.cast(tf.maximum(height // self._stride, 1), tf.float32),
tf.cast(tf.maximum(width // self._stride, 1), tf.float32),
y_source, x_source,
self._peak_radius)
_, num_neighbors = shape_utils.combined_static_and_dynamic_shape(
y_source_neighbors)
......@@ -1722,8 +1726,8 @@ class CenterNetKeypointTargetAssigner(object):
gt_keypoints_weights_list, gt_weights_list,
gt_keypoint_depths_list, gt_keypoint_depth_weights_list)):
keypoints_absolute, kp_weights = _preprocess_keypoints_and_weights(
out_height=height // self._stride,
out_width=width // self._stride,
out_height=tf.maximum(height // self._stride, 1),
out_width=tf.maximum(width // self._stride, 1),
keypoints=keypoints,
class_onehot=classes,
class_weights=weights,
......@@ -1740,10 +1744,11 @@ class CenterNetKeypointTargetAssigner(object):
# All keypoint coordinates and their neighbors:
# [num_instance * num_keypoints, num_neighbors]
(y_source_neighbors, x_source_neighbors,
valid_sources) = ta_utils.get_surrounding_grids(height // self._stride,
width // self._stride,
y_source, x_source,
self._peak_radius)
valid_sources) = ta_utils.get_surrounding_grids(
tf.cast(tf.maximum(height // self._stride, 1), tf.float32),
tf.cast(tf.maximum(width // self._stride, 1), tf.float32),
y_source, x_source,
self._peak_radius)
_, num_neighbors = shape_utils.combined_static_and_dynamic_shape(
y_source_neighbors)
......@@ -1894,8 +1899,8 @@ class CenterNetKeypointTargetAssigner(object):
zip(gt_keypoints_list, gt_classes_list,
gt_boxes_list, gt_keypoints_weights_list, gt_weights_list)):
keypoints_absolute, kp_weights = _preprocess_keypoints_and_weights(
out_height=height // self._stride,
out_width=width // self._stride,
out_height=tf.maximum(height // self._stride, 1),
out_width=tf.maximum(width // self._stride, 1),
keypoints=keypoints,
class_onehot=classes,
class_weights=weights,
......@@ -1909,9 +1914,10 @@ class CenterNetKeypointTargetAssigner(object):
if boxes is not None:
# Compute joint center from boxes.
boxes = box_list.BoxList(boxes)
boxes = box_list_ops.to_absolute_coordinates(boxes,
height // self._stride,
width // self._stride)
boxes = box_list_ops.to_absolute_coordinates(
boxes,
tf.maximum(height // self._stride, 1),
tf.maximum(width // self._stride, 1))
y_center, x_center, _, _ = boxes.get_center_coordinates_and_sizes()
else:
# TODO(yuhuic): Add the logic to generate object centers from keypoints.
......@@ -1930,7 +1936,8 @@ class CenterNetKeypointTargetAssigner(object):
# [num_instance * num_keypoints, num_neighbors]
(y_source_neighbors, x_source_neighbors,
valid_sources) = ta_utils.get_surrounding_grids(
height // self._stride, width // self._stride,
tf.cast(tf.maximum(height // self._stride, 1), tf.float32),
tf.cast(tf.maximum(width // self._stride, 1), tf.float32),
tf.keras.backend.flatten(y_center_tiled),
tf.keras.backend.flatten(x_center_tiled), self._peak_radius)
......@@ -2023,8 +2030,8 @@ class CenterNetMaskTargetAssigner(object):
_, input_height, input_width = (
shape_utils.combined_static_and_dynamic_shape(gt_masks_list[0]))
output_height = input_height // self._stride
output_width = input_width // self._stride
output_height = tf.maximum(input_height // self._stride, 1)
output_width = tf.maximum(input_width // self._stride, 1)
segmentation_targets_list = []
for gt_masks, gt_classes in zip(gt_masks_list, gt_classes_list):
......@@ -2114,7 +2121,9 @@ class CenterNetDensePoseTargetAssigner(object):
part_ids_one_hot = tf.one_hot(part_ids_flattened, depth=self._num_parts)
# Get DensePose coordinates in the output space.
surface_coords_abs = densepose_ops.to_absolute_coordinates(
surface_coords, height // self._stride, width // self._stride)
surface_coords,
tf.maximum(height // self._stride, 1),
tf.maximum(width // self._stride, 1))
surface_coords_abs = tf.reshape(surface_coords_abs, [-1, 4])
# Each tensor has shape [num_boxes * max_sampled_points].
yabs, xabs, v, u = tf.unstack(surface_coords_abs, axis=-1)
......@@ -2213,9 +2222,10 @@ class CenterNetTrackTargetAssigner(object):
for i, (boxes, weights) in enumerate(zip(gt_boxes_list, gt_weights_list)):
boxes = box_list.BoxList(boxes)
boxes = box_list_ops.to_absolute_coordinates(boxes,
height // self._stride,
width // self._stride)
boxes = box_list_ops.to_absolute_coordinates(
boxes,
tf.maximum(height // self._stride, 1),
tf.maximum(width // self._stride, 1))
# Get the box center coordinates. Each returned tensors have the shape of
# [num_boxes]
(y_center, x_center, _, _) = boxes.get_center_coordinates_and_sizes()
......@@ -2318,8 +2328,8 @@ class CenterNetCornerOffsetTargetAssigner(object):
"""
_, input_height, input_width = (
shape_utils.combined_static_and_dynamic_shape(gt_masks_list[0]))
output_height = input_height // self._stride
output_width = input_width // self._stride
output_height = tf.maximum(input_height // self._stride, 1)
output_width = tf.maximum(input_width // self._stride, 1)
y_grid, x_grid = tf.meshgrid(
tf.range(output_height), tf.range(output_width),
indexing='ij')
......@@ -2332,6 +2342,8 @@ class CenterNetCornerOffsetTargetAssigner(object):
method=ResizeMethod.NEAREST_NEIGHBOR)
gt_masks = filter_mask_overlap(gt_masks, self._overlap_resolution)
output_height = tf.cast(output_height, tf.float32)
output_width = tf.cast(output_width, tf.float32)
ymin, xmin, ymax, xmax = tf.unstack(gt_boxes, axis=1)
ymin, ymax = ymin * output_height, ymax * output_height
xmin, xmax = xmin * output_width, xmax * output_width
......@@ -2427,9 +2439,10 @@ class CenterNetTemporalOffsetTargetAssigner(object):
for i, (boxes, offsets, match_flags, weights) in enumerate(zip(
gt_boxes_list, gt_offsets_list, gt_match_list, gt_weights_list)):
boxes = box_list.BoxList(boxes)
boxes = box_list_ops.to_absolute_coordinates(boxes,
height // self._stride,
width // self._stride)
boxes = box_list_ops.to_absolute_coordinates(
boxes,
tf.maximum(height // self._stride, 1),
tf.maximum(width // self._stride, 1))
# Get the box center coordinates. Each returned tensors have the shape of
# [num_boxes]
(y_center, x_center, _, _) = boxes.get_center_coordinates_and_sizes()
......
......@@ -137,7 +137,8 @@ class CenterNetFeatureExtractor(tf.keras.Model):
def make_prediction_net(num_out_channels, kernel_sizes=(3), num_filters=(256),
bias_fill=None, use_depthwise=False, name=None):
bias_fill=None, use_depthwise=False, name=None,
unit_height_conv=True):
"""Creates a network to predict the given number of output channels.
This function is intended to make the prediction heads for the CenterNet
......@@ -157,6 +158,7 @@ def make_prediction_net(num_out_channels, kernel_sizes=(3), num_filters=(256),
use_depthwise: If true, use SeparableConv2D to construct the Sequential
layers instead of Conv2D.
name: Optional name for the prediction net.
unit_height_conv: If True, Conv2Ds have asymmetric kernels with height=1.
Returns:
net: A keras module which when called on an input tensor of size
......@@ -189,7 +191,7 @@ def make_prediction_net(num_out_channels, kernel_sizes=(3), num_filters=(256),
layers.append(
conv_fn(
num_filter,
kernel_size=kernel_size,
kernel_size=[1, kernel_size] if unit_height_conv else kernel_size,
padding='same',
name='conv2_%d' % idx if tf_version.is_tf1() else None))
layers.append(tf.keras.layers.ReLU())
......@@ -2174,7 +2176,8 @@ class CenterNetMetaArch(model.DetectionModel):
temporal_offset_params=None,
use_depthwise=False,
compute_heatmap_sparse=False,
non_max_suppression_fn=None):
non_max_suppression_fn=None,
unit_height_conv=False):
"""Initializes a CenterNet model.
Args:
......@@ -2218,6 +2221,8 @@ class CenterNetMetaArch(model.DetectionModel):
better with number of channels in the heatmap, but in some cases is
known to cause an OOM error. See b/170989061.
non_max_suppression_fn: Optional Non Max Suppression function to apply.
unit_height_conv: If True, Conv2Ds in prediction heads have asymmetric
kernels with height=1.
"""
assert object_detection_params or keypoint_params_dict
# Shorten the name for convenience and better formatting.
......@@ -2244,11 +2249,15 @@ class CenterNetMetaArch(model.DetectionModel):
self._use_depthwise = use_depthwise
self._compute_heatmap_sparse = compute_heatmap_sparse
# subclasses may not implement the unit_height_conv arg, so only provide it
# as a kwarg if it is True.
kwargs = {'unit_height_conv': unit_height_conv} if unit_height_conv else {}
# Construct the prediction head nets.
self._prediction_head_dict = self._construct_prediction_heads(
num_classes,
self._num_feature_outputs,
class_prediction_bias_init=self._center_params.heatmap_bias_init)
class_prediction_bias_init=self._center_params.heatmap_bias_init,
**kwargs)
# Initialize the target assigners.
self._target_assigner_dict = self._initialize_target_assigners(
stride=self._stride,
......@@ -2269,7 +2278,8 @@ class CenterNetMetaArch(model.DetectionModel):
def _make_prediction_net_list(self, num_feature_outputs, num_out_channels,
kernel_sizes=(3), num_filters=(256),
bias_fill=None, name=None):
bias_fill=None, name=None,
unit_height_conv=False):
prediction_net_list = []
for i in range(num_feature_outputs):
prediction_net_list.append(
......@@ -2279,11 +2289,13 @@ class CenterNetMetaArch(model.DetectionModel):
num_filters=num_filters,
bias_fill=bias_fill,
use_depthwise=self._use_depthwise,
name='{}_{}'.format(name, i) if name else name))
name='{}_{}'.format(name, i) if name else name,
unit_height_conv=unit_height_conv))
return prediction_net_list
def _construct_prediction_heads(self, num_classes, num_feature_outputs,
class_prediction_bias_init):
class_prediction_bias_init,
unit_height_conv=False):
"""Constructs the prediction heads based on the specific parameters.
Args:
......@@ -2295,6 +2307,7 @@ class CenterNetMetaArch(model.DetectionModel):
class_prediction_bias_init: float, the initial value of bias in the
convolutional kernel of the class prediction head. If set to None, the
bias is initialized with zeros.
unit_height_conv: If True, Conv2Ds have asymmetric kernels with height=1.
Returns:
A dictionary of keras modules generated by calling make_prediction_net
......@@ -2308,13 +2321,16 @@ class CenterNetMetaArch(model.DetectionModel):
kernel_sizes=self._center_params.center_head_kernel_sizes,
num_filters=self._center_params.center_head_num_filters,
bias_fill=class_prediction_bias_init,
name='center')
name='center',
unit_height_conv=unit_height_conv)
if self._od_params is not None:
prediction_heads[BOX_SCALE] = self._make_prediction_net_list(
num_feature_outputs, NUM_SIZE_CHANNELS, name='box_scale')
num_feature_outputs, NUM_SIZE_CHANNELS, name='box_scale',
unit_height_conv=unit_height_conv)
prediction_heads[BOX_OFFSET] = self._make_prediction_net_list(
num_feature_outputs, NUM_OFFSET_CHANNELS, name='box_offset')
num_feature_outputs, NUM_OFFSET_CHANNELS, name='box_offset',
unit_height_conv=unit_height_conv)
if self._kp_params_dict is not None:
for task_name, kp_params in self._kp_params_dict.items():
......@@ -2326,14 +2342,16 @@ class CenterNetMetaArch(model.DetectionModel):
kernel_sizes=kp_params.heatmap_head_kernel_sizes,
num_filters=kp_params.heatmap_head_num_filters,
bias_fill=kp_params.heatmap_bias_init,
name='kpt_heatmap')
name='kpt_heatmap',
unit_height_conv=unit_height_conv)
prediction_heads[get_keypoint_name(
task_name, KEYPOINT_REGRESSION)] = self._make_prediction_net_list(
num_feature_outputs,
NUM_OFFSET_CHANNELS * num_keypoints,
kernel_sizes=kp_params.regress_head_kernel_sizes,
num_filters=kp_params.regress_head_num_filters,
name='kpt_regress')
name='kpt_regress',
unit_height_conv=unit_height_conv)
if kp_params.per_keypoint_offset:
prediction_heads[get_keypoint_name(
......@@ -2342,7 +2360,8 @@ class CenterNetMetaArch(model.DetectionModel):
NUM_OFFSET_CHANNELS * num_keypoints,
kernel_sizes=kp_params.offset_head_kernel_sizes,
num_filters=kp_params.offset_head_num_filters,
name='kpt_offset')
name='kpt_offset',
unit_height_conv=unit_height_conv)
else:
prediction_heads[get_keypoint_name(
task_name, KEYPOINT_OFFSET)] = self._make_prediction_net_list(
......@@ -2350,38 +2369,44 @@ class CenterNetMetaArch(model.DetectionModel):
NUM_OFFSET_CHANNELS,
kernel_sizes=kp_params.offset_head_kernel_sizes,
num_filters=kp_params.offset_head_num_filters,
name='kpt_offset')
name='kpt_offset',
unit_height_conv=unit_height_conv)
if kp_params.predict_depth:
num_depth_channel = (
num_keypoints if kp_params.per_keypoint_depth else 1)
prediction_heads[get_keypoint_name(
task_name, KEYPOINT_DEPTH)] = self._make_prediction_net_list(
num_feature_outputs, num_depth_channel, name='kpt_depth')
num_feature_outputs, num_depth_channel, name='kpt_depth',
unit_height_conv=unit_height_conv)
if self._mask_params is not None:
prediction_heads[SEGMENTATION_HEATMAP] = self._make_prediction_net_list(
num_feature_outputs,
num_classes,
bias_fill=self._mask_params.heatmap_bias_init,
name='seg_heatmap')
name='seg_heatmap',
unit_height_conv=unit_height_conv)
if self._densepose_params is not None:
prediction_heads[DENSEPOSE_HEATMAP] = self._make_prediction_net_list(
num_feature_outputs,
self._densepose_params.num_parts,
bias_fill=self._densepose_params.heatmap_bias_init,
name='dense_pose_heatmap')
name='dense_pose_heatmap',
unit_height_conv=unit_height_conv)
prediction_heads[DENSEPOSE_REGRESSION] = self._make_prediction_net_list(
num_feature_outputs,
2 * self._densepose_params.num_parts,
name='dense_pose_regress')
name='dense_pose_regress',
unit_height_conv=unit_height_conv)
if self._track_params is not None:
prediction_heads[TRACK_REID] = self._make_prediction_net_list(
num_feature_outputs,
self._track_params.reid_embed_size,
name='track_reid')
name='track_reid',
unit_height_conv=unit_height_conv)
# Creates a classification network to train object embeddings by learning
# a projection from embedding space to object track ID space.
......@@ -2400,7 +2425,8 @@ class CenterNetMetaArch(model.DetectionModel):
self._track_params.reid_embed_size,)))
if self._temporal_offset_params is not None:
prediction_heads[TEMPORAL_OFFSET] = self._make_prediction_net_list(
num_feature_outputs, NUM_OFFSET_CHANNELS, name='temporal_offset')
num_feature_outputs, NUM_OFFSET_CHANNELS, name='temporal_offset',
unit_height_conv=unit_height_conv)
return prediction_heads
def _initialize_target_assigners(self, stride, min_box_overlap_iou):
......@@ -3357,8 +3383,8 @@ class CenterNetMetaArch(model.DetectionModel):
_, input_height, input_width, _ = _get_shape(
prediction_dict['preprocessed_inputs'], 4)
output_height, output_width = (input_height // self._stride,
input_width // self._stride)
output_height, output_width = (tf.maximum(input_height // self._stride, 1),
tf.maximum(input_width // self._stride, 1))
# TODO(vighneshb) Explore whether using floor here is safe.
output_true_image_shapes = tf.ceil(
......
......@@ -2995,6 +2995,162 @@ class CenterNetFeatureExtractorTest(test_case.TestCase):
self.assertAllClose(output[..., 2], 3 * np.ones((2, 32, 32)))
class Dummy1dFeatureExtractor(cnma.CenterNetFeatureExtractor):
"""Returns a static tensor."""
def __init__(self, tensor, out_stride=1, channel_means=(0., 0., 0.),
channel_stds=(1., 1., 1.), bgr_ordering=False):
"""Intializes the feature extractor.
Args:
tensor: The tensor to return as the processed feature.
out_stride: The out_stride to return if asked.
channel_means: Ignored, but provided for API compatability.
channel_stds: Ignored, but provided for API compatability.
bgr_ordering: Ignored, but provided for API compatability.
"""
super().__init__(
channel_means=channel_means, channel_stds=channel_stds,
bgr_ordering=bgr_ordering)
self._tensor = tensor
self._out_stride = out_stride
def call(self, inputs):
return [self._tensor]
@property
def out_stride(self):
"""The stride in the output image of the network."""
return self._out_stride
@property
def num_feature_outputs(self):
"""Ther number of feature outputs returned by the feature extractor."""
return 1
@property
def supported_sub_model_types(self):
return ['detection']
def get_sub_model(self, sub_model_type):
if sub_model_type == 'detection':
return self._network
else:
ValueError('Sub model type "{}" not supported.'.format(sub_model_type))
@unittest.skipIf(tf_version.is_tf1(), 'Skipping TF2.X only test.')
class CenterNetMetaArch1dTest(test_case.TestCase, parameterized.TestCase):
@parameterized.parameters([1, 2])
def test_outputs_with_correct_shape(self, stride):
# The 1D case reuses code from the 2D cases. These tests only check that
# the output shapes are correct, and relies on other tests for correctness.
batch_size = 2
height = 1
width = 32
channels = 16
unstrided_inputs = np.random.randn(
batch_size, height, width, channels)
fixed_output_features = np.random.randn(
batch_size, height, width // stride, channels)
max_boxes = 10
num_classes = 3
feature_extractor = Dummy1dFeatureExtractor(fixed_output_features, stride)
arch = cnma.CenterNetMetaArch(
is_training=True,
add_summaries=True,
num_classes=num_classes,
feature_extractor=feature_extractor,
image_resizer_fn=None,
object_center_params=cnma.ObjectCenterParams(
classification_loss=losses.PenaltyReducedLogisticFocalLoss(),
object_center_loss_weight=1.0,
max_box_predictions=max_boxes,
),
object_detection_params=cnma.ObjectDetectionParams(
localization_loss=losses.L1LocalizationLoss(),
scale_loss_weight=1.0,
offset_loss_weight=1.0,
),
keypoint_params_dict=None,
mask_params=None,
densepose_params=None,
track_params=None,
temporal_offset_params=None,
use_depthwise=False,
compute_heatmap_sparse=False,
non_max_suppression_fn=None,
unit_height_conv=True)
arch.provide_groundtruth(
groundtruth_boxes_list=[
tf.constant([[0, 0.5, 1.0, 0.75],
[0, 0.1, 1.0, 0.25]], tf.float32),
tf.constant([[0, 0, 1.0, 1.0],
[0, 0, 0.0, 0.0]], tf.float32)
],
groundtruth_classes_list=[
tf.constant([[0, 0, 1],
[0, 1, 0]], tf.float32),
tf.constant([[1, 0, 0],
[0, 0, 0]], tf.float32)
],
groundtruth_weights_list=[
tf.constant([1.0, 1.0]),
tf.constant([1.0, 0.0])]
)
predictions = arch.predict(None, None) # input is hardcoded above.
predictions['preprocessed_inputs'] = tf.constant(unstrided_inputs)
true_shapes = tf.constant([[1, 32, 16], [1, 24, 16]], tf.int32)
postprocess_output = arch.postprocess(predictions, true_shapes)
losses_output = arch.loss(predictions, true_shapes)
self.assertIn('%s/%s' % (cnma.LOSS_KEY_PREFIX, cnma.OBJECT_CENTER),
losses_output)
self.assertEqual((), losses_output['%s/%s' % (
cnma.LOSS_KEY_PREFIX, cnma.OBJECT_CENTER)].shape)
self.assertIn('%s/%s' % (cnma.LOSS_KEY_PREFIX, cnma.BOX_SCALE),
losses_output)
self.assertEqual((), losses_output['%s/%s' % (
cnma.LOSS_KEY_PREFIX, cnma.BOX_SCALE)].shape)
self.assertIn('%s/%s' % (cnma.LOSS_KEY_PREFIX, cnma.BOX_OFFSET),
losses_output)
self.assertEqual((), losses_output['%s/%s' % (
cnma.LOSS_KEY_PREFIX, cnma.BOX_OFFSET)].shape)
self.assertIn('detection_scores', postprocess_output)
self.assertEqual(postprocess_output['detection_scores'].shape,
(batch_size, max_boxes))
self.assertIn('detection_multiclass_scores', postprocess_output)
self.assertEqual(postprocess_output['detection_multiclass_scores'].shape,
(batch_size, max_boxes, num_classes))
self.assertIn('detection_classes', postprocess_output)
self.assertEqual(postprocess_output['detection_classes'].shape,
(batch_size, max_boxes))
self.assertIn('num_detections', postprocess_output)
self.assertEqual(postprocess_output['num_detections'].shape,
(batch_size,))
self.assertIn('detection_boxes', postprocess_output)
self.assertEqual(postprocess_output['detection_boxes'].shape,
(batch_size, max_boxes, 4))
self.assertIn('detection_boxes_strided', postprocess_output)
self.assertEqual(postprocess_output['detection_boxes_strided'].shape,
(batch_size, max_boxes, 4))
self.assertIn(cnma.OBJECT_CENTER, predictions)
self.assertEqual(predictions[cnma.OBJECT_CENTER][0].shape,
(batch_size, height, width // stride, num_classes))
self.assertIn(cnma.BOX_SCALE, predictions)
self.assertEqual(predictions[cnma.BOX_SCALE][0].shape,
(batch_size, height, width // stride, 2))
self.assertIn(cnma.BOX_OFFSET, predictions)
self.assertEqual(predictions[cnma.BOX_OFFSET][0].shape,
(batch_size, height, width // stride, 2))
self.assertIn('preprocessed_inputs', predictions)
if __name__ == '__main__':
tf.enable_v2_behavior()
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment