Commit ac739092 authored by Jaehong Kim's avatar Jaehong Kim Committed by A. Unique TensorFlower
Browse files

Add use_keras_layer flag for FPN.

PiperOrigin-RevId: 448849223
parent 23ac769f
...@@ -33,6 +33,7 @@ class FPN(hyperparams.Config): ...@@ -33,6 +33,7 @@ class FPN(hyperparams.Config):
num_filters: int = 256 num_filters: int = 256
fusion_type: str = 'sum' fusion_type: str = 'sum'
use_separable_conv: bool = False use_separable_conv: bool = False
use_keras_layer: bool = False
@dataclasses.dataclass @dataclasses.dataclass
......
...@@ -44,6 +44,7 @@ class FPN(tf.keras.Model): ...@@ -44,6 +44,7 @@ class FPN(tf.keras.Model):
num_filters: int = 256, num_filters: int = 256,
fusion_type: str = 'sum', fusion_type: str = 'sum',
use_separable_conv: bool = False, use_separable_conv: bool = False,
use_keras_layer: bool = False,
activation: str = 'relu', activation: str = 'relu',
use_sync_bn: bool = False, use_sync_bn: bool = False,
norm_momentum: float = 0.99, norm_momentum: float = 0.99,
...@@ -64,6 +65,7 @@ class FPN(tf.keras.Model): ...@@ -64,6 +65,7 @@ class FPN(tf.keras.Model):
concat for feature fusion. concat for feature fusion.
use_separable_conv: A `bool`. If True use separable convolution for use_separable_conv: A `bool`. If True use separable convolution for
convolution in FPN layers. convolution in FPN layers.
use_keras_layer: A `bool`. If Ture use keras layers as many as possible.
activation: A `str` name of the activation function. activation: A `str` name of the activation function.
use_sync_bn: A `bool`. If True, use synchronized batch normalization. use_sync_bn: A `bool`. If True, use synchronized batch normalization.
norm_momentum: A `float` of normalization momentum for the moving average. norm_momentum: A `float` of normalization momentum for the moving average.
...@@ -82,6 +84,7 @@ class FPN(tf.keras.Model): ...@@ -82,6 +84,7 @@ class FPN(tf.keras.Model):
'num_filters': num_filters, 'num_filters': num_filters,
'fusion_type': fusion_type, 'fusion_type': fusion_type,
'use_separable_conv': use_separable_conv, 'use_separable_conv': use_separable_conv,
'use_keras_layer': use_keras_layer,
'activation': activation, 'activation': activation,
'use_sync_bn': use_sync_bn, 'use_sync_bn': use_sync_bn,
'norm_momentum': norm_momentum, 'norm_momentum': norm_momentum,
...@@ -98,8 +101,7 @@ class FPN(tf.keras.Model): ...@@ -98,8 +101,7 @@ class FPN(tf.keras.Model):
norm = tf.keras.layers.experimental.SyncBatchNormalization norm = tf.keras.layers.experimental.SyncBatchNormalization
else: else:
norm = tf.keras.layers.BatchNormalization norm = tf.keras.layers.BatchNormalization
activation_fn = tf.keras.layers.Activation( activation_fn = tf_utils.get_activation(activation, use_keras_layer=True)
tf_utils.get_activation(activation))
# Build input feature pyramid. # Build input feature pyramid.
if tf.keras.backend.image_data_format() == 'channels_last': if tf.keras.backend.image_data_format() == 'channels_last':
...@@ -128,12 +130,19 @@ class FPN(tf.keras.Model): ...@@ -128,12 +130,19 @@ class FPN(tf.keras.Model):
feats = {str(backbone_max_level): feats_lateral[str(backbone_max_level)]} feats = {str(backbone_max_level): feats_lateral[str(backbone_max_level)]}
for level in range(backbone_max_level - 1, min_level - 1, -1): for level in range(backbone_max_level - 1, min_level - 1, -1):
feat_a = spatial_transform_ops.nearest_upsampling( feat_a = spatial_transform_ops.nearest_upsampling(
feats[str(level + 1)], 2) feats[str(level + 1)], 2, use_keras_layer=use_keras_layer)
feat_b = feats_lateral[str(level)] feat_b = feats_lateral[str(level)]
if fusion_type == 'sum': if fusion_type == 'sum':
if use_keras_layer:
feats[str(level)] = tf.keras.layers.Add()([feat_a, feat_b])
else:
feats[str(level)] = feat_a + feat_b feats[str(level)] = feat_a + feat_b
elif fusion_type == 'concat': elif fusion_type == 'concat':
if use_keras_layer:
feats[str(level)] = tf.keras.layers.Concatenate(axis=-1)(
[feat_a, feat_b])
else:
feats[str(level)] = tf.concat([feat_a, feat_b], axis=-1) feats[str(level)] = tf.concat([feat_a, feat_b], axis=-1)
else: else:
raise ValueError('Fusion type {} not supported.'.format(fusion_type)) raise ValueError('Fusion type {} not supported.'.format(fusion_type))
...@@ -239,6 +248,7 @@ def build_fpn_decoder( ...@@ -239,6 +248,7 @@ def build_fpn_decoder(
num_filters=decoder_cfg.num_filters, num_filters=decoder_cfg.num_filters,
fusion_type=decoder_cfg.fusion_type, fusion_type=decoder_cfg.fusion_type,
use_separable_conv=decoder_cfg.use_separable_conv, use_separable_conv=decoder_cfg.use_separable_conv,
use_keras_layer=decoder_cfg.use_keras_layer,
activation=norm_activation_config.activation, activation=norm_activation_config.activation,
use_sync_bn=norm_activation_config.use_sync_bn, use_sync_bn=norm_activation_config.use_sync_bn,
norm_momentum=norm_activation_config.norm_momentum, norm_momentum=norm_activation_config.norm_momentum,
......
...@@ -26,11 +26,13 @@ from official.vision.modeling.decoders import fpn ...@@ -26,11 +26,13 @@ from official.vision.modeling.decoders import fpn
class FPNTest(parameterized.TestCase, tf.test.TestCase): class FPNTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters( @parameterized.parameters(
(256, 3, 7, False, 'sum'), (256, 3, 7, False, False, 'sum'),
(256, 3, 7, True, 'concat'), (256, 3, 7, False, True, 'sum'),
(256, 3, 7, True, False, 'concat'),
(256, 3, 7, True, True, 'concat'),
) )
def test_network_creation(self, input_size, min_level, max_level, def test_network_creation(self, input_size, min_level, max_level,
use_separable_conv, fusion_type): use_separable_conv, use_keras_layer, fusion_type):
"""Test creation of FPN.""" """Test creation of FPN."""
tf.keras.backend.set_image_data_format('channels_last') tf.keras.backend.set_image_data_format('channels_last')
...@@ -42,7 +44,8 @@ class FPNTest(parameterized.TestCase, tf.test.TestCase): ...@@ -42,7 +44,8 @@ class FPNTest(parameterized.TestCase, tf.test.TestCase):
min_level=min_level, min_level=min_level,
max_level=max_level, max_level=max_level,
fusion_type=fusion_type, fusion_type=fusion_type,
use_separable_conv=use_separable_conv) use_separable_conv=use_separable_conv,
use_keras_layer=use_keras_layer)
endpoints = backbone(inputs) endpoints = backbone(inputs)
feats = network(endpoints) feats = network(endpoints)
...@@ -54,11 +57,14 @@ class FPNTest(parameterized.TestCase, tf.test.TestCase): ...@@ -54,11 +57,14 @@ class FPNTest(parameterized.TestCase, tf.test.TestCase):
feats[str(level)].shape.as_list()) feats[str(level)].shape.as_list())
@parameterized.parameters( @parameterized.parameters(
(256, 3, 7, False), (256, 3, 7, False, False),
(256, 3, 7, True), (256, 3, 7, False, True),
(256, 3, 7, True, False),
(256, 3, 7, True, True),
) )
def test_network_creation_with_mobilenet(self, input_size, min_level, def test_network_creation_with_mobilenet(self, input_size, min_level,
max_level, use_separable_conv): max_level, use_separable_conv,
use_keras_layer):
"""Test creation of FPN with mobilenet backbone.""" """Test creation of FPN with mobilenet backbone."""
tf.keras.backend.set_image_data_format('channels_last') tf.keras.backend.set_image_data_format('channels_last')
...@@ -69,7 +75,8 @@ class FPNTest(parameterized.TestCase, tf.test.TestCase): ...@@ -69,7 +75,8 @@ class FPNTest(parameterized.TestCase, tf.test.TestCase):
input_specs=backbone.output_specs, input_specs=backbone.output_specs,
min_level=min_level, min_level=min_level,
max_level=max_level, max_level=max_level,
use_separable_conv=use_separable_conv) use_separable_conv=use_separable_conv,
use_keras_layer=use_keras_layer)
endpoints = backbone(inputs) endpoints = backbone(inputs)
feats = network(endpoints) feats = network(endpoints)
...@@ -89,6 +96,7 @@ class FPNTest(parameterized.TestCase, tf.test.TestCase): ...@@ -89,6 +96,7 @@ class FPNTest(parameterized.TestCase, tf.test.TestCase):
num_filters=256, num_filters=256,
fusion_type='sum', fusion_type='sum',
use_separable_conv=False, use_separable_conv=False,
use_keras_layer=False,
use_sync_bn=False, use_sync_bn=False,
activation='relu', activation='relu',
norm_momentum=0.99, norm_momentum=0.99,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment