Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
af1a6c57
Commit
af1a6c57
authored
Mar 28, 2022
by
Vighnesh Birodkar
Committed by
TF Object Detection Team
Mar 28, 2022
Browse files
Add CenterNet option to expose predict() outputs in postprocess.
PiperOrigin-RevId: 437824318
parent
f8f4ab71
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
18 additions
and
4 deletions
+18
-4
research/object_detection/builders/model_builder.py
research/object_detection/builders/model_builder.py
+2
-1
research/object_detection/meta_architectures/center_net_meta_arch.py
...ject_detection/meta_architectures/center_net_meta_arch.py
+10
-2
research/object_detection/protos/center_net.proto
research/object_detection/protos/center_net.proto
+6
-1
No files found.
research/object_detection/builders/model_builder.py
View file @
af1a6c57
...
...
@@ -1171,7 +1171,8 @@ def _build_center_net_model(center_net_config, is_training, add_summaries):
temporal_offset_params
=
temporal_offset_params
,
use_depthwise
=
center_net_config
.
use_depthwise
,
compute_heatmap_sparse
=
center_net_config
.
compute_heatmap_sparse
,
non_max_suppression_fn
=
non_max_suppression_fn
)
non_max_suppression_fn
=
non_max_suppression_fn
,
output_prediction_dict
=
center_net_config
.
output_prediction_dict
)
def
_build_center_net_feature_extractor
(
feature_extractor_config
,
is_training
):
...
...
research/object_detection/meta_architectures/center_net_meta_arch.py
View file @
af1a6c57
...
...
@@ -2676,7 +2676,8 @@ class CenterNetMetaArch(model.DetectionModel):
use_depthwise
=
False
,
compute_heatmap_sparse
=
False
,
non_max_suppression_fn
=
None
,
unit_height_conv
=
False
):
unit_height_conv
=
False
,
output_prediction_dict
=
False
):
"""Initializes a CenterNet model.
Args:
...
...
@@ -2722,6 +2723,8 @@ class CenterNetMetaArch(model.DetectionModel):
non_max_suppression_fn: Optional Non Max Suppression function to apply.
unit_height_conv: If True, Conv2Ds in prediction heads have asymmetric
kernels with height=1.
output_prediction_dict: If true, combines all items from the dictionary
returned by predict() function into the output of postprocess().
"""
assert
object_detection_params
or
keypoint_params_dict
# Shorten the name for convenience and better formatting.
...
...
@@ -2747,6 +2750,7 @@ class CenterNetMetaArch(model.DetectionModel):
self
.
_use_depthwise
=
use_depthwise
self
.
_compute_heatmap_sparse
=
compute_heatmap_sparse
self
.
_output_prediction_dict
=
output_prediction_dict
# subclasses may not implement the unit_height_conv arg, so only provide it
# as a kwarg if it is True.
...
...
@@ -4110,6 +4114,10 @@ class CenterNetMetaArch(model.DetectionModel):
fields
.
DetectionResultFields
.
num_detections
:
num_detections
,
}
if
self
.
_output_prediction_dict
:
postprocess_dict
.
update
(
prediction_dict
)
postprocess_dict
[
'true_image_shapes'
]
=
true_image_shapes
boxes_strided
=
None
if
self
.
_od_params
:
boxes_strided
=
(
...
...
@@ -4122,7 +4130,7 @@ class CenterNetMetaArch(model.DetectionModel):
postprocess_dict
.
update
({
fields
.
DetectionResultFields
.
detection_boxes
:
boxes
,
'detection_boxes_strided'
:
boxes_strided
'detection_boxes_strided'
:
boxes_strided
,
})
if
self
.
_kp_params_dict
:
...
...
research/object_detection/protos/center_net.proto
View file @
af1a6c57
...
...
@@ -11,7 +11,7 @@ import "object_detection/protos/preprocessor.proto";
// Points" paper [1]
// [1]: https://arxiv.org/abs/1904.07850
// Next Id =
1
6
// Next Id =
2
6
message
CenterNet
{
// Number of classes to predict.
optional
int32
num_classes
=
1
;
...
...
@@ -504,6 +504,7 @@ message CenterNet {
// within error bars.
optional
bool
use_only_last_stage
=
24
[
default
=
false
];
}
optional
DeepMACMaskEstimation
deepmac_mask_estimation
=
14
;
...
...
@@ -514,6 +515,10 @@ message CenterNet {
// from CenterNet. Use this optional parameter to apply traditional non max
// suppression and score thresholding.
optional
PostProcessing
post_processing
=
24
;
// If set, dictionary items returned by the predict() function
// are appended to the output of postprocess().
optional
bool
output_prediction_dict
=
25
[
default
=
false
];
}
enum
LossNormalize
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment