Commit f68d9962 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 388549139
parent be2a21a6
......@@ -60,6 +60,7 @@ class SegmentationHead(hyperparams.Config):
level: int = 3
num_convs: int = 2
num_filters: int = 256
use_depthwise_convolution: bool = False
prediction_kernel_size: int = 1
upsample_factor: int = 1
feature_fusion: Optional[str] = None # None, deeplabv3plus, or pyramid_fusion
......
......@@ -31,6 +31,7 @@ class SegmentationHead(tf.keras.layers.Layer):
level: Union[int, str],
num_convs: int = 2,
num_filters: int = 256,
use_depthwise_convolution: bool = False,
prediction_kernel_size: int = 1,
upsample_factor: int = 1,
feature_fusion: Optional[str] = None,
......@@ -53,6 +54,8 @@ class SegmentationHead(tf.keras.layers.Layer):
prediction layer.
num_filters: An `int` number to specify the number of filters used.
Default is 256.
use_depthwise_convolution: A bool to specify if use depthwise separable
convolutions.
prediction_kernel_size: An `int` number to specify the kernel size of the
prediction layer.
upsample_factor: An `int` number to specify the upsampling factor to
......@@ -84,6 +87,7 @@ class SegmentationHead(tf.keras.layers.Layer):
'level': level,
'num_convs': num_convs,
'num_filters': num_filters,
'use_depthwise_convolution': use_depthwise_convolution,
'prediction_kernel_size': prediction_kernel_size,
'upsample_factor': upsample_factor,
'feature_fusion': feature_fusion,
......@@ -104,12 +108,14 @@ class SegmentationHead(tf.keras.layers.Layer):
def build(self, input_shape: Union[tf.TensorShape, List[tf.TensorShape]]):
"""Creates the variables of the segmentation head."""
use_depthwise_convolution = self._config_dict['use_depthwise_convolution']
random_initializer = tf.keras.initializers.RandomNormal(stddev=0.01)
conv_op = tf.keras.layers.Conv2D
conv_kwargs = {
'kernel_size': 3,
'kernel_size': 3 if not use_depthwise_convolution else 1,
'padding': 'same',
'use_bias': False,
'kernel_initializer': tf.keras.initializers.RandomNormal(stddev=0.01),
'kernel_initializer': random_initializer,
'kernel_regularizer': self._config_dict['kernel_regularizer'],
}
bn_op = (tf.keras.layers.experimental.SyncBatchNormalization
......@@ -139,6 +145,16 @@ class SegmentationHead(tf.keras.layers.Layer):
self._convs = []
self._norms = []
for i in range(self._config_dict['num_convs']):
if use_depthwise_convolution:
self._convs.append(
tf.keras.layers.DepthwiseConv2D(
name='segmentation_head_depthwise_conv_{}'.format(i),
kernel_size=3,
padding='same',
use_bias=False,
depthwise_initializer=random_initializer,
depthwise_regularizer=self._config_dict['kernel_regularizer'],
depth_multiplier=1))
conv_name = 'segmentation_head_conv_{}'.format(i)
self._convs.append(
conv_op(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment