Commit e7c57743 authored by Yu-hui Chen's avatar Yu-hui Chen Committed by TF Object Detection Team
Browse files

Updated the model builder and feature extractor such that the upsampling

interpolation method is configurable.

PiperOrigin-RevId: 372198977
parent 28df3e1f
......@@ -1130,11 +1130,18 @@ def _build_center_net_feature_extractor(feature_extractor_config, is_training):
feature_extractor_config.use_separable_conv or
feature_extractor_config.type == 'mobilenet_v2_fpn_sep_conv')
kwargs = {
'channel_means': list(feature_extractor_config.channel_means),
'channel_stds': list(feature_extractor_config.channel_stds),
'bgr_ordering': feature_extractor_config.bgr_ordering,
'depth_multiplier': feature_extractor_config.depth_multiplier,
'use_separable_conv': use_separable_conv,
'channel_means':
list(feature_extractor_config.channel_means),
'channel_stds':
list(feature_extractor_config.channel_stds),
'bgr_ordering':
feature_extractor_config.bgr_ordering,
'depth_multiplier':
feature_extractor_config.depth_multiplier,
'use_separable_conv':
use_separable_conv,
'upsampling_interpolation':
feature_extractor_config.upsampling_interpolation,
}
......
......@@ -398,7 +398,7 @@ class ModelBuilderTF2Test(
}
"""
# Set up the configuration proto.
config = text_format.Merge(proto_txt, model_pb2.DetectionModel())
config = text_format.Parse(proto_txt, model_pb2.DetectionModel())
# Only add object center and keypoint estimation configs here.
config.center_net.object_center_params.CopyFrom(
self.get_fake_object_center_from_keypoints_proto())
......@@ -422,6 +422,50 @@ class ModelBuilderTF2Test(
self.assertEqual(kp_params.keypoint_labels,
['nose', 'left_shoulder', 'right_shoulder', 'hip'])
def test_create_center_net_model_mobilenet(self):
"""Test building a CenterNet model using bilinear interpolation."""
proto_txt = """
center_net {
num_classes: 10
feature_extractor {
type: "mobilenet_v2_fpn"
depth_multiplier: 1.0
use_separable_conv: true
upsampling_interpolation: "bilinear"
}
image_resizer {
keep_aspect_ratio_resizer {
min_dimension: 512
max_dimension: 512
pad_to_max_dimension: true
}
}
}
"""
# Set up the configuration proto.
config = text_format.Parse(proto_txt, model_pb2.DetectionModel())
# Only add object center and keypoint estimation configs here.
config.center_net.object_center_params.CopyFrom(
self.get_fake_object_center_from_keypoints_proto())
config.center_net.keypoint_estimation_task.append(
self.get_fake_keypoint_proto())
config.center_net.keypoint_label_map_path = (
self.get_fake_label_map_file_path())
# Build the model from the configuration.
model = model_builder.build(config, is_training=True)
feature_extractor = model._feature_extractor
# Verify the upsampling layers in the FPN use 'bilinear' interpolation.
fpn = feature_extractor.get_layer('model_1')
num_up_sampling2d_layers = 0
for layer in fpn.layers:
if 'up_sampling2d' in layer.name:
num_up_sampling2d_layers += 1
self.assertEqual('bilinear', layer.interpolation)
# Verify that there are up_sampling2d layers.
self.assertGreater(num_up_sampling2d_layers, 0)
if __name__ == '__main__':
tf.test.main()
......@@ -39,7 +39,8 @@ class CenterNetMobileNetV2FPNFeatureExtractor(
channel_means=(0., 0., 0.),
channel_stds=(1., 1., 1.),
bgr_ordering=False,
use_separable_conv=False):
use_separable_conv=False,
upsampling_interpolation='nearest'):
"""Intializes the feature extractor.
Args:
......@@ -52,6 +53,9 @@ class CenterNetMobileNetV2FPNFeatureExtractor(
[blue, red, green] order.
use_separable_conv: If set to True, all convolutional layers in the FPN
network will be replaced by separable convolutions.
upsampling_interpolation: A string (one of 'nearest' or 'bilinear')
indicating which interpolation method to use for the upsampling ops in
the FPN.
"""
super(CenterNetMobileNetV2FPNFeatureExtractor, self).__init__(
......@@ -84,7 +88,8 @@ class CenterNetMobileNetV2FPNFeatureExtractor(
for i, num_filters in enumerate(num_filters_list):
level_ind = len(num_filters_list) - 1 - i
# Upsample.
upsample_op = tf.keras.layers.UpSampling2D(2, interpolation='nearest')
upsample_op = tf.keras.layers.UpSampling2D(
2, interpolation=upsampling_interpolation)
top_down = upsample_op(top_down)
# Residual (skip-connection) from bottom-up pathway.
......@@ -144,7 +149,8 @@ class CenterNetMobileNetV2FPNFeatureExtractor(
def mobilenet_v2_fpn(channel_means, channel_stds, bgr_ordering,
use_separable_conv=False, depth_multiplier=1.0, **kwargs):
use_separable_conv=False, depth_multiplier=1.0,
upsampling_interpolation='nearest', **kwargs):
"""The MobileNetV2+FPN backbone for CenterNet."""
del kwargs
......@@ -159,4 +165,5 @@ def mobilenet_v2_fpn(channel_means, channel_stds, bgr_ordering,
channel_means=channel_means,
channel_stds=channel_stds,
bgr_ordering=bgr_ordering,
use_separable_conv=use_separable_conv)
use_separable_conv=use_separable_conv,
upsampling_interpolation=upsampling_interpolation)
......@@ -103,6 +103,30 @@ class CenterNetMobileNetV2FPNFeatureExtractorTest(test_case.TestCase):
# a depth multiplier of 2.
self.assertEqual(64, first_conv.filters)
def test_center_net_mobilenet_v2_fpn_feature_extractor_interpolation(self):
channel_means = (0., 0., 0.)
channel_stds = (1., 1., 1.)
bgr_ordering = False
model = (
center_net_mobilenet_v2_fpn_feature_extractor.mobilenet_v2_fpn(
channel_means, channel_stds, bgr_ordering, use_separable_conv=True,
upsampling_interpolation='bilinear'))
def graph_fn():
img = np.zeros((8, 224, 224, 3), dtype=np.float32)
processed_img = model.preprocess(img)
return model(processed_img)
outputs = self.execute(graph_fn, [])
self.assertEqual(outputs.shape, (8, 56, 56, 24))
# Verify the upsampling layers in the FPN use 'bilinear' interpolation.
fpn = model.get_layer('model_1')
for layer in fpn.layers:
if 'up_sampling2d' in layer.name:
self.assertEqual('bilinear', layer.interpolation)
if __name__ == '__main__':
tf.test.main()
......@@ -440,11 +440,17 @@ message CenterNetFeatureExtractor {
optional bool use_depthwise = 5 [default = false];
// Depth multiplier. Only valid for specific models (e.g. MobileNet). See subclasses of `CenterNetFeatureExtractor`.
// Depth multiplier. Only valid for specific models (e.g. MobileNet). See
// subclasses of `CenterNetFeatureExtractor`.
optional float depth_multiplier = 9 [default = 1.0];
// Whether to use separable convolutions. Only valid for specific
// models. See subclasses of `CenterNetFeatureExtractor`.
optional bool use_separable_conv = 10 [default = false];
// Which interpolation method to use for the upsampling ops in the FPN.
// Currently only valid for CenterNetMobileNetV2FPNFeatureExtractor. The value
// can be on of 'nearest' or 'bilinear'.
optional string upsampling_interpolation = 11 [default = 'nearest'];
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment