Commit 72d37b69 authored by Ronny Votel's avatar Ronny Votel Committed by TF Object Detection Team
Browse files

Updating CenterNet Feature Extractors to plumb through model specific parameters.

PiperOrigin-RevId: 366134194
parent cee4b75e
...@@ -170,9 +170,6 @@ if tf_version.is_tf2(): ...@@ -170,9 +170,6 @@ if tf_version.is_tf2():
center_net_mobilenet_v2_feature_extractor.mobilenet_v2, center_net_mobilenet_v2_feature_extractor.mobilenet_v2,
'mobilenet_v2_fpn': 'mobilenet_v2_fpn':
center_net_mobilenet_v2_fpn_feature_extractor.mobilenet_v2_fpn, center_net_mobilenet_v2_fpn_feature_extractor.mobilenet_v2_fpn,
'mobilenet_v2_fpn_sep_conv':
center_net_mobilenet_v2_fpn_feature_extractor
.mobilenet_v2_fpn_sep_conv,
} }
FEATURE_EXTRACTOR_MAPS = [ FEATURE_EXTRACTOR_MAPS = [
...@@ -1130,6 +1127,8 @@ def _build_center_net_feature_extractor(feature_extractor_config, is_training): ...@@ -1130,6 +1127,8 @@ def _build_center_net_feature_extractor(feature_extractor_config, is_training):
'channel_means': list(feature_extractor_config.channel_means), 'channel_means': list(feature_extractor_config.channel_means),
'channel_stds': list(feature_extractor_config.channel_stds), 'channel_stds': list(feature_extractor_config.channel_stds),
'bgr_ordering': feature_extractor_config.bgr_ordering, 'bgr_ordering': feature_extractor_config.bgr_ordering,
'depth_multiplier': feature_extractor_config.depth_multiplier,
'use_separable_conv': feature_extractor_config.use_separable_conv,
} }
......
...@@ -73,8 +73,9 @@ class CenterNetHourglassFeatureExtractor( ...@@ -73,8 +73,9 @@ class CenterNetHourglassFeatureExtractor(
ValueError('Sub model type "{}" not supported.'.format(sub_model_type)) ValueError('Sub model type "{}" not supported.'.format(sub_model_type))
def hourglass_10(channel_means, channel_stds, bgr_ordering): def hourglass_10(channel_means, channel_stds, bgr_ordering, **kwargs):
"""The Hourglass-10 backbone for CenterNet.""" """The Hourglass-10 backbone for CenterNet."""
del kwargs
network = hourglass_network.hourglass_10(num_channels=32) network = hourglass_network.hourglass_10(num_channels=32)
return CenterNetHourglassFeatureExtractor( return CenterNetHourglassFeatureExtractor(
...@@ -82,8 +83,9 @@ def hourglass_10(channel_means, channel_stds, bgr_ordering): ...@@ -82,8 +83,9 @@ def hourglass_10(channel_means, channel_stds, bgr_ordering):
bgr_ordering=bgr_ordering) bgr_ordering=bgr_ordering)
def hourglass_20(channel_means, channel_stds, bgr_ordering): def hourglass_20(channel_means, channel_stds, bgr_ordering, **kwargs):
"""The Hourglass-20 backbone for CenterNet.""" """The Hourglass-20 backbone for CenterNet."""
del kwargs
network = hourglass_network.hourglass_20(num_channels=48) network = hourglass_network.hourglass_20(num_channels=48)
return CenterNetHourglassFeatureExtractor( return CenterNetHourglassFeatureExtractor(
...@@ -91,8 +93,9 @@ def hourglass_20(channel_means, channel_stds, bgr_ordering): ...@@ -91,8 +93,9 @@ def hourglass_20(channel_means, channel_stds, bgr_ordering):
bgr_ordering=bgr_ordering) bgr_ordering=bgr_ordering)
def hourglass_32(channel_means, channel_stds, bgr_ordering): def hourglass_32(channel_means, channel_stds, bgr_ordering, **kwargs):
"""The Hourglass-32 backbone for CenterNet.""" """The Hourglass-32 backbone for CenterNet."""
del kwargs
network = hourglass_network.hourglass_32(num_channels=48) network = hourglass_network.hourglass_32(num_channels=48)
return CenterNetHourglassFeatureExtractor( return CenterNetHourglassFeatureExtractor(
...@@ -100,8 +103,9 @@ def hourglass_32(channel_means, channel_stds, bgr_ordering): ...@@ -100,8 +103,9 @@ def hourglass_32(channel_means, channel_stds, bgr_ordering):
bgr_ordering=bgr_ordering) bgr_ordering=bgr_ordering)
def hourglass_52(channel_means, channel_stds, bgr_ordering): def hourglass_52(channel_means, channel_stds, bgr_ordering, **kwargs):
"""The Hourglass-52 backbone for CenterNet.""" """The Hourglass-52 backbone for CenterNet."""
del kwargs
network = hourglass_network.hourglass_52(num_channels=64) network = hourglass_network.hourglass_52(num_channels=64)
return CenterNetHourglassFeatureExtractor( return CenterNetHourglassFeatureExtractor(
...@@ -109,8 +113,9 @@ def hourglass_52(channel_means, channel_stds, bgr_ordering): ...@@ -109,8 +113,9 @@ def hourglass_52(channel_means, channel_stds, bgr_ordering):
bgr_ordering=bgr_ordering) bgr_ordering=bgr_ordering)
def hourglass_104(channel_means, channel_stds, bgr_ordering): def hourglass_104(channel_means, channel_stds, bgr_ordering, **kwargs):
"""The Hourglass-104 backbone for CenterNet.""" """The Hourglass-104 backbone for CenterNet."""
del kwargs
# TODO(vighneshb): update hourglass_104 signature to match with other # TODO(vighneshb): update hourglass_104 signature to match with other
# hourglass networks. # hourglass networks.
......
...@@ -110,11 +110,17 @@ class CenterNetMobileNetV2FeatureExtractor( ...@@ -110,11 +110,17 @@ class CenterNetMobileNetV2FeatureExtractor(
ValueError('Sub model type "{}" not supported.'.format(sub_model_type)) ValueError('Sub model type "{}" not supported.'.format(sub_model_type))
def mobilenet_v2(channel_means, channel_stds, bgr_ordering): def mobilenet_v2(channel_means, channel_stds, bgr_ordering,
depth_multiplier=1.0, **kwargs):
"""The MobileNetV2 backbone for CenterNet.""" """The MobileNetV2 backbone for CenterNet."""
del kwargs
# We set 'is_training' to True for now. # We set 'is_training' to True for now.
network = mobilenetv2.mobilenet_v2(True, include_top=False) network = mobilenetv2.mobilenet_v2(
batchnorm_training=True,
alpha=depth_multiplier,
include_top=False,
weights='imagenet' if depth_multiplier == 1.0 else None)
return CenterNetMobileNetV2FeatureExtractor( return CenterNetMobileNetV2FeatureExtractor(
network, network,
channel_means=channel_means, channel_means=channel_means,
......
...@@ -39,7 +39,7 @@ class CenterNetMobileNetV2FPNFeatureExtractor( ...@@ -39,7 +39,7 @@ class CenterNetMobileNetV2FPNFeatureExtractor(
channel_means=(0., 0., 0.), channel_means=(0., 0., 0.),
channel_stds=(1., 1., 1.), channel_stds=(1., 1., 1.),
bgr_ordering=False, bgr_ordering=False,
fpn_separable_conv=False): use_separable_conv=False):
"""Intializes the feature extractor. """Intializes the feature extractor.
Args: Args:
...@@ -50,7 +50,7 @@ class CenterNetMobileNetV2FPNFeatureExtractor( ...@@ -50,7 +50,7 @@ class CenterNetMobileNetV2FPNFeatureExtractor(
channel. Each channel will be divided by its standard deviation value. channel. Each channel will be divided by its standard deviation value.
bgr_ordering: bool, if set will change the channel ordering to be in the bgr_ordering: bool, if set will change the channel ordering to be in the
[blue, red, green] order. [blue, red, green] order.
fpn_separable_conv: If set to True, all convolutional layers in the FPN use_separable_conv: If set to True, all convolutional layers in the FPN
network will be replaced by separable convolutions. network will be replaced by separable convolutions.
""" """
...@@ -96,7 +96,7 @@ class CenterNetMobileNetV2FPNFeatureExtractor( ...@@ -96,7 +96,7 @@ class CenterNetMobileNetV2FPNFeatureExtractor(
# Merge. # Merge.
top_down = top_down + residual top_down = top_down + residual
next_num_filters = num_filters_list[i + 1] if i + 1 <= 2 else 24 next_num_filters = num_filters_list[i + 1] if i + 1 <= 2 else 24
if fpn_separable_conv: if use_separable_conv:
conv = tf.keras.layers.SeparableConv2D( conv = tf.keras.layers.SeparableConv2D(
filters=next_num_filters, kernel_size=3, strides=1, padding='same') filters=next_num_filters, kernel_size=3, strides=1, padding='same')
else: else:
...@@ -143,30 +143,20 @@ class CenterNetMobileNetV2FPNFeatureExtractor( ...@@ -143,30 +143,20 @@ class CenterNetMobileNetV2FPNFeatureExtractor(
return 1 return 1
def mobilenet_v2_fpn(channel_means, channel_stds, bgr_ordering): def mobilenet_v2_fpn(channel_means, channel_stds, bgr_ordering,
use_separable_conv=False, depth_multiplier=1.0, **kwargs):
"""The MobileNetV2+FPN backbone for CenterNet.""" """The MobileNetV2+FPN backbone for CenterNet."""
del kwargs
# Set to batchnorm_training to True for now. # Set to batchnorm_training to True for now.
network = mobilenetv2.mobilenet_v2(batchnorm_training=True, include_top=False) network = mobilenetv2.mobilenet_v2(
batchnorm_training=True,
alpha=depth_multiplier,
include_top=False,
weights='imagenet' if depth_multiplier == 1.0 else None)
return CenterNetMobileNetV2FPNFeatureExtractor( return CenterNetMobileNetV2FPNFeatureExtractor(
network, network,
channel_means=channel_means, channel_means=channel_means,
channel_stds=channel_stds, channel_stds=channel_stds,
bgr_ordering=bgr_ordering, bgr_ordering=bgr_ordering,
fpn_separable_conv=False) use_separable_conv=use_separable_conv)
def mobilenet_v2_fpn_sep_conv(channel_means, channel_stds, bgr_ordering):
"""Same as mobilenet_v2_fpn except with separable convolution in FPN."""
# Setting batchnorm_training to True, which will use the correct
# BatchNormalization layer strategy based on the current Keras learning phase.
# TODO(yuhuic): expriment with True vs. False to understand it's effect in
# practice.
network = mobilenetv2.mobilenet_v2(batchnorm_training=True, include_top=False)
return CenterNetMobileNetV2FPNFeatureExtractor(
network,
channel_means=channel_means,
channel_stds=channel_stds,
bgr_ordering=bgr_ordering,
fpn_separable_conv=True)
...@@ -18,7 +18,6 @@ import numpy as np ...@@ -18,7 +18,6 @@ import numpy as np
import tensorflow.compat.v1 as tf import tensorflow.compat.v1 as tf
from object_detection.models import center_net_mobilenet_v2_fpn_feature_extractor from object_detection.models import center_net_mobilenet_v2_fpn_feature_extractor
from object_detection.models.keras_models import mobilenet_v2
from object_detection.utils import test_case from object_detection.utils import test_case
from object_detection.utils import tf_version from object_detection.utils import tf_version
...@@ -28,10 +27,13 @@ class CenterNetMobileNetV2FPNFeatureExtractorTest(test_case.TestCase): ...@@ -28,10 +27,13 @@ class CenterNetMobileNetV2FPNFeatureExtractorTest(test_case.TestCase):
def test_center_net_mobilenet_v2_fpn_feature_extractor(self): def test_center_net_mobilenet_v2_fpn_feature_extractor(self):
net = mobilenet_v2.mobilenet_v2(True, include_top=False) channel_means = (0., 0., 0.)
channel_stds = (1., 1., 1.)
model = center_net_mobilenet_v2_fpn_feature_extractor.CenterNetMobileNetV2FPNFeatureExtractor( bgr_ordering = False
net) model = (
center_net_mobilenet_v2_fpn_feature_extractor.mobilenet_v2_fpn(
channel_means, channel_stds, bgr_ordering,
use_separable_conv=False))
def graph_fn(): def graph_fn():
img = np.zeros((8, 224, 224, 3), dtype=np.float32) img = np.zeros((8, 224, 224, 3), dtype=np.float32)
...@@ -50,10 +52,12 @@ class CenterNetMobileNetV2FPNFeatureExtractorTest(test_case.TestCase): ...@@ -50,10 +52,12 @@ class CenterNetMobileNetV2FPNFeatureExtractorTest(test_case.TestCase):
def test_center_net_mobilenet_v2_fpn_feature_extractor_sep_conv(self): def test_center_net_mobilenet_v2_fpn_feature_extractor_sep_conv(self):
net = mobilenet_v2.mobilenet_v2(True, include_top=False) channel_means = (0., 0., 0.)
channel_stds = (1., 1., 1.)
model = center_net_mobilenet_v2_fpn_feature_extractor.CenterNetMobileNetV2FPNFeatureExtractor( bgr_ordering = False
net, fpn_separable_conv=True) model = (
center_net_mobilenet_v2_fpn_feature_extractor.mobilenet_v2_fpn(
channel_means, channel_stds, bgr_ordering, use_separable_conv=True))
def graph_fn(): def graph_fn():
img = np.zeros((8, 224, 224, 3), dtype=np.float32) img = np.zeros((8, 224, 224, 3), dtype=np.float32)
...@@ -62,6 +66,10 @@ class CenterNetMobileNetV2FPNFeatureExtractorTest(test_case.TestCase): ...@@ -62,6 +66,10 @@ class CenterNetMobileNetV2FPNFeatureExtractorTest(test_case.TestCase):
outputs = self.execute(graph_fn, []) outputs = self.execute(graph_fn, [])
self.assertEqual(outputs.shape, (8, 56, 56, 24)) self.assertEqual(outputs.shape, (8, 56, 56, 24))
# Pull out the FPN network.
backbone = model.get_layer('model')
first_conv = backbone.get_layer('Conv1')
self.assertEqual(32, first_conv.filters)
# Pull out the FPN network. # Pull out the FPN network.
output = model.get_layer('model_1') output = model.get_layer('model_1')
...@@ -71,6 +79,30 @@ class CenterNetMobileNetV2FPNFeatureExtractorTest(test_case.TestCase): ...@@ -71,6 +79,30 @@ class CenterNetMobileNetV2FPNFeatureExtractorTest(test_case.TestCase):
if 'conv' in layer.name and layer.kernel_size != (1, 1): if 'conv' in layer.name and layer.kernel_size != (1, 1):
self.assertIsInstance(layer, tf.keras.layers.SeparableConv2D) self.assertIsInstance(layer, tf.keras.layers.SeparableConv2D)
def test_center_net_mobilenet_v2_fpn_feature_extractor_depth_multiplier(self):
channel_means = (0., 0., 0.)
channel_stds = (1., 1., 1.)
bgr_ordering = False
model = (
center_net_mobilenet_v2_fpn_feature_extractor.mobilenet_v2_fpn(
channel_means, channel_stds, bgr_ordering, use_separable_conv=True,
depth_multiplier=2.0))
def graph_fn():
img = np.zeros((8, 224, 224, 3), dtype=np.float32)
processed_img = model.preprocess(img)
return model(processed_img)
outputs = self.execute(graph_fn, [])
self.assertEqual(outputs.shape, (8, 56, 56, 24))
# Pull out the FPN network.
backbone = model.get_layer('model')
first_conv = backbone.get_layer('Conv1')
# Note that the first layer typically has 32 filters, but this model has
# a depth multiplier of 2.
self.assertEqual(64, first_conv.filters)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
...@@ -136,8 +136,9 @@ class CenterNetResnetFeatureExtractor(CenterNetFeatureExtractor): ...@@ -136,8 +136,9 @@ class CenterNetResnetFeatureExtractor(CenterNetFeatureExtractor):
ValueError('Sub model type "{}" not supported.'.format(sub_model_type)) ValueError('Sub model type "{}" not supported.'.format(sub_model_type))
def resnet_v2_101(channel_means, channel_stds, bgr_ordering): def resnet_v2_101(channel_means, channel_stds, bgr_ordering, **kwargs):
"""The ResNet v2 101 feature extractor.""" """The ResNet v2 101 feature extractor."""
del kwargs
return CenterNetResnetFeatureExtractor( return CenterNetResnetFeatureExtractor(
resnet_type='resnet_v2_101', resnet_type='resnet_v2_101',
...@@ -147,8 +148,9 @@ def resnet_v2_101(channel_means, channel_stds, bgr_ordering): ...@@ -147,8 +148,9 @@ def resnet_v2_101(channel_means, channel_stds, bgr_ordering):
) )
def resnet_v2_50(channel_means, channel_stds, bgr_ordering): def resnet_v2_50(channel_means, channel_stds, bgr_ordering, **kwargs):
"""The ResNet v2 50 feature extractor.""" """The ResNet v2 50 feature extractor."""
del kwargs
return CenterNetResnetFeatureExtractor( return CenterNetResnetFeatureExtractor(
resnet_type='resnet_v2_50', resnet_type='resnet_v2_50',
......
...@@ -172,8 +172,9 @@ class CenterNetResnetV1FpnFeatureExtractor(CenterNetFeatureExtractor): ...@@ -172,8 +172,9 @@ class CenterNetResnetV1FpnFeatureExtractor(CenterNetFeatureExtractor):
ValueError('Sub model type "{}" not supported.'.format(sub_model_type)) ValueError('Sub model type "{}" not supported.'.format(sub_model_type))
def resnet_v1_101_fpn(channel_means, channel_stds, bgr_ordering): def resnet_v1_101_fpn(channel_means, channel_stds, bgr_ordering, **kwargs):
"""The ResNet v1 101 FPN feature extractor.""" """The ResNet v1 101 FPN feature extractor."""
del kwargs
return CenterNetResnetV1FpnFeatureExtractor( return CenterNetResnetV1FpnFeatureExtractor(
resnet_type='resnet_v1_101', resnet_type='resnet_v1_101',
...@@ -183,8 +184,9 @@ def resnet_v1_101_fpn(channel_means, channel_stds, bgr_ordering): ...@@ -183,8 +184,9 @@ def resnet_v1_101_fpn(channel_means, channel_stds, bgr_ordering):
) )
def resnet_v1_50_fpn(channel_means, channel_stds, bgr_ordering): def resnet_v1_50_fpn(channel_means, channel_stds, bgr_ordering, **kwargs):
"""The ResNet v1 50 FPN feature extractor.""" """The ResNet v1 50 FPN feature extractor."""
del kwargs
return CenterNetResnetV1FpnFeatureExtractor( return CenterNetResnetV1FpnFeatureExtractor(
resnet_type='resnet_v1_50', resnet_type='resnet_v1_50',
...@@ -193,8 +195,9 @@ def resnet_v1_50_fpn(channel_means, channel_stds, bgr_ordering): ...@@ -193,8 +195,9 @@ def resnet_v1_50_fpn(channel_means, channel_stds, bgr_ordering):
bgr_ordering=bgr_ordering) bgr_ordering=bgr_ordering)
def resnet_v1_34_fpn(channel_means, channel_stds, bgr_ordering): def resnet_v1_34_fpn(channel_means, channel_stds, bgr_ordering, **kwargs):
"""The ResNet v1 34 FPN feature extractor.""" """The ResNet v1 34 FPN feature extractor."""
del kwargs
return CenterNetResnetV1FpnFeatureExtractor( return CenterNetResnetV1FpnFeatureExtractor(
resnet_type='resnet_v1_34', resnet_type='resnet_v1_34',
...@@ -204,8 +207,9 @@ def resnet_v1_34_fpn(channel_means, channel_stds, bgr_ordering): ...@@ -204,8 +207,9 @@ def resnet_v1_34_fpn(channel_means, channel_stds, bgr_ordering):
) )
def resnet_v1_18_fpn(channel_means, channel_stds, bgr_ordering): def resnet_v1_18_fpn(channel_means, channel_stds, bgr_ordering, **kwargs):
"""The ResNet v1 18 FPN feature extractor.""" """The ResNet v1 18 FPN feature extractor."""
del kwargs
return CenterNetResnetV1FpnFeatureExtractor( return CenterNetResnetV1FpnFeatureExtractor(
resnet_type='resnet_v1_18', resnet_type='resnet_v1_18',
......
...@@ -376,5 +376,12 @@ message CenterNetFeatureExtractor { ...@@ -376,5 +376,12 @@ message CenterNetFeatureExtractor {
// network if any. // network if any.
optional bool use_depthwise = 5 [default = false]; optional bool use_depthwise = 5 [default = false];
// Depth multiplier. Only valid for specific models (e.g. MobileNet). See subclasses of `CenterNetFeatureExtractor`.
optional float depth_multiplier = 9 [default = 1.0];
// Whether to use separable convolutions. Only valid for specific
// models. See subclasses of `CenterNetFeatureExtractor`.
optional bool use_separable_conv = 10 [default = false];
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment