Commit 0bea104f authored by Fan Yang's avatar Fan Yang Committed by A. Unique TensorFlower
Browse files

Add control to turn-on/off BN for segmentation head 3D.

PiperOrigin-RevId: 384262412
parent 0375a63b
......@@ -56,6 +56,7 @@ class SegmentationHead3D(hyperparams.Config):
num_filters: int = 256
upsample_factor: int = 1
output_logits: bool = True
use_batch_normalization: bool = True
@dataclasses.dataclass
......
......@@ -54,6 +54,7 @@ def build_segmentation_model_3d(
use_sync_bn=norm_activation_config.use_sync_bn,
norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon,
use_batch_normalization=head_config.use_batch_normalization,
kernel_regularizer=l2_regularizer,
output_logits=head_config.output_logits)
......
......@@ -26,12 +26,14 @@ from official.vision.beta.projects.volumetric_models.modeling import factory
class SegmentationModelBuilderTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters(((128, 128, 128), 5e-5), ((64, 64, 64), None))
def test_unet3d_builder(self, input_size, weight_decay):
@parameterized.parameters(((128, 128, 128), 5e-5, True),
((64, 64, 64), None, False))
def test_unet3d_builder(self, input_size, weight_decay, use_bn):
num_classes = 3
input_specs = tf.keras.layers.InputSpec(
shape=[None, input_size[0], input_size[1], input_size[2], 3])
model_config = exp_cfg.SemanticSegmentationModel3D(num_classes=num_classes)
model_config.head.use_batch_normalization = use_bn
l2_regularizer = (
tf.keras.regularizers.l2(weight_decay) if weight_decay else None)
model = factory.build_segmentation_model_3d(
......
......@@ -34,6 +34,7 @@ class SegmentationHead3D(tf.keras.layers.Layer):
use_sync_bn: bool = False,
norm_momentum: float = 0.99,
norm_epsilon: float = 0.001,
use_batch_normalization: bool = False,
kernel_regularizer: tf.keras.regularizers.Regularizer = None,
bias_regularizer: tf.keras.regularizers.Regularizer = None,
output_logits: bool = True,
......@@ -57,6 +58,8 @@ class SegmentationHead3D(tf.keras.layers.Layer):
norm_momentum: `float`, the momentum parameter of the normalization
layers.
norm_epsilon: `float`, the epsilon parameter of the normalization layers.
use_batch_normalization: A bool of whether to use batch normalization or
not.
kernel_regularizer: `tf.keras.regularizers.Regularizer` object for layer
kernel.
bias_regularizer: `tf.keras.regularizers.Regularizer` object for bias.
......@@ -76,6 +79,7 @@ class SegmentationHead3D(tf.keras.layers.Layer):
'use_sync_bn': use_sync_bn,
'norm_momentum': norm_momentum,
'norm_epsilon': norm_epsilon,
'use_batch_normalization': use_batch_normalization,
'kernel_regularizer': kernel_regularizer,
'bias_regularizer': bias_regularizer,
'output_logits': output_logits
......@@ -119,7 +123,8 @@ class SegmentationHead3D(tf.keras.layers.Layer):
filters=self._config_dict['num_filters'],
**conv_kwargs))
norm_name = 'segmentation_head_norm_{}'.format(i)
self._norms.append(bn_op(name=norm_name, **bn_kwargs))
if self._config_dict['use_batch_normalization']:
self._norms.append(bn_op(name=norm_name, **bn_kwargs))
self._classifier = conv_op(
name='segmentation_output',
......@@ -154,9 +159,10 @@ class SegmentationHead3D(tf.keras.layers.Layer):
"""
x = decoder_output[str(self._config_dict['level'])]
for conv, norm in zip(self._convs, self._norms):
for i, conv in enumerate(self._convs):
x = conv(x)
x = norm(x)
if self._norms:
x = self._norms[i](x)
x = self._activation(x)
x = tf.keras.layers.UpSampling3D(size=self._config_dict['upsample_factor'])(
......
......@@ -25,12 +25,15 @@ from official.vision.beta.projects.volumetric_models.modeling.heads import segme
class SegmentationHead3DTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters(
(1, 0),
(2, 1),
(1, 0, True),
(2, 1, False),
)
def test_forward(self, level, num_convs):
def test_forward(self, level, num_convs, use_bn):
head = segmentation_heads_3d.SegmentationHead3D(
num_classes=10, level=level, num_convs=num_convs)
num_classes=10,
level=level,
num_convs=num_convs,
use_batch_normalization=use_bn)
backbone_features = {
'1': np.random.rand(2, 128, 128, 128, 16),
'2': np.random.rand(2, 64, 64, 64, 16),
......
......@@ -106,7 +106,11 @@ class BasicBlock3DVolume(tf.keras.layers.Layer):
padding='same',
data_format=tf.keras.backend.image_data_format(),
activation=None))
self._norms.append(self._norm(axis=self._bn_axis))
self._norms.append(
self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon))
super(BasicBlock3DVolume, self).build(input_shape)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment