Commit 9db38a15 authored by Xianzhi Du's avatar Xianzhi Du Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 395783944
parent f8b7b77d
......@@ -32,6 +32,7 @@ class Identity(hyperparams.Config):
class FPN(hyperparams.Config):
"""FPN config."""
num_filters: int = 256
fusion_type: str = 'sum'
use_separable_conv: bool = False
......
......@@ -42,6 +42,7 @@ class FPN(tf.keras.Model):
min_level: int = 3,
max_level: int = 7,
num_filters: int = 256,
fusion_type: str = 'sum',
use_separable_conv: bool = False,
activation: str = 'relu',
use_sync_bn: bool = False,
......@@ -59,6 +60,8 @@ class FPN(tf.keras.Model):
min_level: An `int` of minimum level in FPN output feature maps.
max_level: An `int` of maximum level in FPN output feature maps.
num_filters: An `int` number of filters in FPN layers.
fusion_type: A `str` of `sum` or `concat`. Whether performing sum or
concat for feature fusion.
use_separable_conv: A `bool`. If True use separable convolution for
convolution in FPN layers.
activation: A `str` name of the activation function.
......@@ -77,6 +80,7 @@ class FPN(tf.keras.Model):
'min_level': min_level,
'max_level': max_level,
'num_filters': num_filters,
'fusion_type': fusion_type,
'use_separable_conv': use_separable_conv,
'activation': activation,
'use_sync_bn': use_sync_bn,
......@@ -122,8 +126,16 @@ class FPN(tf.keras.Model):
# Build top-down path.
feats = {str(backbone_max_level): feats_lateral[str(backbone_max_level)]}
for level in range(backbone_max_level - 1, min_level - 1, -1):
feats[str(level)] = spatial_transform_ops.nearest_upsampling(
feats[str(level + 1)], 2) + feats_lateral[str(level)]
feat_a = spatial_transform_ops.nearest_upsampling(
feats[str(level + 1)], 2)
feat_b = feats_lateral[str(level)]
if fusion_type == 'sum':
feats[str(level)] = feat_a + feat_b
elif fusion_type == 'concat':
feats[str(level)] = tf.concat([feat_a, feat_b], axis=-1)
else:
raise ValueError('Fusion type {} not supported.'.format(fusion_type))
# TODO(xianzhi): consider to remove bias in conv2d.
# Build post-hoc 3x3 convolution kernel.
......@@ -224,6 +236,7 @@ def build_fpn_decoder(
min_level=model_config.min_level,
max_level=model_config.max_level,
num_filters=decoder_cfg.num_filters,
fusion_type=decoder_cfg.fusion_type,
use_separable_conv=decoder_cfg.use_separable_conv,
activation=norm_activation_config.activation,
use_sync_bn=norm_activation_config.use_sync_bn,
......
......@@ -27,11 +27,11 @@ from official.vision.beta.modeling.decoders import fpn
class FPNTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters(
(256, 3, 7, False),
(256, 3, 7, True),
(256, 3, 7, False, 'sum'),
(256, 3, 7, True, 'concat'),
)
def test_network_creation(self, input_size, min_level, max_level,
use_separable_conv):
use_separable_conv, fusion_type):
"""Test creation of FPN."""
tf.keras.backend.set_image_data_format('channels_last')
......@@ -42,6 +42,7 @@ class FPNTest(parameterized.TestCase, tf.test.TestCase):
input_specs=backbone.output_specs,
min_level=min_level,
max_level=max_level,
fusion_type=fusion_type,
use_separable_conv=use_separable_conv)
endpoints = backbone(inputs)
......@@ -87,6 +88,7 @@ class FPNTest(parameterized.TestCase, tf.test.TestCase):
min_level=3,
max_level=7,
num_filters=256,
fusion_type='sum',
use_separable_conv=False,
use_sync_bn=False,
activation='relu',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment