Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
e8a80796
Commit
e8a80796
authored
Apr 28, 2021
by
A. Unique TensorFlower
Committed by
TF Object Detection Team
Apr 28, 2021
Browse files
Support 1D detection for CenterNet as inputs with height==1.
PiperOrigin-RevId: 371037598
parent
68411471
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
261 additions
and
66 deletions
+261
-66
research/object_detection/core/target_assigner.py
research/object_detection/core/target_assigner.py
+57
-44
research/object_detection/meta_architectures/center_net_meta_arch.py
...ject_detection/meta_architectures/center_net_meta_arch.py
+48
-22
research/object_detection/meta_architectures/center_net_meta_arch_tf2_test.py
...ction/meta_architectures/center_net_meta_arch_tf2_test.py
+156
-0
No files found.
research/object_detection/core/target_assigner.py
View file @
e8a80796
...
@@ -985,8 +985,8 @@ class CenterNetCenterHeatmapTargetAssigner(object):
...
@@ -985,8 +985,8 @@ class CenterNetCenterHeatmapTargetAssigner(object):
the stride specified during initialization.
the stride specified during initialization.
"""
"""
out_height
=
tf
.
cast
(
height
//
self
.
_stride
,
tf
.
float32
)
out_height
=
tf
.
cast
(
tf
.
maximum
(
height
//
self
.
_stride
,
1
),
tf
.
float32
)
out_width
=
tf
.
cast
(
width
//
self
.
_stride
,
tf
.
float32
)
out_width
=
tf
.
cast
(
tf
.
maximum
(
width
//
self
.
_stride
,
1
),
tf
.
float32
)
# Compute the yx-grid to be used to generate the heatmap. Each returned
# Compute the yx-grid to be used to generate the heatmap. Each returned
# tensor has shape of [out_height, out_width]
# tensor has shape of [out_height, out_width]
(
y_grid
,
x_grid
)
=
ta_utils
.
image_shape_to_grids
(
out_height
,
out_width
)
(
y_grid
,
x_grid
)
=
ta_utils
.
image_shape_to_grids
(
out_height
,
out_width
)
...
@@ -999,9 +999,10 @@ class CenterNetCenterHeatmapTargetAssigner(object):
...
@@ -999,9 +999,10 @@ class CenterNetCenterHeatmapTargetAssigner(object):
gt_weights_list
):
gt_weights_list
):
boxes
=
box_list
.
BoxList
(
boxes
)
boxes
=
box_list
.
BoxList
(
boxes
)
# Convert the box coordinates to absolute output image dimension space.
# Convert the box coordinates to absolute output image dimension space.
boxes
=
box_list_ops
.
to_absolute_coordinates
(
boxes
,
boxes
=
box_list_ops
.
to_absolute_coordinates
(
height
//
self
.
_stride
,
boxes
,
width
//
self
.
_stride
)
tf
.
maximum
(
height
//
self
.
_stride
,
1
),
tf
.
maximum
(
width
//
self
.
_stride
,
1
))
# Get the box center coordinates. Each returned tensors have the shape of
# Get the box center coordinates. Each returned tensors have the shape of
# [num_instances]
# [num_instances]
(
y_center
,
x_center
,
boxes_height
,
(
y_center
,
x_center
,
boxes_height
,
...
@@ -1062,8 +1063,8 @@ class CenterNetCenterHeatmapTargetAssigner(object):
...
@@ -1062,8 +1063,8 @@ class CenterNetCenterHeatmapTargetAssigner(object):
assert
(
self
.
_keypoint_weights_for_center
is
not
None
and
assert
(
self
.
_keypoint_weights_for_center
is
not
None
and
self
.
_keypoint_class_id
is
not
None
and
self
.
_keypoint_class_id
is
not
None
and
self
.
_keypoint_indices
is
not
None
)
self
.
_keypoint_indices
is
not
None
)
out_height
=
tf
.
cast
(
height
//
self
.
_stride
,
tf
.
float32
)
out_height
=
tf
.
cast
(
tf
.
maximum
(
height
//
self
.
_stride
,
1
),
tf
.
float32
)
out_width
=
tf
.
cast
(
width
//
self
.
_stride
,
tf
.
float32
)
out_width
=
tf
.
cast
(
tf
.
maximum
(
width
//
self
.
_stride
,
1
),
tf
.
float32
)
# Compute the yx-grid to be used to generate the heatmap. Each returned
# Compute the yx-grid to be used to generate the heatmap. Each returned
# tensor has shape of [out_height, out_width]
# tensor has shape of [out_height, out_width]
(
y_grid
,
x_grid
)
=
ta_utils
.
image_shape_to_grids
(
out_height
,
out_width
)
(
y_grid
,
x_grid
)
=
ta_utils
.
image_shape_to_grids
(
out_height
,
out_width
)
...
@@ -1230,9 +1231,10 @@ class CenterNetBoxTargetAssigner(object):
...
@@ -1230,9 +1231,10 @@ class CenterNetBoxTargetAssigner(object):
for
i
,
(
boxes
,
weights
)
in
enumerate
(
zip
(
gt_boxes_list
,
gt_weights_list
)):
for
i
,
(
boxes
,
weights
)
in
enumerate
(
zip
(
gt_boxes_list
,
gt_weights_list
)):
boxes
=
box_list
.
BoxList
(
boxes
)
boxes
=
box_list
.
BoxList
(
boxes
)
boxes
=
box_list_ops
.
to_absolute_coordinates
(
boxes
,
boxes
=
box_list_ops
.
to_absolute_coordinates
(
height
//
self
.
_stride
,
boxes
,
width
//
self
.
_stride
)
tf
.
maximum
(
height
//
self
.
_stride
,
1
),
tf
.
maximum
(
width
//
self
.
_stride
,
1
))
# Get the box center coordinates. Each returned tensors have the shape of
# Get the box center coordinates. Each returned tensors have the shape of
# [num_boxes]
# [num_boxes]
(
y_center
,
x_center
,
boxes_height
,
(
y_center
,
x_center
,
boxes_height
,
...
@@ -1410,8 +1412,8 @@ class CenterNetKeypointTargetAssigner(object):
...
@@ -1410,8 +1412,8 @@ class CenterNetKeypointTargetAssigner(object):
output_width] where all values within the regions of the blackout boxes
output_width] where all values within the regions of the blackout boxes
are 0.0 and 1.0 else where.
are 0.0 and 1.0 else where.
"""
"""
out_width
=
tf
.
cast
(
width
//
self
.
_stride
,
tf
.
float32
)
out_width
=
tf
.
cast
(
tf
.
maximum
(
width
//
self
.
_stride
,
1
),
tf
.
float32
)
out_height
=
tf
.
cast
(
height
//
self
.
_stride
,
tf
.
float32
)
out_height
=
tf
.
cast
(
tf
.
maximum
(
height
//
self
.
_stride
,
1
),
tf
.
float32
)
# Compute the yx-grid to be used to generate the heatmap. Each returned
# Compute the yx-grid to be used to generate the heatmap. Each returned
# tensor has shape of [out_height, out_width]
# tensor has shape of [out_height, out_width]
y_grid
,
x_grid
=
ta_utils
.
image_shape_to_grids
(
out_height
,
out_width
)
y_grid
,
x_grid
=
ta_utils
.
image_shape_to_grids
(
out_height
,
out_width
)
...
@@ -1464,9 +1466,10 @@ class CenterNetKeypointTargetAssigner(object):
...
@@ -1464,9 +1466,10 @@ class CenterNetKeypointTargetAssigner(object):
if
boxes
is
not
None
:
if
boxes
is
not
None
:
boxes
=
box_list
.
BoxList
(
boxes
)
boxes
=
box_list
.
BoxList
(
boxes
)
# Convert the box coordinates to absolute output image dimension space.
# Convert the box coordinates to absolute output image dimension space.
boxes
=
box_list_ops
.
to_absolute_coordinates
(
boxes
,
boxes
=
box_list_ops
.
to_absolute_coordinates
(
height
//
self
.
_stride
,
boxes
,
width
//
self
.
_stride
)
tf
.
maximum
(
height
//
self
.
_stride
,
1
),
tf
.
maximum
(
width
//
self
.
_stride
,
1
))
# Get the box height and width. Each returned tensors have the shape
# Get the box height and width. Each returned tensors have the shape
# of [num_instances]
# of [num_instances]
(
_
,
_
,
boxes_height
,
(
_
,
_
,
boxes_height
,
...
@@ -1586,8 +1589,8 @@ class CenterNetKeypointTargetAssigner(object):
...
@@ -1586,8 +1589,8 @@ class CenterNetKeypointTargetAssigner(object):
zip
(
gt_keypoints_list
,
gt_classes_list
,
gt_keypoints_weights_list
,
zip
(
gt_keypoints_list
,
gt_classes_list
,
gt_keypoints_weights_list
,
gt_weights_list
)):
gt_weights_list
)):
keypoints_absolute
,
kp_weights
=
_preprocess_keypoints_and_weights
(
keypoints_absolute
,
kp_weights
=
_preprocess_keypoints_and_weights
(
out_height
=
height
//
self
.
_stride
,
out_height
=
tf
.
maximum
(
height
//
self
.
_stride
,
1
),
out_width
=
width
//
self
.
_stride
,
out_width
=
tf
.
maximum
(
width
//
self
.
_stride
,
1
),
keypoints
=
keypoints
,
keypoints
=
keypoints
,
class_onehot
=
classes
,
class_onehot
=
classes
,
class_weights
=
weights
,
class_weights
=
weights
,
...
@@ -1604,8 +1607,9 @@ class CenterNetKeypointTargetAssigner(object):
...
@@ -1604,8 +1607,9 @@ class CenterNetKeypointTargetAssigner(object):
# All keypoint coordinates and their neighbors:
# All keypoint coordinates and their neighbors:
# [num_instance * num_keypoints, num_neighbors]
# [num_instance * num_keypoints, num_neighbors]
(
y_source_neighbors
,
x_source_neighbors
,
(
y_source_neighbors
,
x_source_neighbors
,
valid_sources
)
=
ta_utils
.
get_surrounding_grids
(
height
//
self
.
_stride
,
valid_sources
)
=
ta_utils
.
get_surrounding_grids
(
width
//
self
.
_stride
,
tf
.
cast
(
tf
.
maximum
(
height
//
self
.
_stride
,
1
),
tf
.
float32
),
tf
.
cast
(
tf
.
maximum
(
width
//
self
.
_stride
,
1
),
tf
.
float32
),
y_source
,
x_source
,
y_source
,
x_source
,
self
.
_peak_radius
)
self
.
_peak_radius
)
_
,
num_neighbors
=
shape_utils
.
combined_static_and_dynamic_shape
(
_
,
num_neighbors
=
shape_utils
.
combined_static_and_dynamic_shape
(
...
@@ -1722,8 +1726,8 @@ class CenterNetKeypointTargetAssigner(object):
...
@@ -1722,8 +1726,8 @@ class CenterNetKeypointTargetAssigner(object):
gt_keypoints_weights_list
,
gt_weights_list
,
gt_keypoints_weights_list
,
gt_weights_list
,
gt_keypoint_depths_list
,
gt_keypoint_depth_weights_list
)):
gt_keypoint_depths_list
,
gt_keypoint_depth_weights_list
)):
keypoints_absolute
,
kp_weights
=
_preprocess_keypoints_and_weights
(
keypoints_absolute
,
kp_weights
=
_preprocess_keypoints_and_weights
(
out_height
=
height
//
self
.
_stride
,
out_height
=
tf
.
maximum
(
height
//
self
.
_stride
,
1
),
out_width
=
width
//
self
.
_stride
,
out_width
=
tf
.
maximum
(
width
//
self
.
_stride
,
1
),
keypoints
=
keypoints
,
keypoints
=
keypoints
,
class_onehot
=
classes
,
class_onehot
=
classes
,
class_weights
=
weights
,
class_weights
=
weights
,
...
@@ -1740,8 +1744,9 @@ class CenterNetKeypointTargetAssigner(object):
...
@@ -1740,8 +1744,9 @@ class CenterNetKeypointTargetAssigner(object):
# All keypoint coordinates and their neighbors:
# All keypoint coordinates and their neighbors:
# [num_instance * num_keypoints, num_neighbors]
# [num_instance * num_keypoints, num_neighbors]
(
y_source_neighbors
,
x_source_neighbors
,
(
y_source_neighbors
,
x_source_neighbors
,
valid_sources
)
=
ta_utils
.
get_surrounding_grids
(
height
//
self
.
_stride
,
valid_sources
)
=
ta_utils
.
get_surrounding_grids
(
width
//
self
.
_stride
,
tf
.
cast
(
tf
.
maximum
(
height
//
self
.
_stride
,
1
),
tf
.
float32
),
tf
.
cast
(
tf
.
maximum
(
width
//
self
.
_stride
,
1
),
tf
.
float32
),
y_source
,
x_source
,
y_source
,
x_source
,
self
.
_peak_radius
)
self
.
_peak_radius
)
_
,
num_neighbors
=
shape_utils
.
combined_static_and_dynamic_shape
(
_
,
num_neighbors
=
shape_utils
.
combined_static_and_dynamic_shape
(
...
@@ -1894,8 +1899,8 @@ class CenterNetKeypointTargetAssigner(object):
...
@@ -1894,8 +1899,8 @@ class CenterNetKeypointTargetAssigner(object):
zip
(
gt_keypoints_list
,
gt_classes_list
,
zip
(
gt_keypoints_list
,
gt_classes_list
,
gt_boxes_list
,
gt_keypoints_weights_list
,
gt_weights_list
)):
gt_boxes_list
,
gt_keypoints_weights_list
,
gt_weights_list
)):
keypoints_absolute
,
kp_weights
=
_preprocess_keypoints_and_weights
(
keypoints_absolute
,
kp_weights
=
_preprocess_keypoints_and_weights
(
out_height
=
height
//
self
.
_stride
,
out_height
=
tf
.
maximum
(
height
//
self
.
_stride
,
1
),
out_width
=
width
//
self
.
_stride
,
out_width
=
tf
.
maximum
(
width
//
self
.
_stride
,
1
),
keypoints
=
keypoints
,
keypoints
=
keypoints
,
class_onehot
=
classes
,
class_onehot
=
classes
,
class_weights
=
weights
,
class_weights
=
weights
,
...
@@ -1909,9 +1914,10 @@ class CenterNetKeypointTargetAssigner(object):
...
@@ -1909,9 +1914,10 @@ class CenterNetKeypointTargetAssigner(object):
if
boxes
is
not
None
:
if
boxes
is
not
None
:
# Compute joint center from boxes.
# Compute joint center from boxes.
boxes
=
box_list
.
BoxList
(
boxes
)
boxes
=
box_list
.
BoxList
(
boxes
)
boxes
=
box_list_ops
.
to_absolute_coordinates
(
boxes
,
boxes
=
box_list_ops
.
to_absolute_coordinates
(
height
//
self
.
_stride
,
boxes
,
width
//
self
.
_stride
)
tf
.
maximum
(
height
//
self
.
_stride
,
1
),
tf
.
maximum
(
width
//
self
.
_stride
,
1
))
y_center
,
x_center
,
_
,
_
=
boxes
.
get_center_coordinates_and_sizes
()
y_center
,
x_center
,
_
,
_
=
boxes
.
get_center_coordinates_and_sizes
()
else
:
else
:
# TODO(yuhuic): Add the logic to generate object centers from keypoints.
# TODO(yuhuic): Add the logic to generate object centers from keypoints.
...
@@ -1930,7 +1936,8 @@ class CenterNetKeypointTargetAssigner(object):
...
@@ -1930,7 +1936,8 @@ class CenterNetKeypointTargetAssigner(object):
# [num_instance * num_keypoints, num_neighbors]
# [num_instance * num_keypoints, num_neighbors]
(
y_source_neighbors
,
x_source_neighbors
,
(
y_source_neighbors
,
x_source_neighbors
,
valid_sources
)
=
ta_utils
.
get_surrounding_grids
(
valid_sources
)
=
ta_utils
.
get_surrounding_grids
(
height
//
self
.
_stride
,
width
//
self
.
_stride
,
tf
.
cast
(
tf
.
maximum
(
height
//
self
.
_stride
,
1
),
tf
.
float32
),
tf
.
cast
(
tf
.
maximum
(
width
//
self
.
_stride
,
1
),
tf
.
float32
),
tf
.
keras
.
backend
.
flatten
(
y_center_tiled
),
tf
.
keras
.
backend
.
flatten
(
y_center_tiled
),
tf
.
keras
.
backend
.
flatten
(
x_center_tiled
),
self
.
_peak_radius
)
tf
.
keras
.
backend
.
flatten
(
x_center_tiled
),
self
.
_peak_radius
)
...
@@ -2023,8 +2030,8 @@ class CenterNetMaskTargetAssigner(object):
...
@@ -2023,8 +2030,8 @@ class CenterNetMaskTargetAssigner(object):
_
,
input_height
,
input_width
=
(
_
,
input_height
,
input_width
=
(
shape_utils
.
combined_static_and_dynamic_shape
(
gt_masks_list
[
0
]))
shape_utils
.
combined_static_and_dynamic_shape
(
gt_masks_list
[
0
]))
output_height
=
input_height
//
self
.
_stride
output_height
=
tf
.
maximum
(
input_height
//
self
.
_stride
,
1
)
output_width
=
input_width
//
self
.
_stride
output_width
=
tf
.
maximum
(
input_width
//
self
.
_stride
,
1
)
segmentation_targets_list
=
[]
segmentation_targets_list
=
[]
for
gt_masks
,
gt_classes
in
zip
(
gt_masks_list
,
gt_classes_list
):
for
gt_masks
,
gt_classes
in
zip
(
gt_masks_list
,
gt_classes_list
):
...
@@ -2114,7 +2121,9 @@ class CenterNetDensePoseTargetAssigner(object):
...
@@ -2114,7 +2121,9 @@ class CenterNetDensePoseTargetAssigner(object):
part_ids_one_hot
=
tf
.
one_hot
(
part_ids_flattened
,
depth
=
self
.
_num_parts
)
part_ids_one_hot
=
tf
.
one_hot
(
part_ids_flattened
,
depth
=
self
.
_num_parts
)
# Get DensePose coordinates in the output space.
# Get DensePose coordinates in the output space.
surface_coords_abs
=
densepose_ops
.
to_absolute_coordinates
(
surface_coords_abs
=
densepose_ops
.
to_absolute_coordinates
(
surface_coords
,
height
//
self
.
_stride
,
width
//
self
.
_stride
)
surface_coords
,
tf
.
maximum
(
height
//
self
.
_stride
,
1
),
tf
.
maximum
(
width
//
self
.
_stride
,
1
))
surface_coords_abs
=
tf
.
reshape
(
surface_coords_abs
,
[
-
1
,
4
])
surface_coords_abs
=
tf
.
reshape
(
surface_coords_abs
,
[
-
1
,
4
])
# Each tensor has shape [num_boxes * max_sampled_points].
# Each tensor has shape [num_boxes * max_sampled_points].
yabs
,
xabs
,
v
,
u
=
tf
.
unstack
(
surface_coords_abs
,
axis
=-
1
)
yabs
,
xabs
,
v
,
u
=
tf
.
unstack
(
surface_coords_abs
,
axis
=-
1
)
...
@@ -2213,9 +2222,10 @@ class CenterNetTrackTargetAssigner(object):
...
@@ -2213,9 +2222,10 @@ class CenterNetTrackTargetAssigner(object):
for
i
,
(
boxes
,
weights
)
in
enumerate
(
zip
(
gt_boxes_list
,
gt_weights_list
)):
for
i
,
(
boxes
,
weights
)
in
enumerate
(
zip
(
gt_boxes_list
,
gt_weights_list
)):
boxes
=
box_list
.
BoxList
(
boxes
)
boxes
=
box_list
.
BoxList
(
boxes
)
boxes
=
box_list_ops
.
to_absolute_coordinates
(
boxes
,
boxes
=
box_list_ops
.
to_absolute_coordinates
(
height
//
self
.
_stride
,
boxes
,
width
//
self
.
_stride
)
tf
.
maximum
(
height
//
self
.
_stride
,
1
),
tf
.
maximum
(
width
//
self
.
_stride
,
1
))
# Get the box center coordinates. Each returned tensors have the shape of
# Get the box center coordinates. Each returned tensors have the shape of
# [num_boxes]
# [num_boxes]
(
y_center
,
x_center
,
_
,
_
)
=
boxes
.
get_center_coordinates_and_sizes
()
(
y_center
,
x_center
,
_
,
_
)
=
boxes
.
get_center_coordinates_and_sizes
()
...
@@ -2318,8 +2328,8 @@ class CenterNetCornerOffsetTargetAssigner(object):
...
@@ -2318,8 +2328,8 @@ class CenterNetCornerOffsetTargetAssigner(object):
"""
"""
_
,
input_height
,
input_width
=
(
_
,
input_height
,
input_width
=
(
shape_utils
.
combined_static_and_dynamic_shape
(
gt_masks_list
[
0
]))
shape_utils
.
combined_static_and_dynamic_shape
(
gt_masks_list
[
0
]))
output_height
=
input_height
//
self
.
_stride
output_height
=
tf
.
maximum
(
input_height
//
self
.
_stride
,
1
)
output_width
=
input_width
//
self
.
_stride
output_width
=
tf
.
maximum
(
input_width
//
self
.
_stride
,
1
)
y_grid
,
x_grid
=
tf
.
meshgrid
(
y_grid
,
x_grid
=
tf
.
meshgrid
(
tf
.
range
(
output_height
),
tf
.
range
(
output_width
),
tf
.
range
(
output_height
),
tf
.
range
(
output_width
),
indexing
=
'ij'
)
indexing
=
'ij'
)
...
@@ -2332,6 +2342,8 @@ class CenterNetCornerOffsetTargetAssigner(object):
...
@@ -2332,6 +2342,8 @@ class CenterNetCornerOffsetTargetAssigner(object):
method
=
ResizeMethod
.
NEAREST_NEIGHBOR
)
method
=
ResizeMethod
.
NEAREST_NEIGHBOR
)
gt_masks
=
filter_mask_overlap
(
gt_masks
,
self
.
_overlap_resolution
)
gt_masks
=
filter_mask_overlap
(
gt_masks
,
self
.
_overlap_resolution
)
output_height
=
tf
.
cast
(
output_height
,
tf
.
float32
)
output_width
=
tf
.
cast
(
output_width
,
tf
.
float32
)
ymin
,
xmin
,
ymax
,
xmax
=
tf
.
unstack
(
gt_boxes
,
axis
=
1
)
ymin
,
xmin
,
ymax
,
xmax
=
tf
.
unstack
(
gt_boxes
,
axis
=
1
)
ymin
,
ymax
=
ymin
*
output_height
,
ymax
*
output_height
ymin
,
ymax
=
ymin
*
output_height
,
ymax
*
output_height
xmin
,
xmax
=
xmin
*
output_width
,
xmax
*
output_width
xmin
,
xmax
=
xmin
*
output_width
,
xmax
*
output_width
...
@@ -2427,9 +2439,10 @@ class CenterNetTemporalOffsetTargetAssigner(object):
...
@@ -2427,9 +2439,10 @@ class CenterNetTemporalOffsetTargetAssigner(object):
for
i
,
(
boxes
,
offsets
,
match_flags
,
weights
)
in
enumerate
(
zip
(
for
i
,
(
boxes
,
offsets
,
match_flags
,
weights
)
in
enumerate
(
zip
(
gt_boxes_list
,
gt_offsets_list
,
gt_match_list
,
gt_weights_list
)):
gt_boxes_list
,
gt_offsets_list
,
gt_match_list
,
gt_weights_list
)):
boxes
=
box_list
.
BoxList
(
boxes
)
boxes
=
box_list
.
BoxList
(
boxes
)
boxes
=
box_list_ops
.
to_absolute_coordinates
(
boxes
,
boxes
=
box_list_ops
.
to_absolute_coordinates
(
height
//
self
.
_stride
,
boxes
,
width
//
self
.
_stride
)
tf
.
maximum
(
height
//
self
.
_stride
,
1
),
tf
.
maximum
(
width
//
self
.
_stride
,
1
))
# Get the box center coordinates. Each returned tensors have the shape of
# Get the box center coordinates. Each returned tensors have the shape of
# [num_boxes]
# [num_boxes]
(
y_center
,
x_center
,
_
,
_
)
=
boxes
.
get_center_coordinates_and_sizes
()
(
y_center
,
x_center
,
_
,
_
)
=
boxes
.
get_center_coordinates_and_sizes
()
...
...
research/object_detection/meta_architectures/center_net_meta_arch.py
View file @
e8a80796
...
@@ -137,7 +137,8 @@ class CenterNetFeatureExtractor(tf.keras.Model):
...
@@ -137,7 +137,8 @@ class CenterNetFeatureExtractor(tf.keras.Model):
def
make_prediction_net
(
num_out_channels
,
kernel_sizes
=
(
3
),
num_filters
=
(
256
),
def
make_prediction_net
(
num_out_channels
,
kernel_sizes
=
(
3
),
num_filters
=
(
256
),
bias_fill
=
None
,
use_depthwise
=
False
,
name
=
None
):
bias_fill
=
None
,
use_depthwise
=
False
,
name
=
None
,
unit_height_conv
=
True
):
"""Creates a network to predict the given number of output channels.
"""Creates a network to predict the given number of output channels.
This function is intended to make the prediction heads for the CenterNet
This function is intended to make the prediction heads for the CenterNet
...
@@ -157,6 +158,7 @@ def make_prediction_net(num_out_channels, kernel_sizes=(3), num_filters=(256),
...
@@ -157,6 +158,7 @@ def make_prediction_net(num_out_channels, kernel_sizes=(3), num_filters=(256),
use_depthwise: If true, use SeparableConv2D to construct the Sequential
use_depthwise: If true, use SeparableConv2D to construct the Sequential
layers instead of Conv2D.
layers instead of Conv2D.
name: Optional name for the prediction net.
name: Optional name for the prediction net.
unit_height_conv: If True, Conv2Ds have asymmetric kernels with height=1.
Returns:
Returns:
net: A keras module which when called on an input tensor of size
net: A keras module which when called on an input tensor of size
...
@@ -189,7 +191,7 @@ def make_prediction_net(num_out_channels, kernel_sizes=(3), num_filters=(256),
...
@@ -189,7 +191,7 @@ def make_prediction_net(num_out_channels, kernel_sizes=(3), num_filters=(256),
layers
.
append
(
layers
.
append
(
conv_fn
(
conv_fn
(
num_filter
,
num_filter
,
kernel_size
=
kernel_size
,
kernel_size
=
[
1
,
kernel_size
]
if
unit_height_conv
else
kernel_size
,
padding
=
'same'
,
padding
=
'same'
,
name
=
'conv2_%d'
%
idx
if
tf_version
.
is_tf1
()
else
None
))
name
=
'conv2_%d'
%
idx
if
tf_version
.
is_tf1
()
else
None
))
layers
.
append
(
tf
.
keras
.
layers
.
ReLU
())
layers
.
append
(
tf
.
keras
.
layers
.
ReLU
())
...
@@ -2174,7 +2176,8 @@ class CenterNetMetaArch(model.DetectionModel):
...
@@ -2174,7 +2176,8 @@ class CenterNetMetaArch(model.DetectionModel):
temporal_offset_params
=
None
,
temporal_offset_params
=
None
,
use_depthwise
=
False
,
use_depthwise
=
False
,
compute_heatmap_sparse
=
False
,
compute_heatmap_sparse
=
False
,
non_max_suppression_fn
=
None
):
non_max_suppression_fn
=
None
,
unit_height_conv
=
False
):
"""Initializes a CenterNet model.
"""Initializes a CenterNet model.
Args:
Args:
...
@@ -2218,6 +2221,8 @@ class CenterNetMetaArch(model.DetectionModel):
...
@@ -2218,6 +2221,8 @@ class CenterNetMetaArch(model.DetectionModel):
better with number of channels in the heatmap, but in some cases is
better with number of channels in the heatmap, but in some cases is
known to cause an OOM error. See b/170989061.
known to cause an OOM error. See b/170989061.
non_max_suppression_fn: Optional Non Max Suppression function to apply.
non_max_suppression_fn: Optional Non Max Suppression function to apply.
unit_height_conv: If True, Conv2Ds in prediction heads have asymmetric
kernels with height=1.
"""
"""
assert
object_detection_params
or
keypoint_params_dict
assert
object_detection_params
or
keypoint_params_dict
# Shorten the name for convenience and better formatting.
# Shorten the name for convenience and better formatting.
...
@@ -2244,11 +2249,15 @@ class CenterNetMetaArch(model.DetectionModel):
...
@@ -2244,11 +2249,15 @@ class CenterNetMetaArch(model.DetectionModel):
self
.
_use_depthwise
=
use_depthwise
self
.
_use_depthwise
=
use_depthwise
self
.
_compute_heatmap_sparse
=
compute_heatmap_sparse
self
.
_compute_heatmap_sparse
=
compute_heatmap_sparse
# subclasses may not implement the unit_height_conv arg, so only provide it
# as a kwarg if it is True.
kwargs
=
{
'unit_height_conv'
:
unit_height_conv
}
if
unit_height_conv
else
{}
# Construct the prediction head nets.
# Construct the prediction head nets.
self
.
_prediction_head_dict
=
self
.
_construct_prediction_heads
(
self
.
_prediction_head_dict
=
self
.
_construct_prediction_heads
(
num_classes
,
num_classes
,
self
.
_num_feature_outputs
,
self
.
_num_feature_outputs
,
class_prediction_bias_init
=
self
.
_center_params
.
heatmap_bias_init
)
class_prediction_bias_init
=
self
.
_center_params
.
heatmap_bias_init
,
**
kwargs
)
# Initialize the target assigners.
# Initialize the target assigners.
self
.
_target_assigner_dict
=
self
.
_initialize_target_assigners
(
self
.
_target_assigner_dict
=
self
.
_initialize_target_assigners
(
stride
=
self
.
_stride
,
stride
=
self
.
_stride
,
...
@@ -2269,7 +2278,8 @@ class CenterNetMetaArch(model.DetectionModel):
...
@@ -2269,7 +2278,8 @@ class CenterNetMetaArch(model.DetectionModel):
def
_make_prediction_net_list
(
self
,
num_feature_outputs
,
num_out_channels
,
def
_make_prediction_net_list
(
self
,
num_feature_outputs
,
num_out_channels
,
kernel_sizes
=
(
3
),
num_filters
=
(
256
),
kernel_sizes
=
(
3
),
num_filters
=
(
256
),
bias_fill
=
None
,
name
=
None
):
bias_fill
=
None
,
name
=
None
,
unit_height_conv
=
False
):
prediction_net_list
=
[]
prediction_net_list
=
[]
for
i
in
range
(
num_feature_outputs
):
for
i
in
range
(
num_feature_outputs
):
prediction_net_list
.
append
(
prediction_net_list
.
append
(
...
@@ -2279,11 +2289,13 @@ class CenterNetMetaArch(model.DetectionModel):
...
@@ -2279,11 +2289,13 @@ class CenterNetMetaArch(model.DetectionModel):
num_filters
=
num_filters
,
num_filters
=
num_filters
,
bias_fill
=
bias_fill
,
bias_fill
=
bias_fill
,
use_depthwise
=
self
.
_use_depthwise
,
use_depthwise
=
self
.
_use_depthwise
,
name
=
'{}_{}'
.
format
(
name
,
i
)
if
name
else
name
))
name
=
'{}_{}'
.
format
(
name
,
i
)
if
name
else
name
,
unit_height_conv
=
unit_height_conv
))
return
prediction_net_list
return
prediction_net_list
def
_construct_prediction_heads
(
self
,
num_classes
,
num_feature_outputs
,
def
_construct_prediction_heads
(
self
,
num_classes
,
num_feature_outputs
,
class_prediction_bias_init
):
class_prediction_bias_init
,
unit_height_conv
=
False
):
"""Constructs the prediction heads based on the specific parameters.
"""Constructs the prediction heads based on the specific parameters.
Args:
Args:
...
@@ -2295,6 +2307,7 @@ class CenterNetMetaArch(model.DetectionModel):
...
@@ -2295,6 +2307,7 @@ class CenterNetMetaArch(model.DetectionModel):
class_prediction_bias_init: float, the initial value of bias in the
class_prediction_bias_init: float, the initial value of bias in the
convolutional kernel of the class prediction head. If set to None, the
convolutional kernel of the class prediction head. If set to None, the
bias is initialized with zeros.
bias is initialized with zeros.
unit_height_conv: If True, Conv2Ds have asymmetric kernels with height=1.
Returns:
Returns:
A dictionary of keras modules generated by calling make_prediction_net
A dictionary of keras modules generated by calling make_prediction_net
...
@@ -2308,13 +2321,16 @@ class CenterNetMetaArch(model.DetectionModel):
...
@@ -2308,13 +2321,16 @@ class CenterNetMetaArch(model.DetectionModel):
kernel_sizes
=
self
.
_center_params
.
center_head_kernel_sizes
,
kernel_sizes
=
self
.
_center_params
.
center_head_kernel_sizes
,
num_filters
=
self
.
_center_params
.
center_head_num_filters
,
num_filters
=
self
.
_center_params
.
center_head_num_filters
,
bias_fill
=
class_prediction_bias_init
,
bias_fill
=
class_prediction_bias_init
,
name
=
'center'
)
name
=
'center'
,
unit_height_conv
=
unit_height_conv
)
if
self
.
_od_params
is
not
None
:
if
self
.
_od_params
is
not
None
:
prediction_heads
[
BOX_SCALE
]
=
self
.
_make_prediction_net_list
(
prediction_heads
[
BOX_SCALE
]
=
self
.
_make_prediction_net_list
(
num_feature_outputs
,
NUM_SIZE_CHANNELS
,
name
=
'box_scale'
)
num_feature_outputs
,
NUM_SIZE_CHANNELS
,
name
=
'box_scale'
,
unit_height_conv
=
unit_height_conv
)
prediction_heads
[
BOX_OFFSET
]
=
self
.
_make_prediction_net_list
(
prediction_heads
[
BOX_OFFSET
]
=
self
.
_make_prediction_net_list
(
num_feature_outputs
,
NUM_OFFSET_CHANNELS
,
name
=
'box_offset'
)
num_feature_outputs
,
NUM_OFFSET_CHANNELS
,
name
=
'box_offset'
,
unit_height_conv
=
unit_height_conv
)
if
self
.
_kp_params_dict
is
not
None
:
if
self
.
_kp_params_dict
is
not
None
:
for
task_name
,
kp_params
in
self
.
_kp_params_dict
.
items
():
for
task_name
,
kp_params
in
self
.
_kp_params_dict
.
items
():
...
@@ -2326,14 +2342,16 @@ class CenterNetMetaArch(model.DetectionModel):
...
@@ -2326,14 +2342,16 @@ class CenterNetMetaArch(model.DetectionModel):
kernel_sizes
=
kp_params
.
heatmap_head_kernel_sizes
,
kernel_sizes
=
kp_params
.
heatmap_head_kernel_sizes
,
num_filters
=
kp_params
.
heatmap_head_num_filters
,
num_filters
=
kp_params
.
heatmap_head_num_filters
,
bias_fill
=
kp_params
.
heatmap_bias_init
,
bias_fill
=
kp_params
.
heatmap_bias_init
,
name
=
'kpt_heatmap'
)
name
=
'kpt_heatmap'
,
unit_height_conv
=
unit_height_conv
)
prediction_heads
[
get_keypoint_name
(
prediction_heads
[
get_keypoint_name
(
task_name
,
KEYPOINT_REGRESSION
)]
=
self
.
_make_prediction_net_list
(
task_name
,
KEYPOINT_REGRESSION
)]
=
self
.
_make_prediction_net_list
(
num_feature_outputs
,
num_feature_outputs
,
NUM_OFFSET_CHANNELS
*
num_keypoints
,
NUM_OFFSET_CHANNELS
*
num_keypoints
,
kernel_sizes
=
kp_params
.
regress_head_kernel_sizes
,
kernel_sizes
=
kp_params
.
regress_head_kernel_sizes
,
num_filters
=
kp_params
.
regress_head_num_filters
,
num_filters
=
kp_params
.
regress_head_num_filters
,
name
=
'kpt_regress'
)
name
=
'kpt_regress'
,
unit_height_conv
=
unit_height_conv
)
if
kp_params
.
per_keypoint_offset
:
if
kp_params
.
per_keypoint_offset
:
prediction_heads
[
get_keypoint_name
(
prediction_heads
[
get_keypoint_name
(
...
@@ -2342,7 +2360,8 @@ class CenterNetMetaArch(model.DetectionModel):
...
@@ -2342,7 +2360,8 @@ class CenterNetMetaArch(model.DetectionModel):
NUM_OFFSET_CHANNELS
*
num_keypoints
,
NUM_OFFSET_CHANNELS
*
num_keypoints
,
kernel_sizes
=
kp_params
.
offset_head_kernel_sizes
,
kernel_sizes
=
kp_params
.
offset_head_kernel_sizes
,
num_filters
=
kp_params
.
offset_head_num_filters
,
num_filters
=
kp_params
.
offset_head_num_filters
,
name
=
'kpt_offset'
)
name
=
'kpt_offset'
,
unit_height_conv
=
unit_height_conv
)
else
:
else
:
prediction_heads
[
get_keypoint_name
(
prediction_heads
[
get_keypoint_name
(
task_name
,
KEYPOINT_OFFSET
)]
=
self
.
_make_prediction_net_list
(
task_name
,
KEYPOINT_OFFSET
)]
=
self
.
_make_prediction_net_list
(
...
@@ -2350,38 +2369,44 @@ class CenterNetMetaArch(model.DetectionModel):
...
@@ -2350,38 +2369,44 @@ class CenterNetMetaArch(model.DetectionModel):
NUM_OFFSET_CHANNELS
,
NUM_OFFSET_CHANNELS
,
kernel_sizes
=
kp_params
.
offset_head_kernel_sizes
,
kernel_sizes
=
kp_params
.
offset_head_kernel_sizes
,
num_filters
=
kp_params
.
offset_head_num_filters
,
num_filters
=
kp_params
.
offset_head_num_filters
,
name
=
'kpt_offset'
)
name
=
'kpt_offset'
,
unit_height_conv
=
unit_height_conv
)
if
kp_params
.
predict_depth
:
if
kp_params
.
predict_depth
:
num_depth_channel
=
(
num_depth_channel
=
(
num_keypoints
if
kp_params
.
per_keypoint_depth
else
1
)
num_keypoints
if
kp_params
.
per_keypoint_depth
else
1
)
prediction_heads
[
get_keypoint_name
(
prediction_heads
[
get_keypoint_name
(
task_name
,
KEYPOINT_DEPTH
)]
=
self
.
_make_prediction_net_list
(
task_name
,
KEYPOINT_DEPTH
)]
=
self
.
_make_prediction_net_list
(
num_feature_outputs
,
num_depth_channel
,
name
=
'kpt_depth'
)
num_feature_outputs
,
num_depth_channel
,
name
=
'kpt_depth'
,
unit_height_conv
=
unit_height_conv
)
if
self
.
_mask_params
is
not
None
:
if
self
.
_mask_params
is
not
None
:
prediction_heads
[
SEGMENTATION_HEATMAP
]
=
self
.
_make_prediction_net_list
(
prediction_heads
[
SEGMENTATION_HEATMAP
]
=
self
.
_make_prediction_net_list
(
num_feature_outputs
,
num_feature_outputs
,
num_classes
,
num_classes
,
bias_fill
=
self
.
_mask_params
.
heatmap_bias_init
,
bias_fill
=
self
.
_mask_params
.
heatmap_bias_init
,
name
=
'seg_heatmap'
)
name
=
'seg_heatmap'
,
unit_height_conv
=
unit_height_conv
)
if
self
.
_densepose_params
is
not
None
:
if
self
.
_densepose_params
is
not
None
:
prediction_heads
[
DENSEPOSE_HEATMAP
]
=
self
.
_make_prediction_net_list
(
prediction_heads
[
DENSEPOSE_HEATMAP
]
=
self
.
_make_prediction_net_list
(
num_feature_outputs
,
num_feature_outputs
,
self
.
_densepose_params
.
num_parts
,
self
.
_densepose_params
.
num_parts
,
bias_fill
=
self
.
_densepose_params
.
heatmap_bias_init
,
bias_fill
=
self
.
_densepose_params
.
heatmap_bias_init
,
name
=
'dense_pose_heatmap'
)
name
=
'dense_pose_heatmap'
,
unit_height_conv
=
unit_height_conv
)
prediction_heads
[
DENSEPOSE_REGRESSION
]
=
self
.
_make_prediction_net_list
(
prediction_heads
[
DENSEPOSE_REGRESSION
]
=
self
.
_make_prediction_net_list
(
num_feature_outputs
,
num_feature_outputs
,
2
*
self
.
_densepose_params
.
num_parts
,
2
*
self
.
_densepose_params
.
num_parts
,
name
=
'dense_pose_regress'
)
name
=
'dense_pose_regress'
,
unit_height_conv
=
unit_height_conv
)
if
self
.
_track_params
is
not
None
:
if
self
.
_track_params
is
not
None
:
prediction_heads
[
TRACK_REID
]
=
self
.
_make_prediction_net_list
(
prediction_heads
[
TRACK_REID
]
=
self
.
_make_prediction_net_list
(
num_feature_outputs
,
num_feature_outputs
,
self
.
_track_params
.
reid_embed_size
,
self
.
_track_params
.
reid_embed_size
,
name
=
'track_reid'
)
name
=
'track_reid'
,
unit_height_conv
=
unit_height_conv
)
# Creates a classification network to train object embeddings by learning
# Creates a classification network to train object embeddings by learning
# a projection from embedding space to object track ID space.
# a projection from embedding space to object track ID space.
...
@@ -2400,7 +2425,8 @@ class CenterNetMetaArch(model.DetectionModel):
...
@@ -2400,7 +2425,8 @@ class CenterNetMetaArch(model.DetectionModel):
self
.
_track_params
.
reid_embed_size
,)))
self
.
_track_params
.
reid_embed_size
,)))
if
self
.
_temporal_offset_params
is
not
None
:
if
self
.
_temporal_offset_params
is
not
None
:
prediction_heads
[
TEMPORAL_OFFSET
]
=
self
.
_make_prediction_net_list
(
prediction_heads
[
TEMPORAL_OFFSET
]
=
self
.
_make_prediction_net_list
(
num_feature_outputs
,
NUM_OFFSET_CHANNELS
,
name
=
'temporal_offset'
)
num_feature_outputs
,
NUM_OFFSET_CHANNELS
,
name
=
'temporal_offset'
,
unit_height_conv
=
unit_height_conv
)
return
prediction_heads
return
prediction_heads
def
_initialize_target_assigners
(
self
,
stride
,
min_box_overlap_iou
):
def
_initialize_target_assigners
(
self
,
stride
,
min_box_overlap_iou
):
...
@@ -3357,8 +3383,8 @@ class CenterNetMetaArch(model.DetectionModel):
...
@@ -3357,8 +3383,8 @@ class CenterNetMetaArch(model.DetectionModel):
_
,
input_height
,
input_width
,
_
=
_get_shape
(
_
,
input_height
,
input_width
,
_
=
_get_shape
(
prediction_dict
[
'preprocessed_inputs'
],
4
)
prediction_dict
[
'preprocessed_inputs'
],
4
)
output_height
,
output_width
=
(
input_height
//
self
.
_stride
,
output_height
,
output_width
=
(
tf
.
maximum
(
input_height
//
self
.
_stride
,
1
),
input_width
//
self
.
_stride
)
tf
.
maximum
(
input_width
//
self
.
_stride
,
1
)
)
# TODO(vighneshb) Explore whether using floor here is safe.
# TODO(vighneshb) Explore whether using floor here is safe.
output_true_image_shapes
=
tf
.
ceil
(
output_true_image_shapes
=
tf
.
ceil
(
...
...
research/object_detection/meta_architectures/center_net_meta_arch_tf2_test.py
View file @
e8a80796
...
@@ -2995,6 +2995,162 @@ class CenterNetFeatureExtractorTest(test_case.TestCase):
...
@@ -2995,6 +2995,162 @@ class CenterNetFeatureExtractorTest(test_case.TestCase):
self
.
assertAllClose
(
output
[...,
2
],
3
*
np
.
ones
((
2
,
32
,
32
)))
self
.
assertAllClose
(
output
[...,
2
],
3
*
np
.
ones
((
2
,
32
,
32
)))
class
Dummy1dFeatureExtractor
(
cnma
.
CenterNetFeatureExtractor
):
"""Returns a static tensor."""
def
__init__
(
self
,
tensor
,
out_stride
=
1
,
channel_means
=
(
0.
,
0.
,
0.
),
channel_stds
=
(
1.
,
1.
,
1.
),
bgr_ordering
=
False
):
"""Intializes the feature extractor.
Args:
tensor: The tensor to return as the processed feature.
out_stride: The out_stride to return if asked.
channel_means: Ignored, but provided for API compatability.
channel_stds: Ignored, but provided for API compatability.
bgr_ordering: Ignored, but provided for API compatability.
"""
super
().
__init__
(
channel_means
=
channel_means
,
channel_stds
=
channel_stds
,
bgr_ordering
=
bgr_ordering
)
self
.
_tensor
=
tensor
self
.
_out_stride
=
out_stride
def
call
(
self
,
inputs
):
return
[
self
.
_tensor
]
@
property
def
out_stride
(
self
):
"""The stride in the output image of the network."""
return
self
.
_out_stride
@
property
def
num_feature_outputs
(
self
):
"""Ther number of feature outputs returned by the feature extractor."""
return
1
@
property
def
supported_sub_model_types
(
self
):
return
[
'detection'
]
def
get_sub_model
(
self
,
sub_model_type
):
if
sub_model_type
==
'detection'
:
return
self
.
_network
else
:
ValueError
(
'Sub model type "{}" not supported.'
.
format
(
sub_model_type
))
@
unittest
.
skipIf
(
tf_version
.
is_tf1
(),
'Skipping TF2.X only test.'
)
class
CenterNetMetaArch1dTest
(
test_case
.
TestCase
,
parameterized
.
TestCase
):
@
parameterized
.
parameters
([
1
,
2
])
def
test_outputs_with_correct_shape
(
self
,
stride
):
# The 1D case reuses code from the 2D cases. These tests only check that
# the output shapes are correct, and relies on other tests for correctness.
batch_size
=
2
height
=
1
width
=
32
channels
=
16
unstrided_inputs
=
np
.
random
.
randn
(
batch_size
,
height
,
width
,
channels
)
fixed_output_features
=
np
.
random
.
randn
(
batch_size
,
height
,
width
//
stride
,
channels
)
max_boxes
=
10
num_classes
=
3
feature_extractor
=
Dummy1dFeatureExtractor
(
fixed_output_features
,
stride
)
arch
=
cnma
.
CenterNetMetaArch
(
is_training
=
True
,
add_summaries
=
True
,
num_classes
=
num_classes
,
feature_extractor
=
feature_extractor
,
image_resizer_fn
=
None
,
object_center_params
=
cnma
.
ObjectCenterParams
(
classification_loss
=
losses
.
PenaltyReducedLogisticFocalLoss
(),
object_center_loss_weight
=
1.0
,
max_box_predictions
=
max_boxes
,
),
object_detection_params
=
cnma
.
ObjectDetectionParams
(
localization_loss
=
losses
.
L1LocalizationLoss
(),
scale_loss_weight
=
1.0
,
offset_loss_weight
=
1.0
,
),
keypoint_params_dict
=
None
,
mask_params
=
None
,
densepose_params
=
None
,
track_params
=
None
,
temporal_offset_params
=
None
,
use_depthwise
=
False
,
compute_heatmap_sparse
=
False
,
non_max_suppression_fn
=
None
,
unit_height_conv
=
True
)
arch
.
provide_groundtruth
(
groundtruth_boxes_list
=
[
tf
.
constant
([[
0
,
0.5
,
1.0
,
0.75
],
[
0
,
0.1
,
1.0
,
0.25
]],
tf
.
float32
),
tf
.
constant
([[
0
,
0
,
1.0
,
1.0
],
[
0
,
0
,
0.0
,
0.0
]],
tf
.
float32
)
],
groundtruth_classes_list
=
[
tf
.
constant
([[
0
,
0
,
1
],
[
0
,
1
,
0
]],
tf
.
float32
),
tf
.
constant
([[
1
,
0
,
0
],
[
0
,
0
,
0
]],
tf
.
float32
)
],
groundtruth_weights_list
=
[
tf
.
constant
([
1.0
,
1.0
]),
tf
.
constant
([
1.0
,
0.0
])]
)
predictions
=
arch
.
predict
(
None
,
None
)
# input is hardcoded above.
predictions
[
'preprocessed_inputs'
]
=
tf
.
constant
(
unstrided_inputs
)
true_shapes
=
tf
.
constant
([[
1
,
32
,
16
],
[
1
,
24
,
16
]],
tf
.
int32
)
postprocess_output
=
arch
.
postprocess
(
predictions
,
true_shapes
)
losses_output
=
arch
.
loss
(
predictions
,
true_shapes
)
self
.
assertIn
(
'%s/%s'
%
(
cnma
.
LOSS_KEY_PREFIX
,
cnma
.
OBJECT_CENTER
),
losses_output
)
self
.
assertEqual
((),
losses_output
[
'%s/%s'
%
(
cnma
.
LOSS_KEY_PREFIX
,
cnma
.
OBJECT_CENTER
)].
shape
)
self
.
assertIn
(
'%s/%s'
%
(
cnma
.
LOSS_KEY_PREFIX
,
cnma
.
BOX_SCALE
),
losses_output
)
self
.
assertEqual
((),
losses_output
[
'%s/%s'
%
(
cnma
.
LOSS_KEY_PREFIX
,
cnma
.
BOX_SCALE
)].
shape
)
self
.
assertIn
(
'%s/%s'
%
(
cnma
.
LOSS_KEY_PREFIX
,
cnma
.
BOX_OFFSET
),
losses_output
)
self
.
assertEqual
((),
losses_output
[
'%s/%s'
%
(
cnma
.
LOSS_KEY_PREFIX
,
cnma
.
BOX_OFFSET
)].
shape
)
self
.
assertIn
(
'detection_scores'
,
postprocess_output
)
self
.
assertEqual
(
postprocess_output
[
'detection_scores'
].
shape
,
(
batch_size
,
max_boxes
))
self
.
assertIn
(
'detection_multiclass_scores'
,
postprocess_output
)
self
.
assertEqual
(
postprocess_output
[
'detection_multiclass_scores'
].
shape
,
(
batch_size
,
max_boxes
,
num_classes
))
self
.
assertIn
(
'detection_classes'
,
postprocess_output
)
self
.
assertEqual
(
postprocess_output
[
'detection_classes'
].
shape
,
(
batch_size
,
max_boxes
))
self
.
assertIn
(
'num_detections'
,
postprocess_output
)
self
.
assertEqual
(
postprocess_output
[
'num_detections'
].
shape
,
(
batch_size
,))
self
.
assertIn
(
'detection_boxes'
,
postprocess_output
)
self
.
assertEqual
(
postprocess_output
[
'detection_boxes'
].
shape
,
(
batch_size
,
max_boxes
,
4
))
self
.
assertIn
(
'detection_boxes_strided'
,
postprocess_output
)
self
.
assertEqual
(
postprocess_output
[
'detection_boxes_strided'
].
shape
,
(
batch_size
,
max_boxes
,
4
))
self
.
assertIn
(
cnma
.
OBJECT_CENTER
,
predictions
)
self
.
assertEqual
(
predictions
[
cnma
.
OBJECT_CENTER
][
0
].
shape
,
(
batch_size
,
height
,
width
//
stride
,
num_classes
))
self
.
assertIn
(
cnma
.
BOX_SCALE
,
predictions
)
self
.
assertEqual
(
predictions
[
cnma
.
BOX_SCALE
][
0
].
shape
,
(
batch_size
,
height
,
width
//
stride
,
2
))
self
.
assertIn
(
cnma
.
BOX_OFFSET
,
predictions
)
self
.
assertEqual
(
predictions
[
cnma
.
BOX_OFFSET
][
0
].
shape
,
(
batch_size
,
height
,
width
//
stride
,
2
))
self
.
assertIn
(
'preprocessed_inputs'
,
predictions
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
tf
.
enable_v2_behavior
()
tf
.
enable_v2_behavior
()
tf
.
test
.
main
()
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment