Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
6559ed14
Commit
6559ed14
authored
Jan 29, 2021
by
Vivek Rathod
Committed by
TF Object Detection Team
Jan 29, 2021
Browse files
Add opional Non-Max Suppression in CenterNet
PiperOrigin-RevId: 354590046
parent
3063aeb3
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
123 additions
and
8 deletions
+123
-8
research/object_detection/builders/model_builder.py
research/object_detection/builders/model_builder.py
+6
-2
research/object_detection/meta_architectures/center_net_meta_arch.py
...ject_detection/meta_architectures/center_net_meta_arch.py
+32
-1
research/object_detection/meta_architectures/center_net_meta_arch_tf2_test.py
...ction/meta_architectures/center_net_meta_arch_tf2_test.py
+77
-5
research/object_detection/protos/center_net.proto
research/object_detection/protos/center_net.proto
+8
-0
No files found.
research/object_detection/builders/model_builder.py
View file @
6559ed14
...
...
@@ -1039,7 +1039,10 @@ def _build_center_net_model(center_net_config, is_training, add_summaries):
if
center_net_config
.
HasField
(
'temporal_offset_task'
):
temporal_offset_params
=
temporal_offset_proto_to_params
(
center_net_config
.
temporal_offset_task
)
non_max_suppression_fn
=
None
if
center_net_config
.
HasField
(
'post_processing'
):
non_max_suppression_fn
,
_
=
post_processing_builder
.
build
(
center_net_config
.
post_processing
)
return
center_net_meta_arch
.
CenterNetMetaArch
(
is_training
=
is_training
,
add_summaries
=
add_summaries
,
...
...
@@ -1054,7 +1057,8 @@ def _build_center_net_model(center_net_config, is_training, add_summaries):
track_params
=
track_params
,
temporal_offset_params
=
temporal_offset_params
,
use_depthwise
=
center_net_config
.
use_depthwise
,
compute_heatmap_sparse
=
center_net_config
.
compute_heatmap_sparse
)
compute_heatmap_sparse
=
center_net_config
.
compute_heatmap_sparse
,
non_max_suppression_fn
=
non_max_suppression_fn
)
def
_build_center_net_feature_extractor
(
...
...
research/object_detection/meta_architectures/center_net_meta_arch.py
View file @
6559ed14
...
...
@@ -1896,7 +1896,8 @@ class CenterNetMetaArch(model.DetectionModel):
track_params
=
None
,
temporal_offset_params
=
None
,
use_depthwise
=
False
,
compute_heatmap_sparse
=
False
):
compute_heatmap_sparse
=
False
,
non_max_suppression_fn
=
None
):
"""Initializes a CenterNet model.
Args:
...
...
@@ -1939,6 +1940,7 @@ class CenterNetMetaArch(model.DetectionModel):
the Op that computes the center heatmaps. The sparse version scales
better with number of channels in the heatmap, but in some cases is
known to cause an OOM error. See b/170989061.
non_max_suppression_fn: Optional Non Max Suppression function to apply.
"""
assert
object_detection_params
or
keypoint_params_dict
# Shorten the name for convenience and better formatting.
...
...
@@ -1977,6 +1979,7 @@ class CenterNetMetaArch(model.DetectionModel):
# Will be used in VOD single_frame_meta_arch for tensor reshape.
self
.
_batched_prediction_tensor_names
=
[]
self
.
_non_max_suppression_fn
=
non_max_suppression_fn
super
(
CenterNetMetaArch
,
self
).
__init__
(
num_classes
)
...
...
@@ -3108,6 +3111,34 @@ class CenterNetMetaArch(model.DetectionModel):
prediction_dict
[
TEMPORAL_OFFSET
][
-
1
])
postprocess_dict
[
fields
.
DetectionResultFields
.
detection_offsets
]
=
offsets
if
self
.
_non_max_suppression_fn
:
boxes
=
tf
.
expand_dims
(
postprocess_dict
.
pop
(
fields
.
DetectionResultFields
.
detection_boxes
),
axis
=-
2
)
multiclass_scores
=
postprocess_dict
[
fields
.
DetectionResultFields
.
detection_multiclass_scores
]
num_valid_boxes
=
postprocess_dict
.
pop
(
fields
.
DetectionResultFields
.
num_detections
)
# Remove scores and classes as NMS will compute these form multiclass
# scores.
postprocess_dict
.
pop
(
fields
.
DetectionResultFields
.
detection_scores
)
postprocess_dict
.
pop
(
fields
.
DetectionResultFields
.
detection_classes
)
(
nmsed_boxes
,
nmsed_scores
,
nmsed_classes
,
_
,
nmsed_additional_fields
,
num_detections
)
=
self
.
_non_max_suppression_fn
(
boxes
,
multiclass_scores
,
additional_fields
=
postprocess_dict
,
num_valid_boxes
=
num_valid_boxes
)
postprocess_dict
=
nmsed_additional_fields
postprocess_dict
[
fields
.
DetectionResultFields
.
detection_boxes
]
=
nmsed_boxes
postprocess_dict
[
fields
.
DetectionResultFields
.
detection_scores
]
=
nmsed_scores
postprocess_dict
[
fields
.
DetectionResultFields
.
detection_classes
]
=
nmsed_classes
postprocess_dict
[
fields
.
DetectionResultFields
.
num_detections
]
=
num_detections
postprocess_dict
.
update
(
nmsed_additional_fields
)
return
postprocess_dict
def
postprocess_single_instance_keypoints
(
self
,
prediction_dict
,
...
...
research/object_detection/meta_architectures/center_net_meta_arch_tf2_test.py
View file @
6559ed14
...
...
@@ -24,12 +24,14 @@ from absl.testing import parameterized
import
numpy
as
np
import
tensorflow.compat.v1
as
tf
from
object_detection.builders
import
post_processing_builder
from
object_detection.core
import
losses
from
object_detection.core
import
preprocessor
from
object_detection.core
import
standard_fields
as
fields
from
object_detection.core
import
target_assigner
as
cn_assigner
from
object_detection.meta_architectures
import
center_net_meta_arch
as
cnma
from
object_detection.models
import
center_net_resnet_feature_extractor
from
object_detection.protos
import
post_processing_pb2
from
object_detection.utils
import
test_case
from
object_detection.utils
import
tf_version
...
...
@@ -1349,7 +1351,9 @@ def get_fake_temporal_offset_params():
def
build_center_net_meta_arch
(
build_resnet
=
False
,
num_classes
=
_NUM_CLASSES
,
max_box_predictions
=
5
):
max_box_predictions
=
5
,
apply_non_max_suppression
=
False
,
detection_only
=
False
):
"""Builds the CenterNet meta architecture."""
if
build_resnet
:
feature_extractor
=
(
...
...
@@ -1368,7 +1372,31 @@ def build_center_net_meta_arch(build_resnet=False,
max_dimension
=
128
,
pad_to_max_dimesnion
=
True
)
if
num_classes
==
1
:
non_max_suppression_fn
=
None
if
apply_non_max_suppression
:
post_processing_proto
=
post_processing_pb2
.
PostProcessing
()
post_processing_proto
.
batch_non_max_suppression
.
iou_threshold
=
1.0
post_processing_proto
.
batch_non_max_suppression
.
score_threshold
=
0.6
(
post_processing_proto
.
batch_non_max_suppression
.
max_total_detections
)
=
max_box_predictions
(
post_processing_proto
.
batch_non_max_suppression
.
max_detections_per_class
)
=
max_box_predictions
(
post_processing_proto
.
batch_non_max_suppression
.
change_coordinate_frame
)
=
False
non_max_suppression_fn
,
_
=
post_processing_builder
.
build
(
post_processing_proto
)
if
detection_only
:
return
cnma
.
CenterNetMetaArch
(
is_training
=
True
,
add_summaries
=
False
,
num_classes
=
num_classes
,
feature_extractor
=
feature_extractor
,
image_resizer_fn
=
image_resizer_fn
,
object_center_params
=
get_fake_center_params
(
max_box_predictions
),
object_detection_params
=
get_fake_od_params
(),
non_max_suppression_fn
=
non_max_suppression_fn
)
elif
num_classes
==
1
:
num_candidates_per_keypoint
=
100
if
max_box_predictions
>
1
else
1
return
cnma
.
CenterNetMetaArch
(
is_training
=
True
,
...
...
@@ -1380,7 +1408,8 @@ def build_center_net_meta_arch(build_resnet=False,
object_detection_params
=
get_fake_od_params
(),
keypoint_params_dict
=
{
_TASK_NAME
:
get_fake_kp_params
(
num_candidates_per_keypoint
)
})
},
non_max_suppression_fn
=
non_max_suppression_fn
)
else
:
return
cnma
.
CenterNetMetaArch
(
is_training
=
True
,
...
...
@@ -1394,7 +1423,8 @@ def build_center_net_meta_arch(build_resnet=False,
mask_params
=
get_fake_mask_params
(),
densepose_params
=
get_fake_densepose_params
(),
track_params
=
get_fake_track_params
(),
temporal_offset_params
=
get_fake_temporal_offset_params
())
temporal_offset_params
=
get_fake_temporal_offset_params
(),
non_max_suppression_fn
=
non_max_suppression_fn
)
def
_logit
(
p
):
...
...
@@ -1728,7 +1758,6 @@ class CenterNetMetaArchTest(test_case.TestCase, parameterized.TestCase):
return
detections
detections
=
self
.
execute_cpu
(
graph_fn
,
[])
self
.
assertAllClose
(
detections
[
'detection_boxes'
][
0
,
0
],
np
.
array
([
55
,
46
,
75
,
86
])
/
128.0
)
self
.
assertAllClose
(
detections
[
'detection_scores'
][
0
],
...
...
@@ -1801,6 +1830,49 @@ class CenterNetMetaArchTest(test_case.TestCase, parameterized.TestCase):
detections
[
'detection_surface_coords'
][
0
,
0
,
:,
:],
np
.
zeros_like
(
detections
[
'detection_surface_coords'
][
0
,
0
,
:,
:]))
def
test_non_max_suppression
(
self
):
"""Tests application of NMS on CenterNet detections."""
target_class_id
=
1
model
=
build_center_net_meta_arch
(
apply_non_max_suppression
=
True
,
detection_only
=
True
)
class_center
=
np
.
zeros
((
1
,
32
,
32
,
10
),
dtype
=
np
.
float32
)
height_width
=
np
.
zeros
((
1
,
32
,
32
,
2
),
dtype
=
np
.
float32
)
offset
=
np
.
zeros
((
1
,
32
,
32
,
2
),
dtype
=
np
.
float32
)
class_probs
=
np
.
ones
(
10
)
*
_logit
(
0.25
)
class_probs
[
target_class_id
]
=
_logit
(
0.75
)
class_center
[
0
,
16
,
16
]
=
class_probs
height_width
[
0
,
16
,
16
]
=
[
5
,
10
]
offset
[
0
,
16
,
16
]
=
[.
25
,
.
5
]
class_center
=
tf
.
constant
(
class_center
)
height_width
=
tf
.
constant
(
height_width
)
offset
=
tf
.
constant
(
offset
)
prediction_dict
=
{
cnma
.
OBJECT_CENTER
:
[
class_center
],
cnma
.
BOX_SCALE
:
[
height_width
],
cnma
.
BOX_OFFSET
:
[
offset
],
}
def
graph_fn
():
detections
=
model
.
postprocess
(
prediction_dict
,
tf
.
constant
([[
128
,
128
,
3
]]))
return
detections
detections
=
self
.
execute_cpu
(
graph_fn
,
[])
num_detections
=
int
(
detections
[
'num_detections'
])
self
.
assertEqual
(
num_detections
,
1
)
self
.
assertAllClose
(
detections
[
'detection_boxes'
][
0
,
0
],
np
.
array
([
55
,
46
,
75
,
86
])
/
128.0
)
self
.
assertAllClose
(
detections
[
'detection_scores'
][
0
][:
num_detections
],
[.
75
])
expected_multiclass_scores
=
[.
25
]
*
10
expected_multiclass_scores
[
target_class_id
]
=
.
75
self
.
assertAllClose
(
expected_multiclass_scores
,
detections
[
'detection_multiclass_scores'
][
0
][
0
])
def
test_postprocess_single_class
(
self
):
"""Test the postprocess function."""
model
=
build_center_net_meta_arch
(
num_classes
=
1
)
...
...
research/object_detection/protos/center_net.proto
View file @
6559ed14
...
...
@@ -4,6 +4,7 @@ package object_detection.protos;
import
"object_detection/protos/image_resizer.proto"
;
import
"object_detection/protos/losses.proto"
;
import
"object_detection/protos/post_processing.proto"
;
// Configuration for the CenterNet meta architecture from the "Objects as
// Points" paper [1]
...
...
@@ -271,6 +272,13 @@ message CenterNet {
optional
TemporalOffsetEstimation
temporal_offset_task
=
12
;
// CenterNet does not apply conventional post processing operations such as
// non max suppression as it applies a max-pool operator on box centers.
// However, in some cases we observe the need to remove duplicate predictions
// from CenterNet. Use this optional parameter to apply traditional non max
// suppression and score thresholding.
optional
PostProcessing
post_processing
=
24
;
}
message
CenterNetFeatureExtractor
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment