Commit 32229eae authored by Yeqing Li's avatar Yeqing Li Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 274918820
parent 63e20270
......@@ -124,7 +124,7 @@ class FastrcnnHead(object):
"""
self._num_classes = num_classes
self._mlp_head_dim = mlp_head_dim
self._batch_norm_relu = batch_norm_relu()
self._batch_norm_relu = batch_norm_relu
def __call__(self, roi_features, is_training=None):
"""Box and class branches for the Mask-RCNN model.
......@@ -151,11 +151,11 @@ class FastrcnnHead(object):
units=self._mlp_head_dim, activation=None, name='fc6')(
roi_features)
net = self._batch_norm_relu(net, is_training=is_training)
net = self._batch_norm_relu(fused=False)(net, is_training=is_training)
net = tf.keras.layers.Dense(
units=self._mlp_head_dim, activation=None, name='fc7')(
net)
net = self._batch_norm_relu(net, is_training=is_training)
net = self._batch_norm_relu(fused=False)(net, is_training=is_training)
class_outputs = tf.keras.layers.Dense(
self._num_classes,
......@@ -189,7 +189,7 @@ class MaskrcnnHead(object):
"""
self._num_classes = num_classes
self._mrcnn_resolution = mrcnn_resolution
self._batch_norm_relu = batch_norm_relu()
self._batch_norm_relu = batch_norm_relu
def __call__(self, roi_features, class_indices, is_training=None):
"""Mask branch for the Mask-RCNN model.
......@@ -240,7 +240,7 @@ class MaskrcnnHead(object):
bias_initializer=tf.zeros_initializer(),
name='mask-conv-l%d' % i)(
net)
net = self._batch_norm_relu(net, is_training=is_training)
net = self._batch_norm_relu()(net, is_training=is_training)
kernel_size = (2, 2)
fan_out = 256
......@@ -256,7 +256,7 @@ class MaskrcnnHead(object):
bias_initializer=tf.zeros_initializer(),
name='conv5-mask')(
net)
net = self._batch_norm_relu(net, is_training=is_training)
net = self._batch_norm_relu()(net, is_training=is_training)
kernel_size = (1, 1)
fan_out = self._num_classes
......
......@@ -31,6 +31,7 @@ class BatchNormRelu(tf.keras.layers.Layer):
trainable=True,
relu=True,
init_zero=False,
fused=True,
name=None):
"""A class to construct layers for a batch normalization followed by a ReLU.
......@@ -43,6 +44,7 @@ class BatchNormRelu(tf.keras.layers.Layer):
relu: `bool` if False, omits the ReLU operation.
init_zero: `bool` if True, initializes scale parameter of batch
normalization with 0. If False, initialize it with 1.
fused: `bool` fused option in batch normalziation.
name: `str` name for the operation.
"""
self._use_relu = relu
......@@ -51,14 +53,13 @@ class BatchNormRelu(tf.keras.layers.Layer):
gamma_initializer = tf.keras.initializers.Zeros()
else:
gamma_initializer = tf.keras.initializers.Ones()
# TODO(yeqing): Check if we can change the fused=True again.
self._batch_norm_op = tf.keras.layers.BatchNormalization(
momentum=momentum,
epsilon=epsilon,
center=True,
scale=True,
trainable=trainable,
fused=False,
fused=fused,
gamma_initializer=gamma_initializer,
name=name)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment