Commit e4f2edf0 authored by Xianzhi Du's avatar Xianzhi Du Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 395783944
parent 2de035ca
...@@ -32,6 +32,7 @@ class Identity(hyperparams.Config): ...@@ -32,6 +32,7 @@ class Identity(hyperparams.Config):
class FPN(hyperparams.Config): class FPN(hyperparams.Config):
"""FPN config.""" """FPN config."""
num_filters: int = 256 num_filters: int = 256
fusion_type: str = 'sum'
use_separable_conv: bool = False use_separable_conv: bool = False
......
...@@ -42,6 +42,7 @@ class FPN(tf.keras.Model): ...@@ -42,6 +42,7 @@ class FPN(tf.keras.Model):
min_level: int = 3, min_level: int = 3,
max_level: int = 7, max_level: int = 7,
num_filters: int = 256, num_filters: int = 256,
fusion_type: str = 'sum',
use_separable_conv: bool = False, use_separable_conv: bool = False,
activation: str = 'relu', activation: str = 'relu',
use_sync_bn: bool = False, use_sync_bn: bool = False,
...@@ -59,6 +60,8 @@ class FPN(tf.keras.Model): ...@@ -59,6 +60,8 @@ class FPN(tf.keras.Model):
min_level: An `int` of minimum level in FPN output feature maps. min_level: An `int` of minimum level in FPN output feature maps.
max_level: An `int` of maximum level in FPN output feature maps. max_level: An `int` of maximum level in FPN output feature maps.
num_filters: An `int` number of filters in FPN layers. num_filters: An `int` number of filters in FPN layers.
fusion_type: A `str` of `sum` or `concat`. Whether performing sum or
concat for feature fusion.
use_separable_conv: A `bool`. If True use separable convolution for use_separable_conv: A `bool`. If True use separable convolution for
convolution in FPN layers. convolution in FPN layers.
activation: A `str` name of the activation function. activation: A `str` name of the activation function.
...@@ -77,6 +80,7 @@ class FPN(tf.keras.Model): ...@@ -77,6 +80,7 @@ class FPN(tf.keras.Model):
'min_level': min_level, 'min_level': min_level,
'max_level': max_level, 'max_level': max_level,
'num_filters': num_filters, 'num_filters': num_filters,
'fusion_type': fusion_type,
'use_separable_conv': use_separable_conv, 'use_separable_conv': use_separable_conv,
'activation': activation, 'activation': activation,
'use_sync_bn': use_sync_bn, 'use_sync_bn': use_sync_bn,
...@@ -122,8 +126,16 @@ class FPN(tf.keras.Model): ...@@ -122,8 +126,16 @@ class FPN(tf.keras.Model):
# Build top-down path. # Build top-down path.
feats = {str(backbone_max_level): feats_lateral[str(backbone_max_level)]} feats = {str(backbone_max_level): feats_lateral[str(backbone_max_level)]}
for level in range(backbone_max_level - 1, min_level - 1, -1): for level in range(backbone_max_level - 1, min_level - 1, -1):
feats[str(level)] = spatial_transform_ops.nearest_upsampling( feat_a = spatial_transform_ops.nearest_upsampling(
feats[str(level + 1)], 2) + feats_lateral[str(level)] feats[str(level + 1)], 2)
feat_b = feats_lateral[str(level)]
if fusion_type == 'sum':
feats[str(level)] = feat_a + feat_b
elif fusion_type == 'concat':
feats[str(level)] = tf.concat([feat_a, feat_b], axis=-1)
else:
raise ValueError('Fusion type {} not supported.'.format(fusion_type))
# TODO(xianzhi): consider to remove bias in conv2d. # TODO(xianzhi): consider to remove bias in conv2d.
# Build post-hoc 3x3 convolution kernel. # Build post-hoc 3x3 convolution kernel.
...@@ -224,6 +236,7 @@ def build_fpn_decoder( ...@@ -224,6 +236,7 @@ def build_fpn_decoder(
min_level=model_config.min_level, min_level=model_config.min_level,
max_level=model_config.max_level, max_level=model_config.max_level,
num_filters=decoder_cfg.num_filters, num_filters=decoder_cfg.num_filters,
fusion_type=decoder_cfg.fusion_type,
use_separable_conv=decoder_cfg.use_separable_conv, use_separable_conv=decoder_cfg.use_separable_conv,
activation=norm_activation_config.activation, activation=norm_activation_config.activation,
use_sync_bn=norm_activation_config.use_sync_bn, use_sync_bn=norm_activation_config.use_sync_bn,
......
...@@ -27,11 +27,11 @@ from official.vision.beta.modeling.decoders import fpn ...@@ -27,11 +27,11 @@ from official.vision.beta.modeling.decoders import fpn
class FPNTest(parameterized.TestCase, tf.test.TestCase): class FPNTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters( @parameterized.parameters(
(256, 3, 7, False), (256, 3, 7, False, 'sum'),
(256, 3, 7, True), (256, 3, 7, True, 'concat'),
) )
def test_network_creation(self, input_size, min_level, max_level, def test_network_creation(self, input_size, min_level, max_level,
use_separable_conv): use_separable_conv, fusion_type):
"""Test creation of FPN.""" """Test creation of FPN."""
tf.keras.backend.set_image_data_format('channels_last') tf.keras.backend.set_image_data_format('channels_last')
...@@ -42,6 +42,7 @@ class FPNTest(parameterized.TestCase, tf.test.TestCase): ...@@ -42,6 +42,7 @@ class FPNTest(parameterized.TestCase, tf.test.TestCase):
input_specs=backbone.output_specs, input_specs=backbone.output_specs,
min_level=min_level, min_level=min_level,
max_level=max_level, max_level=max_level,
fusion_type=fusion_type,
use_separable_conv=use_separable_conv) use_separable_conv=use_separable_conv)
endpoints = backbone(inputs) endpoints = backbone(inputs)
...@@ -87,6 +88,7 @@ class FPNTest(parameterized.TestCase, tf.test.TestCase): ...@@ -87,6 +88,7 @@ class FPNTest(parameterized.TestCase, tf.test.TestCase):
min_level=3, min_level=3,
max_level=7, max_level=7,
num_filters=256, num_filters=256,
fusion_type='sum',
use_separable_conv=False, use_separable_conv=False,
use_sync_bn=False, use_sync_bn=False,
activation='relu', activation='relu',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment