Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
2ebe7c3c
Commit
2ebe7c3c
authored
Sep 25, 2020
by
Liangzhe Yuan
Committed by
TF Object Detection Team
Sep 25, 2020
Browse files
Support to use separable_conv in CenterNet task head.
PiperOrigin-RevId: 333840074
parent
59888a74
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
65 additions
and
24 deletions
+65
-24
research/object_detection/builders/model_builder.py
research/object_detection/builders/model_builder.py
+2
-1
research/object_detection/meta_architectures/center_net_meta_arch.py
...ject_detection/meta_architectures/center_net_meta_arch.py
+49
-20
research/object_detection/meta_architectures/center_net_meta_arch_tf2_test.py
...ction/meta_architectures/center_net_meta_arch_tf2_test.py
+6
-3
research/object_detection/protos/center_net.proto
research/object_detection/protos/center_net.proto
+8
-0
No files found.
research/object_detection/builders/model_builder.py
View file @
2ebe7c3c
...
...
@@ -1035,7 +1035,8 @@ def _build_center_net_model(center_net_config, is_training, add_summaries):
mask_params
=
mask_params
,
densepose_params
=
densepose_params
,
track_params
=
track_params
,
temporal_offset_params
=
temporal_offset_params
)
temporal_offset_params
=
temporal_offset_params
,
use_depthwise
=
center_net_config
.
use_depthwise
)
def
_build_center_net_feature_extractor
(
...
...
research/object_detection/meta_architectures/center_net_meta_arch.py
View file @
2ebe7c3c
...
...
@@ -139,7 +139,7 @@ class CenterNetFeatureExtractor(tf.keras.Model):
def
make_prediction_net
(
num_out_channels
,
kernel_size
=
3
,
num_filters
=
256
,
bias_fill
=
None
):
bias_fill
=
None
,
use_depthwise
=
False
,
name
=
None
):
"""Creates a network to predict the given number of output channels.
This function is intended to make the prediction heads for the CenterNet
...
...
@@ -151,12 +151,19 @@ def make_prediction_net(num_out_channels, kernel_size=3, num_filters=256,
num_filters: The number of filters in the intermediate conv layer.
bias_fill: If not None, is used to initialize the bias in the final conv
layer.
use_depthwise: If true, use SeparableConv2D to construct the Sequential
layers instead of Conv2D.
name: Optional name for the prediction net.
Returns:
net: A keras module which when called on an input tensor of size
[batch_size, height, width, num_in_channels] returns an output
of size [batch_size, height, width, num_out_channels]
"""
if
use_depthwise
:
conv_fn
=
tf
.
keras
.
layers
.
SeparableConv2D
else
:
conv_fn
=
tf
.
keras
.
layers
.
Conv2D
out_conv
=
tf
.
keras
.
layers
.
Conv2D
(
num_out_channels
,
kernel_size
=
1
)
...
...
@@ -164,11 +171,10 @@ def make_prediction_net(num_out_channels, kernel_size=3, num_filters=256,
out_conv
.
bias_initializer
=
tf
.
keras
.
initializers
.
constant
(
bias_fill
)
net
=
tf
.
keras
.
Sequential
(
[
tf
.
keras
.
layers
.
Conv2D
(
num_filters
,
kernel_size
=
kernel_size
,
padding
=
'same'
),
[
conv_fn
(
num_filters
,
kernel_size
=
kernel_size
,
padding
=
'same'
),
tf
.
keras
.
layers
.
ReLU
(),
out_conv
]
)
out_conv
]
,
name
=
name
)
return
net
...
...
@@ -1673,7 +1679,8 @@ class CenterNetMetaArch(model.DetectionModel):
mask_params
=
None
,
densepose_params
=
None
,
track_params
=
None
,
temporal_offset_params
=
None
):
temporal_offset_params
=
None
,
use_depthwise
=
False
):
"""Initializes a CenterNet model.
Args:
...
...
@@ -1710,6 +1717,8 @@ class CenterNetMetaArch(model.DetectionModel):
definition for more details.
temporal_offset_params: A TemporalOffsetParams namedtuple. This object
holds the hyper-parameters for offset prediction based tracking.
use_depthwise: If true, all task heads will be constructed using
separable_conv. Otherwise, standard convoltuions will be used.
"""
assert
object_detection_params
or
keypoint_params_dict
# Shorten the name for convenience and better formatting.
...
...
@@ -1732,6 +1741,8 @@ class CenterNetMetaArch(model.DetectionModel):
self
.
_track_params
=
track_params
self
.
_temporal_offset_params
=
temporal_offset_params
self
.
_use_depthwise
=
use_depthwise
# Construct the prediction head nets.
self
.
_prediction_head_dict
=
self
.
_construct_prediction_heads
(
num_classes
,
...
...
@@ -1775,58 +1786,75 @@ class CenterNetMetaArch(model.DetectionModel):
"""
prediction_heads
=
{}
prediction_heads
[
OBJECT_CENTER
]
=
[
make_prediction_net
(
num_classes
,
bias_fill
=
class_prediction_bias_init
)
make_prediction_net
(
num_classes
,
bias_fill
=
class_prediction_bias_init
,
use_depthwise
=
self
.
_use_depthwise
)
for
_
in
range
(
num_feature_outputs
)
]
if
self
.
_od_params
is
not
None
:
prediction_heads
[
BOX_SCALE
]
=
[
make_prediction_net
(
NUM_SIZE_CHANNELS
)
make_prediction_net
(
NUM_SIZE_CHANNELS
,
use_depthwise
=
self
.
_use_depthwise
)
for
_
in
range
(
num_feature_outputs
)
]
prediction_heads
[
BOX_OFFSET
]
=
[
make_prediction_net
(
NUM_OFFSET_CHANNELS
)
make_prediction_net
(
NUM_OFFSET_CHANNELS
,
use_depthwise
=
self
.
_use_depthwise
)
for
_
in
range
(
num_feature_outputs
)
]
if
self
.
_kp_params_dict
is
not
None
:
for
task_name
,
kp_params
in
self
.
_kp_params_dict
.
items
():
num_keypoints
=
len
(
kp_params
.
keypoint_indices
)
# pylint: disable=g-complex-comprehension
prediction_heads
[
get_keypoint_name
(
task_name
,
KEYPOINT_HEATMAP
)]
=
[
make_prediction_net
(
num_keypoints
,
bias_fill
=
kp_params
.
heatmap_bias_init
)
num_keypoints
,
bias_fill
=
kp_params
.
heatmap_bias_init
,
use_depthwise
=
self
.
_use_depthwise
)
for
_
in
range
(
num_feature_outputs
)
]
# pylint: enable=g-complex-comprehension
prediction_heads
[
get_keypoint_name
(
task_name
,
KEYPOINT_REGRESSION
)]
=
[
make_prediction_net
(
NUM_OFFSET_CHANNELS
*
num_keypoints
)
make_prediction_net
(
NUM_OFFSET_CHANNELS
*
num_keypoints
,
use_depthwise
=
self
.
_use_depthwise
)
for
_
in
range
(
num_feature_outputs
)
]
if
kp_params
.
per_keypoint_offset
:
prediction_heads
[
get_keypoint_name
(
task_name
,
KEYPOINT_OFFSET
)]
=
[
make_prediction_net
(
NUM_OFFSET_CHANNELS
*
num_keypoints
)
make_prediction_net
(
NUM_OFFSET_CHANNELS
*
num_keypoints
,
use_depthwise
=
self
.
_use_depthwise
)
for
_
in
range
(
num_feature_outputs
)
]
else
:
prediction_heads
[
get_keypoint_name
(
task_name
,
KEYPOINT_OFFSET
)]
=
[
make_prediction_net
(
NUM_OFFSET_CHANNELS
)
make_prediction_net
(
NUM_OFFSET_CHANNELS
,
use_depthwise
=
self
.
_use_depthwise
)
for
_
in
range
(
num_feature_outputs
)
]
# pylint: disable=g-complex-comprehension
if
self
.
_mask_params
is
not
None
:
prediction_heads
[
SEGMENTATION_HEATMAP
]
=
[
make_prediction_net
(
num_classes
,
bias_fill
=
self
.
_mask_params
.
heatmap_bias_init
)
make_prediction_net
(
num_classes
,
bias_fill
=
self
.
_mask_params
.
heatmap_bias_init
,
use_depthwise
=
self
.
_use_depthwise
)
for
_
in
range
(
num_feature_outputs
)]
if
self
.
_densepose_params
is
not
None
:
prediction_heads
[
DENSEPOSE_HEATMAP
]
=
[
make_prediction_net
(
# pylint: disable=g-complex-comprehension
make_prediction_net
(
self
.
_densepose_params
.
num_parts
,
bias_fill
=
self
.
_densepose_params
.
heatmap_bias_init
)
bias_fill
=
self
.
_densepose_params
.
heatmap_bias_init
,
use_depthwise
=
self
.
_use_depthwise
)
for
_
in
range
(
num_feature_outputs
)]
prediction_heads
[
DENSEPOSE_REGRESSION
]
=
[
make_prediction_net
(
2
*
self
.
_densepose_params
.
num_parts
)
make_prediction_net
(
2
*
self
.
_densepose_params
.
num_parts
,
use_depthwise
=
self
.
_use_depthwise
)
for
_
in
range
(
num_feature_outputs
)
]
# pylint: enable=g-complex-comprehension
if
self
.
_track_params
is
not
None
:
prediction_heads
[
TRACK_REID
]
=
[
make_prediction_net
(
self
.
_track_params
.
reid_embed_size
)
make_prediction_net
(
self
.
_track_params
.
reid_embed_size
,
use_depthwise
=
self
.
_use_depthwise
)
for
_
in
range
(
num_feature_outputs
)]
# Creates a classification network to train object embeddings by learning
...
...
@@ -1846,7 +1874,8 @@ class CenterNetMetaArch(model.DetectionModel):
self
.
_track_params
.
reid_embed_size
,)))
if
self
.
_temporal_offset_params
is
not
None
:
prediction_heads
[
TEMPORAL_OFFSET
]
=
[
make_prediction_net
(
NUM_OFFSET_CHANNELS
)
make_prediction_net
(
NUM_OFFSET_CHANNELS
,
use_depthwise
=
self
.
_use_depthwise
)
for
_
in
range
(
num_feature_outputs
)
]
return
prediction_heads
...
...
research/object_detection/meta_architectures/center_net_meta_arch_tf2_test.py
View file @
2ebe7c3c
...
...
@@ -35,11 +35,14 @@ from object_detection.utils import tf_version
@
unittest
.
skipIf
(
tf_version
.
is_tf1
(),
'Skipping TF2.X only test.'
)
class
CenterNetMetaArchPredictionHeadTest
(
test_case
.
TestCase
):
class
CenterNetMetaArchPredictionHeadTest
(
test_case
.
TestCase
,
parameterized
.
TestCase
):
"""Test CenterNet meta architecture prediction head."""
def
test_prediction_head
(
self
):
head
=
cnma
.
make_prediction_net
(
num_out_channels
=
7
)
@
parameterized
.
parameters
([
True
,
False
])
def
test_prediction_head
(
self
,
use_depthwise
):
head
=
cnma
.
make_prediction_net
(
num_out_channels
=
7
,
use_depthwise
=
use_depthwise
)
output
=
head
(
np
.
zeros
((
4
,
128
,
128
,
8
)))
self
.
assertEqual
((
4
,
128
,
128
,
7
),
output
.
shape
)
...
...
research/object_detection/protos/center_net.proto
View file @
2ebe7c3c
...
...
@@ -19,6 +19,9 @@ message CenterNet {
// Image resizer for preprocessing the input image.
optional
ImageResizer
image_resizer
=
3
;
// If set, all task heads will be constructed with separable convolutions.
optional
bool
use_depthwise
=
13
[
default
=
false
];
// Parameters which are related to object detection task.
message
ObjectDetection
{
// The original fields are moved to ObjectCenterParams or deleted.
...
...
@@ -278,4 +281,9 @@ message CenterNetFeatureExtractor {
// If set, will change channel order to be [blue, green, red]. This can be
// useful to be compatible with some pre-trained feature extractors.
optional
bool
bgr_ordering
=
4
[
default
=
false
];
// If set, the feature upsampling layers will be constructed with
// separable convolutions. This is typically applied to feature pyramid
// network if any.
optional
bool
use_depthwise
=
5
[
default
=
false
];
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment