Commit 2ebe7c3c authored by Liangzhe Yuan's avatar Liangzhe Yuan Committed by TF Object Detection Team
Browse files

Support to use separable_conv in CenterNet task head.

PiperOrigin-RevId: 333840074
parent 59888a74
...@@ -1035,7 +1035,8 @@ def _build_center_net_model(center_net_config, is_training, add_summaries): ...@@ -1035,7 +1035,8 @@ def _build_center_net_model(center_net_config, is_training, add_summaries):
mask_params=mask_params, mask_params=mask_params,
densepose_params=densepose_params, densepose_params=densepose_params,
track_params=track_params, track_params=track_params,
temporal_offset_params=temporal_offset_params) temporal_offset_params=temporal_offset_params,
use_depthwise=center_net_config.use_depthwise)
def _build_center_net_feature_extractor( def _build_center_net_feature_extractor(
......
...@@ -139,7 +139,7 @@ class CenterNetFeatureExtractor(tf.keras.Model): ...@@ -139,7 +139,7 @@ class CenterNetFeatureExtractor(tf.keras.Model):
def make_prediction_net(num_out_channels, kernel_size=3, num_filters=256, def make_prediction_net(num_out_channels, kernel_size=3, num_filters=256,
bias_fill=None): bias_fill=None, use_depthwise=False, name=None):
"""Creates a network to predict the given number of output channels. """Creates a network to predict the given number of output channels.
This function is intended to make the prediction heads for the CenterNet This function is intended to make the prediction heads for the CenterNet
...@@ -151,12 +151,19 @@ def make_prediction_net(num_out_channels, kernel_size=3, num_filters=256, ...@@ -151,12 +151,19 @@ def make_prediction_net(num_out_channels, kernel_size=3, num_filters=256,
num_filters: The number of filters in the intermediate conv layer. num_filters: The number of filters in the intermediate conv layer.
bias_fill: If not None, is used to initialize the bias in the final conv bias_fill: If not None, is used to initialize the bias in the final conv
layer. layer.
use_depthwise: If true, use SeparableConv2D to construct the Sequential
layers instead of Conv2D.
name: Optional name for the prediction net.
Returns: Returns:
net: A keras module which when called on an input tensor of size net: A keras module which when called on an input tensor of size
[batch_size, height, width, num_in_channels] returns an output [batch_size, height, width, num_in_channels] returns an output
of size [batch_size, height, width, num_out_channels] of size [batch_size, height, width, num_out_channels]
""" """
if use_depthwise:
conv_fn = tf.keras.layers.SeparableConv2D
else:
conv_fn = tf.keras.layers.Conv2D
out_conv = tf.keras.layers.Conv2D(num_out_channels, kernel_size=1) out_conv = tf.keras.layers.Conv2D(num_out_channels, kernel_size=1)
...@@ -164,11 +171,10 @@ def make_prediction_net(num_out_channels, kernel_size=3, num_filters=256, ...@@ -164,11 +171,10 @@ def make_prediction_net(num_out_channels, kernel_size=3, num_filters=256,
out_conv.bias_initializer = tf.keras.initializers.constant(bias_fill) out_conv.bias_initializer = tf.keras.initializers.constant(bias_fill)
net = tf.keras.Sequential( net = tf.keras.Sequential(
[tf.keras.layers.Conv2D(num_filters, kernel_size=kernel_size, [conv_fn(num_filters, kernel_size=kernel_size, padding='same'),
padding='same'),
tf.keras.layers.ReLU(), tf.keras.layers.ReLU(),
out_conv] out_conv],
) name=name)
return net return net
...@@ -1673,7 +1679,8 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -1673,7 +1679,8 @@ class CenterNetMetaArch(model.DetectionModel):
mask_params=None, mask_params=None,
densepose_params=None, densepose_params=None,
track_params=None, track_params=None,
temporal_offset_params=None): temporal_offset_params=None,
use_depthwise=False):
"""Initializes a CenterNet model. """Initializes a CenterNet model.
Args: Args:
...@@ -1710,6 +1717,8 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -1710,6 +1717,8 @@ class CenterNetMetaArch(model.DetectionModel):
definition for more details. definition for more details.
temporal_offset_params: A TemporalOffsetParams namedtuple. This object temporal_offset_params: A TemporalOffsetParams namedtuple. This object
holds the hyper-parameters for offset prediction based tracking. holds the hyper-parameters for offset prediction based tracking.
use_depthwise: If true, all task heads will be constructed using
separable_conv. Otherwise, standard convoltuions will be used.
""" """
assert object_detection_params or keypoint_params_dict assert object_detection_params or keypoint_params_dict
# Shorten the name for convenience and better formatting. # Shorten the name for convenience and better formatting.
...@@ -1732,6 +1741,8 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -1732,6 +1741,8 @@ class CenterNetMetaArch(model.DetectionModel):
self._track_params = track_params self._track_params = track_params
self._temporal_offset_params = temporal_offset_params self._temporal_offset_params = temporal_offset_params
self._use_depthwise = use_depthwise
# Construct the prediction head nets. # Construct the prediction head nets.
self._prediction_head_dict = self._construct_prediction_heads( self._prediction_head_dict = self._construct_prediction_heads(
num_classes, num_classes,
...@@ -1775,58 +1786,75 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -1775,58 +1786,75 @@ class CenterNetMetaArch(model.DetectionModel):
""" """
prediction_heads = {} prediction_heads = {}
prediction_heads[OBJECT_CENTER] = [ prediction_heads[OBJECT_CENTER] = [
make_prediction_net(num_classes, bias_fill=class_prediction_bias_init) make_prediction_net(num_classes, bias_fill=class_prediction_bias_init,
use_depthwise=self._use_depthwise)
for _ in range(num_feature_outputs) for _ in range(num_feature_outputs)
] ]
if self._od_params is not None: if self._od_params is not None:
prediction_heads[BOX_SCALE] = [ prediction_heads[BOX_SCALE] = [
make_prediction_net(NUM_SIZE_CHANNELS) make_prediction_net(
NUM_SIZE_CHANNELS, use_depthwise=self._use_depthwise)
for _ in range(num_feature_outputs) for _ in range(num_feature_outputs)
] ]
prediction_heads[BOX_OFFSET] = [ prediction_heads[BOX_OFFSET] = [
make_prediction_net(NUM_OFFSET_CHANNELS) make_prediction_net(
NUM_OFFSET_CHANNELS, use_depthwise=self._use_depthwise)
for _ in range(num_feature_outputs) for _ in range(num_feature_outputs)
] ]
if self._kp_params_dict is not None: if self._kp_params_dict is not None:
for task_name, kp_params in self._kp_params_dict.items(): for task_name, kp_params in self._kp_params_dict.items():
num_keypoints = len(kp_params.keypoint_indices) num_keypoints = len(kp_params.keypoint_indices)
# pylint: disable=g-complex-comprehension
prediction_heads[get_keypoint_name(task_name, KEYPOINT_HEATMAP)] = [ prediction_heads[get_keypoint_name(task_name, KEYPOINT_HEATMAP)] = [
make_prediction_net( make_prediction_net(
num_keypoints, bias_fill=kp_params.heatmap_bias_init) num_keypoints,
bias_fill=kp_params.heatmap_bias_init,
use_depthwise=self._use_depthwise)
for _ in range(num_feature_outputs) for _ in range(num_feature_outputs)
] ]
# pylint: enable=g-complex-comprehension
prediction_heads[get_keypoint_name(task_name, KEYPOINT_REGRESSION)] = [ prediction_heads[get_keypoint_name(task_name, KEYPOINT_REGRESSION)] = [
make_prediction_net(NUM_OFFSET_CHANNELS * num_keypoints) make_prediction_net(NUM_OFFSET_CHANNELS * num_keypoints,
use_depthwise=self._use_depthwise)
for _ in range(num_feature_outputs) for _ in range(num_feature_outputs)
] ]
if kp_params.per_keypoint_offset: if kp_params.per_keypoint_offset:
prediction_heads[get_keypoint_name(task_name, KEYPOINT_OFFSET)] = [ prediction_heads[get_keypoint_name(task_name, KEYPOINT_OFFSET)] = [
make_prediction_net(NUM_OFFSET_CHANNELS * num_keypoints) make_prediction_net(NUM_OFFSET_CHANNELS * num_keypoints,
use_depthwise=self._use_depthwise)
for _ in range(num_feature_outputs) for _ in range(num_feature_outputs)
] ]
else: else:
prediction_heads[get_keypoint_name(task_name, KEYPOINT_OFFSET)] = [ prediction_heads[get_keypoint_name(task_name, KEYPOINT_OFFSET)] = [
make_prediction_net(NUM_OFFSET_CHANNELS) make_prediction_net(NUM_OFFSET_CHANNELS,
use_depthwise=self._use_depthwise)
for _ in range(num_feature_outputs) for _ in range(num_feature_outputs)
] ]
# pylint: disable=g-complex-comprehension
if self._mask_params is not None: if self._mask_params is not None:
prediction_heads[SEGMENTATION_HEATMAP] = [ prediction_heads[SEGMENTATION_HEATMAP] = [
make_prediction_net(num_classes, make_prediction_net(
bias_fill=self._mask_params.heatmap_bias_init) num_classes,
bias_fill=self._mask_params.heatmap_bias_init,
use_depthwise=self._use_depthwise)
for _ in range(num_feature_outputs)] for _ in range(num_feature_outputs)]
if self._densepose_params is not None: if self._densepose_params is not None:
prediction_heads[DENSEPOSE_HEATMAP] = [ prediction_heads[DENSEPOSE_HEATMAP] = [
make_prediction_net( # pylint: disable=g-complex-comprehension make_prediction_net(
self._densepose_params.num_parts, self._densepose_params.num_parts,
bias_fill=self._densepose_params.heatmap_bias_init) bias_fill=self._densepose_params.heatmap_bias_init,
use_depthwise=self._use_depthwise)
for _ in range(num_feature_outputs)] for _ in range(num_feature_outputs)]
prediction_heads[DENSEPOSE_REGRESSION] = [ prediction_heads[DENSEPOSE_REGRESSION] = [
make_prediction_net(2 * self._densepose_params.num_parts) make_prediction_net(2 * self._densepose_params.num_parts,
use_depthwise=self._use_depthwise)
for _ in range(num_feature_outputs) for _ in range(num_feature_outputs)
] ]
# pylint: enable=g-complex-comprehension
if self._track_params is not None: if self._track_params is not None:
prediction_heads[TRACK_REID] = [ prediction_heads[TRACK_REID] = [
make_prediction_net(self._track_params.reid_embed_size) make_prediction_net(self._track_params.reid_embed_size,
use_depthwise=self._use_depthwise)
for _ in range(num_feature_outputs)] for _ in range(num_feature_outputs)]
# Creates a classification network to train object embeddings by learning # Creates a classification network to train object embeddings by learning
...@@ -1846,7 +1874,8 @@ class CenterNetMetaArch(model.DetectionModel): ...@@ -1846,7 +1874,8 @@ class CenterNetMetaArch(model.DetectionModel):
self._track_params.reid_embed_size,))) self._track_params.reid_embed_size,)))
if self._temporal_offset_params is not None: if self._temporal_offset_params is not None:
prediction_heads[TEMPORAL_OFFSET] = [ prediction_heads[TEMPORAL_OFFSET] = [
make_prediction_net(NUM_OFFSET_CHANNELS) make_prediction_net(NUM_OFFSET_CHANNELS,
use_depthwise=self._use_depthwise)
for _ in range(num_feature_outputs) for _ in range(num_feature_outputs)
] ]
return prediction_heads return prediction_heads
......
...@@ -35,11 +35,14 @@ from object_detection.utils import tf_version ...@@ -35,11 +35,14 @@ from object_detection.utils import tf_version
@unittest.skipIf(tf_version.is_tf1(), 'Skipping TF2.X only test.') @unittest.skipIf(tf_version.is_tf1(), 'Skipping TF2.X only test.')
class CenterNetMetaArchPredictionHeadTest(test_case.TestCase): class CenterNetMetaArchPredictionHeadTest(
test_case.TestCase, parameterized.TestCase):
"""Test CenterNet meta architecture prediction head.""" """Test CenterNet meta architecture prediction head."""
def test_prediction_head(self): @parameterized.parameters([True, False])
head = cnma.make_prediction_net(num_out_channels=7) def test_prediction_head(self, use_depthwise):
head = cnma.make_prediction_net(num_out_channels=7,
use_depthwise=use_depthwise)
output = head(np.zeros((4, 128, 128, 8))) output = head(np.zeros((4, 128, 128, 8)))
self.assertEqual((4, 128, 128, 7), output.shape) self.assertEqual((4, 128, 128, 7), output.shape)
......
...@@ -19,6 +19,9 @@ message CenterNet { ...@@ -19,6 +19,9 @@ message CenterNet {
// Image resizer for preprocessing the input image. // Image resizer for preprocessing the input image.
optional ImageResizer image_resizer = 3; optional ImageResizer image_resizer = 3;
// If set, all task heads will be constructed with separable convolutions.
optional bool use_depthwise = 13 [default = false];
// Parameters which are related to object detection task. // Parameters which are related to object detection task.
message ObjectDetection { message ObjectDetection {
// The original fields are moved to ObjectCenterParams or deleted. // The original fields are moved to ObjectCenterParams or deleted.
...@@ -278,4 +281,9 @@ message CenterNetFeatureExtractor { ...@@ -278,4 +281,9 @@ message CenterNetFeatureExtractor {
// If set, will change channel order to be [blue, green, red]. This can be // If set, will change channel order to be [blue, green, red]. This can be
// useful to be compatible with some pre-trained feature extractors. // useful to be compatible with some pre-trained feature extractors.
optional bool bgr_ordering = 4 [default = false]; optional bool bgr_ordering = 4 [default = false];
// If set, the feature upsampling layers will be constructed with
// separable convolutions. This is typically applied to feature pyramid
// network if any.
optional bool use_depthwise = 5 [default = false];
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment