Commit 00024735 authored by Jaehong Kim's avatar Jaehong Kim Committed by A. Unique TensorFlower
Browse files

Change activation/add/upsampling/identity to a keras layer form for the spinenet_mobile.

Added use_keras_upsampling_2d flags to use keras upsampling layer instead of optimized custom spatial_transform_ops.

PiperOrigin-RevId: 382626602
parent cbe77ec0
......@@ -108,6 +108,7 @@ def get_activation(identifier, use_keras_layer=False):
"linear": "linear",
"identity": "linear",
"swish": "swish",
"sigmoid": "sigmoid",
"relu6": tf.nn.relu6,
}
if identifier in keras_layer_allowlist:
......
......@@ -80,6 +80,11 @@ class SpineNetMobile(hyperparams.Config):
expand_ratio: int = 6
min_level: int = 3
max_level: int = 7
# If use_keras_upsampling_2d is True, model uses UpSampling2D keras layer
# instead of optimized custom TF op. It makes model be more keras style. We
# set this flag to True when we apply QAT from model optimization toolkit
# that requires the model should use keras layers.
use_keras_upsampling_2d: bool = False
@dataclasses.dataclass
......
......@@ -346,7 +346,8 @@ def retinanet_spinenet_mobile_coco() -> cfg.ExperimentConfig:
model_id='49',
stochastic_depth_drop_rate=0.2,
min_level=3,
max_level=7)),
max_level=7,
use_keras_upsampling_2d=False)),
decoder=decoders.Decoder(
type='identity', identity=decoders.Identity()),
head=RetinaNetHead(num_filters=48, use_separable_conv=True),
......
......@@ -152,6 +152,7 @@ class SpineNetMobile(tf.keras.Model):
use_sync_bn: bool = False,
norm_momentum: float = 0.99,
norm_epsilon: float = 0.001,
use_keras_upsampling_2d: bool = False,
**kwargs):
"""Initializes a Mobile SpineNet model.
......@@ -181,6 +182,7 @@ class SpineNetMobile(tf.keras.Model):
use_sync_bn: If True, use synchronized batch normalization.
norm_momentum: A `float` of normalization momentum for the moving average.
norm_epsilon: A small `float` added to variance to avoid dividing by zero.
use_keras_upsampling_2d: If True, use keras UpSampling2D layer.
**kwargs: Additional keyword arguments to be passed.
"""
self._input_specs = input_specs
......@@ -200,12 +202,7 @@ class SpineNetMobile(tf.keras.Model):
self._use_sync_bn = use_sync_bn
self._norm_momentum = norm_momentum
self._norm_epsilon = norm_epsilon
if activation == 'relu':
self._activation_fn = tf.nn.relu
elif activation == 'swish':
self._activation_fn = tf.nn.swish
else:
raise ValueError('Activation {} not implemented.'.format(activation))
self._use_keras_upsampling_2d = use_keras_upsampling_2d
self._num_init_blocks = 2
if use_sync_bn:
......@@ -271,7 +268,7 @@ class SpineNetMobile(tf.keras.Model):
norm_momentum=self._norm_momentum,
norm_epsilon=self._norm_epsilon)(
inputs)
return tf.identity(x, name=name)
return tf.keras.layers.Activation('linear', name=name)(x)
def _build_stem(self, inputs):
"""Builds SpineNet stem."""
......@@ -290,7 +287,7 @@ class SpineNetMobile(tf.keras.Model):
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)(
x)
x = tf_utils.get_activation(self._activation_fn)(x)
x = tf_utils.get_activation(self._activation, use_keras_layer=True)(x)
net = []
stem_strides = [1, 2]
......@@ -365,14 +362,15 @@ class SpineNetMobile(tf.keras.Model):
parent_weights = [
tf.nn.relu(tf.cast(tf.Variable(1.0, name='block{}_fusion{}'.format(
i, j)), dtype=dtype)) for j in range(len(parents))]
weights_sum = tf.add_n(parent_weights)
weights_sum = layers.Add()(parent_weights)
parents = [
parents[i] * parent_weights[i] / (weights_sum + 0.0001)
for i in range(len(parents))
]
# Fuse all parent nodes then build a new block.
x = tf_utils.get_activation(self._activation_fn)(tf.add_n(parents))
x = tf_utils.get_activation(
self._activation, use_keras_layer=True)(layers.Add()(parents))
x = self._block_group(
inputs=x,
in_filters=target_num_filters,
......@@ -421,7 +419,7 @@ class SpineNetMobile(tf.keras.Model):
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)(
x)
x = tf_utils.get_activation(self._activation_fn)(x)
x = tf_utils.get_activation(self._activation, use_keras_layer=True)(x)
endpoints[str(level)] = x
return endpoints
......@@ -446,11 +444,13 @@ class SpineNetMobile(tf.keras.Model):
momentum=self._norm_momentum,
epsilon=self._norm_epsilon)(
x)
x = tf_utils.get_activation(self._activation_fn)(x)
x = tf_utils.get_activation(
self._activation, use_keras_layer=True)(x)
input_width /= 2
elif input_width < target_width:
scale = target_width // input_width
x = spatial_transform_ops.nearest_upsampling(x, scale=scale)
x = spatial_transform_ops.nearest_upsampling(
x, scale=scale, use_keras_layer=self._use_keras_upsampling_2d)
# Last 1x1 conv to match filter size.
x = layers.Conv2D(
......@@ -485,7 +485,8 @@ class SpineNetMobile(tf.keras.Model):
'activation': self._activation,
'use_sync_bn': self._use_sync_bn,
'norm_momentum': self._norm_momentum,
'norm_epsilon': self._norm_epsilon
'norm_epsilon': self._norm_epsilon,
'use_keras_upsampling_2d': self._use_keras_upsampling_2d,
}
return config_dict
......@@ -531,4 +532,5 @@ def build_spinenet_mobile(
activation=norm_activation_config.activation,
use_sync_bn=norm_activation_config.use_sync_bn,
norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon)
norm_epsilon=norm_activation_config.norm_epsilon,
use_keras_upsampling_2d=backbone_cfg.use_keras_upsampling_2d)
......@@ -90,6 +90,7 @@ class SpineNetMobileTest(parameterized.TestCase, tf.test.TestCase):
kernel_initializer='VarianceScaling',
kernel_regularizer=None,
bias_regularizer=None,
use_keras_upsampling_2d=False,
)
network = spinenet_mobile.SpineNetMobile(**kwargs)
......
......@@ -511,18 +511,22 @@ def crop_mask_in_target_box(masks,
return cropped_masks
def nearest_upsampling(data, scale):
def nearest_upsampling(data, scale, use_keras_layer=False):
"""Nearest neighbor upsampling implementation.
Args:
data: A tensor with a shape of [batch, height_in, width_in, channels].
scale: An integer multiple to scale resolution of input data.
use_keras_layer: If True, use keras Upsampling2D layer.
Returns:
data_up: A tensor with a shape of
[batch, height_in*scale, width_in*scale, channels]. Same dtype as input
data.
"""
if use_keras_layer:
return tf.keras.layers.UpSampling2D(size=(scale, scale),
interpolation='nearest')(data)
with tf.name_scope('nearest_upsampling'):
bs, _, _, c = data.get_shape().as_list()
shape = tf.shape(input=data)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment