"llama/git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "03e40efa51a75a4e8385b64996af6468f42f6c06"
Commit 72d37b69 authored by Ronny Votel's avatar Ronny Votel Committed by TF Object Detection Team
Browse files

Updating CenterNet Feature Extractors to plumb through model specific parameters.

PiperOrigin-RevId: 366134194
parent cee4b75e
...@@ -170,9 +170,6 @@ if tf_version.is_tf2(): ...@@ -170,9 +170,6 @@ if tf_version.is_tf2():
center_net_mobilenet_v2_feature_extractor.mobilenet_v2, center_net_mobilenet_v2_feature_extractor.mobilenet_v2,
'mobilenet_v2_fpn': 'mobilenet_v2_fpn':
center_net_mobilenet_v2_fpn_feature_extractor.mobilenet_v2_fpn, center_net_mobilenet_v2_fpn_feature_extractor.mobilenet_v2_fpn,
'mobilenet_v2_fpn_sep_conv':
center_net_mobilenet_v2_fpn_feature_extractor
.mobilenet_v2_fpn_sep_conv,
} }
FEATURE_EXTRACTOR_MAPS = [ FEATURE_EXTRACTOR_MAPS = [
...@@ -1130,6 +1127,8 @@ def _build_center_net_feature_extractor(feature_extractor_config, is_training): ...@@ -1130,6 +1127,8 @@ def _build_center_net_feature_extractor(feature_extractor_config, is_training):
'channel_means': list(feature_extractor_config.channel_means), 'channel_means': list(feature_extractor_config.channel_means),
'channel_stds': list(feature_extractor_config.channel_stds), 'channel_stds': list(feature_extractor_config.channel_stds),
'bgr_ordering': feature_extractor_config.bgr_ordering, 'bgr_ordering': feature_extractor_config.bgr_ordering,
'depth_multiplier': feature_extractor_config.depth_multiplier,
'use_separable_conv': feature_extractor_config.use_separable_conv,
} }
......
...@@ -73,8 +73,9 @@ class CenterNetHourglassFeatureExtractor( ...@@ -73,8 +73,9 @@ class CenterNetHourglassFeatureExtractor(
ValueError('Sub model type "{}" not supported.'.format(sub_model_type)) ValueError('Sub model type "{}" not supported.'.format(sub_model_type))
def hourglass_10(channel_means, channel_stds, bgr_ordering): def hourglass_10(channel_means, channel_stds, bgr_ordering, **kwargs):
"""The Hourglass-10 backbone for CenterNet.""" """The Hourglass-10 backbone for CenterNet."""
del kwargs
network = hourglass_network.hourglass_10(num_channels=32) network = hourglass_network.hourglass_10(num_channels=32)
return CenterNetHourglassFeatureExtractor( return CenterNetHourglassFeatureExtractor(
...@@ -82,8 +83,9 @@ def hourglass_10(channel_means, channel_stds, bgr_ordering): ...@@ -82,8 +83,9 @@ def hourglass_10(channel_means, channel_stds, bgr_ordering):
bgr_ordering=bgr_ordering) bgr_ordering=bgr_ordering)
def hourglass_20(channel_means, channel_stds, bgr_ordering): def hourglass_20(channel_means, channel_stds, bgr_ordering, **kwargs):
"""The Hourglass-20 backbone for CenterNet.""" """The Hourglass-20 backbone for CenterNet."""
del kwargs
network = hourglass_network.hourglass_20(num_channels=48) network = hourglass_network.hourglass_20(num_channels=48)
return CenterNetHourglassFeatureExtractor( return CenterNetHourglassFeatureExtractor(
...@@ -91,8 +93,9 @@ def hourglass_20(channel_means, channel_stds, bgr_ordering): ...@@ -91,8 +93,9 @@ def hourglass_20(channel_means, channel_stds, bgr_ordering):
bgr_ordering=bgr_ordering) bgr_ordering=bgr_ordering)
def hourglass_32(channel_means, channel_stds, bgr_ordering): def hourglass_32(channel_means, channel_stds, bgr_ordering, **kwargs):
"""The Hourglass-32 backbone for CenterNet.""" """The Hourglass-32 backbone for CenterNet."""
del kwargs
network = hourglass_network.hourglass_32(num_channels=48) network = hourglass_network.hourglass_32(num_channels=48)
return CenterNetHourglassFeatureExtractor( return CenterNetHourglassFeatureExtractor(
...@@ -100,8 +103,9 @@ def hourglass_32(channel_means, channel_stds, bgr_ordering): ...@@ -100,8 +103,9 @@ def hourglass_32(channel_means, channel_stds, bgr_ordering):
bgr_ordering=bgr_ordering) bgr_ordering=bgr_ordering)
def hourglass_52(channel_means, channel_stds, bgr_ordering): def hourglass_52(channel_means, channel_stds, bgr_ordering, **kwargs):
"""The Hourglass-52 backbone for CenterNet.""" """The Hourglass-52 backbone for CenterNet."""
del kwargs
network = hourglass_network.hourglass_52(num_channels=64) network = hourglass_network.hourglass_52(num_channels=64)
return CenterNetHourglassFeatureExtractor( return CenterNetHourglassFeatureExtractor(
...@@ -109,8 +113,9 @@ def hourglass_52(channel_means, channel_stds, bgr_ordering): ...@@ -109,8 +113,9 @@ def hourglass_52(channel_means, channel_stds, bgr_ordering):
bgr_ordering=bgr_ordering) bgr_ordering=bgr_ordering)
def hourglass_104(channel_means, channel_stds, bgr_ordering): def hourglass_104(channel_means, channel_stds, bgr_ordering, **kwargs):
"""The Hourglass-104 backbone for CenterNet.""" """The Hourglass-104 backbone for CenterNet."""
del kwargs
# TODO(vighneshb): update hourglass_104 signature to match with other # TODO(vighneshb): update hourglass_104 signature to match with other
# hourglass networks. # hourglass networks.
......
...@@ -110,11 +110,17 @@ class CenterNetMobileNetV2FeatureExtractor( ...@@ -110,11 +110,17 @@ class CenterNetMobileNetV2FeatureExtractor(
ValueError('Sub model type "{}" not supported.'.format(sub_model_type)) ValueError('Sub model type "{}" not supported.'.format(sub_model_type))
def mobilenet_v2(channel_means, channel_stds, bgr_ordering): def mobilenet_v2(channel_means, channel_stds, bgr_ordering,
depth_multiplier=1.0, **kwargs):
"""The MobileNetV2 backbone for CenterNet.""" """The MobileNetV2 backbone for CenterNet."""
del kwargs
# We set 'is_training' to True for now. # We set 'is_training' to True for now.
network = mobilenetv2.mobilenet_v2(True, include_top=False) network = mobilenetv2.mobilenet_v2(
batchnorm_training=True,
alpha=depth_multiplier,
include_top=False,
weights='imagenet' if depth_multiplier == 1.0 else None)
return CenterNetMobileNetV2FeatureExtractor( return CenterNetMobileNetV2FeatureExtractor(
network, network,
channel_means=channel_means, channel_means=channel_means,
......
...@@ -39,7 +39,7 @@ class CenterNetMobileNetV2FPNFeatureExtractor( ...@@ -39,7 +39,7 @@ class CenterNetMobileNetV2FPNFeatureExtractor(
channel_means=(0., 0., 0.), channel_means=(0., 0., 0.),
channel_stds=(1., 1., 1.), channel_stds=(1., 1., 1.),
bgr_ordering=False, bgr_ordering=False,
fpn_separable_conv=False): use_separable_conv=False):
"""Intializes the feature extractor. """Intializes the feature extractor.
Args: Args:
...@@ -50,7 +50,7 @@ class CenterNetMobileNetV2FPNFeatureExtractor( ...@@ -50,7 +50,7 @@ class CenterNetMobileNetV2FPNFeatureExtractor(
channel. Each channel will be divided by its standard deviation value. channel. Each channel will be divided by its standard deviation value.
bgr_ordering: bool, if set will change the channel ordering to be in the bgr_ordering: bool, if set will change the channel ordering to be in the
[blue, red, green] order. [blue, red, green] order.
fpn_separable_conv: If set to True, all convolutional layers in the FPN use_separable_conv: If set to True, all convolutional layers in the FPN
network will be replaced by separable convolutions. network will be replaced by separable convolutions.
""" """
...@@ -96,7 +96,7 @@ class CenterNetMobileNetV2FPNFeatureExtractor( ...@@ -96,7 +96,7 @@ class CenterNetMobileNetV2FPNFeatureExtractor(
# Merge. # Merge.
top_down = top_down + residual top_down = top_down + residual
next_num_filters = num_filters_list[i + 1] if i + 1 <= 2 else 24 next_num_filters = num_filters_list[i + 1] if i + 1 <= 2 else 24
if fpn_separable_conv: if use_separable_conv:
conv = tf.keras.layers.SeparableConv2D( conv = tf.keras.layers.SeparableConv2D(
filters=next_num_filters, kernel_size=3, strides=1, padding='same') filters=next_num_filters, kernel_size=3, strides=1, padding='same')
else: else:
...@@ -143,30 +143,20 @@ class CenterNetMobileNetV2FPNFeatureExtractor( ...@@ -143,30 +143,20 @@ class CenterNetMobileNetV2FPNFeatureExtractor(
return 1 return 1
def mobilenet_v2_fpn(channel_means, channel_stds, bgr_ordering): def mobilenet_v2_fpn(channel_means, channel_stds, bgr_ordering,
use_separable_conv=False, depth_multiplier=1.0, **kwargs):
"""The MobileNetV2+FPN backbone for CenterNet.""" """The MobileNetV2+FPN backbone for CenterNet."""
del kwargs
# Set to batchnorm_training to True for now. # Set to batchnorm_training to True for now.
network = mobilenetv2.mobilenet_v2(batchnorm_training=True, include_top=False) network = mobilenetv2.mobilenet_v2(
batchnorm_training=True,
alpha=depth_multiplier,
include_top=False,
weights='imagenet' if depth_multiplier == 1.0 else None)
return CenterNetMobileNetV2FPNFeatureExtractor( return CenterNetMobileNetV2FPNFeatureExtractor(
network, network,
channel_means=channel_means, channel_means=channel_means,
channel_stds=channel_stds, channel_stds=channel_stds,
bgr_ordering=bgr_ordering, bgr_ordering=bgr_ordering,
fpn_separable_conv=False) use_separable_conv=use_separable_conv)
def mobilenet_v2_fpn_sep_conv(channel_means, channel_stds, bgr_ordering):
"""Same as mobilenet_v2_fpn except with separable convolution in FPN."""
# Setting batchnorm_training to True, which will use the correct
# BatchNormalization layer strategy based on the current Keras learning phase.
# TODO(yuhuic): expriment with True vs. False to understand it's effect in
# practice.
network = mobilenetv2.mobilenet_v2(batchnorm_training=True, include_top=False)
return CenterNetMobileNetV2FPNFeatureExtractor(
network,
channel_means=channel_means,
channel_stds=channel_stds,
bgr_ordering=bgr_ordering,
fpn_separable_conv=True)
...@@ -18,7 +18,6 @@ import numpy as np ...@@ -18,7 +18,6 @@ import numpy as np
import tensorflow.compat.v1 as tf import tensorflow.compat.v1 as tf
from object_detection.models import center_net_mobilenet_v2_fpn_feature_extractor from object_detection.models import center_net_mobilenet_v2_fpn_feature_extractor
from object_detection.models.keras_models import mobilenet_v2
from object_detection.utils import test_case from object_detection.utils import test_case
from object_detection.utils import tf_version from object_detection.utils import tf_version
...@@ -28,10 +27,13 @@ class CenterNetMobileNetV2FPNFeatureExtractorTest(test_case.TestCase): ...@@ -28,10 +27,13 @@ class CenterNetMobileNetV2FPNFeatureExtractorTest(test_case.TestCase):
def test_center_net_mobilenet_v2_fpn_feature_extractor(self): def test_center_net_mobilenet_v2_fpn_feature_extractor(self):
net = mobilenet_v2.mobilenet_v2(True, include_top=False) channel_means = (0., 0., 0.)
channel_stds = (1., 1., 1.)
model = center_net_mobilenet_v2_fpn_feature_extractor.CenterNetMobileNetV2FPNFeatureExtractor( bgr_ordering = False
net) model = (
center_net_mobilenet_v2_fpn_feature_extractor.mobilenet_v2_fpn(
channel_means, channel_stds, bgr_ordering,
use_separable_conv=False))
def graph_fn(): def graph_fn():
img = np.zeros((8, 224, 224, 3), dtype=np.float32) img = np.zeros((8, 224, 224, 3), dtype=np.float32)
...@@ -50,10 +52,12 @@ class CenterNetMobileNetV2FPNFeatureExtractorTest(test_case.TestCase): ...@@ -50,10 +52,12 @@ class CenterNetMobileNetV2FPNFeatureExtractorTest(test_case.TestCase):
def test_center_net_mobilenet_v2_fpn_feature_extractor_sep_conv(self): def test_center_net_mobilenet_v2_fpn_feature_extractor_sep_conv(self):
net = mobilenet_v2.mobilenet_v2(True, include_top=False) channel_means = (0., 0., 0.)
channel_stds = (1., 1., 1.)
model = center_net_mobilenet_v2_fpn_feature_extractor.CenterNetMobileNetV2FPNFeatureExtractor( bgr_ordering = False
net, fpn_separable_conv=True) model = (
center_net_mobilenet_v2_fpn_feature_extractor.mobilenet_v2_fpn(
channel_means, channel_stds, bgr_ordering, use_separable_conv=True))
def graph_fn(): def graph_fn():
img = np.zeros((8, 224, 224, 3), dtype=np.float32) img = np.zeros((8, 224, 224, 3), dtype=np.float32)
...@@ -62,6 +66,10 @@ class CenterNetMobileNetV2FPNFeatureExtractorTest(test_case.TestCase): ...@@ -62,6 +66,10 @@ class CenterNetMobileNetV2FPNFeatureExtractorTest(test_case.TestCase):
outputs = self.execute(graph_fn, []) outputs = self.execute(graph_fn, [])
self.assertEqual(outputs.shape, (8, 56, 56, 24)) self.assertEqual(outputs.shape, (8, 56, 56, 24))
# Pull out the FPN network.
backbone = model.get_layer('model')
first_conv = backbone.get_layer('Conv1')
self.assertEqual(32, first_conv.filters)
# Pull out the FPN network. # Pull out the FPN network.
output = model.get_layer('model_1') output = model.get_layer('model_1')
...@@ -71,6 +79,30 @@ class CenterNetMobileNetV2FPNFeatureExtractorTest(test_case.TestCase): ...@@ -71,6 +79,30 @@ class CenterNetMobileNetV2FPNFeatureExtractorTest(test_case.TestCase):
if 'conv' in layer.name and layer.kernel_size != (1, 1): if 'conv' in layer.name and layer.kernel_size != (1, 1):
self.assertIsInstance(layer, tf.keras.layers.SeparableConv2D) self.assertIsInstance(layer, tf.keras.layers.SeparableConv2D)
def test_center_net_mobilenet_v2_fpn_feature_extractor_depth_multiplier(self):
channel_means = (0., 0., 0.)
channel_stds = (1., 1., 1.)
bgr_ordering = False
model = (
center_net_mobilenet_v2_fpn_feature_extractor.mobilenet_v2_fpn(
channel_means, channel_stds, bgr_ordering, use_separable_conv=True,
depth_multiplier=2.0))
def graph_fn():
img = np.zeros((8, 224, 224, 3), dtype=np.float32)
processed_img = model.preprocess(img)
return model(processed_img)
outputs = self.execute(graph_fn, [])
self.assertEqual(outputs.shape, (8, 56, 56, 24))
# Pull out the FPN network.
backbone = model.get_layer('model')
first_conv = backbone.get_layer('Conv1')
# Note that the first layer typically has 32 filters, but this model has
# a depth multiplier of 2.
self.assertEqual(64, first_conv.filters)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
...@@ -136,8 +136,9 @@ class CenterNetResnetFeatureExtractor(CenterNetFeatureExtractor): ...@@ -136,8 +136,9 @@ class CenterNetResnetFeatureExtractor(CenterNetFeatureExtractor):
ValueError('Sub model type "{}" not supported.'.format(sub_model_type)) ValueError('Sub model type "{}" not supported.'.format(sub_model_type))
def resnet_v2_101(channel_means, channel_stds, bgr_ordering): def resnet_v2_101(channel_means, channel_stds, bgr_ordering, **kwargs):
"""The ResNet v2 101 feature extractor.""" """The ResNet v2 101 feature extractor."""
del kwargs
return CenterNetResnetFeatureExtractor( return CenterNetResnetFeatureExtractor(
resnet_type='resnet_v2_101', resnet_type='resnet_v2_101',
...@@ -147,8 +148,9 @@ def resnet_v2_101(channel_means, channel_stds, bgr_ordering): ...@@ -147,8 +148,9 @@ def resnet_v2_101(channel_means, channel_stds, bgr_ordering):
) )
def resnet_v2_50(channel_means, channel_stds, bgr_ordering): def resnet_v2_50(channel_means, channel_stds, bgr_ordering, **kwargs):
"""The ResNet v2 50 feature extractor.""" """The ResNet v2 50 feature extractor."""
del kwargs
return CenterNetResnetFeatureExtractor( return CenterNetResnetFeatureExtractor(
resnet_type='resnet_v2_50', resnet_type='resnet_v2_50',
......
...@@ -172,8 +172,9 @@ class CenterNetResnetV1FpnFeatureExtractor(CenterNetFeatureExtractor): ...@@ -172,8 +172,9 @@ class CenterNetResnetV1FpnFeatureExtractor(CenterNetFeatureExtractor):
ValueError('Sub model type "{}" not supported.'.format(sub_model_type)) ValueError('Sub model type "{}" not supported.'.format(sub_model_type))
def resnet_v1_101_fpn(channel_means, channel_stds, bgr_ordering): def resnet_v1_101_fpn(channel_means, channel_stds, bgr_ordering, **kwargs):
"""The ResNet v1 101 FPN feature extractor.""" """The ResNet v1 101 FPN feature extractor."""
del kwargs
return CenterNetResnetV1FpnFeatureExtractor( return CenterNetResnetV1FpnFeatureExtractor(
resnet_type='resnet_v1_101', resnet_type='resnet_v1_101',
...@@ -183,8 +184,9 @@ def resnet_v1_101_fpn(channel_means, channel_stds, bgr_ordering): ...@@ -183,8 +184,9 @@ def resnet_v1_101_fpn(channel_means, channel_stds, bgr_ordering):
) )
def resnet_v1_50_fpn(channel_means, channel_stds, bgr_ordering): def resnet_v1_50_fpn(channel_means, channel_stds, bgr_ordering, **kwargs):
"""The ResNet v1 50 FPN feature extractor.""" """The ResNet v1 50 FPN feature extractor."""
del kwargs
return CenterNetResnetV1FpnFeatureExtractor( return CenterNetResnetV1FpnFeatureExtractor(
resnet_type='resnet_v1_50', resnet_type='resnet_v1_50',
...@@ -193,8 +195,9 @@ def resnet_v1_50_fpn(channel_means, channel_stds, bgr_ordering): ...@@ -193,8 +195,9 @@ def resnet_v1_50_fpn(channel_means, channel_stds, bgr_ordering):
bgr_ordering=bgr_ordering) bgr_ordering=bgr_ordering)
def resnet_v1_34_fpn(channel_means, channel_stds, bgr_ordering): def resnet_v1_34_fpn(channel_means, channel_stds, bgr_ordering, **kwargs):
"""The ResNet v1 34 FPN feature extractor.""" """The ResNet v1 34 FPN feature extractor."""
del kwargs
return CenterNetResnetV1FpnFeatureExtractor( return CenterNetResnetV1FpnFeatureExtractor(
resnet_type='resnet_v1_34', resnet_type='resnet_v1_34',
...@@ -204,8 +207,9 @@ def resnet_v1_34_fpn(channel_means, channel_stds, bgr_ordering): ...@@ -204,8 +207,9 @@ def resnet_v1_34_fpn(channel_means, channel_stds, bgr_ordering):
) )
def resnet_v1_18_fpn(channel_means, channel_stds, bgr_ordering): def resnet_v1_18_fpn(channel_means, channel_stds, bgr_ordering, **kwargs):
"""The ResNet v1 18 FPN feature extractor.""" """The ResNet v1 18 FPN feature extractor."""
del kwargs
return CenterNetResnetV1FpnFeatureExtractor( return CenterNetResnetV1FpnFeatureExtractor(
resnet_type='resnet_v1_18', resnet_type='resnet_v1_18',
......
...@@ -376,5 +376,12 @@ message CenterNetFeatureExtractor { ...@@ -376,5 +376,12 @@ message CenterNetFeatureExtractor {
// network if any. // network if any.
optional bool use_depthwise = 5 [default = false]; optional bool use_depthwise = 5 [default = false];
// Depth multiplier. Only valid for specific models (e.g. MobileNet). See subclasses of `CenterNetFeatureExtractor`.
optional float depth_multiplier = 9 [default = 1.0];
// Whether to use separable convolutions. Only valid for specific
// models. See subclasses of `CenterNetFeatureExtractor`.
optional bool use_separable_conv = 10 [default = false];
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment