Commit 2ce12046 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 443199105
parent b4d128f1
......@@ -14,6 +14,7 @@
"""Contains common building blocks for neural networks."""
import enum
from typing import Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union
import tensorflow as tf
......@@ -31,6 +32,14 @@ States = Dict[str, tf.Tensor]
Activation = Union[str, Callable]
# String constants.
class FeatureFusion(str, enum.Enum):
PYRAMID_FUSION = 'pyramid_fusion'
PANOPTIC_FPN_FUSION = 'panoptic_fpn_fusion'
DEEPLABV3PLUS = 'deeplabv3plus'
DEEPLABV3PLUS_SUM_TO_MERGE = 'deeplabv3plus_sum_to_merge'
@tf.keras.utils.register_keras_serializable(package='Vision')
class SqueezeExcitationQuantized(
helper.LayerQuantizerHelper,
......@@ -237,10 +246,11 @@ class SegmentationHeadQuantized(tf.keras.layers.Layer):
prediction layer.
upsample_factor: An `int` number to specify the upsampling factor to
generate finer mask. Default 1 means no upsampling is applied.
feature_fusion: One of `deeplabv3plus`, `pyramid_fusion`, or None. If
`deeplabv3plus`, features from decoder_features[level] will be fused
with low level feature maps from backbone. If `pyramid_fusion`,
multiscale features will be resized and fused at the target level.
feature_fusion: One of `deeplabv3plus`, `deeplabv3plus_sum_to_merge`,
`pyramid_fusion`, or None. If `deeplabv3plus`, features from
decoder_features[level] will be fused with low level feature maps from
backbone. If `pyramid_fusion`, multiscale features will be resized and
fused at the target level.
decoder_min_level: An `int` of minimum level from decoder to use in
feature fusion. It is only used when feature_fusion is set to
`panoptic_fpn_fusion`.
......@@ -327,7 +337,9 @@ class SegmentationHeadQuantized(tf.keras.layers.Layer):
'epsilon': self._config_dict['norm_epsilon'],
}
if self._config_dict['feature_fusion'] == 'deeplabv3plus':
if self._config_dict['feature_fusion'] in [
FeatureFusion.DEEPLABV3PLUS, FeatureFusion.DEEPLABV3PLUS_SUM_TO_MERGE
]:
# Deeplabv3+ feature fusion layers.
self._dlv3p_conv = helper.Conv2DQuantized(
kernel_size=1,
......@@ -388,6 +400,7 @@ class SegmentationHeadQuantized(tf.keras.layers.Layer):
backbone_shape[1], backbone_shape[2], interpolation='bilinear')
self._concat_layer = helper.ConcatenateQuantized(axis=self._bn_axis)
self._add_layer = tf.keras.layers.Add()
super().build(input_shape)
......@@ -412,14 +425,16 @@ class SegmentationHeadQuantized(tf.keras.layers.Layer):
segmentation prediction mask: A `tf.Tensor` of the segmentation mask
scores predicted from input features.
"""
if self._config_dict['feature_fusion'] in ('pyramid_fusion',
'panoptic_fpn_fusion'):
if self._config_dict['feature_fusion'] in (
FeatureFusion.PYRAMID_FUSION, FeatureFusion.PANOPTIC_FPN_FUSION):
raise ValueError(
'The feature fusion method `pyramid_fusion` is not supported in QAT.')
backbone_output = inputs[0]
decoder_output = inputs[1]
if self._config_dict['feature_fusion'] == 'deeplabv3plus':
if self._config_dict['feature_fusion'] in {
FeatureFusion.DEEPLABV3PLUS, FeatureFusion.DEEPLABV3PLUS_SUM_TO_MERGE
}:
# deeplabv3+ feature fusion.
x = decoder_output[str(self._config_dict['level'])] if isinstance(
decoder_output, dict) else decoder_output
......@@ -429,7 +444,10 @@ class SegmentationHeadQuantized(tf.keras.layers.Layer):
y = self._activation_layer(y)
x = self._resizing_layer(x)
x = tf.cast(x, dtype=y.dtype)
x = self._concat_layer([x, y])
if self._config_dict['feature_fusion'] == FeatureFusion.DEEPLABV3PLUS:
x = self._concat_layer([x, y])
else:
x = self._add_layer([x, y])
else:
x = decoder_output[str(self._config_dict['level'])] if isinstance(
decoder_output, dict) else decoder_output
......
......@@ -24,12 +24,15 @@ from official.projects.qat.vision.modeling.layers import nn_layers
class NNLayersTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters(
('deeplabv3plus', 1),
('deeplabv3plus', 2),
('deeplabv3', 1),
('deeplabv3', 2),
('deeplabv3plus', 1, 128, 128),
('deeplabv3plus', 2, 128, 128),
('deeplabv3', 1, 128, 64),
('deeplabv3', 2, 128, 64),
('deeplabv3plus_sum_to_merge', 1, 64, 128),
('deeplabv3plus_sum_to_merge', 2, 64, 128),
)
def test_segmentation_head_creation(self, feature_fusion, upsample_factor):
def test_segmentation_head_creation(self, feature_fusion, upsample_factor,
low_level_num_filters, expected_shape):
input_size = 128
decoder_outupt_size = input_size // 2
......@@ -42,14 +45,11 @@ class NNLayersTest(parameterized.TestCase, tf.test.TestCase):
level=4,
upsample_factor=upsample_factor,
low_level=2,
low_level_num_filters=128,
low_level_num_filters=low_level_num_filters,
feature_fusion=feature_fusion)
features = segmentation_head((backbone_output, decoder_output))
expected_shape = (
input_size
if feature_fusion == 'deeplabv3plus' else decoder_outupt_size)
self.assertAllEqual([
2, expected_shape * upsample_factor, expected_shape * upsample_factor, 5
], features.shape.as_list())
......
......@@ -233,8 +233,9 @@ class SegmentationHead(tf.keras.layers.Layer):
prediction layer.
upsample_factor: An `int` number to specify the upsampling factor to
generate finer mask. Default 1 means no upsampling is applied.
feature_fusion: One of `deeplabv3plus`, `pyramid_fusion`,
`panoptic_fpn_fusion`, or None. If `deeplabv3plus`, features from
feature_fusion: One of the constants in nn_layers.FeatureFusion, namely
`deeplabv3plus`, `pyramid_fusion`, `panoptic_fpn_fusion`,
`deeplabv3plus_sum_to_merge`, or None. If `deeplabv3plus`, features from
decoder_features[level] will be fused with low level feature maps from
backbone. If `pyramid_fusion`, multiscale features will be resized and
fused at the target level.
......@@ -245,10 +246,12 @@ class SegmentationHead(tf.keras.layers.Layer):
feature fusion. It is only used when feature_fusion is set to
`panoptic_fpn_fusion`.
low_level: An `int` of backbone level to be used for feature fusion. It is
used when feature_fusion is set to `deeplabv3plus`.
used when feature_fusion is set to `deeplabv3plus` or
`deeplabv3plus_sum_to_merge`.
low_level_num_filters: An `int` of reduced number of filters for the low
level features before fusing it with higher level features. It is only
used when feature_fusion is set to `deeplabv3plus`.
used when feature_fusion is set to `deeplabv3plus` or
`deeplabv3plus_sum_to_merge`.
num_decoder_filters: An `int` of number of filters in the decoder outputs.
It is only used when feature_fusion is set to `panoptic_fpn_fusion`.
activation: A `str` that indicates which activation is used, e.g. 'relu',
......@@ -312,7 +315,8 @@ class SegmentationHead(tf.keras.layers.Layer):
'epsilon': self._config_dict['norm_epsilon'],
}
if self._config_dict['feature_fusion'] == 'deeplabv3plus':
if self._config_dict['feature_fusion'] in {'deeplabv3plus',
'deeplabv3plus_sum_to_merge'}:
# Deeplabv3+ feature fusion layers.
self._dlv3p_conv = conv_op(
kernel_size=1,
......@@ -398,7 +402,8 @@ class SegmentationHead(tf.keras.layers.Layer):
backbone_output = inputs[0]
decoder_output = inputs[1]
if self._config_dict['feature_fusion'] == 'deeplabv3plus':
if self._config_dict['feature_fusion'] in {'deeplabv3plus',
'deeplabv3plus_sum_to_merge'}:
# deeplabv3+ feature fusion
x = decoder_output[str(self._config_dict['level'])] if isinstance(
decoder_output, dict) else decoder_output
......@@ -410,7 +415,10 @@ class SegmentationHead(tf.keras.layers.Layer):
x = tf.image.resize(
x, tf.shape(y)[1:3], method=tf.image.ResizeMethod.BILINEAR)
x = tf.cast(x, dtype=y.dtype)
x = tf.concat([x, y], axis=self._bn_axis)
if self._config_dict['feature_fusion'] == 'deeplabv3plus':
x = tf.concat([x, y], axis=self._bn_axis)
else:
x = tf.keras.layers.Add()([x, y])
elif self._config_dict['feature_fusion'] == 'pyramid_fusion':
if not isinstance(decoder_output, dict):
raise ValueError('Only support dictionary decoder_output.')
......
......@@ -30,7 +30,9 @@ class SegmentationHeadTest(parameterized.TestCase, tf.test.TestCase):
(2, 'panoptic_fpn_fusion', 2, 5),
(2, 'panoptic_fpn_fusion', 2, 6),
(3, 'panoptic_fpn_fusion', 3, 5),
(3, 'panoptic_fpn_fusion', 3, 6))
(3, 'panoptic_fpn_fusion', 3, 6),
(3, 'deeplabv3plus', 3, 6),
(3, 'deeplabv3plus_sum_to_merge', 3, 6))
def test_forward(self, level, feature_fusion,
decoder_min_level, decoder_max_level):
backbone_features = {
......@@ -52,6 +54,8 @@ class SegmentationHeadTest(parameterized.TestCase, tf.test.TestCase):
head = segmentation_heads.SegmentationHead(
num_classes=10,
level=level,
low_level=decoder_min_level,
low_level_num_filters=64,
feature_fusion=feature_fusion,
decoder_min_level=decoder_min_level,
decoder_max_level=decoder_max_level,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment