Commit 0bea104f authored by Fan Yang's avatar Fan Yang Committed by A. Unique TensorFlower
Browse files

Add control to turn-on/off BN for segmentation head 3D.

PiperOrigin-RevId: 384262412
parent 0375a63b
...@@ -56,6 +56,7 @@ class SegmentationHead3D(hyperparams.Config): ...@@ -56,6 +56,7 @@ class SegmentationHead3D(hyperparams.Config):
num_filters: int = 256 num_filters: int = 256
upsample_factor: int = 1 upsample_factor: int = 1
output_logits: bool = True output_logits: bool = True
use_batch_normalization: bool = True
@dataclasses.dataclass @dataclasses.dataclass
......
...@@ -54,6 +54,7 @@ def build_segmentation_model_3d( ...@@ -54,6 +54,7 @@ def build_segmentation_model_3d(
use_sync_bn=norm_activation_config.use_sync_bn, use_sync_bn=norm_activation_config.use_sync_bn,
norm_momentum=norm_activation_config.norm_momentum, norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon, norm_epsilon=norm_activation_config.norm_epsilon,
use_batch_normalization=head_config.use_batch_normalization,
kernel_regularizer=l2_regularizer, kernel_regularizer=l2_regularizer,
output_logits=head_config.output_logits) output_logits=head_config.output_logits)
......
...@@ -26,12 +26,14 @@ from official.vision.beta.projects.volumetric_models.modeling import factory ...@@ -26,12 +26,14 @@ from official.vision.beta.projects.volumetric_models.modeling import factory
class SegmentationModelBuilderTest(parameterized.TestCase, tf.test.TestCase): class SegmentationModelBuilderTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters(((128, 128, 128), 5e-5), ((64, 64, 64), None)) @parameterized.parameters(((128, 128, 128), 5e-5, True),
def test_unet3d_builder(self, input_size, weight_decay): ((64, 64, 64), None, False))
def test_unet3d_builder(self, input_size, weight_decay, use_bn):
num_classes = 3 num_classes = 3
input_specs = tf.keras.layers.InputSpec( input_specs = tf.keras.layers.InputSpec(
shape=[None, input_size[0], input_size[1], input_size[2], 3]) shape=[None, input_size[0], input_size[1], input_size[2], 3])
model_config = exp_cfg.SemanticSegmentationModel3D(num_classes=num_classes) model_config = exp_cfg.SemanticSegmentationModel3D(num_classes=num_classes)
model_config.head.use_batch_normalization = use_bn
l2_regularizer = ( l2_regularizer = (
tf.keras.regularizers.l2(weight_decay) if weight_decay else None) tf.keras.regularizers.l2(weight_decay) if weight_decay else None)
model = factory.build_segmentation_model_3d( model = factory.build_segmentation_model_3d(
......
...@@ -34,6 +34,7 @@ class SegmentationHead3D(tf.keras.layers.Layer): ...@@ -34,6 +34,7 @@ class SegmentationHead3D(tf.keras.layers.Layer):
use_sync_bn: bool = False, use_sync_bn: bool = False,
norm_momentum: float = 0.99, norm_momentum: float = 0.99,
norm_epsilon: float = 0.001, norm_epsilon: float = 0.001,
use_batch_normalization: bool = False,
kernel_regularizer: tf.keras.regularizers.Regularizer = None, kernel_regularizer: tf.keras.regularizers.Regularizer = None,
bias_regularizer: tf.keras.regularizers.Regularizer = None, bias_regularizer: tf.keras.regularizers.Regularizer = None,
output_logits: bool = True, output_logits: bool = True,
...@@ -57,6 +58,8 @@ class SegmentationHead3D(tf.keras.layers.Layer): ...@@ -57,6 +58,8 @@ class SegmentationHead3D(tf.keras.layers.Layer):
norm_momentum: `float`, the momentum parameter of the normalization norm_momentum: `float`, the momentum parameter of the normalization
layers. layers.
norm_epsilon: `float`, the epsilon parameter of the normalization layers. norm_epsilon: `float`, the epsilon parameter of the normalization layers.
use_batch_normalization: A bool of whether to use batch normalization or
not.
kernel_regularizer: `tf.keras.regularizers.Regularizer` object for layer kernel_regularizer: `tf.keras.regularizers.Regularizer` object for layer
kernel. kernel.
bias_regularizer: `tf.keras.regularizers.Regularizer` object for bias. bias_regularizer: `tf.keras.regularizers.Regularizer` object for bias.
...@@ -76,6 +79,7 @@ class SegmentationHead3D(tf.keras.layers.Layer): ...@@ -76,6 +79,7 @@ class SegmentationHead3D(tf.keras.layers.Layer):
'use_sync_bn': use_sync_bn, 'use_sync_bn': use_sync_bn,
'norm_momentum': norm_momentum, 'norm_momentum': norm_momentum,
'norm_epsilon': norm_epsilon, 'norm_epsilon': norm_epsilon,
'use_batch_normalization': use_batch_normalization,
'kernel_regularizer': kernel_regularizer, 'kernel_regularizer': kernel_regularizer,
'bias_regularizer': bias_regularizer, 'bias_regularizer': bias_regularizer,
'output_logits': output_logits 'output_logits': output_logits
...@@ -119,6 +123,7 @@ class SegmentationHead3D(tf.keras.layers.Layer): ...@@ -119,6 +123,7 @@ class SegmentationHead3D(tf.keras.layers.Layer):
filters=self._config_dict['num_filters'], filters=self._config_dict['num_filters'],
**conv_kwargs)) **conv_kwargs))
norm_name = 'segmentation_head_norm_{}'.format(i) norm_name = 'segmentation_head_norm_{}'.format(i)
if self._config_dict['use_batch_normalization']:
self._norms.append(bn_op(name=norm_name, **bn_kwargs)) self._norms.append(bn_op(name=norm_name, **bn_kwargs))
self._classifier = conv_op( self._classifier = conv_op(
...@@ -154,9 +159,10 @@ class SegmentationHead3D(tf.keras.layers.Layer): ...@@ -154,9 +159,10 @@ class SegmentationHead3D(tf.keras.layers.Layer):
""" """
x = decoder_output[str(self._config_dict['level'])] x = decoder_output[str(self._config_dict['level'])]
for conv, norm in zip(self._convs, self._norms): for i, conv in enumerate(self._convs):
x = conv(x) x = conv(x)
x = norm(x) if self._norms:
x = self._norms[i](x)
x = self._activation(x) x = self._activation(x)
x = tf.keras.layers.UpSampling3D(size=self._config_dict['upsample_factor'])( x = tf.keras.layers.UpSampling3D(size=self._config_dict['upsample_factor'])(
......
...@@ -25,12 +25,15 @@ from official.vision.beta.projects.volumetric_models.modeling.heads import segme ...@@ -25,12 +25,15 @@ from official.vision.beta.projects.volumetric_models.modeling.heads import segme
class SegmentationHead3DTest(parameterized.TestCase, tf.test.TestCase): class SegmentationHead3DTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters( @parameterized.parameters(
(1, 0), (1, 0, True),
(2, 1), (2, 1, False),
) )
def test_forward(self, level, num_convs): def test_forward(self, level, num_convs, use_bn):
head = segmentation_heads_3d.SegmentationHead3D( head = segmentation_heads_3d.SegmentationHead3D(
num_classes=10, level=level, num_convs=num_convs) num_classes=10,
level=level,
num_convs=num_convs,
use_batch_normalization=use_bn)
backbone_features = { backbone_features = {
'1': np.random.rand(2, 128, 128, 128, 16), '1': np.random.rand(2, 128, 128, 128, 16),
'2': np.random.rand(2, 64, 64, 64, 16), '2': np.random.rand(2, 64, 64, 64, 16),
......
...@@ -106,7 +106,11 @@ class BasicBlock3DVolume(tf.keras.layers.Layer): ...@@ -106,7 +106,11 @@ class BasicBlock3DVolume(tf.keras.layers.Layer):
padding='same', padding='same',
data_format=tf.keras.backend.image_data_format(), data_format=tf.keras.backend.image_data_format(),
activation=None)) activation=None))
self._norms.append(self._norm(axis=self._bn_axis)) self._norms.append(
self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon))
super(BasicBlock3DVolume, self).build(input_shape) super(BasicBlock3DVolume, self).build(input_shape)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment