Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
6e5c5a1b
Commit
6e5c5a1b
authored
Mar 23, 2021
by
Yu-hui Chen
Committed by
TF Object Detection Team
Mar 23, 2021
Browse files
Updated the logcis such that the CenterNet prediction head architectures are configurable.
PiperOrigin-RevId: 364731599
parent
5c3e08b7
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
173 additions
and
28 deletions
+173
-28
research/object_detection/builders/model_builder.py
research/object_detection/builders/model_builder.py
+36
-2
research/object_detection/builders/model_builder_tf2_test.py
research/object_detection/builders/model_builder_tf2_test.py
+20
-0
research/object_detection/meta_architectures/center_net_meta_arch.py
...ject_detection/meta_architectures/center_net_meta_arch.py
+81
-25
research/object_detection/meta_architectures/center_net_meta_arch_tf2_test.py
...ction/meta_architectures/center_net_meta_arch_tf2_test.py
+3
-1
research/object_detection/protos/center_net.proto
research/object_detection/protos/center_net.proto
+33
-0
No files found.
research/object_detection/builders/model_builder.py
View file @
6e5c5a1b
...
@@ -860,6 +860,25 @@ def keypoint_proto_to_params(kp_config, keypoint_map_dict):
...
@@ -860,6 +860,25 @@ def keypoint_proto_to_params(kp_config, keypoint_map_dict):
for
label
,
value
in
kp_config
.
keypoint_label_to_std
.
items
():
for
label
,
value
in
kp_config
.
keypoint_label_to_std
.
items
():
keypoint_std_dev_dict
[
label
]
=
value
keypoint_std_dev_dict
[
label
]
=
value
keypoint_std_dev
=
[
keypoint_std_dev_dict
[
label
]
for
label
in
keypoint_labels
]
keypoint_std_dev
=
[
keypoint_std_dev_dict
[
label
]
for
label
in
keypoint_labels
]
if
kp_config
.
HasField
(
'heatmap_head_params'
):
heatmap_head_num_filters
=
list
(
kp_config
.
heatmap_head_params
.
num_filters
)
heatmap_head_kernel_sizes
=
list
(
kp_config
.
heatmap_head_params
.
kernel_sizes
)
else
:
heatmap_head_num_filters
=
[
256
]
heatmap_head_kernel_sizes
=
[
3
]
if
kp_config
.
HasField
(
'offset_head_params'
):
offset_head_num_filters
=
list
(
kp_config
.
offset_head_params
.
num_filters
)
offset_head_kernel_sizes
=
list
(
kp_config
.
offset_head_params
.
kernel_sizes
)
else
:
offset_head_num_filters
=
[
256
]
offset_head_kernel_sizes
=
[
3
]
if
kp_config
.
HasField
(
'regress_head_params'
):
regress_head_num_filters
=
list
(
kp_config
.
regress_head_params
.
num_filters
)
regress_head_kernel_sizes
=
list
(
kp_config
.
regress_head_params
.
kernel_sizes
)
else
:
regress_head_num_filters
=
[
256
]
regress_head_kernel_sizes
=
[
3
]
return
center_net_meta_arch
.
KeypointEstimationParams
(
return
center_net_meta_arch
.
KeypointEstimationParams
(
task_name
=
kp_config
.
task_name
,
task_name
=
kp_config
.
task_name
,
class_id
=
label_map_item
.
id
-
CLASS_ID_OFFSET
,
class_id
=
label_map_item
.
id
-
CLASS_ID_OFFSET
,
...
@@ -888,7 +907,13 @@ def keypoint_proto_to_params(kp_config, keypoint_map_dict):
...
@@ -888,7 +907,13 @@ def keypoint_proto_to_params(kp_config, keypoint_map_dict):
keypoint_depth_loss_weight
=
kp_config
.
keypoint_depth_loss_weight
,
keypoint_depth_loss_weight
=
kp_config
.
keypoint_depth_loss_weight
,
score_distance_offset
=
kp_config
.
score_distance_offset
,
score_distance_offset
=
kp_config
.
score_distance_offset
,
clip_out_of_frame_keypoints
=
kp_config
.
clip_out_of_frame_keypoints
,
clip_out_of_frame_keypoints
=
kp_config
.
clip_out_of_frame_keypoints
,
rescore_instances
=
kp_config
.
rescore_instances
)
rescore_instances
=
kp_config
.
rescore_instances
,
heatmap_head_num_filters
=
heatmap_head_num_filters
,
heatmap_head_kernel_sizes
=
heatmap_head_kernel_sizes
,
offset_head_num_filters
=
offset_head_num_filters
,
offset_head_kernel_sizes
=
offset_head_kernel_sizes
,
regress_head_num_filters
=
regress_head_num_filters
,
regress_head_kernel_sizes
=
regress_head_kernel_sizes
)
def
object_detection_proto_to_params
(
od_config
):
def
object_detection_proto_to_params
(
od_config
):
...
@@ -921,6 +946,13 @@ def object_center_proto_to_params(oc_config):
...
@@ -921,6 +946,13 @@ def object_center_proto_to_params(oc_config):
keypoint_weights_for_center
=
[]
keypoint_weights_for_center
=
[]
if
oc_config
.
keypoint_weights_for_center
:
if
oc_config
.
keypoint_weights_for_center
:
keypoint_weights_for_center
=
list
(
oc_config
.
keypoint_weights_for_center
)
keypoint_weights_for_center
=
list
(
oc_config
.
keypoint_weights_for_center
)
if
oc_config
.
center_head_params
:
center_head_num_filters
=
list
(
oc_config
.
center_head_params
.
num_filters
)
center_head_kernel_sizes
=
list
(
oc_config
.
center_head_params
.
kernel_sizes
)
else
:
center_head_num_filters
=
[
256
]
center_head_kernel_sizes
=
[
3
]
return
center_net_meta_arch
.
ObjectCenterParams
(
return
center_net_meta_arch
.
ObjectCenterParams
(
classification_loss
=
classification_loss
,
classification_loss
=
classification_loss
,
object_center_loss_weight
=
oc_config
.
object_center_loss_weight
,
object_center_loss_weight
=
oc_config
.
object_center_loss_weight
,
...
@@ -928,7 +960,9 @@ def object_center_proto_to_params(oc_config):
...
@@ -928,7 +960,9 @@ def object_center_proto_to_params(oc_config):
min_box_overlap_iou
=
oc_config
.
min_box_overlap_iou
,
min_box_overlap_iou
=
oc_config
.
min_box_overlap_iou
,
max_box_predictions
=
oc_config
.
max_box_predictions
,
max_box_predictions
=
oc_config
.
max_box_predictions
,
use_labeled_classes
=
oc_config
.
use_labeled_classes
,
use_labeled_classes
=
oc_config
.
use_labeled_classes
,
keypoint_weights_for_center
=
keypoint_weights_for_center
)
keypoint_weights_for_center
=
keypoint_weights_for_center
,
center_head_num_filters
=
center_head_num_filters
,
center_head_kernel_sizes
=
center_head_kernel_sizes
)
def
mask_proto_to_params
(
mask_config
):
def
mask_proto_to_params
(
mask_config
):
...
...
research/object_detection/builders/model_builder_tf2_test.py
View file @
6e5c5a1b
...
@@ -120,6 +120,12 @@ class ModelBuilderTF2Test(model_builder_test.ModelBuilderTest):
...
@@ -120,6 +120,12 @@ class ModelBuilderTF2Test(model_builder_test.ModelBuilderTest):
predict_depth: true
predict_depth: true
per_keypoint_depth: true
per_keypoint_depth: true
keypoint_depth_loss_weight: 0.3
keypoint_depth_loss_weight: 0.3
heatmap_head_params {
num_filters: 64
num_filters: 32
kernel_sizes: 5
kernel_sizes: 3
}
"""
"""
config
=
text_format
.
Merge
(
task_proto_txt
,
config
=
text_format
.
Merge
(
task_proto_txt
,
center_net_pb2
.
CenterNet
.
KeypointEstimation
())
center_net_pb2
.
CenterNet
.
KeypointEstimation
())
...
@@ -137,6 +143,12 @@ class ModelBuilderTF2Test(model_builder_test.ModelBuilderTest):
...
@@ -137,6 +143,12 @@ class ModelBuilderTF2Test(model_builder_test.ModelBuilderTest):
beta: 4.0
beta: 4.0
}
}
}
}
center_head_params {
num_filters: 64
num_filters: 32
kernel_sizes: 5
kernel_sizes: 3
}
"""
"""
return
text_format
.
Merge
(
proto_txt
,
return
text_format
.
Merge
(
proto_txt
,
center_net_pb2
.
CenterNet
.
ObjectCenterParams
())
center_net_pb2
.
CenterNet
.
ObjectCenterParams
())
...
@@ -257,6 +269,8 @@ class ModelBuilderTF2Test(model_builder_test.ModelBuilderTest):
...
@@ -257,6 +269,8 @@ class ModelBuilderTF2Test(model_builder_test.ModelBuilderTest):
self
.
assertAlmostEqual
(
self
.
assertAlmostEqual
(
model
.
_center_params
.
heatmap_bias_init
,
3.14
,
places
=
4
)
model
.
_center_params
.
heatmap_bias_init
,
3.14
,
places
=
4
)
self
.
assertEqual
(
model
.
_center_params
.
max_box_predictions
,
15
)
self
.
assertEqual
(
model
.
_center_params
.
max_box_predictions
,
15
)
self
.
assertEqual
(
model
.
_center_params
.
center_head_num_filters
,
[
64
,
32
])
self
.
assertEqual
(
model
.
_center_params
.
center_head_kernel_sizes
,
[
5
,
3
])
# Check object detection related parameters.
# Check object detection related parameters.
self
.
assertAlmostEqual
(
model
.
_od_params
.
offset_loss_weight
,
0.1
)
self
.
assertAlmostEqual
(
model
.
_od_params
.
offset_loss_weight
,
0.1
)
...
@@ -291,6 +305,12 @@ class ModelBuilderTF2Test(model_builder_test.ModelBuilderTest):
...
@@ -291,6 +305,12 @@ class ModelBuilderTF2Test(model_builder_test.ModelBuilderTest):
self
.
assertEqual
(
kp_params
.
predict_depth
,
True
)
self
.
assertEqual
(
kp_params
.
predict_depth
,
True
)
self
.
assertEqual
(
kp_params
.
per_keypoint_depth
,
True
)
self
.
assertEqual
(
kp_params
.
per_keypoint_depth
,
True
)
self
.
assertAlmostEqual
(
kp_params
.
keypoint_depth_loss_weight
,
0.3
)
self
.
assertAlmostEqual
(
kp_params
.
keypoint_depth_loss_weight
,
0.3
)
# Set by the config.
self
.
assertEqual
(
kp_params
.
heatmap_head_num_filters
,
[
64
,
32
])
self
.
assertEqual
(
kp_params
.
heatmap_head_kernel_sizes
,
[
5
,
3
])
# Default values:
self
.
assertEqual
(
kp_params
.
offset_head_num_filters
,
[
256
])
self
.
assertEqual
(
kp_params
.
offset_head_kernel_sizes
,
[
3
])
# Check mask related parameters.
# Check mask related parameters.
self
.
assertAlmostEqual
(
model
.
_mask_params
.
task_loss_weight
,
0.7
)
self
.
assertAlmostEqual
(
model
.
_mask_params
.
task_loss_weight
,
0.7
)
...
...
research/object_detection/meta_architectures/center_net_meta_arch.py
View file @
6e5c5a1b
...
@@ -137,7 +137,7 @@ class CenterNetFeatureExtractor(tf.keras.Model):
...
@@ -137,7 +137,7 @@ class CenterNetFeatureExtractor(tf.keras.Model):
pass
pass
def
make_prediction_net
(
num_out_channels
,
kernel_size
=
3
,
num_filters
=
256
,
def
make_prediction_net
(
num_out_channels
,
kernel_size
s
=
(
3
)
,
num_filters
=
(
256
)
,
bias_fill
=
None
,
use_depthwise
=
False
,
name
=
None
):
bias_fill
=
None
,
use_depthwise
=
False
,
name
=
None
):
"""Creates a network to predict the given number of output channels.
"""Creates a network to predict the given number of output channels.
...
@@ -146,8 +146,13 @@ def make_prediction_net(num_out_channels, kernel_size=3, num_filters=256,
...
@@ -146,8 +146,13 @@ def make_prediction_net(num_out_channels, kernel_size=3, num_filters=256,
Args:
Args:
num_out_channels: Number of output channels.
num_out_channels: Number of output channels.
kernel_size: The size of the conv kernel in the intermediate layer
kernel_sizes: A list representing the sizes of the conv kernel in the
num_filters: The number of filters in the intermediate conv layer.
intermediate layer. Note that the length of the list indicates the number
of intermediate conv layers and it must be the same as the length of the
num_filters.
num_filters: A list representing the number of filters in the intermediate
conv layer. Note that the length of the list indicates the number of
intermediate conv layers.
bias_fill: If not None, is used to initialize the bias in the final conv
bias_fill: If not None, is used to initialize the bias in the final conv
layer.
layer.
use_depthwise: If true, use SeparableConv2D to construct the Sequential
use_depthwise: If true, use SeparableConv2D to construct the Sequential
...
@@ -159,6 +164,10 @@ def make_prediction_net(num_out_channels, kernel_size=3, num_filters=256,
...
@@ -159,6 +164,10 @@ def make_prediction_net(num_out_channels, kernel_size=3, num_filters=256,
[batch_size, height, width, num_in_channels] returns an output
[batch_size, height, width, num_in_channels] returns an output
of size [batch_size, height, width, num_out_channels]
of size [batch_size, height, width, num_out_channels]
"""
"""
if
isinstance
(
kernel_sizes
,
int
)
and
isinstance
(
num_filters
,
int
):
kernel_sizes
=
[
kernel_sizes
]
num_filters
=
[
num_filters
]
assert
len
(
kernel_sizes
)
==
len
(
num_filters
)
if
use_depthwise
:
if
use_depthwise
:
conv_fn
=
tf
.
keras
.
layers
.
SeparableConv2D
conv_fn
=
tf
.
keras
.
layers
.
SeparableConv2D
else
:
else
:
...
@@ -175,16 +184,18 @@ def make_prediction_net(num_out_channels, kernel_size=3, num_filters=256,
...
@@ -175,16 +184,18 @@ def make_prediction_net(num_out_channels, kernel_size=3, num_filters=256,
if
bias_fill
is
not
None
:
if
bias_fill
is
not
None
:
out_conv
.
bias_initializer
=
tf
.
keras
.
initializers
.
constant
(
bias_fill
)
out_conv
.
bias_initializer
=
tf
.
keras
.
initializers
.
constant
(
bias_fill
)
net
=
tf
.
keras
.
Sequential
([
layers
=
[]
for
idx
,
(
kernel_size
,
num_filter
)
in
enumerate
(
zip
(
kernel_sizes
,
num_filters
)):
layers
.
append
(
conv_fn
(
conv_fn
(
num_filter
s
,
num_filter
,
kernel_size
=
kernel_size
,
kernel_size
=
kernel_size
,
padding
=
'same'
,
padding
=
'same'
,
name
=
'conv2'
if
tf_version
.
is_tf1
()
else
None
),
name
=
'conv2_%d'
%
idx
if
tf_version
.
is_tf1
()
else
None
))
tf
.
keras
.
layers
.
ReLU
(),
out_conv
layers
.
append
(
tf
.
keras
.
layers
.
ReLU
())
],
layers
.
append
(
out_conv
)
name
=
name
)
net
=
tf
.
keras
.
Sequential
(
layers
,
name
=
name
)
return
net
return
net
...
@@ -1687,7 +1698,10 @@ class KeypointEstimationParams(
...
@@ -1687,7 +1698,10 @@ class KeypointEstimationParams(
'offset_peak_radius'
,
'per_keypoint_offset'
,
'predict_depth'
,
'offset_peak_radius'
,
'per_keypoint_offset'
,
'predict_depth'
,
'per_keypoint_depth'
,
'keypoint_depth_loss_weight'
,
'per_keypoint_depth'
,
'keypoint_depth_loss_weight'
,
'score_distance_offset'
,
'clip_out_of_frame_keypoints'
,
'score_distance_offset'
,
'clip_out_of_frame_keypoints'
,
'rescore_instances'
'rescore_instances'
,
'heatmap_head_num_filters'
,
'heatmap_head_kernel_sizes'
,
'offset_head_num_filters'
,
'offset_head_kernel_sizes'
,
'regress_head_num_filters'
,
'regress_head_kernel_sizes'
])):
])):
"""Namedtuple to host object detection related parameters.
"""Namedtuple to host object detection related parameters.
...
@@ -1726,7 +1740,13 @@ class KeypointEstimationParams(
...
@@ -1726,7 +1740,13 @@ class KeypointEstimationParams(
keypoint_depth_loss_weight
=
1.0
,
keypoint_depth_loss_weight
=
1.0
,
score_distance_offset
=
1e-6
,
score_distance_offset
=
1e-6
,
clip_out_of_frame_keypoints
=
False
,
clip_out_of_frame_keypoints
=
False
,
rescore_instances
=
False
):
rescore_instances
=
False
,
heatmap_head_num_filters
=
(
256
),
heatmap_head_kernel_sizes
=
(
3
),
offset_head_num_filters
=
(
256
),
offset_head_kernel_sizes
=
(
3
),
regress_head_num_filters
=
(
256
),
regress_head_kernel_sizes
=
(
3
)):
"""Constructor with default values for KeypointEstimationParams.
"""Constructor with default values for KeypointEstimationParams.
Args:
Args:
...
@@ -1806,6 +1826,18 @@ class KeypointEstimationParams(
...
@@ -1806,6 +1826,18 @@ class KeypointEstimationParams(
that are clipped have scores set to 0.0.
that are clipped have scores set to 0.0.
rescore_instances: Whether to rescore instances based on a combination of
rescore_instances: Whether to rescore instances based on a combination of
detection score and keypoint scores.
detection score and keypoint scores.
heatmap_head_num_filters: filter numbers of the convolutional layers used
by the keypoint heatmap prediction head.
heatmap_head_kernel_sizes: kernel size of the convolutional layers used
by the keypoint heatmap prediction head.
offset_head_num_filters: filter numbers of the convolutional layers used
by the keypoint offset prediction head.
offset_head_kernel_sizes: kernel size of the convolutional layers used
by the keypoint offset prediction head.
regress_head_num_filters: filter numbers of the convolutional layers used
by the keypoint regression prediction head.
regress_head_kernel_sizes: kernel size of the convolutional layers used
by the keypoint regression prediction head.
Returns:
Returns:
An initialized KeypointEstimationParams namedtuple.
An initialized KeypointEstimationParams namedtuple.
...
@@ -1820,14 +1852,18 @@ class KeypointEstimationParams(
...
@@ -1820,14 +1852,18 @@ class KeypointEstimationParams(
candidate_search_scale
,
candidate_ranking_mode
,
offset_peak_radius
,
candidate_search_scale
,
candidate_ranking_mode
,
offset_peak_radius
,
per_keypoint_offset
,
predict_depth
,
per_keypoint_depth
,
per_keypoint_offset
,
predict_depth
,
per_keypoint_depth
,
keypoint_depth_loss_weight
,
score_distance_offset
,
keypoint_depth_loss_weight
,
score_distance_offset
,
clip_out_of_frame_keypoints
,
rescore_instances
)
clip_out_of_frame_keypoints
,
rescore_instances
,
heatmap_head_num_filters
,
heatmap_head_kernel_sizes
,
offset_head_num_filters
,
offset_head_kernel_sizes
,
regress_head_num_filters
,
regress_head_kernel_sizes
)
class
ObjectCenterParams
(
class
ObjectCenterParams
(
collections
.
namedtuple
(
'ObjectCenterParams'
,
[
collections
.
namedtuple
(
'ObjectCenterParams'
,
[
'classification_loss'
,
'object_center_loss_weight'
,
'heatmap_bias_init'
,
'classification_loss'
,
'object_center_loss_weight'
,
'heatmap_bias_init'
,
'min_box_overlap_iou'
,
'max_box_predictions'
,
'use_labeled_classes'
,
'min_box_overlap_iou'
,
'max_box_predictions'
,
'use_labeled_classes'
,
'keypoint_weights_for_center'
'keypoint_weights_for_center'
,
'center_head_num_filters'
,
'center_head_kernel_sizes'
])):
])):
"""Namedtuple to store object center prediction related parameters."""
"""Namedtuple to store object center prediction related parameters."""
...
@@ -1840,7 +1876,9 @@ class ObjectCenterParams(
...
@@ -1840,7 +1876,9 @@ class ObjectCenterParams(
min_box_overlap_iou
=
0.7
,
min_box_overlap_iou
=
0.7
,
max_box_predictions
=
100
,
max_box_predictions
=
100
,
use_labeled_classes
=
False
,
use_labeled_classes
=
False
,
keypoint_weights_for_center
=
None
):
keypoint_weights_for_center
=
None
,
center_head_num_filters
=
(
256
),
center_head_kernel_sizes
=
(
3
)):
"""Constructor with default values for ObjectCenterParams.
"""Constructor with default values for ObjectCenterParams.
Args:
Args:
...
@@ -1861,7 +1899,10 @@ class ObjectCenterParams(
...
@@ -1861,7 +1899,10 @@ class ObjectCenterParams(
center is calculated by the weighted mean of the keypoint locations. If
center is calculated by the weighted mean of the keypoint locations. If
not provided, the object center is determined by the center of the
not provided, the object center is determined by the center of the
bounding box (default behavior).
bounding box (default behavior).
center_head_num_filters: filter numbers of the convolutional layers used
by the object center prediction head.
center_head_kernel_sizes: kernel size of the convolutional layers used
by the object center prediction head.
Returns:
Returns:
An initialized ObjectCenterParams namedtuple.
An initialized ObjectCenterParams namedtuple.
"""
"""
...
@@ -1869,7 +1910,8 @@ class ObjectCenterParams(
...
@@ -1869,7 +1910,8 @@ class ObjectCenterParams(
cls
).
__new__
(
cls
,
classification_loss
,
cls
).
__new__
(
cls
,
classification_loss
,
object_center_loss_weight
,
heatmap_bias_init
,
object_center_loss_weight
,
heatmap_bias_init
,
min_box_overlap_iou
,
max_box_predictions
,
min_box_overlap_iou
,
max_box_predictions
,
use_labeled_classes
,
keypoint_weights_for_center
)
use_labeled_classes
,
keypoint_weights_for_center
,
center_head_num_filters
,
center_head_kernel_sizes
)
class
MaskParams
(
class
MaskParams
(
...
@@ -2194,14 +2236,14 @@ class CenterNetMetaArch(model.DetectionModel):
...
@@ -2194,14 +2236,14 @@ class CenterNetMetaArch(model.DetectionModel):
return
self
.
_batched_prediction_tensor_names
return
self
.
_batched_prediction_tensor_names
def
_make_prediction_net_list
(
self
,
num_feature_outputs
,
num_out_channels
,
def
_make_prediction_net_list
(
self
,
num_feature_outputs
,
num_out_channels
,
kernel_size
=
3
,
num_filters
=
256
,
bias_fill
=
None
,
kernel_size
s
=
(
3
)
,
num_filters
=
(
256
)
,
name
=
None
):
bias_fill
=
None
,
name
=
None
):
prediction_net_list
=
[]
prediction_net_list
=
[]
for
i
in
range
(
num_feature_outputs
):
for
i
in
range
(
num_feature_outputs
):
prediction_net_list
.
append
(
prediction_net_list
.
append
(
make_prediction_net
(
make_prediction_net
(
num_out_channels
,
num_out_channels
,
kernel_size
=
kernel_size
,
kernel_size
s
=
kernel_size
s
,
num_filters
=
num_filters
,
num_filters
=
num_filters
,
bias_fill
=
bias_fill
,
bias_fill
=
bias_fill
,
use_depthwise
=
self
.
_use_depthwise
,
use_depthwise
=
self
.
_use_depthwise
,
...
@@ -2229,7 +2271,11 @@ class CenterNetMetaArch(model.DetectionModel):
...
@@ -2229,7 +2271,11 @@ class CenterNetMetaArch(model.DetectionModel):
"""
"""
prediction_heads
=
{}
prediction_heads
=
{}
prediction_heads
[
OBJECT_CENTER
]
=
self
.
_make_prediction_net_list
(
prediction_heads
[
OBJECT_CENTER
]
=
self
.
_make_prediction_net_list
(
num_feature_outputs
,
num_classes
,
bias_fill
=
class_prediction_bias_init
,
num_feature_outputs
,
num_classes
,
kernel_sizes
=
self
.
_center_params
.
center_head_kernel_sizes
,
num_filters
=
self
.
_center_params
.
center_head_num_filters
,
bias_fill
=
class_prediction_bias_init
,
name
=
'center'
)
name
=
'center'
)
if
self
.
_od_params
is
not
None
:
if
self
.
_od_params
is
not
None
:
...
@@ -2245,12 +2291,16 @@ class CenterNetMetaArch(model.DetectionModel):
...
@@ -2245,12 +2291,16 @@ class CenterNetMetaArch(model.DetectionModel):
task_name
,
KEYPOINT_HEATMAP
)]
=
self
.
_make_prediction_net_list
(
task_name
,
KEYPOINT_HEATMAP
)]
=
self
.
_make_prediction_net_list
(
num_feature_outputs
,
num_feature_outputs
,
num_keypoints
,
num_keypoints
,
kernel_sizes
=
kp_params
.
heatmap_head_kernel_sizes
,
num_filters
=
kp_params
.
heatmap_head_num_filters
,
bias_fill
=
kp_params
.
heatmap_bias_init
,
bias_fill
=
kp_params
.
heatmap_bias_init
,
name
=
'kpt_heatmap'
)
name
=
'kpt_heatmap'
)
prediction_heads
[
get_keypoint_name
(
prediction_heads
[
get_keypoint_name
(
task_name
,
KEYPOINT_REGRESSION
)]
=
self
.
_make_prediction_net_list
(
task_name
,
KEYPOINT_REGRESSION
)]
=
self
.
_make_prediction_net_list
(
num_feature_outputs
,
num_feature_outputs
,
NUM_OFFSET_CHANNELS
*
num_keypoints
,
NUM_OFFSET_CHANNELS
*
num_keypoints
,
kernel_sizes
=
kp_params
.
regress_head_kernel_sizes
,
num_filters
=
kp_params
.
regress_head_num_filters
,
name
=
'kpt_regress'
)
name
=
'kpt_regress'
)
if
kp_params
.
per_keypoint_offset
:
if
kp_params
.
per_keypoint_offset
:
...
@@ -2258,11 +2308,17 @@ class CenterNetMetaArch(model.DetectionModel):
...
@@ -2258,11 +2308,17 @@ class CenterNetMetaArch(model.DetectionModel):
task_name
,
KEYPOINT_OFFSET
)]
=
self
.
_make_prediction_net_list
(
task_name
,
KEYPOINT_OFFSET
)]
=
self
.
_make_prediction_net_list
(
num_feature_outputs
,
num_feature_outputs
,
NUM_OFFSET_CHANNELS
*
num_keypoints
,
NUM_OFFSET_CHANNELS
*
num_keypoints
,
kernel_sizes
=
kp_params
.
offset_head_kernel_sizes
,
num_filters
=
kp_params
.
offset_head_num_filters
,
name
=
'kpt_offset'
)
name
=
'kpt_offset'
)
else
:
else
:
prediction_heads
[
get_keypoint_name
(
prediction_heads
[
get_keypoint_name
(
task_name
,
KEYPOINT_OFFSET
)]
=
self
.
_make_prediction_net_list
(
task_name
,
KEYPOINT_OFFSET
)]
=
self
.
_make_prediction_net_list
(
num_feature_outputs
,
NUM_OFFSET_CHANNELS
,
name
=
'kpt_offset'
)
num_feature_outputs
,
NUM_OFFSET_CHANNELS
,
kernel_sizes
=
kp_params
.
offset_head_kernel_sizes
,
num_filters
=
kp_params
.
offset_head_num_filters
,
name
=
'kpt_offset'
)
if
kp_params
.
predict_depth
:
if
kp_params
.
predict_depth
:
num_depth_channel
=
(
num_depth_channel
=
(
...
...
research/object_detection/meta_architectures/center_net_meta_arch_tf2_test.py
View file @
6e5c5a1b
...
@@ -1482,7 +1482,9 @@ def get_fake_center_params(max_box_predictions=5):
...
@@ -1482,7 +1482,9 @@ def get_fake_center_params(max_box_predictions=5):
object_center_loss_weight
=
1.0
,
object_center_loss_weight
=
1.0
,
min_box_overlap_iou
=
1.0
,
min_box_overlap_iou
=
1.0
,
max_box_predictions
=
max_box_predictions
,
max_box_predictions
=
max_box_predictions
,
use_labeled_classes
=
False
)
use_labeled_classes
=
False
,
center_head_num_filters
=
[
128
],
center_head_kernel_sizes
=
[
5
])
def
get_fake_od_params
():
def
get_fake_od_params
():
...
...
research/object_detection/protos/center_net.proto
View file @
6e5c5a1b
...
@@ -30,6 +30,23 @@ message CenterNet {
...
@@ -30,6 +30,23 @@ message CenterNet {
// TODO(b/170989061) When bug is fixed, make this the default behavior.
// TODO(b/170989061) When bug is fixed, make this the default behavior.
optional
bool
compute_heatmap_sparse
=
15
[
default
=
false
];
optional
bool
compute_heatmap_sparse
=
15
[
default
=
false
];
// Parameters to determine the model architecture/layers of the prediction
// heads.
message
PredictionHeadParams
{
// The two fields: num_filters, kernel_sizes correspond to the parameters of
// the convolutional layers used by the prediction head. If provided, the
// length of the two repeated fields need to be the same and represents the
// number of convolutional layers.
// Corresponds to the "filters" argument in tf.keras.layers.Conv2D. If not
// provided, the default value [256] will be used.
repeated
int32
num_filters
=
1
;
// Corresponds to the "kernel_size" argument in tf.keras.layers.Conv2D. If
// not provided, the default value [3] will be used.
repeated
int32
kernel_sizes
=
2
;
}
// Parameters which are related to object detection task.
// Parameters which are related to object detection task.
message
ObjectDetection
{
message
ObjectDetection
{
// The original fields are moved to ObjectCenterParams or deleted.
// The original fields are moved to ObjectCenterParams or deleted.
...
@@ -81,6 +98,10 @@ message CenterNet {
...
@@ -81,6 +98,10 @@ message CenterNet {
// object center is determined by the bounding box groundtruth annotations
// object center is determined by the bounding box groundtruth annotations
// (default behavior).
// (default behavior).
repeated
float
keypoint_weights_for_center
=
7
;
repeated
float
keypoint_weights_for_center
=
7
;
// Parameters to determine the architecture of the object center prediction
// head.
optional
PredictionHeadParams
center_head_params
=
8
;
}
}
optional
ObjectCenterParams
object_center_params
=
5
;
optional
ObjectCenterParams
object_center_params
=
5
;
...
@@ -207,6 +228,18 @@ message CenterNet {
...
@@ -207,6 +228,18 @@ message CenterNet {
// where o is the object score, s_i is the score for keypoint i, and k is
// where o is the object score, s_i is the score for keypoint i, and k is
// the number of keypoints for that class.
// the number of keypoints for that class.
optional
bool
rescore_instances
=
24
[
default
=
false
];
optional
bool
rescore_instances
=
24
[
default
=
false
];
// Parameters to determine the architecture of the keypoint heatmap
// prediction head.
optional
PredictionHeadParams
heatmap_head_params
=
25
;
// Parameters to determine the architecture of the keypoint offset
// prediction head.
optional
PredictionHeadParams
offset_head_params
=
26
;
// Parameters to determine the architecture of the keypoint regression
// prediction head.
optional
PredictionHeadParams
regress_head_params
=
27
;
}
}
repeated
KeypointEstimation
keypoint_estimation_task
=
7
;
repeated
KeypointEstimation
keypoint_estimation_task
=
7
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment