Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
2ebe7c3c
Commit
2ebe7c3c
authored
Sep 25, 2020
by
Liangzhe Yuan
Committed by
TF Object Detection Team
Sep 25, 2020
Browse files
Support to use separable_conv in CenterNet task head.
PiperOrigin-RevId: 333840074
parent
59888a74
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
65 additions
and
24 deletions
+65
-24
research/object_detection/builders/model_builder.py
research/object_detection/builders/model_builder.py
+2
-1
research/object_detection/meta_architectures/center_net_meta_arch.py
...ject_detection/meta_architectures/center_net_meta_arch.py
+49
-20
research/object_detection/meta_architectures/center_net_meta_arch_tf2_test.py
...ction/meta_architectures/center_net_meta_arch_tf2_test.py
+6
-3
research/object_detection/protos/center_net.proto
research/object_detection/protos/center_net.proto
+8
-0
No files found.
research/object_detection/builders/model_builder.py
View file @
2ebe7c3c
...
@@ -1035,7 +1035,8 @@ def _build_center_net_model(center_net_config, is_training, add_summaries):
...
@@ -1035,7 +1035,8 @@ def _build_center_net_model(center_net_config, is_training, add_summaries):
mask_params
=
mask_params
,
mask_params
=
mask_params
,
densepose_params
=
densepose_params
,
densepose_params
=
densepose_params
,
track_params
=
track_params
,
track_params
=
track_params
,
temporal_offset_params
=
temporal_offset_params
)
temporal_offset_params
=
temporal_offset_params
,
use_depthwise
=
center_net_config
.
use_depthwise
)
def
_build_center_net_feature_extractor
(
def
_build_center_net_feature_extractor
(
...
...
research/object_detection/meta_architectures/center_net_meta_arch.py
View file @
2ebe7c3c
...
@@ -139,7 +139,7 @@ class CenterNetFeatureExtractor(tf.keras.Model):
...
@@ -139,7 +139,7 @@ class CenterNetFeatureExtractor(tf.keras.Model):
def
make_prediction_net
(
num_out_channels
,
kernel_size
=
3
,
num_filters
=
256
,
def
make_prediction_net
(
num_out_channels
,
kernel_size
=
3
,
num_filters
=
256
,
bias_fill
=
None
):
bias_fill
=
None
,
use_depthwise
=
False
,
name
=
None
):
"""Creates a network to predict the given number of output channels.
"""Creates a network to predict the given number of output channels.
This function is intended to make the prediction heads for the CenterNet
This function is intended to make the prediction heads for the CenterNet
...
@@ -151,12 +151,19 @@ def make_prediction_net(num_out_channels, kernel_size=3, num_filters=256,
...
@@ -151,12 +151,19 @@ def make_prediction_net(num_out_channels, kernel_size=3, num_filters=256,
num_filters: The number of filters in the intermediate conv layer.
num_filters: The number of filters in the intermediate conv layer.
bias_fill: If not None, is used to initialize the bias in the final conv
bias_fill: If not None, is used to initialize the bias in the final conv
layer.
layer.
use_depthwise: If true, use SeparableConv2D to construct the Sequential
layers instead of Conv2D.
name: Optional name for the prediction net.
Returns:
Returns:
net: A keras module which when called on an input tensor of size
net: A keras module which when called on an input tensor of size
[batch_size, height, width, num_in_channels] returns an output
[batch_size, height, width, num_in_channels] returns an output
of size [batch_size, height, width, num_out_channels]
of size [batch_size, height, width, num_out_channels]
"""
"""
if
use_depthwise
:
conv_fn
=
tf
.
keras
.
layers
.
SeparableConv2D
else
:
conv_fn
=
tf
.
keras
.
layers
.
Conv2D
out_conv
=
tf
.
keras
.
layers
.
Conv2D
(
num_out_channels
,
kernel_size
=
1
)
out_conv
=
tf
.
keras
.
layers
.
Conv2D
(
num_out_channels
,
kernel_size
=
1
)
...
@@ -164,11 +171,10 @@ def make_prediction_net(num_out_channels, kernel_size=3, num_filters=256,
...
@@ -164,11 +171,10 @@ def make_prediction_net(num_out_channels, kernel_size=3, num_filters=256,
out_conv
.
bias_initializer
=
tf
.
keras
.
initializers
.
constant
(
bias_fill
)
out_conv
.
bias_initializer
=
tf
.
keras
.
initializers
.
constant
(
bias_fill
)
net
=
tf
.
keras
.
Sequential
(
net
=
tf
.
keras
.
Sequential
(
[
tf
.
keras
.
layers
.
Conv2D
(
num_filters
,
kernel_size
=
kernel_size
,
[
conv_fn
(
num_filters
,
kernel_size
=
kernel_size
,
padding
=
'same'
),
padding
=
'same'
),
tf
.
keras
.
layers
.
ReLU
(),
tf
.
keras
.
layers
.
ReLU
(),
out_conv
]
out_conv
]
,
)
name
=
name
)
return
net
return
net
...
@@ -1673,7 +1679,8 @@ class CenterNetMetaArch(model.DetectionModel):
...
@@ -1673,7 +1679,8 @@ class CenterNetMetaArch(model.DetectionModel):
mask_params
=
None
,
mask_params
=
None
,
densepose_params
=
None
,
densepose_params
=
None
,
track_params
=
None
,
track_params
=
None
,
temporal_offset_params
=
None
):
temporal_offset_params
=
None
,
use_depthwise
=
False
):
"""Initializes a CenterNet model.
"""Initializes a CenterNet model.
Args:
Args:
...
@@ -1710,6 +1717,8 @@ class CenterNetMetaArch(model.DetectionModel):
...
@@ -1710,6 +1717,8 @@ class CenterNetMetaArch(model.DetectionModel):
definition for more details.
definition for more details.
temporal_offset_params: A TemporalOffsetParams namedtuple. This object
temporal_offset_params: A TemporalOffsetParams namedtuple. This object
holds the hyper-parameters for offset prediction based tracking.
holds the hyper-parameters for offset prediction based tracking.
use_depthwise: If true, all task heads will be constructed using
separable_conv. Otherwise, standard convoltuions will be used.
"""
"""
assert
object_detection_params
or
keypoint_params_dict
assert
object_detection_params
or
keypoint_params_dict
# Shorten the name for convenience and better formatting.
# Shorten the name for convenience and better formatting.
...
@@ -1732,6 +1741,8 @@ class CenterNetMetaArch(model.DetectionModel):
...
@@ -1732,6 +1741,8 @@ class CenterNetMetaArch(model.DetectionModel):
self
.
_track_params
=
track_params
self
.
_track_params
=
track_params
self
.
_temporal_offset_params
=
temporal_offset_params
self
.
_temporal_offset_params
=
temporal_offset_params
self
.
_use_depthwise
=
use_depthwise
# Construct the prediction head nets.
# Construct the prediction head nets.
self
.
_prediction_head_dict
=
self
.
_construct_prediction_heads
(
self
.
_prediction_head_dict
=
self
.
_construct_prediction_heads
(
num_classes
,
num_classes
,
...
@@ -1775,58 +1786,75 @@ class CenterNetMetaArch(model.DetectionModel):
...
@@ -1775,58 +1786,75 @@ class CenterNetMetaArch(model.DetectionModel):
"""
"""
prediction_heads
=
{}
prediction_heads
=
{}
prediction_heads
[
OBJECT_CENTER
]
=
[
prediction_heads
[
OBJECT_CENTER
]
=
[
make_prediction_net
(
num_classes
,
bias_fill
=
class_prediction_bias_init
)
make_prediction_net
(
num_classes
,
bias_fill
=
class_prediction_bias_init
,
use_depthwise
=
self
.
_use_depthwise
)
for
_
in
range
(
num_feature_outputs
)
for
_
in
range
(
num_feature_outputs
)
]
]
if
self
.
_od_params
is
not
None
:
if
self
.
_od_params
is
not
None
:
prediction_heads
[
BOX_SCALE
]
=
[
prediction_heads
[
BOX_SCALE
]
=
[
make_prediction_net
(
NUM_SIZE_CHANNELS
)
make_prediction_net
(
NUM_SIZE_CHANNELS
,
use_depthwise
=
self
.
_use_depthwise
)
for
_
in
range
(
num_feature_outputs
)
for
_
in
range
(
num_feature_outputs
)
]
]
prediction_heads
[
BOX_OFFSET
]
=
[
prediction_heads
[
BOX_OFFSET
]
=
[
make_prediction_net
(
NUM_OFFSET_CHANNELS
)
make_prediction_net
(
NUM_OFFSET_CHANNELS
,
use_depthwise
=
self
.
_use_depthwise
)
for
_
in
range
(
num_feature_outputs
)
for
_
in
range
(
num_feature_outputs
)
]
]
if
self
.
_kp_params_dict
is
not
None
:
if
self
.
_kp_params_dict
is
not
None
:
for
task_name
,
kp_params
in
self
.
_kp_params_dict
.
items
():
for
task_name
,
kp_params
in
self
.
_kp_params_dict
.
items
():
num_keypoints
=
len
(
kp_params
.
keypoint_indices
)
num_keypoints
=
len
(
kp_params
.
keypoint_indices
)
# pylint: disable=g-complex-comprehension
prediction_heads
[
get_keypoint_name
(
task_name
,
KEYPOINT_HEATMAP
)]
=
[
prediction_heads
[
get_keypoint_name
(
task_name
,
KEYPOINT_HEATMAP
)]
=
[
make_prediction_net
(
make_prediction_net
(
num_keypoints
,
bias_fill
=
kp_params
.
heatmap_bias_init
)
num_keypoints
,
bias_fill
=
kp_params
.
heatmap_bias_init
,
use_depthwise
=
self
.
_use_depthwise
)
for
_
in
range
(
num_feature_outputs
)
for
_
in
range
(
num_feature_outputs
)
]
]
# pylint: enable=g-complex-comprehension
prediction_heads
[
get_keypoint_name
(
task_name
,
KEYPOINT_REGRESSION
)]
=
[
prediction_heads
[
get_keypoint_name
(
task_name
,
KEYPOINT_REGRESSION
)]
=
[
make_prediction_net
(
NUM_OFFSET_CHANNELS
*
num_keypoints
)
make_prediction_net
(
NUM_OFFSET_CHANNELS
*
num_keypoints
,
use_depthwise
=
self
.
_use_depthwise
)
for
_
in
range
(
num_feature_outputs
)
for
_
in
range
(
num_feature_outputs
)
]
]
if
kp_params
.
per_keypoint_offset
:
if
kp_params
.
per_keypoint_offset
:
prediction_heads
[
get_keypoint_name
(
task_name
,
KEYPOINT_OFFSET
)]
=
[
prediction_heads
[
get_keypoint_name
(
task_name
,
KEYPOINT_OFFSET
)]
=
[
make_prediction_net
(
NUM_OFFSET_CHANNELS
*
num_keypoints
)
make_prediction_net
(
NUM_OFFSET_CHANNELS
*
num_keypoints
,
use_depthwise
=
self
.
_use_depthwise
)
for
_
in
range
(
num_feature_outputs
)
for
_
in
range
(
num_feature_outputs
)
]
]
else
:
else
:
prediction_heads
[
get_keypoint_name
(
task_name
,
KEYPOINT_OFFSET
)]
=
[
prediction_heads
[
get_keypoint_name
(
task_name
,
KEYPOINT_OFFSET
)]
=
[
make_prediction_net
(
NUM_OFFSET_CHANNELS
)
make_prediction_net
(
NUM_OFFSET_CHANNELS
,
use_depthwise
=
self
.
_use_depthwise
)
for
_
in
range
(
num_feature_outputs
)
for
_
in
range
(
num_feature_outputs
)
]
]
# pylint: disable=g-complex-comprehension
if
self
.
_mask_params
is
not
None
:
if
self
.
_mask_params
is
not
None
:
prediction_heads
[
SEGMENTATION_HEATMAP
]
=
[
prediction_heads
[
SEGMENTATION_HEATMAP
]
=
[
make_prediction_net
(
num_classes
,
make_prediction_net
(
bias_fill
=
self
.
_mask_params
.
heatmap_bias_init
)
num_classes
,
bias_fill
=
self
.
_mask_params
.
heatmap_bias_init
,
use_depthwise
=
self
.
_use_depthwise
)
for
_
in
range
(
num_feature_outputs
)]
for
_
in
range
(
num_feature_outputs
)]
if
self
.
_densepose_params
is
not
None
:
if
self
.
_densepose_params
is
not
None
:
prediction_heads
[
DENSEPOSE_HEATMAP
]
=
[
prediction_heads
[
DENSEPOSE_HEATMAP
]
=
[
make_prediction_net
(
# pylint: disable=g-complex-comprehension
make_prediction_net
(
self
.
_densepose_params
.
num_parts
,
self
.
_densepose_params
.
num_parts
,
bias_fill
=
self
.
_densepose_params
.
heatmap_bias_init
)
bias_fill
=
self
.
_densepose_params
.
heatmap_bias_init
,
use_depthwise
=
self
.
_use_depthwise
)
for
_
in
range
(
num_feature_outputs
)]
for
_
in
range
(
num_feature_outputs
)]
prediction_heads
[
DENSEPOSE_REGRESSION
]
=
[
prediction_heads
[
DENSEPOSE_REGRESSION
]
=
[
make_prediction_net
(
2
*
self
.
_densepose_params
.
num_parts
)
make_prediction_net
(
2
*
self
.
_densepose_params
.
num_parts
,
use_depthwise
=
self
.
_use_depthwise
)
for
_
in
range
(
num_feature_outputs
)
for
_
in
range
(
num_feature_outputs
)
]
]
# pylint: enable=g-complex-comprehension
if
self
.
_track_params
is
not
None
:
if
self
.
_track_params
is
not
None
:
prediction_heads
[
TRACK_REID
]
=
[
prediction_heads
[
TRACK_REID
]
=
[
make_prediction_net
(
self
.
_track_params
.
reid_embed_size
)
make_prediction_net
(
self
.
_track_params
.
reid_embed_size
,
use_depthwise
=
self
.
_use_depthwise
)
for
_
in
range
(
num_feature_outputs
)]
for
_
in
range
(
num_feature_outputs
)]
# Creates a classification network to train object embeddings by learning
# Creates a classification network to train object embeddings by learning
...
@@ -1846,7 +1874,8 @@ class CenterNetMetaArch(model.DetectionModel):
...
@@ -1846,7 +1874,8 @@ class CenterNetMetaArch(model.DetectionModel):
self
.
_track_params
.
reid_embed_size
,)))
self
.
_track_params
.
reid_embed_size
,)))
if
self
.
_temporal_offset_params
is
not
None
:
if
self
.
_temporal_offset_params
is
not
None
:
prediction_heads
[
TEMPORAL_OFFSET
]
=
[
prediction_heads
[
TEMPORAL_OFFSET
]
=
[
make_prediction_net
(
NUM_OFFSET_CHANNELS
)
make_prediction_net
(
NUM_OFFSET_CHANNELS
,
use_depthwise
=
self
.
_use_depthwise
)
for
_
in
range
(
num_feature_outputs
)
for
_
in
range
(
num_feature_outputs
)
]
]
return
prediction_heads
return
prediction_heads
...
...
research/object_detection/meta_architectures/center_net_meta_arch_tf2_test.py
View file @
2ebe7c3c
...
@@ -35,11 +35,14 @@ from object_detection.utils import tf_version
...
@@ -35,11 +35,14 @@ from object_detection.utils import tf_version
@
unittest
.
skipIf
(
tf_version
.
is_tf1
(),
'Skipping TF2.X only test.'
)
@
unittest
.
skipIf
(
tf_version
.
is_tf1
(),
'Skipping TF2.X only test.'
)
class
CenterNetMetaArchPredictionHeadTest
(
test_case
.
TestCase
):
class
CenterNetMetaArchPredictionHeadTest
(
test_case
.
TestCase
,
parameterized
.
TestCase
):
"""Test CenterNet meta architecture prediction head."""
"""Test CenterNet meta architecture prediction head."""
def
test_prediction_head
(
self
):
@
parameterized
.
parameters
([
True
,
False
])
head
=
cnma
.
make_prediction_net
(
num_out_channels
=
7
)
def
test_prediction_head
(
self
,
use_depthwise
):
head
=
cnma
.
make_prediction_net
(
num_out_channels
=
7
,
use_depthwise
=
use_depthwise
)
output
=
head
(
np
.
zeros
((
4
,
128
,
128
,
8
)))
output
=
head
(
np
.
zeros
((
4
,
128
,
128
,
8
)))
self
.
assertEqual
((
4
,
128
,
128
,
7
),
output
.
shape
)
self
.
assertEqual
((
4
,
128
,
128
,
7
),
output
.
shape
)
...
...
research/object_detection/protos/center_net.proto
View file @
2ebe7c3c
...
@@ -19,6 +19,9 @@ message CenterNet {
...
@@ -19,6 +19,9 @@ message CenterNet {
// Image resizer for preprocessing the input image.
// Image resizer for preprocessing the input image.
optional
ImageResizer
image_resizer
=
3
;
optional
ImageResizer
image_resizer
=
3
;
// If set, all task heads will be constructed with separable convolutions.
optional
bool
use_depthwise
=
13
[
default
=
false
];
// Parameters which are related to object detection task.
// Parameters which are related to object detection task.
message
ObjectDetection
{
message
ObjectDetection
{
// The original fields are moved to ObjectCenterParams or deleted.
// The original fields are moved to ObjectCenterParams or deleted.
...
@@ -278,4 +281,9 @@ message CenterNetFeatureExtractor {
...
@@ -278,4 +281,9 @@ message CenterNetFeatureExtractor {
// If set, will change channel order to be [blue, green, red]. This can be
// If set, will change channel order to be [blue, green, red]. This can be
// useful to be compatible with some pre-trained feature extractors.
// useful to be compatible with some pre-trained feature extractors.
optional
bool
bgr_ordering
=
4
[
default
=
false
];
optional
bool
bgr_ordering
=
4
[
default
=
false
];
// If set, the feature upsampling layers will be constructed with
// separable convolutions. This is typically applied to feature pyramid
// network if any.
optional
bool
use_depthwise
=
5
[
default
=
false
];
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment