"docs/vscode:/vscode.git/clone" did not exist on "14dd917ec60fa69ce3f7c6e3f2eaf520e67928b5"
Commit 72d37b69 authored by Ronny Votel's avatar Ronny Votel Committed by TF Object Detection Team
Browse files

Updating CenterNet Feature Extractors to plumb through model specific parameters.

PiperOrigin-RevId: 366134194
parent cee4b75e
......@@ -170,9 +170,6 @@ if tf_version.is_tf2():
center_net_mobilenet_v2_feature_extractor.mobilenet_v2,
'mobilenet_v2_fpn':
center_net_mobilenet_v2_fpn_feature_extractor.mobilenet_v2_fpn,
'mobilenet_v2_fpn_sep_conv':
center_net_mobilenet_v2_fpn_feature_extractor
.mobilenet_v2_fpn_sep_conv,
}
FEATURE_EXTRACTOR_MAPS = [
......@@ -1130,6 +1127,8 @@ def _build_center_net_feature_extractor(feature_extractor_config, is_training):
'channel_means': list(feature_extractor_config.channel_means),
'channel_stds': list(feature_extractor_config.channel_stds),
'bgr_ordering': feature_extractor_config.bgr_ordering,
'depth_multiplier': feature_extractor_config.depth_multiplier,
'use_separable_conv': feature_extractor_config.use_separable_conv,
}
......
......@@ -73,8 +73,9 @@ class CenterNetHourglassFeatureExtractor(
ValueError('Sub model type "{}" not supported.'.format(sub_model_type))
def hourglass_10(channel_means, channel_stds, bgr_ordering):
def hourglass_10(channel_means, channel_stds, bgr_ordering, **kwargs):
"""The Hourglass-10 backbone for CenterNet."""
del kwargs
network = hourglass_network.hourglass_10(num_channels=32)
return CenterNetHourglassFeatureExtractor(
......@@ -82,8 +83,9 @@ def hourglass_10(channel_means, channel_stds, bgr_ordering):
bgr_ordering=bgr_ordering)
def hourglass_20(channel_means, channel_stds, bgr_ordering):
def hourglass_20(channel_means, channel_stds, bgr_ordering, **kwargs):
"""The Hourglass-20 backbone for CenterNet."""
del kwargs
network = hourglass_network.hourglass_20(num_channels=48)
return CenterNetHourglassFeatureExtractor(
......@@ -91,8 +93,9 @@ def hourglass_20(channel_means, channel_stds, bgr_ordering):
bgr_ordering=bgr_ordering)
def hourglass_32(channel_means, channel_stds, bgr_ordering):
def hourglass_32(channel_means, channel_stds, bgr_ordering, **kwargs):
"""The Hourglass-32 backbone for CenterNet."""
del kwargs
network = hourglass_network.hourglass_32(num_channels=48)
return CenterNetHourglassFeatureExtractor(
......@@ -100,8 +103,9 @@ def hourglass_32(channel_means, channel_stds, bgr_ordering):
bgr_ordering=bgr_ordering)
def hourglass_52(channel_means, channel_stds, bgr_ordering):
def hourglass_52(channel_means, channel_stds, bgr_ordering, **kwargs):
"""The Hourglass-52 backbone for CenterNet."""
del kwargs
network = hourglass_network.hourglass_52(num_channels=64)
return CenterNetHourglassFeatureExtractor(
......@@ -109,8 +113,9 @@ def hourglass_52(channel_means, channel_stds, bgr_ordering):
bgr_ordering=bgr_ordering)
def hourglass_104(channel_means, channel_stds, bgr_ordering):
def hourglass_104(channel_means, channel_stds, bgr_ordering, **kwargs):
"""The Hourglass-104 backbone for CenterNet."""
del kwargs
# TODO(vighneshb): update hourglass_104 signature to match with other
# hourglass networks.
......
......@@ -110,11 +110,17 @@ class CenterNetMobileNetV2FeatureExtractor(
ValueError('Sub model type "{}" not supported.'.format(sub_model_type))
def mobilenet_v2(channel_means, channel_stds, bgr_ordering):
def mobilenet_v2(channel_means, channel_stds, bgr_ordering,
depth_multiplier=1.0, **kwargs):
"""The MobileNetV2 backbone for CenterNet."""
del kwargs
# We set 'is_training' to True for now.
network = mobilenetv2.mobilenet_v2(True, include_top=False)
network = mobilenetv2.mobilenet_v2(
batchnorm_training=True,
alpha=depth_multiplier,
include_top=False,
weights='imagenet' if depth_multiplier == 1.0 else None)
return CenterNetMobileNetV2FeatureExtractor(
network,
channel_means=channel_means,
......
......@@ -39,7 +39,7 @@ class CenterNetMobileNetV2FPNFeatureExtractor(
channel_means=(0., 0., 0.),
channel_stds=(1., 1., 1.),
bgr_ordering=False,
fpn_separable_conv=False):
use_separable_conv=False):
"""Intializes the feature extractor.
Args:
......@@ -50,7 +50,7 @@ class CenterNetMobileNetV2FPNFeatureExtractor(
channel. Each channel will be divided by its standard deviation value.
bgr_ordering: bool, if set will change the channel ordering to be in the
[blue, red, green] order.
fpn_separable_conv: If set to True, all convolutional layers in the FPN
use_separable_conv: If set to True, all convolutional layers in the FPN
network will be replaced by separable convolutions.
"""
......@@ -96,7 +96,7 @@ class CenterNetMobileNetV2FPNFeatureExtractor(
# Merge.
top_down = top_down + residual
next_num_filters = num_filters_list[i + 1] if i + 1 <= 2 else 24
if fpn_separable_conv:
if use_separable_conv:
conv = tf.keras.layers.SeparableConv2D(
filters=next_num_filters, kernel_size=3, strides=1, padding='same')
else:
......@@ -143,30 +143,20 @@ class CenterNetMobileNetV2FPNFeatureExtractor(
return 1
def mobilenet_v2_fpn(channel_means, channel_stds, bgr_ordering):
def mobilenet_v2_fpn(channel_means, channel_stds, bgr_ordering,
use_separable_conv=False, depth_multiplier=1.0, **kwargs):
"""The MobileNetV2+FPN backbone for CenterNet."""
del kwargs
# Set to batchnorm_training to True for now.
network = mobilenetv2.mobilenet_v2(batchnorm_training=True, include_top=False)
network = mobilenetv2.mobilenet_v2(
batchnorm_training=True,
alpha=depth_multiplier,
include_top=False,
weights='imagenet' if depth_multiplier == 1.0 else None)
return CenterNetMobileNetV2FPNFeatureExtractor(
network,
channel_means=channel_means,
channel_stds=channel_stds,
bgr_ordering=bgr_ordering,
fpn_separable_conv=False)
def mobilenet_v2_fpn_sep_conv(channel_means, channel_stds, bgr_ordering):
"""Same as mobilenet_v2_fpn except with separable convolution in FPN."""
# Setting batchnorm_training to True, which will use the correct
# BatchNormalization layer strategy based on the current Keras learning phase.
# TODO(yuhuic): expriment with True vs. False to understand it's effect in
# practice.
network = mobilenetv2.mobilenet_v2(batchnorm_training=True, include_top=False)
return CenterNetMobileNetV2FPNFeatureExtractor(
network,
channel_means=channel_means,
channel_stds=channel_stds,
bgr_ordering=bgr_ordering,
fpn_separable_conv=True)
use_separable_conv=use_separable_conv)
......@@ -18,7 +18,6 @@ import numpy as np
import tensorflow.compat.v1 as tf
from object_detection.models import center_net_mobilenet_v2_fpn_feature_extractor
from object_detection.models.keras_models import mobilenet_v2
from object_detection.utils import test_case
from object_detection.utils import tf_version
......@@ -28,10 +27,13 @@ class CenterNetMobileNetV2FPNFeatureExtractorTest(test_case.TestCase):
def test_center_net_mobilenet_v2_fpn_feature_extractor(self):
net = mobilenet_v2.mobilenet_v2(True, include_top=False)
model = center_net_mobilenet_v2_fpn_feature_extractor.CenterNetMobileNetV2FPNFeatureExtractor(
net)
channel_means = (0., 0., 0.)
channel_stds = (1., 1., 1.)
bgr_ordering = False
model = (
center_net_mobilenet_v2_fpn_feature_extractor.mobilenet_v2_fpn(
channel_means, channel_stds, bgr_ordering,
use_separable_conv=False))
def graph_fn():
img = np.zeros((8, 224, 224, 3), dtype=np.float32)
......@@ -50,10 +52,12 @@ class CenterNetMobileNetV2FPNFeatureExtractorTest(test_case.TestCase):
def test_center_net_mobilenet_v2_fpn_feature_extractor_sep_conv(self):
net = mobilenet_v2.mobilenet_v2(True, include_top=False)
model = center_net_mobilenet_v2_fpn_feature_extractor.CenterNetMobileNetV2FPNFeatureExtractor(
net, fpn_separable_conv=True)
channel_means = (0., 0., 0.)
channel_stds = (1., 1., 1.)
bgr_ordering = False
model = (
center_net_mobilenet_v2_fpn_feature_extractor.mobilenet_v2_fpn(
channel_means, channel_stds, bgr_ordering, use_separable_conv=True))
def graph_fn():
img = np.zeros((8, 224, 224, 3), dtype=np.float32)
......@@ -62,6 +66,10 @@ class CenterNetMobileNetV2FPNFeatureExtractorTest(test_case.TestCase):
outputs = self.execute(graph_fn, [])
self.assertEqual(outputs.shape, (8, 56, 56, 24))
# Pull out the FPN network.
backbone = model.get_layer('model')
first_conv = backbone.get_layer('Conv1')
self.assertEqual(32, first_conv.filters)
# Pull out the FPN network.
output = model.get_layer('model_1')
......@@ -71,6 +79,30 @@ class CenterNetMobileNetV2FPNFeatureExtractorTest(test_case.TestCase):
if 'conv' in layer.name and layer.kernel_size != (1, 1):
self.assertIsInstance(layer, tf.keras.layers.SeparableConv2D)
def test_center_net_mobilenet_v2_fpn_feature_extractor_depth_multiplier(self):
channel_means = (0., 0., 0.)
channel_stds = (1., 1., 1.)
bgr_ordering = False
model = (
center_net_mobilenet_v2_fpn_feature_extractor.mobilenet_v2_fpn(
channel_means, channel_stds, bgr_ordering, use_separable_conv=True,
depth_multiplier=2.0))
def graph_fn():
img = np.zeros((8, 224, 224, 3), dtype=np.float32)
processed_img = model.preprocess(img)
return model(processed_img)
outputs = self.execute(graph_fn, [])
self.assertEqual(outputs.shape, (8, 56, 56, 24))
# Pull out the FPN network.
backbone = model.get_layer('model')
first_conv = backbone.get_layer('Conv1')
# Note that the first layer typically has 32 filters, but this model has
# a depth multiplier of 2.
self.assertEqual(64, first_conv.filters)
if __name__ == '__main__':
tf.test.main()
......@@ -136,8 +136,9 @@ class CenterNetResnetFeatureExtractor(CenterNetFeatureExtractor):
ValueError('Sub model type "{}" not supported.'.format(sub_model_type))
def resnet_v2_101(channel_means, channel_stds, bgr_ordering):
def resnet_v2_101(channel_means, channel_stds, bgr_ordering, **kwargs):
"""The ResNet v2 101 feature extractor."""
del kwargs
return CenterNetResnetFeatureExtractor(
resnet_type='resnet_v2_101',
......@@ -147,8 +148,9 @@ def resnet_v2_101(channel_means, channel_stds, bgr_ordering):
)
def resnet_v2_50(channel_means, channel_stds, bgr_ordering):
def resnet_v2_50(channel_means, channel_stds, bgr_ordering, **kwargs):
"""The ResNet v2 50 feature extractor."""
del kwargs
return CenterNetResnetFeatureExtractor(
resnet_type='resnet_v2_50',
......
......@@ -172,8 +172,9 @@ class CenterNetResnetV1FpnFeatureExtractor(CenterNetFeatureExtractor):
ValueError('Sub model type "{}" not supported.'.format(sub_model_type))
def resnet_v1_101_fpn(channel_means, channel_stds, bgr_ordering):
def resnet_v1_101_fpn(channel_means, channel_stds, bgr_ordering, **kwargs):
"""The ResNet v1 101 FPN feature extractor."""
del kwargs
return CenterNetResnetV1FpnFeatureExtractor(
resnet_type='resnet_v1_101',
......@@ -183,8 +184,9 @@ def resnet_v1_101_fpn(channel_means, channel_stds, bgr_ordering):
)
def resnet_v1_50_fpn(channel_means, channel_stds, bgr_ordering):
def resnet_v1_50_fpn(channel_means, channel_stds, bgr_ordering, **kwargs):
"""The ResNet v1 50 FPN feature extractor."""
del kwargs
return CenterNetResnetV1FpnFeatureExtractor(
resnet_type='resnet_v1_50',
......@@ -193,8 +195,9 @@ def resnet_v1_50_fpn(channel_means, channel_stds, bgr_ordering):
bgr_ordering=bgr_ordering)
def resnet_v1_34_fpn(channel_means, channel_stds, bgr_ordering):
def resnet_v1_34_fpn(channel_means, channel_stds, bgr_ordering, **kwargs):
"""The ResNet v1 34 FPN feature extractor."""
del kwargs
return CenterNetResnetV1FpnFeatureExtractor(
resnet_type='resnet_v1_34',
......@@ -204,8 +207,9 @@ def resnet_v1_34_fpn(channel_means, channel_stds, bgr_ordering):
)
def resnet_v1_18_fpn(channel_means, channel_stds, bgr_ordering):
def resnet_v1_18_fpn(channel_means, channel_stds, bgr_ordering, **kwargs):
"""The ResNet v1 18 FPN feature extractor."""
del kwargs
return CenterNetResnetV1FpnFeatureExtractor(
resnet_type='resnet_v1_18',
......
......@@ -376,5 +376,12 @@ message CenterNetFeatureExtractor {
// network if any.
optional bool use_depthwise = 5 [default = false];
// Depth multiplier. Only valid for specific models (e.g. MobileNet). See subclasses of `CenterNetFeatureExtractor`.
optional float depth_multiplier = 9 [default = 1.0];
// Whether to use separable convolutions. Only valid for specific
// models. See subclasses of `CenterNetFeatureExtractor`.
optional bool use_separable_conv = 10 [default = false];
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment