Commit 78c43ef1 authored by Gunho Park's avatar Gunho Park
Browse files

Merge branch 'master' of https://github.com/tensorflow/models

parents 67cfc95b e3c7e300
...@@ -393,8 +393,10 @@ class SpineNet(tf.keras.Model): ...@@ -393,8 +393,10 @@ class SpineNet(tf.keras.Model):
block_spec.level)) block_spec.level))
if (block_spec.level < self._min_level or if (block_spec.level < self._min_level or
block_spec.level > self._max_level): block_spec.level > self._max_level):
raise ValueError('Output level is out of range [{}, {}]'.format( logging.warning(
self._min_level, self._max_level)) 'SpineNet output level out of range [min_level, max_level] = '
'[%s, %s] will not be used for further processing.',
self._min_level, self._max_level)
endpoints[str(block_spec.level)] = x endpoints[str(block_spec.level)] = x
return endpoints return endpoints
......
...@@ -152,6 +152,7 @@ class SpineNetMobile(tf.keras.Model): ...@@ -152,6 +152,7 @@ class SpineNetMobile(tf.keras.Model):
use_sync_bn: bool = False, use_sync_bn: bool = False,
norm_momentum: float = 0.99, norm_momentum: float = 0.99,
norm_epsilon: float = 0.001, norm_epsilon: float = 0.001,
use_keras_upsampling_2d: bool = False,
**kwargs): **kwargs):
"""Initializes a Mobile SpineNet model. """Initializes a Mobile SpineNet model.
...@@ -181,6 +182,7 @@ class SpineNetMobile(tf.keras.Model): ...@@ -181,6 +182,7 @@ class SpineNetMobile(tf.keras.Model):
use_sync_bn: If True, use synchronized batch normalization. use_sync_bn: If True, use synchronized batch normalization.
norm_momentum: A `float` of normalization momentum for the moving average. norm_momentum: A `float` of normalization momentum for the moving average.
norm_epsilon: A small `float` added to variance to avoid dividing by zero. norm_epsilon: A small `float` added to variance to avoid dividing by zero.
use_keras_upsampling_2d: If True, use keras UpSampling2D layer.
**kwargs: Additional keyword arguments to be passed. **kwargs: Additional keyword arguments to be passed.
""" """
self._input_specs = input_specs self._input_specs = input_specs
...@@ -200,12 +202,7 @@ class SpineNetMobile(tf.keras.Model): ...@@ -200,12 +202,7 @@ class SpineNetMobile(tf.keras.Model):
self._use_sync_bn = use_sync_bn self._use_sync_bn = use_sync_bn
self._norm_momentum = norm_momentum self._norm_momentum = norm_momentum
self._norm_epsilon = norm_epsilon self._norm_epsilon = norm_epsilon
if activation == 'relu': self._use_keras_upsampling_2d = use_keras_upsampling_2d
self._activation_fn = tf.nn.relu
elif activation == 'swish':
self._activation_fn = tf.nn.swish
else:
raise ValueError('Activation {} not implemented.'.format(activation))
self._num_init_blocks = 2 self._num_init_blocks = 2
if use_sync_bn: if use_sync_bn:
...@@ -271,7 +268,7 @@ class SpineNetMobile(tf.keras.Model): ...@@ -271,7 +268,7 @@ class SpineNetMobile(tf.keras.Model):
norm_momentum=self._norm_momentum, norm_momentum=self._norm_momentum,
norm_epsilon=self._norm_epsilon)( norm_epsilon=self._norm_epsilon)(
inputs) inputs)
return tf.identity(x, name=name) return tf.keras.layers.Activation('linear', name=name)(x)
def _build_stem(self, inputs): def _build_stem(self, inputs):
"""Builds SpineNet stem.""" """Builds SpineNet stem."""
...@@ -290,7 +287,7 @@ class SpineNetMobile(tf.keras.Model): ...@@ -290,7 +287,7 @@ class SpineNetMobile(tf.keras.Model):
momentum=self._norm_momentum, momentum=self._norm_momentum,
epsilon=self._norm_epsilon)( epsilon=self._norm_epsilon)(
x) x)
x = tf_utils.get_activation(self._activation_fn)(x) x = tf_utils.get_activation(self._activation, use_keras_layer=True)(x)
net = [] net = []
stem_strides = [1, 2] stem_strides = [1, 2]
...@@ -365,14 +362,15 @@ class SpineNetMobile(tf.keras.Model): ...@@ -365,14 +362,15 @@ class SpineNetMobile(tf.keras.Model):
parent_weights = [ parent_weights = [
tf.nn.relu(tf.cast(tf.Variable(1.0, name='block{}_fusion{}'.format( tf.nn.relu(tf.cast(tf.Variable(1.0, name='block{}_fusion{}'.format(
i, j)), dtype=dtype)) for j in range(len(parents))] i, j)), dtype=dtype)) for j in range(len(parents))]
weights_sum = tf.add_n(parent_weights) weights_sum = layers.Add()(parent_weights)
parents = [ parents = [
parents[i] * parent_weights[i] / (weights_sum + 0.0001) parents[i] * parent_weights[i] / (weights_sum + 0.0001)
for i in range(len(parents)) for i in range(len(parents))
] ]
# Fuse all parent nodes then build a new block. # Fuse all parent nodes then build a new block.
x = tf_utils.get_activation(self._activation_fn)(tf.add_n(parents)) x = tf_utils.get_activation(
self._activation, use_keras_layer=True)(layers.Add()(parents))
x = self._block_group( x = self._block_group(
inputs=x, inputs=x,
in_filters=target_num_filters, in_filters=target_num_filters,
...@@ -421,7 +419,7 @@ class SpineNetMobile(tf.keras.Model): ...@@ -421,7 +419,7 @@ class SpineNetMobile(tf.keras.Model):
momentum=self._norm_momentum, momentum=self._norm_momentum,
epsilon=self._norm_epsilon)( epsilon=self._norm_epsilon)(
x) x)
x = tf_utils.get_activation(self._activation_fn)(x) x = tf_utils.get_activation(self._activation, use_keras_layer=True)(x)
endpoints[str(level)] = x endpoints[str(level)] = x
return endpoints return endpoints
...@@ -446,11 +444,13 @@ class SpineNetMobile(tf.keras.Model): ...@@ -446,11 +444,13 @@ class SpineNetMobile(tf.keras.Model):
momentum=self._norm_momentum, momentum=self._norm_momentum,
epsilon=self._norm_epsilon)( epsilon=self._norm_epsilon)(
x) x)
x = tf_utils.get_activation(self._activation_fn)(x) x = tf_utils.get_activation(
self._activation, use_keras_layer=True)(x)
input_width /= 2 input_width /= 2
elif input_width < target_width: elif input_width < target_width:
scale = target_width // input_width scale = target_width // input_width
x = spatial_transform_ops.nearest_upsampling(x, scale=scale) x = spatial_transform_ops.nearest_upsampling(
x, scale=scale, use_keras_layer=self._use_keras_upsampling_2d)
# Last 1x1 conv to match filter size. # Last 1x1 conv to match filter size.
x = layers.Conv2D( x = layers.Conv2D(
...@@ -485,7 +485,8 @@ class SpineNetMobile(tf.keras.Model): ...@@ -485,7 +485,8 @@ class SpineNetMobile(tf.keras.Model):
'activation': self._activation, 'activation': self._activation,
'use_sync_bn': self._use_sync_bn, 'use_sync_bn': self._use_sync_bn,
'norm_momentum': self._norm_momentum, 'norm_momentum': self._norm_momentum,
'norm_epsilon': self._norm_epsilon 'norm_epsilon': self._norm_epsilon,
'use_keras_upsampling_2d': self._use_keras_upsampling_2d,
} }
return config_dict return config_dict
...@@ -531,4 +532,5 @@ def build_spinenet_mobile( ...@@ -531,4 +532,5 @@ def build_spinenet_mobile(
activation=norm_activation_config.activation, activation=norm_activation_config.activation,
use_sync_bn=norm_activation_config.use_sync_bn, use_sync_bn=norm_activation_config.use_sync_bn,
norm_momentum=norm_activation_config.norm_momentum, norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon) norm_epsilon=norm_activation_config.norm_epsilon,
use_keras_upsampling_2d=backbone_cfg.use_keras_upsampling_2d)
...@@ -90,6 +90,7 @@ class SpineNetMobileTest(parameterized.TestCase, tf.test.TestCase): ...@@ -90,6 +90,7 @@ class SpineNetMobileTest(parameterized.TestCase, tf.test.TestCase):
kernel_initializer='VarianceScaling', kernel_initializer='VarianceScaling',
kernel_regularizer=None, kernel_regularizer=None,
bias_regularizer=None, bias_regularizer=None,
use_keras_upsampling_2d=False,
) )
network = spinenet_mobile.SpineNetMobile(**kwargs) network = spinenet_mobile.SpineNetMobile(**kwargs)
......
...@@ -24,17 +24,16 @@ from official.vision.beta.modeling.backbones import spinenet ...@@ -24,17 +24,16 @@ from official.vision.beta.modeling.backbones import spinenet
class SpineNetTest(parameterized.TestCase, tf.test.TestCase): class SpineNetTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters( @parameterized.parameters(
(128, 0.65, 1, 0.5, 128), (128, 0.65, 1, 0.5, 128, 4, 6),
(256, 1.0, 1, 0.5, 256), (256, 1.0, 1, 0.5, 256, 3, 6),
(384, 1.0, 2, 0.5, 256), (384, 1.0, 2, 0.5, 256, 4, 7),
(512, 1.0, 3, 1.0, 256), (512, 1.0, 3, 1.0, 256, 3, 7),
(640, 1.3, 4, 1.0, 384), (640, 1.3, 4, 1.0, 384, 3, 7),
) )
def test_network_creation(self, input_size, filter_size_scale, block_repeats, def test_network_creation(self, input_size, filter_size_scale, block_repeats,
resample_alpha, endpoints_num_filters): resample_alpha, endpoints_num_filters, min_level,
max_level):
"""Test creation of SpineNet models.""" """Test creation of SpineNet models."""
min_level = 3
max_level = 7
tf.keras.backend.set_image_data_format('channels_last') tf.keras.backend.set_image_data_format('channels_last')
......
...@@ -13,12 +13,15 @@ ...@@ -13,12 +13,15 @@
# limitations under the License. # limitations under the License.
"""Contains definitions of Atrous Spatial Pyramid Pooling (ASPP) decoder.""" """Contains definitions of Atrous Spatial Pyramid Pooling (ASPP) decoder."""
from typing import Any, List, Optional, Mapping from typing import Any, List, Mapping, Optional
# Import libraries # Import libraries
import tensorflow as tf import tensorflow as tf
from official.modeling import hyperparams
from official.vision import keras_cv from official.vision import keras_cv
from official.vision.beta.modeling.decoders import factory
@tf.keras.utils.register_keras_serializable(package='Vision') @tf.keras.utils.register_keras_serializable(package='Vision')
...@@ -128,3 +131,46 @@ class ASPP(tf.keras.layers.Layer): ...@@ -128,3 +131,46 @@ class ASPP(tf.keras.layers.Layer):
@classmethod @classmethod
def from_config(cls, config, custom_objects=None): def from_config(cls, config, custom_objects=None):
return cls(**config) return cls(**config)
@factory.register_decoder_builder('aspp')
def build_aspp_decoder(
input_specs: Mapping[str, tf.TensorShape],
model_config: hyperparams.Config,
l2_regularizer: Optional[tf.keras.regularizers.Regularizer] = None
) -> tf.keras.Model:
"""Builds ASPP decoder from a config.
Args:
input_specs: A `dict` of input specifications. A dictionary consists of
{level: TensorShape} from a backbone. Note this is for consistent
interface, and is not used by ASPP decoder.
model_config: A OneOfConfig. Model config.
l2_regularizer: A `tf.keras.regularizers.Regularizer` instance. Default to
None.
Returns:
A `tf.keras.Model` instance of the ASPP decoder.
Raises:
ValueError: If the model_config.decoder.type is not `aspp`.
"""
del input_specs # input_specs is not used by ASPP decoder.
decoder_type = model_config.decoder.type
decoder_cfg = model_config.decoder.get()
if decoder_type != 'aspp':
raise ValueError(f'Inconsistent decoder type {decoder_type}. '
'Need to be `aspp`.')
norm_activation_config = model_config.norm_activation
return ASPP(
level=decoder_cfg.level,
dilation_rates=decoder_cfg.dilation_rates,
num_filters=decoder_cfg.num_filters,
pool_kernel_size=decoder_cfg.pool_kernel_size,
dropout_rate=decoder_cfg.dropout_rate,
use_sync_bn=norm_activation_config.use_sync_bn,
norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon,
activation=norm_activation_config.activation,
kernel_regularizer=l2_regularizer)
...@@ -12,80 +12,124 @@ ...@@ -12,80 +12,124 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Lint as: python3 """Decoder registers and factory method.
"""Contains the factory method to create decoders."""
from typing import Mapping, Optional One can register a new decoder model by the following two steps:
1 Import the factory and register the build in the decoder file.
2 Import the decoder class and add a build in __init__.py.
```
# my_decoder.py
from modeling.decoders import factory
class MyDecoder():
...
@factory.register_decoder_builder('my_decoder')
def build_my_decoder():
return MyDecoder()
# decoders/__init__.py adds import
from modeling.decoders.my_decoder import MyDecoder
```
If one wants the MyDecoder class to be used only by those binary
then don't imported the decoder module in decoders/__init__.py, but import it
in place that uses it.
"""
from typing import Any, Callable, Mapping, Optional, Union
# Import libraries # Import libraries
import tensorflow as tf import tensorflow as tf
from official.core import registry
from official.modeling import hyperparams from official.modeling import hyperparams
from official.vision.beta.modeling import decoders
_REGISTERED_DECODER_CLS = {}
def register_decoder_builder(key: str) -> Callable[..., Any]:
"""Decorates a builder of decoder class.
The builder should be a Callable (a class or a function).
This decorator supports registration of decoder builder as follows:
```
class MyDecoder(tf.keras.Model):
pass
@register_decoder_builder('mydecoder')
def builder(input_specs, config, l2_reg):
return MyDecoder(...)
# Builds a MyDecoder object.
my_decoder = build_decoder_3d(input_specs, config, l2_reg)
```
Args:
key: A `str` of key to look up the builder.
Returns:
A callable for using as class decorator that registers the decorated class
for creation from an instance of task_config_cls.
"""
return registry.register(_REGISTERED_DECODER_CLS, key)
@register_decoder_builder('identity')
def build_identity(
input_specs: Optional[Mapping[str, tf.TensorShape]] = None,
model_config: Optional[hyperparams.Config] = None,
l2_regularizer: Optional[tf.keras.regularizers.Regularizer] = None) -> None:
"""Builds identity decoder from a config.
All the input arguments are not used by identity decoder but kept here to
ensure the interface is consistent.
Args:
input_specs: A `dict` of input specifications. A dictionary consists of
{level: TensorShape} from a backbone.
model_config: A `OneOfConfig` of model config.
l2_regularizer: A `tf.keras.regularizers.Regularizer` object. Default to
None.
Returns:
An instance of the identity decoder.
"""
del input_specs, model_config, l2_regularizer # Unused by identity decoder.
def build_decoder( def build_decoder(
input_specs: Mapping[str, tf.TensorShape], input_specs: Mapping[str, tf.TensorShape],
model_config: hyperparams.Config, model_config: hyperparams.Config,
l2_regularizer: Optional[tf.keras.regularizers.Regularizer] = None l2_regularizer: tf.keras.regularizers.Regularizer = None,
) -> tf.keras.Model: **kwargs) -> Union[None, tf.keras.Model, tf.keras.layers.Layer]:
"""Builds decoder from a config. """Builds decoder from a config.
A decoder can be a keras.Model, a keras.layers.Layer, or None. If it is not
None, the decoder will take features from the backbone as input and generate
decoded feature maps. If it is None, such as an identity decoder, the decoder
is skipped and features from the backbone are regarded as model output.
Args: Args:
input_specs: A `dict` of input specifications. A dictionary consists of input_specs: A `dict` of input specifications. A dictionary consists of
{level: TensorShape} from a backbone. {level: TensorShape} from a backbone.
model_config: A OneOfConfig. Model config. model_config: A `OneOfConfig` of model config.
l2_regularizer: A `tf.keras.regularizers.Regularizer` instance. Default to l2_regularizer: A `tf.keras.regularizers.Regularizer` object. Default to
None. None.
**kwargs: Additional keyword args to be passed to decoder builder.
Returns: Returns:
A `tf.keras.Model` instance of the decoder. An instance of the decoder.
""" """
decoder_type = model_config.decoder.type decoder_builder = registry.lookup(_REGISTERED_DECODER_CLS,
decoder_cfg = model_config.decoder.get() model_config.decoder.type)
norm_activation_config = model_config.norm_activation
return decoder_builder(
if decoder_type == 'identity':
decoder = None
elif decoder_type == 'fpn':
decoder = decoders.FPN(
input_specs=input_specs,
min_level=model_config.min_level,
max_level=model_config.max_level,
num_filters=decoder_cfg.num_filters,
use_separable_conv=decoder_cfg.use_separable_conv,
activation=norm_activation_config.activation,
use_sync_bn=norm_activation_config.use_sync_bn,
norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon,
kernel_regularizer=l2_regularizer)
elif decoder_type == 'nasfpn':
decoder = decoders.NASFPN(
input_specs=input_specs, input_specs=input_specs,
min_level=model_config.min_level, model_config=model_config,
max_level=model_config.max_level, l2_regularizer=l2_regularizer,
num_filters=decoder_cfg.num_filters, **kwargs)
num_repeats=decoder_cfg.num_repeats,
use_separable_conv=decoder_cfg.use_separable_conv,
activation=norm_activation_config.activation,
use_sync_bn=norm_activation_config.use_sync_bn,
norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon,
kernel_regularizer=l2_regularizer)
elif decoder_type == 'aspp':
decoder = decoders.ASPP(
level=decoder_cfg.level,
dilation_rates=decoder_cfg.dilation_rates,
num_filters=decoder_cfg.num_filters,
pool_kernel_size=decoder_cfg.pool_kernel_size,
dropout_rate=decoder_cfg.dropout_rate,
use_sync_bn=norm_activation_config.use_sync_bn,
norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon,
activation=norm_activation_config.activation,
kernel_regularizer=l2_regularizer)
else:
raise ValueError('Decoder {!r} not implement'.format(decoder_type))
return decoder
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for decoder factory functions."""
from absl.testing import parameterized
import tensorflow as tf
from tensorflow.python.distribute import combinations
from official.vision.beta import configs
from official.vision.beta.configs import decoders as decoders_cfg
from official.vision.beta.modeling import decoders
from official.vision.beta.modeling.decoders import factory
class FactoryTest(tf.test.TestCase, parameterized.TestCase):
@combinations.generate(
combinations.combine(
num_filters=[128, 256], use_separable_conv=[True, False]))
def test_fpn_decoder_creation(self, num_filters, use_separable_conv):
"""Test creation of FPN decoder."""
min_level = 3
max_level = 7
input_specs = {}
for level in range(min_level, max_level):
input_specs[str(level)] = tf.TensorShape(
[1, 128 // (2**level), 128 // (2**level), 3])
network = decoders.FPN(
input_specs=input_specs,
num_filters=num_filters,
use_separable_conv=use_separable_conv,
use_sync_bn=True)
model_config = configs.retinanet.RetinaNet()
model_config.min_level = min_level
model_config.max_level = max_level
model_config.num_classes = 10
model_config.input_size = [None, None, 3]
model_config.decoder = decoders_cfg.Decoder(
type='fpn',
fpn=decoders_cfg.FPN(
num_filters=num_filters, use_separable_conv=use_separable_conv))
factory_network = factory.build_decoder(
input_specs=input_specs, model_config=model_config)
network_config = network.get_config()
factory_network_config = factory_network.get_config()
self.assertEqual(network_config, factory_network_config)
@combinations.generate(
combinations.combine(
num_filters=[128, 256],
num_repeats=[3, 5],
use_separable_conv=[True, False]))
def test_nasfpn_decoder_creation(self, num_filters, num_repeats,
use_separable_conv):
"""Test creation of NASFPN decoder."""
min_level = 3
max_level = 7
input_specs = {}
for level in range(min_level, max_level):
input_specs[str(level)] = tf.TensorShape(
[1, 128 // (2**level), 128 // (2**level), 3])
network = decoders.NASFPN(
input_specs=input_specs,
num_filters=num_filters,
num_repeats=num_repeats,
use_separable_conv=use_separable_conv,
use_sync_bn=True)
model_config = configs.retinanet.RetinaNet()
model_config.min_level = min_level
model_config.max_level = max_level
model_config.num_classes = 10
model_config.input_size = [None, None, 3]
model_config.decoder = decoders_cfg.Decoder(
type='nasfpn',
nasfpn=decoders_cfg.NASFPN(
num_filters=num_filters,
num_repeats=num_repeats,
use_separable_conv=use_separable_conv))
factory_network = factory.build_decoder(
input_specs=input_specs, model_config=model_config)
network_config = network.get_config()
factory_network_config = factory_network.get_config()
self.assertEqual(network_config, factory_network_config)
@combinations.generate(
combinations.combine(
level=[3, 4],
dilation_rates=[[6, 12, 18], [6, 12]],
num_filters=[128, 256]))
def test_aspp_decoder_creation(self, level, dilation_rates, num_filters):
"""Test creation of ASPP decoder."""
input_specs = {'1': tf.TensorShape([1, 128, 128, 3])}
network = decoders.ASPP(
level=level,
dilation_rates=dilation_rates,
num_filters=num_filters,
use_sync_bn=True)
model_config = configs.semantic_segmentation.SemanticSegmentationModel()
model_config.num_classes = 10
model_config.input_size = [None, None, 3]
model_config.decoder = decoders_cfg.Decoder(
type='aspp',
aspp=decoders_cfg.ASPP(
level=level, dilation_rates=dilation_rates,
num_filters=num_filters))
factory_network = factory.build_decoder(
input_specs=input_specs, model_config=model_config)
network_config = network.get_config()
factory_network_config = factory_network.get_config()
self.assertEqual(network_config, factory_network_config)
def test_identity_decoder_creation(self):
"""Test creation of identity decoder."""
model_config = configs.retinanet.RetinaNet()
model_config.num_classes = 2
model_config.input_size = [None, None, 3]
model_config.decoder = decoders_cfg.Decoder(
type='identity', identity=decoders_cfg.Identity())
factory_network = factory.build_decoder(
input_specs=None, model_config=model_config)
self.assertIsNone(factory_network)
if __name__ == '__main__':
tf.test.main()
...@@ -16,9 +16,12 @@ ...@@ -16,9 +16,12 @@
from typing import Any, Mapping, Optional from typing import Any, Mapping, Optional
# Import libraries # Import libraries
import tensorflow as tf import tensorflow as tf
from official.modeling import hyperparams
from official.modeling import tf_utils from official.modeling import tf_utils
from official.vision.beta.modeling.decoders import factory
from official.vision.beta.ops import spatial_transform_ops from official.vision.beta.ops import spatial_transform_ops
...@@ -187,3 +190,43 @@ class FPN(tf.keras.Model): ...@@ -187,3 +190,43 @@ class FPN(tf.keras.Model):
def output_specs(self) -> Mapping[str, tf.TensorShape]: def output_specs(self) -> Mapping[str, tf.TensorShape]:
"""A dict of {level: TensorShape} pairs for the model output.""" """A dict of {level: TensorShape} pairs for the model output."""
return self._output_specs return self._output_specs
@factory.register_decoder_builder('fpn')
def build_fpn_decoder(
input_specs: Mapping[str, tf.TensorShape],
model_config: hyperparams.Config,
l2_regularizer: Optional[tf.keras.regularizers.Regularizer] = None
) -> tf.keras.Model:
"""Builds FPN decoder from a config.
Args:
input_specs: A `dict` of input specifications. A dictionary consists of
{level: TensorShape} from a backbone.
model_config: A OneOfConfig. Model config.
l2_regularizer: A `tf.keras.regularizers.Regularizer` instance. Default to
None.
Returns:
A `tf.keras.Model` instance of the FPN decoder.
Raises:
ValueError: If the model_config.decoder.type is not `fpn`.
"""
decoder_type = model_config.decoder.type
decoder_cfg = model_config.decoder.get()
if decoder_type != 'fpn':
raise ValueError(f'Inconsistent decoder type {decoder_type}. '
'Need to be `fpn`.')
norm_activation_config = model_config.norm_activation
return FPN(
input_specs=input_specs,
min_level=model_config.min_level,
max_level=model_config.max_level,
num_filters=decoder_cfg.num_filters,
use_separable_conv=decoder_cfg.use_separable_conv,
activation=norm_activation_config.activation,
use_sync_bn=norm_activation_config.use_sync_bn,
norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon,
kernel_regularizer=l2_regularizer)
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
from absl.testing import parameterized from absl.testing import parameterized
import tensorflow as tf import tensorflow as tf
from official.vision.beta.modeling.backbones import mobilenet
from official.vision.beta.modeling.backbones import resnet from official.vision.beta.modeling.backbones import resnet
from official.vision.beta.modeling.decoders import fpn from official.vision.beta.modeling.decoders import fpn
...@@ -52,6 +53,33 @@ class FPNTest(parameterized.TestCase, tf.test.TestCase): ...@@ -52,6 +53,33 @@ class FPNTest(parameterized.TestCase, tf.test.TestCase):
[1, input_size // 2**level, input_size // 2**level, 256], [1, input_size // 2**level, input_size // 2**level, 256],
feats[str(level)].shape.as_list()) feats[str(level)].shape.as_list())
@parameterized.parameters(
(256, 3, 7, False),
(256, 3, 7, True),
)
def test_network_creation_with_mobilenet(self, input_size, min_level,
max_level, use_separable_conv):
"""Test creation of FPN with mobilenet backbone."""
tf.keras.backend.set_image_data_format('channels_last')
inputs = tf.keras.Input(shape=(input_size, input_size, 3), batch_size=1)
backbone = mobilenet.MobileNet(model_id='MobileNetV2')
network = fpn.FPN(
input_specs=backbone.output_specs,
min_level=min_level,
max_level=max_level,
use_separable_conv=use_separable_conv)
endpoints = backbone(inputs)
feats = network(endpoints)
for level in range(min_level, max_level + 1):
self.assertIn(str(level), feats)
self.assertAllEqual(
[1, input_size // 2**level, input_size // 2**level, 256],
feats[str(level)].shape.as_list())
def test_serialize_deserialize(self): def test_serialize_deserialize(self):
# Create a network object that sets all of its config options. # Create a network object that sets all of its config options.
kwargs = dict( kwargs = dict(
......
This diff is collapsed.
...@@ -16,7 +16,6 @@ ...@@ -16,7 +16,6 @@
from typing import Any, List, Mapping, Optional, Union from typing import Any, List, Mapping, Optional, Union
# Import libraries
import tensorflow as tf import tensorflow as tf
from official.vision.beta.ops import anchor from official.vision.beta.ops import anchor
...@@ -147,14 +146,18 @@ class MaskRCNNModel(tf.keras.Model): ...@@ -147,14 +146,18 @@ class MaskRCNNModel(tf.keras.Model):
model_outputs = {} model_outputs = {}
# Feature extraction. # Feature extraction.
features = self.backbone(images) backbone_features = self.backbone(images)
if self.decoder: if self.decoder:
features = self.decoder(features) features = self.decoder(backbone_features)
else:
features = backbone_features
# Region proposal network. # Region proposal network.
rpn_scores, rpn_boxes = self.rpn_head(features) rpn_scores, rpn_boxes = self.rpn_head(features)
model_outputs.update({ model_outputs.update({
'backbone_features': backbone_features,
'decoder_features': features,
'rpn_boxes': rpn_boxes, 'rpn_boxes': rpn_boxes,
'rpn_scores': rpn_scores 'rpn_scores': rpn_scores
}) })
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment