Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
af1a6c57
Commit
af1a6c57
authored
Mar 28, 2022
by
Vighnesh Birodkar
Committed by
TF Object Detection Team
Mar 28, 2022
Browse files
Add CenterNet option to expose predict() outputs in postprocess.
PiperOrigin-RevId: 437824318
parent
f8f4ab71
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
18 additions
and
4 deletions
+18
-4
research/object_detection/builders/model_builder.py
research/object_detection/builders/model_builder.py
+2
-1
research/object_detection/meta_architectures/center_net_meta_arch.py
...ject_detection/meta_architectures/center_net_meta_arch.py
+10
-2
research/object_detection/protos/center_net.proto
research/object_detection/protos/center_net.proto
+6
-1
No files found.
research/object_detection/builders/model_builder.py
View file @
af1a6c57
...
@@ -1171,7 +1171,8 @@ def _build_center_net_model(center_net_config, is_training, add_summaries):
...
@@ -1171,7 +1171,8 @@ def _build_center_net_model(center_net_config, is_training, add_summaries):
temporal_offset_params
=
temporal_offset_params
,
temporal_offset_params
=
temporal_offset_params
,
use_depthwise
=
center_net_config
.
use_depthwise
,
use_depthwise
=
center_net_config
.
use_depthwise
,
compute_heatmap_sparse
=
center_net_config
.
compute_heatmap_sparse
,
compute_heatmap_sparse
=
center_net_config
.
compute_heatmap_sparse
,
non_max_suppression_fn
=
non_max_suppression_fn
)
non_max_suppression_fn
=
non_max_suppression_fn
,
output_prediction_dict
=
center_net_config
.
output_prediction_dict
)
def
_build_center_net_feature_extractor
(
feature_extractor_config
,
is_training
):
def
_build_center_net_feature_extractor
(
feature_extractor_config
,
is_training
):
...
...
research/object_detection/meta_architectures/center_net_meta_arch.py
View file @
af1a6c57
...
@@ -2676,7 +2676,8 @@ class CenterNetMetaArch(model.DetectionModel):
...
@@ -2676,7 +2676,8 @@ class CenterNetMetaArch(model.DetectionModel):
use_depthwise
=
False
,
use_depthwise
=
False
,
compute_heatmap_sparse
=
False
,
compute_heatmap_sparse
=
False
,
non_max_suppression_fn
=
None
,
non_max_suppression_fn
=
None
,
unit_height_conv
=
False
):
unit_height_conv
=
False
,
output_prediction_dict
=
False
):
"""Initializes a CenterNet model.
"""Initializes a CenterNet model.
Args:
Args:
...
@@ -2722,6 +2723,8 @@ class CenterNetMetaArch(model.DetectionModel):
...
@@ -2722,6 +2723,8 @@ class CenterNetMetaArch(model.DetectionModel):
non_max_suppression_fn: Optional Non Max Suppression function to apply.
non_max_suppression_fn: Optional Non Max Suppression function to apply.
unit_height_conv: If True, Conv2Ds in prediction heads have asymmetric
unit_height_conv: If True, Conv2Ds in prediction heads have asymmetric
kernels with height=1.
kernels with height=1.
output_prediction_dict: If true, combines all items from the dictionary
returned by predict() function into the output of postprocess().
"""
"""
assert
object_detection_params
or
keypoint_params_dict
assert
object_detection_params
or
keypoint_params_dict
# Shorten the name for convenience and better formatting.
# Shorten the name for convenience and better formatting.
...
@@ -2747,6 +2750,7 @@ class CenterNetMetaArch(model.DetectionModel):
...
@@ -2747,6 +2750,7 @@ class CenterNetMetaArch(model.DetectionModel):
self
.
_use_depthwise
=
use_depthwise
self
.
_use_depthwise
=
use_depthwise
self
.
_compute_heatmap_sparse
=
compute_heatmap_sparse
self
.
_compute_heatmap_sparse
=
compute_heatmap_sparse
self
.
_output_prediction_dict
=
output_prediction_dict
# subclasses may not implement the unit_height_conv arg, so only provide it
# subclasses may not implement the unit_height_conv arg, so only provide it
# as a kwarg if it is True.
# as a kwarg if it is True.
...
@@ -4110,6 +4114,10 @@ class CenterNetMetaArch(model.DetectionModel):
...
@@ -4110,6 +4114,10 @@ class CenterNetMetaArch(model.DetectionModel):
fields
.
DetectionResultFields
.
num_detections
:
num_detections
,
fields
.
DetectionResultFields
.
num_detections
:
num_detections
,
}
}
if
self
.
_output_prediction_dict
:
postprocess_dict
.
update
(
prediction_dict
)
postprocess_dict
[
'true_image_shapes'
]
=
true_image_shapes
boxes_strided
=
None
boxes_strided
=
None
if
self
.
_od_params
:
if
self
.
_od_params
:
boxes_strided
=
(
boxes_strided
=
(
...
@@ -4122,7 +4130,7 @@ class CenterNetMetaArch(model.DetectionModel):
...
@@ -4122,7 +4130,7 @@ class CenterNetMetaArch(model.DetectionModel):
postprocess_dict
.
update
({
postprocess_dict
.
update
({
fields
.
DetectionResultFields
.
detection_boxes
:
boxes
,
fields
.
DetectionResultFields
.
detection_boxes
:
boxes
,
'detection_boxes_strided'
:
boxes_strided
'detection_boxes_strided'
:
boxes_strided
,
})
})
if
self
.
_kp_params_dict
:
if
self
.
_kp_params_dict
:
...
...
research/object_detection/protos/center_net.proto
View file @
af1a6c57
...
@@ -11,7 +11,7 @@ import "object_detection/protos/preprocessor.proto";
...
@@ -11,7 +11,7 @@ import "object_detection/protos/preprocessor.proto";
// Points" paper [1]
// Points" paper [1]
// [1]: https://arxiv.org/abs/1904.07850
// [1]: https://arxiv.org/abs/1904.07850
// Next Id =
1
6
// Next Id =
2
6
message
CenterNet
{
message
CenterNet
{
// Number of classes to predict.
// Number of classes to predict.
optional
int32
num_classes
=
1
;
optional
int32
num_classes
=
1
;
...
@@ -504,6 +504,7 @@ message CenterNet {
...
@@ -504,6 +504,7 @@ message CenterNet {
// within error bars.
// within error bars.
optional
bool
use_only_last_stage
=
24
[
default
=
false
];
optional
bool
use_only_last_stage
=
24
[
default
=
false
];
}
}
optional
DeepMACMaskEstimation
deepmac_mask_estimation
=
14
;
optional
DeepMACMaskEstimation
deepmac_mask_estimation
=
14
;
...
@@ -514,6 +515,10 @@ message CenterNet {
...
@@ -514,6 +515,10 @@ message CenterNet {
// from CenterNet. Use this optional parameter to apply traditional non max
// from CenterNet. Use this optional parameter to apply traditional non max
// suppression and score thresholding.
// suppression and score thresholding.
optional
PostProcessing
post_processing
=
24
;
optional
PostProcessing
post_processing
=
24
;
// If set, dictionary items returned by the predict() function
// are appended to the output of postprocess().
optional
bool
output_prediction_dict
=
25
[
default
=
false
];
}
}
enum
LossNormalize
{
enum
LossNormalize
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment