Commit 2ebe7c3c authored by Liangzhe Yuan's avatar Liangzhe Yuan Committed by TF Object Detection Team
Browse files

Support to use separable_conv in CenterNet task head.

PiperOrigin-RevId: 333840074
parent 59888a74
......@@ -1035,7 +1035,8 @@ def _build_center_net_model(center_net_config, is_training, add_summaries):
mask_params=mask_params,
densepose_params=densepose_params,
track_params=track_params,
temporal_offset_params=temporal_offset_params)
temporal_offset_params=temporal_offset_params,
use_depthwise=center_net_config.use_depthwise)
def _build_center_net_feature_extractor(
......
......@@ -139,7 +139,7 @@ class CenterNetFeatureExtractor(tf.keras.Model):
def make_prediction_net(num_out_channels, kernel_size=3, num_filters=256,
bias_fill=None):
bias_fill=None, use_depthwise=False, name=None):
"""Creates a network to predict the given number of output channels.
This function is intended to make the prediction heads for the CenterNet
......@@ -151,12 +151,19 @@ def make_prediction_net(num_out_channels, kernel_size=3, num_filters=256,
num_filters: The number of filters in the intermediate conv layer.
bias_fill: If not None, is used to initialize the bias in the final conv
layer.
use_depthwise: If true, use SeparableConv2D to construct the Sequential
layers instead of Conv2D.
name: Optional name for the prediction net.
Returns:
net: A keras module which when called on an input tensor of size
[batch_size, height, width, num_in_channels] returns an output
of size [batch_size, height, width, num_out_channels]
"""
if use_depthwise:
conv_fn = tf.keras.layers.SeparableConv2D
else:
conv_fn = tf.keras.layers.Conv2D
out_conv = tf.keras.layers.Conv2D(num_out_channels, kernel_size=1)
......@@ -164,11 +171,10 @@ def make_prediction_net(num_out_channels, kernel_size=3, num_filters=256,
out_conv.bias_initializer = tf.keras.initializers.constant(bias_fill)
net = tf.keras.Sequential(
[tf.keras.layers.Conv2D(num_filters, kernel_size=kernel_size,
padding='same'),
[conv_fn(num_filters, kernel_size=kernel_size, padding='same'),
tf.keras.layers.ReLU(),
out_conv]
)
out_conv],
name=name)
return net
......@@ -1673,7 +1679,8 @@ class CenterNetMetaArch(model.DetectionModel):
mask_params=None,
densepose_params=None,
track_params=None,
temporal_offset_params=None):
temporal_offset_params=None,
use_depthwise=False):
"""Initializes a CenterNet model.
Args:
......@@ -1710,6 +1717,8 @@ class CenterNetMetaArch(model.DetectionModel):
definition for more details.
temporal_offset_params: A TemporalOffsetParams namedtuple. This object
holds the hyper-parameters for offset prediction based tracking.
use_depthwise: If true, all task heads will be constructed using
separable_conv. Otherwise, standard convoltuions will be used.
"""
assert object_detection_params or keypoint_params_dict
# Shorten the name for convenience and better formatting.
......@@ -1732,6 +1741,8 @@ class CenterNetMetaArch(model.DetectionModel):
self._track_params = track_params
self._temporal_offset_params = temporal_offset_params
self._use_depthwise = use_depthwise
# Construct the prediction head nets.
self._prediction_head_dict = self._construct_prediction_heads(
num_classes,
......@@ -1775,58 +1786,75 @@ class CenterNetMetaArch(model.DetectionModel):
"""
prediction_heads = {}
prediction_heads[OBJECT_CENTER] = [
make_prediction_net(num_classes, bias_fill=class_prediction_bias_init)
make_prediction_net(num_classes, bias_fill=class_prediction_bias_init,
use_depthwise=self._use_depthwise)
for _ in range(num_feature_outputs)
]
if self._od_params is not None:
prediction_heads[BOX_SCALE] = [
make_prediction_net(NUM_SIZE_CHANNELS)
make_prediction_net(
NUM_SIZE_CHANNELS, use_depthwise=self._use_depthwise)
for _ in range(num_feature_outputs)
]
prediction_heads[BOX_OFFSET] = [
make_prediction_net(NUM_OFFSET_CHANNELS)
make_prediction_net(
NUM_OFFSET_CHANNELS, use_depthwise=self._use_depthwise)
for _ in range(num_feature_outputs)
]
if self._kp_params_dict is not None:
for task_name, kp_params in self._kp_params_dict.items():
num_keypoints = len(kp_params.keypoint_indices)
# pylint: disable=g-complex-comprehension
prediction_heads[get_keypoint_name(task_name, KEYPOINT_HEATMAP)] = [
make_prediction_net(
num_keypoints, bias_fill=kp_params.heatmap_bias_init)
num_keypoints,
bias_fill=kp_params.heatmap_bias_init,
use_depthwise=self._use_depthwise)
for _ in range(num_feature_outputs)
]
# pylint: enable=g-complex-comprehension
prediction_heads[get_keypoint_name(task_name, KEYPOINT_REGRESSION)] = [
make_prediction_net(NUM_OFFSET_CHANNELS * num_keypoints)
make_prediction_net(NUM_OFFSET_CHANNELS * num_keypoints,
use_depthwise=self._use_depthwise)
for _ in range(num_feature_outputs)
]
if kp_params.per_keypoint_offset:
prediction_heads[get_keypoint_name(task_name, KEYPOINT_OFFSET)] = [
make_prediction_net(NUM_OFFSET_CHANNELS * num_keypoints)
make_prediction_net(NUM_OFFSET_CHANNELS * num_keypoints,
use_depthwise=self._use_depthwise)
for _ in range(num_feature_outputs)
]
else:
prediction_heads[get_keypoint_name(task_name, KEYPOINT_OFFSET)] = [
make_prediction_net(NUM_OFFSET_CHANNELS)
make_prediction_net(NUM_OFFSET_CHANNELS,
use_depthwise=self._use_depthwise)
for _ in range(num_feature_outputs)
]
# pylint: disable=g-complex-comprehension
if self._mask_params is not None:
prediction_heads[SEGMENTATION_HEATMAP] = [
make_prediction_net(num_classes,
bias_fill=self._mask_params.heatmap_bias_init)
make_prediction_net(
num_classes,
bias_fill=self._mask_params.heatmap_bias_init,
use_depthwise=self._use_depthwise)
for _ in range(num_feature_outputs)]
if self._densepose_params is not None:
prediction_heads[DENSEPOSE_HEATMAP] = [
make_prediction_net( # pylint: disable=g-complex-comprehension
make_prediction_net(
self._densepose_params.num_parts,
bias_fill=self._densepose_params.heatmap_bias_init)
bias_fill=self._densepose_params.heatmap_bias_init,
use_depthwise=self._use_depthwise)
for _ in range(num_feature_outputs)]
prediction_heads[DENSEPOSE_REGRESSION] = [
make_prediction_net(2 * self._densepose_params.num_parts)
make_prediction_net(2 * self._densepose_params.num_parts,
use_depthwise=self._use_depthwise)
for _ in range(num_feature_outputs)
]
# pylint: enable=g-complex-comprehension
if self._track_params is not None:
prediction_heads[TRACK_REID] = [
make_prediction_net(self._track_params.reid_embed_size)
make_prediction_net(self._track_params.reid_embed_size,
use_depthwise=self._use_depthwise)
for _ in range(num_feature_outputs)]
# Creates a classification network to train object embeddings by learning
......@@ -1846,7 +1874,8 @@ class CenterNetMetaArch(model.DetectionModel):
self._track_params.reid_embed_size,)))
if self._temporal_offset_params is not None:
prediction_heads[TEMPORAL_OFFSET] = [
make_prediction_net(NUM_OFFSET_CHANNELS)
make_prediction_net(NUM_OFFSET_CHANNELS,
use_depthwise=self._use_depthwise)
for _ in range(num_feature_outputs)
]
return prediction_heads
......
......@@ -35,11 +35,14 @@ from object_detection.utils import tf_version
@unittest.skipIf(tf_version.is_tf1(), 'Skipping TF2.X only test.')
class CenterNetMetaArchPredictionHeadTest(test_case.TestCase):
class CenterNetMetaArchPredictionHeadTest(
test_case.TestCase, parameterized.TestCase):
"""Test CenterNet meta architecture prediction head."""
def test_prediction_head(self):
head = cnma.make_prediction_net(num_out_channels=7)
@parameterized.parameters([True, False])
def test_prediction_head(self, use_depthwise):
head = cnma.make_prediction_net(num_out_channels=7,
use_depthwise=use_depthwise)
output = head(np.zeros((4, 128, 128, 8)))
self.assertEqual((4, 128, 128, 7), output.shape)
......
......@@ -19,6 +19,9 @@ message CenterNet {
// Image resizer for preprocessing the input image.
optional ImageResizer image_resizer = 3;
// If set, all task heads will be constructed with separable convolutions.
optional bool use_depthwise = 13 [default = false];
// Parameters which are related to object detection task.
message ObjectDetection {
// The original fields are moved to ObjectCenterParams or deleted.
......@@ -278,4 +281,9 @@ message CenterNetFeatureExtractor {
// If set, will change channel order to be [blue, green, red]. This can be
// useful to be compatible with some pre-trained feature extractors.
optional bool bgr_ordering = 4 [default = false];
// If set, the feature upsampling layers will be constructed with
// separable convolutions. This is typically applied to feature pyramid
// network if any.
optional bool use_depthwise = 5 [default = false];
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment