Commit 72d37b69 authored by Ronny Votel's avatar Ronny Votel Committed by TF Object Detection Team
Browse files

Updating CenterNet Feature Extractors to plumb through model specific parameters.

PiperOrigin-RevId: 366134194
parent cee4b75e
......@@ -170,9 +170,6 @@ if tf_version.is_tf2():
center_net_mobilenet_v2_feature_extractor.mobilenet_v2,
'mobilenet_v2_fpn':
center_net_mobilenet_v2_fpn_feature_extractor.mobilenet_v2_fpn,
'mobilenet_v2_fpn_sep_conv':
center_net_mobilenet_v2_fpn_feature_extractor
.mobilenet_v2_fpn_sep_conv,
}
FEATURE_EXTRACTOR_MAPS = [
......@@ -1130,6 +1127,8 @@ def _build_center_net_feature_extractor(feature_extractor_config, is_training):
'channel_means': list(feature_extractor_config.channel_means),
'channel_stds': list(feature_extractor_config.channel_stds),
'bgr_ordering': feature_extractor_config.bgr_ordering,
'depth_multiplier': feature_extractor_config.depth_multiplier,
'use_separable_conv': feature_extractor_config.use_separable_conv,
}
......
......@@ -73,8 +73,9 @@ class CenterNetHourglassFeatureExtractor(
ValueError('Sub model type "{}" not supported.'.format(sub_model_type))
def hourglass_10(channel_means, channel_stds, bgr_ordering):
def hourglass_10(channel_means, channel_stds, bgr_ordering, **kwargs):
"""The Hourglass-10 backbone for CenterNet."""
del kwargs
network = hourglass_network.hourglass_10(num_channels=32)
return CenterNetHourglassFeatureExtractor(
......@@ -82,8 +83,9 @@ def hourglass_10(channel_means, channel_stds, bgr_ordering):
bgr_ordering=bgr_ordering)
def hourglass_20(channel_means, channel_stds, bgr_ordering):
def hourglass_20(channel_means, channel_stds, bgr_ordering, **kwargs):
"""The Hourglass-20 backbone for CenterNet."""
del kwargs
network = hourglass_network.hourglass_20(num_channels=48)
return CenterNetHourglassFeatureExtractor(
......@@ -91,8 +93,9 @@ def hourglass_20(channel_means, channel_stds, bgr_ordering):
bgr_ordering=bgr_ordering)
def hourglass_32(channel_means, channel_stds, bgr_ordering):
def hourglass_32(channel_means, channel_stds, bgr_ordering, **kwargs):
"""The Hourglass-32 backbone for CenterNet."""
del kwargs
network = hourglass_network.hourglass_32(num_channels=48)
return CenterNetHourglassFeatureExtractor(
......@@ -100,8 +103,9 @@ def hourglass_32(channel_means, channel_stds, bgr_ordering):
bgr_ordering=bgr_ordering)
def hourglass_52(channel_means, channel_stds, bgr_ordering):
def hourglass_52(channel_means, channel_stds, bgr_ordering, **kwargs):
"""The Hourglass-52 backbone for CenterNet."""
del kwargs
network = hourglass_network.hourglass_52(num_channels=64)
return CenterNetHourglassFeatureExtractor(
......@@ -109,8 +113,9 @@ def hourglass_52(channel_means, channel_stds, bgr_ordering):
bgr_ordering=bgr_ordering)
def hourglass_104(channel_means, channel_stds, bgr_ordering):
def hourglass_104(channel_means, channel_stds, bgr_ordering, **kwargs):
"""The Hourglass-104 backbone for CenterNet."""
del kwargs
# TODO(vighneshb): update hourglass_104 signature to match with other
# hourglass networks.
......
......@@ -110,11 +110,17 @@ class CenterNetMobileNetV2FeatureExtractor(
ValueError('Sub model type "{}" not supported.'.format(sub_model_type))
def mobilenet_v2(channel_means, channel_stds, bgr_ordering):
def mobilenet_v2(channel_means, channel_stds, bgr_ordering,
depth_multiplier=1.0, **kwargs):
"""The MobileNetV2 backbone for CenterNet."""
del kwargs
# We set 'is_training' to True for now.
network = mobilenetv2.mobilenet_v2(True, include_top=False)
network = mobilenetv2.mobilenet_v2(
batchnorm_training=True,
alpha=depth_multiplier,
include_top=False,
weights='imagenet' if depth_multiplier == 1.0 else None)
return CenterNetMobileNetV2FeatureExtractor(
network,
channel_means=channel_means,
......
......@@ -39,7 +39,7 @@ class CenterNetMobileNetV2FPNFeatureExtractor(
channel_means=(0., 0., 0.),
channel_stds=(1., 1., 1.),
bgr_ordering=False,
fpn_separable_conv=False):
use_separable_conv=False):
"""Intializes the feature extractor.
Args:
......@@ -50,7 +50,7 @@ class CenterNetMobileNetV2FPNFeatureExtractor(
channel. Each channel will be divided by its standard deviation value.
bgr_ordering: bool, if set will change the channel ordering to be in the
[blue, red, green] order.
fpn_separable_conv: If set to True, all convolutional layers in the FPN
use_separable_conv: If set to True, all convolutional layers in the FPN
network will be replaced by separable convolutions.
"""
......@@ -96,7 +96,7 @@ class CenterNetMobileNetV2FPNFeatureExtractor(
# Merge.
top_down = top_down + residual
next_num_filters = num_filters_list[i + 1] if i + 1 <= 2 else 24
if fpn_separable_conv:
if use_separable_conv:
conv = tf.keras.layers.SeparableConv2D(
filters=next_num_filters, kernel_size=3, strides=1, padding='same')
else:
......@@ -143,30 +143,20 @@ class CenterNetMobileNetV2FPNFeatureExtractor(
return 1
def mobilenet_v2_fpn(channel_means, channel_stds, bgr_ordering):
def mobilenet_v2_fpn(channel_means, channel_stds, bgr_ordering,
use_separable_conv=False, depth_multiplier=1.0, **kwargs):
"""The MobileNetV2+FPN backbone for CenterNet."""
del kwargs
# Set to batchnorm_training to True for now.
network = mobilenetv2.mobilenet_v2(batchnorm_training=True, include_top=False)
network = mobilenetv2.mobilenet_v2(
batchnorm_training=True,
alpha=depth_multiplier,
include_top=False,
weights='imagenet' if depth_multiplier == 1.0 else None)
return CenterNetMobileNetV2FPNFeatureExtractor(
network,
channel_means=channel_means,
channel_stds=channel_stds,
bgr_ordering=bgr_ordering,
fpn_separable_conv=False)
def mobilenet_v2_fpn_sep_conv(channel_means, channel_stds, bgr_ordering):
"""Same as mobilenet_v2_fpn except with separable convolution in FPN."""
# Setting batchnorm_training to True, which will use the correct
# BatchNormalization layer strategy based on the current Keras learning phase.
# TODO(yuhuic): expriment with True vs. False to understand it's effect in
# practice.
network = mobilenetv2.mobilenet_v2(batchnorm_training=True, include_top=False)
return CenterNetMobileNetV2FPNFeatureExtractor(
network,
channel_means=channel_means,
channel_stds=channel_stds,
bgr_ordering=bgr_ordering,
fpn_separable_conv=True)
use_separable_conv=use_separable_conv)
......@@ -18,7 +18,6 @@ import numpy as np
import tensorflow.compat.v1 as tf
from object_detection.models import center_net_mobilenet_v2_fpn_feature_extractor
from object_detection.models.keras_models import mobilenet_v2
from object_detection.utils import test_case
from object_detection.utils import tf_version
......@@ -28,10 +27,13 @@ class CenterNetMobileNetV2FPNFeatureExtractorTest(test_case.TestCase):
def test_center_net_mobilenet_v2_fpn_feature_extractor(self):
net = mobilenet_v2.mobilenet_v2(True, include_top=False)
model = center_net_mobilenet_v2_fpn_feature_extractor.CenterNetMobileNetV2FPNFeatureExtractor(
net)
channel_means = (0., 0., 0.)
channel_stds = (1., 1., 1.)
bgr_ordering = False
model = (
center_net_mobilenet_v2_fpn_feature_extractor.mobilenet_v2_fpn(
channel_means, channel_stds, bgr_ordering,
use_separable_conv=False))
def graph_fn():
img = np.zeros((8, 224, 224, 3), dtype=np.float32)
......@@ -50,10 +52,12 @@ class CenterNetMobileNetV2FPNFeatureExtractorTest(test_case.TestCase):
def test_center_net_mobilenet_v2_fpn_feature_extractor_sep_conv(self):
net = mobilenet_v2.mobilenet_v2(True, include_top=False)
model = center_net_mobilenet_v2_fpn_feature_extractor.CenterNetMobileNetV2FPNFeatureExtractor(
net, fpn_separable_conv=True)
channel_means = (0., 0., 0.)
channel_stds = (1., 1., 1.)
bgr_ordering = False
model = (
center_net_mobilenet_v2_fpn_feature_extractor.mobilenet_v2_fpn(
channel_means, channel_stds, bgr_ordering, use_separable_conv=True))
def graph_fn():
img = np.zeros((8, 224, 224, 3), dtype=np.float32)
......@@ -62,6 +66,10 @@ class CenterNetMobileNetV2FPNFeatureExtractorTest(test_case.TestCase):
outputs = self.execute(graph_fn, [])
self.assertEqual(outputs.shape, (8, 56, 56, 24))
# Pull out the FPN network.
backbone = model.get_layer('model')
first_conv = backbone.get_layer('Conv1')
self.assertEqual(32, first_conv.filters)
# Pull out the FPN network.
output = model.get_layer('model_1')
......@@ -71,6 +79,30 @@ class CenterNetMobileNetV2FPNFeatureExtractorTest(test_case.TestCase):
if 'conv' in layer.name and layer.kernel_size != (1, 1):
self.assertIsInstance(layer, tf.keras.layers.SeparableConv2D)
def test_center_net_mobilenet_v2_fpn_feature_extractor_depth_multiplier(self):
channel_means = (0., 0., 0.)
channel_stds = (1., 1., 1.)
bgr_ordering = False
model = (
center_net_mobilenet_v2_fpn_feature_extractor.mobilenet_v2_fpn(
channel_means, channel_stds, bgr_ordering, use_separable_conv=True,
depth_multiplier=2.0))
def graph_fn():
img = np.zeros((8, 224, 224, 3), dtype=np.float32)
processed_img = model.preprocess(img)
return model(processed_img)
outputs = self.execute(graph_fn, [])
self.assertEqual(outputs.shape, (8, 56, 56, 24))
# Pull out the FPN network.
backbone = model.get_layer('model')
first_conv = backbone.get_layer('Conv1')
# Note that the first layer typically has 32 filters, but this model has
# a depth multiplier of 2.
self.assertEqual(64, first_conv.filters)
if __name__ == '__main__':
tf.test.main()
......@@ -136,8 +136,9 @@ class CenterNetResnetFeatureExtractor(CenterNetFeatureExtractor):
ValueError('Sub model type "{}" not supported.'.format(sub_model_type))
def resnet_v2_101(channel_means, channel_stds, bgr_ordering):
def resnet_v2_101(channel_means, channel_stds, bgr_ordering, **kwargs):
"""The ResNet v2 101 feature extractor."""
del kwargs
return CenterNetResnetFeatureExtractor(
resnet_type='resnet_v2_101',
......@@ -147,8 +148,9 @@ def resnet_v2_101(channel_means, channel_stds, bgr_ordering):
)
def resnet_v2_50(channel_means, channel_stds, bgr_ordering):
def resnet_v2_50(channel_means, channel_stds, bgr_ordering, **kwargs):
"""The ResNet v2 50 feature extractor."""
del kwargs
return CenterNetResnetFeatureExtractor(
resnet_type='resnet_v2_50',
......
......@@ -172,8 +172,9 @@ class CenterNetResnetV1FpnFeatureExtractor(CenterNetFeatureExtractor):
ValueError('Sub model type "{}" not supported.'.format(sub_model_type))
def resnet_v1_101_fpn(channel_means, channel_stds, bgr_ordering):
def resnet_v1_101_fpn(channel_means, channel_stds, bgr_ordering, **kwargs):
"""The ResNet v1 101 FPN feature extractor."""
del kwargs
return CenterNetResnetV1FpnFeatureExtractor(
resnet_type='resnet_v1_101',
......@@ -183,8 +184,9 @@ def resnet_v1_101_fpn(channel_means, channel_stds, bgr_ordering):
)
def resnet_v1_50_fpn(channel_means, channel_stds, bgr_ordering):
def resnet_v1_50_fpn(channel_means, channel_stds, bgr_ordering, **kwargs):
"""The ResNet v1 50 FPN feature extractor."""
del kwargs
return CenterNetResnetV1FpnFeatureExtractor(
resnet_type='resnet_v1_50',
......@@ -193,8 +195,9 @@ def resnet_v1_50_fpn(channel_means, channel_stds, bgr_ordering):
bgr_ordering=bgr_ordering)
def resnet_v1_34_fpn(channel_means, channel_stds, bgr_ordering):
def resnet_v1_34_fpn(channel_means, channel_stds, bgr_ordering, **kwargs):
"""The ResNet v1 34 FPN feature extractor."""
del kwargs
return CenterNetResnetV1FpnFeatureExtractor(
resnet_type='resnet_v1_34',
......@@ -204,8 +207,9 @@ def resnet_v1_34_fpn(channel_means, channel_stds, bgr_ordering):
)
def resnet_v1_18_fpn(channel_means, channel_stds, bgr_ordering):
def resnet_v1_18_fpn(channel_means, channel_stds, bgr_ordering, **kwargs):
"""The ResNet v1 18 FPN feature extractor."""
del kwargs
return CenterNetResnetV1FpnFeatureExtractor(
resnet_type='resnet_v1_18',
......
......@@ -376,5 +376,12 @@ message CenterNetFeatureExtractor {
// network if any.
optional bool use_depthwise = 5 [default = false];
// Depth multiplier. Only valid for specific models (e.g. MobileNet). See subclasses of `CenterNetFeatureExtractor`.
optional float depth_multiplier = 9 [default = 1.0];
// Whether to use separable convolutions. Only valid for specific
// models. See subclasses of `CenterNetFeatureExtractor`.
optional bool use_separable_conv = 10 [default = false];
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment