Commit 32229eae authored by Yeqing Li's avatar Yeqing Li Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 274918820
parent 63e20270
...@@ -124,7 +124,7 @@ class FastrcnnHead(object): ...@@ -124,7 +124,7 @@ class FastrcnnHead(object):
""" """
self._num_classes = num_classes self._num_classes = num_classes
self._mlp_head_dim = mlp_head_dim self._mlp_head_dim = mlp_head_dim
self._batch_norm_relu = batch_norm_relu() self._batch_norm_relu = batch_norm_relu
def __call__(self, roi_features, is_training=None): def __call__(self, roi_features, is_training=None):
"""Box and class branches for the Mask-RCNN model. """Box and class branches for the Mask-RCNN model.
...@@ -151,11 +151,11 @@ class FastrcnnHead(object): ...@@ -151,11 +151,11 @@ class FastrcnnHead(object):
units=self._mlp_head_dim, activation=None, name='fc6')( units=self._mlp_head_dim, activation=None, name='fc6')(
roi_features) roi_features)
net = self._batch_norm_relu(net, is_training=is_training) net = self._batch_norm_relu(fused=False)(net, is_training=is_training)
net = tf.keras.layers.Dense( net = tf.keras.layers.Dense(
units=self._mlp_head_dim, activation=None, name='fc7')( units=self._mlp_head_dim, activation=None, name='fc7')(
net) net)
net = self._batch_norm_relu(net, is_training=is_training) net = self._batch_norm_relu(fused=False)(net, is_training=is_training)
class_outputs = tf.keras.layers.Dense( class_outputs = tf.keras.layers.Dense(
self._num_classes, self._num_classes,
...@@ -189,7 +189,7 @@ class MaskrcnnHead(object): ...@@ -189,7 +189,7 @@ class MaskrcnnHead(object):
""" """
self._num_classes = num_classes self._num_classes = num_classes
self._mrcnn_resolution = mrcnn_resolution self._mrcnn_resolution = mrcnn_resolution
self._batch_norm_relu = batch_norm_relu() self._batch_norm_relu = batch_norm_relu
def __call__(self, roi_features, class_indices, is_training=None): def __call__(self, roi_features, class_indices, is_training=None):
"""Mask branch for the Mask-RCNN model. """Mask branch for the Mask-RCNN model.
...@@ -240,7 +240,7 @@ class MaskrcnnHead(object): ...@@ -240,7 +240,7 @@ class MaskrcnnHead(object):
bias_initializer=tf.zeros_initializer(), bias_initializer=tf.zeros_initializer(),
name='mask-conv-l%d' % i)( name='mask-conv-l%d' % i)(
net) net)
net = self._batch_norm_relu(net, is_training=is_training) net = self._batch_norm_relu()(net, is_training=is_training)
kernel_size = (2, 2) kernel_size = (2, 2)
fan_out = 256 fan_out = 256
...@@ -256,7 +256,7 @@ class MaskrcnnHead(object): ...@@ -256,7 +256,7 @@ class MaskrcnnHead(object):
bias_initializer=tf.zeros_initializer(), bias_initializer=tf.zeros_initializer(),
name='conv5-mask')( name='conv5-mask')(
net) net)
net = self._batch_norm_relu(net, is_training=is_training) net = self._batch_norm_relu()(net, is_training=is_training)
kernel_size = (1, 1) kernel_size = (1, 1)
fan_out = self._num_classes fan_out = self._num_classes
......
...@@ -31,6 +31,7 @@ class BatchNormRelu(tf.keras.layers.Layer): ...@@ -31,6 +31,7 @@ class BatchNormRelu(tf.keras.layers.Layer):
trainable=True, trainable=True,
relu=True, relu=True,
init_zero=False, init_zero=False,
fused=True,
name=None): name=None):
"""A class to construct layers for a batch normalization followed by a ReLU. """A class to construct layers for a batch normalization followed by a ReLU.
...@@ -43,6 +44,7 @@ class BatchNormRelu(tf.keras.layers.Layer): ...@@ -43,6 +44,7 @@ class BatchNormRelu(tf.keras.layers.Layer):
relu: `bool` if False, omits the ReLU operation. relu: `bool` if False, omits the ReLU operation.
init_zero: `bool` if True, initializes scale parameter of batch init_zero: `bool` if True, initializes scale parameter of batch
normalization with 0. If False, initialize it with 1. normalization with 0. If False, initialize it with 1.
fused: `bool` fused option in batch normalziation.
name: `str` name for the operation. name: `str` name for the operation.
""" """
self._use_relu = relu self._use_relu = relu
...@@ -51,14 +53,13 @@ class BatchNormRelu(tf.keras.layers.Layer): ...@@ -51,14 +53,13 @@ class BatchNormRelu(tf.keras.layers.Layer):
gamma_initializer = tf.keras.initializers.Zeros() gamma_initializer = tf.keras.initializers.Zeros()
else: else:
gamma_initializer = tf.keras.initializers.Ones() gamma_initializer = tf.keras.initializers.Ones()
# TODO(yeqing): Check if we can change the fused=True again.
self._batch_norm_op = tf.keras.layers.BatchNormalization( self._batch_norm_op = tf.keras.layers.BatchNormalization(
momentum=momentum, momentum=momentum,
epsilon=epsilon, epsilon=epsilon,
center=True, center=True,
scale=True, scale=True,
trainable=trainable, trainable=trainable,
fused=False, fused=fused,
gamma_initializer=gamma_initializer, gamma_initializer=gamma_initializer,
name=name) name=name)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment