"tests/vscode:/vscode.git/clone" did not exist on "b3e792bc9da53389b3df4d8ed803e50d1644a0f3"
Commit a3a379a8 authored by Yu-hui Chen's avatar Yu-hui Chen Committed by TF Object Detection Team
Browse files

Added the option to build CenterNet MobileNetV2 model with separable convolution in the

FPN network.

PiperOrigin-RevId: 351255218
parent 6e01e1cd
......@@ -159,6 +159,9 @@ if tf_version.is_tf2():
center_net_mobilenet_v2_feature_extractor.mobilenet_v2,
'mobilenet_v2_fpn':
center_net_mobilenet_v2_fpn_feature_extractor.mobilenet_v2_fpn,
'mobilenet_v2_fpn_sep_conv':
center_net_mobilenet_v2_fpn_feature_extractor
.mobilenet_v2_fpn_sep_conv,
}
FEATURE_EXTRACTOR_MAPS = [
......
......@@ -38,7 +38,8 @@ class CenterNetMobileNetV2FPNFeatureExtractor(
mobilenet_v2_net,
channel_means=(0., 0., 0.),
channel_stds=(1., 1., 1.),
bgr_ordering=False):
bgr_ordering=False,
fpn_separable_conv=False):
"""Intializes the feature extractor.
Args:
......@@ -49,6 +50,8 @@ class CenterNetMobileNetV2FPNFeatureExtractor(
channel. Each channel will be divided by its standard deviation value.
bgr_ordering: bool, if set will change the channel ordering to be in the
[blue, red, green] order.
fpn_separable_conv: If set to True, all convolutional layers in the FPN
network will be replaced by separable convolutions.
"""
super(CenterNetMobileNetV2FPNFeatureExtractor, self).__init__(
......@@ -72,6 +75,7 @@ class CenterNetMobileNetV2FPNFeatureExtractor(
# 7x7x1280, which we continually upsample, apply a residual on and merge.
# This results in a 56x56x24 output volume.
top_layer = fpn_outputs[-1]
# Use normal convolutional layer since the kernel_size is 1.
residual_op = tf.keras.layers.Conv2D(
filters=64, kernel_size=1, strides=1, padding='same')
top_down = residual_op(top_layer)
......@@ -84,6 +88,7 @@ class CenterNetMobileNetV2FPNFeatureExtractor(
top_down = upsample_op(top_down)
# Residual (skip-connection) from bottom-up pathway.
# Use normal convolutional layer since the kernel_size is 1.
residual_op = tf.keras.layers.Conv2D(
filters=num_filters, kernel_size=1, strides=1, padding='same')
residual = residual_op(fpn_outputs[level_ind])
......@@ -91,6 +96,10 @@ class CenterNetMobileNetV2FPNFeatureExtractor(
# Merge.
top_down = top_down + residual
next_num_filters = num_filters_list[i + 1] if i + 1 <= 2 else 24
if fpn_separable_conv:
conv = tf.keras.layers.SeparableConv2D(
filters=next_num_filters, kernel_size=3, strides=1, padding='same')
else:
conv = tf.keras.layers.Conv2D(
filters=next_num_filters, kernel_size=3, strides=1, padding='same')
top_down = conv(top_down)
......@@ -133,10 +142,27 @@ class CenterNetMobileNetV2FPNFeatureExtractor(
def mobilenet_v2_fpn(channel_means, channel_stds, bgr_ordering):
"""The MobileNetV2+FPN backbone for CenterNet."""
# Set to is_training to True for now.
network = mobilenetv2.mobilenet_v2(True, include_top=False)
# Set to batchnorm_training to True for now.
network = mobilenetv2.mobilenet_v2(batchnorm_training=True, include_top=False)
return CenterNetMobileNetV2FPNFeatureExtractor(
network,
channel_means=channel_means,
channel_stds=channel_stds,
bgr_ordering=bgr_ordering)
bgr_ordering=bgr_ordering,
fpn_separable_conv=False)
def mobilenet_v2_fpn_sep_conv(channel_means, channel_stds, bgr_ordering):
"""Same as mobilenet_v2_fpn except with separable convolution in FPN."""
# Setting batchnorm_training to True, which will use the correct
# BatchNormalization layer strategy based on the current Keras learning phase.
# TODO(yuhuic): expriment with True vs. False to understand it's effect in
# practice.
network = mobilenetv2.mobilenet_v2(batchnorm_training=True, include_top=False)
return CenterNetMobileNetV2FPNFeatureExtractor(
network,
channel_means=channel_means,
channel_stds=channel_stds,
bgr_ordering=bgr_ordering,
fpn_separable_conv=True)
......@@ -41,6 +41,36 @@ class CenterNetMobileNetV2FPNFeatureExtractorTest(test_case.TestCase):
outputs = self.execute(graph_fn, [])
self.assertEqual(outputs.shape, (8, 56, 56, 24))
# Pull out the FPN network.
output = model.get_layer('model_1')
for layer in output.layers:
# All convolution layers should be normal 2D convolutions.
if 'conv' in layer.name:
self.assertIsInstance(layer, tf.keras.layers.Conv2D)
def test_center_net_mobilenet_v2_fpn_feature_extractor_sep_conv(self):
net = mobilenet_v2.mobilenet_v2(True, include_top=False)
model = center_net_mobilenet_v2_fpn_feature_extractor.CenterNetMobileNetV2FPNFeatureExtractor(
net, fpn_separable_conv=True)
def graph_fn():
img = np.zeros((8, 224, 224, 3), dtype=np.float32)
processed_img = model.preprocess(img)
return model(processed_img)
outputs = self.execute(graph_fn, [])
self.assertEqual(outputs.shape, (8, 56, 56, 24))
# Pull out the FPN network.
output = model.get_layer('model_1')
for layer in output.layers:
# Convolution layers with kernel size not equal to (1, 1) should be
# separable 2D convolutions.
if 'conv' in layer.name and layer.kernel_size != (1, 1):
self.assertIsInstance(layer, tf.keras.layers.SeparableConv2D)
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment